From 4b2b9c3b5b56890c39de16f9adf85b0306eeac8d Mon Sep 17 00:00:00 2001 From: jmacd Date: Fri, 22 Nov 2019 22:48:50 -0800 Subject: [PATCH 1/3] Add a Reset method --- varopt.go | 11 +++++++++++ varopt_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/varopt.go b/varopt.go index 44b5843..37598cc 100644 --- a/varopt.go +++ b/varopt.go @@ -54,6 +54,17 @@ func New(capacity int, rnd *rand.Rand) *Varopt { } } +// Reset returns the sampler to its initial state, maintaining its +// capacity and random number source. +func (s *Varopt) Reset() { + s.L = s.L[:0] + s.T = s.T[:0] + s.X = s.X[:0] + s.tau = 0 + s.totalCount = 0 + s.totalWeight = 0 +} + // Add considers a new observation for the sample with given weight. // // An error will be returned if the weight is either negative or NaN. diff --git a/varopt_test.go b/varopt_test.go index c7f0b68..2e1a9cd 100644 --- a/varopt_test.go +++ b/varopt_test.go @@ -161,3 +161,28 @@ func TestInvalidWeight(t *testing.T) { err = v.Add(nil, 0) require.Equal(t, err, varopt.ErrInvalidWeight) } + +func TestReset(t *testing.T) { + const capacity = 10 + const insert = 100 + rnd := rand.New(rand.NewSource(98887)) + v := varopt.New(capacity, rnd) + + sum := 0. + for i := 1.; i <= insert; i++ { + v.Add(nil, i) + sum += i + } + + require.Equal(t, capacity, v.Size()) + require.Equal(t, insert, v.TotalCount()) + require.Equal(t, sum, v.TotalWeight()) + require.Less(t, 0., v.Tau()) + + v.Reset() + + require.Equal(t, 0, v.Size()) + require.Equal(t, 0, v.TotalCount()) + require.Equal(t, 0., v.TotalWeight()) + require.Equal(t, 0., v.Tau()) +} From d2bfbc820e6da2b0fcda49c0485fdb2abb718349 Mon Sep 17 00:00:00 2001 From: jmacd Date: Fri, 22 Nov 2019 23:16:41 -0800 Subject: [PATCH 2/3] Return the ejected sample from Add --- varopt.go | 13 +++++++---- varopt_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/varopt.go b/varopt.go index 37598cc..0b8a570 100644 --- a/varopt.go +++ b/varopt.go @@ -66,16 +66,18 @@ func (s *Varopt) Reset() { } // Add considers a new observation for the sample with given weight. +// If there is an item ejected from the same as a result, the item is +// returned. // // An error will be returned if the weight is either negative or NaN. -func (s *Varopt) Add(sample Sample, weight float64) error { +func (s *Varopt) Add(sample Sample, weight float64) (Sample, error) { individual := internal.Vsample{ Sample: sample, Weight: weight, } if weight <= 0 || math.IsNaN(weight) { - return ErrInvalidWeight + return nil, ErrInvalidWeight } s.totalCount++ @@ -83,7 +85,7 @@ func (s *Varopt) Add(sample Sample, weight float64) error { if s.Size() < s.capacity { s.L.Push(individual) - return nil + return nil, nil } // the X <- {} step from the paper is not done here, @@ -113,19 +115,22 @@ func (s *Varopt) Add(sample Sample, weight float64) error { r -= (1 - wxd/s.tau) d++ } + var eject Sample if r < 0 { if d < len(s.X) { s.X[d], s.X[len(s.X)-1] = s.X[len(s.X)-1], s.X[d] } + eject = s.X[len(s.X)-1].Sample s.X = s.X[:len(s.X)-1] } else { ti := s.rnd.Intn(len(s.T)) s.T[ti], s.T[len(s.T)-1] = s.T[len(s.T)-1], s.T[ti] + eject = s.T[len(s.T)-1].Sample s.T = s.T[:len(s.T)-1] } s.T = append(s.T, s.X...) s.X = s.X[:0] - return nil + return eject, nil } func (s *Varopt) uniform() float64 { diff --git a/varopt_test.go b/varopt_test.go index 2e1a9cd..484cf47 100644 --- a/varopt_test.go +++ b/varopt_test.go @@ -152,13 +152,13 @@ func TestInvalidWeight(t *testing.T) { rnd := rand.New(rand.NewSource(98887)) v := varopt.New(1, rnd) - err := v.Add(nil, math.NaN()) + _, err := v.Add(nil, math.NaN()) require.Equal(t, err, varopt.ErrInvalidWeight) - err = v.Add(nil, -1) + _, err = v.Add(nil, -1) require.Equal(t, err, varopt.ErrInvalidWeight) - err = v.Add(nil, 0) + _, err = v.Add(nil, 0) require.Equal(t, err, varopt.ErrInvalidWeight) } @@ -186,3 +186,55 @@ func TestReset(t *testing.T) { require.Equal(t, 0., v.TotalWeight()) require.Equal(t, 0., v.Tau()) } + +func TestEject(t *testing.T) { + const capacity = 100 + const rounds = 10000 + const maxvalue = 10000 + + entries := make([]int, capacity+1) + freelist := make([]*int, capacity+1) + + for i := range entries { + freelist[i] = &entries[i] + } + + // Make two deterministically equal samplers + rnd1 := rand.New(rand.NewSource(98887)) + rnd2 := rand.New(rand.NewSource(98887)) + vsrc := rand.New(rand.NewSource(98887)) + + expected := varopt.New(capacity, rnd1) + ejector := varopt.New(capacity, rnd2) + + for i := 0; i < rounds; i++ { + value := vsrc.Intn(maxvalue) + weight := vsrc.ExpFloat64() + + _, _ = expected.Add(&value, weight) + + lastitem := len(freelist) - 1 + item := freelist[lastitem] + freelist = freelist[:lastitem] + + *item = value + eject, _ := ejector.Add(item, weight) + + if eject != nil { + freelist = append(freelist, eject.(*int)) + } + } + + require.Equal(t, expected.Size(), ejector.Size()) + require.Equal(t, expected.TotalCount(), ejector.TotalCount()) + require.Equal(t, expected.TotalWeight(), ejector.TotalWeight()) + require.Equal(t, expected.Tau(), ejector.Tau()) + + for i := 0; i < capacity; i++ { + expectItem, expectWeight := expected.Get(i) + ejectItem, ejectWeight := expected.Get(i) + + require.Equal(t, *expectItem.(*int), *ejectItem.(*int)) + require.Equal(t, expectWeight, ejectWeight) + } +} From be3c37ae3ccdac4cb2b999e6fb494c6546665bbc Mon Sep 17 00:00:00 2001 From: jmacd Date: Fri, 22 Nov 2019 23:18:39 -0800 Subject: [PATCH 3/3] Typo --- varopt.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/varopt.go b/varopt.go index 0b8a570..99fd01c 100644 --- a/varopt.go +++ b/varopt.go @@ -66,8 +66,8 @@ func (s *Varopt) Reset() { } // Add considers a new observation for the sample with given weight. -// If there is an item ejected from the same as a result, the item is -// returned. +// If there is an item ejected from the sample as a result, the item +// is returned to allow re-use of memory. // // An error will be returned if the weight is either negative or NaN. func (s *Varopt) Add(sample Sample, weight float64) (Sample, error) {