-
Notifications
You must be signed in to change notification settings - Fork 205
/
Copy path__init__.py
62 lines (59 loc) · 1.83 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from torchao.float8.config import (
CastConfig,
DelayedScalingConfig,
Float8GemmConfig,
Float8LinearConfig,
ScalingType,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from torchao.float8.inductor_utils import (
_prototype_register_float8_delayed_scaling_inductor_passes,
)
from torchao.float8.inference import Float8MMConfig
from torchao.float8.stateful_float8_linear import WeightWithDelayedFloat8CastTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
if TORCH_VERSION_AT_LEAST_2_5:
# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals
add_safe_globals(
[
Float8Tensor,
ScaledMMConfig,
GemmInputRole,
LinearMMConfig,
Float8MMConfig,
WeightWithDelayedFloat8CastTensor,
]
)
__all__ = [
# configuration
"DelayedScalingConfig",
"ScalingType",
"Float8GemmConfig",
"Float8LinearConfig",
"CastConfig",
# top level UX
"convert_to_float8_training",
"linear_requires_sync",
"sync_float8_amax_and_scale_history",
"precompute_float8_dynamic_scale_for_fsdp",
"_prototype_register_float8_delayed_scaling_inductor_passes",
# note: Float8Tensor and Float8Linear are not public APIs
]