diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index 64c4bc6a0afee7..abdb3cc52d0ea2 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -1653,6 +1653,11 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { MPSStream* stream = getCurrentMPSStream(); + bool executeGatherOp = + !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || + self.is_contiguous(MemoryFormat::ChannelsLast3d)); + Tensor result_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); + @autoreleasepool { string key = "silu_out_mps:" + getTensorsStringKey({self}); @@ -1673,12 +1678,16 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { newCachedGraph->outputTensor_ = outputTensor; }); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, executeGatherOp ? result_ : result, nil, false); auto feeds = dictionaryFromPlaceholders(selfPlaceholder); runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); } + if (executeGatherOp) { + result.copy_(result_); + } } TORCH_IMPL_FUNC(silu_backward_out_mps) diff --git a/test/test_mps.py b/test/test_mps.py index 7315338fc0701e..4bb3fcc6a33d4d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -6792,9 +6792,18 @@ def helper(shape, beta, threshold, dtype): # Test silu def test_silu(self): - def helper(shape): - cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) - x = cpu_x.detach().clone().to('mps').requires_grad_() + def helper(shape, contiguous=True): + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) + x = cpu_x.detach().clone().to('mps') + + if not contiguous and (0 not in shape and len(shape) >= 2): + # Tranposing will make the tensor non-contiguous + cpu_x = cpu_x.transpose(0, 1) + x = x.transpose(0, 1) + assert not x.is_contiguous() + + cpu_x.requires_grad_() + x.requires_grad_() silu_result = torch.nn.SiLU()(x) silu_result_cpu = torch.nn.SiLU()(cpu_x) @@ -6810,7 +6819,8 @@ def helper(shape): # Test empty shape too for shape in [[], (2, 3), (2, 8, 4, 5)]: - helper(shape) + for contiguous in [True, False]: + helper(shape, contiguous) def test_cast_mps_to_cpu(self): def helper(src_dtype, dst_dtype):