diff --git a/.gitignore b/.gitignore index fa5b41b..ba07277 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,7 @@ dist *.egg-info/ data/train.csv -data/valid.csv +data/test.csv models/* !models/README.md diff --git a/Dockerfile b/Dockerfile index 6092483..9e0a729 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ WORKDIR /workdir COPY config.yaml ./ COPY hyperparams.py ./ -COPY data/train.csv data/valid.csv data/ +COPY data/train.csv data/test.csv data/ RUN pip install --upgrade pip && \ pip install --no-cache-dir text-classification-baseline diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0671823 --- /dev/null +++ b/Makefile @@ -0,0 +1,20 @@ +all: + python -m text_clf --path_to_config config.yaml +load_data: + python data/load_20newsgroups.py +coverage: + coverage run -m unittest discover && coverage report -m +docker_build: + docker image build -t text-classification-baseline . +docker_run: + docker container run -it text-classification-baseline +pypi_packages: + pip install --upgrade build twine +pypi_build: + python -m build +pypi_twine: + python -m twine upload --repository testpypi dist/* +pypi_clean: + rm -rf dist text_classification_baseline.egg-info +clean: + rm -rf models/model* diff --git a/README.md b/README.md index 9466719..f619e94 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ text-clf-train --path_to_config config.yaml ```python3 import text_clf -text_clf.train(path_to_config="config.yaml") +model, target_names_mapping = text_clf.train(path_to_config="config.yaml") ``` **NOTE**: more about config file [here](https://github.com/dayyass/text-classification-baseline/tree/main#config). @@ -55,7 +55,7 @@ text-clf-train --path_to_config config.yaml ```python3 import text_clf -text_clf.train(path_to_config="config.yaml") +model, target_names_mapping = text_clf.train(path_to_config="config.yaml") ``` Default **config.yaml**: @@ -66,7 +66,7 @@ path_to_save_folder: models # data data: train_data_path: data/train.csv - valid_data_path: data/valid.csv + test_data_path: data/test.csv sep: ',' text_column: text target_column: target_name_short diff --git a/config.yaml b/config.yaml index f657db3..56c5fc5 100644 --- a/config.yaml +++ b/config.yaml @@ -4,7 +4,7 @@ path_to_save_folder: models # data data: train_data_path: data/train.csv - valid_data_path: data/valid.csv + test_data_path: data/test.csv sep: ',' text_column: text target_column: target_name_short diff --git a/data/README.md b/data/README.md index e31cc0a..a60c80c 100644 --- a/data/README.md +++ b/data/README.md @@ -4,5 +4,5 @@ Folder for storing datasets. To download [**the 20 newsgroups text dataset**](https://scikit-learn.org/stable/datasets/real_world.html#newsgroups-dataset) run the following command: ``` -python fetch_20newsgroups.py +python load_20newsgroups.py ``` diff --git a/data/load_20newsgroups.py b/data/load_20newsgroups.py index 48dd016..36d2d63 100644 --- a/data/load_20newsgroups.py +++ b/data/load_20newsgroups.py @@ -6,12 +6,13 @@ def make_df_from_bunch(bunch: Bunch) -> pd.DataFrame: - """ - Make pd.DataFrame from 20newsgroups bunch. + """Make pd.DataFrame from 20newsgroups bunch. + + Args: + bunch (Bunch): 20newsgroups bunch. - :param Bunch bunch: 20newsgroups bunch. - :return: 20newsgroups DataFrame. - :rtype: pd.DataFrame + Returns: + pd.DataFrame: 20newsgroups DataFrame. """ df = pd.DataFrame( @@ -27,20 +28,18 @@ def make_df_from_bunch(bunch: Bunch) -> pd.DataFrame: def load_20newsgroups() -> None: - """ - Load 20newsgroups dataset. - """ + """Load 20newsgroups dataset.""" train_bunch = fetch_20newsgroups(subset="train") test_bunch = fetch_20newsgroups(subset="test") df_train = make_df_from_bunch(train_bunch) - df_valid = make_df_from_bunch(test_bunch) + df_test = make_df_from_bunch(test_bunch) os.makedirs("data", exist_ok=True) df_train.to_csv("data/train.csv", index=False) - df_valid.to_csv("data/valid.csv", index=False) + df_test.to_csv("data/test.csv", index=False) if __name__ == "__main__": diff --git a/setup.cfg b/setup.cfg index 13eec86..38bc129 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = text-classification-baseline -version = 0.1.3 +version = 0.1.4 author = Dani El-Ayyass author_email = dayyass@yandex.ru description = TF-IDF + LogReg baseline for text classification diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1e4e58e..739e0a5 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -9,15 +9,11 @@ class TestUsage(unittest.TestCase): - """ - Class for testing pipeline. - """ + """Class for testing pipeline.""" @classmethod def setUpClass(cls) -> None: - """ - SetUp tests with config and data. - """ + """SetUp tests with config and data.""" path_to_config = "config.yaml" @@ -28,16 +24,12 @@ def setUpClass(cls) -> None: load_20newsgroups() def test_train(self) -> None: - """ - Testing train function. - """ + """Testing train function.""" train(path_to_config="config.yaml") def test_train_grid_search(self) -> None: - """ - Testing train function with grid_search. - """ + """Testing train function with grid_search.""" with open("config.yaml", mode="r") as fp: config = yaml.safe_load(fp) @@ -56,9 +48,7 @@ def test_train_grid_search(self) -> None: @classmethod def tearDownClass(cls) -> None: - """ - TearDown after tests. - """ + """TearDown after tests.""" os.remove("config_grid_search.yaml") diff --git a/text_clf/__init__.py b/text_clf/__init__.py index 8250572..8c13dbb 100644 --- a/text_clf/__init__.py +++ b/text_clf/__init__.py @@ -1,4 +1,4 @@ from .__main__ import train -__version__ = "0.1.3" +__version__ = "0.1.4" __all__ = ["train"] diff --git a/text_clf/__main__.py b/text_clf/__main__.py index 05b761b..3a10f27 100644 --- a/text_clf/__main__.py +++ b/text_clf/__main__.py @@ -1,15 +1,22 @@ import traceback +from typing import Dict, Tuple + +from sklearn.pipeline import Pipeline from .config import get_config from .train import _train from .utils import close_logger, get_argparse, get_logger -def train(path_to_config: str) -> None: - """ - Function to train baseline model with exception handler. +def train(path_to_config: str) -> Tuple[Pipeline, Dict[int, str]]: + """Function to train baseline model with exception handler. - :param str path_to_config: path to config. + Args: + path_to_config (str): Path to config. + + Returns: + Tuple[Pipeline, Dict[int, str]]: + Model pipeline (tf-idf + logreg) and target names mapping. Both None if any exception occurred. """ # load config @@ -19,22 +26,26 @@ def train(path_to_config: str) -> None: logger = get_logger(path_to_logfile=config["path_to_save_logfile"]) try: - _train( + pipe, target_names_mapping = _train( config=config, logger=logger, ) + except: # noqa close_logger(logger) print(traceback.format_exc()) + pipe, target_names_mapping = None, None # type: ignore + + return pipe, target_names_mapping + def main() -> int: - """ - Main function to train baseline model. + """Main function to train baseline model. - :return: exit code. - :rtype: int + Returns: + int: Exit code. """ # argument parser @@ -42,7 +53,7 @@ def main() -> int: args = parser.parse_args() # train - train(path_to_config=args.path_to_config) + _ = train(path_to_config=args.path_to_config) return 0 diff --git a/text_clf/config.py b/text_clf/config.py index f17f54c..c513651 100644 --- a/text_clf/config.py +++ b/text_clf/config.py @@ -10,12 +10,13 @@ def get_config(path_to_config: str) -> Dict[str, Any]: - """ - Get config. + """Get config. + + Args: + path_to_config (str): Path to config. - :param str path_to_config: path to config. - :return: config. - :rtype: Dict[str, Any] + Returns: + Dict[str, Any]: Config. """ with open(path_to_config, mode="r") as fp: @@ -55,11 +56,14 @@ def load_default_config( path_to_save_folder: str = ".", filename: str = "config.yaml", ) -> None: - """ - Function to load default config. + """Function to load default config. + + Args: + path_to_save_folder (str, optional): Path to save folder. Defaults to ".". + filename (str, optional): Filename. Defaults to "config.yaml". - :param str path_to_save_folder: path to save folder (default: '.'). - :param str filename: filename (default: 'config.yaml'). + Raises: + FileExistsError: Raise error if config file already exists. """ # get logger @@ -84,7 +88,7 @@ def load_default_config( "# data", "data:", " train_data_path: data/train.csv", - " valid_data_path: data/valid.csv", + " test_data_path: data/test.csv", " sep: ','", " text_column: text", " target_column: target_name_short", diff --git a/text_clf/data.py b/text_clf/data.py index f4e7029..c6bd31a 100644 --- a/text_clf/data.py +++ b/text_clf/data.py @@ -6,12 +6,13 @@ def load_data( config: Dict[str, Any] ) -> Tuple[pd.Series, pd.Series, pd.Series, pd.Series]: - """ - Load data. + """Load data. + + Args: + config (Dict[str, Any]): Config. - :param Dict[str, Any] config: config. - :return: X_train, X_valid, y_train, y_valid. - :rtype: Tuple[pd.Series, pd.Series, pd.Series, pd.Series] + Returns: + Tuple[pd.Series, pd.Series, pd.Series, pd.Series]: X_train, X_test, y_train, y_test. """ text_column = config["data"]["text_column"] @@ -26,15 +27,15 @@ def load_data( usecols=usecols, ) - df_valid = pd.read_csv( - config["data"]["valid_data_path"], + df_test = pd.read_csv( + config["data"]["test_data_path"], sep=sep, usecols=usecols, ) X_train = df_train[text_column] - X_valid = df_valid[text_column] + X_test = df_test[text_column] y_train = df_train[target_column] - y_valid = df_valid[target_column] + y_test = df_test[target_column] - return X_train, X_valid, y_train, y_valid + return X_train, X_test, y_train, y_test diff --git a/text_clf/save.py b/text_clf/save.py index 9d8288c..3e65045 100644 --- a/text_clf/save.py +++ b/text_clf/save.py @@ -11,16 +11,16 @@ def save_model( target_names_mapping: Dict[int, str], config: Dict[str, Any], ) -> None: - """ - Save: - - model pipeline (tf-idf + model) + """Save: + - model pipeline (tf-idf + logreg) - target names mapping - config - hyper-parameters grid (from config) - :param Pipeline pipe: model pipeline (tf-idf + model). - :param Dict[int, str] target_names_mapping: name for each class. - :param Dict[str, Any] config: config. + Args: + pipe (Pipeline): Model pipeline (tf-idf + logreg). + target_names_mapping (Dict[int, str]): Name for each class. + config (Dict[str, Any]): Config. """ # save pipe diff --git a/text_clf/train.py b/text_clf/train.py index ec57220..84c47d4 100644 --- a/text_clf/train.py +++ b/text_clf/train.py @@ -1,9 +1,10 @@ import logging -from typing import Any, Dict +from typing import Any, Dict, Tuple +import pandas as pd from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model import LogisticRegression -from sklearn.metrics import classification_report +from sklearn.metrics import classification_report, confusion_matrix from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline from sklearn.preprocessing import LabelEncoder @@ -16,12 +17,15 @@ def _train( config: Dict[str, Any], logger: logging.Logger, -) -> None: - """ - Function to train baseline model. +) -> Tuple[Pipeline, Dict[int, str]]: + """Function to train baseline model. + + Args: + config (Dict[str, Any]): Config. + logger (logging.Logger): Logger. - :param Dict[str, Any] config: config. - :param logging.Logger logger: logger. + Returns: + Tuple[Pipeline, Dict[int, str]]: Model pipeline (tf-idf + logreg) and target names mapping. """ # log config @@ -34,15 +38,15 @@ def _train( # load data logger.info("Loading data...") - X_train, X_valid, y_train, y_valid = load_data(config) + X_train, X_test, y_train, y_test = load_data(config) logger.info(f"Train dataset size: {X_train.shape[0]}") - logger.info(f"Valid dataset size: {X_valid.shape[0]}") + logger.info(f"Test dataset size: {X_test.shape[0]}") # label encoder le = LabelEncoder() y_train = le.fit_transform(y_train) - y_valid = le.transform(y_valid) + y_test = le.transform(y_test) target_names = [str(cls) for cls in le.classes_.tolist()] target_names_mapping = {i: cls for i, cls in enumerate(target_names)} @@ -97,17 +101,35 @@ def _train( y_pred=y_pred_train, target_names=target_names, ) + conf_matrix_train = pd.DataFrame( + confusion_matrix( + y_true=y_train, + y_pred=y_pred_train, + ), + columns=target_names, + index=target_names, + ) logger.info(f"Train classification report:\n\n{classification_report_train}") + logger.info(f"Train confusion matrix:\n\n{conf_matrix_train}\n") - y_pred_valid = pipe.predict(X_valid) - classification_report_valid = classification_report( - y_true=y_valid, - y_pred=y_pred_valid, + y_pred_test = pipe.predict(X_test) + classification_report_test = classification_report( + y_true=y_test, + y_pred=y_pred_test, target_names=target_names, ) + conf_matrix_test = pd.DataFrame( + confusion_matrix( + y_true=y_test, + y_pred=y_pred_test, + ), + columns=target_names, + index=target_names, + ) - logger.info(f"Valid classification report:\n\n{classification_report_valid}") + logger.info(f"Test classification report:\n\n{classification_report_test}") + logger.info(f"Test confusion matrix:\n\n{conf_matrix_test}\n") # save model logger.info("Saving the model...") @@ -121,3 +143,5 @@ def _train( logger.info("Done!") close_logger(logger) + + return pipe, target_names_mapping diff --git a/text_clf/utils.py b/text_clf/utils.py index f4f77e1..0280752 100644 --- a/text_clf/utils.py +++ b/text_clf/utils.py @@ -9,11 +9,10 @@ def get_argparse() -> ArgumentParser: - """ - Get argument parser. + """Get argument parser. - :return: argument parser. - :rtype: ArgumentParser + Returns: + ArgumentParser: Argument parser. """ parser = ArgumentParser(prog="text-clf-train") @@ -28,12 +27,13 @@ def get_argparse() -> ArgumentParser: def get_logger(path_to_logfile: str) -> logging.Logger: - """ - Get logger. + """Get logger. + + Args: + path_to_logfile (str): Path to logfile. - :param str path_to_logfile: path to logfile. - :return: logger. - :rtype: logging.Logger + Returns: + logging.Logger: Logger. """ logger = logging.getLogger("text-clf-train") @@ -61,11 +61,11 @@ def get_logger(path_to_logfile: str) -> logging.Logger: def close_logger(logger: logging.Logger) -> None: - """ - Close logger. + """Close logger. Source: https://stackoverflow.com/questions/15435652/python-does-not-release-filehandles-to-logfile - :param logging.Logger logger: logger. + Args: + logger (logging.Logger): Logger. """ for handler in logger.handlers[:]: @@ -74,10 +74,10 @@ def close_logger(logger: logging.Logger) -> None: def set_seed(seed: int) -> None: - """ - Set seed for reproducibility. + """Set seed for reproducibility. - :param int seed: seed. + Args: + seed (int): Seed. """ random.seed(seed) @@ -85,12 +85,13 @@ def set_seed(seed: int) -> None: def get_grid_search_params(grid_search_params_path: str) -> Dict[str, Any]: - """ - Get grid_search_params from python file. + """Get grid_search_params from python file. + + Args: + grid_search_params_path (str): Python file with grid_search_params. - :param str grid_search_params_path: python file with grid_search_params. - :return: grid_search_params. - :rtype: Dict[str, Any] + Returns: + Dict[str, Any]: grid_search_params. """ spec = importlib.util.spec_from_file_location( # type: ignore @@ -106,12 +107,13 @@ def get_grid_search_params(grid_search_params_path: str) -> Dict[str, Any]: def prepare_dict_to_print(dict: Dict[str, Any]) -> str: - """ - Helper function to create pretty string to print dictionary. + """Helper function to create pretty string to print dictionary. + + Args: + dict (Dict[str, Any]): Arbitrary dictionary. - :param Dict[str, Any] dict: arbitrary dictionary. - :return: pretty string to print dictionary. - :rtype: str + Returns: + str: Pretty string to print dictionary. """ sorted_items = sorted(dict.items(), key=lambda x: x[0])