summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorPaul Bütow <pbuetow@mimecast.com>2020-01-20 18:41:05 +0000
committerPaul Bütow <pbuetow@mimecast.com>2020-01-21 14:35:23 +0000
commitc128865c4c7411c29a59fca9a3a2f95537686d7b (patch)
tree193bccc70d942c8b70cc93fae2670263701e43aa /internal/server
parent3755a9911ecb05886577095f2b8cc8b9e4066a3a (diff)
Move commands to cmd/ and move internal dependencies to internal/
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handlers/controlhandler.go106
-rw-r--r--internal/server/handlers/handler.go10
-rw-r--r--internal/server/handlers/serverhandler.go492
-rw-r--r--internal/server/server.go212
-rw-r--r--internal/server/stats.go88
5 files changed, 908 insertions, 0 deletions
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
new file mode 100644
index 0000000..482f759
--- /dev/null
+++ b/internal/server/handlers/controlhandler.go
@@ -0,0 +1,106 @@
+package handlers
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/logger"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+// ControlHandler is used for control functions and health monitoring.
+type ControlHandler struct {
+ serverMessages chan string
+ pong chan struct{}
+ stop chan struct{}
+ payload []byte
+ hostname string
+ user *user.User
+}
+
+// NewControlHandler returns a new control handler.
+func NewControlHandler(user *user.User) *ControlHandler {
+ logger.Debug(user, "Creating control handler")
+
+ h := ControlHandler{
+ serverMessages: make(chan string, 10),
+ pong: make(chan struct{}, 10),
+ stop: make(chan struct{}),
+ user: user,
+ }
+
+ fqdn, err := os.Hostname()
+ if err != nil {
+ logger.FatalExit(err)
+ }
+
+ s := strings.Split(fqdn, ".")
+ h.hostname = s[0]
+ return &h
+}
+
+// Read is to send data to the client via the Reader interface.
+func (h *ControlHandler) Read(p []byte) (n int, err error) {
+ for {
+ select {
+ case message := <-h.serverMessages:
+ wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
+ n = copy(p, wholePayload)
+ return
+ case <-h.pong:
+ logger.Info(h.user, "Sending pong")
+ n = copy(p, []byte(".pong\n"))
+ return
+ case <-h.stop:
+ return 0, io.EOF
+ }
+ }
+}
+
+// Write is to read data to the client via the Writer interface.
+func (h *ControlHandler) Write(p []byte) (n int, err error) {
+ for _, c := range p {
+ switch c {
+ case ';':
+ wholePayload := strings.TrimSpace(string(h.payload))
+ h.handleCommand(wholePayload)
+ h.payload = nil
+
+ default:
+ h.payload = append(h.payload, c)
+ }
+ }
+
+ n = len(p)
+ return
+}
+
+// Close the control handler.
+func (h *ControlHandler) Close() {
+ close(h.stop)
+}
+
+// Wait returns the handler stop channel.
+func (h *ControlHandler) Wait() <-chan struct{} {
+ return h.stop
+}
+
+func (h *ControlHandler) handleCommand(command string) {
+ logger.Info(h.user, command)
+ s := strings.Split(command, " ")
+ logger.Debug(h.user, "Receiving command", command, s)
+
+ switch s[0] {
+ case "health":
+ h.serverMessages <- "OK: DTail SSH Server seems fine"
+ h.serverMessages <- "done;"
+ case "ping":
+ h.pong <- struct{}{}
+ case "debug":
+ h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s)
+ default:
+ h.serverMessages <- logger.Warn(h.user, "Received unknown command", command, s)
+ }
+}
diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go
new file mode 100644
index 0000000..8b1f73e
--- /dev/null
+++ b/internal/server/handlers/handler.go
@@ -0,0 +1,10 @@
+package handlers
+
+import "io"
+
+// Handler interface for server side functionality.
+type Handler interface {
+ io.ReadWriter
+ Close()
+ Wait() <-chan struct{}
+}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
new file mode 100644
index 0000000..bed8609
--- /dev/null
+++ b/internal/server/handlers/serverhandler.go
@@ -0,0 +1,492 @@
+package handlers
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/fs"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/mapr/server"
+ "github.com/mimecast/dtail/internal/omode"
+ user "github.com/mimecast/dtail/internal/user/server"
+ "github.com/mimecast/dtail/internal/version"
+)
+
+const (
+ commandParseWarning string = "Unable to parse command"
+)
+
+// ServerHandler implements the Reader and Writer interfaces to handle
+// the Bi-directional communication between SSH client and server.
+// This handler implements the handler of the SSH server.
+type ServerHandler struct {
+ // Local log file readers
+ fileReaders []fs.FileReader
+ fileReadersMtx *sync.Mutex
+ // Channel for read lines.
+ lines chan fs.LineRead
+ // Only process log lines matching this regex.
+ regex string
+ // Server side mapr log aggregation.
+ aggregate *server.Aggregate
+ // Channel of aggregated log lines.
+ aggregatedMessages chan string
+ // Channel for server messages to be sent to the client.
+ serverMessages chan string
+ // Channel for hidden messages to be sent to the client.
+ hiddenMessages chan string
+ // The current payload sent to the client.
+ payload []byte
+ // The current server hostname.
+ hostname string
+ // The user connecting to dtail.
+ user *user.User
+ // To limit the server wide max amount of concurrent cats
+ catLimiter chan struct{}
+ // To limit the server wide max amount of concurrent tails
+ tailLimiter chan struct{}
+ // Server can tell handler to stop the handler.
+ stop chan struct{}
+ // Indicate that client responded to server with "ack stop connection"
+ ackStopReceived chan struct{}
+ // Stop timeout.
+ stopTimeout chan struct{}
+}
+
+// NewServerHandler returns the server handler.
+func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) *ServerHandler {
+ logger.Debug(user, "Creating tail handler")
+ h := ServerHandler{
+ fileReadersMtx: &sync.Mutex{},
+ lines: make(chan fs.LineRead, 100),
+ serverMessages: make(chan string, 10),
+ aggregatedMessages: make(chan string, 10),
+ hiddenMessages: make(chan string, 10),
+ ackStopReceived: make(chan struct{}),
+ stopTimeout: make(chan struct{}),
+ stop: make(chan struct{}),
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ regex: ".",
+ user: user,
+ }
+
+ fqdn, err := os.Hostname()
+ if err != nil {
+ logger.FatalExit(err)
+ }
+
+ s := strings.Split(fqdn, ".")
+ h.hostname = s[0]
+
+ return &h
+}
+
+// Read is to send data to the dtail client via Reader interface.
+func (h *ServerHandler) Read(p []byte) (n int, err error) {
+ for {
+ select {
+ case message := <-h.serverMessages:
+ wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
+ n = copy(p, wholePayload)
+ return
+ case message := <-h.aggregatedMessages:
+ data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message)
+ //logger.Debug("Sending aggregation data", data)
+ wholePayload := []byte(data)
+ n = copy(p, wholePayload)
+ return
+ case message := <-h.hiddenMessages:
+ //logger.Debug(h.user, "Sending hidden message", message)
+ wholePayload := []byte(fmt.Sprintf(".%s\n", message))
+ n = copy(p, wholePayload)
+ return
+ case line := <-h.lines:
+ serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|",
+ h.hostname, line.TransmittedPerc, line.Count, *line.GlobID))
+ wholePayload := append(serverInfo, line.Content[:]...)
+ n = copy(p, wholePayload)
+ return
+ case <-time.After(time.Second):
+ select {
+ case <-h.stop:
+ return 0, io.EOF
+ default:
+ }
+ }
+ }
+}
+
+// Write is to receive data from the dtail client via Writer interface.
+func (h *ServerHandler) Write(p []byte) (n int, err error) {
+ for _, c := range p {
+ switch c {
+ case ';':
+ commandStr := strings.TrimSpace(string(h.payload))
+ h.handleCommand(commandStr)
+ h.payload = nil
+ default:
+ h.payload = append(h.payload, c)
+ }
+ }
+
+ n = len(p)
+ return
+}
+
+// Close the server handler.
+func (h *ServerHandler) Close() {
+ h.fileReadersMtx.Lock()
+ defer h.fileReadersMtx.Unlock()
+
+ for _, reader := range h.fileReaders {
+ reader.Stop()
+ }
+ if h.aggregate != nil {
+ h.aggregate.Close()
+ }
+
+ close(h.stop)
+}
+
+func (h *ServerHandler) makeGlobID(path, glob string) string {
+ var idParts []string
+ pathParts := strings.Split(path, "/")
+
+ for i, globPart := range strings.Split(glob, "/") {
+ if strings.Contains(globPart, "*") {
+ idParts = append(idParts, pathParts[i])
+ }
+ }
+
+ if len(idParts) > 0 {
+ return strings.Join(idParts, "/")
+ }
+
+ if len(pathParts) > 0 {
+ return pathParts[len(pathParts)-1]
+ }
+
+ h.send(h.serverMessages, logger.Error("Empty file path given?", path, glob))
+ return ""
+}
+
+func (h *ServerHandler) processFileGlob(mode omode.Mode, glob string, regex string) {
+ retryInterval := time.Second * 5
+ glob = filepath.Clean(glob)
+
+ errors := make(chan struct{})
+ stop := make(chan struct{})
+ defer close(stop)
+
+ go func() {
+ for {
+ select {
+ case <-errors:
+ h.send(h.serverMessages, logger.Warn(h.user, "Unable to read file(s), check server logs"))
+ case <-stop:
+ return
+ case <-h.stop:
+ return
+ }
+ }
+ }()
+
+ maxRetries := 10
+ for {
+ maxRetries--
+ if maxRetries < 0 {
+ h.send(h.serverMessages, logger.Warn(h.user, "Giving up to read file(s)"))
+ h.internalClose()
+ return
+ }
+
+ paths, err := filepath.Glob(glob)
+ if err != nil {
+ logger.Warn(h.user, glob, err)
+ time.Sleep(retryInterval)
+ continue
+ }
+
+ if numPaths := len(paths); numPaths == 0 {
+ logger.Error(h.user, "No such file(s) to read", glob)
+ select {
+ case errors <- struct{}{}:
+ case <-h.stop:
+ return
+ default:
+ }
+ time.Sleep(retryInterval)
+ continue
+ }
+
+ h.startReadingFiles(mode, paths, glob, regex, retryInterval, errors)
+ break
+ }
+}
+
+func (h *ServerHandler) startReadingFiles(mode omode.Mode, paths []string, glob string, regex string, retryInterval time.Duration, errors chan<- struct{}) {
+ var wg sync.WaitGroup
+ wg.Add(len(paths))
+
+ read := func(path string, wg *sync.WaitGroup) {
+ defer wg.Done()
+ globID := h.makeGlobID(path, glob)
+
+ if !h.user.HasFilePermission(path) {
+ logger.Error(h.user, "No permission to read file", path, globID)
+ select {
+ case errors <- struct{}{}:
+ default:
+ }
+ return
+ }
+
+ h.startReadingFile(mode, path, globID, regex)
+ }
+
+ for _, path := range paths {
+ go read(path, &wg)
+ }
+
+ wg.Wait()
+}
+
+func (h *ServerHandler) startReadingFile(mode omode.Mode, path, globID, regex string) {
+ defer h.stopReadingFile(path)
+ logger.Info(h.user, "Start reading file", path, globID)
+
+ var reader fs.FileReader
+ switch mode {
+ case omode.TailClient:
+ reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter)
+ case omode.GrepClient:
+ fallthrough
+ case omode.CatClient:
+ reader = fs.NewCatFile(path, globID, h.serverMessages, h.catLimiter)
+ default:
+ reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter)
+ }
+
+ h.fileReadersMtx.Lock()
+ h.fileReaders = append(h.fileReaders, reader)
+ h.fileReadersMtx.Unlock()
+
+ lines := h.lines
+ // Plugin mappreduce engine
+ if h.aggregate != nil {
+ lines = h.aggregate.Lines
+ }
+
+ for {
+ if err := reader.Start(lines, regex); err != nil {
+ logger.Error(h.user, path, globID, err)
+ }
+
+ select {
+ case <-h.stop:
+ return
+ default:
+ if !reader.Retry() {
+ return
+ }
+ }
+
+ time.Sleep(time.Second * 2)
+ logger.Info(path, globID, "Reading file again")
+ }
+}
+
+func (h *ServerHandler) stopReadingFile(path string) {
+ logger.Info(h.user, "Stop reading file", path)
+
+ h.fileReadersMtx.Lock()
+ defer h.fileReadersMtx.Unlock()
+
+ path = filepath.Clean(path)
+ var fileReaders []fs.FileReader
+
+ for _, reader := range h.fileReaders {
+ if reader.FilePath() == path {
+ reader.Stop()
+ continue
+ }
+ fileReaders = append(fileReaders, reader)
+ }
+
+ if len(fileReaders) == len(h.fileReaders) {
+ logger.Warn(h.user, "Didn't read file path", path)
+ return
+ }
+
+ h.fileReaders = fileReaders
+
+ if len(fileReaders) == 0 {
+ if h.aggregate != nil {
+ h.aggregate.Serialize()
+ }
+ h.allLinesSent()
+ }
+}
+
+func (h *ServerHandler) numUnsentMessages() int {
+ return len(h.lines) + len(h.serverMessages) + len(h.hiddenMessages) + len(h.aggregatedMessages)
+}
+
+func (h *ServerHandler) allLinesSent() {
+ defer h.internalClose()
+
+ for i := 0; i < 3; i++ {
+ if h.numUnsentMessages() == 0 {
+ logger.Debug(h.user, "All lines sent")
+ return
+ }
+ logger.Debug(h.user, "Still lines to be sent")
+ time.Sleep(time.Second)
+ }
+
+ logger.Warn(h.user, "Some lines remain unsent", h.numUnsentMessages())
+}
+
+// Handler decides to shutdown the connection, not the server itself.
+func (h *ServerHandler) internalClose() {
+ select {
+ case h.hiddenMessages <- "syn close connection":
+ case <-time.After(time.Second * 5):
+ logger.Debug(h.user, "Not waiting for ack close connection")
+ close(h.stopTimeout)
+ return
+ }
+
+ select {
+ case <-h.Wait():
+ case <-time.After(time.Second * 5):
+ logger.Debug(h.user, "Not waiting for ack close connection")
+ close(h.stopTimeout)
+ }
+}
+
+func (h *ServerHandler) handleCommand(commandStr string) {
+ logger.Info(h.user, commandStr)
+
+ args := strings.Split(commandStr, " ")
+ argc := len(args)
+
+ logger.Debug(h.user, "Received command", commandStr, argc, args)
+
+ if h.user.Name == config.ControlUser {
+ h.handleControlCommand(argc, args)
+ return
+ }
+
+ h.handleUserCommand(argc, args)
+}
+
+// Special (restricted) set of commands for anonymous ControlUser access.
+func (h *ServerHandler) handleControlCommand(argc int, args []string) {
+ switch args[0] {
+ case "ping":
+ h.send(h.hiddenMessages, "pong")
+ case "debug":
+ h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args))
+ default:
+ logger.Warn(h.user, "Received unknown command", argc, args)
+ }
+}
+
+// Commands for authed users.
+func (h *ServerHandler) handleUserCommand(argc int, args []string) {
+ switch args[0] {
+ case "grep":
+ fallthrough
+ case "cat":
+ h.handleReadCommand(argc, args, omode.CatClient)
+ case "tail":
+ h.handleReadCommand(argc, args, omode.TailClient)
+ case "map":
+ h.handleMapCommand(argc, args)
+ case "ack":
+ h.handleAckCommand(argc, args)
+ case "ping":
+ h.send(h.hiddenMessages, "pong")
+ case "version":
+ h.send(h.serverMessages, fmt.Sprintf("Server version is "+version.String()))
+ case "debug":
+ h.send(h.serverMessages, logger.Debug(h.user, "Received debug command", argc, args))
+ default:
+ h.send(h.serverMessages, logger.Warn(h.user, "Received unknown command", argc, args))
+ }
+}
+
+func (h *ServerHandler) handleReadCommand(argc int, args []string, mode omode.Mode) {
+ regex := "."
+ if argc >= 4 {
+ regex = args[3]
+ }
+ if argc < 3 {
+ h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
+ return
+ }
+ go h.processFileGlob(mode, args[1], regex)
+}
+
+func (h *ServerHandler) handleMapCommand(argc int, args []string) {
+ if argc < 2 {
+ h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
+ return
+ }
+
+ queryStr := strings.Join(args[1:], " ")
+ logger.Info(h.user, "Creating new mapr aggregator", queryStr)
+ aggregate, err := server.NewAggregate(h.aggregatedMessages, queryStr)
+
+ if err != nil {
+ h.send(h.serverMessages, logger.Error(h.user, err))
+ return
+ }
+
+ h.aggregate = aggregate
+}
+
+func (h *ServerHandler) handleAckCommand(argc int, args []string) {
+ if argc < 3 {
+ h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
+ return
+ }
+ if args[1] == "close" && args[2] == "connection" {
+ close(h.ackStopReceived)
+ }
+}
+
+func (h *ServerHandler) send(ch chan<- string, message string) {
+ select {
+ case ch <- message:
+ case <-h.stop:
+ }
+}
+
+// Wait (block) until server handler is closed or a timeout has exceeded.
+func (h *ServerHandler) Wait() <-chan struct{} {
+ wait := make(chan struct{})
+
+ go func() {
+ select {
+ case <-h.ackStopReceived:
+ logger.Debug(h.user, "Closing wait channel due to ACK stop received")
+ close(wait)
+ case <-h.stopTimeout:
+ logger.Debug(h.user, "Closing wait channel due to wait timeout")
+ close(wait)
+ case <-h.stop:
+ logger.Debug(h.user, "Closing wait channel due to stop")
+ }
+ }()
+
+ return wait
+}
diff --git a/internal/server/server.go b/internal/server/server.go
new file mode 100644
index 0000000..27a98f5
--- /dev/null
+++ b/internal/server/server.go
@@ -0,0 +1,212 @@
+package server
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/server/handlers"
+ "github.com/mimecast/dtail/internal/ssh/server"
+ user "github.com/mimecast/dtail/internal/user/server"
+ "github.com/mimecast/dtail/internal/version"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// Server is the main server data structure.
+type Server struct {
+ // Various server statistics counters.
+ stats stats
+ // SSH server configuration.
+ sshServerConfig *gossh.ServerConfig
+ // To control the max amount of concurrent cats (which can cause a lot of I/O on the server)
+ catLimiterCh chan struct{}
+ // To control the max amount of concurrent tails
+ tailLimiterCh chan struct{}
+ // Ask to shutdown the server
+ stop chan struct{}
+}
+
+// New returns a new server.
+func New() *Server {
+ logger.Info("Creating server", version.String())
+
+ s := Server{
+ sshServerConfig: &gossh.ServerConfig{},
+ catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats),
+ tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails),
+ stop: make(chan struct{}),
+ }
+
+ s.sshServerConfig.PasswordCallback = s.controlUserCallback
+ s.sshServerConfig.PublicKeyCallback = server.PublicKeyCallback
+
+ private, err := gossh.ParsePrivateKey(server.PrivateHostKey())
+ if err != nil {
+ logger.FatalExit(err)
+ }
+ s.sshServerConfig.AddHostKey(private)
+
+ return &s
+}
+
+// Start the server.
+func (s *Server) Start() int {
+ logger.Info("Starting server")
+
+ bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
+ logger.Info("Binding server", bindAt)
+ listener, err := net.Listen("tcp", bindAt)
+ if err != nil {
+ logger.FatalExit("Failed to open listening TCP socket", err)
+ }
+
+ go s.stats.periodicLogServerStats(s.stop)
+
+ for {
+ conn, err := listener.Accept() // Blocking
+ if err != nil {
+ logger.Error("Failed to accept incoming connection", err)
+ continue
+ }
+
+ if err := s.stats.serverLimitExceeded(); err != nil {
+ logger.Error(err)
+ conn.Close()
+ continue
+ }
+
+ go s.handleConnection(conn)
+ }
+}
+
+func (s *Server) handleConnection(conn net.Conn) {
+ logger.Info("Handling connection")
+
+ sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig)
+ if err != nil {
+ logger.Error("Something just happened", err)
+ return
+ }
+
+ s.stats.incrementConnections()
+
+ go gossh.DiscardRequests(reqs)
+ for newChannel := range chans {
+ go s.handleChannel(sshConn, newChannel)
+ }
+}
+
+func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) {
+ user := user.New(sshConn.User(), sshConn.RemoteAddr().String())
+ logger.Info(user, "Invoking channel handler")
+
+ if newChannel.ChannelType() != "session" {
+ err := errors.New("Don'w allow other channel types than session")
+ logger.Error(user, err)
+ newChannel.Reject(gossh.Prohibited, err.Error())
+ return
+ }
+
+ channel, requests, err := newChannel.Accept()
+ if err != nil {
+ logger.Error(user, "Could not accept channel", err)
+ return
+ }
+
+ if err := s.handleRequests(sshConn, requests, channel, user); err != nil {
+ logger.Error(user, err)
+ sshConn.Close()
+ }
+}
+
+func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
+ logger.Info(user, "Invoking request handler")
+
+ for req := range in {
+ var payload = struct{ Value string }{}
+ gossh.Unmarshal(req.Payload, &payload)
+
+ switch req.Type {
+ case "shell":
+ var handler handlers.Handler
+ switch user.Name {
+ case config.ControlUser:
+ handler = handlers.NewControlHandler(user)
+ default:
+ handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh)
+ }
+
+ // Bi-directionally connect SSH stream to SSH handler
+ brokenPipe1 := make(chan struct{})
+ go func() {
+ defer close(brokenPipe1)
+ io.Copy(channel, handler)
+ }()
+
+ brokenPipe2 := make(chan struct{})
+ go func() {
+ defer close(brokenPipe2)
+ io.Copy(handler, channel)
+ }()
+
+ // Ensure to close all fd's and stop all goroutines once ssh connection terminated
+ go func() {
+ defer s.stats.decrementConnections()
+ defer handler.Close()
+
+ if err := sshConn.Wait(); err != nil && err != io.EOF {
+ logger.Error(user, err)
+ }
+ logger.Info(user, "Good bye Mister!")
+ }()
+
+ // Close the underlying ssh socket when server shuts down
+ go func() {
+ select {
+ case <-s.stop:
+ logger.Debug(user, "Server initiating shutdown on handler")
+ case <-handler.Wait():
+ logger.Debug(user, "Handler initiating shutdown by its own")
+ case <-brokenPipe1:
+ logger.Debug(user, "Broken pipe1")
+ case <-brokenPipe2:
+ logger.Debug(user, "Broken pipe2")
+ }
+ sshConn.Close()
+ logger.Info(user, "Closed SSH connection")
+ }()
+
+ // Only serving shell type
+ req.Reply(true, nil)
+
+ default:
+ req.Reply(false, nil)
+
+ return fmt.Errorf("Closing SSH connection as unknown request recieved|%s|%v",
+ req.Type, payload.Value)
+ }
+ }
+
+ return nil
+}
+
+func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) {
+ user := user.New(c.User(), c.RemoteAddr().String())
+
+ if user.Name == config.ControlUser && string(authPayload) == config.ControlUser {
+ logger.Debug(user, "Initiating master control program")
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("Not authorized")
+}
+
+// Stop the server.
+func (s *Server) Stop() {
+ close(s.stop)
+ s.stats.waitForConnections()
+}
diff --git a/internal/server/stats.go b/internal/server/stats.go
new file mode 100644
index 0000000..beb1885
--- /dev/null
+++ b/internal/server/stats.go
@@ -0,0 +1,88 @@
+package server
+
+import (
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/logger"
+ "fmt"
+ "runtime"
+ "sync"
+ "time"
+)
+
+// Used to collect and display various server stats.
+type stats struct {
+ mutex sync.Mutex
+ currentConnections int
+ lifetimeConnections uint64
+}
+
+func (s *stats) incrementConnections() {
+ defer s.logServerStats()
+
+ s.mutex.Lock()
+ s.currentConnections++
+ s.lifetimeConnections++
+ s.mutex.Unlock()
+}
+
+func (s *stats) decrementConnections() {
+ defer s.logServerStats()
+
+ s.mutex.Lock()
+ s.currentConnections--
+ s.mutex.Unlock()
+}
+
+func (s *stats) hasConnections() bool {
+ s.mutex.Lock()
+ currentConnections := s.currentConnections
+ s.mutex.Unlock()
+
+ has := currentConnections > 0
+ logger.Info("stats", "Server with open connections?", has, currentConnections)
+
+ return has
+}
+
+func (s *stats) logServerStats() {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ currentConnections := fmt.Sprintf("currentConnections=%d", s.currentConnections)
+ lifetimeConnections := fmt.Sprintf("lifetimeConnections=%d", s.lifetimeConnections)
+ goroutines := fmt.Sprintf("goroutines=%d", runtime.NumGoroutine())
+ logger.Info("stats", currentConnections, lifetimeConnections, goroutines)
+}
+
+func (s *stats) serverLimitExceeded() error {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ if s.currentConnections >= config.Server.MaxConnections {
+ return fmt.Errorf("Exceeded max allowed concurrent connections of %d", config.Server.MaxConnections)
+ }
+
+ return nil
+}
+
+func (s *stats) periodicLogServerStats(stop <-chan struct{}) {
+ for {
+ select {
+ case <-time.NewTimer(time.Second * 10).C:
+ s.logServerStats()
+ case <-stop:
+ return
+ }
+ }
+}
+
+func (s *stats) waitForConnections() {
+ for {
+ select {
+ case <-time.NewTimer(time.Second).C:
+ if !s.hasConnections() {
+ return
+ }
+ }
+ }
+}