From 3c04835189a2f7fb4704c62e4f347e32bab2374e Mon Sep 17 00:00:00 2001 From: Jack Rae <jwrae@google.com> Date: Wed, 18 Nov 2015 22:05:31 +0000 Subject: [PATCH] Add isSetTo: simple check for shared storage. Returns true iff Tensor is set to argument Tensor. Specifically, this is true iff tensor shares same storage as argument tensor, with same storage offset and identical sizes and strides. --- doc/tensor.md | 19 +++++++++++++++++++ generic/Tensor.c | 9 +++++++++ lib/TH/generic/THTensor.c | 17 +++++++++++++++++ lib/TH/generic/THTensor.h | 1 + test/test.lua | 11 +++++++++++ 5 files changed, 57 insertions(+) diff --git a/doc/tensor.md b/doc/tensor.md index 14ec2951..7fbf442b 100644 --- a/doc/tensor.md +++ b/doc/tensor.md @@ -880,6 +880,25 @@ y:zero() [torch.DoubleTensor of dimension 2x5] ``` +<a name="torch.Tensor.isSetTo"></a> +### [boolean] isSetTo(tensor) ### + +Returns true iff the `Tensor` is set to the argument `Tensor`. Note: this is +only true if the tensors are the same size, have the same strides and share the +same storage and offset. + +```lua +x = torch.Tensor(2,5) +y = torch.Tensor() +> y:isSetTo(x) + false +> y:set(x) +> y:isSetTo(x) + true +> y:t():isSetTo(x) + false -- x and y have different strides +``` + <a name="torch.Tensor.set"></a> ### [self] set(storage, [storageOffset, sizes, [strides]]) ### diff --git a/generic/Tensor.c b/generic/Tensor.c index 958f279c..39176150 100644 --- a/generic/Tensor.c +++ b/generic/Tensor.c @@ -638,6 +638,14 @@ static int torch_Tensor_(isSameSizeAs)(lua_State *L) return 1; } +static int torch_Tensor_(isSetTo)(lua_State *L) +{ + THTensor *tensor1 = luaT_checkudata(L, 1, torch_Tensor); + THTensor *tensor2 = luaT_checkudata(L, 2, torch_Tensor); + lua_pushboolean(L, THTensor_(isSetTo)(tensor1, tensor2)); + return 1; +} + static int torch_Tensor_(nElement)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); @@ -1288,6 +1296,7 @@ static const struct luaL_Reg torch_Tensor_(_) [] = { {"unfold", torch_Tensor_(unfold)}, {"isContiguous", torch_Tensor_(isContiguous)}, {"isSameSizeAs", torch_Tensor_(isSameSizeAs)}, + {"isSetTo", torch_Tensor_(isSetTo)}, {"isSize", torch_Tensor_(isSize)}, {"nElement", torch_Tensor_(nElement)}, {"copy", torch_Tensor_(copy)}, diff --git a/lib/TH/generic/THTensor.c b/lib/TH/generic/THTensor.c index cd67d067..a63da843 100644 --- a/lib/TH/generic/THTensor.c +++ b/lib/TH/generic/THTensor.c @@ -545,6 +545,23 @@ int THTensor_(isSameSizeAs)(const THTensor *self, const THTensor* src) return 1; } +int THTensor_(isSetTo)(const THTensor *self, const THTensor* src) +{ + if (self->storage == src->storage && + self->storageOffset == src->storageOffset && + self->nDimension == src->nDimension) + { + int d; + for (d = 0; d < self->nDimension; ++d) + { + if (self->size[d] != src->size[d] || self->stride[d] != src->stride[d]) + return 0; + } + return 1; + } + return 0; +} + long THTensor_(nElement)(const THTensor *self) { if(self->nDimension == 0) diff --git a/lib/TH/generic/THTensor.h b/lib/TH/generic/THTensor.h index a8dffdfc..7a3d5859 100644 --- a/lib/TH/generic/THTensor.h +++ b/lib/TH/generic/THTensor.h @@ -104,6 +104,7 @@ TH_API void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension_); TH_API int THTensor_(isContiguous)(const THTensor *self); TH_API int THTensor_(isSameSizeAs)(const THTensor *self, const THTensor *src); +TH_API int THTensor_(isSetTo)(const THTensor *self, const THTensor *src); TH_API int THTensor_(isSize)(const THTensor *self, const THLongStorage *dims); TH_API long THTensor_(nElement)(const THTensor *self); diff --git a/test/test.lua b/test/test.lua index c9e55ca1..62408cfe 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2690,6 +2690,17 @@ function torchtest.isSameSizeAs() mytester:assert(t1:isSameSizeAs(t4) == true, "wrong answer ") end +function torchtest.isSetTo() + local t1 = torch.Tensor(3, 4, 9, 10) + local t2 = torch.Tensor(3, 4, 9, 10) + local t3 = torch.Tensor():set(t1) + local t4 = t3:reshape(12, 90) + mytester:assert(t1:isSetTo(t2) == false, "tensors do not share storage") + mytester:assert(t1:isSetTo(t3) == true, "tensor is set to other") + mytester:assert(t3:isSetTo(t1) == true, "isSetTo should be symmetric") + mytester:assert(t1:isSetTo(t4) == false, "tensors have different view") +end + function torchtest.isSize() local t1 = torch.Tensor(3, 4, 5) local s1 = torch.LongStorage({3, 4, 5})