forked from nimashoghi/smart-quantization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
106 lines (84 loc) · 2.98 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from abc import abstractmethod
from argparse import ArgumentParser, Namespace
from typing import Callable, Dict, List, Union
import torch
@torch.no_grad()
def _reduce_fx(tensors: Union[torch.Tensor, List[torch.Tensor]]):
if isinstance(tensors, list):
if len(tensors) == 0:
return 0
if torch.is_tensor(tensors[0]):
return torch.sum(torch.stack(tensors))
else:
return sum(tensors)
return torch.sum(tensors)
def _convert_to_floats(d: Dict[str, Union[float, torch.Tensor]]):
return {key: float(value) for key, value in d.items()}
class CompressionAlgorithmBase:
log = None
log_custom = None
@staticmethod
def add_argparse_args(parent_parser: ArgumentParser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument(
"--measure_compression_ratio",
action="store_true",
dest="measure_compression_ratio",
)
return parser
def __init__(self, hparams: Namespace):
super().__init__()
self.hparams = hparams
def update_hparams(self, hparams: Namespace):
self.hparams = hparams
def _log_scalars(self, scalars: Dict[str, float], custom=False):
if custom and self.log_custom is not None:
self.log_custom(scalars)
return
for name, value in scalars.items():
kwargs = dict()
if "size" in name:
kwargs["reduce_fx"] = _reduce_fx
kwargs["tbptt_reduce_fx"] = _reduce_fx
self.log(name, value, **kwargs)
def log_ratio(
self,
tag: Union[str, None],
size: int,
orig_bitcount: float,
new_bitcount: float,
overhead=0,
):
return self.log_size(
tag, size * orig_bitcount, size * new_bitcount, overhead=overhead
)
def log_size(
self,
tag: Union[str, None],
orig_size: Union[float, Callable[[], float]],
new_size: Union[float, Callable[[], float]],
overhead=0,
):
if not self.hparams.measure_compression_ratio:
return
orig_size = orig_size() if callable(orig_size) else orig_size
new_size = new_size() if callable(new_size) else new_size
assert hasattr(self, "log")
new_size += overhead
compression_ratio = orig_size / new_size
self._log_scalars(
_convert_to_floats(
{
f"compression_ratio": compression_ratio,
f"compression_ratio_{tag}": compression_ratio,
f"new_size": new_size,
f"new_size_{tag}": new_size,
f"orig_size": orig_size,
f"orig_size_{tag}": orig_size,
}
),
custom=tag.startswith("optimizer_"),
)
@abstractmethod
def __call__(self, tensor: torch.Tensor, tag: str = None, **_):
raise Exception("Not implemented")