From 027b933757c59fbd882947baa1b42b8dd3b1dc27 Mon Sep 17 00:00:00 2001 From: raynardj Date: Sat, 30 Apr 2022 18:18:18 +0800 Subject: [PATCH] nb cleaning --- nbs/11_etl.ipynb | 174 ---- nbs/61_thunder_callbacks.ipynb | 151 ---- nbs/70_hf_transformer_data.ipynb | 553 ------------- nbs/72_pl_training.ipynb | 466 ----------- nbs/cross_entropy_weighter.ipynb | 1298 ------------------------------ nbs/optimizers.ipynb | 373 --------- 6 files changed, 3015 deletions(-) delete mode 100644 nbs/11_etl.ipynb delete mode 100644 nbs/61_thunder_callbacks.ipynb delete mode 100644 nbs/70_hf_transformer_data.ipynb delete mode 100644 nbs/72_pl_training.ipynb delete mode 100644 nbs/cross_entropy_weighter.ipynb delete mode 100644 nbs/optimizers.ipynb diff --git a/nbs/11_etl.ipynb b/nbs/11_etl.ipynb deleted file mode 100644 index 711b499..0000000 --- a/nbs/11_etl.ipynb +++ /dev/null @@ -1,174 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# ETL Helper\n", - "> A combined ETL tool sets" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp etl" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## A list for each step" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "\n", - "from forgebox.imports import *\n", - "from typing import Callable, Union\n", - "import traceback as tb\n", - "import logging\n", - "\n", - "class NewFileScanner:\n", - " \"\"\"\n", - " Keep scannning a directory for new file\n", - " if found any, callback processing the file\n", - " \n", - " Designed for download and delete strategy\n", - " \n", - " # Example\n", - " new_file_scanner = NewFileScanner(\".\", new_file_filter=lambda x:x[-4:]==\".txt\")\n", - " new_file_scanner(lambda x:print(f\"new file:{x} found\"))s\n", - " \"\"\"\n", - " def __init__(\n", - " self,\n", - " directory,\n", - " new_file_filter: Callable=None,\n", - " done_list_getter: Callable=None,\n", - " stop_callback: Callable=None,\n", - " ):\n", - " self.directory = Path(directory)\n", - " \n", - " if new_file_filter is not None:\n", - " self.new_file_filter=new_file_filter\n", - " else:\n", - " self.new_file_filter=lambda x:True\n", - " \n", - " if done_list_getter is not None:\n", - " self.done_list_getter = done_list_getter\n", - " else:\n", - " self.done_list_getter = lambda *args:[]\n", - " \n", - " if stop_callback is not None:\n", - " self.stop_callback = stop_callback\n", - " \n", - " self.processed = []\n", - " \n", - " def __call__(self, process: Callable):\n", - " while True:\n", - " try:\n", - " done_list = self.done_list_getter()\n", - " for fname in self.directory.iterdir():\n", - " fname = str(fname)\n", - " if self.new_file_filter(fname) != True:\n", - " continue\n", - " file_path = str(self.directory/fname)\n", - " if fname not in done_list:\n", - " if fname not in self.processed:\n", - " result = process(file_path)\n", - " self.processed.append(fname)\n", - " except KeyboardInterrupt as e:\n", - " if hasattr(self, \"stop_callback\"):\n", - " self.stop_callback()\n", - " logging.error(\"manually stoped\")\n", - " break\n", - " except Exception as e:\n", - " error = tb.format_exc()\n", - " logging.error(error)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Test new file scanner" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "new file:untitled.txt found\n", - "new file:bc.txt found\n", - "new file:cde.txt found\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ERROR:root:manually stoped\n" - ] - } - ], - "source": [ - "new_file_scanner = NewFileScanner(\".\", new_file_filter=lambda x:x[-4:]==\".txt\")\n", - "new_file_scanner(lambda x:print(f\"new file:{x} found\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/61_thunder_callbacks.ipynb b/nbs/61_thunder_callbacks.ipynb deleted file mode 100644 index 848e47f..0000000 --- a/nbs/61_thunder_callbacks.ipynb +++ /dev/null @@ -1,151 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Lightning Callbacks\n", - "> Thunder, the DIYed [pytorch-lightening callbacks](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp thunder.callbacks" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "import pandas as pd\n", - "from ipywidgets import Output\n", - "from typing import List, Dict\n", - "import copy\n", - "import pytorch_lightning as pl\n", - "import torch\n", - "from torch import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "def unfreeze(self):\n", - " \"\"\"unfreeze this module, and its sub modules\"\"\"\n", - " for p in self.parameters():\n", - " p.requires_grad = True\n", - "\n", - "\n", - "def freeze(self):\n", - " \"\"\"freeze this module, and its sub modules\"\"\"\n", - " for p in self.parameters():\n", - " p.requires_grad = False\n", - "\n", - "nn.Module.unfreeze = unfreeze\n", - "nn.Module.freeze = freeze\n", - "\n", - "class DataFrameMetricsCallback(pl.Callback):\n", - " \"\"\"\n", - " A metrics callback keep showing pandas dataframe\n", - " \"\"\"\n", - "\n", - " def __init__(self) -> None:\n", - " \"\"\"\n", - " In Trainer kwargs, passing this arguements along with other callbacks\n", - " callbacks = [DataFrameMetricsCallback(),]\n", - " \"\"\"\n", - " self.metrics: List = []\n", - "\n", - " def on_fit_start(\n", - " self, trainer: pl.Trainer,\n", - " pl_module: pl.LightningModule\n", - " ) -> None:\n", - " pl_module.output = Output()\n", - " display(pl_module.output)\n", - "\n", - " def on_validation_epoch_end(\n", - " self, trainer: pl.Trainer,\n", - " pl_module: pl.LightningModule\n", - " ) -> None:\n", - " metrics_dict = copy.copy(trainer.callback_metrics)\n", - " self.metrics.append(dict((k, v.item())\n", - " for k, v in metrics_dict.items()))\n", - " pl_module.output.clear_output()\n", - " with pl_module.output:\n", - " display(pd.DataFrame(self.metrics).tail(10))\n", - "\n", - "\n", - "def UnfreezeScheduler(frozen_epochs: int = 2):\n", - " assert hasattr(pl_module, \"top_layers\"), \"Please define 'top_layers' attributes\"+\\\n", - " \" for pl_module, which will return a list of nn.Module object(s)\"\n", - " class UnfreezeSchedulerCallback(pl.callbacks.Callback):\n", - " \"\"\"\n", - " Train the top layer for [frozen_epochs] epochs\n", - " then un freeze all\n", - " \"\"\"\n", - "\n", - " def on_epoch_start(self, trainer, pl_module):\n", - " epoch = trainer.current_epoch\n", - "\n", - " if epoch == 0:\n", - " pl_module.freeze()\n", - " for tl in pl_module.top_layers:\n", - " tl.unfreeze()\n", - " if epoch == frozen_epochs:\n", - " pl_module.unfreeze()\n", - " pl_module.base.embeddings.freeze()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/70_hf_transformer_data.ipynb b/nbs/70_hf_transformer_data.ipynb deleted file mode 100644 index ac7b569..0000000 --- a/nbs/70_hf_transformer_data.ipynb +++ /dev/null @@ -1,553 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Data parts for hf transformers" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp hf.data" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "from forgebox.imports import *\n", - "from forgebox.category import Category\n", - "from typing import List, Dict, Callable, Any, Tuple" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Process IOBES files" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "def convert_iob2_file_to_iobes(file_path, result_path):\n", - " \"\"\"\n", - " Convert IOB2 file to IOBES\n", - " \"\"\"\n", - " with open(file_path, 'r') as f:\n", - " lines = f.readlines()\n", - " with open(result_path, 'w') as f:\n", - " for line in lines:\n", - " line = line.strip()\n", - " if line == '':\n", - " f.write('\\n')\n", - " continue\n", - " line = line.split()\n", - " if line[-1] == 'O':\n", - " f.write(' '.join(line) + '\\n')\n", - " else:\n", - " f.write(' '.join(line[:-1]) + ' ' + line[-1] + '\\n')\n", - "\n", - "\n", - "def conbine_iobes_file(\n", - " file_paths: List[Path],\n", - " new_file_path: Path\n", - "):\n", - " \"\"\"\n", - " Conbine from multiple IOBES files\n", - " into IOBES files\n", - " \"\"\"\n", - " with open(new_file_path, 'w') as new_file:\n", - " for file_path in file_paths:\n", - " with open(file_path, 'r') as file:\n", - " for line in file:\n", - " new_file.write(line)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "class IOBES(Dataset):\n", - " \"\"\"\n", - " Load iobes file for NER training task\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " file_path,\n", - " tokenizer,\n", - " max_len=128,\n", - " save_buffer: int = 15,\n", - " category: Category = None,\n", - " return_string: bool = False,\n", - " use_frag: bool = False,\n", - " ):\n", - " \"\"\"\n", - " file_path,\n", - " tokenizer,\n", - " max_len=128,\n", - " save_buffer: int = 15,\n", - " category: Category = None,\n", - " label categories, if set to None, will be figured out\n", - " automatically.\n", - " You can set this to None for train dataset, but for valid\n", - " dataset:\n", - " valid_ds = IOBES(...,category=train_ds.cates)\n", - " return_string: bool = False, do we return original string\n", - " for tokenizer output, this option is good for debuging\n", - " but the data won't pass into cuda if choose so\n", - " use_frag: bool = False, do we use prepend like 'I-','B-'\n", - " \"\"\"\n", - " self.file_path = file_path\n", - " self.max_len = max_len\n", - " self.pairs = []\n", - " self.list_of_words = []\n", - " self.list_of_labels = []\n", - " self.tokenizer = tokenizer\n", - " self.cates = category\n", - " self.return_string = return_string\n", - " self.use_frag = use_frag\n", - " self.load_data(save_buffer)\n", - "\n", - " def load_data(self, save_buffer: int = 15):\n", - " \"\"\"\n", - " Load file in to object structure\n", - " \"\"\"\n", - " with open(self.file_path, 'r') as f:\n", - " for line in f:\n", - " line = line.strip()\n", - " if line:\n", - " splited = line.split()\n", - " if len(splited) != 2:\n", - " continue\n", - " word, label = splited\n", - " # do we use 'I-', 'B-' etc\n", - " if self.use_frag is False:\n", - " if \"-\" in label:\n", - " label = label.split('-')[1]\n", - " self.pairs.append([word, label])\n", - "\n", - " self.pairs = np.array(self.pairs)\n", - "\n", - " if self.cates is None:\n", - " labels_df = pd.DataFrame({\"label\": self.pairs[:, 1]})\n", - " self.cates = Category(list(labels_df.vc(\"label\").index))\n", - "\n", - " self.batching_words(save_buffer)\n", - "\n", - " def batching_words(self, save_buffer: int = 15):\n", - " \"\"\"\n", - " batching self.words into self.list_of_words\n", - " by self.max_len -15\n", - " \"\"\"\n", - " for i in range(0, len(self.pairs), self.max_len-save_buffer):\n", - " chunk_slice = slice(i, i+self.max_len-save_buffer)\n", - " self.list_of_words.append(self.pairs[chunk_slice, 0])\n", - " self.list_of_labels.append(self.pairs[chunk_slice, 1])\n", - "\n", - " def __len__(self) -> int:\n", - " return len(self.list_of_words)\n", - "\n", - " def __getitem__(self, idx: int) -> Tuple[List[str]]:\n", - " return list(self.list_of_words[idx]), list(self.list_of_labels[idx])\n", - "\n", - " def __repr__(self):\n", - " return f\"\"\"NER dataset using IOBES annotation\n", - " {len(self)} sentences,\n", - " Labels:\n", - " {list(self.cates.i2c)}\n", - " \"\"\"\n", - "\n", - " def collate_fn(self, data):\n", - " \"\"\"\n", - " data: list of tuple\n", - " \"\"\"\n", - " words, text_labels = zip(*data)\n", - "\n", - " inputs = self.tokenizer(\n", - " list(words),\n", - " return_tensors='pt',\n", - " padding=True,\n", - " truncation=True,\n", - " max_length=self.max_len,\n", - " is_split_into_words=True,\n", - " return_offsets_mapping=True,\n", - " add_special_tokens=False,\n", - " )\n", - " return self.align_offsets(inputs, text_labels, words)\n", - "\n", - " def align_offsets(\n", - " self,\n", - " inputs,\n", - " text_labels: List[List[str]],\n", - " words: List[List[str]]\n", - " ):\n", - " \"\"\"\n", - " inputs: output if tokenizer\n", - " text_labels: labels in form of list of list of strings\n", - " words: words in form of list of list of strings\n", - " \"\"\"\n", - " labels = torch.zeros_like(inputs.input_ids).long()\n", - " labels -= 100\n", - " text_lables_array = np.empty(labels.shape, dtype=object)\n", - " words_array = np.empty(labels.shape, dtype=object)\n", - " max_len = inputs.input_ids.shape[1]\n", - "\n", - " for row_id, input_ids in enumerate(inputs.input_ids):\n", - " word_pos = inputs.word_ids(row_id)\n", - " for idx, pos in enumerate(word_pos):\n", - " if pos is None:\n", - " continue\n", - " if pos <= max_len:\n", - " labels[row_id, idx] = self.cates.c2i[text_labels[row_id][pos]]\n", - " if self.return_string:\n", - " text_lables_array[row_id,\n", - " idx] = text_labels[row_id][pos]\n", - " words_array[row_id, idx] = words[row_id][pos]\n", - "\n", - " inputs['labels'] = labels\n", - " if self.return_string:\n", - " inputs['text_labels'] = text_lables_array.tolist()\n", - " inputs['word'] = words_array.tolist()\n", - " return inputs\n", - "\n", - " def dataloader(self, batch_size: int = 32, shuffle: bool = True):\n", - " \"\"\"\n", - " Create dataloader\n", - " \"\"\"\n", - " return DataLoader(\n", - " self,\n", - " batch_size=batch_size,\n", - " shuffle=shuffle,\n", - " collate_fn=self.collate_fn,\n", - " )\n", - "\n", - " def one_batch(self, batch_size: int = 32, shuffle: bool = True):\n", - " return next(iter(self.dataloader(batch_size, shuffle)))\n", - "\n", - " def visualize_batch(self, batch, row_idx=0):\n", - " return list(zip(self.tokenizer.convert_ids_to_tokens(batch.input_ids[row_idx]),\n", - " batch.labels[row_idx].numpy(),\n", - " batch.text_labels[row_idx],\n", - " batch.word[row_idx],\n", - " batch.offset_mapping[row_idx].numpy(),\n", - " ))\n", - "\n", - " def set_hfconfig(self, config):\n", - " \"\"\"\n", - " set the category information to huggingface config\n", - " \"\"\"\n", - " config.num_labels = len(self.cates)\n", - " config.id2label = {i: label for i, label in enumerate(self.cates.i2c)}\n", - " config.label2id = {label: i for i, label in enumerate(self.cates.i2c)}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "tokenizer = AutoTokenizer.from_pretrained(\"raynardj/roberta-pubmed\", add_prefix_space=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = IOBES(\"/Users/xiaochen.zhang/data/valid.iobes\", tokenizer)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "in-O\n", - "blood-O\n", - ";-O\n", - "content-O\n", - "of-O\n", - "cAMP-O\n", - "was-O\n", - "also-O\n", - "decreased-O\n", - "in-O\n", - "lymphocytes-O\n", - "by-O\n", - "33-O\n", - "%-O\n", - ".-O\n", - "At-O\n", - "the-O\n", - "same-O\n", - "time-O\n", - ",-O\n", - "total-O\n", - "content-O\n", - "of-O\n", - "T-cell_type\n", - "lymphocytes-cell_type\n", - "was-O\n", - "decreased-O\n", - "1.5-fold-O\n", - "in-O\n", - "peripheric-O\n", - "blood-O\n", - ".-O\n", - "Treatment-O\n", - "with-O\n", - "I-hydroxyvitamin-O\n", - "D3-O\n", - "(-O\n", - "1-1.5-O\n", - "mg-O\n", - "daily-O\n", - ",-O\n", - "within-O\n", - "4-O\n", - "weeks-O\n", - ")-O\n", - "led-O\n", - "to-O\n", - "normalization-O\n", - "of-O\n", - "total-O\n", - "and-O\n", - "ionized-O\n", - "form-O\n", - "of-O\n", - "Ca2+-O\n", - "and-O\n", - "of-O\n", - "25-O\n", - "(-O\n", - "OH-O\n", - ")-O\n", - "D-O\n", - ",-O\n", - "but-O\n", - "did-O\n", - "not-O\n", - "affect-O\n", - "the-O\n", - "PTH-O\n", - "content-O\n", - "in-O\n", - "blood-O\n", - ".-O\n", - "Concentration-O\n", - "of-O\n", - "the-O\n", - "receptors-protein\n", - "to-O\n", - "1.25-O\n", - "(-O\n", - "OH-O\n", - ")-O\n", - "2D3-O\n", - "was-O\n", - "elevated-O\n", - "up-O\n", - "to-O\n", - "39.7-O\n", - "fmole/mg-O\n", - "after-O\n", - "I-O\n", - "week-O\n", - "of-O\n", - "the-O\n", - "treatment-O\n", - ",-O\n", - "whereas-O\n", - "it-O\n", - "was-O\n", - "decreased-O\n", - "to-O\n", - "the-O\n", - "initial-O\n", - "level-O\n", - "24.8-O\n", - "fmole/mg-O\n", - "within-O\n", - "4-O\n", - "weeks-O\n", - ";-O\n", - "simultaneous-O\n", - "alteration-O\n", - "in-O\n" - ] - } - ], - "source": [ - "for w,l in zip(*dataset[2]):\n", - " print(f\"{w}-{l}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'input_ids': tensor([[ 19, 3741, 2603, ..., 1417, 2617, 11576],\n", - " [ 4590, 2156, 255, ..., 405, 1182, 6608],\n", - " [ 6214, 25683, 3809, ..., 11, 5, 8151],\n", - " ...,\n", - " [13998, 25326, 2413, ..., 5, 2199, 21],\n", - " [11299, 705, 24811, ..., 134, 1589, 2032],\n", - " [ 5804, 924, 14, ..., 366, 1168, 9]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", - " ...,\n", - " [1, 1, 1, ..., 1, 1, 1],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", - " [1, 1, 1, ..., 1, 1, 1]]), 'offset_mapping': tensor([[[ 1, 4],\n", - " [ 1, 2],\n", - " [ 2, 5],\n", - " ...,\n", - " [ 3, 5],\n", - " [ 5, 8],\n", - " [ 1, 6]],\n", - "\n", - " [[ 1, 5],\n", - " [ 1, 1],\n", - " [ 1, 1],\n", - " ...,\n", - " [ 5, 7],\n", - " [ 7, 9],\n", - " [ 9, 14]],\n", - "\n", - " [[ 1, 5],\n", - " [ 5, 8],\n", - " [ 8, 10],\n", - " ...,\n", - " [ 1, 2],\n", - " [ 1, 3],\n", - " [ 1, 10]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1, 5],\n", - " [ 5, 8],\n", - " [ 8, 10],\n", - " ...,\n", - " [ 1, 3],\n", - " [ 1, 7],\n", - " [ 1, 3]],\n", - "\n", - " [[ 1, 5],\n", - " [ 5, 6],\n", - " [ 6, 10],\n", - " ...,\n", - " [ 2, 3],\n", - " [ 1, 1],\n", - " [ 1, 2]],\n", - "\n", - " [[ 1, 7],\n", - " [ 1, 5],\n", - " [ 1, 4],\n", - " ...,\n", - " [ 3, 5],\n", - " [ 5, 7],\n", - " [ 1, 2]]]), 'labels': tensor([[0, 1, 1, ..., 0, 0, 0],\n", - " [2, 0, 2, ..., 0, 0, 0],\n", - " [0, 0, 0, ..., 0, 0, 0],\n", - " ...,\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [0, 0, 0, ..., 2, 0, 2],\n", - " [0, 0, 0, ..., 0, 0, 0]])}" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset.one_batch()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": true - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/72_pl_training.ipynb b/nbs/72_pl_training.ipynb deleted file mode 100644 index 3ce4854..0000000 --- a/nbs/72_pl_training.ipynb +++ /dev/null @@ -1,466 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pytorch Lighting training\n", - "> on huggingface transformers" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp hf.train" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "from forgebox.hf.data import IOBES\n", - "from forgebox.imports import *\n", - "from forgebox.loop import chunkify\n", - "import pytorch_lightning as pl\n", - "from transformers import (\n", - " AutoModelForTokenClassification,\n", - " AutoTokenizer,\n", - " pipeline\n", - ")\n", - "from tqdm.notebook import tqdm\n", - "from typing import Callable, List\n", - "from torch import device" - ] - }, - { - "cell_type": "code", - "execution_count": 290, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "try:\n", - " ishell = get_ipython()\n", - " IS_JUPYTER = True\n", - " from tqdm.notebook import tqdm\n", - "except NameError:\n", - " IS_JUPYTER = False\n", - " from tqdm import tqdm" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "# !pip install transformers==4.9.1\n", - "# !pip install pytorch-lightning==1.3.8\n", - "# !pip install tensorflow==2.2.0" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load model and tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "\n", - "# ner model and tokenizer\n", - "def ner_model_from(\n", - " name:str, dataset: IOBES\n", - "):\n", - " \"\"\"\n", - " name: from_pretrain(name)\n", - " \"\"\"\n", - " model = AutoModelForTokenClassification.from_pretrained(\n", - " name,\n", - " num_labels=len(dataset.cates),\n", - " )\n", - " dataset.set_hfconfig(model.config)\n", - " return model\n", - "\n", - "def ner_tokenizer_from(\n", - " name: str\n", - "):\n", - " return AutoTokenizer.from_pretrained(\n", - " name, add_prefix_space=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Lightning data module" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "\n", - "# ner data module\n", - "class NERDataModule(pl.LightningDataModule):\n", - " def __init__(self, train_ds, val_ds, batch_size=32):\n", - " super().__init__()\n", - " self.train_ds = train_ds\n", - " self.val_ds = val_ds\n", - " self.batch_size = batch_size\n", - "\n", - " def train_dataloader(self):\n", - " return self.train_ds.dataloader(batch_size=self.batch_size, shuffle=True)\n", - "\n", - " def val_dataloader(self):\n", - " return self.val_ds.dataloader(batch_size=self.batch_size*2, shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "\n", - "# ner module\n", - "class NERModule(pl.LightningModule):\n", - " \"\"\"\n", - " PyTorch lightning module for training ner model\n", - " \"\"\"\n", - " def __init__(\n", - " self, model,\n", - " ):\n", - " \"\"\"\n", - " model: huggingface transformer model for ner\n", - " \"\"\"\n", - " super().__init__()\n", - " self.model = model\n", - "\n", - " def forward(self, batch):\n", - " return self.model(\n", - " input_ids=batch['input_ids'],\n", - " attention_mask=batch['attention_mask'],\n", - " labels=batch['labels'])\n", - " \n", - " def training_step(self, batch, batch_idx):\n", - " outputs = self(batch)\n", - " loss = outputs.loss\n", - " self.log(\"loss\", loss)\n", - " self.log(\"acc\", self.calcualte_acc(outputs, batch.labels))\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " outputs = self(batch)\n", - " loss = outputs.loss\n", - " self.log(\"val_loss\", loss)\n", - " self.log(\"val_acc\", self.calcualte_acc(outputs, batch.labels))\n", - " return loss\n", - " \n", - " def calcualte_acc(self, outputs, labels):\n", - " pred_idx = outputs.logits.argmax(-1)\n", - " mask = torch.ones_like(pred_idx)\n", - " mask[labels==-100]=False\n", - " return (pred_idx[mask]==labels[mask]).float().mean()\n", - " \n", - " def configure_optimizers(self):\n", - " # discriminative learning rate\n", - " param_groups = [\n", - " {'params': self.model.roberta.parameters(), 'lr': 5e-6},\n", - " {'params': self.model.classifier.parameters(), 'lr': 1e-3},\n", - " ]\n", - " optimizer = torch.optim.Adam(param_groups, lr=1e-3)\n", - " return optimizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Enhance pipeline" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "def clean_ner_output(self, outputs):\n", - " \"\"\"\n", - " Cleaning output for NER task\n", - " \"\"\"\n", - " results = []\n", - " current = []\n", - " last_idx = 0\n", - " # make to sub group by position\n", - " for output in outputs:\n", - " if output[\"start\"] in [last_idx, last_idx-1]:\n", - " current.append(output)\n", - " else:\n", - " results.append(current)\n", - " current = [output, ]\n", - " last_idx = output[\"end\"]\n", - " if len(current) > 0:\n", - " results.append(current)\n", - "\n", - " # from tokens to string\n", - " strings = []\n", - " for c in results:\n", - " tokens = []\n", - " starts = []\n", - " ends = []\n", - " for o in c:\n", - " tokens.append(o['word'])\n", - " starts.append(o['start'])\n", - " ends.append(o['end'])\n", - "\n", - " new_str = self.tokenizer.convert_tokens_to_string(tokens)\n", - " if new_str != '':\n", - " strings.append(dict(\n", - " word=new_str,\n", - " start=min(starts),\n", - " end=max(ends),\n", - " entity=c[0]['entity_group']\n", - " ))\n", - " return strings" - ] - }, - { - "cell_type": "code", - "execution_count": 281, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "class NERInference:\n", - " \"\"\"\n", - " NER Inference pipeline\n", - " ner = NERInference.from_pretrained('xxxx/xxxx')\n", - " ner.predict(['text1','text2'])\n", - " \"\"\"\n", - "\n", - " def __init__(self, model, tokenizer, name=None):\n", - " super().__init__()\n", - " self.model = model.eval()\n", - " self.tokenizer = tokenizer\n", - " self.name = name if name else \"NER model\"\n", - "\n", - " def __repr__(self):\n", - " return f\"[NERInference on {self.name}]\"\n", - "\n", - " def to(self, device_str):\n", - " self.model = self.model.to(device(device_str))\n", - " return self\n", - "\n", - " @classmethod\n", - " def from_pretrained(cls, tag):\n", - " \"\"\"\n", - " Load from pretrained model and tokenizer\n", - " \"\"\"\n", - " model = AutoModelForTokenClassification.from_pretrained(tag)\n", - " tokenizer = AutoTokenizer.from_pretrained(tag)\n", - " return cls(model=model, tokenizer=tokenizer, name=model.config._name_or_path)\n", - " \n", - " def __call__(self, data, batch_size=32, dev=device(\"cpu\")):\n", - " if type(data) == str:\n", - " return self.batch_predict([data,])\n", - " else:\n", - " return self.predict(data, dev=dev, batch_size=batch_size)\n", - "\n", - " def predict(\n", - " self,\n", - " texts: List[str],\n", - " dev=device(\"cpu\"),\n", - " batch_size: int = 32,\n", - " progress_bar: bool = True\n", - " ) -> pd.DataFrame:\n", - " \"\"\"\n", - " Predict a list of sentences/ paragraphs\n", - " \"\"\"\n", - " # place the model into device\n", - " self.model = self.model.to(dev)\n", - " iterator = list(enumerate(chunkify(texts, bs=batch_size)))\n", - " if progress_bar:\n", - " iterator = tqdm(iterator, leave=False)\n", - "\n", - " # run through iterator\n", - " all_dfs = []\n", - " for i, text_b in iterator:\n", - " # by batch prediction\n", - " batch_df = self.batch_predict(text_b)\n", - " if len(batch_df) > 0:\n", - " # calculate the row number\n", - " batch_df['text_id'] = batch_df.apply(\n", - " lambda row: i*batch_size+row.batch_row_sn, axis=1)\n", - " all_dfs.append(batch_df)\n", - "\n", - " # place the model back to cpu\n", - " self.model = self.model.to(\"cpu\")\n", - " return pd.concat(all_dfs).reset_index(drop=True)\n", - " \n", - " def tokenizing(self, texts):\n", - " inputs = self.tokenizer(\n", - " texts,\n", - " padding=\"max_length\",\n", - " max_length=self.tokenizer.model_max_length,\n", - " return_attention_mask=True,\n", - " return_tensors='pt', truncation=True, return_offsets_mapping=True\n", - " ).to(self.model.device)\n", - " return inputs\n", - "\n", - "\n", - " def batch_predict(self, texts:List[str])-> pd.DataFrame:\n", - " \"\"\"\n", - " Predict a single batch of sentences\n", - " \"\"\"\n", - " id2label = self.model.config.id2label\n", - " inputs = self.tokenizing(texts)\n", - "\n", - " with torch.no_grad():\n", - " outputs = self.model(input_ids=inputs.input_ids,\n", - " attention_mask=inputs.attention_mask)\n", - " inputs = inputs.to(device('cpu'))\n", - "\n", - " pred_idx = outputs.logits.argmax(-1).to(device(\"cpu\"))\n", - " batch_size = pred_idx.size(0)\n", - " offsets = inputs.offset_mapping\n", - " results = []\n", - " for bi in range(batch_size):\n", - " text = texts[bi]\n", - " input_ids = inputs.input_ids[bi]\n", - " word_ids = inputs.word_ids(bi)\n", - " pred_ids = pred_idx[bi]\n", - " # initial values for the row\n", - " last_pos = 0\n", - " previous_has_positive = False\n", - " current_start = 0\n", - " current_index = 0\n", - " current_id = 0\n", - " line = []\n", - " for ti in range(1, len(input_ids)):\n", - " if input_ids[ti] == self.tokenizer.sep_token_id:\n", - " break\n", - " # is the current token an appending sub-word?\n", - " if word_ids[ti] == last_pos:\n", - " pass\n", - " # is current token negative\n", - " elif pred_ids[ti].item() == 0:\n", - " # store the previous hanging prediction\n", - " if previous_has_positive:\n", - " start = current_start\n", - " end = offsets[bi, ti, 0].item()\n", - " line.append({\n", - " \"start\": start, \"end\": end,\n", - " \"entity\": id2label[current_id],\n", - " \"word\": text[start:end],\n", - " \"index\": current_index,\n", - " })\n", - "\n", - " current_start = offsets[bi, ti, 0].item()\n", - " previous_has_positive = False\n", - " current_id = 0\n", - " current_index = ti\n", - " # has positive prediction index, other than zero\n", - " else:\n", - " if previous_has_positive:\n", - " # different than the previous\n", - " if current_id != pred_ids[ti].item():\n", - " start = current_start\n", - " end = offsets[bi, ti, 0].item()\n", - " line.append({\n", - " \"start\": start,\n", - " \"end\": end,\n", - " \"entity\": id2label[current_id],\n", - " \"word\": text[start:end],\n", - " \"index\": current_index,\n", - " })\n", - " current_start = offsets[bi, ti, 0].item()\n", - " # this is the 1st postive predict for a while\n", - " else:\n", - " current_start = offsets[bi, ti, 0].item()\n", - " previous_has_positive = True\n", - " current_index = ti\n", - " current_id = pred_ids[ti].item()\n", - "\n", - " last_pos = word_ids[ti]\n", - " if previous_has_positive:\n", - " start = current_start\n", - " end = offsets[bi, ti, 1].item()\n", - " line.append({\n", - " \"start\": start,\n", - " \"end\": end,\n", - " \"entity\": id2label[current_id],\n", - " \"word\": text[start:end],\n", - " \"index\": current_index,\n", - " })\n", - "\n", - " results.append(line)\n", - " all_dfs = []\n", - " for i, res in enumerate(results):\n", - " sub_df = pd.DataFrame(res)\n", - " sub_df[\"batch_row_sn\"] = i\n", - " all_dfs.append(sub_df)\n", - " return pd.concat(all_dfs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/cross_entropy_weighter.ipynb b/nbs/cross_entropy_weighter.ipynb deleted file mode 100644 index 95018ab..0000000 --- a/nbs/cross_entropy_weighter.ipynb +++ /dev/null @@ -1,1298 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Create multi-task adjustment for cross-entropy\n", - "> A weighter for multi-task learning with different softmax+cross-entropy" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Tools and imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp multitask_ce" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "import torch\n", - "from torch import nn\n", - "import pandas as pd\n", - "import numpy as np\n", - "from typing import Callable" - ] - }, - { - "cell_type": "code", - "execution_count": 120, - "metadata": {}, - "outputs": [], - "source": [ - "from matplotlib import pyplot as plt" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Enters softmax" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "metadata": {}, - "outputs": [], - "source": [ - "softmax = nn.Softmax(-1)\n", - "crit = nn.CrossEntropyLoss()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Experience the problem" - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "metadata": {}, - "outputs": [], - "source": [ - "def test_loss(model,iters:300,get_xy):\n", - " losses=[]\n", - " with torch.no_grad():\n", - " for i in range(iters):\n", - " x,y_true = get_xy(model.nb_output)\n", - " y_vec = model(x)\n", - " loss = model.crit(y_vec,y_true)\n", - " losses.append(loss)\n", - " return torch.stack(losses).mean()" - ] - }, - { - "cell_type": "code", - "execution_count": 245, - "metadata": {}, - "outputs": [], - "source": [ - "def create_softmax_pipeline(\n", - " nb_layers:int,\n", - " nb_output:int,\n", - " hs:int=500,\n", - " crit:nn.Module=nn.CrossEntropyLoss(),\n", - " )->nn.Module:\n", - " modules = (nb_layers-1)*[nn.Linear(hs,hs),nn.Dropout(.3)]\n", - " modules+=[nn.Linear(hs,nb_output),]\n", - " model = nn.Sequential(*modules)\n", - " model.hs = hs\n", - " model.nb_output = nb_output\n", - " model.__class__.crit = crit\n", - " model.__class__.test_loss = test_loss\n", - " return model" - ] - }, - { - "cell_type": "code", - "execution_count": 246, - "metadata": {}, - "outputs": [], - "source": [ - "def random_input(nb_output):\n", - " return torch.rand(2,500),torch.randint(low=0,high=nb_output,size = (2,))\n", - "\n", - "def inbalanced_input(\n", - " bias:float\n", - " ) -> Callable:\n", - " def inbalanced_input_(nb_output:int):\n", - " return torch.rand(2,500),torch.randint(\n", - " low=0,\n", - " high=max(1,int(nb_output*bias)),\n", - " size = (2,)\n", - " )\n", - " return inbalanced_input_" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "* For models with different branches of output, some output 2 category, some output more, like 200,500\n", - "* Their loss will end up in different scale\n", - "* That makes mixing them up fairly hard and unfair to lesser category tasks" - ] - }, - { - "cell_type": "code", - "execution_count": 247, - "metadata": {}, - "outputs": [], - "source": [ - "import random" - ] - }, - { - "cell_type": "code", - "execution_count": 266, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fa780ff8fdab4d4cb932c4dbfe96c44f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "result = []\n", - "\n", - "for i in tqdm(range(50)):\n", - " c = random.randint(2,3000)\n", - " b = 1-random.random()\n", - " loss = create_softmax_pipeline(1,c).test_loss(300,inbalanced_input(b))\n", - " result.append(dict(classes=c,loss=loss.item(),inbl_bias=b))" - ] - }, - { - "cell_type": "code", - "execution_count": 267, - "metadata": {}, - "outputs": [], - "source": [ - "df = pd.DataFrame(result)" - ] - }, - { - "cell_type": "code", - "execution_count": 268, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
classeslossinbl_bias
022447.7840150.719190
124477.8491580.002960
27066.6254160.297956
32725.7390010.154829
428607.9948820.179631
523767.8230360.554090
64746.2306430.497769
714767.3176500.790722
87296.6523120.671657
911837.1610010.422553
1017927.5710340.375524
1126017.9283470.937953
1219887.6534010.448365
136176.4717110.669720
1412577.1966650.180724
1512087.1422720.603882
1615717.4319090.307233
1718157.5657630.271957
1823707.8402560.130841
1924217.8353620.470347
2026087.9338520.362477
2118337.5664140.551423
2217697.5360470.630381
2319507.6157930.795910
2419107.5857390.343632
2523017.7984650.829628
2623777.8692110.016988
2714967.3931780.119690
2811797.1470580.568997
2914757.3567400.693435
3024097.8359540.455461
315426.3284030.733283
326996.6098920.674927
3329668.0627700.268944
3428318.0060010.822981
3516577.4594730.323996
3614697.3166780.062805
376746.5796550.812632
386546.5500600.601652
391945.3505550.238536
402505.5670150.270642
4123687.8191310.117133
4215737.4296060.003537
4329948.0539260.764925
4421557.7413450.941843
4526427.9191260.612718
464866.2425270.701243
4719487.6443660.970036
4826267.9323630.657842
4916607.4668120.855916
\n", - "
" - ], - "text/plain": [ - " classes loss inbl_bias\n", - "0 2244 7.784015 0.719190\n", - "1 2447 7.849158 0.002960\n", - "2 706 6.625416 0.297956\n", - "3 272 5.739001 0.154829\n", - "4 2860 7.994882 0.179631\n", - "5 2376 7.823036 0.554090\n", - "6 474 6.230643 0.497769\n", - "7 1476 7.317650 0.790722\n", - "8 729 6.652312 0.671657\n", - "9 1183 7.161001 0.422553\n", - "10 1792 7.571034 0.375524\n", - "11 2601 7.928347 0.937953\n", - "12 1988 7.653401 0.448365\n", - "13 617 6.471711 0.669720\n", - "14 1257 7.196665 0.180724\n", - "15 1208 7.142272 0.603882\n", - "16 1571 7.431909 0.307233\n", - "17 1815 7.565763 0.271957\n", - "18 2370 7.840256 0.130841\n", - "19 2421 7.835362 0.470347\n", - "20 2608 7.933852 0.362477\n", - "21 1833 7.566414 0.551423\n", - "22 1769 7.536047 0.630381\n", - "23 1950 7.615793 0.795910\n", - "24 1910 7.585739 0.343632\n", - "25 2301 7.798465 0.829628\n", - "26 2377 7.869211 0.016988\n", - "27 1496 7.393178 0.119690\n", - "28 1179 7.147058 0.568997\n", - "29 1475 7.356740 0.693435\n", - "30 2409 7.835954 0.455461\n", - "31 542 6.328403 0.733283\n", - "32 699 6.609892 0.674927\n", - "33 2966 8.062770 0.268944\n", - "34 2831 8.006001 0.822981\n", - "35 1657 7.459473 0.323996\n", - "36 1469 7.316678 0.062805\n", - "37 674 6.579655 0.812632\n", - "38 654 6.550060 0.601652\n", - "39 194 5.350555 0.238536\n", - "40 250 5.567015 0.270642\n", - "41 2368 7.819131 0.117133\n", - "42 1573 7.429606 0.003537\n", - "43 2994 8.053926 0.764925\n", - "44 2155 7.741345 0.941843\n", - "45 2642 7.919126 0.612718\n", - "46 486 6.242527 0.701243\n", - "47 1948 7.644366 0.970036\n", - "48 2626 7.932363 0.657842\n", - "49 1660 7.466812 0.855916" - ] - }, - "execution_count": 268, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### The pattern" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "There are $n$ tasks $T$ with a set of numbers of classes $C$ where $c_{i}$ is the number of the classes of task $T_{i}$\n", - "\n", - "The number of classes $c_{i}$ has certain correlation of the average loss $L_{i}$\n", - "\n", - "Where $log_{10}(c_{i})$ has a clear linear relationship with $L$\n", - "\n", - "$L_{i} = a.log_{10}(c_{i})$, where a is a fixed constant" - ] - }, - { - "cell_type": "code", - "execution_count": 269, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 269, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAVw0lEQVR4nO3dfZBd913f8ffXqzWs8rQmFiTaSBa0QW0Tjy13a8vNlAl1iMhDE8EorV1MwMOMcGBC3BYNces6ENwSum0hqYcIDxkaiOu6qMrWMMYy0xDIMCMxa0u24jhiRB4krQLIISvX9kLWq2//2Hudq7v34dzVfTz7fs3s6N5zftr9/sbWR0ff8zu/G5mJJGn0XTboAiRJ3WGgS1JJGOiSVBIGuiSVhIEuSSWxYVA/+Morr8xt27YN6sdL0kh67LHHnsnMTY3ODSzQt23bxtzc3KB+vCSNpIj4arNztlwkqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKomBLVuUpLKbPTrPzKETnF1YZPPkBPt2bWf3jqme/Tyv0CWpB2aPznPnwePMLyySwPzCInc8eIwdH36U2aPzPfmZBrok9cDMoRMsLi2vOv6NF5a48+DxnoS6gS5JPXB2YbHpucWlZWYOnej6z7SHLkldUN8vf9XEOAuLS03Htwr8tSp0hR4R/yoinoqIz0fEAxHx7XXnvy0iHoyIkxFxJCK2db1SSRpSjfrlrcIcYPPkRNfraBvoETEF/AwwnZlvBMaAm+uG/QTwjcz8u8CvAL/c7UIlaVg165c3MzE+xr5d27teR9Ee+gZgIiI2ABuBs3Xn3w18svL6AHBTRER3SpSk4daufTI5Mc7U5AQBTE1O8Es/fHVPli+27aFn5nxE/GfgFLAIPJqZj9YNmwJOV8a/GBHngVcDz9QOioi9wF6ArVu3Xnr1kjQENk9OMN8i1M8vLnHsQ2/teR1FWi5XsHIF/t3AZuBlEXFr/bAGvzVXHci8LzOnM3N606aG+7NL0sjZt2s7E+NjTc/3ol/eSJFVLm8BvpyZ5wAi4iDwj4FP1Yw5A2wBzlTaMq8C/rrLtUrSQNw1e5z7D5966Sr1ZZeP8R9+6Fttk+qvv/C7T/GNFy6+GdqrfnkjRXrop4CdEbGx0he/CXi6bsxDwI9VXu8BPpOZq67QJWnU3DV7nE/VhDnA899c5o4Hj3HX7PGXju3eMcXRu9/Kr/6La/vSL2+kSA/9SEQcAB4HXgSOAvdFxIeBucx8CPgE8NsRcZKVK/P6VTCSNJIeOHK66bn7D59i+qrvuCiwd++Y6luA1yv0YFFmfgj4UN3hu2vO/w3wni7WJUlDYblFsyFZWbI4qACv55OiktatIrshjkW0DPVePPG5Vu7lImldavR0Z6NNs265YUvL79OvFSxFGOiS1qVGT3cuLq3c7HzTRz7zUrDfs/tqbt3Z+LmZfq5gKcJAl7QutXoQqP5q/Z7dV/OVj7xjoCtYirCHLmldatcbr25xOywrWIow0CWtK9Uboa3CvGqYbngWYaBLKr1qiM8vLBI02JekiWG64VmEgS6p1KqrWao3QDt5hH2YbngW4U1RSaXW6V7lVW/6O98x1P3yRrxCl1Qq9Q8LtVrN0sytO7dyz+6re1Bdbxnokkqjvr3SSc98LIJbbtgykkFeZaBLKo1G7ZWEVaFefT/V5HH/UWWgSyqNZssMq+Hdas+WMjDQJZVGs5751OQEf/LBfzqAivrLVS6SSqPRR8EN234rveQVuqTSqLZR2m2JW1YGuqShV7sUcXLjOJlwfnGpYWAP+34rvWSgSxpq9UsRaz+EuborIrBuQ7yWPXRJQ63dk57VXRFloEsackV2PBy1XRF7xUCXNNSK7Hg4arsi9oqBLmlozR6d5/m/fbHlmPW0LLEdb4pKGqhme5VvHL+MpQvJ0nLznVjGIobuY+AGyUCXNDCt9ip/YelCy987MT5mmNcx0CUNzM8/9NSa9iov26Za3WKgSxqI2aPzLCwutR9YZ73sy7IW3hSVNBBrWTvuDdDW2gZ6RGyPiGM1X89GxB11Y94cEedrxtzdu5IllUGRteOXBVyxcZxg5crcnnlrbVsumXkCuBYgIsaAeeDTDYZ+LjPf2d3yJJVVu4+Hu2LjOB/6Z28wwDvQaQ/9JuDPM/OrvShGUvnUf8Zn9Wbmvl3bL1rhAq5cuVSdBvrNwANNzt0YEU8AZ4Gfzcyn6gdExF5gL8DWrVs7/NGSRs3s0Xn2HXjipbXk8wuL7DvwBOBWt70QmUU+PhUi4nJWwvoNmfmXdedeCVzIzOci4u3ARzPz9a2+3/T0dM7Nza2xbEmjYMeHH71od8SqKzaOc/Tutw6gotEXEY9l5nSjc51cob8NeLw+zAEy89ma1w9HxK9FxJWZ+Uzn5UoaRXfNHueBI6dZzmQsgltu2NIwzIGmx3VpOlm2eAtN2i0R8ZqIiMrr6yvf9+uXXp6kUXDX7HE+dfgUy5V/8S9n8qnDpwZc1fpT6Ao9IjYCPwD8ZM2x2wEycz+wB3hfRLwILAI3Z9FejqSR98CR0x2Nn5wY71El61uhQM/MF4BX1x3bX/P6XuDe7pYmaRg1WrWy3OL6bfyyYOlCXvT+59/1hn6Uuu74pKikwqqbac0vLJJ86yPgVhquq41FMPOea5ianHjp4aCZ91zjSpYecS8XSYU12kxrcWmZjeOXNdwd8ZYbtqzrD23uN6/QJRXSajOtxaUL3LpzK2OVS/WxCG7duZV7dl/dzxLXPa/QJRXSajOtzZMT3LP7agN8wLxCl1RIq8203AFxOBjokgqZGG8cF5ePhT3yIWGgSypk8cXGHwlXuyRRg2WgSyqk2VJzHyEcHga6pELGmiw2b3Zc/WegSyrklhu2dHRc/eeyRUmFVJck1u+o6FLF4VF4P/Rucz90Sepcq/3QbblIUknYcpHWgWaf66lyMdClkqvukFjdVKu6QyJgqJeMgS6VSKMr8ZlDJxrukDhz6ISBXjIGulQSd80e5/7Dp6guc6heideHeVWrvVk0mrwpKpXA7NH5i8K8qlmYw8oOiSoXA10qgZlDJ1aFeSsT42PukFhCtlykEijSPhmL4EKmq1xKzECXSmDz5ATzbUL9QiZf/sg7+lSRBsGWi1QC+3ZtZ2J8rOUYe+bl5xW6VALV9snMoRPMLywScFFP3Z75+mCgSyWxe8fUS8Huk6Hrk4EulVBtuGv9sIcuSSVhoEtSSbQN9IjYHhHHar6ejYg76sZERHwsIk5GxJMRcV3vSpYkNdK2h56ZJ4BrASJiDJgHPl037G3A6ytfNwAfr/wqSeqTTlsuNwF/nplfrTv+buC3csVhYDIiXtuVCiVJhXQa6DcDDzQ4PgWcrnl/pnLsIhGxNyLmImLu3LlzHf5oSVIrhQM9Ii4H3gX8TqPTDY6t2isoM+/LzOnMnN60aVPxKiVJbXVyhf424PHM/MsG584AW2revw44eymFSZI600mg30LjdgvAQ8B7K6tddgLnM/Nrl1ydJKmwQk+KRsRG4AeAn6w5djtAZu4HHgbeDpwEXgBu63qlkqSWCgV6Zr4AvLru2P6a1wn8dHdLkyR1widFJakkDHRJKgkDXZJKwkCXpJIw0CWpJAx0SSoJA12SSsJAl6SSMNAlqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKgkDXZJKwkCXpJIo9AEXUhnNHp1n5tAJzi4ssnlygn27trN7x9Sgy5LWzEDXujR7dJ47Dx5ncWkZgPmFRe48eBzAUNfIMtC1btRekV8WwXLmRecXl5aZOXTCQNfIMtC1Ltw1e5z7D5+iGuH1YV51dmGxf0VJXeZNUZXe7NH5i8K8lc2TEz2vR+oVr9BVSvXtlSJhPjE+xr5d23tem9QrBrpKp/6GZ7P2CsBYBBcyXeWiUjDQVRrVq/L5gn3wAP7LP7/GEFdpGOgqhfqr8nYC+JGdWw1zlYqBrlKYOXSibZjbXlHZGegqhXbLDSfGx/ilH77aEFepFVq2GBGTEXEgIr4YEU9HxI11598cEecj4ljl6+7elCs11mq54dTkhGGudaHoFfpHgUcyc09EXA5sbDDmc5n5zu6VJhW3b9f2VT10r8q13rQN9Ih4JfB9wI8DZOY3gW/2tiypsWYbalVD2822tJ4VuUL/HuAc8JsRcQ3wGPCBzHy+btyNEfEEcBb42cx8qv4bRcReYC/A1q1bL6lwrT/tNtSqDXZpPSrSQ98AXAd8PDN3AM8DH6wb8zhwVWZeA/w3YLbRN8rM+zJzOjOnN23adAllaz1qtJKluqGWpGKBfgY4k5lHKu8PsBLwL8nMZzPzucrrh4HxiLiyq5Vq3Wu2ksUNtaQVbQM9M/8COB0R1U0ubgK+UDsmIl4TEVF5fX3l+369y7VqnWu2ksUNtaQVRXdbfD9wf0Q8CVwL/MeIuD0ibq+c3wN8vtJD/xhwc2aLDTSkNdi3azsT42MXHXNDLelbYlC5Oz09nXNzcwP52Rpdfmyc1ruIeCwzpxud80lRjRRXskjN+QEXklQSBroklYQtFw2EvXCp+wx09V27Jz4lrY0tF/WdT3xKvWGgq+984lPqDQNdfecTn1JvGOjqO5/4lHrDm6LqO/cul3rDQNclqS4/nF9YZCyC5UymCgS0T3xK3Wega83qlx8uV/YFchmiNBj20LVmjZYfVrkMUeo/A11r1m6ZocsQpf4y0LVm7ZYZugxR6i8DXWvWaPlhlcsQpf7zpqjWrHb5YaerXCR1n4GuS+LyQ2l42HKRpJIw0CWpJAx0SSoJA12SSsJAl6SSMNAlqSQMdEkqCQNdkkrCQJekkigU6BExGREHIuKLEfF0RNxYdz4i4mMRcTIinoyI63pTriSpmaKP/n8UeCQz90TE5cDGuvNvA15f+boB+HjlV0lSn7S9Qo+IVwLfB3wCIDO/mZkLdcPeDfxWrjgMTEbEa7terSSpqSItl+8BzgG/GRFHI+I3IuJldWOmgNM1789Ujl0kIvZGxFxEzJ07d27NRUuSVisS6BuA64CPZ+YO4Hngg3VjosHvy1UHMu/LzOnMnN60aVPHxUqSmisS6GeAM5l5pPL+ACsBXz9mS8371wFnL708SVJRbQM9M/8COB0R1Y+fuQn4Qt2wh4D3Vla77ATOZ+bXuluqJKmVoqtc3g/cX1nh8iXgtoi4HSAz9wMPA28HTgIvALf1oFZJUguFAj0zjwHTdYf315xP4Ke7WJckqUM+KSpJJWGgS1JJGOiSVBIGuiSVhIEuSSVhoEtSSRjoklQSBroklYSBLkklYaBLUkkU3ctFBcwenWfm0AnOLiyyeXKCfbu2s3vHqm3hJaknDPQumT06z50Hj7O4tAzA/MIidx48DmCoS+oLWy5dMnPoxEthXrW4tMzMoRMDqkjSemOgd8nZhcWOjktStxnoXbJ5cqKj45LUbQZ6l+zbtZ2J8bGLjk2Mj7Fv1/Ymv0OSusubol1SvfHpKhdJg2Kgd9HuHVMGuKSBseUiSSVhoEtSSRjoklQSBroklYQ3RQtwjxZJo8BAb8M9WiSNClsubbhHi6RRYaC34R4tkkaFgd6Ge7RIGhWFAj0ivhIRxyPiWETMNTj/5og4Xzl/LCLu7n6pg+EeLZJGRSc3Rb8/M59pcf5zmfnOSy1o2LhHi6RR4SqXAtyjRdIoKNpDT+DRiHgsIvY2GXNjRDwREb8fEW9oNCAi9kbEXETMnTt3bk0FS5IaK3qF/qbMPBsR3wn8QUR8MTP/uOb848BVmflcRLwdmAVeX/9NMvM+4D6A6enpvMTaJUk1Cl2hZ+bZyq9/BXwauL7u/LOZ+Vzl9cPAeERc2eVaJUkttA30iHhZRLyi+hp4K/D5ujGviYiovL6+8n2/3v1yJUnNFGm5fBfw6UpebwD+R2Y+EhG3A2TmfmAP8L6IeBFYBG7OTFsqktRHbQM9M78EXNPg+P6a1/cC93a3NElSJ3xSVJJKwkCXpJIY2QeL3KNcki42koHuHuWStNpItlzco1ySVhvJQHePcklabSQD3T3KJWm1kQx09yiXpNVG8qaoe5RL0mojGejgHuWSVG8kWy6SpNUMdEkqCQNdkkrCQJekkjDQJakkDHRJKokY1AcLRcQ54KsD+eGX5krgmUEX0QVlmQeUZy5lmQeUZy7DOI+rMnNToxMDC/RRFRFzmTk96DouVVnmAeWZS1nmAeWZy6jNw5aLJJWEgS5JJWGgd+6+QRfQJWWZB5RnLmWZB5RnLiM1D3voklQSXqFLUkkY6JJUEgZ6AxGxJSL+MCKejoinIuIDLcb+o4hYjog9/ayxiKLziIg3R8Sxypg/6nedRRSZS0S8KiJ+NyKeqIy5bRC1thIR3x4Rf1pT4y80GPNtEfFgRJyMiCMRsa3/lbZXcC7/OiK+EBFPRsT/jYirBlFrK0XmUTN2T0RkRAznUsbM9KvuC3gtcF3l9SuAPwP+QYNxY8BngIeBPYOuey3zACaBLwBbK++/c9B1X8Jc/i3wy5XXm4C/Bi4fdO11NQbw8srrceAIsLNuzE8B+yuvbwYeHHTdlzCX7wc2Vl6/bxjnUmQeNf/f/TFwGJgedN2NvrxCbyAzv5aZj1de/z/gaaDRp2m8H/jfwF/1sbzCCs7jXwIHM/NUZdwozyWBV0REAC9nJdBf7GuhbeSK5ypvxytf9SsT3g18svL6AHBTZU5DpchcMvMPM/OFytvDwOv6WGIhBf+bAPwi8J+Av+lXbZ0y0Nuo/HN3Byt/a9cenwJ+CNjf/6o612wewPcCV0TEZyPisYh4b79r61SLudwL/H3gLHAc+EBmXuhrcQVExFhEHGPlQuAPMrN+HlPAaYDMfBE4D7y6v1UWU2AutX4C+P3+VNaZdvOIiB3Alsz8vYEUWJCB3kJEvJyVK/A7MvPZutO/CvxcZi73v7LOtJnHBuAfAu8AdgH/PiK+t88lFtZmLruAY8Bm4Frg3oh4ZZ9LbCszlzPzWlauVq+PiDfWDWl0NT6U64sLzAWAiLgVmAZm+llfUa3mERGXAb8C/JtB1VeUgd5ERIyzEhz3Z+bBBkOmgf8ZEV8B9gC/FhG7+1hiIQXmcQZ4JDOfz8xnWOkRXtPPGosqMJfbWGkfZWaeBL4M/L1+1tiJzFwAPgv8YN2pM8AWgIjYALyKlfbR0GoxFyLiLcC/A96VmX/b59I60mQerwDeCHy28ud9J/DQMN4YNdAbqPQrPwE8nZn/tdGYzPzuzNyWmdtY6XP+VGbO9rHMtorMA/g/wD+JiA0RsRG4gZX+9FApOJdTwE2V8d8FbAe+1J8Ki4mITRExWXk9AbwF+GLdsIeAH6u83gN8Jit35YZJkblUWhW/zkqYD+X9mXbzyMzzmXllzZ/3w6zMZ24gBbewYdAFDKk3AT8KHK/01WBlBcVWgMwcib45BeaRmU9HxCPAk8AF4Dcy8/MDqba1Iv9NfhH47xFxnJW2xc9V/tUxTF4LfDIixli5oPpfmfl7EfFhYC4zH2LlL67fjoiTrFyZ3zy4clsqMpcZVm5Q/07lvu6pzHzXwCpurMg8RoKP/ktSSdhykaSSMNAlqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKon/D8XLOB7FzJCYAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.scatter(np.log10(df.classes),df.loss,)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### The solution\n", - "\n", - "Assume we have certain function, will produce a weight that will bring each cross entropy loss to the same constant scale ${a}$.\n", - "\n", - "$L_{i}.f(c_{i})=a$\n", - "\n", - "$f(c_{i})a.log_{10}(c_{i})=a$\n", - "\n", - "Here we can get how to calculate $f(c_{i})$\n", - "\n", - "$f(c_{i})=\\frac{1}{log_{10}(c_{i})}$" - ] - }, - { - "cell_type": "code", - "execution_count": 270, - "metadata": {}, - "outputs": [], - "source": [ - "def adjust(nb_class):\n", - " return 1/np.log10(nb_class)" - ] - }, - { - "cell_type": "code", - "execution_count": 271, - "metadata": {}, - "outputs": [], - "source": [ - "df[\"lambda_weighted\"] = df.apply(\n", - " lambda row:row['loss']*adjust(row['classes']),\n", - " axis=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For now it's a about the same scale of loss" - ] - }, - { - "cell_type": "code", - "execution_count": 272, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 272, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAa3ElEQVR4nO3df4ycx33f8fcn8lG+iHQpmZdUOpOhTCWinYoQmWuQlrJRq4Ho2EBIWUbotpDkKADbugWkQiZAK0FqVwgslrXgFi3CMlAQKyEcKxbJ0mVTmqCYsHIrxkfyTIo8MZJ/KBJJWGfJNMX66hypb//YOWu13r199m5/PDv3eQGH25uZ3ZvZ59nvPs/MPM8oIjAzs3z9VK8rYGZmneVAb2aWOQd6M7PMOdCbmWXOgd7MLHNv63UFai1ZsiSWL1/e62qYmfWVo0ePfi8ihurllS7QL1++nNHR0V5Xw8ysr0h6sVGeu27MzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5poGeklLJR2SNC7plKT765RZL+mEpDFJo5JuS+kfSGnTP/9P0oZONMTMzOorcmXsZeDBiDgmaRFwVNKBiDhdVeYgsDciQtIq4AlgZUQcAm4FkHQd8ALw1fY2wczMZtL0iD4izkfEsfT4dWAcGK4pcyneXKrqGqDeslUfBf48In44tyqbmVkrWuqjl7QcWA0cqZN3p6TngH3AfXWe/jHgiw1ed1Pq8hmdmJhopUpmZtZE4UAvaSHwJPBARFyszY+I3RGxEtgAPFzz3OuBW4D99V47InZExEhEjAwN1b35mpmZzVKhQC9pgEqQ3xkRu2YqGxGHgRWSllQl/wawOyKmZl1TMzOblSKzbgQ8BoxHxKMNytyUyiFpDbAAeLWqyD+hQbeNmZl1VpFZN2uBu4GTksZS2kPAMoCI2A7cBdwjaQqYBDZOD86mfv2lwF+2teZmZlZI00AfEU8DalJmK7C1Qd53qJmlY2Zm3eMrY83MMudAb2aWOQd6M7PMOdCbmWXOgd7MLHMO9GZmmXOgNzPLnAO9mVnmHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5hzozcwyV2TN2KWSDkkal3RK0v11yqyXdELSmKRRSbdV5S2T9NX0/NNpaUEzM+uSImvGXgYejIhjkhYBRyUdiIjTVWUOAnsjIiStAp4AVqa8x4Hfi4gDkhYCb7SzAWZmNrOmR/QRcT4ijqXHrwPj1KwBGxGXphcDB64BphcGfy/wtog4UFXuh22sv5mZNdFSH33qdlkNHKmTd6ek54B9wH0p+ReAC5J2STouaZukq+o8d1Pq8hmdmJhotQ1mZjaDwoE+dbs8CTwQERdr8yNid0SsBDYAD6fktwHvAz4J/H3g3cDH6zx3R0SMRMTI0NBQy40wM7PGCgV6SQNUgvzOiNg1U9mIOAyskLQEeBk4HhHfiojLwB5gzRzrbGZmLSgy60bAY8B4RDzaoMxNqRyS1gALgFeBrwPXSpo+TL8dOF3vNczMrDOKzLpZC9wNnJQ0ltIeApYBRMR24C7gHklTwCSwMQ3OXpH0SeBg+iI4CvxBm9tQyJ7jZ9m2/wznLkxyw+JBNq+7mQ2rh5s/0cysz+nNyTLlMDIyEqOjo219zT3Hz/KpXSeZnLry47TBgav47EducbA3syxIOhoRI/Xy5sWVsdv2n3lLkAeYnLrCtv1nelQjM7PumReB/tyFyZbSzcxyMi8C/Q2LB1tKNzPLybwI9JvX3czgwFuv0xocuIrN627uUY3MzLqnyKybvjc94OpZN2Y2H82LQA+VYO/Abmbz0bzoujEzm88c6M3MMudAb2aWOQd6M7PMOdCbmWXOgd7MLHMO9GZmmXOgNzPLnAO9mVnmHOjNzDLnQG9mlrkia8YulXRI0rikU5Lur1NmvaQTksYkjUq6rSrvSkofk7S33Q0wM7OZFbmp2WXgwYg4JmkRcFTSgYioXuT7ILA3IkLSKuAJYGXKm4yIW9tbbTMzK6rpEX1EnI+IY+nx68A4MFxT5lK8ufjsNUC5FqI1M5vHWuqjl7QcWA0cqZN3p6TngH3AfVVZb0/dOc9I2tDgdTelMqMTExOtVMnMzJooHOglLQSeBB6IiIu1+RGxOyJWAhuAh6uylqWVyf8p8HlJK+o8d0dEjETEyNDQUMuNMDOzxgoFekkDVIL8zojYNVPZiDgMrJC0JP19Lv3+FvAXVM4IzMysS4rMuhHwGDAeEY82KHNTKoekNcAC4FVJ10q6OqUvAdYCp+u9hpmZdUaRWTdrgbuBk5LGUtpDwDKAiNgO3AXcI2kKmAQ2phk47wH+q6Q3qHypPFIzW8fMzDqsaaCPiKcBNSmzFdhaJ/1/A7fMunZmZjZnvjLWzCxzDvRmZplzoDczy5wDvZlZ5hzozcwy50BvZpY5B3ozs8w50JuZZc6B3swscw70ZmaZc6A3M8ucA72ZWeYc6M3MMudAb2aWOQd6M7PMOdCbmWXOgd7MLHNF1oxdKumQpHFJpyTdX6fMekknJI1JGpV0W03+OySdlfSf21l5MzNrrsiasZeBByPimKRFwFFJB2rWfj0I7E3rxK4CngBWVuU/DPxl22ptZmaFNT2ij4jzEXEsPX4dGAeGa8pciohIf14DTD9G0i8BPwt8tV2VNjOz4lrqo5e0HFgNHKmTd6ek54B9wH0p7aeAzwGbm7zuptTlMzoxMdFKlczMrInCgV7SQuBJ4IGIuFibHxG7I2IlsIFKVw3AJ4D/EREvzfTaEbEjIkYiYmRoaKh47avsOX6WtY88xY1b9rH2kafYc/zsrF7HzCw3RfrokTRAJcjvjIhdM5WNiMOSVkhaAvwD4H2SPgEsBBZIuhQRW+Za8Wp7jp/lU7tOMjl1BYCzFyb51K6TAGxYPTzTU83Msldk1o2Ax4DxiHi0QZmbUjkkrQEWAK9GxD+LiGURsRz4JPB4u4M8wLb9Z34c5KdNTl1h2/4z7f5XZmZ9p8gR/VrgbuCkpLGU9hCwDCAitgN3AfdImgImgY1Vg7Mdd+7CZEvpZmbzSdNAHxFPA2pSZiuwtUmZPwL+qIW6FXbD4kHO1gnqNywe7MS/MzPrK1lcGbt53c0MDlz1lrTBgavYvO7mHtXIzKw8Cg3Glt30gOu2/Wc4d2GSGxYPsnndzR6INTMjk0APlWDvwG5m9pOy6LoxM7PGHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5hzozcwy50BvZpY5B3ozs8xlc1MzM+sve46f9R1nu8SB3sy6zus8d1eRNWOXSjokaVzSKUn31ymzXtIJSWOSRiXdltJ/TtLRlH5K0r/oRCOsv+w5fpa1jzzFjVv2sfaRp9hz/Gyvq2Rd5nWeu6vIEf1l4MGIOCZpEXBU0oGIOF1V5iCwNyJC0irgCWAlcB74hxHxI0kLgWcl7Y2Ic+1uiPUHH8kZeJ3nbmt6RB8R5yPiWHr8OjAODNeUuVS1GPg1QKT0v42IH6X0q4v8P8ubj+QMGq/n7HWeO6OlwCtpObAaOFIn705JzwH7gPuq0pdKOgG8BGytdzQvaVPq8hmdmJhorQXWV3wkZ+B1nrut8GBs6np5EnggIi7W5kfEbmC3pPcDDwO/mtJfAlZJugHYI+nLEfHdmufuAHYAjIyMBH3EMwdac8PiQc7WCeq9OJLztusdr/PcXYUCvaQBKkF+Z0TsmqlsRByWtELSkoj4XlX6OUmngPcBX55LpcvC/c2t27zu5re8Z9CbIzlvu97zOs/dU2TWjYDHgPGIeLRBmZtSOSStARYAr0p6l6TBlH4tsBbIpjPW/c2t27B6mM9+5BaGFw8iYHjxIJ/9yC1d/8B729l8UuSIfi1wN3BS0lhKewhYBhAR24G7gHskTQGTwMY0A+c9wOckBSDgP0TEyXY3olfc3zw7ZTiS87az+aRpoI+Ip6kE6ZnKbAW21kk/AKyade1Krkz9zdaaMmw7jxFYt3i64xx45sDslOGCqV5vu+kxgrMXJgneHCPwxWPWCb4Fwhx0YuZA7kd5ZRkE7fWsj5nGCHLa3lYOevM6p3IYGRmJ0dHRXlejJ2qD4LTFgwN8+td/MYsAsPaRp+p2mQwvHuRrW27vQY1648Yt+6j3yRPw7Uc+3O3qWAYkHY2IkXp57ropkXpHeQAXJqf4N18aY3kG94bxIGiFrwy1bnKgL5GZgt300V+/9+X2U4Dr5FhCr8cIbH5xoC+RosGun+d790uA6/RgaVmuJ7D5wYOxJVLvqtFG+rWro9eDoEV1Y7C0DNcT2PzgQF8i0x/6z3zlFN//4dSMZcvY1VFUPwS4Rl+k9QaSzcrOXTdt0M6+3A2rhzn+u3fw+Y23MjxDMD97YbLvB2bLrNEXqaBU73kZrkmw8nOgn6NO9eVuWD3M17bczuc33srAVfUvTO73gdky27zu5rqXgweUZnykny+68hdUdznQz1Gnb461bf8Zpq40vtahnwdmy2zD6uG689yhPOMjZbkxW6tBu5+/oPqV++jnqNPzwou8TlkCT26GS3A/nJl0et8rcpX2bK507uZVwblfaV6Uj+jnqNPzwou8TlkCT7uU5bS+7FNBO7nvFT3qns1ZRbcumiv7mUM39/PsA32n38xOB4N6r9+p/1UG3fhwFt0nujnXfTb7aSf3vaIBfDZBu1sXzX3mK6dK0bVVT7e/hLLuuunGDbQ6PS+89vUX//QAEfCDyaksT0U7fVrf6j7Rjamgs91PO7nvFQ3gs7ndczdWGdtz/GzDKcpl6Ors9k3tsg703XozOx0M+mHeebt0+rS+jHeNnEudOrVvFA3gswna3bhobqaj9jJ0dXb7nk9ZB3rfQKv/dHpBkDLsE7UDhI0uwurlflo0gM82aHf64GWm964MXZ3dXvimaaCXtBR4HPi7wBvAjoj4jzVl1gMPp/zLwAMR8bSkW4HfB94BXAF+LyK+1N4mNFaGVYSsNZ0+re/1PlGvm0ZQdyrnbOvUjpkmrQTwMp5xNtrOiwcHSlHXbnRfVWt6P3pJ1wPXR8QxSYuAo8CGiDhdVWYh8H/TOrGrgCciYqWkXwAiIp6XdEN67nsi4kKj/9fO+9HXu7/74MBVvnlUyXVySlyv94lG9+OvDfazrVOv21dbl15Nbaz3Pky/x8MlGdtq9/sz0/3oi6wZex44nx6/LmkcGAZOV5W5VPWUa0j7bET8dVWZc5JeAYaAhoG+nfrhBlqe5/uTOnmE2Ot9olGXwnQAmmudyjIG0euVxKq3c+1ZUzvrMpfPbzfPhFpaYUrScuAw8Pci4mJN3p3AZ4GfAT4cEf+nJv+XgS8AvxgRb9TkbQI2ASxbtuyXXnzxxZYb0o/KdPRl3dHpFbbKsnJVmVYS61Rdyvb5bcsKU6l75kkq/e8Xa/MjYndErAQ2UOmvr37u9cAfA79ZG+TTc3dExEhEjAwNDRWtUt8ryyXsNjetzIFvNvd9rtd9lGVhlzIMejf7n3OtSz99fgsFekkDVIL8zojYNVPZiDgMrJC0JD33HcA+4Hci4pk51jcrZfow2Oy0euHLTBdhteMimrJczVuWL5yZ/udc69JPn9+mgV6SgMeA8Yh4tEGZm1I5JK0BFgCvSloA7AYej4g/a1+181CmD4PNzmyO6qbvTPrtRz7M17bc/pb+5LkeIZZl5aqyfOF0si799PktMo9+LXA3cFLSWEp7CFgGEBHbgbuAeyRNAZPAxjQD5zeA9wPvlPTx9NyPR8QY1vUpVmXU74PR7Tyqa9drlWG6Y5FB725t+04NwPfT57elwdhuaOf0yn7Q74FuLso2mDUb7RzoK9MAZqf9zp6T7Hzmb94ycFy26Y9FlOnzO6fpldZZZTj66pWyTAWciw+sHOJPnvmbuumt6qcjxLnYc/zsTwR56Mz0x05r1+e3018YDvTWM/00mNXIoecmWkqfSa/n+HfLtv1nGi7qMm1y6gqf+cqp7N8L6M41Bw701jO9vh1BO7T7y2o+nOEVfW++/8OpH9+Bsp+O8lvVjTPb7O9Hb+VVppkZs9VPMy/KYrbvTVnnqM9VN85sHeitZ8oyFXAucviy6rZmi+nMpJ+69YrqxsGCu246qEwj8mXV710VZehX77f9rN579oGVQ3zxyEtcaTILMMczpW4Mwnt6ZYfkMHXQyi+n/azRfXqm9Wu7imjHl7WnV/ZADlMHrfxy2s9mWoSlXXPry3r20+kzWwf6Dslh6qCVX077WaMujHYdxff61sm95MHYDvFsDOuGnPazTg/Od/puk3O982gn+Yi+Q+bLVY7WW7ntZ53swujk2U/ZzxZ8RN8hOUwdtPLzflZcJ89+yn5veh/Rd1C/Tx208mk0mOj9rLlOnv2UfazEgd6sT5S9e6DsOnnNQ9lv5+FAb9YncppK2SudOvsp+1iJA71Znyh798B8VoYrpGfiQG/WJ8rePTDflXmspMiasUslHZI0LumUpPvrlFkv6YSkMUmjkm6ryvufki5I+u/trrxZq8o817kZ30DNZqvIEf1l4MGIOCZpEXBU0oGIOF1V5iCwN60Tuwp4AliZ8rYBPw3883ZW3KxV/T6YWfbuASuvpoE+Is4D59Pj1yWNA8PA6aoyl6qecg1vrgpGRByU9I/aVWGz2cphMLPM3QNWXi1dMCVpObAaOFIn705JzwH7gPvaUTmzdmo0aNnoRlpmuSgc6CUtBJ4EHoiIi7X5EbE7IlYCG4CHW6mEpE2pb390YqL1tTbNimg0aCnoq756s1YVCvSSBqgE+Z0RsWumshFxGFghaUnRSkTEjogYiYiRoaGhok8za8nmdTejOukBpblU3awTisy6EfAYMB4RjzYoc1Mqh6Q1wALg1XZW1GyuNqwebriwheeiW86KzLpZC9wNnJQ0ltIeApYBRMR24C7gHklTwCSwMdLSVZL+F5UZOAslvQz8VkTsb28zzIoZ9lx0m4eKzLp5Guqe8VaX2QpsbZD3vtlVzaz9yn6pulkn+MpYm1c8F93mIwd6m3c8F93mGy88YmaWOQd6M7PMOdCbmWXOgd7MLHMO9GZmmXOgNzPLnAO9mVnmHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5oqsGbtU0iFJ45JOSbq/Tpn1kk5IGpM0Kum2qrx7JT2ffu5tdwPMzGxmRRYeuQw8GBHHJC0Cjko6EBGnq8ocBPZGREhaBTwBrJR0HfBvgREg0nP3RsT329wOMzNroOkRfUScj4hj6fHrwDgwXFPm0vRi4MA1VII6wDrgQES8loL7AeCD7aq8mZk111IfvaTlwGrgSJ28OyU9B+wD7kvJw8BLVcVepuZLwszMOqtwoJe0EHgSeCAiLtbmR8TuiFgJbAAenn5anZeK2gRJm1Lf/ujExETRKpmZWQGFAr2kASpBfmdE7JqpbEQcBlZIWkLlCH5pVfa7gHN1nrMjIkYiYmRoaKhw5c3MrLkis24EPAaMR8SjDcrclMohaQ2wAHgV2A/cIelaSdcCd6Q0MzPrkiKzbtYCdwMnJY2ltIeAZQARsR24C7hH0hQwCWxMg7OvSXoY+Hp63r+LiNfa2QAzM5uZ3pwsUw4jIyMxOjra62qYmfUVSUcjYqRenq+MNTPLnAO9mVnmHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5hzozcwy50BvZpY5B3ozs8w50JuZZc6B3swscw70ZmaZc6A3M8tc6VaYkjQBvNjreszREuB7va5EB+TaLsi3bW5X/5lt234uIobqZZQu0OdA0mijJb36Wa7tgnzb5nb1n060zV03ZmaZc6A3M8ucA31n7Oh1BTok13ZBvm1zu/pP29vmPnozs8z5iN7MLHMO9GZmmXOgnwVJ35F0UtKYpNGUdp2kA5KeT7+vTemS9J8kvSDphKQ1va39W0n6Q0mvSHq2Kq3ltki6N5V/XtK9vWhLtQbt+rSks2m7jUn6UFXep1K7zkhaV5X+wZT2gqQt3W5HLUlLJR2SNC7plKT7U3pfb7MZ2pXDNnu7pL+S9I3Uts+k9BslHUnv/5ckLUjpV6e/X0j5y6teq26bm4oI/7T4A3wHWFKT9u+BLenxFmBrevwh4M8BAb8CHOl1/Wvq/X5gDfDsbNsCXAd8K/2+Nj2+toTt+jTwyTpl3wt8A7gauBH4JnBV+vkm8G5gQSrz3h6363pgTXq8CPjrVP++3mYztCuHbSZgYXo8ABxJ2+IJ4GMpfTvwL9PjTwDb0+OPAV+aqc1F6uAj+vZZD3whPf4CsKEq/fGoeAZYLOn6XlSwnog4DLxWk9xqW9YBByLitYj4PnAA+GDna99Yg3Y1sh7404j4UUR8G3gB+OX080JEfCsi/hb401S2ZyLifEQcS49fB8aBYfp8m83Qrkb6aZtFRFxKfw6knwBuB76c0mu32fS2/DLwjyWJxm1uyoF+dgL4qqSjkjaltJ+NiPNQ2WmBn0npw8BLVc99mZl34DJotS391MZ/nbow/nC6e4M+bVc6pV9N5Qgxm21W0y7IYJtJukrSGPAKlS/VbwIXIuJyKlJdzx+3IeX/AHgnc2ibA/3srI2INcCvAf9K0vtnKKs6af06p7VRW/qljb8PrABuBc4Dn0vpfdcuSQuBJ4EHIuLiTEXrpJW2bXXalcU2i4grEXEr8C4qR+HvqVcs/W572xzoZyEizqXfrwC7qWy47053yaTfr6TiLwNLq57+LuBc92o7K622pS/aGBHfTR+4N4A/4M3T3r5ql6QBKsFwZ0TsSsl9v83qtSuXbTYtIi4Af0Glj36xpLelrOp6/rgNKf/vUOmGnHXbHOhbJOkaSYumHwN3AM8Ce4HpmQv3Av8tPd4L3JNmP/wK8IPpU+wSa7Ut+4E7JF2bTq3vSGmlUjM2cieV7QaVdn0szXa4Efh54K+ArwM/n2ZHLKAyMLa3m3WulfpqHwPGI+LRqqy+3maN2pXJNhuStDg9HgR+lcoYxCHgo6lY7Tab3pYfBZ6KymhsozY318vR6H78oTKa/430cwr47ZT+TuAg8Hz6fV28OeL+X6j0yZ0ERnrdhpr2fJHKKfEUlSOG35pNW4D7qAwOvQD8Zknb9cep3ifSh+b6qvK/ndp1Bvi1qvQPUZkB8s3pbd3jdt1G5XT9BDCWfj7U79tshnblsM1WAcdTG54Ffjelv5tKoH4B+DPg6pT+9vT3Cyn/3c3a3OzHt0AwM8ucu27MzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy9z/B4iZzGiJC+rEAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.scatter(df.classes,df.lambda_weighted)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Adjusted CrossEntropy" - ] - }, - { - "cell_type": "code", - "execution_count": 288, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "class MultiTaskCELoss(nn.Module):\n", - " \"\"\"\n", - " A cross entropy loss function which will cancel out\n", - " the effect of different class numbers\n", - " \"\"\"\n", - " def __init__(self,):\n", - " super().__init__()\n", - " self.celoss = nn.CrossEntropyLoss()\n", - " \n", - " def forward(\n", - " self,\n", - " y_pred: torch.FloatTensor,\n", - " y_true: torch.LongTensor,\n", - " )-> torch.FloatTensor:\n", - " \"\"\"\n", - " Input:\n", - " - y_pred: torch.FloatTensor, Prediction tensor\n", - " - y_true: torch.LongTensor, Label indices\n", - " Return:\n", - " - loss: torch.FloatTensor, scala adjusted\n", - " \"\"\"\n", - " nb_classes = y_pred.size(-1)\n", - " lambda_ = 1/np.log10(nb_classes)\n", - " loss = self.celoss(y_pred,y_true)\n", - " return loss*lambda_" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's make this adjustment into the loss function, an upgraded version of CrossEntropy" - ] - }, - { - "cell_type": "code", - "execution_count": 289, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e869eca73d324737b011b32aebdeab43", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "result = []\n", - "\n", - "for i in tqdm(range(50)):\n", - " c = random.randint(2,3000)\n", - " b = 1-random.random()\n", - " # here we change the loss function to MultiTaskCELoss\n", - " loss = create_softmax_pipeline(1,c,crit=MultiTaskCELoss())\\\n", - " .test_loss(300,inbalanced_input(b))\n", - " result.append(dict(classes=c,loss=loss.item(),inbl_bias=b))" - ] - }, - { - "cell_type": "code", - "execution_count": 283, - "metadata": {}, - "outputs": [], - "source": [ - "df_adjusted = pd.DataFrame(result)" - ] - }, - { - "cell_type": "code", - "execution_count": 284, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
classeslossinbl_bias
05332.3192480.208584
122642.3263460.022896
22542.3207350.922445
310272.3282910.643444
424282.3261360.883795
57722.3170710.062716
621662.3157540.364122
78122.3219720.852977
826322.3217080.639467
925762.3132690.968200
1013262.3215150.508901
111692.3215260.870789
128152.3091630.262275
137922.3217710.608608
147392.3162150.198933
1524602.3185230.880343
1624202.3205170.781538
1716882.3166570.954598
1829292.3230090.750869
19782.1927230.089868
2016702.3256080.188055
2128682.3245250.058037
229052.3192830.126673
2316752.3095650.307769
2414302.3270990.451335
254242.3158440.512538
2615112.3163250.613963
279592.3303170.987407
281472.3254590.931882
292782.3172730.401440
308852.3253970.087088
3114592.3079710.857782
321822.3447160.443465
3324682.3268250.572978
349542.3227500.815261
3512972.3199340.887149
3617872.3261610.113759
374622.3252760.969307
386862.3242620.305557
394282.3239130.342770
4029002.3193490.609161
4127652.3233730.296973
427522.3201890.302571
435022.3258970.419716
4421622.3204490.202672
4524602.3179620.430931
4625372.3121160.693554
478062.3187460.539312
482512.3228450.181191
4928452.3260070.667570
\n", - "
" - ], - "text/plain": [ - " classes loss inbl_bias\n", - "0 533 2.319248 0.208584\n", - "1 2264 2.326346 0.022896\n", - "2 254 2.320735 0.922445\n", - "3 1027 2.328291 0.643444\n", - "4 2428 2.326136 0.883795\n", - "5 772 2.317071 0.062716\n", - "6 2166 2.315754 0.364122\n", - "7 812 2.321972 0.852977\n", - "8 2632 2.321708 0.639467\n", - "9 2576 2.313269 0.968200\n", - "10 1326 2.321515 0.508901\n", - "11 169 2.321526 0.870789\n", - "12 815 2.309163 0.262275\n", - "13 792 2.321771 0.608608\n", - "14 739 2.316215 0.198933\n", - "15 2460 2.318523 0.880343\n", - "16 2420 2.320517 0.781538\n", - "17 1688 2.316657 0.954598\n", - "18 2929 2.323009 0.750869\n", - "19 78 2.192723 0.089868\n", - "20 1670 2.325608 0.188055\n", - "21 2868 2.324525 0.058037\n", - "22 905 2.319283 0.126673\n", - "23 1675 2.309565 0.307769\n", - "24 1430 2.327099 0.451335\n", - "25 424 2.315844 0.512538\n", - "26 1511 2.316325 0.613963\n", - "27 959 2.330317 0.987407\n", - "28 147 2.325459 0.931882\n", - "29 278 2.317273 0.401440\n", - "30 885 2.325397 0.087088\n", - "31 1459 2.307971 0.857782\n", - "32 182 2.344716 0.443465\n", - "33 2468 2.326825 0.572978\n", - "34 954 2.322750 0.815261\n", - "35 1297 2.319934 0.887149\n", - "36 1787 2.326161 0.113759\n", - "37 462 2.325276 0.969307\n", - "38 686 2.324262 0.305557\n", - "39 428 2.323913 0.342770\n", - "40 2900 2.319349 0.609161\n", - "41 2765 2.323373 0.296973\n", - "42 752 2.320189 0.302571\n", - "43 502 2.325897 0.419716\n", - "44 2162 2.320449 0.202672\n", - "45 2460 2.317962 0.430931\n", - "46 2537 2.312116 0.693554\n", - "47 806 2.318746 0.539312\n", - "48 251 2.322845 0.181191\n", - "49 2845 2.326007 0.667570" - ] - }, - "execution_count": 284, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_adjusted" - ] - }, - { - "cell_type": "code", - "execution_count": 287, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 287, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAD4CAYAAAAD6PrjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAclklEQVR4nO3df5Ac5X3n8fcHaQULwpGAtSP044QxJcAVkLg9TJVcpsAXBFzFErHvjM8lSExKvhyuQy5ZhSBXmDuSChgbcqmK0cngBCeKAYP4UcVxigqUomyC4tUPJMRaIAyOkXSwgGTBobO10vf+mGea0Wh+9MzO7sysPq+qqe15+ume59nu6W/38zw9rYjAzMwM4Lh2F8DMzDqHg4KZmWUcFMzMLOOgYGZmGQcFMzPLTGx3ARpx2mmnxezZs9tdDDOzrrJx48a3I6IvT96uCgqzZ89mYGCg3cUwM+sqkn6RN6+bj8zMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQZ6GkrZK2SBqQ9OmSeYdS+hZJT5SknyFpg6RXJD0oaVLrqmVmZs3Ic6UwDCyLiHOAi4DrJZ1bludp4PyImAt8Bbi3ZN6BiJibXp8rSb8DuDsizgL2Atc1XQszM2uJukEhIvZExKY0/R4wCEwvy/N+RER6exIQ1CBJwKXAwynpfmBRY0U3M7NWa6hPQdJsYB6wocK8qyT9DHiSwtVC0QmpSel5ScUD/6nAvogYTu/foCzQmJnZ2MsdFCRNBh4BlkbE/vL5EfFoRJxN4Yz/tpJZsyKiH/iPwF9IOhNQhY+oeHUhaUkKKgNDQ0N5i2tmZk3IFRQk9VAICKsjYk2tvBHxLHCmpNPS+93p78+Bf6RwpfE2MEVS8RnRM4DdVda3KiL6I6K/ry/Xc6fNzKxJeUYfCbgPGIyIu6rk+UTKh6QLgEnAO5KmSjo+pZ8GzAdeSv0P64EvpFVcCzw+0sqM1GObdzH/9mc4Y8WTzL/9GR7bvKvdRTIzG1MT62dhPrAY2CZpS0q7GZgFEBErgc8D10g6CBwAvhgRIekc4H9KOkwhAN0eES+lddwIPCDpT4HNFAJP2zy2eRc3rdnGgYOHANi17wA3rdkGwKJ57u4ws2ODPhw01Pn6+/tjYGBgVNY9//Zn2LXvwFHp06f08pMVl47KZ5qZjQVJG1Pfbl2+oznZXSEg1Eo3MxuPHBSS06f0NpRuZjYeOSgkyxfMobdnwhFpvT0TWL5gTptKZGY29vJ0NB8Tip3Jd67dwe59Bzh9Si/LF8xxJ7OZHVMcFEosmjfdQcDMjmluPjIzs4yDgpmZZRwUzMws46BgZmYZBwUzM8s4KJiZWcZBwczMMg4KZmaW8c1rNioe27zLd4ebdSEHBWs5P5vCrHu5+cha7s61O7KAUHTg4CHuXLujTSUys7x8pWAtV+/ZFG5aMutcDgrWcqdP6a34FLvTp/R2ZdOSg1h38nZrjoNCB+r2nXn5gjlHHPjhw2dT1Gpa6sQ6jlUQa2abd/t+UqoVdSldx2/19vB/fzPMwUOFxw136slHJ27DukFB0kzgB8BvA4eBVRHxP8ryLARuS/OHgaUR8WNJc4F7gI8Ah4A/i4gH0zJ/A1wM/Cqt5g8iYksrKtWITtgo421nrvVsiq8/WHkTd+pjT8ciiDUTeLrxiquaVtSlfB37Dhw8Kk+nnXx06jbMc6UwDCyLiE2STgY2SloXES+V5HkaeCIiQtJ5wEPA2cAHwDUR8Yqk09OyayNiX1pueUQ83ML6NKSdG6V4UN217wACIqWPxs7c6sCX5/9W7dkUtZqWOtFYPLu7mcDTbVdctbSiLpXWUUknnXxUq/eyh14A2hcY6o4+iog9EbEpTb8HDALTy/K8HxHF49pJpGNcRLwcEa+k6d3AW0Bf64o/Mu0aJVM8qBYPjlEnPzS/M5d+VvDhAfyxzbuaWh+M7P/WbY89HYtndzcTeMYiWI2VVtQlb96xPPl4bPMu5t/+DGeseJL5tz9z1HeuWpkPRRzxHa23nlZraEiqpNnAPGBDhXlXSfoZ8CTwlQrzLwQmAa+WJP+ZpK2S7pZ0fJXPXCJpQNLA0NBQI8Wtq9GdsVUbJ+9ZTalmd+aRHMCr1TfP/63asovmTefPf/93mD6lFwHTp/Ty57//Ox17djsWQayZwDMWwWqstKIuefKO5clHnpOxWmUufkdH46SuntxBQdJk4BEK/QX7y+dHxKMRcTawiEL/Qumy04C/Bf4wIg6n5JsoNDH9G+AU4MZKnxsRqyKiPyL6+/pae5FRbaMcJx31T2/lxmn0bG4kO3OzgW/2iif5+oNbKta33pe43v9q0bzp/GTFpbx2+7/jJysubUsfTt7gPhZBrJnA021XXLW0oi6V1tFznJh6Ys+It1szJ4N5TsYqlbnU7n0H2tKakWv0kaQeCgFhdUSsqZU3Ip6VdKak0yLibUkfoXD18F8j4vmSfHvS5K8l/TXwjeaq0LxKo2Tgw8s3OLLTtFVtuNXa1Yt6jhOTT5jIvg8OjrgPoF4bfml/w5QTe3j//w1z8HChQau8WatY31qji6Cz27ub6Uca7Wd31+qYb+UyYy1vX1ar6nJCz3HZdp3S28Otn/vkiP8fzfY75jkZKy6/7KEXOBRHNyKfPqW3Lc2EeUYfCbgPGIyIu6rk+QTwaupovoBCM9E7kiYBjwI/iIgflS0zLSL2pPUvAl4cYV0aVmujlB/Eqh3Eax3cq6l0UC12Nk+v8IUonqk084WpdQAv3+H3fnB0J3e53fsOHPEl3rXvABOkI85eOrG9u7Rjv1wnBKxmAs9oB6uRaPRgOpK6lH8WwK+HD9dYIr9mT3DyDqgorqPS8aD43aoWMEZLniuF+cBiYJuk4njCm4FZABGxEvg8cI2kg8AB4IspQPwH4DPAqZL+IC1bHHq6WlIfhfpvAf5Ti+rUkLzDJKttnAlSU58J+c6MRjpCqtZnzb/9mab7NirtzMWyTTmxp2KAaVd7d6WDRrlu7KDtZGN5tVhrFM/XH9wyoquoZk9w6l1Nlyo/ySodjVjpmDPazYR1g0JE/JjCgbtWnjuAOyqk/x3wd1WWuTRnGUddnqheaePUSq8n75lRK75c1T6r0QNhz3E6YmesVrbjJx5Hb8+EXF+IsZCnY78bO2g72VheLdYaxQOFk5XlP2pumGezQ6gbbRIrfkfn3/5Mxc+bIHE4YkyaCf2DeOTr6JpeYycYzWFio/nlavRAOPmEiUfsjNXK8KsDBztqhFG9/1UnddCO9fDD0TKWo6PyrPPg4eDWJ7Y3vO6RdII3M6Ci2r56OGLMBmb4Zy7IF9WrdUpD4006jdxMNpo3e11ydh+rn/+XXPdJAOwraxKqVbZOau+u1bFfqQ+nUa26ObBT73BtxvIFc1j+8AvZnfkAPRM0KsG31nezVKUbQ+sZ6w79Tri500EhqXcQK2/3K5e3SafRL34jbZONeGzzLh7ZuCt3QICjd8zRKlurVStnK65eWnkg7+RRW00p37maa2mtq/zA3eqPGcsTnE74Trn5qAHFy8FqHSx5mnQaHXc8WuPkG72BrtKO2S03oo1mOVs5jrwTR2016861O7KhzUUHD8eoja8vbaqZemJPxTzV0jtJJ3ynfKXQhJFc4jXzxR+NM5U8B5o8nVud1ExUy2iVs5UH8k5oOmiVdga4b/7eJys2XX3z9z5Zc7lO+HFMaP93ykGhCSO5xOuUL369G+ha1bwy3rVye3ZC00GrtHM/z9sPUOvGzW7uzxkpNx81YSSXeJ3y8wSVylFsFuvUZqBO1Mrt2QlNB63S7v283sif8p9i2fvBwaOau47VR8j6SqFJzV7idcrPE3RKObpdq/+P7W46aJVO37+68ae2x4qiyZuv2qG/vz8GBgbaXQyzUdEpbdrHgjNWPJlrlNL0Kb38ZEXH3GfbNEkbI6I/T15fKVhVPkiNnfF0j0I3qNenBt3bnzNS7lOwitrxO+7dYLTuOG7XA5+OVRV/anuCmNI78p/a7na+UrCKxt2NVC0wmmfz4+kehW7Q6X0e7eSgYBX5IHW00QyUnTJU+VgyXjr1W83NR1bReHrcY6uMZqBs9xBOsyIHBavIB6mjjWagHE/3KFh3c/ORVeQ216ON9h3Hbs6wTuCgYFX5IHUkB0o7FjgomDXAgdLGu7p9CpJmSlovaVDSdkk3VMizUNJWSVskDUj6dMm8ayW9kl7XlqT/a0nbJO2U9JdSEw87NjOzlsrT0TwMLIuIc4CLgOslnVuW52ng/IiYC3wFuBdA0inAN4FPARcC35Q0NS1zD7AEOCu9Lh9hXczMbITqBoWI2BMRm9L0e8AgML0sz/vx4Y8oncSHz1haAKyLiHcjYi+wDrhc0jTgIxHxT2m5HwCLWlIjMzNrWkNDUiXNBuYBGyrMu0rSz4AnKVwtQCF4/LIk2xspbXqaLk83M7M2yh0UJE0GHgGWRsT+8vkR8WhEnE3hjP+24mIVVhU10it97pLUTzEwNDSUt7hmZtaEXEFBUg+FgLA6ItbUyhsRzwJnSjqNwhXAzJLZM4DdKX1GhfRK61sVEf0R0d/X15enuGZm1qQ8o48E3AcMRsRdVfJ8ojh6SNIFwCTgHWAtcJmkqamD+TJgbUTsAd6TdFFa7hrg8ZbUyMzMmpbnPoX5wGJgm6QtKe1mYBZARKwEPg9cI+kgcAD4YupAflfSbcBP03L/PSLeTdN/DPwN0As8lV5mZtZGfvKamdk418iT1/yDeGZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpZxUDAzs4yDgpmZZRwUzMws46BgZmYZBwUzM8s4KJiZWcZBwczMMg4KZmaWcVAwM7OMg4KZmWXqBgVJMyWtlzQoabukGyrk+bKkren1nKTzU/ocSVtKXvslLU3zbpW0q2Tela2vnpmZNWJijjzDwLKI2CTpZGCjpHUR8VJJnteAiyNir6QrgFXApyJiBzAXQNIEYBfwaMlyd0fEt1tSEzMzG7G6QSEi9gB70vR7kgaB6cBLJXmeK1nkeWBGhVV9Fng1In4xohKbmdmoaahPQdJsYB6woUa264CnKqRfDfywLO1rqcnp+5KmVvnMJZIGJA0MDQ01UlwzM2tQ7qAgaTLwCLA0IvZXyXMJhaBwY1n6JOBzwI9Kku8BzqTQvLQH+E6ldUbEqojoj4j+vr6+vMU1M7Mm5AoKknooBITVEbGmSp7zgHuBhRHxTtnsK4BNEfFmMSEi3oyIQxFxGPgecGEzFTAzs9bJM/pIwH3AYETcVSXPLGANsDgiXq6Q5UuUNR1Jmlby9irgxbyFNjOz0ZFn9NF8YDGwTdKWlHYzMAsgIlYCtwCnAt8txBCGI6IfQNKJwO8CXy1b77ckzQUCeL3CfDMzG2N5Rh/9GFCdPH8E/FGVeR9QCBjl6YtzltHMzMaI72g2M7OMg4KZmWUcFMzMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpZxUDAzs4yDgpmZZRwUzMwsUzcoSJopab2kQUnbJd1QIc+XJW1Nr+cknV8y73VJ2yRtkTRQkn6KpHWSXkl/p7auWmZm1ow8VwrDwLKIOAe4CLhe0rlleV4DLo6I84DbgFVl8y+JiLkR0V+StgJ4OiLOAp5O783MrI3qBoWI2BMRm9L0e8AgML0sz3MRsTe9fR6YkeOzFwL3p+n7gUV5C21mZqOjoT4FSbOBecCGGtmuA54qeR/AP0jaKGlJSfrHImIPFAIP8NEqn7lE0oCkgaGhoUaKa2ZmDZqYN6OkycAjwNKI2F8lzyUUgsKnS5LnR8RuSR8F1kn6WUQ8m/dzI2IVqTmqv78/8i5nZmaNy3WlIKmHQkBYHRFrquQ5D7gXWBgR7xTTI2J3+vsW8ChwYZr1pqRpadlpwFvNVsLMzFojz+gjAfcBgxFxV5U8s4A1wOKIeLkk/SRJJxengcuAF9PsJ4Br0/S1wOPNVsLMzFojT/PRfGAxsE3SlpR2MzALICJWArcApwLfLcQQhtNIo48Bj6a0icDfR8T/Tuu4HXhI0nXAvwD/viU1MjOzpimie5rp+/v7Y2BgoH5GMzPLSNpYdktAVb6j2czMMg4KZmaWcVAwM7OMg4KZmWUcFMzMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQ58uStqbXc5LOr7espFsl7ZK0Jb2ubG3VzMysURNz5BkGlkXEJkknAxslrYuIl0ryvAZcHBF7JV0BrAI+lWPZuyPi2y2sj5mZjUDdK4WI2BMRm9L0e8AgML0sz3MRsTe9fR6YkXdZMzPrHA31KUiaDcwDNtTIdh3wVM5lv5aanL4vaWqVz1wiaUDSwNDQUCPFNTOzBuUOCpImA48ASyNif5U8l1AICjfmWPYe4ExgLrAH+E6ldUbEqojoj4j+vr6+vMU1M7Mm5AoKknooHNRXR8SaKnnOA+4FFkbEO/WWjYg3I+JQRBwGvgdc2Hw1zMysFfKMPhJwHzAYEXdVyTMLWAMsjoiX8ywraVrJ26uAFxsvvpmZtVKe0UfzgcXANklbUtrNwCyAiFgJ3AKcCny3EAcYjoj+astGxP8CviVpLhDA68BXW1IjMzNrmiKi3WXIrb+/PwYGBtpdDDOzriJpYzpRr8t3NJuZWcZBwczMMg4KZmaWcVAwM7OMg4KZmWUcFMzMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQ58uStqbXc5LOL5l3uaQdknZKWlGSfoakDZJekfSgpEmtq5aZmTUjz5XCMLAsIs4BLgKul3RuWZ7XgIsj4jzgNmAVgKQJwF8BVwDnAl8qWfYO4O6IOAvYC1w30sqYmdnI1A0KEbEnIjal6feAQWB6WZ7nImJvevs8MCNNXwjsjIifR8RvgAeAhZIEXAo8nPLdDywaaWXMzGxkGupTkDQbmAdsqJHtOuCpND0d+GXJvDdS2qnAvogYLkuv9JlLJA1IGhgaGmqkuGZm1qDcQUHSZOARYGlE7K+S5xIKQeHGYlKFbFEj/ejEiFUR0R8R/X19fXmLa2ZmTcgVFCT1UAgIqyNiTZU85wH3Agsj4p2U/AYwsyTbDGA38DYwRdLEsnQzM2ujPKOPBNwHDEbEXVXyzALWAIsj4uWSWT8FzkojjSYBVwNPREQA64EvpHzXAo83Xw0zM2uFifWzMB9YDGyTtCWl3QzMAoiIlcAtFPoJvluIIQynJp9hSV8D1gITgO9HxPa0jhuBByT9KbCZQuAxM7M2UuGkvTv09/fHwMBAQ8s8tnkXd67dwe59Bzh9Si/LF8xh0byKfdpmZuOSpI0R0Z8nb54rha712OZd3LRmGwcOHgJg174D3LRmG4ADg5lZBeP6Zy7uXLsjCwhFBw4e4s61O9pUIjOzzjaug8LufQcaSjczO9aN66Bw+pTehtLNzI514zooLF8wh96eCUek9fZMYPmCOW0qkZlZZxvXHc3FzmSPPjIzy2dcBwUoBAYHATOzfMZ185GZmTXGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQ52xJ/yTp15K+UZI+R9KWktd+SUvTvFsl7SqZd2Vrq2ZmZo3K84N4w8CyiNgk6WRgo6R1EfFSSZ53gf8CLCpdMCJ2AHMBJE0AdgGPlmS5OyK+PZIKmJlZ69S9UoiIPRGxKU2/BwwC08vyvBURPwUO1ljVZ4FXI+IXIyivmZmNoob6FCTNBuYBG5r4rKuBH5alfU3SVknflzS1ymcukTQgaWBoaKiJjzUzs7xyBwVJk4FHgKURsb+RD5E0Cfgc8KOS5HuAMyk0L+0BvlNp2YhYFRH9EdHf19fXyMeamVmDcgUFST0UAsLqiFjTxOdcAWyKiDeLCRHxZkQciojDwPeAC5tYr5mZtVCe0UcC7gMGI+KuJj/nS5Q1HUmaVvL2KuDFJtdtZmYtkmf00XxgMbBN0paUdjMwCyAiVkr6bWAA+AhwOA07PTci9ks6Efhd4Ktl6/2WpLlAAK9XmG9mZmOsblCIiB8DqpPn/wAzqsz7ADi1QvrinGU0M7MxoohodxlykzQEVBvSehrw9hgWZyyMxzrB+KyX69Q9xmO96tXpX0VErpE6XRUUapE0EBH97S5HK43HOsH4rJfr1D3GY71aWSf/9pGZmWUcFMzMLDOegsKqdhdgFIzHOsH4rJfr1D3GY71aVqdx06dgZmYjN56uFMzMbIQcFMzMLNP1QUHS5ZJ2SNopaUW7y9MoSa9L2pYeNDSQ0k6RtE7SK+nv1JQuSX+Z6rpV0gXtLX1B+pXbtyS9WJLWcB0kXZvyvyLp2nbUpaQslepU9cFQkm5KddohaUFJekftn9UemtXN26tGnbp2e0k6QdI/S3oh1em/pfQzJG1I//MH04+NIun49H5nmj+7ZF0V61pVRHTtC5gAvAp8HJgEvEDh5zXaXrYG6vA6cFpZ2reAFWl6BXBHmr4SeIrCHeYXARvaXf5Urs8AFwAvNlsH4BTg5+nv1DQ9tcPqdCvwjQp5z0373vHAGWmfnNCJ+ycwDbggTZ8MvJzK37Xbq0adunZ7pf/35DTdQ+FxBRcBDwFXp/SVwB+n6f8MrEzTVwMP1qprrc/u9iuFC4GdEfHziPgN8ACwsM1laoWFwP1p+n4+fKLdQuAHUfA8MEVH/rBgW0TEsxSevleq0TosANZFxLsRsRdYB1w++qWvrEqdqlkIPBARv46I14CdFPbNjts/o/pDs7p2e9WoUzUdv73S//v99LYnvQK4FHg4pZdvp+L2exj4rCRRva5VdXtQmA78suT9G9TeGTpRAP8gaaOkJSntYxGxBwo7PPDRlN5N9W20Dt1St0oPhurKOunIh2aNi+2lox8E1rXbS9IEFX6E9C0KQfdVYF9EDFcoX1b2NP9XFH5zruE6dXtQqPRDfd02xnZ+RFxA4ZkT10v6TI2846G+1erQDXWr9mCorquT8j80q2vqVqFOXb29ovC8mbkUfmz0QuCcStnS35bVqduDwhvAzJL3M4DdbSpLUyJid/r7FvAohY3/ZrFZKP19K2Xvpvo2WoeOr1tUfzBUV9VJlR+a1dXbq1Kdxsv2ioh9wD9S6FOYIqn469al5cvKnub/FoXmz4br1O1B4afAWalHfhKFDpYn2lym3CSdJOnk4jRwGYWHDT0BFEdzXAs8nqafAK5JI0IuAn5VvOTvQI3WYS1wmaSp6TL/spTWMVT9wVBPAFenESBnAGcB/0wH7p+pnbnSQ7O6dntVq1M3by9JfZKmpOle4N9S6CtZD3whZSvfTsXt9wXgmSj0NFera3Xt6Flv5YvC6IiXKbS3/Um7y9Ng2T9OYWTAC8D2YvkptAU+DbyS/p4SH45I+KtU121Af7vrkMr1QwqX5wcpnJlc10wdgK9Q6AjbCfxhB9bpb1OZt6Yv27SS/H+S6rQDuKJT90/g0xSaD7YCW9Lrym7eXjXq1LXbCzgP2JzK/iJwS0r/OIWD+k4Kz7w/PqWfkN7vTPM/Xq+u1V7+mQszM8t0e/ORmZm1kIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwy/x/AvBzCXAI56QAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.scatter(\n", - " df_adjusted.classes,\n", - " df_adjusted.loss,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/optimizers.ipynb b/nbs/optimizers.ipynb deleted file mode 100644 index 31c58da..0000000 --- a/nbs/optimizers.ipynb +++ /dev/null @@ -1,373 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Optimizer Mangement\n", - "> Handling PyTorch Cuda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp ftorch.optimizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This handler is created to handle the multiple optimizer situations" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "import torch\n", - "\n", - "class Opts(object):\n", - " def __init__(self, *args, **kwargs):\n", - " \"\"\"\n", - " opts = Opts(opt1 = opt1, opt2 = opt2)\n", - " opts.opt2.zero_grad()\n", - " opts[\"opt3\"] = opt3\n", - " print(len(opts))\n", - " \"\"\"\n", - " self.optlist = []\n", - " self.optnames = []\n", - " for i in range(len(args)):\n", - " oname = f\"optimizer_no{i + 1}\"\n", - " setattr(self, oname, args[i])\n", - " self.optlist.append(args[i])\n", - " self.optnames.append(oname)\n", - " for k, v in kwargs.items():\n", - " setattr(self, k, v)\n", - " self.optlist.append(v)\n", - " self.optnames.append(k)\n", - "\n", - " def __repr__(self):\n", - " return \"\\n\".join(list(\n", - " f\"{self.optnames[i]}\\n\\t{self.optlist[i].__class__}\\n\\t{self.read_opt(self.optlist[i])}\" for i in\n", - " range(len(self.optnames))))\n", - "\n", - " def get_pg(self, opt):\n", - " \"\"\"\n", - " Get paramgroups dictionary, informations about an optimizer\n", - " opt:torch.optim.optimizer\n", - " \"\"\"\n", - " return dict.copy(opt.param_groups[0])\n", - "\n", - " def read_opt(self, opt):\n", - " rt = self.get_pg(opt)\n", - " if \"params\" in rt:\n", - " del rt[\"params\"]\n", - " return rt\n", - "\n", - " def __len__(self):\n", - " \"\"\"\n", - " Total number of optimizers\n", - " \"\"\"\n", - " return len(self.optlist)\n", - "\n", - " def __contains__(self, item):\n", - " return item in self.optlist\n", - "\n", - " def __getitem__(self, item):\n", - " return getattr(self, item)\n", - "\n", - " def __setitem__(self, key, optimizer):\n", - " self.optlist.append(optimizer)\n", - " self.optnames.append(key)\n", - " setattr(self, key, optimizer)\n", - "\n", - " def zero_all(self):\n", - " \"\"\"\n", - " Zero gradient on all the optimizers\n", - " \"\"\"\n", - " for opt in self.optlist:\n", - " opt.zero_grad()\n", - "\n", - " def step_all(self):\n", - " \"\"\"\n", - " All the optimizers match a step\n", - " \"\"\"\n", - " for opt in self.optlist:\n", - " opt.step()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Experiment" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from torch import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "layer1 = nn.Linear(5,5)\n", - "layer2 = nn.Linear(5,5)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "op1 = torch.optim.Adam(layer1.parameters())\n", - "op2 = torch.optim.Adagrad(layer2.parameters())" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "opts = Opts(op1, op2)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Adam (\n", - "Parameter Group 0\n", - " amsgrad: False\n", - " betas: (0.9, 0.999)\n", - " eps: 1e-08\n", - " lr: 0.001\n", - " weight_decay: 0\n", - ")" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "opts.optimizer_no1" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Adagrad (\n", - "Parameter Group 0\n", - " eps: 1e-10\n", - " initial_accumulator_value: 0\n", - " lr: 0.01\n", - " lr_decay: 0\n", - " weight_decay: 0\n", - ")" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "opts.optimizer_no2" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "optimizer_no1\n", - "\t\n", - "\t{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}\n", - "optimizer_no2\n", - "\t\n", - "\t{'lr': 0.01, 'lr_decay': 0, 'eps': 1e-10, 'weight_decay': 0, 'initial_accumulator_value': 0}" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "opts" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create some gradients" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.rand(2,5)\n", - "y_ = -(layer2(torch.nn.functional.relu(layer1(x))).mean())" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "y_.backward()" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", - " [-5.7813e-02, -5.0724e-02, -6.1752e-02, -1.1506e-01, -6.0402e-02],\n", - " [-4.5386e-04, -2.3771e-04, -2.1409e-04, -6.2308e-04, -3.3177e-04],\n", - " [-1.3012e-01, -1.1416e-01, -1.3899e-01, -2.5898e-01, -1.3595e-01],\n", - " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]),\n", - " tensor([ 0.0000, -0.1238, -0.0007, -0.2787, 0.0000]),\n", - " tensor([[ 0.0000, -0.1671, -0.0036, -0.1062, 0.0000],\n", - " [ 0.0000, -0.1671, -0.0036, -0.1062, 0.0000],\n", - " [ 0.0000, -0.1671, -0.0036, -0.1062, 0.0000],\n", - " [ 0.0000, -0.1671, -0.0036, -0.1062, 0.0000],\n", - " [ 0.0000, -0.1671, -0.0036, -0.1062, 0.0000]]),\n", - " tensor([-0.2000, -0.2000, -0.2000, -0.2000, -0.2000]))" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "layer1.weight.grad,layer1.bias.grad,layer2.weight.grad,layer2.bias.grad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Take a step for all optimizers" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [], - "source": [ - "opts.step_all()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Zero gradient on all optimizers" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [], - "source": [ - "opts.zero_all()" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(tensor([[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]]),\n", - " tensor([0., 0., 0., 0., 0.]),\n", - " tensor([[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]]),\n", - " tensor([0., 0., 0., 0., 0.]))" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "layer1.weight.grad,layer1.bias.grad,layer2.weight.grad,layer2.bias.grad" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}