diff --git a/alphafold3_pytorch/attention.py b/alphafold3_pytorch/attention.py index 2af76ca9..11cd52c3 100644 --- a/alphafold3_pytorch/attention.py +++ b/alphafold3_pytorch/attention.py @@ -185,6 +185,7 @@ def __init__( window_size = None, num_memory_kv: int = 0, laser = False, + laser_softclamp_value = 15., enable_attn_softclamp = False, attn_softclamp_value = 50., softmax_full_precision = False @@ -206,10 +207,11 @@ def __init__( dim_inner = dim_head * heads self.attend = Attend( - laser = laser, dropout = dropout, window_size = window_size, enable_attn_softclamp = enable_attn_softclamp, + laser = laser, + laser_softclamp_value = laser_softclamp_value, attn_softclamp_value = attn_softclamp_value, softmax_full_precision = softmax_full_precision ) @@ -305,6 +307,7 @@ def __init__( self, dropout = 0., laser = False, + laser_softclamp_value = 15., window_size = None, scale: float | None = None, enable_attn_softclamp = False, @@ -336,6 +339,7 @@ def __init__( # laser attention self.laser = laser + self.laser_softclamp_value = laser_softclamp_value # softclamp attention logits # being adopted by a number of recent llms (gemma, grok) @@ -460,8 +464,7 @@ def local_attn( # maybe laser if self.laser: - v_max = v.amax(dim = -2, keepdim = True) - v = (v - v_max).exp() + v = softclamp(v, self.laser_softclamp_value) # aggregate @@ -470,7 +473,7 @@ def local_attn( # maybe laser if self.laser: - out = log(out) + v_max + out = log(out) # un-window the output