Skip to content

Commit

Permalink
[ssl/wav2vec2] add more info (#2035)
Browse files Browse the repository at this point in the history
* [ssl/wav2vec2]  add more info

* [ssl/wav2vec2]  fix lint
  • Loading branch information
Mddct authored Oct 9, 2023
1 parent a1b0a29 commit 3790509
Showing 1 changed file with 44 additions and 28 deletions.
72 changes: 44 additions & 28 deletions wenet/ssl/wav2vec2/wav2vec2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,47 @@
from wenet.utils.mask import make_non_pad_mask


def _sample_negative_indices(features_shape: torch.Size,
def _sample_negative_indices(features_shape: Tuple,
num_negatives: int,
mask: Optional[torch.Tensor] = None):
device: torch.device,
mask_time_indices: Optional[torch.Tensor] = None):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length, _ = features_shape
assert sequence_length > 1
batch_size, sequence_length = features_shape

sequence_length_range = torch.arange(sequence_length)

# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = []
sampled_negative_indices = torch.zeros(
(batch_size, sequence_length, num_negatives),
dtype=sequence_length_range.dtype,
device=device)

mask_time_indices = (mask_time_indices.bool() if mask_time_indices
is not None else torch.ones(features_shape,
dtype=torch.bool))

for batch_idx in range(batch_size):
high = mask[batch_idx].sum(
) - 1 if mask is not None else sequence_length - 1
sampled_indices_slice = torch.randint(0,
high,
size=(num_negatives *
sequence_length, ))
sampled_negative_indices.append(sampled_indices_slice)

sampled_negative_indices = torch.stack(sampled_negative_indices, dim=0).to(
torch.int32) # [B, num_negatives * sequence_length]

# generate indices of the positive vectors themselves,
# repeat them `num_negatives` times
feature_indices = torch.arange(sequence_length)[:, None].repeat(
1, num_negatives).flatten() # [B x num_negatives x sequence_length]

# avoid sampling the same positive vector, but keep the distribution uniform
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1

# correct for batch size
for batch_idx in range(1, batch_size):
high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[
mask_time_indices[batch_idx]]

feature_indices = torch.arange(high + 1).unsqueeze(1).expand(
high + 1, num_negatives)
sampled_indices = torch.randint(0,
high,
size=(high + 1, num_negatives))
sampled_indices[sampled_indices >= feature_indices] += 1

# remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[
batch_idx]] = mapped_masked_indices[sampled_indices]

# correct for batch size
sampled_negative_indices[batch_idx] += batch_idx * sequence_length

return sampled_negative_indices
return sampled_negative_indices.reshape(batch_size, -1)


def _compute_contrastive_loss(quantized_features: torch.Tensor,
Expand Down Expand Up @@ -241,29 +247,39 @@ def forward(
unmasked_xs, masks.squeeze(1), gumbel_temperature)

sampled_negative_indices = _sample_negative_indices(
xs.size(), self.num_negatives, masks.squeeze(1))
xs.size()[:-1], self.num_negatives, masked_masks.device,
masked_masks)

loss_contrastive = _compute_contrastive_loss(
quantized_features, out, sampled_negative_indices, masked_masks,
self.contrastive_logits_temp, self.num_negatives)
loss = loss_contrastive

# scale by sample size
# make sure that diversity loss is multiplied by `sample_size`
# since contrastive_loss is `sum`-reduced instead of averaged
sample_size = masked_masks.sum()
# higher codevector_perplexity leads to lower diversity loss
loss_diversity: Optional[torch.Tensor] = None
if self.diversity_weight != 0.0:
loss_diversity = (
self.num_codevector_groups * self.num_codevectors_per_group -
codevector_perplexity) / (self.num_codevectors_per_group *
self.num_codevector_groups)
loss_diversity = loss_diversity * sample_size
loss = loss + self.diversity_weight * loss_diversity
loss = loss / sample_size

features_pen: Optional[torch.Tensor] = None
if self.features_regularization_weight != 0.0:
features_pen = xs.pow(2).mean()
loss = loss + self.features_regularization_weight * features_pen

return {
"code_ppl": codevector_perplexity.detach(),
"features_l2": features_pen,
"loss": loss,
"losss_constrastive": loss_contrastive / sample_size,
"loss_diversity": loss_diversity,
}

Expand Down

0 comments on commit 3790509

Please sign in to comment.