diff --git a/lua/gp/dispatcher.lua b/lua/gp/dispatcher.lua index cdfadbb3..19a23486 100644 --- a/lua/gp/dispatcher.lua +++ b/lua/gp/dispatcher.lua @@ -187,6 +187,19 @@ D.prepare_payload = function(messages, model, provider) output.stream = false end + if provider == "deepseek" and model.model == "deepseek-reasoner" then + for i = #messages, 1, -1 do + if messages[i].role == "system" then + table.remove(messages, i) + end + end + -- remove max_tokens, top_p, temperature for reason model + output.max_tokens = nil + output.temperature = nil + output.top_p = nil + output.stream = true + end + return output end @@ -198,6 +211,7 @@ end ---@param on_exit function | nil # optional on_exit handler ---@param callback function | nil # optional callback handler local query = function(buf, provider, payload, handler, on_exit, callback) + local is_reasoning = payload.model == "deepseek-reasoner" -- make sure handler is a function if type(handler) ~= "function" then logger.error( @@ -238,9 +252,15 @@ local query = function(buf, provider, payload, handler, on_exit, callback) qt.raw_response = qt.raw_response .. line .. "\n" end line = line:gsub("^data: ", "") + local content = "" + local reasoning_content = "" + if line:match("choices") and line:match("delta") and line:match("content") then line = vim.json.decode(line) + if line.choices[1] and line.choices[1].delta and line.choices[1].delta.reasoning_content then + reasoning_content = line.choices[1].delta.reasoning_content + end if line.choices[1] and line.choices[1].delta and line.choices[1].delta.content then content = line.choices[1].delta.content end @@ -264,10 +284,16 @@ local query = function(buf, provider, payload, handler, on_exit, callback) end end - - if content and type(content) == "string" then + if reasoning_content ~= "" and type(reasoning_content) == "string" then + handler(qid, reasoning_content, true) + elseif content ~= "" and type(content) == "string" then + if is_reasoning then + handler(qid, "\n", true) + handler(qid, "\n\n\n", false) + is_reasoning = false + end qt.response = qt.response .. content - handler(qid, content) + handler(qid, content, false) end end end @@ -393,7 +419,7 @@ local query = function(buf, provider, payload, handler, on_exit, callback) end local temp_file = D.query_dir .. - "/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json" + "/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json" helpers.table_to_file(payload, temp_file) local curl_params = vim.deepcopy(D.config.curl_params or {}) @@ -463,7 +489,7 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor) }) local response = "" - return vim.schedule_wrap(function(qid, chunk) + return vim.schedule_wrap(function(qid, chunk, is_reasoning) local qt = tasker.get_query(qid) if not qt then return @@ -503,6 +529,13 @@ D.create_handler = function(buf, win, line, first_undojoin, prefix, cursor) lines[i] = prefix .. l end + -- prepend prefix > to each line inside CoT + if is_reasoning then + for i, l in ipairs(lines) do + lines[i] = "> " .. l + end + end + local unfinished_lines = {} for i = finished_lines + 1, #lines do table.insert(unfinished_lines, lines[i]) diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 128914cb..f3eda11c 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -1031,22 +1031,38 @@ M.chat_respond = function(params) agent_suffix = M.render.template(agent_suffix, { ["{{agent}}"] = agent_name }) local old_default_user_prefix = "🗨:" + local in_cot_block = false -- Flag to track if we're inside a CoT block + for index = start_index, end_index do local line = lines[index] - if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then - table.insert(messages, { role = role, content = content }) - role = "user" - content = line:sub(#M.config.chat_user_prefix + 1) - elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then - table.insert(messages, { role = role, content = content }) - role = "user" - content = line:sub(#old_default_user_prefix + 1) - elseif line:sub(1, #agent_prefix) == agent_prefix then - table.insert(messages, { role = role, content = content }) - role = "assistant" - content = "" - elseif role ~= "" then - content = content .. "\n" .. line + + -- Check if the line starts with
or
+ if line:match("^%s*
") then + in_cot_block = true + end + + -- Skip lines if we're inside a CoT block + if not in_cot_block then + -- Original logic for handling chat messages + if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then + table.insert(messages, { role = role, content = content }) + role = "user" + content = line:sub(#M.config.chat_user_prefix + 1) + elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then + table.insert(messages, { role = role, content = content }) + role = "user" + content = line:sub(#old_default_user_prefix + 1) + elseif line:sub(1, #agent_prefix) == agent_prefix then + table.insert(messages, { role = role, content = content }) + role = "assistant" + content = "" + elseif role ~= "" then + content = content .. "\n" .. line + end + end + + if line:match("^%s*
") then + in_cot_block = false end end -- insert last message not handled in loop @@ -1074,12 +1090,21 @@ M.chat_respond = function(params) local last_content_line = M.helpers.last_content_line(buf) vim.api.nvim_buf_set_lines(buf, last_content_line, last_content_line, false, { "", agent_prefix .. agent_suffix, "" }) + local offset = 0 + -- Add CoT for DeepSeek-Reasoner + if agent_name == "DeepSeekReasoner" then + vim.api.nvim_buf_set_lines(buf, last_content_line + 3, last_content_line + 3, false, + { "
", "CoT", "" }) + offset = 1 + end + -- call the model and write response M.dispatcher.query( buf, headers.provider or agent.provider, M.dispatcher.prepare_payload(messages, headers.model or agent.model, headers.provider or agent.provider), - M.dispatcher.create_handler(buf, win, M.helpers.last_content_line(buf), true, "", not M.config.chat_free_cursor), + M.dispatcher.create_handler(buf, win, M.helpers.last_content_line(buf) + offset, true, "", + not M.config.chat_free_cursor), vim.schedule_wrap(function(qid) local qt = M.tasker.get_query(qid) if not qt then @@ -1125,7 +1150,8 @@ M.chat_respond = function(params) topic_handler, vim.schedule_wrap(function() -- get topic from invisible buffer - local topic = vim.api.nvim_buf_get_lines(topic_buf, 0, -1, false)[1] + -- instead of the first line, get the last two line can skip CoT + local topic = vim.api.nvim_buf_get_lines(topic_buf, -3, -1, false)[1] -- close invisible buffer vim.api.nvim_buf_delete(topic_buf, { force = true }) -- strip whitespace from ends of topic