diff options
Diffstat (limited to 'internal/mcp/server_test.go')
| -rw-r--r-- | internal/mcp/server_test.go | 505 |
1 files changed, 505 insertions, 0 deletions
diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go new file mode 100644 index 0000000..4a14ffc --- /dev/null +++ b/internal/mcp/server_test.go @@ -0,0 +1,505 @@ +// Summary: Tests for MCP server operations +package mcp + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "strings" + "testing" + "time" + + "codeberg.org/snonux/hexai/internal/promptstore" +) + +// mockPromptStore implements PromptStore for testing +type mockPromptStore struct { + prompts map[string]*promptstore.Prompt + listFn func(string, int) ([]promptstore.Prompt, string, error) + getFn func(string) (*promptstore.Prompt, error) +} + +func (m *mockPromptStore) List(cursor string, limit int) ([]promptstore.Prompt, string, error) { + if m.listFn != nil { + return m.listFn(cursor, limit) + } + var all []promptstore.Prompt + for _, p := range m.prompts { + all = append(all, *p) + } + return all, "", nil +} + +func (m *mockPromptStore) Get(name string) (*promptstore.Prompt, error) { + if m.getFn != nil { + return m.getFn(name) + } + p, ok := m.prompts[name] + if !ok { + return nil, fmt.Errorf("prompt not found: %s", name) + } + return p, nil +} + +// Create, Update, Delete methods moved to handlers_test.go to avoid duplication + +func (m *mockPromptStore) SearchByTags(tags []string) ([]promptstore.Prompt, error) { + return nil, fmt.Errorf("not implemented") +} + +func createTestServer(t *testing.T, store promptstore.PromptStore) (*Server, *bytes.Buffer, *bytes.Buffer) { + t.Helper() + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + return NewServer(inBuf, outBuf, logger, store), inBuf, outBuf +} + +func sendRequest(w io.Writer, req Request) error { + data, err := json.Marshal(req) + if err != nil { + return err + } + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data)) + if _, err := io.WriteString(w, header); err != nil { + return err + } + if _, err := w.Write(data); err != nil { + return err + } + return nil +} + +func readResponse(r io.Reader) (*Response, error) { + // Simple read for testing (assumes one message in buffer) + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + + // Find Content-Length + lines := strings.Split(string(data), "\r\n") + var contentLength int + bodyStart := 0 + for i, line := range lines { + if strings.HasPrefix(line, "Content-Length:") { + fmt.Sscanf(line, "Content-Length: %d", &contentLength) + } + if line == "" { + bodyStart = i + 1 + break + } + } + + body := strings.Join(lines[bodyStart:], "\r\n") + var resp Response + if err := json.Unmarshal([]byte(body), &resp); err != nil { + return nil, fmt.Errorf("unmarshal response: %w, body: %s", err, body) + } + return &resp, nil +} + +func TestServer_Initialize(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, inBuf, outBuf := createTestServer(t, store) + + // Send initialize request + req := Request{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + params := InitializeRequest{ + ProtocolVersion: "2025-06-18", + Capabilities: ClientCapabilities{}, + ClientInfo: ClientInfo{ + Name: "test-client", + Version: "1.0", + }, + } + req.Params, _ = json.Marshal(params) + + if err := sendRequest(inBuf, req); err != nil { + t.Fatalf("sendRequest() error = %v", err) + } + + // Handle request + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.JSONRPC != "2.0" { + t.Errorf("JSONRPC = %v, want 2.0", resp.JSONRPC) + } + // ID comparison: JSON unmarshaling may convert int to float64 + if fmt.Sprintf("%v", resp.ID) != "1" { + t.Errorf("ID = %v (type %T), want 1", resp.ID, resp.ID) + } + if resp.Error != nil { + t.Errorf("Error = %v, want nil", resp.Error) + } + + // Verify server is initialized + server.mu.RLock() + initialized := server.initialized + server.mu.RUnlock() + if !initialized { + t.Error("Server not initialized") + } +} + +func TestServer_PromptsList(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test1": { + Name: "test1", + Title: "Test Prompt 1", + Description: "First test prompt", + Messages: []promptstore.PromptMessage{}, + Created: now, + Updated: now, + }, + "test2": { + Name: "test2", + Title: "Test Prompt 2", + Description: "Second test prompt", + Messages: []promptstore.PromptMessage{}, + Created: now, + Updated: now, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server first + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/list request + req := Request{ + JSONRPC: "2.0", + ID: 2, + Method: "prompts/list", + Params: json.RawMessage(`{}`), + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result ListPromptsResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if len(result.Prompts) != 2 { + t.Errorf("Prompts count = %d, want 2", len(result.Prompts)) + } +} + +func TestServer_PromptsGet(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test_prompt": { + Name: "test_prompt", + Title: "Test Prompt", + Description: "A test prompt", + Arguments: []promptstore.PromptArgument{ + {Name: "name", Description: "User name", Required: true}, + }, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Hello {{name}}!", + }, + }, + }, + Created: now, + Updated: now, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/get request + params := GetPromptRequest{ + Name: "test_prompt", + Arguments: map[string]string{ + "name": "World", + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 3, + Method: "prompts/get", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result GetPromptResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if len(result.Messages) != 1 { + t.Errorf("Messages count = %d, want 1", len(result.Messages)) + } + if result.Messages[0].Content.Text != "Hello World!" { + t.Errorf("Message text = %v, want 'Hello World!'", result.Messages[0].Content.Text) + } +} + +func TestServer_PromptsGet_MissingArg(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test_prompt": { + Name: "test_prompt", + Title: "Test Prompt", + Description: "A test prompt", + Arguments: []promptstore.PromptArgument{ + {Name: "required_arg", Description: "Required", Required: true}, + }, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Test: {{required_arg}}", + }, + }, + }, + Created: now, + Updated: now, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/get request without required argument + params := GetPromptRequest{ + Name: "test_prompt", + Arguments: map[string]string{}, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 4, + Method: "prompts/get", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for missing required argument") + } + if !strings.Contains(resp.Error.Message, "required_arg") { + t.Errorf("Error message = %v, want to contain 'required_arg'", resp.Error.Message) + } +} + +func TestServer_MethodNotFound(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Send request with unknown method + req := Request{ + JSONRPC: "2.0", + ID: 5, + Method: "unknown/method", + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for unknown method") + } + if resp.Error.Code != ErrCodeMethodNotFound { + t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeMethodNotFound) + } +} + +func TestServer_Run(t *testing.T) { + t.Run("exits on EOF", func(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + // Empty reader will cause immediate EOF + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + server := NewServer(inBuf, outBuf, logger, store) + + err := server.Run() + if err != nil { + t.Errorf("Run() error = %v, want nil on EOF", err) + } + }) + + t.Run("processes initialize request", func(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + server := NewServer(inBuf, outBuf, logger, store) + + // Send initialize request + req := Request{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + params := InitializeRequest{ + ProtocolVersion: "2025-06-18", + Capabilities: ClientCapabilities{}, + ClientInfo: ClientInfo{ + Name: "test-client", + Version: "1.0", + }, + } + req.Params, _ = json.Marshal(params) + + if err := sendRequest(inBuf, req); err != nil { + t.Fatalf("sendRequest() error = %v", err) + } + + // Run in background + done := make(chan error, 1) + go func() { + done <- server.Run() + }() + + // Give time for processing (server will block waiting for more input) + time.Sleep(50 * time.Millisecond) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Errorf("Error = %v, want nil", resp.Error) + } + + // Verify server is initialized + server.mu.RLock() + initialized := server.initialized + server.mu.RUnlock() + if !initialized { + t.Error("Server not initialized") + } + }) +} + +func TestServer_ReadMessage(t *testing.T) { + t.Run("reads valid message", func(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + server := NewServer(inBuf, outBuf, logger, store) + + // Write a message with proper framing + msg := []byte(`{"jsonrpc":"2.0","id":1,"method":"test"}`) + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(msg)) + inBuf.WriteString(header) + inBuf.Write(msg) + + // Read it back + body, err := server.readMessage() + if err != nil { + t.Fatalf("readMessage() error = %v", err) + } + + if string(body) != string(msg) { + t.Errorf("readMessage() = %s, want %s", body, msg) + } + }) + + t.Run("returns EOF on empty stream", func(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + server := NewServer(inBuf, outBuf, logger, store) + + _, err := server.readMessage() + if err != io.EOF { + t.Errorf("readMessage() error = %v, want EOF", err) + } + }) +} + +func TestServer_HandleInitialized(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, _ := createTestServer(t, store) + + // Send initialized notification + req := Request{ + JSONRPC: "2.0", + Method: "initialized", + } + + // This should not error (it's a notification, no response expected) + server.handleInitialized(req) +} |
