Featured image of post Gated Delta Net

Gated Delta Net

结合论文与 Qwen3-Next 的代码实现,学习 Gated Delta Net

前言

最近在看 Qwen3-Next 的实现,发现它在一部分层里没有用 Full Attention,而是用了 Gated Delta Net,刚刚发布的 Qwen3.5 系列也沿用了这一设计。
更早一些的工作,例如 MiMo 的 混合 SWA,Jamba的混合 manba 模块,也是希望通过替换掉大部分 Full Attention 来提升长序列效率。

这个结构很有意思:

  • 一方面,它继承了线性 RNN / 线性 Attention 的线性时间和常数状态推理优势;
  • 另一方面,它把 “gating(遗忘)” 和 “delta rule(定向写入)” 合到了一起,改善了长上下文下的记忆管理和检索能力。

这篇笔记主要参考两份材料:

接下来,我会先讲解 Gated Delta Net 的核心算法(先理解推理,再训练并行化),然后分析它的计算复杂度,最后回顾它在各种线性 RNN 中的演化脉络。

Gated Delta Net 算法详解

一点废话

在正式开始之前,让我们把视角抬高一些,看看Transformer模型的结构。Transformer模型的结构,是一层attention,一层FFN如此交错,他们分别代表了2个不同的功能:

  • Attention:负责跨位置的信息交互,也就是“上下文”,让模型能够捕捉序列中不同位置之间的依赖关系。
  • FFN:负责位置内的信息变换,也就是“知识”,让模型能够对每个位置的特征进行非线性变换和增强。

FFN的部分在DeepSeek系列已经讲过了,通过MoE带来的稀疏化,在保持推理速度不变的约束下,大幅提升了模型的容量,提升模型的能力;包括近期出现的Deepseek ENGram,将知识做成外置,也是一个FFN扩容的探索。 Attention的部分,尤其是长序列的Attention,一直是效率的瓶颈。Gated Delta Net就是在这个背景下,基于古老RNN的一个设计,旨在通过一种更高效的方式来实现长序列的上下文交互,同时保持模型的表达能力。

这两部分都被替换之后,Transformer模型是否就变成了忒修斯之船?我觉得不完全是。虽然结构上发生了改变,但它的功能定位(上下文交互)依然存在。

Gated Delta Rule

先看核心公式,Gated Delta Rule:

$$ \mathbf{S}_t =\mathbf{S}_{t-1} \left( \alpha_t(\mathbf{I}-\beta_t \mathbf{k}_t \mathbf{k}_t^T) \right) + \beta_t \mathbf{v}_t \mathbf{k}_t^T $$

其中:

  • $\mathbf{S}_t \in \mathbb{R}^{d_v \times d_k}$:时刻 $t$ 的“快权重/记忆矩阵”
  • $\mathbf{k}_t \in \mathbb{R}^{d_k}$,$\mathbf{v}_t \in \mathbb{R}^{d_v}$,$\mathbf{q}_t \in \mathbb{R}^{d_k}$
  • $\beta_t \in (0,1)$:写入强度(学习率)
  • $\alpha_t \in (0,1)$:遗忘门控(衰减)
  • $\mathbf{I}$:单位矩阵

直觉上:

  • $\alpha_t$ 控制“整体保留多少旧记忆”;
  • $\beta_t$ 控制“针对当前 key 写入多少新 value”;
  • $\mathbf{v}_t \mathbf{k}_t^T$ 是当前输入的增量写入项;
    • 表示在 $\mathbf{k}_t$ 方向上写入 $\mathbf{v}_t$ 的信息;
  • $(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T)$ 是带有“定向擦写”含义的变换。
    • 定向:$\mathbf{k}_t\mathbf{k}_t^T$ 是张量积
    • 擦写:从$\mathbf{I}$中减去,让 $\mathbf{S}_{t-1}$ 在 $\mathbf{k}_t$ 方向上的分量被衰减
    • 衰减:乘以 $\alpha_t$ 进一步控制整体遗忘程度

输出一般写为:

$$ \mathbf{o}_t = \mathbf{S}_t \mathbf{q}_t \in \mathbb{R}^{d_v} $$

