Source code for pykt.utils.wandb_utils

import pandas as pd
import wandb
from wandb.apis.public import gql

[docs]def get_runs_result(runs,drop_duplicate=False): result_list = [] for run in runs: result = {} result.update(run.summary._json_dict) model_config = {k: v for k, v in run.config.items() if not k.startswith('_') and type(v) not in [list,dict]} result.update(model_config) result['Name'] = run.name result['path_id'] = run.path[-1] result_list.append(result) runs_df = pd.DataFrame(result_list) if drop_duplicate: runs_df.drop_duplicates(list(model_config.keys())) return runs_df
[docs]class WandbUtils: """wandb utils wandb_api = WandbUtils(user='tabchen', project_name='pykt_iekt_pred') >self.sweep_dict is {'mx2tvwfy': ['mx2tvwfy']} """ def __init__(self,user,project_name) -> None: self.user = user self.project_name = project_name self._init_wandb() def _init_wandb(self): self.api = wandb.Api() self.project = self.api.project(name=self.project_name) self.sweep_dict = self.get_sweep_dict() print(f"self.sweep_dict is {self.sweep_dict}")
[docs] def get_sweep_dict(self): '''Get sweep dict''' sweep_dict = {} for sweep in self.project.sweeps(): if sweep.name not in sweep_dict: sweep_dict[sweep.name] = [] sweep_dict[sweep.name].append(sweep.id) for name in sweep_dict: if len(sweep_dict[name]) > 1: del sweep_dict[name] print(f"Error!! we can not process the same sweep name {name}, we will not return those sweeps:{sweep_dict[name]}") else: sweep_dict[name] = sweep_dict[name][0] return sweep_dict
def _get_sweep_id(self,id,input_type): if input_type == "sweep_name": sweep_id = self.sweep_dict[id] else: sweep_id = id return sweep_id
[docs] def get_df(self,id,input_type="sweep_name"): """Get one sweep result Args: id (str): the sweep name or sweep id. input_type (str, optional): the type of id. Defaults to sweep_name. Returns: pd.Data: _description_ """ sweep_id = self._get_sweep_id(id,input_type) sweep = self.api.sweep(f"{self.user}/{self.project_name}/{sweep_id}") df = get_runs_result(sweep.runs) df["run_index"] = df["Name"].apply(lambda a: int(a.split("-")[-1])) # 创建的任务名字有顺序 return df
[docs] def get_multi_df(self,id_list=[],input_type="sweep_name"): """Get multi sweep result Args: id_list (list): the list of sweep name or sweep id. input_type (str, optional): the type of id. Defaults to sweep_name. Returns: _type_: _description_ """ df_list = [] for id in id_list: df = self.get_df(id,input_type=input_type) df[input_type] = id df_list.append(df) return df_list
[docs] def get_sweep_status(self,id,input_type="sweep_name"): """Get sweep run status Args: id (str): the sweep name or sweep id. input_type (str, optional): the type of id. Defaults to sweep_name. Returns: str: the state of sweep. 'RUNNING', 'CANCELED' or 'FINISHED' """ query = gql( """query Sweep($project: String, $entity: String, $name: String!) { project(name: $project, entityName: $entity) { sweep(sweepName: $name) { id name bestLoss config state } }, } """) sweep_id = self._get_sweep_id(id,input_type) variables = { "entity": self.user, "project": self.project_name, "name": sweep_id} status = self.project.client.execute(query,variable_values=variables)['project']['sweep']['state'] return status
[docs] def get_sweep_run_num(self,id,input_type="sweep_name"): """Get sweep run num Args: id (str): the sweep name or sweep id. input_type (str, optional): the type of id. Defaults to sweep_name. Returns: int: the num of sweep run """ sweep_id = self._get_sweep_id(id,input_type) sweep = self.api.sweep(f"{self.user}/{self.project_name}/{sweep_id}") return len(sweep.runs)
[docs] def check_sweep_early_stop(self,id,input_type="sweep_name",metric="testauc",metric_type="max",min_run_num=300,patience=100): """Check sweep early stop Args: id (str): the sweep name or sweep id. input_type (str, optional): the type of id. Defaults to sweep_name. metric (str, optional): the metric to check. Defaults to testauc. metric_type (str, optional): the type of metric max or min. Defaults to max. min_run_num (int, optional): the min run num to check. Defaults to 300. patience (int, optional): the patience to stop. Defaults to 100. Returns: dict: {"state":state,'df':df,"num_run":num_run}, state is 'RUNNING', 'CANCELED' or 'FINISHED',df is the df of the sweep, num_run is the num of sweep run, -1 mean the sweep is finished to save time we will not check it again. """ print(f"Start check {id}") sweep_id = self._get_sweep_id(id,input_type) sweep_status = self.get_sweep_status(sweep_id,input_type="sweep_id") df = None report = {"stop_cmd":""} if sweep_status in ['CANCELED','FINISHED']: report['state'] = True report['num_run'] = -1 else: num_run = self.get_sweep_run_num(sweep_id,input_type="sweep_id")#get sweep run num report['num_run'] = num_run if num_run<min_run_num: report['state'] = False else: df = self.get_df(sweep_id,input_type="sweep_id")#get sweep result report['df'] = df best_value = df[metric].max() if metric_type == "max" else df[metric].min()#get best value first_best_index = df[df[metric]==best_value]['run_index'].min() not_improve_num = len(df[df['run_index'] >= first_best_index]) if not_improve_num > patience:#如果连续patience没有提高,则停止 stop_cmd = f"wandb sweep {self.user}/{self.project_name}/{sweep_id} --stop" print(f" Run `{stop_cmd}` to stop the sweep.") report['state'] = True report['stop_cmd'] = stop_cmd else: report['state'] = False print(f" details: {id} state is {report['state']},num of runs is {report['num_run']}") print("-"*60+'\n') return report
[docs] def check_sweep_by_pattern(self,sweep_pattern,metric="testauc",metric_type="max",min_run_num=300,patience=100): """Check sweeps by pattern Args: sweep_pattern (str): check the sweeps which sweep names start with sweep_pattern metric (str, optional): the metric to check. Defaults to testauc. metric_type (str, optional): the type of metric max or min. Defaults to max. min_run_num (int, optional): the min run num to check. Defaults to 300. patience (int, optional): the patience to stop. Defaults to 100. Returns: list: the list of dict, each dict is {"id":id,"state":state,'df':df,"num_run":num_run}, state is 'RUNNING', 'CANCELED' or 'FINISHED',df is the df of the sweep, num_run is the num of sweep run, -1 mean the sweep is finished to save time we will not check it again. """ check_result_list = [] for sweep_name in self.sweep_dict: if sweep_name.startswith(sweep_pattern) or sweep_pattern=='all': check_result = self.check_sweep_early_stop(sweep_name,input_type='sweep_name', metric=metric,metric_type=metric_type,min_run_num=min_run_num,patience=patience) check_result['sweep_name'] = sweep_name check_result_list.append(check_result) return check_result_list