Skip to content

Commit

Permalink
feat: add support for importing .pt speaker files
Browse files Browse the repository at this point in the history
  • Loading branch information
6drf21e committed Jun 16, 2024
1 parent 3bb64a5 commit 3815713
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 40 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
| 版本 | 地址 | 介绍 |
|----------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------|
| 在线Colab版 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/6drf21e/ChatTTS_colab/blob/main/chattts_webui_mix.ipynb) | 可以在 Google Colab 上一键运行,需要 Google账号,Colib 自带15GB的GPU |
| 在线Colab版 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/6drf21e/ChatTTS_colab/blob/main/chattts_webui_mix.ipynb) | 可以在 Google Colab 上一键运行,需要 Google账号,Colab 自带15GB的GPU |
| 离线整合版 | [百度网盘](https://pan.baidu.com/s/1-hGiPLs6ORM8sZv0xTdxFA?pwd=h3c5) 提取码: h3c5 | 下载本地运行,支持 GPU/CPU,适用 Windows 10 及以上 |
| 离线整合版 | [夸克网盘](https://pan.quark.cn/s/c963e147f204) | 下载本地运行,支持 GPU/CPU,适用 Windows 10 及以上 |

Expand Down
46 changes: 26 additions & 20 deletions tts_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import ChatTTS
import torch
import numpy as np
import datetime
import json
import os
import re
import time

import numpy as np
import torch
from tqdm import tqdm
import datetime

import ChatTTS
from config import DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_TOP_K
import json


def load_chat_tts_model(source='huggingface', force_redownload=False, local_path=None):
Expand Down Expand Up @@ -46,33 +48,37 @@ def deterministic(seed=0):
torch.backends.cudnn.benchmark = False


def generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_prompt,roleid=None, temperature=DEFAULT_TEMPERATURE,
def generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_prompt, roleid=None,
temperature=DEFAULT_TEMPERATURE,
top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K, cur_tqdm=None, skip_save=False,
skip_refine_text=False):
skip_refine_text=False, speaker_type="seed", pt_file=None):
from utils import combine_audio, save_audio, batch_split
# torch.manual_seed(seed)
# top_P = 0.7,
# top_K = 20,
# temperature = 0.3,
if seed in [None, -1, 0, "", "random"]:
seed = np.random.randint(0, 9999)

if not roleid:
print(f"speaker_type: {speaker_type}")
if speaker_type == "seed":
if seed in [None, -1, 0, "", "random"]:
seed = np.random.randint(0, 9999)
deterministic(seed)
rnd_spk_emb = chat.sample_random_speaker()
else:

elif speaker_type == "role":
# 从 JSON 文件中读取数据
with open('./slct_voice_240605.json', 'r', encoding='utf-8') as json_file:
slct_idx_loaded = json.load(json_file)

# 将包含 Tensor 数据的部分转换回 Tensor 对象
for key in slct_idx_loaded:
tensor_list = slct_idx_loaded[key]["tensor"]
slct_idx_loaded[key]["tensor"] = torch.tensor(tensor_list)

# 将音色 tensor 打包进params_infer_code,固定使用此音色发音,调低temperature
rnd_spk_emb = slct_idx_loaded[roleid]["tensor"]
# temperature = 0.001
elif speaker_type == "pt":
print(pt_file)
rnd_spk_emb = torch.load(pt_file)
print(rnd_spk_emb.shape)
if rnd_spk_emb.shape != (768,):
raise ValueError("维度应为 768。")
else:
raise ValueError(f"Invalid speaker_type: {speaker_type}. ")

params_infer_code = {
'spk_emb': rnd_spk_emb,
'prompt': f'[speed_{speed}]',
Expand Down Expand Up @@ -116,7 +122,7 @@ def generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_pr


def generate_refine_text(chat, seed, text, refine_text_prompt, temperature=DEFAULT_TEMPERATURE,
top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K):
top_P=DEFAULT_TOP_P, top_K=DEFAULT_TOP_K):
if seed in [None, -1, 0, "", "random"]:
seed = np.random.randint(0, 9999)

