summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handlers/controlhandler.go42
-rw-r--r--internal/server/handlers/handler.go2
-rw-r--r--internal/server/handlers/mapcommand.go35
-rw-r--r--internal/server/handlers/readcommand.go158
-rw-r--r--internal/server/handlers/runcommand.go73
-rw-r--r--internal/server/handlers/serverhandler.go521
-rw-r--r--internal/server/server.go70
-rw-r--r--internal/server/stats.go10
8 files changed, 503 insertions, 408 deletions
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
index 482f759..a33a78b 100644
--- a/internal/server/handlers/controlhandler.go
+++ b/internal/server/handlers/controlhandler.go
@@ -1,33 +1,34 @@
package handlers
import (
+ "context"
"fmt"
"io"
"os"
"strings"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
user "github.com/mimecast/dtail/internal/user/server"
)
// ControlHandler is used for control functions and health monitoring.
type ControlHandler struct {
- serverMessages chan string
- pong chan struct{}
- stop chan struct{}
- payload []byte
+ ctx context.Context
+ done chan struct{}
hostname string
+ payload []byte
+ serverMessages chan string
user *user.User
}
// NewControlHandler returns a new control handler.
-func NewControlHandler(user *user.User) *ControlHandler {
+func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) {
logger.Debug(user, "Creating control handler")
h := ControlHandler{
+ ctx: ctx,
+ done: make(chan struct{}),
serverMessages: make(chan string, 10),
- pong: make(chan struct{}, 10),
- stop: make(chan struct{}),
user: user,
}
@@ -38,7 +39,8 @@ func NewControlHandler(user *user.User) *ControlHandler {
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h
+
+ return &h, h.done
}
// Read is to send data to the client via the Reader interface.
@@ -49,11 +51,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) {
wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
n = copy(p, wholePayload)
return
- case <-h.pong:
- logger.Info(h.user, "Sending pong")
- n = copy(p, []byte(".pong\n"))
- return
- case <-h.stop:
+ case <-h.ctx.Done():
return 0, io.EOF
}
}
@@ -65,7 +63,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
wholePayload := strings.TrimSpace(string(h.payload))
- h.handleCommand(wholePayload)
+ h.handleCommand(h.ctx, wholePayload)
h.payload = nil
default:
@@ -77,17 +75,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
return
}
-// Close the control handler.
-func (h *ControlHandler) Close() {
- close(h.stop)
-}
-
-// Wait returns the handler stop channel.
-func (h *ControlHandler) Wait() <-chan struct{} {
- return h.stop
-}
-
-func (h *ControlHandler) handleCommand(command string) {
+func (h *ControlHandler) handleCommand(ctx context.Context, command string) {
logger.Info(h.user, command)
s := strings.Split(command, " ")
logger.Debug(h.user, "Receiving command", command, s)
@@ -96,8 +84,6 @@ func (h *ControlHandler) handleCommand(command string) {
case "health":
h.serverMessages <- "OK: DTail SSH Server seems fine"
h.serverMessages <- "done;"
- case "ping":
- h.pong <- struct{}{}
case "debug":
h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s)
default:
diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go
index 8b1f73e..c42ceb9 100644
--- a/internal/server/handlers/handler.go
+++ b/internal/server/handlers/handler.go
@@ -5,6 +5,4 @@ import "io"
// Handler interface for server side functionality.
type Handler interface {
io.ReadWriter
- Close()
- Wait() <-chan struct{}
}
diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go
new file mode 100644
index 0000000..10372da
--- /dev/null
+++ b/internal/server/handlers/mapcommand.go
@@ -0,0 +1,35 @@
+package handlers
+
+import (
+ "context"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/mapr/server"
+)
+
+// Map command implements the mapreduce command server side.
+type mapCommand struct {
+ aggregate *server.Aggregate
+ server *ServerHandler
+}
+
+// NewMapCommand returns a new server side mapreduce command.
+func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) {
+ mapCommand := mapCommand{
+ server: serverHandler,
+ }
+
+ queryStr := strings.Join(args[1:], " ")
+ aggregate, err := server.NewAggregate(queryStr)
+ if err != nil {
+ return mapCommand, nil, err
+ }
+
+ mapCommand.aggregate = aggregate
+ return mapCommand, aggregate, nil
+
+}
+
+func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) {
+ m.aggregate.Start(ctx, aggregatedMessages)
+}
diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go
new file mode 100644
index 0000000..e4079e8
--- /dev/null
+++ b/internal/server/handlers/readcommand.go
@@ -0,0 +1,158 @@
+package handlers
+
+import (
+ "context"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/mimecast/dtail/internal/io/fs"
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/omode"
+)
+
+type readCommand struct {
+ server *ServerHandler
+ mode omode.Mode
+}
+
+func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand {
+ return &readCommand{
+ server: server,
+ mode: mode,
+ }
+}
+
+func (r *readCommand) Start(ctx context.Context, argc int, args []string) {
+ regex := "."
+ if argc >= 4 {
+ regex = args[3]
+ }
+ if argc < 3 {
+ r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc))
+ return
+ }
+ r.readGlob(ctx, args[1], regex)
+}
+
+func (r *readCommand) readGlob(ctx context.Context, glob string, regex string) {
+ retryInterval := time.Second * 5
+ glob = filepath.Clean(glob)
+
+ maxRetries := 10
+ for {
+ maxRetries--
+ if maxRetries < 0 {
+ r.server.sendServerMessage(logger.Warn(r.server.user, "Giving up to read file(s)"))
+ return
+ }
+
+ paths, err := filepath.Glob(glob)
+ if err != nil {
+ logger.Warn(r.server.user, glob, err)
+ time.Sleep(retryInterval)
+ continue
+ }
+
+ if numPaths := len(paths); numPaths == 0 {
+ logger.Error(r.server.user, "No such file(s) to read", glob)
+ r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+ time.Sleep(retryInterval)
+ continue
+ }
+
+ r.readFiles(ctx, paths, glob, regex, retryInterval)
+ break
+ }
+}
+
+func (r *readCommand) readFiles(ctx context.Context, paths []string, glob string, regex string, retryInterval time.Duration) {
+ var wg sync.WaitGroup
+ wg.Add(len(paths))
+
+ for _, path := range paths {
+ go r.readFileIfPermissions(ctx, &wg, path, glob, regex)
+ }
+
+ wg.Wait()
+}
+
+func (r *readCommand) readFileIfPermissions(ctx context.Context, wg *sync.WaitGroup, path, glob, regex string) {
+ defer wg.Done()
+ globID := r.makeGlobID(path, glob)
+
+ if !r.server.user.HasFilePermission(path, "readfiles") {
+ logger.Error(r.server.user, "No permission to read file", path, globID)
+ r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ return
+ }
+
+ r.readFile(ctx, path, globID, regex)
+}
+
+func (r *readCommand) readFile(ctx context.Context, path, globID, regex string) {
+ logger.Info(r.server.user, "Start reading file", path, globID)
+
+ var reader fs.FileReader
+ switch r.mode {
+ case omode.TailClient:
+ reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter)
+ case omode.GrepClient, omode.CatClient:
+ reader = fs.NewCatFile(path, globID, r.server.serverMessages, r.server.catLimiter)
+ default:
+ reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter)
+ }
+
+ lines := r.server.lines
+
+ // Plug in mappreduce engine
+ if r.server.aggregate != nil {
+ lines = r.server.aggregate.Lines
+ }
+
+ for {
+ if err := reader.Start(ctx, lines, regex); err != nil {
+ logger.Error(r.server.user, path, globID, err)
+ }
+
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ if !reader.Retry() {
+ return
+ }
+ }
+
+ time.Sleep(time.Second * 2)
+ logger.Info(path, globID, "Reading file again")
+ }
+}
+
+func (r *readCommand) makeGlobID(path, glob string) string {
+ var idParts []string
+ pathParts := strings.Split(path, "/")
+
+ for i, globPart := range strings.Split(glob, "/") {
+ if strings.Contains(globPart, "*") {
+ idParts = append(idParts, pathParts[i])
+ }
+ }
+
+ if len(idParts) > 0 {
+ return strings.Join(idParts, "/")
+ }
+
+ if len(pathParts) > 0 {
+ return pathParts[len(pathParts)-1]
+ }
+
+ r.server.sendServerMessage(logger.Error("Empty file path given?", path, glob))
+ return ""
+}
diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go
new file mode 100644
index 0000000..e260060
--- /dev/null
+++ b/internal/server/handlers/runcommand.go
@@ -0,0 +1,73 @@
+package handlers
+
+import (
+ "context"
+ "fmt"
+ "os/exec"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/run"
+)
+
+type runCommand struct {
+ server *ServerHandler
+ run run.Run
+}
+
+func newRunCommand(server *ServerHandler) runCommand {
+ return runCommand{
+ server: server,
+ }
+}
+
+func (r runCommand) Start(ctx context.Context, argc int, args []string) {
+ if argc < 2 {
+ r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc))
+ return
+ }
+ commands := strings.Split(strings.Join(args[1:], " "), ";")
+ r.start(ctx, commands)
+}
+
+func (r runCommand) start(ctx context.Context, commands []string) {
+ for _, command := range commands {
+ command = strings.TrimSpace(command)
+ if len(command) == 0 {
+ continue
+ }
+ splitted := strings.Split(command, " ")
+ path := splitted[0]
+ args := splitted[1:]
+
+ qualifiedPath, err := exec.LookPath(path)
+ if err != nil {
+ logger.Error(r.server.user, err)
+ r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs"))
+ r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1))
+ return
+ }
+
+ if !r.server.user.HasFilePermission(qualifiedPath, "runcommands") {
+ logger.Error(r.server.user, "No permission to execute path", qualifiedPath)
+ r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs"))
+ r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1))
+ return
+ }
+
+ r.run = run.New(qualifiedPath, args)
+ pid, ec, err := r.run.Start(ctx, r.server.lines)
+
+ if err != nil {
+ message := fmt.Sprintf("Unable to execute remote command '%s'", command)
+ logger.Error(r.server.user, message, ec, pid, err)
+ r.server.sendServerMessage(logger.Error(message, ec, pid, err))
+ r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", ec))
+ return
+ }
+
+ message := fmt.Sprintf("Remote process '%d' exited with status '%d'", pid, ec)
+ r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec))
+ r.server.sendServerMessage(logger.Info("run", pid, ec, message))
+ }
+}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index bed8609..3f0d6ce 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -1,17 +1,19 @@
package handlers
import (
+ "context"
+ "encoding/base64"
+ "errors"
"fmt"
"io"
"os"
- "path/filepath"
"strings"
"sync"
"time"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/fs"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr/server"
"github.com/mimecast/dtail/internal/omode"
user "github.com/mimecast/dtail/internal/user/server"
@@ -26,51 +28,33 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- // Local log file readers
- fileReaders []fs.FileReader
- fileReadersMtx *sync.Mutex
- // Channel for read lines.
- lines chan fs.LineRead
- // Only process log lines matching this regex.
- regex string
- // Server side mapr log aggregation.
- aggregate *server.Aggregate
- // Channel of aggregated log lines.
+ mutex *sync.Mutex
+ lines chan line.Line
+ regex string
+ aggregate *server.Aggregate
aggregatedMessages chan string
- // Channel for server messages to be sent to the client.
- serverMessages chan string
- // Channel for hidden messages to be sent to the client.
- hiddenMessages chan string
- // The current payload sent to the client.
- payload []byte
- // The current server hostname.
- hostname string
- // The user connecting to dtail.
- user *user.User
- // To limit the server wide max amount of concurrent cats
- catLimiter chan struct{}
- // To limit the server wide max amount of concurrent tails
- tailLimiter chan struct{}
- // Server can tell handler to stop the handler.
- stop chan struct{}
- // Indicate that client responded to server with "ack stop connection"
- ackStopReceived chan struct{}
- // Stop timeout.
- stopTimeout chan struct{}
+ serverMessages chan string
+ payload []byte
+ hostname string
+ user *user.User
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ ackCloseReceived chan struct{}
+ ctx context.Context
+ done chan struct{}
+ activeReaders int
}
// NewServerHandler returns the server handler.
-func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) *ServerHandler {
- logger.Debug(user, "Creating tail handler")
+func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) {
h := ServerHandler{
- fileReadersMtx: &sync.Mutex{},
- lines: make(chan fs.LineRead, 100),
+ ctx: ctx,
+ done: make(chan struct{}),
+ mutex: &sync.Mutex{},
+ lines: make(chan line.Line, 100),
serverMessages: make(chan string, 10),
aggregatedMessages: make(chan string, 10),
- hiddenMessages: make(chan string, 10),
- ackStopReceived: make(chan struct{}),
- stopTimeout: make(chan struct{}),
- stop: make(chan struct{}),
+ ackCloseReceived: make(chan struct{}),
catLimiter: catLimiter,
tailLimiter: tailLimiter,
regex: ".",
@@ -85,37 +69,46 @@ func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter cha
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h
+ return &h, h.done
}
// Read is to send data to the dtail client via Reader interface.
func (h *ServerHandler) Read(p []byte) (n int, err error) {
for {
select {
+
case message := <-h.serverMessages:
+ if message[0] == '.' {
+ // Handle hidden message (don't display to the user, interpreted by dtail client)
+ wholePayload := []byte(fmt.Sprintf("%s\n", message))
+ n = copy(p, wholePayload)
+ return
+ }
+
+ // Handle normal server message (display to the user)
wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
n = copy(p, wholePayload)
return
+
case message := <-h.aggregatedMessages:
+ // Send mapreduce-aggregated data as a message.
data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message)
- //logger.Debug("Sending aggregation data", data)
wholePayload := []byte(data)
n = copy(p, wholePayload)
return
- case message := <-h.hiddenMessages:
- //logger.Debug(h.user, "Sending hidden message", message)
- wholePayload := []byte(fmt.Sprintf(".%s\n", message))
- n = copy(p, wholePayload)
- return
+
case line := <-h.lines:
+ // Send normal file content data as a message.
serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|",
- h.hostname, line.TransmittedPerc, line.Count, *line.GlobID))
+ h.hostname, line.TransmittedPerc, line.Count, line.SourceID))
wholePayload := append(serverInfo, line.Content[:]...)
n = copy(p, wholePayload)
return
+
case <-time.After(time.Second):
+ // Once in a while check whether we are done.
select {
- case <-h.stop:
+ case <-h.ctx.Done():
return 0, io.EOF
default:
}
@@ -129,7 +122,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(commandStr)
+ h.handleCommand(h.ctx, commandStr)
h.payload = nil
default:
h.payload = append(h.payload, c)
@@ -140,210 +133,167 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
return
}
-// Close the server handler.
-func (h *ServerHandler) Close() {
- h.fileReadersMtx.Lock()
- defer h.fileReadersMtx.Unlock()
+func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
+ logger.Debug(h.user, commandStr)
- for _, reader := range h.fileReaders {
- reader.Stop()
+ args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
+ if err != nil {
+ h.send(h.serverMessages, logger.Error(h.user, err))
+ return
}
- if h.aggregate != nil {
- h.aggregate.Close()
+
+ args, argc, err = h.handleBase64(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, logger.Error(h.user, err))
+ return
}
- close(h.stop)
-}
+ if h.user.Name == config.ControlUser {
+ h.handleControlCommand(argc, args)
+ return
+ }
-func (h *ServerHandler) makeGlobID(path, glob string) string {
- var idParts []string
- pathParts := strings.Split(path, "/")
+ h.handleUserCommand(ctx, argc, args)
+}
- for i, globPart := range strings.Split(glob, "/") {
- if strings.Contains(globPart, "*") {
- idParts = append(idParts, pathParts[i])
- }
- }
+func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) {
+ argc := len(args)
- if len(idParts) > 0 {
- return strings.Join(idParts, "/")
+ if argc <= 2 || args[0] != "protocol" {
+ return args, argc, errors.New("unable to determine protocol version")
}
- if len(pathParts) > 0 {
- return pathParts[len(pathParts)-1]
+ if args[1] != version.ProtocolCompat {
+ err := fmt.Errorf("server with protool version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1])
+ return args, argc, err
}
- h.send(h.serverMessages, logger.Error("Empty file path given?", path, glob))
- return ""
+ return args[2:], argc - 2, nil
}
-func (h *ServerHandler) processFileGlob(mode omode.Mode, glob string, regex string) {
- retryInterval := time.Second * 5
- glob = filepath.Clean(glob)
-
- errors := make(chan struct{})
- stop := make(chan struct{})
- defer close(stop)
+func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) {
+ err := errors.New("Unable to decode client message")
- go func() {
- for {
- select {
- case <-errors:
- h.send(h.serverMessages, logger.Warn(h.user, "Unable to read file(s), check server logs"))
- case <-stop:
- return
- case <-h.stop:
- return
- }
- }
- }()
+ if argc != 2 || args[0] != "base64" {
+ return args, argc, err
+ }
- maxRetries := 10
- for {
- maxRetries--
- if maxRetries < 0 {
- h.send(h.serverMessages, logger.Warn(h.user, "Giving up to read file(s)"))
- h.internalClose()
- return
- }
+ decoded, err := base64.StdEncoding.DecodeString(args[1])
+ if err != nil {
+ return args, argc, err
+ }
+ decodedStr := string(decoded)
- paths, err := filepath.Glob(glob)
- if err != nil {
- logger.Warn(h.user, glob, err)
- time.Sleep(retryInterval)
- continue
- }
+ args = strings.Split(decodedStr, " ")
+ argc = len(decodedStr)
+ logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
- if numPaths := len(paths); numPaths == 0 {
- logger.Error(h.user, "No such file(s) to read", glob)
- select {
- case errors <- struct{}{}:
- case <-h.stop:
- return
- default:
- }
- time.Sleep(retryInterval)
- continue
- }
+ return args, argc, nil
+}
- h.startReadingFiles(mode, paths, glob, regex, retryInterval, errors)
- break
+func (h *ServerHandler) handleControlCommand(argc int, args []string) {
+ switch args[0] {
+ case "debug":
+ h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args))
+ default:
+ logger.Warn(h.user, "Received unknown command", argc, args)
}
}
-func (h *ServerHandler) startReadingFiles(mode omode.Mode, paths []string, glob string, regex string, retryInterval time.Duration, errors chan<- struct{}) {
- var wg sync.WaitGroup
- wg.Add(len(paths))
+func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
+ logger.Debug(h.user, "handleUserCommand", argc, args)
- read := func(path string, wg *sync.WaitGroup) {
- defer wg.Done()
- globID := h.makeGlobID(path, glob)
+ switch args[0] {
+ case "grep", "cat":
+ command := newReadCommand(h, omode.CatClient)
+ h.incrementActiveReaders()
+ go func() {
+ command.Start(ctx, argc, args)
+ if h.decrementActiveReaders() == 0 {
+ h.shutdown()
+ }
+ }()
- if !h.user.HasFilePermission(path) {
- logger.Error(h.user, "No permission to read file", path, globID)
- select {
- case errors <- struct{}{}:
- default:
+ case "tail":
+ command := newReadCommand(h, omode.TailClient)
+ h.incrementActiveReaders()
+ go func() {
+ command.Start(ctx, argc, args)
+ if h.decrementActiveReaders() == 0 {
+ h.shutdown()
}
+ }()
+
+ case "map":
+ command, aggregate, err := newMapCommand(h, argc, args)
+ if err != nil {
+ h.sendServerMessage(err.Error())
+ logger.Error(h.user, err)
return
}
- h.startReadingFile(mode, path, globID, regex)
- }
-
- for _, path := range paths {
- go read(path, &wg)
- }
+ h.aggregate = aggregate
+ go func() {
+ command.Start(ctx, h.aggregatedMessages)
+ h.shutdown()
+ }()
+
+ case "run":
+ command := newRunCommand(h)
+ h.incrementActiveReaders()
+ go func() {
+ command.Start(ctx, argc, args)
+ if h.decrementActiveReaders() == 0 {
+ h.shutdown()
+ }
+ }()
- wg.Wait()
-}
+ case "ack", ".ack":
+ h.handleAckCommand(argc, args)
-func (h *ServerHandler) startReadingFile(mode omode.Mode, path, globID, regex string) {
- defer h.stopReadingFile(path)
- logger.Info(h.user, "Start reading file", path, globID)
-
- var reader fs.FileReader
- switch mode {
- case omode.TailClient:
- reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter)
- case omode.GrepClient:
- fallthrough
- case omode.CatClient:
- reader = fs.NewCatFile(path, globID, h.serverMessages, h.catLimiter)
default:
- reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter)
+ h.sendServerMessage(logger.Error(h.user, "Received unknown command", argc, args))
}
+}
- h.fileReadersMtx.Lock()
- h.fileReaders = append(h.fileReaders, reader)
- h.fileReadersMtx.Unlock()
-
- lines := h.lines
- // Plugin mappreduce engine
- if h.aggregate != nil {
- lines = h.aggregate.Lines
+func (h *ServerHandler) handleAckCommand(argc int, args []string) {
+ if argc < 3 {
+ h.sendServerMessage(logger.Warn(h.user, commandParseWarning, args, argc))
+ return
}
-
- for {
- if err := reader.Start(lines, regex); err != nil {
- logger.Error(h.user, path, globID, err)
- }
-
- select {
- case <-h.stop:
- return
- default:
- if !reader.Retry() {
- return
- }
- }
-
- time.Sleep(time.Second * 2)
- logger.Info(path, globID, "Reading file again")
+ if args[1] == "close" && args[2] == "connection" {
+ close(h.ackCloseReceived)
}
}
-func (h *ServerHandler) stopReadingFile(path string) {
- logger.Info(h.user, "Stop reading file", path)
+func (h *ServerHandler) send(ch chan<- string, message string) {
+ select {
+ case ch <- message:
+ case <-h.ctx.Done():
+ }
+}
- h.fileReadersMtx.Lock()
- defer h.fileReadersMtx.Unlock()
+func (h *ServerHandler) sendServerMessage(message string) {
+ h.send(h.serverMessageC(), message)
+}
- path = filepath.Clean(path)
- var fileReaders []fs.FileReader
+func (h *ServerHandler) serverMessageC() chan<- string {
+ return h.serverMessages
+}
- for _, reader := range h.fileReaders {
- if reader.FilePath() == path {
- reader.Stop()
- continue
- }
- fileReaders = append(fileReaders, reader)
- }
+func (h *ServerHandler) flush() {
+ logger.Debug(h.user, "flush()")
- if len(fileReaders) == len(h.fileReaders) {
- logger.Warn(h.user, "Didn't read file path", path)
- return
+ if h.aggregate != nil {
+ h.aggregate.Flush()
}
- h.fileReaders = fileReaders
-
- if len(fileReaders) == 0 {
- if h.aggregate != nil {
- h.aggregate.Serialize()
- }
- h.allLinesSent()
+ unsentMessages := func() int {
+ return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages)
}
-}
-
-func (h *ServerHandler) numUnsentMessages() int {
- return len(h.lines) + len(h.serverMessages) + len(h.hiddenMessages) + len(h.aggregatedMessages)
-}
-
-func (h *ServerHandler) allLinesSent() {
- defer h.internalClose()
for i := 0; i < 3; i++ {
- if h.numUnsentMessages() == 0 {
+ if unsentMessages() == 0 {
logger.Debug(h.user, "All lines sent")
return
}
@@ -351,142 +301,43 @@ func (h *ServerHandler) allLinesSent() {
time.Sleep(time.Second)
}
- logger.Warn(h.user, "Some lines remain unsent", h.numUnsentMessages())
+ logger.Warn(h.user, "Some lines remain unsent", unsentMessages())
}
-// Handler decides to shutdown the connection, not the server itself.
-func (h *ServerHandler) internalClose() {
- select {
- case h.hiddenMessages <- "syn close connection":
- case <-time.After(time.Second * 5):
- logger.Debug(h.user, "Not waiting for ack close connection")
- close(h.stopTimeout)
- return
- }
+func (h *ServerHandler) shutdown() {
+ logger.Debug(h.user, "shutdown()")
+ h.flush()
+
+ go func() {
+ select {
+ case h.serverMessageC() <- ".syn close connection":
+ case <-h.ctx.Done():
+ }
+ }()
select {
- case <-h.Wait():
+ case <-h.ackCloseReceived:
case <-time.After(time.Second * 5):
- logger.Debug(h.user, "Not waiting for ack close connection")
- close(h.stopTimeout)
- }
-}
-
-func (h *ServerHandler) handleCommand(commandStr string) {
- logger.Info(h.user, commandStr)
-
- args := strings.Split(commandStr, " ")
- argc := len(args)
-
- logger.Debug(h.user, "Received command", commandStr, argc, args)
-
- if h.user.Name == config.ControlUser {
- h.handleControlCommand(argc, args)
- return
+ logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
+ case <-h.ctx.Done():
}
- h.handleUserCommand(argc, args)
-}
-
-// Special (restricted) set of commands for anonymous ControlUser access.
-func (h *ServerHandler) handleControlCommand(argc int, args []string) {
- switch args[0] {
- case "ping":
- h.send(h.hiddenMessages, "pong")
- case "debug":
- h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args))
- default:
- logger.Warn(h.user, "Received unknown command", argc, args)
- }
-}
-
-// Commands for authed users.
-func (h *ServerHandler) handleUserCommand(argc int, args []string) {
- switch args[0] {
- case "grep":
- fallthrough
- case "cat":
- h.handleReadCommand(argc, args, omode.CatClient)
- case "tail":
- h.handleReadCommand(argc, args, omode.TailClient)
- case "map":
- h.handleMapCommand(argc, args)
- case "ack":
- h.handleAckCommand(argc, args)
- case "ping":
- h.send(h.hiddenMessages, "pong")
- case "version":
- h.send(h.serverMessages, fmt.Sprintf("Server version is "+version.String()))
- case "debug":
- h.send(h.serverMessages, logger.Debug(h.user, "Received debug command", argc, args))
+ select {
+ case h.done <- struct{}{}:
default:
- h.send(h.serverMessages, logger.Warn(h.user, "Received unknown command", argc, args))
}
}
-func (h *ServerHandler) handleReadCommand(argc int, args []string, mode omode.Mode) {
- regex := "."
- if argc >= 4 {
- regex = args[3]
- }
- if argc < 3 {
- h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
- return
- }
- go h.processFileGlob(mode, args[1], regex)
+func (h *ServerHandler) incrementActiveReaders() {
+ // TODO: Use atomic counter variable instead, so we can get rid of the mutex
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ h.activeReaders++
}
+func (h *ServerHandler) decrementActiveReaders() int {
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ h.activeReaders--
-func (h *ServerHandler) handleMapCommand(argc int, args []string) {
- if argc < 2 {
- h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
- return
- }
-
- queryStr := strings.Join(args[1:], " ")
- logger.Info(h.user, "Creating new mapr aggregator", queryStr)
- aggregate, err := server.NewAggregate(h.aggregatedMessages, queryStr)
-
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
- h.aggregate = aggregate
-}
-
-func (h *ServerHandler) handleAckCommand(argc int, args []string) {
- if argc < 3 {
- h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
- return
- }
- if args[1] == "close" && args[2] == "connection" {
- close(h.ackStopReceived)
- }
-}
-
-func (h *ServerHandler) send(ch chan<- string, message string) {
- select {
- case ch <- message:
- case <-h.stop:
- }
-}
-
-// Wait (block) until server handler is closed or a timeout has exceeded.
-func (h *ServerHandler) Wait() <-chan struct{} {
- wait := make(chan struct{})
-
- go func() {
- select {
- case <-h.ackStopReceived:
- logger.Debug(h.user, "Closing wait channel due to ACK stop received")
- close(wait)
- case <-h.stopTimeout:
- logger.Debug(h.user, "Closing wait channel due to wait timeout")
- close(wait)
- case <-h.stop:
- logger.Debug(h.user, "Closing wait channel due to stop")
- }
- }()
-
- return wait
+ return h.activeReaders
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 27a98f5..42eb74c 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -1,13 +1,14 @@
package server
import (
+ "context"
"errors"
"fmt"
"io"
"net"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/server/handlers"
"github.com/mimecast/dtail/internal/ssh/server"
user "github.com/mimecast/dtail/internal/user/server"
@@ -26,8 +27,6 @@ type Server struct {
catLimiterCh chan struct{}
// To control the max amount of concurrent tails
tailLimiterCh chan struct{}
- // Ask to shutdown the server
- stop chan struct{}
}
// New returns a new server.
@@ -38,7 +37,6 @@ func New() *Server {
sshServerConfig: &gossh.ServerConfig{},
catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats),
tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails),
- stop: make(chan struct{}),
}
s.sshServerConfig.PasswordCallback = s.controlUserCallback
@@ -54,7 +52,7 @@ func New() *Server {
}
// Start the server.
-func (s *Server) Start() int {
+func (s *Server) Start(ctx context.Context) int {
logger.Info("Starting server")
bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
@@ -64,7 +62,7 @@ func (s *Server) Start() int {
logger.FatalExit("Failed to open listening TCP socket", err)
}
- go s.stats.periodicLogServerStats(s.stop)
+ go s.stats.periodicLogServerStats(ctx)
for {
conn, err := listener.Accept() // Blocking
@@ -79,11 +77,11 @@ func (s *Server) Start() int {
continue
}
- go s.handleConnection(conn)
+ go s.handleConnection(ctx, conn)
}
}
-func (s *Server) handleConnection(conn net.Conn) {
+func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
logger.Info("Handling connection")
sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig)
@@ -96,11 +94,11 @@ func (s *Server) handleConnection(conn net.Conn) {
go gossh.DiscardRequests(reqs)
for newChannel := range chans {
- go s.handleChannel(sshConn, newChannel)
+ go s.handleChannel(ctx, sshConn, newChannel)
}
}
-func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) {
+func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) {
user := user.New(sshConn.User(), sshConn.RemoteAddr().String())
logger.Info(user, "Invoking channel handler")
@@ -117,13 +115,13 @@ func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel)
return
}
- if err := s.handleRequests(sshConn, requests, channel, user); err != nil {
+ if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil {
logger.Error(user, err)
sshConn.Close()
}
}
-func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
+func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
logger.Info(user, "Invoking request handler")
for req := range in {
@@ -132,50 +130,50 @@ func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, ch
switch req.Type {
case "shell":
+ handlerCtx, cancel := context.WithCancel(ctx)
+
var handler handlers.Handler
+ var done <-chan struct{}
+
switch user.Name {
case config.ControlUser:
- handler = handlers.NewControlHandler(user)
+ handler, done = handlers.NewControlHandler(handlerCtx, user)
default:
- handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh)
+ handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh)
}
- // Bi-directionally connect SSH stream to SSH handler
- brokenPipe1 := make(chan struct{})
go func() {
- defer close(brokenPipe1)
+ // Handler finished work, cancel all remaining routines
+ defer cancel()
+ <-done
+ }()
+
+ go func() {
+ // Broken pipe, cancel
+ defer cancel()
+
io.Copy(channel, handler)
}()
- brokenPipe2 := make(chan struct{})
go func() {
- defer close(brokenPipe2)
+ // Broken pipe, cancel
+ defer cancel()
+
io.Copy(handler, channel)
}()
- // Ensure to close all fd's and stop all goroutines once ssh connection terminated
go func() {
- defer s.stats.decrementConnections()
- defer handler.Close()
+ defer cancel()
if err := sshConn.Wait(); err != nil && err != io.EOF {
logger.Error(user, err)
}
+ s.stats.decrementConnections()
logger.Info(user, "Good bye Mister!")
}()
- // Close the underlying ssh socket when server shuts down
go func() {
- select {
- case <-s.stop:
- logger.Debug(user, "Server initiating shutdown on handler")
- case <-handler.Wait():
- logger.Debug(user, "Handler initiating shutdown by its own")
- case <-brokenPipe1:
- logger.Debug(user, "Broken pipe1")
- case <-brokenPipe2:
- logger.Debug(user, "Broken pipe2")
- }
+ <-handlerCtx.Done()
sshConn.Close()
logger.Info(user, "Closed SSH connection")
}()
@@ -204,9 +202,3 @@ func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*g
return nil, fmt.Errorf("Not authorized")
}
-
-// Stop the server.
-func (s *Server) Stop() {
- close(s.stop)
- s.stats.waitForConnections()
-}
diff --git a/internal/server/stats.go b/internal/server/stats.go
index beb1885..4d661f7 100644
--- a/internal/server/stats.go
+++ b/internal/server/stats.go
@@ -1,12 +1,14 @@
package server
import (
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
+ "context"
"fmt"
"runtime"
"sync"
"time"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/logger"
)
// Used to collect and display various server stats.
@@ -65,12 +67,12 @@ func (s *stats) serverLimitExceeded() error {
return nil
}
-func (s *stats) periodicLogServerStats(stop <-chan struct{}) {
+func (s *stats) periodicLogServerStats(ctx context.Context) {
for {
select {
case <-time.NewTimer(time.Second * 10).C:
s.logServerStats()
- case <-stop:
+ case <-ctx.Done():
return
}
}