Skip to content

Commit

Permalink
fix(torch-frontend): Fixes frontend implementation to correctly use a…
Browse files Browse the repository at this point in the history
…ll the arguments of ivy.inteprolate
  • Loading branch information
AnnaTz committed Oct 23, 2023
1 parent 3417ed2 commit b5422a7
Showing 1 changed file with 7 additions and 22 deletions.
29 changes: 7 additions & 22 deletions ivy/functional/frontends/torch/nn/functional/vision_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# global
import math

# local
import ivy
Expand Down Expand Up @@ -377,9 +376,6 @@ def interpolate(
" linear | bilinear | bicubic | trilinear"
),
)
else:
if not ivy.exists(align_corners):
align_corners = False

dim = ivy.get_num_dims(input) - 2 # Number of spatial dimensions.

Expand All @@ -389,8 +385,6 @@ def interpolate(
)

elif ivy.exists(size) and not ivy.exists(scale_factor):
scale_factors = None

if isinstance(size, (list, tuple)):
ivy.utils.assertions.check_equal(
len(size),
Expand All @@ -406,13 +400,7 @@ def interpolate(
),
as_array=False,
)
output_size = size
else:
output_size = [size for _ in range(dim)]

elif ivy.exists(scale_factor) and not ivy.exists(size):
output_size = None

if isinstance(scale_factor, (list, tuple)):
ivy.utils.assertions.check_equal(
len(scale_factor),
Expand All @@ -428,10 +416,6 @@ def interpolate(
),
as_array=False,
)
scale_factors = scale_factor
else:
scale_factors = [scale_factor for _ in range(dim)]

else:
ivy.utils.assertions.check_any(
[ivy.exists(size), ivy.exists(scale_factor)],
Expand All @@ -448,11 +432,6 @@ def interpolate(
"recompute_scale_factor is not meaningful with an explicit size."
)

if ivy.exists(scale_factors):
output_size = [
math.floor(ivy.shape(input)[i + 2] * scale_factors[i]) for i in range(dim)
]

if (
bool(antialias)
and (mode not in ["bilinear", "bicubic"])
Expand Down Expand Up @@ -494,7 +473,13 @@ def interpolate(
)

return ivy.interpolate(
input, output_size, mode=mode, align_corners=align_corners, antialias=antialias
input,
size,
mode=mode,
scale_factor=scale_factor,
recompute_scale_factor=recompute_scale_factor,
align_corners=align_corners,
antialias=antialias,
)


Expand Down

0 comments on commit b5422a7

Please sign in to comment.