diff options
| author | Paul Buetow <pbuetow@mimecast.com> | 2020-02-26 11:11:07 +0000 |
|---|---|---|
| committer | Paul Buetow <pbuetow@mimecast.com> | 2020-02-26 11:11:07 +0000 |
| commit | 3cdc86e20cbd311fb9c85cef63876a2f39e5e74d (patch) | |
| tree | 9cb50347900ff1ba4dc6a7b6e4766ebd951c2c58 /internal/server | |
| parent | 6e176034306026b922c1df4231a1807f36cbe460 (diff) | |
can list remote jobs and can also pass outer args to scripts
Diffstat (limited to 'internal/server')
| -rw-r--r-- | internal/server/background/background.go | 66 | ||||
| -rw-r--r-- | internal/server/handlers/runcommand.go | 78 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 105 |
3 files changed, 174 insertions, 75 deletions
diff --git a/internal/server/background/background.go b/internal/server/background/background.go index 05a502f..51ef052 100644 --- a/internal/server/background/background.go +++ b/internal/server/background/background.go @@ -3,12 +3,16 @@ package background import ( "context" "errors" + "fmt" + "strings" "sync" + + "github.com/mimecast/dtail/internal/io/logger" ) type job struct { cancel context.CancelFunc - done <-chan struct{} + wg *sync.WaitGroup } // Background specifies a job or command run in background on server side. @@ -27,43 +31,81 @@ func New() Background { } // Add a background job. -func (b Background) Add(name string, cancel context.CancelFunc, done <-chan struct{}) error { +func (b Background) Add(userName, jobName string, cancel context.CancelFunc, wg *sync.WaitGroup) error { + key := b.key(userName, jobName) + b.mutex.Lock() defer b.mutex.Unlock() - if _, ok := b.jobs[name]; ok { + if _, ok := b.jobs[key]; ok { return errors.New("job already exists") } - b.jobs[name] = job{cancel, done} + b.jobs[key] = job{cancel, wg} + + // Clean up background job database. + go func() { + wg.Wait() + b.cancel(key) + }() + return nil } // Cancel a background job. -func (b Background) Cancel(name string) error { - job, ok := b.get(name) +func (b Background) Cancel(userName, jobName string) error { + return b.cancel(b.key(userName, jobName)) +} + +func (b Background) cancel(key string) error { + job, ok := b.get(key) if !ok { return errors.New("no job to cancel") } job.cancel() - <-job.done - b.delete(name) + job.wg.Wait() + b.delete(key) return nil } -func (b Background) get(name string) (job, bool) { +// ListJobsC returns a channel listing all jobs of the given user. +func (b Background) ListJobsC(userName string) <-chan string { + ch := make(chan string) + + go func() { + defer close(ch) + + b.mutex.Lock() + defer b.mutex.Unlock() + + for k, _ := range b.jobs { + logger.Debug("ListJobsC", k, userName) + if strings.HasPrefix(k, fmt.Sprintf("%s.", userName)) { + ch <- k + } + } + }() + + return ch +} + +func (b Background) get(key string) (job, bool) { b.mutex.Lock() defer b.mutex.Unlock() - job, ok := b.jobs[name] + job, ok := b.jobs[key] return job, ok } -func (b Background) delete(name string) { +func (b Background) delete(key string) { b.mutex.Lock() defer b.mutex.Unlock() - delete(b.jobs, name) + delete(b.jobs, key) +} + +func (Background) key(userName, jobName string) string { + return fmt.Sprintf("%s.%s", userName, jobName) } diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go index 95db52f..8e5895b 100644 --- a/internal/server/handlers/runcommand.go +++ b/internal/server/handlers/runcommand.go @@ -2,11 +2,13 @@ package handlers import ( "context" + "errors" "fmt" "io/ioutil" "os" "os/exec" "strings" + "sync" "time" "github.com/mimecast/dtail/internal/config" @@ -25,26 +27,39 @@ func newRunCommand(server *ServerHandler) runCommand { } } -func (r runCommand) Start(ctx context.Context, argc int, args []string) { +func (r runCommand) StartBackground(ctx context.Context, wg *sync.WaitGroup, argc int, args, outerArgs []string) error { if argc < 2 { - r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) - return + return fmt.Errorf("%s: args:%v argc:%d", commandParseWarning, args, argc) } + ec := make(chan int, 1) + var pid int + var err error + command := strings.Join(args[1:], " ") if strings.Contains(command, ";") || strings.Contains(command, "\n") { - r.startScript(ctx, command) - return + if pid, err = r.startScript(ctx, wg, ec, command, outerArgs); err != nil { + r.server.sendServerMessage(".run exitstatus 255") + return err + } + return nil + } + + if pid, err = r.start(ctx, wg, ec, strings.TrimSpace(command), outerArgs); err != nil { + r.server.sendServerMessage(".run exitstatus 255") + return err } - r.start(ctx, strings.TrimSpace(command)) + exitCode := <-ec + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", exitCode)) + r.server.sendServerMessage(logger.Info(fmt.Sprintf("Process %d exited with status %d", pid, exitCode))) + + return nil } -func (r runCommand) startScript(ctx context.Context, script string) { +func (r runCommand) startScript(ctx context.Context, wg *sync.WaitGroup, ec chan<- int, script string, outerArgs []string) (int, error) { if _, err := os.Stat(config.Common.TmpDir); os.IsNotExist(err) { - logger.Error(r.server.user, err) - r.server.sendServerMessage(logger.Error(r.server.user, "Unable to execute command(s), check server logs")) - return + return -1, err } timestamp := time.Now().UnixNano() @@ -55,45 +70,42 @@ func (r runCommand) startScript(ctx context.Context, script string) { script = fmt.Sprintf("#!/bin/sh\n%s", script) if err := ioutil.WriteFile(scriptPath, []byte(script), 0700); err != nil { - logger.Error(r.server.user, err) - r.server.sendServerMessage(logger.Error(r.server.user, "Unable to execute command(s), check server logs")) - return + return -1, err } - r.start(ctx, scriptPath) - os.Remove(scriptPath) + pid, err := r.start(ctx, wg, ec, scriptPath, outerArgs) + go func() { + wg.Wait() + logger.Debug("Deleting script", scriptPath) + os.Remove(scriptPath) + }() + + return pid, err } -func (r runCommand) start(ctx context.Context, command string) { +func (r runCommand) start(ctx context.Context, wg *sync.WaitGroup, ec chan<- int, command string, outerArgs []string) (int, error) { if len(command) == 0 { - return + return -1, errors.New("Empty command provided") } + splitted := strings.Split(command, " ") path := splitted[0] args := splitted[1:] + args = append(args, outerArgs...) qualifiedPath, err := exec.LookPath(path) if err != nil { - logger.Error(r.server.user, err) - r.server.sendServerMessage(logger.Error(r.server.user, "Unable to execute command(s), check server logs")) - r.server.sendServerMessage(".run exitstatus 255") - return + return -1, err } if !r.server.user.HasFilePermission(qualifiedPath, "runcommands") { - logger.Error(r.server.user, "No permission to execute path", qualifiedPath) - r.server.sendServerMessage(logger.Error(r.server.user, "Unable to execute command(s), check server logs")) - r.server.sendServerMessage(".run exitstatus 255") - return + return -1, fmt.Errorf("No permission to execute path: %s", qualifiedPath) } r.run = run.New(qualifiedPath, args) - pid, ec, _ := r.run.Start(ctx, r.server.lines) - - r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec)) - r.server.sendServerMessage(logger.Info(fmt.Sprintf("Process %d exited with status %d", pid, ec))) - - logger.Debug(r.server.user, "Waiting for Pgroup to be killed") - <-r.run.PgroupKilled() - logger.Debug(r.server.user, "Pgroup killed") + pid, err := r.run.StartBackground(ctx, wg, ec, r.server.lines) + if err != nil { + return pid, err + } + return pid, nil } diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index 01e4054..819cddd 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -2,9 +2,7 @@ package handlers import ( "context" - "crypto/sha256" "encoding/base64" - "encoding/hex" "errors" "fmt" "io" @@ -251,7 +249,14 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] splitted := strings.Split(args[0], ":") command := splitted[0] - flags := splitted[1:] + + // TODO: Refactor: Create an "options" clase, combine makeOptions and readOptions there. + options, err := readOptions(splitted[1:]) + if err != nil { + h.sendServerMessage(logger.Error(h.user, err)) + finished() + return + } switch command { case "grep", "cat": @@ -287,10 +292,12 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] case "run": // TODO: Refactor this "run" case, move code to runcommand.go command := newRunCommand(h) - jobName := fmt.Sprintf("%s%%%s", h.user.Name, hash(strings.Join(args[1:], " "))) - if contains(flags, "background.cancel") { - if err := h.background.Cancel(jobName); err != nil { + jobName, _ := options["jobName"] + logger.Debug(h.user, "run", options) + + if val, ok := options["background"]; ok && val == "cancel" { + if err := h.background.Cancel(h.user.Name, jobName); err != nil { h.sendServerMessage(logger.Error(h.user, err, jobName, args)) } else { h.sendServerMessage(logger.Info(h.user, "job cancelled", jobName)) @@ -299,37 +306,64 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] return } - done := make(chan struct{}) + if val, ok := options["background"]; ok && val == "list" { + h.sendServerMessage("Listing jobs") + count := 0 + for jobName := range h.background.ListJobsC(h.user.Name) { + h.sendServerMessage(jobName) + count++ + } + h.sendServerMessage(fmt.Sprintf("Found %d jobs", count)) + finished() + return + } + + str, _ := options["outerArgs"] + outerArgs := strings.Split(str, " ") + + var background bool + if val, ok := options["background"]; ok && val == "start" { + background = true + } + + var wg sync.WaitGroup + wg.Add(1) - if contains(flags, "background.start") { + if background { commandCtx, cancel := context.WithCancel(h.serverCtx) + // TODO: For background jobs dont attempt to send data to dtail client as there might be no SSH connection - if err := h.background.Add(jobName, cancel, done); err != nil { + if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil { h.sendServerMessage(logger.Error(h.user, err, jobName, args)) finished() return } + ctx = commandCtx + } - go func() { h.globalServerWaitFor <- struct{}{} }() - go func() { - command.Start(commandCtx, argc, args) - close(done) - <-h.globalServerWaitFor - }() - - h.sendServerMessage(logger.Info(h.user, jobName, "job started in background")) + if err := command.StartBackground(ctx, &wg, argc, args, outerArgs); err != nil { + h.sendServerMessage(logger.Error(h.user, "Unable to execute command", argc, args, err)) finished() return } + // Make sure that server waits for all sub-processes to finish on shutdown go func() { h.globalServerWaitFor <- struct{}{} }() go func() { - command.Start(ctx, argc, args) - close(done) + wg.Wait() <-h.globalServerWaitFor - finished() }() + if background { + h.sendServerMessage(logger.Info(h.user, jobName, "job started in background")) + finished() + return + } + + // Command run in foreground, wait for it to complete before finishing the connection. + wg.Wait() + finished() + case "ack", ".ack": h.handleAckCommand(argc, args) finished() @@ -427,17 +461,28 @@ func (h *ServerHandler) decrementActiveCommands() int { return h.activeCommands } -func contains(haystack []string, needle string) bool { - for _, str := range haystack { - if str == needle { - return true +func readOptions(opts []string) (map[string]string, error) { + options := make(map[string]string, len(opts)) + + for _, o := range opts { + kv := strings.SplitN(o, "=", 2) + if len(kv) != 2 { + return options, fmt.Errorf("Unable to parse options: %v", kv) } + key := kv[0] + val := kv[1] + + if strings.HasPrefix(val, "base64%") { + s := strings.SplitN(val, "%", 2) + decoded, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + return options, err + } + val = string(decoded) + } + + options[key] = val } - return false -} -func hash(str string) string { - h := sha256.New() - h.Write([]byte(str)) - return hex.EncodeToString(h.Sum(nil)) + return options, nil } |
