-
Notifications
You must be signed in to change notification settings - Fork 53
Expand file tree
/
Copy pathvariant_picker.lua
More file actions
107 lines (92 loc) · 3.04 KB
/
variant_picker.lua
File metadata and controls
107 lines (92 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
local M = {}
local base_picker = require('opencode.ui.base_picker')
local state = require('opencode.state')
local config = require('opencode.config')
local config_file = require('opencode.config_file')
local model_state = require('opencode.model_state')
local util = require('opencode.util')
---Get variants for the current model
---@return table[] variants Array of variant items
local function get_current_model_variants()
if not state.current_model then
return {}
end
local provider, model = state.current_model:match('^(.-)/(.+)$')
if not provider or not model then
return {}
end
local model_info = config_file.get_model_info(provider, model)
if not model_info or not model_info.variants then
return {}
end
local variants = {}
for variant_name, variant_config in pairs(model_info.variants) do
table.insert(variants, {
name = variant_name,
config = variant_config,
})
end
util.sort_by_priority(variants, function(item)
return item.name
end, { low = 1, medium = 2, high = 3 })
return variants
end
---Show variant picker
---@param callback fun(selection: table?) Callback when variant is selected
function M.select(callback)
local variants = get_current_model_variants()
if #variants == 0 then
vim.notify('Current model does not support variants', vim.log.levels.WARN)
if callback then
callback(nil)
end
return
end
-- Get saved variant from model state if no current variant is set
if not state.current_variant and state.current_model then
local provider, model = state.current_model:match('^(.-)/(.+)$')
if provider and model then
local saved_variant = model_state.get_variant(provider, model)
if saved_variant then
state.model.set_variant(saved_variant)
end
end
end
base_picker.pick({
title = 'Select variant',
items = variants,
layout_opts = config.ui.picker,
format_fn = function(item, width)
local item_width = width or vim.api.nvim_win_get_width(0)
local is_current = state.current_variant == item.name
local current_indicator = is_current and '*' or ' '
local name_width = item_width - vim.api.nvim_strwidth(current_indicator)
local picker_item = base_picker.create_picker_item({
{
text = current_indicator,
highlight = is_current and 'OpencodeContextSwitchOn' or 'OpencodeHint',
},
{
text = base_picker.align(item.name, name_width, { truncate = true }),
highlight = is_current and 'OpencodeContextSwitchOn' or nil,
},
})
return picker_item
end,
actions = {},
callback = function(selection)
if selection and state.current_model then
state.model.set_variant(selection.name)
-- Save variant to model state
local provider, model = state.current_model:match('^(.-)/(.+)$')
if provider and model then
model_state.set_variant(provider, model, selection.name)
end
end
if callback then
callback(selection)
end
end,
})
end
return M