代码精讲:GDN前的准备

输入与缓存状态判定

hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
# 注:把 padding 位的状态清理掉,避免线性递推状态被无效 token 污染。

batch_size, seq_len, _ = hidden_states.shape
# 注:后续 reshape / decode 分支判断都会用到 seq_len。

use_precomputed_states = (
    cache_params is not None
    and cache_params.has_previous_state
    and seq_len == 1
    and cache_position is not None
)
# 注:只有 decode 单步(seq_len==1)且已有历史状态时,才走 recurrent 路径;
#     否则走 chunk 路径(prefill/训练)。

if cache_params is not None:
    conv_state = cache_params.conv_states[self.layer_idx]
    recurrent_state = cache_params.recurrent_states[self.layer_idx]
# 注:从动态缓存读取卷积状态与递推状态,供增量推理复用。

双路线性投影 + 拆分(重点看 z / b / a)

projected_states_qkvz = self.in_proj_qkvz(hidden_states)
projected_states_ba = self.in_proj_ba(hidden_states)
# 注:两路投影:
#   - qkvz 路:产出 query/key/value/z
#   - ba 路:产出 b/a(后续分别变成 beta 和 g 的输入)

query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
# 注:在 fix_query_key_value_ordering 内部会先按 head 分组,再 split:
#   - qkvz 按 [q, k, v, z] 拆分
#   - ba   按 [b, a] 拆分

拆分和形状修复

def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
    """
    Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`.
    """
    new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
        self.num_k_heads,
        2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads,
    )
    new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)

    mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
    #   mixed_qkvz
    #      -> view -> [B, L, num_k_heads, 2*head_k_dim + 2*head_v_dim*(num_v_heads/num_k_heads)]
    mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
    #   mixed_ba: [B, L, 2*num_v_heads]
    #      -> view -> [B, L, num_k_heads, 2*(num_v_heads/num_k_heads)]
    
    split_arg_list_qkvz = [
        self.head_k_dim,
        self.head_k_dim,
        (self.num_v_heads // self.num_k_heads * self.head_v_dim),
        (self.num_v_heads // self.num_k_heads * self.head_v_dim),
    ]
    split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]
    query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3)
    # 在 dim=3 上 split 成 q/k/v/z
    #   query: [B, L, num_k_heads, head_k_dim]
    #   key:   [B, L, num_k_heads, head_k_dim]
    #   value: [B, L, num_k_heads, (num_v_heads/num_k_heads)*head_v_dim]
    #   z:     [B, L, num_k_heads, (num_v_heads/num_k_heads)*head_v_dim]
    b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3)
    # 在 dim=3 上 split 成 b/a
    #   b: [B, L, num_k_heads, (num_v_heads/num_k_heads)]
    #   a: [B, L, num_k_heads, (num_v_heads/num_k_heads)]

    value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim)
    #   value: [B, L, num_k_heads, (num_v_heads/num_k_heads)*head_v_dim]
    #      -> reshape -> [B, L, num_v_heads, head_v_dim]
    z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim)
    #   z: [B, L, num_k_heads, (num_v_heads/num_k_heads)*head_v_dim]
    #      -> reshape -> [B, L, num_v_heads, head_v_dim]
    b = b.reshape(b.size(0), b.size(1), self.num_v_heads)
    #   b: [B, L, num_k_heads, (num_v_heads/num_k_heads)]
    #      -> reshape -> [B, L, num_v_heads]
    a = a.reshape(a.size(0), a.size(1), self.num_v_heads)
    #   a: [B, L, num_k_heads, (num_v_heads/num_k_heads)]
    #      -> reshape -> [B, L, num_v_heads]
    return query, key, value, z, b, a

QKV 因果卷积混合

query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))
# 注:先把头维摊平,便于拼接做一维深度卷积。

mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = mixed_qkv.transpose(1, 2)
# 注:卷积输入形状转成 [B, C, L],C=2*key_dim+value_dim。

if use_precomputed_states:
    mixed_qkv = self.causal_conv1d_update(
        mixed_qkv,
        conv_state,
        self.conv1d.weight.squeeze(1),
        self.conv1d.bias,
        self.activation,
    )
