diff options
Diffstat (limited to 'internal/appconfig')
| -rw-r--r-- | internal/appconfig/config.go | 232 | ||||
| -rw-r--r-- | internal/appconfig/config_env_model_test.go | 32 | ||||
| -rw-r--r-- | internal/appconfig/config_test.go | 96 |
3 files changed, 171 insertions, 189 deletions
diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go index 47abaaf..63d0437 100644 --- a/internal/appconfig/config.go +++ b/internal/appconfig/config.go @@ -13,6 +13,13 @@ import ( "github.com/pelletier/go-toml/v2" ) +// SurfaceConfig describes a provider/model pairing (with optional temperature). +type SurfaceConfig struct { + Provider string + Model string + Temperature *float64 +} + // App holds user-configurable settings read from ~/.config/hexai/config.toml. type App struct { MaxTokens int `json:"max_tokens" toml:"max_tokens"` @@ -58,19 +65,11 @@ type App struct { // Default temperature for Copilot requests (nil means use provider default) CopilotTemperature *float64 `json:"copilot_temperature" toml:"copilot_temperature"` - // Per-surface model overrides (fall back to provider defaults when unset) - CompletionModel string `json:"completion_model" toml:"completion_model"` - CompletionTemperature *float64 `json:"completion_temperature" toml:"completion_temperature"` - CompletionProvider string `json:"completion_provider" toml:"completion_provider"` - CodeActionModel string `json:"code_action_model" toml:"code_action_model"` - CodeActionTemperature *float64 `json:"code_action_temperature" toml:"code_action_temperature"` - CodeActionProvider string `json:"code_action_provider" toml:"code_action_provider"` - ChatModel string `json:"chat_model" toml:"chat_model"` - ChatTemperature *float64 `json:"chat_temperature" toml:"chat_temperature"` - ChatProvider string `json:"chat_provider" toml:"chat_provider"` - CLIModel string `json:"cli_model" toml:"cli_model"` - CLITemperature *float64 `json:"cli_temperature" toml:"cli_temperature"` - CLIProvider string `json:"cli_provider" toml:"cli_provider"` + // Per-surface provider/model configurations (ordered; first entry is primary) + CompletionConfigs []SurfaceConfig `json:"-" toml:"-"` + CodeActionConfigs []SurfaceConfig `json:"-" toml:"-"` + ChatConfigs []SurfaceConfig `json:"-" toml:"-"` + CLIConfigs []SurfaceConfig `json:"-" toml:"-"` // Prompt templates (configured only via file; no env overrides) // Completion/chat/code action/CLI prompt strings. See config.toml.example for placeholders. @@ -662,72 +661,66 @@ func parseSurfaceModels(raw map[string]any, logger *log.Logger) *App { return nil } var out App - var any bool - if model, provider, temp, ok := decodeModelEntry(table["completion"], "models.completion", logger); ok { - if model != "" { - out.CompletionModel = model - } - if provider != "" { - out.CompletionProvider = provider - } - if temp != nil { - out.CompletionTemperature = temp - } - any = true - } - if model, provider, temp, ok := decodeModelEntry(table["code_action"], "models.code_action", logger); ok { - if model != "" { - out.CodeActionModel = model - } - if provider != "" { - out.CodeActionProvider = provider - } - if temp != nil { - out.CodeActionTemperature = temp - } - any = true - } - if model, provider, temp, ok := decodeModelEntry(table["chat"], "models.chat", logger); ok { - if model != "" { - out.ChatModel = model - } - if provider != "" { - out.ChatProvider = provider - } - if temp != nil { - out.ChatTemperature = temp - } - any = true + appendEntries := func(dest *[]SurfaceConfig, key string, val any) bool { + entries, ok := parseSurfaceEntries(val, key, logger) + if !ok || len(entries) == 0 { + return false + } + *dest = append(*dest, entries...) + return true + } + any := appendEntries(&out.CompletionConfigs, "models.completion", table["completion"]) + any = appendEntries(&out.CodeActionConfigs, "models.code_action", table["code_action"]) || any + any = appendEntries(&out.ChatConfigs, "models.chat", table["chat"]) || any + any = appendEntries(&out.CLIConfigs, "models.cli", table["cli"]) || any + if !any { + return nil } - if model, provider, temp, ok := decodeModelEntry(table["cli"], "models.cli", logger); ok { - if model != "" { - out.CLIModel = model - } - if provider != "" { - out.CLIProvider = provider + return &out +} + +func parseSurfaceEntries(raw any, path string, logger *log.Logger) ([]SurfaceConfig, bool) { + switch v := raw.(type) { + case nil: + return nil, false + case []any: + var out []SurfaceConfig + for i, entry := range v { + cfg, ok := decodeModelEntry(entry, fmt.Sprintf("%s[%d]", path, i), logger) + if !ok || cfg == nil { + continue + } + out = append(out, *cfg) } - if temp != nil { - out.CLITemperature = temp + return out, len(out) > 0 + default: + if cfg, ok := decodeModelEntry(v, path, logger); ok && cfg != nil { + return []SurfaceConfig{*cfg}, true } - any = true + return nil, false } - if !any { +} + +func cloneSurfaceConfigs(src []SurfaceConfig) []SurfaceConfig { + if len(src) == 0 { return nil } - return &out + out := make([]SurfaceConfig, len(src)) + copy(out, src) + return out } -func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, *float64, bool) { +func decodeModelEntry(raw any, path string, logger *log.Logger) (*SurfaceConfig, bool) { if raw == nil { - return "", "", nil, false + return nil, false } switch v := raw.(type) { case string: model := strings.TrimSpace(v) if model == "" { - return "", "", nil, false + return nil, false } - return model, "", nil, true + return &SurfaceConfig{Model: model}, true case map[string]any: model := "" provider := "" @@ -737,7 +730,7 @@ func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, if logger != nil { logger.Printf("config: %s.model must be a string", path) } - return "", "", nil, false + return nil, false } model = strings.TrimSpace(s) } @@ -747,7 +740,7 @@ func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, if logger != nil { logger.Printf("config: %s.provider must be a string", path) } - return "", "", nil, false + return nil, false } provider = strings.TrimSpace(ps) } @@ -755,19 +748,19 @@ func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, if tRaw, ok := v["temperature"]; ok { parsed, ok := parseTemperatureValue(tRaw, path, logger) if !ok { - return "", "", nil, false + return nil, false } tempPtr = parsed } if model == "" && tempPtr == nil && provider == "" { - return "", "", nil, false + return nil, false } - return model, provider, tempPtr, true + return &SurfaceConfig{Provider: provider, Model: model, Temperature: tempPtr}, true default: if logger != nil { logger.Printf("config: %s must be a string or table, got %T", path, raw) } - return "", "", nil, false + return nil, false } } @@ -861,41 +854,17 @@ func (a *App) mergeBasics(other *App) { // mergeSurfaceModels copies per-surface model and temperature overrides. func (a *App) mergeSurfaceModels(other *App) { - if s := strings.TrimSpace(other.CompletionModel); s != "" { - a.CompletionModel = s - } - if other.CompletionTemperature != nil { - a.CompletionTemperature = other.CompletionTemperature - } - if s := strings.TrimSpace(other.CompletionProvider); s != "" { - a.CompletionProvider = s + if len(other.CompletionConfigs) > 0 { + a.CompletionConfigs = cloneSurfaceConfigs(other.CompletionConfigs) } - if s := strings.TrimSpace(other.CodeActionModel); s != "" { - a.CodeActionModel = s + if len(other.CodeActionConfigs) > 0 { + a.CodeActionConfigs = cloneSurfaceConfigs(other.CodeActionConfigs) } - if other.CodeActionTemperature != nil { - a.CodeActionTemperature = other.CodeActionTemperature + if len(other.ChatConfigs) > 0 { + a.ChatConfigs = cloneSurfaceConfigs(other.ChatConfigs) } - if s := strings.TrimSpace(other.CodeActionProvider); s != "" { - a.CodeActionProvider = s - } - if s := strings.TrimSpace(other.ChatModel); s != "" { - a.ChatModel = s - } - if other.ChatTemperature != nil { - a.ChatTemperature = other.ChatTemperature - } - if s := strings.TrimSpace(other.ChatProvider); s != "" { - a.ChatProvider = s - } - if s := strings.TrimSpace(other.CLIModel); s != "" { - a.CLIModel = s - } - if other.CLITemperature != nil { - a.CLITemperature = other.CLITemperature - } - if s := strings.TrimSpace(other.CLIProvider); s != "" { - a.CLIProvider = s + if len(other.CLIConfigs) > 0 { + a.CLIConfigs = cloneSurfaceConfigs(other.CLIConfigs) } } @@ -1263,52 +1232,33 @@ func loadFromEnv(logger *log.Logger) *App { } // Per-surface overrides - if s := getenv("HEXAI_MODEL_COMPLETION"); s != "" { - out.CompletionModel = s - any = true - } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_COMPLETION"); ok { - out.CompletionTemperature = f - any = true - } - if s := getenv("HEXAI_PROVIDER_COMPLETION"); s != "" { - out.CompletionProvider = s - any = true - } - if s := getenv("HEXAI_MODEL_CODE_ACTION"); s != "" { - out.CodeActionModel = s - any = true - } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CODE_ACTION"); ok { - out.CodeActionTemperature = f - any = true - } - if s := getenv("HEXAI_PROVIDER_CODE_ACTION"); s != "" { - out.CodeActionProvider = s - any = true - } - if s := getenv("HEXAI_MODEL_CHAT"); s != "" { - out.ChatModel = s - any = true - } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CHAT"); ok { - out.ChatTemperature = f - any = true + buildEntry := func(modelKey, tempKey, providerKey string) ([]SurfaceConfig, bool) { + model := getenv(modelKey) + tempPtr, tempSet := parseFloatPtr(tempKey) + provider := getenv(providerKey) + if model == "" && provider == "" && !tempSet { + return nil, false + } + entry := SurfaceConfig{Provider: provider, Model: model} + if tempSet { + entry.Temperature = tempPtr + } + return []SurfaceConfig{entry}, true } - if s := getenv("HEXAI_PROVIDER_CHAT"); s != "" { - out.ChatProvider = s + if entries, ok := buildEntry("HEXAI_MODEL_COMPLETION", "HEXAI_TEMPERATURE_COMPLETION", "HEXAI_PROVIDER_COMPLETION"); ok { + out.CompletionConfigs = entries any = true } - if s := getenv("HEXAI_MODEL_CLI"); s != "" { - out.CLIModel = s + if entries, ok := buildEntry("HEXAI_MODEL_CODE_ACTION", "HEXAI_TEMPERATURE_CODE_ACTION", "HEXAI_PROVIDER_CODE_ACTION"); ok { + out.CodeActionConfigs = entries any = true } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CLI"); ok { - out.CLITemperature = f + if entries, ok := buildEntry("HEXAI_MODEL_CHAT", "HEXAI_TEMPERATURE_CHAT", "HEXAI_PROVIDER_CHAT"); ok { + out.ChatConfigs = entries any = true } - if s := getenv("HEXAI_PROVIDER_CLI"); s != "" { - out.CLIProvider = s + if entries, ok := buildEntry("HEXAI_MODEL_CLI", "HEXAI_TEMPERATURE_CLI", "HEXAI_PROVIDER_CLI"); ok { + out.CLIConfigs = entries any = true } diff --git a/internal/appconfig/config_env_model_test.go b/internal/appconfig/config_env_model_test.go index f34416d..7038819 100644 --- a/internal/appconfig/config_env_model_test.go +++ b/internal/appconfig/config_env_model_test.go @@ -44,22 +44,30 @@ func TestEnv_SurfaceModelOverrides(t *testing.T) { t.Setenv("HEXAI_TEMPERATURE_CLI", "0.22") t.Setenv("HEXAI_PROVIDER_CLI", "ollama") cfg := Load(log.New(os.Stderr, "test ", 0)) - if cfg.CompletionModel != "gpt-c" { - t.Fatalf("expected completion model override, got %q", cfg.CompletionModel) + if len(cfg.CompletionConfigs) != 1 { + t.Fatalf("expected single completion entry, got %+v", cfg.CompletionConfigs) } - if cfg.CompletionTemperature == nil || *cfg.CompletionTemperature != 0.44 { - t.Fatalf("expected completion temperature override, got %v", cfg.CompletionTemperature) + comp := cfg.CompletionConfigs[0] + if comp.Model != "gpt-c" { + t.Fatalf("expected completion model override, got %+v", comp) } - if cfg.CompletionProvider != "copilot" { - t.Fatalf("expected completion provider override, got %q", cfg.CompletionProvider) + if comp.Temperature == nil || *comp.Temperature != 0.44 { + t.Fatalf("expected completion temperature override, got %+v", comp) } - if cfg.CLIModel != "gpt-cli" { - t.Fatalf("expected cli model override, got %q", cfg.CLIModel) + if comp.Provider != "copilot" { + t.Fatalf("expected completion provider override, got %+v", comp) } - if cfg.CLITemperature == nil || *cfg.CLITemperature != 0.22 { - t.Fatalf("expected cli temperature override, got %v", cfg.CLITemperature) + if len(cfg.CLIConfigs) != 1 { + t.Fatalf("expected single CLI entry, got %+v", cfg.CLIConfigs) } - if cfg.CLIProvider != "ollama" { - t.Fatalf("expected cli provider override, got %q", cfg.CLIProvider) + cli := cfg.CLIConfigs[0] + if cli.Model != "gpt-cli" { + t.Fatalf("expected cli model override, got %+v", cli) + } + if cli.Temperature == nil || *cli.Temperature != 0.22 { + t.Fatalf("expected cli temperature override, got %+v", cli) + } + if cli.Provider != "ollama" { + t.Fatalf("expected cli provider override, got %+v", cli) } } diff --git a/internal/appconfig/config_test.go b/internal/appconfig/config_test.go index ea68305..e7f6059 100644 --- a/internal/appconfig/config_test.go +++ b/internal/appconfig/config_test.go @@ -88,20 +88,20 @@ completion_throttle_ms = 300 [triggers] trigger_characters = [".", ":"] -[models.completion] +[[models.completion]] model = "gpt-file-complete" provider = "openai" -[models.code_action] +[[models.code_action]] model = "gpt-file-action" temperature = 0.45 provider = "copilot" -[models.chat] +[[models.chat]] model = "gpt-file-chat" provider = "openai" -[models.cli] +[[models.cli]] model = "gpt-file-cli" temperature = 0.15 provider = "ollama" @@ -192,29 +192,41 @@ temperature = 0.0 if cfg.CopilotBaseURL != "http://copilot-override" || cfg.CopilotModel != "ghost-override" || cfg.CopilotTemperature == nil || *cfg.CopilotTemperature != 0.3 { t.Fatalf("copilot overrides not applied: %+v", cfg) } - if cfg.CompletionModel != "env-completion" || cfg.CompletionTemperature == nil || *cfg.CompletionTemperature != 0.33 { - t.Fatalf("completion overrides not applied: model=%q temp=%v", cfg.CompletionModel, cfg.CompletionTemperature) + if len(cfg.CompletionConfigs) != 1 || cfg.CompletionConfigs[0].Model != "env-completion" { + t.Fatalf("completion overrides not applied: %+v", cfg.CompletionConfigs) } - if cfg.CompletionProvider != "copilot" { - t.Fatalf("completion provider override not applied: %q", cfg.CompletionProvider) + if cfg.CompletionConfigs[0].Temperature == nil || *cfg.CompletionConfigs[0].Temperature != 0.33 { + t.Fatalf("completion temperature override missing: %+v", cfg.CompletionConfigs[0]) } - if cfg.CodeActionModel != "env-action" || cfg.CodeActionTemperature == nil || *cfg.CodeActionTemperature != 0.55 { - t.Fatalf("code action overrides not applied: model=%q temp=%v", cfg.CodeActionModel, cfg.CodeActionTemperature) + if cfg.CompletionConfigs[0].Provider != "copilot" { + t.Fatalf("completion provider override not applied: %+v", cfg.CompletionConfigs[0]) } - if cfg.CodeActionProvider != "openai" { - t.Fatalf("code action provider override not applied: %q", cfg.CodeActionProvider) + if len(cfg.CodeActionConfigs) != 1 || cfg.CodeActionConfigs[0].Model != "env-action" { + t.Fatalf("code action overrides not applied: %+v", cfg.CodeActionConfigs) } - if cfg.ChatModel != "env-chat" || cfg.ChatTemperature == nil || *cfg.ChatTemperature != 0.66 { - t.Fatalf("chat overrides not applied: model=%q temp=%v", cfg.ChatModel, cfg.ChatTemperature) + if cfg.CodeActionConfigs[0].Temperature == nil || *cfg.CodeActionConfigs[0].Temperature != 0.55 { + t.Fatalf("code action temp override missing: %+v", cfg.CodeActionConfigs[0]) } - if cfg.ChatProvider != "copilot" { - t.Fatalf("chat provider override not applied: %q", cfg.ChatProvider) + if cfg.CodeActionConfigs[0].Provider != "openai" { + t.Fatalf("code action provider override not applied: %+v", cfg.CodeActionConfigs[0]) } - if cfg.CLIModel != "env-cli" || cfg.CLITemperature == nil || *cfg.CLITemperature != 0.77 { - t.Fatalf("cli overrides not applied: model=%q temp=%v", cfg.CLIModel, cfg.CLITemperature) + if len(cfg.ChatConfigs) != 1 || cfg.ChatConfigs[0].Model != "env-chat" { + t.Fatalf("chat overrides not applied: %+v", cfg.ChatConfigs) } - if cfg.CLIProvider != "ollama" { - t.Fatalf("cli provider override not applied: %q", cfg.CLIProvider) + if cfg.ChatConfigs[0].Temperature == nil || *cfg.ChatConfigs[0].Temperature != 0.66 { + t.Fatalf("chat temp override missing: %+v", cfg.ChatConfigs[0]) + } + if cfg.ChatConfigs[0].Provider != "copilot" { + t.Fatalf("chat provider override not applied: %+v", cfg.ChatConfigs[0]) + } + if len(cfg.CLIConfigs) != 1 || cfg.CLIConfigs[0].Model != "env-cli" { + t.Fatalf("cli overrides not applied: %+v", cfg.CLIConfigs) + } + if cfg.CLIConfigs[0].Temperature == nil || *cfg.CLIConfigs[0].Temperature != 0.77 { + t.Fatalf("cli temp override missing: %+v", cfg.CLIConfigs[0]) + } + if cfg.CLIConfigs[0].Provider != "ollama" { + t.Fatalf("cli provider override not applied: %+v", cfg.CLIConfigs[0]) } // Ensure file values would have applied absent env @@ -234,29 +246,41 @@ temperature = 0.0 if cfg2.OpenAIBaseURL != "https://api.example" || cfg2.OpenAIModel != "gpt-x" || cfg2.OpenAITemperature == nil || *cfg2.OpenAITemperature != 0.0 { t.Fatalf("file merge (openai) not applied: %+v", cfg2) } - if cfg2.CompletionModel != "gpt-file-complete" || cfg2.CompletionTemperature != nil { - t.Fatalf("file merge (completion) not applied: %+v", cfg2) + if len(cfg2.CompletionConfigs) != 1 || cfg2.CompletionConfigs[0].Model != "gpt-file-complete" { + t.Fatalf("file merge (completion) not applied: %+v", cfg2.CompletionConfigs) + } + if cfg2.CompletionConfigs[0].Temperature != nil { + t.Fatalf("expected nil completion temperature, got %+v", cfg2.CompletionConfigs[0]) + } + if cfg2.CompletionConfigs[0].Provider != "openai" { + t.Fatalf("file merge (completion provider) not applied: %+v", cfg2.CompletionConfigs[0]) + } + if len(cfg2.CodeActionConfigs) != 1 || cfg2.CodeActionConfigs[0].Model != "gpt-file-action" { + t.Fatalf("file merge (code action) not applied: %+v", cfg2.CodeActionConfigs) + } + if cfg2.CodeActionConfigs[0].Temperature == nil || *cfg2.CodeActionConfigs[0].Temperature != 0.45 { + t.Fatalf("expected code action temp 0.45, got %+v", cfg2.CodeActionConfigs[0]) } - if cfg2.CompletionProvider != "openai" { - t.Fatalf("file merge (completion provider) not applied: %q", cfg2.CompletionProvider) + if cfg2.CodeActionConfigs[0].Provider != "copilot" { + t.Fatalf("file merge (code action provider) not applied: %+v", cfg2.CodeActionConfigs[0]) } - if cfg2.CodeActionModel != "gpt-file-action" || cfg2.CodeActionTemperature == nil || *cfg2.CodeActionTemperature != 0.45 { - t.Fatalf("file merge (code action) not applied: %+v", cfg2) + if len(cfg2.ChatConfigs) != 1 || cfg2.ChatConfigs[0].Model != "gpt-file-chat" { + t.Fatalf("file merge (chat) not applied: %+v", cfg2.ChatConfigs) } - if cfg2.CodeActionProvider != "copilot" { - t.Fatalf("file merge (code action provider) not applied: %q", cfg2.CodeActionProvider) + if cfg2.ChatConfigs[0].Temperature != nil { + t.Fatalf("expected nil chat temp, got %+v", cfg2.ChatConfigs[0]) } - if cfg2.ChatModel != "gpt-file-chat" || cfg2.ChatTemperature != nil { - t.Fatalf("file merge (chat) not applied: %+v", cfg2) + if cfg2.ChatConfigs[0].Provider != "openai" { + t.Fatalf("file merge (chat provider) not applied: %+v", cfg2.ChatConfigs[0]) } - if cfg2.ChatProvider != "openai" { - t.Fatalf("file merge (chat provider) not applied: %q", cfg2.ChatProvider) + if len(cfg2.CLIConfigs) != 1 || cfg2.CLIConfigs[0].Model != "gpt-file-cli" { + t.Fatalf("file merge (cli) not applied: %+v", cfg2.CLIConfigs) } - if cfg2.CLIModel != "gpt-file-cli" || cfg2.CLITemperature == nil || *cfg2.CLITemperature != 0.15 { - t.Fatalf("file merge (cli) not applied: %+v", cfg2) + if cfg2.CLIConfigs[0].Temperature == nil || *cfg2.CLIConfigs[0].Temperature != 0.15 { + t.Fatalf("expected CLI temp 0.15, got %+v", cfg2.CLIConfigs[0]) } - if cfg2.CLIProvider != "ollama" { - t.Fatalf("file merge (cli provider) not applied: %q", cfg2.CLIProvider) + if cfg2.CLIConfigs[0].Provider != "ollama" { + t.Fatalf("file merge (cli provider) not applied: %+v", cfg2.CLIConfigs[0]) } } |
