import os import torch import wandb import gc from tqdm import tqdm from torch.utils.data import DataLoader from src.utils.seed import seed_everything from src.utils.config import parse_args_llama from src.model import load_model, llama_model_path from src.dataset import load_dataset from src.utils.evaluate import eval_funcs from src.utils.collate import collate_fn def main(args): # Step 1: Set up wandb seed = args.seed wandb.init(project=f"{args.project}", name=f"{args.dataset}_{args.model_name}_seed{seed}", config=args) seed_everything(seed=seed) print(args) dataset = load_dataset[args.dataset]() idx_split = dataset.get_idx_split() # Step 2: Build Node Classification Dataset test_dataset = [dataset[i] for i in idx_split['test']] test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False, collate_fn=collate_fn) # Step 3: Build Model args.llm_model_path = llama_model_path[args.llm_model_name] model = load_model[args.model_name](graph=dataset.graph, graph_type=dataset.graph_type, args=args) # Step 4. Evaluating model.eval() eval_output = [] progress_bar_test = tqdm(range(len(test_loader))) for _, batch in enumerate(test_loader): with torch.no_grad(): output = model.inference(batch) eval_output.append(output) progress_bar_test.update(1) # Step 5. Post-processing & Evaluating os.makedirs(f'{args.output_dir}/{args.dataset}', exist_ok=True) path = f'{args.output_dir}/{args.dataset}/model_name_{args.model_name}_llm_model_name_{args.llm_model_name}_llm_frozen_{args.llm_frozen}_max_txt_len_{args.max_txt_len}_max_new_tokens_{args.max_new_tokens}_gnn_model_name_{args.gnn_model_name}_patience_{args.patience}_num_epochs_{args.num_epochs}_seed{seed}.csv' acc = eval_funcs[args.dataset](eval_output, path) print(f'Test Acc {acc}') wandb.log({'Test Acc': acc}) if __name__ == "__main__": args = parse_args_llama() main(args) torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() gc.collect()