summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handlers/controlhandler.go4
-rw-r--r--internal/server/handlers/readcommand.go29
-rw-r--r--internal/server/handlers/runcommand.go111
-rw-r--r--internal/server/handlers/serverhandler.go70
-rw-r--r--internal/server/server.go24
5 files changed, 56 insertions, 182 deletions
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
index 8cc5a40..1e17c78 100644
--- a/internal/server/handlers/controlhandler.go
+++ b/internal/server/handlers/controlhandler.go
@@ -92,9 +92,7 @@ func (h *ControlHandler) handleCommand(command string) {
case "health":
h.serverMessages <- "OK: DTail SSH Server seems fine"
h.serverMessages <- "done;"
- case "debug":
- h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s)
default:
- h.serverMessages <- logger.Warn(h.user, "Received unknown control command", command, s)
+ h.serverMessages <- logger.Error(h.user, "Received unknown control command", command, s)
}
}
diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go
index 0f9207d..5bab26f 100644
--- a/internal/server/handlers/readcommand.go
+++ b/internal/server/handlers/readcommand.go
@@ -25,37 +25,29 @@ func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand {
}
}
-func (r *readCommand) Start(ctx context.Context, argc int, args []string) {
+func (r *readCommand) Start(ctx context.Context, argc int, args []string, retries int) {
re := regex.NewNoop()
if argc >= 4 {
deserializedRegex, err := regex.Deserialize(strings.Join(args[2:], " "))
if err != nil {
- logger.Error(err)
r.server.sendServerMessage(logger.Error(r.server.user, commandParseWarning, err))
return
}
re = deserializedRegex
}
if argc < 3 {
- r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc))
+ r.server.sendServerWarnMessage(logger.Warn(r.server.user, commandParseWarning, args, argc))
return
}
- r.readGlob(ctx, args[1], re)
+ r.readGlob(ctx, args[1], re, retries)
}
-func (r *readCommand) readGlob(ctx context.Context, glob string, re regex.Regex) {
+func (r *readCommand) readGlob(ctx context.Context, glob string, re regex.Regex, retries int) {
retryInterval := time.Second * 5
glob = filepath.Clean(glob)
- maxRetries := 10
- for {
- maxRetries--
- if maxRetries < 0 {
- r.server.sendServerMessage(logger.Warn(r.server.user, "Giving up to read file(s)"))
- return
- }
-
+ for retryCount := 0; retryCount < retries; retryCount++ {
paths, err := filepath.Glob(glob)
if err != nil {
logger.Warn(r.server.user, glob, err)
@@ -65,7 +57,7 @@ func (r *readCommand) readGlob(ctx context.Context, glob string, re regex.Regex)
if numPaths := len(paths); numPaths == 0 {
logger.Error(r.server.user, "No such file(s) to read", glob)
- r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
select {
case <-ctx.Done():
return
@@ -76,8 +68,11 @@ func (r *readCommand) readGlob(ctx context.Context, glob string, re regex.Regex)
}
r.readFiles(ctx, paths, glob, re, retryInterval)
- break
+ return
}
+
+ r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Giving up to read file(s)"))
+ return
}
func (r *readCommand) readFiles(ctx context.Context, paths []string, glob string, re regex.Regex, retryInterval time.Duration) {
@@ -97,7 +92,7 @@ func (r *readCommand) readFileIfPermissions(ctx context.Context, wg *sync.WaitGr
if !r.server.user.HasFilePermission(path, "readfiles") {
logger.Error(r.server.user, "No permission to read file", path, globID)
- r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
return
}
@@ -161,6 +156,6 @@ func (r *readCommand) makeGlobID(path, glob string) string {
return pathParts[len(pathParts)-1]
}
- r.server.sendServerMessage(logger.Error("Empty file path given?", path, glob))
+ r.server.sendServerWarnMessage(logger.Warn("Empty file path given?", path, glob))
return ""
}
diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go
deleted file mode 100644
index 8e5895b..0000000
--- a/internal/server/handlers/runcommand.go
+++ /dev/null
@@ -1,111 +0,0 @@
-package handlers
-
-import (
- "context"
- "errors"
- "fmt"
- "io/ioutil"
- "os"
- "os/exec"
- "strings"
- "sync"
- "time"
-
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/io/run"
-)
-
-type runCommand struct {
- server *ServerHandler
- run run.Run
-}
-
-func newRunCommand(server *ServerHandler) runCommand {
- return runCommand{
- server: server,
- }
-}
-
-func (r runCommand) StartBackground(ctx context.Context, wg *sync.WaitGroup, argc int, args, outerArgs []string) error {
- if argc < 2 {
- return fmt.Errorf("%s: args:%v argc:%d", commandParseWarning, args, argc)
- }
-
- ec := make(chan int, 1)
- var pid int
- var err error
-
- command := strings.Join(args[1:], " ")
- if strings.Contains(command, ";") || strings.Contains(command, "\n") {
- if pid, err = r.startScript(ctx, wg, ec, command, outerArgs); err != nil {
- r.server.sendServerMessage(".run exitstatus 255")
- return err
- }
- return nil
- }
-
- if pid, err = r.start(ctx, wg, ec, strings.TrimSpace(command), outerArgs); err != nil {
- r.server.sendServerMessage(".run exitstatus 255")
- return err
- }
-
- exitCode := <-ec
- r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", exitCode))
- r.server.sendServerMessage(logger.Info(fmt.Sprintf("Process %d exited with status %d", pid, exitCode)))
-
- return nil
-}
-
-func (r runCommand) startScript(ctx context.Context, wg *sync.WaitGroup, ec chan<- int, script string, outerArgs []string) (int, error) {
- if _, err := os.Stat(config.Common.TmpDir); os.IsNotExist(err) {
- return -1, err
- }
-
- timestamp := time.Now().UnixNano()
- scriptPath := fmt.Sprintf("%s/%s_%v.sh", config.Common.TmpDir, r.server.user.Name, timestamp)
-
- // TODO: On dserver startup delete all previously written scripts (there might be left overs due to a crash or so)
- logger.Debug(r.server.user, "Writing temp script", scriptPath)
-
- script = fmt.Sprintf("#!/bin/sh\n%s", script)
- if err := ioutil.WriteFile(scriptPath, []byte(script), 0700); err != nil {
- return -1, err
- }
-
- pid, err := r.start(ctx, wg, ec, scriptPath, outerArgs)
- go func() {
- wg.Wait()
- logger.Debug("Deleting script", scriptPath)
- os.Remove(scriptPath)
- }()
-
- return pid, err
-}
-
-func (r runCommand) start(ctx context.Context, wg *sync.WaitGroup, ec chan<- int, command string, outerArgs []string) (int, error) {
- if len(command) == 0 {
- return -1, errors.New("Empty command provided")
- }
-
- splitted := strings.Split(command, " ")
- path := splitted[0]
- args := splitted[1:]
- args = append(args, outerArgs...)
-
- qualifiedPath, err := exec.LookPath(path)
- if err != nil {
- return -1, err
- }
-
- if !r.server.user.HasFilePermission(qualifiedPath, "runcommands") {
- return -1, fmt.Errorf("No permission to execute path: %s", qualifiedPath)
- }
-
- r.run = run.New(qualifiedPath, args)
- pid, err := r.run.StartBackground(ctx, wg, ec, r.server.lines)
- if err != nil {
- return pid, err
- }
- return pid, nil
-}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 5cf8041..185e7c2 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -29,36 +29,35 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- done *internal.Done
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- payload []byte
- hostname string
- user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- globalServerWaitFor chan struct{}
- ackCloseReceived chan struct{}
- activeCommands int32
- activeReaders int32
+ done *internal.Done
+ lines chan line.Line
+ regex string
+ aggregate *server.Aggregate
+ aggregatedMessages chan string
+ serverMessages chan string
+ payload []byte
+ hostname string
+ user *user.User
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ ackCloseReceived chan struct{}
+ activeCommands int32
+ activeReaders int32
+ quiet bool
}
// NewServerHandler returns the server handler.
-func NewServerHandler(user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) *ServerHandler {
+func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler {
h := ServerHandler{
- done: internal.NewDone(),
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- aggregatedMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- globalServerWaitFor: globalServerWaitFor,
- regex: ".",
- user: user,
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ aggregatedMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ regex: ".",
+ user: user,
}
fqdn, err := os.Hostname()
@@ -247,13 +246,19 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
commandFinished()
return
}
+ if quiet, ok := options["quiet"]; ok {
+ if quiet == "true" {
+ logger.Debug(h.user, "Enabling quiet mode")
+ h.quiet = true
+ }
+ }
switch commandName {
case "grep", "cat":
command := newReadCommand(h, omode.CatClient)
go func() {
h.incrementActiveReaders()
- command.Start(ctx, argc, args)
+ command.Start(ctx, argc, args, 1)
readerFinished()
commandFinished()
}()
@@ -262,7 +267,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
command := newReadCommand(h, omode.TailClient)
go func() {
h.incrementActiveReaders()
- command.Start(ctx, argc, args)
+ command.Start(ctx, argc, args, 10)
readerFinished()
commandFinished()
}()
@@ -294,7 +299,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
func (h *ServerHandler) handleAckCommand(argc int, args []string) {
if argc < 3 {
- h.sendServerMessage(logger.Warn(h.user, commandParseWarning, args, argc))
+ h.sendServerWarnMessage(logger.Warn(h.user, commandParseWarning, args, argc))
return
}
if args[1] == "close" && args[2] == "connection" {
@@ -313,6 +318,13 @@ func (h *ServerHandler) sendServerMessage(message string) {
h.send(h.serverMessageC(), message)
}
+func (h *ServerHandler) sendServerWarnMessage(message string) {
+ if h.quiet {
+ return
+ }
+ h.send(h.serverMessageC(), message)
+}
+
func (h *ServerHandler) serverMessageC() chan<- string {
return h.serverMessages
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 31fa85d..a20737e 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -7,7 +7,6 @@ import (
"io"
"net"
"strings"
- "time"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/logger"
@@ -33,9 +32,6 @@ type Server struct {
sched *scheduler
// Mointor log files for pattern (if configured)
cont *continuous
- // Wait counter, e.g. there might be still subprocesses (forked by drun) to be killed.
- // TODO: Remove this counter.
- shutdownWaitFor chan struct{}
}
// New returns a new server.
@@ -46,7 +42,6 @@ func New() *Server {
sshServerConfig: &gossh.ServerConfig{},
catLimiter: make(chan struct{}, config.Server.MaxConcurrentCats),
tailLimiter: make(chan struct{}, config.Server.MaxConcurrentTails),
- shutdownWaitFor: make(chan struct{}, 1000),
sched: newScheduler(),
cont: newContinuous(),
}
@@ -80,27 +75,12 @@ func (s *Server) Start(ctx context.Context) int {
go s.cont.start(ctx)
go s.listenerLoop(ctx, listener)
- select {
- case <-ctx.Done():
- // Wait until all commands/jobs/children are no more!
- s.wait()
- }
+ <-ctx.Done()
// For future use.
return 0
}
-func (s *Server) wait() {
- for {
- num := len(s.shutdownWaitFor)
- logger.Debug("Waiting for stuff to finish", num)
- if num <= 0 {
- return
- }
- time.Sleep(time.Second)
- }
-}
-
func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) {
logger.Debug("Starting listener loop")
@@ -180,7 +160,7 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
case config.ControlUser:
handler = handlers.NewControlHandler(user)
default:
- handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor)
+ handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter)
}
terminate := func() {