From c302cf3282161e81ebcf77a627b6d0e8bf34b069 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 4 Sep 2024 06:13:00 -0700 Subject: [PATCH] address https://github.com/lucidrains/vector-quantize-pytorch/issues/159 --- pyproject.toml | 2 +- vector_quantize_pytorch/lookup_free_quantization.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 039d932..e205a2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.17.1" +version = "1.17.2" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/lookup_free_quantization.py b/vector_quantize_pytorch/lookup_free_quantization.py index 86e8d7e..a29f325 100644 --- a/vector_quantize_pytorch/lookup_free_quantization.py +++ b/vector_quantize_pytorch/lookup_free_quantization.py @@ -10,7 +10,9 @@ from functools import partial, cache from collections import namedtuple from contextlib import nullcontext + import torch.distributed as dist +from torch.distributed import nn as dist_nn import torch from torch import nn, einsum @@ -36,7 +38,7 @@ def maybe_distributed_mean(t): if not is_distributed(): return t - dist.nn.all_reduce(t) + dist_nn.all_reduce(t) t = t / dist.get_world_size() return t