Skip to content

sine2pi/collators

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 

Repository files navigation

On-the-Fly Spectrogram Parameter Adjustment:

Values for n_fft and hop_length are randomly selected from predefined ranges for each audio sample, providing varied spectrogram representations. Most asr models use log mel spectrograms which are "picture" representations of sound (its all data). If you augment the audio you augment the spectrogram.. You get the idea.. It's probably best to hold off with this technique until you are starving for data (looking at you low resource language people). I haven't tested this enough to say if it works well. The idea is sound though :) You should be tokenizing and feature extracting in your datacollator anyway so why not get a little free augmentation?... The noise injection part.. that does increase overhead. I added it in for percieved value.

Example Usage
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
    apply_augmentation=True,
    apply_noise_injection=True  # Enable adaptive noise injection
)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=data_collator)

for batch in dataloader:
    outputs = model(batch)


from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch
import torchaudio
import random


def add_adaptive_noise(audio, noise_type='white', base_intensity=0.005):
    amplitude = audio.abs().mean()
    noise_intensity = base_intensity * amplitude  # Scale noise intensity based on amplitude
    
    noise = torch.randn_like(audio) * noise_intensity
    if noise_type == 'pink':
        noise = torchaudio.functional.highpass_biquad(noise, sample_rate=16000, cutoff_freq=200)
    elif noise_type == 'environmental':
        # Load an example environmental noise file
        noise, _ = torchaudio.load('environmental_noise.wav')
        noise = torch.nn.functional.interpolate(noise.unsqueeze(0), size=audio.size()).squeeze() * noise_intensity
    return audio + noise

##### Pytorch version

def collate_fn(batch, apply_augmentation_flag=True, apply_noise_injection_flag=False):
    n_fft_choices = [400, 800, 1024]
    hop_length_choices = [160, 320, 512]
    noise_profiles = ['white', 'pink', 'environmental']

    input_features, labels, dec_input_features = [], [], []
    
    for f in batch:
        audio = whisper.pad_or_trim(f["audio"].flatten())
        
        if apply_augmentation_flag:
            n_fft = random.choice(n_fft_choices)
            hop_length = random.choice(hop_length_choices)
            if apply_noise_injection_flag:
                noise_type = random.choice(noise_profiles)
                audio = add_adaptive_noise(audio, noise_type=noise_type)
        else:
            n_fft = 1024
            hop_length = 512

        mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000,  # Assuming a sample rate of 16000
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=80
        )(audio)

        input_feature = torch.log(mel_spectrogram + 1e-9)

        label = f["label"]
        label_tokens = [tokenizer.bos_token_id] + tokenizer.encode(label) + [tokenizer.eos_token_id]
        dec_input_feature = label_tokens[:-1]
        label = label_tokens[1:]

        input_features.append(input_feature)
        labels.append(label)
        dec_input_features.append(dec_input_feature)

    input_features = torch.stack(input_features)

    max_label_len = max(len(l) for l in labels)
    max_dec_input_len = max(len(d) for d in dec_input_features)
    max_len = max(max_label_len, max_dec_input_len)

    labels = [np.pad(l, (0, max_len - len(l)), 'constant', constant_values=-100) for l in labels]
    dec_input_features = [np.pad(d, (0, max_len - len(d)), 'constant', constant_values=tokenizer.pad_token_id) for d in dec_input_features]

    labels = np.array(labels)
    dec_input_features = np.array(dec_input_features)

    labels = torch.tensor(labels, dtype=torch.long)
    dec_input_features = torch.tensor(dec_input_features, dtype=torch.long)

    batch = {
        "input_features": input_features,
        "labels": labels,
        "dec_input_features": dec_input_features
    }
    return batch

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published