From 4c8cf18d02c9c053220831f1914137c7ecf3d763 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Fri, 26 Jul 2024 22:18:08 +0800 Subject: [PATCH 01/34] [WIP] Completion menu in markdown files Finally got menus to at least show up. --- after/plugin/gp.lua | 7 ++ lua/gp/completion.lua | 145 ++++++++++++++++++++++++++++++++++++++++++ lua/gp/init.lua | 15 ++++- 3 files changed, 164 insertions(+), 3 deletions(-) create mode 100644 after/plugin/gp.lua create mode 100644 lua/gp/completion.lua diff --git a/after/plugin/gp.lua b/after/plugin/gp.lua new file mode 100644 index 00000000..33d64b99 --- /dev/null +++ b/after/plugin/gp.lua @@ -0,0 +1,7 @@ +print("in after/plugin/gp.lua") +local completion = require("gp.completion") + +print(vim.inspect(completion)) + +completion.register_cmd_source() +print("done after/plugin/gp.lua") diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua new file mode 100644 index 00000000..7d3996ad --- /dev/null +++ b/lua/gp/completion.lua @@ -0,0 +1,145 @@ +print("top of gp.completion.lua") + +local source = {} + +source.src_name = "gp_completion" + +source.new = function() + print("source.new called") + return setmetatable({}, { __index = source }) +end + +source.get_trigger_characters = function() + print("in get_trigger_characters...") + return { "@", ":" } + -- return { "@" } +end + +source.get_keyword_pattern = function() + print("in get_keyword_pattern...") + -- return [[@code:[\w:]*]] + -- return [[@([\w-]+)(?::([\w-]+))?]] + -- return [[@file:]] + return [[@(code|file):?]] +end + +source.setup_for_buffer = function(bufnr) + print("in setup_for_buffer") + local config = require("cmp").get_config() + + print("cmp.get_config() returned:") + print(vim.inspect(config)) + + print("cmp_config.set_buffer: " .. config.set_buffer) + config.set_buffer({ + sources = { + { name = source.src_name }, + }, + }, bufnr) +end + +source.setup_autocmd_for_markdown = function() + print("setting up autocmd...") + vim.api.nvim_create_autocmd("BufEnter", { + pattern = { "*.md", "markdown" }, + callback = function() + print("attaching completion source for buffer: " .. vim.api.nvim_get_current_buf()) + + local cmp = require("cmp") + cmp.setup.buffer({ + sources = cmp.config.sources({ + { name = source.src_name }, + }), + }) + end, + }) +end + +source.register_cmd_source = function() + print("registering completion src") + local s = source.new() + print("new instance: ") + print(vim.inspect(s)) + require("cmp").register_source(source.src_name, s) +end + +local function get_project_files() + -- Assuming the cwd is the project root directory for now + local cwd = vim.fn.getcwd() + local handle = vim.loop.fs_scandir(cwd) + + local files = {} + + if handle then + while true do + local name, type = vim.loop.fs_scandir_next(handle) + if not name then + break + end + + if type == "file" then + table.insert(files, { + label = name, + kind = require("cmp").lsp.CompletionItemKind.File, + }) + end + end + end + + return files +end + +source.complete = function(self, request, callback) + print("[complete] Function called") + local input = string.sub(request.context.cursor_before_line, request.offset - 1) + print("[complete] input: '" .. input .. "'") + print("[complete] offset: " .. request.offset) + print("[complete] cursor_before_line: '" .. request.context.cursor_before_line .. "'") + + local items = {} + local isIncomplete = true + + if request.context.cursor_before_line:match("^@file:$") then + print("[complete] @file: case") + items = { + { label = "file1.lua", kind = require("cmp").lsp.CompletionItemKind.File }, + { label = "file2.lua", kind = require("cmp").lsp.CompletionItemKind.File }, + } + elseif input:match("^@code:") then + print("[complete] @code: case") + local parts = vim.split(input, ":", { plain = true }) + if #parts == 1 then + items = { + { label = "filename1.lua", kind = require("cmp").lsp.CompletionItemKind.File }, + { label = "filename2.lua", kind = require("cmp").lsp.CompletionItemKind.File }, + { label = "function1", kind = require("cmp").lsp.CompletionItemKind.Function }, + { label = "function2", kind = require("cmp").lsp.CompletionItemKind.Function }, + } + elseif #parts == 2 then + items = { + { label = "function1", kind = require("cmp").lsp.CompletionItemKind.Function }, + { label = "function2", kind = require("cmp").lsp.CompletionItemKind.Function }, + } + end + elseif input:match("^@") then + print("[complete] @ case") + items = { + { label = "code", kind = require("cmp").lsp.CompletionItemKind.Keyword }, + { label = "file", kind = require("cmp").lsp.CompletionItemKind.Keyword }, + } + else + print("[complete] default case") + isIncomplete = false + end + + local data = { items = items, isIncomplete = isIncomplete } + print("[complete] Callback data:") + print(vim.inspect(data)) + callback(data) + print("[complete] Callback called") +end + +source.setup_autocmd_for_markdown() + +print("end of gp.completion.lua") +return source diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 25598afd..c77b9b2f 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -2060,6 +2060,7 @@ M.new_chat = function(params, toggle, system_prompt, agent) M.append_selection(params, cbuf, buf) end M._H.feedkeys("G", "xn") + return buf end @@ -2106,6 +2107,7 @@ end ---@param system_prompt string | nil ---@param agent table | nil # obtained from get_command_agent or get_chat_agent M.cmd.ChatToggle = function(params, system_prompt, agent) + print(">> ChatToggle") if M._toggle_close(M._toggle_kind.popup) then return end @@ -2120,18 +2122,25 @@ M.cmd.ChatToggle = function(params, system_prompt, agent) end -- if the range is 2, we want to create a new chat file with the selection + local buf if params.range ~= 2 then -- check if last.md chat file exists and open it local last = M.config.chat_dir .. "/last.md" if vim.fn.filereadable(last) == 1 then -- resolve symlink last = vim.fn.resolve(last) - M.open_buf(last, M.resolve_buf_target(params), M._toggle_kind.chat, true) - return + buf = M.open_buf(last, M.resolve_buf_target(params), M._toggle_kind.chat, true) end + else + buf = M.new_chat(params, true, system_prompt, agent) end - M.new_chat(params, true, system_prompt, agent) + -- Tell nvim-cmp to use our completion source for the new buffer + -- print("In ChatToggle, trying to setup per buffer completion source") + -- local completion = require("gp.completion") + -- completion.setup_for_buffer(buf) + + return buf end M.cmd.ChatPaste = function(params) From ca1d7ed0d07642781c5b150529ab6799215980d9 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sat, 27 Jul 2024 17:39:53 +0800 Subject: [PATCH 02/34] Now able to provide completion for `@file:path/to/file` --- lua/gp/completion.lua | 142 ++++++++++++++++++++++++++++++------------ 1 file changed, 102 insertions(+), 40 deletions(-) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index 7d3996ad..895b3d22 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -11,16 +11,7 @@ end source.get_trigger_characters = function() print("in get_trigger_characters...") - return { "@", ":" } - -- return { "@" } -end - -source.get_keyword_pattern = function() - print("in get_keyword_pattern...") - -- return [[@code:[\w:]*]] - -- return [[@([\w-]+)(?::([\w-]+))?]] - -- return [[@file:]] - return [[@(code|file):?]] + return { "@", ":", "/" } end source.setup_for_buffer = function(bufnr) @@ -57,54 +48,124 @@ end source.register_cmd_source = function() print("registering completion src") - local s = source.new() - print("new instance: ") - print(vim.inspect(s)) - require("cmp").register_source(source.src_name, s) + require("cmp").register_source(source.src_name, source.new()) +end + +local function extract_cmd(request) + local target = request.context.cursor_before_line + local start = target:match(".*()@") + if start then + return string.sub(target, start, request.offset) + end +end + +local function cmd_split(cmd) + return vim.split(cmd, ":", { plain = true }) +end + +local function path_split(path) + return vim.split(path, "/") +end + +local function path_join(...) + local args = { ... } + local parts = {} + + for i, part in ipairs(args) do + if type(part) ~= "string" then + error("Argument #" .. i .. " is not a string", 2) + end + + -- Remove leading/trailing separators (both / and \) + part = part:gsub("^[/\\]+", ""):gsub("[/\\]+$", "") + + if #part > 0 then + table.insert(parts, part) + end + end + + local result = table.concat(parts, "/") + + if args[1]:match("^[/\\]") then + result = "/" .. result + end + + return result end -local function get_project_files() - -- Assuming the cwd is the project root directory for now +local function completion_items_for_path(path) + local cmp = require("cmp") + + -- The incoming path should either be + -- - A relative path that references a directory + -- - A relative path + partial filename as last component- + -- We need a bit of logic to figure out which directory content to return + + -------------------------------------------------------------------- + -- Figure out the full path of the directory we're trying to list -- + -------------------------------------------------------------------- + -- Split the path into component parts + local path_parts = path_split(path) + if path[#path] ~= "/" then + table.remove(path_parts) + end + + -- Assuming the cwd is the project root directory... local cwd = vim.fn.getcwd() - local handle = vim.loop.fs_scandir(cwd) + local target_dir = path_join(cwd, unpack(path_parts)) + -------------------------------------------- + -- List the items in the target directory -- + -------------------------------------------- + local handle = vim.loop.fs_scandir(target_dir) local files = {} - if handle then - while true do - local name, type = vim.loop.fs_scandir_next(handle) - if not name then - break - end - - if type == "file" then - table.insert(files, { - label = name, - kind = require("cmp").lsp.CompletionItemKind.File, - }) - end + if not handle then + return files + end + + while true do + local name, type = vim.loop.fs_scandir_next(handle) + if not name then + break end + + local item_name, item_kind + if type == "file" then + item_kind = cmp.lsp.CompletionItemKind.File + item_name = name + elseif type == "directory" then + item_kind = cmp.lsp.CompletionItemKind.Folder + item_name = name .. "/" + end + + table.insert(files, { + label = item_name, + kind = item_kind, + }) end return files end source.complete = function(self, request, callback) - print("[complete] Function called") local input = string.sub(request.context.cursor_before_line, request.offset - 1) - print("[complete] input: '" .. input .. "'") - print("[complete] offset: " .. request.offset) - print("[complete] cursor_before_line: '" .. request.context.cursor_before_line .. "'") + local cmd = extract_cmd(request) + local cmd_parts = cmd_split(cmd) local items = {} local isIncomplete = true - if request.context.cursor_before_line:match("^@file:$") then - print("[complete] @file: case") - items = { - { label = "file1.lua", kind = require("cmp").lsp.CompletionItemKind.File }, - { label = "file2.lua", kind = require("cmp").lsp.CompletionItemKind.File }, - } + if cmd_parts[1]:match("@file") then + -- What's the path we're trying to provide completion for? + local path = cmd_parts[2] + + -- List the items in the specified directory + items = completion_items_for_path(path) + + -- Say that the entire list has been provided + -- cmp won't call us again to provide an updated list + isIncomplete = false elseif input:match("^@code:") then print("[complete] @code: case") local parts = vim.split(input, ":", { plain = true }) @@ -127,6 +188,7 @@ source.complete = function(self, request, callback) { label = "code", kind = require("cmp").lsp.CompletionItemKind.Keyword }, { label = "file", kind = require("cmp").lsp.CompletionItemKind.Keyword }, } + isIncomplete = false else print("[complete] default case") isIncomplete = false From d97691f08804d153a66cfa63210c347fe0baa5ad Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sat, 27 Jul 2024 21:34:37 +0800 Subject: [PATCH 03/34] Don't try to configure sources multiple times --- lua/gp/completion.lua | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index 895b3d22..cd9a5316 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -1,5 +1,20 @@ print("top of gp.completion.lua") +-- Gets a buffer variable or returns the default +local function buf_get_var(buf, var_name, default) + local status, result = pcall(vim.api.nvim_buf_get_var, buf, var_name) + if status then + return result + else + return default + end +end + +-- This function is only here make the get/set call pair look consistent +local function buf_set_var(buf, var_name, value) + return vim.api.nvim_buf_set_var(buf, var_name, value) +end + local source = {} source.src_name = "gp_completion" @@ -10,7 +25,6 @@ source.new = function() end source.get_trigger_characters = function() - print("in get_trigger_characters...") return { "@", ":", "/" } end @@ -33,15 +47,22 @@ source.setup_autocmd_for_markdown = function() print("setting up autocmd...") vim.api.nvim_create_autocmd("BufEnter", { pattern = { "*.md", "markdown" }, - callback = function() - print("attaching completion source for buffer: " .. vim.api.nvim_get_current_buf()) - + callback = function(arg) + local attached_varname = "gp_source_attached" + local attached = buf_get_var(arg.buf, attached_varname, false) + if attached then + return + end + + print("attaching completion source for buffer: " .. arg.buf) local cmp = require("cmp") cmp.setup.buffer({ sources = cmp.config.sources({ { name = source.src_name }, }), }) + + buf_set_var(arg.buf, attached_varname, true) end, }) end @@ -150,7 +171,13 @@ end source.complete = function(self, request, callback) local input = string.sub(request.context.cursor_before_line, request.offset - 1) + print("[comp] input: '" .. input .. "'") local cmd = extract_cmd(request) + if not cmd then + return + end + + print("[comp] cmd: '" .. cmd .. "'") local cmd_parts = cmd_split(cmd) local items = {} From d384fc752e7281193c4d4e01344a2e5bc627c40a Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sat, 27 Jul 2024 22:27:50 +0800 Subject: [PATCH 04/34] Parse out context request and insert into user message transparently --- lua/gp/completion.lua | 44 ++++------------------------ lua/gp/context.lua | 67 +++++++++++++++++++++++++++++++++++++++++++ lua/gp/init.lua | 5 ++++ lua/gp/utils.lua | 33 +++++++++++++++++++++ 4 files changed, 110 insertions(+), 39 deletions(-) create mode 100644 lua/gp/context.lua create mode 100644 lua/gp/utils.lua diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index cd9a5316..42296862 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -1,4 +1,5 @@ -print("top of gp.completion.lua") +local u = require("gp.utils") +local context = require("gp.context") -- Gets a buffer variable or returns the default local function buf_get_var(buf, var_name, default) @@ -80,40 +81,6 @@ local function extract_cmd(request) end end -local function cmd_split(cmd) - return vim.split(cmd, ":", { plain = true }) -end - -local function path_split(path) - return vim.split(path, "/") -end - -local function path_join(...) - local args = { ... } - local parts = {} - - for i, part in ipairs(args) do - if type(part) ~= "string" then - error("Argument #" .. i .. " is not a string", 2) - end - - -- Remove leading/trailing separators (both / and \) - part = part:gsub("^[/\\]+", ""):gsub("[/\\]+$", "") - - if #part > 0 then - table.insert(parts, part) - end - end - - local result = table.concat(parts, "/") - - if args[1]:match("^[/\\]") then - result = "/" .. result - end - - return result -end - local function completion_items_for_path(path) local cmp = require("cmp") @@ -126,14 +93,14 @@ local function completion_items_for_path(path) -- Figure out the full path of the directory we're trying to list -- -------------------------------------------------------------------- -- Split the path into component parts - local path_parts = path_split(path) + local path_parts = u.path_split(path) if path[#path] ~= "/" then table.remove(path_parts) end -- Assuming the cwd is the project root directory... local cwd = vim.fn.getcwd() - local target_dir = path_join(cwd, unpack(path_parts)) + local target_dir = u.path_join(cwd, unpack(path_parts)) -------------------------------------------- -- List the items in the target directory -- @@ -178,7 +145,7 @@ source.complete = function(self, request, callback) end print("[comp] cmd: '" .. cmd .. "'") - local cmd_parts = cmd_split(cmd) + local cmd_parts = context.cmd_split(cmd) local items = {} local isIncomplete = true @@ -230,5 +197,4 @@ end source.setup_autocmd_for_markdown() -print("end of gp.completion.lua") return source diff --git a/lua/gp/context.lua b/lua/gp/context.lua new file mode 100644 index 00000000..109c0fac --- /dev/null +++ b/lua/gp/context.lua @@ -0,0 +1,67 @@ +local u = require("gp.utils") +local gp = require("gp") +local M = {} + +-- Split a context insertion command into its component parts +function M.cmd_split(cmd) + return vim.split(cmd, ":", { plain = true }) +end + +local function read_file(filepath) + local file = io.open(filepath, "r") + if not file then + return nil + end + local content = file:read("*all") + file:close() + return content +end + +-- Given a single message, parse out all the context insertion +-- commands, then return a new message with all the requested +-- context inserted +function M.insert_contexts(msg) + local context_texts = {} + + -- Parse out all context insertion commands + local cmds = {} + for cmd in msg:gmatch("@file:[%w%p]+") do + table.insert(cmds, cmd) + end + + -- Process each command and turn it into a string be + -- inserted as additional context + for _, cmd in ipairs(cmds) do + local cmd_parts = M.cmd_split(cmd) + + if cmd_parts[1] == "@file" then + -- Read the reqested file and produce a msg snippet to be joined later + local filepath = cmd_parts[2] + + local cwd = vim.fn.getcwd() + local fullpath = u.path_join(cwd, filepath) + + local content = read_file(fullpath) + if content then + local result = gp._H.template_render("filepath\n```content```", { + filepath = filepath, + content = content, + }) + table.insert(context_texts, result) + end + end + end + + -- If no context insertions are requested, don't alter the original msg + if #context_texts == 0 then + return msg + else + -- Otherwise, build and return the final message + return gp._H.template_render("context\n\nmsg", { + context = table.concat(context_texts, "\n"), + msg = msg, + }) + end +end + +return M diff --git a/lua/gp/init.lua b/lua/gp/init.lua index c77b9b2f..df67b75f 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -2358,6 +2358,11 @@ M.chat_respond = function(params) local last_content_line = M._H.last_content_line(buf) vim.api.nvim_buf_set_lines(buf, last_content_line, last_content_line, false, { "", agent_prefix .. agent_suffix, "" }) + -- insert requested context in the message the user just entered + local context = require("gp.context") + messages[#messages].content = context.insert_contexts(messages[#messages].content) + print(vim.inspect(messages[#messages])) + -- call the model and write response M.query( buf, diff --git a/lua/gp/utils.lua b/lua/gp/utils.lua new file mode 100644 index 00000000..ce91207a --- /dev/null +++ b/lua/gp/utils.lua @@ -0,0 +1,33 @@ +local M = {} + +function M.path_split(path) + return vim.split(path, "/") +end + +function M.path_join(...) + local args = { ... } + local parts = {} + + for i, part in ipairs(args) do + if type(part) ~= "string" then + error("Argument #" .. i .. " is not a string", 2) + end + + -- Remove leading/trailing separators (both / and \) + part = part:gsub("^[/\\]+", ""):gsub("[/\\]+$", "") + + if #part > 0 then + table.insert(parts, part) + end + end + + local result = table.concat(parts, "/") + + if args[1]:match("^[/\\]") then + result = "/" .. result + end + + return result +end + +return M From 09d3037c3f513bb805d89571aa45470e7b97a023 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sat, 27 Jul 2024 23:10:37 +0800 Subject: [PATCH 05/34] Misc cleanup to minimize alteration in init.lua --- lua/gp/init.lua | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/lua/gp/init.lua b/lua/gp/init.lua index df67b75f..cbc54dc4 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -2060,7 +2060,6 @@ M.new_chat = function(params, toggle, system_prompt, agent) M.append_selection(params, cbuf, buf) end M._H.feedkeys("G", "xn") - return buf end @@ -2107,7 +2106,6 @@ end ---@param system_prompt string | nil ---@param agent table | nil # obtained from get_command_agent or get_chat_agent M.cmd.ChatToggle = function(params, system_prompt, agent) - print(">> ChatToggle") if M._toggle_close(M._toggle_kind.popup) then return end @@ -2135,11 +2133,6 @@ M.cmd.ChatToggle = function(params, system_prompt, agent) buf = M.new_chat(params, true, system_prompt, agent) end - -- Tell nvim-cmp to use our completion source for the new buffer - -- print("In ChatToggle, trying to setup per buffer completion source") - -- local completion = require("gp.completion") - -- completion.setup_for_buffer(buf) - return buf end @@ -2359,8 +2352,7 @@ M.chat_respond = function(params) vim.api.nvim_buf_set_lines(buf, last_content_line, last_content_line, false, { "", agent_prefix .. agent_suffix, "" }) -- insert requested context in the message the user just entered - local context = require("gp.context") - messages[#messages].content = context.insert_contexts(messages[#messages].content) + messages[#messages].content = require("gp.context").insert_contexts(messages[#messages].content) print(vim.inspect(messages[#messages])) -- call the model and write response From 9a1142bf3b9166863340385b13c0289b7dc9318c Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Mon, 29 Jul 2024 19:11:52 +0800 Subject: [PATCH 06/34] Extract function definitions from src using treesitter --- data/ts_queries/lua.scm | 9 +++++ lua/gp/context.lua | 87 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 data/ts_queries/lua.scm diff --git a/data/ts_queries/lua.scm b/data/ts_queries/lua.scm new file mode 100644 index 00000000..f2fe8496 --- /dev/null +++ b/data/ts_queries/lua.scm @@ -0,0 +1,9 @@ +(function_declaration + name: (identifier) @name) @body + +(function_declaration + name: (dot_index_expression + field: (identifier) @name)) @body + +(function_declaration + name: (dot_index_expression) @name) @body diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 109c0fac..9e9cc75d 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -1,5 +1,6 @@ local u = require("gp.utils") local gp = require("gp") +local logger = require("gp.logger") local M = {} -- Split a context insertion command into its component parts @@ -7,6 +8,7 @@ function M.cmd_split(cmd) return vim.split(cmd, ":", { plain = true }) end +---@return string | nil local function read_file(filepath) local file = io.open(filepath, "r") if not file then @@ -64,4 +66,89 @@ function M.insert_contexts(msg) end end +function find_plugin_path(plugin_name) + local paths = vim.api.nvim_list_runtime_paths() + for path in paths do + local components = u.path_split(path) + if components[#components] == plugin_name then + return path + end + end +end + +function M.treesitter_query(src_filepath, query_filepath) + -- Read the source file content + ---WARNING: This is probably not a good idea for very large files + local src_content = read_file(src_filepath) + if not src_content then + logger.error("Unable to load src file: " .. src_filepath) + return nil + end + + -- Read the query file content + local query_content = read_file(query_filepath) + if not query_content then + logger.error("Unable to load query file: " .. query_filepath) + return nil + end + + -- Get the filetype of the source file + local filetype = vim.filetype.match({ filename = src_filepath }) + if not filetype then + logger.error("Unable to determine filetype for: " .. src_filepath) + return nil + end + + -- Check if the treesitter support for the language is available + local ok, err = pcall(vim.treesitter.language.add, filetype) + if not ok then + print("TreeSitter parser for " .. filetype .. " is not installed") + logger.error(err) + return nil + end + + -- Parse the source text + -- local parser = vim.treesitter.get_parser(0, filetype) + local parser = vim.treesitter.get_string_parser(src_content, filetype, {}) + local tree = parser:parse()[1] + local root = tree:root() + + -- Create and run the query + local query = vim.treesitter.query.parse(filetype, query_content) + + -- Grab all the captures + local captures = {} + for id, node, metadata in query:iter_captures(root, src_content, 0, -1) do + local name = query.captures[id] + local start_row, start_col, end_row, end_col = node:range() + table.insert(captures, { + name = name, + node = node, + range = { start_row, start_col, end_row, end_col }, + text = vim.treesitter.get_node_text(node, src_content), + metadata = metadata, + }) + end + + -- Reshape the captures into a structure we'd like to work with + local results = {} + for i = 1, #captures, 2 do + local fn_body = captures[i] + assert(fn_body.name == "body") + local fn_name = captures[i + 1] + assert(fn_name.name == "name") + + table.insert(results, { + file = src_filepath, + type = "function_definition", + name = fn_name.text, + start_line = fn_body.range[1], + end_line = fn_body.range[3], + body = fn_body.text, + }) + end + + return results +end + return M From 8ed36b16bdb4ff9dfa1ebb5dfc2b618ea88b4797 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Mon, 29 Jul 2024 20:22:28 +0800 Subject: [PATCH 07/34] Implements ts function defintion extraction given source filepath --- data/ts_queries/lua.scm | 15 +++++---- lua/gp/context.lua | 71 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 74 insertions(+), 12 deletions(-) diff --git a/data/ts_queries/lua.scm b/data/ts_queries/lua.scm index f2fe8496..c8858571 100644 --- a/data/ts_queries/lua.scm +++ b/data/ts_queries/lua.scm @@ -1,9 +1,12 @@ +;; Matches global and local function declarations (function_declaration name: (identifier) @name) @body -(function_declaration - name: (dot_index_expression - field: (identifier) @name)) @body - -(function_declaration - name: (dot_index_expression) @name) @body +;; Matches on: +;; M.some_fn = function() end +(assignment_statement + (variable_list + name: (dot_index_expression + field: (identifier) @name)) + (expression_list + value: (function_definition) @body)) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 9e9cc75d..e65e52f3 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -19,6 +19,16 @@ local function read_file(filepath) return content end +function file_exists(path) + local file = io.open(path, "r") + if file then + file:close() + return true + else + return false + end +end + -- Given a single message, parse out all the context insertion -- commands, then return a new message with all the requested -- context inserted @@ -66,9 +76,9 @@ function M.insert_contexts(msg) end end -function find_plugin_path(plugin_name) +function M.find_plugin_path(plugin_name) local paths = vim.api.nvim_list_runtime_paths() - for path in paths do + for _, path in ipairs(paths) do local components = u.path_split(path) if components[#components] == plugin_name then return path @@ -76,6 +86,9 @@ function find_plugin_path(plugin_name) end end +-- Runs the supplied query on the supplied source file. +-- Returns all the captures as is. It is up to the caller to +-- know what the expected output is and to reshape the data. function M.treesitter_query(src_filepath, query_filepath) -- Read the source file content ---WARNING: This is probably not a good idea for very large files @@ -130,13 +143,59 @@ function M.treesitter_query(src_filepath, query_filepath) }) end + return captures +end + +function M.treesitter_extract_function_definitions(src_filepath) + -- Make sure we can locate the source file + if not file_exists(src_filepath) then + logger.error("Unable to locate src file: " .. src_filepath) + return nil + end + + -- Get the filetype of the source file + local filetype = vim.filetype.match({ filename = src_filepath }) + if not filetype then + logger.error("Unable to determine filetype for: " .. src_filepath) + return nil + end + + -- We'll use the reported filetype as the name of the language + -- Try to locate a query file we can use to extract function definitions + local plugin_path = M.find_plugin_path("gp.nvim") + if not plugin_path then + logger.error("Unable to locate path for gp.nvim...") + return nil + end + + -- Find the query file that's approprite for the language + local query_filepath = u.path_join(plugin_path, "data/ts_queries/" .. filetype .. ".scm") + if not file_exists(query_filepath) then + logger.error("Unable to find function extraction ts query file: " .. query_filepath) + return nil + end + + -- Run the query + local captures = M.treesitter_query(src_filepath, query_filepath) + if not captures then + return nil + end + -- Reshape the captures into a structure we'd like to work with local results = {} for i = 1, #captures, 2 do - local fn_body = captures[i] - assert(fn_body.name == "body") - local fn_name = captures[i + 1] - assert(fn_name.name == "name") + -- The captures may arrive out of order. + -- We're only expecting the query to contain @name and @body returned + -- Sort out their ordering here. + local caps = { captures[i], captures[i + 1] } + local named_caps = {} + for _, item in ipairs(caps) do + named_caps[item.name] = item + end + local fn_name = named_caps.name + local fn_body = named_caps.body + assert(fn_name) + assert(fn_body) table.insert(results, { file = src_filepath, From 208478d6148fe12710d75c3bbf98a0268fe9c5f4 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Wed, 31 Jul 2024 15:13:45 +0800 Subject: [PATCH 08/34] [WIP] Trying to get the fnlist into the databae --- data/ts_queries/lua.scm | 3 +- lua/gp/context.lua | 55 ++++++++-- lua/gp/db.lua | 228 ++++++++++++++++++++++++++++++++++++++++ lua/gp/utils.lua | 32 +++++- 4 files changed, 302 insertions(+), 16 deletions(-) create mode 100644 lua/gp/db.lua diff --git a/data/ts_queries/lua.scm b/data/ts_queries/lua.scm index c8858571..c930c2c0 100644 --- a/data/ts_queries/lua.scm +++ b/data/ts_queries/lua.scm @@ -6,7 +6,6 @@ ;; M.some_fn = function() end (assignment_statement (variable_list - name: (dot_index_expression - field: (identifier) @name)) + name: (dot_index_expression) @name) (expression_list value: (function_definition) @body)) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index e65e52f3..384f7077 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -1,10 +1,14 @@ local u = require("gp.utils") local gp = require("gp") local logger = require("gp.logger") -local M = {} + +---@type Db +local Db = require("gp.db") + +local Context = {} -- Split a context insertion command into its component parts -function M.cmd_split(cmd) +function Context.cmd_split(cmd) return vim.split(cmd, ":", { plain = true }) end @@ -32,7 +36,7 @@ end -- Given a single message, parse out all the context insertion -- commands, then return a new message with all the requested -- context inserted -function M.insert_contexts(msg) +function Context.insert_contexts(msg) local context_texts = {} -- Parse out all context insertion commands @@ -44,7 +48,7 @@ function M.insert_contexts(msg) -- Process each command and turn it into a string be -- inserted as additional context for _, cmd in ipairs(cmds) do - local cmd_parts = M.cmd_split(cmd) + local cmd_parts = Context.cmd_split(cmd) if cmd_parts[1] == "@file" then -- Read the reqested file and produce a msg snippet to be joined later @@ -76,7 +80,7 @@ function M.insert_contexts(msg) end end -function M.find_plugin_path(plugin_name) +function Context.find_plugin_path(plugin_name) local paths = vim.api.nvim_list_runtime_paths() for _, path in ipairs(paths) do local components = u.path_split(path) @@ -89,7 +93,7 @@ end -- Runs the supplied query on the supplied source file. -- Returns all the captures as is. It is up to the caller to -- know what the expected output is and to reshape the data. -function M.treesitter_query(src_filepath, query_filepath) +function Context.treesitter_query(src_filepath, query_filepath) -- Read the source file content ---WARNING: This is probably not a good idea for very large files local src_content = read_file(src_filepath) @@ -146,7 +150,7 @@ function M.treesitter_query(src_filepath, query_filepath) return captures end -function M.treesitter_extract_function_definitions(src_filepath) +function Context.treesitter_extract_function_defs(src_filepath) -- Make sure we can locate the source file if not file_exists(src_filepath) then logger.error("Unable to locate src file: " .. src_filepath) @@ -162,7 +166,7 @@ function M.treesitter_extract_function_definitions(src_filepath) -- We'll use the reported filetype as the name of the language -- Try to locate a query file we can use to extract function definitions - local plugin_path = M.find_plugin_path("gp.nvim") + local plugin_path = Context.find_plugin_path("gp.nvim") if not plugin_path then logger.error("Unable to locate path for gp.nvim...") return nil @@ -176,7 +180,7 @@ function M.treesitter_extract_function_definitions(src_filepath) end -- Run the query - local captures = M.treesitter_query(src_filepath, query_filepath) + local captures = Context.treesitter_query(src_filepath, query_filepath) if not captures then return nil end @@ -210,4 +214,35 @@ function M.treesitter_extract_function_definitions(src_filepath) return results end -return M +---@param db Db +---@param src_filepath string +function Context.build_function_def_index_for_file(db, src_filepath) + print("building fn list") + -- try to retrieve function definitions from the file + local fnlist = Context.treesitter_extract_function_defs(src_filepath) + if not fnlist then + return false + end + print("done building fn list") + + -- Grab the src file meta data + local src_file_entry = db.collect_src_file_data(src_filepath) + assert(src_file_entry) + src_file_entry.last_scan_time = os.time() + print("collected file info") + + -- Update the src file entry and the function definitions in a single transaction + -- TODO: Remove stale entries? + local result = db:with_transaction(function() + print("updating src file entry") + local success = db:upsert_src_file(src_file_entry) + if not success then + return false + end + print("upserting fn list") + return db:upsert_fnlist(fnlist) + end) + return result +end + +return Context diff --git a/lua/gp/db.lua b/lua/gp/db.lua new file mode 100644 index 00000000..506e405c --- /dev/null +++ b/lua/gp/db.lua @@ -0,0 +1,228 @@ +local sqlite = require("sqlite.db") + +local sqlite_clib = require("sqlite.defs") +local gp = require("gp") +local u = require("gp.utils") +local logger = require("gp.logger") + +-- Describes files we've scanned previously to produce the list of function definitions +---@class SrcFileEntry +---@field id number: unique id +---@field filename string: path relative to the git/project root +---@field file_size number: -- zie of the file at last scan in bytes +---@field filetype string: filetype as reported by neovim at last scan +---@field mod_time number: last file modification time reported by the os at last scan +---@field last_scan_time number: unix time stamp indicating when the last scan of this file was made + +-- Describes where each of the functions are in the project +---@class FunctionDefEntry +---@field id number: unique id +---@field name string: Name of the function +---@field file string: In which file is the function defined? +---@field start_line number: Which line in the file does the definition start? +---@field end_line number: Which line in the file does the definition end? + +---@class Db +---@field db sqlite_db +local Db = {} + +Db._new = function(db) + return setmetatable({ db = db }, { __index = Db }) +end + +--- Opens and/or creates a SQLite database for storing function definitions. +-- @return Db|nil A new Db object if successful, nil if an error occurs +-- @side-effect Creates .gp directory and database file if they don't exist +-- @side-effect Logs errors if unable to locate project root or create directory +function Db.open() + local git_root = gp._H.find_git_root() + if git_root == "" then + logger.error("[db.open] Unable to locate project root") + return nil + end + + local db_file = u.path_join(git_root, ".gp/function_defs.sqlite") + if not u.ensure_parent_path_exists(db_file) then + logger.error("[db.open] Unable create directory for db file: " .. db_file) + end + + local db = sqlite({ + uri = db_file, + + -- The `src_files` table stores a list of known src files and the last time they were scanned + src_files = { + id = true, + filename = { type = "text", required = true }, -- relative to the git/project root + file_size = { type = "integer", required = true }, -- size of the file at last scan + filetype = { type = "text", required = true }, -- filetype as reported by neovim at last scan + mod_time = { type = "integer", required = true }, -- file mod time reported by the fs at last scan + last_scan_time = { type = "integer", required = true }, -- unix timestamp + }, + + -- The `function_defs` table stores all known functions were we able to extract out of the src files + function_defs = { + id = true, + + -- We're keeping this a text field for now to avoid having to deal with joins to grab the filename. + -- This will also cause the file definition entries to be removed when the cooresponding src_file + -- entry is removed. + file = { + type = "text", + reference = "src_files.filename", + on_delete = "cascade", + required = true, + }, + name = { type = "text", required = true }, -- name of the function + start_line = { type = "integer", required = true }, -- Where the fn def starts + end_line = { type = "integer", required = true }, -- Where the fn def ends + }, + opts = { keep_open = true }, + }) + + db:eval("CREATE UNIQUE INDEX IF NOT EXISTS idx_src_files_filename ON src_files (filename);") + + return Db._new(db) +end + +--- Gathers information on a file to populate most of a SrcFileEntry. +--- @return SrcFileEntry|nil +function Db.collect_src_file_data(relative_path) + local uv = vim.uv or vim.loop + + -- Construct the full path to the file + local proj_root = gp._H.find_git_root() + local fullpath = u.path_join(proj_root, relative_path) + + -- If the file doesn't exist, there is nothing to collect + local stat = uv.fs_stat(fullpath) + if not stat then + return nil + end + + local entry = {} + + entry.filename = relative_path + entry.file_size = stat.size + entry.filetype = vim.filetype.match({ filename = fullpath }) + entry.mod_time = stat.mtime.sec + + return entry +end + +-- Upserts a single src file entry into the database +--- @param file SrcFileEntry +function Db:upsert_src_file(file) + if not self.db then + logger.error("[db.upsert_src_file] Database not initialized") + return false + end + + local sql = [[ + INSERT INTO src_files (filename, file_size, filetype, mod_time, last_scan_time) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(filename) DO UPDATE SET + file_size = excluded.file_size, + filetype = excluded.filetype, + mod_time = excluded.mod_time, + last_scan_time = excluded.last_scan_time + WHERE filename = ? + ]] + + local success = self.db:eval(sql, { + -- For the INSERT VALUES clause + file.filename, + file.file_size, + file.filetype, + file.mod_time, + file.last_scan_time, + + -- For the WHERE claue + file.filename, + }) + + if not success then + logger.error("[db.upsert_src_file] Failed to upsert file: " .. file.filename) + return false + end + + return true +end + +--- @param filelist SrcFileEntry[] +function Db:upsert_filelist(filelist) + for _, file in ipairs(filelist) do + local success = self:upsert_src_file(file) + if not success then + logger.error("[db.upsert_filelist] Failed to upsert file list") + return false + end + end + + return true +end + +-- Upserts a single function def entry into the database +--- @param def FunctionDefEntry +function Db:upsert_function_def(def) + print("[upsert fn def] 1") + if not self.db then + logger.error("[db.upsert_function_def] Database not initialized") + return false + end + + print("[upsert fn def] 2") + local sql = [[ + INSERT INTO function_defs (file, name, start_line, end_line) + VALUES (?, ?, ?, ?) + ON CONFLICT(file, name, start_line) DO UPDATE SET + start_line = excluded.start_line, + end_line = excluded.end_line + WHERE file = ? AND name = ? + ]] + + print("[upsert fn def] 3") + local success = self.db:eval(sql, { + -- For the INSERT VALUES clause + def.file, + def.name, + def.start_line, + def.end_line, + + -- For the WHERE clause + def.file, + def.name, + }) + + print("[upsert fn def] 4") + if not success then + logger.error("[db.upsert_function_def] Failed to upsert function: " .. def.name .. " for file: " .. def.file) + return false + end + + return true +end + +-- Wraps the given function in a sqlite transaction +---@param fn function() +function Db:with_transaction(fn) + return sqlite_clib.wrap_stmts(self.db.conn, fn) +end + +--- Updates the dastabase with the contents of the `fnlist` +--- Note that this function early terminates of any of the entry upsert fails. +--- This behavior is only suitable when run inside a transaction. +--- @param fnlist FunctionDefEntry[] +function Db:upsert_fnlist(fnlist) + for _, def in ipairs(fnlist) do + print(vim.inspect(def)) + local success = self:upsert_function_def(def) + if not success then + logger.error("[db.upsert_fnlist] Failed to upsert function def list") + return false + end + end + + return true +end + +return Db diff --git a/lua/gp/utils.lua b/lua/gp/utils.lua index ce91207a..fccbc3ff 100644 --- a/lua/gp/utils.lua +++ b/lua/gp/utils.lua @@ -1,10 +1,12 @@ -local M = {} +local uv = vim.uv or vim.loop -function M.path_split(path) +local Utils = {} + +function Utils.path_split(path) return vim.split(path, "/") end -function M.path_join(...) +function Utils.path_join(...) local args = { ... } local parts = {} @@ -30,4 +32,26 @@ function M.path_join(...) return result end -return M +function Utils.ensure_path_exists(path) + -- Check if the path exists + local stat = uv.fs_stat(path) + if stat and stat.type == "directory" then + -- The path exists and is a directory + return true + end + + -- Try to create the directory + return vim.fn.mkdir(path, "p") +end + +function Utils.ensure_parent_path_exists(path) + local components = Utils.path_split(path) + + -- Get the parent directory by removing the last component + table.remove(components) + local parent_path = table.concat(components, "/") + + return Utils.ensure_path_exists(parent_path) +end + +return Utils From 75f74962a0bd10080694ca7d996b4ab6952ca8e0 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Wed, 31 Jul 2024 16:09:55 +0800 Subject: [PATCH 09/34] Able to insert all fn defs for a single src file --- lua/gp/db.lua | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 506e405c..31884029 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -59,26 +59,25 @@ function Db.open() last_scan_time = { type = "integer", required = true }, -- unix timestamp }, - -- The `function_defs` table stores all known functions were we able to extract out of the src files - function_defs = { - id = true, - - -- We're keeping this a text field for now to avoid having to deal with joins to grab the filename. - -- This will also cause the file definition entries to be removed when the cooresponding src_file - -- entry is removed. - file = { - type = "text", - reference = "src_files.filename", - on_delete = "cascade", - required = true, - }, - name = { type = "text", required = true }, -- name of the function - start_line = { type = "integer", required = true }, -- Where the fn def starts - end_line = { type = "integer", required = true }, -- Where the fn def ends - }, opts = { keep_open = true }, }) + db:eval("PRAGMA foreign_keys = ON;") + + -- sqlite.lua doesn't seem to support adding random table options + -- In this case, being able to perform an upsert in the function_defs table depends on + -- having UNIQUE file and fn name pair. + db:eval([[ + CREATE TABLE IF NOT EXISTS function_defs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file TEXT NOT NULL REFERENCES src_files(filename) on DELETE CASCADE, + name TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + UNIQUE (file, name) + ); + ]]) + db:eval("CREATE UNIQUE INDEX IF NOT EXISTS idx_src_files_filename ON src_files (filename);") return Db._new(db) @@ -174,7 +173,7 @@ function Db:upsert_function_def(def) local sql = [[ INSERT INTO function_defs (file, name, start_line, end_line) VALUES (?, ?, ?, ?) - ON CONFLICT(file, name, start_line) DO UPDATE SET + ON CONFLICT(file, name) DO UPDATE SET start_line = excluded.start_line, end_line = excluded.end_line WHERE file = ? AND name = ? From df84cbf2d2a1aa4746266bd8cc65a1ad3e9f835e Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Wed, 31 Jul 2024 20:17:22 +0800 Subject: [PATCH 10/34] Implements project-wide indexing --- lua/gp/context.lua | 61 +++++++++++++++++++++++++++++++++++++++++++++- lua/gp/db.lua | 22 ++++++++++++----- lua/gp/utils.lua | 12 +++++++++ 3 files changed, 88 insertions(+), 7 deletions(-) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 384f7077..c5a528bf 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -227,7 +227,10 @@ function Context.build_function_def_index_for_file(db, src_filepath) -- Grab the src file meta data local src_file_entry = db.collect_src_file_data(src_filepath) - assert(src_file_entry) + if not src_file_entry then + logger.error("Unable to collect src file data for:" .. src_filepath) + return false + end src_file_entry.last_scan_time = os.time() print("collected file info") @@ -245,4 +248,60 @@ function Context.build_function_def_index_for_file(db, src_filepath) return result end +function Context.build_function_def_index(db) + local git_root = gp._H.find_git_root() + if not git_root then + logger.error("[Context.build_function_def_index] Unable to locate project root") + return false + end + local git_root_len = #git_root + 2 + + local function scan_directory(dir) + local entries = vim.fn.readdir(dir) + + for _, entry in ipairs(entries) do + local full_path = u.path_join(dir, entry) + local rel_path = full_path:sub(git_root_len) + -- Ignore hidden files and directories + if u.string_starts_with(entry, ".") or u.string_ends_with(entry, ".txt") or u.string_ends_with(entry, ".md") then + goto continue + end + + if vim.fn.isdirectory(full_path) == 1 then + if entry == "node_modules" then + goto continue + end + scan_directory(full_path) + else + -- Only process files with recognized filetypes + if vim.filetype.match({ filename = full_path }) then + local success = Context.build_function_def_index_for_file(db, rel_path) + if not success then + logger.warning("Failed to build function def index for: " .. rel_path) + end + end + end + ::continue:: + end + end + + scan_directory(git_root) +end + +function Context.index_all() + local uv = vim.uv or vim.loop + local start_time = uv.hrtime() + + local db = Db.open() + if not db then + return + end + Context.build_function_def_index(db) + db:close() + + local end_time = uv.hrtime() + local elapsed_time_ms = (end_time - start_time) / 1e6 + logger.info(string.format("[Gp] Indexing took: %.2f ms", elapsed_time_ms)) +end + return Context diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 31884029..4d8d4ccb 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -26,6 +26,7 @@ local logger = require("gp.logger") ---@field db sqlite_db local Db = {} +--- @return Db Db._new = function(db) return setmetatable({ db = db }, { __index = Db }) end @@ -44,6 +45,7 @@ function Db.open() local db_file = u.path_join(git_root, ".gp/function_defs.sqlite") if not u.ensure_parent_path_exists(db_file) then logger.error("[db.open] Unable create directory for db file: " .. db_file) + return nil end local db = sqlite({ @@ -95,6 +97,7 @@ function Db.collect_src_file_data(relative_path) -- If the file doesn't exist, there is nothing to collect local stat = uv.fs_stat(fullpath) if not stat then + logger.error("[Db.collection_src_file_data] failed: " .. relative_path) return nil end @@ -163,13 +166,11 @@ end -- Upserts a single function def entry into the database --- @param def FunctionDefEntry function Db:upsert_function_def(def) - print("[upsert fn def] 1") if not self.db then logger.error("[db.upsert_function_def] Database not initialized") return false end - print("[upsert fn def] 2") local sql = [[ INSERT INTO function_defs (file, name, start_line, end_line) VALUES (?, ?, ?, ?) @@ -179,7 +180,6 @@ function Db:upsert_function_def(def) WHERE file = ? AND name = ? ]] - print("[upsert fn def] 3") local success = self.db:eval(sql, { -- For the INSERT VALUES clause def.file, @@ -192,7 +192,6 @@ function Db:upsert_function_def(def) def.name, }) - print("[upsert fn def] 4") if not success then logger.error("[db.upsert_function_def] Failed to upsert function: " .. def.name .. " for file: " .. def.file) return false @@ -204,7 +203,15 @@ end -- Wraps the given function in a sqlite transaction ---@param fn function() function Db:with_transaction(fn) - return sqlite_clib.wrap_stmts(self.db.conn, fn) + self.db:execute("BEGIN") + local success, result = pcall(fn) + self.db:execute("END") + + if not success then + logger.error(result) + return false + end + return true end --- Updates the dastabase with the contents of the `fnlist` @@ -213,7 +220,6 @@ end --- @param fnlist FunctionDefEntry[] function Db:upsert_fnlist(fnlist) for _, def in ipairs(fnlist) do - print(vim.inspect(def)) local success = self:upsert_function_def(def) if not success then logger.error("[db.upsert_fnlist] Failed to upsert function def list") @@ -224,4 +230,8 @@ function Db:upsert_fnlist(fnlist) return true end +function Db:close() + self.db:close() +end + return Db diff --git a/lua/gp/utils.lua b/lua/gp/utils.lua index fccbc3ff..cc32298f 100644 --- a/lua/gp/utils.lua +++ b/lua/gp/utils.lua @@ -54,4 +54,16 @@ function Utils.ensure_parent_path_exists(path) return Utils.ensure_path_exists(parent_path) end +function Utils.string_starts_with(str, starting) + return string.sub(str, 1, string.len(starting)) == starting +end + +function Utils.string_ends_with(str, ending) + if #ending > #str then + return false + end + + return str:sub(-#ending) == ending +end + return Utils From 50855a637ac2c268fc574e906a77a5cca00be03b Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Wed, 31 Jul 2024 21:24:31 +0800 Subject: [PATCH 11/34] Set better log msg levels when building index Also applies minor fix to gp.logger to respect log levels when it comes to sending notifications. --- lua/gp/context.lua | 11 +++-------- lua/gp/db.lua | 2 +- lua/gp/logger.lua | 26 ++++++++++++-------------- 3 files changed, 16 insertions(+), 23 deletions(-) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index c5a528bf..0d25c3f5 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -119,7 +119,7 @@ function Context.treesitter_query(src_filepath, query_filepath) -- Check if the treesitter support for the language is available local ok, err = pcall(vim.treesitter.language.add, filetype) if not ok then - print("TreeSitter parser for " .. filetype .. " is not installed") + logger.error("TreeSitter parser for " .. filetype .. " is not installed") logger.error(err) return nil end @@ -175,7 +175,7 @@ function Context.treesitter_extract_function_defs(src_filepath) -- Find the query file that's approprite for the language local query_filepath = u.path_join(plugin_path, "data/ts_queries/" .. filetype .. ".scm") if not file_exists(query_filepath) then - logger.error("Unable to find function extraction ts query file: " .. query_filepath) + logger.debug("Unable to find function extraction ts query file: " .. query_filepath) return nil end @@ -217,13 +217,11 @@ end ---@param db Db ---@param src_filepath string function Context.build_function_def_index_for_file(db, src_filepath) - print("building fn list") -- try to retrieve function definitions from the file local fnlist = Context.treesitter_extract_function_defs(src_filepath) if not fnlist then return false end - print("done building fn list") -- Grab the src file meta data local src_file_entry = db.collect_src_file_data(src_filepath) @@ -232,17 +230,14 @@ function Context.build_function_def_index_for_file(db, src_filepath) return false end src_file_entry.last_scan_time = os.time() - print("collected file info") -- Update the src file entry and the function definitions in a single transaction -- TODO: Remove stale entries? local result = db:with_transaction(function() - print("updating src file entry") local success = db:upsert_src_file(src_file_entry) if not success then return false end - print("upserting fn list") return db:upsert_fnlist(fnlist) end) return result @@ -277,7 +272,7 @@ function Context.build_function_def_index(db) if vim.filetype.match({ filename = full_path }) then local success = Context.build_function_def_index_for_file(db, rel_path) if not success then - logger.warning("Failed to build function def index for: " .. rel_path) + logger.debug("Failed to build function def index for: " .. rel_path) end end end diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 4d8d4ccb..495edcbe 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -97,7 +97,7 @@ function Db.collect_src_file_data(relative_path) -- If the file doesn't exist, there is nothing to collect local stat = uv.fs_stat(fullpath) if not stat then - logger.error("[Db.collection_src_file_data] failed: " .. relative_path) + logger.debug("[Db.collection_src_file_data] failed: " .. relative_path) return nil end diff --git a/lua/gp/logger.lua b/lua/gp/logger.lua index 10b7e52f..4a56912d 100644 --- a/lua/gp/logger.lua +++ b/lua/gp/logger.lua @@ -4,6 +4,7 @@ local file = "/dev/null" local uuid = "" M._log_history = {} +M.level = vim.log.levels.INFO ---@param path string # path to log file M.set_log_file = function(path) @@ -26,9 +27,8 @@ end ---@param msg string # message to log ---@param level integer # log level ----@param slevel string # log level as string -local log = function(msg, level, slevel) - local raw = string.format("[%s] [%s] %s: %s", os.date("%Y-%m-%d %H:%M:%S"), uuid, slevel, msg) +local log = function(msg, level) + local raw = string.format("[%s] [%s] %s: %s", os.date("%Y-%m-%d %H:%M:%S"), uuid, vim.lsp.log_levels[level], msg) M._log_history[#M._log_history + 1] = raw if #M._log_history > 100 then @@ -41,38 +41,36 @@ local log = function(msg, level, slevel) log_file:close() end - if level <= vim.log.levels.DEBUG then - return + if level >= M.level then + vim.schedule(function() + vim.notify(msg, level, { title = "gp.nvim" }) + end) end - - vim.schedule(function() - vim.notify(msg, level, { title = "gp.nvim" }) - end) end ---@param msg string # error message M.error = function(msg) - log(msg, vim.log.levels.ERROR, "ERROR") + log(msg, vim.log.levels.ERROR) end ---@param msg string # warning message M.warning = function(msg) - log(msg, vim.log.levels.WARN, "WARNING") + log(msg, vim.log.levels.WARN) end ---@param msg string # plain message M.info = function(msg) - log(msg, vim.log.levels.INFO, "INFO") + log(msg, vim.log.levels.INFO) end ---@param msg string # debug message M.debug = function(msg) - log(msg, vim.log.levels.DEBUG, "DEBUG") + log(msg, vim.log.levels.DEBUG) end ---@param msg string # trace message M.trace = function(msg) - log(msg, vim.log.levels.TRACE, "TRACE") + log(msg, vim.log.levels.TRACE) end return M From 4862407a604c6ea37c295ad9612ad4f575ee2cd5 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Wed, 31 Jul 2024 22:24:53 +0800 Subject: [PATCH 12/34] Verifies we can at least locate all functions for gp.nvim itself --- data/ts_queries/lua.scm | 45 +++++++++++++++++++++++++++++++++-------- lua/gp/context.lua | 2 +- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/data/ts_queries/lua.scm b/data/ts_queries/lua.scm index c930c2c0..33cb8a3f 100644 --- a/data/ts_queries/lua.scm +++ b/data/ts_queries/lua.scm @@ -1,11 +1,40 @@ ;; Matches global and local function declarations -(function_declaration - name: (identifier) @name) @body +;; function a_fn_name() +;; +;; This will only match on top level functions. +;; Specificadlly, this ignores the local function declarations. +;; We're only doing this because we're requiring the +;; (file, function_name) pair to be unique in the database. +(chunk + (function_declaration + name: (identifier) @name) @body) + +;; Matches function declaration using the dot syntax +;; function a_table.a_fn_name() +(chunk + (function_declaration + name: (dot_index_expression) @name) @body) + +;; Matches function declaration using the member function syntax +;; function a_table:a_fn_name() +(chunk + (function_declaration + name: (method_index_expression) @name) @body) + +;; Matches on: +;; M.some_field = function() end +(chunk + (assignment_statement + (variable_list + name: (dot_index_expression) @name) + (expression_list + value: (function_definition) @body))) ;; Matches on: -;; M.some_fn = function() end -(assignment_statement - (variable_list - name: (dot_index_expression) @name) - (expression_list - value: (function_definition) @body)) +;; some_var = function() end +(chunk + (assignment_statement + (variable_list + name: (identifier) @name) + (expression_list + value: (function_definition) @body))) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 0d25c3f5..3aedba0e 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -23,7 +23,7 @@ local function read_file(filepath) return content end -function file_exists(path) +local function file_exists(path) local file = io.open(path, "r") if file then file:close() From 76243a2b293ea60c1f2f1d2affbab57a9b8cb2f9 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Thu, 1 Aug 2024 15:53:28 +0800 Subject: [PATCH 13/34] Implements completion behavior for @code --- lua/gp/completion.lua | 112 ++++++++++++++++++++++++++++++++---------- lua/gp/db.lua | 22 ++++++++- 2 files changed, 106 insertions(+), 28 deletions(-) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index 42296862..cc9352cd 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -1,5 +1,7 @@ local u = require("gp.utils") local context = require("gp.context") +local db = require("gp.db") +local cmp = require("cmp") -- Gets a buffer variable or returns the default local function buf_get_var(buf, var_name, default) @@ -16,20 +18,24 @@ local function buf_set_var(buf, var_name, value) return vim.api.nvim_buf_set_var(buf, var_name, value) end +---@class CompletionSource +---@field db Db local source = {} source.src_name = "gp_completion" -source.new = function() +---@return CompletionSource +function source.new() print("source.new called") - return setmetatable({}, { __index = source }) + local db_inst = db.open() + return setmetatable({ db = db_inst }, { __index = source }) end -source.get_trigger_characters = function() +function source.get_trigger_characters() return { "@", ":", "/" } end -source.setup_for_buffer = function(bufnr) +function source.setup_for_buffer(bufnr) print("in setup_for_buffer") local config = require("cmp").get_config() @@ -44,7 +50,7 @@ source.setup_for_buffer = function(bufnr) }, bufnr) end -source.setup_autocmd_for_markdown = function() +function source.setup_autocmd_for_markdown() print("setting up autocmd...") vim.api.nvim_create_autocmd("BufEnter", { pattern = { "*.md", "markdown" }, @@ -68,9 +74,9 @@ source.setup_autocmd_for_markdown = function() }) end -source.register_cmd_source = function() +function source.register_cmd_source() print("registering completion src") - require("cmp").register_source(source.src_name, source.new()) + cmp.register_source(source.src_name, source.new()) end local function extract_cmd(request) @@ -82,8 +88,6 @@ local function extract_cmd(request) end local function completion_items_for_path(path) - local cmp = require("cmp") - -- The incoming path should either be -- - A relative path that references a directory -- - A relative path + partial filename as last component- @@ -136,7 +140,33 @@ local function completion_items_for_path(path) return files end -source.complete = function(self, request, callback) +function source:completion_items_for_fn_name(partial_fn_name) + local result = self.db:find_fn_def_by_name(partial_fn_name) + + local items = {} + if not result then + return items + end + + for _, row in ipairs(result) do + table.insert(items, { + -- fields meant for nvim-cmp + label = row.name, + kind = cmp.lsp.CompletionItemKind.Function, + labelDetails = { + detail = row.file, + }, + + -- fields meant for internal use + row = row, + type = "@code", + }) + end + + return items +end + +function source.complete(self, request, callback) local input = string.sub(request.context.cursor_before_line, request.offset - 1) print("[comp] input: '" .. input .. "'") local cmd = extract_cmd(request) @@ -160,22 +190,21 @@ source.complete = function(self, request, callback) -- Say that the entire list has been provided -- cmp won't call us again to provide an updated list isIncomplete = false - elseif input:match("^@code:") then - print("[complete] @code: case") - local parts = vim.split(input, ":", { plain = true }) - if #parts == 1 then - items = { - { label = "filename1.lua", kind = require("cmp").lsp.CompletionItemKind.File }, - { label = "filename2.lua", kind = require("cmp").lsp.CompletionItemKind.File }, - { label = "function1", kind = require("cmp").lsp.CompletionItemKind.Function }, - { label = "function2", kind = require("cmp").lsp.CompletionItemKind.Function }, - } - elseif #parts == 2 then - items = { - { label = "function1", kind = require("cmp").lsp.CompletionItemKind.Function }, - { label = "function2", kind = require("cmp").lsp.CompletionItemKind.Function }, - } + elseif cmd_parts[1]:match("@code") then + local partial_fn_name = cmd_parts[2] + + -- When the user confirms completion of an item, we alter the + -- command to look like `@code:path/to/file:fn_name` to uniquely + -- identify a function. + -- + -- If the user were to hit backspace to delete through the text, + -- don't process the input until it no longer looks like a path. + if partial_fn_name:match("/") then + return end + + items = self:completion_items_for_fn_name(partial_fn_name) + isIncomplete = false elseif input:match("^@") then print("[complete] @ case") items = { @@ -189,12 +218,41 @@ source.complete = function(self, request, callback) end local data = { items = items, isIncomplete = isIncomplete } - print("[complete] Callback data:") - print(vim.inspect(data)) callback(data) print("[complete] Callback called") end +local function search_backwards(buf, pattern) + -- Use nvim_buf_call to execute a Vim command in the buffer context + return vim.api.nvim_buf_call(buf, function() + -- Search backwards for the pattern + local result = vim.fn.searchpos(pattern, "bn") + + if result[1] == 0 and result[2] == 0 then + return nil + end + return result + end) +end + +function source:execute(item, callback) + if item.type == "@code" then + -- Locate where @command starts and ends + local end_pos = vim.api.nvim_win_get_cursor(0) + local start_pos = search_backwards(0, "@code") + + -- Replace it with a custom piece of text and move the cursor to the end of the string + local text = string.format("@code:%s:%s", item.row.file, item.row.name) + vim.api.nvim_buf_set_text(0, start_pos[1] - 1, start_pos[2] - 1, end_pos[1] - 1, end_pos[2], { text }) + vim.api.nvim_win_set_cursor(0, { start_pos[1], start_pos[2] - 1 + #text }) + end + + -- After brief glance at the nvim-cmp source, it appears + -- we should call `callback` to continue the entry item selection + -- confirmation handling chain. + callback() +end + source.setup_autocmd_for_markdown() return source diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 495edcbe..4a665de3 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -1,5 +1,4 @@ local sqlite = require("sqlite.db") - local sqlite_clib = require("sqlite.defs") local gp = require("gp") local u = require("gp.utils") @@ -234,4 +233,25 @@ function Db:close() self.db:close() end +function Db:find_fn_def_by_name(partial_fn_name) + local sql = [[ + SELECT * FROM function_defs WHERE name LIKE ? + ]] + + local wildcard_name = "%" .. partial_fn_name .. "%" + + local result = self.db:eval(sql, { + wildcard_name, + }) + + -- We're expecting the query to return a list of FunctionDefEntry. + -- If we get a boolean back instead, we consider the operation to have failed. + if type(result) == "boolean" then + return nil + end + + ---@cast result FunctionDefEntry + return result +end + return Db From 8173dea93f73195fc00ea814956bd10d68664a19 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Thu, 1 Aug 2024 16:50:18 +0800 Subject: [PATCH 14/34] Implements index rebuilding for a single file --- lua/gp/context.lua | 18 ++++++++++++++++-- lua/gp/db.lua | 22 +++++++++++++++++++++- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 3aedba0e..4c81bd33 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -232,8 +232,9 @@ function Context.build_function_def_index_for_file(db, src_filepath) src_file_entry.last_scan_time = os.time() -- Update the src file entry and the function definitions in a single transaction - -- TODO: Remove stale entries? local result = db:with_transaction(function() + db:remove_src_file_entry(src_filepath) + local success = db:upsert_src_file(src_file_entry) if not success then return false @@ -258,7 +259,11 @@ function Context.build_function_def_index(db) local full_path = u.path_join(dir, entry) local rel_path = full_path:sub(git_root_len) -- Ignore hidden files and directories - if u.string_starts_with(entry, ".") or u.string_ends_with(entry, ".txt") or u.string_ends_with(entry, ".md") then + if + u.string_starts_with(entry, ".") + or u.string_ends_with(entry, ".txt") + or u.string_ends_with(entry, ".md") + then goto continue end @@ -283,6 +288,15 @@ function Context.build_function_def_index(db) scan_directory(git_root) end +function Context.index_single_file(src_filepath) + local db = Db.open() + if not db then + return + end + Context.build_function_def_index_for_file(db, src_filepath) + db:close() +end + function Context.index_all() local uv = vim.uv or vim.loop local start_time = uv.hrtime() diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 4a665de3..2fdeef98 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -70,7 +70,7 @@ function Db.open() -- having UNIQUE file and fn name pair. db:eval([[ CREATE TABLE IF NOT EXISTS function_defs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, + id INTEGER NOT NULL PRIMARY KEY, file TEXT NOT NULL REFERENCES src_files(filename) on DELETE CASCADE, name TEXT NOT NULL, start_line INTEGER NOT NULL, @@ -254,4 +254,24 @@ function Db:find_fn_def_by_name(partial_fn_name) return result end +-- Removes a single entry from the src_files table given a relative file path +-- Note that related entries in the function_defs table will be removed via CASCADE. +---@param src_filepath string +function Db:remove_src_file_entry(src_filepath) + local sql = [[ + DELETE FROM src_files WHERE filename = ? + ]] + + local result = self.db:eval(sql, { + src_filepath, + }) + + return result +end + +function Db:clear() + self.db:eval("DELETE FROM function_defs") + self.db:eval("DELETE FROM src_files") +end + return Db From d6f7c7186783426b1bd6a9db50e4d4c43f5b9e72 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Thu, 1 Aug 2024 17:47:26 +0800 Subject: [PATCH 15/34] Adds a metadata (KV store) table to the database --- lua/gp/db.lua | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 2fdeef98..13d9f9b6 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -50,6 +50,13 @@ function Db.open() local db = sqlite({ uri = db_file, + -- The `metadata` table is a simple KV store + metadata = { + id = true, + key = { type = "text", required = true, unique = true }, + value = { type = "luatable", required = true }, + }, + -- The `src_files` table stores a list of known src files and the last time they were scanned src_files = { id = true, @@ -274,4 +281,28 @@ function Db:clear() self.db:eval("DELETE FROM src_files") end +-- Gets the value of a key from the metadata table +---@param keyname string +---@return any +function Db:get_metadata(keyname) + local result = self.db.metadata:where({ key = keyname }) + if result then + return result.value + end +end + +-- Sets the value of a key in the metadata table +-- WARNING: value cannot be of a number type +---@param keyname string +---@param value any +function Db:set_metadata(keyname, value) + -- The sqlite.lua plugin doesn't seem to like having numbers stored in the a field + -- marked as the "luatable" or "json" type. + -- If we store a number into the value field, sqlite.lua will throw a parse error on get. + if type(value) == "number" then + error("database metadata table doesn't not support storing a number as a root value") + end + return self.db.metadata:update({ where = { key = keyname }, set = { value = value } }) +end + return Db From 3217e1c6414b1aa97b5bde2c5387d69140137f2b Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Thu, 1 Aug 2024 23:35:01 +0800 Subject: [PATCH 16/34] Refactors scan_directory into Utils.walk_directory --- lua/gp/context.lua | 47 ++++++++++++++++++----------------------- lua/gp/utils.lua | 52 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 27 deletions(-) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 4c81bd33..97a5a649 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -252,40 +252,33 @@ function Context.build_function_def_index(db) end local git_root_len = #git_root + 2 - local function scan_directory(dir) - local entries = vim.fn.readdir(dir) - - for _, entry in ipairs(entries) do - local full_path = u.path_join(dir, entry) - local rel_path = full_path:sub(git_root_len) - -- Ignore hidden files and directories - if - u.string_starts_with(entry, ".") - or u.string_ends_with(entry, ".txt") - or u.string_ends_with(entry, ".md") - then - goto continue + walk_directory(git_root, { + should_process = function(entry, rel_path, full_path, is_dir) + if u.string_starts_with(entry, ".") then + return false end - if vim.fn.isdirectory(full_path) == 1 then + if is_dir then if entry == "node_modules" then - goto continue + return false end - scan_directory(full_path) else - -- Only process files with recognized filetypes - if vim.filetype.match({ filename = full_path }) then - local success = Context.build_function_def_index_for_file(db, rel_path) - if not success then - logger.debug("Failed to build function def index for: " .. rel_path) - end + if u.string_ends_with(entry, ".txt") or u.string_ends_with(entry, ".md") then + return false end end - ::continue:: - end - end - - scan_directory(git_root) + return true + end, + + process_file = function(rel_path, full_path) + if vim.filetype.match({ filename = full_path }) then + local success = Context.build_function_def_index_for_file(db, full_path) + if not success then + logger.debug("Failed to build function def index for: " .. rel_path) + end + end + end, + }) end function Context.index_single_file(src_filepath) diff --git a/lua/gp/utils.lua b/lua/gp/utils.lua index cc32298f..8485f0ad 100644 --- a/lua/gp/utils.lua +++ b/lua/gp/utils.lua @@ -66,4 +66,56 @@ function Utils.string_ends_with(str, ending) return str:sub(-#ending) == ending end +---@class WalkDirectoryOptions +---@field should_process function Passed `entry`, `rel_path`, `full_path`, and `is_dir` +---@field process_file function +---@field on_error function +---@field recurse boolean +---@field max_depth number +--- +---@param dir string The directory to try to walk +---@param options WalkDirectoryOptions Describes how to walk the directory +--- +function Utils.walk_directory(dir, options) + options = options or {} + + local should_process = options.should_process or function() + return true + end + + local process_file = options.process_file or function(rel_path, full_path) + print(full_path) + end + local recurse = not options.recurse + + ---@type number + local max_depth = options.max_depth or math.huge + + local function walk(current_dir, current_depth) + if current_depth > max_depth then + return + end + + local entries = vim.fn.readdir(current_dir) + + for _, entry in ipairs(entries) do + local full_path = Utils.path_join(current_dir, entry) + local rel_path = full_path:sub(#dir + 2) + local is_dir = vim.fn.isdirectory(full_path) == 1 + + if should_process(entry, rel_path, full_path, is_dir) then + if is_dir then + if recurse then + walk(full_path, current_depth + 1) + end + else + pcall(process_file, rel_path, full_path) + end + end + end + end + + walk(dir, 1) +end + return Utils From 05f56afecd40463bcc10c26f3ce134524bd60a88 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Fri, 2 Aug 2024 10:48:18 +0800 Subject: [PATCH 17/34] Builds initial fn def index on ChatNew or ChatToggle --- lua/gp/context.lua | 36 +++++++++++++++++++++++++++++++++--- lua/gp/db.lua | 4 ++-- lua/gp/init.lua | 13 +++++++++++-- lua/gp/utils.lua | 25 +++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 7 deletions(-) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 97a5a649..567901cc 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -93,6 +93,8 @@ end -- Runs the supplied query on the supplied source file. -- Returns all the captures as is. It is up to the caller to -- know what the expected output is and to reshape the data. +---@param src_filepath string relative or full path to the src file to run the query on +---@param query_filepath string relative or full path to the query file to run function Context.treesitter_query(src_filepath, query_filepath) -- Read the source file content ---WARNING: This is probably not a good idea for very large files @@ -245,14 +247,14 @@ function Context.build_function_def_index_for_file(db, src_filepath) end function Context.build_function_def_index(db) - local git_root = gp._H.find_git_root() + local git_root = u.git_root_from_cwd() if not git_root then logger.error("[Context.build_function_def_index] Unable to locate project root") return false end local git_root_len = #git_root + 2 - walk_directory(git_root, { + u.walk_directory(git_root, { should_process = function(entry, rel_path, full_path, is_dir) if u.string_starts_with(entry, ".") then return false @@ -272,7 +274,7 @@ function Context.build_function_def_index(db) process_file = function(rel_path, full_path) if vim.filetype.match({ filename = full_path }) then - local success = Context.build_function_def_index_for_file(db, full_path) + local success = Context.build_function_def_index_for_file(db, rel_path) if not success then logger.debug("Failed to build function def index for: " .. rel_path) end @@ -306,4 +308,32 @@ function Context.index_all() logger.info(string.format("[Gp] Indexing took: %.2f ms", elapsed_time_ms)) end +function Context.build_initial_index() + local db = Db.open() + if not db then + return + end + + if db:get_metadata("done_initial_run") then + return + end + + Context.index_all() + db:set_metadata("done_initial_run", true) + db:close() +end + +-- Setup autocommand to update the function def index as the files are saved +function Context.setup_autocmd_update_index() + vim.api.nvim_create_autocmd("BufWritePost", { + pattern = { "*" }, + group = vim.api.nvim_create_augroup("GpFileIndexUpdate", { clear = true }), + callback = function(arg) + Context.index_single_file(arg.file) + end, + }) +end + +Context.setup_autocmd_update_index() + return Context diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 13d9f9b6..4d747f94 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -35,7 +35,7 @@ end -- @side-effect Creates .gp directory and database file if they don't exist -- @side-effect Logs errors if unable to locate project root or create directory function Db.open() - local git_root = gp._H.find_git_root() + local git_root = u.git_root_from_cwd() if git_root == "" then logger.error("[db.open] Unable to locate project root") return nil @@ -97,7 +97,7 @@ function Db.collect_src_file_data(relative_path) local uv = vim.uv or vim.loop -- Construct the full path to the file - local proj_root = gp._H.find_git_root() + local proj_root = u.git_root_from_cwd() local fullpath = u.path_join(proj_root, relative_path) -- If the file doesn't exist, there is nothing to collect diff --git a/lua/gp/init.lua b/lua/gp/init.lua index cbc54dc4..e6032571 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -2090,16 +2090,23 @@ M.cmd.ChatNew = function(params, system_prompt, agent) return -1 end end + + local buf + -- if chat toggle is open, close it and start a new one if M._toggle_close(M._toggle_kind.chat) then params.args = params.args or "" if params.args == "" then params.args = M.config.toggle_target end - return M.new_chat(params, true, system_prompt, agent) + buf = M.new_chat(params, true, system_prompt, agent) + else + buf = M.new_chat(params, false, system_prompt, agent) end - return M.new_chat(params, false, system_prompt, agent) + require("gp.context").build_initial_index() + + return buf end ---@param params table @@ -2133,6 +2140,8 @@ M.cmd.ChatToggle = function(params, system_prompt, agent) buf = M.new_chat(params, true, system_prompt, agent) end + require("gp.context").build_initial_index() + return buf end diff --git a/lua/gp/utils.lua b/lua/gp/utils.lua index 8485f0ad..8b87df14 100644 --- a/lua/gp/utils.lua +++ b/lua/gp/utils.lua @@ -32,6 +32,13 @@ function Utils.path_join(...) return result end +function Utils.path_is_absolute(path) + if Utils.string_starts_with(path, "/") then + return true + end + return false +end + function Utils.ensure_path_exists(path) -- Check if the path exists local stat = uv.fs_stat(path) @@ -118,4 +125,22 @@ function Utils.walk_directory(dir, options) walk(dir, 1) end +--- Locates the git_root using the cwd +function Utils.git_root_from_cwd() + return require("gp")._H.find_git_root(vim.fn.getcwd()) +end + +-- If the given path is a relative path, turn it into a fullpath +-- based on the current git_root +---@param path string +function Utils.full_path_for_project_file(path) + if Utils.path_is_absolute(path) then + return path + end + + -- Construct the full path to the file + local proj_root = Utils.git_root_from_cwd() + return Utils.path_join(proj_root, path) +end + return Utils From 2bd9508e134c2863228b3b8d4992be14ef31a3c3 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Fri, 2 Aug 2024 12:33:38 +0800 Subject: [PATCH 18/34] Implements context insertion for @code commands --- lua/gp/completion.lua | 3 ++ lua/gp/context.lua | 110 ++++++++++++++++++++++++++++++++++++++---- lua/gp/db.lua | 30 ++++++++++++ lua/gp/utils.lua | 15 ++++++ 4 files changed, 149 insertions(+), 9 deletions(-) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index cc9352cd..e49515b8 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -192,6 +192,9 @@ function source.complete(self, request, callback) isIncomplete = false elseif cmd_parts[1]:match("@code") then local partial_fn_name = cmd_parts[2] + if not partial_fn_name then + partial_fn_name = "" + end -- When the user confirms completion of an item, we alter the -- command to look like `@code:path/to/file:fn_name` to uniquely diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 567901cc..2070ed73 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -8,8 +8,35 @@ local Db = require("gp.db") local Context = {} -- Split a context insertion command into its component parts +-- This function will split the cmd by ":", at most into 3 parts. +-- It will grab the first 2 substrings that's split by ":", then +-- grab whatever is remaining as the 3rd string. +-- +-- Example: +-- cmd = "@code:/some/path/goes/here:class:fn_name" +-- => {"@code", "/some/path/goes/here", "class:fn_name"} +-- +-- This is can be used to split both @file and @code commands. function Context.cmd_split(cmd) - return vim.split(cmd, ":", { plain = true }) + local result = {} + local splits = u.string_find_all_substr(cmd, ":") + + local cursor = 0 + for i, split in ipairs(splits) do + if i > 2 then + break + end + local next_start = split[1] - 1 + local next_end = split[2] + table.insert(result, string.sub(cmd, cursor, next_start)) + cursor = next_end + 1 + end + + if cursor < #cmd then + table.insert(result, string.sub(cmd, cursor)) + end + + return result end ---@return string | nil @@ -33,9 +60,36 @@ local function file_exists(path) end end +local function get_file_lines(filepath, start_line, end_line) + local lines = {} + local current_line = 0 + + -- Open the file for reading + local file = io.open(filepath, "r") + if not file then + logger.info("[get_file_lines] Could not open file: " .. filepath) + return nil + end + + for line in file:lines() do + if current_line >= start_line then + table.insert(lines, line) + end + if current_line > end_line then + break + end + current_line = current_line + 1 + end + + file:close() + + return lines +end + -- Given a single message, parse out all the context insertion -- commands, then return a new message with all the requested -- context inserted +---@param msg string function Context.insert_contexts(msg) local context_texts = {} @@ -44,11 +98,18 @@ function Context.insert_contexts(msg) for cmd in msg:gmatch("@file:[%w%p]+") do table.insert(cmds, cmd) end + for cmd in msg:gmatch("@code:[%w%p]+[:%w_-]+") do + print("[insert_contexts] found @code cmd: ", cmd) + table.insert(cmds, cmd) + end + + local db = nil -- Process each command and turn it into a string be -- inserted as additional context for _, cmd in ipairs(cmds) do local cmd_parts = Context.cmd_split(cmd) + print("[insert_contexts] processing cmd: ", vim.inspect(cmd_parts)) if cmd_parts[1] == "@file" then -- Read the reqested file and produce a msg snippet to be joined later @@ -59,13 +120,47 @@ function Context.insert_contexts(msg) local content = read_file(fullpath) if content then - local result = gp._H.template_render("filepath\n```content```", { - filepath = filepath, - content = content, - }) + local result = string.format("%s\n```%s```", filepath, content) table.insert(context_texts, result) end + elseif cmd_parts[1] == "@code" then + local rel_path = cmd_parts[2] + local full_fn_name = cmd_parts[3] + print("[insert_contexts] rel_path: ", rel_path) + print("[insert_contexts] full_fn_name: ", full_fn_name) + if not rel_path or not full_fn_name then + print("[insert_contexts] skipping request") + goto continue + end + if db == nil then + db = Db.open() + end + + local fn_def = db:find_fn_def_by_file_n_name(rel_path, full_fn_name) + print("[insert_contexts] fn_def: ", vim.inspect(fn_def)) + if not fn_def then + logger.warning(string.format("Unable to locate function: '%s', '%s'", rel_path, full_fn_name)) + goto continue + end + + local fn_body = get_file_lines(fn_def.file, fn_def.start_line, fn_def.end_line) + print("[insert_contexts] content: ", vim.inspect(fn_body)) + if fn_body then + local result = string.format( + "In '%s', function '%s'\n```%s```", + fn_def.file, + fn_def.name, + table.concat(fn_body, "\n") + ) + table.insert(context_texts, result) + print("[insert_contexts] context_texts: ", vim.inspect(context_texts)) + end end + ::continue:: + end + + if db then + db:close() end -- If no context insertions are requested, don't alter the original msg @@ -73,10 +168,7 @@ function Context.insert_contexts(msg) return msg else -- Otherwise, build and return the final message - return gp._H.template_render("context\n\nmsg", { - context = table.concat(context_texts, "\n"), - msg = msg, - }) + return string.format("%s\n\n%s", table.concat(context_texts, "\n"), msg) end end diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 4d747f94..6853a1ce 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -261,6 +261,36 @@ function Db:find_fn_def_by_name(partial_fn_name) return result end +function Db:find_fn_def_by_file_n_name(rel_path, full_fn_name) + local sql = [[ + SELECT * FROM function_defs WHERE file = ? AND name = ? + ]] + + local result = self.db:eval(sql, { + rel_path, + full_fn_name, + }) + + -- We're expecting the query to return a list of FunctionDefEntry. + -- If we get a boolean back instead, we consider the operation to have failed. + if type(result) == "boolean" then + return nil + end + + ---@cast result FunctionDefEntry[] + if #result > 1 then + logger.error( + string.format( + "[Db.find_fn_def_by_file_n_name] Found more than 1 result for: '%s', '%s'", + rel_path, + full_fn_name + ) + ) + end + + return result[1] +end + -- Removes a single entry from the src_files table given a relative file path -- Note that related entries in the function_defs table will be removed via CASCADE. ---@param src_filepath string diff --git a/lua/gp/utils.lua b/lua/gp/utils.lua index 8b87df14..6d3af977 100644 --- a/lua/gp/utils.lua +++ b/lua/gp/utils.lua @@ -143,4 +143,19 @@ function Utils.full_path_for_project_file(path) return Utils.path_join(proj_root, path) end +function Utils.string_find_all_substr(str, substr) + local result = {} + local first = 0 + local last = 0 + + while true do + first, last = str:find(substr, first + 1) + if not first then + break + end + table.insert(result, { first, last }) + end + return result +end + return Utils From 40dabee07e4288e67f6ccbaf7e408dde98459aa8 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Fri, 2 Aug 2024 12:47:01 +0800 Subject: [PATCH 19/34] Renames the sqlite database file --- lua/gp/db.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 6853a1ce..0d04a9ee 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -41,7 +41,7 @@ function Db.open() return nil end - local db_file = u.path_join(git_root, ".gp/function_defs.sqlite") + local db_file = u.path_join(git_root, ".gp/index.sqlite") if not u.ensure_parent_path_exists(db_file) then logger.error("[db.open] Unable create directory for db file: " .. db_file) return nil From c247a476f595b758094373bd0f16487ff1e8b4f3 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Fri, 2 Aug 2024 12:58:54 +0800 Subject: [PATCH 20/34] Cleans up debug prints in Context.insert_contexts --- lua/gp/context.lua | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 2070ed73..2cdfb7f2 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -99,7 +99,6 @@ function Context.insert_contexts(msg) table.insert(cmds, cmd) end for cmd in msg:gmatch("@code:[%w%p]+[:%w_-]+") do - print("[insert_contexts] found @code cmd: ", cmd) table.insert(cmds, cmd) end @@ -109,7 +108,6 @@ function Context.insert_contexts(msg) -- inserted as additional context for _, cmd in ipairs(cmds) do local cmd_parts = Context.cmd_split(cmd) - print("[insert_contexts] processing cmd: ", vim.inspect(cmd_parts)) if cmd_parts[1] == "@file" then -- Read the reqested file and produce a msg snippet to be joined later @@ -126,10 +124,7 @@ function Context.insert_contexts(msg) elseif cmd_parts[1] == "@code" then local rel_path = cmd_parts[2] local full_fn_name = cmd_parts[3] - print("[insert_contexts] rel_path: ", rel_path) - print("[insert_contexts] full_fn_name: ", full_fn_name) if not rel_path or not full_fn_name then - print("[insert_contexts] skipping request") goto continue end if db == nil then @@ -137,23 +132,16 @@ function Context.insert_contexts(msg) end local fn_def = db:find_fn_def_by_file_n_name(rel_path, full_fn_name) - print("[insert_contexts] fn_def: ", vim.inspect(fn_def)) if not fn_def then logger.warning(string.format("Unable to locate function: '%s', '%s'", rel_path, full_fn_name)) goto continue end local fn_body = get_file_lines(fn_def.file, fn_def.start_line, fn_def.end_line) - print("[insert_contexts] content: ", vim.inspect(fn_body)) if fn_body then - local result = string.format( - "In '%s', function '%s'\n```%s```", - fn_def.file, - fn_def.name, - table.concat(fn_body, "\n") - ) + local result = + string.format("In '%s', function '%s'\n```%s```", fn_def.file, fn_def.name, table.concat(fn_body, "\n")) table.insert(context_texts, result) - print("[insert_contexts] context_texts: ", vim.inspect(context_texts)) end end ::continue:: From 13ed5610e83609a56652d0fbd573e7d069d7f219 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Fri, 2 Aug 2024 19:38:21 +0800 Subject: [PATCH 21/34] Adds python indexing support We're also now indexing symbols of different types: function, classes, and class methods --- data/ts_queries/lua.scm | 19 +++++---- data/ts_queries/python.scm | 20 ++++++++++ lua/gp/completion.lua | 15 +++++-- lua/gp/context.lua | 80 +++++++++++++++++++++++++++----------- lua/gp/db.lua | 8 +++- lua/gp/utils.lua | 52 +++++++++++++++++++++++++ 6 files changed, 159 insertions(+), 35 deletions(-) create mode 100644 data/ts_queries/python.scm diff --git a/data/ts_queries/lua.scm b/data/ts_queries/lua.scm index 33cb8a3f..a50326db 100644 --- a/data/ts_queries/lua.scm +++ b/data/ts_queries/lua.scm @@ -5,36 +5,41 @@ ;; Specificadlly, this ignores the local function declarations. ;; We're only doing this because we're requiring the ;; (file, function_name) pair to be unique in the database. -(chunk - (function_declaration - name: (identifier) @name) @body) +((chunk + (function_declaration + name: (identifier) @name) @body) + (#set! "type" "function")) ;; Matches function declaration using the dot syntax ;; function a_table.a_fn_name() -(chunk +((chunk (function_declaration name: (dot_index_expression) @name) @body) + (#set! "type" "function")) ;; Matches function declaration using the member function syntax ;; function a_table:a_fn_name() -(chunk +((chunk (function_declaration name: (method_index_expression) @name) @body) + (#set! "type" "function")) ;; Matches on: ;; M.some_field = function() end -(chunk +((chunk (assignment_statement (variable_list name: (dot_index_expression) @name) (expression_list value: (function_definition) @body))) + (#set! "type" "function")) ;; Matches on: ;; some_var = function() end -(chunk +((chunk (assignment_statement (variable_list name: (identifier) @name) (expression_list value: (function_definition) @body))) + (#set! "type" "function")) diff --git a/data/ts_queries/python.scm b/data/ts_queries/python.scm new file mode 100644 index 00000000..f7569d22 --- /dev/null +++ b/data/ts_queries/python.scm @@ -0,0 +1,20 @@ +;; Top level function definitions +((module + (function_definition + name: (identifier) @name ) @body + (#not-has-ancestor? @body class_definition)) + (#set! "type" "function")) + +;; Class member function definitions +((class_definition + name: (identifier) @classname + body: (block + (function_definition + name: (identifier) @name ) @body)) + (#set! "type" "class_method")) + + +;; Class definitions +((class_definition + name: (identifier) @name) @body + (#set! "type" "class")) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index e49515b8..283f7211 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -149,10 +149,9 @@ function source:completion_items_for_fn_name(partial_fn_name) end for _, row in ipairs(result) do - table.insert(items, { + local item = { -- fields meant for nvim-cmp label = row.name, - kind = cmp.lsp.CompletionItemKind.Function, labelDetails = { detail = row.file, }, @@ -160,7 +159,17 @@ function source:completion_items_for_fn_name(partial_fn_name) -- fields meant for internal use row = row, type = "@code", - }) + } + + if row.type == "class" then + item.kind = cmp.lsp.CompletionItemKind.Class + elseif row.type == "class_method" then + item.kind = cmp.lsp.CompletionItemKind.Method + else + item.kind = cmp.lsp.CompletionItemKind.Function + end + + table.insert(items, item) end return items diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 2cdfb7f2..41040699 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -139,8 +139,12 @@ function Context.insert_contexts(msg) local fn_body = get_file_lines(fn_def.file, fn_def.start_line, fn_def.end_line) if fn_body then - local result = - string.format("In '%s', function '%s'\n```%s```", fn_def.file, fn_def.name, table.concat(fn_body, "\n")) + local result = string.format( + "In '%s', function '%s'\n```%s```", + fn_def.file, + fn_def.name, + table.concat(fn_body, "\n") + ) table.insert(context_texts, result) end end @@ -267,32 +271,62 @@ function Context.treesitter_extract_function_defs(src_filepath) return nil end + -- The captures are usually returned as a flat list with no way to tell + -- which captures came from the same symbol. But, if the query has attached + -- a some metadata to the query, all captured elements will reference the same metadata + -- table. We can then use this to correctly gather those elements into the same groups. + local function get_meta(x) + return x.metadata + end + captures = u.sort_by(get_meta, captures) + local groups = u.partition_by(get_meta, captures) + -- Reshape the captures into a structure we'd like to work with local results = {} - for i = 1, #captures, 2 do - -- The captures may arrive out of order. - -- We're only expecting the query to contain @name and @body returned - -- Sort out their ordering here. - local caps = { captures[i], captures[i + 1] } - local named_caps = {} - for _, item in ipairs(caps) do - named_caps[item.name] = item + for _, group in ipairs(groups) do + local grp = {} + for _, item in ipairs(group) do + grp[item.name] = item + end + grp.metadata = group[1].metadata + + local type = grp.metadata.type + if type == "function" then + table.insert(results, { + file = src_filepath, + type = "function", + name = grp.name.text, + start_line = grp.body.range[1], + end_line = grp.body.range[3], + body = grp.body.text, + }) + elseif type == "class_method" then + table.insert(results, { + file = src_filepath, + type = "class_method", + name = string.format("%s.%s", grp.classname.text, grp.name.text), + start_line = grp.body.range[1], + end_line = grp.body.range[3], + body = grp.body.text, + }) + elseif type == "class" then + table.insert(results, { + file = src_filepath, + type = "class", + name = grp.name.text, + start_line = grp.body.range[1], + end_line = grp.body.range[3], + body = grp.body.text, + }) end - local fn_name = named_caps.name - local fn_body = named_caps.body - assert(fn_name) - assert(fn_body) - - table.insert(results, { - file = src_filepath, - type = "function_definition", - name = fn_name.text, - start_line = fn_body.range[1], - end_line = fn_body.range[3], - body = fn_body.text, - }) end + -- For debugging and manually checking the output + -- results = u.sort_by(function(x) + -- return x.start_line + -- end, results) + -- u.write_file("results.data.lua", vim.inspect(results)) + return results end diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 0d04a9ee..b1523c61 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -80,6 +80,7 @@ function Db.open() id INTEGER NOT NULL PRIMARY KEY, file TEXT NOT NULL REFERENCES src_files(filename) on DELETE CASCADE, name TEXT NOT NULL, + type TEXT NOT NULL, start_line INTEGER NOT NULL, end_line INTEGER NOT NULL, UNIQUE (file, name) @@ -177,9 +178,11 @@ function Db:upsert_function_def(def) return false end + ---TODO: We're never actually upserting, but deleting and inserting + ---There is no reason to manually construct and upkeep queries like this. local sql = [[ - INSERT INTO function_defs (file, name, start_line, end_line) - VALUES (?, ?, ?, ?) + INSERT INTO function_defs (file, name, type, start_line, end_line) + VALUES (?, ?, ?, ?, ?) ON CONFLICT(file, name) DO UPDATE SET start_line = excluded.start_line, end_line = excluded.end_line @@ -190,6 +193,7 @@ function Db:upsert_function_def(def) -- For the INSERT VALUES clause def.file, def.name, + def.type, def.start_line, def.end_line, diff --git a/lua/gp/utils.lua b/lua/gp/utils.lua index 6d3af977..9e9f4a14 100644 --- a/lua/gp/utils.lua +++ b/lua/gp/utils.lua @@ -158,4 +158,56 @@ function Utils.string_find_all_substr(str, substr) return result end +function Utils.partition_by(pred, list) + local result = {} + local current_partition = {} + local last_key = nil + + for _, item in ipairs(list) do + local key = pred(item) + if last_key == nil or key ~= last_key then + if #current_partition > 0 then + table.insert(result, current_partition) + end + current_partition = {} + end + table.insert(current_partition, item) + last_key = key + end + + if #current_partition > 0 then + table.insert(result, current_partition) + end + + return result +end + +function Utils.write_file(filename, content, mode) + mode = mode or "w" -- Default mode is write + if not content then + return true + end + local file = io.open(filename, mode) + if file then + file:write(content) + file:close() + else + error("Unable to open file: " .. filename) + end + return true +end + +function Utils.sort_by(key_fn, tbl) + table.sort(tbl, function(a, b) + local ka, kb = key_fn(a), key_fn(b) + if type(ka) == "table" and type(kb) == "table" then + -- Use table identifiers as tie-breaker + return tostring(ka) < tostring(kb) + else + return ka < kb + end + end) + return tbl +end + return Utils From ec80e1936ecd5db37ee0db4cc549e3a60bfbe8c4 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Fri, 2 Aug 2024 22:21:43 +0800 Subject: [PATCH 22/34] Renames functions_defs table to symbols table. The symbols table is also now defined through the sqlite.lua ORM syntax --- lua/gp/completion.lua | 2 +- lua/gp/context.lua | 31 +++++++------- lua/gp/db.lua | 94 ++++++++++++++++++++----------------------- 3 files changed, 60 insertions(+), 67 deletions(-) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index 283f7211..4619ce3d 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -141,7 +141,7 @@ local function completion_items_for_path(path) end function source:completion_items_for_fn_name(partial_fn_name) - local result = self.db:find_fn_def_by_name(partial_fn_name) + local result = self.db:find_symbol_by_name(partial_fn_name) local items = {} if not result then diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 41040699..7cbbd803 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -131,7 +131,7 @@ function Context.insert_contexts(msg) db = Db.open() end - local fn_def = db:find_fn_def_by_file_n_name(rel_path, full_fn_name) + local fn_def = db:find_symbol_by_file_n_name(rel_path, full_fn_name) if not fn_def then logger.warning(string.format("Unable to locate function: '%s', '%s'", rel_path, full_fn_name)) goto continue @@ -139,12 +139,8 @@ function Context.insert_contexts(msg) local fn_body = get_file_lines(fn_def.file, fn_def.start_line, fn_def.end_line) if fn_body then - local result = string.format( - "In '%s', function '%s'\n```%s```", - fn_def.file, - fn_def.name, - table.concat(fn_body, "\n") - ) + local result = + string.format("In '%s', function '%s'\n```%s```", fn_def.file, fn_def.name, table.concat(fn_body, "\n")) table.insert(context_texts, result) end end @@ -291,34 +287,38 @@ function Context.treesitter_extract_function_defs(src_filepath) grp.metadata = group[1].metadata local type = grp.metadata.type + local item if type == "function" then - table.insert(results, { + item = { file = src_filepath, type = "function", name = grp.name.text, start_line = grp.body.range[1], end_line = grp.body.range[3], - body = grp.body.text, - }) + body = grp.body.text, -- for diagnostics + } elseif type == "class_method" then - table.insert(results, { + item = { file = src_filepath, type = "class_method", name = string.format("%s.%s", grp.classname.text, grp.name.text), start_line = grp.body.range[1], end_line = grp.body.range[3], body = grp.body.text, - }) + } elseif type == "class" then - table.insert(results, { + item = { file = src_filepath, type = "class", name = grp.name.text, start_line = grp.body.range[1], end_line = grp.body.range[3], body = grp.body.text, - }) + } end + + item.body = nil -- Remove the diagnostics field to prep the entry for db insertion + table.insert(results, item) end -- For debugging and manually checking the output @@ -355,7 +355,8 @@ function Context.build_function_def_index_for_file(db, src_filepath) if not success then return false end - return db:upsert_fnlist(fnlist) + + return db:insert_symbol_list(fnlist) end) return result end diff --git a/lua/gp/db.lua b/lua/gp/db.lua index b1523c61..920c1bc0 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -4,7 +4,7 @@ local gp = require("gp") local u = require("gp.utils") local logger = require("gp.logger") --- Describes files we've scanned previously to produce the list of function definitions +-- Describes files we've scanned previously to produce the list of symbols ---@class SrcFileEntry ---@field id number: unique id ---@field filename string: path relative to the git/project root @@ -14,10 +14,11 @@ local logger = require("gp.logger") ---@field last_scan_time number: unix time stamp indicating when the last scan of this file was made -- Describes where each of the functions are in the project ----@class FunctionDefEntry +---@class SymbolDefEntry ---@field id number: unique id ----@field name string: Name of the function ----@field file string: In which file is the function defined? +---@field file string: Which file is the symbol defined? +---@field name string: Name of the symbol +---@field type string: type of the symbol ---@field start_line number: Which line in the file does the definition start? ---@field end_line number: Which line in the file does the definition end? @@ -30,7 +31,7 @@ Db._new = function(db) return setmetatable({ db = db }, { __index = Db }) end ---- Opens and/or creates a SQLite database for storing function definitions. +--- Opens and/or creates a SQLite database for storing symbol definitions. -- @return Db|nil A new Db object if successful, nil if an error occurs -- @side-effect Creates .gp directory and database file if they don't exist -- @side-effect Logs errors if unable to locate project root or create directory @@ -67,27 +68,20 @@ function Db.open() last_scan_time = { type = "integer", required = true }, -- unix timestamp }, + symbols = { + id = true, + file = { type = "text", require = true, reference = "src_files.filenamed", on_delete = "cascade" }, + name = { type = "text", required = true }, + type = { type = "text", required = true }, + start_line = { type = "integer", required = true }, + end_line = { type = "integer", required = true }, + }, + opts = { keep_open = true }, }) - db:eval("PRAGMA foreign_keys = ON;") - - -- sqlite.lua doesn't seem to support adding random table options - -- In this case, being able to perform an upsert in the function_defs table depends on - -- having UNIQUE file and fn name pair. - db:eval([[ - CREATE TABLE IF NOT EXISTS function_defs ( - id INTEGER NOT NULL PRIMARY KEY, - file TEXT NOT NULL REFERENCES src_files(filename) on DELETE CASCADE, - name TEXT NOT NULL, - type TEXT NOT NULL, - start_line INTEGER NOT NULL, - end_line INTEGER NOT NULL, - UNIQUE (file, name) - ); - ]]) - db:eval("CREATE UNIQUE INDEX IF NOT EXISTS idx_src_files_filename ON src_files (filename);") + db:eval("CREATE UNIQUE INDEX IF NOT EXISTS idx_symbol_file_n_name ON symbols (file, name);") return Db._new(db) end @@ -170,20 +164,22 @@ function Db:upsert_filelist(filelist) return true end --- Upserts a single function def entry into the database ---- @param def FunctionDefEntry -function Db:upsert_function_def(def) +-- Upserts a single symbol entry into the database +--- @param def SymbolDefEntry +function Db:upsert_symbol(def) if not self.db then - logger.error("[db.upsert_function_def] Database not initialized") + logger.error("[db.upsert_symbol] Database not initialized") return false end - ---TODO: We're never actually upserting, but deleting and inserting - ---There is no reason to manually construct and upkeep queries like this. + ---WARNING: Do not use ORM here. + -- This function can be called a lot during a full index rebuild. + -- Using the ORM here can cause a 100% slowdown. local sql = [[ - INSERT INTO function_defs (file, name, type, start_line, end_line) + INSERT INTO symbols (file, name, type, start_line, end_line) VALUES (?, ?, ?, ?, ?) ON CONFLICT(file, name) DO UPDATE SET + type = excluded.type, start_line = excluded.start_line, end_line = excluded.end_line WHERE file = ? AND name = ? @@ -203,7 +199,7 @@ function Db:upsert_function_def(def) }) if not success then - logger.error("[db.upsert_function_def] Failed to upsert function: " .. def.name .. " for file: " .. def.file) + logger.error("[db.upsert_symbol] Failed to upsert symbol: " .. def.name .. " for file: " .. def.file) return false end @@ -224,13 +220,13 @@ function Db:with_transaction(fn) return true end ---- Updates the dastabase with the contents of the `fnlist` ---- Note that this function early terminates of any of the entry upsert fails. +--- Updates the dastabase with the contents of the `symbols_list` +--- Note that this function early terminates if any of the entry upsert fails. --- This behavior is only suitable when run inside a transaction. ---- @param fnlist FunctionDefEntry[] -function Db:upsert_fnlist(fnlist) - for _, def in ipairs(fnlist) do - local success = self:upsert_function_def(def) +--- @param symbols_list SymbolDefEntry[] +function Db:insert_symbol_list(symbols_list) + for _, def in ipairs(symbols_list) do + local success = self:upsert_symbol(def) if not success then logger.error("[db.upsert_fnlist] Failed to upsert function def list") return false @@ -244,9 +240,9 @@ function Db:close() self.db:close() end -function Db:find_fn_def_by_name(partial_fn_name) +function Db:find_symbol_by_name(partial_fn_name) local sql = [[ - SELECT * FROM function_defs WHERE name LIKE ? + SELECT * FROM symbols WHERE name LIKE ? ]] local wildcard_name = "%" .. partial_fn_name .. "%" @@ -255,19 +251,19 @@ function Db:find_fn_def_by_name(partial_fn_name) wildcard_name, }) - -- We're expecting the query to return a list of FunctionDefEntry. + -- We're expecting the query to return a list of SymbolDefEntry. -- If we get a boolean back instead, we consider the operation to have failed. if type(result) == "boolean" then return nil end - ---@cast result FunctionDefEntry + ---@cast result SymbolDefEntry return result end -function Db:find_fn_def_by_file_n_name(rel_path, full_fn_name) +function Db:find_symbol_by_file_n_name(rel_path, full_fn_name) local sql = [[ - SELECT * FROM function_defs WHERE file = ? AND name = ? + SELECT * FROM symbols WHERE file = ? AND name = ? ]] local result = self.db:eval(sql, { @@ -275,20 +271,16 @@ function Db:find_fn_def_by_file_n_name(rel_path, full_fn_name) full_fn_name, }) - -- We're expecting the query to return a list of FunctionDefEntry. + -- We're expecting the query to return a list of SymbolDefEntry. -- If we get a boolean back instead, we consider the operation to have failed. if type(result) == "boolean" then return nil end - ---@cast result FunctionDefEntry[] + ---@cast result SymbolDefEntry[] if #result > 1 then logger.error( - string.format( - "[Db.find_fn_def_by_file_n_name] Found more than 1 result for: '%s', '%s'", - rel_path, - full_fn_name - ) + string.format("[Db.find_symbol_by_file_n_name] Found more than 1 result for: '%s', '%s'", rel_path, full_fn_name) ) end @@ -296,7 +288,7 @@ function Db:find_fn_def_by_file_n_name(rel_path, full_fn_name) end -- Removes a single entry from the src_files table given a relative file path --- Note that related entries in the function_defs table will be removed via CASCADE. +-- Note that related entries in the symbols table will be removed via CASCADE. ---@param src_filepath string function Db:remove_src_file_entry(src_filepath) local sql = [[ @@ -311,7 +303,7 @@ function Db:remove_src_file_entry(src_filepath) end function Db:clear() - self.db:eval("DELETE FROM function_defs") + self.db:eval("DELETE FROM symbols") self.db:eval("DELETE FROM src_files") end From 999d7abbc44d20ed31485ddfdd7c6fa377377e27 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sat, 3 Aug 2024 14:22:36 +0800 Subject: [PATCH 23/34] Remove stale symbols from the index --- lua/gp/context.lua | 36 +++++++++++++--------- lua/gp/db.lua | 75 ++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 88 insertions(+), 23 deletions(-) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 7cbbd803..c55002a2 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -139,8 +139,12 @@ function Context.insert_contexts(msg) local fn_body = get_file_lines(fn_def.file, fn_def.start_line, fn_def.end_line) if fn_body then - local result = - string.format("In '%s', function '%s'\n```%s```", fn_def.file, fn_def.name, table.concat(fn_body, "\n")) + local result = string.format( + "In '%s', function '%s'\n```%s```", + fn_def.file, + fn_def.name, + table.concat(fn_body, "\n") + ) table.insert(context_texts, result) end end @@ -332,10 +336,10 @@ end ---@param db Db ---@param src_filepath string -function Context.build_function_def_index_for_file(db, src_filepath) +function Context.build_symbol_index_for_file(db, src_filepath) -- try to retrieve function definitions from the file - local fnlist = Context.treesitter_extract_function_defs(src_filepath) - if not fnlist then + local symbols_list = Context.treesitter_extract_function_defs(src_filepath) + if not symbols_list then return false end @@ -349,25 +353,29 @@ function Context.build_function_def_index_for_file(db, src_filepath) -- Update the src file entry and the function definitions in a single transaction local result = db:with_transaction(function() - db:remove_src_file_entry(src_filepath) - local success = db:upsert_src_file(src_file_entry) if not success then + logger.error("Upserting src_file failed") + return false + end + + success = db:upsert_and_clean_symbol_list_for_file(src_file_entry.filename, symbols_list) + if not success then + logger.error("Upserting symbol list failed") return false end - return db:insert_symbol_list(fnlist) + return true end) return result end -function Context.build_function_def_index(db) +function Context.build_symbol_index(db) local git_root = u.git_root_from_cwd() if not git_root then - logger.error("[Context.build_function_def_index] Unable to locate project root") + logger.error("[Context.build_symbol_index] Unable to locate project root") return false end - local git_root_len = #git_root + 2 u.walk_directory(git_root, { should_process = function(entry, rel_path, full_path, is_dir) @@ -389,7 +397,7 @@ function Context.build_function_def_index(db) process_file = function(rel_path, full_path) if vim.filetype.match({ filename = full_path }) then - local success = Context.build_function_def_index_for_file(db, rel_path) + local success = Context.build_symbol_index_for_file(db, rel_path) if not success then logger.debug("Failed to build function def index for: " .. rel_path) end @@ -403,7 +411,7 @@ function Context.index_single_file(src_filepath) if not db then return end - Context.build_function_def_index_for_file(db, src_filepath) + Context.build_symbol_index_for_file(db, src_filepath) db:close() end @@ -415,7 +423,7 @@ function Context.index_all() if not db then return end - Context.build_function_def_index(db) + Context.build_symbol_index(db) db:close() local end_time = uv.hrtime() diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 920c1bc0..7637c9c3 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -12,6 +12,7 @@ local logger = require("gp.logger") ---@field filetype string: filetype as reported by neovim at last scan ---@field mod_time number: last file modification time reported by the os at last scan ---@field last_scan_time number: unix time stamp indicating when the last scan of this file was made +---@field generation number: For internal use - garbage collection -- Describes where each of the functions are in the project ---@class SymbolDefEntry @@ -21,6 +22,7 @@ local logger = require("gp.logger") ---@field type string: type of the symbol ---@field start_line number: Which line in the file does the definition start? ---@field end_line number: Which line in the file does the definition end? +---@field generation number: For internal use - garbage collection ---@class Db ---@field db sqlite_db @@ -66,15 +68,17 @@ function Db.open() filetype = { type = "text", required = true }, -- filetype as reported by neovim at last scan mod_time = { type = "integer", required = true }, -- file mod time reported by the fs at last scan last_scan_time = { type = "integer", required = true }, -- unix timestamp + generation = { type = "integer" }, -- for garbage collection }, symbols = { id = true, - file = { type = "text", require = true, reference = "src_files.filenamed", on_delete = "cascade" }, + file = { type = "text", require = true, reference = "src_files.filename", on_delete = "cascade" }, name = { type = "text", required = true }, type = { type = "text", required = true }, start_line = { type = "integer", required = true }, end_line = { type = "integer", required = true }, + generation = { type = "integer" }, -- for garbage collection }, opts = { keep_open = true }, @@ -176,12 +180,13 @@ function Db:upsert_symbol(def) -- This function can be called a lot during a full index rebuild. -- Using the ORM here can cause a 100% slowdown. local sql = [[ - INSERT INTO symbols (file, name, type, start_line, end_line) - VALUES (?, ?, ?, ?, ?) + INSERT INTO symbols (file, name, type, start_line, end_line, generation) + VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(file, name) DO UPDATE SET type = excluded.type, start_line = excluded.start_line, - end_line = excluded.end_line + end_line = excluded.end_line, + generation = excluded.generation WHERE file = ? AND name = ? ]] @@ -192,6 +197,7 @@ function Db:upsert_symbol(def) def.type, def.start_line, def.end_line, + def.generation or -1, -- For the WHERE clause def.file, @@ -209,14 +215,61 @@ end -- Wraps the given function in a sqlite transaction ---@param fn function() function Db:with_transaction(fn) - self.db:execute("BEGIN") - local success, result = pcall(fn) - self.db:execute("END") + local success, result + success = self.db:execute("BEGIN") if not success then + logger.error("[db.with_transaction] Unable to start transaction") + return false + end + + success, result = pcall(fn) + if not success then + logger.error("[db.with_transaction] fn return false") logger.error(result) + + success = self.db:execute("ROLLBACK") + if not success then + logger.error("[db.with_transaction] Rollback failed") + end return false end + + success = self.db:execute("COMMIT") + if not success then + logger.error("[db.with_transaction] Unable to end transaction") + return false + end + + return true +end + +local function random_8byte_int() + return math.random(0, 0xFFFFFFFFFFFFFFFF) +end + +--- @param symbols_list SymbolDefEntry[] +function Db:upsert_and_clean_symbol_list_for_file(src_rel_path, symbols_list) + -- Generate a random generation ID for all tne newly updated/refreshed items + local generation = random_8byte_int() + for _, item in ipairs(symbols_list) do + item.generation = generation + end + + -- Upsert all entries + local success = self:upsert_symbol_list(symbols_list) + if not success then + return success + end + + -- Remove all symbols in the file that does not hav the new generation ID + -- Those symbols are not present in the newly generated list and should be removed. + success = self.db:eval([[DELETE from symbols WHERE file = ? and generation != ? ]], { src_rel_path, generation }) + if not success then + logger.error("[db.insert_and_clean_symbol_list_for_file] Unable to clean up garbage") + return success + end + return true end @@ -224,7 +277,7 @@ end --- Note that this function early terminates if any of the entry upsert fails. --- This behavior is only suitable when run inside a transaction. --- @param symbols_list SymbolDefEntry[] -function Db:insert_symbol_list(symbols_list) +function Db:upsert_symbol_list(symbols_list) for _, def in ipairs(symbols_list) do local success = self:upsert_symbol(def) if not success then @@ -280,7 +333,11 @@ function Db:find_symbol_by_file_n_name(rel_path, full_fn_name) ---@cast result SymbolDefEntry[] if #result > 1 then logger.error( - string.format("[Db.find_symbol_by_file_n_name] Found more than 1 result for: '%s', '%s'", rel_path, full_fn_name) + string.format( + "[Db.find_symbol_by_file_n_name] Found more than 1 result for: '%s', '%s'", + rel_path, + full_fn_name + ) ) end From 300bd0a4c1c958b6ab8f3c1ed1a78c09e6e37be4 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sat, 3 Aug 2024 15:21:35 +0800 Subject: [PATCH 24/34] Only attach completeion source to the chat buffer instead of *.md --- after/plugin/gp.lua | 8 +------- lua/gp/completion.lua | 41 ++++++++++------------------------------- lua/gp/init.lua | 2 ++ 3 files changed, 13 insertions(+), 38 deletions(-) diff --git a/after/plugin/gp.lua b/after/plugin/gp.lua index 33d64b99..535e6292 100644 --- a/after/plugin/gp.lua +++ b/after/plugin/gp.lua @@ -1,7 +1 @@ -print("in after/plugin/gp.lua") -local completion = require("gp.completion") - -print(vim.inspect(completion)) - -completion.register_cmd_source() -print("done after/plugin/gp.lua") +require("gp.completion").register_cmd_source() diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index 4619ce3d..c055d841 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -35,43 +35,24 @@ function source.get_trigger_characters() return { "@", ":", "/" } end +-- Attaches the completion source to the given `bufnr` function source.setup_for_buffer(bufnr) - print("in setup_for_buffer") - local config = require("cmp").get_config() - - print("cmp.get_config() returned:") - print(vim.inspect(config)) + -- Don't attach the completion source if it's already been done + local attached_varname = "gp_source_attached" + if vim.b[attached_varname] then + return + end - print("cmp_config.set_buffer: " .. config.set_buffer) + -- Attach the completion source + local config = require("cmp.config") config.set_buffer({ sources = { { name = source.src_name }, }, }, bufnr) -end -function source.setup_autocmd_for_markdown() - print("setting up autocmd...") - vim.api.nvim_create_autocmd("BufEnter", { - pattern = { "*.md", "markdown" }, - callback = function(arg) - local attached_varname = "gp_source_attached" - local attached = buf_get_var(arg.buf, attached_varname, false) - if attached then - return - end - - print("attaching completion source for buffer: " .. arg.buf) - local cmp = require("cmp") - cmp.setup.buffer({ - sources = cmp.config.sources({ - { name = source.src_name }, - }), - }) - - buf_set_var(arg.buf, attached_varname, true) - end, - }) + -- Set a flag so we don't try to set the source again + vim.b[attached_varname] = true end function source.register_cmd_source() @@ -265,6 +246,4 @@ function source:execute(item, callback) callback() end -source.setup_autocmd_for_markdown() - return source diff --git a/lua/gp/init.lua b/lua/gp/init.lua index e6032571..850a970c 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -2105,6 +2105,7 @@ M.cmd.ChatNew = function(params, system_prompt, agent) end require("gp.context").build_initial_index() + require("gp.completion").setup_for_buffer(buf) return buf end @@ -2141,6 +2142,7 @@ M.cmd.ChatToggle = function(params, system_prompt, agent) end require("gp.context").build_initial_index() + require("gp.completion").setup_for_buffer(buf) return buf end From 52e3ae5d772bc534df698942ff430d2da9110c2e Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sat, 3 Aug 2024 15:28:10 +0800 Subject: [PATCH 25/34] Fixes broken @file when cmd_split was rewritten --- lua/gp/completion.lua | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index c055d841..1baa6aa5 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -172,7 +172,7 @@ function source.complete(self, request, callback) if cmd_parts[1]:match("@file") then -- What's the path we're trying to provide completion for? - local path = cmd_parts[2] + local path = cmd_parts[2] or "" -- List the items in the specified directory items = completion_items_for_path(path) @@ -181,10 +181,7 @@ function source.complete(self, request, callback) -- cmp won't call us again to provide an updated list isIncomplete = false elseif cmd_parts[1]:match("@code") then - local partial_fn_name = cmd_parts[2] - if not partial_fn_name then - partial_fn_name = "" - end + local partial_fn_name = cmd_parts[2] or "" -- When the user confirms completion of an item, we alter the -- command to look like `@code:path/to/file:fn_name` to uniquely From 89d5d6ea096a7a95e1eacfdca5be2357747317e5 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sat, 3 Aug 2024 15:45:39 +0800 Subject: [PATCH 26/34] Adds cmd RpRebuildIndex --- lua/gp/init.lua | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 850a970c..67aee0fc 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -923,6 +923,10 @@ M.setup = function(opts) end end + vim.api.nvim_create_user_command("GpRebuildIndex", function(_) + require("gp.context").index_all() + end, {}) + M.buf_handler() if vim.fn.executable("curl") == 0 then From e63c68eb487ce91a563ceec33119305242533a7e Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sat, 3 Aug 2024 21:26:27 +0800 Subject: [PATCH 27/34] Can now re-index changed files only We're also now discarding old src_files entries that no longer exist on disk --- lua/gp/context.lua | 109 +++++++++++++++++++++++++++++++++++++++------ lua/gp/db.lua | 19 ++++---- lua/gp/utils.lua | 4 ++ 3 files changed, 108 insertions(+), 24 deletions(-) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index c55002a2..d0d02f3d 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -336,7 +336,8 @@ end ---@param db Db ---@param src_filepath string -function Context.build_symbol_index_for_file(db, src_filepath) +---@param generation? number +function Context.build_symbol_index_for_file(db, src_filepath, generation) -- try to retrieve function definitions from the file local symbols_list = Context.treesitter_extract_function_defs(src_filepath) if not symbols_list then @@ -350,6 +351,7 @@ function Context.build_symbol_index_for_file(db, src_filepath) return false end src_file_entry.last_scan_time = os.time() + src_file_entry.generation = generation -- Update the src file entry and the function definitions in a single transaction local result = db:with_transaction(function() @@ -370,6 +372,23 @@ function Context.build_symbol_index_for_file(db, src_filepath) return result end +local function default_ignore_dirs_and_files(entry, rel_path, full_path, is_dir) + if u.string_starts_with(entry, ".") then + return false + end + + if is_dir then + if entry == "node_modules" then + return false + end + else + if u.string_ends_with(entry, ".txt") or u.string_ends_with(entry, ".md") then + return false + end + end + return true +end + function Context.build_symbol_index(db) local git_root = u.git_root_from_cwd() if not git_root then @@ -377,33 +396,79 @@ function Context.build_symbol_index(db) return false end + local generation = u.random_8byte_int() + u.walk_directory(git_root, { - should_process = function(entry, rel_path, full_path, is_dir) - if u.string_starts_with(entry, ".") then - return false - end + should_process = default_ignore_dirs_and_files, - if is_dir then - if entry == "node_modules" then - return false - end - else - if u.string_ends_with(entry, ".txt") or u.string_ends_with(entry, ".md") then - return false + process_file = function(rel_path, full_path) + if vim.filetype.match({ filename = full_path }) then + local success = Context.build_symbol_index_for_file(db, rel_path, generation) + if not success then + logger.debug("Failed to build function def index for: " .. rel_path) end end - return true end, + }) + + db.db:eval([[DELETE FROM src_files WHERE generation != ?]], { generation }) +end + +local ChangeResult = { + UNCHANGED = 0, + CHANGED = 1, + NOT_IN_LAST_SCAN = 2, +} + +-- Answers if the gien file seem to have changed since last scan +---@param db Db +---@param rel_path string +local function file_changed_since_last_scan(db, rel_path) + local cur = Db.collect_src_file_data(rel_path) + assert(cur) + + ---@type boolean|SrcFileEntry + local prev = db.db:eval([[SELECT * from src_files WHERE filename = ?]], { rel_path }) + if not prev then + return ChangeResult.NOT_IN_LAST_SCAN + end + + if cur.mod_time > prev.mod_time or cur.file_size ~= prev.file_size then + return ChangeResult.CHANGED + end + + return ChangeResult.UNCHANGED +end + +function Context.rebuild_symbol_index_for_changed_files(db) + local git_root = u.git_root_from_cwd() + if not git_root then + logger.error("[Context.build_symbol_index] Unable to locate project root") + return false + end + + local generation = u.random_8byte_int() + + u.walk_directory(git_root, { + should_process = default_ignore_dirs_and_files, process_file = function(rel_path, full_path) if vim.filetype.match({ filename = full_path }) then - local success = Context.build_symbol_index_for_file(db, rel_path) + local status = file_changed_since_last_scan(db, rel_path) + if status == ChangeResult.UNCHANGED then + -- Even if the file did not change, we still want to mark the entry with the current generation ID + db.db:eval([[UPDATE src_files SET generation = ? WHERE filename = ?]], { generation, rel_path }) + return + end + local success = Context.build_symbol_index_for_file(db, rel_path, generation) if not success then logger.debug("Failed to build function def index for: " .. rel_path) end end end, }) + + db.db:eval([[DELETE FROM src_files WHERE generation != ?]], { generation }) end function Context.index_single_file(src_filepath) @@ -415,6 +480,22 @@ function Context.index_single_file(src_filepath) db:close() end +function Context.index_stale() + local uv = vim.uv or vim.loop + local start_time = uv.hrtime() + + local db = Db.open() + if not db then + return + end + Context.rebuild_symbol_index_for_changed_files(db) + db:close() + + local end_time = uv.hrtime() + local elapsed_time_ms = (end_time - start_time) / 1e6 + logger.info(string.format("[Gp] Indexing took: %.2f ms", elapsed_time_ms)) +end + function Context.index_all() local uv = vim.uv or vim.loop local start_time = uv.hrtime() diff --git a/lua/gp/db.lua b/lua/gp/db.lua index 7637c9c3..77d6b978 100644 --- a/lua/gp/db.lua +++ b/lua/gp/db.lua @@ -12,7 +12,7 @@ local logger = require("gp.logger") ---@field filetype string: filetype as reported by neovim at last scan ---@field mod_time number: last file modification time reported by the os at last scan ---@field last_scan_time number: unix time stamp indicating when the last scan of this file was made ----@field generation number: For internal use - garbage collection +---@field generation? number: For internal use - garbage collection -- Describes where each of the functions are in the project ---@class SymbolDefEntry @@ -22,7 +22,7 @@ local logger = require("gp.logger") ---@field type string: type of the symbol ---@field start_line number: Which line in the file does the definition start? ---@field end_line number: Which line in the file does the definition end? ----@field generation number: For internal use - garbage collection +---@field generation? number: For internal use - garbage collection ---@class Db ---@field db sqlite_db @@ -50,6 +50,7 @@ function Db.open() return nil end + ---@type sqlite_db local db = sqlite({ uri = db_file, @@ -125,13 +126,14 @@ function Db:upsert_src_file(file) end local sql = [[ - INSERT INTO src_files (filename, file_size, filetype, mod_time, last_scan_time) - VALUES (?, ?, ?, ?, ?) + INSERT INTO src_files (filename, file_size, filetype, mod_time, last_scan_time, generation) + VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(filename) DO UPDATE SET file_size = excluded.file_size, filetype = excluded.filetype, mod_time = excluded.mod_time, - last_scan_time = excluded.last_scan_time + last_scan_time = excluded.last_scan_time, + generation = excluded.generation WHERE filename = ? ]] @@ -142,6 +144,7 @@ function Db:upsert_src_file(file) file.filetype, file.mod_time, file.last_scan_time, + file.generation or -1, -- For the WHERE claue file.filename, @@ -244,14 +247,10 @@ function Db:with_transaction(fn) return true end -local function random_8byte_int() - return math.random(0, 0xFFFFFFFFFFFFFFFF) -end - --- @param symbols_list SymbolDefEntry[] function Db:upsert_and_clean_symbol_list_for_file(src_rel_path, symbols_list) -- Generate a random generation ID for all tne newly updated/refreshed items - local generation = random_8byte_int() + local generation = u.random_8byte_int() for _, item in ipairs(symbols_list) do item.generation = generation end diff --git a/lua/gp/utils.lua b/lua/gp/utils.lua index 9e9f4a14..2bdfd21f 100644 --- a/lua/gp/utils.lua +++ b/lua/gp/utils.lua @@ -210,4 +210,8 @@ function Utils.sort_by(key_fn, tbl) return tbl end +function Utils.random_8byte_int() + return math.random(0, 0xFFFFFFFFFFFFFFFF) +end + return Utils From 15cc8030efb7a16a7c6c589d34d314be1ea2c0b6 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sat, 3 Aug 2024 22:53:20 +0800 Subject: [PATCH 28/34] Periodically rebuild symbols for stale files --- lua/gp/completion.lua | 17 +---------------- lua/gp/context.lua | 27 +++++++++++++++++++++++++-- lua/gp/init.lua | 6 ++---- lua/gp/utils.lua | 15 +++++++++++++++ 4 files changed, 43 insertions(+), 22 deletions(-) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index 1baa6aa5..96180e7c 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -3,21 +3,6 @@ local context = require("gp.context") local db = require("gp.db") local cmp = require("cmp") --- Gets a buffer variable or returns the default -local function buf_get_var(buf, var_name, default) - local status, result = pcall(vim.api.nvim_buf_get_var, buf, var_name) - if status then - return result - else - return default - end -end - --- This function is only here make the get/set call pair look consistent -local function buf_set_var(buf, var_name, value) - return vim.api.nvim_buf_set_var(buf, var_name, value) -end - ---@class CompletionSource ---@field db Db local source = {} @@ -36,7 +21,7 @@ function source.get_trigger_characters() end -- Attaches the completion source to the given `bufnr` -function source.setup_for_buffer(bufnr) +function source.setup_for_chat_buffer(bufnr) -- Don't attach the completion source if it's already been done local attached_varname = "gp_source_attached" if vim.b[attached_varname] then diff --git a/lua/gp/context.lua b/lua/gp/context.lua index d0d02f3d..83b6777d 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -527,8 +527,25 @@ function Context.build_initial_index() db:close() end +function Context.setup_autocmd_update_index_periodically(bufnr) + local rebuild_time_var = "gp_next_rebuild_time" + local rebuild_period = 60 + u.buf_set_var(bufnr, rebuild_time_var, os.time() + rebuild_period) + + vim.api.nvim_create_autocmd("BufEnter", { + buffer = bufnr, + callback = function(arg) + local build_time = u.buf_get_var(arg.buf, rebuild_time_var) + if os.time() > build_time then + Context.index_stale() + u.buf_set_var(arg.buf, rebuild_time_var, os.time() + rebuild_period) + end + end, + }) +end + -- Setup autocommand to update the function def index as the files are saved -function Context.setup_autocmd_update_index() +function Context.setup_autocmd_update_index_on_file_save() vim.api.nvim_create_autocmd("BufWritePost", { pattern = { "*" }, group = vim.api.nvim_create_augroup("GpFileIndexUpdate", { clear = true }), @@ -538,6 +555,12 @@ function Context.setup_autocmd_update_index() }) end -Context.setup_autocmd_update_index() +function Context.setup_for_chat_buffer(buf) + Context.build_initial_index() + Context.setup_autocmd_update_index_periodically(buf) + require("gp.completion").setup_for_chat_buffer(buf) +end + +Context.setup_autocmd_update_index_on_file_save() return Context diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 67aee0fc..5cfac3e0 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -2108,8 +2108,7 @@ M.cmd.ChatNew = function(params, system_prompt, agent) buf = M.new_chat(params, false, system_prompt, agent) end - require("gp.context").build_initial_index() - require("gp.completion").setup_for_buffer(buf) + require("gp.context").setup_for_chat_buffer(buf) return buf end @@ -2145,8 +2144,7 @@ M.cmd.ChatToggle = function(params, system_prompt, agent) buf = M.new_chat(params, true, system_prompt, agent) end - require("gp.context").build_initial_index() - require("gp.completion").setup_for_buffer(buf) + require("gp.context").setup_for_chat_buffer(buf) return buf end diff --git a/lua/gp/utils.lua b/lua/gp/utils.lua index 2bdfd21f..6b9a226e 100644 --- a/lua/gp/utils.lua +++ b/lua/gp/utils.lua @@ -214,4 +214,19 @@ function Utils.random_8byte_int() return math.random(0, 0xFFFFFFFFFFFFFFFF) end +-- Gets a buffer variable or returns the default +function Utils.buf_get_var(buf, var_name, default) + local status, result = pcall(vim.api.nvim_buf_get_var, buf, var_name) + if status then + return result + else + return default + end +end + +-- This function is only here make the get/set call pair look consistent +function Utils.buf_set_var(buf, var_name, value) + return vim.api.nvim_buf_set_var(buf, var_name, value) +end + return Utils From 1951cad3ae5073bbd306938a8f1003067c46ae62 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sun, 4 Aug 2024 14:22:33 +0800 Subject: [PATCH 29/34] Directory walk now respects .gitignore --- lua/gp/context.lua | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 83b6777d..c1688a11 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -372,21 +372,19 @@ function Context.build_symbol_index_for_file(db, src_filepath, generation) return result end -local function default_ignore_dirs_and_files(entry, rel_path, full_path, is_dir) - if u.string_starts_with(entry, ".") then - return false +local function make_gitignore_fn(git_root) + local base_paths = { git_root } + local allow = require("plenary.scandir").__make_gitignore(base_paths) + if not allow then + return nil end - if is_dir then - if entry == "node_modules" then - return false - end - else - if u.string_ends_with(entry, ".txt") or u.string_ends_with(entry, ".md") then + return function(entry, rel_path, full_path, is_dir) + if entry == ".git" or entry == ".github" then return false end + return allow(base_paths, full_path) end - return true end function Context.build_symbol_index(db) @@ -399,7 +397,7 @@ function Context.build_symbol_index(db) local generation = u.random_8byte_int() u.walk_directory(git_root, { - should_process = default_ignore_dirs_and_files, + should_process = make_gitignore_fn(git_root), process_file = function(rel_path, full_path) if vim.filetype.match({ filename = full_path }) then @@ -450,7 +448,7 @@ function Context.rebuild_symbol_index_for_changed_files(db) local generation = u.random_8byte_int() u.walk_directory(git_root, { - should_process = default_ignore_dirs_and_files, + should_process = make_gitignore_fn(git_root), process_file = function(rel_path, full_path) if vim.filetype.match({ filename = full_path }) then From 89e1545835f9d5d5adf404220c6eb9ca3b144dac Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sun, 4 Aug 2024 15:28:04 +0800 Subject: [PATCH 30/34] Cleans up prints used for debugging --- lua/gp/completion.lua | 7 ------- lua/gp/context.lua | 4 ++-- lua/gp/init.lua | 2 +- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index 96180e7c..5c22c9a6 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -11,7 +11,6 @@ source.src_name = "gp_completion" ---@return CompletionSource function source.new() - print("source.new called") local db_inst = db.open() return setmetatable({ db = db_inst }, { __index = source }) end @@ -41,7 +40,6 @@ function source.setup_for_chat_buffer(bufnr) end function source.register_cmd_source() - print("registering completion src") cmp.register_source(source.src_name, source.new()) end @@ -143,13 +141,11 @@ end function source.complete(self, request, callback) local input = string.sub(request.context.cursor_before_line, request.offset - 1) - print("[comp] input: '" .. input .. "'") local cmd = extract_cmd(request) if not cmd then return end - print("[comp] cmd: '" .. cmd .. "'") local cmd_parts = context.cmd_split(cmd) local items = {} @@ -181,20 +177,17 @@ function source.complete(self, request, callback) items = self:completion_items_for_fn_name(partial_fn_name) isIncomplete = false elseif input:match("^@") then - print("[complete] @ case") items = { { label = "code", kind = require("cmp").lsp.CompletionItemKind.Keyword }, { label = "file", kind = require("cmp").lsp.CompletionItemKind.Keyword }, } isIncomplete = false else - print("[complete] default case") isIncomplete = false end local data = { items = items, isIncomplete = isIncomplete } callback(data) - print("[complete] Callback called") end local function search_backwards(buf, pattern) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index c1688a11..187e4e6f 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -491,7 +491,7 @@ function Context.index_stale() local end_time = uv.hrtime() local elapsed_time_ms = (end_time - start_time) / 1e6 - logger.info(string.format("[Gp] Indexing took: %.2f ms", elapsed_time_ms)) + logger.info(string.format("Indexing took: %.2f ms", elapsed_time_ms)) end function Context.index_all() @@ -507,7 +507,7 @@ function Context.index_all() local end_time = uv.hrtime() local elapsed_time_ms = (end_time - start_time) / 1e6 - logger.info(string.format("[Gp] Indexing took: %.2f ms", elapsed_time_ms)) + logger.info(string.format("Indexing took: %.2f ms", elapsed_time_ms)) end function Context.build_initial_index() diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 5cfac3e0..8cdd6fd7 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -2366,7 +2366,7 @@ M.chat_respond = function(params) -- insert requested context in the message the user just entered messages[#messages].content = require("gp.context").insert_contexts(messages[#messages].content) - print(vim.inspect(messages[#messages])) + -- print(vim.inspect(messages[#messages])) -- call the model and write response M.query( From d8a16519ed55d4edfa5d29ab03c417018eb62227 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sun, 4 Aug 2024 15:39:19 +0800 Subject: [PATCH 31/34] Adds @include command Behaves exactly like @file, but does not inject the file name or a backtick fence around the contents --- lua/gp/completion.lua | 6 ++++-- lua/gp/context.lua | 15 ++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/lua/gp/completion.lua b/lua/gp/completion.lua index 5c22c9a6..12ab28a1 100644 --- a/lua/gp/completion.lua +++ b/lua/gp/completion.lua @@ -150,8 +150,9 @@ function source.complete(self, request, callback) local items = {} local isIncomplete = true + local cmd_type = cmd_parts[1] - if cmd_parts[1]:match("@file") then + if cmd_type:match("@file") or cmd_type:match("@include") then -- What's the path we're trying to provide completion for? local path = cmd_parts[2] or "" @@ -161,7 +162,7 @@ function source.complete(self, request, callback) -- Say that the entire list has been provided -- cmp won't call us again to provide an updated list isIncomplete = false - elseif cmd_parts[1]:match("@code") then + elseif cmd_type:match("@code") then local partial_fn_name = cmd_parts[2] or "" -- When the user confirms completion of an item, we alter the @@ -180,6 +181,7 @@ function source.complete(self, request, callback) items = { { label = "code", kind = require("cmp").lsp.CompletionItemKind.Keyword }, { label = "file", kind = require("cmp").lsp.CompletionItemKind.Keyword }, + { label = "include", kind = require("cmp").lsp.CompletionItemKind.Keyword }, } isIncomplete = false else diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 187e4e6f..e797eda5 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -98,6 +98,9 @@ function Context.insert_contexts(msg) for cmd in msg:gmatch("@file:[%w%p]+") do table.insert(cmds, cmd) end + for cmd in msg:gmatch("@include:[%w%p]+") do + table.insert(cmds, cmd) + end for cmd in msg:gmatch("@code:[%w%p]+[:%w_-]+") do table.insert(cmds, cmd) end @@ -108,8 +111,9 @@ function Context.insert_contexts(msg) -- inserted as additional context for _, cmd in ipairs(cmds) do local cmd_parts = Context.cmd_split(cmd) + local cmd_type = cmd_parts[1] - if cmd_parts[1] == "@file" then + if cmd_type == "@file" or cmd_type == "@include" then -- Read the reqested file and produce a msg snippet to be joined later local filepath = cmd_parts[2] @@ -118,10 +122,15 @@ function Context.insert_contexts(msg) local content = read_file(fullpath) if content then - local result = string.format("%s\n```%s```", filepath, content) + local result + if cmd_type == "@file" then + result = string.format("%s\n```%s```", filepath, content) + else + result = content + end table.insert(context_texts, result) end - elseif cmd_parts[1] == "@code" then + elseif cmd_type == "@code" then local rel_path = cmd_parts[2] local full_fn_name = cmd_parts[3] if not rel_path or not full_fn_name then From dbc7df8721a92dd3ffcb40bd3f0b8a1abcb530dd Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Sun, 4 Aug 2024 16:35:01 +0800 Subject: [PATCH 32/34] Fixes bug where .git/* is indexed when .gitignore does not exist --- lua/gp/context.lua | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index e797eda5..f31d689c 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -384,15 +384,15 @@ end local function make_gitignore_fn(git_root) local base_paths = { git_root } local allow = require("plenary.scandir").__make_gitignore(base_paths) - if not allow then - return nil - end return function(entry, rel_path, full_path, is_dir) if entry == ".git" or entry == ".github" then return false end - return allow(base_paths, full_path) + if allow then + return allow(base_paths, full_path) + end + return true end end From bd30d7b8bbcf9cb0a99067f7ab3f82117bc4f451 Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Mon, 5 Aug 2024 17:06:46 +0800 Subject: [PATCH 33/34] Adds "GpReferenceFunction" to add a @code command for the function under the cursor This simplifies sending the function under the cursor as the chat context. --- lua/gp/context.lua | 33 +++++++++++++++++++++ lua/gp/init.lua | 72 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index f31d689c..000692ab 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -568,6 +568,39 @@ function Context.setup_for_chat_buffer(buf) require("gp.completion").setup_for_chat_buffer(buf) end +-- Inserts the reference to the function under the cursor into the chat buffer +function Context.reference_current_function() + local db = Db.open() + if not db then + return + end + + local buf = vim.api.nvim_get_current_buf() + local rel_path = vim.fn.bufname(buf) + local lineno = math.max(vim.api.nvim_win_get_cursor(0)[1] - 1, 0) + + ---@type boolean|SymbolDefEntry + local res = db.db:eval( + [[ SELECT * from symbols + WHERE + file = ? AND + start_line <= ? AND + end_line >= ? ]], + { rel_path, lineno, lineno } + ) + + db:close() + + if type(res) == "boolean" then + logger.error("[context.reference_current_function] Symbol lookup returned unexpected value: " .. res) + return + end + + local entry = res[1] + + require("gp").chat_paste(string.format("@code:%s:%s", entry.file, entry.name)) +end + Context.setup_autocmd_update_index_on_file_save() return Context diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 8e8b0c89..6174fb2b 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -212,6 +212,10 @@ M.setup = function(opts) require("gp.context").index_all() end, {}) + vim.api.nvim_create_user_command("GpReferenceCurrentFunction", function(_) + require("gp.context").reference_current_function() + end, {}) + M.buf_handler() if vim.fn.executable("curl") == 0 then @@ -880,6 +884,74 @@ M.cmd.ChatToggle = function(params, system_prompt, agent) return buf end +local function win_for_buf(bufnr) + for _, w in ipairs(vim.api.nvim_list_wins()) do + if vim.api.nvim_win_get_buf(w) == bufnr then + return w + end + end +end + +local function create_buffer_with_file(file_path) + -- Create a new buffer + local bufnr = vim.api.nvim_create_buf(true, false) + + -- Set the buffer's name to the file path + vim.api.nvim_buf_set_name(bufnr, file_path) + + -- Load the file into the buffer + vim.api.nvim_buf_call(bufnr, function() + vim.api.nvim_command("edit " .. vim.fn.fnameescape(file_path)) + end) + + return bufnr +end + +-- Paste some content into the chat buffer +M.chat_paste = function(content) + -- locate the chat buffer + local chat_buf + local last = M._state.last_chat + + ------------------------------------------------ + -- Try to locate or setup a valid chat buffer -- + ------------------------------------------------ + -- If don't have a record of the last chat file that's been opened... + -- Just create a new chat + if not last or vim.fn.filereadable(last) ~= 1 then + chat_buf = M.cmd.ChatNew({}, nil, nil) + else + -- We have a record of the last chat file... + -- Can we locate a buffer with the file loaded? + last = vim.fn.resolve(last) + chat_buf = M.helpers.get_buffer(last) + + if not chat_buf then + chat_buf = create_buffer_with_file(last) + end + end + + -------------------------------------------- + -- Paste the content into the chat buffer -- + -------------------------------------------- + if chat_buf then + -- Paste the given `content` at the end of the buffer + vim.api.nvim_buf_set_lines(chat_buf, -1, -1, false, { content }) + + -- If we can locate a window for the buffer... + -- Set the cursor to the end of the file where we just pasted the content + local win = win_for_buf(chat_buf) + if win then + local line_count = vim.api.nvim_buf_line_count(chat_buf) + vim.api.nvim_win_set_cursor(win, { line_count, 0 }) + + vim.api.nvim_win_call(win, function() + vim.api.nvim_command("normal! zz") + end) + end + end +end + M.cmd.ChatPaste = function(params) -- if there is no selection, do nothing if params.range ~= 2 then From 24a2cae722c293181edce3957aaffc73d1b5032c Mon Sep 17 00:00:00 2001 From: Jonathan Shieh Date: Fri, 9 Aug 2024 11:57:47 +0800 Subject: [PATCH 34/34] Adds "GpReferenceCurrentFile" command to add a @file command the file/buffer the user is examining. --- lua/gp/context.lua | 6 ++++++ lua/gp/init.lua | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/lua/gp/context.lua b/lua/gp/context.lua index 000692ab..53ea5cd5 100644 --- a/lua/gp/context.lua +++ b/lua/gp/context.lua @@ -601,6 +601,12 @@ function Context.reference_current_function() require("gp").chat_paste(string.format("@code:%s:%s", entry.file, entry.name)) end +function Context.reference_current_file() + local buf = vim.api.nvim_get_current_buf() + local rel_path = vim.fn.bufname(buf) + require("gp").chat_paste(string.format("@file:%s", rel_path)) +end + Context.setup_autocmd_update_index_on_file_save() return Context diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 6174fb2b..f4d1e43d 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -216,6 +216,10 @@ M.setup = function(opts) require("gp.context").reference_current_function() end, {}) + vim.api.nvim_create_user_command("GpReferenceCurrentFile", function(_) + require("gp.context").reference_current_file() + end, {}) + M.buf_handler() if vim.fn.executable("curl") == 0 then