Source code for pykt.models.train_model

import os, sys
import torch
import torch.nn as nn
from torch.nn.functional import one_hot, binary_cross_entropy
import numpy as np
from .evaluate_model import evaluate
from torch.autograd import Variable, grad
from .atkt import _l2_normalize_adv
from ..utils.utils import debug_print

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

[docs]def cal_loss(model, ys, r, rshft, sm, preloss=[]): model_name = model.model_name if model_name in ["dkt", "dkt_forget", "dkvmn", "kqn", "sakt", "saint", "atkt", "atktfix", "gkt", "skvmn", "hawkes"]: y = torch.masked_select(ys[0], sm) t = torch.masked_select(rshft, sm) loss = binary_cross_entropy(y.double(), t.double()) elif model_name == "dkt+": y_curr = torch.masked_select(ys[1], sm) y_next = torch.masked_select(ys[0], sm) r_curr = torch.masked_select(r, sm) r_next = torch.masked_select(rshft, sm) loss = binary_cross_entropy(y_next.double(), r_next.double()) loss_r = binary_cross_entropy(y_curr.double(), r_curr.double()) # if answered wrong for C in t-1, cur answer for C should be wrong too loss_w1 = torch.masked_select(torch.norm(ys[2][:, 1:] - ys[2][:, :-1], p=1, dim=-1), sm[:, 1:]) loss_w1 = loss_w1.mean() / model.num_c loss_w2 = torch.masked_select(torch.norm(ys[2][:, 1:] - ys[2][:, :-1], p=2, dim=-1) ** 2, sm[:, 1:]) loss_w2 = loss_w2.mean() / model.num_c loss = loss + model.lambda_r * loss_r + model.lambda_w1 * loss_w1 + model.lambda_w2 * loss_w2 elif model_name in ["akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]: y = torch.masked_select(ys[0], sm) t = torch.masked_select(rshft, sm) loss = binary_cross_entropy(y.double(), t.double()) + preloss[0] elif model_name == "lpkt": y = torch.masked_select(ys[0], sm) t = torch.masked_select(rshft, sm) criterion = nn.BCELoss(reduction='none') loss = criterion(y, t).sum() return loss
[docs]def model_forward(model, data): model_name = model.model_name # if model_name in ["dkt_forget", "lpkt"]: # q, c, r, qshft, cshft, rshft, m, sm, d, dshft = data if model_name in ["dkt_forget"]: dcur, dgaps = data else: dcur = data q, c, r, t = dcur["qseqs"], dcur["cseqs"], dcur["rseqs"], dcur["tseqs"] qshft, cshft, rshft, tshft = dcur["shft_qseqs"], dcur["shft_cseqs"], dcur["shft_rseqs"], dcur["shft_tseqs"] m, sm = dcur["masks"], dcur["smasks"] ys, preloss = [], [] cq = torch.cat((q[:,0:1], qshft), dim=1) cc = torch.cat((c[:,0:1], cshft), dim=1) cr = torch.cat((r[:,0:1], rshft), dim=1) if model_name in ["hawkes"]: ct = torch.cat((t[:,0:1], tshft), dim=1) if model_name in ["lpkt"]: # cat = torch.cat((d["at_seqs"][:,0:1], dshft["at_seqs"]), dim=1) cit = torch.cat((dcur["itseqs"][:,0:1], dcur["shft_itseqs"]), dim=1) if model_name in ["dkt"]: y = model(c.long(), r.long()) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) ys.append(y) # first: yshft elif model_name == "dkt+": y = model(c.long(), r.long()) y_next = (y * one_hot(cshft.long(), model.num_c)).sum(-1) y_curr = (y * one_hot(c.long(), model.num_c)).sum(-1) ys = [y_next, y_curr, y] elif model_name in ["dkt_forget"]: y = model(c.long(), r.long(), dgaps) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) ys.append(y) elif model_name in ["dkvmn", "skvmn"]: y = model(cc.long(), cr.long()) ys.append(y[:,1:]) elif model_name in ["kqn", "sakt"]: y = model(c.long(), r.long(), cshft.long()) ys.append(y) elif model_name in ["saint"]: y = model(cq.long(), cc.long(), r.long()) ys.append(y[:, 1:]) elif model_name in ["akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]: y, reg_loss = model(cc.long(), cr.long(), cq.long()) ys.append(y[:,1:]) preloss.append(reg_loss) elif model_name in ["atkt", "atktfix"]: y, features = model(c.long(), r.long()) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) loss = cal_loss(model, [y], r, rshft, sm) # at features_grad = grad(loss, features, retain_graph=True) p_adv = torch.FloatTensor(model.epsilon * _l2_normalize_adv(features_grad[0].data)) p_adv = Variable(p_adv).to(device) pred_res, _ = model(c.long(), r.long(), p_adv) # second loss pred_res = (pred_res * one_hot(cshft.long(), model.num_c)).sum(-1) adv_loss = cal_loss(model, [pred_res], r, rshft, sm) loss = loss + model.beta * adv_loss elif model_name == "gkt": y = model(cc.long(), cr.long()) ys.append(y) # cal loss elif model_name == "lpkt": # y = model(cq.long(), cr.long(), cat, cit.long()) y = model(cq.long(), cr.long(), cit.long()) ys.append(y[:, 1:]) elif model_name == "hawkes": # ct = torch.cat((dcur["tseqs"][:,0:1], dcur["shft_tseqs"]), dim=1) # csm = torch.cat((dcur["smasks"][:,0:1], dcur["smasks"]), dim=1) # y = model(cc[0:1,0:5].long(), cq[0:1,0:5].long(), ct[0:1,0:5].long(), cr[0:1,0:5].long(), csm[0:1,0:5].long()) y = model(cc.long(), cq.long(), ct.long(), cr.long())#, csm.long()) ys.append(y[:, 1:]) elif model_name == "iekt": y,loss = model.train_one_step(data) if model_name not in ["atkt", "atktfix","iekt"]: loss = cal_loss(model, ys, r, rshft, sm, preloss) return loss
[docs]def train_model(model, train_loader, valid_loader, num_epochs, opt, ckpt_path, test_loader=None, test_window_loader=None, save_model=False): max_auc, best_epoch = 0, -1 train_step = 0 if model.model_name=='lpkt': scheduler = torch.optim.lr_scheduler.StepLR(opt, 10, gamma=0.5) for i in range(1, num_epochs + 1): loss_mean = [] for data in train_loader: train_step+=1 if model.model_name=='iekt': model.model.train() else: model.train() loss = model_forward(model, data) opt.zero_grad() loss.backward()#compute gradients opt.step()#update model’s parameters loss_mean.append(loss.detach().cpu().numpy()) if model.model_name == "gkt" and train_step%10==0: text = f"Total train step is {train_step}, the loss is {loss.item():.5}" debug_print(text = text,fuc_name="train_model") if model.model_name=='lpkt': scheduler.step()#update each epoch loss_mean = np.mean(loss_mean) auc, acc = evaluate(model, valid_loader, model.model_name) ### atkt 有diff, 以下代码导致的 ### auc, acc = round(auc, 4), round(acc, 4) if auc > max_auc: if save_model: torch.save(model.state_dict(), os.path.join(ckpt_path, model.emb_type+"_model.ckpt")) max_auc = auc best_epoch = i testauc, testacc = -1, -1 window_testauc, window_testacc = -1, -1 if not save_model: if test_loader != None: save_test_path = os.path.join(ckpt_path, model.emb_type+"_test_predictions.txt") testauc, testacc = evaluate(model, test_loader, model.model_name, save_test_path) if test_window_loader != None: save_test_path = os.path.join(ckpt_path, model.emb_type+"_test_window_predictions.txt") window_testauc, window_testacc = evaluate(model, test_window_loader, model.model_name, save_test_path) # window_testauc, window_testacc = -1, -1 validauc, validacc = round(auc, 4), round(acc, 4)#model.evaluate(valid_loader, emb_type) # trainauc, trainacc = model.evaluate(train_loader, emb_type) testauc, testacc, window_testauc, window_testacc = round(testauc, 4), round(testacc, 4), round(window_testauc, 4), round(window_testacc, 4) max_auc = round(max_auc, 4) print(f"Epoch: {i}, validauc: {validauc}, validacc: {validacc}, best epoch: {best_epoch}, best auc: {max_auc}, loss: {loss_mean}, emb_type: {model.emb_type}, model: {model.model_name}, save_dir: {ckpt_path}") print(f" testauc: {testauc}, testacc: {testacc}, window_testauc: {window_testauc}, window_testacc: {window_testacc}") if i - best_epoch >= 10: break return testauc, testacc, window_testauc, window_testacc, validauc, validacc, best_epoch