Skip to content

Commit

Permalink
[ssl] align masks and subsamplingmask (#2603)
Browse files Browse the repository at this point in the history
* [ssl] align masks and subsamplingmask

* fix typo
  • Loading branch information
Mddct authored Aug 9, 2024
1 parent 4df07e5 commit 4ba6d52
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions wenet/ssl/bestrq/bestrq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,8 @@ def forward(
if self.features_regularization_weight != 0.0:
features_pen = input.pow(2).mean()

# 0 mask input
xs, masked_masks = self._apply_mask_signal(xs, xs_lens)

# 1 get subsampling mask
subsampling_masks = masked_masks.unfold(1,
size=self.stack_frames,
step=self.stride)
# NOTE(Mddct): you can try torch.max(subsampling_masks, 2) if
# subsampling rate == 2 or mask probs is smaller
code_ids_mask, _ = torch.min(subsampling_masks, 2)
# 1 mask input
xs, code_ids_mask = self._apply_mask_signal(xs, xs_lens)

# 2.0 stack fbank
unmasked_xs = self._stack_features(input)
Expand Down Expand Up @@ -224,20 +216,43 @@ def forward(
def _apply_mask_signal(
self, input: torch.Tensor,
input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
device = input.device
B, T, _ = input.size()
padding_mask = make_pad_mask(input_lens)
masks = compute_mask_indices_v2(input.size()[:-1],

# calc subsampling masks
padding_mask_stride = padding_mask.unfold(
1,
size=self.stack_frames,
step=self.stride,
)
padding_mask, _ = torch.max(padding_mask_stride, dim=-1)
masks = compute_mask_indices_v2(padding_mask.size(),
padding_mask,
self.mask_prob,
self.mask_length,
min_masks=self.min_masks,
device=input.device)

device=device)
# calc signal mask
subsampling_mask = masks
bool_stride_mask = torch.ones_like(padding_mask_stride, device=device)
mask_stride = torch.where(masks.unsqueeze(1), bool_stride_mask, False)
# recover orign seq masks
masks = mask_stride[:, :, :self.stride].flatten(start_dim=1)
masks_padding = torch.zeros(
B,
T,
device=device,
dtype=padding_mask.dtype,
)
masks_padding[:, :masks.size(-1)] = masks
masks = masks_padding
masks_expand = masks.unsqueeze(-1) # [B, T, 1]
# NOTE(Mddct): you can use size (b,t,d) for torch.normal
mask_emb = torch.normal(mean=0, std=0.1,
size=(1, 1, input.size(2))).to(input.device)
xs = torch.where(masks_expand, mask_emb, input)
return xs, masks
return xs, subsampling_mask

def _stack_features(self, input: torch.Tensor) -> torch.Tensor:

Expand Down

0 comments on commit 4ba6d52

Please sign in to comment.