Source code for pykt.datasets.init_dataset

import os, sys
import json

from torch.utils.data import DataLoader
import numpy as np
from .data_loader import KTDataset
from .dkt_forget_dataloader import DktForgetDataset
from .lpkt_dataloader import LPKTDataset
from .lpkt_utils import generate_time2idx
from .que_data_loader import KTQueDataset

[docs]def init_test_datasets(data_config, model_name, batch_size): print(f"model_name is {model_name}") test_question_loader, test_question_window_loader = None, None if model_name in ["dkt_forget"]: test_dataset = DktForgetDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1}) test_window_dataset = DktForgetDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1}) if "test_question_file" in data_config: test_question_dataset = DktForgetDataset(os.path.join(data_config["dpath"], data_config["test_question_file"]), data_config["input_type"], {-1}, True) test_question_window_dataset = DktForgetDataset(os.path.join(data_config["dpath"], data_config["test_question_window_file"]), data_config["input_type"], {-1}, True) elif model_name in ["lpkt"]: test_dataset = LPKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), at2idx, it2idx, data_config["input_type"], {-1}) test_window_dataset = LPKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), at2idx, it2idx, data_config["input_type"], {-1}) if "test_question_file" in data_config: test_question_dataset = LPKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), at2idx, it2idx, data_config["input_type"], {-1}, True) test_question_window_dataset = LPKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), at2idx, it2idx, data_config["input_type"], {-1}, True) elif model_name == "iekt": test_dataset = KTQueDataset(os.path.join(data_config["dpath"], data_config["test_file_quelevel"]), input_type=data_config["input_type"], folds=[-1], concept_num=data_config['num_c'], max_concepts=data_config['max_concepts']) test_window_dataset = KTQueDataset(os.path.join(data_config["dpath"], data_config["test_window_file_quelevel"]), input_type=data_config["input_type"], folds=[-1], concept_num=data_config['num_c'], max_concepts=data_config['max_concepts']) test_question_dataset = None test_question_window_dataset= None else: test_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1}) test_window_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1}) if "test_question_file" in data_config: test_question_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_question_file"]), data_config["input_type"], {-1}, True) test_question_window_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_question_window_file"]), data_config["input_type"], {-1}, True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) test_window_loader = DataLoader(test_window_dataset, batch_size=batch_size, shuffle=False) if "test_question_file" in data_config: print(f"has test_question_file!") test_question_loader,test_question_window_loader = None,None if not test_question_dataset is None: test_question_loader = DataLoader(test_question_dataset, batch_size=batch_size, shuffle=False) if not test_question_window_dataset is None: test_question_window_loader = DataLoader(test_question_window_dataset, batch_size=batch_size, shuffle=False) return test_loader, test_window_loader, test_question_loader, test_question_window_loader
[docs]def update_gap(max_rgap, max_sgap, max_pcount, cur): max_rgap = cur.max_rgap if cur.max_rgap > max_rgap else max_rgap max_sgap = cur.max_sgap if cur.max_sgap > max_sgap else max_sgap max_pcount = cur.max_pcount if cur.max_pcount > max_pcount else max_pcount return max_rgap, max_sgap, max_pcount
[docs]def init_dataset4train(dataset_name, model_name, data_config, i, batch_size): data_config = data_config[dataset_name] all_folds = set(data_config["folds"]) if model_name == "dkt_forget": max_rgap, max_sgap, max_pcount = 0, 0, 0 curvalid = DktForgetDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i}) curtrain = DktForgetDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i}) max_rgap, max_sgap, max_pcount = update_gap(max_rgap, max_sgap, max_pcount, curtrain) max_rgap, max_sgap, max_pcount = update_gap(max_rgap, max_sgap, max_pcount, curvalid) elif model_name == "lpkt": at2idx, it2idx = generate_time2idx(data_config) json_str = json.dumps(at2idx) # with open('at2idx.json', 'w') as json_file: # json_file.write(json_str) # json_str_2 = json.dumps(it2idx) # with open('it2idx.json', 'w') as json_file2: # json_file2.write(json_str_2) curvalid = LPKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), at2idx, it2idx, data_config["input_type"], {i}) curtrain = LPKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), at2idx, it2idx, data_config["input_type"], all_folds - {i}) elif model_name == "iekt": curvalid = KTQueDataset(os.path.join(data_config["dpath"], data_config["train_valid_file_quelevel"]), input_type=data_config["input_type"], folds={i}, concept_num=data_config['num_c'], max_concepts=data_config['max_concepts']) curtrain = KTQueDataset(os.path.join(data_config["dpath"], data_config["train_valid_file_quelevel"]), input_type=data_config["input_type"], folds=all_folds - {i}, concept_num=data_config['num_c'], max_concepts=data_config['max_concepts']) else: curvalid = KTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i}) curtrain = KTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i}) train_loader = DataLoader(curtrain, batch_size=batch_size) valid_loader = DataLoader(curvalid, batch_size=batch_size) if model_name == "dkt_forget": test_dataset = DktForgetDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1}) # test_window_dataset = DktForgetDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), # data_config["input_type"], {-1}) max_rgap, max_sgap, max_pcount = update_gap(max_rgap, max_sgap, max_pcount, test_dataset) elif model_name == "lpkt": test_dataset = LPKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), at2idx, it2idx, data_config["input_type"], {-1}) # test_window_dataset = LPKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), at2idx, it2idx, data_config["input_type"], {-1}) elif model_name == "iekt": test_dataset = KTQueDataset(os.path.join(data_config["dpath"], data_config["test_window_file_quelevel"]), input_type=data_config["input_type"], folds=[-1], concept_num=data_config['num_c'], max_concepts=data_config['max_concepts']) else: test_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1}) # test_window_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1}) if model_name == "dkt_forget": data_config["num_rgap"] = max_rgap + 1 data_config["num_sgap"] = max_sgap + 1 data_config["num_pcount"] = max_pcount + 1 if model_name == "lpkt": print(f"num_at:{len(at2idx)}") print(f"num_it:{len(it2idx)}") data_config["num_at"] = len(at2idx) + 1 data_config["num_it"] = len(it2idx) + 1 test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # test_window_loader = DataLoader(test_window_dataset, batch_size=batch_size, shuffle=False) test_window_loader = None return train_loader, valid_loader, test_loader, test_window_loader