Skip to content

Commit

Permalink
Defer LayerScale initialization for compatibility with "meta" devices
Browse files Browse the repository at this point in the history
  • Loading branch information
baldassarreFe authored Oct 25, 2024
1 parent e1277af commit 3a3217b
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions dinov2/layers/layer_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110

from typing import Union
from typing import Optional, Union

import torch
from torch import Tensor
Expand All @@ -18,10 +18,17 @@ def __init__(
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
self.init_values = init_values
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
self.reset_parameters()

def reset_parameters(self):
nn.init.constant_(self.gamma, self.init_values)

def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma

0 comments on commit 3a3217b

Please sign in to comment.