diff options
| author | Paul Buetow <paul@buetow.org> | 2025-09-26 19:34:19 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-09-26 19:34:19 +0300 |
| commit | 0583b360ceb606b8e58f12a17f588bd27feeb117 (patch) | |
| tree | ae8ac0d7968a409a76d18d84e080d02da52ce775 /internal/hexaicli/run.go | |
| parent | 869c018a7a26285263cf7692f25f6aa44e2635c9 (diff) | |
Add per-surface provider overrides and wiring
Diffstat (limited to 'internal/hexaicli/run.go')
| -rw-r--r-- | internal/hexaicli/run.go | 119 |
1 files changed, 105 insertions, 14 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 11e8938..b965261 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -20,6 +20,84 @@ import ( "codeberg.org/snonux/hexai/internal/tmux" ) +type requestArgs struct { + model string + options []llm.RequestOption +} + +func buildCLIRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { + provider := canonicalProvider(cfg.Provider) + if strings.TrimSpace(cfg.CLIProvider) != "" { + provider = canonicalProvider(cfg.CLIProvider) + } + if client != nil { + provider = strings.ToLower(strings.TrimSpace(client.Name())) + } + override := strings.TrimSpace(cfg.CLIModel) + fallback := strings.TrimSpace(defaultModelForProvider(cfg, provider)) + if client != nil { + if dm := strings.TrimSpace(client.DefaultModel()); dm != "" { + fallback = dm + } + } + effective := override + if effective == "" { + effective = fallback + } + opts := make([]llm.RequestOption, 0, 2) + if override != "" { + opts = append(opts, llm.WithModel(override)) + } + if temp, ok := cliTemperature(cfg, provider, effective); ok { + opts = append(opts, llm.WithTemperature(temp)) + } + return requestArgs{model: effective, options: opts} +} + +func defaultRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { + model := strings.TrimSpace(cfg.CLIModel) + if model == "" && client != nil { + model = strings.TrimSpace(client.DefaultModel()) + } + return requestArgs{model: model} +} + +func cliTemperature(cfg appconfig.App, provider, model string) (float64, bool) { + if cfg.CLITemperature != nil { + return *cfg.CLITemperature, true + } + if cfg.CodingTemperature != nil { + temp := *cfg.CodingTemperature + if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") && temp == 0.2 { + temp = 1.0 + } + return temp, true + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") { + return 1.0, true + } + return 0, false +} + +func canonicalProvider(name string) string { + p := strings.ToLower(strings.TrimSpace(name)) + if p == "" { + return "openai" + } + return p +} + +func defaultModelForProvider(cfg appconfig.App, provider string) string { + switch provider { + case "ollama": + return cfg.OllamaModel + case "copilot": + return cfg.CopilotModel + default: + return cfg.OpenAIModel + } +} + // Run executes the Hexai CLI behavior given arguments and I/O streams. // It assumes flags have already been parsed by the caller. func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { @@ -29,11 +107,16 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. if cfg.StatsWindowMinutes > 0 { stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) } + providerOverride := strings.TrimSpace(cfg.CLIProvider) + if providerOverride != "" { + cfg.Provider = providerOverride + } client, err := newClientFromApp(cfg) if err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) return err } + req := buildCLIRequestArgs(cfg, client) // Prefer piped stdin when present; only open the editor when there are no args // and no stdin content available. input, rerr := readInput(stdin, args) @@ -47,9 +130,9 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset) return rerr } - printProviderInfo(stderr, client) + printProviderInfo(stderr, client, req.model) msgs := buildMessagesFromConfig(cfg, input) - if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil { + if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) return err } @@ -64,9 +147,10 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) return err } - printProviderInfo(stderr, client) + req := defaultRequestArgs(appconfig.App{}, client) + printProviderInfo(stderr, client, req.model) msgs := buildMessages(input) - if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil { + if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) return err } @@ -128,22 +212,26 @@ func buildMessagesFromConfig(cfg appconfig.App, input string) []llm.Message { } // runChat executes the chat request, handling streaming and summary output. -func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { +func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { start := time.Now() // Best-effort tmux status update (colored start heartbeat) - _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), client.DefaultModel())) + model := strings.TrimSpace(req.model) + if model == "" { + model = client.DefaultModel() + } + _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model)) var output string if s, ok := client.(llm.Streamer); ok { var b strings.Builder if err := s.ChatStream(ctx, msgs, func(chunk string) { b.WriteString(chunk) fmt.Fprint(out, chunk) - }); err != nil { + }, req.options...); err != nil { return err } output = b.String() } else { - txt, err := client.Chat(ctx, msgs) + txt, err := client.Chat(ctx, msgs, req.options...) if err != nil { return err } @@ -157,7 +245,7 @@ func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input s sent += len(m.Content) } recv := len(output) - _ = stats.Update(ctx, client.Name(), client.DefaultModel(), sent, recv) + _ = stats.Update(ctx, client.Name(), model, sent, recv) snap, _ := stats.TakeSnapshot() minsWin := snap.Window.Minutes() if minsWin <= 0 { @@ -165,20 +253,23 @@ func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input s } scopeReqs := int64(0) if pe, ok := snap.Providers[client.Name()]; ok { - if mc, ok2 := pe.Models[client.DefaultModel()]; ok2 { + if mc, ok2 := pe.Models[model]; ok2 { scopeReqs = mc.Reqs } } scopeRPM := float64(scopeReqs) / minsWin fmt.Fprintf(errw, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d | global Σ reqs=%d rpm=%.2f"+logging.AnsiReset+"\n", - client.Name(), client.DefaultModel(), dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM) - _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), client.DefaultModel(), scopeRPM, scopeReqs, snap.Window)) + client.Name(), model, dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM) + _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), model, scopeRPM, scopeReqs, snap.Window)) return nil } // printProviderInfo writes the provider/model line to stderr. -func printProviderInfo(errw io.Writer, client llm.Client) { - fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), client.DefaultModel()) +func printProviderInfo(errw io.Writer, client llm.Client, model string) { + if strings.TrimSpace(model) == "" { + model = client.DefaultModel() + } + fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), model) } // newClientFromConfig is kept for tests; delegates to llmutils. |
