跳过正文
PyTorch手搓网络结构
  1. posts/

PyTorch手搓网络结构

·2485 字·5 分钟·
陈驰水
作者
陈驰水
感谢您看到这里,祝您生活愉快
目录
从梯度下降到 MHA,面试常见的手搓网络

手搓梯度下降
#

经典题:用梯度下降法求 $\sqrt{x}$ 由于有 $$ \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 $$ 带尖的 $\hat{y}_i$ 是第 $i$ 个样本的预测值 $y_i$ 是第 $i$ 个样本真实值 因此此题就是把 $x^2$ 作为更新值去拟合输入 $y$ 带入 MSE 公式对 $x$ 求梯度,就有 $\nabla = 4x(x^2 - y)$

y = 500
x = y / 2
lr = 0.001
epochs = 10000
for epoch in range(epochs):
    if abs(x * x - y) <= 1e-5: # 终止
        break
    d = 4 * x * (x * x - y)
    if abs(d) >= 1e3: # 防止梯度爆炸
        d = 1e3 if d > 0 else -1e3
    x = x - lr * d # 更新梯度
    # print(x)
print(x)

# 此外,MSE 手搓就是 np.mean((y_pred - y_true) ** 2)
22.0

手搓线性回归
#

相比于上面的内容,线性回归的公式是 $\hat{y}_i = w \cdot x_i + b$,有两个要优化的参数 则求梯度后分别为

$$\nabla_b = \frac{\partial \text{MSE}}{\partial b} = \frac{1}{n} \sum_{i=1}^{n} -2(y_i - \hat{y}_i)$$

$$\nabla_w = \frac{\partial \text{MSE}}{\partial w} = \frac{1}{n} \sum_{i=1}^{n} -2x_i(y_i - \hat{y}_i)$$

# 数据构造:目标是 y = 2x + 1
x_data = [1.0, 2.0, 3.0, 4.0]
y_data = [3.0, 5.0, 7.0, 9.0]
n = len(x_data)
w = 0.0
b = 0.0
lr = 0.01
epochs = 2000
for epoch in range(epochs):
    w_grad = 0.0
    b_grad = 0.0
    for i in range(n):
        y_pred = w * x_data[i] + b
        error = y_data[i] - y_pred  # 误差 = 真实值 - 预测值
        w_grad += -2 * error * x_data[i]
        b_grad += -2 * error * 1.0
    # 求平均
    w_grad = w_grad / n
    b_grad = b_grad / n
    w = w - lr * w_grad
    b = b - lr * b_grad
    if epoch % 200 == 0:
        # 用 MSE 计算下 Loss
        loss = sum([(y - (w*x + b))**2 for x, y in zip(x_data, y_data)]) / n
        print(f"Epoch {epoch}: Loss={loss:.4f}")

print(f"y = {w:.2f}x + {b:.2f}")
Epoch 0: Loss=28.4532
Epoch 200: Loss=0.0041
Epoch 400: Loss=0.0012
Epoch 600: Loss=0.0004
Epoch 800: Loss=0.0001
Epoch 1000: Loss=0.0000
Epoch 1200: Loss=0.0000
Epoch 1400: Loss=0.0000
Epoch 1600: Loss=0.0000
Epoch 1800: Loss=0.0000
y = 2.00x + 1.00

交叉熵
#

二分类交叉熵 $$L = - \frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right]$$

多分类交叉熵 $$L = - \frac{1}{N} \sum_{i=1}^{N} \sum_{c=1}^{C} y_{i,c} \log(\hat{y}_{i,c})$$

公式内层实际上就是对 target 为 True 所对应的 predict 做 log

外层是对 Batch 做 sum ,注意从 0-1 的 log 是负数,做 loss 在外层要加负号

"""
y_pred: shape (batch_size, num_classes) — softmax 概率
y_true: shape (batch_size, num_classes) — one-hot 编码标签
"""
import numpy as np
def cross_entropy(y_pred, y_true):
    epsilon = 1e-12 # 要对 log 做一个 clip 来防止过高或过低
    y_pred = np.clip(y_pred, epsilon, 1. - epsilon)
    losses = -np.sum(y_true * np.log(y_pred), axis=1)
    return np.mean(losses)

