From 6639eab0f1bd0c076f6f5fea321765e7bf69ee10 Mon Sep 17 00:00:00 2001 From: Karl Gaissmaier Date: Mon, 6 May 2024 15:13:53 +0200 Subject: [PATCH 1/2] simplify --- table.go | 252 +++++++++++++++++++++++++------------------------------ 1 file changed, 112 insertions(+), 140 deletions(-) diff --git a/table.go b/table.go index 60a7b69..f4cd59f 100644 --- a/table.go +++ b/table.go @@ -72,18 +72,11 @@ func (t *Table[V]) Insert(pfx netip.Prefix, val V) { lastOctetBits := bits - (lastOctetIdx * strideLen) // find the proper trie node to insert prefix - for i, octet := range octets { - if i == lastOctetIdx { - // insert prefix into node - n.insertPrefix(lastOctet, lastOctetBits, val) - return - } - + for _, octet := range octets[:lastOctetIdx] { // descend down to next trie level c := n.getChild(octet) - - // create and insert missing intermediate child if c == nil { + // create and insert missing intermediate child c = newNode[V]() n.insertChild(octet, c) } @@ -91,17 +84,22 @@ func (t *Table[V]) Insert(pfx netip.Prefix, val V) { // proceed with next level n = c } + + // insert prefix into node + n.insertPrefix(lastOctet, lastOctetBits, val) } -// Delete removes pfx from the tree, pfx does not have to be present. -func (t *Table[V]) Delete(pfx netip.Prefix) { +// Update or set the value at pfx with a callback function. +// The callback function is called with (value, ok) and returns a new value.. +// +// If the pfx does not already exist, it is set with the new value. +func (t *Table[V]) Update(pfx netip.Prefix, cb func(val V, ok bool) V) V { + t.init() + // some values derived from pfx _, ip, bits, is4 := pfxToValues(pfx) n := t.rootNodeByVersion(is4) - if n == nil { - return - } // do not allocate octets := make([]byte, 0, 16) @@ -112,64 +110,34 @@ func (t *Table[V]) Delete(pfx netip.Prefix) { lastOctet := octets[lastOctetIdx] lastOctetBits := bits - (lastOctetIdx * strideLen) - // record path to deleted node - // purge dangling nodes after deletion - stack := [maxTreeDepth]*node[V]{} - - // run variables, needed after for loop - var i int - var octet byte - - // find the trie node - for i, octet = range octets { - // push current node on stack for path recording - stack[i] = n - - if i == lastOctetIdx { - if !n.deletePrefix(lastOctet, lastOctetBits) { - // prefix not in tree, nothing deleted - return - } - - // escape, but purge dangling path if needed, see below - break - } - - // descend down to next level - if c := n.getChild(octet); c != nil { - n = c - continue + // find the proper trie node to update prefix + for _, octet := range octets[:lastOctetIdx] { + // descend down to next trie level + c := n.getChild(octet) + if c == nil { + // create and insert missing intermediate child + c = newNode[V]() + n.insertChild(octet, c) } - // no child found, nothing to delete - return + // proceed with next level + n = c } - // purge dangling paths - for i > 0 { - if n.isEmpty() { - // purge empty node from parents children - parent := stack[i-1] - parent.deleteChild(octets[i-1]) - } - - // unwind the stack - i-- - n = stack[i] - } + // update/insert prefix into node + return n.updatePrefix(lastOctet, lastOctetBits, cb) } -// Update or set the value at pfx with a callback function. -// The callback function is called with (value, ok) and returns a new value.. -// -// If the pfx does not already exist, it is set with the new value. -func (t *Table[V]) Update(pfx netip.Prefix, cb func(val V, ok bool) V) V { - t.init() - +// Get returns the associated payload for prefix and true, or false if +// prefix is not set in the routing table. +func (t *Table[V]) Get(pfx netip.Prefix) (val V, ok bool) { // some values derived from pfx _, ip, bits, is4 := pfxToValues(pfx) n := t.rootNodeByVersion(is4) + if n == nil { + return + } // do not allocate octets := make([]byte, 0, 16) @@ -180,32 +148,21 @@ func (t *Table[V]) Update(pfx netip.Prefix, cb func(val V, ok bool) V) V { lastOctet := octets[lastOctetIdx] lastOctetBits := bits - (lastOctetIdx * strideLen) - // find the proper trie node to update prefix - for i, octet := range octets { - if i == lastOctetIdx { - // update/insert prefix into node - return n.updatePrefix(lastOctet, lastOctetBits, cb) - } - - // descend down to next trie level + // find the proper trie node + for _, octet := range octets[:lastOctetIdx] { c := n.getChild(octet) - - // create and insert missing intermediate child if c == nil { - c = newNode[V]() - n.insertChild(octet, c) + // not found + return } - - // proceed with next level n = c + continue } - - panic("unreachable") + return n.getValByPrefix(lastOctet, lastOctetBits) } -// Get returns the associated payload for prefix and true, or false if -// prefix is not set in the routing table. -func (t *Table[V]) Get(pfx netip.Prefix) (val V, ok bool) { +// Delete removes pfx from the tree, pfx does not have to be present. +func (t *Table[V]) Delete(pfx netip.Prefix) { // some values derived from pfx _, ip, bits, is4 := pfxToValues(pfx) @@ -223,22 +180,50 @@ func (t *Table[V]) Get(pfx netip.Prefix) (val V, ok bool) { lastOctet := octets[lastOctetIdx] lastOctetBits := bits - (lastOctetIdx * strideLen) - // find the proper trie node to update prefix - for i, octet := range octets { + // record path to deleted node + // purge dangling nodes after deletion + stack := [maxTreeDepth]*node[V]{} + + // run variables, needed after for loop + var i int + + // find the trie node + for i = range octets { + // push current node on stack for path recording + stack[i] = n + if i == lastOctetIdx { - return n.getValByPrefix(lastOctet, lastOctetBits) + if !n.deletePrefix(lastOctet, lastOctetBits) { + // prefix not in tree, nothing deleted + return + } + + // escape, but purge dangling path if needed, see below + break } // descend down to next level - if c := n.getChild(octet); c != nil { + if c := n.getChild(octets[i]); c != nil { n = c continue } + // no child found, nothing to delete return } - panic("unreachable") + // purge dangling paths + for i > 0 { + if n.isEmpty() { + // purge empty node from parents children + parent := stack[i-1] + parent.deleteChild(octets[i-1]) + } + + // unwind the stack + i-- + n = stack[i] + } } // Lookup does a route lookup (longest prefix match) for IP and @@ -255,26 +240,19 @@ func (t *Table[V]) Lookup(ip netip.Addr) (val V, ok bool) { octets := make([]byte, 0, 16) octets = ipToOctets(octets, ip, is4) - lastOctetIdx := len(octets) - 1 - // stack of the traversed nodes for fast backtracking, if needed stack := [maxTreeDepth]*node[V]{} - // run variables, used after for loop + // run variable, used after for loop var i int - var octet byte // find leaf node - for i, octet = range octets { + for i = range octets { // push current node on stack for fast backtracking stack[i] = n - if i == lastOctetIdx { - break - } - // go down in tight loop to leaf node - if c := n.getChild(octet); c != nil { + if c := n.getChild(octets[i]); c != nil { n = c continue } @@ -284,8 +262,8 @@ func (t *Table[V]) Lookup(ip netip.Addr) (val V, ok bool) { // start backtracking at leaf node in tight loop for { - // longest prefix match? - if _, val, ok := n.lpmByOctet(octet); ok { + // longest prefix match, make inlining possible + if _, val, ok := n.lpmByIndex(octetToBaseIndex(octets[i])); ok { return val, true } @@ -296,7 +274,6 @@ func (t *Table[V]) Lookup(ip netip.Addr) (val V, ok bool) { // unwind the stack i-- - octet = octets[i] n = stack[i] } } @@ -350,35 +327,35 @@ func (t *Table[V]) lpmByPrefix(pfx netip.Prefix) (depth int, baseIdx uint, val V lastOctetIdx := (bits - 1) / strideLen lastOctetBits := bits - (lastOctetIdx * strideLen) - // default, only the lastOctet has a different prefix len - pfxLen := strideLen - - var octet byte + var pfxLen int // find the trie node - for depth, octet = range octets { + for depth = range octets { // push current node on stack stack[depth] = n - // last significant octet reached + // only the lastOctet has a different prefix len (prefix route) if depth == lastOctetIdx { - // only the lastOctet has a different prefix len (prefix route) pfxLen = lastOctetBits break } // go down in tight loop to leaf node - if c := n.getChild(octet); c != nil { + if c := n.getChild(octets[depth]); c != nil { n = c continue } + pfxLen = strideLen break } // start backtracking with last node and octet for { - if baseIdx, val, ok := n.lpmByPrefix(octet, pfxLen); ok { + pfxIdx := prefixToBaseIndex(octets[depth], pfxLen) + + // longest prefix match + if baseIdx, val, ok := n.lpmByIndex(pfxIdx); ok { return depth, baseIdx, val, true } @@ -387,10 +364,11 @@ func (t *Table[V]) lpmByPrefix(pfx netip.Prefix) (depth int, baseIdx uint, val V return } + // for all upper levels + pfxLen = strideLen + // unwind the stack depth-- - pfxLen = strideLen - octet = octets[depth] n = stack[depth] } } @@ -413,25 +391,25 @@ func (t *Table[V]) Subnets(pfx netip.Prefix) []netip.Prefix { lastOctet := octets[lastOctetIdx] lastOctetBits := bits - (lastOctetIdx * strideLen) + parentIdx := prefixToBaseIndex(lastOctet, lastOctetBits) + // find the trie node for i, octet := range octets { if i == lastOctetIdx { - result := n.subnets(octets[:i], prefixToBaseIndex(lastOctet, lastOctetBits), is4) - + result := n.subnets(octets[:i], parentIdx, is4) slices.SortFunc(result, cmpPrefix) return result } - // descend down to next level - if c := n.getChild(octet); c != nil { - n = c - continue + c := n.getChild(octet) + if c == nil { + break } - return nil + n = c + continue } - - panic("unreachable") + return nil } // Supernets, return all matching routes for pfx, @@ -458,22 +436,23 @@ func (t *Table[V]) Supernets(pfx netip.Prefix) []netip.Prefix { for i, octet := range octets { if i == lastOctetIdx { // make an all-prefix-match at last level - return append(result, n.apmByPrefix(lastOctet, lastOctetBits, i, ip)...) + result = append(result, n.apmByPrefix(lastOctet, lastOctetBits, i, ip)...) + break } // make an all-prefix-match at intermediate level for octet result = append(result, n.apmByOctet(octet, i, ip)...) // descend down to next trie level - if c := n.getChild(octet); c != nil { - n = c - continue + c := n.getChild(octet) + if c == nil { + break } - - return result + n = c + continue } - panic("unreachable") + return result } // OverlapsPrefix reports whether any IP in pfx matches a route in the table. @@ -495,27 +474,20 @@ func (t *Table[V]) OverlapsPrefix(pfx netip.Prefix) bool { lastOctet := octets[lastOctetIdx] lastOctetBits := bits - (lastOctetIdx * strideLen) - for i, octet := range octets { - if i == lastOctetIdx { - return n.overlapsPrefix(lastOctet, lastOctetBits) - } - - // still in the middle of prefix chunks - // test if any route overlaps prefix´ + for _, octet := range octets[:lastOctetIdx] { + // test if any route overlaps prefix´ so far if _, _, ok := n.lpmByOctet(octet); ok { return true } // no overlap so far, go down to next c - if c := n.getChild(octet); c != nil { - n = c - continue + c := n.getChild(octet) + if c == nil { + return false } - - return false + n = c } - - panic("unreachable") + return n.overlapsPrefix(lastOctet, lastOctetBits) } // Overlaps reports whether any IP in the table matches a route in the From 28f865d02d761aa73480af0a9129a06d49bf3982 Mon Sep 17 00:00:00 2001 From: Karl Gaissmaier Date: Mon, 6 May 2024 15:25:52 +0200 Subject: [PATCH 2/2] more tests --- slow_table_test.go | 26 ++++++++++ table_test.go | 127 +++++++++++++++++++++++++-------------------- 2 files changed, 98 insertions(+), 55 deletions(-) diff --git a/slow_table_test.go b/slow_table_test.go index 18e8890..161822a 100644 --- a/slow_table_test.go +++ b/slow_table_test.go @@ -33,6 +33,32 @@ func (s *slowRT[V]) insert(pfx netip.Prefix, val V) { s.entries = append(s.entries, slowRTEntry[V]{pfx, val}) } +func (s *slowRT[V]) get(pfx netip.Prefix) (val V, ok bool) { + pfx = pfx.Masked() + for _, ent := range s.entries { + if ent.pfx == pfx { + return ent.val, true + } + } + return val, false +} + +func (s *slowRT[V]) update(pfx netip.Prefix, cb func(V, bool) V) (val V) { + pfx = pfx.Masked() + for i, ent := range s.entries { + if ent.pfx == pfx { + // update val + s.entries[i].val = cb(ent.val, true) + return + } + } + // new val + val = cb(val, false) + + s.entries = append(s.entries, slowRTEntry[V]{pfx, val}) + return val +} + func (s *slowRT[T]) union(o *slowRT[T]) { for _, op := range o.entries { var match bool diff --git a/table_test.go b/table_test.go index dd7fb3a..c70d397 100644 --- a/table_test.go +++ b/table_test.go @@ -729,15 +729,6 @@ func TestInsertShuffled(t *testing.T) { t.Parallel() pfxs := randomPrefixes(1000) - /* uncomment for failure debugging - var pfxs2 []slowRTEntry[int] - defer func() { - if t.Failed() { - t.Logf("pre-shuffle: %#v", pfxs) - t.Logf("post-shuffle: %#v", pfxs2) - } - }() - */ for i := 0; i < 10; i++ { pfxs2 := append([]slowRTEntry[int](nil), pfxs...) @@ -792,21 +783,6 @@ func TestDeleteCompare(t *testing.T) { toDelete := append([]slowRTEntry[int](nil), all4[deleteCut:]...) toDelete = append(toDelete, all6[deleteCut:]...) - /* uncomment for failure debugging - defer func() { - if t.Failed() { - for _, pfx := range pfxs { - fmt.Printf("%q, ", pfx.pfx) - } - fmt.Println("") - for _, pfx := range toDelete { - fmt.Printf("%q, ", pfx.pfx) - } - fmt.Println("") - } - }() - */ - slow := slowRT[int]{pfxs} fast := Table[int]{} @@ -946,6 +922,15 @@ func TestDeleteIsReverseOfInsert(t *testing.T) { func TestGet(t *testing.T) { t.Parallel() + rt := new(Table[int]) + t.Run("empty table", func(t *testing.T) { + _, ok := rt.Get(randomPrefix4()) + + if ok { + t.Errorf("empty table: ok=%v, expected: %v", ok, false) + } + }) + tests := []struct { name string pfx netip.Prefix @@ -973,23 +958,7 @@ func TestGet(t *testing.T) { }, } - rt := new(Table[int]) - - t.Run("empty table", func(t *testing.T) { - _, ok := rt.Get(randomPrefix4()) - - if ok { - t.Errorf("empty table: ok=%v, expected: %v", ok, false) - } - }) - - t.Run("empty table", func(t *testing.T) { - _, ok := rt.Get(randomPrefix6()) - - if ok { - t.Errorf("empty table: ok=%v, expected: %v", ok, false) - } - }) + rt = new(Table[int]) for _, tt := range tests { rt.Insert(tt.pfx, tt.val) @@ -1008,15 +977,65 @@ func TestGet(t *testing.T) { } }) } +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rt.Delete(tt.pfx) +func TestGetCompare(t *testing.T) { + t.Parallel() - if _, ok := rt.Get(tt.pfx); ok { - t.Errorf("%s: ok=%v, expected: %v", tt.name, ok, false) - } - }) + pfxs := randomPrefixes(10_000) + slow := slowRT[int]{pfxs} + fast := Table[int]{} + + for _, pfx := range pfxs { + fast.Insert(pfx.pfx, pfx.val) + } + + for _, pfx := range pfxs { + slowVal, slowOK := slow.get(pfx.pfx) + fastVal, fastOK := fast.Get(pfx.pfx) + + if !getsEqual(slowVal, slowOK, fastVal, fastOK) { + t.Fatalf("get(%q) = (%v, %v), want (%v, %v)", pfx.pfx, fastVal, fastOK, slowVal, slowOK) + } + } +} + +func TestUpdateCompare(t *testing.T) { + t.Parallel() + + pfxs := randomPrefixes(10_000) + slow := slowRT[int]{pfxs} + fast := Table[int]{} + + // Update as insert + for _, pfx := range pfxs { + fast.Update(pfx.pfx, func(int, bool) int { return pfx.val }) + } + + for _, pfx := range pfxs { + slowVal, slowOK := slow.get(pfx.pfx) + fastVal, fastOK := fast.Get(pfx.pfx) + + if !getsEqual(slowVal, slowOK, fastVal, fastOK) { + t.Fatalf("get(%q) = (%v, %v), want (%v, %v)", pfx.pfx, fastVal, fastOK, slowVal, slowOK) + } + } + + cb := func(val int, _ bool) int { return val + 1 } + + // Update as update + for _, pfx := range pfxs[:len(pfxs)/2] { + slow.update(pfx.pfx, cb) + fast.Update(pfx.pfx, cb) + } + + for _, pfx := range pfxs { + slowVal, slowOK := slow.get(pfx.pfx) + fastVal, fastOK := fast.Get(pfx.pfx) + + if !getsEqual(slowVal, slowOK, fastVal, fastOK) { + t.Fatalf("get(%q) = (%v, %v), want (%v, %v)", pfx.pfx, fastVal, fastOK, slowVal, slowOK) + } } } @@ -1055,8 +1074,9 @@ func TestUpdate(t *testing.T) { return 0 } + // update as insert for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(fmt.Sprintf("insert: %s", tt.name), func(t *testing.T) { val := rt.Update(tt.pfx, cb) got, ok := rt.Get(tt.pfx) @@ -1070,8 +1090,9 @@ func TestUpdate(t *testing.T) { }) } + // update as update for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + t.Run(fmt.Sprintf("update: %s", tt.name), func(t *testing.T) { val := rt.Update(tt.pfx, cb) got, ok := rt.Get(tt.pfx) @@ -1113,14 +1134,12 @@ func TestOverlapsCompare(t *testing.T) { gotFast := fast.Overlaps(&fastInter) if gotSlow != gotFast { - t.Fatalf("Overlaps(...) = %v, want %v\nTable1:\n%s\nTable2:\n%v", + t.Fatalf("Overlaps(...) = %v, want %v\nTable1:\n%s\nTable:\n%v", gotFast, gotSlow, fast.String(), fastInter.String()) } seen[gotFast]++ } - - t.Log(seen) } func TestOverlapsPrefixCompare(t *testing.T) { @@ -1478,7 +1497,6 @@ func TestSubnetsEdgeCases(t *testing.T) { if !reflect.DeepEqual(got, tt.want) { t.Fatalf("%s: got:\n%v\nwant:\n%v", tt.name, got, tt.want) } - }) } } @@ -1578,7 +1596,6 @@ func TestSupernetsEdgeCases(t *testing.T) { if !reflect.DeepEqual(got, tt.want) { t.Fatalf("%s: got:\n%v\nwant:\n%v", tt.name, got, tt.want) } - }) } }