diff options
Diffstat (limited to 'internal/appconfig/config.go')
| -rw-r--r-- | internal/appconfig/config.go | 232 |
1 files changed, 91 insertions, 141 deletions
diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go index 47abaaf..63d0437 100644 --- a/internal/appconfig/config.go +++ b/internal/appconfig/config.go @@ -13,6 +13,13 @@ import ( "github.com/pelletier/go-toml/v2" ) +// SurfaceConfig describes a provider/model pairing (with optional temperature). +type SurfaceConfig struct { + Provider string + Model string + Temperature *float64 +} + // App holds user-configurable settings read from ~/.config/hexai/config.toml. type App struct { MaxTokens int `json:"max_tokens" toml:"max_tokens"` @@ -58,19 +65,11 @@ type App struct { // Default temperature for Copilot requests (nil means use provider default) CopilotTemperature *float64 `json:"copilot_temperature" toml:"copilot_temperature"` - // Per-surface model overrides (fall back to provider defaults when unset) - CompletionModel string `json:"completion_model" toml:"completion_model"` - CompletionTemperature *float64 `json:"completion_temperature" toml:"completion_temperature"` - CompletionProvider string `json:"completion_provider" toml:"completion_provider"` - CodeActionModel string `json:"code_action_model" toml:"code_action_model"` - CodeActionTemperature *float64 `json:"code_action_temperature" toml:"code_action_temperature"` - CodeActionProvider string `json:"code_action_provider" toml:"code_action_provider"` - ChatModel string `json:"chat_model" toml:"chat_model"` - ChatTemperature *float64 `json:"chat_temperature" toml:"chat_temperature"` - ChatProvider string `json:"chat_provider" toml:"chat_provider"` - CLIModel string `json:"cli_model" toml:"cli_model"` - CLITemperature *float64 `json:"cli_temperature" toml:"cli_temperature"` - CLIProvider string `json:"cli_provider" toml:"cli_provider"` + // Per-surface provider/model configurations (ordered; first entry is primary) + CompletionConfigs []SurfaceConfig `json:"-" toml:"-"` + CodeActionConfigs []SurfaceConfig `json:"-" toml:"-"` + ChatConfigs []SurfaceConfig `json:"-" toml:"-"` + CLIConfigs []SurfaceConfig `json:"-" toml:"-"` // Prompt templates (configured only via file; no env overrides) // Completion/chat/code action/CLI prompt strings. See config.toml.example for placeholders. @@ -662,72 +661,66 @@ func parseSurfaceModels(raw map[string]any, logger *log.Logger) *App { return nil } var out App - var any bool - if model, provider, temp, ok := decodeModelEntry(table["completion"], "models.completion", logger); ok { - if model != "" { - out.CompletionModel = model - } - if provider != "" { - out.CompletionProvider = provider - } - if temp != nil { - out.CompletionTemperature = temp - } - any = true - } - if model, provider, temp, ok := decodeModelEntry(table["code_action"], "models.code_action", logger); ok { - if model != "" { - out.CodeActionModel = model - } - if provider != "" { - out.CodeActionProvider = provider - } - if temp != nil { - out.CodeActionTemperature = temp - } - any = true - } - if model, provider, temp, ok := decodeModelEntry(table["chat"], "models.chat", logger); ok { - if model != "" { - out.ChatModel = model - } - if provider != "" { - out.ChatProvider = provider - } - if temp != nil { - out.ChatTemperature = temp - } - any = true + appendEntries := func(dest *[]SurfaceConfig, key string, val any) bool { + entries, ok := parseSurfaceEntries(val, key, logger) + if !ok || len(entries) == 0 { + return false + } + *dest = append(*dest, entries...) + return true + } + any := appendEntries(&out.CompletionConfigs, "models.completion", table["completion"]) + any = appendEntries(&out.CodeActionConfigs, "models.code_action", table["code_action"]) || any + any = appendEntries(&out.ChatConfigs, "models.chat", table["chat"]) || any + any = appendEntries(&out.CLIConfigs, "models.cli", table["cli"]) || any + if !any { + return nil } - if model, provider, temp, ok := decodeModelEntry(table["cli"], "models.cli", logger); ok { - if model != "" { - out.CLIModel = model - } - if provider != "" { - out.CLIProvider = provider + return &out +} + +func parseSurfaceEntries(raw any, path string, logger *log.Logger) ([]SurfaceConfig, bool) { + switch v := raw.(type) { + case nil: + return nil, false + case []any: + var out []SurfaceConfig + for i, entry := range v { + cfg, ok := decodeModelEntry(entry, fmt.Sprintf("%s[%d]", path, i), logger) + if !ok || cfg == nil { + continue + } + out = append(out, *cfg) } - if temp != nil { - out.CLITemperature = temp + return out, len(out) > 0 + default: + if cfg, ok := decodeModelEntry(v, path, logger); ok && cfg != nil { + return []SurfaceConfig{*cfg}, true } - any = true + return nil, false } - if !any { +} + +func cloneSurfaceConfigs(src []SurfaceConfig) []SurfaceConfig { + if len(src) == 0 { return nil } - return &out + out := make([]SurfaceConfig, len(src)) + copy(out, src) + return out } -func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, *float64, bool) { +func decodeModelEntry(raw any, path string, logger *log.Logger) (*SurfaceConfig, bool) { if raw == nil { - return "", "", nil, false + return nil, false } switch v := raw.(type) { case string: model := strings.TrimSpace(v) if model == "" { - return "", "", nil, false + return nil, false } - return model, "", nil, true + return &SurfaceConfig{Model: model}, true case map[string]any: model := "" provider := "" @@ -737,7 +730,7 @@ func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, if logger != nil { logger.Printf("config: %s.model must be a string", path) } - return "", "", nil, false + return nil, false } model = strings.TrimSpace(s) } @@ -747,7 +740,7 @@ func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, if logger != nil { logger.Printf("config: %s.provider must be a string", path) } - return "", "", nil, false + return nil, false } provider = strings.TrimSpace(ps) } @@ -755,19 +748,19 @@ func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, if tRaw, ok := v["temperature"]; ok { parsed, ok := parseTemperatureValue(tRaw, path, logger) if !ok { - return "", "", nil, false + return nil, false } tempPtr = parsed } if model == "" && tempPtr == nil && provider == "" { - return "", "", nil, false + return nil, false } - return model, provider, tempPtr, true + return &SurfaceConfig{Provider: provider, Model: model, Temperature: tempPtr}, true default: if logger != nil { logger.Printf("config: %s must be a string or table, got %T", path, raw) } - return "", "", nil, false + return nil, false } } @@ -861,41 +854,17 @@ func (a *App) mergeBasics(other *App) { // mergeSurfaceModels copies per-surface model and temperature overrides. func (a *App) mergeSurfaceModels(other *App) { - if s := strings.TrimSpace(other.CompletionModel); s != "" { - a.CompletionModel = s - } - if other.CompletionTemperature != nil { - a.CompletionTemperature = other.CompletionTemperature - } - if s := strings.TrimSpace(other.CompletionProvider); s != "" { - a.CompletionProvider = s + if len(other.CompletionConfigs) > 0 { + a.CompletionConfigs = cloneSurfaceConfigs(other.CompletionConfigs) } - if s := strings.TrimSpace(other.CodeActionModel); s != "" { - a.CodeActionModel = s + if len(other.CodeActionConfigs) > 0 { + a.CodeActionConfigs = cloneSurfaceConfigs(other.CodeActionConfigs) } - if other.CodeActionTemperature != nil { - a.CodeActionTemperature = other.CodeActionTemperature + if len(other.ChatConfigs) > 0 { + a.ChatConfigs = cloneSurfaceConfigs(other.ChatConfigs) } - if s := strings.TrimSpace(other.CodeActionProvider); s != "" { - a.CodeActionProvider = s - } - if s := strings.TrimSpace(other.ChatModel); s != "" { - a.ChatModel = s - } - if other.ChatTemperature != nil { - a.ChatTemperature = other.ChatTemperature - } - if s := strings.TrimSpace(other.ChatProvider); s != "" { - a.ChatProvider = s - } - if s := strings.TrimSpace(other.CLIModel); s != "" { - a.CLIModel = s - } - if other.CLITemperature != nil { - a.CLITemperature = other.CLITemperature - } - if s := strings.TrimSpace(other.CLIProvider); s != "" { - a.CLIProvider = s + if len(other.CLIConfigs) > 0 { + a.CLIConfigs = cloneSurfaceConfigs(other.CLIConfigs) } } @@ -1263,52 +1232,33 @@ func loadFromEnv(logger *log.Logger) *App { } // Per-surface overrides - if s := getenv("HEXAI_MODEL_COMPLETION"); s != "" { - out.CompletionModel = s - any = true - } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_COMPLETION"); ok { - out.CompletionTemperature = f - any = true - } - if s := getenv("HEXAI_PROVIDER_COMPLETION"); s != "" { - out.CompletionProvider = s - any = true - } - if s := getenv("HEXAI_MODEL_CODE_ACTION"); s != "" { - out.CodeActionModel = s - any = true - } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CODE_ACTION"); ok { - out.CodeActionTemperature = f - any = true - } - if s := getenv("HEXAI_PROVIDER_CODE_ACTION"); s != "" { - out.CodeActionProvider = s - any = true - } - if s := getenv("HEXAI_MODEL_CHAT"); s != "" { - out.ChatModel = s - any = true - } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CHAT"); ok { - out.ChatTemperature = f - any = true + buildEntry := func(modelKey, tempKey, providerKey string) ([]SurfaceConfig, bool) { + model := getenv(modelKey) + tempPtr, tempSet := parseFloatPtr(tempKey) + provider := getenv(providerKey) + if model == "" && provider == "" && !tempSet { + return nil, false + } + entry := SurfaceConfig{Provider: provider, Model: model} + if tempSet { + entry.Temperature = tempPtr + } + return []SurfaceConfig{entry}, true } - if s := getenv("HEXAI_PROVIDER_CHAT"); s != "" { - out.ChatProvider = s + if entries, ok := buildEntry("HEXAI_MODEL_COMPLETION", "HEXAI_TEMPERATURE_COMPLETION", "HEXAI_PROVIDER_COMPLETION"); ok { + out.CompletionConfigs = entries any = true } - if s := getenv("HEXAI_MODEL_CLI"); s != "" { - out.CLIModel = s + if entries, ok := buildEntry("HEXAI_MODEL_CODE_ACTION", "HEXAI_TEMPERATURE_CODE_ACTION", "HEXAI_PROVIDER_CODE_ACTION"); ok { + out.CodeActionConfigs = entries any = true } - if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CLI"); ok { - out.CLITemperature = f + if entries, ok := buildEntry("HEXAI_MODEL_CHAT", "HEXAI_TEMPERATURE_CHAT", "HEXAI_PROVIDER_CHAT"); ok { + out.ChatConfigs = entries any = true } - if s := getenv("HEXAI_PROVIDER_CLI"); s != "" { - out.CLIProvider = s + if entries, ok := buildEntry("HEXAI_MODEL_CLI", "HEXAI_TEMPERATURE_CLI", "HEXAI_PROVIDER_CLI"); ok { + out.CLIConfigs = entries any = true } |
