diff options
Diffstat (limited to 'internal/appconfig')
| -rw-r--r-- | internal/appconfig/config.go | 264 | ||||
| -rw-r--r-- | internal/appconfig/config_env_model_test.go | 74 | ||||
| -rw-r--r-- | internal/appconfig/config_test.go | 84 |
3 files changed, 397 insertions, 25 deletions
diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go index adf9b75..47abaaf 100644 --- a/internal/appconfig/config.go +++ b/internal/appconfig/config.go @@ -58,6 +58,20 @@ type App struct { // Default temperature for Copilot requests (nil means use provider default) CopilotTemperature *float64 `json:"copilot_temperature" toml:"copilot_temperature"` + // Per-surface model overrides (fall back to provider defaults when unset) + CompletionModel string `json:"completion_model" toml:"completion_model"` + CompletionTemperature *float64 `json:"completion_temperature" toml:"completion_temperature"` + CompletionProvider string `json:"completion_provider" toml:"completion_provider"` + CodeActionModel string `json:"code_action_model" toml:"code_action_model"` + CodeActionTemperature *float64 `json:"code_action_temperature" toml:"code_action_temperature"` + CodeActionProvider string `json:"code_action_provider" toml:"code_action_provider"` + ChatModel string `json:"chat_model" toml:"chat_model"` + ChatTemperature *float64 `json:"chat_temperature" toml:"chat_temperature"` + ChatProvider string `json:"chat_provider" toml:"chat_provider"` + CLIModel string `json:"cli_model" toml:"cli_model"` + CLITemperature *float64 `json:"cli_temperature" toml:"cli_temperature"` + CLIProvider string `json:"cli_provider" toml:"cli_provider"` + // Prompt templates (configured only via file; no env overrides) // Completion/chat/code action/CLI prompt strings. See config.toml.example for placeholders. // Completion @@ -589,7 +603,7 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) { "copilot_model": {}, "copilot_base_url": {}, "copilot_temperature": {}, } for k := range raw { - if _, isTable := map[string]struct{}{"general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {}, "chat": {}, "provider": {}, "openai": {}, "copilot": {}, "ollama": {}, "prompts": {}}[k]; isTable { + if _, isTable := map[string]struct{}{"general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {}, "chat": {}, "provider": {}, "models": {}, "openai": {}, "copilot": {}, "ollama": {}, "prompts": {}}[k]; isTable { continue } if _, isLegacy := legacy[k]; isLegacy { @@ -629,12 +643,170 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) { } } } + if m := parseSurfaceModels(raw, logger); m != nil { + tab.mergeSurfaceModels(m) + } return &tab, nil } +func parseSurfaceModels(raw map[string]any, logger *log.Logger) *App { + modelsRaw, ok := raw["models"] + if !ok { + return nil + } + table, ok := modelsRaw.(map[string]any) + if !ok { + if logger != nil { + logger.Printf("config: ignoring models section (expected table, got %T)", modelsRaw) + } + return nil + } + var out App + var any bool + if model, provider, temp, ok := decodeModelEntry(table["completion"], "models.completion", logger); ok { + if model != "" { + out.CompletionModel = model + } + if provider != "" { + out.CompletionProvider = provider + } + if temp != nil { + out.CompletionTemperature = temp + } + any = true + } + if model, provider, temp, ok := decodeModelEntry(table["code_action"], "models.code_action", logger); ok { + if model != "" { + out.CodeActionModel = model + } + if provider != "" { + out.CodeActionProvider = provider + } + if temp != nil { + out.CodeActionTemperature = temp + } + any = true + } + if model, provider, temp, ok := decodeModelEntry(table["chat"], "models.chat", logger); ok { + if model != "" { + out.ChatModel = model + } + if provider != "" { + out.ChatProvider = provider + } + if temp != nil { + out.ChatTemperature = temp + } + any = true + } + if model, provider, temp, ok := decodeModelEntry(table["cli"], "models.cli", logger); ok { + if model != "" { + out.CLIModel = model + } + if provider != "" { + out.CLIProvider = provider + } + if temp != nil { + out.CLITemperature = temp + } + any = true + } + if !any { + return nil + } + return &out +} + +func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, *float64, bool) { + if raw == nil { + return "", "", nil, false + } + switch v := raw.(type) { + case string: + model := strings.TrimSpace(v) + if model == "" { + return "", "", nil, false + } + return model, "", nil, true + case map[string]any: + model := "" + provider := "" + if m, ok := v["model"]; ok { + s, ok := m.(string) + if !ok { + if logger != nil { + logger.Printf("config: %s.model must be a string", path) + } + return "", "", nil, false + } + model = strings.TrimSpace(s) + } + if pRaw, ok := v["provider"]; ok { + ps, ok := pRaw.(string) + if !ok { + if logger != nil { + logger.Printf("config: %s.provider must be a string", path) + } + return "", "", nil, false + } + provider = strings.TrimSpace(ps) + } + var tempPtr *float64 + if tRaw, ok := v["temperature"]; ok { + parsed, ok := parseTemperatureValue(tRaw, path, logger) + if !ok { + return "", "", nil, false + } + tempPtr = parsed + } + if model == "" && tempPtr == nil && provider == "" { + return "", "", nil, false + } + return model, provider, tempPtr, true + default: + if logger != nil { + logger.Printf("config: %s must be a string or table, got %T", path, raw) + } + return "", "", nil, false + } +} + +func parseTemperatureValue(raw any, path string, logger *log.Logger) (*float64, bool) { + switch v := raw.(type) { + case float64: + return floatPtr(v), true + case int64: + return floatPtr(float64(v)), true + case string: + s := strings.TrimSpace(v) + if s == "" { + return nil, true + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + if logger != nil { + logger.Printf("config: %s.temperature invalid: %v", path, err) + } + return nil, false + } + return floatPtr(f), true + default: + if logger != nil { + logger.Printf("config: %s.temperature must be numeric or string, got %T", path, raw) + } + return nil, false + } +} + +func floatPtr(v float64) *float64 { + f := v + return &f +} + func (a *App) mergeWith(other *App) { a.mergeBasics(other) a.mergeProviderFields(other) + a.mergeSurfaceModels(other) a.mergePrompts(other) } @@ -687,6 +859,46 @@ func (a *App) mergeBasics(other *App) { } } +// mergeSurfaceModels copies per-surface model and temperature overrides. +func (a *App) mergeSurfaceModels(other *App) { + if s := strings.TrimSpace(other.CompletionModel); s != "" { + a.CompletionModel = s + } + if other.CompletionTemperature != nil { + a.CompletionTemperature = other.CompletionTemperature + } + if s := strings.TrimSpace(other.CompletionProvider); s != "" { + a.CompletionProvider = s + } + if s := strings.TrimSpace(other.CodeActionModel); s != "" { + a.CodeActionModel = s + } + if other.CodeActionTemperature != nil { + a.CodeActionTemperature = other.CodeActionTemperature + } + if s := strings.TrimSpace(other.CodeActionProvider); s != "" { + a.CodeActionProvider = s + } + if s := strings.TrimSpace(other.ChatModel); s != "" { + a.ChatModel = s + } + if other.ChatTemperature != nil { + a.ChatTemperature = other.ChatTemperature + } + if s := strings.TrimSpace(other.ChatProvider); s != "" { + a.ChatProvider = s + } + if s := strings.TrimSpace(other.CLIModel); s != "" { + a.CLIModel = s + } + if other.CLITemperature != nil { + a.CLITemperature = other.CLITemperature + } + if s := strings.TrimSpace(other.CLIProvider); s != "" { + a.CLIProvider = s + } +} + // mergePrompts copies non-empty prompt templates from other. func (a *App) mergePrompts(other *App) { // Completion @@ -1050,6 +1262,56 @@ func loadFromEnv(logger *log.Logger) *App { any = true } + // Per-surface overrides + if s := getenv("HEXAI_MODEL_COMPLETION"); s != "" { + out.CompletionModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_COMPLETION"); ok { + out.CompletionTemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_COMPLETION"); s != "" { + out.CompletionProvider = s + any = true + } + if s := getenv("HEXAI_MODEL_CODE_ACTION"); s != "" { + out.CodeActionModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CODE_ACTION"); ok { + out.CodeActionTemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_CODE_ACTION"); s != "" { + out.CodeActionProvider = s + any = true + } + if s := getenv("HEXAI_MODEL_CHAT"); s != "" { + out.ChatModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CHAT"); ok { + out.ChatTemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_CHAT"); s != "" { + out.ChatProvider = s + any = true + } + if s := getenv("HEXAI_MODEL_CLI"); s != "" { + out.CLIModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CLI"); ok { + out.CLITemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_CLI"); s != "" { + out.CLIProvider = s + any = true + } + if !any { return nil } diff --git a/internal/appconfig/config_env_model_test.go b/internal/appconfig/config_env_model_test.go index 2db2bb5..f34416d 100644 --- a/internal/appconfig/config_env_model_test.go +++ b/internal/appconfig/config_env_model_test.go @@ -1,37 +1,65 @@ package appconfig import ( - "log" - "os" - "testing" + "log" + "os" + "testing" ) // Test that HEXAI_MODEL applies to provider model fields and that // provider-specific envs take precedence when both are set. func TestEnv_GenericModelOverrideAndPrecedence(t *testing.T) { - t.Setenv("HEXAI_MODEL", "gpt-5-codex") - t.Setenv("HEXAI_PROVIDER", "openai") - // No provider-specific env set yet: HEXAI_MODEL should flow into OpenAIModel - cfg := Load(log.New(os.Stderr, "test ", 0)) - if cfg.OpenAIModel != "gpt-5-codex" { - t.Fatalf("expected OpenAIModel=gpt-5-codex via HEXAI_MODEL, got %q", cfg.OpenAIModel) - } + t.Setenv("HEXAI_MODEL", "gpt-5-codex") + t.Setenv("HEXAI_PROVIDER", "openai") + // No provider-specific env set yet: HEXAI_MODEL should flow into OpenAIModel + cfg := Load(log.New(os.Stderr, "test ", 0)) + if cfg.OpenAIModel != "gpt-5-codex" { + t.Fatalf("expected OpenAIModel=gpt-5-codex via HEXAI_MODEL, got %q", cfg.OpenAIModel) + } - // Now set a provider-specific model; it should win over HEXAI_MODEL - t.Setenv("HEXAI_OPENAI_MODEL", "gpt-5-thinking") - cfg2 := Load(log.New(os.Stderr, "test ", 0)) - if cfg2.OpenAIModel != "gpt-5-thinking" { - t.Fatalf("expected OpenAIModel from HEXAI_OPENAI_MODEL to win, got %q", cfg2.OpenAIModel) - } + // Now set a provider-specific model; it should win over HEXAI_MODEL + t.Setenv("HEXAI_OPENAI_MODEL", "gpt-5-thinking") + cfg2 := Load(log.New(os.Stderr, "test ", 0)) + if cfg2.OpenAIModel != "gpt-5-thinking" { + t.Fatalf("expected OpenAIModel from HEXAI_OPENAI_MODEL to win, got %q", cfg2.OpenAIModel) + } } // Test that HEXAI_MODEL_FORCE overrides provider-specific envs (used by CLI --model). func TestEnv_ModelForce_OverridesProviderSpecific(t *testing.T) { - t.Setenv("HEXAI_OPENAI_MODEL", "gpt-5-main") - t.Setenv("HEXAI_MODEL_FORCE", "gpt-5-codex") - t.Setenv("HEXAI_PROVIDER", "openai") - cfg := Load(log.New(os.Stderr, "test ", 0)) - if cfg.OpenAIModel != "gpt-5-codex" { - t.Fatalf("expected OpenAIModel forced to gpt-5-codex, got %q", cfg.OpenAIModel) - } + t.Setenv("HEXAI_OPENAI_MODEL", "gpt-5-main") + t.Setenv("HEXAI_MODEL_FORCE", "gpt-5-codex") + t.Setenv("HEXAI_PROVIDER", "openai") + cfg := Load(log.New(os.Stderr, "test ", 0)) + if cfg.OpenAIModel != "gpt-5-codex" { + t.Fatalf("expected OpenAIModel forced to gpt-5-codex, got %q", cfg.OpenAIModel) + } +} + +func TestEnv_SurfaceModelOverrides(t *testing.T) { + t.Setenv("HEXAI_MODEL_COMPLETION", "gpt-c") + t.Setenv("HEXAI_TEMPERATURE_COMPLETION", "0.44") + t.Setenv("HEXAI_PROVIDER_COMPLETION", "copilot") + t.Setenv("HEXAI_MODEL_CLI", "gpt-cli") + t.Setenv("HEXAI_TEMPERATURE_CLI", "0.22") + t.Setenv("HEXAI_PROVIDER_CLI", "ollama") + cfg := Load(log.New(os.Stderr, "test ", 0)) + if cfg.CompletionModel != "gpt-c" { + t.Fatalf("expected completion model override, got %q", cfg.CompletionModel) + } + if cfg.CompletionTemperature == nil || *cfg.CompletionTemperature != 0.44 { + t.Fatalf("expected completion temperature override, got %v", cfg.CompletionTemperature) + } + if cfg.CompletionProvider != "copilot" { + t.Fatalf("expected completion provider override, got %q", cfg.CompletionProvider) + } + if cfg.CLIModel != "gpt-cli" { + t.Fatalf("expected cli model override, got %q", cfg.CLIModel) + } + if cfg.CLITemperature == nil || *cfg.CLITemperature != 0.22 { + t.Fatalf("expected cli temperature override, got %v", cfg.CLITemperature) + } + if cfg.CLIProvider != "ollama" { + t.Fatalf("expected cli provider override, got %q", cfg.CLIProvider) + } } diff --git a/internal/appconfig/config_test.go b/internal/appconfig/config_test.go index b03137e..ea68305 100644 --- a/internal/appconfig/config_test.go +++ b/internal/appconfig/config_test.go @@ -88,6 +88,24 @@ completion_throttle_ms = 300 [triggers] trigger_characters = [".", ":"] +[models.completion] +model = "gpt-file-complete" +provider = "openai" + +[models.code_action] +model = "gpt-file-action" +temperature = 0.45 +provider = "copilot" + +[models.chat] +model = "gpt-file-chat" +provider = "openai" + +[models.cli] +model = "gpt-file-cli" +temperature = 0.15 +provider = "ollama" + [provider] name = "openai" @@ -107,6 +125,10 @@ model = "ghost" temperature = 0.0 `) + if _, err := loadFromFile(cfgPath, newLogger()); err != nil { + t.Fatalf("loadFromFile: %v", err) + } + // Env overrides take precedence withEnv(t, "HEXAI_MAX_TOKENS", "321") withEnv(t, "HEXAI_CONTEXT_MODE", "always-full") @@ -128,6 +150,18 @@ temperature = 0.0 withEnv(t, "HEXAI_COPILOT_BASE_URL", "http://copilot-override") withEnv(t, "HEXAI_COPILOT_MODEL", "ghost-override") withEnv(t, "HEXAI_COPILOT_TEMPERATURE", "0.3") + withEnv(t, "HEXAI_MODEL_COMPLETION", "env-completion") + withEnv(t, "HEXAI_TEMPERATURE_COMPLETION", "0.33") + withEnv(t, "HEXAI_PROVIDER_COMPLETION", "copilot") + withEnv(t, "HEXAI_MODEL_CODE_ACTION", "env-action") + withEnv(t, "HEXAI_TEMPERATURE_CODE_ACTION", "0.55") + withEnv(t, "HEXAI_PROVIDER_CODE_ACTION", "openai") + withEnv(t, "HEXAI_MODEL_CHAT", "env-chat") + withEnv(t, "HEXAI_TEMPERATURE_CHAT", "0.66") + withEnv(t, "HEXAI_PROVIDER_CHAT", "copilot") + withEnv(t, "HEXAI_MODEL_CLI", "env-cli") + withEnv(t, "HEXAI_TEMPERATURE_CLI", "0.77") + withEnv(t, "HEXAI_PROVIDER_CLI", "ollama") logger := newLogger() cfg := Load(logger) @@ -158,11 +192,35 @@ temperature = 0.0 if cfg.CopilotBaseURL != "http://copilot-override" || cfg.CopilotModel != "ghost-override" || cfg.CopilotTemperature == nil || *cfg.CopilotTemperature != 0.3 { t.Fatalf("copilot overrides not applied: %+v", cfg) } + if cfg.CompletionModel != "env-completion" || cfg.CompletionTemperature == nil || *cfg.CompletionTemperature != 0.33 { + t.Fatalf("completion overrides not applied: model=%q temp=%v", cfg.CompletionModel, cfg.CompletionTemperature) + } + if cfg.CompletionProvider != "copilot" { + t.Fatalf("completion provider override not applied: %q", cfg.CompletionProvider) + } + if cfg.CodeActionModel != "env-action" || cfg.CodeActionTemperature == nil || *cfg.CodeActionTemperature != 0.55 { + t.Fatalf("code action overrides not applied: model=%q temp=%v", cfg.CodeActionModel, cfg.CodeActionTemperature) + } + if cfg.CodeActionProvider != "openai" { + t.Fatalf("code action provider override not applied: %q", cfg.CodeActionProvider) + } + if cfg.ChatModel != "env-chat" || cfg.ChatTemperature == nil || *cfg.ChatTemperature != 0.66 { + t.Fatalf("chat overrides not applied: model=%q temp=%v", cfg.ChatModel, cfg.ChatTemperature) + } + if cfg.ChatProvider != "copilot" { + t.Fatalf("chat provider override not applied: %q", cfg.ChatProvider) + } + if cfg.CLIModel != "env-cli" || cfg.CLITemperature == nil || *cfg.CLITemperature != 0.77 { + t.Fatalf("cli overrides not applied: model=%q temp=%v", cfg.CLIModel, cfg.CLITemperature) + } + if cfg.CLIProvider != "ollama" { + t.Fatalf("cli provider override not applied: %q", cfg.CLIProvider) + } // Ensure file values would have applied absent env // Spot-check: reset env and reload for _, k := range []string{ - "HEXAI_MAX_TOKENS", "HEXAI_CONTEXT_MODE", "HEXAI_CONTEXT_WINDOW_LINES", "HEXAI_MAX_CONTEXT_TOKENS", "HEXAI_LOG_PREVIEW_LIMIT", "HEXAI_CODING_TEMPERATURE", "HEXAI_MANUAL_INVOKE_MIN_PREFIX", "HEXAI_COMPLETION_DEBOUNCE_MS", "HEXAI_COMPLETION_THROTTLE_MS", "HEXAI_TRIGGER_CHARACTERS", "HEXAI_PROVIDER", "HEXAI_OPENAI_BASE_URL", "HEXAI_OPENAI_MODEL", "HEXAI_OPENAI_TEMPERATURE", "HEXAI_OLLAMA_BASE_URL", "HEXAI_OLLAMA_MODEL", "HEXAI_OLLAMA_TEMPERATURE", "HEXAI_COPILOT_BASE_URL", "HEXAI_COPILOT_MODEL", "HEXAI_COPILOT_TEMPERATURE", + "HEXAI_MAX_TOKENS", "HEXAI_CONTEXT_MODE", "HEXAI_CONTEXT_WINDOW_LINES", "HEXAI_MAX_CONTEXT_TOKENS", "HEXAI_LOG_PREVIEW_LIMIT", "HEXAI_CODING_TEMPERATURE", "HEXAI_MANUAL_INVOKE_MIN_PREFIX", "HEXAI_COMPLETION_DEBOUNCE_MS", "HEXAI_COMPLETION_THROTTLE_MS", "HEXAI_TRIGGER_CHARACTERS", "HEXAI_PROVIDER", "HEXAI_OPENAI_BASE_URL", "HEXAI_OPENAI_MODEL", "HEXAI_OPENAI_TEMPERATURE", "HEXAI_OLLAMA_BASE_URL", "HEXAI_OLLAMA_MODEL", "HEXAI_OLLAMA_TEMPERATURE", "HEXAI_COPILOT_BASE_URL", "HEXAI_COPILOT_MODEL", "HEXAI_COPILOT_TEMPERATURE", "HEXAI_MODEL_COMPLETION", "HEXAI_TEMPERATURE_COMPLETION", "HEXAI_MODEL_CODE_ACTION", "HEXAI_TEMPERATURE_CODE_ACTION", "HEXAI_MODEL_CHAT", "HEXAI_TEMPERATURE_CHAT", "HEXAI_MODEL_CLI", "HEXAI_TEMPERATURE_CLI", "HEXAI_PROVIDER_COMPLETION", "HEXAI_PROVIDER_CODE_ACTION", "HEXAI_PROVIDER_CHAT", "HEXAI_PROVIDER_CLI", } { t.Setenv(k, "") } @@ -176,6 +234,30 @@ temperature = 0.0 if cfg2.OpenAIBaseURL != "https://api.example" || cfg2.OpenAIModel != "gpt-x" || cfg2.OpenAITemperature == nil || *cfg2.OpenAITemperature != 0.0 { t.Fatalf("file merge (openai) not applied: %+v", cfg2) } + if cfg2.CompletionModel != "gpt-file-complete" || cfg2.CompletionTemperature != nil { + t.Fatalf("file merge (completion) not applied: %+v", cfg2) + } + if cfg2.CompletionProvider != "openai" { + t.Fatalf("file merge (completion provider) not applied: %q", cfg2.CompletionProvider) + } + if cfg2.CodeActionModel != "gpt-file-action" || cfg2.CodeActionTemperature == nil || *cfg2.CodeActionTemperature != 0.45 { + t.Fatalf("file merge (code action) not applied: %+v", cfg2) + } + if cfg2.CodeActionProvider != "copilot" { + t.Fatalf("file merge (code action provider) not applied: %q", cfg2.CodeActionProvider) + } + if cfg2.ChatModel != "gpt-file-chat" || cfg2.ChatTemperature != nil { + t.Fatalf("file merge (chat) not applied: %+v", cfg2) + } + if cfg2.ChatProvider != "openai" { + t.Fatalf("file merge (chat provider) not applied: %q", cfg2.ChatProvider) + } + if cfg2.CLIModel != "gpt-file-cli" || cfg2.CLITemperature == nil || *cfg2.CLITemperature != 0.15 { + t.Fatalf("file merge (cli) not applied: %+v", cfg2) + } + if cfg2.CLIProvider != "ollama" { + t.Fatalf("file merge (cli provider) not applied: %q", cfg2.CLIProvider) + } } func TestGetConfigPath_XDG(t *testing.T) { |
