summaryrefslogtreecommitdiff
path: root/internal/hexaiaction
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2025-09-26 19:34:19 +0300
committerPaul Buetow <paul@buetow.org>2025-09-26 19:34:19 +0300
commit0583b360ceb606b8e58f12a17f588bd27feeb117 (patch)
treeae8ac0d7968a409a76d18d84e080d02da52ce775 /internal/hexaiaction
parent869c018a7a26285263cf7692f25f6aa44e2635c9 (diff)
Add per-surface provider overrides and wiring
Diffstat (limited to 'internal/hexaiaction')
-rw-r--r--internal/hexaiaction/prompts.go86
-rw-r--r--internal/hexaiaction/prompts_more_test.go33
-rw-r--r--internal/hexaiaction/run.go9
3 files changed, 111 insertions, 17 deletions
diff --git a/internal/hexaiaction/prompts.go b/internal/hexaiaction/prompts.go
index 207302e..47dadbf 100644
--- a/internal/hexaiaction/prompts.go
+++ b/internal/hexaiaction/prompts.go
@@ -25,6 +25,11 @@ type chatDoer interface {
type providerNamer interface{ Name() string }
+type requestArgs struct {
+ model string
+ options []llm.RequestOption
+}
+
func providerOf(c any) string {
if n, ok := c.(providerNamer); ok {
return n.Name()
@@ -32,6 +37,42 @@ func providerOf(c any) string {
return "llm"
}
+func canonicalProvider(name string) string {
+ p := strings.ToLower(strings.TrimSpace(name))
+ if p == "" {
+ return "openai"
+ }
+ return p
+}
+
+func defaultModelForProvider(cfg appconfig.App, provider string) string {
+ switch provider {
+ case "ollama":
+ return cfg.OllamaModel
+ case "copilot":
+ return cfg.CopilotModel
+ default:
+ return cfg.OpenAIModel
+ }
+}
+
+func selectActionTemperature(cfg appconfig.App, provider, model string) (float64, bool) {
+ if cfg.CodeActionTemperature != nil {
+ return *cfg.CodeActionTemperature, true
+ }
+ if cfg.CodingTemperature != nil {
+ temp := *cfg.CodingTemperature
+ if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") && temp == 0.2 {
+ temp = 1.0
+ }
+ return temp, true
+ }
+ if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") {
+ return 1.0, true
+ }
+ return 0, false
+}
+
func runRewrite(ctx context.Context, cfg appconfig.App, client chatDoer, instruction, selection string) (string, error) {
sys := cfg.PromptCodeActionRewriteSystem
user := Render(cfg.PromptCodeActionRewriteUser, map[string]string{"instruction": instruction, "selection": selection})
@@ -118,9 +159,9 @@ func runOnce(ctx context.Context, client chatDoer, sys, user string) (string, er
return out, nil
}
-func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, opts []llm.RequestOption) (string, error) {
+func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, req requestArgs) (string, error) {
msgs := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}}
- txt, err := client.Chat(ctx, msgs, opts...)
+ txt, err := client.Chat(ctx, msgs, req.options...)
if err != nil {
return "", err
}
@@ -131,7 +172,11 @@ func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, opt
sent += len(m.Content)
}
recv := len(out)
- _ = stats.Update(ctx, providerOf(client), client.DefaultModel(), sent, recv)
+ model := strings.TrimSpace(req.model)
+ if model == "" {
+ model = client.DefaultModel()
+ }
+ _ = stats.Update(ctx, providerOf(client), model, sent, recv)
if snap, err := stats.TakeSnapshot(); err == nil {
minsWin := snap.Window.Minutes()
if minsWin <= 0 {
@@ -139,30 +184,39 @@ func runOnceWithOpts(ctx context.Context, client chatDoer, sys, user string, opt
}
scopeReqs := int64(0)
if pe, ok := snap.Providers[providerOf(client)]; ok {
- if mc, ok2 := pe.Models[client.DefaultModel()]; ok2 {
+ if mc, ok2 := pe.Models[model]; ok2 {
scopeReqs = mc.Reqs
}
}
scopeRPM := float64(scopeReqs) / minsWin
- _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, providerOf(client), client.DefaultModel(), scopeRPM, scopeReqs, snap.Window))
+ _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, providerOf(client), model, scopeRPM, scopeReqs, snap.Window))
}
return out, nil
}
// reqOptsFrom builds LLM request options similar to LSP behavior.
-func reqOptsFrom(cfg appconfig.App) []llm.RequestOption {
- opts := []llm.RequestOption{llm.WithMaxTokens(cfg.MaxTokens)}
- // Apply temperature, with special-case for gpt-5 (default temp must be 1.0)
- if cfg.CodingTemperature != nil {
- temp := *cfg.CodingTemperature
- prov := strings.ToLower(strings.TrimSpace(cfg.Provider))
- model := strings.ToLower(strings.TrimSpace(cfg.OpenAIModel))
- if prov == "openai" && strings.HasPrefix(model, "gpt-5") {
- temp = 1.0
- }
+func reqOptsFrom(cfg appconfig.App) requestArgs {
+ opts := make([]llm.RequestOption, 0, 3)
+ if cfg.MaxTokens > 0 {
+ opts = append(opts, llm.WithMaxTokens(cfg.MaxTokens))
+ }
+ provider := canonicalProvider(cfg.Provider)
+ if strings.TrimSpace(cfg.CodeActionProvider) != "" {
+ provider = canonicalProvider(cfg.CodeActionProvider)
+ }
+ override := strings.TrimSpace(cfg.CodeActionModel)
+ fallback := strings.TrimSpace(defaultModelForProvider(cfg, provider))
+ effective := override
+ if effective == "" {
+ effective = fallback
+ }
+ if override != "" {
+ opts = append(opts, llm.WithModel(override))
+ }
+ if temp, ok := selectActionTemperature(cfg, provider, effective); ok {
opts = append(opts, llm.WithTemperature(temp))
}
- return opts
+ return requestArgs{model: effective, options: opts}
}
// Timeout helpers to mirror LSP behavior.
diff --git a/internal/hexaiaction/prompts_more_test.go b/internal/hexaiaction/prompts_more_test.go
index 9f5d6cb..97d3979 100644
--- a/internal/hexaiaction/prompts_more_test.go
+++ b/internal/hexaiaction/prompts_more_test.go
@@ -5,6 +5,7 @@ import (
"strings"
"testing"
+ "codeberg.org/snonux/hexai/internal/appconfig"
"codeberg.org/snonux/hexai/internal/llm"
)
@@ -15,6 +16,11 @@ func (d simpleDoer) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOpt
}
func (d simpleDoer) DefaultModel() string { return "m" }
+func ptrFloat(v float64) *float64 {
+ x := v
+ return &x
+}
+
func TestRunOnce_StripsFences(t *testing.T) {
got, err := runOnce(context.Background(), simpleDoer{"```\nok\n```"}, "SYS", "USER")
if err != nil {
@@ -24,3 +30,30 @@ func TestRunOnce_StripsFences(t *testing.T) {
t.Fatalf("got %q", got)
}
}
+
+func TestReqOptsFrom_Override(t *testing.T) {
+ cfg := appconfig.App{MaxTokens: 123, CodeActionModel: "override", CodeActionTemperature: ptrFloat(0.6), Provider: "openai", CodeActionProvider: "copilot", CopilotModel: "gpt-4o"}
+ req := reqOptsFrom(cfg)
+ if req.model != "override" {
+ t.Fatalf("expected override model, got %q", req.model)
+ }
+ var opts llm.Options
+ for _, o := range req.options {
+ o(&opts)
+ }
+ if opts.MaxTokens != 123 || opts.Model != "override" || opts.Temperature != 0.6 {
+ t.Fatalf("unexpected options: %+v", opts)
+ }
+}
+
+func TestReqOptsFrom_Gpt5Temp(t *testing.T) {
+ cfg := appconfig.App{Provider: "openai", CodingTemperature: ptrFloat(0.2), OpenAIModel: "gpt-5.0"}
+ req := reqOptsFrom(cfg)
+ var opts llm.Options
+ for _, o := range req.options {
+ o(&opts)
+ }
+ if opts.Temperature != 1.0 {
+ t.Fatalf("expected gpt-5 temp adjustment to 1.0, got %v", opts.Temperature)
+ }
+}
diff --git a/internal/hexaiaction/run.go b/internal/hexaiaction/run.go
index a48bf94..953da80 100644
--- a/internal/hexaiaction/run.go
+++ b/internal/hexaiaction/run.go
@@ -41,12 +41,19 @@ func Run(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer) error {
if len(cfg.CustomActions) > 0 {
chooseActionFn = func() (ActionKind, error) { return RunTUIWithCustom(cfg.CustomActions, cfg.TmuxCustomMenuHotkey) }
}
+ if providerOverride := strings.TrimSpace(cfg.CodeActionProvider); providerOverride != "" {
+ cfg.Provider = providerOverride
+ }
cli, err := newClientFromApp(cfg)
if err != nil {
fmt.Fprintf(stderr, logging.AnsiBase+"hexai-tmux-action: LLM disabled: %v"+logging.AnsiReset+"\n", err)
return err
}
- _ = tmux.SetStatus(tmux.FormatLLMStartStatus(cli.Name(), cli.DefaultModel()))
+ primaryModel := strings.TrimSpace(reqOptsFrom(cfg).model)
+ if primaryModel == "" {
+ primaryModel = cli.DefaultModel()
+ }
+ _ = tmux.SetStatus(tmux.FormatLLMStartStatus(cli.Name(), primaryModel))
var client chatDoer = cli
parts, err := ParseInput(stdin)
if err != nil {