Skip to content

Commit

Permalink
Add openai embedding client
Browse files Browse the repository at this point in the history
Signed-off-by: junjie.jiang <[email protected]>
  • Loading branch information
junjiejiangjjj committed Sep 20, 2024
1 parent 89397d1 commit fd701cd
Show file tree
Hide file tree
Showing 2 changed files with 377 additions and 0 deletions.
192 changes: 192 additions & 0 deletions pkg/models/openai_embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package models

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)

const (
TextEmbeddingAda002 string = "text-embedding-ada-002"
TextEmbedding3Small string = "text-embedding-3-small"
TextEmbedding3Large string = "text-embedding-3-large"
)


type EmbeddingRequest struct {
// ID of the model to use.
Model string `json:"model"`

// Input text to embed, encoded as a string.
Input []string `json:"input"`

// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
User string `json:"user,omitempty"`

// The format to return the embeddings in. Can be either float or base64.
EncodingFormat string `json:"encoding_format,omitempty"`

// The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
}

type Usage struct {
// The number of tokens used by the prompt.
PromptTokens int `json:"prompt_tokens"`

// The total number of tokens used by the request.
TotalTokens int `json:"total_tokens"`
}


type EmbeddingData struct {
// The object type, which is always "embedding".
Object string `json:"object"`

// The embedding vector, which is a list of floats.
Embedding []float32 `json:"embedding"`

// The index of the embedding in the list of embeddings.
Index int `json:"index"`
}


type EmbeddingResponse struct {
// The object type, which is always "list".
Object string `json:"object"`

// The list of embeddings generated by the model.
Data []EmbeddingData `json:"data"`

// The name of the model used to generate the embedding.
Model string `json:"model"`

// The usage information for the request.
Usage Usage `json:"usage"`
}

type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
Param string `json:"param,omitempty"`
Type string `json:"type"`
}

type EmbedddingError struct {
Error ErrorInfo `json:"error"`
}

type OpenAIEmbeddingClient struct {
api_key string
uri string
model_name string
}

func (c *OpenAIEmbeddingClient) Check() error {
if c.model_name != TextEmbeddingAda002 && c.model_name != TextEmbedding3Small && c.model_name != TextEmbedding3Large {
return fmt.Errorf("Unsupported model: %s, only support [%s, %s, %s]",
c.model_name, TextEmbeddingAda002, TextEmbedding3Small, TextEmbedding3Large)
}

if c.api_key == "" {
return fmt.Errorf("OpenAI api key is empty")
}

if c.uri == "" {
return fmt.Errorf("OpenAI embedding uri is empty")
}
return nil
}


func (c *OpenAIEmbeddingClient) send(client *http.Client, req *http.Request, res *EmbeddingResponse) error {
// call openai
resp, err := client.Do(req)

if err != nil {
return err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}

if resp.StatusCode != 200 {
return fmt.Errorf(string(body))
}

err = json.Unmarshal(body, &res)
if err != nil {
return err
}
return nil
}

func (c *OpenAIEmbeddingClient) sendWithRetry(client *http.Client, req *http.Request,res *EmbeddingResponse, max_retries int) error {
var err error
for i := 0; i < max_retries; i++ {
err = c.send(client, req, res)
if err == nil {
return nil
}
}
return err
}

func (c *OpenAIEmbeddingClient) Embedding(texts []string, dim int, user string, timeout_sec time.Duration) (EmbeddingResponse, error) {
var r EmbeddingRequest
r.Model = c.model_name
r.Input = texts
r.EncodingFormat = "float"
if user != "" {
r.User = user
}
if dim != 0 {
r.Dimensions = dim
}

var res EmbeddingResponse
data, err := json.Marshal(r)
if err != nil {
return res, err
}

// call openai
if timeout_sec <= 0 {
timeout_sec = 30
}
client := &http.Client{
Timeout: timeout_sec * time.Second,
}
req, err := http.NewRequest("POST" , c.uri, bytes.NewBuffer(data))
if err != nil {
return res, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", c.api_key)

err = c.sendWithRetry(client, req, &res, 3)
return res, err

}
185 changes: 185 additions & 0 deletions pkg/models/openai_embedding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package models

import (
// "bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestEmbeddingClientCheck(t *testing.T) {
{
c := OpenAIEmbeddingClient{"mock_key", "mock_uri", "unknow_model"}
err := c.Check();
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := OpenAIEmbeddingClient{"", "mock_uri", TextEmbeddingAda002}
err := c.Check();
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := OpenAIEmbeddingClient{"mock_key", "", TextEmbedding3Small}
err := c.Check();
assert.True(t, err != nil)
fmt.Println(err)
}

{
c := OpenAIEmbeddingClient{"mock_key", "mock_uri", TextEmbedding3Small}
err := c.Check();
assert.True(t, err == nil)
}
}


func TestEmbeddingOK(t *testing.T) {
var res EmbeddingResponse
res.Object = "list"
res.Model = TextEmbedding3Small
res.Data = []EmbeddingData{
{
Object: "embedding",
Embedding: []float32{1.1, 2.2, 3.3, 4.4},
Index: 0,
},
}
res.Usage = Usage{
PromptTokens: 1,
TotalTokens: 100,
}

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}))

defer ts.Close()
url := ts.URL

{
c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small}
err := c.Check();
assert.True(t, err == nil)
ret, err := c.Embedding([]string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret, res)
}
}


func TestEmbeddingRetry(t *testing.T) {
var res EmbeddingResponse
res.Object = "list"
res.Model = TextEmbedding3Small
res.Data = []EmbeddingData{
{
Object: "embedding",
Embedding: []float32{1.1, 2.2, 3.3, 4.4},
Index: 0,
},
}
res.Usage = Usage{
PromptTokens: 1,
TotalTokens: 100,
}

var count = 0

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if count < 2 {
count += 1
w.WriteHeader(http.StatusUnauthorized)
} else {
w.WriteHeader(http.StatusOK)
data, _ := json.Marshal(res)
w.Write(data)
}
}))

defer ts.Close()
url := ts.URL

{
c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small}
err := c.Check();
assert.True(t, err == nil)
ret, err := c.Embedding([]string{"sentence"}, 0, "", 0)
assert.True(t, err == nil)
assert.Equal(t, ret, res)
assert.Equal(t, count, 2)
}
}


func TestEmbeddingFailed(t *testing.T) {
var count = 0

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count += 1
w.WriteHeader(http.StatusUnauthorized)
}))

defer ts.Close()
url := ts.URL

{
c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small}
err := c.Check();
assert.True(t, err == nil)
_, err = c.Embedding([]string{"sentence"}, 0, "", 0)
assert.True(t, err != nil)
assert.Equal(t, count, 3)
}
}

func TestTimeout(t *testing.T) {
var st = "Doing"

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(3 * time.Second)
st = "Done"
w.WriteHeader(http.StatusUnauthorized)

}))

defer ts.Close()
url := ts.URL

{
c := OpenAIEmbeddingClient{"mock_key", url, TextEmbedding3Small}
err := c.Check();
assert.True(t, err == nil)
_, err = c.Embedding([]string{"sentence"}, 0, "", 1)
assert.True(t, err != nil)
assert.Equal(t, st, "Doing")
time.Sleep(3 * time.Second)
assert.Equal(t, st, "Done")
}
}

0 comments on commit fd701cd

Please sign in to comment.