diff --git a/Makefile b/Makefile index f67c1515..4b91187a 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ isort = isort opt_einsum scripts/ black = black opt_einsum scripts/ autoflake = autoflake -ir --remove-all-unused-imports --ignore-init-module-imports --remove-unused-variables opt_einsum scripts/ -mypy = mypy --ignore-missing-imports codex opt_einsum scripts/ +mypy = mypy --ignore-missing-imports opt_einsum scripts/ .PHONY: install install: diff --git a/docs/api_reference.md b/docs/api_reference.md index e4ac7723..75695331 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -1,18 +1,22 @@ +--- +toc_depth: 1 +--- + # API Documentation ### `opt_einsum.contract` -::: opt_einsum.contract +::: opt_einsum.contract.contract ### `opt_einsum.contract_path` -::: opt_einsum.contract_path +::: opt_einsum.contract.contract_path ### `opt_einsum.contract_expression` -::: opt_einsum.contract_expression +::: opt_einsum.contract.contract_expression @@ -29,12 +33,12 @@ ### `opt_einsum.get_symbol` -::: opt_einsum.get_symbol +::: opt_einsum.parser.get_symbol ### `opt_einsum.shared_intermediates` -::: opt_einsum.shared_intermediates +::: opt_einsum.sharing.shared_intermediates ### `opt_einsum.paths.optimal` diff --git a/docs/getting_started/backends.md b/docs/getting_started/backends.md index 36c3869a..66c848bb 100644 --- a/docs/getting_started/backends.md +++ b/docs/getting_started/backends.md @@ -38,7 +38,7 @@ The following is a brief overview of libraries which have been tested with The automatic backend detection will be detected based on the first supplied array (default), this can be overridden by specifying the correct `backend` argument for the type of arrays supplied when calling -[`opt_einsum.contract`](../api_reference.md##opt_einsumcontract). For example, if you had a library installed +[`opt_einsum.contract`](../api_reference.md#opt_einsum.contract.contract). For example, if you had a library installed called `'foo'` which provided an `numpy.ndarray` like object with a `.shape` attribute as well as `foo.tensordot` and `foo.transpose` then you could contract them with something like: @@ -189,7 +189,7 @@ Currently `opt_einsum` can handle this automatically for: all of which offer GPU support. Since `tensorflow` and `theano` both require compiling the expression, this functionality is encapsulated in generating a -[`opt_einsum.ContractExpression`](../api_reference.md#opt_einsumcontractcontractexpression) using +[`opt_einsum.ContractExpression`](../api_reference.md#opt_einsum.contract.contract_expression) using [`opt_einsum.contract_expression`](../api_reference.md#opt_einsumcontract_expression), which can then be called using numpy arrays whilst specifying `backend='tensorflow'` etc. Additionally, if arrays are marked as `constant` @@ -259,7 +259,7 @@ tf.enable_eager_execution() After which `opt_einsum` will automatically detect eager mode if `backend='tensorflow'` is supplied to a -[`opt_einsum.ContractExpression`](../api_reference.md###opt_einsumcontractcontractexpression). +[`opt_einsum.ContractExpression`](../api_reference.md#opt_einsum.contract.contract_expression). ### Pytorch & Cupy diff --git a/docs/paths/custom_paths.md b/docs/paths/custom_paths.md index 40c67fae..bac22db9 100644 --- a/docs/paths/custom_paths.md +++ b/docs/paths/custom_paths.md @@ -1,7 +1,7 @@ # Custom Path Optimizers If you want to implement or just experiment with custom contaction paths then -you can easily by subclassing the [`opt_einsum.paths.PathOptimizer`](../api_reference.md#opt_einsumpathspathoptimizer) +you can easily by subclassing the [`opt_einsum.paths.PathOptimizer`](../api_reference.md#opt_einsum.paths.PathOptimizer) object. For example, imagine we want to test the path that just blindly contracts the first pair of tensors again and again. We would implement this as: @@ -49,7 +49,7 @@ machinery of the random-greedy approach. Namely: - Parallelization using a pool-executor This is done by subclassing the -[`opt_einsum.paths.RandomOptimizer`](../api_reference.md#opt_einsumpathsrandomoptimizer) +[`opt_einsum.paths.RandomOptimizer`](../api_reference.md#opt_einsum.path_random.RandomOptimizer) object and implementing a `setup` method. Here's an example where we just randomly select any path (again, although we get a considerable speedup over `einsum` this is diff --git a/docs/reference/api.rst b/docs/reference/api.rst deleted file mode 100644 index fa4fe558..00000000 --- a/docs/reference/api.rst +++ /dev/null @@ -1,22 +0,0 @@ -================== -Function Reference -================== - -.. autosummary:: - :toctree: autosummary - - opt_einsum.contract - opt_einsum.contract_path - opt_einsum.contract_expression - opt_einsum.contract.ContractExpression - opt_einsum.contract.PathInfo - opt_einsum.paths.optimal - opt_einsum.paths.greedy - opt_einsum.paths.branch - opt_einsum.parser.get_symbol - opt_einsum.sharing.shared_intermediates - opt_einsum.paths.PathOptimizer - opt_einsum.paths.BranchBound - opt_einsum.path_random.RandomOptimizer - opt_einsum.path_random.RandomGreedy - opt_einsum.paths.DynamicProgramming diff --git a/mkdocs.yml b/mkdocs.yml index 7565cdb0..365a36d8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -3,6 +3,30 @@ repo_url: https://github.com/dgasmith/opt_einsum repo_name: dgasmith/opt_einsum theme: name: material + features: + - navigation.instant + palette: + + # Palette toggle for automatic mode + - media: "(prefers-color-scheme)" + toggle: + icon: material/brightness-auto + name: Switch to light mode + + # Palette toggle for light mode + - media: "(prefers-color-scheme: light)" + scheme: default + toggle: + icon: material/brightness-7 + name: Switch to dark mode + + # Palette toggle for dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + toggle: + icon: material/brightness-4 + name: Switch to system preference + plugins: - search @@ -14,6 +38,25 @@ plugins: # paths: [opt_einsum] options: docstring_style: google + docstring_options: + ignore_init_summary: true + docstring_section_style: list + filters: ["!^_"] + heading_level: 1 + inherited_members: true + merge_init_into_class: true + parameter_headings: true + preload_modules: [mkdocstrings] + separate_signature: true + show_root_heading: true + show_root_full_path: false + show_signature_annotations: true + show_source: false + show_symbol_type_heading: true + show_symbol_type_toc: true + signature_crossrefs: true + summary: true + unwrap_annotated: true extra_javascript: - javascript/config.js @@ -36,6 +79,8 @@ markdown_extensions: - pymdownx.extra - pymdownx.arithmatex: generic: true + - toc: + toc_depth: 2 nav: - Overview: index.md diff --git a/opt_einsum/__init__.py b/opt_einsum/__init__.py index 3e8f6e50..828fc529 100644 --- a/opt_einsum/__init__.py +++ b/opt_einsum/__init__.py @@ -2,15 +2,15 @@ Main init function for opt_einsum. """ -from . import blas, helpers, path_random, paths -from .contract import contract, contract_expression, contract_path -from .parser import get_symbol -from .path_random import RandomGreedy -from .paths import BranchBound, DynamicProgramming -from .sharing import shared_intermediates +from opt_einsum import blas, helpers, path_random, paths +from opt_einsum.contract import contract, contract_expression, contract_path +from opt_einsum.parser import get_symbol +from opt_einsum.path_random import RandomGreedy +from opt_einsum.paths import BranchBound, DynamicProgramming +from opt_einsum.sharing import shared_intermediates # Handle versioneer -from ._version import get_versions # isort:skip +from opt_einsum._version import get_versions # isort:skip versions = get_versions() __version__ = versions["version"] diff --git a/opt_einsum/backends/object_arrays.py b/opt_einsum/backends/object_arrays.py index 308cb671..eae0e92f 100644 --- a/opt_einsum/backends/object_arrays.py +++ b/opt_einsum/backends/object_arrays.py @@ -7,8 +7,10 @@ import numpy as np +from opt_einsum.typing import ArrayType -def object_einsum(eq, *arrays): + +def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType: """A ``einsum`` implementation for ``numpy`` arrays with object dtype. The loop is performed in python, meaning the objects themselves need only to implement ``__mul__`` and ``__add__`` for the contraction to be diff --git a/opt_einsum/backends/torch.py b/opt_einsum/backends/torch.py index ed92fd53..c3ae9b5e 100644 --- a/opt_einsum/backends/torch.py +++ b/opt_einsum/backends/torch.py @@ -41,7 +41,7 @@ def transpose(a, axes): return a.permute(*axes) -def einsum(equation, *operands): +def einsum(equation, *operands, **kwargs): """Variadic version of torch.einsum to match numpy api.""" # rename symbols to support PyTorch 0.4.1 and earlier, # which allow only symbols a-z. diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index 1b38c04c..e44d12da 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -5,10 +5,18 @@ from collections import namedtuple from decimal import Decimal from functools import lru_cache -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple - -from . import backends, blas, helpers, parser, paths, sharing -from .typing import ArrayIndexType, ArrayType, ContractionListType, PathType +from typing import Any, Collection, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union, overload + +from opt_einsum import backends, blas, helpers, parser, paths, sharing +from opt_einsum.typing import ( + ArrayIndexType, + ArrayType, + BackendType, + ContractionListType, + OptimizeKind, + PathType, + TensorShapeType, +) __all__ = [ "contract_path", @@ -18,16 +26,16 @@ "shape_only", ] +## Common types -class PathInfo: - """A printable object to contain information about a contraction path. +_OrderKACF = Literal[None, "K", "A", "C", "F"] - **Attributes:** +_Casting = Literal["no", "equiv", "safe", "same_kind", "unsafe"] +_MemoryLimit = Union[None, int, Decimal, Literal["max_input"]] - - **naive_cost** - *(int)* The estimate FLOP cost of a naive einsum contraction. - - **opt_cost** - *(int)* The estimate FLOP cost of this optimized contraction path. - - **largest_intermediate** - *(int)* The number of elements in the largest intermediate array that will be produced during the contraction. - """ + +class PathInfo: + """A printable object to contain information about a contraction path.""" def __init__( self, @@ -96,10 +104,13 @@ def __repr__(self) -> str: return "".join(path_print) -def _choose_memory_arg(memory_limit: int, size_list: List[int]) -> Optional[int]: +def _choose_memory_arg(memory_limit: _MemoryLimit, size_list: List[int]) -> Optional[int]: if memory_limit == "max_input": return max(size_list) + if isinstance(memory_limit, str): + raise ValueError("memory_limit must be None, int, or the string Literal['max_input'].") + if memory_limit is None: return None @@ -112,154 +123,189 @@ def _choose_memory_arg(memory_limit: int, size_list: List[int]) -> Optional[int] return int(memory_limit) -_VALID_CONTRACT_KWARGS = { - "optimize", - "path", - "memory_limit", - "einsum_call", - "use_blas", - "shapes", -} - - -def contract_path(*operands_: Any, **kwargs: Any) -> Tuple[PathType, PathInfo]: +_EinsumDefaultKeys = Literal["order", "casting", "dtype", "out"] + + +def _filter_einsum_defaults(kwargs: Dict[_EinsumDefaultKeys, Any]) -> Dict[_EinsumDefaultKeys, Any]: + """Filters out default contract kwargs to pass to various backends.""" + kwargs = kwargs.copy() + ret: Dict[_EinsumDefaultKeys, Any] = {} + if (order := kwargs.pop("order", "K")) != "K": + ret["order"] = order + + if (casting := kwargs.pop("casting", "safe")) != "safe": + ret["casting"] = casting + + if (dtype := kwargs.pop("dtype", None)) is not None: + ret["dtype"] = dtype + + if (out := kwargs.pop("out", None)) is not None: + ret["out"] = out + + ret.update(kwargs) + return ret + + +# Overlaod for contract(einsum_string, *operands) +@overload +def contract_path( + subscripts: str, + *operands: ArrayType, + use_blas: bool = True, + optimize: OptimizeKind = True, + memory_limit: _MemoryLimit = None, + shapes: bool = False, + **kwargs: Any, +) -> Tuple[PathType, PathInfo]: ... + + +# Overlaod for contract(operand, indices, operand, indices, ....) +@overload +def contract_path( + subscripts: ArrayType, + *operands: Union[ArrayType, Collection[int]], + use_blas: bool = True, + optimize: OptimizeKind = True, + memory_limit: _MemoryLimit = None, + shapes: bool = False, + **kwargs: Any, +) -> Tuple[PathType, PathInfo]: ... + + +def contract_path( + subscripts: Any, + *operands: Any, + use_blas: bool = True, + optimize: OptimizeKind = True, + memory_limit: _MemoryLimit = None, + shapes: bool = False, + **kwargs: Any, +) -> Tuple[PathType, PathInfo]: """ - Find a contraction order `path`, without performing the contraction. - - **Parameters:** - - - **subscripts** - *(str)* Specifies the subscripts for summation. - - **\\*operands** - *(list of array_like)* these are the arrays for the operation. - - **use_blas** - *(bool)* Do you use BLAS for valid operations, may use extra memory for more intermediates. - - **optimize** - *(str, list or bool, optional (default: `auto`))* Choose the type of path. - - - if a list is given uses this as the path. - - `'optimal'` An algorithm that explores all possible ways of - contracting the listed tensors. Scales factorially with the number of - terms in the contraction. - - `'dp'` A faster (but essentially optimal) algorithm that uses - dynamic programming to exhaustively search all contraction paths - without outer-products. - - `'greedy'` An cheap algorithm that heuristically chooses the best - pairwise contraction at each step. Scales linearly in the number of - terms in the contraction. - - `'random-greedy'` Run a randomized version of the greedy algorithm - 32 times and pick the best path. - - `'random-greedy-128'` Run a randomized version of the greedy - algorithm 128 times and pick the best path. - - `'branch-all'` An algorithm like optimal but that restricts itself - to searching 'likely' paths. Still scales factorially. - - `'branch-2'` An even more restricted version of 'branch-all' that - only searches the best two options at each step. Scales exponentially - with the number of terms in the contraction. - - `'auto'` Choose the best of the above algorithms whilst aiming to - keep the path finding time below 1ms. - - `'auto-hq'` Aim for a high quality contraction, choosing the best - of the above algorithms whilst aiming to keep the path finding time - below 1sec. - - - **memory_limit** - *({None, int, 'max_input'} (default: `None`))* - Give the upper bound of the largest intermediate tensor contract will build. - - - None or -1 means there is no limit - - `max_input` means the limit is set as largest input tensor - - a positive integer is taken as an explicit limit on the number of elements - - The default is None. Note that imposing a limit can make contractions - exponentially slower to perform. - - - **shapes** - *(bool, optional)* Whether ``contract_path`` should assume arrays (the default) or array shapes have been supplied. - - **Returns:** - - - **path** - *(list of tuples)* The einsum path - - **PathInfo** - *(str)* A printable object containing various information about the path found. - - **Notes:** - - The resulting path indicates which terms of the input contraction should be - contracted first, the result of this contraction is then appended to the end of - the contraction list. - - **Examples:** - - We can begin with a chain dot example. In this case, it is optimal to - contract the b and c tensors represented by the first element of the path (1, - 2). The resulting tensor is added to the end of the contraction and the - remaining contraction, `(0, 1)`, is then executed. + Find a contraction order `path`, without performing the contraction. - ```python - a = np.random.rand(2, 2) - b = np.random.rand(2, 5) - c = np.random.rand(5, 2) - path_info = opt_einsum.contract_path('ij,jk,kl->il', a, b, c) - print(path_info[0]) - #> [(1, 2), (0, 1)] - print(path_info[1]) - #> Complete contraction: ij,jk,kl->il - #> Naive scaling: 4 - #> Optimized scaling: 3 - #> Naive FLOP count: 1.600e+02 - #> Optimized FLOP count: 5.600e+01 - #> Theoretical speedup: 2.857 - #> Largest intermediate: 4.000e+00 elements - #> ------------------------------------------------------------------------- - #> scaling current remaining - #> ------------------------------------------------------------------------- - #> 3 kl,jk->jl ij,jl->il - #> 3 jl,ij->il il->il - ``` - - A more complex index transformation example. - - ```python - I = np.random.rand(10, 10, 10, 10) - C = np.random.rand(10, 10) - path_info = oe.contract_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C) - - print(path_info[0]) - #> [(0, 2), (0, 3), (0, 2), (0, 1)] - print(path_info[1]) - #> Complete contraction: ea,fb,abcd,gc,hd->efgh - #> Naive scaling: 8 - #> Optimized scaling: 5 - #> Naive FLOP count: 8.000e+08 - #> Optimized FLOP count: 8.000e+05 - #> Theoretical speedup: 1000.000 - #> Largest intermediate: 1.000e+04 elements - #> -------------------------------------------------------------------------- - #> scaling current remaining - #> -------------------------------------------------------------------------- - #> 5 abcd,ea->bcde fb,gc,hd,bcde->efgh - #> 5 bcde,fb->cdef gc,hd,cdef->efgh - #> 5 cdef,gc->defg hd,defg->efgh - #> 5 defg,hd->efgh efgh->efgh - ``` + Parameters: + subscripts: Specifies the subscripts for summation. + *operands: These are the arrays for the operation. + use_blas: Do you use BLAS for valid operations, may use extra memory for more intermediates. + optimize: Choose the type of path the contraction will be optimized with. + - if a list is given uses this as the path. + - `'optimal'` An algorithm that explores all possible ways of + contracting the listed tensors. Scales factorially with the number of + terms in the contraction. + - `'dp'` A faster (but essentially optimal) algorithm that uses + dynamic programming to exhaustively search all contraction paths + without outer-products. + - `'greedy'` An cheap algorithm that heuristically chooses the best + pairwise contraction at each step. Scales linearly in the number of + terms in the contraction. + - `'random-greedy'` Run a randomized version of the greedy algorithm + 32 times and pick the best path. + - `'random-greedy-128'` Run a randomized version of the greedy + algorithm 128 times and pick the best path. + - `'branch-all'` An algorithm like optimal but that restricts itself + to searching 'likely' paths. Still scales factorially. + - `'branch-2'` An even more restricted version of 'branch-all' that + only searches the best two options at each step. Scales exponentially + with the number of terms in the contraction. + - `'auto'` Choose the best of the above algorithms whilst aiming to + keep the path finding time below 1ms. + - `'auto-hq'` Aim for a high quality contraction, choosing the best + of the above algorithms whilst aiming to keep the path finding time + below 1sec. + + memory_limit: Give the upper bound of the largest intermediate tensor contract will build. + - None or -1 means there is no limit + - `max_input` means the limit is set as largest input tensor + - a positive integer is taken as an explicit limit on the number of elements + + The default is None. Note that imposing a limit can make contractions + exponentially slower to perform. + + shapes: Whether ``contract_path`` should assume arrays (the default) or array shapes have been supplied. + + Returns: + path: The optimized einsum contraciton path + PathInfo: A printable object containing various information about the path found. + + Notes: + The resulting path indicates which terms of the input contraction should be + contracted first, the result of this contraction is then appended to the end of + the contraction list. + + Examples: + We can begin with a chain dot example. In this case, it is optimal to + contract the b and c tensors represented by the first element of the path (1, + 2). The resulting tensor is added to the end of the contraction and the + remaining contraction, `(0, 1)`, is then executed. + + ```python + a = np.random.rand(2, 2) + b = np.random.rand(2, 5) + c = np.random.rand(5, 2) + path_info = opt_einsum.contract_path('ij,jk,kl->il', a, b, c) + print(path_info[0]) + #> [(1, 2), (0, 1)] + print(path_info[1]) + #> Complete contraction: ij,jk,kl->il + #> Naive scaling: 4 + #> Optimized scaling: 3 + #> Naive FLOP count: 1.600e+02 + #> Optimized FLOP count: 5.600e+01 + #> Theoretical speedup: 2.857 + #> Largest intermediate: 4.000e+00 elements + #> ------------------------------------------------------------------------- + #> scaling current remaining + #> ------------------------------------------------------------------------- + #> 3 kl,jk->jl ij,jl->il + #> 3 jl,ij->il il->il + ``` + + A more complex index transformation example. + + ```python + I = np.random.rand(10, 10, 10, 10) + C = np.random.rand(10, 10) + path_info = oe.contract_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C) + + print(path_info[0]) + #> [(0, 2), (0, 3), (0, 2), (0, 1)] + print(path_info[1]) + #> Complete contraction: ea,fb,abcd,gc,hd->efgh + #> Naive scaling: 8 + #> Optimized scaling: 5 + #> Naive FLOP count: 8.000e+08 + #> Optimized FLOP count: 8.000e+05 + #> Theoretical speedup: 1000.000 + #> Largest intermediate: 1.000e+04 elements + #> -------------------------------------------------------------------------- + #> scaling current remaining + #> -------------------------------------------------------------------------- + #> 5 abcd,ea->bcde fb,gc,hd,bcde->efgh + #> 5 bcde,fb->cdef gc,hd,cdef->efgh + #> 5 cdef,gc->defg hd,defg->efgh + #> 5 defg,hd->efgh efgh->efgh + ``` """ - - # Make sure all keywords are valid - unknown_kwargs = set(kwargs) - _VALID_CONTRACT_KWARGS - if len(unknown_kwargs): - raise TypeError("einsum_path: Did not understand the following kwargs: {}".format(unknown_kwargs)) - - path_type = kwargs.pop("optimize", "auto") - - memory_limit = kwargs.pop("memory_limit", None) - shapes = kwargs.pop("shapes", False) + if optimize is True: + optimize = "auto" # Hidden option, only einsum should call this einsum_call_arg = kwargs.pop("einsum_call", False) - use_blas = kwargs.pop("use_blas", True) + if len(kwargs): + raise TypeError(f"Did not understand the following kwargs: {kwargs.keys()}") # Python side parsing - input_subscripts, output_subscript, operands = parser.parse_einsum_input(operands_, shapes=shapes) + operands_ = [subscripts] + list(operands) + input_subscripts, output_subscript, operands_prepped = parser.parse_einsum_input(operands_, shapes=shapes) # Build a few useful list and sets input_list = input_subscripts.split(",") input_sets = [frozenset(x) for x in input_list] if shapes: - input_shapes = operands + input_shapes = operands_prepped else: - input_shapes = [x.shape for x in operands] + input_shapes = [x.shape for x in operands_prepped] output_set = frozenset(output_subscript) indices = frozenset(input_subscripts.replace(",", "")) @@ -297,23 +343,22 @@ def contract_path(*operands_: Any, **kwargs: Any) -> Tuple[PathType, PathInfo]: # Compute naive cost # This is not quite right, need to look into exactly how einsum does this # indices_in_input = input_subscripts.replace(',', '') - inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0 naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict) # Compute the path - if not isinstance(path_type, (str, paths.PathOptimizer)): + if not isinstance(optimize, (str, paths.PathOptimizer)): # Custom path supplied - path = path_type + path_tuple: PathType = optimize # type: ignore elif num_ops <= 2: # Nothing to be optimized - path = [tuple(range(num_ops))] - elif isinstance(path_type, paths.PathOptimizer): + path_tuple = [tuple(range(num_ops))] + elif isinstance(optimize, paths.PathOptimizer): # Custom path optimizer supplied - path = path_type(input_sets, output_set, size_dict, memory_arg) + path_tuple = optimize(input_sets, output_set, size_dict, memory_arg) else: - path_optimizer = paths.get_path_fn(path_type) - path = path_optimizer(input_sets, output_set, size_dict, memory_arg) + path_optimizer = paths.get_path_fn(optimize) + path_tuple = path_optimizer(input_sets, output_set, size_dict, memory_arg) cost_list = [] scale_list = [] @@ -321,7 +366,7 @@ def contract_path(*operands_: Any, **kwargs: Any) -> Tuple[PathType, PathInfo]: contraction_list = [] # Build contraction tuple (positions, gemm, einsum_str, remaining) - for cnum, contract_inds in enumerate(path): + for cnum, contract_inds in enumerate(path_tuple): # Make sure we remove inds from right to left contract_inds = tuple(sorted(list(contract_inds), reverse=True)) @@ -343,7 +388,7 @@ def contract_path(*operands_: Any, **kwargs: Any) -> Tuple[PathType, PathInfo]: do_blas = False # Last contraction - if (cnum - len(path)) == -1: + if (cnum - len(path_tuple)) == -1: idx_result = output_subscript else: # use tensordot order to minimize transpositions @@ -370,14 +415,14 @@ def contract_path(*operands_: Any, **kwargs: Any) -> Tuple[PathType, PathInfo]: opt_cost = sum(cost_list) if einsum_call_arg: - return operands, contraction_list # type: ignore + return operands_prepped, contraction_list # type: ignore path_print = PathInfo( contraction_list, input_subscripts, output_subscript, indices, - path, + path_tuple, scale_list, naive_cost, opt_cost, @@ -385,11 +430,11 @@ def contract_path(*operands_: Any, **kwargs: Any) -> Tuple[PathType, PathInfo]: size_dict, ) - return path, path_print + return path_tuple, path_print @sharing.einsum_cache_wrap -def _einsum(*operands, **kwargs): +def _einsum(*operands: Any, **kwargs: Any) -> ArrayType: """Base einsum, but with pre-parse for valid characters if a string is given.""" fn = backends.get_func("einsum", kwargs.pop("backend", "numpy")) @@ -407,6 +452,7 @@ def _einsum(*operands, **kwargs): einsum_str = parser.convert_to_valid_einsum_chars(einsum_str) + kwargs = _filter_einsum_defaults(kwargs) # type: ignore return fn(einsum_str, *operands, **kwargs) @@ -430,21 +476,67 @@ def _tensordot(x: ArrayType, y: ArrayType, axes: Tuple[int, ...], backend: str = # Rewrite einsum to handle different cases -def contract(*operands_: Any, **kwargs: Any) -> ArrayType: + + +@overload +def contract( + subscripts: str, + *operands: ArrayType, + out: ArrayType = ..., + dtype: Any = ..., + order: _OrderKACF = ..., + casting: _Casting = ..., + use_blas: bool = ..., + optimize: OptimizeKind = ..., + memory_limit: _MemoryLimit = ..., + backend: BackendType = ..., + **kwargs: Any, +) -> ArrayType: ... + + +@overload +def contract( + subscripts: ArrayType, + *operands: Union[ArrayType, Collection[int]], + out: ArrayType = ..., + dtype: Any = ..., + order: _OrderKACF = ..., + casting: _Casting = ..., + use_blas: bool = ..., + optimize: OptimizeKind = ..., + memory_limit: _MemoryLimit = ..., + backend: BackendType = ..., + **kwargs: Any, +) -> ArrayType: ... + + +def contract( + subscripts: Union[str, ArrayType], + *operands: Union[ArrayType, Collection[int]], + out: Optional[ArrayType] = None, + dtype: Optional[str] = None, + order: _OrderKACF = "K", + casting: _Casting = "safe", + use_blas: bool = True, + optimize: OptimizeKind = True, + memory_limit: _MemoryLimit = None, + backend: BackendType = "auto", + **kwargs: Any, +) -> ArrayType: """ Evaluates the Einstein summation convention on the operands. A drop in replacement for NumPy's einsum function that optimizes the order of contraction to reduce overall scaling at the cost of several intermediate arrays. Parameters: - subscripts: *(str)* Specifies the subscripts for summation. - \\*operands: *(list of array_like)* hese are the arrays for the operation. - out: *(array_like)* A output array in which set the sresulting output. - dtype - *(str)* The dtype of the given contraction, see np.einsum. - order - *(str)* The order of the resulting contraction, see np.einsum. - casting - *(str)* The casting procedure for operations of different dtype, see np.einsum. - use_blas - *(bool)* Do you use BLAS for valid operations, may use extra memory for more intermediates. - optimize - *(str, list or bool, optional (default: ``auto``))* Choose the type of path. + subscripts: Specifies the subscripts for summation. + *operands: These are the arrays for the operation. + out: A output array in which set the resulting output. + dtype: The dtype of the given contraction, see np.einsum. + order: The order of the resulting contraction, see np.einsum. + casting: The casting procedure for operations of different dtype, see np.einsum. + use_blas: Do you use BLAS for valid operations, may use extra memory for more intermediates. + optimize:- Choose the type of path the contraction will be optimized with - if a list is given uses this as the path. - `'optimal'` An algorithm that explores all possible ways of contracting the listed tensors. Scales factorially with the number of @@ -470,17 +562,17 @@ def contract(*operands_: Any, **kwargs: Any) -> ArrayType: of the above algorithms whilst aiming to keep the path finding time below 1sec. - memory_limit - *({None, int, 'max_input'} (default: `None`))* - Give the upper bound of the largest intermediate tensor contract will build. - - None or -1 means there is no limit - - `max_input` means the limit is set as largest input tensor - - a positive integer is taken as an explicit limit on the number of elements + memory_limit:- Give the upper bound of the largest intermediate tensor contract will build. + - None or -1 means there is no limit. + - `max_input` means the limit is set as largest input tensor. + - A positive integer is taken as an explicit limit on the number of elements. The default is None. Note that imposing a limit can make contractions exponentially slower to perform. - - backend - *(str, optional (default: ``auto``))* Which library to use to perform the required ``tensordot``, ``transpose`` + backend: Which library to use to perform the required ``tensordot``, ``transpose`` and ``einsum`` calls. Should match the types of arrays supplied, See - :func:`contract_expression` for generating expressions which convert + `contract_expression` for generating expressions which convert numpy arrays to and from the backend library automatically. Returns: @@ -499,44 +591,38 @@ def contract(*operands_: Any, **kwargs: Any) -> ArrayType: performed optimally. When NumPy is linked to a threaded BLAS, potential speedups are on the order of 20-100 for a six core machine. """ - optimize_arg = kwargs.pop("optimize", True) - if optimize_arg is True: - optimize_arg = "auto" + if optimize is True: + optimize = "auto" - valid_einsum_kwargs = ["out", "dtype", "order", "casting"] - einsum_kwargs = {k: v for (k, v) in kwargs.items() if k in valid_einsum_kwargs} + operands_list = [subscripts] + list(operands) + einsum_kwargs = {"out": out, "dtype": dtype, "order": order, "casting": casting} # If no optimization, run pure einsum - if optimize_arg is False: - return _einsum(*operands_, **einsum_kwargs) + if optimize is False: + return _einsum(*operands_list, **einsum_kwargs) # Grab non-einsum kwargs - use_blas = kwargs.pop("use_blas", True) - memory_limit = kwargs.pop("memory_limit", None) - backend = kwargs.pop("backend", "auto") gen_expression = kwargs.pop("_gen_expression", False) constants_dict = kwargs.pop("_constants_dict", {}) - - # Make sure remaining keywords are valid for einsum - unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_einsum_kwargs] - if len(unknown_kwargs): - raise TypeError("Did not understand the following kwargs: {}".format(unknown_kwargs)) + if len(kwargs): + raise TypeError(f"Did not understand the following kwargs: {kwargs.keys()}") if gen_expression: - full_str = operands_[0] + full_str = operands_list[0] # Build the contraction list and operand - operands: Sequence[ArrayType] contraction_list: ContractionListType operands, contraction_list = contract_path( # type: ignore - *operands_, optimize=optimize_arg, memory_limit=memory_limit, einsum_call=True, use_blas=use_blas + *operands_list, optimize=optimize, memory_limit=memory_limit, einsum_call=True, use_blas=use_blas ) # check if performing contraction or just building expression if gen_expression: - return ContractExpression(full_str, contraction_list, constants_dict, **einsum_kwargs) + return ContractExpression(full_str, contraction_list, constants_dict, dtype=dtype, order=order, casting=casting) - return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs) + return _core_contract( + operands, contraction_list, backend=backend, out=out, dtype=dtype, order=order, casting=casting + ) @lru_cache(None) @@ -569,15 +655,17 @@ def _core_contract( contraction_list: ContractionListType, backend: Optional[str] = "auto", evaluate_constants: bool = False, - **einsum_kwargs: Any, + out: Optional[ArrayType] = None, + dtype: Optional[str] = None, + order: _OrderKACF = "K", + casting: _Casting = "safe", ) -> ArrayType: """Inner loop used to perform an actual contraction given the output from a ``contract_path(..., einsum_call=True)`` call. """ # Special handling if out is specified - out_array = einsum_kwargs.pop("out", None) - specified_out = out_array is not None + specified_out = out is not None operands = list(operands_) backend = parse_backend(operands, backend) @@ -631,23 +719,23 @@ def _core_contract( new_view = _transpose(new_view, axes=transpose, backend=backend) if handle_out: - out_array[:] = new_view + out[:] = new_view # type: ignore - # Call einsum else: - # If out was specified + # Call einsum + out_kwarg: Union[None, ArrayType] = None if handle_out: - einsum_kwargs["out"] = out_array - - # Do the contraction - new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs) + out_kwarg = out + new_view = _einsum( + einsum_str, *tmp_operands, backend=backend, dtype=dtype, order=order, casting=casting, out=out_kwarg + ) # Append new items and dereference what we can operands.append(new_view) del tmp_operands, new_view if specified_out: - return out_array + return out else: return operands[0] @@ -688,10 +776,14 @@ def __init__( contraction: str, contraction_list: ContractionListType, constants_dict: Dict[int, ArrayType], - **einsum_kwargs: Any, + dtype: Optional[str] = None, + order: _OrderKACF = "K", + casting: _Casting = "safe", ): self.contraction_list = contraction_list - self.einsum_kwargs = einsum_kwargs + self.dtype = dtype + self.order = order + self.casting = casting self.contraction = format_const_einsum_str(contraction, constants_dict.keys()) # need to know _full_num_args to parse constants with, and num_args to call with @@ -760,7 +852,9 @@ def _contract( out=out, backend=backend, evaluate_constants=evaluate_constants, - **self.einsum_kwargs, + dtype=self.dtype, + order=self.order, + casting=self.casting, ) def _contract_with_conversion( @@ -790,29 +884,28 @@ def _contract_with_conversion( return result - def __call__(self, *arrays: ArrayType, **kwargs: Any) -> ArrayType: + def __call__( + self, + *arrays: ArrayType, + out: Union[None, ArrayType] = None, + backend: str = "auto", + evaluate_constants: bool = False, + ) -> ArrayType: """Evaluate this expression with a set of arrays. - Parameters - ---------- - arrays : seq of array - The arrays to supply as input to the expression. - out : array, optional (default: ``None``) - If specified, output the result into this array. - backend : str, optional (default: ``numpy``) - Perform the contraction with this backend library. If numpy arrays - are supplied then try to convert them to and from the correct - backend array type. + Parameters: + arrays: The arrays to supply as input to the expression. + out: If specified, output the result into this array. + backend: Perform the contraction with this backend library. If numpy arrays + are supplied then try to convert them to and from the correct + backend array type. + evaluate_constants: Pre-evaluates constants with the appropriate backend. + + Returns: + The contracted result. """ - out = kwargs.pop("out", None) - backend = parse_backend(arrays, kwargs.pop("backend", "auto")) - evaluate_constants = kwargs.pop("evaluate_constants", False) - if kwargs: - raise ValueError( - "The only valid keyword arguments to a `ContractExpression` " - "call are `out=` or `backend=`. Got: {}.".format(kwargs) - ) + backend = parse_backend(arrays, backend) correct_num_args = self._full_num_args if evaluate_constants else self.num_args @@ -835,7 +928,7 @@ def __call__(self, *arrays: ArrayType, **kwargs: Any) -> ArrayType: if backends.has_backend(backend) and all(infer_backend(x) == "numpy" for x in arrays): return self._contract_with_conversion(ops, out, backend, evaluate_constants=evaluate_constants) - return self._contract(ops, out, backend, evaluate_constants=evaluate_constants) + return self._contract(ops, out=out, backend=backend, evaluate_constants=evaluate_constants) except ValueError as err: original_msg = str(err.args) if err.args else "" @@ -859,55 +952,88 @@ def __str__(self) -> str: for i, c in enumerate(self.contraction_list): s.append("\n {}. ".format(i + 1)) s.append("'{}'".format(c[2]) + (" [{}]".format(c[-1]) if c[-1] else "")) - if self.einsum_kwargs: - s.append("\neinsum_kwargs={}".format(self.einsum_kwargs)) + kwargs = {"dtype": self.dtype, "order": self.order, "casting": self.casting} + s.append(f"\neinsum_kwargs={kwargs}") return "".join(s) Shaped = namedtuple("Shaped", ["shape"]) -def shape_only(shape: PathType) -> Shaped: +def shape_only(shape: TensorShapeType) -> Shaped: """Dummy ``numpy.ndarray`` which has a shape only - for generating contract expressions. """ return Shaped(shape) -def contract_expression(subscripts: str, *shapes: PathType, **kwargs: Any) -> Any: +# Overlaod for contract(einsum_string, *operands) +@overload +def contract_expression( + subscripts: str, + *operands: Union[ArrayType, TensorShapeType], + constants: Union[Collection[int], None] = ..., + use_blas: bool = ..., + optimize: OptimizeKind = ..., + memory_limit: _MemoryLimit = ..., + **kwargs: Any, +) -> ContractExpression: ... + + +# Overlaod for contract(operand, indices, operand, indices, ....) +@overload +def contract_expression( + subscripts: Union[ArrayType, TensorShapeType], + *operands: Union[ArrayType, TensorShapeType, Collection[int]], + constants: Union[Collection[int], None] = ..., + use_blas: bool = ..., + optimize: OptimizeKind = ..., + memory_limit: _MemoryLimit = ..., + **kwargs: Any, +) -> ContractExpression: ... + + +def contract_expression( + subscripts: Union[str, ArrayType, TensorShapeType], + *shapes: Union[ArrayType, TensorShapeType, Collection[int]], + constants: Union[Collection[int], None] = None, + use_blas: bool = True, + optimize: OptimizeKind = True, + memory_limit: _MemoryLimit = None, + **kwargs: Any, +) -> ContractExpression: """Generate a reusable expression for a given contraction with specific shapes, which can, for example, be cached. - **Parameters:** - - - **subscripts** - *(str)* Specifies the subscripts for summation. - - **shapes** - *(sequence of integer tuples)* Shapes of the arrays to optimize the contraction for. - - **constants** - *(sequence of int, optional)* The indices of any constant arguments in `shapes`, in which case the - actual array should be supplied at that position rather than just a - shape. If these are specified, then constant parts of the contraction - between calls will be reused. Additionally, if a GPU-enabled backend is - used for example, then the constant tensors will be kept on the GPU, - minimizing transfers. - - **kwargs** - Passed on to `contract_path` or `einsum`. See `contract`. + Parameters: - **Returns:** + subscripts: Specifies the subscripts for summation. + shapes: Shapes of the arrays to optimize the contraction for. + constants: The indices of any constant arguments in `shapes`, in which case the + actual array should be supplied at that position rather than just a + shape. If these are specified, then constant parts of the contraction + between calls will be reused. Additionally, if a GPU-enabled backend is + used for example, then the constant tensors will be kept on the GPU, + minimizing transfers. + kwargs: Passed on to `contract_path` or `einsum`. See `contract`. - - **expr** - *(ContractExpression)* Callable with signature `expr(*arrays, out=None, backend='numpy')` where the array's shapes should match `shapes`. + Returns: + Callable with signature `expr(*arrays, out=None, backend='numpy')` where the array's shapes should match `shapes`. - **Notes:** + Notes: - - The `out` keyword argument should be supplied to the generated expression - rather than this function. - - The `backend` keyword argument should also be supplied to the generated - expression. If numpy arrays are supplied, if possible they will be - converted to and back from the correct backend array type. - - The generated expression will work with any arrays which have - the same rank (number of dimensions) as the original shapes, however, if - the actual sizes are different, the expression may no longer be optimal. - - Constant operations will be computed upon the first call with a particular - backend, then subsequently reused. + The `out` keyword argument should be supplied to the generated expression + rather than this function. + The `backend` keyword argument should also be supplied to the generated + expression. If numpy arrays are supplied, if possible they will be + converted to and back from the correct backend array type. + The generated expression will work with any arrays which have + the same rank (number of dimensions) as the original shapes, however, if + the actual sizes are different, the expression may no longer be optimal. + Constant operations will be computed upon the first call with a particular + backend, then subsequently reused. - **Examples:** + Examples: Basic usage: @@ -932,7 +1058,7 @@ def contract_expression(subscripts: str, *shapes: PathType, **kwargs: Any) -> An ``` """ - if not kwargs.get("optimize", True): + if not optimize: raise ValueError("Can only generate expressions for optimized contractions.") for arg in ("out", "backend"): @@ -943,16 +1069,18 @@ def contract_expression(subscripts: str, *shapes: PathType, **kwargs: Any) -> An ) if not isinstance(subscripts, str): - subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes) + subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes) # type: ignore kwargs["_gen_expression"] = True # build dict of constant indices mapped to arrays - constants = kwargs.pop("constants", ()) + constants = constants or tuple() constants_dict = {i: shapes[i] for i in constants} kwargs["_constants_dict"] = constants_dict # apart from constant arguments, make dummy arrays - dummy_arrays = [s if i in constants else shape_only(s) for i, s in enumerate(shapes)] + dummy_arrays = [s if i in constants else shape_only(s) for i, s in enumerate(shapes)] # type: ignore - return contract(subscripts, *dummy_arrays, **kwargs) + return contract( + subscripts, *dummy_arrays, use_blas=use_blas, optimize=optimize, memory_limit=memory_limit, **kwargs + ) diff --git a/opt_einsum/helpers.py b/opt_einsum/helpers.py index 3a1c39a1..594212b3 100644 --- a/opt_einsum/helpers.py +++ b/opt_einsum/helpers.py @@ -2,12 +2,12 @@ Contains helper functions for opt_einsum testing scripts """ -from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Optional, Tuple, Union, overload +from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Literal, Optional, Tuple, Union, overload import numpy as np -from .parser import get_symbol -from .typing import ArrayIndexType, PathType +from opt_einsum.parser import get_symbol +from opt_einsum.typing import ArrayIndexType, PathType __all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"] @@ -191,9 +191,36 @@ def flop_count( return overall_size * op_factor +@overload +def rand_equation( + n: int, + regularity: int, + n_out: int = ..., + d_min: int = ..., + d_max: int = ..., + seed: Optional[int] = ..., + global_dim: bool = ..., + *, + return_size_dict: Literal[True], +) -> Tuple[str, PathType, Dict[str, int]]: ... + + +@overload +def rand_equation( + n: int, + regularity: int, + n_out: int = ..., + d_min: int = ..., + d_max: int = ..., + seed: Optional[int] = ..., + global_dim: bool = ..., + return_size_dict: Literal[False] = ..., +) -> Tuple[str, PathType]: ... + + def rand_equation( n: int, - reg: int, + regularity: int, n_out: int = 0, d_min: int = 2, d_max: int = 9, @@ -203,60 +230,48 @@ def rand_equation( ) -> Union[Tuple[str, PathType, Dict[str, int]], Tuple[str, PathType]]: """Generate a random contraction and shapes. - Parameters - ---------- - n : int - Number of array arguments. - reg : int - 'Regularity' of the contraction graph. This essentially determines how - many indices each tensor shares with others on average. - n_out : int, optional - Number of output indices (i.e. the number of non-contracted indices). - Defaults to 0, i.e., a contraction resulting in a scalar. - d_min : int, optional - Minimum dimension size. - d_max : int, optional - Maximum dimension size. - seed: int, optional - If not None, seed numpy's random generator with this. - global_dim : bool, optional - Add a global, 'broadcast', dimension to every operand. - return_size_dict : bool, optional - Return the mapping of indices to sizes. - - Returns - ------- - eq : str - The equation string. - shapes : list[tuple[int]] - The array shapes. - size_dict : dict[str, int] - The dict of index sizes, only returned if ``return_size_dict=True``. - - Examples - -------- - >>> eq, shapes = rand_equation(n=10, reg=4, n_out=5, seed=42) - >>> eq - 'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda' - - >>> shapes - [(9, 5, 4, 5, 4), - (4, 4, 8, 5), - (9, 4, 6, 9), - (6, 6), - (6, 9, 7, 8), - (4,), - (9, 3, 9, 4, 9), - (6, 8, 4, 6, 8, 6, 3), - (4, 7, 8, 8, 6, 9, 6), - (9, 5, 3, 3, 9, 5)] + Parameters: + n: Number of array arguments. + regularity: 'Regularity' of the contraction graph. This essentially determines how + many indices each tensor shares with others on average. + n_out: Number of output indices (i.e. the number of non-contracted indices). + Defaults to 0, i.e., a contraction resulting in a scalar. + d_min: Minimum dimension size. + d_max: Maximum dimension size. + seed: If not None, seed numpy's random generator with this. + global_dim: Add a global, 'broadcast', dimension to every operand. + return_size_dict: Return the mapping of indices to sizes. + + Returns: + eq: The equation string. + shapes: The array shapes. + size_dict: The dict of index sizes, only returned if ``return_size_dict=True``. + + Examples: + ```python + >>> eq, shapes = rand_equation(n=10, regularity=4, n_out=5, seed=42) + >>> eq + 'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda' + + >>> shapes + [(9, 5, 4, 5, 4), + (4, 4, 8, 5), + (9, 4, 6, 9), + (6, 6), + (6, 9, 7, 8), + (4,), + (9, 3, 9, 4, 9), + (6, 8, 4, 6, 8, 6, 3), + (4, 7, 8, 8, 6, 9, 6), + (9, 5, 3, 3, 9, 5)] + ``` """ if seed is not None: np.random.seed(seed) # total number of indices - num_inds = n * reg // 2 + n_out + num_inds = n * regularity // 2 + n_out inputs = ["" for _ in range(n)] output = [] @@ -302,9 +317,11 @@ def gen(): # make the shapes shapes = [tuple(size_dict[ix] for ix in op) for op in inputs] - ret = (eq, shapes) - if return_size_dict: - return ret + (size_dict,) + return ( + eq, + shapes, + size_dict, + ) else: - return ret + return (eq, shapes) diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index e390ec57..47567ae5 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -224,7 +224,7 @@ def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str: return new_sub -def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[str, List[Any]]: +def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[str, Tuple[ArrayType, ...]]: """Convert 'interleaved' input to standard einsum input.""" tmp_operands = list(operands) operand_list = [] @@ -259,7 +259,7 @@ def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[s subscripts += "->" subscripts += convert_subscripts(output_list, symbol_map) - return subscripts, operands + return subscripts, tuple(operands) def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]: diff --git a/opt_einsum/path_random.py b/opt_einsum/path_random.py index a7ea6587..ae7eff5a 100644 --- a/opt_einsum/path_random.py +++ b/opt_einsum/path_random.py @@ -5,12 +5,12 @@ import functools import heapq import math -import numbers import time from collections import deque +from decimal import Decimal from random import choices as random_choices from random import seed as random_seed -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple +from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union from . import helpers, paths from .typing import ArrayIndexType, ArrayType, PathType @@ -42,28 +42,26 @@ def trial_fn(r, *trial_args): random number generator. See `RandomGreedy` for an example. - **Parameters:** - - - **max_repeats** - *(int, optional)* The maximum number of repeat trials to have. - - **max_time** - *(float, optional)* The maximum amount of time to run the algorithm for. - - **minimize** - *({'flops', 'size'}, optional)* Whether to favour paths that minimize the total estimated flop-count or - the size of the largest intermediate created. - - **parallel** - *({bool, int, or executor-pool like}, optional)* Whether to parallelize the random trials, by default `False`. If - `True`, use a `concurrent.futures.ProcessPoolExecutor` with the same - number of processes as cores. If an integer is specified, use that many - processes instead. Finally, you can supply a custom executor-pool which - should have an API matching that of the python 3 standard library - module `concurrent.futures`. Namely, a `submit` method that returns - `Future` objects, themselves with `result` and `cancel` methods. - - **pre_dispatch** - *(int, optional)* If running in parallel, how many jobs to pre-dispatch so as to avoid - submitting all jobs at once. Should also be more than twice the number - of workers to avoid under-subscription. Default: 128. - - **Attributes:** - - - **path** - *(list[tuple[int]])* The best path found so far. - - **costs** - *(list[int])* The list of each trial's costs found so far. - - **sizes** - *(list[int])* The list of each trial's largest intermediate size so far. + Parameters: + max_repeats: The maximum number of repeat trials to have. + max_time: The maximum amount of time to run the algorithm for. + minimize: Whether to favour paths that minimize the total estimated flop-count or + the size of the largest intermediate created. + parallel: Whether to parallelize the random trials, by default `False`. If + `True`, use a `concurrent.futures.ProcessPoolExecutor` with the same + number of processes as cores. If an integer is specified, use that many + processes instead. Finally, you can supply a custom executor-pool which + should have an API matching that of the python 3 standard library + module `concurrent.futures`. Namely, a `submit` method that returns + `Future` objects, themselves with `result` and `cancel` methods. + pre_dispatch: If running in parallel, how many jobs to pre-dispatch so as to avoid + submitting all jobs at once. Should also be more than twice the number + of workers to avoid under-subscription. Default: 128. + + Attributes: + path: The best path found so far. + costs: The list of each trial's costs found so far. + sizes: The list of each trial's largest intermediate size so far. """ def __init__( @@ -71,7 +69,7 @@ def __init__( max_repeats: int = 32, max_time: Optional[float] = None, minimize: str = "flops", - parallel: bool = False, + parallel: Union[bool, Decimal, int] = False, pre_dispatch: int = 128, ): @@ -82,7 +80,7 @@ def __init__( self.max_time = max_time self.minimize = minimize self.better = paths.get_better_fn(minimize) - self._parallel = False + self._parallel: Union[bool, Decimal, int] = False self.parallel = parallel self.pre_dispatch = pre_dispatch @@ -100,11 +98,11 @@ def path(self) -> PathType: return paths.ssa_to_linear(self.best["ssa_path"]) @property - def parallel(self) -> bool: + def parallel(self) -> Union[bool, Decimal, int]: return self._parallel @parallel.setter - def parallel(self, parallel: bool) -> None: + def parallel(self, parallel: Union[bool, Decimal, int]) -> None: # shutdown any previous executor if we are managing it if getattr(self, "_managing_executor", False): self._executor.shutdown() @@ -123,10 +121,10 @@ def parallel(self, parallel: bool) -> None: self._managing_executor = True return - if isinstance(parallel, numbers.Number): + if isinstance(parallel, (int, Decimal)): from concurrent.futures import ProcessPoolExecutor - self._executor = ProcessPoolExecutor(parallel) + self._executor = ProcessPoolExecutor(int(parallel)) self._managing_executor = True return @@ -225,26 +223,24 @@ def thermal_chooser(queue, remaining, nbranch=8, temperature=1, rel_temperature= `abs(c_0)` to account for likely fluctuating cost magnitudes during the course of a contraction. - **Parameters:** - - - **queue** - *(list)* The heapified list of candidate contractions. - - **remaining** - *(dict[str, int])* Mapping of remaining inputs' indices to the ssa id. - - **temperature** - *(float, optional)* When choosing a possible contraction, its relative probability will be - proportional to `exp(-cost / temperature)`. Thus the larger - `temperature` is, the further random paths will stray from the normal - 'greedy' path. Conversely, if set to zero, only paths with exactly the - same cost as the best at each step will be explored. - - **rel_temperature** - *(bool, optional)* Whether to normalize the `temperature` at each step to the scale of - the best cost. This is generally beneficial as the magnitude of costs - can vary significantly throughout a contraction. - - **nbranch** - *(int, optional)* How many potential paths to calculate probability for and choose from at each step. - - **Returns:** - - - **cost** - - **k1** - - **k2** - - **k3** + Parameters: + queue: The heapified list of candidate contractions. + remaining: Mapping of remaining inputs' indices to the ssa id. + temperature: When choosing a possible contraction, its relative probability will be + proportional to `exp(-cost / temperature)`. Thus the larger + `temperature` is, the further random paths will stray from the normal + 'greedy' path. Conversely, if set to zero, only paths with exactly the + same cost as the best at each step will be explored. + rel_temperature: Whether to normalize the `temperature` at each step to the scale of + the best cost. This is generally beneficial as the magnitude of costs + can vary significantly throughout a contraction. + nbranch: How many potential paths to calculate probability for and choose from at each step. + + Returns: + cost + k1 + k2 + k3 """ n = 0 choices = [] @@ -332,27 +328,6 @@ def _trial_greedy_ssa_path_and_cost( class RandomGreedy(RandomOptimizer): - """ - - **Parameters:** - - - **cost_fn** - *(callable, optional)* A function that returns a heuristic 'cost' of a potential contraction - with which to sort candidates. Should have signature - `cost_fn(size12, size1, size2, k12, k1, k2)`. - - **temperature** - *(float, optional)* When choosing a possible contraction, its relative probability will be - proportional to `exp(-cost / temperature)`. Thus the larger - `temperature` is, the further random paths will stray from the normal - 'greedy' path. Conversely, if set to zero, only paths with exactly the - same cost as the best at each step will be explored. - - **rel_temperature** - *(bool, optional)* Whether to normalize the ``temperature`` at each step to the scale of - the best cost. This is generally beneficial as the magnitude of costs - can vary significantly throughout a contraction. If False, the - algorithm will end up branching when the absolute cost is low, but - stick to the 'greedy' path when the cost is high - this can also be - beneficial. - - **nbranch** - *(int, optional)* How many potential paths to calculate probability for and choose from at each step. - - **kwargs** - Supplied to RandomOptimizer. - """ def __init__( self, @@ -362,6 +337,25 @@ def __init__( nbranch: int = 8, **kwargs: Any, ): + """ + Parameters: + cost_fn: A function that returns a heuristic 'cost' of a potential contraction + with which to sort candidates. Should have signature + `cost_fn(size12, size1, size2, k12, k1, k2)`. + temperature: When choosing a possible contraction, its relative probability will be + proportional to `exp(-cost / temperature)`. Thus the larger + `temperature` is, the further random paths will stray from the normal + 'greedy' path. Conversely, if set to zero, only paths with exactly the + same cost as the best at each step will be explored. + rel_temperature: Whether to normalize the ``temperature`` at each step to the scale of + the best cost. This is generally beneficial as the magnitude of costs + can vary significantly throughout a contraction. If False, the + algorithm will end up branching when the absolute cost is low, but + stick to the 'greedy' path when the cost is high - this can also be + beneficial. + nbranch: How many potential paths to calculate probability for and choose from at each step. + kwargs: Supplied to RandomOptimizer. + """ self.cost_fn = cost_fn self.temperature = temperature self.rel_temperature = rel_temperature diff --git a/opt_einsum/paths.py b/opt_einsum/paths.py index 761f56af..902ace30 100644 --- a/opt_einsum/paths.py +++ b/opt_einsum/paths.py @@ -15,8 +15,8 @@ import numpy as np -from .helpers import compute_size_by_dict, flop_count -from .typing import ArrayIndexType, PathType +from opt_einsum.helpers import compute_size_by_dict, flop_count +from opt_einsum.typing import ArrayIndexType, PathSearchFunctionType, PathType, TensorShapeType __all__ = [ "optimal", @@ -41,16 +41,15 @@ class PathOptimizer: ```python def __call__(self, inputs, output, size_dict, memory_limit=None): \"\"\" - **Parameters:** - ---------- - inputs : list[set[str]] - The indices of each input array. - outputs : set[str] - The output indices - size_dict : dict[str, int] - The size of each index - memory_limit : int, optional - If given, the maximum allowed memory. + Parameters: + inputs : list[set[str]] + The indices of each input array. + outputs : set[str] + The output indices + size_dict : dict[str, int] + The size of each index + memory_limit : int, optional + If given, the maximum allowed memory. \"\"\" # ... compute path here ... return path @@ -91,12 +90,13 @@ def __call__( def ssa_to_linear(ssa_path: PathType) -> PathType: """ Convert a path with static single assignment ids to a path with recycled - linear ids. For example: + linear ids. - ```python - ssa_to_linear([(0, 3), (2, 4), (1, 5)]) - #> [(0, 3), (1, 2), (0, 1)] - ``` + Example: + ```python + ssa_to_linear([(0, 3), (2, 4), (1, 5)]) + #> [(0, 3), (1, 2), (0, 1)] + ``` """ ids = np.arange(1 + max(map(max, ssa_path)), dtype=np.int32) path = [] @@ -110,12 +110,13 @@ def ssa_to_linear(ssa_path: PathType) -> PathType: def linear_to_ssa(path: PathType) -> PathType: """ Convert a path with recycled linear ids to a path with static single - assignment ids. For example:: + assignment ids. - ```python - linear_to_ssa([(0, 3), (1, 2), (0, 1)]) - #> [(0, 3), (2, 4), (1, 5)] - ``` + Exmaple: + ```python + linear_to_ssa([(0, 3), (1, 2), (0, 1)]) + #> [(0, 3), (2, 4), (1, 5)] + ``` """ num_inputs = sum(map(len, path)) - len(path) + 1 linear_to_ssa = list(range(num_inputs)) @@ -141,21 +142,19 @@ def calc_k12_flops( Calculate the resulting indices and flops for a potential pairwise contraction - used in the recursive (optimal/branch) algorithms. - **Parameters:** - - - **inputs** - *(tuple[frozenset[str]])* The indices of each tensor in this contraction, note this includes - tensors unavailable to contract as static single assignment is used -> - contracted tensors are not removed from the list. - - **output** - *(frozenset[str])* The set of output indices for the whole contraction. - - **remaining** - *(frozenset[int])* The set of indices (corresponding to ``inputs``) of tensors still available to contract. - - **i** - *(int)* Index of potential tensor to contract. - - **j** - *(int)* Index of potential tensor to contract. - - **size_dict : dict[str, int] )* Size mapping of all the indices. - - **Returns:** - - - **k12** - *(frozenset)* The resulting indices of the potential tensor. - - **cost** - *(int)* Estimated flop count of operation. + Parameters: + inputs: The indices of each tensor in this contraction, note this includes + tensors unavailable to contract as static single assignment is used:> + contracted tensors are not removed from the list. + output: The set of output indices for the whole contraction. + remaining: *The set of indices (corresponding to ``inputs``) of tensors still available to contract. + i: Index of potential tensor to contract. + j: Index of potential tensor to contract. + size_dict: Size mapping of all the indices. + + Returns: + k12: The resulting indices of the potential tensor. + cost: Estimated flop count of operation. """ k1, k2 = inputs[i], inputs[j] either = k1 | k2 @@ -193,21 +192,17 @@ def optimal( """ Computes all possible pair contractions in a depth-first recursive manner, sieving results based on `memory_limit` and the best path found so far. - **Returns:** the lowest cost path. This algorithm scales factoriallly with - respect to the elements in the list `input_sets`. - - **Parameters:** - - **inputs** - *(list)* List of sets that represent the lhs side of the einsum subscript. - - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript. - - **size_dict** - *(dictionary)* Dictionary of index sizes. - - **memory_limit** - *(int)* The maximum number of elements in a temporary array. + Parameters: + inputs: List of sets that represent the lhs side of the einsum subscript. + output: Set that represents the rhs side of the overall einsum subscript. + size_dict: Dictionary of index sizes. + memory_limit: The maximum number of elements in a temporary array. - **Returns:** + Returns: + path: The optimal contraction order within the memory limit constraint. - - **path** - *(list)* The optimal contraction order within the memory limit constraint. - - **Examples:** + Examples: ```python isets = [set('abd'), set('ac'), set('bdc')] @@ -321,48 +316,44 @@ def cost_memory_removed_jitter(size12: int, size1: int, size2: int, k12: int, k1 class BranchBound(PathOptimizer): - """ - Explores possible pair contractions in a depth-first recursive manner like - the `optimal` approach, but with extra heuristic early pruning of branches - as well sieving by `memory_limit` and the best path found so far. **Returns:** - the lowest cost path. This algorithm still scales factorially with respect - to the elements in the list `input_sets` if `nbranch` is not set, but it - scales exponentially like `nbranch**len(input_sets)` otherwise. - - **Parameters:** - - - **nbranch** - *(None or int, optional)* How many branches to explore at each contraction step. If None, explore - all possible branches. If an integer, branch into this many paths at - each step. Defaults to None. - - **cutoff_flops_factor** - *(float, optional)* If at any point, a path is doing this much worse than the best path - found so far was, terminate it. The larger this is made, the more paths - will be fully explored and the slower the algorithm. Defaults to 4. - - **minimize** - *({'flops', 'size'}, optional)* Whether to optimize the path with regard primarily to the total - estimated flop-count, or the size of the largest intermediate. The - option not chosen will still be used as a secondary criterion. - - **cost_fn** - *(callable, optional)* A function that returns a heuristic 'cost' of a potential contraction - with which to sort candidates. Should have signature - `cost_fn(size12, size1, size2, k12, k1, k2)`. - """ - def __init__( self, - nbranch=None, - cutoff_flops_factor=4, - minimize="flops", - cost_fn="memory-removed", + nbranch: Optional[int] = None, + cutoff_flops_factor: int = 4, + minimize: str = "flops", + cost_fn: str = "memory-removed", ): + """ + Explores possible pair contractions in a depth-first recursive manner like + the `optimal` approach, but with extra heuristic early pruning of branches + as well sieving by `memory_limit` and the best path found so far. + + + Parameters: + nbranch: How many branches to explore at each contraction step. If None, explore + all possible branches. If an integer, branch into this many paths at + each step. Defaults to None. + cutoff_flops_factor: If at any point, a path is doing this much worse than the best path + found so far was, terminate it. The larger this is made, the more paths + will be fully explored and the slower the algorithm. Defaults to 4. + minimize: Whether to optimize the path with regard primarily to the total + estimated flop-count, or the size of the largest intermediate. The + option not chosen will still be used as a secondary criterion. + cost_fn: A function that returns a heuristic 'cost' of a potential contraction + with which to sort candidates. Should have signature + `cost_fn(size12, size1, size2, k12, k1, k2)`. + """ if (nbranch is not None) and nbranch < 1: raise ValueError(f"The number of branches must be at least one, `nbranch={nbranch}`.") self.nbranch = nbranch self.cutoff_flops_factor = cutoff_flops_factor self.minimize = minimize - self.cost_fn = _COST_FNS.get(cost_fn, cost_fn) + self.cost_fn: Any = _COST_FNS.get(cost_fn, cost_fn) self.better = get_better_fn(minimize) - self.best = {"flops": float("inf"), "size": float("inf")} - self.best_progress = defaultdict(lambda: float("inf")) + self.best: Dict[str, Any] = {"flops": float("inf"), "size": float("inf")} + self.best_progress: Dict[int, float] = defaultdict(lambda: float("inf")) @property def path(self) -> PathType: @@ -377,18 +368,16 @@ def __call__( ) -> PathType: """ - **Parameters:** - - - **input_sets** - *(list)* List of sets that represent the lhs side of the einsum subscript - - **output_set** - *(set)* Set that represents the rhs side of the overall einsum subscript - - **idx_dict** - *(dictionary)* Dictionary of index sizes - - **memory_limit** - *(int)* The maximum number of elements in a temporary array - - **Returns:** + Parameters: + inputs_: List of sets that represent the lhs side of the einsum subscript + output_: Set that represents the rhs side of the overall einsum subscript + size_dict: Dictionary of index sizes + memory_limit: The maximum number of elements in a temporary array - - **path** - *(list)* The contraction order within the memory limit constraint. + Returns: + path: The contraction order within the memory limit constraint. - **Examples:** + Examples: ```python isets = [set('abd'), set('ac'), set('bdc')] @@ -505,7 +494,7 @@ def branch( memory_limit: Optional[int] = None, **optimizer_kwargs: Dict[str, Any], ) -> PathType: - optimizer = BranchBound(**optimizer_kwargs) + optimizer = BranchBound(**optimizer_kwargs) # type: ignore return optimizer(inputs, output, size_dict, memory_limit) @@ -631,7 +620,7 @@ def ssa_greedy_optimize( # Deduplicate shapes by eagerly computing Hadamard products. remaining: Dict[ArrayIndexType, int] = {} # key -> ssa_id ssa_ids = itertools.count(len(fs_inputs)) - ssa_path = [] + ssa_path: List[TensorShapeType] = [] for ssa_id, key in enumerate(fs_inputs): if key in remaining: ssa_path.append((remaining[key], ssa_id)) @@ -749,28 +738,26 @@ def greedy( This algorithm scales quadratically with respect to the maximum number of elements sharing a common dim. - **Parameters:** - - - **inputs** - *(list)* List of sets that represent the lhs side of the einsum subscript - - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript - - **size_dict** - *(dictionary)* Dictionary of index sizes - - **memory_limit** - *(int)* The maximum number of elements in a temporary array - - **choose_fn** - *(callable, optional)* A function that chooses which contraction to perform from the queue - - **cost_fn** - *(callable, optional)* A function that assigns a potential contraction a cost. + Parameters: + inputs: List of sets that represent the lhs side of the einsum subscript + output: Set that represents the rhs side of the overall einsum subscript + size_dict: Dictionary of index sizes + memory_limit: The maximum number of elements in a temporary array + choose_fn: A function that chooses which contraction to perform from the queue + cost_fn: A function that assigns a potential contraction a cost. - **Returns:** + Returns: + path: The contraction order (a list of tuples of ints). - - **path** - *(list)* The contraction order (a list of tuples of ints). + Examples: - **Examples:** - - ```python - isets = [set('abd'), set('ac'), set('bdc')] - oset = set('') - idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} - greedy(isets, oset, idx_sizes) - #> [(0, 2), (0, 1)] - ``` + ```python + isets = [set('abd'), set('ac'), set('bdc')] + oset = set('') + idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} + greedy(isets, oset, idx_sizes) + #> [(0, 2), (0, 1)] + ``` """ if memory_limit not in _UNLIMITED_MEM: return branch(inputs, output, size_dict, memory_limit, nbranch=1, cost_fn=cost_fn) # type: ignore @@ -788,20 +775,17 @@ def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType: contractions are commutative, e.g. (j, k, l) = (k, l, j). Note that in general, solutions are not unique. - **Parameters:** + Parameters: + c: Contraction tree - - **c** - *(tuple or int)* Contraction tree + Returns: + path: Contraction path - **Returns:** - - - **path** - *(list[set[int]])* Contraction path - - **Examples:** - - ```python - _tree_to_sequence(((1,2),(0,(4,5,3)))) - #> [(1, 2), (1, 2, 3), (0, 2), (0, 1)] - ``` + Examples: + ```python + _tree_to_sequence(((1,2),(0,(4,5,3)))) + #> [(1, 2), (1, 2, 3), (0, 2), (0, 1)] + ``` """ # ((1,2),(0,(4,5,3))) --> [(1, 2), (1, 2, 3), (0, 2), (0, 1)] @@ -843,23 +827,21 @@ def _find_disconnected_subgraphs(inputs: List[FrozenSet[int]], output: FrozenSet connected if they share summation indices. Note: Disconnected subgraphs can be contracted independently before forming outer products. - **Parameters:** - - **inputs** - *(list[set])* List of sets that represent the lhs side of the einsum subscript - - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript - - **Returns:** + Parameters: + inputs: List of sets that represent the lhs side of the einsum subscript + output: Set that represents the rhs side of the overall einsum subscript - - **subgraphs** - *(list[set[int]])* List containing sets of indices for each subgraph + Returns: + subgraphs: List containing sets of indices for each subgraph - **Examples:** - - ```python - _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("bd")) - #> [{0, 2}, {1}] + Examples: + ```python + _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("bd")) + #> [{0, 2}, {1}] - _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("abd")) - #> [{0}, {1}, {2}] - ``` + _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("abd")) + #> [{0}, {1}, {2}] + ``` """ subgraphs = [] @@ -1136,31 +1118,28 @@ class DynamicProgramming(PathOptimizer): linearly in the number of disconnected subgraphs and only exponentially with the number of inputs per subgraph. - **Parameters:** - - - **minimize** - *({'flops', 'size', 'write', 'combo', 'limit', callable}, optional)* What to minimize: - - - 'flops' - minimize the number of flops - - 'size' - minimize the size of the largest intermediate - - 'write' - minimize the size of all intermediate tensors - - 'combo' - minimize `flops + alpha * write` summed over intermediates, a default ratio of alpha=64 - is used, or it can be customized with `f'combo-{alpha}'` - - 'limit' - minimize `max(flops, alpha * write)` summed over intermediates, a default ratio of alpha=64 - is used, or it can be customized with `f'limit-{alpha}'` - - callable - a custom local cost function - - - **cost_cap** - *({True, False, int}, optional)* How to implement cost-capping: - - - True - iteratively increase the cost-cap - - False - implement no cost-cap at all - - int - use explicit cost cap - - - **search_outer** - *(bool, optional)* In rare circumstances the optimal contraction may involve an outer - product, this option allows searching such contractions but may well - slow down the path finding considerably on all but very small graphs. + Parameters: + minimize: What to minimize: + - 'flops' - minimize the number of flops + - 'size' - minimize the size of the largest intermediate + - 'write' - minimize the size of all intermediate tensors + - 'combo' - minimize `flops + alpha * write` summed over intermediates, a default ratio of alpha=64 + is used, or it can be customized with `f'combo-{alpha}'` + - 'limit' - minimize `max(flops, alpha * write)` summed over intermediates, a default ratio of alpha=64 + is used, or it can be customized with `f'limit-{alpha}'` + - callable - a custom local cost function + + cost_cap: How to implement cost-capping: + - True - iteratively increase the cost-cap + - False - implement no cost-cap at all + - int - use explicit cost cap + + search_outer: In rare circumstances the optimal contraction may involve an outer + product, this option allows searching such contractions but may well + slow down the path finding considerably on all but very small graphs. """ - def __init__(self, minimize: str = "flops", cost_cap: bool = True, search_outer: bool = False) -> None: + def __init__(self, minimize: str = "flops", cost_cap: Union[bool, int] = True, search_outer: bool = False) -> None: self.minimize = minimize self.search_outer = search_outer self.cost_cap = cost_cap @@ -1170,40 +1149,37 @@ def __call__( inputs_: List[ArrayIndexType], output_: ArrayIndexType, size_dict_: Dict[str, int], - memory_limit: Optional[int] = None, + memory_limit_: Optional[int] = None, ) -> PathType: """ - **Parameters:** - - - **inputs** - *(list)* List of sets that represent the lhs side of the einsum subscript - - **output** - *(set)* Set that represents the rhs side of the overall einsum subscript - - **size_dict** - *(dictionary)* Dictionary of index sizes - - **memory_limit** - *(int)* The maximum number of elements in a temporary array - - **Returns:** - - - **path** - *(list)* The contraction order (a list of tuples of ints). - - **Examples:** - - ```python - n_in = 3 # exponential scaling - n_out = 2 # linear scaling - s = dict() - i_all = [] - for _ in range(n_out): - i = [set() for _ in range(n_in)] - for j in range(n_in): - for k in range(j+1, n_in): - c = oe.get_symbol(len(s)) - i[j].add(c) - i[k].add(c) - s[c] = 2 - i_all.extend(i) - o = DynamicProgramming() - o(i_all, set(), s) - #> [(1, 2), (0, 4), (1, 2), (0, 2), (0, 1)] - ``` + Parameters: + inputs_: List of sets that represent the lhs side of the einsum subscript + output_: Set that represents the rhs side of the overall einsum subscript + size_dict_: Dictionary of index sizes + memory_limit_: The maximum number of elements in a temporary array + + Returns: + path: The contraction order (a list of tuples of ints). + + Examples: + ```python + n_in = 3 # exponential scaling + n_out = 2 # linear scaling + s = dict() + i_all = [] + for _ in range(n_out): + i = [set() for _ in range(n_in)] + for j in range(n_in): + for k in range(j+1, n_in): + c = oe.get_symbol(len(s)) + i[j].add(c) + i[k].add(c) + s[c] = 2 + i_all.extend(i) + o = DynamicProgramming() + o(i_all, set(), s) + #> [(1, 2), (0, 4), (1, 2), (0, 2), (0, 1)] + ``` """ _check_contraction, naive_scale = _parse_minimize(self.minimize) _check_outer = (lambda x: True) if self.search_outer else (lambda x: x) @@ -1302,7 +1278,7 @@ def __call__( all_tensors, inputs, i1_cut_i2_wo_output, - memory_limit, + memory_limit_, contract1, contract2, ) @@ -1391,7 +1367,6 @@ def auto_hq( return _AUTO_HQ_CHOICES.get(N, random_greedy_128)(inputs, output, size_dict, memory_limit) -PathSearchFunctionType = Callable[[List[ArrayIndexType], ArrayIndexType, Dict[str, int], Optional[int]], PathType] _PATH_OPTIONS: Dict[str, PathSearchFunctionType] = { "auto": auto, "auto-hq": auto_hq, diff --git a/opt_einsum/tests/test_backends.py b/opt_einsum/tests/test_backends.py index c6e02f6b..3481f86d 100644 --- a/opt_einsum/tests/test_backends.py +++ b/opt_einsum/tests/test_backends.py @@ -1,3 +1,5 @@ +from typing import Set + import numpy as np import pytest @@ -71,7 +73,7 @@ @pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.") @pytest.mark.parametrize("string", tests) -def test_tensorflow(string): +def test_tensorflow(string: str) -> None: views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) opt = np.empty_like(ein) @@ -93,7 +95,7 @@ def test_tensorflow(string): @pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.") @pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) -def test_tensorflow_with_constants(constants): +def test_tensorflow_with_constants(constants: Set[int]) -> None: eq = "ij,jk,kl->li" shapes = (2, 3), (3, 4), (4, 5) (non_const,) = {0, 1, 2} - constants @@ -122,7 +124,7 @@ def test_tensorflow_with_constants(constants): @pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.") @pytest.mark.parametrize("string", tests) -def test_tensorflow_with_sharing(string): +def test_tensorflow_with_sharing(string: str) -> None: views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) @@ -147,7 +149,7 @@ def test_tensorflow_with_sharing(string): @pytest.mark.skipif(not found_theano, reason="Theano not installed.") @pytest.mark.parametrize("string", tests) -def test_theano(string): +def test_theano(string: str) -> None: views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] @@ -165,7 +167,7 @@ def test_theano(string): @pytest.mark.skipif(not found_theano, reason="theano not installed.") @pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) -def test_theano_with_constants(constants): +def test_theano_with_constants(constants: Set[int]) -> None: eq = "ij,jk,kl->li" shapes = (2, 3), (3, 4), (4, 5) (non_const,) = {0, 1, 2} - constants @@ -191,7 +193,7 @@ def test_theano_with_constants(constants): @pytest.mark.skipif(not found_theano, reason="Theano not installed.") @pytest.mark.parametrize("string", tests) -def test_theano_with_sharing(string): +def test_theano_with_sharing(string: str) -> None: views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) @@ -214,7 +216,7 @@ def test_theano_with_sharing(string): @pytest.mark.skipif(not found_cupy, reason="Cupy not installed.") @pytest.mark.parametrize("string", tests) -def test_cupy(string): # pragma: no cover +def test_cupy(string: str) -> None: # pragma: no cover views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] @@ -233,7 +235,7 @@ def test_cupy(string): # pragma: no cover @pytest.mark.skipif(not found_cupy, reason="Cupy not installed.") @pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) -def test_cupy_with_constants(constants): # pragma: no cover +def test_cupy_with_constants(constants: Set[int]) -> None: # pragma: no cover eq = "ij,jk,kl->li" shapes = (2, 3), (3, 4), (4, 5) (non_const,) = {0, 1, 2} - constants @@ -261,7 +263,7 @@ def test_cupy_with_constants(constants): # pragma: no cover @pytest.mark.skipif(not found_jax, reason="jax not installed.") @pytest.mark.parametrize("string", tests) -def test_jax(string): # pragma: no cover +def test_jax(string: str) -> None: # pragma: no cover views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] @@ -275,7 +277,7 @@ def test_jax(string): # pragma: no cover @pytest.mark.skipif(not found_jax, reason="jax not installed.") @pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) -def test_jax_with_constants(constants): # pragma: no cover +def test_jax_with_constants(constants: Set[int]) -> None: # pragma: no cover eq = "ij,jk,kl->li" shapes = (2, 3), (3, 4), (4, 5) (non_const,) = {0, 1, 2} - constants @@ -294,7 +296,7 @@ def test_jax_with_constants(constants): # pragma: no cover @pytest.mark.skipif(not found_jax, reason="jax not installed.") -def test_jax_jit_gradient(): +def test_jax_jit_gradient() -> None: eq = "ij,jk,kl->" shapes = (2, 3), (3, 4), (4, 2) views = [np.random.randn(*s) for s in shapes] @@ -317,7 +319,7 @@ def test_jax_jit_gradient(): @pytest.mark.skipif(not found_autograd, reason="autograd not installed.") -def test_autograd_gradient(): +def test_autograd_gradient() -> None: eq = "ij,jk,kl->" shapes = (2, 3), (3, 4), (4, 2) views = [np.random.randn(*s) for s in shapes] @@ -336,7 +338,7 @@ def test_autograd_gradient(): @pytest.mark.parametrize("string", tests) -def test_dask(string): +def test_dask(string: str) -> None: da = pytest.importorskip("dask.array") views = helpers.build_views(string) @@ -360,7 +362,7 @@ def test_dask(string): @pytest.mark.parametrize("string", tests) -def test_sparse(string): +def test_sparse(string: str) -> None: sparse = pytest.importorskip("sparse") views = helpers.build_views(string) @@ -396,7 +398,7 @@ def test_sparse(string): @pytest.mark.skipif(not found_torch, reason="Torch not installed.") @pytest.mark.parametrize("string", tests) -def test_torch(string): +def test_torch(string: str) -> None: views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) @@ -416,7 +418,7 @@ def test_torch(string): @pytest.mark.skipif(not found_torch, reason="Torch not installed.") @pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) -def test_torch_with_constants(constants): +def test_torch_with_constants(constants: Set[int]) -> None: eq = "ij,jk,kl->li" shapes = (2, 3), (3, 4), (4, 5) (non_const,) = {0, 1, 2} - constants @@ -442,7 +444,7 @@ def test_torch_with_constants(constants): assert np.allclose(res_exp, res_got3) -def test_auto_backend_custom_array_no_tensordot(): +def test_auto_backend_custom_array_no_tensordot() -> None: x = Shaped((1, 2, 3)) # Shaped is an array-like object defined by opt_einsum - which has no TDOT assert infer_backend(x) == "opt_einsum" @@ -451,7 +453,7 @@ def test_auto_backend_custom_array_no_tensordot(): @pytest.mark.parametrize("string", tests) -def test_object_arrays_backend(string): +def test_object_arrays_backend(string: str) -> None: views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) assert ein.dtype != object diff --git a/opt_einsum/tests/test_blas.py b/opt_einsum/tests/test_blas.py index abab77f5..e72b8a56 100644 --- a/opt_einsum/tests/test_blas.py +++ b/opt_einsum/tests/test_blas.py @@ -2,6 +2,8 @@ Tests the BLAS capability for the opt_einsum module. """ +from typing import Any + import numpy as np import pytest @@ -59,13 +61,13 @@ @pytest.mark.parametrize("inp,benchmark", blas_tests) -def test_can_blas(inp, benchmark): +def test_can_blas(inp: Any, benchmark: bool) -> None: result = blas.can_blas(*inp) assert result == benchmark @pytest.mark.parametrize("inp,benchmark", blas_tests) -def test_tensor_blas(inp, benchmark): +def test_tensor_blas(inp: Any, benchmark: bool) -> None: # Weed out non-blas cases if benchmark is False: @@ -83,17 +85,18 @@ def test_tensor_blas(inp, benchmark): einsum_result = np.einsum(einsum_str, view_left, view_right) blas_result = blas.tensor_blas(view_left, tensor_strs[0], view_right, tensor_strs[1], output, reduced_idx) - assert np.allclose(einsum_result, blas_result) + np.testing.assert_allclose(einsum_result, blas_result) -def test_blas_out(): +def test_blas_out() -> None: a = np.random.rand(4, 4) b = np.random.rand(4, 4) c = np.random.rand(4, 4) d = np.empty((4, 4)) contract("ij,jk->ik", a, b, out=d) + np.testing.assert_allclose(d, np.dot(a, b)) assert np.allclose(d, np.dot(a, b)) contract("ij,jk,kl->il", a, b, c, out=d) - assert np.allclose(d, np.dot(a, b).dot(c)) + np.testing.assert_allclose(d, np.dot(a, b).dot(c)) diff --git a/opt_einsum/tests/test_contract.py b/opt_einsum/tests/test_contract.py index ef3ee5de..8e7cec10 100644 --- a/opt_einsum/tests/test_contract.py +++ b/opt_einsum/tests/test_contract.py @@ -2,11 +2,14 @@ Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths """ +from typing import Any, List + import numpy as np import pytest from opt_einsum import contract, contract_expression, contract_path, helpers from opt_einsum.paths import _PATH_OPTIONS, linear_to_ssa, ssa_to_linear +from opt_einsum.typing import OptimizeKind tests = [ # Test scalar-like operations @@ -95,7 +98,7 @@ @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", _PATH_OPTIONS) -def test_compare(optimize, string): +def test_compare(optimize: OptimizeKind, string: str) -> None: views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) @@ -104,7 +107,7 @@ def test_compare(optimize, string): @pytest.mark.parametrize("string", tests) -def test_drop_in_replacement(string): +def test_drop_in_replacement(string: str) -> None: views = helpers.build_views(string) opt = contract(string, *views) assert np.allclose(opt, np.einsum(string, *views)) @@ -112,7 +115,7 @@ def test_drop_in_replacement(string): @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", _PATH_OPTIONS) -def test_compare_greek(optimize, string): +def test_compare_greek(optimize: OptimizeKind, string: str) -> None: views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) @@ -126,7 +129,7 @@ def test_compare_greek(optimize, string): @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", _PATH_OPTIONS) -def test_compare_blas(optimize, string): +def test_compare_blas(optimize: OptimizeKind, string: str) -> None: views = helpers.build_views(string) ein = contract(string, *views, optimize=False) @@ -136,7 +139,7 @@ def test_compare_blas(optimize, string): @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", _PATH_OPTIONS) -def test_compare_blas_greek(optimize, string): +def test_compare_blas_greek(optimize: OptimizeKind, string: str) -> None: views = helpers.build_views(string) ein = contract(string, *views, optimize=False) @@ -148,7 +151,7 @@ def test_compare_blas_greek(optimize, string): assert np.allclose(ein, opt) -def test_some_non_alphabet_maintains_order(): +def test_some_non_alphabet_maintains_order() -> None: # 'c beta a' should automatically go to -> 'a c beta' string = "c" + chr(ord("b") + 848) + "a" # but beta will be temporarily replaced with 'b' for which 'cba->abc' @@ -169,7 +172,7 @@ def test_printing(): @pytest.mark.parametrize("optimize", _PATH_OPTIONS) @pytest.mark.parametrize("use_blas", [False, True]) @pytest.mark.parametrize("out_spec", [False, True]) -def test_contract_expressions(string, optimize, use_blas, out_spec): +def test_contract_expressions(string: str, optimize: OptimizeKind, use_blas: bool, out_spec: bool) -> None: views = helpers.build_views(string) shapes = [view.shape if hasattr(view, "shape") else tuple() for view in views] expected = contract(string, *views, optimize=False, use_blas=False) @@ -189,9 +192,9 @@ def test_contract_expressions(string, optimize, use_blas, out_spec): assert string in expr.__str__() -def test_contract_expression_interleaved_input(): +def test_contract_expression_interleaved_input() -> None: x, y, z = (np.random.randn(2, 2) for _ in "xyz") - expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0]) + expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0]) # type: ignore xshp, yshp, zshp = ((2, 2) for _ in "xyz") expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0]) out = expr(x, y, z) @@ -210,13 +213,13 @@ def test_contract_expression_interleaved_input(): ("ab,bc,cd", [0, 1]), ], ) -def test_contract_expression_with_constants(string, constants): +def test_contract_expression_with_constants(string: str, constants: List[int]) -> None: views = helpers.build_views(string) expected = contract(string, *views, optimize=False, use_blas=False) shapes = [view.shape if hasattr(view, "shape") else tuple() for view in views] - expr_args = [] + expr_args: List[Any] = [] ctrc_args = [] for i, (shape, view) in enumerate(zip(shapes, views)): if i in constants: @@ -235,7 +238,7 @@ def test_contract_expression_with_constants(string, constants): @pytest.mark.parametrize("reg", [2, 3]) @pytest.mark.parametrize("n_out", [0, 2, 4]) @pytest.mark.parametrize("global_dim", [False, True]) -def test_rand_equation(optimize, n, reg, n_out, global_dim): +def test_rand_equation(optimize: OptimizeKind, n: int, reg: int, n_out: int, global_dim: bool) -> None: eq, _, size_dict = helpers.rand_equation(n, reg, n_out, d_min=2, d_max=5, seed=42, return_size_dict=True) views = helpers.build_views(eq, size_dict) @@ -246,7 +249,7 @@ def test_rand_equation(optimize, n, reg, n_out, global_dim): @pytest.mark.parametrize("equation", tests) -def test_linear_vs_ssa(equation): +def test_linear_vs_ssa(equation: str) -> None: views = helpers.build_views(equation) linear_path, _ = contract_path(equation, *views) ssa_path = linear_to_ssa(linear_path) @@ -254,7 +257,7 @@ def test_linear_vs_ssa(equation): assert linear_path2 == linear_path -def test_contract_path_supply_shapes(): +def test_contract_path_supply_shapes() -> None: eq = "ab,bc,cd" shps = [(2, 3), (3, 4), (4, 5)] contract_path(eq, *shps, shapes=True) diff --git a/opt_einsum/tests/test_edge_cases.py b/opt_einsum/tests/test_edge_cases.py index 7851b237..80942495 100644 --- a/opt_einsum/tests/test_edge_cases.py +++ b/opt_einsum/tests/test_edge_cases.py @@ -5,10 +5,11 @@ import numpy as np import pytest -from opt_einsum import contract, contract_path, contract_expression +from opt_einsum import contract, contract_expression, contract_path +from opt_einsum.typing import PathType -def test_contract_expression_checks(): +def test_contract_expression_checks() -> None: # check optimize needed with pytest.raises(ValueError): contract_expression("ab,bc->ac", (2, 3), (3, 4), optimize=False) @@ -47,12 +48,12 @@ def test_contract_expression_checks(): assert "Internal error while evaluating `ContractExpression`" in str(err.value) # should only be able to specify out - with pytest.raises(ValueError) as err: - expr(np.random.rand(2, 3), np.random.rand(3, 4), order="F") - assert "only valid keyword arguments to a `ContractExpression`" in str(err.value) + with pytest.raises(TypeError) as err_type: + expr(np.random.rand(2, 3), np.random.rand(3, 4), order="F") # type: ignore + assert "got an unexpected keyword" in str(err_type.value) -def test_broadcasting_contraction(): +def test_broadcasting_contraction() -> None: a = np.random.rand(1, 5, 4) b = np.random.rand(4, 6) c = np.random.rand(5, 6) @@ -71,7 +72,7 @@ def test_broadcasting_contraction(): assert np.allclose(opt, result) -def test_broadcasting_contraction2(): +def test_broadcasting_contraction2() -> None: a = np.random.rand(1, 1, 5, 4) b = np.random.rand(4, 6) c = np.random.rand(5, 6) @@ -90,7 +91,7 @@ def test_broadcasting_contraction2(): assert np.allclose(opt, result) -def test_broadcasting_contraction3(): +def test_broadcasting_contraction3() -> None: a = np.random.rand(1, 5, 4) b = np.random.rand(4, 1, 6) c = np.random.rand(5, 6) @@ -102,7 +103,7 @@ def test_broadcasting_contraction3(): assert np.allclose(ein, opt) -def test_broadcasting_contraction4(): +def test_broadcasting_contraction4() -> None: a = np.arange(64).reshape(2, 4, 8) ein = contract("obk,ijk->ioj", a, a, optimize=False) opt = contract("obk,ijk->ioj", a, a, optimize=True) @@ -110,7 +111,7 @@ def test_broadcasting_contraction4(): assert np.allclose(ein, opt) -def test_can_blas_on_healed_broadcast_dimensions(): +def test_can_blas_on_healed_broadcast_dimensions() -> None: expr = contract_expression("ab,bc,bd->acd", (5, 4), (1, 5), (4, 20)) # first contraction involves broadcasting assert expr.contraction_list[0][2] == "bc,ab->bca" @@ -120,10 +121,10 @@ def test_can_blas_on_healed_broadcast_dimensions(): assert expr.contraction_list[1][-1] == "GEMM" -def test_pathinfo_for_empty_contraction(): +def test_pathinfo_for_empty_contraction() -> None: eq = "->" arrays = (1.0,) - path = [] + path: PathType = [] _, info = contract_path(eq, *arrays, optimize=path) # some info is built lazily, so check repr assert repr(info) diff --git a/opt_einsum/tests/test_input.py b/opt_einsum/tests/test_input.py index 17dc0ac7..6f1ecc13 100644 --- a/opt_einsum/tests/test_input.py +++ b/opt_einsum/tests/test_input.py @@ -2,16 +2,19 @@ Tests the input parsing for opt_einsum. Duplicates the np.einsum input tests. """ +from typing import Any + import numpy as np import pytest from opt_einsum import contract, contract_path +from opt_einsum.typing import ArrayType -def build_views(string): +def build_views(string: str) -> list[ArrayType]: chars = "abcdefghij" - sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4]) - sizes = {c: s for c, s in zip(chars, sizes)} + sizes_array = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4]) + sizes = {c: s for c, s in zip(chars, sizes_array)} views = [] @@ -24,7 +27,7 @@ def build_views(string): return views -def test_type_errors(): +def test_type_errors() -> None: # subscripts must be a string with pytest.raises(TypeError): contract(0, 0) @@ -36,11 +39,11 @@ def test_type_errors(): # order parameter must be a valid order # changed in Numpy 1.19, see https://github.com/numpy/numpy/commit/35b0a051c19265f5643f6011ee11e31d30c8bc4c with pytest.raises((TypeError, ValueError)): - contract("", 0, order="W") + contract("", 0, order="W") # type: ignore # casting parameter must be a valid casting with pytest.raises(ValueError): - contract("", 0, casting="blah") + contract("", 0, casting="blah") # type: ignore # dtype parameter must be a valid dtype with pytest.raises(TypeError): @@ -82,7 +85,7 @@ def test_type_errors(): @pytest.mark.parametrize("contract_fn", [contract, contract_path]) -def test_value_errors(contract_fn): +def test_value_errors(contract_fn: Any) -> None: with pytest.raises(ValueError): contract_fn("") @@ -181,7 +184,7 @@ def test_value_errors(contract_fn): "...a,...b", ], ) -def test_compare(string): +def test_compare(string: str) -> None: views = build_views(string) ein = contract(string, *views, optimize=False) @@ -192,7 +195,7 @@ def test_compare(string): assert np.allclose(ein, opt) -def test_ellipse_input1(): +def test_ellipse_input1() -> None: string = "...a->..." views = build_views(string) @@ -201,7 +204,7 @@ def test_ellipse_input1(): assert np.allclose(ein, opt) -def test_ellipse_input2(): +def test_ellipse_input2() -> None: string = "...a" views = build_views(string) @@ -210,7 +213,7 @@ def test_ellipse_input2(): assert np.allclose(ein, opt) -def test_ellipse_input3(): +def test_ellipse_input3() -> None: string = "...a->...a" views = build_views(string) @@ -219,7 +222,7 @@ def test_ellipse_input3(): assert np.allclose(ein, opt) -def test_ellipse_input4(): +def test_ellipse_input4() -> None: string = "...b,...a->..." views = build_views(string) @@ -228,7 +231,7 @@ def test_ellipse_input4(): assert np.allclose(ein, opt) -def test_singleton_dimension_broadcast(): +def test_singleton_dimension_broadcast() -> None: # singleton dimensions broadcast (gh-10343) p = np.ones((10, 2)) q = np.ones((1, 2)) @@ -248,7 +251,7 @@ def test_singleton_dimension_broadcast(): assert np.allclose(res2, np.full((1, 5), 5)) -def test_large_int_input_format(): +def test_large_int_input_format() -> None: string = "ab,bc,cd" x, y, z = build_views(string) string_output = contract(string, x, y, z) @@ -259,7 +262,7 @@ def test_large_int_input_format(): assert np.allclose(transpose_output, x.T) -def test_hashable_object_input_format(): +def test_hashable_object_input_format() -> None: string = "ab,bc,cd" x, y, z = build_views(string) string_output = contract(string, x, y, z) diff --git a/opt_einsum/tests/test_parser.py b/opt_einsum/tests/test_parser.py index 0f458367..d582ca4d 100644 --- a/opt_einsum/tests/test_parser.py +++ b/opt_einsum/tests/test_parser.py @@ -8,7 +8,7 @@ from opt_einsum.parser import get_symbol, parse_einsum_input, possibly_convert_to_numpy -def test_get_symbol(): +def test_get_symbol() -> None: assert get_symbol(2) == "c" assert get_symbol(200000) == "\U00031540" # Ensure we skip surrogates '[\uD800-\uDFFF]' @@ -17,7 +17,7 @@ def test_get_symbol(): assert get_symbol(57343) == "\ue7ff" -def test_parse_einsum_input(): +def test_parse_einsum_input() -> None: eq = "ab,bc,cd" ops = [np.random.rand(2, 3), np.random.rand(3, 4), np.random.rand(4, 5)] input_subscripts, output_subscript, operands = parse_einsum_input([eq, *ops]) @@ -26,7 +26,7 @@ def test_parse_einsum_input(): assert operands == ops -def test_parse_einsum_input_shapes_error(): +def test_parse_einsum_input_shapes_error() -> None: eq = "ab,bc,cd" ops = [np.random.rand(2, 3), np.random.rand(3, 4), np.random.rand(4, 5)] @@ -34,7 +34,7 @@ def test_parse_einsum_input_shapes_error(): _ = parse_einsum_input([eq, *ops], shapes=True) -def test_parse_einsum_input_shapes(): +def test_parse_einsum_input_shapes() -> None: eq = "ab,bc,cd" shps = [(2, 3), (3, 4), (4, 5)] input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shps], shapes=True) diff --git a/opt_einsum/tests/test_paths.py b/opt_einsum/tests/test_paths.py index eb4b18bb..4566e8bd 100644 --- a/opt_einsum/tests/test_paths.py +++ b/opt_einsum/tests/test_paths.py @@ -5,11 +5,13 @@ import itertools import sys +from typing import Any, Dict, List, Optional import numpy as np import pytest import opt_einsum as oe +from opt_einsum.typing import ArrayIndexType, OptimizeKind, PathType, TensorShapeType explicit_path_tests = { "GEMM1": ( @@ -62,7 +64,7 @@ ] -def check_path(test_output, benchmark, bypass=False): +def check_path(test_output: PathType, benchmark: PathType, bypass: bool = False) -> bool: if not isinstance(test_output, list): return False @@ -72,17 +74,17 @@ def check_path(test_output, benchmark, bypass=False): ret = True for pos in range(len(test_output)): ret &= isinstance(test_output[pos], tuple) - ret &= test_output[pos] == benchmark[pos] + ret &= test_output[pos] == list(benchmark)[pos] return ret -def assert_contract_order(func, test_data, max_size, benchmark): +def assert_contract_order(func: Any, test_data: Any, max_size: int, benchmark: PathType) -> None: test_output = func(test_data[0], test_data[1], test_data[2], max_size) assert check_path(test_output, benchmark) -def test_size_by_dict(): +def test_size_by_dict() -> None: sizes_dict = {} for ind, val in zip("abcdez", [2, 5, 9, 11, 13, 0]): @@ -102,7 +104,7 @@ def test_size_by_dict(): assert 12870 == path_func("abcde", sizes_dict) -def test_flop_cost(): +def test_flop_cost() -> None: size_dict = {v: 10 for v in "abcdef"} @@ -124,17 +126,17 @@ def test_flop_cost(): assert 2000 == oe.helpers.flop_count("abc", True, 2, size_dict) -def test_bad_path_option(): +def test_bad_path_option() -> None: with pytest.raises(KeyError): - oe.contract("a,b,c", [1], [2], [3], optimize="optimall") + oe.contract("a,b,c", [1], [2], [3], optimize="optimall") # type: ignore -def test_explicit_path(): +def test_explicit_path() -> None: x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)]) assert x.item() == 6 -def test_path_optimal(): +def test_path_optimal() -> None: test_func = oe.paths.optimal @@ -143,7 +145,7 @@ def test_path_optimal(): assert_contract_order(test_func, test_data, 0, [(0, 1, 2)]) -def test_path_greedy(): +def test_path_greedy() -> None: test_func = oe.paths.greedy @@ -152,7 +154,7 @@ def test_path_greedy(): assert_contract_order(test_func, test_data, 0, [(0, 1, 2)]) -def test_memory_paths(): +def test_memory_paths() -> None: expression = "abc,bdef,fghj,cem,mhk,ljk->adgl" @@ -174,7 +176,7 @@ def test_memory_paths(): @pytest.mark.parametrize("alg,expression,order", path_edge_tests) -def test_path_edge_cases(alg, expression, order): +def test_path_edge_cases(alg: OptimizeKind, expression: str, order: PathType) -> None: views = oe.helpers.build_views(expression) # Test tiny memory limit @@ -184,7 +186,7 @@ def test_path_edge_cases(alg, expression, order): @pytest.mark.parametrize("expression,order", path_scalar_tests) @pytest.mark.parametrize("alg", oe.paths._PATH_OPTIONS) -def test_path_scalar_cases(alg, expression, order): +def test_path_scalar_cases(alg: OptimizeKind, expression: str, order: PathType) -> None: views = oe.helpers.build_views(expression) # Test tiny memory limit @@ -193,7 +195,7 @@ def test_path_scalar_cases(alg, expression, order): assert len(path_ret[0]) == order -def test_optimal_edge_cases(): +def test_optimal_edge_cases() -> None: # Edge test5 expression = "a,ac,ab,ad,cd,bd,bc->" @@ -205,7 +207,7 @@ def test_optimal_edge_cases(): assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)]) -def test_greedy_edge_cases(): +def test_greedy_edge_cases() -> None: expression = "abc,cfd,dbe,efa" dim_dict = {k: 20 for k in expression.replace(",", "")} @@ -218,21 +220,21 @@ def test_greedy_edge_cases(): assert check_path(path, [(0, 1), (0, 2), (0, 1)]) -def test_dp_edge_cases_dimension_1(): +def test_dp_edge_cases_dimension_1() -> None: eq = "nlp,nlq,pl->n" shapes = [(1, 1, 1), (1, 1, 1), (1, 1)] info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1] assert max(info.scale_list) == 3 -def test_dp_edge_cases_all_singlet_indices(): +def test_dp_edge_cases_all_singlet_indices() -> None: eq = "a,bcd,efg->" shapes = [(2,), (2, 2, 2), (2, 2, 2)] info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1] assert max(info.scale_list) == 3 -def test_custom_dp_can_optimize_for_outer_products(): +def test_custom_dp_can_optimize_for_outer_products() -> None: eq = "a,b,abc->c" da, db, dc = 2, 2, 3 @@ -247,7 +249,7 @@ def test_custom_dp_can_optimize_for_outer_products(): assert info2.opt_cost < info1.opt_cost -def test_custom_dp_can_optimize_for_size(): +def test_custom_dp_can_optimize_for_size() -> None: eq, shapes = oe.helpers.rand_equation(10, 4, seed=43) opt1 = oe.DynamicProgramming(minimize="flops") @@ -260,7 +262,7 @@ def test_custom_dp_can_optimize_for_size(): assert info1.largest_intermediate > info2.largest_intermediate -def test_custom_dp_can_set_cost_cap(): +def test_custom_dp_can_set_cost_cap() -> None: eq, shapes = oe.helpers.rand_equation(5, 3, seed=42) opt1 = oe.DynamicProgramming(cost_cap=True) opt2 = oe.DynamicProgramming(cost_cap=False) @@ -283,7 +285,7 @@ def test_custom_dp_can_set_cost_cap(): ("limit-256", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]), ], ) -def test_custom_dp_can_set_minimize(minimize, cost, width, path): +def test_custom_dp_can_set_minimize(minimize: str, cost: int, width: int, path: PathType) -> None: eq, shapes = oe.helpers.rand_equation(10, 4, seed=43) opt = oe.DynamicProgramming(minimize=minimize) info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)[1] @@ -292,12 +294,12 @@ def test_custom_dp_can_set_minimize(minimize, cost, width, path): assert info.largest_intermediate == width -def test_dp_errors_when_no_contractions_found(): - eq, shapes, size_dict = oe.helpers.rand_equation(10, 3, seed=42, return_size_dict=True) +def test_dp_errors_when_no_contractions_found() -> None: + eq, shapes = oe.helpers.rand_equation(10, 3, seed=42) # first get the actual minimum cost opt = oe.DynamicProgramming(minimize="size") - path, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt) + _, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt) mincost = info.largest_intermediate # check we can still find it without minimizing size explicitly @@ -309,7 +311,7 @@ def test_dp_errors_when_no_contractions_found(): @pytest.mark.parametrize("optimize", ["greedy", "branch-2", "branch-all", "optimal", "dp"]) -def test_can_optimize_outer_products(optimize): +def test_can_optimize_outer_products(optimize: OptimizeKind) -> None: a, b, c = [np.random.randn(10, 10) for _ in range(3)] d = np.random.randn(10, 2) assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize)[0] == [ @@ -320,7 +322,7 @@ def test_can_optimize_outer_products(optimize): @pytest.mark.parametrize("num_symbols", [2, 3, 26, 26 + 26, 256 - 140, 300]) -def test_large_path(num_symbols): +def test_large_path(num_symbols: int) -> None: symbols = "".join(oe.get_symbol(i) for i in range(num_symbols)) dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4]))) expression = ",".join(symbols[t : t + 2] for t in range(num_symbols - 1)) @@ -330,7 +332,7 @@ def test_large_path(num_symbols): oe.contract_path(expression, *tensors, optimize="greedy") -def test_custom_random_greedy(): +def test_custom_random_greedy() -> None: eq, shapes = oe.helpers.rand_equation(10, 4, seed=42) views = list(map(np.ones, shapes)) @@ -368,7 +370,7 @@ def test_custom_random_greedy(): path, path_info = oe.contract_path(eq, *views, optimize=optimizer) -def test_custom_branchbound(): +def test_custom_branchbound() -> None: eq, shapes = oe.helpers.rand_equation(8, 4, seed=42) views = list(map(np.ones, shapes)) optimizer = oe.BranchBound(nbranch=2, cutoff_flops_factor=10, minimize="size") @@ -395,13 +397,13 @@ def test_custom_branchbound(): path, path_info = oe.contract_path(eq, *views, optimize=optimizer) -def test_branchbound_validation(): +def test_branchbound_validation() -> None: with pytest.raises(ValueError): oe.BranchBound(nbranch=0) @pytest.mark.skipif(sys.version_info < (3, 2), reason="requires python3.2 or higher") -def test_parallel_random_greedy(): +def test_parallel_random_greedy() -> None: from concurrent.futures import ProcessPoolExecutor pool = ProcessPoolExecutor(2) @@ -445,9 +447,16 @@ def test_parallel_random_greedy(): assert all(are_done) -def test_custom_path_optimizer(): +def test_custom_path_optimizer() -> None: + class NaiveOptimizer(oe.paths.PathOptimizer): - def __call__(self, inputs, output, size_dict, memory_limit=None): + def __call__( + self, + inputs: List[ArrayIndexType], + output: ArrayIndexType, + size_dict: Dict[str, int], + memory_limit: Optional[int] = None, + ) -> PathType: self.was_used = True return [(0, 1)] * (len(inputs) - 1) @@ -462,13 +471,15 @@ def __call__(self, inputs, output, size_dict, memory_limit=None): assert optimizer.was_used -def test_custom_random_optimizer(): +def test_custom_random_optimizer() -> None: class NaiveRandomOptimizer(oe.path_random.RandomOptimizer): @staticmethod - def random_path(r, n, inputs, output, size_dict): + def random_path( + r: int, n: int, inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int] + ) -> Any: """Picks a completely random contraction order.""" np.random.seed(r) - ssa_path = [] + ssa_path: List[TensorShapeType] = [] remaining = set(range(n)) while len(remaining) > 1: i, j = np.random.choice(list(remaining), size=2, replace=False) @@ -479,7 +490,7 @@ def random_path(r, n, inputs, output, size_dict): cost, size = oe.path_random.ssa_path_compute_cost(ssa_path, inputs, output, size_dict) return ssa_path, cost, size - def setup(self, inputs, output, size_dict): + def setup(self, inputs: Any, output: Any, size_dict: Any) -> Any: self.was_used = True n = len(inputs) trial_fn = self.random_path @@ -499,8 +510,10 @@ def setup(self, inputs, output, size_dict): assert len(optimizer.costs) == 16 -def test_optimizer_registration(): - def custom_optimizer(inputs, output, size_dict, memory_limit): +def test_optimizer_registration() -> None: + def custom_optimizer( + inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int] + ) -> PathType: return [(0, 1)] * (len(inputs) - 1) with pytest.raises(KeyError): @@ -511,6 +524,6 @@ def custom_optimizer(inputs, output, size_dict, memory_limit): eq = "ab,bc,cd" shapes = [(2, 3), (3, 4), (4, 5)] - path, path_info = oe.contract_path(eq, *shapes, shapes=True, optimize="custom") + path, _ = oe.contract_path(eq, *shapes, shapes=True, optimize="custom") # type: ignore assert path == [(0, 1), (0, 1)] del oe.paths._PATH_OPTIONS["custom"] diff --git a/opt_einsum/tests/test_sharing.py b/opt_einsum/tests/test_sharing.py index 0df72765..dcf880c3 100644 --- a/opt_einsum/tests/test_sharing.py +++ b/opt_einsum/tests/test_sharing.py @@ -1,6 +1,7 @@ import itertools import weakref from collections import Counter +from typing import Any import numpy as np import pytest @@ -10,6 +11,7 @@ from opt_einsum.contract import _einsum from opt_einsum.parser import parse_einsum_input from opt_einsum.sharing import count_cached_ops, currently_sharing, get_sharing_cache +from opt_einsum.typing import BackendType try: import cupy # noqa @@ -46,7 +48,7 @@ @pytest.mark.parametrize("eq", equations) @pytest.mark.parametrize("backend", backends) -def test_sharing_value(eq, backend): +def test_sharing_value(eq: str, backend: BackendType) -> None: views = helpers.build_views(eq) shapes = [v.shape for v in views] expr = contract_expression(eq, *shapes) @@ -59,7 +61,7 @@ def test_sharing_value(eq, backend): @pytest.mark.parametrize("backend", backends) -def test_complete_sharing(backend): +def test_complete_sharing(backend: BackendType) -> None: eq = "ab,bc,cd->" views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) @@ -84,7 +86,7 @@ def test_complete_sharing(backend): @pytest.mark.parametrize("backend", backends) -def test_sharing_reused_cache(backend): +def test_sharing_reused_cache(backend: BackendType) -> None: eq = "ab,bc,cd->" views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) @@ -110,7 +112,7 @@ def test_sharing_reused_cache(backend): @pytest.mark.parametrize("backend", backends) -def test_no_sharing_separate_cache(backend): +def test_no_sharing_separate_cache(backend: BackendType) -> None: eq = "ab,bc,cd->" views = helpers.build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) @@ -138,11 +140,11 @@ def test_no_sharing_separate_cache(backend): @pytest.mark.parametrize("backend", backends) -def test_sharing_nesting(backend): +def test_sharing_nesting(backend: BackendType) -> None: eqs = ["ab,bc,cd->a", "ab,bc,cd->b", "ab,bc,cd->c", "ab,bc,cd->c"] views = helpers.build_views(eqs[0]) shapes = [v.shape for v in views] - refs = weakref.WeakValueDictionary() + refs: Any = weakref.WeakValueDictionary() def method1(views): with shared_intermediates(): @@ -178,11 +180,11 @@ def method2(views): @pytest.mark.parametrize("eq", equations) @pytest.mark.parametrize("backend", backends) -def test_sharing_modulo_commutativity(eq, backend): +def test_sharing_modulo_commutativity(eq: str, backend: BackendType) -> None: ops = helpers.build_views(eq) ops = [to_backend[backend](x) for x in ops] inputs, output, _ = parse_einsum_input([eq] + ops) - inputs = inputs.split(",") + inputs_list = inputs.split(",") print("-" * 40) print("Without sharing:") @@ -193,7 +195,7 @@ def test_sharing_modulo_commutativity(eq, backend): print("-" * 40) print("With sharing:") with shared_intermediates() as cache: - for permuted in itertools.permutations(zip(inputs, ops)): + for permuted in itertools.permutations(zip(inputs_list, ops)): permuted_inputs = [p[0] for p in permuted] permuted_ops = [p[1] for p in permuted] permuted_eq = "{}->{}".format(",".join(permuted_inputs), output) @@ -207,7 +209,7 @@ def test_sharing_modulo_commutativity(eq, backend): @pytest.mark.parametrize("backend", backends) -def test_partial_sharing(backend): +def test_partial_sharing(backend: BackendType) -> None: eq = "ab,bc,de->" x, y, z1 = helpers.build_views(eq) z2 = 2.0 * z1 - 1.0 @@ -215,7 +217,7 @@ def test_partial_sharing(backend): print("-" * 40) print("Without sharing:") - num_exprs_nosharing = Counter() + num_exprs_nosharing: Any = Counter() with shared_intermediates() as cache: expr(x, y, z1, backend=backend) num_exprs_nosharing.update(count_cached_ops(cache)) @@ -237,10 +239,10 @@ def test_partial_sharing(backend): @pytest.mark.parametrize("backend", backends) -def test_sharing_with_constants(backend): +def test_sharing_with_constants(backend: BackendType) -> None: inputs = "ij,jk,kl" outputs = "ijkl" - equations = ["{}->{}".format(inputs, output) for output in outputs] + equations = [f"{inputs}->{output}" for output in outputs] shapes = (2, 3), (3, 4), (4, 5) constants = {0, 2} ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] @@ -252,12 +254,12 @@ def test_sharing_with_constants(backend): actual = [contract_expression(eq, *ops, constants=constants)(var) for eq in equations] for dim, expected_dim, actual_dim in zip(outputs, expected, actual): - assert np.allclose(expected_dim, actual_dim), "error at {}".format(dim) + assert np.allclose(expected_dim, actual_dim), f"error at {dim}" @pytest.mark.parametrize("size", [3, 4, 5]) @pytest.mark.parametrize("backend", backends) -def test_chain(size, backend): +def test_chain(size: int, backend: BackendType) -> None: xs = [np.random.rand(2, 2) for _ in range(size)] shapes = [x.shape for x in xs] alphabet = "".join(get_symbol(i) for i in range(size + 1)) @@ -278,7 +280,7 @@ def test_chain(size, backend): @pytest.mark.parametrize("size", [3, 4, 5, 10]) @pytest.mark.parametrize("backend", backends) -def test_chain_2(size, backend): +def test_chain_2(size: int, backend: BackendType) -> None: xs = [np.random.rand(2, 2) for _ in range(size)] shapes = [x.shape for x in xs] alphabet = "".join(get_symbol(i) for i in range(size + 1)) @@ -303,7 +305,7 @@ def _compute_cost(cache): @pytest.mark.parametrize("backend", backends) -def test_chain_2_growth(backend): +def test_chain_2_growth(backend: BackendType) -> None: sizes = list(range(1, 21)) costs = [] for size in sizes: @@ -328,7 +330,7 @@ def test_chain_2_growth(backend): @pytest.mark.parametrize("size", [3, 4, 5]) @pytest.mark.parametrize("backend", backends) -def test_chain_sharing(size, backend): +def test_chain_sharing(size: int, backend: BackendType) -> None: xs = [np.random.rand(2, 2) for _ in range(size)] alphabet = "".join(get_symbol(i) for i in range(size + 1)) names = [alphabet[i : i + 2] for i in range(size)] @@ -339,7 +341,7 @@ def test_chain_sharing(size, backend): with shared_intermediates() as cache: target = alphabet[i] eq = "{}->{}".format(inputs, target) - expr = contract_expression(eq, *(x.shape for x in xs)) + expr = contract_expression(eq, *tuple(x.shape for x in xs)) expr(*xs, backend=backend) num_exprs_nosharing += _compute_cost(cache) @@ -350,7 +352,7 @@ def test_chain_sharing(size, backend): eq = "{}->{}".format(inputs, target) path_info = contract_path(eq, *xs) print(path_info[1]) - expr = contract_expression(eq, *(x.shape for x in xs)) + expr = contract_expression(eq, *list(x.shape for x in xs)) expr(*xs, backend=backend) num_exprs_sharing = _compute_cost(cache) @@ -360,7 +362,7 @@ def test_chain_sharing(size, backend): assert num_exprs_nosharing > num_exprs_sharing -def test_multithreaded_sharing(): +def test_multithreaded_sharing() -> None: from multiprocessing.pool import ThreadPool def fn(): diff --git a/opt_einsum/typing.py b/opt_einsum/typing.py index 3fb22ce3..175bb480 100644 --- a/opt_einsum/typing.py +++ b/opt_einsum/typing.py @@ -2,10 +2,25 @@ Types used in the opt_einsum package """ -from typing import Any, Collection, FrozenSet, List, Optional, Tuple, Union +from typing import Any, Callable, Collection, Dict, FrozenSet, List, Literal, Optional, Tuple, Union + +TensorShapeType = Tuple[int, ...] +PathType = Collection[TensorShapeType] -PathType = Collection[Tuple[int, ...]] ArrayType = Any # TODO ArrayIndexType = FrozenSet[str] -TensorShapeType = Tuple[int, ...] + ContractionListType = List[Tuple[Any, ArrayIndexType, str, Optional[Tuple[str, ...]], Union[str, bool]]] +PathSearchFunctionType = Callable[[List[ArrayIndexType], ArrayIndexType, Dict[str, int], Optional[int]], PathType] + +# Contract kwargs +OptimizeKind = Union[ + None, + bool, + Literal[ + "optimal", "dp", "greedy", "random-greedy", "random-greedy-128", "branch-all", "branch-2", "auto", "auto-hq" + ], + PathType, + PathSearchFunctionType, +] +BackendType = Literal["auto", "object", "autograd", "cupy", "dask", "jax", "theano", "tensorflow", "torch", "libjax"] diff --git a/scripts/compare_random_paths.py b/scripts/compare_random_paths.py index b6d4bf7a..b5122374 100644 --- a/scripts/compare_random_paths.py +++ b/scripts/compare_random_paths.py @@ -1,5 +1,6 @@ import resource import timeit +from typing import Literal import numpy as np import pandas as pd @@ -12,7 +13,7 @@ pd.set_option("display.width", 200) -opt_path = "optimal" +opt_path: Literal["optimal"] = "optimal" # Number of dimensions max_dims = 4 @@ -108,7 +109,7 @@ def random_contraction(): diff_flags = df["Flag"] is not True print("\nNumber of contract different than einsum: %d." % np.sum(diff_flags)) -if sum(diff_flags) > 0: +if diff_flags > 0: print("Terms different than einsum") print(df[df["Flag"] is not True])