Expand Down
96 changes: 77 additions & 19 deletions webui_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import random
import gradio as gr
import json
from utils import combine_audio, save_audio, batch_split, normalize_zh
from tts_model import load_chat_tts_model, clear_cuda_cache, deterministic, generate_audio_for_seed
from utils import batch_split, normalize_zh
from tts_model import load_chat_tts_model, clear_cuda_cache, generate_audio_for_seed
from config import DEFAULT_BATCH_SIZE, DEFAULT_SPEED, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P, DEFAULT_ORAL, \
DEFAULT_LAUGH, DEFAULT_BK, DEFAULT_SEG_LENGTH

Expand Down Expand Up @@ -142,7 +142,7 @@ def generate_seeds(num_seeds, texts, tq):
for _ in tq(range(num_seeds), desc=f"随机音色生成中..."):
seed = np.random.randint(0, 9999)

filename = generate_audio_for_seed(chat, seed, texts, 1, 5, "[oral_2][laugh_0][break_4]",None,0.3, 0.7, 20)
filename = generate_audio_for_seed(chat, seed, texts, 1, 5, "[oral_2][laugh_0][break_4]", None, 0.3, 0.7, 20)
seeds.append((filename, seed))
clear_cuda_cache()

Expand Down Expand Up @@ -261,7 +261,7 @@ def seed_change(evt: gr.SelectData):


def generate_tts_audio(text_file, num_seeds, seed, speed, oral, laugh, bk, min_length, batch_size, temperature, top_P,
top_K,roleid=None,refine_text=True, progress=gr.Progress()):
top_K, roleid=None, refine_text=True, speaker_type="seed", pt_file=None, progress=gr.Progress()):
from tts_model import generate_audio_for_seed
from utils import split_text, replace_tokens, restore_tokens
if seed in [0, -1, None]:
Expand All @@ -282,11 +282,26 @@ def generate_tts_audio(text_file, num_seeds, seed, speed, oral, laugh, bk, min_l

refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]"
try:
output_files = generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_prompt,roleid,temperature,
top_P, top_K, progress.tqdm, False, not refine_text)
output_files = generate_audio_for_seed(
chat=chat,
seed=seed,
texts=texts,
batch_size=batch_size,
speed=speed,
refine_text_prompt=refine_text_prompt,
roleid=roleid,
temperature=temperature,
top_P=top_P,
top_K=top_K,
cur_tqdm=progress.tqdm,
skip_save=False,
skip_refine_text=not refine_text,
speaker_type=speaker_type,
pt_file=pt_file,
)
return output_files
except Exception as e:
return str(e)
raise e


def generate_refine(text_file, oral, laugh, bk, temperature, top_P, top_K, progress=gr.Progress()):
Expand Down Expand Up @@ -439,10 +454,35 @@ def inser_token(text, btn):
with gr.Column():
gr.Markdown("### 配置参数")
with gr.Row():
num_seeds_input = gr.Number(label="生成音频的数量", value=1, precision=0, visible=False)
seed_input = gr.Number(label="指定种子", info="种子决定音色 0则随机", value=None, precision=0)
roleid = gr.Dropdown(label="定制音色", choices=["选择音色后会覆盖种子","1", "2", "3", "4","5","6","7","21","8","9"],info="音色编号:1发姐,2纯情男大学生,3阳光开朗大男孩,4知心小姐姐,5电视台女主持,6魅力大叔,7优雅甜美,21贴心男宝2,8正式打工人,9贴心男宝1")
generate_audio_seed = gr.Button("\U0001F3B2")
with gr.Column():
gr.Markdown("音色选择")
num_seeds_input = gr.Number(label="生成音频的数量", value=1, precision=0, visible=False)
speaker_stat = gr.State(value="seed")
tab_seed = gr.Tab(label="种子")
with tab_seed:
with gr.Row():
seed_input = gr.Number(label="指定种子", info="种子决定音色 0则随机", value=None,
precision=0)
generate_audio_seed = gr.Button("\U0001F3B2")
tab_roleid = gr.Tab(label="内置音色")
with tab_roleid:
roleid_input = gr.Dropdown(label="内置音色",
choices=[("发姐", "1"),
("纯情男大学生", "2"),
("阳光开朗大男孩", "3"),
("知心小姐姐", "4"),
("电视台女主持", "5"),
("魅力大叔", "6"),
("优雅甜美", "7"),
("贴心男宝2", "21"),
("正式打工人", "8"),
("贴心男宝1", "9")],
value="1",
info="选择音色后会覆盖种子。感谢 @QuantumDriver 提供音色")
tab_pt = gr.Tab(label="上传.PT文件")
with tab_pt:
pt_input = gr.File(label="上传音色文件", file_types=[".pt"], height=100)

