diff --git a/pyproject.toml b/pyproject.toml index 52b2341..039d932 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vector-quantize-pytorch" -version = "1.16.3" +version = "1.17.1" description = "Vector Quantization - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/vector_quantize_pytorch/finite_scalar_quantization.py b/vector_quantize_pytorch/finite_scalar_quantization.py index b59b697..c6fbeda 100644 --- a/vector_quantize_pytorch/finite_scalar_quantization.py +++ b/vector_quantize_pytorch/finite_scalar_quantization.py @@ -12,7 +12,7 @@ import torch.nn as nn from torch.nn import Module from torch import Tensor, int32 -from torch.cuda.amp import autocast +from torch.amp import autocast from einops import rearrange, pack, unpack @@ -159,7 +159,7 @@ def indices_to_codes(self, indices): return codes - @autocast(enabled = False) + @autocast('cuda', enabled = False) def forward(self, z): """ einstein notation @@ -187,7 +187,7 @@ def forward(self, z): # whether to force quantization step to be full precision or not force_f32 = self.force_quantization_f32 - quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext + quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext with quantization_context(): orig_dtype = z.dtype diff --git a/vector_quantize_pytorch/lookup_free_quantization.py b/vector_quantize_pytorch/lookup_free_quantization.py index 9a598ae..86e8d7e 100644 --- a/vector_quantize_pytorch/lookup_free_quantization.py +++ b/vector_quantize_pytorch/lookup_free_quantization.py @@ -16,7 +16,7 @@ from torch import nn, einsum import torch.nn.functional as F from torch.nn import Module -from torch.cuda.amp import autocast +from torch.amp import autocast from einops import rearrange, reduce, pack, unpack @@ -293,7 +293,7 @@ def forward( force_f32 = self.force_quantization_f32 - quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext + quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext with quantization_context(): diff --git a/vector_quantize_pytorch/residual_fsq.py b/vector_quantize_pytorch/residual_fsq.py index 1b14c75..fc37cf5 100644 --- a/vector_quantize_pytorch/residual_fsq.py +++ b/vector_quantize_pytorch/residual_fsq.py @@ -8,7 +8,7 @@ from torch import nn from torch.nn import Module, ModuleList import torch.nn.functional as F -from torch.cuda.amp import autocast +from torch.amp import autocast from vector_quantize_pytorch.finite_scalar_quantization import FSQ @@ -167,7 +167,7 @@ def forward( # go through the layers - with autocast(enabled = False): + with autocast('cuda', enabled = False): for quantizer_index, (layer, scale) in enumerate(zip(self.layers, self.scales)): if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: diff --git a/vector_quantize_pytorch/residual_lfq.py b/vector_quantize_pytorch/residual_lfq.py index 0bb855a..992c88d 100644 --- a/vector_quantize_pytorch/residual_lfq.py +++ b/vector_quantize_pytorch/residual_lfq.py @@ -6,7 +6,7 @@ from torch import nn from torch.nn import Module, ModuleList import torch.nn.functional as F -from torch.cuda.amp import autocast +from torch.amp import autocast from vector_quantize_pytorch.lookup_free_quantization import LFQ @@ -156,7 +156,7 @@ def forward( # go through the layers - with autocast(enabled = False): + with autocast('cuda', enabled = False): for quantizer_index, layer in enumerate(self.layers): if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: diff --git a/vector_quantize_pytorch/vector_quantize_pytorch.py b/vector_quantize_pytorch/vector_quantize_pytorch.py index 1055c73..9f14e58 100644 --- a/vector_quantize_pytorch/vector_quantize_pytorch.py +++ b/vector_quantize_pytorch/vector_quantize_pytorch.py @@ -9,7 +9,7 @@ import torch.nn.functional as F import torch.distributed as distributed from torch.optim import Optimizer -from torch.cuda.amp import autocast +from torch.amp import autocast import einx from einops import rearrange, repeat, reduce, pack, unpack @@ -458,7 +458,7 @@ def expire_codes_(self, batch_samples): batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') self.replace(batch_samples, batch_mask = expired_codes) - @autocast(enabled = False) + @autocast('cuda', enabled = False) def forward( self, x, @@ -671,7 +671,7 @@ def expire_codes_(self, batch_samples): batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') self.replace(batch_samples, batch_mask = expired_codes) - @autocast(enabled = False) + @autocast('cuda', enabled = False) def forward( self, x,