diff --git a/src/InvertedIndices.jl b/src/InvertedIndices.jl index bfabd0d..c619746 100644 --- a/src/InvertedIndices.jl +++ b/src/InvertedIndices.jl @@ -97,41 +97,53 @@ end struct InvertedIndexIterator{T,S,P} <: AbstractVector{T} skips::S picks::P + length::Int end -InvertedIndexIterator(skips, picks) = InvertedIndexIterator{eltype(picks), typeof(skips), typeof(picks)}(skips, picks) -Base.size(III::InvertedIndexIterator) = (length(III.picks) - length(III.skips),) +InvertedIndexIterator(skips, picks) = InvertedIndexIterator{eltype(picks), typeof(skips), typeof(picks)}(skips, picks, length(picks) - length(skips)) +Base.size(III::InvertedIndexIterator) = (III.length,) + +# Ensure iteration consumes all skips by the time it hits the end of the picks +assert_iteration_finished(I, n, ::Nothing) = (@assert n == I.length "InvertedIndexIterator iterated $n values but expected $(I.length)"; true) +assert_iteration_finished(I, _, (skipvalue, _)) = throw(ArgumentError("did not find index $skipvalue in axis $(I.picks), so could not skip it")) +# Ensure iteration does not generate more than I.length values +assert_iteration_not_finished(I, n, ::Nothing) = @assert n <= I.length "InvertedIndexIterator iterated more values than expected" +assert_iteration_not_finished(I, n, (skipvalue, _)) = n <= I.length || throw(ArgumentError("did not find index $skipvalue in axis $(I.picks), so could not skip it")) @inline function Base.iterate(I::InvertedIndexIterator) + n = 0 skipitr = iterate(I.skips) pickitr = iterate(I.picks) - pickitr === nothing && return nothing + pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing while should_skip(skipitr, pickitr) skipitr = iterate(I.skips, skipitr[2]) pickitr = iterate(I.picks, pickitr[2]) - pickitr === nothing && return nothing + pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing end + n += 1; assert_iteration_not_finished(I, n, skipitr) # This is a little silly, but splitting the tuple here allows inference to normalize # Tuple{Union{Nothing, Tuple}, Tuple} to Union{Tuple{Nothing, Tuple}, Tuple{Tuple, Tuple}} return skipitr === nothing ? - (pickitr[1], (nothing, pickitr[2])) : - (pickitr[1], (skipitr, pickitr[2])) + (pickitr[1], (nothing, pickitr[2], n)) : + (pickitr[1], (skipitr, pickitr[2], n)) end -@inline function Base.iterate(I::InvertedIndexIterator, (_, pickstate)::Tuple{Nothing, Any}) +@inline function Base.iterate(I::InvertedIndexIterator, (_, pickstate, n)::Tuple{Nothing, Any, Any}) pickitr = iterate(I.picks, pickstate) - pickitr === nothing && return nothing - return (pickitr[1], (nothing, pickitr[2])) + pickitr === nothing && assert_iteration_finished(I, n, nothing) && return nothing + n += 1; assert_iteration_not_finished(I, n, nothing) + return (pickitr[1], (nothing, pickitr[2], n)) end -@inline function Base.iterate(I::InvertedIndexIterator, (skipitr, pickstate)::Tuple) +@inline function Base.iterate(I::InvertedIndexIterator, (skipitr, pickstate, n)::Tuple) pickitr = iterate(I.picks, pickstate) - pickitr === nothing && return nothing + pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing while should_skip(skipitr, pickitr) skipitr = iterate(I.skips, tail(skipitr)...) pickitr = iterate(I.picks, tail(pickitr)...) - pickitr === nothing && return nothing + pickitr === nothing && assert_iteration_finished(I, n, skipitr) && return nothing end + n += 1; assert_iteration_not_finished(I, n, skipitr) return skipitr === nothing ? - (pickitr[1], (nothing, pickitr[2])) : - (pickitr[1], (skipitr, pickitr[2])) + (pickitr[1], (nothing, pickitr[2], n)) : + (pickitr[1], (skipitr, pickitr[2], n)) end function Base.collect(III::InvertedIndexIterator{T}) where {T} !isconcretetype(T) && return [i for i in III] # use widening if T is not concrete diff --git a/test/runtests.jl b/test/runtests.jl index d9ab763..a69a7bc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -203,3 +203,38 @@ returns(val) = _->val @test @inferred(LinearIndices(arr)[collect(I)]) == vec(filter(!iseven, arr)) end end + +struct NamedVector{T,A,B} <: AbstractArray{T,1} + data::A + names::B +end +function NamedVector(data, names) + @assert size(data) == size(names) + NamedVector{eltype(data), typeof(data), typeof(names)}(data, names) +end +Base.size(n::NamedVector) = size(n.data) +Base.getindex(n::NamedVector, i::Int) = n.data[i] +Base.to_index(n::NamedVector, name::Symbol) = findfirst(==(name), n.names) +Base.checkbounds(::Type{Bool}, n::NamedVector, names::AbstractArray{Symbol}) = all(name in n.names for name in names) + +@testset "ensure skipped indices are skipped" begin + @test_throws "did not find" [1, 2, 3, 4][Not([1.5])] + @test_throws "did not find" [1, 2, 3, 4][Not(Not([1.5]))] + # Without error checking/checkbounds, this segfaults with a large enough array: + @test_throws "did not find" rand(100)[Not(begin+.5:end)] + @test_broken @test_throws "invalid index" [1, 2, 3, 4][Not(Integer[true, 2])] + + n = NamedVector(1:4, [:a, :b, :c, :d]); + @test_broken n[Not([:a,:b])] == n[Not(1:2)] == [3, 4] + @test_broken n[Not([:c,:d])] == n[Not(3:4)] == [1, 2] + @test n[Not(:a)] == n[Not(1)] == [2,3,4] + @test n[Not(:b)] == n[Not(2)] == [1,3,4] + + n = NamedVector(1:4, [:d, :b, :c, :a]); + @test_broken n[Not([:a,:b])] == n[Not([4,2])]== n[[:d,:c]] == [1, 3] + @test_broken n[Not([:c,:d])] == n[Not([3,1])] == n[[:b,:a]] == [2, 4] + @test n[Not(:a)] == n[Not(4)] == [1,2,3] + @test n[Not(:b)] == n[Not(2)] == [1,3,4] + @test n[Not(:c)] == n[Not(3)] == [1,2,4] + @test n[Not(:d)] == n[Not(1)] == [2,3,4] +end