diff options
| author | Paul Buetow <pbuetow@mimecast.com> | 2021-10-21 21:28:49 +0300 |
|---|---|---|
| committer | Paul Buetow <pbuetow@mimecast.com> | 2021-10-21 21:28:49 +0300 |
| commit | f4207a55f71bfbcfdc532d5cdd3befaa3474a157 (patch) | |
| tree | ea5e4a2d2a67035f645bdee496ae55a52034178a /internal/server | |
| parent | d80d6070557e3a800e3a54967af9eced518f116b (diff) | |
| parent | 739205206d63bf42f4e843b39d04d4c8cd8207c3 (diff) | |
merge develop
Diffstat (limited to 'internal/server')
| -rw-r--r-- | internal/server/continuous.go | 36 | ||||
| -rw-r--r-- | internal/server/handlers/basehandler.go | 320 | ||||
| -rw-r--r-- | internal/server/handlers/controlhandler.go | 98 | ||||
| -rw-r--r-- | internal/server/handlers/healthhandler.go | 58 | ||||
| -rw-r--r-- | internal/server/handlers/mapcommand.go | 7 | ||||
| -rw-r--r-- | internal/server/handlers/readcommand.go | 83 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 412 | ||||
| -rw-r--r-- | internal/server/scheduler.go | 37 | ||||
| -rw-r--r-- | internal/server/server.go | 110 | ||||
| -rw-r--r-- | internal/server/stats.go | 21 |
10 files changed, 557 insertions, 625 deletions
diff --git a/internal/server/continuous.go b/internal/server/continuous.go index f75c732..93b3fcb 100644 --- a/internal/server/continuous.go +++ b/internal/server/continuous.go @@ -8,33 +8,29 @@ import ( "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/omode" - gossh "golang.org/x/crypto/ssh" ) -type continuous struct { -} +type continuous struct{} func newContinuous() *continuous { return &continuous{} } func (c *continuous) start(ctx context.Context) { - logger.Info("Starting continuous job runner after 10s") + dlog.Server.Info("Starting continuous job runner after 10s") time.Sleep(time.Second * 10) - c.runJobs(ctx) } func (c *continuous) runJobs(ctx context.Context) { for _, job := range config.Server.Continuous { if !job.Enable { - logger.Debug(job.Name, "Not running job as not enabled") + dlog.Server.Debug(job.Name, "Not running job as not enabled") continue } - go func(job config.Continuous) { c.runJob(ctx, job) for { @@ -51,18 +47,17 @@ func (c *continuous) runJobs(ctx context.Context) { } func (c *continuous) runJob(ctx context.Context, job config.Continuous) { - logger.Debug(job.Name, "Processing job") + dlog.Server.Debug(job.Name, "Processing job") files := fillDates(job.Files) outfile := fillDates(job.Outfile) - servers := strings.Join(job.Servers, ",") if servers == "" { servers = config.Server.SSHBindAddress } - args := clients.Args{ - ConnectionsPerCPU: 10, + args := config.Args{ + ConnectionsPerCPU: config.DefaultConnectionsPerCPU, Discovery: job.Discovery, ServersStr: servers, What: files, @@ -71,35 +66,32 @@ func (c *continuous) runJob(ctx context.Context, job config.Continuous) { } args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name)) - - query := fmt.Sprintf("%s outfile %s", job.Query, outfile) - client, err := clients.NewMaprClient(args, query, clients.NonCumulativeMode) + args.QueryStr = fmt.Sprintf("%s outfile %s", job.Query, outfile) + client, err := clients.NewMaprClient(args, clients.NonCumulativeMode) if err != nil { - logger.Error(fmt.Sprintf("Unable to create job %s", job.Name), err) + dlog.Server.Error(fmt.Sprintf("Unable to create job %s", job.Name), err) return } jobCtx, cancel := context.WithCancel(ctx) defer cancel() - if job.RestartOnDayChange { go func() { if c.waitForDayChange(ctx) { - logger.Info(fmt.Sprintf("Canceling job %s due to day change", job.Name)) + dlog.Server.Info(fmt.Sprintf("Canceling job %s due to day change", job.Name)) cancel() } }() } - logger.Info(fmt.Sprintf("Starting job %s", job.Name)) + dlog.Server.Info(fmt.Sprintf("Starting job %s", job.Name)) status := client.Start(jobCtx, make(chan string)) logMessage := fmt.Sprintf("Job exited with status %d", status) - if status != 0 { - logger.Warn(logMessage) + dlog.Server.Warn(logMessage) return } - logger.Info(logMessage) + dlog.Server.Info(logMessage) } func (c *continuous) waitForDayChange(ctx context.Context) bool { diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go new file mode 100644 index 0000000..6d10d17 --- /dev/null +++ b/internal/server/handlers/basehandler.go @@ -0,0 +1,320 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mimecast/dtail/internal" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/pool" + "github.com/mimecast/dtail/internal/lcontext" + "github.com/mimecast/dtail/internal/mapr/server" + "github.com/mimecast/dtail/internal/protocol" + user "github.com/mimecast/dtail/internal/user/server" +) + +type handleCommandCb func(context.Context, lcontext.LContext, int, []string, string) + +type baseHandler struct { + done *internal.Done + handleCommandCb handleCommandCb + lines chan line.Line + aggregate *server.Aggregate + maprMessages chan string + serverMessages chan string + hostname string + user *user.User + ackCloseReceived chan struct{} + activeCommands int32 + readBuf bytes.Buffer + writeBuf bytes.Buffer + + // Some global options + sync primitives required. + once sync.Once + mutex sync.Mutex + quiet bool + spartan bool + serverless bool +} + +// Shutdown the handler. +func (h *baseHandler) Shutdown() { + h.done.Shutdown() +} + +// Done channel of the handler. +func (h *baseHandler) Done() <-chan struct{} { + return h.done.Done() +} + +// Read is to send data to the dtail client via Reader interface. +func (h *baseHandler) Read(p []byte) (n int, err error) { + defer h.readBuf.Reset() + + select { + case message := <-h.serverMessages: + if len(message) > 0 && message[0] == '.' { + // Handle hidden message (don't display to the user) + h.readBuf.WriteString(message) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + return + } + + if h.serverless { + return + } + + // Handle normal server message (display to the user) + h.readBuf.WriteString("SERVER") + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(h.hostname) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(message) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + + case message := <-h.maprMessages: + // Send mapreduce-aggregated data as a message. + h.readBuf.WriteString("AGGREGATE") + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(h.hostname) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(message) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + + case line := <-h.lines: + if !h.spartan { + h.readBuf.WriteString("REMOTE") + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(h.hostname) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc)) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(fmt.Sprintf("%v", line.Count)) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(line.SourceID) + h.readBuf.WriteString(protocol.FieldDelimiter) + } + h.readBuf.WriteString(line.Content.String()) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + pool.RecycleBytesBuffer(line.Content) + + case <-time.After(time.Second): + // Once in a while check whether we are done. + select { + case <-h.done.Done(): + err = io.EOF + return + default: + } + } + return +} + +// Write is to receive data from the dtail client via Writer interface. +func (h *baseHandler) Write(p []byte) (n int, err error) { + for _, b := range p { + switch b { + case ';': + h.handleCommand(string(h.writeBuf.Bytes())) + h.writeBuf.Reset() + default: + h.writeBuf.WriteByte(b) + } + } + n = len(p) + return +} + +func (h *baseHandler) handleCommand(commandStr string) { + dlog.Server.Debug(h.user, commandStr) + + args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) + if err != nil { + h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add) + return + } + args, argc, err = h.handleBase64(args, argc) + if err != nil { + h.send(h.serverMessages, dlog.Server.Error(h.user, err)) + return + } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-h.done.Done() + cancel() + }() + + parts := strings.Split(args[0], ":") + commandName := parts[0] + + // Either no options or empty options provided. + if len(parts) == 1 || len(parts[1]) == 0 { + h.handleCommandCb(ctx, lcontext.LContext{}, argc, args, commandName) + return + } + + options, ltx, err := config.DeserializeOptions(parts[1:]) + if err != nil { + h.send(h.serverMessages, dlog.Server.Error(h.user, err)) + return + } + h.handleOptions(options) + h.handleCommandCb(ctx, ltx, argc, args, commandName) +} + +func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) { + argc := len(args) + var add string + + if argc <= 2 || args[0] != "protocol" { + return args, argc, add, errors.New("unable to determine protocol version") + } + + if args[1] != protocol.ProtocolCompat { + clientCompat, _ := strconv.Atoi(args[1]) + serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat) + if clientCompat <= 3 { + // Protocol version 3 or lower expect a newline as message separator + // One day (after 2 major versions) this exception may be removed! + add = "\n" + } + + toUpdate := "client" + if clientCompat > serverCompat { + toUpdate = "server" + } + err := fmt.Errorf("the DTail server protocol version '%s' does not match "+ + "client protocol version '%s', please update DTail %s", + protocol.ProtocolCompat, args[1], toUpdate) + return args, argc, add, err + } + + return args[2:], argc - 2, add, nil +} + +func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, error) { + err := errors.New("unable to decode client message, DTail server and client " + + "versions may not be compatible") + if argc != 2 || args[0] != "base64" { + return args, argc, err + } + + decoded, err := base64.StdEncoding.DecodeString(args[1]) + if err != nil { + return args, argc, err + } + decodedStr := string(decoded) + + args = strings.Split(decodedStr, " ") + argc = len(decodedStr) + dlog.Server.Trace(h.user, "Base64 decoded received command", + decodedStr, argc, args) + + return args, argc, nil +} + +func (h *baseHandler) handleAckCommand(argc int, args []string) { + if argc < 3 { + if !h.quiet { + h.send(h.serverMessages, dlog.Server.Warn(h.user, + "Unable to parse command", args, argc)) + } + return + } + if args[1] == "close" && args[2] == "connection" { + select { + case <-h.ackCloseReceived: + default: + close(h.ackCloseReceived) + } + } +} + +func (h *baseHandler) handleOptions(options map[string]string) { + // We have to make sure that this block is executed only once. + h.mutex.Lock() + defer h.mutex.Unlock() + // We can read the options only once, will cause a data race otherwise if + // changed multiple times for multiple incoming commands. + h.once.Do(func() { + if quiet, _ := options["quiet"]; quiet == "true" { + dlog.Server.Debug(h.user, "Enabling quiet mode") + h.quiet = true + } + if spartan, _ := options["spartan"]; spartan == "true" { + dlog.Server.Debug(h.user, "Enabling spartan mode") + h.spartan = true + } + if serverless, _ := options["serverless"]; serverless == "true" { + dlog.Server.Debug(h.user, "Enabling serverless mode") + h.serverless = true + } + }) +} + +func (h *baseHandler) send(ch chan<- string, message string) { + select { + case ch <- message: + case <-h.done.Done(): + } +} + +func (h *baseHandler) flush() { + dlog.Server.Trace(h.user, "flush()") + numUnsentMessages := func() int { + return len(h.lines) + len(h.serverMessages) + len(h.maprMessages) + } + for i := 0; i < 10; i++ { + if numUnsentMessages() == 0 { + dlog.Server.Debug(h.user, "ALL lines sent", fmt.Sprintf("%p", h)) + return + } + dlog.Server.Debug(h.user, "Still lines to be sent") + time.Sleep(time.Millisecond * 10) + } + dlog.Server.Warn(h.user, "Some lines remain unsent", numUnsentMessages()) +} + +func (h *baseHandler) shutdown() { + dlog.Server.Debug(h.user, "shutdown()") + h.flush() + + go func() { + select { + case h.serverMessages <- ".syn close connection": + case <-h.done.Done(): + } + }() + + select { + case <-h.ackCloseReceived: + case <-time.After(time.Second * 5): + dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") + case <-h.done.Done(): + } + h.done.Shutdown() +} + +func (h *baseHandler) incrementActiveCommands() { + atomic.AddInt32(&h.activeCommands, 1) +} + +func (h *baseHandler) decrementActiveCommands() int32 { + atomic.AddInt32(&h.activeCommands, -1) + return atomic.LoadInt32(&h.activeCommands) +} diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go deleted file mode 100644 index 1e17c78..0000000 --- a/internal/server/handlers/controlhandler.go +++ /dev/null @@ -1,98 +0,0 @@ -package handlers - -import ( - "fmt" - "io" - "os" - "strings" - - "github.com/mimecast/dtail/internal" - "github.com/mimecast/dtail/internal/io/logger" - user "github.com/mimecast/dtail/internal/user/server" -) - -// ControlHandler is used for control functions and health monitoring. -type ControlHandler struct { - done *internal.Done - hostname string - payload []byte - serverMessages chan string - user *user.User -} - -// NewControlHandler returns a new control handler. -func NewControlHandler(user *user.User) *ControlHandler { - logger.Debug(user, "Creating control handler") - - h := ControlHandler{ - done: internal.NewDone(), - serverMessages: make(chan string, 10), - user: user, - } - - fqdn, err := os.Hostname() - if err != nil { - logger.FatalExit(err) - } - - s := strings.Split(fqdn, ".") - h.hostname = s[0] - - return &h -} - -// Shutdown the handler. -func (h *ControlHandler) Shutdown() { - h.done.Shutdown() -} - -// Done channel of the handler. -func (h *ControlHandler) Done() <-chan struct{} { - return h.done.Done() -} - -// Read is to send data to the client via the Reader interface. -func (h *ControlHandler) Read(p []byte) (n int, err error) { - for { - select { - case message := <-h.serverMessages: - wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) - n = copy(p, wholePayload) - return - case <-h.done.Done(): - return 0, io.EOF - } - } -} - -// Write is to read data to the client via the Writer interface. -func (h *ControlHandler) Write(p []byte) (n int, err error) { - for _, c := range p { - switch c { - case ';': - wholePayload := strings.TrimSpace(string(h.payload)) - h.handleCommand(wholePayload) - h.payload = nil - - default: - h.payload = append(h.payload, c) - } - } - - n = len(p) - return -} - -func (h *ControlHandler) handleCommand(command string) { - logger.Info(h.user, command) - s := strings.Split(command, " ") - logger.Debug(h.user, "Receiving command", command, s) - - switch s[0] { - case "health": - h.serverMessages <- "OK: DTail SSH Server seems fine" - h.serverMessages <- "done;" - default: - h.serverMessages <- logger.Error(h.user, "Received unknown control command", command, s) - } -} diff --git a/internal/server/handlers/healthhandler.go b/internal/server/handlers/healthhandler.go new file mode 100644 index 0000000..6dd9872 --- /dev/null +++ b/internal/server/handlers/healthhandler.go @@ -0,0 +1,58 @@ +package handlers + +import ( + "context" + "os" + "strings" + + "github.com/mimecast/dtail/internal" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/lcontext" + user "github.com/mimecast/dtail/internal/user/server" +) + +// HealthHandler is for the remote health check. +type HealthHandler struct { + baseHandler +} + +// NewHealthHandler returns the server handler. +func NewHealthHandler(user *user.User) *HealthHandler { + dlog.Server.Debug(user, "Creating new server health handler") + h := HealthHandler{ + baseHandler: baseHandler{ + done: internal.NewDone(), + lines: make(chan line.Line, 100), + serverMessages: make(chan string, 10), + maprMessages: make(chan string, 10), + ackCloseReceived: make(chan struct{}), + user: user, + }, + } + h.handleCommandCb = h.handleHealthCommand + + fqdn, err := os.Hostname() + if err != nil { + dlog.Server.FatalPanic(err) + } + s := strings.Split(fqdn, ".") + h.hostname = s[0] + return &h +} + +func (h *HealthHandler) handleHealthCommand(ctx context.Context, + ltx lcontext.LContext, argc int, args []string, commandName string) { + + dlog.Server.Debug(h.user, "Handling health command", argc, args) + switch commandName { + case "health": + h.send(h.serverMessages, "OK") + case ".ack": + h.handleAckCommand(argc, args) + default: + h.send(h.serverMessages, dlog.Server.Error(h.user, + "Received unknown health command", commandName, argc, args)) + } + h.shutdown() +} diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go index c3e600e..65e0ed8 100644 --- a/internal/server/handlers/mapcommand.go +++ b/internal/server/handlers/mapcommand.go @@ -14,18 +14,17 @@ type mapCommand struct { } // NewMapCommand returns a new server side mapreduce command. -func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) { - m := mapCommand{server: serverHandler} +func newMapCommand(serverHandler *ServerHandler, argc int, + args []string) (mapCommand, *server.Aggregate, error) { + m := mapCommand{server: serverHandler} queryStr := strings.Join(args[1:], " ") aggregate, err := server.NewAggregate(queryStr) if err != nil { return m, nil, err } - m.aggregate = aggregate return m, aggregate, nil - } func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) { diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go index b659c06..4728a55 100644 --- a/internal/server/handlers/readcommand.go +++ b/internal/server/handlers/readcommand.go @@ -7,8 +7,9 @@ import ( "sync" "time" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/fs" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/line" "github.com/mimecast/dtail/internal/lcontext" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/regex" @@ -26,39 +27,45 @@ func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand { } } -func (r *readCommand) Start(ctx context.Context, lContext lcontext.LContext, argc int, args []string, retries int) { - re := regex.NewNoop() +func (r *readCommand) Start(ctx context.Context, ltx lcontext.LContext, + argc int, args []string, retries int) { + re := regex.NewNoop() if argc >= 4 { deserializedRegex, err := regex.Deserialize(strings.Join(args[2:], " ")) if err != nil { - r.server.sendServerMessage(logger.Error(r.server.user, commandParseWarning, err)) + r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user, + "Unable to parse command", err)) return } re = deserializedRegex } if argc < 3 { - r.server.sendServerWarnMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) + r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, + "Unable to parse command", args, argc)) return } - r.readGlob(ctx, lContext, args[1], re, retries) + r.readGlob(ctx, ltx, args[1], re, retries) } -func (r *readCommand) readGlob(ctx context.Context, lContext lcontext.LContext, glob string, re regex.Regex, retries int) { +func (r *readCommand) readGlob(ctx context.Context, ltx lcontext.LContext, + glob string, re regex.Regex, retries int) { + retryInterval := time.Second * 5 glob = filepath.Clean(glob) for retryCount := 0; retryCount < retries; retryCount++ { paths, err := filepath.Glob(glob) if err != nil { - logger.Warn(r.server.user, glob, err) + dlog.Server.Warn(r.server.user, glob, err) time.Sleep(retryInterval) continue } if numPaths := len(paths); numPaths == 0 { - logger.Error(r.server.user, "No such file(s) to read", glob) - r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + dlog.Server.Error(r.server.user, "No such file(s) to read", glob) + r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, + "Unable to read file(s), check server logs")) select { case <-ctx.Done(): return @@ -68,41 +75,44 @@ func (r *readCommand) readGlob(ctx context.Context, lContext lcontext.LContext, continue } - r.readFiles(ctx, lContext, paths, glob, re, retryInterval) + r.readFiles(ctx, ltx, paths, glob, re, retryInterval) return } - r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Giving up to read file(s)")) + r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, + "Giving up to read file(s)")) return } -func (r *readCommand) readFiles(ctx context.Context, lContext lcontext.LContext, paths []string, glob string, re regex.Regex, retryInterval time.Duration) { +func (r *readCommand) readFiles(ctx context.Context, ltx lcontext.LContext, + paths []string, glob string, re regex.Regex, retryInterval time.Duration) { + var wg sync.WaitGroup wg.Add(len(paths)) - for _, path := range paths { - go r.readFileIfPermissions(ctx, lContext, &wg, path, glob, re) + go r.readFileIfPermissions(ctx, ltx, &wg, path, glob, re) } - wg.Wait() } -func (r *readCommand) readFileIfPermissions(ctx context.Context, lContext lcontext.LContext, wg *sync.WaitGroup, path, glob string, re regex.Regex) { +func (r *readCommand) readFileIfPermissions(ctx context.Context, ltx lcontext.LContext, + wg *sync.WaitGroup, path, glob string, re regex.Regex) { + defer wg.Done() globID := r.makeGlobID(path, glob) - if !r.server.user.HasFilePermission(path, "readfiles") { - logger.Error(r.server.user, "No permission to read file", path, globID) - r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + dlog.Server.Error(r.server.user, "No permission to read file", path, globID) + r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, + "Unable to read file(s), check server logs")) return } - - r.readFile(ctx, lContext, path, globID, re) + r.readFile(ctx, ltx, path, globID, re) } -func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext, path, globID string, re regex.Regex) { - logger.Info(r.server.user, "Start reading file", path, globID) +func (r *readCommand) readFile(ctx context.Context, ltx lcontext.LContext, + path, globID string, re regex.Regex) { + dlog.Server.Info(r.server.user, "Start reading file", path, globID) var reader fs.FileReader switch r.mode { case omode.TailClient: @@ -114,15 +124,19 @@ func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext, } lines := r.server.lines - - // Plug in mappreduce engine - if r.server.aggregate != nil { - lines = r.server.aggregate.Lines - } + aggregate := r.server.aggregate for { - if err := reader.Start(ctx, lContext, lines, re); err != nil { - logger.Error(r.server.user, path, globID, err) + if aggregate != nil { + lines = make(chan line.Line, 100) + aggregate.NextLinesCh <- lines + } + if err := reader.Start(ctx, ltx, lines, re); err != nil { + dlog.Server.Error(r.server.user, path, globID, err) + } + if aggregate != nil { + // Also makes aggregate to Flush + close(lines) } select { @@ -133,9 +147,8 @@ func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext, return } } - time.Sleep(time.Second * 2) - logger.Info(path, globID, "Reading file again") + dlog.Server.Info(path, globID, "Reading file again") } } @@ -152,11 +165,11 @@ func (r *readCommand) makeGlobID(path, glob string) string { if len(idParts) > 0 { return strings.Join(idParts, "/") } - if len(pathParts) > 0 { return pathParts[len(pathParts)-1] } - r.server.sendServerWarnMessage(logger.Warn("Empty file path given?", path, glob)) + r.server.send(r.server.serverMessages, + dlog.Server.Warn("Empty file path given?", path, glob)) return "" } diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index 39d5d5f..36574a9 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -2,69 +2,50 @@ package handlers import ( "context" - "encoding/base64" - "errors" - "fmt" - "io" "os" - "strconv" "strings" - "sync/atomic" - "time" "github.com/mimecast/dtail/internal" - "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/line" - "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/lcontext" - "github.com/mimecast/dtail/internal/mapr/server" "github.com/mimecast/dtail/internal/omode" user "github.com/mimecast/dtail/internal/user/server" - "github.com/mimecast/dtail/internal/version" -) - -const ( - commandParseWarning string = "Unable to parse command" ) // ServerHandler implements the Reader and Writer interfaces to handle // the Bi-directional communication between SSH client and server. // This handler implements the handler of the SSH server. type ServerHandler struct { - done *internal.Done - lines chan line.Line - regex string - aggregate *server.Aggregate - aggregatedMessages chan string - serverMessages chan string - payload []byte - hostname string - user *user.User - catLimiter chan struct{} - tailLimiter chan struct{} - ackCloseReceived chan struct{} - activeCommands int32 - activeReaders int32 - quiet bool + baseHandler + catLimiter chan struct{} + tailLimiter chan struct{} + regex string } // NewServerHandler returns the server handler. -func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler { +func NewServerHandler(user *user.User, catLimiter, + tailLimiter chan struct{}) *ServerHandler { + + dlog.Server.Debug(user, "Creating new server handler") h := ServerHandler{ - done: internal.NewDone(), - lines: make(chan line.Line, 100), - serverMessages: make(chan string, 10), - aggregatedMessages: make(chan string, 10), - ackCloseReceived: make(chan struct{}), - catLimiter: catLimiter, - tailLimiter: tailLimiter, - regex: ".", - user: user, - } + baseHandler: baseHandler{ + done: internal.NewDone(), + lines: make(chan line.Line, 100), + serverMessages: make(chan string, 10), + maprMessages: make(chan string, 10), + ackCloseReceived: make(chan struct{}), + user: user, + }, + catLimiter: catLimiter, + tailLimiter: tailLimiter, + regex: ".", + } + h.handleCommandCb = h.handleUserCommand fqdn, err := os.Hostname() if err != nil { - logger.FatalExit(err) + dlog.Server.FatalPanic(err) } s := strings.Split(fqdn, ".") @@ -73,374 +54,49 @@ func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *S return &h } -// Shutdown the handler. -func (h *ServerHandler) Shutdown() { - h.done.Shutdown() -} - -// Done channel of the handler. -func (h *ServerHandler) Done() <-chan struct{} { - return h.done.Done() -} - -// Read is to send data to the dtail client via Reader interface. -func (h *ServerHandler) Read(p []byte) (n int, err error) { - for { - select { - case message := <-h.serverMessages: - if len(message) == 0 { - logger.Warn(h.user, "Empty message received") - return - } - if message[0] == '.' { - // Handle hidden message (don't display to the user, interpreted by dtail client) - wholePayload := []byte(fmt.Sprintf("%s\n", message)) - n = copy(p, wholePayload) - return - } - - // Handle normal server message (display to the user) - wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) - n = copy(p, wholePayload) - return - - case message := <-h.aggregatedMessages: - // Send mapreduce-aggregated data as a message. - data := fmt.Sprintf("AGGREGATE➔%s➔%s\n", h.hostname, message) - wholePayload := []byte(data) - n = copy(p, wholePayload) - return - - case line := <-h.lines: - // Send normal file content data as a message. - serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|", - h.hostname, line.TransmittedPerc, line.Count, line.SourceID)) - wholePayload := append(serverInfo, line.Content[:]...) - n = copy(p, wholePayload) - return - - case <-time.After(time.Second): - // Once in a while check whether we are done. - select { - case <-h.done.Done(): - return 0, io.EOF - default: - } - } - } -} - -// Write is to receive data from the dtail client via Writer interface. -func (h *ServerHandler) Write(p []byte) (n int, err error) { - for _, c := range p { - switch c { - case ';': - commandStr := strings.TrimSpace(string(h.payload)) - h.handleCommand(commandStr) - h.payload = nil - default: - h.payload = append(h.payload, c) - } - } - - n = len(p) - return -} - -func (h *ServerHandler) handleCommand(commandStr string) { - logger.Debug(h.user, commandStr) - ctx := context.Background() - - args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) - if err != nil { - h.send(h.serverMessages, logger.Error(h.user, err)) - return - } - - args, argc, err = h.handleBase64(args, argc) - if err != nil { - h.send(h.serverMessages, logger.Error(h.user, err)) - return - } - - if h.user.Name == config.ControlUser { - h.handleControlCommand(argc, args) - return - } - - ctx, cancel := context.WithCancel(ctx) - go func() { - <-h.done.Done() - cancel() - }() - - h.handleUserCommand(ctx, argc, args) -} - -func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) { - argc := len(args) - - if argc <= 2 || args[0] != "protocol" { - return args, argc, errors.New("unable to determine protocol version") - } - - if args[1] != version.ProtocolCompat { - err := fmt.Errorf("server with protocol version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1]) - return args, argc, err - } - - return args[2:], argc - 2, nil -} - -func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) { - err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible") - - if argc != 2 || args[0] != "base64" { - return args, argc, err - } - - decoded, err := base64.StdEncoding.DecodeString(args[1]) - if err != nil { - return args, argc, err - } - decodedStr := string(decoded) - - args = strings.Split(decodedStr, " ") - argc = len(decodedStr) - logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args) - - return args, argc, nil -} - -func (h *ServerHandler) handleControlCommand(argc int, args []string) { - switch args[0] { - case "debug": - h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) - default: - logger.Warn(h.user, "Received unknown control command", argc, args) - } -} - -func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) { - logger.Debug(h.user, "handleUserCommand", argc, args) +func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LContext, + argc int, args []string, commandName string) { + dlog.Server.Debug(h.user, "Handling user command", argc, args) h.incrementActiveCommands() commandFinished := func() { if h.decrementActiveCommands() == 0 { h.shutdown() } } - readerFinished := func() { - if h.decrementActiveReaders() == 0 { - if h.aggregate == nil { - return - } - h.aggregate.Shutdown() - } - } - - splitted := strings.Split(args[0], ":") - commandName := splitted[0] - - options, lContext, err := readOptions(splitted[1:]) - if err != nil { - h.sendServerMessage(logger.Error(h.user, err)) - commandFinished() - return - } - if quiet, ok := options["quiet"]; ok { - if quiet == "true" { - logger.Debug(h.user, "Enabling quiet mode") - h.quiet = true - } - } switch commandName { case "grep", "cat": command := newReadCommand(h, omode.CatClient) go func() { - h.incrementActiveReaders() - command.Start(ctx, lContext, argc, args, 1) - readerFinished() + command.Start(ctx, ltx, argc, args, 1) commandFinished() }() - case "tail": command := newReadCommand(h, omode.TailClient) go func() { - h.incrementActiveReaders() - command.Start(ctx, lContext, argc, args, 10) - readerFinished() + command.Start(ctx, ltx, argc, args, 10) commandFinished() }() - case "map": command, aggregate, err := newMapCommand(h, argc, args) if err != nil { - h.sendServerMessage(err.Error()) - logger.Error(h.user, err) + h.send(h.serverMessages, err.Error()) + dlog.Server.Error(h.user, err) commandFinished() return } - h.aggregate = aggregate go func() { - command.Start(ctx, h.aggregatedMessages) + command.Start(ctx, h.maprMessages) commandFinished() }() - - case "ack", ".ack": + case ".ack": h.handleAckCommand(argc, args) commandFinished() - default: - h.sendServerMessage(logger.Error(h.user, "Received unknown user command", commandName, argc, args, options)) + h.send(h.serverMessages, dlog.Server.Error(h.user, + "Received unknown user command", commandName, argc, args)) commandFinished() } } - -func (h *ServerHandler) handleAckCommand(argc int, args []string) { - if argc < 3 { - h.sendServerWarnMessage(logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - if args[1] == "close" && args[2] == "connection" { - close(h.ackCloseReceived) - } -} - -func (h *ServerHandler) send(ch chan<- string, message string) { - select { - case ch <- message: - case <-h.done.Done(): - } -} - -func (h *ServerHandler) sendServerMessage(message string) { - h.send(h.serverMessageC(), message) -} - -func (h *ServerHandler) sendServerWarnMessage(message string) { - if h.quiet { - return - } - h.send(h.serverMessageC(), message) -} - -func (h *ServerHandler) serverMessageC() chan<- string { - return h.serverMessages -} - -func (h *ServerHandler) flush() { - logger.Debug(h.user, "flush()") - - if h.aggregate != nil { - h.aggregate.Flush() - } - - unsentMessages := func() int { - return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages) - } - for i := 0; i < 3; i++ { - if unsentMessages() == 0 { - logger.Debug(h.user, "All lines sent") - return - } - logger.Debug(h.user, "Still lines to be sent") - time.Sleep(time.Second) - } - - logger.Warn(h.user, "Some lines remain unsent", unsentMessages()) -} - -func (h *ServerHandler) shutdown() { - logger.Debug(h.user, "shutdown()") - h.flush() - - go func() { - select { - case h.serverMessageC() <- ".syn close connection": - case <-h.done.Done(): - } - }() - - select { - case <-h.ackCloseReceived: - case <-time.After(time.Second * 5): - logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") - case <-h.done.Done(): - } - - h.done.Shutdown() -} - -func (h *ServerHandler) incrementActiveCommands() { - atomic.AddInt32(&h.activeCommands, 1) -} - -func (h *ServerHandler) decrementActiveCommands() int32 { - atomic.AddInt32(&h.activeCommands, -1) - return atomic.LoadInt32(&h.activeCommands) -} - -func (h *ServerHandler) incrementActiveReaders() { - atomic.AddInt32(&h.activeReaders, 1) -} - -func (h *ServerHandler) decrementActiveReaders() int32 { - atomic.AddInt32(&h.activeReaders, -1) - return atomic.LoadInt32(&h.activeReaders) -} - -// TODO: All options related code should be in its own package (client + server) -func readOptions(opts []string) (map[string]string, lcontext.LContext, error) { - options := make(map[string]string, len(opts)) - // Local search context - var lContext lcontext.LContext - - for _, o := range opts { - kv := strings.SplitN(o, "=", 2) - if len(kv) != 2 { - continue - } - key := kv[0] - val := kv[1] - - if strings.HasPrefix(val, "base64%") { - s := strings.SplitN(val, "%", 2) - decoded, err := base64.StdEncoding.DecodeString(s[1]) - if err != nil { - return options, lContext, err - } - val = string(decoded) - } - - switch key { - case "before": - iVal, err := strconv.Atoi(val) - if err != nil { - logger.Error(err) - continue - } - lContext.BeforeContext = iVal - case "after": - iVal, err := strconv.Atoi(val) - if err != nil { - logger.Error(err) - continue - } - lContext.AfterContext = iVal - case "max": - iVal, err := strconv.Atoi(val) - if err != nil { - logger.Error(err) - continue - } - lContext.MaxCount = iVal - default: - options[key] = val - } - } - - return options, lContext, nil -} diff --git a/internal/server/scheduler.go b/internal/server/scheduler.go index a1e9e36..0ba65f7 100644 --- a/internal/server/scheduler.go +++ b/internal/server/scheduler.go @@ -10,25 +10,23 @@ import ( "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/omode" gossh "golang.org/x/crypto/ssh" ) -type scheduler struct { -} +type scheduler struct{} func newScheduler() *scheduler { return &scheduler{} } func (s *scheduler) start(ctx context.Context) { - logger.Info("Starting scheduled job runner after 10s") + dlog.Server.Info("Starting scheduled job runner after 10s") // First run after just 10s! time.Sleep(time.Second * 10) s.runJobs(ctx) - for { select { case <-time.After(time.Minute): @@ -42,27 +40,24 @@ func (s *scheduler) start(ctx context.Context) { func (s *scheduler) runJobs(ctx context.Context) { for _, job := range config.Server.Schedule { if !job.Enable { - logger.Debug(job.Name, "Not running job as not enabled") + dlog.Server.Debug(job.Name, "Not running job as not enabled") continue } - hour, err := strconv.Atoi(time.Now().Format("15")) if err != nil { - logger.Error(job.Name, "Unable to create job", err) + dlog.Server.Error(job.Name, "Unable to create job", err) continue } - if hour < job.TimeRange[0] || hour >= job.TimeRange[1] { - logger.Debug(job.Name, "Not running job out of time range") + dlog.Server.Debug(job.Name, "Not running job out of time range") continue } files := fillDates(job.Files) outfile := fillDates(job.Outfile) - _, err = os.Stat(outfile) if !os.IsNotExist(err) { - logger.Debug(job.Name, "Not running job as outfile already exists", outfile) + dlog.Server.Debug(job.Name, "Not running job as outfile already exists", outfile) continue } @@ -70,9 +65,8 @@ func (s *scheduler) runJobs(ctx context.Context) { if servers == "" { servers = config.Server.SSHBindAddress } - - args := clients.Args{ - ConnectionsPerCPU: 10, + args := config.Args{ + ConnectionsPerCPU: config.DefaultConnectionsPerCPU, Discovery: job.Discovery, ServersStr: servers, What: files, @@ -81,25 +75,24 @@ func (s *scheduler) runJobs(ctx context.Context) { } args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name)) - - query := fmt.Sprintf("%s outfile %s", job.Query, outfile) - client, err := clients.NewMaprClient(args, query, clients.CumulativeMode) + args.QueryStr = fmt.Sprintf("%s outfile %s", job.Query, outfile) + client, err := clients.NewMaprClient(args, clients.CumulativeMode) if err != nil { - logger.Error(fmt.Sprintf("Unable to create job %s", job.Name), err) + dlog.Server.Error(fmt.Sprintf("Unable to create job %s", job.Name), err) continue } jobCtx, cancel := context.WithCancel(ctx) defer cancel() - logger.Info(fmt.Sprintf("Starting job %s", job.Name)) + dlog.Server.Info(fmt.Sprintf("Starting job %s", job.Name)) status := client.Start(jobCtx, make(chan string)) logMessage := fmt.Sprintf("Job exited with status %d", status) if status != 0 { - logger.Warn(logMessage) + dlog.Server.Warn(logMessage) continue } - logger.Info(logMessage) + dlog.Server.Info(logMessage) } } diff --git a/internal/server/server.go b/internal/server/server.go index 3640208..0cb5e27 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/server/handlers" "github.com/mimecast/dtail/internal/ssh/server" user "github.com/mimecast/dtail/internal/user/server" @@ -24,9 +24,9 @@ type Server struct { stats stats // SSH server configuration. sshServerConfig *gossh.ServerConfig - // To control the max amount of concurrent cats (which can cause a lot of I/O on the server) + // To control the max amount of concurrent cats. catLimiter chan struct{} - // To control the max amount of concurrent tails + // To control the max amount of concurrent tails. tailLimiter chan struct{} // To run scheduled tasks (if configured) sched *scheduler @@ -36,7 +36,7 @@ type Server struct { // New returns a new server. func New() *Server { - logger.Info("Creating server", version.String()) + dlog.Server.Info("Creating server", version.String()) s := Server{ sshServerConfig: &gossh.ServerConfig{}, @@ -51,7 +51,7 @@ func New() *Server { private, err := gossh.ParsePrivateKey(server.PrivateHostKey()) if err != nil { - logger.FatalExit(err) + dlog.Server.FatalPanic(err) } s.sshServerConfig.AddHostKey(private) @@ -60,14 +60,13 @@ func New() *Server { // Start the server. func (s *Server) Start(ctx context.Context) int { - logger.Info("Starting server") - + dlog.Server.Info("Starting server") bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) - logger.Info("Binding server", bindAt) + dlog.Server.Info("Binding server", bindAt) listener, err := net.Listen("tcp", bindAt) if err != nil { - logger.FatalExit("Failed to open listening TCP socket", err) + dlog.Server.FatalPanic("Failed to open listening TCP socket", err) } go s.stats.start(ctx) @@ -76,14 +75,12 @@ func (s *Server) Start(ctx context.Context) int { go s.listenerLoop(ctx, listener) <-ctx.Done() - // For future use. return 0 } func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) { - logger.Debug("Starting listener loop") - + dlog.Server.Debug("Starting listener loop") for { conn, err := listener.Accept() // Blocking if err != nil { @@ -92,63 +89,69 @@ func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) { return default: } - logger.Error("Failed to accept incoming connection", err) + dlog.Server.Error("Failed to accept incoming connection", err) continue } if err := s.stats.serverLimitExceeded(); err != nil { - logger.Error(err) + dlog.Server.Error(err) conn.Close() continue } - go s.handleConnection(ctx, conn) } } func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { - logger.Info("Handling connection") + dlog.Server.Info("Handling connection") sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig) if err != nil { - logger.Error("Something just happened", err) + dlog.Server.Error("Something just happened", err) return } s.stats.incrementConnections() - go gossh.DiscardRequests(reqs) for newChannel := range chans { go s.handleChannel(ctx, sshConn, newChannel) } } -func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) { - user := user.New(sshConn.User(), sshConn.RemoteAddr().String()) - logger.Info(user, "Invoking channel handler") +func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, + newChannel gossh.NewChannel) { + user, err := user.New(sshConn.User(), sshConn.RemoteAddr().String()) + if err != nil { + dlog.Server.Error(user, err) + newChannel.Reject(gossh.Prohibited, err.Error()) + return + } + + dlog.Server.Info(user, "Invoking channel handler") if newChannel.ChannelType() != "session" { err := errors.New("Don'w allow other channel types than session") - logger.Error(user, err) + dlog.Server.Error(user, err) newChannel.Reject(gossh.Prohibited, err.Error()) return } channel, requests, err := newChannel.Accept() if err != nil { - logger.Error(user, "Could not accept channel", err) + dlog.Server.Error(user, "Could not accept channel", err) return } if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil { - logger.Error(user, "While handling request", err) + dlog.Server.Error(user, err) sshConn.Close() } } -func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { - logger.Info(user, "Invoking request handler") +func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, + in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { + dlog.Server.Info(user, "Invoking request handler") for req := range in { var payload = struct{ Value string }{} gossh.Unmarshal(req.Payload, &payload) @@ -157,12 +160,11 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch case "shell": var handler handlers.Handler switch user.Name { - case config.ControlUser: - handler = handlers.NewControlHandler(user) + case config.HealthUser: + handler = handlers.NewHealthHandler(user) default: handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter) } - terminate := func() { handler.Shutdown() sshConn.Close() @@ -173,13 +175,11 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch io.Copy(channel, handler) terminate() }() - go func() { // Broken pipe, cancel io.Copy(handler, channel) terminate() }() - go func() { select { case <-ctx.Done(): @@ -187,62 +187,61 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch } terminate() }() - go func() { if err := sshConn.Wait(); err != nil && err != io.EOF { - // Use of closed network connection. - logger.Debug(user, "While waiting for ssh connection", err) + dlog.Server.Error(user, err) } s.stats.decrementConnections() - logger.Info(user, "Good bye Mister!") + dlog.Server.Info(user, "Good bye Mister!") terminate() }() // Only serving shell type req.Reply(true, nil) - default: req.Reply(false, nil) - - return fmt.Errorf("Closing SSH connection as unknown request received|%s|%v", + return fmt.Errorf("Closing SSH connection as unknown request recieved|%s|%v", req.Type, payload.Value) } } - return nil } // Callback for SSH authentication. -func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) { - user := user.New(c.User(), c.RemoteAddr().String()) +func (s *Server) Callback(c gossh.ConnMetadata, + authPayload []byte) (*gossh.Permissions, error) { + + user, err := user.New(c.User(), c.RemoteAddr().String()) + if err != nil { + return nil, err + } if config.ServerRelaxedAuthEnable { - logger.Fatal(user, "Granting permissions via relaxed-auth") + dlog.Server.Fatal(user, "Granting permissions via relaxed-auth") return nil, nil } authInfo := string(authPayload) - splitted := strings.Split(c.RemoteAddr().String(), ":") remoteIP := splitted[0] switch user.Name { - case config.ControlUser: - if authInfo == config.ControlUser { - logger.Debug(user, "Granting permissions to control user") + case config.HealthUser: + if authInfo == config.HealthUser { + dlog.Server.Debug(user, "Granting permissions to health user") return nil, nil } case config.ScheduleUser: for _, job := range config.Server.Schedule { if s.backgroundCanSSH(user, authInfo, remoteIP, job.Name, job.AllowFrom) { - logger.Debug(user, "Granting SSH connection") + dlog.Server.Debug(user, "Granting SSH connection") return nil, nil } } case config.ContinuousUser: for _, job := range config.Server.Continuous { if s.backgroundCanSSH(user, authInfo, remoteIP, job.Name, job.AllowFrom) { - logger.Debug(user, "Granting SSH connection") + dlog.Server.Debug(user, "Granting SSH connection") return nil, nil } } @@ -252,23 +251,26 @@ func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Perm return nil, fmt.Errorf("user %s not authorized", user) } -func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP, allowedJobName string, allowFrom []string) bool { - logger.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom) +func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP, + allowedJobName string, allowFrom []string) bool { + dlog.Server.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom) if jobName != allowedJobName { - logger.Debug(user, jobName, "backgroundCanSSH", "Job name does not match, skipping to next one...", allowedJobName) + dlog.Server.Debug(user, jobName, "backgroundCanSSH", + "Job name does not match, skipping to next one...", allowedJobName) return false } for _, myAddr := range allowFrom { ips, err := net.LookupIP(myAddr) if err != nil { - logger.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP address for allowed hosts lookup, skipping to next one...", myAddr, err) + dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP "+ + "address for allowed hosts lookup, skipping to next one...", myAddr, err) continue } - for _, ip := range ips { - logger.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses", remoteIP, ip.String()) + dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses", + remoteIP, ip.String()) if remoteIP == ip.String() { return true } diff --git a/internal/server/stats.go b/internal/server/stats.go index ac579ad..99a644a 100644 --- a/internal/server/stats.go +++ b/internal/server/stats.go @@ -3,12 +3,11 @@ package server import ( "context" "fmt" - "runtime" "sync" "time" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" ) // Used to collect and display various server stats. @@ -20,7 +19,6 @@ type stats struct { func (s *stats) incrementConnections() { defer s.logServerStats() - s.mutex.Lock() s.currentConnections++ s.lifetimeConnections++ @@ -29,7 +27,6 @@ func (s *stats) incrementConnections() { func (s *stats) decrementConnections() { defer s.logServerStats() - s.mutex.Lock() s.currentConnections-- s.mutex.Unlock() @@ -41,8 +38,8 @@ func (s *stats) hasConnections() bool { s.mutex.Unlock() has := currentConnections > 0 - logger.Info("stats", "Server with open connections?", has, currentConnections) - + dlog.Server.Info("stats", "Server with open connections?", + has, currentConnections) return has } @@ -50,10 +47,10 @@ func (s *stats) logServerStats() { s.mutex.Lock() defer s.mutex.Unlock() - currentConnections := fmt.Sprintf("currentConnections=%d", s.currentConnections) - lifetimeConnections := fmt.Sprintf("lifetimeConnections=%d", s.lifetimeConnections) - goroutines := fmt.Sprintf("goroutines=%d", runtime.NumGoroutine()) - logger.Info("stats", currentConnections, lifetimeConnections, goroutines) + data := make(map[string]interface{}) + data["currentConnections"] = s.currentConnections + data["lifetimeConnections"] = s.lifetimeConnections + dlog.Server.Mapreduce("STATS", data) } func (s *stats) serverLimitExceeded() error { @@ -61,9 +58,9 @@ func (s *stats) serverLimitExceeded() error { defer s.mutex.Unlock() if s.currentConnections >= config.Server.MaxConnections { - return fmt.Errorf("Exceeded max allowed concurrent connections of %d", config.Server.MaxConnections) + return fmt.Errorf("Exceeded max allowed concurrent connections of %d", + config.Server.MaxConnections) } - return nil } |
