summaryrefslogtreecommitdiff
path: root/internal/mcp/server_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/mcp/server_test.go')
-rw-r--r--internal/mcp/server_test.go505
1 files changed, 505 insertions, 0 deletions
diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go
new file mode 100644
index 0000000..4a14ffc
--- /dev/null
+++ b/internal/mcp/server_test.go
@@ -0,0 +1,505 @@
+// Summary: Tests for MCP server operations
+package mcp
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "strings"
+ "testing"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/promptstore"
+)
+
+// mockPromptStore implements PromptStore for testing
+type mockPromptStore struct {
+ prompts map[string]*promptstore.Prompt
+ listFn func(string, int) ([]promptstore.Prompt, string, error)
+ getFn func(string) (*promptstore.Prompt, error)
+}
+
+func (m *mockPromptStore) List(cursor string, limit int) ([]promptstore.Prompt, string, error) {
+ if m.listFn != nil {
+ return m.listFn(cursor, limit)
+ }
+ var all []promptstore.Prompt
+ for _, p := range m.prompts {
+ all = append(all, *p)
+ }
+ return all, "", nil
+}
+
+func (m *mockPromptStore) Get(name string) (*promptstore.Prompt, error) {
+ if m.getFn != nil {
+ return m.getFn(name)
+ }
+ p, ok := m.prompts[name]
+ if !ok {
+ return nil, fmt.Errorf("prompt not found: %s", name)
+ }
+ return p, nil
+}
+
+// Create, Update, Delete methods moved to handlers_test.go to avoid duplication
+
+func (m *mockPromptStore) SearchByTags(tags []string) ([]promptstore.Prompt, error) {
+ return nil, fmt.Errorf("not implemented")
+}
+
+func createTestServer(t *testing.T, store promptstore.PromptStore) (*Server, *bytes.Buffer, *bytes.Buffer) {
+ t.Helper()
+ inBuf := &bytes.Buffer{}
+ outBuf := &bytes.Buffer{}
+ logger := log.New(io.Discard, "", 0)
+ return NewServer(inBuf, outBuf, logger, store), inBuf, outBuf
+}
+
+func sendRequest(w io.Writer, req Request) error {
+ data, err := json.Marshal(req)
+ if err != nil {
+ return err
+ }
+ header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
+ if _, err := io.WriteString(w, header); err != nil {
+ return err
+ }
+ if _, err := w.Write(data); err != nil {
+ return err
+ }
+ return nil
+}
+
+func readResponse(r io.Reader) (*Response, error) {
+ // Simple read for testing (assumes one message in buffer)
+ data, err := io.ReadAll(r)
+ if err != nil {
+ return nil, err
+ }
+
+ // Find Content-Length
+ lines := strings.Split(string(data), "\r\n")
+ var contentLength int
+ bodyStart := 0
+ for i, line := range lines {
+ if strings.HasPrefix(line, "Content-Length:") {
+ fmt.Sscanf(line, "Content-Length: %d", &contentLength)
+ }
+ if line == "" {
+ bodyStart = i + 1
+ break
+ }
+ }
+
+ body := strings.Join(lines[bodyStart:], "\r\n")
+ var resp Response
+ if err := json.Unmarshal([]byte(body), &resp); err != nil {
+ return nil, fmt.Errorf("unmarshal response: %w, body: %s", err, body)
+ }
+ return &resp, nil
+}
+
+func TestServer_Initialize(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, inBuf, outBuf := createTestServer(t, store)
+
+ // Send initialize request
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 1,
+ Method: "initialize",
+ }
+ params := InitializeRequest{
+ ProtocolVersion: "2025-06-18",
+ Capabilities: ClientCapabilities{},
+ ClientInfo: ClientInfo{
+ Name: "test-client",
+ Version: "1.0",
+ },
+ }
+ req.Params, _ = json.Marshal(params)
+
+ if err := sendRequest(inBuf, req); err != nil {
+ t.Fatalf("sendRequest() error = %v", err)
+ }
+
+ // Handle request
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.JSONRPC != "2.0" {
+ t.Errorf("JSONRPC = %v, want 2.0", resp.JSONRPC)
+ }
+ // ID comparison: JSON unmarshaling may convert int to float64
+ if fmt.Sprintf("%v", resp.ID) != "1" {
+ t.Errorf("ID = %v (type %T), want 1", resp.ID, resp.ID)
+ }
+ if resp.Error != nil {
+ t.Errorf("Error = %v, want nil", resp.Error)
+ }
+
+ // Verify server is initialized
+ server.mu.RLock()
+ initialized := server.initialized
+ server.mu.RUnlock()
+ if !initialized {
+ t.Error("Server not initialized")
+ }
+}
+
+func TestServer_PromptsList(t *testing.T) {
+ now := time.Now()
+ store := &mockPromptStore{
+ prompts: map[string]*promptstore.Prompt{
+ "test1": {
+ Name: "test1",
+ Title: "Test Prompt 1",
+ Description: "First test prompt",
+ Messages: []promptstore.PromptMessage{},
+ Created: now,
+ Updated: now,
+ },
+ "test2": {
+ Name: "test2",
+ Title: "Test Prompt 2",
+ Description: "Second test prompt",
+ Messages: []promptstore.PromptMessage{},
+ Created: now,
+ Updated: now,
+ },
+ },
+ }
+
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server first
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send prompts/list request
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 2,
+ Method: "prompts/list",
+ Params: json.RawMessage(`{}`),
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result ListPromptsResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if len(result.Prompts) != 2 {
+ t.Errorf("Prompts count = %d, want 2", len(result.Prompts))
+ }
+}
+
+func TestServer_PromptsGet(t *testing.T) {
+ now := time.Now()
+ store := &mockPromptStore{
+ prompts: map[string]*promptstore.Prompt{
+ "test_prompt": {
+ Name: "test_prompt",
+ Title: "Test Prompt",
+ Description: "A test prompt",
+ Arguments: []promptstore.PromptArgument{
+ {Name: "name", Description: "User name", Required: true},
+ },
+ Messages: []promptstore.PromptMessage{
+ {
+ Role: "user",
+ Content: promptstore.MessageContent{
+ Type: "text",
+ Text: "Hello {{name}}!",
+ },
+ },
+ },
+ Created: now,
+ Updated: now,
+ },
+ },
+ }
+
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send prompts/get request
+ params := GetPromptRequest{
+ Name: "test_prompt",
+ Arguments: map[string]string{
+ "name": "World",
+ },
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 3,
+ Method: "prompts/get",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Fatalf("Error = %v, want nil", resp.Error)
+ }
+
+ // Parse result
+ resultBytes, _ := json.Marshal(resp.Result)
+ var result GetPromptResult
+ if err := json.Unmarshal(resultBytes, &result); err != nil {
+ t.Fatalf("Unmarshal result error = %v", err)
+ }
+
+ if len(result.Messages) != 1 {
+ t.Errorf("Messages count = %d, want 1", len(result.Messages))
+ }
+ if result.Messages[0].Content.Text != "Hello World!" {
+ t.Errorf("Message text = %v, want 'Hello World!'", result.Messages[0].Content.Text)
+ }
+}
+
+func TestServer_PromptsGet_MissingArg(t *testing.T) {
+ now := time.Now()
+ store := &mockPromptStore{
+ prompts: map[string]*promptstore.Prompt{
+ "test_prompt": {
+ Name: "test_prompt",
+ Title: "Test Prompt",
+ Description: "A test prompt",
+ Arguments: []promptstore.PromptArgument{
+ {Name: "required_arg", Description: "Required", Required: true},
+ },
+ Messages: []promptstore.PromptMessage{
+ {
+ Role: "user",
+ Content: promptstore.MessageContent{
+ Type: "text",
+ Text: "Test: {{required_arg}}",
+ },
+ },
+ },
+ Created: now,
+ Updated: now,
+ },
+ },
+ }
+
+ server, _, outBuf := createTestServer(t, store)
+
+ // Initialize server
+ server.mu.Lock()
+ server.initialized = true
+ server.mu.Unlock()
+
+ // Send prompts/get request without required argument
+ params := GetPromptRequest{
+ Name: "test_prompt",
+ Arguments: map[string]string{},
+ }
+ paramsBytes, _ := json.Marshal(params)
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 4,
+ Method: "prompts/get",
+ Params: paramsBytes,
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error == nil {
+ t.Fatal("Expected error for missing required argument")
+ }
+ if !strings.Contains(resp.Error.Message, "required_arg") {
+ t.Errorf("Error message = %v, want to contain 'required_arg'", resp.Error.Message)
+ }
+}
+
+func TestServer_MethodNotFound(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, outBuf := createTestServer(t, store)
+
+ // Send request with unknown method
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 5,
+ Method: "unknown/method",
+ }
+
+ server.handle(req)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error == nil {
+ t.Fatal("Expected error for unknown method")
+ }
+ if resp.Error.Code != ErrCodeMethodNotFound {
+ t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeMethodNotFound)
+ }
+}
+
+func TestServer_Run(t *testing.T) {
+ t.Run("exits on EOF", func(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ // Empty reader will cause immediate EOF
+ inBuf := &bytes.Buffer{}
+ outBuf := &bytes.Buffer{}
+ logger := log.New(io.Discard, "", 0)
+ server := NewServer(inBuf, outBuf, logger, store)
+
+ err := server.Run()
+ if err != nil {
+ t.Errorf("Run() error = %v, want nil on EOF", err)
+ }
+ })
+
+ t.Run("processes initialize request", func(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ inBuf := &bytes.Buffer{}
+ outBuf := &bytes.Buffer{}
+ logger := log.New(io.Discard, "", 0)
+ server := NewServer(inBuf, outBuf, logger, store)
+
+ // Send initialize request
+ req := Request{
+ JSONRPC: "2.0",
+ ID: 1,
+ Method: "initialize",
+ }
+ params := InitializeRequest{
+ ProtocolVersion: "2025-06-18",
+ Capabilities: ClientCapabilities{},
+ ClientInfo: ClientInfo{
+ Name: "test-client",
+ Version: "1.0",
+ },
+ }
+ req.Params, _ = json.Marshal(params)
+
+ if err := sendRequest(inBuf, req); err != nil {
+ t.Fatalf("sendRequest() error = %v", err)
+ }
+
+ // Run in background
+ done := make(chan error, 1)
+ go func() {
+ done <- server.Run()
+ }()
+
+ // Give time for processing (server will block waiting for more input)
+ time.Sleep(50 * time.Millisecond)
+
+ // Read response
+ resp, err := readResponse(outBuf)
+ if err != nil {
+ t.Fatalf("readResponse() error = %v", err)
+ }
+
+ if resp.Error != nil {
+ t.Errorf("Error = %v, want nil", resp.Error)
+ }
+
+ // Verify server is initialized
+ server.mu.RLock()
+ initialized := server.initialized
+ server.mu.RUnlock()
+ if !initialized {
+ t.Error("Server not initialized")
+ }
+ })
+}
+
+func TestServer_ReadMessage(t *testing.T) {
+ t.Run("reads valid message", func(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ inBuf := &bytes.Buffer{}
+ outBuf := &bytes.Buffer{}
+ logger := log.New(io.Discard, "", 0)
+ server := NewServer(inBuf, outBuf, logger, store)
+
+ // Write a message with proper framing
+ msg := []byte(`{"jsonrpc":"2.0","id":1,"method":"test"}`)
+ header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(msg))
+ inBuf.WriteString(header)
+ inBuf.Write(msg)
+
+ // Read it back
+ body, err := server.readMessage()
+ if err != nil {
+ t.Fatalf("readMessage() error = %v", err)
+ }
+
+ if string(body) != string(msg) {
+ t.Errorf("readMessage() = %s, want %s", body, msg)
+ }
+ })
+
+ t.Run("returns EOF on empty stream", func(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ inBuf := &bytes.Buffer{}
+ outBuf := &bytes.Buffer{}
+ logger := log.New(io.Discard, "", 0)
+ server := NewServer(inBuf, outBuf, logger, store)
+
+ _, err := server.readMessage()
+ if err != io.EOF {
+ t.Errorf("readMessage() error = %v, want EOF", err)
+ }
+ })
+}
+
+func TestServer_HandleInitialized(t *testing.T) {
+ store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)}
+ server, _, _ := createTestServer(t, store)
+
+ // Send initialized notification
+ req := Request{
+ JSONRPC: "2.0",
+ Method: "initialized",
+ }
+
+ // This should not error (it's a notification, no response expected)
+ server.handleInitialized(req)
+}