softmax
#

原始公式是 $$\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}$$

带有温度的公式是

$$P(y_i) = \frac{\exp(z_i / T)}{\sum_{j} \exp(z_j / T)}$$

T = 1 就是原始公式,T 越大越平滑,T 越小越尖锐

def softmax(logits):
    # logits: shape (batch_size, num_classes)
    logits_stable = logits - np.max(logits, axis=1, keepdims=True)  # 防止数值溢出
    exps = np.exp(logits_stable)
    return exps / np.sum(exps, axis=1, keepdims=True)
logits = np.array([
    [2.0, 1.0, 0.1],
    [1.0, 3.0, 0.1],
    [0.5, 0.2, 2.0]
])

y_true = np.array([
    [1, 0, 0],
    [0, 1, 0],
    [0, 0, 1]
])

y_pred = softmax(logits)

loss = cross_entropy(y_pred, y_true)

print("Softmax 概率:\n", y_pred)
print("交叉熵损失:", loss)
Softmax 概率:
 [[0.65900114 0.24243297 0.09856589]
 [0.11369288 0.84008305 0.04622407]
 [0.16070692 0.11905462 0.72023846]]
交叉熵损失: 0.3064858227599003

MLP
#

最基础的多层感知机

class MLP(nn.Module):
    def __init__(self, dim, expansion_ratio=4):
        super().__init__()
        hidden_dim = dim * expansion_ratio  
        self.linear1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.linear2 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        # x: (B, T, dim)
        x = self.linear1(x)
        x = self.act(x)
        return self.linear2(x)

MHA + decoder-only
#

经典手搓多头注意力

$$\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) $$

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O$$

$$ \text{head}_i = \mathrm{Attention}(Q W_i^Q, K W_i^K, V W_i^V)$$

from math import sqrt
import torch
import torch.nn as nn

class MHA(nn.Module):
    def __init__(self, embed_dim, heads_nums):
        super().__init__()
        self.heads_nums = heads_nums
        self.head_dim = embed_dim // heads_nums
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.att_drop = nn.Dropout(0.1)

    def forward(self, x, mask=None):
        B, T, C = x.shape # batch_size, seq_len, input_dim
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(B, T, self.heads_nums, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.heads_nums, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.heads_nums, self.head_dim).transpose(1, 2)
        # Self-attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        # 合并多头 (B, T, C)
        attn = torch.softmax(scores, dim=-1)
        attn = self.att_drop(attn)
        context = torch.matmul(attn, v)
        context = context.transpose(1, 2)  # (B, T, num_heads, head_dim)
        context = context.reshape(B, T, C)
        return self.out_proj(context)

x = torch.rand(2, 3, 4)  # (batch=2, seq_len=3, input_dim=4)
attn = MHA(embed_dim=4, heads_nums=2)
y = attn(x)
print(y.shape) 
torch.Size([2, 3, 4])
class FFN(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )

    def forward(self, x):
        return self.net(x)

class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim):
        super().__init__()
        self.ln = nn.LayerNorm(embed_dim)
        self.attn = MHA(embed_dim, num_heads)
        self.ff = FFN(embed_dim, ff_hidden_dim)

    def forward(self, x, mask=None):
        x = x + self.attn(self.ln(x), mask)
        x = x + self.ff(self.ln(x))
        return x

class DecoderTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, num_layers, seq_len):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderBlock(embed_dim, num_heads, ff_hidden_dim) for _ in range(num_layers)
        ])

    def forward(self, x):
        # x: (B, T, C)
        B, T, C = x.shape
        pos = torch.arange(1, T + 1).unsqueeze(0).unsqueeze(-1)  # shape: (1, T, 1)
        x = x + pos  # 加上位置编码
        mask = torch.tril(torch.ones(T, T, device=x.device))
        for layer in self.layers:
            x = layer(x, mask)
        return x

x = torch.randn(2, 16, 64)  # batch_size=2, seq_len=16, embed_dim=64
model = DecoderTransformer(64, 4, 128, 2, 16)
softmax = nn.Softmax(dim=-1)
x = softmax(model(x))
print(x.shape)
print(x.sum(dim=-1))
print(x.argmax(dim=-1))
print(x)  # 输出: torch.Size([2, 16, 64])
torch.Size([2, 16, 64])
tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
       grad_fn=<SumBackward1>)