# 注:decode 增量卷积,只更新最后一步,避免重复算历史。

else:
    if cache_params is not None:
        conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
        ## 这部分的参数格式是F.pad(input, (pad_left, pad_right));
        ## 如果大于kernel_size,pad_left就会是负数,表示裁剪掉前面多余的部分;
        ## 如果小于kernel_size,就会在前面补零。
        cache_params.conv_states[self.layer_idx] = conv_state
    if self.causal_conv1d_fn is not None:
        mixed_qkv = self.causal_conv1d_fn(
            x=mixed_qkv,
            weight=self.conv1d.weight.squeeze(1),
            bias=self.conv1d.bias,
            activation=self.activation,
            seq_idx=None,
        )
    else:
        mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
# 注:prefill/训练路径:
#   - 有快实现就用 fused kernel;
#   - 否则回退到 torch conv1d + silu;
#   - 若开 cache,同步初始化 conv_state。

mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1)
# 注:卷积后再拆回 Q/K/V 三路。

重排头维并构造 beta/g

query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
# 注:恢复为多头形状 [B, L, H, d]。

beta = b.sigmoid()
# 注:写入强度 beta∈(0,1)。

g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
# 注:得到 log-space 的衰减参数;后续在 rule 内部 exp(g)->alpha∈(0,1)。

if self.num_v_heads // self.num_k_heads > 1:
    query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
    key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
# 注:当 v-head 多于 k-head 时,把 K/Q 复制到对齐的头数,保证后续逐头计算可广播。

概括来说,这部分代码的主要功能是: 投影拆分 → 局部卷积混合 → 头维对齐 → 生成 beta/g。

变量名与公式符号对照(对应本小节准备阶段)

代码变量名公式符号典型形状(多头)含义
hidden_states输入序列(公式外)[B, L, D]模块输入隐藏状态
query$\mathbf{q}_t$[B, L, H, d_k]查询向量(每时刻/每头)
key$\mathbf{k}_t$[B, L, H, d_k]键向量(每时刻/每头)
value$\mathbf{v}_t$[B, L, H, d_v]值向量(每时刻/每头)
b$\beta_t$ 的参数化前体[B, L, H]经过 sigmoid 前的写入强度参数
beta = b.sigmoid()$\beta_t \in (0,1)$[B, L, H]写入强度(学习率)
a$\alpha_t$ 的参数化前体[B, L, H]经过变换前的遗忘门控参数
g$\log \alpha_t$(实现中)[B, L, H]衰减的 log-space 参数,通常为负
g.exp()$\alpha_t \in (0,1)$[B, L, H]真正用于递推的遗忘系数
recurrent_state / last_recurrent_state$\mathbf{S}_{t-1}$ / $\mathbf{S}_t$[B, H, d_k, d_v]快权重记忆矩阵(递推状态)
z输出门控(公式外)[B, L, H, d_v]用于后续 RMSNormGated 的门控分支

注:文中公式使用 $\alpha_t$;代码里常先算 g,在 rule 内部通过 exp(g) 得到 $\alpha_t$。 注2:代码实现通过计算 kv_mem = S_{t-1} k_tdelta = (v_t - kv_mem) * beta_t,巧妙地合并了公式中的矩阵乘法与减法,效率更高。


代码精讲:Gated Delta Rule 的两种实现

torch_recurrent_gated_delta_rule:单步递推

