diff --git a/nbs/11_etl.ipynb b/nbs/11_etl.ipynb deleted file mode 100644 index 711b499..0000000 --- a/nbs/11_etl.ipynb +++ /dev/null @@ -1,174 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# ETL Helper\n", - "> A combined ETL tool sets" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp etl" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## A list for each step" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "\n", - "from forgebox.imports import *\n", - "from typing import Callable, Union\n", - "import traceback as tb\n", - "import logging\n", - "\n", - "class NewFileScanner:\n", - " \"\"\"\n", - " Keep scannning a directory for new file\n", - " if found any, callback processing the file\n", - " \n", - " Designed for download and delete strategy\n", - " \n", - " # Example\n", - " new_file_scanner = NewFileScanner(\".\", new_file_filter=lambda x:x[-4:]==\".txt\")\n", - " new_file_scanner(lambda x:print(f\"new file:{x} found\"))s\n", - " \"\"\"\n", - " def __init__(\n", - " self,\n", - " directory,\n", - " new_file_filter: Callable=None,\n", - " done_list_getter: Callable=None,\n", - " stop_callback: Callable=None,\n", - " ):\n", - " self.directory = Path(directory)\n", - " \n", - " if new_file_filter is not None:\n", - " self.new_file_filter=new_file_filter\n", - " else:\n", - " self.new_file_filter=lambda x:True\n", - " \n", - " if done_list_getter is not None:\n", - " self.done_list_getter = done_list_getter\n", - " else:\n", - " self.done_list_getter = lambda *args:[]\n", - " \n", - " if stop_callback is not None:\n", - " self.stop_callback = stop_callback\n", - " \n", - " self.processed = []\n", - " \n", - " def __call__(self, process: Callable):\n", - " while True:\n", - " try:\n", - " done_list = self.done_list_getter()\n", - " for fname in self.directory.iterdir():\n", - " fname = str(fname)\n", - " if self.new_file_filter(fname) != True:\n", - " continue\n", - " file_path = str(self.directory/fname)\n", - " if fname not in done_list:\n", - " if fname not in self.processed:\n", - " result = process(file_path)\n", - " self.processed.append(fname)\n", - " except KeyboardInterrupt as e:\n", - " if hasattr(self, \"stop_callback\"):\n", - " self.stop_callback()\n", - " logging.error(\"manually stoped\")\n", - " break\n", - " except Exception as e:\n", - " error = tb.format_exc()\n", - " logging.error(error)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Test new file scanner" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "new file:untitled.txt found\n", - "new file:bc.txt found\n", - "new file:cde.txt found\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ERROR:root:manually stoped\n" - ] - } - ], - "source": [ - "new_file_scanner = NewFileScanner(\".\", new_file_filter=lambda x:x[-4:]==\".txt\")\n", - "new_file_scanner(lambda x:print(f\"new file:{x} found\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/61_thunder_callbacks.ipynb b/nbs/61_thunder_callbacks.ipynb deleted file mode 100644 index 848e47f..0000000 --- a/nbs/61_thunder_callbacks.ipynb +++ /dev/null @@ -1,151 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Lightning Callbacks\n", - "> Thunder, the DIYed [pytorch-lightening callbacks](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp thunder.callbacks" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "import pandas as pd\n", - "from ipywidgets import Output\n", - "from typing import List, Dict\n", - "import copy\n", - "import pytorch_lightning as pl\n", - "import torch\n", - "from torch import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "def unfreeze(self):\n", - " \"\"\"unfreeze this module, and its sub modules\"\"\"\n", - " for p in self.parameters():\n", - " p.requires_grad = True\n", - "\n", - "\n", - "def freeze(self):\n", - " \"\"\"freeze this module, and its sub modules\"\"\"\n", - " for p in self.parameters():\n", - " p.requires_grad = False\n", - "\n", - "nn.Module.unfreeze = unfreeze\n", - "nn.Module.freeze = freeze\n", - "\n", - "class DataFrameMetricsCallback(pl.Callback):\n", - " \"\"\"\n", - " A metrics callback keep showing pandas dataframe\n", - " \"\"\"\n", - "\n", - " def __init__(self) -> None:\n", - " \"\"\"\n", - " In Trainer kwargs, passing this arguements along with other callbacks\n", - " callbacks = [DataFrameMetricsCallback(),]\n", - " \"\"\"\n", - " self.metrics: List = []\n", - "\n", - " def on_fit_start(\n", - " self, trainer: pl.Trainer,\n", - " pl_module: pl.LightningModule\n", - " ) -> None:\n", - " pl_module.output = Output()\n", - " display(pl_module.output)\n", - "\n", - " def on_validation_epoch_end(\n", - " self, trainer: pl.Trainer,\n", - " pl_module: pl.LightningModule\n", - " ) -> None:\n", - " metrics_dict = copy.copy(trainer.callback_metrics)\n", - " self.metrics.append(dict((k, v.item())\n", - " for k, v in metrics_dict.items()))\n", - " pl_module.output.clear_output()\n", - " with pl_module.output:\n", - " display(pd.DataFrame(self.metrics).tail(10))\n", - "\n", - "\n", - "def UnfreezeScheduler(frozen_epochs: int = 2):\n", - " assert hasattr(pl_module, \"top_layers\"), \"Please define 'top_layers' attributes\"+\\\n", - " \" for pl_module, which will return a list of nn.Module object(s)\"\n", - " class UnfreezeSchedulerCallback(pl.callbacks.Callback):\n", - " \"\"\"\n", - " Train the top layer for [frozen_epochs] epochs\n", - " then un freeze all\n", - " \"\"\"\n", - "\n", - " def on_epoch_start(self, trainer, pl_module):\n", - " epoch = trainer.current_epoch\n", - "\n", - " if epoch == 0:\n", - " pl_module.freeze()\n", - " for tl in pl_module.top_layers:\n", - " tl.unfreeze()\n", - " if epoch == frozen_epochs:\n", - " pl_module.unfreeze()\n", - " pl_module.base.embeddings.freeze()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/70_hf_transformer_data.ipynb b/nbs/70_hf_transformer_data.ipynb deleted file mode 100644 index ac7b569..0000000 --- a/nbs/70_hf_transformer_data.ipynb +++ /dev/null @@ -1,553 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Data parts for hf transformers" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp hf.data" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "from forgebox.imports import *\n", - "from forgebox.category import Category\n", - "from typing import List, Dict, Callable, Any, Tuple" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Process IOBES files" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "def convert_iob2_file_to_iobes(file_path, result_path):\n", - " \"\"\"\n", - " Convert IOB2 file to IOBES\n", - " \"\"\"\n", - " with open(file_path, 'r') as f:\n", - " lines = f.readlines()\n", - " with open(result_path, 'w') as f:\n", - " for line in lines:\n", - " line = line.strip()\n", - " if line == '':\n", - " f.write('\\n')\n", - " continue\n", - " line = line.split()\n", - " if line[-1] == 'O':\n", - " f.write(' '.join(line) + '\\n')\n", - " else:\n", - " f.write(' '.join(line[:-1]) + ' ' + line[-1] + '\\n')\n", - "\n", - "\n", - "def conbine_iobes_file(\n", - " file_paths: List[Path],\n", - " new_file_path: Path\n", - "):\n", - " \"\"\"\n", - " Conbine from multiple IOBES files\n", - " into IOBES files\n", - " \"\"\"\n", - " with open(new_file_path, 'w') as new_file:\n", - " for file_path in file_paths:\n", - " with open(file_path, 'r') as file:\n", - " for line in file:\n", - " new_file.write(line)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "class IOBES(Dataset):\n", - " \"\"\"\n", - " Load iobes file for NER training task\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " file_path,\n", - " tokenizer,\n", - " max_len=128,\n", - " save_buffer: int = 15,\n", - " category: Category = None,\n", - " return_string: bool = False,\n", - " use_frag: bool = False,\n", - " ):\n", - " \"\"\"\n", - " file_path,\n", - " tokenizer,\n", - " max_len=128,\n", - " save_buffer: int = 15,\n", - " category: Category = None,\n", - " label categories, if set to None, will be figured out\n", - " automatically.\n", - " You can set this to None for train dataset, but for valid\n", - " dataset:\n", - " valid_ds = IOBES(...,category=train_ds.cates)\n", - " return_string: bool = False, do we return original string\n", - " for tokenizer output, this option is good for debuging\n", - " but the data won't pass into cuda if choose so\n", - " use_frag: bool = False, do we use prepend like 'I-','B-'\n", - " \"\"\"\n", - " self.file_path = file_path\n", - " self.max_len = max_len\n", - " self.pairs = []\n", - " self.list_of_words = []\n", - " self.list_of_labels = []\n", - " self.tokenizer = tokenizer\n", - " self.cates = category\n", - " self.return_string = return_string\n", - " self.use_frag = use_frag\n", - " self.load_data(save_buffer)\n", - "\n", - " def load_data(self, save_buffer: int = 15):\n", - " \"\"\"\n", - " Load file in to object structure\n", - " \"\"\"\n", - " with open(self.file_path, 'r') as f:\n", - " for line in f:\n", - " line = line.strip()\n", - " if line:\n", - " splited = line.split()\n", - " if len(splited) != 2:\n", - " continue\n", - " word, label = splited\n", - " # do we use 'I-', 'B-' etc\n", - " if self.use_frag is False:\n", - " if \"-\" in label:\n", - " label = label.split('-')[1]\n", - " self.pairs.append([word, label])\n", - "\n", - " self.pairs = np.array(self.pairs)\n", - "\n", - " if self.cates is None:\n", - " labels_df = pd.DataFrame({\"label\": self.pairs[:, 1]})\n", - " self.cates = Category(list(labels_df.vc(\"label\").index))\n", - "\n", - " self.batching_words(save_buffer)\n", - "\n", - " def batching_words(self, save_buffer: int = 15):\n", - " \"\"\"\n", - " batching self.words into self.list_of_words\n", - " by self.max_len -15\n", - " \"\"\"\n", - " for i in range(0, len(self.pairs), self.max_len-save_buffer):\n", - " chunk_slice = slice(i, i+self.max_len-save_buffer)\n", - " self.list_of_words.append(self.pairs[chunk_slice, 0])\n", - " self.list_of_labels.append(self.pairs[chunk_slice, 1])\n", - "\n", - " def __len__(self) -> int:\n", - " return len(self.list_of_words)\n", - "\n", - " def __getitem__(self, idx: int) -> Tuple[List[str]]:\n", - " return list(self.list_of_words[idx]), list(self.list_of_labels[idx])\n", - "\n", - " def __repr__(self):\n", - " return f\"\"\"NER dataset using IOBES annotation\n", - " {len(self)} sentences,\n", - " Labels:\n", - " {list(self.cates.i2c)}\n", - " \"\"\"\n", - "\n", - " def collate_fn(self, data):\n", - " \"\"\"\n", - " data: list of tuple\n", - " \"\"\"\n", - " words, text_labels = zip(*data)\n", - "\n", - " inputs = self.tokenizer(\n", - " list(words),\n", - " return_tensors='pt',\n", - " padding=True,\n", - " truncation=True,\n", - " max_length=self.max_len,\n", - " is_split_into_words=True,\n", - " return_offsets_mapping=True,\n", - " add_special_tokens=False,\n", - " )\n", - " return self.align_offsets(inputs, text_labels, words)\n", - "\n", - " def align_offsets(\n", - " self,\n", - " inputs,\n", - " text_labels: List[List[str]],\n", - " words: List[List[str]]\n", - " ):\n", - " \"\"\"\n", - " inputs: output if tokenizer\n", - " text_labels: labels in form of list of list of strings\n", - " words: words in form of list of list of strings\n", - " \"\"\"\n", - " labels = torch.zeros_like(inputs.input_ids).long()\n", - " labels -= 100\n", - " text_lables_array = np.empty(labels.shape, dtype=object)\n", - " words_array = np.empty(labels.shape, dtype=object)\n", - " max_len = inputs.input_ids.shape[1]\n", - "\n", - " for row_id, input_ids in enumerate(inputs.input_ids):\n", - " word_pos = inputs.word_ids(row_id)\n", - " for idx, pos in enumerate(word_pos):\n", - " if pos is None:\n", - " continue\n", - " if pos <= max_len:\n", - " labels[row_id, idx] = self.cates.c2i[text_labels[row_id][pos]]\n", - " if self.return_string:\n", - " text_lables_array[row_id,\n", - " idx] = text_labels[row_id][pos]\n", - " words_array[row_id, idx] = words[row_id][pos]\n", - "\n", - " inputs['labels'] = labels\n", - " if self.return_string:\n", - " inputs['text_labels'] = text_lables_array.tolist()\n", - " inputs['word'] = words_array.tolist()\n", - " return inputs\n", - "\n", - " def dataloader(self, batch_size: int = 32, shuffle: bool = True):\n", - " \"\"\"\n", - " Create dataloader\n", - " \"\"\"\n", - " return DataLoader(\n", - " self,\n", - " batch_size=batch_size,\n", - " shuffle=shuffle,\n", - " collate_fn=self.collate_fn,\n", - " )\n", - "\n", - " def one_batch(self, batch_size: int = 32, shuffle: bool = True):\n", - " return next(iter(self.dataloader(batch_size, shuffle)))\n", - "\n", - " def visualize_batch(self, batch, row_idx=0):\n", - " return list(zip(self.tokenizer.convert_ids_to_tokens(batch.input_ids[row_idx]),\n", - " batch.labels[row_idx].numpy(),\n", - " batch.text_labels[row_idx],\n", - " batch.word[row_idx],\n", - " batch.offset_mapping[row_idx].numpy(),\n", - " ))\n", - "\n", - " def set_hfconfig(self, config):\n", - " \"\"\"\n", - " set the category information to huggingface config\n", - " \"\"\"\n", - " config.num_labels = len(self.cates)\n", - " config.id2label = {i: label for i, label in enumerate(self.cates.i2c)}\n", - " config.label2id = {label: i for i, label in enumerate(self.cates.i2c)}" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "tokenizer = AutoTokenizer.from_pretrained(\"raynardj/roberta-pubmed\", add_prefix_space=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = IOBES(\"/Users/xiaochen.zhang/data/valid.iobes\", tokenizer)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "in-O\n", - "blood-O\n", - ";-O\n", - "content-O\n", - "of-O\n", - "cAMP-O\n", - "was-O\n", - "also-O\n", - "decreased-O\n", - "in-O\n", - "lymphocytes-O\n", - "by-O\n", - "33-O\n", - "%-O\n", - ".-O\n", - "At-O\n", - "the-O\n", - "same-O\n", - "time-O\n", - ",-O\n", - "total-O\n", - "content-O\n", - "of-O\n", - "T-cell_type\n", - "lymphocytes-cell_type\n", - "was-O\n", - "decreased-O\n", - "1.5-fold-O\n", - "in-O\n", - "peripheric-O\n", - "blood-O\n", - ".-O\n", - "Treatment-O\n", - "with-O\n", - "I-hydroxyvitamin-O\n", - "D3-O\n", - "(-O\n", - "1-1.5-O\n", - "mg-O\n", - "daily-O\n", - ",-O\n", - "within-O\n", - "4-O\n", - "weeks-O\n", - ")-O\n", - "led-O\n", - "to-O\n", - "normalization-O\n", - "of-O\n", - "total-O\n", - "and-O\n", - "ionized-O\n", - "form-O\n", - "of-O\n", - "Ca2+-O\n", - "and-O\n", - "of-O\n", - "25-O\n", - "(-O\n", - "OH-O\n", - ")-O\n", - "D-O\n", - ",-O\n", - "but-O\n", - "did-O\n", - "not-O\n", - "affect-O\n", - "the-O\n", - "PTH-O\n", - "content-O\n", - "in-O\n", - "blood-O\n", - ".-O\n", - "Concentration-O\n", - "of-O\n", - "the-O\n", - "receptors-protein\n", - "to-O\n", - "1.25-O\n", - "(-O\n", - "OH-O\n", - ")-O\n", - "2D3-O\n", - "was-O\n", - "elevated-O\n", - "up-O\n", - "to-O\n", - "39.7-O\n", - "fmole/mg-O\n", - "after-O\n", - "I-O\n", - "week-O\n", - "of-O\n", - "the-O\n", - "treatment-O\n", - ",-O\n", - "whereas-O\n", - "it-O\n", - "was-O\n", - "decreased-O\n", - "to-O\n", - "the-O\n", - "initial-O\n", - "level-O\n", - "24.8-O\n", - "fmole/mg-O\n", - "within-O\n", - "4-O\n", - "weeks-O\n", - ";-O\n", - "simultaneous-O\n", - "alteration-O\n", - "in-O\n" - ] - } - ], - "source": [ - "for w,l in zip(*dataset[2]):\n", - " print(f\"{w}-{l}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'input_ids': tensor([[ 19, 3741, 2603, ..., 1417, 2617, 11576],\n", - " [ 4590, 2156, 255, ..., 405, 1182, 6608],\n", - " [ 6214, 25683, 3809, ..., 11, 5, 8151],\n", - " ...,\n", - " [13998, 25326, 2413, ..., 5, 2199, 21],\n", - " [11299, 705, 24811, ..., 134, 1589, 2032],\n", - " [ 5804, 924, 14, ..., 366, 1168, 9]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", - " ...,\n", - " [1, 1, 1, ..., 1, 1, 1],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", - " [1, 1, 1, ..., 1, 1, 1]]), 'offset_mapping': tensor([[[ 1, 4],\n", - " [ 1, 2],\n", - " [ 2, 5],\n", - " ...,\n", - " [ 3, 5],\n", - " [ 5, 8],\n", - " [ 1, 6]],\n", - "\n", - " [[ 1, 5],\n", - " [ 1, 1],\n", - " [ 1, 1],\n", - " ...,\n", - " [ 5, 7],\n", - " [ 7, 9],\n", - " [ 9, 14]],\n", - "\n", - " [[ 1, 5],\n", - " [ 5, 8],\n", - " [ 8, 10],\n", - " ...,\n", - " [ 1, 2],\n", - " [ 1, 3],\n", - " [ 1, 10]],\n", - "\n", - " ...,\n", - "\n", - " [[ 1, 5],\n", - " [ 5, 8],\n", - " [ 8, 10],\n", - " ...,\n", - " [ 1, 3],\n", - " [ 1, 7],\n", - " [ 1, 3]],\n", - "\n", - " [[ 1, 5],\n", - " [ 5, 6],\n", - " [ 6, 10],\n", - " ...,\n", - " [ 2, 3],\n", - " [ 1, 1],\n", - " [ 1, 2]],\n", - "\n", - " [[ 1, 7],\n", - " [ 1, 5],\n", - " [ 1, 4],\n", - " ...,\n", - " [ 3, 5],\n", - " [ 5, 7],\n", - " [ 1, 2]]]), 'labels': tensor([[0, 1, 1, ..., 0, 0, 0],\n", - " [2, 0, 2, ..., 0, 0, 0],\n", - " [0, 0, 0, ..., 0, 0, 0],\n", - " ...,\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [0, 0, 0, ..., 2, 0, 2],\n", - " [0, 0, 0, ..., 0, 0, 0]])}" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset.one_batch()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": true - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/72_pl_training.ipynb b/nbs/72_pl_training.ipynb deleted file mode 100644 index 3ce4854..0000000 --- a/nbs/72_pl_training.ipynb +++ /dev/null @@ -1,466 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pytorch Lighting training\n", - "> on huggingface transformers" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp hf.train" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "from forgebox.hf.data import IOBES\n", - "from forgebox.imports import *\n", - "from forgebox.loop import chunkify\n", - "import pytorch_lightning as pl\n", - "from transformers import (\n", - " AutoModelForTokenClassification,\n", - " AutoTokenizer,\n", - " pipeline\n", - ")\n", - "from tqdm.notebook import tqdm\n", - "from typing import Callable, List\n", - "from torch import device" - ] - }, - { - "cell_type": "code", - "execution_count": 290, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "try:\n", - " ishell = get_ipython()\n", - " IS_JUPYTER = True\n", - " from tqdm.notebook import tqdm\n", - "except NameError:\n", - " IS_JUPYTER = False\n", - " from tqdm import tqdm" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "# !pip install transformers==4.9.1\n", - "# !pip install pytorch-lightning==1.3.8\n", - "# !pip install tensorflow==2.2.0" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load model and tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "\n", - "# ner model and tokenizer\n", - "def ner_model_from(\n", - " name:str, dataset: IOBES\n", - "):\n", - " \"\"\"\n", - " name: from_pretrain(name)\n", - " \"\"\"\n", - " model = AutoModelForTokenClassification.from_pretrained(\n", - " name,\n", - " num_labels=len(dataset.cates),\n", - " )\n", - " dataset.set_hfconfig(model.config)\n", - " return model\n", - "\n", - "def ner_tokenizer_from(\n", - " name: str\n", - "):\n", - " return AutoTokenizer.from_pretrained(\n", - " name, add_prefix_space=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Lightning data module" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "\n", - "# ner data module\n", - "class NERDataModule(pl.LightningDataModule):\n", - " def __init__(self, train_ds, val_ds, batch_size=32):\n", - " super().__init__()\n", - " self.train_ds = train_ds\n", - " self.val_ds = val_ds\n", - " self.batch_size = batch_size\n", - "\n", - " def train_dataloader(self):\n", - " return self.train_ds.dataloader(batch_size=self.batch_size, shuffle=True)\n", - "\n", - " def val_dataloader(self):\n", - " return self.val_ds.dataloader(batch_size=self.batch_size*2, shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "\n", - "# ner module\n", - "class NERModule(pl.LightningModule):\n", - " \"\"\"\n", - " PyTorch lightning module for training ner model\n", - " \"\"\"\n", - " def __init__(\n", - " self, model,\n", - " ):\n", - " \"\"\"\n", - " model: huggingface transformer model for ner\n", - " \"\"\"\n", - " super().__init__()\n", - " self.model = model\n", - "\n", - " def forward(self, batch):\n", - " return self.model(\n", - " input_ids=batch['input_ids'],\n", - " attention_mask=batch['attention_mask'],\n", - " labels=batch['labels'])\n", - " \n", - " def training_step(self, batch, batch_idx):\n", - " outputs = self(batch)\n", - " loss = outputs.loss\n", - " self.log(\"loss\", loss)\n", - " self.log(\"acc\", self.calcualte_acc(outputs, batch.labels))\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " outputs = self(batch)\n", - " loss = outputs.loss\n", - " self.log(\"val_loss\", loss)\n", - " self.log(\"val_acc\", self.calcualte_acc(outputs, batch.labels))\n", - " return loss\n", - " \n", - " def calcualte_acc(self, outputs, labels):\n", - " pred_idx = outputs.logits.argmax(-1)\n", - " mask = torch.ones_like(pred_idx)\n", - " mask[labels==-100]=False\n", - " return (pred_idx[mask]==labels[mask]).float().mean()\n", - " \n", - " def configure_optimizers(self):\n", - " # discriminative learning rate\n", - " param_groups = [\n", - " {'params': self.model.roberta.parameters(), 'lr': 5e-6},\n", - " {'params': self.model.classifier.parameters(), 'lr': 1e-3},\n", - " ]\n", - " optimizer = torch.optim.Adam(param_groups, lr=1e-3)\n", - " return optimizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Enhance pipeline" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "def clean_ner_output(self, outputs):\n", - " \"\"\"\n", - " Cleaning output for NER task\n", - " \"\"\"\n", - " results = []\n", - " current = []\n", - " last_idx = 0\n", - " # make to sub group by position\n", - " for output in outputs:\n", - " if output[\"start\"] in [last_idx, last_idx-1]:\n", - " current.append(output)\n", - " else:\n", - " results.append(current)\n", - " current = [output, ]\n", - " last_idx = output[\"end\"]\n", - " if len(current) > 0:\n", - " results.append(current)\n", - "\n", - " # from tokens to string\n", - " strings = []\n", - " for c in results:\n", - " tokens = []\n", - " starts = []\n", - " ends = []\n", - " for o in c:\n", - " tokens.append(o['word'])\n", - " starts.append(o['start'])\n", - " ends.append(o['end'])\n", - "\n", - " new_str = self.tokenizer.convert_tokens_to_string(tokens)\n", - " if new_str != '':\n", - " strings.append(dict(\n", - " word=new_str,\n", - " start=min(starts),\n", - " end=max(ends),\n", - " entity=c[0]['entity_group']\n", - " ))\n", - " return strings" - ] - }, - { - "cell_type": "code", - "execution_count": 281, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "class NERInference:\n", - " \"\"\"\n", - " NER Inference pipeline\n", - " ner = NERInference.from_pretrained('xxxx/xxxx')\n", - " ner.predict(['text1','text2'])\n", - " \"\"\"\n", - "\n", - " def __init__(self, model, tokenizer, name=None):\n", - " super().__init__()\n", - " self.model = model.eval()\n", - " self.tokenizer = tokenizer\n", - " self.name = name if name else \"NER model\"\n", - "\n", - " def __repr__(self):\n", - " return f\"[NERInference on {self.name}]\"\n", - "\n", - " def to(self, device_str):\n", - " self.model = self.model.to(device(device_str))\n", - " return self\n", - "\n", - " @classmethod\n", - " def from_pretrained(cls, tag):\n", - " \"\"\"\n", - " Load from pretrained model and tokenizer\n", - " \"\"\"\n", - " model = AutoModelForTokenClassification.from_pretrained(tag)\n", - " tokenizer = AutoTokenizer.from_pretrained(tag)\n", - " return cls(model=model, tokenizer=tokenizer, name=model.config._name_or_path)\n", - " \n", - " def __call__(self, data, batch_size=32, dev=device(\"cpu\")):\n", - " if type(data) == str:\n", - " return self.batch_predict([data,])\n", - " else:\n", - " return self.predict(data, dev=dev, batch_size=batch_size)\n", - "\n", - " def predict(\n", - " self,\n", - " texts: List[str],\n", - " dev=device(\"cpu\"),\n", - " batch_size: int = 32,\n", - " progress_bar: bool = True\n", - " ) -> pd.DataFrame:\n", - " \"\"\"\n", - " Predict a list of sentences/ paragraphs\n", - " \"\"\"\n", - " # place the model into device\n", - " self.model = self.model.to(dev)\n", - " iterator = list(enumerate(chunkify(texts, bs=batch_size)))\n", - " if progress_bar:\n", - " iterator = tqdm(iterator, leave=False)\n", - "\n", - " # run through iterator\n", - " all_dfs = []\n", - " for i, text_b in iterator:\n", - " # by batch prediction\n", - " batch_df = self.batch_predict(text_b)\n", - " if len(batch_df) > 0:\n", - " # calculate the row number\n", - " batch_df['text_id'] = batch_df.apply(\n", - " lambda row: i*batch_size+row.batch_row_sn, axis=1)\n", - " all_dfs.append(batch_df)\n", - "\n", - " # place the model back to cpu\n", - " self.model = self.model.to(\"cpu\")\n", - " return pd.concat(all_dfs).reset_index(drop=True)\n", - " \n", - " def tokenizing(self, texts):\n", - " inputs = self.tokenizer(\n", - " texts,\n", - " padding=\"max_length\",\n", - " max_length=self.tokenizer.model_max_length,\n", - " return_attention_mask=True,\n", - " return_tensors='pt', truncation=True, return_offsets_mapping=True\n", - " ).to(self.model.device)\n", - " return inputs\n", - "\n", - "\n", - " def batch_predict(self, texts:List[str])-> pd.DataFrame:\n", - " \"\"\"\n", - " Predict a single batch of sentences\n", - " \"\"\"\n", - " id2label = self.model.config.id2label\n", - " inputs = self.tokenizing(texts)\n", - "\n", - " with torch.no_grad():\n", - " outputs = self.model(input_ids=inputs.input_ids,\n", - " attention_mask=inputs.attention_mask)\n", - " inputs = inputs.to(device('cpu'))\n", - "\n", - " pred_idx = outputs.logits.argmax(-1).to(device(\"cpu\"))\n", - " batch_size = pred_idx.size(0)\n", - " offsets = inputs.offset_mapping\n", - " results = []\n", - " for bi in range(batch_size):\n", - " text = texts[bi]\n", - " input_ids = inputs.input_ids[bi]\n", - " word_ids = inputs.word_ids(bi)\n", - " pred_ids = pred_idx[bi]\n", - " # initial values for the row\n", - " last_pos = 0\n", - " previous_has_positive = False\n", - " current_start = 0\n", - " current_index = 0\n", - " current_id = 0\n", - " line = []\n", - " for ti in range(1, len(input_ids)):\n", - " if input_ids[ti] == self.tokenizer.sep_token_id:\n", - " break\n", - " # is the current token an appending sub-word?\n", - " if word_ids[ti] == last_pos:\n", - " pass\n", - " # is current token negative\n", - " elif pred_ids[ti].item() == 0:\n", - " # store the previous hanging prediction\n", - " if previous_has_positive:\n", - " start = current_start\n", - " end = offsets[bi, ti, 0].item()\n", - " line.append({\n", - " \"start\": start, \"end\": end,\n", - " \"entity\": id2label[current_id],\n", - " \"word\": text[start:end],\n", - " \"index\": current_index,\n", - " })\n", - "\n", - " current_start = offsets[bi, ti, 0].item()\n", - " previous_has_positive = False\n", - " current_id = 0\n", - " current_index = ti\n", - " # has positive prediction index, other than zero\n", - " else:\n", - " if previous_has_positive:\n", - " # different than the previous\n", - " if current_id != pred_ids[ti].item():\n", - " start = current_start\n", - " end = offsets[bi, ti, 0].item()\n", - " line.append({\n", - " \"start\": start,\n", - " \"end\": end,\n", - " \"entity\": id2label[current_id],\n", - " \"word\": text[start:end],\n", - " \"index\": current_index,\n", - " })\n", - " current_start = offsets[bi, ti, 0].item()\n", - " # this is the 1st postive predict for a while\n", - " else:\n", - " current_start = offsets[bi, ti, 0].item()\n", - " previous_has_positive = True\n", - " current_index = ti\n", - " current_id = pred_ids[ti].item()\n", - "\n", - " last_pos = word_ids[ti]\n", - " if previous_has_positive:\n", - " start = current_start\n", - " end = offsets[bi, ti, 1].item()\n", - " line.append({\n", - " \"start\": start,\n", - " \"end\": end,\n", - " \"entity\": id2label[current_id],\n", - " \"word\": text[start:end],\n", - " \"index\": current_index,\n", - " })\n", - "\n", - " results.append(line)\n", - " all_dfs = []\n", - " for i, res in enumerate(results):\n", - " sub_df = pd.DataFrame(res)\n", - " sub_df[\"batch_row_sn\"] = i\n", - " all_dfs.append(sub_df)\n", - " return pd.concat(all_dfs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/cross_entropy_weighter.ipynb b/nbs/cross_entropy_weighter.ipynb deleted file mode 100644 index 95018ab..0000000 --- a/nbs/cross_entropy_weighter.ipynb +++ /dev/null @@ -1,1298 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Create multi-task adjustment for cross-entropy\n", - "> A weighter for multi-task learning with different softmax+cross-entropy" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Tools and imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp multitask_ce" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "import torch\n", - "from torch import nn\n", - "import pandas as pd\n", - "import numpy as np\n", - "from typing import Callable" - ] - }, - { - "cell_type": "code", - "execution_count": 120, - "metadata": {}, - "outputs": [], - "source": [ - "from matplotlib import pyplot as plt" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Enters softmax" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "metadata": {}, - "outputs": [], - "source": [ - "softmax = nn.Softmax(-1)\n", - "crit = nn.CrossEntropyLoss()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Experience the problem" - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "metadata": {}, - "outputs": [], - "source": [ - "def test_loss(model,iters:300,get_xy):\n", - " losses=[]\n", - " with torch.no_grad():\n", - " for i in range(iters):\n", - " x,y_true = get_xy(model.nb_output)\n", - " y_vec = model(x)\n", - " loss = model.crit(y_vec,y_true)\n", - " losses.append(loss)\n", - " return torch.stack(losses).mean()" - ] - }, - { - "cell_type": "code", - "execution_count": 245, - "metadata": {}, - "outputs": [], - "source": [ - "def create_softmax_pipeline(\n", - " nb_layers:int,\n", - " nb_output:int,\n", - " hs:int=500,\n", - " crit:nn.Module=nn.CrossEntropyLoss(),\n", - " )->nn.Module:\n", - " modules = (nb_layers-1)*[nn.Linear(hs,hs),nn.Dropout(.3)]\n", - " modules+=[nn.Linear(hs,nb_output),]\n", - " model = nn.Sequential(*modules)\n", - " model.hs = hs\n", - " model.nb_output = nb_output\n", - " model.__class__.crit = crit\n", - " model.__class__.test_loss = test_loss\n", - " return model" - ] - }, - { - "cell_type": "code", - "execution_count": 246, - "metadata": {}, - "outputs": [], - "source": [ - "def random_input(nb_output):\n", - " return torch.rand(2,500),torch.randint(low=0,high=nb_output,size = (2,))\n", - "\n", - "def inbalanced_input(\n", - " bias:float\n", - " ) -> Callable:\n", - " def inbalanced_input_(nb_output:int):\n", - " return torch.rand(2,500),torch.randint(\n", - " low=0,\n", - " high=max(1,int(nb_output*bias)),\n", - " size = (2,)\n", - " )\n", - " return inbalanced_input_" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "* For models with different branches of output, some output 2 category, some output more, like 200,500\n", - "* Their loss will end up in different scale\n", - "* That makes mixing them up fairly hard and unfair to lesser category tasks" - ] - }, - { - "cell_type": "code", - "execution_count": 247, - "metadata": {}, - "outputs": [], - "source": [ - "import random" - ] - }, - { - "cell_type": "code", - "execution_count": 266, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fa780ff8fdab4d4cb932c4dbfe96c44f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "result = []\n", - "\n", - "for i in tqdm(range(50)):\n", - " c = random.randint(2,3000)\n", - " b = 1-random.random()\n", - " loss = create_softmax_pipeline(1,c).test_loss(300,inbalanced_input(b))\n", - " result.append(dict(classes=c,loss=loss.item(),inbl_bias=b))" - ] - }, - { - "cell_type": "code", - "execution_count": 267, - "metadata": {}, - "outputs": [], - "source": [ - "df = pd.DataFrame(result)" - ] - }, - { - "cell_type": "code", - "execution_count": 268, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
classeslossinbl_bias
022447.7840150.719190
124477.8491580.002960
27066.6254160.297956
32725.7390010.154829
428607.9948820.179631
523767.8230360.554090
64746.2306430.497769
714767.3176500.790722
87296.6523120.671657
911837.1610010.422553
1017927.5710340.375524
1126017.9283470.937953
1219887.6534010.448365
136176.4717110.669720
1412577.1966650.180724
1512087.1422720.603882
1615717.4319090.307233
1718157.5657630.271957
1823707.8402560.130841
1924217.8353620.470347
2026087.9338520.362477
2118337.5664140.551423
2217697.5360470.630381
2319507.6157930.795910
2419107.5857390.343632
2523017.7984650.829628
2623777.8692110.016988
2714967.3931780.119690
2811797.1470580.568997
2914757.3567400.693435
3024097.8359540.455461
315426.3284030.733283
326996.6098920.674927
3329668.0627700.268944
3428318.0060010.822981
3516577.4594730.323996
3614697.3166780.062805
376746.5796550.812632
386546.5500600.601652
391945.3505550.238536
402505.5670150.270642
4123687.8191310.117133
4215737.4296060.003537
4329948.0539260.764925
4421557.7413450.941843
4526427.9191260.612718
464866.2425270.701243
4719487.6443660.970036
4826267.9323630.657842
4916607.4668120.855916
\n", - "
" - ], - "text/plain": [ - " classes loss inbl_bias\n", - "0 2244 7.784015 0.719190\n", - "1 2447 7.849158 0.002960\n", - "2 706 6.625416 0.297956\n", - "3 272 5.739001 0.154829\n", - "4 2860 7.994882 0.179631\n", - "5 2376 7.823036 0.554090\n", - "6 474 6.230643 0.497769\n", - "7 1476 7.317650 0.790722\n", - "8 729 6.652312 0.671657\n", - "9 1183 7.161001 0.422553\n", - "10 1792 7.571034 0.375524\n", - "11 2601 7.928347 0.937953\n", - "12 1988 7.653401 0.448365\n", - "13 617 6.471711 0.669720\n", - "14 1257 7.196665 0.180724\n", - "15 1208 7.142272 0.603882\n", - "16 1571 7.431909 0.307233\n", - "17 1815 7.565763 0.271957\n", - "18 2370 7.840256 0.130841\n", - "19 2421 7.835362 0.470347\n", - "20 2608 7.933852 0.362477\n", - "21 1833 7.566414 0.551423\n", - "22 1769 7.536047 0.630381\n", - "23 1950 7.615793 0.795910\n", - "24 1910 7.585739 0.343632\n", - "25 2301 7.798465 0.829628\n", - "26 2377 7.869211 0.016988\n", - "27 1496 7.393178 0.119690\n", - "28 1179 7.147058 0.568997\n", - "29 1475 7.356740 0.693435\n", - "30 2409 7.835954 0.455461\n", - "31 542 6.328403 0.733283\n", - "32 699 6.609892 0.674927\n", - "33 2966 8.062770 0.268944\n", - "34 2831 8.006001 0.822981\n", - "35 1657 7.459473 0.323996\n", - "36 1469 7.316678 0.062805\n", - "37 674 6.579655 0.812632\n", - "38 654 6.550060 0.601652\n", - "39 194 5.350555 0.238536\n", - "40 250 5.567015 0.270642\n", - "41 2368 7.819131 0.117133\n", - "42 1573 7.429606 0.003537\n", - "43 2994 8.053926 0.764925\n", - "44 2155 7.741345 0.941843\n", - "45 2642 7.919126 0.612718\n", - "46 486 6.242527 0.701243\n", - "47 1948 7.644366 0.970036\n", - "48 2626 7.932363 0.657842\n", - "49 1660 7.466812 0.855916" - ] - }, - "execution_count": 268, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### The pattern" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "There are $n$ tasks $T$ with a set of numbers of classes $C$ where $c_{i}$ is the number of the classes of task $T_{i}$\n", - "\n", - "The number of classes $c_{i}$ has certain correlation of the average loss $L_{i}$\n", - "\n", - "Where $log_{10}(c_{i})$ has a clear linear relationship with $L$\n", - "\n", - "$L_{i} = a.log_{10}(c_{i})$, where a is a fixed constant" - ] - }, - { - "cell_type": "code", - "execution_count": 269, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 269, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAVw0lEQVR4nO3dfZBd913f8ffXqzWs8rQmFiTaSBa0QW0Tjy13a8vNlAl1iMhDE8EorV1MwMOMcGBC3BYNces6ENwSum0hqYcIDxkaiOu6qMrWMMYy0xDIMCMxa0u24jhiRB4krQLIISvX9kLWq2//2Hudq7v34dzVfTz7fs3s6N5zftr9/sbWR0ff8zu/G5mJJGn0XTboAiRJ3WGgS1JJGOiSVBIGuiSVhIEuSSWxYVA/+Morr8xt27YN6sdL0kh67LHHnsnMTY3ODSzQt23bxtzc3KB+vCSNpIj4arNztlwkqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKomBLVuUpLKbPTrPzKETnF1YZPPkBPt2bWf3jqme/Tyv0CWpB2aPznPnwePMLyySwPzCInc8eIwdH36U2aPzPfmZBrok9cDMoRMsLi2vOv6NF5a48+DxnoS6gS5JPXB2YbHpucWlZWYOnej6z7SHLkldUN8vf9XEOAuLS03Htwr8tSp0hR4R/yoinoqIz0fEAxHx7XXnvy0iHoyIkxFxJCK2db1SSRpSjfrlrcIcYPPkRNfraBvoETEF/AwwnZlvBMaAm+uG/QTwjcz8u8CvAL/c7UIlaVg165c3MzE+xr5d27teR9Ee+gZgIiI2ABuBs3Xn3w18svL6AHBTRER3SpSk4daufTI5Mc7U5AQBTE1O8Es/fHVPli+27aFn5nxE/GfgFLAIPJqZj9YNmwJOV8a/GBHngVcDz9QOioi9wF6ArVu3Xnr1kjQENk9OMN8i1M8vLnHsQ2/teR1FWi5XsHIF/t3AZuBlEXFr/bAGvzVXHci8LzOnM3N606aG+7NL0sjZt2s7E+NjTc/3ol/eSJFVLm8BvpyZ5wAi4iDwj4FP1Yw5A2wBzlTaMq8C/rrLtUrSQNw1e5z7D5966Sr1ZZeP8R9+6Fttk+qvv/C7T/GNFy6+GdqrfnkjRXrop4CdEbGx0he/CXi6bsxDwI9VXu8BPpOZq67QJWnU3DV7nE/VhDnA899c5o4Hj3HX7PGXju3eMcXRu9/Kr/6La/vSL2+kSA/9SEQcAB4HXgSOAvdFxIeBucx8CPgE8NsRcZKVK/P6VTCSNJIeOHK66bn7D59i+qrvuCiwd++Y6luA1yv0YFFmfgj4UN3hu2vO/w3wni7WJUlDYblFsyFZWbI4qACv55OiktatIrshjkW0DPVePPG5Vu7lImldavR0Z6NNs265YUvL79OvFSxFGOiS1qVGT3cuLq3c7HzTRz7zUrDfs/tqbt3Z+LmZfq5gKcJAl7QutXoQqP5q/Z7dV/OVj7xjoCtYirCHLmldatcbr25xOywrWIow0CWtK9Uboa3CvGqYbngWYaBLKr1qiM8vLBI02JekiWG64VmEgS6p1KqrWao3QDt5hH2YbngW4U1RSaXW6V7lVW/6O98x1P3yRrxCl1Qq9Q8LtVrN0sytO7dyz+6re1Bdbxnokkqjvr3SSc98LIJbbtgykkFeZaBLKo1G7ZWEVaFefT/V5HH/UWWgSyqNZssMq+Hdas+WMjDQJZVGs5751OQEf/LBfzqAivrLVS6SSqPRR8EN234rveQVuqTSqLZR2m2JW1YGuqShV7sUcXLjOJlwfnGpYWAP+34rvWSgSxpq9UsRaz+EuborIrBuQ7yWPXRJQ63dk57VXRFloEsackV2PBy1XRF7xUCXNNSK7Hg4arsi9oqBLmlozR6d5/m/fbHlmPW0LLEdb4pKGqhme5VvHL+MpQvJ0nLznVjGIobuY+AGyUCXNDCt9ip/YelCy987MT5mmNcx0CUNzM8/9NSa9iov26Za3WKgSxqI2aPzLCwutR9YZ73sy7IW3hSVNBBrWTvuDdDW2gZ6RGyPiGM1X89GxB11Y94cEedrxtzdu5IllUGRteOXBVyxcZxg5crcnnlrbVsumXkCuBYgIsaAeeDTDYZ+LjPf2d3yJJVVu4+Hu2LjOB/6Z28wwDvQaQ/9JuDPM/OrvShGUvnUf8Zn9Wbmvl3bL1rhAq5cuVSdBvrNwANNzt0YEU8AZ4Gfzcyn6gdExF5gL8DWrVs7/NGSRs3s0Xn2HXjipbXk8wuL7DvwBOBWt70QmUU+PhUi4nJWwvoNmfmXdedeCVzIzOci4u3ARzPz9a2+3/T0dM7Nza2xbEmjYMeHH71od8SqKzaOc/Tutw6gotEXEY9l5nSjc51cob8NeLw+zAEy89ma1w9HxK9FxJWZ+Uzn5UoaRXfNHueBI6dZzmQsgltu2NIwzIGmx3VpOlm2eAtN2i0R8ZqIiMrr6yvf9+uXXp6kUXDX7HE+dfgUy5V/8S9n8qnDpwZc1fpT6Ao9IjYCPwD8ZM2x2wEycz+wB3hfRLwILAI3Z9FejqSR98CR0x2Nn5wY71El61uhQM/MF4BX1x3bX/P6XuDe7pYmaRg1WrWy3OL6bfyyYOlCXvT+59/1hn6Uuu74pKikwqqbac0vLJJ86yPgVhquq41FMPOea5ianHjp4aCZ91zjSpYecS8XSYU12kxrcWmZjeOXNdwd8ZYbtqzrD23uN6/QJRXSajOtxaUL3LpzK2OVS/WxCG7duZV7dl/dzxLXPa/QJRXSajOtzZMT3LP7agN8wLxCl1RIq8203AFxOBjokgqZGG8cF5ePhT3yIWGgSypk8cXGHwlXuyRRg2WgSyqk2VJzHyEcHga6pELGmiw2b3Zc/WegSyrklhu2dHRc/eeyRUmFVJck1u+o6FLF4VF4P/Rucz90Sepcq/3QbblIUknYcpHWgWaf66lyMdClkqvukFjdVKu6QyJgqJeMgS6VSKMr8ZlDJxrukDhz6ISBXjIGulQSd80e5/7Dp6guc6heideHeVWrvVk0mrwpKpXA7NH5i8K8qlmYw8oOiSoXA10qgZlDJ1aFeSsT42PukFhCtlykEijSPhmL4EKmq1xKzECXSmDz5ATzbUL9QiZf/sg7+lSRBsGWi1QC+3ZtZ2J8rOUYe+bl5xW6VALV9snMoRPMLywScFFP3Z75+mCgSyWxe8fUS8Huk6Hrk4EulVBtuGv9sIcuSSVhoEtSSbQN9IjYHhHHar6ejYg76sZERHwsIk5GxJMRcV3vSpYkNdK2h56ZJ4BrASJiDJgHPl037G3A6ytfNwAfr/wqSeqTTlsuNwF/nplfrTv+buC3csVhYDIiXtuVCiVJhXQa6DcDDzQ4PgWcrnl/pnLsIhGxNyLmImLu3LlzHf5oSVIrhQM9Ii4H3gX8TqPTDY6t2isoM+/LzOnMnN60aVPxKiVJbXVyhf424PHM/MsG584AW2revw44eymFSZI600mg30LjdgvAQ8B7K6tddgLnM/Nrl1ydJKmwQk+KRsRG4AeAn6w5djtAZu4HHgbeDpwEXgBu63qlkqSWCgV6Zr4AvLru2P6a1wn8dHdLkyR1widFJakkDHRJKgkDXZJKwkCXpJIw0CWpJAx0SSoJA12SSsJAl6SSMNAlqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKgkDXZJKwkCXpJIo9AEXUhnNHp1n5tAJzi4ssnlygn27trN7x9Sgy5LWzEDXujR7dJ47Dx5ncWkZgPmFRe48eBzAUNfIMtC1btRekV8WwXLmRecXl5aZOXTCQNfIMtC1Ltw1e5z7D5+iGuH1YV51dmGxf0VJXeZNUZXe7NH5i8K8lc2TEz2vR+oVr9BVSvXtlSJhPjE+xr5d23tem9QrBrpKp/6GZ7P2CsBYBBcyXeWiUjDQVRrVq/L5gn3wAP7LP7/GEFdpGOgqhfqr8nYC+JGdWw1zlYqBrlKYOXSibZjbXlHZGegqhXbLDSfGx/ilH77aEFepFVq2GBGTEXEgIr4YEU9HxI11598cEecj4ljl6+7elCs11mq54dTkhGGudaHoFfpHgUcyc09EXA5sbDDmc5n5zu6VJhW3b9f2VT10r8q13rQN9Ih4JfB9wI8DZOY3gW/2tiypsWYbalVD2822tJ4VuUL/HuAc8JsRcQ3wGPCBzHy+btyNEfEEcBb42cx8qv4bRcReYC/A1q1bL6lwrT/tNtSqDXZpPSrSQ98AXAd8PDN3AM8DH6wb8zhwVWZeA/w3YLbRN8rM+zJzOjOnN23adAllaz1qtJKluqGWpGKBfgY4k5lHKu8PsBLwL8nMZzPzucrrh4HxiLiyq5Vq3Wu2ksUNtaQVbQM9M/8COB0R1U0ubgK+UDsmIl4TEVF5fX3l+369y7VqnWu2ksUNtaQVRXdbfD9wf0Q8CVwL/MeIuD0ibq+c3wN8vtJD/xhwc2aLDTSkNdi3azsT42MXHXNDLelbYlC5Oz09nXNzcwP52Rpdfmyc1ruIeCwzpxud80lRjRRXskjN+QEXklQSBroklYQtFw2EvXCp+wx09V27Jz4lrY0tF/WdT3xKvWGgq+984lPqDQNdfecTn1JvGOjqO5/4lHrDm6LqO/cul3rDQNclqS4/nF9YZCyC5UymCgS0T3xK3Wega83qlx8uV/YFchmiNBj20LVmjZYfVrkMUeo/A11r1m6ZocsQpf4y0LVm7ZYZugxR6i8DXWvWaPlhlcsQpf7zpqjWrHb5YaerXCR1n4GuS+LyQ2l42HKRpJIw0CWpJAx0SSoJA12SSsJAl6SSMNAlqSQMdEkqCQNdkkrCQJekkigU6BExGREHIuKLEfF0RNxYdz4i4mMRcTIinoyI63pTriSpmaKP/n8UeCQz90TE5cDGuvNvA15f+boB+HjlV0lSn7S9Qo+IVwLfB3wCIDO/mZkLdcPeDfxWrjgMTEbEa7terSSpqSItl+8BzgG/GRFHI+I3IuJldWOmgNM1789Ujl0kIvZGxFxEzJ07d27NRUuSVisS6BuA64CPZ+YO4Hngg3VjosHvy1UHMu/LzOnMnN60aVPHxUqSmisS6GeAM5l5pPL+ACsBXz9mS8371wFnL708SVJRbQM9M/8COB0R1Y+fuQn4Qt2wh4D3Vla77ATOZ+bXuluqJKmVoqtc3g/cX1nh8iXgtoi4HSAz9wMPA28HTgIvALf1oFZJUguFAj0zjwHTdYf315xP4Ke7WJckqUM+KSpJJWGgS1JJGOiSVBIGuiSVhIEuSSVhoEtSSRjoklQSBroklYSBLkklYaBLUkkU3ctFBcwenWfm0AnOLiyyeXKCfbu2s3vHqm3hJaknDPQumT06z50Hj7O4tAzA/MIidx48DmCoS+oLWy5dMnPoxEthXrW4tMzMoRMDqkjSemOgd8nZhcWOjktStxnoXbJ5cqKj45LUbQZ6l+zbtZ2J8bGLjk2Mj7Fv1/Ymv0OSusubol1SvfHpKhdJg2Kgd9HuHVMGuKSBseUiSSVhoEtSSRjoklQSBroklYQ3RQtwjxZJo8BAb8M9WiSNClsubbhHi6RRYaC34R4tkkaFgd6Ge7RIGhWFAj0ivhIRxyPiWETMNTj/5og4Xzl/LCLu7n6pg+EeLZJGRSc3Rb8/M59pcf5zmfnOSy1o2LhHi6RR4SqXAtyjRdIoKNpDT+DRiHgsIvY2GXNjRDwREb8fEW9oNCAi9kbEXETMnTt3bk0FS5IaK3qF/qbMPBsR3wn8QUR8MTP/uOb848BVmflcRLwdmAVeX/9NMvM+4D6A6enpvMTaJUk1Cl2hZ+bZyq9/BXwauL7u/LOZ+Vzl9cPAeERc2eVaJUkttA30iHhZRLyi+hp4K/D5ujGviYiovL6+8n2/3v1yJUnNFGm5fBfw6UpebwD+R2Y+EhG3A2TmfmAP8L6IeBFYBG7OTFsqktRHbQM9M78EXNPg+P6a1/cC93a3NElSJ3xSVJJKwkCXpJIY2QeL3KNcki42koHuHuWStNpItlzco1ySVhvJQHePcklabSQD3T3KJWm1kQx09yiXpNVG8qaoe5RL0mojGejgHuWSVG8kWy6SpNUMdEkqCQNdkkrCQJekkjDQJakkDHRJKokY1AcLRcQ54KsD+eGX5krgmUEX0QVlmQeUZy5lmQeUZy7DOI+rMnNToxMDC/RRFRFzmTk96DouVVnmAeWZS1nmAeWZy6jNw5aLJJWEgS5JJWGgd+6+QRfQJWWZB5RnLmWZB5RnLiM1D3voklQSXqFLUkkY6JJUEgZ6AxGxJSL+MCKejoinIuIDLcb+o4hYjog9/ayxiKLziIg3R8Sxypg/6nedRRSZS0S8KiJ+NyKeqIy5bRC1thIR3x4Rf1pT4y80GPNtEfFgRJyMiCMRsa3/lbZXcC7/OiK+EBFPRsT/jYirBlFrK0XmUTN2T0RkRAznUsbM9KvuC3gtcF3l9SuAPwP+QYNxY8BngIeBPYOuey3zACaBLwBbK++/c9B1X8Jc/i3wy5XXm4C/Bi4fdO11NQbw8srrceAIsLNuzE8B+yuvbwYeHHTdlzCX7wc2Vl6/bxjnUmQeNf/f/TFwGJgedN2NvrxCbyAzv5aZj1de/z/gaaDRp2m8H/jfwF/1sbzCCs7jXwIHM/NUZdwozyWBV0REAC9nJdBf7GuhbeSK5ypvxytf9SsT3g18svL6AHBTZU5DpchcMvMPM/OFytvDwOv6WGIhBf+bAPwi8J+Av+lXbZ0y0Nuo/HN3Byt/a9cenwJ+CNjf/6o612wewPcCV0TEZyPisYh4b79r61SLudwL/H3gLHAc+EBmXuhrcQVExFhEHGPlQuAPMrN+HlPAaYDMfBE4D7y6v1UWU2AutX4C+P3+VNaZdvOIiB3Alsz8vYEUWJCB3kJEvJyVK/A7MvPZutO/CvxcZi73v7LOtJnHBuAfAu8AdgH/PiK+t88lFtZmLruAY8Bm4Frg3oh4ZZ9LbCszlzPzWlauVq+PiDfWDWl0NT6U64sLzAWAiLgVmAZm+llfUa3mERGXAb8C/JtB1VeUgd5ERIyzEhz3Z+bBBkOmgf8ZEV8B9gC/FhG7+1hiIQXmcQZ4JDOfz8xnWOkRXtPPGosqMJfbWGkfZWaeBL4M/L1+1tiJzFwAPgv8YN2pM8AWgIjYALyKlfbR0GoxFyLiLcC/A96VmX/b59I60mQerwDeCHy28ud9J/DQMN4YNdAbqPQrPwE8nZn/tdGYzPzuzNyWmdtY6XP+VGbO9rHMtorMA/g/wD+JiA0RsRG4gZX+9FApOJdTwE2V8d8FbAe+1J8Ki4mITRExWXk9AbwF+GLdsIeAH6u83gN8Jit35YZJkblUWhW/zkqYD+X9mXbzyMzzmXllzZ/3w6zMZ24gBbewYdAFDKk3AT8KHK/01WBlBcVWgMwcib45BeaRmU9HxCPAk8AF4Dcy8/MDqba1Iv9NfhH47xFxnJW2xc9V/tUxTF4LfDIixli5oPpfmfl7EfFhYC4zH2LlL67fjoiTrFyZ3zy4clsqMpcZVm5Q/07lvu6pzHzXwCpurMg8RoKP/ktSSdhykaSSMNAlqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKon/D8XLOB7FzJCYAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.scatter(np.log10(df.classes),df.loss,)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### The solution\n", - "\n", - "Assume we have certain function, will produce a weight that will bring each cross entropy loss to the same constant scale ${a}$.\n", - "\n", - "$L_{i}.f(c_{i})=a$\n", - "\n", - "$f(c_{i})a.log_{10}(c_{i})=a$\n", - "\n", - "Here we can get how to calculate $f(c_{i})$\n", - "\n", - "$f(c_{i})=\\frac{1}{log_{10}(c_{i})}$" - ] - }, - { - "cell_type": "code", - "execution_count": 270, - "metadata": {}, - "outputs": [], - "source": [ - "def adjust(nb_class):\n", - " return 1/np.log10(nb_class)" - ] - }, - { - "cell_type": "code", - "execution_count": 271, - "metadata": {}, - "outputs": [], - "source": [ - "df[\"lambda_weighted\"] = df.apply(\n", - " lambda row:row['loss']*adjust(row['classes']),\n", - " axis=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For now it's a about the same scale of loss" - ] - }, - { - "cell_type": "code", - "execution_count": 272, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 272, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAa3ElEQVR4nO3df4ycx33f8fcn8lG+iHQpmZdUOpOhTCWinYoQmWuQlrJRq4Ho2EBIWUbotpDkKADbugWkQiZAK0FqVwgslrXgFi3CMlAQKyEcKxbJ0mVTmqCYsHIrxkfyTIo8MZJ/KBJJWGfJNMX66hypb//YOWu13r199m5/PDv3eQGH25uZ3ZvZ59nvPs/MPM8oIjAzs3z9VK8rYGZmneVAb2aWOQd6M7PMOdCbmWXOgd7MLHNv63UFai1ZsiSWL1/e62qYmfWVo0ePfi8ihurllS7QL1++nNHR0V5Xw8ysr0h6sVGeu27MzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5poGeklLJR2SNC7plKT765RZL+mEpDFJo5JuS+kfSGnTP/9P0oZONMTMzOorcmXsZeDBiDgmaRFwVNKBiDhdVeYgsDciQtIq4AlgZUQcAm4FkHQd8ALw1fY2wczMZtL0iD4izkfEsfT4dWAcGK4pcyneXKrqGqDeslUfBf48In44tyqbmVkrWuqjl7QcWA0cqZN3p6TngH3AfXWe/jHgiw1ed1Pq8hmdmJhopUpmZtZE4UAvaSHwJPBARFyszY+I3RGxEtgAPFzz3OuBW4D99V47InZExEhEjAwN1b35mpmZzVKhQC9pgEqQ3xkRu2YqGxGHgRWSllQl/wawOyKmZl1TMzOblSKzbgQ8BoxHxKMNytyUyiFpDbAAeLWqyD+hQbeNmZl1VpFZN2uBu4GTksZS2kPAMoCI2A7cBdwjaQqYBDZOD86mfv2lwF+2teZmZlZI00AfEU8DalJmK7C1Qd53qJmlY2Zm3eMrY83MMudAb2aWOQd6M7PMOdCbmWXOgd7MLHMO9GZmmXOgNzPLnAO9mVnmHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5hzozcwyV2TN2KWSDkkal3RK0v11yqyXdELSmKRRSbdV5S2T9NX0/NNpaUEzM+uSImvGXgYejIhjkhYBRyUdiIjTVWUOAnsjIiStAp4AVqa8x4Hfi4gDkhYCb7SzAWZmNrOmR/QRcT4ijqXHrwPj1KwBGxGXphcDB64BphcGfy/wtog4UFXuh22sv5mZNdFSH33qdlkNHKmTd6ek54B9wH0p+ReAC5J2STouaZukq+o8d1Pq8hmdmJhotQ1mZjaDwoE+dbs8CTwQERdr8yNid0SsBDYAD6fktwHvAz4J/H3g3cDH6zx3R0SMRMTI0NBQy40wM7PGCgV6SQNUgvzOiNg1U9mIOAyskLQEeBk4HhHfiojLwB5gzRzrbGZmLSgy60bAY8B4RDzaoMxNqRyS1gALgFeBrwPXSpo+TL8dOF3vNczMrDOKzLpZC9wNnJQ0ltIeApYBRMR24C7gHklTwCSwMQ3OXpH0SeBg+iI4CvxBm9tQyJ7jZ9m2/wznLkxyw+JBNq+7mQ2rh5s/0cysz+nNyTLlMDIyEqOjo219zT3Hz/KpXSeZnLry47TBgav47EducbA3syxIOhoRI/Xy5sWVsdv2n3lLkAeYnLrCtv1nelQjM7PumReB/tyFyZbSzcxyMi8C/Q2LB1tKNzPLybwI9JvX3czgwFuv0xocuIrN627uUY3MzLqnyKybvjc94OpZN2Y2H82LQA+VYO/Abmbz0bzoujEzm88c6M3MMudAb2aWOQd6M7PMOdCbmWXOgd7MLHMO9GZmmXOgNzPLnAO9mVnmHOjNzDLnQG9mlrkia8YulXRI0rikU5Lur1NmvaQTksYkjUq6rSrvSkofk7S33Q0wM7OZFbmp2WXgwYg4JmkRcFTSgYioXuT7ILA3IkLSKuAJYGXKm4yIW9tbbTMzK6rpEX1EnI+IY+nx68A4MFxT5lK8ufjsNUC5FqI1M5vHWuqjl7QcWA0cqZN3p6TngH3AfVVZb0/dOc9I2tDgdTelMqMTExOtVMnMzJooHOglLQSeBB6IiIu1+RGxOyJWAhuAh6uylqWVyf8p8HlJK+o8d0dEjETEyNDQUMuNMDOzxgoFekkDVIL8zojYNVPZiDgMrJC0JP19Lv3+FvAXVM4IzMysS4rMuhHwGDAeEY82KHNTKoekNcAC4FVJ10q6OqUvAdYCp+u9hpmZdUaRWTdrgbuBk5LGUtpDwDKAiNgO3AXcI2kKmAQ2phk47wH+q6Q3qHypPFIzW8fMzDqsaaCPiKcBNSmzFdhaJ/1/A7fMunZmZjZnvjLWzCxzDvRmZplzoDczy5wDvZlZ5hzozcwy50BvZpY5B3ozs8w50JuZZc6B3swscw70ZmaZc6A3M8ucA72ZWeYc6M3MMudAb2aWOQd6M7PMOdCbmWXOgd7MLHNF1oxdKumQpHFJpyTdX6fMekknJI1JGpV0W03+OySdlfSf21l5MzNrrsiasZeBByPimKRFwFFJB2rWfj0I7E3rxK4CngBWVuU/DPxl22ptZmaFNT2ij4jzEXEsPX4dGAeGa8pciohIf14DTD9G0i8BPwt8tV2VNjOz4lrqo5e0HFgNHKmTd6ek54B9wH0p7aeAzwGbm7zuptTlMzoxMdFKlczMrInCgV7SQuBJ4IGIuFibHxG7I2IlsIFKVw3AJ4D/EREvzfTaEbEjIkYiYmRoaKh47avsOX6WtY88xY1b9rH2kafYc/zsrF7HzCw3RfrokTRAJcjvjIhdM5WNiMOSVkhaAvwD4H2SPgEsBBZIuhQRW+Za8Wp7jp/lU7tOMjl1BYCzFyb51K6TAGxYPTzTU83Msldk1o2Ax4DxiHi0QZmbUjkkrQEWAK9GxD+LiGURsRz4JPB4u4M8wLb9Z34c5KdNTl1h2/4z7f5XZmZ9p8gR/VrgbuCkpLGU9hCwDCAitgN3AfdImgImgY1Vg7Mdd+7CZEvpZmbzSdNAHxFPA2pSZiuwtUmZPwL+qIW6FXbD4kHO1gnqNywe7MS/MzPrK1lcGbt53c0MDlz1lrTBgavYvO7mHtXIzKw8Cg3Glt30gOu2/Wc4d2GSGxYPsnndzR6INTMjk0APlWDvwG5m9pOy6LoxM7PGHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5hzozcwy50BvZpY5B3ozs8xlc1MzM+sve46f9R1nu8SB3sy6zus8d1eRNWOXSjokaVzSKUn31ymzXtIJSWOSRiXdltJ/TtLRlH5K0r/oRCOsv+w5fpa1jzzFjVv2sfaRp9hz/Gyvq2Rd5nWeu6vIEf1l4MGIOCZpEXBU0oGIOF1V5iCwNyJC0irgCWAlcB74hxHxI0kLgWcl7Y2Ic+1uiPUHH8kZeJ3nbmt6RB8R5yPiWHr8OjAODNeUuVS1GPg1QKT0v42IH6X0q4v8P8ubj+QMGq/n7HWeO6OlwCtpObAaOFIn705JzwH7gPuq0pdKOgG8BGytdzQvaVPq8hmdmJhorQXWV3wkZ+B1nrut8GBs6np5EnggIi7W5kfEbmC3pPcDDwO/mtJfAlZJugHYI+nLEfHdmufuAHYAjIyMBH3EMwdac8PiQc7WCeq9OJLztusdr/PcXYUCvaQBKkF+Z0TsmqlsRByWtELSkoj4XlX6OUmngPcBX55LpcvC/c2t27zu5re8Z9CbIzlvu97zOs/dU2TWjYDHgPGIeLRBmZtSOSStARYAr0p6l6TBlH4tsBbIpjPW/c2t27B6mM9+5BaGFw8iYHjxIJ/9yC1d/8B729l8UuSIfi1wN3BS0lhKewhYBhAR24G7gHskTQGTwMY0A+c9wOckBSDgP0TEyXY3olfc3zw7ZTiS87az+aRpoI+Ip6kE6ZnKbAW21kk/AKyade1Krkz9zdaaMmw7jxFYt3i64xx45sDslOGCqV5vu+kxgrMXJgneHCPwxWPWCb4Fwhx0YuZA7kd5ZRkE7fWsj5nGCHLa3lYOevM6p3IYGRmJ0dHRXlejJ2qD4LTFgwN8+td/MYsAsPaRp+p2mQwvHuRrW27vQY1648Yt+6j3yRPw7Uc+3O3qWAYkHY2IkXp57ropkXpHeQAXJqf4N18aY3kG94bxIGiFrwy1bnKgL5GZgt300V+/9+X2U4Dr5FhCr8cIbH5xoC+RosGun+d790uA6/RgaVmuJ7D5wYOxJVLvqtFG+rWro9eDoEV1Y7C0DNcT2PzgQF8i0x/6z3zlFN//4dSMZcvY1VFUPwS4Rl+k9QaSzcrOXTdt0M6+3A2rhzn+u3fw+Y23MjxDMD97YbLvB2bLrNEXqaBU73kZrkmw8nOgn6NO9eVuWD3M17bczuc33srAVfUvTO73gdky27zu5rqXgweUZnykny+68hdUdznQz1Gnb461bf8Zpq40vtahnwdmy2zD6uG689yhPOMjZbkxW6tBu5+/oPqV++jnqNPzwou8TlkCT26GS3A/nJl0et8rcpX2bK507uZVwblfaV6Uj+jnqNPzwou8TlkCT7uU5bS+7FNBO7nvFT3qns1ZRbcumiv7mUM39/PsA32n38xOB4N6r9+p/1UG3fhwFt0nujnXfTb7aSf3vaIBfDZBu1sXzX3mK6dK0bVVT7e/hLLuuunGDbQ6PS+89vUX//QAEfCDyaksT0U7fVrf6j7Rjamgs91PO7nvFQ3gs7ndczdWGdtz/GzDKcpl6Ors9k3tsg703XozOx0M+mHeebt0+rS+jHeNnEudOrVvFA3gswna3bhobqaj9jJ0dXb7nk9ZB3rfQKv/dHpBkDLsE7UDhI0uwurlflo0gM82aHf64GWm964MXZ3dXvimaaCXtBR4HPi7wBvAjoj4jzVl1gMPp/zLwAMR8bSkW4HfB94BXAF+LyK+1N4mNFaGVYSsNZ0+re/1PlGvm0ZQdyrnbOvUjpkmrQTwMp5xNtrOiwcHSlHXbnRfVWt6P3pJ1wPXR8QxSYuAo8CGiDhdVWYh8H/TOrGrgCciYqWkXwAiIp6XdEN67nsi4kKj/9fO+9HXu7/74MBVvnlUyXVySlyv94lG9+OvDfazrVOv21dbl15Nbaz3Pky/x8MlGdtq9/sz0/3oi6wZex44nx6/LmkcGAZOV5W5VPWUa0j7bET8dVWZc5JeAYaAhoG+nfrhBlqe5/uTOnmE2Ot9olGXwnQAmmudyjIG0euVxKq3c+1ZUzvrMpfPbzfPhFpaYUrScuAw8Pci4mJN3p3AZ4GfAT4cEf+nJv+XgS8AvxgRb9TkbQI2ASxbtuyXXnzxxZYb0o/KdPRl3dHpFbbKsnJVmVYS61Rdyvb5bcsKU6l75kkq/e8Xa/MjYndErAQ2UOmvr37u9cAfA79ZG+TTc3dExEhEjAwNDRWtUt8ryyXsNjetzIFvNvd9rtd9lGVhlzIMejf7n3OtSz99fgsFekkDVIL8zojYNVPZiDgMrJC0JD33HcA+4Hci4pk51jcrZfow2Oy0euHLTBdhteMimrJczVuWL5yZ/udc69JPn9+mgV6SgMeA8Yh4tEGZm1I5JK0BFgCvSloA7AYej4g/a1+181CmD4PNzmyO6qbvTPrtRz7M17bc/pb+5LkeIZZl5aqyfOF0si799PktMo9+LXA3cFLSWEp7CFgGEBHbgbuAeyRNAZPAxjQD5zeA9wPvlPTx9NyPR8QY1vUpVmXU74PR7Tyqa9drlWG6Y5FB725t+04NwPfT57elwdhuaOf0yn7Q74FuLso2mDUb7RzoK9MAZqf9zp6T7Hzmb94ycFy26Y9FlOnzO6fpldZZZTj66pWyTAWciw+sHOJPnvmbuumt6qcjxLnYc/zsTwR56Mz0x05r1+e3018YDvTWM/00mNXIoecmWkqfSa/n+HfLtv1nGi7qMm1y6gqf+cqp7N8L6M41Bw701jO9vh1BO7T7y2o+nOEVfW++/8OpH9+Bsp+O8lvVjTPb7O9Hb+VVppkZs9VPMy/KYrbvTVnnqM9VN85sHeitZ8oyFXAucviy6rZmi+nMpJ+69YrqxsGCu246qEwj8mXV710VZehX77f9rN579oGVQ3zxyEtcaTILMMczpW4Mwnt6ZYfkMHXQyi+n/azRfXqm9Wu7imjHl7WnV/ZADlMHrfxy2s9mWoSlXXPry3r20+kzWwf6Dslh6qCVX077WaMujHYdxff61sm95MHYDvFsDOuGnPazTg/Od/puk3O982gn+Yi+Q+bLVY7WW7ntZ53swujk2U/ZzxZ8RN8hOUwdtPLzflZcJ89+yn5veh/Rd1C/Tx208mk0mOj9rLlOnv2UfazEgd6sT5S9e6DsOnnNQ9lv5+FAb9YncppK2SudOvsp+1iJA71Znyh798B8VoYrpGfiQG/WJ8rePTDflXmspMiasUslHZI0LumUpPvrlFkv6YSkMUmjkm6ryvufki5I+u/trrxZq8o817kZ30DNZqvIEf1l4MGIOCZpEXBU0oGIOF1V5iCwN60Tuwp4AliZ8rYBPw3883ZW3KxV/T6YWfbuASuvpoE+Is4D59Pj1yWNA8PA6aoyl6qecg1vrgpGRByU9I/aVWGz2cphMLPM3QNWXi1dMCVpObAaOFIn705JzwH7gPvaUTmzdmo0aNnoRlpmuSgc6CUtBJ4EHoiIi7X5EbE7IlYCG4CHW6mEpE2pb390YqL1tTbNimg0aCnoq756s1YVCvSSBqgE+Z0RsWumshFxGFghaUnRSkTEjogYiYiRoaGhok8za8nmdTejOukBpblU3awTisy6EfAYMB4RjzYoc1Mqh6Q1wALg1XZW1GyuNqwebriwheeiW86KzLpZC9wNnJQ0ltIeApYBRMR24C7gHklTwCSwMdLSVZL+F5UZOAslvQz8VkTsb28zzIoZ9lx0m4eKzLp5Guqe8VaX2QpsbZD3vtlVzaz9yn6pulkn+MpYm1c8F93mIwd6m3c8F93mGy88YmaWOQd6M7PMOdCbmWXOgd7MLHMO9GZmmXOgNzPLnAO9mVnmHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5oqsGbtU0iFJ45JOSbq/Tpn1kk5IGpM0Kum2qrx7JT2ffu5tdwPMzGxmRRYeuQw8GBHHJC0Cjko6EBGnq8ocBPZGREhaBTwBrJR0HfBvgREg0nP3RsT329wOMzNroOkRfUScj4hj6fHrwDgwXFPm0vRi4MA1VII6wDrgQES8loL7AeCD7aq8mZk111IfvaTlwGrgSJ28OyU9B+wD7kvJw8BLVcVepuZLwszMOqtwoJe0EHgSeCAiLtbmR8TuiFgJbAAenn5anZeK2gRJm1Lf/ujExETRKpmZWQGFAr2kASpBfmdE7JqpbEQcBlZIWkLlCH5pVfa7gHN1nrMjIkYiYmRoaKhw5c3MrLkis24EPAaMR8SjDcrclMohaQ2wAHgV2A/cIelaSdcCd6Q0MzPrkiKzbtYCdwMnJY2ltIeAZQARsR24C7hH0hQwCWxMg7OvSXoY+Hp63r+LiNfa2QAzM5uZ3pwsUw4jIyMxOjra62qYmfUVSUcjYqRenq+MNTPLnAO9mVnmHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5hzozcwy50BvZpY5B3ozs8w50JuZZc6B3swscw70ZmaZc6A3M8tc6VaYkjQBvNjreszREuB7va5EB+TaLsi3bW5X/5lt234uIobqZZQu0OdA0mijJb36Wa7tgnzb5nb1n060zV03ZmaZc6A3M8ucA31n7Oh1BTok13ZBvm1zu/pP29vmPnozs8z5iN7MLHMO9GZmmXOgnwVJ35F0UtKYpNGUdp2kA5KeT7+vTemS9J8kvSDphKQ1va39W0n6Q0mvSHq2Kq3ltki6N5V/XtK9vWhLtQbt+rSks2m7jUn6UFXep1K7zkhaV5X+wZT2gqQt3W5HLUlLJR2SNC7plKT7U3pfb7MZ2pXDNnu7pL+S9I3Uts+k9BslHUnv/5ckLUjpV6e/X0j5y6teq26bm4oI/7T4A3wHWFKT9u+BLenxFmBrevwh4M8BAb8CHOl1/Wvq/X5gDfDsbNsCXAd8K/2+Nj2+toTt+jTwyTpl3wt8A7gauBH4JnBV+vkm8G5gQSrz3h6363pgTXq8CPjrVP++3mYztCuHbSZgYXo8ABxJ2+IJ4GMpfTvwL9PjTwDb0+OPAV+aqc1F6uAj+vZZD3whPf4CsKEq/fGoeAZYLOn6XlSwnog4DLxWk9xqW9YBByLitYj4PnAA+GDna99Yg3Y1sh7404j4UUR8G3gB+OX080JEfCsi/hb401S2ZyLifEQcS49fB8aBYfp8m83Qrkb6aZtFRFxKfw6knwBuB76c0mu32fS2/DLwjyWJxm1uyoF+dgL4qqSjkjaltJ+NiPNQ2WmBn0npw8BLVc99mZl34DJotS391MZ/nbow/nC6e4M+bVc6pV9N5Qgxm21W0y7IYJtJukrSGPAKlS/VbwIXIuJyKlJdzx+3IeX/AHgnc2ibA/3srI2INcCvAf9K0vtnKKs6af06p7VRW/qljb8PrABuBc4Dn0vpfdcuSQuBJ4EHIuLiTEXrpJW2bXXalcU2i4grEXEr8C4qR+HvqVcs/W572xzoZyEizqXfrwC7qWy47053yaTfr6TiLwNLq57+LuBc92o7K622pS/aGBHfTR+4N4A/4M3T3r5ql6QBKsFwZ0TsSsl9v83qtSuXbTYtIi4Af0Glj36xpLelrOp6/rgNKf/vUOmGnHXbHOhbJOkaSYumHwN3AM8Ce4HpmQv3Av8tPd4L3JNmP/wK8IPpU+wSa7Ut+4E7JF2bTq3vSGmlUjM2cieV7QaVdn0szXa4Efh54K+ArwM/n2ZHLKAyMLa3m3WulfpqHwPGI+LRqqy+3maN2pXJNhuStDg9HgR+lcoYxCHgo6lY7Tab3pYfBZ6KymhsozY318vR6H78oTKa/430cwr47ZT+TuAg8Hz6fV28OeL+X6j0yZ0ERnrdhpr2fJHKKfEUlSOG35pNW4D7qAwOvQD8Zknb9cep3ifSh+b6qvK/ndp1Bvi1qvQPUZkB8s3pbd3jdt1G5XT9BDCWfj7U79tshnblsM1WAcdTG54Ffjelv5tKoH4B+DPg6pT+9vT3Cyn/3c3a3OzHt0AwM8ucu27MzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy9z/B4iZzGiJC+rEAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.scatter(df.classes,df.lambda_weighted)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Adjusted CrossEntropy" - ] - }, - { - "cell_type": "code", - "execution_count": 288, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "class MultiTaskCELoss(nn.Module):\n", - " \"\"\"\n", - " A cross entropy loss function which will cancel out\n", - " the effect of different class numbers\n", - " \"\"\"\n", - " def __init__(self,):\n", - " super().__init__()\n", - " self.celoss = nn.CrossEntropyLoss()\n", - " \n", - " def forward(\n", - " self,\n", - " y_pred: torch.FloatTensor,\n", - " y_true: torch.LongTensor,\n", - " )-> torch.FloatTensor:\n", - " \"\"\"\n", - " Input:\n", - " - y_pred: torch.FloatTensor, Prediction tensor\n", - " - y_true: torch.LongTensor, Label indices\n", - " Return:\n", - " - loss: torch.FloatTensor, scala adjusted\n", - " \"\"\"\n", - " nb_classes = y_pred.size(-1)\n", - " lambda_ = 1/np.log10(nb_classes)\n", - " loss = self.celoss(y_pred,y_true)\n", - " return loss*lambda_" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's make this adjustment into the loss function, an upgraded version of CrossEntropy" - ] - }, - { - "cell_type": "code", - "execution_count": 289, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e869eca73d324737b011b32aebdeab43", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "result = []\n", - "\n", - "for i in tqdm(range(50)):\n", - " c = random.randint(2,3000)\n", - " b = 1-random.random()\n", - " # here we change the loss function to MultiTaskCELoss\n", - " loss = create_softmax_pipeline(1,c,crit=MultiTaskCELoss())\\\n", - " .test_loss(300,inbalanced_input(b))\n", - " result.append(dict(classes=c,loss=loss.item(),inbl_bias=b))" - ] - }, - { - "cell_type": "code", - "execution_count": 283, - "metadata": {}, - "outputs": [], - "source": [ - "df_adjusted = pd.DataFrame(result)" - ] - }, - { - "cell_type": "code", - "execution_count": 284, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
classeslossinbl_bias
05332.3192480.208584
122642.3263460.022896
22542.3207350.922445
310272.3282910.643444
424282.3261360.883795
57722.3170710.062716
621662.3157540.364122
78122.3219720.852977
826322.3217080.639467
925762.3132690.968200
1013262.3215150.508901
111692.3215260.870789
128152.3091630.262275
137922.3217710.608608
147392.3162150.198933
1524602.3185230.880343
1624202.3205170.781538
1716882.3166570.954598
1829292.3230090.750869
19782.1927230.089868
2016702.3256080.188055
2128682.3245250.058037
229052.3192830.126673
2316752.3095650.307769
2414302.3270990.451335
254242.3158440.512538
2615112.3163250.613963
279592.3303170.987407
281472.3254590.931882
292782.3172730.401440
308852.3253970.087088
3114592.3079710.857782
321822.3447160.443465
3324682.3268250.572978
349542.3227500.815261
3512972.3199340.887149
3617872.3261610.113759
374622.3252760.969307
386862.3242620.305557
394282.3239130.342770
4029002.3193490.609161
4127652.3233730.296973
427522.3201890.302571
435022.3258970.419716
4421622.3204490.202672
4524602.3179620.430931
4625372.3121160.693554
478062.3187460.539312
482512.3228450.181191
4928452.3260070.667570
\n", - "
" - ], - "text/plain": [ - " classes loss inbl_bias\n", - "0 533 2.319248 0.208584\n", - "1 2264 2.326346 0.022896\n", - "2 254 2.320735 0.922445\n", - "3 1027 2.328291 0.643444\n", - "4 2428 2.326136 0.883795\n", - "5 772 2.317071 0.062716\n", - "6 2166 2.315754 0.364122\n", - "7 812 2.321972 0.852977\n", - "8 2632 2.321708 0.639467\n", - "9 2576 2.313269 0.968200\n", - "10 1326 2.321515 0.508901\n", - "11 169 2.321526 0.870789\n", - "12 815 2.309163 0.262275\n", - "13 792 2.321771 0.608608\n", - "14 739 2.316215 0.198933\n", - "15 2460 2.318523 0.880343\n", - "16 2420 2.320517 0.781538\n", - "17 1688 2.316657 0.954598\n", - "18 2929 2.323009 0.750869\n", - "19 78 2.192723 0.089868\n", - "20 1670 2.325608 0.188055\n", - "21 2868 2.324525 0.058037\n", - "22 905 2.319283 0.126673\n", - "23 1675 2.309565 0.307769\n", - "24 1430 2.327099 0.451335\n", - "25 424 2.315844 0.512538\n", - "26 1511 2.316325 0.613963\n", - "27 959 2.330317 0.987407\n", - "28 147 2.325459 0.931882\n", - "29 278 2.317273 0.401440\n", - "30 885 2.325397 0.087088\n", - "31 1459 2.307971 0.857782\n", - "32 182 2.344716 0.443465\n", - "33 2468 2.326825 0.572978\n", - "34 954 2.322750 0.815261\n", - "35 1297 2.319934 0.887149\n", - "36 1787 2.326161 0.113759\n", - "37 462 2.325276 0.969307\n", - "38 686 2.324262 0.305557\n", - "39 428 2.323913 0.342770\n", - "40 2900 2.319349 0.609161\n", - "41 2765 2.323373 0.296973\n", - "42 752 2.320189 0.302571\n", - "43 502 2.325897 0.419716\n", - "44 2162 2.320449 0.202672\n", - "45 2460 2.317962 0.430931\n", - "46 2537 2.312116 0.693554\n", - "47 806 2.318746 0.539312\n", - "48 251 2.322845 0.181191\n", - "49 2845 2.326007 0.667570" - ] - }, - "execution_count": 284, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_adjusted" - ] - }, - { - "cell_type": "code", - "execution_count": 287, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 287, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAD4CAYAAAAD6PrjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAclklEQVR4nO3df5Ac5X3n8fcHaQULwpGAtSP044QxJcAVkLg9TJVcpsAXBFzFErHvjM8lSExKvhyuQy5ZhSBXmDuSChgbcqmK0cngBCeKAYP4UcVxigqUomyC4tUPJMRaIAyOkXSwgGTBobO10vf+mGea0Wh+9MzO7sysPq+qqe15+ume59nu6W/38zw9rYjAzMwM4Lh2F8DMzDqHg4KZmWUcFMzMLOOgYGZmGQcFMzPLTGx3ARpx2mmnxezZs9tdDDOzrrJx48a3I6IvT96uCgqzZ89mYGCg3cUwM+sqkn6RN6+bj8zMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQZ6GkrZK2SBqQ9OmSeYdS+hZJT5SknyFpg6RXJD0oaVLrqmVmZs3Ic6UwDCyLiHOAi4DrJZ1bludp4PyImAt8Bbi3ZN6BiJibXp8rSb8DuDsizgL2Atc1XQszM2uJukEhIvZExKY0/R4wCEwvy/N+RER6exIQ1CBJwKXAwynpfmBRY0U3M7NWa6hPQdJsYB6wocK8qyT9DHiSwtVC0QmpSel5ScUD/6nAvogYTu/foCzQmJnZ2MsdFCRNBh4BlkbE/vL5EfFoRJxN4Yz/tpJZsyKiH/iPwF9IOhNQhY+oeHUhaUkKKgNDQ0N5i2tmZk3IFRQk9VAICKsjYk2tvBHxLHCmpNPS+93p78+Bf6RwpfE2MEVS8RnRM4DdVda3KiL6I6K/ry/Xc6fNzKxJeUYfCbgPGIyIu6rk+UTKh6QLgEnAO5KmSjo+pZ8GzAdeSv0P64EvpFVcCzw+0sqM1GObdzH/9mc4Y8WTzL/9GR7bvKvdRTIzG1MT62dhPrAY2CZpS0q7GZgFEBErgc8D10g6CBwAvhgRIekc4H9KOkwhAN0eES+lddwIPCDpT4HNFAJP2zy2eRc3rdnGgYOHANi17wA3rdkGwKJ57u4ws2ODPhw01Pn6+/tjYGBgVNY9//Zn2LXvwFHp06f08pMVl47KZ5qZjQVJG1Pfbl2+oznZXSEg1Eo3MxuPHBSS06f0NpRuZjYeOSgkyxfMobdnwhFpvT0TWL5gTptKZGY29vJ0NB8Tip3Jd67dwe59Bzh9Si/LF8xxJ7OZHVMcFEosmjfdQcDMjmluPjIzs4yDgpmZZRwUzMws46BgZmYZBwUzM8s4KJiZWcZBwczMMg4KZmaW8c1rNioe27zLd4ebdSEHBWs5P5vCrHu5+cha7s61O7KAUHTg4CHuXLujTSUys7x8pWAtV+/ZFG5aMutcDgrWcqdP6a34FLvTp/R2ZdOSg1h38nZrjoNCB+r2nXn5gjlHHPjhw2dT1Gpa6sQ6jlUQa2abd/t+UqoVdSldx2/19vB/fzPMwUOFxw136slHJ27DukFB0kzgB8BvA4eBVRHxP8ryLARuS/OHgaUR8WNJc4F7gI8Ah4A/i4gH0zJ/A1wM/Cqt5g8iYksrKtWITtgo421nrvVsiq8/WHkTd+pjT8ciiDUTeLrxiquaVtSlfB37Dhw8Kk+nnXx06jbMc6UwDCyLiE2STgY2SloXES+V5HkaeCIiQtJ5wEPA2cAHwDUR8Yqk09OyayNiX1pueUQ83ML6NKSdG6V4UN217wACIqWPxs7c6sCX5/9W7dkUtZqWOtFYPLu7mcDTbVdctbSiLpXWUUknnXxUq/eyh14A2hcY6o4+iog9EbEpTb8HDALTy/K8HxHF49pJpGNcRLwcEa+k6d3AW0Bf64o/Mu0aJVM8qBYPjlEnPzS/M5d+VvDhAfyxzbuaWh+M7P/WbY89HYtndzcTeMYiWI2VVtQlb96xPPl4bPMu5t/+DGeseJL5tz9z1HeuWpkPRRzxHa23nlZraEiqpNnAPGBDhXlXSfoZ8CTwlQrzLwQmAa+WJP+ZpK2S7pZ0fJXPXCJpQNLA0NBQI8Wtq9GdsVUbJ+9ZTalmd+aRHMCr1TfP/63asovmTefPf/93mD6lFwHTp/Ty57//Ox17djsWQayZwDMWwWqstKIuefKO5clHnpOxWmUufkdH46SuntxBQdJk4BEK/QX7y+dHxKMRcTawiEL/Qumy04C/Bf4wIg6n5JsoNDH9G+AU4MZKnxsRqyKiPyL6+/pae5FRbaMcJx31T2/lxmn0bG4kO3OzgW/2iif5+oNbKta33pe43v9q0bzp/GTFpbx2+7/jJysubUsfTt7gPhZBrJnA021XXLW0oi6V1tFznJh6Ys+It1szJ4N5TsYqlbnU7n0H2tKakWv0kaQeCgFhdUSsqZU3Ip6VdKak0yLibUkfoXD18F8j4vmSfHvS5K8l/TXwjeaq0LxKo2Tgw8s3OLLTtFVtuNXa1Yt6jhOTT5jIvg8OjrgPoF4bfml/w5QTe3j//w1z8HChQau8WatY31qji6Cz27ub6Uca7Wd31+qYb+UyYy1vX1ar6nJCz3HZdp3S28Otn/vkiP8fzfY75jkZKy6/7KEXOBRHNyKfPqW3Lc2EeUYfCbgPGIyIu6rk+QTwaupovoBCM9E7kiYBjwI/iIgflS0zLSL2pPUvAl4cYV0aVmujlB/Eqh3Eax3cq6l0UC12Nk+v8IUonqk084WpdQAv3+H3fnB0J3e53fsOHPEl3rXvABOkI85eOrG9u7Rjv1wnBKxmAs9oB6uRaPRgOpK6lH8WwK+HD9dYIr9mT3DyDqgorqPS8aD43aoWMEZLniuF+cBiYJuk4njCm4FZABGxEvg8cI2kg8AB4IspQPwH4DPAqZL+IC1bHHq6WlIfhfpvAf5Ti+rUkLzDJKttnAlSU58J+c6MRjpCqtZnzb/9mab7NirtzMWyTTmxp2KAaVd7d6WDRrlu7KDtZGN5tVhrFM/XH9wyoquoZk9w6l1Nlyo/ySodjVjpmDPazYR1g0JE/JjCgbtWnjuAOyqk/x3wd1WWuTRnGUddnqheaePUSq8n75lRK75c1T6r0QNhz3E6YmesVrbjJx5Hb8+EXF+IsZCnY78bO2g72VheLdYaxQOFk5XlP2pumGezQ6gbbRIrfkfn3/5Mxc+bIHE4YkyaCf2DeOTr6JpeYycYzWFio/nlavRAOPmEiUfsjNXK8KsDBztqhFG9/1UnddCO9fDD0TKWo6PyrPPg4eDWJ7Y3vO6RdII3M6Ci2r56OGLMBmb4Zy7IF9WrdUpD4006jdxMNpo3e11ydh+rn/+XXPdJAOwraxKqVbZOau+u1bFfqQ+nUa26ObBT73BtxvIFc1j+8AvZnfkAPRM0KsG31nezVKUbQ+sZ6w79Tri500EhqXcQK2/3K5e3SafRL34jbZONeGzzLh7ZuCt3QICjd8zRKlurVStnK65eWnkg7+RRW00p37maa2mtq/zA3eqPGcsTnE74Trn5qAHFy8FqHSx5mnQaHXc8WuPkG72BrtKO2S03oo1mOVs5jrwTR2016861O7KhzUUHD8eoja8vbaqZemJPxTzV0jtJJ3ynfKXQhJFc4jXzxR+NM5U8B5o8nVud1ExUy2iVs5UH8k5oOmiVdga4b/7eJys2XX3z9z5Zc7lO+HFMaP93ykGhCSO5xOuUL369G+ha1bwy3rVye3ZC00GrtHM/z9sPUOvGzW7uzxkpNx81YSSXeJ3y8wSVylFsFuvUZqBO1Mrt2QlNB63S7v283sif8p9i2fvBwaOau47VR8j6SqFJzV7idcrPE3RKObpdq/+P7W46aJVO37+68ae2x4qiyZuv2qG/vz8GBgbaXQyzUdEpbdrHgjNWPJlrlNL0Kb38ZEXH3GfbNEkbI6I/T15fKVhVPkiNnfF0j0I3qNenBt3bnzNS7lOwitrxO+7dYLTuOG7XA5+OVRV/anuCmNI78p/a7na+UrCKxt2NVC0wmmfz4+kehW7Q6X0e7eSgYBX5IHW00QyUnTJU+VgyXjr1W83NR1bReHrcY6uMZqBs9xBOsyIHBavIB6mjjWagHE/3KFh3c/ORVeQ216ON9h3Hbs6wTuCgYFX5IHUkB0o7FjgomDXAgdLGu7p9CpJmSlovaVDSdkk3VMizUNJWSVskDUj6dMm8ayW9kl7XlqT/a0nbJO2U9JdSEw87NjOzlsrT0TwMLIuIc4CLgOslnVuW52ng/IiYC3wFuBdA0inAN4FPARcC35Q0NS1zD7AEOCu9Lh9hXczMbITqBoWI2BMRm9L0e8AgML0sz/vx4Y8oncSHz1haAKyLiHcjYi+wDrhc0jTgIxHxT2m5HwCLWlIjMzNrWkNDUiXNBuYBGyrMu0rSz4AnKVwtQCF4/LIk2xspbXqaLk83M7M2yh0UJE0GHgGWRsT+8vkR8WhEnE3hjP+24mIVVhU10it97pLUTzEwNDSUt7hmZtaEXEFBUg+FgLA6ItbUyhsRzwJnSjqNwhXAzJLZM4DdKX1GhfRK61sVEf0R0d/X15enuGZm1qQ8o48E3AcMRsRdVfJ8ojh6SNIFwCTgHWAtcJmkqamD+TJgbUTsAd6TdFFa7hrg8ZbUyMzMmpbnPoX5wGJgm6QtKe1mYBZARKwEPg9cI+kgcAD4YupAflfSbcBP03L/PSLeTdN/DPwN0As8lV5mZtZGfvKamdk418iT1/yDeGZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpZxUDAzs4yDgpmZZRwUzMws46BgZmYZBwUzM8s4KJiZWcZBwczMMg4KZmaWcVAwM7OMg4KZmWXqBgVJMyWtlzQoabukGyrk+bKkren1nKTzU/ocSVtKXvslLU3zbpW0q2Tela2vnpmZNWJijjzDwLKI2CTpZGCjpHUR8VJJnteAiyNir6QrgFXApyJiBzAXQNIEYBfwaMlyd0fEt1tSEzMzG7G6QSEi9gB70vR7kgaB6cBLJXmeK1nkeWBGhVV9Fng1In4xohKbmdmoaahPQdJsYB6woUa264CnKqRfDfywLO1rqcnp+5KmVvnMJZIGJA0MDQ01UlwzM2tQ7qAgaTLwCLA0IvZXyXMJhaBwY1n6JOBzwI9Kku8BzqTQvLQH+E6ldUbEqojoj4j+vr6+vMU1M7Mm5AoKknooBITVEbGmSp7zgHuBhRHxTtnsK4BNEfFmMSEi3oyIQxFxGPgecGEzFTAzs9bJM/pIwH3AYETcVSXPLGANsDgiXq6Q5UuUNR1Jmlby9irgxbyFNjOz0ZFn9NF8YDGwTdKWlHYzMAsgIlYCtwCnAt8txBCGI6IfQNKJwO8CXy1b77ckzQUCeL3CfDMzG2N5Rh/9GFCdPH8E/FGVeR9QCBjl6YtzltHMzMaI72g2M7OMg4KZmWUcFMzMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpZxUDAzs4yDgpmZZRwUzMwsUzcoSJopab2kQUnbJd1QIc+XJW1Nr+cknV8y73VJ2yRtkTRQkn6KpHWSXkl/p7auWmZm1ow8VwrDwLKIOAe4CLhe0rlleV4DLo6I84DbgFVl8y+JiLkR0V+StgJ4OiLOAp5O783MrI3qBoWI2BMRm9L0e8AgML0sz3MRsTe9fR6YkeOzFwL3p+n7gUV5C21mZqOjoT4FSbOBecCGGtmuA54qeR/AP0jaKGlJSfrHImIPFAIP8NEqn7lE0oCkgaGhoUaKa2ZmDZqYN6OkycAjwNKI2F8lzyUUgsKnS5LnR8RuSR8F1kn6WUQ8m/dzI2IVqTmqv78/8i5nZmaNy3WlIKmHQkBYHRFrquQ5D7gXWBgR7xTTI2J3+vsW8ChwYZr1pqRpadlpwFvNVsLMzFojz+gjAfcBgxFxV5U8s4A1wOKIeLkk/SRJJxengcuAF9PsJ4Br0/S1wOPNVsLMzFojT/PRfGAxsE3SlpR2MzALICJWArcApwLfLcQQhtNIo48Bj6a0icDfR8T/Tuu4HXhI0nXAvwD/viU1MjOzpimie5rp+/v7Y2BgoH5GMzPLSNpYdktAVb6j2czMMg4KZmaWcVAwM7OMg4KZmWUcFMzMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQ58uStqbXc5LOr7espFsl7ZK0Jb2ubG3VzMysURNz5BkGlkXEJkknAxslrYuIl0ryvAZcHBF7JV0BrAI+lWPZuyPi2y2sj5mZjUDdK4WI2BMRm9L0e8AgML0sz3MRsTe9fR6YkXdZMzPrHA31KUiaDcwDNtTIdh3wVM5lv5aanL4vaWqVz1wiaUDSwNDQUCPFNTOzBuUOCpImA48ASyNif5U8l1AICjfmWPYe4ExgLrAH+E6ldUbEqojoj4j+vr6+vMU1M7Mm5AoKknooHNRXR8SaKnnOA+4FFkbEO/WWjYg3I+JQRBwGvgdc2Hw1zMysFfKMPhJwHzAYEXdVyTMLWAMsjoiX8ywraVrJ26uAFxsvvpmZtVKe0UfzgcXANklbUtrNwCyAiFgJ3AKcCny3EAcYjoj+astGxP8CviVpLhDA68BXW1IjMzNrmiKi3WXIrb+/PwYGBtpdDDOzriJpYzpRr8t3NJuZWcZBwczMMg4KZmaWcVAwM7OMg4KZmWUcFMzMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQ58uStqbXc5LOL5l3uaQdknZKWlGSfoakDZJekfSgpEmtq5aZmTUjz5XCMLAsIs4BLgKul3RuWZ7XgIsj4jzgNmAVgKQJwF8BVwDnAl8qWfYO4O6IOAvYC1w30sqYmdnI1A0KEbEnIjal6feAQWB6WZ7nImJvevs8MCNNXwjsjIifR8RvgAeAhZIEXAo8nPLdDywaaWXMzGxkGupTkDQbmAdsqJHtOuCpND0d+GXJvDdS2qnAvogYLkuv9JlLJA1IGhgaGmqkuGZm1qDcQUHSZOARYGlE7K+S5xIKQeHGYlKFbFEj/ejEiFUR0R8R/X19fXmLa2ZmTcgVFCT1UAgIqyNiTZU85wH3Agsj4p2U/AYwsyTbDGA38DYwRdLEsnQzM2ujPKOPBNwHDEbEXVXyzALWAIsj4uWSWT8FzkojjSYBVwNPREQA64EvpHzXAo83Xw0zM2uFifWzMB9YDGyTtCWl3QzMAoiIlcAtFPoJvluIIQynJp9hSV8D1gITgO9HxPa0jhuBByT9KbCZQuAxM7M2UuGkvTv09/fHwMBAQ8s8tnkXd67dwe59Bzh9Si/LF8xh0byKfdpmZuOSpI0R0Z8nb54rha712OZd3LRmGwcOHgJg174D3LRmG4ADg5lZBeP6Zy7uXLsjCwhFBw4e4s61O9pUIjOzzjaug8LufQcaSjczO9aN66Bw+pTehtLNzI514zooLF8wh96eCUek9fZMYPmCOW0qkZlZZxvXHc3FzmSPPjIzy2dcBwUoBAYHATOzfMZ185GZmTXGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQ52xJ/yTp15K+UZI+R9KWktd+SUvTvFsl7SqZd2Vrq2ZmZo3K84N4w8CyiNgk6WRgo6R1EfFSSZ53gf8CLCpdMCJ2AHMBJE0AdgGPlmS5OyK+PZIKmJlZ69S9UoiIPRGxKU2/BwwC08vyvBURPwUO1ljVZ4FXI+IXIyivmZmNoob6FCTNBuYBG5r4rKuBH5alfU3SVknflzS1ymcukTQgaWBoaKiJjzUzs7xyBwVJk4FHgKURsb+RD5E0Cfgc8KOS5HuAMyk0L+0BvlNp2YhYFRH9EdHf19fXyMeamVmDcgUFST0UAsLqiFjTxOdcAWyKiDeLCRHxZkQciojDwPeAC5tYr5mZtVCe0UcC7gMGI+KuJj/nS5Q1HUmaVvL2KuDFJtdtZmYtkmf00XxgMbBN0paUdjMwCyAiVkr6bWAA+AhwOA07PTci9ks6Efhd4Ktl6/2WpLlAAK9XmG9mZmOsblCIiB8DqpPn/wAzqsz7ADi1QvrinGU0M7MxoohodxlykzQEVBvSehrw9hgWZyyMxzrB+KyX69Q9xmO96tXpX0VErpE6XRUUapE0EBH97S5HK43HOsH4rJfr1D3GY71aWSf/9pGZmWUcFMzMLDOegsKqdhdgFIzHOsH4rJfr1D3GY71aVqdx06dgZmYjN56uFMzMbIQcFMzMLNP1QUHS5ZJ2SNopaUW7y9MoSa9L2pYeNDSQ0k6RtE7SK+nv1JQuSX+Z6rpV0gXtLX1B+pXbtyS9WJLWcB0kXZvyvyLp2nbUpaQslepU9cFQkm5KddohaUFJekftn9UemtXN26tGnbp2e0k6QdI/S3oh1em/pfQzJG1I//MH04+NIun49H5nmj+7ZF0V61pVRHTtC5gAvAp8HJgEvEDh5zXaXrYG6vA6cFpZ2reAFWl6BXBHmr4SeIrCHeYXARvaXf5Urs8AFwAvNlsH4BTg5+nv1DQ9tcPqdCvwjQp5z0373vHAGWmfnNCJ+ycwDbggTZ8MvJzK37Xbq0adunZ7pf/35DTdQ+FxBRcBDwFXp/SVwB+n6f8MrEzTVwMP1qprrc/u9iuFC4GdEfHziPgN8ACwsM1laoWFwP1p+n4+fKLdQuAHUfA8MEVH/rBgW0TEsxSevleq0TosANZFxLsRsRdYB1w++qWvrEqdqlkIPBARv46I14CdFPbNjts/o/pDs7p2e9WoUzUdv73S//v99LYnvQK4FHg4pZdvp+L2exj4rCRRva5VdXtQmA78suT9G9TeGTpRAP8gaaOkJSntYxGxBwo7PPDRlN5N9W20Dt1St0oPhurKOunIh2aNi+2lox8E1rXbS9IEFX6E9C0KQfdVYF9EDFcoX1b2NP9XFH5zruE6dXtQqPRDfd02xnZ+RFxA4ZkT10v6TI2846G+1erQDXWr9mCorquT8j80q2vqVqFOXb29ovC8mbkUfmz0QuCcStnS35bVqduDwhvAzJL3M4DdbSpLUyJid/r7FvAohY3/ZrFZKP19K2Xvpvo2WoeOr1tUfzBUV9VJlR+a1dXbq1Kdxsv2ioh9wD9S6FOYIqn469al5cvKnub/FoXmz4br1O1B4afAWalHfhKFDpYn2lym3CSdJOnk4jRwGYWHDT0BFEdzXAs8nqafAK5JI0IuAn5VvOTvQI3WYS1wmaSp6TL/spTWMVT9wVBPAFenESBnAGcB/0wH7p+pnbnSQ7O6dntVq1M3by9JfZKmpOle4N9S6CtZD3whZSvfTsXt9wXgmSj0NFera3Xt6Flv5YvC6IiXKbS3/Um7y9Ng2T9OYWTAC8D2YvkptAU+DbyS/p4SH45I+KtU121Af7vrkMr1QwqX5wcpnJlc10wdgK9Q6AjbCfxhB9bpb1OZt6Yv27SS/H+S6rQDuKJT90/g0xSaD7YCW9Lrym7eXjXq1LXbCzgP2JzK/iJwS0r/OIWD+k4Kz7w/PqWfkN7vTPM/Xq+u1V7+mQszM8t0e/ORmZm1kIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwy/x/AvBzCXAI56QAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.scatter(\n", - " df_adjusted.classes,\n", - " df_adjusted.loss,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/nbs/optimizers.ipynb b/nbs/optimizers.ipynb deleted file mode 100644 index 31c58da..0000000 --- a/nbs/optimizers.ipynb +++ /dev/null @@ -1,373 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Optimizer Mangement\n", - "> Handling PyTorch Cuda" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# default_exp ftorch.optimizer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This handler is created to handle the multiple optimizer situations" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# export\n", - "import torch\n", - "\n", - "class Opts(object):\n", - " def __init__(self, *args, **kwargs):\n", - " \"\"\"\n", - " opts = Opts(opt1 = opt1, opt2 = opt2)\n", - " opts.opt2.zero_grad()\n", - " opts[\"opt3\"] = opt3\n", - " print(len(opts))\n", - " \"\"\"\n", - " self.optlist = []\n", - " self.optnames = []\n", - " for i in range(len(args)):\n", - " oname = f\"optimizer_no{i + 1}\"\n", - " setattr(self, oname, args[i])\n", - " self.optlist.append(args[i])\n", - " self.optnames.append(oname)\n", - " for k, v in kwargs.items():\n", - " setattr(self, k, v)\n", - " self.optlist.append(v)\n", - " self.optnames.append(k)\n", - "\n", - " def __repr__(self):\n", - " return \"\\n\".join(list(\n", - " f\"{self.optnames[i]}\\n\\t{self.optlist[i].__class__}\\n\\t{self.read_opt(self.optlist[i])}\" for i in\n", - " range(len(self.optnames))))\n", - "\n", - " def get_pg(self, opt):\n", - " \"\"\"\n", - " Get paramgroups dictionary, informations about an optimizer\n", - " opt:torch.optim.optimizer\n", - " \"\"\"\n", - " return dict.copy(opt.param_groups[0])\n", - "\n", - " def read_opt(self, opt):\n", - " rt = self.get_pg(opt)\n", - " if \"params\" in rt:\n", - " del rt[\"params\"]\n", - " return rt\n", - "\n", - " def __len__(self):\n", - " \"\"\"\n", - " Total number of optimizers\n", - " \"\"\"\n", - " return len(self.optlist)\n", - "\n", - " def __contains__(self, item):\n", - " return item in self.optlist\n", - "\n", - " def __getitem__(self, item):\n", - " return getattr(self, item)\n", - "\n", - " def __setitem__(self, key, optimizer):\n", - " self.optlist.append(optimizer)\n", - " self.optnames.append(key)\n", - " setattr(self, key, optimizer)\n", - "\n", - " def zero_all(self):\n", - " \"\"\"\n", - " Zero gradient on all the optimizers\n", - " \"\"\"\n", - " for opt in self.optlist:\n", - " opt.zero_grad()\n", - "\n", - " def step_all(self):\n", - " \"\"\"\n", - " All the optimizers match a step\n", - " \"\"\"\n", - " for opt in self.optlist:\n", - " opt.step()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Experiment" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from torch import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "layer1 = nn.Linear(5,5)\n", - "layer2 = nn.Linear(5,5)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "op1 = torch.optim.Adam(layer1.parameters())\n", - "op2 = torch.optim.Adagrad(layer2.parameters())" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "opts = Opts(op1, op2)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Adam (\n", - "Parameter Group 0\n", - " amsgrad: False\n", - " betas: (0.9, 0.999)\n", - " eps: 1e-08\n", - " lr: 0.001\n", - " weight_decay: 0\n", - ")" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "opts.optimizer_no1" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Adagrad (\n", - "Parameter Group 0\n", - " eps: 1e-10\n", - " initial_accumulator_value: 0\n", - " lr: 0.01\n", - " lr_decay: 0\n", - " weight_decay: 0\n", - ")" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "opts.optimizer_no2" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "optimizer_no1\n", - "\t\n", - "\t{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}\n", - "optimizer_no2\n", - "\t\n", - "\t{'lr': 0.01, 'lr_decay': 0, 'eps': 1e-10, 'weight_decay': 0, 'initial_accumulator_value': 0}" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "opts" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's create some gradients" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.rand(2,5)\n", - "y_ = -(layer2(torch.nn.functional.relu(layer1(x))).mean())" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "y_.backward()" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", - " [-5.7813e-02, -5.0724e-02, -6.1752e-02, -1.1506e-01, -6.0402e-02],\n", - " [-4.5386e-04, -2.3771e-04, -2.1409e-04, -6.2308e-04, -3.3177e-04],\n", - " [-1.3012e-01, -1.1416e-01, -1.3899e-01, -2.5898e-01, -1.3595e-01],\n", - " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]),\n", - " tensor([ 0.0000, -0.1238, -0.0007, -0.2787, 0.0000]),\n", - " tensor([[ 0.0000, -0.1671, -0.0036, -0.1062, 0.0000],\n", - " [ 0.0000, -0.1671, -0.0036, -0.1062, 0.0000],\n", - " [ 0.0000, -0.1671, -0.0036, -0.1062, 0.0000],\n", - " [ 0.0000, -0.1671, -0.0036, -0.1062, 0.0000],\n", - " [ 0.0000, -0.1671, -0.0036, -0.1062, 0.0000]]),\n", - " tensor([-0.2000, -0.2000, -0.2000, -0.2000, -0.2000]))" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "layer1.weight.grad,layer1.bias.grad,layer2.weight.grad,layer2.bias.grad" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Take a step for all optimizers" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [], - "source": [ - "opts.step_all()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Zero gradient on all optimizers" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [], - "source": [ - "opts.zero_all()" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(tensor([[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]]),\n", - " tensor([0., 0., 0., 0., 0.]),\n", - " tensor([[0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0.]]),\n", - " tensor([0., 0., 0., 0., 0.]))" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "layer1.weight.grad,layer1.bias.grad,layer2.weight.grad,layer2.bias.grad" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}