URL
TL;DR
HiPPO
全称是High-order Polynomial Projection Operators
,是SSM (State Space Model)
的开山之作,之后作者延续SSM
思路,做出了可以和Transformer
结构掰手腕的Mamba
HiPPO
的目标是用一个有限维的向量来储存这一段u(t)
的信息,实现方式是将u(t)
通过Legendre
(勒让德)多项式展开,用有限维的向量存储勒让德多项式系数,且这些 向量的值通过求解勒让德多项式得出,不在训练过程中通过梯度下降更新HiPPO
可以给RNN
提供一种记忆表示方法,因此一个实际的用处是使用HiPPO
作为RNN
的记忆存储算子
Algorithm
- 勒让德多项式本身是定义在连续函数上的,但实际使用中需要记忆的内容是离散的,因此需要离散化过程
HiPPO
的记忆更新公式是 ,其中A
和B
是HiPPO
参数,m(t)
表示记忆向量,u(t)
表示更新向量,m(t+1)
表示更新后的记忆向量N
表示记忆单元的参数尺度,类似于Transformer
的hidden size
,越大记忆能力越强
HiPPO
有LegT (Translated Legendre Measure)
和LegS (Scaled Legendre Measure)
两种度量方法,二者都使用上述记忆更新公式,只是A
和B
参数不同LegT
使用 翻译 任务的勒让德多项式,本质是一个滑动窗口,只记忆当前时刻前 窗口内容, 为超参数LegS
使用 缩放 的勒让德多项式,记忆全部时刻的序列内容
- 以
Permute Mnist
分类任务的例子讲解HiPPO
如何作为RNN
的单元参与计算,以及HiPPO
的记忆单元如何更新 Permute Mnist
任务是将28x28
的Mnist
图像的每一类按照同一种pattern
进行shuffle
,训练并分类- 下图为使用
HiPPO
作为记忆单元的RNN
网络解决Permute Mnist
任务的计算过程,input_t
是每次顺序输入图片的一个像素值,是一个时间步总长为28 * 28 = 784
的RNN
网络,最后一个hidden state
输出映射到class dim
上进行分类
graph TD subgraph input; input_t([input_t]); h_t; end; subgraph fully_connect; W_hxm; W_gxm; W_uxh; end; h_t([h_t])-->|512|Concat_1-->|513|W_uxh-->|1|u_t([u_t]); input_t([input_t])-->|1|Concat_1[Concat]; subgraph update_memory; A([A])-->|max_length, 512, 512|get_index_A[get_index]-->|512, 512|A_t; timestep([timestep])-->get_index_A; A_t([A_t])-->|512, 512|MatMul_A[MatMul]; m_t([m_t])-->|1, 512|MatMul_A-->|1, 512|Add; m_t([m_t])-->|1, 512|Add; timestep([timestep])-->get_index_B; B([B])-->|max_length, 512|get_index_B[get_index]-->|512, 1|B_t; B_t([B_t])-->|512, 1|MatMul_B[MatMul]; u_t([u_t])-->|1|MatMul_B-->|1, 512|Add-->|1, 512|m_t+1([m_t+1]); end; m_t+1-->|512|Concat_2; input_t-->|1|Concat_2[Concat]-->|513|W_hxm-->|512|Tanh-->|512|hidden([hidden]); Concat_2-->|513|W_gxm-->|512|gate([gate]); h_t-->|512|Alpha_Blending-->|512|h_t+1([h_t+1])-->|512|until_last_h{until_last_h}; hidden-->|512|Alpha_Blending; gate-->|512|Alpha_Blending; subgraph output; until_last_h-->|512|map_to_class_dim-->|10|classification_result([classification_result]); end;
Thought
LegT
和LegS
的参数计算过程需要较强的数学功底才能完全理解- 如果只把
A
和B
当做 万能的不需要梯度下降更新的神经网络记忆力更新参数,那么实际上并不复杂