URL
TL;DR
HiPPO
全称是 High-order Polynomial Projection Operators
,是 SSM (State Space Model)
的开山之作,作者延续 SSM
思路,后续做出了可以和 Transformer
结构掰手腕的 Mamba
HiPPO
的目标是用一个有限维的向量来储存这一段 u(t)
的信息,实现方式是将 u(t)
通过 Legendre
(勒让德)多项式展开,用有限维的向量存储勒让德多项式系数,且这些 向量的值通过求解勒让德多项式得出,不在训练过程中通过梯度下降更新
HiPPO
可以给 RNN
提供一种记忆表示方法,因此一个实际的用处是使用 HiPPO
作为 RNN
的记忆存储算子
Algorithm
- 勒让德多项式本身是定义在连续函数上的,但实际使用中需要记忆的内容是离散的,因此需要离散化过程
HiPPO
的记忆更新公式是 m(t+1)=Am(t)+Bu(t),其中 A
和 B
是 HiPPO
参数,m(t)
表示记忆向量,u(t)
表示更新向量,m(t+1)
表示更新后的记忆向量
- A∈RN×N
- B∈RN×1
N
表示记忆单元的参数尺度,类似于 Transformer
的 hidden size
,越大记忆能力越强
HiPPO
有 LegT (Translated Legendre Measure)
和 LegS (Scaled Legendre Measure)
两种度量方法,二者都使用上述记忆更新公式,只是 A
和 B
参数不同
LegT
使用 翻译 任务的勒让德多项式,本质是一个滑动窗口,只记忆当前时刻前 θ 窗口内容,θ 为超参数
- Ank=θ1{(−1)n−k(2n+1)2n+1ififn≥kn<k
- Bn=θ1(2n+1)(−1)n
LegS
使用 缩放 的勒让德多项式,记忆全部时刻的序列内容
- Ank=θ1⎩⎪⎪⎨⎪⎪⎧(2n+1)21(2k+1)21n+10ifififn>kn=kn<k
- Bn=(2n+1)21
- 以
Permute Mnist
分类任务的例子讲解 HiPPO
如何作为 RNN
的单元参与计算,以及HiPPO
的记忆单元如何更新
Permute Mnist
任务是将 28x28
的 Mnist
图像的每一类按照同一种 pattern
进行 shuffle
,训练并分类
- 下图为使用
HiPPO
作为记忆单元的 RNN
网络解决 Permute Mnist
任务的计算过程,input_t
是每次顺序输入图片的一个像素值,是一个时间步总长为 28 * 28 = 784
的 RNN
网络,最后一个 hidden state
输出映射到 class dim
上进行分类
graph TD
subgraph input;
input_t([input_t]);
h_t;
end;
subgraph fully_connect;
W_hxm;
W_gxm;
W_uxh;
end;
input_t([input_t])-->|1|Concat_1[Concat];
h_t([h_t])-->|512|Concat_1-->|513|W_uxh-->|1|u_t([u_t]);
subgraph update_memory;
A([A])-->|max_length, 512, 512|get_index_A[get_index]-->|512, 512|A_t;
timestep([timestep])-->get_index_A;
A_t([A_t])-->|512, 512|MatMul_A[MatMul];
m_t([m_t])-->|1, 512|MatMul_A-->|1, 512|Add;
m_t([m_t])-->|1, 512|Add;
B([B])-->|max_length, 512|get_index_B[get_index]-->|512, 1|B_t;
timestep([timestep])-->get_index_B;
B_t([B_t])-->|512, 1|MatMul_B[MatMul];
u_t([u_t])-->|1|MatMul_B-->|1, 512|Add-->|1, 512|m_t+1([m_t+1]);
end;
m_t+1-->|512|Concat_2;
input_t-->|1|Concat_2[Concat]-->|513|W_hxm-->|512|Tanh-->|512|hidden([hidden]);
Concat_2-->|513|W_gxm-->|512|gate([gate]);
h_t-->|512|Alpha_Blending-->|512|h_t+1([h_t+1])-->|512|until_last_h{until_last_h};
hidden-->|512|Alpha_Blending;
gate-->|512|Alpha_Blending;
subgraph output;
until_last_h-->|512|map_to_class_dim-->|10|classification_result([classification_result]);
end;
Thought
LegT
和 LegS
的参数计算过程需要较强的数学功底才能完全理解
- 如果只把
A
和 B
当做 万能的不需要梯度下降更新的神经网络记忆力更新参数,那么实际上并不复杂