diff --git a/nbs/11_etl.ipynb b/nbs/11_etl.ipynb
deleted file mode 100644
index 711b499..0000000
--- a/nbs/11_etl.ipynb
+++ /dev/null
@@ -1,174 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# ETL Helper\n",
- "> A combined ETL tool sets"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# default_exp etl"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## A list for each step"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "\n",
- "from forgebox.imports import *\n",
- "from typing import Callable, Union\n",
- "import traceback as tb\n",
- "import logging\n",
- "\n",
- "class NewFileScanner:\n",
- " \"\"\"\n",
- " Keep scannning a directory for new file\n",
- " if found any, callback processing the file\n",
- " \n",
- " Designed for download and delete strategy\n",
- " \n",
- " # Example\n",
- " new_file_scanner = NewFileScanner(\".\", new_file_filter=lambda x:x[-4:]==\".txt\")\n",
- " new_file_scanner(lambda x:print(f\"new file:{x} found\"))s\n",
- " \"\"\"\n",
- " def __init__(\n",
- " self,\n",
- " directory,\n",
- " new_file_filter: Callable=None,\n",
- " done_list_getter: Callable=None,\n",
- " stop_callback: Callable=None,\n",
- " ):\n",
- " self.directory = Path(directory)\n",
- " \n",
- " if new_file_filter is not None:\n",
- " self.new_file_filter=new_file_filter\n",
- " else:\n",
- " self.new_file_filter=lambda x:True\n",
- " \n",
- " if done_list_getter is not None:\n",
- " self.done_list_getter = done_list_getter\n",
- " else:\n",
- " self.done_list_getter = lambda *args:[]\n",
- " \n",
- " if stop_callback is not None:\n",
- " self.stop_callback = stop_callback\n",
- " \n",
- " self.processed = []\n",
- " \n",
- " def __call__(self, process: Callable):\n",
- " while True:\n",
- " try:\n",
- " done_list = self.done_list_getter()\n",
- " for fname in self.directory.iterdir():\n",
- " fname = str(fname)\n",
- " if self.new_file_filter(fname) != True:\n",
- " continue\n",
- " file_path = str(self.directory/fname)\n",
- " if fname not in done_list:\n",
- " if fname not in self.processed:\n",
- " result = process(file_path)\n",
- " self.processed.append(fname)\n",
- " except KeyboardInterrupt as e:\n",
- " if hasattr(self, \"stop_callback\"):\n",
- " self.stop_callback()\n",
- " logging.error(\"manually stoped\")\n",
- " break\n",
- " except Exception as e:\n",
- " error = tb.format_exc()\n",
- " logging.error(error)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Test new file scanner"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "new file:untitled.txt found\n",
- "new file:bc.txt found\n",
- "new file:cde.txt found\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "ERROR:root:manually stoped\n"
- ]
- }
- ],
- "source": [
- "new_file_scanner = NewFileScanner(\".\", new_file_filter=lambda x:x[-4:]==\".txt\")\n",
- "new_file_scanner(lambda x:print(f\"new file:{x} found\"))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.4"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": true,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {},
- "toc_section_display": true,
- "toc_window_display": false
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/nbs/61_thunder_callbacks.ipynb b/nbs/61_thunder_callbacks.ipynb
deleted file mode 100644
index 848e47f..0000000
--- a/nbs/61_thunder_callbacks.ipynb
+++ /dev/null
@@ -1,151 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Lightning Callbacks\n",
- "> Thunder, the DIYed [pytorch-lightening callbacks](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# default_exp thunder.callbacks"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "import pandas as pd\n",
- "from ipywidgets import Output\n",
- "from typing import List, Dict\n",
- "import copy\n",
- "import pytorch_lightning as pl\n",
- "import torch\n",
- "from torch import nn"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "def unfreeze(self):\n",
- " \"\"\"unfreeze this module, and its sub modules\"\"\"\n",
- " for p in self.parameters():\n",
- " p.requires_grad = True\n",
- "\n",
- "\n",
- "def freeze(self):\n",
- " \"\"\"freeze this module, and its sub modules\"\"\"\n",
- " for p in self.parameters():\n",
- " p.requires_grad = False\n",
- "\n",
- "nn.Module.unfreeze = unfreeze\n",
- "nn.Module.freeze = freeze\n",
- "\n",
- "class DataFrameMetricsCallback(pl.Callback):\n",
- " \"\"\"\n",
- " A metrics callback keep showing pandas dataframe\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self) -> None:\n",
- " \"\"\"\n",
- " In Trainer kwargs, passing this arguements along with other callbacks\n",
- " callbacks = [DataFrameMetricsCallback(),]\n",
- " \"\"\"\n",
- " self.metrics: List = []\n",
- "\n",
- " def on_fit_start(\n",
- " self, trainer: pl.Trainer,\n",
- " pl_module: pl.LightningModule\n",
- " ) -> None:\n",
- " pl_module.output = Output()\n",
- " display(pl_module.output)\n",
- "\n",
- " def on_validation_epoch_end(\n",
- " self, trainer: pl.Trainer,\n",
- " pl_module: pl.LightningModule\n",
- " ) -> None:\n",
- " metrics_dict = copy.copy(trainer.callback_metrics)\n",
- " self.metrics.append(dict((k, v.item())\n",
- " for k, v in metrics_dict.items()))\n",
- " pl_module.output.clear_output()\n",
- " with pl_module.output:\n",
- " display(pd.DataFrame(self.metrics).tail(10))\n",
- "\n",
- "\n",
- "def UnfreezeScheduler(frozen_epochs: int = 2):\n",
- " assert hasattr(pl_module, \"top_layers\"), \"Please define 'top_layers' attributes\"+\\\n",
- " \" for pl_module, which will return a list of nn.Module object(s)\"\n",
- " class UnfreezeSchedulerCallback(pl.callbacks.Callback):\n",
- " \"\"\"\n",
- " Train the top layer for [frozen_epochs] epochs\n",
- " then un freeze all\n",
- " \"\"\"\n",
- "\n",
- " def on_epoch_start(self, trainer, pl_module):\n",
- " epoch = trainer.current_epoch\n",
- "\n",
- " if epoch == 0:\n",
- " pl_module.freeze()\n",
- " for tl in pl_module.top_layers:\n",
- " tl.unfreeze()\n",
- " if epoch == frozen_epochs:\n",
- " pl_module.unfreeze()\n",
- " pl_module.base.embeddings.freeze()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.4"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": true,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {},
- "toc_section_display": true,
- "toc_window_display": false
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/nbs/70_hf_transformer_data.ipynb b/nbs/70_hf_transformer_data.ipynb
deleted file mode 100644
index ac7b569..0000000
--- a/nbs/70_hf_transformer_data.ipynb
+++ /dev/null
@@ -1,553 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Data parts for hf transformers"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# default_exp hf.data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "from forgebox.imports import *\n",
- "from forgebox.category import Category\n",
- "from typing import List, Dict, Callable, Any, Tuple"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Process IOBES files"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "def convert_iob2_file_to_iobes(file_path, result_path):\n",
- " \"\"\"\n",
- " Convert IOB2 file to IOBES\n",
- " \"\"\"\n",
- " with open(file_path, 'r') as f:\n",
- " lines = f.readlines()\n",
- " with open(result_path, 'w') as f:\n",
- " for line in lines:\n",
- " line = line.strip()\n",
- " if line == '':\n",
- " f.write('\\n')\n",
- " continue\n",
- " line = line.split()\n",
- " if line[-1] == 'O':\n",
- " f.write(' '.join(line) + '\\n')\n",
- " else:\n",
- " f.write(' '.join(line[:-1]) + ' ' + line[-1] + '\\n')\n",
- "\n",
- "\n",
- "def conbine_iobes_file(\n",
- " file_paths: List[Path],\n",
- " new_file_path: Path\n",
- "):\n",
- " \"\"\"\n",
- " Conbine from multiple IOBES files\n",
- " into IOBES files\n",
- " \"\"\"\n",
- " with open(new_file_path, 'w') as new_file:\n",
- " for file_path in file_paths:\n",
- " with open(file_path, 'r') as file:\n",
- " for line in file:\n",
- " new_file.write(line)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "class IOBES(Dataset):\n",
- " \"\"\"\n",
- " Load iobes file for NER training task\n",
- " \"\"\"\n",
- "\n",
- " def __init__(\n",
- " self,\n",
- " file_path,\n",
- " tokenizer,\n",
- " max_len=128,\n",
- " save_buffer: int = 15,\n",
- " category: Category = None,\n",
- " return_string: bool = False,\n",
- " use_frag: bool = False,\n",
- " ):\n",
- " \"\"\"\n",
- " file_path,\n",
- " tokenizer,\n",
- " max_len=128,\n",
- " save_buffer: int = 15,\n",
- " category: Category = None,\n",
- " label categories, if set to None, will be figured out\n",
- " automatically.\n",
- " You can set this to None for train dataset, but for valid\n",
- " dataset:\n",
- " valid_ds = IOBES(...,category=train_ds.cates)\n",
- " return_string: bool = False, do we return original string\n",
- " for tokenizer output, this option is good for debuging\n",
- " but the data won't pass into cuda if choose so\n",
- " use_frag: bool = False, do we use prepend like 'I-','B-'\n",
- " \"\"\"\n",
- " self.file_path = file_path\n",
- " self.max_len = max_len\n",
- " self.pairs = []\n",
- " self.list_of_words = []\n",
- " self.list_of_labels = []\n",
- " self.tokenizer = tokenizer\n",
- " self.cates = category\n",
- " self.return_string = return_string\n",
- " self.use_frag = use_frag\n",
- " self.load_data(save_buffer)\n",
- "\n",
- " def load_data(self, save_buffer: int = 15):\n",
- " \"\"\"\n",
- " Load file in to object structure\n",
- " \"\"\"\n",
- " with open(self.file_path, 'r') as f:\n",
- " for line in f:\n",
- " line = line.strip()\n",
- " if line:\n",
- " splited = line.split()\n",
- " if len(splited) != 2:\n",
- " continue\n",
- " word, label = splited\n",
- " # do we use 'I-', 'B-' etc\n",
- " if self.use_frag is False:\n",
- " if \"-\" in label:\n",
- " label = label.split('-')[1]\n",
- " self.pairs.append([word, label])\n",
- "\n",
- " self.pairs = np.array(self.pairs)\n",
- "\n",
- " if self.cates is None:\n",
- " labels_df = pd.DataFrame({\"label\": self.pairs[:, 1]})\n",
- " self.cates = Category(list(labels_df.vc(\"label\").index))\n",
- "\n",
- " self.batching_words(save_buffer)\n",
- "\n",
- " def batching_words(self, save_buffer: int = 15):\n",
- " \"\"\"\n",
- " batching self.words into self.list_of_words\n",
- " by self.max_len -15\n",
- " \"\"\"\n",
- " for i in range(0, len(self.pairs), self.max_len-save_buffer):\n",
- " chunk_slice = slice(i, i+self.max_len-save_buffer)\n",
- " self.list_of_words.append(self.pairs[chunk_slice, 0])\n",
- " self.list_of_labels.append(self.pairs[chunk_slice, 1])\n",
- "\n",
- " def __len__(self) -> int:\n",
- " return len(self.list_of_words)\n",
- "\n",
- " def __getitem__(self, idx: int) -> Tuple[List[str]]:\n",
- " return list(self.list_of_words[idx]), list(self.list_of_labels[idx])\n",
- "\n",
- " def __repr__(self):\n",
- " return f\"\"\"NER dataset using IOBES annotation\n",
- " {len(self)} sentences,\n",
- " Labels:\n",
- " {list(self.cates.i2c)}\n",
- " \"\"\"\n",
- "\n",
- " def collate_fn(self, data):\n",
- " \"\"\"\n",
- " data: list of tuple\n",
- " \"\"\"\n",
- " words, text_labels = zip(*data)\n",
- "\n",
- " inputs = self.tokenizer(\n",
- " list(words),\n",
- " return_tensors='pt',\n",
- " padding=True,\n",
- " truncation=True,\n",
- " max_length=self.max_len,\n",
- " is_split_into_words=True,\n",
- " return_offsets_mapping=True,\n",
- " add_special_tokens=False,\n",
- " )\n",
- " return self.align_offsets(inputs, text_labels, words)\n",
- "\n",
- " def align_offsets(\n",
- " self,\n",
- " inputs,\n",
- " text_labels: List[List[str]],\n",
- " words: List[List[str]]\n",
- " ):\n",
- " \"\"\"\n",
- " inputs: output if tokenizer\n",
- " text_labels: labels in form of list of list of strings\n",
- " words: words in form of list of list of strings\n",
- " \"\"\"\n",
- " labels = torch.zeros_like(inputs.input_ids).long()\n",
- " labels -= 100\n",
- " text_lables_array = np.empty(labels.shape, dtype=object)\n",
- " words_array = np.empty(labels.shape, dtype=object)\n",
- " max_len = inputs.input_ids.shape[1]\n",
- "\n",
- " for row_id, input_ids in enumerate(inputs.input_ids):\n",
- " word_pos = inputs.word_ids(row_id)\n",
- " for idx, pos in enumerate(word_pos):\n",
- " if pos is None:\n",
- " continue\n",
- " if pos <= max_len:\n",
- " labels[row_id, idx] = self.cates.c2i[text_labels[row_id][pos]]\n",
- " if self.return_string:\n",
- " text_lables_array[row_id,\n",
- " idx] = text_labels[row_id][pos]\n",
- " words_array[row_id, idx] = words[row_id][pos]\n",
- "\n",
- " inputs['labels'] = labels\n",
- " if self.return_string:\n",
- " inputs['text_labels'] = text_lables_array.tolist()\n",
- " inputs['word'] = words_array.tolist()\n",
- " return inputs\n",
- "\n",
- " def dataloader(self, batch_size: int = 32, shuffle: bool = True):\n",
- " \"\"\"\n",
- " Create dataloader\n",
- " \"\"\"\n",
- " return DataLoader(\n",
- " self,\n",
- " batch_size=batch_size,\n",
- " shuffle=shuffle,\n",
- " collate_fn=self.collate_fn,\n",
- " )\n",
- "\n",
- " def one_batch(self, batch_size: int = 32, shuffle: bool = True):\n",
- " return next(iter(self.dataloader(batch_size, shuffle)))\n",
- "\n",
- " def visualize_batch(self, batch, row_idx=0):\n",
- " return list(zip(self.tokenizer.convert_ids_to_tokens(batch.input_ids[row_idx]),\n",
- " batch.labels[row_idx].numpy(),\n",
- " batch.text_labels[row_idx],\n",
- " batch.word[row_idx],\n",
- " batch.offset_mapping[row_idx].numpy(),\n",
- " ))\n",
- "\n",
- " def set_hfconfig(self, config):\n",
- " \"\"\"\n",
- " set the category information to huggingface config\n",
- " \"\"\"\n",
- " config.num_labels = len(self.cates)\n",
- " config.id2label = {i: label for i, label in enumerate(self.cates.i2c)}\n",
- " config.label2id = {label: i for i, label in enumerate(self.cates.i2c)}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoTokenizer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "tokenizer = AutoTokenizer.from_pretrained(\"raynardj/roberta-pubmed\", add_prefix_space=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset = IOBES(\"/Users/xiaochen.zhang/data/valid.iobes\", tokenizer)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "in-O\n",
- "blood-O\n",
- ";-O\n",
- "content-O\n",
- "of-O\n",
- "cAMP-O\n",
- "was-O\n",
- "also-O\n",
- "decreased-O\n",
- "in-O\n",
- "lymphocytes-O\n",
- "by-O\n",
- "33-O\n",
- "%-O\n",
- ".-O\n",
- "At-O\n",
- "the-O\n",
- "same-O\n",
- "time-O\n",
- ",-O\n",
- "total-O\n",
- "content-O\n",
- "of-O\n",
- "T-cell_type\n",
- "lymphocytes-cell_type\n",
- "was-O\n",
- "decreased-O\n",
- "1.5-fold-O\n",
- "in-O\n",
- "peripheric-O\n",
- "blood-O\n",
- ".-O\n",
- "Treatment-O\n",
- "with-O\n",
- "I-hydroxyvitamin-O\n",
- "D3-O\n",
- "(-O\n",
- "1-1.5-O\n",
- "mg-O\n",
- "daily-O\n",
- ",-O\n",
- "within-O\n",
- "4-O\n",
- "weeks-O\n",
- ")-O\n",
- "led-O\n",
- "to-O\n",
- "normalization-O\n",
- "of-O\n",
- "total-O\n",
- "and-O\n",
- "ionized-O\n",
- "form-O\n",
- "of-O\n",
- "Ca2+-O\n",
- "and-O\n",
- "of-O\n",
- "25-O\n",
- "(-O\n",
- "OH-O\n",
- ")-O\n",
- "D-O\n",
- ",-O\n",
- "but-O\n",
- "did-O\n",
- "not-O\n",
- "affect-O\n",
- "the-O\n",
- "PTH-O\n",
- "content-O\n",
- "in-O\n",
- "blood-O\n",
- ".-O\n",
- "Concentration-O\n",
- "of-O\n",
- "the-O\n",
- "receptors-protein\n",
- "to-O\n",
- "1.25-O\n",
- "(-O\n",
- "OH-O\n",
- ")-O\n",
- "2D3-O\n",
- "was-O\n",
- "elevated-O\n",
- "up-O\n",
- "to-O\n",
- "39.7-O\n",
- "fmole/mg-O\n",
- "after-O\n",
- "I-O\n",
- "week-O\n",
- "of-O\n",
- "the-O\n",
- "treatment-O\n",
- ",-O\n",
- "whereas-O\n",
- "it-O\n",
- "was-O\n",
- "decreased-O\n",
- "to-O\n",
- "the-O\n",
- "initial-O\n",
- "level-O\n",
- "24.8-O\n",
- "fmole/mg-O\n",
- "within-O\n",
- "4-O\n",
- "weeks-O\n",
- ";-O\n",
- "simultaneous-O\n",
- "alteration-O\n",
- "in-O\n"
- ]
- }
- ],
- "source": [
- "for w,l in zip(*dataset[2]):\n",
- " print(f\"{w}-{l}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'input_ids': tensor([[ 19, 3741, 2603, ..., 1417, 2617, 11576],\n",
- " [ 4590, 2156, 255, ..., 405, 1182, 6608],\n",
- " [ 6214, 25683, 3809, ..., 11, 5, 8151],\n",
- " ...,\n",
- " [13998, 25326, 2413, ..., 5, 2199, 21],\n",
- " [11299, 705, 24811, ..., 134, 1589, 2032],\n",
- " [ 5804, 924, 14, ..., 366, 1168, 9]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n",
- " [1, 1, 1, ..., 1, 1, 1],\n",
- " [1, 1, 1, ..., 1, 1, 1],\n",
- " ...,\n",
- " [1, 1, 1, ..., 1, 1, 1],\n",
- " [1, 1, 1, ..., 1, 1, 1],\n",
- " [1, 1, 1, ..., 1, 1, 1]]), 'offset_mapping': tensor([[[ 1, 4],\n",
- " [ 1, 2],\n",
- " [ 2, 5],\n",
- " ...,\n",
- " [ 3, 5],\n",
- " [ 5, 8],\n",
- " [ 1, 6]],\n",
- "\n",
- " [[ 1, 5],\n",
- " [ 1, 1],\n",
- " [ 1, 1],\n",
- " ...,\n",
- " [ 5, 7],\n",
- " [ 7, 9],\n",
- " [ 9, 14]],\n",
- "\n",
- " [[ 1, 5],\n",
- " [ 5, 8],\n",
- " [ 8, 10],\n",
- " ...,\n",
- " [ 1, 2],\n",
- " [ 1, 3],\n",
- " [ 1, 10]],\n",
- "\n",
- " ...,\n",
- "\n",
- " [[ 1, 5],\n",
- " [ 5, 8],\n",
- " [ 8, 10],\n",
- " ...,\n",
- " [ 1, 3],\n",
- " [ 1, 7],\n",
- " [ 1, 3]],\n",
- "\n",
- " [[ 1, 5],\n",
- " [ 5, 6],\n",
- " [ 6, 10],\n",
- " ...,\n",
- " [ 2, 3],\n",
- " [ 1, 1],\n",
- " [ 1, 2]],\n",
- "\n",
- " [[ 1, 7],\n",
- " [ 1, 5],\n",
- " [ 1, 4],\n",
- " ...,\n",
- " [ 3, 5],\n",
- " [ 5, 7],\n",
- " [ 1, 2]]]), 'labels': tensor([[0, 1, 1, ..., 0, 0, 0],\n",
- " [2, 0, 2, ..., 0, 0, 0],\n",
- " [0, 0, 0, ..., 0, 0, 0],\n",
- " ...,\n",
- " [1, 1, 1, ..., 0, 0, 0],\n",
- " [0, 0, 0, ..., 2, 0, 2],\n",
- " [0, 0, 0, ..., 0, 0, 0]])}"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dataset.one_batch()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.4"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": true,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {},
- "toc_section_display": true,
- "toc_window_display": true
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/nbs/72_pl_training.ipynb b/nbs/72_pl_training.ipynb
deleted file mode 100644
index 3ce4854..0000000
--- a/nbs/72_pl_training.ipynb
+++ /dev/null
@@ -1,466 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Pytorch Lighting training\n",
- "> on huggingface transformers"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "# default_exp hf.train"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "from forgebox.hf.data import IOBES\n",
- "from forgebox.imports import *\n",
- "from forgebox.loop import chunkify\n",
- "import pytorch_lightning as pl\n",
- "from transformers import (\n",
- " AutoModelForTokenClassification,\n",
- " AutoTokenizer,\n",
- " pipeline\n",
- ")\n",
- "from tqdm.notebook import tqdm\n",
- "from typing import Callable, List\n",
- "from torch import device"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 290,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "try:\n",
- " ishell = get_ipython()\n",
- " IS_JUPYTER = True\n",
- " from tqdm.notebook import tqdm\n",
- "except NameError:\n",
- " IS_JUPYTER = False\n",
- " from tqdm import tqdm"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "# !pip install transformers==4.9.1\n",
- "# !pip install pytorch-lightning==1.3.8\n",
- "# !pip install tensorflow==2.2.0"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Load model and tokenizer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "\n",
- "# ner model and tokenizer\n",
- "def ner_model_from(\n",
- " name:str, dataset: IOBES\n",
- "):\n",
- " \"\"\"\n",
- " name: from_pretrain(name)\n",
- " \"\"\"\n",
- " model = AutoModelForTokenClassification.from_pretrained(\n",
- " name,\n",
- " num_labels=len(dataset.cates),\n",
- " )\n",
- " dataset.set_hfconfig(model.config)\n",
- " return model\n",
- "\n",
- "def ner_tokenizer_from(\n",
- " name: str\n",
- "):\n",
- " return AutoTokenizer.from_pretrained(\n",
- " name, add_prefix_space=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Lightning data module"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "\n",
- "# ner data module\n",
- "class NERDataModule(pl.LightningDataModule):\n",
- " def __init__(self, train_ds, val_ds, batch_size=32):\n",
- " super().__init__()\n",
- " self.train_ds = train_ds\n",
- " self.val_ds = val_ds\n",
- " self.batch_size = batch_size\n",
- "\n",
- " def train_dataloader(self):\n",
- " return self.train_ds.dataloader(batch_size=self.batch_size, shuffle=True)\n",
- "\n",
- " def val_dataloader(self):\n",
- " return self.val_ds.dataloader(batch_size=self.batch_size*2, shuffle=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "\n",
- "# ner module\n",
- "class NERModule(pl.LightningModule):\n",
- " \"\"\"\n",
- " PyTorch lightning module for training ner model\n",
- " \"\"\"\n",
- " def __init__(\n",
- " self, model,\n",
- " ):\n",
- " \"\"\"\n",
- " model: huggingface transformer model for ner\n",
- " \"\"\"\n",
- " super().__init__()\n",
- " self.model = model\n",
- "\n",
- " def forward(self, batch):\n",
- " return self.model(\n",
- " input_ids=batch['input_ids'],\n",
- " attention_mask=batch['attention_mask'],\n",
- " labels=batch['labels'])\n",
- " \n",
- " def training_step(self, batch, batch_idx):\n",
- " outputs = self(batch)\n",
- " loss = outputs.loss\n",
- " self.log(\"loss\", loss)\n",
- " self.log(\"acc\", self.calcualte_acc(outputs, batch.labels))\n",
- " return loss\n",
- "\n",
- " def validation_step(self, batch, batch_idx):\n",
- " outputs = self(batch)\n",
- " loss = outputs.loss\n",
- " self.log(\"val_loss\", loss)\n",
- " self.log(\"val_acc\", self.calcualte_acc(outputs, batch.labels))\n",
- " return loss\n",
- " \n",
- " def calcualte_acc(self, outputs, labels):\n",
- " pred_idx = outputs.logits.argmax(-1)\n",
- " mask = torch.ones_like(pred_idx)\n",
- " mask[labels==-100]=False\n",
- " return (pred_idx[mask]==labels[mask]).float().mean()\n",
- " \n",
- " def configure_optimizers(self):\n",
- " # discriminative learning rate\n",
- " param_groups = [\n",
- " {'params': self.model.roberta.parameters(), 'lr': 5e-6},\n",
- " {'params': self.model.classifier.parameters(), 'lr': 1e-3},\n",
- " ]\n",
- " optimizer = torch.optim.Adam(param_groups, lr=1e-3)\n",
- " return optimizer"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Enhance pipeline"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "def clean_ner_output(self, outputs):\n",
- " \"\"\"\n",
- " Cleaning output for NER task\n",
- " \"\"\"\n",
- " results = []\n",
- " current = []\n",
- " last_idx = 0\n",
- " # make to sub group by position\n",
- " for output in outputs:\n",
- " if output[\"start\"] in [last_idx, last_idx-1]:\n",
- " current.append(output)\n",
- " else:\n",
- " results.append(current)\n",
- " current = [output, ]\n",
- " last_idx = output[\"end\"]\n",
- " if len(current) > 0:\n",
- " results.append(current)\n",
- "\n",
- " # from tokens to string\n",
- " strings = []\n",
- " for c in results:\n",
- " tokens = []\n",
- " starts = []\n",
- " ends = []\n",
- " for o in c:\n",
- " tokens.append(o['word'])\n",
- " starts.append(o['start'])\n",
- " ends.append(o['end'])\n",
- "\n",
- " new_str = self.tokenizer.convert_tokens_to_string(tokens)\n",
- " if new_str != '':\n",
- " strings.append(dict(\n",
- " word=new_str,\n",
- " start=min(starts),\n",
- " end=max(ends),\n",
- " entity=c[0]['entity_group']\n",
- " ))\n",
- " return strings"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 281,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "class NERInference:\n",
- " \"\"\"\n",
- " NER Inference pipeline\n",
- " ner = NERInference.from_pretrained('xxxx/xxxx')\n",
- " ner.predict(['text1','text2'])\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self, model, tokenizer, name=None):\n",
- " super().__init__()\n",
- " self.model = model.eval()\n",
- " self.tokenizer = tokenizer\n",
- " self.name = name if name else \"NER model\"\n",
- "\n",
- " def __repr__(self):\n",
- " return f\"[NERInference on {self.name}]\"\n",
- "\n",
- " def to(self, device_str):\n",
- " self.model = self.model.to(device(device_str))\n",
- " return self\n",
- "\n",
- " @classmethod\n",
- " def from_pretrained(cls, tag):\n",
- " \"\"\"\n",
- " Load from pretrained model and tokenizer\n",
- " \"\"\"\n",
- " model = AutoModelForTokenClassification.from_pretrained(tag)\n",
- " tokenizer = AutoTokenizer.from_pretrained(tag)\n",
- " return cls(model=model, tokenizer=tokenizer, name=model.config._name_or_path)\n",
- " \n",
- " def __call__(self, data, batch_size=32, dev=device(\"cpu\")):\n",
- " if type(data) == str:\n",
- " return self.batch_predict([data,])\n",
- " else:\n",
- " return self.predict(data, dev=dev, batch_size=batch_size)\n",
- "\n",
- " def predict(\n",
- " self,\n",
- " texts: List[str],\n",
- " dev=device(\"cpu\"),\n",
- " batch_size: int = 32,\n",
- " progress_bar: bool = True\n",
- " ) -> pd.DataFrame:\n",
- " \"\"\"\n",
- " Predict a list of sentences/ paragraphs\n",
- " \"\"\"\n",
- " # place the model into device\n",
- " self.model = self.model.to(dev)\n",
- " iterator = list(enumerate(chunkify(texts, bs=batch_size)))\n",
- " if progress_bar:\n",
- " iterator = tqdm(iterator, leave=False)\n",
- "\n",
- " # run through iterator\n",
- " all_dfs = []\n",
- " for i, text_b in iterator:\n",
- " # by batch prediction\n",
- " batch_df = self.batch_predict(text_b)\n",
- " if len(batch_df) > 0:\n",
- " # calculate the row number\n",
- " batch_df['text_id'] = batch_df.apply(\n",
- " lambda row: i*batch_size+row.batch_row_sn, axis=1)\n",
- " all_dfs.append(batch_df)\n",
- "\n",
- " # place the model back to cpu\n",
- " self.model = self.model.to(\"cpu\")\n",
- " return pd.concat(all_dfs).reset_index(drop=True)\n",
- " \n",
- " def tokenizing(self, texts):\n",
- " inputs = self.tokenizer(\n",
- " texts,\n",
- " padding=\"max_length\",\n",
- " max_length=self.tokenizer.model_max_length,\n",
- " return_attention_mask=True,\n",
- " return_tensors='pt', truncation=True, return_offsets_mapping=True\n",
- " ).to(self.model.device)\n",
- " return inputs\n",
- "\n",
- "\n",
- " def batch_predict(self, texts:List[str])-> pd.DataFrame:\n",
- " \"\"\"\n",
- " Predict a single batch of sentences\n",
- " \"\"\"\n",
- " id2label = self.model.config.id2label\n",
- " inputs = self.tokenizing(texts)\n",
- "\n",
- " with torch.no_grad():\n",
- " outputs = self.model(input_ids=inputs.input_ids,\n",
- " attention_mask=inputs.attention_mask)\n",
- " inputs = inputs.to(device('cpu'))\n",
- "\n",
- " pred_idx = outputs.logits.argmax(-1).to(device(\"cpu\"))\n",
- " batch_size = pred_idx.size(0)\n",
- " offsets = inputs.offset_mapping\n",
- " results = []\n",
- " for bi in range(batch_size):\n",
- " text = texts[bi]\n",
- " input_ids = inputs.input_ids[bi]\n",
- " word_ids = inputs.word_ids(bi)\n",
- " pred_ids = pred_idx[bi]\n",
- " # initial values for the row\n",
- " last_pos = 0\n",
- " previous_has_positive = False\n",
- " current_start = 0\n",
- " current_index = 0\n",
- " current_id = 0\n",
- " line = []\n",
- " for ti in range(1, len(input_ids)):\n",
- " if input_ids[ti] == self.tokenizer.sep_token_id:\n",
- " break\n",
- " # is the current token an appending sub-word?\n",
- " if word_ids[ti] == last_pos:\n",
- " pass\n",
- " # is current token negative\n",
- " elif pred_ids[ti].item() == 0:\n",
- " # store the previous hanging prediction\n",
- " if previous_has_positive:\n",
- " start = current_start\n",
- " end = offsets[bi, ti, 0].item()\n",
- " line.append({\n",
- " \"start\": start, \"end\": end,\n",
- " \"entity\": id2label[current_id],\n",
- " \"word\": text[start:end],\n",
- " \"index\": current_index,\n",
- " })\n",
- "\n",
- " current_start = offsets[bi, ti, 0].item()\n",
- " previous_has_positive = False\n",
- " current_id = 0\n",
- " current_index = ti\n",
- " # has positive prediction index, other than zero\n",
- " else:\n",
- " if previous_has_positive:\n",
- " # different than the previous\n",
- " if current_id != pred_ids[ti].item():\n",
- " start = current_start\n",
- " end = offsets[bi, ti, 0].item()\n",
- " line.append({\n",
- " \"start\": start,\n",
- " \"end\": end,\n",
- " \"entity\": id2label[current_id],\n",
- " \"word\": text[start:end],\n",
- " \"index\": current_index,\n",
- " })\n",
- " current_start = offsets[bi, ti, 0].item()\n",
- " # this is the 1st postive predict for a while\n",
- " else:\n",
- " current_start = offsets[bi, ti, 0].item()\n",
- " previous_has_positive = True\n",
- " current_index = ti\n",
- " current_id = pred_ids[ti].item()\n",
- "\n",
- " last_pos = word_ids[ti]\n",
- " if previous_has_positive:\n",
- " start = current_start\n",
- " end = offsets[bi, ti, 1].item()\n",
- " line.append({\n",
- " \"start\": start,\n",
- " \"end\": end,\n",
- " \"entity\": id2label[current_id],\n",
- " \"word\": text[start:end],\n",
- " \"index\": current_index,\n",
- " })\n",
- "\n",
- " results.append(line)\n",
- " all_dfs = []\n",
- " for i, res in enumerate(results):\n",
- " sub_df = pd.DataFrame(res)\n",
- " sub_df[\"batch_row_sn\"] = i\n",
- " all_dfs.append(sub_df)\n",
- " return pd.concat(all_dfs)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.4"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": true,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {},
- "toc_section_display": true,
- "toc_window_display": false
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/nbs/cross_entropy_weighter.ipynb b/nbs/cross_entropy_weighter.ipynb
deleted file mode 100644
index 95018ab..0000000
--- a/nbs/cross_entropy_weighter.ipynb
+++ /dev/null
@@ -1,1298 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Create multi-task adjustment for cross-entropy\n",
- "> A weighter for multi-task learning with different softmax+cross-entropy"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Tools and imports"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# default_exp multitask_ce"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "import torch\n",
- "from torch import nn\n",
- "import pandas as pd\n",
- "import numpy as np\n",
- "from typing import Callable"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 120,
- "metadata": {},
- "outputs": [],
- "source": [
- "from matplotlib import pyplot as plt"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Enters softmax"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 102,
- "metadata": {},
- "outputs": [],
- "source": [
- "softmax = nn.Softmax(-1)\n",
- "crit = nn.CrossEntropyLoss()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Experience the problem"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 104,
- "metadata": {},
- "outputs": [],
- "source": [
- "def test_loss(model,iters:300,get_xy):\n",
- " losses=[]\n",
- " with torch.no_grad():\n",
- " for i in range(iters):\n",
- " x,y_true = get_xy(model.nb_output)\n",
- " y_vec = model(x)\n",
- " loss = model.crit(y_vec,y_true)\n",
- " losses.append(loss)\n",
- " return torch.stack(losses).mean()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 245,
- "metadata": {},
- "outputs": [],
- "source": [
- "def create_softmax_pipeline(\n",
- " nb_layers:int,\n",
- " nb_output:int,\n",
- " hs:int=500,\n",
- " crit:nn.Module=nn.CrossEntropyLoss(),\n",
- " )->nn.Module:\n",
- " modules = (nb_layers-1)*[nn.Linear(hs,hs),nn.Dropout(.3)]\n",
- " modules+=[nn.Linear(hs,nb_output),]\n",
- " model = nn.Sequential(*modules)\n",
- " model.hs = hs\n",
- " model.nb_output = nb_output\n",
- " model.__class__.crit = crit\n",
- " model.__class__.test_loss = test_loss\n",
- " return model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 246,
- "metadata": {},
- "outputs": [],
- "source": [
- "def random_input(nb_output):\n",
- " return torch.rand(2,500),torch.randint(low=0,high=nb_output,size = (2,))\n",
- "\n",
- "def inbalanced_input(\n",
- " bias:float\n",
- " ) -> Callable:\n",
- " def inbalanced_input_(nb_output:int):\n",
- " return torch.rand(2,500),torch.randint(\n",
- " low=0,\n",
- " high=max(1,int(nb_output*bias)),\n",
- " size = (2,)\n",
- " )\n",
- " return inbalanced_input_"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "* For models with different branches of output, some output 2 category, some output more, like 200,500\n",
- "* Their loss will end up in different scale\n",
- "* That makes mixing them up fairly hard and unfair to lesser category tasks"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 247,
- "metadata": {},
- "outputs": [],
- "source": [
- "import random"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 266,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "fa780ff8fdab4d4cb932c4dbfe96c44f",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n"
- ]
- }
- ],
- "source": [
- "result = []\n",
- "\n",
- "for i in tqdm(range(50)):\n",
- " c = random.randint(2,3000)\n",
- " b = 1-random.random()\n",
- " loss = create_softmax_pipeline(1,c).test_loss(300,inbalanced_input(b))\n",
- " result.append(dict(classes=c,loss=loss.item(),inbl_bias=b))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 267,
- "metadata": {},
- "outputs": [],
- "source": [
- "df = pd.DataFrame(result)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 268,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " classes | \n",
- " loss | \n",
- " inbl_bias | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 2244 | \n",
- " 7.784015 | \n",
- " 0.719190 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 2447 | \n",
- " 7.849158 | \n",
- " 0.002960 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 706 | \n",
- " 6.625416 | \n",
- " 0.297956 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 272 | \n",
- " 5.739001 | \n",
- " 0.154829 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 2860 | \n",
- " 7.994882 | \n",
- " 0.179631 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " 2376 | \n",
- " 7.823036 | \n",
- " 0.554090 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 474 | \n",
- " 6.230643 | \n",
- " 0.497769 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " 1476 | \n",
- " 7.317650 | \n",
- " 0.790722 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " 729 | \n",
- " 6.652312 | \n",
- " 0.671657 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 1183 | \n",
- " 7.161001 | \n",
- " 0.422553 | \n",
- "
\n",
- " \n",
- " 10 | \n",
- " 1792 | \n",
- " 7.571034 | \n",
- " 0.375524 | \n",
- "
\n",
- " \n",
- " 11 | \n",
- " 2601 | \n",
- " 7.928347 | \n",
- " 0.937953 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 1988 | \n",
- " 7.653401 | \n",
- " 0.448365 | \n",
- "
\n",
- " \n",
- " 13 | \n",
- " 617 | \n",
- " 6.471711 | \n",
- " 0.669720 | \n",
- "
\n",
- " \n",
- " 14 | \n",
- " 1257 | \n",
- " 7.196665 | \n",
- " 0.180724 | \n",
- "
\n",
- " \n",
- " 15 | \n",
- " 1208 | \n",
- " 7.142272 | \n",
- " 0.603882 | \n",
- "
\n",
- " \n",
- " 16 | \n",
- " 1571 | \n",
- " 7.431909 | \n",
- " 0.307233 | \n",
- "
\n",
- " \n",
- " 17 | \n",
- " 1815 | \n",
- " 7.565763 | \n",
- " 0.271957 | \n",
- "
\n",
- " \n",
- " 18 | \n",
- " 2370 | \n",
- " 7.840256 | \n",
- " 0.130841 | \n",
- "
\n",
- " \n",
- " 19 | \n",
- " 2421 | \n",
- " 7.835362 | \n",
- " 0.470347 | \n",
- "
\n",
- " \n",
- " 20 | \n",
- " 2608 | \n",
- " 7.933852 | \n",
- " 0.362477 | \n",
- "
\n",
- " \n",
- " 21 | \n",
- " 1833 | \n",
- " 7.566414 | \n",
- " 0.551423 | \n",
- "
\n",
- " \n",
- " 22 | \n",
- " 1769 | \n",
- " 7.536047 | \n",
- " 0.630381 | \n",
- "
\n",
- " \n",
- " 23 | \n",
- " 1950 | \n",
- " 7.615793 | \n",
- " 0.795910 | \n",
- "
\n",
- " \n",
- " 24 | \n",
- " 1910 | \n",
- " 7.585739 | \n",
- " 0.343632 | \n",
- "
\n",
- " \n",
- " 25 | \n",
- " 2301 | \n",
- " 7.798465 | \n",
- " 0.829628 | \n",
- "
\n",
- " \n",
- " 26 | \n",
- " 2377 | \n",
- " 7.869211 | \n",
- " 0.016988 | \n",
- "
\n",
- " \n",
- " 27 | \n",
- " 1496 | \n",
- " 7.393178 | \n",
- " 0.119690 | \n",
- "
\n",
- " \n",
- " 28 | \n",
- " 1179 | \n",
- " 7.147058 | \n",
- " 0.568997 | \n",
- "
\n",
- " \n",
- " 29 | \n",
- " 1475 | \n",
- " 7.356740 | \n",
- " 0.693435 | \n",
- "
\n",
- " \n",
- " 30 | \n",
- " 2409 | \n",
- " 7.835954 | \n",
- " 0.455461 | \n",
- "
\n",
- " \n",
- " 31 | \n",
- " 542 | \n",
- " 6.328403 | \n",
- " 0.733283 | \n",
- "
\n",
- " \n",
- " 32 | \n",
- " 699 | \n",
- " 6.609892 | \n",
- " 0.674927 | \n",
- "
\n",
- " \n",
- " 33 | \n",
- " 2966 | \n",
- " 8.062770 | \n",
- " 0.268944 | \n",
- "
\n",
- " \n",
- " 34 | \n",
- " 2831 | \n",
- " 8.006001 | \n",
- " 0.822981 | \n",
- "
\n",
- " \n",
- " 35 | \n",
- " 1657 | \n",
- " 7.459473 | \n",
- " 0.323996 | \n",
- "
\n",
- " \n",
- " 36 | \n",
- " 1469 | \n",
- " 7.316678 | \n",
- " 0.062805 | \n",
- "
\n",
- " \n",
- " 37 | \n",
- " 674 | \n",
- " 6.579655 | \n",
- " 0.812632 | \n",
- "
\n",
- " \n",
- " 38 | \n",
- " 654 | \n",
- " 6.550060 | \n",
- " 0.601652 | \n",
- "
\n",
- " \n",
- " 39 | \n",
- " 194 | \n",
- " 5.350555 | \n",
- " 0.238536 | \n",
- "
\n",
- " \n",
- " 40 | \n",
- " 250 | \n",
- " 5.567015 | \n",
- " 0.270642 | \n",
- "
\n",
- " \n",
- " 41 | \n",
- " 2368 | \n",
- " 7.819131 | \n",
- " 0.117133 | \n",
- "
\n",
- " \n",
- " 42 | \n",
- " 1573 | \n",
- " 7.429606 | \n",
- " 0.003537 | \n",
- "
\n",
- " \n",
- " 43 | \n",
- " 2994 | \n",
- " 8.053926 | \n",
- " 0.764925 | \n",
- "
\n",
- " \n",
- " 44 | \n",
- " 2155 | \n",
- " 7.741345 | \n",
- " 0.941843 | \n",
- "
\n",
- " \n",
- " 45 | \n",
- " 2642 | \n",
- " 7.919126 | \n",
- " 0.612718 | \n",
- "
\n",
- " \n",
- " 46 | \n",
- " 486 | \n",
- " 6.242527 | \n",
- " 0.701243 | \n",
- "
\n",
- " \n",
- " 47 | \n",
- " 1948 | \n",
- " 7.644366 | \n",
- " 0.970036 | \n",
- "
\n",
- " \n",
- " 48 | \n",
- " 2626 | \n",
- " 7.932363 | \n",
- " 0.657842 | \n",
- "
\n",
- " \n",
- " 49 | \n",
- " 1660 | \n",
- " 7.466812 | \n",
- " 0.855916 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " classes loss inbl_bias\n",
- "0 2244 7.784015 0.719190\n",
- "1 2447 7.849158 0.002960\n",
- "2 706 6.625416 0.297956\n",
- "3 272 5.739001 0.154829\n",
- "4 2860 7.994882 0.179631\n",
- "5 2376 7.823036 0.554090\n",
- "6 474 6.230643 0.497769\n",
- "7 1476 7.317650 0.790722\n",
- "8 729 6.652312 0.671657\n",
- "9 1183 7.161001 0.422553\n",
- "10 1792 7.571034 0.375524\n",
- "11 2601 7.928347 0.937953\n",
- "12 1988 7.653401 0.448365\n",
- "13 617 6.471711 0.669720\n",
- "14 1257 7.196665 0.180724\n",
- "15 1208 7.142272 0.603882\n",
- "16 1571 7.431909 0.307233\n",
- "17 1815 7.565763 0.271957\n",
- "18 2370 7.840256 0.130841\n",
- "19 2421 7.835362 0.470347\n",
- "20 2608 7.933852 0.362477\n",
- "21 1833 7.566414 0.551423\n",
- "22 1769 7.536047 0.630381\n",
- "23 1950 7.615793 0.795910\n",
- "24 1910 7.585739 0.343632\n",
- "25 2301 7.798465 0.829628\n",
- "26 2377 7.869211 0.016988\n",
- "27 1496 7.393178 0.119690\n",
- "28 1179 7.147058 0.568997\n",
- "29 1475 7.356740 0.693435\n",
- "30 2409 7.835954 0.455461\n",
- "31 542 6.328403 0.733283\n",
- "32 699 6.609892 0.674927\n",
- "33 2966 8.062770 0.268944\n",
- "34 2831 8.006001 0.822981\n",
- "35 1657 7.459473 0.323996\n",
- "36 1469 7.316678 0.062805\n",
- "37 674 6.579655 0.812632\n",
- "38 654 6.550060 0.601652\n",
- "39 194 5.350555 0.238536\n",
- "40 250 5.567015 0.270642\n",
- "41 2368 7.819131 0.117133\n",
- "42 1573 7.429606 0.003537\n",
- "43 2994 8.053926 0.764925\n",
- "44 2155 7.741345 0.941843\n",
- "45 2642 7.919126 0.612718\n",
- "46 486 6.242527 0.701243\n",
- "47 1948 7.644366 0.970036\n",
- "48 2626 7.932363 0.657842\n",
- "49 1660 7.466812 0.855916"
- ]
- },
- "execution_count": 268,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### The pattern"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "There are $n$ tasks $T$ with a set of numbers of classes $C$ where $c_{i}$ is the number of the classes of task $T_{i}$\n",
- "\n",
- "The number of classes $c_{i}$ has certain correlation of the average loss $L_{i}$\n",
- "\n",
- "Where $log_{10}(c_{i})$ has a clear linear relationship with $L$\n",
- "\n",
- "$L_{i} = a.log_{10}(c_{i})$, where a is a fixed constant"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 269,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 269,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAVw0lEQVR4nO3dfZBd913f8ffXqzWs8rQmFiTaSBa0QW0Tjy13a8vNlAl1iMhDE8EorV1MwMOMcGBC3BYNces6ENwSum0hqYcIDxkaiOu6qMrWMMYy0xDIMCMxa0u24jhiRB4krQLIISvX9kLWq2//2Hudq7v34dzVfTz7fs3s6N5zftr9/sbWR0ff8zu/G5mJJGn0XTboAiRJ3WGgS1JJGOiSVBIGuiSVhIEuSSWxYVA/+Morr8xt27YN6sdL0kh67LHHnsnMTY3ODSzQt23bxtzc3KB+vCSNpIj4arNztlwkqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKomBLVuUpLKbPTrPzKETnF1YZPPkBPt2bWf3jqme/Tyv0CWpB2aPznPnwePMLyySwPzCInc8eIwdH36U2aPzPfmZBrok9cDMoRMsLi2vOv6NF5a48+DxnoS6gS5JPXB2YbHpucWlZWYOnej6z7SHLkldUN8vf9XEOAuLS03Htwr8tSp0hR4R/yoinoqIz0fEAxHx7XXnvy0iHoyIkxFxJCK2db1SSRpSjfrlrcIcYPPkRNfraBvoETEF/AwwnZlvBMaAm+uG/QTwjcz8u8CvAL/c7UIlaVg165c3MzE+xr5d27teR9Ee+gZgIiI2ABuBs3Xn3w18svL6AHBTRER3SpSk4daufTI5Mc7U5AQBTE1O8Es/fHVPli+27aFn5nxE/GfgFLAIPJqZj9YNmwJOV8a/GBHngVcDz9QOioi9wF6ArVu3Xnr1kjQENk9OMN8i1M8vLnHsQ2/teR1FWi5XsHIF/t3AZuBlEXFr/bAGvzVXHci8LzOnM3N606aG+7NL0sjZt2s7E+NjTc/3ol/eSJFVLm8BvpyZ5wAi4iDwj4FP1Yw5A2wBzlTaMq8C/rrLtUrSQNw1e5z7D5966Sr1ZZeP8R9+6Fttk+qvv/C7T/GNFy6+GdqrfnkjRXrop4CdEbGx0he/CXi6bsxDwI9VXu8BPpOZq67QJWnU3DV7nE/VhDnA899c5o4Hj3HX7PGXju3eMcXRu9/Kr/6La/vSL2+kSA/9SEQcAB4HXgSOAvdFxIeBucx8CPgE8NsRcZKVK/P6VTCSNJIeOHK66bn7D59i+qrvuCiwd++Y6luA1yv0YFFmfgj4UN3hu2vO/w3wni7WJUlDYblFsyFZWbI4qACv55OiktatIrshjkW0DPVePPG5Vu7lImldavR0Z6NNs265YUvL79OvFSxFGOiS1qVGT3cuLq3c7HzTRz7zUrDfs/tqbt3Z+LmZfq5gKcJAl7QutXoQqP5q/Z7dV/OVj7xjoCtYirCHLmldatcbr25xOywrWIow0CWtK9Uboa3CvGqYbngWYaBLKr1qiM8vLBI02JekiWG64VmEgS6p1KqrWao3QDt5hH2YbngW4U1RSaXW6V7lVW/6O98x1P3yRrxCl1Qq9Q8LtVrN0sytO7dyz+6re1Bdbxnokkqjvr3SSc98LIJbbtgykkFeZaBLKo1G7ZWEVaFefT/V5HH/UWWgSyqNZssMq+Hdas+WMjDQJZVGs5751OQEf/LBfzqAivrLVS6SSqPRR8EN234rveQVuqTSqLZR2m2JW1YGuqShV7sUcXLjOJlwfnGpYWAP+34rvWSgSxpq9UsRaz+EuborIrBuQ7yWPXRJQ63dk57VXRFloEsackV2PBy1XRF7xUCXNNSK7Hg4arsi9oqBLmlozR6d5/m/fbHlmPW0LLEdb4pKGqhme5VvHL+MpQvJ0nLznVjGIobuY+AGyUCXNDCt9ip/YelCy987MT5mmNcx0CUNzM8/9NSa9iov26Za3WKgSxqI2aPzLCwutR9YZ73sy7IW3hSVNBBrWTvuDdDW2gZ6RGyPiGM1X89GxB11Y94cEedrxtzdu5IllUGRteOXBVyxcZxg5crcnnlrbVsumXkCuBYgIsaAeeDTDYZ+LjPf2d3yJJVVu4+Hu2LjOB/6Z28wwDvQaQ/9JuDPM/OrvShGUvnUf8Zn9Wbmvl3bL1rhAq5cuVSdBvrNwANNzt0YEU8AZ4Gfzcyn6gdExF5gL8DWrVs7/NGSRs3s0Xn2HXjipbXk8wuL7DvwBOBWt70QmUU+PhUi4nJWwvoNmfmXdedeCVzIzOci4u3ARzPz9a2+3/T0dM7Nza2xbEmjYMeHH71od8SqKzaOc/Tutw6gotEXEY9l5nSjc51cob8NeLw+zAEy89ma1w9HxK9FxJWZ+Uzn5UoaRXfNHueBI6dZzmQsgltu2NIwzIGmx3VpOlm2eAtN2i0R8ZqIiMrr6yvf9+uXXp6kUXDX7HE+dfgUy5V/8S9n8qnDpwZc1fpT6Ao9IjYCPwD8ZM2x2wEycz+wB3hfRLwILAI3Z9FejqSR98CR0x2Nn5wY71El61uhQM/MF4BX1x3bX/P6XuDe7pYmaRg1WrWy3OL6bfyyYOlCXvT+59/1hn6Uuu74pKikwqqbac0vLJJ86yPgVhquq41FMPOea5ianHjp4aCZ91zjSpYecS8XSYU12kxrcWmZjeOXNdwd8ZYbtqzrD23uN6/QJRXSajOtxaUL3LpzK2OVS/WxCG7duZV7dl/dzxLXPa/QJRXSajOtzZMT3LP7agN8wLxCl1RIq8203AFxOBjokgqZGG8cF5ePhT3yIWGgSypk8cXGHwlXuyRRg2WgSyqk2VJzHyEcHga6pELGmiw2b3Zc/WegSyrklhu2dHRc/eeyRUmFVJck1u+o6FLF4VF4P/Rucz90Sepcq/3QbblIUknYcpHWgWaf66lyMdClkqvukFjdVKu6QyJgqJeMgS6VSKMr8ZlDJxrukDhz6ISBXjIGulQSd80e5/7Dp6guc6heideHeVWrvVk0mrwpKpXA7NH5i8K8qlmYw8oOiSoXA10qgZlDJ1aFeSsT42PukFhCtlykEijSPhmL4EKmq1xKzECXSmDz5ATzbUL9QiZf/sg7+lSRBsGWi1QC+3ZtZ2J8rOUYe+bl5xW6VALV9snMoRPMLywScFFP3Z75+mCgSyWxe8fUS8Huk6Hrk4EulVBtuGv9sIcuSSVhoEtSSbQN9IjYHhHHar6ejYg76sZERHwsIk5GxJMRcV3vSpYkNdK2h56ZJ4BrASJiDJgHPl037G3A6ytfNwAfr/wqSeqTTlsuNwF/nplfrTv+buC3csVhYDIiXtuVCiVJhXQa6DcDDzQ4PgWcrnl/pnLsIhGxNyLmImLu3LlzHf5oSVIrhQM9Ii4H3gX8TqPTDY6t2isoM+/LzOnMnN60aVPxKiVJbXVyhf424PHM/MsG584AW2revw44eymFSZI600mg30LjdgvAQ8B7K6tddgLnM/Nrl1ydJKmwQk+KRsRG4AeAn6w5djtAZu4HHgbeDpwEXgBu63qlkqSWCgV6Zr4AvLru2P6a1wn8dHdLkyR1widFJakkDHRJKgkDXZJKwkCXpJIw0CWpJAx0SSoJA12SSsJAl6SSMNAlqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKgkDXZJKwkCXpJIo9AEXUhnNHp1n5tAJzi4ssnlygn27trN7x9Sgy5LWzEDXujR7dJ47Dx5ncWkZgPmFRe48eBzAUNfIMtC1btRekV8WwXLmRecXl5aZOXTCQNfIMtC1Ltw1e5z7D5+iGuH1YV51dmGxf0VJXeZNUZXe7NH5i8K8lc2TEz2vR+oVr9BVSvXtlSJhPjE+xr5d23tem9QrBrpKp/6GZ7P2CsBYBBcyXeWiUjDQVRrVq/L5gn3wAP7LP7/GEFdpGOgqhfqr8nYC+JGdWw1zlYqBrlKYOXSibZjbXlHZGegqhXbLDSfGx/ilH77aEFepFVq2GBGTEXEgIr4YEU9HxI11598cEecj4ljl6+7elCs11mq54dTkhGGudaHoFfpHgUcyc09EXA5sbDDmc5n5zu6VJhW3b9f2VT10r8q13rQN9Ih4JfB9wI8DZOY3gW/2tiypsWYbalVD2822tJ4VuUL/HuAc8JsRcQ3wGPCBzHy+btyNEfEEcBb42cx8qv4bRcReYC/A1q1bL6lwrT/tNtSqDXZpPSrSQ98AXAd8PDN3AM8DH6wb8zhwVWZeA/w3YLbRN8rM+zJzOjOnN23adAllaz1qtJKluqGWpGKBfgY4k5lHKu8PsBLwL8nMZzPzucrrh4HxiLiyq5Vq3Wu2ksUNtaQVbQM9M/8COB0R1U0ubgK+UDsmIl4TEVF5fX3l+369y7VqnWu2ksUNtaQVRXdbfD9wf0Q8CVwL/MeIuD0ibq+c3wN8vtJD/xhwc2aLDTSkNdi3azsT42MXHXNDLelbYlC5Oz09nXNzcwP52Rpdfmyc1ruIeCwzpxud80lRjRRXskjN+QEXklQSBroklYQtFw2EvXCp+wx09V27Jz4lrY0tF/WdT3xKvWGgq+984lPqDQNdfecTn1JvGOjqO5/4lHrDm6LqO/cul3rDQNclqS4/nF9YZCyC5UymCgS0T3xK3Wega83qlx8uV/YFchmiNBj20LVmjZYfVrkMUeo/A11r1m6ZocsQpf4y0LVm7ZYZugxR6i8DXWvWaPlhlcsQpf7zpqjWrHb5YaerXCR1n4GuS+LyQ2l42HKRpJIw0CWpJAx0SSoJA12SSsJAl6SSMNAlqSQMdEkqCQNdkkrCQJekkigU6BExGREHIuKLEfF0RNxYdz4i4mMRcTIinoyI63pTriSpmaKP/n8UeCQz90TE5cDGuvNvA15f+boB+HjlV0lSn7S9Qo+IVwLfB3wCIDO/mZkLdcPeDfxWrjgMTEbEa7terSSpqSItl+8BzgG/GRFHI+I3IuJldWOmgNM1789Ujl0kIvZGxFxEzJ07d27NRUuSVisS6BuA64CPZ+YO4Hngg3VjosHvy1UHMu/LzOnMnN60aVPHxUqSmisS6GeAM5l5pPL+ACsBXz9mS8371wFnL708SVJRbQM9M/8COB0R1Y+fuQn4Qt2wh4D3Vla77ATOZ+bXuluqJKmVoqtc3g/cX1nh8iXgtoi4HSAz9wMPA28HTgIvALf1oFZJUguFAj0zjwHTdYf315xP4Ke7WJckqUM+KSpJJWGgS1JJGOiSVBIGuiSVhIEuSSVhoEtSSRjoklQSBroklYSBLkklYaBLUkkU3ctFBcwenWfm0AnOLiyyeXKCfbu2s3vHqm3hJaknDPQumT06z50Hj7O4tAzA/MIidx48DmCoS+oLWy5dMnPoxEthXrW4tMzMoRMDqkjSemOgd8nZhcWOjktStxnoXbJ5cqKj45LUbQZ6l+zbtZ2J8bGLjk2Mj7Fv1/Ymv0OSusubol1SvfHpKhdJg2Kgd9HuHVMGuKSBseUiSSVhoEtSSRjoklQSBroklYQ3RQtwjxZJo8BAb8M9WiSNClsubbhHi6RRYaC34R4tkkaFgd6Ge7RIGhWFAj0ivhIRxyPiWETMNTj/5og4Xzl/LCLu7n6pg+EeLZJGRSc3Rb8/M59pcf5zmfnOSy1o2LhHi6RR4SqXAtyjRdIoKNpDT+DRiHgsIvY2GXNjRDwREb8fEW9oNCAi9kbEXETMnTt3bk0FS5IaK3qF/qbMPBsR3wn8QUR8MTP/uOb848BVmflcRLwdmAVeX/9NMvM+4D6A6enpvMTaJUk1Cl2hZ+bZyq9/BXwauL7u/LOZ+Vzl9cPAeERc2eVaJUkttA30iHhZRLyi+hp4K/D5ujGviYiovL6+8n2/3v1yJUnNFGm5fBfw6UpebwD+R2Y+EhG3A2TmfmAP8L6IeBFYBG7OTFsqktRHbQM9M78EXNPg+P6a1/cC93a3NElSJ3xSVJJKwkCXpJIY2QeL3KNcki42koHuHuWStNpItlzco1ySVhvJQHePcklabSQD3T3KJWm1kQx09yiXpNVG8qaoe5RL0mojGejgHuWSVG8kWy6SpNUMdEkqCQNdkkrCQJekkjDQJakkDHRJKokY1AcLRcQ54KsD+eGX5krgmUEX0QVlmQeUZy5lmQeUZy7DOI+rMnNToxMDC/RRFRFzmTk96DouVVnmAeWZS1nmAeWZy6jNw5aLJJWEgS5JJWGgd+6+QRfQJWWZB5RnLmWZB5RnLiM1D3voklQSXqFLUkkY6JJUEgZ6AxGxJSL+MCKejoinIuIDLcb+o4hYjog9/ayxiKLziIg3R8Sxypg/6nedRRSZS0S8KiJ+NyKeqIy5bRC1thIR3x4Rf1pT4y80GPNtEfFgRJyMiCMRsa3/lbZXcC7/OiK+EBFPRsT/jYirBlFrK0XmUTN2T0RkRAznUsbM9KvuC3gtcF3l9SuAPwP+QYNxY8BngIeBPYOuey3zACaBLwBbK++/c9B1X8Jc/i3wy5XXm4C/Bi4fdO11NQbw8srrceAIsLNuzE8B+yuvbwYeHHTdlzCX7wc2Vl6/bxjnUmQeNf/f/TFwGJgedN2NvrxCbyAzv5aZj1de/z/gaaDRp2m8H/jfwF/1sbzCCs7jXwIHM/NUZdwozyWBV0REAC9nJdBf7GuhbeSK5ypvxytf9SsT3g18svL6AHBTZU5DpchcMvMPM/OFytvDwOv6WGIhBf+bAPwi8J+Av+lXbZ0y0Nuo/HN3Byt/a9cenwJ+CNjf/6o612wewPcCV0TEZyPisYh4b79r61SLudwL/H3gLHAc+EBmXuhrcQVExFhEHGPlQuAPMrN+HlPAaYDMfBE4D7y6v1UWU2AutX4C+P3+VNaZdvOIiB3Alsz8vYEUWJCB3kJEvJyVK/A7MvPZutO/CvxcZi73v7LOtJnHBuAfAu8AdgH/PiK+t88lFtZmLruAY8Bm4Frg3oh4ZZ9LbCszlzPzWlauVq+PiDfWDWl0NT6U64sLzAWAiLgVmAZm+llfUa3mERGXAb8C/JtB1VeUgd5ERIyzEhz3Z+bBBkOmgf8ZEV8B9gC/FhG7+1hiIQXmcQZ4JDOfz8xnWOkRXtPPGosqMJfbWGkfZWaeBL4M/L1+1tiJzFwAPgv8YN2pM8AWgIjYALyKlfbR0GoxFyLiLcC/A96VmX/b59I60mQerwDeCHy28ud9J/DQMN4YNdAbqPQrPwE8nZn/tdGYzPzuzNyWmdtY6XP+VGbO9rHMtorMA/g/wD+JiA0RsRG4gZX+9FApOJdTwE2V8d8FbAe+1J8Ki4mITRExWXk9AbwF+GLdsIeAH6u83gN8Jit35YZJkblUWhW/zkqYD+X9mXbzyMzzmXllzZ/3w6zMZ24gBbewYdAFDKk3AT8KHK/01WBlBcVWgMwcib45BeaRmU9HxCPAk8AF4Dcy8/MDqba1Iv9NfhH47xFxnJW2xc9V/tUxTF4LfDIixli5oPpfmfl7EfFhYC4zH2LlL67fjoiTrFyZ3zy4clsqMpcZVm5Q/07lvu6pzHzXwCpurMg8RoKP/ktSSdhykaSSMNAlqSQMdEkqCQNdkkrCQJekkjDQJakkDHRJKon/D8XLOB7FzJCYAAAAAElFTkSuQmCC\n",
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "plt.scatter(np.log10(df.classes),df.loss,)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### The solution\n",
- "\n",
- "Assume we have certain function, will produce a weight that will bring each cross entropy loss to the same constant scale ${a}$.\n",
- "\n",
- "$L_{i}.f(c_{i})=a$\n",
- "\n",
- "$f(c_{i})a.log_{10}(c_{i})=a$\n",
- "\n",
- "Here we can get how to calculate $f(c_{i})$\n",
- "\n",
- "$f(c_{i})=\\frac{1}{log_{10}(c_{i})}$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 270,
- "metadata": {},
- "outputs": [],
- "source": [
- "def adjust(nb_class):\n",
- " return 1/np.log10(nb_class)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 271,
- "metadata": {},
- "outputs": [],
- "source": [
- "df[\"lambda_weighted\"] = df.apply(\n",
- " lambda row:row['loss']*adjust(row['classes']),\n",
- " axis=1)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "For now it's a about the same scale of loss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 272,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 272,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "plt.scatter(df.classes,df.lambda_weighted)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Adjusted CrossEntropy"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 288,
- "metadata": {},
- "outputs": [],
- "source": [
- "# export\n",
- "class MultiTaskCELoss(nn.Module):\n",
- " \"\"\"\n",
- " A cross entropy loss function which will cancel out\n",
- " the effect of different class numbers\n",
- " \"\"\"\n",
- " def __init__(self,):\n",
- " super().__init__()\n",
- " self.celoss = nn.CrossEntropyLoss()\n",
- " \n",
- " def forward(\n",
- " self,\n",
- " y_pred: torch.FloatTensor,\n",
- " y_true: torch.LongTensor,\n",
- " )-> torch.FloatTensor:\n",
- " \"\"\"\n",
- " Input:\n",
- " - y_pred: torch.FloatTensor, Prediction tensor\n",
- " - y_true: torch.LongTensor, Label indices\n",
- " Return:\n",
- " - loss: torch.FloatTensor, scala adjusted\n",
- " \"\"\"\n",
- " nb_classes = y_pred.size(-1)\n",
- " lambda_ = 1/np.log10(nb_classes)\n",
- " loss = self.celoss(y_pred,y_true)\n",
- " return loss*lambda_"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Let's make this adjustment into the loss function, an upgraded version of CrossEntropy"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 289,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "e869eca73d324737b011b32aebdeab43",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n"
- ]
- }
- ],
- "source": [
- "result = []\n",
- "\n",
- "for i in tqdm(range(50)):\n",
- " c = random.randint(2,3000)\n",
- " b = 1-random.random()\n",
- " # here we change the loss function to MultiTaskCELoss\n",
- " loss = create_softmax_pipeline(1,c,crit=MultiTaskCELoss())\\\n",
- " .test_loss(300,inbalanced_input(b))\n",
- " result.append(dict(classes=c,loss=loss.item(),inbl_bias=b))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 283,
- "metadata": {},
- "outputs": [],
- "source": [
- "df_adjusted = pd.DataFrame(result)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 284,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " classes | \n",
- " loss | \n",
- " inbl_bias | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 533 | \n",
- " 2.319248 | \n",
- " 0.208584 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 2264 | \n",
- " 2.326346 | \n",
- " 0.022896 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 254 | \n",
- " 2.320735 | \n",
- " 0.922445 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 1027 | \n",
- " 2.328291 | \n",
- " 0.643444 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 2428 | \n",
- " 2.326136 | \n",
- " 0.883795 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " 772 | \n",
- " 2.317071 | \n",
- " 0.062716 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 2166 | \n",
- " 2.315754 | \n",
- " 0.364122 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " 812 | \n",
- " 2.321972 | \n",
- " 0.852977 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " 2632 | \n",
- " 2.321708 | \n",
- " 0.639467 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 2576 | \n",
- " 2.313269 | \n",
- " 0.968200 | \n",
- "
\n",
- " \n",
- " 10 | \n",
- " 1326 | \n",
- " 2.321515 | \n",
- " 0.508901 | \n",
- "
\n",
- " \n",
- " 11 | \n",
- " 169 | \n",
- " 2.321526 | \n",
- " 0.870789 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 815 | \n",
- " 2.309163 | \n",
- " 0.262275 | \n",
- "
\n",
- " \n",
- " 13 | \n",
- " 792 | \n",
- " 2.321771 | \n",
- " 0.608608 | \n",
- "
\n",
- " \n",
- " 14 | \n",
- " 739 | \n",
- " 2.316215 | \n",
- " 0.198933 | \n",
- "
\n",
- " \n",
- " 15 | \n",
- " 2460 | \n",
- " 2.318523 | \n",
- " 0.880343 | \n",
- "
\n",
- " \n",
- " 16 | \n",
- " 2420 | \n",
- " 2.320517 | \n",
- " 0.781538 | \n",
- "
\n",
- " \n",
- " 17 | \n",
- " 1688 | \n",
- " 2.316657 | \n",
- " 0.954598 | \n",
- "
\n",
- " \n",
- " 18 | \n",
- " 2929 | \n",
- " 2.323009 | \n",
- " 0.750869 | \n",
- "
\n",
- " \n",
- " 19 | \n",
- " 78 | \n",
- " 2.192723 | \n",
- " 0.089868 | \n",
- "
\n",
- " \n",
- " 20 | \n",
- " 1670 | \n",
- " 2.325608 | \n",
- " 0.188055 | \n",
- "
\n",
- " \n",
- " 21 | \n",
- " 2868 | \n",
- " 2.324525 | \n",
- " 0.058037 | \n",
- "
\n",
- " \n",
- " 22 | \n",
- " 905 | \n",
- " 2.319283 | \n",
- " 0.126673 | \n",
- "
\n",
- " \n",
- " 23 | \n",
- " 1675 | \n",
- " 2.309565 | \n",
- " 0.307769 | \n",
- "
\n",
- " \n",
- " 24 | \n",
- " 1430 | \n",
- " 2.327099 | \n",
- " 0.451335 | \n",
- "
\n",
- " \n",
- " 25 | \n",
- " 424 | \n",
- " 2.315844 | \n",
- " 0.512538 | \n",
- "
\n",
- " \n",
- " 26 | \n",
- " 1511 | \n",
- " 2.316325 | \n",
- " 0.613963 | \n",
- "
\n",
- " \n",
- " 27 | \n",
- " 959 | \n",
- " 2.330317 | \n",
- " 0.987407 | \n",
- "
\n",
- " \n",
- " 28 | \n",
- " 147 | \n",
- " 2.325459 | \n",
- " 0.931882 | \n",
- "
\n",
- " \n",
- " 29 | \n",
- " 278 | \n",
- " 2.317273 | \n",
- " 0.401440 | \n",
- "
\n",
- " \n",
- " 30 | \n",
- " 885 | \n",
- " 2.325397 | \n",
- " 0.087088 | \n",
- "
\n",
- " \n",
- " 31 | \n",
- " 1459 | \n",
- " 2.307971 | \n",
- " 0.857782 | \n",
- "
\n",
- " \n",
- " 32 | \n",
- " 182 | \n",
- " 2.344716 | \n",
- " 0.443465 | \n",
- "
\n",
- " \n",
- " 33 | \n",
- " 2468 | \n",
- " 2.326825 | \n",
- " 0.572978 | \n",
- "
\n",
- " \n",
- " 34 | \n",
- " 954 | \n",
- " 2.322750 | \n",
- " 0.815261 | \n",
- "
\n",
- " \n",
- " 35 | \n",
- " 1297 | \n",
- " 2.319934 | \n",
- " 0.887149 | \n",
- "
\n",
- " \n",
- " 36 | \n",
- " 1787 | \n",
- " 2.326161 | \n",
- " 0.113759 | \n",
- "
\n",
- " \n",
- " 37 | \n",
- " 462 | \n",
- " 2.325276 | \n",
- " 0.969307 | \n",
- "
\n",
- " \n",
- " 38 | \n",
- " 686 | \n",
- " 2.324262 | \n",
- " 0.305557 | \n",
- "
\n",
- " \n",
- " 39 | \n",
- " 428 | \n",
- " 2.323913 | \n",
- " 0.342770 | \n",
- "
\n",
- " \n",
- " 40 | \n",
- " 2900 | \n",
- " 2.319349 | \n",
- " 0.609161 | \n",
- "
\n",
- " \n",
- " 41 | \n",
- " 2765 | \n",
- " 2.323373 | \n",
- " 0.296973 | \n",
- "
\n",
- " \n",
- " 42 | \n",
- " 752 | \n",
- " 2.320189 | \n",
- " 0.302571 | \n",
- "
\n",
- " \n",
- " 43 | \n",
- " 502 | \n",
- " 2.325897 | \n",
- " 0.419716 | \n",
- "
\n",
- " \n",
- " 44 | \n",
- " 2162 | \n",
- " 2.320449 | \n",
- " 0.202672 | \n",
- "
\n",
- " \n",
- " 45 | \n",
- " 2460 | \n",
- " 2.317962 | \n",
- " 0.430931 | \n",
- "
\n",
- " \n",
- " 46 | \n",
- " 2537 | \n",
- " 2.312116 | \n",
- " 0.693554 | \n",
- "
\n",
- " \n",
- " 47 | \n",
- " 806 | \n",
- " 2.318746 | \n",
- " 0.539312 | \n",
- "
\n",
- " \n",
- " 48 | \n",
- " 251 | \n",
- " 2.322845 | \n",
- " 0.181191 | \n",
- "
\n",
- " \n",
- " 49 | \n",
- " 2845 | \n",
- " 2.326007 | \n",
- " 0.667570 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " classes loss inbl_bias\n",
- "0 533 2.319248 0.208584\n",
- "1 2264 2.326346 0.022896\n",
- "2 254 2.320735 0.922445\n",
- "3 1027 2.328291 0.643444\n",
- "4 2428 2.326136 0.883795\n",
- "5 772 2.317071 0.062716\n",
- "6 2166 2.315754 0.364122\n",
- "7 812 2.321972 0.852977\n",
- "8 2632 2.321708 0.639467\n",
- "9 2576 2.313269 0.968200\n",
- "10 1326 2.321515 0.508901\n",
- "11 169 2.321526 0.870789\n",
- "12 815 2.309163 0.262275\n",
- "13 792 2.321771 0.608608\n",
- "14 739 2.316215 0.198933\n",
- "15 2460 2.318523 0.880343\n",
- "16 2420 2.320517 0.781538\n",
- "17 1688 2.316657 0.954598\n",
- "18 2929 2.323009 0.750869\n",
- "19 78 2.192723 0.089868\n",
- "20 1670 2.325608 0.188055\n",
- "21 2868 2.324525 0.058037\n",
- "22 905 2.319283 0.126673\n",
- "23 1675 2.309565 0.307769\n",
- "24 1430 2.327099 0.451335\n",
- "25 424 2.315844 0.512538\n",
- "26 1511 2.316325 0.613963\n",
- "27 959 2.330317 0.987407\n",
- "28 147 2.325459 0.931882\n",
- "29 278 2.317273 0.401440\n",
- "30 885 2.325397 0.087088\n",
- "31 1459 2.307971 0.857782\n",
- "32 182 2.344716 0.443465\n",
- "33 2468 2.326825 0.572978\n",
- "34 954 2.322750 0.815261\n",
- "35 1297 2.319934 0.887149\n",
- "36 1787 2.326161 0.113759\n",
- "37 462 2.325276 0.969307\n",
- "38 686 2.324262 0.305557\n",
- "39 428 2.323913 0.342770\n",
- "40 2900 2.319349 0.609161\n",
- "41 2765 2.323373 0.296973\n",
- "42 752 2.320189 0.302571\n",
- "43 502 2.325897 0.419716\n",
- "44 2162 2.320449 0.202672\n",
- "45 2460 2.317962 0.430931\n",
- "46 2537 2.312116 0.693554\n",
- "47 806 2.318746 0.539312\n",
- "48 251 2.322845 0.181191\n",
- "49 2845 2.326007 0.667570"
- ]
- },
- "execution_count": 284,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "df_adjusted"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 287,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 287,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "