Skip to content

Commit

Permalink
Fix (export): export from non CPU devices
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Pappalardo <[email protected]>
  • Loading branch information
volcacius committed Jun 4, 2021
1 parent cfce308 commit b21c539
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 16 deletions.
10 changes: 5 additions & 5 deletions src/brevitas/export/onnx/finn/function/acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
class QuantAvgPool2dPlaceholderFunction(Function):

@staticmethod
def symbolic(g, input, out_shape, kernel, stride, signed, ibits, obits, scale, qnt_type):
def symbolic(g, x, out_shape, kernel, stride, signed, ibits, obits, scale, qnt_type):
if scale is not None:
input = g.op('Div', input, scale, activation_qnt_s=qnt_type)
x = g.op('Div', x, scale, activation_qnt_s=qnt_type)
ret = g.op(
'QuantAvgPool2d', input,
'QuantAvgPool2d', x,
domain_s="finn.custom_op.general",
kernel_i=kernel,
stride_i=stride,
Expand All @@ -21,5 +21,5 @@ def symbolic(g, input, out_shape, kernel, stride, signed, ibits, obits, scale, q
return ret

@staticmethod
def forward(ctx, input, out_shape, kernel, stride, signed, ibits, obits, scale, qnt_type):
return torch.empty(out_shape, dtype=torch.float)
def forward(ctx, x, out_shape, kernel, stride, signed, ibits, obits, scale, qnt_type):
return torch.empty(out_shape, dtype=torch.float, device=x.device)
5 changes: 2 additions & 3 deletions src/brevitas/export/onnx/finn/function/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def symbolic(g, x, Wt, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_sha

@staticmethod
def forward(ctx, x, Wt, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, bias):
return torch.empty(out_shape, dtype = torch.float)

return torch.empty(out_shape, dtype=torch.float, device=x.device)

class QuantizedConvNdPlaceholderFunction(Function):

Expand Down Expand Up @@ -52,4 +51,4 @@ def symbolic(
@staticmethod
def forward(
ctx, x, W, w_qnt_scale, b_qnt_scale, w_qnt_type, b_qnt_type, out_shape, pads, strides, bias, kernel_shape, groups, dilations):
return torch.empty(out_shape, dtype = torch.float)
return torch.empty(out_shape, dtype=torch.float, device=x.device)
4 changes: 2 additions & 2 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def forward(
stride,
groups,
dilation):
return torch.empty(out_shape).type(output_zero_point.dtype)
return torch.empty(out_shape, dtype=output_zero_point.dtype, device=int_x.device)


class QLinearMatMulFunction(Function):
Expand Down Expand Up @@ -172,4 +172,4 @@ def forward(
output_scale,
output_zero_point,
out_shape):
return torch.empty(out_shape).type(output_zero_point.dtype)
return torch.empty(out_shape, dtype=output_zero_point.dtype, device=int_x.device)
8 changes: 4 additions & 4 deletions src/brevitas/export/onnx/vitis_ai/pyxir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def forward(
input_scale,
output_bit_width,
output_scale):
return torch.empty(out_shape, dtype=torch.float)
return torch.empty(out_shape, dtype=torch.float, device=x.device)


class DPUQuantEltwiseAddPlaceholderFunction(Function):
Expand Down Expand Up @@ -146,7 +146,7 @@ def forward(
input_scale,
output_bit_width,
output_scale):
return torch.empty(out_shape, dtype=torch.float)
return torch.empty(out_shape, dtype=torch.float, device=x.device)


class DPUQuantConv2dPlaceholderFunction(Function):
Expand Down Expand Up @@ -192,7 +192,7 @@ def forward(
stride,
groups,
dilation):
return torch.empty(out_shape, dtype=torch.float)
return torch.empty(out_shape, dtype=torch.float, device=x.device)


class DPUQuantLinearPlaceholderFunction(Function):
Expand Down Expand Up @@ -263,4 +263,4 @@ def forward(
weight_scale,
int_bias_bit_width,
int_bias_scale):
return torch.empty(out_shape, dtype=torch.float)
return torch.empty(out_shape, dtype=torch.float, device=x.device)
4 changes: 2 additions & 2 deletions src/brevitas/export/onnx/vitis_ai/xir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def symbolic(
def forward(
ctx, x, weight, bias, is_depthwise, kernel_size,
padding, padding_type, stride, dilation, output_shape):
return torch.empty(output_shape)
return torch.empty(output_shape, dtype=x.dtype, device=x.device)


class XIRConvTranpose2dPlaceholderFunction(Function):
Expand Down Expand Up @@ -138,4 +138,4 @@ def symbolic(
def forward(
ctx, x, weight, bias, is_depthwise, kernel_size,
padding, padding_type, stride, dilation, output_shape):
return torch.empty(output_shape)
return torch.empty(output_shape, dtype=x.dtype, device=x.device)

0 comments on commit b21c539

Please sign in to comment.