Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(iterator): ensure HasNext does not advance state #47

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 54 additions & 28 deletions tree_iterator.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
package art

import "errors"

// state represents the iteration state during tree traversal.
type state struct {
items []*iteratorContext
}

// push adds a new iterator context to the state.
func (s *state) push(ctx *iteratorContext) {
s.items = append(s.items, ctx)
}

// current returns the current iterator context and a flag indicating if there is any.
func (s *state) current() (*iteratorContext, bool) {
if len(s.items) == 0 {
return nil, false
Expand All @@ -17,6 +21,7 @@ func (s *state) current() (*iteratorContext, bool) {
return s.items[len(s.items)-1], true
}

// discard removes the last iterator context from the state.
func (s *state) discard() {
if len(s.items) == 0 {
return
Expand Down Expand Up @@ -67,6 +72,7 @@ type iterator struct {
// assert that iterator implements the Iterator interface.
var _ Iterator = (*iterator)(nil)

// newTreeIterator creates a new tree iterator.
func newTreeIterator(tr *tree, opts traverseOpts) Iterator {
state := &state{}
state.push(newIteratorContext(tr.root, opts.hasReverse()))
Expand All @@ -83,10 +89,15 @@ func newTreeIterator(tr *tree, opts traverseOpts) Iterator {
return it
}

return &bufferedIterator{
bit := &bufferedIterator{
opts: opts,
it: it,
}

// peek the first node or leaf
bit.peek()

return bit
}

// hasConcurrentModification checks if the tree has been modified concurrently.
Expand Down Expand Up @@ -148,52 +159,67 @@ type bufferedIterator struct {
nextErr error
}

// HasNext returns true if there are more nodes to iterate.
func (bit *bufferedIterator) HasNext() bool {
for bit.hasNext() {
nxt, err := bit.peek()
if err != nil {
return true
}
return bit.nextNode != nil
}

// are we looking for a leaf node?
if bit.hasLeafIterator() && nxt.Kind() == Leaf {
return true
}
// Next returns the next node or leaf node and an error if any.
// ErrNoMoreNodes is returned if there are no more nodes to iterate.
// ErrConcurrentModification is returned if the tree has been modified concurrently.
func (bit *bufferedIterator) Next() (Node, error) {
current := bit.nextNode

// are we looking for a non-leaf node?
if bit.hasNodeIterator() && nxt.Kind() != Leaf {
return true
}
if !bit.HasNext() {
return nil, bit.nextErr
}

bit.resetNext()
bit.peek()

return false
}
// ErrConcurrentModification should be returned immediately.
// ErrNoMoreNodes will be return on the next call.
if errors.Is(bit.nextErr, ErrConcurrentModification) {
return nil, bit.nextErr
}

func (bit *bufferedIterator) Next() (Node, error) {
return bit.nextNode, bit.nextErr
return current, nil
}

// hasLeafIterator checks if the iterator is for leaf nodes.
func (bit *bufferedIterator) hasLeafIterator() bool {
return bit.opts&TraverseLeaf == TraverseLeaf
}

// hasNodeIterator checks if the iterator is for non-leaf nodes.
func (bit *bufferedIterator) hasNodeIterator() bool {
return bit.opts&TraverseNode == TraverseNode
}

func (bit *bufferedIterator) hasNext() bool {
return bit.it.HasNext()
// peek looks for the next node or leaf node to iterate.
func (bit *bufferedIterator) peek() {
for {
bit.nextNode, bit.nextErr = bit.it.Next()
if bit.nextErr != nil {
return
}

if bit.matchesFilter() {
return
}
}
}

func (bit *bufferedIterator) peek() (Node, error) {
bit.nextNode, bit.nextErr = bit.it.Next()
// matchesFilter checks if the next node matches the iterator filter.
func (bit *bufferedIterator) matchesFilter() bool {
// check if the iterator is looking for leaf nodes
if bit.hasLeafIterator() && bit.nextNode.Kind() == Leaf {
return true
}

return bit.nextNode, bit.nextErr
}
// check if the iterator is looking for non-leaf nodes
if bit.hasNodeIterator() && bit.nextNode.Kind() != Leaf {
return true
}

func (bit *bufferedIterator) resetNext() {
bit.nextNode = nil
bit.nextErr = nil
return false
}
59 changes: 59 additions & 0 deletions tree_traversal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,4 +529,63 @@ func TestTreeIterateWordsStats(t *testing.T) {

stats = collectStats(tree.Iterator(TraverseNode))
assert.Equal(t, treeStats{0, 113419, 10433, 403, 1}, stats)

// by default Iterator traverses only leaf nodes
stats = collectStats(tree.Iterator())
assert.Equal(t, treeStats{235886, 0, 0, 0, 0}, stats)
}

func TestIteratorHasNextDoesNotAdvanceState(t *testing.T) {
t.Parallel()

tree := newTree()
tree.Insert(Key("1"), []byte{1})
tree.Insert(Key("2"), []byte{2})

iter := tree.Iterator()

// HasNext should not advance the iterator state
assert.True(t, iter.HasNext())
assert.True(t, iter.HasNext())

// change the iterator state
n, err := iter.Next()
require.NoError(t, err)
assert.Equal(t, Key("1"), n.Key())

// HasNext remains idempotent
assert.True(t, iter.HasNext())
assert.True(t, iter.HasNext())

// advance to the second key
n, err = iter.Next()
require.NoError(t, err)
assert.Equal(t, Key("2"), n.Key())

// HasNext returns false at the end
assert.False(t, iter.HasNext())
assert.False(t, iter.HasNext())

// calling Next after the iterator is exhausted
for i := 0; i < 2; i++ {
n, err = iter.Next()
assert.Nil(t, n, "Next() should return nil after exhaustion")
assert.Equal(t, ErrNoMoreNodes, err, "Next() should return ErrNoMoreNodes after exhaustion")
}
}

func TestIteratorEmptyTreeBehavior(t *testing.T) {
t.Parallel()

tree := New()
iter := tree.Iterator()

// HasNext should return false for an empty tree
assert.False(t, iter.HasNext())
assert.False(t, iter.HasNext())

// Next should return nil and ErrNoMoreNodes for an empty tree
n, err := iter.Next()
assert.Nil(t, n)
assert.Equal(t, ErrNoMoreNodes, err)
}
Loading