import torch
from torch.nn import Module, Embedding, LSTM, Linear, Dropout
[docs]class DKTPlus(Module):
def __init__(self, num_c, emb_size, lambda_r, lambda_w1, lambda_w2, dropout=0.1, emb_type="qid", emb_path="", pretrain_dim=768):
super().__init__()
self.model_name = "dkt+"
self.num_c = num_c
self.emb_size = emb_size
self.hidden_size = emb_size
self.lambda_r = lambda_r
self.lambda_w1 = lambda_w1
self.lambda_w2 = lambda_w2
self.emb_type = emb_type
if emb_type.startswith("qid"):
self.interaction_emb = Embedding(self.num_c * 2, self.emb_size)
self.lstm_layer = LSTM(self.emb_size, self.hidden_size, batch_first=True)
self.dropout_layer = Dropout(dropout)
self.out_layer = Linear(self.hidden_size, self.num_c)
[docs] def forward(self, q, r):
emb_type = self.emb_type
if emb_type == "qid":
x = q + self.num_c * r
xemb = self.interaction_emb(x)
h, _ = self.lstm_layer(xemb)
h = self.dropout_layer(h)
y = self.out_layer(h)
y = torch.sigmoid(y)
return y