| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- import os
- # 设置Hugging Face国内镜像源 (必须在import transformers之前设置)
- os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
- os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
- # 强制离线模式 - 确保从本地加载模型而不联网
- os.environ['TRANSFORMERS_OFFLINE'] = '1'
- os.environ['HF_DATASETS_OFFLINE'] = '1'
- # 新增:设置本地WandB服务地址(可选但推荐)
- os.environ['WANDB_BASE_URL'] = 'http://117.72.35.222:8081' # 添加这一行
- import wandb
- import gc
- from tqdm import tqdm
- import torch
- from torch.utils.data import DataLoader
- from torch.nn.utils import clip_grad_norm_
- from src.model import load_model, llama_model_path
- from src.dataset import load_dataset
- from src.utils.evaluate import eval_funcs
- from src.utils.config import parse_args_llama
- from src.utils.ckpt import _save_checkpoint, _reload_best_model
- from src.utils.collate import collate_fn
- from src.utils.seed import seed_everything
- from src.utils.lr_schedule import adjust_learning_rate
- def main(args):
- # Step 1: Set up wandb
- seed = args.seed
- wandb.init(project=f"{args.project}",
- name=f"{args.dataset}_{args.model_name}_seed{seed}",
- config=args)
- seed_everything(seed=args.seed)
- print(args)
- dataset = load_dataset[args.dataset]()
- idx_split = dataset.get_idx_split()
- # Step 2: Build Dataset
- train_dataset = [dataset[i] for i in idx_split['train']]
- val_dataset = [dataset[i] for i in idx_split['val']]
- test_dataset = [dataset[i] for i in idx_split['test']]
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, pin_memory=True, shuffle=True, collate_fn=collate_fn)
- val_loader = DataLoader(val_dataset, batch_size=args.batch_size, drop_last=False, pin_memory=True, shuffle=False, collate_fn=collate_fn)
- test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False, collate_fn=collate_fn)
- # Step 3: Build Model
- args.llm_model_path = llama_model_path[args.llm_model_name]
- model = load_model[args.model_name](graph_type=dataset.graph_type, args=args, init_prompt=dataset.prompt)
- # Step 4 Set Optimizer
- params = [p for _, p in model.named_parameters() if p.requires_grad]
- optimizer = torch.optim.AdamW(
- [{'params': params, 'lr': args.lr, 'weight_decay': args.wd}, ],
- betas=(0.9, 0.95)
- )
- trainable_params, all_param = model.print_trainable_params()
- print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
- # Step 5. Training
- num_training_steps = args.num_epochs * len(train_loader)
- progress_bar = tqdm(range(num_training_steps))
- best_val_loss = float('inf')
- for epoch in range(args.num_epochs):
- model.train()
- epoch_loss, accum_loss = 0., 0.
- for step, batch in enumerate(train_loader):
- optimizer.zero_grad()
- loss = model(batch)
- loss.backward()
- clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
- if (step + 1) % args.grad_steps == 0:
- adjust_learning_rate(optimizer.param_groups[0], args.lr, step / len(train_loader) + epoch, args)
- optimizer.step()
- epoch_loss, accum_loss = epoch_loss + loss.item(), accum_loss + loss.item()
- if (step + 1) % args.grad_steps == 0:
- lr = optimizer.param_groups[0]["lr"]
- wandb.log({'Lr': lr})
- wandb.log({'Accum Loss': accum_loss / args.grad_steps})
- accum_loss = 0.
- progress_bar.update(1)
- print(f"Epoch: {epoch}|{args.num_epochs}: Train Loss (Epoch Mean): {epoch_loss / len(train_loader)}")
- wandb.log({'Train Loss (Epoch Mean)': epoch_loss / len(train_loader)})
- val_loss = 0.
- eval_output = []
- model.eval()
- with torch.no_grad():
- for step, batch in enumerate(val_loader):
- loss = model(batch)
- val_loss += loss.item()
- val_loss = val_loss/len(val_loader)
- print(f"Epoch: {epoch}|{args.num_epochs}: Val Loss: {val_loss}")
- wandb.log({'Val Loss': val_loss})
- if val_loss < best_val_loss:
- best_val_loss = val_loss
- _save_checkpoint(model, optimizer, epoch, args, is_best=True)
- best_epoch = epoch
- print(f'Epoch {epoch} Val Loss {val_loss} Best Val Loss {best_val_loss} Best Epoch {best_epoch}')
- if epoch - best_epoch >= args.patience:
- print(f'Early stop at epoch {epoch}')
- break
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- # Step 5. Evaluating
- model = _reload_best_model(model, args)
- model.eval()
- eval_output = []
- progress_bar_test = tqdm(range(len(test_loader)))
- for step, batch in enumerate(test_loader):
- with torch.no_grad():
- output = model.inference(batch)
- eval_output.append(output)
- progress_bar_test.update(1)
- # Step 6. Post-processing & compute metrics
- os.makedirs(f'{args.output_dir}/{args.dataset}', exist_ok=True)
- path = f'{args.output_dir}/{args.dataset}/model_name_{args.model_name}_llm_model_name_{args.llm_model_name}_llm_frozen_{args.llm_frozen}_max_txt_len_{args.max_txt_len}_max_new_tokens_{args.max_new_tokens}_gnn_model_name_{args.gnn_model_name}_patience_{args.patience}_num_epochs_{args.num_epochs}_seed{seed}.csv'
- acc = eval_funcs[args.dataset](eval_output, path)
- print(f'Test Acc {acc}')
- wandb.log({'Test Acc': acc})
- if __name__ == "__main__":
- args = parse_args_llama()
- main(args)
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- gc.collect()
|