summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2020-02-26 11:11:07 +0000
committerPaul Buetow <pbuetow@mimecast.com>2020-02-26 11:11:07 +0000
commit3cdc86e20cbd311fb9c85cef63876a2f39e5e74d (patch)
tree9cb50347900ff1ba4dc6a7b6e4766ebd951c2c58 /internal/server
parent6e176034306026b922c1df4231a1807f36cbe460 (diff)
can list remote jobs and can also pass outer args to scripts
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/background/background.go66
-rw-r--r--internal/server/handlers/runcommand.go78
-rw-r--r--internal/server/handlers/serverhandler.go105
3 files changed, 174 insertions, 75 deletions
diff --git a/internal/server/background/background.go b/internal/server/background/background.go
index 05a502f..51ef052 100644
--- a/internal/server/background/background.go
+++ b/internal/server/background/background.go
@@ -3,12 +3,16 @@ package background
import (
"context"
"errors"
+ "fmt"
+ "strings"
"sync"
+
+ "github.com/mimecast/dtail/internal/io/logger"
)
type job struct {
cancel context.CancelFunc
- done <-chan struct{}
+ wg *sync.WaitGroup
}
// Background specifies a job or command run in background on server side.
@@ -27,43 +31,81 @@ func New() Background {
}
// Add a background job.
-func (b Background) Add(name string, cancel context.CancelFunc, done <-chan struct{}) error {
+func (b Background) Add(userName, jobName string, cancel context.CancelFunc, wg *sync.WaitGroup) error {
+ key := b.key(userName, jobName)
+
b.mutex.Lock()
defer b.mutex.Unlock()
- if _, ok := b.jobs[name]; ok {
+ if _, ok := b.jobs[key]; ok {
return errors.New("job already exists")
}
- b.jobs[name] = job{cancel, done}
+ b.jobs[key] = job{cancel, wg}
+
+ // Clean up background job database.
+ go func() {
+ wg.Wait()
+ b.cancel(key)
+ }()
+
return nil
}
// Cancel a background job.
-func (b Background) Cancel(name string) error {
- job, ok := b.get(name)
+func (b Background) Cancel(userName, jobName string) error {
+ return b.cancel(b.key(userName, jobName))
+}
+
+func (b Background) cancel(key string) error {
+ job, ok := b.get(key)
if !ok {
return errors.New("no job to cancel")
}
job.cancel()
- <-job.done
- b.delete(name)
+ job.wg.Wait()
+ b.delete(key)
return nil
}
-func (b Background) get(name string) (job, bool) {
+// ListJobsC returns a channel listing all jobs of the given user.
+func (b Background) ListJobsC(userName string) <-chan string {
+ ch := make(chan string)
+
+ go func() {
+ defer close(ch)
+
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+
+ for k, _ := range b.jobs {
+ logger.Debug("ListJobsC", k, userName)
+ if strings.HasPrefix(k, fmt.Sprintf("%s.", userName)) {
+ ch <- k
+ }
+ }
+ }()
+
+ return ch
+}
+
+func (b Background) get(key string) (job, bool) {
b.mutex.Lock()
defer b.mutex.Unlock()
- job, ok := b.jobs[name]
+ job, ok := b.jobs[key]
return job, ok
}
-func (b Background) delete(name string) {
+func (b Background) delete(key string) {
b.mutex.Lock()
defer b.mutex.Unlock()
- delete(b.jobs, name)
+ delete(b.jobs, key)
+}
+
+func (Background) key(userName, jobName string) string {
+ return fmt.Sprintf("%s.%s", userName, jobName)
}
diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go
index 95db52f..8e5895b 100644
--- a/internal/server/handlers/runcommand.go
+++ b/internal/server/handlers/runcommand.go
@@ -2,11 +2,13 @@ package handlers
import (
"context"
+ "errors"
"fmt"
"io/ioutil"
"os"
"os/exec"
"strings"
+ "sync"
"time"
"github.com/mimecast/dtail/internal/config"
@@ -25,26 +27,39 @@ func newRunCommand(server *ServerHandler) runCommand {
}
}
-func (r runCommand) Start(ctx context.Context, argc int, args []string) {
+func (r runCommand) StartBackground(ctx context.Context, wg *sync.WaitGroup, argc int, args, outerArgs []string) error {
if argc < 2 {
- r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc))
- return
+ return fmt.Errorf("%s: args:%v argc:%d", commandParseWarning, args, argc)
}
+ ec := make(chan int, 1)
+ var pid int
+ var err error
+
command := strings.Join(args[1:], " ")
if strings.Contains(command, ";") || strings.Contains(command, "\n") {
- r.startScript(ctx, command)
- return
+ if pid, err = r.startScript(ctx, wg, ec, command, outerArgs); err != nil {
+ r.server.sendServerMessage(".run exitstatus 255")
+ return err
+ }
+ return nil
+ }
+
+ if pid, err = r.start(ctx, wg, ec, strings.TrimSpace(command), outerArgs); err != nil {
+ r.server.sendServerMessage(".run exitstatus 255")
+ return err
}
- r.start(ctx, strings.TrimSpace(command))
+ exitCode := <-ec
+ r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", exitCode))
+ r.server.sendServerMessage(logger.Info(fmt.Sprintf("Process %d exited with status %d", pid, exitCode)))
+
+ return nil
}
-func (r runCommand) startScript(ctx context.Context, script string) {
+func (r runCommand) startScript(ctx context.Context, wg *sync.WaitGroup, ec chan<- int, script string, outerArgs []string) (int, error) {
if _, err := os.Stat(config.Common.TmpDir); os.IsNotExist(err) {
- logger.Error(r.server.user, err)
- r.server.sendServerMessage(logger.Error(r.server.user, "Unable to execute command(s), check server logs"))
- return
+ return -1, err
}
timestamp := time.Now().UnixNano()
@@ -55,45 +70,42 @@ func (r runCommand) startScript(ctx context.Context, script string) {
script = fmt.Sprintf("#!/bin/sh\n%s", script)
if err := ioutil.WriteFile(scriptPath, []byte(script), 0700); err != nil {
- logger.Error(r.server.user, err)
- r.server.sendServerMessage(logger.Error(r.server.user, "Unable to execute command(s), check server logs"))
- return
+ return -1, err
}
- r.start(ctx, scriptPath)
- os.Remove(scriptPath)
+ pid, err := r.start(ctx, wg, ec, scriptPath, outerArgs)
+ go func() {
+ wg.Wait()
+ logger.Debug("Deleting script", scriptPath)
+ os.Remove(scriptPath)
+ }()
+
+ return pid, err
}
-func (r runCommand) start(ctx context.Context, command string) {
+func (r runCommand) start(ctx context.Context, wg *sync.WaitGroup, ec chan<- int, command string, outerArgs []string) (int, error) {
if len(command) == 0 {
- return
+ return -1, errors.New("Empty command provided")
}
+
splitted := strings.Split(command, " ")
path := splitted[0]
args := splitted[1:]
+ args = append(args, outerArgs...)
qualifiedPath, err := exec.LookPath(path)
if err != nil {
- logger.Error(r.server.user, err)
- r.server.sendServerMessage(logger.Error(r.server.user, "Unable to execute command(s), check server logs"))
- r.server.sendServerMessage(".run exitstatus 255")
- return
+ return -1, err
}
if !r.server.user.HasFilePermission(qualifiedPath, "runcommands") {
- logger.Error(r.server.user, "No permission to execute path", qualifiedPath)
- r.server.sendServerMessage(logger.Error(r.server.user, "Unable to execute command(s), check server logs"))
- r.server.sendServerMessage(".run exitstatus 255")
- return
+ return -1, fmt.Errorf("No permission to execute path: %s", qualifiedPath)
}
r.run = run.New(qualifiedPath, args)
- pid, ec, _ := r.run.Start(ctx, r.server.lines)
-
- r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec))
- r.server.sendServerMessage(logger.Info(fmt.Sprintf("Process %d exited with status %d", pid, ec)))
-
- logger.Debug(r.server.user, "Waiting for Pgroup to be killed")
- <-r.run.PgroupKilled()
- logger.Debug(r.server.user, "Pgroup killed")
+ pid, err := r.run.StartBackground(ctx, wg, ec, r.server.lines)
+ if err != nil {
+ return pid, err
+ }
+ return pid, nil
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 01e4054..819cddd 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -2,9 +2,7 @@ package handlers
import (
"context"
- "crypto/sha256"
"encoding/base64"
- "encoding/hex"
"errors"
"fmt"
"io"
@@ -251,7 +249,14 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
splitted := strings.Split(args[0], ":")
command := splitted[0]
- flags := splitted[1:]
+
+ // TODO: Refactor: Create an "options" clase, combine makeOptions and readOptions there.
+ options, err := readOptions(splitted[1:])
+ if err != nil {
+ h.sendServerMessage(logger.Error(h.user, err))
+ finished()
+ return
+ }
switch command {
case "grep", "cat":
@@ -287,10 +292,12 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
case "run":
// TODO: Refactor this "run" case, move code to runcommand.go
command := newRunCommand(h)
- jobName := fmt.Sprintf("%s%%%s", h.user.Name, hash(strings.Join(args[1:], " ")))
- if contains(flags, "background.cancel") {
- if err := h.background.Cancel(jobName); err != nil {
+ jobName, _ := options["jobName"]
+ logger.Debug(h.user, "run", options)
+
+ if val, ok := options["background"]; ok && val == "cancel" {
+ if err := h.background.Cancel(h.user.Name, jobName); err != nil {
h.sendServerMessage(logger.Error(h.user, err, jobName, args))
} else {
h.sendServerMessage(logger.Info(h.user, "job cancelled", jobName))
@@ -299,37 +306,64 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
return
}
- done := make(chan struct{})
+ if val, ok := options["background"]; ok && val == "list" {
+ h.sendServerMessage("Listing jobs")
+ count := 0
+ for jobName := range h.background.ListJobsC(h.user.Name) {
+ h.sendServerMessage(jobName)
+ count++
+ }
+ h.sendServerMessage(fmt.Sprintf("Found %d jobs", count))
+ finished()
+ return
+ }
+
+ str, _ := options["outerArgs"]
+ outerArgs := strings.Split(str, " ")
+
+ var background bool
+ if val, ok := options["background"]; ok && val == "start" {
+ background = true
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(1)
- if contains(flags, "background.start") {
+ if background {
commandCtx, cancel := context.WithCancel(h.serverCtx)
+
// TODO: For background jobs dont attempt to send data to dtail client as there might be no SSH connection
- if err := h.background.Add(jobName, cancel, done); err != nil {
+ if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil {
h.sendServerMessage(logger.Error(h.user, err, jobName, args))
finished()
return
}
+ ctx = commandCtx
+ }
- go func() { h.globalServerWaitFor <- struct{}{} }()
- go func() {
- command.Start(commandCtx, argc, args)
- close(done)
- <-h.globalServerWaitFor
- }()
-
- h.sendServerMessage(logger.Info(h.user, jobName, "job started in background"))
+ if err := command.StartBackground(ctx, &wg, argc, args, outerArgs); err != nil {
+ h.sendServerMessage(logger.Error(h.user, "Unable to execute command", argc, args, err))
finished()
return
}
+ // Make sure that server waits for all sub-processes to finish on shutdown
go func() { h.globalServerWaitFor <- struct{}{} }()
go func() {
- command.Start(ctx, argc, args)
- close(done)
+ wg.Wait()
<-h.globalServerWaitFor
- finished()
}()
+ if background {
+ h.sendServerMessage(logger.Info(h.user, jobName, "job started in background"))
+ finished()
+ return
+ }
+
+ // Command run in foreground, wait for it to complete before finishing the connection.
+ wg.Wait()
+ finished()
+
case "ack", ".ack":
h.handleAckCommand(argc, args)
finished()
@@ -427,17 +461,28 @@ func (h *ServerHandler) decrementActiveCommands() int {
return h.activeCommands
}
-func contains(haystack []string, needle string) bool {
- for _, str := range haystack {
- if str == needle {
- return true
+func readOptions(opts []string) (map[string]string, error) {
+ options := make(map[string]string, len(opts))
+
+ for _, o := range opts {
+ kv := strings.SplitN(o, "=", 2)
+ if len(kv) != 2 {
+ return options, fmt.Errorf("Unable to parse options: %v", kv)
}
+ key := kv[0]
+ val := kv[1]
+
+ if strings.HasPrefix(val, "base64%") {
+ s := strings.SplitN(val, "%", 2)
+ decoded, err := base64.StdEncoding.DecodeString(s[1])
+ if err != nil {
+ return options, err
+ }
+ val = string(decoded)
+ }
+
+ options[key] = val
}
- return false
-}
-func hash(str string) string {
- h := sha256.New()
- h.Write([]byte(str))
- return hex.EncodeToString(h.Sum(nil))
+ return options, nil
}