Skip to content

Commit

Permalink
Merge pull request #11 from ajatkj/core_upgrade_for_better_dataclass_…
Browse files Browse the repository at this point in the history
…support

upgrade for better dataclass support
  • Loading branch information
ajatkj authored Mar 8, 2024
2 parents f53e12b + 138672e commit 649ea92
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 49 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ It leverages Python's type hints and dataclasses to provide a convenient way of
✓ Fully typed.<br />
✓ Use dataclasses to parse the configuration file.<br />
✓ Support for almost all python built-in data types - `int`, `float`, `str`, `list`, `tuple`, `dict` and complex data types using `Union` and `Optional`.<br />
✓ Supports almost all features of dataclasses including field level init flag, **post_init** method, InitVars and more.<br />
✓ Built on top of `configparser`, hence retains all functionalities of `configparser`.<br />
✓ Support for optional values (optional values are automatically set to `None` if not provided).<br />
✓ Smarter defaults (see below).
Expand Down
68 changes: 67 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ class NotADataclass:
with self.assertRaisesRegex(ParseError, f"ParseError in section '{_SECTION_}'"):
self.config_parser.parse_section(NotADataclass, _SECTION_) # type: ignore

def test_parse_section_dataclass_init_false(self) -> None:
@dataclasses.dataclass(init=False)
class TestDataclass:
option1: int

self.config_parser.add_section(_SECTION_)

with self.assertRaisesRegex(TypeError, "init flag must be True for dataclass 'TestDataclass'"):
self.config_parser.parse_section(TestDataclass, _SECTION_)

def test_parse_section_extra_fields_allow(self) -> None:
@dataclasses.dataclass
class TestDataclass:
Expand Down Expand Up @@ -160,11 +170,13 @@ class TestDataclass:
option1: typing.Optional[str]
option2: str
option3: typing.Optional[int] = 0
option4: typing.Optional[float] = 10.2

self.config_parser.set("DEFAULT", "option1", "default_value1")
self.config_parser.set("DEFAULT", "option2", "default_value2")
self.config_parser.add_section(_SECTION_)
self.config_parser.set(_SECTION_, "option1", "test_value1")
self.config_parser.set(_SECTION_, "option4", "11.2")

result = self.config_parser.parse_section(TestDataclass, _SECTION_)
self.assertIsInstance(result, TestDataclass)
Expand All @@ -174,6 +186,8 @@ class TestDataclass:
self.assertIsInstance(result.option2, str)
self.assertEqual(result.option3, 0)
self.assertIsInstance(result.option3, int)
self.assertEqual(result.option4, 11.2)
self.assertIsInstance(result.option4, float)

def test_parse_section_arbitrary_types(self) -> None:
@dataclasses.dataclass
Expand Down Expand Up @@ -285,7 +299,7 @@ class TestDataclass:

with self.assertRaisesRegex(
ParseError,
f"ParseError in section '{_SECTION_}' for option 'option1': Cannot cast value 'foo' to 'union type'",
f"ParseError in section '{_SECTION_}' for option 'option1': Cannot cast value 'foo' to '\(int|float\)' type",
):
self.config_parser.parse_section(TestDataclass, _SECTION_)

Expand Down Expand Up @@ -358,6 +372,58 @@ class TestDataclass:
self.assertEqual(result.option1, ["foo", "bar", "baz"])
self.assertIsInstance(result.option1, typing.List)

def test_parse_section_field_level_init_flag(self) -> None:
@dataclasses.dataclass
class TestDataclass:
option1: int
option2: float = dataclasses.field(init=False)

self.config_parser.add_section(_SECTION_)
self.config_parser.set(_SECTION_, "option1", "10")
result = self.config_parser.parse_section(TestDataclass, _SECTION_)

self.assertIsInstance(result, TestDataclass)
self.assertEqual(result.option1, 10)
self.assertFalse(hasattr(result, "option2"))

