Skip to content

Commit

Permalink
v3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Oct 4, 2024
1 parent 929dfcf commit 7b002f2
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 63 deletions.
64 changes: 47 additions & 17 deletions ct2_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def __init__(self, cuda_available=False):
model, quantization, device = "base.en", "int8", "cpu"
self.supported_quantizations = {"cpu": [], "cuda": []}

self.recorder.update_model(model, quantization, device)
self.record_button = QPushButton("Record", self)
self.record_button.clicked.connect(self.recorder.start_recording)
layout.addWidget(self.record_button)

for text, callback in [("Record", self.recorder.start_recording),
("Stop and Copy to Clipboard", self.recorder.save_audio)]:
button = QPushButton(text, self)
button.clicked.connect(callback)
layout.addWidget(button)
self.stop_button = QPushButton("Stop and Transcribe", self)
self.stop_button.clicked.connect(self.recorder.save_audio)
layout.addWidget(self.stop_button)

settings_group = QGroupBox("Settings")
settings_layout = QVBoxLayout()
Expand All @@ -41,7 +41,10 @@ def __init__(self, cuda_available=False):
model_label = QLabel('Model')
h_layout.addWidget(model_label)
self.model_dropdown = QComboBox(self)
self.model_dropdown.addItems(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2"])
self.model_dropdown.addItems([
"tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2",
"distil-whisper-small.en", "distil-whisper-medium.en", "distil-whisper-large-v2", "distil-whisper-large-v3"
])
h_layout.addWidget(self.model_dropdown)
self.model_dropdown.setCurrentText(model)

Expand All @@ -64,9 +67,9 @@ def __init__(self, cuda_available=False):

settings_layout.addLayout(h_layout)

update_model_btn = QPushButton("Update Settings", self)
update_model_btn.clicked.connect(self.update_model)
settings_layout.addWidget(update_model_btn)
self.update_model_btn = QPushButton("Update Settings", self)
self.update_model_btn.clicked.connect(self.update_model)
settings_layout.addWidget(self.update_model_btn)

settings_group.setLayout(settings_layout)
layout.addWidget(settings_group)
Expand All @@ -75,19 +78,46 @@ def __init__(self, cuda_available=False):
self.setWindowFlag(Qt.WindowStaysOnTopHint)

self.device_dropdown.currentTextChanged.connect(self.update_quantization_options)
self.update_quantization_options(quantization)
self.model_dropdown.currentTextChanged.connect(self.update_quantization_options)
self.update_quantization_options()

def update_quantization_options(self, current_quantization):
self.recorder.update_status_signal.connect(self.update_status)
self.recorder.enable_widgets_signal.connect(self.set_widgets_enabled)

def update_quantization_options(self):
model = self.model_dropdown.currentText()
device = self.device_dropdown.currentText()
self.quantization_dropdown.clear()
options = self.supported_quantizations.get(self.device_dropdown.currentText(), [])
options = self.get_quantization_options(model, device)
self.quantization_dropdown.addItems(options)
if current_quantization in options:
self.quantization_dropdown.setCurrentText(current_quantization)
if self.quantization_dropdown.currentText() not in options and options:
self.quantization_dropdown.setCurrentText(options[0])

def get_quantization_options(self, model, device):
distil_models = {
"distil-whisper-small.en": ["float16", "bfloat16", "float32"],
"distil-whisper-medium.en": ["float16", "bfloat16", "float32"],
"distil-whisper-large-v2": ["float16", "float32"],
"distil-whisper-large-v3": ["float16", "bfloat16", "float32"]
}
if model in distil_models:
return distil_models[model]
else:
self.quantization_dropdown.setCurrentText("")
return self.supported_quantizations.get(device, [])

def update_model(self):
self.recorder.update_model(self.model_dropdown.currentText(), self.quantization_dropdown.currentText(), self.device_dropdown.currentText())
model_name = self.model_dropdown.currentText()
quantization = self.quantization_dropdown.currentText()
device = self.device_dropdown.currentText()
self.recorder.update_model(model_name, quantization, device)

def update_status(self, text):
self.status_label.setText(text)

def set_widgets_enabled(self, enabled):
self.record_button.setEnabled(enabled)
self.stop_button.setEnabled(enabled)
self.model_dropdown.setEnabled(enabled)
self.quantization_dropdown.setEnabled(enabled)
self.device_dropdown.setEnabled(enabled)
self.update_model_btn.setEnabled(enabled)
98 changes: 66 additions & 32 deletions ct2_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,90 @@
from faster_whisper import WhisperModel
import yaml
from PySide6.QtWidgets import QApplication
from PySide6.QtGui import QClipboard
from PySide6.QtCore import QObject, Signal, Slot

class VoiceRecorder(QObject):
update_status_signal = Signal(str)
enable_widgets_signal = Signal(bool)
copy_to_clipboard_signal = Signal(str)

class VoiceRecorder:
def __init__(self, window, samplerate=44100, channels=1, dtype='int16'):
super().__init__()
self.samplerate, self.channels, self.dtype = samplerate, channels, dtype
self.window = window
self.is_recording, self.frames = False, []
self.model = None
self.load_settings()

# Connect the signal to the slot
self.copy_to_clipboard_signal.connect(self.copy_to_clipboard)

def load_settings(self):
try:
with open("config.yaml", "r") as f:
config = yaml.safe_load(f)
if "device_type" not in config:
config["device_type"] = "cpu"
if "model_name" not in config:
config["model_name"] = "base.en"
if "quantization_type" not in config:
config["quantization_type"] = "int8"
self.update_model(config["model_name"], config["quantization_type"], config["device_type"])
model_name = config.get("model_name", "base.en")
quantization_type = config.get("quantization_type", "int8")
device_type = config.get("device_type", "cpu")
self.update_model(model_name, quantization_type, device_type)
except FileNotFoundError:
self.update_model("base.en", "int8", "cpu")

def save_settings(self, model_name, quantization_type, device_type):
try:
with open("config.yaml", "r") as f:
config = yaml.safe_load(f)
except FileNotFoundError:
config = {}
config["model_name"] = model_name
config["quantization_type"] = quantization_type
config["device_type"] = device_type
config = {
"model_name": model_name,
"quantization_type": quantization_type,
"device_type": device_type
}
with open("config.yaml", "w") as f:
yaml.safe_dump(config, f)

def update_model(self, model_name, quantization_type, device_type):
model_str = f"ctranslate2-4you/whisper-{model_name}-ct2-{quantization_type}"
self.model = WhisperModel(model_str, device=device_type, compute_type=quantization_type, cpu_threads=26)
self.window.update_status(f"Model updated to {model_name} with {quantization_type} quantization on {device_type} device")
self.save_settings(model_name, quantization_type, device_type)

def transcribe_audio(self, audio_file):
segments, _ = self.model.transcribe(audio_file)
clipboard_text = "\n".join([segment.text for segment in segments])
self.enable_widgets_signal.emit(False)
self.update_status_signal.emit(f"Updating model to {model_name}...")

clipboard = QApplication.clipboard()
clipboard.setText(clipboard_text)
def update_model_thread():
try:
if model_name.startswith("distil-whisper"):
model_str = f"ctranslate2-4you/{model_name}-ct2-{quantization_type}"
else:
model_str = f"ctranslate2-4you/whisper-{model_name}-ct2-{quantization_type}"

self.model = WhisperModel(model_str, device=device_type, compute_type=quantization_type, cpu_threads=26)
self.save_settings(model_name, quantization_type, device_type)
self.update_status_signal.emit(f"Model updated to {model_name} on {device_type} device")
except Exception as e:
self.update_status_signal.emit(f"Error updating model: {str(e)}")
finally:
self.enable_widgets_signal.emit(True)

self.window.update_status("Audio saved and transcribed")
threading.Thread(target=update_model_thread).start()

def transcribe_audio(self, audio_file):
self.update_status_signal.emit("Transcribing audio...")
try:
segments, _ = self.model.transcribe(audio_file)
clipboard_text = "\n".join([segment.text for segment in segments])

# Emit signal to copy text to clipboard in main thread
self.copy_to_clipboard_signal.emit(clipboard_text)

self.update_status_signal.emit("Audio transcribed and copied to clipboard")
except Exception as e:
self.update_status_signal.emit(f"Transcription failed: {e}")
finally:
self.enable_widgets_signal.emit(True)
try:
os.remove(audio_file)
except OSError as e:
print(f"Error deleting temporary file: {e}")

@Slot(str)
def copy_to_clipboard(self, text):
QApplication.instance().clipboard().setText(text)

def record_audio(self):
self.window.update_status("Recording...")
self.update_status_signal.emit("Recording...")
def callback(indata, frames, time, status):
if status:
print(status)
Expand All @@ -69,18 +101,20 @@ def callback(indata, frames, time, status):

def save_audio(self):
self.is_recording = False
self.enable_widgets_signal.emit(False)
temp_filename = tempfile.mktemp(suffix=".wav")
data = np.concatenate(self.frames, axis=0)
with wave.open(temp_filename, "wb") as wf:
wf.setnchannels(self.channels)
wf.setsampwidth(2) # Always 2 for int16
wf.setframerate(self.samplerate)
wf.writeframes(data.tobytes())
self.transcribe_audio(temp_filename)
os.remove(temp_filename)

self.update_status_signal.emit("Audio saved, starting transcription...")
threading.Thread(target=self.transcribe_audio, args=(temp_filename,)).start()
self.frames.clear()

def start_recording(self):
if not self.is_recording:
self.is_recording = True
threading.Thread(target=self.record_audio).start()
threading.Thread(target=self.record_audio).start()
11 changes: 0 additions & 11 deletions ct2_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,10 @@ def set_cuda_paths():
paths_to_add = [str(cuda_path), str(cublas_path), str(cudnn_path)]
env_vars = ['CUDA_PATH', 'CUDA_PATH_V12_1', 'PATH']

# print(f"Virtual environment base: {venv_base}")
# print(f"NVIDIA base path: {nvidia_base_path}")
# print(f"CUDA path: {cuda_path}")
# print(f"cuBLAS path: {cublas_path}")
# print(f"cuDNN path: {cudnn_path}")

for env_var in env_vars:
current_value = os.environ.get(env_var, '')
new_value = os.pathsep.join(paths_to_add + [current_value] if current_value else paths_to_add)
os.environ[env_var] = new_value
# print(f"\n{env_var} updated:")
# print(f" Old value: {current_value}")
# print(f" New value: {new_value}")

# print("\nCUDA paths have been set or updated in the environment variables.")

set_cuda_paths()

Expand Down
4 changes: 1 addition & 3 deletions ct2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,4 @@ def _update_supported_quantizations_in_config(self, device, quantizations):
config["supported_quantizations"][device] = quantizations

with open("config.yaml", "w") as f:
yaml.safe_dump(config, f, default_style="'")

# print(f"Updated {device} quantizations in config.yaml to: {quantizations}")
yaml.safe_dump(config, f, default_style="'")

0 comments on commit 7b002f2

Please sign in to comment.