Skip to content

Commit

Permalink
Merge pull request #56 from dayyass/develop
Browse files Browse the repository at this point in the history
release v0.1.2
  • Loading branch information
dayyass authored Aug 19, 2021
2 parents 3b90e19 + 72c761c commit f8eb4cf
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 20 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = text-classification-baseline
version = 0.1.1
version = 0.1.2
author = Dani El-Ayyass
author_email = [email protected]
description = TF-IDF + LogReg baseline for text classification
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import unittest

from data.load_20newsgroups import load_20newsgroups
from text_clf.__main__ import train
from text_clf.config import load_default_config
from text_clf.train import train


class TestUsage(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions text_clf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .train import train
from .__main__ import train

__version__ = "0.1.1"
__version__ = "0.1.2"
__all__ = ["train"]
31 changes: 29 additions & 2 deletions text_clf/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,32 @@
from .train import train
from .utils import get_argparse
import traceback

from .config import get_config
from .train import _train
from .utils import close_logger, get_argparse, get_logger


def train(path_to_config: str = "config.yaml") -> None:
"""
Function to train baseline model with exception handler.
:param str path_to_config: path to config.
"""

# load config
config = get_config(path_to_config=path_to_config)

# get logger
logger = get_logger(path_to_logfile=config["path_to_save_logfile"])

try:
_train(
config=config,
logger=logger,
)
except: # noqa
close_logger(logger)

print(traceback.format_exc())


def main() -> int:
Expand Down
3 changes: 3 additions & 0 deletions text_clf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def get_config(path_to_config: str) -> Dict[str, Any]:
/ f"model_{datetime.datetime.now().strftime('%Y-%m-%d %H-%M-%S')}"
)

# mkdir if not exists
config["path_to_save_folder"].absolute().mkdir(parents=True, exist_ok=True)

config["path_to_config"] = path_to_config
config["path_to_save_model"] = config["path_to_save_folder"] / "model.joblib"
config["path_to_save_logfile"] = config["path_to_save_folder"] / "logging.txt"
Expand Down
26 changes: 12 additions & 14 deletions text_clf/train.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,31 @@
import logging
import time
from typing import Any, Dict

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder

from .config import get_config
from .data import load_data
from .save import save_model
from .utils import get_logger, set_seed
from .utils import close_logger, set_seed


def train(path_to_config: str = "config.yaml") -> None:
def _train(
config: Dict[str, Any],
logger: logging.Logger,
) -> None:
"""
Function to train baseline model.
:param str path_to_config: path to config.
:param Dict[str, Any] config: config.
:param logging.Logger logger: logger.
"""

# load config
config = get_config(path_to_config)

# mkdir if not exists
config["path_to_save_folder"].absolute().mkdir(parents=True, exist_ok=True)

# get logger
logger = get_logger(config["path_to_save_logfile"])

# log config
with open(path_to_config, mode="r") as fp:
with open(config["path_to_config"], mode="r") as fp:
logger.info(f"Config:\n\n{fp.read()}")

# reproducibility
Expand Down Expand Up @@ -109,3 +105,5 @@ def train(path_to_config: str = "config.yaml") -> None:
)

logger.info("Done!")

close_logger(logger)
13 changes: 13 additions & 0 deletions text_clf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def get_logger(path_to_logfile: str) -> logging.Logger:
return logger


def close_logger(logger: logging.Logger) -> None:
"""
Close logger.
Source: https://stackoverflow.com/questions/15435652/python-does-not-release-filehandles-to-logfile
:param logging.Logger logger: logger.
"""

for handler in logger.handlers[:]:
handler.close()
logger.removeHandler(handler)


def set_seed(seed: int) -> None:
"""
Set seed for reproducibility.
Expand Down

0 comments on commit f8eb4cf

Please sign in to comment.