Skip to content

Commit

Permalink
feat: add diarization postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
linozen committed Jan 19, 2025
1 parent 1868126 commit 5e19c7c
Show file tree
Hide file tree
Showing 6 changed files with 465 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ dependencies = [
"pywhispercpp>=1.3.0",
"mlx-whisper>=0.4.1 ; sys_platform == 'darwin'",
"word-levenshtein>=0.0.3",
"openai>=1.59.8",
]

[dependency-groups]
Expand Down
227 changes: 227 additions & 0 deletions uv.lock

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions verbatim/transcript/postprocessing/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# PS08_verbatim/verbatim/transcript/postprocessing/config.py
from dataclasses import dataclass

@dataclass
class Config:
"""Configuration for the diarization processor"""
MODEL_NAME: str = "phi4"
API_KEY: str = "ollama"
OLLAMA_BASE_URL: str = "http://localhost:11434/v1"
SYSTEM_PROMPT: str = """Du bist ein Experte für die Verbesserung von Gesprächstranskripten. Deine Aufgabe ist es, Dialoge so zu strukturieren, dass die Sprecherwechsel natürlich und logisch erscheinen.
Wichtige Regeln:
- Platziere die <speaker:x> Markierungen nur am Anfang zusammenhängender Äußerungen
- Behalte den ursprünglichen Inhalt bei, optimiere nur die Sprecherzuweisung
- Gib immer nur einen Zeilenumbruch zwischen den Äußerungen aus
- Gib ausschließlich den optimierten Dialog aus, keine Einleitung oder Erklärungen
Beispiel:
Eingabe:
<speaker:1> Guten Tag, ich bin Dr. Schmidt. Können Sie mir sagen
<speaker:2> was Sie herführt? Ja, ich habe seit einigen Tagen
<speaker:1> Kopfschmerzen. Wie lange genau?
<speaker:2> Etwa eine Woche. -->
Ausgabe:
<speaker:1> Guten Tag, ich bin Dr. Schmidt. Können Sie mir sagen, was Sie herführt?
<speaker:2> Ja, ich habe seit einigen Tagen Kopfschmerzen.
<speaker:1> Wie lange genau?
<speaker:2> Etwa eine Woche."""
58 changes: 58 additions & 0 deletions verbatim/transcript/postprocessing/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import argparse
import json
import logging
from pathlib import Path

from verbatim.transcript.postprocessing.config import Config
from verbatim.transcript.postprocessing.processor import DiarizationProcessor
from verbatim.eval.metrics import calculate_metrics, format_metrics, format_improvements

logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger(__name__)

def main():
parser = argparse.ArgumentParser(description="Process diarized transcripts with LLM")
parser.add_argument("input_json", type=Path, help="Path to input JSON file")
parser.add_argument("--output-json", type=Path, help="Path to output JSON file")
parser.add_argument("--ref-json", type=Path, help="Path to reference JSON file for evaluation")
parser.add_argument("--chunk-size", type=int, default=3, help="Number of utterances per chunk")
parser.add_argument("--model", type=str, default=Config.MODEL_NAME, help="Name of the Ollama model to use")

args = parser.parse_args()

output_path = args.output_json or args.input_json.with_suffix('.dlm.json')

# Initialize configuration and processor
config = Config()
if args.model:
config.MODEL_NAME = args.model

processor = DiarizationProcessor(config)

# Load input data
with open(args.input_json) as f:
input_data = json.load(f)

# Process with LLM
print("\nProcessing JSON with LLM...")
processed_data = processor.process_json(
input_path=args.input_json,
output_path=output_path,
chunk_size=args.chunk_size
)

# Evaluate if reference provided
if args.ref_json:
with open(args.ref_json) as f:
ref_data = json.load(f)

print("\nEvaluating results...")
before_metrics = calculate_metrics(input_data, ref_data)
after_metrics = calculate_metrics(processed_data, ref_data)

print(format_metrics(before_metrics, prefix="Before LLM postprocessing"))
print(format_metrics(after_metrics, prefix="After LLM postprocessing"))
print(format_improvements(before_metrics, after_metrics))

if __name__ == "__main__":
main()
130 changes: 130 additions & 0 deletions verbatim/transcript/postprocessing/processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import logging
from typing import Tuple, List, Dict
from pathlib import Path
import json
import re

