Skip to content

Commit

Permalink
feat(ai-proxy): support OpenAI format /v1/models API; respect `mode…
Browse files Browse the repository at this point in the history
…l` field in request body in order (#6359)

* feat(ai-proxy): support OpenAI format `/v1/models` API; respect `model` field in request body in order

* check body at context filter

* remove useless code

* fix goimport
  • Loading branch information
sfwn authored Jun 6, 2024
1 parent e83827d commit 2ff6dfd
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 23 deletions.
1 change: 1 addition & 0 deletions cmd/ai-proxy/bootstrap.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ erda.app.ai-proxy:
- conf/routes/routes.yml
- conf/routes/assistant.yml
- conf/routes/file.yml
- conf/routes/openai_format.yml
log_level: ${LOG_LEVEL:info}
open_on_erda: ${OPEN_ON_ERDA:true} # 是否将 API 通过 Erda Openapi 暴露出来

Expand Down
5 changes: 5 additions & 0 deletions cmd/ai-proxy/conf/routes/openai_format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
routes:
- path: /v1/models
method: GET
filters:
- name: openai-v1-models
1 change: 1 addition & 0 deletions internal/apps/ai-proxy/dependent_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/initialize"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/log-http"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/openai-director"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/openai-v1-models"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/prometheus-collector"
_ "github.com/erda-project/erda/internal/apps/ai-proxy/filters/rate-limit"
)
54 changes: 39 additions & 15 deletions internal/apps/ai-proxy/filters/context/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
sessionpb "github.com/erda-project/erda-proto-go/apps/aiproxy/session/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/common"
"github.com/erda-project/erda/internal/apps/ai-proxy/common/ctxhelper"
openai_v1_models "github.com/erda-project/erda/internal/apps/ai-proxy/filters/openai-v1-models"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/client_token"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/metadata"
"github.com/erda-project/erda/internal/apps/ai-proxy/providers/dao"
Expand Down Expand Up @@ -66,6 +67,15 @@ func (f *Context) OnRequest(ctx context.Context, w http.ResponseWriter, infor re
q = ctx.Value(vars.CtxKeyDAO{}).(dao.DAO)
m = ctx.Value(reverseproxy.CtxKeyMap{}).(*sync.Map)
)

// check body
body := infor.BodyBuffer()
if body == nil {
err = fmt.Errorf("missing body")
l.Error(err)
return reverseproxy.Intercept, err
}

