Skip to content

Commit

Permalink
fix(torch-frontend): Handle default case for torch.Size when no input…
Browse files Browse the repository at this point in the history
… shape is provided. Also add numel to the Size class
  • Loading branch information
hmahmood24 committed Jan 24, 2024
1 parent 4a2ef0b commit e1724a8
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2298,6 +2298,7 @@ def _set(index):

class Size(tuple):
def __new__(cls, iterable=()):
iterable = ivy.Shape([]) if not iterable else iterable
new_iterable = []
for i, item in enumerate(iterable):
if isinstance(item, int):
Expand All @@ -2311,7 +2312,8 @@ def __new__(cls, iterable=()):
) from e
return super().__new__(cls, tuple(new_iterable))

def __init__(self, shape) -> None:
def __init__(self, shape=()) -> None:
shape = ivy.Shape([]) if not shape else shape
self._ivy_shape = shape if isinstance(shape, ivy.Shape) else ivy.shape(shape)

def __repr__(self):
Expand All @@ -2320,3 +2322,6 @@ def __repr__(self):
@property
def ivy_shape(self):
return self._ivy_shape

def numel(self):
return int(ivy.astype(ivy.prod(self), ivy.int64))

0 comments on commit e1724a8

Please sign in to comment.