diff --git a/internal/util/function/common.go b/internal/util/function/common.go index 2db39a21de7fe..64a317b459f87 100644 --- a/internal/util/function/common.go +++ b/internal/util/function/common.go @@ -88,6 +88,19 @@ const ( vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS" ) +// voyageAI +const ( + voyage3Large string = "voyage-3-large" + voyage3 string = "voyage-3" + voyage3Lite string = "voyage-3-lite" + voyageCode3 string = "voyage-code-3" + voyageFinance2 string = "voyage-finance-2" + voyageLaw2 string = "voyage-law-2" + voyageCode2 string = "voyage-code-2" + + voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY" +) + func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) { dim, err := strconv.ParseInt(dimStr, 10, 64) if err != nil { diff --git a/internal/util/function/mock_embedding_service.go b/internal/util/function/mock_embedding_service.go index 66eff7b62fa61..dfb667d576bb7 100644 --- a/internal/util/function/mock_embedding_service.go +++ b/internal/util/function/mock_embedding_service.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus/internal/util/function/models/ali" "github.com/milvus-io/milvus/internal/util/function/models/openai" "github.com/milvus-io/milvus/internal/util/function/models/vertexai" + "github.com/milvus-io/milvus/internal/util/function/models/voyageai" ) func mockEmbedding(texts []string, dim int) [][]float32 { @@ -99,6 +100,32 @@ func CreateAliEmbeddingServer() *httptest.Server { return ts } +func CreateVoyageAIEmbeddingServer() *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req voyageai.EmbeddingRequest + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + embs := mockEmbedding(req.Input, int(req.OutputDimension)) + var res voyageai.EmbeddingResponse + for i := 0; i < len(req.Input); i++ { + res.Data = append(res.Data, voyageai.EmbeddingData{ + Object: "list", + Embedding: embs[i], + Index: i, + }) + } + + res.Usage = voyageai.Usage{ + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + return ts +} + func CreateVertexAIEmbeddingServer() *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req vertexai.EmbeddingRequest diff --git a/internal/util/function/models/voyageai/voyageai_text_embedding.go b/internal/util/function/models/voyageai/voyageai_text_embedding.go new file mode 100644 index 0000000000000..0150e2ec5bdaa --- /dev/null +++ b/internal/util/function/models/voyageai/voyageai_text_embedding.go @@ -0,0 +1,152 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package voyageai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "sort" + "time" + + "github.com/milvus-io/milvus/internal/util/function/models/utils" +) + +type EmbeddingRequest struct { + // ID of the model to use. + Model string `json:"model"` + + // Input text to embed, encoded as a string. + Input []string `json:"input"` + + InputType string `json:"input_type,omitempty"` + + Truncation bool `json:"truncation,omitempty"` + + OutputDimension int64 `json:"output_dimension,omitempty"` + + OutputDtype string `json:"output_dtype,omitempty"` + + EncodingFormat string `json:"encoding_format,omitempty"` +} + +type Usage struct { + // The total number of tokens used by the request. + TotalTokens int `json:"total_tokens"` +} + +type EmbeddingData struct { + Object string `json:"object"` + + Embedding []float32 `json:"embedding"` + + Index int `json:"index"` +} + +type EmbeddingResponse struct { + Object string `json:"object"` + + Data []EmbeddingData `json:"data"` + + Model string `json:"model"` + + Usage Usage `json:"usage"` +} + +type ByIndex struct { + resp *EmbeddingResponse +} + +func (eb *ByIndex) Len() int { return len(eb.resp.Data) } +func (eb *ByIndex) Swap(i, j int) { + eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i] +} + +func (eb *ByIndex) Less(i, j int) bool { + return eb.resp.Data[i].Index < eb.resp.Data[j].Index +} + +type ErrorInfo struct { + Code string `json:"code"` + Message string `json:"message"` + RequestID string `json:"request_id"` +} + +type VoyageAIEmbedding struct { + apiKey string + url string +} + +func NewVoyageAIEmbeddingClient(apiKey string, url string) *VoyageAIEmbedding { + return &VoyageAIEmbedding{ + apiKey: apiKey, + url: url, + } +} + +func (c *VoyageAIEmbedding) Check() error { + if c.apiKey == "" { + return fmt.Errorf("api key is empty") + } + + if c.url == "" { + return fmt.Errorf("url is empty") + } + return nil +} + +func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (*EmbeddingResponse, error) { + var r EmbeddingRequest + r.Model = modelName + r.Input = texts + r.InputType = textType + r.OutputDtype = outputType + if dim != 0 { + r.OutputDimension = int64(dim) + } + data, err := json.Marshal(r) + if err != nil { + return nil, err + } + + if timeoutSec <= 0 { + timeoutSec = utils.DefaultTimeout + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + body, err := utils.RetrySend(req, 3) + if err != nil { + return nil, err + } + var res EmbeddingResponse + err = json.Unmarshal(body, &res) + if err != nil { + return nil, err + } + sort.Sort(&ByIndex{&res}) + return &res, err +} diff --git a/internal/util/function/models/voyageai/voyageai_text_embedding_test.go b/internal/util/function/models/voyageai/voyageai_text_embedding_test.go new file mode 100644 index 0000000000000..79d390d2fc9b5 --- /dev/null +++ b/internal/util/function/models/voyageai/voyageai_text_embedding_test.go @@ -0,0 +1,127 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package voyageai + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingClientCheck(t *testing.T) { + { + c := NewVoyageAIEmbeddingClient("", "mock_uri") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewVoyageAIEmbeddingClient("mock_key", "") + err := c.Check() + assert.True(t, err != nil) + fmt.Println(err) + } + + { + c := NewVoyageAIEmbeddingClient("mock_key", "mock_uri") + err := c.Check() + assert.True(t, err == nil) + } +} + +func TestEmbeddingOK(t *testing.T) { + var res EmbeddingResponse + repStr := `{ + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [ + 0.0, + 0.1 + ], + "index": 0 + }, + { + "object": "embedding", + "embedding": [ + 2.0, + 2.1 + ], + "index": 2 + }, + { + "object": "embedding", + "embedding": [ + 1.0, + 1.1 + ], + "index": 1 + } + ], + "model": "voyage-large-2", + "usage": { + "total_tokens": 10 + } +}` + err := json.Unmarshal([]byte(repStr), &res) + assert.NoError(t, err) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewVoyageAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + ret, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0) + assert.True(t, err == nil) + assert.Equal(t, ret.Data[0].Index, 0) + assert.Equal(t, ret.Data[1].Index, 1) + assert.Equal(t, ret.Data[2].Index, 2) + assert.Equal(t, ret.Data[0].Embedding, []float32{0.0, 0.1}) + assert.Equal(t, ret.Data[1].Embedding, []float32{1.0, 1.1}) + assert.Equal(t, ret.Data[2].Embedding, []float32{2.0, 2.1}) + } +} + +func TestEmbeddingFailed(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewVoyageAIEmbeddingClient("mock_key", url) + err := c.Check() + assert.True(t, err == nil) + _, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0) + assert.True(t, err != nil) + } +} diff --git a/internal/util/function/text_embedding_function.go b/internal/util/function/text_embedding_function.go index 340aa60abe4d5..fab424e2d8768 100644 --- a/internal/util/function/text_embedding_function.go +++ b/internal/util/function/text_embedding_function.go @@ -38,6 +38,7 @@ const ( aliDashScopeProvider string = "dashscope" bedrockProvider string = "bedrock" vertexAIProvider string = "vertexai" + voyageAIProvider string = "voyageai" ) // Text embedding for retrieval task @@ -130,8 +131,17 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s FunctionBase: *base, embProvider: embP, }, nil + case voyageAIProvider: + embP, err := NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema) + if err != nil { + return nil, err + } + return &TextEmbeddingFunction{ + FunctionBase: *base, + embProvider: embP, + }, nil default: - return nil, fmt.Errorf("Unsupported embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s]", provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider) + return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s]", provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider) } } diff --git a/internal/util/function/voyageai_embedding_provider.go b/internal/util/function/voyageai_embedding_provider.go new file mode 100644 index 0000000000000..f6a71a12b8ac4 --- /dev/null +++ b/internal/util/function/voyageai_embedding_provider.go @@ -0,0 +1,155 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "fmt" + "os" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function/models/voyageai" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type VoyageAIEmbeddingProvider struct { + fieldDim int64 + + client *voyageai.VoyageAIEmbedding + modelName string + embedDimParam int64 + + maxBatch int + timeoutSec int64 +} + +func createVoyageAIEmbeddingClient(apiKey string, url string) (*voyageai.VoyageAIEmbedding, error) { + if apiKey == "" { + apiKey = os.Getenv(voyageAIAKEnvStr) + } + if apiKey == "" { + return nil, fmt.Errorf("Missing credentials. Please pass `api_key`, or configure the %s environment variable in the Milvus service.", voyageAIAKEnvStr) + } + + if url == "" { + url = "https://api.voyageai.com/v1/embeddings" + } + + c := voyageai.NewVoyageAIEmbeddingClient(apiKey, url) + return c, nil +} + +func NewVoyageAIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *schemapb.FunctionSchema) (*VoyageAIEmbeddingProvider, error) { + fieldDim, err := typeutil.GetDim(fieldSchema) + if err != nil { + return nil, err + } + var apiKey, url, modelName string + dim := int64(0) + + for _, param := range functionSchema.Params { + switch strings.ToLower(param.Key) { + case modelNameParamKey: + modelName = param.Value + case dimParamKey: + // Only voyage-3-large and voyage-code-3 support dim param: 1024 (default), 256, 512, 2048 + dim, err = parseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name) + if err != nil { + return nil, err + } + case apiKeyParamKey: + apiKey = param.Value + case embeddingURLParamKey: + url = param.Value + default: + } + } + + if modelName != voyage3Large && modelName != voyage3 && modelName != voyage3Lite && modelName != voyageCode3 && modelName != voyageFinance2 && modelName != voyageLaw2 && modelName != voyageCode2 { + return nil, fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s, %s, %s, %s, %s]", + modelName, voyage3Large, voyage3, voyage3Lite, voyageCode3, voyageFinance2, voyageLaw2, voyageCode2) + } + + if dim != 0 { + if modelName != voyage3Large && modelName != voyageCode3 { + return nil, fmt.Errorf("VoyageAI text embedding model: [%s] doesn't supports dim parameter, only [%s, %s] support it.", modelName, voyage3, voyageCode3) + } + if dim != 1024 && dim != 256 && dim != 512 && dim != 2048 { + return nil, fmt.Errorf("VoyageAI text embedding model's dim only supports 2048, 1024 (default), 512, and 256.") + } + } + c, err := createVoyageAIEmbeddingClient(apiKey, url) + if err != nil { + return nil, err + } + + provider := VoyageAIEmbeddingProvider{ + client: c, + fieldDim: fieldDim, + modelName: modelName, + embedDimParam: dim, + maxBatch: 128, + timeoutSec: 30, + } + return &provider, nil +} + +func (provider *VoyageAIEmbeddingProvider) MaxBatch() int { + return 5 * provider.maxBatch +} + +func (provider *VoyageAIEmbeddingProvider) FieldDim() int64 { + return provider.fieldDim +} + +func (provider *VoyageAIEmbeddingProvider) CallEmbedding(texts []string, batchLimit bool, mode TextEmbeddingMode) ([][]float32, error) { + numRows := len(texts) + if batchLimit && numRows > provider.MaxBatch() { + return nil, fmt.Errorf("VoyageAI embedding supports up to [%d] pieces of data at a time, got [%d]", provider.MaxBatch(), numRows) + } + var textType string + if mode == InsertMode { + textType = "document" + } else { + textType = "query" + } + + data := make([][]float32, 0, numRows) + for i := 0; i < numRows; i += provider.maxBatch { + end := i + provider.maxBatch + if end > numRows { + end = numRows + } + resp, err := provider.client.Embedding(provider.modelName, texts[i:end], int(provider.embedDimParam), textType, "float", provider.timeoutSec) + if err != nil { + return nil, err + } + if end-i != len(resp.Data) { + return nil, fmt.Errorf("Get embedding failed. The number of texts and embeddings does not match text:[%d], embedding:[%d]", end-i, len(resp.Data)) + } + for _, item := range resp.Data { + if len(item.Embedding) != int(provider.fieldDim) { + return nil, fmt.Errorf("The required embedding dim is [%d], but the embedding obtained from the model is [%d]", + provider.fieldDim, len(item.Embedding)) + } + data = append(data, item.Embedding) + } + } + return data, nil +} diff --git a/internal/util/function/voyageai_embedding_provider_test.go b/internal/util/function/voyageai_embedding_provider_test.go new file mode 100644 index 0000000000000..8f7bc4277df0e --- /dev/null +++ b/internal/util/function/voyageai_embedding_provider_test.go @@ -0,0 +1,166 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package function + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function/models/ali" +) + +func TestVoyageAITextEmbeddingProvider(t *testing.T) { + suite.Run(t, new(VoyageAITextEmbeddingProviderSuite)) +} + +type VoyageAITextEmbeddingProviderSuite struct { + suite.Suite + schema *schemapb.CollectionSchema + providers []string +} + +func (s *VoyageAITextEmbeddingProviderSuite) SetupTest() { + s.schema = &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + s.providers = []string{aliDashScopeProvider} +} + +func createVoyageAIProvider(url string, schema *schemapb.FieldSchema, providerName string) (textEmbeddingProvider, error) { + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Unknown, + InputFieldNames: []string{"text"}, + OutputFieldNames: []string{"vector"}, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + Params: []*commonpb.KeyValuePair{ + {Key: modelNameParamKey, Value: voyage3}, + {Key: apiKeyParamKey, Value: "mock"}, + {Key: embeddingURLParamKey, Value: url}, + {Key: dimParamKey, Value: "4"}, + }, + } + switch providerName { + case aliDashScopeProvider: + return NewVoyageAIEmbeddingProvider(schema, functionSchema) + default: + return nil, fmt.Errorf("Unknow provider") + } +} + +func (s *VoyageAITextEmbeddingProviderSuite) TestEmbedding() { + ts := CreateVoyageAIEmbeddingServer() + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createVoyageAIProvider(ts.URL, s.schema.Fields[2], provderName) + s.NoError(err) + { + data := []string{"sentence"} + ret, err2 := provder.CallEmbedding(data, false, InsertMode) + s.NoError(err2) + s.Equal(1, len(ret)) + s.Equal(4, len(ret[0])) + s.Equal([]float32{0.0, 0.1, 0.2, 0.3}, ret[0]) + } + { + data := []string{"sentence 1", "sentence 2", "sentence 3"} + ret, _ := provder.CallEmbedding(data, false, SearchMode) + s.Equal([][]float32{{0.0, 0.1, 0.2, 0.3}, {1.0, 1.1, 1.2, 1.3}, {2.0, 2.1, 2.2, 2.3}}, ret) + } + } +} + +func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingDimNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res ali.EmbeddingResponse + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + TextIndex: 0, + }) + + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: []float32{1.0, 1.0}, + TextIndex: 1, + }) + res.Usage = ali.Usage{ + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + for _, providerName := range s.providers { + provder, err := createVoyageAIProvider(ts.URL, s.schema.Fields[2], providerName) + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) + } +} + +func (s *VoyageAITextEmbeddingProviderSuite) TestEmbeddingNumberNotMatch() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var res ali.EmbeddingResponse + res.Output.Embeddings = append(res.Output.Embeddings, ali.Embeddings{ + Embedding: []float32{1.0, 1.0, 1.0, 1.0}, + TextIndex: 0, + }) + res.Usage = ali.Usage{ + TotalTokens: 100, + } + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + + defer ts.Close() + for _, provderName := range s.providers { + provder, err := createVoyageAIProvider(ts.URL, s.schema.Fields[2], provderName) + + s.NoError(err) + + // embedding dim not match + data := []string{"sentence", "sentence2"} + _, err2 := provder.CallEmbedding(data, false, InsertMode) + s.Error(err2) + } +}