from openai import OpenAI

from verbatim.transcript.postprocessing.config import Config

LOG = logging.getLogger(__name__)

class DiarizationProcessor:
def __init__(self, config: Config):
"""Initialize the processor with configuration"""
self.config = config
self.client = OpenAI(base_url=config.OLLAMA_BASE_URL, api_key=config.API_KEY)

def extract_text_and_spk(self, completion: str) -> Tuple[str, str]:
"""Extract text and speaker labels from completion string"""
spk = "1"
previous_spk = "1"
result_text = []
result_spk = []

for word in completion.split():
if word.startswith("<speaker:"):
if not word.endswith(">"):
word += ">"
spk = word[len("<speaker:"):-len(">")]
try:
spk_int = int(spk)
if not spk or spk_int < 1 or spk_int > 10:
raise ValueError(f"Unexpected speaker token: {word}")
previous_spk = spk
except ValueError:
LOG.warning(f"Skipping meaningless speaker token: {word}")
spk = previous_spk
else:
result_text.append(word)
result_spk.append(spk)

return " ".join(result_text), " ".join(result_spk)

def process_chunk(self, text: str) -> Tuple[str, str]:
"""Process a single chunk of text"""
try:
response = self.client.chat.completions.create(
model=self.config.MODEL_NAME,
messages=[
{"role": "system", "content": self.config.SYSTEM_PROMPT},
{"role": "user", "content": f"{text} -->"}
],
temperature=0.1
)

completion = response.choices[0].message.content

# Create and log pretty diff
print("\nProcessing chunk:")
print("=" * 80)
print("Before:")
print(text) # Original formatted text
print("\nAfter:")
print(completion)

return self.extract_text_and_spk(completion)

except Exception as e:
LOG.error(f"Error processing chunk: {e}")
raise

def clean_speaker_tag(self, tag: str) -> str:
"""Clean up speaker tags by removing repeated numbers"""
# Extract the first number from the tag
match = re.search(r'<speaker:(\d+)', tag)
if match:
number = match.group(1)
return f"<speaker:{number}>"
return tag

def format_chunk(self, utterances: List[Dict]) -> str:
"""Format a chunk of utterances into diarized text"""
# Join utterances with proper speaker tags
text_parts = []
for utt in utterances:
# Clean up any repeated speaker numbers in the input
speaker_tag = self.clean_speaker_tag(f"<speaker:{utt['hyp_spk']}>")
text_parts.append(f"{speaker_tag} {utt['hyp_text']}")

return "\n".join(text_parts)

def process_json(self, input_path: Path, output_path: Path, chunk_size: int = 3) -> Dict:
"""Process entire JSON file and save results"""
with open(input_path) as f:
data = json.load(f)

output_utterances = []
current_chunk = []

for utterance in data["utterances"]:
current_chunk.append(utterance)

if len(current_chunk) >= chunk_size:
chunk_text = self.format_chunk(current_chunk)
text, spk = self.process_chunk(chunk_text)

output_utterances.append({
"utterance_id": f"utt{len(output_utterances)}",
"hyp_text": text,
"hyp_spk": spk
})
current_chunk = []

# Process remaining utterances
if current_chunk:
chunk_text = self.format_chunk(current_chunk)
text, spk = self.process_chunk(chunk_text)
output_utterances.append({
"utterance_id": f"utt{len(output_utterances)}",
"hyp_text": text,
"hyp_spk": spk
})

output_data = {"utterances": output_utterances}

with open(output_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, ensure_ascii=False, indent=2)

return output_data
19 changes: 19 additions & 0 deletions verbatim/transcript/postprocessing/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import difflib
from typing import List, Dict

import termcolor


def color_diff_line(line: str) -> str:
"""Color diff lines based on their prefix"""
if line.startswith('^'):
return termcolor.colored(line, 'yellow') # or any color you prefer for structural changes
return line


def format_chunk_for_display(utterances: List[Dict]) -> str:
"""Format utterances as raw diarized text"""
return " ".join([
f"<speaker:{utt['hyp_spk']}> {utt['hyp_text']}"
for utt in utterances
])

0 comments on commit 5e19c7c

Please sign in to comment.