def torch_recurrent_gated_delta_rule(query, key, value, g, beta, initial_state, ...):

    initial_dtype = query.dtype # 备份数据类型
    if use_qk_l2norm_in_kernel:
        query = l2norm(query, dim=-1, eps=1e-6)
        key = l2norm(key, dim=-1, eps=1e-6)
    # QK 归一化(可选),有助于稳定训练和推理。
    query, key, value, beta, g = [
        x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
    ]
    # 这一步很重要,它把输入的形状从 [B, L, H, d] 转成 [B, H, L, d],并且转换为 float32 以提高数值稳定性。对于SSM等长序列模型,记忆状态会经历累加,这种精度是必要的。

    batch_size, num_heads, sequence_length, k_head_dim = key.shape
    v_head_dim = value.shape[-1]
    scale = 1 / (query.shape[-1] ** 0.5)
    query = query * scale

    # 声明(创建)空变量存储输出和递推状态
    core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
    last_recurrent_state = (
        torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
        if initial_state is None
        else initial_state.to(value)
    )

    # sequence_length 一般为1,但也可以是小批量的多步(例如 decode 时一次更新多个 token)。
    for i in range(sequence_length):
        q_t = query[:, :, i]   # q_t
        k_t = key[:, :, i]     # k_t
        v_t = value[:, :, i]   # v_t
        g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)  # alpha_t
        beta_t = beta[:, :, i].unsqueeze(-1)                # beta_t

        # 对应公式: S_{t-1} * alpha_t
        last_recurrent_state = last_recurrent_state * g_t

        # 计算 kv_mem = S_{t-1} * k_t (公式中的矩阵乘法)
        kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)

        # 核心Delta操作: delta = (v_t - kv_mem) * beta_t
        # 这等效于先计算 beta_t * v_t,再减去 beta_t * (S_{t-1} * k_t)
        delta = (v_t - kv_mem) * beta_t

        # Rank-1 更新: S_t = S_{t-1} + k_t ⊗ delta
        # 这行代码等价于完成公式的最终合并:
        # S_t = alpha_t * S_{t-1} - alpha_t * beta_t * S_{t-1} * (k_t k_t^T) + beta_t * v_t k_t^T
        last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)

        # 输出: o_t = S_t * q_t
        core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)

    core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
    # 最后把输出转回 [B, L, H, d_v] 的形状,并恢复原始数据类型。

torch_chunk_gated_delta_rule:分块并行训练

分块并行(chunk-wise parallel)的核心思想是:在块内用矩阵运算并行,块间用 recurrent 递推串行。 这样既能利用 GPU 的矩阵乘法加速,又把计算复杂度控制在 $O(LC)$ 而非 $O(L^2)$。

