From 06b252f325957c504426212a73705454bbfdc04e Mon Sep 17 00:00:00 2001 From: chrysle <96722107+chrysle@users.noreply.github.com> Date: Wed, 26 Jun 2024 22:00:03 +0200 Subject: [PATCH] Add compat modules for `tomllib` and `importlib.metadata` --- piptools/_compat/__init__.py | 2 ++ piptools/_compat/importlib_metadata.py | 17 +++++++++++++++++ piptools/_compat/tomllib.py | 11 +++++++++++ piptools/build.py | 23 ++--------------------- piptools/utils.py | 11 +++-------- 5 files changed, 35 insertions(+), 29 deletions(-) create mode 100644 piptools/_compat/importlib_metadata.py create mode 100644 piptools/_compat/tomllib.py diff --git a/piptools/_compat/__init__.py b/piptools/_compat/__init__.py index cded6776..6852dc36 100644 --- a/piptools/_compat/__init__.py +++ b/piptools/_compat/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .importlib_metadata import PackageMetadata from .pip_compat import ( Distribution, create_wheel_cache, @@ -12,4 +13,5 @@ "parse_requirements", "create_wheel_cache", "get_dev_pkgs", + "PackageMetadata", ] diff --git a/piptools/_compat/importlib_metadata.py b/piptools/_compat/importlib_metadata.py new file mode 100644 index 00000000..31c41d9a --- /dev/null +++ b/piptools/_compat/importlib_metadata.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import sys + +if sys.version_info >= (3, 10): + from importlib.metadata import PackageMetadata +else: + from typing import Any, Protocol, TypeVar, overload + + _T = TypeVar("_T") + + class PackageMetadata(Protocol): + @overload + def get_all(self, name: str, failobj: None = None) -> list[Any] | None: ... + + @overload + def get_all(self, name: str, failobj: _T) -> list[Any] | _T: ... diff --git a/piptools/_compat/tomllib.py b/piptools/_compat/tomllib.py new file mode 100644 index 00000000..66f73429 --- /dev/null +++ b/piptools/_compat/tomllib.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import sys + +if sys.version_info >= (3, 11): + from tomllib import TOMLDecodeError, load, loads +else: + from tomli import TOMLDecodeError, load, loads + + +__all__ = ["loads", "load", "TOMLDecodeError"] diff --git a/piptools/build.py b/piptools/build.py index c19e206b..ccc89778 100644 --- a/piptools/build.py +++ b/piptools/build.py @@ -3,11 +3,10 @@ import collections import contextlib import pathlib -import sys import tempfile from dataclasses import dataclass from importlib import metadata as importlib_metadata -from typing import Any, Iterator, Protocol, TypeVar, overload +from typing import Iterator import build import build.env @@ -17,29 +16,11 @@ from pip._vendor.packaging.markers import Marker from pip._vendor.packaging.requirements import Requirement +from ._compat import PackageMetadata, tomllib from .utils import copy_install_requirement, install_req_from_line -if sys.version_info >= (3, 11): - import tomllib -else: - import tomli as tomllib - PYPROJECT_TOML = "pyproject.toml" -_T = TypeVar("_T") - - -if sys.version_info >= (3, 10): - from importlib.metadata import PackageMetadata -else: - - class PackageMetadata(Protocol): - @overload - def get_all(self, name: str, failobj: None = None) -> list[Any] | None: ... - - @overload - def get_all(self, name: str, failobj: _T) -> list[Any] | _T: ... - @dataclass class StaticProjectMetadata: diff --git a/piptools/utils.py b/piptools/utils.py index 8d04f7a7..00fb9fdc 100644 --- a/piptools/utils.py +++ b/piptools/utils.py @@ -8,19 +8,12 @@ import os import re import shlex -import sys from pathlib import Path from typing import Any, Callable, Iterable, Iterator, TypeVar, cast -from click.core import ParameterSource - -if sys.version_info >= (3, 11): - import tomllib -else: - import tomli as tomllib - import click import pip +from click.core import ParameterSource from click.utils import LazyFile from pip._internal.req import InstallRequirement from pip._internal.req.constructors import ( @@ -40,6 +33,8 @@ from piptools.locations import DEFAULT_CONFIG_FILE_NAMES from piptools.subprocess_utils import run_python_snippet +from ._compat import tomllib + _KT = TypeVar("_KT") _VT = TypeVar("_VT") _T = TypeVar("_T")