Skip to content

Commit

Permalink
Hyena Operator (NVIDIA#9264)
Browse files Browse the repository at this point in the history
* Initial reference code commit, unchanged

Signed-off-by: Guy Jacob <[email protected]>

* Hyena code changes for NeMO compatibility

Signed-off-by: Guy Jacob <[email protected]>

* MCore spec override functionality + example config w. hyena

Signed-off-by: Guy Jacob <[email protected]>

* Additional changes - now working on char-level TinyShakespeare

* Add missing input LayerNorm to spec (in the default attention
  spec it's fused with the projection Linear layer, so not
  explicitly defined)
* Shape conversion at start and end of Hyena forward

Signed-off-by: Guy Jacob <[email protected]>

* Add fftconv cuda impl from safari

Signed-off-by: Guy Jacob <[email protected]>

* Workaround for shape error in fftconv

See: HazyResearch/safari#26 (comment)
Signed-off-by: Guy Jacob <[email protected]>

* Explicitly convert kernel to FP32

(torch.fft doesn't support bf16)

Signed-off-by: Guy Jacob <[email protected]>

* Working run configs

Signed-off-by: Guy Jacob <[email protected]>

* Remove sharded_state_dict from HyenaOperator

(made redundant by the default inmplementation in Megatron)

Signed-off-by: Guy Jacob <[email protected]>

* Update configs

Signed-off-by: Guy Jacob <[email protected]>

* Testing TE Linear classes in HyenaOperator

Signed-off-by: Guy Jacob <[email protected]>

* Revert to FusedDense for in/out projections after merging with 24.01.01

Signed-off-by: Guy Jacob <[email protected]>

* Fix bug (use fused LNorm+Linear), bring back TE layers

Signed-off-by: Guy Jacob <[email protected]>

* Configs rename + cleanup

Signed-off-by: Guy Jacob <[email protected]>

* FlashFFTConv, Multi-head, some cleanup

Signed-off-by: Guy Jacob <[email protected]>

* Bug fix - init FlashFFTConv with 2*seq_len

Signed-off-by: Guy Jacob <[email protected]>

* ModuleSpec + replace nn.Conv1d with causal_conv1d

Signed-off-by: Guy Jacob <[email protected]>

* Remove unneeded arguments

Signed-off-by: Guy Jacob <[email protected]>

* More cleanup, remove fftconv ref functions

Signed-off-by: Guy Jacob <[email protected]>

* Refactor HyenaFilter + more cleanup

* Refactor in spirit of implementation in MAD-Lab repo:
  https://github.com/athms/mad-lab/blob/main/mad/model/layers/hyena.py

Signed-off-by: Guy Jacob <[email protected]>

* Add missing attributions

Signed-off-by: Guy Jacob <[email protected]>

* Remove fftconv sources

Signed-off-by: Guy Jacob <[email protected]>

* Bug fixes

Signed-off-by: Guy Jacob <[email protected]>

* Remove d_model from external API, take from TransformerConfig

Signed-off-by: Guy Jacob <[email protected]>

* cleanup config

Signed-off-by: Guy Jacob <[email protected]>

* Remove spec override logic (possibly push separately)

Signed-off-by: Guy Jacob <[email protected]>

* Add tests

Signed-off-by: Guy Jacob <[email protected]>

* Keep only megatron_gpt_config_hyena (w. 153m parameters)

Signed-off-by: Guy Jacob <[email protected]>

* Black + isort formatting changes

Signed-off-by: Guy Jacob <[email protected]>

* Fixes following PR review

* Clearer names + more documentation for config params
* Clearer README
* Check seq len < 8K with safari-fftconv
* Avoid 0*bias op during forward

Signed-off-by: Guy Jacob <[email protected]>

* Fix tests following param name changes

Signed-off-by: Guy Jacob <[email protected]>

---------

Signed-off-by: Guy Jacob <[email protected]>
  • Loading branch information
guyjacob authored Jun 13, 2024
1 parent 937c8d4 commit e56edc9
Show file tree
Hide file tree
Showing 9 changed files with 1,217 additions and 1 deletion.
277 changes: 277 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config_hyena.yaml

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec
from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.modules.common.hyena.hyena_spec import get_gpt_layer_with_te_and_hyena_spec
from nemo.collections.nlp.modules.common.megatron.build_model import build_model
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.utils import (
Expand Down Expand Up @@ -143,7 +144,7 @@ def mcore_supports_moe() -> bool:
return False


def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True):
def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True, hyena_cfg: Dict = None):
if num_experts is not None:
assert mcore_supports_moe(), "Megatron-core >= v0.5.0 is required for MoE"

Expand All @@ -155,6 +156,7 @@ def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True):
"megatron_falcon_gpt": get_falcon_layer_spec(),
"megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(),
"modelopt": get_gpt_layer_modelopt_spec(),
"te_gpt_hyena": get_gpt_layer_with_te_and_hyena_spec(hyena_cfg),
}
if spec_name not in name_spec_dict:
raise ValueError(f"Spec name '{spec_name}' is not recognized.")
Expand Down Expand Up @@ -417,6 +419,7 @@ def model_provider_func(self, pre_process, post_process):
self.transformer_config.num_moe_experts,
self.transformer_config.moe_grouped_gemm,
self.transformer_engine,
self.cfg.get('hyena', None),
),
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
max_sequence_length=self.cfg.get('encoder_seq_length', 512),
Expand Down
26 changes: 26 additions & 0 deletions nemo/collections/nlp/modules/common/hyena/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
## Required Dependencies for Hyena

