Skip to content

Commit

Permalink
Support Int8Vector in go
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain committed Jan 6, 2025
1 parent f0cddfd commit f617d0a
Show file tree
Hide file tree
Showing 65 changed files with 1,731 additions and 199 deletions.
21 changes: 21 additions & 0 deletions client/column/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
vector = append(vector, v)
}
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil

case schemapb.DataType_SparseFloatVector:
sparseVectors := fd.GetVectors().GetSparseFloatVector()
if sparseVectors == nil {
Expand All @@ -303,6 +304,26 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
vectors = append(vectors, vector)
}
return NewColumnSparseVectors(fd.GetFieldName(), vectors), nil

case schemapb.DataType_Int8Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schemapb.VectorField_Int8Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Int8Vector
dim := int(vectors.GetDim())
if end < 0 {
end = len(data) / dim
}
vector := make([][]int8, 0, end-begin) // shall not have remanunt
for i := begin; i < end; i++ {
v := make([]int8, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnInt8Vector(fd.GetFieldName(), dim, vector), nil

default:
return nil, fmt.Errorf("unsupported data type %s", fd.GetType())
}
Expand Down
33 changes: 33 additions & 0 deletions client/column/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,36 @@ func (c *ColumnBFloat16Vector) Slice(start, end int) Column {
vectorBase: c.vectorBase.slice(start, end),
}
}

/* int8 vector */

type ColumnInt8Vector struct {
*vectorBase[entity.Int8Vector]
}

func NewColumnInt8Vector(fieldName string, dim int, data [][]int8) *ColumnInt8Vector {
vectors := lo.Map(data, func(row []int8, _ int) entity.Int8Vector { return entity.Int8Vector(row) })
return &ColumnInt8Vector{
vectorBase: newVectorBase(fieldName, dim, vectors, entity.FieldTypeInt8Vector),
}
}

// AppendValue appends vector value into values.
// override default type constrains, add `[]int8` conversion
func (c *ColumnInt8Vector) AppendValue(i interface{}) error {
switch vector := i.(type) {
case entity.Int8Vector:
c.values = append(c.values, vector)
case []int8:
c.values = append(c.values, vector)
default:
return errors.Newf("unexpected append value type %T, field type %v", vector, c.fieldType)
}
return nil
}

func (c *ColumnInt8Vector) Slice(start, end int) Column {
return &ColumnInt8Vector{
vectorBase: c.vectorBase.slice(start, end),
}
}
54 changes: 54 additions & 0 deletions client/column/vector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,38 @@ func (s *VectorSuite) TestBasic() {
}
}
})

s.Run("int8_vector", func() {
name := fmt.Sprintf("field_%d", rand.Intn(1000))
n := 3
dim := rand.Intn(10) + 2
data := make([][]int8, 0, n)
for i := 0; i < n; i++ {
row := lo.RepeatBy(dim, func(i int) int8 {
return int8(rand.Intn(256) - 128)
})
data = append(data, row)
}
column := NewColumnInt8Vector(name, dim, data)
s.Equal(entity.FieldTypeInt8Vector, column.Type())
s.Equal(name, column.Name())
s.Equal(lo.Map(data, func(row []int8, _ int) entity.Int8Vector { return entity.Int8Vector(row) }), column.Data())
s.Equal(dim, column.Dim())

fd := column.FieldData()
s.Equal(name, fd.GetFieldName())
s.Equal(lo.Flatten(data), fd.GetVectors().GetInt8Vector())

result, err := FieldDataColumn(fd, 0, -1)
s.NoError(err)
parsed, ok := result.(*ColumnInt8Vector)
if s.True(ok) {
s.Equal(entity.FieldTypeInt8Vector, parsed.Type())
s.Equal(name, parsed.Name())
s.Equal(lo.Map(data, func(row []int8, _ int) entity.Int8Vector { return entity.Int8Vector(row) }), parsed.Data())
s.Equal(dim, parsed.Dim())
}
})
}

func (s *VectorSuite) TestSlice() {
Expand Down Expand Up @@ -277,6 +309,28 @@ func (s *VectorSuite) TestSlice() {
s.Equal(lo.Map(data[:l], func(row []byte, _ int) entity.BFloat16Vector { return entity.BFloat16Vector(row) }), slicedColumn.Data())
}
})

s.Run("int8_vector", func() {
name := fmt.Sprintf("field_%d", rand.Intn(1000))
n := 100
dim := rand.Intn(10) + 2
data := make([][]int8, 0, n)
for i := 0; i < n; i++ {
row := lo.RepeatBy(dim, func(i int) int8 {
return int8(rand.Intn(256) - 128)
})
data = append(data, row)
}
column := NewColumnInt8Vector(name, dim, data)

l := rand.Intn(n)
sliced := column.Slice(0, l)
slicedColumn, ok := sliced.(*ColumnInt8Vector)
if s.True(ok) {
s.Equal(dim, slicedColumn.Dim())
s.Equal(lo.Map(data[:l], func(row []int8, _ int) entity.Int8Vector { return entity.Int8Vector(row) }), slicedColumn.Data())
}
})
}

