#!/usr/bin/env python
# coding=utf-8
import torch
from torch import nn
# from models.utils import RobertaEncode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs]class LPKT(nn.Module):
def __init__(self, n_at, n_it, n_exercise, n_question, d_a, d_e, d_k, gamma=0.03, dropout=0.2, q_matrix="", emb_type="qid", emb_path="", pretrain_dim=768, use_time=True):
super(LPKT, self).__init__()
self.model_name = "lpkt"
self.d_k = d_k
self.d_a = d_a
self.d_e = d_e
q_matrix[q_matrix==0] = gamma
self.q_matrix = q_matrix
self.n_question = n_question
self.emb_type = emb_type
self.use_time = use_time
self.at_embed = nn.Embedding(n_at + 10, d_k)
torch.nn.init.xavier_uniform_(self.at_embed.weight)
self.it_embed = nn.Embedding(n_it + 10, d_k)
torch.nn.init.xavier_uniform_(self.it_embed.weight)
self.e_embed = nn.Embedding(n_exercise + 10, d_e)
torch.nn.init.xavier_uniform_(self.e_embed.weight)
if emb_type.startswith("qidcatr"):
self.interaction_emb = nn.Embedding(self.num_exercise * 2, self.d_k)
self.catrlinear = nn.Linear(self.d_k * 2, self.d_k)
self.pooling = nn.MaxPool1d(2, stride=2)
self.avg_pooling = nn.AvgPool1d(2, stride=2)
if emb_type.startswith("qidrobertacatr"):
self.catrlinear = nn.Linear(self.d_k * 3, self.d_k)
self.pooling = nn.MaxPool1d(3, stride=3)
self.avg_pooling = nn.AvgPool1d(3, stride=3)
if emb_type.find("roberta") != -1:
self.roberta_emb = RobertaEncode(self.d_k, emb_path, pretrain_dim)
self.linear_0 = nn.Linear(d_a + d_e, d_k)
torch.nn.init.xavier_uniform_(self.linear_0.weight)
self.linear_1 = nn.Linear(d_a + d_e + d_k, d_k)
torch.nn.init.xavier_uniform_(self.linear_1.weight)
self.linear_2 = nn.Linear(4 * d_k, d_k)
torch.nn.init.xavier_uniform_(self.linear_2.weight)
self.linear_3 = nn.Linear(4 * d_k, d_k)
torch.nn.init.xavier_uniform_(self.linear_3.weight)
self.linear_4 = nn.Linear(3 * d_k, d_k)
torch.nn.init.xavier_uniform_(self.linear_4.weight)
self.linear_5 = nn.Linear(d_e + d_k, d_k)
torch.nn.init.xavier_uniform_(self.linear_5.weight)
self.linear_6 = nn.Linear(3 * d_k, d_k)
torch.nn.init.xavier_uniform_(self.linear_6.weight)
self.linear_7 = nn.Linear(3 * d_k, d_k)
torch.nn.init.xavier_uniform_(self.linear_7.weight)
self.linear_8 = nn.Linear(2 * d_k, d_k)
torch.nn.init.xavier_uniform_(self.linear_8.weight)
self.tanh = nn.Tanh()
self.sig = nn.Sigmoid()
self.dropout = nn.Dropout(dropout)
[docs] def forward(self, e_data, a_data, it_data=None, at_data=None, qtest=False):
emb_type = self.emb_type
batch_size, seq_len = e_data.size(0), e_data.size(1)
e_embed_data = self.e_embed(e_data)
if self.use_time:
if at_data != None:
at_embed_data = self.at_embed(at_data)
it_embed_data = self.it_embed(it_data)
a_data = a_data.view(-1, 1).repeat(1, self.d_a).view(batch_size, -1, self.d_a)
h_pre = nn.init.xavier_uniform_(torch.zeros(self.n_question + 1, self.d_k)).repeat(batch_size, 1, 1).to(device)
h_tilde_pre = None
if emb_type == "qid":
if self.use_time and at_data != None:
all_learning = self.linear_1(torch.cat((e_embed_data, at_embed_data, a_data), 2))
else:
all_learning = self.linear_0(torch.cat((e_embed_data, a_data), 2))
learning_pre = torch.zeros(batch_size, self.d_k).to(device)
pred = torch.zeros(batch_size, seq_len).to(device)
hidden_state = torch.zeros(batch_size, seq_len, self.d_k).to(device)
for t in range(0, seq_len - 1):
e = e_data[:, t]
# q_e: (bs, 1, n_skill)
q_e = self.q_matrix[e].view(batch_size, 1, -1).to(device)
if self.use_time:
it = it_embed_data[:, t]
# Learning Module
if h_tilde_pre is None:
h_tilde_pre = q_e.bmm(h_pre).view(batch_size, self.d_k)
learning = all_learning[:, t]
learning_gain = self.linear_2(torch.cat((learning_pre, it, learning, h_tilde_pre), 1))
learning_gain = self.tanh(learning_gain)
gamma_l = self.linear_3(torch.cat((learning_pre, it, learning, h_tilde_pre), 1))
else:
# Learning Module
if h_tilde_pre is None:
h_tilde_pre = q_e.bmm(h_pre).view(batch_size, self.d_k)
learning = all_learning[:, t]
learning_gain = self.linear_6(torch.cat((learning_pre, learning, h_tilde_pre), 1))
learning_gain = self.tanh(learning_gain)
gamma_l = self.linear_7(torch.cat((learning_pre, learning, h_tilde_pre), 1))
gamma_l = self.sig(gamma_l)
LG = gamma_l * ((learning_gain + 1) / 2)
LG_tilde = self.dropout(q_e.transpose(1, 2).bmm(LG.view(batch_size, 1, -1)))
# Forgetting Module
# h_pre: (bs, n_skill, d_k)
# LG: (bs, d_k)
# it: (bs, d_k)
n_skill = LG_tilde.size(1)
if self.use_time:
gamma_f = self.sig(self.linear_4(torch.cat((
h_pre,
LG.repeat(1, n_skill).view(batch_size, -1, self.d_k),
it.repeat(1, n_skill).view(batch_size, -1, self.d_k)
), 2)))
else:
gamma_f = self.sig(self.linear_8(torch.cat((
h_pre,
LG.repeat(1, n_skill).view(batch_size, -1, self.d_k)
), 2)))
h = LG_tilde + gamma_f * h_pre
# Predicting Module
h_tilde = self.q_matrix[e_data[:, t + 1]].view(batch_size, 1, -1).bmm(h).view(batch_size, self.d_k)
# print(f"h_tilde: {h_tilde.shape}")
y = self.sig(self.linear_5(torch.cat((e_embed_data[:, t + 1], h_tilde), 1))).sum(1) / self.d_k
pred[:, t + 1] = y
hidden_state[:, t+1, :] = h_tilde
# prepare for next prediction
learning_pre = learning
h_pre = h
h_tilde_pre = h_tilde
if not qtest:
return pred
else:
return pred, hidden_state, e_data