// find client
var client *clientpb.Client
ak := vars.TrimBearer(infor.Header().Get(httputil.HeaderKeyAuthorization))
Expand Down Expand Up @@ -114,6 +124,8 @@ func (f *Context) OnRequest(ctx context.Context, w http.ResponseWriter, infor re
// get from session if exists
headerSessionId := infor.Header().Get(vars.XAIProxySessionId)
headerModelId := infor.Header().Get(vars.XAIProxyModelId)
var allModelIDsByPriority []string
allModelIDsByPriority = append(allModelIDsByPriority, headerModelId)
if headerSessionId != "" && headerSessionId != vars.UIValueUndefined {
_session, err := q.SessionClient().Get(ctx, &sessionpb.SessionGetRequest{Id: headerSessionId})
if err != nil {
Expand All @@ -123,27 +135,39 @@ func (f *Context) OnRequest(ctx context.Context, w http.ResponseWriter, infor re
}
session = _session
if session.ModelId != "" {
sessionModel, err := q.ModelClient().Get(ctx, &modelpb.ModelGetRequest{Id: session.ModelId})
if err != nil {
l.Errorf("failed to get model, id: %s, err: %v", session.ModelId, err)
http.Error(w, "ModelId is invalid", http.StatusBadRequest)
return reverseproxy.Intercept, err
allModelIDsByPriority = append(allModelIDsByPriority, session.ModelId)
}
}
allModelIDsByPriority = strutil.DedupSlice(allModelIDsByPriority, true)
// if no model id found, respect 'model' field in request body
if len(allModelIDsByPriority) == 0 {
type Model struct {
ModelID string `json:"model"`
}
var m Model
if err := json.NewDecoder(body).Decode(&m); err == nil {
if m.ModelID != "" {
// parse truly model uuid, which is generated at api `/v1/models`/, see: internal/apps/ai-proxy/filters/openai-v1-models/filter.go#generateModelDisplayName
uuid := openai_v1_models.ParseModelUUIDFromDisplayName(m.ModelID)
if uuid != "" {
allModelIDsByPriority = append(allModelIDsByPriority, uuid)
}
}
model = sessionModel
}
} else if headerModelId != "" {
// get from model header
if headerModelId == "" {
http.Error(w, fmt.Sprintf("header %s is required", vars.XAIProxyModelId), http.StatusBadRequest)
return reverseproxy.Intercept, nil
}
for _, modelID := range allModelIDsByPriority {
if modelID == "" {
continue
}
headerModel, err := q.ModelClient().Get(ctx, &modelpb.ModelGetRequest{Id: headerModelId})
_model, err := q.ModelClient().Get(ctx, &modelpb.ModelGetRequest{Id: modelID})
// do not skip error, because modelId must be valid or be empty
if err != nil {
l.Errorf("failed to get model, id: %s, err: %v", headerModelId, err)
http.Error(w, "ModelId is invalid", http.StatusBadRequest)
l.Errorf("failed to get model, id: %s, err: %v", modelID, err)
http.Error(w, fmt.Sprintf("ModelId %s is invalid", modelID), http.StatusBadRequest)
return reverseproxy.Intercept, err
}
model = headerModel
model = _model
break
}
if model == nil {
// get client default model
Expand Down
91 changes: 91 additions & 0 deletions internal/apps/ai-proxy/filters/openai-v1-models/filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package openai_v1_models

import (
"context"
"encoding/json"
"fmt"
"net/http"

"github.com/sashabaranov/go-openai"

richclientpb "github.com/erda-project/erda-proto-go/apps/aiproxy/client/rich_client/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/common/ctxhelper"
"github.com/erda-project/erda/internal/apps/ai-proxy/handlers/common/akutil"
"github.com/erda-project/erda/internal/apps/ai-proxy/handlers/handler_rich_client"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
"github.com/erda-project/erda/pkg/reverseproxy"
)

const (
Name = "openai-v1-models"
)

var (
_ reverseproxy.RequestFilter = (*Filter)(nil)
)

func init() {
reverseproxy.RegisterFilterCreator(Name, New)
}

type Filter struct {
}

func New(_ json.RawMessage) (reverseproxy.Filter, error) {
return &Filter{}, nil
}

func (f *Filter) OnRequest(ctx context.Context, w http.ResponseWriter, infor reverseproxy.HttpInfor) (signal reverseproxy.Signal, err error) {
var (
richClientHandler = ctx.Value(vars.CtxKeyRichClientHandler{}).(*handler_rich_client.ClientHandler)
)
// try set clientId by ak
client, err := akutil.CheckAkOrToken(ctx, infor.Request(), ctxhelper.MustGetDBClient(ctx))
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get client, err: %v", err), http.StatusInternalServerError)
return reverseproxy.Intercept, nil
}
if client == nil {
http.Error(w, "Client not found", http.StatusUnauthorized)
return reverseproxy.Intercept, nil
}
ctx = context.WithValue(ctx, vars.CtxKeyClient{}, client)
ctx = context.WithValue(ctx, vars.CtxKeyClientId{}, client.Id)

richClient, err := richClientHandler.GetByAccessKeyId(ctx, &richclientpb.GetByClientAccessKeyIdRequest{AccessKeyId: client.AccessKeyId})
if err != nil {
http.Error(w, "Failed to get rich client", http.StatusInternalServerError)
return reverseproxy.Intercept, nil
}
// convert to openai /v1/models response, see: https://platform.openai.com/docs/api-reference/models/list
var oaiFormatModels openai.ModelsList
for _, m := range richClient.Models {
oaiFormatModels.Models = append(oaiFormatModels.Models, openai.Model{
ID: GenerateModelDisplayName(m),
CreatedAt: m.Model.CreatedAt.Seconds, // seconds
Object: "model", // always "model"
OwnedBy: m.Provider.Name,
})
}
// write response
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(oaiFormatModels); err != nil {
http.Error(w, "Failed to write response", http.StatusInternalServerError)
return reverseproxy.Intercept, nil
}
return reverseproxy.Continue, nil
}
52 changes: 52 additions & 0 deletions internal/apps/ai-proxy/filters/openai-v1-models/model_name.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package openai_v1_models

import (
"regexp"

richclientpb "github.com/erda-project/erda-proto-go/apps/aiproxy/client/rich_client/pb"
)

func GenerateModelDisplayName(model *richclientpb.RichModel) string {
s := model.Model.Name
attrs := []string{}
// provider type
attrs = append(attrs, "T:"+model.Provider.Type.String())
// provider location
if model.Provider.Metadata != nil && model.Provider.Metadata.Public != nil {
if loc := model.Provider.Metadata.Public["location"]; loc != "" {
attrs = append(attrs, "L:"+loc)
}
}
// model id at last
attrs = append(attrs, "ID:"+model.Model.Id)

attrs_s := ""
for _, attr := range attrs {
attrs_s += "[" + attr + "]"
}

return s + " " + attrs_s
}

func ParseModelUUIDFromDisplayName(s string) string {
regex := regexp.MustCompile(`\[ID:([^]]*)]`)
matches := regex.FindAllStringSubmatch(s, 1)
if len(matches) == 0 {
return ""
}
return matches[0][1]
}
57 changes: 57 additions & 0 deletions internal/apps/ai-proxy/filters/openai-v1-models/model_name_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package openai_v1_models

import "testing"

func TestParseModelUUIDFromDisplayName(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
want string
}{
{
name: "test1",
args: args{
s: "test",
},
want: "",
},
{
name: "test2",
args: args{
s: "test [ID:123]",
},
want: "123",
},
{
name: "test3",
args: args{
s: "test [ID:123][ID:456]",
},
want: "123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ParseModelUUIDFromDisplayName(tt.args.s); got != tt.want {
t.Errorf("ParseModelUUIDFromDisplayName() = %v, want %v", got, tt.want)
}
})
}
}
29 changes: 27 additions & 2 deletions internal/apps/ai-proxy/handlers/common/akutil/ak.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,39 @@ func (util *AKUtil) AkToClient(ak string) (*clientpb.Client, error) {
return pagingResp.List[0], nil
}

func (util *AKUtil) GetAkFromHeader(ctx context.Context) (string, bool) {
func (util *AKUtil) GetAkFromHeader(ctx context.Context, req any) (string, bool) {
// get from HTTP
v, ok := util.GetAkFromHTTPHeader(req)
if ok {
return v, true
}
// get from GRPC
v, ok = util.GetAkFromGRPCHeader(ctx)
if ok {
return v, true
}
return "", false
}

func (util *AKUtil) GetAkFromGRPCHeader(ctx context.Context) (string, bool) {
v := apis.GetHeader(ctx, httputil.HeaderKeyAuthorization)
v = vars.TrimBearer(v)
return v, v != ""
}

func (util *AKUtil) GetAkFromHTTPHeader(req any) (string, bool) {
var v string
if req != nil {
if r, ok := req.(*http.Request); ok {
v = r.Header.Get(httputil.HeaderKeyAuthorization)
}
}
v = vars.TrimBearer(v)
return v, v != ""
}

func CheckAkOrToken(ctx context.Context, req any, dao dao.DAO) (*clientpb.Client, error) {
ak, ok := New(dao).GetAkFromHeader(ctx)
ak, ok := New(dao).GetAkFromHeader(ctx, req)
if !ok {
return nil, handlers.ErrAkNotFound
}
Expand Down
6 changes: 5 additions & 1 deletion internal/apps/ai-proxy/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ type provider struct {
Dao dao.DAO `autowired:"erda.apps.ai-proxy.dao"`
DynamicOpenapi dynamic.DynamicOpenapiRegisterServer `autowired:"erda.core.openapi.dynamic_register.DynamicOpenapiRegister"`
ErdaOpenapis map[string]*url.URL

richClientHandler *handler_rich_client.ClientHandler
}

func (p *provider) Init(ctx servicehub.Context) error {
Expand All @@ -133,7 +135,8 @@ func (p *provider) Init(ctx servicehub.Context) error {
promptpb.RegisterPromptServiceImp(p, &handler_prompt.PromptHandler{DAO: p.Dao}, apis.Options(), encoderOpts, trySetAuth(p.Dao), permission.CheckPromptPerm)
sessionpb.RegisterSessionServiceImp(p, &handler_session.SessionHandler{DAO: p.Dao}, apis.Options(), encoderOpts, trySetAuth(p.Dao), permission.CheckSessionPerm)
clienttokenpb.RegisterClientTokenServiceImp(p, &handler_client_token.ClientTokenHandler{DAO: p.Dao}, apis.Options(), encoderOpts, trySetAuth(p.Dao), permission.CheckClientTokenPerm)
richclientpb.RegisterRichClientServiceImp(p, &handler_rich_client.ClientHandler{DAO: p.Dao}, apis.Options(), encoderOpts, trySetAuth(p.Dao), permission.CheckRichClientPerm)
p.richClientHandler = &handler_rich_client.ClientHandler{DAO: p.Dao}
richclientpb.RegisterRichClientServiceImp(p, p.richClientHandler, apis.Options(), encoderOpts, trySetAuth(p.Dao), permission.CheckRichClientPerm)

// ai-proxy prometheus metrics
p.HTTP.Handle("/metrics", http.MethodGet, promhttp.Handler())
Expand Down Expand Up @@ -166,6 +169,7 @@ func (p *provider) ServeAIProxy() {
reverseproxy.CtxKeyMap{}, new(sync.Map),
vars.CtxKeyDAO{}, p.Dao,
vars.CtxKeyErdaOpenapi{}, p.ErdaOpenapis,
vars.CtxKeyRichClientHandler{}, p.richClientHandler,
).ServeHTTP(w, r)
}
p.HTTP.HandlePrefix("/", "*", f, mux.SetXRequestId, mux.CORS)
Expand Down
Loading

0 comments on commit 2ff6dfd

Please sign in to comment.