diff --git a/pkg/specter/unitloading.go b/pkg/specter/unitloading.go index 24bbc68..e4ad801 100644 --- a/pkg/specter/unitloading.go +++ b/pkg/specter/unitloading.go @@ -119,9 +119,25 @@ func (g UnitGroup) Merge(group UnitGroup) UnitGroup { type UnitMatcher func(u Unit) bool -func UnitWithKindMatcher(kind UnitKind) UnitMatcher { +func UnitWithKindsMatcher(kinds ...UnitKind) UnitMatcher { return func(u Unit) bool { - return u.Kind() == kind + for _, kind := range kinds { + if u.Kind() == kind { + return true + } + } + return false + } +} + +func UnitWithIDsMatcher(id ...UnitID) UnitMatcher { + return func(u Unit) bool { + for _, id := range id { + if u.ID() == id { + return true + } + } + return false } } @@ -144,40 +160,14 @@ func (g UnitGroup) Find(m UnitMatcher) (Unit, bool) { return u, true } } - return nil, false -} -func (g UnitGroup) SelectType(t UnitKind) UnitGroup { - return g.Select(func(u Unit) bool { - return u.Kind() == t - }) -} - -func (g UnitGroup) SelectName(t UnitID) Unit { - for _, u := range g { - if u.ID() == t { - return u - } - } - - return nil -} - -func (g UnitGroup) SelectNames(names ...UnitID) UnitGroup { - return g.Select(func(u Unit) bool { - for _, name := range names { - if u.ID() == name { - return true - } - } - return false - }) + return nil, false } -func (g UnitGroup) Exclude(p func(u Unit) bool) UnitGroup { +func (g UnitGroup) Exclude(m UnitMatcher) UnitGroup { r := UnitGroup{} for _, u := range g { - if !p(u) { + if !m(u) { r = append(r, u) } } @@ -185,23 +175,6 @@ func (g UnitGroup) Exclude(p func(u Unit) bool) UnitGroup { return r } -func (g UnitGroup) ExcludeType(t UnitKind) UnitGroup { - return g.Exclude(func(u Unit) bool { - return u.Kind() == t - }) -} - -func (g UnitGroup) ExcludeNames(names ...UnitID) UnitGroup { - return g.Exclude(func(u Unit) bool { - for _, name := range names { - if u.ID() == name { - return true - } - } - return false - }) -} - // MapUnitGroup performs a map operation on a UnitGroup func MapUnitGroup[T any](g UnitGroup, p func(u Unit) T) []T { var mapped []T diff --git a/pkg/specter/unitloading_test.go b/pkg/specter/unitloading_test.go index da9dc7b..74bce95 100644 --- a/pkg/specter/unitloading_test.go +++ b/pkg/specter/unitloading_test.go @@ -154,117 +154,6 @@ func TestUnitGroup_Select(t *testing.T) { } } -func TestUnitGroup_SelectType(t *testing.T) { - tests := []struct { - name string - given specter.UnitGroup - when specter.UnitKind - then specter.UnitGroup - }{ - { - name: "GIVEN no units matches, THEN return an empty group", - given: specter.UnitGroup{ - testutils.NewUnitStub("unit2name", "type", specter.Source{}), - }, - when: "not_found", - then: specter.UnitGroup{}, - }, - { - name: "GIVEN a unit matches, THEN return a group with matching unit", - given: specter.UnitGroup{ - testutils.NewUnitStub("unit1", "type1", specter.Source{}), - testutils.NewUnitStub("unit2", "type2", specter.Source{}), - }, - when: "type1", - then: specter.UnitGroup{ - testutils.NewUnitStub("unit1", "type1", specter.Source{}), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.given.SelectType(tt.when) - require.Equal(t, tt.then, got) - }) - } -} - -func TestUnitGroup_SelectName(t *testing.T) { - tests := []struct { - name string - given specter.UnitGroup - when specter.UnitID - then specter.Unit - }{ - { - name: "GIVEN a group with multiple units WHEN selecting an existing name THEN return the corresponding unit", - given: specter.NewUnitGroup( - testutils.NewUnitStub("unit1", "type", specter.Source{}), - testutils.NewUnitStub("unit2", "type", specter.Source{}), - ), - when: "unit2", - then: testutils.NewUnitStub("unit2", "type", specter.Source{}), - }, - { - name: "GIVEN a group with multiple units WHEN selecting a non-existent name THEN return nil", - given: specter.NewUnitGroup( - testutils.NewUnitStub("unit1", "type", specter.Source{}), - testutils.NewUnitStub("unit2", "type", specter.Source{}), - ), - when: "spec3", - then: nil, - }, - { - name: "GIVEN an empty group WHEN selecting a name THEN return nil", - given: specter.NewUnitGroup(), - when: "unit1", - then: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.given.SelectName(tt.when) - require.Equal(t, tt.then, got) - }) - } -} - -func TestUnitGroup_SelectNames(t *testing.T) { - tests := []struct { - name string - given specter.UnitGroup - when []specter.UnitID - then specter.UnitGroup - }{ - { - name: "GIVEN no units matches, THEN return a group with no values", - given: specter.UnitGroup{ - testutils.NewUnitStub("name", "type", specter.Source{}), - }, - when: []specter.UnitID{"not_found"}, - then: specter.UnitGroup{}, - }, - { - name: "GIVEN a unit matches, THEN return a group with matching unit", - given: specter.UnitGroup{ - testutils.NewUnitStub("unit1", "type", specter.Source{}), - testutils.NewUnitStub("unit2", "type", specter.Source{}), - }, - when: []specter.UnitID{"unit1"}, - then: specter.UnitGroup{ - testutils.NewUnitStub("unit1", "type", specter.Source{}), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.given.SelectNames(tt.when...) - require.Equal(t, tt.then, got) - }) - } -} - func TestUnitGroup_Exclude(t *testing.T) { tests := []struct { name string @@ -304,80 +193,6 @@ func TestUnitGroup_Exclude(t *testing.T) { } } -func TestUnitGroup_ExcludeType(t *testing.T) { - tests := []struct { - name string - given specter.UnitGroup - when specter.UnitKind - then specter.UnitGroup - }{ - { - name: "GIVEN no units matches, THEN return a group with the same values", - given: specter.UnitGroup{ - testutils.NewUnitStub("unit2name", "type", specter.Source{}), - }, - when: "not_found", - then: specter.UnitGroup{ - testutils.NewUnitStub("unit2name", "type", specter.Source{}), - }, - }, - { - name: "GIVEN a unit matches, THEN return a group without matching unit", - given: specter.UnitGroup{ - testutils.NewUnitStub("unit1", "type1", specter.Source{}), - testutils.NewUnitStub("unit2", "type2", specter.Source{}), - }, - when: "type1", - then: specter.UnitGroup{ - testutils.NewUnitStub("unit2", "type2", specter.Source{}), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.given.ExcludeType(tt.when) - require.Equal(t, tt.then, got) - }) - } -} - -func TestUnitGroup_ExcludeNames(t *testing.T) { - tests := []struct { - name string - given specter.UnitGroup - when []specter.UnitID - then specter.UnitGroup - }{ - { - name: "GIVEN no units matches, THEN return a group with the same values", - given: specter.UnitGroup{ - testutils.NewUnitStub("unit2name", "type", specter.Source{}), - }, - when: []specter.UnitID{"not_found"}, - then: specter.UnitGroup{ - testutils.NewUnitStub("unit2name", "type", specter.Source{}), - }, - }, - { - name: "GIVEN a unit matches, THEN return a group without matching unit", - given: specter.UnitGroup{ - testutils.NewUnitStub("unit1", "type", specter.Source{}), - testutils.NewUnitStub("unit2", "type", specter.Source{}), - }, - when: []specter.UnitID{"unit1"}, - then: specter.UnitGroup{ - testutils.NewUnitStub("unit2", "type", specter.Source{}), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.given.ExcludeNames(tt.when...) - require.Equal(t, tt.then, got) - }) - } -} - func TestMapUnitGroup(t *testing.T) { tests := []struct { name string @@ -520,7 +335,42 @@ func TestUnitWithKindMatcher(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := specter.UnitWithKindMatcher(tt.when.kind)(tt.when.unit) + got := specter.UnitWithKindsMatcher(tt.when.kind)(tt.when.unit) + assert.Equal(t, tt.then, got) + }) + } +} + +func TestUnitWithIDsMatcher(t *testing.T) { + type when struct { + id specter.UnitID + unit specter.Unit + } + tests := []struct { + name string + when when + then bool + }{ + { + name: "WHEN unit with id THEN return true", + when: when{ + id: "unit1", + unit: testutils.NewUnitStub("unit1", "kind", specter.Source{}), + }, + then: true, + }, + { + name: "WHEN unit not with id THEN return false", + when: when{ + id: "id", + unit: testutils.NewUnitStub("not_id", "kind", specter.Source{}), + }, + then: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := specter.UnitWithIDsMatcher(tt.when.id)(tt.when.unit) assert.Equal(t, tt.then, got) }) }