func TestVectors(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions client/entity/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func (t FieldType) Name() string {
return "Float16Vector"
case FieldTypeBFloat16Vector:
return "BFloat16Vector"
case FieldTypeInt8Vector:
return "Int8Vector"
default:
return "undefined"
}
Expand Down Expand Up @@ -100,6 +102,8 @@ func (t FieldType) String() string {
return "[]byte"
case FieldTypeBFloat16Vector:
return "[]byte"
case FieldTypeInt8Vector:
return "[]int8"
default:
return "undefined"
}
Expand Down Expand Up @@ -136,6 +140,8 @@ func (t FieldType) PbFieldType() (string, string) {
return "[]byte", ""
case FieldTypeBFloat16Vector:
return "[]byte", ""
case FieldTypeInt8Vector:
return "[]int8", ""
default:
return "undefined", ""
}
Expand Down Expand Up @@ -177,6 +183,8 @@ const (
FieldTypeBFloat16Vector FieldType = 103
// FieldTypeBinaryVector field type sparse vector
FieldTypeSparseVector FieldType = 104
// FieldTypeInt8Vector field type int8 vector
FieldTypeInt8Vector FieldType = 105
)

// Field represent field schema in milvus
Expand Down
22 changes: 20 additions & 2 deletions client/entity/vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (fv FloatVector) ToBFloat16Vector() BFloat16Vector {
return typeutil.Float32ArrayToBFloat16Bytes(fv)
}

// FloatVector float32 vector wrapper.
// Float16Vector float16 vector wrapper.
type Float16Vector []byte

// Dim returns vector dimension.
Expand All @@ -77,7 +77,7 @@ func (fv Float16Vector) ToFloat32Vector() FloatVector {
return typeutil.Float16BytesToFloat32Vector(fv)
}

// FloatVector float32 vector wrapper.
// BFloat16Vector bfloat16 vector wrapper.
type BFloat16Vector []byte

// Dim returns vector dimension.
Expand Down Expand Up @@ -131,3 +131,21 @@ func (t Text) FieldType() FieldType {
func (t Text) Serialize() []byte {
return []byte(t)
}

// Int8Vector []int8 vector wrapper
type Int8Vector []int8

// Dim return vector dimension
func (iv Int8Vector) Dim() int {
return len(iv)
}

// Serialize just return bytes
func (iv Int8Vector) Serialize() []byte {
return typeutil.Int8ArrayToBytes(iv)
}

// entity.FieldType returns coresponding field type.
func (iv Int8Vector) FieldType() FieldType {
return FieldTypeInt8Vector
}
11 changes: 11 additions & 0 deletions client/entity/vectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,15 @@ func TestVectors(t *testing.T) {
assert.Equal(t, dim*8, bv.Dim())
assert.ElementsMatch(t, raw, bv.Serialize())
})

t.Run("test int8 vector", func(t *testing.T) {
raw := make([]int8, dim)
for i := 0; i < dim; i++ {
raw[i] = int8(rand.Intn(256) - 128)
}

iv := Int8Vector(raw)
assert.Equal(t, dim, iv.Dim())
assert.Equal(t, dim, len(iv.Serialize()))
})
}
4 changes: 2 additions & 2 deletions client/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@ require (
github.com/blang/semver/v4 v4.0.0
github.com/cockroachdb/errors v1.9.1
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84
github.com/quasilyte/go-ruleguard/dsl v0.3.22
github.com/samber/lo v1.27.0
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.17.1
go.uber.org/atomic v1.10.0
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.34.2
)
Expand Down Expand Up @@ -99,6 +98,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions client/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b h1:iPPhnFx+s7FF53UeWj7A4EYhPRMFPL6mHqyQw7qRjeQ=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f h1:So6RKU5wqP/8EaKogicJP8gZ2SrzzS/JprusBaE3RKc=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84 h1:EAFxmxUVp5yYFDCrX1MQoSxkTO+ycy8NXEqEDEB3cRM=
github.com/milvus-io/milvus/pkg v0.0.2-0.20241126032235-cb6542339e84/go.mod h1:RATa0GS4jhkPpsYOvQ/QvcNz8rd+TlRPDiSyXQnMMxs=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/klauspost/compress v1.17.9
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f
github.com/minio/minio-go/v7 v7.0.73
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
github.com/prometheus/client_golang v1.14.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu
github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b h1:iPPhnFx+s7FF53UeWj7A4EYhPRMFPL6mHqyQw7qRjeQ=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f h1:So6RKU5wqP/8EaKogicJP8gZ2SrzzS/JprusBaE3RKc=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/pulsar-client-go v0.12.1 h1:O2JZp1tsYiO7C0MQ4hrUY/aJXnn2Gry6hpm7UodghmE=
github.com/milvus-io/pulsar-client-go v0.12.1/go.mod h1:dkutuH4oS2pXiGm+Ti7fQZ4MRjrMPZ8IJeEGAWMeckk=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
Expand Down
21 changes: 19 additions & 2 deletions internal/core/src/common/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ using distance_t = float;
using float16 = knowhere::fp16;
using bfloat16 = knowhere::bf16;
using bin1 = knowhere::bin1;
using int8 = knowhere::int8;

