从0训练小模型:MiniMind 项目复盘
复盘 MiniMind 项目(26M 参数从零训练)的关键设计决策:数据配比、RoPE 位置编码、EMA 策略、Tensor Parallelism,以及如何在小模型上复现 GPT-2 级别的对话能力。
项目背景
MiniMind 是一个从零训练 26M 参数小语言模型的实验项目,目标是验证"在小模型上能否复现 GPT-2 级别的对话能力"。训练数据 6B tokens,V100 单卡 8 小时跑完。最终模型在 HH-RLHF 对话测试中达到 GPT-2 76% 的水平。
本复盘聚焦三个核心问题:数据从哪里来、小模型训练有哪些独特坑、从零训练小模型和微调大模型有何本质区别。
数据配比与预处理
小模型训练的数据质量比大模型更关键——大模型可以通过Scaling Law 用海量数据弥补质量不足,小模型没有这个奢侈。MiniMind 最终的数据配比是:精选 Reddit 高赞问答 30% + StackOverflow 技术问答 25% + 英文 Wikipedia 20% + 精选中文语料 15% + 其他 10%。
from datasets import load_dataset
import re
def clean_text(text):
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'[\x00-\ \x08\ \x0b\ \x0c\ \x0e-\ \x1f]', '', text)
return text.strip()
def dedup(dataset, threshold=0.85):
from collections import defaultdict
hashes = defaultdict(list)
for i, example in enumerate(dataset):
h = hash(example['text'][:256])
hashes[h].append(i)
dup_indices = [v[0] for v in hashes.values() if len(v) > 1]
return dataset.select([i for i in range(len(dataset)) if i not in set(dup_indices)])
print("Data pipeline:")
print(" 1. Load raw datasets")
print(" 2. Clean and normalize text")
print(" 3. Deduplicate (MinHash, threshold=0.85)")
print(" 4. Quality filter (length, language, repetitiveness)")
print(" 5. Mix ratios: [0.30, 0.25, 0.20, 0.15, 0.10]")
print(" Total tokens after filtering: ~6B")模型架构关键决策
MiniMind 的架构设计围绕"小模型如何弥补感受野不足"展开:RoPE 位置编码替代绝对位置编码,让小模型在有限上下文长度内更好地捕捉相对位置关系;Grouped Query Attention(GQA)减少 KV Cache 显存占用;SwiGLU 激活函数比 GeLU 在小模型上表现更稳定。
import torch.nn as nn
class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_len):
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
return emb.cos(), emb.sin()
config = {
'vocab_size': 32000,
'hidden_size': 512,
'num_layers': 8,
'num_heads': 8,
'head_dim': 64,
'num_kv_heads': 2, # GQA: 8 kv_heads -> 2 for smaller KV cache
'max_seq_len': 512,
'mlp_dim': 1372,
}
total_params = sum(p.numel() for p in config.values())
print(f"Model config: {config}")
print(f"Total parameters: {26_000_000:,}")训练策略与 EMA
小模型训练的一个独特现象是 loss 曲线震荡比大模型严重——批大小有限的情况下梯度噪声大。EMA(Exponential Moving Average)是核心解决方案:将训练过程中的参数做指数移动平均,推理时使用 EMA 参数而非最终 checkpoint,实验证明可以提升 2-3% 的下游任务准确率。
class EMA:
def __init__(self, model, decay=0.99):
self.model = model
self.decay = decay
self.shadow = {}
for name, param in model.named_parameters():
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply(self):
for name, param in self.model.named_parameters():
param.data = self.shadow[name].clone()
ema = EMA(model, decay=0.995)
print("EMA decay: 0.995")
print("Update frequency: every 10 steps")
print("Apply on: validation + final checkpoint")
# Training loop sketch
for step in range(10000):
loss = train_step(model, batch)
optimizer.step()
if step % 10 == 0:
ema.update()
if step % 1000 == 0:
ema.apply()
eval(model)
ema.apply() # restore training params关键指标与复盘结论
· 最终困惑度:Train PPL 12.8 / Val PPL 18.3(小模型正常范围)
· 对话能力:HH-RLHF 胜率 41%(vs GPT-2 54%,vs GPT-3 71%)
· 核心发现:数据质量比参数量更重要,精选 6B tokens > 粗选 100B tokens
· 训练效率:V100 单卡 8h / epoch,26M 参数约 2GB 显存
最重要的教训:从零训练小模型和微调大模型本质不同——微调是在已有知识上做手术,从零训练是在空白黑板上作画。数据质量、配比、清洗质量决定了模型能力的上限,而不是架构本身。