forked from gaspardpetit/verbatim
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add diarization postprocessing
- Loading branch information
Showing
6 changed files
with
465 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# PS08_verbatim/verbatim/transcript/postprocessing/config.py | ||
from dataclasses import dataclass | ||
|
||
@dataclass | ||
class Config: | ||
"""Configuration for the diarization processor""" | ||
MODEL_NAME: str = "phi4" | ||
API_KEY: str = "ollama" | ||
OLLAMA_BASE_URL: str = "http://localhost:11434/v1" | ||
SYSTEM_PROMPT: str = """Du bist ein Experte für die Verbesserung von Gesprächstranskripten. Deine Aufgabe ist es, Dialoge so zu strukturieren, dass die Sprecherwechsel natürlich und logisch erscheinen. | ||
Wichtige Regeln: | ||
- Platziere die <speaker:x> Markierungen nur am Anfang zusammenhängender Äußerungen | ||
- Behalte den ursprünglichen Inhalt bei, optimiere nur die Sprecherzuweisung | ||
- Gib immer nur einen Zeilenumbruch zwischen den Äußerungen aus | ||
- Gib ausschließlich den optimierten Dialog aus, keine Einleitung oder Erklärungen | ||
Beispiel: | ||
Eingabe: | ||
<speaker:1> Guten Tag, ich bin Dr. Schmidt. Können Sie mir sagen | ||
<speaker:2> was Sie herführt? Ja, ich habe seit einigen Tagen | ||
<speaker:1> Kopfschmerzen. Wie lange genau? | ||
<speaker:2> Etwa eine Woche. --> | ||
Ausgabe: | ||
<speaker:1> Guten Tag, ich bin Dr. Schmidt. Können Sie mir sagen, was Sie herführt? | ||
<speaker:2> Ja, ich habe seit einigen Tagen Kopfschmerzen. | ||
<speaker:1> Wie lange genau? | ||
<speaker:2> Etwa eine Woche.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import argparse | ||
import json | ||
import logging | ||
from pathlib import Path | ||
|
||
from verbatim.transcript.postprocessing.config import Config | ||
from verbatim.transcript.postprocessing.processor import DiarizationProcessor | ||
from verbatim.eval.metrics import calculate_metrics, format_metrics, format_improvements | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
LOG = logging.getLogger(__name__) | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Process diarized transcripts with LLM") | ||
parser.add_argument("input_json", type=Path, help="Path to input JSON file") | ||
parser.add_argument("--output-json", type=Path, help="Path to output JSON file") | ||
parser.add_argument("--ref-json", type=Path, help="Path to reference JSON file for evaluation") | ||
parser.add_argument("--chunk-size", type=int, default=3, help="Number of utterances per chunk") | ||
parser.add_argument("--model", type=str, default=Config.MODEL_NAME, help="Name of the Ollama model to use") | ||
|
||
args = parser.parse_args() | ||
|
||
output_path = args.output_json or args.input_json.with_suffix('.dlm.json') | ||
|
||
# Initialize configuration and processor | ||
config = Config() | ||
if args.model: | ||
config.MODEL_NAME = args.model | ||
|
||
processor = DiarizationProcessor(config) | ||
|
||
# Load input data | ||
with open(args.input_json) as f: | ||
input_data = json.load(f) | ||
|
||
# Process with LLM | ||
print("\nProcessing JSON with LLM...") | ||
processed_data = processor.process_json( | ||
input_path=args.input_json, | ||
output_path=output_path, | ||
chunk_size=args.chunk_size | ||
) | ||
|
||
# Evaluate if reference provided | ||
if args.ref_json: | ||
with open(args.ref_json) as f: | ||
ref_data = json.load(f) | ||
|
||
print("\nEvaluating results...") | ||
before_metrics = calculate_metrics(input_data, ref_data) | ||
after_metrics = calculate_metrics(processed_data, ref_data) | ||
|
||
print(format_metrics(before_metrics, prefix="Before LLM postprocessing")) | ||
print(format_metrics(after_metrics, prefix="After LLM postprocessing")) | ||
print(format_improvements(before_metrics, after_metrics)) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import logging | ||
from typing import Tuple, List, Dict | ||
from pathlib import Path | ||
import json | ||
import re | ||
|
||
from openai import OpenAI | ||
|
||
from verbatim.transcript.postprocessing.config import Config | ||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
class DiarizationProcessor: | ||
def __init__(self, config: Config): | ||
"""Initialize the processor with configuration""" | ||
self.config = config | ||
self.client = OpenAI(base_url=config.OLLAMA_BASE_URL, api_key=config.API_KEY) | ||
|
||
def extract_text_and_spk(self, completion: str) -> Tuple[str, str]: | ||
"""Extract text and speaker labels from completion string""" | ||
spk = "1" | ||
previous_spk = "1" | ||
result_text = [] | ||
result_spk = [] | ||
|
||
for word in completion.split(): | ||
if word.startswith("<speaker:"): | ||
if not word.endswith(">"): | ||
word += ">" | ||
spk = word[len("<speaker:"):-len(">")] | ||
try: | ||
spk_int = int(spk) | ||
if not spk or spk_int < 1 or spk_int > 10: | ||
raise ValueError(f"Unexpected speaker token: {word}") | ||
previous_spk = spk | ||
except ValueError: | ||
LOG.warning(f"Skipping meaningless speaker token: {word}") | ||
spk = previous_spk | ||
else: | ||
result_text.append(word) | ||
result_spk.append(spk) | ||
|
||
return " ".join(result_text), " ".join(result_spk) | ||
|
||
def process_chunk(self, text: str) -> Tuple[str, str]: | ||
"""Process a single chunk of text""" | ||
try: | ||
response = self.client.chat.completions.create( | ||
model=self.config.MODEL_NAME, | ||
messages=[ | ||
{"role": "system", "content": self.config.SYSTEM_PROMPT}, | ||
{"role": "user", "content": f"{text} -->"} | ||
], | ||
temperature=0.1 | ||
) | ||
|
||
completion = response.choices[0].message.content | ||
|
||
# Create and log pretty diff | ||
print("\nProcessing chunk:") | ||
print("=" * 80) | ||
print("Before:") | ||
print(text) # Original formatted text | ||
print("\nAfter:") | ||
print(completion) | ||
|
||
return self.extract_text_and_spk(completion) | ||
|
||
except Exception as e: | ||
LOG.error(f"Error processing chunk: {e}") | ||
raise | ||
|
||
def clean_speaker_tag(self, tag: str) -> str: | ||
"""Clean up speaker tags by removing repeated numbers""" | ||
# Extract the first number from the tag | ||
match = re.search(r'<speaker:(\d+)', tag) | ||
if match: | ||
number = match.group(1) | ||
return f"<speaker:{number}>" | ||
return tag | ||
|
||
def format_chunk(self, utterances: List[Dict]) -> str: | ||
"""Format a chunk of utterances into diarized text""" | ||
# Join utterances with proper speaker tags | ||
text_parts = [] | ||
for utt in utterances: | ||
# Clean up any repeated speaker numbers in the input | ||
speaker_tag = self.clean_speaker_tag(f"<speaker:{utt['hyp_spk']}>") | ||
text_parts.append(f"{speaker_tag} {utt['hyp_text']}") | ||
|
||
return "\n".join(text_parts) | ||
|
||
def process_json(self, input_path: Path, output_path: Path, chunk_size: int = 3) -> Dict: | ||
"""Process entire JSON file and save results""" | ||
with open(input_path) as f: | ||
data = json.load(f) | ||
|
||
output_utterances = [] | ||
current_chunk = [] | ||
|
||
for utterance in data["utterances"]: | ||
current_chunk.append(utterance) | ||
|
||
if len(current_chunk) >= chunk_size: | ||
chunk_text = self.format_chunk(current_chunk) | ||
text, spk = self.process_chunk(chunk_text) | ||
|
||
output_utterances.append({ | ||
"utterance_id": f"utt{len(output_utterances)}", | ||
"hyp_text": text, | ||
"hyp_spk": spk | ||
}) | ||
current_chunk = [] | ||
|
||
# Process remaining utterances | ||
if current_chunk: | ||
chunk_text = self.format_chunk(current_chunk) | ||
text, spk = self.process_chunk(chunk_text) | ||
output_utterances.append({ | ||
"utterance_id": f"utt{len(output_utterances)}", | ||
"hyp_text": text, | ||
"hyp_spk": spk | ||
}) | ||
|
||
output_data = {"utterances": output_utterances} | ||
|
||
with open(output_path, "w", encoding="utf-8") as f: | ||
json.dump(output_data, f, ensure_ascii=False, indent=2) | ||
|
||
return output_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import difflib | ||
from typing import List, Dict | ||
|
||
import termcolor | ||
|
||
|
||
def color_diff_line(line: str) -> str: | ||
"""Color diff lines based on their prefix""" | ||
if line.startswith('^'): | ||
return termcolor.colored(line, 'yellow') # or any color you prefer for structural changes | ||
return line | ||
|
||
|
||
def format_chunk_for_display(utterances: List[Dict]) -> str: | ||
"""Format utterances as raw diarized text""" | ||
return " ".join([ | ||
f"<speaker:{utt['hyp_spk']}> {utt['hyp_text']}" | ||
for utt in utterances | ||
]) |