Skip to content

Commit

Permalink
feat: add CoT support for DeepSeek-R1
Browse files Browse the repository at this point in the history
  • Loading branch information
yuukibarns committed Jan 26, 2025
1 parent e06c018 commit ac26099
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 21 deletions.
43 changes: 38 additions & 5 deletions lua/gp/dispatcher.lua
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,19 @@ D.prepare_payload = function(messages, model, provider)
output.stream = false
end

if provider == "deepseek" and model.model == "deepseek-reasoner" then
for i = #messages, 1, -1 do
if messages[i].role == "system" then
table.remove(messages, i)
end
end
-- remove max_tokens, top_p, temperature for reason model
output.max_tokens = nil
output.temperature = nil
output.top_p = nil
output.stream = true
end

return output
end

Expand All @@ -198,6 +211,7 @@ end
---@param on_exit function | nil # optional on_exit handler
---@param callback function | nil # optional callback handler
local query = function(buf, provider, payload, handler, on_exit, callback)
local is_reasoning = payload.model == "deepseek-reasoner"
-- make sure handler is a function
if type(handler) ~= "function" then
logger.error(
Expand Down Expand Up @@ -238,9 +252,15 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
qt.raw_response = qt.raw_response .. line .. "\n"
end
line = line:gsub("^data: ", "")

local content = ""
local reasoning_content = ""

if line:match("choices") and line:match("delta") and line:match("content") then
line = vim.json.decode(line)
if line.choices[1] and line.choices[1].delta and line.choices[1].delta.reasoning_content then
reasoning_content = line.choices[1].delta.reasoning_content
end
if line.choices[1] and line.choices[1].delta and line.choices[1].delta.content then
content = line.choices[1].delta.content
end
Expand All @@ -264,10 +284,16 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end
end


if content and type(content) == "string" then
if reasoning_content ~= "" and type(reasoning_content) == "string" then
handler(qid, reasoning_content, true)
elseif content ~= "" and type(content) == "string" then
if is_reasoning then
handler(qid, "\n", true)
handler(qid, "\n</details><!-- }}} -->\n\n", false)
is_reasoning = false
end
qt.response = qt.response .. content
handler(qid, content)
handler(qid, content, false)
end
end
end
Expand Down Expand Up @@ -393,7 +419,7 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
end

local temp_file = D.query_dir ..
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
helpers.table_to_file(payload, temp_file)

local curl_params = vim.deepcopy(D.config.curl_params or {})
Expand Down Expand Up @@ -463,7 +489,7 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
})

local response = ""
return vim.schedule_wrap(function(qid, chunk)
return vim.schedule_wrap(function(qid, chunk, is_reasoning)
local qt = tasker.get_query(qid)
if not qt then
return
Expand Down Expand Up @@ -503,6 +529,13 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor)
lines[i] = prefix .. l
end

-- prepend prefix > to each line inside CoT
if is_reasoning then
for i, l in ipairs(lines) do
lines[i] = "> " .. l
end
end

local unfinished_lines = {}
for i = finished_lines + 1, #lines do
table.insert(unfinished_lines, lines[i])
Expand Down
58 changes: 42 additions & 16 deletions lua/gp/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1031,22 +1031,38 @@ M.chat_respond = function(params)
agent_suffix = M.render.template(agent_suffix, { ["{{agent}}"] = agent_name })

local old_default_user_prefix = "🗨:"
local in_cot_block = false -- Flag to track if we're inside a CoT block

for index = start_index, end_index do
local line = lines[index]
if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#M.config.chat_user_prefix + 1)
elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#old_default_user_prefix + 1)
elseif line:sub(1, #agent_prefix) == agent_prefix then
table.insert(messages, { role = role, content = content })
role = "assistant"
content = ""
elseif role ~= "" then
content = content .. "\n" .. line

-- Check if the line starts with <details> or </details>
if line:match("^%s*<details>") then
in_cot_block = true
end

-- Skip lines if we're inside a CoT block
if not in_cot_block then
-- Original logic for handling chat messages
if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#M.config.chat_user_prefix + 1)
elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then
table.insert(messages, { role = role, content = content })
role = "user"
content = line:sub(#old_default_user_prefix + 1)
elseif line:sub(1, #agent_prefix) == agent_prefix then
table.insert(messages, { role = role, content = content })
role = "assistant"
content = ""
elseif role ~= "" then
content = content .. "\n" .. line
end
end

if line:match("^%s*</details>") then
in_cot_block = false
end
end
-- insert last message not handled in loop
Expand Down Expand Up @@ -1074,12 +1090,21 @@ M.chat_respond = function(params)
local last_content_line = M.helpers.last_content_line(buf)
vim.api.nvim_buf_set_lines(buf, last_content_line, last_content_line, false, { "", agent_prefix .. agent_suffix, "" })

local offset = 0
-- Add CoT for DeepSeek-Reasoner
if agent_name == "DeepSeekReasoner" then
vim.api.nvim_buf_set_lines(buf, last_content_line + 3, last_content_line + 3, false,
{ "<details>", "<summary>CoT</summary><!-- {{{ -->", "" })
offset = 1
end

-- call the model and write response
M.dispatcher.query(
buf,
headers.provider or agent.provider,
M.dispatcher.prepare_payload(messages, headers.model or agent.model, headers.provider or agent.provider),
M.dispatcher.create_handler(buf, win, M.helpers.last_content_line(buf), true, "", not M.config.chat_free_cursor),
M.dispatcher.create_handler(buf, win, M.helpers.last_content_line(buf) + offset, true, "",
not M.config.chat_free_cursor),
vim.schedule_wrap(function(qid)
local qt = M.tasker.get_query(qid)
if not qt then
Expand Down Expand Up @@ -1125,7 +1150,8 @@ M.chat_respond = function(params)
topic_handler,
vim.schedule_wrap(function()
-- get topic from invisible buffer
local topic = vim.api.nvim_buf_get_lines(topic_buf, 0, -1, false)[1]
-- instead of the first line, get the last two line can skip CoT
local topic = vim.api.nvim_buf_get_lines(topic_buf, -3, -1, false)[1]
-- close invisible buffer
vim.api.nvim_buf_delete(topic_buf, { force = true })
-- strip whitespace from ends of topic
Expand Down

0 comments on commit ac26099

Please sign in to comment.