From 0583b360ceb606b8e58f12a17f588bd27feeb117 Mon Sep 17 00:00:00 2001 From: Paul Buetow Date: Fri, 26 Sep 2025 19:34:19 +0300 Subject: Add per-surface provider overrides and wiring --- internal/hexaiaction/prompts.go | 86 +++++++++++++++++++++++++------ internal/hexaiaction/prompts_more_test.go | 33 ++++++++++++ internal/hexaiaction/run.go | 9 +++- 3 files changed, 111 insertions(+), 17 deletions(-) (limited to 'internal/hexaiaction') diff --git a/internal/hexaiaction/prompts.go b/internal/hexaiaction/prompts.go index 207302e..47dadbf 100644 --- a/internal/hexaiaction/prompts.go +++ b/internal/hexaiaction/prompts.go @@ -25,6 +25,11 @@ type chatDoer interface { type providerNamer interface{ Name() string } +type requestArgs struct { + model string + options []llm.RequestOption +} + func providerOf(c any) string { if n, ok := c.(providerNamer); ok { return n.Name() @@ -32,6 +37,42 @@ func providerOf(c any) string { return "llm" } +func canonicalProvider(name string) string { + p := strings.ToLower(strings.TrimSpace(name)) + if p == "" { + return "openai" + } + return p +} + +func defaultModelForProvider(cfg appconfig.App, provider string) string { + switch provider { + case "ollama": + return cfg.OllamaModel + case "copilot": + return cfg.CopilotModel + default: + return cfg.OpenAIModel + } +} + +func selectActionTemperature(cfg appconfig.App, provider, model string) (float64, bool) { + if cfg.CodeActionTemperature != nil { + return *cfg.CodeActionTemperature, true + } + if cfg.CodingTemperature != nil { + temp := *cfg.CodingTemperature + if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") && temp == 0.2 { + temp = 1.0 + } + return temp, true + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") { + return 1.0, true + } + return 0, false +} + func runRewrite(ctx context.Context, cfg appconfig.App, client chatDoer, instruction, selection string) (string, error) { sys := cfg.PromptCodeActionRewriteSystem user := Render(cfg.PromptCodeActionRewriteUser, map[string]string{"instruction": instruction, "selection": selection}) @@ -118,9 +159,9 @@ func runOnce(ctx context.Context, client chatDoer, sys, user string) (string, er return out, nil } -func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, opts []llm.RequestOption) (string, error) { +func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, req requestArgs) (string, error) { msgs := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} - txt, err := client.Chat(ctx, msgs, opts...) + txt, err := client.Chat(ctx, msgs, req.options...) if err != nil { return "", err } @@ -131,7 +172,11 @@ func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, opt sent += len(m.Content) } recv := len(out) - _ = stats.Update(ctx, providerOf(client), client.DefaultModel(), sent, recv) + model := strings.TrimSpace(req.model) + if model == "" { + model = client.DefaultModel() + } + _ = stats.Update(ctx, providerOf(client), model, sent, recv) if snap, err := stats.TakeSnapshot(); err == nil { minsWin := snap.Window.Minutes() if minsWin <= 0 { @@ -139,30 +184,39 @@ func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, opt } scopeReqs := int64(0) if pe, ok := snap.Providers[providerOf(client)]; ok { - if mc, ok2 := pe.Models[client.DefaultModel()]; ok2 { + if mc, ok2 := pe.Models[model]; ok2 { scopeReqs = mc.Reqs } } scopeRPM := float64(scopeReqs) / minsWin - _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, providerOf(client), client.DefaultModel(), scopeRPM, scopeReqs, snap.Window)) + _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, providerOf(client), model, scopeRPM, scopeReqs, snap.Window)) } return out, nil } // reqOptsFrom builds LLM request options similar to LSP behavior. -func reqOptsFrom(cfg appconfig.App) []llm.RequestOption { - opts := []llm.RequestOption{llm.WithMaxTokens(cfg.MaxTokens)} - // Apply temperature, with special-case for gpt-5 (default temp must be 1.0) - if cfg.CodingTemperature != nil { - temp := *cfg.CodingTemperature - prov := strings.ToLower(strings.TrimSpace(cfg.Provider)) - model := strings.ToLower(strings.TrimSpace(cfg.OpenAIModel)) - if prov == "openai" && strings.HasPrefix(model, "gpt-5") { - temp = 1.0 - } +func reqOptsFrom(cfg appconfig.App) requestArgs { + opts := make([]llm.RequestOption, 0, 3) + if cfg.MaxTokens > 0 { + opts = append(opts, llm.WithMaxTokens(cfg.MaxTokens)) + } + provider := canonicalProvider(cfg.Provider) + if strings.TrimSpace(cfg.CodeActionProvider) != "" { + provider = canonicalProvider(cfg.CodeActionProvider) + } + override := strings.TrimSpace(cfg.CodeActionModel) + fallback := strings.TrimSpace(defaultModelForProvider(cfg, provider)) + effective := override + if effective == "" { + effective = fallback + } + if override != "" { + opts = append(opts, llm.WithModel(override)) + } + if temp, ok := selectActionTemperature(cfg, provider, effective); ok { opts = append(opts, llm.WithTemperature(temp)) } - return opts + return requestArgs{model: effective, options: opts} } // Timeout helpers to mirror LSP behavior. diff --git a/internal/hexaiaction/prompts_more_test.go b/internal/hexaiaction/prompts_more_test.go index 9f5d6cb..97d3979 100644 --- a/internal/hexaiaction/prompts_more_test.go +++ b/internal/hexaiaction/prompts_more_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/llm" ) @@ -15,6 +16,11 @@ func (d simpleDoer) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOpt } func (d simpleDoer) DefaultModel() string { return "m" } +func ptrFloat(v float64) *float64 { + x := v + return &x +} + func TestRunOnce_StripsFences(t *testing.T) { got, err := runOnce(context.Background(), simpleDoer{"```\nok\n```"}, "SYS", "USER") if err != nil { @@ -24,3 +30,30 @@ func TestRunOnce_StripsFences(t *testing.T) { t.Fatalf("got %q", got) } } + +func TestReqOptsFrom_Override(t *testing.T) { + cfg := appconfig.App{MaxTokens: 123, CodeActionModel: "override", CodeActionTemperature: ptrFloat(0.6), Provider: "openai", CodeActionProvider: "copilot", CopilotModel: "gpt-4o"} + req := reqOptsFrom(cfg) + if req.model != "override" { + t.Fatalf("expected override model, got %q", req.model) + } + var opts llm.Options + for _, o := range req.options { + o(&opts) + } + if opts.MaxTokens != 123 || opts.Model != "override" || opts.Temperature != 0.6 { + t.Fatalf("unexpected options: %+v", opts) + } +} + +func TestReqOptsFrom_Gpt5Temp(t *testing.T) { + cfg := appconfig.App{Provider: "openai", CodingTemperature: ptrFloat(0.2), OpenAIModel: "gpt-5.0"} + req := reqOptsFrom(cfg) + var opts llm.Options + for _, o := range req.options { + o(&opts) + } + if opts.Temperature != 1.0 { + t.Fatalf("expected gpt-5 temp adjustment to 1.0, got %v", opts.Temperature) + } +} diff --git a/internal/hexaiaction/run.go b/internal/hexaiaction/run.go index a48bf94..953da80 100644 --- a/internal/hexaiaction/run.go +++ b/internal/hexaiaction/run.go @@ -41,12 +41,19 @@ func Run(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer) error { if len(cfg.CustomActions) > 0 { chooseActionFn = func() (ActionKind, error) { return RunTUIWithCustom(cfg.CustomActions, cfg.TmuxCustomMenuHotkey) } } + if providerOverride := strings.TrimSpace(cfg.CodeActionProvider); providerOverride != "" { + cfg.Provider = providerOverride + } cli, err := newClientFromApp(cfg) if err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai-tmux-action: LLM disabled: %v"+logging.AnsiReset+"\n", err) return err } - _ = tmux.SetStatus(tmux.FormatLLMStartStatus(cli.Name(), cli.DefaultModel())) + primaryModel := strings.TrimSpace(reqOptsFrom(cfg).model) + if primaryModel == "" { + primaryModel = cli.DefaultModel() + } + _ = tmux.SetStatus(tmux.FormatLLMStartStatus(cli.Name(), primaryModel)) var client chatDoer = cli parts, err := ParseInput(stdin) if err != nil { -- cgit v1.2.3