Skip to content

Commit

Permalink
[cli] support punc (#2650)
Browse files Browse the repository at this point in the history
* [cli] paraformer support batch infer

* fix device

* fix ts

* fix lint

* [cli] support punc

* fix result

* disable jieba log

* fix load jieba once

* refactor call args
  • Loading branch information
Mddct authored Nov 2, 2024
1 parent f92c4e9 commit 7dc43db
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 9 deletions.
6 changes: 4 additions & 2 deletions wenet/cli/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

import os
import requests
import sys
import tarfile
from pathlib import Path
from urllib.request import urlretrieve

import requests
import tqdm


Expand Down Expand Up @@ -77,7 +77,9 @@ class Hub(object):
# gigaspeech
"english": "gigaspeech_u2pp_conformer_libtorch.tar.gz",
# paraformer
"paraformer": "paraformer.tar.gz"
"paraformer": "paraformer.tar.gz",
# punc
"punc": "punc.tar.gz"
}

def __init__(self) -> None:
Expand Down
115 changes: 115 additions & 0 deletions wenet/cli/punc_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
from typing import List

import jieba
import torch
from wenet.cli.hub import Hub
from wenet.paraformer.search import _isAllAlpha
from wenet.text.char_tokenizer import CharTokenizer


class PuncModel:

def __init__(self, model_dir: str) -> None:
self.model_dir = model_dir
model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')

self.model = torch.jit.load(model_path)
self.tokenizer = CharTokenizer(units_path)
self.device = torch.device("cpu")
self.use_jieba = False

self.punc_table = ['<unk>', '', ',', '。', '?', '、']

def split_words(self, text: str):
if not self.use_jieba:
self.use_jieba = True
import logging

# Disable jieba's logger
logging.getLogger('jieba').disabled = True
jieba.load_userdict(os.path.join(self.model_dir, 'jieba_usr_dict'))

result_list = []
tokens = text.split()
current_language = None
buffer = []

for token in tokens:
is_english = token.isascii()
if is_english:
language = "English"
else:
language = "Chinese"

if current_language and language != current_language:
if current_language == "Chinese":
result_list.extend(jieba.cut(''.join(buffer), HMM=False))
else:
result_list.extend(buffer)
buffer = []

buffer.append(token)
current_language = language

if buffer:
if current_language == "Chinese":
result_list.extend(jieba.cut(''.join(buffer), HMM=False))
else:
result_list.extend(buffer)

return result_list

def add_punc_batch(self, texts: List[str]):
batch_text_words = []
batch_text_ids = []
batch_text_lens = []

for text in texts:
words = self.split_words(text)
ids = self.tokenizer.tokens2ids(words)
batch_text_words.append(words)
batch_text_ids.append(ids)
batch_text_lens.append(len(ids))

texts_tensor = torch.tensor(batch_text_ids,
device=self.device,
dtype=torch.int64)
texts_lens_tensor = torch.tensor(batch_text_lens,
device=self.device,
dtype=torch.int64)

log_probs, _ = self.model(texts_tensor, texts_lens_tensor)
result = []
outs = log_probs.argmax(-1).cpu().numpy()
for i, out in enumerate(outs):
punc_id = out[:batch_text_lens[i]]
sentence = ''
for j, word in enumerate(batch_text_words[i]):
if _isAllAlpha(word):
word = '▁' + word
word += self.punc_table[punc_id[j]]
sentence += word
result.append(sentence.replace('▁', ' '))
return result

def __call__(self, text: str):
if text != '':
r = self.add_punc_batch([text])[0]
return r
return ''


def load_model(model_dir: str = None,
gpu: int = -1,
device: str = "cpu") -> PuncModel:
if model_dir is None:
model_dir = Hub.get_model_by_lang('punc')
if gpu != -1:
# remain the original usage of gpu
device = "cuda"
punc = PuncModel(model_dir)
punc.device = torch.device(device)
punc.model.to(device)
return punc
17 changes: 16 additions & 1 deletion wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import argparse

from wenet.cli.paraformer_model import load_model as load_paraformer
from wenet.cli.model import load_model
from wenet.cli.paraformer_model import load_model as load_paraformer
from wenet.cli.punc_model import load_model as load_punc_model


def get_args():
Expand Down Expand Up @@ -64,6 +65,13 @@ def get_args():
type=float,
default=6.0,
help='context score')
parser.add_argument('--punc', action='store_true', help='context score')

parser.add_argument('-pm',
'--punc_model_dir',
default=None,
help='specify your own punc model dir')

args = parser.parse_args()
return args

Expand All @@ -76,10 +84,17 @@ def main():
else:
model = load_model(args.language, args.model_dir, args.gpu, args.beam,
args.context_path, args.context_score, args.device)
punc_model = None
if args.punc:
punc_model = load_punc_model(args.punc_model_dir, args.gpu,
args.device)
if args.align:
result = model.align(args.audio_file, args.label)
else:
result = model.transcribe(args.audio_file, args.show_tokens_info)
if args.punc:
assert punc_model is not None
result['text_with_punc'] = punc_model(result['text'])
print(result)


Expand Down
11 changes: 5 additions & 6 deletions wenet/paraformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@

import math
from typing import Optional, Tuple
import torch

import torch
import torch.utils.checkpoint as ckpt

from wenet.paraformer.attention import (DummyMultiHeadSANM,
MultiHeadAttentionCross,
MultiHeadedAttentionSANM)
from wenet.paraformer.embedding import ParaformerPositinoalEncoding
from wenet.paraformer.subsampling import IdentitySubsampling
from wenet.transformer.encoder import BaseEncoder
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.decoder_layer import DecoderLayer
from wenet.transformer.encoder import BaseEncoder
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.mask import make_non_pad_mask
Expand Down Expand Up @@ -190,7 +189,7 @@ def __init__(
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "abs_pos",
normalize_before: bool = True,
Expand Down Expand Up @@ -389,8 +388,8 @@ def __init__(
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0,
src_attention_dropout_rate: float = 0,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
Expand Down

0 comments on commit 7dc43db

Please sign in to comment.