diff options
Diffstat (limited to 'internal/hexaicli')
| -rw-r--r-- | internal/hexaicli/run.go | 119 | ||||
| -rw-r--r-- | internal/hexaicli/run_more_test.go | 3 | ||||
| -rw-r--r-- | internal/hexaicli/run_test.go | 44 |
3 files changed, 147 insertions, 19 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 11e8938..b965261 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -20,6 +20,84 @@ import ( "codeberg.org/snonux/hexai/internal/tmux" ) +type requestArgs struct { + model string + options []llm.RequestOption +} + +func buildCLIRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { + provider := canonicalProvider(cfg.Provider) + if strings.TrimSpace(cfg.CLIProvider) != "" { + provider = canonicalProvider(cfg.CLIProvider) + } + if client != nil { + provider = strings.ToLower(strings.TrimSpace(client.Name())) + } + override := strings.TrimSpace(cfg.CLIModel) + fallback := strings.TrimSpace(defaultModelForProvider(cfg, provider)) + if client != nil { + if dm := strings.TrimSpace(client.DefaultModel()); dm != "" { + fallback = dm + } + } + effective := override + if effective == "" { + effective = fallback + } + opts := make([]llm.RequestOption, 0, 2) + if override != "" { + opts = append(opts, llm.WithModel(override)) + } + if temp, ok := cliTemperature(cfg, provider, effective); ok { + opts = append(opts, llm.WithTemperature(temp)) + } + return requestArgs{model: effective, options: opts} +} + +func defaultRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { + model := strings.TrimSpace(cfg.CLIModel) + if model == "" && client != nil { + model = strings.TrimSpace(client.DefaultModel()) + } + return requestArgs{model: model} +} + +func cliTemperature(cfg appconfig.App, provider, model string) (float64, bool) { + if cfg.CLITemperature != nil { + return *cfg.CLITemperature, true + } + if cfg.CodingTemperature != nil { + temp := *cfg.CodingTemperature + if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") && temp == 0.2 { + temp = 1.0 + } + return temp, true + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") { + return 1.0, true + } + return 0, false +} + +func canonicalProvider(name string) string { + p := strings.ToLower(strings.TrimSpace(name)) + if p == "" { + return "openai" + } + return p +} + +func defaultModelForProvider(cfg appconfig.App, provider string) string { + switch provider { + case "ollama": + return cfg.OllamaModel + case "copilot": + return cfg.CopilotModel + default: + return cfg.OpenAIModel + } +} + // Run executes the Hexai CLI behavior given arguments and I/O streams. // It assumes flags have already been parsed by the caller. func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { @@ -29,11 +107,16 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. if cfg.StatsWindowMinutes > 0 { stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) } + providerOverride := strings.TrimSpace(cfg.CLIProvider) + if providerOverride != "" { + cfg.Provider = providerOverride + } client, err := newClientFromApp(cfg) if err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) return err } + req := buildCLIRequestArgs(cfg, client) // Prefer piped stdin when present; only open the editor when there are no args // and no stdin content available. input, rerr := readInput(stdin, args) @@ -47,9 +130,9 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset) return rerr } - printProviderInfo(stderr, client) + printProviderInfo(stderr, client, req.model) msgs := buildMessagesFromConfig(cfg, input) - if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil { + if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) return err } @@ -64,9 +147,10 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) return err } - printProviderInfo(stderr, client) + req := defaultRequestArgs(appconfig.App{}, client) + printProviderInfo(stderr, client, req.model) msgs := buildMessages(input) - if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil { + if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) return err } @@ -128,22 +212,26 @@ func buildMessagesFromConfig(cfg appconfig.App, input string) []llm.Message { } // runChat executes the chat request, handling streaming and summary output. -func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { +func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { start := time.Now() // Best-effort tmux status update (colored start heartbeat) - _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), client.DefaultModel())) + model := strings.TrimSpace(req.model) + if model == "" { + model = client.DefaultModel() + } + _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model)) var output string if s, ok := client.(llm.Streamer); ok { var b strings.Builder if err := s.ChatStream(ctx, msgs, func(chunk string) { b.WriteString(chunk) fmt.Fprint(out, chunk) - }); err != nil { + }, req.options...); err != nil { return err } output = b.String() } else { - txt, err := client.Chat(ctx, msgs) + txt, err := client.Chat(ctx, msgs, req.options...) if err != nil { return err } @@ -157,7 +245,7 @@ func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input s sent += len(m.Content) } recv := len(output) - _ = stats.Update(ctx, client.Name(), client.DefaultModel(), sent, recv) + _ = stats.Update(ctx, client.Name(), model, sent, recv) snap, _ := stats.TakeSnapshot() minsWin := snap.Window.Minutes() if minsWin <= 0 { @@ -165,20 +253,23 @@ func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input s } scopeReqs := int64(0) if pe, ok := snap.Providers[client.Name()]; ok { - if mc, ok2 := pe.Models[client.DefaultModel()]; ok2 { + if mc, ok2 := pe.Models[model]; ok2 { scopeReqs = mc.Reqs } } scopeRPM := float64(scopeReqs) / minsWin fmt.Fprintf(errw, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d | global Σ reqs=%d rpm=%.2f"+logging.AnsiReset+"\n", - client.Name(), client.DefaultModel(), dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM) - _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), client.DefaultModel(), scopeRPM, scopeReqs, snap.Window)) + client.Name(), model, dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM) + _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), model, scopeRPM, scopeReqs, snap.Window)) return nil } // printProviderInfo writes the provider/model line to stderr. -func printProviderInfo(errw io.Writer, client llm.Client) { - fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), client.DefaultModel()) +func printProviderInfo(errw io.Writer, client llm.Client, model string) { + if strings.TrimSpace(model) == "" { + model = client.DefaultModel() + } + fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), model) } // newClientFromConfig is kept for tests; delegates to llmutils. diff --git a/internal/hexaicli/run_more_test.go b/internal/hexaicli/run_more_test.go index bd88d56..469f0c0 100644 --- a/internal/hexaicli/run_more_test.go +++ b/internal/hexaicli/run_more_test.go @@ -26,7 +26,8 @@ func TestRunChat_Streaming(t *testing.T) { var out, errw bytes.Buffer input := "hello" msgs := []llm.Message{{Role: "user", Content: input}} - if err := runChat(context.Background(), streamClient{}, msgs, input, &out, &errw); err != nil { + req := requestArgs{model: "m"} + if err := runChat(context.Background(), streamClient{}, req, msgs, input, &out, &errw); err != nil { t.Fatalf("runChat failed: %v", err) } if out.String() != "AB" { diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index a4184f6..4dcbbc5 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -16,6 +16,11 @@ type failingReader struct{ err error } func (f failingReader) Read([]byte) (int, error) { return 0, f.err } +func floatPtr(v float64) *float64 { + x := v + return &x +} + func TestReadInput_Combinations(t *testing.T) { // stdin + arg restore, f := setStdin(t, "from-stdin") @@ -72,7 +77,8 @@ func TestRunChat_StreamAndNonStream(t *testing.T) { // stream path fc := &fakeStreamer{fakeClient: fakeClient{name: "p", model: "m"}, chunks: []string{"H", "i", "!"}} var out, errb bytes.Buffer - if err := runChat(context.Background(), fc, buildMessages("hello"), "hello", &out, &errb); err != nil { + req := requestArgs{model: fc.DefaultModel()} + if err := runChat(context.Background(), fc, req, buildMessages("hello"), "hello", &out, &errb); err != nil { t.Fatalf("stream: %v", err) } if out.String() != "Hi!" || !strings.Contains(errb.String(), "provider=p model=m") { @@ -82,7 +88,7 @@ func TestRunChat_StreamAndNonStream(t *testing.T) { fc2 := &fakeClient{name: "p2", model: "m2", resp: "Yo"} out.Reset() errb.Reset() - if err := runChat(context.Background(), fc2, buildMessages("hello"), "hello", &out, &errb); err != nil { + if err := runChat(context.Background(), fc2, requestArgs{model: fc2.DefaultModel()}, buildMessages("hello"), "hello", &out, &errb); err != nil { t.Fatalf("non-stream: %v", err) } if out.String() != "Yo" || !strings.Contains(errb.String(), "provider=p2 model=m2") { @@ -101,7 +107,7 @@ func (c clientErr) DefaultModel() string { return c.model } func TestRunChat_ErrorPaths(t *testing.T) { ctx := context.Background() out, errb := &bytes.Buffer{}, &bytes.Buffer{} - if err := runChat(ctx, clientErr{"p", "m"}, buildMessages("hi"), "hi", out, errb); err == nil { + if err := runChat(ctx, clientErr{"p", "m"}, requestArgs{model: "m"}, buildMessages("hi"), "hi", out, errb); err == nil { t.Fatalf("expected error from Chat") } } @@ -139,12 +145,42 @@ func TestRun_OpenAI_NoKey_ShowsError(t *testing.T) { func TestPrintProviderInfo(t *testing.T) { var b bytes.Buffer - printProviderInfo(&b, &fakeClient{name: "x", model: "y"}) + printProviderInfo(&b, &fakeClient{name: "x", model: "y"}, "y") if !strings.Contains(b.String(), "provider=x model=y") { t.Fatalf("missing provider line: %q", b.String()) } } +func TestBuildCLIRequestArgs_Override(t *testing.T) { + cfg := appconfig.App{CLIModel: "override", CLITemperature: floatPtr(0.7), Provider: "openai", CLIProvider: "copilot", CopilotModel: "gpt-4o"} + req := buildCLIRequestArgs(cfg, &fakeClient{name: "copilot", model: "default"}) + if req.model != "override" { + t.Fatalf("expected model override, got %q", req.model) + } + var opts llm.Options + for _, o := range req.options { + o(&opts) + } + if opts.Model != "override" || opts.Temperature != 0.7 { + t.Fatalf("unexpected options: %+v", opts) + } +} + +func TestBuildCLIRequestArgs_Gpt5Temp(t *testing.T) { + cfg := appconfig.App{Provider: "openai", CodingTemperature: floatPtr(0.2)} + req := buildCLIRequestArgs(cfg, &fakeClient{name: "openai", model: "gpt-5.1"}) + if req.model != "gpt-5.1" { + t.Fatalf("expected fallback model, got %q", req.model) + } + var opts llm.Options + for _, o := range req.options { + o(&opts) + } + if opts.Temperature != 1.0 { + t.Fatalf("expected temp 1.0, got %v", opts.Temperature) + } +} + func TestNewClientFromConfig_Ollama(t *testing.T) { cfg := appconfig.App{Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m"} c, err := newClientFromConfig(cfg) |
