diff options
| author | Paul Bütow <pbuetow@mimecast.com> | 2020-01-26 11:26:53 +0000 |
|---|---|---|
| committer | Paul Bütow <pbuetow@mimecast.com> | 2020-02-07 13:31:15 +0000 |
| commit | 0945da8dfefcbb723eecea0e5f4eafff63398253 (patch) | |
| tree | f06dab4d2bf21d25d176b23d5baeca588d27f5d7 /internal/clients/handlers/basehandler.go | |
| parent | 2a8e5de265a0e0a31a5834909d6879f5c9941467 (diff) | |
Introduce drun command, refactor code to use context package
Diffstat (limited to 'internal/clients/handlers/basehandler.go')
| -rw-r--r-- | internal/clients/handlers/basehandler.go | 84 |
1 files changed, 39 insertions, 45 deletions
diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 19246f9..68b8ddc 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -1,60 +1,44 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" - "errors" + "encoding/base64" "fmt" "io" + "strconv" "strings" "time" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/version" ) type baseHandler struct { + withCancel server string shellStarted bool commands chan string - pong chan struct{} receiveBuf []byte - stop chan struct{} - pingTimeout int + status int } func (h *baseHandler) Server() string { return h.server } -// Used to determine whether server is still responding to requests or not. -func (h *baseHandler) Ping() error { - if h.pingTimeout == 0 { - // Server ping disabled - return nil - } - - if err := h.SendCommand("ping"); err != nil { - return err - } - - select { - case <-h.pong: - return nil - case <-time.After(time.Duration(h.pingTimeout) * time.Second): - } - - return errors.New("Didn't receive any server pongs (ping replies)") +func (h *baseHandler) Status() int { + return h.status } -func (h *baseHandler) SendCommand(command string) error { - if command == "ping" { - logger.Trace("Sending command", h.server, command) - } else { - logger.Debug("Sending command", h.server, command) - } +// SendMessage to the server. +func (h *baseHandler) SendMessage(command string) error { + encoded := base64.StdEncoding.EncodeToString([]byte(command)) + logger.Debug("Sending command", h.server, command, encoded) select { - case h.commands <- fmt.Sprintf("%s;", command): + case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded): case <-time.After(time.Second * 5): - return errors.New("Timed out sending command " + command) - case <-h.stop: + return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded) + case <-h.ctx.Done(): } return nil @@ -81,7 +65,7 @@ func (h *baseHandler) Read(p []byte) (n int, err error) { select { case command := <-h.commands: n = copy(p, []byte(command)) - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF } return @@ -92,6 +76,7 @@ func (h *baseHandler) handleMessageType(message string) { if len(h.receiveBuf) == 0 { return } + // Hidden server commands starti with a dot "." if h.receiveBuf[0] == '.' { h.handleHiddenMessage(message) @@ -108,6 +93,7 @@ func (h *baseHandler) handleMessageType(message string) { h.receiveBuf = h.receiveBuf[:0] return } + logger.Raw(message) h.receiveBuf = h.receiveBuf[:0] } @@ -116,19 +102,27 @@ func (h *baseHandler) handleMessageType(message string) { // to the end user. func (h *baseHandler) handleHiddenMessage(message string) { switch { - case strings.HasPrefix(message, ".pong"): - h.pong <- struct{}{} case strings.HasPrefix(message, ".syn close connection"): - h.SendCommand("ack close connection") - } -} + h.SendMessage(".ack close connection") + select { + case <-time.After(time.Second * 1): + logger.Debug("Shutting down client after timeout and sending ack to server") + h.withCancel.shutdown() + case <-h.ctx.Done(): + } -// Stop the handler. -func (h *baseHandler) Stop() { - select { - case <-h.stop: - default: - logger.Debug("Stopping base handler", h.server) - close(h.stop) + case strings.HasPrefix(message, ".run exitstatus"): + splitted := strings.Split(strings.TrimSuffix(message, "\n"), " ") + if len(splitted) != 3 { + logger.Error("Unable to retrieve exitstatus", message) + return + } + i, err := strconv.Atoi(splitted[2]) + if err != nil { + logger.Error("Unable to retrieve exitstatus", message, err) + return + } + h.status = i + logger.Debug("Retrieved exitstatus", h.status) } } |
