summaryrefslogtreecommitdiff
path: root/internal/hexaicli
diff options
context:
space:
mode:
Diffstat (limited to 'internal/hexaicli')
-rw-r--r--internal/hexaicli/run.go64
-rw-r--r--internal/hexaicli/run_model_override_test.go2
-rw-r--r--internal/hexaicli/run_test.go64
3 files changed, 110 insertions, 20 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index bc0341d..b48bee0 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -205,8 +205,9 @@ type chatRunSummary struct {
}
func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error {
- results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr)
- if printer == nil {
+ streamSingle := len(jobs) == 1
+ results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr, streamSingle)
+ if printer == nil && !streamSingle {
if err := writeCLIJobOutputs(stdout, results); err != nil {
return err
}
@@ -214,17 +215,17 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st
return writeCLIJobSummaries(stderr, results)
}
-func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer) ([]*cliJobResult, *termprint.ColumnPrinter) {
+func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer, streamSingle bool) ([]*cliJobResult, *termprint.ColumnPrinter) {
results := make([]*cliJobResult, len(jobs))
printer := setupCLIPrinter(stdout, jobs)
+ printCLIHeader(stderr, jobs, printer)
var wg sync.WaitGroup
for _, job := range jobs {
job := job
wg.Add(1)
- printProviderInfo(stderr, job.client, job.req.model)
go func() {
defer wg.Done()
- results[job.index] = runSingleCLIJob(ctx, job, msgs, input, printer)
+ results[job.index] = runSingleCLIJob(ctx, job, msgs, input, stdout, printer, streamSingle)
}()
}
wg.Wait()
@@ -232,21 +233,21 @@ func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, inpu
}
func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter {
- if len(jobs) == 0 {
+ if len(jobs) < 2 {
return nil
}
- printer := newColumnPrinter(stdout, jobs)
- printer.PrintHeader()
- return printer
+ return newColumnPrinter(stdout, jobs)
}
-func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, printer *termprint.ColumnPrinter) *cliJobResult {
+func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult {
var errBuf bytes.Buffer
var outBuf bytes.Buffer
jobMsgs := append([]llm.Message(nil), msgs...)
writer := io.Writer(&outBuf)
if printer != nil {
writer = printer.Writer(job.index)
+ } else if streamOutput {
+ writer = io.MultiWriter(stdout, &outBuf)
}
err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf)
if printer != nil {
@@ -263,6 +264,7 @@ func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input
func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error {
printed := false
+ showHeading := cliJobResultCount(results) > 1
for _, res := range results {
if res == nil {
continue
@@ -272,7 +274,7 @@ func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error {
return err
}
}
- if err := writeCLIJobOutput(stdout, res); err != nil {
+ if err := writeCLIJobOutput(stdout, res, showHeading); err != nil {
return err
}
printed = true
@@ -280,10 +282,22 @@ func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error {
return nil
}
-func writeCLIJobOutput(stdout io.Writer, res *cliJobResult) error {
- heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
- if _, err := io.WriteString(stdout, heading); err != nil {
- return err
+func cliJobResultCount(results []*cliJobResult) int {
+ count := 0
+ for _, res := range results {
+ if res != nil {
+ count++
+ }
+ }
+ return count
+}
+
+func writeCLIJobOutput(stdout io.Writer, res *cliJobResult, showHeading bool) error {
+ if showHeading {
+ heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
+ if _, err := io.WriteString(stdout, heading); err != nil {
+ return err
+ }
}
if res.output == "" {
return nil
@@ -338,6 +352,18 @@ func newColumnPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter
return termprint.NewColumnPrinter(stdout, providers, models)
}
+func printCLIHeader(stderr io.Writer, jobs []cliJob, printer *termprint.ColumnPrinter) {
+ if len(jobs) == 0 {
+ return
+ }
+ if printer != nil {
+ printer.PrintHeaderTo(stderr)
+ return
+ }
+ job := jobs[0]
+ printProviderInfo(stderr, job.client, job.req.model)
+}
+
// WithCLISelection injects provider indices into the context so Run only executes those jobs.
func WithCLISelection(ctx context.Context, indices []int) context.Context {
if ctx == nil {
@@ -549,12 +575,16 @@ func summarizeChatRun(ctx context.Context, client llm.Client, model string, msgs
return summary
}
-// printProviderInfo writes the provider/model line to stderr.
+// printProviderInfo writes the provider:model header and divider to stderr.
func printProviderInfo(errw io.Writer, client llm.Client, model string) {
if strings.TrimSpace(model) == "" {
model = client.DefaultModel()
}
- _, _ = fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), model)
+ printer := termprint.NewColumnPrinter(errw, []string{client.Name()}, []string{model})
+ if printer == nil {
+ return
+ }
+ printer.PrintHeader()
}
// newClientFromConfig is kept for tests; delegates to llmutils.
diff --git a/internal/hexaicli/run_model_override_test.go b/internal/hexaicli/run_model_override_test.go
index b32b172..f669ede 100644
--- a/internal/hexaicli/run_model_override_test.go
+++ b/internal/hexaicli/run_model_override_test.go
@@ -39,7 +39,7 @@ func TestRun_ModelEnvOverride_FlowsIntoClient(t *testing.T) {
if seenModel != "gpt-5-codex" {
t.Fatalf("expected cfg.OpenAIModel=gpt-5-codex, got %q", seenModel)
}
- if !strings.Contains(errb.String(), "model=gpt-5-codex") {
+ if !strings.Contains(errb.String(), "openai:gpt-5-codex") {
t.Fatalf("stderr should print effective model, got: %s", errb.String())
}
}
diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go
index 34a5c51..315d016 100644
--- a/internal/hexaicli/run_test.go
+++ b/internal/hexaicli/run_test.go
@@ -154,8 +154,68 @@ model = "gpt-x"
func TestPrintProviderInfo(t *testing.T) {
var b bytes.Buffer
printProviderInfo(&b, &fakeClient{name: "x", model: "y"}, "y")
- if !strings.Contains(b.String(), "provider=x model=y") {
- t.Fatalf("missing provider line: %q", b.String())
+ if !strings.Contains(b.String(), "x:y") || !strings.Contains(b.String(), "─") {
+ t.Fatalf("missing provider header: %q", b.String())
+ }
+ if strings.Contains(b.String(), "provider=") {
+ t.Fatalf("unexpected legacy provider line: %q", b.String())
+ }
+}
+
+func TestRun_SingleProviderHeaderUsesStderr(t *testing.T) {
+ oldNew := newClientFromApp
+ defer func() { newClientFromApp = oldNew }()
+ newClientFromApp = func(_ appconfig.App) (llm.Client, error) {
+ return &fakeClient{name: "openai", model: "gpt-4.1", resp: "OUT"}, nil
+ }
+
+ restore, f := setStdin(t, "hello")
+ defer restore()
+
+ var stdout, stderr bytes.Buffer
+ if err := Run(context.Background(), nil, f, &stdout, &stderr); err != nil {
+ t.Fatalf("Run: %v", err)
+ }
+ if got := stdout.String(); got != "OUT" {
+ t.Fatalf("stdout = %q, want %q", got, "OUT")
+ }
+ if !strings.Contains(stderr.String(), "openai:gpt-4.1") || !strings.Contains(stderr.String(), "─") {
+ t.Fatalf("stderr missing provider header: %q", stderr.String())
+ }
+ if strings.Contains(stdout.String(), "openai:gpt-4.1") || strings.Contains(stdout.String(), "─") {
+ t.Fatalf("stdout should not contain provider header: %q", stdout.String())
+ }
+}
+
+func TestExecuteCLIJobs_MultiProviderHeaderUsesStderr(t *testing.T) {
+ jobs := []cliJob{
+ {
+ index: 0,
+ provider: "openai",
+ client: &fakeClient{name: "openai", model: "gpt-4.1", resp: "LEFT"},
+ req: requestArgs{model: "gpt-4.1"},
+ },
+ {
+ index: 1,
+ provider: "anthropic",
+ client: &fakeClient{name: "anthropic", model: "claude", resp: "RIGHT"},
+ req: requestArgs{model: "claude"},
+ },
+ }
+
+ var stdout, stderr bytes.Buffer
+ results, printer := executeCLIJobs(context.Background(), jobs, buildMessages("hello"), "hello", &stdout, &stderr, false)
+ if printer == nil {
+ t.Fatalf("expected column printer for multi-provider run")
+ }
+ if len(results) != 2 || results[0] == nil || results[1] == nil {
+ t.Fatalf("expected results for both jobs, got %#v", results)
+ }
+ if !strings.Contains(stderr.String(), "openai:gpt-4.1") || !strings.Contains(stderr.String(), "anthropic:claude") {
+ t.Fatalf("stderr missing multi-provider header: %q", stderr.String())
+ }
+ if strings.Contains(stdout.String(), "openai:gpt-4.1") || strings.Contains(stdout.String(), "anthropic:claude") {
+ t.Fatalf("stdout should not contain provider header: %q", stdout.String())
}
}