Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add CoT support for DeepSeek-R1 (only for reference) #228

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lua/gp/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ local config = {
-- secret : "sk-...",
-- secret = os.getenv("env_name.."),
openai = {
disable = false,
disable = true,
endpoint = "https://api.openai.com/v1/chat/completions",
-- secret = os.getenv("OPENAI_API_KEY"),
},
Expand Down Expand Up @@ -103,6 +103,7 @@ local config = {
disable = true,
},
{
provider = "openai",
name = "ChatGPT4o",
chat = true,
command = false,
Expand Down
60 changes: 49 additions & 11 deletions lua/gp/dispatcher.lua
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,19 @@ D.prepare_payload = function(messages, model, provider)
output.stream = false
end

if provider == "deepseek" and model.model == "deepseek-reasoner" then
for i = #messages, 1, -1 do
if messages[i].role == "system" then
table.remove(messages, i)
end
end
-- remove max_tokens, top_p, temperature for reason model
output.max_tokens = nil
output.temperature = nil
output.top_p = nil
output.stream = true
end

return output
end

Expand All @@ -198,6 +211,7 @@ end
---@param on_exit function | nil # optional on_exit handler
---@param callback function | nil # optional callback handler
local query = function(buf, provider, payload, handler, on_exit, callback)
local is_reasoning = payload.model == "deepseek-reasoner"
-- make sure handler is a function
if type(handler) ~= "function" then
logger.error(
Expand Down Expand Up @@ -238,9 +252,15 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
qt.raw_response = qt.raw_response .. line .. "\n"
end
line = line:gsub("^data: ", "")

local content = ""
local reasoning_content = ""

if line:match("choices") and line:match("delta") and line:match("content") then
line = vim.json.decode(line)
if line.choices[1] and line.choices[1].delta and line.choices[1].delta.reasoning_content then
reasoning_content = line.choices[1].delta.reasoning_content
end
if line.choices[1] and line.choices[1].delta and line.choices[1].delta.content then
content = line.choices[1].delta.content
end
Expand All @@ -264,10 +284,16 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end
end


if content and type(content) == "string" then
if reasoning_content ~= "" and type(reasoning_content) == "string" then
handler(qid, reasoning_content, true)
elseif content ~= "" and type(content) == "string" then
if is_reasoning then
handler(qid, "\n", true)
handler(qid, "\n</details>\n\n", false)
is_reasoning = false
end
qt.response = qt.response .. content
handler(qid, content)
handler(qid, content, false)
end
end
end
Expand Down Expand Up @@ -311,11 +337,16 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end
end


if qt.response == "" then
logger.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response))
if is_reasoning then
handler(qid, "\n", true)
handler(qid, "\n</details>\n\n", false)
is_reasoning = false
end

-- if qt.response == "" then
-- logger.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response))
-- end

-- optional on_exit handler
if type(on_exit) == "function" then
on_exit(qid)
Expand Down Expand Up @@ -393,7 +424,7 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end

local temp_file = D.query_dir ..
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
helpers.table_to_file(payload, temp_file)

local curl_params = vim.deepcopy(D.config.curl_params or {})
Expand Down Expand Up @@ -463,7 +494,7 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
})

local response = ""
return vim.schedule_wrap(function(qid, chunk)
return vim.schedule_wrap(function(qid, chunk, is_reasoning)
local qt = tasker.get_query(qid)
if not qt then
return
Expand Down Expand Up @@ -503,6 +534,13 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
lines[i] = prefix .. l
end

-- prepend prefix > to each line inside CoT
if is_reasoning then
for i, l in ipairs(lines) do
lines[i] = "> " .. l
end
end

local unfinished_lines = {}
for i = finished_lines + 1, #lines do
table.insert(unfinished_lines, lines[i])
Expand All @@ -511,9 +549,9 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
vim.api.nvim_buf_set_lines(buf, first_line + finished_lines, first_line + finished_lines, false, unfinished_lines)

local new_finished_lines = math.max(0, #lines - 1)
for i = finished_lines, new_finished_lines do
vim.api.nvim_buf_add_highlight(buf, qt.ns_id, hl_handler_group, first_line + i, 0, -1)
end
-- for i = finished_lines, new_finished_lines do
-- vim.api.nvim_buf_add_highlight(buf, qt.ns_id, hl_handler_group, first_line + i, 0, -1)
-- end
finished_lines = new_finished_lines

local end_line = first_line + #vim.split(response, "\n")
Expand Down
58 changes: 42 additions & 16 deletions lua/gp/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1031,22 +1031,38 @@ M.chat_respond = function(params)
agent_suffix = M.render.template(agent_suffix, { ["{{agent}}"] = agent_name })

local old_default_user_prefix = "🗨:"
local in_cot_block = false -- Flag to track if we're inside a CoT block

for index = start_index, end_index do
local line = lines[index]
if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#M.config.chat_user_prefix + 1)
elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#old_default_user_prefix + 1)
elseif line:sub(1, #agent_prefix) == agent_prefix then
table.insert(messages, { role = role, content = content })
role = "assistant"
content = ""
elseif role ~= "" then
content = content .. "\n" .. line

-- Check if the line starts with <details> or </details>
if line:match("^%s*<details>") then
in_cot_block = true
end

-- Skip lines if we're inside a CoT block
if not in_cot_block then
-- Original logic for handling chat messages
if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#M.config.chat_user_prefix + 1)
elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#old_default_user_prefix + 1)
elseif line:sub(1, #agent_prefix) == agent_prefix then
table.insert(messages, { role = role, content = content })
role = "assistant"
content = ""
elseif role ~= "" then
content = content .. "\n" .. line
end
end

if line:match("^%s*</details>") then
in_cot_block = false
end
end
-- insert last message not handled in loop
Expand Down Expand Up @@ -1074,12 +1090,21 @@ M.chat_respond = function(params)
local last_content_line = M.helpers.last_content_line(buf)
vim.api.nvim_buf_set_lines(buf, last_content_line, last_content_line, false, { "", agent_prefix .. agent_suffix, "" })

local offset = 0
-- Add CoT for DeepSeek-Reasoner
if agent_name == "DeepSeekReasoner" then
vim.api.nvim_buf_set_lines(buf, last_content_line + 3, last_content_line + 3, false,
{ "<details>", "<summary>CoT</summary>", "" })
offset = 1
end

-- call the model and write response
M.dispatcher.query(
buf,
headers.provider or agent.provider,
M.dispatcher.prepare_payload(messages, headers.model or agent.model, headers.provider or agent.provider),
M.dispatcher.create_handler(buf, win, M.helpers.last_content_line(buf), true, "", not M.config.chat_free_cursor),
M.dispatcher.create_handler(buf, win, M.helpers.last_content_line(buf) + offset, true, "",
not M.config.chat_free_cursor),
vim.schedule_wrap(function(qid)
local qt = M.tasker.get_query(qid)
if not qt then
Expand Down Expand Up @@ -1125,7 +1150,8 @@ M.chat_respond = function(params)
topic_handler,
vim.schedule_wrap(function()
-- get topic from invisible buffer
local topic = vim.api.nvim_buf_get_lines(topic_buf, 0, -1, false)[1]
-- instead of the first line, get the last two line can skip CoT
local topic = vim.api.nvim_buf_get_lines(topic_buf, -3, -1, false)[1]
-- close invisible buffer
vim.api.nvim_buf_delete(topic_buf, { force = true })
-- strip whitespace from ends of topic
Expand Down