with gr.Row():
style_select = gr.Radio(label="预设参数", info="语速部分可自行更改",
choices=["小说朗读", "闲聊", "默认"], interactive=True, )
Expand All @@ -458,14 +498,15 @@ def inser_token(text, btn):
bk_input = gr.Slider(label="停顿", minimum=0, maximum=7, value=DEFAULT_BK, step=1)
# gr.Markdown("### 文本参数")
with gr.Row():
min_length_input = gr.Number(label="文本分段长度", info="大于这个数值进行分段", value=DEFAULT_SEG_LENGTH,
precision=0)
batch_size_input = gr.Number(label="批大小", info="越高越快 太高爆显存 4G推荐3 其他酌情", value=DEFAULT_BATCH_SIZE,
precision=0)
min_length_input = gr.Number(label="文本分段长度", info="大于这个数值进行分段",
value=DEFAULT_SEG_LENGTH, precision=0)
batch_size_input = gr.Number(label="批大小", info="越高越快 太高爆显存 4G推荐3 其他酌情",
value=DEFAULT_BATCH_SIZE, precision=0)
with gr.Accordion("其他参数", open=False):
with gr.Row():
# 温度 top_P top_K
temperature_input = gr.Slider(label="温度", minimum=0.01, maximum=1.0, step=0.01, value=DEFAULT_TEMPERATURE)
temperature_input = gr.Slider(label="温度", minimum=0.01, maximum=1.0, step=0.01,
value=DEFAULT_TEMPERATURE)
top_P_input = gr.Slider(label="top_P", minimum=0.1, maximum=0.9, step=0.05, value=DEFAULT_TOP_P)
top_K_input = gr.Slider(label="top_K", minimum=1, maximum=20, step=1, value=DEFAULT_TOP_K)
# reset 按钮
Expand All @@ -482,6 +523,21 @@ def inser_token(text, btn):
outputs=seed_input)


def do_tab_change(evt: gr.SelectData):
print(evt.selected, evt.index, evt.value, evt.target)
kv = {
"种子": "seed",
"内置音色": "role",
"上传.PT文件": "pt"
}
return kv.get(evt.value, "seed")


tab_seed.select(do_tab_change, outputs=speaker_stat)
tab_roleid.select(do_tab_change, outputs=speaker_stat)
tab_pt.select(do_tab_change, outputs=speaker_stat)


def do_style_select(x):
if x == "小说朗读":
return [4, 0, 0, 2]
Expand Down Expand Up @@ -526,8 +582,10 @@ def do_style_select(x):
temperature_input,
top_P_input,
top_K_input,
roleid,
roleid_input,
refine_text_input,
speaker_stat,
pt_input
],
outputs=[output_audio]
)
Expand Down Expand Up @@ -686,7 +744,7 @@ def batch(iterable, batch_size):
texts = [normalize_zh(line["txt"]) for line in batch_lines]
print(f"seed={seed} t={texts} c={character} s={speed} r={refine_text_prompt}")
wavs = generate_audio_for_seed(chat, int(seed), texts, DEFAULT_BATCH_SIZE, speed,
refine_text_prompt,None,DEFAULT_TEMPERATURE, DEFAULT_TOP_P,
refine_text_prompt, None, DEFAULT_TEMPERATURE, DEFAULT_TOP_P,
DEFAULT_TOP_K, skip_save=True) # 批量处理文本
batch_results[character].extend(wavs)

Expand Down Expand Up @@ -782,7 +840,7 @@ def batch(iterable, batch_size):
placeholder="请输入API Base URL",
value=r"https://api.openai.com/v1")
openai_api_key_input = gr.Textbox(label="OpenAI API Key", placeholder="请输入API Key",
value="sk-xxxxxxx",type="password")
value="sk-xxxxxxx", type="password")
# AI提示词
ai_text_input = gr.Textbox(label="剧情简介或者一段故事", placeholder="请输入文本...", lines=2,
value=ai_text_default)
Expand Down

0 comments on commit 3815713

Please sign in to comment.