Source code for pykt.models.dkt_forget

import torch
from torch.nn import Module, Embedding, LSTM, Linear, Dropout

device = "cpu" if not torch.cuda.is_available() else "cuda"

[docs]class DKTForget(Module): def __init__(self, num_c, num_rgap, num_sgap, num_pcount, emb_size, dropout=0.1, emb_type='qid', emb_path=""): super().__init__() self.model_name = "dkt_forget" self.num_c = num_c self.emb_size = emb_size self.hidden_size = emb_size self.emb_type = emb_type if emb_type.startswith("qid"): self.interaction_emb = Embedding(self.num_c * 2, self.emb_size) self.c_integration = CIntegration(num_rgap, num_sgap, num_pcount, emb_size) ntotal = num_rgap + num_sgap + num_pcount self.lstm_layer = LSTM(self.emb_size + ntotal, self.hidden_size, batch_first=True) self.dropout_layer = Dropout(dropout) self.out_layer = Linear(self.hidden_size + ntotal, self.num_c)
[docs] def forward(self, q, r, dgaps): emb_type = self.emb_type if emb_type == "qid": x = q + self.num_c * r xemb = self.interaction_emb(x) theta_in = self.c_integration(xemb, dgaps["rgaps"].long(), dgaps["sgaps"].long(), dgaps["pcounts"].long()) h, _ = self.lstm_layer(theta_in) theta_out = self.c_integration(h, dgaps["shft_rgaps"].long(), dgaps["shft_sgaps"].long(), dgaps["shft_pcounts"].long()) theta_out = self.dropout_layer(theta_out) y = self.out_layer(theta_out) y = torch.sigmoid(y) return y
[docs]class CIntegration(Module): def __init__(self, num_rgap, num_sgap, num_pcount, emb_dim) -> None: super().__init__() self.rgap_eye = torch.eye(num_rgap) self.sgap_eye = torch.eye(num_sgap) self.pcount_eye = torch.eye(num_pcount) ntotal = num_rgap + num_sgap + num_pcount self.cemb = Linear(ntotal, emb_dim, bias=False) # print(f"total: {ntotal}, self.cemb.weight: {self.cemb.weight.shape}")
[docs] def forward(self, vt, rgap, sgap, pcount): rgap, sgap, pcount = self.rgap_eye[rgap].to(device), self.sgap_eye[sgap].to(device), self.pcount_eye[pcount].to(device) # print(f"vt: {vt.shape}, rgap: {rgap.shape}, sgap: {sgap.shape}, pcount: {pcount.shape}") ct = torch.cat((rgap, sgap, pcount), -1) # bz * seq_len * num_fea # print(f"ct: {ct.shape}, self.cemb.weight: {self.cemb.weight.shape}") # element-wise mul Cct = self.cemb(ct) # bz * seq_len * emb # print(f"ct: {ct.shape}, Cct: {Cct.shape}") theta = torch.mul(vt, Cct) theta = torch.cat((theta, ct), -1) return theta