diff --git a/model/hippo.py b/model/hippo.py index f760f11..fd10eb1 100644 --- a/model/hippo.py +++ b/model/hippo.py @@ -12,21 +12,32 @@ from model import unroll from model.op import transition -# forward_aliases = ['euler', 'forward_euler', 'forward', 'forward_diff'] -# backward_aliases = ['backward', 'backward_diff', 'backward_euler'] -# bilinear_aliases = ['bilinear', 'tustin', 'trapezoidal', 'trapezoid'] -# zoh_aliases = ['zoh'] +""" +The HiPPO_LegT and HiPPO_LegS modules satisfy the HiPPO interface: + +The forward() method takes an input sequence f of length L to an output sequence c of shape (L, N) where N is the order of the HiPPO operator. +c[k] can be thought of as representing all of f[:k] via coefficients of a polynomial approximation. + +The reconstruct() method takes the coefficients and turns each coefficient into a reconstruction of the original input. +Note that each coefficient c[k] turns into an approximation of the entire input f, so this reconstruction has shape (L, L), +and the last element of this reconstruction (which has shape (L,)) is the most accurate reconstruction of the original input. + +Both of these two methods construct approximations according to different measures, defined in the HiPPO paper. +The first one is the "Translated Legendre" (which is up to scaling equal to the LMU matrix), +and the second one is the "Scaled Legendre". +Each method comprises an exact recurrence c_k = A_k c_{k-1} + B_k f_k, and an exact reconstruction formula based on the corresponding polynomial family. +""" class HiPPO_LegT(nn.Module): - def __init__(self, N, dt=1.0, measure='legt', discretization='bilinear'): + def __init__(self, N, dt=1.0, discretization='bilinear'): """ N: the order of the HiPPO projection dt: discretization step size - should be roughly inverse to the length of the sequence """ super().__init__() self.N = N - A, B = transition(measure, N) + A, B = transition('lmu', N) C = np.ones((1, N)) D = np.zeros((1,)) # dt, discretization options @@ -171,7 +182,7 @@ def plot(): f, _ = next(it) f = f.squeeze(0).squeeze(-1) - legt = HiPPO_LegT(N, 1./T, measure='lmu') + legt = HiPPO_LegT(N, 1./T) f_legt = legt.reconstruct(legt(f))[-1] legs = HiPPO_LegS(N, T) f_legs = legs.reconstruct(legs(f))[-1]