summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-16 03:37:28 +0200
committerPaul Buetow <paul@buetow.org>2026-03-16 03:37:28 +0200
commit2e9cabb1c8bf1f0246e513fe1f86a552e07eee94 (patch)
tree04620a25f1aa5c7d9d914ba136e591abab1c509d /internal
parentad988c34181b7234a54d279874f29e126607fad3 (diff)
Refactor oversized functions and split large test files
Split DefaultPrompts (201L), loadFromFile (83L), and Update (74L) into focused helper functions under 50 lines each. Split handlers_test.go (1650L) and config_test.go (1419L) into logical sub-files under 1000L. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal')
-rw-r--r--internal/appconfig/config_features_test.go588
-rw-r--r--internal/appconfig/config_load.go98
-rw-r--r--internal/appconfig/config_test.go578
-rw-r--r--internal/mcp/handlers_prompt_test.go (renamed from internal/mcp/handlers_test.go)699
-rw-r--r--internal/mcp/handlers_tool_test.go703
-rw-r--r--internal/promptstore/default_prompts.go219
-rw-r--r--internal/stats/stats.go85
7 files changed, 1517 insertions, 1453 deletions
diff --git a/internal/appconfig/config_features_test.go b/internal/appconfig/config_features_test.go
new file mode 100644
index 0000000..b3c12e9
--- /dev/null
+++ b/internal/appconfig/config_features_test.go
@@ -0,0 +1,588 @@
+// Summary: Tests for ignore config, tmux-edit config, and low-level parsing helpers
+// (temperature, model entries, surface entries, resolved model).
+package appconfig
+
+import (
+ "os"
+ "path/filepath"
+ "reflect"
+ "testing"
+)
+
+func TestIgnoreConfig_Defaults(t *testing.T) {
+ clearHexaiEnv(t)
+ cfg := Load(nil)
+ if cfg.IgnoreGitignore == nil || !*cfg.IgnoreGitignore {
+ t.Error("expected IgnoreGitignore default true")
+ }
+ if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify {
+ t.Error("expected IgnoreLSPNotify default true")
+ }
+ if len(cfg.IgnoreExtraPatterns) != 0 {
+ t.Errorf("expected empty IgnoreExtraPatterns, got %v", cfg.IgnoreExtraPatterns)
+ }
+}
+
+func TestIgnoreConfig_FromFile(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[ignore]
+gitignore = false
+extra_patterns = ["*.min.js", "dist/**"]
+lsp_notify_ignored = false
+`)
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
+ if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
+ t.Error("expected IgnoreGitignore false from file")
+ }
+ if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify {
+ t.Error("expected IgnoreLSPNotify false from file")
+ }
+ want := []string{"*.min.js", "dist/**"}
+ if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
+ t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
+ }
+}
+
+func TestIgnoreConfig_EnvOverrides(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[ignore]
+gitignore = true
+lsp_notify_ignored = true
+`)
+ withEnv(t, "HEXAI_IGNORE_GITIGNORE", "false")
+ withEnv(t, "HEXAI_IGNORE_LSP_NOTIFY", "0")
+ withEnv(t, "HEXAI_IGNORE_EXTRA_PATTERNS", "*.bak,*.tmp")
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
+ if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
+ t.Error("expected IgnoreGitignore false from env override")
+ }
+ if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify {
+ t.Error("expected IgnoreLSPNotify false from env override")
+ }
+ want := []string{"*.bak", "*.tmp"}
+ if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
+ t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
+ }
+}
+
+func TestIgnoreConfig_ProjectOverride(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[ignore]
+gitignore = true
+`)
+ // Set up a fake git repo with project override
+ projectDir := t.TempDir()
+ if err := os.Mkdir(filepath.Join(projectDir, ".git"), 0o755); err != nil {
+ t.Fatalf("mkdir .git: %v", err)
+ }
+ projectCfg := filepath.Join(projectDir, ProjectConfigFilename)
+ writeFile(t, projectCfg, `
+[ignore]
+gitignore = false
+extra_patterns = ["build/**"]
+`)
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: projectDir})
+ if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
+ t.Error("expected project override to set IgnoreGitignore false")
+ }
+ want := []string{"build/**"}
+ if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
+ t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
+ }
+}
+
+func TestIgnoreConfig_DisableGitignore(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[ignore]
+gitignore = false
+`)
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
+ if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
+ t.Error("expected IgnoreGitignore false")
+ }
+ // LSP notify should still be true (default, not overridden)
+ if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify {
+ t.Error("expected IgnoreLSPNotify to remain true (default)")
+ }
+}
+
+func TestTmuxEditConfig_FromFile(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[tmux_edit]
+popup_width = "90%"
+popup_height = "85%"
+default_agent = "claude"
+
+[[tmux_edit.agents]]
+name = "claude"
+display_name = "Claude Code"
+detect_pattern = "(?i)(claude|anthropic)"
+prompt_pattern = '(?s)>\s*(.+?)$'
+clear_first = true
+clear_keys = "C-u"
+newline_keys = "S-Enter"
+submit_keys = "Enter"
+
+[[tmux_edit.agents]]
+name = "cursor"
+display_name = "Cursor"
+detect_pattern = "(?i)cursor"
+prompt_pattern = '(?s)│\s*(.+?)$'
+strip_patterns = ["INSERT", "Add a follow-up"]
+clear_first = true
+clear_keys = "C-u"
+newline_keys = "S-Enter"
+submit_keys = "Enter"
+`)
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
+ if cfg.TmuxEditPopupWidth != "90%" {
+ t.Errorf("PopupWidth = %q, want 90%%", cfg.TmuxEditPopupWidth)
+ }
+ if cfg.TmuxEditPopupHeight != "85%" {
+ t.Errorf("PopupHeight = %q, want 85%%", cfg.TmuxEditPopupHeight)
+ }
+ if cfg.TmuxEditDefaultAgent != "claude" {
+ t.Errorf("DefaultAgent = %q, want claude", cfg.TmuxEditDefaultAgent)
+ }
+ if len(cfg.TmuxEditAgents) != 2 {
+ t.Fatalf("got %d agents, want 2", len(cfg.TmuxEditAgents))
+ }
+ a := cfg.TmuxEditAgents[0]
+ if a.Name != "claude" || a.DisplayName != "Claude Code" {
+ t.Errorf("agent[0] = %q/%q, want claude/Claude Code", a.Name, a.DisplayName)
+ }
+ if a.ClearFirst == nil || !*a.ClearFirst {
+ t.Error("expected ClearFirst = true for claude agent")
+ }
+ b := cfg.TmuxEditAgents[1]
+ if b.Name != "cursor" {
+ t.Errorf("agent[1].Name = %q, want cursor", b.Name)
+ }
+ if len(b.StripPatterns) != 2 {
+ t.Errorf("agent[1].StripPatterns = %v, want 2 entries", b.StripPatterns)
+ }
+}
+
+func TestTmuxEditConfig_Merge(t *testing.T) {
+ clearHexaiEnv(t)
+ a := newDefaultConfig()
+ b := App{
+ TmuxEditPopupWidth: "70%",
+ TmuxEditDefaultAgent: "amp",
+ TmuxEditAgents: []TmuxEditAgentCfg{
+ {Name: "amp", DisplayName: "Amp"},
+ },
+ }
+ a.mergeWith(&b)
+ if a.TmuxEditPopupWidth != "70%" {
+ t.Errorf("PopupWidth = %q, want 70%%", a.TmuxEditPopupWidth)
+ }
+ if a.TmuxEditDefaultAgent != "amp" {
+ t.Errorf("DefaultAgent = %q, want amp", a.TmuxEditDefaultAgent)
+ }
+ if len(a.TmuxEditAgents) != 1 || a.TmuxEditAgents[0].Name != "amp" {
+ t.Errorf("Agents = %v, want single amp", a.TmuxEditAgents)
+ }
+}
+
+func TestTmuxEditConfig_SkipsEmptyName(t *testing.T) {
+ clearHexaiEnv(t)
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.toml")
+ writeFile(t, cfgPath, `
+[tmux_edit]
+[[tmux_edit.agents]]
+name = ""
+display_name = "Empty"
+`)
+ cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
+ if len(cfg.TmuxEditAgents) != 0 {
+ t.Errorf("got %d agents, want 0 (empty name should be skipped)", len(cfg.TmuxEditAgents))
+ }
+}
+
+// --- Config Parsing Tests ---
+
+func TestParseTemperatureValue(t *testing.T) {
+ tests := []struct {
+ name string
+ input any
+ wantValue *float64
+ wantOK bool
+ }{
+ {"float64 zero", float64(0.0), floatPtr(0.0), true},
+ {"float64 half", float64(0.5), floatPtr(0.5), true},
+ {"float64 one", float64(1.0), floatPtr(1.0), true},
+ {"float64 two", float64(2.0), floatPtr(2.0), true},
+ {"int64 zero", int64(0), floatPtr(0.0), true},
+ {"int64 one", int64(1), floatPtr(1.0), true},
+ {"int64 two", int64(2), floatPtr(2.0), true},
+ {"string zero", "0", floatPtr(0.0), true},
+ {"string one", "1", floatPtr(1.0), true},
+ {"string two", "2", floatPtr(2.0), true},
+ {"string float", "0.75", floatPtr(0.75), true},
+ {"string empty", "", nil, true},
+ {"string whitespace", " ", nil, true},
+ {"string invalid", "invalid", nil, false},
+ {"string negative", "-0.5", floatPtr(-0.5), true},
+ {"string very small", "0.0001", floatPtr(0.0001), true},
+ {"string high precision", "1.123456789", floatPtr(1.123456789), true},
+ {"nil value", nil, nil, false},
+ {"bool value", true, nil, false},
+ {"map value", map[string]any{}, nil, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, ok := parseTemperatureValue(tt.input, "test", newLogger())
+ if ok != tt.wantOK {
+ t.Errorf("parseTemperatureValue() ok = %v, want %v", ok, tt.wantOK)
+ }
+ if !ok {
+ return
+ }
+ if (got == nil) != (tt.wantValue == nil) {
+ t.Errorf("parseTemperatureValue() = %v, want %v", got, tt.wantValue)
+ return
+ }
+ if got != nil && tt.wantValue != nil && *got != *tt.wantValue {
+ t.Errorf("parseTemperatureValue() = %v, want %v", *got, *tt.wantValue)
+ }
+ })
+ }
+}
+
+func TestDecodeModelEntry(t *testing.T) {
+ tests := []struct {
+ name string
+ input any
+ wantCfg *SurfaceConfig
+ wantOK bool
+ }{
+ {
+ name: "simple string model",
+ input: "gpt-4",
+ wantCfg: &SurfaceConfig{Model: "gpt-4"},
+ wantOK: true,
+ },
+ {
+ name: "empty string",
+ input: "",
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "whitespace string",
+ input: " ",
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "object with all fields",
+ input: map[string]any{
+ "model": "claude-3",
+ "provider": "anthropic",
+ "temperature": float64(0.7),
+ },
+ wantCfg: &SurfaceConfig{
+ Model: "claude-3",
+ Provider: "anthropic",
+ Temperature: floatPtr(0.7),
+ },
+ wantOK: true,
+ },
+ {
+ name: "object with model only",
+ input: map[string]any{
+ "model": "gpt-4o",
+ },
+ wantCfg: &SurfaceConfig{Model: "gpt-4o"},
+ wantOK: true,
+ },
+ {
+ name: "object with provider only",
+ input: map[string]any{
+ "provider": "openai",
+ },
+ wantCfg: &SurfaceConfig{Provider: "openai"},
+ wantOK: true,
+ },
+ {
+ name: "object with temperature only",
+ input: map[string]any{
+ "temperature": float64(0.5),
+ },
+ wantCfg: &SurfaceConfig{Temperature: floatPtr(0.5)},
+ wantOK: true,
+ },
+ {
+ name: "object with empty fields",
+ input: map[string]any{
+ "model": "",
+ "provider": "",
+ },
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "object with invalid model type",
+ input: map[string]any{
+ "model": 123,
+ },
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "object with invalid provider type",
+ input: map[string]any{
+ "provider": 456,
+ },
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "object with invalid temperature",
+ input: map[string]any{
+ "model": "gpt-4",
+ "temperature": "not a number",
+ },
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "nil input",
+ input: nil,
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "invalid type (int)",
+ input: 123,
+ wantCfg: nil,
+ wantOK: false,
+ },
+ {
+ name: "invalid type (slice)",
+ input: []string{"gpt-4"},
+ wantCfg: nil,
+ wantOK: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, ok := decodeModelEntry(tt.input, "test", newLogger())
+ if ok != tt.wantOK {
+ t.Errorf("decodeModelEntry() ok = %v, want %v", ok, tt.wantOK)
+ }
+ if !ok {
+ return
+ }
+ if (got == nil) != (tt.wantCfg == nil) {
+ t.Errorf("decodeModelEntry() = %v, want %v", got, tt.wantCfg)
+ return
+ }
+ if got == nil {
+ return
+ }
+ if got.Model != tt.wantCfg.Model {
+ t.Errorf("Model = %q, want %q", got.Model, tt.wantCfg.Model)
+ }
+ if got.Provider != tt.wantCfg.Provider {
+ t.Errorf("Provider = %q, want %q", got.Provider, tt.wantCfg.Provider)
+ }
+ if (got.Temperature == nil) != (tt.wantCfg.Temperature == nil) {
+ t.Errorf("Temperature nil mismatch: got %v, want %v", got.Temperature, tt.wantCfg.Temperature)
+ } else if got.Temperature != nil && *got.Temperature != *tt.wantCfg.Temperature {
+ t.Errorf("Temperature = %v, want %v", *got.Temperature, *tt.wantCfg.Temperature)
+ }
+ })
+ }
+}
+
+func TestResolvedModel(t *testing.T) {
+ tests := []struct {
+ name string
+ section sectionOpenAI
+ want string
+ }{
+ {
+ name: "explicit model no presets",
+ section: sectionOpenAI{Model: "gpt-4"},
+ want: "gpt-4",
+ },
+ {
+ name: "empty model",
+ section: sectionOpenAI{Model: ""},
+ want: "",
+ },
+ {
+ name: "whitespace model",
+ section: sectionOpenAI{Model: " "},
+ want: "",
+ },
+ {
+ name: "preset match exact case",
+ section: sectionOpenAI{
+ Model: "fast",
+ Presets: map[string]string{"fast": "gpt-3.5-turbo"},
+ },
+ want: "gpt-3.5-turbo",
+ },
+ {
+ name: "preset match case insensitive",
+ section: sectionOpenAI{
+ Model: "FAST",
+ Presets: map[string]string{"fast": "gpt-3.5-turbo"},
+ },
+ want: "gpt-3.5-turbo",
+ },
+ {
+ name: "no preset match returns original",
+ section: sectionOpenAI{
+ Model: "custom-model",
+ Presets: map[string]string{"fast": "gpt-3.5-turbo"},
+ },
+ want: "custom-model",
+ },
+ {
+ name: "preset empty value returns original",
+ section: sectionOpenAI{
+ Model: "fast",
+ Presets: map[string]string{"fast": ""},
+ },
+ want: "fast",
+ },
+ {
+ name: "preset whitespace value returns original",
+ section: sectionOpenAI{
+ Model: "fast",
+ Presets: map[string]string{"fast": " "},
+ },
+ want: "fast",
+ },
+ {
+ name: "multiple presets uses correct one",
+ section: sectionOpenAI{
+ Model: "smart",
+ Presets: map[string]string{
+ "fast": "gpt-3.5-turbo",
+ "smart": "gpt-4",
+ "mini": "gpt-3.5-mini",
+ },
+ },
+ want: "gpt-4",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.section.resolvedModel()
+ if got != tt.want {
+ t.Errorf("resolvedModel() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestParseSurfaceEntries(t *testing.T) {
+ tests := []struct {
+ name string
+ input any
+ wantLen int
+ wantOK bool
+ }{
+ {
+ name: "nil input",
+ input: nil,
+ wantLen: 0,
+ wantOK: false,
+ },
+ {
+ name: "single string",
+ input: "gpt-4",
+ wantLen: 1,
+ wantOK: true,
+ },
+ {
+ name: "single map",
+ input: map[string]any{
+ "model": "claude-3",
+ "provider": "anthropic",
+ },
+ wantLen: 1,
+ wantOK: true,
+ },
+ {
+ name: "array of strings",
+ input: []any{
+ "gpt-4",
+ "claude-3",
+ },
+ wantLen: 2,
+ wantOK: true,
+ },
+ {
+ name: "array of maps",
+ input: []any{
+ map[string]any{"model": "gpt-4", "provider": "openai"},
+ map[string]any{"model": "claude-3", "provider": "anthropic"},
+ },
+ wantLen: 2,
+ wantOK: true,
+ },
+ {
+ name: "array with invalid entries",
+ input: []any{
+ "gpt-4",
+ 123,
+ "claude-3",
+ },
+ wantLen: 2,
+ wantOK: true,
+ },
+ {
+ name: "array with all invalid entries",
+ input: []any{
+ 123,
+ true,
+ nil,
+ },
+ wantLen: 0,
+ wantOK: false,
+ },
+ {
+ name: "empty array",
+ input: []any{},
+ wantLen: 0,
+ wantOK: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, ok := parseSurfaceEntries(tt.input, "test", newLogger())
+ if ok != tt.wantOK {
+ t.Errorf("parseSurfaceEntries() ok = %v, want %v", ok, tt.wantOK)
+ }
+ if len(got) != tt.wantLen {
+ t.Errorf("parseSurfaceEntries() len = %d, want %d", len(got), tt.wantLen)
+ }
+ })
+ }
+}
diff --git a/internal/appconfig/config_load.go b/internal/appconfig/config_load.go
index 261835b..dc917ff 100644
--- a/internal/appconfig/config_load.go
+++ b/internal/appconfig/config_load.go
@@ -139,6 +139,8 @@ func loadProjectConfig(logger *log.Logger, opts LoadOptions, cfg *App) {
}
}
+// loadFromFile reads a TOML config file, validates it, and returns the parsed App.
+// Returns (nil, err) on I/O or parse errors; returns (nil, nil) when the file does not exist.
func loadFromFile(path string, logger *log.Logger) (*App, error) {
b, err := os.ReadFile(path)
if err != nil {
@@ -148,25 +150,50 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) {
return nil, err
}
+ tables, raw, err := decodeTOML(b, path, logger)
+ if err != nil {
+ return nil, err
+ }
+ if err := rejectLegacyKeys(raw); err != nil {
+ return nil, err
+ }
+ if logger != nil {
+ logger.Printf("loaded configuration from %s (TOML)", path)
+ }
+
+ tab := tables.toApp()
+ applyRawIntOverrides(raw, &tab)
+ if m := parseSurfaceModels(raw, logger); m != nil {
+ tab.mergeSurfaceModels(m)
+ }
+ return &tab, nil
+}
+
+// decodeTOML parses raw TOML bytes into both the typed fileConfig and a raw map
+// for validation and defensive integer handling.
+func decodeTOML(b []byte, path string, logger *log.Logger) (*fileConfig, map[string]any, error) {
var tables fileConfig
errTables := toml.NewDecoder(strings.NewReader(string(b))).Decode(&tables)
- // Raw map for validation/presence checks
var raw map[string]any
errRaw := toml.Unmarshal(b, &raw)
if errTables != nil {
if logger != nil {
logger.Printf("invalid TOML config file %s: %v", path, errTables)
}
- return nil, errTables
+ return nil, nil, errTables
}
if errRaw != nil {
if logger != nil {
logger.Printf("invalid TOML config file %s: %v", path, errRaw)
}
- return nil, errRaw
+ return nil, nil, errRaw
}
+ return &tables, raw, nil
+}
- // Reject legacy flat keys at top-level (sectioned-only config is allowed)
+// rejectLegacyKeys returns an error if the raw map contains flat keys from the
+// old unsectioned config format. Only sectioned table keys are allowed.
+func rejectLegacyKeys(raw map[string]any) error {
legacy := map[string]struct{}{
"max_tokens": {}, "context_mode": {}, "context_window_lines": {}, "max_context_tokens": {},
"log_preview_limit": {}, "completion_debounce_ms": {}, "completion_throttle_ms": {},
@@ -175,53 +202,46 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) {
"openai_model": {}, "openai_base_url": {}, "openai_temperature": {},
"ollama_model": {}, "ollama_base_url": {}, "ollama_temperature": {},
}
+ knownTables := map[string]struct{}{
+ "general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {},
+ "chat": {}, "provider": {}, "models": {}, "openai": {}, "ollama": {}, "prompts": {},
+ }
for k := range raw {
- if _, isTable := map[string]struct{}{
- "general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {},
- "chat": {}, "provider": {}, "models": {}, "openai": {}, "ollama": {}, "prompts": {},
- }[k]; isTable {
+ if _, isTable := knownTables[k]; isTable {
continue
}
if _, isLegacy := legacy[k]; isLegacy {
- return nil, fmt.Errorf("unsupported flat key '%s' in config; use sectioned tables (see config.toml.example)", k)
+ return fmt.Errorf("unsupported flat key '%s' in config; use sectioned tables (see config.toml.example)", k)
}
}
+ return nil
+}
- if logger != nil {
- logger.Printf("loaded configuration from %s (TOML)", path)
- }
-
- // Build App from tables only
- tab := tables.toApp()
- // Ensure explicit values from raw map are respected (defensive for ints)
+// applyRawIntOverrides defensively re-applies integer values from the raw TOML map
+// that the typed decoder may have silently zeroed (e.g. int vs float mismatch).
+func applyRawIntOverrides(raw map[string]any, tab *App) {
if t, ok := raw["completion"].(map[string]any); ok {
- if v, present := t["manual_invoke_min_prefix"]; present {
- switch vv := v.(type) {
- case int64:
- tab.ManualInvokeMinPrefix = int(vv)
- case int:
- tab.ManualInvokeMinPrefix = vv
- case float64:
- tab.ManualInvokeMinPrefix = int(vv)
- }
- }
+ applyRawInt(&tab.ManualInvokeMinPrefix, t, "manual_invoke_min_prefix")
}
if t, ok := raw["logging"].(map[string]any); ok {
- if v, present := t["log_preview_limit"]; present {
- switch vv := v.(type) {
- case int64:
- tab.LogPreviewLimit = int(vv)
- case int:
- tab.LogPreviewLimit = vv
- case float64:
- tab.LogPreviewLimit = int(vv)
- }
- }
+ applyRawInt(&tab.LogPreviewLimit, t, "log_preview_limit")
}
- if m := parseSurfaceModels(raw, logger); m != nil {
- tab.mergeSurfaceModels(m)
+}
+
+// applyRawInt sets *dst from table[key] when the value is a numeric type.
+func applyRawInt(dst *int, table map[string]any, key string) {
+ v, present := table[key]
+ if !present {
+ return
+ }
+ switch vv := v.(type) {
+ case int64:
+ *dst = int(vv)
+ case int:
+ *dst = vv
+ case float64:
+ *dst = int(vv)
}
- return &tab, nil
}
func (fc *fileConfig) toApp() App {
diff --git a/internal/appconfig/config_test.go b/internal/appconfig/config_test.go
index c98d904..d4e7739 100644
--- a/internal/appconfig/config_test.go
+++ b/internal/appconfig/config_test.go
@@ -839,581 +839,3 @@ func TestProjectConfigPath(t *testing.T) {
t.Fatalf("ProjectConfigPath() = %q, want empty", path)
}
}
-
-func TestIgnoreConfig_Defaults(t *testing.T) {
- clearHexaiEnv(t)
- cfg := Load(nil)
- if cfg.IgnoreGitignore == nil || !*cfg.IgnoreGitignore {
- t.Error("expected IgnoreGitignore default true")
- }
- if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify {
- t.Error("expected IgnoreLSPNotify default true")
- }
- if len(cfg.IgnoreExtraPatterns) != 0 {
- t.Errorf("expected empty IgnoreExtraPatterns, got %v", cfg.IgnoreExtraPatterns)
- }
-}
-
-func TestIgnoreConfig_FromFile(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[ignore]
-gitignore = false
-extra_patterns = ["*.min.js", "dist/**"]
-lsp_notify_ignored = false
-`)
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
- if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
- t.Error("expected IgnoreGitignore false from file")
- }
- if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify {
- t.Error("expected IgnoreLSPNotify false from file")
- }
- want := []string{"*.min.js", "dist/**"}
- if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
- t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
- }
-}
-
-func TestIgnoreConfig_EnvOverrides(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[ignore]
-gitignore = true
-lsp_notify_ignored = true
-`)
- withEnv(t, "HEXAI_IGNORE_GITIGNORE", "false")
- withEnv(t, "HEXAI_IGNORE_LSP_NOTIFY", "0")
- withEnv(t, "HEXAI_IGNORE_EXTRA_PATTERNS", "*.bak,*.tmp")
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
- if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
- t.Error("expected IgnoreGitignore false from env override")
- }
- if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify {
- t.Error("expected IgnoreLSPNotify false from env override")
- }
- want := []string{"*.bak", "*.tmp"}
- if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
- t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
- }
-}
-
-func TestIgnoreConfig_ProjectOverride(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[ignore]
-gitignore = true
-`)
- // Set up a fake git repo with project override
- projectDir := t.TempDir()
- if err := os.Mkdir(filepath.Join(projectDir, ".git"), 0o755); err != nil {
- t.Fatalf("mkdir .git: %v", err)
- }
- projectCfg := filepath.Join(projectDir, ProjectConfigFilename)
- writeFile(t, projectCfg, `
-[ignore]
-gitignore = false
-extra_patterns = ["build/**"]
-`)
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: projectDir})
- if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
- t.Error("expected project override to set IgnoreGitignore false")
- }
- want := []string{"build/**"}
- if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) {
- t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want)
- }
-}
-
-func TestIgnoreConfig_DisableGitignore(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[ignore]
-gitignore = false
-`)
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
- if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore {
- t.Error("expected IgnoreGitignore false")
- }
- // LSP notify should still be true (default, not overridden)
- if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify {
- t.Error("expected IgnoreLSPNotify to remain true (default)")
- }
-}
-
-func TestTmuxEditConfig_FromFile(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[tmux_edit]
-popup_width = "90%"
-popup_height = "85%"
-default_agent = "claude"
-
-[[tmux_edit.agents]]
-name = "claude"
-display_name = "Claude Code"
-detect_pattern = "(?i)(claude|anthropic)"
-prompt_pattern = '(?s)>\s*(.+?)$'
-clear_first = true
-clear_keys = "C-u"
-newline_keys = "S-Enter"
-submit_keys = "Enter"
-
-[[tmux_edit.agents]]
-name = "cursor"
-display_name = "Cursor"
-detect_pattern = "(?i)cursor"
-prompt_pattern = '(?s)│\s*(.+?)$'
-strip_patterns = ["INSERT", "Add a follow-up"]
-clear_first = true
-clear_keys = "C-u"
-newline_keys = "S-Enter"
-submit_keys = "Enter"
-`)
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
- if cfg.TmuxEditPopupWidth != "90%" {
- t.Errorf("PopupWidth = %q, want 90%%", cfg.TmuxEditPopupWidth)
- }
- if cfg.TmuxEditPopupHeight != "85%" {
- t.Errorf("PopupHeight = %q, want 85%%", cfg.TmuxEditPopupHeight)
- }
- if cfg.TmuxEditDefaultAgent != "claude" {
- t.Errorf("DefaultAgent = %q, want claude", cfg.TmuxEditDefaultAgent)
- }
- if len(cfg.TmuxEditAgents) != 2 {
- t.Fatalf("got %d agents, want 2", len(cfg.TmuxEditAgents))
- }
- a := cfg.TmuxEditAgents[0]
- if a.Name != "claude" || a.DisplayName != "Claude Code" {
- t.Errorf("agent[0] = %q/%q, want claude/Claude Code", a.Name, a.DisplayName)
- }
- if a.ClearFirst == nil || !*a.ClearFirst {
- t.Error("expected ClearFirst = true for claude agent")
- }
- b := cfg.TmuxEditAgents[1]
- if b.Name != "cursor" {
- t.Errorf("agent[1].Name = %q, want cursor", b.Name)
- }
- if len(b.StripPatterns) != 2 {
- t.Errorf("agent[1].StripPatterns = %v, want 2 entries", b.StripPatterns)
- }
-}
-
-func TestTmuxEditConfig_Merge(t *testing.T) {
- clearHexaiEnv(t)
- a := newDefaultConfig()
- b := App{
- TmuxEditPopupWidth: "70%",
- TmuxEditDefaultAgent: "amp",
- TmuxEditAgents: []TmuxEditAgentCfg{
- {Name: "amp", DisplayName: "Amp"},
- },
- }
- a.mergeWith(&b)
- if a.TmuxEditPopupWidth != "70%" {
- t.Errorf("PopupWidth = %q, want 70%%", a.TmuxEditPopupWidth)
- }
- if a.TmuxEditDefaultAgent != "amp" {
- t.Errorf("DefaultAgent = %q, want amp", a.TmuxEditDefaultAgent)
- }
- if len(a.TmuxEditAgents) != 1 || a.TmuxEditAgents[0].Name != "amp" {
- t.Errorf("Agents = %v, want single amp", a.TmuxEditAgents)
- }
-}
-
-func TestTmuxEditConfig_SkipsEmptyName(t *testing.T) {
- clearHexaiEnv(t)
- dir := t.TempDir()
- cfgPath := filepath.Join(dir, "config.toml")
- writeFile(t, cfgPath, `
-[tmux_edit]
-[[tmux_edit.agents]]
-name = ""
-display_name = "Empty"
-`)
- cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir})
- if len(cfg.TmuxEditAgents) != 0 {
- t.Errorf("got %d agents, want 0 (empty name should be skipped)", len(cfg.TmuxEditAgents))
- }
-}
-
-// --- Phase 1: Config Parsing Tests ---
-
-func TestParseTemperatureValue(t *testing.T) {
- tests := []struct {
- name string
- input any
- wantValue *float64
- wantOK bool
- }{
- {"float64 zero", float64(0.0), floatPtr(0.0), true},
- {"float64 half", float64(0.5), floatPtr(0.5), true},
- {"float64 one", float64(1.0), floatPtr(1.0), true},
- {"float64 two", float64(2.0), floatPtr(2.0), true},
- {"int64 zero", int64(0), floatPtr(0.0), true},
- {"int64 one", int64(1), floatPtr(1.0), true},
- {"int64 two", int64(2), floatPtr(2.0), true},
- {"string zero", "0", floatPtr(0.0), true},
- {"string one", "1", floatPtr(1.0), true},
- {"string two", "2", floatPtr(2.0), true},
- {"string float", "0.75", floatPtr(0.75), true},
- {"string empty", "", nil, true},
- {"string whitespace", " ", nil, true},
- {"string invalid", "invalid", nil, false},
- {"string negative", "-0.5", floatPtr(-0.5), true},
- {"string very small", "0.0001", floatPtr(0.0001), true},
- {"string high precision", "1.123456789", floatPtr(1.123456789), true},
- {"nil value", nil, nil, false},
- {"bool value", true, nil, false},
- {"map value", map[string]any{}, nil, false},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, ok := parseTemperatureValue(tt.input, "test", newLogger())
- if ok != tt.wantOK {
- t.Errorf("parseTemperatureValue() ok = %v, want %v", ok, tt.wantOK)
- }
- if !ok {
- return
- }
- if (got == nil) != (tt.wantValue == nil) {
- t.Errorf("parseTemperatureValue() = %v, want %v", got, tt.wantValue)
- return
- }
- if got != nil && tt.wantValue != nil && *got != *tt.wantValue {
- t.Errorf("parseTemperatureValue() = %v, want %v", *got, *tt.wantValue)
- }
- })
- }
-}
-
-func TestDecodeModelEntry(t *testing.T) {
- tests := []struct {
- name string
- input any
- wantCfg *SurfaceConfig
- wantOK bool
- }{
- {
- name: "simple string model",
- input: "gpt-4",
- wantCfg: &SurfaceConfig{Model: "gpt-4"},
- wantOK: true,
- },
- {
- name: "empty string",
- input: "",
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "whitespace string",
- input: " ",
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "object with all fields",
- input: map[string]any{
- "model": "claude-3",
- "provider": "anthropic",
- "temperature": float64(0.7),
- },
- wantCfg: &SurfaceConfig{
- Model: "claude-3",
- Provider: "anthropic",
- Temperature: floatPtr(0.7),
- },
- wantOK: true,
- },
- {
- name: "object with model only",
- input: map[string]any{
- "model": "gpt-4o",
- },
- wantCfg: &SurfaceConfig{Model: "gpt-4o"},
- wantOK: true,
- },
- {
- name: "object with provider only",
- input: map[string]any{
- "provider": "openai",
- },
- wantCfg: &SurfaceConfig{Provider: "openai"},
- wantOK: true,
- },
- {
- name: "object with temperature only",
- input: map[string]any{
- "temperature": float64(0.5),
- },
- wantCfg: &SurfaceConfig{Temperature: floatPtr(0.5)},
- wantOK: true,
- },
- {
- name: "object with empty fields",
- input: map[string]any{
- "model": "",
- "provider": "",
- },
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "object with invalid model type",
- input: map[string]any{
- "model": 123,
- },
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "object with invalid provider type",
- input: map[string]any{
- "provider": 456,
- },
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "object with invalid temperature",
- input: map[string]any{
- "model": "gpt-4",
- "temperature": "not a number",
- },
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "nil input",
- input: nil,
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "invalid type (int)",
- input: 123,
- wantCfg: nil,
- wantOK: false,
- },
- {
- name: "invalid type (slice)",
- input: []string{"gpt-4"},
- wantCfg: nil,
- wantOK: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, ok := decodeModelEntry(tt.input, "test", newLogger())
- if ok != tt.wantOK {
- t.Errorf("decodeModelEntry() ok = %v, want %v", ok, tt.wantOK)
- }
- if !ok {
- return
- }
- if (got == nil) != (tt.wantCfg == nil) {
- t.Errorf("decodeModelEntry() = %v, want %v", got, tt.wantCfg)
- return
- }
- if got == nil {
- return
- }
- if got.Model != tt.wantCfg.Model {
- t.Errorf("Model = %q, want %q", got.Model, tt.wantCfg.Model)
- }
- if got.Provider != tt.wantCfg.Provider {
- t.Errorf("Provider = %q, want %q", got.Provider, tt.wantCfg.Provider)
- }
- if (got.Temperature == nil) != (tt.wantCfg.Temperature == nil) {
- t.Errorf("Temperature nil mismatch: got %v, want %v", got.Temperature, tt.wantCfg.Temperature)
- } else if got.Temperature != nil && *got.Temperature != *tt.wantCfg.Temperature {
- t.Errorf("Temperature = %v, want %v", *got.Temperature, *tt.wantCfg.Temperature)
- }
- })
- }
-}
-
-func TestResolvedModel(t *testing.T) {
- tests := []struct {
- name string
- section sectionOpenAI
- want string
- }{
- {
- name: "explicit model no presets",
- section: sectionOpenAI{Model: "gpt-4"},
- want: "gpt-4",
- },
- {
- name: "empty model",
- section: sectionOpenAI{Model: ""},
- want: "",
- },
- {
- name: "whitespace model",
- section: sectionOpenAI{Model: " "},
- want: "",
- },
- {
- name: "preset match exact case",
- section: sectionOpenAI{
- Model: "fast",
- Presets: map[string]string{"fast": "gpt-3.5-turbo"},
- },
- want: "gpt-3.5-turbo",
- },
- {
- name: "preset match case insensitive",
- section: sectionOpenAI{
- Model: "FAST",
- Presets: map[string]string{"fast": "gpt-3.5-turbo"},
- },
- want: "gpt-3.5-turbo",
- },
- {
- name: "no preset match returns original",
- section: sectionOpenAI{
- Model: "custom-model",
- Presets: map[string]string{"fast": "gpt-3.5-turbo"},
- },
- want: "custom-model",
- },
- {
- name: "preset empty value returns original",
- section: sectionOpenAI{
- Model: "fast",
- Presets: map[string]string{"fast": ""},
- },
- want: "fast",
- },
- {
- name: "preset whitespace value returns original",
- section: sectionOpenAI{
- Model: "fast",
- Presets: map[string]string{"fast": " "},
- },
- want: "fast",
- },
- {
- name: "multiple presets uses correct one",
- section: sectionOpenAI{
- Model: "smart",
- Presets: map[string]string{
- "fast": "gpt-3.5-turbo",
- "smart": "gpt-4",
- "mini": "gpt-3.5-mini",
- },
- },
- want: "gpt-4",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := tt.section.resolvedModel()
- if got != tt.want {
- t.Errorf("resolvedModel() = %q, want %q", got, tt.want)
- }
- })
- }
-}
-
-func TestParseSurfaceEntries(t *testing.T) {
- tests := []struct {
- name string
- input any
- wantLen int
- wantOK bool
- }{
- {
- name: "nil input",
- input: nil,
- wantLen: 0,
- wantOK: false,
- },
- {
- name: "single string",
- input: "gpt-4",
- wantLen: 1,
- wantOK: true,
- },
- {
- name: "single map",
- input: map[string]any{
- "model": "claude-3",
- "provider": "anthropic",
- },
- wantLen: 1,
- wantOK: true,
- },
- {
- name: "array of strings",
- input: []any{
- "gpt-4",
- "claude-3",
- },
- wantLen: 2,
- wantOK: true,
- },
- {
- name: "array of maps",
- input: []any{
- map[string]any{"model": "gpt-4", "provider": "openai"},
- map[string]any{"model": "claude-3", "provider": "anthropic"},
- },
- wantLen: 2,
- wantOK: true,
- },
- {
- name: "array with invalid entries",
- input: []any{
- "gpt-4",
- 123,
- "claude-3",
- },
- wantLen: 2,
- wantOK: true,
- },
- {
- name: "array with all invalid entries",
- input: []any{
- 123,
- true,
- nil,
- },
- wantLen: 0,
- wantOK: false,
- },
- {
- name: "empty array",
- input: []any{},
- wantLen: 0,
- wantOK: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, ok := parseSurfaceEntries(tt.input, "test", newLogger())
- if ok != tt.wantOK {
- t.Errorf("parseSurfaceEntries() ok = %v, want %v", ok, tt.wantOK)
- }
- if len(got) != tt.wantLen {
- t.Errorf("parseSurfaceEntries() len = %d, want %d", len(got), tt.wantLen)
- }
- })
- }
-}
diff --git a/internal/mcp/handlers_test.go b/internal/mcp/handlers_prompt_test.go
index 2a4f821..ad0d261 100644
--- a/internal/mcp/handlers_test.go
+++ b/internal/mcp/handlers_prompt_test.go
@@ -1,4 +1,4 @@
-// Summary: Tests for MCP prompt management handlers
+// Summary: Tests for MCP prompt management handlers (create, update, delete, get, list)
package mcp
import (
@@ -237,7 +237,7 @@ func TestServer_PromptsCreate_MissingName(t *testing.T) {
}
}
-// Update mockPromptStore to support Create, Update, Delete
+// mockPromptStore Create, Update, Delete methods used by prompt handler tests.
func (m *mockPromptStore) Create(prompt *promptstore.Prompt) error {
if _, exists := m.prompts[prompt.Name]; exists {
return fmt.Errorf("prompt already exists: %s", prompt.Name)
@@ -953,698 +953,3 @@ func TestServer_HandleInitialize_InvalidParams(t *testing.T) {
t.Fatal("Expected error for invalid params")
}
}
-
-// ==================== Tools Tests ====================
-
-func TestServer_ToolsList(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/list request
- req := Request{
- JSONRPC: "2.0",
- ID: 40,
- Method: "tools/list",
- Params: json.RawMessage(`{}`),
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error != nil {
- t.Fatalf("Error = %v, want nil", resp.Error)
- }
-
- // Parse result
- resultBytes, _ := json.Marshal(resp.Result)
- var result ListToolsResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- // Verify 3 tools returned
- if len(result.Tools) != 3 {
- t.Errorf("Tools count = %d, want 3", len(result.Tools))
- }
-
- // Verify tool names
- toolNames := make(map[string]bool)
- for _, tool := range result.Tools {
- toolNames[tool.Name] = true
- if tool.Description == "" {
- t.Errorf("Tool %s has empty description", tool.Name)
- }
- if tool.InputSchema == nil {
- t.Errorf("Tool %s has nil InputSchema", tool.Name)
- }
- }
-
- expectedTools := []string{"create_prompt", "update_prompt", "delete_prompt"}
- for _, name := range expectedTools {
- if !toolNames[name] {
- t.Errorf("Missing expected tool: %s", name)
- }
- }
-}
-
-func TestServer_ToolsCall_CreatePrompt(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call request
- params := CallToolRequest{
- Name: "create_prompt",
- Arguments: map[string]interface{}{
- "name": "tool_test",
- "title": "Tool Test Prompt",
- "messages": []interface{}{
- map[string]interface{}{
- "role": "user",
- "content": map[string]interface{}{
- "type": "text",
- "text": "Test message",
- },
- },
- },
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 41,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error != nil {
- t.Fatalf("Error = %v, want nil", resp.Error)
- }
-
- // Parse result
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if result.IsError {
- t.Errorf("IsError = true, want false. Content: %v", result.Content)
- }
-
- if len(result.Content) == 0 {
- t.Fatal("Expected content in result")
- }
-
- // Verify prompt was created
- if _, exists := store.prompts["tool_test"]; !exists {
- t.Error("Prompt was not created in store")
- }
-}
-
-func TestServer_ToolsCall_UpdatePrompt(t *testing.T) {
- now := time.Now()
- store := &mockPromptStore{
- prompts: map[string]*promptstore.Prompt{
- "tool_update": {
- Name: "tool_update",
- Title: "Original Title",
- Created: now,
- Updated: now,
- Messages: []promptstore.PromptMessage{
- {
- Role: "user",
- Content: promptstore.MessageContent{
- Type: "text",
- Text: "Original",
- },
- },
- },
- },
- },
- }
-
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call request
- params := CallToolRequest{
- Name: "update_prompt",
- Arguments: map[string]interface{}{
- "name": "tool_update",
- "title": "Updated Title",
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 42,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error != nil {
- t.Fatalf("Error = %v, want nil", resp.Error)
- }
-
- // Parse result
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if result.IsError {
- t.Errorf("IsError = true, want false. Content: %v", result.Content)
- }
-
- // Verify prompt was updated
- if store.prompts["tool_update"].Title != "Updated Title" {
- t.Errorf("Title not updated, got %s", store.prompts["tool_update"].Title)
- }
-}
-
-func TestServer_ToolsCall_DeletePrompt(t *testing.T) {
- now := time.Now()
- store := &mockPromptStore{
- prompts: map[string]*promptstore.Prompt{
- "tool_delete": {
- Name: "tool_delete",
- Title: "To Delete",
- Created: now,
- Updated: now,
- Messages: []promptstore.PromptMessage{
- {
- Role: "user",
- Content: promptstore.MessageContent{
- Type: "text",
- Text: "Test",
- },
- },
- },
- },
- },
- }
-
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call request
- params := CallToolRequest{
- Name: "delete_prompt",
- Arguments: map[string]interface{}{
- "name": "tool_delete",
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 43,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error != nil {
- t.Fatalf("Error = %v, want nil", resp.Error)
- }
-
- // Parse result
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if result.IsError {
- t.Errorf("IsError = true, want false. Content: %v", result.Content)
- }
-
- // Verify prompt was deleted
- if _, exists := store.prompts["tool_delete"]; exists {
- t.Error("Prompt was not deleted from store")
- }
-}
-
-func TestServer_ToolsCall_UnknownTool(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call request with unknown tool
- params := CallToolRequest{
- Name: "nonexistent_tool",
- Arguments: map[string]interface{}{},
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 44,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- // Should not be a protocol error
- if resp.Error != nil {
- t.Fatalf("Unexpected protocol error: %v", resp.Error)
- }
-
- // Parse result - should be a tool error
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if !result.IsError {
- t.Error("Expected IsError = true for unknown tool")
- }
-}
-
-func TestServer_ToolsCall_InvalidArguments(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call with invalid arguments (missing required fields)
- params := CallToolRequest{
- Name: "create_prompt",
- Arguments: map[string]interface{}{
- "name": "test", // Missing title and messages
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 45,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- // Should not be a protocol error
- if resp.Error != nil {
- t.Fatalf("Unexpected protocol error: %v", resp.Error)
- }
-
- // Parse result - should be a tool error
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if !result.IsError {
- t.Error("Expected IsError = true for invalid arguments")
- }
-}
-
-func TestServer_ToolsCall_NotInitialized(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // DO NOT initialize server
-
- // Send tools/call request
- params := CallToolRequest{
- Name: "create_prompt",
- Arguments: map[string]interface{}{},
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 46,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error == nil {
- t.Fatal("Expected error for uninitialized server")
- }
-
- if resp.Error.Code != ErrCodeInvalidRequest {
- t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest)
- }
-}
-
-func TestServer_ToolsCall_CreatePrompt_AlreadyExists(t *testing.T) {
- now := time.Now()
- store := &mockPromptStore{
- prompts: map[string]*promptstore.Prompt{
- "existing": {
- Name: "existing",
- Title: "Existing",
- Created: now,
- Updated: now,
- Messages: []promptstore.PromptMessage{},
- },
- },
- }
-
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call to create duplicate
- params := CallToolRequest{
- Name: "create_prompt",
- Arguments: map[string]interface{}{
- "name": "existing",
- "title": "Duplicate",
- "messages": []interface{}{
- map[string]interface{}{
- "role": "user",
- "content": map[string]interface{}{
- "type": "text",
- "text": "Test",
- },
- },
- },
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 47,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- // Should not be a protocol error
- if resp.Error != nil {
- t.Fatalf("Unexpected protocol error: %v", resp.Error)
- }
-
- // Parse result - should be a tool error
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if !result.IsError {
- t.Error("Expected IsError = true for duplicate prompt")
- }
-}
-
-func TestServer_Initialize_AdvertisesTools(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Send initialize request
- params := InitializeRequest{
- ProtocolVersion: "2025-11-25",
- Capabilities: ClientCapabilities{},
- ClientInfo: ClientInfo{
- Name: "test-client",
- Version: "1.0",
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 48,
- Method: "initialize",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error != nil {
- t.Fatalf("Error = %v, want nil", resp.Error)
- }
-
- // Parse result
- resultBytes, _ := json.Marshal(resp.Result)
- var result InitializeResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- // Verify Tools capability is advertised
- if result.Capabilities.Tools == nil {
- t.Fatal("Tools capability not advertised")
- }
-
- // Verify Prompts capability is also still advertised
- if result.Capabilities.Prompts == nil {
- t.Fatal("Prompts capability not advertised")
- }
-}
-
-func TestServer_ToolsList_NotInitialized(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // DO NOT initialize server
-
- // Send tools/list request
- req := Request{
- JSONRPC: "2.0",
- ID: 49,
- Method: "tools/list",
- Params: json.RawMessage(`{}`),
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error == nil {
- t.Fatal("Expected error for uninitialized server")
- }
-
- if resp.Error.Code != ErrCodeInvalidRequest {
- t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest)
- }
-}
-
-func TestServer_ToolsCall_UpdatePrompt_NotFound(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call to update non-existent prompt
- params := CallToolRequest{
- Name: "update_prompt",
- Arguments: map[string]interface{}{
- "name": "nonexistent",
- "title": "Updated",
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 50,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- // Should not be a protocol error
- if resp.Error != nil {
- t.Fatalf("Unexpected protocol error: %v", resp.Error)
- }
-
- // Parse result - should be a tool error
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if !result.IsError {
- t.Error("Expected IsError = true for non-existent prompt")
- }
-}
-
-func TestServer_ToolsCall_DeletePrompt_NotFound(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call to delete non-existent prompt
- params := CallToolRequest{
- Name: "delete_prompt",
- Arguments: map[string]interface{}{
- "name": "nonexistent",
- },
- }
- paramsBytes, _ := json.Marshal(params)
- req := Request{
- JSONRPC: "2.0",
- ID: 51,
- Method: "tools/call",
- Params: paramsBytes,
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- // Should not be a protocol error
- if resp.Error != nil {
- t.Fatalf("Unexpected protocol error: %v", resp.Error)
- }
-
- // Parse result - should be a tool error
- resultBytes, _ := json.Marshal(resp.Result)
- var result CallToolResult
- if err := json.Unmarshal(resultBytes, &result); err != nil {
- t.Fatalf("Unmarshal result error = %v", err)
- }
-
- if !result.IsError {
- t.Error("Expected IsError = true for non-existent prompt")
- }
-}
-
-func TestServer_ToolsCall_InvalidParams(t *testing.T) {
- store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
- server, _, outBuf := createTestServer(t, store)
-
- // Initialize server
- server.mu.Lock()
- server.initialized = true
- server.mu.Unlock()
-
- // Send tools/call with invalid JSON params
- req := Request{
- JSONRPC: "2.0",
- ID: 52,
- Method: "tools/call",
- Params: json.RawMessage(`{invalid}`),
- }
-
- server.handle(req)
-
- // Read response
- resp, err := readResponse(outBuf)
- if err != nil {
- t.Fatalf("readResponse() error = %v", err)
- }
-
- if resp.Error == nil {
- t.Fatal("Expected protocol error for invalid params")
- }
-
- if resp.Error.Code != ErrCodeInvalidParams {
- t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidParams)
- }
-}
diff --git a/internal/mcp/handlers_tool_test.go b/internal/mcp/handlers_tool_test.go
new file mode 100644
index 0000000..604938a
--- /dev/null
+++ b/internal/mcp/handlers_tool_test.go
@@ -0,0 +1,703 @@
+// Summary: Tests for MCP tool handlers (tools/list, tools/call for create/update/delete)
+package mcp
+
+import (
+ "encoding/json"
+ "testing"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/promptstore"
+)
+
+func TestServer_ToolsList(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/list request
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 40,
+ Method: "tools/list",
+ Params: json.RawMessage(`{}`),
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result ListToolsResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ // Verify 3 tools returned
+ if len(result.Tools) != 3 {
+ t.Errorf("Tools count = %d, want 3", len(result.Tools))
+ }
+
+ // Verify tool names
+ toolNames := make(map[string]bool)
+ for _, tool := range result.Tools {
+ toolNames[tool.Name] = true
+ if tool.Description == "" {
+ t.Errorf("Tool %s has empty description", tool.Name)
+ }
+ if tool.InputSchema == nil {
+ t.Errorf("Tool %s has nil InputSchema", tool.Name)
+ }
+ }
+
+ expectedTools := []string{"create_prompt", "update_prompt", "delete_prompt"}
+ for _, name := range expectedTools {
+ if !toolNames[name] {
+ t.Errorf("Missing expected tool: %s", name)
+ }
+ }
+}
+
+func TestServer_ToolsCall_CreatePrompt(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call request
+ params := CallToolRequest{
+ Name: "create_prompt",
+ Arguments: map[string]interface{}{
+ "name": "tool_test",
+ "title": "Tool Test Prompt",
+ "messages": []interface{}{
+ map[string]interface{}{
+ "role": "user",
+ "content": map[string]interface{}{
+ "type": "text",
+ "text": "Test message",
+ },
+ },
+ },
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 41,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if result.IsError {
+ t.Errorf("IsError = true, want false. Content: %v", result.Content)
+ }
+
+ if len(result.Content) == 0 {
+ t.Fatal("Expected content in result")
+ }
+
+ // Verify prompt was created
+ if _, exists := store.prompts["tool_test"]; !exists {
+ t.Error("Prompt was not created in store")
+ }
+}
+
+func TestServer_ToolsCall_UpdatePrompt(t *testing.T) {
+ now := time.Now()
+ store := &mockPromptStore{
+ prompts: map[string]*promptstore.Prompt{
+ "tool_update": {
+ Name: "tool_update",
+ Title: "Original Title",
+ Created: now,
+ Updated: now,
+ Messages: []promptstore.PromptMessage{
+ {
+ Role: "user",
+ Content: promptstore.MessageContent{
+ Type: "text",
+ Text: "Original",
+ },
+ },
+ },
+ },
+ },
+ }
+
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call request
+ params := CallToolRequest{
+ Name: "update_prompt",
+ Arguments: map[string]interface{}{
+ "name": "tool_update",
+ "title": "Updated Title",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 42,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if result.IsError {
+ t.Errorf("IsError = true, want false. Content: %v", result.Content)
+ }
+
+ // Verify prompt was updated
+ if store.prompts["tool_update"].Title != "Updated Title" {
+ t.Errorf("Title not updated, got %s", store.prompts["tool_update"].Title)
+ }
+}
+
+func TestServer_ToolsCall_DeletePrompt(t *testing.T) {
+ now := time.Now()
+ store := &mockPromptStore{
+ prompts: map[string]*promptstore.Prompt{
+ "tool_delete": {
+ Name: "tool_delete",
+ Title: "To Delete",
+ Created: now,
+ Updated: now,
+ Messages: []promptstore.PromptMessage{
+ {
+ Role: "user",
+ Content: promptstore.MessageContent{
+ Type: "text",
+ Text: "Test",
+ },
+ },
+ },
+ },
+ },
+ }
+
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call request
+ params := CallToolRequest{
+ Name: "delete_prompt",
+ Arguments: map[string]interface{}{
+ "name": "tool_delete",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 43,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if result.IsError {
+ t.Errorf("IsError = true, want false. Content: %v", result.Content)
+ }
+
+ // Verify prompt was deleted
+ if _, exists := store.prompts["tool_delete"]; exists {
+ t.Error("Prompt was not deleted from store")
+ }
+}
+
+func TestServer_ToolsCall_UnknownTool(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call request with unknown tool
+ params := CallToolRequest{
+ Name: "nonexistent_tool",
+ Arguments: map[string]interface{}{},
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 44,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ // Should not be a protocol error
+ if resp.Error != nil {
+ t.Fatalf("Unexpected protocol error: %v", resp.Error)
+ }
+
+ // Parse result - should be a tool error
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if !result.IsError {
+ t.Error("Expected IsError = true for unknown tool")
+ }
+}
+
+func TestServer_ToolsCall_InvalidArguments(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call with invalid arguments (missing required fields)
+ params := CallToolRequest{
+ Name: "create_prompt",
+ Arguments: map[string]interface{}{
+ "name": "test", // Missing title and messages
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 45,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ // Should not be a protocol error
+ if resp.Error != nil {
+ t.Fatalf("Unexpected protocol error: %v", resp.Error)
+ }
+
+ // Parse result - should be a tool error
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if !result.IsError {
+ t.Error("Expected IsError = true for invalid arguments")
+ }
+}
+
+func TestServer_ToolsCall_NotInitialized(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // DO NOT initialize server
+
+ // Send tools/call request
+ params := CallToolRequest{
+ Name: "create_prompt",
+ Arguments: map[string]interface{}{},
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 46,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error == nil {
+ t.Fatal("Expected error for uninitialized server")
+ }
+
+ if resp.Error.Code != ErrCodeInvalidRequest {
+ t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest)
+ }
+}
+
+func TestServer_ToolsCall_CreatePrompt_AlreadyExists(t *testing.T) {
+ now := time.Now()
+ store := &mockPromptStore{
+ prompts: map[string]*promptstore.Prompt{
+ "existing": {
+ Name: "existing",
+ Title: "Existing",
+ Created: now,
+ Updated: now,
+ Messages: []promptstore.PromptMessage{},
+ },
+ },
+ }
+
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call to create duplicate
+ params := CallToolRequest{
+ Name: "create_prompt",
+ Arguments: map[string]interface{}{
+ "name": "existing",
+ "title": "Duplicate",
+ "messages": []interface{}{
+ map[string]interface{}{
+ "role": "user",
+ "content": map[string]interface{}{
+ "type": "text",
+ "text": "Test",
+ },
+ },
+ },
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 47,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ // Should not be a protocol error
+ if resp.Error != nil {
+ t.Fatalf("Unexpected protocol error: %v", resp.Error)
+ }
+
+ // Parse result - should be a tool error
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if !result.IsError {
+ t.Error("Expected IsError = true for duplicate prompt")
+ }
+}
+
+func TestServer_Initialize_AdvertisesTools(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Send initialize request
+ params := InitializeRequest{
+ ProtocolVersion: "2025-11-25",
+ Capabilities: ClientCapabilities{},
+ ClientInfo: ClientInfo{
+ Name: "test-client",
+ Version: "1.0",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 48,
+ Method: "initialize",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result InitializeResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ // Verify Tools capability is advertised
+ if result.Capabilities.Tools == nil {
+ t.Fatal("Tools capability not advertised")
+ }
+
+ // Verify Prompts capability is also still advertised
+ if result.Capabilities.Prompts == nil {
+ t.Fatal("Prompts capability not advertised")
+ }
+}
+
+func TestServer_ToolsList_NotInitialized(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // DO NOT initialize server
+
+ // Send tools/list request
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 49,
+ Method: "tools/list",
+ Params: json.RawMessage(`{}`),
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error == nil {
+ t.Fatal("Expected error for uninitialized server")
+ }
+
+ if resp.Error.Code != ErrCodeInvalidRequest {
+ t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest)
+ }
+}
+
+func TestServer_ToolsCall_UpdatePrompt_NotFound(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call to update non-existent prompt
+ params := CallToolRequest{
+ Name: "update_prompt",
+ Arguments: map[string]interface{}{
+ "name": "nonexistent",
+ "title": "Updated",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 50,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ // Should not be a protocol error
+ if resp.Error != nil {
+ t.Fatalf("Unexpected protocol error: %v", resp.Error)
+ }
+
+ // Parse result - should be a tool error
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if !result.IsError {
+ t.Error("Expected IsError = true for non-existent prompt")
+ }
+}
+
+func TestServer_ToolsCall_DeletePrompt_NotFound(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call to delete non-existent prompt
+ params := CallToolRequest{
+ Name: "delete_prompt",
+ Arguments: map[string]interface{}{
+ "name": "nonexistent",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 51,
+ Method: "tools/call",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ // Should not be a protocol error
+ if resp.Error != nil {
+ t.Fatalf("Unexpected protocol error: %v", resp.Error)
+ }
+
+ // Parse result - should be a tool error
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result CallToolResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if !result.IsError {
+ t.Error("Expected IsError = true for non-existent prompt")
+ }
+}
+
+func TestServer_ToolsCall_InvalidParams(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send tools/call with invalid JSON params
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 52,
+ Method: "tools/call",
+ Params: json.RawMessage(`{invalid}`),
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error == nil {
+ t.Fatal("Expected protocol error for invalid params")
+ }
+
+ if resp.Error.Code != ErrCodeInvalidParams {
+ t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidParams)
+ }
+}
diff --git a/internal/promptstore/default_prompts.go b/internal/promptstore/default_prompts.go
index f8edefc..434855a 100644
--- a/internal/promptstore/default_prompts.go
+++ b/internal/promptstore/default_prompts.go
@@ -6,34 +6,50 @@ import (
)
// DefaultPrompts returns the built-in meta-prompts for prompt management.
-// These prompts help users create, update, delete, and design prompts interactively
-// using Claude's access to conversation context.
+// It assembles prompts from category-specific helpers so each remains concise.
func DefaultPrompts() []Prompt {
now := time.Now()
-
return []Prompt{
- {
- Name: "save_prompt",
- Title: "Save Current Conversation as Prompt",
- Description: "Interactively create a new prompt template from the current conversation. Claude will analyze the conversation, ask clarifying questions about templating, show a preview, and wait for approval before saving.",
- Arguments: []PromptArgument{
- {
- Name: "prompt_name",
- Description: "Unique identifier for the new prompt (lowercase, underscores allowed)",
- Required: true,
- },
- {
- Name: "prompt_title",
- Description: "Human-readable display name for the new prompt",
- Required: true,
- },
- },
- Messages: []PromptMessage{
- {
- Role: "user",
- Content: MessageContent{
- Type: "text",
- Text: `I want to create a new prompt template named '{{prompt_name}}' with title '{{prompt_title}}'.
+ buildSavePrompt(now),
+ buildUpdatePrompt(now),
+ buildDeletePrompt(now),
+ buildDesignPrompt(now),
+ }
+}
+
+// formattingRules is the shared instruction block for clarifying-question formatting
+// used by the save and update meta-prompts.
+const formattingRules = `IMPORTANT FORMATTING RULES for clarifying questions:
+- Use numbered questions: 1), 2), 3)
+- ANY CHOICE MUST BE NUMBERED using combined format: 1a), 1b), 1c), 2a), 2b), etc.
+- NEVER use standalone letters like "a)" - always combine with question number
+- NEVER use dashes (-) or bullets (•) for options
+- Every option must be numbered for easy selection by the user`
+
+// buildSavePrompt creates the "save_prompt" meta-prompt that turns a conversation
+// into a reusable prompt template.
+func buildSavePrompt(now time.Time) Prompt {
+ return Prompt{
+ Name: "save_prompt",
+ Title: "Save Current Conversation as Prompt",
+ Description: "Interactively create a new prompt template from the current conversation. Claude will analyze the conversation, ask clarifying questions about templating, show a preview, and wait for approval before saving.",
+ Arguments: []PromptArgument{
+ {Name: "prompt_name", Description: "Unique identifier for the new prompt (lowercase, underscores allowed)", Required: true},
+ {Name: "prompt_title", Description: "Human-readable display name for the new prompt", Required: true},
+ },
+ Messages: []PromptMessage{{
+ Role: "user",
+ Content: MessageContent{Type: "text", Text: savePromptText()},
+ }},
+ Tags: []string{"meta", "prompt-management", "interactive"},
+ Created: now,
+ Updated: now,
+ }
+}
+
+// savePromptText returns the user message body for the save_prompt meta-prompt.
+func savePromptText() string {
+ return `I want to create a new prompt template named '{{prompt_name}}' with title '{{prompt_title}}'.
Please help me by:
1) Analyzing our current conversation to understand what should be templated
@@ -45,12 +61,7 @@ Please help me by:
3) Showing me a complete preview of the prompt structure in a code block
4) Only after I approve, use the create_prompt tool to save it
-IMPORTANT FORMATTING RULES for clarifying questions:
-- Use numbered questions: 1), 2), 3)
-- ANY CHOICE MUST BE NUMBERED using combined format: 1a), 1b), 1c), 2a), 2b), etc.
-- NEVER use standalone letters like "a)" - always combine with question number
-- NEVER use dashes (-) or bullets (•) for options
-- Every option must be numbered for easy selection by the user
+` + formattingRules + `
Examples:
1) Question Category
@@ -77,31 +88,32 @@ Examples:
4b) Sub-question two?
Answer options here
-Start by examining our conversation and asking your clarifying questions using this format.`,
- },
- },
- },
- Tags: []string{"meta", "prompt-management", "interactive"},
- Created: now,
- Updated: now,
+Start by examining our conversation and asking your clarifying questions using this format.`
+}
+
+// buildUpdatePrompt creates the "update_prompt" meta-prompt for modifying
+// an existing prompt interactively.
+func buildUpdatePrompt(now time.Time) Prompt {
+ return Prompt{
+ Name: "update_prompt",
+ Title: "Update Existing Prompt",
+ Description: "Interactively modify an existing prompt. Claude will fetch the current version, ask what changes you want, show a preview with changes highlighted, and wait for approval before updating.",
+ Arguments: []PromptArgument{
+ {Name: "prompt_name", Description: "Name of the existing prompt to update", Required: true},
},
- {
- Name: "update_prompt",
- Title: "Update Existing Prompt",
- Description: "Interactively modify an existing prompt. Claude will fetch the current version, ask what changes you want, show a preview with changes highlighted, and wait for approval before updating.",
- Arguments: []PromptArgument{
- {
- Name: "prompt_name",
- Description: "Name of the existing prompt to update",
- Required: true,
- },
- },
- Messages: []PromptMessage{
- {
- Role: "user",
- Content: MessageContent{
- Type: "text",
- Text: `I want to update the existing prompt '{{prompt_name}}'.
+ Messages: []PromptMessage{{
+ Role: "user",
+ Content: MessageContent{Type: "text", Text: updatePromptText()},
+ }},
+ Tags: []string{"meta", "prompt-management", "interactive"},
+ Created: now,
+ Updated: now,
+ }
+}
+
+// updatePromptText returns the user message body for the update_prompt meta-prompt.
+func updatePromptText() string {
+ return `I want to update the existing prompt '{{prompt_name}}'.
Please help me by:
1) Ask me what changes I want to make to the prompt '{{prompt_name}}' (title, description, arguments, messages, or tags)
@@ -109,12 +121,7 @@ Please help me by:
3) Show me a complete preview of the updated prompt with changes highlighted
4) Only after I approve, use the update_prompt tool to save the changes
-IMPORTANT FORMATTING RULES for clarifying questions:
-- Use numbered questions: 1), 2), 3)
-- ANY CHOICE MUST BE NUMBERED using combined format: 1a), 1b), 1c), 2a), 2b), etc.
-- NEVER use standalone letters like "a)" - always combine with question number
-- NEVER use dashes (-) or bullets (•) for options
-- Every option must be numbered for easy selection by the user
+` + formattingRules + `
Examples:
1) Question Category
@@ -136,31 +143,22 @@ Examples:
3b) Second aspect to evaluate
3c) Third aspect to evaluate
-Start by asking me what changes I want to make, using this format for any clarifying questions.`,
- },
- },
- },
- Tags: []string{"meta", "prompt-management", "interactive"},
- Created: now,
- Updated: now,
+Start by asking me what changes I want to make, using this format for any clarifying questions.`
+}
+
+// buildDeletePrompt creates the "delete_prompt" meta-prompt for safely removing
+// a custom prompt with explicit confirmation.
+func buildDeletePrompt(now time.Time) Prompt {
+ return Prompt{
+ Name: "delete_prompt",
+ Title: "Delete Custom Prompt",
+ Description: "Interactively delete an existing custom prompt with confirmation. Claude will show the current prompt, ask for confirmation, and only delete after explicit approval. Built-in prompts cannot be deleted.",
+ Arguments: []PromptArgument{
+ {Name: "prompt_name", Description: "Name of the existing prompt to delete", Required: true},
},
- {
- Name: "delete_prompt",
- Title: "Delete Custom Prompt",
- Description: "Interactively delete an existing custom prompt with confirmation. Claude will show the current prompt, ask for confirmation, and only delete after explicit approval. Built-in prompts cannot be deleted.",
- Arguments: []PromptArgument{
- {
- Name: "prompt_name",
- Description: "Name of the existing prompt to delete",
- Required: true,
- },
- },
- Messages: []PromptMessage{
- {
- Role: "user",
- Content: MessageContent{
- Type: "text",
- Text: `I want to delete the existing prompt '{{prompt_name}}'.
+ Messages: []PromptMessage{{
+ Role: "user",
+ Content: MessageContent{Type: "text", Text: `I want to delete the existing prompt '{{prompt_name}}'.
Please help me by:
1) Confirm with me that I want to delete the prompt named '{{prompt_name}}'
@@ -173,25 +171,25 @@ IMPORTANT NOTES:
- Only custom prompts stored in user.jsonl can be deleted
- Backups are automatically created before deletion
-Ask me to confirm the deletion of '{{prompt_name}}'.`,
- },
- },
- },
- Tags: []string{"meta", "prompt-management", "interactive"},
- Created: now,
- Updated: now,
- },
- {
- Name: "design_prompt",
- Title: "Design New Prompt from Scratch",
- Description: "Interactively design a brand new prompt template through guided questions. Claude will help you define the purpose, arguments, message flow, and metadata step by step, show a preview, and wait for approval before saving.",
- Arguments: []PromptArgument{},
- Messages: []PromptMessage{
- {
- Role: "user",
- Content: MessageContent{
- Type: "text",
- Text: `I want to design a brand new prompt template from scratch.
+Ask me to confirm the deletion of '{{prompt_name}}'.`},
+ }},
+ Tags: []string{"meta", "prompt-management", "interactive"},
+ Created: now,
+ Updated: now,
+ }
+}
+
+// buildDesignPrompt creates the "design_prompt" meta-prompt for building
+// a brand new prompt template from scratch through guided questions.
+func buildDesignPrompt(now time.Time) Prompt {
+ return Prompt{
+ Name: "design_prompt",
+ Title: "Design New Prompt from Scratch",
+ Description: "Interactively design a brand new prompt template through guided questions. Claude will help you define the purpose, arguments, message flow, and metadata step by step, show a preview, and wait for approval before saving.",
+ Arguments: []PromptArgument{},
+ Messages: []PromptMessage{{
+ Role: "user",
+ Content: MessageContent{Type: "text", Text: `I want to design a brand new prompt template from scratch.
Please ask me:
1) What should this prompt do? (describe the task/purpose in 1-2 sentences)
@@ -200,13 +198,10 @@ Please ask me:
Then show me a preview and save it after I approve.
-Keep questions brief and focused.`,
- },
- },
- },
- Tags: []string{"meta", "prompt-management", "interactive", "creation"},
- Created: now,
- Updated: now,
- },
+Keep questions brief and focused.`},
+ }},
+ Tags: []string{"meta", "prompt-management", "interactive", "creation"},
+ Created: now,
+ Updated: now,
}
}
diff --git a/internal/stats/stats.go b/internal/stats/stats.go
index 95981c5..4b05617 100644
--- a/internal/stats/stats.go
+++ b/internal/stats/stats.go
@@ -83,54 +83,85 @@ func Update(ctx context.Context, provider, model string, sentBytes, recvBytes in
if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
+ unlock, err := lockStatsFile(ctx, dir)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = unlock() }()
+
+ path := filepath.Join(dir, fileName)
+ sf := readStatsFile(path)
+ now := time.Now()
+ win := Window()
+ sf.WindowSeconds = int(win.Seconds())
+ sf.Events = append(sf.Events, Event{
+ TS: now, Provider: provider, Model: model,
+ Sent: int64(sentBytes), Recv: int64(recvBytes),
+ })
+ pruneOldEvents(&sf, now.Add(-win))
+ sf.UpdatedAt = now
+ return writeStatsFileAtomic(dir, path, &sf)
+}
+
+// lockStatsFile acquires an advisory file lock on the stats lock file within dir.
+// Returns an unlock function on success.
+func lockStatsFile(ctx context.Context, dir string) (func() error, error) {
lockPath := filepath.Join(dir, lockFileName)
f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0o600)
if err != nil {
- return err
+ return nil, err
}
- defer func() { _ = f.Close() }()
unlock, err := acquireFileLock(ctx, f)
if err != nil {
- return err
+ _ = f.Close()
+ return nil, err
}
- defer func() { _ = unlock() }()
- // Read existing file (if any)
- path := filepath.Join(dir, fileName)
+ // Return a combined unlock+close function so the caller only needs one defer.
+ return func() error {
+ uErr := unlock()
+ cErr := f.Close()
+ if uErr != nil {
+ return uErr
+ }
+ return cErr
+ }, nil
+}
+
+// readStatsFile loads the on-disk stats file, returning a fresh File if it is
+// missing or has an incompatible version.
+func readStatsFile(path string) File {
var sf File
- if b, rerr := os.ReadFile(path); rerr == nil {
+ if b, err := os.ReadFile(path); err == nil {
_ = json.Unmarshal(b, &sf)
}
if sf.Version != fileVersion {
sf = File{Version: fileVersion}
}
- now := time.Now()
- win := Window()
- sf.WindowSeconds = int(win.Seconds())
- // Append event
- sf.Events = append(sf.Events, Event{TS: now, Provider: provider, Model: model, Sent: int64(sentBytes), Recv: int64(recvBytes)})
- // Prune old
- cutoff := now.Add(-win)
- if len(sf.Events) > 0 {
- // Find first >= cutoff
- i := 0
- for ; i < len(sf.Events); i++ {
- if !sf.Events[i].TS.Before(cutoff) {
- break
- }
- }
- if i > 0 {
- sf.Events = append([]Event(nil), sf.Events[i:]...)
+ return sf
+}
+
+// pruneOldEvents removes events older than cutoff in-place.
+func pruneOldEvents(sf *File, cutoff time.Time) {
+ i := 0
+ for ; i < len(sf.Events); i++ {
+ if !sf.Events[i].TS.Before(cutoff) {
+ break
}
}
- sf.UpdatedAt = now
- // Write atomically
+ if i > 0 {
+ sf.Events = append([]Event(nil), sf.Events[i:]...)
+ }
+}
+
+// writeStatsFileAtomic writes sf to path via a temp file + rename for crash safety.
+func writeStatsFileAtomic(dir, path string, sf *File) error {
tmp, err := os.CreateTemp(dir, fileName+".tmp.")
if err != nil {
return err
}
enc := json.NewEncoder(tmp)
enc.SetEscapeHTML(false)
- if err := enc.Encode(&sf); err != nil {
+ if err := enc.Encode(sf); err != nil {
_ = tmp.Close()
_ = os.Remove(tmp.Name())
return err