summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-16 03:37:28 +0200
committerPaul Buetow <paul@buetow.org>2026-03-16 03:37:28 +0200
commit2e9cabb1c8bf1f0246e513fe1f86a552e07eee94 (patch)
tree04620a25f1aa5c7d9d914ba136e591abab1c509d
parentad988c34181b7234a54d279874f29e126607fad3 (diff)
Refactor oversized functions and split large test files
Split DefaultPrompts (201L), loadFromFile (83L), and Update (74L) into focused helper functions under 50 lines each. Split handlers_test.go (1650L) and config_test.go (1419L) into logical sub-files under 1000L. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
-rw-r--r--internal/appconfig/config_features_test.go588
-rw-r--r--internal/appconfig/config_load.go98
-rw-r--r--internal/appconfig/config_test.go578
-rw-r--r--internal/mcp/handlers_prompt_test.go (renamed from internal/mcp/handlers_test.go)699
-rw-r--r--internal/mcp/handlers_tool_test.go703
-rw-r--r--internal/promptstore/default_prompts.go219
-rw-r--r--internal/stats/stats.go85
7 files changed, 1517 insertions, 1453 deletions
diff --git a/internal/appconfig/config_features_test.go b/internal/appconfig/config_features_test.go
new file mode 100644
index 0000000..b3c12e9
--- /dev/null
+++ b/internal/appconfig/config_features_test.go
@@ -0,0 +1,588 @@
+// Summary: Tests for ignore config, tmux-edit config, and low-level parsing helpers
+// (temperature, model entries, surface entries, resolved model).
+package appconfig
+
+import (
+ "os"
+ "path/filepath"
+ "reflect"
+ "testing"
+)
+
+func TestIgnoreConfig_Defaults(t *testing.T) {
+ clearHexaiEnv(t)
+ cfg := Load(nil)
+ if cfg.IgnoreGitignore == nil || !*cfg.IgnoreGitignore {
+ t.Error("expected IgnoreGitignore default true")
+ }
+ if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify {
+ t.Error("expected IgnoreLSPNotify default true")
+ }
+ if len(cfg.IgnoreExtraPatterns) != 0 {
+ t.Errorf("expected empty IgnoreExtraPatterns, got %v", cfg.IgnoreExtraPatterns)
+ }
+}
+
+func TestIgnoreConfig_FromFile(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[ignore]
+gitignore = false
+extra_patterns = ["*.min.js", "dist/**"]
+lsp_notify_ignored = false
+`)
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
+ if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
+ t.Error("expected IgnoreGitignore false from file")
+ }
+ if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify {
+ t.Error("expected IgnoreLSPNotify false from file")
+ }
+ want := []string{"*.min.js", "dist/**"}
+ if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
+ t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
+ }
+}
+
+func TestIgnoreConfig_EnvOverrides(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[ignore]
+gitignore = true
+lsp_notify_ignored = true
+`)
+ withEnv(t, "HEXAI_IGNORE_GITIGNORE", "false")
+ withEnv(t, "HEXAI_IGNORE_LSP_NOTIFY", "0")
+ withEnv(t, "HEXAI_IGNORE_EXTRA_PATTERNS", "*.bak,*.tmp")
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
+ if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
+ t.Error("expected IgnoreGitignore false from env override")
+ }
+ if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify {
+ t.Error("expected IgnoreLSPNotify false from env override")
+ }
+ want := []string{"*.bak", "*.tmp"}
+ if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
+ t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
+ }
+}
+
+func TestIgnoreConfig_ProjectOverride(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[ignore]
+gitignore = true
+`)
+ // Set up a fake git repo with project override
+ projectDir := t.TempDir()
+ if err := os.Mkdir(filepath.Join(projectDir, ".git"), 0o755); err != nil {
+ t.Fatalf("mkdir .git: %v", err)
+ }
+ projectCfg := filepath.Join(projectDir, ProjectConfigFilename)
+ writeFile(t, projectCfg, `
+[ignore]
+gitignore = false
+extra_patterns = ["build/**"]
+`)
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: projectDir})
+ if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
+ t.Error("expected project override to set IgnoreGitignore false")
+ }
+ want := []string{"build/**"}
+ if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
+ t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
+ }
+}
+
+func TestIgnoreConfig_DisableGitignore(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[ignore]
+gitignore = false
+`)
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
+ if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
+ t.Error("expected IgnoreGitignore false")
+ }
+ // LSP notify should still be true (default, not overridden)
+ if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify {
+ t.Error("expected IgnoreLSPNotify to remain true (default)")
+ }
+}
+
+func TestTmuxEditConfig_FromFile(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[tmux_edit]
+popup_width = "90%"
+popup_height = "85%"
+default_agent = "claude"
+
+[[tmux_edit.agents]]
+name = "claude"
+display_name = "Claude Code"
+detect_pattern = "(?i)(claude|anthropic)"
+prompt_pattern = '(?s)>\s*(.+?)$'
+clear_first = true
+clear_keys = "C-u"
+newline_keys = "S-Enter"
+submit_keys = "Enter"
+
+[[tmux_edit.agents]]
+name = "cursor"
+display_name = "Cursor"
+detect_pattern = "(?i)cursor"
+prompt_pattern = '(?s)│\s*(.+?)$'
+strip_patterns = ["INSERT", "Add a follow-up"]
+clear_first = true
+clear_keys = "C-u"
+newline_keys = "S-Enter"
+submit_keys = "Enter"
+`)
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
+ if cfg.TmuxEditPopupWidth != "90%" {
+ t.Errorf("PopupWidth = %q, want 90%%", cfg.TmuxEditPopupWidth)
+ }
+ if cfg.TmuxEditPopupHeight != "85%" {
+ t.Errorf("PopupHeight = %q, want 85%%", cfg.TmuxEditPopupHeight)
+ }
+ if cfg.TmuxEditDefaultAgent != "claude" {
+ t.Errorf("DefaultAgent = %q, want claude", cfg.TmuxEditDefaultAgent)
+ }
+ if len(cfg.TmuxEditAgents) != 2 {
+ t.Fatalf("got %d agents, want 2", len(cfg.TmuxEditAgents))
+ }
+ a := cfg.TmuxEditAgents[0]
+ if a.Name != "claude" || a.DisplayName != "Claude Code" {
+ t.Errorf("agent[0] = %q/%q, want claude/Claude Code", a.Name, a.DisplayName)
+ }
+ if a.ClearFirst == nil || !*a.ClearFirst {
+ t.Error("expected ClearFirst = true for claude agent")
+ }
+ b := cfg.TmuxEditAgents[1]
+ if b.Name != "cursor" {
+ t.Errorf("agent[1].Name = %q, want cursor", b.Name)
+ }
+ if len(b.StripPatterns) != 2 {
+ t.Errorf("agent[1].StripPatterns = %v, want 2 entries", b.StripPatterns)
+ }
+}
+
+func TestTmuxEditConfig_Merge(t *testing.T) {
+ clearHexaiEnv(t)
+ a := newDefaultConfig()
+ b := App{
+ TmuxEditPopupWidth: "70%",
+ TmuxEditDefaultAgent: "amp",
+ TmuxEditAgents: []TmuxEditAgentCfg{
+ {Name: "amp", DisplayName: "Amp"},
+ },
+ }
+ a.mergeWith(&b)
+ if a.TmuxEditPopupWidth != "70%" {
+ t.Errorf("PopupWidth = %q, want 70%%", a.TmuxEditPopupWidth)
+ }
+ if a.TmuxEditDefaultAgent != "amp" {
+ t.Errorf("DefaultAgent = %q, want amp", a.TmuxEditDefaultAgent)
+ }
+ if len(a.TmuxEditAgents) != 1 || a.TmuxEditAgents[0].Name != "amp" {
+ t.Errorf("Agents = %v, want single amp", a.TmuxEditAgents)
+ }
+}
+
+func TestTmuxEditConfig_SkipsEmptyName(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[tmux_edit]
+[[tmux_edit.agents]]
+name = ""
+display_name = "Empty"
+`)
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
+ if len(cfg.TmuxEditAgents) != 0 {
+ t.Errorf("got %d agents, want 0 (empty name should be skipped)", len(cfg.TmuxEditAgents))
+ }
+}
+
+// --- Config Parsing Tests ---
+
+func TestParseTemperatureValue(t *testing.T) {
+ tests := []struct {
+ name string
+ input any
+ wantValue *float64
+ wantOK bool
+ }{
+ {"float64 zero", float64(0.0), floatPtr(0.0), true},
+ {"float64 half", float64(0.5), floatPtr(0.5), true},
+ {"float64 one", float64(1.0), floatPtr(1.0), true},
+ {"float64 two", float64(2.0), floatPtr(2.0), true},
+ {"int64 zero", int64(0), floatPtr(0.0), true},
+ {"int64 one", int64(1), floatPtr(1.0), true},
+ {"int64 two", int64(2), floatPtr(2.0), true},
+ {"string zero", "0", floatPtr(0.0), true},
+ {"string one", "1", floatPtr(1.0), true},
+ {"string two", "2", floatPtr(2.0), true},
+ {"string float", "0.75", floatPtr(0.75), true},
+ {"string empty", "", nil, true},
+ {"string whitespace", " ", nil, true},
+ {"string invalid", "invalid", nil, false},
+ {"string negative", "-0.5", floatPtr(-0.5), true},
+ {"string very small", "0.0001", floatPtr(0.0001), true},
+ {"string high precision", "1.123456789", floatPtr(1.123456789), true},
+ {"nil value", nil, nil, false},
+ {"bool value", true, nil, false},
+ {"map value", map[string]any{}, nil, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, ok := parseTemperatureValue(tt.input, "test", newLogger())
+ if ok != tt.wantOK {
+ t.Errorf("parseTemperatureValue() ok = %v, want %v", ok, tt.wantOK)
+ }
+ if !ok {
+ return
+ }
+ if (got == nil) != (tt.wantValue == nil) {
+ t.Errorf("parseTemperatureValue() = %v, want %v", got, tt.wantValue)
+ return
+ }
+ if got != nil && tt.wantValue != nil && *got != *tt.wantValue {
+ t.Errorf("parseTemperatureValue() = %v, want %v", *got, *tt.wantValue)
+ }
+ })
+ }
+}
+
+func TestDecodeModelEntry(t *testing.T) {
+ tests := []struct {
+ name string
+ input any
+ wantCfg *SurfaceConfig
+ wantOK bool
+ }{
+ {
+ name: "simple string model",
+ input: "gpt-4",
+ wantCfg: &SurfaceConfig{Model: "gpt-4"},
+ wantOK: true,
+ },
+ {
+ name: "empty string",
+ input: "",
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "whitespace string",
+ input: " ",
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "object with all fields",
+ input: map[string]any{
+ "model": "claude-3",
+ "provider": "anthropic",
+ "temperature": float64(0.7),
+ },
+ wantCfg: &SurfaceConfig{
+ Model: "claude-3",
+ Provider: "anthropic",
+ Temperature: floatPtr(0.7),
+ },
+ wantOK: true,
+ },
+ {
+ name: "object with model only",
+ input: map[string]any{
+ "model": "gpt-4o",
+ },
+ wantCfg: &SurfaceConfig{Model: "gpt-4o"},
+ wantOK: true,
+ },
+ {
+ name: "object with provider only",
+ input: map[string]any{
+ "provider": "openai",
+ },
+ wantCfg: &SurfaceConfig{Provider: "openai"},
+ wantOK: true,
+ },
+ {
+ name: "object with temperature only",
+ input: map[string]any{
+ "temperature": float64(0.5),
+ },
+ wantCfg: &SurfaceConfig{Temperature: floatPtr(0.5)},
+ wantOK: true,
+ },
+ {
+ name: "object with empty fields",
+ input: map[string]any{
+ "model": "",
+ "provider": "",
+ },
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "object with invalid model type",
+ input: map[string]any{
+ "model": 123,
+ },
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "object with invalid provider type",
+ input: map[string]any{
+ "provider": 456,
+ },
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "object with invalid temperature",
+ input: map[string]any{
+ "model": "gpt-4",
+ "temperature": "not a number",
+ },
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "nil input",
+ input: nil,
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "invalid type (int)",
+ input: 123,
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "invalid type (slice)",
+ input: []string{"gpt-4"},
+ wantCfg: nil,
+ wantOK: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, ok := decodeModelEntry(tt.input, "test", newLogger())
+ if ok != tt.wantOK {
+ t.Errorf("decodeModelEntry() ok = %v, want %v", ok, tt.wantOK)
+ }
+ if !ok {
+ return
+ }
+ if (got == nil) != (tt.wantCfg == nil) {
+ t.Errorf("decodeModelEntry() = %v, want %v", got, tt.wantCfg)
+ return
+ }
+ if got == nil {
+ return
+ }
+ if got.Model != tt.wantCfg.Model {
+ t.Errorf("Model = %q, want %q", got.Model, tt.wantCfg.Model)
+ }
+ if got.Provider != tt.wantCfg.Provider {
+ t.Errorf("Provider = %q, want %q", got.Provider, tt.wantCfg.Provider)
+ }
+ if (got.Temperature == nil) != (tt.wantCfg.Temperature == nil) {
+ t.Errorf("Temperature nil mismatch: got %v, want %v", got.Temperature, tt.wantCfg.Temperature)
+ } else if got.Temperature != nil && *got.Temperature != *tt.wantCfg.Temperature {
+ t.Errorf("Temperature = %v, want %v", *got.Temperature, *tt.wantCfg.Temperature)
+ }
+ })
+ }
+}
+
+func TestResolvedModel(t *testing.T) {
+ tests := []struct {
+ name string
+ section sectionOpenAI
+ want string
+ }{
+ {
+ name: "explicit model no presets",
+ section: sectionOpenAI{Model: "gpt-4"},
+ want: "gpt-4",
+ },
+ {
+ name: "empty model",
+ section: sectionOpenAI{Model: ""},
+ want: "",
+ },
+ {
+ name: "whitespace model",
+ section: sectionOpenAI{Model: " "},
+ want: "",
+ },
+ {
+ name: "preset match exact case",
+ section: sectionOpenAI{
+ Model: "fast",
+ Presets: map[string]string{"fast": "gpt-3.5-turbo"},
+ },
+ want: "gpt-3.5-turbo",
+ },
+ {
+ name: "preset match case insensitive",
+ section: sectionOpenAI{
+ Model: "FAST",
+ Presets: map[string]string{"fast": "gpt-3.5-turbo"},
+ },
+ want: "gpt-3.5-turbo",
+ },
+ {
+ name: "no preset match returns original",
+ section: sectionOpenAI{
+ Model: "custom-model",
+ Presets: map[string]string{"fast": "gpt-3.5-turbo"},
+ },
+ want: "custom-model",
+ },
+ {
+ name: "preset empty value returns original",
+ section: sectionOpenAI{
+ Model: "fast",
+ Presets: map[string]string{"fast": ""},
+ },
+ want: "fast",
+ },
+ {
+ name: "preset whitespace value returns original",
+ section: sectionOpenAI{
+ Model: "fast",
+ Presets: map[string]string{"fast": " "},
+ },
+ want: "fast",
+ },
+ {
+ name: "multiple presets uses correct one",
+ section: sectionOpenAI{
+ Model: "smart",
+ Presets: map[string]string{
+ "fast": "gpt-3.5-turbo",
+ "smart": "gpt-4",
+ "mini": "gpt-3.5-mini",
+ },
+ },
+ want: "gpt-4",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.section.resolvedModel()
+ if got != tt.want {
+ t.Errorf("resolvedModel() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestParseSurfaceEntries(t *testing.T) {
+ tests := []struct {
+ name string
+ input any
+ wantLen int
+ wantOK bool
+ }{
+ {
+ name: "nil input",
+ input: nil,
+ wantLen: 0,
+ wantOK: false,
+ },
+ {
+ name: "single string",
+ input: "gpt-4",
+ wantLen: 1,
+ wantOK: true,
+ },
+ {
+ name: "single map",
+ input: map[string]any{
+ "model": "claude-3",
+ "provider": "anthropic",
+ },
+ wantLen: 1,
+ wantOK: true,
+ },
+ {
+ name: "array of strings",
+ input: []any{
+ "gpt-4",
+ "claude-3",
+ },
+ wantLen: 2,
+ wantOK: true,
+ },
+ {
+ name: "array of maps",
+ input: []any{
+ map[string]any{"model": "gpt-4", "provider": "openai"},
+ map[string]any{"model": "claude-3", "provider": "anthropic"},
+ },
+ wantLen: 2,
+ wantOK: true,
+ },
+ {
+ name: "array with invalid entries",
+ input: []any{
+ "gpt-4",
+ 123,
+ "claude-3",
+ },
+ wantLen: 2,
+ wantOK: true,
+ },
+ {
+ name: "array with all invalid entries",
+ input: []any{
+ 123,
+ true,
+ nil,
+ },
+ wantLen: 0,
+ wantOK: false,
+ },
+ {
+ name: "empty array",
+ input: []any{},
+ wantLen: 0,
+ wantOK: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, ok := parseSurfaceEntries(tt.input, "test", newLogger())
+ if ok != tt.wantOK {
+ t.Errorf("parseSurfaceEntries() ok = %v, want %v", ok, tt.wantOK)
+ }
+ if len(got) != tt.wantLen {
+ t.Errorf("parseSurfaceEntries() len = %d, want %d", len(got), tt.wantLen)
+ }
+ })
+ }
+}
diff --git a/internal/appconfig/config_load.go b/internal/appconfig/config_load.go
index 261835b..dc917ff 100644
--- a/internal/appconfig/config_load.go
+++ b/internal/appconfig/config_load.go
@@ -139,6 +139,8 @@ func loadProjectConfig(logger *log.Logger, opts LoadOptions, cfg *App) {
}
}
+// loadFromFile reads a TOML config file, validates it, and returns the parsed App.
+// Returns (nil, err) on I/O or parse errors; returns (nil, nil) when the file does not exist.
func loadFromFile(path string, logger *log.Logger) (*App, error) {
b, err := os.ReadFile(path)
if err != nil {
@@ -148,25 +150,50 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) {
return nil, err
}
+ tables, raw, err := decodeTOML(b, path, logger)
+ if err != nil {
+ return nil, err
+ }
+ if err := rejectLegacyKeys(raw); err != nil {
+ return nil, err
+ }
+ if logger != nil {
+ logger.Printf("loaded configuration from %s (TOML)", path)
+ }
+
+ tab := tables.toApp()
+ applyRawIntOverrides(raw, &tab)
+ if m := parseSurfaceModels(raw, logger); m != nil {
+ tab.mergeSurfaceModels(m)
+ }
+ return &tab, nil
+}
+
+// decodeTOML parses raw TOML bytes into both the typed fileConfig and a raw map
+// for validation and defensive integer handling.
+func decodeTOML(b []byte, path string, logger *log.Logger) (*fileConfig, map[string]any, error) {
var tables fileConfig
errTables := toml.NewDecoder(strings.NewReader(string(b))).Decode(&tables)
- // Raw map for validation/presence checks
var raw map[string]any
errRaw := toml.Unmarshal(b, &raw)
if errTables != nil {
if logger != nil {
logger.Printf("invalid TOML config file %s: %v", path, errTables)
}
- return nil, errTables
+ return nil, nil, errTables
}
if errRaw != nil {
if logger != nil {
logger.Printf("invalid TOML config file %s: %v", path, errRaw)
}
- return nil, errRaw
+ return nil, nil, errRaw
}
+ return &tables, raw, nil
+}
- // Reject legacy flat keys at top-level (sectioned-only config is allowed)
+// rejectLegacyKeys returns an error if the raw map contains flat keys from the
+// old unsectioned config format. Only sectioned table keys are allowed.
+func rejectLegacyKeys(raw map[string]any) error {
legacy := map[string]struct{}{
"max_tokens": {}, "context_mode": {}, "context_window_lines": {}, "max_context_tokens": {},
"log_preview_limit": {}, "completion_debounce_ms": {}, "completion_throttle_ms": {},
@@ -175,53 +202,46 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) {
"openai_model": {}, "openai_base_url": {}, "openai_temperature": {},
"ollama_model": {}, "ollama_base_url": {}, "ollama_temperature": {},
}
+ knownTables := map[string]struct{}{
+ "general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {},
+ "chat": {}, "provider": {}, "models": {}, "openai": {}, "ollama": {}, "prompts": {},
+ }
for k := range raw {
- if _, isTable := map[string]struct{}{
- "general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {},
- "chat": {}, "provider": {}, "models": {}, "openai": {}, "ollama": {}, "prompts": {},
- }[k]; isTable {
+ if _, isTable := knownTables[k]; isTable {
continue
}
if _, isLegacy := legacy[k]; isLegacy {
- return nil, fmt.Errorf("unsupported flat key '%s' in config; use sectioned tables (see config.toml.example)", k)
+ return fmt.Errorf("unsupported flat key '%s' in config; use sectioned tables (see config.toml.example)", k)
}
}
+ return nil
+}
- if logger != nil {
- logger.Printf("loaded configuration from %s (TOML)", path)
- }
-
- // Build App from tables only
- tab := tables.toApp()
- // Ensure explicit values from raw map are respected (defensive for ints)
+// applyRawIntOverrides defensively re-applies integer values from the raw TOML map
+// that the typed decoder may have silently zeroed (e.g. int vs float mismatch).
+func applyRawIntOverrides(raw map[string]any, tab *App) {
if t, ok := raw["completion"].(map[string]any); ok {
- if v, present := t["manual_invoke_min_prefix"]; present {
- switch vv := v.(type) {
- case int64:
- tab.ManualInvokeMinPrefix = int(vv)
- case int:
- tab.ManualInvokeMinPrefix = vv
- case float64:
- tab.ManualInvokeMinPrefix = int(vv)
- }
- }
+ applyRawInt(&tab.ManualInvokeMinPrefix, t, "manual_invoke_min_prefix")
}
if t, ok := raw["logging"].(map[string]any); ok {
- if v, present := t["log_preview_limit"]; present {
- switch vv := v.(type) {
- case int64:
- tab.LogPreviewLimit = int(vv)
- case int:
- tab.LogPreviewLimit = vv
- case float64:
- tab.LogPreviewLimit = int(vv)
- }
- }
+ applyRawInt(&tab.LogPreviewLimit, t, "log_preview_limit")
}
- if m := parseSurfaceModels(raw, logger); m != nil {
- tab.mergeSurfaceModels(m)
+}
+
+// applyRawInt sets *dst from table[key] when the value is a numeric type.
+func applyRawInt(dst *int, table map[string]any, key string) {
+ v, present := table[key]
+ if !present {
+ return
+ }
+ switch vv := v.(type) {
+ case int64:
+ *dst = int(vv)
+ case int:
+ *dst = vv
+ case float64:
+ *dst = int(vv)
}
- return &tab, nil
}
func (fc *fileConfig) toApp() App {
diff --git a/internal/appconfig/config_test.go b/internal/appconfig/config_test.go
index c98d904..d4e7739 100644
--- a/internal/appconfig/config_test.go
+++ b/internal/appconfig/config_test.go
@@ -839,581 +839,3 @@ func TestProjectConfigPath(t *testing.T) {
t.Fatalf("ProjectConfigPath() = %q, want empty", path)
}
}
-
-func TestIgnoreConfig_Defaults(t *testing.T) {
- clearHexaiEnv(t)
- cfg := Load(nil)
- if cfg.IgnoreGitignore == nil || !*cfg.IgnoreGitignore {
- t.Error("expected IgnoreGitignore default true")
- }
- if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify {
- t.Error("expected IgnoreLSPNotify default true")
- }
- if len(cfg.IgnoreExtraPatterns) != 0 {
- t.Errorf("expected empty IgnoreExtraPatterns, got %v", cfg.IgnoreExtraPatterns)
- }
-}
-
-func TestIgnoreConfig_FromFile(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[ignore]
-gitignore = false
-extra_patterns = ["*.min.js", "dist/**"]
-lsp_notify_ignored = false
-`)
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
- if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
- t.Error("expected IgnoreGitignore false from file")
- }
- if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify {
- t.Error("expected IgnoreLSPNotify false from file")
- }
- want := []string{"*.min.js", "dist/**"}
- if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
- t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
- }
-}
-
-func TestIgnoreConfig_EnvOverrides(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[ignore]
-gitignore = true
-lsp_notify_ignored = true
-`)
- withEnv(t, "HEXAI_IGNORE_GITIGNORE", "false")
- withEnv(t, "HEXAI_IGNORE_LSP_NOTIFY", "0")
- withEnv(t, "HEXAI_IGNORE_EXTRA_PATTERNS", "*.bak,*.tmp")
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
- if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
- t.Error("expected IgnoreGitignore false from env override")
- }
- if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify {
- t.Error("expected IgnoreLSPNotify false from env override")
- }
- want := []string{"*.bak", "*.tmp"}
- if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
- t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
- }
-}
-
-func TestIgnoreConfig_ProjectOverride(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[ignore]
-gitignore = true
-`)
- // Set up a fake git repo with project override
- projectDir := t.TempDir()
- if err := os.Mkdir(filepath.Join(projectDir, ".git"), 0o755); err != nil {
- t.Fatalf("mkdir .git: %v", err)
- }
- projectCfg := filepath.Join(projectDir, ProjectConfigFilename)
- writeFile(t, projectCfg, `
-[ignore]
-gitignore = false
-extra_patterns = ["build/**"]
-`)
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: projectDir})
- if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
- t.Error("expected project override to set IgnoreGitignore false")
- }
- want := []string{"build/**"}
- if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
- t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
- }
-}
-
-func TestIgnoreConfig_DisableGitignore(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[ignore]
-gitignore = false
-`)
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
- if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
- t.Error("expected IgnoreGitignore false")
- }
- // LSP notify should still be true (default, not overridden)
- if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify {
- t.Error("expected IgnoreLSPNotify to remain true (default)")
- }
-}
-
-func TestTmuxEditConfig_FromFile(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[tmux_edit]
-popup_width = "90%"
-popup_height = "85%"
-default_agent = "claude"
-
-[[tmux_edit.agents]]
-name = "claude"
-display_name = "Claude Code"
-detect_pattern = "(?i)(claude|anthropic)"
-prompt_pattern = '(?s)>\s*(.+?)$'
-clear_first = true
-clear_keys = "C-u"
-newline_keys = "S-Enter"
-submit_keys = "Enter"
-
-[[tmux_edit.agents]]
-name = "cursor"
-display_name = "Cursor"
-detect_pattern = "(?i)cursor"
-prompt_pattern = '(?s)│\s*(.+?)$'
-strip_patterns = ["INSERT", "Add a follow-up"]
-clear_first = true
-clear_keys = "C-u"
-newline_keys = "S-Enter"
-submit_keys = "Enter"
-`)
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
- if cfg.TmuxEditPopupWidth != "90%" {
- t.Errorf("PopupWidth = %q, want 90%%", cfg.TmuxEditPopupWidth)
- }
- if cfg.TmuxEditPopupHeight != "85%" {
- t.Errorf("PopupHeight = %q, want 85%%", cfg.TmuxEditPopupHeight)
- }
- if cfg.TmuxEditDefaultAgent != "claude" {
- t.Errorf("DefaultAgent = %q, want claude", cfg.TmuxEditDefaultAgent)
- }
- if len(cfg.TmuxEditAgents) != 2 {
- t.Fatalf("got %d agents, want 2", len(cfg.TmuxEditAgents))
- }
- a := cfg.TmuxEditAgents[0]
- if a.Name != "claude" || a.DisplayName != "Claude Code" {
- t.Errorf("agent[0] = %q/%q, want claude/Claude Code", a.Name, a.DisplayName)
- }
- if a.ClearFirst == nil || !*a.ClearFirst {
- t.Error("expected ClearFirst = true for claude agent")
- }
- b := cfg.TmuxEditAgents[1]
- if b.Name != "cursor" {
- t.Errorf("agent[1].Name = %q, want cursor", b.Name)
- }
- if len(b.StripPatterns) != 2 {
- t.Errorf("agent[1].StripPatterns = %v, want 2 entries", b.StripPatterns)
- }
-}
-
-func TestTmuxEditConfig_Merge(t *testing.T) {
- clearHexaiEnv(t)
- a := newDefaultConfig()
- b := App{
- TmuxEditPopupWidth: "70%",
- TmuxEditDefaultAgent: "amp",
- TmuxEditAgents: []TmuxEditAgentCfg{
- {Name: "amp", DisplayName: "Amp"},
- },
- }
- a.mergeWith(&b)
- if a.TmuxEditPopupWidth != "70%" {
- t.Errorf("PopupWidth = %q, want 70%%", a.TmuxEditPopupWidth)
- }
- if a.TmuxEditDefaultAgent != "amp" {
- t.Errorf("DefaultAgent = %q, want amp", a.TmuxEditDefaultAgent)
- }
- if len(a.TmuxEditAgents) != 1 || a.TmuxEditAgents[0].Name != "amp" {
- t.Errorf("Agents = %v, want single amp", a.TmuxEditAgents)
- }
-}
-
-func TestTmuxEditConfig_SkipsEmptyName(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[tmux_edit]
-[[tmux_edit.agents]]
-name = ""
-display_name = "Empty"
-`)
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
- if len(cfg.TmuxEditAgents) != 0 {
- t.Errorf("got %d agents, want 0 (empty name should be skipped)", len(cfg.TmuxEditAgents))
- }
-}
-
-// --- Phase 1: Config Parsing Tests ---
-
-func TestParseTemperatureValue(t *testing.T) {
- tests := []struct {
- name string
- input any
- wantValue *float64
- wantOK bool
- }{
- {"float64 zero", float64(0.0), floatPtr(0.0), true},
- {"float64 half", float64(0.5), floatPtr(0.5), true},
- {"float64 one", float64(1.0), floatPtr(1.0), true},
- {"float64 two", float64(2.0), floatPtr(2.0), true},
- {"int64 zero", int64(0), floatPtr(0.0), true},
- {"int64 one", int64(1), floatPtr(1.0), true},
- {"int64 two", int64(2), floatPtr(2.0), true},
- {"string zero", "0", floatPtr(0.0), true},
- {"string one", "1", floatPtr(1.0), true},
- {"string two", "2", floatPtr(2.0), true},
- {"string float", "0.75", floatPtr(0.75), true},
- {"string empty", "", nil, true},
- {"string whitespace", " ", nil, true},
- {"string invalid", "invalid", nil, false},
- {"string negative", "-0.5", floatPtr(-0.5), true},
- {"string very small", "0.0001", floatPtr(0.0001), true},
- {"string high precision", "1.123456789", floatPtr(1.123456789), true},
- {"nil value", nil, nil, false},
- {"bool value", true, nil, false},
- {"map value", map[string]any{}, nil, false},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, ok := parseTemperatureValue(tt.input, "test", newLogger())
- if ok != tt.wantOK {
- t.Errorf("parseTemperatureValue() ok = %v, want %v", ok, tt.wantOK)
- }
- if !ok {
- return
- }
- if (got == nil) != (tt.wantValue == nil) {
- t.Errorf("parseTemperatureValue() = %v, want %v", got, tt.wantValue)
- return
- }
- if got != nil && tt.wantValue != nil && *got != *tt.wantValue {
- t.Errorf("parseTemperatureValue() = %v, want %v", *got, *tt.wantValue)
- }
- })
- }
-}
-
-func TestDecodeModelEntry(t *testing.T) {
- tests := []struct {
- name string
- input any
- wantCfg *SurfaceConfig
- wantOK bool
- }{
- {
- name: "simple string model",
- input: "gpt-4",
- wantCfg: &SurfaceConfig{Model: "gpt-4"},
- wantOK: true,
- },
- {
- name: "empty string",
- input: "",
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "whitespace string",
- input: " ",
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "object with all fields",
- input: map[string]any{
- "model": "claude-3",
- "provider": "anthropic",
- "temperature": float64(0.7),
- },
- wantCfg: &SurfaceConfig{
- Model: "claude-3",
- Provider: "anthropic",
- Temperature: floatPtr(0.7),
- },
- wantOK: true,
- },
- {
- name: "object with model only",
- input: map[string]any{
- "model": "gpt-4o",
- },
- wantCfg: &SurfaceConfig{Model: "gpt-4o"},
- wantOK: true,
- },
- {
- name: "object with provider only",
- input: map[string]any{
- "provider": "openai",
- },
- wantCfg: &SurfaceConfig{Provider: "openai"},
- wantOK: true,
- },
- {
- name: "object with temperature only",
- input: map[string]any{
- "temperature": float64(0.5),
- },
- wantCfg: &SurfaceConfig{Temperature: floatPtr(0.5)},
- wantOK: true,
- },
- {
- name: "object with empty fields",
- input: map[string]any{
- "model": "",
- "provider": "",
- },
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "object with invalid model type",
- input: map[string]any{
- "model": 123,
- },
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "object with invalid provider type",
- input: map[string]any{
- "provider": 456,
- },
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "object with invalid temperature",
- input: map[string]any{
- "model": "gpt-4",
- "temperature": "not a number",
- },
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "nil input",
- input: nil,
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "invalid type (int)",
- input: 123,
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "invalid type (slice)",
- input: []string{"gpt-4"},
- wantCfg: nil,
- wantOK: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, ok := decodeModelEntry(tt.input, "test", newLogger())
- if ok != tt.wantOK {
- t.Errorf("decodeModelEntry() ok = %v, want %v", ok, tt.wantOK)
- }
- if !ok {
- return
- }
- if (got == nil) != (tt.wantCfg == nil) {
- t.Errorf("decodeModelEntry() = %v, want %v", got, tt.wantCfg)
- return
- }
- if got == nil {
- return
- }
- if got.Model != tt.wantCfg.Model {
- t.Errorf("Model = %q, want %q", got.Model, tt.wantCfg.Model)
- }
- if got.Provider != tt.wantCfg.Provider {
- t.Errorf("Provider = %q, want %q", got.Provider, tt.wantCfg.Provider)
- }
- if (got.Temperature == nil) != (tt.wantCfg.Temperature == nil) {
- t.Errorf("Temperature nil mismatch: got %v, want %v", got.Temperature, tt.wantCfg.Temperature)
- } else if got.Temperature != nil && *got.Temperature != *tt.wantCfg.Temperature {
- t.Errorf("Temperature = %v, want %v", *got.Temperature, *tt.wantCfg.Temperature)
- }
- })
- }
-}
-
-func TestResolvedModel(t *testing.T) {
- tests := []struct {
- name string
- section sectionOpenAI
- want string
- }{
- {
- name: "explicit model no presets",
- section: sectionOpenAI{Model: "gpt-4"},
- want: "gpt-4",
- },
- {
- name: "empty model",
- section: sectionOpenAI{Model: ""},
- want: "",
- },
- {
- name: "whitespace model",
- section: sectionOpenAI{Model: " "},
- want: "",
- },
- {
- name: "preset match exact case",
- section: sectionOpenAI{
- Model: "fast",
- Presets: map[string]string{"fast": "gpt-3.5-turbo"},
- },
- want: "gpt-3.5-turbo",
- },
- {
- name: "preset match case insensitive",
- section: sectionOpenAI{
- Model: "FAST",
- Presets: map[string]string{"fast": "gpt-3.5-turbo"},
- },
- want: "gpt-3.5-turbo",
- },
- {
- name: "no preset match returns original",
- section: sectionOpenAI{
- Model: "custom-model",
- Presets: map[string]string{"fast": "gpt-3.5-turbo"},
- },
- want: "custom-model",
- },
- {
- name: "preset empty value returns original",
- section: sectionOpenAI{
- Model: "fast",
- Presets: map[string]string{"fast": ""},
- },
- want: "fast",
- },
- {
- name: "preset whitespace value returns original",
- section: sectionOpenAI{
- Model: "fast",
- Presets: map[string]string{"fast": " "},
- },
- want: "fast",
- },
- {
- name: "multiple presets uses correct one",
- section: sectionOpenAI{
- Model: "smart",
- Presets: map[string]string{
- "fast": "gpt-3.5-turbo",
- "smart": "gpt-4",
- "mini": "gpt-3.5-mini",
- },
- },
- want: "gpt-4",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := tt.section.resolvedModel()
- if got != tt.want {
- t.Errorf("resolvedModel() = %q, want %q", got, tt.want)
- }
- })
- }
-}
-
-func TestParseSurfaceEntries(t *testing.T) {
- tests := []struct {
- name string
- input any
- wantLen int
- wantOK bool
- }{
- {
- name: "nil input",
- input: nil,
- wantLen: 0,
- wantOK: false,
- },
- {
- name: "single string",
- input: "gpt-4",
- wantLen: 1,
- wantOK: true,
- },
- {
- name: "single map",
- input: map[string]any{
- "model": "claude-3",
- "provider": "anthropic",
- },
- wantLen: 1,
- wantOK: true,
- },
- {
- name: "array of strings",
- input: []any{
- "gpt-4",
- "claude-3",
- },
- wantLen: 2,
- wantOK: true,
- },
- {
- name: "array of maps",
- input: []any{
- map[string]any{"model": "gpt-4", "provider": "openai"},
- map[string]any{"model": "claude-3", "provider": "anthropic"},
- },
- wantLen: 2,
- wantOK: true,
- },
- {
- name: "array with invalid entries",
- input: []any{
- "gpt-4",
- 123,
- "claude-3",
- },
- wantLen: 2,
- wantOK: true,
- },
- {
- name: "array with all invalid entries",
- input: []any{
- 123,
- true,
- nil,
- },
- wantLen: 0,
- wantOK: false,
- },
- {
- name: "empty array",
- input: []any{},
- wantLen: 0,
- wantOK: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, ok := parseSurfaceEntries(tt.input, "test", newLogger())
- if ok != tt.wantOK {
- t.Errorf("parseSurfaceEntries() ok = %v, want %v", ok, tt.wantOK)
- }
- if len(got) != tt.wantLen {
- t.Errorf("parseSurfaceEntries() len = %d, want %d", len(got), tt.wantLen)
- }
- })
- }
-}
diff --git a/internal/mcp/handlers_test.go b/internal/mcp/handlers_prompt_test.go
index 2a4f821..ad0d261 100644
--- a/internal/mcp/handlers_test.go
+++ b/internal/mcp/handlers_prompt_test.go
@@ -1,4 +1,4 @@
-// Summary: Tests for MCP prompt management handlers
+// Summary: Tests for MCP prompt management handlers (create, update, delete, get, list)
package mcp
import (
@@ -237,7 +237,7 @@ func TestServer_PromptsCreate_MissingName(t *testing.T) {
}
}
-// Update mockPromptStore to support Create, Update, Delete
+// mockPromptStore Create, Update, Delete methods used by prompt handler tests.
func (m *mockPromptStore) Create(prompt *promptstore.Prompt) error {
if _, exists := m.prompts[prompt.Name]; exists {
return fmt.Errorf("prompt already exists: %s", prompt.Name)
@@ -953,698 +953,3 @@ func TestServer_HandleInitialize_InvalidParams(t *testing.T) {
t.Fatal("Expected error for invalid params")
}
}
-
-// ==================== Tools Tests ====================
-
-func TestServer_ToolsList(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/list request
- req := Request{
- JSONRPC: "2.0",
- ID: 40,
- Method: "tools/list",
- Params: json.RawMessage(`{}`),
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error != nil {
- t.Fatalf("Error = %v, want nil", resp.Error)
- }
-
- // Parse result
- resultBytes, _ := json.Marshal(resp.Result)
- var result ListToolsResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- // Verify 3 tools returned
- if len(result.Tools) != 3 {
- t.Errorf("Tools count = %d, want 3", len(result.Tools))
- }
-
- // Verify tool names
- toolNames := make(map[string]bool)
- for _, tool := range result.Tools {
- toolNames[tool.Name] = true
- if tool.Description == "" {
- t.Errorf("Tool %s has empty description", tool.Name)
- }
- if tool.InputSchema == nil {
- t.Errorf("Tool %s has nil InputSchema", tool.Name)
- }
- }
-
- expectedTools := []string{"create_prompt", "update_prompt", "delete_prompt"}
- for _, name := range expectedTools {
- if !toolNames[name] {
- t.Errorf("Missing expected tool: %s", name)
- }
- }
-}
-
-func TestServer_ToolsCall_CreatePrompt(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call request
- params := CallToolRequest{
- Name: "create_prompt",
- Arguments: map[string]interface{}{
- "name": "tool_test",
- "title": "Tool Test Prompt",
- "messages": []interface{}{
- map[string]interface{}{
- "role": "user",
- "content": map[string]interface{}{
- "type": "text",
- "text": "Test message",
- },
- },
- },
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 41,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error != nil {
- t.Fatalf("Error = %v, want nil", resp.Error)
- }
-
- // Parse result
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if result.IsError {
- t.Errorf("IsError = true, want false. Content: %v", result.Content)
- }
-
- if len(result.Content) == 0 {
- t.Fatal("Expected content in result")
- }
-
- // Verify prompt was created
- if _, exists := store.prompts["tool_test"]; !exists {
- t.Error("Prompt was not created in store")
- }
-}
-
-func TestServer_ToolsCall_UpdatePrompt(t *testing.T) {
- now := time.Now()
- store := &mockPromptStore{
- prompts: map[string]*promptstore.Prompt{
- "tool_update": {
- Name: "tool_update",
- Title: "Original Title",
- Created: now,
- Updated: now,
- Messages: []promptstore.PromptMessage{
- {
- Role: "user",
- Content: promptstore.MessageContent{
- Type: "text",
- Text: "Original",
- },
- },
- },
- },
- },
- }
-
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call request
- params := CallToolRequest{
- Name: "update_prompt",
- Arguments: map[string]interface{}{
- "name": "tool_update",
- "title": "Updated Title",
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 42,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error != nil {
- t.Fatalf("Error = %v, want nil", resp.Error)
- }
-
- // Parse result
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if result.IsError {
- t.Errorf("IsError = true, want false. Content: %v", result.Content)
- }
-
- // Verify prompt was updated
- if store.prompts["tool_update"].Title != "Updated Title" {
- t.Errorf("Title not updated, got %s", store.prompts["tool_update"].Title)
- }
-}
-
-func TestServer_ToolsCall_DeletePrompt(t *testing.T) {
- now := time.Now()
- store := &mockPromptStore{
- prompts: map[string]*promptstore.Prompt{
- "tool_delete": {
- Name: "tool_delete",
- Title: "To Delete",
- Created: now,
- Updated: now,
- Messages: []promptstore.PromptMessage{
- {
- Role: "user",
- Content: promptstore.MessageContent{
- Type: "text",
- Text: "Test",
- },
- },
- },
- },
- },
- }
-
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call request
- params := CallToolRequest{
- Name: "delete_prompt",
- Arguments: map[string]interface{}{
- "name": "tool_delete",
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 43,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error != nil {
- t.Fatalf("Error = %v, want nil", resp.Error)
- }
-
- // Parse result
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if result.IsError {
- t.Errorf("IsError = true, want false. Content: %v", result.Content)
- }
-
- // Verify prompt was deleted
- if _, exists := store.prompts["tool_delete"]; exists {
- t.Error("Prompt was not deleted from store")
- }
-}
-
-func TestServer_ToolsCall_UnknownTool(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call request with unknown tool
- params := CallToolRequest{
- Name: "nonexistent_tool",
- Arguments: map[string]interface{}{},
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 44,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- // Should not be a protocol error
- if resp.Error != nil {
- t.Fatalf("Unexpected protocol error: %v", resp.Error)
- }
-
- // Parse result - should be a tool error
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if !result.IsError {
- t.Error("Expected IsError = true for unknown tool")
- }
-}
-
-func TestServer_ToolsCall_InvalidArguments(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call with invalid arguments (missing required fields)
- params := CallToolRequest{
- Name: "create_prompt",
- Arguments: map[string]interface{}{
- "name": "test", // Missing title and messages
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 45,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- // Should not be a protocol error
- if resp.Error != nil {
- t.Fatalf("Unexpected protocol error: %v", resp.Error)
- }
-
- // Parse result - should be a tool error
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if !result.IsError {
- t.Error("Expected IsError = true for invalid arguments")
- }
-}
-
-func TestServer_ToolsCall_NotInitialized(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // DO NOT initialize server
-
- // Send tools/call request
- params := CallToolRequest{
- Name: "create_prompt",
- Arguments: map[string]interface{}{},
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 46,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error == nil {
- t.Fatal("Expected error for uninitialized server")
- }
-
- if resp.Error.Code != ErrCodeInvalidRequest {
- t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest)
- }
-}
-
-func TestServer_ToolsCall_CreatePrompt_AlreadyExists(t *testing.T) {
- now := time.Now()
- store := &mockPromptStore{
- prompts: map[string]*promptstore.Prompt{
- "existing": {
- Name: "existing",
- Title: "Existing",
- Created: now,
- Updated: now,
- Messages: []promptstore.PromptMessage{},
- },
- },
- }
-
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call to create duplicate
- params := CallToolRequest{
- Name: "create_prompt",
- Arguments: map[string]interface{}{
- "name": "existing",
- "title": "Duplicate",
- "messages": []interface{}{
- map[string]interface{}{
- "role": "user",
- "content": map[string]interface{}{
- "type": "text",
- "text": "Test",
- },
- },
- },
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 47,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- // Should not be a protocol error
- if resp.Error != nil {
- t.Fatalf("Unexpected protocol error: %v", resp.Error)
- }
-
- // Parse result - should be a tool error
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if !result.IsError {
- t.Error("Expected IsError = true for duplicate prompt")
- }
-}
-
-func TestServer_Initialize_AdvertisesTools(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Send initialize request
- params := InitializeRequest{
- ProtocolVersion: "2025-11-25",
- Capabilities: ClientCapabilities{},
- ClientInfo: ClientInfo{
- Name: "test-client",
- Version: "1.0",
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 48,
- Method: "initialize",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error != nil {
- t.Fatalf("Error = %v, want nil", resp.Error)
- }
-
- // Parse result
- resultBytes, _ := json.Marshal(resp.Result)
- var result InitializeResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- // Verify Tools capability is advertised
- if result.Capabilities.Tools == nil {
- t.Fatal("Tools capability not advertised")
- }
-
- // Verify Prompts capability is also still advertised
- if result.Capabilities.Prompts == nil {
- t.Fatal("Prompts capability not advertised")
- }
-}
-
-func TestServer_ToolsList_NotInitialized(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // DO NOT initialize server
-
- // Send tools/list request
- req := Request{
- JSONRPC: "2.0",
- ID: 49,
- Method: "tools/list",
- Params: json.RawMessage(`{}`),
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error == nil {
- t.Fatal("Expected error for uninitialized server")
- }
-
- if resp.Error.Code != ErrCodeInvalidRequest {
- t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest)
- }
-}
-
-func TestServer_ToolsCall_UpdatePrompt_NotFound(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call to update non-existent prompt
- params := CallToolRequest{
- Name: "update_prompt",
- Arguments: map[string]interface{}{
- "name": "nonexistent",
- "title": "Updated",
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 50,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- // Should not be a protocol error
- if resp.Error != nil {
- t.Fatalf("Unexpected protocol error: %v", resp.Error)
- }
-
- // Parse result - should be a tool error
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if !result.IsError {
- t.Error("Expected IsError = true for non-existent prompt")
- }
-}
-
-func TestServer_ToolsCall_DeletePrompt_NotFound(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call to delete non-existent prompt
- params := CallToolRequest{
- Name: "delete_prompt",
- Arguments: map[string]interface{}{
- "name": "nonexistent",
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 51,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- // Should not be a protocol error
- if resp.Error != nil {
- t.Fatalf("Unexpected protocol error: %v", resp.Error)
- }
-
- // Parse result - should be a tool error
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if !result.IsError {
- t.Error("Expected IsError = true for non-existent prompt")
- }
-}
-
-func TestServer_ToolsCall_InvalidParams(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call with invalid JSON params
- req := Request{
- JSONRPC: "2.0",
- ID: 52,
- Method: "tools/call",
- Params: json.RawMessage(`{invalid}`),
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error == nil {
- t.Fatal("Expected protocol error for invalid params")
- }
-
- if resp.Error.Code != ErrCodeInvalidParams {
- t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidParams)
- }
-}
diff --git a/internal/mcp/handlers_tool_test.go b/internal/mcp/handlers_tool_test.go
new file mode 100644
index 0000000..604938a
--- /dev/null
+++ b/internal/mcp/handlers_tool_test.go
@@ -0,0 +1,703 @@
+// Summary: Tests for MCP tool handlers (tools/list, tools/call for create/update/delete)
+package mcp
+
+import (
+ "encoding/json"
+ "testing"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/promptstore"
+)
+
+func TestServer_ToolsList(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/list request
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 40,
+ Method: "tools/list",
+ Params: json.RawMessage(`{}`),
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result ListToolsResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ // Verify 3 tools returned
+ if len(result.Tools) != 3 {
+ t.Errorf("Tools count = %d, want 3", len(result.Tools))
+ }
+
+ // Verify tool names
+ toolNames := make(map[string]bool)
+ for _, tool := range result.Tools {
+ toolNames[tool.Name] = true
+ if tool.Description == "" {
+ t.Errorf("Tool %s has empty description", tool.Name)
+ }
+ if tool.InputSchema == nil {
+ t.Errorf("Tool %s has nil InputSchema", tool.Name)
+ }
+ }
+
+ expectedTools := []string{"create_prompt", "update_prompt", "delete_prompt"}
+ for _, name := range expectedTools {
+ if !toolNames[name] {
+ t.Errorf("Missing expected tool: %s", name)
+ }
+ }
+}
+
+func TestServer_ToolsCall_CreatePrompt(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call request
+ params := CallToolRequest{
+ Name: "create_prompt",
+ Arguments: map[string]interface{}{
+ "name": "tool_test",
+ "title": "Tool Test Prompt",
+ "messages": []interface{}{
+ map[string]interface{}{
+ "role": "user",
+ "content": map[string]interface{}{
+ "type": "text",
+ "text": "Test message",
+ },
+ },
+ },
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 41,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if result.IsError {
+ t.Errorf("IsError = true, want false. Content: %v", result.Content)
+ }
+
+ if len(result.Content) == 0 {
+ t.Fatal("Expected content in result")
+ }
+
+ // Verify prompt was created
+ if _, exists := store.prompts["tool_test"]; !exists {
+ t.Error("Prompt was not created in store")
+ }
+}
+
+func TestServer_ToolsCall_UpdatePrompt(t *testing.T) {
+ now := time.Now()
+ store := &mockPromptStore{
+ prompts: map[string]*promptstore.Prompt{
+ "tool_update": {
+ Name: "tool_update",
+ Title: "Original Title",
+ Created: now,
+ Updated: now,
+ Messages: []promptstore.PromptMessage{
+ {
+ Role: "user",
+ Content: promptstore.MessageContent{
+ Type: "text",
+ Text: "Original",
+ },
+ },
+ },
+ },
+ },
+ }
+
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call request
+ params := CallToolRequest{
+ Name: "update_prompt",
+ Arguments: map[string]interface{}{
+ "name": "tool_update",
+ "title": "Updated Title",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 42,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if result.IsError {
+ t.Errorf("IsError = true, want false. Content: %v", result.Content)
+ }
+
+ // Verify prompt was updated
+ if store.prompts["tool_update"].Title != "Updated Title" {
+ t.Errorf("Title not updated, got %s", store.prompts["tool_update"].Title)
+ }
+}
+
+func TestServer_ToolsCall_DeletePrompt(t *testing.T) {
+ now := time.Now()
+ store := &mockPromptStore{
+ prompts: map[string]*promptstore.Prompt{
+ "tool_delete": {
+ Name: "tool_delete",
+ Title: "To Delete",
+ Created: now,
+ Updated: now,
+ Messages: []promptstore.PromptMessage{
+ {
+ Role: "user",
+ Content: promptstore.MessageContent{
+ Type: "text",
+ Text: "Test",
+ },
+ },
+ },
+ },
+ },
+ }
+
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call request
+ params := CallToolRequest{
+ Name: "delete_prompt",
+ Arguments: map[string]interface{}{
+ "name": "tool_delete",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 43,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if result.IsError {
+ t.Errorf("IsError = true, want false. Content: %v", result.Content)
+ }
+
+ // Verify prompt was deleted
+ if _, exists := store.prompts["tool_delete"]; exists {
+ t.Error("Prompt was not deleted from store")
+ }
+}
+
+func TestServer_ToolsCall_UnknownTool(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call request with unknown tool
+ params := CallToolRequest{
+ Name: "nonexistent_tool",
+ Arguments: map[string]interface{}{},
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 44,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ // Should not be a protocol error
+ if resp.Error != nil {
+ t.Fatalf("Unexpected protocol error: %v", resp.Error)
+ }
+
+ // Parse result - should be a tool error
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if !result.IsError {
+ t.Error("Expected IsError = true for unknown tool")
+ }
+}
+
+func TestServer_ToolsCall_InvalidArguments(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call with invalid arguments (missing required fields)
+ params := CallToolRequest{
+ Name: "create_prompt",
+ Arguments: map[string]interface{}{
+ "name": "test", // Missing title and messages
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 45,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ // Should not be a protocol error
+ if resp.Error != nil {
+ t.Fatalf("Unexpected protocol error: %v", resp.Error)
+ }
+
+ // Parse result - should be a tool error
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if !result.IsError {
+ t.Error("Expected IsError = true for invalid arguments")
+ }
+}
+
+func TestServer_ToolsCall_NotInitialized(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // DO NOT initialize server
+
+ // Send tools/call request
+ params := CallToolRequest{
+ Name: "create_prompt",
+ Arguments: map[string]interface{}{},
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 46,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error == nil {
+ t.Fatal("Expected error for uninitialized server")
+ }
+
+ if resp.Error.Code != ErrCodeInvalidRequest {
+ t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest)
+ }
+}
+
+func TestServer_ToolsCall_CreatePrompt_AlreadyExists(t *testing.T) {
+ now := time.Now()
+ store := &mockPromptStore{
+ prompts: map[string]*promptstore.Prompt{
+ "existing": {
+ Name: "existing",
+ Title: "Existing",
+ Created: now,
+ Updated: now,
+ Messages: []promptstore.PromptMessage{},
+ },
+ },
+ }
+
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call to create duplicate
+ params := CallToolRequest{
+ Name: "create_prompt",
+ Arguments: map[string]interface{}{
+ "name": "existing",
+ "title": "Duplicate",
+ "messages": []interface{}{
+ map[string]interface{}{
+ "role": "user",
+ "content": map[string]interface{}{
+ "type": "text",
+ "text": "Test",
+ },
+ },
+ },
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 47,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ // Should not be a protocol error
+ if resp.Error != nil {
+ t.Fatalf("Unexpected protocol error: %v", resp.Error)
+ }
+
+ // Parse result - should be a tool error
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if !result.IsError {
+ t.Error("Expected IsError = true for duplicate prompt")
+ }
+}
+
+func TestServer_Initialize_AdvertisesTools(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Send initialize request
+ params := InitializeRequest{
+ ProtocolVersion: "2025-11-25",
+ Capabilities: ClientCapabilities{},
+ ClientInfo: ClientInfo{
+ Name: "test-client",
+ Version: "1.0",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 48,
+ Method: "initialize",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result InitializeResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ // Verify Tools capability is advertised
+ if result.Capabilities.Tools == nil {
+ t.Fatal("Tools capability not advertised")
+ }
+
+ // Verify Prompts capability is also still advertised
+ if result.Capabilities.Prompts == nil {
+ t.Fatal("Prompts capability not advertised")
+ }
+}
+
+func TestServer_ToolsList_NotInitialized(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // DO NOT initialize server
+
+ // Send tools/list request
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 49,
+ Method: "tools/list",
+ Params: json.RawMessage(`{}`),
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error == nil {
+ t.Fatal("Expected error for uninitialized server")
+ }
+
+ if resp.Error.Code != ErrCodeInvalidRequest {
+ t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest)
+ }
+}
+
+func TestServer_ToolsCall_UpdatePrompt_NotFound(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call to update non-existent prompt
+ params := CallToolRequest{
+ Name: "update_prompt",
+ Arguments: map[string]interface{}{
+ "name": "nonexistent",
+ "title": "Updated",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 50,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ // Should not be a protocol error
+ if resp.Error != nil {
+ t.Fatalf("Unexpected protocol error: %v", resp.Error)
+ }
+
+ // Parse result - should be a tool error
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if !result.IsError {
+ t.Error("Expected IsError = true for non-existent prompt")
+ }
+}
+
+func TestServer_ToolsCall_DeletePrompt_NotFound(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call to delete non-existent prompt
+ params := CallToolRequest{
+ Name: "delete_prompt",
+ Arguments: map[string]interface{}{
+ "name": "nonexistent",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 51,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ // Should not be a protocol error
+ if resp.Error != nil {
+ t.Fatalf("Unexpected protocol error: %v", resp.Error)
+ }
+
+ // Parse result - should be a tool error
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if !result.IsError {
+ t.Error("Expected IsError = true for non-existent prompt")
+ }
+}
+
+func TestServer_ToolsCall_InvalidParams(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call with invalid JSON params
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 52,
+ Method: "tools/call",
+ Params: json.RawMessage(`{invalid}`),
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error == nil {
+ t.Fatal("Expected protocol error for invalid params")
+ }
+
+ if resp.Error.Code != ErrCodeInvalidParams {
+ t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidParams)
+ }
+}
diff --git a/internal/promptstore/default_prompts.go b/internal/promptstore/default_prompts.go
index f8edefc..434855a 100644
--- a/internal/promptstore/default_prompts.go
+++ b/internal/promptstore/default_prompts.go
@@ -6,34 +6,50 @@ import (
)
// DefaultPrompts returns the built-in meta-prompts for prompt management.
-// These prompts help users create, update, delete, and design prompts interactively
-// using Claude's access to conversation context.
+// It assembles prompts from category-specific helpers so each remains concise.
func DefaultPrompts() []Prompt {
now := time.Now()
-
return []Prompt{
- {
- Name: "save_prompt",
- Title: "Save Current Conversation as Prompt",
- Description: "Interactively create a new prompt template from the current conversation. Claude will analyze the conversation, ask clarifying questions about templating, show a preview, and wait for approval before saving.",
- Arguments: []PromptArgument{
- {
- Name: "prompt_name",
- Description: "Unique identifier for the new prompt (lowercase, underscores allowed)",
- Required: true,
- },
- {
- Name: "prompt_title",
- Description: "Human-readable display name for the new prompt",
- Required: true,
- },
- },
- Messages: []PromptMessage{
- {
- Role: "user",
- Content: MessageContent{
- Type: "text",
- Text: `I want to create a new prompt template named '{{prompt_name}}' with title '{{prompt_title}}'.
+ buildSavePrompt(now),
+ buildUpdatePrompt(now),
+ buildDeletePrompt(now),
+ buildDesignPrompt(now),
+ }
+}
+
+// formattingRules is the shared instruction block for clarifying-question formatting
+// used by the save and update meta-prompts.
+const formattingRules = `IMPORTANT FORMATTING RULES for clarifying questions:
+- Use numbered questions: 1), 2), 3)
+- ANY CHOICE MUST BE NUMBERED using combined format: 1a), 1b), 1c), 2a), 2b), etc.
+- NEVER use standalone letters like "a)" - always combine with question number
+- NEVER use dashes (-) or bullets (•) for options
+- Every option must be numbered for easy selection by the user`
+
+// buildSavePrompt creates the "save_prompt" meta-prompt that turns a conversation
+// into a reusable prompt template.
+func buildSavePrompt(now time.Time) Prompt {
+ return Prompt{
+ Name: "save_prompt",
+ Title: "Save Current Conversation as Prompt",
+ Description: "Interactively create a new prompt template from the current conversation. Claude will analyze the conversation, ask clarifying questions about templating, show a preview, and wait for approval before saving.",
+ Arguments: []PromptArgument{
+ {Name: "prompt_name", Description: "Unique identifier for the new prompt (lowercase, underscores allowed)", Required: true},
+ {Name: "prompt_title", Description: "Human-readable display name for the new prompt", Required: true},
+ },
+ Messages: []PromptMessage{{
+ Role: "user",
+ Content: MessageContent{Type: "text", Text: savePromptText()},
+ }},
+ Tags: []string{"meta", "prompt-management", "interactive"},
+ Created: now,
+ Updated: now,
+ }
+}
+
+// savePromptText returns the user message body for the save_prompt meta-prompt.
+func savePromptText() string {
+ return `I want to create a new prompt template named '{{prompt_name}}' with title '{{prompt_title}}'.
Please help me by:
1) Analyzing our current conversation to understand what should be templated
@@ -45,12 +61,7 @@ Please help me by:
3) Showing me a complete preview of the prompt structure in a code block
4) Only after I approve, use the create_prompt tool to save it
-IMPORTANT FORMATTING RULES for clarifying questions:
-- Use numbered questions: 1), 2), 3)
-- ANY CHOICE MUST BE NUMBERED using combined format: 1a), 1b), 1c), 2a), 2b), etc.
-- NEVER use standalone letters like "a)" - always combine with question number
-- NEVER use dashes (-) or bullets (•) for options
-- Every option must be numbered for easy selection by the user
+` + formattingRules + `
Examples:
1) Question Category
@@ -77,31 +88,32 @@ Examples:
4b) Sub-question two?
Answer options here
-Start by examining our conversation and asking your clarifying questions using this format.`,
- },
- },
- },
- Tags: []string{"meta", "prompt-management", "interactive"},
- Created: now,
- Updated: now,
+Start by examining our conversation and asking your clarifying questions using this format.`
+}
+
+// buildUpdatePrompt creates the "update_prompt" meta-prompt for modifying
+// an existing prompt interactively.
+func buildUpdatePrompt(now time.Time) Prompt {
+ return Prompt{
+ Name: "update_prompt",
+ Title: "Update Existing Prompt",
+ Description: "Interactively modify an existing prompt. Claude will fetch the current version, ask what changes you want, show a preview with changes highlighted, and wait for approval before updating.",
+ Arguments: []PromptArgument{
+ {Name: "prompt_name", Description: "Name of the existing prompt to update", Required: true},
},
- {
- Name: "update_prompt",
- Title: "Update Existing Prompt",
- Description: "Interactively modify an existing prompt. Claude will fetch the current version, ask what changes you want, show a preview with changes highlighted, and wait for approval before updating.",
- Arguments: []PromptArgument{
- {
- Name: "prompt_name",
- Description: "Name of the existing prompt to update",
- Required: true,
- },
- },
- Messages: []PromptMessage{
- {
- Role: "user",
- Content: MessageContent{
- Type: "text",
- Text: `I want to update the existing prompt '{{prompt_name}}'.
+ Messages: []PromptMessage{{
+ Role: "user",
+ Content: MessageContent{Type: "text", Text: updatePromptText()},
+ }},
+ Tags: []string{"meta", "prompt-management", "interactive"},
+ Created: now,
+ Updated: now,
+ }
+}
+
+// updatePromptText returns the user message body for the update_prompt meta-prompt.
+func updatePromptText() string {
+ return `I want to update the existing prompt '{{prompt_name}}'.
Please help me by:
1) Ask me what changes I want to make to the prompt '{{prompt_name}}' (title, description, arguments, messages, or tags)
@@ -109,12 +121,7 @@ Please help me by:
3) Show me a complete preview of the updated prompt with changes highlighted
4) Only after I approve, use the update_prompt tool to save the changes
-IMPORTANT FORMATTING RULES for clarifying questions:
-- Use numbered questions: 1), 2), 3)
-- ANY CHOICE MUST BE NUMBERED using combined format: 1a), 1b), 1c), 2a), 2b), etc.
-- NEVER use standalone letters like "a)" - always combine with question number
-- NEVER use dashes (-) or bullets (•) for options
-- Every option must be numbered for easy selection by the user
+` + formattingRules + `
Examples:
1) Question Category
@@ -136,31 +143,22 @@ Examples:
3b) Second aspect to evaluate
3c) Third aspect to evaluate
-Start by asking me what changes I want to make, using this format for any clarifying questions.`,
- },
- },
- },
- Tags: []string{"meta", "prompt-management", "interactive"},
- Created: now,
- Updated: now,
+Start by asking me what changes I want to make, using this format for any clarifying questions.`
+}
+
+// buildDeletePrompt creates the "delete_prompt" meta-prompt for safely removing
+// a custom prompt with explicit confirmation.
+func buildDeletePrompt(now time.Time) Prompt {
+ return Prompt{
+ Name: "delete_prompt",
+ Title: "Delete Custom Prompt",
+ Description: "Interactively delete an existing custom prompt with confirmation. Claude will show the current prompt, ask for confirmation, and only delete after explicit approval. Built-in prompts cannot be deleted.",
+ Arguments: []PromptArgument{
+ {Name: "prompt_name", Description: "Name of the existing prompt to delete", Required: true},
},
- {
- Name: "delete_prompt",
- Title: "Delete Custom Prompt",
- Description: "Interactively delete an existing custom prompt with confirmation. Claude will show the current prompt, ask for confirmation, and only delete after explicit approval. Built-in prompts cannot be deleted.",
- Arguments: []PromptArgument{
- {
- Name: "prompt_name",
- Description: "Name of the existing prompt to delete",
- Required: true,
- },
- },
- Messages: []PromptMessage{
- {
- Role: "user",
- Content: MessageContent{
- Type: "text",
- Text: `I want to delete the existing prompt '{{prompt_name}}'.
+ Messages: []PromptMessage{{
+ Role: "user",
+ Content: MessageContent{Type: "text", Text: `I want to delete the existing prompt '{{prompt_name}}'.
Please help me by:
1) Confirm with me that I want to delete the prompt named '{{prompt_name}}'
@@ -173,25 +171,25 @@ IMPORTANT NOTES:
- Only custom prompts stored in user.jsonl can be deleted
- Backups are automatically created before deletion
-Ask me to confirm the deletion of '{{prompt_name}}'.`,
- },
- },
- },
- Tags: []string{"meta", "prompt-management", "interactive"},
- Created: now,
- Updated: now,
- },
- {
- Name: "design_prompt",
- Title: "Design New Prompt from Scratch",
- Description: "Interactively design a brand new prompt template through guided questions. Claude will help you define the purpose, arguments, message flow, and metadata step by step, show a preview, and wait for approval before saving.",
- Arguments: []PromptArgument{},
- Messages: []PromptMessage{
- {
- Role: "user",
- Content: MessageContent{
- Type: "text",
- Text: `I want to design a brand new prompt template from scratch.
+Ask me to confirm the deletion of '{{prompt_name}}'.`},
+ }},
+ Tags: []string{"meta", "prompt-management", "interactive"},
+ Created: now,
+ Updated: now,
+ }
+}
+
+// buildDesignPrompt creates the "design_prompt" meta-prompt for building
+// a brand new prompt template from scratch through guided questions.
+func buildDesignPrompt(now time.Time) Prompt {
+ return Prompt{
+ Name: "design_prompt",
+ Title: "Design New Prompt from Scratch",
+ Description: "Interactively design a brand new prompt template through guided questions. Claude will help you define the purpose, arguments, message flow, and metadata step by step, show a preview, and wait for approval before saving.",
+ Arguments: []PromptArgument{},
+ Messages: []PromptMessage{{
+ Role: "user",
+ Content: MessageContent{Type: "text", Text: `I want to design a brand new prompt template from scratch.
Please ask me:
1) What should this prompt do? (describe the task/purpose in 1-2 sentences)
@@ -200,13 +198,10 @@ Please ask me:
Then show me a preview and save it after I approve.
-Keep questions brief and focused.`,
- },
- },
- },
- Tags: []string{"meta", "prompt-management", "interactive", "creation"},
- Created: now,
- Updated: now,
- },
+Keep questions brief and focused.`},
+ }},
+ Tags: []string{"meta", "prompt-management", "interactive", "creation"},
+ Created: now,
+ Updated: now,
}
}
diff --git a/internal/stats/stats.go b/internal/stats/stats.go
index 95981c5..4b05617 100644
--- a/internal/stats/stats.go
+++ b/internal/stats/stats.go
@@ -83,54 +83,85 @@ func Update(ctx context.Context, provider, model string, sentBytes, recvBytes in
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
+ unlock, err := lockStatsFile(ctx, dir)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = unlock() }()
+
+ path := filepath.Join(dir, fileName)
+ sf := readStatsFile(path)
+ now := time.Now()
+ win := Window()
+ sf.WindowSeconds = int(win.Seconds())
+ sf.Events = append(sf.Events, Event{
+ TS: now, Provider: provider, Model: model,
+ Sent: int64(sentBytes), Recv: int64(recvBytes),
+ })
+ pruneOldEvents(&sf, now.Add(-win))
+ sf.UpdatedAt = now
+ return writeStatsFileAtomic(dir, path, &sf)
+}
+
+// lockStatsFile acquires an advisory file lock on the stats lock file within dir.
+// Returns an unlock function on success.
+func lockStatsFile(ctx context.Context, dir string) (func() error, error) {
lockPath := filepath.Join(dir, lockFileName)
f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0o600)
if err != nil {
- return err
+ return nil, err
}
- defer func() { _ = f.Close() }()
unlock, err := acquireFileLock(ctx, f)
if err != nil {
- return err
+ _ = f.Close()
+ return nil, err
}
- defer func() { _ = unlock() }()
- // Read existing file (if any)
- path := filepath.Join(dir, fileName)
+ // Return a combined unlock+close function so the caller only needs one defer.
+ return func() error {
+ uErr := unlock()
+ cErr := f.Close()
+ if uErr != nil {
+ return uErr
+ }
+ return cErr
+ }, nil
+}
+
+// readStatsFile loads the on-disk stats file, returning a fresh File if it is
+// missing or has an incompatible version.
+func readStatsFile(path string) File {
var sf File
- if b, rerr := os.ReadFile(path); rerr == nil {
+ if b, err := os.ReadFile(path); err == nil {
_ = json.Unmarshal(b, &sf)
}
if sf.Version != fileVersion {
sf = File{Version: fileVersion}
}
- now := time.Now()
- win := Window()
- sf.WindowSeconds = int(win.Seconds())
- // Append event
- sf.Events = append(sf.Events, Event{TS: now, Provider: provider, Model: model, Sent: int64(sentBytes), Recv: int64(recvBytes)})
- // Prune old
- cutoff := now.Add(-win)
- if len(sf.Events) > 0 {
- // Find first >= cutoff
- i := 0
- for ; i < len(sf.Events); i++ {
- if !sf.Events[i].TS.Before(cutoff) {
- break
- }
- }
- if i > 0 {
- sf.Events = append([]Event(nil), sf.Events[i:]...)
+ return sf
+}
+
+// pruneOldEvents removes events older than cutoff in-place.
+func pruneOldEvents(sf *File, cutoff time.Time) {
+ i := 0
+ for ; i < len(sf.Events); i++ {
+ if !sf.Events[i].TS.Before(cutoff) {
+ break
}
}
- sf.UpdatedAt = now
- // Write atomically
+ if i > 0 {
+ sf.Events = append([]Event(nil), sf.Events[i:]...)
+ }
+}
+
+// writeStatsFileAtomic writes sf to path via a temp file + rename for crash safety.
+func writeStatsFileAtomic(dir, path string, sf *File) error {
tmp, err := os.CreateTemp(dir, fileName+".tmp.")
if err != nil {
return err
}
enc := json.NewEncoder(tmp)
enc.SetEscapeHTML(false)
- if err := enc.Encode(&sf); err != nil {
+ if err := enc.Encode(sf); err != nil {
_ = tmp.Close()
_ = os.Remove(tmp.Name())
return err