diff options
| -rw-r--r-- | internal/appconfig/config_features_test.go | 588 | ||||
| -rw-r--r-- | internal/appconfig/config_load.go | 98 | ||||
| -rw-r--r-- | internal/appconfig/config_test.go | 578 | ||||
| -rw-r--r-- | internal/mcp/handlers_prompt_test.go (renamed from internal/mcp/handlers_test.go) | 699 | ||||
| -rw-r--r-- | internal/mcp/handlers_tool_test.go | 703 | ||||
| -rw-r--r-- | internal/promptstore/default_prompts.go | 219 | ||||
| -rw-r--r-- | internal/stats/stats.go | 85 |
7 files changed, 1517 insertions, 1453 deletions
diff --git a/internal/appconfig/config_features_test.go b/internal/appconfig/config_features_test.go new file mode 100644 index 0000000..b3c12e9 --- /dev/null +++ b/internal/appconfig/config_features_test.go @@ -0,0 +1,588 @@ +// Summary: Tests for ignore config, tmux-edit config, and low-level parsing helpers +// (temperature, model entries, surface entries, resolved model). +package appconfig + +import ( + "os" + "path/filepath" + "reflect" + "testing" +) + +func TestIgnoreConfig_Defaults(t *testing.T) { + clearHexaiEnv(t) + cfg := Load(nil) + if cfg.IgnoreGitignore == nil || !*cfg.IgnoreGitignore { + t.Error("expected IgnoreGitignore default true") + } + if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify { + t.Error("expected IgnoreLSPNotify default true") + } + if len(cfg.IgnoreExtraPatterns) != 0 { + t.Errorf("expected empty IgnoreExtraPatterns, got %v", cfg.IgnoreExtraPatterns) + } +} + +func TestIgnoreConfig_FromFile(t *testing.T) { + clearHexaiEnv(t) + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.toml") + writeFile(t, cfgPath, ` +[ignore] +gitignore = false +extra_patterns = ["*.min.js", "dist/**"] +lsp_notify_ignored = false +`) + cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir}) + if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore { + t.Error("expected IgnoreGitignore false from file") + } + if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify { + t.Error("expected IgnoreLSPNotify false from file") + } + want := []string{"*.min.js", "dist/**"} + if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) { + t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want) + } +} + +func TestIgnoreConfig_EnvOverrides(t *testing.T) { + clearHexaiEnv(t) + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.toml") + writeFile(t, cfgPath, ` +[ignore] +gitignore = true +lsp_notify_ignored = true +`) + withEnv(t, "HEXAI_IGNORE_GITIGNORE", "false") + withEnv(t, "HEXAI_IGNORE_LSP_NOTIFY", "0") + withEnv(t, "HEXAI_IGNORE_EXTRA_PATTERNS", "*.bak,*.tmp") + cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir}) + if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore { + t.Error("expected IgnoreGitignore false from env override") + } + if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify { + t.Error("expected IgnoreLSPNotify false from env override") + } + want := []string{"*.bak", "*.tmp"} + if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) { + t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want) + } +} + +func TestIgnoreConfig_ProjectOverride(t *testing.T) { + clearHexaiEnv(t) + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.toml") + writeFile(t, cfgPath, ` +[ignore] +gitignore = true +`) + // Set up a fake git repo with project override + projectDir := t.TempDir() + if err := os.Mkdir(filepath.Join(projectDir, ".git"), 0o755); err != nil { + t.Fatalf("mkdir .git: %v", err) + } + projectCfg := filepath.Join(projectDir, ProjectConfigFilename) + writeFile(t, projectCfg, ` +[ignore] +gitignore = false +extra_patterns = ["build/**"] +`) + cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: projectDir}) + if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore { + t.Error("expected project override to set IgnoreGitignore false") + } + want := []string{"build/**"} + if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) { + t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want) + } +} + +func TestIgnoreConfig_DisableGitignore(t *testing.T) { + clearHexaiEnv(t) + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.toml") + writeFile(t, cfgPath, ` +[ignore] +gitignore = false +`) + cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir}) + if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore { + t.Error("expected IgnoreGitignore false") + } + // LSP notify should still be true (default, not overridden) + if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify { + t.Error("expected IgnoreLSPNotify to remain true (default)") + } +} + +func TestTmuxEditConfig_FromFile(t *testing.T) { + clearHexaiEnv(t) + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.toml") + writeFile(t, cfgPath, ` +[tmux_edit] +popup_width = "90%" +popup_height = "85%" +default_agent = "claude" + +[[tmux_edit.agents]] +name = "claude" +display_name = "Claude Code" +detect_pattern = "(?i)(claude|anthropic)" +prompt_pattern = '(?s)>\s*(.+?)$' +clear_first = true +clear_keys = "C-u" +newline_keys = "S-Enter" +submit_keys = "Enter" + +[[tmux_edit.agents]] +name = "cursor" +display_name = "Cursor" +detect_pattern = "(?i)cursor" +prompt_pattern = '(?s)│\s*(.+?)$' +strip_patterns = ["INSERT", "Add a follow-up"] +clear_first = true +clear_keys = "C-u" +newline_keys = "S-Enter" +submit_keys = "Enter" +`) + cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir}) + if cfg.TmuxEditPopupWidth != "90%" { + t.Errorf("PopupWidth = %q, want 90%%", cfg.TmuxEditPopupWidth) + } + if cfg.TmuxEditPopupHeight != "85%" { + t.Errorf("PopupHeight = %q, want 85%%", cfg.TmuxEditPopupHeight) + } + if cfg.TmuxEditDefaultAgent != "claude" { + t.Errorf("DefaultAgent = %q, want claude", cfg.TmuxEditDefaultAgent) + } + if len(cfg.TmuxEditAgents) != 2 { + t.Fatalf("got %d agents, want 2", len(cfg.TmuxEditAgents)) + } + a := cfg.TmuxEditAgents[0] + if a.Name != "claude" || a.DisplayName != "Claude Code" { + t.Errorf("agent[0] = %q/%q, want claude/Claude Code", a.Name, a.DisplayName) + } + if a.ClearFirst == nil || !*a.ClearFirst { + t.Error("expected ClearFirst = true for claude agent") + } + b := cfg.TmuxEditAgents[1] + if b.Name != "cursor" { + t.Errorf("agent[1].Name = %q, want cursor", b.Name) + } + if len(b.StripPatterns) != 2 { + t.Errorf("agent[1].StripPatterns = %v, want 2 entries", b.StripPatterns) + } +} + +func TestTmuxEditConfig_Merge(t *testing.T) { + clearHexaiEnv(t) + a := newDefaultConfig() + b := App{ + TmuxEditPopupWidth: "70%", + TmuxEditDefaultAgent: "amp", + TmuxEditAgents: []TmuxEditAgentCfg{ + {Name: "amp", DisplayName: "Amp"}, + }, + } + a.mergeWith(&b) + if a.TmuxEditPopupWidth != "70%" { + t.Errorf("PopupWidth = %q, want 70%%", a.TmuxEditPopupWidth) + } + if a.TmuxEditDefaultAgent != "amp" { + t.Errorf("DefaultAgent = %q, want amp", a.TmuxEditDefaultAgent) + } + if len(a.TmuxEditAgents) != 1 || a.TmuxEditAgents[0].Name != "amp" { + t.Errorf("Agents = %v, want single amp", a.TmuxEditAgents) + } +} + +func TestTmuxEditConfig_SkipsEmptyName(t *testing.T) { + clearHexaiEnv(t) + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.toml") + writeFile(t, cfgPath, ` +[tmux_edit] +[[tmux_edit.agents]] +name = "" +display_name = "Empty" +`) + cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir}) + if len(cfg.TmuxEditAgents) != 0 { + t.Errorf("got %d agents, want 0 (empty name should be skipped)", len(cfg.TmuxEditAgents)) + } +} + +// --- Config Parsing Tests --- + +func TestParseTemperatureValue(t *testing.T) { + tests := []struct { + name string + input any + wantValue *float64 + wantOK bool + }{ + {"float64 zero", float64(0.0), floatPtr(0.0), true}, + {"float64 half", float64(0.5), floatPtr(0.5), true}, + {"float64 one", float64(1.0), floatPtr(1.0), true}, + {"float64 two", float64(2.0), floatPtr(2.0), true}, + {"int64 zero", int64(0), floatPtr(0.0), true}, + {"int64 one", int64(1), floatPtr(1.0), true}, + {"int64 two", int64(2), floatPtr(2.0), true}, + {"string zero", "0", floatPtr(0.0), true}, + {"string one", "1", floatPtr(1.0), true}, + {"string two", "2", floatPtr(2.0), true}, + {"string float", "0.75", floatPtr(0.75), true}, + {"string empty", "", nil, true}, + {"string whitespace", " ", nil, true}, + {"string invalid", "invalid", nil, false}, + {"string negative", "-0.5", floatPtr(-0.5), true}, + {"string very small", "0.0001", floatPtr(0.0001), true}, + {"string high precision", "1.123456789", floatPtr(1.123456789), true}, + {"nil value", nil, nil, false}, + {"bool value", true, nil, false}, + {"map value", map[string]any{}, nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := parseTemperatureValue(tt.input, "test", newLogger()) + if ok != tt.wantOK { + t.Errorf("parseTemperatureValue() ok = %v, want %v", ok, tt.wantOK) + } + if !ok { + return + } + if (got == nil) != (tt.wantValue == nil) { + t.Errorf("parseTemperatureValue() = %v, want %v", got, tt.wantValue) + return + } + if got != nil && tt.wantValue != nil && *got != *tt.wantValue { + t.Errorf("parseTemperatureValue() = %v, want %v", *got, *tt.wantValue) + } + }) + } +} + +func TestDecodeModelEntry(t *testing.T) { + tests := []struct { + name string + input any + wantCfg *SurfaceConfig + wantOK bool + }{ + { + name: "simple string model", + input: "gpt-4", + wantCfg: &SurfaceConfig{Model: "gpt-4"}, + wantOK: true, + }, + { + name: "empty string", + input: "", + wantCfg: nil, + wantOK: false, + }, + { + name: "whitespace string", + input: " ", + wantCfg: nil, + wantOK: false, + }, + { + name: "object with all fields", + input: map[string]any{ + "model": "claude-3", + "provider": "anthropic", + "temperature": float64(0.7), + }, + wantCfg: &SurfaceConfig{ + Model: "claude-3", + Provider: "anthropic", + Temperature: floatPtr(0.7), + }, + wantOK: true, + }, + { + name: "object with model only", + input: map[string]any{ + "model": "gpt-4o", + }, + wantCfg: &SurfaceConfig{Model: "gpt-4o"}, + wantOK: true, + }, + { + name: "object with provider only", + input: map[string]any{ + "provider": "openai", + }, + wantCfg: &SurfaceConfig{Provider: "openai"}, + wantOK: true, + }, + { + name: "object with temperature only", + input: map[string]any{ + "temperature": float64(0.5), + }, + wantCfg: &SurfaceConfig{Temperature: floatPtr(0.5)}, + wantOK: true, + }, + { + name: "object with empty fields", + input: map[string]any{ + "model": "", + "provider": "", + }, + wantCfg: nil, + wantOK: false, + }, + { + name: "object with invalid model type", + input: map[string]any{ + "model": 123, + }, + wantCfg: nil, + wantOK: false, + }, + { + name: "object with invalid provider type", + input: map[string]any{ + "provider": 456, + }, + wantCfg: nil, + wantOK: false, + }, + { + name: "object with invalid temperature", + input: map[string]any{ + "model": "gpt-4", + "temperature": "not a number", + }, + wantCfg: nil, + wantOK: false, + }, + { + name: "nil input", + input: nil, + wantCfg: nil, + wantOK: false, + }, + { + name: "invalid type (int)", + input: 123, + wantCfg: nil, + wantOK: false, + }, + { + name: "invalid type (slice)", + input: []string{"gpt-4"}, + wantCfg: nil, + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := decodeModelEntry(tt.input, "test", newLogger()) + if ok != tt.wantOK { + t.Errorf("decodeModelEntry() ok = %v, want %v", ok, tt.wantOK) + } + if !ok { + return + } + if (got == nil) != (tt.wantCfg == nil) { + t.Errorf("decodeModelEntry() = %v, want %v", got, tt.wantCfg) + return + } + if got == nil { + return + } + if got.Model != tt.wantCfg.Model { + t.Errorf("Model = %q, want %q", got.Model, tt.wantCfg.Model) + } + if got.Provider != tt.wantCfg.Provider { + t.Errorf("Provider = %q, want %q", got.Provider, tt.wantCfg.Provider) + } + if (got.Temperature == nil) != (tt.wantCfg.Temperature == nil) { + t.Errorf("Temperature nil mismatch: got %v, want %v", got.Temperature, tt.wantCfg.Temperature) + } else if got.Temperature != nil && *got.Temperature != *tt.wantCfg.Temperature { + t.Errorf("Temperature = %v, want %v", *got.Temperature, *tt.wantCfg.Temperature) + } + }) + } +} + +func TestResolvedModel(t *testing.T) { + tests := []struct { + name string + section sectionOpenAI + want string + }{ + { + name: "explicit model no presets", + section: sectionOpenAI{Model: "gpt-4"}, + want: "gpt-4", + }, + { + name: "empty model", + section: sectionOpenAI{Model: ""}, + want: "", + }, + { + name: "whitespace model", + section: sectionOpenAI{Model: " "}, + want: "", + }, + { + name: "preset match exact case", + section: sectionOpenAI{ + Model: "fast", + Presets: map[string]string{"fast": "gpt-3.5-turbo"}, + }, + want: "gpt-3.5-turbo", + }, + { + name: "preset match case insensitive", + section: sectionOpenAI{ + Model: "FAST", + Presets: map[string]string{"fast": "gpt-3.5-turbo"}, + }, + want: "gpt-3.5-turbo", + }, + { + name: "no preset match returns original", + section: sectionOpenAI{ + Model: "custom-model", + Presets: map[string]string{"fast": "gpt-3.5-turbo"}, + }, + want: "custom-model", + }, + { + name: "preset empty value returns original", + section: sectionOpenAI{ + Model: "fast", + Presets: map[string]string{"fast": ""}, + }, + want: "fast", + }, + { + name: "preset whitespace value returns original", + section: sectionOpenAI{ + Model: "fast", + Presets: map[string]string{"fast": " "}, + }, + want: "fast", + }, + { + name: "multiple presets uses correct one", + section: sectionOpenAI{ + Model: "smart", + Presets: map[string]string{ + "fast": "gpt-3.5-turbo", + "smart": "gpt-4", + "mini": "gpt-3.5-mini", + }, + }, + want: "gpt-4", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.section.resolvedModel() + if got != tt.want { + t.Errorf("resolvedModel() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestParseSurfaceEntries(t *testing.T) { + tests := []struct { + name string + input any + wantLen int + wantOK bool + }{ + { + name: "nil input", + input: nil, + wantLen: 0, + wantOK: false, + }, + { + name: "single string", + input: "gpt-4", + wantLen: 1, + wantOK: true, + }, + { + name: "single map", + input: map[string]any{ + "model": "claude-3", + "provider": "anthropic", + }, + wantLen: 1, + wantOK: true, + }, + { + name: "array of strings", + input: []any{ + "gpt-4", + "claude-3", + }, + wantLen: 2, + wantOK: true, + }, + { + name: "array of maps", + input: []any{ + map[string]any{"model": "gpt-4", "provider": "openai"}, + map[string]any{"model": "claude-3", "provider": "anthropic"}, + }, + wantLen: 2, + wantOK: true, + }, + { + name: "array with invalid entries", + input: []any{ + "gpt-4", + 123, + "claude-3", + }, + wantLen: 2, + wantOK: true, + }, + { + name: "array with all invalid entries", + input: []any{ + 123, + true, + nil, + }, + wantLen: 0, + wantOK: false, + }, + { + name: "empty array", + input: []any{}, + wantLen: 0, + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := parseSurfaceEntries(tt.input, "test", newLogger()) + if ok != tt.wantOK { + t.Errorf("parseSurfaceEntries() ok = %v, want %v", ok, tt.wantOK) + } + if len(got) != tt.wantLen { + t.Errorf("parseSurfaceEntries() len = %d, want %d", len(got), tt.wantLen) + } + }) + } +} diff --git a/internal/appconfig/config_load.go b/internal/appconfig/config_load.go index 261835b..dc917ff 100644 --- a/internal/appconfig/config_load.go +++ b/internal/appconfig/config_load.go @@ -139,6 +139,8 @@ func loadProjectConfig(logger *log.Logger, opts LoadOptions, cfg *App) { } } +// loadFromFile reads a TOML config file, validates it, and returns the parsed App. +// Returns (nil, err) on I/O or parse errors; returns (nil, nil) when the file does not exist. func loadFromFile(path string, logger *log.Logger) (*App, error) { b, err := os.ReadFile(path) if err != nil { @@ -148,25 +150,50 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) { return nil, err } + tables, raw, err := decodeTOML(b, path, logger) + if err != nil { + return nil, err + } + if err := rejectLegacyKeys(raw); err != nil { + return nil, err + } + if logger != nil { + logger.Printf("loaded configuration from %s (TOML)", path) + } + + tab := tables.toApp() + applyRawIntOverrides(raw, &tab) + if m := parseSurfaceModels(raw, logger); m != nil { + tab.mergeSurfaceModels(m) + } + return &tab, nil +} + +// decodeTOML parses raw TOML bytes into both the typed fileConfig and a raw map +// for validation and defensive integer handling. +func decodeTOML(b []byte, path string, logger *log.Logger) (*fileConfig, map[string]any, error) { var tables fileConfig errTables := toml.NewDecoder(strings.NewReader(string(b))).Decode(&tables) - // Raw map for validation/presence checks var raw map[string]any errRaw := toml.Unmarshal(b, &raw) if errTables != nil { if logger != nil { logger.Printf("invalid TOML config file %s: %v", path, errTables) } - return nil, errTables + return nil, nil, errTables } if errRaw != nil { if logger != nil { logger.Printf("invalid TOML config file %s: %v", path, errRaw) } - return nil, errRaw + return nil, nil, errRaw } + return &tables, raw, nil +} - // Reject legacy flat keys at top-level (sectioned-only config is allowed) +// rejectLegacyKeys returns an error if the raw map contains flat keys from the +// old unsectioned config format. Only sectioned table keys are allowed. +func rejectLegacyKeys(raw map[string]any) error { legacy := map[string]struct{}{ "max_tokens": {}, "context_mode": {}, "context_window_lines": {}, "max_context_tokens": {}, "log_preview_limit": {}, "completion_debounce_ms": {}, "completion_throttle_ms": {}, @@ -175,53 +202,46 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) { "openai_model": {}, "openai_base_url": {}, "openai_temperature": {}, "ollama_model": {}, "ollama_base_url": {}, "ollama_temperature": {}, } + knownTables := map[string]struct{}{ + "general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {}, + "chat": {}, "provider": {}, "models": {}, "openai": {}, "ollama": {}, "prompts": {}, + } for k := range raw { - if _, isTable := map[string]struct{}{ - "general": {}, "logging": {}, "completion": {}, "triggers": {}, "inline": {}, - "chat": {}, "provider": {}, "models": {}, "openai": {}, "ollama": {}, "prompts": {}, - }[k]; isTable { + if _, isTable := knownTables[k]; isTable { continue } if _, isLegacy := legacy[k]; isLegacy { - return nil, fmt.Errorf("unsupported flat key '%s' in config; use sectioned tables (see config.toml.example)", k) + return fmt.Errorf("unsupported flat key '%s' in config; use sectioned tables (see config.toml.example)", k) } } + return nil +} - if logger != nil { - logger.Printf("loaded configuration from %s (TOML)", path) - } - - // Build App from tables only - tab := tables.toApp() - // Ensure explicit values from raw map are respected (defensive for ints) +// applyRawIntOverrides defensively re-applies integer values from the raw TOML map +// that the typed decoder may have silently zeroed (e.g. int vs float mismatch). +func applyRawIntOverrides(raw map[string]any, tab *App) { if t, ok := raw["completion"].(map[string]any); ok { - if v, present := t["manual_invoke_min_prefix"]; present { - switch vv := v.(type) { - case int64: - tab.ManualInvokeMinPrefix = int(vv) - case int: - tab.ManualInvokeMinPrefix = vv - case float64: - tab.ManualInvokeMinPrefix = int(vv) - } - } + applyRawInt(&tab.ManualInvokeMinPrefix, t, "manual_invoke_min_prefix") } if t, ok := raw["logging"].(map[string]any); ok { - if v, present := t["log_preview_limit"]; present { - switch vv := v.(type) { - case int64: - tab.LogPreviewLimit = int(vv) - case int: - tab.LogPreviewLimit = vv - case float64: - tab.LogPreviewLimit = int(vv) - } - } + applyRawInt(&tab.LogPreviewLimit, t, "log_preview_limit") } - if m := parseSurfaceModels(raw, logger); m != nil { - tab.mergeSurfaceModels(m) +} + +// applyRawInt sets *dst from table[key] when the value is a numeric type. +func applyRawInt(dst *int, table map[string]any, key string) { + v, present := table[key] + if !present { + return + } + switch vv := v.(type) { + case int64: + *dst = int(vv) + case int: + *dst = vv + case float64: + *dst = int(vv) } - return &tab, nil } func (fc *fileConfig) toApp() App { diff --git a/internal/appconfig/config_test.go b/internal/appconfig/config_test.go index c98d904..d4e7739 100644 --- a/internal/appconfig/config_test.go +++ b/internal/appconfig/config_test.go @@ -839,581 +839,3 @@ func TestProjectConfigPath(t *testing.T) { t.Fatalf("ProjectConfigPath() = %q, want empty", path) } } - -func TestIgnoreConfig_Defaults(t *testing.T) { - clearHexaiEnv(t) - cfg := Load(nil) - if cfg.IgnoreGitignore == nil || !*cfg.IgnoreGitignore { - t.Error("expected IgnoreGitignore default true") - } - if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify { - t.Error("expected IgnoreLSPNotify default true") - } - if len(cfg.IgnoreExtraPatterns) != 0 { - t.Errorf("expected empty IgnoreExtraPatterns, got %v", cfg.IgnoreExtraPatterns) - } -} - -func TestIgnoreConfig_FromFile(t *testing.T) { - clearHexaiEnv(t) - dir := t.TempDir() - cfgPath := filepath.Join(dir, "config.toml") - writeFile(t, cfgPath, ` -[ignore] -gitignore = false -extra_patterns = ["*.min.js", "dist/**"] -lsp_notify_ignored = false -`) - cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir}) - if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore { - t.Error("expected IgnoreGitignore false from file") - } - if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify { - t.Error("expected IgnoreLSPNotify false from file") - } - want := []string{"*.min.js", "dist/**"} - if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) { - t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want) - } -} - -func TestIgnoreConfig_EnvOverrides(t *testing.T) { - clearHexaiEnv(t) - dir := t.TempDir() - cfgPath := filepath.Join(dir, "config.toml") - writeFile(t, cfgPath, ` -[ignore] -gitignore = true -lsp_notify_ignored = true -`) - withEnv(t, "HEXAI_IGNORE_GITIGNORE", "false") - withEnv(t, "HEXAI_IGNORE_LSP_NOTIFY", "0") - withEnv(t, "HEXAI_IGNORE_EXTRA_PATTERNS", "*.bak,*.tmp") - cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir}) - if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore { - t.Error("expected IgnoreGitignore false from env override") - } - if cfg.IgnoreLSPNotify == nil || *cfg.IgnoreLSPNotify { - t.Error("expected IgnoreLSPNotify false from env override") - } - want := []string{"*.bak", "*.tmp"} - if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) { - t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want) - } -} - -func TestIgnoreConfig_ProjectOverride(t *testing.T) { - clearHexaiEnv(t) - dir := t.TempDir() - cfgPath := filepath.Join(dir, "config.toml") - writeFile(t, cfgPath, ` -[ignore] -gitignore = true -`) - // Set up a fake git repo with project override - projectDir := t.TempDir() - if err := os.Mkdir(filepath.Join(projectDir, ".git"), 0o755); err != nil { - t.Fatalf("mkdir .git: %v", err) - } - projectCfg := filepath.Join(projectDir, ProjectConfigFilename) - writeFile(t, projectCfg, ` -[ignore] -gitignore = false -extra_patterns = ["build/**"] -`) - cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: projectDir}) - if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore { - t.Error("expected project override to set IgnoreGitignore false") - } - want := []string{"build/**"} - if !reflect.DeepEqual(cfg.IgnoreExtraPatterns, want) { - t.Errorf("IgnoreExtraPatterns = %v, want %v", cfg.IgnoreExtraPatterns, want) - } -} - -func TestIgnoreConfig_DisableGitignore(t *testing.T) { - clearHexaiEnv(t) - dir := t.TempDir() - cfgPath := filepath.Join(dir, "config.toml") - writeFile(t, cfgPath, ` -[ignore] -gitignore = false -`) - cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir}) - if cfg.IgnoreGitignore == nil || *cfg.IgnoreGitignore { - t.Error("expected IgnoreGitignore false") - } - // LSP notify should still be true (default, not overridden) - if cfg.IgnoreLSPNotify == nil || !*cfg.IgnoreLSPNotify { - t.Error("expected IgnoreLSPNotify to remain true (default)") - } -} - -func TestTmuxEditConfig_FromFile(t *testing.T) { - clearHexaiEnv(t) - dir := t.TempDir() - cfgPath := filepath.Join(dir, "config.toml") - writeFile(t, cfgPath, ` -[tmux_edit] -popup_width = "90%" -popup_height = "85%" -default_agent = "claude" - -[[tmux_edit.agents]] -name = "claude" -display_name = "Claude Code" -detect_pattern = "(?i)(claude|anthropic)" -prompt_pattern = '(?s)>\s*(.+?)$' -clear_first = true -clear_keys = "C-u" -newline_keys = "S-Enter" -submit_keys = "Enter" - -[[tmux_edit.agents]] -name = "cursor" -display_name = "Cursor" -detect_pattern = "(?i)cursor" -prompt_pattern = '(?s)│\s*(.+?)$' -strip_patterns = ["INSERT", "Add a follow-up"] -clear_first = true -clear_keys = "C-u" -newline_keys = "S-Enter" -submit_keys = "Enter" -`) - cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir}) - if cfg.TmuxEditPopupWidth != "90%" { - t.Errorf("PopupWidth = %q, want 90%%", cfg.TmuxEditPopupWidth) - } - if cfg.TmuxEditPopupHeight != "85%" { - t.Errorf("PopupHeight = %q, want 85%%", cfg.TmuxEditPopupHeight) - } - if cfg.TmuxEditDefaultAgent != "claude" { - t.Errorf("DefaultAgent = %q, want claude", cfg.TmuxEditDefaultAgent) - } - if len(cfg.TmuxEditAgents) != 2 { - t.Fatalf("got %d agents, want 2", len(cfg.TmuxEditAgents)) - } - a := cfg.TmuxEditAgents[0] - if a.Name != "claude" || a.DisplayName != "Claude Code" { - t.Errorf("agent[0] = %q/%q, want claude/Claude Code", a.Name, a.DisplayName) - } - if a.ClearFirst == nil || !*a.ClearFirst { - t.Error("expected ClearFirst = true for claude agent") - } - b := cfg.TmuxEditAgents[1] - if b.Name != "cursor" { - t.Errorf("agent[1].Name = %q, want cursor", b.Name) - } - if len(b.StripPatterns) != 2 { - t.Errorf("agent[1].StripPatterns = %v, want 2 entries", b.StripPatterns) - } -} - -func TestTmuxEditConfig_Merge(t *testing.T) { - clearHexaiEnv(t) - a := newDefaultConfig() - b := App{ - TmuxEditPopupWidth: "70%", - TmuxEditDefaultAgent: "amp", - TmuxEditAgents: []TmuxEditAgentCfg{ - {Name: "amp", DisplayName: "Amp"}, - }, - } - a.mergeWith(&b) - if a.TmuxEditPopupWidth != "70%" { - t.Errorf("PopupWidth = %q, want 70%%", a.TmuxEditPopupWidth) - } - if a.TmuxEditDefaultAgent != "amp" { - t.Errorf("DefaultAgent = %q, want amp", a.TmuxEditDefaultAgent) - } - if len(a.TmuxEditAgents) != 1 || a.TmuxEditAgents[0].Name != "amp" { - t.Errorf("Agents = %v, want single amp", a.TmuxEditAgents) - } -} - -func TestTmuxEditConfig_SkipsEmptyName(t *testing.T) { - clearHexaiEnv(t) - dir := t.TempDir() - cfgPath := filepath.Join(dir, "config.toml") - writeFile(t, cfgPath, ` -[tmux_edit] -[[tmux_edit.agents]] -name = "" -display_name = "Empty" -`) - cfg := LoadWithOptions(newLogger(), LoadOptions{ConfigPath: cfgPath, ProjectRoot: dir}) - if len(cfg.TmuxEditAgents) != 0 { - t.Errorf("got %d agents, want 0 (empty name should be skipped)", len(cfg.TmuxEditAgents)) - } -} - -// --- Phase 1: Config Parsing Tests --- - -func TestParseTemperatureValue(t *testing.T) { - tests := []struct { - name string - input any - wantValue *float64 - wantOK bool - }{ - {"float64 zero", float64(0.0), floatPtr(0.0), true}, - {"float64 half", float64(0.5), floatPtr(0.5), true}, - {"float64 one", float64(1.0), floatPtr(1.0), true}, - {"float64 two", float64(2.0), floatPtr(2.0), true}, - {"int64 zero", int64(0), floatPtr(0.0), true}, - {"int64 one", int64(1), floatPtr(1.0), true}, - {"int64 two", int64(2), floatPtr(2.0), true}, - {"string zero", "0", floatPtr(0.0), true}, - {"string one", "1", floatPtr(1.0), true}, - {"string two", "2", floatPtr(2.0), true}, - {"string float", "0.75", floatPtr(0.75), true}, - {"string empty", "", nil, true}, - {"string whitespace", " ", nil, true}, - {"string invalid", "invalid", nil, false}, - {"string negative", "-0.5", floatPtr(-0.5), true}, - {"string very small", "0.0001", floatPtr(0.0001), true}, - {"string high precision", "1.123456789", floatPtr(1.123456789), true}, - {"nil value", nil, nil, false}, - {"bool value", true, nil, false}, - {"map value", map[string]any{}, nil, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, ok := parseTemperatureValue(tt.input, "test", newLogger()) - if ok != tt.wantOK { - t.Errorf("parseTemperatureValue() ok = %v, want %v", ok, tt.wantOK) - } - if !ok { - return - } - if (got == nil) != (tt.wantValue == nil) { - t.Errorf("parseTemperatureValue() = %v, want %v", got, tt.wantValue) - return - } - if got != nil && tt.wantValue != nil && *got != *tt.wantValue { - t.Errorf("parseTemperatureValue() = %v, want %v", *got, *tt.wantValue) - } - }) - } -} - -func TestDecodeModelEntry(t *testing.T) { - tests := []struct { - name string - input any - wantCfg *SurfaceConfig - wantOK bool - }{ - { - name: "simple string model", - input: "gpt-4", - wantCfg: &SurfaceConfig{Model: "gpt-4"}, - wantOK: true, - }, - { - name: "empty string", - input: "", - wantCfg: nil, - wantOK: false, - }, - { - name: "whitespace string", - input: " ", - wantCfg: nil, - wantOK: false, - }, - { - name: "object with all fields", - input: map[string]any{ - "model": "claude-3", - "provider": "anthropic", - "temperature": float64(0.7), - }, - wantCfg: &SurfaceConfig{ - Model: "claude-3", - Provider: "anthropic", - Temperature: floatPtr(0.7), - }, - wantOK: true, - }, - { - name: "object with model only", - input: map[string]any{ - "model": "gpt-4o", - }, - wantCfg: &SurfaceConfig{Model: "gpt-4o"}, - wantOK: true, - }, - { - name: "object with provider only", - input: map[string]any{ - "provider": "openai", - }, - wantCfg: &SurfaceConfig{Provider: "openai"}, - wantOK: true, - }, - { - name: "object with temperature only", - input: map[string]any{ - "temperature": float64(0.5), - }, - wantCfg: &SurfaceConfig{Temperature: floatPtr(0.5)}, - wantOK: true, - }, - { - name: "object with empty fields", - input: map[string]any{ - "model": "", - "provider": "", - }, - wantCfg: nil, - wantOK: false, - }, - { - name: "object with invalid model type", - input: map[string]any{ - "model": 123, - }, - wantCfg: nil, - wantOK: false, - }, - { - name: "object with invalid provider type", - input: map[string]any{ - "provider": 456, - }, - wantCfg: nil, - wantOK: false, - }, - { - name: "object with invalid temperature", - input: map[string]any{ - "model": "gpt-4", - "temperature": "not a number", - }, - wantCfg: nil, - wantOK: false, - }, - { - name: "nil input", - input: nil, - wantCfg: nil, - wantOK: false, - }, - { - name: "invalid type (int)", - input: 123, - wantCfg: nil, - wantOK: false, - }, - { - name: "invalid type (slice)", - input: []string{"gpt-4"}, - wantCfg: nil, - wantOK: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, ok := decodeModelEntry(tt.input, "test", newLogger()) - if ok != tt.wantOK { - t.Errorf("decodeModelEntry() ok = %v, want %v", ok, tt.wantOK) - } - if !ok { - return - } - if (got == nil) != (tt.wantCfg == nil) { - t.Errorf("decodeModelEntry() = %v, want %v", got, tt.wantCfg) - return - } - if got == nil { - return - } - if got.Model != tt.wantCfg.Model { - t.Errorf("Model = %q, want %q", got.Model, tt.wantCfg.Model) - } - if got.Provider != tt.wantCfg.Provider { - t.Errorf("Provider = %q, want %q", got.Provider, tt.wantCfg.Provider) - } - if (got.Temperature == nil) != (tt.wantCfg.Temperature == nil) { - t.Errorf("Temperature nil mismatch: got %v, want %v", got.Temperature, tt.wantCfg.Temperature) - } else if got.Temperature != nil && *got.Temperature != *tt.wantCfg.Temperature { - t.Errorf("Temperature = %v, want %v", *got.Temperature, *tt.wantCfg.Temperature) - } - }) - } -} - -func TestResolvedModel(t *testing.T) { - tests := []struct { - name string - section sectionOpenAI - want string - }{ - { - name: "explicit model no presets", - section: sectionOpenAI{Model: "gpt-4"}, - want: "gpt-4", - }, - { - name: "empty model", - section: sectionOpenAI{Model: ""}, - want: "", - }, - { - name: "whitespace model", - section: sectionOpenAI{Model: " "}, - want: "", - }, - { - name: "preset match exact case", - section: sectionOpenAI{ - Model: "fast", - Presets: map[string]string{"fast": "gpt-3.5-turbo"}, - }, - want: "gpt-3.5-turbo", - }, - { - name: "preset match case insensitive", - section: sectionOpenAI{ - Model: "FAST", - Presets: map[string]string{"fast": "gpt-3.5-turbo"}, - }, - want: "gpt-3.5-turbo", - }, - { - name: "no preset match returns original", - section: sectionOpenAI{ - Model: "custom-model", - Presets: map[string]string{"fast": "gpt-3.5-turbo"}, - }, - want: "custom-model", - }, - { - name: "preset empty value returns original", - section: sectionOpenAI{ - Model: "fast", - Presets: map[string]string{"fast": ""}, - }, - want: "fast", - }, - { - name: "preset whitespace value returns original", - section: sectionOpenAI{ - Model: "fast", - Presets: map[string]string{"fast": " "}, - }, - want: "fast", - }, - { - name: "multiple presets uses correct one", - section: sectionOpenAI{ - Model: "smart", - Presets: map[string]string{ - "fast": "gpt-3.5-turbo", - "smart": "gpt-4", - "mini": "gpt-3.5-mini", - }, - }, - want: "gpt-4", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.section.resolvedModel() - if got != tt.want { - t.Errorf("resolvedModel() = %q, want %q", got, tt.want) - } - }) - } -} - -func TestParseSurfaceEntries(t *testing.T) { - tests := []struct { - name string - input any - wantLen int - wantOK bool - }{ - { - name: "nil input", - input: nil, - wantLen: 0, - wantOK: false, - }, - { - name: "single string", - input: "gpt-4", - wantLen: 1, - wantOK: true, - }, - { - name: "single map", - input: map[string]any{ - "model": "claude-3", - "provider": "anthropic", - }, - wantLen: 1, - wantOK: true, - }, - { - name: "array of strings", - input: []any{ - "gpt-4", - "claude-3", - }, - wantLen: 2, - wantOK: true, - }, - { - name: "array of maps", - input: []any{ - map[string]any{"model": "gpt-4", "provider": "openai"}, - map[string]any{"model": "claude-3", "provider": "anthropic"}, - }, - wantLen: 2, - wantOK: true, - }, - { - name: "array with invalid entries", - input: []any{ - "gpt-4", - 123, - "claude-3", - }, - wantLen: 2, - wantOK: true, - }, - { - name: "array with all invalid entries", - input: []any{ - 123, - true, - nil, - }, - wantLen: 0, - wantOK: false, - }, - { - name: "empty array", - input: []any{}, - wantLen: 0, - wantOK: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, ok := parseSurfaceEntries(tt.input, "test", newLogger()) - if ok != tt.wantOK { - t.Errorf("parseSurfaceEntries() ok = %v, want %v", ok, tt.wantOK) - } - if len(got) != tt.wantLen { - t.Errorf("parseSurfaceEntries() len = %d, want %d", len(got), tt.wantLen) - } - }) - } -} diff --git a/internal/mcp/handlers_test.go b/internal/mcp/handlers_prompt_test.go index 2a4f821..ad0d261 100644 --- a/internal/mcp/handlers_test.go +++ b/internal/mcp/handlers_prompt_test.go @@ -1,4 +1,4 @@ -// Summary: Tests for MCP prompt management handlers +// Summary: Tests for MCP prompt management handlers (create, update, delete, get, list) package mcp import ( @@ -237,7 +237,7 @@ func TestServer_PromptsCreate_MissingName(t *testing.T) { } } -// Update mockPromptStore to support Create, Update, Delete +// mockPromptStore Create, Update, Delete methods used by prompt handler tests. func (m *mockPromptStore) Create(prompt *promptstore.Prompt) error { if _, exists := m.prompts[prompt.Name]; exists { return fmt.Errorf("prompt already exists: %s", prompt.Name) @@ -953,698 +953,3 @@ func TestServer_HandleInitialize_InvalidParams(t *testing.T) { t.Fatal("Expected error for invalid params") } } - -// ==================== Tools Tests ==================== - -func TestServer_ToolsList(t *testing.T) { - store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} - server, _, outBuf := createTestServer(t, store) - - // Initialize server - server.mu.Lock() - server.initialized = true - server.mu.Unlock() - - // Send tools/list request - req := Request{ - JSONRPC: "2.0", - ID: 40, - Method: "tools/list", - Params: json.RawMessage(`{}`), - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - if resp.Error != nil { - t.Fatalf("Error = %v, want nil", resp.Error) - } - - // Parse result - resultBytes, _ := json.Marshal(resp.Result) - var result ListToolsResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("Unmarshal result error = %v", err) - } - - // Verify 3 tools returned - if len(result.Tools) != 3 { - t.Errorf("Tools count = %d, want 3", len(result.Tools)) - } - - // Verify tool names - toolNames := make(map[string]bool) - for _, tool := range result.Tools { - toolNames[tool.Name] = true - if tool.Description == "" { - t.Errorf("Tool %s has empty description", tool.Name) - } - if tool.InputSchema == nil { - t.Errorf("Tool %s has nil InputSchema", tool.Name) - } - } - - expectedTools := []string{"create_prompt", "update_prompt", "delete_prompt"} - for _, name := range expectedTools { - if !toolNames[name] { - t.Errorf("Missing expected tool: %s", name) - } - } -} - -func TestServer_ToolsCall_CreatePrompt(t *testing.T) { - store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} - server, _, outBuf := createTestServer(t, store) - - // Initialize server - server.mu.Lock() - server.initialized = true - server.mu.Unlock() - - // Send tools/call request - params := CallToolRequest{ - Name: "create_prompt", - Arguments: map[string]interface{}{ - "name": "tool_test", - "title": "Tool Test Prompt", - "messages": []interface{}{ - map[string]interface{}{ - "role": "user", - "content": map[string]interface{}{ - "type": "text", - "text": "Test message", - }, - }, - }, - }, - } - paramsBytes, _ := json.Marshal(params) - req := Request{ - JSONRPC: "2.0", - ID: 41, - Method: "tools/call", - Params: paramsBytes, - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - if resp.Error != nil { - t.Fatalf("Error = %v, want nil", resp.Error) - } - - // Parse result - resultBytes, _ := json.Marshal(resp.Result) - var result CallToolResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("Unmarshal result error = %v", err) - } - - if result.IsError { - t.Errorf("IsError = true, want false. Content: %v", result.Content) - } - - if len(result.Content) == 0 { - t.Fatal("Expected content in result") - } - - // Verify prompt was created - if _, exists := store.prompts["tool_test"]; !exists { - t.Error("Prompt was not created in store") - } -} - -func TestServer_ToolsCall_UpdatePrompt(t *testing.T) { - now := time.Now() - store := &mockPromptStore{ - prompts: map[string]*promptstore.Prompt{ - "tool_update": { - Name: "tool_update", - Title: "Original Title", - Created: now, - Updated: now, - Messages: []promptstore.PromptMessage{ - { - Role: "user", - Content: promptstore.MessageContent{ - Type: "text", - Text: "Original", - }, - }, - }, - }, - }, - } - - server, _, outBuf := createTestServer(t, store) - - // Initialize server - server.mu.Lock() - server.initialized = true - server.mu.Unlock() - - // Send tools/call request - params := CallToolRequest{ - Name: "update_prompt", - Arguments: map[string]interface{}{ - "name": "tool_update", - "title": "Updated Title", - }, - } - paramsBytes, _ := json.Marshal(params) - req := Request{ - JSONRPC: "2.0", - ID: 42, - Method: "tools/call", - Params: paramsBytes, - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - if resp.Error != nil { - t.Fatalf("Error = %v, want nil", resp.Error) - } - - // Parse result - resultBytes, _ := json.Marshal(resp.Result) - var result CallToolResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("Unmarshal result error = %v", err) - } - - if result.IsError { - t.Errorf("IsError = true, want false. Content: %v", result.Content) - } - - // Verify prompt was updated - if store.prompts["tool_update"].Title != "Updated Title" { - t.Errorf("Title not updated, got %s", store.prompts["tool_update"].Title) - } -} - -func TestServer_ToolsCall_DeletePrompt(t *testing.T) { - now := time.Now() - store := &mockPromptStore{ - prompts: map[string]*promptstore.Prompt{ - "tool_delete": { - Name: "tool_delete", - Title: "To Delete", - Created: now, - Updated: now, - Messages: []promptstore.PromptMessage{ - { - Role: "user", - Content: promptstore.MessageContent{ - Type: "text", - Text: "Test", - }, - }, - }, - }, - }, - } - - server, _, outBuf := createTestServer(t, store) - - // Initialize server - server.mu.Lock() - server.initialized = true - server.mu.Unlock() - - // Send tools/call request - params := CallToolRequest{ - Name: "delete_prompt", - Arguments: map[string]interface{}{ - "name": "tool_delete", - }, - } - paramsBytes, _ := json.Marshal(params) - req := Request{ - JSONRPC: "2.0", - ID: 43, - Method: "tools/call", - Params: paramsBytes, - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - if resp.Error != nil { - t.Fatalf("Error = %v, want nil", resp.Error) - } - - // Parse result - resultBytes, _ := json.Marshal(resp.Result) - var result CallToolResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("Unmarshal result error = %v", err) - } - - if result.IsError { - t.Errorf("IsError = true, want false. Content: %v", result.Content) - } - - // Verify prompt was deleted - if _, exists := store.prompts["tool_delete"]; exists { - t.Error("Prompt was not deleted from store") - } -} - -func TestServer_ToolsCall_UnknownTool(t *testing.T) { - store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} - server, _, outBuf := createTestServer(t, store) - - // Initialize server - server.mu.Lock() - server.initialized = true - server.mu.Unlock() - - // Send tools/call request with unknown tool - params := CallToolRequest{ - Name: "nonexistent_tool", - Arguments: map[string]interface{}{}, - } - paramsBytes, _ := json.Marshal(params) - req := Request{ - JSONRPC: "2.0", - ID: 44, - Method: "tools/call", - Params: paramsBytes, - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - // Should not be a protocol error - if resp.Error != nil { - t.Fatalf("Unexpected protocol error: %v", resp.Error) - } - - // Parse result - should be a tool error - resultBytes, _ := json.Marshal(resp.Result) - var result CallToolResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("Unmarshal result error = %v", err) - } - - if !result.IsError { - t.Error("Expected IsError = true for unknown tool") - } -} - -func TestServer_ToolsCall_InvalidArguments(t *testing.T) { - store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} - server, _, outBuf := createTestServer(t, store) - - // Initialize server - server.mu.Lock() - server.initialized = true - server.mu.Unlock() - - // Send tools/call with invalid arguments (missing required fields) - params := CallToolRequest{ - Name: "create_prompt", - Arguments: map[string]interface{}{ - "name": "test", // Missing title and messages - }, - } - paramsBytes, _ := json.Marshal(params) - req := Request{ - JSONRPC: "2.0", - ID: 45, - Method: "tools/call", - Params: paramsBytes, - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - // Should not be a protocol error - if resp.Error != nil { - t.Fatalf("Unexpected protocol error: %v", resp.Error) - } - - // Parse result - should be a tool error - resultBytes, _ := json.Marshal(resp.Result) - var result CallToolResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("Unmarshal result error = %v", err) - } - - if !result.IsError { - t.Error("Expected IsError = true for invalid arguments") - } -} - -func TestServer_ToolsCall_NotInitialized(t *testing.T) { - store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} - server, _, outBuf := createTestServer(t, store) - - // DO NOT initialize server - - // Send tools/call request - params := CallToolRequest{ - Name: "create_prompt", - Arguments: map[string]interface{}{}, - } - paramsBytes, _ := json.Marshal(params) - req := Request{ - JSONRPC: "2.0", - ID: 46, - Method: "tools/call", - Params: paramsBytes, - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - if resp.Error == nil { - t.Fatal("Expected error for uninitialized server") - } - - if resp.Error.Code != ErrCodeInvalidRequest { - t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest) - } -} - -func TestServer_ToolsCall_CreatePrompt_AlreadyExists(t *testing.T) { - now := time.Now() - store := &mockPromptStore{ - prompts: map[string]*promptstore.Prompt{ - "existing": { - Name: "existing", - Title: "Existing", - Created: now, - Updated: now, - Messages: []promptstore.PromptMessage{}, - }, - }, - } - - server, _, outBuf := createTestServer(t, store) - - // Initialize server - server.mu.Lock() - server.initialized = true - server.mu.Unlock() - - // Send tools/call to create duplicate - params := CallToolRequest{ - Name: "create_prompt", - Arguments: map[string]interface{}{ - "name": "existing", - "title": "Duplicate", - "messages": []interface{}{ - map[string]interface{}{ - "role": "user", - "content": map[string]interface{}{ - "type": "text", - "text": "Test", - }, - }, - }, - }, - } - paramsBytes, _ := json.Marshal(params) - req := Request{ - JSONRPC: "2.0", - ID: 47, - Method: "tools/call", - Params: paramsBytes, - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - // Should not be a protocol error - if resp.Error != nil { - t.Fatalf("Unexpected protocol error: %v", resp.Error) - } - - // Parse result - should be a tool error - resultBytes, _ := json.Marshal(resp.Result) - var result CallToolResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("Unmarshal result error = %v", err) - } - - if !result.IsError { - t.Error("Expected IsError = true for duplicate prompt") - } -} - -func TestServer_Initialize_AdvertisesTools(t *testing.T) { - store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} - server, _, outBuf := createTestServer(t, store) - - // Send initialize request - params := InitializeRequest{ - ProtocolVersion: "2025-11-25", - Capabilities: ClientCapabilities{}, - ClientInfo: ClientInfo{ - Name: "test-client", - Version: "1.0", - }, - } - paramsBytes, _ := json.Marshal(params) - req := Request{ - JSONRPC: "2.0", - ID: 48, - Method: "initialize", - Params: paramsBytes, - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - if resp.Error != nil { - t.Fatalf("Error = %v, want nil", resp.Error) - } - - // Parse result - resultBytes, _ := json.Marshal(resp.Result) - var result InitializeResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("Unmarshal result error = %v", err) - } - - // Verify Tools capability is advertised - if result.Capabilities.Tools == nil { - t.Fatal("Tools capability not advertised") - } - - // Verify Prompts capability is also still advertised - if result.Capabilities.Prompts == nil { - t.Fatal("Prompts capability not advertised") - } -} - -func TestServer_ToolsList_NotInitialized(t *testing.T) { - store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} - server, _, outBuf := createTestServer(t, store) - - // DO NOT initialize server - - // Send tools/list request - req := Request{ - JSONRPC: "2.0", - ID: 49, - Method: "tools/list", - Params: json.RawMessage(`{}`), - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - if resp.Error == nil { - t.Fatal("Expected error for uninitialized server") - } - - if resp.Error.Code != ErrCodeInvalidRequest { - t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest) - } -} - -func TestServer_ToolsCall_UpdatePrompt_NotFound(t *testing.T) { - store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} - server, _, outBuf := createTestServer(t, store) - - // Initialize server - server.mu.Lock() - server.initialized = true - server.mu.Unlock() - - // Send tools/call to update non-existent prompt - params := CallToolRequest{ - Name: "update_prompt", - Arguments: map[string]interface{}{ - "name": "nonexistent", - "title": "Updated", - }, - } - paramsBytes, _ := json.Marshal(params) - req := Request{ - JSONRPC: "2.0", - ID: 50, - Method: "tools/call", - Params: paramsBytes, - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - // Should not be a protocol error - if resp.Error != nil { - t.Fatalf("Unexpected protocol error: %v", resp.Error) - } - - // Parse result - should be a tool error - resultBytes, _ := json.Marshal(resp.Result) - var result CallToolResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("Unmarshal result error = %v", err) - } - - if !result.IsError { - t.Error("Expected IsError = true for non-existent prompt") - } -} - -func TestServer_ToolsCall_DeletePrompt_NotFound(t *testing.T) { - store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} - server, _, outBuf := createTestServer(t, store) - - // Initialize server - server.mu.Lock() - server.initialized = true - server.mu.Unlock() - - // Send tools/call to delete non-existent prompt - params := CallToolRequest{ - Name: "delete_prompt", - Arguments: map[string]interface{}{ - "name": "nonexistent", - }, - } - paramsBytes, _ := json.Marshal(params) - req := Request{ - JSONRPC: "2.0", - ID: 51, - Method: "tools/call", - Params: paramsBytes, - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - // Should not be a protocol error - if resp.Error != nil { - t.Fatalf("Unexpected protocol error: %v", resp.Error) - } - - // Parse result - should be a tool error - resultBytes, _ := json.Marshal(resp.Result) - var result CallToolResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("Unmarshal result error = %v", err) - } - - if !result.IsError { - t.Error("Expected IsError = true for non-existent prompt") - } -} - -func TestServer_ToolsCall_InvalidParams(t *testing.T) { - store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} - server, _, outBuf := createTestServer(t, store) - - // Initialize server - server.mu.Lock() - server.initialized = true - server.mu.Unlock() - - // Send tools/call with invalid JSON params - req := Request{ - JSONRPC: "2.0", - ID: 52, - Method: "tools/call", - Params: json.RawMessage(`{invalid}`), - } - - server.handle(req) - - // Read response - resp, err := readResponse(outBuf) - if err != nil { - t.Fatalf("readResponse() error = %v", err) - } - - if resp.Error == nil { - t.Fatal("Expected protocol error for invalid params") - } - - if resp.Error.Code != ErrCodeInvalidParams { - t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidParams) - } -} diff --git a/internal/mcp/handlers_tool_test.go b/internal/mcp/handlers_tool_test.go new file mode 100644 index 0000000..604938a --- /dev/null +++ b/internal/mcp/handlers_tool_test.go @@ -0,0 +1,703 @@ +// Summary: Tests for MCP tool handlers (tools/list, tools/call for create/update/delete) +package mcp + +import ( + "encoding/json" + "testing" + "time" + + "codeberg.org/snonux/hexai/internal/promptstore" +) + +func TestServer_ToolsList(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send tools/list request + req := Request{ + JSONRPC: "2.0", + ID: 40, + Method: "tools/list", + Params: json.RawMessage(`{}`), + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result ListToolsResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + // Verify 3 tools returned + if len(result.Tools) != 3 { + t.Errorf("Tools count = %d, want 3", len(result.Tools)) + } + + // Verify tool names + toolNames := make(map[string]bool) + for _, tool := range result.Tools { + toolNames[tool.Name] = true + if tool.Description == "" { + t.Errorf("Tool %s has empty description", tool.Name) + } + if tool.InputSchema == nil { + t.Errorf("Tool %s has nil InputSchema", tool.Name) + } + } + + expectedTools := []string{"create_prompt", "update_prompt", "delete_prompt"} + for _, name := range expectedTools { + if !toolNames[name] { + t.Errorf("Missing expected tool: %s", name) + } + } +} + +func TestServer_ToolsCall_CreatePrompt(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send tools/call request + params := CallToolRequest{ + Name: "create_prompt", + Arguments: map[string]interface{}{ + "name": "tool_test", + "title": "Tool Test Prompt", + "messages": []interface{}{ + map[string]interface{}{ + "role": "user", + "content": map[string]interface{}{ + "type": "text", + "text": "Test message", + }, + }, + }, + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 41, + Method: "tools/call", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result CallToolResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if result.IsError { + t.Errorf("IsError = true, want false. Content: %v", result.Content) + } + + if len(result.Content) == 0 { + t.Fatal("Expected content in result") + } + + // Verify prompt was created + if _, exists := store.prompts["tool_test"]; !exists { + t.Error("Prompt was not created in store") + } +} + +func TestServer_ToolsCall_UpdatePrompt(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "tool_update": { + Name: "tool_update", + Title: "Original Title", + Created: now, + Updated: now, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Original", + }, + }, + }, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send tools/call request + params := CallToolRequest{ + Name: "update_prompt", + Arguments: map[string]interface{}{ + "name": "tool_update", + "title": "Updated Title", + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 42, + Method: "tools/call", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result CallToolResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if result.IsError { + t.Errorf("IsError = true, want false. Content: %v", result.Content) + } + + // Verify prompt was updated + if store.prompts["tool_update"].Title != "Updated Title" { + t.Errorf("Title not updated, got %s", store.prompts["tool_update"].Title) + } +} + +func TestServer_ToolsCall_DeletePrompt(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "tool_delete": { + Name: "tool_delete", + Title: "To Delete", + Created: now, + Updated: now, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Test", + }, + }, + }, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send tools/call request + params := CallToolRequest{ + Name: "delete_prompt", + Arguments: map[string]interface{}{ + "name": "tool_delete", + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 43, + Method: "tools/call", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result CallToolResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if result.IsError { + t.Errorf("IsError = true, want false. Content: %v", result.Content) + } + + // Verify prompt was deleted + if _, exists := store.prompts["tool_delete"]; exists { + t.Error("Prompt was not deleted from store") + } +} + +func TestServer_ToolsCall_UnknownTool(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send tools/call request with unknown tool + params := CallToolRequest{ + Name: "nonexistent_tool", + Arguments: map[string]interface{}{}, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 44, + Method: "tools/call", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + // Should not be a protocol error + if resp.Error != nil { + t.Fatalf("Unexpected protocol error: %v", resp.Error) + } + + // Parse result - should be a tool error + resultBytes, _ := json.Marshal(resp.Result) + var result CallToolResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if !result.IsError { + t.Error("Expected IsError = true for unknown tool") + } +} + +func TestServer_ToolsCall_InvalidArguments(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send tools/call with invalid arguments (missing required fields) + params := CallToolRequest{ + Name: "create_prompt", + Arguments: map[string]interface{}{ + "name": "test", // Missing title and messages + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 45, + Method: "tools/call", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + // Should not be a protocol error + if resp.Error != nil { + t.Fatalf("Unexpected protocol error: %v", resp.Error) + } + + // Parse result - should be a tool error + resultBytes, _ := json.Marshal(resp.Result) + var result CallToolResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if !result.IsError { + t.Error("Expected IsError = true for invalid arguments") + } +} + +func TestServer_ToolsCall_NotInitialized(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // DO NOT initialize server + + // Send tools/call request + params := CallToolRequest{ + Name: "create_prompt", + Arguments: map[string]interface{}{}, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 46, + Method: "tools/call", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for uninitialized server") + } + + if resp.Error.Code != ErrCodeInvalidRequest { + t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest) + } +} + +func TestServer_ToolsCall_CreatePrompt_AlreadyExists(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "existing": { + Name: "existing", + Title: "Existing", + Created: now, + Updated: now, + Messages: []promptstore.PromptMessage{}, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send tools/call to create duplicate + params := CallToolRequest{ + Name: "create_prompt", + Arguments: map[string]interface{}{ + "name": "existing", + "title": "Duplicate", + "messages": []interface{}{ + map[string]interface{}{ + "role": "user", + "content": map[string]interface{}{ + "type": "text", + "text": "Test", + }, + }, + }, + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 47, + Method: "tools/call", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + // Should not be a protocol error + if resp.Error != nil { + t.Fatalf("Unexpected protocol error: %v", resp.Error) + } + + // Parse result - should be a tool error + resultBytes, _ := json.Marshal(resp.Result) + var result CallToolResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if !result.IsError { + t.Error("Expected IsError = true for duplicate prompt") + } +} + +func TestServer_Initialize_AdvertisesTools(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Send initialize request + params := InitializeRequest{ + ProtocolVersion: "2025-11-25", + Capabilities: ClientCapabilities{}, + ClientInfo: ClientInfo{ + Name: "test-client", + Version: "1.0", + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 48, + Method: "initialize", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result InitializeResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + // Verify Tools capability is advertised + if result.Capabilities.Tools == nil { + t.Fatal("Tools capability not advertised") + } + + // Verify Prompts capability is also still advertised + if result.Capabilities.Prompts == nil { + t.Fatal("Prompts capability not advertised") + } +} + +func TestServer_ToolsList_NotInitialized(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // DO NOT initialize server + + // Send tools/list request + req := Request{ + JSONRPC: "2.0", + ID: 49, + Method: "tools/list", + Params: json.RawMessage(`{}`), + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for uninitialized server") + } + + if resp.Error.Code != ErrCodeInvalidRequest { + t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidRequest) + } +} + +func TestServer_ToolsCall_UpdatePrompt_NotFound(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send tools/call to update non-existent prompt + params := CallToolRequest{ + Name: "update_prompt", + Arguments: map[string]interface{}{ + "name": "nonexistent", + "title": "Updated", + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 50, + Method: "tools/call", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + // Should not be a protocol error + if resp.Error != nil { + t.Fatalf("Unexpected protocol error: %v", resp.Error) + } + + // Parse result - should be a tool error + resultBytes, _ := json.Marshal(resp.Result) + var result CallToolResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if !result.IsError { + t.Error("Expected IsError = true for non-existent prompt") + } +} + +func TestServer_ToolsCall_DeletePrompt_NotFound(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send tools/call to delete non-existent prompt + params := CallToolRequest{ + Name: "delete_prompt", + Arguments: map[string]interface{}{ + "name": "nonexistent", + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 51, + Method: "tools/call", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + // Should not be a protocol error + if resp.Error != nil { + t.Fatalf("Unexpected protocol error: %v", resp.Error) + } + + // Parse result - should be a tool error + resultBytes, _ := json.Marshal(resp.Result) + var result CallToolResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if !result.IsError { + t.Error("Expected IsError = true for non-existent prompt") + } +} + +func TestServer_ToolsCall_InvalidParams(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send tools/call with invalid JSON params + req := Request{ + JSONRPC: "2.0", + ID: 52, + Method: "tools/call", + Params: json.RawMessage(`{invalid}`), + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected protocol error for invalid params") + } + + if resp.Error.Code != ErrCodeInvalidParams { + t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidParams) + } +} diff --git a/internal/promptstore/default_prompts.go b/internal/promptstore/default_prompts.go index f8edefc..434855a 100644 --- a/internal/promptstore/default_prompts.go +++ b/internal/promptstore/default_prompts.go @@ -6,34 +6,50 @@ import ( ) // DefaultPrompts returns the built-in meta-prompts for prompt management. -// These prompts help users create, update, delete, and design prompts interactively -// using Claude's access to conversation context. +// It assembles prompts from category-specific helpers so each remains concise. func DefaultPrompts() []Prompt { now := time.Now() - return []Prompt{ - { - Name: "save_prompt", - Title: "Save Current Conversation as Prompt", - Description: "Interactively create a new prompt template from the current conversation. Claude will analyze the conversation, ask clarifying questions about templating, show a preview, and wait for approval before saving.", - Arguments: []PromptArgument{ - { - Name: "prompt_name", - Description: "Unique identifier for the new prompt (lowercase, underscores allowed)", - Required: true, - }, - { - Name: "prompt_title", - Description: "Human-readable display name for the new prompt", - Required: true, - }, - }, - Messages: []PromptMessage{ - { - Role: "user", - Content: MessageContent{ - Type: "text", - Text: `I want to create a new prompt template named '{{prompt_name}}' with title '{{prompt_title}}'. + buildSavePrompt(now), + buildUpdatePrompt(now), + buildDeletePrompt(now), + buildDesignPrompt(now), + } +} + +// formattingRules is the shared instruction block for clarifying-question formatting +// used by the save and update meta-prompts. +const formattingRules = `IMPORTANT FORMATTING RULES for clarifying questions: +- Use numbered questions: 1), 2), 3) +- ANY CHOICE MUST BE NUMBERED using combined format: 1a), 1b), 1c), 2a), 2b), etc. +- NEVER use standalone letters like "a)" - always combine with question number +- NEVER use dashes (-) or bullets (•) for options +- Every option must be numbered for easy selection by the user` + +// buildSavePrompt creates the "save_prompt" meta-prompt that turns a conversation +// into a reusable prompt template. +func buildSavePrompt(now time.Time) Prompt { + return Prompt{ + Name: "save_prompt", + Title: "Save Current Conversation as Prompt", + Description: "Interactively create a new prompt template from the current conversation. Claude will analyze the conversation, ask clarifying questions about templating, show a preview, and wait for approval before saving.", + Arguments: []PromptArgument{ + {Name: "prompt_name", Description: "Unique identifier for the new prompt (lowercase, underscores allowed)", Required: true}, + {Name: "prompt_title", Description: "Human-readable display name for the new prompt", Required: true}, + }, + Messages: []PromptMessage{{ + Role: "user", + Content: MessageContent{Type: "text", Text: savePromptText()}, + }}, + Tags: []string{"meta", "prompt-management", "interactive"}, + Created: now, + Updated: now, + } +} + +// savePromptText returns the user message body for the save_prompt meta-prompt. +func savePromptText() string { + return `I want to create a new prompt template named '{{prompt_name}}' with title '{{prompt_title}}'. Please help me by: 1) Analyzing our current conversation to understand what should be templated @@ -45,12 +61,7 @@ Please help me by: 3) Showing me a complete preview of the prompt structure in a code block 4) Only after I approve, use the create_prompt tool to save it -IMPORTANT FORMATTING RULES for clarifying questions: -- Use numbered questions: 1), 2), 3) -- ANY CHOICE MUST BE NUMBERED using combined format: 1a), 1b), 1c), 2a), 2b), etc. -- NEVER use standalone letters like "a)" - always combine with question number -- NEVER use dashes (-) or bullets (•) for options -- Every option must be numbered for easy selection by the user +` + formattingRules + ` Examples: 1) Question Category @@ -77,31 +88,32 @@ Examples: 4b) Sub-question two? Answer options here -Start by examining our conversation and asking your clarifying questions using this format.`, - }, - }, - }, - Tags: []string{"meta", "prompt-management", "interactive"}, - Created: now, - Updated: now, +Start by examining our conversation and asking your clarifying questions using this format.` +} + +// buildUpdatePrompt creates the "update_prompt" meta-prompt for modifying +// an existing prompt interactively. +func buildUpdatePrompt(now time.Time) Prompt { + return Prompt{ + Name: "update_prompt", + Title: "Update Existing Prompt", + Description: "Interactively modify an existing prompt. Claude will fetch the current version, ask what changes you want, show a preview with changes highlighted, and wait for approval before updating.", + Arguments: []PromptArgument{ + {Name: "prompt_name", Description: "Name of the existing prompt to update", Required: true}, }, - { - Name: "update_prompt", - Title: "Update Existing Prompt", - Description: "Interactively modify an existing prompt. Claude will fetch the current version, ask what changes you want, show a preview with changes highlighted, and wait for approval before updating.", - Arguments: []PromptArgument{ - { - Name: "prompt_name", - Description: "Name of the existing prompt to update", - Required: true, - }, - }, - Messages: []PromptMessage{ - { - Role: "user", - Content: MessageContent{ - Type: "text", - Text: `I want to update the existing prompt '{{prompt_name}}'. + Messages: []PromptMessage{{ + Role: "user", + Content: MessageContent{Type: "text", Text: updatePromptText()}, + }}, + Tags: []string{"meta", "prompt-management", "interactive"}, + Created: now, + Updated: now, + } +} + +// updatePromptText returns the user message body for the update_prompt meta-prompt. +func updatePromptText() string { + return `I want to update the existing prompt '{{prompt_name}}'. Please help me by: 1) Ask me what changes I want to make to the prompt '{{prompt_name}}' (title, description, arguments, messages, or tags) @@ -109,12 +121,7 @@ Please help me by: 3) Show me a complete preview of the updated prompt with changes highlighted 4) Only after I approve, use the update_prompt tool to save the changes -IMPORTANT FORMATTING RULES for clarifying questions: -- Use numbered questions: 1), 2), 3) -- ANY CHOICE MUST BE NUMBERED using combined format: 1a), 1b), 1c), 2a), 2b), etc. -- NEVER use standalone letters like "a)" - always combine with question number -- NEVER use dashes (-) or bullets (•) for options -- Every option must be numbered for easy selection by the user +` + formattingRules + ` Examples: 1) Question Category @@ -136,31 +143,22 @@ Examples: 3b) Second aspect to evaluate 3c) Third aspect to evaluate -Start by asking me what changes I want to make, using this format for any clarifying questions.`, - }, - }, - }, - Tags: []string{"meta", "prompt-management", "interactive"}, - Created: now, - Updated: now, +Start by asking me what changes I want to make, using this format for any clarifying questions.` +} + +// buildDeletePrompt creates the "delete_prompt" meta-prompt for safely removing +// a custom prompt with explicit confirmation. +func buildDeletePrompt(now time.Time) Prompt { + return Prompt{ + Name: "delete_prompt", + Title: "Delete Custom Prompt", + Description: "Interactively delete an existing custom prompt with confirmation. Claude will show the current prompt, ask for confirmation, and only delete after explicit approval. Built-in prompts cannot be deleted.", + Arguments: []PromptArgument{ + {Name: "prompt_name", Description: "Name of the existing prompt to delete", Required: true}, }, - { - Name: "delete_prompt", - Title: "Delete Custom Prompt", - Description: "Interactively delete an existing custom prompt with confirmation. Claude will show the current prompt, ask for confirmation, and only delete after explicit approval. Built-in prompts cannot be deleted.", - Arguments: []PromptArgument{ - { - Name: "prompt_name", - Description: "Name of the existing prompt to delete", - Required: true, - }, - }, - Messages: []PromptMessage{ - { - Role: "user", - Content: MessageContent{ - Type: "text", - Text: `I want to delete the existing prompt '{{prompt_name}}'. + Messages: []PromptMessage{{ + Role: "user", + Content: MessageContent{Type: "text", Text: `I want to delete the existing prompt '{{prompt_name}}'. Please help me by: 1) Confirm with me that I want to delete the prompt named '{{prompt_name}}' @@ -173,25 +171,25 @@ IMPORTANT NOTES: - Only custom prompts stored in user.jsonl can be deleted - Backups are automatically created before deletion -Ask me to confirm the deletion of '{{prompt_name}}'.`, - }, - }, - }, - Tags: []string{"meta", "prompt-management", "interactive"}, - Created: now, - Updated: now, - }, - { - Name: "design_prompt", - Title: "Design New Prompt from Scratch", - Description: "Interactively design a brand new prompt template through guided questions. Claude will help you define the purpose, arguments, message flow, and metadata step by step, show a preview, and wait for approval before saving.", - Arguments: []PromptArgument{}, - Messages: []PromptMessage{ - { - Role: "user", - Content: MessageContent{ - Type: "text", - Text: `I want to design a brand new prompt template from scratch. +Ask me to confirm the deletion of '{{prompt_name}}'.`}, + }}, + Tags: []string{"meta", "prompt-management", "interactive"}, + Created: now, + Updated: now, + } +} + +// buildDesignPrompt creates the "design_prompt" meta-prompt for building +// a brand new prompt template from scratch through guided questions. +func buildDesignPrompt(now time.Time) Prompt { + return Prompt{ + Name: "design_prompt", + Title: "Design New Prompt from Scratch", + Description: "Interactively design a brand new prompt template through guided questions. Claude will help you define the purpose, arguments, message flow, and metadata step by step, show a preview, and wait for approval before saving.", + Arguments: []PromptArgument{}, + Messages: []PromptMessage{{ + Role: "user", + Content: MessageContent{Type: "text", Text: `I want to design a brand new prompt template from scratch. Please ask me: 1) What should this prompt do? (describe the task/purpose in 1-2 sentences) @@ -200,13 +198,10 @@ Please ask me: Then show me a preview and save it after I approve. -Keep questions brief and focused.`, - }, - }, - }, - Tags: []string{"meta", "prompt-management", "interactive", "creation"}, - Created: now, - Updated: now, - }, +Keep questions brief and focused.`}, + }}, + Tags: []string{"meta", "prompt-management", "interactive", "creation"}, + Created: now, + Updated: now, } } diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 95981c5..4b05617 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -83,54 +83,85 @@ func Update(ctx context.Context, provider, model string, sentBytes, recvBytes in if err := os.MkdirAll(dir, 0o755); err != nil { return err } + unlock, err := lockStatsFile(ctx, dir) + if err != nil { + return err + } + defer func() { _ = unlock() }() + + path := filepath.Join(dir, fileName) + sf := readStatsFile(path) + now := time.Now() + win := Window() + sf.WindowSeconds = int(win.Seconds()) + sf.Events = append(sf.Events, Event{ + TS: now, Provider: provider, Model: model, + Sent: int64(sentBytes), Recv: int64(recvBytes), + }) + pruneOldEvents(&sf, now.Add(-win)) + sf.UpdatedAt = now + return writeStatsFileAtomic(dir, path, &sf) +} + +// lockStatsFile acquires an advisory file lock on the stats lock file within dir. +// Returns an unlock function on success. +func lockStatsFile(ctx context.Context, dir string) (func() error, error) { lockPath := filepath.Join(dir, lockFileName) f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0o600) if err != nil { - return err + return nil, err } - defer func() { _ = f.Close() }() unlock, err := acquireFileLock(ctx, f) if err != nil { - return err + _ = f.Close() + return nil, err } - defer func() { _ = unlock() }() - // Read existing file (if any) - path := filepath.Join(dir, fileName) + // Return a combined unlock+close function so the caller only needs one defer. + return func() error { + uErr := unlock() + cErr := f.Close() + if uErr != nil { + return uErr + } + return cErr + }, nil +} + +// readStatsFile loads the on-disk stats file, returning a fresh File if it is +// missing or has an incompatible version. +func readStatsFile(path string) File { var sf File - if b, rerr := os.ReadFile(path); rerr == nil { + if b, err := os.ReadFile(path); err == nil { _ = json.Unmarshal(b, &sf) } if sf.Version != fileVersion { sf = File{Version: fileVersion} } - now := time.Now() - win := Window() - sf.WindowSeconds = int(win.Seconds()) - // Append event - sf.Events = append(sf.Events, Event{TS: now, Provider: provider, Model: model, Sent: int64(sentBytes), Recv: int64(recvBytes)}) - // Prune old - cutoff := now.Add(-win) - if len(sf.Events) > 0 { - // Find first >= cutoff - i := 0 - for ; i < len(sf.Events); i++ { - if !sf.Events[i].TS.Before(cutoff) { - break - } - } - if i > 0 { - sf.Events = append([]Event(nil), sf.Events[i:]...) + return sf +} + +// pruneOldEvents removes events older than cutoff in-place. +func pruneOldEvents(sf *File, cutoff time.Time) { + i := 0 + for ; i < len(sf.Events); i++ { + if !sf.Events[i].TS.Before(cutoff) { + break } } - sf.UpdatedAt = now - // Write atomically + if i > 0 { + sf.Events = append([]Event(nil), sf.Events[i:]...) + } +} + +// writeStatsFileAtomic writes sf to path via a temp file + rename for crash safety. +func writeStatsFileAtomic(dir, path string, sf *File) error { tmp, err := os.CreateTemp(dir, fileName+".tmp.") if err != nil { return err } enc := json.NewEncoder(tmp) enc.SetEscapeHTML(false) - if err := enc.Encode(&sf); err != nil { + if err := enc.Encode(sf); err != nil { _ = tmp.Close() _ = os.Remove(tmp.Name()) return err |
