Skip to content

Commit

Permalink
GODRIVER-3286 BSON Binary vector subtype support
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Jan 10, 2025
1 parent c6d1369 commit 73018a5
Show file tree
Hide file tree
Showing 12 changed files with 746 additions and 42 deletions.
209 changes: 209 additions & 0 deletions bson/bson_binary_vector_spec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package bson

import (
"encoding/hex"
"encoding/json"
"fmt"
"math"
"os"
"path"
"testing"

"go.mongodb.org/mongo-driver/v2/internal/require"
)

const bsonBinaryVectorDir = "../testdata/bson-binary-vector/"

type bsonBinaryVectorTests struct {
Description string `json:"description"`
TestKey string `json:"test_key"`
Tests []bsonBinaryVectorTestCase `json:"tests"`
}

type bsonBinaryVectorTestCase struct {
Description string `json:"description"`
Valid bool `json:"valid"`
Vector []interface{} `json:"vector"`
DtypeHex string `json:"dtype_hex"`
DtypeAlias string `json:"dtype_alias"`
Padding int `json:"padding"`
CanonicalBson string `json:"canonical_bson"`
}

func Test_BsonBinaryVector(t *testing.T) {
t.Parallel()

jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir)
require.NoErrorf(t, err, "error finding JSON files in %s: %v", bsonBinaryVectorDir, err)

for _, file := range jsonFiles {
filepath := path.Join(bsonBinaryVectorDir, file)
content, err := os.ReadFile(filepath)
require.NoErrorf(t, err, "reading test file %s", filepath)

var tests bsonBinaryVectorTests
require.NoErrorf(t, json.Unmarshal(content, &tests), "parsing test file %s", filepath)

t.Run(tests.Description, func(t *testing.T) {
t.Parallel()

for _, test := range tests.Tests {
test := test
t.Run(test.Description, func(t *testing.T) {
t.Parallel()

runBsonBinaryVectorTest(t, tests.TestKey, test)
})
}
})
}

t.Run("Insufficient vector data FLOAT32", func(t *testing.T) {
t.Parallel()

val := Binary{Subtype: TypeBinaryVector}

for _, tc := range [][]byte{
{Float32Vector, 0, 42},
{Float32Vector, 0, 42, 42},
{Float32Vector, 0, 42, 42, 42},

{Float32Vector, 0, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42, 42},
{Float32Vector, 0, 42, 42, 42, 42, 42, 42, 42},
} {
t.Run(fmt.Sprintf("marshaling %d bytes", len(tc)-2), func(t *testing.T) {
val.Data = tc
b, err := Marshal(D{{"vector", val}})
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector[float32]
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, ErrInsufficientVectorData.Error())
})
}
})

t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) {
t.Parallel()

t.Run("Marshaling", func(t *testing.T) {
val := BitVector{Padding: 1}
_, err := Marshal(val)
require.EqualError(t, err, ErrNonZeroVectorPadding.Error())
})
t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 1}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector[float32]
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, ErrNonZeroVectorPadding.Error())
})
})

t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) {
t.Parallel()

t.Run("Marshaling", func(t *testing.T) {
val := BitVector{Padding: 8}
_, err := Marshal(val)
require.EqualError(t, err, ErrVectorPaddingTooLarge.Error())
})
t.Run("Unmarshaling", func(t *testing.T) {
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 8}}}}
b, err := Marshal(val)
require.NoError(t, err, "marshaling test BSON")
var got struct {
Vector Vector[float32]
}
err = Unmarshal(b, &got)
require.ErrorContains(t, err, ErrVectorPaddingTooLarge.Error())
})
})
}

func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
v := make([]T, len(s))
for i, e := range s {
f := math.NaN()
switch v := e.(type) {
case float64:
f = v
case string:
if v == "inf" {
f = math.Inf(0)
} else if v == "-inf" {
f = math.Inf(-1)
}
}
v[i] = T(f)
}
return v
}