tensor([[18, 42, 23,  7,  4, 46, 55, 44, 50,  0, 21, 37, 10, 34, 58, 62],
        [29, 34, 25, 52, 33, 38, 26, 26,  4, 34, 35, 23,  7, 23, 42, 42]])
tensor([[[0.0111, 0.0078, 0.0045,  ..., 0.0258, 0.0063, 0.0017],
         [0.0158, 0.0293, 0.0179,  ..., 0.0096, 0.0078, 0.0041],
         [0.0093, 0.0107, 0.0251,  ..., 0.0029, 0.0207, 0.0067],
         ...,
         [0.0156, 0.0226, 0.0016,  ..., 0.0106, 0.0271, 0.0056],
         [0.0078, 0.0013, 0.0039,  ..., 0.0063, 0.0184, 0.0113],
         [0.0107, 0.0191, 0.0120,  ..., 0.0067, 0.1122, 0.0172]],

        [[0.0253, 0.0096, 0.0052,  ..., 0.0045, 0.0136, 0.0077],
         [0.0134, 0.0185, 0.0010,  ..., 0.0009, 0.0214, 0.0188],
         [0.0745, 0.0047, 0.0300,  ..., 0.0027, 0.0184, 0.0285],
         ...,
         [0.0304, 0.0221, 0.0018,  ..., 0.0050, 0.0666, 0.0119],
         [0.0214, 0.0179, 0.0366,  ..., 0.0189, 0.0061, 0.0135],
         [0.0031, 0.0099, 0.0039,  ..., 0.0037, 0.0037, 0.0322]]],
       grad_fn=<SoftmaxBackward0>)

MHA MQA GQA MLA
#

MHA 在上面已经写过了

MQA 是指拆分 Q矩阵 而不拆分 KV 矩阵

class MHA(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        B, T, C = x.shape # batch_size seq_len input_dim
        qkv = self.qkv(x)  # (B, T, 3 * C)
        q, k, v = qkv.chunk(3, dim=-1)  # 每个 (B, T, C)
        # 拆头,每个变为 (B, num_heads, T, head_dim)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        # Self-attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.head_dim)
        if mask is not None:
        # mask shape: (B, 1, T, T) or (B, T, T) → broadcast 到 (B, H, T, T)
            scores = scores.masked_fill(mask == 0, float('-inf'))
        # 合并多头 (B, T, C)
        attn = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)
        context = context.transpose(1, 2)  # (B, T, num_heads, head_dim)
        context = context.reshape(B, T, C)
        return self.out_proj(context)
