1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
| import torch from torch.utils.data import DataLoader from transformers import AdamW, BertTokenizer import pandas as pd import os from sklearn.model_selection import train_test_split from sklearn.metrics import precision_recall_fscore_support, accuracy_score import openpyxl from openpyxl import Workbook
def evaluate_model(model, dataloader, device): model.eval() preds = [] true_labels = []
for batch in dataloader: batch = tuple(t.to(device) for t in batch) token_ids, attn_masks, token_type_ids, labels = batch
with torch.no_grad(): logits = model(token_ids, attn_masks, token_type_ids)
preds.extend(torch.argmax(logits, axis=1).cpu().numpy()) true_labels.extend(labels.cpu().numpy())
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, preds, average='weighted') acc = accuracy_score(true_labels, preds) #记录一下我的验证的结果 return { 'accuracy': acc, 'precision': precision, 'recall': recall, 'f1': f1, 'preds': preds }
# 保存模型和分词器 def save_model_and_tokenizer(model, tokenizer, dir_path): os.makedirs(dir_path, exist_ok=True) torch.save(model.state_dict(), os.path.join(dir_path, 'model.pth')) tokenizer.save_pretrained(dir_path)
# 保存指标和模型的函数 def save_metrics(epoch, avg_train_loss, metrics, best_f1, model, val_dataloader, device, tokenizer): # 每隔五轮就保存一次训练的模型 if epoch % 5 == 0: save_model_and_tokenizer(model, tokenizer, f'./model_weights/epoch_{epoch}')
if metrics['f1'] > best_f1: #这里保存了我训练完的模型,最好的那个模型 best_f1 = metrics['f1'] save_model_and_tokenizer(model, tokenizer, './model_weights/best_model')
val_metrics = evaluate_model(model, val_dataloader, device) predictions = val_metrics['preds'] val_results = pd.DataFrame({'text': val_texts, 'label': val_labels, 'prediction': predictions}) val_results.to_csv(f'./val_results/epoch_{epoch}.csv', index=False)
# 更新指标 print(f"Epoch: {epoch}, Loss: {avg_train_loss:.4f}, Accuracy: {metrics['accuracy']:.4f}, Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, F1: {metrics['f1']:.4f}")
df = pd.read_csv('./data_news/train_set.csv', sep='\t') train_texts, val_texts, train_labels, val_labels = train_test_split(df['text'].tolist(), df['label'].astype(int).tolist(), test_size=0.2)
train_dataset = XuanDataset(train_texts, train_labels, max_length=128) val_dataset = XuanDataset(val_texts, val_labels, max_length=128)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True) val_dataloader = DataLoader(val_dataset, batch_size=8)
model = BertWithLoRA(num_classes=16) device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
os.makedirs('./model_weights', exist_ok=True) os.makedirs('./val_results', exist_ok=True)
excel_file = 'metrics1.xlsx' if not os.path.exists(excel_file): wb = Workbook() ws = wb.active ws.title = 'Metrics' ws.append(['Epoch', 'Loss', 'Accuracy', 'Precision', 'Recall', 'F1']) wb.save(excel_file)
optimizer = AdamW(model.parameters(), lr=2e-5) epochs = 100 # 训练轮数 best_f1 = 0.0
for epoch in range(1, epochs + 1): model.train() total_loss = 0
for batch in train_dataloader: batch = tuple(t.to(device) for t in batch) token_ids, attn_masks, token_type_ids, labels = batch
optimizer.zero_grad() logits = model(token_ids, attn_masks, token_type_ids) loss = nn.CrossEntropyLoss()(logits, labels) loss.backward() optimizer.step()
total_loss += loss.item()
avg_train_loss = total_loss / len(train_dataloader) # 验证模型 metrics = evaluate_model(model, val_dataloader, device) wb = openpyxl.load_workbook(excel_file) ws = wb['Metrics'] ws.append([epoch, avg_train_loss, metrics['accuracy'], metrics['precision'], metrics['recall'], metrics['f1']]) wb.save(excel_file)
# 保存模型和打印评价指标 save_metrics(epoch, avg_train_loss, metrics, best_f1, model, val_dataloader, device, tokenizer)
|