func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) {
if !test.Valid {
t.Skipf("skip invalid case %s", test.Description)
}

var testVector interface{}
switch alias := test.DtypeHex; alias {
case "0x03":
testVector = map[string]Vector[int8]{
testKey: {convertSlice[int8](test.Vector)},
}
case "0x27":
testVector = map[string]Vector[float32]{
testKey: {convertSlice[float32](test.Vector)},
}
case "0x10":
testVector = map[string]BitVector{
testKey: {
Padding: uint8(test.Padding),
Data: convertSlice[byte](test.Vector),
},
}
default:
t.Fatalf("unsupported vector type: %s", alias)
}

testBSON, err := hex.DecodeString(test.CanonicalBson)
require.NoError(t, err, "decoding canonical BSON")

t.Run("Unmarshaling", func(t *testing.T) {
t.Parallel()

var got interface{}
switch alias := test.DtypeHex; alias {
case "0x03":
got = make(map[string]Vector[int8])
case "0x27":
got = make(map[string]Vector[float32])
case "0x10":
got = make(map[string]BitVector)
default:
t.Fatalf("unsupported type: %s", alias)
}
err := Unmarshal(testBSON, got)
require.NoError(t, err)
require.Equal(t, testVector, got)
})

t.Run("Marshaling", func(t *testing.T) {
t.Parallel()

got, err := Marshal(testVector)
require.NoError(t, err)
require.Equal(t, testBSON, got)
})
}
38 changes: 10 additions & 28 deletions bson/bson_corpus_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,15 @@ func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string {
func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D {
var doc D
err := Unmarshal(b, &doc)
expectNoError(t, err, fmt.Sprintf("%s: decoding %s BSON", testDesc, bType))
require.NoErrorf(t, err, "%s: decoding %s BSON", testDesc, bType)
return doc
}

// nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected
// canonical BSON (cB)
func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) {
actual, err := Marshal(doc)
expectNoError(t, err, fmt.Sprintf("%s: encoding %s BSON", testDesc, bType))
require.NoErrorf(t, err, "%s: encoding %s BSON", testDesc, bType)

if diff := cmp.Diff(cB, actual); diff != "" {
t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n",
Expand Down Expand Up @@ -261,7 +261,7 @@ func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) {
// nativeToJSON encodes the native Document (doc) into an extended JSON string
func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) {
actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true)
expectNoError(t, err, fmt.Sprintf("%s: encoding %s extended JSON", testDesc, ejType))
require.NoErrorf(t, err, "%s: encoding %s extended JSON", testDesc, ejType)

if diff := cmp.Diff(ej, string(actualEJ)); diff != "" {
t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n",
Expand All @@ -288,7 +288,7 @@ func runTest(t *testing.T, file string) {
t.Run(v.Description, func(t *testing.T) {
// get canonical BSON
cB, err := hex.DecodeString(v.CanonicalBson)
expectNoError(t, err, fmt.Sprintf("%s: reading canonical BSON", v.Description))
require.NoErrorf(t, err, "%s: reading canonical BSON", v.Description)

// get canonical extended JSON
var compactEJ bytes.Buffer
Expand Down Expand Up @@ -341,7 +341,7 @@ func runTest(t *testing.T, file string) {
/*** degenerate BSON round-trip tests (if exists) ***/
if v.DegenerateBSON != nil {
dB, err := hex.DecodeString(*v.DegenerateBSON)
expectNoError(t, err, fmt.Sprintf("%s: reading degenerate BSON", v.Description))
require.NoErrorf(t, err, "%s: reading degenerate BSON", v.Description)

doc = bsonToNative(t, dB, "degenerate", v.Description)

Expand Down Expand Up @@ -377,7 +377,7 @@ func runTest(t *testing.T, file string) {
for _, d := range test.DecodeErrors {
t.Run(d.Description, func(t *testing.T) {
b, err := hex.DecodeString(d.Bson)
expectNoError(t, err, d.Description)
require.NoError(t, err, d.Description)

var doc D
err = Unmarshal(b, &doc)
Expand All @@ -392,12 +392,12 @@ func runTest(t *testing.T, file string) {
invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB)

if invalidString || invalidDBPtr {
expectNoError(t, err, d.Description)
require.NoError(t, err, d.Description)
return
}
}

expectError(t, err, fmt.Sprintf("%s: expected decode error", d.Description))
require.Errorf(t, err, "%s: expected decode error", d.Description)
})
}
})
Expand All @@ -418,7 +418,7 @@ func runTest(t *testing.T, file string) {
if strings.Contains(p.Description, "Null") {
_, err = Marshal(doc)
}
expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description))
require.Errorf(t, err, "%s: expected parse error", p.Description)
default:
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)
t.Fail()
Expand All @@ -431,31 +431,13 @@ func runTest(t *testing.T, file string) {

func Test_BsonCorpus(t *testing.T) {
jsonFiles, err := findJSONFilesInDir(dataDir)
if err != nil {
t.Fatalf("error finding JSON files in %s: %v", dataDir, err)
}
require.NoErrorf(t, err, "error finding JSON files in %s: %v", dataDir, err)

for _, file := range jsonFiles {
runTest(t, file)
}
}

func expectNoError(t *testing.T, err error, desc string) {
if err != nil {
t.Helper()
t.Errorf("%s: Unepexted error: %v", desc, err)
t.FailNow()
}
}

func expectError(t *testing.T, err error, desc string) {
if err == nil {
t.Helper()
t.Errorf("%s: Expected error", desc)
t.FailNow()
}
}

func TestRelaxedUUIDValidation(t *testing.T) {
testCases := []struct {
description string
Expand Down
Loading

0 comments on commit 73018a5

Please sign in to comment.