Source code for pykt.models.sakt

import torch

from torch.nn import Module, Embedding, Linear, MultiheadAttention, LayerNorm, Dropout
from .utils import transformer_FFN, pos_encode, ut_mask, get_clones

[docs]class SAKT(Module): def __init__(self, num_c, seq_len, emb_size, num_attn_heads, dropout, num_en=2, emb_type="qid", emb_path="", pretrain_dim=768): super().__init__() self.model_name = "sakt" self.emb_type = emb_type self.num_c = num_c self.seq_len = seq_len self.emb_size = emb_size self.num_attn_heads = num_attn_heads self.dropout = dropout self.num_en = num_en if emb_type.startswith("qid"): # num_c, seq_len, emb_size, num_attn_heads, dropout, emb_path="") self.interaction_emb = Embedding(num_c * 2, emb_size) self.exercise_emb = Embedding(num_c, emb_size) # self.P = Parameter(torch.Tensor(self.seq_len, self.emb_size)) self.position_emb = Embedding(seq_len, emb_size) self.blocks = get_clones(Blocks(emb_size, num_attn_heads, dropout), self.num_en) self.dropout_layer = Dropout(dropout) self.pred = Linear(self.emb_size, 1)
[docs] def base_emb(self, q, r, qry): x = q + self.num_c * r qshftemb, xemb = self.exercise_emb(qry), self.interaction_emb(x) posemb = self.position_emb(pos_encode(xemb.shape[1])) xemb = xemb + posemb return qshftemb, xemb
[docs] def forward(self, q, r, qry, qtest=False): emb_type = self.emb_type qemb, qshftemb, xemb = None, None, None if emb_type == "qid": qshftemb, xemb = self.base_emb(q, r, qry) # print(f"qemb: {qemb.shape}, xemb: {xemb.shape}, qshftemb: {qshftemb.shape}") for i in range(self.num_en): xemb = self.blocks[i](qshftemb, xemb, xemb) p = torch.sigmoid(self.pred(self.dropout_layer(xemb))).squeeze(-1) if not qtest: return p else: return p, xemb
[docs]class Blocks(Module): def __init__(self, emb_size, num_attn_heads, dropout) -> None: super().__init__() self.attn = MultiheadAttention(emb_size, num_attn_heads, dropout=dropout) self.attn_dropout = Dropout(dropout) self.attn_layer_norm = LayerNorm(emb_size) self.FFN = transformer_FFN(emb_size, dropout) self.FFN_dropout = Dropout(dropout) self.FFN_layer_norm = LayerNorm(emb_size)
[docs] def forward(self, q=None, k=None, v=None): q, k, v = q.permute(1, 0, 2), k.permute(1, 0, 2), v.permute(1, 0, 2) # attn -> drop -> skip -> norm # transformer: attn -> drop -> skip -> norm transformer default causal_mask = ut_mask(seq_len = k.shape[0]) attn_emb, _ = self.attn(q, k, v, attn_mask=causal_mask) attn_emb = self.attn_dropout(attn_emb) attn_emb, q = attn_emb.permute(1, 0, 2), q.permute(1, 0, 2) attn_emb = self.attn_layer_norm(q + attn_emb) emb = self.FFN(attn_emb) emb = self.FFN_dropout(emb) emb = self.FFN_layer_norm(attn_emb + emb) return emb