diff options
| author | Paul Buetow <paul@buetow.org> | 2025-09-26 19:34:19 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-09-26 19:34:19 +0300 |
| commit | 0583b360ceb606b8e58f12a17f588bd27feeb117 (patch) | |
| tree | ae8ac0d7968a409a76d18d84e080d02da52ce775 /internal/appconfig/config.go | |
| parent | 869c018a7a26285263cf7692f25f6aa44e2635c9 (diff) | |
Add per-surface provider overrides and wiring
Diffstat (limited to 'internal/appconfig/config.go')
| -rw-r--r-- | internal/appconfig/config.go | 264 |
1 files changed, 263 insertions, 1 deletions
diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go index adf9b75..47abaaf 100644 --- a/internal/appconfig/config.go +++ b/internal/appconfig/config.go @@ -58,6 +58,20 @@ type App struct { // Default temperature for Copilot requests (nil means use provider default) CopilotTemperature *float64 `json:"copilot_temperature" toml:"copilot_temperature"` + // Per-surface model overrides (fall back to provider defaults when unset) + CompletionModel string `json:"completion_model" toml:"completion_model"` + CompletionTemperature *float64 `json:"completion_temperature" toml:"completion_temperature"` + CompletionProvider string `json:"completion_provider" toml:"completion_provider"` + CodeActionModel string `json:"code_action_model" toml:"code_action_model"` + CodeActionTemperature *float64 `json:"code_action_temperature" toml:"code_action_temperature"` + CodeActionProvider string `json:"code_action_provider" toml:"code_action_provider"` + ChatModel string `json:"chat_model" toml:"chat_model"` + ChatTemperature *float64 `json:"chat_temperature" toml:"chat_temperature"` + ChatProvider string `json:"chat_provider" toml:"chat_provider"` + CLIModel string `json:"cli_model" toml:"cli_model"` + CLITemperature *float64 `json:"cli_temperature" toml:"cli_temperature"` + CLIProvider string `json:"cli_provider" toml:"cli_provider"` + // Prompt templates (configured only via file; no env overrides) // Completion/chat/code action/CLI prompt strings. See config.toml.example for placeholders. // Completion @@ -589,7 +603,7 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) { "copilot_model": {}, "copilot_base_url": {}, "copilot_temperature": {}, } for k := range raw { - if _, isTable := map[string]struct{}{"general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {}, "chat": {}, "provider": {}, "openai": {}, "copilot": {}, "ollama": {}, "prompts": {}}[k]; isTable { + if _, isTable := map[string]struct{}{"general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {}, "chat": {}, "provider": {}, "models": {}, "openai": {}, "copilot": {}, "ollama": {}, "prompts": {}}[k]; isTable { continue } if _, isLegacy := legacy[k]; isLegacy { @@ -629,12 +643,170 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) { } } } + if m := parseSurfaceModels(raw, logger); m != nil { + tab.mergeSurfaceModels(m) + } return &tab, nil } +func parseSurfaceModels(raw map[string]any, logger *log.Logger) *App { + modelsRaw, ok := raw["models"] + if !ok { + return nil + } + table, ok := modelsRaw.(map[string]any) + if !ok { + if logger != nil { + logger.Printf("config: ignoring models section (expected table, got %T)", modelsRaw) + } + return nil + } + var out App + var any bool + if model, provider, temp, ok := decodeModelEntry(table["completion"], "models.completion", logger); ok { + if model != "" { + out.CompletionModel = model + } + if provider != "" { + out.CompletionProvider = provider + } + if temp != nil { + out.CompletionTemperature = temp + } + any = true + } + if model, provider, temp, ok := decodeModelEntry(table["code_action"], "models.code_action", logger); ok { + if model != "" { + out.CodeActionModel = model + } + if provider != "" { + out.CodeActionProvider = provider + } + if temp != nil { + out.CodeActionTemperature = temp + } + any = true + } + if model, provider, temp, ok := decodeModelEntry(table["chat"], "models.chat", logger); ok { + if model != "" { + out.ChatModel = model + } + if provider != "" { + out.ChatProvider = provider + } + if temp != nil { + out.ChatTemperature = temp + } + any = true + } + if model, provider, temp, ok := decodeModelEntry(table["cli"], "models.cli", logger); ok { + if model != "" { + out.CLIModel = model + } + if provider != "" { + out.CLIProvider = provider + } + if temp != nil { + out.CLITemperature = temp + } + any = true + } + if !any { + return nil + } + return &out +} + +func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, *float64, bool) { + if raw == nil { + return "", "", nil, false + } + switch v := raw.(type) { + case string: + model := strings.TrimSpace(v) + if model == "" { + return "", "", nil, false + } + return model, "", nil, true + case map[string]any: + model := "" + provider := "" + if m, ok := v["model"]; ok { + s, ok := m.(string) + if !ok { + if logger != nil { + logger.Printf("config: %s.model must be a string", path) + } + return "", "", nil, false + } + model = strings.TrimSpace(s) + } + if pRaw, ok := v["provider"]; ok { + ps, ok := pRaw.(string) + if !ok { + if logger != nil { + logger.Printf("config: %s.provider must be a string", path) + } + return "", "", nil, false + } + provider = strings.TrimSpace(ps) + } + var tempPtr *float64 + if tRaw, ok := v["temperature"]; ok { + parsed, ok := parseTemperatureValue(tRaw, path, logger) + if !ok { + return "", "", nil, false + } + tempPtr = parsed + } + if model == "" && tempPtr == nil && provider == "" { + return "", "", nil, false + } + return model, provider, tempPtr, true + default: + if logger != nil { + logger.Printf("config: %s must be a string or table, got %T", path, raw) + } + return "", "", nil, false + } +} + +func parseTemperatureValue(raw any, path string, logger *log.Logger) (*float64, bool) { + switch v := raw.(type) { + case float64: + return floatPtr(v), true + case int64: + return floatPtr(float64(v)), true + case string: + s := strings.TrimSpace(v) + if s == "" { + return nil, true + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + if logger != nil { + logger.Printf("config: %s.temperature invalid: %v", path, err) + } + return nil, false + } + return floatPtr(f), true + default: + if logger != nil { + logger.Printf("config: %s.temperature must be numeric or string, got %T", path, raw) + } + return nil, false + } +} + +func floatPtr(v float64) *float64 { + f := v + return &f +} + func (a *App) mergeWith(other *App) { a.mergeBasics(other) a.mergeProviderFields(other) + a.mergeSurfaceModels(other) a.mergePrompts(other) } @@ -687,6 +859,46 @@ func (a *App) mergeBasics(other *App) { } } +// mergeSurfaceModels copies per-surface model and temperature overrides. +func (a *App) mergeSurfaceModels(other *App) { + if s := strings.TrimSpace(other.CompletionModel); s != "" { + a.CompletionModel = s + } + if other.CompletionTemperature != nil { + a.CompletionTemperature = other.CompletionTemperature + } + if s := strings.TrimSpace(other.CompletionProvider); s != "" { + a.CompletionProvider = s + } + if s := strings.TrimSpace(other.CodeActionModel); s != "" { + a.CodeActionModel = s + } + if other.CodeActionTemperature != nil { + a.CodeActionTemperature = other.CodeActionTemperature + } + if s := strings.TrimSpace(other.CodeActionProvider); s != "" { + a.CodeActionProvider = s + } + if s := strings.TrimSpace(other.ChatModel); s != "" { + a.ChatModel = s + } + if other.ChatTemperature != nil { + a.ChatTemperature = other.ChatTemperature + } + if s := strings.TrimSpace(other.ChatProvider); s != "" { + a.ChatProvider = s + } + if s := strings.TrimSpace(other.CLIModel); s != "" { + a.CLIModel = s + } + if other.CLITemperature != nil { + a.CLITemperature = other.CLITemperature + } + if s := strings.TrimSpace(other.CLIProvider); s != "" { + a.CLIProvider = s + } +} + // mergePrompts copies non-empty prompt templates from other. func (a *App) mergePrompts(other *App) { // Completion @@ -1050,6 +1262,56 @@ func loadFromEnv(logger *log.Logger) *App { any = true } + // Per-surface overrides + if s := getenv("HEXAI_MODEL_COMPLETION"); s != "" { + out.CompletionModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_COMPLETION"); ok { + out.CompletionTemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_COMPLETION"); s != "" { + out.CompletionProvider = s + any = true + } + if s := getenv("HEXAI_MODEL_CODE_ACTION"); s != "" { + out.CodeActionModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CODE_ACTION"); ok { + out.CodeActionTemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_CODE_ACTION"); s != "" { + out.CodeActionProvider = s + any = true + } + if s := getenv("HEXAI_MODEL_CHAT"); s != "" { + out.ChatModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CHAT"); ok { + out.ChatTemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_CHAT"); s != "" { + out.ChatProvider = s + any = true + } + if s := getenv("HEXAI_MODEL_CLI"); s != "" { + out.CLIModel = s + any = true + } + if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CLI"); ok { + out.CLITemperature = f + any = true + } + if s := getenv("HEXAI_PROVIDER_CLI"); s != "" { + out.CLIProvider = s + any = true + } + if !any { return nil } |