def test_parse_section_post_init_method(self) -> None:
@dataclasses.dataclass
class TestDataclass:
option1: int
option2: float = dataclasses.field(init=False)

def __post_init__(self) -> None:
self.option2 = self.option1 + 20.2

self.config_parser.add_section(_SECTION_)
self.config_parser.set(_SECTION_, "option1", "10")
result = self.config_parser.parse_section(TestDataclass, _SECTION_)

self.assertIsInstance(result, TestDataclass)
self.assertIsInstance(result.option1, int)
self.assertEqual(result.option1, 10)
self.assertIsInstance(result.option2, float)
self.assertEqual(result.option2, 30.2)

def test_parse_section_initvars(self) -> None:
@dataclasses.dataclass
class TestDataclass:
option1: int
option2: dataclasses.InitVar[typing.Optional[str]] = None
option3: typing.Optional[str] = None
option4: dataclasses.InitVar[typing.Optional[str]] = None

def __post_init__(self, option2: typing.Optional[str], option4: typing.Optional[str]) -> None:
self.option3 = option2

self.config_parser.add_section(_SECTION_)
self.config_parser.set(_SECTION_, "option1", "10")
result = self.config_parser.parse_section(TestDataclass, _SECTION_, init_vars={"option2": "foo"})

self.assertIsInstance(result, TestDataclass)
self.assertEqual(result.option1, 10)
self.assertEqual(result.option3, "foo")


def start_test() -> None:
unittest.main()
Expand Down
157 changes: 109 additions & 48 deletions typed_configparser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
import types
import typing

import typing_extensions

from typed_configparser.exceptions import ParseError

if typing.TYPE_CHECKING:
from _typeshed import DataclassInstance

T = typing.TypeVar("T", bound="DataclassInstance")

# This is a hack to make sure get_type_hints work correctly for InitVar
# when __future__ annotations is turned on
dataclasses.InitVar.__call__ = lambda *args: None # type: ignore[method-assign]

