diff options
Diffstat (limited to 'internal/mcp')
| -rw-r--r-- | internal/mcp/handlers_test.go | 955 | ||||
| -rw-r--r-- | internal/mcp/server.go | 494 | ||||
| -rw-r--r-- | internal/mcp/server_test.go | 505 | ||||
| -rw-r--r-- | internal/mcp/transport.go | 69 | ||||
| -rw-r--r-- | internal/mcp/types.go | 187 |
5 files changed, 2210 insertions, 0 deletions
diff --git a/internal/mcp/handlers_test.go b/internal/mcp/handlers_test.go new file mode 100644 index 0000000..79c567e --- /dev/null +++ b/internal/mcp/handlers_test.go @@ -0,0 +1,955 @@ +// Summary: Tests for MCP prompt management handlers +package mcp + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "testing" + "time" + + "codeberg.org/snonux/hexai/internal/promptstore" +) + +func TestServer_PromptsCreate(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/create request + params := CreatePromptRequest{ + Name: "test_create", + Title: "Test Create Prompt", + Description: "A test prompt", + Arguments: []PromptArgument{ + {Name: "input", Description: "Test input", Required: true}, + }, + Messages: []PromptMessage{ + { + Role: "user", + Content: MessageContent{ + Type: "text", + Text: "Test: {{input}}", + }, + }, + }, + Tags: []string{"test"}, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 10, + Method: "prompts/create", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result PromptOperationResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if !result.Success { + t.Errorf("Success = false, want true") + } +} + +func TestServer_PromptsUpdate(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test_update": { + Name: "test_update", + Title: "Original Title", + Created: now, + Updated: now, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Original text", + }, + }, + }, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/update request + params := UpdatePromptRequest{ + Name: "test_update", + Title: "Updated Title", + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 11, + Method: "prompts/update", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result PromptOperationResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if !result.Success { + t.Errorf("Success = false, want true") + } +} + +func TestServer_PromptsDelete(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test_delete": { + Name: "test_delete", + Title: "To Be Deleted", + Created: now, + Updated: now, + Messages: []promptstore.PromptMessage{}, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/delete request + params := DeletePromptRequest{ + Name: "test_delete", + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 12, + Method: "prompts/delete", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result PromptOperationResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if !result.Success { + t.Errorf("Success = false, want true") + } +} + +func TestServer_PromptsCreate_MissingName(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/create request without name + params := CreatePromptRequest{ + Title: "No Name", + Messages: []PromptMessage{ + {Role: "user", Content: MessageContent{Type: "text", Text: "Test"}}, + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 13, + Method: "prompts/create", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for missing name") + } + + if resp.Error.Code != ErrCodeInvalidParams { + t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeInvalidParams) + } +} + +// Update mockPromptStore to support Create, Update, Delete +func (m *mockPromptStore) Create(prompt *promptstore.Prompt) error { + if _, exists := m.prompts[prompt.Name]; exists { + return fmt.Errorf("prompt already exists: %s", prompt.Name) + } + m.prompts[prompt.Name] = prompt + return nil +} + +func (m *mockPromptStore) Update(prompt *promptstore.Prompt) error { + if _, exists := m.prompts[prompt.Name]; !exists { + return fmt.Errorf("prompt not found: %s", prompt.Name) + } + m.prompts[prompt.Name] = prompt + return nil +} + +func (m *mockPromptStore) Delete(name string) error { + if _, exists := m.prompts[name]; !exists { + return fmt.Errorf("prompt not found: %s", name) + } + delete(m.prompts, name) + return nil +} + +func TestServer_PromptsUpdate_NotFound(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/update request for non-existent prompt + params := UpdatePromptRequest{ + Name: "nonexistent", + Title: "Updated Title", + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 20, + Method: "prompts/update", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for non-existent prompt") + } +} + +func TestServer_PromptsUpdate_MissingName(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/update request without name + params := UpdatePromptRequest{ + Title: "Updated Title", + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 21, + Method: "prompts/update", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for missing name") + } +} + +func TestServer_PromptsDelete_NotFound(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/delete request for non-existent prompt + params := DeletePromptRequest{ + Name: "nonexistent", + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 22, + Method: "prompts/delete", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for non-existent prompt") + } +} + +func TestServer_PromptsDelete_MissingName(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/delete request without name + params := DeletePromptRequest{} + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 23, + Method: "prompts/delete", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for missing name") + } +} + +func TestServer_PromptsCreate_AlreadyExists(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "existing": { + Name: "existing", + Title: "Existing Prompt", + Created: now, + Updated: now, + Messages: []promptstore.PromptMessage{}, + }, + }, + } + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/create request with existing name + params := CreatePromptRequest{ + Name: "existing", + Title: "Duplicate", + Messages: []PromptMessage{ + {Role: "user", Content: MessageContent{Type: "text", Text: "Test"}}, + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 24, + Method: "prompts/create", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for duplicate prompt") + } +} + +func TestServer_PromptsGet_NotFound(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/get request for non-existent prompt + params := GetPromptRequest{ + Name: "nonexistent", + Arguments: map[string]string{}, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 25, + Method: "prompts/get", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for non-existent prompt") + } +} + +func TestServer_PromptsList_WithError(t *testing.T) { + store := &mockPromptStore{ + listFn: func(cursor string, limit int) ([]promptstore.Prompt, string, error) { + return nil, "", fmt.Errorf("store error") + }, + } + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/list request + req := Request{ + JSONRPC: "2.0", + ID: 26, + Method: "prompts/list", + Params: json.RawMessage(`{}`), + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error from store") + } +} + +func TestServer_PromptsUpdate_WithMessages(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test": { + Name: "test", + Title: "Original", + Created: now, + Updated: now, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Original", + }, + }, + }, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/update request with new messages + params := UpdatePromptRequest{ + Name: "test", + Title: "Updated", + Messages: []PromptMessage{ + { + Role: "user", + Content: MessageContent{ + Type: "text", + Text: "Updated message", + }, + }, + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 27, + Method: "prompts/update", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } +} + +func TestServer_PromptsCreate_WithAllFields(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/create request with all fields + params := CreatePromptRequest{ + Name: "complete", + Title: "Complete Prompt", + Description: "Full description", + Arguments: []PromptArgument{ + {Name: "arg1", Description: "First arg", Required: true}, + {Name: "arg2", Description: "Second arg", Required: false}, + }, + Messages: []PromptMessage{ + { + Role: "user", + Content: MessageContent{ + Type: "text", + Text: "Test {{arg1}} and {{arg2}}", + }, + }, + { + Role: "assistant", + Content: MessageContent{ + Type: "text", + Text: "Response", + }, + }, + }, + Tags: []string{"tag1", "tag2"}, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 28, + Method: "prompts/create", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result PromptOperationResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if !result.Success { + t.Errorf("Success = false, want true") + } +} + +func TestServer_PromptsList_WithCursorAndLimit(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test1": { + Name: "test1", + Title: "Test 1", + Created: now, + Updated: now, + }, + "test2": { + Name: "test2", + Title: "Test 2", + Created: now, + Updated: now, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/list request with cursor and limit + params := map[string]interface{}{ + "cursor": "test1", + "limit": 10, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 30, + Method: "prompts/list", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } +} + +func TestServer_PromptsUpdate_WithDescription(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test": { + Name: "test", + Title: "Original", + Created: now, + Updated: now, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Original", + }, + }, + }, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/update request with description + params := UpdatePromptRequest{ + Name: "test", + Description: "Updated description", + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 31, + Method: "prompts/update", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } +} + +func TestServer_PromptsUpdate_WithArguments(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test": { + Name: "test", + Title: "Original", + Created: now, + Updated: now, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Original", + }, + }, + }, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/update request with arguments + params := UpdatePromptRequest{ + Name: "test", + Arguments: []PromptArgument{ + {Name: "newarg", Description: "New argument", Required: true}, + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 32, + Method: "prompts/update", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } +} + +func TestServer_PromptsUpdate_WithTags(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test": { + Name: "test", + Title: "Original", + Created: now, + Updated: now, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Original", + }, + }, + }, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/update request with tags + params := UpdatePromptRequest{ + Name: "test", + Tags: []string{"newtag1", "newtag2"}, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 33, + Method: "prompts/update", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } +} + +func TestServer_Run_InvalidJSON(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + server := NewServer(inBuf, outBuf, logger, store) + + // Write invalid JSON + msg := []byte(`{invalid json}`) + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(msg)) + inBuf.WriteString(header) + inBuf.Write(msg) + + // Run in background + done := make(chan error, 1) + go func() { + done <- server.Run() + }() + + // Give time for processing + time.Sleep(50 * time.Millisecond) + + // Should have written error response + if outBuf.Len() == 0 { + t.Error("Expected error response to be written") + } +} + +func TestServer_PromptsCreate_MissingMessages(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/create request without messages + params := CreatePromptRequest{ + Name: "test", + Title: "Test", + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 34, + Method: "prompts/create", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for missing messages") + } +} + +func TestServer_HandleInitialize_InvalidParams(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Send initialize request with invalid params + req := Request{ + JSONRPC: "2.0", + ID: 35, + Method: "initialize", + Params: json.RawMessage(`{invalid}`), + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for invalid params") + } +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go new file mode 100644 index 0000000..f6479d9 --- /dev/null +++ b/internal/mcp/server.go @@ -0,0 +1,494 @@ +// Summary: MCP server over stdio; manages prompt store, dispatches requests, and handles protocol. +package mcp + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "log" + "strings" + "sync" + "time" + + "codeberg.org/snonux/hexai/internal" + "codeberg.org/snonux/hexai/internal/promptstore" +) + +// Server implements an MCP server over stdio using JSON-RPC 2.0. +// Follows the same pattern as the LSP server with dispatch table and thread safety. +type Server struct { + in *bufio.Reader + out io.Writer + outMu sync.Mutex + logger *log.Logger + store promptstore.PromptStore + initialized bool + mu sync.RWMutex + + // Dispatch table for JSON-RPC methods + handlers map[string]func(Request) +} + +// NewServer creates a new MCP server with the given store and I/O streams. +// The store provides access to prompts; logger is used for debugging. +func NewServer(r io.Reader, w io.Writer, logger *log.Logger, store promptstore.PromptStore) *Server { + s := &Server{ + in: bufio.NewReader(r), + out: w, + logger: logger, + store: store, + } + + // Initialize dispatch table + s.handlers = map[string]func(Request){ + "initialize": s.handleInitialize, + "initialized": s.handleInitialized, + "prompts/list": s.handlePromptsList, + "prompts/get": s.handlePromptsGet, + "prompts/create": s.handlePromptsCreate, + "prompts/update": s.handlePromptsUpdate, + "prompts/delete": s.handlePromptsDelete, + "notifications/initialized": s.handleInitialized, + } + + return s +} + +// Run starts the server main loop, reading and dispatching requests. +// Returns on EOF or fatal error. +func (s *Server) Run() error { + for { + body, err := s.readMessage() + if err == io.EOF { + return nil + } + if err != nil { + return fmt.Errorf("read message: %w", err) + } + + var req Request + if err := json.Unmarshal(body, &req); err != nil { + s.logger.Printf("invalid JSON: %v", err) + s.sendError(nil, ErrCodeParseError, "Parse error") + continue + } + + if req.Method == "" { + // Response from client; ignore + continue + } + + // Dispatch request + go s.handle(req) + } +} + +// handle dispatches a request to the appropriate handler. +func (s *Server) handle(req Request) { + handler, ok := s.handlers[req.Method] + if !ok { + s.logger.Printf("method not found: %s", req.Method) + s.sendError(req.ID, ErrCodeMethodNotFound, fmt.Sprintf("Method not found: %s", req.Method)) + return + } + + handler(req) +} + +// handleInitialize processes the initialize request and returns server capabilities. +func (s *Server) handleInitialize(req Request) { + var params InitializeRequest + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + s.sendError(req.ID, ErrCodeInvalidParams, "Invalid initialize params") + return + } + + s.logger.Printf("initialize from client: %s %s (protocol: %s)", + params.ClientInfo.Name, params.ClientInfo.Version, params.ProtocolVersion) + + // Validate protocol version (accept both old and new versions for compatibility) + if params.ProtocolVersion != "2024-11-05" && params.ProtocolVersion != "2025-06-18" { + s.logger.Printf("warning: unsupported protocol version: %s", params.ProtocolVersion) + } + + result := InitializeResult{ + ProtocolVersion: "2025-06-18", + Capabilities: ServerCapabilities{ + Prompts: &PromptsCapability{ + ListChanged: false, + Mutable: true, // Advertise that we support create/update/delete + }, + }, + ServerInfo: ServerInfo{ + Name: "hexai-mcp-server", + Version: internal.Version, + }, + } + + s.mu.Lock() + s.initialized = true + s.mu.Unlock() + + s.sendResponse(req.ID, result) +} + +// handleInitialized processes the initialized notification. +// This is sent by the client after receiving initialize response. +func (s *Server) handleInitialized(_ Request) { + s.logger.Printf("client sent initialized notification") + // No response required for notifications +} + +// handlePromptsList processes the prompts/list request. +func (s *Server) handlePromptsList(req Request) { + s.mu.RLock() + if !s.initialized { + s.mu.RUnlock() + s.sendError(req.ID, ErrCodeInvalidRequest, "Server not initialized") + return + } + s.mu.RUnlock() + + var params ListPromptsRequest + if req.Params != nil && len(req.Params) > 0 { + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + s.sendError(req.ID, ErrCodeInvalidParams, "Invalid prompts/list params") + return + } + } + + // List prompts from store + prompts, nextCursor, err := s.store.List(params.Cursor, 100) + if err != nil { + s.logger.Printf("store list error: %v", err) + s.sendError(req.ID, ErrCodeInternalError, "Failed to list prompts") + return + } + + // Convert to PromptInfo + var infos []PromptInfo + for _, p := range prompts { + args := make([]PromptArgument, len(p.Arguments)) + for i, a := range p.Arguments { + args[i] = PromptArgument{ + Name: a.Name, + Description: a.Description, + Required: a.Required, + } + } + infos = append(infos, PromptInfo{ + Name: p.Name, + Title: p.Title, + Description: p.Description, + Arguments: args, + }) + } + + result := ListPromptsResult{ + Prompts: infos, + NextCursor: nextCursor, + } + + s.sendResponse(req.ID, result) +} + +// handlePromptsGet processes the prompts/get request. +func (s *Server) handlePromptsGet(req Request) { + s.mu.RLock() + if !s.initialized { + s.mu.RUnlock() + s.sendError(req.ID, ErrCodeInvalidRequest, "Server not initialized") + return + } + s.mu.RUnlock() + + var params GetPromptRequest + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + s.sendError(req.ID, ErrCodeInvalidParams, "Invalid prompts/get params") + return + } + + if params.Name == "" { + s.sendError(req.ID, ErrCodeInvalidParams, "Missing prompt name") + return + } + + // Get prompt from store + prompt, err := s.store.Get(params.Name) + if err != nil { + s.logger.Printf("store get error: %v", err) + s.sendError(req.ID, ErrCodeInvalidParams, fmt.Sprintf("Prompt not found: %s", params.Name)) + return + } + + // Render prompt with arguments + messages, err := s.renderPrompt(prompt, params.Arguments) + if err != nil { + s.logger.Printf("render error: %v", err) + s.sendError(req.ID, ErrCodeInvalidParams, err.Error()) + return + } + + result := GetPromptResult{ + Description: prompt.Description, + Messages: messages, + } + + s.sendResponse(req.ID, result) +} + +// renderPrompt substitutes template arguments in prompt messages. +// Returns error if required arguments are missing. +func (s *Server) renderPrompt(prompt *promptstore.Prompt, args map[string]string) ([]PromptMessage, error) { + // Validate required arguments + for _, arg := range prompt.Arguments { + if arg.Required { + if _, ok := args[arg.Name]; !ok { + return nil, fmt.Errorf("missing required argument: %s", arg.Name) + } + } + } + + // Render each message + var rendered []PromptMessage + for _, msg := range prompt.Messages { + text := msg.Content.Text + + // Simple template substitution: {{arg}} -> value + for key, val := range args { + placeholder := "{{" + key + "}}" + text = strings.ReplaceAll(text, placeholder, val) + } + + rendered = append(rendered, PromptMessage{ + Role: msg.Role, + Content: MessageContent{ + Type: msg.Content.Type, + Text: text, + }, + }) + } + + return rendered, nil +} + +// sendResponse sends a successful JSON-RPC response. +func (s *Server) sendResponse(id any, result any) { + resp := Response{ + JSONRPC: "2.0", + ID: id, + Result: result, + } + if err := s.writeMessage(resp); err != nil { + s.logger.Printf("write response error: %v", err) + } +} + +// sendError sends an error JSON-RPC response. +func (s *Server) sendError(id any, code int, message string) { + resp := Response{ + JSONRPC: "2.0", + ID: id, + Error: &RespError{ + Code: code, + Message: message, + }, + } + if err := s.writeMessage(resp); err != nil { + s.logger.Printf("write error response error: %v", err) + } +} + +// handlePromptsCreate processes the prompts/create request. +// Creates a new custom prompt in user.jsonl. +func (s *Server) handlePromptsCreate(req Request) { + s.mu.RLock() + if !s.initialized { + s.mu.RUnlock() + s.sendError(req.ID, ErrCodeInvalidRequest, "Server not initialized") + return + } + s.mu.RUnlock() + + var params CreatePromptRequest + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + s.sendError(req.ID, ErrCodeInvalidParams, "Invalid prompts/create params") + return + } + + // Validate required fields + if params.Name == "" { + s.sendError(req.ID, ErrCodeInvalidParams, "Prompt name is required") + return + } + if params.Title == "" { + s.sendError(req.ID, ErrCodeInvalidParams, "Prompt title is required") + return + } + if len(params.Messages) == 0 { + s.sendError(req.ID, ErrCodeInvalidParams, "At least one message is required") + return + } + + // Create prompt + prompt := &promptstore.Prompt{ + Name: params.Name, + Title: params.Title, + Description: params.Description, + Tags: params.Tags, + Created: time.Now(), + Updated: time.Now(), + } + + // Convert arguments + for _, arg := range params.Arguments { + prompt.Arguments = append(prompt.Arguments, promptstore.PromptArgument{ + Name: arg.Name, + Description: arg.Description, + Required: arg.Required, + }) + } + + // Convert messages + for _, msg := range params.Messages { + prompt.Messages = append(prompt.Messages, promptstore.PromptMessage{ + Role: msg.Role, + Content: promptstore.MessageContent{ + Type: msg.Content.Type, + Text: msg.Content.Text, + }, + }) + } + + if err := s.store.Create(prompt); err != nil { + s.logger.Printf("create prompt error: %v", err) + s.sendError(req.ID, ErrCodeInternalError, fmt.Sprintf("Failed to create prompt: %v", err)) + return + } + + result := PromptOperationResult{ + Success: true, + Message: fmt.Sprintf("Created prompt: %s", params.Name), + } + + s.logger.Printf("created prompt: %s", params.Name) + s.sendResponse(req.ID, result) +} + +// handlePromptsUpdate processes the prompts/update request. +// Updates an existing custom prompt in user.jsonl. +func (s *Server) handlePromptsUpdate(req Request) { + s.mu.RLock() + if !s.initialized { + s.mu.RUnlock() + s.sendError(req.ID, ErrCodeInvalidRequest, "Server not initialized") + return + } + s.mu.RUnlock() + + var params UpdatePromptRequest + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + s.sendError(req.ID, ErrCodeInvalidParams, "Invalid prompts/update params") + return + } + + if params.Name == "" { + s.sendError(req.ID, ErrCodeInvalidParams, "Prompt name is required") + return + } + + // Get existing prompt + existing, err := s.store.Get(params.Name) + if err != nil { + s.logger.Printf("get prompt error: %v", err) + s.sendError(req.ID, ErrCodeInvalidParams, fmt.Sprintf("Prompt not found: %s", params.Name)) + return + } + + // Update fields (only if provided) + if params.Title != "" { + existing.Title = params.Title + } + if params.Description != "" { + existing.Description = params.Description + } + if len(params.Arguments) > 0 { + existing.Arguments = nil + for _, arg := range params.Arguments { + existing.Arguments = append(existing.Arguments, promptstore.PromptArgument{ + Name: arg.Name, + Description: arg.Description, + Required: arg.Required, + }) + } + } + if len(params.Messages) > 0 { + existing.Messages = nil + for _, msg := range params.Messages { + existing.Messages = append(existing.Messages, promptstore.PromptMessage{ + Role: msg.Role, + Content: promptstore.MessageContent{ + Type: msg.Content.Type, + Text: msg.Content.Text, + }, + }) + } + } + if len(params.Tags) > 0 { + existing.Tags = params.Tags + } + + existing.Updated = time.Now() + + if err := s.store.Update(existing); err != nil { + s.logger.Printf("update prompt error: %v", err) + s.sendError(req.ID, ErrCodeInternalError, fmt.Sprintf("Failed to update prompt: %v", err)) + return + } + + result := PromptOperationResult{ + Success: true, + Message: fmt.Sprintf("Updated prompt: %s", params.Name), + } + + s.logger.Printf("updated prompt: %s", params.Name) + s.sendResponse(req.ID, result) +} + +// handlePromptsDelete processes the prompts/delete request. +// Deletes a custom prompt from user.jsonl. +func (s *Server) handlePromptsDelete(req Request) { + s.mu.RLock() + if !s.initialized { + s.mu.RUnlock() + s.sendError(req.ID, ErrCodeInvalidRequest, "Server not initialized") + return + } + s.mu.RUnlock() + + var params DeletePromptRequest + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + s.sendError(req.ID, ErrCodeInvalidParams, "Invalid prompts/delete params") + return + } + + if params.Name == "" { + s.sendError(req.ID, ErrCodeInvalidParams, "Prompt name is required") + return + } + + if err := s.store.Delete(params.Name); err != nil { + s.logger.Printf("delete prompt error: %v", err) + s.sendError(req.ID, ErrCodeInternalError, fmt.Sprintf("Failed to delete prompt: %v", err)) + return + } + + result := PromptOperationResult{ + Success: true, + Message: fmt.Sprintf("Deleted prompt: %s", params.Name), + } + + s.logger.Printf("deleted prompt: %s", params.Name) + s.sendResponse(req.ID, result) +} diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go new file mode 100644 index 0000000..4a14ffc --- /dev/null +++ b/internal/mcp/server_test.go @@ -0,0 +1,505 @@ +// Summary: Tests for MCP server operations +package mcp + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "strings" + "testing" + "time" + + "codeberg.org/snonux/hexai/internal/promptstore" +) + +// mockPromptStore implements PromptStore for testing +type mockPromptStore struct { + prompts map[string]*promptstore.Prompt + listFn func(string, int) ([]promptstore.Prompt, string, error) + getFn func(string) (*promptstore.Prompt, error) +} + +func (m *mockPromptStore) List(cursor string, limit int) ([]promptstore.Prompt, string, error) { + if m.listFn != nil { + return m.listFn(cursor, limit) + } + var all []promptstore.Prompt + for _, p := range m.prompts { + all = append(all, *p) + } + return all, "", nil +} + +func (m *mockPromptStore) Get(name string) (*promptstore.Prompt, error) { + if m.getFn != nil { + return m.getFn(name) + } + p, ok := m.prompts[name] + if !ok { + return nil, fmt.Errorf("prompt not found: %s", name) + } + return p, nil +} + +// Create, Update, Delete methods moved to handlers_test.go to avoid duplication + +func (m *mockPromptStore) SearchByTags(tags []string) ([]promptstore.Prompt, error) { + return nil, fmt.Errorf("not implemented") +} + +func createTestServer(t *testing.T, store promptstore.PromptStore) (*Server, *bytes.Buffer, *bytes.Buffer) { + t.Helper() + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + return NewServer(inBuf, outBuf, logger, store), inBuf, outBuf +} + +func sendRequest(w io.Writer, req Request) error { + data, err := json.Marshal(req) + if err != nil { + return err + } + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data)) + if _, err := io.WriteString(w, header); err != nil { + return err + } + if _, err := w.Write(data); err != nil { + return err + } + return nil +} + +func readResponse(r io.Reader) (*Response, error) { + // Simple read for testing (assumes one message in buffer) + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + + // Find Content-Length + lines := strings.Split(string(data), "\r\n") + var contentLength int + bodyStart := 0 + for i, line := range lines { + if strings.HasPrefix(line, "Content-Length:") { + fmt.Sscanf(line, "Content-Length: %d", &contentLength) + } + if line == "" { + bodyStart = i + 1 + break + } + } + + body := strings.Join(lines[bodyStart:], "\r\n") + var resp Response + if err := json.Unmarshal([]byte(body), &resp); err != nil { + return nil, fmt.Errorf("unmarshal response: %w, body: %s", err, body) + } + return &resp, nil +} + +func TestServer_Initialize(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, inBuf, outBuf := createTestServer(t, store) + + // Send initialize request + req := Request{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + params := InitializeRequest{ + ProtocolVersion: "2025-06-18", + Capabilities: ClientCapabilities{}, + ClientInfo: ClientInfo{ + Name: "test-client", + Version: "1.0", + }, + } + req.Params, _ = json.Marshal(params) + + if err := sendRequest(inBuf, req); err != nil { + t.Fatalf("sendRequest() error = %v", err) + } + + // Handle request + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.JSONRPC != "2.0" { + t.Errorf("JSONRPC = %v, want 2.0", resp.JSONRPC) + } + // ID comparison: JSON unmarshaling may convert int to float64 + if fmt.Sprintf("%v", resp.ID) != "1" { + t.Errorf("ID = %v (type %T), want 1", resp.ID, resp.ID) + } + if resp.Error != nil { + t.Errorf("Error = %v, want nil", resp.Error) + } + + // Verify server is initialized + server.mu.RLock() + initialized := server.initialized + server.mu.RUnlock() + if !initialized { + t.Error("Server not initialized") + } +} + +func TestServer_PromptsList(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test1": { + Name: "test1", + Title: "Test Prompt 1", + Description: "First test prompt", + Messages: []promptstore.PromptMessage{}, + Created: now, + Updated: now, + }, + "test2": { + Name: "test2", + Title: "Test Prompt 2", + Description: "Second test prompt", + Messages: []promptstore.PromptMessage{}, + Created: now, + Updated: now, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server first + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/list request + req := Request{ + JSONRPC: "2.0", + ID: 2, + Method: "prompts/list", + Params: json.RawMessage(`{}`), + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result ListPromptsResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if len(result.Prompts) != 2 { + t.Errorf("Prompts count = %d, want 2", len(result.Prompts)) + } +} + +func TestServer_PromptsGet(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test_prompt": { + Name: "test_prompt", + Title: "Test Prompt", + Description: "A test prompt", + Arguments: []promptstore.PromptArgument{ + {Name: "name", Description: "User name", Required: true}, + }, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Hello {{name}}!", + }, + }, + }, + Created: now, + Updated: now, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/get request + params := GetPromptRequest{ + Name: "test_prompt", + Arguments: map[string]string{ + "name": "World", + }, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 3, + Method: "prompts/get", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Fatalf("Error = %v, want nil", resp.Error) + } + + // Parse result + resultBytes, _ := json.Marshal(resp.Result) + var result GetPromptResult + if err := json.Unmarshal(resultBytes, &result); err != nil { + t.Fatalf("Unmarshal result error = %v", err) + } + + if len(result.Messages) != 1 { + t.Errorf("Messages count = %d, want 1", len(result.Messages)) + } + if result.Messages[0].Content.Text != "Hello World!" { + t.Errorf("Message text = %v, want 'Hello World!'", result.Messages[0].Content.Text) + } +} + +func TestServer_PromptsGet_MissingArg(t *testing.T) { + now := time.Now() + store := &mockPromptStore{ + prompts: map[string]*promptstore.Prompt{ + "test_prompt": { + Name: "test_prompt", + Title: "Test Prompt", + Description: "A test prompt", + Arguments: []promptstore.PromptArgument{ + {Name: "required_arg", Description: "Required", Required: true}, + }, + Messages: []promptstore.PromptMessage{ + { + Role: "user", + Content: promptstore.MessageContent{ + Type: "text", + Text: "Test: {{required_arg}}", + }, + }, + }, + Created: now, + Updated: now, + }, + }, + } + + server, _, outBuf := createTestServer(t, store) + + // Initialize server + server.mu.Lock() + server.initialized = true + server.mu.Unlock() + + // Send prompts/get request without required argument + params := GetPromptRequest{ + Name: "test_prompt", + Arguments: map[string]string{}, + } + paramsBytes, _ := json.Marshal(params) + req := Request{ + JSONRPC: "2.0", + ID: 4, + Method: "prompts/get", + Params: paramsBytes, + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for missing required argument") + } + if !strings.Contains(resp.Error.Message, "required_arg") { + t.Errorf("Error message = %v, want to contain 'required_arg'", resp.Error.Message) + } +} + +func TestServer_MethodNotFound(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, outBuf := createTestServer(t, store) + + // Send request with unknown method + req := Request{ + JSONRPC: "2.0", + ID: 5, + Method: "unknown/method", + } + + server.handle(req) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error == nil { + t.Fatal("Expected error for unknown method") + } + if resp.Error.Code != ErrCodeMethodNotFound { + t.Errorf("Error code = %d, want %d", resp.Error.Code, ErrCodeMethodNotFound) + } +} + +func TestServer_Run(t *testing.T) { + t.Run("exits on EOF", func(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + // Empty reader will cause immediate EOF + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + server := NewServer(inBuf, outBuf, logger, store) + + err := server.Run() + if err != nil { + t.Errorf("Run() error = %v, want nil on EOF", err) + } + }) + + t.Run("processes initialize request", func(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + server := NewServer(inBuf, outBuf, logger, store) + + // Send initialize request + req := Request{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + } + params := InitializeRequest{ + ProtocolVersion: "2025-06-18", + Capabilities: ClientCapabilities{}, + ClientInfo: ClientInfo{ + Name: "test-client", + Version: "1.0", + }, + } + req.Params, _ = json.Marshal(params) + + if err := sendRequest(inBuf, req); err != nil { + t.Fatalf("sendRequest() error = %v", err) + } + + // Run in background + done := make(chan error, 1) + go func() { + done <- server.Run() + }() + + // Give time for processing (server will block waiting for more input) + time.Sleep(50 * time.Millisecond) + + // Read response + resp, err := readResponse(outBuf) + if err != nil { + t.Fatalf("readResponse() error = %v", err) + } + + if resp.Error != nil { + t.Errorf("Error = %v, want nil", resp.Error) + } + + // Verify server is initialized + server.mu.RLock() + initialized := server.initialized + server.mu.RUnlock() + if !initialized { + t.Error("Server not initialized") + } + }) +} + +func TestServer_ReadMessage(t *testing.T) { + t.Run("reads valid message", func(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + server := NewServer(inBuf, outBuf, logger, store) + + // Write a message with proper framing + msg := []byte(`{"jsonrpc":"2.0","id":1,"method":"test"}`) + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(msg)) + inBuf.WriteString(header) + inBuf.Write(msg) + + // Read it back + body, err := server.readMessage() + if err != nil { + t.Fatalf("readMessage() error = %v", err) + } + + if string(body) != string(msg) { + t.Errorf("readMessage() = %s, want %s", body, msg) + } + }) + + t.Run("returns EOF on empty stream", func(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + inBuf := &bytes.Buffer{} + outBuf := &bytes.Buffer{} + logger := log.New(io.Discard, "", 0) + server := NewServer(inBuf, outBuf, logger, store) + + _, err := server.readMessage() + if err != io.EOF { + t.Errorf("readMessage() error = %v, want EOF", err) + } + }) +} + +func TestServer_HandleInitialized(t *testing.T) { + store := &mockPromptStore{prompts: make(map[string]*promptstore.Prompt)} + server, _, _ := createTestServer(t, store) + + // Send initialized notification + req := Request{ + JSONRPC: "2.0", + Method: "initialized", + } + + // This should not error (it's a notification, no response expected) + server.handleInitialized(req) +} diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go new file mode 100644 index 0000000..aba416a --- /dev/null +++ b/internal/mcp/transport.go @@ -0,0 +1,69 @@ +// Summary: MCP transport utilities for reading and writing JSON-RPC messages with Content-Length framing. +package mcp + +import ( + "encoding/json" + "fmt" + "io" + "net/textproto" + "strconv" + "strings" +) + +// readMessage reads a Content-Length framed JSON-RPC message from the input stream. +// Returns the raw JSON bytes. Follows LSP/JSON-RPC framing convention. +func (s *Server) readMessage() ([]byte, error) { + tp := textproto.NewReader(s.in) + var contentLength int + for { + line, err := tp.ReadLine() + if err != nil { + return nil, err + } + if line == "" { // end of headers + break + } + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + continue + } + key := strings.TrimSpace(strings.ToLower(parts[0])) + val := strings.TrimSpace(parts[1]) + switch key { + case "content-length": + n, err := strconv.Atoi(val) + if err != nil { + return nil, fmt.Errorf("invalid Content-Length: %v", err) + } + contentLength = n + } + } + if contentLength <= 0 { + return nil, fmt.Errorf("missing or invalid Content-Length") + } + buf := make([]byte, contentLength) + if _, err := io.ReadFull(s.in, buf); err != nil { + return nil, err + } + return buf, nil +} + +// writeMessage writes a JSON-RPC response with Content-Length framing. +// Thread-safe via mutex lock. +func (s *Server) writeMessage(v any) error { + s.outMu.Lock() + defer s.outMu.Unlock() + + data, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("marshal error: %w", err) + } + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data)) + if _, err := io.WriteString(s.out, header); err != nil { + return fmt.Errorf("write header error: %w", err) + } + if _, err := s.out.Write(data); err != nil { + return fmt.Errorf("write body error: %w", err) + } + return nil +} diff --git a/internal/mcp/types.go b/internal/mcp/types.go new file mode 100644 index 0000000..e165972 --- /dev/null +++ b/internal/mcp/types.go @@ -0,0 +1,187 @@ +// Summary: MCP protocol types for JSON-RPC 2.0 messages and MCP-specific structures. +package mcp + +import "encoding/json" + +// Request represents an MCP JSON-RPC 2.0 request from the client. +// All MCP communication follows JSON-RPC 2.0 specification. +type Request struct { + JSONRPC string `json:"jsonrpc"` // Always "2.0" + ID any `json:"id"` // Request ID (string or number) + Method string `json:"method"` // Method name + Params json.RawMessage `json:"params,omitempty"` +} + +// Response represents an MCP JSON-RPC 2.0 response to the client. +// Contains either a result or an error, never both. +type Response struct { + JSONRPC string `json:"jsonrpc"` // Always "2.0" + ID any `json:"id"` // Matching request ID + Result any `json:"result,omitempty"` + Error *RespError `json:"error,omitempty"` +} + +// RespError represents a JSON-RPC error with code and message. +// Follows JSON-RPC 2.0 error object specification. +type RespError struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` +} + +// JSON-RPC error codes +const ( + ErrCodeParseError = -32700 // Invalid JSON + ErrCodeInvalidRequest = -32600 // Invalid Request structure + ErrCodeMethodNotFound = -32601 // Method doesn't exist + ErrCodeInvalidParams = -32602 // Invalid parameters + ErrCodeInternalError = -32603 // Server internal error +) + +// InitializeRequest is the first message from client to establish connection. +// Client sends capabilities and protocol version; server responds with its capabilities. +type InitializeRequest struct { + ProtocolVersion string `json:"protocolVersion"` // "2025-06-18" + Capabilities ClientCapabilities `json:"capabilities"` + ClientInfo ClientInfo `json:"clientInfo"` + Meta map[string]interface{} `json:"_meta,omitempty"` +} + +// ClientCapabilities describes what features the client supports. +// Used during handshake to negotiate supported features. +type ClientCapabilities struct { + Sampling map[string]interface{} `json:"sampling,omitempty"` +} + +// ClientInfo identifies the connecting client. +type ClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// InitializeResult is the server's response to initialize. +// Contains server capabilities and identification. +type InitializeResult struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo ServerInfo `json:"serverInfo"` +} + +// ServerCapabilities describes what features this server supports. +// Currently only prompts are supported; tools/resources reserved for future. +type ServerCapabilities struct { + Prompts *PromptsCapability `json:"prompts,omitempty"` + Resources *ResourcesCapability `json:"resources,omitempty"` + Tools *ToolsCapability `json:"tools,omitempty"` +} + +// PromptsCapability indicates server supports prompt listing and retrieval. +type PromptsCapability struct { + ListChanged bool `json:"listChanged,omitempty"` // Server notifies when prompts change + Mutable bool `json:"mutable,omitempty"` // Server supports create/update/delete +} + +// ResourcesCapability indicates server supports resource listing and retrieval. +// Reserved for future use (runbooks, documentation files). +type ResourcesCapability struct { + Subscribe bool `json:"subscribe,omitempty"` + ListChanged bool `json:"listChanged,omitempty"` +} + +// ToolsCapability indicates server supports tool execution. +// Reserved for future use (script execution, git operations). +type ToolsCapability struct { + ListChanged bool `json:"listChanged,omitempty"` +} + +// ServerInfo identifies this server. +type ServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// ListPromptsRequest contains parameters for prompts/list method. +// Supports pagination via cursor-based iteration. +type ListPromptsRequest struct { + Cursor string `json:"cursor,omitempty"` // Pagination cursor (empty for first page) +} + +// ListPromptsResult contains paginated list of available prompts. +// Includes nextCursor for fetching additional pages. +type ListPromptsResult struct { + Prompts []PromptInfo `json:"prompts"` + NextCursor string `json:"nextCursor,omitempty"` +} + +// PromptInfo describes a prompt in the list (metadata only, no messages). +// Client uses this to display available prompts before fetching full content. +type PromptInfo struct { + Name string `json:"name"` // Unique identifier + Title string `json:"title,omitempty"` // Display name + Description string `json:"description,omitempty"` // Human-readable + Arguments []PromptArgument `json:"arguments,omitempty"` // Template variables +} + +// PromptArgument describes a template variable in a prompt. +type PromptArgument struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Required bool `json:"required,omitempty"` +} + +// GetPromptRequest contains parameters for prompts/get method. +// Includes argument values to render the prompt template. +type GetPromptRequest struct { + Name string `json:"name"` // Prompt identifier + Arguments map[string]string `json:"arguments,omitempty"` // Argument values +} + +// GetPromptResult contains a fully rendered prompt ready for use. +// Template variables have been substituted with provided arguments. +type GetPromptResult struct { + Description string `json:"description,omitempty"` + Messages []PromptMessage `json:"messages"` +} + +// PromptMessage represents a message in the rendered prompt. +type PromptMessage struct { + Role string `json:"role"` // "user" or "assistant" + Content MessageContent `json:"content"` +} + +// MessageContent contains the message text or other content types. +type MessageContent struct { + Type string `json:"type"` // "text", "image", "resource" + Text string `json:"text,omitempty"` +} + +// CreatePromptRequest contains parameters for prompts/create method. +type CreatePromptRequest struct { + Name string `json:"name"` + Title string `json:"title"` + Description string `json:"description,omitempty"` + Arguments []PromptArgument `json:"arguments,omitempty"` + Messages []PromptMessage `json:"messages"` + Tags []string `json:"tags,omitempty"` +} + +// UpdatePromptRequest contains parameters for prompts/update method. +type UpdatePromptRequest struct { + Name string `json:"name"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Arguments []PromptArgument `json:"arguments,omitempty"` + Messages []PromptMessage `json:"messages,omitempty"` + Tags []string `json:"tags,omitempty"` +} + +// DeletePromptRequest contains parameters for prompts/delete method. +type DeletePromptRequest struct { + Name string `json:"name"` +} + +// PromptOperationResult indicates success/failure of create/update/delete. +type PromptOperationResult struct { + Success bool `json:"success"` + Message string `json:"message,omitempty"` +} |
