Skip to content

Commit

Permalink
Add voyageai text embedding
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <[email protected]>
  • Loading branch information
junjiejiangjjj committed Jan 7, 2025
1 parent 6100ccd commit 65275f9
Show file tree
Hide file tree
Showing 7 changed files with 651 additions and 1 deletion.
13 changes: 13 additions & 0 deletions internal/util/function/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ const (
vertexServiceAccountJSONEnv string = "MILVUSAI_GOOGLE_APPLICATION_CREDENTIALS"
)

// voyageAI
const (
voyage3Large string = "voyage-3-large"
voyage3 string = "voyage-3"
voyage3Lite string = "voyage-3-lite"
voyageCode3 string = "voyage-code-3"
voyageFinance2 string = "voyage-finance-2"
voyageLaw2 string = "voyage-law-2"
voyageCode2 string = "voyage-code-2"

voyageAIAKEnvStr string = "MILVUSAI_VOYAGEAI_API_KEY"
)

func parseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) {
dim, err := strconv.ParseInt(dimStr, 10, 64)
if err != nil {
Expand Down
27 changes: 27 additions & 0 deletions internal/util/function/mock_embedding_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/internal/util/function/models/ali"
"github.com/milvus-io/milvus/internal/util/function/models/openai"
"github.com/milvus-io/milvus/internal/util/function/models/vertexai"
"github.com/milvus-io/milvus/internal/util/function/models/voyageai"
)

func mockEmbedding(texts []string, dim int) [][]float32 {
Expand Down Expand Up @@ -99,6 +100,32 @@ func CreateAliEmbeddingServer() *httptest.Server {
return ts
}

func CreateVoyageAIEmbeddingServer() *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req voyageai.EmbeddingRequest
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)
embs := mockEmbedding(req.Input, int(req.OutputDimension))
var res voyageai.EmbeddingResponse
for i := 0; i < len(req.Input); i++ {
res.Data = append(res.Data, voyageai.EmbeddingData{
Object: "list",
Embedding: embs[i],
Index: i,
})
}

res.Usage = voyageai.Usage{
TotalTokens: 100,
}
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))
return ts
}

func CreateVertexAIEmbeddingServer() *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req vertexai.EmbeddingRequest
Expand Down
152 changes: 152 additions & 0 deletions internal/util/function/models/voyageai/voyageai_text_embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package voyageai

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"sort"
"time"

"github.com/milvus-io/milvus/internal/util/function/models/utils"
)

type EmbeddingRequest struct {
// ID of the model to use.
Model string `json:"model"`

// Input text to embed, encoded as a string.
Input []string `json:"input"`

InputType string `json:"input_type,omitempty"`

Truncation bool `json:"truncation,omitempty"`

OutputDimension int64 `json:"output_dimension,omitempty"`

OutputDtype string `json:"output_dtype,omitempty"`

EncodingFormat string `json:"encoding_format,omitempty"`
}

type Usage struct {
// The total number of tokens used by the request.
TotalTokens int `json:"total_tokens"`
}

type EmbeddingData struct {
Object string `json:"object"`

Embedding []float32 `json:"embedding"`

Index int `json:"index"`
}

type EmbeddingResponse struct {
Object string `json:"object"`

Data []EmbeddingData `json:"data"`

Model string `json:"model"`

Usage Usage `json:"usage"`
}

type ByIndex struct {
resp *EmbeddingResponse
}

func (eb *ByIndex) Len() int { return len(eb.resp.Data) }
func (eb *ByIndex) Swap(i, j int) {
eb.resp.Data[i], eb.resp.Data[j] = eb.resp.Data[j], eb.resp.Data[i]
}

func (eb *ByIndex) Less(i, j int) bool {
return eb.resp.Data[i].Index < eb.resp.Data[j].Index
}

type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
}

type VoyageAIEmbedding struct {
apiKey string
url string
}

func NewVoyageAIEmbeddingClient(apiKey string, url string) *VoyageAIEmbedding {
return &VoyageAIEmbedding{
apiKey: apiKey,
url: url,
}
}

func (c *VoyageAIEmbedding) Check() error {
if c.apiKey == "" {
return fmt.Errorf("api key is empty")
}

if c.url == "" {
return fmt.Errorf("url is empty")
}
return nil
}

func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (*EmbeddingResponse, error) {
var r EmbeddingRequest
r.Model = modelName
r.Input = texts
r.InputType = textType
r.OutputDtype = outputType
if dim != 0 {
r.OutputDimension = int64(dim)
}
data, err := json.Marshal(r)
if err != nil {
return nil, err
}

if timeoutSec <= 0 {
timeoutSec = utils.DefaultTimeout
}

ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
body, err := utils.RetrySend(req, 3)
if err != nil {
return nil, err
}
var res EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {
return nil, err
}
sort.Sort(&ByIndex{&res})
return &res, err
}
127 changes: 127 additions & 0 deletions internal/util/function/models/voyageai/voyageai_text_embedding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package voyageai

import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestEmbeddingClientCheck(t *testing.T) {
{
c := NewVoyageAIEmbeddingClient("", "mock_uri")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := NewVoyageAIEmbeddingClient("mock_key", "")
err := c.Check()
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := NewVoyageAIEmbeddingClient("mock_key", "mock_uri")
err := c.Check()
assert.True(t, err == nil)
}
}

func TestEmbeddingOK(t *testing.T) {
var res EmbeddingResponse
repStr := `{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [
0.0,
0.1
],
"index": 0
},
{
"object": "embedding",
"embedding": [
2.0,
2.1
],
"index": 2
},
{
"object": "embedding",
"embedding": [
1.0,
1.1
],
"index": 1
}
],
"model": "voyage-large-2",
"usage": {
"total_tokens": 10
}
}`
err := json.Unmarshal([]byte(repStr), &res)
assert.NoError(t, err)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))

defer ts.Close()
url := ts.URL

{
c := NewVoyageAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
ret, err := c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0)
assert.True(t, err == nil)
assert.Equal(t, ret.Data[0].Index, 0)
assert.Equal(t, ret.Data[1].Index, 1)
assert.Equal(t, ret.Data[2].Index, 2)
assert.Equal(t, ret.Data[0].Embedding, []float32{0.0, 0.1})
assert.Equal(t, ret.Data[1].Embedding, []float32{1.0, 1.1})
assert.Equal(t, ret.Data[2].Embedding, []float32{2.0, 2.1})
}
}

func TestEmbeddingFailed(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))

defer ts.Close()
url := ts.URL

{
c := NewVoyageAIEmbeddingClient("mock_key", url)
err := c.Check()
assert.True(t, err == nil)
_, err = c.Embedding("voyage-3", []string{"sentence"}, 0, "query", "float", 0)
assert.True(t, err != nil)
}
}
12 changes: 11 additions & 1 deletion internal/util/function/text_embedding_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const (
aliDashScopeProvider string = "dashscope"
bedrockProvider string = "bedrock"
vertexAIProvider string = "vertexai"
voyageAIProvider string = "voyageai"
)

// Text embedding for retrieval task
Expand Down Expand Up @@ -130,8 +131,17 @@ func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *s
FunctionBase: *base,
embProvider: embP,
}, nil
case voyageAIProvider:
embP, err := NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema)
if err != nil {
return nil, err
}
return &TextEmbeddingFunction{
FunctionBase: *base,
embProvider: embP,
}, nil
default:
return nil, fmt.Errorf("Unsupported embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s]", provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider)
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s]", provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider)
}
}

Expand Down
Loading

0 comments on commit 65275f9

Please sign in to comment.