URL
TL;DR
HiPPO
全称是High-order Polynomial Projection Operators
,是SSM (State Space Model)
的开山之作,作者延续SSM
思路,后续做出了可以和Transformer
结构掰手腕的Mamba
HiPPO
的目标是用一个有限维的向量来储存这一段u(t)
的信息,实现方式是将u(t)
通过Legendre
(勒让德)多项式展开,用有限维的向量存储勒让德多项式系数,且这些 向量的值通过求解勒让德多项式得出,不在训练过程中通过梯度下降更新HiPPO
可以给RNN
提供一种记忆表示方法,因此一个实际的用处是使用HiPPO
作为RNN
的记忆存储算子
Algorithm
勒让德多项式本身是定义在连续函数上的,但实际使用中需要记忆的内容是离散的,因此需要离散化过程
HiPPO
的记忆更新公式是 $m(t+1) = Am(t)+Bu(t)$,其中A
和B
是HiPPO
参数,m(t)
表示记忆向量,u(t)
表示更新向量,m(t+1)
表示更新后的记忆向量- $A\in\mathbb R^{N\times N}$
- $B\in\mathbb R^{N\times 1}$
N
表示记忆单元的参数尺度,类似于Transformer
的hidden size
,越大记忆能力越强
HiPPO
有LegT (Translated Legendre Measure)
和LegS (Scaled Legendre Measure)
两种度量方法,二者都使用上述记忆更新公式,只是A
和B
参数不同LegT
使用 翻译 任务的勒让德多项式,本质是一个滑动窗口,只记忆当前时刻前 $\theta$ 窗口内容,$\theta$ 为超参数- $A_{nk}=\frac{1}{\theta}
\begin{cases}
(-1)^{n-k}(2n+1) & if & n\ge k\\
2n+1 & if & n < k
\end{cases}$ - $B_n=\frac{1}{\theta}(2n+1)(-1)^n$
- $A_{nk}=\frac{1}{\theta}
LegS
使用 缩放 的勒让德多项式,记忆全部时刻的序列内容- $A_{nk}=\frac{1}{\theta}
\begin{cases}
(2n+1)^{\frac{1}{2}}(2k+1)^{\frac{1}{2}} & if & n\gt k\\
n+1 & if & n = k\\
0 & if & n < k
\end{cases}$ - $B_n=(2n+1)^{\frac{1}{2}}$
- $A_{nk}=\frac{1}{\theta}
以
Permute Mnist
分类任务的例子讲解HiPPO
如何作为RNN
的单元参与计算,以及HiPPO
的记忆单元如何更新Permute Mnist
任务是将28x28
的Mnist
图像的每一类按照同一种pattern
进行shuffle
,训练并分类下图为使用
HiPPO
作为记忆单元的RNN
网络解决Permute Mnist
任务的计算过程,input_t
是每次顺序输入图片的一个像素值,是一个时间步总长为28 * 28 = 784
的RNN
网络,最后一个hidden state
输出映射到class dim
上进行分类
graph TD subgraph input; input_t([input_t]); h_t; end; subgraph fully_connect; W_hxm; W_gxm; W_uxh; end; input_t([input_t])-->|1|Concat_1[Concat]; h_t([h_t])-->|512|Concat_1-->|513|W_uxh-->|1|u_t([u_t]); subgraph update_memory; A([A])-->|max_length, 512, 512|get_index_A[get_index]-->|512, 512|A_t; timestep([timestep])-->get_index_A; A_t([A_t])-->|512, 512|MatMul_A[MatMul]; m_t([m_t])-->|1, 512|MatMul_A-->|1, 512|Add; m_t([m_t])-->|1, 512|Add; B([B])-->|max_length, 512|get_index_B[get_index]-->|512, 1|B_t; timestep([timestep])-->get_index_B; B_t([B_t])-->|512, 1|MatMul_B[MatMul]; u_t([u_t])-->|1|MatMul_B-->|1, 512|Add-->|1, 512|m_t+1([m_t+1]); end; m_t+1-->|512|Concat_2; input_t-->|1|Concat_2[Concat]-->|513|W_hxm-->|512|Tanh-->|512|hidden([hidden]); Concat_2-->|513|W_gxm-->|512|gate([gate]); h_t-->|512|Alpha_Blending-->|512|h_t+1([h_t+1])-->|512|until_last_h{until_last_h}; hidden-->|512|Alpha_Blending; gate-->|512|Alpha_Blending; subgraph output; until_last_h-->|512|map_to_class_dim-->|10|classification_result([classification_result]); end;
Thought
LegT
和LegS
的参数计算过程需要较强的数学功底才能完全理解- 如果只把
A
和B
当做 万能的不需要梯度下降更新的神经网络记忆力更新参数,那么实际上并不复杂