forked from facebookresearch/seamless_communication
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Mutox Classifier Model (facebookresearch#332)
- Loading branch information
1 parent
586eab1
commit 0727989
Showing
14 changed files
with
818 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
name: mutox | ||
model_type: mutox_classifier | ||
model_arch: mutox | ||
checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/mutox.pt" | ||
input_size: 1024 |
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# MuTox: MuTox: Universal MUltilingual Audio-based TOXicity Dataset and Zero-shot Detector | ||
|
||
MuTox, the first highly multilingual audio-based dataset with toxicity labels. | ||
The dataset consists of 20k audio utterances for English and Spanish, and 4k for | ||
the other 19 languages. To showcase the quality of this dataset, we train the | ||
MuTox audio-based toxicity classifier, which allows zero-shot toxicity detection | ||
across a broad range of languages. This classifier outperforms existing | ||
text-based trainable classifiers by more than 1% AUC, while increasing the | ||
language coverage from 8 to 100+ languages. When compared to a wordlist-based | ||
classifier that covers a similar number of languages, MuTox improves precision | ||
and recall by ∼2.5 times. | ||
|
||
## License | ||
|
||
The mutox code and model are licensed under the MIT license (see MIT_LICENSE | ||
file at the root of seamless_communication). The mutox model depends on SONAR | ||
encoders, most are under the MIT license but a few are under CC-BY-NC license. | ||
See the [SONAR repository](https://github.com/facebookresearch/SONAR) for | ||
details. | ||
|
||
## Dataset Languages. | ||
|
||
- English, | ||
- Spanish, | ||
- Arabic, | ||
- Bengali, | ||
- Mandarin Chinese, | ||
- Dutch, | ||
- French, | ||
- German, | ||
- Hindi, | ||
- Indonesian, | ||
- Italian, | ||
- Japanese, | ||
- Korean, | ||
- Portuguese, | ||
- Russian, | ||
- Swahili, | ||
- Tagalog, | ||
- Thai, | ||
- Turkish, | ||
- Urdu, | ||
- Vietnamese | ||
|
||
## Classifier details. | ||
|
||
We use multi-modal and multilingual | ||
[SONAR](https://github.com/facebookresearch/SONAR) encoders from (Duquenne et | ||
al., 2023). For the classifier, we use variable input sizes for the 3 | ||
feedforward layers (1024, 512, and 128). | ||
|
||
## Classifier Quick Start | ||
|
||
This introduces the MuTox speech toxicity model, this relies on computing the | ||
sonar embedding and then classifying it through the MuTox model. The | ||
`cli/mutox/mutox.py` provides an example of reading a TSV, computing the SONAR | ||
embedding and running the classifier on the results: | ||
|
||
```bash | ||
python -m seamless_communication.cli.toxicity.mutox.mutox_speech --lang fra --audio_column ref_tgt_audio /checkpoint/bokai/seamless/toxity_mitigation/exps_v5/joined_etox/fleurs/s2t/en-xx/fra.tsv /tmp/tesmortt.tsv | ||
``` | ||
|
||
You can also work with text: | ||
|
||
```bash | ||
python -m seamless_communication.cli.toxicity.mutox.mutox_text --lang fra_Latn sentences.txt | ||
``` | ||
|
||
You can also check the mutox example notebook in this directory. | ||
|
||
## Dataset | ||
|
||
The dataset is available in this [file](https://dl.fbaipublicfiles.com/seamless/datasets/mutox.csv). The dataset is licensed under the MIT license (see MIT_LICENSE | ||
file at the root of seamless_communication). | ||
|
||
## Citation | ||
|
||
```bitex | ||
@misc{costajussà2023mutox, | ||
title={MuTox: Universal MUltilingual Audio-based TOXicity Dataset and Zero-shot Detector}, | ||
author={ Marta R. Costa-jussà, Mariano Coria Meglioli, Pierre Andrews, David Dale, Prangthip Hansanti, Elahe Kalbassi, Alex Mourachko, Christophe Ropers, Carleigh Wood}, | ||
year={2023}, | ||
eprint={}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CL} | ||
} | ||
``` |
245 changes: 245 additions & 0 deletions
245
src/seamless_communication/cli/toxicity/mutox/mutox_example.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Copyright (c) Meta Platforms, Inc. and affiliates\n", | ||
"# All rights reserved.\n", | ||
"#\n", | ||
"# This source code is licensed under the license found in the\n", | ||
"# MIT_LICENSE file in the root directory of this source tree." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# MUTOX toxicity classification\n", | ||
"\n", | ||
"Mutox lets you score speech and text toxicity using a classifier that can score sonar embeddings. In this notebook, we provide an example of encoding speech and text and classifying that." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from pathlib import Path\n", | ||
"\n", | ||
"if torch.cuda.is_available():\n", | ||
" device = torch.device(\"cuda:0\")\n", | ||
" dtype = torch.float16\n", | ||
"else:\n", | ||
" device = torch.device(\"cpu\")\n", | ||
" dtype = torch.float32" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Speech Scoring" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"1. download some demo audio segments\n", | ||
"2. create a tsv file to feed to the speech scoring pipeline\n", | ||
"3. load the model and build the pipeline\n", | ||
"4. go through the batches in the pipeline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get demo file\n", | ||
"import urllib.request\n", | ||
"import tempfile\n", | ||
"\n", | ||
"files = [\n", | ||
" (\"https://dl.fbaipublicfiles.com/seamless/tests/commonvoice_example_en_clocks.wav\", \"commonvoice_example_en_clocks.wav\"),\n", | ||
" (\"https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav\", \"LJ037-0171_sr16k.wav\")\n", | ||
"]\n", | ||
"\n", | ||
"tmpdir = Path(tempfile.mkdtemp())\n", | ||
"tsv_file = (tmpdir / 'data.tsv')\n", | ||
"with tsv_file.open('w') as tsv_file_p:\n", | ||
" print('path', file=tsv_file_p)\n", | ||
" for (uri, name) in files:\n", | ||
" dl = tmpdir / name\n", | ||
" urllib.request.urlretrieve(uri, dl)\n", | ||
" print(dl, file=tsv_file_p)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sonar.inference_pipelines.speech import SpeechInferenceParams\n", | ||
"from seamless_communication.toxicity.mutox.speech_pipeline import MutoxSpeechClassifierPipeline\n", | ||
"\n", | ||
"pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n", | ||
" mutox_classifier_name =\"mutox\",\n", | ||
" encoder_name=f\"sonar_speech_encoder_eng\",\n", | ||
" device=device,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pipeline = pipeline_builder.build_pipeline(SpeechInferenceParams(\n", | ||
" data_file=tsv_file,\n", | ||
" audio_root_dir=None,\n", | ||
" audio_path_index=0,\n", | ||
" target_lang=\"eng\",\n", | ||
" batch_size=4,\n", | ||
" pad_idx=0,\n", | ||
" device=device,\n", | ||
" fbank_dtype=torch.float32,\n", | ||
" n_parallel=4\n", | ||
"))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"/tmp/tmpqasvhgx6/commonvoice_example_en_clocks.wav\t-42.40079116821289\n", | ||
"/tmp/tmpqasvhgx6/LJ037-0171_sr16k.wav\t-47.90427780151367\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for batch in pipeline:\n", | ||
" ex = batch['audio']\n", | ||
" for idx, path in enumerate(ex['path']):\n", | ||
" print(str(path), ex[\"data\"][idx].item(), sep=\"\\t\")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# cleanup tmp dir\n", | ||
"import shutil\n", | ||
"shutil.rmtree(tmpdir)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Text Scoring\n", | ||
"\n", | ||
"1. load the sonar text encoder\n", | ||
"2. load the mutox classifier model\n", | ||
"3. compute embedding for a sentence\n", | ||
"4. score this embedding" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Using the cached checkpoint of mutox. Set `force` to `True` to download again.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from seamless_communication.toxicity.mutox.loader import load_mutox_model\n", | ||
"from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n", | ||
"\n", | ||
"t2vec_model = TextToEmbeddingModelPipeline(\n", | ||
" encoder=\"text_sonar_basic_encoder\",\n", | ||
" tokenizer=\"text_sonar_basic_encoder\",\n", | ||
")\n", | ||
"text_column='lang_txt'\n", | ||
"classifier = load_mutox_model(\n", | ||
" \"mutox\",\n", | ||
" device=device,\n", | ||
" dtype=dtype,\n", | ||
").eval()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"with torch.inference_mode():\n", | ||
" emb = t2vec_model.predict([\"De peur que le pays ne se prostitue et ne se remplisse de crimes.\"], source_lang='fra_Latn')\n", | ||
" x = classifier(emb.to(device).half())\n", | ||
"\n", | ||
"x" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "sc_fr2", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.