Skip to content

Commit

Permalink
Merge pull request torch#467 from dm-jrae/master
Browse files Browse the repository at this point in the history
Add isSetTo: simple check for shared storage.
  • Loading branch information
soumith committed Nov 19, 2015
2 parents c2b91e6 + 3c04835 commit 840e731
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 0 deletions.
19 changes: 19 additions & 0 deletions doc/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,25 @@ y:zero()
[torch.DoubleTensor of dimension 2x5]
```

<a name="torch.Tensor.isSetTo"></a>
### [boolean] isSetTo(tensor) ###

Returns true iff the `Tensor` is set to the argument `Tensor`. Note: this is
only true if the tensors are the same size, have the same strides and share the
same storage and offset.

```lua
x = torch.Tensor(2,5)
y = torch.Tensor()
> y:isSetTo(x)
false
> y:set(x)
> y:isSetTo(x)
true
> y:t():isSetTo(x)
false -- x and y have different strides
```

<a name="torch.Tensor.set"></a>
### [self] set(storage, [storageOffset, sizes, [strides]]) ###

Expand Down
9 changes: 9 additions & 0 deletions generic/Tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,14 @@ static int torch_Tensor_(isSameSizeAs)(lua_State *L)
return 1;
}

static int torch_Tensor_(isSetTo)(lua_State *L)
{
THTensor *tensor1 = luaT_checkudata(L, 1, torch_Tensor);
THTensor *tensor2 = luaT_checkudata(L, 2, torch_Tensor);
lua_pushboolean(L, THTensor_(isSetTo)(tensor1, tensor2));
return 1;
}

static int torch_Tensor_(nElement)(lua_State *L)
{
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
Expand Down Expand Up @@ -1288,6 +1296,7 @@ static const struct luaL_Reg torch_Tensor_(_) [] = {
{"unfold", torch_Tensor_(unfold)},
{"isContiguous", torch_Tensor_(isContiguous)},
{"isSameSizeAs", torch_Tensor_(isSameSizeAs)},
{"isSetTo", torch_Tensor_(isSetTo)},
{"isSize", torch_Tensor_(isSize)},
{"nElement", torch_Tensor_(nElement)},
{"copy", torch_Tensor_(copy)},
Expand Down
17 changes: 17 additions & 0 deletions lib/TH/generic/THTensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,23 @@ int THTensor_(isSameSizeAs)(const THTensor *self, const THTensor* src)
return 1;
}

int THTensor_(isSetTo)(const THTensor *self, const THTensor* src)
{
if (self->storage == src->storage &&
self->storageOffset == src->storageOffset &&
self->nDimension == src->nDimension)
{
int d;
for (d = 0; d < self->nDimension; ++d)
{
if (self->size[d] != src->size[d] || self->stride[d] != src->stride[d])
return 0;
}
return 1;
}
return 0;
}

long THTensor_(nElement)(const THTensor *self)
{
if(self->nDimension == 0)
Expand Down
1 change: 1 addition & 0 deletions lib/TH/generic/THTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ TH_API void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension_);

TH_API int THTensor_(isContiguous)(const THTensor *self);
TH_API int THTensor_(isSameSizeAs)(const THTensor *self, const THTensor *src);
TH_API int THTensor_(isSetTo)(const THTensor *self, const THTensor *src);
TH_API int THTensor_(isSize)(const THTensor *self, const THLongStorage *dims);
TH_API long THTensor_(nElement)(const THTensor *self);

Expand Down
11 changes: 11 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2690,6 +2690,17 @@ function torchtest.isSameSizeAs()
mytester:assert(t1:isSameSizeAs(t4) == true, "wrong answer ")
end

function torchtest.isSetTo()
local t1 = torch.Tensor(3, 4, 9, 10)
local t2 = torch.Tensor(3, 4, 9, 10)
local t3 = torch.Tensor():set(t1)
local t4 = t3:reshape(12, 90)
mytester:assert(t1:isSetTo(t2) == false, "tensors do not share storage")
mytester:assert(t1:isSetTo(t3) == true, "tensor is set to other")
mytester:assert(t3:isSetTo(t1) == true, "isSetTo should be symmetric")
mytester:assert(t1:isSetTo(t4) == false, "tensors have different view")
end

function torchtest.isSize()
local t1 = torch.Tensor(3, 4, 5)
local s1 = torch.LongStorage({3, 4, 5})
Expand Down

0 comments on commit 840e731

Please sign in to comment.