Skip to content

Commit

Permalink
[utils] use force_align of torchaudio (#2597)
Browse files Browse the repository at this point in the history
* [utils] use force_align of torchaudio

* rounded outputs of force_align

* remove empty line
  • Loading branch information
pengzhendong authored Aug 7, 2024
1 parent fcf26a4 commit 98eac6f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 53 deletions.
6 changes: 3 additions & 3 deletions wenet/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def _decode(self,
for i, x in enumerate(res.tokens):
tokens_info.append({
'token': self.char_dict[x],
'start': times[i][0],
'end': times[i][1],
'confidence': res.tokens_confidence[i]
'start': round(times[i][0], 3),
'end': round(times[i][1], 3),
'confidence': round(res.tokens_confidence[i], 2)
})
result['tokens'] = tokens_info
return result
Expand Down
6 changes: 3 additions & 3 deletions wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
for i, x in enumerate(res.tokens):
tokens_info.append({
'token': self.tokenizer.char_dict[x],
'start': times[i][0],
'end': times[i][1],
'confidence': res.tokens_confidence[i]
'start': round(times[i][0], 3),
'end': round(times[i][1], 3),
'confidence': round(res.tokens_confidence[i], 2)
})
result['tokens'] = tokens_info

Expand Down
52 changes: 5 additions & 47 deletions wenet/utils/ctc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np

import torch
import torchaudio.functional as F


def remove_duplicates_and_blank(hyp: List[int],
Expand Down Expand Up @@ -112,53 +113,10 @@ def force_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list:
Returns:
torch.Tensor: alignment result
"""
ctc_probs = ctc_probs.cpu()
y = y.cpu()
y_insert_blank = insert_blank(y, blank_id)

log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
log_alpha = log_alpha - float('inf') # log of zero
state_path = torch.zeros((ctc_probs.size(0), len(y_insert_blank)),
dtype=torch.int16) - 1 # state path

# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]]

for t in range(1, ctc_probs.size(0)):
for s in range(len(y_insert_blank)):
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
candidates = torch.tensor(
[log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
prev_state = [s, s - 1]
else:
candidates = torch.tensor([
log_alpha[t - 1, s],
log_alpha[t - 1, s - 1],
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[
t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
state_path[t, s] = prev_state[torch.argmax(candidates)]

state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)

candidates = torch.tensor([
log_alpha[-1, len(y_insert_blank) - 1],
log_alpha[-1, len(y_insert_blank) - 2]
])
final_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
state_seq[-1] = final_state[torch.argmax(candidates)]
for t in range(ctc_probs.size(0) - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]

output_alignment = []
for t in range(0, ctc_probs.size(0)):
output_alignment.append(y_insert_blank[state_seq[t, 0]])

return output_alignment
ctc_probs = ctc_probs[None].cpu()
y = y[None].cpu()
alignments, _ = F.forced_align(ctc_probs, y, blank=blank_id)
return alignments[0]


def get_blank_id(configs, symbol_table):
Expand Down

0 comments on commit 98eac6f

Please sign in to comment.