Skip to content

如果要使用transformer和accelerator进行微调,应该怎么处理数据 #114

@yingli-Claire

Description

@yingli-Claire

使用huggging face官方给出的代码修改后

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch_npu
from torch.utils.data import Dataset, DataLoader
import time
from data_prepare import CPMDataset

torch_npu.npu.set_compile_mode(jit_compile=False)
torch.npu.empty_cache()

trainset = CPMDataset("basic_task_finetune/bee_data/eval.jsonl")
trainset = trainset[:100]
train_loader = DataLoader(trainset, batch_size=2)

model_path = "models/cpm-bee-2b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to('npu')

optimizer = torch.optim.Adam(model.parameters())

for iter, data in enumerate(train_loader):
    model.train()
    
    step_start = time.perf_counter()
    
    optimizer.zero_grad()
    input_encoded = tokenizer.prepare_for_finetune(data, max_length=1024).to(model.device)
    outputs = model(**input_encoded)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    
    step_time = time.perf_counter() - step_start
    
    print(f"Step {iter}, Loss: {loss.item():.4f}, Time per step: {step_time:.4f} s")

输出的loss为NaN

数据处理为

import json
from torch.utils.data import Dataset

class CPMDataset(Dataset):
    def __init__(self, jsonl_file):
        self.data = []
        with open(jsonl_file, 'r', encoding='utf-8') as file:
            for line in file:
                # 解析每一行 JSON 数据
                item = json.loads(line)
                # 提取需要的字段
                # input = item['input']
                # options = item['options']
                # question = item['question']
                # answer = item['<ans>']
                # input_text = f"{input}<sep>{question}<sep>{options}"
                # 将数据添加到列表中
                self.data.append(item)

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 返回格式化的数据
        return self.data[idx]

if __name__=="__main__":

    dataset = CPMDataset('eval.jsonl')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions