diff --git a/lua/opencode/config.lua b/lua/opencode/config.lua index 295b20de..bc010aaf 100644 --- a/lua/opencode/config.lua +++ b/lua/opencode/config.lua @@ -142,6 +142,11 @@ M.defaults = { rendering = { markdown_debounce_ms = 250, on_data_rendered = nil, + markdown_on_idle = false, + -- If set to a number, markdown rendering will be deferred while + -- `state.user_message_count[session_id]` is greater than this value. + -- If `nil`, the existing behavior is used (defer while > 0). + markdown_on_idle_threshold = nil, event_throttle_ms = 40, event_collapsing = true, }, @@ -255,6 +260,8 @@ M.defaults = { enabled = false, capture_streamed_events = false, show_ids = true, + highlight_changed_lines = true, + highlight_changed_lines_timeout_ms = 120, quick_chat = { keep_session = false, set_active_session = false, diff --git a/lua/opencode/core.lua b/lua/opencode/core.lua index 7ceb3a3f..d4f30b19 100644 --- a/lua/opencode/core.lua +++ b/lua/opencode/core.lua @@ -537,6 +537,8 @@ M.initialize_current_model = Promise.async(function() end) M._on_user_message_count_change = Promise.async(function(_, new, old) + require('opencode.ui.renderer.flush').flush_pending_on_data_rendered() + if config.hooks and config.hooks.on_done_thinking then local all_sessions = session.get_all_workspace_sessions():await() local done_sessions = vim.tbl_filter(function(s) diff --git a/lua/opencode/types.lua b/lua/opencode/types.lua index 3129acaf..a2b6d545 100644 --- a/lua/opencode/types.lua +++ b/lua/opencode/types.lua @@ -172,6 +172,7 @@ ---@class OpencodeUIOutputRenderingConfig ---@field markdown_debounce_ms number ---@field on_data_rendered (fun(buf: integer, win: integer)|boolean)|nil +---@field markdown_on_idle boolean ---@field event_throttle_ms number ---@field event_collapsing boolean @@ -207,6 +208,8 @@ ---@field enabled boolean ---@field capture_streamed_events boolean ---@field show_ids boolean +---@field highlight_changed_lines boolean +---@field highlight_changed_lines_timeout_ms number ---@field quick_chat {keep_session: boolean, set_active_session: boolean} ---@class OpencodeHooks diff --git a/lua/opencode/ui/dialog.lua b/lua/opencode/ui/dialog.lua index 2a0893db..0ab064ed 100644 --- a/lua/opencode/ui/dialog.lua +++ b/lua/opencode/ui/dialog.lua @@ -261,7 +261,11 @@ function Dialog:format_dialog(output, config) local end_line = output:get_line_count() if config.border_hl then - formatter.add_vertical_border(output, start_line + 1, end_line, config.border_hl, -2) + local border_end = end_line + if config.extend_border_to_trailing_blank then + border_end = border_end + 1 + end + formatter.add_vertical_border(output, start_line + 1, border_end, config.border_hl, -2) end output:add_line('') diff --git a/lua/opencode/ui/formatter.lua b/lua/opencode/ui/formatter.lua index dc025960..4d72f812 100644 --- a/lua/opencode/ui/formatter.lua +++ b/lua/opencode/ui/formatter.lua @@ -61,7 +61,6 @@ function M._format_revert_message(session_data, start_idx) local message_text = stats.messages == 1 and 'message' or 'messages' local tool_text = stats.tool_calls == 1 and 'tool call' or 'tool calls' - output:add_lines(M.separator) output:add_line( string.format('> %d %s reverted, %d %s reverted', stats.messages, message_text, stats.tool_calls, tool_text) ) @@ -95,12 +94,14 @@ function M._format_revert_message(session_data, start_idx) end end end + + output:add_empty_line() return output end local function add_action(output, text, action_type, args, key, line) -- actions use api-indexing (e.g. 0 indexed) - line = (line or output:get_line_count()) - 1 + line = (line or output:get_line_count()) - 2 output:add_action({ text = text, type = action_type, @@ -154,6 +155,11 @@ end function M.format_message_header(message) local output = Output.new() + if message.info and message.info.id == '__opencode_revert_message__' then + output:add_lines(M.separator) + return output + end + output:add_lines(M.separator) local role = message.info.role or 'unknown' local icon = message.info.role == 'user' and icons.get('header_user') or icons.get('header_assistant') @@ -167,18 +173,13 @@ function M.format_message_header(message) local display_name if role == 'assistant' then local mode = message.info.mode - if mode and mode ~= '' then - display_name = mode:upper() - else - -- For the most recent assistant message, show current_mode if mode is missing - -- This handles new messages that haven't been stamped yet - local is_last_message = #state.messages == 0 or message.info.id == state.messages[#state.messages].info.id - if is_last_message and state.current_mode and state.current_mode ~= '' then + if mode and mode ~= '' then + display_name = mode:upper() + elseif state.current_mode and state.current_mode ~= '' then display_name = state.current_mode:upper() else display_name = 'ASSISTANT' end - end else display_name = role:upper() end @@ -291,11 +292,32 @@ end ---@param output Output Output object to write to ---@param part OpencodeMessagePart function M._format_selection_context(output, part) + local part_message = part._message_context local json = context_module.decode_json_context(part.text or '', 'selection') if not json then return end - local start_line = output:get_line_count() + local start_line = output:get_line_count() + 1 + + if part_message and part_message.parts then + for i, message_part in ipairs(part_message.parts) do + if message_part.id == part.id then + local previous_part = part_message.parts[i - 1] + if previous_part and previous_part.type == 'text' and previous_part.synthetic then + local has_selection = context_module.decode_json_context(previous_part.text or '', 'selection') ~= nil + local has_cursor = context_module.decode_json_context(previous_part.text or '', 'cursor-data') ~= nil + local diagnostics = context_module.decode_json_context(previous_part.text or '', 'diagnostics') + local has_diagnostics = diagnostics and diagnostics.content and type(diagnostics.content) == 'table' and #diagnostics.content > 0 + + if has_selection or has_cursor or has_diagnostics then + start_line = output:get_line_count() + end + end + break + end + end + end + output:add_lines(vim.split(json.content or '', '\n')) output:add_empty_line() @@ -359,6 +381,75 @@ function M._format_diagnostics_context(output, part) M.add_vertical_border(output, start_line, end_line, 'OpencodeMessageRoleUser', -3) end +local function get_visible_user_part_kind(part) + if not part then + return nil + end + + if part.type == 'file' and part.filename and part.filename ~= '' then + return 'file' + end + + if part.type ~= 'text' or not part.text or part.text == '' then + return nil + end + + if not part.synthetic then + return 'text' + end + + if context_module.decode_json_context(part.text, 'selection') then + return 'selection' + end + + if context_module.decode_json_context(part.text, 'cursor-data') then + return 'cursor-data' + end + + local diagnostics = context_module.decode_json_context(part.text, 'diagnostics') + if diagnostics and diagnostics.content and type(diagnostics.content) == 'table' and #diagnostics.content > 0 then + return 'diagnostics' + end + + return nil +end + +local function get_user_part_neighbors(message, part) + if not message or not message.parts or not part or not part.id then + return nil, nil + end + + local current_index = nil + for i, message_part in ipairs(message.parts) do + if message_part.id == part.id then + current_index = i + break + end + end + + if not current_index then + return nil, nil + end + + local previous_kind = nil + for i = current_index - 1, 1, -1 do + previous_kind = get_visible_user_part_kind(message.parts[i]) + if previous_kind then + break + end + end + + local next_kind = nil + for i = current_index + 1, #message.parts do + next_kind = get_visible_user_part_kind(message.parts[i]) + if next_kind then + break + end + end + + return previous_kind, next_kind +end + ---Format and display the file path in the context ---@param output Output Output object to write to ---@param path string|nil File path @@ -450,19 +541,14 @@ end ---@param win_col number ---@param text_hl_group? string Optional highlight group for the background/foreground of text lines function M.add_vertical_border(output, start_line, end_line, hl_group, win_col, text_hl_group) + local extmark_opts = { + virt_text = { { require('opencode.ui.icons').get('border'), hl_group } }, + virt_text_pos = 'overlay', + virt_text_win_col = win_col, + virt_text_repeat_linebreak = true, + line_hl_group = text_hl_group or nil, + } for line = start_line, end_line do - local extmark_opts = { - virt_text = { { require('opencode.ui.icons').get('border'), hl_group } }, - virt_text_pos = 'overlay', - virt_text_win_col = win_col, - virt_text_repeat_linebreak = true, - } - - -- Add line highlight if text_hl_group is provided - if text_hl_group then - extmark_opts.line_hl_group = text_hl_group - end - output:add_extmark(line - 1, extmark_opts --[[@as OutputExtmark]]) end end @@ -486,9 +572,11 @@ function M.format_part(part, message, is_last_part, get_child_parts) if role == 'user' then if part.type == 'text' and part.text then if part.synthetic == true then + part._message_context = message M._format_selection_context(output, part) M._format_cursor_data_context(output, part) M._format_diagnostics_context(output, part) + part._message_context = nil else M._format_user_prompt(output, vim.trim(part.text), message) content_added = true @@ -496,7 +584,18 @@ function M.format_part(part, message, is_last_part, get_child_parts) elseif part.type == 'file' then local file_line = M._format_context_file(output, part.filename) if file_line then - M.add_vertical_border(output, file_line - 1, file_line, 'OpencodeMessageRoleUser', -3) + local previous_kind, next_kind = get_user_part_neighbors(message, part) + local previous_is_context = previous_kind == 'selection' + or previous_kind == 'cursor-data' + or previous_kind == 'diagnostics' + + if next_kind == 'text' or (previous_is_context and not next_kind) then + M.add_vertical_border(output, file_line - 1, file_line, 'OpencodeMessageRoleUser', -3) + elseif next_kind == 'file' then + M.add_vertical_border(output, file_line, file_line + 1, 'OpencodeMessageRoleUser', -3) + else + M.add_vertical_border(output, file_line, file_line, 'OpencodeMessageRoleUser', -3) + end content_added = true end end @@ -522,6 +621,12 @@ function M.format_part(part, message, is_last_part, get_child_parts) local question_window = require('opencode.ui.question_window') question_window.format_display(output) content_added = true + elseif part.type == 'revert-display' then + local revert_index = part.state and part.state.revert_index + if revert_index then + output = M._format_revert_message(state.messages or {}, revert_index) + content_added = output:get_line_count() > 0 + end end end diff --git a/lua/opencode/ui/highlight.lua b/lua/opencode/ui/highlight.lua index 84e26c9a..6762792d 100644 --- a/lua/opencode/ui/highlight.lua +++ b/lua/opencode/ui/highlight.lua @@ -47,6 +47,7 @@ function M.setup() vim.api.nvim_set_hl(0, 'OpencodeQuestionOption', { link = 'Normal', default = true }) vim.api.nvim_set_hl(0, 'OpencodeQuestionBorder', { fg = '#E3F2FD', default = true }) vim.api.nvim_set_hl(0, 'OpencodeQuestionTitle', { link = '@label', bold = true, default = true }) + vim.api.nvim_set_hl(0, 'OpencodeChangedLines', { bg = '#FFF3BF', default = true }) else vim.api.nvim_set_hl(0, 'OpencodeBorder', { fg = '#616161', default = true }) vim.api.nvim_set_hl(0, 'OpencodeBackground', { link = 'Normal', default = true }) @@ -90,6 +91,7 @@ function M.setup() vim.api.nvim_set_hl(0, 'OpencodeQuestionOption', { link = 'Normal', default = true }) vim.api.nvim_set_hl(0, 'OpencodeQuestionBorder', { fg = '#2B3A5A', default = true }) vim.api.nvim_set_hl(0, 'OpencodeQuestionTitle', { link = '@label', bold = true, default = true }) + vim.api.nvim_set_hl(0, 'OpencodeChangedLines', { bg = '#3D3520', default = true }) end end diff --git a/lua/opencode/ui/output_window.lua b/lua/opencode/ui/output_window.lua index b703c748..6a5541a0 100644 --- a/lua/opencode/ui/output_window.lua +++ b/lua/opencode/ui/output_window.lua @@ -4,6 +4,10 @@ local window_options = require('opencode.ui.window_options') local M = {} M.namespace = vim.api.nvim_create_namespace('opencode_output') +M.debug_namespace = vim.api.nvim_create_namespace('opencode_output_debug') +M.markdown_namespace = vim.api.nvim_create_namespace('opencode_output_markdown') +M._last_visible_bottom_by_win = {} +M._viewport_cursor_tracking_by_win = {} local _update_depth = 0 local _update_buf = nil @@ -106,8 +110,75 @@ function M.is_at_bottom(win) return cursor[1] >= line_count end +---@param win? integer +---@return integer|nil +function M.get_visible_bottom_line(win) + win = win or (state.windows and state.windows.output_win) + if not win or not vim.api.nvim_win_is_valid(win) then + return nil + end + local ok, line = pcall(vim.fn.line, 'w$', win) + return (ok and line and line > 0) and line or nil +end + +---@param win? integer +function M.reset_scroll_tracking(win) + if win then + M._last_visible_bottom_by_win[win] = nil + M._viewport_cursor_tracking_by_win[win] = nil + return + end + + M._last_visible_bottom_by_win = {} + M._viewport_cursor_tracking_by_win = {} +end + +---@param win? integer +function M.sync_cursor_with_viewport(win) + win = win or (state.windows and state.windows.output_win) + if not win or not vim.api.nvim_win_is_valid(win) then + return + end + + local windows = state.windows + local buf = windows and windows.output_buf + if not buf or not vim.api.nvim_buf_is_valid(buf) or vim.api.nvim_win_get_buf(win) ~= buf then + M.reset_scroll_tracking(win) + return + end + + local ok_cursor, cursor = pcall(vim.api.nvim_win_get_cursor, win) + local ok_count, line_count = pcall(vim.api.nvim_buf_line_count, buf) + local visible_bottom = M.get_visible_bottom_line(win) + if not ok_cursor or not cursor or not ok_count or not line_count or line_count == 0 or not visible_bottom then + return + end + + local last_visible_bottom = M._last_visible_bottom_by_win[win] + local tracking = M._viewport_cursor_tracking_by_win[win] == true + local anchored_to_viewport_bottom = tracking and last_visible_bottom and cursor[1] == last_visible_bottom + + if cursor[1] > visible_bottom or (anchored_to_viewport_bottom and cursor[1] ~= visible_bottom) then + M._viewport_cursor_tracking_by_win[win] = true + pcall(vim.api.nvim_win_set_cursor, win, { math.min(visible_bottom, line_count), 0 }) + local pos = state.ui.get_window_cursor(win) + if pos then + state.ui.set_cursor_position('output', pos) + end + elseif not anchored_to_viewport_bottom then + M._viewport_cursor_tracking_by_win[win] = false + end + + M._last_visible_bottom_by_win[win] = visible_bottom +end + function M.setup(windows) - window_options.set_window_option('winhighlight', config.ui.window_highlight, windows.output_win, { save_original = true }) + window_options.set_window_option( + 'winhighlight', + config.ui.window_highlight, + windows.output_win, + { save_original = true } + ) window_options.set_window_option('wrap', true, windows.output_win, { save_original = true }) window_options.set_window_option('linebreak', true, windows.output_win, { save_original = true }) window_options.set_window_option('number', false, windows.output_win, { save_original = true }) @@ -117,6 +188,8 @@ function M.setup(windows) window_options.set_buffer_option('bufhidden', 'hide', windows.output_buf) window_options.set_buffer_option('buflisted', false, windows.output_buf) window_options.set_buffer_option('swapfile', false, windows.output_buf) + window_options.set_buffer_option('undofile', false, windows.output_buf) + window_options.set_buffer_option('undolevels', -1, windows.output_buf) if config.ui.position ~= 'current' then window_options.set_window_option('winfixbuf', true, windows.output_win, { save_original = true }) @@ -128,6 +201,8 @@ function M.setup(windows) window_options.set_window_option('statuscolumn', '', windows.output_win, { save_original = true }) M.update_dimensions(windows) + M.reset_scroll_tracking(windows.output_win) + M._last_visible_bottom_by_win[windows.output_win] = M.get_visible_bottom_line(windows.output_win) M.setup_keymaps(windows) end @@ -188,6 +263,26 @@ function M.set_lines(lines, start_line, end_line) start_line = start_line or 0 end_line = end_line or -1 + -- Skip identical content outside of batch mode to avoid unnecessary writes + -- that cause flicker (e.g. when a markdown plugin re-renders an unchanged part). + -- Inside begin_update/end_update the caller controls exactly what is written, + -- so the check would be redundant and expensive. + if _update_depth == 0 then + local ok, existing = pcall(vim.api.nvim_buf_get_lines, buf, start_line, end_line, false) + if ok and existing and #existing == #lines then + local same = true + for i = 1, #lines do + if existing[i] ~= lines[i] then + same = false + break + end + end + if same then + return + end + end + end + if _update_depth == 0 then vim.api.nvim_set_option_value('modifiable', true, { buf = buf }) vim.api.nvim_buf_set_lines(buf, start_line, end_line, false, lines) @@ -229,23 +324,72 @@ function M.set_extmarks(extmarks, line_offset) local output_buf = windows.output_buf - for line_idx, marks in pairs(extmarks) do + local line_indices = vim.tbl_keys(extmarks) + table.sort(line_indices) + + for _, line_idx in ipairs(line_indices) do + local marks = extmarks[line_idx] + table.sort(marks, function(a, b) + local ma = type(a) == 'function' and a() or a + local mb = type(b) == 'function' and b() or b + return (ma.priority or 0) > (mb.priority or 0) + end) + for _, mark in ipairs(marks) do - local actual_mark = type(mark) == 'function' and mark() or mark + local m = type(mark) == 'function' and mark() or mark local target_line = line_offset + line_idx --[[@as integer]] - if actual_mark.end_row then - actual_mark.end_row = actual_mark.end_row + line_offset - end - local start_col = actual_mark.start_col - if actual_mark.start_col then - actual_mark.start_col = nil + local start_col = m.start_col + -- Only deepcopy when we need to mutate: start_col must be removed from the + -- opts table, and end_row must be offset when line_offset is non-zero. + -- The vast majority of extmarks (border virt_text) have neither field, so + -- we avoid 100k+ deepcopy calls during a full session render. + if start_col ~= nil or (m.end_row ~= nil and line_offset ~= 0) then + m = vim.deepcopy(m) + m.start_col = nil + if m.end_row then + m.end_row = m.end_row + line_offset + end end - ---@cast actual_mark vim.api.keyset.set_extmark - pcall(vim.api.nvim_buf_set_extmark, output_buf, M.namespace, target_line, start_col or 0, actual_mark) + ---@cast m vim.api.keyset.set_extmark + pcall(vim.api.nvim_buf_set_extmark, output_buf, M.namespace, target_line, start_col or 0, m) end end end +---@param start_line integer +---@param end_line integer +function M.highlight_changed_lines(start_line, end_line) + local windows = state.windows + if not windows or not windows.output_buf or not vim.api.nvim_buf_is_valid(windows.output_buf) then + return + end + if not config.debug.highlight_changed_lines then + return + end + + local buf = windows.output_buf + local first = math.max(0, start_line) + if end_line < start_line then + return + end + local last = math.max(first, end_line) + + vim.api.nvim_buf_clear_namespace(buf, M.debug_namespace, first, last + 1) + for line = first, last do + vim.api.nvim_buf_set_extmark(buf, M.debug_namespace, line, 0, { + line_hl_group = 'OpencodeChangedLines', + hl_eol = true, + priority = 250, + }) + end + + vim.defer_fn(function() + if vim.api.nvim_buf_is_valid(buf) then + vim.api.nvim_buf_clear_namespace(buf, M.debug_namespace, first, last + 1) + end + end, config.debug.highlight_changed_lines_timeout_ms or 120) +end + function M.focus_output(should_stop_insert) if not M.mounted() then return @@ -264,6 +408,7 @@ function M.close() end ---@cast state.windows { output_win: integer, output_buf: integer } + M.reset_scroll_tracking(state.windows.output_win) pcall(vim.api.nvim_win_close, state.windows.output_win, true) pcall(vim.api.nvim_buf_delete, state.windows.output_buf, { force = true }) end @@ -313,36 +458,7 @@ function M.setup_autocmds(windows, group) group = group, buffer = windows.output_buf, callback = function() - if not windows.output_win or not vim.api.nvim_win_is_valid(windows.output_win) then - return - end - - local ok, cursor = pcall(vim.api.nvim_win_get_cursor, windows.output_win) - if not ok then - return - end - - local ok2, line_count = pcall(vim.api.nvim_buf_line_count, windows.output_buf) - if not ok2 or line_count == 0 then - return - end - - if cursor[1] >= line_count then - local ok3, view = pcall(vim.api.nvim_win_call, windows.output_win, vim.fn.winsaveview) - if ok3 and type(view) == 'table' then - local topline = view.topline or 1 - local win_height = vim.api.nvim_win_get_height(windows.output_win) - local visible_bottom = math.min(topline + win_height - 1, line_count) - - if visible_bottom < line_count then - pcall(vim.api.nvim_win_set_cursor, windows.output_win, { visible_bottom, 0 }) - local pos = state.ui.get_window_cursor(windows.output_win) - if pos then - state.ui.set_cursor_position('output', pos) - end - end - end - end + M.sync_cursor_with_viewport(windows.output_win) end, }) end diff --git a/lua/opencode/ui/renderer.lua b/lua/opencode/ui/renderer.lua index fd3fd77d..6154b3f3 100644 --- a/lua/opencode/ui/renderer.lua +++ b/lua/opencode/ui/renderer.lua @@ -1,12 +1,12 @@ local state = require('opencode.state') local config = require('opencode.config') -local formatter = require('opencode.ui.formatter') local output_window = require('opencode.ui.output_window') local permission_window = require('opencode.ui.permission_window') local Promise = require('opencode.promise') local ctx = require('opencode.ui.renderer.ctx') -local buf = require('opencode.ui.renderer.buffer') local events = require('opencode.ui.renderer.events') +local flush = require('opencode.ui.renderer.flush') +local scroll = require('opencode.ui.renderer.scroll') local M = {} @@ -14,23 +14,6 @@ local M = {} -- can be stubbed cleanly (e.g. stub(renderer, '_render_full_session_data')) M.on_session_updated = events.on_session_updated -local trigger_on_data_rendered = require('opencode.util').debounce(function() - local cb_type = type(config.ui.output.rendering.on_data_rendered) - if cb_type == 'boolean' then - return - end - if not state.windows or not state.windows.output_buf or not state.windows.output_win then - return - end - if cb_type == 'function' then - pcall(config.ui.output.rendering.on_data_rendered, state.windows.output_buf, state.windows.output_win) - elseif vim.fn.exists(':RenderMarkdown') > 0 then - vim.cmd(':RenderMarkdown') - elseif vim.fn.exists(':Markview') > 0 then - vim.cmd(':Markview render ' .. state.windows.output_buf) - end -end, config.ui.output.rendering.markdown_debounce_ms or 250) - ---Reset all renderer state and clear the output buffer function M.reset() ctx:reset() @@ -45,7 +28,7 @@ function M.reset() permission_window.clear_all() state.renderer.reset() - trigger_on_data_rendered() + flush.trigger_on_data_rendered() end ---Unsubscribe from all events and reset @@ -141,6 +124,8 @@ function M._render_full_session_data(session_data) local revert_index = nil local set_mode_from_messages = not state.current_model + flush.begin_bulk_mode() + for i, msg in ipairs(session_data) do if state.active_session.revert and state.active_session.revert.messageID == msg.info.id then revert_index = i @@ -152,9 +137,33 @@ function M._render_full_session_data(session_data) end if revert_index then - buf.write_formatted_data(formatter._format_revert_message(state.messages, revert_index)) + local revert_message = { + info = { + id = '__opencode_revert_message__', + sessionID = state.active_session.id, + role = 'system', + }, + parts = { + { + id = '__opencode_revert_part__', + messageID = '__opencode_revert_message__', + sessionID = state.active_session.id, + type = 'revert-display', + state = { + revert_index = revert_index, + }, + }, + }, + } + + table.insert(state.messages, revert_message) + events.on_message_updated(revert_message) + events.on_part_updated({ part = revert_message.parts[1] }) end + flush.flush() + flush.end_bulk_mode() + if set_mode_from_messages then set_model_and_mode_from_messages() end @@ -192,6 +201,7 @@ function M.render_output(output_data) output_window.set_lines(output_data.lines or {}) output_window.clear_extmarks() output_window.set_extmarks(output_data.extmarks) + flush.trigger_on_data_rendered() M.scroll_to_bottom() end @@ -210,30 +220,8 @@ function M.scroll_to_bottom(force) return end - local ok, line_count = pcall(vim.api.nvim_buf_line_count, output_buf) - if not ok or line_count == 0 then - return - end - - local prev_line_count = ctx.prev_line_count - ctx.prev_line_count = line_count - - trigger_on_data_rendered() - - local should_scroll = force - or prev_line_count == 0 - or config.ui.output.always_scroll_to_bottom - or (function() - local ok_cursor, cursor = pcall(vim.api.nvim_win_get_cursor, output_win) - return ok_cursor and cursor and (cursor[1] >= prev_line_count or cursor[1] >= line_count) - end)() - - if should_scroll then - local last_line = vim.api.nvim_buf_get_lines(output_buf, line_count - 1, line_count, false)[1] or '' - vim.api.nvim_win_set_cursor(output_win, { line_count, #last_line }) - vim.api.nvim_win_call(output_win, function() - vim.cmd('normal! zb') - end) + if force or config.ui.output.always_scroll_to_bottom or output_window.is_at_bottom(output_win) then + scroll.scroll_win_to_bottom(output_win, output_buf) end end @@ -242,8 +230,8 @@ function M.on_focus_changed() if not permission_window.get_all_permissions()[1] then return end - buf.rerender_part('permission-display-part') - trigger_on_data_rendered() + flush.mark_part_dirty('permission-display-part', 'permission-display-message') + flush.flush() end ---Re-render when the active session changes diff --git a/lua/opencode/ui/renderer/append.lua b/lua/opencode/ui/renderer/append.lua new file mode 100644 index 00000000..52fbd045 --- /dev/null +++ b/lua/opencode/ui/renderer/append.lua @@ -0,0 +1,43 @@ +local M = {} + +---@param old_lines string[] +---@param new_lines string[] +---@return boolean +function M.is_append_only(old_lines, new_lines) + local old_count = #old_lines + if #new_lines <= old_count then + return false + end + + for i = old_count, 1, -1 do + if old_lines[i] ~= new_lines[i] then + return false + end + end + + return true +end + +---@param old_lines string[] +---@param new_lines string[] +---@return string[] +function M.tail_lines(old_lines, new_lines) + return vim.list_slice(new_lines, #old_lines + 1, #new_lines) +end + +---@param row_offset integer +---@param extmarks table|nil +---@return table +function M.tail_extmarks(row_offset, extmarks) + local tail = {} + + for line_idx, marks in pairs(extmarks or {}) do + if line_idx >= row_offset then + tail[line_idx - row_offset] = vim.deepcopy(marks) + end + end + + return tail +end + +return M diff --git a/lua/opencode/ui/renderer/buffer.lua b/lua/opencode/ui/renderer/buffer.lua index 6d9b6e7d..da069832 100644 --- a/lua/opencode/ui/renderer/buffer.lua +++ b/lua/opencode/ui/renderer/buffer.lua @@ -1,6 +1,5 @@ local ctx = require('opencode.ui.renderer.ctx') local state = require('opencode.state') -local formatter = require('opencode.ui.formatter') local output_window = require('opencode.ui.output_window') local M = {} @@ -9,99 +8,184 @@ local function has_extmarks(extmarks) return type(extmarks) == 'table' and next(extmarks) ~= nil end +local function accumulate_bulk_extmarks(extmarks, line_start) + for line_idx, marks in pairs(extmarks) do + local actual_line = line_start + line_idx + local bucket = ctx.bulk_extmarks_by_line[actual_line] + if not bucket then + bucket = {} + ctx.bulk_extmarks_by_line[actual_line] = bucket + end + for _, mark in ipairs(marks) do + local copy = vim.deepcopy(mark) + if copy.end_row then + copy.end_row = line_start + copy.end_row + end + bucket[#bucket + 1] = copy + end + end +end + local function has_actions(actions) return type(actions) == 'table' and #actions > 0 end ----@param old_lines string[] ----@param new_lines string[] ----@return integer, integer -local function get_shared_prefix_suffix(old_lines, new_lines) - local old_count = #old_lines - local new_count = #new_lines - local prefix = 0 +local function unchanged_prefix_len(previous_formatted, formatted_data) + local previous_lines = previous_formatted and previous_formatted.lines or {} + local next_lines = formatted_data and formatted_data.lines or {} + local prefix_len = 0 - while prefix < old_count and prefix < new_count do - if old_lines[prefix + 1] ~= new_lines[prefix + 1] then + for i = 1, math.min(#previous_lines, #next_lines) do + if previous_lines[i] ~= next_lines[i] then break end - prefix = prefix + 1 + prefix_len = i end - local suffix = 0 - while suffix < (old_count - prefix) and suffix < (new_count - prefix) do - if old_lines[old_count - suffix] ~= new_lines[new_count - suffix] then - break + return prefix_len +end + +local function slice_lines(lines, start_idx) + local slice = {} + for i = start_idx, #(lines or {}) do + slice[#slice + 1] = lines[i] + end + return slice +end + +local function slice_extmarks(extmarks, start_line) + local slice = {} + for line_idx, marks in pairs(extmarks or {}) do + if line_idx >= start_line then + slice[line_idx - start_line] = vim.deepcopy(marks) end - suffix = suffix + 1 end + return slice +end - return prefix, suffix +local function resolve_mark(mark) + return type(mark) == 'function' and mark() or mark end ----Find the last renderable part ID in a message (skips step-start/finish) ----@param message OpencodeMessage ----@return string? -function M.get_last_part_for_message(message) - if not message or not message.parts or #message.parts == 0 then - return nil +local function marks_equal(a, b) + a = a or {} + b = b or {} + + if #a ~= #b then + return false end - for i = #message.parts, 1, -1 do - local part = message.parts[i] - if part.type ~= 'step-start' and part.type ~= 'step-finish' and part.id then - return part.id + + for i = 1, #a do + if not vim.deep_equal(resolve_mark(a[i]), resolve_mark(b[i])) then + return false end end - return nil + + return true end ----Find the first non-synthetic text part ID in a message ----@param message OpencodeMessage ----@return string? -function M.find_text_part_for_message(message) - if not message or not message.parts then - return nil - end - for _, part in ipairs(message.parts) do - if part.type == 'text' and not part.synthetic then - return part.id +local function unchanged_extmark_prefix_len(previous_formatted, formatted_data) + local previous_lines = previous_formatted and previous_formatted.lines or {} + local next_lines = formatted_data and formatted_data.lines or {} + local max_lines = math.max(#previous_lines, #next_lines) + local prefix_len = 0 + + for line_idx = 0, math.max(max_lines - 1, 0) do + local previous_marks = previous_formatted and previous_formatted.extmarks and previous_formatted.extmarks[line_idx] or nil + local next_marks = formatted_data and formatted_data.extmarks and formatted_data.extmarks[line_idx] or nil + + if not marks_equal(previous_marks, next_marks) then + break end + + prefix_len = line_idx + 1 end - return nil + + return prefix_len end ----Find part ID by call ID and message ID ----@param call_id string ----@param message_id string ----@return string? -function M.find_part_by_call_id(call_id, message_id) - return ctx.render_state:get_part_by_call_id(call_id, message_id) +local function highlight_written_lines(start_line, lines) + if #lines == 0 then + return + end + output_window.highlight_changed_lines(start_line, start_line + #lines - 1) +end + +local function apply_extmarks(previous_formatted, formatted_data, line_start, old_line_end, new_line_end) + local prefix_len = math.min( + unchanged_prefix_len(previous_formatted, formatted_data), + unchanged_extmark_prefix_len(previous_formatted, formatted_data) + ) + local clear_start = line_start + prefix_len + local clear_end = math.max(old_line_end, new_line_end) + 1 + + output_window.clear_extmarks(clear_start, clear_end) + + local extmarks = slice_extmarks(formatted_data.extmarks, prefix_len) + if has_extmarks(extmarks) then + output_window.set_extmarks(extmarks, clear_start) + end end ----Determine where to insert an out-of-order part (after the last rendered ----sibling, or right after the message header if no siblings are rendered yet) ----@param part_id string ----@param message_id string ----@return integer? -local function get_insertion_point_for_part(part_id, message_id) +local function get_message_insert_line(message_id) local rendered_message = ctx.render_state:get_message(message_id) - if not rendered_message or not rendered_message.message then - return nil + if rendered_message and rendered_message.line_start then + return rendered_message.line_start + end + + local line_count = output_window.get_buf_line_count() + local append_at = math.max(line_count - 1, 0) + if line_count == 1 then + local windows = state.windows + local output_buf = windows and windows.output_buf + if output_buf and vim.api.nvim_buf_is_valid(output_buf) then + local lines = vim.api.nvim_buf_get_lines(output_buf, 0, 1, false) + if lines[1] == '' then + return 0 + end + end end - local message = rendered_message.message - local insertion_line = rendered_message.line_end and (rendered_message.line_end + 1) - if not insertion_line then + local messages = state.messages or {} + local message_index = nil + for i, message in ipairs(messages) do + if message.info and message.info.id == message_id then + message_index = i + break + end + end + + if not message_index then + return append_at + end + + for i = message_index + 1, #messages do + local next_message = messages[i] + if next_message and next_message.info and next_message.info.id then + local next_rendered = ctx.render_state:get_message(next_message.info.id) + if next_rendered and next_rendered.line_start then + return next_rendered.line_start + end + end + end + + return append_at +end + +local function get_part_insertion_line(part_id, message_id) + local rendered_message = ctx.render_state:get_message(message_id) + if not rendered_message or not rendered_message.message or not rendered_message.line_end then return nil end + local message = rendered_message.message + local insertion_line = rendered_message.line_end + 1 local current_part_index = nil - if message.parts then - for i, part in ipairs(message.parts) do - if part.id == part_id then - current_part_index = i - break - end + + for i, part in ipairs(message.parts or {}) do + if part.id == part_id then + current_part_index = i + break end end @@ -109,13 +193,12 @@ local function get_insertion_point_for_part(part_id, message_id) return insertion_line end - -- Walk backwards through earlier siblings to find the last rendered one for i = current_part_index - 1, 1, -1 do - local prev_part = message.parts[i] - if prev_part and prev_part.id then - local prev_rendered = ctx.render_state:get_part(prev_part.id) - if prev_rendered and prev_rendered.line_end then - return prev_rendered.line_end + 1 + local previous = message.parts[i] + if previous and previous.id then + local previous_rendered = ctx.render_state:get_part(previous.id) + if previous_rendered and previous_rendered.line_end then + return previous_rendered.line_end + 1 end end end @@ -123,279 +206,247 @@ local function get_insertion_point_for_part(part_id, message_id) return insertion_line end ----Append formatted data to the end of the buffer, or insert at start_line. ----Returns the range of lines written, or nil if nothing was written. ----@param formatted_data Output ----@param part_id? string When provided, actions are registered for this part ----@param start_line? integer When provided, content is inserted here (shifts down) ----@return {line_start: integer, line_end: integer}? -function M.write_formatted_data(formatted_data, part_id, start_line) - if not state.windows or not state.windows.output_buf then - return nil - end +local function write_at(lines, start_line, end_line) + output_window.set_lines(lines, start_line, end_line) + highlight_written_lines(start_line, lines) + return { + line_start = start_line, + line_end = start_line + #lines - 1, + } +end - local new_lines = formatted_data.lines - if #new_lines == 0 then - return nil +local function apply_part_actions(part_id, formatted_data, line_start) + if has_actions(formatted_data.actions) then + ctx.render_state:clear_actions(part_id) + ctx.render_state:add_actions(part_id, vim.deepcopy(formatted_data.actions), line_start + 1) + else + ctx.render_state:clear_actions(part_id) end - local is_insertion = start_line ~= nil - local target_line = start_line or output_window.get_buf_line_count() + local part_data = ctx.render_state:get_part(part_id) + if part_data then + part_data.has_extmarks = has_extmarks(formatted_data.extmarks) + end +end - if is_insertion then - output_window.set_lines(new_lines, target_line, target_line) - else - -- Append: overlap the last buffer line with our lines - target_line = target_line - 1 - local append_lines = table.move(new_lines, 1, #new_lines, 1, {}) - append_lines[#append_lines + 1] = '' - output_window.set_lines(append_lines, target_line) +local function set_part_extmark_state(part_id, formatted_data) + local part_data = ctx.render_state:get_part(part_id) + if part_data then + part_data.has_extmarks = has_extmarks(formatted_data.extmarks) end +end - if part_id and formatted_data.actions then - ctx.render_state:add_actions(part_id, formatted_data.actions, target_line) +function M.get_last_part_for_message(message) + if not message or not message.parts or #message.parts == 0 then + return nil + end + for i = #message.parts, 1, -1 do + local part = message.parts[i] + if part.type ~= 'step-start' and part.type ~= 'step-finish' and part.id then + return part.id + end end + return nil +end - if has_extmarks(formatted_data.extmarks) then - output_window.set_extmarks(formatted_data.extmarks, target_line) - local part_data = ctx.render_state:get_part(part_id) - if part_data then - part_data.has_extmarks = true +function M.find_text_part_for_message(message) + if not message or not message.parts then + return nil + end + for _, part in ipairs(message.parts) do + if part.type == 'text' and not part.synthetic then + return part.id end end + return nil +end - return { line_start = target_line, line_end = target_line + #new_lines - 1 } +function M.find_part_by_call_id(call_id, message_id) + return ctx.render_state:get_part_by_call_id(call_id, message_id) end ----Insert a new part into the buffer. ----Appends if the part belongs to the current message; inserts in-order otherwise. ----@param part_id string ----@param formatted_data Output ----@return boolean -function M.insert_part(part_id, formatted_data) - local cached = ctx.render_state:get_part(part_id) - if not cached then - return false - end +function M.upsert_message_now(message_id, formatted_data, previous_formatted) + if ctx.bulk_mode then + local line_start = #ctx.bulk_buffer_lines + local line_end = line_start + #formatted_data.lines - 1 + + for _, line in ipairs(formatted_data.lines) do + ctx.bulk_buffer_lines[#ctx.bulk_buffer_lines + 1] = line + end + if has_extmarks(formatted_data.extmarks) then + accumulate_bulk_extmarks(formatted_data.extmarks, line_start) + end + + local message_data = ctx.render_state:get_message(message_id) + if message_data then + ctx.render_state:set_message(message_data.message, line_start, line_end) + end - if #formatted_data.lines == 0 then return true end - local is_current_message = state.current_message - and state.current_message.info - and state.current_message.info.id == cached.message_id - - if is_current_message then - local range = M.write_formatted_data(formatted_data, part_id) - if not range then - return false + local cached = ctx.render_state:get_message(message_id) + if cached and cached.line_start and cached.line_end then + local old_line_end = cached.line_end + local prefix_len = unchanged_prefix_len(previous_formatted, formatted_data) + local write_start = cached.line_start + prefix_len + local lines_to_write = slice_lines(formatted_data.lines, prefix_len + 1) + + output_window.set_lines(lines_to_write, write_start, cached.line_end + 1) + highlight_written_lines(write_start, lines_to_write) + + local new_line_end = cached.line_start + #formatted_data.lines - 1 + apply_extmarks(previous_formatted, formatted_data, cached.line_start, old_line_end, new_line_end) + ctx.render_state:set_message(cached.message, cached.line_start, new_line_end) + + local delta = new_line_end - old_line_end + if delta ~= 0 then + ctx.render_state:shift_all(old_line_end + 1, delta) end - ctx.render_state:set_part(cached.part, range.line_start, range.line_end) - ctx.last_part_formatted = { part_id = part_id, formatted_data = formatted_data } return true end - -- Out-of-order part: find the correct insertion point - local insertion_line = get_insertion_point_for_part(part_id, cached.message_id) - if not insertion_line then - return false - end + local insert_at = get_message_insert_line(message_id) + local message_data = ctx.render_state:get_message(message_id) + if message_data and message_data.message then + local range = write_at(formatted_data.lines, insert_at, insert_at) + if has_extmarks(formatted_data.extmarks) then + output_window.set_extmarks(formatted_data.extmarks, insert_at) + end - local range = M.write_formatted_data(formatted_data, part_id, insertion_line) - if not range then - return false + ctx.render_state:shift_all(insert_at, #formatted_data.lines) + ctx.render_state:set_message(message_data.message, range.line_start, range.line_end) + return true end - ctx.render_state:shift_all(insertion_line, #formatted_data.lines) - ctx.render_state:set_part(cached.part, range.line_start, range.line_end) - return true + return false end ----Replace an existing part in the buffer. ----Only writes lines that differ from the previous render (diff optimisation). ----@param part_id string ----@param formatted_data Output ----@return boolean -function M.replace_part(part_id, formatted_data) - local cached = ctx.render_state:get_part(part_id) - if not cached or not cached.line_start or not cached.line_end then - return false - end +function M.upsert_part_now(part_id, message_id, formatted_data, previous_formatted) + if ctx.bulk_mode then + local line_start = #ctx.bulk_buffer_lines + local line_end = line_start + #formatted_data.lines - 1 - local new_lines = formatted_data.lines - local new_line_count = #new_lines - local next_has_extmarks = has_extmarks(formatted_data.extmarks) - local had_extmarks = cached.has_extmarks == true - local next_has_actions = has_actions(formatted_data.actions) - local had_actions = cached.actions and #cached.actions > 0 - local old_buf_line_count = output_window.get_buf_line_count() - local was_tail_part = cached.line_end == old_buf_line_count - 1 - - -- Diff optimisation: skip lines that haven't changed since the last render - local old = ctx.last_part_formatted - local lines_to_write = new_lines - local write_start = cached.line_start - local write_end = cached.line_end + 1 - local prefix = 0 - local suffix = 0 - - if old and old.part_id == part_id and old.formatted_data and old.formatted_data.lines then - local old_lines = old.formatted_data.lines - prefix, suffix = get_shared_prefix_suffix(old_lines, new_lines) - - if prefix == #old_lines and prefix == new_line_count then - if not had_extmarks and not next_has_extmarks and not had_actions and not next_has_actions then - ctx.last_part_formatted = { part_id = part_id, formatted_data = formatted_data } - return true - end + for _, line in ipairs(formatted_data.lines) do + ctx.bulk_buffer_lines[#ctx.bulk_buffer_lines + 1] = line + end + if has_extmarks(formatted_data.extmarks) then + accumulate_bulk_extmarks(formatted_data.extmarks, line_start) end - local replace_from = prefix + 1 - local replace_to = new_line_count - suffix - lines_to_write = replace_from <= replace_to and vim.list_slice(new_lines, replace_from, replace_to) or {} - write_start = cached.line_start + prefix - write_end = cached.line_end + 1 - suffix - end + local part_data = ctx.render_state:get_part(part_id) + if part_data then + ctx.render_state:set_part(part_data.part, line_start, line_end) + apply_part_actions(part_id, formatted_data, line_start) + end - if had_actions or next_has_actions then - ctx.render_state:clear_actions(part_id) + return true end - output_window.begin_update() - if had_extmarks or next_has_extmarks then - output_window.clear_extmarks(cached.line_start - 1, cached.line_end + 1) - end - output_window.set_lines(lines_to_write, write_start, write_end) + local cached = ctx.render_state:get_part(part_id) + if cached and cached.line_start and cached.line_end then + local old_line_end = cached.line_end + local prefix_len = unchanged_prefix_len(previous_formatted, formatted_data) + local write_start = cached.line_start + prefix_len + local lines_to_write = slice_lines(formatted_data.lines, prefix_len + 1) - local new_line_end = cached.line_start + new_line_count - 1 - if next_has_extmarks then - output_window.set_extmarks(formatted_data.extmarks, cached.line_start) - end - output_window.end_update() - cached.has_extmarks = next_has_extmarks + output_window.set_lines(lines_to_write, write_start, cached.line_end + 1) + highlight_written_lines(write_start, lines_to_write) - if next_has_actions then - ctx.render_state:add_actions(part_id, formatted_data.actions, cached.line_start + 1) - end + local new_line_end = cached.line_start + #formatted_data.lines - 1 + apply_part_actions(part_id, formatted_data, cached.line_start) - if new_line_end ~= cached.line_end then - if was_tail_part then - ctx.render_state:set_part(cached.part, cached.line_start, new_line_end) - else + if new_line_end ~= cached.line_end then ctx.render_state:update_part_lines(part_id, cached.line_start, new_line_end) end + apply_extmarks(previous_formatted, formatted_data, cached.line_start, old_line_end, new_line_end) + set_part_extmark_state(part_id, formatted_data) + return true end - ctx.last_part_formatted = { part_id = part_id, formatted_data = formatted_data } - return true -end - ----Remove a part and its extmarks from the buffer ----@param part_id string -function M.remove_part(part_id) - local cached = ctx.render_state:get_part(part_id) - if not cached or not cached.line_start or not cached.line_end then - return + local insert_at = get_part_insertion_line(part_id, message_id) + if not insert_at then + return false end - output_window.begin_update() - output_window.clear_extmarks(cached.line_start - 1, cached.line_end + 1) - output_window.set_lines({}, cached.line_start, cached.line_end + 1) - output_window.end_update() - ctx.render_state:remove_part(part_id) -end ----Write a message header into the buffer ----@param message OpencodeMessage -function M.add_message(message) - local header_data = formatter.format_message_header(message) - local range = M.write_formatted_data(header_data) - if range then - ctx.render_state:set_message(message, range.line_start, range.line_end) + local part_data = ctx.render_state:get_part(part_id) + if part_data and part_data.part then + local range = write_at(formatted_data.lines, insert_at, insert_at) + ctx.render_state:shift_all(insert_at, #formatted_data.lines) + ctx.render_state:set_part(part_data.part, range.line_start, range.line_end) + apply_part_actions(part_id, formatted_data, range.line_start) + if has_extmarks(formatted_data.extmarks) then + output_window.set_extmarks(formatted_data.extmarks, range.line_start) + end + set_part_extmark_state(part_id, formatted_data) + return true end + + return false end ----Replace an existing message header in the buffer ----@param message_id string ----@param formatted_data Output ----@return boolean -function M.replace_message(message_id, formatted_data) - local cached = ctx.render_state:get_message(message_id) - if not cached or not cached.line_start or not cached.line_end then +function M.append_part_now(part_id, extra_lines, extra_extmarks, previous_formatted) + local cached = ctx.render_state:get_part(part_id) + if not cached or not cached.line_start or not cached.line_end or #extra_lines == 0 then return false end - local new_lines = formatted_data.lines - local new_line_count = #new_lines - - output_window.begin_update() - output_window.clear_extmarks(cached.line_start, cached.line_end + 1) - output_window.set_lines(new_lines, cached.line_start, cached.line_end + 1) - output_window.set_extmarks(formatted_data.extmarks, cached.line_start) - output_window.end_update() - + local insert_at = cached.line_end + 1 local old_line_end = cached.line_end - local new_line_end = cached.line_start + new_line_count - 1 + output_window.set_lines(extra_lines, insert_at, insert_at) + highlight_written_lines(insert_at, extra_lines) - ctx.render_state:set_message(cached.message, cached.line_start, new_line_end) + local new_line_end = cached.line_end + #extra_lines + ctx.render_state:update_part_lines(part_id, cached.line_start, new_line_end) - local delta = new_line_end - old_line_end - if delta ~= 0 then - ctx.render_state:shift_all(old_line_end + 1, delta) + local formatted_data = ctx.formatted_parts[part_id] + if formatted_data then + apply_part_actions(part_id, formatted_data, cached.line_start) + apply_extmarks(previous_formatted, formatted_data, cached.line_start, old_line_end, new_line_end) + set_part_extmark_state(part_id, formatted_data) + elseif has_extmarks(extra_extmarks) then + output_window.set_extmarks(extra_extmarks, insert_at) end return true end ----Remove a message header and its extmarks from the buffer ----@param message_id string -function M.remove_message(message_id) - local cached = ctx.render_state:get_message(message_id) - if not cached or not cached.line_start or not cached.line_end then +function M.remove_part_now(part_id) + if ctx.bulk_mode then + -- In bulk mode, we don't actually remove from buffer since we're building fresh + -- Just track that this part should be excluded + ctx.render_state:remove_part(part_id) return end - if not state.windows or not state.windows.output_buf then - return - end - if cached.line_start == 0 and cached.line_end == 0 then + + local cached = ctx.render_state:get_part(part_id) + if not cached or not cached.line_start or not cached.line_end then + ctx.render_state:remove_part(part_id) return end - output_window.begin_update() + output_window.clear_extmarks(cached.line_start - 1, cached.line_end + 1) output_window.set_lines({}, cached.line_start, cached.line_end + 1) - output_window.end_update() - ctx.render_state:remove_message(message_id) + ctx.render_state:remove_part(part_id) end ----Re-render an existing part using its current data from render_state ----@param part_id string -function M.rerender_part(part_id) - local cached = ctx.render_state:get_part(part_id) - if not cached or not cached.part then +function M.remove_message_now(message_id) + if ctx.bulk_mode then + -- In bulk mode, we don't actually remove from buffer since we're building fresh + -- Just track that this message should be excluded + ctx.render_state:remove_message(message_id) return end - local rendered_message = ctx.render_state:get_message(cached.message_id) - if not rendered_message or not rendered_message.message then + local cached = ctx.render_state:get_message(message_id) + if not cached or not cached.line_start or not cached.line_end then + ctx.render_state:remove_message(message_id) return end - local message = rendered_message.message - local is_last_part = (M.get_last_part_for_message(message) == part_id) - local formatted = formatter.format_part(cached.part, message, is_last_part, function(session_id) - return ctx.render_state:get_child_session_parts(session_id) - end) - - M.replace_part(part_id, formatted) -end - ----Re-render the task-tool part that owns the given child session ----@param child_session_id string -function M.rerender_task_tool_for_child_session(child_session_id) - local part_id = ctx.render_state:get_task_part_by_child_session(child_session_id) - if part_id then - M.rerender_part(part_id) - end + output_window.clear_extmarks(cached.line_start, cached.line_end + 1) + output_window.set_lines({}, cached.line_start, cached.line_end + 1) + ctx.render_state:remove_message(message_id) end return M diff --git a/lua/opencode/ui/renderer/ctx.lua b/lua/opencode/ui/renderer/ctx.lua index 7cbeddee..19ea6a15 100644 --- a/lua/opencode/ui/renderer/ctx.lua +++ b/lua/opencode/ui/renderer/ctx.lua @@ -6,16 +6,55 @@ local RenderState = require('opencode.ui.render_state') local ctx = { ---@type RenderState render_state = RenderState.new(), - ---@type integer - prev_line_count = 0, ---@type { part_id: string|nil, formatted_data: Output|nil } last_part_formatted = { part_id = nil, formatted_data = nil }, + ---@type table + formatted_parts = {}, + ---@type table + formatted_messages = {}, + pending = { + dirty_message_order = {}, + dirty_messages = {}, + dirty_part_by_message = {}, + dirty_part_order = {}, + dirty_parts = {}, + removed_part_order = {}, + removed_parts = {}, + removed_message_order = {}, + removed_messages = {}, + }, + flush_scheduled = false, + markdown_render_scheduled = false, + bulk_mode = false, + bulk_buffer_lines = {}, + bulk_extmarks_by_line = {}, } function ctx:reset() self.render_state:reset() - self.prev_line_count = 0 self.last_part_formatted = { part_id = nil, formatted_data = nil } + self.formatted_parts = {} + self.formatted_messages = {} + self.pending = { + dirty_message_order = {}, + dirty_messages = {}, + dirty_part_by_message = {}, + dirty_part_order = {}, + dirty_parts = {}, + removed_part_order = {}, + removed_parts = {}, + removed_message_order = {}, + removed_messages = {}, + } + self.flush_scheduled = false + self.markdown_render_scheduled = false + self:bulk_reset() +end + +function ctx:bulk_reset() + self.bulk_mode = false + self.bulk_buffer_lines = {} + self.bulk_extmarks_by_line = {} end return ctx diff --git a/lua/opencode/ui/renderer/events.lua b/lua/opencode/ui/renderer/events.lua index e46231b1..cae58015 100644 --- a/lua/opencode/ui/renderer/events.lua +++ b/lua/opencode/ui/renderer/events.lua @@ -1,9 +1,33 @@ local state = require('opencode.state') local config = require('opencode.config') -local formatter = require('opencode.ui.formatter') local ctx = require('opencode.ui.renderer.ctx') -local buf = require('opencode.ui.renderer.buffer') local permission_window = require('opencode.ui.permission_window') +local flush = require('opencode.ui.renderer.flush') + +local function get_last_part_for_message(message) + if not message or not message.parts or #message.parts == 0 then + return nil + end + for i = #message.parts, 1, -1 do + local part = message.parts[i] + if part.type ~= 'step-start' and part.type ~= 'step-finish' and part.id then + return part.id + end + end + return nil +end + +local function find_text_part_for_message(message) + if not message or not message.parts then + return nil + end + for _, part in ipairs(message.parts) do + if part.type == 'text' and not part.synthetic then + return part.id + end + end + return nil +end -- Lazy require to avoid circular dependency: renderer.lua <-> events.lua local function scroll(force) @@ -33,8 +57,8 @@ end function M.render_permissions_display() local permissions = permission_window.get_all_permissions() if not permissions or #permissions == 0 then - buf.remove_part('permission-display-part') - buf.remove_message('permission-display-message') + flush.queue_part_removal('permission-display-part') + flush.queue_message_removal('permission-display-message') return end @@ -55,7 +79,6 @@ function M.render_permissions_display() type = 'permissions-display', } M.on_part_updated({ part = fake_part }) - scroll(true) end ---Render the current question as a synthetic part at the end of the buffer @@ -69,8 +92,8 @@ function M.render_question_display() local current_question = question_window._current_question if not question_window.has_question() or not current_question or not current_question.id then - buf.remove_part('question-display-part') - buf.remove_message('question-display-message') + flush.queue_part_removal('question-display-part') + flush.queue_message_removal('question-display-message') return end @@ -101,8 +124,8 @@ function M.clear_question_display() question_window.clear_question() if not use_vim_ui then - buf.remove_part('question-display-part') - buf.remove_message('question-display-message') + flush.queue_part_removal('question-display-part') + flush.queue_message_removal('question-display-message') end end @@ -142,17 +165,17 @@ function M.on_message_updated(message, revert_index) -- Re-render the last part (or the header if there are no parts) so the -- error appears in the right place. if error_changed then - local last_part_id = buf.get_last_part_for_message(found_msg) + local last_part_id = get_last_part_for_message(found_msg) if last_part_id then - buf.rerender_part(last_part_id) + flush.mark_part_dirty(last_part_id, msg.info.id) else - local header_data = formatter.format_message_header(found_msg) - buf.replace_message(msg.info.id, header_data) + flush.mark_message_dirty(msg.info.id) end end else table.insert(state.messages, msg) - buf.add_message(msg) + ctx.render_state:set_message(msg) + flush.mark_message_dirty(msg.info.id) state.renderer.set_current_message(msg) if message.info.role == 'user' then state.renderer.set_last_user_message(msg) @@ -182,11 +205,11 @@ function M.on_message_removed(properties) for _, part in ipairs(rendered_message.message.parts or {}) do if part.id then - buf.remove_part(part.id) + flush.queue_part_removal(part.id) end end - buf.remove_message(message_id) + flush.queue_message_removal(message_id) for i, msg in ipairs(state.messages or {}) do if msg.info.id == message_id then @@ -213,7 +236,10 @@ function M.on_part_updated(properties, revert_index) if state.active_session.id ~= part.sessionID then if part.tool or part.type == 'tool' then ctx.render_state:upsert_child_session_part(part.sessionID, part) - buf.rerender_task_tool_for_child_session(part.sessionID) + local task_part_id = ctx.render_state:get_task_part_by_child_session(part.sessionID) + if task_part_id then + flush.mark_part_dirty(task_part_id) + end end return end @@ -230,8 +256,7 @@ function M.on_part_updated(properties, revert_index) local part_data = ctx.render_state:get_part(part.id) local is_new_part = not part_data - local prev_last_part_id = buf.get_last_part_for_message(message) - local is_last_part = is_new_part or (prev_last_part_id == part.id) + local prev_last_part_id = get_last_part_for_message(message) -- Update the part reference in the message if is_new_part then @@ -277,32 +302,27 @@ function M.on_part_updated(properties, revert_index) return end - local formatted = formatter.format_part(part, message, is_last_part, function(session_id) - return ctx.render_state:get_child_session_parts(session_id) - end) - if is_new_part then - buf.insert_part(part.id, formatted) + flush.mark_part_dirty(part.id, part.messageID) -- If there's already an error on this message, adjust adjacent parts so -- the error only appears after the last part. if message.info.error then if not prev_last_part_id then - local header_data = formatter.format_message_header(message) - buf.replace_message(part.messageID, header_data) + flush.mark_message_dirty(part.messageID) elseif prev_last_part_id ~= part.id then - buf.rerender_part(prev_last_part_id) + flush.mark_part_dirty(prev_last_part_id, part.messageID) end end else - buf.replace_part(part.id, formatted) + flush.mark_part_dirty(part.id, part.messageID) end -- File / agent mentions: re-render the text part to highlight them if (part.type == 'file' or part.type == 'agent') and part.source then - local text_part_id = buf.find_text_part_for_message(message) + local text_part_id = find_text_part_for_message(message) if text_part_id then - buf.rerender_part(text_part_id) + flush.mark_part_dirty(text_part_id, part.messageID) end end end @@ -321,8 +341,9 @@ function M.on_part_removed(properties) -- Remove the part from the in-memory message too local cached = ctx.render_state:get_part(part_id) - if cached and cached.message_id then - local rendered_message = ctx.render_state:get_message(cached.message_id) + local message_id = cached and cached.message_id + if message_id then + local rendered_message = ctx.render_state:get_message(message_id) if rendered_message and rendered_message.message and rendered_message.message.parts then for i, part in ipairs(rendered_message.message.parts) do if part.id == part_id then @@ -333,7 +354,12 @@ function M.on_part_removed(properties) end end - buf.remove_part(part_id) + flush.queue_part_removal(part_id) + + -- Mark message dirty so header (timestamp, etc.) gets re-rendered + if message_id then + flush.mark_message_dirty(message_id) + end end ---Handle session.updated — re-render the full session if the revert state changed @@ -429,8 +455,8 @@ function M.on_permission_replied(properties) state.renderer.set_pending_permissions(vim.deepcopy(permission_window.get_all_permissions())) if #state.pending_permissions == 0 then - buf.remove_part('permission-display-part') - buf.remove_message('permission-display-message') + flush.queue_part_removal('permission-display-part') + flush.queue_message_removal('permission-display-message') else M.render_permissions_display() end diff --git a/lua/opencode/ui/renderer/flush.lua b/lua/opencode/ui/renderer/flush.lua new file mode 100644 index 00000000..2ba19555 --- /dev/null +++ b/lua/opencode/ui/renderer/flush.lua @@ -0,0 +1,426 @@ +local state = require('opencode.state') +local config = require('opencode.config') +local formatter = require('opencode.ui.formatter') +local output_window = require('opencode.ui.output_window') +local ctx = require('opencode.ui.renderer.ctx') +local scroll = require('opencode.ui.renderer.scroll') +local buffer = require('opencode.ui.renderer.buffer') +local append = require('opencode.ui.renderer.append') + +local M = {} + +local function lines_equal(a, b) + a = a or {} + b = b or {} + if #a ~= #b then + return false + end + for i = 1, #a do + if a[i] ~= b[i] then + return false + end + end + return true +end + +local function resolve_mark(m) + return type(m) == 'function' and m() or m +end + +local function extmarks_equal(a, b) + a = a or {} + b = b or {} + for k, va in pairs(a) do + local vb = b[k] + if not vb or #va ~= #vb then + return false + end + for i = 1, #va do + if not vim.deep_equal(resolve_mark(va[i]), resolve_mark(vb[i])) then + return false + end + end + end + for k in pairs(b) do + if not a[k] then + return false + end + end + return true +end + +local function is_markdown_render_deferred() + if not config.ui.output.rendering.markdown_on_idle then + return false + end + + local active_session = state.active_session + local session_id = active_session and active_session.id + if not session_id then + return false + end + + local pending = state.user_message_count or {} + local threshold = config.ui.output.rendering.markdown_on_idle_threshold + if type(threshold) == 'number' then + return (pending[session_id] or 0) > threshold + end + return (pending[session_id] or 0) > 0 +end + +local function enqueue_once(order, lookup, id) + if lookup[id] then + return + end + order[#order + 1] = id +end + +local function track_message_for_part(message_id, part_id) + if not message_id or not part_id then + return + end + + local part_ids = ctx.pending.dirty_part_by_message[message_id] + if not part_ids then + part_ids = {} + ctx.pending.dirty_part_by_message[message_id] = part_ids + end + part_ids[part_id] = true +end + +local function untrack_message_for_part(message_id, part_id) + local part_ids = message_id and ctx.pending.dirty_part_by_message[message_id] + if not part_ids then + return + end + part_ids[part_id] = nil + if next(part_ids) == nil then + ctx.pending.dirty_part_by_message[message_id] = nil + end +end + +function M.mark_message_dirty(message_id) + if not message_id then + return + end + ctx.pending.removed_messages[message_id] = nil + enqueue_once(ctx.pending.dirty_message_order, ctx.pending.dirty_messages, message_id) + ctx.pending.dirty_messages[message_id] = true + -- Clear cached formatted data so the message gets fully re-rendered + ctx.formatted_messages[message_id] = nil + M.schedule() +end + +function M.mark_part_dirty(part_id, message_id) + if not part_id then + return + end + + local rendered_part = ctx.render_state:get_part(part_id) + message_id = message_id or (rendered_part and rendered_part.message_id) + if not message_id then + return + end + + ctx.pending.removed_parts[part_id] = nil + enqueue_once(ctx.pending.dirty_part_order, ctx.pending.dirty_parts, part_id) + ctx.pending.dirty_parts[part_id] = message_id + track_message_for_part(message_id, part_id) + M.schedule() +end + +function M.queue_part_removal(part_id) + if not part_id then + return + end + + local rendered_part = ctx.render_state:get_part(part_id) + if rendered_part and rendered_part.message_id then + untrack_message_for_part(rendered_part.message_id, part_id) + end + + ctx.pending.dirty_parts[part_id] = nil + enqueue_once(ctx.pending.removed_part_order, ctx.pending.removed_parts, part_id) + ctx.pending.removed_parts[part_id] = true + ctx.formatted_parts[part_id] = nil + M.schedule() +end + +function M.queue_message_removal(message_id) + if not message_id then + return + end + + ctx.pending.dirty_messages[message_id] = nil + ctx.pending.dirty_part_by_message[message_id] = nil + enqueue_once(ctx.pending.removed_message_order, ctx.pending.removed_messages, message_id) + ctx.pending.removed_messages[message_id] = true + ctx.formatted_messages[message_id] = nil + M.schedule() +end + +function M.schedule() + if ctx.flush_scheduled then + return + end + + ctx.flush_scheduled = true + vim.schedule(function() + ctx.flush_scheduled = false + M.flush() + end) +end + +local function snapshot_pending() + local pending = ctx.pending + ctx.pending = { + dirty_message_order = {}, + dirty_messages = {}, + dirty_part_by_message = {}, + dirty_part_order = {}, + dirty_parts = {}, + removed_part_order = {}, + removed_parts = {}, + removed_message_order = {}, + removed_messages = {}, + } + return pending +end + +local function format_message(message_id) + local rendered_message = ctx.render_state:get_message(message_id) + local message = rendered_message and rendered_message.message + if not message then + return nil + end + + local prev = ctx.formatted_messages[message_id] + local formatted = formatter.format_message_header(message) + + if prev and lines_equal(prev.lines, formatted.lines) and extmarks_equal(prev.extmarks, formatted.extmarks) then + -- no visible change + return nil + end + + ctx.formatted_messages[message_id] = formatted + return formatted +end + +local function format_part(part_id) + local rendered_part = ctx.render_state:get_part(part_id) + if not rendered_part or not rendered_part.part then + return nil + end + + local rendered_message = ctx.render_state:get_message(rendered_part.message_id) + local message = rendered_message and rendered_message.message + if not message then + return nil + end + + local is_last_part = (buffer.get_last_part_for_message(message) == part_id) + local formatted = formatter.format_part(rendered_part.part, message, is_last_part, function(session_id) + return ctx.render_state:get_child_session_parts(session_id) + end) + + return formatted, rendered_part.message_id +end + +local function apply_message(message_id) + local previous = ctx.formatted_messages[message_id] + local formatted = format_message(message_id) + if not formatted then + return + end + buffer.upsert_message_now(message_id, formatted, previous) +end + +local function apply_part(part_id, message_id) + local previous = ctx.formatted_parts[part_id] + local formatted = nil + formatted, message_id = format_part(part_id) + if not formatted or not message_id then + return + end + + local cached = ctx.render_state:get_part(part_id) + local can_append = previous + and cached + and cached.line_start + and cached.line_end + and append.is_append_only(previous.lines or {}, formatted.lines or {}) + + ctx.formatted_parts[part_id] = formatted + ctx.last_part_formatted = { part_id = part_id, formatted_data = formatted } + + if can_append then + buffer.append_part_now( + part_id, + append.tail_lines(previous.lines or {}, formatted.lines or {}), + append.tail_extmarks(#(previous.lines or {}), formatted.extmarks), + previous + ) + return + end + + buffer.upsert_part_now(part_id, message_id, formatted, previous) +end + +local function apply_pending(pending) + local buf = state.windows and state.windows.output_buf + if not buf or not vim.api.nvim_buf_is_valid(buf) then + return false + end + + local has_updates = #pending.removed_part_order > 0 + or #pending.removed_message_order > 0 + or #pending.dirty_message_order > 0 + or #pending.dirty_part_order > 0 + + if not has_updates then + return false + end + + local scroll_snapshot = scroll.pre_flush(buf) + local saved_eventignore = vim.o.eventignore + vim.o.eventignore = 'all' + output_window.begin_update() + + for _, part_id in ipairs(pending.removed_part_order) do + if pending.removed_parts[part_id] then + buffer.remove_part_now(part_id) + end + end + + for _, message_id in ipairs(pending.removed_message_order) do + if pending.removed_messages[message_id] then + buffer.remove_message_now(message_id) + end + end + + for _, message_id in ipairs(pending.dirty_message_order) do + if pending.dirty_messages[message_id] then + apply_message(message_id) + end + + local dirty_parts = pending.dirty_part_by_message[message_id] + if dirty_parts then + local message = ctx.render_state:get_message(message_id) + local parts = message and message.message and message.message.parts or {} + for _, part in ipairs(parts or {}) do + if part.id and dirty_parts[part.id] then + apply_part(part.id, message_id) + dirty_parts[part.id] = nil + pending.dirty_parts[part.id] = nil + end + end + end + end + + for _, part_id in ipairs(pending.dirty_part_order) do + local message_id = pending.dirty_parts[part_id] + if message_id then + apply_part(part_id, message_id) + end + end + + output_window.end_update() + vim.o.eventignore = saved_eventignore + scroll.post_flush(scroll_snapshot, buf) + return true +end + +local function do_trigger_on_data_rendered() + local cb_type = type(config.ui.output.rendering.on_data_rendered) + if cb_type == 'boolean' then + return + end + if not state.windows or not state.windows.output_buf or not state.windows.output_win then + return + end + vim.b[state.windows.output_buf].opencode_markdown_namespace = output_window.markdown_namespace + if cb_type == 'function' then + pcall(config.ui.output.rendering.on_data_rendered, state.windows.output_buf, state.windows.output_win) + elseif vim.fn.exists(':RenderMarkdown') > 0 then + vim.cmd(':RenderMarkdown') + elseif vim.fn.exists(':Markview') > 0 then + vim.cmd(':Markview render ' .. state.windows.output_buf) + end +end + +M.trigger_on_data_rendered = require('opencode.util').debounce(do_trigger_on_data_rendered, config.ui.output.rendering.markdown_debounce_ms or 250) + +function M.request_on_data_rendered(force) + if force or not is_markdown_render_deferred() then + ctx.markdown_render_scheduled = false + M.trigger_on_data_rendered() + return + end + + ctx.markdown_render_scheduled = true +end + +function M.flush_pending_on_data_rendered() + if not ctx.markdown_render_scheduled or is_markdown_render_deferred() then + return + end + + ctx.markdown_render_scheduled = false + M.trigger_on_data_rendered() +end + +function M.begin_bulk_mode() + ctx:bulk_reset() + ctx.bulk_mode = true +end + +function M.end_bulk_mode() + if not ctx.bulk_mode then + return + end + ctx.bulk_mode = false + local lines = ctx.bulk_buffer_lines + if #lines == 0 then + ctx:bulk_reset() + return + end + + -- Add trailing empty line to match non-bulk behavior + table.insert(lines, '') + + local buf = state.windows and state.windows.output_buf + if not buf or not vim.api.nvim_buf_is_valid(buf) then + ctx:bulk_reset() + return + end + + -- Write all lines at once. Suppress autocmds so render-markdown and similar + -- plugins don't fire mid-write; we trigger them explicitly via vim.schedule + -- below. begin_update/end_update handles the modifiable toggle. + local saved_eventignore = vim.o.eventignore + vim.o.eventignore = 'all' + output_window.begin_update() + output_window.set_lines(lines, 0, -1) + output_window.end_update() + vim.o.eventignore = saved_eventignore + + if next(ctx.bulk_extmarks_by_line) then + output_window.set_extmarks(ctx.bulk_extmarks_by_line, 0) + end + + ctx:bulk_reset() + + vim.schedule(function() + M.request_on_data_rendered(true) + end) +end + +function M.flush() + local pending = snapshot_pending() + local applied = apply_pending(pending) + if applied and not ctx.bulk_mode then + M.request_on_data_rendered() + end +end + +return M diff --git a/lua/opencode/ui/renderer/scroll.lua b/lua/opencode/ui/renderer/scroll.lua new file mode 100644 index 00000000..ed0bb8f2 --- /dev/null +++ b/lua/opencode/ui/renderer/scroll.lua @@ -0,0 +1,62 @@ +local config = require('opencode.config') +local state = require('opencode.state') +local output_window = require('opencode.ui.output_window') + +local M = {} + +---@return integer|nil +function M.get_output_win() + local windows = state.windows + local win = windows and windows.output_win + if not win or not vim.api.nvim_win_is_valid(win) then + return nil + end + return win +end + +---Move the cursor in `win` to the last line of `buf` and scroll so it's visible. +---@param win integer +---@param buf integer +function M.scroll_win_to_bottom(win, buf) + local line_count = vim.api.nvim_buf_line_count(buf) + if line_count == 0 then + return + end + local last_line = vim.api.nvim_buf_get_lines(buf, line_count - 1, line_count, false)[1] or '' + vim.api.nvim_win_set_cursor(win, { line_count, #last_line }) + vim.api.nvim_win_call(win, function() + vim.cmd('normal! zb') + end) +end + +---@param buf integer|nil +---@return { win: integer, follow: boolean }|nil +function M.pre_flush(buf) + if not buf or not vim.api.nvim_buf_is_valid(buf) then + return nil + end + + local win = M.get_output_win() + if not win or vim.api.nvim_win_get_buf(win) ~= buf then + return nil + end + + return { + win = win, + follow = output_window.is_at_bottom(win), + } +end + +---@param snapshot { win: integer, follow: boolean }|nil +---@param buf integer|nil +function M.post_flush(snapshot, buf) + if not snapshot or not snapshot.follow or not buf or not vim.api.nvim_buf_is_valid(buf) then + return + end + if not vim.api.nvim_win_is_valid(snapshot.win) or vim.api.nvim_win_get_buf(snapshot.win) ~= buf then + return + end + M.scroll_win_to_bottom(snapshot.win, buf) +end + +return M diff --git a/lua/opencode/ui/ui.lua b/lua/opencode/ui/ui.lua index b99d41e6..6ba7089f 100644 --- a/lua/opencode/ui/ui.lua +++ b/lua/opencode/ui/ui.lua @@ -335,8 +335,14 @@ end function M.create_windows() if config.ui.enable_treesitter_markdown then - vim.treesitter.language.register('markdown', 'opencode_output') - vim.treesitter.language.register('markdown', 'opencode') + local ok, treesitter = pcall(function() + return vim.treesitter + end) + + if ok and treesitter and treesitter.language and treesitter.language.register then + treesitter.language.register('markdown', 'opencode_output') + treesitter.language.register('markdown', 'opencode') + end end local autocmds = require('opencode.ui.autocmds') diff --git a/lua/opencode/util.lua b/lua/opencode/util.lua index cde1b4ee..96e3cc03 100644 --- a/lua/opencode/util.lua +++ b/lua/opencode/util.lua @@ -499,6 +499,17 @@ function M.check_prompt_allowed(guard_callback, mentioned_files) return result, nil end +local _filetype_overrides = { + javascriptreact = 'jsx', + typescriptreact = 'tsx', + typescript = 'ts', + javascipt = 'js', + sh = 'bash', + yaml = 'yml', + text = 'txt', -- nvim 0.12-nightly returns text as the type which breaks our unit tests +} +local _filetype_cache = {} + --- Get the markdown type to use based on the filename. First gets the neovim type --- for the file. Then apply any specific overrides. Falls back to using the file --- extension if nothing else matches @@ -509,27 +520,18 @@ function M.get_markdown_filetype(filename) return '' end - local file_type_overrides = { - javascriptreact = 'jsx', - typescriptreact = 'tsx', - typescript = 'ts', - javascipt = 'js', - sh = 'bash', - yaml = 'yml', - text = 'txt', -- nvim 0.12-nightly returns text as the type which breaks our unit tests - } - - local file_type = vim.filetype.match({ filename = filename }) or '' - - if file_type_overrides[file_type] then - return file_type_overrides[file_type] + local cached = _filetype_cache[filename] + if cached ~= nil then + return cached end - if file_type and file_type ~= '' then - return file_type - end + local file_type = vim.filetype.match({ filename = filename }) or '' + local result = _filetype_overrides[file_type] + or (file_type ~= '' and file_type) + or vim.fn.fnamemodify(filename, ':e') - return vim.fn.fnamemodify(filename, ':e') + _filetype_cache[filename] = result + return result end function M.strdisplaywidth(str) diff --git a/tests/helpers.lua b/tests/helpers.lua index 1ec1a9c6..2bb729ce 100644 --- a/tests/helpers.lua +++ b/tests/helpers.lua @@ -11,12 +11,28 @@ function M.replay_setup() local state = require('opencode.state') local ui = require('opencode.ui.ui') local renderer = require('opencode.ui.renderer') + local permission_window = require('opencode.ui.permission_window') + local question_window = require('opencode.ui.question_window') + local reference_picker = require('opencode.ui.reference_picker') local empty_promise = require('opencode.promise').new():resolve(nil) config_file.config_promise = empty_promise config_file.project_promise = empty_promise config_file.providers_promise = empty_promise + if state.windows then + ui.close_windows(state.windows) + end + + renderer.reset() + permission_window.clear_all() + question_window._clear_dialog() + question_window._current_question = nil + question_window._current_question_index = 1 + question_window._collected_answers = {} + question_window._answering = false + reference_picker.clear_all() + ---@diagnostic disable-next-line: duplicate-set-field require('opencode.session').project_id = function() return nil @@ -323,9 +339,28 @@ function M.normalize_namespace_ids(extmarks) end function M.capture_output(output_buf, namespace) + local extmarks = vim.api.nvim_buf_get_extmarks(output_buf, namespace, 0, -1, { details = true }) or {} + table.sort(extmarks, function(a, b) + if a[2] ~= b[2] then + return a[2] < b[2] + end + + if a[3] ~= b[3] then + return a[3] < b[3] + end + + local a_priority = a[4] and a[4].priority or 0 + local b_priority = b[4] and b[4].priority or 0 + if a_priority ~= b_priority then + return a_priority > b_priority + end + + return a[1] < b[1] + end) + return { lines = vim.api.nvim_buf_get_lines(output_buf, 0, -1, false) or {}, - extmarks = vim.api.nvim_buf_get_extmarks(output_buf, namespace, 0, -1, { details = true }) or {}, + extmarks = extmarks, actions = vim.deepcopy(require('opencode.ui.renderer.ctx').render_state:get_all_actions()), } end diff --git a/tests/replay/renderer_spec.lua b/tests/replay/renderer_spec.lua index ca2fd2e9..fe6db602 100644 --- a/tests/replay/renderer_spec.lua +++ b/tests/replay/renderer_spec.lua @@ -277,20 +277,30 @@ describe('renderer functional tests', function() ) end - if not vim.tbl_contains(skip_full_session, name) then - it('replays ' .. name .. ' correctly (session)', function() - local renderer = require('opencode.ui.renderer') - local events = helpers.load_test_data(filepath) - state.session.set_active(helpers.get_session_from_events(events, true)) - local expected = helpers.load_test_data(expected_path) - - local session_data = helpers.load_session_from_events(events) - renderer._render_full_session_data(session_data) - - local actual = helpers.capture_output(state.windows and state.windows.output_buf, output_window.namespace) - assert_output_matches(expected, actual, name) - end) - end + if not vim.tbl_contains(skip_full_session, name) then + it('replays ' .. name .. ' correctly (session)', function() + local renderer = require('opencode.ui.renderer') + local flush = require('opencode.ui.renderer.flush') + local ctx = require('opencode.ui.renderer.ctx') + local events = helpers.load_test_data(filepath) + state.session.set_active(helpers.get_session_from_events(events, true)) + local expected = helpers.load_test_data(expected_path) + + local session_data = helpers.load_session_from_events(events) + renderer._render_full_session_data(session_data) + + -- If bulk mode is active (async writing), wait for it to complete + -- by forcing synchronous completion + if ctx.bulk_mode then + -- Force synchronous completion by calling end_bulk_mode directly + -- This ensures all content is written before we check + flush.end_bulk_mode() + end + + local actual = helpers.capture_output(state.windows and state.windows.output_buf, output_window.namespace) + assert_output_matches(expected, actual, name) + end) + end end end end diff --git a/tests/unit/core_spec.lua b/tests/unit/core_spec.lua index fbda0982..1bb4e463 100644 --- a/tests/unit/core_spec.lua +++ b/tests/unit/core_spec.lua @@ -7,6 +7,7 @@ local session = require('opencode.session') local Promise = require('opencode.promise') local stub = require('luassert.stub') local assert = require('luassert') +local flush = require('opencode.ui.renderer.flush') -- Provide a mock api_client for tests that need it local function mock_api_client() @@ -431,6 +432,60 @@ describe('opencode.core', function() end) end) + describe('_on_user_message_count_change', function() + it('flushes deferred markdown render when thinking completes', function() + local flush_stub = stub(flush, 'flush_pending_on_data_rendered') + + core._on_user_message_count_change(nil, { sess1 = 0 }, { sess1 = 1 }):wait() + + assert.stub(flush_stub).was_called() + flush_stub:revert() + end) + end) + + describe('markdown rendering metadata', function() + it('stores the markdown namespace on the output buffer before rendering', function() + local output_window = require('opencode.ui.output_window') + local buf = vim.api.nvim_create_buf(false, true) + local win = vim.api.nvim_open_win(buf, false, { + relative = 'editor', + width = 20, + height = 5, + row = 0, + col = 0, + style = 'minimal', + }) + + state.ui.set_windows({ output_buf = buf, output_win = win }) + vim.api.nvim_buf_set_var(buf, 'opencode_markdown_namespace', 0) + + local defer_stub = stub(vim, 'defer_fn').invokes(function(cb) + cb() + return 1 + end) + local original_exists = vim.fn.exists + vim.fn.exists = function(name) + if name == ':RenderMarkdown' then + return 2 + end + return original_exists(name) + end + local cmd_stub = stub(vim, 'cmd') + + flush.trigger_on_data_rendered() + + assert.equals(output_window.markdown_namespace, vim.b[buf].opencode_markdown_namespace) + assert.stub(cmd_stub).was_called_with(':RenderMarkdown') + + cmd_stub:revert() + defer_stub:revert() + vim.fn.exists = original_exists + state.ui.set_windows(nil) + pcall(vim.api.nvim_win_close, win, true) + pcall(vim.api.nvim_buf_delete, buf, { force = true }) + end) + end) + describe('cancel', function() it('aborts running session even when ui is not visible', function() state.ui.set_windows(nil) diff --git a/tests/unit/cursor_tracking_spec.lua b/tests/unit/cursor_tracking_spec.lua index 5c276145..5cbd1131 100644 --- a/tests/unit/cursor_tracking_spec.lua +++ b/tests/unit/cursor_tracking_spec.lua @@ -284,6 +284,76 @@ describe('output_window.is_at_bottom', function() -- This is the key behavioral difference from viewport-based check assert.is_true(output_window.is_at_bottom(win)) end) + + it('reports the actual visible bottom line in wrapped windows', function() + local long_line = string.rep('x', 180) + + vim.api.nvim_win_set_width(win, 20) + vim.api.nvim_set_option_value('wrap', true, { win = win, scope = 'local' }) + vim.api.nvim_buf_set_lines(buf, 0, -1, false, { 'line 1', 'line 2', long_line, 'line 4', 'line 5' }) + vim.api.nvim_win_set_cursor(win, { 5, 0 }) + pcall(vim.api.nvim_win_call, win, function() + vim.fn.winrestview({ topline = 1 }) + end) + + local visible_bottom = output_window.get_visible_bottom_line(win) + assert.equals(3, visible_bottom) + end) +end) + +describe('output_window.sync_cursor_with_viewport', function() + local output_window = require('opencode.ui.output_window') + local buf, win + + before_each(function() + config.setup({}) + buf = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_lines(buf, 0, -1, false, { + 'line 1', + 'line 2', + string.rep('x', 180), + 'line 4', + 'line 5', + }) + + win = vim.api.nvim_open_win(buf, true, { + relative = 'editor', + width = 20, + height = 5, + row = 0, + col = 0, + }) + + vim.api.nvim_set_option_value('wrap', true, { win = win, scope = 'local' }) + state.ui.set_windows({ output_win = win, output_buf = buf }) + output_window.reset_scroll_tracking(win) + end) + + after_each(function() + output_window.reset_scroll_tracking(win) + pcall(vim.api.nvim_win_close, win, true) + pcall(vim.api.nvim_buf_delete, buf, { force = true }) + state.ui.set_windows(nil) + end) + + it('keeps the cursor aligned with the actual viewport bottom while scrolling', function() + vim.api.nvim_win_set_cursor(win, { 5, 0 }) + pcall(vim.api.nvim_win_call, win, function() + vim.fn.winrestview({ topline = 1 }) + end) + output_window.sync_cursor_with_viewport(win) + + local cursor = vim.api.nvim_win_get_cursor(win) + assert.equals(3, cursor[1]) + end) + + it('does not move the cursor when the user is already reading earlier content', function() + vim.api.nvim_win_set_cursor(win, { 2, 0 }) + output_window.sync_cursor_with_viewport(win) + + local cursor = vim.api.nvim_win_get_cursor(win) + assert.equals(2, cursor[1]) + end) end) describe('renderer.scroll_to_bottom', function() diff --git a/tests/unit/output_window_spec.lua b/tests/unit/output_window_spec.lua index d4a464d9..e99a0cc0 100644 --- a/tests/unit/output_window_spec.lua +++ b/tests/unit/output_window_spec.lua @@ -1,5 +1,7 @@ local config = require('opencode.config') +local state = require('opencode.state') local output_window = require('opencode.ui.output_window') +local stub = require('luassert.stub') describe('output_window.create_buf', function() local original_config @@ -40,3 +42,74 @@ describe('output_window.create_buf', function() pcall(vim.api.nvim_buf_delete, buf, { force = true }) end) end) + +describe('output_window.highlight_changed_lines', function() + local original_config + local buf + local defer_stub + local scheduled_cb + +before_each(function() + original_config = vim.deepcopy(config.values) + config.values = vim.deepcopy(config.defaults) + buf = vim.api.nvim_create_buf(false, true) + state.ui.set_windows({ output_buf = buf }) + vim.api.nvim_buf_clear_namespace(buf, output_window.debug_namespace, 0, -1) + scheduled_cb = nil + defer_stub = stub(vim, 'defer_fn').invokes(function(cb) + scheduled_cb = cb + return 1 + end) + end) + + after_each(function() + if defer_stub then + defer_stub:revert() + end + state.ui.set_windows(nil) + pcall(vim.api.nvim_buf_delete, buf, { force = true }) + config.values = original_config + end) + + it('adds and clears debug line highlights when enabled', function() + config.setup({ + debug = { + highlight_changed_lines = true, + highlight_changed_lines_timeout_ms = 500, + }, + }) + + output_window.highlight_changed_lines(0, 1) + + local marks = vim.api.nvim_buf_get_extmarks(buf, output_window.debug_namespace, 0, -1, { details = true }) + assert.equals(2, #marks) + assert.equals('OpencodeChangedLines', marks[1][4].line_hl_group) + assert.is_function(scheduled_cb) + + scheduled_cb() + + local cleared = vim.api.nvim_buf_get_extmarks(buf, output_window.debug_namespace, 0, -1, {}) + assert.equals(0, #cleared) + end) + + it('does nothing when debug highlights are disabled', function() + config.setup({ + debug = { + highlight_changed_lines = false, + }, + }) + + output_window.highlight_changed_lines(0, 1) + + local marks = vim.api.nvim_buf_get_extmarks(buf, output_window.debug_namespace, 0, -1, {}) + assert.equals(0, #marks) + end) +end) + +describe('output_window namespaces', function() + it('exposes a dedicated markdown namespace', function() + assert.is_number(output_window.markdown_namespace) + assert.is_not.equals(output_window.namespace, output_window.markdown_namespace) + assert.is_not.equals(output_window.debug_namespace, output_window.markdown_namespace) + end) +end) diff --git a/tests/unit/renderer_buffer_spec.lua b/tests/unit/renderer_buffer_spec.lua new file mode 100644 index 00000000..77729929 --- /dev/null +++ b/tests/unit/renderer_buffer_spec.lua @@ -0,0 +1,52 @@ +local buffer = require('opencode.ui.renderer.buffer') +local ctx = require('opencode.ui.renderer.ctx') +local output_window = require('opencode.ui.output_window') +local stub = require('luassert.stub') + +describe('renderer.buffer extmarks', function() + local set_lines_stub + local clear_extmarks_stub + local set_extmarks_stub + local highlight_changed_lines_stub + + before_each(function() + ctx:reset() + set_lines_stub = stub(output_window, 'set_lines') + clear_extmarks_stub = stub(output_window, 'clear_extmarks') + set_extmarks_stub = stub(output_window, 'set_extmarks') + highlight_changed_lines_stub = stub(output_window, 'highlight_changed_lines') + end) + + after_each(function() + set_lines_stub:revert() + clear_extmarks_stub:revert() + set_extmarks_stub:revert() + highlight_changed_lines_stub:revert() + ctx:reset() + end) + + it('reapplies extmarks on the first changed line when updating a part', function() + ctx.render_state:set_part({ id = 'part_1', messageID = 'msg_1', type = 'text' }, 10, 11) + + buffer.upsert_part_now('part_1', 'msg_1', { + lines = { 'alpha', 'gamma' }, + extmarks = { + [1] = { + { line_hl_group = 'OpencodeReasoningText' }, + }, + }, + actions = {}, + }, { + lines = { 'alpha', 'beta' }, + extmarks = {}, + actions = {}, + }) + + assert.stub(clear_extmarks_stub).was_called_with(11, 12) + assert.stub(set_extmarks_stub).was_called_with({ + [0] = { + { line_hl_group = 'OpencodeReasoningText' }, + }, + }, 11) + end) +end)