summaryrefslogtreecommitdiff
path: root/internal/hexaicli/run.go
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2025-09-26 19:34:19 +0300
committerPaul Buetow <paul@buetow.org>2025-09-26 19:34:19 +0300
commit0583b360ceb606b8e58f12a17f588bd27feeb117 (patch)
treeae8ac0d7968a409a76d18d84e080d02da52ce775 /internal/hexaicli/run.go
parent869c018a7a26285263cf7692f25f6aa44e2635c9 (diff)
Add per-surface provider overrides and wiring
Diffstat (limited to 'internal/hexaicli/run.go')
-rw-r--r--internal/hexaicli/run.go119
1 files changed, 105 insertions, 14 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index 11e8938..b965261 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -20,6 +20,84 @@ import (
"codeberg.org/snonux/hexai/internal/tmux"
)
+type requestArgs struct {
+ model string
+ options []llm.RequestOption
+}
+
+func buildCLIRequestArgs(cfg appconfig.App, client llm.Client) requestArgs {
+ provider := canonicalProvider(cfg.Provider)
+ if strings.TrimSpace(cfg.CLIProvider) != "" {
+ provider = canonicalProvider(cfg.CLIProvider)
+ }
+ if client != nil {
+ provider = strings.ToLower(strings.TrimSpace(client.Name()))
+ }
+ override := strings.TrimSpace(cfg.CLIModel)
+ fallback := strings.TrimSpace(defaultModelForProvider(cfg, provider))
+ if client != nil {
+ if dm := strings.TrimSpace(client.DefaultModel()); dm != "" {
+ fallback = dm
+ }
+ }
+ effective := override
+ if effective == "" {
+ effective = fallback
+ }
+ opts := make([]llm.RequestOption, 0, 2)
+ if override != "" {
+ opts = append(opts, llm.WithModel(override))
+ }
+ if temp, ok := cliTemperature(cfg, provider, effective); ok {
+ opts = append(opts, llm.WithTemperature(temp))
+ }
+ return requestArgs{model: effective, options: opts}
+}
+
+func defaultRequestArgs(cfg appconfig.App, client llm.Client) requestArgs {
+ model := strings.TrimSpace(cfg.CLIModel)
+ if model == "" && client != nil {
+ model = strings.TrimSpace(client.DefaultModel())
+ }
+ return requestArgs{model: model}
+}
+
+func cliTemperature(cfg appconfig.App, provider, model string) (float64, bool) {
+ if cfg.CLITemperature != nil {
+ return *cfg.CLITemperature, true
+ }
+ if cfg.CodingTemperature != nil {
+ temp := *cfg.CodingTemperature
+ if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") && temp == 0.2 {
+ temp = 1.0
+ }
+ return temp, true
+ }
+ if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") {
+ return 1.0, true
+ }
+ return 0, false
+}
+
+func canonicalProvider(name string) string {
+ p := strings.ToLower(strings.TrimSpace(name))
+ if p == "" {
+ return "openai"
+ }
+ return p
+}
+
+func defaultModelForProvider(cfg appconfig.App, provider string) string {
+ switch provider {
+ case "ollama":
+ return cfg.OllamaModel
+ case "copilot":
+ return cfg.CopilotModel
+ default:
+ return cfg.OpenAIModel
+ }
+}
+
// Run executes the Hexai CLI behavior given arguments and I/O streams.
// It assumes flags have already been parsed by the caller.
func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
@@ -29,11 +107,16 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.
if cfg.StatsWindowMinutes > 0 {
stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute)
}
+ providerOverride := strings.TrimSpace(cfg.CLIProvider)
+ if providerOverride != "" {
+ cfg.Provider = providerOverride
+ }
client, err := newClientFromApp(cfg)
if err != nil {
fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err)
return err
}
+ req := buildCLIRequestArgs(cfg, client)
// Prefer piped stdin when present; only open the editor when there are no args
// and no stdin content available.
input, rerr := readInput(stdin, args)
@@ -47,9 +130,9 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.
fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset)
return rerr
}
- printProviderInfo(stderr, client)
+ printProviderInfo(stderr, client, req.model)
msgs := buildMessagesFromConfig(cfg, input)
- if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil {
+ if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil {
fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err)
return err
}
@@ -64,9 +147,10 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout,
fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset)
return err
}
- printProviderInfo(stderr, client)
+ req := defaultRequestArgs(appconfig.App{}, client)
+ printProviderInfo(stderr, client, req.model)
msgs := buildMessages(input)
- if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil {
+ if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil {
fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err)
return err
}
@@ -128,22 +212,26 @@ func buildMessagesFromConfig(cfg appconfig.App, input string) []llm.Message {
}
// runChat executes the chat request, handling streaming and summary output.
-func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error {
+func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error {
start := time.Now()
// Best-effort tmux status update (colored start heartbeat)
- _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), client.DefaultModel()))
+ model := strings.TrimSpace(req.model)
+ if model == "" {
+ model = client.DefaultModel()
+ }
+ _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model))
var output string
if s, ok := client.(llm.Streamer); ok {
var b strings.Builder
if err := s.ChatStream(ctx, msgs, func(chunk string) {
b.WriteString(chunk)
fmt.Fprint(out, chunk)
- }); err != nil {
+ }, req.options...); err != nil {
return err
}
output = b.String()
} else {
- txt, err := client.Chat(ctx, msgs)
+ txt, err := client.Chat(ctx, msgs, req.options...)
if err != nil {
return err
}
@@ -157,7 +245,7 @@ func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input s
sent += len(m.Content)
}
recv := len(output)
- _ = stats.Update(ctx, client.Name(), client.DefaultModel(), sent, recv)
+ _ = stats.Update(ctx, client.Name(), model, sent, recv)
snap, _ := stats.TakeSnapshot()
minsWin := snap.Window.Minutes()
if minsWin <= 0 {
@@ -165,20 +253,23 @@ func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input s
}
scopeReqs := int64(0)
if pe, ok := snap.Providers[client.Name()]; ok {
- if mc, ok2 := pe.Models[client.DefaultModel()]; ok2 {
+ if mc, ok2 := pe.Models[model]; ok2 {
scopeReqs = mc.Reqs
}
}
scopeRPM := float64(scopeReqs) / minsWin
fmt.Fprintf(errw, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d | global Σ reqs=%d rpm=%.2f"+logging.AnsiReset+"\n",
- client.Name(), client.DefaultModel(), dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM)
- _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), client.DefaultModel(), scopeRPM, scopeReqs, snap.Window))
+ client.Name(), model, dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM)
+ _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), model, scopeRPM, scopeReqs, snap.Window))
return nil
}
// printProviderInfo writes the provider/model line to stderr.
-func printProviderInfo(errw io.Writer, client llm.Client) {
- fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), client.DefaultModel())
+func printProviderInfo(errw io.Writer, client llm.Client, model string) {
+ if strings.TrimSpace(model) == "" {
+ model = client.DefaultModel()
+ }
+ fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), model)
}
// newClientFromConfig is kept for tests; delegates to llmutils.