模型训练

SFT和DPO联合训练,单卡

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from trl import DPOTrainer, DPOConfig
from torch.utils.data import Dataset, DataLoader

# 加载模型和分词器
model_name = "your_model_name"  # 替换为你的模型名称
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 定义SFT数据集
class SFTDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        inputs = tokenizer(item["input"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        labels = tokenizer(item["output"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": labels["input_ids"].squeeze(),
        }

# 定义DPO数据集
class DPODataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        prompt = tokenizer(item["prompt"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        chosen = tokenizer(item["chosen"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        rejected = tokenizer(item["rejected"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        return {
            "prompt": prompt["input_ids"].squeeze(),
            "chosen": chosen["input_ids"].squeeze(),
            "rejected": rejected["input_ids"].squeeze(),
        }

# 加载SFT和DPO数据
sft_data = [...]  # 替换为你的SFT数据
dpo_data = [...]  # 替换为你的DPO数据

sft_dataset = SFTDataset(sft_data)
dpo_dataset = DPODataset(dpo_data)

sft_dataloader = DataLoader(sft_dataset, batch_size=8, shuffle=True)
dpo_dataloader = DataLoader(dpo_dataset, batch_size=1, shuffle=True)

# 设置训练参数
num_epochs = 3
learning_rate = 1e-5
sft_weight = 0.5  # SFT损失的权重
dpo_weight = 0.5  # DPO损失的权重

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 自定义训练循环
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for sft_batch, dpo_batch in zip(sft_dataloader, dpo_dataloader):
        # SFT训练
        sft_input_ids = sft_batch["input_ids"].to(device)
        sft_attention_mask = sft_batch["attention_mask"].to(device)
        sft_labels = sft_batch["labels"].to(device)

        sft_outputs = model(input_ids=sft_input_ids, attention_mask=sft_attention_mask, labels=sft_labels)
        sft_loss = sft_outputs.loss

        # DPO训练
        dpo_prompt = dpo_batch["prompt"].to(device)
        dpo_chosen = dpo_batch["chosen"].to(device)
        dpo_rejected = dpo_batch["rejected"].to(device)

        dpo_loss = dpo_trainer.compute_dpo_loss(model, dpo_prompt, dpo_chosen, dpo_rejected)

        # 混合损失
        loss = sft_weight * sft_loss + dpo_weight * dpo_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(sft_dataloader)}")

# 保存模型
model.save_pretrained("./final_model")

多卡

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer, DPOConfig

# 初始化分布式环境
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

# 清理分布式环境
def cleanup():
    dist.destroy_process_group()

# 定义SFT数据集
class SFTDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        inputs = self.tokenizer(item["input"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        labels = self.tokenizer(item["output"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": labels["input_ids"].squeeze(),
        }

# 定义DPO数据集
class DPODataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        prompt = self.tokenizer(item["prompt"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        chosen = self.tokenizer(item["chosen"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        rejected = self.tokenizer(item["rejected"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        return {
            "prompt": prompt["input_ids"].squeeze(),
            "chosen": chosen["input_ids"].squeeze(),
            "rejected": rejected["input_ids"].squeeze(),
        }

# 训练函数
def train(rank, world_size):
    setup(rank, world_size)

    # 加载模型和分词器
    model_name = "your_model_name"  # 替换为你的模型名称
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    # 将模型移动到指定设备
    device = torch.device(f"cuda:{rank}")
    model.to(device)

    # 包装为 DDP
    model = DDP(model, device_ids=[rank])

    # 加载SFT和DPO数据
    sft_data = [...]  # 替换为你的SFT数据
    dpo_data = [...]  # 替换为你的DPO数据

    sft_dataset = SFTDataset(sft_data, tokenizer)
    dpo_dataset = DPODataset(dpo_data, tokenizer)

    # 使用DistributedSampler
    sft_sampler = DistributedSampler(sft_dataset, num_replicas=world_size, rank=rank)
    dpo_sampler = DistributedSampler(dpo_dataset, num_replicas=world_size, rank=rank)

    sft_dataloader = DataLoader(sft_dataset, batch_size=8, sampler=sft_sampler)
    dpo_dataloader = DataLoader(dpo_dataset, batch_size=1, sampler=dpo_sampler)

    # 设置训练参数
    num_epochs = 3
    learning_rate = 1e-5
    sft_weight = 0.5  # SFT损失的权重
    dpo_weight = 0.5  # DPO损失的权重

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    # 自定义训练循环
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for sft_batch, dpo_batch in zip(sft_dataloader, dpo_dataloader):
            # SFT训练
            sft_input_ids = sft_batch["input_ids"].to(device)
            sft_attention_mask = sft_batch["attention_mask"].to(device)
            sft_labels = sft_batch["labels"].to(device)

            sft_outputs = model(input_ids=sft_input_ids, attention_mask=sft_attention_mask, labels=sft_labels)
            sft_loss = sft_outputs.loss

            # DPO训练
            dpo_prompt = dpo_batch["prompt"].to(device)
            dpo_chosen = dpo_batch["chosen"].to(device)
            dpo_rejected = dpo_batch["rejected"].to(device)

            dpo_loss = dpo_trainer.compute_dpo_loss(model, dpo_prompt, dpo_chosen, dpo_rejected)

            # 混合损失
            loss = sft_weight * sft_loss + dpo_weight * dpo_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if rank == 0:
            print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(sft_dataloader)}")

    # 保存模型(只在主进程保存)
    if rank == 0:
        model.module.save_pretrained("./final_model")

    cleanup()

# 主函数
def main():
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()

DPO loss计算

import torch
import torch.nn.functional as F

def compute_dpo_loss(model, prompt, chosen, rejected):
    # 将输入和输出拼接到一起
    chosen_input = torch.cat((prompt, chosen), dim=1)
    rejected_input = torch.cat((prompt, rejected), dim=1)

    # 计算模型的输出 logits
    chosen_logits = model(input_ids=chosen_input).logits
    rejected_logits = model(input_ids=rejected_input).logits

    # 计算生成概率
    chosen_prob = F.softmax(chosen_logits, dim=-1)
    rejected_prob = F.softmax(rejected_logits, dim=-1)

    # 计算偏好概率比
    chosen_prob = chosen_prob.gather(dim=-1, index=chosen.unsqueeze(-1)).squeeze(-1)
    rejected_prob = rejected_prob.gather(dim=-1, index=rejected.unsqueeze(-1)).squeeze(-1)

    preference_ratio = chosen_prob / (chosen_prob + rejected_prob)

    # 计算损失
    loss = -torch.log(preference_ratio).mean()

    return loss

# 示例使用
model_name = "your_model_name"  # 替换为你的模型名称
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 示例输入
prompt = torch.tensor([[1, 2, 3]])  # 替换为实际的 prompt
chosen = torch.tensor([[4, 5, 6]])  # 替换为实际的 chosen 输出
rejected = torch.tensor([[7, 8, 9]])  # 替换为实际的 rejected 输出

# 计算 DPO 损失
loss = compute_dpo_loss(model, prompt, chosen, rejected)
print(f"DPO Loss: {loss.item()}")

SFT loss计算

import torch
import torch.nn.functional as F

def compute_sft_loss(model, input_ids, attention_mask, labels):
    # 计算模型的输出 logits
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    logits = outputs.logits

    # 计算交叉熵损失
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)

    return loss

# 示例使用
model_name = "your_model_name"  # 替换为你的模型名称
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# 示例输入
input_ids = torch.tensor([[1, 2, 3, 4, 5]])  # 替换为实际的 input_ids
attention_mask = torch.tensor([[1, 1, 1, 1, 1]])  # 替换为实际的 attention_mask
labels = torch.tensor([[6, 7, 8, 9, 10]])  # 替换为实际的 labels

# 计算 SFT 损失
loss = compute_sft_loss(model, input_ids, attention_mask, labels)
print(f"SFT Loss: {loss.item()}")

强化学习

url = 'https://zhuanlan.zhihu.com/p/693582342'