Skip to content

Commit

Permalink
hyperconnect single repr in pairformerstack
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 28, 2024
1 parent c5d1f7b commit b82e345
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
25 changes: 20 additions & 5 deletions alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,10 +1430,18 @@ def __init__(
num_register_tokens = 0,
checkpoint = False,
add_value_residual = False,
num_residual_streams = 1,
pairwise_block_kwargs: dict = dict(),
pair_bias_attn_kwargs: dict = dict()
):
super().__init__()

# residual / hyper connections

init_hyper_conn, self.expand_streams, self.reduce_streams = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

# layers

layers = ModuleList([])

pair_bias_attn_kwargs = dict(
Expand Down Expand Up @@ -1463,8 +1471,8 @@ def __init__(

layers.append(ModuleList([
pairwise_block,
single_pre_ln(pair_bias_attn),
single_pre_ln(single_transition),
init_hyper_conn(dim = dim_single, branch = single_pre_ln(pair_bias_attn)),
init_hyper_conn(dim = dim_single, branch = single_pre_ln(single_transition)),
]))

self.layers = layers
Expand Down Expand Up @@ -1499,6 +1507,8 @@ def to_layers(

) -> Tuple[Float['b n ds'], Float['b n n dp']]:

single_repr = self.expand_streams(single_repr)

for _ in range(self.recurrent_depth):

value_residual = None
Expand All @@ -1522,6 +1532,8 @@ def to_layers(

single_repr = single_transition(single_repr) + single_repr

single_repr = self.reduce_streams(single_repr)

return single_repr, pairwise_repr

@typecheck
Expand Down Expand Up @@ -1550,8 +1562,7 @@ def pair_bias_attn_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
attn_out, attn_values = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = maybe_value_residual)
single_repr = single_repr + attn_out
single_repr, attn_values = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = maybe_value_residual)

if self.add_value_residual:
maybe_value_residual = default(maybe_value_residual, attn_values)
Expand All @@ -1563,7 +1574,7 @@ def single_transition_wrapper(layer):
@wraps(layer)
def inner(inputs, *args, **kwargs):
single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals = inputs
single_repr = layer(single_repr) + single_repr
single_repr = layer(single_repr)
return single_repr, pairwise_repr, mask, maybe_value_residual, maybe_pairwise_value_residuals
return inner

Expand All @@ -1579,6 +1590,8 @@ def inner(inputs, *args, **kwargs):
wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn))
wrapped_layers.append(single_transition_wrapper(single_transition))

single_repr = self.expand_streams(single_repr)

for _ in range(self.recurrent_depth):
inputs = (single_repr, pairwise_repr, mask, None, None)

Expand All @@ -1587,6 +1600,8 @@ def inner(inputs, *args, **kwargs):

single_repr, pairwise_repr, *_ = inputs

single_repr = self.reduce_streams(single_repr)

return single_repr, pairwise_repr

@typecheck
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.7.5"
version = "0.7.6"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" },
Expand Down
5 changes: 4 additions & 1 deletion tests/test_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,13 @@ def test_centre_random_augmentation():
@pytest.mark.parametrize('recurrent_depth', (1, 2))
@pytest.mark.parametrize('enable_attn_softclamp', (True, False))
@pytest.mark.parametrize('add_value_residual', (True, False))
@pytest.mark.parametrize('num_residual_streams', (1, 4))
def test_pairformer(
checkpoint,
recurrent_depth,
enable_attn_softclamp,
add_value_residual
add_value_residual,
num_residual_streams
):
single = torch.randn(2, 16, 384).requires_grad_()
pairwise = torch.randn(2, 16, 16, 128).requires_grad_()
Expand All @@ -319,6 +321,7 @@ def test_pairformer(
recurrent_depth = recurrent_depth,
checkpoint = checkpoint,
add_value_residual = add_value_residual,
num_residual_streams = num_residual_streams,
pair_bias_attn_kwargs = dict(
enable_attn_softclamp = enable_attn_softclamp
)
Expand Down

0 comments on commit b82e345

Please sign in to comment.