diff --git a/src/layers.py b/src/layers.py index 436b5f2..b9d518b 100644 --- a/src/layers.py +++ b/src/layers.py @@ -41,7 +41,14 @@ def __init__( super().__init__() multiple_of = config.get("inner_size_multiple_of", 64) - self.act = F.silu + self.act_type = config.get("mlp_activation", "silu") + if self.act_type == "gelu": + self.act = F.gelu + elif self.act_type == "silu": + self.act = F.silu + else: + raise NotImplementedError + self.multiple_of = multiple_of * config.model_parallel_size inner_size = int(2 * config.hidden_size * 4 / 3)