inference.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import os
  2. import torch
  3. import wandb
  4. import gc
  5. from tqdm import tqdm
  6. from torch.utils.data import DataLoader
  7. from src.utils.seed import seed_everything
  8. from src.utils.config import parse_args_llama
  9. from src.model import load_model, llama_model_path
  10. from src.dataset import load_dataset
  11. from src.utils.evaluate import eval_funcs
  12. from src.utils.collate import collate_fn
  13. def main(args):
  14. # Step 1: Set up wandb
  15. seed = args.seed
  16. wandb.init(project=f"{args.project}",
  17. name=f"{args.dataset}_{args.model_name}_seed{seed}",
  18. config=args)
  19. seed_everything(seed=seed)
  20. print(args)
  21. dataset = load_dataset[args.dataset]()
  22. idx_split = dataset.get_idx_split()
  23. # Step 2: Build Node Classification Dataset
  24. test_dataset = [dataset[i] for i in idx_split['test']]
  25. test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False, collate_fn=collate_fn)
  26. # Step 3: Build Model
  27. args.llm_model_path = llama_model_path[args.llm_model_name]
  28. model = load_model[args.model_name](graph=dataset.graph, graph_type=dataset.graph_type, args=args)
  29. # Step 4. Evaluating
  30. model.eval()
  31. eval_output = []
  32. progress_bar_test = tqdm(range(len(test_loader)))
  33. for _, batch in enumerate(test_loader):
  34. with torch.no_grad():
  35. output = model.inference(batch)
  36. eval_output.append(output)
  37. progress_bar_test.update(1)
  38. # Step 5. Post-processing & Evaluating
  39. os.makedirs(f'{args.output_dir}/{args.dataset}', exist_ok=True)
  40. path = f'{args.output_dir}/{args.dataset}/model_name_{args.model_name}_llm_model_name_{args.llm_model_name}_llm_frozen_{args.llm_frozen}_max_txt_len_{args.max_txt_len}_max_new_tokens_{args.max_new_tokens}_gnn_model_name_{args.gnn_model_name}_patience_{args.patience}_num_epochs_{args.num_epochs}_seed{seed}.csv'
  41. acc = eval_funcs[args.dataset](eval_output, path)
  42. print(f'Test Acc {acc}')
  43. wandb.log({'Test Acc': acc})
  44. if __name__ == "__main__":
  45. args = parse_args_llama()
  46. main(args)
  47. torch.cuda.empty_cache()
  48. torch.cuda.reset_max_memory_allocated()
  49. gc.collect()