BOOLEAN_STATES = {
"1": True,
Expand Down Expand Up @@ -42,10 +47,6 @@
NONE_TYPE = (type(None),)


def _CUSTOM_INIT_METHOD(self: T) -> None:
pass


def _CUSTOM_REPR_METHOD(self: T) -> str:
fields_str = ", ".join(f"{field.name}={getattr(self, field.name)}" for field in self.__dataclass_fields__.values())
return f"{self.__class__.__name__}({fields_str})"
Expand All @@ -67,8 +68,8 @@ def get_types(typ: typing.Type[typing.Any]) -> typing.Any:
Any: A dictionary containing information about the type, including its origin and arguments.
"""
origin = typing.get_origin(typ)
args = typing.get_args(typ)
origin = typing_extensions.get_origin(typ)
args = typing_extensions.get_args(typ)
if origin is None:
return typ
result = {_ORIGIN_KEY_: origin, _ARGS_KEY_: []}
Expand Down Expand Up @@ -121,6 +122,10 @@ def cast_str(section: str, option: str, value: str) -> typing.Any:
raise ParseError(f"Cannot cast value '{value}' to 'str'", section, option=option)


def get_name(args: typing.List[type]) -> str:
return "|".join([getattr(arg, "__name__", repr(arg)) for arg in args])


def cast_value_wrapper(section: str, option: str, value: str, target_type: typing.Any) -> typing.Any:
def cast_value(value: str, target_type: typing.Any) -> typing.Any:
"""
Expand Down Expand Up @@ -150,7 +155,12 @@ def cast_value(value: str, target_type: typing.Any) -> typing.Any:
return cast_value(value, arg)
except Exception:
continue
raise ParseError(f"Cannot cast value '{value}' to 'union type'", section, option=option)

raise ParseError(
f"Cannot cast value '{value}' to '({get_name(args)})' type",
section,
option=option,
)
elif origin in LIST_TYPE:
if is_list(value):
values = re.split(_REGEX_, strip(value, "[", "]"))
Expand Down Expand Up @@ -183,7 +193,7 @@ def cast_value(value: str, target_type: typing.Any) -> typing.Any:
return cast_str(section, option, value)
elif target_type == bool:
return cast_bool(section, option, value)
elif target_type is None:
elif target_type in NONE_TYPE:
return cast_none(section, option, value)
else:
return cast_any(section, option, value, target_type)
Expand All @@ -200,6 +210,11 @@ def is_field_optional(typ: typing.Type[T]) -> bool:
return typs in NONE_TYPE


def is_field_default(field: typing.Any) -> bool:
assert isinstance(field, dataclasses.Field)
return field.default != dataclasses.MISSING or field.default_factory != dataclasses.MISSING or field.init is False


def is_list(value: str) -> bool:
"""Check whether string value qualifies as a list"""
if value.startswith("[") and value.endswith("]"):
Expand Down Expand Up @@ -227,10 +242,12 @@ def strip(value: str, first: str, last: str) -> str:
return value # pragma: no cover


def generate_field(key: str) -> typing.Any:
def generate_field(key: str, default: typing.Optional[str] = None) -> typing.Any:
"""Get a new empty field with just the name attribute"""
f = dataclasses.field()
f.name = key
if default:
f.default = default
return f


Expand Down Expand Up @@ -274,13 +291,16 @@ def _get_type(self, section: str, option: str) -> typing.Any:
"""
config_class = self.__config_class_mapper__.get(section)
try:
typ = config_class.__annotations__[option]
return lambda val: cast_value_wrapper(section, option, val, get_types(typ))
except KeyError:
return str
except Exception: # pragma: no cover
raise
if config_class:
try:
typ = typing_extensions.get_type_hints(config_class)[option]
return lambda val: cast_value_wrapper(section, option, val, get_types(typ))
except KeyError:
return str
except Exception: # pragma: no cover
raise
else: # pragma: no cover
raise TypeError("Config class not found")

def _getitem(self, section: str, option: str) -> typing.Any:
"""
Expand All @@ -304,6 +324,7 @@ def parse_section(
using_dataclass: typing.Type[T],
section_name: typing.Union[str, None] = None,
extra: typing.Literal["allow", "ignore", "error"] = "allow",
init_vars: typing.Dict[str, typing.Any] = {},
) -> T:
"""
Parse a configuration section into a dataclass instance.
Expand All @@ -314,7 +335,9 @@ def parse_section(
If None, the name is derived from the dataclass name. Defaults to None.
extra (Literal["allow", "ignore", "error"], optional): How to handle extra fields
not present in the dataclass. "allow" allows extra fields, "ignore" ignores them,
and "error" raises an ExtraFieldsError. Defaults to "allow".
and "error" raises an ParseError. Defaults to "allow".
init_vars (Dict[str, Any]): For any InitVars on dataclass, send values here as a dict
which will be send to dataclasses's init method and eventually to post_init method.
Returns:
T: An instance of the specified dataclass populated with values from the configuration section.
Expand All @@ -334,44 +357,82 @@ def parse_section(
if not is_dataclass(using_dataclass):
raise ParseError(f"{using_dataclass.__name__} is not a valid dataclass", section_name_)

self.__config_class_mapper__[section_name_] = using_dataclass

# Irrespective of whether init flag is set to False or True for the dataclass,
# we set it to an empty method to avoid any error
using_dataclass.__init__ = _CUSTOM_INIT_METHOD # type: ignore[assignment]
if params := getattr(using_dataclass, "__dataclass_params__", None):
if (init := getattr(params, "init", None)) is not None:
if init is False:
raise TypeError(f"init flag must be True for dataclass '{using_dataclass.__name__}'")

section = using_dataclass()
dataclass_fields = {f.name for f in dataclasses.fields(section)}
self.__config_class_mapper__[section_name_] = using_dataclass
# This are just "fields" and doesn't contain classvar or initvar fields
dataclass_fields = {item.name: item for item in dataclasses.fields(using_dataclass)}
initvar_fields = {
item.name: item
for item in using_dataclass.__dataclass_fields__.values()
if item._field_type is dataclasses._FIELD_INITVAR # type: ignore [attr-defined]
}
options = []

# Extract all (typed) options from a section and assign to the dataclass instance
# Adding all keys to args initially to maintain the order of position arguments
# to be sent to dataclass init method. It is not required for keyword arguments
args = {k: v for k, v in dataclass_fields.items() if not is_field_default(v)}
kwargs = {}
extra_fields = {}
seen = set()

# Iterate through config section to update args & kwargs
# for fields present in dataclass. Anything not found in
# dataclass is added to extra_fields
for key, _ in self.items(section_name_):
value = self._getitem(section_name_, key)
options.append(key)
if (key not in dataclass_fields and extra == "allow") or key in dataclass_fields:
setattr(section, key, value)
if key in dataclass_fields:
field_info = dataclass_fields[key]
if is_field_default(field_info):
kwargs[key] = value
else:
args[key] = value
seen.add(key)
else:
extra_fields[key] = generate_field(key, default=value)

# Now iterate through dataclass fields and update default value of
# any "Optional" fields to None.
# Any non-"Optional" fields present in dataclass but not found in
# config options are missing fields and should raise error
missing_fields = []
for field, field_info in dataclass_fields.items():
field_type = typing_extensions.get_type_hints(using_dataclass)[field]
if not is_field_default(field_info) and field not in seen:
if is_field_optional(field_type):
args[field] = None # type: ignore
elif field not in options:
missing_fields.append(field)

# Supply initvars as kwargs to the dataclass call
for field, field_info in initvar_fields.items():
if field in init_vars:
kwargs[field] = init_vars[field]
else:
kwargs[field] = None

if len(missing_fields) > 0:
raise ParseError(
"Unable to find value in section, default section or dataclass defaults",
section_name_,
", ".join(missing_fields),
)

extra_fields = {key: generate_field(key) for key in options if key not in dataclass_fields}
if len(extra_fields) > 0 and extra == "error":
raise ParseError("Extra fields are not allowed in configuration.", section_name_)

# Set optional fields to None which are not present in config options,
# and neither have a default set in the dataclass
for field_, field_info in using_dataclass.__dataclass_fields__.items():
if field_ not in options and field_ not in using_dataclass.__dict__:
if field_info.default_factory != dataclasses.MISSING:
setattr(section, field_, field_info.default_factory())
elif is_field_optional(field_info.type):
setattr(section, field_, None)
else:
raise ParseError(
"Unable to find value in section, default section or dataclass default",
section_name_,
option=field_,
)

using_dataclass.__dataclass_fields__.update(extra_fields)
setattr(using_dataclass, "__dataclass_extra_fields__", extra_fields)
using_dataclass.__repr__ = _CUSTOM_REPR_METHOD # type: ignore[assignment]
using_dataclass.__str__ = _CUSTOM_STR_METHOD # type: ignore[assignment]
section = using_dataclass(*args.values(), **kwargs)

if extra_fields and extra == "allow":
for k, f in extra_fields.items():
setattr(section, k, f.default)
setattr(using_dataclass, "__dataclass_extra_fields__", extra_fields)
using_dataclass.__dataclass_fields__.update(extra_fields)
# Since __repr__ and __str__ are created when dataclass is created using @dataclass
# decorator, we need to rewrite our own methods for extra fields
using_dataclass.__repr__ = _CUSTOM_REPR_METHOD # type: ignore[assignment]
using_dataclass.__str__ = _CUSTOM_STR_METHOD # type: ignore[assignment]
return section

0 comments on commit 649ea92

Please sign in to comment.