Skip to content

Commit

Permalink
Add bedrock and sagemaker clients (#81)
Browse files Browse the repository at this point in the history
* Add aws.go

* Skip tests for now
  • Loading branch information
billytrend-cohere authored Jun 26, 2024
1 parent a867fbf commit faf1913
Show file tree
Hide file tree
Showing 6 changed files with 564 additions and 24 deletions.
3 changes: 2 additions & 1 deletion .fernignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ README.md
LICENSE
.github/workflows/e2e.yml
tests/
.github/workflows/ci.yml
.github/workflows/ci.yml
client/aws.go
258 changes: 258 additions & 0 deletions client/aws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
package client

import (
"bytes"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
errors "errors"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
option "github.com/cohere-ai/cohere-go/v2/option"
)

func newAwsRequestOptions(opts ...AwsRequestOption) *AwsRequestOptions {
options := &AwsRequestOptions{}
for _, opt := range opts {
opt.applyRequestOptions(options)
}
return options
}

type awsClient struct {
AwsRequestOptions
Service string
}

// RequestOption adapts the behavior of the client or an individual request.
type AwsRequestOption interface {
applyRequestOptions(*AwsRequestOptions)
}

type AwsRequestOptions struct {
environment string
awsRegion string
awsAccessKey string
awsSecretKey string
awsSessionToken string
}

// AwsEnvironment implements the RequestOption interface.
type AwsEnvironment struct {
environment string
}

func (h *AwsEnvironment) applyRequestOptions(opts *AwsRequestOptions) {
opts.environment = h.environment
}

type AwsRegion struct {
awsRegion string
}

func (h *AwsRegion) applyRequestOptions(opts *AwsRequestOptions) {
opts.awsRegion = h.awsRegion
}

func WithAwsRegion(region string) *AwsRegion {
return &AwsRegion{awsRegion: region}
}

type AwsAccessKey struct {
awsAccessKey string
}

func (h *AwsAccessKey) applyRequestOptions(opts *AwsRequestOptions) {
opts.awsAccessKey = h.awsAccessKey
}

func WithAwsAccessKey(accessKey string) *AwsAccessKey {
return &AwsAccessKey{awsAccessKey: accessKey}
}

type AwsSecretKey struct {
awsSecretKey string
}

func (h *AwsSecretKey) applyRequestOptions(opts *AwsRequestOptions) {
opts.awsSecretKey = h.awsSecretKey
}

func WithAwsSecretKey(secretKey string) *AwsSecretKey {
return &AwsSecretKey{awsSecretKey: secretKey}
}

type AwsSessionToken struct {
awsSessionToken string
}

func (h *AwsSessionToken) applyRequestOptions(opts *AwsRequestOptions) {
opts.awsSessionToken = h.awsSessionToken
}

func WithAwsSessionToken(sessionToken string) *AwsSessionToken {
return &AwsSessionToken{awsSessionToken: sessionToken}
}

func NewAwsClient(baseOpts []option.RequestOption, awsOpts []AwsRequestOption, service string) *Client {
options := newAwsRequestOptions(awsOpts...)

baseOpts = append(
baseOpts,
WithHTTPClient(&awsClient{Service: service, AwsRequestOptions: *options}),
WithToken("n/a"),
)

return NewClient(baseOpts...)
}

func (b *awsClient) Do(req *http.Request) (*http.Response, error) {
isStream, err := b.setModelParams(req)
if err != nil {
return nil, err
}

err = signRequest(b.awsAccessKey, b.awsSecretKey, b.awsSessionToken, b.Service, b.awsRegion, req)
if err != nil {
return nil, err
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}

// map response to expected bedrock stream response
if isStream && b.Service == "bedrock" {
resBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

// This is a naive implementation that assumes the responses are all in the same format
var mappedResponseEvents []string
events := strings.Split(string(resBody), "message")
for _, event := range events {

fmt.Println(event)
// look for the payload
idx := strings.Index(event, "\"bytes\":")
if idx == -1 {
continue
}
// get the payload and append it to the mapped response
eventBody := strings.Split(event[idx:], "\"")[3]

idx1 := strings.Index(eventBody, "stream-end")
if idx1 != -1 {
continue
}

// decode the payload
decoded, err := base64.StdEncoding.DecodeString(eventBody)
if err != nil {
return nil, err
}
mappedResponseEvents = append(mappedResponseEvents, string(decoded))
}

resp.Body = io.NopCloser(strings.NewReader(strings.Join(mappedResponseEvents, "\n") + "\n"))
}

// parse response
return resp, err
}

// modify the request to point to the bedrock model and remove the model from the request body
// handle stream param
func (b *awsClient) setModelParams(req *http.Request) (bool, error) {
// try to parse the model from the request body
reqBody, err := io.ReadAll(req.Body)
if err != nil {
return false, err
}
jsonBody := map[string]interface{}{}
err = json.Unmarshal(reqBody, &jsonBody)
if err != nil {
return false, err
}
model, ok := jsonBody["model"].(string)
if !ok {
return false, errors.New("model not found in request body")
}
delete(jsonBody, "model")
stream, ok := jsonBody["stream"].(bool)
if !ok {
stream = false
} else if b.Service == "bedrock" {
delete(jsonBody, "stream")
}

reqBody, err = json.Marshal(jsonBody)
if err != nil {
return false, err
}

req.URL, err = req.URL.Parse(getUrl(b.Service, b.awsRegion, model, stream))
if err != nil {
return false, err
}

req.Body = io.NopCloser(bytes.NewReader(reqBody))
req.ContentLength = int64(len(reqBody))
return stream, nil
}

// see https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_aws-signing.html
func signRequest(accessID, secretKey, token, service, region string, req *http.Request) error {
signer := v4.NewSigner()

bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return err
}
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
sha := sha256.New()
_, err = sha.Write(bodyBytes)
if err != nil {
return err
}
payloadHash := hex.EncodeToString(sha.Sum(nil))

err = signer.SignHTTP(
req.Context(),
aws.Credentials{AccessKeyID: accessID, SecretAccessKey: secretKey, SessionToken: token},
req,
payloadHash,
service,
region,
time.Now(),
)
if err != nil {
return err
}
return nil
}

func getUrl(platform string, awsRegion string, model string, stream bool) string {
endpoint := map[string]map[bool]string{
"bedrock": {true: "invoke-with-response-stream", false: "invoke"},
"sagemaker": {true: "invocations-response-stream", false: "invocations"},
}
return fmt.Sprintf("https://%s-runtime.%s.amazonaws.com/model/%s/%s", platform, awsRegion, model, endpoint[platform][stream])
}

func NewBedrockClient(baseOpts []option.RequestOption, awsOpts []AwsRequestOption) *Client {
return NewAwsClient(baseOpts, awsOpts, "bedrock")
}

func NewSagemakerClient(baseOpts []option.RequestOption, awsOpts []AwsRequestOption) *Client {
return NewAwsClient(baseOpts, awsOpts, "sagemaker")
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ module github.com/cohere-ai/cohere-go/v2
go 1.18

require (
github.com/aws/aws-sdk-go-v2 v1.28.0
github.com/google/uuid v1.4.0
github.com/stretchr/testify v1.7.0
)

require (
github.com/aws/smithy-go v1.20.2 // indirect
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
github.com/aws/aws-sdk-go-v2 v1.28.0 h1:ne6ftNhY0lUvlazMUQF15FF6NH80wKmPRFG7g2q6TCw=
github.com/aws/aws-sdk-go-v2 v1.28.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
Expand Down
Loading

0 comments on commit faf1913

Please sign in to comment.