summaryrefslogtreecommitdiff
path: root/internal/appconfig/config.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/appconfig/config.go')
-rw-r--r--internal/appconfig/config.go264
1 files changed, 263 insertions, 1 deletions
diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go
index adf9b75..47abaaf 100644
--- a/internal/appconfig/config.go
+++ b/internal/appconfig/config.go
@@ -58,6 +58,20 @@ type App struct {
// Default temperature for Copilot requests (nil means use provider default)
CopilotTemperature *float64 `json:"copilot_temperature" toml:"copilot_temperature"`
+ // Per-surface model overrides (fall back to provider defaults when unset)
+ CompletionModel string `json:"completion_model" toml:"completion_model"`
+ CompletionTemperature *float64 `json:"completion_temperature" toml:"completion_temperature"`
+ CompletionProvider string `json:"completion_provider" toml:"completion_provider"`
+ CodeActionModel string `json:"code_action_model" toml:"code_action_model"`
+ CodeActionTemperature *float64 `json:"code_action_temperature" toml:"code_action_temperature"`
+ CodeActionProvider string `json:"code_action_provider" toml:"code_action_provider"`
+ ChatModel string `json:"chat_model" toml:"chat_model"`
+ ChatTemperature *float64 `json:"chat_temperature" toml:"chat_temperature"`
+ ChatProvider string `json:"chat_provider" toml:"chat_provider"`
+ CLIModel string `json:"cli_model" toml:"cli_model"`
+ CLITemperature *float64 `json:"cli_temperature" toml:"cli_temperature"`
+ CLIProvider string `json:"cli_provider" toml:"cli_provider"`
+
// Prompt templates (configured only via file; no env overrides)
// Completion/chat/code action/CLI prompt strings. See config.toml.example for placeholders.
// Completion
@@ -589,7 +603,7 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) {
"copilot_model": {}, "copilot_base_url": {}, "copilot_temperature": {},
}
for k := range raw {
- if _, isTable := map[string]struct{}{"general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {}, "chat": {}, "provider": {}, "openai": {}, "copilot": {}, "ollama": {}, "prompts": {}}[k]; isTable {
+ if _, isTable := map[string]struct{}{"general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {}, "chat": {}, "provider": {}, "models": {}, "openai": {}, "copilot": {}, "ollama": {}, "prompts": {}}[k]; isTable {
continue
}
if _, isLegacy := legacy[k]; isLegacy {
@@ -629,12 +643,170 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) {
}
}
}
+ if m := parseSurfaceModels(raw, logger); m != nil {
+ tab.mergeSurfaceModels(m)
+ }
return &tab, nil
}
+func parseSurfaceModels(raw map[string]any, logger *log.Logger) *App {
+ modelsRaw, ok := raw["models"]
+ if !ok {
+ return nil
+ }
+ table, ok := modelsRaw.(map[string]any)
+ if !ok {
+ if logger != nil {
+ logger.Printf("config: ignoring models section (expected table, got %T)", modelsRaw)
+ }
+ return nil
+ }
+ var out App
+ var any bool
+ if model, provider, temp, ok := decodeModelEntry(table["completion"], "models.completion", logger); ok {
+ if model != "" {
+ out.CompletionModel = model
+ }
+ if provider != "" {
+ out.CompletionProvider = provider
+ }
+ if temp != nil {
+ out.CompletionTemperature = temp
+ }
+ any = true
+ }
+ if model, provider, temp, ok := decodeModelEntry(table["code_action"], "models.code_action", logger); ok {
+ if model != "" {
+ out.CodeActionModel = model
+ }
+ if provider != "" {
+ out.CodeActionProvider = provider
+ }
+ if temp != nil {
+ out.CodeActionTemperature = temp
+ }
+ any = true
+ }
+ if model, provider, temp, ok := decodeModelEntry(table["chat"], "models.chat", logger); ok {
+ if model != "" {
+ out.ChatModel = model
+ }
+ if provider != "" {
+ out.ChatProvider = provider
+ }
+ if temp != nil {
+ out.ChatTemperature = temp
+ }
+ any = true
+ }
+ if model, provider, temp, ok := decodeModelEntry(table["cli"], "models.cli", logger); ok {
+ if model != "" {
+ out.CLIModel = model
+ }
+ if provider != "" {
+ out.CLIProvider = provider
+ }
+ if temp != nil {
+ out.CLITemperature = temp
+ }
+ any = true
+ }
+ if !any {
+ return nil
+ }
+ return &out
+}
+
+func decodeModelEntry(raw any, path string, logger *log.Logger) (string, string, *float64, bool) {
+ if raw == nil {
+ return "", "", nil, false
+ }
+ switch v := raw.(type) {
+ case string:
+ model := strings.TrimSpace(v)
+ if model == "" {
+ return "", "", nil, false
+ }
+ return model, "", nil, true
+ case map[string]any:
+ model := ""
+ provider := ""
+ if m, ok := v["model"]; ok {
+ s, ok := m.(string)
+ if !ok {
+ if logger != nil {
+ logger.Printf("config: %s.model must be a string", path)
+ }
+ return "", "", nil, false
+ }
+ model = strings.TrimSpace(s)
+ }
+ if pRaw, ok := v["provider"]; ok {
+ ps, ok := pRaw.(string)
+ if !ok {
+ if logger != nil {
+ logger.Printf("config: %s.provider must be a string", path)
+ }
+ return "", "", nil, false
+ }
+ provider = strings.TrimSpace(ps)
+ }
+ var tempPtr *float64
+ if tRaw, ok := v["temperature"]; ok {
+ parsed, ok := parseTemperatureValue(tRaw, path, logger)
+ if !ok {
+ return "", "", nil, false
+ }
+ tempPtr = parsed
+ }
+ if model == "" && tempPtr == nil && provider == "" {
+ return "", "", nil, false
+ }
+ return model, provider, tempPtr, true
+ default:
+ if logger != nil {
+ logger.Printf("config: %s must be a string or table, got %T", path, raw)
+ }
+ return "", "", nil, false
+ }
+}
+
+func parseTemperatureValue(raw any, path string, logger *log.Logger) (*float64, bool) {
+ switch v := raw.(type) {
+ case float64:
+ return floatPtr(v), true
+ case int64:
+ return floatPtr(float64(v)), true
+ case string:
+ s := strings.TrimSpace(v)
+ if s == "" {
+ return nil, true
+ }
+ f, err := strconv.ParseFloat(s, 64)
+ if err != nil {
+ if logger != nil {
+ logger.Printf("config: %s.temperature invalid: %v", path, err)
+ }
+ return nil, false
+ }
+ return floatPtr(f), true
+ default:
+ if logger != nil {
+ logger.Printf("config: %s.temperature must be numeric or string, got %T", path, raw)
+ }
+ return nil, false
+ }
+}
+
+func floatPtr(v float64) *float64 {
+ f := v
+ return &f
+}
+
func (a *App) mergeWith(other *App) {
a.mergeBasics(other)
a.mergeProviderFields(other)
+ a.mergeSurfaceModels(other)
a.mergePrompts(other)
}
@@ -687,6 +859,46 @@ func (a *App) mergeBasics(other *App) {
}
}
+// mergeSurfaceModels copies per-surface model and temperature overrides.
+func (a *App) mergeSurfaceModels(other *App) {
+ if s := strings.TrimSpace(other.CompletionModel); s != "" {
+ a.CompletionModel = s
+ }
+ if other.CompletionTemperature != nil {
+ a.CompletionTemperature = other.CompletionTemperature
+ }
+ if s := strings.TrimSpace(other.CompletionProvider); s != "" {
+ a.CompletionProvider = s
+ }
+ if s := strings.TrimSpace(other.CodeActionModel); s != "" {
+ a.CodeActionModel = s
+ }
+ if other.CodeActionTemperature != nil {
+ a.CodeActionTemperature = other.CodeActionTemperature
+ }
+ if s := strings.TrimSpace(other.CodeActionProvider); s != "" {
+ a.CodeActionProvider = s
+ }
+ if s := strings.TrimSpace(other.ChatModel); s != "" {
+ a.ChatModel = s
+ }
+ if other.ChatTemperature != nil {
+ a.ChatTemperature = other.ChatTemperature
+ }
+ if s := strings.TrimSpace(other.ChatProvider); s != "" {
+ a.ChatProvider = s
+ }
+ if s := strings.TrimSpace(other.CLIModel); s != "" {
+ a.CLIModel = s
+ }
+ if other.CLITemperature != nil {
+ a.CLITemperature = other.CLITemperature
+ }
+ if s := strings.TrimSpace(other.CLIProvider); s != "" {
+ a.CLIProvider = s
+ }
+}
+
// mergePrompts copies non-empty prompt templates from other.
func (a *App) mergePrompts(other *App) {
// Completion
@@ -1050,6 +1262,56 @@ func loadFromEnv(logger *log.Logger) *App {
any = true
}
+ // Per-surface overrides
+ if s := getenv("HEXAI_MODEL_COMPLETION"); s != "" {
+ out.CompletionModel = s
+ any = true
+ }
+ if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_COMPLETION"); ok {
+ out.CompletionTemperature = f
+ any = true
+ }
+ if s := getenv("HEXAI_PROVIDER_COMPLETION"); s != "" {
+ out.CompletionProvider = s
+ any = true
+ }
+ if s := getenv("HEXAI_MODEL_CODE_ACTION"); s != "" {
+ out.CodeActionModel = s
+ any = true
+ }
+ if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CODE_ACTION"); ok {
+ out.CodeActionTemperature = f
+ any = true
+ }
+ if s := getenv("HEXAI_PROVIDER_CODE_ACTION"); s != "" {
+ out.CodeActionProvider = s
+ any = true
+ }
+ if s := getenv("HEXAI_MODEL_CHAT"); s != "" {
+ out.ChatModel = s
+ any = true
+ }
+ if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CHAT"); ok {
+ out.ChatTemperature = f
+ any = true
+ }
+ if s := getenv("HEXAI_PROVIDER_CHAT"); s != "" {
+ out.ChatProvider = s
+ any = true
+ }
+ if s := getenv("HEXAI_MODEL_CLI"); s != "" {
+ out.CLIModel = s
+ any = true
+ }
+ if f, ok := parseFloatPtr("HEXAI_TEMPERATURE_CLI"); ok {
+ out.CLITemperature = f
+ any = true
+ }
+ if s := getenv("HEXAI_PROVIDER_CLI"); s != "" {
+ out.CLIProvider = s
+ any = true
+ }
+
if !any {
return nil
}