Skip to content

Commit

Permalink
Fix split, uncomment tests (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfdx authored Oct 22, 2023
1 parent fb7e66d commit 7e8d640
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ONNX"
uuid = "d0dd6a25-fac6-55c0-abf7-829e0c774d20"
version = "0.2.5"
version = "0.2.6"

[deps]
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Expand Down
8 changes: 4 additions & 4 deletions src/load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,21 +222,21 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Slice}, args::VarVec, attrs::
return push_call!(tape, onnx_slice, args...)
end

function load_node!(tape::Tape, ::OpConfig{:ONNX, :Split}, inputs::VarVec, attrs::AttrDict)
function load_node!(tape::Tape, ::OpConfig{:ONNX, :Split}, args::VarVec, attrs::AttrDict)
axis = get(attrs, :axis, 0)
split = if haskey(attrs, :split) # Version 1, 2, 11
attrs[:split]
elseif length(args) == 2
inputs[2]
args[2]
else
# the results cannot be split in multiple outputs on the tape
# if the output size is not known during tracing.
error("Unhandled case where split is not provided")
end
out = push_call!(tape, onnx_split, first(inputs), split; axis)
out = push_call!(tape, onnx_split, first(args), split; axis)
return Tuple(
push_call!(tape, getfield, out, i)
for i in eachindex(split)
for i in eachindex(split.op.val)
)
end

Expand Down
27 changes: 13 additions & 14 deletions test/saveload.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,17 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
ort_test(ONNX.onnx_concat, [1 2 3; 1 2 3], [4 5; 4 5]; axis=0)
end

# TODO: Split is not implemented in ONNXRuntime.jl
# @testset "Split" begin
# x = rand(3, 20, 10); split = [5, 10, 5];
# args = (x, split)
# tape = Tape(ONNXCtx())
# inp = [push!(tape, Input(a)) for a in args]
# out = push_call!(tape, ONNX.onnx_split, inp...; axis=1)
# push_call!(tape, getfield, out, 1)
# push_call!(tape, getfield, out, 2)
# push_call!(tape, getfield, out, 3)
# tape.result = out

# ort_test(tape, args...)
# end
@testset "Split" begin
x = rand(3, 20, 10); split = [5, 10, 5];
args = (x, split)
tape = Tape(ONNXCtx())
inp = [push!(tape, Input(a)) for a in args]
out = push_call!(tape, ONNX.onnx_split, inp...; axis=1)
push_call!(tape, getfield, out, 1)
push_call!(tape, getfield, out, 2)
push_call!(tape, getfield, out, 3)
tape.result = out

ort_test(tape, args...)
end
end

0 comments on commit 7e8d640

Please sign in to comment.