diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 0a5600c1..eab1d4af 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -2,7 +2,7 @@ from copy import deepcopy from functools import wraps -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast, overload from django.core.exceptions import FieldError from django.db import models @@ -10,10 +10,17 @@ from django.db.models.query_utils import DeferredAttribute if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Mapping + from types import TracebackType + class _AugmentedModel(models.Model): + _instance_initialized: bool _deferred_fields: set[str] +T = TypeVar("T") + + class LightStateFieldFile(FieldFile): """ FieldFile subclass with the only aim to remove the instance from the state. @@ -24,32 +31,34 @@ class LightStateFieldFile(FieldFile): Django 3.1+ can make the app unusable, as CPU and memory usage gets easily multiplied by magnitudes. """ - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: """ We don't need to deepcopy the instance, so nullify if provided. """ - state = super().__getstate__() + # django-stubs 1.16.0 doesn't annotate __getstate__(), but it does exist + # in Django itself. + state = super().__getstate__() # type: ignore[misc] if 'instance' in state: state['instance'] = None return state -def lightweight_deepcopy(value): +def lightweight_deepcopy(value: T) -> T: """ Use our lightweight class to avoid copying the instance on a FieldFile deepcopy. """ if isinstance(value, FieldFile): - value = LightStateFieldFile( + value = cast(T, LightStateFieldFile( instance=value.instance, field=value.field, name=value.name, - ) + )) return deepcopy(value) class DescriptorMixin: field_name: str - tracker_instance: Any = None + tracker_instance: FieldInstanceTracker def __get__( self, @@ -75,12 +84,20 @@ def _get_field_name(self) -> str: class DescriptorWrapper: - def __init__(self, field_name, descriptor, tracker_attname): + def __init__(self, field_name: str, descriptor: models.Field, tracker_attname: str): self.field_name = field_name self.descriptor = descriptor self.tracker_attname = tracker_attname - def __get__(self, instance, owner): + @overload + def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper: + ... + + @overload + def __get__(self, instance: models.Model, owner: type[models.Model]) -> models.Field: + ... + + def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper | models.Field: if instance is None: return self was_deferred = self.field_name in instance.get_deferred_fields() @@ -93,7 +110,7 @@ def __get__(self, instance, owner): tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value) return value - def __set__(self, instance, value): + def __set__(self, instance: models.Model, value: models.Field) -> None: initialized = hasattr(instance, '_instance_initialized') was_deferred = self.field_name in instance.get_deferred_fields() @@ -117,7 +134,7 @@ def __set__(self, instance, value): instance.__dict__[self.field_name] = value @staticmethod - def cls_for_descriptor(descriptor): + def cls_for_descriptor(descriptor: models.Field) -> type[DescriptorWrapper]: if hasattr(descriptor, '__delete__'): return FullDescriptorWrapper else: @@ -128,8 +145,8 @@ class FullDescriptorWrapper(DescriptorWrapper): """ Wrapper for descriptors with all three descriptor methods. """ - def __delete__(self, obj): - self.descriptor.__delete__(obj) + def __delete__(self, obj: models.Field) -> None: + self.descriptor.__delete__(obj) # type: ignore[attr-defined] class FieldsContext: @@ -153,7 +170,12 @@ class FieldsContext: """ - def __init__(self, tracker, *fields, state=None): + def __init__( + self, + tracker: FieldInstanceTracker, + *fields: str, + state: dict[str, int] | None = None + ): """ :param tracker: FieldInstanceTracker instance to be reset after context exit @@ -171,7 +193,7 @@ def __init__(self, tracker, *fields, state=None): self.fields = fields self.state = state - def __enter__(self): + def __enter__(self) -> FieldsContext: """ Increments tracked fields occurrences count in shared state. """ @@ -180,7 +202,12 @@ def __enter__(self): self.state[f] += 1 return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: """ Decrements tracked fields occurrences count in shared state. @@ -198,29 +225,34 @@ def __exit__(self, exc_type, exc_val, exc_tb): class FieldInstanceTracker: - def __init__(self, instance, fields, field_map): - self.instance = instance + def __init__(self, instance: models.Model, fields: Iterable[str], field_map: Mapping[str, str]): + self.instance = cast("_AugmentedModel", instance) self.fields = fields self.field_map = field_map self.context = FieldsContext(self, *self.fields) - def __enter__(self): + def __enter__(self) -> FieldsContext: return self.context.__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: return self.context.__exit__(exc_type, exc_val, exc_tb) - def __call__(self, *fields): + def __call__(self, *fields: str) -> FieldsContext: return FieldsContext(self, *fields, state=self.context.state) @property - def deferred_fields(self): + def deferred_fields(self) -> set[str]: return self.instance.get_deferred_fields() - def get_field_value(self, field): + def get_field_value(self, field: str) -> Any: return getattr(self.instance, self.field_map[field]) - def set_saved_fields(self, fields=None): + def set_saved_fields(self, fields: Iterable[str] | None = None) -> None: if not self.instance.pk: self.saved_data = {} elif fields is None: @@ -232,7 +264,7 @@ def set_saved_fields(self, fields=None): for field, field_value in self.saved_data.items(): self.saved_data[field] = lightweight_deepcopy(field_value) - def current(self, fields=None): + def current(self, fields: Iterable[str] | None = None) -> dict[str, Any]: """Returns dict of current values for all tracked fields""" if fields is None: deferred_fields = self.deferred_fields @@ -246,7 +278,7 @@ def current(self, fields=None): return {f: self.get_field_value(f) for f in fields} - def has_changed(self, field): + def has_changed(self, field: str) -> bool: """Returns ``True`` if field has changed from currently saved value""" if field in self.fields: # deferred fields haven't changed @@ -256,7 +288,7 @@ def has_changed(self, field): else: raise FieldError('field "%s" not tracked' % field) - def previous(self, field): + def previous(self, field: str) -> Any: """Returns currently saved value of given field""" # handle deferred fields that have not yet been loaded from the database @@ -276,7 +308,7 @@ def previous(self, field): return self.saved_data.get(field) - def changed(self): + def changed(self) -> dict[str, Any]: """Returns dict of fields that changed since save (with old values)""" return { field: self.previous(field) @@ -284,7 +316,7 @@ def changed(self): if self.has_changed(field) } - def init_deferred_fields(self): + def init_deferred_fields(self) -> None: self.instance._deferred_fields = set() if hasattr(self.instance, '_deferred') and not self.instance._deferred: return @@ -295,31 +327,36 @@ class DeferredAttributeTracker(DescriptorMixin, DeferredAttribute): class FileDescriptorTracker(DescriptorMixin, FileDescriptor): tracker_instance = self - def _get_field_name(self): + def _get_field_name(self) -> str: return self.field.name self.instance._deferred_fields = self.instance.get_deferred_fields() for field in self.instance._deferred_fields: field_obj = self.instance.__class__.__dict__.get(field) if isinstance(field_obj, FileDescriptor): - field_tracker = FileDescriptorTracker(field_obj.field) - setattr(self.instance.__class__, field, field_tracker) + file_descriptor_tracker = FileDescriptorTracker(field_obj.field) + setattr(self.instance.__class__, field, file_descriptor_tracker) else: - field_tracker = DeferredAttributeTracker(field) - setattr(self.instance.__class__, field, field_tracker) + deferred_attribute_tracker = DeferredAttributeTracker(field) + setattr(self.instance.__class__, field, deferred_attribute_tracker) class FieldTracker: tracker_class = FieldInstanceTracker - def __init__(self, fields=None): - self.fields = fields + def __init__(self, fields: Iterable[str] | None = None): + # finalize_class() will replace None; pretend it is never None. + self.fields = cast("Iterable[str]", fields) - def __call__(self, func=None, fields=None): - def decorator(f): + def __call__( + self, + func: Callable | None = None, + fields: Iterable[str] | None = None + ) -> Any: + def decorator(f: Callable) -> Callable: @wraps(f) - def inner(obj, *args, **kwargs): + def inner(obj: models.Model, *args: object, **kwargs: object) -> object: tracker = getattr(obj, self.attname) field_list = tracker.fields if fields is None else fields with tracker(*field_list): @@ -330,7 +367,7 @@ def inner(obj, *args, **kwargs): return decorator return decorator(func) - def get_field_map(self, cls): + def get_field_map(self, cls: type[models.Model]) -> dict[str, str]: """Returns dict mapping fields names to model attribute names""" field_map = {field: field for field in self.fields} all_fields = {f.name: f.attname for f in cls._meta.fields} @@ -338,17 +375,17 @@ def get_field_map(self, cls): if k in field_map}) return field_map - def contribute_to_class(self, cls, name): + def contribute_to_class(self, cls: type[models.Model], name: str) -> None: self.name = name self.attname = '_%s' % name models.signals.class_prepared.connect(self.finalize_class, sender=cls) - def finalize_class(self, sender, **kwargs): + def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None: if self.fields is None: self.fields = (field.attname for field in sender._meta.fields) self.fields = set(self.fields) for field_name in self.fields: - descriptor = getattr(sender, field_name) + descriptor: models.Field = getattr(sender, field_name) wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor) wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname) setattr(sender, field_name, wrapped_descriptor) @@ -358,24 +395,29 @@ def finalize_class(self, sender, **kwargs): setattr(sender, self.name, self) self.patch_save(sender) - def initialize_tracker(self, sender, instance, **kwargs): + def initialize_tracker( + self, + sender: type[models.Model], + instance: models.Model, + **kwargs: object + ) -> None: if not isinstance(instance, self.model_class): return # Only init instances of given model (including children) tracker = self.tracker_class(instance, self.fields, self.field_map) setattr(instance, self.attname, tracker) tracker.set_saved_fields() - instance._instance_initialized = True + cast("_AugmentedModel", instance)._instance_initialized = True - def patch_save(self, model): + def patch_save(self, model: type[models.Model]) -> None: self._patch(model, 'save_base', 'update_fields') self._patch(model, 'refresh_from_db', 'fields') - def _patch(self, model, method, fields_kwarg): + def _patch(self, model: type[models.Model], method: str, fields_kwarg: str) -> None: original = getattr(model, method) @wraps(original) - def inner(instance, *args, **kwargs): - update_fields = kwargs.get(fields_kwarg) + def inner(instance: models.Model, *args: object, **kwargs: Any) -> object: + update_fields: Iterable[str] | None = kwargs.get(fields_kwarg) if update_fields is None: fields = self.fields else: @@ -389,7 +431,7 @@ def inner(instance, *args, **kwargs): setattr(model, method, inner) - def __get__(self, instance, owner): + def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> FieldTracker: if instance is None: return self else: @@ -398,7 +440,7 @@ def __get__(self, instance, owner): class ModelInstanceTracker(FieldInstanceTracker): - def has_changed(self, field): + def has_changed(self, field: str) -> bool: """Returns ``True`` if field has changed from currently saved value""" if not self.instance.pk: return True @@ -407,7 +449,7 @@ def has_changed(self, field): else: raise FieldError('field "%s" not tracked' % field) - def changed(self): + def changed(self) -> dict[str, Any]: """Returns dict of fields that changed since save (with old values)""" if not self.instance.pk: return {} @@ -419,5 +461,5 @@ def changed(self): class ModelTracker(FieldTracker): tracker_class = ModelInstanceTracker - def get_field_map(self, cls): + def get_field_map(self, cls: type[models.Model]) -> dict[str, str]: return {field: field for field in self.fields}