Skip to content

Commit

Permalink
invoke AWS Bedrock, not tested yet
Browse files Browse the repository at this point in the history
Signed-off-by: cbh778899 <[email protected]>
  • Loading branch information
cbh778899 committed Sep 25, 2024
1 parent 15b3ac5 commit 213e70a
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 90 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ENV_FILE:=.env
APP_PORT:=8000
DATABASE_BIND_PATH:=./lancedb
REGION:=ap-southeast-2
MODEL_ID:=anthropic.claude-3-sonnet-20240229-v1:0

# build and run this service only
.PHONY: build
Expand All @@ -22,6 +23,7 @@ env:
@echo "APP_PORT=$(APP_PORT)"> $(ENV_FILE)
@echo "DATABASE_BIND_PATH=$(DATABASE_BIND_PATH)">> $(ENV_FILE)
@echo "REGION=$(REGION)">> $(ENV_FILE)
@echo "MODEL_ID=$(MODEL_ID)">> $(ENV_FILE)

# normal build & up
.PHONY: compose-build
Expand Down
81 changes: 81 additions & 0 deletions actions/bedrock.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import { BedrockRuntimeClient, ConverseCommand, ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime";

/**
* @type { BedrockRuntimeClient? }
*/
let client = null;

/**
* Initialize or Re-initialize the bedrock runtime client in given region.
*/
export function rebuildBedrockClient() {
client = new BedrockRuntimeClient({
region: process.env.REGION || 'ap-southeast-2'
})
}

/**
* @callback InferenceCallback
* @param {String} text_piece Piece of text
* @param {Boolean} finished indicate whether response finished or not
*/

/**
* @typedef MessageContent
* @property {String} text text of the content
*/

/**
* @typedef Message
* @property {"user"|"assistant"} role
* @property {MessageContent[]} content
*/

/**
* @typedef Settings
* @property {Boolean} stream Whether response in stream or not
* @property {Number} max_tokens The max tokens response can have
* @property {Number} top_p Top P of the request
* @property {Number} temperature Temperature of the request
*/

/**
* Do inference with AWS Bedrock
* @param {Message[]} messages messages to inference
* @param {Settings} settings
* @param {InferenceCallback} cb
* @returns {Promise<String>} the whole response text no matter stream or not
*/
export async function inference(messages, settings, cb = null) {
if(!client) rebuildBedrockClient();

const { top_p, temperature, max_tokens } = settings;

const input = {
modelId: process.env.MODEL_ID || 'anthropic.claude-3-sonnet-20240229-v1:0',
messages,
inferenceConfig: {
maxTokens: max_tokens || 2048,
temperature: temperature || 0.7,
topP: top_p || 0.9
}
}

let command;
if(settings.stream) command = new ConverseStreamCommand(input);
else command = new ConverseCommand(input);

const response = await client.send(command);

let response_text;
for await (const resp of response.stream) {
if(resp.contentBlockDelta) {
text_piece = resp.contentBlockDelta.delta.text;
response_text += text_piece;
cb && cb(text_piece, false);
}
}
cb && cb('', true)

return response_text;
}
196 changes: 106 additions & 90 deletions actions/inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import { formatOpenAIContext } from "../tools/formatContext.js";
// import { formatOpenAIContext } from "../tools/formatContext.js";
import { generateFingerprint } from "../tools/generator.js";
import { post } from "../tools/request.js";
import { searchByMessage } from "../database/rag-inference.js";
import { userMessageHandler } from "../tools/plugin.js";
// import { post } from "../tools/request.js";
// import { searchByMessage } from "../database/rag-inference.js";
// import { userMessageHandler } from "../tools/plugin.js";
import { extractAPIKeyFromHeader, validateAPIKey } from "../tools/apiKey.js";
import { inference } from "./bedrock.js";

/**
* Generates a response content object for chat completion.
Expand Down Expand Up @@ -69,33 +70,33 @@ function generateResponseContent(
return resp;
}

/**
* Post to inference engine
* @param {Object} req_body The request body to be sent
* @param {Function} callback Callback function, takes one parameter contains parsed response json
* @param {Boolean} isStream To set the callback behaviour
*/
async function doInference(req_body, callback, isStream) {
if(isStream) {
const eng_resp = await post('completion', { body: req_body }, { getJSON: false });
const reader = eng_resp.body.pipeThrough(new TextDecoderStream()).getReader();
while(true) {
const { value, done } = await reader.read();
if(done) break;
const data = value.split("data: ").pop()
try {
callback(JSON.parse(data));
} catch(error) {
console.log(error)
callback({content: "", stop: true})
}
}
} else {
const eng_resp = await post('completion', { body: req_body });
if(eng_resp.http_error) return;
callback(eng_resp);
}
}
// /**
// * Post to inference engine
// * @param {Object} req_body The request body to be sent
// * @param {Function} callback Callback function, takes one parameter contains parsed response json
// * @param {Boolean} isStream To set the callback behaviour
// */
// async function doInference(req_body, callback, isStream) {
// if(isStream) {
// const eng_resp = await post('completion', { body: req_body }, { getJSON: false });
// const reader = eng_resp.body.pipeThrough(new TextDecoderStream()).getReader();
// while(true) {
// const { value, done } = await reader.read();
// if(done) break;
// const data = value.split("data: ").pop()
// try {
// callback(JSON.parse(data));
// } catch(error) {
// console.log(error)
// callback({content: "", stop: true})
// }
// }
// } else {
// const eng_resp = await post('completion', { body: req_body });
// if(eng_resp.http_error) return;
// callback(eng_resp);
// }
// }

function retrieveData(req_header, req_body) {
// retrieve api key
Expand All @@ -105,7 +106,7 @@ function retrieveData(req_header, req_body) {
}

// get attributes required special consideration
let { messages, max_tokens, ...request_body } = req_body;
let { messages, ...request_body } = req_body;

// validate messages
if(!messages || !messages.length) {
Expand All @@ -118,11 +119,11 @@ function retrieveData(req_header, req_body) {
})

// apply n_predict value
if(!max_tokens) max_tokens = 128;
request_body.n_predict = max_tokens;
// if(!max_tokens) max_tokens = 128;
// request_body.n_predict = max_tokens;

// apply stop value
if(!req_body.stop) request_body.stop = [...default_stop_keywords];
// if(!req_body.stop) request_body.stop = [...default_stop_keywords];

// generated fields
const system_fingerprint = generateFingerprint();
Expand All @@ -133,7 +134,7 @@ function retrieveData(req_header, req_body) {

}

const default_stop_keywords = ["<|endoftext|>", "<|end|>", "<|user|>", "<|assistant|>"]
// const default_stop_keywords = ["<|endoftext|>", "<|end|>", "<|user|>", "<|assistant|>"]

/**
* Handles a chat completion request, generating a response based on the input messages.
Expand Down Expand Up @@ -191,73 +192,88 @@ export async function chatCompletion(req, res) {
res.setHeader("X-Accel-Buffering", "no");
res.setHeader("Connection", "Keep-Alive");
}
doInference(request_body, (data) => {
const { content, stop } = data;
// doInference(request_body, (data) => {
// const { content, stop } = data;
// if(isStream) {
// res.write(JSON.stringify(
// generateResponseContent(
// api_key, 'chat.completion.chunk', model, system_fingerprint, isStream, content, stop
// )
// )+'\n\n');
// if(stop) res.end();
// } else {
// res.send(generateResponseContent(
// api_key, 'chat.completion', model, system_fingerprint,
// isStream, content, true
// ))
// }
// }, isStream)
inference(messages, request_body, (text_piece, finished) => {
if(isStream) {
res.write(JSON.stringify(
generateResponseContent(
api_key, 'chat.completion.chunk', model, system_fingerprint, isStream, content, stop
api_key, 'chat.completion.chunk', model, system_fingerprint, isStream, text_piece, finished
)
)+'\n\n');
if(stop) res.end();
if(finished) res.end();
} else {
res.send(generateResponseContent(
api_key, 'chat.completion', model, system_fingerprint,
isStream, content, true
isStream, text_piece, true
))
}
}, isStream)
})
}

/**
* Handles a RAG-based (Retrieval-Augmented Generation) chat completion request.
*
* @async
* @param {Request} req - The HTTP request object.
* @param {Response} res - The HTTP response object.
* @returns {Promise<void>} A promise that resolves when the response is sent.
*/
export async function ragChatCompletion(req, res) {
const {error, body, status, message} = retrieveData(req.headers, req.body);
if(error) {
res.status(status).send(message);
return;
}
const { dataset_name, ...request_body } = body.request_body;
if(!dataset_name) {
res.status(422).send("Dataset name not specified.");
}
const { api_key, model, system_fingerprint, messages } = body
// /**
// * Handles a RAG-based (Retrieval-Augmented Generation) chat completion request.
// *
// * @async
// * @param {Request} req - The HTTP request object.
// * @param {Response} res - The HTTP response object.
// * @returns {Promise<void>} A promise that resolves when the response is sent.
// */
// export async function ragChatCompletion(req, res) {
// const {error, body, status, message} = retrieveData(req.headers, req.body);
// if(error) {
// res.status(status).send(message);
// return;
// }
// const { dataset_name, ...request_body } = body.request_body;
// if(!dataset_name) {
// res.status(422).send("Dataset name not specified.");
// }
// const { api_key, model, system_fingerprint, messages } = body

const latest_message = messages.slice(-1)[0].content;
const rag_result = await searchByMessage(dataset_name, latest_message);
// const latest_message = messages.slice(-1)[0].content;
// const rag_result = await searchByMessage(dataset_name, latest_message);

const context = [...messages];
if(rag_result) context.push({
role: "system",
content: `This background information is useful for your next answer: "${rag_result.context}"`
})
request_body.prompt = formatOpenAIContext(context);
// const context = [...messages];
// if(rag_result) context.push({
// role: "system",
// content: `This background information is useful for your next answer: "${rag_result.context}"`
// })
// request_body.prompt = formatOpenAIContext(context);

const isStream = !!request_body.stream;
if(isStream) {
res.setHeader("Content-Type", "text/event-stream");
res.setHeader("Cache-Control", "no-cache");
res.setHeader("X-Accel-Buffering", "no");
res.setHeader("Connection", "Keep-Alive");
}
doInference(request_body, (data) => {
const { content, stop } = data;
const openai_response = generateResponseContent(
api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop
)
const rag_response = stop ? { content: openai_response, rag_context: rag_result } : openai_response;
// const isStream = !!request_body.stream;
// if(isStream) {
// res.setHeader("Content-Type", "text/event-stream");
// res.setHeader("Cache-Control", "no-cache");
// res.setHeader("X-Accel-Buffering", "no");
// res.setHeader("Connection", "Keep-Alive");
// }
// doInference(request_body, (data) => {
// const { content, stop } = data;
// const openai_response = generateResponseContent(
// api_key, 'chat.completion.chunk', model, system_fingerprint, true, content, stop
// )
// const rag_response = stop ? { content: openai_response, rag_context: rag_result } : openai_response;

if(isStream) {
res.write(JSON.stringify(rag_response)+'\n\n');
if(stop) res.end();
} else {
res.send(rag_response);
}
}, isStream)
}
// if(isStream) {
// res.write(JSON.stringify(rag_response)+'\n\n');
// if(stop) res.end();
// } else {
// res.send(rag_response);
// }
// }, isStream)
// }

0 comments on commit 213e70a

Please sign in to comment.