URL
TL;DR
HiPPO 全称是 High-order Polynomial Projection Operators,是 SSM (State Space Model) 的开山之作,之后作者延续 SSM 思路,做出了可以和 Transformer 结构掰手腕的 Mamba
HiPPO 的目标是用一个有限维的向量来储存这一段 u(t) 的信息,实现方式是将 u(t) 通过 Legendre (勒让德)多项式展开,用有限维的向量存储勒让德多项式系数,且这些 向量的值通过求解勒让德多项式得出,不在训练过程中通过梯度下降更新
HiPPO 可以给 RNN 提供一种记忆表示方法,因此一个实际的用处是使用 HiPPO 作为 RNN 的记忆存储算子
Algorithm
- 勒让德多项式本身是定义在连续函数上的,但实际使用中需要记忆的内容是离散的,因此需要离散化过程
HiPPO 的记忆更新公式是 m(t+1)=Am(t)+Bu(t),其中 A 和 B 是 HiPPO 参数,m(t) 表示记忆向量,u(t) 表示更新向量,m(t+1) 表示更新后的记忆向量
- A∈RN×N
- B∈RN×1
N 表示记忆单元的参数尺度,类似于 Transformer 的 hidden size,越大记忆能力越强
HiPPO 有 LegT (Translated Legendre Measure) 和 LegS (Scaled Legendre Measure) 两种度量方法,二者都使用上述记忆更新公式,只是 A 和 B 参数不同
LegT 使用 翻译 任务的勒让德多项式,本质是一个滑动窗口,只记忆当前时刻前 θ 窗口内容,θ 为超参数
- Ank=θ1{(−1)n−k(2n+1)2n+1ififn≥kn<k
- Bn=θ1(2n+1)(−1)n
LegS 使用 缩放 的勒让德多项式,记忆全部时刻的序列内容
- Ank=θ1⎩⎪⎪⎨⎪⎪⎧(2n+1)21(2k+1)21n+10ifififn>kn=kn<k
- Bn=(2n+1)21
- 以
Permute Mnist 分类任务的例子讲解 HiPPO 如何作为 RNN 的单元参与计算,以及HiPPO 的记忆单元如何更新
Permute Mnist 任务是将 28x28 的 Mnist 图像的每一类按照同一种 pattern 进行 shuffle,训练并分类
- 下图为使用
HiPPO 作为记忆单元的 RNN 网络解决 Permute Mnist 任务的计算过程,input_t 是每次顺序输入图片的一个像素值,是一个时间步总长为 28 * 28 = 784 的 RNN 网络,最后一个 hidden state 输出映射到 class dim 上进行分类
graph TD
subgraph input;
input_t([input_t]);
h_t;
end;
subgraph fully_connect;
W_hxm;
W_gxm;
W_uxh;
end;
h_t([h_t])-->|512|Concat_1-->|513|W_uxh-->|1|u_t([u_t]);
input_t([input_t])-->|1|Concat_1[Concat];
subgraph update_memory;
A([A])-->|max_length, 512, 512|get_index_A[get_index]-->|512, 512|A_t;
timestep([timestep])-->get_index_A;
A_t([A_t])-->|512, 512|MatMul_A[MatMul];
m_t([m_t])-->|1, 512|MatMul_A-->|1, 512|Add;
m_t([m_t])-->|1, 512|Add;
timestep([timestep])-->get_index_B;
B([B])-->|max_length, 512|get_index_B[get_index]-->|512, 1|B_t;
B_t([B_t])-->|512, 1|MatMul_B[MatMul];
u_t([u_t])-->|1|MatMul_B-->|1, 512|Add-->|1, 512|m_t+1([m_t+1]);
end;
m_t+1-->|512|Concat_2;
input_t-->|1|Concat_2[Concat]-->|513|W_hxm-->|512|Tanh-->|512|hidden([hidden]);
Concat_2-->|513|W_gxm-->|512|gate([gate]);
h_t-->|512|Alpha_Blending-->|512|h_t+1([h_t+1])-->|512|until_last_h{until_last_h};
hidden-->|512|Alpha_Blending;
gate-->|512|Alpha_Blending;
subgraph output;
until_last_h-->|512|map_to_class_dim-->|10|classification_result([classification_result]);
end;
Thought
LegT 和 LegS 的参数计算过程需要较强的数学功底才能完全理解
- 如果只把
A 和 B 当做 万能的不需要梯度下降更新的神经网络记忆力更新参数,那么实际上并不复杂