Skip to content

Commit

Permalink
use softclamping to address numerical issues with laser instead
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 3, 2024
1 parent 34f6c59 commit 0cee1cc
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions alphafold3_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __init__(
window_size = None,
num_memory_kv: int = 0,
laser = False,
laser_softclamp_value = 15.,
enable_attn_softclamp = False,
attn_softclamp_value = 50.,
softmax_full_precision = False
Expand All @@ -206,10 +207,11 @@ def __init__(
dim_inner = dim_head * heads

self.attend = Attend(
laser = laser,
dropout = dropout,
window_size = window_size,
enable_attn_softclamp = enable_attn_softclamp,
laser = laser,
laser_softclamp_value = laser_softclamp_value,
attn_softclamp_value = attn_softclamp_value,
softmax_full_precision = softmax_full_precision
)
Expand Down Expand Up @@ -305,6 +307,7 @@ def __init__(
self,
dropout = 0.,
laser = False,
laser_softclamp_value = 15.,
window_size = None,
scale: float | None = None,
enable_attn_softclamp = False,
Expand Down Expand Up @@ -336,6 +339,7 @@ def __init__(
# laser attention

self.laser = laser
self.laser_softclamp_value = laser_softclamp_value

# softclamp attention logits
# being adopted by a number of recent llms (gemma, grok)
Expand Down Expand Up @@ -460,8 +464,7 @@ def local_attn(
# maybe laser

if self.laser:
v_max = v.amax(dim = -2, keepdim = True)
v = (v - v_max).exp()
v = softclamp(v, self.laser_softclamp_value)

# aggregate

Expand All @@ -470,7 +473,7 @@ def local_attn(
# maybe laser

if self.laser:
out = log(out) + v_max
out = log(out)

# un-window the output

Expand Down

0 comments on commit 0cee1cc

Please sign in to comment.