diff --git a/nbs/11_etl.ipynb b/nbs/11_etl.ipynb
deleted file mode 100644
index 711b499..0000000
--- a/nbs/11_etl.ipynb
+++ /dev/null
@@ -1,174 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# ETL Helper\n",
- "> A combined ETL tool sets"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# default_exp etl"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## A list for each step"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "\n",
- "from forgebox.imports import *\n",
- "from typing import Callable, Union\n",
- "import traceback as tb\n",
- "import logging\n",
- "\n",
- "class NewFileScanner:\n",
- " \"\"\"\n",
- " Keep scannning a directory for new file\n",
- " if found any, callback processing the file\n",
- " \n",
- " Designed for download and delete strategy\n",
- " \n",
- " # Example\n",
- " new_file_scanner = NewFileScanner(\".\", new_file_filter=lambda x:x[-4:]==\".txt\")\n",
- " new_file_scanner(lambda x:print(f\"new file:{x} found\"))s\n",
- " \"\"\"\n",
- " def __init__(\n",
- " self,\n",
- " directory,\n",
- " new_file_filter: Callable=None,\n",
- " done_list_getter: Callable=None,\n",
- " stop_callback: Callable=None,\n",
- " ):\n",
- " self.directory = Path(directory)\n",
- " \n",
- " if new_file_filter is not None:\n",
- " self.new_file_filter=new_file_filter\n",
- " else:\n",
- " self.new_file_filter=lambda x:True\n",
- " \n",
- " if done_list_getter is not None:\n",
- " self.done_list_getter = done_list_getter\n",
- " else:\n",
- " self.done_list_getter = lambda *args:[]\n",
- " \n",
- " if stop_callback is not None:\n",
- " self.stop_callback = stop_callback\n",
- " \n",
- " self.processed = []\n",
- " \n",
- " def __call__(self, process: Callable):\n",
- " while True:\n",
- " try:\n",
- " done_list = self.done_list_getter()\n",
- " for fname in self.directory.iterdir():\n",
- " fname = str(fname)\n",
- " if self.new_file_filter(fname) != True:\n",
- " continue\n",
- " file_path = str(self.directory/fname)\n",
- " if fname not in done_list:\n",
- " if fname not in self.processed:\n",
- " result = process(file_path)\n",
- " self.processed.append(fname)\n",
- " except KeyboardInterrupt as e:\n",
- " if hasattr(self, \"stop_callback\"):\n",
- " self.stop_callback()\n",
- " logging.error(\"manually stoped\")\n",
- " break\n",
- " except Exception as e:\n",
- " error = tb.format_exc()\n",
- " logging.error(error)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Test new file scanner"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "new file:untitled.txt found\n",
- "new file:bc.txt found\n",
- "new file:cde.txt found\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "ERROR:root:manually stoped\n"
- ]
- }
- ],
- "source": [
- "new_file_scanner = NewFileScanner(\".\", new_file_filter=lambda x:x[-4:]==\".txt\")\n",
- "new_file_scanner(lambda x:print(f\"new file:{x} found\"))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.4"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": true,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {},
- "toc_section_display": true,
- "toc_window_display": false
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/nbs/61_thunder_callbacks.ipynb b/nbs/61_thunder_callbacks.ipynb
deleted file mode 100644
index 848e47f..0000000
--- a/nbs/61_thunder_callbacks.ipynb
+++ /dev/null
@@ -1,151 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Lightning Callbacks\n",
- "> Thunder, the DIYed [pytorch-lightening callbacks](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# default_exp thunder.callbacks"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "import pandas as pd\n",
- "from ipywidgets import Output\n",
- "from typing import List, Dict\n",
- "import copy\n",
- "import pytorch_lightning as pl\n",
- "import torch\n",
- "from torch import nn"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "def unfreeze(self):\n",
- " \"\"\"unfreeze this module, and its sub modules\"\"\"\n",
- " for p in self.parameters():\n",
- " p.requires_grad = True\n",
- "\n",
- "\n",
- "def freeze(self):\n",
- " \"\"\"freeze this module, and its sub modules\"\"\"\n",
- " for p in self.parameters():\n",
- " p.requires_grad = False\n",
- "\n",
- "nn.Module.unfreeze = unfreeze\n",
- "nn.Module.freeze = freeze\n",
- "\n",
- "class DataFrameMetricsCallback(pl.Callback):\n",
- " \"\"\"\n",
- " A metrics callback keep showing pandas dataframe\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self) -> None:\n",
- " \"\"\"\n",
- " In Trainer kwargs, passing this arguements along with other callbacks\n",
- " callbacks = [DataFrameMetricsCallback(),]\n",
- " \"\"\"\n",
- " self.metrics: List = []\n",
- "\n",
- " def on_fit_start(\n",
- " self, trainer: pl.Trainer,\n",
- " pl_module: pl.LightningModule\n",
- " ) -> None:\n",
- " pl_module.output = Output()\n",
- " display(pl_module.output)\n",
- "\n",
- " def on_validation_epoch_end(\n",
- " self, trainer: pl.Trainer,\n",
- " pl_module: pl.LightningModule\n",
- " ) -> None:\n",
- " metrics_dict = copy.copy(trainer.callback_metrics)\n",
- " self.metrics.append(dict((k, v.item())\n",
- " for k, v in metrics_dict.items()))\n",
- " pl_module.output.clear_output()\n",
- " with pl_module.output:\n",
- " display(pd.DataFrame(self.metrics).tail(10))\n",
- "\n",
- "\n",
- "def UnfreezeScheduler(frozen_epochs: int = 2):\n",
- " assert hasattr(pl_module, \"top_layers\"), \"Please define 'top_layers' attributes\"+\\\n",
- " \" for pl_module, which will return a list of nn.Module object(s)\"\n",
- " class UnfreezeSchedulerCallback(pl.callbacks.Callback):\n",
- " \"\"\"\n",
- " Train the top layer for [frozen_epochs] epochs\n",
- " then un freeze all\n",
- " \"\"\"\n",
- "\n",
- " def on_epoch_start(self, trainer, pl_module):\n",
- " epoch = trainer.current_epoch\n",
- "\n",
- " if epoch == 0:\n",
- " pl_module.freeze()\n",
- " for tl in pl_module.top_layers:\n",
- " tl.unfreeze()\n",
- " if epoch == frozen_epochs:\n",
- " pl_module.unfreeze()\n",
- " pl_module.base.embeddings.freeze()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.4"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": true,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {},
- "toc_section_display": true,
- "toc_window_display": false
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/nbs/70_hf_transformer_data.ipynb b/nbs/70_hf_transformer_data.ipynb
deleted file mode 100644
index ac7b569..0000000
--- a/nbs/70_hf_transformer_data.ipynb
+++ /dev/null
@@ -1,553 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Data parts for hf transformers"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# default_exp hf.data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "from forgebox.imports import *\n",
- "from forgebox.category import Category\n",
- "from typing import List, Dict, Callable, Any, Tuple"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Process IOBES files"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "def convert_iob2_file_to_iobes(file_path, result_path):\n",
- " \"\"\"\n",
- " Convert IOB2 file to IOBES\n",
- " \"\"\"\n",
- " with open(file_path, 'r') as f:\n",
- " lines = f.readlines()\n",
- " with open(result_path, 'w') as f:\n",
- " for line in lines:\n",
- " line = line.strip()\n",
- " if line == '':\n",
- " f.write('\\n')\n",
- " continue\n",
- " line = line.split()\n",
- " if line[-1] == 'O':\n",
- " f.write(' '.join(line) + '\\n')\n",
- " else:\n",
- " f.write(' '.join(line[:-1]) + ' ' + line[-1] + '\\n')\n",
- "\n",
- "\n",
- "def conbine_iobes_file(\n",
- " file_paths: List[Path],\n",
- " new_file_path: Path\n",
- "):\n",
- " \"\"\"\n",
- " Conbine from multiple IOBES files\n",
- " into IOBES files\n",
- " \"\"\"\n",
- " with open(new_file_path, 'w') as new_file:\n",
- " for file_path in file_paths:\n",
- " with open(file_path, 'r') as file:\n",
- " for line in file:\n",
- " new_file.write(line)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "class IOBES(Dataset):\n",
- " \"\"\"\n",
- " Load iobes file for NER training task\n",
- " \"\"\"\n",
- "\n",
- " def __init__(\n",
- " self,\n",
- " file_path,\n",
- " tokenizer,\n",
- " max_len=128,\n",
- " save_buffer: int = 15,\n",
- " category: Category = None,\n",
- " return_string: bool = False,\n",
- " use_frag: bool = False,\n",
- " ):\n",
- " \"\"\"\n",
- " file_path,\n",
- " tokenizer,\n",
- " max_len=128,\n",
- " save_buffer: int = 15,\n",
- " category: Category = None,\n",
- " label categories, if set to None, will be figured out\n",
- " automatically.\n",
- " You can set this to None for train dataset, but for valid\n",
- " dataset:\n",
- " valid_ds = IOBES(...,category=train_ds.cates)\n",
- " return_string: bool = False, do we return original string\n",
- " for tokenizer output, this option is good for debuging\n",
- " but the data won't pass into cuda if choose so\n",
- " use_frag: bool = False, do we use prepend like 'I-','B-'\n",
- " \"\"\"\n",
- " self.file_path = file_path\n",
- " self.max_len = max_len\n",
- " self.pairs = []\n",
- " self.list_of_words = []\n",
- " self.list_of_labels = []\n",
- " self.tokenizer = tokenizer\n",
- " self.cates = category\n",
- " self.return_string = return_string\n",
- " self.use_frag = use_frag\n",
- " self.load_data(save_buffer)\n",
- "\n",
- " def load_data(self, save_buffer: int = 15):\n",
- " \"\"\"\n",
- " Load file in to object structure\n",
- " \"\"\"\n",
- " with open(self.file_path, 'r') as f:\n",
- " for line in f:\n",
- " line = line.strip()\n",
- " if line:\n",
- " splited = line.split()\n",
- " if len(splited) != 2:\n",
- " continue\n",
- " word, label = splited\n",
- " # do we use 'I-', 'B-' etc\n",
- " if self.use_frag is False:\n",
- " if \"-\" in label:\n",
- " label = label.split('-')[1]\n",
- " self.pairs.append([word, label])\n",
- "\n",
- " self.pairs = np.array(self.pairs)\n",
- "\n",
- " if self.cates is None:\n",
- " labels_df = pd.DataFrame({\"label\": self.pairs[:, 1]})\n",
- " self.cates = Category(list(labels_df.vc(\"label\").index))\n",
- "\n",
- " self.batching_words(save_buffer)\n",
- "\n",
- " def batching_words(self, save_buffer: int = 15):\n",
- " \"\"\"\n",
- " batching self.words into self.list_of_words\n",
- " by self.max_len -15\n",
- " \"\"\"\n",
- " for i in range(0, len(self.pairs), self.max_len-save_buffer):\n",
- " chunk_slice = slice(i, i+self.max_len-save_buffer)\n",
- " self.list_of_words.append(self.pairs[chunk_slice, 0])\n",
- " self.list_of_labels.append(self.pairs[chunk_slice, 1])\n",
- "\n",
- " def __len__(self) -> int:\n",
- " return len(self.list_of_words)\n",
- "\n",
- " def __getitem__(self, idx: int) -> Tuple[List[str]]:\n",
- " return list(self.list_of_words[idx]), list(self.list_of_labels[idx])\n",
- "\n",
- " def __repr__(self):\n",
- " return f\"\"\"NER dataset using IOBES annotation\n",
- " {len(self)} sentences,\n",
- " Labels:\n",
- " {list(self.cates.i2c)}\n",
- " \"\"\"\n",
- "\n",
- " def collate_fn(self, data):\n",
- " \"\"\"\n",
- " data: list of tuple\n",
- " \"\"\"\n",
- " words, text_labels = zip(*data)\n",
- "\n",
- " inputs = self.tokenizer(\n",
- " list(words),\n",
- " return_tensors='pt',\n",
- " padding=True,\n",
- " truncation=True,\n",
- " max_length=self.max_len,\n",
- " is_split_into_words=True,\n",
- " return_offsets_mapping=True,\n",
- " add_special_tokens=False,\n",
- " )\n",
- " return self.align_offsets(inputs, text_labels, words)\n",
- "\n",
- " def align_offsets(\n",
- " self,\n",
- " inputs,\n",
- " text_labels: List[List[str]],\n",
- " words: List[List[str]]\n",
- " ):\n",
- " \"\"\"\n",
- " inputs: output if tokenizer\n",
- " text_labels: labels in form of list of list of strings\n",
- " words: words in form of list of list of strings\n",
- " \"\"\"\n",
- " labels = torch.zeros_like(inputs.input_ids).long()\n",
- " labels -= 100\n",
- " text_lables_array = np.empty(labels.shape, dtype=object)\n",
- " words_array = np.empty(labels.shape, dtype=object)\n",
- " max_len = inputs.input_ids.shape[1]\n",
- "\n",
- " for row_id, input_ids in enumerate(inputs.input_ids):\n",
- " word_pos = inputs.word_ids(row_id)\n",
- " for idx, pos in enumerate(word_pos):\n",
- " if pos is None:\n",
- " continue\n",
- " if pos <= max_len:\n",
- " labels[row_id, idx] = self.cates.c2i[text_labels[row_id][pos]]\n",
- " if self.return_string:\n",
- " text_lables_array[row_id,\n",
- " idx] = text_labels[row_id][pos]\n",
- " words_array[row_id, idx] = words[row_id][pos]\n",
- "\n",
- " inputs['labels'] = labels\n",
- " if self.return_string:\n",
- " inputs['text_labels'] = text_lables_array.tolist()\n",
- " inputs['word'] = words_array.tolist()\n",
- " return inputs\n",
- "\n",
- " def dataloader(self, batch_size: int = 32, shuffle: bool = True):\n",
- " \"\"\"\n",
- " Create dataloader\n",
- " \"\"\"\n",
- " return DataLoader(\n",
- " self,\n",
- " batch_size=batch_size,\n",
- " shuffle=shuffle,\n",
- " collate_fn=self.collate_fn,\n",
- " )\n",
- "\n",
- " def one_batch(self, batch_size: int = 32, shuffle: bool = True):\n",
- " return next(iter(self.dataloader(batch_size, shuffle)))\n",
- "\n",
- " def visualize_batch(self, batch, row_idx=0):\n",
- " return list(zip(self.tokenizer.convert_ids_to_tokens(batch.input_ids[row_idx]),\n",
- " batch.labels[row_idx].numpy(),\n",
- " batch.text_labels[row_idx],\n",
- " batch.word[row_idx],\n",
- " batch.offset_mapping[row_idx].numpy(),\n",
- " ))\n",
- "\n",
- " def set_hfconfig(self, config):\n",
- " \"\"\"\n",
- " set the category information to huggingface config\n",
- " \"\"\"\n",
- " config.num_labels = len(self.cates)\n",
- " config.id2label = {i: label for i, label in enumerate(self.cates.i2c)}\n",
- " config.label2id = {label: i for i, label in enumerate(self.cates.i2c)}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoTokenizer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "tokenizer = AutoTokenizer.from_pretrained(\"raynardj/roberta-pubmed\", add_prefix_space=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset = IOBES(\"/Users/xiaochen.zhang/data/valid.iobes\", tokenizer)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "in-O\n",
- "blood-O\n",
- ";-O\n",
- "content-O\n",
- "of-O\n",
- "cAMP-O\n",
- "was-O\n",
- "also-O\n",
- "decreased-O\n",
- "in-O\n",
- "lymphocytes-O\n",
- "by-O\n",
- "33-O\n",
- "%-O\n",
- ".-O\n",
- "At-O\n",
- "the-O\n",
- "same-O\n",
- "time-O\n",
- ",-O\n",
- "total-O\n",
- "content-O\n",
- "of-O\n",
- "T-cell_type\n",
- "lymphocytes-cell_type\n",
- "was-O\n",
- "decreased-O\n",
- "1.5-fold-O\n",
- "in-O\n",
- "peripheric-O\n",
- "blood-O\n",
- ".-O\n",
- "Treatment-O\n",
- "with-O\n",
- "I-hydroxyvitamin-O\n",
- "D3-O\n",
- "(-O\n",
- "1-1.5-O\n",
- "mg-O\n",
- "daily-O\n",
- ",-O\n",
- "within-O\n",
- "4-O\n",
- "weeks-O\n",
- ")-O\n",
- "led-O\n",
- "to-O\n",
- "normalization-O\n",
- "of-O\n",
- "total-O\n",
- "and-O\n",
- "ionized-O\n",
- "form-O\n",
- "of-O\n",
- "Ca2+-O\n",
- "and-O\n",
- "of-O\n",
- "25-O\n",
- "(-O\n",
- "OH-O\n",
- ")-O\n",
- "D-O\n",
- ",-O\n",
- "but-O\n",
- "did-O\n",
- "not-O\n",
- "affect-O\n",
- "the-O\n",
- "PTH-O\n",
- "content-O\n",
- "in-O\n",
- "blood-O\n",
- ".-O\n",
- "Concentration-O\n",
- "of-O\n",
- "the-O\n",
- "receptors-protein\n",
- "to-O\n",
- "1.25-O\n",
- "(-O\n",
- "OH-O\n",
- ")-O\n",
- "2D3-O\n",
- "was-O\n",
- "elevated-O\n",
- "up-O\n",
- "to-O\n",
- "39.7-O\n",
- "fmole/mg-O\n",
- "after-O\n",
- "I-O\n",
- "week-O\n",
- "of-O\n",
- "the-O\n",
- "treatment-O\n",
- ",-O\n",
- "whereas-O\n",
- "it-O\n",
- "was-O\n",
- "decreased-O\n",
- "to-O\n",
- "the-O\n",
- "initial-O\n",
- "level-O\n",
- "24.8-O\n",
- "fmole/mg-O\n",
- "within-O\n",
- "4-O\n",
- "weeks-O\n",
- ";-O\n",
- "simultaneous-O\n",
- "alteration-O\n",
- "in-O\n"
- ]
- }
- ],
- "source": [
- "for w,l in zip(*dataset[2]):\n",
- " print(f\"{w}-{l}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'input_ids': tensor([[ 19, 3741, 2603, ..., 1417, 2617, 11576],\n",
- " [ 4590, 2156, 255, ..., 405, 1182, 6608],\n",
- " [ 6214, 25683, 3809, ..., 11, 5, 8151],\n",
- " ...,\n",
- " [13998, 25326, 2413, ..., 5, 2199, 21],\n",
- " [11299, 705, 24811, ..., 134, 1589, 2032],\n",
- " [ 5804, 924, 14, ..., 366, 1168, 9]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n",
- " [1, 1, 1, ..., 1, 1, 1],\n",
- " [1, 1, 1, ..., 1, 1, 1],\n",
- " ...,\n",
- " [1, 1, 1, ..., 1, 1, 1],\n",
- " [1, 1, 1, ..., 1, 1, 1],\n",
- " [1, 1, 1, ..., 1, 1, 1]]), 'offset_mapping': tensor([[[ 1, 4],\n",
- " [ 1, 2],\n",
- " [ 2, 5],\n",
- " ...,\n",
- " [ 3, 5],\n",
- " [ 5, 8],\n",
- " [ 1, 6]],\n",
- "\n",
- " [[ 1, 5],\n",
- " [ 1, 1],\n",
- " [ 1, 1],\n",
- " ...,\n",
- " [ 5, 7],\n",
- " [ 7, 9],\n",
- " [ 9, 14]],\n",
- "\n",
- " [[ 1, 5],\n",
- " [ 5, 8],\n",
- " [ 8, 10],\n",
- " ...,\n",
- " [ 1, 2],\n",
- " [ 1, 3],\n",
- " [ 1, 10]],\n",
- "\n",
- " ...,\n",
- "\n",
- " [[ 1, 5],\n",
- " [ 5, 8],\n",
- " [ 8, 10],\n",
- " ...,\n",
- " [ 1, 3],\n",
- " [ 1, 7],\n",
- " [ 1, 3]],\n",
- "\n",
- " [[ 1, 5],\n",
- " [ 5, 6],\n",
- " [ 6, 10],\n",
- " ...,\n",
- " [ 2, 3],\n",
- " [ 1, 1],\n",
- " [ 1, 2]],\n",
- "\n",
- " [[ 1, 7],\n",
- " [ 1, 5],\n",
- " [ 1, 4],\n",
- " ...,\n",
- " [ 3, 5],\n",
- " [ 5, 7],\n",
- " [ 1, 2]]]), 'labels': tensor([[0, 1, 1, ..., 0, 0, 0],\n",
- " [2, 0, 2, ..., 0, 0, 0],\n",
- " [0, 0, 0, ..., 0, 0, 0],\n",
- " ...,\n",
- " [1, 1, 1, ..., 0, 0, 0],\n",
- " [0, 0, 0, ..., 2, 0, 2],\n",
- " [0, 0, 0, ..., 0, 0, 0]])}"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dataset.one_batch()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.4"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": true,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {},
- "toc_section_display": true,
- "toc_window_display": true
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/nbs/72_pl_training.ipynb b/nbs/72_pl_training.ipynb
deleted file mode 100644
index 3ce4854..0000000
--- a/nbs/72_pl_training.ipynb
+++ /dev/null
@@ -1,466 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Pytorch Lighting training\n",
- "> on huggingface transformers"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "# default_exp hf.train"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "from forgebox.hf.data import IOBES\n",
- "from forgebox.imports import *\n",
- "from forgebox.loop import chunkify\n",
- "import pytorch_lightning as pl\n",
- "from transformers import (\n",
- " AutoModelForTokenClassification,\n",
- " AutoTokenizer,\n",
- " pipeline\n",
- ")\n",
- "from tqdm.notebook import tqdm\n",
- "from typing import Callable, List\n",
- "from torch import device"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 290,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "try:\n",
- " ishell = get_ipython()\n",
- " IS_JUPYTER = True\n",
- " from tqdm.notebook import tqdm\n",
- "except NameError:\n",
- " IS_JUPYTER = False\n",
- " from tqdm import tqdm"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "# !pip install transformers==4.9.1\n",
- "# !pip install pytorch-lightning==1.3.8\n",
- "# !pip install tensorflow==2.2.0"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Load model and tokenizer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "\n",
- "# ner model and tokenizer\n",
- "def ner_model_from(\n",
- " name:str, dataset: IOBES\n",
- "):\n",
- " \"\"\"\n",
- " name: from_pretrain(name)\n",
- " \"\"\"\n",
- " model = AutoModelForTokenClassification.from_pretrained(\n",
- " name,\n",
- " num_labels=len(dataset.cates),\n",
- " )\n",
- " dataset.set_hfconfig(model.config)\n",
- " return model\n",
- "\n",
- "def ner_tokenizer_from(\n",
- " name: str\n",
- "):\n",
- " return AutoTokenizer.from_pretrained(\n",
- " name, add_prefix_space=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Lightning data module"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "\n",
- "# ner data module\n",
- "class NERDataModule(pl.LightningDataModule):\n",
- " def __init__(self, train_ds, val_ds, batch_size=32):\n",
- " super().__init__()\n",
- " self.train_ds = train_ds\n",
- " self.val_ds = val_ds\n",
- " self.batch_size = batch_size\n",
- "\n",
- " def train_dataloader(self):\n",
- " return self.train_ds.dataloader(batch_size=self.batch_size, shuffle=True)\n",
- "\n",
- " def val_dataloader(self):\n",
- " return self.val_ds.dataloader(batch_size=self.batch_size*2, shuffle=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "\n",
- "# ner module\n",
- "class NERModule(pl.LightningModule):\n",
- " \"\"\"\n",
- " PyTorch lightning module for training ner model\n",
- " \"\"\"\n",
- " def __init__(\n",
- " self, model,\n",
- " ):\n",
- " \"\"\"\n",
- " model: huggingface transformer model for ner\n",
- " \"\"\"\n",
- " super().__init__()\n",
- " self.model = model\n",
- "\n",
- " def forward(self, batch):\n",
- " return self.model(\n",
- " input_ids=batch['input_ids'],\n",
- " attention_mask=batch['attention_mask'],\n",
- " labels=batch['labels'])\n",
- " \n",
- " def training_step(self, batch, batch_idx):\n",
- " outputs = self(batch)\n",
- " loss = outputs.loss\n",
- " self.log(\"loss\", loss)\n",
- " self.log(\"acc\", self.calcualte_acc(outputs, batch.labels))\n",
- " return loss\n",
- "\n",
- " def validation_step(self, batch, batch_idx):\n",
- " outputs = self(batch)\n",
- " loss = outputs.loss\n",
- " self.log(\"val_loss\", loss)\n",
- " self.log(\"val_acc\", self.calcualte_acc(outputs, batch.labels))\n",
- " return loss\n",
- " \n",
- " def calcualte_acc(self, outputs, labels):\n",
- " pred_idx = outputs.logits.argmax(-1)\n",
- " mask = torch.ones_like(pred_idx)\n",
- " mask[labels==-100]=False\n",
- " return (pred_idx[mask]==labels[mask]).float().mean()\n",
- " \n",
- " def configure_optimizers(self):\n",
- " # discriminative learning rate\n",
- " param_groups = [\n",
- " {'params': self.model.roberta.parameters(), 'lr': 5e-6},\n",
- " {'params': self.model.classifier.parameters(), 'lr': 1e-3},\n",
- " ]\n",
- " optimizer = torch.optim.Adam(param_groups, lr=1e-3)\n",
- " return optimizer"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Enhance pipeline"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "def clean_ner_output(self, outputs):\n",
- " \"\"\"\n",
- " Cleaning output for NER task\n",
- " \"\"\"\n",
- " results = []\n",
- " current = []\n",
- " last_idx = 0\n",
- " # make to sub group by position\n",
- " for output in outputs:\n",
- " if output[\"start\"] in [last_idx, last_idx-1]:\n",
- " current.append(output)\n",
- " else:\n",
- " results.append(current)\n",
- " current = [output, ]\n",
- " last_idx = output[\"end\"]\n",
- " if len(current) > 0:\n",
- " results.append(current)\n",
- "\n",
- " # from tokens to string\n",
- " strings = []\n",
- " for c in results:\n",
- " tokens = []\n",
- " starts = []\n",
- " ends = []\n",
- " for o in c:\n",
- " tokens.append(o['word'])\n",
- " starts.append(o['start'])\n",
- " ends.append(o['end'])\n",
- "\n",
- " new_str = self.tokenizer.convert_tokens_to_string(tokens)\n",
- " if new_str != '':\n",
- " strings.append(dict(\n",
- " word=new_str,\n",
- " start=min(starts),\n",
- " end=max(ends),\n",
- " entity=c[0]['entity_group']\n",
- " ))\n",
- " return strings"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 281,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "class NERInference:\n",
- " \"\"\"\n",
- " NER Inference pipeline\n",
- " ner = NERInference.from_pretrained('xxxx/xxxx')\n",
- " ner.predict(['text1','text2'])\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self, model, tokenizer, name=None):\n",
- " super().__init__()\n",
- " self.model = model.eval()\n",
- " self.tokenizer = tokenizer\n",
- " self.name = name if name else \"NER model\"\n",
- "\n",
- " def __repr__(self):\n",
- " return f\"[NERInference on {self.name}]\"\n",
- "\n",
- " def to(self, device_str):\n",
- " self.model = self.model.to(device(device_str))\n",
- " return self\n",
- "\n",
- " @classmethod\n",
- " def from_pretrained(cls, tag):\n",
- " \"\"\"\n",
- " Load from pretrained model and tokenizer\n",
- " \"\"\"\n",
- " model = AutoModelForTokenClassification.from_pretrained(tag)\n",
- " tokenizer = AutoTokenizer.from_pretrained(tag)\n",
- " return cls(model=model, tokenizer=tokenizer, name=model.config._name_or_path)\n",
- " \n",
- " def __call__(self, data, batch_size=32, dev=device(\"cpu\")):\n",
- " if type(data) == str:\n",
- " return self.batch_predict([data,])\n",
- " else:\n",
- " return self.predict(data, dev=dev, batch_size=batch_size)\n",
- "\n",
- " def predict(\n",
- " self,\n",
- " texts: List[str],\n",
- " dev=device(\"cpu\"),\n",
- " batch_size: int = 32,\n",
- " progress_bar: bool = True\n",
- " ) -> pd.DataFrame:\n",
- " \"\"\"\n",
- " Predict a list of sentences/ paragraphs\n",
- " \"\"\"\n",
- " # place the model into device\n",
- " self.model = self.model.to(dev)\n",
- " iterator = list(enumerate(chunkify(texts, bs=batch_size)))\n",
- " if progress_bar:\n",
- " iterator = tqdm(iterator, leave=False)\n",
- "\n",
- " # run through iterator\n",
- " all_dfs = []\n",
- " for i, text_b in iterator:\n",
- " # by batch prediction\n",
- " batch_df = self.batch_predict(text_b)\n",
- " if len(batch_df) > 0:\n",
- " # calculate the row number\n",
- " batch_df['text_id'] = batch_df.apply(\n",
- " lambda row: i*batch_size+row.batch_row_sn, axis=1)\n",
- " all_dfs.append(batch_df)\n",
- "\n",
- " # place the model back to cpu\n",
- " self.model = self.model.to(\"cpu\")\n",
- " return pd.concat(all_dfs).reset_index(drop=True)\n",
- " \n",
- " def tokenizing(self, texts):\n",
- " inputs = self.tokenizer(\n",
- " texts,\n",
- " padding=\"max_length\",\n",
- " max_length=self.tokenizer.model_max_length,\n",
- " return_attention_mask=True,\n",
- " return_tensors='pt', truncation=True, return_offsets_mapping=True\n",
- " ).to(self.model.device)\n",
- " return inputs\n",
- "\n",
- "\n",
- " def batch_predict(self, texts:List[str])-> pd.DataFrame:\n",
- " \"\"\"\n",
- " Predict a single batch of sentences\n",
- " \"\"\"\n",
- " id2label = self.model.config.id2label\n",
- " inputs = self.tokenizing(texts)\n",
- "\n",
- " with torch.no_grad():\n",
- " outputs = self.model(input_ids=inputs.input_ids,\n",
- " attention_mask=inputs.attention_mask)\n",
- " inputs = inputs.to(device('cpu'))\n",
- "\n",
- " pred_idx = outputs.logits.argmax(-1).to(device(\"cpu\"))\n",
- " batch_size = pred_idx.size(0)\n",
- " offsets = inputs.offset_mapping\n",
- " results = []\n",
- " for bi in range(batch_size):\n",
- " text = texts[bi]\n",
- " input_ids = inputs.input_ids[bi]\n",
- " word_ids = inputs.word_ids(bi)\n",
- " pred_ids = pred_idx[bi]\n",
- " # initial values for the row\n",
- " last_pos = 0\n",
- " previous_has_positive = False\n",
- " current_start = 0\n",
- " current_index = 0\n",
- " current_id = 0\n",
- " line = []\n",
- " for ti in range(1, len(input_ids)):\n",
- " if input_ids[ti] == self.tokenizer.sep_token_id:\n",
- " break\n",
- " # is the current token an appending sub-word?\n",
- " if word_ids[ti] == last_pos:\n",
- " pass\n",
- " # is current token negative\n",
- " elif pred_ids[ti].item() == 0:\n",
- " # store the previous hanging prediction\n",
- " if previous_has_positive:\n",
- " start = current_start\n",
- " end = offsets[bi, ti, 0].item()\n",
- " line.append({\n",
- " \"start\": start, \"end\": end,\n",
- " \"entity\": id2label[current_id],\n",
- " \"word\": text[start:end],\n",
- " \"index\": current_index,\n",
- " })\n",
- "\n",
- " current_start = offsets[bi, ti, 0].item()\n",
- " previous_has_positive = False\n",
- " current_id = 0\n",
- " current_index = ti\n",
- " # has positive prediction index, other than zero\n",
- " else:\n",
- " if previous_has_positive:\n",
- " # different than the previous\n",
- " if current_id != pred_ids[ti].item():\n",
- " start = current_start\n",
- " end = offsets[bi, ti, 0].item()\n",
- " line.append({\n",
- " \"start\": start,\n",
- " \"end\": end,\n",
- " \"entity\": id2label[current_id],\n",
- " \"word\": text[start:end],\n",
- " \"index\": current_index,\n",
- " })\n",
- " current_start = offsets[bi, ti, 0].item()\n",
- " # this is the 1st postive predict for a while\n",
- " else:\n",
- " current_start = offsets[bi, ti, 0].item()\n",
- " previous_has_positive = True\n",
- " current_index = ti\n",
- " current_id = pred_ids[ti].item()\n",
- "\n",
- " last_pos = word_ids[ti]\n",
- " if previous_has_positive:\n",
- " start = current_start\n",
- " end = offsets[bi, ti, 1].item()\n",
- " line.append({\n",
- " \"start\": start,\n",
- " \"end\": end,\n",
- " \"entity\": id2label[current_id],\n",
- " \"word\": text[start:end],\n",
- " \"index\": current_index,\n",
- " })\n",
- "\n",
- " results.append(line)\n",
- " all_dfs = []\n",
- " for i, res in enumerate(results):\n",
- " sub_df = pd.DataFrame(res)\n",
- " sub_df[\"batch_row_sn\"] = i\n",
- " all_dfs.append(sub_df)\n",
- " return pd.concat(all_dfs)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.4"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": true,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {},
- "toc_section_display": true,
- "toc_window_display": false
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/nbs/cross_entropy_weighter.ipynb b/nbs/cross_entropy_weighter.ipynb
deleted file mode 100644
index 95018ab..0000000
--- a/nbs/cross_entropy_weighter.ipynb
+++ /dev/null
@@ -1,1298 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Create multi-task adjustment for cross-entropy\n",
- "> A weighter for multi-task learning with different softmax+cross-entropy"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Tools and imports"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# default_exp multitask_ce"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "import torch\n",
- "from torch import nn\n",
- "import pandas as pd\n",
- "import numpy as np\n",
- "from typing import Callable"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 120,
- "metadata": {},
- "outputs": [],
- "source": [
- "from matplotlib import pyplot as plt"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Enters softmax"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 102,
- "metadata": {},
- "outputs": [],
- "source": [
- "softmax = nn.Softmax(-1)\n",
- "crit = nn.CrossEntropyLoss()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Experience the problem"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 104,
- "metadata": {},
- "outputs": [],
- "source": [
- "def test_loss(model,iters:300,get_xy):\n",
- " losses=[]\n",
- " with torch.no_grad():\n",
- " for i in range(iters):\n",
- " x,y_true = get_xy(model.nb_output)\n",
- " y_vec = model(x)\n",
- " loss = model.crit(y_vec,y_true)\n",
- " losses.append(loss)\n",
- " return torch.stack(losses).mean()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 245,
- "metadata": {},
- "outputs": [],
- "source": [
- "def create_softmax_pipeline(\n",
- " nb_layers:int,\n",
- " nb_output:int,\n",
- " hs:int=500,\n",
- " crit:nn.Module=nn.CrossEntropyLoss(),\n",
- " )->nn.Module:\n",
- " modules = (nb_layers-1)*[nn.Linear(hs,hs),nn.Dropout(.3)]\n",
- " modules+=[nn.Linear(hs,nb_output),]\n",
- " model = nn.Sequential(*modules)\n",
- " model.hs = hs\n",
- " model.nb_output = nb_output\n",
- " model.__class__.crit = crit\n",
- " model.__class__.test_loss = test_loss\n",
- " return model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 246,
- "metadata": {},
- "outputs": [],
- "source": [
- "def random_input(nb_output):\n",
- " return torch.rand(2,500),torch.randint(low=0,high=nb_output,size = (2,))\n",
- "\n",
- "def inbalanced_input(\n",
- " bias:float\n",
- " ) -> Callable:\n",
- " def inbalanced_input_(nb_output:int):\n",
- " return torch.rand(2,500),torch.randint(\n",
- " low=0,\n",
- " high=max(1,int(nb_output*bias)),\n",
- " size = (2,)\n",
- " )\n",
- " return inbalanced_input_"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "* For models with different branches of output, some output 2 category, some output more, like 200,500\n",
- "* Their loss will end up in different scale\n",
- "* That makes mixing them up fairly hard and unfair to lesser category tasks"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 247,
- "metadata": {},
- "outputs": [],
- "source": [
- "import random"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 266,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "fa780ff8fdab4d4cb932c4dbfe96c44f",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n"
- ]
- }
- ],
- "source": [
- "result = []\n",
- "\n",
- "for i in tqdm(range(50)):\n",
- " c = random.randint(2,3000)\n",
- " b = 1-random.random()\n",
- " loss = create_softmax_pipeline(1,c).test_loss(300,inbalanced_input(b))\n",
- " result.append(dict(classes=c,loss=loss.item(),inbl_bias=b))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 267,
- "metadata": {},
- "outputs": [],
- "source": [
- "df = pd.DataFrame(result)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 268,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " classes | \n",
- " loss | \n",
- " inbl_bias | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 2244 | \n",
- " 7.784015 | \n",
- " 0.719190 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 2447 | \n",
- " 7.849158 | \n",
- " 0.002960 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 706 | \n",
- " 6.625416 | \n",
- " 0.297956 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 272 | \n",
- " 5.739001 | \n",
- " 0.154829 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 2860 | \n",
- " 7.994882 | \n",
- " 0.179631 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " 2376 | \n",
- " 7.823036 | \n",
- " 0.554090 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 474 | \n",
- " 6.230643 | \n",
- " 0.497769 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " 1476 | \n",
- " 7.317650 | \n",
- " 0.790722 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " 729 | \n",
- " 6.652312 | \n",
- " 0.671657 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 1183 | \n",
- " 7.161001 | \n",
- " 0.422553 | \n",
- "
\n",
- " \n",
- " 10 | \n",
- " 1792 | \n",
- " 7.571034 | \n",
- " 0.375524 | \n",
- "
\n",
- " \n",
- " 11 | \n",
- " 2601 | \n",
- " 7.928347 | \n",
- " 0.937953 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 1988 | \n",
- " 7.653401 | \n",
- " 0.448365 | \n",
- "
\n",
- " \n",
- " 13 | \n",
- " 617 | \n",
- " 6.471711 | \n",
- " 0.669720 | \n",
- "
\n",
- " \n",
- " 14 | \n",
- " 1257 | \n",
- " 7.196665 | \n",
- " 0.180724 | \n",
- "
\n",
- " \n",
- " 15 | \n",
- " 1208 | \n",
- " 7.142272 | \n",
- " 0.603882 | \n",
- "
\n",
- " \n",
- " 16 | \n",
- " 1571 | \n",
- " 7.431909 | \n",
- " 0.307233 | \n",
- "
\n",
- " \n",
- " 17 | \n",
- " 1815 | \n",
- " 7.565763 | \n",
- " 0.271957 | \n",
- "
\n",
- " \n",
- " 18 | \n",
- " 2370 | \n",
- " 7.840256 | \n",
- " 0.130841 | \n",
- "
\n",
- " \n",
- " 19 | \n",
- " 2421 | \n",
- " 7.835362 | \n",
- " 0.470347 | \n",
- "
\n",
- " \n",
- " 20 | \n",
- " 2608 | \n",
- " 7.933852 | \n",
- " 0.362477 | \n",
- "
\n",
- " \n",
- " 21 | \n",
- " 1833 | \n",
- " 7.566414 | \n",
- " 0.551423 | \n",
- "
\n",
- " \n",
- " 22 | \n",
- " 1769 | \n",
- " 7.536047 | \n",
- " 0.630381 | \n",
- "
\n",
- " \n",
- " 23 | \n",
- " 1950 | \n",
- " 7.615793 | \n",
- " 0.795910 | \n",
- "
\n",
- " \n",
- " 24 | \n",
- " 1910 | \n",
- " 7.585739 | \n",
- " 0.343632 | \n",
- "
\n",
- " \n",
- " 25 | \n",
- " 2301 | \n",
- " 7.798465 | \n",
- " 0.829628 | \n",
- "
\n",
- " \n",
- " 26 | \n",
- " 2377 | \n",
- " 7.869211 | \n",
- " 0.016988 | \n",
- "
\n",
- " \n",
- " 27 | \n",
- " 1496 | \n",
- " 7.393178 | \n",
- " 0.119690 | \n",
- "
\n",
- " \n",
- " 28 | \n",
- " 1179 | \n",
- " 7.147058 | \n",
- " 0.568997 | \n",
- "
\n",
- " \n",
- " 29 | \n",
- " 1475 | \n",
- " 7.356740 | \n",
- " 0.693435 | \n",
- "
\n",
- " \n",
- " 30 | \n",
- " 2409 | \n",
- " 7.835954 | \n",
- " 0.455461 | \n",
- "
\n",
- " \n",
- " 31 | \n",
- " 542 | \n",
- " 6.328403 | \n",
- " 0.733283 | \n",
- "
\n",
- " \n",
- " 32 | \n",
- " 699 | \n",
- " 6.609892 | \n",
- " 0.674927 | \n",
- "
\n",
- " \n",
- " 33 | \n",
- " 2966 | \n",
- " 8.062770 | \n",
- " 0.268944 | \n",
- "
\n",
- " \n",
- " 34 | \n",
- " 2831 | \n",
- " 8.006001 | \n",
- " 0.822981 | \n",
- "
\n",
- " \n",
- " 35 | \n",
- " 1657 | \n",
- " 7.459473 | \n",
- " 0.323996 | \n",
- "
\n",
- " \n",
- " 36 | \n",
- " 1469 | \n",
- " 7.316678 | \n",
- " 0.062805 | \n",
- "
\n",
- " \n",
- " 37 | \n",
- " 674 | \n",
- " 6.579655 | \n",
- " 0.812632 | \n",
- "
\n",
- " \n",
- " 38 | \n",
- " 654 | \n",
- " 6.550060 | \n",
- " 0.601652 | \n",
- "
\n",
- " \n",
- " 39 | \n",
- " 194 | \n",
- " 5.350555 | \n",
- " 0.238536 | \n",
- "
\n",
- " \n",
- " 40 | \n",
- " 250 | \n",
- " 5.567015 | \n",
- " 0.270642 | \n",
- "
\n",
- " \n",
- " 41 | \n",
- " 2368 | \n",
- " 7.819131 | \n",
- " 0.117133 | \n",
- "
\n",
- " \n",
- " 42 | \n",
- " 1573 | \n",
- " 7.429606 | \n",
- " 0.003537 | \n",
- "
\n",
- " \n",
- " 43 | \n",
- " 2994 | \n",
- " 8.053926 | \n",
- " 0.764925 | \n",
- "
\n",
- " \n",
- " 44 | \n",
- " 2155 | \n",
- " 7.741345 | \n",
- " 0.941843 | \n",
- "
\n",
- " \n",
- " 45 | \n",
- " 2642 | \n",
- " 7.919126 | \n",
- " 0.612718 | \n",
- "
\n",
- " \n",
- " 46 | \n",
- " 486 | \n",
- " 6.242527 | \n",
- " 0.701243 | \n",
- "
\n",
- " \n",
- " 47 | \n",
- " 1948 | \n",
- " 7.644366 | \n",
- " 0.970036 | \n",
- "
\n",
- " \n",
- " 48 | \n",
- " 2626 | \n",
- " 7.932363 | \n",
- " 0.657842 | \n",
- "
\n",
- " \n",
- " 49 | \n",
- " 1660 | \n",
- " 7.466812 | \n",
- " 0.855916 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " classes loss inbl_bias\n",
- "0 2244 7.784015 0.719190\n",
- "1 2447 7.849158 0.002960\n",
- "2 706 6.625416 0.297956\n",
- "3 272 5.739001 0.154829\n",
- "4 2860 7.994882 0.179631\n",
- "5 2376 7.823036 0.554090\n",
- "6 474 6.230643 0.497769\n",
- "7 1476 7.317650 0.790722\n",
- "8 729 6.652312 0.671657\n",
- "9 1183 7.161001 0.422553\n",
- "10 1792 7.571034 0.375524\n",
- "11 2601 7.928347 0.937953\n",
- "12 1988 7.653401 0.448365\n",
- "13 617 6.471711 0.669720\n",
- "14 1257 7.196665 0.180724\n",
- "15 1208 7.142272 0.603882\n",
- "16 1571 7.431909 0.307233\n",
- "17 1815 7.565763 0.271957\n",
- "18 2370 7.840256 0.130841\n",
- "19 2421 7.835362 0.470347\n",
- "20 2608 7.933852 0.362477\n",
- "21 1833 7.566414 0.551423\n",
- "22 1769 7.536047 0.630381\n",
- "23 1950 7.615793 0.795910\n",
- "24 1910 7.585739 0.343632\n",
- "25 2301 7.798465 0.829628\n",
- "26 2377 7.869211 0.016988\n",
- "27 1496 7.393178 0.119690\n",
- "28 1179 7.147058 0.568997\n",
- "29 1475 7.356740 0.693435\n",
- "30 2409 7.835954 0.455461\n",
- "31 542 6.328403 0.733283\n",
- "32 699 6.609892 0.674927\n",
- "33 2966 8.062770 0.268944\n",
- "34 2831 8.006001 0.822981\n",
- "35 1657 7.459473 0.323996\n",
- "36 1469 7.316678 0.062805\n",
- "37 674 6.579655 0.812632\n",
- "38 654 6.550060 0.601652\n",
- "39 194 5.350555 0.238536\n",
- "40 250 5.567015 0.270642\n",
- "41 2368 7.819131 0.117133\n",
- "42 1573 7.429606 0.003537\n",
- "43 2994 8.053926 0.764925\n",
- "44 2155 7.741345 0.941843\n",
- "45 2642 7.919126 0.612718\n",
- "46 486 6.242527 0.701243\n",
- "47 1948 7.644366 0.970036\n",
- "48 2626 7.932363 0.657842\n",
- "49 1660 7.466812 0.855916"
- ]
- },
- "execution_count": 268,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### The pattern"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "There are $n$ tasks $T$ with a set of numbers of classes $C$ where $c_{i}$ is the number of the classes of task $T_{i}$\n",
- "\n",
- "The number of classes $c_{i}$ has certain correlation of the average loss $L_{i}$\n",
- "\n",
- "Where $log_{10}(c_{i})$ has a clear linear relationship with $L$\n",
- "\n",
- "$L_{i} = a.log_{10}(c_{i})$, where a is a fixed constant"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 269,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 269,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAVw0lEQVR4nO3dfZBd913f8ffXqzWs8rQmFiTaSBa0QW0Tjy13a8vNlAl1iMhDE8EorV1MwMOMcGBC3BYNces6ENwSum0hqYcIDxkaiOu6qMrWMMYy0xDIMCMxa0u24jhiRB4krQLIISvX9kLWq2//2Hudq7v34dzVfTz7fs3s6N5zftr9/sbWR0ff8zu/G5mJJGn0XTboAiRJ3WGgS1JJGOiSVBIGuiSVhIEuSSWxYVA/+Morr8xt27YN6sdL0kh67LHHnsnMTY3ODSzQt23bxtzc3KB+vCSNpIj4arNztlwkqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKomBLVuUpLKbPTrPzKETnF1YZPPkBPt2bWf3jqme/Tyv0CWpB2aPznPnwePMLyySwPzCInc8eIwdH36U2aPzPfmZBrok9cDMoRMsLi2vOv6NF5a48+DxnoS6gS5JPXB2YbHpucWlZWYOnej6z7SHLkldUN8vf9XEOAuLS03Htwr8tSp0hR4R/yoinoqIz0fEAxHx7XXnvy0iHoyIkxFxJCK2db1SSRpSjfrlrcIcYPPkRNfraBvoETEF/AwwnZlvBMaAm+uG/QTwjcz8u8CvAL/c7UIlaVg165c3MzE+xr5d27teR9Ee+gZgIiI2ABuBs3Xn3w18svL6AHBTRER3SpSk4daufTI5Mc7U5AQBTE1O8Es/fHVPli+27aFn5nxE/GfgFLAIPJqZj9YNmwJOV8a/GBHngVcDz9QOioi9wF6ArVu3Xnr1kjQENk9OMN8i1M8vLnHsQ2/teR1FWi5XsHIF/t3AZuBlEXFr/bAGvzVXHci8LzOnM3N606aG+7NL0sjZt2s7E+NjTc/3ol/eSJFVLm8BvpyZ5wAi4iDwj4FP1Yw5A2wBzlTaMq8C/rrLtUrSQNw1e5z7D5966Sr1ZZeP8R9+6Fttk+qvv/C7T/GNFy6+GdqrfnkjRXrop4CdEbGx0he/CXi6bsxDwI9VXu8BPpOZq67QJWnU3DV7nE/VhDnA899c5o4Hj3HX7PGXju3eMcXRu9/Kr/6La/vSL2+kSA/9SEQcAB4HXgSOAvdFxIeBucx8CPgE8NsRcZKVK/P6VTCSNJIeOHK66bn7D59i+qrvuCiwd++Y6luA1yv0YFFmfgj4UN3hu2vO/w3wni7WJUlDYblFsyFZWbI4qACv55OiktatIrshjkW0DPVePPG5Vu7lImldavR0Z6NNs265YUvL79OvFSxFGOiS1qVGT3cuLq3c7HzTRz7zUrDfs/tqbt3Z+LmZfq5gKcJAl7QutXoQqP5q/Z7dV/OVj7xjoCtYirCHLmldatcbr25xOywrWIow0CWtK9Uboa3CvGqYbngWYaBLKr1qiM8vLBI02JekiWG64VmEgS6p1KqrWao3QDt5hH2YbngW4U1RSaXW6V7lVW/6O98x1P3yRrxCl1Qq9Q8LtVrN0sytO7dyz+6re1Bdbxnokkqjvr3SSc98LIJbbtgykkFeZaBLKo1G7ZWEVaFefT/V5HH/UWWgSyqNZssMq+Hdas+WMjDQJZVGs5751OQEf/LBfzqAivrLVS6SSqPRR8EN234rveQVuqTSqLZR2m2JW1YGuqShV7sUcXLjOJlwfnGpYWAP+34rvWSgSxpq9UsRaz+EuborIrBuQ7yWPXRJQ63dk57VXRFloEsackV2PBy1XRF7xUCXNNSK7Hg4arsi9oqBLmlozR6d5/m/fbHlmPW0LLEdb4pKGqhme5VvHL+MpQvJ0nLznVjGIobuY+AGyUCXNDCt9ip/YelCy987MT5mmNcx0CUNzM8/9NSa9iov26Za3WKgSxqI2aPzLCwutR9YZ73sy7IW3hSVNBBrWTvuDdDW2gZ6RGyPiGM1X89GxB11Y94cEedrxtzdu5IllUGRteOXBVyxcZxg5crcnnlrbVsumXkCuBYgIsaAeeDTDYZ+LjPf2d3yJJVVu4+Hu2LjOB/6Z28wwDvQaQ/9JuDPM/OrvShGUvnUf8Zn9Wbmvl3bL1rhAq5cuVSdBvrNwANNzt0YEU8AZ4Gfzcyn6gdExF5gL8DWrVs7/NGSRs3s0Xn2HXjipbXk8wuL7DvwBOBWt70QmUU+PhUi4nJWwvoNmfmXdedeCVzIzOci4u3ARzPz9a2+3/T0dM7Nza2xbEmjYMeHH71od8SqKzaOc/Tutw6gotEXEY9l5nSjc51cob8NeLw+zAEy89ma1w9HxK9FxJWZ+Uzn5UoaRXfNHueBI6dZzmQsgltu2NIwzIGmx3VpOlm2eAtN2i0R8ZqIiMrr6yvf9+uXXp6kUXDX7HE+dfgUy5V/8S9n8qnDpwZc1fpT6Ao9IjYCPwD8ZM2x2wEycz+wB3hfRLwILAI3Z9FejqSR98CR0x2Nn5wY71El61uhQM/MF4BX1x3bX/P6XuDe7pYmaRg1WrWy3OL6bfyyYOlCXvT+59/1hn6Uuu74pKikwqqbac0vLJJ86yPgVhquq41FMPOea5ianHjp4aCZ91zjSpYecS8XSYU12kxrcWmZjeOXNdwd8ZYbtqzrD23uN6/QJRXSajOtxaUL3LpzK2OVS/WxCG7duZV7dl/dzxLXPa/QJRXSajOtzZMT3LP7agN8wLxCl1RIq8203AFxOBjokgqZGG8cF5ePhT3yIWGgSypk8cXGHwlXuyRRg2WgSyqk2VJzHyEcHga6pELGmiw2b3Zc/WegSyrklhu2dHRc/eeyRUmFVJck1u+o6FLF4VF4P/Rucz90Sepcq/3QbblIUknYcpHWgWaf66lyMdClkqvukFjdVKu6QyJgqJeMgS6VSKMr8ZlDJxrukDhz6ISBXjIGulQSd80e5/7Dp6guc6heideHeVWrvVk0mrwpKpXA7NH5i8K8qlmYw8oOiSoXA10qgZlDJ1aFeSsT42PukFhCtlykEijSPhmL4EKmq1xKzECXSmDz5ATzbUL9QiZf/sg7+lSRBsGWi1QC+3ZtZ2J8rOUYe+bl5xW6VALV9snMoRPMLywScFFP3Z75+mCgSyWxe8fUS8Huk6Hrk4EulVBtuGv9sIcuSSVhoEtSSbQN9IjYHhHHar6ejYg76sZERHwsIk5GxJMRcV3vSpYkNdK2h56ZJ4BrASJiDJgHPl037G3A6ytfNwAfr/wqSeqTTlsuNwF/nplfrTv+buC3csVhYDIiXtuVCiVJhXQa6DcDDzQ4PgWcrnl/pnLsIhGxNyLmImLu3LlzHf5oSVIrhQM9Ii4H3gX8TqPTDY6t2isoM+/LzOnMnN60aVPxKiVJbXVyhf424PHM/MsG584AW2revw44eymFSZI600mg30LjdgvAQ8B7K6tddgLnM/Nrl1ydJKmwQk+KRsRG4AeAn6w5djtAZu4HHgbeDpwEXgBu63qlkqSWCgV6Zr4AvLru2P6a1wn8dHdLkyR1widFJakkDHRJKgkDXZJKwkCXpJIw0CWpJAx0SSoJA12SSsJAl6SSMNAlqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKgkDXZJKwkCXpJIo9AEXUhnNHp1n5tAJzi4ssnlygn27trN7x9Sgy5LWzEDXujR7dJ47Dx5ncWkZgPmFRe48eBzAUNfIMtC1btRekV8WwXLmRecXl5aZOXTCQNfIMtC1Ltw1e5z7D5+iGuH1YV51dmGxf0VJXeZNUZXe7NH5i8K8lc2TEz2vR+oVr9BVSvXtlSJhPjE+xr5d23tem9QrBrpKp/6GZ7P2CsBYBBcyXeWiUjDQVRrVq/L5gn3wAP7LP7/GEFdpGOgqhfqr8nYC+JGdWw1zlYqBrlKYOXSibZjbXlHZGegqhXbLDSfGx/ilH77aEFepFVq2GBGTEXEgIr4YEU9HxI11598cEecj4ljl6+7elCs11mq54dTkhGGudaHoFfpHgUcyc09EXA5sbDDmc5n5zu6VJhW3b9f2VT10r8q13rQN9Ih4JfB9wI8DZOY3gW/2tiypsWYbalVD2822tJ4VuUL/HuAc8JsRcQ3wGPCBzHy+btyNEfEEcBb42cx8qv4bRcReYC/A1q1bL6lwrT/tNtSqDXZpPSrSQ98AXAd8PDN3AM8DH6wb8zhwVWZeA/w3YLbRN8rM+zJzOjOnN23adAllaz1qtJKluqGWpGKBfgY4k5lHKu8PsBLwL8nMZzPzucrrh4HxiLiyq5Vq3Wu2ksUNtaQVbQM9M/8COB0R1U0ubgK+UDsmIl4TEVF5fX3l+369y7VqnWu2ksUNtaQVRXdbfD9wf0Q8CVwL/MeIuD0ibq+c3wN8vtJD/xhwc2aLDTSkNdi3azsT42MXHXNDLelbYlC5Oz09nXNzcwP52Rpdfmyc1ruIeCwzpxud80lRjRRXskjN+QEXklQSBroklYQtFw2EvXCp+wx09V27Jz4lrY0tF/WdT3xKvWGgq+984lPqDQNdfecTn1JvGOjqO5/4lHrDm6LqO/cul3rDQNclqS4/nF9YZCyC5UymCgS0T3xK3Wega83qlx8uV/YFchmiNBj20LVmjZYfVrkMUeo/A11r1m6ZocsQpf4y0LVm7ZYZugxR6i8DXWvWaPlhlcsQpf7zpqjWrHb5YaerXCR1n4GuS+LyQ2l42HKRpJIw0CWpJAx0SSoJA12SSsJAl6SSMNAlqSQMdEkqCQNdkkrCQJekkigU6BExGREHIuKLEfF0RNxYdz4i4mMRcTIinoyI63pTriSpmaKP/n8UeCQz90TE5cDGuvNvA15f+boB+HjlV0lSn7S9Qo+IVwLfB3wCIDO/mZkLdcPeDfxWrjgMTEbEa7terSSpqSItl+8BzgG/GRFHI+I3IuJldWOmgNM1789Ujl0kIvZGxFxEzJ07d27NRUuSVisS6BuA64CPZ+YO4Hngg3VjosHvy1UHMu/LzOnMnN60aVPHxUqSmisS6GeAM5l5pPL+ACsBXz9mS8371wFnL708SVJRbQM9M/8COB0R1Y+fuQn4Qt2wh4D3Vla77ATOZ+bXuluqJKmVoqtc3g/cX1nh8iXgtoi4HSAz9wMPA28HTgIvALf1oFZJUguFAj0zjwHTdYf315xP4Ke7WJckqUM+KSpJJWGgS1JJGOiSVBIGuiSVhIEuSSVhoEtSSRjoklQSBroklYSBLkklYaBLUkkU3ctFBcwenWfm0AnOLiyyeXKCfbu2s3vHqm3hJaknDPQumT06z50Hj7O4tAzA/MIidx48DmCoS+oLWy5dMnPoxEthXrW4tMzMoRMDqkjSemOgd8nZhcWOjktStxnoXbJ5cqKj45LUbQZ6l+zbtZ2J8bGLjk2Mj7Fv1/Ymv0OSusubol1SvfHpKhdJg2Kgd9HuHVMGuKSBseUiSSVhoEtSSRjoklQSBroklYQ3RQtwjxZJo8BAb8M9WiSNClsubbhHi6RRYaC34R4tkkaFgd6Ge7RIGhWFAj0ivhIRxyPiWETMNTj/5og4Xzl/LCLu7n6pg+EeLZJGRSc3Rb8/M59pcf5zmfnOSy1o2LhHi6RR4SqXAtyjRdIoKNpDT+DRiHgsIvY2GXNjRDwREb8fEW9oNCAi9kbEXETMnTt3bk0FS5IaK3qF/qbMPBsR3wn8QUR8MTP/uOb848BVmflcRLwdmAVeX/9NMvM+4D6A6enpvMTaJUk1Cl2hZ+bZyq9/BXwauL7u/LOZ+Vzl9cPAeERc2eVaJUkttA30iHhZRLyi+hp4K/D5ujGviYiovL6+8n2/3v1yJUnNFGm5fBfw6UpebwD+R2Y+EhG3A2TmfmAP8L6IeBFYBG7OTFsqktRHbQM9M78EXNPg+P6a1/cC93a3NElSJ3xSVJJKwkCXpJIY2QeL3KNcki42koHuHuWStNpItlzco1ySVhvJQHePcklabSQD3T3KJWm1kQx09yiXpNVG8qaoe5RL0mojGejgHuWSVG8kWy6SpNUMdEkqCQNdkkrCQJekkjDQJakkDHRJKokY1AcLRcQ54KsD+eGX5krgmUEX0QVlmQeUZy5lmQeUZy7DOI+rMnNToxMDC/RRFRFzmTk96DouVVnmAeWZS1nmAeWZy6jNw5aLJJWEgS5JJWGgd+6+QRfQJWWZB5RnLmWZB5RnLiM1D3voklQSXqFLUkkY6JJUEgZ6AxGxJSL+MCKejoinIuIDLcb+o4hYjog9/ayxiKLziIg3R8Sxypg/6nedRRSZS0S8KiJ+NyKeqIy5bRC1thIR3x4Rf1pT4y80GPNtEfFgRJyMiCMRsa3/lbZXcC7/OiK+EBFPRsT/jYirBlFrK0XmUTN2T0RkRAznUsbM9KvuC3gtcF3l9SuAPwP+QYNxY8BngIeBPYOuey3zACaBLwBbK++/c9B1X8Jc/i3wy5XXm4C/Bi4fdO11NQbw8srrceAIsLNuzE8B+yuvbwYeHHTdlzCX7wc2Vl6/bxjnUmQeNf/f/TFwGJgedN2NvrxCbyAzv5aZj1de/z/gaaDRp2m8H/jfwF/1sbzCCs7jXwIHM/NUZdwozyWBV0REAC9nJdBf7GuhbeSK5ypvxytf9SsT3g18svL6AHBTZU5DpchcMvMPM/OFytvDwOv6WGIhBf+bAPwi8J+Av+lXbZ0y0Nuo/HN3Byt/a9cenwJ+CNjf/6o612wewPcCV0TEZyPisYh4b79r61SLudwL/H3gLHAc+EBmXuhrcQVExFhEHGPlQuAPMrN+HlPAaYDMfBE4D7y6v1UWU2AutX4C+P3+VNaZdvOIiB3Alsz8vYEUWJCB3kJEvJyVK/A7MvPZutO/CvxcZi73v7LOtJnHBuAfAu8AdgH/PiK+t88lFtZmLruAY8Bm4Frg3oh4ZZ9LbCszlzPzWlauVq+PiDfWDWl0NT6U64sLzAWAiLgVmAZm+llfUa3mERGXAb8C/JtB1VeUgd5ERIyzEhz3Z+bBBkOmgf8ZEV8B9gC/FhG7+1hiIQXmcQZ4JDOfz8xnWOkRXtPPGosqMJfbWGkfZWaeBL4M/L1+1tiJzFwAPgv8YN2pM8AWgIjYALyKlfbR0GoxFyLiLcC/A96VmX/b59I60mQerwDeCHy28ud9J/DQMN4YNdAbqPQrPwE8nZn/tdGYzPzuzNyWmdtY6XP+VGbO9rHMtorMA/g/wD+JiA0RsRG4gZX+9FApOJdTwE2V8d8FbAe+1J8Ki4mITRExWXk9AbwF+GLdsIeAH6u83gN8Jit35YZJkblUWhW/zkqYD+X9mXbzyMzzmXllzZ/3w6zMZ24gBbewYdAFDKk3AT8KHK/01WBlBcVWgMwcib45BeaRmU9HxCPAk8AF4Dcy8/MDqba1Iv9NfhH47xFxnJW2xc9V/tUxTF4LfDIixli5oPpfmfl7EfFhYC4zH2LlL67fjoiTrFyZ3zy4clsqMpcZVm5Q/07lvu6pzHzXwCpurMg8RoKP/ktSSdhykaSSMNAlqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKon/D8XLOB7FzJCYAAAAAElFTkSuQmCC\n",
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "plt.scatter(np.log10(df.classes),df.loss,)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### The solution\n",
- "\n",
- "Assume we have certain function, will produce a weight that will bring each cross entropy loss to the same constant scale ${a}$.\n",
- "\n",
- "$L_{i}.f(c_{i})=a$\n",
- "\n",
- "$f(c_{i})a.log_{10}(c_{i})=a$\n",
- "\n",
- "Here we can get how to calculate $f(c_{i})$\n",
- "\n",
- "$f(c_{i})=\\frac{1}{log_{10}(c_{i})}$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 270,
- "metadata": {},
- "outputs": [],
- "source": [
- "def adjust(nb_class):\n",
- " return 1/np.log10(nb_class)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 271,
- "metadata": {},
- "outputs": [],
- "source": [
- "df[\"lambda_weighted\"] = df.apply(\n",
- " lambda row:row['loss']*adjust(row['classes']),\n",
- " axis=1)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For now it's a about the same scale of loss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 272,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 272,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAa3ElEQVR4nO3df4ycx33f8fcn8lG+iHQpmZdUOpOhTCWinYoQmWuQlrJRq4Ho2EBIWUbotpDkKADbugWkQiZAK0FqVwgslrXgFi3CMlAQKyEcKxbJ0mVTmqCYsHIrxkfyTIo8MZJ/KBJJWGfJNMX66hypb//YOWu13r199m5/PDv3eQGH25uZ3ZvZ59nvPs/MPM8oIjAzs3z9VK8rYGZmneVAb2aWOQd6M7PMOdCbmWXOgd7MLHNv63UFai1ZsiSWL1/e62qYmfWVo0ePfi8ihurllS7QL1++nNHR0V5Xw8ysr0h6sVGeu27MzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5poGeklLJR2SNC7plKT765RZL+mEpDFJo5JuS+kfSGnTP/9P0oZONMTMzOorcmXsZeDBiDgmaRFwVNKBiDhdVeYgsDciQtIq4AlgZUQcAm4FkHQd8ALw1fY2wczMZtL0iD4izkfEsfT4dWAcGK4pcyneXKrqGqDeslUfBf48In44tyqbmVkrWuqjl7QcWA0cqZN3p6TngH3AfXWe/jHgiw1ed1Pq8hmdmJhopUpmZtZE4UAvaSHwJPBARFyszY+I3RGxEtgAPFzz3OuBW4D99V47InZExEhEjAwN1b35mpmZzVKhQC9pgEqQ3xkRu2YqGxGHgRWSllQl/wawOyKmZl1TMzOblSKzbgQ8BoxHxKMNytyUyiFpDbAAeLWqyD+hQbeNmZl1VpFZN2uBu4GTksZS2kPAMoCI2A7cBdwjaQqYBDZOD86mfv2lwF+2teZmZlZI00AfEU8DalJmK7C1Qd53qJmlY2Zm3eMrY83MMudAb2aWOQd6M7PMOdCbmWXOgd7MLHMO9GZmmXOgNzPLnAO9mVnmHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5hzozcwyV2TN2KWSDkkal3RK0v11yqyXdELSmKRRSbdV5S2T9NX0/NNpaUEzM+uSImvGXgYejIhjkhYBRyUdiIjTVWUOAnsjIiStAp4AVqa8x4Hfi4gDkhYCb7SzAWZmNrOmR/QRcT4ijqXHrwPj1KwBGxGXphcDB64BphcGfy/wtog4UFXuh22sv5mZNdFSH33qdlkNHKmTd6ek54B9wH0p+ReAC5J2STouaZukq+o8d1Pq8hmdmJhotQ1mZjaDwoE+dbs8CTwQERdr8yNid0SsBDYAD6fktwHvAz4J/H3g3cDH6zx3R0SMRMTI0NBQy40wM7PGCgV6SQNUgvzOiNg1U9mIOAyskLQEeBk4HhHfiojLwB5gzRzrbGZmLSgy60bAY8B4RDzaoMxNqRyS1gALgFeBrwPXSpo+TL8dOF3vNczMrDOKzLpZC9wNnJQ0ltIeApYBRMR24C7gHklTwCSwMQ3OXpH0SeBg+iI4CvxBm9tQyJ7jZ9m2/wznLkxyw+JBNq+7mQ2rh5s/0cysz+nNyTLlMDIyEqOjo219zT3Hz/KpXSeZnLry47TBgav47EducbA3syxIOhoRI/Xy5sWVsdv2n3lLkAeYnLrCtv1nelQjM7PumReB/tyFyZbSzcxyMi8C/Q2LB1tKNzPLybwI9JvX3czgwFuv0xocuIrN627uUY3MzLqnyKybvjc94OpZN2Y2H82LQA+VYO/Abmbz0bzoujEzm88c6M3MMudAb2aWOQd6M7PMOdCbmWXOgd7MLHMO9GZmmXOgNzPLnAO9mVnmHOjNzDLnQG9mlrkia8YulXRI0rikU5Lur1NmvaQTksYkjUq6rSrvSkofk7S33Q0wM7OZFbmp2WXgwYg4JmkRcFTSgYioXuT7ILA3IkLSKuAJYGXKm4yIW9tbbTMzK6rpEX1EnI+IY+nx68A4MFxT5lK8ufjsNUC5FqI1M5vHWuqjl7QcWA0cqZN3p6TngH3AfVVZb0/dOc9I2tDgdTelMqMTExOtVMnMzJooHOglLQSeBB6IiIu1+RGxOyJWAhuAh6uylqWVyf8p8HlJK+o8d0dEjETEyNDQUMuNMDOzxgoFekkDVIL8zojYNVPZiDgMrJC0JP19Lv3+FvAXVM4IzMysS4rMuhHwGDAeEY82KHNTKoekNcAC4FVJ10q6OqUvAdYCp+u9hpmZdUaRWTdrgbuBk5LGUtpDwDKAiNgO3AXcI2kKmAQ2phk47wH+q6Q3qHypPFIzW8fMzDqsaaCPiKcBNSmzFdhaJ/1/A7fMunZmZjZnvjLWzCxzDvRmZplzoDczy5wDvZlZ5hzozcwy50BvZpY5B3ozs8w50JuZZc6B3swscw70ZmaZc6A3M8ucA72ZWeYc6M3MMudAb2aWOQd6M7PMOdCbmWXOgd7MLHNF1oxdKumQpHFJpyTdX6fMekknJI1JGpV0W03+OySdlfSf21l5MzNrrsiasZeBByPimKRFwFFJB2rWfj0I7E3rxK4CngBWVuU/DPxl22ptZmaFNT2ij4jzEXEsPX4dGAeGa8pciohIf14DTD9G0i8BPwt8tV2VNjOz4lrqo5e0HFgNHKmTd6ek54B9wH0p7aeAzwGbm7zuptTlMzoxMdFKlczMrInCgV7SQuBJ4IGIuFibHxG7I2IlsIFKVw3AJ4D/EREvzfTaEbEjIkYiYmRoaKh47avsOX6WtY88xY1b9rH2kafYc/zsrF7HzCw3RfrokTRAJcjvjIhdM5WNiMOSVkhaAvwD4H2SPgEsBBZIuhQRW+Za8Wp7jp/lU7tOMjl1BYCzFyb51K6TAGxYPTzTU83Msldk1o2Ax4DxiHi0QZmbUjkkrQEWAK9GxD+LiGURsRz4JPB4u4M8wLb9Z34c5KdNTl1h2/4z7f5XZmZ9p8gR/VrgbuCkpLGU9hCwDCAitgN3AfdImgImgY1Vg7Mdd+7CZEvpZmbzSdNAHxFPA2pSZiuwtUmZPwL+qIW6FXbD4kHO1gnqNywe7MS/MzPrK1lcGbt53c0MDlz1lrTBgavYvO7mHtXIzKw8Cg3Glt30gOu2/Wc4d2GSGxYPsnndzR6INTMjk0APlWDvwG5m9pOy6LoxM7PGHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5hzozcwy50BvZpY5B3ozs8xlc1MzM+sve46f9R1nu8SB3sy6zus8d1eRNWOXSjokaVzSKUn31ymzXtIJSWOSRiXdltJ/TtLRlH5K0r/oRCOsv+w5fpa1jzzFjVv2sfaRp9hz/Gyvq2Rd5nWeu6vIEf1l4MGIOCZpEXBU0oGIOF1V5iCwNyJC0irgCWAlcB74hxHxI0kLgWcl7Y2Ic+1uiPUHH8kZeJ3nbmt6RB8R5yPiWHr8OjAODNeUuVS1GPg1QKT0v42IH6X0q4v8P8ubj+QMGq/n7HWeO6OlwCtpObAaOFIn705JzwH7gPuq0pdKOgG8BGytdzQvaVPq8hmdmJhorQXWV3wkZ+B1nrut8GBs6np5EnggIi7W5kfEbmC3pPcDDwO/mtJfAlZJugHYI+nLEfHdmufuAHYAjIyMBH3EMwdac8PiQc7WCeq9OJLztusdr/PcXYUCvaQBKkF+Z0TsmqlsRByWtELSkoj4XlX6OUmngPcBX55LpcvC/c2t27zu5re8Z9CbIzlvu97zOs/dU2TWjYDHgPGIeLRBmZtSOSStARYAr0p6l6TBlH4tsBbIpjPW/c2t27B6mM9+5BaGFw8iYHjxIJ/9yC1d/8B729l8UuSIfi1wN3BS0lhKewhYBhAR24G7gHskTQGTwMY0A+c9wOckBSDgP0TEyXY3olfc3zw7ZTiS87az+aRpoI+Ip6kE6ZnKbAW21kk/AKyade1Krkz9zdaaMmw7jxFYt3i64xx45sDslOGCqV5vu+kxgrMXJgneHCPwxWPWCb4Fwhx0YuZA7kd5ZRkE7fWsj5nGCHLa3lYOevM6p3IYGRmJ0dHRXlejJ2qD4LTFgwN8+td/MYsAsPaRp+p2mQwvHuRrW27vQY1648Yt+6j3yRPw7Uc+3O3qWAYkHY2IkXp57ropkXpHeQAXJqf4N18aY3kG94bxIGiFrwy1bnKgL5GZgt300V+/9+X2U4Dr5FhCr8cIbH5xoC+RosGun+d790uA6/RgaVmuJ7D5wYOxJVLvqtFG+rWro9eDoEV1Y7C0DNcT2PzgQF8i0x/6z3zlFN//4dSMZcvY1VFUPwS4Rl+k9QaSzcrOXTdt0M6+3A2rhzn+u3fw+Y23MjxDMD97YbLvB2bLrNEXqaBU73kZrkmw8nOgn6NO9eVuWD3M17bczuc33srAVfUvTO73gdky27zu5rqXgweUZnykny+68hdUdznQz1Gnb461bf8Zpq40vtahnwdmy2zD6uG689yhPOMjZbkxW6tBu5+/oPqV++jnqNPzwou8TlkCT26GS3A/nJl0et8rcpX2bK507uZVwblfaV6Uj+jnqNPzwou8TlkCT7uU5bS+7FNBO7nvFT3qns1ZRbcumiv7mUM39/PsA32n38xOB4N6r9+p/1UG3fhwFt0nujnXfTb7aSf3vaIBfDZBu1sXzX3mK6dK0bVVT7e/hLLuuunGDbQ6PS+89vUX//QAEfCDyaksT0U7fVrf6j7Rjamgs91PO7nvFQ3gs7ndczdWGdtz/GzDKcpl6Ors9k3tsg703XozOx0M+mHeebt0+rS+jHeNnEudOrVvFA3gswna3bhobqaj9jJ0dXb7nk9ZB3rfQKv/dHpBkDLsE7UDhI0uwurlflo0gM82aHf64GWm964MXZ3dXvimaaCXtBR4HPi7wBvAjoj4jzVl1gMPp/zLwAMR8bSkW4HfB94BXAF+LyK+1N4mNFaGVYSsNZ0+re/1PlGvm0ZQdyrnbOvUjpkmrQTwMp5xNtrOiwcHSlHXbnRfVWt6P3pJ1wPXR8QxSYuAo8CGiDhdVWYh8H/TOrGrgCciYqWkXwAiIp6XdEN67nsi4kKj/9fO+9HXu7/74MBVvnlUyXVySlyv94lG9+OvDfazrVOv21dbl15Nbaz3Pky/x8MlGdtq9/sz0/3oi6wZex44nx6/LmkcGAZOV5W5VPWUa0j7bET8dVWZc5JeAYaAhoG+nfrhBlqe5/uTOnmE2Ot9olGXwnQAmmudyjIG0euVxKq3c+1ZUzvrMpfPbzfPhFpaYUrScuAw8Pci4mJN3p3AZ4GfAT4cEf+nJv+XgS8AvxgRb9TkbQI2ASxbtuyXXnzxxZYb0o/KdPRl3dHpFbbKsnJVmVYS61Rdyvb5bcsKU6l75kkq/e8Xa/MjYndErAQ2UOmvr37u9cAfA79ZG+TTc3dExEhEjAwNDRWtUt8ryyXsNjetzIFvNvd9rtd9lGVhlzIMejf7n3OtSz99fgsFekkDVIL8zojYNVPZiDgMrJC0JD33HcA+4Hci4pk51jcrZfow2Oy0euHLTBdhteMimrJczVuWL5yZ/udc69JPn9+mgV6SgMeA8Yh4tEGZm1I5JK0BFgCvSloA7AYej4g/a1+181CmD4PNzmyO6qbvTPrtRz7M17bc/pb+5LkeIZZl5aqyfOF0si799PktMo9+LXA3cFLSWEp7CFgGEBHbgbuAeyRNAZPAxjQD5zeA9wPvlPTx9NyPR8QY1vUpVmXU74PR7Tyqa9drlWG6Y5FB725t+04NwPfT57elwdhuaOf0yn7Q74FuLso2mDUb7RzoK9MAZqf9zp6T7Hzmb94ycFy26Y9FlOnzO6fpldZZZTj66pWyTAWciw+sHOJPnvmbuumt6qcjxLnYc/zsTwR56Mz0x05r1+e3018YDvTWM/00mNXIoecmWkqfSa/n+HfLtv1nGi7qMm1y6gqf+cqp7N8L6M41Bw701jO9vh1BO7T7y2o+nOEVfW++/8OpH9+Bsp+O8lvVjTPb7O9Hb+VVppkZs9VPMy/KYrbvTVnnqM9VN85sHeitZ8oyFXAucviy6rZmi+nMpJ+69YrqxsGCu246qEwj8mXV710VZehX77f9rN579oGVQ3zxyEtcaTILMMczpW4Mwnt6ZYfkMHXQyi+n/azRfXqm9Wu7imjHl7WnV/ZADlMHrfxy2s9mWoSlXXPry3r20+kzWwf6Dslh6qCVX077WaMujHYdxff61sm95MHYDvFsDOuGnPazTg/Od/puk3O982gn+Yi+Q+bLVY7WW7ntZ53swujk2U/ZzxZ8RN8hOUwdtPLzflZcJ89+yn5veh/Rd1C/Tx208mk0mOj9rLlOnv2UfazEgd6sT5S9e6DsOnnNQ9lv5+FAb9YncppK2SudOvsp+1iJA71Znyh798B8VoYrpGfiQG/WJ8rePTDflXmspMiasUslHZI0LumUpPvrlFkv6YSkMUmjkm6ryvufki5I+u/trrxZq8o817kZ30DNZqvIEf1l4MGIOCZpEXBU0oGIOF1V5iCwN60Tuwp4AliZ8rYBPw3883ZW3KxV/T6YWfbuASuvpoE+Is4D59Pj1yWNA8PA6aoyl6qecg1vrgpGRByU9I/aVWGz2cphMLPM3QNWXi1dMCVpObAaOFIn705JzwH7gPvaUTmzdmo0aNnoRlpmuSgc6CUtBJ4EHoiIi7X5EbE7IlYCG4CHW6mEpE2pb390YqL1tTbNimg0aCnoq756s1YVCvSSBqgE+Z0RsWumshFxGFghaUnRSkTEjogYiYiRoaGhok8za8nmdTejOukBpblU3awTisy6EfAYMB4RjzYoc1Mqh6Q1wALg1XZW1GyuNqwebriwheeiW86KzLpZC9wNnJQ0ltIeApYBRMR24C7gHklTwCSwMdLSVZL+F5UZOAslvQz8VkTsb28zzIoZ9lx0m4eKzLp5Guqe8VaX2QpsbZD3vtlVzaz9yn6pulkn+MpYm1c8F93mIwd6m3c8F93mGy88YmaWOQd6M7PMOdCbmWXOgd7MLHMO9GZmmXOgNzPLnAO9mVnmHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5oqsGbtU0iFJ45JOSbq/Tpn1kk5IGpM0Kum2qrx7JT2ffu5tdwPMzGxmRRYeuQw8GBHHJC0Cjko6EBGnq8ocBPZGREhaBTwBrJR0HfBvgREg0nP3RsT329wOMzNroOkRfUScj4hj6fHrwDgwXFPm0vRi4MA1VII6wDrgQES8loL7AeCD7aq8mZk111IfvaTlwGrgSJ28OyU9B+wD7kvJw8BLVcVepuZLwszMOqtwoJe0EHgSeCAiLtbmR8TuiFgJbAAenn5anZeK2gRJm1Lf/ujExETRKpmZWQGFAr2kASpBfmdE7JqpbEQcBlZIWkLlCH5pVfa7gHN1nrMjIkYiYmRoaKhw5c3MrLkis24EPAaMR8SjDcrclMohaQ2wAHgV2A/cIelaSdcCd6Q0MzPrkiKzbtYCdwMnJY2ltIeAZQARsR24C7hH0hQwCWxMg7OvSXoY+Hp63r+LiNfa2QAzM5uZ3pwsUw4jIyMxOjra62qYmfUVSUcjYqRenq+MNTPLnAO9mVnmHOjNzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy5wDvZlZ5hzozcwy50BvZpY5B3ozs8w50JuZZc6B3swscw70ZmaZc6A3M8tc6VaYkjQBvNjreszREuB7va5EB+TaLsi3bW5X/5lt234uIobqZZQu0OdA0mijJb36Wa7tgnzb5nb1n060zV03ZmaZc6A3M8ucA31n7Oh1BTok13ZBvm1zu/pP29vmPnozs8z5iN7MLHMO9GZmmXOgnwVJ35F0UtKYpNGUdp2kA5KeT7+vTemS9J8kvSDphKQ1va39W0n6Q0mvSHq2Kq3ltki6N5V/XtK9vWhLtQbt+rSks2m7jUn6UFXep1K7zkhaV5X+wZT2gqQt3W5HLUlLJR2SNC7plKT7U3pfb7MZ2pXDNnu7pL+S9I3Uts+k9BslHUnv/5ckLUjpV6e/X0j5y6teq26bm4oI/7T4A3wHWFKT9u+BLenxFmBrevwh4M8BAb8CHOl1/Wvq/X5gDfDsbNsCXAd8K/2+Nj2+toTt+jTwyTpl3wt8A7gauBH4JnBV+vkm8G5gQSrz3h6363pgTXq8CPjrVP++3mYztCuHbSZgYXo8ABxJ2+IJ4GMpfTvwL9PjTwDb0+OPAV+aqc1F6uAj+vZZD3whPf4CsKEq/fGoeAZYLOn6XlSwnog4DLxWk9xqW9YBByLitYj4PnAA+GDna99Yg3Y1sh7404j4UUR8G3gB+OX080JEfCsi/hb401S2ZyLifEQcS49fB8aBYfp8m83Qrkb6aZtFRFxKfw6knwBuB76c0mu32fS2/DLwjyWJxm1uyoF+dgL4qqSjkjaltJ+NiPNQ2WmBn0npw8BLVc99mZl34DJotS391MZ/nbow/nC6e4M+bVc6pV9N5Qgxm21W0y7IYJtJukrSGPAKlS/VbwIXIuJyKlJdzx+3IeX/AHgnc2ibA/3srI2INcCvAf9K0vtnKKs6af06p7VRW/qljb8PrABuBc4Dn0vpfdcuSQuBJ4EHIuLiTEXrpJW2bXXalcU2i4grEXEr8C4qR+HvqVcs/W572xzoZyEizqXfrwC7qWy47053yaTfr6TiLwNLq57+LuBc92o7K622pS/aGBHfTR+4N4A/4M3T3r5ql6QBKsFwZ0TsSsl9v83qtSuXbTYtIi4Af0Glj36xpLelrOp6/rgNKf/vUOmGnHXbHOhbJOkaSYumHwN3AM8Ce4HpmQv3Av8tPd4L3JNmP/wK8IPpU+wSa7Ut+4E7JF2bTq3vSGmlUjM2cieV7QaVdn0szXa4Efh54K+ArwM/n2ZHLKAyMLa3m3WulfpqHwPGI+LRqqy+3maN2pXJNhuStDg9HgR+lcoYxCHgo6lY7Tab3pYfBZ6KymhsozY318vR6H78oTKa/430cwr47ZT+TuAg8Hz6fV28OeL+X6j0yZ0ERnrdhpr2fJHKKfEUlSOG35pNW4D7qAwOvQD8Zknb9cep3ifSh+b6qvK/ndp1Bvi1qvQPUZkB8s3pbd3jdt1G5XT9BDCWfj7U79tshnblsM1WAcdTG54Ffjelv5tKoH4B+DPg6pT+9vT3Cyn/3c3a3OzHt0AwM8ucu27MzDLnQG9mljkHejOzzDnQm5llzoHezCxzDvRmZplzoDczy9z/B4iZzGiJC+rEAAAAAElFTkSuQmCC\n",
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "plt.scatter(df.classes,df.lambda_weighted)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Adjusted CrossEntropy"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 288,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "class MultiTaskCELoss(nn.Module):\n",
- " \"\"\"\n",
- " A cross entropy loss function which will cancel out\n",
- " the effect of different class numbers\n",
- " \"\"\"\n",
- " def __init__(self,):\n",
- " super().__init__()\n",
- " self.celoss = nn.CrossEntropyLoss()\n",
- " \n",
- " def forward(\n",
- " self,\n",
- " y_pred: torch.FloatTensor,\n",
- " y_true: torch.LongTensor,\n",
- " )-> torch.FloatTensor:\n",
- " \"\"\"\n",
- " Input:\n",
- " - y_pred: torch.FloatTensor, Prediction tensor\n",
- " - y_true: torch.LongTensor, Label indices\n",
- " Return:\n",
- " - loss: torch.FloatTensor, scala adjusted\n",
- " \"\"\"\n",
- " nb_classes = y_pred.size(-1)\n",
- " lambda_ = 1/np.log10(nb_classes)\n",
- " loss = self.celoss(y_pred,y_true)\n",
- " return loss*lambda_"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Let's make this adjustment into the loss function, an upgraded version of CrossEntropy"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 289,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "e869eca73d324737b011b32aebdeab43",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n"
- ]
- }
- ],
- "source": [
- "result = []\n",
- "\n",
- "for i in tqdm(range(50)):\n",
- " c = random.randint(2,3000)\n",
- " b = 1-random.random()\n",
- " # here we change the loss function to MultiTaskCELoss\n",
- " loss = create_softmax_pipeline(1,c,crit=MultiTaskCELoss())\\\n",
- " .test_loss(300,inbalanced_input(b))\n",
- " result.append(dict(classes=c,loss=loss.item(),inbl_bias=b))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 283,
- "metadata": {},
- "outputs": [],
- "source": [
- "df_adjusted = pd.DataFrame(result)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 284,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " classes | \n",
- " loss | \n",
- " inbl_bias | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 533 | \n",
- " 2.319248 | \n",
- " 0.208584 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 2264 | \n",
- " 2.326346 | \n",
- " 0.022896 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 254 | \n",
- " 2.320735 | \n",
- " 0.922445 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 1027 | \n",
- " 2.328291 | \n",
- " 0.643444 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 2428 | \n",
- " 2.326136 | \n",
- " 0.883795 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " 772 | \n",
- " 2.317071 | \n",
- " 0.062716 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 2166 | \n",
- " 2.315754 | \n",
- " 0.364122 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " 812 | \n",
- " 2.321972 | \n",
- " 0.852977 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " 2632 | \n",
- " 2.321708 | \n",
- " 0.639467 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 2576 | \n",
- " 2.313269 | \n",
- " 0.968200 | \n",
- "
\n",
- " \n",
- " 10 | \n",
- " 1326 | \n",
- " 2.321515 | \n",
- " 0.508901 | \n",
- "
\n",
- " \n",
- " 11 | \n",
- " 169 | \n",
- " 2.321526 | \n",
- " 0.870789 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 815 | \n",
- " 2.309163 | \n",
- " 0.262275 | \n",
- "
\n",
- " \n",
- " 13 | \n",
- " 792 | \n",
- " 2.321771 | \n",
- " 0.608608 | \n",
- "
\n",
- " \n",
- " 14 | \n",
- " 739 | \n",
- " 2.316215 | \n",
- " 0.198933 | \n",
- "
\n",
- " \n",
- " 15 | \n",
- " 2460 | \n",
- " 2.318523 | \n",
- " 0.880343 | \n",
- "
\n",
- " \n",
- " 16 | \n",
- " 2420 | \n",
- " 2.320517 | \n",
- " 0.781538 | \n",
- "
\n",
- " \n",
- " 17 | \n",
- " 1688 | \n",
- " 2.316657 | \n",
- " 0.954598 | \n",
- "
\n",
- " \n",
- " 18 | \n",
- " 2929 | \n",
- " 2.323009 | \n",
- " 0.750869 | \n",
- "
\n",
- " \n",
- " 19 | \n",
- " 78 | \n",
- " 2.192723 | \n",
- " 0.089868 | \n",
- "
\n",
- " \n",
- " 20 | \n",
- " 1670 | \n",
- " 2.325608 | \n",
- " 0.188055 | \n",
- "
\n",
- " \n",
- " 21 | \n",
- " 2868 | \n",
- " 2.324525 | \n",
- " 0.058037 | \n",
- "
\n",
- " \n",
- " 22 | \n",
- " 905 | \n",
- " 2.319283 | \n",
- " 0.126673 | \n",
- "
\n",
- " \n",
- " 23 | \n",
- " 1675 | \n",
- " 2.309565 | \n",
- " 0.307769 | \n",
- "
\n",
- " \n",
- " 24 | \n",
- " 1430 | \n",
- " 2.327099 | \n",
- " 0.451335 | \n",
- "
\n",
- " \n",
- " 25 | \n",
- " 424 | \n",
- " 2.315844 | \n",
- " 0.512538 | \n",
- "
\n",
- " \n",
- " 26 | \n",
- " 1511 | \n",
- " 2.316325 | \n",
- " 0.613963 | \n",
- "
\n",
- " \n",
- " 27 | \n",
- " 959 | \n",
- " 2.330317 | \n",
- " 0.987407 | \n",
- "
\n",
- " \n",
- " 28 | \n",
- " 147 | \n",
- " 2.325459 | \n",
- " 0.931882 | \n",
- "
\n",
- " \n",
- " 29 | \n",
- " 278 | \n",
- " 2.317273 | \n",
- " 0.401440 | \n",
- "
\n",
- " \n",
- " 30 | \n",
- " 885 | \n",
- " 2.325397 | \n",
- " 0.087088 | \n",
- "
\n",
- " \n",
- " 31 | \n",
- " 1459 | \n",
- " 2.307971 | \n",
- " 0.857782 | \n",
- "
\n",
- " \n",
- " 32 | \n",
- " 182 | \n",
- " 2.344716 | \n",
- " 0.443465 | \n",
- "
\n",
- " \n",
- " 33 | \n",
- " 2468 | \n",
- " 2.326825 | \n",
- " 0.572978 | \n",
- "
\n",
- " \n",
- " 34 | \n",
- " 954 | \n",
- " 2.322750 | \n",
- " 0.815261 | \n",
- "
\n",
- " \n",
- " 35 | \n",
- " 1297 | \n",
- " 2.319934 | \n",
- " 0.887149 | \n",
- "
\n",
- " \n",
- " 36 | \n",
- " 1787 | \n",
- " 2.326161 | \n",
- " 0.113759 | \n",
- "
\n",
- " \n",
- " 37 | \n",
- " 462 | \n",
- " 2.325276 | \n",
- " 0.969307 | \n",
- "
\n",
- " \n",
- " 38 | \n",
- " 686 | \n",
- " 2.324262 | \n",
- " 0.305557 | \n",
- "
\n",
- " \n",
- " 39 | \n",
- " 428 | \n",
- " 2.323913 | \n",
- " 0.342770 | \n",
- "
\n",
- " \n",
- " 40 | \n",
- " 2900 | \n",
- " 2.319349 | \n",
- " 0.609161 | \n",
- "
\n",
- " \n",
- " 41 | \n",
- " 2765 | \n",
- " 2.323373 | \n",
- " 0.296973 | \n",
- "
\n",
- " \n",
- " 42 | \n",
- " 752 | \n",
- " 2.320189 | \n",
- " 0.302571 | \n",
- "
\n",
- " \n",
- " 43 | \n",
- " 502 | \n",
- " 2.325897 | \n",
- " 0.419716 | \n",
- "
\n",
- " \n",
- " 44 | \n",
- " 2162 | \n",
- " 2.320449 | \n",
- " 0.202672 | \n",
- "
\n",
- " \n",
- " 45 | \n",
- " 2460 | \n",
- " 2.317962 | \n",
- " 0.430931 | \n",
- "
\n",
- " \n",
- " 46 | \n",
- " 2537 | \n",
- " 2.312116 | \n",
- " 0.693554 | \n",
- "
\n",
- " \n",
- " 47 | \n",
- " 806 | \n",
- " 2.318746 | \n",
- " 0.539312 | \n",
- "
\n",
- " \n",
- " 48 | \n",
- " 251 | \n",
- " 2.322845 | \n",
- " 0.181191 | \n",
- "
\n",
- " \n",
- " 49 | \n",
- " 2845 | \n",
- " 2.326007 | \n",
- " 0.667570 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " classes loss inbl_bias\n",
- "0 533 2.319248 0.208584\n",
- "1 2264 2.326346 0.022896\n",
- "2 254 2.320735 0.922445\n",
- "3 1027 2.328291 0.643444\n",
- "4 2428 2.326136 0.883795\n",
- "5 772 2.317071 0.062716\n",
- "6 2166 2.315754 0.364122\n",
- "7 812 2.321972 0.852977\n",
- "8 2632 2.321708 0.639467\n",
- "9 2576 2.313269 0.968200\n",
- "10 1326 2.321515 0.508901\n",
- "11 169 2.321526 0.870789\n",
- "12 815 2.309163 0.262275\n",
- "13 792 2.321771 0.608608\n",
- "14 739 2.316215 0.198933\n",
- "15 2460 2.318523 0.880343\n",
- "16 2420 2.320517 0.781538\n",
- "17 1688 2.316657 0.954598\n",
- "18 2929 2.323009 0.750869\n",
- "19 78 2.192723 0.089868\n",
- "20 1670 2.325608 0.188055\n",
- "21 2868 2.324525 0.058037\n",
- "22 905 2.319283 0.126673\n",
- "23 1675 2.309565 0.307769\n",
- "24 1430 2.327099 0.451335\n",
- "25 424 2.315844 0.512538\n",
- "26 1511 2.316325 0.613963\n",
- "27 959 2.330317 0.987407\n",
- "28 147 2.325459 0.931882\n",
- "29 278 2.317273 0.401440\n",
- "30 885 2.325397 0.087088\n",
- "31 1459 2.307971 0.857782\n",
- "32 182 2.344716 0.443465\n",
- "33 2468 2.326825 0.572978\n",
- "34 954 2.322750 0.815261\n",
- "35 1297 2.319934 0.887149\n",
- "36 1787 2.326161 0.113759\n",
- "37 462 2.325276 0.969307\n",
- "38 686 2.324262 0.305557\n",
- "39 428 2.323913 0.342770\n",
- "40 2900 2.319349 0.609161\n",
- "41 2765 2.323373 0.296973\n",
- "42 752 2.320189 0.302571\n",
- "43 502 2.325897 0.419716\n",
- "44 2162 2.320449 0.202672\n",
- "45 2460 2.317962 0.430931\n",
- "46 2537 2.312116 0.693554\n",
- "47 806 2.318746 0.539312\n",
- "48 251 2.322845 0.181191\n",
- "49 2845 2.326007 0.667570"
- ]
- },
- "execution_count": 284,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df_adjusted"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 287,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 287,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAD4CAYAAAAD6PrjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAclklEQVR4nO3df5Ac5X3n8fcHaQULwpGAtSP044QxJcAVkLg9TJVcpsAXBFzFErHvjM8lSExKvhyuQy5ZhSBXmDuSChgbcqmK0cngBCeKAYP4UcVxigqUomyC4tUPJMRaIAyOkXSwgGTBobO10vf+mGea0Wh+9MzO7sysPq+qqe15+ume59nu6W/38zw9rYjAzMwM4Lh2F8DMzDqHg4KZmWUcFMzMLOOgYGZmGQcFMzPLTGx3ARpx2mmnxezZs9tdDDOzrrJx48a3I6IvT96uCgqzZ89mYGCg3cUwM+sqkn6RN6+bj8zMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQZ6GkrZK2SBqQ9OmSeYdS+hZJT5SknyFpg6RXJD0oaVLrqmVmZs3Ic6UwDCyLiHOAi4DrJZ1bludp4PyImAt8Bbi3ZN6BiJibXp8rSb8DuDsizgL2Atc1XQszM2uJukEhIvZExKY0/R4wCEwvy/N+RER6exIQ1CBJwKXAwynpfmBRY0U3M7NWa6hPQdJsYB6wocK8qyT9DHiSwtVC0QmpSel5ScUD/6nAvogYTu/foCzQmJnZ2MsdFCRNBh4BlkbE/vL5EfFoRJxN4Yz/tpJZsyKiH/iPwF9IOhNQhY+oeHUhaUkKKgNDQ0N5i2tmZk3IFRQk9VAICKsjYk2tvBHxLHCmpNPS+93p78+Bf6RwpfE2MEVS8RnRM4DdVda3KiL6I6K/ry/Xc6fNzKxJeUYfCbgPGIyIu6rk+UTKh6QLgEnAO5KmSjo+pZ8GzAdeSv0P64EvpFVcCzw+0sqM1GObdzH/9mc4Y8WTzL/9GR7bvKvdRTIzG1MT62dhPrAY2CZpS0q7GZgFEBErgc8D10g6CBwAvhgRIekc4H9KOkwhAN0eES+lddwIPCDpT4HNFAJP2zy2eRc3rdnGgYOHANi17wA3rdkGwKJ57u4ws2ODPhw01Pn6+/tjYGBgVNY9//Zn2LXvwFHp06f08pMVl47KZ5qZjQVJG1Pfbl2+oznZXSEg1Eo3MxuPHBSS06f0NpRuZjYeOSgkyxfMobdnwhFpvT0TWL5gTptKZGY29vJ0NB8Tip3Jd67dwe59Bzh9Si/LF8xxJ7OZHVMcFEosmjfdQcDMjmluPjIzs4yDgpmZZRwUzMws46BgZmYZBwUzM8s4KJiZWcZBwczMMg4KZmaW8c1rNioe27zLd4ebdSEHBWs5P5vCrHu5+cha7s61O7KAUHTg4CHuXLujTSUys7x8pWAtV+/ZFG5aMutcDgrWcqdP6a34FLvTp/R2ZdOSg1h38nZrjoNCB+r2nXn5gjlHHPjhw2dT1Gpa6sQ6jlUQa2abd/t+UqoVdSldx2/19vB/fzPMwUOFxw136slHJ27DukFB0kzgB8BvA4eBVRHxP8ryLARuS/OHgaUR8WNJc4F7gI8Ah4A/i4gH0zJ/A1wM/Cqt5g8iYksrKtWITtgo421nrvVsiq8/WHkTd+pjT8ciiDUTeLrxiquaVtSlfB37Dhw8Kk+nnXx06jbMc6UwDCyLiE2STgY2SloXES+V5HkaeCIiQtJ5wEPA2cAHwDUR8Yqk09OyayNiX1pueUQ83ML6NKSdG6V4UN217wACIqWPxs7c6sCX5/9W7dkUtZqWOtFYPLu7mcDTbVdctbSiLpXWUUknnXxUq/eyh14A2hcY6o4+iog9EbEpTb8HDALTy/K8HxHF49pJpGNcRLwcEa+k6d3AW0Bf64o/Mu0aJVM8qBYPjlEnPzS/M5d+VvDhAfyxzbuaWh+M7P/WbY89HYtndzcTeMYiWI2VVtQlb96xPPl4bPMu5t/+DGeseJL5tz9z1HeuWpkPRRzxHa23nlZraEiqpNnAPGBDhXlXSfoZ8CTwlQrzLwQmAa+WJP+ZpK2S7pZ0fJXPXCJpQNLA0NBQI8Wtq9GdsVUbJ+9ZTalmd+aRHMCr1TfP/63asovmTefPf/93mD6lFwHTp/Ty57//Ox17djsWQayZwDMWwWqstKIuefKO5clHnpOxWmUufkdH46SuntxBQdJk4BEK/QX7y+dHxKMRcTawiEL/Qumy04C/Bf4wIg6n5JsoNDH9G+AU4MZKnxsRqyKiPyL6+/pae5FRbaMcJx31T2/lxmn0bG4kO3OzgW/2iif5+oNbKta33pe43v9q0bzp/GTFpbx2+7/jJysubUsfTt7gPhZBrJnA021XXLW0oi6V1tFznJh6Ys+It1szJ4N5TsYqlbnU7n0H2tKakWv0kaQeCgFhdUSsqZU3Ip6VdKak0yLibUkfoXD18F8j4vmSfHvS5K8l/TXwjeaq0LxKo2Tgw8s3OLLTtFVtuNXa1Yt6jhOTT5jIvg8OjrgPoF4bfml/w5QTe3j//w1z8HChQau8WatY31qji6Cz27ub6Uca7Wd31+qYb+UyYy1vX1ar6nJCz3HZdp3S28Otn/vkiP8fzfY75jkZKy6/7KEXOBRHNyKfPqW3Lc2EeUYfCbgPGIyIu6rk+QTwaupovoBCM9E7kiYBjwI/iIgflS0zLSL2pPUvAl4cYV0aVmujlB/Eqh3Eax3cq6l0UC12Nk+v8IUonqk084WpdQAv3+H3fnB0J3e53fsOHPEl3rXvABOkI85eOrG9u7Rjv1wnBKxmAs9oB6uRaPRgOpK6lH8WwK+HD9dYIr9mT3DyDqgorqPS8aD43aoWMEZLniuF+cBiYJuk4njCm4FZABGxEvg8cI2kg8AB4IspQPwH4DPAqZL+IC1bHHq6WlIfhfpvAf5Ti+rUkLzDJKttnAlSU58J+c6MRjpCqtZnzb/9mab7NirtzMWyTTmxp2KAaVd7d6WDRrlu7KDtZGN5tVhrFM/XH9wyoquoZk9w6l1Nlyo/ySodjVjpmDPazYR1g0JE/JjCgbtWnjuAOyqk/x3wd1WWuTRnGUddnqheaePUSq8n75lRK75c1T6r0QNhz3E6YmesVrbjJx5Hb8+EXF+IsZCnY78bO2g72VheLdYaxQOFk5XlP2pumGezQ6gbbRIrfkfn3/5Mxc+bIHE4YkyaCf2DeOTr6JpeYycYzWFio/nlavRAOPmEiUfsjNXK8KsDBztqhFG9/1UnddCO9fDD0TKWo6PyrPPg4eDWJ7Y3vO6RdII3M6Ci2r56OGLMBmb4Zy7IF9WrdUpD4006jdxMNpo3e11ydh+rn/+XXPdJAOwraxKqVbZOau+u1bFfqQ+nUa26ObBT73BtxvIFc1j+8AvZnfkAPRM0KsG31nezVKUbQ+sZ6w79Tri500EhqXcQK2/3K5e3SafRL34jbZONeGzzLh7ZuCt3QICjd8zRKlurVStnK65eWnkg7+RRW00p37maa2mtq/zA3eqPGcsTnE74Trn5qAHFy8FqHSx5mnQaHXc8WuPkG72BrtKO2S03oo1mOVs5jrwTR2016861O7KhzUUHD8eoja8vbaqZemJPxTzV0jtJJ3ynfKXQhJFc4jXzxR+NM5U8B5o8nVud1ExUy2iVs5UH8k5oOmiVdga4b/7eJys2XX3z9z5Zc7lO+HFMaP93ykGhCSO5xOuUL369G+ha1bwy3rVye3ZC00GrtHM/z9sPUOvGzW7uzxkpNx81YSSXeJ3y8wSVylFsFuvUZqBO1Mrt2QlNB63S7v283sif8p9i2fvBwaOau47VR8j6SqFJzV7idcrPE3RKObpdq/+P7W46aJVO37+68ae2x4qiyZuv2qG/vz8GBgbaXQyzUdEpbdrHgjNWPJlrlNL0Kb38ZEXH3GfbNEkbI6I/T15fKVhVPkiNnfF0j0I3qNenBt3bnzNS7lOwitrxO+7dYLTuOG7XA5+OVRV/anuCmNI78p/a7na+UrCKxt2NVC0wmmfz4+kehW7Q6X0e7eSgYBX5IHW00QyUnTJU+VgyXjr1W83NR1bReHrcY6uMZqBs9xBOsyIHBavIB6mjjWagHE/3KFh3c/ORVeQ216ON9h3Hbs6wTuCgYFX5IHUkB0o7FjgomDXAgdLGu7p9CpJmSlovaVDSdkk3VMizUNJWSVskDUj6dMm8ayW9kl7XlqT/a0nbJO2U9JdSEw87NjOzlsrT0TwMLIuIc4CLgOslnVuW52ng/IiYC3wFuBdA0inAN4FPARcC35Q0NS1zD7AEOCu9Lh9hXczMbITqBoWI2BMRm9L0e8AgML0sz/vx4Y8oncSHz1haAKyLiHcjYi+wDrhc0jTgIxHxT2m5HwCLWlIjMzNrWkNDUiXNBuYBGyrMu0rSz4AnKVwtQCF4/LIk2xspbXqaLk83M7M2yh0UJE0GHgGWRsT+8vkR8WhEnE3hjP+24mIVVhU10it97pLUTzEwNDSUt7hmZtaEXEFBUg+FgLA6ItbUyhsRzwJnSjqNwhXAzJLZM4DdKX1GhfRK61sVEf0R0d/X15enuGZm1qQ8o48E3AcMRsRdVfJ8ojh6SNIFwCTgHWAtcJmkqamD+TJgbUTsAd6TdFFa7hrg8ZbUyMzMmpbnPoX5wGJgm6QtKe1mYBZARKwEPg9cI+kgcAD4YupAflfSbcBP03L/PSLeTdN/DPwN0As8lV5mZtZGfvKamdk418iT1/yDeGZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpZxUDAzs4yDgpmZZRwUzMws46BgZmYZBwUzM8s4KJiZWcZBwczMMg4KZmaWcVAwM7OMg4KZmWXqBgVJMyWtlzQoabukGyrk+bKkren1nKTzU/ocSVtKXvslLU3zbpW0q2Tela2vnpmZNWJijjzDwLKI2CTpZGCjpHUR8VJJnteAiyNir6QrgFXApyJiBzAXQNIEYBfwaMlyd0fEt1tSEzMzG7G6QSEi9gB70vR7kgaB6cBLJXmeK1nkeWBGhVV9Fng1In4xohKbmdmoaahPQdJsYB6woUa264CnKqRfDfywLO1rqcnp+5KmVvnMJZIGJA0MDQ01UlwzM2tQ7qAgaTLwCLA0IvZXyXMJhaBwY1n6JOBzwI9Kku8BzqTQvLQH+E6ldUbEqojoj4j+vr6+vMU1M7Mm5AoKknooBITVEbGmSp7zgHuBhRHxTtnsK4BNEfFmMSEi3oyIQxFxGPgecGEzFTAzs9bJM/pIwH3AYETcVSXPLGANsDgiXq6Q5UuUNR1Jmlby9irgxbyFNjOz0ZFn9NF8YDGwTdKWlHYzMAsgIlYCtwCnAt8txBCGI6IfQNKJwO8CXy1b77ckzQUCeL3CfDMzG2N5Rh/9GFCdPH8E/FGVeR9QCBjl6YtzltHMzMaI72g2M7OMg4KZmWUcFMzMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpZxUDAzs4yDgpmZZRwUzMwsUzcoSJopab2kQUnbJd1QIc+XJW1Nr+cknV8y73VJ2yRtkTRQkn6KpHWSXkl/p7auWmZm1ow8VwrDwLKIOAe4CLhe0rlleV4DLo6I84DbgFVl8y+JiLkR0V+StgJ4OiLOAp5O783MrI3qBoWI2BMRm9L0e8AgML0sz3MRsTe9fR6YkeOzFwL3p+n7gUV5C21mZqOjoT4FSbOBecCGGtmuA54qeR/AP0jaKGlJSfrHImIPFAIP8NEqn7lE0oCkgaGhoUaKa2ZmDZqYN6OkycAjwNKI2F8lzyUUgsKnS5LnR8RuSR8F1kn6WUQ8m/dzI2IVqTmqv78/8i5nZmaNy3WlIKmHQkBYHRFrquQ5D7gXWBgR7xTTI2J3+vsW8ChwYZr1pqRpadlpwFvNVsLMzFojz+gjAfcBgxFxV5U8s4A1wOKIeLkk/SRJJxengcuAF9PsJ4Br0/S1wOPNVsLMzFojT/PRfGAxsE3SlpR2MzALICJWArcApwLfLcQQhtNIo48Bj6a0icDfR8T/Tuu4HXhI0nXAvwD/viU1MjOzpimie5rp+/v7Y2BgoH5GMzPLSNpYdktAVb6j2czMMg4KZmaWcVAwM7OMg4KZmWUcFMzMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQ58uStqbXc5LOr7espFsl7ZK0Jb2ubG3VzMysURNz5BkGlkXEJkknAxslrYuIl0ryvAZcHBF7JV0BrAI+lWPZuyPi2y2sj5mZjUDdK4WI2BMRm9L0e8AgML0sz3MRsTe9fR6YkXdZMzPrHA31KUiaDcwDNtTIdh3wVM5lv5aanL4vaWqVz1wiaUDSwNDQUCPFNTOzBuUOCpImA48ASyNif5U8l1AICjfmWPYe4ExgLrAH+E6ldUbEqojoj4j+vr6+vMU1M7Mm5AoKknooHNRXR8SaKnnOA+4FFkbEO/WWjYg3I+JQRBwGvgdc2Hw1zMysFfKMPhJwHzAYEXdVyTMLWAMsjoiX8ywraVrJ26uAFxsvvpmZtVKe0UfzgcXANklbUtrNwCyAiFgJ3AKcCny3EAcYjoj+astGxP8CviVpLhDA68BXW1IjMzNrmiKi3WXIrb+/PwYGBtpdDDOzriJpYzpRr8t3NJuZWcZBwczMMg4KZmaWcVAwM7OMg4KZmWUcFMzMLOOgYGZmGQcFMzPLOCiYmVnGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQ58uStqbXc5LOL5l3uaQdknZKWlGSfoakDZJekfSgpEmtq5aZmTUjz5XCMLAsIs4BLgKul3RuWZ7XgIsj4jzgNmAVgKQJwF8BVwDnAl8qWfYO4O6IOAvYC1w30sqYmdnI1A0KEbEnIjal6feAQWB6WZ7nImJvevs8MCNNXwjsjIifR8RvgAeAhZIEXAo8nPLdDywaaWXMzGxkGupTkDQbmAdsqJHtOuCpND0d+GXJvDdS2qnAvogYLkuv9JlLJA1IGhgaGmqkuGZm1qDcQUHSZOARYGlE7K+S5xIKQeHGYlKFbFEj/ejEiFUR0R8R/X19fXmLa2ZmTcgVFCT1UAgIqyNiTZU85wH3Agsj4p2U/AYwsyTbDGA38DYwRdLEsnQzM2ujPKOPBNwHDEbEXVXyzALWAIsj4uWSWT8FzkojjSYBVwNPREQA64EvpHzXAo83Xw0zM2uFifWzMB9YDGyTtCWl3QzMAoiIlcAtFPoJvluIIQynJp9hSV8D1gITgO9HxPa0jhuBByT9KbCZQuAxM7M2UuGkvTv09/fHwMBAQ8s8tnkXd67dwe59Bzh9Si/LF8xh0byKfdpmZuOSpI0R0Z8nb54rha712OZd3LRmGwcOHgJg174D3LRmG4ADg5lZBeP6Zy7uXLsjCwhFBw4e4s61O9pUIjOzzjaug8LufQcaSjczO9aN66Bw+pTehtLNzI514zooLF8wh96eCUek9fZMYPmCOW0qkZlZZxvXHc3FzmSPPjIzy2dcBwUoBAYHATOzfMZ185GZmTXGQcHMzDIOCmZmlnFQMDOzjIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwyDgpmZpapGxQkzZS0XtKgpO2SbqiQ52xJ/yTp15K+UZI+R9KWktd+SUvTvFsl7SqZd2Vrq2ZmZo3K84N4w8CyiNgk6WRgo6R1EfFSSZ53gf8CLCpdMCJ2AHMBJE0AdgGPlmS5OyK+PZIKmJlZ69S9UoiIPRGxKU2/BwwC08vyvBURPwUO1ljVZ4FXI+IXIyivmZmNoob6FCTNBuYBG5r4rKuBH5alfU3SVknflzS1ymcukTQgaWBoaKiJjzUzs7xyBwVJk4FHgKURsb+RD5E0Cfgc8KOS5HuAMyk0L+0BvlNp2YhYFRH9EdHf19fXyMeamVmDcgUFST0UAsLqiFjTxOdcAWyKiDeLCRHxZkQciojDwPeAC5tYr5mZtVCe0UcC7gMGI+KuJj/nS5Q1HUmaVvL2KuDFJtdtZmYtkmf00XxgMbBN0paUdjMwCyAiVkr6bWAA+AhwOA07PTci9ks6Efhd4Ktl6/2WpLlAAK9XmG9mZmOsblCIiB8DqpPn/wAzqsz7ADi1QvrinGU0M7MxoohodxlykzQEVBvSehrw9hgWZyyMxzrB+KyX69Q9xmO96tXpX0VErpE6XRUUapE0EBH97S5HK43HOsH4rJfr1D3GY71aWSf/9pGZmWUcFMzMLDOegsKqdhdgFIzHOsH4rJfr1D3GY71aVqdx06dgZmYjN56uFMzMbIQcFMzMLNP1QUHS5ZJ2SNopaUW7y9MoSa9L2pYeNDSQ0k6RtE7SK+nv1JQuSX+Z6rpV0gXtLX1B+pXbtyS9WJLWcB0kXZvyvyLp2nbUpaQslepU9cFQkm5KddohaUFJekftn9UemtXN26tGnbp2e0k6QdI/S3oh1em/pfQzJG1I//MH04+NIun49H5nmj+7ZF0V61pVRHTtC5gAvAp8HJgEvEDh5zXaXrYG6vA6cFpZ2reAFWl6BXBHmr4SeIrCHeYXARvaXf5Urs8AFwAvNlsH4BTg5+nv1DQ9tcPqdCvwjQp5z0373vHAGWmfnNCJ+ycwDbggTZ8MvJzK37Xbq0adunZ7pf/35DTdQ+FxBRcBDwFXp/SVwB+n6f8MrEzTVwMP1qprrc/u9iuFC4GdEfHziPgN8ACwsM1laoWFwP1p+n4+fKLdQuAHUfA8MEVH/rBgW0TEsxSevleq0TosANZFxLsRsRdYB1w++qWvrEqdqlkIPBARv46I14CdFPbNjts/o/pDs7p2e9WoUzUdv73S//v99LYnvQK4FHg4pZdvp+L2exj4rCRRva5VdXtQmA78suT9G9TeGTpRAP8gaaOkJSntYxGxBwo7PPDRlN5N9W20Dt1St0oPhurKOunIh2aNi+2lox8E1rXbS9IEFX6E9C0KQfdVYF9EDFcoX1b2NP9XFH5zruE6dXtQqPRDfd02xnZ+RFxA4ZkT10v6TI2846G+1erQDXWr9mCorquT8j80q2vqVqFOXb29ovC8mbkUfmz0QuCcStnS35bVqduDwhvAzJL3M4DdbSpLUyJid/r7FvAohY3/ZrFZKP19K2Xvpvo2WoeOr1tUfzBUV9VJlR+a1dXbq1Kdxsv2ioh9wD9S6FOYIqn469al5cvKnub/FoXmz4br1O1B4afAWalHfhKFDpYn2lym3CSdJOnk4jRwGYWHDT0BFEdzXAs8nqafAK5JI0IuAn5VvOTvQI3WYS1wmaSp6TL/spTWMVT9wVBPAFenESBnAGcB/0wH7p+pnbnSQ7O6dntVq1M3by9JfZKmpOle4N9S6CtZD3whZSvfTsXt9wXgmSj0NFera3Xt6Flv5YvC6IiXKbS3/Um7y9Ng2T9OYWTAC8D2YvkptAU+DbyS/p4SH45I+KtU121Af7vrkMr1QwqX5wcpnJlc10wdgK9Q6AjbCfxhB9bpb1OZt6Yv27SS/H+S6rQDuKJT90/g0xSaD7YCW9Lrym7eXjXq1LXbCzgP2JzK/iJwS0r/OIWD+k4Kz7w/PqWfkN7vTPM/Xq+u1V7+mQszM8t0e/ORmZm1kIOCmZllHBTMzCzjoGBmZhkHBTMzyzgomJlZxkHBzMwy/x/AvBzCXAI56QAAAABJRU5ErkJggg==\n",
- "text/plain": [
- "