前言
最近在看 Qwen3-Next 的实现,发现它在一部分层里没有用 Full Attention,而是用了 Gated Delta Net,刚刚发布的 Qwen3.5 系列也沿用了这一设计。
更早一些的工作,例如 MiMo 的 混合 SWA,Jamba的混合 manba 模块,也是希望通过替换掉大部分 Full Attention 来提升长序列效率。
这个结构很有意思:
- 一方面,它继承了线性 RNN / 线性 Attention 的线性时间和常数状态推理优势;
- 另一方面,它把 “gating(遗忘)” 和 “delta rule(定向写入)” 合到了一起,改善了长上下文下的记忆管理和检索能力。
这篇笔记主要参考两份材料:
- 论文 Gated Delta Networks: Improving Mamba2 with Delta Rule(核心公式与训练算法)
- Transformers 里的
models/qwen3_next/modular_qwen3_next.py(Qwen3-Next 工程实现)
接下来,我会先讲解 Gated Delta Net 的核心算法(先理解推理,再训练并行化),然后分析它的计算复杂度,最后回顾它在各种线性 RNN 中的演化脉络。
Gated Delta Net 算法详解
一点废话
在正式开始之前,让我们把视角抬高一些,看看Transformer模型的结构。Transformer模型的结构,是一层attention,一层FFN如此交错,他们分别代表了2个不同的功能:
- Attention:负责跨位置的信息交互,也就是“上下文”,让模型能够捕捉序列中不同位置之间的依赖关系。
- FFN:负责位置内的信息变换,也就是“知识”,让模型能够对每个位置的特征进行非线性变换和增强。
FFN的部分在DeepSeek系列已经讲过了,通过MoE带来的稀疏化,在保持推理速度不变的约束下,大幅提升了模型的容量,提升模型的能力;包括近期出现的Deepseek ENGram,将知识做成外置,也是一个FFN扩容的探索。 Attention的部分,尤其是长序列的Attention,一直是效率的瓶颈。Gated Delta Net就是在这个背景下,基于古老RNN的一个设计,旨在通过一种更高效的方式来实现长序列的上下文交互,同时保持模型的表达能力。
这两部分都被替换之后,Transformer模型是否就变成了忒修斯之船?我觉得不完全是。虽然结构上发生了改变,但它的功能定位(上下文交互)依然存在。
Gated Delta Rule
先看核心公式,Gated Delta Rule:
$$ \mathbf{S}_t =\mathbf{S}_{t-1} \left( \alpha_t(\mathbf{I}-\beta_t \mathbf{k}_t \mathbf{k}_t^T) \right) + \beta_t \mathbf{v}_t \mathbf{k}_t^T $$其中:
- $\mathbf{S}_t \in \mathbb{R}^{d_v \times d_k}$:时刻 $t$ 的“快权重/记忆矩阵”
- $\mathbf{k}_t \in \mathbb{R}^{d_k}$,$\mathbf{v}_t \in \mathbb{R}^{d_v}$,$\mathbf{q}_t \in \mathbb{R}^{d_k}$
- $\beta_t \in (0,1)$:写入强度(学习率)
- $\alpha_t \in (0,1)$:遗忘门控(衰减)
- $\mathbf{I}$:单位矩阵
直觉上:
- $\alpha_t$ 控制“整体保留多少旧记忆”;
- $\beta_t$ 控制“针对当前 key 写入多少新 value”;
- $\mathbf{v}_t \mathbf{k}_t^T$ 是当前输入的增量写入项;
- 表示在 $\mathbf{k}_t$ 方向上写入 $\mathbf{v}_t$ 的信息;
- $(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T)$ 是带有“定向擦写”含义的变换。
- 定向:$\mathbf{k}_t\mathbf{k}_t^T$ 是张量积
- 擦写:从$\mathbf{I}$中减去,让 $\mathbf{S}_{t-1}$ 在 $\mathbf{k}_t$ 方向上的分量被衰减
- 衰减:乘以 $\alpha_t$ 进一步控制整体遗忘程度
输出一般写为:
$$ \mathbf{o}_t = \mathbf{S}_t \mathbf{q}_t \in \mathbb{R}^{d_v} $$代码精讲:GDN前的准备
输入与缓存状态判定
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
# 注:把 padding 位的状态清理掉,避免线性递推状态被无效 token 污染。
batch_size, seq_len, _ = hidden_states.shape
# 注:后续 reshape / decode 分支判断都会用到 seq_len。
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_position is not None
)
# 注:只有 decode 单步(seq_len==1)且已有历史状态时,才走 recurrent 路径;
# 否则走 chunk 路径(prefill/训练)。
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
# 注:从动态缓存读取卷积状态与递推状态,供增量推理复用。
双路线性投影 + 拆分(重点看 z / b / a)
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
projected_states_ba = self.in_proj_ba(hidden_states)
# 注:两路投影:
# - qkvz 路:产出 query/key/value/z
# - ba 路:产出 b/a(后续分别变成 beta 和 g 的输入)
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
# 注:在 fix_query_key_value_ordering 内部会先按 head 分组,再 split:
# - qkvz 按 [q, k, v, z] 拆分
# - ba 按 [b, a] 拆分
拆分和形状修复
def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`.
"""
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
self.num_k_heads,
2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads,
)
new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)
mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
# mixed_qkvz
# -> view -> [B, L, num_k_heads, 2*head_k_dim + 2*head_v_dim*(num_v_heads/num_k_heads)]
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
# mixed_ba: [B, L, 2*num_v_heads]
# -> view -> [B, L, num_k_heads, 2*(num_v_heads/num_k_heads)]
split_arg_list_qkvz = [
self.head_k_dim,
self.head_k_dim,
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
]
split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]
query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3)
# 在 dim=3 上 split 成 q/k/v/z
# query: [B, L, num_k_heads, head_k_dim]
# key: [B, L, num_k_heads, head_k_dim]
# value: [B, L, num_k_heads, (num_v_heads/num_k_heads)*head_v_dim]
# z: [B, L, num_k_heads, (num_v_heads/num_k_heads)*head_v_dim]
b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3)
# 在 dim=3 上 split 成 b/a
# b: [B, L, num_k_heads, (num_v_heads/num_k_heads)]
# a: [B, L, num_k_heads, (num_v_heads/num_k_heads)]
value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim)
# value: [B, L, num_k_heads, (num_v_heads/num_k_heads)*head_v_dim]
# -> reshape -> [B, L, num_v_heads, head_v_dim]
z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim)
# z: [B, L, num_k_heads, (num_v_heads/num_k_heads)*head_v_dim]
# -> reshape -> [B, L, num_v_heads, head_v_dim]
b = b.reshape(b.size(0), b.size(1), self.num_v_heads)
# b: [B, L, num_k_heads, (num_v_heads/num_k_heads)]
# -> reshape -> [B, L, num_v_heads]
a = a.reshape(a.size(0), a.size(1), self.num_v_heads)
# a: [B, L, num_k_heads, (num_v_heads/num_k_heads)]
# -> reshape -> [B, L, num_v_heads]
return query, key, value, z, b, a
QKV 因果卷积混合
query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))
# 注:先把头维摊平,便于拼接做一维深度卷积。
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = mixed_qkv.transpose(1, 2)
# 注:卷积输入形状转成 [B, C, L],C=2*key_dim+value_dim。
if use_precomputed_states:
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
# 注:decode 增量卷积,只更新最后一步,避免重复算历史。
else:
if cache_params is not None:
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
## 这部分的参数格式是F.pad(input, (pad_left, pad_right));
## 如果大于kernel_size,pad_left就会是负数,表示裁剪掉前面多余的部分;
## 如果小于kernel_size,就会在前面补零。
cache_params.conv_states[self.layer_idx] = conv_state
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
# 注:prefill/训练路径:
# - 有快实现就用 fused kernel;
# - 否则回退到 torch conv1d + silu;
# - 若开 cache,同步初始化 conv_state。
mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1)
# 注:卷积后再拆回 Q/K/V 三路。
重排头维并构造 beta/g
query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
# 注:恢复为多头形状 [B, L, H, d]。
beta = b.sigmoid()
# 注:写入强度 beta∈(0,1)。
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
# 注:得到 log-space 的衰减参数;后续在 rule 内部 exp(g)->alpha∈(0,1)。
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
# 注:当 v-head 多于 k-head 时,把 K/Q 复制到对齐的头数,保证后续逐头计算可广播。
概括来说,这部分代码的主要功能是: 投影拆分 → 局部卷积混合 → 头维对齐 → 生成 beta/g。
变量名与公式符号对照(对应本小节准备阶段)
| 代码变量名 | 公式符号 | 典型形状(多头) | 含义 |
|---|---|---|---|
hidden_states | 输入序列(公式外) | [B, L, D] | 模块输入隐藏状态 |
query | $\mathbf{q}_t$ | [B, L, H, d_k] | 查询向量(每时刻/每头) |
key | $\mathbf{k}_t$ | [B, L, H, d_k] | 键向量(每时刻/每头) |
value | $\mathbf{v}_t$ | [B, L, H, d_v] | 值向量(每时刻/每头) |
b | $\beta_t$ 的参数化前体 | [B, L, H] | 经过 sigmoid 前的写入强度参数 |
beta = b.sigmoid() | $\beta_t \in (0,1)$ | [B, L, H] | 写入强度(学习率) |
a | $\alpha_t$ 的参数化前体 | [B, L, H] | 经过变换前的遗忘门控参数 |
g | $\log \alpha_t$(实现中) | [B, L, H] | 衰减的 log-space 参数,通常为负 |
g.exp() | $\alpha_t \in (0,1)$ | [B, L, H] | 真正用于递推的遗忘系数 |
recurrent_state / last_recurrent_state | $\mathbf{S}_{t-1}$ / $\mathbf{S}_t$ | [B, H, d_k, d_v] | 快权重记忆矩阵(递推状态) |
z | 输出门控(公式外) | [B, L, H, d_v] | 用于后续 RMSNormGated 的门控分支 |
注:文中公式使用 $\alpha_t$;代码里常先算
g,在 rule 内部通过exp(g)得到 $\alpha_t$。 注2:代码实现通过计算kv_mem = S_{t-1} k_t和delta = (v_t - kv_mem) * beta_t,巧妙地合并了公式中的矩阵乘法与减法,效率更高。
代码精讲:Gated Delta Rule 的两种实现
torch_recurrent_gated_delta_rule:单步递推
def torch_recurrent_gated_delta_rule(query, key, value, g, beta, initial_state, ...):
initial_dtype = query.dtype # 备份数据类型
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
# QK 归一化(可选),有助于稳定训练和推理。
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
# 这一步很重要,它把输入的形状从 [B, L, H, d] 转成 [B, H, L, d],并且转换为 float32 以提高数值稳定性。对于SSM等长序列模型,记忆状态会经历累加,这种精度是必要的。
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
# 声明(创建)空变量存储输出和递推状态
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
# sequence_length 一般为1,但也可以是小批量的多步(例如 decode 时一次更新多个 token)。
for i in range(sequence_length):
q_t = query[:, :, i] # q_t
k_t = key[:, :, i] # k_t
v_t = value[:, :, i] # v_t
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) # alpha_t
beta_t = beta[:, :, i].unsqueeze(-1) # beta_t
# 对应公式: S_{t-1} * alpha_t
last_recurrent_state = last_recurrent_state * g_t
# 计算 kv_mem = S_{t-1} * k_t (公式中的矩阵乘法)
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
# 核心Delta操作: delta = (v_t - kv_mem) * beta_t
# 这等效于先计算 beta_t * v_t,再减去 beta_t * (S_{t-1} * k_t)
delta = (v_t - kv_mem) * beta_t
# Rank-1 更新: S_t = S_{t-1} + k_t ⊗ delta
# 这行代码等价于完成公式的最终合并:
# S_t = alpha_t * S_{t-1} - alpha_t * beta_t * S_{t-1} * (k_t k_t^T) + beta_t * v_t k_t^T
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
# 输出: o_t = S_t * q_t
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
# 最后把输出转回 [B, L, H, d_v] 的形状,并恢复原始数据类型。
torch_chunk_gated_delta_rule:分块并行训练
分块并行(chunk-wise parallel)的核心思想是:在块内用矩阵运算并行,块间用 recurrent 递推串行。 这样既能利用 GPU 的矩阵乘法加速,又把计算复杂度控制在 $O(LC)$ 而非 $O(L^2)$。
def torch_chunk_gated_delta_rule(
query, key, value, g, beta,
chunk_size=64, initial_state=None,
output_final_state=False, use_qk_l2norm_in_kernel=False,
):
# ── 0. 预处理:QK 归一化 + 转置 + float32(与单步递推完全一致)──
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
# 形状由 [B, L, H, d] 转为 [B, H, L, d],便于后续按头做矩阵乘法。
# ── 1. 对序列长度做右侧补零,使其整除 chunk_size ──
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
# pad_size 保证 total_sequence_length = sequence_length + pad_size 能被 chunk_size 整除。
# 若 sequence_length 已经整除,则 pad_size = 0,不做任何填充。
query = F.pad(query, (0, 0, 0, pad_size)) # [B, H, L+pad, d_k]
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size)) # [B, H, L+pad]
g = F.pad(g, (0, pad_size)) # [B, H, L+pad]
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale # 缩放同单步递推
# ── 2. 预计算带 beta 的 v 和 k,对应公式中 beta_t * v_t 和 beta_t * k_t ──
v_beta = value * beta.unsqueeze(-1) # [B, H, L+pad, d_v],β_t v_t
k_beta = key * beta.unsqueeze(-1) # [B, H, L+pad, d_k],β_t k_t
# ── 3. Reshape:将时间维切分为 (num_chunks, chunk_size) ──
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (query, key, value, k_beta, v_beta)
]
# 形状变为 [B, H, num_chunks, chunk_size, d],方便按 chunk 索引。
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
# g: [B, H, num_chunks, chunk_size]
# ── 4. 块内衰减掩码 (decay_mask) ──
# 对每个 chunk 内的 g 做累积求和,得到从 chunk 起点到各位置的累积 log-decay。
g = g.cumsum(dim=-1)
# g[..., t] 现在表示 sum_{s=0}^{t} log_alpha_s(块内前缀和)。
# decay_mask[..., t, s] = exp(g[t] - g[s]),表示从位置 s 到位置 t 的衰减系数。
# .tril() 保证只有 t >= s 的位置有值(因果性),其余置 0。
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
# decay_mask: [B, H, num_chunks, chunk_size, chunk_size]
# 对角线 (t==s) 处值为 exp(0)=1,即当前 token 对自身无衰减。
# ── 5. 块内 Delta Rule 并行化:三角递推求解变换矩阵 attn ──
# 目标:高效并行地在块内实现 delta rule 的"定向擦写"效果。
#
# 原始 delta 操作写成矩阵形式,块内第 t 步的状态依赖之前所有步,
# 天然是下三角结构。以下用一次矩阵乘法 + 一个 O(C^2) 的迭代修正来完成。
#
# 首先构造初始下三角注意力矩阵(严格下三角,diagonal=0 被 mask 掉):
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
# k_beta @ key^T: [B, H, num_chunks, chunk_size, chunk_size]
# 每个 [t, s] 元素 = -beta_t * (k_t · k_s) * decay(s->t),表示位置 s 对 t 的"抹除"贡献。
# masked_fill(mask, 0):上三角(含对角线)置 0,保证因果性。
# 迭代修正:处理链式依赖(S_t 依赖 S_{t-1} 依赖 S_{t-2} ...)
# 每次迭代把"间接路径"(跨 2 步、3 步...的衰减擦写)加到矩阵里,
# 直到所有路径都被统计到(最多 chunk_size-1 次迭代)。
for i in range(1, chunk_size):
row = attn[..., i, :i].clone() # 第 i 行(i 对前 i 个 token 的直接影响)
sub = attn[..., :i, :i].clone() # 左上角子矩阵(前 i 个 token 之间的影响)
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
# row + row @ sub:把"i->j 的直接路径"与"j->k 的间接路径"合并,
# 类似并行前缀求和,完成下三角系统的递归展开。
# 加上单位矩阵:对角线代表每个 token 对自身的恒等变换(无擦写时直接写入)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
# attn 此时是块内完整的"Delta变换矩阵",形状 [B, H, num_chunks, C, C]。
# ── 6. 用变换矩阵更新 value 和 k_cumdecay ──
value = attn @ v_beta
# 将每个位置的 v_beta(beta_t * v_t)经过 delta 变换矩阵聚合,
# 得到块内已完成 delta rule 修正的有效 value。
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
# k_cumdecay[..., t, :] 表示"t 时刻的 key 方向(含衰减与 delta 修正)",
# 用于块间传播:告诉下一个 chunk,当前 chunk 末尾对历史状态 S 的影响方向。
# ── 7. 初始化块间递推状态与输出缓冲 ──
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
# last_recurrent_state: [B, H, d_k, d_v],跨 chunk 传递的 S 矩阵,与单步递推中完全相同。
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
# 重新定义 mask:此处为严格上三角(diagonal=1),用于块内 Q@K^T 的因果掩码。
# ── 8. 块间循环:对每个 chunk 计算输出并更新跨块状态 ──
for i in range(0, total_sequence_length // chunk_size):
q_i = query[:, :, i] # [B, H, C, d_k],当前 chunk 的 query
k_i = key[:, :, i] # [B, H, C, d_k]
v_i = value[:, :, i] # [B, H, C, d_v],已经过 delta 变换的 value
# 块内注意力(上一 chunk 内的因果 Q@K^T,加上衰减掩码)
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
# attn[..., t, s] = (q_t · k_s) * decay(s->t),因果下三角,[B, H, C, C]
# 历史状态对当前 chunk 每个 token 的贡献("跨块 key 对历史 S 的读取")
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
# v_prime[..., t, :] = k_cumdecay[t] @ S_{prev}
# 表示:如果当前 token 的 key 方向在历史 S 中已有存储,则 v_prime 是需要"擦除"的部分。
# 修正后的 value:从 delta 变换后的 v 中减去历史状态的贡献
v_new = v_i - v_prime
# v_new 是最终"净写入量",即在 delta rule 下,实际新增到 S 中的信息。
# 历史状态直接经过衰减后对当前 Q 的贡献(跨块远程记忆检索)
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
# g[:, :, i, :] 是当前 chunk 每个位置的累积 log-decay(从 chunk 起点到该位置);
# exp() 后乘以 q_i,再与 S_{prev} 做矩阵乘,得到历史记忆对当前查询的响应。
# 注意:这里 q 已经乘上了块内累积衰减,反映了"历史 S 传入本 chunk 后继续衰减"。
# 最终输出 = 跨块远程记忆 + 块内局部注意力(作用于修正后的 value)
core_attn_out[:, :, i] = attn_inter + attn @ v_new
# 更新跨块状态 S(递推到下一个 chunk):
# S_{next} = S_{prev} * exp(g_last) # 全局衰减(用 chunk 末尾的累积 decay)
# + K_i^T @ v_new(按各位置的"剩余衰减"加权) # 块内净写入
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
# g[..., -1] 是当前 chunk 末尾的累积 log-decay,exp() 后即 alpha_{chunk_end},
# 表示整个 chunk 对历史状态的全局遗忘。
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
# (g_last - g_t) 是从位置 t 到 chunk 末的"剩余衰减";
# k_i 乘以此系数后转置,再与 v_new 做外积累加,
# 等效于把块内每个 token 的写入贡献按其到 chunk 末的衰减正确加权后汇总进 S。
)
# ── 9. 后处理:截断 pad、转回原始形状与数据类型 ──
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(
core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]
)
core_attn_out = core_attn_out[:, :, :sequence_length] # 去掉右侧补零
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
# 转回 [B, L, H, d_v],恢复原始精度。
return core_attn_out, last_recurrent_state
代码精讲:gated_delta_rule 后的输出处理
# 更新缓存中的递推状态,供下一个 decode 步骤复用
if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
# 注:只在有 cache 时写回;训练/prefill 无 cache 则跳过。
z_shape_og = z.shape
# 注:保存 z 的原始形状 [B, L, H, d_v],后续 reshape 回来时需要。
# 将 core_attn_out 和 z 都压平到二维,满足 RMSNormGated 的输入要求
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
# 注:RMSNormGated 期望输入为 [B*L*H, d_v],故先把前三维合并。
core_attn_out = self.norm(core_attn_out, z)
# 注:self.norm 是 RMSNormGated,公式为:
# output = RMSNorm(core_attn_out) * sigmoid(z)
# 其中:
# - RMSNorm(core_attn_out) 对 attention 输出做归一化,稳定数值范围;
# - sigmoid(z) 作为输出门控,z 是之前从 qkvz 投影中拆出的门控分支;
# - 两者逐元素相乘,实现对输出的动态缩放(类似 GLU 结构)。
core_attn_out = core_attn_out.reshape(z_shape_og)
# 注:恢复为 [B, L, H, d_v] 的多头形状。
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)
# 注:把头维摊平,变为 [B, L, H*d_v],准备送入线性投影层。
output = self.out_proj(core_attn_out)
# 注:out_proj 是标准线性层,将 [B, L, H*d_v] 映射回模型维度 [B, L, D],
# 与标准 Attention 的输出投影完全对称。
变量名与公式符号对照(对应本小节输出处理阶段)
| 代码变量名 | 公式符号 | 典型形状 | 含义 |
|---|---|---|---|
last_recurrent_state | $\mathbf{S}_t$ | [B, H, d_k, d_v] | 当前步最终记忆矩阵,写回 cache |
core_attn_out | $\mathbf{o}_t` | [B, L, H, d_v] → [B, L, D] | GDN 核心输出,门控归一化前后 |
z | 输出门控(公式外) | [B, L, H, d_v] | RMSNormGated 的门控分支 |
self.norm(·, z) | $\text{RMSNorm}(\mathbf{o}_t)\odot\sigma(\mathbf{z}_t)$ | [B*L*H, d_v] | 归一化 + 门控融合 |
self.out_proj(·) | $\mathbf{W}_o$ | [H*d_v, D] | 输出线性投影,映射回模型维度 |
注:
z门控在"前言"部分的fix_query_key_value_ordering中与query/key/value一同从qkvz投影中拆出,贯穿整个前向流程,最终在此处参与输出门控计算。
Gated Delta Net 的计算量分析
下面给出分析中的变量定义。设:
- 序列长度 $L$
- chunk 大小 $C$
- chunk 数 $N=L/C$
- 头数 $H$
- 每头维度 $d_k, d_v$
- 模型维度 $D$
单步推理复杂度
结合代码实现,状态矩阵 $\mathbf{S}$ 的实际形状为 [B, H_v, d_k, d_v]——每个 value head 维护一个 $d_k \times d_v$ 的记忆矩阵;而 Q/K 在 repeat_interleave 后也对齐到 $H_v$ 个头。因此复杂度应以 $H_v$ 为单位计算。
对每个 value head,单步递推的操作为:
S <- alpha * S:对 $d_k \times d_v$ 矩阵逐元素乘标量,$O(d_k d_v)$kv_mem = (S * k_t).sum(-2):按 $d_k$ 维收缩,$O(d_k d_v)$delta = (v_t - kv_mem) * beta_t:逐元素操作,$O(d_v)$- rank-1 更新
S += k_t ⊗ delta:外积写入,$O(d_k d_v)$ - 输出
o_t = (S * q_t).sum(-2):按 $d_k$ 维收缩,$O(d_k d_v)$
合计每头 $O(4d_k d_v + d_v)$,全 $H_v$ 个头后每 token 的核心递推复杂度:
$$ O(4H_vd_k d_v + H_v d_v) \approx O(H_v d_k d_v) $$输入/输出投影层(in_proj_qkvz、in_proj_ba、out_proj),三项分别为:
假设 $H_v \approx H_k$,$H_v d_v \approx D$,投影层复杂度主项约为 $O(D^2)$
$$ O(LD^2) $$这就是它之所以被称为"线性 Attention"的原因:核心递推部分对 $L$ 的复杂度是线性的,而不是标准 Attention 的 $O(L^2)$。
我们刚刚的计算漏掉了一个细节。如果加上因果卷积(causal_conv1d,kernel size $K$)。卷积作用在拼接后的 Q/K/V 通道上(共 $2H_k d_k + H_v d_v$ 个通道),每个通道每个位置做 $K$ 次乘加,单 token 复杂度为:
实现中 $K$ 为固定小常数($K=4$),故卷积开销相对投影层可忽略不计。
单步推理时,投影层的计算量通常远大于核心递推部分,这在 decode 阶段(seq_len=1)尤为明显。
并行训练复杂度
将序列分块后,计算通常分成:
- 块内构造相关矩阵(如 $\mathbf{K}\mathbf{K}^T$)与三角求解/变换;
- 块内输出计算(类似 $(\mathbf{Q}\mathbf{K}^T \odot \mathbf{M})\cdot(\cdots)$);
- 块间状态传播。
按每块估算主项(每头)可写为:
$$ O(C^2 d_k) + O(C^2 d_v) + O(C d_k d_v) $$全序列($N=L/C$ 块):
$$ O\!\left(\frac{L}{C}(C^2(d_k+d_v)+C d_k d_v)\right) =O\!\left(LC(d_k+d_v)+L d_k d_v\right) $$乘上头数 $H$:
$$ O\!\left(HLC(d_k+d_v)+HL d_k d_v\right) $$当 $C$ 取固定工程常数(例如 64)时,整体对 $L$ 仍线性增长。
并且由于块内是大矩阵运算,实际硬件利用率通常优于纯逐 token 串行训练。
Gated Delta Net 的演化
LA(Linear Attention)
状态更新公式:
$$ \mathbf{S}_t=\mathbf{S}_{t-1}+\mathbf{v}_t\mathbf{k}_t^T $$特征:
- 最朴素外积记忆累加
- 没有显式遗忘,长序列易“记忆叠加/碰撞”
Mamba2(加入门控衰减)
状态更新公式:
$$ \mathbf{S}_t=\alpha_t\mathbf{S}_{t-1}+\mathbf{v}_t\mathbf{k}_t^T $$特征:
- 有了全局遗忘(通过 $\alpha_t$)
- 能主动清空历史,但擦除是“比较均匀”的,不够定向
Longhorn(在线回归闭式更新)
状态更新公式:
$$ \mathbf{S}_t=\mathbf{S}_{t-1}(\mathbf{I}-\epsilon_t\mathbf{k}_t\mathbf{k}_t^T)+\epsilon_t\mathbf{v}_t\mathbf{k}_t^T,\quad \epsilon_t=\frac{\beta_t}{1+\beta_t\mathbf{k}_t^T\mathbf{k}_t} $$特征:
- 来自更“回归化”的在线学习目标
- 与 delta 形式非常接近,但系数来自 implicit online learning 闭式推导
DeltaNet(定向替换写入)
状态更新公式:
$$ \mathbf{S}_t=\mathbf{S}_{t-1}(\mathbf{I}-\beta_t\mathbf{k}_t\mathbf{k}_t^T)+\beta_t\mathbf{v}_t\mathbf{k}_t^T $$特征:
- 对当前 key 对应方向做定向擦写 + 写入
- 关联记忆能力强(检索类任务常更好)
- 但缺少全局快速遗忘机制
Gated DeltaNet(两者融合)
状态更新公式:
$$ \mathbf{S}_t=\mathbf{S}_{t-1}\left(\alpha_t(\mathbf{I}-\beta_t\mathbf{k}_t\mathbf{k}_t^T)\right)+\beta_t\mathbf{v}_t\mathbf{k}_t^T $$它结合了两种能力:
- $\alpha_t$:衰减陈旧信息
- $\beta_t$ + delta 项:针对 key 做精细写入与替换
这也是论文实验里它在需要“记忆 + 过滤”同时成立的场景下表现稳定的核心原因。
Qwen3-Next 架构里的 Gated Delta Net(简述)
从 modular_qwen3_next.py 看,Qwen3-Next 是混合层设计:
Qwen3NextDecoderLayer里,layer_type可为linear_attention(即Qwen3NextGatedDeltaNet)full_attention(标准注意力分支)
- FFN 侧可配 MoE(
Qwen3NextSparseMoeBlock)或普通 MLP - 因此整体是一个 Attention / Linear-Attention 混合 + MoE 可选 的解码器结构
可以把它理解成:
- 用一部分 Gated Delta Net 层换取长序列效率与记忆管理能力;
- 用全注意力层保留强建模上限;
- 再通过 MoE 扩展容量与性价比。
参考链接
- Transformers 实现:modular_qwen3_next.py
- LA(Linear Attention):Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention(Katharopoulos et al., ICML 2020)
- Mamba2:Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality(Dao & Gu, ICML 2024)
- Longhorn:Longhorn: State Space Models are Amortized Online Learners(Bo Liu et al., 2024)
- DeltaNet(Fast Weight / delta rule 视角):Linear Transformers Are Secretly Fast Weight Programmers(Schlag et al., ICML 2021)
- DeltaNet(并行化训练算法):Parallelizing Linear Transformers with the Delta Rule over Sequence Length(Yang et al., NeurIPS 2024)
- Gated DeltaNet:Gated Delta Networks: Improving Mamba2 with Delta Rule(Yang et al., ICLR 2025)
