diff --git a/lua/gp/config.lua b/lua/gp/config.lua index d4ddcd8d..23a45651 100644 --- a/lua/gp/config.lua +++ b/lua/gp/config.lua @@ -122,7 +122,7 @@ local config = { chat = true, command = false, -- string with model name or table with model name and parameters - model = { model = "gpt-4", temperature = 1.1, top_p = 1 }, + model = { model = "gpt-4o", temperature = 1.1, top_p = 1 }, -- system prompt (use this to specify the persona/role of the AI) system_prompt = require("gp.defaults").chat_system_prompt, }, @@ -220,7 +220,7 @@ local config = { chat = false, command = true, -- string with the Copilot engine name or table with engine name and parameters if applicable - model = { model = "gpt-4", temperature = 0.8, top_p = 1, n = 1 }, + model = { model = "gpt-4o", temperature = 0.8, top_p = 1, n = 1 }, -- system prompt (use this to specify the persona/role of the AI) system_prompt = require("gp.defaults").code_system_prompt, }, diff --git a/lua/gp/dispatcher.lua b/lua/gp/dispatcher.lua index 932b9fe3..e3e41e5a 100644 --- a/lua/gp/dispatcher.lua +++ b/lua/gp/dispatcher.lua @@ -44,13 +44,9 @@ D.setup = function(opts) end end - + local callbacks = { copilot = vault.refresh_copilot_bearer } for name, provider in pairs(D.providers) do - if name == "copilot" then - vault.resolve_secret(name, provider.secret, vault.refresh_copilot_bearer) - else - vault.resolve_secret(name, provider.secret) - end + vault.resolve_secret(name, provider.secret, callbacks[name]) provider.secret = nil end @@ -152,6 +148,10 @@ D.prepare_payload = function(messages, model, provider) return payload end + if provider == "copilot" and model.model == "gpt-4o" then + model.model = "gpt-4o-2024-05-13" + end + return { model = model.model, stream = true, @@ -304,7 +304,9 @@ D.query = function(buf, provider, payload, handler, on_exit, callback) if provider == "copilot" then vault.refresh_copilot_bearer() bearer = vault.get_secret("copilot_bearer") - if not bearer then + local expires_at = vault.get_secret("copilot_bearer_expires_at") + if not bearer or not expires_at or expires_at < os.time() then + logger.warning("copilot bearer token is missing or expired, trying to refresh..") return end headers = { diff --git a/lua/gp/vault.lua b/lua/gp/vault.lua index 78d599a2..141bf7a5 100644 --- a/lua/gp/vault.lua +++ b/lua/gp/vault.lua @@ -136,6 +136,7 @@ V.refresh_copilot_bearer = function() if bearer.token and bearer.expires_at and bearer.expires_at > os.time() then logger.debug("vault refresh_copilot_bearer token still valid", true) secrets.copilot_bearer = bearer.token + secrets.copilot_bearer_expires_at = bearer.expires_at return end