import os
import numpy as np
import torch
from torch.nn import Module, Parameter, Embedding, Linear, Dropout
from torch.nn.init import kaiming_normal_
[docs]class DKVMN(Module):
def __init__(self, num_c, dim_s, size_m, dropout=0.2, emb_type='qid', emb_path="", pretrain_dim=768):
super().__init__()
self.model_name = "dkvmn"
self.num_c = num_c
self.dim_s = dim_s
self.size_m = size_m
self.emb_type = emb_type
if emb_type.startswith("qid"):
self.k_emb_layer = Embedding(self.num_c, self.dim_s)
self.Mk = Parameter(torch.Tensor(self.size_m, self.dim_s))
self.Mv0 = Parameter(torch.Tensor(self.size_m, self.dim_s))
kaiming_normal_(self.Mk)
kaiming_normal_(self.Mv0)
self.v_emb_layer = Embedding(self.num_c * 2, self.dim_s)
self.f_layer = Linear(self.dim_s * 2, self.dim_s)
self.dropout_layer = Dropout(dropout)
self.p_layer = Linear(self.dim_s, 1)
self.e_layer = Linear(self.dim_s, self.dim_s)
self.a_layer = Linear(self.dim_s, self.dim_s)
[docs] def forward(self, q, r, qtest=False):
emb_type = self.emb_type
batch_size = q.shape[0]
if emb_type == "qid":
x = q + self.num_c * r
k = self.k_emb_layer(q)
v = self.v_emb_layer(x)
Mvt = self.Mv0.unsqueeze(0).repeat(batch_size, 1, 1)
Mv = [Mvt]
w = torch.softmax(torch.matmul(k, self.Mk.T), dim=-1)
# Write Process
e = torch.sigmoid(self.e_layer(v))
a = torch.tanh(self.a_layer(v))
for et, at, wt in zip(
e.permute(1, 0, 2), a.permute(1, 0, 2), w.permute(1, 0, 2)
):
Mvt = Mvt * (1 - (wt.unsqueeze(-1) * et.unsqueeze(1))) + \
(wt.unsqueeze(-1) * at.unsqueeze(1))
Mv.append(Mvt)
Mv = torch.stack(Mv, dim=1)
# Read Process
f = torch.tanh(
self.f_layer(
torch.cat(
[
(w.unsqueeze(-1) * Mv[:, :-1]).sum(-2),
k
],
dim=-1
)
)
)
p = self.p_layer(self.dropout_layer(f))
p = torch.sigmoid(p)
# print(f"p: {p.shape}")
p = p.squeeze(-1)
if not qtest:
return p
else:
return p, f