diff --git a/lua/jumpy/init.lua b/lua/jumpy/init.lua index a24b9b5..bb8513d 100644 --- a/lua/jumpy/init.lua +++ b/lua/jumpy/init.lua @@ -24,6 +24,17 @@ M.config = { "- Do NOT wrap in markdown code fences", "- Do NOT explain", }, "\n"), + system_prompt_multi_file = table.concat({ + "When multiple files are provided, prefix the SEARCH marker with the file path:", + "<<<< SEARCH path/to/file.lua", + "exact existing lines from that file", + "====", + "replacement lines", + ">>>> REPLACE", + "", + "The path must exactly match the path shown in the --- FILE: ... --- header.", + "You may edit any subset of the provided files. Every SEARCH block MUST include a path.", + }, "\n"), keymaps = { prompt = "j", next_hunk = "]h", diff --git a/lua/jumpy/llm.lua b/lua/jumpy/llm.lua index a6bb789..2e50e5a 100644 --- a/lua/jumpy/llm.lua +++ b/lua/jumpy/llm.lua @@ -4,9 +4,32 @@ local function get_config() return require("jumpy").config end +local function build_file_block(path, contents) + return string.format("--- FILE: %s ---\n%s\n--- END FILE ---", path, contents) +end + local function build_messages(context) local config = get_config() + local tagged = context.tagged_files + if tagged and #tagged > 0 then + local parts = {} + for _, file in ipairs(tagged) do + table.insert(parts, build_file_block(file.path, table.concat(file.lines, "\n"))) + end + if context.symbols and context.symbols ~= "" then + table.insert(parts, context.symbols) + end + table.insert(parts, "") + table.insert(parts, "Instruction: " .. context.prompt) + local user_content = table.concat(parts, "\n") + local system = config.system_prompt .. "\n\n" .. config.system_prompt_multi_file + return { + { role = "system", content = system }, + { role = "user", content = user_content }, + } + end + local user_content = string.format( "File type: %s\n\n--- FILE CONTENTS ---\n%s\n--- END FILE ---%s\n\nInstruction: %s", context.filetype or "text", diff --git a/lua/jumpy/patch.lua b/lua/jumpy/patch.lua index 5c55ecf..6390796 100644 --- a/lua/jumpy/patch.lua +++ b/lua/jumpy/patch.lua @@ -35,13 +35,26 @@ local function find_lines(haystack, needle) return nil end +local function parse_search_marker(line) + local rest = line:match("^<<<< SEARCH%s*(.*)$") + if rest == nil then + return nil + end + rest = rest:match("^%s*(.-)%s*$") + if rest == "" then + return nil + end + return rest +end + function M.parse(text) local blocks = {} local lines = split_lines(text) local i = 1 while i <= #lines do - if lines[i]:match("^<<<< SEARCH%s*$") then + local path = parse_search_marker(lines[i]) + if path ~= nil or lines[i]:match("^<<<< SEARCH%s*$") then local search_lines = {} local replace_lines = {} i = i + 1 @@ -59,6 +72,7 @@ function M.parse(text) end table.insert(blocks, { + path = path, search = search_lines, replace = replace_lines, }) @@ -69,13 +83,7 @@ function M.parse(text) return blocks end -function M.apply(original_lines, response_text) - local blocks = M.parse(response_text) - - if #blocks == 0 then - return split_lines(response_text), 0 - end - +local function apply_blocks(original_lines, blocks) local lines = {} for _, l in ipairs(original_lines) do table.insert(lines, l) @@ -87,14 +95,14 @@ function M.apply(original_lines, response_text) local pos = find_lines(lines, block.search) if pos then local new = {} - for i = 1, pos - 1 do - table.insert(new, lines[i]) + for j = 1, pos - 1 do + table.insert(new, lines[j]) end for _, l in ipairs(block.replace) do table.insert(new, l) end - for i = pos + #block.search, #lines do - table.insert(new, lines[i]) + for j = pos + #block.search, #lines do + table.insert(new, lines[j]) end lines = new else @@ -105,4 +113,51 @@ function M.apply(original_lines, response_text) return lines, unmatched end +function M.apply(original_lines, response_text) + local blocks = M.parse(response_text) + + if #blocks == 0 then + return split_lines(response_text), 0 + end + + return apply_blocks(original_lines, blocks) +end + +function M.apply_by_file(files_by_path, response_text, primary_path) + assert(primary_path, "apply_by_file requires a primary_path") + + local blocks = M.parse(response_text) + + if #blocks == 0 then + if files_by_path[primary_path] then + local lines, unmatched = M.apply(files_by_path[primary_path], response_text) + return { [primary_path] = { lines = lines, unmatched = unmatched } }, unmatched + end + return {}, 0 + end + + local grouped = {} + for _, block in ipairs(blocks) do + local key = block.path or primary_path + grouped[key] = grouped[key] or {} + table.insert(grouped[key], block) + end + + local results = {} + local total_unmatched = 0 + + for path, file_blocks in pairs(grouped) do + local original = files_by_path[path] + if not original then + total_unmatched = total_unmatched + #file_blocks + else + local lines, unmatched = apply_blocks(original, file_blocks) + results[path] = { lines = lines, unmatched = unmatched } + total_unmatched = total_unmatched + unmatched + end + end + + return results, total_unmatched +end + return M diff --git a/lua/jumpy/prompt.lua b/lua/jumpy/prompt.lua index 12357ed..f181bc1 100644 --- a/lua/jumpy/prompt.lua +++ b/lua/jumpy/prompt.lua @@ -12,6 +12,21 @@ local state = { local mention_ns = vim.api.nvim_create_namespace("jumpy_mentions") +local function index_tagged_files(tagged_files) + local by_path = {} + for _, file in ipairs(tagged_files) do + by_path[file.path] = file + end + return by_path +end + +local function buffer_for_tagged_file(tags, file) + if file.bufnr and vim.api.nvim_buf_is_valid(file.bufnr) then + return file.bufnr + end + return tags.open_buffer(file.abs_path) +end + local function highlight_mentions(buf) if not vim.api.nvim_buf_is_valid(buf) then return @@ -181,12 +196,35 @@ function M._submit() end local source_buf = state.source_buf + local tags = require("jumpy.tags") local source_lines = state.visual_selection and vim.split(state.visual_selection.text, "\n", { plain = true }) or vim.api.nvim_buf_get_lines(source_buf, 0, -1, false) + local source_name = vim.api.nvim_buf_get_name(source_buf) + local source_rel = source_name ~= "" and tags.rel_path(source_name, tags.project_root()) or "current" + + local parsed = tags.parse(prompt_text, { + source = { + path = source_rel, + abs_path = source_name ~= "" and tags.normalize_abs(source_name) or nil, + lines = source_lines, + bufnr = source_buf, + }, + }) + + local cleaned_prompt = parsed.cleaned_prompt + local tagged_files = parsed.tagged + + if #parsed.errors > 0 then + for _, err in ipairs(parsed.errors) do + vim.notify("jumpy: " .. err, vim.log.levels.WARN) + end + end + local filetype = vim.bo[source_buf].filetype local reprompt_idx = state.reprompt_hunk_idx + local is_multi_file = #tagged_files > 1 local llm = require("jumpy.llm") @@ -213,7 +251,7 @@ function M._submit() local context = { original_lines = hunk.removed_lines, proposed_lines = hunk.added_lines, - prompt = prompt_text, + prompt = cleaned_prompt, symbols = symbols, filetype = filetype, } @@ -227,10 +265,70 @@ function M._submit() vim.notify("jumpy: hunk updated", vim.log.levels.INFO) end) end) + elseif is_multi_file then + local context = { + file_contents = table.concat(source_lines, "\n"), + tagged_files = tagged_files, + primary_path = source_rel, + prompt = cleaned_prompt, + symbols = symbols, + filetype = filetype, + } + + llm.request(context, function(response_text) + vim.schedule(function() + local diff = require("jumpy.diff") + local render = require("jumpy.render") + local patch = require("jumpy.patch") + + local files_by_path = {} + for _, file in ipairs(tagged_files) do + files_by_path[file.path] = file.lines + end + + local results, total_unmatched = patch.apply_by_file(files_by_path, response_text, source_rel) + + if total_unmatched > 0 then + vim.notify(string.format("jumpy: %d block(s) could not be matched", total_unmatched), vim.log.levels.WARN) + end + + local tagged_by_path = index_tagged_files(tagged_files) + local total_hunks = 0 + + for path, result in pairs(results) do + local file = tagged_by_path[path] + if file then + local bufnr = buffer_for_tagged_file(tags, file) + if bufnr then + local hunks = diff.compute(file.lines, result.lines) + if #hunks > 0 then + render.show(bufnr, hunks, file.lines, result.lines) + total_hunks = total_hunks + #hunks + end + else + vim.notify("jumpy: could not open " .. path .. ", skipping", vim.log.levels.WARN) + end + end + end + + if total_hunks == 0 then + vim.notify("jumpy: no changes proposed", vim.log.levels.INFO) + return + end + + vim.notify( + string.format("jumpy: %d hunk(s) proposed across %d file(s)", total_hunks, vim.tbl_count(results)), + vim.log.levels.INFO + ) + + local nav = require("jumpy.navigate") + nav.next_hunk() + end) + end) else local context = { file_contents = table.concat(source_lines, "\n"), - prompt = prompt_text, + prompt = cleaned_prompt, symbols = symbols, filetype = filetype, } @@ -278,7 +376,7 @@ function M._submit() end if prompt_text:find("@lsp") then - prompt_text = vim.trim(prompt_text:gsub("%f[%w@]@lsp%f[%W]", "")) + cleaned_prompt = vim.trim(cleaned_prompt:gsub("%f[%w@]@lsp%f[%W]", "")) context_tools.get_workspace_symbols(tonumber(source_buf) or 0, send_request) else diff --git a/lua/jumpy/tags.lua b/lua/jumpy/tags.lua new file mode 100644 index 0000000..7c666de --- /dev/null +++ b/lua/jumpy/tags.lua @@ -0,0 +1,286 @@ +local M = {} + +M.MAX_BYTES = 256 * 1024 +M.MAX_LINES = 2000 + +local RESERVED = { + lsp = true, +} + +local function word_boundary_before(text, pos) + if pos <= 1 then + return true + end + return not text:sub(pos - 1, pos - 1):match("[%w@]") +end + +local function word_boundary_after(text, pos) + if pos >= #text then + return true + end + return not text:sub(pos + 1, pos + 1):match("[%w]") +end + +local function normalize_mention_path(path) + path = path:gsub("/+$", "") + path = path:gsub("%.$", "") + return path +end + +local function mention_remove_len(raw) + local path = normalize_mention_path(raw) + if raw:sub(-1) == "." and #raw == #path + 1 then + return #path + end + return #raw +end + +function M.find_mentions(text) + local mentions = {} + local seen = {} + local search_from = 1 + + while search_from <= #text do + local at = text:find("@", search_from, true) + if not at then + break + end + + if word_boundary_before(text, at) then + local rest = text:sub(at + 1) + local raw = rest:match("^([%.%w%-_/]+)") + local path = raw and normalize_mention_path(raw) or nil + if path and path ~= "" and not RESERVED[path] and word_boundary_after(text, at + #raw) then + if not seen[path] then + seen[path] = true + table.insert(mentions, path) + end + search_from = at + #raw + 1 + else + search_from = at + 1 + end + else + search_from = at + 1 + end + end + + return mentions +end + +local function trim(text) + return (text:gsub("^%s+", ""):gsub("%s+$", "")) +end + +function M.strip_mentions(text) + local stripped = text + local search_from = 1 + + while search_from <= #stripped do + local at = stripped:find("@", search_from, true) + if not at then + break + end + + if word_boundary_before(stripped, at) then + local rest = stripped:sub(at + 1) + local raw = rest:match("^([%.%w%-_/]+)") + local path = raw and normalize_mention_path(raw) or nil + if path and path ~= "" and not RESERVED[path] and word_boundary_after(stripped, at + #raw) then + local remove_len = mention_remove_len(raw) + stripped = stripped:sub(1, at - 1) .. stripped:sub(at + remove_len + 1) + search_from = at + else + search_from = at + 1 + end + else + search_from = at + 1 + end + end + + return trim((stripped:gsub("%s+", " "))) +end + +function M.normalize_abs(path) + if vim and vim.fn and vim.fn.fnamemodify then + path = vim.fn.fnamemodify(path, ":p") + end + if path:sub(-1) == "/" then + path = path:sub(1, -2) + end + return path +end + +function M.resolve_path(raw_path, root) + root = M.normalize_abs(root or (vim and vim.fn and vim.fn.getcwd() or ".")) + if raw_path:sub(1, 1) == "/" then + return M.normalize_abs(raw_path) + end + return M.normalize_abs(root .. "/" .. raw_path) +end + +function M.rel_path(abs_path, root) + abs_path = M.normalize_abs(abs_path) + root = M.normalize_abs(root) + local prefix = root .. "/" + if abs_path:sub(1, #prefix) == prefix then + return abs_path:sub(#prefix + 1) + end + return abs_path +end + +local function slice_lines(lines, count) + local out = {} + for i = 1, math.min(count, #lines) do + out[i] = lines[i] + end + return out +end + +function M.truncate_lines(lines) + local truncated = false + if #lines > M.MAX_LINES then + lines = slice_lines(lines, M.MAX_LINES) + truncated = true + end + return lines, truncated +end + +function M.project_root() + local cwd = vim.fn.getcwd() + if vim.system then + local result = vim.system({ "git", "rev-parse", "--show-toplevel" }, { cwd = cwd }):wait() + if result.code == 0 then + local root = vim.trim(result.stdout or "") + if root ~= "" then + return M.normalize_abs(root) + end + end + end + return M.normalize_abs(cwd) +end + +function M.find_bufnr(abs_path) + abs_path = M.normalize_abs(abs_path) + for _, bufnr in ipairs(vim.api.nvim_list_bufs()) do + if vim.api.nvim_buf_is_loaded(bufnr) then + local name = vim.api.nvim_buf_get_name(bufnr) + if name ~= "" and M.normalize_abs(name) == abs_path then + return bufnr + end + end + end + return nil +end + +function M.open_buffer(abs_path) + abs_path = M.normalize_abs(abs_path) + + local bufnr = M.find_bufnr(abs_path) + if bufnr then + return bufnr + end + + if vim.fn.filereadable(abs_path) ~= 1 then + return nil + end + + bufnr = vim.fn.bufadd(abs_path) + vim.fn.bufload(bufnr) + return bufnr +end + +function M.read_lines(abs_path, opts) + opts = opts or {} + + if opts.read_file then + return opts.read_file(abs_path) + end + + local bufnr = M.find_bufnr(abs_path) + if bufnr then + local lines, truncated = M.truncate_lines(vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)) + local err = truncated and string.format("file exceeds %d line limit: %s", M.MAX_LINES, abs_path) or nil + return lines, err, bufnr + end + + local fd = vim.uv and vim.uv.fs_open(abs_path, "r", 438) or nil + if not fd then + return nil, "file not found: " .. abs_path + end + + local stat = vim.uv.fs_fstat(fd) + if stat and stat.size > M.MAX_BYTES then + vim.uv.fs_close(fd) + return nil, string.format("file exceeds %d byte limit: %s", M.MAX_BYTES, abs_path) + end + + local data = vim.uv.fs_read(fd, M.MAX_BYTES) + vim.uv.fs_close(fd) + + if not data then + return nil, "could not read file: " .. abs_path + end + + if data:sub(-1) == "\n" then + data = data:sub(1, -2) + end + + local lines = data == "" and {} or vim.split(data, "\n", { plain = true }) + local truncated + lines, truncated = M.truncate_lines(lines) + if truncated then + return lines, string.format("file exceeds %d line limit: %s", M.MAX_LINES, abs_path) + end + + return lines, nil, nil +end + +function M.parse(prompt_text, opts) + opts = opts or {} + local root = opts.root or M.project_root() + local mentions = M.find_mentions(prompt_text) + local tagged = {} + local errors = {} + local seen_abs = {} + + if opts.source then + local src = opts.source + local abs_path = M.normalize_abs(src.abs_path or M.resolve_path(src.path, root)) + seen_abs[abs_path] = true + table.insert(tagged, { + path = src.path or M.rel_path(abs_path, root), + abs_path = abs_path, + lines = src.lines, + bufnr = src.bufnr, + }) + end + + for _, raw_path in ipairs(mentions) do + local abs_path = M.resolve_path(raw_path, root) + if not seen_abs[abs_path] then + local lines, err, bufnr = M.read_lines(abs_path, opts) + + if not lines then + table.insert(errors, err or ("could not read: " .. raw_path)) + else + table.insert(tagged, { + path = M.rel_path(abs_path, root), + abs_path = abs_path, + lines = lines, + bufnr = bufnr, + }) + if err then + table.insert(errors, err) + end + end + end + end + + return { + tagged = tagged, + cleaned_prompt = M.strip_mentions(prompt_text), + errors = errors, + } +end + +return M diff --git a/tests/patch_spec.lua b/tests/patch_spec.lua index a3bb24c..c137525 100644 --- a/tests/patch_spec.lua +++ b/tests/patch_spec.lua @@ -14,10 +14,47 @@ describe("patch.parse", function() local blocks = patch.parse(text) assert.are.equal(1, #blocks) + assert.is_nil(blocks[1].path) assert.are.same({ "old line" }, blocks[1].search) assert.are.same({ "new line" }, blocks[1].replace) end) + it("parses a search/replace block with file path", function() + local text = table.concat({ + "<<<< SEARCH lua/jumpy/foo.lua", + "old line", + "====", + "new line", + ">>>> REPLACE", + }, "\n") + local blocks = patch.parse(text) + + assert.are.equal(1, #blocks) + assert.are.equal("lua/jumpy/foo.lua", blocks[1].path) + assert.are.same({ "old line" }, blocks[1].search) + assert.are.same({ "new line" }, blocks[1].replace) + end) + + it("parses mixed path and pathless blocks", function() + local text = table.concat({ + "<<<< SEARCH", + "aaa", + "====", + "bbb", + ">>>> REPLACE", + "<<<< SEARCH bar.lua", + "ccc", + "====", + "ddd", + ">>>> REPLACE", + }, "\n") + local blocks = patch.parse(text) + + assert.are.equal(2, #blocks) + assert.is_nil(blocks[1].path) + assert.are.equal("bar.lua", blocks[2].path) + end) + it("parses multiple blocks", function() local text = table.concat({ "<<<< SEARCH", @@ -203,3 +240,75 @@ describe("patch.apply", function() assert.are.same({ " if true then", " print('hello')", " end" }, result) end) end) + +describe("patch.apply_by_file", function() + it("routes blocks to files by path", function() + local files = { + ["lua/a.lua"] = { "a-old", "shared" }, + ["lua/b.lua"] = { "b-old", "shared" }, + } + local response = table.concat({ + "<<<< SEARCH lua/a.lua", + "a-old", + "====", + "a-new", + ">>>> REPLACE", + "<<<< SEARCH lua/b.lua", + "b-old", + "====", + "b-new", + ">>>> REPLACE", + }, "\n") + + local results, unmatched = patch.apply_by_file(files, response, "lua/a.lua") + assert.are.equal(0, unmatched) + assert.are.same({ "a-new", "shared" }, results["lua/a.lua"].lines) + assert.are.same({ "b-new", "shared" }, results["lua/b.lua"].lines) + end) + + it("applies pathless blocks to the primary file", function() + local files = { + ["main.lua"] = { "old", "keep" }, + ["other.lua"] = { "x" }, + } + local response = table.concat({ + "<<<< SEARCH", + "old", + "====", + "new", + ">>>> REPLACE", + }, "\n") + + local results, unmatched = patch.apply_by_file(files, response, "main.lua") + assert.are.equal(0, unmatched) + assert.are.same({ "new", "keep" }, results["main.lua"].lines) + assert.is_nil(results["other.lua"]) + end) + + it("counts blocks for unknown files as unmatched", function() + local files = { + ["known.lua"] = { "a" }, + } + local response = table.concat({ + "<<<< SEARCH missing.lua", + "a", + "====", + "b", + ">>>> REPLACE", + }, "\n") + + local results, unmatched = patch.apply_by_file(files, response, "known.lua") + assert.are.equal(1, unmatched) + assert.are.same({}, results) + end) + + it("falls back to full-file replace when no blocks are found", function() + local files = { + ["main.lua"] = { "old" }, + } + + local results, unmatched = patch.apply_by_file(files, "new\nlines", "main.lua") + assert.are.equal(0, unmatched) + assert.are.same({ "new", "lines" }, results["main.lua"].lines) + end) +end) diff --git a/tests/tags_spec.lua b/tests/tags_spec.lua new file mode 100644 index 0000000..d9de760 --- /dev/null +++ b/tests/tags_spec.lua @@ -0,0 +1,190 @@ +package.path = package.path .. ";lua/?.lua;lua/?/init.lua" + +_G.vim = _G.vim or {} +_G.vim.fn = _G.vim.fn or {} +_G.vim.fn.getcwd = _G.vim.fn.getcwd or function() + return "/project" +end +_G.vim.fn.fnamemodify = _G.vim.fn.fnamemodify + or function(path, mod) + if mod == ":p" then + if path:sub(1, 1) == "/" then + return path:sub(-1) == "/" and path:sub(1, -2) or path + end + local joined = "/project/" .. path + while joined:find("/%./") do + joined = joined:gsub("/%./", "/") + end + return joined:sub(-1) == "/" and joined:sub(1, -2) or joined + end + return path + end + +local tags = require("jumpy.tags") + +describe("tags.find_mentions", function() + it("finds a single file mention", function() + assert.are.same({ "lua/jumpy/foo.lua" }, tags.find_mentions("move helpers from @lua/jumpy/foo.lua")) + end) + + it("finds multiple unique mentions", function() + local mentions = tags.find_mentions("merge @lua/a.lua into @lua/b.lua and @lua/a.lua again") + assert.are.same({ "lua/a.lua", "lua/b.lua" }, mentions) + end) + + it("ignores reserved @lsp", function() + assert.are.same({}, tags.find_mentions("use @lsp to find symbols")) + end) + + it("ignores email addresses", function() + assert.are.same({}, tags.find_mentions("contact me at user@domain.com")) + end) + + it("finds simple filenames", function() + assert.are.same({ "foo.c" }, tags.find_mentions("take methods in @foo.c")) + end) + + it("ignores trailing sentence punctuation", function() + assert.are.same({ "foo.lua" }, tags.find_mentions("update @foo.lua.")) + end) + + it("ignores trailing slashes", function() + assert.are.same({ "lua/jumpy" }, tags.find_mentions("look at @lua/jumpy/")) + end) +end) + +describe("tags.strip_mentions", function() + it("removes file mentions and trims whitespace", function() + assert.are.equal( + "move helpers from into this file", + tags.strip_mentions("move helpers from @lua/jumpy/foo.lua into this file") + ) + end) + + it("removes mentions followed by sentence punctuation", function() + assert.are.equal("update .", tags.strip_mentions("update @foo.lua.")) + end) + + it("removes mentions followed by trailing slashes", function() + assert.are.equal("look at", tags.strip_mentions("look at @lua/jumpy/")) + end) +end) + +describe("tags.resolve_path", function() + it("resolves relative paths against root", function() + assert.are.equal("/project/lua/foo.lua", tags.resolve_path("lua/foo.lua", "/project")) + end) + + it("keeps absolute paths unchanged", function() + assert.are.equal("/tmp/foo.lua", tags.resolve_path("/tmp/foo.lua", "/project")) + end) +end) + +describe("tags.rel_path", function() + it("returns a path relative to root", function() + assert.are.equal("lua/foo.lua", tags.rel_path("/project/lua/foo.lua", "/project")) + end) + + it("returns absolute path when outside root", function() + assert.are.equal("/tmp/foo.lua", tags.rel_path("/tmp/foo.lua", "/project")) + end) +end) + +describe("tags.truncate_lines", function() + it("passes through small files", function() + local lines = { "a", "b" } + local out, truncated = tags.truncate_lines(lines) + assert.are.same(lines, out) + assert.is_false(truncated) + end) + + it("truncates at the line limit", function() + local lines = {} + for i = 1, tags.MAX_LINES + 10 do + lines[i] = tostring(i) + end + + local out, truncated = tags.truncate_lines(lines) + assert.are.equal(tags.MAX_LINES, #out) + assert.is_true(truncated) + end) +end) + +describe("tags.parse", function() + it("loads tagged files and strips mentions from the prompt", function() + local files = { + ["/project/lua/a.lua"] = { "a" }, + ["/project/lua/b.lua"] = { "b" }, + } + + local result = tags.parse("merge @lua/a.lua into @lua/b.lua", { + root = "/project", + read_file = function(abs_path) + return files[abs_path], nil, nil + end, + }) + + assert.are.equal("merge into", result.cleaned_prompt) + assert.are.equal(2, #result.tagged) + assert.are.equal("lua/a.lua", result.tagged[1].path) + assert.are.equal("lua/b.lua", result.tagged[2].path) + assert.are.same({ "a" }, result.tagged[1].lines) + assert.are.same({ "b" }, result.tagged[2].lines) + assert.are.equal(0, #result.errors) + end) + + it("records errors for missing files", function() + local result = tags.parse("fix @missing.lua", { + root = "/project", + read_file = function() + return nil, "file not found" + end, + }) + + assert.are.equal(0, #result.tagged) + assert.are.equal(1, #result.errors) + assert.are.equal("fix", result.cleaned_prompt) + end) + + it("includes the source buffer before @ mentions", function() + local result = tags.parse("merge into @lua/b.lua", { + root = "/project", + source = { + path = "lua/a.lua", + abs_path = "/project/lua/a.lua", + lines = { "source" }, + }, + read_file = function(abs_path) + if abs_path == "/project/lua/b.lua" then + return { "b" }, nil, nil + end + return nil, "file not found" + end, + }) + + assert.are.equal(2, #result.tagged) + assert.are.equal("lua/a.lua", result.tagged[1].path) + assert.are.equal("lua/b.lua", result.tagged[2].path) + end) + + it("does not duplicate the source when it is also @ mentioned", function() + local result = tags.parse("also update @lua/a.lua", { + root = "/project", + source = { + path = "lua/a.lua", + abs_path = "/project/lua/a.lua", + lines = { "source" }, + }, + read_file = function(abs_path) + if abs_path == "/project/lua/a.lua" then + return { "from disk" }, nil, nil + end + return nil, "file not found" + end, + }) + + assert.are.equal(1, #result.tagged) + assert.are.equal("lua/a.lua", result.tagged[1].path) + assert.are.same({ "source" }, result.tagged[1].lines) + end) +end)