Skip to content

Commit

Permalink
Merge pull request #16 from togethercomputer/pypi-install
Browse files Browse the repository at this point in the history
chore: extend pypi packaging with setup.py
  • Loading branch information
Zymrael authored Feb 23, 2024
2 parents efd3ead + 436c998 commit 1746a0d
Show file tree
Hide file tree
Showing 18 changed files with 42 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__pycache__
*egg*
10 changes: 5 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import torch
import yaml

from src.generation import Generator
from src.model import StripedHyena
from src.sample import sample
from src.tokenizer import HFAutoTokenizer
from src.utils import dotdict, print_rank_0
from stripedhyena.generation import Generator
from stripedhyena.model import StripedHyena
from stripedhyena.sample import sample
from stripedhyena.tokenizer import HFAutoTokenizer
from stripedhyena.utils import dotdict, print_rank_0

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run StripedHyena Model")
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ build-backend = "setuptools.build_meta"

[project]
name = "stripedhyena"
version = "0.2.0"
version = "0.2.1"
description = "Model and inference code for beyond Transformer architectures"
readme = "README.md"
license = {file = "LICENSE"}
authors = [{ name = "Michael Poli"}]
dependencies = [
"transformers",
Expand All @@ -22,3 +23,6 @@ profile = "black"
line_length = 119
combine_as_imports = true
combine_star = true

[tool.setuptools]
packages = ['stripedhyena']
20 changes: 20 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from setuptools import find_packages, setup

with open("README.md") as f:
readme = f.read()

with open("requirements.txt") as f:
requirements = f.read().split("\n")

setup(
name="stripedhyena",
version="0.2.1",
description="Model and inference code for beyond Transformer architectures",
long_description=readme,
long_description_content_type="text/markdown",
author="Michael Poli",
url="http://github.com/togethercomputer/stripedhyena",
license="Apache-2.0",
packages=find_packages(where="stripedhyena"),
install_requires=requirements,
)
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/engine.py → stripedhyena/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import conv1d_cpp
except:
pass
from src.utils import column_split
from stripedhyena.utils import column_split

IIR_PREFILL_MODES = [
"recurrence",
Expand Down
6 changes: 3 additions & 3 deletions src/generation.py → stripedhyena/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import torch

from src.sample import sample
from src.tokenizer import CharLevelTokenizer
from src.utils import print_rank_0
from stripedhyena.sample import sample
from stripedhyena.tokenizer import CharLevelTokenizer
from stripedhyena.utils import print_rank_0


class Generator:
Expand Down
2 changes: 1 addition & 1 deletion src/layers.py → stripedhyena/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from einops import rearrange
from torch import Tensor

from src.utils import grab_first_if_tuple
from stripedhyena.utils import grab_first_if_tuple


class RMSNorm(torch.nn.Module):
Expand Down
10 changes: 5 additions & 5 deletions src/model.py → stripedhyena/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
import torch.nn as nn
import torch.nn.functional as F

from src.cache import InferenceParams, RecurrentInferenceParams
from src.engine import HyenaInferenceEngine
from src.layers import ParallelGatedMLP, RMSNorm, VocabParallelEmbedding
from src.utils import column_split, print_rank_0
from stripedhyena.cache import InferenceParams, RecurrentInferenceParams
from stripedhyena.engine import HyenaInferenceEngine
from stripedhyena.layers import ParallelGatedMLP, RMSNorm, VocabParallelEmbedding
from stripedhyena.utils import column_split, print_rank_0

try:
from flash_attn.modules.mha import MHA
except ImportError:
"flash_attn not installed"

try:
from src.positional_embeddings import swap_mha_rope
from stripedhyena.positional_embeddings import swap_mha_rope
except ImportError:
"could not import swap_mha_rope from src.positional_embeddings"

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
1 change: 0 additions & 1 deletion test/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
import yaml

from src.layers import RMSNorm
from src.model import StripedHyena
from src.utils import dotdict
Expand Down
3 changes: 1 addition & 2 deletions test/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import torch
import torch.nn as nn
import yaml
from torch.autograd import grad

from src.layers import RMSNorm
from src.model import StripedHyena
from src.utils import dotdict
from torch.autograd import grad

try:
from flashfftconv import FlashFFTConv
Expand Down
1 change: 0 additions & 1 deletion test/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
import yaml

from src.layers import RMSNorm
from src.utils import dotdict

Expand Down
1 change: 0 additions & 1 deletion test/test_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
import yaml

from src.layers import RMSNorm
from src.model import StripedHyena
from src.utils import dotdict
Expand Down

0 comments on commit 1746a0d

Please sign in to comment.