class MQA(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.qkv = nn.Linear(embed_dim, embed_dim + 2 * self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        B, T, C = x.shape
        qkv = self.qkv(x)  # (B, T, embed_dim + 2 * head_dim)
        q, k, v = torch.split(qkv, [C, self.head_dim, self.head_dim], dim=-1)

        # Q 变成多头: (B, num_heads, T, head_dim)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # K, V 保持单头 (Broadcast 准备): (B, 1, T, head_dim)
        # 这里 dim=1 设为 1,方便后续与 Q 的 num_heads 广播计算
        k = k.view(B, T, 1, self.head_dim).transpose(1, 2)
        v = v.view(B, T, 1, self.head_dim).transpose(1, 2)

        # PyTorch 会自动将 dim=1 的 1 广播 (Broadcast) 到 H
        scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)

        # attn: (B, H, T, T)
        # v:    (B, 1, T, D) -> 广播计算
        context = torch.matmul(attn, v) # (B, H, T, D)
        context = context.transpose(1, 2).reshape(B, T, C)
        return self.out_proj(context)
class GQA(nn.Module):
    def __init__(self, embed_dim, num_heads, num_kv_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        assert num_heads % num_kv_heads == 0 # 必须能整除,保证分组均匀
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = embed_dim // num_heads
        self.group_size = num_heads // num_kv_heads

        # Q 有 num_heads 个头,K/V 只有 num_kv_heads 个头
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.kv_proj = nn.Linear(embed_dim, 2 * self.num_kv_heads * self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        B, T, C = x.shape

        # 计算 Q, K, V
        q = self.q_proj(x) # (B, T, num_heads * head_dim)
        kv = self.kv_proj(x) # (B, T, 2 * num_kv_heads * head_dim)
        k, v = torch.split(kv, [self.num_kv_heads * self.head_dim, self.num_kv_heads * self.head_dim], dim=-1)

        # Q: (B, num_heads, T, head_dim)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # K, V: (B, num_kv_heads, T, head_dim)
        k = k.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)

        # GQA 核心:将 KV 在 dim=1 (head维度) 复制 group_size 次,以匹配 Q 的 head 数量
        # (B, num_kv_heads, ...) -> (B, num_heads, ...)
        k = k.repeat_interleave(self.group_size, dim=1)
        v = v.repeat_interleave(self.group_size, dim=1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        # context: (B, H, T, D)
        context = torch.matmul(attn, v)
        context = context.transpose(1, 2).reshape(B, T, C)
        return self.out_proj(context)
class MLA(nn.Module):
    def __init__(self, embed_dim, num_heads, kv_lora_rank):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.kv_lora_rank = kv_lora_rank # 压缩后的潜变量维度

        # Q 保持标准投影 (也可进行 Low-Rank 压缩,此处为简化保持标准)
        self.q_proj = nn.Linear(embed_dim, embed_dim)

        # KV 压缩:先降维 (Down) 再升维 (Up) 生成多头
        # 1. Down Projection: 压缩成潜变量 c_KV
        self.kv_down_proj = nn.Linear(embed_dim, kv_lora_rank)
        self.kv_norm = nn.LayerNorm(kv_lora_rank) # MLA 通常在潜空间加 Norm

        # 2. Up Projection: 从潜变量生成多头 K 和 V
        self.kv_up_proj = nn.Linear(kv_lora_rank, 2 * embed_dim) # 生成 num_heads * head_dim * 2

        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        B, T, C = x.shape

        # 计算 Q: (B, num_heads, T, head_dim)
        q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # 1. 压缩: 生成 Latent KV (B, T, kv_lora_rank)
        c_kv = self.kv_norm(self.kv_down_proj(x))

        # 2. 还原: 生成多头 K, V (B, T, 2 * embed_dim)
        # 注意:在推理时,可以将 kv_up_proj 矩阵吸收到 Q 的投影中,无需显式还原,这里展示训练前向过程
        kv = self.kv_up_proj(c_kv)
        k, v = torch.split(kv, [C, C], dim=-1)

        # K, V 变多头: (B, num_heads, T, head_dim)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # 标准 Attention 计算
        scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)

        context = torch.matmul(attn, v) # (B, H, T, D)
        context = context.transpose(1, 2).reshape(B, T, C)
        return self.out_proj(context)

手搓 LoRA
#

  1. 权重分解: W_0 是冻结权重,BA 是低秩更新 $$W = W_0 + \Delta W = W_0 + B A \frac{\alpha}{r}$$

  2. 前向传播: 原始输出 + 旁路输出 x: 输入, r: rank, alpha: 缩放因子 $$h = W_0 x + B (A x) \frac{\alpha}{r}$$

class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, rank=4, alpha=1.0):
        super().__init__()
        # 1. 冻结的主路 (Pretrained)
        self.pretrained = nn.Linear(in_features, out_features)
        for param in self.pretrained.parameters():
            param.requires_grad = False  # 锁死权重
        # 2. LoRA 旁路 (Adapter)
        # A: 降维 (in -> rank)
        self.lora_a = nn.Linear(in_features, rank, bias=False)
        # B: 升维 (rank -> out)
        self.lora_b = nn.Linear(rank, out_features, bias=False)

        # 3. 关键初始化: B 设为 0,A 设为高斯分布
        nn.init.normal_(self.lora_a.weight, std=0.02)
        nn.init.zeros_(self.lora_b.weight)

        self.scaling = alpha / rank

    def forward(self, x):
        x_out = self.pretrained(x)
        lora_out = self.lora_b(self.lora_a(x))
        return x_out + lora_out * self.scaling

相关文章