def torch_chunk_gated_delta_rule(
    query, key, value, g, beta,
    chunk_size=64, initial_state=None,
    output_final_state=False, use_qk_l2norm_in_kernel=False,
):
    # ── 0. 预处理:QK 归一化 + 转置 + float32(与单步递推完全一致)──
    initial_dtype = query.dtype
    if use_qk_l2norm_in_kernel:
        query = l2norm(query, dim=-1, eps=1e-6)
        key = l2norm(key, dim=-1, eps=1e-6)
    query, key, value, beta, g = [
        x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
    ]
    # 形状由 [B, L, H, d] 转为 [B, H, L, d],便于后续按头做矩阵乘法。

    # ── 1. 对序列长度做右侧补零,使其整除 chunk_size ──
    batch_size, num_heads, sequence_length, k_head_dim = key.shape
    v_head_dim = value.shape[-1]
    pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
    # pad_size 保证 total_sequence_length = sequence_length + pad_size 能被 chunk_size 整除。
    # 若 sequence_length 已经整除,则 pad_size = 0,不做任何填充。
    query  = F.pad(query,  (0, 0, 0, pad_size))   # [B, H, L+pad, d_k]
    key    = F.pad(key,    (0, 0, 0, pad_size))
    value  = F.pad(value,  (0, 0, 0, pad_size))
    beta   = F.pad(beta,   (0, pad_size))           # [B, H, L+pad]
    g      = F.pad(g,      (0, pad_size))           # [B, H, L+pad]
    total_sequence_length = sequence_length + pad_size

    scale = 1 / (query.shape[-1] ** 0.5)
    query = query * scale                           # 缩放同单步递推

    # ── 2. 预计算带 beta 的 v 和 k,对应公式中 beta_t * v_t 和 beta_t * k_t ──
    v_beta = value * beta.unsqueeze(-1)             # [B, H, L+pad, d_v],β_t v_t
    k_beta = key   * beta.unsqueeze(-1)             # [B, H, L+pad, d_k],β_t k_t

    # ── 3. Reshape:将时间维切分为 (num_chunks, chunk_size) ──
    query, key, value, k_beta, v_beta = [
        x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
        for x in (query, key, value, k_beta, v_beta)
    ]
    # 形状变为 [B, H, num_chunks, chunk_size, d],方便按 chunk 索引。
    g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
    # g: [B, H, num_chunks, chunk_size]

    # ── 4. 块内衰减掩码 (decay_mask) ──
    # 对每个 chunk 内的 g 做累积求和,得到从 chunk 起点到各位置的累积 log-decay。
    g = g.cumsum(dim=-1)
    # g[..., t] 现在表示 sum_{s=0}^{t} log_alpha_s(块内前缀和)。

    # decay_mask[..., t, s] = exp(g[t] - g[s]),表示从位置 s 到位置 t 的衰减系数。
    # .tril() 保证只有 t >= s 的位置有值(因果性),其余置 0。
    decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
    # decay_mask: [B, H, num_chunks, chunk_size, chunk_size]
    # 对角线 (t==s) 处值为 exp(0)=1,即当前 token 对自身无衰减。

    # ── 5. 块内 Delta Rule 并行化:三角递推求解变换矩阵 attn ──
    # 目标:高效并行地在块内实现 delta rule 的"定向擦写"效果。
    # 
    # 原始 delta 操作写成矩阵形式,块内第 t 步的状态依赖之前所有步,
    # 天然是下三角结构。以下用一次矩阵乘法 + 一个 O(C^2) 的迭代修正来完成。
    #
    # 首先构造初始下三角注意力矩阵(严格下三角,diagonal=0 被 mask 掉):

    mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
    attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
    # k_beta @ key^T: [B, H, num_chunks, chunk_size, chunk_size]
    # 每个 [t, s] 元素 = -beta_t * (k_t · k_s) * decay(s->t),表示位置 s 对 t 的"抹除"贡献。
    # masked_fill(mask, 0):上三角(含对角线)置 0,保证因果性。

    # 迭代修正:处理链式依赖(S_t 依赖 S_{t-1} 依赖 S_{t-2} ...)
    # 每次迭代把"间接路径"(跨 2 步、3 步...的衰减擦写)加到矩阵里,
    # 直到所有路径都被统计到(最多 chunk_size-1 次迭代)。
    for i in range(1, chunk_size):
        row = attn[..., i, :i].clone()       # 第 i 行(i 对前 i 个 token 的直接影响)
        sub = attn[..., :i, :i].clone()      # 左上角子矩阵(前 i 个 token 之间的影响)
        attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
        # row + row @ sub:把"i->j 的直接路径"与"j->k 的间接路径"合并,
        # 类似并行前缀求和,完成下三角系统的递归展开。

    # 加上单位矩阵:对角线代表每个 token 对自身的恒等变换(无擦写时直接写入)
    attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
    # attn 此时是块内完整的"Delta变换矩阵",形状 [B, H, num_chunks, C, C]。

    # ── 6. 用变换矩阵更新 value 和 k_cumdecay ──
    value = attn @ v_beta
    # 将每个位置的 v_beta(beta_t * v_t)经过 delta 变换矩阵聚合,
    # 得到块内已完成 delta rule 修正的有效 value。
    k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
    # k_cumdecay[..., t, :] 表示"t 时刻的 key 方向(含衰减与 delta 修正)",

    # 用于块间传播:告诉下一个 chunk,当前 chunk 末尾对历史状态 S 的影响方向。

    # ── 7. 初始化块间递推状态与输出缓冲 ──
    last_recurrent_state = (
        torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
        if initial_state is None
        else initial_state.to(value)
    )
    # last_recurrent_state: [B, H, d_k, d_v],跨 chunk 传递的 S 矩阵,与单步递推中完全相同。
    core_attn_out = torch.zeros_like(value)
    mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
    # 重新定义 mask:此处为严格上三角(diagonal=1),用于块内 Q@K^T 的因果掩码。

    # ── 8. 块间循环:对每个 chunk 计算输出并更新跨块状态 ──
    for i in range(0, total_sequence_length // chunk_size):
        q_i = query[:, :, i]       # [B, H, C, d_k],当前 chunk 的 query
        k_i = key[:, :, i]         # [B, H, C, d_k]
        v_i = value[:, :, i]       # [B, H, C, d_v],已经过 delta 变换的 value

        # 块内注意力(上一 chunk 内的因果 Q@K^T,加上衰减掩码)
        attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
        # attn[..., t, s] = (q_t · k_s) * decay(s->t),因果下三角,[B, H, C, C]

        # 历史状态对当前 chunk 每个 token 的贡献("跨块 key 对历史 S 的读取")
        v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
        # v_prime[..., t, :] = k_cumdecay[t] @ S_{prev}
        # 表示:如果当前 token 的 key 方向在历史 S 中已有存储,则 v_prime 是需要"擦除"的部分。

        # 修正后的 value:从 delta 变换后的 v 中减去历史状态的贡献
        v_new = v_i - v_prime
        # v_new 是最终"净写入量",即在 delta rule 下,实际新增到 S 中的信息。

        # 历史状态直接经过衰减后对当前 Q 的贡献(跨块远程记忆检索)
        attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
        # g[:, :, i, :] 是当前 chunk 每个位置的累积 log-decay(从 chunk 起点到该位置);
        # exp() 后乘以 q_i,再与 S_{prev} 做矩阵乘,得到历史记忆对当前查询的响应。
        # 注意:这里 q 已经乘上了块内累积衰减,反映了"历史 S 传入本 chunk 后继续衰减"。

        # 最终输出 = 跨块远程记忆 + 块内局部注意力(作用于修正后的 value)
        core_attn_out[:, :, i] = attn_inter + attn @ v_new

        # 更新跨块状态 S(递推到下一个 chunk):
        # S_{next} = S_{prev} * exp(g_last)                        # 全局衰减(用 chunk 末尾的累积 decay)
        #           + K_i^T @ v_new(按各位置的"剩余衰减"加权)     # 块内净写入
        last_recurrent_state = (
            last_recurrent_state * g[:, :, i, -1, None, None].exp()
            # g[..., -1] 是当前 chunk 末尾的累积 log-decay,exp() 后即 alpha_{chunk_end},
            # 表示整个 chunk 对历史状态的全局遗忘。
            + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
            # (g_last - g_t) 是从位置 t 到 chunk 末的"剩余衰减";
            # k_i 乘以此系数后转置,再与 v_new 做外积累加,
            # 等效于把块内每个 token 的写入贡献按其到 chunk 末的衰减正确加权后汇总进 S。
        )

    # ── 9. 后处理:截断 pad、转回原始形状与数据类型 ──
    if not output_final_state:
        last_recurrent_state = None
    core_attn_out = core_attn_out.reshape(
        core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]
    )
    core_attn_out = core_attn_out[:, :, :sequence_length]   # 去掉右侧补零
    core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
    # 转回 [B, L, H, d_v],恢复原始精度。
    return core_attn_out, last_recurrent_state

