import os # 设置可见的GPU设备 - 新增这行 os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" # 只使用设备0和1 # 设置Hugging Face国内镜像源 (必须在import transformers之前设置) os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' # 强制离线模式 - 确保从本地加载模型而不联网 os.environ['TRANSFORMERS_OFFLINE'] = '1' os.environ['HF_DATASETS_OFFLINE'] = '1' # 新增:设置本地WandB服务地址(可选但推荐) os.environ['WANDB_BASE_URL'] = 'http://117.72.35.222:8081' # 添加这一行 # 设置WandB API密钥 (替换为您的实际密钥) os.environ['WANDB_API_KEY'] = 'local-84f1b57b7ab8df20a8881450fcfe22acb87bf9e2' # 替换为您的本地WandB API密钥 import wandb import gc from tqdm import tqdm import torch from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ from src.model import load_model, llama_model_path from src.dataset import load_dataset from src.utils.evaluate import eval_funcs from src.utils.config import parse_args_llama from src.utils.ckpt import _save_checkpoint, _reload_best_model from src.utils.collate import collate_fn from src.utils.seed import seed_everything from src.utils.lr_schedule import adjust_learning_rate # 初始化CUDA状态 - 新增这行 if torch.cuda.is_available(): torch.cuda.init() # 确保CUDA正确初始化 def main(args): # 打印可用设备信息 - 新增这行 print(f"可用GPU设备数量: {torch.cuda.device_count()}") print(f"当前设备: {torch.cuda.current_device()}") # Step 1: Set up wandb seed = args.seed wandb.init(project=f"{args.project}", name=f"{args.dataset}_{args.model_name}_seed{seed}", config=args) seed_everything(seed=args.seed) print(args) dataset = load_dataset[args.dataset]() idx_split = dataset.get_idx_split() # Step 2: Build Dataset train_dataset = [dataset[i] for i in idx_split['train']] val_dataset = [dataset[i] for i in idx_split['val']] test_dataset = [dataset[i] for i in idx_split['test']] train_loader = DataLoader(train_dataset, batch_size=args.batch_size, drop_last=True, pin_memory=True, shuffle=True, collate_fn=collate_fn) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, drop_last=False, pin_memory=True, shuffle=False, collate_fn=collate_fn) test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, drop_last=False, pin_memory=True, shuffle=False, collate_fn=collate_fn) # 使用本地模型路径 local_model_path = args.local_model_path print(f"使用本地模型路径: {local_model_path}") print(f"确保该路径包含完整的模型文件: config.json, pytorch_model.bin 等") # 设置模型路径 args.llm_model_path = local_model_path # Step 3: Build Model # 添加显存优化选项 - 修改这行 model = load_model[args.model_name]( graph_type=dataset.graph_type, args=args, init_prompt=dataset.prompt, device_map="auto" # 自动分配设备 ) # 打印模型所在设备 - 新增这行 print(f"模型加载到设备: {model.device}") # Step 4 Set Optimizer params = [p for _, p in model.named_parameters() if p.requires_grad] optimizer = torch.optim.AdamW( [{'params': params, 'lr': args.lr, 'weight_decay': args.wd}, ], betas=(0.9, 0.95) ) trainable_params, all_param = model.print_trainable_params() print( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}") # Step 5. Training num_training_steps = args.num_epochs * len(train_loader) progress_bar = tqdm(range(num_training_steps)) best_val_loss = float('inf') #best_epoch = 0 for epoch in range(args.num_epochs): # 清空显存缓存 - 新增这行 #torch.cuda.empty_cache() model.train() epoch_loss, accum_loss = 0., 0. for step, batch in enumerate(train_loader): # 确保数据在正确设备上 - 新增这行 if hasattr(model, 'device'): batch = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} optimizer.zero_grad() loss = model(batch) loss.backward() clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) if (step + 1) % args.grad_steps == 0: adjust_learning_rate(optimizer.param_groups[0], args.lr, step / len(train_loader) + epoch, args) optimizer.step() epoch_loss, accum_loss = epoch_loss + loss.item(), accum_loss + loss.item() if (step + 1) % args.grad_steps == 0: lr = optimizer.param_groups[0]["lr"] wandb.log({'Lr': lr}) wandb.log({'Accum Loss': accum_loss / args.grad_steps}) accum_loss = 0. progress_bar.update(1) print(f"Epoch: {epoch}|{args.num_epochs}: Train Loss (Epoch Mean): {epoch_loss / len(train_loader)}") wandb.log({'Train Loss (Epoch Mean)': epoch_loss / len(train_loader)}) val_loss = 0. eval_output = [] model.eval() with torch.no_grad(): for step, batch in enumerate(val_loader): # 确保数据在正确设备上 - 新增这行 if hasattr(model, 'device'): batch = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} loss = model(batch) val_loss += loss.item() val_loss = val_loss / len(val_loader) print(f"Epoch: {epoch}|{args.num_epochs}: Val Loss: {val_loss}") wandb.log({'Val Loss': val_loss}) if val_loss < best_val_loss: best_val_loss = val_loss _save_checkpoint(model, optimizer, epoch, args, is_best=True) best_epoch = epoch print(f'Epoch {epoch} Val Loss {val_loss} Best Val Loss {best_val_loss} Best Epoch {best_epoch}') if epoch - best_epoch >= args.patience: print(f'Early stop at epoch {epoch}') break torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() # Step 5. Evaluating model = _reload_best_model(model, args) model.eval() eval_output = [] progress_bar_test = tqdm(range(len(test_loader))) for step, batch in enumerate(test_loader): with torch.no_grad(): # 确保数据在正确设备上 - 新增这行 if hasattr(model, 'device'): batch = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} output = model.inference(batch) eval_output.append(output) progress_bar_test.update(1) # Step 6. Post-processing & compute metrics os.makedirs(f'{args.output_dir}/{args.dataset}', exist_ok=True) path = f'{args.output_dir}/{args.dataset}/model_name_{args.model_name}_llm_model_name_{args.llm_model_name}_llm_frozen_{args.llm_frozen}_max_txt_len_{args.max_txt_len}_max_new_tokens_{args.max_new_tokens}_gnn_model_name_{args.gnn_model_name}_patience_{args.patience}_num_epochs_{args.num_epochs}_seed{seed}.csv' acc = eval_funcs[args.dataset](eval_output, path) print(f'Test Acc {acc}') wandb.log({'Test Acc': acc}) if __name__ == "__main__": # 确保CUDA正确初始化 - 新增这行 if torch.cuda.is_available(): torch.cuda.init() args = parse_args_llama() # 添加本地模型路径参数 if not hasattr(args, 'local_model_path'): args.local_model_path = "/data2/ycj/models/Llama-2-7b-chat-hf" # 默认路径 main(args) torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() gc.collect()