Source code for pykt.models.evaluate_model

import numpy as np
import torch
from torch import nn
from torch.nn.functional import one_hot
from sklearn import metrics
from ..datasets.lpkt_utils import generate_time2idx

import pandas as pd

device = "cpu" if not torch.cuda.is_available() else "cuda"

[docs]def save_cur_predict_result(dres, q, r, d, t, m, sm, p): # dres, q, r, qshft, rshft, m, sm, y results = [] for i in range(0, t.shape[0]): cps = torch.masked_select(p[i], sm[i]).detach().cpu() cts = torch.masked_select(t[i], sm[i]).detach().cpu() cqs = torch.masked_select(q[i], m[i]).detach().cpu() crs = torch.masked_select(r[i], m[i]).detach().cpu() cds = torch.masked_select(d[i], sm[i]).detach().cpu() qs, rs, ts, ps, ds = [], [], [], [], [] for cq, cr in zip(cqs.int(), crs.int()): qs.append(cq.item()) rs.append(cr.item()) for ct, cp, cd in zip(cts.int(), cps, cds.int()): ts.append(ct.item()) ps.append(cp.item()) ds.append(cd.item()) try: auc = metrics.roc_auc_score( y_true=np.array(ts), y_score=np.array(ps) ) except Exception as e: # print(e) auc = -1 prelabels = [1 if p >= 0.5 else 0 for p in ps] acc = metrics.accuracy_score(ts, prelabels) dres[len(dres)] = [qs, rs, ds, ts, ps, prelabels, auc, acc] results.append(str([qs, rs, ds, ts, ps, prelabels, auc, acc])) return "\n".join(results)
[docs]def evaluate(model, test_loader, model_name, save_path=""): if save_path != "": fout = open(save_path, "w", encoding="utf8") with torch.no_grad(): y_trues = [] y_scores = [] dres = dict() for data in test_loader: # if model_name in ["dkt_forget", "lpkt"]: # q, c, r, qshft, cshft, rshft, m, sm, d, dshft = data if model_name in ["dkt_forget"]: dcur, dgaps = data else: dcur = data q, c, r = dcur["qseqs"], dcur["cseqs"], dcur["rseqs"] qshft, cshft, rshft = dcur["shft_qseqs"], dcur["shft_cseqs"], dcur["shft_rseqs"] m, sm = dcur["masks"], dcur["smasks"] q, c, r, qshft, cshft, rshft, m, sm = q.to(device), c.to(device), r.to(device), qshft.to(device), cshft.to(device), rshft.to(device), m.to(device), sm.to(device) if model.model_name=='iekt': model.model.eval() else: model.eval() # print(f"before y: {y.shape}") cq = torch.cat((q[:,0:1], qshft), dim=1) cc = torch.cat((c[:,0:1], cshft), dim=1) cr = torch.cat((r[:,0:1], rshft), dim=1) if model_name in ["dkt", "dkt+"]: y = model(c.long(), r.long()) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) elif model_name in ["dkt_forget"]: y = model(c.long(), r.long(), dgaps) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) elif model_name in ["dkvmn", "skvmn"]: y = model(cc.long(), cr.long()) y = y[:,1:] elif model_name in ["kqn", "sakt"]: y = model(c.long(), r.long(), cshft.long()) elif model_name == "saint": y = model(cq.long(), cc.long(), r.long()) y = y[:, 1:] elif model_name in ["akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]: y, reg_loss = model(cc.long(), cr.long(), cq.long()) y = y[:,1:] elif model_name in ["atkt", "atktfix"]: y, _ = model(c.long(), r.long()) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) elif model_name == "gkt": y = model(cc.long(), cr.long()) elif model_name == "lpkt": # cat = torch.cat((d["at_seqs"][:,0:1], dshft["at_seqs"]), dim=1).to(device) cit = torch.cat((dcur["itseqs"][:,0:1], dcur["shft_itseqs"]), dim=1) y = model(cq.long(), cr.long(), cit.long()) y = y[:,1:] elif model_name == "hawkes": ct = torch.cat((dcur["tseqs"][:,0:1], dcur["shft_tseqs"]), dim=1) # csm = torch.cat((dcur["smasks"][:,0:1], dcur["smasks"]), dim=1) y = model(cc.long(), cq.long(), ct.long(), cr.long())#, csm.long()) y = y[:, 1:] elif model_name == "iekt": y = model.predict_one_step(data) c,cshft = q,qshft#question level # print(f"after y: {y.shape}") # save predict result if save_path != "": result = save_cur_predict_result(dres, c, r, cshft, rshft, m, sm, y) fout.write(result+"\n") y = torch.masked_select(y, sm).detach().cpu() t = torch.masked_select(rshft, sm).detach().cpu() y_trues.append(t.numpy()) y_scores.append(y.numpy()) ts = np.concatenate(y_trues, axis=0) ps = np.concatenate(y_scores, axis=0) print(f"ts.shape: {ts.shape}, ps.shape: {ps.shape}") auc = metrics.roc_auc_score(y_true=ts, y_score=ps) prelabels = [1 if p >= 0.5 else 0 for p in ps] acc = metrics.accuracy_score(ts, prelabels) # if save_path != "": # pd.to_pickle(dres, save_path+".pkl") return auc, acc
[docs]def early_fusion(curhs, model, model_name): if model_name in ["dkvmn", "skvmn"]: p = model.p_layer(model.dropout_layer(curhs[0])) p = torch.sigmoid(p) p = p.squeeze(-1) elif model_name in ["akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]: output = model.out(curhs[0]).squeeze(-1) m = nn.Sigmoid() p = m(output) elif model_name == "saint": p = model.out(model.dropout(curhs[0])) p = torch.sigmoid(p).squeeze(-1) elif model_name == "sakt": p = torch.sigmoid(model.pred(model.dropout_layer(curhs[0]))).squeeze(-1) elif model_name == "kqn": logits = torch.sum(curhs[0] * curhs[1], dim=1) # (batch_size, max_seq_len) p = model.sigmoid(logits) elif model_name == "hawkes": p = curhs[0].sigmoid() elif model_name == "lpkt": y = model.sig(model.linear_5(torch.cat((curhs[1], curhs[0]), 1))).sum(1) / self.d_k p = curhs[0].sigmoid() return p
[docs]def late_fusion(dcur, curdf, fusion_type=["mean", "vote", "all"]): high, low = [], [] for pred in curdf["preds"]: if pred >= 0.5: high.append(pred) else: low.append(pred) if "mean" in fusion_type: dcur.setdefault("late_mean", []) dcur["late_mean"].append(round(curdf["preds"].mean().astype(float), 4)) if "vote" in fusion_type: dcur.setdefault("late_vote", []) correctnum = list(curdf["preds"]>=0.5).count(True) late_vote = np.mean(high) if correctnum / len(curdf["preds"]) >= 0.5 else np.mean(low) dcur["late_vote"].append(late_vote) if "all" in fusion_type: dcur.setdefault("late_all", []) late_all = np.mean(high) if correctnum == len(curdf["preds"]) else np.mean(low) dcur["late_all"].append(late_all) return
[docs]def effective_fusion(df, model, model_name, fusion_type): dres = dict() df = df.groupby("qidx", as_index=True, sort=True)#.mean() curhs, curr = [[], []], [] dcur = {"late_trues": [], "qidxs": [], "questions": [], "concepts": [], "row": [], "concept_preds": []} hasearly = ["dkvmn", "skvmn", "akt", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"] for ui in df: # 一题一题处理 curdf = ui[1] if model_name in hasearly: curhs[0].append(curdf["hidden"].mean().astype(float)) elif model_name == "kqn": curhs[0].append(curdf["ek"].mean().astype(float)) curhs[1].append(curdf["es"].mean().astype(float)) elif model_name == "lpkt": curhs[0].append(curdf["h"].mean().astype(float)) curhs[1].append(curdf["e_data"].mean().astype(float)) else: # print(f"model: {model_name} has no early fusion res!") pass curr.append(curdf["response"].mean().astype(int)) dcur["late_trues"].append(curdf["response"].mean().astype(int)) dcur["qidxs"].append(ui[0]) dcur["row"].append(curdf["row"].mean().astype(int)) dcur["questions"].append(",".join([str(int(s)) for s in curdf["questions"].tolist()])) dcur["concepts"].append(",".join([str(int(s)) for s in curdf["concepts"].tolist()])) late_fusion(dcur, curdf) # save original predres in concepts dcur["concept_preds"].append(",".join([str(round(s, 4)) for s in (curdf["preds"].tolist())])) for key in dcur: dres.setdefault(key, []) dres[key].append(np.array(dcur[key])) # early fusion if "early_fusion" in fusion_type and model_name in hasearly: curhs = [torch.tensor(curh).float().to(device) for curh in curhs] curr = torch.tensor(curr).long().to(device) p = early_fusion(curhs, model, model_name) dres.setdefault("early_trues", []) dres["early_trues"].append(curr.cpu().numpy()) dres.setdefault("early_preds", []) dres["early_preds"].append(p.cpu().numpy()) return dres
[docs]def group_fusion(dmerge, model, model_name, fusion_type, fout): hs, sms, cq, cc, rs, ps, qidxs, rests, orirows = dmerge["hs"], dmerge["sm"], dmerge["cq"], dmerge["cc"], dmerge["cr"], dmerge["y"], dmerge["qidxs"], dmerge["rests"], dmerge["orirow"] if cq.shape[1] == 0: cq = cc hasearly = ["dkvmn", "skvmn", "kqn", "akt", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"] alldfs, drest = [], dict() # not predict infos! # print(f"real bz in group fusion: {rs.shape[0]}") realbz = rs.shape[0] for bz in range(rs.shape[0]): cursm = ([0] + sms[bz].cpu().tolist()) curqidxs = ([-1] + qidxs[bz].cpu().tolist()) currests = ([-1] + rests[bz].cpu().tolist()) currows = ([-1] + orirows[bz].cpu().tolist()) curps = ([-1] + ps[bz].cpu().tolist()) # print(f"qid: {len(curqidxs)}, select: {len(cursm)}, response: {len(rs[bz].cpu().tolist())}, preds: {len(curps)}") df = pd.DataFrame({"qidx": curqidxs, "rest": currests, "row": currows, "select": cursm, "questions": cq[bz].cpu().tolist(), "concepts": cc[bz].cpu().tolist(), "response": rs[bz].cpu().tolist(), "preds": curps}) if model_name in hasearly and model_name != "kqn" and model_name != "lpkt": df["hidden"] = [np.array(a) for a in hs[0][bz].cpu().tolist()] elif model_name == "kqn": df["ek"] = [np.array(a) for a in hs[0][bz].cpu().tolist()] df["es"] = [np.array(a) for a in hs[1][bz].cpu().tolist()] elif model_name == "lpkt": df["h"] = [np.array(a) for a in hs[0][bz].cpu().tolist()] df["e_data"] = [np.array(a) for a in hs[1][bz].cpu().tolist()] df = df[df["select"] != 0] alldfs.append(df) effective_dfs, rest_start = [], -1 flag = False for i in range(len(alldfs) - 1, -1, -1): df = alldfs[i] counts = (df["rest"] == 0).value_counts() if not flag and False not in counts: # has no question rest > 0 flag =True effective_dfs.append(df) rest_start = i + 1 elif flag: effective_dfs.append(df) if rest_start == -1: rest_start = 0 # merge rest for key in dmerge.keys(): if key == "hs": drest[key] = [] if model_name in hasearly and model_name != "kqn" and model_name != "lpkt": drest[key] = [dmerge[key][0][rest_start:]] elif model_name == "kqn" or model_name == "lpkt": drest[key] = [dmerge[key][0][rest_start:], dmerge[key][1][rest_start:]] else: drest[key] = dmerge[key][rest_start:] restlen = drest["cr"].shape[0] dfs = dict() for df in effective_dfs: for i, row in df.iterrows(): for key in row.keys(): dfs.setdefault(key, []) dfs[key].extend([row[key]]) df = pd.DataFrame(dfs) # print(f"real bz: {realbz}, effective_dfs: {len(effective_dfs)}, rest_start: {rest_start}, drestlen: {restlen}, predict infos: {df.shape}") if df.shape[0] == 0: return {}, drest dres = effective_fusion(df, model, model_name, fusion_type) dfinal = dict() for key in dres: dfinal[key] = np.concatenate(dres[key], axis=0) early = False if model_name in hasearly and "early_fusion" in fusion_type: early = True save_question_res(dfinal, fout, early) return dfinal , drest
[docs]def save_question_res(dres, fout, early=False): # print(f"dres: {dres.keys()}") # qidxs, late_trues, late_mean, late_vote, late_all, early_trues, early_preds for i in range(0, len(dres["qidxs"])): row, qidx, qs, cs, lt, lm, lv, la = dres["row"][i], dres["qidxs"][i], dres["questions"][i], dres["concepts"][i], \ dres["late_trues"][i], dres["late_mean"][i], dres["late_vote"][i], dres["late_all"][i] conceptps = dres["concept_preds"][i] curres = [row, qidx, qs, cs, conceptps, lt, lm, lv, la] if early: et, ep = dres["early_trues"][i], dres["early_preds"][i] curres = curres + [et, ep] curstr = "\t".join([str(round(s, 4)) if type(s) == type(0.1) or type(s) == np.float32 else str(s) for s in curres]) fout.write(curstr + "\n")
[docs]def evaluate_question(model, test_loader, model_name, fusion_type=["early_fusion", "late_fusion"], save_path=""): # dkt / dkt+ / dkt_forget / atkt: give past -> predict all. has no early fusion!!! # dkvmn / akt / saint: give cur -> predict cur # sakt: give past+cur -> predict cur # kqn: give past+cur -> predict cur hasearly = ["dkvmn", "skvmn", "kqn", "akt", "saint", "sakt", "hawkes", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx", "lpkt"] if save_path != "": fout = open(save_path, "w", encoding="utf8") if model_name in hasearly: fout.write("\t".join(["orirow", "qidx", "questions", "concepts", "concept_preds", "late_trues", "late_mean", "late_vote", "late_all", "early_trues", "early_preds"]) + "\n") else: fout.write("\t".join(["orirow", "qidx", "questions", "concepts", "concept_preds", "late_trues", "late_mean", "late_vote", "late_all"]) + "\n") with torch.no_grad(): dinfos = dict() dhistory = dict() history_keys = ["hs", "sm", "cq", "cc", "cr", "y", "qidxs", "rests", "orirow"] # for key in history_keys: # dhistory[key] = [] y_trues, y_scores = [], [] lenc = 0 for data in test_loader: if model_name in ["dkt_forget"]: dcurori, dgaps, dqtest = data else: dcurori, dqtest = data q, c, r = dcurori["qseqs"], dcurori["cseqs"], dcurori["rseqs"] qshft, cshft, rshft = dcurori["shft_qseqs"], dcurori["shft_cseqs"], dcurori["shft_rseqs"] m, sm = dcurori["masks"], dcurori["smasks"] q, c, r, qshft, cshft, rshft, m, sm = q.to(device), c.to(device), r.to(device), qshft.to(device), cshft.to(device), rshft.to(device), m.to(device), sm.to(device) qidxs, rests, orirow = dqtest["qidxs"], dqtest["rests"], dqtest["orirow"] lenc += q.shape[0] # print("="*20) # print(f"start predict seqlen: {lenc}") model.eval() # print(f"before y: {y.shape}") cq = torch.cat((q[:,0:1], qshft), dim=1) cc = torch.cat((c[:,0:1], cshft), dim=1) cr = torch.cat((r[:,0:1], rshft), dim=1) dcur = dict() if model_name in ["dkvmn", "skvmn"]: y, h = model(cc.long(), cr.long(), True) y = y[:,1:] elif model_name in ["akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]: y, reg_loss, h = model(cc.long(), cr.long(), cq.long(), True) y = y[:,1:] elif model_name == "saint": y, h = model(cq.long(), cc.long(), r.long(), True) y = y[:,1:] elif model_name == "sakt": y, h = model(c.long(), r.long(), cshft.long(), True) start_hemb = torch.tensor([-1] * (h.shape[0] * h.shape[2])).reshape(h.shape[0], 1, h.shape[2]).to(device) # print(start_hemb.shape, h.shape) h = torch.cat((start_hemb, h), dim=1) # add the first hidden emb elif model_name == "kqn": y, ek, es = model(c.long(), r.long(), cshft.long(), True) # print(f"ek: {ek.shape}, es: {es.shape}") start_hemb = torch.tensor([-1] * (ek.shape[0] * ek.shape[2])).reshape(ek.shape[0], 1, ek.shape[2]).to(device) ek = torch.cat((start_hemb, ek), dim=1) # add the first hidden emb es = torch.cat((start_hemb, es), dim=1) # add the first hidden emb elif model_name in ["dkt", "dkt+"]: y = model(c.long(), r.long()) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) elif model_name in ["dkt_forget"]: y = model(c.long(), r.long(), dgaps) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) elif model_name in ["atkt", "atktfix"]: y, _ = model(c.long(), r.long()) y = (y * one_hot(cshft.long(), model.num_c)).sum(-1) elif model_name == "gkt": y = model(cc.long(), cr.long()) elif model_name == "hawkes": ct = torch.cat((dcurori["tseqs"][:,0:1], dcurori["shft_tseqs"]), dim=1) y = model(cc.long(), cq.long(), ct.long(), cr.long(), True) y, h = y[:, 1:] elif model_name == "lpkt": cit = torch.cat((dcur["itseqs"][:,0:1], dcur["shft_itseqs"]), dim=1) y, h, e_data = model(cq.long(), cr.long(), cit.long(), True) start_hemb = torch.tensor([-1] * (h.shape[0] * h.shape[2])).reshape(h.shape[0], 1, h.shape[2]).to(device) # add the first hidden emb h = torch.cat((start_hemb, h), dim=1) # e_data = torch.cat((start_hemb, e_data), dim=1) y = y[:, 1:] concepty = torch.masked_select(y, sm).detach().cpu() conceptt = torch.masked_select(rshft, sm).detach().cpu() y_trues.append(conceptt.numpy()) y_scores.append(concepty.numpy()) # hs, sms, rs, ps, qidxs, model, model_name, fusion_type hs = [] if model_name == "kqn": hs = [ek, es] elif model_name == "lpkt": hs = [h, e_data] elif model_name in hasearly: hs = [h] dcur["hs"], dcur["sm"], dcur["cq"], dcur["cc"], dcur["cr"], dcur["y"], dcur["qidxs"], dcur["rests"], dcur["orirow"] = hs, sm, cq, cc, cr, y, qidxs, rests, orirow # merge history dmerge = dict() for key in history_keys: if len(dhistory) == 0: dmerge[key] = dcur[key] else: if key == "hs": dmerge[key] = [] if model_name == "kqn": dmerge[key] = [[], []] dmerge[key][0] = torch.cat((dhistory[key][0], dcur[key][0]), dim=0) dmerge[key][1] = torch.cat((dhistory[key][1], dcur[key][1]), dim=0) elif model_name == "lpkt": dmerge[key] = [[], []] dmerge[key][0] = torch.cat((dhistory[key][0], dcur[key][0]), dim=0) dmerge[key][1] = torch.cat((dhistory[key][1], dcur[key][1]), dim=0) elif model_name in hasearly: dmerge[key] = [torch.cat((dhistory[key][0], dcur[key][0]), dim=0)] else: dmerge[key] = torch.cat((dhistory[key], dcur[key]), dim=0) dcur, dhistory = group_fusion(dmerge, model, model_name, fusion_type, fout) for key in dcur: dinfos.setdefault(key, []) dinfos[key].append(dcur[key]) if "early_fusion" in dinfos and "late_fusion" in dinfos: assert dinfos["early_trues"][-1].all() == dinfos["late_trues"][-1].all() # import sys # sys.exit() # ori concept eval aucs, accs = dict(), dict() ts = np.concatenate(y_trues, axis=0) ps = np.concatenate(y_scores, axis=0) # print(f"ts.shape: {ts.shape}, ps.shape: {ps.shape}") auc = metrics.roc_auc_score(y_true=ts, y_score=ps) prelabels = [1 if p >= 0.5 else 0 for p in ps] acc = metrics.accuracy_score(ts, prelabels) aucs["concepts"] = auc accs["concepts"] = acc # print(f"dinfos: {dinfos.keys()}") for key in dinfos: if key not in ["late_mean", "late_vote", "late_all", "early_preds"]: continue ts = np.concatenate(dinfos['late_trues'], axis=0) # early_trues == late_trues ps = np.concatenate(dinfos[key], axis=0) # print(f"key: {key}, ts.shape: {ts.shape}, ps.shape: {ps.shape}") auc = metrics.roc_auc_score(y_true=ts, y_score=ps) prelabels = [1 if p >= 0.5 else 0 for p in ps] acc = metrics.accuracy_score(ts, prelabels) aucs[key] = auc accs[key] = acc return aucs, accs
[docs]def log2(t): import math return round(math.log(t+1, 2))
[docs]def calC(row, data_config): repeated_gap, sequence_gap, past_counts = [], [], [] uid = row["uid"] # default: concepts skills = row["concepts"].split(",") timestamps = row["timestamps"].split(",") dlastskill, dcount = dict(), dict() pret = None idx = -1 for s, t in zip(skills, timestamps): idx += 1 s, t = int(s), int(t) if s not in dlastskill or s == -1: curRepeatedGap = 0 else: curRepeatedGap = log2((t - dlastskill[s]) / 1000 / 60) + 1 # minutes dlastskill[s] = t repeated_gap.append(curRepeatedGap) if pret == None or t == -1: curLastGap = 0 else: curLastGap = log2((t - pret) / 1000 / 60) + 1 pret = t sequence_gap.append(curLastGap) dcount.setdefault(s, 0) ccount = log2(dcount[s]) ccount = data_config["num_pcount"] - 1 if ccount >= data_config["num_pcount"] else ccount past_counts.append(ccount) dcount[s] += 1 return repeated_gap, sequence_gap, past_counts
[docs]def get_info_dkt_forget(row, data_config): dforget = dict() rgap, sgap, pcount = calC(row, data_config) ## TODO dforget["rgaps"], dforget["sgaps"], dforget["pcounts"] = rgap, sgap, pcount return dforget
[docs]def evaluate_splitpred_question(model, data_config, testf, model_name, save_path="", use_pred=False, train_ratio=0.2, atkt_pad=False): if save_path != "": fout = open(save_path, "w", encoding="utf8") if model_name == "lpkt": at2idx, it2idx = generate_time2idx(data_config) with torch.no_grad(): y_trues = [] y_scores = [] dres = dict() idx = 0 df = pd.read_csv(testf) dcres, dqres = {"trues": [], "preds": []}, {"trues": [], "late_mean": [], "late_vote": [], "late_all": []} for i, row in df.iterrows(): # print(f"idx: {idx}") # if idx == 2: # import sys # sys.exit() model.eval() dforget = dict() if model_name != "dkt_forget" else get_info_dkt_forget(row, data_config) concepts, responses = row["concepts"].split(","), row["responses"].split(",") curl = len(responses) # print("="*20) is_repeat = ["0"] * curl if "is_repeat" not in row else row["is_repeat"].split(",") is_repeat = [int(s) for s in is_repeat] questions = [] if "questions" not in row else row["questions"].split(",") times = [] if "timestamps" not in row else row["timestamps"].split(",") if model_name == "lpkt": shft_times = [0] + times[:-1] it_times = np.maximum(np.minimum((np.array(timestamps) - np.array(shft_timestamps)) // 60, 43200),-1) it_times = [it2idx[str(t)] for t in it] qlen, qtrainlen, ctrainlen = get_cur_teststart(is_repeat, train_ratio) # print(f"idx: {idx}, qlen: {qlen}, qtrainlen: {qtrainlen}, ctrainlen: {ctrainlen}") # print(concepts) # print(responses) cq = torch.tensor([int(s) for s in questions]).to(device) cc = torch.tensor([int(s) for s in concepts]).to(device) cr = torch.tensor([int(s) for s in responses]).to(device) ct = torch.tensor([int(s) for s in times]).to(device) dtotal = {"cq": cq, "cc": cc, "cr": cr, "ct": ct} if model_name == "lpkt": cit = torch.tensor([int(s) for s in it_times]).to(device) dtotal["cit"] = cit # print(f"cc: {cc[0:ctrainlen]}") curcin, currin = cc[0:ctrainlen].unsqueeze(0), cr[0:ctrainlen].unsqueeze(0) curqin = cq[0:ctrainlen].unsqueeze(0) if cq.shape[0] > 0 else cq curtin = ct[0:ctrainlen].unsqueeze(0) if ct.shape[0] > 0 else ct if model_name == "lpkt": curitin = ct[0:ctrainlen].unsqueeze(0) if cit.shape[0] > 0 else cit dcur = {"curqin": curqin, "curcin": curcin, "currin": currin, "curtin": curtin} if model_name == "lpkt": dcur["curitin"] = curitin curdforget = dict() for key in dforget: dforget[key] = torch.tensor(dforget[key]).to(device) curdforget[key] = dforget[key][0:ctrainlen].unsqueeze(0) # print(f"curcin: {curcin}") t = ctrainlen ### 如果不用预测结果,可以从这里并行了 if not use_pred: uid, end = row["uid"], curl qidx = qtrainlen # qidxs, ctrues, cpreds = predict_each_group2(curdforget, dforget, is_repeat, qidx, uid, idx, curqin, curcin, currin, model_name, model, t, cq, cc, cr, end, fout, atkt_pad) # qidxs, ctrues, cpreds = predict_each_group2(curdforget, dforget, is_repeat, qidx, uid, idx, dcur, model_name, model, t, dtotal, end, fout, atkt_pad) qidxs, ctrues, cpreds = predict_each_group2(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid, idx, model_name, model, t, end, fout, atkt_pad) # 计算 save_currow_question_res(idx, dcres, dqres, qidxs, ctrues, cpreds, uid, fout) else: qidx = qtrainlen while t < curl: rtmp = [t] for k in range(t+1, curl): if is_repeat[k] != 0: rtmp.append(k) else: break # dfshape = curdforget["rgaps"].shape # print(f"currin: {currin.shape}, curdforget: {dfshape}, rtmp: {rtmp}") # print(f"rtmp: {rtmp}") end = rtmp[-1]+1 uid = row["uid"] # if use_pred: # curqin, curcin, currin, curdforget, ctrues, cpreds = predict_each_group(curdforget, dforget, is_repeat, qidx, uid, idx, curqin, curcin, currin, model_name, model, t, cq, cc, cr, end, fout, atkt_pad) curqin, curcin, currin, curdforget, ctrues, cpreds = predict_each_group(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid, idx, model_name, model, t, end, fout, atkt_pad) late_mean, late_vote, late_all = save_each_question_res(dcres, dqres, ctrues, cpreds) # print("\t".join([str(idx), str(uid), str(k), str(qidx), str(late_mean), str(late_vote), str(late_all)])) fout.write("\t".join([str(idx), str(uid), str(qidx), str(late_mean), str(late_vote), str(late_all)]) + "\n") t = end qidx += 1 idx += 1 dfinal = cal_predres(dcres, dqres) for key in dfinal: fout.write(key + "\t" + str(dfinal[key]) + "\n") return dfinal
[docs]def get_cur_teststart(is_repeat, train_ratio): curl = len(is_repeat) # print(is_repeat) qlen = is_repeat.count(0) qtrainlen = int(qlen * train_ratio) qtrainlen = 1 if qtrainlen == 0 else qtrainlen qtrainlen = qtrainlen - 1 if qtrainlen == qlen else qtrainlen # get real concept len ctrainlen, qidx = 0, 0 i = 0 while i < curl: if is_repeat[i] == 0: qidx += 1 # print(f"i: {i}, curl: {curl}, qidx: {qidx}, qtrainlen: {qtrainlen}") # qtrainlen = 7 if qlen>7 else qtrainlen if qidx == qtrainlen: break i += 1 for j in range(i+1, curl): if is_repeat[j] == 0: ctrainlen = j break return qlen, qtrainlen, ctrainlen
# def predict_each_group(curdforget, dforget, is_repeat, qidx, uid, idx, curqin, curcin, currin, model_name, model, t, cq, cc, cr, end, fout, atkt_pad=False, maxlen=200):
[docs]def predict_each_group(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid, idx, model_name, model, t, end, fout, atkt_pad=False, maxlen=200): """use the predict result as next question input """ curqin, curcin, currin, curtin = dcur["curqin"], dcur["curcin"], dcur["currin"], dcur["curtin"] cq, cc, cr, ct = dtotal["cq"], dtotal["cc"], dtotal["cr"], dtotal["ct"] if model_name == "lpkt": curitin = dcur["curitin"] cit = dtotal["cit"] nextcin, nextrin = curcin, currin import copy nextdforget = copy.deepcopy(curdforget) ctrues, cpreds = [], [] for k in range(t, end): qin, cin, rin, tin = curqin, curcin, currin, curtin if model_name == "lpkt": itin = curitin # 输入长度大于200时,截断 # print("cin: ", cin) start = 0 cinlen = cin.shape[1] if cinlen >= maxlen - 1: start = cinlen - maxlen + 1 cin, rin = cin[:,start:], rin[:,start:] # print(f"start: {start}, cin: {cin.shape}") if cq.shape[0] > 0: qin = qin[:, start:] if ct.shape[0] > 0: tin = tin[:, start:] if model_name == "lpkt": itin = itin[:, start:] # print(f"start: {start}, cin: {cin.shape}") cout, true = cc.long()[k], cr.long()[k] # 当前预测的是第k个 qout = None if cq.shape[0] == 0 else cq.long()[k] tout = None if ct.shape[0] == 0 else ct.long()[k] if model_name == "lpkt": itout = None if cit.shape[0] == 0 else cit.long()[k] if model_name in ["dkt", "dkt+"]: y = model(cin.long(), rin.long()) # print(y) pred = y[0][-1][cout.item()] elif model_name == "dkt_forget": din = dict() for key in curdforget: din[key] = curdforget[key][:,start:] dcur = dict() for key in dforget: curd = torch.tensor([[dforget[key][k]]]).long().to(device) dcur[key] = torch.cat((din[key][:,1:], curd), axis=1) # print(f"cin: {cin.shape}, dcur key: {dcur[key].shape}") # if idx == 13: # print(f"input to dktforget ! cin: {cin.shape}, k: {k}") # for key in dcur: # print(key, dcur[key].shape, din[key].shape, dcur[key], din[key]) y = model(cin.long(), rin.long(), din, dcur) pred = y[0][-1][cout.item()] elif model_name in ["kqn", "sakt"]: curc = torch.tensor([[cout.item()]]).to(device) cshft = torch.cat((cin[:,1:],curc), axis=1) y = model(cin.long(), rin.long(), cshft.long()) pred = y[0][-1] elif model_name == "saint": #### 输入有question! if qout != None: curq = torch.tensor([[qout.item()]]).to(device) qin = torch.cat((qin, curq), axis=1) curc = torch.tensor([[cout.item()]]).to(device) cin = torch.cat((cin, curc), axis=1) y = model(qin.long(), cin.long(), rin.long()) pred = y[0][-1] elif model_name in ["atkt", "atktfix"]: if atkt_pad == True: oricinlen = cin.shape[1] padlen = maxlen-1-oricinlen # print(f"padlen: {padlen}, cin: {cin.shape}") pad = torch.tensor([0]*(padlen)).unsqueeze(0).to(device) # curc = torch.tensor([[cout.item()]]).to(device) # cshft = torch.cat((cin[:,1:],curc), axis=1) cin = torch.cat((cin, pad), axis=1) rin = torch.cat((rin, pad), axis=1) y, _ = model(cin.long(), rin.long()) # print(f"y: {y}") if atkt_pad == True: # print(f"use idx: {oricinlen-1}") pred = y[0][oricinlen-1][cout.item()] else: pred = y[0][-1][cout.item()] elif model_name in ["dkvmn", "skvmn"]: curc, curr = torch.tensor([[cout.item()]]).to(device), torch.tensor([[true.item()]]).to(device) cin, rin = torch.cat((cin, curc), axis=1), torch.cat((rin, curr), axis=1) # print(f"cin: {cin.shape}, curc: {curc.shape}") # 应该用预测的r更新memory value,但是这里一个知识点一个知识点预测,所以curr不起作用! y = model(cin.long(), rin.long()) pred = y[0][-1] elif model_name in ["akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]: #### 输入有question! if qout != None: curq = torch.tensor([[qout.item()]]).to(device) qin = torch.cat((qin, curq), axis=1) # curr不起作用!当前预测不用curr,实际用的历史的也不一定是true,用的是rin,可能是预测,可能是历史 curc, curr = torch.tensor([[cout.item()]]).to(device), torch.tensor([[1]]).to(device) cin, rin = torch.cat((cin, curc), axis=1), torch.cat((rin, curr), axis=1) y, reg_loss = model(cin.long(), rin.long(), qin.long()) pred = y[0][-1] elif model_name == "lpkt": if itout != None: curit = torch.tensor([[itout.item()]]).to(device) itin = torch.cat((itin, curit), axis=1) curc, curr = torch.tensor([[cout.item()]]).to(device), torch.tensor([[1]]).to(device) cin, rin = torch.cat((cin, curc), axis=1), torch.cat((rin, curr), axis=1) y = model(cin.long(), rin.long(), itin.long()) pred = y[0][-1] elif model_name == "gkt": curc, curr = torch.tensor([[cout.item()]]).to(device), torch.tensor([[1]]).to(device) cin, rin = torch.cat((cin, curc), axis=1), torch.cat((rin, curr), axis=1) y = model(cin.long(), rin.long()) # print(f"y.shape is {y.shape},cin shape is {cin.shape}") pred = y[0][-1] elif model_name == "hawkes": curc, curr = torch.tensor([[cout.item()]]).to(device), torch.tensor([[1]]).to(device) if tout != None: curt = torch.tensor([[tout.item()]]).to(device) tin = torch.cat((tin, curt), axis=1) if qout != None: curq = torch.tensor([[qout.item()]]).to(device) qin = torch.cat((qin, curq), axis=1) curc, curr = torch.tensor([[cout.item()]]).to(device), torch.tensor([[1]]).to(device) cin, rin = torch.cat((cin, curc), axis=1), torch.cat((rin, curr), axis=1) y = model(cin.long(), qin.long(), tin.long(), rin.long()) pred = y[0][-1] predl = 1 if pred.item() >= 0.5 else 0 cpred = torch.tensor([[predl]]).to(device) nextqin = cq[0:k+1].unsqueeze(0) if cq.shape[0] > 0 else qin nextcin = cc[0:k+1].unsqueeze(0) nextrin = torch.cat((nextrin, cpred), axis=1)### change!! # print(f"nextqin: {nextqin.shape}") # update nextdforget if model_name == "dkt_forget": for key in nextdforget: curd = torch.tensor([[dforget[key][k]]]).long().to(device) nextdforget[key] = torch.cat((nextdforget[key], curd), axis=1) # print(f"bz: {bz}, t: {t}, pred: {pred}, true: {true}") # save pred res ctrues.append(true.item()) cpreds.append(pred.item()) # output clist, rlist = cin.squeeze(0).long().tolist()[0:k], rin.squeeze(0).long().tolist()[0:k] # print("\t".join([str(idx), str(uid), str(k), str(qidx), str(is_repeat[t:end]), str(len(clist)), str(clist), str(rlist), str(cout.item()), str(true.item()), str(pred.item()), str(predl)])) fout.write("\t".join([str(idx), str(uid), str(k), str(qidx), str(is_repeat[t:end]), str(len(clist)), str(clist), str(rlist), str(cout.item()), str(true.item()), str(pred.item()), str(predl)]) + "\n") # nextcin, nextrin = nextcin.unsqueeze(0), nextrin.unsqueeze(0) return nextqin, nextcin, nextrin, nextdforget, ctrues, cpreds
[docs]def save_each_question_res(dcres, dqres, ctrues, cpreds): # save res high, low = [], [] for true, pred in zip(ctrues, cpreds): dcres["trues"].append(true) dcres["preds"].append(pred) if pred >= 0.5: high.append(pred) else: low.append(pred) cpreds = np.array(cpreds) late_mean = np.mean(cpreds) correctnum = list(cpreds>=0.5).count(True) late_vote = np.mean(high) if correctnum / len(cpreds) >= 0.5 else np.mean(low) late_all = np.mean(high) if correctnum == len(cpreds) else np.mean(low) assert len(set(ctrues)) == 1 dqres["trues"].append(dcres["trues"][-1]) dqres["late_mean"].append(late_mean) dqres["late_vote"].append(late_vote) dqres["late_all"].append(late_all) return late_mean, late_vote, late_all
[docs]def cal_predres(dcres, dqres): dres = dict()#{"concept": [], "late_mean": [], "late_vote": [], "late_all": []} ctrues, cpreds = np.array(dcres["trues"]), np.array(dcres["preds"]) # print(f"key: concepts, ts.shape: {ctrues.shape}, ps.shape: {cpreds.shape}") auc = metrics.roc_auc_score(y_true=ctrues, y_score=cpreds) prelabels = [1 if p >= 0.5 else 0 for p in cpreds] acc = metrics.accuracy_score(ctrues, prelabels) dres["concepts"] = [len(cpreds), auc, acc] qtrues = np.array(dqres["trues"]) for key in dqres: if key == "trues": continue preds = np.array(dqres[key]) # print(f"key: {key}, ts.shape: {qtrues.shape}, ps.shape: {preds.shape}") auc = metrics.roc_auc_score(y_true=qtrues, y_score=preds) prelabels = [1 if p >= 0.5 else 0 for p in preds] acc = metrics.accuracy_score(qtrues, prelabels) dres[key] = [len(preds), auc, acc] return dres
[docs]def prepare_data(model_name, is_repeat, qidx, dcur, curdforget, dtotal, dforget, t, end, maxlen=200): curqin, curcin, currin, curtin = dcur["curqin"], dcur["curcin"], dcur["currin"], dcur["curtin"] cq, cc, cr, ct = dtotal["cq"], dtotal["cc"], dtotal["cr"], dtotal["ct"] dqshfts, dcshfts, drshfts, dtshfts, dds, ddshfts = [], [], [], [], dict(), dict() dqs, dcs, drs, dts = [], [], [], [] if model_name == "lpkt": curitin = dcur["curitin"] cit = dtotal["cit"] dits, ditshfts = [], [] qidxs = [] qstart = qidx-1 for k in range(t, end): if is_repeat[k] == 0: qstart += 1 qidxs.append(qstart) else: qidxs.append(qstart) # get start start = 0 cinlen = curcin.shape[1] if cinlen >= maxlen - 1: start = cinlen - maxlen + 1 curc, curr = cc.long()[k], cr.long()[k] curc, curr = torch.tensor([[curc.item()]]).to(device), torch.tensor([[curr.item()]]).to(device) dcs.append(curcin[:, start:]) drs.append(currin[:, start:]) curc, curr = torch.cat((curcin[:, start+1:], curc), axis=1), torch.cat((currin[:, start+1:], curr), axis=1) dcshfts.append(curc) drshfts.append(curr) if cq.shape[0] > 0: curq = cq.long()[k] curq = torch.tensor([[curq.item()]]).to(device) dqs.append(curqin[:, start:]) curq = torch.cat((curqin[:, start+1:], curq), axis=1) dqshfts.append(curq) if ct.shape[0] > 0: curt = ct.long()[k] curt = torch.tensor([[curt.item()]]).to(device) dts.append(curtin[:, start:]) curt = torch.cat((curtin[:, start+1:], curt), axis=1) dtshfts.append(curt) if model_name == "lpkt": if cit.shape[0] > 0: curit = cit.long()[k] curit = torch.tensor([[curit.item()]]).to(device) dits.append(curitin[:, start:]) curit = torch.cat((curitin[:, start+1:], curit), axis=1) ditshfts.append(curit) d, dshft = dict(), dict() if model_name == "dkt_forget": for key in curdforget: d[key] = curdforget[key][:,start:] dds.setdefault(key, []) dds[key].append(d[key]) for key in dforget: curd = torch.tensor([[dforget[key][k]]]).long().to(device) dshft[key] = torch.cat((d[key][:,1:], curd), axis=1) ddshfts.setdefault(key, []) ddshfts[key].append(dshft[key]) finalcs, finalrs = torch.cat(dcs, axis=0), torch.cat(drs, axis=0) finalqs, finalqshfts = torch.tensor([]), torch.tensor([]) finalts, finaltshfts = torch.tensor([]), torch.tensor([]) if cq.shape[0] > 0: finalqs = torch.cat(dqs, axis=0) finalqshfts = torch.cat(dqshfts, axis=0) if ct.shape[0] > 0: finalts = torch.cat(dts, axis=0) finaltshfts = torch.cat(dtshfts, axis=0) finalcshfts, finalrshfts = torch.cat(dcshfts, axis=0), torch.cat(drshfts, axis=0) finald, finaldshft = dict(), dict() for key in dds: finald[key] = torch.cat(dds[key], axis=0) finaldshft[key] = torch.cat(ddshfts[key], axis=0) # print(f"qidx: {len(qidxs)}, finalqs: {finalqs.shape}, finalcs: {finalcs.shape}, finalrs: {finalrs.shape}") # print(f"qidx: {len(qidxs)}, finalqshfts: {finalqshfts.shape}, finalcshfts: {finalcshfts.shape}, finalrshfts: {finalrshfts.shape}") if model_name == "lpkt": finalits, finalitshfts = torch.tensor([]), torch.tensor([]) if cit.shape[0] > 0: finalits = torch.cat(dits, axis=0) finalitshfts = torch.cat(ditshfts, axis=0) if model_names != lpkt: return qidxs, finalqs, finalcs, finalrs, finalts, finalqshfts, finalcshfts, finalrshfts, finaltshfts, finald, finaldshft else: return qidxs, finalqs, finalcs, finalrs, finalts, finalits, finalqshfts, finalcshfts, finalrshfts, finaltshfts, finalitshfts, finald, finaldshft
# def predict_each_group2(curdforget, dforget, is_repeat, qidx, uid, idx, curqin, curcin, currin, model_name, model, t, cq, cc, cr, end, fout, atkt_pad=False, maxlen=200):
[docs]def predict_each_group2(dtotal, dcur, dforget, curdforget, is_repeat, qidx, uid, idx, model_name, model, t, end, fout, atkt_pad=False, maxlen=200): """not use the predict result """ curqin, curcin, currin, curtin = dcur["curqin"], dcur["curcin"], dcur["currin"], dcur["curtin"] cq, cc, cr, ct = dtotal["cq"], dtotal["cc"], dtotal["cr"], dtotal["ct"] if model_name == "lpkt": cit = dtotal["cit"] nextcin, nextrin = curcin, currin import copy nextdforget = copy.deepcopy(curdforget) ctrues, cpreds = [], [] # 以下这些用的是同一个历史,可以并行 # 不用预测结果 if model_name == "lpkt": qidxs, finalqs, finalcs, finalrs, finalts, finalits, finalqshfts, finalcshfts, finalrshfts, finaltshfts, finalitshfts, finald, finaldshft = prepare_data(model_name, is_repeat, qidx, dcur, curdforget, dtotal, dforget, t, end) else: qidxs, finalqs, finalcs, finalrs, finalts, finalqshfts, finalcshfts, finalrshfts, finaltshfts, finald, finaldshft = prepare_data(model_name, is_repeat, qidx, dcur, curdforget, dtotal, dforget, t, end) bidx, bz = 0, 128 while bidx < finalcs.shape[0]: curc, curr = finalcs[bidx: bidx+bz], finalrs[bidx: bidx+bz] curcshft, currshft = finalcshfts[bidx: bidx+bz], finalrshfts[bidx: bidx+bz] curqidxs = qidxs[bidx: bidx+bz] curq, curqshft = torch.tensor([[]]), torch.tensor([[]]) if finalqs.shape[0] > 0: curq = finalqs[bidx: bidx+bz] curqshft = finalqshfts[bidx: bidx+bz] if finalts.shape[0] > 0: curt = finalts[bidx: bidx+bz] curtshft = finaltshfts[bidx: bidx+bz] curd, curdshft = dict(), dict() if model_name == "dkt_forget": for key in finald: curd[key] = finald[key][bidx: bidx+bz] curdshft[key] = finaldshft[key][bidx: bidx+bz] if model_name == "lpkt": curit = finalits[bidx: bidx+bz] curitshft = finalitshfts[bidx: bidx+bz] ## start predict ccq = torch.cat((curq[:,0:1], curqshft), dim=1) ccc = torch.cat((curc[:,0:1], curcshft), dim=1) ccr = torch.cat((curr[:,0:1], currshft), dim=1) cct = torch.cat((curt[:,0:1], curtshft), dim=1) if model_name in ["dkt", "dkt+"]: y = model(curc.long(), curr.long()) y = (y * one_hot(curcshft.long(), model.num_c)).sum(-1) elif model_name in ["dkt_forget"]: y = model(curc.long(), curr.long(), curd, curdshft) y = (y * one_hot(curcshft.long(), model.num_c)).sum(-1) elif model_name in ["dkvmn", "skvmn"]: y = model(ccc.long(), ccr.long()) y = y[:,1:] elif model_name in ["kqn", "sakt"]: y = model(curc.long(), curr.long(), curcshft.long()) elif model_name == "saint": y = model(ccq.long(), ccc.long(), curr.long()) y = y[:, 1:] elif model_name in ["akt", "akt_vector", "akt_norasch", "akt_mono", "akt_attn", "aktattn_pos", "aktmono_pos", "akt_raschx", "akt_raschy", "aktvec_raschx"]: y, reg_loss = model(ccc.long(), ccr.long(), ccq.long()) y = y[:,1:] elif model_name in ["atkt", "atktfix"]: # print(f"atkt_pad: {atkt_pad}") if atkt_pad == True: oricurclen = curc.shape[1] padlen = maxlen-1-oricurclen # print(f"padlen: {padlen}, curc: {curc.shape}") pad = torch.tensor([0]*padlen).unsqueeze(0).expand(curc.shape[0], padlen).to(device) curc = torch.cat((curc, pad), axis=1) curr = torch.cat((curr, pad), axis=1) curcshft = torch.cat((curcshft, pad), axis=1) y, _ = model(curc.long(), curr.long()) y = (y * one_hot(curcshft.long(), model.num_c)).sum(-1) elif model_name == "lpkt": ccit = torch.cat((curit[:,0:1], curitshft), dim=1) y = model(ccq.long(), ccr.long(), ccit.long()) y = y[:, 1:] elif model_name == "gkt": y = model(ccc.long(), ccr.long()) # print(f"y: {y}") # y = y[:, t-1:t] elif model_name == "hawkes": y = model(ccc.long(), ccq.long(), cct.long(), ccr.long()) pred = y[0][-1] if model_name in ["atkt", "atktfix"] and atkt_pad == True: # print(f"use idx: {oricurclen-1}") pred = y[:, oricurclen-1].tolist() # assert ccr[:, t] == curcshft[:, t-1] true = currshft[:, oricurclen-1].tolist() # print(true) # true = curcshft[:, t-1].tolist() else: pred = y[:, -1].tolist() true = ccr[:, -1].tolist() # print(f"pred: {len(pred)}, true: {true}") # save pred res ctrues.extend(true) cpreds.extend(pred) # output for i in range(0, curc.shape[0]): clist, rlist = curc[i].long().tolist()[0:t], curr[i].long().tolist()[0:t] cshftlist, rshftlist = curcshft[i].long().tolist()[0:t], currshft[i].long().tolist()[0:t] qidx = curqidxs[i] predl = 1 if pred[i] >= 0.5 else 0 # print("\t".join([str(idx), str(uid), str(bidx+i), str(qidx), str(len(clist)), str(clist), str(rlist), str(cshftlist), str(rshftlist), str(true[i]), str(pred[i]), str(predl)])) fout.write("\t".join([str(idx), str(uid), str(bidx+i), str(qidx), str(len(clist)), str(clist), str(rlist), str(cshftlist), str(rshftlist), str(true[i]), str(pred[i]), str(predl)]) + "\n") bidx += bz return qidxs, ctrues, cpreds
[docs]def save_currow_question_res(idx, dcres, dqres, qidxs, ctrues, cpreds, uid, fout): # save res dqidx = dict() # dhigh, dlow = dict(), dict() for i in range(0, len(qidxs)): true, pred = ctrues[i], cpreds[i] qidx = qidxs[i] dqidx.setdefault(qidx, {"trues": [], "preds": []}) dqidx[qidx]["trues"].append(true) dqidx[qidx]["preds"].append(pred) for qidx in dqidx: ctrues, cpreds = dqidx[qidx]["trues"], dqidx[qidx]["preds"] late_mean, late_vote, late_all = save_each_question_res(dcres, dqres, ctrues, cpreds) # print("\t".join([str(idx), str(uid), str(qidx), str(late_mean), str(late_vote), str(late_all)])) fout.write("\t".join([str(idx), str(uid), str(qidx), str(late_mean), str(late_vote), str(late_all)]) + "\n")