Zhangzhe's Blog

The projection of my life.

0%

HiPPO: Recurrent Memory with Optimal Polynomial Projections

URL

TL;DR

  • HiPPO 全称是 High-order Polynomial Projection Operators,是 SSM (State Space Model) 的开山之作,作者延续 SSM 思路,后续做出了可以和 Transformer 结构掰手腕的 Mamba

  • HiPPO 的目标是用一个有限维的向量来储存这一段 u(t) 的信息,实现方式是将 u(t) 通过 Legendre (勒让德)多项式展开,用有限维的向量存储勒让德多项式系数,且这些 向量的值通过求解勒让德多项式得出,不在训练过程中通过梯度下降更新

  • HiPPO 可以给 RNN 提供一种记忆表示方法,因此一个实际的用处是使用 HiPPO 作为 RNN 的记忆存储算子

Algorithm

  • 勒让德多项式本身是定义在连续函数上的,但实际使用中需要记忆的内容是离散的,因此需要离散化过程

  • HiPPO 的记忆更新公式是 $m(t+1) = Am(t)+Bu(t)$,其中 ABHiPPO 参数,m(t) 表示记忆向量,u(t) 表示更新向量,m(t+1) 表示更新后的记忆向量

    • $A\in\mathbb R^{N\times N}$
    • $B\in\mathbb R^{N\times 1}$
    • N 表示记忆单元的参数尺度,类似于 Transformerhidden size,越大记忆能力越强
  • HiPPOLegT (Translated Legendre Measure)LegS (Scaled Legendre Measure) 两种度量方法,二者都使用上述记忆更新公式,只是 AB 参数不同

    • LegT 使用 翻译 任务的勒让德多项式,本质是一个滑动窗口,只记忆当前时刻前 $\theta$ 窗口内容,$\theta$ 为超参数

      • $A_{nk}=\frac{1}{\theta}
        \begin{cases}
        (-1)^{n-k}(2n+1) & if & n\ge k\\
        2n+1 & if & n < k
        \end{cases}$
      • $B_n=\frac{1}{\theta}(2n+1)(-1)^n$
    • LegS 使用 缩放 的勒让德多项式,记忆全部时刻的序列内容

      • $A_{nk}=\frac{1}{\theta}
        \begin{cases}
        (2n+1)^{\frac{1}{2}}(2k+1)^{\frac{1}{2}} & if & n\gt k\\
        n+1 & if & n = k\\
        0 & if & n < k
        \end{cases}$
      • $B_n=(2n+1)^{\frac{1}{2}}$
  • Permute Mnist 分类任务的例子讲解 HiPPO 如何作为 RNN 的单元参与计算,以及HiPPO 的记忆单元如何更新

  • Permute Mnist 任务是将 28x28Mnist 图像的每一类按照同一种 pattern 进行 shuffle,训练并分类

  • 下图为使用 HiPPO 作为记忆单元的 RNN 网络解决 Permute Mnist 任务的计算过程,input_t 是每次顺序输入图片的一个像素值,是一个时间步总长为 28 * 28 = 784RNN 网络,最后一个 hidden state 输出映射到 class dim 上进行分类

graph TD
    subgraph input;
    input_t([input_t]);
    h_t;
    end;
    subgraph fully_connect;
    W_hxm;
    W_gxm;
    W_uxh;
    end;
    input_t([input_t])-->|1|Concat_1[Concat];
    h_t([h_t])-->|512|Concat_1-->|513|W_uxh-->|1|u_t([u_t]);
    subgraph update_memory;
    A([A])-->|max_length, 512, 512|get_index_A[get_index]-->|512, 512|A_t;
    timestep([timestep])-->get_index_A;
    A_t([A_t])-->|512, 512|MatMul_A[MatMul];
    m_t([m_t])-->|1, 512|MatMul_A-->|1, 512|Add;
    m_t([m_t])-->|1, 512|Add;
    B([B])-->|max_length, 512|get_index_B[get_index]-->|512, 1|B_t;
    timestep([timestep])-->get_index_B;
    B_t([B_t])-->|512, 1|MatMul_B[MatMul];
    u_t([u_t])-->|1|MatMul_B-->|1, 512|Add-->|1, 512|m_t+1([m_t+1]);
    end;
    m_t+1-->|512|Concat_2;
    input_t-->|1|Concat_2[Concat]-->|513|W_hxm-->|512|Tanh-->|512|hidden([hidden]);
    Concat_2-->|513|W_gxm-->|512|gate([gate]);
    h_t-->|512|Alpha_Blending-->|512|h_t+1([h_t+1])-->|512|until_last_h{until_last_h};
    hidden-->|512|Alpha_Blending;
    gate-->|512|Alpha_Blending;
    subgraph output;
    until_last_h-->|512|map_to_class_dim-->|10|classification_result([classification_result]);
    end;

Thought

  • LegTLegS 的参数计算过程需要较强的数学功底才能完全理解
  • 如果只把 AB 当做 万能的不需要梯度下降更新的神经网络记忆力更新参数,那么实际上并不复杂