Skip to content

The WeightWatcher tool for predicting the accuracy of Deep Neural Networks

License

Notifications You must be signed in to change notification settings

ehariri/WeightWatcher

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Weight Watcher

the Stat Mech Edition

Current Version: 0.2.7

Weight Watcher analyzes the Fat Tails in the weight matrices of Deep Neural Networks (DNNs).

This tool can predict the trends in the generalization accuracy of a series of DNNs, such as VGG11, VGG13, ..., or even the entire series of ResNet models--without needing a test set !

This relies upon recent research into the Heavy (Fat) Tailed Self Regularization in DNNs

The tool lets one compute a averager capacity, or quality, metric for a series of DNNs, trained on the same data, but with different hyperparameters, or even different but related architectures. For example, it can predict that VGG19_BN generalizes better than VGG19, and better than VGG16_BN, VGG16, etc.

Types of Capacity Metrics:

There are 2 metrics availabe. The average log Norm, which is much faster but less accurate. The average weighted alpha is more accurate but much slower because it needs to both compute the SVD of the layer weight matrices, and thenaa fit the singluar/eigenvalues to a power law.

  • log Norm (default, fast, less accurate)
  • weighted alpaha (slow, more accurate)

Here is an example of the Weighted Alpha capacity metric for all the current pretrained VGG models. alt text

Notice: we did not peek at the ImageNet test data to build this plot.

Frameworks supported

  • Keras
  • PyTorch

Layers supported

  • Dense / Linear / Fully Connected (and Conv1D)
  • Conv2D

Installation

pip install weightwatcher

Usage

Weight Watcher works with both Keras and pyTorch models.

import weightwatcher as ww
watcher = ww.WeightWatcher(model=model)
results = watcher.analyze()

watcher.get_summary()
watcher.print_results()

Advanced Usage

The analyze function has several features described below

def analyze(self, model=None, layers=[], min_size=50, max_size=0,
                alphas=False, softranks=True, spectralnorms=True, 
                mp_fit=True,  plot=False):
...

and in the Demo Notebook

weighted alpha (SLOW)

Power Law fit, here with pyTorch example

import weightwatcher as ww
import torchvision.models as models

model = models.vgg19_bn(pretrained=True)
watcher = ww.WeightWatcher(model=model)
results = watcher.analyze(alphas=True)
data.append({"name": "vgg19bntorch", "summary": watcher.get_summary()})


### data:
{'name': 'vgg19bntorch',
  'summary': {'lognorm': 0.81850576,
   'lognorm_compound': 0.9365272010550088,
   'alpha': 2.9646726379493287,
   'alpha_compound': 2.847975521455623,
   'alpha_weighted': 1.1588882728052485,
   'alpha_weighted_compound': 1.5002343912892515}},

Capacity Metrics (evarages over all layers):

  • lognorm: average log norm, fast

  • alpha_weight: average weighted alpha, slow

  • alpha: average alpha, not weighted (slow, not as useful)

Compound averages:

Same as above, but averages are computed slightly differently. This will be desrcibed in an upcoming paper.

Results are also provided for every layer; see Demo Notebook

Additional options

filter by layer types

results = watcher.analyze(layers=ww.LAYER_TYPE.CONV1D|ww.LAYER_TYPE.DENSE)

filter by ids

results = watcher.analyze(layers=[20])

minimum, maximum size of weight matrix

Sets the minimum and maximum size of the weight matrices analyzed. Setting max is useful for a quick debugging.

results = watcher.analyze(min_size=50, max_size=500)

plots (for weight_alpha=True)

Create log-log plots for each layer weight matrix to observe how well the power law fits work

results = watcher.analyze(alphas=True, plot=True)

Links

Demo Notebook

Calculation Consulting homepage

Calculated Content Blog


Implicit Self-Regularization in Deep Neural Networks: Evidence from Random Matrix Theory and Implications for Learning

Traditional and Heavy Tailed Self Regularization in Neural Network Models

Notebook for above 2 papers (https://github.com/CalculatedContent/ImplicitSelfRegularization)

Talk at NERSC Summer 2018


Heavy-Tailed Universality Predicts Trends in Test Accuracies for Very Large Pre-Trained Deep Neural Networks

Notebook for paper (https://github.com/CalculatedContent/PredictingTestAccuracies)

Talk at UC Berkeley/ICSI 12/13/2018

ICML 2019 Theoretical Physics Workshop Paper


KDD2019 Workshop

KDD 2019 Workshop: Statistical Mechanics Methods for Discovering Knowledge from Production-Scale Neural Networks

KDD 2019 Workshop: Slides


Selected Podcasts

Data Science at Home Podcast

Aggregate Intellect Podcast


Latest paper and results

Predicting trends in the quality of state-of-the-art neural networks without access to training or testing data

Repo for latest paper

Talk on latest results, Stanford ICME 2020

How to Release

Publishing to the PyPI repository:

# 1. Check in the latest code with the correct revision number (__version__ in __init__.py)
vi weightwatcher/__init__.py # Increse release number, remove -dev to revision number
git commit
# 2. Check out latest version from the repo in a fresh directory
cd ~/temp/
git clone https://github.com/CalculatedContent/WeightWatcher
cd WeightWatcher/
# 3. Use the latest version of the tools
python -m pip install --upgrade setuptools wheel twine
# 4. Create the package
python setup.py sdist bdist_wheel
# 5. Test the package
twine check dist/*
# 6. Upload the package to PyPI
twine upload dist/*
# 7. Tag/Release in github by creating a new release (https://github.com/CalculatedContent/WeightWatcher/releases/new)

License

Apache License 2.0

Slack Channel

We have a slack channel for the tool if you need help For an invite, please send an email to [email protected]

Contributors

Charles H Martin, PhD Calculation Consulting

Serena Peng

About

The WeightWatcher tool for predicting the accuracy of Deep Neural Networks

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Jupyter Notebook 97.3%
  • HTML 1.5%
  • Python 1.2%