// See also: https://github.com/milvus-io/milvus-proto/blob/master/proto/schema.proto
enum class DataType {
Expand Down Expand Up @@ -85,6 +86,7 @@ enum class DataType {
VECTOR_FLOAT16 = 102,
VECTOR_BFLOAT16 = 103,
VECTOR_SPARSE_FLOAT = 104,
VECTOR_INT8 = 105,
};

using Timestamp = uint64_t; // TODO: use TiKV-like timestamp
Expand Down Expand Up @@ -322,6 +324,11 @@ IsSparseFloatVectorDataType(DataType data_type) {
return data_type == DataType::VECTOR_SPARSE_FLOAT;
}

inline bool
IsInt8VectorDataType(DataType data_type) {
return data_type == DataType::VECTOR_INT8;
}

inline bool
IsFloatVectorDataType(DataType data_type) {
return IsDenseFloatVectorDataType(data_type) ||
Expand All @@ -331,7 +338,7 @@ IsFloatVectorDataType(DataType data_type) {
inline bool
IsVectorDataType(DataType data_type) {
return IsBinaryVectorDataType(data_type) ||
IsFloatVectorDataType(data_type);
IsFloatVectorDataType(data_type) || IsInt8VectorDataType(data_type);
}

inline bool
Expand Down Expand Up @@ -418,7 +425,17 @@ IsFloatVectorMetricType(const MetricType& metric_type) {

inline bool
IsBinaryVectorMetricType(const MetricType& metric_type) {
return !IsFloatVectorMetricType(metric_type);
return metric_type == knowhere::metric::HAMMING ||
metric_type == knowhere::metric::JACCARD ||
metric_type == knowhere::metric::SUPERSTRUCTURE ||
metric_type == knowhere::metric::SUBSTRUCTURE;
}

inline bool
IsIntVectorMetricType(const MetricType& metric_type) {
return metric_type == knowhere::metric::L2 ||
metric_type == knowhere::metric::IP ||
metric_type == knowhere::metric::COSINE;

Check warning on line 438 in internal/core/src/common/Types.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/common/Types.h#L435-L438

Added lines #L435 - L438 were not covered by tests
}

// Plus 1 because we can't use greater(>) symbol
Expand Down
4 changes: 4 additions & 0 deletions internal/core/src/index/IndexFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ IndexFactory::CreateVectorIndex(
return std::make_unique<VectorMemIndex<bfloat16>>(
index_type, metric_type, version, file_manager_context);
}
case DataType::VECTOR_INT8: {
return std::make_unique<VectorMemIndex<int8>>(
index_type, metric_type, version, file_manager_context);

Check warning on line 448 in internal/core/src/index/IndexFactory.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/index/IndexFactory.cpp#L446-L448

Added lines #L446 - L448 were not covered by tests
}
default:
PanicInfo(
DataTypeInvalid,
Expand Down
1 change: 1 addition & 0 deletions internal/core/src/index/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ BIN_List() {
return ret;
}

// TODO caiyd: should list supported list
std::vector<std::tuple<IndexType, MetricType>>
unsupported_index_combinations() {
static std::vector<std::tuple<IndexType, MetricType>> ret{
Expand Down
10 changes: 6 additions & 4 deletions internal/core/src/index/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,13 @@ CheckMetricTypeSupport(const MetricType& metric_type) {
if constexpr (std::is_same_v<T, bin1>) {
AssertInfo(
IsBinaryVectorMetricType(metric_type),
"binary vector does not float vector metric type: " + metric_type);
"binary vector does not support metric type: " + metric_type);
} else if constexpr (std::is_same_v<T, int8>) {
AssertInfo(IsIntVectorMetricType(metric_type),

Check warning on line 110 in internal/core/src/index/Utils.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/index/Utils.h#L110

Added line #L110 was not covered by tests
"int vector does not support metric type: " + metric_type);
} else {
AssertInfo(
IsFloatVectorMetricType(metric_type),
"float vector does not binary vector metric type: " + metric_type);
AssertInfo(IsFloatVectorMetricType(metric_type),
"float vector does not support metric type: " + metric_type);
}
}

Expand Down
1 change: 1 addition & 0 deletions internal/core/src/index/VectorMemIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,5 +636,6 @@ template class VectorMemIndex<float>;
template class VectorMemIndex<bin1>;
template class VectorMemIndex<float16>;
template class VectorMemIndex<bfloat16>;
template class VectorMemIndex<int8>;

} // namespace milvus::index
1 change: 1 addition & 0 deletions internal/core/src/indexbuilder/IndexFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class IndexFactory {
case DataType::VECTOR_BFLOAT16:
case DataType::VECTOR_BINARY:
case DataType::VECTOR_SPARSE_FLOAT:
case DataType::VECTOR_INT8:
return std::make_unique<VecIndexCreator>(type, config, context);
default:
PanicInfo(DataTypeInvalid,
Expand Down
Loading

0 comments on commit f617d0a

Please sign in to comment.