diff --git a/src/engine.py b/src/engine.py index 18775e1..809f3be 100644 --- a/src/engine.py +++ b/src/engine.py @@ -84,12 +84,15 @@ def parallel_fir( z_pre = fir_fn( u, weight, - bias, + bias=None, # don't pass it here, add manually instead! source of small error stride=1, padding=fir_length - 1, groups=u.shape[1], )[..., :L] + # add manually instead! source of small error + z_pre = z_pre + bias[None, :, None] + # handle padding post fir, the only place with biases if type(padding_mask) == torch.Tensor: z_pre = z_pre * padding_mask[:, None] @@ -289,9 +292,9 @@ def prefill_via_direct_recurrence( # x1v_: b, d, l, sdim, reim for i in range(L): state[..., 0] = poles[..., 0] * state[..., 0] - poles[..., 1] * state[..., 1] + x1v_[:, :, i, :, 0] - state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1] + state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1] output[:, :, i] = torch.sum(residues * state, dim=-2)[..., 0] # .real - + inference_params.state_dict[self.layer_idx] = torch.view_as_complex(state.to(dtype=torch.float32)) return output