summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2021-10-05 10:00:38 +0300
committerPaul Buetow <paul@buetow.org>2021-10-05 10:00:38 +0300
commitf70622f307629a2542ea5eb128dea8c1043d3a40 (patch)
tree82455dac0c870b28aea8c96a426050dc215a8818 /internal
parent599075bc6580ba77dc22ba1c1ec8aa908ef2462d (diff)
more on this
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/baseclient.go8
-rw-r--r--internal/clients/connectors/serverconnection.go19
-rw-r--r--internal/clients/connectors/serverless.go22
-rw-r--r--internal/clients/handlers/healthhandler.go114
-rw-r--r--internal/clients/healthclient.go97
-rw-r--r--internal/io/dlog/dlog.go6
-rw-r--r--internal/server/handlers/basehandler.go283
-rw-r--r--internal/server/handlers/readcommand.go4
-rw-r--r--internal/server/handlers/serverhandler.go332
-rw-r--r--internal/source/source.go9
10 files changed, 417 insertions, 477 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index fc01955..5ac298f 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -86,7 +86,7 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i
var mutex sync.Mutex
for i, conn := range c.connections {
go func(i int, conn connectors.Connector) {
- connStatus := c.start(ctx, active, i, conn)
+ connStatus := c.startConnection(ctx, active, i, conn)
// Update global status.
mutex.Lock()
@@ -97,11 +97,12 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i
}(i, conn)
}
+ time.Sleep(time.Second * 2)
c.waitUntilDone(ctx, active)
return
}
-func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn connectors.Connector) (status int) {
+func (c *baseClient) startConnection(ctx context.Context, active chan struct{}, i int, conn connectors.Connector) (status int) {
// Increment connection count
active <- struct{}{}
// Derement connection count
@@ -146,12 +147,13 @@ func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) {
<-ctx.Done()
}
+ // TODO: Rewrite this to use a wait group.
for {
numActive := len(active)
if numActive == 0 {
return
}
dlog.Client.Debug("Active connections", numActive)
- time.Sleep(time.Second)
+ time.Sleep(time.Second * time.Millisecond * 100)
}
}
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index 5bc63ee..1666a79 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -23,7 +23,6 @@ type ServerConnection struct {
config *ssh.ClientConfig
handler handlers.Handler
commands []string
- isOneOff bool
hostKeyCallback client.HostKeyCallback
throttlingDone bool
}
@@ -49,24 +48,6 @@ func NewServerConnection(server string, userName string, authMethods []ssh.AuthM
return &c
}
-// NewOneOffServerConnection creates new one-off connection (only for sending a series of commands and then quit).
-func NewOneOffServerConnection(server string, userName string, authMethods []ssh.AuthMethod, handler handlers.Handler, commands []string) *ServerConnection {
- c := ServerConnection{
- server: server,
- handler: handler,
- commands: commands,
- config: &ssh.ClientConfig{
- User: userName,
- Auth: authMethods,
- HostKeyCallback: ssh.InsecureIgnoreHostKey(),
- },
- isOneOff: true,
- }
-
- c.initServerPort()
- return &c
-}
-
func (c *ServerConnection) Server() string {
return c.server
}
diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go
index 7740aab..ae72c9b 100644
--- a/internal/clients/connectors/serverless.go
+++ b/internal/clients/connectors/serverless.go
@@ -20,14 +20,12 @@ type Serverless struct {
// NewServerConnection returns a new connection.
func NewServerless(userName string, handler handlers.Handler, commands []string) *Serverless {
- s := Serverless{
+ dlog.Client.Debug("Creating new serverless connector", handler, commands)
+ return &Serverless{
userName: userName,
handler: handler,
commands: commands,
}
-
- dlog.Client.Debug("Creating new serverless connector", handler, commands)
- return &s
}
func (s *Serverless) Server() string {
@@ -58,11 +56,17 @@ func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) erro
return err
}
- serverHandler := serverHandlers.NewServerHandler(
- user,
- make(chan struct{}, config.Server.MaxConcurrentCats),
- make(chan struct{}, config.Server.MaxConcurrentTails),
- )
+ var serverHandler serverHandlers.Handler
+ switch s.userName {
+ case config.ControlUser:
+ serverHandler = serverHandlers.NewControlHandler(user)
+ default:
+ serverHandler = serverHandlers.NewServerHandler(
+ user,
+ make(chan struct{}, config.Server.MaxConcurrentCats),
+ make(chan struct{}, config.Server.MaxConcurrentTails),
+ )
+ }
terminate := func() {
serverHandler.Shutdown()
diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go
index eca0348..4949985 100644
--- a/internal/clients/handlers/healthhandler.go
+++ b/internal/clients/handlers/healthhandler.go
@@ -1,90 +1,72 @@
package handlers
import (
- "bytes"
- "errors"
"fmt"
- "time"
+ "strings"
"github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/protocol"
)
-// HealthHandler implements the handler required for health checks.
+// HealthHandler is the handler used on the client side for running mapreduce aggregations.
type HealthHandler struct {
- done *internal.Done
- // Buffer of incoming data from server.
- receiveBuf bytes.Buffer
- // To send commands to the server.
- commands chan string
- // To receive messages from the server.
- receive chan<- string
- // The remote server address
- server string
- // The return status.
- status int
+ baseHandler
+ HealthStatusCh chan<- int
}
-// NewHealthHandler returns a new health check handler.
-func NewHealthHandler(server string, receive chan<- string) *HealthHandler {
- h := HealthHandler{
- server: server,
- receive: receive,
- commands: make(chan string),
- status: -1,
- done: internal.NewDone(),
+// NewHealthHandler returns a new health client handler.
+func NewHealthHandler(server string) *HealthHandler {
+ dlog.Client.Debug(server, "Creating new health handler")
+ return &HealthHandler{
+ baseHandler: baseHandler{
+ server: server,
+ shellStarted: false,
+ commands: make(chan string),
+ status: -1,
+ done: internal.NewDone(),
+ },
+ HealthStatusCh: make(chan int),
}
-
- return &h
-}
-
-// Server returns the remote server name.
-func (h *HealthHandler) Server() string {
- return h.server
-}
-
-// Status of the handler.
-func (h *HealthHandler) Status() int {
- return h.status
-}
-
-// Done returns done channel of the handler.
-func (h *HealthHandler) Done() <-chan struct{} {
- return h.done.Done()
}
-// Shutdown the handler.
-func (h *HealthHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// SendMessage sends a DTail command to the server.
-func (h *HealthHandler) SendMessage(command string) error {
- select {
- case h.commands <- fmt.Sprintf("%s;", command):
- case <-time.NewTimer(time.Second * 10).C:
- return errors.New("Timed out sending command " + command)
- case <-h.Done():
- }
-
- return nil
-}
-
-// Server writes byte stream to client.
+// Read data from the dtail server via Writer interface.
func (h *HealthHandler) Write(p []byte) (n int, err error) {
for _, b := range p {
- h.receiveBuf.WriteByte(b)
- if b == protocol.MessageDelimiter {
- h.receive <- h.receiveBuf.String()
- h.receiveBuf.Reset()
+ switch b {
+ case '\n':
+ continue
+ case protocol.MessageDelimiter:
+ message := h.baseHandler.receiveBuf.String()
+ dlog.Client.Debug(message)
+ h.handleHealthMessage(message)
+ h.baseHandler.receiveBuf.Reset()
+ default:
+ h.baseHandler.receiveBuf.WriteByte(b)
}
}
return len(p), nil
}
-// Server reads byte stream from client.
-func (h *HealthHandler) Read(p []byte) (n int, err error) {
- n = copy(p, []byte(<-h.commands))
- return
+func (h *HealthHandler) handleHealthMessage(message string) {
+ s := strings.Split(message, protocol.FieldDelimiter)
+ message = s[len(s)-1]
+ status := strings.Split(message, ":")
+ fmt.Println(status)
+ /*
+ switch status {
+ case "OK":
+ h.HealthStatusCh <- 0
+ case "WARNING":
+ h.HealthStatusCh <- 1
+ case "CRITICAL":
+ h.HealthStatusCh <- 2
+ case "UNKNOWN":
+ h.HealthStatusCh <- 3
+ default:
+ fmt.Println("CRITICAL: Unexpected server response: '%s'")
+ h.HealthStatusCh <- 2
+ }
+ */
}
diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go
index 47007b6..df919ae 100644
--- a/internal/clients/healthclient.go
+++ b/internal/clients/healthclient.go
@@ -1,101 +1,44 @@
package clients
import (
- "context"
- "fmt"
"runtime"
- "strings"
- "time"
- "github.com/mimecast/dtail/internal/clients/connectors"
"github.com/mimecast/dtail/internal/clients/handlers"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/protocol"
gossh "golang.org/x/crypto/ssh"
)
-// HealthClient is used for health checking (e.g. via Nagios)
+// HealthClient is used to perform a basic server health check.
type HealthClient struct {
- // Client operating mode
- mode omode.Mode
- // The remote server address
- server string
- // SSH user name
- userName string
- // SSH auth methods to use to connect to the remote servers.
- sshAuthMethods []gossh.AuthMethod
+ baseClient
}
-// NewHealthClient returns a new healh client.
-func NewHealthClient(mode omode.Mode) (*HealthClient, error) {
+// NewHealthClient returns a new health client.
+func NewHealthClient(args config.Args) (*HealthClient, error) {
+ args.Mode = omode.HealthClient
+ args.UserName = config.ControlUser
c := HealthClient{
- mode: mode,
- server: fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort),
- userName: config.ControlUser,
+ baseClient: baseClient{
+ Args: args,
+ throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
+ retry: false,
+ },
}
- c.initSSHAuthMethods()
+
+ c.init()
+ c.sshAuthMethods = append(c.sshAuthMethods, gossh.Password(config.ControlUser))
+ c.makeConnections(c)
return &c, nil
}
-// Start the health client.
-func (c *HealthClient) Start(ctx context.Context) (status int) {
- receive := make(chan string)
-
- throttleCh := make(chan struct{}, runtime.NumCPU())
- statsCh := make(chan struct{}, 1)
-
- conn := connectors.NewOneOffServerConnection(
- c.server,
- c.userName,
- c.sshAuthMethods,
- handlers.NewHealthHandler(c.server, receive),
- []string{c.mode.String()},
- )
-
- connCtx, cancel := context.WithCancel(ctx)
- go conn.Start(connCtx, cancel, throttleCh, statsCh)
-
- for {
- select {
- case data := <-receive:
- // Parse recieved data.
- s := strings.Split(data, protocol.FieldDelimiter)
- message := s[len(s)-1]
- if strings.HasPrefix(message, "done;") {
- return
- }
-
- // Set severity.
- s = strings.Split(message, ":")
- switch s[0] {
- case "OK":
- case "WARNING":
- if status < 1 {
- status = 1
- }
- case "CRITICAL":
- status = 2
- case "UNKNOWN":
- status = 3
- default:
- fmt.Printf("CRITICAL: Unexpected server response: '%s'\n", message)
- status = 2
- return
- }
- fmt.Print(message)
-
- case <-time.After(time.Second * 2):
- status = 2
- fmt.Println("CRITICAL: Could not communicate with DTail server")
- return
- }
- }
+func (c HealthClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewHealthHandler(server)
}
-// Initialize SSH auth methods.
-func (c *HealthClient) initSSHAuthMethods() {
- c.sshAuthMethods = append(c.sshAuthMethods, gossh.Password(config.ControlUser))
+func (c HealthClient) makeCommands() (commands []string) {
+ commands = append(commands, "health")
+ return
}
diff --git a/internal/io/dlog/dlog.go b/internal/io/dlog/dlog.go
index 2beda75..db99307 100644
--- a/internal/io/dlog/dlog.go
+++ b/internal/io/dlog/dlog.go
@@ -57,6 +57,12 @@ func Start(ctx context.Context, wg *sync.WaitGroup, sourceProcess source.Source,
Client = New(source.Server, source.Client, level, impl, strategy)
Server = New(source.Server, source.Server, level, impl, strategy)
Common = Server
+ case source.HealthCheck:
+ // Health check isn't logging anything.
+ impl := loggers.STDOUT
+ Client = New(source.HealthCheck, source.Client, level, impl, strategy)
+ Server = New(source.HealthCheck, source.Server, level, impl, strategy)
+ Common = Client
}
var wg2 sync.WaitGroup
diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go
new file mode 100644
index 0000000..12fb2b3
--- /dev/null
+++ b/internal/server/handlers/basehandler.go
@@ -0,0 +1,283 @@
+package handlers
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/pool"
+ "github.com/mimecast/dtail/internal/mapr/server"
+ "github.com/mimecast/dtail/internal/protocol"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+type handleCommandCb func(context.Context, int, []string)
+
+type baseHandler struct {
+ done *internal.Done
+ handleCommandCb handleCommandCb
+ lines chan line.Line
+ aggregate *server.Aggregate
+ maprMessages chan string
+ serverMessages chan string
+ hostname string
+ user *user.User
+ ackCloseReceived chan struct{}
+ activeCommands int32
+ quiet bool
+ spartan bool
+ serverless bool
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+}
+
+// Shutdown the handler.
+func (h *baseHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+// Done channel of the handler.
+func (h *baseHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+// Read is to send data to the dtail client via Reader interface.
+func (h *baseHandler) Read(p []byte) (n int, err error) {
+ defer h.readBuf.Reset()
+
+ select {
+ case message := <-h.serverMessages:
+ if message[0] == '.' {
+ // Handle hidden message (don't display to the user, interpreted by dtail client)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ return
+ }
+
+ if h.serverless {
+ // In serverless mode we have logged the server message already via the
+ // dlog logger, no need to send the message again to the client part.
+ return
+ }
+
+ // Handle normal server message (display to the user)
+ h.readBuf.WriteString("SERVER")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case message := <-h.maprMessages:
+ // Send mapreduce-aggregated data as a message.
+ h.readBuf.WriteString("AGGREGATE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case line := <-h.lines:
+ if !h.spartan {
+ h.readBuf.WriteString("REMOTE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%v", line.Count))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(line.SourceID)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ }
+ h.readBuf.WriteString(line.Content.String())
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ pool.RecycleBytesBuffer(line.Content)
+
+ case <-time.After(time.Second):
+ // Once in a while check whether we are done.
+ select {
+ case <-h.done.Done():
+ err = io.EOF
+ return
+ default:
+ }
+ }
+ return
+}
+
+// Write is to receive data from the dtail client via Writer interface.
+func (h *baseHandler) Write(p []byte) (n int, err error) {
+ for _, b := range p {
+ switch b {
+ case ';':
+ h.handleCommand(string(h.writeBuf.Bytes()))
+ h.writeBuf.Reset()
+ default:
+ h.writeBuf.WriteByte(b)
+ }
+ }
+
+ n = len(p)
+ return
+}
+
+func (h *baseHandler) handleCommand(commandStr string) {
+ dlog.Server.Debug(h.user, commandStr)
+
+ args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add)
+ return
+ }
+
+ args, argc, err = h.handleBase64(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err))
+ return
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
+
+ h.handleCommandCb(ctx, argc, args)
+}
+
+func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) {
+ argc := len(args)
+ var add string
+
+ if argc <= 2 || args[0] != "protocol" {
+ return args, argc, add, errors.New("unable to determine protocol version")
+ }
+
+ if args[1] != protocol.ProtocolCompat {
+ clientCompat, _ := strconv.Atoi(args[1])
+ serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat)
+ if clientCompat <= 3 {
+ // Protocol version 3 or lower expect a newline as message separator
+ // One day (after 2 major versions) this exception may be removed!
+ add = "\n"
+ }
+
+ toUpdate := "client"
+ if clientCompat > serverCompat {
+ toUpdate = "server"
+ }
+
+ err := fmt.Errorf("DTail server protocol version '%s' does not match client protocol version '%s', please update DTail %s!",
+ protocol.ProtocolCompat, args[1], toUpdate)
+ return args, argc, add, err
+ }
+
+ return args[2:], argc - 2, add, nil
+}
+
+func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, error) {
+ err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible")
+
+ if argc != 2 || args[0] != "base64" {
+ return args, argc, err
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(args[1])
+ if err != nil {
+ return args, argc, err
+ }
+ decodedStr := string(decoded)
+
+ args = strings.Split(decodedStr, " ")
+ argc = len(decodedStr)
+ dlog.Server.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
+
+ return args, argc, nil
+}
+
+func (h *baseHandler) handleAckCommand(argc int, args []string) {
+ if argc < 3 {
+ if !h.quiet {
+ h.send(h.serverMessages, dlog.Server.Warn(h.user, "Unable to parse command", args, argc))
+ }
+ return
+ }
+ if args[1] == "close" && args[2] == "connection" {
+ select {
+ case <-h.ackCloseReceived:
+ default:
+ close(h.ackCloseReceived)
+ }
+ }
+}
+
+func (h *baseHandler) send(ch chan<- string, message string) {
+ select {
+ case ch <- message:
+ case <-h.done.Done():
+ }
+}
+
+func (h *baseHandler) flush() {
+ dlog.Server.Debug(h.user, "flush()")
+
+ numUnsentMessages := func() int {
+ return len(h.lines) + len(h.serverMessages) + len(h.maprMessages)
+ }
+
+ for i := 0; i < 3; i++ {
+ if numUnsentMessages() == 0 {
+ dlog.Server.Debug(h.user, "All lines sent")
+ return
+ }
+ dlog.Server.Debug(h.user, "Still lines to be sent")
+ time.Sleep(time.Second)
+ }
+
+ dlog.Server.Warn(h.user, "Some lines remain unsent", numUnsentMessages())
+}
+
+func (h *baseHandler) shutdown() {
+ dlog.Server.Debug(h.user, "shutdown()")
+ h.flush()
+
+ go func() {
+ select {
+ case h.serverMessages <- ".syn close connection":
+ case <-h.done.Done():
+ }
+ }()
+
+ select {
+ case <-h.ackCloseReceived:
+ case <-time.After(time.Second * 5):
+ dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
+ case <-h.done.Done():
+ }
+
+ h.done.Shutdown()
+}
+
+func (h *baseHandler) incrementActiveCommands() {
+ atomic.AddInt32(&h.activeCommands, 1)
+}
+
+func (h *baseHandler) decrementActiveCommands() int32 {
+ atomic.AddInt32(&h.activeCommands, -1)
+ return atomic.LoadInt32(&h.activeCommands)
+}
diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go
index 6579018..abc44c7 100644
--- a/internal/server/handlers/readcommand.go
+++ b/internal/server/handlers/readcommand.go
@@ -32,13 +32,13 @@ func (r *readCommand) Start(ctx context.Context, argc int, args []string, retrie
if argc >= 4 {
deserializedRegex, err := regex.Deserialize(strings.Join(args[2:], " "))
if err != nil {
- r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user, commandParseWarning, err))
+ r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user, "Unable to parse command", err))
return
}
re = deserializedRegex
}
if argc < 3 {
- r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, commandParseWarning, args, argc))
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, "Unable to parse command", args, argc))
return
}
r.readGlob(ctx, args[1], re, retries)
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index ace2626..2ec4fbf 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -1,69 +1,60 @@
package handlers
import (
- "bytes"
"context"
- "encoding/base64"
- "errors"
- "fmt"
- "io"
"os"
- "strconv"
"strings"
- "sync/atomic"
- "time"
"github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/line"
- "github.com/mimecast/dtail/internal/io/pool"
- "github.com/mimecast/dtail/internal/mapr/server"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/protocol"
user "github.com/mimecast/dtail/internal/user/server"
)
-const (
- commandParseWarning string = "Unable to parse command"
-)
-
// ServerHandler implements the Reader and Writer interfaces to handle
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- done *internal.Done
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- maprMessages chan string
- serverMessages chan string
- hostname string
- user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- ackCloseReceived chan struct{}
- activeCommands int32
- quiet bool
- spartan bool
- serverless bool
- readBuf bytes.Buffer
- writeBuf bytes.Buffer
+ baseHandler
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ regex string
+ /*
+ done *internal.Done
+ lines chan line.Line
+ aggregate *server.Aggregate
+ maprMessages chan string
+ serverMessages chan string
+ hostname string
+ user *user.User
+ ackCloseReceived chan struct{}
+ activeCommands int32
+ quiet bool
+ spartan bool
+ serverless bool
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+ */
}
// NewServerHandler returns the server handler.
func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler {
h := ServerHandler{
- done: internal.NewDone(),
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- maprMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- regex: ".",
- user: user,
- }
+ baseHandler: baseHandler{
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ maprMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ user: user,
+ },
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ regex: ".",
+ }
+ h.handleCommandCb = h.handleUserCommand
fqdn, err := os.Hostname()
if err != nil {
@@ -76,192 +67,8 @@ func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *S
return &h
}
-// Shutdown the handler.
-func (h *ServerHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// Done channel of the handler.
-func (h *ServerHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-// Read is to send data to the dtail client via Reader interface.
-func (h *ServerHandler) Read(p []byte) (n int, err error) {
- defer h.readBuf.Reset()
-
- select {
- case message := <-h.serverMessages:
- if message[0] == '.' {
- // Handle hidden message (don't display to the user, interpreted by dtail client)
- h.readBuf.WriteString(message)
- h.readBuf.WriteByte(protocol.MessageDelimiter)
- n = copy(p, h.readBuf.Bytes())
- return
- }
-
- if h.serverless {
- // In serverless mode we have logged the server message already via the
- // dlog logger, no need to send the message again to the client part.
- return
- }
-
- // Handle normal server message (display to the user)
- h.readBuf.WriteString("SERVER")
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(h.hostname)
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(message)
- h.readBuf.WriteByte(protocol.MessageDelimiter)
- n = copy(p, h.readBuf.Bytes())
-
- case message := <-h.maprMessages:
- // Send mapreduce-aggregated data as a message.
- h.readBuf.WriteString("AGGREGATE")
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(h.hostname)
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(message)
- h.readBuf.WriteByte(protocol.MessageDelimiter)
- n = copy(p, h.readBuf.Bytes())
-
- case line := <-h.lines:
- if !h.spartan {
- h.readBuf.WriteString("REMOTE")
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(h.hostname)
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc))
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(fmt.Sprintf("%v", line.Count))
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(line.SourceID)
- h.readBuf.WriteString(protocol.FieldDelimiter)
- }
- h.readBuf.WriteString(line.Content.String())
- h.readBuf.WriteByte(protocol.MessageDelimiter)
- n = copy(p, h.readBuf.Bytes())
- pool.RecycleBytesBuffer(line.Content)
-
- case <-time.After(time.Second):
- // Once in a while check whether we are done.
- select {
- case <-h.done.Done():
- err = io.EOF
- return
- default:
- }
- }
- return
-}
-
-// Write is to receive data from the dtail client via Writer interface.
-func (h *ServerHandler) Write(p []byte) (n int, err error) {
- for _, b := range p {
- switch b {
- case ';':
- h.handleCommand(string(h.writeBuf.Bytes()))
- h.writeBuf.Reset()
- default:
- h.writeBuf.WriteByte(b)
- }
- }
-
- n = len(p)
- return
-}
-
-func (h *ServerHandler) handleCommand(commandStr string) {
- dlog.Server.Debug(h.user, commandStr)
- ctx := context.Background()
-
- args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
- if err != nil {
- h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add)
- return
- }
-
- args, argc, err = h.handleBase64(args, argc)
- if err != nil {
- h.send(h.serverMessages, dlog.Server.Error(h.user, err))
- return
- }
-
- if h.user.Name == config.ControlUser {
- h.handleControlCommand(argc, args)
- return
- }
-
- ctx, cancel := context.WithCancel(ctx)
- go func() {
- <-h.done.Done()
- cancel()
- }()
-
- h.handleUserCommand(ctx, argc, args)
-}
-
-func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, string, error) {
- argc := len(args)
- var add string
-
- if argc <= 2 || args[0] != "protocol" {
- return args, argc, add, errors.New("unable to determine protocol version")
- }
-
- if args[1] != protocol.ProtocolCompat {
- clientCompat, _ := strconv.Atoi(args[1])
- serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat)
- if clientCompat <= 3 {
- // Protocol version 3 or lower expect a newline as message separator
- // One day (after 2 major versions) this exception may be removed!
- add = "\n"
- }
-
- toUpdate := "client"
- if clientCompat > serverCompat {
- toUpdate = "server"
- }
-
- err := fmt.Errorf("DTail server protocol version '%s' does not match client protocol version '%s', please update DTail %s!",
- protocol.ProtocolCompat, args[1], toUpdate)
- return args, argc, add, err
- }
-
- return args[2:], argc - 2, add, nil
-}
-
-func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) {
- err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible")
-
- if argc != 2 || args[0] != "base64" {
- return args, argc, err
- }
-
- decoded, err := base64.StdEncoding.DecodeString(args[1])
- if err != nil {
- return args, argc, err
- }
- decodedStr := string(decoded)
-
- args = strings.Split(decodedStr, " ")
- argc = len(decodedStr)
- dlog.Server.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
-
- return args, argc, nil
-}
-
-func (h *ServerHandler) handleControlCommand(argc int, args []string) {
- switch args[0] {
- case "debug":
- h.send(h.serverMessages, dlog.Server.Debug(h.user, "Receiving debug command", argc, args))
- default:
- dlog.Server.Warn(h.user, "Received unknown control command", argc, args)
- }
-}
-
func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
- dlog.Server.Debug(h.user, "handleUserCommand", argc, args)
+ dlog.Server.Debug(h.user, "Handling user command", argc, args)
h.incrementActiveCommands()
commandFinished := func() {
@@ -332,74 +139,3 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
commandFinished()
}
}
-
-func (h *ServerHandler) handleAckCommand(argc int, args []string) {
- if argc < 3 {
- if !h.quiet {
- h.send(h.serverMessages, dlog.Server.Warn(h.user, commandParseWarning, args, argc))
- }
- return
- }
- if args[1] == "close" && args[2] == "connection" {
- select {
- case <-h.ackCloseReceived:
- default:
- close(h.ackCloseReceived)
- }
- }
-}
-
-func (h *ServerHandler) send(ch chan<- string, message string) {
- select {
- case ch <- message:
- case <-h.done.Done():
- }
-}
-
-func (h *ServerHandler) flush() {
- dlog.Server.Debug(h.user, "flush()")
-
- unsentMessages := func() int {
- return len(h.lines) + len(h.serverMessages) + len(h.maprMessages)
- }
- for i := 0; i < 3; i++ {
- if unsentMessages() == 0 {
- dlog.Server.Debug(h.user, "All lines sent")
- return
- }
- dlog.Server.Debug(h.user, "Still lines to be sent")
- time.Sleep(time.Second)
- }
-
- dlog.Server.Warn(h.user, "Some lines remain unsent", unsentMessages())
-}
-
-func (h *ServerHandler) shutdown() {
- dlog.Server.Debug(h.user, "shutdown()")
- h.flush()
-
- go func() {
- select {
- case h.serverMessages <- ".syn close connection":
- case <-h.done.Done():
- }
- }()
-
- select {
- case <-h.ackCloseReceived:
- case <-time.After(time.Second * 5):
- dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.done.Done():
- }
-
- h.done.Shutdown()
-}
-
-func (h *ServerHandler) incrementActiveCommands() {
- atomic.AddInt32(&h.activeCommands, 1)
-}
-
-func (h *ServerHandler) decrementActiveCommands() int32 {
- atomic.AddInt32(&h.activeCommands, -1)
- return atomic.LoadInt32(&h.activeCommands)
-}
diff --git a/internal/source/source.go b/internal/source/source.go
index 73dccb2..be7aecd 100644
--- a/internal/source/source.go
+++ b/internal/source/source.go
@@ -3,8 +3,9 @@ package source
type Source int
const (
- Client Source = iota
- Server Source = iota
+ Client Source = iota
+ Server Source = iota
+ HealthCheck Source = iota
)
func (s Source) String() string {
@@ -13,7 +14,9 @@ func (s Source) String() string {
return "CLIENT"
case Server:
return "SERVER"
+ case HealthCheck:
+ return "HEALTHCHECK"
}
- panic("Unknown log source type")
+ panic("Unknown source type")
}