summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2021-10-21 21:28:49 +0300
committerPaul Buetow <pbuetow@mimecast.com>2021-10-21 21:28:49 +0300
commitf4207a55f71bfbcfdc532d5cdd3befaa3474a157 (patch)
treeea5e4a2d2a67035f645bdee496ae55a52034178a /internal/server
parentd80d6070557e3a800e3a54967af9eced518f116b (diff)
parent739205206d63bf42f4e843b39d04d4c8cd8207c3 (diff)
merge develop
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/continuous.go36
-rw-r--r--internal/server/handlers/basehandler.go320
-rw-r--r--internal/server/handlers/controlhandler.go98
-rw-r--r--internal/server/handlers/healthhandler.go58
-rw-r--r--internal/server/handlers/mapcommand.go7
-rw-r--r--internal/server/handlers/readcommand.go83
-rw-r--r--internal/server/handlers/serverhandler.go412
-rw-r--r--internal/server/scheduler.go37
-rw-r--r--internal/server/server.go110
-rw-r--r--internal/server/stats.go21
10 files changed, 557 insertions, 625 deletions
diff --git a/internal/server/continuous.go b/internal/server/continuous.go
index f75c732..93b3fcb 100644
--- a/internal/server/continuous.go
+++ b/internal/server/continuous.go
@@ -8,33 +8,29 @@ import (
"github.com/mimecast/dtail/internal/clients"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/omode"
-
gossh "golang.org/x/crypto/ssh"
)
-type continuous struct {
-}
+type continuous struct{}
func newContinuous() *continuous {
return &continuous{}
}
func (c *continuous) start(ctx context.Context) {
- logger.Info("Starting continuous job runner after 10s")
+ dlog.Server.Info("Starting continuous job runner after 10s")
time.Sleep(time.Second * 10)
-
c.runJobs(ctx)
}
func (c *continuous) runJobs(ctx context.Context) {
for _, job := range config.Server.Continuous {
if !job.Enable {
- logger.Debug(job.Name, "Not running job as not enabled")
+ dlog.Server.Debug(job.Name, "Not running job as not enabled")
continue
}
-
go func(job config.Continuous) {
c.runJob(ctx, job)
for {
@@ -51,18 +47,17 @@ func (c *continuous) runJobs(ctx context.Context) {
}
func (c *continuous) runJob(ctx context.Context, job config.Continuous) {
- logger.Debug(job.Name, "Processing job")
+ dlog.Server.Debug(job.Name, "Processing job")
files := fillDates(job.Files)
outfile := fillDates(job.Outfile)
-
servers := strings.Join(job.Servers, ",")
if servers == "" {
servers = config.Server.SSHBindAddress
}
- args := clients.Args{
- ConnectionsPerCPU: 10,
+ args := config.Args{
+ ConnectionsPerCPU: config.DefaultConnectionsPerCPU,
Discovery: job.Discovery,
ServersStr: servers,
What: files,
@@ -71,35 +66,32 @@ func (c *continuous) runJob(ctx context.Context, job config.Continuous) {
}
args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name))
-
- query := fmt.Sprintf("%s outfile %s", job.Query, outfile)
- client, err := clients.NewMaprClient(args, query, clients.NonCumulativeMode)
+ args.QueryStr = fmt.Sprintf("%s outfile %s", job.Query, outfile)
+ client, err := clients.NewMaprClient(args, clients.NonCumulativeMode)
if err != nil {
- logger.Error(fmt.Sprintf("Unable to create job %s", job.Name), err)
+ dlog.Server.Error(fmt.Sprintf("Unable to create job %s", job.Name), err)
return
}
jobCtx, cancel := context.WithCancel(ctx)
defer cancel()
-
if job.RestartOnDayChange {
go func() {
if c.waitForDayChange(ctx) {
- logger.Info(fmt.Sprintf("Canceling job %s due to day change", job.Name))
+ dlog.Server.Info(fmt.Sprintf("Canceling job %s due to day change", job.Name))
cancel()
}
}()
}
- logger.Info(fmt.Sprintf("Starting job %s", job.Name))
+ dlog.Server.Info(fmt.Sprintf("Starting job %s", job.Name))
status := client.Start(jobCtx, make(chan string))
logMessage := fmt.Sprintf("Job exited with status %d", status)
-
if status != 0 {
- logger.Warn(logMessage)
+ dlog.Server.Warn(logMessage)
return
}
- logger.Info(logMessage)
+ dlog.Server.Info(logMessage)
}
func (c *continuous) waitForDayChange(ctx context.Context) bool {
diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go
new file mode 100644
index 0000000..6d10d17
--- /dev/null
+++ b/internal/server/handlers/basehandler.go
@@ -0,0 +1,320 @@
+package handlers
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/pool"
+ "github.com/mimecast/dtail/internal/lcontext"
+ "github.com/mimecast/dtail/internal/mapr/server"
+ "github.com/mimecast/dtail/internal/protocol"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+type handleCommandCb func(context.Context, lcontext.LContext, int, []string, string)
+
+type baseHandler struct {
+ done *internal.Done
+ handleCommandCb handleCommandCb
+ lines chan line.Line
+ aggregate *server.Aggregate
+ maprMessages chan string
+ serverMessages chan string
+ hostname string
+ user *user.User
+ ackCloseReceived chan struct{}
+ activeCommands int32
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+
+ // Some global options + sync primitives required.
+ once sync.Once
+ mutex sync.Mutex
+ quiet bool
+ spartan bool
+ serverless bool
+}
+
+// Shutdown the handler.
+func (h *baseHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+// Done channel of the handler.
+func (h *baseHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+// Read is to send data to the dtail client via Reader interface.
+func (h *baseHandler) Read(p []byte) (n int, err error) {
+ defer h.readBuf.Reset()
+
+ select {
+ case message := <-h.serverMessages:
+ if len(message) > 0 && message[0] == '.' {
+ // Handle hidden message (don't display to the user)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ return
+ }
+
+ if h.serverless {
+ return
+ }
+
+ // Handle normal server message (display to the user)
+ h.readBuf.WriteString("SERVER")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case message := <-h.maprMessages:
+ // Send mapreduce-aggregated data as a message.
+ h.readBuf.WriteString("AGGREGATE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case line := <-h.lines:
+ if !h.spartan {
+ h.readBuf.WriteString("REMOTE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%v", line.Count))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(line.SourceID)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ }
+ h.readBuf.WriteString(line.Content.String())
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ pool.RecycleBytesBuffer(line.Content)
+
+ case <-time.After(time.Second):
+ // Once in a while check whether we are done.
+ select {
+ case <-h.done.Done():
+ err = io.EOF
+ return
+ default:
+ }
+ }
+ return
+}
+
+// Write is to receive data from the dtail client via Writer interface.
+func (h *baseHandler) Write(p []byte) (n int, err error) {
+ for _, b := range p {
+ switch b {
+ case ';':
+ h.handleCommand(string(h.writeBuf.Bytes()))
+ h.writeBuf.Reset()
+ default:
+ h.writeBuf.WriteByte(b)
+ }
+ }
+ n = len(p)
+ return
+}
+
+func (h *baseHandler) handleCommand(commandStr string) {
+ dlog.Server.Debug(h.user, commandStr)
+
+ args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add)
+ return
+ }
+ args, argc, err = h.handleBase64(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err))
+ return
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
+
+ parts := strings.Split(args[0], ":")
+ commandName := parts[0]
+
+ // Either no options or empty options provided.
+ if len(parts) == 1 || len(parts[1]) == 0 {
+ h.handleCommandCb(ctx, lcontext.LContext{}, argc, args, commandName)
+ return
+ }
+
+ options, ltx, err := config.DeserializeOptions(parts[1:])
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err))
+ return
+ }
+ h.handleOptions(options)
+ h.handleCommandCb(ctx, ltx, argc, args, commandName)
+}
+
+func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) {
+ argc := len(args)
+ var add string
+
+ if argc <= 2 || args[0] != "protocol" {
+ return args, argc, add, errors.New("unable to determine protocol version")
+ }
+
+ if args[1] != protocol.ProtocolCompat {
+ clientCompat, _ := strconv.Atoi(args[1])
+ serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat)
+ if clientCompat <= 3 {
+ // Protocol version 3 or lower expect a newline as message separator
+ // One day (after 2 major versions) this exception may be removed!
+ add = "\n"
+ }
+
+ toUpdate := "client"
+ if clientCompat > serverCompat {
+ toUpdate = "server"
+ }
+ err := fmt.Errorf("the DTail server protocol version '%s' does not match "+
+ "client protocol version '%s', please update DTail %s",
+ protocol.ProtocolCompat, args[1], toUpdate)
+ return args, argc, add, err
+ }
+
+ return args[2:], argc - 2, add, nil
+}
+
+func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, error) {
+ err := errors.New("unable to decode client message, DTail server and client " +
+ "versions may not be compatible")
+ if argc != 2 || args[0] != "base64" {
+ return args, argc, err
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(args[1])
+ if err != nil {
+ return args, argc, err
+ }
+ decodedStr := string(decoded)
+
+ args = strings.Split(decodedStr, " ")
+ argc = len(decodedStr)
+ dlog.Server.Trace(h.user, "Base64 decoded received command",
+ decodedStr, argc, args)
+
+ return args, argc, nil
+}
+
+func (h *baseHandler) handleAckCommand(argc int, args []string) {
+ if argc < 3 {
+ if !h.quiet {
+ h.send(h.serverMessages, dlog.Server.Warn(h.user,
+ "Unable to parse command", args, argc))
+ }
+ return
+ }
+ if args[1] == "close" && args[2] == "connection" {
+ select {
+ case <-h.ackCloseReceived:
+ default:
+ close(h.ackCloseReceived)
+ }
+ }
+}
+
+func (h *baseHandler) handleOptions(options map[string]string) {
+ // We have to make sure that this block is executed only once.
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ // We can read the options only once, will cause a data race otherwise if
+ // changed multiple times for multiple incoming commands.
+ h.once.Do(func() {
+ if quiet, _ := options["quiet"]; quiet == "true" {
+ dlog.Server.Debug(h.user, "Enabling quiet mode")
+ h.quiet = true
+ }
+ if spartan, _ := options["spartan"]; spartan == "true" {
+ dlog.Server.Debug(h.user, "Enabling spartan mode")
+ h.spartan = true
+ }
+ if serverless, _ := options["serverless"]; serverless == "true" {
+ dlog.Server.Debug(h.user, "Enabling serverless mode")
+ h.serverless = true
+ }
+ })
+}
+
+func (h *baseHandler) send(ch chan<- string, message string) {
+ select {
+ case ch <- message:
+ case <-h.done.Done():
+ }
+}
+
+func (h *baseHandler) flush() {
+ dlog.Server.Trace(h.user, "flush()")
+ numUnsentMessages := func() int {
+ return len(h.lines) + len(h.serverMessages) + len(h.maprMessages)
+ }
+ for i := 0; i < 10; i++ {
+ if numUnsentMessages() == 0 {
+ dlog.Server.Debug(h.user, "ALL lines sent", fmt.Sprintf("%p", h))
+ return
+ }
+ dlog.Server.Debug(h.user, "Still lines to be sent")
+ time.Sleep(time.Millisecond * 10)
+ }
+ dlog.Server.Warn(h.user, "Some lines remain unsent", numUnsentMessages())
+}
+
+func (h *baseHandler) shutdown() {
+ dlog.Server.Debug(h.user, "shutdown()")
+ h.flush()
+
+ go func() {
+ select {
+ case h.serverMessages <- ".syn close connection":
+ case <-h.done.Done():
+ }
+ }()
+
+ select {
+ case <-h.ackCloseReceived:
+ case <-time.After(time.Second * 5):
+ dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
+ case <-h.done.Done():
+ }
+ h.done.Shutdown()
+}
+
+func (h *baseHandler) incrementActiveCommands() {
+ atomic.AddInt32(&h.activeCommands, 1)
+}
+
+func (h *baseHandler) decrementActiveCommands() int32 {
+ atomic.AddInt32(&h.activeCommands, -1)
+ return atomic.LoadInt32(&h.activeCommands)
+}
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
deleted file mode 100644
index 1e17c78..0000000
--- a/internal/server/handlers/controlhandler.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package handlers
-
-import (
- "fmt"
- "io"
- "os"
- "strings"
-
- "github.com/mimecast/dtail/internal"
- "github.com/mimecast/dtail/internal/io/logger"
- user "github.com/mimecast/dtail/internal/user/server"
-)
-
-// ControlHandler is used for control functions and health monitoring.
-type ControlHandler struct {
- done *internal.Done
- hostname string
- payload []byte
- serverMessages chan string
- user *user.User
-}
-
-// NewControlHandler returns a new control handler.
-func NewControlHandler(user *user.User) *ControlHandler {
- logger.Debug(user, "Creating control handler")
-
- h := ControlHandler{
- done: internal.NewDone(),
- serverMessages: make(chan string, 10),
- user: user,
- }
-
- fqdn, err := os.Hostname()
- if err != nil {
- logger.FatalExit(err)
- }
-
- s := strings.Split(fqdn, ".")
- h.hostname = s[0]
-
- return &h
-}
-
-// Shutdown the handler.
-func (h *ControlHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// Done channel of the handler.
-func (h *ControlHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-// Read is to send data to the client via the Reader interface.
-func (h *ControlHandler) Read(p []byte) (n int, err error) {
- for {
- select {
- case message := <-h.serverMessages:
- wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
- n = copy(p, wholePayload)
- return
- case <-h.done.Done():
- return 0, io.EOF
- }
- }
-}
-
-// Write is to read data to the client via the Writer interface.
-func (h *ControlHandler) Write(p []byte) (n int, err error) {
- for _, c := range p {
- switch c {
- case ';':
- wholePayload := strings.TrimSpace(string(h.payload))
- h.handleCommand(wholePayload)
- h.payload = nil
-
- default:
- h.payload = append(h.payload, c)
- }
- }
-
- n = len(p)
- return
-}
-
-func (h *ControlHandler) handleCommand(command string) {
- logger.Info(h.user, command)
- s := strings.Split(command, " ")
- logger.Debug(h.user, "Receiving command", command, s)
-
- switch s[0] {
- case "health":
- h.serverMessages <- "OK: DTail SSH Server seems fine"
- h.serverMessages <- "done;"
- default:
- h.serverMessages <- logger.Error(h.user, "Received unknown control command", command, s)
- }
-}
diff --git a/internal/server/handlers/healthhandler.go b/internal/server/handlers/healthhandler.go
new file mode 100644
index 0000000..6dd9872
--- /dev/null
+++ b/internal/server/handlers/healthhandler.go
@@ -0,0 +1,58 @@
+package handlers
+
+import (
+ "context"
+ "os"
+ "strings"
+
+ "github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/lcontext"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+// HealthHandler is for the remote health check.
+type HealthHandler struct {
+ baseHandler
+}
+
+// NewHealthHandler returns the server handler.
+func NewHealthHandler(user *user.User) *HealthHandler {
+ dlog.Server.Debug(user, "Creating new server health handler")
+ h := HealthHandler{
+ baseHandler: baseHandler{
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ maprMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ user: user,
+ },
+ }
+ h.handleCommandCb = h.handleHealthCommand
+
+ fqdn, err := os.Hostname()
+ if err != nil {
+ dlog.Server.FatalPanic(err)
+ }
+ s := strings.Split(fqdn, ".")
+ h.hostname = s[0]
+ return &h
+}
+
+func (h *HealthHandler) handleHealthCommand(ctx context.Context,
+ ltx lcontext.LContext, argc int, args []string, commandName string) {
+
+ dlog.Server.Debug(h.user, "Handling health command", argc, args)
+ switch commandName {
+ case "health":
+ h.send(h.serverMessages, "OK")
+ case ".ack":
+ h.handleAckCommand(argc, args)
+ default:
+ h.send(h.serverMessages, dlog.Server.Error(h.user,
+ "Received unknown health command", commandName, argc, args))
+ }
+ h.shutdown()
+}
diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go
index c3e600e..65e0ed8 100644
--- a/internal/server/handlers/mapcommand.go
+++ b/internal/server/handlers/mapcommand.go
@@ -14,18 +14,17 @@ type mapCommand struct {
}
// NewMapCommand returns a new server side mapreduce command.
-func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) {
- m := mapCommand{server: serverHandler}
+func newMapCommand(serverHandler *ServerHandler, argc int,
+ args []string) (mapCommand, *server.Aggregate, error) {
+ m := mapCommand{server: serverHandler}
queryStr := strings.Join(args[1:], " ")
aggregate, err := server.NewAggregate(queryStr)
if err != nil {
return m, nil, err
}
-
m.aggregate = aggregate
return m, aggregate, nil
-
}
func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) {
diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go
index b659c06..4728a55 100644
--- a/internal/server/handlers/readcommand.go
+++ b/internal/server/handlers/readcommand.go
@@ -7,8 +7,9 @@ import (
"sync"
"time"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/fs"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/lcontext"
"github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/regex"
@@ -26,39 +27,45 @@ func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand {
}
}
-func (r *readCommand) Start(ctx context.Context, lContext lcontext.LContext, argc int, args []string, retries int) {
- re := regex.NewNoop()
+func (r *readCommand) Start(ctx context.Context, ltx lcontext.LContext,
+ argc int, args []string, retries int) {
+ re := regex.NewNoop()
if argc >= 4 {
deserializedRegex, err := regex.Deserialize(strings.Join(args[2:], " "))
if err != nil {
- r.server.sendServerMessage(logger.Error(r.server.user, commandParseWarning, err))
+ r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user,
+ "Unable to parse command", err))
return
}
re = deserializedRegex
}
if argc < 3 {
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, commandParseWarning, args, argc))
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Unable to parse command", args, argc))
return
}
- r.readGlob(ctx, lContext, args[1], re, retries)
+ r.readGlob(ctx, ltx, args[1], re, retries)
}
-func (r *readCommand) readGlob(ctx context.Context, lContext lcontext.LContext, glob string, re regex.Regex, retries int) {
+func (r *readCommand) readGlob(ctx context.Context, ltx lcontext.LContext,
+ glob string, re regex.Regex, retries int) {
+
retryInterval := time.Second * 5
glob = filepath.Clean(glob)
for retryCount := 0; retryCount < retries; retryCount++ {
paths, err := filepath.Glob(glob)
if err != nil {
- logger.Warn(r.server.user, glob, err)
+ dlog.Server.Warn(r.server.user, glob, err)
time.Sleep(retryInterval)
continue
}
if numPaths := len(paths); numPaths == 0 {
- logger.Error(r.server.user, "No such file(s) to read", glob)
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ dlog.Server.Error(r.server.user, "No such file(s) to read", glob)
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Unable to read file(s), check server logs"))
select {
case <-ctx.Done():
return
@@ -68,41 +75,44 @@ func (r *readCommand) readGlob(ctx context.Context, lContext lcontext.LContext,
continue
}
- r.readFiles(ctx, lContext, paths, glob, re, retryInterval)
+ r.readFiles(ctx, ltx, paths, glob, re, retryInterval)
return
}
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Giving up to read file(s)"))
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Giving up to read file(s)"))
return
}
-func (r *readCommand) readFiles(ctx context.Context, lContext lcontext.LContext, paths []string, glob string, re regex.Regex, retryInterval time.Duration) {
+func (r *readCommand) readFiles(ctx context.Context, ltx lcontext.LContext,
+ paths []string, glob string, re regex.Regex, retryInterval time.Duration) {
+
var wg sync.WaitGroup
wg.Add(len(paths))
-
for _, path := range paths {
- go r.readFileIfPermissions(ctx, lContext, &wg, path, glob, re)
+ go r.readFileIfPermissions(ctx, ltx, &wg, path, glob, re)
}
-
wg.Wait()
}
-func (r *readCommand) readFileIfPermissions(ctx context.Context, lContext lcontext.LContext, wg *sync.WaitGroup, path, glob string, re regex.Regex) {
+func (r *readCommand) readFileIfPermissions(ctx context.Context, ltx lcontext.LContext,
+ wg *sync.WaitGroup, path, glob string, re regex.Regex) {
+
defer wg.Done()
globID := r.makeGlobID(path, glob)
-
if !r.server.user.HasFilePermission(path, "readfiles") {
- logger.Error(r.server.user, "No permission to read file", path, globID)
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ dlog.Server.Error(r.server.user, "No permission to read file", path, globID)
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Unable to read file(s), check server logs"))
return
}
-
- r.readFile(ctx, lContext, path, globID, re)
+ r.readFile(ctx, ltx, path, globID, re)
}
-func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext, path, globID string, re regex.Regex) {
- logger.Info(r.server.user, "Start reading file", path, globID)
+func (r *readCommand) readFile(ctx context.Context, ltx lcontext.LContext,
+ path, globID string, re regex.Regex) {
+ dlog.Server.Info(r.server.user, "Start reading file", path, globID)
var reader fs.FileReader
switch r.mode {
case omode.TailClient:
@@ -114,15 +124,19 @@ func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext,
}
lines := r.server.lines
-
- // Plug in mappreduce engine
- if r.server.aggregate != nil {
- lines = r.server.aggregate.Lines
- }
+ aggregate := r.server.aggregate
for {
- if err := reader.Start(ctx, lContext, lines, re); err != nil {
- logger.Error(r.server.user, path, globID, err)
+ if aggregate != nil {
+ lines = make(chan line.Line, 100)
+ aggregate.NextLinesCh <- lines
+ }
+ if err := reader.Start(ctx, ltx, lines, re); err != nil {
+ dlog.Server.Error(r.server.user, path, globID, err)
+ }
+ if aggregate != nil {
+ // Also makes aggregate to Flush
+ close(lines)
}
select {
@@ -133,9 +147,8 @@ func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext,
return
}
}
-
time.Sleep(time.Second * 2)
- logger.Info(path, globID, "Reading file again")
+ dlog.Server.Info(path, globID, "Reading file again")
}
}
@@ -152,11 +165,11 @@ func (r *readCommand) makeGlobID(path, glob string) string {
if len(idParts) > 0 {
return strings.Join(idParts, "/")
}
-
if len(pathParts) > 0 {
return pathParts[len(pathParts)-1]
}
- r.server.sendServerWarnMessage(logger.Warn("Empty file path given?", path, glob))
+ r.server.send(r.server.serverMessages,
+ dlog.Server.Warn("Empty file path given?", path, glob))
return ""
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 39d5d5f..36574a9 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -2,69 +2,50 @@ package handlers
import (
"context"
- "encoding/base64"
- "errors"
- "fmt"
- "io"
"os"
- "strconv"
"strings"
- "sync/atomic"
- "time"
"github.com/mimecast/dtail/internal"
- "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/line"
- "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/lcontext"
- "github.com/mimecast/dtail/internal/mapr/server"
"github.com/mimecast/dtail/internal/omode"
user "github.com/mimecast/dtail/internal/user/server"
- "github.com/mimecast/dtail/internal/version"
-)
-
-const (
- commandParseWarning string = "Unable to parse command"
)
// ServerHandler implements the Reader and Writer interfaces to handle
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- done *internal.Done
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- payload []byte
- hostname string
- user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- ackCloseReceived chan struct{}
- activeCommands int32
- activeReaders int32
- quiet bool
+ baseHandler
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ regex string
}
// NewServerHandler returns the server handler.
-func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler {
+func NewServerHandler(user *user.User, catLimiter,
+ tailLimiter chan struct{}) *ServerHandler {
+
+ dlog.Server.Debug(user, "Creating new server handler")
h := ServerHandler{
- done: internal.NewDone(),
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- aggregatedMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- regex: ".",
- user: user,
- }
+ baseHandler: baseHandler{
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ maprMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ user: user,
+ },
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ regex: ".",
+ }
+ h.handleCommandCb = h.handleUserCommand
fqdn, err := os.Hostname()
if err != nil {
- logger.FatalExit(err)
+ dlog.Server.FatalPanic(err)
}
s := strings.Split(fqdn, ".")
@@ -73,374 +54,49 @@ func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *S
return &h
}
-// Shutdown the handler.
-func (h *ServerHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// Done channel of the handler.
-func (h *ServerHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-// Read is to send data to the dtail client via Reader interface.
-func (h *ServerHandler) Read(p []byte) (n int, err error) {
- for {
- select {
- case message := <-h.serverMessages:
- if len(message) == 0 {
- logger.Warn(h.user, "Empty message received")
- return
- }
- if message[0] == '.' {
- // Handle hidden message (don't display to the user, interpreted by dtail client)
- wholePayload := []byte(fmt.Sprintf("%s\n", message))
- n = copy(p, wholePayload)
- return
- }
-
- // Handle normal server message (display to the user)
- wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
- n = copy(p, wholePayload)
- return
-
- case message := <-h.aggregatedMessages:
- // Send mapreduce-aggregated data as a message.
- data := fmt.Sprintf("AGGREGATE➔%s➔%s\n", h.hostname, message)
- wholePayload := []byte(data)
- n = copy(p, wholePayload)
- return
-
- case line := <-h.lines:
- // Send normal file content data as a message.
- serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|",
- h.hostname, line.TransmittedPerc, line.Count, line.SourceID))
- wholePayload := append(serverInfo, line.Content[:]...)
- n = copy(p, wholePayload)
- return
-
- case <-time.After(time.Second):
- // Once in a while check whether we are done.
- select {
- case <-h.done.Done():
- return 0, io.EOF
- default:
- }
- }
- }
-}
-
-// Write is to receive data from the dtail client via Writer interface.
-func (h *ServerHandler) Write(p []byte) (n int, err error) {
- for _, c := range p {
- switch c {
- case ';':
- commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(commandStr)
- h.payload = nil
- default:
- h.payload = append(h.payload, c)
- }
- }
-
- n = len(p)
- return
-}
-
-func (h *ServerHandler) handleCommand(commandStr string) {
- logger.Debug(h.user, commandStr)
- ctx := context.Background()
-
- args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
- args, argc, err = h.handleBase64(args, argc)
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
- if h.user.Name == config.ControlUser {
- h.handleControlCommand(argc, args)
- return
- }
-
- ctx, cancel := context.WithCancel(ctx)
- go func() {
- <-h.done.Done()
- cancel()
- }()
-
- h.handleUserCommand(ctx, argc, args)
-}
-
-func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) {
- argc := len(args)
-
- if argc <= 2 || args[0] != "protocol" {
- return args, argc, errors.New("unable to determine protocol version")
- }
-
- if args[1] != version.ProtocolCompat {
- err := fmt.Errorf("server with protocol version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1])
- return args, argc, err
- }
-
- return args[2:], argc - 2, nil
-}
-
-func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) {
- err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible")
-
- if argc != 2 || args[0] != "base64" {
- return args, argc, err
- }
-
- decoded, err := base64.StdEncoding.DecodeString(args[1])
- if err != nil {
- return args, argc, err
- }
- decodedStr := string(decoded)
-
- args = strings.Split(decodedStr, " ")
- argc = len(decodedStr)
- logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
-
- return args, argc, nil
-}
-
-func (h *ServerHandler) handleControlCommand(argc int, args []string) {
- switch args[0] {
- case "debug":
- h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args))
- default:
- logger.Warn(h.user, "Received unknown control command", argc, args)
- }
-}
-
-func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
- logger.Debug(h.user, "handleUserCommand", argc, args)
+func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LContext,
+ argc int, args []string, commandName string) {
+ dlog.Server.Debug(h.user, "Handling user command", argc, args)
h.incrementActiveCommands()
commandFinished := func() {
if h.decrementActiveCommands() == 0 {
h.shutdown()
}
}
- readerFinished := func() {
- if h.decrementActiveReaders() == 0 {
- if h.aggregate == nil {
- return
- }
- h.aggregate.Shutdown()
- }
- }
-
- splitted := strings.Split(args[0], ":")
- commandName := splitted[0]
-
- options, lContext, err := readOptions(splitted[1:])
- if err != nil {
- h.sendServerMessage(logger.Error(h.user, err))
- commandFinished()
- return
- }
- if quiet, ok := options["quiet"]; ok {
- if quiet == "true" {
- logger.Debug(h.user, "Enabling quiet mode")
- h.quiet = true
- }
- }
switch commandName {
case "grep", "cat":
command := newReadCommand(h, omode.CatClient)
go func() {
- h.incrementActiveReaders()
- command.Start(ctx, lContext, argc, args, 1)
- readerFinished()
+ command.Start(ctx, ltx, argc, args, 1)
commandFinished()
}()
-
case "tail":
command := newReadCommand(h, omode.TailClient)
go func() {
- h.incrementActiveReaders()
- command.Start(ctx, lContext, argc, args, 10)
- readerFinished()
+ command.Start(ctx, ltx, argc, args, 10)
commandFinished()
}()
-
case "map":
command, aggregate, err := newMapCommand(h, argc, args)
if err != nil {
- h.sendServerMessage(err.Error())
- logger.Error(h.user, err)
+ h.send(h.serverMessages, err.Error())
+ dlog.Server.Error(h.user, err)
commandFinished()
return
}
-
h.aggregate = aggregate
go func() {
- command.Start(ctx, h.aggregatedMessages)
+ command.Start(ctx, h.maprMessages)
commandFinished()
}()
-
- case "ack", ".ack":
+ case ".ack":
h.handleAckCommand(argc, args)
commandFinished()
-
default:
- h.sendServerMessage(logger.Error(h.user, "Received unknown user command", commandName, argc, args, options))
+ h.send(h.serverMessages, dlog.Server.Error(h.user,
+ "Received unknown user command", commandName, argc, args))
commandFinished()
}
}
-
-func (h *ServerHandler) handleAckCommand(argc int, args []string) {
- if argc < 3 {
- h.sendServerWarnMessage(logger.Warn(h.user, commandParseWarning, args, argc))
- return
- }
- if args[1] == "close" && args[2] == "connection" {
- close(h.ackCloseReceived)
- }
-}
-
-func (h *ServerHandler) send(ch chan<- string, message string) {
- select {
- case ch <- message:
- case <-h.done.Done():
- }
-}
-
-func (h *ServerHandler) sendServerMessage(message string) {
- h.send(h.serverMessageC(), message)
-}
-
-func (h *ServerHandler) sendServerWarnMessage(message string) {
- if h.quiet {
- return
- }
- h.send(h.serverMessageC(), message)
-}
-
-func (h *ServerHandler) serverMessageC() chan<- string {
- return h.serverMessages
-}
-
-func (h *ServerHandler) flush() {
- logger.Debug(h.user, "flush()")
-
- if h.aggregate != nil {
- h.aggregate.Flush()
- }
-
- unsentMessages := func() int {
- return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages)
- }
- for i := 0; i < 3; i++ {
- if unsentMessages() == 0 {
- logger.Debug(h.user, "All lines sent")
- return
- }
- logger.Debug(h.user, "Still lines to be sent")
- time.Sleep(time.Second)
- }
-
- logger.Warn(h.user, "Some lines remain unsent", unsentMessages())
-}
-
-func (h *ServerHandler) shutdown() {
- logger.Debug(h.user, "shutdown()")
- h.flush()
-
- go func() {
- select {
- case h.serverMessageC() <- ".syn close connection":
- case <-h.done.Done():
- }
- }()
-
- select {
- case <-h.ackCloseReceived:
- case <-time.After(time.Second * 5):
- logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.done.Done():
- }
-
- h.done.Shutdown()
-}
-
-func (h *ServerHandler) incrementActiveCommands() {
- atomic.AddInt32(&h.activeCommands, 1)
-}
-
-func (h *ServerHandler) decrementActiveCommands() int32 {
- atomic.AddInt32(&h.activeCommands, -1)
- return atomic.LoadInt32(&h.activeCommands)
-}
-
-func (h *ServerHandler) incrementActiveReaders() {
- atomic.AddInt32(&h.activeReaders, 1)
-}
-
-func (h *ServerHandler) decrementActiveReaders() int32 {
- atomic.AddInt32(&h.activeReaders, -1)
- return atomic.LoadInt32(&h.activeReaders)
-}
-
-// TODO: All options related code should be in its own package (client + server)
-func readOptions(opts []string) (map[string]string, lcontext.LContext, error) {
- options := make(map[string]string, len(opts))
- // Local search context
- var lContext lcontext.LContext
-
- for _, o := range opts {
- kv := strings.SplitN(o, "=", 2)
- if len(kv) != 2 {
- continue
- }
- key := kv[0]
- val := kv[1]
-
- if strings.HasPrefix(val, "base64%") {
- s := strings.SplitN(val, "%", 2)
- decoded, err := base64.StdEncoding.DecodeString(s[1])
- if err != nil {
- return options, lContext, err
- }
- val = string(decoded)
- }
-
- switch key {
- case "before":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.BeforeContext = iVal
- case "after":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.AfterContext = iVal
- case "max":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.MaxCount = iVal
- default:
- options[key] = val
- }
- }
-
- return options, lContext, nil
-}
diff --git a/internal/server/scheduler.go b/internal/server/scheduler.go
index a1e9e36..0ba65f7 100644
--- a/internal/server/scheduler.go
+++ b/internal/server/scheduler.go
@@ -10,25 +10,23 @@ import (
"github.com/mimecast/dtail/internal/clients"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/omode"
gossh "golang.org/x/crypto/ssh"
)
-type scheduler struct {
-}
+type scheduler struct{}
func newScheduler() *scheduler {
return &scheduler{}
}
func (s *scheduler) start(ctx context.Context) {
- logger.Info("Starting scheduled job runner after 10s")
+ dlog.Server.Info("Starting scheduled job runner after 10s")
// First run after just 10s!
time.Sleep(time.Second * 10)
s.runJobs(ctx)
-
for {
select {
case <-time.After(time.Minute):
@@ -42,27 +40,24 @@ func (s *scheduler) start(ctx context.Context) {
func (s *scheduler) runJobs(ctx context.Context) {
for _, job := range config.Server.Schedule {
if !job.Enable {
- logger.Debug(job.Name, "Not running job as not enabled")
+ dlog.Server.Debug(job.Name, "Not running job as not enabled")
continue
}
-
hour, err := strconv.Atoi(time.Now().Format("15"))
if err != nil {
- logger.Error(job.Name, "Unable to create job", err)
+ dlog.Server.Error(job.Name, "Unable to create job", err)
continue
}
-
if hour < job.TimeRange[0] || hour >= job.TimeRange[1] {
- logger.Debug(job.Name, "Not running job out of time range")
+ dlog.Server.Debug(job.Name, "Not running job out of time range")
continue
}
files := fillDates(job.Files)
outfile := fillDates(job.Outfile)
-
_, err = os.Stat(outfile)
if !os.IsNotExist(err) {
- logger.Debug(job.Name, "Not running job as outfile already exists", outfile)
+ dlog.Server.Debug(job.Name, "Not running job as outfile already exists", outfile)
continue
}
@@ -70,9 +65,8 @@ func (s *scheduler) runJobs(ctx context.Context) {
if servers == "" {
servers = config.Server.SSHBindAddress
}
-
- args := clients.Args{
- ConnectionsPerCPU: 10,
+ args := config.Args{
+ ConnectionsPerCPU: config.DefaultConnectionsPerCPU,
Discovery: job.Discovery,
ServersStr: servers,
What: files,
@@ -81,25 +75,24 @@ func (s *scheduler) runJobs(ctx context.Context) {
}
args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name))
-
- query := fmt.Sprintf("%s outfile %s", job.Query, outfile)
- client, err := clients.NewMaprClient(args, query, clients.CumulativeMode)
+ args.QueryStr = fmt.Sprintf("%s outfile %s", job.Query, outfile)
+ client, err := clients.NewMaprClient(args, clients.CumulativeMode)
if err != nil {
- logger.Error(fmt.Sprintf("Unable to create job %s", job.Name), err)
+ dlog.Server.Error(fmt.Sprintf("Unable to create job %s", job.Name), err)
continue
}
jobCtx, cancel := context.WithCancel(ctx)
defer cancel()
- logger.Info(fmt.Sprintf("Starting job %s", job.Name))
+ dlog.Server.Info(fmt.Sprintf("Starting job %s", job.Name))
status := client.Start(jobCtx, make(chan string))
logMessage := fmt.Sprintf("Job exited with status %d", status)
if status != 0 {
- logger.Warn(logMessage)
+ dlog.Server.Warn(logMessage)
continue
}
- logger.Info(logMessage)
+ dlog.Server.Info(logMessage)
}
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 3640208..0cb5e27 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -9,7 +9,7 @@ import (
"strings"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/server/handlers"
"github.com/mimecast/dtail/internal/ssh/server"
user "github.com/mimecast/dtail/internal/user/server"
@@ -24,9 +24,9 @@ type Server struct {
stats stats
// SSH server configuration.
sshServerConfig *gossh.ServerConfig
- // To control the max amount of concurrent cats (which can cause a lot of I/O on the server)
+ // To control the max amount of concurrent cats.
catLimiter chan struct{}
- // To control the max amount of concurrent tails
+ // To control the max amount of concurrent tails.
tailLimiter chan struct{}
// To run scheduled tasks (if configured)
sched *scheduler
@@ -36,7 +36,7 @@ type Server struct {
// New returns a new server.
func New() *Server {
- logger.Info("Creating server", version.String())
+ dlog.Server.Info("Creating server", version.String())
s := Server{
sshServerConfig: &gossh.ServerConfig{},
@@ -51,7 +51,7 @@ func New() *Server {
private, err := gossh.ParsePrivateKey(server.PrivateHostKey())
if err != nil {
- logger.FatalExit(err)
+ dlog.Server.FatalPanic(err)
}
s.sshServerConfig.AddHostKey(private)
@@ -60,14 +60,13 @@ func New() *Server {
// Start the server.
func (s *Server) Start(ctx context.Context) int {
- logger.Info("Starting server")
-
+ dlog.Server.Info("Starting server")
bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
- logger.Info("Binding server", bindAt)
+ dlog.Server.Info("Binding server", bindAt)
listener, err := net.Listen("tcp", bindAt)
if err != nil {
- logger.FatalExit("Failed to open listening TCP socket", err)
+ dlog.Server.FatalPanic("Failed to open listening TCP socket", err)
}
go s.stats.start(ctx)
@@ -76,14 +75,12 @@ func (s *Server) Start(ctx context.Context) int {
go s.listenerLoop(ctx, listener)
<-ctx.Done()
-
// For future use.
return 0
}
func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) {
- logger.Debug("Starting listener loop")
-
+ dlog.Server.Debug("Starting listener loop")
for {
conn, err := listener.Accept() // Blocking
if err != nil {
@@ -92,63 +89,69 @@ func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) {
return
default:
}
- logger.Error("Failed to accept incoming connection", err)
+ dlog.Server.Error("Failed to accept incoming connection", err)
continue
}
if err := s.stats.serverLimitExceeded(); err != nil {
- logger.Error(err)
+ dlog.Server.Error(err)
conn.Close()
continue
}
-
go s.handleConnection(ctx, conn)
}
}
func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
- logger.Info("Handling connection")
+ dlog.Server.Info("Handling connection")
sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig)
if err != nil {
- logger.Error("Something just happened", err)
+ dlog.Server.Error("Something just happened", err)
return
}
s.stats.incrementConnections()
-
go gossh.DiscardRequests(reqs)
for newChannel := range chans {
go s.handleChannel(ctx, sshConn, newChannel)
}
}
-func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) {
- user := user.New(sshConn.User(), sshConn.RemoteAddr().String())
- logger.Info(user, "Invoking channel handler")
+func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn,
+ newChannel gossh.NewChannel) {
+ user, err := user.New(sshConn.User(), sshConn.RemoteAddr().String())
+ if err != nil {
+ dlog.Server.Error(user, err)
+ newChannel.Reject(gossh.Prohibited, err.Error())
+ return
+ }
+
+ dlog.Server.Info(user, "Invoking channel handler")
if newChannel.ChannelType() != "session" {
err := errors.New("Don'w allow other channel types than session")
- logger.Error(user, err)
+ dlog.Server.Error(user, err)
newChannel.Reject(gossh.Prohibited, err.Error())
return
}
channel, requests, err := newChannel.Accept()
if err != nil {
- logger.Error(user, "Could not accept channel", err)
+ dlog.Server.Error(user, "Could not accept channel", err)
return
}
if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil {
- logger.Error(user, "While handling request", err)
+ dlog.Server.Error(user, err)
sshConn.Close()
}
}
-func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
- logger.Info(user, "Invoking request handler")
+func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn,
+ in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
+ dlog.Server.Info(user, "Invoking request handler")
for req := range in {
var payload = struct{ Value string }{}
gossh.Unmarshal(req.Payload, &payload)
@@ -157,12 +160,11 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
case "shell":
var handler handlers.Handler
switch user.Name {
- case config.ControlUser:
- handler = handlers.NewControlHandler(user)
+ case config.HealthUser:
+ handler = handlers.NewHealthHandler(user)
default:
handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter)
}
-
terminate := func() {
handler.Shutdown()
sshConn.Close()
@@ -173,13 +175,11 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
io.Copy(channel, handler)
terminate()
}()
-
go func() {
// Broken pipe, cancel
io.Copy(handler, channel)
terminate()
}()
-
go func() {
select {
case <-ctx.Done():
@@ -187,62 +187,61 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
}
terminate()
}()
-
go func() {
if err := sshConn.Wait(); err != nil && err != io.EOF {
- // Use of closed network connection.
- logger.Debug(user, "While waiting for ssh connection", err)
+ dlog.Server.Error(user, err)
}
s.stats.decrementConnections()
- logger.Info(user, "Good bye Mister!")
+ dlog.Server.Info(user, "Good bye Mister!")
terminate()
}()
// Only serving shell type
req.Reply(true, nil)
-
default:
req.Reply(false, nil)
-
- return fmt.Errorf("Closing SSH connection as unknown request received|%s|%v",
+ return fmt.Errorf("Closing SSH connection as unknown request recieved|%s|%v",
req.Type, payload.Value)
}
}
-
return nil
}
// Callback for SSH authentication.
-func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) {
- user := user.New(c.User(), c.RemoteAddr().String())
+func (s *Server) Callback(c gossh.ConnMetadata,
+ authPayload []byte) (*gossh.Permissions, error) {
+
+ user, err := user.New(c.User(), c.RemoteAddr().String())
+ if err != nil {
+ return nil, err
+ }
if config.ServerRelaxedAuthEnable {
- logger.Fatal(user, "Granting permissions via relaxed-auth")
+ dlog.Server.Fatal(user, "Granting permissions via relaxed-auth")
return nil, nil
}
authInfo := string(authPayload)
-
splitted := strings.Split(c.RemoteAddr().String(), ":")
remoteIP := splitted[0]
switch user.Name {
- case config.ControlUser:
- if authInfo == config.ControlUser {
- logger.Debug(user, "Granting permissions to control user")
+ case config.HealthUser:
+ if authInfo == config.HealthUser {
+ dlog.Server.Debug(user, "Granting permissions to health user")
return nil, nil
}
case config.ScheduleUser:
for _, job := range config.Server.Schedule {
if s.backgroundCanSSH(user, authInfo, remoteIP, job.Name, job.AllowFrom) {
- logger.Debug(user, "Granting SSH connection")
+ dlog.Server.Debug(user, "Granting SSH connection")
return nil, nil
}
}
case config.ContinuousUser:
for _, job := range config.Server.Continuous {
if s.backgroundCanSSH(user, authInfo, remoteIP, job.Name, job.AllowFrom) {
- logger.Debug(user, "Granting SSH connection")
+ dlog.Server.Debug(user, "Granting SSH connection")
return nil, nil
}
}
@@ -252,23 +251,26 @@ func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Perm
return nil, fmt.Errorf("user %s not authorized", user)
}
-func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP, allowedJobName string, allowFrom []string) bool {
- logger.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom)
+func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP,
+ allowedJobName string, allowFrom []string) bool {
+ dlog.Server.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom)
if jobName != allowedJobName {
- logger.Debug(user, jobName, "backgroundCanSSH", "Job name does not match, skipping to next one...", allowedJobName)
+ dlog.Server.Debug(user, jobName, "backgroundCanSSH",
+ "Job name does not match, skipping to next one...", allowedJobName)
return false
}
for _, myAddr := range allowFrom {
ips, err := net.LookupIP(myAddr)
if err != nil {
- logger.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP address for allowed hosts lookup, skipping to next one...", myAddr, err)
+ dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP "+
+ "address for allowed hosts lookup, skipping to next one...", myAddr, err)
continue
}
-
for _, ip := range ips {
- logger.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses", remoteIP, ip.String())
+ dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses",
+ remoteIP, ip.String())
if remoteIP == ip.String() {
return true
}
diff --git a/internal/server/stats.go b/internal/server/stats.go
index ac579ad..99a644a 100644
--- a/internal/server/stats.go
+++ b/internal/server/stats.go
@@ -3,12 +3,11 @@ package server
import (
"context"
"fmt"
- "runtime"
"sync"
"time"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
)
// Used to collect and display various server stats.
@@ -20,7 +19,6 @@ type stats struct {
func (s *stats) incrementConnections() {
defer s.logServerStats()
-
s.mutex.Lock()
s.currentConnections++
s.lifetimeConnections++
@@ -29,7 +27,6 @@ func (s *stats) incrementConnections() {
func (s *stats) decrementConnections() {
defer s.logServerStats()
-
s.mutex.Lock()
s.currentConnections--
s.mutex.Unlock()
@@ -41,8 +38,8 @@ func (s *stats) hasConnections() bool {
s.mutex.Unlock()
has := currentConnections > 0
- logger.Info("stats", "Server with open connections?", has, currentConnections)
-
+ dlog.Server.Info("stats", "Server with open connections?",
+ has, currentConnections)
return has
}
@@ -50,10 +47,10 @@ func (s *stats) logServerStats() {
s.mutex.Lock()
defer s.mutex.Unlock()
- currentConnections := fmt.Sprintf("currentConnections=%d", s.currentConnections)
- lifetimeConnections := fmt.Sprintf("lifetimeConnections=%d", s.lifetimeConnections)
- goroutines := fmt.Sprintf("goroutines=%d", runtime.NumGoroutine())
- logger.Info("stats", currentConnections, lifetimeConnections, goroutines)
+ data := make(map[string]interface{})
+ data["currentConnections"] = s.currentConnections
+ data["lifetimeConnections"] = s.lifetimeConnections
+ dlog.Server.Mapreduce("STATS", data)
}
func (s *stats) serverLimitExceeded() error {
@@ -61,9 +58,9 @@ func (s *stats) serverLimitExceeded() error {
defer s.mutex.Unlock()
if s.currentConnections >= config.Server.MaxConnections {
- return fmt.Errorf("Exceeded max allowed concurrent connections of %d", config.Server.MaxConnections)
+ return fmt.Errorf("Exceeded max allowed concurrent connections of %d",
+ config.Server.MaxConnections)
}
-
return nil
}