From 5ed7b869df3cad24b35b7c1afe9aa4ae9a2d1bf1 Mon Sep 17 00:00:00 2001 From: Brian Hie Date: Sun, 18 Feb 2024 16:26:18 +0000 Subject: [PATCH 1/2] modification that introduces small source of error compared to savanna --- src/engine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/engine.py b/src/engine.py index 18775e1..72df9f9 100644 --- a/src/engine.py +++ b/src/engine.py @@ -84,12 +84,15 @@ def parallel_fir( z_pre = fir_fn( u, weight, - bias, + bias=None, # don't pass it here, add manually instead! source of small error stride=1, padding=fir_length - 1, groups=u.shape[1], )[..., :L] + # add manually instead! source of small error + z_pre = z_pre + bias[None, :, None] + # handle padding post fir, the only place with biases if type(padding_mask) == torch.Tensor: z_pre = z_pre * padding_mask[:, None] From 30a95f89c1a2b42e2608c3c848f4491f0091bc6d Mon Sep 17 00:00:00 2001 From: Brian Hie Date: Sun, 18 Feb 2024 21:38:33 +0000 Subject: [PATCH 2/2] black --- src/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/engine.py b/src/engine.py index 72df9f9..809f3be 100644 --- a/src/engine.py +++ b/src/engine.py @@ -292,9 +292,9 @@ def prefill_via_direct_recurrence( # x1v_: b, d, l, sdim, reim for i in range(L): state[..., 0] = poles[..., 0] * state[..., 0] - poles[..., 1] * state[..., 1] + x1v_[:, :, i, :, 0] - state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1] + state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1] output[:, :, i] = torch.sum(residues * state, dim=-2)[..., 0] # .real - + inference_params.state_dict[self.layer_idx] = torch.view_as_complex(state.to(dtype=torch.float32)) return output