We depend on 3rd-party libraries for FFT convolutions implementation. Each library supports different use-cases:

| Library | Supported Sequence Length | Single/Multi-Head Support |
|:----------------:|:-------------------------:|:-------------------------:|
| Safari `fftconv` | Up to 8192 | 1 or 8 heads |
| FlashFFTConv | Up to 4M | Single-head only |

Note the overlapping support for single-head with sequence length up to 8192. By default, in this case we default to Safari `fftconv` as it is faster (and fallback to FlashFFTConv). The user may force the FFT convolution implementation used by setting the configuration key `model.hyena.fftconv_type` to either `safari` or `flash`.

### Installation

#### Safari `fftconv`

Install from the [Safari repository](https://github.com/HazyResearch/safari/tree/main/csrc/fftconv). Run the following in a terminal:

```bash
git clone https://github.com/HazyResearch/safari.git
cd safari/csrc/fftconv
pip install .
```

#### FlashFFTConv

Follow the [installation instructions](https://github.com/HazyResearch/flash-fft-conv?tab=readme-ov-file#installation) in the FlashFFTConv repository.
1 change: 1 addition & 0 deletions nemo/collections/nlp/modules/common/hyena/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from nemo.collections.nlp.modules.common.hyena.hyena import HyenaOperator
129 changes: 129 additions & 0 deletions nemo/collections/nlp/modules/common/hyena/fftconv_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import math

import torch
from einops import rearrange
from fftconv import fftconv_bwd, fftconv_fwd

# Code taken from:
# https://github.com/HazyResearch/safari/blob/main/src/ops/fftconv.py


class FFTConvFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
u,
k,
D,
dropout_mask=None,
gelu=True,
force_fp16_output=False,
output_hbl_layout=False,
v=None,
head_dim=1,
q=None,
fftfp16=False,
k_rev=None,
):
seqlen = u.shape[-1]
fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16)
k_f = torch.fft.rfft(k, n=fft_size)
if k_rev is not None:
k_f = k_f + torch.fft.rfft(k_rev, n=fft_size).conj()
if u.stride(-1) != 1:
u = u.contiguous()
k_f = k_f.contiguous()
D = D.contiguous()
if v is not None and v.stride(-1) != 1:
v = v.contiguous()
if q is not None and q.stride(-1) != 1:
q = q.contiguous()
if dropout_mask is not None:
dropout_mask = dropout_mask.contiguous()
ctx.save_for_backward(u, k_f, D, dropout_mask, v, q)
ctx.output_hbl_layout = output_hbl_layout
ctx.head_dim = head_dim
ctx.gelu = gelu
ctx.fftfp16 = fftfp16
ctx.has_k_rev = k_rev is not None
out = fftconv_fwd(
u,
k_f,
D,
v,
head_dim,
q,
dropout_mask,
gelu,
False,
False,
fft_size,
force_fp16_output,
output_hbl_layout,
fftfp16,
)
return out

@staticmethod
def backward(ctx, dout):
if ctx.output_hbl_layout:
dout = rearrange(rearrange(dout, 'b h l -> h b l').contiguous(), 'h b l -> b h l')
else:
dout = dout.contiguous()
u, k_f, D, dropout_mask, v, q = ctx.saved_tensors
seqlen = u.shape[-1]
fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16)
du, dk_f, dD, dv, dq = fftconv_bwd(
dout,
u,
k_f,
D,
v,
ctx.head_dim,
q,
dropout_mask,
ctx.gelu,
False,
False,
fft_size,
ctx.output_hbl_layout,
ctx.fftfp16,
)
dk = torch.fft.irfft(dk_f, n=fft_size, norm='forward')[..., :seqlen]
dk_rev = None if not ctx.has_k_rev else torch.fft.irfft(dk_f.conj(), n=fft_size, norm='forward')[..., :seqlen]
if v is not None:
dv = dv.to(dtype=v.dtype) # We do atomicAdd in fp32 so might need to convert to fp16
return (
du,
dk,
dD,
None,
None,
None,
None,
dv,
None,
dq,
None,
dk_rev,
)


def fftconv_func(
u,
k,
D,
dropout_mask=None,
gelu=True,
force_fp16_output=False,
output_hbl_layout=False,
v=None,
head_dim=1,
q=None,
fftfp16=False,
k_rev=None,
):
return FFTConvFunc.apply(
u, k, D, dropout_mask, gelu, force_fp16_output, output_hbl_layout, v, head_dim, q, fftfp16, k_rev
)
Loading

0 comments on commit e56edc9

Please sign in to comment.