summaryrefslogtreecommitdiff
path: root/internal/hexaicli/run.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/hexaicli/run.go')
-rw-r--r--internal/hexaicli/run.go181
1 files changed, 142 insertions, 39 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index 5f6284c..06fcb83 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -3,12 +3,14 @@
package hexaicli
import (
+ "bytes"
"context"
"fmt"
"io"
"log"
"os"
"strings"
+ "sync"
"time"
"codeberg.org/snonux/hexai/internal/appconfig"
@@ -25,48 +27,79 @@ type requestArgs struct {
options []llm.RequestOption
}
-func buildCLIRequestArgs(cfg appconfig.App, client llm.Client) requestArgs {
- provider := canonicalProvider(cfg.Provider)
+type cliJob struct {
+ index int
+ provider string
+ entry appconfig.SurfaceConfig
+ client llm.Client
+ req requestArgs
+}
+
+func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) {
entries := cfg.CLIConfigs
if len(entries) == 0 {
- entries = []appconfig.SurfaceConfig{{Provider: cfg.Provider, Model: strings.TrimSpace(defaultModelForProvider(cfg, provider))}}
+ entries = []appconfig.SurfaceConfig{{}}
+ }
+ jobs := make([]cliJob, 0, len(entries))
+ for i, raw := range entries {
+ entry := appconfig.SurfaceConfig{Provider: strings.TrimSpace(raw.Provider), Model: strings.TrimSpace(raw.Model), Temperature: raw.Temperature}
+ provider := entry.Provider
+ if provider == "" {
+ provider = cfg.Provider
+ }
+ provider = canonicalProvider(provider)
+ derived := cfg
+ derived.Provider = provider
+ switch provider {
+ case "openai":
+ if entry.Model != "" {
+ derived.OpenAIModel = entry.Model
+ }
+ case "copilot":
+ if entry.Model != "" {
+ derived.CopilotModel = entry.Model
+ }
+ case "ollama":
+ if entry.Model != "" {
+ derived.OllamaModel = entry.Model
+ }
+ }
+ client, err := newClientFromApp(derived)
+ if err != nil {
+ return nil, err
+ }
+ req := buildCLIRequest(entry, provider, cfg, client)
+ if strings.TrimSpace(req.model) == "" {
+ req.model = strings.TrimSpace(client.DefaultModel())
+ }
+ jobs = append(jobs, cliJob{index: i, provider: provider, entry: entry, client: client, req: req})
}
- primary := entries[0]
- if strings.TrimSpace(primary.Provider) != "" {
- provider = canonicalProvider(primary.Provider)
+ return jobs, nil
+}
+
+func buildCLIRequest(entry appconfig.SurfaceConfig, provider string, cfg appconfig.App, client llm.Client) requestArgs {
+ opts := make([]llm.RequestOption, 0, 2)
+ if cfg.MaxTokens > 0 {
+ opts = append(opts, llm.WithMaxTokens(cfg.MaxTokens))
}
- model := strings.TrimSpace(primary.Model)
- if client != nil {
- provider = strings.ToLower(strings.TrimSpace(client.Name()))
- if model == "" {
+ model := strings.TrimSpace(entry.Model)
+ if model == "" {
+ if client != nil {
model = strings.TrimSpace(client.DefaultModel())
}
+ if model == "" {
+ model = strings.TrimSpace(defaultModelForProvider(cfg, provider))
+ }
}
- if model == "" {
- model = strings.TrimSpace(defaultModelForProvider(cfg, provider))
- }
- opts := make([]llm.RequestOption, 0, 2)
- if strings.TrimSpace(primary.Model) != "" {
- opts = append(opts, llm.WithModel(strings.TrimSpace(primary.Model)))
+ if entry.Model != "" {
+ opts = append(opts, llm.WithModel(entry.Model))
}
- if temp, ok := cliTemperatureFromEntry(cfg, provider, primary, model); ok {
+ if temp, ok := cliTemperatureFromEntry(cfg, provider, entry, model); ok {
opts = append(opts, llm.WithTemperature(temp))
}
return requestArgs{model: model, options: opts}
}
-func defaultRequestArgs(cfg appconfig.App, client llm.Client) requestArgs {
- if len(cfg.CLIConfigs) > 0 {
- if m := strings.TrimSpace(cfg.CLIConfigs[0].Model); m != "" {
- return requestArgs{model: m}
- }
- }
- if client != nil {
- return requestArgs{model: strings.TrimSpace(client.DefaultModel())}
- }
- return requestArgs{}
-}
-
func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig.SurfaceConfig, model string) (float64, bool) {
if entry.Temperature != nil {
return *entry.Temperature, true
@@ -112,17 +145,14 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.
if cfg.StatsWindowMinutes > 0 {
stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute)
}
- if len(cfg.CLIConfigs) > 0 {
- if provider := strings.TrimSpace(cfg.CLIConfigs[0].Provider); provider != "" {
- cfg.Provider = provider
- }
- }
- client, err := newClientFromApp(cfg)
+ jobs, err := buildCLIJobs(cfg)
if err != nil {
fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err)
return err
}
- req := buildCLIRequestArgs(cfg, client)
+ if len(jobs) == 0 {
+ return fmt.Errorf("hexai: no CLI providers configured")
+ }
// Prefer piped stdin when present; only open the editor when there are no args
// and no stdin content available.
input, rerr := readInput(stdin, args)
@@ -136,9 +166,8 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.
fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset)
return rerr
}
- printProviderInfo(stderr, client, req.model)
msgs := buildMessagesFromConfig(cfg, input)
- if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil {
+ if err := runCLIJobs(ctx, jobs, msgs, input, stdout, stderr); err != nil {
fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err)
return err
}
@@ -153,7 +182,7 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout,
fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset)
return err
}
- req := defaultRequestArgs(appconfig.App{}, client)
+ req := requestArgs{model: strings.TrimSpace(client.DefaultModel())}
printProviderInfo(stderr, client, req.model)
msgs := buildMessages(input)
if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil {
@@ -163,6 +192,80 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout,
return nil
}
+type cliJobResult struct {
+ provider string
+ model string
+ output string
+ summary string
+ err error
+}
+
+func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error {
+ results := make([]*cliJobResult, len(jobs))
+ var wg sync.WaitGroup
+ for _, job := range jobs {
+ job := job
+ wg.Add(1)
+ printProviderInfo(stderr, job.client, job.req.model)
+ go func() {
+ defer wg.Done()
+ var outBuf, errBuf bytes.Buffer
+ jobMsgs := make([]llm.Message, len(msgs))
+ copy(jobMsgs, msgs)
+ err := runChat(ctx, job.client, job.req, jobMsgs, input, &outBuf, &errBuf)
+ results[job.index] = &cliJobResult{
+ provider: job.client.Name(),
+ model: job.req.model,
+ output: outBuf.String(),
+ summary: errBuf.String(),
+ err: err,
+ }
+ }()
+ }
+ wg.Wait()
+ var firstErr error
+ printed := false
+ for _, res := range results {
+ if res == nil {
+ continue
+ }
+ if printed {
+ if _, err := io.WriteString(stdout, "\n"); err != nil {
+ return err
+ }
+ }
+ heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
+ if _, err := io.WriteString(stdout, heading); err != nil {
+ return err
+ }
+ if res.output != "" {
+ if _, err := io.WriteString(stdout, res.output); err != nil {
+ return err
+ }
+ if !strings.HasSuffix(res.output, "\n") {
+ if _, err := io.WriteString(stdout, "\n"); err != nil {
+ return err
+ }
+ }
+ }
+ printed = true
+ if res.summary != "" {
+ if _, err := io.WriteString(stderr, res.summary); err != nil {
+ return err
+ }
+ }
+ if res.err != nil {
+ if _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err); err != nil {
+ return err
+ }
+ if firstErr == nil {
+ firstErr = res.err
+ }
+ }
+ }
+ return firstErr
+}
+
// readInput reads from stdin and args, then combines them per CLI rules.
func readInput(stdin io.Reader, args []string) (string, error) {
var stdinData string