代码精讲:gated_delta_rule 后的输出处理

# 更新缓存中的递推状态,供下一个 decode 步骤复用
if cache_params is not None:
    cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
# 注:只在有 cache 时写回;训练/prefill 无 cache 则跳过。

z_shape_og = z.shape
# 注:保存 z 的原始形状 [B, L, H, d_v],后续 reshape 回来时需要。

# 将 core_attn_out 和 z 都压平到二维,满足 RMSNormGated 的输入要求
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
# 注:RMSNormGated 期望输入为 [B*L*H, d_v],故先把前三维合并。

core_attn_out = self.norm(core_attn_out, z)
# 注:self.norm 是 RMSNormGated,公式为:
#   output = RMSNorm(core_attn_out) * sigmoid(z)
# 其中:
#   - RMSNorm(core_attn_out) 对 attention 输出做归一化,稳定数值范围;
#   - sigmoid(z) 作为输出门控,z 是之前从 qkvz 投影中拆出的门控分支;
#   - 两者逐元素相乘,实现对输出的动态缩放(类似 GLU 结构)。

core_attn_out = core_attn_out.reshape(z_shape_og)
# 注:恢复为 [B, L, H, d_v] 的多头形状。

core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)
# 注:把头维摊平,变为 [B, L, H*d_v],准备送入线性投影层。

