SFT和DPO联合训练,单卡
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from trl import DPOTrainer, DPOConfig
from torch.utils.data import Dataset, DataLoader
# 加载模型和分词器
model_name = "your_model_name" # 替换为你的模型名称
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 定义SFT数据集
class SFTDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
inputs = tokenizer(item["input"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
labels = tokenizer(item["output"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
return {
"input_ids": inputs["input_ids"].squeeze(),
"attention_mask": inputs["attention_mask"].squeeze(),
"labels": labels["input_ids"].squeeze(),
}
# 定义DPO数据集
class DPODataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
prompt = tokenizer(item["prompt"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
chosen = tokenizer(item["chosen"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
rejected = tokenizer(item["rejected"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
return {
"prompt": prompt["input_ids"].squeeze(),
"chosen": chosen["input_ids"].squeeze(),
"rejected": rejected["input_ids"].squeeze(),
}
# 加载SFT和DPO数据
sft_data = [...] # 替换为你的SFT数据
dpo_data = [...] # 替换为你的DPO数据
sft_dataset = SFTDataset(sft_data)
dpo_dataset = DPODataset(dpo_data)
sft_dataloader = DataLoader(sft_dataset, batch_size=8, shuffle=True)
dpo_dataloader = DataLoader(dpo_dataset, batch_size=1, shuffle=True)
# 设置训练参数
num_epochs = 3
learning_rate = 1e-5
sft_weight = 0.5 # SFT损失的权重
dpo_weight = 0.5 # DPO损失的权重
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 自定义训练循环
for epoch in range(num_epochs):
model.train()
total_loss = 0
for sft_batch, dpo_batch in zip(sft_dataloader, dpo_dataloader):
# SFT训练
sft_input_ids = sft_batch["input_ids"].to(device)
sft_attention_mask = sft_batch["attention_mask"].to(device)
sft_labels = sft_batch["labels"].to(device)
sft_outputs = model(input_ids=sft_input_ids, attention_mask=sft_attention_mask, labels=sft_labels)
sft_loss = sft_outputs.loss
# DPO训练
dpo_prompt = dpo_batch["prompt"].to(device)
dpo_chosen = dpo_batch["chosen"].to(device)
dpo_rejected = dpo_batch["rejected"].to(device)
dpo_loss = dpo_trainer.compute_dpo_loss(model, dpo_prompt, dpo_chosen, dpo_rejected)
# 混合损失
loss = sft_weight * sft_loss + dpo_weight * dpo_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(sft_dataloader)}")
# 保存模型
model.save_pretrained("./final_model")
多卡
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer, DPOConfig
# 初始化分布式环境
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
# 清理分布式环境
def cleanup():
dist.destroy_process_group()
# 定义SFT数据集
class SFTDataset(torch.utils.data.Dataset):
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
inputs = self.tokenizer(item["input"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
labels = self.tokenizer(item["output"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
return {
"input_ids": inputs["input_ids"].squeeze(),
"attention_mask": inputs["attention_mask"].squeeze(),
"labels": labels["input_ids"].squeeze(),
}
# 定义DPO数据集
class DPODataset(torch.utils.data.Dataset):
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
prompt = self.tokenizer(item["prompt"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
chosen = self.tokenizer(item["chosen"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
rejected = self.tokenizer(item["rejected"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)
return {
"prompt": prompt["input_ids"].squeeze(),
"chosen": chosen["input_ids"].squeeze(),
"rejected": rejected["input_ids"].squeeze(),
}
# 训练函数
def train(rank, world_size):
setup(rank, world_size)
# 加载模型和分词器
model_name = "your_model_name" # 替换为你的模型名称
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 将模型移动到指定设备
device = torch.device(f"cuda:{rank}")
model.to(device)
# 包装为 DDP
model = DDP(model, device_ids=[rank])
# 加载SFT和DPO数据
sft_data = [...] # 替换为你的SFT数据
dpo_data = [...] # 替换为你的DPO数据
sft_dataset = SFTDataset(sft_data, tokenizer)
dpo_dataset = DPODataset(dpo_data, tokenizer)
# 使用DistributedSampler
sft_sampler = DistributedSampler(sft_dataset, num_replicas=world_size, rank=rank)
dpo_sampler = DistributedSampler(dpo_dataset, num_replicas=world_size, rank=rank)
sft_dataloader = DataLoader(sft_dataset, batch_size=8, sampler=sft_sampler)
dpo_dataloader = DataLoader(dpo_dataset, batch_size=1, sampler=dpo_sampler)
# 设置训练参数
num_epochs = 3
learning_rate = 1e-5
sft_weight = 0.5 # SFT损失的权重
dpo_weight = 0.5 # DPO损失的权重
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# 自定义训练循环
for epoch in range(num_epochs):
model.train()
total_loss = 0
for sft_batch, dpo_batch in zip(sft_dataloader, dpo_dataloader):
# SFT训练
sft_input_ids = sft_batch["input_ids"].to(device)
sft_attention_mask = sft_batch["attention_mask"].to(device)
sft_labels = sft_batch["labels"].to(device)
sft_outputs = model(input_ids=sft_input_ids, attention_mask=sft_attention_mask, labels=sft_labels)
sft_loss = sft_outputs.loss
# DPO训练
dpo_prompt = dpo_batch["prompt"].to(device)
dpo_chosen = dpo_batch["chosen"].to(device)
dpo_rejected = dpo_batch["rejected"].to(device)
dpo_loss = dpo_trainer.compute_dpo_loss(model, dpo_prompt, dpo_chosen, dpo_rejected)
# 混合损失
loss = sft_weight * sft_loss + dpo_weight * dpo_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if rank == 0:
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(sft_dataloader)}")
# 保存模型(只在主进程保存)
if rank == 0:
model.module.save_pretrained("./final_model")
cleanup()
# 主函数
def main():
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
main()
DPO loss计算
import torch
import torch.nn.functional as F
def compute_dpo_loss(model, prompt, chosen, rejected):
# 将输入和输出拼接到一起
chosen_input = torch.cat((prompt, chosen), dim=1)
rejected_input = torch.cat((prompt, rejected), dim=1)
# 计算模型的输出 logits
chosen_logits = model(input_ids=chosen_input).logits
rejected_logits = model(input_ids=rejected_input).logits
# 计算生成概率
chosen_prob = F.softmax(chosen_logits, dim=-1)
rejected_prob = F.softmax(rejected_logits, dim=-1)
# 计算偏好概率比
chosen_prob = chosen_prob.gather(dim=-1, index=chosen.unsqueeze(-1)).squeeze(-1)
rejected_prob = rejected_prob.gather(dim=-1, index=rejected.unsqueeze(-1)).squeeze(-1)
preference_ratio = chosen_prob / (chosen_prob + rejected_prob)
# 计算损失
loss = -torch.log(preference_ratio).mean()
return loss
# 示例使用
model_name = "your_model_name" # 替换为你的模型名称
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 示例输入
prompt = torch.tensor([[1, 2, 3]]) # 替换为实际的 prompt
chosen = torch.tensor([[4, 5, 6]]) # 替换为实际的 chosen 输出
rejected = torch.tensor([[7, 8, 9]]) # 替换为实际的 rejected 输出
# 计算 DPO 损失
loss = compute_dpo_loss(model, prompt, chosen, rejected)
print(f"DPO Loss: {loss.item()}")
SFT loss计算
import torch
import torch.nn.functional as F
def compute_sft_loss(model, input_ids, attention_mask, labels):
# 计算模型的输出 logits
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
logits = outputs.logits
# 计算交叉熵损失
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
return loss
# 示例使用
model_name = "your_model_name" # 替换为你的模型名称
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 示例输入
input_ids = torch.tensor([[1, 2, 3, 4, 5]]) # 替换为实际的 input_ids
attention_mask = torch.tensor([[1, 1, 1, 1, 1]]) # 替换为实际的 attention_mask
labels = torch.tensor([[6, 7, 8, 9, 10]]) # 替换为实际的 labels
# 计算 SFT 损失
loss = compute_sft_loss(model, input_ids, attention_mask, labels)
print(f"SFT Loss: {loss.item()}")
强化学习
url = 'https://zhuanlan.zhihu.com/p/693582342'