-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #40 from raynardj/new_feature
⛈ lightning callbacks
- Loading branch information
Showing
9 changed files
with
389 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
--- | ||
|
||
title: Lightning Callbacks | ||
|
||
|
||
keywords: fastai | ||
sidebar: home_sidebar | ||
|
||
summary: "Thunder, the DIYed <a href='https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html'>pytorch-lightening callbacks</a>" | ||
description: "Thunder, the DIYed <a href='https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html'>pytorch-lightening callbacks</a>" | ||
nb_path: "nbs/61_thunder_callbacks.ipynb" | ||
--- | ||
<!-- | ||
################################################# | ||
### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ### | ||
################################################# | ||
# file to edit: nbs/61_thunder_callbacks.ipynb | ||
# command to build the docs after a change: nbdev_build_docs | ||
--> | ||
|
||
<div class="container" id="notebook-container"> | ||
|
||
{% raw %} | ||
|
||
<div class="cell border-box-sizing code_cell rendered"> | ||
|
||
</div> | ||
{% endraw %} | ||
|
||
{% raw %} | ||
|
||
<div class="cell border-box-sizing code_cell rendered"> | ||
|
||
</div> | ||
{% endraw %} | ||
|
||
{% raw %} | ||
|
||
<div class="cell border-box-sizing code_cell rendered"> | ||
|
||
<div class="output_wrapper"> | ||
<div class="output"> | ||
|
||
<div class="output_area"> | ||
|
||
|
||
<div class="output_markdown rendered_html output_subarea "> | ||
<h4 id="unfreeze" class="doc_header"><code>unfreeze</code><a href="https://github.com/raynardj/forgebox/tree/master/forgebox/thunder/callbacks.py#L15" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>unfreeze</code>()</p> | ||
</blockquote> | ||
<p>unfreeze this module, and its sub modules</p> | ||
|
||
</div> | ||
|
||
</div> | ||
|
||
</div> | ||
</div> | ||
|
||
</div> | ||
{% endraw %} | ||
|
||
{% raw %} | ||
|
||
<div class="cell border-box-sizing code_cell rendered"> | ||
|
||
<div class="output_wrapper"> | ||
<div class="output"> | ||
|
||
<div class="output_area"> | ||
|
||
|
||
<div class="output_markdown rendered_html output_subarea "> | ||
<h4 id="freeze" class="doc_header"><code>freeze</code><a href="https://github.com/raynardj/forgebox/tree/master/forgebox/thunder/callbacks.py#L21" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>freeze</code>()</p> | ||
</blockquote> | ||
<p>freeze this module, and its sub modules</p> | ||
|
||
</div> | ||
|
||
</div> | ||
|
||
</div> | ||
</div> | ||
|
||
</div> | ||
{% endraw %} | ||
|
||
{% raw %} | ||
|
||
<div class="cell border-box-sizing code_cell rendered"> | ||
|
||
<div class="output_wrapper"> | ||
<div class="output"> | ||
|
||
<div class="output_area"> | ||
|
||
|
||
<div class="output_markdown rendered_html output_subarea "> | ||
<h2 id="DataFrameMetricsCallback" class="doc_header"><code>class</code> <code>DataFrameMetricsCallback</code><a href="https://github.com/raynardj/forgebox/tree/master/forgebox/thunder/callbacks.py#L29" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>DataFrameMetricsCallback</code>() :: <code>Callback</code></p> | ||
</blockquote> | ||
<p>A metrics callback keep showing pandas dataframe</p> | ||
|
||
</div> | ||
|
||
</div> | ||
|
||
</div> | ||
</div> | ||
|
||
</div> | ||
{% endraw %} | ||
|
||
{% raw %} | ||
|
||
<div class="cell border-box-sizing code_cell rendered"> | ||
|
||
<div class="output_wrapper"> | ||
<div class="output"> | ||
|
||
<div class="output_area"> | ||
|
||
|
||
<div class="output_markdown rendered_html output_subarea "> | ||
<h4 id="UnfreezeScheduler" class="doc_header"><code>UnfreezeScheduler</code><a href="https://github.com/raynardj/forgebox/tree/master/forgebox/thunder/callbacks.py#L60" class="source_link" style="float:right">[source]</a></h4><blockquote><p><code>UnfreezeScheduler</code>(<strong><code>frozen_epochs</code></strong>:<code>int</code>=<em><code>2</code></em>)</p> | ||
</blockquote> | ||
|
||
</div> | ||
|
||
</div> | ||
|
||
</div> | ||
</div> | ||
|
||
</div> | ||
{% endraw %} | ||
|
||
{% raw %} | ||
|
||
<div class="cell border-box-sizing code_cell rendered"> | ||
|
||
</div> | ||
{% endraw %} | ||
|
||
</div> | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.4.4" | ||
__version__ = "0.4.5" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/61_thunder_callbacks.ipynb (unless otherwise specified). | ||
|
||
__all__ = ['unfreeze', 'freeze', 'DataFrameMetricsCallback', 'UnfreezeScheduler'] | ||
|
||
# Cell | ||
import pandas as pd | ||
from ipywidgets import Output | ||
from typing import List, Dict | ||
import copy | ||
import pytorch_lightning as pl | ||
import torch | ||
from torch import nn | ||
|
||
# Cell | ||
def unfreeze(self): | ||
"""unfreeze this module, and its sub modules""" | ||
for p in self.parameters(): | ||
p.requires_grad = True | ||
|
||
|
||
def freeze(self): | ||
"""freeze this module, and its sub modules""" | ||
for p in self.parameters(): | ||
p.requires_grad = False | ||
|
||
nn.Module.unfreeze = unfreeze | ||
nn.Module.freeze = freeze | ||
|
||
class DataFrameMetricsCallback(pl.Callback): | ||
""" | ||
A metrics callback keep showing pandas dataframe | ||
""" | ||
|
||
def __init__(self) -> None: | ||
""" | ||
In Trainer kwargs, passing this arguements along with other callbacks | ||
callbacks = [DataFrameMetricsCallback(),] | ||
""" | ||
self.metrics: List = [] | ||
|
||
def on_fit_start( | ||
self, trainer: pl.Trainer, | ||
pl_module: pl.LightningModule | ||
) -> None: | ||
pl_module.output = Output() | ||
display(pl_module.output) | ||
|
||
def on_validation_epoch_end( | ||
self, trainer: pl.Trainer, | ||
pl_module: pl.LightningModule | ||
) -> None: | ||
metrics_dict = copy.copy(trainer.callback_metrics) | ||
self.metrics.append(dict((k, v.item()) | ||
for k, v in metrics_dict.items())) | ||
pl_module.output.clear_output() | ||
with pl_module.output: | ||
display(pd.DataFrame(self.metrics).tail(10)) | ||
|
||
|
||
def UnfreezeScheduler(frozen_epochs: int = 2): | ||
assert hasattr(pl_module, "top_layers"), "Please define 'top_layers' attributes"+\ | ||
" for pl_module, which will return a list of nn.Module object(s)" | ||
class UnfreezeSchedulerCallback(pl.callbacks.Callback): | ||
""" | ||
Train the top layer for [frozen_epochs] epochs | ||
then un freeze all | ||
""" | ||
|
||
def on_epoch_start(self, trainer, pl_module): | ||
epoch = trainer.current_epoch | ||
|
||
if epoch == 0: | ||
pl_module.freeze() | ||
for tl in pl_module.top_layers: | ||
tl.unfreeze() | ||
if epoch == frozen_epochs: | ||
pl_module.unfreeze() | ||
pl_module.base.embeddings.freeze() |
Oops, something went wrong.