diff --git a/mapping/document.go b/mapping/document.go index 9f5aea581..2a702f323 100644 --- a/mapping/document.go +++ b/mapping/document.go @@ -50,7 +50,8 @@ type DocumentMapping struct { StructTagKey string `json:"struct_tag_key,omitempty"` } -func (dm *DocumentMapping) Validate(cache *registry.Cache) error { +func (dm *DocumentMapping) Validate(cache *registry.Cache, + parentName string, fieldAliasCtx map[string]*FieldMapping) error { var err error if dm.DefaultAnalyzer != "" { _, err := cache.AnalyzerNamed(dm.DefaultAnalyzer) @@ -58,8 +59,12 @@ func (dm *DocumentMapping) Validate(cache *registry.Cache) error { return err } } - for _, property := range dm.Properties { - err = property.Validate(cache) + for propertyName, property := range dm.Properties { + newParent := propertyName + if parentName != "" { + newParent = fmt.Sprintf("%s.%s", parentName, propertyName) + } + err = property.Validate(cache, newParent, fieldAliasCtx) if err != nil { return err } @@ -78,21 +83,24 @@ func (dm *DocumentMapping) Validate(cache *registry.Cache) error { } } - err := validateFieldType(field.Type) + err := validateFieldMapping(field, parentName, fieldAliasCtx) if err != nil { return err } - - if field.Type == "vector" { - err := validateVectorField(field) - if err != nil { - return err - } - } } return nil } +func validateFieldType(field *FieldMapping) error { + switch field.Type { + case "text", "datetime", "number", "boolean", "geopoint", "geoshape", "IP": + return nil + default: + return fmt.Errorf("field: '%s', unknown field type: '%s'", + field.Name, field.Type) + } +} + // analyzerNameForPath attempts to first find the field // described by this path, then returns the analyzer // configured for that field diff --git a/mapping/index.go b/mapping/index.go index 1c08bc589..171ee1a72 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -174,12 +174,14 @@ func (im *IndexMappingImpl) Validate() error { if err != nil { return err } - err = im.DefaultMapping.Validate(im.cache) + + fieldAliasCtx := make(map[string]*FieldMapping) + err = im.DefaultMapping.Validate(im.cache, "", fieldAliasCtx) if err != nil { return err } for _, docMapping := range im.TypeMapping { - err = docMapping.Validate(im.cache) + err = docMapping.Validate(im.cache, "", fieldAliasCtx) if err != nil { return err } diff --git a/mapping/mapping_no_vectors.go b/mapping/mapping_no_vectors.go index f4987596a..b5e033a62 100644 --- a/mapping/mapping_no_vectors.go +++ b/mapping/mapping_no_vectors.go @@ -17,8 +17,6 @@ package mapping -import "fmt" - func NewVectorFieldMapping() *FieldMapping { return nil } @@ -31,16 +29,7 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, // ----------------------------------------------------------------------------- // document validation functions -func validateVectorField(fieldMapping *FieldMapping) error { - return nil -} - -func validateFieldType(fieldType string) error { - switch fieldType { - case "text", "datetime", "number", "boolean", "geopoint", "geoshape", "IP": - default: - return fmt.Errorf("unknown field type: '%s'", fieldType) - } - - return nil +func validateFieldMapping(field *FieldMapping, parentName string, + fieldAliasCtx map[string]*FieldMapping) error { + return validateFieldType(field) } diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index 3289b7162..1e9268261 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -26,6 +26,12 @@ import ( index "github.com/blevesearch/bleve_index_api" ) +// Min and Max allowed dimensions for a vector field +const ( + MinVectorDims = 1 + MaxVectorDims = 2048 +) + func NewVectorFieldMapping() *FieldMapping { return &FieldMapping{ Type: "vector", @@ -136,12 +142,22 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, // ----------------------------------------------------------------------------- // document validation functions -func validateVectorField(field *FieldMapping) error { - if field.Dims <= 0 || field.Dims > 2048 { - return fmt.Errorf("invalid vector dimension,"+ - " value should be in range (%d, %d)", 0, 2048) +func validateFieldMapping(field *FieldMapping, parentName string, + fieldAliasCtx map[string]*FieldMapping) error { + switch field.Type { + case "vector": + return validateVectorFieldAlias(field, parentName, fieldAliasCtx) + default: // non-vector field + return validateFieldType(field) } +} +func validateVectorFieldAlias(field *FieldMapping, parentName string, + fieldAliasCtx map[string]*FieldMapping) error { + + if field.Name == "" { + field.Name = parentName + } if field.Similarity == "" { field.Similarity = index.DefaultSimilarityMetric } @@ -154,21 +170,40 @@ func validateVectorField(field *FieldMapping) error { field.DocValues = false field.SkipFreqNorm = true + // # If alias is present, validate the field options as per the alias + // note: reading from a nil map is safe + if fieldAlias, ok := fieldAliasCtx[field.Name]; ok { + if field.Dims != fieldAlias.Dims { + return fmt.Errorf("field: '%s', invalid alias "+ + "(different dimensions %d and %d)", fieldAlias.Name, field.Dims, + fieldAlias.Dims) + } + + if field.Similarity != fieldAlias.Similarity { + return fmt.Errorf("field: '%s', invalid alias "+ + "(different similarity values %s and %s)", fieldAlias.Name, + field.Similarity, fieldAlias.Similarity) + } + + return nil + } + + // # Validate field options + + if field.Dims < MinVectorDims || field.Dims > MaxVectorDims { + return fmt.Errorf("field: '%s', invalid vector dimension: %d,"+ + " value should be in range (%d, %d)", field.Name, field.Dims, + MinVectorDims, MaxVectorDims) + } + if _, ok := index.SupportedSimilarityMetrics[field.Similarity]; !ok { - return fmt.Errorf("invalid similarity metric: '%s', "+ - "valid metrics are: %+v", field.Similarity, + return fmt.Errorf("field: '%s', invalid similarity "+ + "metric: '%s', valid metrics are: %+v", field.Name, field.Similarity, reflect.ValueOf(index.SupportedSimilarityMetrics).MapKeys()) } - return nil -} - -func validateFieldType(fieldType string) error { - switch fieldType { - case "text", "datetime", "number", "boolean", "geopoint", "geoshape", - "IP", "vector": - default: - return fmt.Errorf("unknown field type: '%s'", fieldType) + if fieldAliasCtx != nil { // writing to a nil map is unsafe + fieldAliasCtx[field.Name] = field } return nil diff --git a/mapping/mapping_vectors_test.go b/mapping/mapping_vectors_test.go index de3f426f7..1decb13ee 100644 --- a/mapping/mapping_vectors_test.go +++ b/mapping/mapping_vectors_test.go @@ -17,9 +17,186 @@ package mapping -import ( - "testing" -) +import "testing" + +func TestVectorFieldAliasValidation(t *testing.T) { + tests := []struct { + // input + name string // name of the test + mappingStr string // index mapping json string + + // expected output + expValidity bool // validity of the mapping + errMsg string // error message, given expValidity is false + }{ + { + name: "test1", + mappingStr: ` + { + "default_mapping": { + "properties": { + "cityVec": { + "fields": [ + { + "type": "vector", + "dims": 3 + }, + { + "name": "cityVec", + "type": "vector", + "dims": 4 + } + ] + } + } + } + }`, + expValidity: false, + errMsg: `field: 'cityVec', invalid alias (different dimensions 4 and 3)`, + }, + { + name: "test2", + mappingStr: ` + { + "default_mapping": { + "properties": { + "cityVec": { + "fields": [ + { + "type": "vector", + "dims": 3, + "similarity": "l2_norm" + }, + { + "name": "cityVec", + "type": "vector", + "dims": 3, + "similarity": "dot_product" + } + ] + } + } + } + }`, + expValidity: false, + errMsg: `field: 'cityVec', invalid alias (different similarity values dot_product and l2_norm)`, + }, + { + name: "test3", + mappingStr: ` + { + "default_mapping": { + "properties": { + "cityVec": { + "fields": [ + { + "type": "vector", + "dims": 3 + }, + { + "name": "cityVec", + "type": "vector", + "dims": 3 + } + ] + } + } + } + }`, + expValidity: true, + errMsg: "", + }, + { + name: "test4", + mappingStr: ` + { + "default_mapping": { + "properties": { + "cityVec": { + "fields": [ + { + "name": "vecData", + "type": "vector", + "dims": 4 + } + ] + }, + "countryVec": { + "fields": [ + { + "name": "vecData", + "type": "vector", + "dims": 3 + } + ] + } + } + } + }`, + expValidity: false, + errMsg: `field: 'vecData', invalid alias (different dimensions 3 and 4)`, + }, + { + name: "test5", + mappingStr: ` + { + "default_mapping": { + "properties": { + "cityVec": { + "fields": [ + { + "name": "vecData", + "type": "vector", + "dims": 3 + } + ] + } + } + }, + "types": { + "type1": { + "properties": { + "cityVec": { + "fields": [ + { + "name": "vecData", + "type": "vector", + "dims": 4 + } + ] + } + } + } + } + }`, + expValidity: false, + errMsg: `field: 'vecData', invalid alias (different dimensions 4 and 3)`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + im := NewIndexMapping() + err := im.UnmarshalJSON([]byte(test.mappingStr)) + if err != nil { + t.Fatalf("failed to unmarshal index mapping: %v", err) + } + + err = im.Validate() + isValid := err == nil + if test.expValidity != isValid { + t.Fatalf("validity mismatch, expected: %v, got: %v", + test.expValidity, isValid) + } + + if !isValid && err.Error() != test.errMsg { + t.Fatalf("invalid error message, expected: %v, got: %v", + test.errMsg, err.Error()) + } + }) + } +} + // A test case for processVector function type vectorTest struct {