diff --git a/.github/workflows/base.yml.template b/.github/workflows/base.yml.template index 27c53d186..68e89aa4e 100644 --- a/.github/workflows/base.yml.template +++ b/.github/workflows/base.yml.template @@ -57,6 +57,11 @@ jobs: mkdir $CONDA_PKGS_DIRS mkdir $PYTORCH_PKG_CACHE_PATH + - name: Install libsndfile on Ubuntu + shell: bash + run: sudo apt-get install -y libsndfile-dev + if: startsWith(runner.os, 'Linux') == true + - name: Setup Conda uses: s-weigand/setup-conda@v1 diff --git a/.github/workflows/develop_install.yml b/.github/workflows/develop_install.yml index 39801e0dd..be7a4a110 100644 --- a/.github/workflows/develop_install.yml +++ b/.github/workflows/develop_install.yml @@ -17,21 +17,28 @@ jobs: pytorch_version: ['1.1.0', '1.2.0', '1.3.0', '1.3.1', '1.4.0'] platform: ['windows-latest', 'ubuntu-latest', 'macos-latest'] + exclude: - platform: 'windows-latest' conda_python_version: '3.6' + - platform: 'macos-latest' pytorch_version: '1.1.0' + - pytorch_version: '1.1.0' conda_python_version: '3.8' + - pytorch_version: '1.2.0' conda_python_version: '3.8' + - pytorch_version: '1.3.0' conda_python_version: '3.8' + - pytorch_version: '1.3.1' conda_python_version: '3.8' + steps: - name: Checkout repo @@ -72,6 +79,11 @@ jobs: mkdir $CONDA_PKGS_DIRS mkdir $PYTORCH_PKG_CACHE_PATH + - name: Install libsndfile on Ubuntu + shell: bash + run: sudo apt-get install -y libsndfile-dev + if: startsWith(runner.os, 'Linux') == true + - name: Setup Conda uses: s-weigand/setup-conda@v1 @@ -115,9 +127,14 @@ jobs: rm -R ${{ steps.cache-dirs.outputs.pytorch_pkg_cache_path }}/pytorch* if: steps.pytorch-cache.outputs.cache-hit == 'true' - - name: Run Nox session for testing develop install and imports + - name: Run Nox session for testing brevitas develop install and imports shell: bash - run: nox --verbose --session tests_install_develop-${{ matrix.conda_python_version }}\(\pytorch_${{ matrix.pytorch_version }}\) + run: nox -v -s tests_brevitas_install_dev-${{ matrix.conda_python_version }}\(\pytorch_${{ matrix.pytorch_version }}\) + + - name: Run Nox session for testing brevitas_examples develop install and imports + shell: bash + run: nox -v -s tests_brevitas_examples_install_dev-${{ matrix.conda_python_version }}\(\pytorch_${{ matrix.pytorch_version }}\) + - name: Remove tarballs before caching Pytorch deps pkgs diff --git a/.github/workflows/gen_github_actions.py b/.github/workflows/gen_github_actions.py index 58ff73fa9..40cb54b07 100644 --- a/.github/workflows/gen_github_actions.py +++ b/.github/workflows/gen_github_actions.py @@ -42,15 +42,23 @@ PYTEST_MATRIX_EXTRA = od([('jit_status', list(JIT_STATUSES))]) -PYTEST_STEP_LIST = [od([ - ('name', 'Run Nox session for pytest'), - ('shell', 'bash'), - ('run', 'nox --verbose --session tests_cpu-${{ matrix.conda_python_version }}\(${{ matrix.jit_status }}\,\ pytorch_${{ matrix.pytorch_version }}\)')])] - -TEST_INSTALL_DEVELOP_STEP_LIST = [od([ - ('name', 'Run Nox session for testing develop install and imports'), - ('shell', 'bash'), - ('run', 'nox --verbose --session tests_install_develop-${{ matrix.conda_python_version }}\(\pytorch_${{ matrix.pytorch_version }}\)')])] +PYTEST_STEP_LIST = [ + od([ + ('name', 'Run Nox session for pytest'), + ('shell', 'bash'), + ('run', 'nox -v -s tests_brevitas_cpu-${{ matrix.conda_python_version }}\(${{ matrix.jit_status }}\,\ pytorch_${{ matrix.pytorch_version }}\)') + ])] + +TEST_INSTALL_DEV_STEP_LIST = [ + od([ + ('name', 'Run Nox session for testing brevitas develop install and imports'), + ('shell', 'bash'), + ('run', 'nox -v -s tests_brevitas_install_dev-${{ matrix.conda_python_version }}\(\pytorch_${{ matrix.pytorch_version }}\)')]), + od([ + ('name', 'Run Nox session for testing brevitas_examples develop install and imports'), + ('shell', 'bash'), + ('run', 'nox -v -s tests_brevitas_examples_install_dev-${{ matrix.conda_python_version }}\(\pytorch_${{ matrix.pytorch_version }}\)') + ])] # whitespaces to indent generated portions of output yaml @@ -90,6 +98,7 @@ def dict_str(d, quote_val, indent_first): repr += f"{name}: {val}\n" if indent_first: repr = indent(repr, RELATIVE_INDENT*' ', predicate=lambda line: not first_line_prefix in line) + repr += '\n' return repr def gen_yaml(self, output_path): @@ -122,7 +131,7 @@ def gen_test_develop_install_yml(): 'Test develop install', EXCLUDE_LIST, MATRIX, - TEST_INSTALL_DEVELOP_STEP_LIST) + TEST_INSTALL_DEV_STEP_LIST) test_develop_install.gen_yaml(DEVELOP_INSTALL_YML) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ffc5875cd..6cdfdf8f5 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -18,23 +18,31 @@ jobs: platform: ['windows-latest', 'ubuntu-latest', 'macos-latest'] jit_status: ['jit_enabled', 'jit_disabled'] + exclude: - platform: 'windows-latest' conda_python_version: '3.6' + - platform: 'macos-latest' pytorch_version: '1.1.0' + - pytorch_version: '1.1.0' conda_python_version: '3.8' + - pytorch_version: '1.2.0' conda_python_version: '3.8' + - pytorch_version: '1.3.0' conda_python_version: '3.8' + - pytorch_version: '1.3.1' conda_python_version: '3.8' + - pytorch_version: '1.1.0' jit_status: 'jit_disabled' + steps: - name: Checkout repo @@ -75,6 +83,11 @@ jobs: mkdir $CONDA_PKGS_DIRS mkdir $PYTORCH_PKG_CACHE_PATH + - name: Install libsndfile on Ubuntu + shell: bash + run: sudo apt-get install -y libsndfile-dev + if: startsWith(runner.os, 'Linux') == true + - name: Setup Conda uses: s-weigand/setup-conda@v1 @@ -120,7 +133,8 @@ jobs: - name: Run Nox session for pytest shell: bash - run: nox --verbose --session tests_cpu-${{ matrix.conda_python_version }}\(${{ matrix.jit_status }}\,\ pytorch_${{ matrix.pytorch_version }}\) + run: nox -v -s tests_brevitas_cpu-${{ matrix.conda_python_version }}\(${{ matrix.jit_status }}\,\ pytorch_${{ matrix.pytorch_version }}\) + - name: Remove tarballs before caching Pytorch deps pkgs diff --git a/brevitas_examples/__init__.py b/brevitas_examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/imagenet_classification/README.md b/brevitas_examples/imagenet_classification/README.md similarity index 94% rename from examples/imagenet_classification/README.md rename to brevitas_examples/imagenet_classification/README.md index 46d93de06..a63367265 100644 --- a/examples/imagenet_classification/README.md +++ b/brevitas_examples/imagenet_classification/README.md @@ -16,12 +16,12 @@ Below in the table is a list of example pretrained models made available for ref To evaluate a pretrained quantized model on ImageNet: - Make sure you have Brevitas installed and the ImageNet dataset in a Pytorch friendly format (following this [script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh)). - - Pass the corresponding cfg .ini file as an input to the evaluation script. The required checkpoint will be downloaded automatically. + - Pass the name of the model as an input to the evaluation script. The required checkpoint will be downloaded automatically. For example, for *quant_mobilenet_v1_4b* evaluated on GPU 0: ``` -python imagenet_val.py --imagenet-dir /path/to/imagenet --model-cfg ./cfg/quant_mobilenet_v1_4b.ini --gpu 0 +brevitas_imagenet_val --imagenet-dir /path/to/imagenet --model quant_mobilenet_v1_4b --gpu 0 --pretrained ``` ## MobileNet V1 diff --git a/brevitas_examples/imagenet_classification/__init__.py b/brevitas_examples/imagenet_classification/__init__.py new file mode 100644 index 000000000..cf4f59d6c --- /dev/null +++ b/brevitas_examples/imagenet_classification/__init__.py @@ -0,0 +1 @@ +from .models import * \ No newline at end of file diff --git a/brevitas_examples/imagenet_classification/cfg/__init__.py b/brevitas_examples/imagenet_classification/cfg/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/imagenet_classification/cfg/quant_mobilenet_v1_4b.ini b/brevitas_examples/imagenet_classification/cfg/quant_mobilenet_v1_4b.ini similarity index 100% rename from examples/imagenet_classification/cfg/quant_mobilenet_v1_4b.ini rename to brevitas_examples/imagenet_classification/cfg/quant_mobilenet_v1_4b.ini diff --git a/examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_4b.ini b/brevitas_examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_4b.ini similarity index 100% rename from examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_4b.ini rename to brevitas_examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_4b.ini diff --git a/examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_4b5b.ini b/brevitas_examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_4b5b.ini similarity index 100% rename from examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_4b5b.ini rename to brevitas_examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_4b5b.ini diff --git a/examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_hadamard_4b.ini b/brevitas_examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_hadamard_4b.ini similarity index 100% rename from examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_hadamard_4b.ini rename to brevitas_examples/imagenet_classification/cfg/quant_proxylessnas_mobile14_hadamard_4b.ini diff --git a/examples/imagenet_classification/imagenet_val.py b/brevitas_examples/imagenet_classification/imagenet_val.py similarity index 81% rename from examples/imagenet_classification/imagenet_val.py rename to brevitas_examples/imagenet_classification/imagenet_val.py index a256aacd0..17cde7772 100644 --- a/examples/imagenet_classification/imagenet_val.py +++ b/brevitas_examples/imagenet_classification/imagenet_val.py @@ -12,17 +12,15 @@ import torchvision.transforms as transforms import torchvision.datasets as datasets -from models import * +from .models import model_with_cfg SEED = 123456 -models = {'quant_mobilenet_v1': quant_mobilenet_v1, - 'quant_proxylessnas_mobile14': quant_proxylessnas_mobile14} - parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser.add_argument('--imagenet-dir', help='path to folder containing Imagenet val folder') -parser.add_argument('--model-cfg', type=str, help='Path to pretrained model .ini configuration file') +parser.add_argument('--model', type=str, default='quant_mobilenet_v1_4b', help='Name of the model') +parser.add_argument('--pretrained', action='store_true', help='Load pretrained checkpoint') parser.add_argument('--workers', default=4, type=int, help='number of data loading workers') parser.add_argument('--batch-size', default=256, type=int, help='Minibatch size') parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') @@ -33,28 +31,13 @@ def main(): random.seed(SEED) torch.manual_seed(SEED) - assert os.path.exists(args.model_cfg) - cfg = configparser.ConfigParser() - cfg.read(args.model_cfg) - arch = cfg.get('MODEL', 'ARCH') - - model = models[arch](cfg) + model, cfg = model_with_cfg(args.model, args.pretrained) if args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) cudnn.benchmark = True - pretrained_url = cfg.get('MODEL', 'PRETRAINED_URL') - print("=> Loading checkpoint from:'{}'".format(pretrained_url)) - if args.gpu is None: - checkpoint = torch.hub.load_state_dict_from_url(pretrained_url) - else: - # Map model to be loaded to specified single gpu. - loc = 'cuda:{}'.format(args.gpu) - checkpoint = torch.hub.load_state_dict_from_url(pretrained_url, map_location=loc) - model.load_state_dict(checkpoint, strict=True) - valdir = os.path.join(args.imagenet_dir, 'val') mean = [float(cfg.get('PREPROCESS', 'MEAN_0')), float(cfg.get('PREPROCESS', 'MEAN_1')), float(cfg.get('PREPROCESS', 'MEAN_2'))] diff --git a/brevitas_examples/imagenet_classification/models/__init__.py b/brevitas_examples/imagenet_classification/models/__init__.py new file mode 100644 index 000000000..2319f2aff --- /dev/null +++ b/brevitas_examples/imagenet_classification/models/__init__.py @@ -0,0 +1,48 @@ +import os +from configparser import ConfigParser + +from torch import hub + +from .mobilenetv1 import * +from .vgg import * +from .proxylessnas import * + +model_impl = { + 'quant_mobilenet_v1': quant_mobilenet_v1, + 'quant_proxylessnas_mobile14': quant_proxylessnas_mobile14 +} + + +def model_with_cfg(name, pretrained): + cfg = ConfigParser() + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(current_dir, '..', 'cfg', name + '.ini') + assert os.path.exists(config_path) + cfg.read(config_path) + arch = cfg.get('MODEL', 'ARCH') + model = model_impl[arch](cfg) + if pretrained: + checkpoint = cfg.get('MODEL', 'PRETRAINED_URL') + state_dict = hub.load_state_dict_from_url(checkpoint, progress=True, map_location='cpu') + model.load_state_dict(state_dict, strict=True) + return model, cfg + + +def quant_mobilenet_v1_4b(pretrained=True): + model, _ = model_with_cfg('quant_mobilenet_v1_4b', pretrained) + return model + + +def quant_proxylessnas_mobile14_4b(pretrained=True): + model, _ = model_with_cfg('quant_proxylessnas_mobile14_4b', pretrained) + return model + + +def quant_proxylessnas_mobile14_4b5b(pretrained=True): + model, _ = model_with_cfg('quant_proxylessnas_mobile14_4b5b', pretrained) + return model + + +def quant_proxylessnas_mobile14_hadamard_4b(pretrained=True): + model, _ = model_with_cfg('quant_proxylessnas_mobile14_hadamard_4b', pretrained) + return model \ No newline at end of file diff --git a/examples/imagenet_classification/models/common.py b/brevitas_examples/imagenet_classification/models/common.py similarity index 99% rename from examples/imagenet_classification/models/common.py rename to brevitas_examples/imagenet_classification/models/common.py index d2e56bb17..2acb53ca1 100644 --- a/examples/imagenet_classification/models/common.py +++ b/brevitas_examples/imagenet_classification/models/common.py @@ -4,6 +4,7 @@ from brevitas.core.scaling import ScalingImplType from brevitas.core.stats import StatsOp + QUANT_TYPE = QuantType.INT SCALING_MIN_VAL = 2e-16 @@ -151,3 +152,6 @@ def make_hadamard_classifier(in_channels, return qnn.HadamardClassifier(in_channels=in_channels, out_channels=out_channels, fixed_scale=fixed_scale) + + + diff --git a/examples/imagenet_classification/models/mobilenetv1.py b/brevitas_examples/imagenet_classification/models/mobilenetv1.py similarity index 98% rename from examples/imagenet_classification/models/mobilenetv1.py rename to brevitas_examples/imagenet_classification/models/mobilenetv1.py index 2b4c7a2b0..3f189ee8a 100644 --- a/examples/imagenet_classification/models/mobilenetv1.py +++ b/brevitas_examples/imagenet_classification/models/mobilenetv1.py @@ -33,7 +33,7 @@ from brevitas.quant_tensor import pack_quant_tensor -from .common import make_quant_conv2d, make_quant_linear, make_quant_relu, make_quant_avg_pool +from .common import * FIRST_LAYER_BIT_WIDTH = 8 @@ -170,3 +170,6 @@ def quant_mobilenet_v1(cfg): first_stage_stride=first_stage_stride, bit_width=bit_width) return net + + + diff --git a/examples/imagenet_classification/models/proxylessnas.py b/brevitas_examples/imagenet_classification/models/proxylessnas.py similarity index 100% rename from examples/imagenet_classification/models/proxylessnas.py rename to brevitas_examples/imagenet_classification/models/proxylessnas.py diff --git a/examples/imagenet_classification/models/vgg.py b/brevitas_examples/imagenet_classification/models/vgg.py similarity index 100% rename from examples/imagenet_classification/models/vgg.py rename to brevitas_examples/imagenet_classification/models/vgg.py diff --git a/examples/speech_to_text/README.md b/brevitas_examples/speech_to_text/README.md similarity index 80% rename from examples/speech_to_text/README.md rename to brevitas_examples/speech_to_text/README.md index 4308553d8..c02c1de09 100644 --- a/examples/speech_to_text/README.md +++ b/brevitas_examples/speech_to_text/README.md @@ -15,19 +15,20 @@ It is highly recommended to setup a virtual environment. Download and pre-process the LibriSpeech dataset with the following command: ``` -python utilities/get_librispeech_data.py --data_root=/path/to/validation/folder --data_set=DEV_OTHER +brevitas_quartznet_preprocess --data_root=/path/to/validation/folder --data_set=DEV_OTHER ``` To evaluate a pretrained quantized model on LibriSpeech: - Install pytorch from the [Pytorch Website](https://pytorch.org/), and Cython with the following command: `python install --upgrade cython` - - Install the Quartznet requirements with `pip install requirements.txt` - - Make sure you have Brevitas installed - - Pass the corresponding cfg .ini file as an input to the evaluation script. The required checkpoint will be downloaded automatically. + - Install SoX (this [guide](https://at.projects.genivi.org/wiki/display/PROJ/Installation+of+SoX+on+different+Platforms) + may be helpful) + - After cloning the repository, install Brevitas and QuartzNet requirements with `pip install .[stt]` + - Pass the name of the model as an input to the evaluation script. The required checkpoint will be downloaded automatically. For example, for the evaluation on GPU 0: ``` -python quartznet_val.py --input-folder /path/to/validation/folder --model-cfg cfg/quant_quartznet_pertensorscaling_8b.ini --gpu 0 +brevitas_quartznet_val --input-folder /path/to/validation/folder --model quant_quartznet_pertensorscaling_8b --gpu 0 --pretrained ``` diff --git a/brevitas_examples/speech_to_text/__init__.py b/brevitas_examples/speech_to_text/__init__.py new file mode 100644 index 000000000..8d6fb3970 --- /dev/null +++ b/brevitas_examples/speech_to_text/__init__.py @@ -0,0 +1 @@ +from .quartznet import * \ No newline at end of file diff --git a/brevitas_examples/speech_to_text/cfg/__init__.py b/brevitas_examples/speech_to_text/cfg/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_4b.ini b/brevitas_examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_4b.ini similarity index 92% rename from examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_4b.ini rename to brevitas_examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_4b.ini index 5f2f82e03..62ea54a6c 100644 --- a/examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_4b.ini +++ b/brevitas_examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_4b.ini @@ -1,6 +1,6 @@ [MODEL] ARCH: quartznet -TOPOLOGY_FILE: cfg/topology/quartznet15x5.yaml +TOPOLOGY_FILE: quartznet15x5.yaml PRETRAINED_ENCODER_URL: https://github.com/Xilinx/brevitas/releases/download/quant_quartznet_4b-r0/quant_quartznet_encoder_4b-0a46a232.pth PRETRAINED_DECODER_URL: https://github.com/Xilinx/brevitas/releases/download/quant_quartznet_4b-r0/quant_quartznet_decoder_4b-bcbf8c7b.pth diff --git a/examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_8b.ini b/brevitas_examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_8b.ini similarity index 92% rename from examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_8b.ini rename to brevitas_examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_8b.ini index 58005ebbb..2aa28a511 100644 --- a/examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_8b.ini +++ b/brevitas_examples/speech_to_text/cfg/quant_quartznet_perchannelscaling_8b.ini @@ -1,6 +1,6 @@ [MODEL] ARCH: quartznet -TOPOLOGY_FILE: cfg/topology/quartznet15x5.yaml +TOPOLOGY_FILE: quartznet15x5.yaml PRETRAINED_ENCODER_URL: https://github.com/Xilinx/brevitas/releases/download/quant_quartznet_8b-r0/quant_quartznet_encoder_8b-50f12b4b.pth PRETRAINED_DECODER_URL: https://github.com/Xilinx/brevitas/releases/download/quant_quartznet_8b-r0/quant_quartznet_decoder_8b-af09651c.pth diff --git a/examples/speech_to_text/cfg/quant_quartznet_pertensorscaling_8b.ini b/brevitas_examples/speech_to_text/cfg/quant_quartznet_pertensorscaling_8b.ini similarity index 92% rename from examples/speech_to_text/cfg/quant_quartznet_pertensorscaling_8b.ini rename to brevitas_examples/speech_to_text/cfg/quant_quartznet_pertensorscaling_8b.ini index 8d025e9c7..c7f50261b 100644 --- a/examples/speech_to_text/cfg/quant_quartznet_pertensorscaling_8b.ini +++ b/brevitas_examples/speech_to_text/cfg/quant_quartznet_pertensorscaling_8b.ini @@ -1,6 +1,6 @@ [MODEL] ARCH: quartznet -TOPOLOGY_FILE: cfg/topology/quartznet15x5.yaml +TOPOLOGY_FILE: quartznet15x5.yaml PRETRAINED_ENCODER_URL: https://github.com/Xilinx/brevitas/releases/download/quant_quartznet_8b-r0/quant_quartznet_encoder_8b-50f12b4b.pth PRETRAINED_DECODER_URL: https://github.com/Xilinx/brevitas/releases/download/quant_quartznet_8b-r0/quant_quartznet_decoder_8b-af09651c.pth diff --git a/brevitas_examples/speech_to_text/cfg/topology/__init__.py b/brevitas_examples/speech_to_text/cfg/topology/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/speech_to_text/cfg/topology/quartznet15x5.yaml b/brevitas_examples/speech_to_text/cfg/topology/quartznet15x5.yaml similarity index 100% rename from examples/speech_to_text/cfg/topology/quartznet15x5.yaml rename to brevitas_examples/speech_to_text/cfg/topology/quartznet15x5.yaml diff --git a/brevitas_examples/speech_to_text/get_librispeech_data.py b/brevitas_examples/speech_to_text/get_librispeech_data.py new file mode 100644 index 000000000..495bb1665 --- /dev/null +++ b/brevitas_examples/speech_to_text/get_librispeech_data.py @@ -0,0 +1,148 @@ +# Adapted from https://github.com/NVIDIA/NeMo/tree/r0.9 +# Copyright (C) 2020 Xilinx (Giuseppe Franco) +# Copyright (C) 2019 NVIDIA CORPORATION. +# +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import fnmatch +import json +import os +import subprocess +import tarfile +import urllib.request + +from sox import Transformer +from tqdm import tqdm + +parser = argparse.ArgumentParser(description='LibriSpeech Data download') +parser.add_argument("--data_root", required=True, default=None, type=str) +parser.add_argument("--data_sets", default="dev_clean", type=str) +args = parser.parse_args() + +URLS = { + 'TRAIN_CLEAN_100': ("http://www.openslr.org/resources/12/train-clean-100.tar.gz"), + 'TRAIN_CLEAN_360': ("http://www.openslr.org/resources/12/train-clean-360.tar.gz"), + 'TRAIN_OTHER_500': ("http://www.openslr.org/resources/12/train-other-500.tar.gz"), + 'DEV_CLEAN': "http://www.openslr.org/resources/12/dev-clean.tar.gz", + 'DEV_OTHER': "http://www.openslr.org/resources/12/dev-other.tar.gz", + 'TEST_CLEAN': "http://www.openslr.org/resources/12/test-clean.tar.gz", + 'TEST_OTHER': "http://www.openslr.org/resources/12/test-other.tar.gz", +} + + +def __maybe_download_file(destination: str, source: str): + """ + Downloads source to destination if it doesn't exist. + If exists, skips download + Args: + destination: local filepath + source: url of resource + + Returns: + + """ + source = URLS[source] + if not os.path.exists(destination): + print("{0} does not exist. Downloading ...".format(destination)) + urllib.request.urlretrieve(source, filename=destination + '.tmp') + os.rename(destination + '.tmp', destination) + print("Downloaded {0}.".format(destination)) + else: + print("Destination {0} exists. Skipping.".format(destination)) + return destination + + +def __extract_file(filepath: str, data_dir: str): + try: + tar = tarfile.open(filepath) + tar.extractall(data_dir) + tar.close() + except Exception: + print('Not extracting. Maybe already there?') + + +def __process_data(data_folder: str, dst_folder: str, manifest_file: str): + """ + Converts flac to wav and build manifests's json + Args: + data_folder: source with flac files + dst_folder: where wav files will be stored + manifest_file: where to store manifest + + Returns: + + """ + + if not os.path.exists(dst_folder): + os.makedirs(dst_folder) + + files = [] + entries = [] + + for root, dirnames, filenames in os.walk(data_folder): + for filename in fnmatch.filter(filenames, '*.trans.txt'): + files.append((os.path.join(root, filename), root)) + + for transcripts_file, root in tqdm(files): + with open(transcripts_file, encoding="utf-8") as fin: + for line in fin: + id, text = line[: line.index(" ")], line[line.index(" ") + 1 :] + transcript_text = text.lower().strip() + + # Convert FLAC file to WAV + flac_file = os.path.join(root, id + ".flac") + wav_file = os.path.join(dst_folder, id + ".wav") + if not os.path.exists(wav_file): + Transformer().build(flac_file, wav_file) + # check duration + duration = subprocess.check_output("soxi -D {0}".format(wav_file), shell=True) + + entry = dict() + entry['audio_filepath'] = os.path.abspath(wav_file) + entry['duration'] = float(duration) + entry['text'] = transcript_text + entries.append(entry) + + with open(manifest_file, 'w') as fout: + for m in entries: + fout.write(json.dumps(m) + '\n') + + +def main(): + data_root = args.data_root + data_sets = args.data_sets + + if data_sets == "ALL": + data_sets = "dev_clean,dev_other,train_clean_100,train_clean_360,train_other_500,test_clean,test_other" + + for data_set in data_sets.split(','): + print("\n\nWorking on: {0}".format(data_set)) + filepath = os.path.join(data_root, data_set + ".tar.gz") + print("Getting {0}".format(data_set)) + __maybe_download_file(filepath, data_set.upper()) + print("Extracting {0}".format(data_set)) + __extract_file(filepath, data_root) + print("Processing {0}".format(data_set)) + __process_data( + os.path.join(os.path.join(data_root, "LibriSpeech"), data_set.replace("_", "-"),), + os.path.join(os.path.join(data_root, "LibriSpeech"), data_set.replace("_", "-"),) + "-processed", + os.path.join(data_root, data_set + ".json"), + ) + print('Done!') + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/brevitas_examples/speech_to_text/quartznet/__init__.py b/brevitas_examples/speech_to_text/quartznet/__init__.py new file mode 100644 index 000000000..6c894e610 --- /dev/null +++ b/brevitas_examples/speech_to_text/quartznet/__init__.py @@ -0,0 +1,80 @@ +# Adapted from https://github.com/NVIDIA/NeMo/blob/r0.9/collections/nemo_asr/ +# Copyright (C) 2020 Xilinx (Giuseppe Franco) +# Copyright (C) 2019 NVIDIA CORPORATION. +# +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .data_layer import ( + AudioToTextDataLayer) +from .greedy_ctc_decoder import GreedyCTCDecoder +from .quartznet import quartznet +from .losses import CTCLossNM +from .helpers import * + +import os +from configparser import ConfigParser +from ruamel.yaml import YAML +from torch import hub + +__all__ = ['AudioToTextDataLayer', + 'quartznet', + 'quant_quartznet_perchannelscaling_4b', + 'quant_quartznet_perchannelscaling_8b', + 'quant_quartznet_pertensorscaling_8b'] + +name = "quarznet_release" +model_impl = { + 'quartznet': quartznet, +} + + +def model_with_cfg(name, pretrained): + cfg = ConfigParser() + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(current_dir, '..', 'cfg', name + '.ini') + assert os.path.exists(config_path) + cfg.read(config_path) + arch = cfg.get('MODEL', 'ARCH') + topology_file = cfg.get('MODEL', 'TOPOLOGY_FILE') + topology_path = os.path.join(current_dir, '..', 'cfg', 'topology', topology_file) + yaml = YAML(typ="safe") + with open(topology_path) as f: + quartnzet_params = yaml.load(f) + model = model_impl[arch](cfg, quartnzet_params) + if pretrained: + pretrained_encoder_url = cfg.get('MODEL', 'PRETRAINED_ENCODER_URL') + pretrained_decoder_url = cfg.get('MODEL', 'PRETRAINED_DECODER_URL') + print("=> Loading encoder checkpoint from:'{}'".format(pretrained_encoder_url)) + print("=> Loading decoder checkpoint from:'{}'".format(pretrained_decoder_url)) + checkpoint_enc = torch.hub.load_state_dict_from_url(pretrained_encoder_url, progress=True, map_location='cpu') + checkpoint_dec = torch.hub.load_state_dict_from_url(pretrained_decoder_url, progress=True, map_location='cpu') + model.restore_checkpoints(checkpoint_enc, checkpoint_dec) + return model, cfg + + +def quant_quartznet_perchannelscaling_4b(pretrained=True): + model, _ = model_with_cfg('quant_quartznet_perchannelscaling_4b', pretrained) + return model + + +def quant_quartznet_perchannelscaling_8b(pretrained=True): + model, _ = model_with_cfg('quant_quartznet_perchannelscaling_8b', pretrained) + return model + + +def quant_quartznet_pertensorscaling_8b(pretrained=True): + model, _ = model_with_cfg('quant_quartznet_pertensorscaling_8b', pretrained) + return model diff --git a/examples/speech_to_text/quartznet/audio_preprocessing.py b/brevitas_examples/speech_to_text/quartznet/audio_preprocessing.py similarity index 100% rename from examples/speech_to_text/quartznet/audio_preprocessing.py rename to brevitas_examples/speech_to_text/quartznet/audio_preprocessing.py diff --git a/examples/speech_to_text/quartznet/data_layer.py b/brevitas_examples/speech_to_text/quartznet/data_layer.py similarity index 95% rename from examples/speech_to_text/quartznet/data_layer.py rename to brevitas_examples/speech_to_text/quartznet/data_layer.py index 4594cb1d5..21bb852b1 100644 --- a/examples/speech_to_text/quartznet/data_layer.py +++ b/brevitas_examples/speech_to_text/quartznet/data_layer.py @@ -25,10 +25,7 @@ import torch import torch.nn as nn -# from nemo.backends.pytorch import DataLayerNM -# from nemo.core import DeviceType -# from nemo.core.neural_types import * -from .parts.dataset import (AudioDataset, seq_collate_fn) # , KaldiFeatureDataset, TranscriptDataset) +from .parts.dataset import (AudioDataset, seq_collate_fn) from .parts.features import WaveformFeaturizer def pad_to(x, k=8): @@ -117,7 +114,7 @@ def __init__( load_audio=True, drop_last=False, shuffle=True, - num_workers=0, + num_workers=4, placement='cpu', # perturb_config=None, **kwargs diff --git a/examples/speech_to_text/quartznet/greedy_ctc_decoder.py b/brevitas_examples/speech_to_text/quartznet/greedy_ctc_decoder.py similarity index 100% rename from examples/speech_to_text/quartznet/greedy_ctc_decoder.py rename to brevitas_examples/speech_to_text/quartznet/greedy_ctc_decoder.py diff --git a/examples/speech_to_text/quartznet/helpers.py b/brevitas_examples/speech_to_text/quartznet/helpers.py similarity index 100% rename from examples/speech_to_text/quartznet/helpers.py rename to brevitas_examples/speech_to_text/quartznet/helpers.py diff --git a/examples/speech_to_text/quartznet/losses.py b/brevitas_examples/speech_to_text/quartznet/losses.py similarity index 100% rename from examples/speech_to_text/quartznet/losses.py rename to brevitas_examples/speech_to_text/quartznet/losses.py diff --git a/examples/speech_to_text/quartznet/metrics.py b/brevitas_examples/speech_to_text/quartznet/metrics.py similarity index 100% rename from examples/speech_to_text/quartznet/metrics.py rename to brevitas_examples/speech_to_text/quartznet/metrics.py diff --git a/examples/speech_to_text/quartznet/parts/__init__.py b/brevitas_examples/speech_to_text/quartznet/parts/__init__.py similarity index 100% rename from examples/speech_to_text/quartznet/parts/__init__.py rename to brevitas_examples/speech_to_text/quartznet/parts/__init__.py diff --git a/examples/speech_to_text/quartznet/parts/cleaners.py b/brevitas_examples/speech_to_text/quartznet/parts/cleaners.py similarity index 100% rename from examples/speech_to_text/quartznet/parts/cleaners.py rename to brevitas_examples/speech_to_text/quartznet/parts/cleaners.py diff --git a/examples/speech_to_text/quartznet/parts/common.py b/brevitas_examples/speech_to_text/quartznet/parts/common.py similarity index 100% rename from examples/speech_to_text/quartznet/parts/common.py rename to brevitas_examples/speech_to_text/quartznet/parts/common.py diff --git a/examples/speech_to_text/quartznet/parts/dataset.py b/brevitas_examples/speech_to_text/quartznet/parts/dataset.py similarity index 99% rename from examples/speech_to_text/quartznet/parts/dataset.py rename to brevitas_examples/speech_to_text/quartznet/parts/dataset.py index 283770fb8..a6b60a119 100644 --- a/examples/speech_to_text/quartznet/parts/dataset.py +++ b/brevitas_examples/speech_to_text/quartznet/parts/dataset.py @@ -23,9 +23,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import os -import pandas as pd -import string import torch from torch.utils.data import Dataset diff --git a/examples/speech_to_text/quartznet/parts/features.py b/brevitas_examples/speech_to_text/quartznet/parts/features.py similarity index 100% rename from examples/speech_to_text/quartznet/parts/features.py rename to brevitas_examples/speech_to_text/quartznet/parts/features.py diff --git a/examples/speech_to_text/quartznet/parts/manifest.py b/brevitas_examples/speech_to_text/quartznet/parts/manifest.py similarity index 100% rename from examples/speech_to_text/quartznet/parts/manifest.py rename to brevitas_examples/speech_to_text/quartznet/parts/manifest.py diff --git a/examples/speech_to_text/quartznet/parts/perturb.py b/brevitas_examples/speech_to_text/quartznet/parts/perturb.py similarity index 100% rename from examples/speech_to_text/quartznet/parts/perturb.py rename to brevitas_examples/speech_to_text/quartznet/parts/perturb.py diff --git a/examples/speech_to_text/quartznet/parts/quartznet.py b/brevitas_examples/speech_to_text/quartznet/parts/quartznet.py similarity index 100% rename from examples/speech_to_text/quartznet/parts/quartznet.py rename to brevitas_examples/speech_to_text/quartznet/parts/quartznet.py diff --git a/examples/speech_to_text/quartznet/parts/segment.py b/brevitas_examples/speech_to_text/quartznet/parts/segment.py similarity index 100% rename from examples/speech_to_text/quartznet/parts/segment.py rename to brevitas_examples/speech_to_text/quartznet/parts/segment.py diff --git a/examples/speech_to_text/quartznet/parts/spectr_augment.py b/brevitas_examples/speech_to_text/quartznet/parts/spectr_augment.py similarity index 100% rename from examples/speech_to_text/quartznet/parts/spectr_augment.py rename to brevitas_examples/speech_to_text/quartznet/parts/spectr_augment.py diff --git a/examples/speech_to_text/quartznet/quartznet.py b/brevitas_examples/speech_to_text/quartznet/quartznet.py similarity index 100% rename from examples/speech_to_text/quartznet/quartznet.py rename to brevitas_examples/speech_to_text/quartznet/quartznet.py diff --git a/examples/speech_to_text/quartznet_val.py b/brevitas_examples/speech_to_text/quartznet_val.py similarity index 69% rename from examples/speech_to_text/quartznet_val.py rename to brevitas_examples/speech_to_text/quartznet_val.py index 99a8edb69..77a142f20 100644 --- a/examples/speech_to_text/quartznet_val.py +++ b/brevitas_examples/speech_to_text/quartznet_val.py @@ -7,22 +7,24 @@ import random import torch -from quartznet import AudioToTextDataLayer, quartznet -from quartznet.helpers import word_error_rate, post_process_predictions, \ +from .quartznet import AudioToTextDataLayer +from .quartznet.helpers import word_error_rate, post_process_predictions, \ post_process_transcripts import torch.backends.cudnn as cudnn import brevitas.config +from .quartznet import model_with_cfg + brevitas.config.IGNORE_MISSING_KEYS = False SEED = 123456 BATCH_SIZE = 64 -models = {'quartznet': quartznet} parser = argparse.ArgumentParser(description='Quartznet') -parser.add_argument("--model-cfg", type=str, required=True) parser.add_argument("--input-folder", type=str, required=False) parser.add_argument("--gpu", type=int) +parser.add_argument('--pretrained', action='store_true', default=True, help='Load pretrained checkpoint') +parser.add_argument('--model', type=str, default='quant_quartznet_perchannelscaling_4b', help='Name of the model') def main(): @@ -31,19 +33,15 @@ def main(): args = parser.parse_args() - assert os.path.exists(args.model_cfg) - cfg = configparser.ConfigParser() - cfg.read(args.model_cfg) + model, cfg = model_with_cfg(args.model, args.pretrained) topology_file = cfg.get('MODEL', 'TOPOLOGY_FILE') - + current_dir = os.path.dirname(os.path.abspath(__file__)) + topology_path = os.path.join(current_dir, 'cfg', 'topology', topology_file) yaml = YAML(typ="safe") - with open(topology_file) as f: + with open(topology_path) as f: quartnzet_params = yaml.load(f) - arch = cfg.get('MODEL', 'ARCH') - - model = models[arch](cfg, quartnzet_params) vocab = quartnzet_params['labels'] sample_rate = quartnzet_params['sample_rate'] @@ -79,28 +77,14 @@ def main(): print('================================') - if args.gpu is not None: torch.cuda.set_device(args.gpu) cudnn.benchmark = True - - pretrained_encoder_url = cfg.get('MODEL', 'PRETRAINED_ENCODER_URL') - pretrained_decoder_url = cfg.get('MODEL', 'PRETRAINED_DECODER_URL') - print("=> Loading encoder checkpoint from:'{}'".format(pretrained_encoder_url)) - print("=> Loading decoder checkpoint from:'{}'".format(pretrained_decoder_url)) - if args.gpu is None: - loc = 'cpu' - checkpoint_enc = torch.hub.load_state_dict_from_url(pretrained_encoder_url) - checkpoint_dec = torch.hub.load_state_dict_from_url(pretrained_decoder_url) - else: - # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) - checkpoint_enc = torch.hub.load_state_dict_from_url(pretrained_encoder_url, map_location=loc) - checkpoint_dec = torch.hub.load_state_dict_from_url(pretrained_decoder_url, map_location=loc) - - model.restore_checkpoints(checkpoint_enc, checkpoint_dec) + else: + loc = 'cpu' model.to(loc) - + predictions = [] transcripts = [] transcripts_len = [] diff --git a/examples/text_to_speech/README.md b/brevitas_examples/text_to_speech/README.md similarity index 78% rename from examples/text_to_speech/README.md rename to brevitas_examples/text_to_speech/README.md index df07c5240..aacf086f0 100644 --- a/examples/text_to_speech/README.md +++ b/brevitas_examples/text_to_speech/README.md @@ -19,13 +19,12 @@ find /path/to/LJSpeech/LJSpeech-1.1/wavs -type f | head -10 | xargs cp -t /path/ To evaluate a pretrained quantized model on LJSpeech1.1: - - Install the MelGAN requirements with `pip install requirements.txt` - - Make sure you have Brevitas installed - - Preprocess the dataset with `python preprocess_datasets --model-cfg cfg/quant_melgan_8b --data-path /path/to/validation/folder` - - Pass the corresponding cfg .ini file as an input to the evaluation script. The required checkpoint will be downloaded automatically. + - After cloning the repository, install Brevitas and MelGAN requirements with `pip install .[tts]` + - Preprocess the dataset with `brevitas_melgan_preprocess --name quant_melgan_8b --data-path /path/to/validation/folder` + - Pass the name of the model as an input to the evaluation script. The required checkpoint will be downloaded automatically. For example, for the evaluation on GPU 0: ``` -python melgan_val.py --input-folder /path/to/validation/folder --model-cfg cfg/quant_melgan_8b --gpu 0 +brevitas_melgan_val --input-folder /path/to/validation/folder --model quant_melgan_8b --gpu 0 --pretrained ``` diff --git a/examples/text_to_speech/MelGAN/__init__.py b/brevitas_examples/text_to_speech/__init__.py similarity index 100% rename from examples/text_to_speech/MelGAN/__init__.py rename to brevitas_examples/text_to_speech/__init__.py diff --git a/brevitas_examples/text_to_speech/cfg/__init__.py b/brevitas_examples/text_to_speech/cfg/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/text_to_speech/cfg/quant_melgan_8b.ini b/brevitas_examples/text_to_speech/cfg/quant_melgan_8b.ini similarity index 100% rename from examples/text_to_speech/cfg/quant_melgan_8b.ini rename to brevitas_examples/text_to_speech/cfg/quant_melgan_8b.ini diff --git a/brevitas_examples/text_to_speech/melgan/__init__.py b/brevitas_examples/text_to_speech/melgan/__init__.py new file mode 100644 index 000000000..6bb4c5fff --- /dev/null +++ b/brevitas_examples/text_to_speech/melgan/__init__.py @@ -0,0 +1,30 @@ +import os +from configparser import ConfigParser + +from torch import hub + +from .melgan import * + +model_impl = { + 'melgan': melgan, +} + + +def model_with_cfg(name, pretrained): + cfg = ConfigParser() + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(current_dir, '..', 'cfg', name + '.ini') + assert os.path.exists(config_path) + cfg.read(config_path) + arch = cfg.get('MODEL', 'ARCH') + model = model_impl[arch](cfg) + if pretrained: + checkpoint = cfg.get('MODEL', 'PRETRAINED_URL') + state_dict = hub.load_state_dict_from_url(checkpoint, progress=True, map_location='cpu') + model.load_state_dict(state_dict, strict=True) + return model, cfg + + +def quant_melgan_8b(pretrained=True): + model, _ = model_with_cfg('quant_melgan_8b', pretrained) + return model diff --git a/examples/text_to_speech/MelGAN/common.py b/brevitas_examples/text_to_speech/melgan/common.py similarity index 100% rename from examples/text_to_speech/MelGAN/common.py rename to brevitas_examples/text_to_speech/melgan/common.py diff --git a/examples/text_to_speech/MelGAN/generator_brevitas.py b/brevitas_examples/text_to_speech/melgan/generator_brevitas.py similarity index 100% rename from examples/text_to_speech/MelGAN/generator_brevitas.py rename to brevitas_examples/text_to_speech/melgan/generator_brevitas.py diff --git a/examples/text_to_speech/MelGAN/melgan.py b/brevitas_examples/text_to_speech/melgan/melgan.py similarity index 100% rename from examples/text_to_speech/MelGAN/melgan.py rename to brevitas_examples/text_to_speech/melgan/melgan.py diff --git a/examples/text_to_speech/MelGAN/res_stack_brevitas.py b/brevitas_examples/text_to_speech/melgan/res_stack_brevitas.py similarity index 100% rename from examples/text_to_speech/MelGAN/res_stack_brevitas.py rename to brevitas_examples/text_to_speech/melgan/res_stack_brevitas.py diff --git a/examples/text_to_speech/melgan_val.py b/brevitas_examples/text_to_speech/melgan_val.py similarity index 66% rename from examples/text_to_speech/melgan_val.py rename to brevitas_examples/text_to_speech/melgan_val.py index c55b71156..02d1154f7 100644 --- a/examples/text_to_speech/melgan_val.py +++ b/brevitas_examples/text_to_speech/melgan_val.py @@ -4,25 +4,24 @@ import argparse from scipy.io.wavfile import write -from MelGAN import * import torch.backends.cudnn as cudnn import brevitas.config +from .MelGAN import model_with_cfg brevitas.config.IGNORE_MISSING_KEYS = False MAX_WAV_VALUE = 32768.0 -import configparser import random import os SEED = 123456 -models = {'melgan': melgan} parser = argparse.ArgumentParser() parser.add_argument('--input-folder', help='path to folder containing the val folder') -parser.add_argument('--model-cfg', type=str, help='Path to pretrained model .ini configuration file') parser.add_argument('--workers', default=32, type=int, help='number of data loading workers') parser.add_argument('--batch-size', default=16, type=int, help='Minibatch size') parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') +parser.add_argument('--pretrained', action='store_true', default=True, help='Load pretrained checkpoint') +parser.add_argument('--model', type=str, default='quant_melgan_8b', help='Name of the model') def main(): @@ -30,30 +29,19 @@ def main(): random.seed(SEED) torch.manual_seed(SEED) - assert os.path.exists(args.model_cfg) - cfg = configparser.ConfigParser() - cfg.read(args.model_cfg) - sampling_rate = cfg.getint('AUDIO', 'sampling_rate') - - arch = cfg.get('MODEL', 'ARCH') + model, cfg = model_with_cfg(args.model, args.pretrained) - model = models[arch](cfg) + sampling_rate = cfg.getint('AUDIO', 'sampling_rate') if args.gpu is not None: + loc = 'cuda:{}'.format(args.gpu) torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) cudnn.benchmark = True - pretrained_url = cfg.get('MODEL', 'PRETRAINED_URL') - print("=> Loading checkpoint from:'{}'".format(pretrained_url)) - if args.gpu is None: - loc = 'cpu' - checkpoint = torch.hub.load_state_dict_from_url(pretrained_url) else: - # Map model to be loaded to specified single gpu. - loc = 'cuda:{}'.format(args.gpu) - checkpoint = torch.hub.load_state_dict_from_url(pretrained_url, map_location=loc) + loc = 'cpu' - model.load_state_dict(checkpoint, strict=True) + model.to(loc) model.eval(inference=True) with torch.no_grad(): diff --git a/examples/text_to_speech/preprocess_dataset.py b/brevitas_examples/text_to_speech/preprocess_dataset.py similarity index 76% rename from examples/text_to_speech/preprocess_dataset.py rename to brevitas_examples/text_to_speech/preprocess_dataset.py index df76caa18..e2204efe5 100644 --- a/examples/text_to_speech/preprocess_dataset.py +++ b/brevitas_examples/text_to_speech/preprocess_dataset.py @@ -4,13 +4,13 @@ import torch import argparse import numpy as np -import configparser +from configparser import ConfigParser -from utilities.stft import TacotronSTFT -from utilities.audio_processing import read_wav_np +from .utilities.stft import TacotronSTFT +from .utilities.audio_processing import read_wav_np -def main(cfg, args): +def preprocess(cfg, args): filter_length = cfg.getint('AUDIO', 'filter_length') hop_length = cfg.getint('AUDIO', 'hop_length') win_length = cfg.getint('AUDIO', 'win_length') @@ -48,16 +48,21 @@ def main(cfg, args): torch.save(mel, melpath) -if __name__ == '__main__': +def main(): parser = argparse.ArgumentParser() - parser.add_argument('-c', '--config', type=str, required=True, - help="yaml file for config.") + parser.add_argument('-n', '--name', type=str, required=True, + help="name of the model") parser.add_argument('-d', '--data-path', type=str, required=True, help="root directory of wav files") args = parser.parse_args() + cfg = ConfigParser() + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(current_dir, 'cfg', args.name + '.ini') + assert os.path.exists(config_path) + cfg.read(config_path) + + preprocess(cfg, args) - assert os.path.exists(args.config) - cfg = configparser.ConfigParser() - cfg.read(args.config) - main(cfg, args) +if __name__ == '__main__': + main() diff --git a/examples/text_to_speech/utilities/__init__.py b/brevitas_examples/text_to_speech/utilities/__init__.py similarity index 100% rename from examples/text_to_speech/utilities/__init__.py rename to brevitas_examples/text_to_speech/utilities/__init__.py diff --git a/examples/text_to_speech/utilities/audio_processing.py b/brevitas_examples/text_to_speech/utilities/audio_processing.py similarity index 100% rename from examples/text_to_speech/utilities/audio_processing.py rename to brevitas_examples/text_to_speech/utilities/audio_processing.py diff --git a/examples/text_to_speech/utilities/stft.py b/brevitas_examples/text_to_speech/utilities/stft.py similarity index 100% rename from examples/text_to_speech/utilities/stft.py rename to brevitas_examples/text_to_speech/utilities/stft.py diff --git a/examples/imagenet_classification/models/__init__.py b/examples/imagenet_classification/models/__init__.py deleted file mode 100644 index 38de5032b..000000000 --- a/examples/imagenet_classification/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mobilenetv1 import * -from .vgg import * -from .proxylessnas import * \ No newline at end of file diff --git a/examples/speech_to_text/quartznet/__init__.py b/examples/speech_to_text/quartznet/__init__.py deleted file mode 100644 index 27976ead0..000000000 --- a/examples/speech_to_text/quartznet/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# Adapted from https://github.com/NVIDIA/NeMo/blob/r0.9/collections/nemo_asr/ -# Copyright (C) 2020 Xilinx (Giuseppe Franco) -# Copyright (C) 2019 NVIDIA CORPORATION. -# -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# from .audio_preprocessing import AudioToMelSpectrogramPreprocessor -from .data_layer import ( - AudioToTextDataLayer) -from .greedy_ctc_decoder import GreedyCTCDecoder -from .quartznet import quartznet -from .losses import CTCLossNM - -__all__ = ['AudioToTextDataLayer', - 'quartznet'] - - -name = "quarznet_release" diff --git a/examples/speech_to_text/requirements.txt b/examples/speech_to_text/requirements.txt deleted file mode 100644 index b798cee7e..000000000 --- a/examples/speech_to_text/requirements.txt +++ /dev/null @@ -1,34 +0,0 @@ -pillow>=4.3.0 -ipython[all] -tqdm -sox -ruamel.yaml -jupyterlab -tqdm -boto3 -requests -six -ipdb -h5py -html2text -nltk -progressbar -matplotlib -wget -tensorboardX -pandas -onnx -wget -num2words -librosa -inflect -kaldi-io -marshmallow -unidecode -sentencepiece -boto3 -matplotlib -h5py -youtokentome -docrep -torch-stft diff --git a/noxfile.py b/noxfile.py index db67bed70..6b5c88d88 100644 --- a/noxfile.py +++ b/noxfile.py @@ -10,61 +10,79 @@ PYTORCH_CPU_VIRTUAL_PKG = '1.2.0' +NOX_WIN_NUMPY_VERSION = '1.17.4' # avoid errors from more recent Numpy called through Nox on Windows CONDA_PYTHON_IDS = tuple([f'conda_python_{i}' for i in CONDA_PYTHON_VERSIONS]) PYTORCH_IDS = tuple([f'pytorch_{i}' for i in PYTORCH_VERSIONS]) JIT_IDS = tuple([f'{i}'.lower() for i in JIT_STATUSES]) +def install_pytorch(pytorch, session): + is_win = system() == 'Windows' + is_cpu_virtual = version.parse(pytorch) >= version.parse(PYTORCH_CPU_VIRTUAL_PKG) + if is_cpu_virtual and is_win: + session.conda_install('-c', 'pytorch', f'pytorch=={pytorch}', f'numpy=={NOX_WIN_NUMPY_VERSION}', 'cpuonly') + elif is_cpu_virtual and not is_win: + session.conda_install('-c', 'pytorch', f'pytorch=={pytorch}', 'cpuonly') + elif not is_cpu_virtual and is_win: + session.conda_install('-c', 'pytorch', f'pytorch-cpu=={pytorch}', f'numpy=={NOX_WIN_NUMPY_VERSION}') + else: + session.conda_install('-c', 'pytorch', f'pytorch-cpu=={pytorch}') + + +def dry_run_install_pytorch_deps(python, pytorch, session, deps_only): + deps = '--only-deps' if deps_only else '--no-deps' + is_win = system() == 'Windows' + is_cpu_virtual = version.parse(pytorch) >= version.parse(PYTORCH_CPU_VIRTUAL_PKG) + if is_cpu_virtual and is_win: + session.run('conda', 'create', '-n', 'dry_run', deps, '-d', '-c', 'pytorch', f'pytorch=={pytorch}', + f'numpy=={NOX_WIN_NUMPY_VERSION}', 'cpuonly', f'python={python}') + elif is_cpu_virtual and not is_win: + session.run('conda', 'create', '-n', 'dry_run', deps, '-d', '-c', 'pytorch', f'pytorch=={pytorch}', + 'cpuonly', f'python={python}') + elif not is_cpu_virtual and not is_win: + session.run('conda', 'create', '-n', 'dry_run', deps, '-d', '-c', 'pytorch', f'pytorch-cpu=={pytorch}', + 'cpuonly', f'python={python}') + else: + session.run('conda', 'create', '-n', 'dry_run', deps, '-d', '-c', 'pytorch', f'pytorch-cpu=={pytorch}', + f'numpy=={NOX_WIN_NUMPY_VERSION}', f'python={python}') + + @nox.session(venv_backend="conda", python=CONDA_PYTHON_VERSIONS) @nox.parametrize("pytorch", PYTORCH_VERSIONS, ids=PYTORCH_IDS) @nox.parametrize("jit_status", JIT_STATUSES, ids=JIT_IDS) -def tests_cpu(session, pytorch, jit_status): +def tests_brevitas_cpu(session, pytorch, jit_status): session.env['PYTORCH_JIT'] = '{}'.format(int(jit_status == 'jit_enabled')) - if version.parse(pytorch) >= version.parse(PYTORCH_CPU_VIRTUAL_PKG): - session.conda_install('-c', 'pytorch', f'pytorch=={pytorch}', 'cpuonly') - else: - session.conda_install('-c', 'pytorch', f'pytorch-cpu=={pytorch}') + install_pytorch(pytorch, session) session.install('.[test]') - session.run('pytest', '-v') + session.run('pytest', 'test/brevitas', '-v') @nox.session(venv_backend="conda", python=CONDA_PYTHON_VERSIONS) @nox.parametrize("pytorch", PYTORCH_VERSIONS, ids=PYTORCH_IDS) -def tests_install_develop(session, pytorch): - if version.parse(pytorch) >= version.parse(PYTORCH_CPU_VIRTUAL_PKG): - session.conda_install('-c', 'pytorch', f'pytorch=={pytorch}', 'cpuonly') - else: - session.conda_install('-c', 'pytorch', f'pytorch-cpu=={pytorch}') - session.install('-e', '.') - if system() == 'Windows': - env = session.env - nox.command.run(['python', '-c', 'import brevitas'], env=env, path=os.path.dirname(session.bin)) - nox.command.run(['python', '-c', 'import brevitas.function.ops_ste'], env=env, - path=os.path.dirname(session.bin)) - else: - session.run('python', '-c', 'import brevitas') - session.run('python', '-c', 'import brevitas.function.ops_ste') +def tests_brevitas_install_dev(session, pytorch): + install_pytorch(pytorch, session) + session.install('-e', '.[test]') + session.run('pytest', '-v', 'test/brevitas/test_import.py') + + +@nox.session(venv_backend="conda", python=CONDA_PYTHON_VERSIONS) +@nox.parametrize("pytorch", PYTORCH_VERSIONS, ids=PYTORCH_IDS) +def tests_brevitas_examples_install_dev(session, pytorch): + install_pytorch(pytorch, session) + session.conda_install('scipy') # For Hadamard example + session.install('-e', '.[test, tts, stt]') + session.run('pytest', '-v', 'test/brevitas_examples/test_import.py') @nox.session(python=False) @nox.parametrize("pytorch", PYTORCH_VERSIONS, ids=PYTORCH_IDS) @nox.parametrize("python", CONDA_PYTHON_VERSIONS, ids=CONDA_PYTHON_IDS) def dry_run_pytorch_only_deps(session, pytorch, python): - if version.parse(pytorch) >= version.parse(PYTORCH_CPU_VIRTUAL_PKG): - session.run('conda', 'create', '-n', 'dry_run', '--only-deps', '-d', '-c', 'pytorch', f'pytorch=={pytorch}', - 'cpuonly', f'python={python}') - else: - session.run('conda', 'create', '-n', 'dry_run', '--only-deps', '-d', '-c', 'pytorch', f'pytorch-cpu=={pytorch}', - f'python={python}') + dry_run_install_pytorch_deps(python, pytorch, session, deps_only=True) @nox.session(python=False) @nox.parametrize("pytorch", PYTORCH_VERSIONS, ids=PYTORCH_IDS) @nox.parametrize("python", CONDA_PYTHON_VERSIONS, ids=CONDA_PYTHON_IDS) def dry_run_pytorch_no_deps(session, pytorch, python): - if version.parse(pytorch) >= version.parse(PYTORCH_CPU_VIRTUAL_PKG): - session.run('conda', 'create', '-n', 'dry_run', '--no-deps', '-d', '-c', 'pytorch', f'pytorch=={pytorch}', - 'cpuonly', f'python={python}') - else: - session.run('conda', 'create', '-n', 'dry_run', '--no-deps', '-d', '-c', 'pytorch', f'pytorch-cpu=={pytorch}', - f'python={python}') + dry_run_install_pytorch_deps(python, pytorch, session, deps_only=False) diff --git a/requirements/requirements-stt.txt b/requirements/requirements-stt.txt new file mode 100644 index 000000000..c820e73c5 --- /dev/null +++ b/requirements/requirements-stt.txt @@ -0,0 +1,10 @@ +pillow>=4.3.0 +ruamel.yaml +requests +librosa +inflect +unidecode +torch-stft +sox +tqdm +soundfile diff --git a/examples/text_to_speech/requirements.txt b/requirements/requirements-tts.txt similarity index 55% rename from examples/text_to_speech/requirements.txt rename to requirements/requirements-tts.txt index f14b66145..075e3e6a0 100644 --- a/examples/text_to_speech/requirements.txt +++ b/requirements/requirements-tts.txt @@ -1,11 +1,7 @@ +soundfile librosa -matplotlib numpy scipy -tensorboardX -torch tqdm pillow pyyaml -soundfile -packaging diff --git a/setup.py b/setup.py index 797e53dc6..567c0877e 100644 --- a/setup.py +++ b/setup.py @@ -167,15 +167,29 @@ def run(self): install_requires=read_requirements('requirements.txt'), extras_require={ "Hadamard": read_requirements('requirements-hadamard.txt'), - "test": read_requirements('requirements-test.txt') + "test": read_requirements('requirements-test.txt'), + "tts": read_requirements('requirements-tts.txt'), + "stt": read_requirements('requirements-stt.txt') }, packages=find_packages(), zip_safe=False, ext_modules=get_jittable_extension(), cmdclass={ - 'build_py': BuildPy, - 'build_ext': BuildJittableExtension.with_options(no_python_abi_suffix=True), - 'develop': DevelopInstall, - } - ) + 'build_py': BuildPy, + 'build_ext': BuildJittableExtension.with_options(no_python_abi_suffix=True), + 'develop': DevelopInstall, + }, + package_data={ + 'brevitas_examples': ['*.ini', '*.yaml'], + }, + entry_points={ + 'console_scripts': [ + 'brevitas_imagenet_val = brevitas_examples.imagenet_classification.imagenet_val:main', + 'brevitas_quartznet_val = brevitas_examples.speech_to_text.quartznet_val:main', + 'brevitas_melgan_val = brevitas_examples.text_to_speech.melgan_val:main', + 'brevitas_quartznet_preprocess = brevitas_examples.speech_to_text.get_librispeech_data:main', + 'brevitas_melgan_preprocess = brevitas_examples.text_to_speech.preprocess_dataset:main' + ], + }) + diff --git a/test/common.py b/test/brevitas/common.py similarity index 100% rename from test/common.py rename to test/brevitas/common.py diff --git a/test/generate_quant_input.py b/test/brevitas/generate_quant_input.py similarity index 100% rename from test/generate_quant_input.py rename to test/brevitas/generate_quant_input.py diff --git a/test/test_act_scaling.py b/test/brevitas/test_act_scaling.py similarity index 100% rename from test/test_act_scaling.py rename to test/brevitas/test_act_scaling.py diff --git a/test/test_conv1d.py b/test/brevitas/test_conv1d.py similarity index 100% rename from test/test_conv1d.py rename to test/brevitas/test_conv1d.py diff --git a/test/brevitas/test_import.py b/test/brevitas/test_import.py new file mode 100644 index 000000000..30a571e87 --- /dev/null +++ b/test/brevitas/test_import.py @@ -0,0 +1,44 @@ +# Copyright (c) 2019- Xilinx, Inc (Alessandro Pappalardo) +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Xilinx, Facebook, Deepmind Technologies, NYU, +# NEC Laboratories America and IDIAP Research Institute nor the names +# of its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +def test_import_brevitas(): + import brevitas + from brevitas.function import ops_ste diff --git a/test/test_ops.py b/test/brevitas/test_ops.py similarity index 100% rename from test/test_ops.py rename to test/brevitas/test_ops.py diff --git a/test/test_quant.py b/test/brevitas/test_quant.py similarity index 100% rename from test/test_quant.py rename to test/brevitas/test_quant.py diff --git a/test/test_transposed_conv1d.py b/test/brevitas/test_transposed_conv1d.py similarity index 100% rename from test/test_transposed_conv1d.py rename to test/brevitas/test_transposed_conv1d.py diff --git a/test/brevitas_examples/test_import.py b/test/brevitas_examples/test_import.py new file mode 100644 index 000000000..2b61a7b45 --- /dev/null +++ b/test/brevitas_examples/test_import.py @@ -0,0 +1,59 @@ +# Copyright (c) 2019- Xilinx, Inc (Alessandro Pappalardo) +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. + +# 3. Neither the names of Xilinx, Facebook, Deepmind Technologies, NYU, +# NEC Laboratories America and IDIAP Research Institute nor the names +# of its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +def test_import_image_classification(): + from brevitas_examples.imagenet_classification import ( + quant_mobilenet_v1_4b, + quant_proxylessnas_mobile14_hadamard_4b, + quant_proxylessnas_mobile14_4b5b, + quant_proxylessnas_mobile14_4b) + + +def test_import_tts(): + from brevitas_examples.text_to_speech import quant_melgan_8b + + +def test_import_stt(): + from brevitas_examples.speech_to_text import ( + quant_quartznet_pertensorscaling_8b, + quant_quartznet_perchannelscaling_8b, + quant_quartznet_perchannelscaling_4b + ) \ No newline at end of file