output = self.out_proj(core_attn_out)
# 注:out_proj 是标准线性层,将 [B, L, H*d_v] 映射回模型维度 [B, L, D],
# 与标准 Attention 的输出投影完全对称。

变量名与公式符号对照(对应本小节输出处理阶段)

代码变量名公式符号典型形状含义
last_recurrent_state$\mathbf{S}_t$[B, H, d_k, d_v]当前步最终记忆矩阵,写回 cache
core_attn_out$\mathbf{o}_t`[B, L, H, d_v][B, L, D]GDN 核心输出,门控归一化前后
z输出门控(公式外)[B, L, H, d_v]RMSNormGated 的门控分支
self.norm(·, z)$\text{RMSNorm}(\mathbf{o}_t)\odot\sigma(\mathbf{z}_t)$[B*L*H, d_v]归一化 + 门控融合
self.out_proj(·)$\mathbf{W}_o$[H*d_v, D]输出线性投影,映射回模型维度

注:z 门控在"前言"部分的 fix_query_key_value_ordering 中与 query/key/value 一同从 qkvz 投影中拆出,贯穿整个前向流程,最终在此处参与输出门控计算。

Gated Delta Net 的计算量分析

下面给出分析中的变量定义。设:

  • 序列长度 $L$
  • chunk 大小 $C$
  • chunk 数 $N=L/C$
  • 头数 $H$
  • 每头维度 $d_k, d_v$
  • 模型维度 $D$

单步推理复杂度

结合代码实现,状态矩阵 $\mathbf{S}$ 的实际形状为 [B, H_v, d_k, d_v]——每个 value head 维护一个 $d_k \times d_v$ 的记忆矩阵;而 Q/K 在 repeat_interleave 后也对齐到 $H_v$ 个头。因此复杂度应以 $H_v$ 为单位计算。

对每个 value head,单步递推的操作为:

  1. S <- alpha * S:对 $d_k \times d_v$ 矩阵逐元素乘标量,$O(d_k d_v)$
  2. kv_mem = (S * k_t).sum(-2):按 $d_k$ 维收缩,$O(d_k d_v)$
  3. delta = (v_t - kv_mem) * beta_t:逐元素操作,$O(d_v)$
  4. rank-1 更新 S += k_t ⊗ delta:外积写入,$O(d_k d_v)$
  5. 输出 o_t = (S * q_t).sum(-2):按 $d_k$ 维收缩,$O(d_k d_v)$

合计每头 $O(4d_k d_v + d_v)$,全 $H_v$ 个头后每 token 的核心递推复杂度:

$$ O(4H_vd_k d_v + H_v d_v) \approx O(H_v d_k d_v) $$

输入/输出投影层(in_proj_qkvzin_proj_baout_proj),三项分别为:

$$ O\!\left(D \cdot (2H_k d_k + 2H_v d_v)\right) + O(D \cdot 2H_v) + O(H_v d_v \cdot D) $$

假设 $H_v \approx H_k$,$H_v d_v \approx D$,投影层复杂度主项约为 $O(D^2)$

$$ O(LD^2) $$

这就是它之所以被称为"线性 Attention"的原因:核心递推部分对 $L$ 的复杂度是线性的,而不是标准 Attention 的 $O(L^2)$。

我们刚刚的计算漏掉了一个细节。如果加上因果卷积(causal_conv1d,kernel size $K$)。卷积作用在拼接后的 Q/K/V 通道上(共 $2H_k d_k + H_v d_v$ 个通道),每个通道每个位置做 $K$ 次乘加,单 token 复杂度为:

$$ O\!\left(K \cdot (2H_k d_k + H_v d_v)\right) \approx O(KD) $$

实现中 $K$ 为固定小常数($K=4$),故卷积开销相对投影层可忽略不计。

单步推理时,投影层的计算量通常远大于核心递推部分,这在 decode 阶段(seq_len=1)尤为明显。

并行训练复杂度

将序列分块后,计算通常分成:

  • 块内构造相关矩阵(如 $\mathbf{K}\mathbf{K}^T$)与三角求解/变换;
  • 块内输出计算(类似 $(\mathbf{Q}\mathbf{K}^T \odot \mathbf{M})\cdot(\cdots)$);
  • 块间状态传播。

按每块估算主项(每头)可写为:

$$ O(C^2 d_k) + O(C^2 d_v) + O(C d_k d_v) $$

全序列($N=L/C$ 块):

$$ O\!\left(\frac{L}{C}(C^2(d_k+d_v)+C d_k d_v)\right) =O\!\left(LC(d_k+d_v)+L d_k d_v\right) $$

乘上头数 $H$:

$$ O\!\left(HLC(d_k+d_v)+HL d_k d_v\right) $$

当 $C$ 取固定工程常数(例如 64)时,整体对 $L$ 仍线性增长。
并且由于块内是大矩阵运算,实际硬件利用率通常优于纯逐 token 串行训练。

Gated Delta Net 的演化

LA(Linear Attention)

状态更新公式:

$$ \mathbf{S}_t=\mathbf{S}_{t-1}+\mathbf{v}_t\mathbf{k}_t^T $$

特征:

  • 最朴素外积记忆累加
  • 没有显式遗忘,长序列易“记忆叠加/碰撞”

Mamba2(加入门控衰减)

状态更新公式:

$$ \mathbf{S}_t=\alpha_t\mathbf{S}_{t-1}+\mathbf{v}_t\mathbf{k}_t^T $$

特征:

  • 有了全局遗忘(通过 $\alpha_t$)
  • 能主动清空历史,但擦除是“比较均匀”的,不够定向

Longhorn(在线回归闭式更新)

状态更新公式:

$$ \mathbf{S}_t=\mathbf{S}_{t-1}(\mathbf{I}-\epsilon_t\mathbf{k}_t\mathbf{k}_t^T)+\epsilon_t\mathbf{v}_t\mathbf{k}_t^T,\quad \epsilon_t=\frac{\beta_t}{1+\beta_t\mathbf{k}_t^T\mathbf{k}_t} $$

特征:

  • 来自更“回归化”的在线学习目标
  • 与 delta 形式非常接近,但系数来自 implicit online learning 闭式推导

DeltaNet(定向替换写入)

状态更新公式:

$$ \mathbf{S}_t=\mathbf{S}_{t-1}(\mathbf{I}-\beta_t\mathbf{k}_t\mathbf{k}_t^T)+\beta_t\mathbf{v}_t\mathbf{k}_t^T $$

特征:

  • 对当前 key 对应方向做定向擦写 + 写入
  • 关联记忆能力强(检索类任务常更好)
  • 但缺少全局快速遗忘机制

Gated DeltaNet(两者融合)

状态更新公式:

$$ \mathbf{S}_t=\mathbf{S}_{t-1}\left(\alpha_t(\mathbf{I}-\beta_t\mathbf{k}_t\mathbf{k}_t^T)\right)+\beta_t\mathbf{v}_t\mathbf{k}_t^T $$

它结合了两种能力:

  • $\alpha_t$:衰减陈旧信息
  • $\beta_t$ + delta 项:针对 key 做精细写入与替换

这也是论文实验里它在需要“记忆 + 过滤”同时成立的场景下表现稳定的核心原因。

Qwen3-Next 架构里的 Gated Delta Net(简述)

modular_qwen3_next.py 看,Qwen3-Next 是混合层设计:

  • Qwen3NextDecoderLayer 里,layer_type 可为
    • linear_attention(即 Qwen3NextGatedDeltaNet
    • full_attention(标准注意力分支)
  • FFN 侧可配 MoE(Qwen3NextSparseMoeBlock)或普通 MLP
  • 因此整体是一个 Attention / Linear-Attention 混合 + MoE 可选 的解码器结构

可以把它理解成:

  • 用一部分 Gated Delta Net 层换取长序列效率与记忆管理能力;
  • 用全注意力层保留强建模上限;
  • 再通过 MoE 扩展容量与性价比。

参考链接