diff --git a/nemo/collections/nlp/modules/common/hyena/hyena.py b/nemo/collections/nlp/modules/common/hyena/hyena.py
index f90ae680db311..4e4711e41eca2 100644
--- a/nemo/collections/nlp/modules/common/hyena/hyena.py
+++ b/nemo/collections/nlp/modules/common/hyena/hyena.py
@@ -345,13 +345,17 @@ def forward(self, u, *args, **kwargs):
 
         uc = self.short_filter(u)[..., :l_filter]
 
-        uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l',
-                       z=self.num_blocks,
-                       ho=self.num_heads,
-                       v=self.head_dim * (self.order + 1)
-                       )
+        # Workaround for shape error in fftconv, based on:
+        # https://github.com/HazyResearch/safari/issues/26#issuecomment-1589018138
 
-        *x, v = uc.split(self.d_model, dim=2)
+        # uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l',
+        #                z=self.num_blocks,
+        #                ho=self.num_heads,
+        #                v=self.head_dim * (self.order + 1)
+        #                )
+
+        # *x, v = uc.split(self.d_model, dim=2)
+        *x, v = uc.split(self.d_model, dim=1)
         k = self.filter_fn.filter(l_filter)
 
         # `c` is always 1 by default
@@ -370,7 +374,8 @@ def forward(self, u, *args, **kwargs):
                 v = self.dropout(v * x_i)
 
             # the bias term is broadcasted. Last dimension (l) is handled by fftconv
-            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
+            # v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
+            v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
 
             if self.post_order_ffn:
                 w = self.ord_proj_w[o]
@@ -378,7 +383,8 @@ def forward(self, u, *args, **kwargs):
                     rearrange(w, 'h1 h2 -> 1 h1 h2 1 1 1'), rearrange(v, 'b h v z l -> b h 1 v z l')
                 )
 
-        y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads))
+        # y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads))
+        y = self.activation((v * x[0]).transpose(-2, -1))
         y = self.out_proj(y)
 
         # Convert back to sequence-first for MCore