diff options
Diffstat (limited to 'internal/hexaicli')
| -rw-r--r-- | internal/hexaicli/run.go | 64 | ||||
| -rw-r--r-- | internal/hexaicli/run_model_override_test.go | 2 | ||||
| -rw-r--r-- | internal/hexaicli/run_test.go | 64 |
3 files changed, 110 insertions, 20 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index bc0341d..b48bee0 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -205,8 +205,9 @@ type chatRunSummary struct { } func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error { - results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr) - if printer == nil { + streamSingle := len(jobs) == 1 + results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr, streamSingle) + if printer == nil && !streamSingle { if err := writeCLIJobOutputs(stdout, results); err != nil { return err } @@ -214,17 +215,17 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st return writeCLIJobSummaries(stderr, results) } -func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer) ([]*cliJobResult, *termprint.ColumnPrinter) { +func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer, streamSingle bool) ([]*cliJobResult, *termprint.ColumnPrinter) { results := make([]*cliJobResult, len(jobs)) printer := setupCLIPrinter(stdout, jobs) + printCLIHeader(stderr, jobs, printer) var wg sync.WaitGroup for _, job := range jobs { job := job wg.Add(1) - printProviderInfo(stderr, job.client, job.req.model) go func() { defer wg.Done() - results[job.index] = runSingleCLIJob(ctx, job, msgs, input, printer) + results[job.index] = runSingleCLIJob(ctx, job, msgs, input, stdout, printer, streamSingle) }() } wg.Wait() @@ -232,21 +233,21 @@ func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, inpu } func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter { - if len(jobs) == 0 { + if len(jobs) < 2 { return nil } - printer := newColumnPrinter(stdout, jobs) - printer.PrintHeader() - return printer + return newColumnPrinter(stdout, jobs) } -func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, printer *termprint.ColumnPrinter) *cliJobResult { +func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult { var errBuf bytes.Buffer var outBuf bytes.Buffer jobMsgs := append([]llm.Message(nil), msgs...) writer := io.Writer(&outBuf) if printer != nil { writer = printer.Writer(job.index) + } else if streamOutput { + writer = io.MultiWriter(stdout, &outBuf) } err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf) if printer != nil { @@ -263,6 +264,7 @@ func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error { printed := false + showHeading := cliJobResultCount(results) > 1 for _, res := range results { if res == nil { continue @@ -272,7 +274,7 @@ func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error { return err } } - if err := writeCLIJobOutput(stdout, res); err != nil { + if err := writeCLIJobOutput(stdout, res, showHeading); err != nil { return err } printed = true @@ -280,10 +282,22 @@ func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error { return nil } -func writeCLIJobOutput(stdout io.Writer, res *cliJobResult) error { - heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) - if _, err := io.WriteString(stdout, heading); err != nil { - return err +func cliJobResultCount(results []*cliJobResult) int { + count := 0 + for _, res := range results { + if res != nil { + count++ + } + } + return count +} + +func writeCLIJobOutput(stdout io.Writer, res *cliJobResult, showHeading bool) error { + if showHeading { + heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) + if _, err := io.WriteString(stdout, heading); err != nil { + return err + } } if res.output == "" { return nil @@ -338,6 +352,18 @@ func newColumnPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter return termprint.NewColumnPrinter(stdout, providers, models) } +func printCLIHeader(stderr io.Writer, jobs []cliJob, printer *termprint.ColumnPrinter) { + if len(jobs) == 0 { + return + } + if printer != nil { + printer.PrintHeaderTo(stderr) + return + } + job := jobs[0] + printProviderInfo(stderr, job.client, job.req.model) +} + // WithCLISelection injects provider indices into the context so Run only executes those jobs. func WithCLISelection(ctx context.Context, indices []int) context.Context { if ctx == nil { @@ -549,12 +575,16 @@ func summarizeChatRun(ctx context.Context, client llm.Client, model string, msgs return summary } -// printProviderInfo writes the provider/model line to stderr. +// printProviderInfo writes the provider:model header and divider to stderr. func printProviderInfo(errw io.Writer, client llm.Client, model string) { if strings.TrimSpace(model) == "" { model = client.DefaultModel() } - _, _ = fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), model) + printer := termprint.NewColumnPrinter(errw, []string{client.Name()}, []string{model}) + if printer == nil { + return + } + printer.PrintHeader() } // newClientFromConfig is kept for tests; delegates to llmutils. diff --git a/internal/hexaicli/run_model_override_test.go b/internal/hexaicli/run_model_override_test.go index b32b172..f669ede 100644 --- a/internal/hexaicli/run_model_override_test.go +++ b/internal/hexaicli/run_model_override_test.go @@ -39,7 +39,7 @@ func TestRun_ModelEnvOverride_FlowsIntoClient(t *testing.T) { if seenModel != "gpt-5-codex" { t.Fatalf("expected cfg.OpenAIModel=gpt-5-codex, got %q", seenModel) } - if !strings.Contains(errb.String(), "model=gpt-5-codex") { + if !strings.Contains(errb.String(), "openai:gpt-5-codex") { t.Fatalf("stderr should print effective model, got: %s", errb.String()) } } diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index 34a5c51..315d016 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -154,8 +154,68 @@ model = "gpt-x" func TestPrintProviderInfo(t *testing.T) { var b bytes.Buffer printProviderInfo(&b, &fakeClient{name: "x", model: "y"}, "y") - if !strings.Contains(b.String(), "provider=x model=y") { - t.Fatalf("missing provider line: %q", b.String()) + if !strings.Contains(b.String(), "x:y") || !strings.Contains(b.String(), "─") { + t.Fatalf("missing provider header: %q", b.String()) + } + if strings.Contains(b.String(), "provider=") { + t.Fatalf("unexpected legacy provider line: %q", b.String()) + } +} + +func TestRun_SingleProviderHeaderUsesStderr(t *testing.T) { + oldNew := newClientFromApp + defer func() { newClientFromApp = oldNew }() + newClientFromApp = func(_ appconfig.App) (llm.Client, error) { + return &fakeClient{name: "openai", model: "gpt-4.1", resp: "OUT"}, nil + } + + restore, f := setStdin(t, "hello") + defer restore() + + var stdout, stderr bytes.Buffer + if err := Run(context.Background(), nil, f, &stdout, &stderr); err != nil { + t.Fatalf("Run: %v", err) + } + if got := stdout.String(); got != "OUT" { + t.Fatalf("stdout = %q, want %q", got, "OUT") + } + if !strings.Contains(stderr.String(), "openai:gpt-4.1") || !strings.Contains(stderr.String(), "─") { + t.Fatalf("stderr missing provider header: %q", stderr.String()) + } + if strings.Contains(stdout.String(), "openai:gpt-4.1") || strings.Contains(stdout.String(), "─") { + t.Fatalf("stdout should not contain provider header: %q", stdout.String()) + } +} + +func TestExecuteCLIJobs_MultiProviderHeaderUsesStderr(t *testing.T) { + jobs := []cliJob{ + { + index: 0, + provider: "openai", + client: &fakeClient{name: "openai", model: "gpt-4.1", resp: "LEFT"}, + req: requestArgs{model: "gpt-4.1"}, + }, + { + index: 1, + provider: "anthropic", + client: &fakeClient{name: "anthropic", model: "claude", resp: "RIGHT"}, + req: requestArgs{model: "claude"}, + }, + } + + var stdout, stderr bytes.Buffer + results, printer := executeCLIJobs(context.Background(), jobs, buildMessages("hello"), "hello", &stdout, &stderr, false) + if printer == nil { + t.Fatalf("expected column printer for multi-provider run") + } + if len(results) != 2 || results[0] == nil || results[1] == nil { + t.Fatalf("expected results for both jobs, got %#v", results) + } + if !strings.Contains(stderr.String(), "openai:gpt-4.1") || !strings.Contains(stderr.String(), "anthropic:claude") { + t.Fatalf("stderr missing multi-provider header: %q", stderr.String()) + } + if strings.Contains(stdout.String(), "openai:gpt-4.1") || strings.Contains(stdout.String(), "anthropic:claude") { + t.Fatalf("stdout should not contain provider header: %q", stdout.String()) } } |
