Skip to content

Commit

Permalink
feat: copilot with gpt4-o
Browse files Browse the repository at this point in the history
  • Loading branch information
Robitx committed Aug 3, 2024
1 parent 02080d8 commit e1700d1
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
4 changes: 2 additions & 2 deletions lua/gp/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ local config = {
chat = true,
command = false,
-- string with model name or table with model name and parameters
model = { model = "gpt-4", temperature = 1.1, top_p = 1 },
model = { model = "gpt-4o", temperature = 1.1, top_p = 1 },
-- system prompt (use this to specify the persona/role of the AI)
system_prompt = require("gp.defaults").chat_system_prompt,
},
Expand Down Expand Up @@ -220,7 +220,7 @@ local config = {
chat = false,
command = true,
-- string with the Copilot engine name or table with engine name and parameters if applicable
model = { model = "gpt-4", temperature = 0.8, top_p = 1, n = 1 },
model = { model = "gpt-4o", temperature = 0.8, top_p = 1, n = 1 },
-- system prompt (use this to specify the persona/role of the AI)
system_prompt = require("gp.defaults").code_system_prompt,
},
Expand Down
16 changes: 9 additions & 7 deletions lua/gp/dispatcher.lua
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,9 @@ D.setup = function(opts)
end
end


local callbacks = { copilot = vault.refresh_copilot_bearer }
for name, provider in pairs(D.providers) do
if name == "copilot" then
vault.resolve_secret(name, provider.secret, vault.refresh_copilot_bearer)
else
vault.resolve_secret(name, provider.secret)
end
vault.resolve_secret(name, provider.secret, callbacks[name])
provider.secret = nil
end

Expand Down Expand Up @@ -152,6 +148,10 @@ D.prepare_payload = function(messages, model, provider)
return payload
end

if provider == "copilot" and model.model == "gpt-4o" then
model.model = "gpt-4o-2024-05-13"
end

return {
model = model.model,
stream = true,
Expand Down Expand Up @@ -304,7 +304,9 @@ D.query = function(buf, provider, payload, handler, on_exit, callback)
if provider == "copilot" then
vault.refresh_copilot_bearer()
bearer = vault.get_secret("copilot_bearer")
if not bearer then
local expires_at = vault.get_secret("copilot_bearer_expires_at")
if not bearer or not expires_at or expires_at < os.time() then
logger.warning("copilot bearer token is missing or expired, trying to refresh..")
return
end
headers = {
Expand Down
1 change: 1 addition & 0 deletions lua/gp/vault.lua
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ V.refresh_copilot_bearer = function()
if bearer.token and bearer.expires_at and bearer.expires_at > os.time() then
logger.debug("vault refresh_copilot_bearer token still valid", true)
secrets.copilot_bearer = bearer.token
secrets.copilot_bearer_expires_at = bearer.expires_at
return
end

Expand Down

0 comments on commit e1700d1

Please sign in to comment.