Skip to content

Commit

Permalink
add consistency checks in iterate
Browse files Browse the repository at this point in the history
* ensure that iterate generates `length(I)` values
* ensure that iterate always makes it to the end of all skips

addresses the most egregious bad behaviors in #7 and #31
  • Loading branch information
mbauman committed Dec 11, 2024
1 parent 2efe37d commit b337d91
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 14 deletions.
40 changes: 26 additions & 14 deletions src/InvertedIndices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,41 +97,53 @@ end
struct InvertedIndexIterator{T,S,P} <: AbstractVector{T}
skips::S
picks::P
length::Int
end
InvertedIndexIterator(skips, picks) = InvertedIndexIterator{eltype(picks), typeof(skips), typeof(picks)}(skips, picks)
Base.size(III::InvertedIndexIterator) = (length(III.picks) - length(III.skips),)
InvertedIndexIterator(skips, picks) = InvertedIndexIterator{eltype(picks), typeof(skips), typeof(picks)}(skips, picks, length(picks) - length(skips))
Base.size(III::InvertedIndexIterator) = (III.length,)

# Ensure iteration consumes all skips by the time it hits the end of the picks
assert_iteration_finished(I, n, ::Nothing) = (@assert n == I.length "InvertedIndexIterator iterated $n values but expected $(I.length)"; true)
assert_iteration_finished(I, _, (skipvalue, _)) = throw(ArgumentError("did not find index $skipvalue in axis $(I.picks), so could not skip it"))
# Ensure iteration does not generate more than I.length values
assert_iteration_not_finished(I, n, ::Nothing) = @assert n <= I.length "InvertedIndexIterator iterated more values than expected"
assert_iteration_not_finished(I, n, (skipvalue, _)) = n <= I.length || throw(ArgumentError("did not find index $skipvalue in axis $(I.picks), so could not skip it"))

@inline function Base.iterate(I::InvertedIndexIterator)
n = 0
skipitr = iterate(I.skips)
pickitr = iterate(I.picks)
pickitr === nothing && return nothing
pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing
while should_skip(skipitr, pickitr)
skipitr = iterate(I.skips, skipitr[2])
pickitr = iterate(I.picks, pickitr[2])
pickitr === nothing && return nothing
pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing
end
n += 1; assert_iteration_not_finished(I, n, skipitr)
# This is a little silly, but splitting the tuple here allows inference to normalize
# Tuple{Union{Nothing, Tuple}, Tuple} to Union{Tuple{Nothing, Tuple}, Tuple{Tuple, Tuple}}
return skipitr === nothing ?
(pickitr[1], (nothing, pickitr[2])) :
(pickitr[1], (skipitr, pickitr[2]))
(pickitr[1], (nothing, pickitr[2], n)) :
(pickitr[1], (skipitr, pickitr[2], n))
end
@inline function Base.iterate(I::InvertedIndexIterator, (_, pickstate)::Tuple{Nothing, Any})
@inline function Base.iterate(I::InvertedIndexIterator, (_, pickstate, n)::Tuple{Nothing, Any, Any})
pickitr = iterate(I.picks, pickstate)
pickitr === nothing && return nothing
return (pickitr[1], (nothing, pickitr[2]))
pickitr === nothing && assert_iteration_finished(I, n, nothing) && return nothing
n += 1; assert_iteration_not_finished(I, n, nothing)
return (pickitr[1], (nothing, pickitr[2], n))
end
@inline function Base.iterate(I::InvertedIndexIterator, (skipitr, pickstate)::Tuple)
@inline function Base.iterate(I::InvertedIndexIterator, (skipitr, pickstate, n)::Tuple)
pickitr = iterate(I.picks, pickstate)
pickitr === nothing && return nothing
pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing
while should_skip(skipitr, pickitr)
skipitr = iterate(I.skips, tail(skipitr)...)
pickitr = iterate(I.picks, tail(pickitr)...)
pickitr === nothing && return nothing
pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing
end
n += 1; assert_iteration_not_finished(I, n, skipitr)
return skipitr === nothing ?
(pickitr[1], (nothing, pickitr[2])) :
(pickitr[1], (skipitr, pickitr[2]))
(pickitr[1], (nothing, pickitr[2], n)) :
(pickitr[1], (skipitr, pickitr[2], n))
end
function Base.collect(III::InvertedIndexIterator{T}) where {T}
!isconcretetype(T) && return [i for i in III] # use widening if T is not concrete
Expand Down
35 changes: 35 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,38 @@ returns(val) = _->val
@test @inferred(LinearIndices(arr)[collect(I)]) == vec(filter(!iseven, arr))
end
end

struct NamedVector{T,A,B} <: AbstractArray{T,1}
data::A
names::B
end
function NamedVector(data, names)
@assert size(data) == size(names)
NamedVector{eltype(data), typeof(data), typeof(names)}(data, names)
end
Base.size(n::NamedVector) = size(n.data)
Base.getindex(n::NamedVector, i::Int) = n.data[i]
Base.to_index(n::NamedVector, name::Symbol) = findfirst(==(name), n.names)
Base.checkbounds(::Type{Bool}, n::NamedVector, names::AbstractArray{Symbol}) = all(name in n.names for name in names)

@testset "ensure skipped indices are skipped" begin
@test_throws "did not find" [1, 2, 3, 4][Not([1.5])]
@test_throws "did not find" [1, 2, 3, 4][Not(Not([1.5]))]
# Without error checking/checkbounds, this segfaults with a large enough array:
@test_throws "did not find" rand(100)[Not(begin+.5:end)]
@test_broken @test_throws "invalid index" [1, 2, 3, 4][Not(Integer[true, 2])]

n = NamedVector(1:4, [:a, :b, :c, :d]);
@test_broken n[Not([:a,:b])] == n[Not(1:2)] == [3, 4]
@test_broken n[Not([:c,:d])] == n[Not(3:4)] == [1, 2]
@test n[Not(:a)] == n[Not(1)] == [2,3,4]
@test n[Not(:b)] == n[Not(2)] == [1,3,4]

n = NamedVector(1:4, [:d, :b, :c, :a]);
@test_broken n[Not([:a,:b])] == n[Not([4,2])]== n[[:d,:c]] == [1, 3]
@test_broken n[Not([:c,:d])] == n[Not([3,1])] == n[[:b,:a]] == [2, 4]
@test n[Not(:a)] == n[Not(4)] == [1,2,3]
@test n[Not(:b)] == n[Not(2)] == [1,3,4]
@test n[Not(:c)] == n[Not(3)] == [1,2,4]
@test n[Not(:d)] == n[Not(1)] == [2,3,4]
end

0 comments on commit b337d91

Please sign in to comment.