diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/hexaiaction/run.go | 132 | ||||
| -rw-r--r-- | internal/lsp/handlers_codeaction.go | 311 |
2 files changed, 334 insertions, 109 deletions
diff --git a/internal/hexaiaction/run.go b/internal/hexaiaction/run.go index bf355b0..ffd31f1 100644 --- a/internal/hexaiaction/run.go +++ b/internal/hexaiaction/run.go @@ -29,6 +29,35 @@ type configPathKey struct{} // to the executor. Cleared after use. var selectedCustom *appconfig.CustomAction +type actionPlan struct { + fallback string + run func(context.Context) (string, error) +} + +// CodeActionHandler builds a plan for an action and resolves it. +type CodeActionHandler interface { + Build(parts InputParts, cfg appconfig.App, client chatDoer, stderr io.Writer) (actionPlan, bool) + Resolve(ctx context.Context, plan actionPlan) (string, error) +} + +type codeActionHandler struct { + build func(parts InputParts, cfg appconfig.App, client chatDoer, stderr io.Writer) (actionPlan, bool) +} + +func (h codeActionHandler) Build(parts InputParts, cfg appconfig.App, client chatDoer, stderr io.Writer) (actionPlan, bool) { + if h.build == nil { + return actionPlan{}, false + } + return h.build(parts, cfg, client, stderr) +} + +func (h codeActionHandler) Resolve(ctx context.Context, plan actionPlan) (string, error) { + if plan.run == nil { + return plan.fallback, nil + } + return plan.run(ctx) +} + func Run(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer) error { logger := log.New(stderr, "hexai-tmux-action ", log.LstdFlags|log.Lmsgprefix) cfg := appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPathFromContext(ctx)}) @@ -98,26 +127,95 @@ func configPathFromContext(ctx context.Context) string { } func executeAction(ctx context.Context, kind ActionKind, parts InputParts, cfg appconfig.App, client chatDoer, stderr io.Writer) (string, error) { - switch kind { - case ActionSkip: + handler, ok := codeActionHandlers()[kind] + if !ok { return parts.Selection, nil - case ActionRewrite: - return handleRewriteAction(ctx, parts, cfg, client, stderr) - case ActionDiagnostics: - return handleDiagnosticsAction(ctx, parts, cfg, client) - case ActionDocument: - return handleDocumentAction(ctx, parts, cfg, client) - case ActionGoTest: - return handleGoTestAction(ctx, parts, cfg, client) - case ActionSimplify: - return handleSimplifyAction(ctx, parts, cfg, client) - case ActionCustom: - return handleCustomAction(ctx, parts, cfg, client) - case ActionCustomPrompt: - return handleCustomPromptAction(ctx, parts, cfg, client, stderr) - default: + } + plan, ok := handler.Build(parts, cfg, client, stderr) + if !ok { return parts.Selection, nil } + return handler.Resolve(ctx, plan) +} + +func codeActionHandlers() map[ActionKind]CodeActionHandler { + return map[ActionKind]CodeActionHandler{ + ActionSkip: codeActionHandler{build: buildSkipPlan}, + ActionRewrite: codeActionHandler{build: buildRewritePlan}, + ActionDiagnostics: codeActionHandler{build: buildDiagnosticsPlan}, + ActionDocument: codeActionHandler{build: buildDocumentPlan}, + ActionGoTest: codeActionHandler{build: buildGoTestPlan}, + ActionSimplify: codeActionHandler{build: buildSimplifyPlan}, + ActionCustom: codeActionHandler{build: buildCustomPlan}, + ActionCustomPrompt: codeActionHandler{build: buildCustomPromptPlan}, + } +} + +func buildSkipPlan(parts InputParts, _ appconfig.App, _ chatDoer, _ io.Writer) (actionPlan, bool) { + return actionPlan{fallback: parts.Selection}, true +} + +func buildRewritePlan(parts InputParts, cfg appconfig.App, client chatDoer, stderr io.Writer) (actionPlan, bool) { + return actionPlan{ + fallback: parts.Selection, + run: func(ctx context.Context) (string, error) { + return handleRewriteAction(ctx, parts, cfg, client, stderr) + }, + }, true +} + +func buildDiagnosticsPlan(parts InputParts, cfg appconfig.App, client chatDoer, _ io.Writer) (actionPlan, bool) { + return actionPlan{ + fallback: parts.Selection, + run: func(ctx context.Context) (string, error) { + return handleDiagnosticsAction(ctx, parts, cfg, client) + }, + }, true +} + +func buildDocumentPlan(parts InputParts, cfg appconfig.App, client chatDoer, _ io.Writer) (actionPlan, bool) { + return actionPlan{ + fallback: parts.Selection, + run: func(ctx context.Context) (string, error) { + return handleDocumentAction(ctx, parts, cfg, client) + }, + }, true +} + +func buildGoTestPlan(parts InputParts, cfg appconfig.App, client chatDoer, _ io.Writer) (actionPlan, bool) { + return actionPlan{ + fallback: parts.Selection, + run: func(ctx context.Context) (string, error) { + return handleGoTestAction(ctx, parts, cfg, client) + }, + }, true +} + +func buildSimplifyPlan(parts InputParts, cfg appconfig.App, client chatDoer, _ io.Writer) (actionPlan, bool) { + return actionPlan{ + fallback: parts.Selection, + run: func(ctx context.Context) (string, error) { + return handleSimplifyAction(ctx, parts, cfg, client) + }, + }, true +} + +func buildCustomPlan(parts InputParts, cfg appconfig.App, client chatDoer, _ io.Writer) (actionPlan, bool) { + return actionPlan{ + fallback: parts.Selection, + run: func(ctx context.Context) (string, error) { + return handleCustomAction(ctx, parts, cfg, client) + }, + }, true +} + +func buildCustomPromptPlan(parts InputParts, cfg appconfig.App, client chatDoer, stderr io.Writer) (actionPlan, bool) { + return actionPlan{ + fallback: parts.Selection, + run: func(ctx context.Context) (string, error) { + return handleCustomPromptAction(ctx, parts, cfg, client, stderr) + }, + }, true } func handleRewriteAction(ctx context.Context, parts InputParts, cfg appconfig.App, client chatDoer, stderr io.Writer) (string, error) { diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go index 4562954..2a1d7fa 100644 --- a/internal/lsp/handlers_codeaction.go +++ b/internal/lsp/handlers_codeaction.go @@ -14,6 +14,43 @@ import ( "codeberg.org/snonux/hexai/internal/logging" ) +type codeActionPayload struct { + Type string `json:"type"` + ID string `json:"id"` + URI string `json:"uri"` + Range Range `json:"range"` + Instruction string `json:"instruction,omitempty"` + Selection string `json:"selection"` + Diagnostics []Diagnostic `json:"diagnostics,omitempty"` +} + +// CodeActionHandler builds and resolves code actions for a specific action type. +type CodeActionHandler interface { + Build(s *Server, p CodeActionParams, selection string) []CodeAction + Resolve(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) +} + +type codeActionHandler struct { + build func(s *Server, p CodeActionParams, selection string) []CodeAction + resolve func(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) +} + +func (h codeActionHandler) Build(s *Server, p CodeActionParams, selection string) []CodeAction { + if h.build == nil { + return nil + } + return h.build(s, p, selection) +} + +func (h codeActionHandler) Resolve(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + if h.resolve == nil { + return action, false + } + return h.resolve(s, action, payload) +} + +var codeActionBuildOrder = []string{"rewrite", "diagnostics", "document", "go_test", "simplify", "custom"} + func (s *Server) handleCodeAction(req Request) { var p CodeActionParams if err := json.Unmarshal(req.Params, &p); err != nil { @@ -39,29 +76,25 @@ func (s *Server) handleCodeAction(req Request) { } sel := extractRangeText(d, p.Range) - actions := make([]CodeAction, 0, 8) - if a := s.buildRewriteCodeAction(p, sel); a != nil { - actions = append(actions, *a) - } - if a := s.buildDiagnosticsCodeAction(p, sel); a != nil { - actions = append(actions, *a) - } - if a := s.buildDocumentCodeAction(p, sel); a != nil { - actions = append(actions, *a) - } - if a := s.buildGoUnitTestCodeAction(p); a != nil { - actions = append(actions, *a) - } - if a := s.buildSimplifyCodeAction(p, sel); a != nil { - actions = append(actions, *a) - } - // Custom actions from config - s.appendCustomActions(&actions, p, sel) + actions := s.buildCodeActions(p, sel) if len(req.ID) != 0 { s.reply(req.ID, actions, nil) } } +func (s *Server) buildCodeActions(p CodeActionParams, selection string) []CodeAction { + actions := make([]CodeAction, 0, 8) + handlers := s.codeActionHandlers() + for _, key := range codeActionBuildOrder { + handler, ok := handlers[key] + if !ok { + continue + } + actions = append(actions, handler.Build(s, p, selection)...) + } + return actions +} + // appendCustomActions adds user-defined actions depending on scope and availability. func (s *Server) appendCustomActions(actions *[]CodeAction, p CodeActionParams, sel string) { customs := s.customActions() @@ -115,6 +148,133 @@ func (s *Server) appendCustomActions(actions *[]CodeAction, p CodeActionParams, } } +func (s *Server) codeActionHandlers() map[string]CodeActionHandler { + return map[string]CodeActionHandler{ + "rewrite": codeActionHandler{build: buildRewriteActions, resolve: resolveRewriteCodeAction}, + "diagnostics": codeActionHandler{build: buildDiagnosticsActions, resolve: resolveDiagnosticsCodeAction}, + "document": codeActionHandler{build: buildDocumentActions, resolve: resolveDocumentCodeAction}, + "go_test": codeActionHandler{build: buildGoTestActions, resolve: resolveGoTestCodeAction}, + "simplify": codeActionHandler{build: buildSimplifyActions, resolve: resolveSimplifyCodeAction}, + "custom": codeActionHandler{build: buildCustomActions, resolve: resolveCustomCodeAction}, + } +} + +func buildRewriteActions(s *Server, p CodeActionParams, selection string) []CodeAction { + if action := s.buildRewriteCodeAction(p, selection); action != nil { + return []CodeAction{*action} + } + return nil +} + +func buildDiagnosticsActions(s *Server, p CodeActionParams, selection string) []CodeAction { + if action := s.buildDiagnosticsCodeAction(p, selection); action != nil { + return []CodeAction{*action} + } + return nil +} + +func buildDocumentActions(s *Server, p CodeActionParams, selection string) []CodeAction { + if action := s.buildDocumentCodeAction(p, selection); action != nil { + return []CodeAction{*action} + } + return nil +} + +func buildGoTestActions(s *Server, p CodeActionParams, _ string) []CodeAction { + if action := s.buildGoUnitTestCodeAction(p); action != nil { + return []CodeAction{*action} + } + return nil +} + +func buildSimplifyActions(s *Server, p CodeActionParams, selection string) []CodeAction { + if action := s.buildSimplifyCodeAction(p, selection); action != nil { + return []CodeAction{*action} + } + return nil +} + +func buildCustomActions(s *Server, p CodeActionParams, selection string) []CodeAction { + actions := make([]CodeAction, 0, len(s.customActions())) + s.appendCustomActions(&actions, p, selection) + return actions +} + +func resolveRewriteCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + cfg := s.currentConfig() + sys := cfg.PromptCodeActionRewriteSystem + user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{ + "instruction": payload.Instruction, + "selection": payload.Selection, + }) + return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 20*time.Second) +} + +func resolveDiagnosticsCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + cfg := s.currentConfig() + sys := cfg.PromptCodeActionDiagnosticsSystem + user := renderTemplate(cfg.PromptCodeActionDiagnosticsUser, map[string]string{ + "diagnostics": formatDiagnostics(payload.Diagnostics), + "selection": payload.Selection, + }) + return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 22*time.Second) +} + +func resolveDocumentCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + cfg := s.currentConfig() + sys := cfg.PromptCodeActionDocumentSystem + user := renderTemplate(cfg.PromptCodeActionDocumentUser, map[string]string{"selection": payload.Selection}) + return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 20*time.Second) +} + +func resolveGoTestCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + edit, jumpURI, jumpRange, ok := s.resolveGoTest(payload.URI, payload.Range.Start) + if !ok { + return action, false + } + action.Edit = &edit + action.Command = &Command{ + Title: "Jump to generated test", + Command: "hexai.showDocument", + Arguments: []any{jumpURI, jumpRange}, + } + s.deferShowDocument(jumpURI, jumpRange) + return action, true +} + +func resolveSimplifyCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + cfg := s.currentConfig() + sys := cfg.PromptCodeActionRewriteSystem + user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{ + "instruction": "Simplify and improve the code while preserving behavior. Return only the improved code.", + "selection": payload.Selection, + }) + return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 20*time.Second) +} + +func resolveCustomCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + custom := s.customActionByID(payload.ID) + if custom == nil { + return action, false + } + cfg := s.currentConfig() + sys := cfg.PromptCodeActionRewriteSystem + user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{ + "instruction": payload.Instruction, + "selection": payload.Selection, + }) + if strings.TrimSpace(custom.User) != "" { + if strings.TrimSpace(custom.System) != "" { + sys = custom.System + } + user = renderTemplate(custom.User, map[string]string{ + "selection": payload.Selection, + "diagnostics": strings.TrimSpace(formatCustomDiagnostics(payload.Diagnostics)), + }) + } + return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 20*time.Second) +} + func (s *Server) buildSimplifyCodeAction(p CodeActionParams, sel string) *CodeAction { if strings.TrimSpace(sel) == "" { return nil @@ -167,86 +327,53 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) { if s.currentLLMClient() == nil || len(ca.Data) == 0 { return ca, false } - var payload struct { - Type string `json:"type"` - ID string `json:"id"` - URI string `json:"uri"` - Range Range `json:"range"` - Instruction string `json:"instruction,omitempty"` - Selection string `json:"selection"` - Diagnostics []Diagnostic `json:"diagnostics,omitempty"` + payload, ok := decodeCodeActionPayload(ca.Data) + if !ok { + return ca, false } - if err := json.Unmarshal(ca.Data, &payload); err != nil { + handler, found := s.codeActionHandlers()[payload.Type] + if !found { return ca, false } - cfg := s.currentConfig() - switch payload.Type { - case "rewrite": - sys := cfg.PromptCodeActionRewriteSystem - user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{"instruction": payload.Instruction, "selection": payload.Selection}) - return s.completeCodeAction(ca, payload.URI, payload.Range, sys, user, 20*time.Second) - case "diagnostics": - sys := cfg.PromptCodeActionDiagnosticsSystem - var b strings.Builder - for i, dgn := range payload.Diagnostics { - if dgn.Source != "" { - fmt.Fprintf(&b, "%d. [%s] %s\n", i+1, dgn.Source, dgn.Message) - } else { - fmt.Fprintf(&b, "%d. %s\n", i+1, dgn.Message) - } - } - diagList := b.String() - user := renderTemplate(cfg.PromptCodeActionDiagnosticsUser, map[string]string{"diagnostics": diagList, "selection": payload.Selection}) - return s.completeCodeAction(ca, payload.URI, payload.Range, sys, user, 22*time.Second) - case "document": - sys := cfg.PromptCodeActionDocumentSystem - user := renderTemplate(cfg.PromptCodeActionDocumentUser, map[string]string{"selection": payload.Selection}) - return s.completeCodeAction(ca, payload.URI, payload.Range, sys, user, 20*time.Second) - case "go_test": - if edit, jumpURI, jumpRange, ok := s.resolveGoTest(payload.URI, payload.Range.Start); ok { - ca.Edit = &edit - ca.Command = &Command{Title: "Jump to generated test", Command: "hexai.showDocument", Arguments: []any{jumpURI, jumpRange}} - s.deferShowDocument(jumpURI, jumpRange) - return ca, true - } - case "simplify": - sys := cfg.PromptCodeActionRewriteSystem - user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{"instruction": "Simplify and improve the code while preserving behavior. Return only the improved code.", "selection": payload.Selection}) - return s.completeCodeAction(ca, payload.URI, payload.Range, sys, user, 20*time.Second) - case "custom": - var action *CustomAction - for _, caDef := range s.customActions() { - if caDef.ID == payload.ID { - action = &caDef - break - } - } - if action == nil { - return ca, false + return handler.Resolve(s, ca, payload) +} + +func decodeCodeActionPayload(raw json.RawMessage) (codeActionPayload, bool) { + var payload codeActionPayload + if err := json.Unmarshal(raw, &payload); err != nil { + return codeActionPayload{}, false + } + return payload, true +} + +func formatDiagnostics(diagnostics []Diagnostic) string { + var b strings.Builder + for i, dgn := range diagnostics { + if dgn.Source != "" { + fmt.Fprintf(&b, "%d. [%s] %s\n", i+1, dgn.Source, dgn.Message) + continue } - var sys, user string - if strings.TrimSpace(action.User) != "" { - if strings.TrimSpace(action.System) != "" { - sys = action.System - } else { - sys = cfg.PromptCodeActionRewriteSystem - } - var diagList string - if len(payload.Diagnostics) > 0 { - var b strings.Builder - for _, d := range payload.Diagnostics { - fmt.Fprintf(&b, "%s\n", d.Message) - } - diagList = b.String() - } - user = renderTemplate(action.User, map[string]string{"selection": payload.Selection, "diagnostics": strings.TrimSpace(diagList)}) - } else { - sys = cfg.PromptCodeActionRewriteSystem - user = renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{"instruction": payload.Instruction, "selection": payload.Selection}) + fmt.Fprintf(&b, "%d. %s\n", i+1, dgn.Message) + } + return b.String() +} + +func formatCustomDiagnostics(diagnostics []Diagnostic) string { + var b strings.Builder + for _, d := range diagnostics { + fmt.Fprintf(&b, "%s\n", d.Message) + } + return b.String() +} + +func (s *Server) customActionByID(id string) *CustomAction { + for _, item := range s.customActions() { + if item.ID == id { + action := item + return &action } - return s.completeCodeAction(ca, payload.URI, payload.Range, sys, user, 20*time.Second) } - return ca, false + return nil } func (s *Server) completeCodeAction(ca CodeAction, uri string, rng Range, sys, user string, timeout time.Duration) (CodeAction, bool) { |
