Skip to content

Commit

Permalink
Fix tests on Julia 1.9 (#92)
Browse files Browse the repository at this point in the history
* Fix tests on Julia 1.9
* Add Julia 1.9 to the test matrix

---------

Co-authored-by: Andrei Zhabinski <[email protected]>
  • Loading branch information
dfdx authored May 30, 2023
1 parent d3deea5 commit cc0ddfd
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: CompatHelper
on:
schedule:
- cron: 0 0 * * *
- cron: 10 7 * * *
workflow_dispatch:
jobs:
CompatHelper:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
matrix:
version:
- '1.6'
- '1.9'
os:
- ubuntu-latest
arch:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ONNX"
uuid = "d0dd6a25-fac6-55c0-abf7-829e0c774d20"
version = "0.2.3"
version = "0.2.4"

[deps]
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Expand Down
30 changes: 24 additions & 6 deletions src/save.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,29 @@ add!(gp::GraphProto, tp::TensorProto) = push!(gp.initializer, tp)
# Utils #
##############################################################################

# can we make it more robust?
iskwfunc(f) = endswith(string(f), "##kw")
if VERSION < v"1.9"
# can we make it more robust?
iskwfunc(f) = endswith(string(f), "##kw")
else
iskwfunc(f) = (f === Core.kwcall)
end

function kwargs2dict(op::Umlaut.Call)
kw = iskwfunc(op.fn) ? op.args[1] : (;)
return Dict(zip(keys(kw), values(kw)))
end

macro opconfig_kw(backend, fn)
return quote
$OpConfig{$backend, <:Union{typeof($fn), typeof(Core.kwfunc($fn))}}
if VERSION < v"1.9.0"
macro opconfig_kw(backend, fn)
return quote
$OpConfig{$backend, <:Union{typeof($fn), typeof(Core.kwfunc($fn))}}
end
end
else
macro opconfig_kw(backend, fn)
return quote
$OpConfig{$backend, <:Union{typeof($fn)}}
end
end
end

Expand Down Expand Up @@ -83,7 +95,13 @@ onnx_name(op::Umlaut.AbstractOp) = "x$(op.id)"
Serialize a single operation from a tape to graph.
"""
function save_node!(g::GraphProto, op::Umlaut.Call)
save_node!(g, OpConfig{:ONNX, typeof(op.fn)}(), op)
if VERSION >= v"1.9" && op.fn == Core.kwcall
v_fn = op.args[2]
fn = v_fn isa V ? op.tape[v_fn].val : v_fn
save_node!(g, OpConfig{:ONNX, typeof(fn)}(), op)
else
save_node!(g, OpConfig{:ONNX, typeof(op.fn)}(), op)
end
end


Expand Down

0 comments on commit cc0ddfd

Please sign in to comment.