diff --git a/cli/helpers.py b/cli/helpers.py index b717ab295..cdd2986e9 100644 --- a/cli/helpers.py +++ b/cli/helpers.py @@ -167,6 +167,8 @@ def get_item_yaml_values( if values: if isinstance(values, list): values_set = set(values) + elif isinstance(values, dict): + values_set = values else: values_set.add(values) values_dict[key] = values_set diff --git a/cli/item_to_function.py b/cli/item_to_function.py index 87f622d47..a1bb4b168 100644 --- a/cli/item_to_function.py +++ b/cli/item_to_function.py @@ -55,17 +55,17 @@ help="If -b/--bump_version is enabled, increase the minor version in the item.yaml file", ) def item_to_function_cli( - item_path: str, output_path: Optional[str], code_output: bool, format_code: bool, bump_version: bool + item_path: str, output_path: Optional[str], code_output: bool, format_code: bool, bump_version: bool ): item_to_function(item_path, output_path, code_output, format_code, bump_version) def item_to_function( - item_path: str, - output_path: Optional[str] = None, - code_output: bool = False, - format_code: bool = True, - bump_version: bool = False, + item_path: str, + output_path: Optional[str] = None, + code_output: bool = False, + format_code: bool = True, + bump_version: bool = False, ): item_path = Path(item_path) if item_path.is_dir(): @@ -78,9 +78,9 @@ def item_to_function( # That means we need to search for items inside this direcotry else: for inner_dir in PathIterator( - root=item_path.parent, - rule=is_item_dir, - as_path=True, + root=item_path.parent, + rule=is_item_dir, + as_path=True, ): try: _output_path = output_path or (inner_dir / "function.yaml") @@ -119,11 +119,11 @@ def _get_item_yaml(item_path: Path) -> dict: def create_function_yaml( - item_path: Union[str, Path], - output_path: Optional[str] = None, - code_output: bool = False, - format_code: bool = True, - bump_version: bool = False, + item_path: Union[str, Path], + output_path: Optional[str] = None, + code_output: bool = False, + format_code: bool = True, + bump_version: bool = False, ): item_path = Path(item_path) if bump_version: @@ -161,7 +161,8 @@ def create_function_yaml( # remove build info from object function_object.spec.build.code_origin = '' function_object.spec.build.origin_filename = '' - function_object.spec.state_thresholds=None + if 'state_thresholds' not in spec: + function_object.spec.state_thresholds = None custom_fields = spec.get("customFields", {}) for key, value in custom_fields.items(): diff --git a/pii_recognizer/function.yaml b/pii_recognizer/function.yaml index 086bc3867..54b448d9c 100644 --- a/pii_recognizer/function.yaml +++ b/pii_recognizer/function.yaml @@ -2,8 +2,8 @@ kind: job metadata: name: pii-recognizer tag: '' - hash: 0972dbbfd83e86970a3655774ace0c074ea617ce - project: llm-workflow-gilads + hash: b09b7b9a4ffd55088d665a0191055411e9198a2f + project: '' labels: author: pgw categories: @@ -14,48 +14,67 @@ spec: args: [] image: '' build: - functionSourceCode: # Copyright 2019 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import warnings
import os
import logging
import mlrun
import pathlib
import tempfile
import nltk
from tqdm.auto import tqdm
from typing import List, Tuple, Set, Optional, Dict, Any, Union
from collections.abc import Iterable
import presidio_analyzer as pa
import presidio_anonymizer as pre_anoymizer
from presidio_anonymizer.entities import OperatorConfig
import annotated_text.util as at_util

try:
    import flair as fl
except ModuleNotFoundError:
    print("Flair is not installed")

# There is a conflict between Rust-based tokenizers' parallel processing
# and Python's fork operations during multiprocessing. To avoid this, we need
# the following two lines

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")

logger = logging.getLogger("pii-recognizer")


# Add the constant classes of Models and Entities to govern the whole package
class Models:
    WHOLE = "whole"
    PATTERN = "pattern"
    SPACY = "spacy"
    FLAIR = "flair"


class Entities:
    CREDIT_CARD = "CREDIT_CARD"
    SSN = "SSN"
    PHONE = "PHONE"
    EMAIL = "EMAIL"
    LOCATION = "LOCATION"
    PERSON = "PERSON"
    NRP = "NRP"
    ORGANIZATION = "ORGANIZATION"
    DATE_TIME = "DATE_TIME"
    GPE = ("GPE",)
    MAC_ADDRESS = "MAC_ADDRESS"
    US_BANK_NUMBER = "US_BANK_NUMBER"
    IMEI = "IMEI"
    TITLE = "TITLE"
    LICENSE_PLATE = "LICENSE_PLATE"
    US_PASSPORT = "US_PASSPORT"
    CURRENCY = "CURRENCY"
    ROUTING_NUMBER = "ROUTING_NUMBER"
    US_ITIN = "US_ITIN"
    US_BANK_NUMBER = "US_BANK_NUMBER"
    US_DRIVER_LICENSE = "US_DRIVER_LICENSE"
    AGE = "AGE"
    PASSWORD = "PASSWORD"
    SWIFT_CODE = "SWIFT_CODE"


class PatternRecognizerFactory:
    """
    Factory for creating pattern recognizers, it can be extended in the future to
    add more regex pattern for different entities. For the pattern recognizer to work,
    we need construct a list of regex patterns for each entity.
    """

    RECOGNIZABLE_ENTITIES = {
        "CREDIT_CARD": [pa.Pattern("CREDIT_CARD", r"\b(?:\d[ -]*?){13,16}\b", 0.5)],
        "SSN": [pa.Pattern("SSN", r"\b\d{3}-?\d{2}-?\d{4}\b", 0.5)],
        "PHONE": [pa.Pattern("PHONE", r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}", 0.5)],
        "EMAIL": [pa.Pattern("EMAIL", r"\S+@\S+", 0.5)],
    }

    # create a list of pattern recognizers
    @classmethod
    def _create_pattern_recognizer(cls):
        """
        For each entity, create a list of patterns to recognize it

        :param cls: PatternRecognizerFactory class

        :returns: List of pattern recognizers
        """

        # Entities to recognize and their regex patterns

        return [
            pa.PatternRecognizer(supported_entity=entity, patterns=pattern)
            for entity, pattern in cls.RECOGNIZABLE_ENTITIES.items()
        ]


class CustomSpacyRecognizer(pa.LocalRecognizer):
    """
    Custom Spacy Recognizer from Presidio Analyzer trained on Privy data.
    The privy data is generated using this https://github.com/pixie-io/pixie/tree/main/src/datagen/pii/privy
    It can be used to recognize custom entities, Since we want to use Presidio's Registries to generate AnalyzerEngine,
    it inherits from Presidio Analyzer's LocalRecognizer class.
    """

    # Entities to recognize

    RECOGNIZABLE_ENTITIES = {
        "LOCATION",
        "PERSON",
        "NRP",
        "ORGANIZATION",
        "DATE_TIME",
    }

    # Default explanation for this recognizer

    _DEFAULT_EXPLANATION = (
        "Identified as {} by Spacy's Named Entity Recognition (Privy-trained)"
    )

    # Label groups to check

    _DEFAULT_CHECK_LABEL_GROUPS = [
        ({"LOCATION"}, {"LOC", "LOCATION", "STREET_ADDRESS", "COORDINATE"}),
        ({"PERSON"}, {"PER", "PERSON"}),
        ({"NRP"}, {"NORP", "NRP"}),
        ({"ORGANIZATION"}, {"ORG"}),
        ({"DATE_TIME"}, {"DATE_TIME"}),
    ]

    # pretrained model for this recognizer

    _DEFAULT_MODEL_LANGUAGES = {
        "en": "beki/en_spacy_pii_distilbert",
    }

    _DEFAULT_PRESIDIO_EQUIVALENCES = {
        "PER": "PERSON",
        "LOC": "LOCATION",
        "ORG": "ORGANIZATION",
        "NROP": "NRP",
        "DATE_TIME": "DATE_TIME",
    }

    def __init__(
        self,
        supported_language: str = "en",
        supported_entities: List[str] = None,
        check_label_groups: Tuple[Set, Set] = None,
        context: List[str] = None,
        ner_strength: float = 1,
    ):
        """
        Initialize Spacy Recognizer.

        :param supported_language: Language to use, default is English
        :param supported_entities: Entities to use for recognition
        :param check_label_groups: Label groups to check for the entities
        :param context:            Context to use if any
        :param ner_strength:       Default confidence for NER prediction

        :returns: SpacyRecognizer object
        """

        # Default confidence for NER prediction
        self.ner_strength = ner_strength

        self.check_label_groups = check_label_groups or self._DEFAULT_CHECK_LABEL_GROUPS
        supported_entities = supported_entities or self.RECOGNIZABLE_ENTITIES
        super().__init__(
            supported_entities=supported_entities,
            supported_language=supported_language,
        )

    # get the presidio explanation for the result

    def _build_spacy_explanation(
        self, original_score: float, explanation: str
    ) -> pa.AnalysisExplanation:
        """
        Create explanation for why this result was detected.

        :param original_score: Score given by this recognizer
        :param explanation:    Explanation string

        :returns: Presidio AnalysisExplanation object
        """
        explanation = pa.AnalysisExplanation(
            recognizer=self.__class__.__name__,
            original_score=original_score,
            textual_explanation=explanation,
        )
        return explanation

    # main method for the recognizer
    def analyze(self, text: str, entities: List[str], nlp_artifacts=None):  # noqa D102
        """
        Analyze text using Spacy.

        :param text:          Text to analyze
        :param entities:      Entities to analyze
        :param nlp_artifacts: NLP artifacts to use

        :returns: List of Presidio RecognizerResult objects
        """
        results = []
        if not nlp_artifacts:
            logger.warning("Skipping SpaCy, nlp artifacts not provided...")
            return results

        ner_entities = nlp_artifacts.entities

        # recognize the supported entities
        for entity in entities:
            if entity not in self.supported_entities:
                continue
            for ent in ner_entities:
                if not self.__check_label(entity, ent.label_, self.check_label_groups):
                    continue

                # string of the explanation saying the entity is recognized by spacy
                textual_explanation = self._DEFAULT_EXPLANATION.format(ent.label_)
                explanation = self._build_spacy_explanation(
                    self.ner_strength, textual_explanation
                )

                # create the standard result with the entity, start, end, score, and explanation
                spacy_result = pa.RecognizerResult(
                    entity_type=entity,
                    start=ent.start_char,
                    end=ent.end_char,
                    score=self.ner_strength,
                    analysis_explanation=explanation,
                    recognition_metadata={
                        pa.RecognizerResult.RECOGNIZER_NAME_KEY: self.name
                    },
                )
                results.append(spacy_result)

        return results

    @staticmethod
    def __check_label(
        entity: str, label: str, check_label_groups: Tuple[Set, Set]
    ) -> bool:
        """
        Check if the label is in the label group.

        :param entity:             Entity to check
        :param label:              Label to check
        :param check_label_groups: Label groups to check

        :returns: True if the label is in the label group, False otherwise
        """
        return any(
            entity in egrp and label in lgrp for egrp, lgrp in check_label_groups
        )


# Class to use Flair with Presidio as an external recognizer.
class FlairRecognizer(pa.EntityRecognizer):
    """
    Wrapper for a flair model, if needed to be used within Presidio Analyzer.
    This is to make sure the recognizer can be registered with Presidio registry.
    """

    RECOGNIZABLE_ENTITIES = {
        "LOCATION",
        "PERSON",
        "NRP",
        "GPE",
        "ORGANIZATION",
        "MAC_ADDRESS",
        "US_BANK_NUMBER",
        "IMEI",
        "TITLE",
        "LICENSE_PLATE",
        "US_PASSPORT",
        "CURRENCY",
        "ROUTING_NUMBER",
        "US_ITIN",
        "US_BANK_NUMBER",
        "US_DRIVER_LICENSE",
        "AGE",
        "PASSWORD",
        "SWIFT_CODE",
    }

    # This is used to construct the explanation for the result

    _DEFAULT_EXPLANATION = "Identified as {} by Flair's Named Entity Recognition"

    _DEFAULT_CHECK_LABEL_GROUPS = [
        ({"LOCATION"}, {"LOC", "LOCATION", "STREET_ADDRESS", "COORDINATE"}),
        ({"PERSON"}, {"PER", "PERSON"}),
        ({"NRP"}, {"NORP", "NRP"}),
        ({"GPE"}, {"GPE"}),
        ({"ORGANIZATION"}, {"ORG"}),
        ({"MAC_ADDRESS"}, {"MAC_ADDRESS"}),
        ({"US_BANK_NUMBER"}, {"US_BANK_NUMBER"}),
        ({"IMEI"}, {"IMEI"}),
        ({"TITLE"}, {"TITLE"}),
        ({"LICENSE_PLATE"}, {"LICENSE_PLATE"}),
        ({"US_PASSPORT"}, {"US_PASSPORT"}),
        ({"CURRENCY"}, {"CURRENCY"}),
        ({"ROUTING_NUMBER"}, {"ROUTING_NUMBER"}),
        ({"AGE"}, {"AGE"}),
        ({"CURRENCY"}, {"CURRENCY"}),
        ({"SWIFT_CODE"}, {"SWIFT_CODE"}),
        ({"US_ITIN"}, {"US_ITIN"}),
        ({"US_BANK_NUMBER"}, {"US_BANK_NUMBER"}),
        ({"US_DRIVER_LICENSE"}, {"US_DRIVER_LICENSE"}),
    ]

    _DEFAULT_MODEL_LANGUAGES = {
        "en": "beki/flair-pii-distilbert",
    }

    _DEFAULT_PRESIDIO_EQUIVALENCES = {
        "PER": "PERSON",
        "LOC": "LOCATION",
        "ORG": "ORGANIZATION",
        "NROP": "NRP",
        "URL": "URL",
        "US_ITIN": "US_ITIN",
        "US_PASSPORT": "US_PASSPORT",
        "IBAN_CODE": "IBAN_CODE",
        "IP_ADDRESS": "IP_ADDRESS",
        "EMAIL_ADDRESS": "EMAIL",
        "US_DRIVER_LICENSE": "US_DRIVER_LICENSE",
        "US_BANK_NUMBER": "US_BANK_NUMBER",
    }

    def __init__(
        self,
        supported_language: str = "en",
        supported_entities: List[str] = None,
        check_label_groups: Tuple[Set, Set] = None,
    ):
        """
        Initialize the FlairRecognizer.

        :param supported_language: Language to use
        :param supported_entities: Entities to use
        :param check_label_groups: Label groups to check

        :returns: FlairRecognizer object

        """
        self.check_label_groups = check_label_groups or self._DEFAULT_CHECK_LABEL_GROUPS

        supported_entities = supported_entities or self.RECOGNIZABLE_ENTITIES
        self.model = fl.models.SequenceTagger.load(
            self._DEFAULT_MODEL_LANGUAGES.get(supported_language)
        )

        super().__init__(
            supported_entities=supported_entities,
            supported_language=supported_language,
            name="Flair Analytics",
        )

    # main method for the recognizer
    def analyze(
        self,
        text: str,
        entities: List[str],
        nlp_artifacts: pa.nlp_engine.NlpArtifacts = None,
    ) -> List[pa.RecognizerResult]:
        """
        Analyze text and return the results.

        :param text:          The text for analysis.
        :param entities:      The list of entities to recognize.
        :param nlp_artifacts: Not used by this recognizer but needed for the interface.
        :param language:      Text language. Supported languages in MODEL_LANGUAGES

        :returns: The list of Presidio RecognizerResult constructed from the recognized Flair detections.
        """

        results = []

        sentences = fl.data.Sentence(text)
        self.model.predict(sentences)

        # If there are no specific list of entities, we will look for all of it.
        if not entities:
            entities = self.supported_entities

        # Go over the entities and check if they are in the supported entities list.
        for entity in entities:
            if entity not in self.supported_entities:
                continue

            # Go over the sentences and check if the entity is in the sentence.
            for ent in sentences.get_spans("ner"):
                if not self.__check_label(
                    entity, ent.labels[0].value, self.check_label_groups
                ):
                    continue

                # If the entity is in the sentence, we will add it to the results.
                textual_explanation = self._DEFAULT_EXPLANATION.format(
                    ent.labels[0].value
                )

                # Build the explanation for the result
                explanation = self._build_flair_explanation(
                    round(ent.score, 2), textual_explanation
                )

                flair_result = self._convert_to_recognizer_result(ent, explanation)

                results.append(flair_result)

        return results

    def _convert_to_recognizer_result(
        self, entity: fl.data.Span, explanation: str
    ) -> pa.RecognizerResult:
        """
        Convert Flair result to Presidio RecognizerResult.

        :param entity:      Flair entity of Span
        :param explanation: Presidio AnalysisExplanation

        :returns: Presidio RecognizerResult
        """

        # Convert the entity type to Presidio entity type
        entity_type = self._DEFAULT_PRESIDIO_EQUIVALENCES.get(entity.tag, entity.tag)

        # Convert the score to Presidio score
        flair_score = round(entity.score, 2)

        # Create the Presidio RecognizerResult from the Flair entity
        flair_results = pa.RecognizerResult(
            entity_type=entity_type,
            start=entity.start_position,
            end=entity.end_position,
            score=flair_score,
            analysis_explanation=explanation,
        )

        return flair_results

    def _build_flair_explanation(
        self, original_score: float, explanation: str
    ) -> pa.AnalysisExplanation:
        """
        Create explanation for why this result was detected.

        :param original_score: Score given by this recognizer
        :param explanation:    Explanation string

        :returns: Presidio AnalysisExplanation
        """

        # Create the Presidio AnalysisExplanation for the result
        explanation = pa.AnalysisExplanation(
            recognizer=self.__class__.__name__,
            original_score=original_score,
            textual_explanation=explanation,
        )
        return explanation

    # sanity check of the entity and label before recognition
    @staticmethod
    def __check_label(
        entity: str, label: str, check_label_groups: Tuple[Set, Set]
    ) -> bool:
        return any(
            entity in egrp and label in lgrp for egrp, lgrp in check_label_groups
        )


# get the analyzer engine based on the model
def _get_analyzer_engine(
    model: str = None, entities: List[str] = None
) -> pa.AnalyzerEngine:
    """
    Return pa.AnalyzerEngine.

    :param model: The model to use. Can be "spacy", "flair", "pattern" or "whole".
    :param entities: The list of entities to use.

    :returns: pa.AnalyzerEngine
    """
    # recognizer registry that can store multiple recognizers
    registry = pa.RecognizerRegistry()
    if model == Models.SPACY:
        # custom spacy recognizer
        spacy_recognizer = CustomSpacyRecognizer()
        # add the custom build spacy recognizer
        registry.add_recognizer(spacy_recognizer)
    elif model == Models.FLAIR:
        # pre-trained flair recognizer
        flair_recognizer = FlairRecognizer()
        # add the custom build flair recognizer
        registry.add_recognizer(flair_recognizer)
    elif model == Models.PATTERN:
        # add the pattern recognizer
        pattern_recognizer_factory = PatternRecognizerFactory()
        for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
            registry.add_recognizer(recognizer)
    elif model == Models.WHOLE:
        spacy_recognizer = CustomSpacyRecognizer()
        flair_recognizer = FlairRecognizer()
        registry.add_recognizer(spacy_recognizer)
        registry.add_recognizer(flair_recognizer)
        # add the pattern recognizer
        pattern_recognizer_factory = PatternRecognizerFactory()
        for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
            registry.add_recognizer(recognizer)
    elif not model and entities:
        if set(entities) & CustomSpacyRecognizer.RECOGNIZABLE_ENTITIES:
            spacy_recognizer = CustomSpacyRecognizer()
            registry.add_recognizer(spacy_recognizer)
        if set(entities) & FlairRecognizer.RECOGNIZABLE_ENTITIES:
            flair_recognizer = FlairRecognizer()
            registry.add_recognizer(flair_recognizer)
        # add the pattern recognizer
        if set(entities) & (set(PatternRecognizerFactory.RECOGNIZABLE_ENTITIES.keys())):
            pattern_recognizer_factory = PatternRecognizerFactory()
            for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
                registry.add_recognizer(recognizer)
    else:
        raise ValueError(
            f"argument of model and entities can not be None at the same time"
        )
    analyzer = pa.AnalyzerEngine(
        registry=registry,
        supported_languages=["en"],
    )

    supported_entities = analyzer.get_supported_entities()

    if entities and not all(item in supported_entities for item in entities):
        not_supported_entities = [
            item for item in entities if item not in supported_entities
        ]
        raise ValueError(
            f"The current model {model} doesn't support the following entities: {not_supported_entities}. "
            f"Supported entities are: {supported_entities}"
        )
    return analyzer


def _get_anonymizer_engine() -> pre_anoymizer.AnonymizerEngine:
    """
    Return AnonymizerEngine.

    :returns: The AnonymizerEngine.
    """
    return pre_anoymizer.AnonymizerEngine()


def _anonymize(
    text: str,
    analyze_results: List[pa.RecognizerResult],
    entity_operator_map: dict = None,
    is_full_text: bool = True,
) -> str:
    """
    Anonymize identified input using Presidio Abonymizer.

    :param text:                The text for analysis.
    :param analyze_results:     The list of Presidio RecognizerResult constructed from
    :param entity_operator_map: The entity_operator_map is a dictionary that maps entity to operator name and operator params.
    :param is_full_text:        Whether the text is full text or not.

    :returns: The anonymized text.
    """
    if not text:
        return ""

    anonymizer_engine = _get_anonymizer_engine()
    if not entity_operator_map:
        operators = None
    else:
        # Create OperatorConfig based on the entity_operator_map
        operators = {
            entity: OperatorConfig(operator_name, operator_params)
            for entity, (operator_name, operator_params) in entity_operator_map.items()
        }

    if is_full_text:
        # Anonymize the entire text
        return anonymizer_engine.anonymize(
            text=text, analyzer_results=analyze_results, operators=operators
        ).text
    # Tokenize the text to sentences
    sentences = nltk.sent_tokenize(text)
    anonymized_sentences = []
    current_idx = 0

    # Find the sentence that has pii entity
    for sentence in sentences:
        start_idx = current_idx
        end_idx = start_idx + len(sentence)

        # Get the entities that are in the sentence, update hte start_idx and end_idx
        sentence_results = [
            pa.RecognizerResult(
                result.entity_type,
                start=result.start - start_idx,
                end=result.end - start_idx,
                score=result.score,
            )
            for result in analyze_results
            if result.start >= start_idx and result.end <= end_idx
        ]

        # If PII is detected
        if sentence_results:
            anonymized_sentence = anonymizer_engine.anonymize(
                text=sentence, analyzer_results=sentence_results, operators=operators
            ).text
            anonymized_sentences.append(anonymized_sentence)

        current_idx = end_idx

    return " ".join(anonymized_sentences)


def _get_tokens(
    text: str, analyze_results: List[pa.RecognizerResult], is_full: bool = True
) -> List[str]:
    """
    Get the full tokens or only contains the entities that can form a sentence.

    :param text:            The text for analysis.
    :param analyze_results: The list of Presidio RecognizerResult constructed from
    :param is_full:         Whether return full tokens or just the tokens that only contains the entities that can form a sentence.

    :returns: The tokens.
    """

    tokens = []
    # sort by start index
    results = sorted(analyze_results, key=lambda x: x.start)
    for i, res in enumerate(results):
        if i == 0:
            tokens.append(text[: res.start])

        # append entity text and entity type
        tokens.append((text[res.start : res.end], res.entity_type))

        # if another entity coming i.e. we're not at the last results element,
        # add text up to next entity
        if i != len(results) - 1:
            tokens.append(text[res.end : results[i + 1].start])
        # if no more entities coming, add all remaining text
        else:
            tokens.append(text[res.end :])

    # get the tokens that only contains the entities that can form a sentence
    part_annontated_tokens = []
    if not is_full:
        last_end_sentence = 0
        for i, token in enumerate(tokens):
            if any(item in token for item in [".", "!", "?"]) and any(
                type(item) is tuple for item in tokens[last_end_sentence:i]
            ):
                part_annontated_tokens.append(tokens[last_end_sentence:i])
                last_end_sentence = i
        return part_annontated_tokens
    return tokens


def _annotate(
    text: str, st_analyze_results: List[pa.RecognizerResult], is_full_html: bool = True
) -> List[str]:
    """
    Annotate identified input using Presidio Anonymizer.

    :param text:               The text for analysis.
    :param st_analyze_results: The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html:       Whether generate full html or not.

    :returns: The list of tokens with the identified entities.

    """
    return _get_tokens(text, st_analyze_results, is_full_html)


def _process(
    text: str,
    model: pa.AnalyzerEngine,
    score_threshold: float,
    entities: List[str] = None,
    entities_operator_map: dict = None,
    is_full_text: bool = True,
) -> Tuple[str, str, str]:
    """
    Process the text of str using the model.

    :param txt:                   Text to process
    :param model:                 Model to use for processing
    :param entities:              Entities to recognize
    :param entities_operator_map: The entity_operator_map is a dictionary that maps entity to operator name and operator params.
    :param score_threshold:       The score threshold to use for recognition
    :param is_full_text:          Whether to return the full text or just the annotated text

    :returns: A tuple of:

              * the anonymized text
              * the list of Presidio RecognizerResult constructed from analysis
    """

    # get the analyzer engine

    analyzer = model

    # analyze the text that can be used for anonymization
    results = analyzer.analyze(
        text=text,
        language="en",
        entities=entities,
        score_threshold=score_threshold,
        return_decision_process=True,
    )

    # anonymize the text, replace the pii entities with the labels
    anonymized_text = _anonymize(text, results, entities_operator_map, is_full_text)

    return anonymized_text, results


def _get_single_html(
    text: str, results: List[pa.RecognizerResult], is_full_html: bool = True
):
    """
    Generate the html for a single txt file.

    :param text:         The text for analysis.
    :param results:      The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html: Whether generate full html or not.

    :returns: The html string for a single txt file.
    """
    # convert the results to tokens to generate the html
    tokens = _annotate(text, results, is_full_html)
    html = at_util.get_annotated_html(*tokens)

    # avoid the error during rendering of the \n in the html
    backslash_char = "\\"

    html_str = f"<p>{html.replace('{backslash_char}n', '<br>')}</p>"

    return html_str


def _get_single_json(results: List[pa.RecognizerResult], is_full_report: bool = True):
    """
    Generate the json for a single txt file.

    :param results:        The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_report: Whether generate full json or not.

    :returns: The json string for a single txt file.
    """
    # generate the stats report if needed
    if not is_full_report:
        stats = []
        # add the simplify stats logic here
        for item in results:
            item.analysis_explanation = None
            stats.append(item)
    else:
        stats = results

    return stats


def _get_all_html(
    txt_content: dict,
    res_dict: dict,
    is_full_html: bool = True,
):
    """
    Generate the html for all txt files.

    :param txt_content:  The dictionary of txt file name and content.
    :param res_dict:     The dictionary of txt file name and the list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html: Whether generate full html or not.

    :returns: The html string for all txt files.

    """
    # These are placeholder for the html string
    html_index = "<html><head><title>Highlighted Pii Entities</title></head><body><h1>Highlighted Pii Entities</h1><ul>"
    html_content = ""
    for txt_file, results in res_dict.items():
        txt = txt_content[txt_file]
        html_index += f"<li><a href='#{txt_file}'>{txt_file}</a></li>"
        html_content += f"<li><h2>{txt_file}</h2><p>{_get_single_html(txt, results, is_full_html)}</p></li>"
    html_index += "</ul>"
    html_res = f"{html_index}{html_content}</body></html>"

    return html_res


def _get_all_rpt(res_dict: dict, is_full_report: bool = True):
    """
    Generate the stats report for all txt files.

    :param res_dict:       The dictionary of txt file name and the list of Presidio RecognizerResult constructed from analysis.
    :param is_full_report: Whether generate full report or not.

    :returns: The stats report for all txt files.
    """
    # These are placeholder for the json report
    stats_dict = {}
    for txt_file, results in res_dict.items():
        new_stats = []
        for item in _get_single_json(results, is_full_report):
            if is_full_report:
                item.analysis_explanation = item.analysis_explanation.to_dict()
                new_stats.append(item.to_dict())
            else:
                tmp_dict = item.to_dict()
                tmp_dict.pop("analysis_explanation")
                tmp_dict.pop("recognition_metadata")
                new_stats.append(tmp_dict)
        stats_dict[txt_file] = new_stats
    return stats_dict


def recognize_pii(
    context: mlrun.MLClientCtx,
    input_path: Union[str,pathlib.Path],
    output_path: str,
    output_suffix: str,
    html_key: str,
    score_threshold: float,
    entities: List[
        str
    ] = None,  # List of entities to recognize, default is recognizing all
    entity_operator_map: dict = None,
    model: str = None,
    generate_json: bool = True,
    generate_html: bool = True,
    is_full_text: bool = True,
    is_full_html: bool = True,
    is_full_report: bool = True,
) -> Tuple[pathlib.Path, dict, dict]:
    """
    Walk through the input path, recognize PII in text and store the anonymized text in the output path. Generate the html with different colors for each entity, json report of the explaination.

    :param context:              The MLRun context. this is needed for log the artifacts.
    :param input_path:           The input path of the text files needs to be analyzied.
    :param output_path:          The output path to store the anonymized text.
    :param output_suffix:        The surfix of output key for the anonymized text. for example if the input file is pii.txt, the output key is anoymized, the output file name will be pii_anonymized.txt.
    :param html_key:             The html key for the artifact.
    :param score_threshold:      The score threshold to mark the recognition as trusted.
    :param entities:             The list of entities to recognize.
    :param entity_operator_map:  The map of entity to operator (mask, redact, replace, keep, hash, and its params)
    :param model:                The model to use. Can be "spacy", "flair", "pattern" or "whole".
    :param generate_json:        Whether to generate the json report of the explaination.
    :param generate_html:        Whether to generate the html report of the explaination.
    :param is_full_text:         Whether to return the full text or only the masked text.
    :param is_full_html:         Whether to return the full html or just the annotated text
    :param is_full_report:       Whether to return the full report or just the score and start, end index

    :returns: A tuple of:

              * Path to the output directory
              * The json report of the explaination (if generate_json is True)
              * A dictionary of errors files that were not processed

    """

    # Set output directory
    if output_path is None:
        output_path = tempfile.mkdtemp()

    # Create the output directory:
    output_directory = pathlib.Path(output_path)
    if not output_directory.exists():
        output_directory.mkdir()

    txt_files_directory = pathlib.Path(input_path)
    errors = {}

    res_dict = {}
    txt_content = {}
    # Load the model:
    try:
        analyzer = _get_analyzer_engine(model, entities)
    except Exception as e:
        errors["model"] = str(e)
        logger.error(f"Error when get the model: {e}")

    logger.info("Model loaded")
    # Go over the text files in the input path, analyze and anonymize them:
    for i, txt_file in enumerate(
        tqdm(
            list(txt_files_directory.glob("*.txt")),
            desc="Processing files",
            unit="file",
        )
    ):
        try:
            # Load the str from the text file
            text = txt_file.read_text()
            txt_content[str(txt_file)] = text
            # Process the text to recoginze the pii entities in it
            anonymized_text, results = _process(
                text=text,
                model=analyzer,
                entities=entities,
                entities_operator_map=entity_operator_map,
                score_threshold=score_threshold,
                is_full_text=is_full_text,
            )
            res_dict[str(txt_file)] = results
            # Store the anonymized text in the output path
            output_file = (
                output_directory
                / f"{str(txt_file.relative_to(txt_files_directory)).split('.')[0]}.txt"
            )
            output_file.parent.mkdir(parents=True, exist_ok=True)
            with open(output_file, "w") as f:
                f.write(anonymized_text)

        except Exception as e:
            errors[str(txt_file)] = str(e)
            logger.error(f"Error processing {txt_file}: {e}")
    if generate_html:
        # Generate the html report
        html_res = _get_all_html(txt_content, res_dict, is_full_html)
        # Store the html report in the context
        arti_html = mlrun.artifacts.Artifact(body=html_res, format="html", key=html_key)
        context.log_artifact(arti_html)
    if generate_json:
        # Generate the json report
        json_res = _get_all_rpt(res_dict, is_full_report)
        return output_path, json_res, errors
    return output_path, errors
 + functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import pathlib
import tempfile
import warnings
from typing import List, Set, Tuple, Union

import annotated_text.util as at_util
import mlrun
import nltk
import pandas as pd
import presidio_analyzer as pa
import presidio_anonymizer as pre_anoymizer
from presidio_anonymizer.entities import OperatorConfig
from tqdm import tqdm

try:
    import flair as fl
except ModuleNotFoundError:
    print("Flair is not installed")

# There is a conflict between Rust-based tokenizers' parallel processing
# and Python's fork operations during multiprocessing. To avoid this, we need
# the following two lines

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")

logger = logging.getLogger("pii-recognizer")


# Add the constant classes of Models and Entities to govern the whole package
class Models:
    WHOLE = "whole"
    PATTERN = "pattern"
    SPACY = "spacy"
    FLAIR = "flair"


class Entities:
    CREDIT_CARD = "CREDIT_CARD"
    SSN = "SSN"
    PHONE = "PHONE"
    EMAIL = "EMAIL"
    LOCATION = "LOCATION"
    PERSON = "PERSON"
    NRP = "NRP"
    ORGANIZATION = "ORGANIZATION"
    DATE_TIME = "DATE_TIME"
    GPE = ("GPE",)
    MAC_ADDRESS = "MAC_ADDRESS"
    US_BANK_NUMBER = "US_BANK_NUMBER"
    IMEI = "IMEI"
    TITLE = "TITLE"
    LICENSE_PLATE = "LICENSE_PLATE"
    US_PASSPORT = "US_PASSPORT"
    CURRENCY = "CURRENCY"
    ROUTING_NUMBER = "ROUTING_NUMBER"
    US_ITIN = "US_ITIN"
    US_BANK_NUMBER = "US_BANK_NUMBER"
    US_DRIVER_LICENSE = "US_DRIVER_LICENSE"
    AGE = "AGE"
    PASSWORD = "PASSWORD"
    SWIFT_CODE = "SWIFT_CODE"


class PatternRecognizerFactory:
    """
    Factory for creating pattern recognizers, it can be extended in the future to
    add more regex pattern for different entities. For the pattern recognizer to work,
    we need construct a list of regex patterns for each entity.
    """

    RECOGNIZABLE_ENTITIES = {
        "CREDIT_CARD": [pa.Pattern("CREDIT_CARD", r"\b(?:\d[ -]*?){13,16}\b", 0.5)],
        "SSN": [pa.Pattern("SSN", r"\b\d{3}-?\d{2}-?\d{4}\b", 0.5)],
        "PHONE": [pa.Pattern("PHONE", r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}", 0.5)],
        "EMAIL": [pa.Pattern("EMAIL", r"\S+@\S+", 0.5)],
    }

    # create a list of pattern recognizers
    @classmethod
    def _create_pattern_recognizer(cls):
        """
        For each entity, create a list of patterns to recognize it

        :param cls: PatternRecognizerFactory class

        :returns: List of pattern recognizers
        """

        # Entities to recognize and their regex patterns

        return [
            pa.PatternRecognizer(supported_entity=entity, patterns=pattern)
            for entity, pattern in cls.RECOGNIZABLE_ENTITIES.items()
        ]


class CustomSpacyRecognizer(pa.LocalRecognizer):
    """
    Custom Spacy Recognizer from Presidio Analyzer trained on Privy data.
    The privy data is generated using this https://github.com/pixie-io/pixie/tree/main/src/datagen/pii/privy
    It can be used to recognize custom entities, Since we want to use Presidio's Registries to generate AnalyzerEngine,
    it inherits from Presidio Analyzer's LocalRecognizer class.
    """

    # Entities to recognize

    RECOGNIZABLE_ENTITIES = {
        "LOCATION",
        "PERSON",
        "NRP",
        "ORGANIZATION",
        "DATE_TIME",
    }

    # Default explanation for this recognizer

    _DEFAULT_EXPLANATION = (
        "Identified as {} by Spacy's Named Entity Recognition (Privy-trained)"
    )

    # Label groups to check

    _DEFAULT_CHECK_LABEL_GROUPS = [
        ({"LOCATION"}, {"LOC", "LOCATION", "STREET_ADDRESS", "COORDINATE"}),
        ({"PERSON"}, {"PER", "PERSON"}),
        ({"NRP"}, {"NORP", "NRP"}),
        ({"ORGANIZATION"}, {"ORG"}),
        ({"DATE_TIME"}, {"DATE_TIME"}),
    ]

    # pretrained model for this recognizer

    _DEFAULT_MODEL_LANGUAGES = {
        "en": "beki/en_spacy_pii_distilbert",
    }

    _DEFAULT_PRESIDIO_EQUIVALENCES = {
        "PER": "PERSON",
        "LOC": "LOCATION",
        "ORG": "ORGANIZATION",
        "NROP": "NRP",
        "DATE_TIME": "DATE_TIME",
    }

    def __init__(
        self,
        supported_language: str = "en",
        supported_entities: List[str] = None,
        check_label_groups: Tuple[Set, Set] = None,
        context: List[str] = None,
        ner_strength: float = 1,
    ):
        """
        Initialize Spacy Recognizer.

        :param supported_language: Language to use, default is English
        :param supported_entities: Entities to use for recognition
        :param check_label_groups: Label groups to check for the entities
        :param context:            Context to use if any
        :param ner_strength:       Default confidence for NER prediction

        :returns: SpacyRecognizer object
        """

        # Default confidence for NER prediction
        self.ner_strength = ner_strength

        self.check_label_groups = check_label_groups or self._DEFAULT_CHECK_LABEL_GROUPS
        supported_entities = supported_entities or self.RECOGNIZABLE_ENTITIES
        super().__init__(
            supported_entities=supported_entities,
            supported_language=supported_language,
        )

    # get the presidio explanation for the result

    def _build_spacy_explanation(
        self, original_score: float, explanation: str
    ) -> pa.AnalysisExplanation:
        """
        Create explanation for why this result was detected.

        :param original_score: Score given by this recognizer
        :param explanation:    Explanation string

        :returns: Presidio AnalysisExplanation object
        """
        explanation = pa.AnalysisExplanation(
            recognizer=self.__class__.__name__,
            original_score=original_score,
            textual_explanation=explanation,
        )
        return explanation

    # main method for the recognizer
    def analyze(self, text: str, entities: List[str], nlp_artifacts=None):  # noqa D102
        """
        Analyze text using Spacy.

        :param text:          Text to analyze
        :param entities:      Entities to analyze
        :param nlp_artifacts: NLP artifacts to use

        :returns: List of Presidio RecognizerResult objects
        """
        results = []
        if not nlp_artifacts:
            logger.warning("Skipping SpaCy, nlp artifacts not provided...")
            return results

        ner_entities = nlp_artifacts.entities

        # recognize the supported entities
        for entity in entities:
            if entity not in self.supported_entities:
                continue
            for ent in ner_entities:
                if not self.__check_label(entity, ent.label_, self.check_label_groups):
                    continue

                # string of the explanation saying the entity is recognized by spacy
                textual_explanation = self._DEFAULT_EXPLANATION.format(ent.label_)
                explanation = self._build_spacy_explanation(
                    self.ner_strength, textual_explanation
                )

                # create the standard result with the entity, start, end, score, and explanation
                spacy_result = pa.RecognizerResult(
                    entity_type=entity,
                    start=ent.start_char,
                    end=ent.end_char,
                    score=self.ner_strength,
                    analysis_explanation=explanation,
                    recognition_metadata={
                        pa.RecognizerResult.RECOGNIZER_NAME_KEY: self.name
                    },
                )
                results.append(spacy_result)

        return results

    @staticmethod
    def __check_label(
        entity: str, label: str, check_label_groups: Tuple[Set, Set]
    ) -> bool:
        """
        Check if the label is in the label group.

        :param entity:             Entity to check
        :param label:              Label to check
        :param check_label_groups: Label groups to check

        :returns: True if the label is in the label group, False otherwise
        """
        return any(
            entity in egrp and label in lgrp for egrp, lgrp in check_label_groups
        )


# Class to use Flair with Presidio as an external recognizer.
class FlairRecognizer(pa.EntityRecognizer):
    """
    Wrapper for a flair model, if needed to be used within Presidio Analyzer.
    This is to make sure the recognizer can be registered with Presidio registry.
    """

    RECOGNIZABLE_ENTITIES = {
        "LOCATION",
        "PERSON",
        "NRP",
        "GPE",
        "ORGANIZATION",
        "MAC_ADDRESS",
        "US_BANK_NUMBER",
        "IMEI",
        "TITLE",
        "LICENSE_PLATE",
        "US_PASSPORT",
        "CURRENCY",
        "ROUTING_NUMBER",
        "US_ITIN",
        "US_BANK_NUMBER",
        "US_DRIVER_LICENSE",
        "AGE",
        "PASSWORD",
        "SWIFT_CODE",
    }

    # This is used to construct the explanation for the result

    _DEFAULT_EXPLANATION = "Identified as {} by Flair's Named Entity Recognition"

    _DEFAULT_CHECK_LABEL_GROUPS = [
        ({"LOCATION"}, {"LOC", "LOCATION", "STREET_ADDRESS", "COORDINATE"}),
        ({"PERSON"}, {"PER", "PERSON"}),
        ({"NRP"}, {"NORP", "NRP"}),
        ({"GPE"}, {"GPE"}),
        ({"ORGANIZATION"}, {"ORG"}),
        ({"MAC_ADDRESS"}, {"MAC_ADDRESS"}),
        ({"US_BANK_NUMBER"}, {"US_BANK_NUMBER"}),
        ({"IMEI"}, {"IMEI"}),
        ({"TITLE"}, {"TITLE"}),
        ({"LICENSE_PLATE"}, {"LICENSE_PLATE"}),
        ({"US_PASSPORT"}, {"US_PASSPORT"}),
        ({"CURRENCY"}, {"CURRENCY"}),
        ({"ROUTING_NUMBER"}, {"ROUTING_NUMBER"}),
        ({"AGE"}, {"AGE"}),
        ({"CURRENCY"}, {"CURRENCY"}),
        ({"SWIFT_CODE"}, {"SWIFT_CODE"}),
        ({"US_ITIN"}, {"US_ITIN"}),
        ({"US_BANK_NUMBER"}, {"US_BANK_NUMBER"}),
        ({"US_DRIVER_LICENSE"}, {"US_DRIVER_LICENSE"}),
    ]

    _DEFAULT_MODEL_LANGUAGES = {
        "en": "beki/flair-pii-distilbert",
    }

    _DEFAULT_PRESIDIO_EQUIVALENCES = {
        "PER": "PERSON",
        "LOC": "LOCATION",
        "ORG": "ORGANIZATION",
        "NROP": "NRP",
        "URL": "URL",
        "US_ITIN": "US_ITIN",
        "US_PASSPORT": "US_PASSPORT",
        "IBAN_CODE": "IBAN_CODE",
        "IP_ADDRESS": "IP_ADDRESS",
        "EMAIL_ADDRESS": "EMAIL",
        "US_DRIVER_LICENSE": "US_DRIVER_LICENSE",
        "US_BANK_NUMBER": "US_BANK_NUMBER",
    }

    def __init__(
        self,
        supported_language: str = "en",
        supported_entities: List[str] = None,
        check_label_groups: Tuple[Set, Set] = None,
    ):
        """
        Initialize the FlairRecognizer.

        :param supported_language: Language to use
        :param supported_entities: Entities to use
        :param check_label_groups: Label groups to check

        :returns: FlairRecognizer object

        """
        self.check_label_groups = check_label_groups or self._DEFAULT_CHECK_LABEL_GROUPS

        supported_entities = supported_entities or self.RECOGNIZABLE_ENTITIES
        self.model = fl.models.SequenceTagger.load(
            self._DEFAULT_MODEL_LANGUAGES.get(supported_language)
        )

        super().__init__(
            supported_entities=supported_entities,
            supported_language=supported_language,
            name="Flair Analytics",
        )

    # main method for the recognizer
    def analyze(
        self,
        text: str,
        entities: List[str],
        nlp_artifacts: pa.nlp_engine.NlpArtifacts = None,
    ) -> List[pa.RecognizerResult]:
        """
        Analyze text and return the results.

        :param text:          The text for analysis.
        :param entities:      The list of entities to recognize.
        :param nlp_artifacts: Not used by this recognizer but needed for the interface.

        :returns: The list of Presidio RecognizerResult constructed from the recognized Flair detections.
        """

        results = []

        sentences = fl.data.Sentence(text)
        self.model.predict(sentences)

        # If there are no specific list of entities, we will look for all of it.
        if not entities:
            entities = self.supported_entities

        # Go over the entities and check if they are in the supported entities list.
        for entity in entities:
            if entity not in self.supported_entities:
                continue

            # Go over the sentences and check if the entity is in the sentence.
            for ent in sentences.get_spans("ner"):
                if not self.__check_label(
                    entity, ent.labels[0].value, self.check_label_groups
                ):
                    continue

                # If the entity is in the sentence, we will add it to the results.
                textual_explanation = self._DEFAULT_EXPLANATION.format(
                    ent.labels[0].value
                )

                # Build the explanation for the result
                explanation = self._build_flair_explanation(
                    round(ent.score, 2), textual_explanation
                )

                flair_result = self._convert_to_recognizer_result(ent, explanation)

                results.append(flair_result)

        return results

    def _convert_to_recognizer_result(
        self, entity: fl.data.Span, explanation: str
    ) -> pa.RecognizerResult:
        """
        Convert Flair result to Presidio RecognizerResult.

        :param entity:      Flair entity of Span
        :param explanation: Presidio AnalysisExplanation

        :returns: Presidio RecognizerResult
        """

        # Convert the entity type to Presidio entity type
        entity_type = self._DEFAULT_PRESIDIO_EQUIVALENCES.get(entity.tag, entity.tag)

        # Convert the score to Presidio score
        flair_score = round(entity.score, 2)

        # Create the Presidio RecognizerResult from the Flair entity
        flair_results = pa.RecognizerResult(
            entity_type=entity_type,
            start=entity.start_position,
            end=entity.end_position,
            score=flair_score,
            analysis_explanation=explanation,
        )

        return flair_results

    def _build_flair_explanation(
        self, original_score: float, explanation: str
    ) -> pa.AnalysisExplanation:
        """
        Create explanation for why this result was detected.

        :param original_score: Score given by this recognizer
        :param explanation:    Explanation string

        :returns: Presidio AnalysisExplanation
        """

        # Create the Presidio AnalysisExplanation for the result
        explanation = pa.AnalysisExplanation(
            recognizer=self.__class__.__name__,
            original_score=original_score,
            textual_explanation=explanation,
        )
        return explanation

    # sanity check of the entity and label before recognition
    @staticmethod
    def __check_label(
        entity: str, label: str, check_label_groups: Tuple[Set, Set]
    ) -> bool:
        return any(
            entity in egrp and label in lgrp for egrp, lgrp in check_label_groups
        )


# get the analyzer engine based on the model
def _get_analyzer_engine(
    model: str = None, entities: List[str] = None
) -> pa.AnalyzerEngine:
    """
    Return pa.AnalyzerEngine.

    :param model: The model to use. Can be "spacy", "flair", "pattern" or "whole".
    :param entities: The list of entities to use.

    :returns: pa.AnalyzerEngine
    """
    # recognizer registry that can store multiple recognizers
    registry = pa.RecognizerRegistry()
    if model == Models.SPACY:
        # custom spacy recognizer
        spacy_recognizer = CustomSpacyRecognizer()
        # add the custom build spacy recognizer
        registry.add_recognizer(spacy_recognizer)
    elif model == Models.FLAIR:
        # pre-trained flair recognizer
        flair_recognizer = FlairRecognizer()
        # add the custom build flair recognizer
        registry.add_recognizer(flair_recognizer)
    elif model == Models.PATTERN:
        # add the pattern recognizer
        pattern_recognizer_factory = PatternRecognizerFactory()
        for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
            registry.add_recognizer(recognizer)
    elif model == Models.WHOLE:
        spacy_recognizer = CustomSpacyRecognizer()
        flair_recognizer = FlairRecognizer()
        registry.add_recognizer(spacy_recognizer)
        registry.add_recognizer(flair_recognizer)
        # add the pattern recognizer
        pattern_recognizer_factory = PatternRecognizerFactory()
        for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
            registry.add_recognizer(recognizer)
    elif not model and entities:
        if set(entities) & CustomSpacyRecognizer.RECOGNIZABLE_ENTITIES:
            spacy_recognizer = CustomSpacyRecognizer()
            registry.add_recognizer(spacy_recognizer)
        if set(entities) & FlairRecognizer.RECOGNIZABLE_ENTITIES:
            flair_recognizer = FlairRecognizer()
            registry.add_recognizer(flair_recognizer)
        # add the pattern recognizer
        if set(entities) & (set(PatternRecognizerFactory.RECOGNIZABLE_ENTITIES.keys())):
            pattern_recognizer_factory = PatternRecognizerFactory()
            for recognizer in pattern_recognizer_factory._create_pattern_recognizer():
                registry.add_recognizer(recognizer)
    else:
        raise ValueError(
            f"argument of model and entities can not be None at the same time"
        )
    analyzer = pa.AnalyzerEngine(
        registry=registry,
        supported_languages=["en"],
    )

    supported_entities = analyzer.get_supported_entities()

    if entities and not all(item in supported_entities for item in entities):
        not_supported_entities = [
            item for item in entities if item not in supported_entities
        ]
        raise ValueError(
            f"The current model {model} doesn't support the following entities: {not_supported_entities}. "
            f"Supported entities are: {supported_entities}"
        )
    return analyzer


def _get_anonymizer_engine() -> pre_anoymizer.AnonymizerEngine:
    """
    Return AnonymizerEngine.

    :returns: The AnonymizerEngine.
    """
    return pre_anoymizer.AnonymizerEngine()


def _anonymize(
    text: str,
    analyze_results: List[pa.RecognizerResult],
    entity_operator_map: dict = None,
    is_full_text: bool = True,
) -> str:
    """
    Anonymize identified input using Presidio Abonymizer.

    :param text:                The text for analysis.
    :param analyze_results:     The list of Presidio RecognizerResult constructed from
    :param entity_operator_map: The entity_operator_map is a dictionary that maps entity to operator name and operator params.
    :param is_full_text:        Whether the text is full text or not.

    :returns: The anonymized text.
    """
    if not text:
        return ""

    anonymizer_engine = _get_anonymizer_engine()
    if not entity_operator_map:
        operators = None
    else:
        # Create OperatorConfig based on the entity_operator_map
        operators = {
            entity: OperatorConfig(operator_name, operator_params)
            for entity, (operator_name, operator_params) in entity_operator_map.items()
        }

    if is_full_text:
        # Anonymize the entire text
        return anonymizer_engine.anonymize(
            text=text, analyzer_results=analyze_results, operators=operators
        ).text
    # Tokenize the text to sentences
    sentences = nltk.sent_tokenize(text)
    anonymized_sentences = []
    current_idx = 0

    # Find the sentence that has pii entity
    for sentence in sentences:
        start_idx = current_idx
        end_idx = start_idx + len(sentence)

        # Get the entities that are in the sentence, update hte start_idx and end_idx
        sentence_results = [
            pa.RecognizerResult(
                result.entity_type,
                start=result.start - start_idx,
                end=result.end - start_idx,
                score=result.score,
            )
            for result in analyze_results
            if result.start >= start_idx and result.end <= end_idx
        ]

        # If PII is detected
        if sentence_results:
            anonymized_sentence = anonymizer_engine.anonymize(
                text=sentence, analyzer_results=sentence_results, operators=operators
            ).text
            anonymized_sentences.append(anonymized_sentence)

        current_idx = end_idx

    return " ".join(anonymized_sentences)


def _get_tokens(
    text: str, analyze_results: List[pa.RecognizerResult], is_full: bool = True
) -> List[str]:
    """
    Get the full tokens or only contains the entities that can form a sentence.

    :param text:            The text for analysis.
    :param analyze_results: The list of Presidio RecognizerResult constructed from
    :param is_full:         Whether return full tokens or just the tokens that only contains the entities that can form a sentence.

    :returns: The tokens.
    """

    tokens = []
    # sort by start index
    results = sorted(analyze_results, key=lambda x: x.start)
    for i, res in enumerate(results):
        if i == 0:
            tokens.append(text[: res.start])

        # append entity text and entity type
        tokens.append((text[res.start : res.end], res.entity_type))

        # if another entity coming i.e. we're not at the last results element,
        # add text up to next entity
        if i != len(results) - 1:
            tokens.append(text[res.end : results[i + 1].start])
        # if no more entities coming, add all remaining text
        else:
            tokens.append(text[res.end :])

    # get the tokens that only contains the entities that can form a sentence
    part_annontated_tokens = []
    if not is_full:
        last_end_sentence = 0
        for i, token in enumerate(tokens):
            if any(item in token for item in [".", "!", "?"]) and any(
                type(item) is tuple for item in tokens[last_end_sentence:i]
            ):
                part_annontated_tokens.append(tokens[last_end_sentence:i])
                last_end_sentence = i
        return part_annontated_tokens
    return tokens


def _annotate(
    text: str, st_analyze_results: List[pa.RecognizerResult], is_full_html: bool = True
) -> List[str]:
    """
    Annotate identified input using Presidio Anonymizer.

    :param text:               The text for analysis.
    :param st_analyze_results: The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html:       Whether generate full html or not.

    :returns: The list of tokens with the identified entities.

    """
    return _get_tokens(text, st_analyze_results, is_full_html)


def _process(
    text: str,
    model: pa.AnalyzerEngine,
    score_threshold: float,
    entities: List[str] = None,
    entities_operator_map: dict = None,
    is_full_text: bool = True,
) -> Tuple[str, list]:
    """
    Process the text of str using the model.

    :param text:                  Text to process
    :param model:                 Model to use for processing
    :param entities:              Entities to recognize
    :param entities_operator_map: The entity_operator_map is a dictionary that maps entity to operator name and operator params.
    :param score_threshold:       The score threshold to use for recognition
    :param is_full_text:          Whether to return the full text or just the annotated text

    :returns: A tuple of:

              * the anonymized text
              * the list of Presidio RecognizerResult constructed from analysis
    """

    # get the analyzer engine
    analyzer = model

    # analyze the text that can be used for anonymization
    results = analyzer.analyze(
        text=text,
        language="en",
        entities=entities,
        score_threshold=score_threshold,
        return_decision_process=True,
    )

    # anonymize the text, replace the pii entities with the labels
    anonymized_text = _anonymize(text, results, entities_operator_map, is_full_text)

    return anonymized_text, results


def _get_single_html(
    text: str, results: List[pa.RecognizerResult], is_full_html: bool = True
):
    """
    Generate the html for a single txt file.

    :param text:         The text for analysis.
    :param results:      The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html: Whether generate full html or not.

    :returns: The html string for a single txt file.
    """
    # convert the results to tokens to generate the html
    tokens = _annotate(text, results, is_full_html)
    html = at_util.get_annotated_html(*tokens)

    # avoid the error during rendering of the \n in the html
    backslash_char = "\\"

    html_str = f"<p>{html.replace('{backslash_char}n', '<br>')}</p>"

    return html_str


def _get_single_json(results: List[pa.RecognizerResult], is_full_report: bool = True):
    """
    Generate the json for a single txt file.

    :param results:        The list of Presidio RecognizerResult constructed from analysis.
    :param is_full_report: Whether generate full json or not.

    :returns: The json string for a single txt file.
    """
    # generate the stats report if needed
    if not is_full_report:
        stats = []
        # add the simplify stats logic here
        for item in results:
            item.analysis_explanation = None
            stats.append(item)
    else:
        stats = results

    return stats


def _get_all_html(
    txt_content: dict,
    res_dict: dict,
    is_full_html: bool = True,
):
    """
    Generate the html for all txt files.

    :param txt_content:  The dictionary of txt file name and content.
    :param res_dict:     The dictionary of txt file name and the list of Presidio RecognizerResult constructed from analysis.
    :param is_full_html: Whether generate full html or not.

    :returns: The html string for all txt files.

    """
    # These are placeholder for the html string
    html_index = "<html><head><title>Highlighted Pii Entities</title></head><body><h1>Highlighted Pii Entities</h1><ul>"
    html_content = ""
    for txt_file, results in res_dict.items():
        txt = txt_content[txt_file]
        html_index += f"<li><a href='#{txt_file}'>{txt_file}</a></li>"
        html_content += f"<li><h2>{txt_file}</h2><p>{_get_single_html(txt, results, is_full_html)}</p></li>"
    html_index += "</ul>"
    html_res = f"{html_index}{html_content}</body></html>"

    return html_res


def _get_all_rpt(res_dict: dict, is_full_report: bool = True):
    """
    Generate the stats report for all txt files.

    :param res_dict:       The dictionary of txt file name and the list of Presidio RecognizerResult constructed from analysis.
    :param is_full_report: Whether generate full report or not.

    :returns: The stats report for all txt files.
    """
    # These are placeholder for the json report
    stats_dict = {}
    for txt_file, results in res_dict.items():
        new_stats = []
        for item in _get_single_json(results, is_full_report):
            if is_full_report:
                item.analysis_explanation = item.analysis_explanation.to_dict()
                new_stats.append(item.to_dict())
            else:
                tmp_dict = item.to_dict()
                tmp_dict.pop("analysis_explanation")
                tmp_dict.pop("recognition_metadata")
                new_stats.append(tmp_dict)
        stats_dict[txt_file] = new_stats
    return stats_dict


def recognize_pii(
    context: mlrun.MLClientCtx,
    input_path: Union[str, pathlib.Path],
    html_key: str,
    score_threshold: float,
    output_directory: str = None,
    entities: List[
        str
    ] = None,  # List of entities to recognize, default is recognizing all
    entity_operator_map: dict = None,
    model: str = None,
    generate_json: bool = True,
    generate_html: bool = True,
    is_full_text: bool = True,
    is_full_html: bool = True,
    is_full_report: bool = True,
) -> Union[Tuple[str, pd.DataFrame, dict, dict], Tuple[str, pd.DataFrame, dict]]:
    """
    Walk through the input path, recognize PII in text and store the anonymized text in the output path.
    Generate the html with different colors for each entity, json report of the explanation.

    :param context:              The MLRun context. this is needed for log the artifacts.
    :param input_path:           The input path of the text files needs to be analyzed.
    :param html_key:             The html key for the artifact.
    :param score_threshold:      The score threshold to mark the recognition as trusted.
    :param output_directory:     The output directory path to store the anonymized text.
    :param entities:             The list of entities to recognize.
    :param entity_operator_map:  The map of entity to operator (mask, redact, replace, keep, hash, and its params)
    :param model:                The model to use. Can be "spacy", "flair", "pattern" or "whole".
    :param generate_json:        Whether to generate the json report of the explanation.
    :param generate_html:        Whether to generate the html report of the explanation.
    :param is_full_text:         Whether to return the full text or only the masked text.
    :param is_full_html:         Whether to return the full html or just the annotated text
    :param is_full_report:       Whether to return the full report or just the score and start, end index

    :returns: A tuple of:

              * Path to the output directory
              * The json report of the explanation (if generate_json is True)
              * A dictionary of errors files that were not processed

    """

    # Set output directory
    if output_directory is None:
        output_directory = tempfile.mkdtemp()

    # Create the output directory:
    output_directory = pathlib.Path(output_directory)
    if not output_directory.exists():
        output_directory.mkdir(parents=True, exist_ok=True)

    txt_files_directory = pathlib.Path(input_path)
    successes = []
    errors = {}

    res_dict = {}
    txt_content = {}
    # Load the model:
    analyzer = _get_analyzer_engine(model, entities)
    logger.info("Model loaded")
    # Go over the text files in the input path, analyze and anonymize them:
    for txt_file in tqdm(
        list(txt_files_directory.glob("*.txt")),
        desc="Processing files",
        unit="file",
    ):
        try:
            # Load the str from the text file
            text = txt_file.read_text()
            txt_content[str(txt_file)] = text
            # Process the text to recoginze the pii entities in it
            anonymized_text, results = _process(
                text=text,
                model=analyzer,
                entities=entities,
                entities_operator_map=entity_operator_map,
                score_threshold=score_threshold,
                is_full_text=is_full_text,
            )
            res_dict[str(txt_file)] = results
            # Store the anonymized text in the output path
            output_file = output_directory / f"{txt_file.stem}.txt"
            output_file.parent.mkdir(parents=True, exist_ok=True)
            with open(output_file, "w") as f:
                f.write(anonymized_text)
            successes.append([txt_file.name, output_file.name])
        except Exception as e:
            errors[str(txt_file)] = str(e)
            logger.error(f"Error processing {txt_file}: {e}")

    successes = pd.DataFrame(
        successes,
        columns=["original_file", "anonymized_file"],
    )

    if generate_html:
        # Generate the html report
        html_res = _get_all_html(txt_content, res_dict, is_full_html)
        # Store the html report in the context
        arti_html = mlrun.artifacts.Artifact(body=html_res, format="html", key=html_key)
        context.log_artifact(arti_html)
    if generate_json:
        # Generate the json report
        json_res = _get_all_rpt(res_dict, is_full_report)
        return str(output_directory), successes, errors, json_res
    return str(output_directory), successes, errors
 base_image: mlrun/mlrun - commands: - - python -m pip install nltk pandas presidio-anonymizer presidio-analyzer torch - flair@git+https://github.com/flairNLP/flair.git@d4ed67bf663e4066517f00397412510d90043653 - st-annotated-text https://huggingface.co/beki/en_spacy_pii_distilbert/resolve/main/en_spacy_pii_distilbert-any-py3-none-any.whl - code_origin: git@github.com-personal:pengwei715/functions.git#5468a7acb9b9fde12832e27daac2624f43746ee7:/Users/Peng_Wei/work/mlrun_related/functions/pii_recognizer/pii_recognizer.py - origin_filename: /Users/Peng_Wei/work/mlrun_related/functions/pii_recognizer/pii_recognizer.py - requirements: [] + commands: [] + code_origin: '' + origin_filename: '' + requirements: + - nltk + - pandas + - presidio-anonymizer + - presidio-analyzer + - torch + - flair@git+https://github.com/flairNLP/flair.git@d4ed67bf663e4066517f00397412510d90043653 + - st-annotated-text + - https://huggingface.co/beki/en_spacy_pii_distilbert/resolve/main/en_spacy_pii_distilbert-any-py3-none-any.whl entry_points: + analyze: + name: analyze + doc: Analyze text and return the results. + parameters: + - name: self + - name: text + type: str + doc: The text for analysis. + - name: entities + type: List[str] + doc: The list of entities to recognize. + - name: nlp_artifacts + type: pa.nlp_engine.NlpArtifacts + doc: Not used by this recognizer but needed for the interface. + default: null + outputs: + - doc: The list of Presidio RecognizerResult constructed from the recognized + Flair detections. + type: List[pa.RecognizerResult] + lineno: 381 + has_varargs: false + has_kwargs: false recognize_pii: name: recognize_pii - doc: Walk through the input path, recognize PII in text and store the anonymized - text in the output path. Generate the html with different colors for each - entity, json report of the explaination. + doc: 'Walk through the input path, recognize PII in text and store the anonymized + text in the output path. + + Generate the html with different colors for each entity, json report of the + explanation.' parameters: - name: context type: MLClientCtx doc: The MLRun context. this is needed for log the artifacts. - default: '' - name: input_path - type: str - doc: The input path of the text files needs to be analyzied. - default: '' - - name: output_path - type: str - doc: The output path to store the anonymized text. - default: '' - - name: output_suffix - type: str - doc: The surfix of output key for the anonymized text. for example if the - input file is pii.txt, the output key is anoymized, the output file name - will be pii_anonymized.txt. - default: '' + type: Union[str, Path] + doc: The input path of the text files needs to be analyzed. - name: html_key type: str doc: The html key for the artifact. - default: '' - name: score_threshold type: float doc: The score threshold to mark the recognition as trusted. - default: '' + - name: output_directory + type: str + doc: The output directory path to store the anonymized text. + default: null - name: entities type: List[str] doc: The list of entities to recognize. @@ -71,11 +90,11 @@ spec: default: null - name: generate_json type: bool - doc: Whether to generate the json report of the explaination. + doc: Whether to generate the json report of the explanation. default: true - name: generate_html type: bool - doc: Whether to generate the html report of the explaination. + doc: Whether to generate the html report of the explanation. default: true - name: is_full_text type: bool @@ -90,40 +109,20 @@ spec: doc: Whether to return the full report or just the score and start, end index default: true outputs: - - default: '' - doc: 'A tuple of:' - lineno: 850 + - doc: 'A tuple of:' + type: Union[Tuple[str, pd.DataFrame, dict, dict], Tuple[str, pd.DataFrame, + dict]] + lineno: 845 + has_varargs: false + has_kwargs: false description: This function is used to recognize PII in a directory of text files default_handler: recognize_pii disable_auto_mount: false clone_target_dir: '' env: [] - resources: - requests: - memory: 1Mi - cpu: 25m - limits: - memory: 20Gi - cpu: '2' priority_class_name: '' preemption_mode: prevent - affinity: - nodeAffinity: - requiredDuringSchedulingIgnoredDuringExecution: - nodeSelectorTerms: - - matchExpressions: - - key: app.iguazio.com/lifecycle - operator: NotIn - values: - - preemptible - - key: eks.amazonaws.com/capacityType - operator: NotIn - values: - - SPOT - - key: node-lifecycle - operator: NotIn - values: - - spot + affinity: null tolerations: null security_context: {} -verbose: false \ No newline at end of file +verbose: false diff --git a/pii_recognizer/item.yaml b/pii_recognizer/item.yaml index 5fa9f0ae4..2f618febc 100644 --- a/pii_recognizer/item.yaml +++ b/pii_recognizer/item.yaml @@ -30,5 +30,5 @@ spec: - st-annotated-text - https://huggingface.co/beki/en_spacy_pii_distilbert/resolve/main/en_spacy_pii_distilbert-any-py3-none-any.whl url: '' -version: 0.1.0 +version: 0.2.0 test_valid: False diff --git a/pii_recognizer/pii_recognizer.py b/pii_recognizer/pii_recognizer.py index 38c0e0ec3..0acc55dcb 100644 --- a/pii_recognizer/pii_recognizer.py +++ b/pii_recognizer/pii_recognizer.py @@ -1,35 +1,32 @@ -# Copyright 2019 Iguazio +# Copyright 2023 Iguazio # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# import logging import os import pathlib import tempfile import warnings -import pandas as pd -from collections.abc import Iterable -from multiprocessing import Pool, cpu_count -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import List, Set, Tuple, Union import annotated_text.util as at_util import mlrun import nltk +import pandas as pd import presidio_analyzer as pa import presidio_anonymizer as pre_anoymizer from presidio_anonymizer.entities import OperatorConfig -from tqdm.auto import tqdm +from tqdm import tqdm try: import flair as fl @@ -393,7 +390,6 @@ def analyze( :param text: The text for analysis. :param entities: The list of entities to recognize. :param nlp_artifacts: Not used by this recognizer but needed for the interface. - :param language: Text language. Supported languages in MODEL_LANGUAGES :returns: The list of Presidio RecognizerResult constructed from the recognized Flair detections. """ @@ -711,11 +707,11 @@ def _process( entities: List[str] = None, entities_operator_map: dict = None, is_full_text: bool = True, -) -> Tuple[str, str, str]: +) -> Tuple[str, list]: """ Process the text of str using the model. - :param txt: Text to process + :param text: Text to process :param model: Model to use for processing :param entities: Entities to recognize :param entities_operator_map: The entity_operator_map is a dictionary that maps entity to operator name and operator params. @@ -729,7 +725,6 @@ def _process( """ # get the analyzer engine - analyzer = model # analyze the text that can be used for anonymization @@ -850,9 +845,9 @@ def _get_all_rpt(res_dict: dict, is_full_report: bool = True): def recognize_pii( context: mlrun.MLClientCtx, input_path: Union[str, pathlib.Path], - output_path: str, html_key: str, score_threshold: float, + output_directory: str = None, entities: List[ str ] = None, # List of entities to recognize, default is recognizing all @@ -863,20 +858,21 @@ def recognize_pii( is_full_text: bool = True, is_full_html: bool = True, is_full_report: bool = True, -) -> Tuple[pathlib.Path, dict, dict]: +) -> Union[Tuple[str, pd.DataFrame, dict, dict], Tuple[str, pd.DataFrame, dict]]: """ - Walk through the input path, recognize PII in text and store the anonymized text in the output path. Generate the html with different colors for each entity, json report of the explaination. + Walk through the input path, recognize PII in text and store the anonymized text in the output path. + Generate the html with different colors for each entity, json report of the explanation. :param context: The MLRun context. this is needed for log the artifacts. - :param input_path: The input path of the text files needs to be analyzied. - :param output_path: The output path to store the anonymized text. + :param input_path: The input path of the text files needs to be analyzed. :param html_key: The html key for the artifact. :param score_threshold: The score threshold to mark the recognition as trusted. + :param output_directory: The output directory path to store the anonymized text. :param entities: The list of entities to recognize. :param entity_operator_map: The map of entity to operator (mask, redact, replace, keep, hash, and its params) :param model: The model to use. Can be "spacy", "flair", "pattern" or "whole". - :param generate_json: Whether to generate the json report of the explaination. - :param generate_html: Whether to generate the html report of the explaination. + :param generate_json: Whether to generate the json report of the explanation. + :param generate_html: Whether to generate the html report of the explanation. :param is_full_text: Whether to return the full text or only the masked text. :param is_full_html: Whether to return the full html or just the annotated text :param is_full_report: Whether to return the full report or just the score and start, end index @@ -884,46 +880,38 @@ def recognize_pii( :returns: A tuple of: * Path to the output directory - * The json report of the explaination (if generate_json is True) + * The json report of the explanation (if generate_json is True) * A dictionary of errors files that were not processed """ # Set output directory - if output_path is None: - output_path = tempfile.mkdtemp() + if output_directory is None: + output_directory = tempfile.mkdtemp() # Create the output directory: - output_directory = pathlib.Path(output_path) + output_directory = pathlib.Path(output_directory) if not output_directory.exists(): - output_directory.mkdir() + output_directory.mkdir(parents=True, exist_ok=True) txt_files_directory = pathlib.Path(input_path) + successes = [] errors = {} res_dict = {} txt_content = {} # Load the model: - try: - analyzer = _get_analyzer_engine(model, entities) - except Exception as e: - errors["model"] = str(e) - logger.error(f"Error when get the model: {e}") - + analyzer = _get_analyzer_engine(model, entities) logger.info("Model loaded") # Go over the text files in the input path, analyze and anonymize them: - for i, txt_file in enumerate( - tqdm( - list(txt_files_directory.glob("*.txt")), - desc="Processing files", - unit="file", - ) + for txt_file in tqdm( + list(txt_files_directory.glob("*.txt")), + desc="Processing files", + unit="file", ): try: # Load the str from the text file text = txt_file.read_text() - # TODO maybe the encoding issue if from this function call of tqdm.read_text() - # Need to fix it later txt_content[str(txt_file)] = text # Process the text to recoginze the pii entities in it anonymized_text, results = _process( @@ -936,158 +924,19 @@ def recognize_pii( ) res_dict[str(txt_file)] = results # Store the anonymized text in the output path - output_file = ( - output_directory - / f"{str(txt_file.relative_to(txt_files_directory)).split('.')[0]}.txt" - ) + output_file = output_directory / f"{txt_file.stem}.txt" output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, "w") as f: f.write(anonymized_text) - + successes.append([txt_file.name, output_file.name]) except Exception as e: errors[str(txt_file)] = str(e) logger.error(f"Error processing {txt_file}: {e}") - if generate_html: - # Generate the html report - html_res = _get_all_html(txt_content, res_dict, is_full_html) - # Store the html report in the context - arti_html = mlrun.artifacts.Artifact(body=html_res, format="html", key=html_key) - context.log_artifact(arti_html) - if generate_json: - # Generate the json report - json_res = _get_all_rpt(res_dict, is_full_report) - return output_path, json_res, errors - return output_path, errors - - -def _recognize_pii_one_file( - input_file: str, - output_file: str, - score_threshold: float, - entities: List[ - str - ] = None, # List of entities to recognize, default is recognizing all - entity_operator_map: dict = None, - model: str = None, - is_full_text: bool = True, -) -> Tuple[dict, dict, dict]: - """ - Recognize PII in text and store the anonymized text in the output path. Generate the html with different colors for each entity, json report of the explaination. - :param input_file: The input path of the text files needs to be analyzied. - :param output_file: The output path to store the anonymized text. - :param score_threshold: The score threshold to mark the recognition as trusted. - :param entities: The list of entities to recognize. - :param entity_operator_map: The map of entity to operator (mask, redact, replace, keep, hash, and its params) - :param model: The model to use. Can be "spacy", "flair", "pattern" or "whole". - :param is_full_text: Whether to return the full text or only the masked text. - - :returns: A tuple of: - * A dictionary of the text content of the input file - * A dictionary of the results of the explaination - * A dictionary of errors files that were not processed - """ - errors = {} - res_dict = {} - txt_content = {} - # Load the model: - try: - analyzer = _get_analyzer_engine(model, entities) - except Exception as e: - errors["model"] = str(e) - logger.error(f"Error when get the model: {e}") - - logger.info("Model loaded") - try: - # Load the str from the text file - with open(input_file, "r", encoding="utf-8") as file: - text = file.read() - txt_content[str(input_file)] = text - # Process the text to recoginze the pii entities in it - anonymized_text, results = _process( - text=text, - model=analyzer, - entities=entities, - entities_operator_map=entity_operator_map, - score_threshold=score_threshold, - is_full_text=is_full_text, - ) - res_dict[str(input_file)] = results - with open(output_file, "w", encoding="utf-8") as f: - f.write(anonymized_text) - - except Exception as e: - errors[str(txt_file)] = str(e) - logger.error(f"Error processing {txt_file}: {e}") - - return res_dict, txt_content, errors - - -def recognize_pii_parallel( - context: mlrun.MLClientCtx, - config_input_output: str, - score_threshold: float, - html_key: str, - entities: List[str] = None, - entity_operator_map: Dict = None, - model: str = None, - generate_html: bool = True, - generate_json: bool = True, - is_full_html: bool = True, - is_full_text: bool = True, - is_full_report: bool = True, - num_processes: int = None, -) -> Tuple[dict, dict]: - """Doing a fan-in and fan-out pattern using mutiple processes for cpu node, Since our model is mixed with rule_based and NLP model based. Both Spacy and Flair do not support the cuda GPU natively. For now, we can use all the cores that a CPU offers. - :param context: The MLRun context. this is needed - :param config_input_output csv file which have the input file path and output file path - :param score_threshold: The threshold of the score to recognize the entities - :param html_key: The key of the html report in the context - :entities List of entities to recognize, default is recognizing all - :entity_operator_map The map of the entities and the operator to use. For example, {"PERSON": "replace", "LOCATION": "mask"} - :param model The model to use. Can be "spacy", "flair", "pattern" or "whole". - :param generate_html: Whether to generate the html report - :param generate_json: Whether to generate the json report - :param is_full_html: Whether to generate the full html report - :param is_full_text: Whether to generate the full text in the html report - :param is_full_report: Whether to generate the full json report - :param num_process The number of process to run in parallel - - :returns: A tuple of: - * A json report of the result explaination - * A dictionary of errors files that were not processed - - """ - if num_processes is None: - num_processes = cpu_count() - - # Read the CSV into a DataFrame - config_df = pd.read_csv(config_input_output) - - # Convert DataFrame rows into a list of tuples, each tuple is arguments for `_recognize_pii_one_file` - tasks = [ - ( - row["input_file"], - row["output_file"], - score_threshold, - entities, - entity_operator_map, - model, - is_full_text, - ) - for _, row in config_df.iterrows() - ] - # Create a pool of processes and distribute the tasks - with Pool(processes=num_processes) as pool: - res = pool.starmap(_recognize_pii_one_file, tasks) - # Get the results - res_dict = {} - txt_content = {} - errors = {} - for r in res: - res_dict.update(r[0]) - txt_content.update(r[1]) - errors.update(r[2]) + successes = pd.DataFrame( + successes, + columns=["original_file", "anonymized_file"], + ) if generate_html: # Generate the html report @@ -1098,5 +947,5 @@ def recognize_pii_parallel( if generate_json: # Generate the json report json_res = _get_all_rpt(res_dict, is_full_report) - return json_res, errors - return errors + return str(output_directory), successes, errors, json_res + return str(output_directory), successes, errors diff --git a/speech_diarization/assets/test_data.wav b/pyannote_audio/assets/test_data.wav similarity index 100% rename from speech_diarization/assets/test_data.wav rename to pyannote_audio/assets/test_data.wav diff --git a/pyannote_audio/function.yaml b/pyannote_audio/function.yaml new file mode 100644 index 000000000..1229e0f32 --- /dev/null +++ b/pyannote_audio/function.yaml @@ -0,0 +1,146 @@ +kind: job +metadata: + name: pyannote-audio + tag: '' + hash: 335752327ddd14b62222bd45faa3a88704505b66 + project: '' + labels: + author: guyl + categories: + - Deep Learning + - Huggingface + - Audio +spec: + command: '' + args: [] + image: '' + build: + functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import heapq
import logging
import operator
import os
import pathlib
from functools import reduce, wraps
from typing import Any, Dict, List, Tuple, Union

import pandas as pd
import pyannote.audio
import pyannote.core
import torch
import torchaudio
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
        else:
            return context, None
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


def open_mpi_handler(
    worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    # Check if MLRun is available, set the global logger to MLRun's:
    if context:
        _LOGGER = context.logger

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, str):
                    input_argument = _get_audio_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                diarization_dictionary = reduce(
                    operator.ior, [dia for dia, _ in output], {}
                )
                errors_dictionary = reduce(operator.ior, [err for _, err in output], {})
                return diarization_dictionary, errors_dictionary
            return None

        return wrapper

    return decorator


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def diarize(
    data_path: Union[str, List[str]],
    model_name: str = "pyannote/speaker-diarization-3.0",
    access_token: str = None,
    device: str = None,
    speakers_labels: List[str] = None,
    speaker_prefix: str = "speaker_",
    separate_by_channels: bool = False,
    minimum_speakers: int = None,
    maximum_speakers: int = None,
    verbose: bool = False,
) -> Tuple[Dict[str, List[Tuple[float, float, str]]], Dict[str, str]]:
    """
    Perform speech diarization on given audio files using pyannote-audio (https://github.com/pyannote/pyannote-audio).
    The end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list
    of tuples: (start, end, speaker_label).

    To use the `pyannote.audio` models you must pass a Huggingface token and get access to the required models. The
    token can be passed in one of the following options:

    * Use the parameter `access_token`.
    * Set an environment variable named "HUGGING_FACE_HUB_TOKEN".
    * If using MLRun, you can pass it as a secret named "HUGGING_FACE_HUB_TOKEN".

    To get access to the models on Huggingface, visit their page. For example, to use the default diarization model set
    in this function ("pyannote/speaker-diarization-3.0"), you need access for these two models:

    * https://huggingface.co/pyannote/segmentation-3.0
    * https://huggingface.co/pyannote/speaker-diarization-3.0

    Note: To control the recognized speakers in the diarization output you can choose one of the following methods:

    * For a known speakers amount, you may set speaker labels via the `speakers_labels` parameter that will be used in
      the order of speaking in the audio (first person speaking be the first label in the list). In addition, you can do
      diarization per channel (setting the parameter `separate_by_channels` to True). Each label will be assigned to a
      specific channel by order (first label to channel 0, second label to channel 1 and so on). Notice, this will
      increase runtime.
    * For unknown speakers amount, you can set the `speaker_prefix` parameter to add a prefix for each speaker number.
      You can also help the diarization by setting the speakers range via the `speakers_amount_range` parameter.

    :param data_path:            A directory of the audio files, a single file or a list of files to transcribe.
    :param model_name:           One of the official diarization model names (referred as diarization pipelines) of
                                 `pyannote.audio` Huggingface page. Default: "pyannote/speaker-diarization-3.0".
    :param access_token:         An access token to pass for using the `pyannote.audio` models. If not provided, it
                                 will be looking for the environment variable "HUGGING_FACE_HUB_TOKEN". If MLRun is
                                 available, it will look for a secret "HUGGING_FACE_HUB_TOKEN".
    :param device:               Device to load the model. Can be one of {"cuda", "cpu"}. Default will prefer "cuda" if
                                 available.
    :param speakers_labels:      Labels to use for the recognized speakers. Default: numeric labels (0, 1, ...).
    :param separate_by_channels: If each speaker is speaking in a separate channel, you can diarize each channel and
                                 combine the result into a single diarization. Each label set in the `speakers_labels`
                                 parameter will be assigned to a specific channel by order.
    :param speaker_prefix:       A prefix to add for the speakers labels. This parameter is ignored if
                                 `speakers_labels` is not None. Default: "speaker".
    :param minimum_speakers:     Set the minimum expected amount of speakers to be in the audio files. This parameter is
                                 ignored if `speakers_labels` is not None.
    :param maximum_speakers:     Set the maximum expected amount of speakers to be in the audio files. This parameter is
                                 ignored if `speakers_labels` is not None.
    :param verbose:              Whether to present logs of a progress bar and errors. Default: True.

    :returns: A tuple of:

              * Speech diarization dictionary.
              * A dictionary of errored files that were not transcribed.
    """
    global _LOGGER

    # Get the input audio files to diarize:
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        audio_files = _get_audio_files(data_path=data_path)
    else:  # Should be a list of files.
        audio_files = data_path

    # Get the Huggingface access token:
    access_token = _get_access_token(parameter=access_token)
    if access_token is None:
        raise ValueError(
            "A Huggingface access token must be provided to use `pyannote.audio` models. Access token can be passed "
            "via one of the following options:\n"
            "* Use the parameter `access_token`.\n"
            "* Set an environment variable named 'HUGGING_FACE_HUB_TOKEN'.\n"
            "* If using MLRun, you can pass it as a secret named 'HUGGING_FACE_HUB_TOKEN'."
        )

    # Load the diarization pipeline:
    pipeline = pyannote.audio.Pipeline.from_pretrained(
        checkpoint_path=model_name, use_auth_token=access_token
    )

    # Set the device:
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if device != "cpu":
        pipeline.to(torch.device(device))

    # Prepare the successes dataframe and errors dictionary to be returned:
    diarizations = {}
    errors = {}

    # Prepare the diarization keyword arguments:
    diarize_kwargs = {}
    if speakers_labels:
        diarize_kwargs["num_speakers"] = len(speakers_labels)
    else:
        if minimum_speakers:
            diarize_kwargs["min_speakers"] = minimum_speakers
        if maximum_speakers:
            diarize_kwargs["max_speakers"] = maximum_speakers

    # Go over the audio files and diarize:
    for audio_file in tqdm(
        audio_files, desc="Diarizing", unit="file", disable=not verbose
    ):
        try:
            # Load audio file:
            audio, sample_rate = torchaudio.load(uri=audio_file, channels_first=True)
            # Get the diarization (if provided):
            diarizations[audio_file.name] = _diarize(
                audio=audio,
                sample_rate=sample_rate,
                pipeline=pipeline,
                speakers_labels=speakers_labels,
                separate_by_channels=separate_by_channels,
                speaker_prefix=speaker_prefix,
                diarize_kwargs=diarize_kwargs,
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{audio_file.name}'")
            errors[str(audio_file.name)] = str(exception)
            continue

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(f"Done ({len(diarizations)}/{len(audio_files)})\n")
    return diarizations, errors


def _get_audio_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return audio_files


def _get_access_token(parameter: str) -> str:
    # If given as a parameter, return it:
    if parameter:
        return parameter

    # Otherwise, look at the environment variable:
    environment_variable = os.environ.get("HUGGING_FACE_HUB_TOKEN")
    if environment_variable:
        return environment_variable

    # Lastly, try look in the set secrets in MLRun:
    secret = None
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        secret = context.get_secret(key="HUGGING_FACE_HUB_TOKEN")
    except ModuleNotFoundError:
        pass

    return secret


def _diarize(
    audio: torch.Tensor,
    sample_rate: int,
    pipeline: pyannote.audio.Pipeline,
    speakers_labels: List[str],
    separate_by_channels: bool,
    speaker_prefix: str,
    diarize_kwargs: dict,
) -> List[Tuple[float, float, str]]:
    # If there is no need for separation by channels, we diarize and return:
    if not separate_by_channels:
        # Diarize:
        diarization: pyannote.core.Annotation = pipeline(
            file={"waveform": audio, "sample_rate": sample_rate}, **diarize_kwargs
        )
        # Verify speakers labels (should not fail here as we set `num_speakers=len(speakers_labels)` when inferring
        # through the pipeline):
        if speakers_labels:
            given_speakers = len(speakers_labels)
            found_speakers = len(set(diarization.labels()))
            if given_speakers < found_speakers:
                raise ValueError(
                    f"Not enough `speakers_labels` were given. Got {given_speakers} labels but the diarization "
                    f"recognized {found_speakers} speakers."
                )
        # Return as a diarization list - a sorted list of tuples of start time, end time and a label (the default label
        # returned is "SPEAKER_i" so we take only the index out of it):
        return [
            (
                segment.start,
                segment.end,
                speakers_labels[int(label.split("_")[1])]
                if speakers_labels
                else f"{speaker_prefix}{int(label.split('_')[1])}",
            )
            for segment, track, label in diarization.itertracks(yield_label=True)
        ]

    # Separate to channels and diarize (we expect only one speaker per channel):
    channel_diarizations = [
        _diarize(
            audio=audio[channel].unsqueeze(
                0
            ),  # Take channel and add a channel dimension to it.
            sample_rate=sample_rate,
            pipeline=pipeline,
            speakers_labels=[
                speakers_labels[channel]
            ],  # Take the channel's label only.
            separate_by_channels=False,
            speaker_prefix=speaker_prefix,
            diarize_kwargs={"num_speakers": 1},  # Set to one speaker.
        )
        for channel in range(audio.shape[0])
    ]

    # Merge the channel diarizations into a single sorted list:
    return list(heapq.merge(*channel_diarizations))
 + base_image: mlrun/mlrun-gpu + commands: [] + code_origin: '' + origin_filename: '' + requirements: + - pyannote.audio + - pyannote.core + - torchaudio + - tqdm + entry_points: + open_mpi_handler: + name: open_mpi_handler + doc: '' + parameters: + - name: worker_inputs + type: List[str] + - name: root_worker_inputs + type: Dict[str, Any] + default: null + outputs: + - default: '' + lineno: 61 + decorator: + name: decorator + doc: '' + parameters: + - name: handler + outputs: + - default: '' + lineno: 73 + wrapper: + name: wrapper + doc: '' + parameters: [] + outputs: + - default: '' + lineno: 78 + diarize: + name: diarize + doc: "Perform speech diarization on given audio files using pyannote-audio (https://github.com/pyannote/pyannote-audio).\n\ + The end result is a dictionary with the file names as keys and their diarization\ + \ as value. A diarization is a list\nof tuples: (start, end, speaker_label).\n\ + \nTo use the `pyannote.audio` models you must pass a Huggingface token and\ + \ get access to the required models. The\ntoken can be passed in one of the\ + \ following options:\n\n* Use the parameter `access_token`.\n* Set an environment\ + \ variable named \"HUGGING_FACE_HUB_TOKEN\".\n* If using MLRun, you can pass\ + \ it as a secret named \"HUGGING_FACE_HUB_TOKEN\".\n\nTo get access to the\ + \ models on Huggingface, visit their page. For example, to use the default\ + \ diarization model set\nin this function (\"pyannote/speaker-diarization-3.0\"\ + ), you need access for these two models:\n\n* https://huggingface.co/pyannote/segmentation-3.0\n\ + * https://huggingface.co/pyannote/speaker-diarization-3.0\n\nNote: To control\ + \ the recognized speakers in the diarization output you can choose one of\ + \ the following methods:\n\n* For a known speakers amount, you may set speaker\ + \ labels via the `speakers_labels` parameter that will be used in\n the order\ + \ of speaking in the audio (first person speaking be the first label in the\ + \ list). In addition, you can do\n diarization per channel (setting the parameter\ + \ `separate_by_channels` to True). Each label will be assigned to a\n specific\ + \ channel by order (first label to channel 0, second label to channel 1 and\ + \ so on). Notice, this will\n increase runtime.\n* For unknown speakers amount,\ + \ you can set the `speaker_prefix` parameter to add a prefix for each speaker\ + \ number.\n You can also help the diarization by setting the speakers range\ + \ via the `speakers_amount_range` parameter." + parameters: + - name: data_path + type: Union[str, List[str]] + doc: A directory of the audio files, a single file or a list of files to transcribe. + - name: model_name + type: str + doc: 'One of the official diarization model names (referred as diarization + pipelines) of `pyannote.audio` Huggingface page. Default: "pyannote/speaker-diarization-3.0".' + default: pyannote/speaker-diarization-3.0 + - name: access_token + type: str + doc: An access token to pass for using the `pyannote.audio` models. If not + provided, it will be looking for the environment variable "HUGGING_FACE_HUB_TOKEN". + If MLRun is available, it will look for a secret "HUGGING_FACE_HUB_TOKEN". + default: null + - name: device + type: str + doc: Device to load the model. Can be one of {"cuda", "cpu"}. Default will + prefer "cuda" if available. + default: null + - name: speakers_labels + type: List[str] + doc: 'Labels to use for the recognized speakers. Default: numeric labels (0, + 1, ...).' + default: null + - name: speaker_prefix + type: str + doc: 'A prefix to add for the speakers labels. This parameter is ignored if + `speakers_labels` is not None. Default: "speaker".' + default: speaker_ + - name: separate_by_channels + type: bool + doc: If each speaker is speaking in a separate channel, you can diarize each + channel and combine the result into a single diarization. Each label set + in the `speakers_labels` parameter will be assigned to a specific channel + by order. + default: false + - name: minimum_speakers + type: int + doc: Set the minimum expected amount of speakers to be in the audio files. + This parameter is ignored if `speakers_labels` is not None. + default: null + - name: maximum_speakers + type: int + doc: Set the maximum expected amount of speakers to be in the audio files. + This parameter is ignored if `speakers_labels` is not None. + default: null + - name: verbose + type: bool + doc: 'Whether to present logs of a progress bar and errors. Default: True.' + default: false + outputs: + - doc: 'A tuple of:' + default: '' + lineno: 139 + description: pyannote's speech diarization of audio files + default_handler: diarize + disable_auto_mount: false + clone_target_dir: '' + env: [] + priority_class_name: '' + preemption_mode: prevent + affinity: null + tolerations: null + security_context: {} +verbose: false diff --git a/speech_diarization/item.yaml b/pyannote_audio/item.yaml similarity index 71% rename from speech_diarization/item.yaml rename to pyannote_audio/item.yaml index f49dbc319..603c1a361 100644 --- a/speech_diarization/item.yaml +++ b/pyannote_audio/item.yaml @@ -3,9 +3,9 @@ categories: - Deep Learning - Huggingface - Audio -description: speech diarization of audio files +description: pyannote's speech diarization of audio files doc: '' -example: speech_diarization.ipynb +example: pyannote_audio.ipynb generationDate: 2023-12-03:14-30 hidden: false icon: '' @@ -14,10 +14,10 @@ labels: maintainers: [] marketplaceType: '' mlrunVersion: 1.5.2 -name: speech_diarization +name: pyannote-audio platformVersion: 3.5.3 spec: - filename: speech_diarization.py + filename: pyannote_audio.py handler: diarize image: mlrun/mlrun-gpu kind: job @@ -27,4 +27,4 @@ spec: - torchaudio - tqdm url: '' -version: 2.0.0 +version: 1.0.0 diff --git a/speech_diarization/speech_diarization.ipynb b/pyannote_audio/pyannote_audio.ipynb similarity index 100% rename from speech_diarization/speech_diarization.ipynb rename to pyannote_audio/pyannote_audio.ipynb diff --git a/speech_diarization/speech_diarization.py b/pyannote_audio/pyannote_audio.py similarity index 100% rename from speech_diarization/speech_diarization.py rename to pyannote_audio/pyannote_audio.py diff --git a/speech_diarization/test_speech_diarization.py b/pyannote_audio/test_pyannote_audio.py similarity index 79% rename from speech_diarization/test_speech_diarization.py rename to pyannote_audio/test_pyannote_audio.py index 71a95575a..93da50834 100644 --- a/speech_diarization/test_speech_diarization.py +++ b/pyannote_audio/test_pyannote_audio.py @@ -1,4 +1,5 @@ import os + import mlrun import pytest @@ -6,8 +7,9 @@ @pytest.mark.skipif("HUGGING_FACE_HUB_TOKEN" not in os.environ, reason="no token") def test_speech_diarization(): project = mlrun.new_project("diarization-test2") - speech_diarization = project.set_function(func="speech_diarization.py", name="speech_diarization", - image="mlrun/mlrun") + speech_diarization = project.set_function( + func="./function.yaml", name="speech_diarization", image="mlrun/mlrun" + ) diarize_run = speech_diarization.run( handler="diarize", diff --git a/question_answering/function.yaml b/question_answering/function.yaml index fad891ac9..a33614153 100644 --- a/question_answering/function.yaml +++ b/question_answering/function.yaml @@ -2,7 +2,7 @@ kind: job metadata: name: question-answering tag: '' - hash: 9f9635a21ce5ea490c939297c7cb60f5b21945ab + hash: 90e67d116b256a98da7d5819724e43df01d8b4eb project: '' labels: author: yonish @@ -13,13 +13,15 @@ spec: args: [] image: '' build: - functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import operator
import pathlib
from collections import Counter
from functools import reduce, wraps
from typing import Any, Dict, List, Tuple, Union

import pandas as pd
import transformers
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    global _LOGGER

    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        _LOGGER = context.logger
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


def open_mpi_handler(
    worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, str):
                    input_argument = _get_text_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                dataframe = pd.concat(objs=[df for df, _ in output], axis=0)
                errors_dictionary = reduce(operator.ior, [err for _, err in output], {})
                return dataframe, errors_dictionary
            return None

        return wrapper

    return decorator


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def answer_questions(
    data_path: Union[str, List[str]],
    model_name: str,
    questions: Union[List[str], List[List[str]]],
    device_map: Union[str, dict] = None,
    model_kwargs: dict = None,
    auto_gptq_exllama_max_input_length: int = None,
    tokenizer_name: str = None,
    tokenizer_kwargs: dict = None,
    text_wrapper: Union[str, List[str]] = "",
    questions_wrapper: Union[str, List[str]] = "",
    generation_config: Union[Dict, List[Dict]] = None,
    questions_config: Union[Dict, List[Dict]] = None,
    batch_size: int = 1,
    questions_columns: List[str] = None,
    verbose: bool = False,
) -> Tuple[pd.DataFrame, dict]:
    """
    Answer questions with a context to the given text files contents by a pretrained LLM model. Each text file will have
    the following prompt built:

    start of `text_wrapper`
    <text file content>
    end of `text_wrapper`

    start of `questions_wrapper`
    1. <questions[0]>
    2. <questions[1]>
    ...
    n. <questions[n-1]>
    end of `questions_wrapper`

    :param data_path:                          A path to a directory of text files or a path to a text file to ask
                                               questions about.
    :param model_name:                         The pre-trained model name from the huggingface hub to use for asking
                                               questions.
    :param questions:                          The questions to ask.
                                               A list of lists of questions to ask per text file, and devided
                                               by question groups, the groups can be dtermained by size (in order to
                                               avoid large inputs to the llm) or by questioning method
                                               (regular or poll like questioning).
    :param device_map:                         A map to use for loading the model on multiple devices.
    :param model_kwargs:                       Keyword arguments to pass for loading the model using HuggingFace's
                                               `transformers.AutoModelForCausalLM.from_pretrained` function.
    :param auto_gptq_exllama_max_input_length: For AutoGPTQ models to set and extend the model's input buffer size.
    :param tokenizer_name:                     The tokenizer name from the huggingface hub to use. If not given, the
                                               model name will be used.
    :param tokenizer_kwargs:                   Keyword arguments to pass for loading the tokenizer using HuggingFace's
                                               `transformers.AutoTokenizer.from_pretrained` function.
    :param text_wrapper:                       A wrapper for the file's text. Will be added at the start of the prompt.
                                               Must have a placeholder ('{}') for the text of the file.
    :param questions_wrapper:                  A wrapper for the questions received. Will be added after the text
                                               wrapper in the prompt template. Must have a placeholder ('{}') for the
                                               questions.
    :param generation_config:                  HuggingFace's `GenerationConfig` keyword arguments to pass to the
                                               `generate` method.
    :param questions_config:                   A dictionary or list of dictionaries containing specific ways to answer
                                               questions (using a poll for example), each dictionary in the list is for
                                               corresponding question group and determines the question asking method
                                               for said group.
    :param batch_size:                         Batch size for inference.
    :param questions_columns:                  Columns to use for the dataframe returned.
    :param verbose:                            Whether to present logs of a progress bar and errors. Default: True.


    :returns: A tuple of:

              * A dataframe dataset of the questions answers.
              * A dictionary of errored files that were not inferred or were not answered properly.
    """
    global _LOGGER

    # Set configs to empty dict if not given:
    if generation_config is None:
        generation_config = {}
    if questions_config is None:
        questions_config = {}

    # Get the input text files to question:
    if verbose:
        _LOGGER.info("Collecting text files.")
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        text_files = _get_text_files(data_path=data_path)
    else:
        text_files = data_path
    if verbose:
        _LOGGER.info(f"Collected {len(text_files)} text files.")

    # Get the prompt template:
    if verbose:
        _LOGGER.info("Creating prompt template.")

    # Organize questions as a list of list, and count number of sub-lists for future use
    number_of_question_groups = 1 if isinstance(questions[0], str) else len(questions)
    questions = _to_group_list(
        argument_value=questions,
        argument_name="questions",
        length=number_of_question_groups,
    )

    # Organize prompt parts at proper length
    text_wrapper = _to_group_list(
        argument_value=text_wrapper,
        argument_name="text_wrapper",
        length=number_of_question_groups,
    )
    questions_wrapper = _to_group_list(
        argument_value=questions_wrapper,
        argument_name="questions_wrapper",
        length=number_of_question_groups,
    )

    # Create a list of prompt according to given parts and questions
    prompt_template = []
    questions = questions if isinstance(questions[0], list) else [questions]

    # Build all prompts
    for i in range(number_of_question_groups):
        prompt_template.append(
            _get_prompt_template(
                text_wrapper=text_wrapper[i],
                questions_wrapper=questions_wrapper[i],
                questions=questions[i],
            )
        )
    if verbose:
        _LOGGER.info(f"Prompt template created:\n\n{prompt_template}\n")

    # Get the total amount of questions:
    questions_amount = sum([len(sublist) for sublist in questions])

    # Get the questions columns:
    questions_columns = questions_columns or [
        f"q{i}" for i in range(1, questions_amount + 1)
    ]

    # Check if we have the correct amount of questions columns:
    if len(questions_columns) != questions_amount:
        raise ValueError(
            f"The provided questions columns length ({len(questions_columns)}) "
            f"does not match the questions amount ({questions_amount})"
        )

    # Load the generation config:
    if verbose:
        _LOGGER.info("Loading generation configuration.")
    generation_config = [
        transformers.GenerationConfig(**(cfg or {}))
        for cfg in _to_group_list(
            argument_value=generation_config,
            argument_name="generation_config",
            length=number_of_question_groups,
        )
    ]
    if verbose:
        _LOGGER.info(f"Generation configuration loaded: {generation_config}")

    # Load the model and tokenizer into a pipeline object:
    if verbose:
        _LOGGER.info(f"Loading model '{model_name}'.")
    generation_pipeline = _get_generation_pipeline(
        model_name=model_name,
        device_map=device_map,
        tokenizer_name=tokenizer_name or model_name,
        model_kwargs=model_kwargs or {},
        tokenizer_kwargs=tokenizer_kwargs or {},
        auto_gptq_exllama_max_input_length=auto_gptq_exllama_max_input_length,
        batch_size=batch_size,
    )
    if verbose:
        _LOGGER.info("Model loaded.")

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Split the files into batches:
    file_batches = [
        text_files[i : i + batch_size]
        if i + batch_size < len(text_files)
        else text_files[i:]
        for i in range(0, len(text_files), batch_size)
    ]
    questions_config = _to_group_list(
        argument_value=questions_config,
        argument_name="questions_config",
        length=number_of_question_groups,
    )

    # Create a list of question handlers according to given configs
    handlers = []
    for cfg in questions_config:
        question_type = cfg.pop("type", "default")
        handlers.append(QUESTION_MAPPING.get(question_type)(**cfg))

    # Go over the batches of text files and question them:
    for file_batch in tqdm(
        file_batches,
        desc="Generating answers",
        unit=f"file (batch of {batch_size})",
        disable=not verbose,
    ):
        try:
            total_answers = [[] for _ in range(batch_size)]

            # Go over all question group per batch of documents
            for question_group in range(number_of_question_groups):
                current_questions_amount = len(questions[question_group])

                # Read batch (read the text from the text files):
                batched_input = _read_file_batch(
                    file_batch=file_batch,
                    prompt_template=prompt_template[question_group],
                )

                # Answer the questions with each question handler:
                batched_answers = handlers[question_group].answer(
                    questions_amount=current_questions_amount,
                    batched_input=batched_input,
                    generation_pipeline=generation_pipeline,
                    generation_config=generation_config[question_group],
                )

                # Put the answers in the correct place in the total answers list according to the place in the batch:
                for i in range(batch_size):
                    total_answers[i].extend(batched_answers[i])

            # Collect the answers and attach the file name:
            successes.extend(
                [
                    [file.name, *answers]
                    for file, answers in zip(file_batch, total_answers)
                ]
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            batch_file_names = ", ".join([file.name for file in file_batch])
            if verbose:
                _LOGGER.warning(
                    f"Error in batch '{batch_file_names}': {str(exception)}"
                )
            errors[batch_file_names] = str(exception)
            continue

    # Construct the answers dataframe:
    columns = [
        "text_file",
        *questions_columns,
    ]

    # Create a data frame of answers by files
    successes = pd.DataFrame(
        successes,
        columns=columns,
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Answers summary:\n"
            f"{successes.head()}"
        )
    return successes, errors


def _get_text_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:

    # Check if the path is of a directory or a file:
    if data_path.is_dir():

        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _get_prompt_template(
    text_wrapper: str,
    questions_wrapper: str,
    questions: List[str],
) -> str:

    # Validate and build the text wrapper:
    text_wrapper = text_wrapper or (
        "Given the following text:\n" "-----\n" "{}\n" "-----"
    )
    if text_wrapper.count("{}") != 1:
        raise ValueError(
            "The `text_wrapper` must include one placeholder '{}' for the text of the file to be asked about."
        )

    # Validate and build the question wrapper:
    questions_wrapper = questions_wrapper or "Answer the questions:\n" "{}"
    if questions_wrapper.count("{}") != 1:
        raise ValueError(
            "The `questions_wrapper` must include one placeholder '{}' for the list of questions."
        )

    # Validate and parse the questions:
    if len(questions) == 0:
        raise ValueError("Please include at least one question.")
    questions = "\n".join(
        [f"{i}. {question}" for i, question in enumerate(questions, 1)]
    )

    # Construct the template:
    return f"{text_wrapper}\n{questions_wrapper.format(questions)}\n"


def _get_generation_pipeline(
    model_name: str,
    device_map: Union[str, dict],
    tokenizer_name: str,
    model_kwargs: dict,
    tokenizer_kwargs: dict,
    auto_gptq_exllama_max_input_length: int = None,
    batch_size: int = 1,
):
    # Load the model:
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, device_map=device_map, **model_kwargs
    )

    # Set exllama max input length if provided:
    # This changes the model's context size.
    if auto_gptq_exllama_max_input_length:
        from auto_gptq import exllama_set_max_input_length

        model = exllama_set_max_input_length(
            model=model, max_input_length=auto_gptq_exllama_max_input_length
        )

    # Load the tokenizer:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        tokenizer_name, **tokenizer_kwargs
    )

    # Initialize a generation pipline and return:
    pipe = transformers.pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        batch_size=batch_size,
    )
    pipe.tokenizer.pad_token_id = model.config.eos_token_id
    return pipe


def _read_file_batch(
    file_batch: List[pathlib.Path],
    prompt_template: str,
) -> List[str]:
    batch = []

    # Go over all files and read in usable format
    for file in file_batch:
        with open(file, "r", encoding="utf-8") as fp:
            batch.append(prompt_template.format(fp.read()))
    return batch


def _to_group_list(argument_value: list, argument_name: str, length: int):

    # Check if is list, turn to list if not
    argument_value = (
        argument_value if isinstance(argument_value, list) else [argument_value]
    )
    list_len = len(argument_value)

    # If not a list, or is a list of len 1 we duplicate for correct length
    # If list in wrong length throw an error
    if list_len != length:
        if list_len == 1:
            return argument_value * length
        raise ValueError(
            f"The argument value of '{argument_name}' is not equal to the length of the given questions - {length}"
        )
    return argument_value


class QuestionHandler:
    """
    A class for handling questions answering for a given question type.
    This class is used as a base class for all question types, and for default question type (regular question
    answering without any special handling).
    """

    class ConfigKeys:
        pass

    def __init__(self, **kwargs):
        pass

    @staticmethod
    def _get_answers(generated_text: str, questions_amount: int) -> List[str]:

        # Clear answer start (part before numbers):
        # TODO find better way to verify, for list of questions this is redundant for example
        if "1." not in generated_text:
            raise ValueError(
                f"Answer 1. is missing from the generated text: '{generated_text}'"
            )
        text = generated_text.split("1.", 1)[1]

        # Start extracting the answers:
        answers = []
        for i in range(1, questions_amount + 1):
            # If it's the last answer to look for, take the rest of the text:
            if i == questions_amount:
                answer_i = text
            # Verify there is a question number in the text:
            elif f"{i + 1}." not in text:
                raise ValueError(
                    f"Answer {i + 1}. is missing from the generated text: '{generated_text}'"
                )
            # Take i's answer:
            else:
                answer_i, text = text.split(f"{i + 1}.", 1)
            # Collect the answer removing redundant spaces:
            answers.append(answer_i.strip())

        return answers

    def _infer_questions(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:

        # Infer through the llm:
        batched_output = generation_pipeline(
            batched_input,
            generation_config=generation_config,
            eos_token_id=generation_pipeline.tokenizer.eos_token_id,
            return_full_text=False,
            num_return_sequences=1,
        )

        # Process the outputs to get the answers:
        batched_answers = []
        for output in batched_output:
            # Get the generated answers:
            answers = self._get_answers(
                generated_text=output[0]["generated_text"],
                questions_amount=questions_amount,
            )
            # Collect the processed answers:
            batched_answers.append(answers)
        return batched_answers

    def answer(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:
        """
        Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline.
        """
        return self._infer_questions(
            questions_amount=questions_amount,
            batched_input=batched_input,
            generation_pipeline=generation_pipeline,
            generation_config=generation_config,
        )


class PollQuestionHandler(QuestionHandler):
    """
    Static class to hold all the possible poll question configurations options keys
    """

    class ConfigKeys:
        """
        A class for handling questions answering for poll type questions.
        These type of question are answered by asking the same question multiple times
        and choosing the most common answer or the average answer.
        """

        #: The number of times to ask the same question.
        POLL_COUNT = "poll_count"

        #: The strategy to use for choosing the answer from the poll.
        POLL_STRATEGY = "poll_strategy"

    class Strategy(enum.Enum):
        #: The most common answer strategy.
        MOST_COMMON = "most_common"

        #: The average answer strategy.
        AVERAGE = "average"

        @staticmethod
        def most_common(answers):
            """
            Calculate the most common answer for a given list of answers.
            """
            count = Counter(answers)
            most_common = count.most_common(1)
            return most_common[0][0]

        @staticmethod
        def average(answers):
            """
            Calculate the average answer for a given list of answers.
            """
            if isinstance(answers[0], str):
                raise ValueError(
                    "Cannot perform poll with average answer strategy of non numeric values,"
                    " please change the question to give numeric data, or choose 'most_common' as strategy."
                )
            else:
                numeric_values = answers
            avg = sum(numeric_values) / len(numeric_values)

            # Round to the closest integer and return corresponding value
            return round(avg)

        def do(self, answers):
            """
            Perform the strategy.
            """
            return getattr(self, self.value)(answers)

    def __init__(
        self, poll_count: int = 5, poll_strategy: str = "most_common"):
        super().__init__()
        self.poll_count = poll_count
        self.poll_strategy = self.Strategy(poll_strategy)

    def answer(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:
        """
        Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline.
        """
        return self._answer_poll_questions(
            questions_amount=questions_amount,
            batched_input=batched_input,
            generation_pipeline=generation_pipeline,
            generation_config=generation_config,
        )

    def _answer_poll_questions(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:
        votes = []

        # Run the poll for each question
        for _ in range(self.poll_count):
            batched_answers = self._infer_questions(
                questions_amount=questions_amount,
                batched_input=batched_input,
                generation_pipeline=generation_pipeline,
                generation_config=generation_config,
            )
            votes.append(batched_answers)
        answers = []

        # Collect the answers according to the poll strategy
        # Average strategy works for numeric values only
        for batch in range(len(votes[0])):
            batched_answers = []
            for question in range(questions_amount):
                # Create a list of all answers to relevant question
                answer = [
                    votes[voter][batch][question] for voter in range(self.poll_count)
                ]
                answer = self.poll_strategy.do(answer)
                batched_answers.append(answer)
            answers.append(batched_answers)
        return answers


# Holds names of QuestionHandles
class QuestionTypes:
    DEFAULT = "default"
    POLL = "poll"


# Maps question types to their handlers
QUESTION_MAPPING = {
    QuestionTypes.DEFAULT: QuestionHandler,
    QuestionTypes.POLL: PollQuestionHandler,
}
 + functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import operator
import pathlib
from collections import Counter
from functools import reduce, wraps
from typing import Any, Dict, List, Tuple, Union

import pandas as pd
import transformers
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    global _LOGGER

    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        _LOGGER = context.logger
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


def open_mpi_handler(
    worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, str):
                    input_argument = _get_text_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                dataframe = pd.concat(objs=[df for df, _ in output], axis=0)
                errors_dictionary = reduce(operator.ior, [err for _, err in output], {})
                return dataframe, errors_dictionary
            return None

        return wrapper

    return decorator


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def answer_questions(
    data_path: Union[str, List[str]],
    model_name: str,
    questions: Union[List[str], List[List[str]]],
    device_map: Union[str, dict] = None,
    model_kwargs: dict = None,
    auto_gptq_exllama_max_input_length: int = None,
    tokenizer_name: str = None,
    tokenizer_kwargs: dict = None,
    text_wrapper: Union[str, List[str]] = "",
    questions_wrapper: Union[str, List[str]] = "",
    generation_config: Union[Dict, List[Dict]] = None,
    questions_config: Union[Dict, List[Dict]] = None,
    batch_size: int = 1,
    questions_columns: List[str] = None,
    verbose: bool = False,
) -> Tuple[pd.DataFrame, dict]:
    """
    Answer questions with a context to the given text files contents by a pretrained LLM model. Each text file will have
    the following prompt built:

    start of `text_wrapper`
    <text file content>
    end of `text_wrapper`

    start of `questions_wrapper`
    1. <questions[0]>
    2. <questions[1]>
    ...
    n. <questions[n-1]>
    end of `questions_wrapper`

    :param data_path:                          A path to a directory of text files or a path to a text file to ask
                                               questions about.
    :param model_name:                         The pre-trained model name from the huggingface hub to use for asking
                                               questions.
    :param questions:                          The questions to ask.
                                               A list of lists of questions to ask per text file, and devided
                                               by question groups, the groups can be dtermained by size (in order to
                                               avoid large inputs to the llm) or by questioning method
                                               (regular or poll like questioning).
    :param device_map:                         A map to use for loading the model on multiple devices.
    :param model_kwargs:                       Keyword arguments to pass for loading the model using HuggingFace's
                                               `transformers.AutoModelForCausalLM.from_pretrained` function.
    :param auto_gptq_exllama_max_input_length: For AutoGPTQ models to set and extend the model's input buffer size.
    :param tokenizer_name:                     The tokenizer name from the huggingface hub to use. If not given, the
                                               model name will be used.
    :param tokenizer_kwargs:                   Keyword arguments to pass for loading the tokenizer using HuggingFace's
                                               `transformers.AutoTokenizer.from_pretrained` function.
    :param text_wrapper:                       A wrapper for the file's text. Will be added at the start of the prompt.
                                               Must have a placeholder ('{}') for the text of the file.
    :param questions_wrapper:                  A wrapper for the questions received. Will be added after the text
                                               wrapper in the prompt template. Must have a placeholder ('{}') for the
                                               questions.
    :param generation_config:                  HuggingFace's `GenerationConfig` keyword arguments to pass to the
                                               `generate` method.
    :param questions_config:                   A dictionary or list of dictionaries containing specific ways to answer
                                               questions (using a poll for example), each dictionary in the list is for
                                               corresponding question group and determines the question asking method
                                               for said group.
    :param batch_size:                         Batch size for inference.
    :param questions_columns:                  Columns to use for the dataframe returned.
    :param verbose:                            Whether to present logs of a progress bar and errors. Default: True.


    :returns: A tuple of:

              * A dataframe dataset of the questions answers.
              * A dictionary of errored files that were not inferred or were not answered properly.
    """
    global _LOGGER

    # Set configs to empty dict if not given:
    if generation_config is None:
        generation_config = {}
    if questions_config is None:
        questions_config = {}

    # Get the input text files to question:
    if verbose:
        _LOGGER.info("Collecting text files.")
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        text_files = _get_text_files(data_path=data_path)
    else:
        text_files = data_path
    if verbose:
        _LOGGER.info(f"Collected {len(text_files)} text files.")

    # Get the prompt template:
    if verbose:
        _LOGGER.info("Creating prompt template.")

    # Organize questions as a list of list, and count number of sub-lists for future use
    number_of_question_groups = 1 if isinstance(questions[0], str) else len(questions)
    questions = _to_group_list(
        argument_value=questions,
        argument_name="questions",
        length=number_of_question_groups,
    )

    # Organize prompt parts at proper length
    text_wrapper = _to_group_list(
        argument_value=text_wrapper,
        argument_name="text_wrapper",
        length=number_of_question_groups,
    )
    questions_wrapper = _to_group_list(
        argument_value=questions_wrapper,
        argument_name="questions_wrapper",
        length=number_of_question_groups,
    )

    # Create a list of prompt according to given parts and questions
    prompt_template = []
    questions = questions if isinstance(questions[0], list) else [questions]

    # Build all prompts
    for i in range(number_of_question_groups):
        prompt_template.append(
            _get_prompt_template(
                text_wrapper=text_wrapper[i],
                questions_wrapper=questions_wrapper[i],
                questions=questions[i],
            )
        )
    if verbose:
        _LOGGER.info(f"Prompt template created:\n\n{prompt_template}\n")

    # Get the total amount of questions:
    questions_amount = sum([len(sublist) for sublist in questions])

    # Get the questions columns:
    questions_columns = questions_columns or [
        f"q{i}" for i in range(1, questions_amount + 1)
    ]

    # Check if we have the correct amount of questions columns:
    if len(questions_columns) != questions_amount:
        raise ValueError(
            f"The provided questions columns length ({len(questions_columns)}) "
            f"does not match the questions amount ({questions_amount})"
        )

    # Load the generation config:
    if verbose:
        _LOGGER.info("Loading generation configuration.")
    generation_config = [
        transformers.GenerationConfig(**(cfg or {}))
        for cfg in _to_group_list(
            argument_value=generation_config,
            argument_name="generation_config",
            length=number_of_question_groups,
        )
    ]
    if verbose:
        _LOGGER.info(f"Generation configuration loaded: {generation_config}")

    # Load the model and tokenizer into a pipeline object:
    if verbose:
        _LOGGER.info(f"Loading model '{model_name}'.")
    generation_pipeline = _get_generation_pipeline(
        model_name=model_name,
        device_map=device_map,
        tokenizer_name=tokenizer_name or model_name,
        model_kwargs=model_kwargs or {},
        tokenizer_kwargs=tokenizer_kwargs or {},
        auto_gptq_exllama_max_input_length=auto_gptq_exllama_max_input_length,
        batch_size=batch_size,
    )
    if verbose:
        _LOGGER.info("Model loaded.")

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Split the files into batches:
    file_batches = [
        text_files[i : i + batch_size]
        if i + batch_size < len(text_files)
        else text_files[i:]
        for i in range(0, len(text_files), batch_size)
    ]
    questions_config = _to_group_list(
        argument_value=questions_config,
        argument_name="questions_config",
        length=number_of_question_groups,
    )

    # Create a list of question handlers according to given configs
    handlers = []
    for cfg in questions_config:
        question_type = cfg.pop("type", "default")
        handlers.append(QUESTION_MAPPING.get(question_type)(**cfg))

    # Go over the batches of text files and question them:
    for file_batch in tqdm(
        file_batches,
        desc="Generating answers",
        unit=f"file (batch of {batch_size})",
        disable=not verbose,
    ):
        try:
            total_answers = [[] for _ in range(batch_size)]

            # Go over all question group per batch of documents
            for question_group in range(number_of_question_groups):
                current_questions_amount = len(questions[question_group])

                # Read batch (read the text from the text files):
                batched_input = _read_file_batch(
                    file_batch=file_batch,
                    prompt_template=prompt_template[question_group],
                )

                # Answer the questions with each question handler:
                batched_answers = handlers[question_group].answer(
                    questions_amount=current_questions_amount,
                    batched_input=batched_input,
                    generation_pipeline=generation_pipeline,
                    generation_config=generation_config[question_group],
                )

                # Put the answers in the correct place in the total answers list according to the place in the batch:
                for i in range(batch_size):
                    total_answers[i].extend(batched_answers[i])

            # Collect the answers and attach the file name:
            successes.extend(
                [
                    [file.name, *answers]
                    for file, answers in zip(file_batch, total_answers)
                ]
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            batch_file_names = ", ".join([file.name for file in file_batch])
            if verbose:
                _LOGGER.warning(
                    f"Error in batch '{batch_file_names}': {str(exception)}"
                )
            errors[batch_file_names] = str(exception)
            continue

    # Construct the answers dataframe:
    columns = [
        "text_file",
        *questions_columns,
    ]

    # Create a data frame of answers by files
    successes = pd.DataFrame(
        successes,
        columns=columns,
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Answers summary:\n"
            f"{successes.head()}"
        )
    return successes, errors


def _get_text_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:

    # Check if the path is of a directory or a file:
    if data_path.is_dir():

        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _get_prompt_template(
    text_wrapper: str,
    questions_wrapper: str,
    questions: List[str],
) -> str:

    # Validate and build the text wrapper:
    text_wrapper = text_wrapper or (
        "Given the following text:\n" "-----\n" "{}\n" "-----"
    )
    if text_wrapper.count("{}") != 1:
        raise ValueError(
            "The `text_wrapper` must include one placeholder '{}' for the text of the file to be asked about."
        )

    # Validate and build the question wrapper:
    questions_wrapper = questions_wrapper or "Answer the questions:\n" "{}"
    if questions_wrapper.count("{}") != 1:
        raise ValueError(
            "The `questions_wrapper` must include one placeholder '{}' for the list of questions."
        )

    # Validate and parse the questions:
    if len(questions) == 0:
        raise ValueError("Please include at least one question.")
    questions = "\n".join(
        [f"{i}. {question}" for i, question in enumerate(questions, 1)]
    )

    # Construct the template:
    return f"{text_wrapper}\n{questions_wrapper.format(questions)}\n"


def _get_generation_pipeline(
    model_name: str,
    device_map: Union[str, dict],
    tokenizer_name: str,
    model_kwargs: dict,
    tokenizer_kwargs: dict,
    auto_gptq_exllama_max_input_length: int = None,
    batch_size: int = 1,
):
    # Load the model:
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, device_map=device_map, **model_kwargs
    )

    # Set exllama max input length if provided:
    # This changes the model's context size.
    if auto_gptq_exllama_max_input_length:
        from auto_gptq import exllama_set_max_input_length

        model = exllama_set_max_input_length(
            model=model, max_input_length=auto_gptq_exllama_max_input_length
        )

    # Load the tokenizer:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        tokenizer_name, **tokenizer_kwargs
    )

    # Initialize a generation pipline and return:
    pipe = transformers.pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        batch_size=batch_size,
    )
    pipe.tokenizer.pad_token_id = model.config.eos_token_id
    return pipe


def _read_file_batch(
    file_batch: List[pathlib.Path],
    prompt_template: str,
) -> List[str]:
    batch = []

    # Go over all files and read in usable format
    for file in file_batch:
        with open(file, "r", encoding="utf-8") as fp:
            batch.append(prompt_template.format(fp.read()))
    return batch


def _to_group_list(argument_value: list, argument_name: str, length: int):

    # Check if is list, turn to list if not
    argument_value = (
        argument_value if isinstance(argument_value, list) else [argument_value]
    )
    list_len = len(argument_value)

    # If not a list, or is a list of len 1 we duplicate for correct length
    # If list in wrong length throw an error
    if list_len != length:
        if list_len == 1:
            return argument_value * length
        raise ValueError(
            f"The argument value of '{argument_name}' is not equal to the length of the given questions - {length}"
        )
    return argument_value


class QuestionHandler:
    """
    A class for handling questions answering for a given question type.
    This class is used as a base class for all question types, and for default question type (regular question
    answering without any special handling).
    """

    class ConfigKeys:
        pass

    def __init__(self):
        pass

    @staticmethod
    def _get_answers(generated_text: str, questions_amount: int) -> List[str]:

        # Clear answer start (part before numbers):
        # TODO find better way to verify, for list of questions this is redundant for example
        if "1." not in generated_text:
            raise ValueError(
                f"Answer 1. is missing from the generated text: '{generated_text}'"
            )
        text = generated_text.split("1.", 1)[1]

        # Start extracting the answers:
        answers = []
        for i in range(1, questions_amount + 1):
            # If it's the last answer to look for, take the rest of the text:
            if i == questions_amount:
                answer_i = text
            # Verify there is a question number in the text:
            elif f"{i + 1}." not in text:
                raise ValueError(
                    f"Answer {i + 1}. is missing from the generated text: '{generated_text}'"
                )
            # Take i's answer:
            else:
                answer_i, text = text.split(f"{i + 1}.", 1)
            # Collect the answer removing redundant spaces:
            answers.append(answer_i.strip())

        return answers

    def _infer_questions(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:

        # Infer through the llm:
        batched_output = generation_pipeline(
            batched_input,
            generation_config=generation_config,
            eos_token_id=generation_pipeline.tokenizer.eos_token_id,
            return_full_text=False,
            num_return_sequences=1,
        )

        # Process the outputs to get the answers:
        batched_answers = []
        for output in batched_output:
            # Get the generated answers:
            answers = self._get_answers(
                generated_text=output[0]["generated_text"],
                questions_amount=questions_amount,
            )
            # Collect the processed answers:
            batched_answers.append(answers)
        return batched_answers

    def answer(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:
        """
        Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline.
        """
        return self._infer_questions(
            questions_amount=questions_amount,
            batched_input=batched_input,
            generation_pipeline=generation_pipeline,
            generation_config=generation_config,
        )


class PollQuestionHandler(QuestionHandler):
    """
    Static class to hold all the possible poll question configurations options keys
    """

    class ConfigKeys:
        """
        A class for handling questions answering for poll type questions.
        These type of question are answered by asking the same question multiple times
        and choosing the most common answer or the average answer.
        """

        #: The number of times to ask the same question.
        POLL_COUNT = "poll_count"

        #: The strategy to use for choosing the answer from the poll.
        POLL_STRATEGY = "poll_strategy"

    class Strategy(enum.Enum):
        #: The most common answer strategy.
        MOST_COMMON = "most_common"

        #: The average answer strategy.
        AVERAGE = "average"

        @staticmethod
        def most_common(answers):
            """
            Calculate the most common answer for a given list of answers.
            """
            count = Counter(answers)
            most_common = count.most_common(1)
            return most_common[0][0]

        @staticmethod
        def average(answers):
            """
            Calculate the average answer for a given list of answers.
            """
            if isinstance(answers[0], str):
                raise ValueError(
                    "Cannot perform poll with average answer strategy of non numeric values,"
                    " please change the question to give numeric data, or choose 'most_common' as strategy."
                )
            else:
                numeric_values = answers
            avg = sum(numeric_values) / len(numeric_values)

            # Round to the closest integer and return corresponding value
            return round(avg)

        def do(self, answers):
            """
            Perform the strategy.
            """
            return getattr(self, self.value)(answers)

    def __init__(
        self, poll_count: int = 5, poll_strategy: str = "most_common"):
        super().__init__()
        self.poll_count = poll_count
        self.poll_strategy = self.Strategy(poll_strategy)

    def answer(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:
        """
        Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline.
        """
        return self._answer_poll_questions(
            questions_amount=questions_amount,
            batched_input=batched_input,
            generation_pipeline=generation_pipeline,
            generation_config=generation_config,
        )

    def _answer_poll_questions(
        self,
        questions_amount: int,
        batched_input: List[str],
        generation_pipeline: transformers.Pipeline,
        generation_config: transformers.GenerationConfig,
    ) -> List[List[str]]:
        votes = []

        # Run the poll for each question
        for _ in range(self.poll_count):
            batched_answers = self._infer_questions(
                questions_amount=questions_amount,
                batched_input=batched_input,
                generation_pipeline=generation_pipeline,
                generation_config=generation_config,
            )
            votes.append(batched_answers)
        answers = []

        # Collect the answers according to the poll strategy
        # Average strategy works for numeric values only
        for batch in range(len(votes[0])):
            batched_answers = []
            for question in range(questions_amount):
                # Create a list of all answers to relevant question
                answer = [
                    votes[voter][batch][question] for voter in range(self.poll_count)
                ]
                answer = self.poll_strategy.do(answer)
                batched_answers.append(answer)
            answers.append(batched_answers)
        return answers


# Holds names of QuestionHandles
class QuestionTypes:
    DEFAULT = "default"
    POLL = "poll"


# Maps question types to their handlers
QUESTION_MAPPING = {
    QuestionTypes.DEFAULT: QuestionHandler,
    QuestionTypes.POLL: PollQuestionHandler,
}
 base_image: mlrun/mlrun commands: [] code_origin: '' origin_filename: '' requirements: - - transformers torch tqdm + - transformers + - torch + - tqdm entry_points: open_mpi_handler: name: open_mpi_handler @@ -27,29 +29,30 @@ spec: parameters: - name: worker_inputs type: List[str] - default: '' - name: root_worker_inputs type: Dict[str, Any] default: null - outputs: - - default: '' + outputs: [] lineno: 58 + has_varargs: false + has_kwargs: false decorator: name: decorator doc: '' parameters: - name: handler - default: '' - outputs: - - default: '' + outputs: [] lineno: 66 + has_varargs: false + has_kwargs: false wrapper: name: wrapper doc: '' parameters: [] - outputs: - - default: '' + outputs: [] lineno: 71 + has_varargs: false + has_kwargs: true answer_questions: name: answer_questions doc: 'Answer questions with a context to the given text files contents by a @@ -81,19 +84,16 @@ spec: type: Union[str, List[str]] doc: A path to a directory of text files or a path to a text file to ask questions about. - default: '' - name: model_name type: str doc: The pre-trained model name from the huggingface hub to use for asking questions. - default: '' - name: questions type: Union[List[str], List[List[str]]] doc: The questions to ask. A list of lists of questions to ask per text file, and devided by question groups, the groups can be dtermained by size (in order to avoid large inputs to the llm) or by questioning method (regular or poll like questioning). - default: '' - name: device_map type: Union[str, dict] doc: A map to use for loading the model on multiple devices. @@ -152,60 +152,58 @@ spec: doc: 'Whether to present logs of a progress bar and errors. Default: True.' default: false outputs: - - default: '' - doc: 'A tuple of:' + - doc: 'A tuple of:' + type: Tuple[pd.DataFrame, dict] lineno: 130 + has_varargs: false + has_kwargs: false answer: name: answer doc: Answer questions with a context to the given text files contents by a pretrained LLM model in given pipeline. parameters: - name: self - default: '' - name: questions_amount type: int - default: '' - name: batched_input type: List[str] - default: '' - name: generation_pipeline type: Pipeline - default: '' - name: generation_config type: GenerationConfig - default: '' outputs: - - default: '' + - type: List[List[str]] lineno: 674 + has_varargs: false + has_kwargs: false most_common: name: most_common doc: Calculate the most common answer for a given list of answers. parameters: - name: answers - default: '' - outputs: - - default: '' + outputs: [] lineno: 637 + has_varargs: false + has_kwargs: false average: name: average doc: Calculate the average answer for a given list of answers. parameters: - name: answers - default: '' - outputs: - - default: '' + outputs: [] lineno: 646 + has_varargs: false + has_kwargs: false do: name: do doc: Perform the strategy. parameters: - name: self - default: '' - name: answers - default: '' - outputs: - - default: '' + outputs: [] lineno: 662 + has_varargs: false + has_kwargs: false description: GenAI approach of question answering on a given data default_handler: answer_questions disable_auto_mount: false diff --git a/question_answering/item.yaml b/question_answering/item.yaml index 6daa1b564..58ab5cc36 100755 --- a/question_answering/item.yaml +++ b/question_answering/item.yaml @@ -20,8 +20,8 @@ spec: image: mlrun/mlrun kind: job requirements: - transformers - torch - tqdm + - transformers + - torch + - tqdm url: '' -version: 0.3.0 +version: 0.3.1 diff --git a/requirements.txt b/requirements.txt index be36c8c86..faa20126f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,9 @@ black~=22.0 isort~=5.7 sphinx==4.0.2 sphinx-book-theme==0.3.3 -sphinx-togglebutton==0.3.1 \ No newline at end of file +sphinx-togglebutton==0.3.1 +sphinxcontrib-applehelp<=1.0.7 +sphinxcontrib.devhelp<=1.0.5 +sphinxcontrib-htmlhelp<=2.0.4 +sphinxcontrib-serializinghtml<=1.1.9 +sphinxcontrib-qthelp<=1.0.6 \ No newline at end of file diff --git a/silero_vad/assets/test_data.wav b/silero_vad/assets/test_data.wav new file mode 100644 index 000000000..a3a993c20 Binary files /dev/null and b/silero_vad/assets/test_data.wav differ diff --git a/silero_vad/function.yaml b/silero_vad/function.yaml new file mode 100644 index 000000000..75d1ce0cc --- /dev/null +++ b/silero_vad/function.yaml @@ -0,0 +1,280 @@ +kind: job +metadata: + name: silero-vad + tag: '' + hash: bc0ad5572cc391fcdc93baaee48e1ef949a7984d + project: '' + labels: + author: guyl + categories: + - Deep Learning + - PyTorch + - Audio +spec: + command: '' + args: [] + image: '' + build: + functionSourceCode: # Copyright 2024 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from multiprocessing import Process, Queue
from pathlib import Path
from types import FunctionType
from typing import Dict, List, Tuple, Type, Union

import torch
import torchaudio
from tqdm import tqdm


class BaseTask:
    """
    A base class for a task to complete after VAD.
    """

    def __init__(self, audio_file: Path):
        """
        Initialize the base task.

        :param audio_file: The audio file assigned to the task.
        """
        # Store the audio file:
        self._audio_file = audio_file

        # Prepare the result:
        self._result = None

    @property
    def audio_file(self) -> Path:
        """
        Get the audio file of the task.

        :returns: The audio file of the task.
        """
        return self._audio_file

    def do_task(
        self, speech_timestamps: Union[List[Dict[str, int]], List[List[Dict[str, int]]]]
    ):
        """
        Do the task on the given speech timestamps. The base task will simply save the speech timestamps as the result.

        :param speech_timestamps: The speech timestamps to do the task on as outputted from the VAD.
        """
        self._result = speech_timestamps

    def get_result(self) -> Tuple[str, list]:
        """
        Get the result of the task. A tuple of the audio file name and the result.

        :returns: The result of the task.
        """
        return self._audio_file.name, self._result

    def to_tuple(self) -> Tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        return self.__class__.__name__, {"audio_file": self._audio_file}


class SpeechDiarizationTask(BaseTask):
    """
    A speech diarization task. The task will diarize the VAD speech timestamps into speakers.
    """

    def __init__(self, audio_file: Path, speaker_labels: List[str]):
        """
        Initialize the speech diarization task.

        :param audio_file:     The audio file assigned to the task.
        :param speaker_labels: The speaker labels to use for the diarization. If not given, the speakers will be named
                               "speaker_0", "speaker_1", etc.
        """
        super().__init__(audio_file=audio_file)
        self._speaker_labels = speaker_labels

    def do_task(self, speech_timestamps: List[List[Dict[str, int]]]):
        """
        Do the task on the given speech timestamps. The task will diarize the VAD speech timestamps into speakers.

        :param speech_timestamps: The speech timestamps per channel to do the task on as outputted from the VAD.
        """
        # Get the speaker labels (set default if not given):
        speaker_labels = self._speaker_labels or [
            f"speaker_{i}" for i in range(len(speech_timestamps))
        ]

        # Diarize - organize the speech timestamps into a single list of speakers and sort it by start time:
        speech_diarization = [
            (speech_timestamp["start"], speech_timestamp["end"], speaker_label)
            for speaker_label, channel_speech_timestamps in zip(
                speaker_labels, speech_timestamps
            )
            for speech_timestamp in channel_speech_timestamps
        ]
        speech_diarization.sort()
        self._result = speech_diarization

    def to_tuple(self) -> Tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        task_class, task_kwargs = super().to_tuple()
        return task_class, {**task_kwargs, "speaker_labels": self._speaker_labels}


class TaskCreator:
    """
    A task creator to create different tasks to run after the VAD.
    """

    #: A map from task class name to task class to use in `from_tuple`:
    _MAP = {
        BaseTask.__name__: BaseTask,
        SpeechDiarizationTask.__name__: SpeechDiarizationTask,
    }

    def __init__(self, task_type: Type[BaseTask], task_kwargs: dict = None):
        """
        Initialize the task creator.
        :param task_type: The task type - a `BaseTask` subclass.
        :param task_kwargs: Additional keyword arguments to pass to the to be created tasks.
        """
        self._task_type = task_type
        self._task_kwargs = task_kwargs or {}

    def create_task(self, audio_file: Path) -> BaseTask:
        """
        Create a task with the given audio file.

        :param audio_file: The audio file to assign to the task.

        :returns: The created task.
        """
        return self._task_type(audio_file=audio_file, **self._task_kwargs)

    @classmethod
    def from_tuple(cls, task_tuple: Tuple[str, dict]) -> BaseTask:
        """
        Create a task from a tuple of the audio file name and the task kwargs.

        :param task_tuple: The task tuple to create the task from.

        :returns: The created task.
        """
        task_class, task_kwargs = task_tuple
        return cls._MAP[task_class](**task_kwargs)


class VoiceActivityDetector:
    """
    A voice activity detection wrapper for the silero VAD model - https://github.com/snakers4/silero-vad.
    """

    def __init__(
        self,
        # Model loading kwargs:
        use_onnx: bool = True,
        force_onnx_cpu: bool = True,
        # Detection kwargs:
        threshold: float = 0.5,
        sampling_rate: int = 16_000,
        min_speech_duration_ms: int = 250,
        max_speech_duration_s: float = float("inf"),
        min_silence_duration_ms: int = 100,
        window_size_samples: int = 512,
        speech_pad_ms: int = 30,
        return_seconds: bool = False,
        per_channel: bool = False,
    ):
        """
        Initialize the voice activity detector.

        :param use_onnx:                Whether to use ONNX for inference. Default is True.
        :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
        :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                        probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                        this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                        most datasets.
        :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
        :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
        :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                        `max_speech_duration_s` will be split at the timestamp of the last silence that
                                        lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise,
                                        they will be split aggressively just before max_speech_duration_s.
        :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before
                                        separating it.
        :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.
                                        WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                        sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                        these may affect model performance!
        :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
        :param return_seconds:          Whether return timestamps in seconds. False means to return timestamps in
                                        samples (default - False).
        :param per_channel:             Whether to return timestamps per channel (default - False). This will run VAD
                                        on each channel separately and return a list of timestamps per channel.
        """
        # Store configurations:
        self._use_onnx = use_onnx
        self._force_onnx_cpu = force_onnx_cpu
        self._threshold = threshold
        self._sampling_rate = sampling_rate
        self._min_speech_duration_ms = min_speech_duration_ms
        self._max_speech_duration_s = max_speech_duration_s
        self._min_silence_duration_ms = min_silence_duration_ms
        self._window_size_samples = window_size_samples
        self._speech_pad_ms = speech_pad_ms
        self._return_seconds = return_seconds
        self._per_channel = per_channel

        # Prepare the model variables
        self._model: torch.Module = None
        self._get_speech_timestamps: FunctionType = None

    def load(self, force_reload: bool = True):
        """
        Load the VAD model.

        :param force_reload: Whether to force reload the model even if it was already loaded. Default is True.
        """
        model, utils = torch.hub.load(
            repo_or_dir="snakers4/silero-vad",
            model="silero_vad",
            force_reload=force_reload,
            onnx=self._use_onnx,
            force_onnx_cpu=self._force_onnx_cpu,
        )
        self._model = model
        (
            self._get_speech_timestamps,
            _,  # save_audio,
            _,  # read_audio,
            _,  # VADIterator,
            _,  # collect_chunks
        ) = utils

    def detect_voice(
        self,
        audio_file: Path,
    ) -> Union[List[Dict[str, int]], List[List[Dict[str, int]]]]:
        """
        Infer the audio through the VAD model and return the speech timestamps.

        :param audio_file: The audio file to infer.

        :returns: The speech timestamps in the audio. A list of timestamps where each timestamp is a dictionary with the
                 following keys:

                 * "start": The start sample index of the speech in the audio.
                 * "end":   The end sample index of the speech in the audio.

                 If `per_channel` is True, a list of timestamps per channel will be returned.
        """
        # Cast to a numpy array:
        audio = self._read_audio(audio_file)

        # Detect speech:
        if not self._per_channel:
            return self._get_speech_timestamps(
                audio,
                self._model,
                threshold=self._threshold,
                min_speech_duration_ms=self._min_speech_duration_ms,
                max_speech_duration_s=self._max_speech_duration_s,
                min_silence_duration_ms=self._min_silence_duration_ms,
                speech_pad_ms=self._speech_pad_ms,
                sampling_rate=self._sampling_rate,
                window_size_samples=self._window_size_samples,
                return_seconds=self._return_seconds,
            )

        # Per channel:
        speech_timestamps = []
        for channel in audio:
            speech_timestamps.append(
                self._get_speech_timestamps(
                    channel,
                    self._model,
                    threshold=self._threshold,
                    min_speech_duration_ms=self._min_speech_duration_ms,
                    max_speech_duration_s=self._max_speech_duration_s,
                    min_silence_duration_ms=self._min_silence_duration_ms,
                    speech_pad_ms=self._speech_pad_ms,
                    sampling_rate=self._sampling_rate,
                    window_size_samples=self._window_size_samples,
                    return_seconds=self._return_seconds,
                )
            )

        return speech_timestamps

    def _read_audio(
        self,
        path: Path,
    ) -> torch.Tensor:
        """
        Read the audio from the given path and return it as a tensor.

        :param path: The path to the audio file.

        :returns: The audio as a tensor.
        """
        # Read the audio:
        audio, sampling_rate = torchaudio.load(str(path))

        # Check if the audio is stereo and if so, convert it to mono (only if not per channel):
        if audio.size(0) > 1 and not self._per_channel:
            audio = audio.mean(dim=0, keepdim=True)

        # Resample the audio if needed:
        if sampling_rate != self._sampling_rate:
            transform = torchaudio.transforms.Resample(
                orig_freq=sampling_rate, new_freq=self._sampling_rate
            )
            audio = transform(audio)

        # Return the audio (squeeze if not per channel):
        return audio if self._per_channel else audio.squeeze(0)


#: The value to send into multiprocessing queues to stop the process:
_MULTIPROCESSING_STOP_MARK = "STOP"


def _multiprocessing_complete_tasks(
    vad_init_kwargs: dict, tasks_queue: Queue, results_queue: Queue
):
    """
    Complete the tasks in the given queue and put the results in the given results queue. The function will stop when
    the given tasks queue will receive the stop mark. It is aimed to be used with multiprocessing as a process.

    :param vad_init_kwargs: The VAD initialization kwargs.
    :param tasks_queue:     A queue to get the tasks from.
    :param results_queue:   A queue to put the results in.
    """
    # Initialize and load the VAD:
    vad = VoiceActivityDetector(**vad_init_kwargs)
    vad.load(force_reload=False)

    # Start listening to the tasks queue:
    while True:
        # Get the task:
        task: Tuple[str, dict] = tasks_queue.get()
        if task == _MULTIPROCESSING_STOP_MARK:
            break
        try:
            # Create the task:
            task = TaskCreator.from_tuple(task_tuple=task)
            # Run the file through the VAD:
            speech_timestamps = vad.detect_voice(audio_file=task.audio_file)
            # Complete the task:
            task.do_task(speech_timestamps=speech_timestamps)
            # Build the result:
            result = (False, task.get_result())
        except Exception as exception:
            # Build the error:
            result = (True, (task.audio_file.name, str(exception)))
        # Collect the result / error:
        results_queue.put(result)

    # Mark the end of the tasks:
    results_queue.put(_MULTIPROCESSING_STOP_MARK)


# Get the global logger:
try:
    import mlrun

    _LOGGER = mlrun.get_or_create_ctx("silero_vad").logger
except ModuleNotFoundError:
    _LOGGER = logging.getLogger()


def detect_voice(
    # Input kwargs:
    data_path: Union[str, Path, List[Union[str, Path]]],
    # Model loading kwargs:
    use_onnx: bool = True,
    force_onnx_cpu: bool = True,
    # Detection kwargs:
    threshold: float = 0.5,
    sampling_rate: int = 16_000,
    min_speech_duration_ms: int = 250,
    max_speech_duration_s: float = float("inf"),
    min_silence_duration_ms: int = 100,
    window_size_samples: int = 512,
    speech_pad_ms: int = 30,
    return_seconds: bool = False,
    per_channel: bool = False,
    # Other kwargs:
    use_multiprocessing: int = 0,
    verbose: bool = False,
):
    """
    Perform voice activity detection on given audio files using the silero VAD model -
    https://github.com/snakers4/silero-vad. The end result is a dictionary with the file names as keys and their
    VAD timestamps dictionaries as value.

    For example::

        {
            "file_1.wav": [
                {"start": 0, "end": 16000},
                {"start": 16000, "end": 32000},
                {"start": 32000, "end": 48000},
                ...
            ],
            "file_2.wav": [
                {"start": 0, "end": 16000},
                {"start": 16000, "end": 32000},
                {"start": 32000, "end": 48000},
                ...
            ],
            ...
        }


    :param data_path:               The path to the audio files to diarize. Can be a path to a single file, a path to a
                                    directory or a list of paths to files.
    :param use_onnx:                Whether to use ONNX for inference. Default is True.
    :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
    :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                    probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                    this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                    most datasets.
    :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
    :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
    :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                    `max_speech_duration_s` will be split at the timestamp of the last silence that
                                    lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will
                                    be split aggressively just before max_speech_duration_s.
    :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating
                                    it.
    :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.

                                    WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                    sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                    these may affect model performance!
    :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
    :param return_seconds:          Whether return timestamps in seconds. False means to return timestamps in samples
                                    (default - False).
    :param per_channel:             Whether to return timestamps per channel (default - False). This will run VAD on
                                    each channel separately and return a list of timestamps per channel.
    :param use_multiprocessing:     The number of workers to use for multiprocessing. If 0, no multiprocessing will
                                    be used. Default is 0.
    :param verbose:                 Verbosity.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Initialize the transcription pipeline:
    vad_init_kwargs = {
        "use_onnx": use_onnx,
        "force_onnx_cpu": force_onnx_cpu,
        "threshold": threshold,
        "sampling_rate": sampling_rate,
        "min_speech_duration_ms": min_speech_duration_ms,
        "max_speech_duration_s": max_speech_duration_s,
        "min_silence_duration_ms": min_silence_duration_ms,
        "window_size_samples": window_size_samples,
        "speech_pad_ms": speech_pad_ms,
        "return_seconds": return_seconds,
        "per_channel": per_channel,
    }

    # Create the task creator:
    task_creator = TaskCreator(task_type=BaseTask)

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Detecting voice",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            description="Detecting voice",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )

    # Process the results:
    return _process_results(results=results, verbose=verbose)


def diarize(
    # Input / Output kwargs:
    data_path: Union[str, Path, List[Union[str, Path]]],
    # Model loading kwargs:
    use_onnx: bool = True,
    force_onnx_cpu: bool = True,
    # Detection kwargs:
    threshold: float = 0.5,
    sampling_rate: int = 16_000,
    min_speech_duration_ms: int = 250,
    max_speech_duration_s: float = float("inf"),
    min_silence_duration_ms: int = 100,
    window_size_samples: int = 512,
    speech_pad_ms: int = 30,
    # Diarization kwargs:
    speaker_labels: List[str] = None,
    # Other kwargs:
    use_multiprocessing: int = 0,
    verbose: bool = False,
):
    """
    Perform speech diarization on given audio files using the silero VAD model - https://github.com/snakers4/silero-vad.
    The speech diarization is performed per channel so that each channel in the audio belong to a different speaker. The
    end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list
    of tuples: (start, end, speaker_label).

    For example::

        {
            "file_1.wav": [
                (0.0, 1.0, "speaker_0"),
                (1.0, 2.0, "speaker_1"),
                (2.0, 3.0, "speaker_0"),
                ...
            ],
            "file_2.wav": [
                (0.0, 1.0, "speaker_0"),
                (1.0, 2.0, "speaker_1"),
                (2.0, 3.0, "speaker_0"),
                ...
            ],
            ...
        }


    :param data_path:               The path to the audio files to diarize. Can be a path to a single file, a path to a
                                    directory or a list of paths to files.
    :param use_onnx:                Whether to use ONNX for inference. Default is True.
    :param force_onnx_cpu:          Whether to force ONNX to use CPU for inference. Default is True.
    :param threshold:               Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
                                    probabilities ABOVE this value are considered as SPEECH. It is better to tune
                                    this parameter for each dataset separately, but "lazy" 0.5 is pretty good for
                                    most datasets.
    :param sampling_rate:           Currently, silero VAD models support 8000 and 16000 sample rates.
    :param min_speech_duration_ms:  Final speech chunks shorter min_speech_duration_ms are thrown out.
    :param max_speech_duration_s:   Maximum duration of speech chunks in seconds. Chunks longer than
                                    `max_speech_duration_s` will be split at the timestamp of the last silence that
                                    lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will
                                    be split aggressively just before max_speech_duration_s.
    :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating
                                    it.
    :param window_size_samples:     Audio chunks of window_size_samples size are fed to the silero VAD model.

                                    WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000
                                    sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than
                                    these may affect model performance!
    :param speech_pad_ms:           Final speech chunks are padded by speech_pad_ms each side.
    :param speaker_labels:          The speaker labels to use for the diarization. If not given, the speakers will be
                                    named "speaker_0", "speaker_1", etc.
    :param use_multiprocessing:     The number of workers to use for multiprocessing. If 0, no multiprocessing will
                                    be used. Default is 0.
    :param verbose:                 Verbosity.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Initialize the transcription pipeline:
    vad_init_kwargs = {
        "use_onnx": use_onnx,
        "force_onnx_cpu": force_onnx_cpu,
        "threshold": threshold,
        "sampling_rate": sampling_rate,
        "min_speech_duration_ms": min_speech_duration_ms,
        "max_speech_duration_s": max_speech_duration_s,
        "min_silence_duration_ms": min_silence_duration_ms,
        "window_size_samples": window_size_samples,
        "speech_pad_ms": speech_pad_ms,
        "return_seconds": True,
        "per_channel": True,
    }

    # Create the task creator:
    task_creator = TaskCreator(
        task_type=SpeechDiarizationTask, task_kwargs={"speaker_labels": speaker_labels}
    )

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing,
            audio_files=audio_files,
            description="Diarizing",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            description="Diarizing",
            vad_init_kwargs=vad_init_kwargs,
            task_creator=task_creator,
            verbose=verbose,
        )

    # Process the results:
    return _process_results(results=results, verbose=verbose)


def _get_audio_files(
    data_path: Union[Path, str, list],
) -> List[Path]:
    """
    Get the audio files from the data path. If a path to a directory is given, all files in the directory will be
    collected.

    :param data_path: The data path to collect the audio files from.

    :returns: The audio files list.
    """
    # Check if given a list of paths:
    if isinstance(data_path, list):
        audio_files = []
        for path in data_path:
            audio_files.extend(_get_audio_files(data_path=path))
        return audio_files

    # Check if given a single string path to cast it to a `pathlib.Path`:
    if isinstance(data_path, str):
        data_path = Path(data_path).absolute()

    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be a valid path to either a directory path or a "
            f"file. Given: {str(data_path)} "
        )

    return audio_files


def _run(
    audio_files: List[Path],
    description: str,
    vad_init_kwargs: dict,
    task_creator: TaskCreator,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, list]]]:
    """
    Load a VAD and use it to complete the tasks that will be created on the provided files using the given task creator.

    :param audio_files:     The audio files to use.
    :param description:     The description to use for the progress bar.
    :param vad_init_kwargs: The VAD initialization keyword arguments.
    :param task_creator:    The task creator to use to create the tasks.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Load the VAD:
    vad = VoiceActivityDetector(**vad_init_kwargs)
    if verbose:
        _LOGGER.info(f"Loading the VAD model.")
    vad.load()
    if verbose:
        _LOGGER.info("VAD model loaded.")

    # Run the VAD on the audio files and collect the results:
    results = []
    for audio_file in tqdm(
        audio_files,
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ):
        try:
            # Create the task:
            task = task_creator.create_task(audio_file=audio_file)
            # Run the file through the VAD:
            speech_timestamps = vad.detect_voice(audio_file=audio_file)
            # Complete the task:
            task.do_task(speech_timestamps=speech_timestamps)
            # Collect the result:
            results.append((False, task.get_result()))
        except Exception as exception:
            # Collect the error:
            results.append((True, (audio_file.name, str(exception))))

    return results


def _parallel_run(
    n_workers: int,
    audio_files: List[Path],
    description: str,
    vad_init_kwargs: dict,
    task_creator: TaskCreator,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, list]]]:
    """
    Run multiple VAD workers with multiprocessing to complete the tasks that will be created on the provided files using
    the given task creator.

    :param n_workers:       The number of workers to use.
    :param audio_files:     The audio files to use.
    :param description:     The description to use for the progress bar.
    :param vad_init_kwargs: The VAD initialization keyword arguments.
    :param task_creator:    The task creator to use to create the tasks.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Load the VAD (download once, and it will be loaded then per process later on):
    if verbose:
        _LOGGER.info(f"Loading the VAD model.")
    vad = VoiceActivityDetector(**vad_init_kwargs)
    vad.load()
    if verbose:
        _LOGGER.info("VAD model loaded.")

    # Check the number of workers:
    if n_workers > len(audio_files):
        _LOGGER.warning(
            f"The number of workers ({n_workers}) is larger than the number of audio files ({len(audio_files)}). "
            f"Setting the number of workers to {len(audio_files)}."
        )
        n_workers = len(audio_files)

    # Initialize the multiprocessing queues:
    tasks_queue = Queue()
    results_queue = Queue()

    # Initialize the multiprocessing processes:
    task_completion_processes = [
        Process(
            target=_multiprocessing_complete_tasks,
            kwargs={
                "vad_init_kwargs": vad_init_kwargs,
                "tasks_queue": tasks_queue,
                "results_queue": results_queue,
            },
        )
        for _ in range(n_workers)
    ]

    # Start the multiprocessing processes:
    for p in task_completion_processes:
        p.start()

    # Put the tasks in the queue:
    for audio_file in audio_files:
        tasks_queue.put(task_creator.create_task(audio_file=audio_file).to_tuple())

    # Put the stop marks in the queue:
    for _ in range(n_workers):
        tasks_queue.put(_MULTIPROCESSING_STOP_MARK)

    # Collect the results:
    results = []
    stop_marks_counter = 0
    with tqdm(
        desc=description,
        unit="file",
        total=len(audio_files),
        disable=not verbose,
    ) as progressbar:
        while True:
            # Get a result from the queue:
            result: Tuple[bool, Tuple[str, list]] = results_queue.get()
            if result == _MULTIPROCESSING_STOP_MARK:
                stop_marks_counter += 1
                if stop_marks_counter == n_workers:
                    break
            else:
                # Collect the result:
                results.append(result)
                progressbar.update(1)

    # Wait for the processes to finish:
    for p in task_completion_processes:
        p.join()

    return results


def _process_results(
    results: List[Tuple[bool, Tuple[str, list]]], verbose: bool
) -> Tuple[dict, dict]:
    """
    Process the results of the tasks.

    :param results: The results to process.
    :param verbose: Verbosity.

    :returns: The processed results as a tuple of successes and errors.
    """
    if verbose:
        _LOGGER.info("Summarizing the results.")
    successes = {}
    errors = {}
    for is_error, result in results:
        if is_error:
            errors[result[0]] = result[1]
        else:
            successes[result[0]] = result[1]
    if verbose:
        _LOGGER.info(f"Done ({len(successes)}/{len(successes) + len(errors)})\n")

    return successes, errors
 + base_image: mlrun/mlrun + commands: [] + code_origin: '' + origin_filename: '' + requirements: + - torch + - torchaudio + - tqdm + - onnxruntime + entry_points: + audio_file: + name: audio_file + doc: Get the audio file of the task. + parameters: + - name: self + outputs: + - doc: The audio file of the task. + type: Path + default: '' + lineno: 43 + do_task: + name: do_task + doc: Do the task on the given speech timestamps. The task will diarize the VAD + speech timestamps into speakers. + parameters: + - name: self + - name: speech_timestamps + type: List[List[Dict[str, int]]] + doc: The speech timestamps per channel to do the task on as outputted from + the VAD. + outputs: + - default: '' + lineno: 94 + get_result: + name: get_result + doc: Get the result of the task. A tuple of the audio file name and the result. + parameters: + - name: self + outputs: + - doc: The result of the task. + default: '' + lineno: 61 + to_tuple: + name: to_tuple + doc: Convert the task to a tuple to reconstruct it later (used for multiprocessing + to pass in queue). + parameters: + - name: self + outputs: + - doc: The converted task. + default: '' + lineno: 116 + create_task: + name: create_task + doc: Create a task with the given audio file. + parameters: + - name: self + - name: audio_file + type: Path + doc: The audio file to assign to the task. + outputs: + - doc: The created task. + type: BaseTask + default: '' + lineno: 146 + from_tuple: + name: from_tuple + doc: Create a task from a tuple of the audio file name and the task kwargs. + parameters: + - name: cls + - name: task_tuple + type: Tuple[str, dict] + doc: The task tuple to create the task from. + outputs: + - doc: The created task. + type: BaseTask + default: '' + lineno: 157 + load: + name: load + doc: Load the VAD model. + parameters: + - name: self + - name: force_reload + type: bool + doc: Whether to force reload the model even if it was already loaded. Default + is True. + default: true + outputs: + - default: '' + lineno: 234 + detect_voice: + name: detect_voice + doc: "Perform voice activity detection on given audio files using the silero\ + \ VAD model -\nhttps://github.com/snakers4/silero-vad. The end result is a\ + \ dictionary with the file names as keys and their\nVAD timestamps dictionaries\ + \ as value.\n\nFor example::\n\n {\n \"file_1.wav\": [\n \ + \ {\"start\": 0, \"end\": 16000},\n {\"start\": 16000, \"end\"\ + : 32000},\n {\"start\": 32000, \"end\": 48000},\n ...\n\ + \ ],\n \"file_2.wav\": [\n {\"start\": 0, \"end\"\ + : 16000},\n {\"start\": 16000, \"end\": 32000},\n {\"\ + start\": 32000, \"end\": 48000},\n ...\n ],\n ...\n\ + \ }" + parameters: + - name: data_path + type: Union[str, Path, List[Union[str, Path]]] + doc: The path to the audio files to diarize. Can be a path to a single file, + a path to a directory or a list of paths to files. + - name: use_onnx + type: bool + doc: Whether to use ONNX for inference. Default is True. + default: true + - name: force_onnx_cpu + type: bool + doc: Whether to force ONNX to use CPU for inference. Default is True. + default: true + - name: threshold + type: float + doc: Speech threshold. Silero VAD outputs speech probabilities for each audio + chunk, probabilities ABOVE this value are considered as SPEECH. It is better + to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty + good for most datasets. + default: 0.5 + - name: sampling_rate + type: int + doc: Currently, silero VAD models support 8000 and 16000 sample rates. + default: 16000 + - name: min_speech_duration_ms + type: int + doc: Final speech chunks shorter min_speech_duration_ms are thrown out. + default: 250 + - name: max_speech_duration_s + type: float + doc: Maximum duration of speech chunks in seconds. Chunks longer than `max_speech_duration_s` + will be split at the timestamp of the last silence that lasts more than + 100ms (if any), to prevent aggressive cutting. Otherwise, they will be split + aggressively just before max_speech_duration_s. + default: float('inf') + - name: min_silence_duration_ms + type: int + doc: In the end of each speech chunk wait for min_silence_duration_ms before + separating it. + default: 100 + - name: window_size_samples + type: int + doc: Audio chunks of window_size_samples size are fed to the silero VAD model. + default: 512 + - name: speech_pad_ms + type: int + doc: Final speech chunks are padded by speech_pad_ms each side. + default: 30 + - name: return_seconds + type: bool + doc: Whether return timestamps in seconds. False means to return timestamps + in samples (default - False). + default: false + - name: per_channel + type: bool + doc: Whether to return timestamps per channel (default - False). This will + run VAD on each channel separately and return a list of timestamps per channel. + default: false + - name: use_multiprocessing + type: int + doc: The number of workers to use for multiprocessing. If 0, no multiprocessing + will be used. Default is 0. + default: 0 + - name: verbose + type: bool + doc: Verbosity. + default: false + outputs: + - default: '' + lineno: 393 + diarize: + name: diarize + doc: "Perform speech diarization on given audio files using the silero VAD model\ + \ - https://github.com/snakers4/silero-vad.\nThe speech diarization is performed\ + \ per channel so that each channel in the audio belong to a different speaker.\ + \ The\nend result is a dictionary with the file names as keys and their diarization\ + \ as value. A diarization is a list\nof tuples: (start, end, speaker_label).\n\ + \nFor example::\n\n {\n \"file_1.wav\": [\n (0.0, 1.0,\ + \ \"speaker_0\"),\n (1.0, 2.0, \"speaker_1\"),\n (2.0,\ + \ 3.0, \"speaker_0\"),\n ...\n ],\n \"file_2.wav\"\ + : [\n (0.0, 1.0, \"speaker_0\"),\n (1.0, 2.0, \"speaker_1\"\ + ),\n (2.0, 3.0, \"speaker_0\"),\n ...\n ],\n\ + \ ...\n }" + parameters: + - name: data_path + type: Union[str, Path, List[Union[str, Path]]] + doc: The path to the audio files to diarize. Can be a path to a single file, + a path to a directory or a list of paths to files. + - name: use_onnx + type: bool + doc: Whether to use ONNX for inference. Default is True. + default: true + - name: force_onnx_cpu + type: bool + doc: Whether to force ONNX to use CPU for inference. Default is True. + default: true + - name: threshold + type: float + doc: Speech threshold. Silero VAD outputs speech probabilities for each audio + chunk, probabilities ABOVE this value are considered as SPEECH. It is better + to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty + good for most datasets. + default: 0.5 + - name: sampling_rate + type: int + doc: Currently, silero VAD models support 8000 and 16000 sample rates. + default: 16000 + - name: min_speech_duration_ms + type: int + doc: Final speech chunks shorter min_speech_duration_ms are thrown out. + default: 250 + - name: max_speech_duration_s + type: float + doc: Maximum duration of speech chunks in seconds. Chunks longer than `max_speech_duration_s` + will be split at the timestamp of the last silence that lasts more than + 100ms (if any), to prevent aggressive cutting. Otherwise, they will be split + aggressively just before max_speech_duration_s. + default: float('inf') + - name: min_silence_duration_ms + type: int + doc: In the end of each speech chunk wait for min_silence_duration_ms before + separating it. + default: 100 + - name: window_size_samples + type: int + doc: Audio chunks of window_size_samples size are fed to the silero VAD model. + default: 512 + - name: speech_pad_ms + type: int + doc: Final speech chunks are padded by speech_pad_ms each side. + default: 30 + - name: speaker_labels + type: List[str] + doc: The speaker labels to use for the diarization. If not given, the speakers + will be named "speaker_0", "speaker_1", etc. + default: null + - name: use_multiprocessing + type: int + doc: The number of workers to use for multiprocessing. If 0, no multiprocessing + will be used. Default is 0. + default: 0 + - name: verbose + type: bool + doc: Verbosity. + default: false + outputs: + - default: '' + lineno: 517 + description: Silero VAD (Voice Activity Detection) functions. + default_handler: detect_voice + disable_auto_mount: false + clone_target_dir: '' + env: [] + priority_class_name: '' + preemption_mode: prevent + affinity: null + tolerations: null + security_context: {} +verbose: false diff --git a/silero_vad/item.yaml b/silero_vad/item.yaml new file mode 100644 index 000000000..6f85a4c7d --- /dev/null +++ b/silero_vad/item.yaml @@ -0,0 +1,30 @@ +apiVersion: v1 +categories: + - Deep Learning + - PyTorch + - Audio +description: Silero VAD (Voice Activity Detection) functions. +doc: '' +example: silero_vad.ipynb +generationDate: 2023-12-03:14-30 +hidden: false +icon: '' +labels: + author: guyl +maintainers: [] +marketplaceType: '' +mlrunVersion: 1.5.2 +name: silero_vad +platformVersion: 3.5.3 +spec: + filename: silero_vad.py + handler: detect_voice + image: mlrun/mlrun + kind: job + requirements: + - torch + - torchaudio + - tqdm + - onnxruntime +url: '' +version: 1.1.0 diff --git a/silero_vad/silero_vad.ipynb b/silero_vad/silero_vad.ipynb new file mode 100644 index 000000000..29cd7437e --- /dev/null +++ b/silero_vad/silero_vad.ipynb @@ -0,0 +1,35 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "initial_id", + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/silero_vad/silero_vad.py b/silero_vad/silero_vad.py new file mode 100644 index 000000000..a477d4ecf --- /dev/null +++ b/silero_vad/silero_vad.py @@ -0,0 +1,847 @@ +# Copyright 2024 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from multiprocessing import Process, Queue +from pathlib import Path +from types import FunctionType +from typing import Dict, List, Tuple, Type, Union + +import torch +import torchaudio +from tqdm import tqdm + + +class BaseTask: + """ + A base class for a task to complete after VAD. + """ + + def __init__(self, audio_file: Path): + """ + Initialize the base task. + + :param audio_file: The audio file assigned to the task. + """ + # Store the audio file: + self._audio_file = audio_file + + # Prepare the result: + self._result = None + + @property + def audio_file(self) -> Path: + """ + Get the audio file of the task. + + :returns: The audio file of the task. + """ + return self._audio_file + + def do_task( + self, speech_timestamps: Union[List[Dict[str, int]], List[List[Dict[str, int]]]] + ): + """ + Do the task on the given speech timestamps. The base task will simply save the speech timestamps as the result. + + :param speech_timestamps: The speech timestamps to do the task on as outputted from the VAD. + """ + self._result = speech_timestamps + + def get_result(self) -> Tuple[str, list]: + """ + Get the result of the task. A tuple of the audio file name and the result. + + :returns: The result of the task. + """ + return self._audio_file.name, self._result + + def to_tuple(self) -> Tuple[str, dict]: + """ + Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). + + :returns: The converted task. + """ + return self.__class__.__name__, {"audio_file": self._audio_file} + + +class SpeechDiarizationTask(BaseTask): + """ + A speech diarization task. The task will diarize the VAD speech timestamps into speakers. + """ + + def __init__(self, audio_file: Path, speaker_labels: List[str]): + """ + Initialize the speech diarization task. + + :param audio_file: The audio file assigned to the task. + :param speaker_labels: The speaker labels to use for the diarization. If not given, the speakers will be named + "speaker_0", "speaker_1", etc. + """ + super().__init__(audio_file=audio_file) + self._speaker_labels = speaker_labels + + def do_task(self, speech_timestamps: List[List[Dict[str, int]]]): + """ + Do the task on the given speech timestamps. The task will diarize the VAD speech timestamps into speakers. + + :param speech_timestamps: The speech timestamps per channel to do the task on as outputted from the VAD. + """ + # Get the speaker labels (set default if not given): + speaker_labels = self._speaker_labels or [ + f"speaker_{i}" for i in range(len(speech_timestamps)) + ] + + # Diarize - organize the speech timestamps into a single list of speakers and sort it by start time: + speech_diarization = [ + (speech_timestamp["start"], speech_timestamp["end"], speaker_label) + for speaker_label, channel_speech_timestamps in zip( + speaker_labels, speech_timestamps + ) + for speech_timestamp in channel_speech_timestamps + ] + speech_diarization.sort() + self._result = speech_diarization + + def to_tuple(self) -> Tuple[str, dict]: + """ + Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). + + :returns: The converted task. + """ + task_class, task_kwargs = super().to_tuple() + return task_class, {**task_kwargs, "speaker_labels": self._speaker_labels} + + +class TaskCreator: + """ + A task creator to create different tasks to run after the VAD. + """ + + #: A map from task class name to task class to use in `from_tuple`: + _MAP = { + BaseTask.__name__: BaseTask, + SpeechDiarizationTask.__name__: SpeechDiarizationTask, + } + + def __init__(self, task_type: Type[BaseTask], task_kwargs: dict = None): + """ + Initialize the task creator. + :param task_type: The task type - a `BaseTask` subclass. + :param task_kwargs: Additional keyword arguments to pass to the to be created tasks. + """ + self._task_type = task_type + self._task_kwargs = task_kwargs or {} + + def create_task(self, audio_file: Path) -> BaseTask: + """ + Create a task with the given audio file. + + :param audio_file: The audio file to assign to the task. + + :returns: The created task. + """ + return self._task_type(audio_file=audio_file, **self._task_kwargs) + + @classmethod + def from_tuple(cls, task_tuple: Tuple[str, dict]) -> BaseTask: + """ + Create a task from a tuple of the audio file name and the task kwargs. + + :param task_tuple: The task tuple to create the task from. + + :returns: The created task. + """ + task_class, task_kwargs = task_tuple + return cls._MAP[task_class](**task_kwargs) + + +class VoiceActivityDetector: + """ + A voice activity detection wrapper for the silero VAD model - https://github.com/snakers4/silero-vad. + """ + + def __init__( + self, + # Model loading kwargs: + use_onnx: bool = True, + force_onnx_cpu: bool = True, + # Detection kwargs: + threshold: float = 0.5, + sampling_rate: int = 16_000, + min_speech_duration_ms: int = 250, + max_speech_duration_s: float = float("inf"), + min_silence_duration_ms: int = 100, + window_size_samples: int = 512, + speech_pad_ms: int = 30, + return_seconds: bool = False, + per_channel: bool = False, + ): + """ + Initialize the voice activity detector. + + :param use_onnx: Whether to use ONNX for inference. Default is True. + :param force_onnx_cpu: Whether to force ONNX to use CPU for inference. Default is True. + :param threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, + probabilities ABOVE this value are considered as SPEECH. It is better to tune + this parameter for each dataset separately, but "lazy" 0.5 is pretty good for + most datasets. + :param sampling_rate: Currently, silero VAD models support 8000 and 16000 sample rates. + :param min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. + :param max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer than + `max_speech_duration_s` will be split at the timestamp of the last silence that + lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, + they will be split aggressively just before max_speech_duration_s. + :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before + separating it. + :param window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. + WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 + sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than + these may affect model performance! + :param speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side. + :param return_seconds: Whether return timestamps in seconds. False means to return timestamps in + samples (default - False). + :param per_channel: Whether to return timestamps per channel (default - False). This will run VAD + on each channel separately and return a list of timestamps per channel. + """ + # Store configurations: + self._use_onnx = use_onnx + self._force_onnx_cpu = force_onnx_cpu + self._threshold = threshold + self._sampling_rate = sampling_rate + self._min_speech_duration_ms = min_speech_duration_ms + self._max_speech_duration_s = max_speech_duration_s + self._min_silence_duration_ms = min_silence_duration_ms + self._window_size_samples = window_size_samples + self._speech_pad_ms = speech_pad_ms + self._return_seconds = return_seconds + self._per_channel = per_channel + + # Prepare the model variables + self._model: torch.Module = None + self._get_speech_timestamps: FunctionType = None + + def load(self, force_reload: bool = True): + """ + Load the VAD model. + + :param force_reload: Whether to force reload the model even if it was already loaded. Default is True. + """ + model, utils = torch.hub.load( + repo_or_dir="snakers4/silero-vad", + model="silero_vad", + force_reload=force_reload, + onnx=self._use_onnx, + force_onnx_cpu=self._force_onnx_cpu, + ) + self._model = model + ( + self._get_speech_timestamps, + _, # save_audio, + _, # read_audio, + _, # VADIterator, + _, # collect_chunks + ) = utils + + def detect_voice( + self, + audio_file: Path, + ) -> Union[List[Dict[str, int]], List[List[Dict[str, int]]]]: + """ + Infer the audio through the VAD model and return the speech timestamps. + + :param audio_file: The audio file to infer. + + :returns: The speech timestamps in the audio. A list of timestamps where each timestamp is a dictionary with the + following keys: + + * "start": The start sample index of the speech in the audio. + * "end": The end sample index of the speech in the audio. + + If `per_channel` is True, a list of timestamps per channel will be returned. + """ + # Cast to a numpy array: + audio = self._read_audio(audio_file) + + # Detect speech: + if not self._per_channel: + return self._get_speech_timestamps( + audio, + self._model, + threshold=self._threshold, + min_speech_duration_ms=self._min_speech_duration_ms, + max_speech_duration_s=self._max_speech_duration_s, + min_silence_duration_ms=self._min_silence_duration_ms, + speech_pad_ms=self._speech_pad_ms, + sampling_rate=self._sampling_rate, + window_size_samples=self._window_size_samples, + return_seconds=self._return_seconds, + ) + + # Per channel: + speech_timestamps = [] + for channel in audio: + speech_timestamps.append( + self._get_speech_timestamps( + channel, + self._model, + threshold=self._threshold, + min_speech_duration_ms=self._min_speech_duration_ms, + max_speech_duration_s=self._max_speech_duration_s, + min_silence_duration_ms=self._min_silence_duration_ms, + speech_pad_ms=self._speech_pad_ms, + sampling_rate=self._sampling_rate, + window_size_samples=self._window_size_samples, + return_seconds=self._return_seconds, + ) + ) + + return speech_timestamps + + def _read_audio( + self, + path: Path, + ) -> torch.Tensor: + """ + Read the audio from the given path and return it as a tensor. + + :param path: The path to the audio file. + + :returns: The audio as a tensor. + """ + # Read the audio: + audio, sampling_rate = torchaudio.load(str(path)) + + # Check if the audio is stereo and if so, convert it to mono (only if not per channel): + if audio.size(0) > 1 and not self._per_channel: + audio = audio.mean(dim=0, keepdim=True) + + # Resample the audio if needed: + if sampling_rate != self._sampling_rate: + transform = torchaudio.transforms.Resample( + orig_freq=sampling_rate, new_freq=self._sampling_rate + ) + audio = transform(audio) + + # Return the audio (squeeze if not per channel): + return audio if self._per_channel else audio.squeeze(0) + + +#: The value to send into multiprocessing queues to stop the process: +_MULTIPROCESSING_STOP_MARK = "STOP" + + +def _multiprocessing_complete_tasks( + vad_init_kwargs: dict, tasks_queue: Queue, results_queue: Queue +): + """ + Complete the tasks in the given queue and put the results in the given results queue. The function will stop when + the given tasks queue will receive the stop mark. It is aimed to be used with multiprocessing as a process. + + :param vad_init_kwargs: The VAD initialization kwargs. + :param tasks_queue: A queue to get the tasks from. + :param results_queue: A queue to put the results in. + """ + # Initialize and load the VAD: + vad = VoiceActivityDetector(**vad_init_kwargs) + vad.load(force_reload=False) + + # Start listening to the tasks queue: + while True: + # Get the task: + task: Tuple[str, dict] = tasks_queue.get() + if task == _MULTIPROCESSING_STOP_MARK: + break + try: + # Create the task: + task = TaskCreator.from_tuple(task_tuple=task) + # Run the file through the VAD: + speech_timestamps = vad.detect_voice(audio_file=task.audio_file) + # Complete the task: + task.do_task(speech_timestamps=speech_timestamps) + # Build the result: + result = (False, task.get_result()) + except Exception as exception: + # Build the error: + result = (True, (task.audio_file.name, str(exception))) + # Collect the result / error: + results_queue.put(result) + + # Mark the end of the tasks: + results_queue.put(_MULTIPROCESSING_STOP_MARK) + + +# Get the global logger: +try: + import mlrun + + _LOGGER = mlrun.get_or_create_ctx("silero_vad").logger +except ModuleNotFoundError: + _LOGGER = logging.getLogger() + + +def detect_voice( + # Input kwargs: + data_path: Union[str, Path, List[Union[str, Path]]], + # Model loading kwargs: + use_onnx: bool = True, + force_onnx_cpu: bool = True, + # Detection kwargs: + threshold: float = 0.5, + sampling_rate: int = 16_000, + min_speech_duration_ms: int = 250, + max_speech_duration_s: float = float("inf"), + min_silence_duration_ms: int = 100, + window_size_samples: int = 512, + speech_pad_ms: int = 30, + return_seconds: bool = False, + per_channel: bool = False, + # Other kwargs: + use_multiprocessing: int = 0, + verbose: bool = False, +): + """ + Perform voice activity detection on given audio files using the silero VAD model - + https://github.com/snakers4/silero-vad. The end result is a dictionary with the file names as keys and their + VAD timestamps dictionaries as value. + + For example:: + + { + "file_1.wav": [ + {"start": 0, "end": 16000}, + {"start": 16000, "end": 32000}, + {"start": 32000, "end": 48000}, + ... + ], + "file_2.wav": [ + {"start": 0, "end": 16000}, + {"start": 16000, "end": 32000}, + {"start": 32000, "end": 48000}, + ... + ], + ... + } + + + :param data_path: The path to the audio files to diarize. Can be a path to a single file, a path to a + directory or a list of paths to files. + :param use_onnx: Whether to use ONNX for inference. Default is True. + :param force_onnx_cpu: Whether to force ONNX to use CPU for inference. Default is True. + :param threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, + probabilities ABOVE this value are considered as SPEECH. It is better to tune + this parameter for each dataset separately, but "lazy" 0.5 is pretty good for + most datasets. + :param sampling_rate: Currently, silero VAD models support 8000 and 16000 sample rates. + :param min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. + :param max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer than + `max_speech_duration_s` will be split at the timestamp of the last silence that + lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will + be split aggressively just before max_speech_duration_s. + :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating + it. + :param window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. + + WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 + sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than + these may affect model performance! + :param speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side. + :param return_seconds: Whether return timestamps in seconds. False means to return timestamps in samples + (default - False). + :param per_channel: Whether to return timestamps per channel (default - False). This will run VAD on + each channel separately and return a list of timestamps per channel. + :param use_multiprocessing: The number of workers to use for multiprocessing. If 0, no multiprocessing will + be used. Default is 0. + :param verbose: Verbosity. + """ + global _LOGGER + + # Get the input audio files to transcribe: + if verbose: + _LOGGER.info("Collecting audio files.") + audio_files = _get_audio_files(data_path=data_path) + if verbose: + _LOGGER.info(f"Collected {len(audio_files)} audio files.") + + # Initialize the transcription pipeline: + vad_init_kwargs = { + "use_onnx": use_onnx, + "force_onnx_cpu": force_onnx_cpu, + "threshold": threshold, + "sampling_rate": sampling_rate, + "min_speech_duration_ms": min_speech_duration_ms, + "max_speech_duration_s": max_speech_duration_s, + "min_silence_duration_ms": min_silence_duration_ms, + "window_size_samples": window_size_samples, + "speech_pad_ms": speech_pad_ms, + "return_seconds": return_seconds, + "per_channel": per_channel, + } + + # Create the task creator: + task_creator = TaskCreator(task_type=BaseTask) + + # Run the transcription: + if use_multiprocessing: + results = _parallel_run( + n_workers=use_multiprocessing, + audio_files=audio_files, + description="Detecting voice", + vad_init_kwargs=vad_init_kwargs, + task_creator=task_creator, + verbose=verbose, + ) + else: + results = _run( + audio_files=audio_files, + description="Detecting voice", + vad_init_kwargs=vad_init_kwargs, + task_creator=task_creator, + verbose=verbose, + ) + + # Process the results: + return _process_results(results=results, verbose=verbose) + + +def diarize( + # Input / Output kwargs: + data_path: Union[str, Path, List[Union[str, Path]]], + # Model loading kwargs: + use_onnx: bool = True, + force_onnx_cpu: bool = True, + # Detection kwargs: + threshold: float = 0.5, + sampling_rate: int = 16_000, + min_speech_duration_ms: int = 250, + max_speech_duration_s: float = float("inf"), + min_silence_duration_ms: int = 100, + window_size_samples: int = 512, + speech_pad_ms: int = 30, + # Diarization kwargs: + speaker_labels: List[str] = None, + # Other kwargs: + use_multiprocessing: int = 0, + verbose: bool = False, +): + """ + Perform speech diarization on given audio files using the silero VAD model - https://github.com/snakers4/silero-vad. + The speech diarization is performed per channel so that each channel in the audio belong to a different speaker. The + end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list + of tuples: (start, end, speaker_label). + + For example:: + + { + "file_1.wav": [ + (0.0, 1.0, "speaker_0"), + (1.0, 2.0, "speaker_1"), + (2.0, 3.0, "speaker_0"), + ... + ], + "file_2.wav": [ + (0.0, 1.0, "speaker_0"), + (1.0, 2.0, "speaker_1"), + (2.0, 3.0, "speaker_0"), + ... + ], + ... + } + + + :param data_path: The path to the audio files to diarize. Can be a path to a single file, a path to a + directory or a list of paths to files. + :param use_onnx: Whether to use ONNX for inference. Default is True. + :param force_onnx_cpu: Whether to force ONNX to use CPU for inference. Default is True. + :param threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, + probabilities ABOVE this value are considered as SPEECH. It is better to tune + this parameter for each dataset separately, but "lazy" 0.5 is pretty good for + most datasets. + :param sampling_rate: Currently, silero VAD models support 8000 and 16000 sample rates. + :param min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out. + :param max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer than + `max_speech_duration_s` will be split at the timestamp of the last silence that + lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will + be split aggressively just before max_speech_duration_s. + :param min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms before separating + it. + :param window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model. + + WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 + sample rate and 256, 512, 768 samples for 8000 sample rate. Values other than + these may affect model performance! + :param speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side. + :param speaker_labels: The speaker labels to use for the diarization. If not given, the speakers will be + named "speaker_0", "speaker_1", etc. + :param use_multiprocessing: The number of workers to use for multiprocessing. If 0, no multiprocessing will + be used. Default is 0. + :param verbose: Verbosity. + """ + global _LOGGER + + # Get the input audio files to transcribe: + if verbose: + _LOGGER.info("Collecting audio files.") + audio_files = _get_audio_files(data_path=data_path) + if verbose: + _LOGGER.info(f"Collected {len(audio_files)} audio files.") + + # Initialize the transcription pipeline: + vad_init_kwargs = { + "use_onnx": use_onnx, + "force_onnx_cpu": force_onnx_cpu, + "threshold": threshold, + "sampling_rate": sampling_rate, + "min_speech_duration_ms": min_speech_duration_ms, + "max_speech_duration_s": max_speech_duration_s, + "min_silence_duration_ms": min_silence_duration_ms, + "window_size_samples": window_size_samples, + "speech_pad_ms": speech_pad_ms, + "return_seconds": True, + "per_channel": True, + } + + # Create the task creator: + task_creator = TaskCreator( + task_type=SpeechDiarizationTask, task_kwargs={"speaker_labels": speaker_labels} + ) + + # Run the transcription: + if use_multiprocessing: + results = _parallel_run( + n_workers=use_multiprocessing, + audio_files=audio_files, + description="Diarizing", + vad_init_kwargs=vad_init_kwargs, + task_creator=task_creator, + verbose=verbose, + ) + else: + results = _run( + audio_files=audio_files, + description="Diarizing", + vad_init_kwargs=vad_init_kwargs, + task_creator=task_creator, + verbose=verbose, + ) + + # Process the results: + return _process_results(results=results, verbose=verbose) + + +def _get_audio_files( + data_path: Union[Path, str, list], +) -> List[Path]: + """ + Get the audio files from the data path. If a path to a directory is given, all files in the directory will be + collected. + + :param data_path: The data path to collect the audio files from. + + :returns: The audio files list. + """ + # Check if given a list of paths: + if isinstance(data_path, list): + audio_files = [] + for path in data_path: + audio_files.extend(_get_audio_files(data_path=path)) + return audio_files + + # Check if given a single string path to cast it to a `pathlib.Path`: + if isinstance(data_path, str): + data_path = Path(data_path).absolute() + + # Check if the path is of a directory or a file: + if data_path.is_dir(): + # Get all files inside the directory: + audio_files = list(data_path.glob("*.*")) + elif data_path.is_file(): + audio_files = [data_path] + else: + raise ValueError( + f"Unrecognized data path. The parameter `data_path` must be a valid path to either a directory path or a " + f"file. Given: {str(data_path)} " + ) + + return audio_files + + +def _run( + audio_files: List[Path], + description: str, + vad_init_kwargs: dict, + task_creator: TaskCreator, + verbose: bool, +) -> List[Tuple[bool, Tuple[str, list]]]: + """ + Load a VAD and use it to complete the tasks that will be created on the provided files using the given task creator. + + :param audio_files: The audio files to use. + :param description: The description to use for the progress bar. + :param vad_init_kwargs: The VAD initialization keyword arguments. + :param task_creator: The task creator to use to create the tasks. + :param verbose: Verbosity. + + :returns: The collected results. + """ + # Load the VAD: + vad = VoiceActivityDetector(**vad_init_kwargs) + if verbose: + _LOGGER.info(f"Loading the VAD model.") + vad.load() + if verbose: + _LOGGER.info("VAD model loaded.") + + # Run the VAD on the audio files and collect the results: + results = [] + for audio_file in tqdm( + audio_files, + desc=description, + unit="file", + total=len(audio_files), + disable=not verbose, + ): + try: + # Create the task: + task = task_creator.create_task(audio_file=audio_file) + # Run the file through the VAD: + speech_timestamps = vad.detect_voice(audio_file=audio_file) + # Complete the task: + task.do_task(speech_timestamps=speech_timestamps) + # Collect the result: + results.append((False, task.get_result())) + except Exception as exception: + # Collect the error: + results.append((True, (audio_file.name, str(exception)))) + + return results + + +def _parallel_run( + n_workers: int, + audio_files: List[Path], + description: str, + vad_init_kwargs: dict, + task_creator: TaskCreator, + verbose: bool, +) -> List[Tuple[bool, Tuple[str, list]]]: + """ + Run multiple VAD workers with multiprocessing to complete the tasks that will be created on the provided files using + the given task creator. + + :param n_workers: The number of workers to use. + :param audio_files: The audio files to use. + :param description: The description to use for the progress bar. + :param vad_init_kwargs: The VAD initialization keyword arguments. + :param task_creator: The task creator to use to create the tasks. + :param verbose: Verbosity. + + :returns: The collected results. + """ + # Load the VAD (download once, and it will be loaded then per process later on): + if verbose: + _LOGGER.info(f"Loading the VAD model.") + vad = VoiceActivityDetector(**vad_init_kwargs) + vad.load() + if verbose: + _LOGGER.info("VAD model loaded.") + + # Check the number of workers: + if n_workers > len(audio_files): + _LOGGER.warning( + f"The number of workers ({n_workers}) is larger than the number of audio files ({len(audio_files)}). " + f"Setting the number of workers to {len(audio_files)}." + ) + n_workers = len(audio_files) + + # Initialize the multiprocessing queues: + tasks_queue = Queue() + results_queue = Queue() + + # Initialize the multiprocessing processes: + task_completion_processes = [ + Process( + target=_multiprocessing_complete_tasks, + kwargs={ + "vad_init_kwargs": vad_init_kwargs, + "tasks_queue": tasks_queue, + "results_queue": results_queue, + }, + ) + for _ in range(n_workers) + ] + + # Start the multiprocessing processes: + for p in task_completion_processes: + p.start() + + # Put the tasks in the queue: + for audio_file in audio_files: + tasks_queue.put(task_creator.create_task(audio_file=audio_file).to_tuple()) + + # Put the stop marks in the queue: + for _ in range(n_workers): + tasks_queue.put(_MULTIPROCESSING_STOP_MARK) + + # Collect the results: + results = [] + stop_marks_counter = 0 + with tqdm( + desc=description, + unit="file", + total=len(audio_files), + disable=not verbose, + ) as progressbar: + while True: + # Get a result from the queue: + result: Tuple[bool, Tuple[str, list]] = results_queue.get() + if result == _MULTIPROCESSING_STOP_MARK: + stop_marks_counter += 1 + if stop_marks_counter == n_workers: + break + else: + # Collect the result: + results.append(result) + progressbar.update(1) + + # Wait for the processes to finish: + for p in task_completion_processes: + p.join() + + return results + + +def _process_results( + results: List[Tuple[bool, Tuple[str, list]]], verbose: bool +) -> Tuple[dict, dict]: + """ + Process the results of the tasks. + + :param results: The results to process. + :param verbose: Verbosity. + + :returns: The processed results as a tuple of successes and errors. + """ + if verbose: + _LOGGER.info("Summarizing the results.") + successes = {} + errors = {} + for is_error, result in results: + if is_error: + errors[result[0]] = result[1] + else: + successes[result[0]] = result[1] + if verbose: + _LOGGER.info(f"Done ({len(successes)}/{len(successes) + len(errors)})\n") + + return successes, errors diff --git a/silero_vad/test_silero_vad.py b/silero_vad/test_silero_vad.py new file mode 100644 index 000000000..d46471a57 --- /dev/null +++ b/silero_vad/test_silero_vad.py @@ -0,0 +1,44 @@ +import os +import tempfile + +import mlrun +import pytest + + +@pytest.fixture() +def setup_test(): + with tempfile.TemporaryDirectory() as artifact_path: + project = mlrun.get_or_create_project(name="default", context=artifact_path) + func = project.set_function( + func=os.path.abspath("./function.yaml"), + name="silero-vad", + image="mlrun/mlrun", + ) + yield func, artifact_path + + +def test_detect_voice(setup_test): + silero_vad_function, artifact_path = setup_test + run = silero_vad_function.run( + handler="detect_voice", + inputs={"data_path": "./assets"}, + returns=["vad_outputs: file", "errors: file"], + artifact_path=artifact_path, + local=True, + ) + assert run.outputs["vad_outputs"] + + +def test_diarize(setup_test): + silero_vad_function, artifact_path = setup_test + run = silero_vad_function.run( + handler="diarize", + inputs={"data_path": "./assets"}, + params={ + "speakers_labels": ["Agent", "Client"], + }, + returns=["speech_diarization: file", "errors: file"], + artifact_path=artifact_path, + local=True, + ) + assert run.outputs["speech_diarization"] diff --git a/speech_diarization/function.yaml b/speech_diarization/function.yaml deleted file mode 100644 index 03b0a78d5..000000000 --- a/speech_diarization/function.yaml +++ /dev/null @@ -1,143 +0,0 @@ -kind: job -metadata: - name: speech-diarization - tag: '' - hash: 2486500a2579a422fb586752aadc02a58427f60f - project: '' - labels: - author: guyl - categories: - - Utilities - - Machine Learning -spec: - command: '' - args: [] - image: mlrun/mlrun - build: - functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import heapq
import logging
import operator
import os
import pathlib
from functools import reduce, wraps
from typing import Any, Dict, List, Tuple, Union

import pandas as pd
import pyannote.audio
import pyannote.core
import torch
import torchaudio
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


def open_mpi_handler(
    worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    # Check if MLRun is available, set the global logger to MLRun's:
    if context:
        _LOGGER = context.logger

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, str):
                    input_argument = _get_audio_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                diarization_dictionary = reduce(
                    operator.ior, [dia for dia, _ in output], {}
                )
                errors_dictionary = reduce(operator.ior, [err for _, err in output], {})
                return diarization_dictionary, errors_dictionary
            return None

        return wrapper

    return decorator


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def diarize(
    data_path: Union[str, List[str]],
    model_name: str = "pyannote/speaker-diarization-3.0",
    access_token: str = None,
    device: str = None,
    speakers_labels: List[str] = None,
    speaker_prefix: str = "speaker_",
    separate_by_channels: bool = False,
    minimum_speakers: int = None,
    maximum_speakers: int = None,
    verbose: bool = False,
) -> Tuple[Dict[str, List[Tuple[float, float, str]]], Dict[str, str]]:
    """
    Perform speech diarization on given audio files using pyannote-audio (https://github.com/pyannote/pyannote-audio).
    The end result is a dictionary with the file names as keys and their diarization as value. A diarization is a list
    of tuples: (start, end, speaker_label).

    To use the `pyannote.audio` models you must pass a Huggingface token and get access to the required models. The
    token can be passed in one of the following options:

    * Use the parameter `access_token`.
    * Set an environment variable named "HUGGING_FACE_HUB_TOKEN".
    * If using MLRun, you can pass it as a secret named "HUGGING_FACE_HUB_TOKEN".

    To get access to the models on Huggingface, visit their page. For example, to use the default diarization model set
    in this function ("pyannote/speaker-diarization-3.0"), you need access for these two models:

    * https://huggingface.co/pyannote/segmentation-3.0
    * https://huggingface.co/pyannote/speaker-diarization-3.0

    Note: To control the recognized speakers in the diarization output you can choose one of the following methods:

    * For a known speakers amount, you may set speaker labels via the `speakers_labels` parameter that will be used in
      the order of speaking in the audio (first person speaking be the first label in the list). In addition, you can do
      diarization per channel (setting the parameter `separate_by_channels` to True). Each label will be assigned to a
      specific channel by order (first label to channel 0, second label to channel 1 and so on). Notice, this will
      increase runtime.
    * For unknown speakers amount, you can set the `speaker_prefix` parameter to add a prefix for each speaker number.
      You can also help the diarization by setting the speakers range via the `speakers_amount_range` parameter.

    :param data_path:            A directory of the audio files, a single file or a list of files to transcribe.
    :param model_name:           One of the official diarization model names (referred as diarization pipelines) of
                                 `pyannote.audio` Huggingface page. Default: "pyannote/speaker-diarization-3.0".
    :param access_token:         An access token to pass for using the `pyannote.audio` models. If not provided, it
                                 will be looking for the environment variable "HUGGING_FACE_HUB_TOKEN". If MLRun is
                                 available, it will look for a secret "HUGGING_FACE_HUB_TOKEN".
    :param device:               Device to load the model. Can be one of {"cuda", "cpu"}. Default will prefer "cuda" if
                                 available.
    :param speakers_labels:      Labels to use for the recognized speakers. Default: numeric labels (0, 1, ...).
    :param separate_by_channels: If each speaker is speaking in a separate channel, you can diarize each channel and
                                 combine the result into a single diarization. Each label set in the `speakers_labels`
                                 parameter will be assigned to a specific channel by order.
    :param speaker_prefix:       A prefix to add for the speakers labels. This parameter is ignored if
                                 `speakers_labels` is not None. Default: "speaker".
    :param minimum_speakers:     Set the minimum expected amount of speakers to be in the audio files. This parameter is
                                 ignored if `speakers_labels` is not None.
    :param maximum_speakers:     Set the maximum expected amount of speakers to be in the audio files. This parameter is
                                 ignored if `speakers_labels` is not None.
    :param verbose:              Whether to present logs of a progress bar and errors. Default: True.

    :returns: A tuple of:

              * Speech diarization dictionary.
              * A dictionary of errored files that were not transcribed.
    """
    global _LOGGER

    # Get the input audio files to diarize:
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        audio_files = _get_audio_files(data_path=data_path)
    else:  # Should be a list of files.
        audio_files = data_path

    # Get the Huggingface access token:
    access_token = _get_access_token(parameter=access_token)
    if access_token is None:
        raise ValueError(
            "A Huggingface access token must be provided to use `pyannote.audio` models. Access token can be passed "
            "via one of the following options:\n"
            "* Use the parameter `access_token`.\n"
            "* Set an environment variable named 'HUGGING_FACE_HUB_TOKEN'.\n"
            "* If using MLRun, you can pass it as a secret named 'HUGGING_FACE_HUB_TOKEN'."
        )

    # Load the diarization pipeline:
    pipeline = pyannote.audio.Pipeline.from_pretrained(
        checkpoint_path=model_name, use_auth_token=access_token
    )

    # Set the device:
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    if device != "cpu":
        pipeline.to(torch.device(device))

    # Prepare the successes dataframe and errors dictionary to be returned:
    diarizations = {}
    errors = {}

    # Prepare the diarization keyword arguments:
    diarize_kwargs = {}
    if speakers_labels:
        diarize_kwargs["num_speakers"] = len(speakers_labels)
    else:
        if minimum_speakers:
            diarize_kwargs["min_speakers"] = minimum_speakers
        if maximum_speakers:
            diarize_kwargs["max_speakers"] = maximum_speakers

    # Go over the audio files and diarize:
    for audio_file in tqdm(
        audio_files, desc="Diarizing", unit="file", disable=not verbose
    ):
        try:
            # Load audio file:
            audio, sample_rate = torchaudio.load(uri=audio_file, channels_first=True)
            # Get the diarization (if provided):
            diarizations[audio_file.name] = _diarize(
                audio=audio,
                sample_rate=sample_rate,
                pipeline=pipeline,
                speakers_labels=speakers_labels,
                separate_by_channels=separate_by_channels,
                speaker_prefix=speaker_prefix,
                diarize_kwargs=diarize_kwargs,
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{audio_file.name}'")
            errors[str(audio_file.name)] = str(exception)
            continue

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(f"Done ({len(diarizations)}/{len(audio_files)})\n")
    return diarizations, errors


def _get_audio_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return audio_files


def _get_access_token(parameter: str) -> str:
    # If given as a parameter, return it:
    if parameter:
        return parameter

    # Otherwise, look at the environment variable:
    environment_variable = os.environ.get("HUGGING_FACE_HUB_TOKEN")
    if environment_variable:
        return environment_variable

    # Lastly, try look in the set secrets in MLRun:
    secret = None
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        secret = context.get_secret(key="HUGGING_FACE_HUB_TOKEN")
    except ModuleNotFoundError:
        pass

    return secret


def _diarize(
    audio: torch.Tensor,
    sample_rate: int,
    pipeline: pyannote.audio.Pipeline,
    speakers_labels: List[str],
    separate_by_channels: bool,
    speaker_prefix: str,
    diarize_kwargs: dict,
) -> List[Tuple[float, float, str]]:
    # If there is no need for separation by channels, we diarize and return:
    if not separate_by_channels:
        # Diarize:
        diarization: pyannote.core.Annotation = pipeline(
            file={"waveform": audio, "sample_rate": sample_rate}, **diarize_kwargs
        )
        # Verify speakers labels (should not fail here as we set `num_speakers=len(speakers_labels)` when inferring
        # through the pipeline):
        if speakers_labels:
            given_speakers = len(speakers_labels)
            found_speakers = len(set(diarization.labels()))
            if given_speakers < found_speakers:
                raise ValueError(
                    f"Not enough `speakers_labels` were given. Got {given_speakers} labels but the diarization "
                    f"recognized {found_speakers} speakers."
                )
        # Return as a diarization list - a sorted list of tuples of start time, end time and a label (the default label
        # returned is "SPEAKER_i" so we take only the index out of it):
        return [
            (
                segment.start,
                segment.end,
                speakers_labels[int(label.split("_")[1])]
                if speakers_labels
                else f"{speaker_prefix}{int(label.split('_')[1])}",
            )
            for segment, track, label in diarization.itertracks(yield_label=True)
        ]

    # Separate to channels and diarize (we expect only one speaker per channel):
    channel_diarizations = [
        _diarize(
            audio=audio[channel].unsqueeze(
                0
            ),  # Take channel and add a channel dimension to it.
            sample_rate=sample_rate,
            pipeline=pipeline,
            speakers_labels=[
                speakers_labels[channel]
            ],  # Take the channel's label only.
            separate_by_channels=False,
            speaker_prefix=speaker_prefix,
            diarize_kwargs={"num_speakers": 1},  # Set to one speaker.
        )
        for channel in range(audio.shape[0])
    ]

    # Merge the channel diarizations into a single sorted list:
    return list(heapq.merge(*channel_diarizations))
 - commands: [] - code_origin: '' - origin_filename: '' - requirements: [] - entry_points: - open_mpi_handler: - name: open_mpi_handler - doc: '' - parameters: - - name: worker_inputs - type: List[str] - default: '' - - name: root_worker_inputs - type: Dict[str, Any] - default: null - outputs: - - default: '' - lineno: 59 - decorator: - name: decorator - doc: '' - parameters: - - name: handler - default: '' - outputs: - - default: '' - lineno: 71 - wrapper: - name: wrapper - doc: '' - parameters: [] - outputs: - - default: '' - lineno: 76 - diarize: - name: diarize - doc: "Perform speech diarization on given audio files using pyannote-audio (https://github.com/pyannote/pyannote-audio).\n\ - The end result is a dictionary with the file names as keys and their diarization\ - \ as value. A diarization is a list\nof tuples: (start, end, speaker_label).\n\ - \nTo use the `pyannote.audio` models you must pass a Huggingface token and\ - \ get access to the required models. The\ntoken can be passed in one of the\ - \ following options:\n\n* Use the parameter `access_token`.\n* Set an environment\ - \ variable named \"HUGGING_FACE_HUB_TOKEN\".\n* If using MLRun, you can pass\ - \ it as a secret named \"HUGGING_FACE_HUB_TOKEN\".\n\nTo get access to the\ - \ models on Huggingface, visit their page. For example, to use the default\ - \ diarization model set\nin this function (\"pyannote/speaker-diarization-3.0\"\ - ), you need access for these two models:\n\n* https://huggingface.co/pyannote/segmentation-3.0\n\ - * https://huggingface.co/pyannote/speaker-diarization-3.0\n\nNote: To control\ - \ the recognized speakers in the diarization output you can choose one of\ - \ the following methods:\n\n* For a known speakers amount, you may set speaker\ - \ labels via the `speakers_labels` parameter that will be used in\n the order\ - \ of speaking in the audio (first person speaking be the first label in the\ - \ list). In addition, you can do\n diarization per channel (setting the parameter\ - \ `separate_by_channels` to True). Each label will be assigned to a\n specific\ - \ channel by order (first label to channel 0, second label to channel 1 and\ - \ so on). Notice, this will\n increase runtime.\n* For unknown speakers amount,\ - \ you can set the `speaker_prefix` parameter to add a prefix for each speaker\ - \ number.\n You can also help the diarization by setting the speakers range\ - \ via the `speakers_amount_range` parameter." - parameters: - - name: data_path - type: Union[str, List[str]] - doc: A directory of the audio files, a single file or a list of files to transcribe. - default: '' - - name: model_name - type: str - doc: 'One of the official diarization model names (referred as diarization - pipelines) of `pyannote.audio` Huggingface page. Default: "pyannote/speaker-diarization-3.0".' - default: pyannote/speaker-diarization-3.0 - - name: access_token - type: str - doc: An access token to pass for using the `pyannote.audio` models. If not - provided, it will be looking for the environment variable "HUGGING_FACE_HUB_TOKEN". - If MLRun is available, it will look for a secret "HUGGING_FACE_HUB_TOKEN". - default: null - - name: device - type: str - doc: Device to load the model. Can be one of {"cuda", "cpu"}. Default will - prefer "cuda" if available. - default: null - - name: speakers_labels - type: List[str] - doc: 'Labels to use for the recognized speakers. Default: numeric labels (0, - 1, ...).' - default: null - - name: speaker_prefix - type: str - doc: 'A prefix to add for the speakers labels. This parameter is ignored if - `speakers_labels` is not None. Default: "speaker".' - default: speaker_ - - name: separate_by_channels - type: bool - doc: If each speaker is speaking in a separate channel, you can diarize each - channel and combine the result into a single diarization. Each label set - in the `speakers_labels` parameter will be assigned to a specific channel - by order. - default: false - - name: minimum_speakers - type: int - doc: Set the minimum expected amount of speakers to be in the audio files. - This parameter is ignored if `speakers_labels` is not None. - default: null - - name: maximum_speakers - type: int - doc: Set the maximum expected amount of speakers to be in the audio files. - This parameter is ignored if `speakers_labels` is not None. - default: null - - name: verbose - type: bool - doc: 'Whether to present logs of a progress bar and errors. Default: True.' - default: false - outputs: - - default: '' - doc: 'A tuple of:' - lineno: 137 - description: speech diarization of audio files - default_handler: diarize - disable_auto_mount: false - clone_target_dir: '' - env: [] - priority_class_name: '' - preemption_mode: prevent - affinity: null - tolerations: null - security_context: {} -verbose: false diff --git a/structured_data_generator/function.yaml b/structured_data_generator/function.yaml index f6c1ea5e0..82f48295e 100644 --- a/structured_data_generator/function.yaml +++ b/structured_data_generator/function.yaml @@ -2,7 +2,7 @@ kind: job metadata: name: structured-data-generator tag: '' - hash: 775c1a59adea52f5a1a4d26c96925c88474015f3 + hash: aa811f5c583d081b71d4da97088837546e29c4a1 project: '' labels: author: zeevr @@ -16,7 +16,7 @@ spec: args: [] image: '' build: - functionSourceCode: IyBDb3B5cmlnaHQgMjAyMyBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMAojCiMgVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQojIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuICJBUyBJUyIgQkFTSVMsCiMgV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuCiMgU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAojIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLgppbXBvcnQgYXN0CmltcG9ydCBvcwoKaW1wb3J0IHRxZG0KZnJvbSBsYW5nY2hhaW4uY2hhdF9tb2RlbHMgaW1wb3J0IENoYXRPcGVuQUkKCgpkZWYgX3NldF9vcGVuYWlfc2VjcmV0cygpIC0+IGJvb2w6CiAgICBrZXkgPSAiT1BFTkFJX0FQSV9LRVkiCiAgICBiYXNlID0gIk9QRU5BSV9BUElfQkFTRSIKICAgICMgQ2hlY2sgaWYgdGhlIGtleSBpcyBhbHJlYWR5IGluIHRoZSBlbnZpcm9ubWVudCB2YXJpYWJsZXM6CiAgICBpZiBrZXkgaW4gb3MuZW52aXJvbiBhbmQgYmFzZSBpbiBvcy5lbnZpcm9uOgogICAgICAgIHJldHVybiBUcnVlCiAgICAjIENoZWNrIGlmIG1scnVuIGlzIGluc3RhbGxlZDoKICAgIHRyeToKICAgICAgICBpbXBvcnQgbWxydW4KICAgIGV4Y2VwdCBNb2R1bGVOb3RGb3VuZEVycm9yOgogICAgICAgIHJhaXNlIEVudmlyb25tZW50RXJyb3IoCiAgICAgICAgICAgIGYiT25lIG9yIG1vcmUgb2YgdGhlIE9wZW5BSSByZXF1aXJlZCBlbnZpcm9ubWVudCB2YXJpYWJsZXMgKCd7a2V5fScsICd7YmFzZX0nKSBhcmUgbWlzc2luZy4iCiAgICAgICAgICAgIGYiUGxlYXNlIHNldCB0aGVtIGFzIGVudmlyb25tZW50IHZhcmlhYmxlcyBvciBpbnN0YWxsIG1scnVuIChgcGlwIGluc3RhbGwgbWxydW5gKSIKICAgICAgICAgICAgZiJhbmQgc2V0IHRoZW0gYXMgcHJvamVjdCBzZWNyZXRzIHVzaW5nIGBwcm9qZWN5LnNldF9zZWNyZXRzYC4iCiAgICAgICAgKQoKICAgICMgQ2hlY2sgaWYgdGhlIGtleSBpcyBpbiB0aGUgc2VjcmV0czoKICAgIGNvbnRleHQgPSBtbHJ1bi5nZXRfb3JfY3JlYXRlX2NvbnRleHQobmFtZT0iY29udGV4dCIpCiAgICBvcGVuYWlfa2V5ID0gY29udGV4dC5nZXRfc2VjcmV0KGtleSwgTm9uZSkKICAgIG9wZW5haV9iYXNlID0gY29udGV4dC5nZXRfc2VjcmV0KGJhc2UsIE5vbmUpCgogICAgIyBJZiB0aGUga2V5IGlzIG5vdCBpbiB0aGUgc2VjcmV0cywgcmV0dXJuIEZhbHNlOgogICAgaWYgbm90IG9wZW5haV9rZXk6CiAgICAgICAgcmFpc2UgRW52aXJvbm1lbnRFcnJvcigKICAgICAgICAgICAgZiJDb3VsZCBub3QgZmluZCBPcGVuQUkgQVBJIGtleSBpbiB0aGUgZW52aXJvbm1lbnQgdmFyaWFibGVzIG9yIHNlY3JldHMsIgogICAgICAgICAgICBmIiBwbGVhc2Ugc2V0IGl0IGFzOiB7a2V5fS4iCiAgICAgICAgKQogICAgaWYgbm90IG9wZW5haV9iYXNlOgogICAgICAgIHJhaXNlIEVudmlyb25tZW50RXJyb3IoCiAgICAgICAgICAgIGYiQ291bGQgbm90IGZpbmQgT3BlbkFJIEFQSSBiYXNlIGluIHRoZSBlbnZpcm9ubWVudCB2YXJpYWJsZXMgb3Igc2VjcmV0cywiCiAgICAgICAgICAgIGYiIHBsZWFzZSBzZXQgaXQgYXM6IHtiYXNlfS4iCiAgICAgICAgKQogICAgIyBJZiB0aGUga2V5IGlzIGluIHRoZSBzZWNyZXRzLCBzZXQgaXQgaW4gdGhlIGVudmlyb25tZW50IHZhcmlhYmxlcyBhbmQgcmV0dXJuIFRydWU6CiAgICBvcy5lbnZpcm9uW2tleV0gPSBvcGVuYWlfa2V5CiAgICBvcy5lbnZpcm9uW2Jhc2VdID0gb3BlbmFpX2Jhc2UKICAgIHJldHVybiBUcnVlCgoKZGVmIGdlbmVyYXRlX2RhdGEoCiAgICBmaWVsZHM6IGxpc3QsCiAgICBhbW91bnQ6IGludCA9IDEwLAogICAgbW9kZWxfbmFtZTogc3RyID0gImdwdC0zLjUtdHVyYm8iLAogICAgbGFuZ3VhZ2U6IHN0ciA9ICJlbiIsCiAgICBjaHVua19zaXplOiBpbnQgPSA1MCwKKSAtPiBsaXN0OgogICAgIiIiCiAgICBTdHJ1Y3R1cmVkIGRhdGEgb2YgZWxlbWVudHMgYWNjb3JkaW5nIHRvIHRoZSBnaXZlbiBwYXJhbWV0ZXJzLgogICAgVGhlIGRhdGEgY2FuIGJlIGxhdGVyIGxvZ2dlZCBhcyBhIHN0cnVjdHVyZWQgZmlsZSB3aXRoIE1MUnVuJ3MgYHJldHVybnNgIHBhcmFtZXRlci4KCiAgICA6cGFyYW0gZmllbGRzOiBBIGxpc3Qgb2YgZmllbGRzIHRvIHJhbmRvbWx5IGdlbmVyYXRlLgogICAgOnBhcmFtIGFtb3VudDogVGhlIG51bWJlciBvZiB2YXJpYW50cyB0byBnZW5lcmF0ZS4KICAgIDpwYXJhbSBtb2RlbF9uYW1lOiBUaGUgbmFtZSBvZiB0aGUgbW9kZWwgdG8gdXNlIGZvciBjb252ZXJzYXRpb24gZ2VuZXJhdGlvbi4KICAgICAgICAgICAgICAgICAgICAgICBZb3Ugc2hvdWxkIGNob29zZSBvbmUgb2YgR1BULTQgb3IgR1BULTMuNSBmcm9tIHRoZSBsaXN0IGhlcmU6IGh0dHBzOi8vcGxhdGZvcm0ub3BlbmFpLmNvbS9kb2NzL21vZGVscy4KICAgICAgICAgICAgICAgICAgICAgICBEZWZhdWx0OiAnZ3B0LTMuNS10dXJibycuCiAgICA6cGFyYW0gbGFuZ3VhZ2U6IFRoZSBsYW5ndWFnZSB0byB1c2UgZm9yIHRoZSBnZW5lcmF0ZWQgY29udmVyc2F0aW9uIHRleHQuCiAgICA6cGFyYW0gY2h1bmtfc2l6ZTogTnVtYmVyIG9mIHNhbXBsZXMgZ2VuZXJhdGVkIGF0IGVhY2ggR1BUIHF1ZXJ5LgogICAgIiIiCiAgICBpbnN0cnVjdGlvbnMgPSAiIgogICAgZm9yIGZpZWxkIGluIGZpZWxkczoKICAgICAgICAjIFNwbGl0IHRoZSBmaWVsZCB0byBrZXkgYW5kIGluc3RydWN0aW9uOgogICAgICAgIGlmICI6IiBpbiBmaWVsZDoKICAgICAgICAgICAga2V5LCBpbnN0cnVjdGlvbiA9IGZpZWxkLnNwbGl0KCI6IiwgMSkKICAgICAgICBlbHNlOgogICAgICAgICAgICBrZXksIGluc3RydWN0aW9uID0gZmllbGQsICJubyBzcGVjaWFsIGluc3RydWN0aW9uIgogICAgICAgICMgUmVwbGFjZSBzcGFjZXMgd2l0aCB1bmRlcnNjb3JlcyBmb3IgdGhlIGtleSB0byBiZSB1c2VkIGFzIGEganNvbiBrZXk6CiAgICAgICAga2V5ID0ga2V5LnJlcGxhY2UoIiAiLCAiXyIpCiAgICAgICAgaW5zdHJ1Y3Rpb25zICs9IGYiKiB7a2V5fToge2luc3RydWN0aW9ufVxuIgoKICAgICMgQ3JlYXRlIHRoZSBwcm9tcHQgc3RydWN0dXJlOgogICAgcHJvbXB0X3N0cnVjdHVyZSA9ICgKICAgICAgICBmImdlbmVyYXRlIHRoZSBmb2xsb3dpbmcgdmFsdWVzIHthbW91bnR9IHRpbWVzIHJhbmRvbWx5LCBpbiBhbiBvcmRlciB0aGF0IGNyZWF0ZXMgYSBqc29uIHRhYmxlLlxuIgogICAgICAgIGYiVXNlIHRoZSBmb2xsb3dpbmcga2V5cyBhbmQgaW5zdHJ1Y3Rpb25zIChleGFtcGxlOiAna2V5OiBpbnN0cnVjdGlvbiBvciBubyBzcGVjaWFsIGluc3RydWN0aW9uJyk6ICIKICAgICAgICBmIntpbnN0cnVjdGlvbnN9LlxuIgogICAgICAgIGYiUGxlYXNlIGdlbmVyYXRlIHRoZSB2YWx1ZXMgaW4ge2xhbmd1YWdlfSBsYW5ndWFnZS4gXG4iCiAgICAgICAgZiJNYWtlIHN1cmUgdGhlIG5hbWVzIG9mIHRoZSBrZXlzIGFyZSB0aGUgc2FtZSBhcyB0aGUgZ2l2ZW4gZmllbGQgbmFtZS5cbiIKICAgICAgICBmIlBsZWFzZSByZXR1cm4gb25seSB0aGUganNvbiBmb3JtYXQgd2l0aG91dCBhbnkgaW50cm9kdWN0aW9uIGFuZCBlbmRpbmciCiAgICApCgogICAgIyBTZXQgdGhlIE9wZW5BSSBzZWNyZXRzOgogICAgX3NldF9vcGVuYWlfc2VjcmV0cygpCgogICAgIyBMb2FkIHRoZSBPcGVuQUkgbW9kZWwgdXNpbmcgbGFuZ2NoYWluOgogICAgbGxtID0gQ2hhdE9wZW5BSShtb2RlbD1tb2RlbF9uYW1lKQoKICAgICMgU3RhcnQgZ2VuZXJhdGluZyBkYXRhOgogICAgZGF0YSA9IFtdCiAgICBmb3IgXyBpbiB0cWRtLnRxZG0ocmFuZ2UoKGFtb3VudCAvLyBjaHVua19zaXplKSArIDEpLCBkZXNjPSJHZW5lcmF0aW5nIik6CiAgICAgICAgIyBXZSB0cnkgdG8gZ2VuZXJhdGUgdGhlIGRhdGEgMyB0aW1lcywgaWYgd2UgZmFpbCB3ZSByYWlzZSBhbiBlcnJvcjoKICAgICAgICBmb3IgdHJ5b3V0IGluIHJhbmdlKDMpOgogICAgICAgICAgICAjIElmIHRoZSBhbW91bnQgd2FudGVkIGlzIGJpZ2dlciB0aGFuIHRoZSBjaHVuayBzaXplLCB3ZSBnZW5lcmF0ZSBhIGNodW5rIG9mIGRhdGEgaW4gdGhlIHNpemUgb2YgdGhlIGNodW5rCiAgICAgICAgICAgICMgYW5kIGRlY3JlYXNlIHRoZSBhbW91bnQgYnkgdGhlIGNodW5rIHNpemUuCiAgICAgICAgICAgICMgb3RoZXJ3aXNlIHdlIGdlbmVyYXRlIGEgY2h1bmsgb2YgZGF0YSBpbiB0aGUgc2l6ZSBvZiB0aGUgYW1vdW50OgogICAgICAgICAgICBpZiBhbW91bnQgPiBjaHVua19zaXplOgogICAgICAgICAgICAgICAgY3VycmVudF9jaHVua19zaXplID0gY2h1bmtfc2l6ZQogICAgICAgICAgICAgICAgYW1vdW50IC09IGNodW5rX3NpemUKICAgICAgICAgICAgZWxzZToKICAgICAgICAgICAgICAgIGN1cnJlbnRfY2h1bmtfc2l6ZSA9IGFtb3VudAoKICAgICAgICAgICAgIyBDcmVhdGUgdGhlIHByb21wdDoKICAgICAgICAgICAgcHJvbXB0ID0gcHJvbXB0X3N0cnVjdHVyZS5mb3JtYXQoCiAgICAgICAgICAgICAgICBhbW91bnQ9Y3VycmVudF9jaHVua19zaXplLAogICAgICAgICAgICApCgogICAgICAgICAgICAjIEdlbmVyYXRlIGEgY2h1bmsgb2YgZGF0YToKICAgICAgICAgICAgY2h1bmtfZGF0YSA9IGxsbS5wcmVkaWN0KHRleHQ9cHJvbXB0KQoKICAgICAgICAgICAgIyBWYWxpZGF0ZSB0aGUgcmVzcG9uc2UgZm9yIGNvcnJlY3QgcHl0aG9uIGBsaXN0YCBzdHJ1Y3R1cmUKICAgICAgICAgICAgY2h1bmtfZGF0YSA9IGNodW5rX2RhdGFbY2h1bmtfZGF0YS5maW5kKCJbIikgOiBjaHVua19kYXRhLnJmaW5kKCJdIikgKyAxXQogICAgICAgICAgICBpZiBjaHVua19kYXRhLmNvdW50KCJbIikgIT0gY2h1bmtfZGF0YS5jb3VudCgiXSIpOgogICAgICAgICAgICAgICAgcHJpbnQoCiAgICAgICAgICAgICAgICAgICAgIkZhaWxlZCB0byBnZXQgcHJvcGVyIGpzb24gZm9ybWF0IGZyb20gbW9kZWwsIG51bWJlciBvZiAnWycgZG9lc24ndCBtYXRjaCBudW1iZXIgb2YgJ10nLiIKICAgICAgICAgICAgICAgICkKICAgICAgICAgICAgICAgIGNvbnRpbnVlCiAgICAgICAgICAgIGNodW5rX2RhdGEgPSBhc3QubGl0ZXJhbF9ldmFsKGNodW5rX2RhdGEpCiAgICAgICAgICAgIGRhdGEgKz0gY2h1bmtfZGF0YQogICAgICAgICAgICBicmVhawogICAgICAgIGlmIHRyeW91dCA9PSAzOgogICAgICAgICAgICByYWlzZSBSdW50aW1lRXJyb3IoCiAgICAgICAgICAgICAgICBmIkNvdWxkIG5vdCBnZW5lcmF0ZSBhIHByb3BlciBqc29uIGZvcm1hdCBmb3IgdGhlIGdpdmVuIGZpZWxkcywgdXNpbmcgZ2l2ZW4gbW9kZWw6IHttb2RlbF9uYW1lfS4iCiAgICAgICAgICAgICAgICBmIiBIaW50OiBHcHQtNCB3b3JrcyBiZXN0IGZvciBtb3N0IHNjZW5hcmlvcy4iCiAgICAgICAgICAgICkKICAgIHJldHVybiBkYXRhCg== + functionSourceCode: IyBDb3B5cmlnaHQgMjAyMyBJZ3VhemlvCiMKIyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKIyB5b3UgbWF5IG5vdCB1c2UgdGhpcyBmaWxlIGV4Y2VwdCBpbiBjb21wbGlhbmNlIHdpdGggdGhlIExpY2Vuc2UuCiMgWW91IG1heSBvYnRhaW4gYSBjb3B5IG9mIHRoZSBMaWNlbnNlIGF0CiMKIyAgIGh0dHA6Ly93d3cuYXBhY2hlLm9yZy9saWNlbnNlcy9MSUNFTlNFLTIuMAojCiMgVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQojIGRpc3RyaWJ1dGVkIHVuZGVyIHRoZSBMaWNlbnNlIGlzIGRpc3RyaWJ1dGVkIG9uIGFuICJBUyBJUyIgQkFTSVMsCiMgV0lUSE9VVCBXQVJSQU5USUVTIE9SIENPTkRJVElPTlMgT0YgQU5ZIEtJTkQsIGVpdGhlciBleHByZXNzIG9yIGltcGxpZWQuCiMgU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAojIGxpbWl0YXRpb25zIHVuZGVyIHRoZSBMaWNlbnNlLgppbXBvcnQgYXN0CmltcG9ydCBvcwoKaW1wb3J0IHRxZG0KZnJvbSBsYW5nY2hhaW4uY2hhdF9tb2RlbHMgaW1wb3J0IENoYXRPcGVuQUkKCgpkZWYgX3NldF9vcGVuYWlfc2VjcmV0cygpIC0+IGJvb2w6CiAgICBrZXkgPSAiT1BFTkFJX0FQSV9LRVkiCiAgICBiYXNlID0gIk9QRU5BSV9BUElfQkFTRSIKICAgICMgQ2hlY2sgaWYgdGhlIGtleSBpcyBhbHJlYWR5IGluIHRoZSBlbnZpcm9ubWVudCB2YXJpYWJsZXM6CiAgICBpZiBrZXkgaW4gb3MuZW52aXJvbiBhbmQgYmFzZSBpbiBvcy5lbnZpcm9uOgogICAgICAgIHJldHVybiBUcnVlCiAgICAjIENoZWNrIGlmIG1scnVuIGlzIGluc3RhbGxlZDoKICAgIHRyeToKICAgICAgICBpbXBvcnQgbWxydW4KICAgIGV4Y2VwdCBNb2R1bGVOb3RGb3VuZEVycm9yOgogICAgICAgIHJhaXNlIEVudmlyb25tZW50RXJyb3IoCiAgICAgICAgICAgIGYiT25lIG9yIG1vcmUgb2YgdGhlIE9wZW5BSSByZXF1aXJlZCBlbnZpcm9ubWVudCB2YXJpYWJsZXMgKCd7a2V5fScsICd7YmFzZX0nKSBhcmUgbWlzc2luZy4iCiAgICAgICAgICAgIGYiUGxlYXNlIHNldCB0aGVtIGFzIGVudmlyb25tZW50IHZhcmlhYmxlcyBvciBpbnN0YWxsIG1scnVuIChgcGlwIGluc3RhbGwgbWxydW5gKSIKICAgICAgICAgICAgZiJhbmQgc2V0IHRoZW0gYXMgcHJvamVjdCBzZWNyZXRzIHVzaW5nIGBwcm9qZWN5LnNldF9zZWNyZXRzYC4iCiAgICAgICAgKQoKICAgICMgQ2hlY2sgaWYgdGhlIGtleSBpcyBpbiB0aGUgc2VjcmV0czoKICAgIGNvbnRleHQgPSBtbHJ1bi5nZXRfb3JfY3JlYXRlX2N0eChuYW1lPSJjb250ZXh0IikKICAgIG9wZW5haV9rZXkgPSBjb250ZXh0LmdldF9zZWNyZXQoa2V5KQogICAgb3BlbmFpX2Jhc2UgPSBjb250ZXh0LmdldF9zZWNyZXQoYmFzZSkKCiAgICAjIElmIHRoZSBrZXkgaXMgbm90IGluIHRoZSBzZWNyZXRzLCByZXR1cm4gRmFsc2U6CiAgICBpZiBub3Qgb3BlbmFpX2tleToKICAgICAgICByYWlzZSBFbnZpcm9ubWVudEVycm9yKAogICAgICAgICAgICBmIkNvdWxkIG5vdCBmaW5kIE9wZW5BSSBBUEkga2V5IGluIHRoZSBlbnZpcm9ubWVudCB2YXJpYWJsZXMgb3Igc2VjcmV0cywiCiAgICAgICAgICAgIGYiIHBsZWFzZSBzZXQgaXQgYXM6IHtrZXl9LiIKICAgICAgICApCiAgICBpZiBub3Qgb3BlbmFpX2Jhc2U6CiAgICAgICAgcmFpc2UgRW52aXJvbm1lbnRFcnJvcigKICAgICAgICAgICAgZiJDb3VsZCBub3QgZmluZCBPcGVuQUkgQVBJIGJhc2UgaW4gdGhlIGVudmlyb25tZW50IHZhcmlhYmxlcyBvciBzZWNyZXRzLCIKICAgICAgICAgICAgZiIgcGxlYXNlIHNldCBpdCBhczoge2Jhc2V9LiIKICAgICAgICApCiAgICAjIElmIHRoZSBrZXkgaXMgaW4gdGhlIHNlY3JldHMsIHNldCBpdCBpbiB0aGUgZW52aXJvbm1lbnQgdmFyaWFibGVzIGFuZCByZXR1cm4gVHJ1ZToKICAgIG9zLmVudmlyb25ba2V5XSA9IG9wZW5haV9rZXkKICAgIG9zLmVudmlyb25bYmFzZV0gPSBvcGVuYWlfYmFzZQogICAgcmV0dXJuIFRydWUKCgpkZWYgZ2VuZXJhdGVfZGF0YSgKICAgIGZpZWxkczogbGlzdCwKICAgIGFtb3VudDogaW50ID0gMTAsCiAgICBtb2RlbF9uYW1lOiBzdHIgPSAiZ3B0LTMuNS10dXJibyIsCiAgICBsYW5ndWFnZTogc3RyID0gImVuIiwKICAgIGNodW5rX3NpemU6IGludCA9IDUwLAopIC0+IGxpc3Q6CiAgICAiIiIKICAgIFN0cnVjdHVyZWQgZGF0YSBvZiBlbGVtZW50cyBhY2NvcmRpbmcgdG8gdGhlIGdpdmVuIHBhcmFtZXRlcnMuCiAgICBUaGUgZGF0YSBjYW4gYmUgbGF0ZXIgbG9nZ2VkIGFzIGEgc3RydWN0dXJlZCBmaWxlIHdpdGggTUxSdW4ncyBgcmV0dXJuc2AgcGFyYW1ldGVyLgoKICAgIDpwYXJhbSBmaWVsZHM6IEEgbGlzdCBvZiBmaWVsZHMgdG8gcmFuZG9tbHkgZ2VuZXJhdGUuCiAgICA6cGFyYW0gYW1vdW50OiBUaGUgbnVtYmVyIG9mIHZhcmlhbnRzIHRvIGdlbmVyYXRlLgogICAgOnBhcmFtIG1vZGVsX25hbWU6IFRoZSBuYW1lIG9mIHRoZSBtb2RlbCB0byB1c2UgZm9yIGNvbnZlcnNhdGlvbiBnZW5lcmF0aW9uLgogICAgICAgICAgICAgICAgICAgICAgIFlvdSBzaG91bGQgY2hvb3NlIG9uZSBvZiBHUFQtNCBvciBHUFQtMy41IGZyb20gdGhlIGxpc3QgaGVyZTogaHR0cHM6Ly9wbGF0Zm9ybS5vcGVuYWkuY29tL2RvY3MvbW9kZWxzLgogICAgICAgICAgICAgICAgICAgICAgIERlZmF1bHQ6ICdncHQtMy41LXR1cmJvJy4KICAgIDpwYXJhbSBsYW5ndWFnZTogVGhlIGxhbmd1YWdlIHRvIHVzZSBmb3IgdGhlIGdlbmVyYXRlZCBjb252ZXJzYXRpb24gdGV4dC4KICAgIDpwYXJhbSBjaHVua19zaXplOiBOdW1iZXIgb2Ygc2FtcGxlcyBnZW5lcmF0ZWQgYXQgZWFjaCBHUFQgcXVlcnkuCiAgICAiIiIKICAgIGluc3RydWN0aW9ucyA9ICIiCiAgICBmb3IgZmllbGQgaW4gZmllbGRzOgogICAgICAgICMgU3BsaXQgdGhlIGZpZWxkIHRvIGtleSBhbmQgaW5zdHJ1Y3Rpb246CiAgICAgICAgaWYgIjoiIGluIGZpZWxkOgogICAgICAgICAgICBrZXksIGluc3RydWN0aW9uID0gZmllbGQuc3BsaXQoIjoiLCAxKQogICAgICAgIGVsc2U6CiAgICAgICAgICAgIGtleSwgaW5zdHJ1Y3Rpb24gPSBmaWVsZCwgIm5vIHNwZWNpYWwgaW5zdHJ1Y3Rpb24iCiAgICAgICAgIyBSZXBsYWNlIHNwYWNlcyB3aXRoIHVuZGVyc2NvcmVzIGZvciB0aGUga2V5IHRvIGJlIHVzZWQgYXMgYSBqc29uIGtleToKICAgICAgICBrZXkgPSBrZXkuc3RyaXAoKS5yZXBsYWNlKCIgIiwgIl8iKQogICAgICAgIGluc3RydWN0aW9ucyArPSBmIioge2tleX06IHtpbnN0cnVjdGlvbn1cbiIKCiAgICAjIENyZWF0ZSB0aGUgcHJvbXB0IHN0cnVjdHVyZToKICAgIHByb21wdF9zdHJ1Y3R1cmUgPSAoCiAgICAgICAgZiJnZW5lcmF0ZSB0aGUgZm9sbG93aW5nIHZhbHVlcyB7YW1vdW50fSB0aW1lcyByYW5kb21seSwgaW4gYW4gb3JkZXIgdGhhdCBjcmVhdGVzIGEganNvbiB0YWJsZS5cbiIKICAgICAgICBmIlVzZSB0aGUgZm9sbG93aW5nIGtleXMgYW5kIGluc3RydWN0aW9ucyAoZXhhbXBsZTogJ2tleTogaW5zdHJ1Y3Rpb24gb3Igbm8gc3BlY2lhbCBpbnN0cnVjdGlvbicpOiAiCiAgICAgICAgZiJ7aW5zdHJ1Y3Rpb25zfS5cbiIKICAgICAgICBmIlBsZWFzZSBnZW5lcmF0ZSB0aGUgdmFsdWVzIGluIHtsYW5ndWFnZX0gbGFuZ3VhZ2UuIFxuIgogICAgICAgIGYiTWFrZSBzdXJlIHRoZSBuYW1lcyBvZiB0aGUga2V5cyBhcmUgdGhlIHNhbWUgYXMgdGhlIGdpdmVuIGZpZWxkIG5hbWUuXG4iCiAgICAgICAgZiJQbGVhc2UgcmV0dXJuIG9ubHkgdGhlIGpzb24gZm9ybWF0IHdpdGhvdXQgYW55IGludHJvZHVjdGlvbiBhbmQgZW5kaW5nIgogICAgKQoKICAgICMgU2V0IHRoZSBPcGVuQUkgc2VjcmV0czoKICAgIF9zZXRfb3BlbmFpX3NlY3JldHMoKQoKICAgICMgTG9hZCB0aGUgT3BlbkFJIG1vZGVsIHVzaW5nIGxhbmdjaGFpbjoKICAgIGxsbSA9IENoYXRPcGVuQUkobW9kZWw9bW9kZWxfbmFtZSkKCiAgICAjIFN0YXJ0IGdlbmVyYXRpbmcgZGF0YToKICAgIGRhdGEgPSBbXQogICAgZm9yIF8gaW4gdHFkbS50cWRtKHJhbmdlKChhbW91bnQgLy8gY2h1bmtfc2l6ZSkgKyAxKSwgZGVzYz0iR2VuZXJhdGluZyIpOgogICAgICAgICMgV2UgdHJ5IHRvIGdlbmVyYXRlIHRoZSBkYXRhIDMgdGltZXMsIGlmIHdlIGZhaWwgd2UgcmFpc2UgYW4gZXJyb3I6CiAgICAgICAgZm9yIHRyeW91dCBpbiByYW5nZSgzKToKICAgICAgICAgICAgIyBJZiB0aGUgYW1vdW50IHdhbnRlZCBpcyBiaWdnZXIgdGhhbiB0aGUgY2h1bmsgc2l6ZSwgd2UgZ2VuZXJhdGUgYSBjaHVuayBvZiBkYXRhIGluIHRoZSBzaXplIG9mIHRoZSBjaHVuawogICAgICAgICAgICAjIGFuZCBkZWNyZWFzZSB0aGUgYW1vdW50IGJ5IHRoZSBjaHVuayBzaXplLgogICAgICAgICAgICAjIG90aGVyd2lzZSB3ZSBnZW5lcmF0ZSBhIGNodW5rIG9mIGRhdGEgaW4gdGhlIHNpemUgb2YgdGhlIGFtb3VudDoKICAgICAgICAgICAgaWYgYW1vdW50ID4gY2h1bmtfc2l6ZToKICAgICAgICAgICAgICAgIGN1cnJlbnRfY2h1bmtfc2l6ZSA9IGNodW5rX3NpemUKICAgICAgICAgICAgICAgIGFtb3VudCAtPSBjaHVua19zaXplCiAgICAgICAgICAgIGVsc2U6CiAgICAgICAgICAgICAgICBjdXJyZW50X2NodW5rX3NpemUgPSBhbW91bnQKCiAgICAgICAgICAgICMgQ3JlYXRlIHRoZSBwcm9tcHQ6CiAgICAgICAgICAgIHByb21wdCA9IHByb21wdF9zdHJ1Y3R1cmUuZm9ybWF0KAogICAgICAgICAgICAgICAgYW1vdW50PWN1cnJlbnRfY2h1bmtfc2l6ZSwKICAgICAgICAgICAgKQoKICAgICAgICAgICAgIyBHZW5lcmF0ZSBhIGNodW5rIG9mIGRhdGE6CiAgICAgICAgICAgIGNodW5rX2RhdGEgPSBsbG0ucHJlZGljdCh0ZXh0PXByb21wdCkKCiAgICAgICAgICAgICMgVmFsaWRhdGUgdGhlIHJlc3BvbnNlIGZvciBjb3JyZWN0IHB5dGhvbiBgbGlzdGAgc3RydWN0dXJlCiAgICAgICAgICAgIGNodW5rX2RhdGEgPSBjaHVua19kYXRhW2NodW5rX2RhdGEuZmluZCgiWyIpIDogY2h1bmtfZGF0YS5yZmluZCgiXSIpICsgMV0KICAgICAgICAgICAgaWYgY2h1bmtfZGF0YS5jb3VudCgiWyIpICE9IGNodW5rX2RhdGEuY291bnQoIl0iKToKICAgICAgICAgICAgICAgIHByaW50KAogICAgICAgICAgICAgICAgICAgICJGYWlsZWQgdG8gZ2V0IHByb3BlciBqc29uIGZvcm1hdCBmcm9tIG1vZGVsLCBudW1iZXIgb2YgJ1snIGRvZXNuJ3QgbWF0Y2ggbnVtYmVyIG9mICddJy4iCiAgICAgICAgICAgICAgICApCiAgICAgICAgICAgICAgICBjb250aW51ZQogICAgICAgICAgICBjaHVua19kYXRhID0gYXN0LmxpdGVyYWxfZXZhbChjaHVua19kYXRhKQogICAgICAgICAgICBkYXRhICs9IGNodW5rX2RhdGEKICAgICAgICAgICAgYnJlYWsKICAgICAgICBpZiB0cnlvdXQgPT0gMzoKICAgICAgICAgICAgcmFpc2UgUnVudGltZUVycm9yKAogICAgICAgICAgICAgICAgZiJDb3VsZCBub3QgZ2VuZXJhdGUgYSBwcm9wZXIganNvbiBmb3JtYXQgZm9yIHRoZSBnaXZlbiBmaWVsZHMsIHVzaW5nIGdpdmVuIG1vZGVsOiB7bW9kZWxfbmFtZX0uIgogICAgICAgICAgICAgICAgZiIgSGludDogR3B0LTQgd29ya3MgYmVzdCBmb3IgbW9zdCBzY2VuYXJpb3MuIgogICAgICAgICAgICApCiAgICByZXR1cm4gZGF0YQo= base_image: mlrun/mlrun commands: [] code_origin: '' diff --git a/structured_data_generator/item.yaml b/structured_data_generator/item.yaml index b854f0834..8b3644fbd 100755 --- a/structured_data_generator/item.yaml +++ b/structured_data_generator/item.yaml @@ -26,4 +26,4 @@ spec: - langchain - tqdm url: '' -version: 1.1.0 +version: 1.3.0 diff --git a/structured_data_generator/structured_data_generator.py b/structured_data_generator/structured_data_generator.py index 2ace492c5..34fa36d49 100644 --- a/structured_data_generator/structured_data_generator.py +++ b/structured_data_generator/structured_data_generator.py @@ -35,9 +35,9 @@ def _set_openai_secrets() -> bool: ) # Check if the key is in the secrets: - context = mlrun.get_or_create_context(name="context") - openai_key = context.get_secret(key, None) - openai_base = context.get_secret(base, None) + context = mlrun.get_or_create_ctx(name="context") + openai_key = context.get_secret(key) + openai_base = context.get_secret(base) # If the key is not in the secrets, return False: if not openai_key: @@ -83,7 +83,7 @@ def generate_data( else: key, instruction = field, "no special instruction" # Replace spaces with underscores for the key to be used as a json key: - key = key.replace(" ", "_") + key = key.strip().replace(" ", "_") instructions += f"* {key}: {instruction}\n" # Create the prompt structure: diff --git a/text_to_audio_generator/function.yaml b/text_to_audio_generator/function.yaml index 25af4d575..df142d2ef 100644 --- a/text_to_audio_generator/function.yaml +++ b/text_to_audio_generator/function.yaml @@ -2,7 +2,7 @@ kind: job metadata: name: text-to-audio-generator tag: '' - hash: f36d56d620c6a69f414c9cb90e42ec012847a607 + hash: 534e34d316098dcb345860a786ea013102150e67 project: '' labels: author: yonatans @@ -14,7 +14,7 @@ spec: args: [] image: '' build: - functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pathlib
import random
from typing import Dict, List, Optional, Tuple, Union

import bark
import numpy as np
import pandas as pd
import torch
import torchaudio
import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def generate_multi_speakers_audio(
    data_path: str,
    output_directory: str,
    speakers: Union[List[str], Dict[str, int]],
    available_voices: List[str],
    use_gpu: bool = True,
    use_small_models: bool = False,
    offload_cpu: bool = False,
    sample_rate: int = 16000,
    file_format: str = "wav",
    verbose: bool = True,
    bits_per_sample: Optional[int] = None,
) -> Tuple[str, pd.DataFrame, dict]:
    """

    :param data_path:           Path to the text file or directory containing the text files to generate audio from.
    :param output_directory:    Path to the directory to save the generated audio files to.
    :param speakers:            List / Dict of speakers to generate audio for.
                                If a list is given, the speakers will be assigned to channels in the order given.
                                If dictionary, the keys will be the speakers and the values will be the channels.
    :param available_voices:    List of available voices to use for the generation.
                        See here for the available voices:
                        https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
    :param use_gpu:             Whether to use the GPU for the generation.
    :param use_small_models:    Whether to use the small models for the generation.
    :param offload_cpu:         TODO: What does this do?
    :param sample_rate:         The sampling rate of the generated audio.
    :param file_format:         The format of the generated audio files.
    :param verbose:             Whether to print the progress of the generation.
    :param bits_per_sample:     Changes the bit depth for the supported formats.
                                Supported only in "wav" or "flac" formats.

    :returns:                   A tuple of:
                                - The output directory path.
                                - The generated audio files dataframe.
                                - The errors dictionary.
    """

    global _LOGGER
    _LOGGER = _get_logger()
    # Get the input text files to turn to audio:
    data_path = pathlib.Path(data_path).absolute()
    text_files = _get_text_files(data_path=data_path)

    # Load the bark models according to the given configurations:
    bark.preload_models(
        text_use_gpu=use_gpu,
        text_use_small=use_small_models,
        coarse_use_gpu=use_gpu,
        coarse_use_small=use_small_models,
        fine_use_gpu=use_gpu,
        fine_use_small=use_small_models,
        codec_use_gpu=use_gpu,
        force_reload=offload_cpu,
    )

    # Check for per channel generation:
    if isinstance(speakers, dict):
        speaker_per_channel = True
        # Sort the given speakers by channels:
        speakers = {
            speaker: channel
            for speaker, channel in sorted(speakers.items(), key=lambda item: item[1])
        }
    else:
        speaker_per_channel = False

    # Prepare the resampling module:
    resampler = torchaudio.transforms.Resample(
        orig_freq=bark.SAMPLE_RATE, new_freq=sample_rate, dtype=torch.float32
    )

    # Prepare the gap between each speaker:
    gap_between_speakers = np.zeros(int(0.5 * bark.SAMPLE_RATE))

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Create the output directory:
    output_directory = pathlib.Path(output_directory)
    output_directory.mkdir(exist_ok=True)

    # Start generating audio:
    # Go over the audio files and transcribe:
    for text_file in tqdm.tqdm(
        text_files, desc="Generating", unit="file", disable=not verbose
    ):

        try:
            # Randomize voices for each speaker:
            chosen_voices = {}
            available_voices_copy = available_voices.copy()
            for speaker in speakers:
                voice = random.choice(available_voices_copy)
                chosen_voices[speaker] = voice
                available_voices_copy.remove(voice)
            # Read text:
            with open(text_file, "r") as fp:
                text = fp.read()
            # Prepare a holder for all the generated pieces (if per channel each speaker will have its own):
            audio_pieces = (
                {speaker: [] for speaker in speakers}
                if speaker_per_channel
                else {"all": []}
            )

            # Generate audio per line:
            for line in text.splitlines():
                # Validate line is in correct speaker format:

                if ": " not in line:
                    if verbose:
                        _LOGGER.warning(f"Skipping line: {line}")
                    continue
                # Split line to speaker and his words:
                current_speaker, sentences = line.split(": ", 1)
                # Validate speaker is known:
                if current_speaker not in speakers:
                    raise ValueError(
                        f"Unknown speaker: {current_speaker}. Given speakers are: {speakers}"
                    )
                for sentence in _split_line(line=sentences):
                    # Generate words audio:
                    audio = bark.generate_audio(
                        sentence,
                        history_prompt=chosen_voices[current_speaker],
                        silent=True,
                    )
                    if speaker_per_channel:
                        silence = np.zeros_like(audio)
                        for speaker in audio_pieces.keys():
                            if speaker == current_speaker:
                                audio_pieces[speaker] += [audio, gap_between_speakers]
                            else:
                                audio_pieces[speaker] += [silence, gap_between_speakers]
                    else:
                        audio_pieces["all"] += [audio, gap_between_speakers]
            # Construct a single audio array from all the pieces and channels:

            audio = np.vstack(
                [np.concatenate(audio_pieces[speaker]) for speaker in speakers]
            ).astype(dtype=np.float32)
            # Resample:
            audio = torch.from_numpy(audio)
            audio = resampler(audio)
            # Save to audio file:
            audio_file = output_directory / f"{text_file.stem}.{file_format}"

            torchaudio.save(
                uri=str(audio_file),
                src=audio,
                sample_rate=sample_rate,
                format=file_format,
                bits_per_sample=bits_per_sample,
            )

            # Collect to the successes:
            successes.append([text_file.name, audio_file.name])
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{text_file.name}'")
            print(exception)
            errors[text_file.name] = str(exception)

    # Construct the translations dataframe:
    successes = pd.DataFrame(
        successes,
        columns=["text_file", "audio_file"],
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Translations summary:\n"
            f"{successes.head()}"
        )
    return str(output_directory), successes, errors


def _get_text_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _split_line(line: str, max_length: int = 250) -> List[str]:
    if len(line) < max_length:
        return [line]

    sentences = [
        f"{sentence.strip()}." for sentence in line.split(".") if sentence.strip()
    ]

    splits = []
    current_length = len(sentences[0])
    split = sentences[0]
    for sentence in sentences[1:]:
        if current_length + len(sentence) > max_length:
            splits.append(split)
            split = sentence
            current_length = len(sentence)
        else:
            current_length += len(sentence)
            split += " " + sentence
    if split:
        splits.append(split)

    return splits


def _get_logger():
    global _LOGGER
    try:
        import mlrun
        # Check if MLRun is available:
        context = mlrun.get_or_create_ctx(name="mlrun")
        return context.logger
    except ModuleNotFoundError:
        return _LOGGER
 + functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pathlib
import random
import tempfile
from typing import Dict, List, Optional, Tuple, Union

import bark
import numpy as np
import pandas as pd
import torch
import torchaudio
import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def generate_multi_speakers_audio(
    data_path: str,
    speakers: Union[List[str], Dict[str, int]],
    available_voices: List[str],
    output_directory: str = None,
    use_gpu: bool = True,
    use_small_models: bool = False,
    offload_cpu: bool = False,
    sample_rate: int = 16000,
    file_format: str = "wav",
    verbose: bool = True,
    bits_per_sample: Optional[int] = None,
) -> Tuple[str, pd.DataFrame, dict]:
    """
    Generate audio files from text files.

    :param data_path:           Path to the text file or directory containing the text files to generate audio from.
    :param speakers:            List / Dict of speakers to generate audio for.
                                If a list is given, the speakers will be assigned to channels in the order given.
                                If dictionary, the keys will be the speakers and the values will be the channels.
    :param available_voices:    List of available voices to use for the generation.
                        See here for the available voices:
                        https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
    :param output_directory:    Path to the directory to save the generated audio files to.
    :param use_gpu:             Whether to use the GPU for the generation.
    :param use_small_models:    Whether to use the small models for the generation.
    :param offload_cpu:         To reduce the memory footprint, the models can be offloaded to the CPU after loading.
    :param sample_rate:         The sampling rate of the generated audio.
    :param file_format:         The format of the generated audio files.
    :param verbose:             Whether to print the progress of the generation.
    :param bits_per_sample:     Changes the bit depth for the supported formats.
                                Supported only in "wav" or "flac" formats.

    :returns:                   A tuple of:
                                - The output directory path.
                                - The generated audio files dataframe.
                                - The errors dictionary.
    """

    global _LOGGER
    _LOGGER = _get_logger()
    # Get the input text files to turn to audio:
    data_path = pathlib.Path(data_path).absolute()
    text_files = _get_text_files(data_path=data_path)

    # Load the bark models according to the given configurations:
    bark.preload_models(
        text_use_gpu=use_gpu,
        text_use_small=use_small_models,
        coarse_use_gpu=use_gpu,
        coarse_use_small=use_small_models,
        fine_use_gpu=use_gpu,
        fine_use_small=use_small_models,
        codec_use_gpu=use_gpu,
        force_reload=offload_cpu,
    )

    # Check for per channel generation:
    if isinstance(speakers, dict):
        speaker_per_channel = True
        # Sort the given speakers by channels:
        speakers = {
            speaker: channel
            for speaker, channel in sorted(speakers.items(), key=lambda item: item[1])
        }
    else:
        speaker_per_channel = False

    # Prepare the resampling module:
    resampler = torchaudio.transforms.Resample(
        orig_freq=bark.SAMPLE_RATE, new_freq=sample_rate, dtype=torch.float32
    )

    # Prepare the gap between each speaker:
    gap_between_speakers = np.zeros(int(0.5 * bark.SAMPLE_RATE))

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Create the output directory:
    if output_directory is None:
        output_directory = tempfile.mkdtemp()
    output_directory = pathlib.Path(output_directory)
    if not output_directory.exists():
        output_directory.mkdir(exist_ok=True, parents=True)

    # Start generating audio:
    # Go over the audio files and transcribe:
    for text_file in tqdm.tqdm(
        text_files, desc="Generating", unit="file", disable=not verbose
    ):

        try:
            # Randomize voices for each speaker:
            chosen_voices = {}
            available_voices_copy = available_voices.copy()
            for speaker in speakers:
                voice = random.choice(available_voices_copy)
                chosen_voices[speaker] = voice
                available_voices_copy.remove(voice)
            # Read text:
            with open(text_file, "r") as fp:
                text = fp.read()
            # Prepare a holder for all the generated pieces (if per channel each speaker will have its own):
            audio_pieces = (
                {speaker: [] for speaker in speakers}
                if speaker_per_channel
                else {"all": []}
            )

            # Generate audio per line:
            for line in text.splitlines():
                # Validate line is in correct speaker format:

                if ": " not in line:
                    if verbose:
                        _LOGGER.warning(f"Skipping line: {line}")
                    continue
                # Split line to speaker and his words:
                current_speaker, sentences = line.split(": ", 1)
                # Validate speaker is known:
                if current_speaker not in speakers:
                    raise ValueError(
                        f"Unknown speaker: {current_speaker}. Given speakers are: {speakers}"
                    )
                for sentence in _split_line(line=sentences):
                    # Generate words audio:
                    audio = bark.generate_audio(
                        sentence,
                        history_prompt=chosen_voices[current_speaker],
                        silent=True,
                    )
                    if speaker_per_channel:
                        silence = np.zeros_like(audio)
                        for speaker in audio_pieces.keys():
                            if speaker == current_speaker:
                                audio_pieces[speaker] += [audio, gap_between_speakers]
                            else:
                                audio_pieces[speaker] += [silence, gap_between_speakers]
                    else:
                        audio_pieces["all"] += [audio, gap_between_speakers]
            # Construct a single audio array from all the pieces and channels:

            audio = np.vstack(
                [np.concatenate(audio_pieces[speaker]) for speaker in speakers]
            ).astype(dtype=np.float32)
            # Resample:
            audio = torch.from_numpy(audio)
            audio = resampler(audio)
            # Save to audio file:
            audio_file = output_directory / f"{text_file.stem}.{file_format}"

            torchaudio.save(
                uri=str(audio_file),
                src=audio,
                sample_rate=sample_rate,
                format=file_format,
                bits_per_sample=bits_per_sample,
            )

            # Collect to the successes:
            successes.append([text_file.name, audio_file.name])
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{text_file.name}'")
            print(exception)
            errors[text_file.name] = str(exception)

    # Construct the translations dataframe:
    successes = pd.DataFrame(
        successes,
        columns=["text_file", "audio_file"],
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Translations summary:\n"
            f"{successes.head()}"
        )
    return str(output_directory), successes, errors


def _get_text_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _split_line(line: str, max_length: int = 250) -> List[str]:
    if len(line) < max_length:
        return [line]

    sentences = [
        f"{sentence.strip()}." for sentence in line.split(".") if sentence.strip()
    ]

    splits = []
    current_length = len(sentences[0])
    split = sentences[0]
    for sentence in sentences[1:]:
        if current_length + len(sentence) > max_length:
            splits.append(split)
            split = sentence
            current_length = len(sentence)
        else:
            current_length += len(sentence)
            split += " " + sentence
    if split:
        splits.append(split)

    return splits


def _get_logger():
    global _LOGGER
    try:
        import mlrun
        # Check if MLRun is available:
        context = mlrun.get_or_create_ctx(name="mlrun")
        return context.logger
    except ModuleNotFoundError:
        return _LOGGER
 base_image: mlrun/mlrun commands: [] code_origin: '' @@ -25,15 +25,12 @@ spec: entry_points: generate_multi_speakers_audio: name: generate_multi_speakers_audio - doc: '' + doc: Generate audio files from text files. parameters: - name: data_path type: str doc: Path to the text file or directory containing the text files to generate audio from. - - name: output_directory - type: str - doc: Path to the directory to save the generated audio files to. - name: speakers type: Union[List[str], Dict[str, int]] doc: List / Dict of speakers to generate audio for. If a list is given, the @@ -43,6 +40,10 @@ spec: type: List[str] doc: 'List of available voices to use for the generation. See here for the available voices: https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c' + - name: output_directory + type: str + doc: Path to the directory to save the generated audio files to. + default: null - name: use_gpu type: bool doc: Whether to use the GPU for the generation. @@ -53,7 +54,8 @@ spec: default: false - name: offload_cpu type: bool - doc: 'TODO: What does this do?' + doc: To reduce the memory footprint, the models can be offloaded to the CPU + after loading. default: false - name: sample_rate type: int @@ -75,8 +77,10 @@ spec: outputs: - doc: 'A tuple of: - The output directory path. - The generated audio files dataframe. - The errors dictionary.' - default: '' - lineno: 30 + type: Tuple[str, pd.DataFrame, dict] + lineno: 31 + has_varargs: false + has_kwargs: false description: Generate audio file from text using different speakers default_handler: generate_multi_speakers_audio disable_auto_mount: false diff --git a/text_to_audio_generator/item.yaml b/text_to_audio_generator/item.yaml index dba7f1e0c..4784a80d2 100644 --- a/text_to_audio_generator/item.yaml +++ b/text_to_audio_generator/item.yaml @@ -24,5 +24,5 @@ spec: - bark - torchaudio url: '' -version: 1.0.0 +version: 1.1.0 test_valid: True diff --git a/text_to_audio_generator/text_to_audio_generator.py b/text_to_audio_generator/text_to_audio_generator.py index ad0e114e8..7602745ee 100644 --- a/text_to_audio_generator/text_to_audio_generator.py +++ b/text_to_audio_generator/text_to_audio_generator.py @@ -14,6 +14,7 @@ import logging import pathlib import random +import tempfile from typing import Dict, List, Optional, Tuple, Union import bark @@ -29,9 +30,9 @@ def generate_multi_speakers_audio( data_path: str, - output_directory: str, speakers: Union[List[str], Dict[str, int]], available_voices: List[str], + output_directory: str = None, use_gpu: bool = True, use_small_models: bool = False, offload_cpu: bool = False, @@ -44,13 +45,13 @@ def generate_multi_speakers_audio( Generate audio files from text files. :param data_path: Path to the text file or directory containing the text files to generate audio from. - :param output_directory: Path to the directory to save the generated audio files to. :param speakers: List / Dict of speakers to generate audio for. If a list is given, the speakers will be assigned to channels in the order given. If dictionary, the keys will be the speakers and the values will be the channels. :param available_voices: List of available voices to use for the generation. See here for the available voices: https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c + :param output_directory: Path to the directory to save the generated audio files to. :param use_gpu: Whether to use the GPU for the generation. :param use_small_models: Whether to use the small models for the generation. :param offload_cpu: To reduce the memory footprint, the models can be offloaded to the CPU after loading. @@ -108,8 +109,11 @@ def generate_multi_speakers_audio( errors = {} # Create the output directory: + if output_directory is None: + output_directory = tempfile.mkdtemp() output_directory = pathlib.Path(output_directory) - output_directory.mkdir(exist_ok=True) + if not output_directory.exists(): + output_directory.mkdir(exist_ok=True, parents=True) # Start generating audio: # Go over the audio files and transcribe: diff --git a/transcribe/function.yaml b/transcribe/function.yaml index 471dd6f26..40dd2f0e6 100644 --- a/transcribe/function.yaml +++ b/transcribe/function.yaml @@ -2,7 +2,7 @@ kind: job metadata: name: transcribe tag: '' - hash: e7f85ec6e204a54069b4e264003cf59d0cb27bfe + hash: 5cd620de67a936ee8a87cfc1f0b97e19730d0a69 project: '' labels: author: yonatans @@ -14,124 +14,287 @@ spec: args: [] image: '' build: - functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import operator
import pathlib
from functools import reduce, wraps
from typing import Any, Dict, List, Literal, NamedTuple, Tuple, Union

import faster_whisper
import pandas as pd
from tqdm import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def open_mpi_handler(
    worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    # Check if MLRun is available, set the global logger to MLRun's:
    if context:
        _LOGGER = context.logger

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, str):
                    input_argument = _get_audio_files(
                        data_path=pathlib.Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)
            if rank == 0:
                # Join the outputs:
                context.logger.info("Collecting data from workers to root worker.")
                output_directory = output[0][0]
                dataframe = pd.concat(objs=[df for _, df, _ in output], axis=0)
                errors_dictionary = reduce(
                    operator.ior, [err for _, _, err in output], {}
                )
                return output_directory, dataframe, errors_dictionary
            return None

        return wrapper

    return decorator


def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
        else:
            return context, None
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def transcribe(
    data_path: Union[str, List[str]],
    output_directory: str,
    model_name: str = "base",
    device: Literal["cuda", "cpu", "auto"] = "auto",
    compute_type: str = "default",
    language: str = None,
    translate_to_english: bool = False,
    speech_diarization: Dict[str, List[Tuple[float, float, str]]] = None,
    audio_duration: bool = False,
    init_kwargs: dict = None,
    transcribe_kwargs: dict = None,
    verbose: bool = False,
) -> Tuple[str, pd.DataFrame, dict]:
    """
    Transcribe audio files into text files and collect additional data. The end result is a directory of transcribed
    text files and a dataframe containing the following columns:

    * audio_file - The audio file path.
    * transcription_file - The transcribed text file name in the output directory.
    * language - The detected language in the audio file.
    * language_probability - The detected language probability.
    * duration - The duration (in seconds) of the audio file (only if `audio_duration` is set to True).

    :param data_path:               A directory of audio files or a single file or a list of files to transcribe.
    :param output_directory:        Path to a directory to save all transcribed audio files.
    :param model_name:              One of the official model names of Whisper: {'tiny.en', 'tiny', 'base.en', 'base',
                                    'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large'} or a
                                    full name of a fine-tuned whisper model from the huggingface hub.
    :param device:                  Device to load the model. Can be one of {"cuda", "cpu"}. Default will prefer "cuda"
                                    if available. To use a specific GPU or more than one GPU, pass the `device_index`
                                    argument via the `init_kwargs`.
    :param compute_type:            The data type to use for computation. For more information, check
                                    https://opennmt.net/CTranslate2/quantization.html. Default: "default" - will use the
                                    default type depending on the device used.
    :param language:                The spoken language to force Whisper the output language. If None, the Whisper model
                                    will automatically predict the output langauge. Default: None.
    :param translate_to_english:    Whether to translate the English post transcription. Default: False.
    :param speech_diarization:      A speech diarization dictionary with the file names to transcribe as keys and their
                                    diarization as value. The diarization is a list of tuples: (start, end, speaker).
                                    The transcription result will be in the following format:
                                    "{speaker}: text text text.". Files with missing diarizations will print a warning.
                                    Pay attention the diarization must be for the entire duration of the audio file (as
                                    long as Whisper is predicting words up until then).
    :param audio_duration:          Whether to include the audio files duration (in seconds). The estimated duration is
                                    from bitrate and may be inaccurate. Default: False.
    :param init_kwargs:             Additional `WhisperModel.__init__` keyword arguments to use.
    :param transcribe_kwargs:       Additional `WhisperModel.transcribe` keyword arguments to use.
    :param verbose:                 Whether to present logs of a progress bar and errors. Default: False.

    :returns: A tuple of:

              * Path to the output directory.
              * A dataframe dataset of the transcribed file names.
              * A dictionary of errored files that were not transcribed.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    if isinstance(data_path, str):
        data_path = pathlib.Path(data_path).absolute()
        audio_files = _get_audio_files(data_path=data_path)
    else:
        audio_files = data_path
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Load the whisper model:
    if verbose:
        _LOGGER.info(f"Loading model '{model_name}' - using device '{device}'.")
    init_kwargs = init_kwargs or {}
    model = faster_whisper.WhisperModel(
        model_size_or_path=model_name,
        device=device,
        compute_type=compute_type,
        **init_kwargs,
    )
    if verbose:
        _LOGGER.info(f"Model loaded successfully.")

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Create the output directory:
    output_directory = pathlib.Path(output_directory)
    output_directory.mkdir(parents=True, exist_ok=True)

    # Prepare the transcribe keyword arguments:
    transcribe_kwargs = transcribe_kwargs or {}
    transcribe_kwargs["language"] = language
    transcribe_kwargs["task"] = "translate" if translate_to_english else "transcribe"

    # Go over the audio files and transcribe:
    for audio_file in tqdm(
        audio_files, desc="Transcribing", unit="file", disable=not verbose
    ):
        try:
            # Transcribe:
            transcription_and_info = _transcribe(
                audio_file=audio_file,
                model=model,
                transcribe_kwargs=transcribe_kwargs,
                speech_diarization=_get_diarization(  # Get the diarization (if provided).
                    speech_diarization=speech_diarization,
                    file_name=audio_file.name,
                    verbose=verbose,
                ),
                audio_duration=audio_duration,
            )
            # Write the transcription to file:
            transcription_file = _save_to_file(
                transcription=transcription_and_info[0],
                file_name=audio_file.stem,
                output_directory=output_directory,
            )
            # Note as a success in the list:
            successes.append(
                [
                    audio_file.name,
                    transcription_file.name,
                    *transcription_and_info[1:],
                ]
            )
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{audio_file.name}'")
            errors[str(audio_file.name)] = str(exception)
            continue

    # Construct the transcriptions dataframe:
    columns = [
        "audio_file",
        "transcription_file",
        "language",
        "language_probability",
    ]
    if audio_duration:
        columns.append("duration")
    successes = pd.DataFrame(
        successes,
        columns=columns,
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(audio_files)})\n"
            f"Transcriptions summary:\n"
            f"{successes.head()}"
        )
    return str(output_directory), successes, errors


def _get_audio_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return audio_files


class _DiarizationSegment(NamedTuple):
    start: float
    end: float
    speaker: str


def _get_diarization(
    speech_diarization: Dict[str, List[Tuple[float, float, str]]],
    file_name: str,
    verbose: bool,
) -> Union[List[_DiarizationSegment], None]:
    diarization = None
    if speech_diarization is not None:
        diarization = speech_diarization.get(file_name)
        if diarization is None:
            if verbose:
                _LOGGER.warning(
                    f"Missing speech diarization for the audio file '{file_name}'. Continuing transcribing without "
                    f"diarization."
                )
        diarization = [_DiarizationSegment(*segment) for segment in diarization]
    return diarization


def _get_next_diarization_segment(
    word: faster_whisper.transcribe.Word,
    speech_diarization: List[_DiarizationSegment],
    last_chosen_index: int,
) -> int:
    # Get the last chosen diarization segment:
    last_chosen = speech_diarization[last_chosen_index]

    # If the last chosen segment is the last segment, return it:
    if last_chosen_index == len(speech_diarization) - 1:
        return last_chosen_index

    # If the word ends before the last chosen segment:
    if word.end <= last_chosen.start:
        # Then it is still the closest segment
        return last_chosen_index

    # We check if it ends inside the last chosen segment:
    if word.end < last_chosen.end:
        # Then it still is the closest segment
        return last_chosen_index

    # The word ends after the segment, we need to collect all next segments up until the word ends before them:
    possible_segments = [last_chosen_index]
    for i in range(last_chosen_index + 1, len(speech_diarization)):
        if word.end > speech_diarization[i].end:
            possible_segments.append(i)
            continue
        possible_segments.append(i)
        break

    # Check for the most overlapping option:
    best_overlap = 0
    overlapping_segment = None
    for i in possible_segments:
        overlap = 0
        # If the word starts before segment:
        if word.start <= speech_diarization[i].start:
            # If it ends before the segment, there is an overlap from the start of the segment to the end of the word:
            if word.end < speech_diarization[i].end:
                overlap = word.end - speech_diarization[i].start
            else:
                # The word is wrapping the segment, the overlap is the segment's length:
                overlap = speech_diarization[i].end - speech_diarization[i].start
        # The word starts in segment, check if the word ends in it:
        elif word.end < speech_diarization[i].end:
            # The overlap is the word's length:
            overlap = word.end - word.start
        # The word start in segment but ends after it, the overlap is from the word's start to the segment's end:
        else:
            overlap = speech_diarization[i].end - word.start
        # Check for new best overlap:
        if overlap > best_overlap:
            best_overlap = overlap
            overlapping_segment = i
    if overlapping_segment is not None:
        return overlapping_segment

    # If there is no overlapping segment, return the closest segment:
    best_distance = None
    closest_segment = None
    for i in possible_segments:
        distance = (
            word.start - speech_diarization[i].end
            if word.start > speech_diarization[i].end
            else speech_diarization[i].start - word.end
        )
        if best_distance is None or distance < best_distance:
            best_distance = distance
            closest_segment = i
    return closest_segment


def _construct_transcription(
    segments: List[faster_whisper.transcribe.Segment],
    speech_diarization: List[_DiarizationSegment],
) -> str:
    # If there is no diarization, concatenate all segments and return:
    if speech_diarization is None:
        return " ".join([segment.text for segment in segments])

    # There is a diarization, try to match the Whisper model predicted timestamps to the closest diarization segment
    # (closest diarization segment will be the most overlapping with the word, and if there is no overlap, the closest
    # segment to the word):
    diarization_index = 0
    speaker = speech_diarization[diarization_index].speaker
    text = f"{speaker}:"
    for segment in segments:
        for word in segment.words:
            # Get the next diarization segment:
            diarization_index = _get_next_diarization_segment(
                word=word,
                speech_diarization=speech_diarization,
                last_chosen_index=diarization_index,
            )
            # Check if the segment is of the same speaker:
            if speech_diarization[diarization_index].speaker == speaker:
                # Collect the word:
                text += word.word
            else:
                # Append a newline and update the new speaker:
                speaker = speech_diarization[diarization_index].speaker
                text += f"\n{speaker}:{word.word}"

    return text


def _transcribe(
    audio_file: pathlib.Path,
    model: faster_whisper.WhisperModel,
    transcribe_kwargs: dict,
    speech_diarization: List[_DiarizationSegment],
    audio_duration: bool,
) -> Union[Tuple[str, str, float], Tuple[str, str, float, float]]:
    # Transcribe (Segments is a generator, so we cast to list to begin transcription from start to end):
    segments, info = model.transcribe(
        audio=str(audio_file),
        **transcribe_kwargs,
        word_timestamps=speech_diarization is not None,
    )
    segments = list(segments)

    # Check if speech diarization was provided:
    if speech_diarization is None:
        text = "".join([segment.text for segment in segments])
    else:
        text = _construct_transcription(
            segments=segments,
            speech_diarization=speech_diarization,
        )
    text = text.strip()

    # Return the transcription text and the additional information:
    if audio_duration:
        return text.strip(), info.language, info.language_probability, info.duration
    return text.strip(), info.language, info.language_probability


def _save_to_file(
    transcription: str, file_name: str, output_directory: pathlib.Path
) -> pathlib.Path:
    # Prepare the file full path (checking for no duplications):
    transcription_file = output_directory / f"{file_name}.txt"
    i = 1
    while transcription_file.exists():
        i += 1
        transcription_file = output_directory / f"{file_name}_{i}.txt"

    # Make sure all directories are created:
    transcription_file.parent.mkdir(exist_ok=True, parents=True)

    # Write to file:
    with open(transcription_file, "w") as fp:
        fp.write(transcription)

    return transcription_file
 + functionSourceCode: # Copyright 2024 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import operator
import os
import tempfile
from functools import reduce, wraps
from multiprocessing import Process, Queue
from pathlib import Path
from typing import Any, Dict, Generator, List, Literal, NamedTuple, Tuple, Union

import pandas as pd
import torch
import torchaudio
from tqdm import tqdm
from transformers import (
    AutomaticSpeechRecognitionPipeline,
    AutoModelForCausalLM,
    pipeline,
)
from transformers.utils import is_flash_attn_2_available


class BaseTask:
    """
    A task to write the transcription to file.
    """

    def __init__(
        self, audio_file: Path, transcription_output: Union[dict, str], text_file: Path
    ):
        """
        Initialize the task.

        :param audio_file:           Path to the audio file that was transcribed.
        :param transcription_output: The transcription output from the pipeline. String means an exception was raised.
        :param text_file:            Path to the text file to write the transcription to.
        """
        # Store the parameters:
        self._audio_file = audio_file
        self._transcription_output = transcription_output
        self._text_file = text_file

        # Prepare the error variable:
        self._error: str = None

    def do_task(self):
        """
        Try to perform the task storing an error if occurred.
        """
        if isinstance(self._transcription_output, str):
            self._error = self._transcription_output
            return
        try:
            self._do_task()
        except Exception as exception:
            self._error = str(exception)

    def is_failed(self) -> bool:
        """
        Check if the task failed.

        :returns: Whether the task failed.
        """
        return self._error is not None

    def get_result(self) -> Tuple[str, str]:
        """
        Get the result of the task. If the task failed, the error will be returned, otherwise, the result will be the
        text file name.

        :returns: The task's result.
        """
        if self.is_failed():
            return self._audio_file.name, self._error
        return self._audio_file.name, self._text_file.name

    def to_tuple(self) -> Tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        return self.__class__.__name__, {
            "audio_file": self._audio_file,
            "transcription_output": self._transcription_output,
            "text_file": self._text_file,
        }

    def _do_task(self):
        """
        Perform the task - write the transcription to the stored file path.
        """
        # Checking for no duplications:
        i = 1
        while self._text_file.exists():
            i += 1
            self._text_file = (
                self._text_file.parent
                / f"{self._text_file.stem.rsplit('_', 1)[0]}_{i}{self._text_file.suffix}"
            )

        # Make sure all directories are created:
        self._text_file.parent.mkdir(exist_ok=True, parents=True)

        # Write to file:
        with open(self._text_file, "w") as fp:
            fp.write(self._transcription_output["text"])


class SpeechDiarizationTask(BaseTask):
    """
    A task to write the transcription to file with respect to a given speech diarization.
    """

    class _DiarizationSegment(NamedTuple):
        """
        A speech diarization segment.
        """

        start: float
        end: float
        speaker: str

    class _WordTimestamp(NamedTuple):
        """
        A word with its start and end timestamps.
        """

        start: float
        end: float
        text: str

    def __init__(
        self,
        audio_file: Path,
        transcription_output: dict,
        text_file: Path,
        speech_diarization: List[Tuple[float, float, str]],
    ):
        """
        Initialize the task.

        :param audio_file:           Path to the audio file that was transcribed.
        :param transcription_output: The transcription output from the pipeline.
        :param text_file:            Path to the text file to write the transcription to.
        :param speech_diarization:   A speech diarization as a list of tuples: (start, end, speaker).
        """
        super().__init__(
            audio_file=audio_file,
            transcription_output=transcription_output,
            text_file=text_file,
        )
        self._speech_diarization = speech_diarization
        self._segments: List[SpeechDiarizationTask._DiarizationSegment] = None
        self._last_chosen_index = 0

    def to_tuple(self) -> Tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        task_class, task_kwargs = super().to_tuple()
        return task_class, {
            **task_kwargs,
            "speech_diarization": self._speech_diarization,
        }

    def _do_task(self):
        """
        Perform the task - write the transcription to the stored file path with respect to the given speech diarization.
        """
        # Check if a speech diarization is given, if not, just write the transcription to file:
        if not self._speech_diarization:
            super()._do_task()
            return

        # Cast the chunks to word timestamps tuples:
        words = [
            SpeechDiarizationTask._WordTimestamp(
                start=chunk["timestamp"][0],
                end=chunk["timestamp"][1],
                text=chunk["text"],
            )
            for chunk in self._transcription_output["chunks"]
        ]

        # Cast speech diarization to segments tuples:
        self._segments = [
            SpeechDiarizationTask._DiarizationSegment(*segment)
            for segment in self._speech_diarization
        ]

        # Try to match the Whisper model predicted timestamps to the closest diarization segment (closest diarization
        # segment will be the most overlapping with the word, and if there is no overlap, the closest segment to the
        # word):
        speaker = self._segments[self._last_chosen_index].speaker
        text = f"{speaker}:"
        for word in words:
            # Get the next diarization segment:
            self._get_next_segment(word=word)
            # Check if the segment is of the same speaker:
            if self._segments[self._last_chosen_index].speaker == speaker:
                # Collect the word:
                text += word.text
            else:
                # Append a newline and update the new speaker:
                speaker = self._segments[self._last_chosen_index].speaker
                text += f"\n{speaker}:{word.text}"

        # Update the transcription output with the new text to write it to file:
        self._transcription_output["text"] = text
        super()._do_task()

    def _get_next_segment(
        self,
        word: _WordTimestamp,
    ):
        """
        Get the next diarization segment the given word falls into. The `self._last_chosen_index` will be updated
        accordingly.

        :param word: The word timestamp to match to the next segment.
        """
        # If the last chosen segment is the last segment, return it:
        if self._last_chosen_index == len(self._segments) - 1:
            return

        # Get the last chosen diarization segment:
        last_chosen = self._segments[self._last_chosen_index]

        # None value may appear if the word is the last word in the audio file, or it was split during inference. In
        # that case, we'll set the last segment:
        if word.end is None:
            self._last_chosen_index = len(self._segments) - 1
            return

        # If the word ends before the last chosen segment:
        if word.end <= last_chosen.start:
            # Then it is still the closest segment
            return

        # We check if it ends inside the last chosen segment:
        if word.end < last_chosen.end:
            # Then it still is the closest segment
            return

        # The word ends after the segment, we need to collect all next segments up until the word ends before them:
        possible_segments = [self._last_chosen_index]
        for i in range(self._last_chosen_index + 1, len(self._segments)):
            if word.end > self._segments[i].end:
                possible_segments.append(i)
                continue
            possible_segments.append(i)
            break

        # Check for the most overlapping option:
        best_overlap = 0
        most_overlapping_segment_index = None
        for i in possible_segments:
            # If the word starts before segment:
            if word.start <= self._segments[i].start:
                # If it ends before the segment, there is an overlap from the start of the segment to the end of the
                # word:
                if word.end < self._segments[i].end:
                    overlap = word.end - self._segments[i].start
                else:
                    # The word is wrapping the segment, the overlap is the segment's length:
                    overlap = self._segments[i].end - self._segments[i].start
            # The word starts in segment, check if the word ends in it:
            elif word.end < self._segments[i].end:
                # The overlap is the word's length:
                overlap = word.end - word.start
            # The word start in segment but ends after it, the overlap is from the word's start to the segment's end:
            else:
                overlap = self._segments[i].end - word.start
            # Check for new best overlap:
            if overlap > best_overlap:
                best_overlap = overlap
                most_overlapping_segment_index = i
        if most_overlapping_segment_index is not None:
            self._last_chosen_index = most_overlapping_segment_index
            return

        # If there is no overlapping segment, return the closest segment:
        best_distance = None
        closest_segment_index = None
        for i in possible_segments:
            distance = (
                word.start - self._segments[i].end
                if word.start > self._segments[i].end
                else self._segments[i].start - word.end
            )
            if best_distance is None or distance < best_distance:
                best_distance = distance
                closest_segment_index = i
        self._last_chosen_index = closest_segment_index


class SpeechDiarizationPerChannelTask(BaseTask):
    """
    A task to write the transcription to file with respect to a given speech diarization per channel.
    """

    class _WordTimestamp(NamedTuple):
        """
        A word with its start and end timestamps and speaker label (channel the word was taken from).
        """

        start: float
        end: float
        speaker: str
        text: str

    def __init__(self, audio_file: Path, text_file: Path):
        """
        Initialize the task.

        :param audio_file: Path to the audio file that was transcribed.
        :param text_file:  Path to the text file to write the transcription to.
        """
        super().__init__(
            audio_file=audio_file, transcription_output={}, text_file=text_file
        )
        self._transcription_output_channels: List[Tuple[str, dict]] = []

    @property
    def transcription_output_channels(self) -> List[Tuple[str, dict]]:
        """
        Get the transcription output channels.

        :returns: The transcription output channels.
        """
        return self._transcription_output_channels

    def do_task(self):
        """
        Try to perform the task storing an error if occurred.
        """
        for _, channel_output in self._transcription_output_channels:
            if isinstance(channel_output, str):
                self._error = self._transcription_output_channels
                return
        super().do_task()

    def to_tuple(self) -> Tuple[str, dict]:
        """
        Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue).

        :returns: The converted task.
        """
        task_class, task_kwargs = super().to_tuple()
        task_kwargs.pop("transcription_output")
        return task_class, task_kwargs

    def _do_task(self):
        """
        Perform the task - write the transcription to the stored file path with respect to the given speech diarization
        per channel.
        """
        # Cast the chunks to word timestamps tuples:
        words_per_channel = [
            [
                SpeechDiarizationPerChannelTask._WordTimestamp(
                    start=chunk["timestamp"][0],
                    end=chunk["timestamp"][1],
                    speaker=speaker,
                    text=chunk["text"],
                )
                for chunk in output["chunks"]
            ]
            for speaker, output in self._transcription_output_channels
        ]

        # Merge and sort the words per channel by their start time:
        words = operator.add(*words_per_channel)
        words.sort()

        # Write the transcription to file:
        current_speaker = words[0].speaker
        text = f"{current_speaker}:"
        for word in words:
            # Check if the word's speaker is different from the current one:
            if word.speaker != current_speaker:
                # Append a newline and update the new speaker:
                current_speaker = word.speaker
                text += f"\n{current_speaker}:"
            # Collect the word:
            text += word.text

        # Update the transcription output with the new text to write it to file:
        self._transcription_output["text"] = text
        super()._do_task()


class BatchProcessor:
    """
    A batch processor to process batches of transcriptions. The batch processor is creating tasks and is aimed to be
    working along the transcriber. It can be used with multiprocessing queue or run the tasks directly using the
    associated methods.
    """

    def __init__(self, audio_files: List[Path], output_directory: Path):
        """
        Initialize the batch processor.

        :param audio_files:      The list of all audio files to transcribe.
        :param output_directory: The output directory to write the transcriptions to.
        """
        # Store the parameters:
        self._audio_files = audio_files
        self._output_directory = output_directory

        # Prepare the batching variables:
        self._current_file_index = 0
        self._tasks: List[BaseTask] = []
        self._results: List[Tuple[bool, Tuple[str, str]]] = []

    def process_batch(self, batch: List[Union[dict, str]]):
        """
        Process a batch of transcriptions. Tasks related to the given batch will be created and stored in the batch
        processor.

        :param batch: The batch of transcriptions to process.
        """
        # Get the relevant files belongs to the given batch:
        current_files = self._get_current_files(batch_size=len(batch))

        # Build the diarization tasks:
        self._tasks.extend(
            [
                BaseTask(
                    audio_file=file,
                    transcription_output=batch[i],
                    text_file=self._output_directory / f"{file.stem}.txt",
                )
                for i, file in enumerate(current_files)
            ]
        )

    def get_tasks(self) -> List[BaseTask]:
        """
        Get the tasks to perform.

        :returns: The tasks to perform.
        """
        tasks = self._tasks
        self._tasks = []
        return tasks

    def do_tasks(self):
        """
        Perform the tasks. Should be used if no multiprocessing queue is given to a transcriber.
        """
        for task in self.get_tasks():
            task.do_task()
            self._results.append((task.is_failed(), task.get_result()))

    def get_results(self) -> List[Tuple[bool, Tuple[str, str]]]:
        """
        Get the results of the tasks. The stored results are then cleared.

        :returns: The results of the tasks.
        """
        results = self._results
        self._results = []
        return results

    def _get_current_files(self, batch_size: int) -> List[Path]:
        """
        Get the current files to process.

        :param batch_size: The batch size to progress the current file index.

        :returns: The current files to process.
        """
        end_index = (
            self._current_file_index + batch_size
            if self._current_file_index + batch_size < len(self._audio_files)
            else len(self._audio_files)
        )
        current_files = self._audio_files[self._current_file_index : end_index]
        self._current_file_index = end_index
        return current_files


class SpeechDiarizationBatchProcessor(BatchProcessor):
    """
    A batch processor to process batches of transcriptions with respect to a given speech diarization. The batch
    processor is creating tasks and is aimed to be working along the transcriber. It can be used with multiprocessing
    queue or run the tasks directly using the associated methods.
    """

    def __init__(
        self, audio_files: List[Path], output_directory: Path, speech_diarization: dict
    ):
        """
        Initialize the batch processor.

        :param audio_files:        The list of all audio files to transcribe.
        :param output_directory:   The output directory to write the transcriptions to.
        :param speech_diarization: A speech diarization dictionary to pass along with each processed batch.
        """
        super().__init__(audio_files=audio_files, output_directory=output_directory)
        self._speech_diarization = speech_diarization
        self._audio_files = audio_files

    def process_batch(self, batch: List[dict]):
        """
        Process a batch of transcriptions. Tasks related to the given batch will be created and stored in the batch
        processor.

        :param batch: The batch of transcriptions to process.
        """
        # Get the relevant files belongs to the given batch:
        current_files = self._get_current_files(batch_size=len(batch))

        # Build the diarization tasks:
        self._tasks.extend(
            [
                SpeechDiarizationTask(
                    audio_file=file,
                    transcription_output=batch[i],
                    text_file=self._output_directory / f"{file.stem}.txt",
                    speech_diarization=self._speech_diarization.get(file.name),
                )
                for i, file in enumerate(current_files)
            ]
        )


class PerChannelSpeechDiarizationBatchProcessor(BatchProcessor):
    """
    A batch processor to process batches of transcriptions per channel. The batch processor is creating tasks with the
    selected amount of channels given and is aimed to be working along the transcriber. It can be used with
    multiprocessing queue or run the tasks directly using the associated methods.
    """

    def __init__(
        self,
        audio_files: List[Path],
        output_directory: Path,
        n_channels: int,
        speakers: List[str],
    ):
        """
        Initialize the batch processor.

        :param audio_files:      The list of all audio files to transcribe.
        :param output_directory: The output directory to write the transcriptions to.
        :param n_channels:       The number of channels in each audio file to transcribe.
        :param speakers:         The speakers labels to use for each channel.
        """
        super().__init__(audio_files=audio_files, output_directory=output_directory)

        # Store the parameters:
        self._n_channels = n_channels
        self._speakers = speakers

        # Prepare a channel buffer to store the channels until the current task created is fully covered:
        self._task_in_process: SpeechDiarizationPerChannelTask = None

    def process_batch(self, batch: List[dict]):
        """
        Process a batch of transcriptions. Tasks related to the given batch will be created and stored in the batch
        processor.

        :param batch: The batch of transcriptions to process.
        """
        # Go over the batch and create the tasks:
        for output in batch:
            # Check if there is a task in process:
            if not self._task_in_process:
                # Create a new task:
                self._task_in_process = SpeechDiarizationPerChannelTask(
                    audio_file=self._audio_files[self._current_file_index],
                    text_file=self._output_directory
                    / f"{self._audio_files[self._current_file_index].stem}.txt",
                )
            # Get the channel's speaker:
            speaker = self._speakers[
                len(self._task_in_process.transcription_output_channels)
            ]
            # Collect the channel into the processed task:
            self._task_in_process.transcription_output_channels.append(
                (speaker, output)
            )
            # Check if the task is fully covered (all channels are collected):
            if (
                len(self._task_in_process.transcription_output_channels)
                == self._n_channels
            ):
                # Collect the task and reset the task in process:
                self._tasks.append(self._task_in_process)
                self._current_file_index += 1
                self._task_in_process = None


class Transcriber:
    """
    A transcription wrapper for the Huggingface's ASR pipeline -
    https://huggingface.co/transformers/main_classes/pipelines.html#transformers.AutomaticSpeechRecognitionPipeline to
    use with OpenAI's Whisper models - https://huggingface.co/openai.
    """

    def __init__(
        self,
        model_name: str,
        device: str = None,
        use_flash_attention_2: bool = None,
        use_better_transformers: bool = None,
        assistant_model: str = None,
        max_new_tokens: int = 128,
        chunk_length_s: int = 30,
        batch_size: int = 2,
        spoken_language: str = None,
        translate_to_english: bool = False,
        return_timestamps: Union[bool, Literal["word"]] = False,
        per_channel_transcription: int = 0,
    ):
        """
        Initialize the transcriber.

        :param model_name:                The model name to use. Should be a model from the OpenAI's Whisper models for
                                          best results (for example "tiny", "base", "large", etc.).
        :param device:                    The device to use for inference. If not given, will use GPU if available.
        :param use_flash_attention_2:     Whether to use the Flash Attention 2 implementation. It can be used only with
                                          one of the following GPUs: Nvidia H series and Nvidia A series. T4 support
                                          will be available soon.

                                          Note: If both `use_flash_attention_2` and
                                          `use_better_transformers` are `None`, the optimization will be chosen
                                          automatically according to the available resources.

        :param use_better_transformers:   Whether to use the Better Transformers library to further optimize the model.
                                          Should be used for all use cases that do not support flash attention 2.

                                          Note: If both `use_flash_attention_2` and `use_better_transformers` are
                                          `None`, the optimization will be chosen automatically according to the
                                          available resources.
       :param assistant_model:           The assistant model name to use for inference. Notice that the optimizations
                                          (flash attention 2 and better transformers) will be applied for the assistant
                                          as well. Should be a model from Huggingface's distil-whisper (see here for
                                          more information: https://github.com/huggingface/distil-whisper).
        :param max_new_tokens:            The maximum number of new tokens to generate. This is used to limit the
                                          generation length. Default is 128 tokens.
        :param chunk_length_s:            The audio chunk to split the audio to (in seconds). Default is 30 seconds.
        :param batch_size:                The batch size to use for inference. Default is 2.
        :param spoken_language:           Aim whisper to know what language is spoken. If None, it will try to detect it
                                          for each chunk.
        :param translate_to_english:      Whether to translate the transcriptions to English. Default is False.
        :param return_timestamps:         Whether to return the timestamps of the words. If "word", will return the
                                          timestamps of each word. If True will return the timestamps of each chunk.
                                          Default is False. Aimed to be used for speech diarization.
        :param per_channel_transcription: Whether to do per channel transcription. If needed to run per channel
                                          transcription, pass the number of channels expected for each audio file here.
                                          0 means regular transcription (merge channels).

                                          Note: If `per_channel_transcription` is not 0, `batch_size` must be treated to
                                          be the number of channels and not audio files. Aimed to be used for per
                                          channel speech diarization.
        """
        # Store loading parameters:
        self._model_name = model_name
        self._device = device
        self._use_flash_attention_2 = use_flash_attention_2
        self._use_better_transformers = use_better_transformers
        self._max_new_tokens = max_new_tokens
        self._chunk_length_s = chunk_length_s
        self._batch_size = batch_size
        self._return_timestamps = return_timestamps
        self._per_channel_transcription = per_channel_transcription

        # Store generation parameters:
        self._assistant_model = assistant_model
        self._spoken_language = spoken_language
        self._translate_to_english = translate_to_english

        # Prepare the transcription objects:
        self._transcription_pipeline: AutomaticSpeechRecognitionPipeline = None
        self._generate_kwargs: dict = None

    def load(self):
        """
        Load the transcriber. Must be called before transcribing.
        """
        # Set the device and data type to use (prefer GPU if available):
        device = torch.device(
            self._device or "cuda" if torch.cuda.is_available() else "cpu"
        )
        torch_dtype = torch.float16 if device.type == "cuda" else torch.float32

        # Choose the optimization to use (in case the user did not specify any):
        if (
            self._use_flash_attention_2 is None
            and self._use_better_transformers is None
        ):
            # Prefer to use flash attention 2 if available and cuda device is supported (see GPU names to architecture
            # here: https://en.wikipedia.org/wiki/List_of_Nvidia_graphics_processing_units#Tesla):
            if device.type == "cuda" and is_flash_attn_2_available():
                cuda_device_name = torch.cuda.get_device_properties(device).name
                if any(
                    cuda_device_name.startswith(gpu_name)
                    for gpu_name in [
                        "NVIDIA A",  # For Ampere architecture (e.g. A10, A30, A100)
                        "NVIDIA H",  # For Hopper architecture (e.g. H100)
                        "NVIDIA L",  # For Ada Lovelace architecture (e.g. L4, L40)
                        "NVIDIA RTX 30",  # For Ada Lovelace architecture (RTX 30 series)
                        "NVIDIA RTX 40",  # For Ada Lovelace architecture (RTX 40 series)
                        "NVIDIA RTX 50",  # For Ada Lovelace architecture (RTX 50 series)
                        # Will be supported soon according to FlashAttention GitHub repo:
                        # https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features
                        # "NVIDIA T4",  # For Turing architecture (only T4)
                        # "NVIDIA RTX 20",  # For Turing architecture (RTX 20 series)
                    ]
                ):
                    self._use_flash_attention_2 = True
                else:
                    self._use_better_transformers = True
            else:
                self._use_better_transformers = True

        # Build the optimizations kwargs:
        model_kwargs = {
            "low_cpu_mem_usage": True,
            "use_safetensors": True,
        }
        if self._use_flash_attention_2:
            if _LOGGER:
                _LOGGER.info(
                    "Using FlashAttention2 optimization - make sure the `flash-attn` package is installed via "
                    "`pip install -U flash-attn --no-build-isolation`"
                )
            model_kwargs["attn_implementation"] = "flash_attention_2"
        elif self._use_better_transformers:
            if _LOGGER:
                _LOGGER.info(
                    "Using BetterTransformers optimization - make sure the `optimum` package is installed via "
                    "`pip install -U optimum`"
                )
            model_kwargs["attn_implementation"] = "sdpa"

        # Initialize the speech recognition pipeline:
        self._transcription_pipeline = pipeline(
            task="automatic-speech-recognition",
            model=self._model_name,
            model_kwargs=model_kwargs.copy(),
            batch_size=self._batch_size,
            max_new_tokens=self._max_new_tokens,
            chunk_length_s=self._chunk_length_s,
            return_timestamps=self._return_timestamps,
            torch_dtype=torch_dtype,
            device=device,
        )

        # Prepare the generation kwargs:
        self._generate_kwargs = {
            "language": self._spoken_language,
            "task": "translate" if self._translate_to_english else "transcribe",
        }

        # Initialize the assistant model (if needed):
        if self._assistant_model:
            assistant_model = AutoModelForCausalLM.from_pretrained(
                self._assistant_model, torch_dtype=torch_dtype, **model_kwargs
            )
            assistant_model.to(device)
            self._generate_kwargs["assistant_model"] = assistant_model

    def transcribe(
        self,
        audio_files: List[Path],
        batch_processor: BatchProcessor = None,
        batches_queue: Queue = None,
        verbose: bool = False,
    ) -> Union[List[List[dict]], None]:
        """
        Transcribe the given audio files. The transcriptions will be sent to a queue or a batch processor for further
        processing like writing to text files. If no queue or batch processor is given, the transcriptions outputs from
        the pipeline will be returned. Otherwise, `None` is returned.

        :param audio_files:     The audio files to transcribe.
        :param batch_processor: A batch processor.
        :param batches_queue:   A multiprocessing queue to put the batches in.
        :param verbose:         Whether to show a progress bar. Default is False.

        :returns: The transcriptions outputs from the pipeline if no queue or batch processor is given, otherwise,
                  `None`.
        """
        # Wrap the audio files with a function to iterate over them via a generator (save memory and runtime with
        # Huggingface's pipelines as they preload each input while inference is running):
        def audio_iterator() -> Generator[Union[dict, str], None, None]:
            if self._per_channel_transcription:
                for audio_file in audio_files:
                    audio, sampling_rate = torchaudio.load(str(audio_file))
                    audio = audio.numpy()
                    for channel in audio:
                        yield {"raw": channel, "sampling_rate": sampling_rate}
            else:
                for audio_file in audio_files:
                    yield str(audio_file)

        # Create a batch iterator:
        def batch_iterator() -> Generator[List[Union[dict, str]], None, None]:
            batch = []
            for audio in audio_iterator():
                batch.append(audio)
                if len(batch) == self._batch_size:
                    yield batch
                    batch = []
            if batch:
                yield batch

        # Prepare the successes dataframe and errors dictionary to be returned:
        outputs = []

        # Infer through the pipeline:
        for input_batch in tqdm(
            batch_iterator() if self._batch_size > 1 else audio_iterator(),
            desc="Transcribing",
            unit="channel" if self._per_channel_transcription else "audio file",
            total=(
                (
                    (len(audio_files) // self._batch_size)
                    + (len(audio_files) % self._batch_size != 0)
                )
                * (self._per_channel_transcription or 1)
            ),
            disable=not verbose,
        ):
            # Infer:
            try:
                output_batch = self._transcription_pipeline(
                    input_batch,
                    generate_kwargs=self._generate_kwargs,
                )
            except Exception as exception:
                # Collect the exception:
                output_batch = str(exception)
                # Align to batch size:
                output_batch = (
                    [output_batch] * len(input_batch)
                    if isinstance(input_batch, list)
                    else [output_batch]
                )
            # To align with batching, if batch size is 1, wrap the output with a list:
            if isinstance(output_batch, dict):
                output_batch = [output_batch]
            # If a batch processor is given, process the batch:
            if batch_processor:
                # Process it directly:
                batch_processor.process_batch(batch=output_batch)
                batch_processor.do_tasks()
            elif batches_queue:
                # Otherwise, queue the batch:
                batches_queue.put(output_batch)
            else:
                # Otherwise, collect the output as is without processing:
                outputs.append(output_batch)

        # Check if given a multiprocessing queue or a batch processor:
        if batches_queue:
            batches_queue.put(_MULTIPROCESSING_STOP_MARK)

        return outputs if not batch_processor else None


#: The value to send into multiprocessing queues to stop the process:
_MULTIPROCESSING_STOP_MARK = "STOP"


def _multiprocessing_process_batches(
    batch_processor: BatchProcessor,
    batches_queue: Queue,
    tasks_queue: Queue,
    n_task_completers: int,
):
    """
    Process the batches in the given batches queue and put the tasks in the given tasks queue. The function will stop
    when the given batches queue will receive the stop mark. It is aimed to be used with multiprocessing as a process.

    :param batch_processor:   A batch processor to process the batches.
    :param batches_queue:     A queue to get the batches from.
    :param tasks_queue:       A queue to put the tasks in.
    :param n_task_completers: The number of task completers (processes that run the `_multiprocessing_complete_tasks`
                              function). A stop mark will be sent to the tasks queue for each task completer.
    """
    while True:
        # Get the batch:
        batch: List[dict] = batches_queue.get()
        if batch == _MULTIPROCESSING_STOP_MARK:
            break

        # Process the batch:
        batch_processor.process_batch(batch=batch)

        # Get the tasks:
        tasks = batch_processor.get_tasks()

        # Queue the tasks:
        for task in tasks:
            tasks_queue.put(task.to_tuple())

    # Mark the end of the batches:
    for _ in range(n_task_completers):
        tasks_queue.put(_MULTIPROCESSING_STOP_MARK)


def _multiprocessing_complete_tasks(tasks_queue: Queue, results_queue: Queue):
    """
    Complete the tasks in the given queue and put the results in the given results queue. The function will stop when
    the given tasks queue will receive the stop mark. It is aimed to be used with multiprocessing as a process.

    :param tasks_queue:   A queue to get the tasks from.
    :param results_queue: A queue to put the results in.
    """
    tasks_map = {
        BaseTask.__name__: BaseTask,
        SpeechDiarizationTask.__name__: SpeechDiarizationTask,
        SpeechDiarizationPerChannelTask.__name__: SpeechDiarizationPerChannelTask,
    }

    while True:
        # Get the task:
        task = tasks_queue.get()
        if task == _MULTIPROCESSING_STOP_MARK:
            break

        # Reconstruct the task:
        task_class, task_kwargs = task
        task = tasks_map[task_class](**task_kwargs)

        # Complete the task:
        task.do_task()
        results_queue.put((task.is_failed(), task.get_result()))

    # Mark the end of the tasks:
    results_queue.put(_MULTIPROCESSING_STOP_MARK)


# Get the global logger:
_LOGGER = logging.getLogger()


def open_mpi_handler(
    worker_inputs: List[str], root_worker_inputs: Dict[str, Any] = None
):
    global _LOGGER

    # Check for MLRun and OpenMPI availability:
    context, comm = _check_mlrun_and_open_mpi()

    # Check if MLRun is available, set the global logger to MLRun's:
    if context:
        _LOGGER = context.logger

    def decorator(handler):
        if comm is None or comm.Get_size() == 1:
            return handler

        @wraps(handler)
        def wrapper(**kwargs):
            # Get the open mpi environment properties:
            size = comm.Get_size()
            rank = comm.Get_rank()

            # Give the correct chunk of the workers inputs:
            for worker_input in worker_inputs:
                input_argument = kwargs[worker_input]
                if input_argument is None:
                    continue
                if isinstance(input_argument, str):
                    input_argument = _get_audio_files(
                        data_path=Path(input_argument).absolute()
                    )
                if len(input_argument) < size:
                    raise ValueError(
                        f"Cannot split the input '{worker_input}' of length {len(input_argument)} to {size} workers. "
                        f"Please reduce the amount of workers for this input."
                    )
                even_chunk_size = len(input_argument) // size
                chunk_start = rank * even_chunk_size
                chunk_end = (
                    (rank + 1) * even_chunk_size
                    if rank + 1 < size
                    else len(input_argument)
                )
                context.logger.info(
                    f"Rank #{rank}: Processing input chunk of '{worker_input}' "
                    f"from index {chunk_start} to {chunk_end}."
                )
                if isinstance(input_argument, list):
                    input_argument = input_argument[chunk_start:chunk_end]
                elif isinstance(input_argument, pd.DataFrame):
                    input_argument = input_argument.iloc[chunk_start:chunk_end:, :]
                kwargs[worker_input] = input_argument

            # Set the root worker only arguments:
            if rank == 0 and root_worker_inputs:
                kwargs.update(root_worker_inputs)

            # Run the worker:
            output = handler(**kwargs)

            # Save the output directory of this worker:
            output_directory = Path(output[0])

            # Send the output to the root rank (rank #0):
            output = comm.gather(output, root=0)

            # Join the data from all workers:
            if rank == 0:
                context.logger.info("Collecting data from workers to root worker.")

                # Check if there are different output directories:
                output_directories = set([Path(out_dir) for out_dir, _, _ in output])
                for r in range(1, size):
                    # True means the other workers should pass their files to the root worker (rank 0):
                    comm.send(len(output_directories) != 1, dest=r)

                # If there are different output directories, listen to the other workers:
                if len(output_directories) != 1:
                    # Collect the files from the other workers:
                    files = []
                    for r in range(1, size):
                        files.extend(comm.recv(source=r))
                    # Write the files to the root worker's output directory:
                    for file_name, file_content in files:
                        with open(output_directory / file_name, "w") as f:
                            f.write(file_content)

                # Concatenate the dataframes:
                dataframe = pd.concat(objs=[df for _, df, _ in output], axis=0)

                # Concatenate the errors dictionaries:
                errors_dictionary = reduce(
                    operator.ior, [err for _, _, err in output], {}
                )

                return str(output_directory), dataframe, errors_dictionary

            # Listen to rank 0 to see if there are different output directories and this rank need to send its files to
            # it:
            if comm.recv(source=0):
                files = []
                for file in os.listdir(output_directory):
                    with open(output_directory / file, "r") as f:
                        files.append((file, f.read()))
                comm.send(files, dest=0)
            return None

        return wrapper

    return decorator


def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intracomm"]:
    is_mpi = False
    try:
        import mlrun

        context = mlrun.get_or_create_ctx(name="mlrun")
        is_mpi = context.labels.get("kind", "job") == "mpijob"

        if is_mpi:
            try:
                from mpi4py import MPI

                return context, MPI.COMM_WORLD
            except ModuleNotFoundError as mpi4py_not_found:
                context.logger.error(
                    "To distribute the function using MLRun's 'mpijob' you need to have `mpi4py` package in your "
                    "interpreter. Please run `pip install mpi4py` and make sure you have open-mpi."
                )
                raise mpi4py_not_found
        else:
            return context, None
    except ModuleNotFoundError as module_not_found:
        if is_mpi:
            raise module_not_found
    return None, None


@open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True})
def transcribe(
    # Input / Output kwargs:
    data_path: Union[str, Path, List[Union[str, Path]]],
    output_directory: str = None,
    # Model loading kwargs:
    model_name: str = "openai/whisper-tiny",
    device: str = None,
    use_flash_attention_2: bool = None,
    use_better_transformers: bool = None,
    # Generation kwargs:
    assistant_model: str = None,
    max_new_tokens: int = 128,
    chunk_length_s: int = 30,
    batch_size: int = 8,
    spoken_language: str = None,
    translate_to_english: bool = False,
    # Diarization kwargs:
    speech_diarization: Dict[str, List[Tuple[float, float, str]]] = None,
    speech_diarize_per_channel: int = None,
    speaker_labels: List[str] = None,
    # Other kwargs:
    use_multiprocessing: Union[bool, int] = False,
    verbose: bool = False,
):
    """
    Transcribe audio files into text files and collect additional data. The end result is a directory of transcribed
    text files and a dataframe containing the following columns:

    * audio_file - The audio file path.
    * transcription_file - The transcribed text file name in the output directory.

    The transcription is based on Huggingface's ASR pipeline -
    https://huggingface.co/transformers/main_classes/pipelines.html#transformers.AutomaticSpeechRecognitionPipeline and
    is tested with OpenAI's Whisper models - https://huggingface.co/openai.

    If one of the speaker diarization parameters are given (either `speech_diarization` or
    `speech_diarize_per_channel`), the transcription will be written in a conversation format, where each speaker will
    be written in a separate line::

        speaker_1: text
        speaker_2: text
        speaker_1: text
        ...

    :param data_path:                  A directory of audio files or a single file or a list of files to transcribe.
    :param output_directory:           Path to a directory to save all transcribed audio files. If not given, will save
                                       the transcribed files in a temporary directory.
    :param model_name:                 The model name to use. Should be a model from the OpenAI's Whisper models for
                                       best results (for example "tiny", "base", "large", etc.). See here for more
                                       information: https://huggingface.co/openai?search_models=whisper.
    :param device:                     The device to use for inference. If not given, will use GPU if available.
    :param use_flash_attention_2:      Whether to use the Flash Attention 2 implementation. It can be used only with
                                       one of the following GPUs: Nvidia H series and Nvidia A series. T4 support
                                       will be available soon.

                                       Note: If both `use_flash_attention_2` and
                                       `use_better_transformers` are `None`, the optimization will be chosen
                                       automatically according to the available resources.

    :param use_better_transformers:    Whether to use the Better Transformers library to further optimize the model.
                                       Should be used for all use cases that do not support flash attention 2.

                                       Note: If both `use_flash_attention_2` and `use_better_transformers` are
                                       `None`, the optimization will be chosen automatically according to the
                                       available resources.
    :param assistant_model:            The assistant model name to use for inference. Notice that the optimizations
                                       (flash attention 2 and better transformers) will be applied for the assistant as
                                       well. Should be a model from Huggingface's distil-whisper (see here for more
                                       information: https://github.com/huggingface/distil-whisper).

                                       Note: Currently an assistant model is only usable with batch size of 1.
    :param max_new_tokens:             The maximum number of new tokens to generate. This is used to limit the
                                       generation length. Default is 128 tokens.
    :param chunk_length_s:             The audio chunk to split the audio to (in seconds). Default is 30 seconds.
    :param batch_size:                 The batch size to use for inference. Default is 2.
    :param spoken_language:            Aim whisper to know what language is spoken. If None, it will try to detect
                                       it.
    :param translate_to_english:       Whether to translate the transcriptions to English.
    :param speech_diarization:         A speech diarization dictionary with the file names to transcribe as keys and
                                       their diarization as value. The diarization is a list of tuples:
                                       (start, end, speaker). An example
                                       for a diarization dictionary::

                                       {
                                           "audio_file_name": [
                                               {
                                                   "start": 0.0,
                                                   "end": 2.0,
                                                   "speaker": "Agent",
                                               },
                                               {
                                                   "start": 2.0,
                                                   "end": 4.0,
                                                   "speaker": "Client",
                                               },
                                               ...
                                           ],
                                           ...
                                       }

                                       Note: The diarization must be for the entire duration of the audio file (as long
                                       as Whisper is predicting words up until then.
    :param speech_diarize_per_channel: Perform speech diarization per channel. Each speaker is expected to belong to
                                       a separate channel in the audio. Notice: This will make the transcription
                                       slower as each channel wil be transcribed separatly. If a speech diarization
                                       is passed (via the `speech_diarization` parameter), this parameter is
                                       ignored.
    :param speaker_labels:             A list of speaker labels by channel order to use for writing the
                                       transcription with respect to per channel speech diarization. This won't be
                                       used together with a given speech diarization (via the `speech_diarization`
                                       parameter).
    :param use_multiprocessing:        Whether to use multiprocessing to transcribe the audio files. Can be either a
                                       boolean value or an integer. If `True`, will use the default amount of workers
                                       (3): 1 for transcription, 1 for batch processing and 1 for task completion (such
                                       as speech diarization and writing to files). To control the amount of tasks
                                       completion workers, an integer can be provided to specify the amount of workers.
                                       `False`, will use a single process. Default is `False`.
    :param verbose:                    Whether to print the progress of the transcription. Default is `False`.
    """
    global _LOGGER

    # Get the input audio files to transcribe:
    if verbose:
        _LOGGER.info("Collecting audio files.")
    audio_files = _get_audio_files(data_path=data_path)
    if verbose:
        _LOGGER.info(f"Collected {len(audio_files)} audio files.")

    # Get the output directory:
    if output_directory is None:
        if verbose:
            _LOGGER.info("No output directory given, using temporary directory.")
        output_directory = tempfile.mkdtemp()
    output_directory = Path(output_directory).absolute()
    output_directory.mkdir(exist_ok=True, parents=True)
    if verbose:
        _LOGGER.info(f"Transcriptions will be saved to: {output_directory}")

    # Initialize a batch processor according to user requirements (no speech diarization, given speech diarization,
    # speech diarization per channel):
    if speech_diarization:
        batch_processor = SpeechDiarizationBatchProcessor(
            audio_files=audio_files,
            output_directory=output_directory,
            speech_diarization=speech_diarization,
        )
    elif speech_diarize_per_channel:
        batch_processor = PerChannelSpeechDiarizationBatchProcessor(
            audio_files=audio_files,
            output_directory=output_directory,
            n_channels=speech_diarize_per_channel,
            speakers=speaker_labels,
        )
    else:
        batch_processor = BatchProcessor(
            audio_files=audio_files,
            output_directory=output_directory,
        )

    # Initialize the transcription pipeline:
    transcriber = Transcriber(
        device=device,
        use_flash_attention_2=use_flash_attention_2,
        use_better_transformers=use_better_transformers,
        assistant_model=assistant_model,
        model_name=model_name,
        max_new_tokens=max_new_tokens,
        chunk_length_s=chunk_length_s,
        batch_size=batch_size,
        return_timestamps=(
            "word"
            if speech_diarization is not None or speech_diarize_per_channel is not None
            else False
        ),
        per_channel_transcription=speech_diarize_per_channel or 0,
        spoken_language=spoken_language,
        translate_to_english=translate_to_english,
    )

    # Run the transcription:
    if use_multiprocessing:
        results = _parallel_run(
            n_workers=use_multiprocessing
            if isinstance(use_multiprocessing, int)
            else 1,
            audio_files=audio_files,
            batch_processor=batch_processor,
            transcriber=transcriber,
            verbose=verbose,
        )
    else:
        results = _run(
            audio_files=audio_files,
            batch_processor=batch_processor,
            transcriber=transcriber,
            verbose=verbose,
        )

    # Process the results:
    if verbose:
        _LOGGER.info("Summarizing the results.")
    successes = []
    errors = {}
    for is_error, result in results:
        if is_error:
            errors[result[0]] = result[1]
        else:
            successes.append(result)
    successes = pd.DataFrame(successes, columns=["audio_file", "transcription_file"])
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(audio_files)})\n"
            f"Transcriptions summary:\n"
            f"{successes.head()}"
        )

    return str(output_directory), successes, errors


def _get_audio_files(
    data_path: Union[Path, str, list],
) -> List[Path]:
    """
    Get the audio files to transcribe. If a path to a directory is given, all files in the directory will be collected.

    :param data_path: The data path to collect the audio files from.

    :returns: The audio files list.
    """
    # Check if given a list of paths:
    if isinstance(data_path, list):
        audio_files = []
        for path in data_path:
            audio_files.extend(_get_audio_files(data_path=path))
        return audio_files

    # Check if given a single string path to cast it to a `pathlib.Path`:
    if isinstance(data_path, str):
        data_path = Path(data_path).absolute()

    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        audio_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        audio_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be a valid path to either a directory path or a "
            f"file. Given: {str(data_path)} "
        )

    return audio_files


def _run(
    audio_files: List[Path],
    batch_processor: BatchProcessor,
    transcriber: Transcriber,
    verbose: bool,
) -> List[Tuple[bool, Tuple[str, str]]]:
    """
    Run the transcription without multiprocessing.

    :param audio_files:     The audio files to transcribe.
    :param batch_processor: The batch processor to use.
    :param transcriber:     The transcriber to use.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Load the transcription pipeline:
    if verbose:
        _LOGGER.info(f"Loading the transcription pipeline.")
    transcriber.load()
    if verbose:
        _LOGGER.info("Transcription pipeline loaded.")

    # Transcribe the files:
    transcriber.transcribe(
        audio_files=audio_files,
        batch_processor=batch_processor,
        verbose=verbose,
    )

    # Return the results:
    return batch_processor.get_results()


def _parallel_run(
    n_workers: int,
    audio_files: List[Path],
    batch_processor: BatchProcessor,
    transcriber: Transcriber,
    verbose: bool,
):
    """
    Run the transcription with multiprocessing.

    :param n_workers:       The amount of workers to use as task completers.
    :param audio_files:     The audio files to transcribe.
    :param batch_processor: The batch processor to use.
    :param transcriber:     The transcriber to use.
    :param verbose:         Verbosity.

    :returns: The collected results.
    """
    # Initialize the multiprocessing queues:
    batches_queue = Queue()
    tasks_queue = Queue()
    results_queue = Queue()

    # Initialize the multiprocessing processes:
    batch_processing_process = Process(
        target=_multiprocessing_process_batches,
        kwargs={
            "batch_processor": batch_processor,
            "batches_queue": batches_queue,
            "tasks_queue": tasks_queue,
            "n_task_completers": n_workers,
        },
    )
    task_completion_processes = [
        Process(
            target=_multiprocessing_complete_tasks,
            kwargs={"tasks_queue": tasks_queue, "results_queue": results_queue},
        )
        for _ in range(n_workers)
    ]

    # Start the multiprocessing processes:
    batch_processing_process.start()
    for p in task_completion_processes:
        p.start()

    # Load the transcription pipeline:
    if verbose:
        _LOGGER.info(f"Loading the transcription pipeline.")
    transcriber.load()
    if verbose:
        _LOGGER.info("Transcription pipeline loaded.")

    # Transcribe the files:
    transcriber.transcribe(
        audio_files=audio_files, batches_queue=batches_queue, verbose=verbose
    )

    # Collect the results:
    results = []
    stop_marks_counter = 0
    while True:
        # Get a result from the queue:
        result: Tuple[bool, Tuple[str, str]] = results_queue.get()
        if result == _MULTIPROCESSING_STOP_MARK:
            stop_marks_counter += 1
            if stop_marks_counter == n_workers:
                break
        else:
            # Collect the result:
            results.append(result)

    # Wait for the processes to finish:
    results_queue.empty()
    batch_processing_process.join()
    for p in task_completion_processes:
        p.join()

    return results base_image: mlrun/mlrun commands: [] code_origin: '' origin_filename: '' requirements: - - openai-whisper + - transformers - tqdm + - torchaudio + - torch entry_points: - open_mpi_handler: - name: open_mpi_handler - doc: '' + do_task: + name: do_task + doc: Try to perform the task storing an error if occurred. parameters: - - name: worker_inputs - type: List[str] - - name: root_worker_inputs - type: Dict[str, Any] - default: null + - name: self + outputs: [] + lineno: 348 + has_varargs: false + has_kwargs: false + is_failed: + name: is_failed + doc: Check if the task failed. + parameters: + - name: self outputs: - - default: '' - lineno: 29 - decorator: - name: decorator - doc: '' + - doc: Whether the task failed. + type: bool + lineno: 70 + has_varargs: false + has_kwargs: false + get_result: + name: get_result + doc: 'Get the result of the task. If the task failed, the error will be returned, + otherwise, the result will be the + + text file name.' parameters: - - name: handler + - name: self outputs: - - default: '' - lineno: 41 - wrapper: - name: wrapper - doc: '' - parameters: [] + - doc: The task's result. + type: Tuple[str, str] + lineno: 78 + has_varargs: false + has_kwargs: false + to_tuple: + name: to_tuple + doc: Convert the task to a tuple to reconstruct it later (used for multiprocessing + to pass in queue). + parameters: + - name: self + outputs: + - doc: The converted task. + type: Tuple[str, dict] + lineno: 358 + has_varargs: false + has_kwargs: false + transcription_output_channels: + name: transcription_output_channels + doc: Get the transcription output channels. + parameters: + - name: self + outputs: + - doc: The transcription output channels. + type: List[Tuple[str, dict]] + lineno: 340 + has_varargs: false + has_kwargs: false + process_batch: + name: process_batch + doc: 'Process a batch of transcriptions. Tasks related to the given batch will + be created and stored in the batch + + processor.' + parameters: + - name: self + - name: batch + type: List[dict] + doc: The batch of transcriptions to process. + outputs: [] + lineno: 575 + has_varargs: false + has_kwargs: false + get_tasks: + name: get_tasks + doc: Get the tasks to perform. + parameters: + - name: self + outputs: + - doc: The tasks to perform. + type: List[BaseTask] + lineno: 453 + has_varargs: false + has_kwargs: false + do_tasks: + name: do_tasks + doc: Perform the tasks. Should be used if no multiprocessing queue is given + to a transcriber. + parameters: + - name: self + outputs: [] + lineno: 463 + has_varargs: false + has_kwargs: false + get_results: + name: get_results + doc: Get the results of the tasks. The stored results are then cleared. + parameters: + - name: self outputs: - - default: '' - lineno: 46 + - doc: The results of the tasks. + type: List[Tuple[bool, Tuple[str, str]]] + lineno: 471 + has_varargs: false + has_kwargs: false + load: + name: load + doc: Load the transcriber. Must be called before transcribing. + parameters: + - name: self + outputs: [] + lineno: 695 + has_varargs: false + has_kwargs: false transcribe: name: transcribe - doc: 'Transcribe audio files into text files and collect additional data. The - end result is a directory of transcribed - - text files and a dataframe containing the following columns: - - - * audio_file - The audio file path. - - * transcription_file - The transcribed text file name in the output directory. - - * language - The detected language in the audio file. - - * language_probability - The detected language probability. - - * duration - The duration (in seconds) of the audio file (only if `audio_duration` - is set to True).' + doc: "Transcribe audio files into text files and collect additional data. The\ + \ end result is a directory of transcribed\ntext files and a dataframe containing\ + \ the following columns:\n\n* audio_file - The audio file path.\n* transcription_file\ + \ - The transcribed text file name in the output directory.\n\nThe transcription\ + \ is based on Huggingface's ASR pipeline -\nhttps://huggingface.co/transformers/main_classes/pipelines.html#transformers.AutomaticSpeechRecognitionPipeline\ + \ and\nis tested with OpenAI's Whisper models - https://huggingface.co/openai.\n\ + \nIf one of the speaker diarization parameters are given (either `speech_diarization`\ + \ or\n`speech_diarize_per_channel`), the transcription will be written in\ + \ a conversation format, where each speaker will\nbe written in a separate\ + \ line::\n\n speaker_1: text\n speaker_2: text\n speaker_1: text\n\ + \ ..." parameters: - name: data_path - type: Union[str, List[str]] + type: Union[str, Path, List[Union[str, Path]]] doc: A directory of audio files or a single file or a list of files to transcribe. - name: output_directory type: str - doc: Path to a directory to save all transcribed audio files. + doc: Path to a directory to save all transcribed audio files. If not given, + will save the transcribed files in a temporary directory. + default: null - name: model_name type: str - doc: 'One of the official model names of Whisper: {''tiny.en'', ''tiny'', - ''base.en'', ''base'', ''small.en'', ''small'', ''medium.en'', ''medium'', - ''large-v1'', ''large-v2'', ''large''} or a full name of a fine-tuned whisper - model from the huggingface hub.' - default: base + doc: 'The model name to use. Should be a model from the OpenAI''s Whisper + models for best results (for example "tiny", "base", "large", etc.). See + here for more information: https://huggingface.co/openai?search_models=whisper.' + default: openai/whisper-tiny - name: device - type: Literal[, , ] - doc: Device to load the model. Can be one of {"cuda", "cpu"}. Default will - prefer "cuda" if available. To use a specific GPU or more than one GPU, - pass the `device_index` argument via the `init_kwargs`. - default: auto - - name: compute_type type: str - doc: 'The data type to use for computation. For more information, check https://opennmt.net/CTranslate2/quantization.html. - Default: "default" - will use the default type depending on the device used.' - default: default - - name: language + doc: The device to use for inference. If not given, will use GPU if available. + default: null + - name: use_flash_attention_2 + type: bool + doc: 'Whether to use the Flash Attention 2 implementation. It can be used + only with one of the following GPUs: Nvidia H series and Nvidia A series. + T4 support will be available soon.' + default: null + - name: use_better_transformers + type: bool + doc: Whether to use the Better Transformers library to further optimize the + model. Should be used for all use cases that do not support flash attention + 2. + default: null + - name: assistant_model + type: str + doc: 'The assistant model name to use for inference. Notice that the optimizations + (flash attention 2 and better transformers) will be applied for the assistant + as well. Should be a model from Huggingface''s distil-whisper (see here + for more information: https://github.com/huggingface/distil-whisper).' + default: null + - name: max_new_tokens + type: int + doc: The maximum number of new tokens to generate. This is used to limit the + generation length. Default is 128 tokens. + default: 128 + - name: chunk_length_s + type: int + doc: The audio chunk to split the audio to (in seconds). Default is 30 seconds. + default: 30 + - name: batch_size + type: int + doc: The batch size to use for inference. Default is 2. + default: 8 + - name: spoken_language type: str - doc: 'The spoken language to force Whisper the output language. If None, the - Whisper model will automatically predict the output langauge. Default: None.' + doc: Aim whisper to know what language is spoken. If None, it will try to + detect it. default: null - name: translate_to_english type: bool - doc: 'Whether to translate the English post transcription. Default: False.' + doc: Whether to translate the transcriptions to English. default: false - name: speech_diarization type: Dict[str, List[Tuple[float, float, str]]] doc: 'A speech diarization dictionary with the file names to transcribe as keys and their diarization as value. The diarization is a list of tuples: - (start, end, speaker). The transcription result will be in the following - format: "{speaker}: text text text.". Files with missing diarizations will - print a warning. Pay attention the diarization must be for the entire duration - of the audio file (as long as Whisper is predicting words up until then).' + (start, end, speaker). An example for a diarization dictionary::' default: null - - name: audio_duration - type: bool - doc: 'Whether to include the audio files duration (in seconds). The estimated - duration is from bitrate and may be inaccurate. Default: False.' - default: false - - name: init_kwargs - type: dict - doc: Additional `WhisperModel.__init__` keyword arguments to use. + - name: speech_diarize_per_channel + type: int + doc: 'Perform speech diarization per channel. Each speaker is expected to + belong to a separate channel in the audio. Notice: This will make the transcription + slower as each channel wil be transcribed separatly. If a speech diarization + is passed (via the `speech_diarization` parameter), this parameter is ignored.' default: null - - name: transcribe_kwargs - type: dict - doc: Additional `WhisperModel.transcribe` keyword arguments to use. + - name: speaker_labels + type: List[str] + doc: A list of speaker labels by channel order to use for writing the transcription + with respect to per channel speech diarization. This won't be used together + with a given speech diarization (via the `speech_diarization` parameter). default: null + - name: use_multiprocessing + type: Union[bool, int] + doc: 'Whether to use multiprocessing to transcribe the audio files. Can be + either a boolean value or an integer. If `True`, will use the default amount + of workers (3): 1 for transcription, 1 for batch processing and 1 for task + completion (such as speech diarization and writing to files). To control + the amount of tasks completion workers, an integer can be provided to specify + the amount of workers. `False`, will use a single process. Default is `False`.' + default: false - name: verbose type: bool - doc: 'Whether to present logs of a progress bar and errors. Default: False.' + doc: Whether to print the progress of the transcription. Default is `False`. default: false + outputs: [] + lineno: 1097 + has_varargs: false + has_kwargs: false + audio_iterator: + name: audio_iterator + doc: '' + parameters: [] + outputs: + - type: Generator[Union[dict, str], None, None] + lineno: 804 + has_varargs: false + has_kwargs: false + batch_iterator: + name: batch_iterator + doc: '' + parameters: [] outputs: - - doc: 'A tuple of:' - default: '' - lineno: 135 + - type: Generator[List[Union[dict, str]], None, None] + lineno: 816 + has_varargs: false + has_kwargs: false + open_mpi_handler: + name: open_mpi_handler + doc: '' + parameters: + - name: worker_inputs + type: List[str] + - name: root_worker_inputs + type: Dict[str, Any] + default: null + outputs: [] + lineno: 957 + has_varargs: false + has_kwargs: false + decorator: + name: decorator + doc: '' + parameters: + - name: handler + outputs: [] + lineno: 969 + has_varargs: false + has_kwargs: false + wrapper: + name: wrapper + doc: '' + parameters: [] + outputs: [] + lineno: 974 + has_varargs: false + has_kwargs: true description: Transcribe audio files into text files default_handler: transcribe disable_auto_mount: false diff --git a/transcribe/item.yaml b/transcribe/item.yaml index 28bc5a1c0..d53341ff2 100644 --- a/transcribe/item.yaml +++ b/transcribe/item.yaml @@ -21,8 +21,10 @@ spec: image: mlrun/mlrun kind: job requirements: - - openai-whisper + - transformers - tqdm + - torchaudio + - torch + - accelerate url: '' -version: 0.0.2 -test_valid: True +version: 1.0.0 \ No newline at end of file diff --git a/transcribe/requirements.txt b/transcribe/requirements.txt index 47af1e515..d16bfc9dd 100644 --- a/transcribe/requirements.txt +++ b/transcribe/requirements.txt @@ -1,3 +1,5 @@ -faster-whisper +transformers +torch +torchaudio tqdm -librosa \ No newline at end of file +accelerate \ No newline at end of file diff --git a/transcribe/test_transcribe.py b/transcribe/test_transcribe.py index 9f89cddbb..f70b3856d 100644 --- a/transcribe/test_transcribe.py +++ b/transcribe/test_transcribe.py @@ -14,13 +14,13 @@ # import os import pathlib -import sys import tempfile from difflib import SequenceMatcher import mlrun import pytest + expected_outputs = [ "This is a speech to text test.", "In the heart of the stadium, " @@ -29,24 +29,21 @@ "The crowd roars, a symphony of passion, " "as the game writes its unpredictable story on the field of destiny.", ] -whisper_models = [ - "tiny.en", - "tiny", - "base.en", - "base", +models = [ + + "openai/whisper-tiny", ] -@pytest.mark.skipif( - condition=sys.version_info[:2] < (3, 8), - reason="whisper requires python 3.8 and above" -) -@pytest.mark.parametrize("model_name", whisper_models) +@pytest.mark.skipif(os.system("which ffmpeg") != 0, reason="ffmpeg not installed") +@pytest.mark.parametrize("model_name", models) @pytest.mark.parametrize("audio_path", ["./data", "./data/speech_01.mp3"]) def test_transcribe(model_name: str, audio_path: str): # Setting variables and importing function: artifact_path = tempfile.mkdtemp() - transcribe_function = mlrun.import_function("function.yaml") + project = mlrun.get_or_create_project("test") + transcribe_function = project.set_function("transcribe.py", "transcribe", kind="job", image="mlrun/mlrun") + # transcribe_function = mlrun.import_function("function.yaml") temp_dir = tempfile.mkdtemp() # Running transcribe function: @@ -56,7 +53,6 @@ def test_transcribe(model_name: str, audio_path: str): "data_path": audio_path, "model_name": model_name, "device": "cpu", - "compute_type": "int8", "output_directory": temp_dir, }, local=True, diff --git a/transcribe/transcribe.py b/transcribe/transcribe.py index bcd37f5c5..9cabcb1e8 100644 --- a/transcribe/transcribe.py +++ b/transcribe/transcribe.py @@ -1,4 +1,4 @@ -# Copyright 2023 Iguazio +# Copyright 2024 Iguazio # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,16 +11,944 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import operator -import pathlib +import os +import tempfile from functools import reduce, wraps -from typing import Any, Dict, List, Literal, NamedTuple, Tuple, Union +from multiprocessing import Process, Queue +from pathlib import Path +from typing import Any, Dict, Generator, List, Literal, NamedTuple, Tuple, Union -import faster_whisper import pandas as pd +import torch +import torchaudio from tqdm import tqdm +from transformers import ( + AutomaticSpeechRecognitionPipeline, + AutoModelForCausalLM, + pipeline, +) +from transformers.utils import is_flash_attn_2_available + + +class BaseTask: + """ + A task to write the transcription to file. + """ + + def __init__( + self, audio_file: Path, transcription_output: Union[dict, str], text_file: Path + ): + """ + Initialize the task. + + :param audio_file: Path to the audio file that was transcribed. + :param transcription_output: The transcription output from the pipeline. String means an exception was raised. + :param text_file: Path to the text file to write the transcription to. + """ + # Store the parameters: + self._audio_file = audio_file + self._transcription_output = transcription_output + self._text_file = text_file + + # Prepare the error variable: + self._error: str = None + + def do_task(self): + """ + Try to perform the task storing an error if occurred. + """ + if isinstance(self._transcription_output, str): + self._error = self._transcription_output + return + try: + self._do_task() + except Exception as exception: + self._error = str(exception) + + def is_failed(self) -> bool: + """ + Check if the task failed. + + :returns: Whether the task failed. + """ + return self._error is not None + + def get_result(self) -> Tuple[str, str]: + """ + Get the result of the task. If the task failed, the error will be returned, otherwise, the result will be the + text file name. + + :returns: The task's result. + """ + if self.is_failed(): + return self._audio_file.name, self._error + return self._audio_file.name, self._text_file.name + + def to_tuple(self) -> Tuple[str, dict]: + """ + Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). + + :returns: The converted task. + """ + return self.__class__.__name__, { + "audio_file": self._audio_file, + "transcription_output": self._transcription_output, + "text_file": self._text_file, + } + + def _do_task(self): + """ + Perform the task - write the transcription to the stored file path. + """ + # Checking for no duplications: + i = 1 + while self._text_file.exists(): + i += 1 + self._text_file = ( + self._text_file.parent + / f"{self._text_file.stem.rsplit('_', 1)[0]}_{i}{self._text_file.suffix}" + ) + + # Make sure all directories are created: + self._text_file.parent.mkdir(exist_ok=True, parents=True) + + # Write to file: + with open(self._text_file, "w") as fp: + fp.write(self._transcription_output["text"]) + + +class SpeechDiarizationTask(BaseTask): + """ + A task to write the transcription to file with respect to a given speech diarization. + """ + + class _DiarizationSegment(NamedTuple): + """ + A speech diarization segment. + """ + + start: float + end: float + speaker: str + + class _WordTimestamp(NamedTuple): + """ + A word with its start and end timestamps. + """ + + start: float + end: float + text: str + + def __init__( + self, + audio_file: Path, + transcription_output: dict, + text_file: Path, + speech_diarization: List[Tuple[float, float, str]], + ): + """ + Initialize the task. + + :param audio_file: Path to the audio file that was transcribed. + :param transcription_output: The transcription output from the pipeline. + :param text_file: Path to the text file to write the transcription to. + :param speech_diarization: A speech diarization as a list of tuples: (start, end, speaker). + """ + super().__init__( + audio_file=audio_file, + transcription_output=transcription_output, + text_file=text_file, + ) + self._speech_diarization = speech_diarization + self._segments: List[SpeechDiarizationTask._DiarizationSegment] = None + self._last_chosen_index = 0 + + def to_tuple(self) -> Tuple[str, dict]: + """ + Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). + + :returns: The converted task. + """ + task_class, task_kwargs = super().to_tuple() + return task_class, { + **task_kwargs, + "speech_diarization": self._speech_diarization, + } + + def _do_task(self): + """ + Perform the task - write the transcription to the stored file path with respect to the given speech diarization. + """ + # Check if a speech diarization is given, if not, just write the transcription to file: + if not self._speech_diarization: + super()._do_task() + return + + # Cast the chunks to word timestamps tuples: + words = [ + SpeechDiarizationTask._WordTimestamp( + start=chunk["timestamp"][0], + end=chunk["timestamp"][1], + text=chunk["text"], + ) + for chunk in self._transcription_output["chunks"] + ] + + # Cast speech diarization to segments tuples: + self._segments = [ + SpeechDiarizationTask._DiarizationSegment(*segment) + for segment in self._speech_diarization + ] + + # Try to match the Whisper model predicted timestamps to the closest diarization segment (closest diarization + # segment will be the most overlapping with the word, and if there is no overlap, the closest segment to the + # word): + speaker = self._segments[self._last_chosen_index].speaker + text = f"{speaker}:" + for word in words: + # Get the next diarization segment: + self._get_next_segment(word=word) + # Check if the segment is of the same speaker: + if self._segments[self._last_chosen_index].speaker == speaker: + # Collect the word: + text += word.text + else: + # Append a newline and update the new speaker: + speaker = self._segments[self._last_chosen_index].speaker + text += f"\n{speaker}:{word.text}" + + # Update the transcription output with the new text to write it to file: + self._transcription_output["text"] = text + super()._do_task() + + def _get_next_segment( + self, + word: _WordTimestamp, + ): + """ + Get the next diarization segment the given word falls into. The `self._last_chosen_index` will be updated + accordingly. + + :param word: The word timestamp to match to the next segment. + """ + # If the last chosen segment is the last segment, return it: + if self._last_chosen_index == len(self._segments) - 1: + return + + # Get the last chosen diarization segment: + last_chosen = self._segments[self._last_chosen_index] + + # None value may appear if the word is the last word in the audio file, or it was split during inference. In + # that case, we'll set the last segment: + if word.end is None: + self._last_chosen_index = len(self._segments) - 1 + return + + # If the word ends before the last chosen segment: + if word.end <= last_chosen.start: + # Then it is still the closest segment + return + + # We check if it ends inside the last chosen segment: + if word.end < last_chosen.end: + # Then it still is the closest segment + return + + # The word ends after the segment, we need to collect all next segments up until the word ends before them: + possible_segments = [self._last_chosen_index] + for i in range(self._last_chosen_index + 1, len(self._segments)): + if word.end > self._segments[i].end: + possible_segments.append(i) + continue + possible_segments.append(i) + break + + # Check for the most overlapping option: + best_overlap = 0 + most_overlapping_segment_index = None + for i in possible_segments: + # If the word starts before segment: + if word.start <= self._segments[i].start: + # If it ends before the segment, there is an overlap from the start of the segment to the end of the + # word: + if word.end < self._segments[i].end: + overlap = word.end - self._segments[i].start + else: + # The word is wrapping the segment, the overlap is the segment's length: + overlap = self._segments[i].end - self._segments[i].start + # The word starts in segment, check if the word ends in it: + elif word.end < self._segments[i].end: + # The overlap is the word's length: + overlap = word.end - word.start + # The word start in segment but ends after it, the overlap is from the word's start to the segment's end: + else: + overlap = self._segments[i].end - word.start + # Check for new best overlap: + if overlap > best_overlap: + best_overlap = overlap + most_overlapping_segment_index = i + if most_overlapping_segment_index is not None: + self._last_chosen_index = most_overlapping_segment_index + return + + # If there is no overlapping segment, return the closest segment: + best_distance = None + closest_segment_index = None + for i in possible_segments: + distance = ( + word.start - self._segments[i].end + if word.start > self._segments[i].end + else self._segments[i].start - word.end + ) + if best_distance is None or distance < best_distance: + best_distance = distance + closest_segment_index = i + self._last_chosen_index = closest_segment_index + + +class SpeechDiarizationPerChannelTask(BaseTask): + """ + A task to write the transcription to file with respect to a given speech diarization per channel. + """ + + class _WordTimestamp(NamedTuple): + """ + A word with its start and end timestamps and speaker label (channel the word was taken from). + """ + + start: float + end: float + speaker: str + text: str + + def __init__(self, audio_file: Path, text_file: Path): + """ + Initialize the task. + + :param audio_file: Path to the audio file that was transcribed. + :param text_file: Path to the text file to write the transcription to. + """ + super().__init__( + audio_file=audio_file, transcription_output={}, text_file=text_file + ) + self._transcription_output_channels: List[Tuple[str, dict]] = [] + + @property + def transcription_output_channels(self) -> List[Tuple[str, dict]]: + """ + Get the transcription output channels. + + :returns: The transcription output channels. + """ + return self._transcription_output_channels + + def do_task(self): + """ + Try to perform the task storing an error if occurred. + """ + for _, channel_output in self._transcription_output_channels: + if isinstance(channel_output, str): + self._error = self._transcription_output_channels + return + super().do_task() + + def to_tuple(self) -> Tuple[str, dict]: + """ + Convert the task to a tuple to reconstruct it later (used for multiprocessing to pass in queue). + + :returns: The converted task. + """ + task_class, task_kwargs = super().to_tuple() + task_kwargs.pop("transcription_output") + return task_class, task_kwargs + + def _do_task(self): + """ + Perform the task - write the transcription to the stored file path with respect to the given speech diarization + per channel. + """ + # Cast the chunks to word timestamps tuples: + words_per_channel = [ + [ + SpeechDiarizationPerChannelTask._WordTimestamp( + start=chunk["timestamp"][0], + end=chunk["timestamp"][1], + speaker=speaker, + text=chunk["text"], + ) + for chunk in output["chunks"] + ] + for speaker, output in self._transcription_output_channels + ] + + # Merge and sort the words per channel by their start time: + words = operator.add(*words_per_channel) + words.sort() + + # Write the transcription to file: + current_speaker = words[0].speaker + text = f"{current_speaker}:" + for word in words: + # Check if the word's speaker is different from the current one: + if word.speaker != current_speaker: + # Append a newline and update the new speaker: + current_speaker = word.speaker + text += f"\n{current_speaker}:" + # Collect the word: + text += word.text + + # Update the transcription output with the new text to write it to file: + self._transcription_output["text"] = text + super()._do_task() + + +class BatchProcessor: + """ + A batch processor to process batches of transcriptions. The batch processor is creating tasks and is aimed to be + working along the transcriber. It can be used with multiprocessing queue or run the tasks directly using the + associated methods. + """ + + def __init__(self, audio_files: List[Path], output_directory: Path): + """ + Initialize the batch processor. + + :param audio_files: The list of all audio files to transcribe. + :param output_directory: The output directory to write the transcriptions to. + """ + # Store the parameters: + self._audio_files = audio_files + self._output_directory = output_directory + + # Prepare the batching variables: + self._current_file_index = 0 + self._tasks: List[BaseTask] = [] + self._results: List[Tuple[bool, Tuple[str, str]]] = [] + + def process_batch(self, batch: List[Union[dict, str]]): + """ + Process a batch of transcriptions. Tasks related to the given batch will be created and stored in the batch + processor. + + :param batch: The batch of transcriptions to process. + """ + # Get the relevant files belongs to the given batch: + current_files = self._get_current_files(batch_size=len(batch)) + + # Build the diarization tasks: + self._tasks.extend( + [ + BaseTask( + audio_file=file, + transcription_output=batch[i], + text_file=self._output_directory / f"{file.stem}.txt", + ) + for i, file in enumerate(current_files) + ] + ) + + def get_tasks(self) -> List[BaseTask]: + """ + Get the tasks to perform. + + :returns: The tasks to perform. + """ + tasks = self._tasks + self._tasks = [] + return tasks + + def do_tasks(self): + """ + Perform the tasks. Should be used if no multiprocessing queue is given to a transcriber. + """ + for task in self.get_tasks(): + task.do_task() + self._results.append((task.is_failed(), task.get_result())) + + def get_results(self) -> List[Tuple[bool, Tuple[str, str]]]: + """ + Get the results of the tasks. The stored results are then cleared. + + :returns: The results of the tasks. + """ + results = self._results + self._results = [] + return results + + def _get_current_files(self, batch_size: int) -> List[Path]: + """ + Get the current files to process. + + :param batch_size: The batch size to progress the current file index. + + :returns: The current files to process. + """ + end_index = ( + self._current_file_index + batch_size + if self._current_file_index + batch_size < len(self._audio_files) + else len(self._audio_files) + ) + current_files = self._audio_files[self._current_file_index : end_index] + self._current_file_index = end_index + return current_files + + +class SpeechDiarizationBatchProcessor(BatchProcessor): + """ + A batch processor to process batches of transcriptions with respect to a given speech diarization. The batch + processor is creating tasks and is aimed to be working along the transcriber. It can be used with multiprocessing + queue or run the tasks directly using the associated methods. + """ + + def __init__( + self, audio_files: List[Path], output_directory: Path, speech_diarization: dict + ): + """ + Initialize the batch processor. + + :param audio_files: The list of all audio files to transcribe. + :param output_directory: The output directory to write the transcriptions to. + :param speech_diarization: A speech diarization dictionary to pass along with each processed batch. + """ + super().__init__(audio_files=audio_files, output_directory=output_directory) + self._speech_diarization = speech_diarization + self._audio_files = audio_files + + def process_batch(self, batch: List[dict]): + """ + Process a batch of transcriptions. Tasks related to the given batch will be created and stored in the batch + processor. + + :param batch: The batch of transcriptions to process. + """ + # Get the relevant files belongs to the given batch: + current_files = self._get_current_files(batch_size=len(batch)) + + # Build the diarization tasks: + self._tasks.extend( + [ + SpeechDiarizationTask( + audio_file=file, + transcription_output=batch[i], + text_file=self._output_directory / f"{file.stem}.txt", + speech_diarization=self._speech_diarization.get(file.name), + ) + for i, file in enumerate(current_files) + ] + ) + + +class PerChannelSpeechDiarizationBatchProcessor(BatchProcessor): + """ + A batch processor to process batches of transcriptions per channel. The batch processor is creating tasks with the + selected amount of channels given and is aimed to be working along the transcriber. It can be used with + multiprocessing queue or run the tasks directly using the associated methods. + """ + + def __init__( + self, + audio_files: List[Path], + output_directory: Path, + n_channels: int, + speakers: List[str], + ): + """ + Initialize the batch processor. + + :param audio_files: The list of all audio files to transcribe. + :param output_directory: The output directory to write the transcriptions to. + :param n_channels: The number of channels in each audio file to transcribe. + :param speakers: The speakers labels to use for each channel. + """ + super().__init__(audio_files=audio_files, output_directory=output_directory) + + # Store the parameters: + self._n_channels = n_channels + self._speakers = speakers + + # Prepare a channel buffer to store the channels until the current task created is fully covered: + self._task_in_process: SpeechDiarizationPerChannelTask = None + + def process_batch(self, batch: List[dict]): + """ + Process a batch of transcriptions. Tasks related to the given batch will be created and stored in the batch + processor. + + :param batch: The batch of transcriptions to process. + """ + # Go over the batch and create the tasks: + for output in batch: + # Check if there is a task in process: + if not self._task_in_process: + # Create a new task: + self._task_in_process = SpeechDiarizationPerChannelTask( + audio_file=self._audio_files[self._current_file_index], + text_file=self._output_directory + / f"{self._audio_files[self._current_file_index].stem}.txt", + ) + # Get the channel's speaker: + speaker = self._speakers[ + len(self._task_in_process.transcription_output_channels) + ] + # Collect the channel into the processed task: + self._task_in_process.transcription_output_channels.append( + (speaker, output) + ) + # Check if the task is fully covered (all channels are collected): + if ( + len(self._task_in_process.transcription_output_channels) + == self._n_channels + ): + # Collect the task and reset the task in process: + self._tasks.append(self._task_in_process) + self._current_file_index += 1 + self._task_in_process = None + + +class Transcriber: + """ + A transcription wrapper for the Huggingface's ASR pipeline - + https://huggingface.co/transformers/main_classes/pipelines.html#transformers.AutomaticSpeechRecognitionPipeline to + use with OpenAI's Whisper models - https://huggingface.co/openai. + """ + + def __init__( + self, + model_name: str, + device: str = None, + use_flash_attention_2: bool = None, + use_better_transformers: bool = None, + assistant_model: str = None, + max_new_tokens: int = 128, + chunk_length_s: int = 30, + batch_size: int = 2, + spoken_language: str = None, + translate_to_english: bool = False, + return_timestamps: Union[bool, Literal["word"]] = False, + per_channel_transcription: int = 0, + ): + """ + Initialize the transcriber. + + :param model_name: The model name to use. Should be a model from the OpenAI's Whisper models for + best results (for example "tiny", "base", "large", etc.). + :param device: The device to use for inference. If not given, will use GPU if available. + :param use_flash_attention_2: Whether to use the Flash Attention 2 implementation. It can be used only with + one of the following GPUs: Nvidia H series and Nvidia A series. T4 support + will be available soon. + + Note: If both `use_flash_attention_2` and + `use_better_transformers` are `None`, the optimization will be chosen + automatically according to the available resources. + + :param use_better_transformers: Whether to use the Better Transformers library to further optimize the model. + Should be used for all use cases that do not support flash attention 2. + + Note: If both `use_flash_attention_2` and `use_better_transformers` are + `None`, the optimization will be chosen automatically according to the + available resources. + :param assistant_model: The assistant model name to use for inference. Notice that the optimizations + (flash attention 2 and better transformers) will be applied for the assistant + as well. Should be a model from Huggingface's distil-whisper (see here for + more information: https://github.com/huggingface/distil-whisper). + :param max_new_tokens: The maximum number of new tokens to generate. This is used to limit the + generation length. Default is 128 tokens. + :param chunk_length_s: The audio chunk to split the audio to (in seconds). Default is 30 seconds. + :param batch_size: The batch size to use for inference. Default is 2. + :param spoken_language: Aim whisper to know what language is spoken. If None, it will try to detect it + for each chunk. + :param translate_to_english: Whether to translate the transcriptions to English. Default is False. + :param return_timestamps: Whether to return the timestamps of the words. If "word", will return the + timestamps of each word. If True will return the timestamps of each chunk. + Default is False. Aimed to be used for speech diarization. + :param per_channel_transcription: Whether to do per channel transcription. If needed to run per channel + transcription, pass the number of channels expected for each audio file here. + 0 means regular transcription (merge channels). + + Note: If `per_channel_transcription` is not 0, `batch_size` must be treated to + be the number of channels and not audio files. Aimed to be used for per + channel speech diarization. + """ + # Store loading parameters: + self._model_name = model_name + self._device = device + self._use_flash_attention_2 = use_flash_attention_2 + self._use_better_transformers = use_better_transformers + self._max_new_tokens = max_new_tokens + self._chunk_length_s = chunk_length_s + self._batch_size = batch_size + self._return_timestamps = return_timestamps + self._per_channel_transcription = per_channel_transcription + + # Store generation parameters: + self._assistant_model = assistant_model + self._spoken_language = spoken_language + self._translate_to_english = translate_to_english + + # Prepare the transcription objects: + self._transcription_pipeline: AutomaticSpeechRecognitionPipeline = None + self._generate_kwargs: dict = None + + def load(self): + """ + Load the transcriber. Must be called before transcribing. + """ + # Set the device and data type to use (prefer GPU if available): + device = torch.device( + self._device or "cuda" if torch.cuda.is_available() else "cpu" + ) + torch_dtype = torch.float16 if device.type == "cuda" else torch.float32 + + # Choose the optimization to use (in case the user did not specify any): + if ( + self._use_flash_attention_2 is None + and self._use_better_transformers is None + ): + # Prefer to use flash attention 2 if available and cuda device is supported (see GPU names to architecture + # here: https://en.wikipedia.org/wiki/List_of_Nvidia_graphics_processing_units#Tesla): + if device.type == "cuda" and is_flash_attn_2_available(): + cuda_device_name = torch.cuda.get_device_properties(device).name + if any( + cuda_device_name.startswith(gpu_name) + for gpu_name in [ + "NVIDIA A", # For Ampere architecture (e.g. A10, A30, A100) + "NVIDIA H", # For Hopper architecture (e.g. H100) + "NVIDIA L", # For Ada Lovelace architecture (e.g. L4, L40) + "NVIDIA RTX 30", # For Ada Lovelace architecture (RTX 30 series) + "NVIDIA RTX 40", # For Ada Lovelace architecture (RTX 40 series) + "NVIDIA RTX 50", # For Ada Lovelace architecture (RTX 50 series) + # Will be supported soon according to FlashAttention GitHub repo: + # https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features + # "NVIDIA T4", # For Turing architecture (only T4) + # "NVIDIA RTX 20", # For Turing architecture (RTX 20 series) + ] + ): + self._use_flash_attention_2 = True + else: + self._use_better_transformers = True + else: + self._use_better_transformers = True + + # Build the optimizations kwargs: + model_kwargs = { + "low_cpu_mem_usage": True, + "use_safetensors": True, + } + if self._use_flash_attention_2: + if _LOGGER: + _LOGGER.info( + "Using FlashAttention2 optimization - make sure the `flash-attn` package is installed via " + "`pip install -U flash-attn --no-build-isolation`" + ) + model_kwargs["attn_implementation"] = "flash_attention_2" + elif self._use_better_transformers: + if _LOGGER: + _LOGGER.info( + "Using BetterTransformers optimization - make sure the `optimum` package is installed via " + "`pip install -U optimum`" + ) + model_kwargs["attn_implementation"] = "sdpa" + + # Initialize the speech recognition pipeline: + self._transcription_pipeline = pipeline( + task="automatic-speech-recognition", + model=self._model_name, + model_kwargs=model_kwargs.copy(), + batch_size=self._batch_size, + max_new_tokens=self._max_new_tokens, + chunk_length_s=self._chunk_length_s, + return_timestamps=self._return_timestamps, + torch_dtype=torch_dtype, + device=device, + ) + + # Prepare the generation kwargs: + self._generate_kwargs = { + "language": self._spoken_language, + "task": "translate" if self._translate_to_english else "transcribe", + } + + # Initialize the assistant model (if needed): + if self._assistant_model: + assistant_model = AutoModelForCausalLM.from_pretrained( + self._assistant_model, torch_dtype=torch_dtype, **model_kwargs + ) + assistant_model.to(device) + self._generate_kwargs["assistant_model"] = assistant_model + + def transcribe( + self, + audio_files: List[Path], + batch_processor: BatchProcessor = None, + batches_queue: Queue = None, + verbose: bool = False, + ) -> Union[List[List[dict]], None]: + """ + Transcribe the given audio files. The transcriptions will be sent to a queue or a batch processor for further + processing like writing to text files. If no queue or batch processor is given, the transcriptions outputs from + the pipeline will be returned. Otherwise, `None` is returned. + + :param audio_files: The audio files to transcribe. + :param batch_processor: A batch processor. + :param batches_queue: A multiprocessing queue to put the batches in. + :param verbose: Whether to show a progress bar. Default is False. + + :returns: The transcriptions outputs from the pipeline if no queue or batch processor is given, otherwise, + `None`. + """ + # Wrap the audio files with a function to iterate over them via a generator (save memory and runtime with + # Huggingface's pipelines as they preload each input while inference is running): + def audio_iterator() -> Generator[Union[dict, str], None, None]: + if self._per_channel_transcription: + for audio_file in audio_files: + audio, sampling_rate = torchaudio.load(str(audio_file)) + audio = audio.numpy() + for channel in audio: + yield {"raw": channel, "sampling_rate": sampling_rate} + else: + for audio_file in audio_files: + yield str(audio_file) + + # Create a batch iterator: + def batch_iterator() -> Generator[List[Union[dict, str]], None, None]: + batch = [] + for audio in audio_iterator(): + batch.append(audio) + if len(batch) == self._batch_size: + yield batch + batch = [] + if batch: + yield batch + + # Prepare the successes dataframe and errors dictionary to be returned: + outputs = [] + + # Infer through the pipeline: + for input_batch in tqdm( + batch_iterator() if self._batch_size > 1 else audio_iterator(), + desc="Transcribing", + unit="channel" if self._per_channel_transcription else "audio file", + total=( + ( + (len(audio_files) // self._batch_size) + + (len(audio_files) % self._batch_size != 0) + ) + * (self._per_channel_transcription or 1) + ), + disable=not verbose, + ): + # Infer: + try: + output_batch = self._transcription_pipeline( + input_batch, + generate_kwargs=self._generate_kwargs, + ) + except Exception as exception: + # Collect the exception: + output_batch = str(exception) + # Align to batch size: + output_batch = ( + [output_batch] * len(input_batch) + if isinstance(input_batch, list) + else [output_batch] + ) + # To align with batching, if batch size is 1, wrap the output with a list: + if isinstance(output_batch, dict): + output_batch = [output_batch] + # If a batch processor is given, process the batch: + if batch_processor: + # Process it directly: + batch_processor.process_batch(batch=output_batch) + batch_processor.do_tasks() + elif batches_queue: + # Otherwise, queue the batch: + batches_queue.put(output_batch) + else: + # Otherwise, collect the output as is without processing: + outputs.append(output_batch) + + # Check if given a multiprocessing queue or a batch processor: + if batches_queue: + batches_queue.put(_MULTIPROCESSING_STOP_MARK) + + return outputs if not batch_processor else None + + +#: The value to send into multiprocessing queues to stop the process: +_MULTIPROCESSING_STOP_MARK = "STOP" + + +def _multiprocessing_process_batches( + batch_processor: BatchProcessor, + batches_queue: Queue, + tasks_queue: Queue, + n_task_completers: int, +): + """ + Process the batches in the given batches queue and put the tasks in the given tasks queue. The function will stop + when the given batches queue will receive the stop mark. It is aimed to be used with multiprocessing as a process. + + :param batch_processor: A batch processor to process the batches. + :param batches_queue: A queue to get the batches from. + :param tasks_queue: A queue to put the tasks in. + :param n_task_completers: The number of task completers (processes that run the `_multiprocessing_complete_tasks` + function). A stop mark will be sent to the tasks queue for each task completer. + """ + while True: + # Get the batch: + batch: List[dict] = batches_queue.get() + if batch == _MULTIPROCESSING_STOP_MARK: + break + + # Process the batch: + batch_processor.process_batch(batch=batch) + + # Get the tasks: + tasks = batch_processor.get_tasks() + + # Queue the tasks: + for task in tasks: + tasks_queue.put(task.to_tuple()) + + # Mark the end of the batches: + for _ in range(n_task_completers): + tasks_queue.put(_MULTIPROCESSING_STOP_MARK) + + +def _multiprocessing_complete_tasks(tasks_queue: Queue, results_queue: Queue): + """ + Complete the tasks in the given queue and put the results in the given results queue. The function will stop when + the given tasks queue will receive the stop mark. It is aimed to be used with multiprocessing as a process. + + :param tasks_queue: A queue to get the tasks from. + :param results_queue: A queue to put the results in. + """ + tasks_map = { + BaseTask.__name__: BaseTask, + SpeechDiarizationTask.__name__: SpeechDiarizationTask, + SpeechDiarizationPerChannelTask.__name__: SpeechDiarizationPerChannelTask, + } + + while True: + # Get the task: + task = tasks_queue.get() + if task == _MULTIPROCESSING_STOP_MARK: + break + + # Reconstruct the task: + task_class, task_kwargs = task + task = tasks_map[task_class](**task_kwargs) + + # Complete the task: + task.do_task() + results_queue.put((task.is_failed(), task.get_result())) + + # Mark the end of the tasks: + results_queue.put(_MULTIPROCESSING_STOP_MARK) + # Get the global logger: _LOGGER = logging.getLogger() @@ -55,7 +983,7 @@ def wrapper(**kwargs): continue if isinstance(input_argument, str): input_argument = _get_audio_files( - data_path=pathlib.Path(input_argument).absolute() + data_path=Path(input_argument).absolute() ) if len(input_argument) < size: raise ValueError( @@ -86,17 +1014,51 @@ def wrapper(**kwargs): # Run the worker: output = handler(**kwargs) + # Save the output directory of this worker: + output_directory = Path(output[0]) + # Send the output to the root rank (rank #0): output = comm.gather(output, root=0) + + # Join the data from all workers: if rank == 0: - # Join the outputs: context.logger.info("Collecting data from workers to root worker.") - output_directory = output[0][0] + + # Check if there are different output directories: + output_directories = set([Path(out_dir) for out_dir, _, _ in output]) + for r in range(1, size): + # True means the other workers should pass their files to the root worker (rank 0): + comm.send(len(output_directories) != 1, dest=r) + + # If there are different output directories, listen to the other workers: + if len(output_directories) != 1: + # Collect the files from the other workers: + files = [] + for r in range(1, size): + files.extend(comm.recv(source=r)) + # Write the files to the root worker's output directory: + for file_name, file_content in files: + with open(output_directory / file_name, "w") as f: + f.write(file_content) + + # Concatenate the dataframes: dataframe = pd.concat(objs=[df for _, df, _ in output], axis=0) + + # Concatenate the errors dictionaries: errors_dictionary = reduce( operator.ior, [err for _, _, err in output], {} ) - return output_directory, dataframe, errors_dictionary + + return str(output_directory), dataframe, errors_dictionary + + # Listen to rank 0 to see if there are different output directories and this rank need to send its files to + # it: + if comm.recv(source=0): + files = [] + for file in os.listdir(output_directory): + with open(output_directory / file, "r") as f: + files.append((file, f.read())) + comm.send(files, dest=0) return None return wrapper @@ -133,165 +1095,245 @@ def _check_mlrun_and_open_mpi() -> Tuple["mlrun.MLClientCtx", "mpi4py.MPI.Intrac @open_mpi_handler(worker_inputs=["data_path"], root_worker_inputs={"verbose": True}) def transcribe( - data_path: Union[str, List[str]], - output_directory: str, - model_name: str = "base", - device: Literal["cuda", "cpu", "auto"] = "auto", - compute_type: str = "default", - language: str = None, + # Input / Output kwargs: + data_path: Union[str, Path, List[Union[str, Path]]], + output_directory: str = None, + # Model loading kwargs: + model_name: str = "openai/whisper-tiny", + device: str = None, + use_flash_attention_2: bool = None, + use_better_transformers: bool = None, + # Generation kwargs: + assistant_model: str = None, + max_new_tokens: int = 128, + chunk_length_s: int = 30, + batch_size: int = 8, + spoken_language: str = None, translate_to_english: bool = False, + # Diarization kwargs: speech_diarization: Dict[str, List[Tuple[float, float, str]]] = None, - audio_duration: bool = False, - init_kwargs: dict = None, - transcribe_kwargs: dict = None, + speech_diarize_per_channel: int = None, + speaker_labels: List[str] = None, + # Other kwargs: + use_multiprocessing: Union[bool, int] = False, verbose: bool = False, -) -> Tuple[str, pd.DataFrame, dict]: +): """ Transcribe audio files into text files and collect additional data. The end result is a directory of transcribed text files and a dataframe containing the following columns: * audio_file - The audio file path. * transcription_file - The transcribed text file name in the output directory. - * language - The detected language in the audio file. - * language_probability - The detected language probability. - * duration - The duration (in seconds) of the audio file (only if `audio_duration` is set to True). - - :param data_path: A directory of audio files or a single file or a list of files to transcribe. - :param output_directory: Path to a directory to save all transcribed audio files. - :param model_name: One of the official model names of Whisper: {'tiny.en', 'tiny', 'base.en', 'base', - 'small.en', 'small', 'medium.en', 'medium', 'large-v1', 'large-v2', 'large'} or a - full name of a fine-tuned whisper model from the huggingface hub. - :param device: Device to load the model. Can be one of {"cuda", "cpu"}. Default will prefer "cuda" - if available. To use a specific GPU or more than one GPU, pass the `device_index` - argument via the `init_kwargs`. - :param compute_type: The data type to use for computation. For more information, check - https://opennmt.net/CTranslate2/quantization.html. Default: "default" - will use the - default type depending on the device used. - :param language: The spoken language to force Whisper the output language. If None, the Whisper model - will automatically predict the output langauge. Default: None. - :param translate_to_english: Whether to translate the English post transcription. Default: False. - :param speech_diarization: A speech diarization dictionary with the file names to transcribe as keys and their - diarization as value. The diarization is a list of tuples: (start, end, speaker). - The transcription result will be in the following format: - "{speaker}: text text text.". Files with missing diarizations will print a warning. - Pay attention the diarization must be for the entire duration of the audio file (as - long as Whisper is predicting words up until then). - :param audio_duration: Whether to include the audio files duration (in seconds). The estimated duration is - from bitrate and may be inaccurate. Default: False. - :param init_kwargs: Additional `WhisperModel.__init__` keyword arguments to use. - :param transcribe_kwargs: Additional `WhisperModel.transcribe` keyword arguments to use. - :param verbose: Whether to present logs of a progress bar and errors. Default: False. - - :returns: A tuple of: - - * Path to the output directory. - * A dataframe dataset of the transcribed file names. - * A dictionary of errored files that were not transcribed. + + The transcription is based on Huggingface's ASR pipeline - + https://huggingface.co/transformers/main_classes/pipelines.html#transformers.AutomaticSpeechRecognitionPipeline and + is tested with OpenAI's Whisper models - https://huggingface.co/openai. + + If one of the speaker diarization parameters are given (either `speech_diarization` or + `speech_diarize_per_channel`), the transcription will be written in a conversation format, where each speaker will + be written in a separate line:: + + speaker_1: text + speaker_2: text + speaker_1: text + ... + + :param data_path: A directory of audio files or a single file or a list of files to transcribe. + :param output_directory: Path to a directory to save all transcribed audio files. If not given, will save + the transcribed files in a temporary directory. + :param model_name: The model name to use. Should be a model from the OpenAI's Whisper models for + best results (for example "tiny", "base", "large", etc.). See here for more + information: https://huggingface.co/openai?search_models=whisper. + :param device: The device to use for inference. If not given, will use GPU if available. + :param use_flash_attention_2: Whether to use the Flash Attention 2 implementation. It can be used only with + one of the following GPUs: Nvidia H series and Nvidia A series. T4 support + will be available soon. + + Note: If both `use_flash_attention_2` and + `use_better_transformers` are `None`, the optimization will be chosen + automatically according to the available resources. + + :param use_better_transformers: Whether to use the Better Transformers library to further optimize the model. + Should be used for all use cases that do not support flash attention 2. + + Note: If both `use_flash_attention_2` and `use_better_transformers` are + `None`, the optimization will be chosen automatically according to the + available resources. + :param assistant_model: The assistant model name to use for inference. Notice that the optimizations + (flash attention 2 and better transformers) will be applied for the assistant as + well. Should be a model from Huggingface's distil-whisper (see here for more + information: https://github.com/huggingface/distil-whisper). + + Note: Currently an assistant model is only usable with batch size of 1. + :param max_new_tokens: The maximum number of new tokens to generate. This is used to limit the + generation length. Default is 128 tokens. + :param chunk_length_s: The audio chunk to split the audio to (in seconds). Default is 30 seconds. + :param batch_size: The batch size to use for inference. Default is 2. + :param spoken_language: Aim whisper to know what language is spoken. If None, it will try to detect + it. + :param translate_to_english: Whether to translate the transcriptions to English. + :param speech_diarization: A speech diarization dictionary with the file names to transcribe as keys and + their diarization as value. The diarization is a list of tuples: + (start, end, speaker). An example + for a diarization dictionary:: + + { + "audio_file_name": [ + { + "start": 0.0, + "end": 2.0, + "speaker": "Agent", + }, + { + "start": 2.0, + "end": 4.0, + "speaker": "Client", + }, + ... + ], + ... + } + + Note: The diarization must be for the entire duration of the audio file (as long + as Whisper is predicting words up until then. + :param speech_diarize_per_channel: Perform speech diarization per channel. Each speaker is expected to belong to + a separate channel in the audio. Notice: This will make the transcription + slower as each channel wil be transcribed separatly. If a speech diarization + is passed (via the `speech_diarization` parameter), this parameter is + ignored. + :param speaker_labels: A list of speaker labels by channel order to use for writing the + transcription with respect to per channel speech diarization. This won't be + used together with a given speech diarization (via the `speech_diarization` + parameter). + :param use_multiprocessing: Whether to use multiprocessing to transcribe the audio files. Can be either a + boolean value or an integer. If `True`, will use the default amount of workers + (3): 1 for transcription, 1 for batch processing and 1 for task completion (such + as speech diarization and writing to files). To control the amount of tasks + completion workers, an integer can be provided to specify the amount of workers. + `False`, will use a single process. Default is `False`. + :param verbose: Whether to print the progress of the transcription. Default is `False`. """ global _LOGGER # Get the input audio files to transcribe: if verbose: _LOGGER.info("Collecting audio files.") - if isinstance(data_path, str): - data_path = pathlib.Path(data_path).absolute() - audio_files = _get_audio_files(data_path=data_path) - else: - audio_files = data_path + audio_files = _get_audio_files(data_path=data_path) if verbose: _LOGGER.info(f"Collected {len(audio_files)} audio files.") - # Load the whisper model: + # Get the output directory: + if output_directory is None: + if verbose: + _LOGGER.info("No output directory given, using temporary directory.") + output_directory = tempfile.mkdtemp() + output_directory = Path(output_directory).absolute() + output_directory.mkdir(exist_ok=True, parents=True) if verbose: - _LOGGER.info(f"Loading model '{model_name}' - using device '{device}'.") - init_kwargs = init_kwargs or {} - model = faster_whisper.WhisperModel( - model_size_or_path=model_name, + _LOGGER.info(f"Transcriptions will be saved to: {output_directory}") + + # Initialize a batch processor according to user requirements (no speech diarization, given speech diarization, + # speech diarization per channel): + if speech_diarization: + batch_processor = SpeechDiarizationBatchProcessor( + audio_files=audio_files, + output_directory=output_directory, + speech_diarization=speech_diarization, + ) + elif speech_diarize_per_channel: + batch_processor = PerChannelSpeechDiarizationBatchProcessor( + audio_files=audio_files, + output_directory=output_directory, + n_channels=speech_diarize_per_channel, + speakers=speaker_labels, + ) + else: + batch_processor = BatchProcessor( + audio_files=audio_files, + output_directory=output_directory, + ) + + # Initialize the transcription pipeline: + transcriber = Transcriber( device=device, - compute_type=compute_type, - **init_kwargs, + use_flash_attention_2=use_flash_attention_2, + use_better_transformers=use_better_transformers, + assistant_model=assistant_model, + model_name=model_name, + max_new_tokens=max_new_tokens, + chunk_length_s=chunk_length_s, + batch_size=batch_size, + return_timestamps=( + "word" + if speech_diarization is not None or speech_diarize_per_channel is not None + else False + ), + per_channel_transcription=speech_diarize_per_channel or 0, + spoken_language=spoken_language, + translate_to_english=translate_to_english, ) - if verbose: - _LOGGER.info(f"Model loaded successfully.") - # Prepare the successes dataframe and errors dictionary to be returned: + # Run the transcription: + if use_multiprocessing: + results = _parallel_run( + n_workers=use_multiprocessing + if isinstance(use_multiprocessing, int) + else 1, + audio_files=audio_files, + batch_processor=batch_processor, + transcriber=transcriber, + verbose=verbose, + ) + else: + results = _run( + audio_files=audio_files, + batch_processor=batch_processor, + transcriber=transcriber, + verbose=verbose, + ) + + # Process the results: + if verbose: + _LOGGER.info("Summarizing the results.") successes = [] errors = {} - - # Create the output directory: - output_directory = pathlib.Path(output_directory) - output_directory.mkdir(parents=True, exist_ok=True) - - # Prepare the transcribe keyword arguments: - transcribe_kwargs = transcribe_kwargs or {} - transcribe_kwargs["language"] = language - transcribe_kwargs["task"] = "translate" if translate_to_english else "transcribe" - - # Go over the audio files and transcribe: - for audio_file in tqdm( - audio_files, desc="Transcribing", unit="file", disable=not verbose - ): - try: - # Transcribe: - transcription_and_info = _transcribe( - audio_file=audio_file, - model=model, - transcribe_kwargs=transcribe_kwargs, - speech_diarization=_get_diarization( # Get the diarization (if provided). - speech_diarization=speech_diarization, - file_name=audio_file.name, - verbose=verbose, - ), - audio_duration=audio_duration, - ) - # Write the transcription to file: - transcription_file = _save_to_file( - transcription=transcription_and_info[0], - file_name=audio_file.stem, - output_directory=output_directory, - ) - # Note as a success in the list: - successes.append( - [ - audio_file.name, - transcription_file.name, - *transcription_and_info[1:], - ] - ) - except Exception as exception: - # Note the exception as error in the dictionary: - if verbose: - _LOGGER.warning(f"Error in file: '{audio_file.name}'") - errors[str(audio_file.name)] = str(exception) - continue - - # Construct the transcriptions dataframe: - columns = [ - "audio_file", - "transcription_file", - "language", - "language_probability", - ] - if audio_duration: - columns.append("duration") - successes = pd.DataFrame( - successes, - columns=columns, - ) - - # Print the head of the produced dataframe and return: + for is_error, result in results: + if is_error: + errors[result[0]] = result[1] + else: + successes.append(result) + successes = pd.DataFrame(successes, columns=["audio_file", "transcription_file"]) if verbose: _LOGGER.info( f"Done ({successes.shape[0]}/{len(audio_files)})\n" f"Transcriptions summary:\n" f"{successes.head()}" ) + return str(output_directory), successes, errors def _get_audio_files( - data_path: pathlib.Path, -) -> List[pathlib.Path]: + data_path: Union[Path, str, list], +) -> List[Path]: + """ + Get the audio files to transcribe. If a path to a directory is given, all files in the directory will be collected. + + :param data_path: The data path to collect the audio files from. + + :returns: The audio files list. + """ + # Check if given a list of paths: + if isinstance(data_path, list): + audio_files = [] + for path in data_path: + audio_files.extend(_get_audio_files(data_path=path)) + return audio_files + + # Check if given a single string path to cast it to a `pathlib.Path`: + if isinstance(data_path, str): + data_path = Path(data_path).absolute() + # Check if the path is of a directory or a file: if data_path.is_dir(): # Get all files inside the directory: @@ -300,190 +1342,123 @@ def _get_audio_files( audio_files = [data_path] else: raise ValueError( - f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. " - f"Given: {str(data_path)} " + f"Unrecognized data path. The parameter `data_path` must be a valid path to either a directory path or a " + f"file. Given: {str(data_path)} " ) return audio_files -class _DiarizationSegment(NamedTuple): - start: float - end: float - speaker: str +def _run( + audio_files: List[Path], + batch_processor: BatchProcessor, + transcriber: Transcriber, + verbose: bool, +) -> List[Tuple[bool, Tuple[str, str]]]: + """ + Run the transcription without multiprocessing. + :param audio_files: The audio files to transcribe. + :param batch_processor: The batch processor to use. + :param transcriber: The transcriber to use. + :param verbose: Verbosity. -def _get_diarization( - speech_diarization: Dict[str, List[Tuple[float, float, str]]], - file_name: str, - verbose: bool, -) -> Union[List[_DiarizationSegment], None]: - diarization = None - if speech_diarization is not None: - diarization = speech_diarization.get(file_name) - if diarization is None: - if verbose: - _LOGGER.warning( - f"Missing speech diarization for the audio file '{file_name}'. Continuing transcribing without " - f"diarization." - ) - diarization = [_DiarizationSegment(*segment) for segment in diarization] - return diarization - - -def _get_next_diarization_segment( - word: faster_whisper.transcribe.Word, - speech_diarization: List[_DiarizationSegment], - last_chosen_index: int, -) -> int: - # Get the last chosen diarization segment: - last_chosen = speech_diarization[last_chosen_index] - - # If the last chosen segment is the last segment, return it: - if last_chosen_index == len(speech_diarization) - 1: - return last_chosen_index - - # If the word ends before the last chosen segment: - if word.end <= last_chosen.start: - # Then it is still the closest segment - return last_chosen_index - - # We check if it ends inside the last chosen segment: - if word.end < last_chosen.end: - # Then it still is the closest segment - return last_chosen_index - - # The word ends after the segment, we need to collect all next segments up until the word ends before them: - possible_segments = [last_chosen_index] - for i in range(last_chosen_index + 1, len(speech_diarization)): - if word.end > speech_diarization[i].end: - possible_segments.append(i) - continue - possible_segments.append(i) - break - - # Check for the most overlapping option: - best_overlap = 0 - overlapping_segment = None - for i in possible_segments: - overlap = 0 - # If the word starts before segment: - if word.start <= speech_diarization[i].start: - # If it ends before the segment, there is an overlap from the start of the segment to the end of the word: - if word.end < speech_diarization[i].end: - overlap = word.end - speech_diarization[i].start - else: - # The word is wrapping the segment, the overlap is the segment's length: - overlap = speech_diarization[i].end - speech_diarization[i].start - # The word starts in segment, check if the word ends in it: - elif word.end < speech_diarization[i].end: - # The overlap is the word's length: - overlap = word.end - word.start - # The word start in segment but ends after it, the overlap is from the word's start to the segment's end: - else: - overlap = speech_diarization[i].end - word.start - # Check for new best overlap: - if overlap > best_overlap: - best_overlap = overlap - overlapping_segment = i - if overlapping_segment is not None: - return overlapping_segment - - # If there is no overlapping segment, return the closest segment: - best_distance = None - closest_segment = None - for i in possible_segments: - distance = ( - word.start - speech_diarization[i].end - if word.start > speech_diarization[i].end - else speech_diarization[i].start - word.end - ) - if best_distance is None or distance < best_distance: - best_distance = distance - closest_segment = i - return closest_segment - - -def _construct_transcription( - segments: List[faster_whisper.transcribe.Segment], - speech_diarization: List[_DiarizationSegment], -) -> str: - # If there is no diarization, concatenate all segments and return: - if speech_diarization is None: - return " ".join([segment.text for segment in segments]) - - # There is a diarization, try to match the Whisper model predicted timestamps to the closest diarization segment - # (closest diarization segment will be the most overlapping with the word, and if there is no overlap, the closest - # segment to the word): - diarization_index = 0 - speaker = speech_diarization[diarization_index].speaker - text = f"{speaker}:" - for segment in segments: - for word in segment.words: - # Get the next diarization segment: - diarization_index = _get_next_diarization_segment( - word=word, - speech_diarization=speech_diarization, - last_chosen_index=diarization_index, - ) - # Check if the segment is of the same speaker: - if speech_diarization[diarization_index].speaker == speaker: - # Collect the word: - text += word.word - else: - # Append a newline and update the new speaker: - speaker = speech_diarization[diarization_index].speaker - text += f"\n{speaker}:{word.word}" - - return text - - -def _transcribe( - audio_file: pathlib.Path, - model: faster_whisper.WhisperModel, - transcribe_kwargs: dict, - speech_diarization: List[_DiarizationSegment], - audio_duration: bool, -) -> Union[Tuple[str, str, float], Tuple[str, str, float, float]]: - # Transcribe (Segments is a generator, so we cast to list to begin transcription from start to end): - segments, info = model.transcribe( - audio=str(audio_file), - **transcribe_kwargs, - word_timestamps=speech_diarization is not None, + :returns: The collected results. + """ + # Load the transcription pipeline: + if verbose: + _LOGGER.info(f"Loading the transcription pipeline.") + transcriber.load() + if verbose: + _LOGGER.info("Transcription pipeline loaded.") + + # Transcribe the files: + transcriber.transcribe( + audio_files=audio_files, + batch_processor=batch_processor, + verbose=verbose, ) - segments = list(segments) - # Check if speech diarization was provided: - if speech_diarization is None: - text = "".join([segment.text for segment in segments]) - else: - text = _construct_transcription( - segments=segments, - speech_diarization=speech_diarization, + # Return the results: + return batch_processor.get_results() + + +def _parallel_run( + n_workers: int, + audio_files: List[Path], + batch_processor: BatchProcessor, + transcriber: Transcriber, + verbose: bool, +): + """ + Run the transcription with multiprocessing. + + :param n_workers: The amount of workers to use as task completers. + :param audio_files: The audio files to transcribe. + :param batch_processor: The batch processor to use. + :param transcriber: The transcriber to use. + :param verbose: Verbosity. + + :returns: The collected results. + """ + # Initialize the multiprocessing queues: + batches_queue = Queue() + tasks_queue = Queue() + results_queue = Queue() + + # Initialize the multiprocessing processes: + batch_processing_process = Process( + target=_multiprocessing_process_batches, + kwargs={ + "batch_processor": batch_processor, + "batches_queue": batches_queue, + "tasks_queue": tasks_queue, + "n_task_completers": n_workers, + }, + ) + task_completion_processes = [ + Process( + target=_multiprocessing_complete_tasks, + kwargs={"tasks_queue": tasks_queue, "results_queue": results_queue}, ) - text = text.strip() + for _ in range(n_workers) + ] - # Return the transcription text and the additional information: - if audio_duration: - return text.strip(), info.language, info.language_probability, info.duration - return text.strip(), info.language, info.language_probability + # Start the multiprocessing processes: + batch_processing_process.start() + for p in task_completion_processes: + p.start() + # Load the transcription pipeline: + if verbose: + _LOGGER.info(f"Loading the transcription pipeline.") + transcriber.load() + if verbose: + _LOGGER.info("Transcription pipeline loaded.") -def _save_to_file( - transcription: str, file_name: str, output_directory: pathlib.Path -) -> pathlib.Path: - # Prepare the file full path (checking for no duplications): - transcription_file = output_directory / f"{file_name}.txt" - i = 1 - while transcription_file.exists(): - i += 1 - transcription_file = output_directory / f"{file_name}_{i}.txt" + # Transcribe the files: + transcriber.transcribe( + audio_files=audio_files, batches_queue=batches_queue, verbose=verbose + ) - # Make sure all directories are created: - transcription_file.parent.mkdir(exist_ok=True, parents=True) + # Collect the results: + results = [] + stop_marks_counter = 0 + while True: + # Get a result from the queue: + result: Tuple[bool, Tuple[str, str]] = results_queue.get() + if result == _MULTIPROCESSING_STOP_MARK: + stop_marks_counter += 1 + if stop_marks_counter == n_workers: + break + else: + # Collect the result: + results.append(result) - # Write to file: - with open(transcription_file, "w") as fp: - fp.write(transcription) + # Wait for the processes to finish: + results_queue.empty() + batch_processing_process.join() + for p in task_completion_processes: + p.join() - return transcription_file + return results \ No newline at end of file