Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MB-59421: add validation for vector field aliases #1903

Merged
merged 7 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions mapping/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,21 @@ type DocumentMapping struct {
StructTagKey string `json:"struct_tag_key,omitempty"`
}

func (dm *DocumentMapping) Validate(cache *registry.Cache) error {
func (dm *DocumentMapping) Validate(cache *registry.Cache,
parentName string, fieldAliasCtx map[string]*FieldMapping) error {
var err error
if dm.DefaultAnalyzer != "" {
_, err := cache.AnalyzerNamed(dm.DefaultAnalyzer)
if err != nil {
return err
}
}
for _, property := range dm.Properties {
err = property.Validate(cache)
for propertyName, property := range dm.Properties {
newParent := propertyName
if parentName != "" {
newParent = fmt.Sprintf("%s.%s", parentName, propertyName)
}
err = property.Validate(cache, newParent, fieldAliasCtx)
if err != nil {
return err
}
Expand All @@ -78,21 +83,24 @@ func (dm *DocumentMapping) Validate(cache *registry.Cache) error {
}
}

err := validateFieldType(field.Type)
err := validateFieldMapping(field, parentName, fieldAliasCtx)
if err != nil {
return err
}

if field.Type == "vector" {
err := validateVectorField(field)
if err != nil {
return err
}
}
}
return nil
}

func validateFieldType(field *FieldMapping) error {
switch field.Type {
case "text", "datetime", "number", "boolean", "geopoint", "geoshape", "IP":
return nil
default:
return fmt.Errorf("field: '%s', unknown field type: '%s'",
field.Name, field.Type)
}
}

// analyzerNameForPath attempts to first find the field
// described by this path, then returns the analyzer
// configured for that field
Expand Down
6 changes: 4 additions & 2 deletions mapping/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,14 @@ func (im *IndexMappingImpl) Validate() error {
if err != nil {
return err
}
err = im.DefaultMapping.Validate(im.cache)

fieldAliasCtx := make(map[string]*FieldMapping)
err = im.DefaultMapping.Validate(im.cache, "", fieldAliasCtx)
if err != nil {
return err
}
for _, docMapping := range im.TypeMapping {
err = docMapping.Validate(im.cache)
err = docMapping.Validate(im.cache, "", fieldAliasCtx)
if err != nil {
return err
}
Expand Down
17 changes: 3 additions & 14 deletions mapping/mapping_no_vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package mapping

import "fmt"

func NewVectorFieldMapping() *FieldMapping {
return nil
}
Expand All @@ -31,16 +29,7 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{},
// -----------------------------------------------------------------------------
// document validation functions

func validateVectorField(fieldMapping *FieldMapping) error {
return nil
}

func validateFieldType(fieldType string) error {
switch fieldType {
case "text", "datetime", "number", "boolean", "geopoint", "geoshape", "IP":
default:
return fmt.Errorf("unknown field type: '%s'", fieldType)
}

return nil
func validateFieldMapping(field *FieldMapping, parentName string,
fieldAliasCtx map[string]*FieldMapping) error {
return validateFieldType(field)
}
65 changes: 50 additions & 15 deletions mapping/mapping_vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ import (
index "github.com/blevesearch/bleve_index_api"
)

// Min and Max allowed dimensions for a vector field
const (
MinVectorDims = 1
MaxVectorDims = 2048
)

func NewVectorFieldMapping() *FieldMapping {
return &FieldMapping{
Type: "vector",
Expand Down Expand Up @@ -136,12 +142,22 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{},
// -----------------------------------------------------------------------------
// document validation functions

func validateVectorField(field *FieldMapping) error {
if field.Dims <= 0 || field.Dims > 2048 {
return fmt.Errorf("invalid vector dimension,"+
" value should be in range (%d, %d)", 0, 2048)
func validateFieldMapping(field *FieldMapping, parentName string,
fieldAliasCtx map[string]*FieldMapping) error {
switch field.Type {
case "vector":
return validateVectorFieldAlias(field, parentName, fieldAliasCtx)
default: // non-vector field
return validateFieldType(field)
}
}

func validateVectorFieldAlias(field *FieldMapping, parentName string,
fieldAliasCtx map[string]*FieldMapping) error {

if field.Name == "" {
field.Name = parentName
}
if field.Similarity == "" {
field.Similarity = index.DefaultSimilarityMetric
}
Expand All @@ -154,21 +170,40 @@ func validateVectorField(field *FieldMapping) error {
field.DocValues = false
moshaad7 marked this conversation as resolved.
Show resolved Hide resolved
field.SkipFreqNorm = true

// # If alias is present, validate the field options as per the alias
// note: reading from a nil map is safe
if fieldAlias, ok := fieldAliasCtx[field.Name]; ok {
if field.Dims != fieldAlias.Dims {
return fmt.Errorf("field: '%s', invalid alias "+
"(different dimensions %d and %d)", fieldAlias.Name, field.Dims,
fieldAlias.Dims)
}

if field.Similarity != fieldAlias.Similarity {
return fmt.Errorf("field: '%s', invalid alias "+
"(different similarity values %s and %s)", fieldAlias.Name,
field.Similarity, fieldAlias.Similarity)
}

return nil
}

// # Validate field options

if field.Dims < MinVectorDims || field.Dims > MaxVectorDims {
return fmt.Errorf("field: '%s', invalid vector dimension: %d,"+
" value should be in range (%d, %d)", field.Name, field.Dims,
MinVectorDims, MaxVectorDims)
}

if _, ok := index.SupportedSimilarityMetrics[field.Similarity]; !ok {
return fmt.Errorf("invalid similarity metric: '%s', "+
"valid metrics are: %+v", field.Similarity,
return fmt.Errorf("field: '%s', invalid similarity "+
"metric: '%s', valid metrics are: %+v", field.Name, field.Similarity,
reflect.ValueOf(index.SupportedSimilarityMetrics).MapKeys())
}

return nil
}

func validateFieldType(fieldType string) error {
switch fieldType {
case "text", "datetime", "number", "boolean", "geopoint", "geoshape",
"IP", "vector":
default:
return fmt.Errorf("unknown field type: '%s'", fieldType)
if fieldAliasCtx != nil { // writing to a nil map is unsafe
fieldAliasCtx[field.Name] = field
}

return nil
Expand Down
183 changes: 180 additions & 3 deletions mapping/mapping_vectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,186 @@

package mapping

import (
"testing"
)
import "testing"

func TestVectorFieldAliasValidation(t *testing.T) {
tests := []struct {
// input
name string // name of the test
mappingStr string // index mapping json string

// expected output
expValidity bool // validity of the mapping
errMsg string // error message, given expValidity is false
}{
{
name: "test1",
mappingStr: `
{
"default_mapping": {
"properties": {
"cityVec": {
"fields": [
{
"type": "vector",
"dims": 3
},
{
"name": "cityVec",
"type": "vector",
"dims": 4
}
]
}
}
}
}`,
expValidity: false,
errMsg: `field: 'cityVec', invalid alias (different dimensions 4 and 3)`,
},
{
name: "test2",
mappingStr: `
{
"default_mapping": {
"properties": {
"cityVec": {
"fields": [
{
"type": "vector",
"dims": 3,
"similarity": "l2_norm"
},
{
"name": "cityVec",
"type": "vector",
"dims": 3,
"similarity": "dot_product"
}
]
}
}
}
}`,
expValidity: false,
errMsg: `field: 'cityVec', invalid alias (different similarity values dot_product and l2_norm)`,
},
{
name: "test3",
mappingStr: `
{
"default_mapping": {
"properties": {
"cityVec": {
"fields": [
{
"type": "vector",
"dims": 3
},
{
"name": "cityVec",
"type": "vector",
"dims": 3
}
]
}
}
}
}`,
expValidity: true,
errMsg: "",
},
{
name: "test4",
mappingStr: `
{
"default_mapping": {
"properties": {
"cityVec": {
"fields": [
{
"name": "vecData",
"type": "vector",
"dims": 4
}
]
},
"countryVec": {
"fields": [
{
"name": "vecData",
"type": "vector",
"dims": 3
}
]
}
}
}
}`,
expValidity: false,
errMsg: `field: 'vecData', invalid alias (different dimensions 3 and 4)`,
},
{
name: "test5",
mappingStr: `
{
"default_mapping": {
"properties": {
"cityVec": {
"fields": [
{
"name": "vecData",
"type": "vector",
"dims": 3
}
]
}
}
},
"types": {
"type1": {
"properties": {
"cityVec": {
"fields": [
{
"name": "vecData",
"type": "vector",
"dims": 4
}
]
}
}
}
}
}`,
expValidity: false,
errMsg: `field: 'vecData', invalid alias (different dimensions 4 and 3)`,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
im := NewIndexMapping()
err := im.UnmarshalJSON([]byte(test.mappingStr))
if err != nil {
t.Fatalf("failed to unmarshal index mapping: %v", err)
}

err = im.Validate()
isValid := err == nil
if test.expValidity != isValid {
t.Fatalf("validity mismatch, expected: %v, got: %v",
test.expValidity, isValid)
}

if !isValid && err.Error() != test.errMsg {
t.Fatalf("invalid error message, expected: %v, got: %v",
test.errMsg, err.Error())
}
})
}
}


// A test case for processVector function
type vectorTest struct {
Expand Down
Loading