summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2020-12-08 14:49:41 +0000
committerPaul Buetow <pbuetow@mimecast.com>2020-12-08 14:49:41 +0000
commit799b9b69ba08b898e13026b7ecab9f9f58580a82 (patch)
tree34bc0e5e539aed99dd1f13e7489e9d3111ba050f
parent6b2d8539a66f1b36ffd55c56723376b9b068a5dc (diff)
merge develop
-rw-r--r--Makefile6
-rw-r--r--cmd/dcat/main.go2
-rw-r--r--cmd/dgrep/main.go2
-rw-r--r--cmd/dmap/main.go2
-rw-r--r--cmd/drun/main.go88
-rw-r--r--cmd/dtail/main.go2
-rw-r--r--doc/examples.md11
-rw-r--r--doc/quickstart.md3
-rw-r--r--internal/clients/baseclient.go4
-rw-r--r--internal/clients/client.go2
-rw-r--r--internal/clients/handlers/basehandler.go23
-rw-r--r--internal/clients/handlers/clienthandler.go5
-rw-r--r--internal/clients/handlers/handler.go3
-rw-r--r--internal/clients/handlers/healthhandler.go19
-rw-r--r--internal/clients/handlers/maprhandler.go5
-rw-r--r--internal/clients/healthclient.go2
-rw-r--r--internal/clients/maprclient.go4
-rw-r--r--internal/clients/remote/connection.go7
-rw-r--r--internal/clients/stats.go42
-rw-r--r--internal/config/config.go3
-rw-r--r--internal/config/server.go1
-rw-r--r--internal/io/logger/logger.go2
-rw-r--r--internal/io/signal/signal.go29
-rw-r--r--internal/mapr/server/aggregate.go81
-rw-r--r--internal/regex/flag.go2
-rw-r--r--internal/regex/regex.go11
-rw-r--r--internal/server/continuous.go2
-rw-r--r--internal/server/handlers/controlhandler.go28
-rw-r--r--internal/server/handlers/handler.go2
-rw-r--r--internal/server/handlers/serverhandler.go182
-rw-r--r--internal/server/scheduler.go2
-rw-r--r--internal/server/server.go44
-rw-r--r--internal/version/version.go2
33 files changed, 248 insertions, 375 deletions
diff --git a/Makefile b/Makefile
index 75b9333..e97656c 100644
--- a/Makefile
+++ b/Makefile
@@ -5,7 +5,6 @@ build:
${GO} build -o dcat ./cmd/dcat/main.go
${GO} build -o dgrep ./cmd/dgrep/main.go
${GO} build -o dmap ./cmd/dmap/main.go
- ${GO} build -o drun ./cmd/drun/main.go
${GO} build -o dtail ./cmd/dtail/main.go
clean:
ls ./cmd/ | while read cmd; do \
@@ -16,7 +15,6 @@ install: build
cp -pv dcat ${GOPATH}/bin/dcat
cp -pv dgrep ${GOPATH}/bin/dgrep
cp -pv dmap ${GOPATH}/bin/dmap
- cp -pv drun ${GOPATH}/bin/drun
cp -pv dtail ${GOPATH}/bin/dtail
vet:
find . -type d | while read dir; do \
@@ -26,8 +24,8 @@ vet:
lint:
${GO} get golang.org/x/lint/golint
find . -type d | while read dir; do \
- echo ${GOPATH}/bin/golint $$dir; \
- ${GOPATH}/bin/golint $$dir; \
+ echo golint $$dir; \
+ golint $$dir; \
done
test:
${GO} test ./... -v
diff --git a/cmd/dcat/main.go b/cmd/dcat/main.go
index f0ea946..05e46ab 100644
--- a/cmd/dcat/main.go
+++ b/cmd/dcat/main.go
@@ -55,7 +55,7 @@ func main() {
panic(err)
}
- status := client.Start(ctx, signal.StatsCh(ctx))
+ status := client.Start(ctx, signal.InterruptCh(ctx))
logger.Flush()
os.Exit(status)
}
diff --git a/cmd/dgrep/main.go b/cmd/dgrep/main.go
index d1fdc21..133631f 100644
--- a/cmd/dgrep/main.go
+++ b/cmd/dgrep/main.go
@@ -63,7 +63,7 @@ func main() {
panic(err)
}
- status := client.Start(ctx, signal.StatsCh(ctx))
+ status := client.Start(ctx, signal.InterruptCh(ctx))
logger.Flush()
os.Exit(status)
}
diff --git a/cmd/dmap/main.go b/cmd/dmap/main.go
index 279b343..9f9ca9d 100644
--- a/cmd/dmap/main.go
+++ b/cmd/dmap/main.go
@@ -62,7 +62,7 @@ func main() {
panic(err)
}
- status := client.Start(ctx, signal.StatsCh(ctx))
+ status := client.Start(ctx, signal.InterruptCh(ctx))
logger.Flush()
os.Exit(status)
}
diff --git a/cmd/drun/main.go b/cmd/drun/main.go
deleted file mode 100644
index ffdf7bf..0000000
--- a/cmd/drun/main.go
+++ /dev/null
@@ -1,88 +0,0 @@
-package main
-
-import (
- "context"
- "flag"
- "io/ioutil"
- "os"
- "strings"
-
- "github.com/mimecast/dtail/internal/clients"
- "github.com/mimecast/dtail/internal/color"
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/io/signal"
- "github.com/mimecast/dtail/internal/user"
- "github.com/mimecast/dtail/internal/version"
-)
-
-// The evil begins here.
-func main() {
- var args clients.Args
- var background string
- var cfgFile string
- var command string
- var debugEnable bool
- var displayVersion bool
- var jobName string
- var noColor bool
- var sshPort int
-
- userName := user.Name()
-
- flag.BoolVar(&args.TrustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys")
- flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages")
- flag.BoolVar(&displayVersion, "version", false, "Display version")
- flag.BoolVar(&noColor, "noColor", false, "Disable ANSII terminal colors")
- flag.IntVar(&args.ConnectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently")
- flag.IntVar(&args.Timeout, "timeout", 0, "Command execution timeout")
- flag.IntVar(&sshPort, "port", 2222, "SSH server port")
- flag.StringVar(&args.Discovery, "discovery", "", "Server discovery method")
- flag.StringVar(&args.PrivateKeyPathFile, "key", "", "Path to private key")
- flag.StringVar(&args.ServersStr, "servers", "", "Remote servers to connect")
- flag.StringVar(&args.UserName, "user", userName, "Your system user name")
- flag.StringVar(&background, "background", "", "Can be one of 'start', 'cancel', 'list' or empty")
- flag.StringVar(&cfgFile, "cfg", "", "Config file path")
- flag.StringVar(&command, "command", "", "Command to run")
- flag.StringVar(&jobName, "name", "", "The job name (if run in background)")
-
- flag.Parse()
-
- config.Read(cfgFile, sshPort)
- color.Colored = !noColor
-
- if displayVersion {
- version.PrintAndExit()
- }
-
- ctx := context.TODO()
- logger.Start(ctx, logger.Modes{Debug: debugEnable || config.Common.DebugEnable})
-
- args.What, args.Arguments = readCommand(command)
- client, err := clients.NewRunClient(args, background, jobName)
- if err != nil {
- panic(err)
- }
-
- status := client.Start(ctx, signal.StatsCh(ctx))
- logger.Flush()
- os.Exit(status)
-}
-
-func readCommand(command string) (string, []string) {
- splitted := strings.Split(command, " ")
-
- script := splitted[0]
- if _, err := os.Stat(script); os.IsNotExist(err) {
- var commandArgs []string
- return command, commandArgs
- }
- commandArgs := splitted[1:]
-
- bytes, err := ioutil.ReadFile(script)
- if err != nil {
- panic(err)
- }
-
- return string(bytes), commandArgs
-}
diff --git a/cmd/dtail/main.go b/cmd/dtail/main.go
index ff9028b..aefaa6a 100644
--- a/cmd/dtail/main.go
+++ b/cmd/dtail/main.go
@@ -106,7 +106,7 @@ func main() {
}
}
- status := client.Start(ctx, signal.StatsCh(ctx))
+ status := client.Start(ctx, signal.InterruptCh(ctx))
logger.Flush()
os.Exit(status)
}
diff --git a/doc/examples.md b/doc/examples.md
index 5a5d892..964660a 100644
--- a/doc/examples.md
+++ b/doc/examples.md
@@ -52,17 +52,6 @@ The following example demonstrates how to grep files (display only the lines whi
![dgrep](dgrep.gif "Grep example")
-# How to use ``drun``
-
-The following example demonstrates how to execute a command on multiple machines remotely:
-
-```shell
-% drun --servers <(head -n 30 serverlist.txt) \
- --command uptime
-```
-
-![dgrep](drun.gif "Run example")
-
# How to use ``dmap``
To run a mapreduce aggregation over logs written in the past the ``dmap`` command can be used. For example the following command aggregates all mapreduce fields of all the logs and calculates the average memory free grouped by day of the month, hour, minute and the server hostname. ``dmap`` will print interim results every few seconds. The final result however will be written to file ``mapreduce.csv``.
diff --git a/doc/quickstart.md b/doc/quickstart.md
index 733442f..6baedbb 100644
--- a/doc/quickstart.md
+++ b/doc/quickstart.md
@@ -16,7 +16,7 @@ On Linux you need to install the libacl development library for file system ACL
To compile and install all DTail binaries directly from GitHub run:
```console
-% for cmd in dcat dgrep dmap drun dtail dserver; do
+% for cmd in dcat dgrep dmap dtail dserver; do
go get github.com/mimecast/dtail/cmd/$cmd;
done
```
@@ -26,7 +26,6 @@ It produces the following executables in ``$GOPATH/bin``:
* ``dcat``: Client for displaying whole files remotely (distributed cat)
* ``dgrep``: Client for searching whole files files remotely using a regex (distributed grep)
* ``dmap``: Client for executing distributed mapreduce queries (may will consume a lot of RAM and CPU)
-* ``drun``: Client for executing commands on remote servers.
* ``dtail``: Client for tailing/following log files remotely (distributed tail)
* ``dserver``: The DTail server
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index 008a01e..69055a3 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -66,7 +66,7 @@ func (c *baseClient) makeConnections(maker maker) {
c.stats = newTailStats(len(c.connections))
}
-func (c *baseClient) Start(ctx context.Context, statsCh <-chan struct{}) (status int) {
+func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status int) {
// Periodically check for unknown hosts, and ask the user whether to trust them or not.
go c.hostKeyCallback.PromptAddHosts(ctx)
// Print client stats every time something on statsCh is recieved.
@@ -99,7 +99,7 @@ func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, con
defer func() { <-active }()
for {
- connCtx, cancel := conn.Handler.WithCancel(ctx)
+ connCtx, cancel := context.WithCancel(ctx)
defer cancel()
conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh)
diff --git a/internal/clients/client.go b/internal/clients/client.go
index eb8452d..4a547e8 100644
--- a/internal/clients/client.go
+++ b/internal/clients/client.go
@@ -4,5 +4,5 @@ import "context"
// Client is the interface for the end user command line client.
type Client interface {
- Start(ctx context.Context, statsCh <-chan struct{}) int
+ Start(ctx context.Context, statsCh <-chan string) int
}
diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go
index 65bbfd7..b5045e2 100644
--- a/internal/clients/handlers/basehandler.go
+++ b/internal/clients/handlers/basehandler.go
@@ -8,12 +8,13 @@ import (
"strings"
"time"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/version"
)
type baseHandler struct {
- withCancel
+ done *internal.Done
server string
shellStarted bool
commands chan string
@@ -29,6 +30,14 @@ func (h *baseHandler) Status() int {
return h.status
}
+func (h *baseHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+func (h *baseHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
// SendMessage to the server.
func (h *baseHandler) SendMessage(command string) error {
encoded := base64.StdEncoding.EncodeToString([]byte(command))
@@ -38,7 +47,8 @@ func (h *baseHandler) SendMessage(command string) error {
case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded):
case <-time.After(time.Second * 5):
return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded)
- case <-h.ctx.Done():
+ case <-h.Done():
+ return nil
}
return nil
@@ -65,7 +75,7 @@ func (h *baseHandler) Read(p []byte) (n int, err error) {
select {
case command := <-h.commands:
n = copy(p, []byte(command))
- case <-h.ctx.Done():
+ case <-h.Done():
return 0, io.EOF
}
return
@@ -95,10 +105,11 @@ func (h *baseHandler) handleHiddenMessage(message string) {
case strings.HasPrefix(message, ".syn close connection"):
h.SendMessage(".ack close connection")
select {
- case <-time.After(time.Second * 1):
+ case <-time.After(time.Second * 5):
logger.Debug("Shutting down client after timeout and sending ack to server")
- h.withCancel.shutdown()
- case <-h.ctx.Done():
+ h.Shutdown()
+ case <-h.Done():
+ return
}
case strings.HasPrefix(message, ".run exitstatus"):
diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go
index fcd8052..2bcb038 100644
--- a/internal/clients/handlers/clienthandler.go
+++ b/internal/clients/handlers/clienthandler.go
@@ -1,6 +1,7 @@
package handlers
import (
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/io/logger"
)
@@ -19,9 +20,7 @@ func NewClientHandler(server string) *ClientHandler {
shellStarted: false,
commands: make(chan string),
status: -1,
- withCancel: withCancel{
- done: make(chan struct{}),
- },
+ done: internal.NewDone(),
},
}
}
diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go
index c53ca34..afa87e2 100644
--- a/internal/clients/handlers/handler.go
+++ b/internal/clients/handlers/handler.go
@@ -1,7 +1,6 @@
package handlers
import (
- "context"
"io"
)
@@ -11,6 +10,6 @@ type Handler interface {
SendMessage(command string) error
Server() string
Status() int
- WithCancel(ctx context.Context) (context.Context, context.CancelFunc)
+ Shutdown()
Done() <-chan struct{}
}
diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go
index 9051015..08ed137 100644
--- a/internal/clients/handlers/healthhandler.go
+++ b/internal/clients/handlers/healthhandler.go
@@ -4,11 +4,13 @@ import (
"errors"
"fmt"
"time"
+
+ "github.com/mimecast/dtail/internal"
)
// HealthHandler implements the handler required for health checks.
type HealthHandler struct {
- withCancel
+ done *internal.Done
// Buffer of incoming data from server.
receiveBuf []byte
// To send commands to the server.
@@ -27,9 +29,7 @@ func NewHealthHandler(server string, receive chan<- string) *HealthHandler {
receive: receive,
commands: make(chan string),
status: -1,
- withCancel: withCancel{
- done: make(chan struct{}),
- },
+ done: internal.NewDone(),
}
return &h
@@ -45,12 +45,23 @@ func (h *HealthHandler) Status() int {
return h.status
}
+// Done returns done channel of the handler.
+func (h *HealthHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+// Shutdown the handler.
+func (h *HealthHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
// SendMessage sends a DTail command to the server.
func (h *HealthHandler) SendMessage(command string) error {
select {
case h.commands <- fmt.Sprintf("%s;", command):
case <-time.NewTimer(time.Second * 10).C:
return errors.New("Timed out sending command " + command)
+ case <-h.Done():
}
return nil
diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go
index b908f3b..fb71c8f 100644
--- a/internal/clients/handlers/maprhandler.go
+++ b/internal/clients/handlers/maprhandler.go
@@ -3,6 +3,7 @@ package handlers
import (
"strings"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/mapr/client"
@@ -24,9 +25,7 @@ func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGr
shellStarted: false,
commands: make(chan string),
status: -1,
- withCancel: withCancel{
- done: make(chan struct{}),
- },
+ done: internal.NewDone(),
},
query: query,
aggregate: client.NewAggregate(server, query, globalGroup),
diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go
index 7313583..e93f6be 100644
--- a/internal/clients/healthclient.go
+++ b/internal/clients/healthclient.go
@@ -50,7 +50,7 @@ func (c *HealthClient) Start(ctx context.Context) (status int) {
conn.Handler = handlers.NewHealthHandler(c.server, receive)
conn.Commands = []string{c.mode.String()}
- connCtx, cancel := conn.Handler.WithCancel(ctx)
+ connCtx, cancel := context.WithCancel(ctx)
go conn.Start(connCtx, cancel, throttleCh, statsCh)
for {
diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go
index 581db44..6aadd6b 100644
--- a/internal/clients/maprclient.go
+++ b/internal/clients/maprclient.go
@@ -94,7 +94,7 @@ func NewMaprClient(args Args, queryStr string, maprClientMode MaprClientMode) (*
}
// Start starts the mapreduce client.
-func (c *MaprClient) Start(ctx context.Context, statsCh <-chan struct{}) (status int) {
+func (c *MaprClient) Start(ctx context.Context, statsCh <-chan string) (status int) {
go c.periodicReportResults(ctx)
status = c.baseClient.Start(ctx, statsCh)
@@ -123,7 +123,7 @@ func (c MaprClient) makeCommands() (commands []string) {
commands = append(commands, fmt.Sprintf("timeout %d %s %s %s", c.Timeout, modeStr, file, c.Regex.Serialize()))
continue
}
- commands = append(commands, fmt.Sprintf("%s %s regex %s", modeStr, file, c.Regex.Serialize()))
+ commands = append(commands, fmt.Sprintf("%s %s %s", modeStr, file, c.Regex.Serialize()))
}
return
diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go
index 2d97d14..b29ffed 100644
--- a/internal/clients/remote/connection.go
+++ b/internal/clients/remote/connection.go
@@ -177,21 +177,21 @@ func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, sess
}
go func() {
- defer cancel()
io.Copy(stdinPipe, c.Handler)
+ cancel()
}()
go func() {
- defer cancel()
io.Copy(c.Handler, stdoutPipe)
+ cancel()
}()
go func() {
- defer cancel()
select {
case <-c.Handler.Done():
case <-ctx.Done():
}
+ cancel()
}()
// Send all commands to client.
@@ -207,5 +207,6 @@ func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, sess
}
<-ctx.Done()
+ c.Handler.Shutdown()
return nil
}
diff --git a/internal/clients/stats.go b/internal/clients/stats.go
index a6ac0c5..17343b5 100644
--- a/internal/clients/stats.go
+++ b/internal/clients/stats.go
@@ -4,9 +4,11 @@ import (
"context"
"fmt"
"runtime"
+ "strings"
"sync"
"time"
+ "github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/logger"
)
@@ -32,16 +34,18 @@ func newTailStats(connectionsTotal int) *stats {
// Start starts printing client connection stats every time a signal is recieved or
// connection count has changed.
-func (s *stats) Start(ctx context.Context, throttleCh, statsCh <-chan struct{}) {
+func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{}, statsCh <-chan string) {
var connectedLast int
for {
var force bool
+ var messages []string
select {
- case <-statsCh:
+ case message := <-statsCh:
+ messages = append(messages, message)
force = true
- case <-time.After(time.Second * 2):
+ case <-time.After(time.Second * 10):
case <-ctx.Done():
return
}
@@ -54,7 +58,15 @@ func (s *stats) Start(ctx context.Context, throttleCh, statsCh <-chan struct{})
if connected == connectedLast && !force {
continue
}
- s.log(connected, newConnections, throttle)
+
+ stats := s.statsLine(connected, newConnections, throttle)
+ switch force {
+ case true:
+ messages = append(messages, fmt.Sprintf("Connection stats: %s", stats))
+ s.printStatsOnInterrupt(messages)
+ default:
+ logger.Info(stats)
+ }
connectedLast = connected
s.mutex.Lock()
@@ -63,15 +75,25 @@ func (s *stats) Start(ctx context.Context, throttleCh, statsCh <-chan struct{})
}
}
-func (s *stats) log(connected, newConnections int, throttle int) {
+func (s *stats) printStatsOnInterrupt(messages []string) {
+ logger.Pause()
+ for _, message := range messages {
+ fmt.Println(fmt.Sprintf(" %s", message))
+ }
+ time.Sleep(time.Second * time.Duration(config.InterruptTimeoutS))
+ logger.Resume()
+}
+
+func (s *stats) statsLine(connected, newConnections int, throttle int) string {
percConnected := percentOf(float64(s.connectionsTotal), float64(connected))
- connectedStr := fmt.Sprintf("connected=%d/%d(%d%%)", connected, s.connectionsTotal, int(percConnected))
- newConnStr := fmt.Sprintf("new=%d", newConnections)
- throttleStr := fmt.Sprintf("throttle=%d", throttle)
- cpusGoroutinesStr := fmt.Sprintf("cpus/goroutines=%d/%d", runtime.NumCPU(), runtime.NumGoroutine())
+ var stats []string
+ stats = append(stats, fmt.Sprintf("connected=%d/%d(%d%%)", connected, s.connectionsTotal, int(percConnected)))
+ stats = append(stats, fmt.Sprintf("new=%d", newConnections))
+ stats = append(stats, fmt.Sprintf("throttle=%d", throttle))
+ stats = append(stats, fmt.Sprintf("cpus/goroutines=%d/%d", runtime.NumCPU(), runtime.NumGoroutine()))
- logger.Info("stats", connectedStr, newConnStr, throttleStr, cpusGoroutinesStr)
+ return strings.Join(stats, "|")
}
func (s *stats) numConnected() int {
diff --git a/internal/config/config.go b/internal/config/config.go
index dc96d6b..276ddcf 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -15,6 +15,9 @@ const ScheduleUser string = "DTAIL-SCHEDULE"
// ContinuousUser is used for non-interactive continuous mapreduce queries.
const ContinuousUser string = "DTAIL-CONTINUOUS"
+// InterruptTimeoutS is used to terminate DTail when Ctrl+C was pressed twice within a given interval.
+const InterruptTimeoutS int = 3
+
// Client holds a DTail client configuration.
var Client *ClientConfig
diff --git a/internal/config/server.go b/internal/config/server.go
index db12cec..dc0d587 100644
--- a/internal/config/server.go
+++ b/internal/config/server.go
@@ -61,6 +61,7 @@ type ServerConfig struct {
Continuous []Continuous `json:",omitempty"`
}
+// ServerRelaxedAuthEnable should be used for development and testing purposes only.
var ServerRelaxedAuthEnable bool
// Create a new default server configuration.
diff --git a/internal/io/logger/logger.go b/internal/io/logger/logger.go
index d059cbb..b7db0a7 100644
--- a/internal/io/logger/logger.go
+++ b/internal/io/logger/logger.go
@@ -224,7 +224,7 @@ func log(what string, severity string, args []interface{}) string {
return ""
}
- messages := []string{severity}
+ messages := []string{}
for _, arg := range args {
switch v := arg.(type) {
diff --git a/internal/io/signal/signal.go b/internal/io/signal/signal.go
index bca7e6e..500c530 100644
--- a/internal/io/signal/signal.go
+++ b/internal/io/signal/signal.go
@@ -5,24 +5,37 @@ import (
"os"
gosignal "os/signal"
"syscall"
+ "time"
+
+ "github.com/mimecast/dtail/internal/config"
)
-// StatsCh returns a channel for "please print stats" signalling.
-func StatsCh(ctx context.Context) <-chan struct{} {
- sigCh := make(chan os.Signal)
- gosignal.Notify(sigCh, syscall.SIGINFO, syscall.SIGUSR1)
+// InterruptCh returns a channel for "please print stats" signalling.
+func InterruptCh(ctx context.Context) <-chan string {
+ sigIntCh := make(chan os.Signal)
+ gosignal.Notify(sigIntCh, os.Interrupt)
+
+ sigOtherCh := make(chan os.Signal)
+ gosignal.Notify(sigOtherCh, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGQUIT)
- statsCh := make(chan struct{})
+ statsCh := make(chan string)
go func() {
for {
select {
- case <-sigCh:
+ case <-sigIntCh:
select {
- case statsCh <- struct{}{}:
+ case statsCh <- "Hint: Hit Ctrl+C again to exit":
+ select {
+ case <-sigIntCh:
+ os.Exit(0)
+ case <-time.After(time.Second * time.Duration(config.InterruptTimeoutS)):
+ }
default:
- // Stats currently already printed.
+ // Stats already printed.
}
+ case <-sigOtherCh:
+ os.Exit(0)
case <-ctx.Done():
return
}
diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go
index 1028943..28bb074 100644
--- a/internal/mapr/server/aggregate.go
+++ b/internal/mapr/server/aggregate.go
@@ -6,6 +6,7 @@ import (
"strings"
"time"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/io/logger"
@@ -15,6 +16,7 @@ import (
// Aggregate is for aggregating mapreduce data on the DTail server side.
type Aggregate struct {
+ done *internal.Done
// Log lines to process (parsing MAPREDUCE lines).
Lines chan line.Line
// Hostname of the current server (used to populate $hostname field).
@@ -23,12 +25,12 @@ type Aggregate struct {
serialize chan struct{}
// Signals to flush data.
flush chan struct{}
+ // Signals that data has been flushed
+ flushed chan struct{}
// The mapr query
query *mapr.Query
// The mapr log format parser
parser *logformat.Parser
- cancel context.CancelFunc
- ctx context.Context
}
// NewAggregate return a new server side aggregator.
@@ -64,56 +66,64 @@ func NewAggregate(queryStr string) (*Aggregate, error) {
}
}
- ctx, cancel := context.WithCancel(context.Background())
-
a := Aggregate{
+ done: internal.NewDone(),
Lines: make(chan line.Line, 100),
serialize: make(chan struct{}),
flush: make(chan struct{}),
+ flushed: make(chan struct{}),
hostname: s[0],
query: query,
parser: logParser,
- ctx: ctx,
- cancel: cancel,
}
return &a, nil
}
+// Shutdown the aggregation engine.
+func (a *Aggregate) Shutdown() {
+ a.Flush()
+ a.done.Shutdown()
+}
+
// Start an aggregation.
func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) {
- defer a.cancel()
- fieldsCh := a.linesToFields(ctx)
+ myCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ go func() {
+ select {
+ case <-myCtx.Done():
+ a.done.Shutdown()
+ case <-a.done.Done():
+ cancel()
+ }
+ }()
+
+ fieldsCh := a.makeFields(myCtx)
// Add fields (e.g. via 'set' clause)
if len(a.query.Set) > 0 {
- fieldsCh = a.addMoreFields(ctx, fieldsCh)
+ fieldsCh = a.addFields(myCtx, fieldsCh)
}
- go a.fieldsToMaprLines(ctx, fieldsCh, maprLines)
- a.periodicAggregateTimer(ctx)
+ go a.aggregateTimer(myCtx)
+ a.makeMaprLines(myCtx, fieldsCh, maprLines)
}
-// Cancel the aggregation.
-func (a *Aggregate) Cancel() {
- a.cancel()
-}
-
-func (a *Aggregate) periodicAggregateTimer(ctx context.Context) {
+func (a *Aggregate) aggregateTimer(ctx context.Context) {
for {
select {
case <-time.After(a.query.Interval):
a.Serialize(ctx)
case <-ctx.Done():
return
- case <-a.ctx.Done():
- return
}
}
}
-func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string {
+func (a *Aggregate) makeFields(ctx context.Context) <-chan map[string]string {
ch := make(chan map[string]string)
go func() {
@@ -144,8 +154,6 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
}
case <-ctx.Done():
return
- case <-a.ctx.Done():
- return
}
}
}()
@@ -153,14 +161,14 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
return ch
}
-func (a *Aggregate) addMoreFields(ctx context.Context, fieldsCh <-chan map[string]string) <-chan map[string]string {
+func (a *Aggregate) addFields(ctx context.Context, fieldsCh <-chan map[string]string) <-chan map[string]string {
ch := make(chan map[string]string)
go func() {
defer close(ch)
for {
- // fieldsCh will be closed via 'linesToFields' if ctx is done
+ // fieldsCh will be closed via 'makeFields' if ctx is done
fields, ok := <-fieldsCh
if !ok {
return
@@ -179,7 +187,7 @@ func (a *Aggregate) addMoreFields(ctx context.Context, fieldsCh <-chan map[strin
return ch
}
-func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) {
+func (a *Aggregate) makeMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) {
group := mapr.NewGroupSet()
serialize := func() {
@@ -200,18 +208,10 @@ func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[s
case <-a.serialize:
serialize()
case <-a.flush:
- logger.Info("Flushing mapreduce result")
serialize()
- a.flush <- struct{}{}
- logger.Info("Done flushing mapreduce result")
+ a.flushed <- struct{}{}
case <-ctx.Done():
return
- case <-a.ctx.Done():
- logger.Info("Flushing mapreduce result")
- serialize()
- a.flush <- struct{}{}
- logger.Info("Done flushing mapreduce result")
- return
}
}
}
@@ -254,6 +254,8 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) {
func (a *Aggregate) Serialize(ctx context.Context) {
select {
case a.serialize <- struct{}{}:
+ case <-time.After(time.Minute):
+ logger.Warn("Starting to serialize mapredice data takes over a minute")
case <-ctx.Done():
}
}
@@ -261,15 +263,20 @@ func (a *Aggregate) Serialize(ctx context.Context) {
// Flush all data.
func (a *Aggregate) Flush() {
select {
- case <-a.ctx.Done():
- return
case a.flush <- struct{}{}:
+ logger.Info("Flushing mapreduce data")
case <-time.After(time.Minute):
+ logger.Warn("Starting to flush mapreduce data takes over a minute")
+ return
+ case <-a.done.Done():
return
}
select {
- case <-a.flush:
+ case <-a.flushed:
+ logger.Info("Done flushing")
case <-time.After(time.Minute):
+ logger.Warn("Waiting for data to be flushed takes over a minute")
+ case <-a.done.Done():
}
}
diff --git a/internal/regex/flag.go b/internal/regex/flag.go
index d3ff712..396bda0 100644
--- a/internal/regex/flag.go
+++ b/internal/regex/flag.go
@@ -2,6 +2,7 @@ package regex
import "fmt"
+// Flag for regex.
type Flag int
const (
@@ -15,6 +16,7 @@ const (
Noop Flag = iota
)
+// NewFlag returns a new regex flag.
func NewFlag(str string) (Flag, error) {
switch str {
case "default":
diff --git a/internal/regex/regex.go b/internal/regex/regex.go
index 707cb48..2561659 100644
--- a/internal/regex/regex.go
+++ b/internal/regex/regex.go
@@ -8,6 +8,7 @@ import (
"github.com/mimecast/dtail/internal/io/logger"
)
+// Regex for filtering lines.
type Regex struct {
// The original regex string
regexStr string
@@ -24,6 +25,7 @@ func (r Regex) String() string {
r.regexStr, r.flags, r.initialized, r.re == nil)
}
+// NewNoop is a noop regex (doing nothing).
func NewNoop() Regex {
return Regex{
flags: []Flag{Noop},
@@ -31,6 +33,7 @@ func NewNoop() Regex {
}
}
+// New returns a new regex object.
func New(regexStr string, flag Flag) (Regex, error) {
if regexStr == "" || regexStr == "." || regexStr == ".*" {
return NewNoop(), nil
@@ -39,6 +42,10 @@ func New(regexStr string, flag Flag) (Regex, error) {
}
func new(regexStr string, flags []Flag) (Regex, error) {
+ if len(flags) == 0 {
+ flags = append(flags, Default)
+ }
+
r := Regex{
regexStr: regexStr,
flags: flags,
@@ -55,6 +62,7 @@ func new(regexStr string, flags []Flag) (Regex, error) {
return r, nil
}
+// Match a byte string.
func (r Regex) Match(bytes []byte) bool {
switch r.flags[0] {
case Default:
@@ -68,6 +76,7 @@ func (r Regex) Match(bytes []byte) bool {
}
}
+// MatchString matches a string.
func (r Regex) MatchString(str string) bool {
switch r.flags[0] {
case Default:
@@ -81,6 +90,7 @@ func (r Regex) MatchString(str string) bool {
}
}
+// Serialize the regex.
func (r Regex) Serialize() string {
var flags []string
for _, flag := range r.flags {
@@ -94,6 +104,7 @@ func (r Regex) Serialize() string {
return fmt.Sprintf("regex:%s %s", strings.Join(flags, ","), r.regexStr)
}
+// Deserialize the regex.
func Deserialize(str string) (Regex, error) {
// Get regex string
s := strings.SplitN(str, " ", 2)
diff --git a/internal/server/continuous.go b/internal/server/continuous.go
index 583d136..f75c732 100644
--- a/internal/server/continuous.go
+++ b/internal/server/continuous.go
@@ -92,7 +92,7 @@ func (c *continuous) runJob(ctx context.Context, job config.Continuous) {
}
logger.Info(fmt.Sprintf("Starting job %s", job.Name))
- status := client.Start(jobCtx, make(chan struct{}))
+ status := client.Start(jobCtx, make(chan string))
logMessage := fmt.Sprintf("Job exited with status %d", status)
if status != 0 {
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
index daa9835..8cc5a40 100644
--- a/internal/server/handlers/controlhandler.go
+++ b/internal/server/handlers/controlhandler.go
@@ -1,20 +1,19 @@
package handlers
import (
- "context"
"fmt"
"io"
"os"
"strings"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/io/logger"
user "github.com/mimecast/dtail/internal/user/server"
)
// ControlHandler is used for control functions and health monitoring.
type ControlHandler struct {
- ctx context.Context
- done chan struct{}
+ done *internal.Done
hostname string
payload []byte
serverMessages chan string
@@ -22,12 +21,11 @@ type ControlHandler struct {
}
// NewControlHandler returns a new control handler.
-func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) {
+func NewControlHandler(user *user.User) *ControlHandler {
logger.Debug(user, "Creating control handler")
h := ControlHandler{
- ctx: ctx,
- done: make(chan struct{}),
+ done: internal.NewDone(),
serverMessages: make(chan string, 10),
user: user,
}
@@ -40,7 +38,17 @@ func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h, h.done
+ return &h
+}
+
+// Shutdown the handler.
+func (h *ControlHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+// Done channel of the handler.
+func (h *ControlHandler) Done() <-chan struct{} {
+ return h.done.Done()
}
// Read is to send data to the client via the Reader interface.
@@ -51,7 +59,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) {
wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
n = copy(p, wholePayload)
return
- case <-h.ctx.Done():
+ case <-h.done.Done():
return 0, io.EOF
}
}
@@ -63,7 +71,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
wholePayload := strings.TrimSpace(string(h.payload))
- h.handleCommand(h.ctx, wholePayload)
+ h.handleCommand(wholePayload)
h.payload = nil
default:
@@ -75,7 +83,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
return
}
-func (h *ControlHandler) handleCommand(ctx context.Context, command string) {
+func (h *ControlHandler) handleCommand(command string) {
logger.Info(h.user, command)
s := strings.Split(command, " ")
logger.Debug(h.user, "Receiving command", command, s)
diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go
index c42ceb9..b04e854 100644
--- a/internal/server/handlers/handler.go
+++ b/internal/server/handlers/handler.go
@@ -5,4 +5,6 @@ import "io"
// Handler interface for server side functionality.
type Handler interface {
io.ReadWriter
+ Shutdown()
+ Done() <-chan struct{}
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 7017f3e..5cf8041 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -7,18 +7,16 @@ import (
"fmt"
"io"
"os"
- "strconv"
"strings"
- "sync"
"sync/atomic"
"time"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr/server"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/server/background"
user "github.com/mimecast/dtail/internal/user/server"
"github.com/mimecast/dtail/internal/version"
)
@@ -31,33 +29,27 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- payload []byte
- hostname string
- user *user.User
- // TODO: Move all these channels into a separate struct for readability!
+ done *internal.Done
+ lines chan line.Line
+ regex string
+ aggregate *server.Aggregate
+ aggregatedMessages chan string
+ serverMessages chan string
+ payload []byte
+ hostname string
+ user *user.User
catLimiter chan struct{}
tailLimiter chan struct{}
globalServerWaitFor chan struct{}
ackCloseReceived chan struct{}
- serverCtx context.Context
- handlerCtx context.Context
- done chan struct{}
activeCommands int32
activeReaders int32
- background background.Background
}
// NewServerHandler returns the server handler.
-func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}, background background.Background) (*ServerHandler, <-chan struct{}) {
+func NewServerHandler(user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) *ServerHandler {
h := ServerHandler{
- serverCtx: serverCtx,
- handlerCtx: handlerCtx,
- done: make(chan struct{}),
+ done: internal.NewDone(),
lines: make(chan line.Line, 100),
serverMessages: make(chan string, 10),
aggregatedMessages: make(chan string, 10),
@@ -67,7 +59,6 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca
globalServerWaitFor: globalServerWaitFor,
regex: ".",
user: user,
- background: background,
}
fqdn, err := os.Hostname()
@@ -78,7 +69,17 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h, h.done
+ return &h
+}
+
+// Shutdown the handler.
+func (h *ServerHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+// Done channel of the handler.
+func (h *ServerHandler) Done() <-chan struct{} {
+ return h.done.Done()
}
// Read is to send data to the dtail client via Reader interface.
@@ -120,7 +121,7 @@ func (h *ServerHandler) Read(p []byte) (n int, err error) {
case <-time.After(time.Second):
// Once in a while check whether we are done.
select {
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
return 0, io.EOF
default:
}
@@ -134,7 +135,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(h.handlerCtx, commandStr)
+ h.handleCommand(commandStr)
h.payload = nil
default:
h.payload = append(h.payload, c)
@@ -145,9 +146,9 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
return
}
-func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
+func (h *ServerHandler) handleCommand(commandStr string) {
logger.Debug(h.user, commandStr)
- var timeout time.Duration
+ ctx := context.Background()
args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
if err != nil {
@@ -161,30 +162,18 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
return
}
- args, argc, timeout, err = h.handleTimeout(args, argc)
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
if h.user.Name == config.ControlUser {
h.handleControlCommand(argc, args)
return
}
- if timeout > 0 {
- logger.Info(h.user, "Command with timeout context", argc, args, timeout)
- commandCtx, cancel := context.WithTimeout(ctx, timeout)
- go func() {
- <-commandCtx.Done()
- logger.Info(h.user, "Command timed out, canceling it", args, args, timeout)
- cancel()
- }()
- h.handleUserCommand(commandCtx, argc, args, timeout)
- return
- }
+ ctx, cancel := context.WithCancel(ctx)
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
- h.handleUserCommand(ctx, argc, args, timeout)
+ h.handleUserCommand(ctx, argc, args)
}
func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) {
@@ -222,16 +211,6 @@ func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, er
return args, argc, nil
}
-func (h *ServerHandler) handleTimeout(args []string, argc int) ([]string, int, time.Duration, error) {
- if argc <= 2 || args[0] != "timeout" {
- // No timeout specified
- return args, argc, time.Duration(0) * time.Second, nil
- }
-
- timeout, err := strconv.Atoi(args[1])
- return args[2:], argc - 2, time.Duration(timeout) * time.Second, err
-}
-
func (h *ServerHandler) handleControlCommand(argc int, args []string) {
switch args[0] {
case "debug":
@@ -241,7 +220,7 @@ func (h *ServerHandler) handleControlCommand(argc int, args []string) {
}
}
-func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string, timeout time.Duration) {
+func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
logger.Debug(h.user, "handleUserCommand", argc, args)
h.incrementActiveCommands()
@@ -255,7 +234,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if h.aggregate == nil {
return
}
- h.aggregate.Cancel()
+ h.aggregate.Shutdown()
}
}
@@ -303,86 +282,6 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
commandFinished()
}()
- case "run":
- // TODO: Refactor this "run" case, move code to runcommand.go
- command := newRunCommand(h)
-
- jobName, _ := options["jobName"]
- logger.Debug(h.user, "run", options)
-
- if val, ok := options["background"]; ok && (val == "cancel" || val == "stop") {
- if err := h.background.Cancel(h.user.Name, jobName); err != nil {
- h.sendServerMessage(logger.Error(h.user, err, jobName, args))
- } else {
- h.sendServerMessage(logger.Info(h.user, "job cancelled", jobName))
- }
- commandFinished()
- return
- }
-
- if val, ok := options["background"]; ok && val == "list" {
- h.sendServerMessage("Listing jobs")
- count := 0
- for jobName := range h.background.ListJobsC(h.user.Name) {
- h.sendServerMessage(jobName)
- count++
- }
- h.sendServerMessage(fmt.Sprintf("Found %d jobs", count))
- commandFinished()
- return
- }
-
- str, _ := options["outerArgs"]
- outerArgs := strings.Split(str, " ")
-
- var background bool
- if val, ok := options["background"]; ok && val == "start" {
- background = true
- }
-
- var wg sync.WaitGroup
- wg.Add(1)
-
- if background {
- if timeout == 0 {
- // Set default background timeout.
- timeout = time.Hour * 1
- }
- // Use a new context based on the server context, so that background job does not get
- // terminated when handler/SSH connection terminates.
- commandCtx, cancel := context.WithTimeout(h.serverCtx, timeout)
-
- if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil {
- h.sendServerMessage(logger.Error(h.user, err, jobName, args))
- commandFinished()
- return
- }
- ctx = commandCtx
- }
-
- if err := command.StartBackground(ctx, &wg, argc, args, outerArgs); err != nil {
- h.sendServerMessage(logger.Error(h.user, "Unable to execute command", argc, args, err))
- commandFinished()
- return
- }
-
- // Make sure that server waits for all sub-processes to finish on shutdown
- go func() { h.globalServerWaitFor <- struct{}{} }()
- go func() {
- wg.Wait()
- <-h.globalServerWaitFor
- }()
-
- if background {
- h.sendServerMessage(logger.Info(h.user, jobName, "job started in background"))
- commandFinished()
- return
- }
-
- // Command run in foreground, wait for it to complete before finishing the connection.
- wg.Wait()
- commandFinished()
-
case "ack", ".ack":
h.handleAckCommand(argc, args)
commandFinished()
@@ -406,7 +305,7 @@ func (h *ServerHandler) handleAckCommand(argc int, args []string) {
func (h *ServerHandler) send(ch chan<- string, message string) {
select {
case ch <- message:
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
}
@@ -447,7 +346,7 @@ func (h *ServerHandler) shutdown() {
go func() {
select {
case h.serverMessageC() <- ".syn close connection":
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
}()
@@ -455,13 +354,10 @@ func (h *ServerHandler) shutdown() {
case <-h.ackCloseReceived:
case <-time.After(time.Second * 5):
logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
- select {
- case h.done <- struct{}{}:
- default:
- }
+ h.done.Shutdown()
}
func (h *ServerHandler) incrementActiveCommands() {
diff --git a/internal/server/scheduler.go b/internal/server/scheduler.go
index 9d76a3b..a1e9e36 100644
--- a/internal/server/scheduler.go
+++ b/internal/server/scheduler.go
@@ -93,7 +93,7 @@ func (s *scheduler) runJobs(ctx context.Context) {
defer cancel()
logger.Info(fmt.Sprintf("Starting job %s", job.Name))
- status := client.Start(jobCtx, make(chan struct{}))
+ status := client.Start(jobCtx, make(chan string))
logMessage := fmt.Sprintf("Job exited with status %d", status)
if status != 0 {
diff --git a/internal/server/server.go b/internal/server/server.go
index a446738..31fa85d 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -11,7 +11,6 @@ import (
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/server/background"
"github.com/mimecast/dtail/internal/server/handlers"
"github.com/mimecast/dtail/internal/ssh/server"
user "github.com/mimecast/dtail/internal/user/server"
@@ -35,9 +34,8 @@ type Server struct {
// Mointor log files for pattern (if configured)
cont *continuous
// Wait counter, e.g. there might be still subprocesses (forked by drun) to be killed.
+ // TODO: Remove this counter.
shutdownWaitFor chan struct{}
- // Background jobs
- background background.Background
}
// New returns a new server.
@@ -51,7 +49,6 @@ func New() *Server {
shutdownWaitFor: make(chan struct{}, 1000),
sched: newScheduler(),
cont: newContinuous(),
- background: background.New(),
}
s.sshServerConfig.PasswordCallback = s.Callback
@@ -178,53 +175,46 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
switch req.Type {
case "shell":
- handlerCtx, cancel := context.WithCancel(ctx)
-
var handler handlers.Handler
- var done <-chan struct{}
-
switch user.Name {
case config.ControlUser:
- handler, done = handlers.NewControlHandler(handlerCtx, user)
+ handler = handlers.NewControlHandler(user)
default:
- handler, done = handlers.NewServerHandler(handlerCtx, ctx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor, s.background)
+ handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor)
}
- go func() {
- // Handler finished work, cancel all remaining routines
- defer cancel()
-
- <-done
- }()
+ terminate := func() {
+ handler.Shutdown()
+ sshConn.Close()
+ }
go func() {
// Broken pipe, cancel
- defer cancel()
-
io.Copy(channel, handler)
+ terminate()
}()
go func() {
// Broken pipe, cancel
- defer cancel()
-
io.Copy(handler, channel)
+ terminate()
}()
go func() {
- defer cancel()
+ select {
+ case <-ctx.Done():
+ case <-handler.Done():
+ }
+ terminate()
+ }()
+ go func() {
if err := sshConn.Wait(); err != nil && err != io.EOF {
logger.Error(user, err)
}
s.stats.decrementConnections()
logger.Info(user, "Good bye Mister!")
- }()
-
- go func() {
- <-handlerCtx.Done()
- sshConn.Close()
- logger.Info(user, "Closed SSH connection")
+ terminate()
}()
// Only serving shell type
diff --git a/internal/version/version.go b/internal/version/version.go
index 36ef62c..b513b40 100644
--- a/internal/version/version.go
+++ b/internal/version/version.go
@@ -13,7 +13,7 @@ const (
// Version of DTail.
Version string = "3.1.0"
// Additional information for DTail
- Additional string = "develop"
+ Additional string = ""
// ProtocolCompat -ibility version.
ProtocolCompat string = "3"
)