Skip to content

Commit

Permalink
[ssl] fix bestrq l2norm (#2599)
Browse files Browse the repository at this point in the history
* [ssl] align the l2 norm

* fix cv targets in dataset
  • Loading branch information
Mddct authored Aug 7, 2024
1 parent 98eac6f commit 0560f70
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
3 changes: 1 addition & 2 deletions wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,13 @@ def decode_wav(sample):
""" Parse key/wav/txt from json line
Args:
sample: str, str is a json line has key/wav/txt
sample: str, str is a json line has key/wav
Returns:
{key, wav, sample_rate, ...}
"""
assert 'key' in sample
assert 'wav' in sample
assert 'txt' in sample
wav_file = sample['wav']
if isinstance(wav_file, str):
with open(wav_file, 'rb') as f:
Expand Down
23 changes: 9 additions & 14 deletions wenet/ssl/bestrq/bestrq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,6 @@ def __init__(

# encoder
self.encoder = encoder
assert self.encoder.global_cmvn is not None
self.register_buffer('signal_mean', self.encoder.global_cmvn.mean)
self.register_buffer('signal_istd', self.encoder.global_cmvn.istd)
self.signal_norm_var = self.encoder.global_cmvn.norm_var
# NOTE(Mddct): disable encoder's global_cmvn
self.encoder.global_cmvn = None

# n softmax
self.encoder_top_n_out = torch.nn.parameter.Parameter(
torch.empty(self.num_codebooks, self.encoder.output_size(),
Expand Down Expand Up @@ -122,6 +115,8 @@ def __init__(
requires_grad=False,
)
torch.nn.init.normal_(self.embeddings)
self.embeddings /= (self.embeddings.norm(dim=-1, p=2, keepdim=True) +
1e-8)

# force reset encoder papameter
self.reset_encoder_parameter()
Expand Down Expand Up @@ -169,10 +164,6 @@ def forward(
):
xs = batch['feats'].to(device)
xs_lens = batch['feats_lengths'].to(device)
# force global cmvn
xs = xs - self.signal_mean
if self.signal_norm_var:
xs = xs * self.signal_istd
input = xs

features_pen: Optional[torch.Tensor] = None
Expand All @@ -186,6 +177,8 @@ def forward(
subsampling_masks = masked_masks.unfold(1,
size=self.stack_frames,
step=self.stride)
# NOTE(Mddct): you can try torch.max(subsampling_masks, 2) if
# subsampling rate == 2 or mask probs is smaller
code_ids_mask, _ = torch.min(subsampling_masks, 2)

# 2.0 stack fbank
Expand Down Expand Up @@ -267,10 +260,12 @@ def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
return loss

def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
xs = self.norm(xs)
if self.encoder.global_cmvn is None:
xs = self.norm(xs)
xs = torch.matmul(xs, self.projection.to(xs.device))

xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8)
codebooks = self.embeddings
B, T, C = xs.size()
xs_flatten = xs.view(B * T, C)
_, codes, _ = quantize_vector(xs_flatten, self.embeddings)
_, codes, _ = quantize_vector(xs_flatten, codebooks)
return codes.reshape(B, T, -1) # [B, T, num_codebooks]
3 changes: 3 additions & 0 deletions wenet/ssl/init_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def padding(data):
"keys": sorted_keys,
"feats": padded_feats,
"feats_lengths": feats_lengths,
# NOTE(Mddct): cv need targets , refine later
"target": padded_feats,
"target_lengths": feats_lengths,
}
return batch

Expand Down

0 comments on commit 0560f70

Please sign in to comment.