summaryrefslogtreecommitdiff
path: root/internal/hexaicli
diff options
context:
space:
mode:
Diffstat (limited to 'internal/hexaicli')
-rw-r--r--internal/hexaicli/run.go119
-rw-r--r--internal/hexaicli/run_more_test.go3
-rw-r--r--internal/hexaicli/run_test.go44
3 files changed, 147 insertions, 19 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index 11e8938..b965261 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -20,6 +20,84 @@ import (
"codeberg.org/snonux/hexai/internal/tmux"
)
+type requestArgs struct {
+ model string
+ options []llm.RequestOption
+}
+
+func buildCLIRequestArgs(cfg appconfig.App, client llm.Client) requestArgs {
+ provider := canonicalProvider(cfg.Provider)
+ if strings.TrimSpace(cfg.CLIProvider) != "" {
+ provider = canonicalProvider(cfg.CLIProvider)
+ }
+ if client != nil {
+ provider = strings.ToLower(strings.TrimSpace(client.Name()))
+ }
+ override := strings.TrimSpace(cfg.CLIModel)
+ fallback := strings.TrimSpace(defaultModelForProvider(cfg, provider))
+ if client != nil {
+ if dm := strings.TrimSpace(client.DefaultModel()); dm != "" {
+ fallback = dm
+ }
+ }
+ effective := override
+ if effective == "" {
+ effective = fallback
+ }
+ opts := make([]llm.RequestOption, 0, 2)
+ if override != "" {
+ opts = append(opts, llm.WithModel(override))
+ }
+ if temp, ok := cliTemperature(cfg, provider, effective); ok {
+ opts = append(opts, llm.WithTemperature(temp))
+ }
+ return requestArgs{model: effective, options: opts}
+}
+
+func defaultRequestArgs(cfg appconfig.App, client llm.Client) requestArgs {
+ model := strings.TrimSpace(cfg.CLIModel)
+ if model == "" && client != nil {
+ model = strings.TrimSpace(client.DefaultModel())
+ }
+ return requestArgs{model: model}
+}
+
+func cliTemperature(cfg appconfig.App, provider, model string) (float64, bool) {
+ if cfg.CLITemperature != nil {
+ return *cfg.CLITemperature, true
+ }
+ if cfg.CodingTemperature != nil {
+ temp := *cfg.CodingTemperature
+ if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") && temp == 0.2 {
+ temp = 1.0
+ }
+ return temp, true
+ }
+ if provider == "openai" && strings.HasPrefix(strings.ToLower(model), "gpt-5") {
+ return 1.0, true
+ }
+ return 0, false
+}
+
+func canonicalProvider(name string) string {
+ p := strings.ToLower(strings.TrimSpace(name))
+ if p == "" {
+ return "openai"
+ }
+ return p
+}
+
+func defaultModelForProvider(cfg appconfig.App, provider string) string {
+ switch provider {
+ case "ollama":
+ return cfg.OllamaModel
+ case "copilot":
+ return cfg.CopilotModel
+ default:
+ return cfg.OpenAIModel
+ }
+}
+
// Run executes the Hexai CLI behavior given arguments and I/O streams.
// It assumes flags have already been parsed by the caller.
func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
@@ -29,11 +107,16 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.
if cfg.StatsWindowMinutes > 0 {
stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute)
}
+ providerOverride := strings.TrimSpace(cfg.CLIProvider)
+ if providerOverride != "" {
+ cfg.Provider = providerOverride
+ }
client, err := newClientFromApp(cfg)
if err != nil {
fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err)
return err
}
+ req := buildCLIRequestArgs(cfg, client)
// Prefer piped stdin when present; only open the editor when there are no args
// and no stdin content available.
input, rerr := readInput(stdin, args)
@@ -47,9 +130,9 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.
fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset)
return rerr
}
- printProviderInfo(stderr, client)
+ printProviderInfo(stderr, client, req.model)
msgs := buildMessagesFromConfig(cfg, input)
- if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil {
+ if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil {
fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err)
return err
}
@@ -64,9 +147,10 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout,
fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset)
return err
}
- printProviderInfo(stderr, client)
+ req := defaultRequestArgs(appconfig.App{}, client)
+ printProviderInfo(stderr, client, req.model)
msgs := buildMessages(input)
- if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil {
+ if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil {
fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err)
return err
}
@@ -128,22 +212,26 @@ func buildMessagesFromConfig(cfg appconfig.App, input string) []llm.Message {
}
// runChat executes the chat request, handling streaming and summary output.
-func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error {
+func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error {
start := time.Now()
// Best-effort tmux status update (colored start heartbeat)
- _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), client.DefaultModel()))
+ model := strings.TrimSpace(req.model)
+ if model == "" {
+ model = client.DefaultModel()
+ }
+ _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model))
var output string
if s, ok := client.(llm.Streamer); ok {
var b strings.Builder
if err := s.ChatStream(ctx, msgs, func(chunk string) {
b.WriteString(chunk)
fmt.Fprint(out, chunk)
- }); err != nil {
+ }, req.options...); err != nil {
return err
}
output = b.String()
} else {
- txt, err := client.Chat(ctx, msgs)
+ txt, err := client.Chat(ctx, msgs, req.options...)
if err != nil {
return err
}
@@ -157,7 +245,7 @@ func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input s
sent += len(m.Content)
}
recv := len(output)
- _ = stats.Update(ctx, client.Name(), client.DefaultModel(), sent, recv)
+ _ = stats.Update(ctx, client.Name(), model, sent, recv)
snap, _ := stats.TakeSnapshot()
minsWin := snap.Window.Minutes()
if minsWin <= 0 {
@@ -165,20 +253,23 @@ func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input s
}
scopeReqs := int64(0)
if pe, ok := snap.Providers[client.Name()]; ok {
- if mc, ok2 := pe.Models[client.DefaultModel()]; ok2 {
+ if mc, ok2 := pe.Models[model]; ok2 {
scopeReqs = mc.Reqs
}
}
scopeRPM := float64(scopeReqs) / minsWin
fmt.Fprintf(errw, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d | global Σ reqs=%d rpm=%.2f"+logging.AnsiReset+"\n",
- client.Name(), client.DefaultModel(), dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM)
- _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), client.DefaultModel(), scopeRPM, scopeReqs, snap.Window))
+ client.Name(), model, dur.Round(time.Millisecond), sent, recv, snap.Global.Reqs, snap.RPM)
+ _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, client.Name(), model, scopeRPM, scopeReqs, snap.Window))
return nil
}
// printProviderInfo writes the provider/model line to stderr.
-func printProviderInfo(errw io.Writer, client llm.Client) {
- fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), client.DefaultModel())
+func printProviderInfo(errw io.Writer, client llm.Client, model string) {
+ if strings.TrimSpace(model) == "" {
+ model = client.DefaultModel()
+ }
+ fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), model)
}
// newClientFromConfig is kept for tests; delegates to llmutils.
diff --git a/internal/hexaicli/run_more_test.go b/internal/hexaicli/run_more_test.go
index bd88d56..469f0c0 100644
--- a/internal/hexaicli/run_more_test.go
+++ b/internal/hexaicli/run_more_test.go
@@ -26,7 +26,8 @@ func TestRunChat_Streaming(t *testing.T) {
var out, errw bytes.Buffer
input := "hello"
msgs := []llm.Message{{Role: "user", Content: input}}
- if err := runChat(context.Background(), streamClient{}, msgs, input, &out, &errw); err != nil {
+ req := requestArgs{model: "m"}
+ if err := runChat(context.Background(), streamClient{}, req, msgs, input, &out, &errw); err != nil {
t.Fatalf("runChat failed: %v", err)
}
if out.String() != "AB" {
diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go
index a4184f6..4dcbbc5 100644
--- a/internal/hexaicli/run_test.go
+++ b/internal/hexaicli/run_test.go
@@ -16,6 +16,11 @@ type failingReader struct{ err error }
func (f failingReader) Read([]byte) (int, error) { return 0, f.err }
+func floatPtr(v float64) *float64 {
+ x := v
+ return &x
+}
+
func TestReadInput_Combinations(t *testing.T) {
// stdin + arg
restore, f := setStdin(t, "from-stdin")
@@ -72,7 +77,8 @@ func TestRunChat_StreamAndNonStream(t *testing.T) {
// stream path
fc := &fakeStreamer{fakeClient: fakeClient{name: "p", model: "m"}, chunks: []string{"H", "i", "!"}}
var out, errb bytes.Buffer
- if err := runChat(context.Background(), fc, buildMessages("hello"), "hello", &out, &errb); err != nil {
+ req := requestArgs{model: fc.DefaultModel()}
+ if err := runChat(context.Background(), fc, req, buildMessages("hello"), "hello", &out, &errb); err != nil {
t.Fatalf("stream: %v", err)
}
if out.String() != "Hi!" || !strings.Contains(errb.String(), "provider=p model=m") {
@@ -82,7 +88,7 @@ func TestRunChat_StreamAndNonStream(t *testing.T) {
fc2 := &fakeClient{name: "p2", model: "m2", resp: "Yo"}
out.Reset()
errb.Reset()
- if err := runChat(context.Background(), fc2, buildMessages("hello"), "hello", &out, &errb); err != nil {
+ if err := runChat(context.Background(), fc2, requestArgs{model: fc2.DefaultModel()}, buildMessages("hello"), "hello", &out, &errb); err != nil {
t.Fatalf("non-stream: %v", err)
}
if out.String() != "Yo" || !strings.Contains(errb.String(), "provider=p2 model=m2") {
@@ -101,7 +107,7 @@ func (c clientErr) DefaultModel() string { return c.model }
func TestRunChat_ErrorPaths(t *testing.T) {
ctx := context.Background()
out, errb := &bytes.Buffer{}, &bytes.Buffer{}
- if err := runChat(ctx, clientErr{"p", "m"}, buildMessages("hi"), "hi", out, errb); err == nil {
+ if err := runChat(ctx, clientErr{"p", "m"}, requestArgs{model: "m"}, buildMessages("hi"), "hi", out, errb); err == nil {
t.Fatalf("expected error from Chat")
}
}
@@ -139,12 +145,42 @@ func TestRun_OpenAI_NoKey_ShowsError(t *testing.T) {
func TestPrintProviderInfo(t *testing.T) {
var b bytes.Buffer
- printProviderInfo(&b, &fakeClient{name: "x", model: "y"})
+ printProviderInfo(&b, &fakeClient{name: "x", model: "y"}, "y")
if !strings.Contains(b.String(), "provider=x model=y") {
t.Fatalf("missing provider line: %q", b.String())
}
}
+func TestBuildCLIRequestArgs_Override(t *testing.T) {
+ cfg := appconfig.App{CLIModel: "override", CLITemperature: floatPtr(0.7), Provider: "openai", CLIProvider: "copilot", CopilotModel: "gpt-4o"}
+ req := buildCLIRequestArgs(cfg, &fakeClient{name: "copilot", model: "default"})
+ if req.model != "override" {
+ t.Fatalf("expected model override, got %q", req.model)
+ }
+ var opts llm.Options
+ for _, o := range req.options {
+ o(&opts)
+ }
+ if opts.Model != "override" || opts.Temperature != 0.7 {
+ t.Fatalf("unexpected options: %+v", opts)
+ }
+}
+
+func TestBuildCLIRequestArgs_Gpt5Temp(t *testing.T) {
+ cfg := appconfig.App{Provider: "openai", CodingTemperature: floatPtr(0.2)}
+ req := buildCLIRequestArgs(cfg, &fakeClient{name: "openai", model: "gpt-5.1"})
+ if req.model != "gpt-5.1" {
+ t.Fatalf("expected fallback model, got %q", req.model)
+ }
+ var opts llm.Options
+ for _, o := range req.options {
+ o(&opts)
+ }
+ if opts.Temperature != 1.0 {
+ t.Fatalf("expected temp 1.0, got %v", opts.Temperature)
+ }
+}
+
func TestNewClientFromConfig_Ollama(t *testing.T) {
cfg := appconfig.App{Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m"}
c, err := newClientFromConfig(cfg)