Skip to content

Commit

Permalink
feat: Support Int8Vector in go (milvus-io#38990)
Browse files Browse the repository at this point in the history
Issue: milvus-io#38666

Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain authored and gifi-siby committed Jan 16, 2025
1 parent a4fd1bd commit 3db549f
Show file tree
Hide file tree
Showing 50 changed files with 1,331 additions and 150 deletions.
4 changes: 4 additions & 0 deletions client/entity/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ func (t FieldType) PbFieldType() (string, string) {
return "[]byte", ""
case FieldTypeBFloat16Vector:
return "[]byte", ""
case FieldTypeInt8Vector:
return "[]int8", ""
default:
return "undefined", ""
}
Expand Down Expand Up @@ -177,6 +179,8 @@ const (
FieldTypeBFloat16Vector FieldType = 103
// FieldTypeBinaryVector field type sparse vector
FieldTypeSparseVector FieldType = 104
// FieldTypeInt8Vector field type int8 vector
FieldTypeInt8Vector FieldType = 105
)

// Field represent field schema in milvus
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/klauspost/compress v1.17.9
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f
github.com/minio/minio-go/v7 v7.0.73
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81
github.com/prometheus/client_golang v1.14.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu
github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b h1:iPPhnFx+s7FF53UeWj7A4EYhPRMFPL6mHqyQw7qRjeQ=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20241211060635-410431d7865b/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f h1:So6RKU5wqP/8EaKogicJP8gZ2SrzzS/JprusBaE3RKc=
github.com/milvus-io/milvus-proto/go-api/v2 v2.5.0-beta.0.20250102080446-c3ba3d26a90f/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/pulsar-client-go v0.12.1 h1:O2JZp1tsYiO7C0MQ4hrUY/aJXnn2Gry6hpm7UodghmE=
github.com/milvus-io/pulsar-client-go v0.12.1/go.mod h1:dkutuH4oS2pXiGm+Ti7fQZ4MRjrMPZ8IJeEGAWMeckk=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
Expand Down
21 changes: 19 additions & 2 deletions internal/core/src/common/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ using distance_t = float;
using float16 = knowhere::fp16;
using bfloat16 = knowhere::bf16;
using bin1 = knowhere::bin1;
using int8 = knowhere::int8;

// See also: https://github.com/milvus-io/milvus-proto/blob/master/proto/schema.proto
enum class DataType {
Expand Down Expand Up @@ -85,6 +86,7 @@ enum class DataType {
VECTOR_FLOAT16 = 102,
VECTOR_BFLOAT16 = 103,
VECTOR_SPARSE_FLOAT = 104,
VECTOR_INT8 = 105,
};

using Timestamp = uint64_t; // TODO: use TiKV-like timestamp
Expand Down Expand Up @@ -322,6 +324,11 @@ IsSparseFloatVectorDataType(DataType data_type) {
return data_type == DataType::VECTOR_SPARSE_FLOAT;
}

inline bool
IsInt8VectorDataType(DataType data_type) {
return data_type == DataType::VECTOR_INT8;
}

inline bool
IsFloatVectorDataType(DataType data_type) {
return IsDenseFloatVectorDataType(data_type) ||
Expand All @@ -331,7 +338,7 @@ IsFloatVectorDataType(DataType data_type) {
inline bool
IsVectorDataType(DataType data_type) {
return IsBinaryVectorDataType(data_type) ||
IsFloatVectorDataType(data_type);
IsFloatVectorDataType(data_type) || IsInt8VectorDataType(data_type);
}

inline bool
Expand Down Expand Up @@ -418,7 +425,17 @@ IsFloatVectorMetricType(const MetricType& metric_type) {

inline bool
IsBinaryVectorMetricType(const MetricType& metric_type) {
return !IsFloatVectorMetricType(metric_type);
return metric_type == knowhere::metric::HAMMING ||
metric_type == knowhere::metric::JACCARD ||
metric_type == knowhere::metric::SUPERSTRUCTURE ||
metric_type == knowhere::metric::SUBSTRUCTURE;
}

inline bool
IsIntVectorMetricType(const MetricType& metric_type) {
return metric_type == knowhere::metric::L2 ||
metric_type == knowhere::metric::IP ||
metric_type == knowhere::metric::COSINE;
}

// Plus 1 because we can't use greater(>) symbol
Expand Down
4 changes: 4 additions & 0 deletions internal/core/src/index/IndexFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ IndexFactory::CreateVectorIndex(
return std::make_unique<VectorMemIndex<bfloat16>>(
index_type, metric_type, version, file_manager_context);
}
case DataType::VECTOR_INT8: {
return std::make_unique<VectorMemIndex<int8>>(
index_type, metric_type, version, file_manager_context);
}
default:
PanicInfo(
DataTypeInvalid,
Expand Down
1 change: 1 addition & 0 deletions internal/core/src/index/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ BIN_List() {
return ret;
}

// TODO caiyd: should list supported list
std::vector<std::tuple<IndexType, MetricType>>
unsupported_index_combinations() {
static std::vector<std::tuple<IndexType, MetricType>> ret{
Expand Down
10 changes: 6 additions & 4 deletions internal/core/src/index/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,13 @@ CheckMetricTypeSupport(const MetricType& metric_type) {
if constexpr (std::is_same_v<T, bin1>) {
AssertInfo(
IsBinaryVectorMetricType(metric_type),
"binary vector does not float vector metric type: " + metric_type);
"binary vector does not support metric type: " + metric_type);
} else if constexpr (std::is_same_v<T, int8>) {
AssertInfo(IsIntVectorMetricType(metric_type),
"int vector does not support metric type: " + metric_type);
} else {
AssertInfo(
IsFloatVectorMetricType(metric_type),
"float vector does not binary vector metric type: " + metric_type);
AssertInfo(IsFloatVectorMetricType(metric_type),
"float vector does not support metric type: " + metric_type);
}
}

Expand Down
1 change: 1 addition & 0 deletions internal/core/src/index/VectorMemIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,5 +632,6 @@ template class VectorMemIndex<float>;
template class VectorMemIndex<bin1>;
template class VectorMemIndex<float16>;
template class VectorMemIndex<bfloat16>;
template class VectorMemIndex<int8>;

} // namespace milvus::index
1 change: 1 addition & 0 deletions internal/core/src/indexbuilder/IndexFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class IndexFactory {
case DataType::VECTOR_BFLOAT16:
case DataType::VECTOR_BINARY:
case DataType::VECTOR_SPARSE_FLOAT:
case DataType::VECTOR_INT8:
return std::make_unique<VecIndexCreator>(type, config, context);
default:
PanicInfo(DataTypeInvalid,
Expand Down
23 changes: 23 additions & 0 deletions internal/core/src/indexbuilder/index_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,29 @@ BuildSparseFloatVecIndex(CIndex index,
return status;
}

CStatus
BuildInt8VecIndex(CIndex index, int64_t int8_value_num, const int8_t* vectors) {
auto status = CStatus();
try {
AssertInfo(index,
"failed to build int8 vector index, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto dim = cIndex->dim();
auto row_nums = int8_value_num / dim;
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
cIndex->Build(ds);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
}
return status;
}

// field_data:
// 1, serialized proto::schema::BoolArray, if type is bool;
// 2, serialized proto::schema::StringArray, if type is string;
Expand Down
3 changes: 3 additions & 0 deletions internal/core/src/indexbuilder/index_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ BuildSparseFloatVecIndex(CIndex index,
int64_t dim,
const uint8_t* vectors);

CStatus
BuildInt8VecIndex(CIndex index, int64_t data_size, const int8_t* vectors);

// field_data:
// 1, serialized proto::schema::BoolArray, if type is bool;
// 2, serialized proto::schema::StringArray, if type is string;
Expand Down
6 changes: 6 additions & 0 deletions internal/core/src/segcore/vector_index_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ ValidateIndexParams(const char* index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
error_msg);
} else if (dataType == milvus::DataType::VECTOR_INT8) {
status = knowhere::IndexStaticFaced<knowhere::int8>::ConfigCheck(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
error_msg);
} else {
status = knowhere::Status::invalid_args;
}
Expand Down
15 changes: 13 additions & 2 deletions internal/proxy/task_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
for k, v := range Params.AutoIndexConfig.BinaryIndexParams.GetAsJSONMap() {
indexParamsMap[k] = v
}
} else if typeutil.IsIntVectorType(cit.fieldSchema.DataType) {
// override int vector index params by autoindex
for k, v := range Params.AutoIndexConfig.IndexParams.GetAsJSONMap() {
indexParamsMap[k] = v
}
}

if metricTypeExist {
Expand Down Expand Up @@ -320,6 +325,9 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) {
// override binary vector index params by autoindex
config = Params.AutoIndexConfig.BinaryIndexParams.GetAsJSONMap()
} else if typeutil.IsIntVectorType(cit.fieldSchema.DataType) {
// override int vector index params by autoindex
config = Params.AutoIndexConfig.IndexParams.GetAsJSONMap()
}
if !exist {
if err := handle(0, config); err != nil {
Expand Down Expand Up @@ -364,17 +372,20 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "float vector index does not support metric type: "+metricType)
}
} else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) {
if metricType != metric.IP && metricType != metric.BM25 {
if !funcutil.SliceContain(indexparamcheck.SparseFloatVectorMetrics, metricType) {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only IP&BM25 is the supported metric type for sparse index")
}

if metricType == metric.BM25 && cit.functionSchema.GetType() != schemapb.FunctionType_BM25 {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "only BM25 Function output field support BM25 metric type")
}
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) {
if !funcutil.SliceContain(indexparamcheck.BinaryVectorMetrics, metricType) {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "binary vector index does not support metric type: "+metricType)
}
} else if typeutil.IsIntVectorType(cit.fieldSchema.DataType) {
if !funcutil.SliceContain(indexparamcheck.IntVectorMetrics, metricType) {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "int vector index does not support metric type: "+metricType)
}
}
}

Expand Down
36 changes: 36 additions & 0 deletions internal/proxy/validate_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, helper *typeutil.Sch
if err := v.checkSparseFloatFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_Int8Vector:
if err := v.checkInt8VectorFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_VarChar:
if err := v.checkVarCharFieldData(field, fieldSchema); err != nil {
return err
Expand Down Expand Up @@ -246,6 +250,29 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
return errNumRowsMismatch(field.GetFieldName(), n)
}

case schemapb.DataType_Int8Vector:
f, err := schema.GetFieldFromName(field.GetFieldName())
if err != nil {
return err
}

dim, err := typeutil.GetDim(f)
if err != nil {
return err
}

n, err := funcutil.GetNumRowsOfInt8VectorField(field.GetVectors().GetInt8Vector(), dim)
if err != nil {
return err
}
dataDim := field.GetVectors().Dim
if dataDim != dim {
return errDimMismatch(field.GetFieldName(), dataDim, dim)
}

if n != numRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}
default:
// error won't happen here.
n, err := funcutil.GetNumRowOfFieldDataWithSchema(field, schema)
Expand Down Expand Up @@ -609,6 +636,15 @@ func (v *validateUtil) checkSparseFloatFieldData(field *schemapb.FieldData, fiel
return typeutil.ValidateSparseFloatRows(sparseRows...)
}

func (v *validateUtil) checkInt8VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
int8VecArray := field.GetVectors().GetInt8Vector()
if int8VecArray == nil {
msg := fmt.Sprintf("int8 vector field '%v' is illegal, nil Vector_Int8 type", field.GetFieldName())
return merr.WrapErrParameterInvalid("need vector_int8 array", "got nil", msg)
}
return nil
}

func (v *validateUtil) checkVarCharFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
strArr := field.GetScalars().GetStringData().GetData()
if strArr == nil && fieldSchema.GetDefaultValue() == nil && !fieldSchema.GetNullable() {
Expand Down
Loading

0 comments on commit 3db549f

Please sign in to comment.