diff options
Diffstat (limited to 'internal/hexaicli/run.go')
| -rw-r--r-- | internal/hexaicli/run.go | 181 |
1 files changed, 142 insertions, 39 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 5f6284c..06fcb83 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -3,12 +3,14 @@ package hexaicli import ( + "bytes" "context" "fmt" "io" "log" "os" "strings" + "sync" "time" "codeberg.org/snonux/hexai/internal/appconfig" @@ -25,48 +27,79 @@ type requestArgs struct { options []llm.RequestOption } -func buildCLIRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { - provider := canonicalProvider(cfg.Provider) +type cliJob struct { + index int + provider string + entry appconfig.SurfaceConfig + client llm.Client + req requestArgs +} + +func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) { entries := cfg.CLIConfigs if len(entries) == 0 { - entries = []appconfig.SurfaceConfig{{Provider: cfg.Provider, Model: strings.TrimSpace(defaultModelForProvider(cfg, provider))}} + entries = []appconfig.SurfaceConfig{{}} + } + jobs := make([]cliJob, 0, len(entries)) + for i, raw := range entries { + entry := appconfig.SurfaceConfig{Provider: strings.TrimSpace(raw.Provider), Model: strings.TrimSpace(raw.Model), Temperature: raw.Temperature} + provider := entry.Provider + if provider == "" { + provider = cfg.Provider + } + provider = canonicalProvider(provider) + derived := cfg + derived.Provider = provider + switch provider { + case "openai": + if entry.Model != "" { + derived.OpenAIModel = entry.Model + } + case "copilot": + if entry.Model != "" { + derived.CopilotModel = entry.Model + } + case "ollama": + if entry.Model != "" { + derived.OllamaModel = entry.Model + } + } + client, err := newClientFromApp(derived) + if err != nil { + return nil, err + } + req := buildCLIRequest(entry, provider, cfg, client) + if strings.TrimSpace(req.model) == "" { + req.model = strings.TrimSpace(client.DefaultModel()) + } + jobs = append(jobs, cliJob{index: i, provider: provider, entry: entry, client: client, req: req}) } - primary := entries[0] - if strings.TrimSpace(primary.Provider) != "" { - provider = canonicalProvider(primary.Provider) + return jobs, nil +} + +func buildCLIRequest(entry appconfig.SurfaceConfig, provider string, cfg appconfig.App, client llm.Client) requestArgs { + opts := make([]llm.RequestOption, 0, 2) + if cfg.MaxTokens > 0 { + opts = append(opts, llm.WithMaxTokens(cfg.MaxTokens)) } - model := strings.TrimSpace(primary.Model) - if client != nil { - provider = strings.ToLower(strings.TrimSpace(client.Name())) - if model == "" { + model := strings.TrimSpace(entry.Model) + if model == "" { + if client != nil { model = strings.TrimSpace(client.DefaultModel()) } + if model == "" { + model = strings.TrimSpace(defaultModelForProvider(cfg, provider)) + } } - if model == "" { - model = strings.TrimSpace(defaultModelForProvider(cfg, provider)) - } - opts := make([]llm.RequestOption, 0, 2) - if strings.TrimSpace(primary.Model) != "" { - opts = append(opts, llm.WithModel(strings.TrimSpace(primary.Model))) + if entry.Model != "" { + opts = append(opts, llm.WithModel(entry.Model)) } - if temp, ok := cliTemperatureFromEntry(cfg, provider, primary, model); ok { + if temp, ok := cliTemperatureFromEntry(cfg, provider, entry, model); ok { opts = append(opts, llm.WithTemperature(temp)) } return requestArgs{model: model, options: opts} } -func defaultRequestArgs(cfg appconfig.App, client llm.Client) requestArgs { - if len(cfg.CLIConfigs) > 0 { - if m := strings.TrimSpace(cfg.CLIConfigs[0].Model); m != "" { - return requestArgs{model: m} - } - } - if client != nil { - return requestArgs{model: strings.TrimSpace(client.DefaultModel())} - } - return requestArgs{} -} - func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig.SurfaceConfig, model string) (float64, bool) { if entry.Temperature != nil { return *entry.Temperature, true @@ -112,17 +145,14 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. if cfg.StatsWindowMinutes > 0 { stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute) } - if len(cfg.CLIConfigs) > 0 { - if provider := strings.TrimSpace(cfg.CLIConfigs[0].Provider); provider != "" { - cfg.Provider = provider - } - } - client, err := newClientFromApp(cfg) + jobs, err := buildCLIJobs(cfg) if err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) return err } - req := buildCLIRequestArgs(cfg, client) + if len(jobs) == 0 { + return fmt.Errorf("hexai: no CLI providers configured") + } // Prefer piped stdin when present; only open the editor when there are no args // and no stdin content available. input, rerr := readInput(stdin, args) @@ -136,9 +166,8 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset) return rerr } - printProviderInfo(stderr, client, req.model) msgs := buildMessagesFromConfig(cfg, input) - if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { + if err := runCLIJobs(ctx, jobs, msgs, input, stdout, stderr); err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) return err } @@ -153,7 +182,7 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) return err } - req := defaultRequestArgs(appconfig.App{}, client) + req := requestArgs{model: strings.TrimSpace(client.DefaultModel())} printProviderInfo(stderr, client, req.model) msgs := buildMessages(input) if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil { @@ -163,6 +192,80 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, return nil } +type cliJobResult struct { + provider string + model string + output string + summary string + err error +} + +func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error { + results := make([]*cliJobResult, len(jobs)) + var wg sync.WaitGroup + for _, job := range jobs { + job := job + wg.Add(1) + printProviderInfo(stderr, job.client, job.req.model) + go func() { + defer wg.Done() + var outBuf, errBuf bytes.Buffer + jobMsgs := make([]llm.Message, len(msgs)) + copy(jobMsgs, msgs) + err := runChat(ctx, job.client, job.req, jobMsgs, input, &outBuf, &errBuf) + results[job.index] = &cliJobResult{ + provider: job.client.Name(), + model: job.req.model, + output: outBuf.String(), + summary: errBuf.String(), + err: err, + } + }() + } + wg.Wait() + var firstErr error + printed := false + for _, res := range results { + if res == nil { + continue + } + if printed { + if _, err := io.WriteString(stdout, "\n"); err != nil { + return err + } + } + heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) + if _, err := io.WriteString(stdout, heading); err != nil { + return err + } + if res.output != "" { + if _, err := io.WriteString(stdout, res.output); err != nil { + return err + } + if !strings.HasSuffix(res.output, "\n") { + if _, err := io.WriteString(stdout, "\n"); err != nil { + return err + } + } + } + printed = true + if res.summary != "" { + if _, err := io.WriteString(stderr, res.summary); err != nil { + return err + } + } + if res.err != nil { + if _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err); err != nil { + return err + } + if firstErr == nil { + firstErr = res.err + } + } + } + return firstErr +} + // readInput reads from stdin and args, then combines them per CLI rules. func readInput(stdin io.Reader, args []string) (string, error) { var stdinData string |
