diff --git a/pyproject.toml b/pyproject.toml index 11bb3d1..364efe5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.18.7" +version = "1.18.8" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/residual_fsq.py b/vector_quantize_pytorch/residual_fsq.py index d57b594..a837f5d 100644 --- a/vector_quantize_pytorch/residual_fsq.py +++ b/vector_quantize_pytorch/residual_fsq.py @@ -185,7 +185,7 @@ def forward( # check if seed is manually passed in if not exists(rand_quantize_dropout_fixed_seed): - rand_quantize_dropout_fixed_seed = get_maybe_sync_seed() + rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) rand = random.Random(rand_quantize_dropout_fixed_seed) @@ -296,7 +296,7 @@ def forward( x, return_all_codes = False ): - shape, split_dim = x.shape, self.split_dim + shape, split_dim, device = x.shape, self.split_dim, x.device assert shape[split_dim] == self.dim # split the feature dimension into groups @@ -305,7 +305,7 @@ def forward( forward_kwargs = dict( return_all_codes = return_all_codes, - rand_quantize_dropout_fixed_seed = get_maybe_sync_seed() + rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) ) # invoke residual vq on each group diff --git a/vector_quantize_pytorch/residual_lfq.py b/vector_quantize_pytorch/residual_lfq.py index 2ea6407..98bdc54 100644 --- a/vector_quantize_pytorch/residual_lfq.py +++ b/vector_quantize_pytorch/residual_lfq.py @@ -31,8 +31,8 @@ def round_up_multiple(num, mult): def is_distributed(): return dist.is_initialized() and dist.get_world_size() > 1 -def get_maybe_sync_seed(max_size = 10_000): - rand_int = torch.randint(0, max_size, ()) +def get_maybe_sync_seed(device, max_size = 10_000): + rand_int = torch.randint(0, max_size, (), device = device) if is_distributed(): dist.all_reduce(rand_int) @@ -162,7 +162,7 @@ def forward( # check if seed is manually passed in if not exists(rand_quantize_dropout_fixed_seed): - rand_quantize_dropout_fixed_seed = get_maybe_sync_seed() + rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) rand = random.Random(rand_quantize_dropout_fixed_seed) @@ -262,7 +262,7 @@ def forward( mask = None, return_all_codes = False ): - shape, split_dim = x.shape, self.split_dim + shape, split_dim, device = x.shape, self.split_dim, x.device assert shape[split_dim] == self.dim # split the feature dimension into groups @@ -272,7 +272,7 @@ def forward( forward_kwargs = dict( mask = mask, return_all_codes = return_all_codes, - rand_quantize_dropout_fixed_seed = get_maybe_sync_seed() + rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) ) # invoke residual vq on each group diff --git a/vector_quantize_pytorch/residual_vq.py b/vector_quantize_pytorch/residual_vq.py index 768ceab..3e70ae2 100644 --- a/vector_quantize_pytorch/residual_vq.py +++ b/vector_quantize_pytorch/residual_vq.py @@ -36,8 +36,8 @@ def round_up_multiple(num, mult): def is_distributed(): return dist.is_initialized() and dist.get_world_size() > 1 -def get_maybe_sync_seed(max_size = 10_000): - rand_int = torch.randint(0, max_size, ()) +def get_maybe_sync_seed(device, max_size = 10_000): + rand_int = torch.randint(0, max_size, (), device = device) if is_distributed(): dist.all_reduce(rand_int) @@ -296,7 +296,7 @@ def forward( # check if seed is manually passed in if not exists(rand_quantize_dropout_fixed_seed): - rand_quantize_dropout_fixed_seed = get_maybe_sync_seed() + rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) rand = random.Random(rand_quantize_dropout_fixed_seed) @@ -452,7 +452,7 @@ def forward( freeze_codebook = False, mask = None, ): - shape, split_dim = x.shape, self.split_dim + shape, split_dim, device = x.shape, self.split_dim, x.device assert shape[split_dim] == self.dim # split the feature dimension into groups @@ -468,7 +468,7 @@ def forward( sample_codebook_temp = sample_codebook_temp, mask = mask, freeze_codebook = freeze_codebook, - rand_quantize_dropout_fixed_seed = get_maybe_sync_seed() + rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) ) # invoke residual vq on each group