diff options
Diffstat (limited to 'internal/hexaicli')
| -rw-r--r-- | internal/hexaicli/run.go | 97 | ||||
| -rw-r--r-- | internal/hexaicli/run_output_test.go | 2 | ||||
| -rw-r--r-- | internal/hexaicli/run_test.go | 2 | ||||
| -rw-r--r-- | internal/hexaicli/runner.go | 161 | ||||
| -rw-r--r-- | internal/hexaicli/runner_test.go | 61 |
5 files changed, 243 insertions, 80 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 9c7ba73..0da1a5f 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -7,20 +7,17 @@ import ( "context" "fmt" "io" - "log" "os" "strings" "sync" "time" "codeberg.org/snonux/hexai/internal/appconfig" - "codeberg.org/snonux/hexai/internal/editor" "codeberg.org/snonux/hexai/internal/llm" "codeberg.org/snonux/hexai/internal/llmutils" "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/stats" "codeberg.org/snonux/hexai/internal/termprint" - "codeberg.org/snonux/hexai/internal/tmux" ) type requestArgs struct { @@ -95,77 +92,13 @@ func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig // Run executes the Hexai CLI behavior given arguments and I/O streams. // It assumes flags have already been parsed by the caller. func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { - if spec, ok, err := tpsSimulationFromContext(ctx); err != nil { - _, _ = fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) - return err - } else if ok { - input, inputErr := readSimulationInput(stdin, args) - if inputErr != nil { - _, _ = fmt.Fprintln(stderr, logging.AnsiBase+inputErr.Error()+logging.AnsiReset) - return inputErr - } - return runTPSSimulation(ctx, spec, input, stdout) - } - - // Load configuration silently; config-load messages are noise in the CLI. - logger := log.New(io.Discard, "", 0) - configPath := configPathFromContext(ctx) - cfg := appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPath}) - if cfg.StatsWindowMinutes > 0 { - stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) - } - jobs, err := buildCLIJobs(cfg) - if err != nil { - _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) - return err - } - if selected := selectionFromContext(ctx); len(selected) > 0 { - jobs, err = filterJobsBySelection(jobs, selected) - if err != nil { - _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: %v"+logging.AnsiReset+"\n", err) - return err - } - } - if len(jobs) == 0 { - return fmt.Errorf("hexai: no CLI providers configured") - } - // Prefer piped stdin when present; only open the editor when there are no args - // and no stdin content available. - input, rerr := readInput(stdin, args) - if rerr != nil && len(args) == 0 { - if prompt, eerr := editor.OpenTempAndEdit(nil); eerr == nil && strings.TrimSpace(prompt) != "" { - args = []string{prompt} - input, rerr = readInput(stdin, args) - } - } - if rerr != nil { - _, _ = fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset) - return rerr - } - msgs := buildMessagesFromConfig(cfg, input) - if err := runCLIJobs(ctx, jobs, msgs, input, stdout, stderr); err != nil { - _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) - return err - } - return nil + return NewRunner().Run(ctx, args, stdin, stdout, stderr) } // RunWithClient executes the CLI flow using an already-constructed client. // Useful for testing and embedding. func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer, client llm.Client) error { - input, err := readInput(stdin, args) - if err != nil { - _, _ = fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) - return err - } - req := requestArgs{model: strings.TrimSpace(client.DefaultModel())} - printProviderInfo(stderr, client, req.model) - msgs := buildMessages(input) - if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { - _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) - return err - } - return nil + return NewRunner().RunWithClient(ctx, args, stdin, stdout, stderr, client) } type cliJobResult struct { @@ -184,9 +117,9 @@ type chatRunSummary struct { scopeRPM float64 } -func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error { +func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer, clientFactory cliClientFactory, statusSink cliStatusSink) error { streamSingle := len(jobs) == 1 - results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr, streamSingle) + results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr, streamSingle, clientFactory, statusSink) if printer == nil && !streamSingle { if err := writeCLIJobOutputs(stdout, results); err != nil { return err @@ -195,7 +128,7 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st return writeCLIJobSummaries(stderr, results) } -func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer, streamSingle bool) ([]*cliJobResult, *termprint.ColumnPrinter) { +func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer, streamSingle bool, clientFactory cliClientFactory, statusSink cliStatusSink) ([]*cliJobResult, *termprint.ColumnPrinter) { results := make([]*cliJobResult, len(jobs)) printer := setupCLIPrinter(stdout, jobs) printCLIHeader(stderr, jobs, printer) @@ -205,7 +138,7 @@ func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, inpu wg.Add(1) go func() { defer wg.Done() - results[job.index] = runSingleCLIJob(ctx, job, msgs, input, stdout, printer, streamSingle) + results[job.index] = runSingleCLIJob(ctx, job, msgs, input, stdout, printer, streamSingle, clientFactory, statusSink) }() } wg.Wait() @@ -219,12 +152,12 @@ func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter { return newColumnPrinter(stdout, jobs) } -func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult { +func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool, clientFactory cliClientFactory, statusSink cliStatusSink) *cliJobResult { if res := cachedCLIJobResult(job, msgs, stdout, printer, streamOutput); res != nil { return res } - client, err := newClientFromApp(job.cfg) + client, err := clientFactory(job.cfg) if err != nil { return &cliJobResult{provider: job.provider, model: job.req.model, err: err} } @@ -239,7 +172,7 @@ func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input } else if streamOutput { writer = io.MultiWriter(stdout, &outBuf) } - err = runChat(ctx, client, job.req, jobMsgs, input, writer, &errBuf) + err = runChatWithStatus(statusSink, ctx, client, job.req, jobMsgs, input, writer, &errBuf) if printer != nil { printer.Flush(job.index) } @@ -532,9 +465,15 @@ func buildMessagesFromConfig(cfg appconfig.App, input string) []llm.Message { // runChat executes the chat request, handling streaming and summary output. func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { + return runChatWithStatus(tmuxCLIStatusSink{}, ctx, client, req, msgs, input, out, errw) +} + +func runChatWithStatus(statusSink cliStatusSink, ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { start := time.Now() model := effectiveModel(req, client) - _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model)) + if statusSink != nil { + _ = statusSink.SetLLMStart(client.Name(), model) + } output, err := runChatRequest(ctx, client, req, msgs, out) if err != nil { @@ -547,7 +486,9 @@ func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm client.Name(), model, dur.Round(time.Millisecond), summary.sent, summary.recv, summary.snapshot.Global.Reqs, summary.snapshot.RPM); err != nil { return err } - _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(summary.snapshot.Global.Reqs, summary.snapshot.RPM, summary.snapshot.Global.Sent, summary.snapshot.Global.Recv, client.Name(), model, summary.scopeRPM, summary.scopeReq, summary.snapshot.Window)) + if statusSink != nil { + _ = statusSink.SetGlobal(summary.snapshot, client.Name(), model, summary.scopeRPM, summary.scopeReq) + } return nil } diff --git a/internal/hexaicli/run_output_test.go b/internal/hexaicli/run_output_test.go index b4614da..e61e4b6 100644 --- a/internal/hexaicli/run_output_test.go +++ b/internal/hexaicli/run_output_test.go @@ -400,7 +400,7 @@ func TestRunCLIJobs_MultiJob_WritesOutputs(t *testing.T) { } stdout.Reset() stderr.Reset() - if err := runCLIJobs(context.Background(), singleJobs, msgs, "hello", &stdout, &stderr); err != nil { + if err := runCLIJobs(context.Background(), singleJobs, msgs, "hello", &stdout, &stderr, newClientFromApp, nil); err != nil { t.Fatalf("runCLIJobs single: %v", err) } if !strings.Contains(stdout.String(), "out-a") { diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index 9711399..69e5d98 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -229,7 +229,7 @@ func TestExecuteCLIJobs_MultiProviderHeaderUsesStderr(t *testing.T) { } var stdout, stderr bytes.Buffer - results, printer := executeCLIJobs(context.Background(), jobs, buildMessages("hello"), "hello", &stdout, &stderr, false) + results, printer := executeCLIJobs(context.Background(), jobs, buildMessages("hello"), "hello", &stdout, &stderr, false, newClientFromApp, nil) if printer == nil { t.Fatalf("expected column printer for multi-provider run") } diff --git a/internal/hexaicli/runner.go b/internal/hexaicli/runner.go new file mode 100644 index 0000000..f372021 --- /dev/null +++ b/internal/hexaicli/runner.go @@ -0,0 +1,161 @@ +package hexaicli + +import ( + "context" + "fmt" + "io" + "log" + "strings" + "time" + + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/editor" + "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/logging" + "codeberg.org/snonux/hexai/internal/stats" + "codeberg.org/snonux/hexai/internal/tmux" +) + +type cliConfigLoader func(context.Context, *log.Logger) appconfig.App + +type cliEditorOpener func([]byte) (string, error) + +type cliClientFactory func(appconfig.App) (llm.Client, error) + +type cliStatusSink interface { + SetLLMStart(provider, model string) error + SetGlobal(snapshot stats.Snapshot, provider, model string, scopeRPM float64, scopeReq int64) error +} + +// Runner executes the CLI with injectable configuration, editor, client, and status dependencies. +type Runner struct { + loadConfig cliConfigLoader + openEditor cliEditorOpener + newClient cliClientFactory + statusSink cliStatusSink +} + +type tmuxCLIStatusSink struct{} + +func (tmuxCLIStatusSink) SetLLMStart(provider, model string) error { + return tmux.SetStatus(tmux.FormatLLMStartStatus(provider, model)) +} + +func (tmuxCLIStatusSink) SetGlobal(snapshot stats.Snapshot, provider, model string, scopeRPM float64, scopeReq int64) error { + return tmux.SetStatus(tmux.FormatGlobalStatusColored( + snapshot.Global.Reqs, + snapshot.RPM, + snapshot.Global.Sent, + snapshot.Global.Recv, + provider, + model, + scopeRPM, + scopeReq, + snapshot.Window, + )) +} + +// NewRunner builds a CLI runner with production dependencies. +func NewRunner() *Runner { + return &Runner{ + loadConfig: loadConfigFromContext, + openEditor: editor.OpenTempAndEdit, + newClient: newClientFromApp, + statusSink: tmuxCLIStatusSink{}, + } +} + +func (r *Runner) Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { + runner := normalizeRunner(r) + if spec, ok, err := tpsSimulationFromContext(ctx); err != nil { + _, _ = fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) + return err + } else if ok { + input, inputErr := readSimulationInput(stdin, args) + if inputErr != nil { + _, _ = fmt.Fprintln(stderr, logging.AnsiBase+inputErr.Error()+logging.AnsiReset) + return inputErr + } + return runTPSSimulation(ctx, spec, input, stdout) + } + + logger := log.New(io.Discard, "", 0) + cfg := runner.loadConfig(ctx, logger) + if cfg.StatsWindowMinutes > 0 { + stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) + } + jobs, err := buildCLIJobs(cfg) + if err != nil { + _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) + return err + } + if selected := selectionFromContext(ctx); len(selected) > 0 { + jobs, err = filterJobsBySelection(jobs, selected) + if err != nil { + _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: %v"+logging.AnsiReset+"\n", err) + return err + } + } + if len(jobs) == 0 { + return fmt.Errorf("hexai: no CLI providers configured") + } + + input, rerr := readInput(stdin, args) + if rerr != nil && len(args) == 0 { + if prompt, eerr := runner.openEditor(nil); eerr == nil && strings.TrimSpace(prompt) != "" { + args = []string{prompt} + input, rerr = readInput(stdin, args) + } + } + if rerr != nil { + _, _ = fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset) + return rerr + } + msgs := buildMessagesFromConfig(cfg, input) + if err := runCLIJobs(ctx, jobs, msgs, input, stdout, stderr, runner.newClient, runner.statusSink); err != nil { + _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) + return err + } + return nil +} + +func (r *Runner) RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer, client llm.Client) error { + runner := normalizeRunner(r) + input, err := readInput(stdin, args) + if err != nil { + _, _ = fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) + return err + } + req := requestArgs{model: strings.TrimSpace(client.DefaultModel())} + printProviderInfo(stderr, client, req.model) + msgs := buildMessages(input) + if err := runChatWithStatus(runner.statusSink, ctx, client, req, msgs, input, stdout, stderr); err != nil { + _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) + return err + } + return nil +} + +func normalizeRunner(r *Runner) Runner { + if r == nil { + return *NewRunner() + } + runner := *r + if runner.loadConfig == nil { + runner.loadConfig = loadConfigFromContext + } + if runner.openEditor == nil { + runner.openEditor = editor.OpenTempAndEdit + } + if runner.newClient == nil { + runner.newClient = newClientFromApp + } + if runner.statusSink == nil { + runner.statusSink = tmuxCLIStatusSink{} + } + return runner +} + +func loadConfigFromContext(ctx context.Context, logger *log.Logger) appconfig.App { + return appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPathFromContext(ctx)}) +} diff --git a/internal/hexaicli/runner_test.go b/internal/hexaicli/runner_test.go new file mode 100644 index 0000000..1d438b0 --- /dev/null +++ b/internal/hexaicli/runner_test.go @@ -0,0 +1,61 @@ +package hexaicli + +import ( + "bytes" + "context" + "log" + "strings" + "testing" + + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/stats" +) + +type recordingCLIStatusSink struct { + startProvider string + startModel string + globalCalls int +} + +func (s *recordingCLIStatusSink) SetLLMStart(provider, model string) error { + s.startProvider = provider + s.startModel = model + return nil +} + +func (s *recordingCLIStatusSink) SetGlobal(stats.Snapshot, string, string, float64, int64) error { + s.globalCalls++ + return nil +} + +func TestRunner_UsesInjectedDependencies(t *testing.T) { + sink := &recordingCLIStatusSink{} + runner := &Runner{ + loadConfig: func(context.Context, *log.Logger) appconfig.App { + return appconfig.App{ + CoreConfig: appconfig.CoreConfig{Provider: "openai"}, + PromptConfig: appconfig.PromptConfig{PromptCLIDefaultSystem: "SYS"}, + } + }, + openEditor: func([]byte) (string, error) { return "PROMPT", nil }, + newClient: func(appconfig.App) (client llm.Client, err error) { + return &fakeClient{name: "fake", model: "m", resp: "OUT"}, nil + }, + statusSink: sink, + } + + var stdout, stderr bytes.Buffer + if err := runner.Run(context.Background(), nil, strings.NewReader(""), &stdout, &stderr); err != nil { + t.Fatalf("Run: %v", err) + } + if stdout.String() != "OUT" { + t.Fatalf("stdout = %q, want OUT", stdout.String()) + } + if sink.startProvider != "fake" || sink.startModel == "" { + t.Fatalf("unexpected start status: provider=%q model=%q", sink.startProvider, sink.startModel) + } + if sink.globalCalls != 1 { + t.Fatalf("expected one global status update, got %d", sink.globalCalls) + } +} |
