train_local.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import os
  2. # 设置可见的GPU设备 - 新增这行
  3. os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" # 只使用设备0和1
  4. # 设置Hugging Face国内镜像源 (必须在import transformers之前设置)
  5. os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
  6. os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
  7. # 强制离线模式 - 确保从本地加载模型而不联网
  8. os.environ['TRANSFORMERS_OFFLINE'] = '1'
  9. os.environ['HF_DATASETS_OFFLINE'] = '1'
  10. # 新增:设置本地WandB服务地址(可选但推荐)
  11. os.environ['WANDB_BASE_URL'] = 'http://117.72.35.222:8081' # 添加这一行
  12. # 设置WandB API密钥 (替换为您的实际密钥)
  13. os.environ['WANDB_API_KEY'] = 'local-84f1b57b7ab8df20a8881450fcfe22acb87bf9e2' # 替换为您的本地WandB API密钥
  14. import wandb
  15. import gc
  16. from tqdm import tqdm
  17. import torch
  18. from torch.utils.data import DataLoader
  19. from torch.nn.utils import clip_grad_norm_
  20. from src.model import load_model, llama_model_path
  21. from src.dataset import load_dataset
  22. from src.utils.evaluate import eval_funcs
  23. from src.utils.config import parse_args_llama
  24. from src.utils.ckpt import _save_checkpoint, _reload_best_model
  25. from src.utils.collate import collate_fn
  26. from src.utils.seed import seed_everything
  27. from src.utils.lr_schedule import adjust_learning_rate
  28. # 初始化CUDA状态 - 新增这行
  29. if torch.cuda.is_available():
  30. torch.cuda.init() # 确保CUDA正确初始化
  31. def main(args):
  32. # 打印可用设备信息 - 新增这行
  33. print(f"可用GPU设备数量: {torch.cuda.device_count()}")
  34. print(f"当前设备: {torch.cuda.current_device()}")
  35. # Step 1: Set up wandb
  36. seed = args.seed
  37. wandb.init(project=f"{args.project}",
  38. name=f"{args.dataset}_{args.model_name}_seed{seed}",
  39. config=args)
  40. seed_everything(seed=args.seed)
  41. print(args)
  42. dataset = load_dataset[args.dataset]()
  43. idx_split = dataset.get_idx_split()
  44. # Step 2: Build Dataset
  45. train_dataset = [dataset[i] for i in idx_split['train']]
  46. val_dataset = [dataset[i] for i in idx_split['val']]
  47. test_dataset = [dataset[i] for i in idx_split['test']]
  48. train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, pin_memory=True, shuffle=True,
  49. collate_fn=collate_fn)
  50. val_loader = DataLoader(val_dataset, batch_size=args.batch_size, drop_last=False, pin_memory=True, shuffle=False,
  51. collate_fn=collate_fn)
  52. test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True,
  53. shuffle=False, collate_fn=collate_fn)
  54. # 使用本地模型路径
  55. local_model_path = args.local_model_path
  56. print(f"使用本地模型路径: {local_model_path}")
  57. print(f"确保该路径包含完整的模型文件: config.json, pytorch_model.bin 等")
  58. # 设置模型路径
  59. args.llm_model_path = local_model_path
  60. # Step 3: Build Model
  61. # 添加显存优化选项 - 修改这行
  62. model = load_model[args.model_name](
  63. graph_type=dataset.graph_type,
  64. args=args,
  65. init_prompt=dataset.prompt,
  66. device_map="auto" # 自动分配设备
  67. )
  68. # 打印模型所在设备 - 新增这行
  69. print(f"模型加载到设备: {model.device}")
  70. # Step 4 Set Optimizer
  71. params = [p for _, p in model.named_parameters() if p.requires_grad]
  72. optimizer = torch.optim.AdamW(
  73. [{'params': params, 'lr': args.lr, 'weight_decay': args.wd}, ],
  74. betas=(0.9, 0.95)
  75. )
  76. trainable_params, all_param = model.print_trainable_params()
  77. print(
  78. f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
  79. # Step 5. Training
  80. num_training_steps = args.num_epochs * len(train_loader)
  81. progress_bar = tqdm(range(num_training_steps))
  82. best_val_loss = float('inf')
  83. #best_epoch = 0
  84. for epoch in range(args.num_epochs):
  85. # 清空显存缓存 - 新增这行
  86. #torch.cuda.empty_cache()
  87. model.train()
  88. epoch_loss, accum_loss = 0., 0.
  89. for step, batch in enumerate(train_loader):
  90. # 确保数据在正确设备上 - 新增这行
  91. if hasattr(model, 'device'):
  92. batch = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
  93. optimizer.zero_grad()
  94. loss = model(batch)
  95. loss.backward()
  96. clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
  97. if (step + 1) % args.grad_steps == 0:
  98. adjust_learning_rate(optimizer.param_groups[0], args.lr, step / len(train_loader) + epoch, args)
  99. optimizer.step()
  100. epoch_loss, accum_loss = epoch_loss + loss.item(), accum_loss + loss.item()
  101. if (step + 1) % args.grad_steps == 0:
  102. lr = optimizer.param_groups[0]["lr"]
  103. wandb.log({'Lr': lr})
  104. wandb.log({'Accum Loss': accum_loss / args.grad_steps})
  105. accum_loss = 0.
  106. progress_bar.update(1)
  107. print(f"Epoch: {epoch}|{args.num_epochs}: Train Loss (Epoch Mean): {epoch_loss / len(train_loader)}")
  108. wandb.log({'Train Loss (Epoch Mean)': epoch_loss / len(train_loader)})
  109. val_loss = 0.
  110. eval_output = []
  111. model.eval()
  112. with torch.no_grad():
  113. for step, batch in enumerate(val_loader):
  114. # 确保数据在正确设备上 - 新增这行
  115. if hasattr(model, 'device'):
  116. batch = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
  117. loss = model(batch)
  118. val_loss += loss.item()
  119. val_loss = val_loss / len(val_loader)
  120. print(f"Epoch: {epoch}|{args.num_epochs}: Val Loss: {val_loss}")
  121. wandb.log({'Val Loss': val_loss})
  122. if val_loss < best_val_loss:
  123. best_val_loss = val_loss
  124. _save_checkpoint(model, optimizer, epoch, args, is_best=True)
  125. best_epoch = epoch
  126. print(f'Epoch {epoch} Val Loss {val_loss} Best Val Loss {best_val_loss} Best Epoch {best_epoch}')
  127. if epoch - best_epoch >= args.patience:
  128. print(f'Early stop at epoch {epoch}')
  129. break
  130. torch.cuda.empty_cache()
  131. torch.cuda.reset_max_memory_allocated()
  132. # Step 5. Evaluating
  133. model = _reload_best_model(model, args)
  134. model.eval()
  135. eval_output = []
  136. progress_bar_test = tqdm(range(len(test_loader)))
  137. for step, batch in enumerate(test_loader):
  138. with torch.no_grad():
  139. # 确保数据在正确设备上 - 新增这行
  140. if hasattr(model, 'device'):
  141. batch = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
  142. output = model.inference(batch)
  143. eval_output.append(output)
  144. progress_bar_test.update(1)
  145. # Step 6. Post-processing & compute metrics
  146. os.makedirs(f'{args.output_dir}/{args.dataset}', exist_ok=True)
  147. path = f'{args.output_dir}/{args.dataset}/model_name_{args.model_name}_llm_model_name_{args.llm_model_name}_llm_frozen_{args.llm_frozen}_max_txt_len_{args.max_txt_len}_max_new_tokens_{args.max_new_tokens}_gnn_model_name_{args.gnn_model_name}_patience_{args.patience}_num_epochs_{args.num_epochs}_seed{seed}.csv'
  148. acc = eval_funcs[args.dataset](eval_output, path)
  149. print(f'Test Acc {acc}')
  150. wandb.log({'Test Acc': acc})
  151. if __name__ == "__main__":
  152. # 确保CUDA正确初始化 - 新增这行
  153. if torch.cuda.is_available():
  154. torch.cuda.init()
  155. args = parse_args_llama()
  156. # 添加本地模型路径参数
  157. if not hasattr(args, 'local_model_path'):
  158. args.local_model_path = "/data2/ycj/models/Llama-2-7b-chat-hf" # 默认路径
  159. main(args)
  160. torch.cuda.empty_cache()
  161. torch.cuda.reset_max_memory_allocated()
  162. gc.collect()