summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2020-12-08 14:49:41 +0000
committerPaul Buetow <pbuetow@mimecast.com>2020-12-08 14:49:41 +0000
commit799b9b69ba08b898e13026b7ecab9f9f58580a82 (patch)
tree34bc0e5e539aed99dd1f13e7489e9d3111ba050f /internal/server
parent6b2d8539a66f1b36ffd55c56723376b9b068a5dc (diff)
merge develop
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/continuous.go2
-rw-r--r--internal/server/handlers/controlhandler.go28
-rw-r--r--internal/server/handlers/handler.go2
-rw-r--r--internal/server/handlers/serverhandler.go182
-rw-r--r--internal/server/scheduler.go2
-rw-r--r--internal/server/server.go44
6 files changed, 78 insertions, 182 deletions
diff --git a/internal/server/continuous.go b/internal/server/continuous.go
index 583d136..f75c732 100644
--- a/internal/server/continuous.go
+++ b/internal/server/continuous.go
@@ -92,7 +92,7 @@ func (c *continuous) runJob(ctx context.Context, job config.Continuous) {
}
logger.Info(fmt.Sprintf("Starting job %s", job.Name))
- status := client.Start(jobCtx, make(chan struct{}))
+ status := client.Start(jobCtx, make(chan string))
logMessage := fmt.Sprintf("Job exited with status %d", status)
if status != 0 {
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
index daa9835..8cc5a40 100644
--- a/internal/server/handlers/controlhandler.go
+++ b/internal/server/handlers/controlhandler.go
@@ -1,20 +1,19 @@
package handlers
import (
- "context"
"fmt"
"io"
"os"
"strings"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/io/logger"
user "github.com/mimecast/dtail/internal/user/server"
)
// ControlHandler is used for control functions and health monitoring.
type ControlHandler struct {
- ctx context.Context
- done chan struct{}
+ done *internal.Done
hostname string
payload []byte
serverMessages chan string
@@ -22,12 +21,11 @@ type ControlHandler struct {
}
// NewControlHandler returns a new control handler.
-func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) {
+func NewControlHandler(user *user.User) *ControlHandler {
logger.Debug(user, "Creating control handler")
h := ControlHandler{
- ctx: ctx,
- done: make(chan struct{}),
+ done: internal.NewDone(),
serverMessages: make(chan string, 10),
user: user,
}
@@ -40,7 +38,17 @@ func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h, h.done
+ return &h
+}
+
+// Shutdown the handler.
+func (h *ControlHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+// Done channel of the handler.
+func (h *ControlHandler) Done() <-chan struct{} {
+ return h.done.Done()
}
// Read is to send data to the client via the Reader interface.
@@ -51,7 +59,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) {
wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
n = copy(p, wholePayload)
return
- case <-h.ctx.Done():
+ case <-h.done.Done():
return 0, io.EOF
}
}
@@ -63,7 +71,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
wholePayload := strings.TrimSpace(string(h.payload))
- h.handleCommand(h.ctx, wholePayload)
+ h.handleCommand(wholePayload)
h.payload = nil
default:
@@ -75,7 +83,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
return
}
-func (h *ControlHandler) handleCommand(ctx context.Context, command string) {
+func (h *ControlHandler) handleCommand(command string) {
logger.Info(h.user, command)
s := strings.Split(command, " ")
logger.Debug(h.user, "Receiving command", command, s)
diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go
index c42ceb9..b04e854 100644
--- a/internal/server/handlers/handler.go
+++ b/internal/server/handlers/handler.go
@@ -5,4 +5,6 @@ import "io"
// Handler interface for server side functionality.
type Handler interface {
io.ReadWriter
+ Shutdown()
+ Done() <-chan struct{}
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 7017f3e..5cf8041 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -7,18 +7,16 @@ import (
"fmt"
"io"
"os"
- "strconv"
"strings"
- "sync"
"sync/atomic"
"time"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr/server"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/server/background"
user "github.com/mimecast/dtail/internal/user/server"
"github.com/mimecast/dtail/internal/version"
)
@@ -31,33 +29,27 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- payload []byte
- hostname string
- user *user.User
- // TODO: Move all these channels into a separate struct for readability!
+ done *internal.Done
+ lines chan line.Line
+ regex string
+ aggregate *server.Aggregate
+ aggregatedMessages chan string
+ serverMessages chan string
+ payload []byte
+ hostname string
+ user *user.User
catLimiter chan struct{}
tailLimiter chan struct{}
globalServerWaitFor chan struct{}
ackCloseReceived chan struct{}
- serverCtx context.Context
- handlerCtx context.Context
- done chan struct{}
activeCommands int32
activeReaders int32
- background background.Background
}
// NewServerHandler returns the server handler.
-func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}, background background.Background) (*ServerHandler, <-chan struct{}) {
+func NewServerHandler(user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) *ServerHandler {
h := ServerHandler{
- serverCtx: serverCtx,
- handlerCtx: handlerCtx,
- done: make(chan struct{}),
+ done: internal.NewDone(),
lines: make(chan line.Line, 100),
serverMessages: make(chan string, 10),
aggregatedMessages: make(chan string, 10),
@@ -67,7 +59,6 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca
globalServerWaitFor: globalServerWaitFor,
regex: ".",
user: user,
- background: background,
}
fqdn, err := os.Hostname()
@@ -78,7 +69,17 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h, h.done
+ return &h
+}
+
+// Shutdown the handler.
+func (h *ServerHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+// Done channel of the handler.
+func (h *ServerHandler) Done() <-chan struct{} {
+ return h.done.Done()
}
// Read is to send data to the dtail client via Reader interface.
@@ -120,7 +121,7 @@ func (h *ServerHandler) Read(p []byte) (n int, err error) {
case <-time.After(time.Second):
// Once in a while check whether we are done.
select {
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
return 0, io.EOF
default:
}
@@ -134,7 +135,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(h.handlerCtx, commandStr)
+ h.handleCommand(commandStr)
h.payload = nil
default:
h.payload = append(h.payload, c)
@@ -145,9 +146,9 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
return
}
-func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
+func (h *ServerHandler) handleCommand(commandStr string) {
logger.Debug(h.user, commandStr)
- var timeout time.Duration
+ ctx := context.Background()
args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
if err != nil {
@@ -161,30 +162,18 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
return
}
- args, argc, timeout, err = h.handleTimeout(args, argc)
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
if h.user.Name == config.ControlUser {
h.handleControlCommand(argc, args)
return
}
- if timeout > 0 {
- logger.Info(h.user, "Command with timeout context", argc, args, timeout)
- commandCtx, cancel := context.WithTimeout(ctx, timeout)
- go func() {
- <-commandCtx.Done()
- logger.Info(h.user, "Command timed out, canceling it", args, args, timeout)
- cancel()
- }()
- h.handleUserCommand(commandCtx, argc, args, timeout)
- return
- }
+ ctx, cancel := context.WithCancel(ctx)
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
- h.handleUserCommand(ctx, argc, args, timeout)
+ h.handleUserCommand(ctx, argc, args)
}
func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) {
@@ -222,16 +211,6 @@ func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, er
return args, argc, nil
}
-func (h *ServerHandler) handleTimeout(args []string, argc int) ([]string, int, time.Duration, error) {
- if argc <= 2 || args[0] != "timeout" {
- // No timeout specified
- return args, argc, time.Duration(0) * time.Second, nil
- }
-
- timeout, err := strconv.Atoi(args[1])
- return args[2:], argc - 2, time.Duration(timeout) * time.Second, err
-}
-
func (h *ServerHandler) handleControlCommand(argc int, args []string) {
switch args[0] {
case "debug":
@@ -241,7 +220,7 @@ func (h *ServerHandler) handleControlCommand(argc int, args []string) {
}
}
-func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string, timeout time.Duration) {
+func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
logger.Debug(h.user, "handleUserCommand", argc, args)
h.incrementActiveCommands()
@@ -255,7 +234,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if h.aggregate == nil {
return
}
- h.aggregate.Cancel()
+ h.aggregate.Shutdown()
}
}
@@ -303,86 +282,6 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
commandFinished()
}()
- case "run":
- // TODO: Refactor this "run" case, move code to runcommand.go
- command := newRunCommand(h)
-
- jobName, _ := options["jobName"]
- logger.Debug(h.user, "run", options)
-
- if val, ok := options["background"]; ok && (val == "cancel" || val == "stop") {
- if err := h.background.Cancel(h.user.Name, jobName); err != nil {
- h.sendServerMessage(logger.Error(h.user, err, jobName, args))
- } else {
- h.sendServerMessage(logger.Info(h.user, "job cancelled", jobName))
- }
- commandFinished()
- return
- }
-
- if val, ok := options["background"]; ok && val == "list" {
- h.sendServerMessage("Listing jobs")
- count := 0
- for jobName := range h.background.ListJobsC(h.user.Name) {
- h.sendServerMessage(jobName)
- count++
- }
- h.sendServerMessage(fmt.Sprintf("Found %d jobs", count))
- commandFinished()
- return
- }
-
- str, _ := options["outerArgs"]
- outerArgs := strings.Split(str, " ")
-
- var background bool
- if val, ok := options["background"]; ok && val == "start" {
- background = true
- }
-
- var wg sync.WaitGroup
- wg.Add(1)
-
- if background {
- if timeout == 0 {
- // Set default background timeout.
- timeout = time.Hour * 1
- }
- // Use a new context based on the server context, so that background job does not get
- // terminated when handler/SSH connection terminates.
- commandCtx, cancel := context.WithTimeout(h.serverCtx, timeout)
-
- if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil {
- h.sendServerMessage(logger.Error(h.user, err, jobName, args))
- commandFinished()
- return
- }
- ctx = commandCtx
- }
-
- if err := command.StartBackground(ctx, &wg, argc, args, outerArgs); err != nil {
- h.sendServerMessage(logger.Error(h.user, "Unable to execute command", argc, args, err))
- commandFinished()
- return
- }
-
- // Make sure that server waits for all sub-processes to finish on shutdown
- go func() { h.globalServerWaitFor <- struct{}{} }()
- go func() {
- wg.Wait()
- <-h.globalServerWaitFor
- }()
-
- if background {
- h.sendServerMessage(logger.Info(h.user, jobName, "job started in background"))
- commandFinished()
- return
- }
-
- // Command run in foreground, wait for it to complete before finishing the connection.
- wg.Wait()
- commandFinished()
-
case "ack", ".ack":
h.handleAckCommand(argc, args)
commandFinished()
@@ -406,7 +305,7 @@ func (h *ServerHandler) handleAckCommand(argc int, args []string) {
func (h *ServerHandler) send(ch chan<- string, message string) {
select {
case ch <- message:
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
}
@@ -447,7 +346,7 @@ func (h *ServerHandler) shutdown() {
go func() {
select {
case h.serverMessageC() <- ".syn close connection":
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
}()
@@ -455,13 +354,10 @@ func (h *ServerHandler) shutdown() {
case <-h.ackCloseReceived:
case <-time.After(time.Second * 5):
logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
- select {
- case h.done <- struct{}{}:
- default:
- }
+ h.done.Shutdown()
}
func (h *ServerHandler) incrementActiveCommands() {
diff --git a/internal/server/scheduler.go b/internal/server/scheduler.go
index 9d76a3b..a1e9e36 100644
--- a/internal/server/scheduler.go
+++ b/internal/server/scheduler.go
@@ -93,7 +93,7 @@ func (s *scheduler) runJobs(ctx context.Context) {
defer cancel()
logger.Info(fmt.Sprintf("Starting job %s", job.Name))
- status := client.Start(jobCtx, make(chan struct{}))
+ status := client.Start(jobCtx, make(chan string))
logMessage := fmt.Sprintf("Job exited with status %d", status)
if status != 0 {
diff --git a/internal/server/server.go b/internal/server/server.go
index a446738..31fa85d 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -11,7 +11,6 @@ import (
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/server/background"
"github.com/mimecast/dtail/internal/server/handlers"
"github.com/mimecast/dtail/internal/ssh/server"
user "github.com/mimecast/dtail/internal/user/server"
@@ -35,9 +34,8 @@ type Server struct {
// Mointor log files for pattern (if configured)
cont *continuous
// Wait counter, e.g. there might be still subprocesses (forked by drun) to be killed.
+ // TODO: Remove this counter.
shutdownWaitFor chan struct{}
- // Background jobs
- background background.Background
}
// New returns a new server.
@@ -51,7 +49,6 @@ func New() *Server {
shutdownWaitFor: make(chan struct{}, 1000),
sched: newScheduler(),
cont: newContinuous(),
- background: background.New(),
}
s.sshServerConfig.PasswordCallback = s.Callback
@@ -178,53 +175,46 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
switch req.Type {
case "shell":
- handlerCtx, cancel := context.WithCancel(ctx)
-
var handler handlers.Handler
- var done <-chan struct{}
-
switch user.Name {
case config.ControlUser:
- handler, done = handlers.NewControlHandler(handlerCtx, user)
+ handler = handlers.NewControlHandler(user)
default:
- handler, done = handlers.NewServerHandler(handlerCtx, ctx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor, s.background)
+ handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor)
}
- go func() {
- // Handler finished work, cancel all remaining routines
- defer cancel()
-
- <-done
- }()
+ terminate := func() {
+ handler.Shutdown()
+ sshConn.Close()
+ }
go func() {
// Broken pipe, cancel
- defer cancel()
-
io.Copy(channel, handler)
+ terminate()
}()
go func() {
// Broken pipe, cancel
- defer cancel()
-
io.Copy(handler, channel)
+ terminate()
}()
go func() {
- defer cancel()
+ select {
+ case <-ctx.Done():
+ case <-handler.Done():
+ }
+ terminate()
+ }()
+ go func() {
if err := sshConn.Wait(); err != nil && err != io.EOF {
logger.Error(user, err)
}
s.stats.decrementConnections()
logger.Info(user, "Good bye Mister!")
- }()
-
- go func() {
- <-handlerCtx.Done()
- sshConn.Close()
- logger.Info(user, "Closed SSH connection")
+ terminate()
}()
// Only serving shell type