import torch
import numpy as np
import os
from .dkt import DKT
from .dkt_plus import DKTPlus
from .dkvmn import DKVMN
from .sakt import SAKT
from .saint import SAINT
from .kqn import KQN
from .atkt import ATKT
from .dkt_forget import DKTForget
from .akt import AKT
from .gkt import GKT
from .gkt_utils import get_gkt_graph
from .lpkt import LPKT
from .lpkt_utils import generate_qmatrix
from .skvmn import SKVMN
from .hawkes import HawkesKT
from .iekt import IEKT
device = "cpu" if not torch.cuda.is_available() else "cuda"
[docs]def init_model(model_name, model_config, data_config, emb_type):
if model_name == "dkt":
model = DKT(data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
elif model_name == "dkt+":
model = DKTPlus(data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
elif model_name == "dkvmn":
model = DKVMN(data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
elif model_name == "sakt":
model = SAKT(data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
elif model_name == "saint":
model = SAINT(data_config["num_q"], data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
elif model_name == "dkt_forget":
model = DKTForget(data_config["num_c"], data_config["num_rgap"], data_config["num_sgap"], data_config["num_pcount"], **model_config).to(device)
elif model_name == "akt":
model = AKT(data_config["num_c"], data_config["num_q"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
elif model_name == "kqn":
model = KQN(data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
elif model_name == "atkt":
model = ATKT(data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"], fix=False).to(device)
elif model_name == "atktfix":
model = ATKT(data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"], fix=True).to(device)
elif model_name == "gkt":
graph_type = model_config['graph_type']
fname = f"gkt_graph_{graph_type}.npz"
graph_path = os.path.join(data_config["dpath"], fname)
if os.path.exists(graph_path):
graph = torch.tensor(np.load(graph_path, allow_pickle=True)['matrix']).float()
else:
graph = get_gkt_graph(data_config["num_c"], data_config["dpath"],
data_config["train_valid_original_file"], data_config["test_original_file"], graph_type=graph_type, tofile=fname)
graph = torch.tensor(graph).float()
model = GKT(data_config["num_c"], **model_config,graph=graph,emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
elif model_name == "lpkt":
qmatrix_path = os.path.join(data_config["dpath"], "qmatrix.npz")
if os.path.exists(qmatrix_path):
q_matrix = np.load(qmatrix_path, allow_pickle=True)['matrix']
else:
q_matrix = generate_qmatrix(data_config)
q_matrix = torch.tensor(q_matrix).float().to(device)
model = LPKT(data_config["num_at"], data_config["num_it"], data_config["num_q"], data_config["num_c"], **model_config, q_matrix=q_matrix, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
elif model_name == "skvmn":
model = SKVMN(data_config["num_c"], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"]).to(device)
elif model_name == "hawkes":
if data_config["num_q"] == 0 or data_config["num_c"] == 0:
print(f"model: {model_name} needs questions ans concepts! but the dataset has no both")
return None
model = HawkesKT(data_config["num_c"], data_config["num_q"], **model_config)
model = model.double()
# print("===before init weights"+"@"*100)
# model.printparams()
model.apply(model.init_weights)
# print("===after init weights")
# model.printparams()
model = model.to(device)
elif model_name == "iekt":
model = IEKT(num_q=data_config['num_q'], num_c=data_config['num_c'],
max_concepts=data_config['max_concepts'], **model_config, emb_type=emb_type, emb_path=data_config["emb_path"],device=device).to(device)
else:
print("The wrong model name was used...")
return None
return model
[docs]def load_model(model_name, model_config, data_config, emb_type, ckpt_path):
model = init_model(model_name, model_config, data_config, emb_type)
net = torch.load(os.path.join(ckpt_path, emb_type+"_model.ckpt"))
model.load_state_dict(net)
return model