diff options
| author | Paul Bütow <pbuetow@mimecast.com> | 2020-01-26 11:26:53 +0000 |
|---|---|---|
| committer | Paul Bütow <pbuetow@mimecast.com> | 2020-02-07 13:31:15 +0000 |
| commit | 0945da8dfefcbb723eecea0e5f4eafff63398253 (patch) | |
| tree | f06dab4d2bf21d25d176b23d5baeca588d27f5d7 | |
| parent | 2a8e5de265a0e0a31a5834909d6879f5c9941467 (diff) | |
Introduce drun command, refactor code to use context package
75 files changed, 1271 insertions, 1081 deletions
@@ -1,29 +1,31 @@ GO ?= go all: build build: + ${GO} build -o dserver ./cmd/dserver/main.go ${GO} build -o dcat ./cmd/dcat/main.go - ${GO} build -o dexec ./cmd/dexec/main.go ${GO} build -o dgrep ./cmd/dgrep/main.go ${GO} build -o dmap ./cmd/dmap/main.go - ${GO} build -o dserver ./cmd/dserver/main.go + ${GO} build -o drun ./cmd/drun/main.go ${GO} build -o dtail ./cmd/dtail/main.go clean: - rm -v dtail dgrep dcat dmap dserver dexec 2>/dev/null + ls ./cmd/ | while read cmd; do \ + test -f $$cmd && rm $$cmd; \ + done install: build + cp -pv dserver ${GOPATH}/bin/dserver cp -pv dcat ${GOPATH}/bin/dcat - cp -pv dexec ${GOPATH}/bin/dexec cp -pv dgrep ${GOPATH}/bin/dgrep cp -pv dmap ${GOPATH}/bin/dmap - cp -pv dserver ${GOPATH}/bin/dserver + cp -pv drun ${GOPATH}/bin/drun cp -pv dtail ${GOPATH}/bin/dtail vet: find . -type d | while read dir; do \ echo ${GO} vet $$dir; \ ${GO} vet $$dir; \ - done + done lint: ${GO} get golang.org/x/lint/golint find . -type d | while read dir; do \ echo ${GOPATH}/bin/golint $$dir; \ ${GOPATH}/bin/golint $$dir; \ - done + done diff --git a/cmd/dcat/main.go b/cmd/dcat/main.go index b02d369..1ec945d 100644 --- a/cmd/dcat/main.go +++ b/cmd/dcat/main.go @@ -1,12 +1,14 @@ package main import ( + "context" "flag" + "os" "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/user" "github.com/mimecast/dtail/internal/version" @@ -27,7 +29,6 @@ func main() { var sshPort int var trustAllHosts bool - pingTimeoutS := 60 userName := user.Name() flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") @@ -37,7 +38,6 @@ func main() { flag.BoolVar(&silentEnable, "silent", false, "Reduce output") flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") - flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") flag.IntVar(&sshPort, "port", 2222, "SSH server port") flag.StringVar(&cfgFile, "cfg", "", "Config file path") flag.StringVar(&discovery, "discovery", "", "Server discovery method") @@ -54,9 +54,10 @@ func main() { version.PrintAndExit() } + ctx := context.Background() serverEnable := false - logger.Start(serverEnable, debugEnable, silentEnable, silentEnable) - defer logger.Stop() + + logger.Start(ctx, serverEnable, debugEnable, silentEnable, silentEnable) if pprofEnable || config.Common.PProfEnable { pprof.Start() @@ -67,14 +68,16 @@ func main() { ServersStr: serversStr, Discovery: discovery, UserName: userName, - Files: files, + What: files, TrustAllHosts: trustAllHosts, - PingTimeout: pingTimeoutS, } client, err := clients.NewCatClient(args) if err != nil { panic(err) } - client.Start() + + status := client.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/cmd/dgrep/main.go b/cmd/dgrep/main.go index d1a7d52..74a501f 100644 --- a/cmd/dgrep/main.go +++ b/cmd/dgrep/main.go @@ -1,12 +1,14 @@ package main import ( + "context" "flag" + "os" "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/user" "github.com/mimecast/dtail/internal/version" @@ -28,7 +30,6 @@ func main() { var sshPort int var trustAllHosts bool - pingTimeoutS := 60 userName := user.Name() flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") @@ -38,7 +39,6 @@ func main() { flag.BoolVar(&silentEnable, "silent", false, "Reduce output") flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") - flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") flag.IntVar(&sshPort, "port", 2222, "SSH server port") flag.StringVar(&cfgFile, "cfg", "", "Config file path") flag.StringVar(&discovery, "discovery", "", "Server discovery method") @@ -56,9 +56,9 @@ func main() { version.PrintAndExit() } + ctx := context.Background() serverEnable := false - logger.Start(serverEnable, debugEnable, silentEnable, silentEnable) - defer logger.Stop() + logger.Start(ctx, serverEnable, debugEnable, silentEnable, silentEnable) if pprofEnable || config.Common.PProfEnable { pprof.Start() @@ -69,9 +69,8 @@ func main() { ServersStr: serversStr, Discovery: discovery, UserName: userName, - Files: files, + What: files, TrustAllHosts: trustAllHosts, - PingTimeout: pingTimeoutS, Regex: regex, } @@ -79,5 +78,8 @@ func main() { if err != nil { panic(err) } - client.Start() + + status := client.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/cmd/dmap/main.go b/cmd/dmap/main.go index 83dad50..f3f706a 100644 --- a/cmd/dmap/main.go +++ b/cmd/dmap/main.go @@ -1,13 +1,15 @@ package main import ( + "context" "flag" + "os" - "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/user" "github.com/mimecast/dtail/internal/version" @@ -29,7 +31,6 @@ func main() { var sshPort int var trustAllHosts bool - pingTimeoutS := 900 userName := user.Name() flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") @@ -39,7 +40,6 @@ func main() { flag.BoolVar(&silentEnable, "silent", false, "Reduce output") flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") - flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") flag.IntVar(&sshPort, "port", 2222, "SSH server port") flag.StringVar(&cfgFile, "cfg", "", "Config file path") flag.StringVar(&discovery, "discovery", "", "Server discovery method") @@ -57,10 +57,10 @@ func main() { version.PrintAndExit() } + ctx := context.Background() serverEnable := false - logger.Start(serverEnable, debugEnable, silentEnable, silentEnable) - defer logger.Stop() + logger.Start(ctx, serverEnable, debugEnable, silentEnable, silentEnable) if pprofEnable || config.Common.PProfEnable { pprof.Start() } @@ -70,9 +70,8 @@ func main() { ServersStr: serversStr, Discovery: discovery, UserName: userName, - Files: files, + What: files, TrustAllHosts: trustAllHosts, - PingTimeout: pingTimeoutS, Mode: omode.MapClient, } @@ -81,5 +80,7 @@ func main() { panic(err) } - client.Start() + status := client.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/cmd/dexec/main.go b/cmd/drun/main.go index 7a7ab1f..b1936d4 100644 --- a/cmd/dexec/main.go +++ b/cmd/drun/main.go @@ -1,12 +1,14 @@ package main import ( + "context" "flag" + "os" "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/user" "github.com/mimecast/dtail/internal/version" @@ -15,11 +17,11 @@ import ( // The evil begins here. func main() { var cfgFile string + var command string var connectionsPerCPU int var debugEnable bool var discovery string var displayVersion bool - var command string var noColor bool var pprofEnable bool var serversStr string @@ -27,7 +29,6 @@ func main() { var sshPort int var trustAllHosts bool - pingTimeoutS := 60 userName := user.Name() flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") @@ -37,11 +38,10 @@ func main() { flag.BoolVar(&silentEnable, "silent", false, "Reduce output") flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") - flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") flag.IntVar(&sshPort, "port", 2222, "SSH server port") flag.StringVar(&cfgFile, "cfg", "", "Config file path") - flag.StringVar(&discovery, "discovery", "", "Server discovery method") flag.StringVar(&command, "command", "", "Command to run") + flag.StringVar(&discovery, "discovery", "", "Server discovery method") flag.StringVar(&serversStr, "servers", "", "Remote servers to connect") flag.StringVar(&userName, "user", userName, "Your system user name") @@ -54,10 +54,10 @@ func main() { version.PrintAndExit() } + ctx := context.Background() serverEnable := false - logger.Start(serverEnable, debugEnable, silentEnable, silentEnable) - defer logger.Stop() + logger.Start(ctx, serverEnable, debugEnable, silentEnable, silentEnable) if pprofEnable || config.Common.PProfEnable { pprof.Start() } @@ -67,14 +67,16 @@ func main() { ServersStr: serversStr, Discovery: discovery, UserName: userName, - Files: files, + What: command, TrustAllHosts: trustAllHosts, - PingTimeout: pingTimeoutS, } - client, err := clients.NewExecClient(args) + client, err := clients.NewRunClient(args) if err != nil { panic(err) } - client.Start() + + status := client.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/cmd/dserver/main.go b/cmd/dserver/main.go index 489910b..aa209a8 100644 --- a/cmd/dserver/main.go +++ b/cmd/dserver/main.go @@ -1,13 +1,14 @@ package main import ( + "context" "flag" "os" "time" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/server" "github.com/mimecast/dtail/internal/user" @@ -24,7 +25,7 @@ func main() { var shutdownAfter int var sshPort int - userName := user.Name() + user.NoRootCheck() flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") flag.BoolVar(&displayVersion, "version", false, "Display version") @@ -43,19 +44,23 @@ func main() { version.PrintAndExit() } + ctx := context.Background() + serverEnable := true silentEnable := false nothingEnable := false - logger.Start(serverEnable, debugEnable, silentEnable, nothingEnable) - defer logger.Stop() + logger.Start(ctx, serverEnable, debugEnable, silentEnable, nothingEnable) if shutdownAfter > 0 { go func() { defer os.Exit(1) logger.Info("Enabling auto shutdown timer", shutdownAfter) - time.Sleep(time.Duration(shutdownAfter) * time.Second) - logger.Info("Auto shutdown timer reached, shutting down now") + select { + case <-time.After(time.Duration(shutdownAfter) * time.Second): + logger.Info("Auto shutdown timer reached, shutting down now") + case <-ctx.Done(): + } }() } @@ -63,7 +68,8 @@ func main() { pprof.Start() } - logger.Info("Launching server", version.String(), userName) sshServer := server.New() - sshServer.Start() + status := sshServer.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/cmd/dtail/main.go b/cmd/dtail/main.go index 1bf77c7..76070ff 100644 --- a/cmd/dtail/main.go +++ b/cmd/dtail/main.go @@ -1,13 +1,14 @@ package main import ( + "context" "flag" "os" "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/user" @@ -32,7 +33,6 @@ func main() { var sshPort int var trustAllHosts bool - pingTimeoutS := 5 userName := user.Name() flag.BoolVar(&checkHealth, "checkHealth", false, "Only check for server health") @@ -43,7 +43,6 @@ func main() { flag.BoolVar(&silentEnable, "silent", false, "Reduce output") flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") - flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") flag.IntVar(&sshPort, "port", 2222, "SSH server port") flag.StringVar(&cfgFile, "cfg", "", "Config file path") flag.StringVar(&discovery, "discovery", "", "Server discovery method") @@ -62,17 +61,18 @@ func main() { version.PrintAndExit() } + ctx := context.Background() + if checkHealth { healthClient, _ := clients.NewHealthClient(omode.HealthClient) - os.Exit(healthClient.Start()) + os.Exit(healthClient.Start(ctx)) } serverEnable := false if checkHealth { silentEnable = true } - logger.Start(serverEnable, debugEnable, silentEnable, silentEnable) - defer logger.Stop() + logger.Start(ctx, serverEnable, debugEnable, silentEnable, silentEnable) if pprofEnable || config.Common.PProfEnable { pprof.Start() @@ -83,9 +83,8 @@ func main() { ServersStr: serversStr, Discovery: discovery, UserName: userName, - Files: files, + What: files, TrustAllHosts: trustAllHosts, - PingTimeout: pingTimeoutS, Regex: regex, Mode: omode.TailClient, } @@ -104,5 +103,7 @@ func main() { } } - client.Start() + status := client.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/doc/examples.md b/doc/examples.md index 959105c..964660a 100644 --- a/doc/examples.md +++ b/doc/examples.md @@ -25,7 +25,7 @@ To run ad-hoc mapreduce aggregations on newly written log lines you also must ad --files '/var/log/service/*.log' ``` -In order for mapreduce queries to work you have to make sure that your log format is supported by DTail. You can either use the ones which are already defined in ``mapr/logformat`` or add an extension to support a custom log format. +In order for mapreduce queries to work you have to make sure that your log format is supported by DTail. You can either use the ones which are already defined in ``internal/mapr/logformat`` or add an extension to support a custom log format.  @@ -62,6 +62,6 @@ To run a mapreduce aggregation over logs written in the past the ``dmap`` comman --files "/var/log/service/*.log" ``` -Remember: In order for that to work you have to make sure that your log format is supported by DTail. You can either use the ones which are already defined in ``mapr/logformat`` or add an extension to support a custom log format. +Remember: In order for that to work you have to make sure that your log format is supported by DTail. You can either use the ones which are already defined in ``internal/mapr/logformat`` or add an extension to support a custom log format.  diff --git a/doc/installation.md b/doc/installation.md index 305eae5..a15beb1 100644 --- a/doc/installation.md +++ b/doc/installation.md @@ -3,8 +3,6 @@ DTail Installation Guide The following installation guide has been tested successfully on CentOS 7. You may need to adjust accordingly depending on the distribution you use. -This guide also assumes that you know how to use ``systemd`` and how to configure a service there. If you are unsure please consult the documentation of your distribution. - # Compile it Please check the [Quick Starting Guide](quickstart.md) for instructions how to compile DTail. It is recommended to automate the build process via your build pipeline (e.g. produce a deployable RPM via Jenkins). You don't have to use ``go get...`` to compile and install the binaries. You can also clone the repository and use ``make`` instead. @@ -12,6 +10,7 @@ Please check the [Quick Starting Guide](quickstart.md) for instructions how to c # Install it It is recommended to automate all the installation process outlined here. You could use a configuration management system such as Puppet, Chef or Ansible. However, that relies heavily on how your infrastructure is managed and is out of scope of this documentation. + 1. The ``dserver`` binary has to be installed on all machines (server boxes) involved. A good location for the binary would be ``/usr/local/bin/dserver`` with permissions set as follows: ```console @@ -72,7 +71,7 @@ Now you should be able to use DTail client like outlined in the [Quick Starting # Monitor it -To verify that DTail server is up and running and functioning as expected you should configure the Nagios check [check_dserver.sh](../samples/check_dserver.sh.sample) in your monitoring system. The check has to be executed locally on the server (e.g. via NRPE). How to configure the monitoring system in detail is out of scope of this guide, as it depends on the monitoring infrastructure used. +To verify that DTail server is up and running and functioning as expected you should configure the Nagios check [check_dserver.sh](../samples/check_dserver.sh.sample) in your monitoring system. The check has to be executed locally on the server (e.g. via NRPE). How to configure the monitoring system in detail is out of scope of this guide. ```console % ./check_dserver.sh diff --git a/doc/quickstart.md b/doc/quickstart.md index 46f7fae..7b6fbf4 100644 --- a/doc/quickstart.md +++ b/doc/quickstart.md @@ -3,13 +3,11 @@ Quick Starting Guide This is the quick starting guide. For a more sustainable setup, involving how to create a background service via ``systemd``, recommendations about automation via Jenkins and/or Puppet and health monitoring via Nagios please also follow the [Installation Guide](installation.md). -This guide assumes that you know how to generate and configure a public/private SSH key pair for secure authorization and shell access. That is out of scope of this guide. For more information please have a look at the OpenSSH documentation of your distribution. +This guide assumes that you know how to generate and configure a public/private SSH key pair for secure authorization and shell access. For more information please have a look at the OpenSSH documentation of your distribution. -This guide also assumes that you know how to install and use a Go compiler and GNU make. +# Install it -# Compile it - -To install all DTail binaries from github run: +To compile and install all DTail binaries directly from GitHub run: ```console % go get github.com/mimecast/dtail/cmd/dtail @@ -48,7 +46,9 @@ SERVER|serv-001|INFO|Binding server|0.0.0.0:2222 Make sure that your public SSH key is listed in ``~/.ssh/authorized_keys`` on all server machines involved. The private SSH key counterpart should preferably stay on your Laptop or workstation in ``~/.ssh/id_rsa`` or ``~/.ssh/id_dsa``. -DTail utilises the SSH Agent for SSH authentication. This is to avoid entering the passphrase of the private SSH key over and over again when a new SSH session is initiated from the DTail client to a new DTail server. For this the private SSH key has to be registered at the SSH Agent: +DTail relies on SSH for secure authentication and communication. The clients (all client binaries such as ``dtail``, ``dgrep`` and so on...) communicate with an auth backend via the SSH auth socket. The SSH auth socket is configured via the environment variable ``SSH_AUTH_SOCK`` which usually points to ``~/.ssh/ssh_auth_socket`` or similar (depending on your configuration it may also point to other auth backends such as GPG Agent, in which case ``SSH_AUTH_SOCK`` would point to ``~/.gnupg/S.gpg-agent.ssh`` or similar). + +Usually you would use the SSH Auth Agent. For this the private SSH key has to be registered at the SSH Agent: ```console % ssh-add ~/.ssh/id_rsa @@ -56,16 +56,17 @@ Enter passphrase for ~/.ssh/id_rsa: ********** Identity added: ~/.ssh/id_rsa (~/.ssh/id_rsa) ``` -The DTail client communicates with the SSH Agent through ``~/.ssh/ssh_auth_socket`` whenever a new session to a DTail server is established. - To test whether SSH is setup correctly you should be able to SSH into the servers with the OpenSSH client and your private SSH key through the SSH Agent without entering the private keys passphrase. The following assumes to have an OpenSSH server running on ``serv-001.lan.example.org`` and an OpenSSH client installed on your laptop or workstation. Please notice that DTail does not require to have an OpenSSH infrastructure set up but DTail uses by default the same public/private key file paths as OpenSSH. OpenSSH can be of a great help to verify that the SSH keys are configured correctly: ```console -% ssh serv-001.lan.example.org -% -% exit +workstation01 ~ % ssh serv-001.lan.example.org +serv-001 ~ % +serv-001 ~ % exit +workstation01 ~ % ``` +Please consult the OpenSSH documentation of your distribution if the test above does not work for you. + ## Run DTail client Now it is time to connect to the DTail servers through the DTail client: @@ -5,4 +5,5 @@ go 1.13 require ( github.com/DataDog/zstd v1.4.4 golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 + golang.org/x/lint v0.0.0-20200130185559-910be7a94367 // indirect ) @@ -1,10 +1,19 @@ github.com/DataDog/zstd v1.4.4 h1:+IawcoXhCBylN7ccwdwf8LOH2jKq7NavGpEPanrlTzE= github.com/DataDog/zstd v1.4.4/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90onWEf1dF4C+0hPJCc9Mpc= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/lint v0.0.0-20200130185559-910be7a94367 h1:0IiAsCRByjO2QjX7ZPkw5oU9x+n1YqRL802rjC0c3Aw= +golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7 h1:EBZoQjiKKPaLbPrbpssUfuHtwM6KV/vb4U85g/cigFY= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/clients/args.go b/internal/clients/args.go index 5fe0a72..dea5a9e 100644 --- a/internal/clients/args.go +++ b/internal/clients/args.go @@ -9,10 +9,9 @@ type Args struct { Mode omode.Mode ServersStr string UserName string - Files string + What string Regex string TrustAllHosts bool Discovery string ConnectionsPerCPU int - PingTimeout int } diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 574ae94..b1540ea 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -1,13 +1,14 @@ package clients import ( + "context" "regexp" "sync" "time" "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/discovery" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/ssh/client" @@ -27,111 +28,110 @@ type baseClient struct { sshAuthMethods []gossh.AuthMethod // To deal with SSH host keys hostKeyCallback *client.HostKeyCallback - // To stop the client. - stop chan struct{} - // To indicate that the client has stopped. - stopped chan struct{} // Throttle how fast we initiate SSH connections concurrently throttleCh chan struct{} // Retry connection upon failure? retry bool - // Connection helper. - maker connectionMaker + // Connection maker helper. + maker maker } -func (c *baseClient) init(maker connectionMaker) { +func (c *baseClient) init(maker maker) { logger.Info("Initiating base client") c.maker = maker - //c.connections = make(map[string]*remote.Connection) c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh) + discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle) - // Retrieve a shuffled list of remote dtail servers. - shuffleServers := true - discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers) for _, server := range discoveryService.ServerList() { - c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) + c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) } if _, err := regexp.Compile(c.Regex); err != nil { logger.FatalExit(c.Regex, "Can't test compile regex", err) } - // Periodically check for unknown hosts, and ask the user whether to trust them or not. - go c.hostKeyCallback.PromptAddHosts(c.stop) - - // Periodically print out connection stats to the client. c.stats = newTailStats(len(c.connections)) - go c.stats.periodicLogStats(c.throttleCh, c.stop) } -func (c *baseClient) Start() (status int) { +func (c *baseClient) Start(ctx context.Context) (status int) { + // Periodically check for unknown hosts, and ask the user whether to trust them or not. + go c.hostKeyCallback.PromptAddHosts(ctx) + // Periodically print out connection stats to the client. + go c.stats.periodicLogStats(ctx, c.throttleCh) + // Keep count of active connections active := make(chan struct{}, len(c.connections)) - var wg sync.WaitGroup - wg.Add(len(c.connections)) - + var mutex sync.Mutex for i, conn := range c.connections { go func(i int, conn *remote.Connection) { - active <- struct{}{} - defer func() { - logger.Debug(conn.Server, "Disconnected completely...") - <-active - }() - wg.Done() - - for { - conn.Start(c.throttleCh, c.stats.connectionsEstCh) - if !c.retry { - return - } - time.Sleep(time.Second * 2) - logger.Debug(conn.Server, "Reconencting") - conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) - c.connections[i] = conn + connStatus := c.start(ctx, active, i, conn) + + // Update global status. + mutex.Lock() + defer mutex.Unlock() + if connStatus > status { + status = connStatus } }(i, conn) } - wg.Wait() - c.waitUntilDone(active) - + c.waitUntilDone(ctx, active) return } -func (c *baseClient) waitUntilDone(active chan struct{}) { - defer close(c.stopped) +func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) { + // Increment connection count + active <- struct{}{} + // Derement connection count + defer func() { <-active }() - if c.Mode != omode.TailClient { - c.waitUntilZero(active) - logger.Info("All connections stopped") - return - } + for { + connCtx, cancel := conn.Handler.WithCancel(ctx) + defer cancel() - <-c.stop - logger.Info("Stopping client") - for _, conn := range c.connections { - conn.Stop() + conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh) + // Retrieve status code from handler (dtail client will exit with that status) + status = conn.Handler.Status() + + if !c.retry { + return + } + + time.Sleep(time.Second * 2) + logger.Debug(conn.Server, "Reconnecting") + + conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) + c.connections[i] = conn } +} - c.waitUntilZero(active) +func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = c.maker.makeHandler(server) + conn.Commands = c.maker.makeCommands() + + return conn } -func (c *baseClient) waitUntilZero(active chan struct{}) { +func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) { + defer logger.Info("Terminated connection") + + // We want to have at least one active connection + <-active + // Put it back on the channel + active <- struct{}{} + + if c.Mode == omode.TailClient { + <-ctx.Done() + } + for { - logger.Debug("Active connections", len(active)) - if len(active) == 0 { + numActive := len(active) + if numActive == 0 { return } + logger.Debug("Active connections", numActive) time.Sleep(time.Second) } } - -func (c *baseClient) Stop() { - close(c.stop) - <-c.WaitC() -} - -func (c *baseClient) WaitC() <-chan struct{} { - return c.stopped -} diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go index 5ea701d..7fd6bdc 100644 --- a/internal/clients/catclient.go +++ b/internal/clients/catclient.go @@ -7,11 +7,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // CatClient is a client for returning a whole file from the beginning to the end. @@ -31,8 +27,6 @@ func NewCatClient(args Args) (*CatClient, error) { c := CatClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, }, @@ -43,11 +37,13 @@ func NewCatClient(args Args) (*CatClient, error) { return &c, nil } -func (c CatClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c CatClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} + +func (c CatClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - return conn + return } diff --git a/internal/clients/client.go b/internal/clients/client.go index 85d1aae..1fc5e23 100644 --- a/internal/clients/client.go +++ b/internal/clients/client.go @@ -1,7 +1,8 @@ package clients +import "context" + // Client is the interface for the end user command line client. type Client interface { - Start() int - Stop() + Start(ctx context.Context) int } diff --git a/internal/clients/connectionmaker.go b/internal/clients/connectionmaker.go deleted file mode 100644 index 0617992..0000000 --- a/internal/clients/connectionmaker.go +++ /dev/null @@ -1,12 +0,0 @@ -package clients - -import ( - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" -) - -type connectionMaker interface { - makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection -} diff --git a/internal/clients/execclient.go b/internal/clients/execclient.go deleted file mode 100644 index 10bd081..0000000 --- a/internal/clients/execclient.go +++ /dev/null @@ -1,48 +0,0 @@ -package clients - -import ( - "fmt" - "runtime" - "strings" - - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" -) - -// ExecClient is a client for execute various commands on the server. -type ExecClient struct { - baseClient -} - -// NewExecClient returns a new cat client. -func NewExecClient(args Args) (*ExecClient, error) { - args.Regex = "." - args.Mode = omode.ExecClient - - c := ExecClient{ - baseClient: baseClient{ - Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), - throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), - retry: false, - }, - } - - c.init(c) - - return &c, nil -} - -func (c ExecClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) - for _, file := range strings.Split(c.Files, ";") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s", c.Mode.String(), file)) - } - return conn -} diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go index c568f63..8d11458 100644 --- a/internal/clients/grepclient.go +++ b/internal/clients/grepclient.go @@ -7,11 +7,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed. @@ -29,8 +25,6 @@ func NewGrepClient(args Args) (*GrepClient, error) { c := GrepClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, }, @@ -41,13 +35,13 @@ func NewGrepClient(args Args) (*GrepClient, error) { return &c, nil } -func (c GrepClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) +func (c GrepClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c GrepClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - - return conn + return } diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 19246f9..68b8ddc 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -1,60 +1,44 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" - "errors" + "encoding/base64" "fmt" "io" + "strconv" "strings" "time" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/version" ) type baseHandler struct { + withCancel server string shellStarted bool commands chan string - pong chan struct{} receiveBuf []byte - stop chan struct{} - pingTimeout int + status int } func (h *baseHandler) Server() string { return h.server } -// Used to determine whether server is still responding to requests or not. -func (h *baseHandler) Ping() error { - if h.pingTimeout == 0 { - // Server ping disabled - return nil - } - - if err := h.SendCommand("ping"); err != nil { - return err - } - - select { - case <-h.pong: - return nil - case <-time.After(time.Duration(h.pingTimeout) * time.Second): - } - - return errors.New("Didn't receive any server pongs (ping replies)") +func (h *baseHandler) Status() int { + return h.status } -func (h *baseHandler) SendCommand(command string) error { - if command == "ping" { - logger.Trace("Sending command", h.server, command) - } else { - logger.Debug("Sending command", h.server, command) - } +// SendMessage to the server. +func (h *baseHandler) SendMessage(command string) error { + encoded := base64.StdEncoding.EncodeToString([]byte(command)) + logger.Debug("Sending command", h.server, command, encoded) select { - case h.commands <- fmt.Sprintf("%s;", command): + case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded): case <-time.After(time.Second * 5): - return errors.New("Timed out sending command " + command) - case <-h.stop: + return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded) + case <-h.ctx.Done(): } return nil @@ -81,7 +65,7 @@ func (h *baseHandler) Read(p []byte) (n int, err error) { select { case command := <-h.commands: n = copy(p, []byte(command)) - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF } return @@ -92,6 +76,7 @@ func (h *baseHandler) handleMessageType(message string) { if len(h.receiveBuf) == 0 { return } + // Hidden server commands starti with a dot "." if h.receiveBuf[0] == '.' { h.handleHiddenMessage(message) @@ -108,6 +93,7 @@ func (h *baseHandler) handleMessageType(message string) { h.receiveBuf = h.receiveBuf[:0] return } + logger.Raw(message) h.receiveBuf = h.receiveBuf[:0] } @@ -116,19 +102,27 @@ func (h *baseHandler) handleMessageType(message string) { // to the end user. func (h *baseHandler) handleHiddenMessage(message string) { switch { - case strings.HasPrefix(message, ".pong"): - h.pong <- struct{}{} case strings.HasPrefix(message, ".syn close connection"): - h.SendCommand("ack close connection") - } -} + h.SendMessage(".ack close connection") + select { + case <-time.After(time.Second * 1): + logger.Debug("Shutting down client after timeout and sending ack to server") + h.withCancel.shutdown() + case <-h.ctx.Done(): + } -// Stop the handler. -func (h *baseHandler) Stop() { - select { - case <-h.stop: - default: - logger.Debug("Stopping base handler", h.server) - close(h.stop) + case strings.HasPrefix(message, ".run exitstatus"): + splitted := strings.Split(strings.TrimSuffix(message, "\n"), " ") + if len(splitted) != 3 { + logger.Error("Unable to retrieve exitstatus", message) + return + } + i, err := strconv.Atoi(splitted[2]) + if err != nil { + logger.Error("Unable to retrieve exitstatus", message, err) + return + } + h.status = i + logger.Debug("Retrieved exitstatus", h.status) } } diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go index 4738cd3..fcd8052 100644 --- a/internal/clients/handlers/clienthandler.go +++ b/internal/clients/handlers/clienthandler.go @@ -1,7 +1,7 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" ) // ClientHandler is the basic client handler interface. @@ -10,7 +10,7 @@ type ClientHandler struct { } // NewClientHandler creates a new client handler. -func NewClientHandler(server string, pingTimeout int) *ClientHandler { +func NewClientHandler(server string) *ClientHandler { logger.Debug(server, "Creating new client handler") return &ClientHandler{ @@ -18,9 +18,10 @@ func NewClientHandler(server string, pingTimeout int) *ClientHandler { server: server, shellStarted: false, commands: make(chan string), - pong: make(chan struct{}, 1), - stop: make(chan struct{}), - pingTimeout: pingTimeout, + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, }, } } diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go index 2013be0..c53ca34 100644 --- a/internal/clients/handlers/handler.go +++ b/internal/clients/handlers/handler.go @@ -1,12 +1,16 @@ package handlers -import "io" +import ( + "context" + "io" +) // Handler provides all methods which can be run on any client handler. type Handler interface { io.ReadWriter - Ping() error - Stop() - SendCommand(command string) error + SendMessage(command string) error Server() string + Status() int + WithCancel(ctx context.Context) (context.Context, context.CancelFunc) + Done() <-chan struct{} } diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go index 4051e2c..9051015 100644 --- a/internal/clients/handlers/healthhandler.go +++ b/internal/clients/handlers/healthhandler.go @@ -8,6 +8,7 @@ import ( // HealthHandler implements the handler required for health checks. type HealthHandler struct { + withCancel // Buffer of incoming data from server. receiveBuf []byte // To send commands to the server. @@ -16,6 +17,7 @@ type HealthHandler struct { receive chan<- string // The remote server address server string + status int } // NewHealthHandler returns a new health check handler. @@ -24,6 +26,10 @@ func NewHealthHandler(server string, receive chan<- string) *HealthHandler { server: server, receive: receive, commands: make(chan string), + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, } return &h @@ -34,18 +40,13 @@ func (h *HealthHandler) Server() string { return h.server } -// Stop is not of use for health check handler. -func (h *HealthHandler) Stop() { - // Nothing done here. +// Status of the handler. +func (h *HealthHandler) Status() int { + return h.status } -// Ping is not of use for health check handler. -func (h *HealthHandler) Ping() error { - return nil -} - -// SendCommand send a DTail command to the server. -func (h *HealthHandler) SendCommand(command string) error { +// SendMessage sends a DTail command to the server. +func (h *HealthHandler) SendMessage(command string) error { select { case h.commands <- fmt.Sprintf("%s;", command): case <-time.NewTimer(time.Second * 10).C: diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go index d76cdfd..874bb7d 100644 --- a/internal/clients/handlers/maprhandler.go +++ b/internal/clients/handlers/maprhandler.go @@ -1,10 +1,11 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" + "strings" + + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/mapr/client" - "strings" ) // MaprHandler is the handler used on the client side for running mapreduce aggregations. @@ -16,15 +17,16 @@ type MaprHandler struct { } // NewMaprHandler returns a new mapreduce client handler. -func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet, pingTimeout int) *MaprHandler { +func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *MaprHandler { return &MaprHandler{ baseHandler: baseHandler{ server: server, shellStarted: false, commands: make(chan string), - pong: make(chan struct{}, 1), - stop: make(chan struct{}), - pingTimeout: pingTimeout, + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, }, query: query, aggregate: client.NewAggregate(server, query, globalGroup), @@ -65,10 +67,3 @@ func (h *MaprHandler) handleAggregateMessage(message string) { h.aggregate.Aggregate(parts[2:]) logger.Debug("Aggregated aggregate data", h.server, h.count) } - -// Stop stops the mapreduce client handler. -func (h *MaprHandler) Stop() { - logger.Debug("Stopping mapreduce handler", h.server) - h.aggregate.Stop() - h.baseHandler.Stop() -} diff --git a/internal/clients/handlers/withcancel.go b/internal/clients/handlers/withcancel.go new file mode 100644 index 0000000..7c9cf4e --- /dev/null +++ b/internal/clients/handlers/withcancel.go @@ -0,0 +1,24 @@ +package handlers + +import "context" + +type withCancel struct { + ctx context.Context + done chan struct{} +} + +// WithCancel sets and returns the context used. +func (w *withCancel) WithCancel(ctx context.Context) (context.Context, context.CancelFunc) { + cancelCtx, cancel := context.WithCancel(ctx) + w.ctx = cancelCtx + + return cancelCtx, cancel +} + +func (w *withCancel) Done() <-chan struct{} { + return w.done +} + +func (w *withCancel) shutdown() { + close(w.done) +} diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go index ff13b83..7313583 100644 --- a/internal/clients/healthclient.go +++ b/internal/clients/healthclient.go @@ -1,6 +1,7 @@ package clients import ( + "context" "fmt" "runtime" "strings" @@ -39,7 +40,7 @@ func NewHealthClient(mode omode.Mode) (*HealthClient, error) { } // Start the health client. -func (c *HealthClient) Start() (status int) { +func (c *HealthClient) Start(ctx context.Context) (status int) { receive := make(chan string) throttleCh := make(chan struct{}, runtime.NumCPU()) @@ -49,8 +50,8 @@ func (c *HealthClient) Start() (status int) { conn.Handler = handlers.NewHealthHandler(c.server, receive) conn.Commands = []string{c.mode.String()} - go conn.Start(throttleCh, statsCh) - defer conn.Stop() + connCtx, cancel := conn.Handler.WithCancel(ctx) + go conn.Start(connCtx, cancel, throttleCh, statsCh) for { select { diff --git a/internal/clients/maker.go b/internal/clients/maker.go new file mode 100644 index 0000000..da9dfc9 --- /dev/null +++ b/internal/clients/maker.go @@ -0,0 +1,8 @@ +package clients + +import "github.com/mimecast/dtail/internal/clients/handlers" + +type maker interface { + makeHandler(server string) handlers.Handler + makeCommands() (commands []string) +} diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go index 9070827..b581844 100644 --- a/internal/clients/maprclient.go +++ b/internal/clients/maprclient.go @@ -1,6 +1,7 @@ package clients import ( + "context" "errors" "fmt" "runtime" @@ -8,13 +9,9 @@ import ( "time" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // MaprClient is used for running mapreduce aggregations on remote files. @@ -39,8 +36,6 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { c := MaprClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: args.Mode == omode.TailClient, }, @@ -70,35 +65,36 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { return &c, nil } -func (c MaprClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewMaprHandler(conn.Server, c.query, c.globalGroup, c.PingTimeout) +// Start starts the mapreduce client. +func (c *MaprClient) Start(ctx context.Context) (status int) { + if c.query.Outfile == "" { + // Only print out periodic results if we don't write an outfile + go c.periodicPrintResults(ctx) + } - conn.Commands = append(conn.Commands, fmt.Sprintf("map %s", c.query.RawQuery)) - commandStr := "tail" + status = c.baseClient.Start(ctx) if c.additative { - commandStr = "cat" + c.recievedFinalResult() } - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", commandStr, file, c.Regex)) - } + return +} - return conn +func (c MaprClient) makeHandler(server string) handlers.Handler { + return handlers.NewMaprHandler(server, c.query, c.globalGroup) } -// Start starts the mapreduce client. -func (c *MaprClient) Start() (status int) { - if c.query.Outfile == "" { - // Only print out periodic results if we don't write an outfile - go c.periodicPrintResults() - } +func (c MaprClient) makeCommands() (commands []string) { + commands = append(commands, fmt.Sprintf("map %s", c.query.RawQuery)) - status = c.baseClient.Start() + modeStr := "tail" if c.additative { - c.recievedFinalResult() + modeStr = "cat" + } + + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", modeStr, file, c.Regex)) } - c.baseClient.Stop() return } @@ -120,13 +116,13 @@ func (c *MaprClient) recievedFinalResult() { logger.Info(fmt.Sprintf("Wrote final mapreduce result to '%s'", c.query.Outfile)) } -func (c *MaprClient) periodicPrintResults() { +func (c *MaprClient) periodicPrintResults(ctx context.Context) { for { select { case <-time.After(c.query.Interval): logger.Info("Gathering interim mapreduce result") c.printResults() - case <-c.baseClient.stop: + case <-ctx.Done(): return } } diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go index bfc7bc5..71639b1 100644 --- a/internal/clients/remote/connection.go +++ b/internal/clients/remote/connection.go @@ -1,16 +1,18 @@ package remote import ( - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/ssh/client" + "context" "fmt" "io" "strconv" "strings" "time" + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/ssh/client" + "golang.org/x/crypto/ssh" ) @@ -30,8 +32,6 @@ type Connection struct { Commands []string // Is it a persistent connection or a one-off? isOneOff bool - // Used to stop the connection - stop chan struct{} // To deal with SSH server host keys hostKeyCallback *client.HostKeyCallback } @@ -48,7 +48,6 @@ func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, HostKeyCallback: hostKeyCallback.Wrap(), Timeout: time.Second * 3, }, - stop: make(chan struct{}), } c.initServerPort(server) @@ -64,7 +63,6 @@ func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthM Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), }, - stop: make(chan struct{}), isOneOff: true, } @@ -90,39 +88,34 @@ func (c *Connection) initServerPort(server string) { } } -// Start the server connection. Build up SSH session and send some DTail commandc. -func (c *Connection) Start(throttleCh, statsCh chan struct{}) { +// Start the server connection. Build up SSH session and send some DTail commands. +func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { + // Throttle how many connections can be established concurrently (based on ch length) select { - case <-c.stop: - logger.Info(c.Server, c.port, "Disconnecting client") + case throttleCh <- struct{}{}: + defer func() { <-throttleCh }() + case <-ctx.Done(): return - default: } - // Wait for SSH connection throttler - throttleCh <- struct{}{} - - // Wait until connection has been initiated or an error occured - // during initialization. - throttleStopCh := make(chan struct{}, 2) go func() { - <-throttleStopCh - <-throttleCh - }() + defer cancel() - if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil { - logger.Warn(c.Server, c.port, err) - throttleStopCh <- struct{}{} + if err := c.dial(ctx, cancel, c.Server, c.port, statsCh); err != nil { + logger.Warn(c.Server, c.port, err) - if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { - logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port) - return + if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { + logger.Debug("Not trusting host", c.Server, c.port) + return + } } - } + }() + + <-ctx.Done() } // Dail into a new SSH connection. Close connection in case of an error. -func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error { +func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, statsCh chan struct{}) error { statsCh <- struct{}{} defer func() { <-statsCh }() @@ -135,11 +128,11 @@ func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan st } defer client.Close() - return c.session(client, throttleStopCh) + return c.session(ctx, cancel, client) } // Create the SSH session. Close the session in case of an error. -func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error { +func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client) error { logger.Debug(c.Server, "session") session, err := client.NewSession() @@ -148,14 +141,10 @@ func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) } defer session.Close() - return c.handle(session, throttleStopCh) + return c.handle(ctx, cancel, session) } -// Handle the SSH session. Also send periodic pings to the server in order -// to determine that session is still intact. -func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error { - defer c.Handler.Stop() - +func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session) error { logger.Debug(c.Server, "handle") stdinPipe, err := session.StdinPipe() @@ -172,59 +161,30 @@ func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{} return err } - // Establish Bi-directional pipe between SSH session and client handler. - brokenStdinPipe := make(chan struct{}) go func() { - defer close(brokenStdinPipe) + defer cancel() io.Copy(stdinPipe, c.Handler) }() - brokenStdoutPipe := make(chan struct{}) go func() { - defer close(brokenStdoutPipe) + defer cancel() io.Copy(c.Handler, stdoutPipe) }() - // SSH session established, other goroutine can initiate session now. - throttleStopCh <- struct{}{} + go func() { + defer cancel() + select { + case <-c.Handler.Done(): + case <-ctx.Done(): + } + }() // Send all commands to client. for _, command := range c.Commands { logger.Debug(command) - c.Handler.SendCommand(command) + c.Handler.SendMessage(command) } - if !c.isOneOff { - return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe) - } - - <-c.stop - - // Normal shutdown, all fine + <-ctx.Done() return nil } - -// Periodically check whether connection is still alive or not. -func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error { - for { - select { - case <-time.After(time.Second * 3): - if err := c.Handler.Ping(); err != nil { - return err - } - case <-brokenStdinPipe: - logger.Debug("Broken stdin pipe", c.Server, c.port) - return nil - case <-brokenStdoutPipe: - logger.Debug("Broken stdout pipe", c.Server, c.port) - return nil - case <-c.stop: - return nil - } - } -} - -// Stop the connection. -func (c *Connection) Stop() { - close(c.stop) -} diff --git a/internal/clients/runclient.go b/internal/clients/runclient.go new file mode 100644 index 0000000..7a62fcc --- /dev/null +++ b/internal/clients/runclient.go @@ -0,0 +1,40 @@ +package clients + +import ( + "fmt" + "runtime" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/omode" +) + +// RunClient is a client to run various commands on the server. +type RunClient struct { + baseClient +} + +// NewRunClient returns a new cat client. +func NewRunClient(args Args) (*RunClient, error) { + args.Mode = omode.RunClient + + c := RunClient{ + baseClient: baseClient{ + Args: args, + throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), + retry: false, + }, + } + + c.init(c) + return &c, nil +} + +func (c RunClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} + +func (c RunClient) makeCommands() (commands []string) { + // Send "run COMMAND" to server! + commands = append(commands, fmt.Sprintf("%s %s", c.Mode.String(), c.What)) + return +} diff --git a/internal/clients/stats.go b/internal/clients/stats.go index d36cef6..ec6adfe 100644 --- a/internal/clients/stats.go +++ b/internal/clients/stats.go @@ -1,11 +1,13 @@ package clients import ( - "github.com/mimecast/dtail/internal/logger" + "context" "fmt" "runtime" "sync" "time" + + "github.com/mimecast/dtail/internal/io/logger" ) // Used to collect and display various client stats. @@ -28,14 +30,14 @@ func newTailStats(connectionsTotal int) *stats { } } -func (s *stats) periodicLogStats(throttleCh chan struct{}, stop <-chan struct{}) { +func (s *stats) periodicLogStats(ctx context.Context, throttleCh chan struct{}) { connectedLast := 0 statsInterval := 5 for { select { case <-time.After(time.Second * time.Duration(statsInterval)): - case <-stop: + case <-ctx.Done(): return } diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go index 674ca36..4d81fd5 100644 --- a/internal/clients/tailclient.go +++ b/internal/clients/tailclient.go @@ -6,11 +6,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // TailClient is used for tailing remote log files (opening, seeking to the end and returning only new incoming lines). @@ -25,25 +21,22 @@ func NewTailClient(args Args) (*TailClient, error) { c := TailClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: true, }, } c.init(c) - return &c, nil } -func (c TailClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) +func (c TailClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c TailClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - - return conn + return } diff --git a/internal/discovery/comma.go b/internal/discovery/comma.go index ad18be0..94276c7 100644 --- a/internal/discovery/comma.go +++ b/internal/discovery/comma.go @@ -1,7 +1,7 @@ package discovery import ( - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "strings" ) diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index d76c1b2..1090ea9 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -1,7 +1,6 @@ package discovery import ( - "github.com/mimecast/dtail/internal/logger" "fmt" "math/rand" "os" @@ -9,6 +8,16 @@ import ( "regexp" "strings" "time" + + "github.com/mimecast/dtail/internal/io/logger" +) + +// ServerOrder to specify how to sort the server list. +type ServerOrder int + +const ( + // Shuffle the server list? + Shuffle ServerOrder = iota ) // Discovery method for discovering a list of available DTail servers. @@ -21,12 +30,12 @@ type Discovery struct { server string // To filter server list. regex *regexp.Regexp - // To shuffle resulting server list. - shuffle bool + // How to order the server list. + order ServerOrder } // New returns a new discovery method. -func New(method, server string, shuffle bool) *Discovery { +func New(method, server string, order ServerOrder) *Discovery { module := method options := "" @@ -43,7 +52,7 @@ func New(method, server string, shuffle bool) *Discovery { module: strings.ToUpper(module), options: options, server: server, - shuffle: shuffle, + order: order, } if strings.HasPrefix(server, "/") && strings.HasSuffix(server, "/") { @@ -84,7 +93,7 @@ func (d *Discovery) ServerList() []string { servers = d.dedupList(servers) - if d.shuffle { + if d.order == Shuffle { servers = d.shuffleList(servers) } diff --git a/internal/discovery/file.go b/internal/discovery/file.go index 2edc867..c04173e 100644 --- a/internal/discovery/file.go +++ b/internal/discovery/file.go @@ -2,7 +2,7 @@ package discovery import ( "bufio" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "os" ) diff --git a/internal/fs/catfile.go b/internal/io/fs/catfile.go index 99f521f..7f387bc 100644 --- a/internal/fs/catfile.go +++ b/internal/io/fs/catfile.go @@ -1,7 +1,5 @@ package fs -import "sync" - // CatFile is for reading a whole file. type CatFile struct { readFile @@ -9,19 +7,15 @@ type CatFile struct { // NewCatFile returns a new file catter. func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile { - var mutex sync.Mutex - return CatFile{ readFile: readFile{ filePath: filePath, - stop: make(chan struct{}), globID: globID, serverMessages: serverMessages, retry: false, canSkipLines: false, seekEOF: false, limiter: limiter, - mutex: &mutex, }, } } diff --git a/internal/fs/filereader.go b/internal/io/fs/filereader.go index 5a08e27..05e58a1 100644 --- a/internal/fs/filereader.go +++ b/internal/io/fs/filereader.go @@ -1,9 +1,14 @@ package fs +import ( + "context" + + "github.com/mimecast/dtail/internal/io/line" +) + // FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file. type FileReader interface { - Start(lines chan<- LineRead, regex string) error + Start(ctx context.Context, lines chan<- line.Line, regex string) error FilePath() string Retry() bool - Stop() } diff --git a/internal/fs/permissions/permission.go b/internal/io/fs/permissions/permission.go index 6e83309..0ed4f17 100644 --- a/internal/fs/permissions/permission.go +++ b/internal/io/fs/permissions/permission.go @@ -3,7 +3,7 @@ package permissions import ( - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" ) // ToRead is to check whether user has read permissions to a given file. diff --git a/internal/fs/permissions/permission_linux.c b/internal/io/fs/permissions/permission_linux.c index cd10525..cd10525 100644 --- a/internal/fs/permissions/permission_linux.c +++ b/internal/io/fs/permissions/permission_linux.c diff --git a/internal/fs/permissions/permission_linux.go b/internal/io/fs/permissions/permission_linux.go index feae729..feae729 100644 --- a/internal/fs/permissions/permission_linux.go +++ b/internal/io/fs/permissions/permission_linux.go diff --git a/internal/fs/permissions/permission_linux.h b/internal/io/fs/permissions/permission_linux.h index a2c266e..a2c266e 100644 --- a/internal/fs/permissions/permission_linux.h +++ b/internal/io/fs/permissions/permission_linux.h diff --git a/internal/fs/permissions/permission_test.go b/internal/io/fs/permissions/permission_test.go index d415ac2..d415ac2 100644 --- a/internal/fs/permissions/permission_test.go +++ b/internal/io/fs/permissions/permission_test.go diff --git a/internal/fs/readfile.go b/internal/io/fs/readfile.go index 312447a..321432e 100644 --- a/internal/fs/readfile.go +++ b/internal/io/fs/readfile.go @@ -3,7 +3,7 @@ package fs import ( "bufio" "compress/gzip" - "github.com/mimecast/dtail/internal/logger" + "context" "errors" "io" "os" @@ -12,6 +12,9 @@ import ( "sync" "time" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/DataDog/zstd" ) @@ -27,16 +30,12 @@ type readFile struct { globID string // Channel to send a server message to the dtail client serverMessages chan<- string - // Signals to stop tailing the log file. - stop chan struct{} // Periodically retry reading file. retry bool // Can I skip messages when there are too many? canSkipLines bool // Seek to the EOF before processing file? seekEOF bool - // Mutex to control the stopping of the file - mutex *sync.Mutex limiter chan struct{} } @@ -51,7 +50,7 @@ func (f readFile) Retry() bool { } // Start tailing a log file. -func (f readFile) Start(lines chan<- LineRead, regex string) error { +func (f readFile) Start(ctx context.Context, lines chan<- line.Line, regex string) error { defer func() { select { case <-f.limiter: @@ -64,7 +63,7 @@ func (f readFile) Start(lines chan<- LineRead, regex string) error { default: select { case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."): - case <-f.stop: + case <-ctx.Done(): return nil } f.limiter <- struct{}{} @@ -86,44 +85,30 @@ func (f readFile) Start(lines chan<- LineRead, regex string) error { var wg sync.WaitGroup wg.Add(1) - go f.periodicTruncateCheck(truncate) - go f.filter(&wg, rawLines, lines, regex) + go f.periodicTruncateCheck(ctx, truncate) + go f.filter(ctx, &wg, rawLines, lines, regex) - err = f.read(fd, rawLines, truncate) + err = f.read(ctx, fd, rawLines, truncate) close(rawLines) wg.Wait() return err } -func (f readFile) periodicTruncateCheck(truncate chan struct{}) { +func (f readFile) periodicTruncateCheck(ctx context.Context, truncate chan struct{}) { for { select { case <-time.After(time.Second * 3): select { case truncate <- struct{}{}: - case <-f.stop: + case <-ctx.Done(): } - case <-f.stop: + case <-ctx.Done(): return } } } -// Stop reading file. -func (f readFile) Stop() { - f.mutex.Lock() - defer f.mutex.Unlock() - - select { - case <-f.stop: - return - default: - } - - close(f.stop) -} - func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) { switch { case strings.HasSuffix(f.FilePath(), ".gz"): @@ -146,27 +131,31 @@ func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) { return } -func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error { +func (f readFile) read(ctx context.Context, fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error { + var offset uint64 + reader, err := f.makeReader(fd) if err != nil { return err } rawLine := make([]byte, 0, 512) - var offset uint64 lineLengthThreshold := 1024 * 1024 // 1mb longLineWarning := false for { select { + case <-ctx.Done(): + return nil + default: + } + + select { case <-truncate: if isTruncated, err := f.truncated(fd); isTruncated { return err } logger.Info(f.filePath, "Current offset", offset) - - case <-f.stop: - return nil default: } @@ -196,7 +185,7 @@ func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct rawLine = append(rawLine, '\n') select { case rawLines <- rawLine: - case <-f.stop: + case <-ctx.Done(): return nil } rawLine = make([]byte, 0, 512) @@ -219,7 +208,7 @@ func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct rawLine = append(rawLine, '\n') select { case rawLines <- rawLine: - case <-f.stop: + case <-ctx.Done(): return nil } rawLine = make([]byte, 0, 512) @@ -228,7 +217,7 @@ func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct } // Filter log lines matching a given regular expression. -func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- LineRead, regex string) { +func (f readFile) filter(ctx context.Context, wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- line.Line, regex string) { defer wg.Done() if regex == "" { @@ -252,7 +241,7 @@ func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan< if filteredLine, ok := f.transmittable(line, len(lines), cap(lines)); ok { select { case lines <- filteredLine: - case <-f.stop: + case <-ctx.Done(): return } } @@ -260,10 +249,10 @@ func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan< } } -func (f readFile) transmittable(line []byte, length, capacity int) (LineRead, bool) { - var read LineRead +func (f readFile) transmittable(lineBytes []byte, length, capacity int) (line.Line, bool) { + var read line.Line - if !f.re.Match(line) { + if !f.re.Match(lineBytes) { f.updateLineNotMatched() f.updateLineNotTransmitted() return read, false @@ -277,9 +266,9 @@ func (f readFile) transmittable(line []byte, length, capacity int) (LineRead, bo } f.updateLineTransmitted() - read = LineRead{ - Content: line, - GlobID: &f.globID, + read = line.Line{ + Content: lineBytes, + SourceID: f.globID, Count: f.totalLineCount(), TransmittedPerc: f.transmittedPerc(), } diff --git a/internal/fs/stats.go b/internal/io/fs/stats.go index 4121ff7..4121ff7 100644 --- a/internal/fs/stats.go +++ b/internal/io/fs/stats.go diff --git a/internal/fs/tailfile.go b/internal/io/fs/tailfile.go index a19d4e6..14994e5 100644 --- a/internal/fs/tailfile.go +++ b/internal/io/fs/tailfile.go @@ -1,7 +1,5 @@ package fs -import "sync" - // TailFile is to tail and filter a log file. type TailFile struct { readFile @@ -9,19 +7,15 @@ type TailFile struct { // NewTailFile returns a new file tailer. func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile { - var mutex sync.Mutex - return TailFile{ readFile: readFile{ filePath: filePath, - stop: make(chan struct{}), globID: globID, serverMessages: serverMessages, retry: true, canSkipLines: true, seekEOF: true, limiter: limiter, - mutex: &mutex, }, } } diff --git a/internal/fs/lineread.go b/internal/io/line/line.go index 7ee558e..9db93c0 100644 --- a/internal/fs/lineread.go +++ b/internal/io/line/line.go @@ -1,11 +1,11 @@ -package fs +package line import ( "fmt" ) -// LineRead represents a read log line. -type LineRead struct { +// Line represents a read log line. +type Line struct { // The content of the log line. Content []byte // Until now, how many log lines were processed? @@ -15,14 +15,14 @@ type LineRead struct { // lines if that happens but it will signal to the client how // many log lines in % could be transmitted to the client. TransmittedPerc int - GlobID *string + SourceID string } // Return a human readable representation of the followed line. -func (l LineRead) String() string { - return fmt.Sprintf("LineRead(Content:%s,TransmittedPerc:%v,Count:%v,GlobID:%s)", +func (l Line) String() string { + return fmt.Sprintf("Line(Content:%s,TransmittedPerc:%v,Count:%v,SourceID:%s)", string(l.Content), l.TransmittedPerc, l.Count, - *l.GlobID) + l.SourceID) } diff --git a/internal/logger/logger.go b/internal/io/logger/logger.go index ca85e32..e30b907 100644 --- a/internal/logger/logger.go +++ b/internal/io/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "bufio" + "context" "fmt" "os" "os/signal" @@ -48,17 +49,13 @@ var lastDateStr string var serverEnable bool // Used to make logging non-blocking. -var logBufCh chan buf +var fileLogBufCh chan buf var stdoutBufCh chan string // Stdout channel, required to pause output var pauseCh chan struct{} var resumeCh chan struct{} -// Tell the logger that we are done, program shuts down -var stop chan struct{} -var stdoutFlushed chan struct{} - // Tell the logger about logrotation var rotateCh chan os.Signal @@ -103,7 +100,7 @@ type buf struct { } // Start logging. -func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) { +func Start(ctx context.Context, myServerEnable, debugEnable, silentEnable, nothingEnable bool) { serverEnable = myServerEnable mode := logMode(debugEnable, silentEnable, nothingEnable) @@ -125,7 +122,7 @@ func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) { case StdoutStrategy: fallthrough default: - logToFile = false + logToFile = !serverEnable logToStdout = true } @@ -138,8 +135,6 @@ func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) { pauseCh = make(chan struct{}) resumeCh = make(chan struct{}) - stop = make(chan struct{}) - stdoutFlushed = make(chan struct{}) // Setup logrotation rotateCh = make(chan os.Signal, 1) @@ -147,12 +142,12 @@ func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) { if logToStdout { stdoutBufCh = make(chan string, runtime.NumCPU()*100) - go writeToStdout() + go writeToStdout(ctx) } if logToFile { - logBufCh = make(chan buf, runtime.NumCPU()*100) - go writeToFile() + fileLogBufCh = make(chan buf, runtime.NumCPU()*100) + go writeToFile(ctx) } } @@ -264,7 +259,7 @@ func write(what, severity, message string) { if logToFile { t := time.Now() timeStr := t.Format("20060102-150405") - logBufCh <- buf{ + fileLogBufCh <- buf{ time: t, message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message), } @@ -304,16 +299,16 @@ func Raw(message string) { return } + if logToFile { + fileLogBufCh <- buf{time.Now(), message} + } + if logToStdout { if color.Colored { message = color.Colorfy(message) } stdoutBufCh <- message } - - if logToFile { - logBufCh <- buf{time.Now(), message} - } } // Close log writer (e.g. on change of day). @@ -367,9 +362,8 @@ func updateFileWriter(dateStr string) *bufio.Writer { return writer } -func flushStdout() { - defer close(stdoutFlushed) - +// Flush all outstanding lines. +func Flush() { for { select { case message := <-stdoutBufCh: @@ -381,7 +375,7 @@ func flushStdout() { } } -func writeToStdout() { +func writeToStdout(ctx context.Context) { for { select { case message := <-stdoutBufCh: @@ -395,21 +389,21 @@ func writeToStdout() { case <-stdoutBufCh: case <-resumeCh: break PAUSE - case <-stop: + case <-ctx.Done(): return } } - case <-stop: - flushStdout() + case <-ctx.Done(): + Flush() return } } } -func writeToFile() { +func writeToFile(ctx context.Context) { for { select { - case buf := <-logBufCh: + case buf := <-fileLogBufCh: dateStr := buf.time.Format("20060102") w := fileWriter(dateStr) w.WriteString(buf.message) @@ -420,11 +414,11 @@ func writeToFile() { case <-stdoutBufCh: case <-resumeCh: break PAUSE - case <-stop: + case <-ctx.Done(): return } } - case <-stop: + case <-ctx.Done(): return } } @@ -449,9 +443,3 @@ func Resume() { resumeCh <- struct{}{} } } - -// Stop logging. -func Stop() { - close(stop) - <-stdoutFlushed -} diff --git a/internal/io/run/run.go b/internal/io/run/run.go new file mode 100644 index 0000000..b608639 --- /dev/null +++ b/internal/io/run/run.go @@ -0,0 +1,104 @@ +package run + +import ( + "bufio" + "context" + "io" + "os/exec" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" +) + +// Run is for execute a command. +type Run struct { + commandPath string + args []string + cmd *exec.Cmd +} + +// New returns a new command runner. +func New(commandPath string, args []string) Run { + return Run{ + commandPath: commandPath, + args: args, + } +} + +// Start running the command. +func (r Run) Start(ctx context.Context, lines chan<- line.Line) (pid int, ec int, err error) { + done := make(chan struct{}) + defer close(done) + + ec = -1 + pid = -1 + + if len(r.args) > 0 { + logger.Debug(r.commandPath, strings.Join(r.args, " ")) + r.cmd = exec.CommandContext(ctx, r.commandPath, strings.Join(r.args, " ")) + } else { + logger.Debug(r.commandPath) + r.cmd = exec.CommandContext(ctx, r.commandPath) + } + + stdoutPipe, myErr := r.cmd.StdoutPipe() + if err != nil { + err = myErr + return + } + + stderrPipe, myErr := r.cmd.StderrPipe() + if myErr != nil { + err = myErr + return + } + + if myErr := r.cmd.Start(); err != nil { + err = myErr + return + } + + pid = r.cmd.Process.Pid + ec = 0 + + var wg sync.WaitGroup + wg.Add(2) + + go r.pipeToLines(done, &wg, pid, stdoutPipe, "STDOUT", lines) + go r.pipeToLines(done, &wg, pid, stderrPipe, "STDERR", lines) + + if err = r.cmd.Wait(); err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + ec = exitError.ExitCode() + } + } + + return +} + +func (r Run) pipeToLines(done chan struct{}, wg *sync.WaitGroup, pid int, reader io.Reader, what string, lines chan<- line.Line) { + defer wg.Done() + bufReader := bufio.NewReader(reader) + + for { + lineStr, err := bufReader.ReadString('\n') + for err == nil { + lines <- line.Line{ + Content: []byte(lineStr), + Count: uint64(pid), + TransmittedPerc: 100, + SourceID: what, + } + lineStr, err = bufReader.ReadString('\n') + } + select { + case <-done: + return + default: + } + time.Sleep(time.Millisecond * 10) + } +} diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go index 2096c3c..7fb4c17 100644 --- a/internal/mapr/aggregateset.go +++ b/internal/mapr/aggregateset.go @@ -1,6 +1,7 @@ package mapr import ( + "context" "fmt" "strconv" "strings" @@ -64,7 +65,7 @@ func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error { } // Serialize the aggregate set so it can be sent over the wire. -func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan struct{}) { +func (s *AggregateSet) Serialize(ctx context.Context, groupKey string, ch chan<- string) { //logger.Trace("Serialising mapr.AggregateSet", s) var sb strings.Builder @@ -87,7 +88,7 @@ func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan st select { case ch <- sb.String(): - case <-stop: + case <-ctx.Done(): } } diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go index 3f2b7a5..1272a19 100644 --- a/internal/mapr/client/aggregate.go +++ b/internal/mapr/client/aggregate.go @@ -1,10 +1,11 @@ package client import ( - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/mapr" "strconv" "strings" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/mapr" ) // Aggregate mapreduce data on the DTail client side. @@ -15,7 +16,6 @@ type Aggregate struct { group *mapr.GroupSet // This represents the merged aggregated data of all servers. globalGroup *mapr.GlobalGroupSet - stop chan struct{} // The server we aggregate the data for (logging and debugging purposes only) server string } @@ -26,20 +26,12 @@ func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGrou query: query, group: mapr.NewGroupSet(), globalGroup: globalGroup, - stop: make(chan struct{}), server: server, } } // Aggregate data from mapr log line into local (and global) group sets. func (a *Aggregate) Aggregate(parts []string) { - select { - case <-a.stop: - logger.Error("Client aggregator stopped for server, not processing new data", a.server) - return - default: - } - groupKey := parts[0] samples, err := strconv.Atoi(parts[1]) if err != nil { @@ -87,14 +79,3 @@ func (a *Aggregate) makeFields(parts []string) map[string]string { return fields } - -// Stop the client side mapreduce aggregator. -func (a *Aggregate) Stop() { - logger.Debug("Stopping client mapreduce aggregator") - close(a.stop) - - err := a.globalGroup.Merge(a.query, a.group) - if err != nil { - panic(err) - } -} diff --git a/internal/mapr/groupset.go b/internal/mapr/groupset.go index d8f9379..e9e0d37 100644 --- a/internal/mapr/groupset.go +++ b/internal/mapr/groupset.go @@ -1,6 +1,7 @@ package mapr import ( + "context" "errors" "fmt" "io/ioutil" @@ -46,9 +47,9 @@ func (g *GroupSet) GetSet(groupKey string) *AggregateSet { } // Serialize the group set (e.g. to send it over the wire). -func (g *GroupSet) Serialize(ch chan<- string, stop chan struct{}) { +func (g *GroupSet) Serialize(ctx context.Context, ch chan<- string) { for groupKey, set := range g.sets { - set.Serialize(groupKey, ch, stop) + set.Serialize(ctx, groupKey, ch) } } diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go index 5730d29..09c706b 100644 --- a/internal/mapr/logformat/parser.go +++ b/internal/mapr/logformat/parser.go @@ -1,9 +1,9 @@ package logformat import ( - "github.com/mimecast/dtail/internal/logger" "errors" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "os" "reflect" "strings" diff --git a/internal/mapr/query.go b/internal/mapr/query.go index 3805d15..0127be3 100644 --- a/internal/mapr/query.go +++ b/internal/mapr/query.go @@ -1,9 +1,9 @@ package mapr import ( - "github.com/mimecast/dtail/internal/logger" "errors" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "strconv" "strings" "time" diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go index 900756e..922dcbd 100644 --- a/internal/mapr/server/aggregate.go +++ b/internal/mapr/server/aggregate.go @@ -1,26 +1,28 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/fs" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/mapr" - "github.com/mimecast/dtail/internal/mapr/logformat" + "context" "os" "strings" "time" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/mapr" + "github.com/mimecast/dtail/internal/mapr/logformat" ) // Aggregate is for aggregating mapreduce data on the DTail server side. type Aggregate struct { // Log lines to process (parsing MAPREDUCE lines). - Lines chan fs.LineRead + Lines chan line.Line // Hostname of the current server (used to populate $hostname field). hostname string - // Signals to exit goroutine. - stop chan struct{} // Signals to serialize data. serialize chan struct{} + // Signals to flush data. + flush chan struct{} // The mapr query query *mapr.Query // The mapr log format parser @@ -28,7 +30,7 @@ type Aggregate struct { } // NewAggregate return a new server side aggregator. -func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error) { +func NewAggregate(queryStr string) (*Aggregate, error) { query, err := mapr.NewQuery(queryStr) if err != nil { return nil, err @@ -47,76 +49,98 @@ func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error) } a := Aggregate{ - Lines: make(chan fs.LineRead, 100), - stop: make(chan struct{}), + Lines: make(chan line.Line, 100), serialize: make(chan struct{}), + flush: make(chan struct{}), hostname: s[0], query: query, parser: logParser, } - go a.periodicAggregateTimer() - - fieldsCh := make(chan map[string]string) - go a.readFields(fieldsCh, maprLines) - go a.readLines(fieldsCh) - return &a, nil } -func (a *Aggregate) periodicAggregateTimer() { +// Start an aggregation run. +func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) { + fieldsCh := a.linesToFields(ctx) + go a.fieldsToMaprLines(ctx, fieldsCh, maprLines) + a.periodicAggregateTimer(ctx) +} + +func (a *Aggregate) periodicAggregateTimer(ctx context.Context) { for { select { case <-time.After(a.query.Interval): - a.Serialize() - case <-a.stop: + a.Serialize(ctx) + case <-ctx.Done(): return } } } -func (a *Aggregate) readFields(fieldsCh <-chan map[string]string, maprLines chan<- string) { - group := mapr.NewGroupSet() +func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string { + fieldsCh := make(chan map[string]string) - for { - select { - case fields := <-fieldsCh: - a.aggregate(group, fields) - case <-a.serialize: - logger.Info("Serializing mapreduce result") - group.Serialize(maprLines, a.stop) - logger.Info("Done serializing mapreduce result") - group = mapr.NewGroupSet() - case <-a.stop: - return + go func() { + defer close(fieldsCh) + + for { + select { + case line, ok := <-a.Lines: + if !ok { + return + } + + maprLine := strings.TrimSpace(string(line.Content)) + fields, err := a.parser.MakeFields(maprLine) + + if err != nil { + logger.Error(err) + continue + } + if !a.query.WhereClause(fields) { + continue + } + + select { + case fieldsCh <- fields: + case <-ctx.Done(): + } + case <-ctx.Done(): + return + } } - } + }() + + return fieldsCh } -func (a *Aggregate) readLines(fieldsCh chan<- map[string]string) { +func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) { + group := mapr.NewGroupSet() + for { select { - case line, ok := <-a.Lines: + case fields, ok := <-fieldsCh: if !ok { + logger.Info("Serializing mapreduce result (final)") + group.Serialize(ctx, maprLines) + group = mapr.NewGroupSet() + logger.Info("Done serializing mapreduce result (final)") return } - - maprLine := strings.TrimSpace(string(line.Content)) - fields, err := a.parser.MakeFields(maprLine) - - if err != nil { - logger.Error(err) - continue - } - if !a.query.WhereClause(fields) { - continue - } - - select { - case fieldsCh <- fields: - case <-a.stop: - } - case <-a.stop: + a.aggregate(group, fields) + case <-a.serialize: + logger.Info("Serializing mapreduce result") + group.Serialize(ctx, maprLines) + group = mapr.NewGroupSet() + logger.Info("Done serializing mapreduce result") + case <-a.flush: + logger.Info("Flushing mapreduce result") + group.Serialize(ctx, maprLines) + group = mapr.NewGroupSet() + a.flush <- struct{}{} + logger.Info("Done flushing mapreduce result") + case <-ctx.Done(): return } } @@ -157,14 +181,15 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) { } // Serialize all the aggregated data. -func (a *Aggregate) Serialize() { +func (a *Aggregate) Serialize(ctx context.Context) { select { case a.serialize <- struct{}{}: - case <-a.stop: + case <-ctx.Done(): } } -// Close the aggregator. -func (a *Aggregate) Close() { - close(a.stop) +// Flush all data. +func (a *Aggregate) Flush() { + a.flush <- struct{}{} + <-a.flush } diff --git a/internal/mapr/wherecondition.go b/internal/mapr/wherecondition.go index e1f4e5b..ab46bed 100644 --- a/internal/mapr/wherecondition.go +++ b/internal/mapr/wherecondition.go @@ -1,9 +1,9 @@ package mapr import ( - "github.com/mimecast/dtail/internal/logger" "errors" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "strconv" "strings" ) diff --git a/internal/omode/mode.go b/internal/omode/mode.go index 57366d2..e29aacc 100644 --- a/internal/omode/mode.go +++ b/internal/omode/mode.go @@ -12,7 +12,7 @@ const ( GrepClient Mode = iota MapClient Mode = iota HealthClient Mode = iota - ExecClient Mode = iota + RunClient Mode = iota ) func (m Mode) String() string { @@ -29,8 +29,8 @@ func (m Mode) String() string { return "map" case HealthClient: return "health" - case ExecClient: - return "exec" + case RunClient: + return "run" default: return "unknown" } diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go index f78bcf6..c6d11ca 100644 --- a/internal/pprof/pprof.go +++ b/internal/pprof/pprof.go @@ -7,9 +7,10 @@ import ( _ "net/http/pprof" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" ) +// Start the profiler HTTP server. func Start() { bindAddr := fmt.Sprintf("%s:%d", config.Common.PProfBindAddress, config.Common.PProfPort) logger.Info("Starting PProf server", bindAddr) diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go index 76a2726..a438d33 100644 --- a/internal/prompt/prompt.go +++ b/internal/prompt/prompt.go @@ -2,8 +2,8 @@ package prompt import ( "bufio" - "github.com/mimecast/dtail/internal/logger" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "os" "strings" ) diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go index 482f759..a33a78b 100644 --- a/internal/server/handlers/controlhandler.go +++ b/internal/server/handlers/controlhandler.go @@ -1,33 +1,34 @@ package handlers import ( + "context" "fmt" "io" "os" "strings" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" user "github.com/mimecast/dtail/internal/user/server" ) // ControlHandler is used for control functions and health monitoring. type ControlHandler struct { - serverMessages chan string - pong chan struct{} - stop chan struct{} - payload []byte + ctx context.Context + done chan struct{} hostname string + payload []byte + serverMessages chan string user *user.User } // NewControlHandler returns a new control handler. -func NewControlHandler(user *user.User) *ControlHandler { +func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) { logger.Debug(user, "Creating control handler") h := ControlHandler{ + ctx: ctx, + done: make(chan struct{}), serverMessages: make(chan string, 10), - pong: make(chan struct{}, 10), - stop: make(chan struct{}), user: user, } @@ -38,7 +39,8 @@ func NewControlHandler(user *user.User) *ControlHandler { s := strings.Split(fqdn, ".") h.hostname = s[0] - return &h + + return &h, h.done } // Read is to send data to the client via the Reader interface. @@ -49,11 +51,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) { wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) n = copy(p, wholePayload) return - case <-h.pong: - logger.Info(h.user, "Sending pong") - n = copy(p, []byte(".pong\n")) - return - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF } } @@ -65,7 +63,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) { switch c { case ';': wholePayload := strings.TrimSpace(string(h.payload)) - h.handleCommand(wholePayload) + h.handleCommand(h.ctx, wholePayload) h.payload = nil default: @@ -77,17 +75,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) { return } -// Close the control handler. -func (h *ControlHandler) Close() { - close(h.stop) -} - -// Wait returns the handler stop channel. -func (h *ControlHandler) Wait() <-chan struct{} { - return h.stop -} - -func (h *ControlHandler) handleCommand(command string) { +func (h *ControlHandler) handleCommand(ctx context.Context, command string) { logger.Info(h.user, command) s := strings.Split(command, " ") logger.Debug(h.user, "Receiving command", command, s) @@ -96,8 +84,6 @@ func (h *ControlHandler) handleCommand(command string) { case "health": h.serverMessages <- "OK: DTail SSH Server seems fine" h.serverMessages <- "done;" - case "ping": - h.pong <- struct{}{} case "debug": h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s) default: diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go index 8b1f73e..c42ceb9 100644 --- a/internal/server/handlers/handler.go +++ b/internal/server/handlers/handler.go @@ -5,6 +5,4 @@ import "io" // Handler interface for server side functionality. type Handler interface { io.ReadWriter - Close() - Wait() <-chan struct{} } diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go new file mode 100644 index 0000000..10372da --- /dev/null +++ b/internal/server/handlers/mapcommand.go @@ -0,0 +1,35 @@ +package handlers + +import ( + "context" + "strings" + + "github.com/mimecast/dtail/internal/mapr/server" +) + +// Map command implements the mapreduce command server side. +type mapCommand struct { + aggregate *server.Aggregate + server *ServerHandler +} + +// NewMapCommand returns a new server side mapreduce command. +func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) { + mapCommand := mapCommand{ + server: serverHandler, + } + + queryStr := strings.Join(args[1:], " ") + aggregate, err := server.NewAggregate(queryStr) + if err != nil { + return mapCommand, nil, err + } + + mapCommand.aggregate = aggregate + return mapCommand, aggregate, nil + +} + +func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) { + m.aggregate.Start(ctx, aggregatedMessages) +} diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go new file mode 100644 index 0000000..e4079e8 --- /dev/null +++ b/internal/server/handlers/readcommand.go @@ -0,0 +1,158 @@ +package handlers + +import ( + "context" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/fs" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/omode" +) + +type readCommand struct { + server *ServerHandler + mode omode.Mode +} + +func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand { + return &readCommand{ + server: server, + mode: mode, + } +} + +func (r *readCommand) Start(ctx context.Context, argc int, args []string) { + regex := "." + if argc >= 4 { + regex = args[3] + } + if argc < 3 { + r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) + return + } + r.readGlob(ctx, args[1], regex) +} + +func (r *readCommand) readGlob(ctx context.Context, glob string, regex string) { + retryInterval := time.Second * 5 + glob = filepath.Clean(glob) + + maxRetries := 10 + for { + maxRetries-- + if maxRetries < 0 { + r.server.sendServerMessage(logger.Warn(r.server.user, "Giving up to read file(s)")) + return + } + + paths, err := filepath.Glob(glob) + if err != nil { + logger.Warn(r.server.user, glob, err) + time.Sleep(retryInterval) + continue + } + + if numPaths := len(paths); numPaths == 0 { + logger.Error(r.server.user, "No such file(s) to read", glob) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + select { + case <-ctx.Done(): + return + default: + } + time.Sleep(retryInterval) + continue + } + + r.readFiles(ctx, paths, glob, regex, retryInterval) + break + } +} + +func (r *readCommand) readFiles(ctx context.Context, paths []string, glob string, regex string, retryInterval time.Duration) { + var wg sync.WaitGroup + wg.Add(len(paths)) + + for _, path := range paths { + go r.readFileIfPermissions(ctx, &wg, path, glob, regex) + } + + wg.Wait() +} + +func (r *readCommand) readFileIfPermissions(ctx context.Context, wg *sync.WaitGroup, path, glob, regex string) { + defer wg.Done() + globID := r.makeGlobID(path, glob) + + if !r.server.user.HasFilePermission(path, "readfiles") { + logger.Error(r.server.user, "No permission to read file", path, globID) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + return + } + + r.readFile(ctx, path, globID, regex) +} + +func (r *readCommand) readFile(ctx context.Context, path, globID, regex string) { + logger.Info(r.server.user, "Start reading file", path, globID) + + var reader fs.FileReader + switch r.mode { + case omode.TailClient: + reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter) + case omode.GrepClient, omode.CatClient: + reader = fs.NewCatFile(path, globID, r.server.serverMessages, r.server.catLimiter) + default: + reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter) + } + + lines := r.server.lines + + // Plug in mappreduce engine + if r.server.aggregate != nil { + lines = r.server.aggregate.Lines + } + + for { + if err := reader.Start(ctx, lines, regex); err != nil { + logger.Error(r.server.user, path, globID, err) + } + + select { + case <-ctx.Done(): + return + default: + if !reader.Retry() { + return + } + } + + time.Sleep(time.Second * 2) + logger.Info(path, globID, "Reading file again") + } +} + +func (r *readCommand) makeGlobID(path, glob string) string { + var idParts []string + pathParts := strings.Split(path, "/") + + for i, globPart := range strings.Split(glob, "/") { + if strings.Contains(globPart, "*") { + idParts = append(idParts, pathParts[i]) + } + } + + if len(idParts) > 0 { + return strings.Join(idParts, "/") + } + + if len(pathParts) > 0 { + return pathParts[len(pathParts)-1] + } + + r.server.sendServerMessage(logger.Error("Empty file path given?", path, glob)) + return "" +} diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go new file mode 100644 index 0000000..e260060 --- /dev/null +++ b/internal/server/handlers/runcommand.go @@ -0,0 +1,73 @@ +package handlers + +import ( + "context" + "fmt" + "os/exec" + "strings" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/run" +) + +type runCommand struct { + server *ServerHandler + run run.Run +} + +func newRunCommand(server *ServerHandler) runCommand { + return runCommand{ + server: server, + } +} + +func (r runCommand) Start(ctx context.Context, argc int, args []string) { + if argc < 2 { + r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) + return + } + commands := strings.Split(strings.Join(args[1:], " "), ";") + r.start(ctx, commands) +} + +func (r runCommand) start(ctx context.Context, commands []string) { + for _, command := range commands { + command = strings.TrimSpace(command) + if len(command) == 0 { + continue + } + splitted := strings.Split(command, " ") + path := splitted[0] + args := splitted[1:] + + qualifiedPath, err := exec.LookPath(path) + if err != nil { + logger.Error(r.server.user, err) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs")) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1)) + return + } + + if !r.server.user.HasFilePermission(qualifiedPath, "runcommands") { + logger.Error(r.server.user, "No permission to execute path", qualifiedPath) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs")) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1)) + return + } + + r.run = run.New(qualifiedPath, args) + pid, ec, err := r.run.Start(ctx, r.server.lines) + + if err != nil { + message := fmt.Sprintf("Unable to execute remote command '%s'", command) + logger.Error(r.server.user, message, ec, pid, err) + r.server.sendServerMessage(logger.Error(message, ec, pid, err)) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", ec)) + return + } + + message := fmt.Sprintf("Remote process '%d' exited with status '%d'", pid, ec) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec)) + r.server.sendServerMessage(logger.Info("run", pid, ec, message)) + } +} diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index bed8609..3f0d6ce 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -1,17 +1,19 @@ package handlers import ( + "context" + "encoding/base64" + "errors" "fmt" "io" "os" - "path/filepath" "strings" "sync" "time" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/fs" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr/server" "github.com/mimecast/dtail/internal/omode" user "github.com/mimecast/dtail/internal/user/server" @@ -26,51 +28,33 @@ const ( // the Bi-directional communication between SSH client and server. // This handler implements the handler of the SSH server. type ServerHandler struct { - // Local log file readers - fileReaders []fs.FileReader - fileReadersMtx *sync.Mutex - // Channel for read lines. - lines chan fs.LineRead - // Only process log lines matching this regex. - regex string - // Server side mapr log aggregation. - aggregate *server.Aggregate - // Channel of aggregated log lines. + mutex *sync.Mutex + lines chan line.Line + regex string + aggregate *server.Aggregate aggregatedMessages chan string - // Channel for server messages to be sent to the client. - serverMessages chan string - // Channel for hidden messages to be sent to the client. - hiddenMessages chan string - // The current payload sent to the client. - payload []byte - // The current server hostname. - hostname string - // The user connecting to dtail. - user *user.User - // To limit the server wide max amount of concurrent cats - catLimiter chan struct{} - // To limit the server wide max amount of concurrent tails - tailLimiter chan struct{} - // Server can tell handler to stop the handler. - stop chan struct{} - // Indicate that client responded to server with "ack stop connection" - ackStopReceived chan struct{} - // Stop timeout. - stopTimeout chan struct{} + serverMessages chan string + payload []byte + hostname string + user *user.User + catLimiter chan struct{} + tailLimiter chan struct{} + ackCloseReceived chan struct{} + ctx context.Context + done chan struct{} + activeReaders int } // NewServerHandler returns the server handler. -func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) *ServerHandler { - logger.Debug(user, "Creating tail handler") +func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) { h := ServerHandler{ - fileReadersMtx: &sync.Mutex{}, - lines: make(chan fs.LineRead, 100), + ctx: ctx, + done: make(chan struct{}), + mutex: &sync.Mutex{}, + lines: make(chan line.Line, 100), serverMessages: make(chan string, 10), aggregatedMessages: make(chan string, 10), - hiddenMessages: make(chan string, 10), - ackStopReceived: make(chan struct{}), - stopTimeout: make(chan struct{}), - stop: make(chan struct{}), + ackCloseReceived: make(chan struct{}), catLimiter: catLimiter, tailLimiter: tailLimiter, regex: ".", @@ -85,37 +69,46 @@ func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter cha s := strings.Split(fqdn, ".") h.hostname = s[0] - return &h + return &h, h.done } // Read is to send data to the dtail client via Reader interface. func (h *ServerHandler) Read(p []byte) (n int, err error) { for { select { + case message := <-h.serverMessages: + if message[0] == '.' { + // Handle hidden message (don't display to the user, interpreted by dtail client) + wholePayload := []byte(fmt.Sprintf("%s\n", message)) + n = copy(p, wholePayload) + return + } + + // Handle normal server message (display to the user) wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) n = copy(p, wholePayload) return + case message := <-h.aggregatedMessages: + // Send mapreduce-aggregated data as a message. data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message) - //logger.Debug("Sending aggregation data", data) wholePayload := []byte(data) n = copy(p, wholePayload) return - case message := <-h.hiddenMessages: - //logger.Debug(h.user, "Sending hidden message", message) - wholePayload := []byte(fmt.Sprintf(".%s\n", message)) - n = copy(p, wholePayload) - return + case line := <-h.lines: + // Send normal file content data as a message. serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|", - h.hostname, line.TransmittedPerc, line.Count, *line.GlobID)) + h.hostname, line.TransmittedPerc, line.Count, line.SourceID)) wholePayload := append(serverInfo, line.Content[:]...) n = copy(p, wholePayload) return + case <-time.After(time.Second): + // Once in a while check whether we are done. select { - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF default: } @@ -129,7 +122,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { switch c { case ';': commandStr := strings.TrimSpace(string(h.payload)) - h.handleCommand(commandStr) + h.handleCommand(h.ctx, commandStr) h.payload = nil default: h.payload = append(h.payload, c) @@ -140,210 +133,167 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { return } -// Close the server handler. -func (h *ServerHandler) Close() { - h.fileReadersMtx.Lock() - defer h.fileReadersMtx.Unlock() +func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { + logger.Debug(h.user, commandStr) - for _, reader := range h.fileReaders { - reader.Stop() + args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return } - if h.aggregate != nil { - h.aggregate.Close() + + args, argc, err = h.handleBase64(args, argc) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return } - close(h.stop) -} + if h.user.Name == config.ControlUser { + h.handleControlCommand(argc, args) + return + } -func (h *ServerHandler) makeGlobID(path, glob string) string { - var idParts []string - pathParts := strings.Split(path, "/") + h.handleUserCommand(ctx, argc, args) +} - for i, globPart := range strings.Split(glob, "/") { - if strings.Contains(globPart, "*") { - idParts = append(idParts, pathParts[i]) - } - } +func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) { + argc := len(args) - if len(idParts) > 0 { - return strings.Join(idParts, "/") + if argc <= 2 || args[0] != "protocol" { + return args, argc, errors.New("unable to determine protocol version") } - if len(pathParts) > 0 { - return pathParts[len(pathParts)-1] + if args[1] != version.ProtocolCompat { + err := fmt.Errorf("server with protool version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1]) + return args, argc, err } - h.send(h.serverMessages, logger.Error("Empty file path given?", path, glob)) - return "" + return args[2:], argc - 2, nil } -func (h *ServerHandler) processFileGlob(mode omode.Mode, glob string, regex string) { - retryInterval := time.Second * 5 - glob = filepath.Clean(glob) - - errors := make(chan struct{}) - stop := make(chan struct{}) - defer close(stop) +func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) { + err := errors.New("Unable to decode client message") - go func() { - for { - select { - case <-errors: - h.send(h.serverMessages, logger.Warn(h.user, "Unable to read file(s), check server logs")) - case <-stop: - return - case <-h.stop: - return - } - } - }() + if argc != 2 || args[0] != "base64" { + return args, argc, err + } - maxRetries := 10 - for { - maxRetries-- - if maxRetries < 0 { - h.send(h.serverMessages, logger.Warn(h.user, "Giving up to read file(s)")) - h.internalClose() - return - } + decoded, err := base64.StdEncoding.DecodeString(args[1]) + if err != nil { + return args, argc, err + } + decodedStr := string(decoded) - paths, err := filepath.Glob(glob) - if err != nil { - logger.Warn(h.user, glob, err) - time.Sleep(retryInterval) - continue - } + args = strings.Split(decodedStr, " ") + argc = len(decodedStr) + logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args) - if numPaths := len(paths); numPaths == 0 { - logger.Error(h.user, "No such file(s) to read", glob) - select { - case errors <- struct{}{}: - case <-h.stop: - return - default: - } - time.Sleep(retryInterval) - continue - } + return args, argc, nil +} - h.startReadingFiles(mode, paths, glob, regex, retryInterval, errors) - break +func (h *ServerHandler) handleControlCommand(argc int, args []string) { + switch args[0] { + case "debug": + h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) + default: + logger.Warn(h.user, "Received unknown command", argc, args) } } -func (h *ServerHandler) startReadingFiles(mode omode.Mode, paths []string, glob string, regex string, retryInterval time.Duration, errors chan<- struct{}) { - var wg sync.WaitGroup - wg.Add(len(paths)) +func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) { + logger.Debug(h.user, "handleUserCommand", argc, args) - read := func(path string, wg *sync.WaitGroup) { - defer wg.Done() - globID := h.makeGlobID(path, glob) + switch args[0] { + case "grep", "cat": + command := newReadCommand(h, omode.CatClient) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() + } + }() - if !h.user.HasFilePermission(path) { - logger.Error(h.user, "No permission to read file", path, globID) - select { - case errors <- struct{}{}: - default: + case "tail": + command := newReadCommand(h, omode.TailClient) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() } + }() + + case "map": + command, aggregate, err := newMapCommand(h, argc, args) + if err != nil { + h.sendServerMessage(err.Error()) + logger.Error(h.user, err) return } - h.startReadingFile(mode, path, globID, regex) - } - - for _, path := range paths { - go read(path, &wg) - } + h.aggregate = aggregate + go func() { + command.Start(ctx, h.aggregatedMessages) + h.shutdown() + }() + + case "run": + command := newRunCommand(h) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() + } + }() - wg.Wait() -} + case "ack", ".ack": + h.handleAckCommand(argc, args) -func (h *ServerHandler) startReadingFile(mode omode.Mode, path, globID, regex string) { - defer h.stopReadingFile(path) - logger.Info(h.user, "Start reading file", path, globID) - - var reader fs.FileReader - switch mode { - case omode.TailClient: - reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) - case omode.GrepClient: - fallthrough - case omode.CatClient: - reader = fs.NewCatFile(path, globID, h.serverMessages, h.catLimiter) default: - reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) + h.sendServerMessage(logger.Error(h.user, "Received unknown command", argc, args)) } +} - h.fileReadersMtx.Lock() - h.fileReaders = append(h.fileReaders, reader) - h.fileReadersMtx.Unlock() - - lines := h.lines - // Plugin mappreduce engine - if h.aggregate != nil { - lines = h.aggregate.Lines +func (h *ServerHandler) handleAckCommand(argc int, args []string) { + if argc < 3 { + h.sendServerMessage(logger.Warn(h.user, commandParseWarning, args, argc)) + return } - - for { - if err := reader.Start(lines, regex); err != nil { - logger.Error(h.user, path, globID, err) - } - - select { - case <-h.stop: - return - default: - if !reader.Retry() { - return - } - } - - time.Sleep(time.Second * 2) - logger.Info(path, globID, "Reading file again") + if args[1] == "close" && args[2] == "connection" { + close(h.ackCloseReceived) } } -func (h *ServerHandler) stopReadingFile(path string) { - logger.Info(h.user, "Stop reading file", path) +func (h *ServerHandler) send(ch chan<- string, message string) { + select { + case ch <- message: + case <-h.ctx.Done(): + } +} - h.fileReadersMtx.Lock() - defer h.fileReadersMtx.Unlock() +func (h *ServerHandler) sendServerMessage(message string) { + h.send(h.serverMessageC(), message) +} - path = filepath.Clean(path) - var fileReaders []fs.FileReader +func (h *ServerHandler) serverMessageC() chan<- string { + return h.serverMessages +} - for _, reader := range h.fileReaders { - if reader.FilePath() == path { - reader.Stop() - continue - } - fileReaders = append(fileReaders, reader) - } +func (h *ServerHandler) flush() { + logger.Debug(h.user, "flush()") - if len(fileReaders) == len(h.fileReaders) { - logger.Warn(h.user, "Didn't read file path", path) - return + if h.aggregate != nil { + h.aggregate.Flush() } - h.fileReaders = fileReaders - - if len(fileReaders) == 0 { - if h.aggregate != nil { - h.aggregate.Serialize() - } - h.allLinesSent() + unsentMessages := func() int { + return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages) } -} - -func (h *ServerHandler) numUnsentMessages() int { - return len(h.lines) + len(h.serverMessages) + len(h.hiddenMessages) + len(h.aggregatedMessages) -} - -func (h *ServerHandler) allLinesSent() { - defer h.internalClose() for i := 0; i < 3; i++ { - if h.numUnsentMessages() == 0 { + if unsentMessages() == 0 { logger.Debug(h.user, "All lines sent") return } @@ -351,142 +301,43 @@ func (h *ServerHandler) allLinesSent() { time.Sleep(time.Second) } - logger.Warn(h.user, "Some lines remain unsent", h.numUnsentMessages()) + logger.Warn(h.user, "Some lines remain unsent", unsentMessages()) } -// Handler decides to shutdown the connection, not the server itself. -func (h *ServerHandler) internalClose() { - select { - case h.hiddenMessages <- "syn close connection": - case <-time.After(time.Second * 5): - logger.Debug(h.user, "Not waiting for ack close connection") - close(h.stopTimeout) - return - } +func (h *ServerHandler) shutdown() { + logger.Debug(h.user, "shutdown()") + h.flush() + + go func() { + select { + case h.serverMessageC() <- ".syn close connection": + case <-h.ctx.Done(): + } + }() select { - case <-h.Wait(): + case <-h.ackCloseReceived: case <-time.After(time.Second * 5): - logger.Debug(h.user, "Not waiting for ack close connection") - close(h.stopTimeout) - } -} - -func (h *ServerHandler) handleCommand(commandStr string) { - logger.Info(h.user, commandStr) - - args := strings.Split(commandStr, " ") - argc := len(args) - - logger.Debug(h.user, "Received command", commandStr, argc, args) - - if h.user.Name == config.ControlUser { - h.handleControlCommand(argc, args) - return + logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") + case <-h.ctx.Done(): } - h.handleUserCommand(argc, args) -} - -// Special (restricted) set of commands for anonymous ControlUser access. -func (h *ServerHandler) handleControlCommand(argc int, args []string) { - switch args[0] { - case "ping": - h.send(h.hiddenMessages, "pong") - case "debug": - h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) - default: - logger.Warn(h.user, "Received unknown command", argc, args) - } -} - -// Commands for authed users. -func (h *ServerHandler) handleUserCommand(argc int, args []string) { - switch args[0] { - case "grep": - fallthrough - case "cat": - h.handleReadCommand(argc, args, omode.CatClient) - case "tail": - h.handleReadCommand(argc, args, omode.TailClient) - case "map": - h.handleMapCommand(argc, args) - case "ack": - h.handleAckCommand(argc, args) - case "ping": - h.send(h.hiddenMessages, "pong") - case "version": - h.send(h.serverMessages, fmt.Sprintf("Server version is "+version.String())) - case "debug": - h.send(h.serverMessages, logger.Debug(h.user, "Received debug command", argc, args)) + select { + case h.done <- struct{}{}: default: - h.send(h.serverMessages, logger.Warn(h.user, "Received unknown command", argc, args)) } } -func (h *ServerHandler) handleReadCommand(argc int, args []string, mode omode.Mode) { - regex := "." - if argc >= 4 { - regex = args[3] - } - if argc < 3 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - go h.processFileGlob(mode, args[1], regex) +func (h *ServerHandler) incrementActiveReaders() { + // TODO: Use atomic counter variable instead, so we can get rid of the mutex + h.mutex.Lock() + defer h.mutex.Unlock() + h.activeReaders++ } +func (h *ServerHandler) decrementActiveReaders() int { + h.mutex.Lock() + defer h.mutex.Unlock() + h.activeReaders-- -func (h *ServerHandler) handleMapCommand(argc int, args []string) { - if argc < 2 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - - queryStr := strings.Join(args[1:], " ") - logger.Info(h.user, "Creating new mapr aggregator", queryStr) - aggregate, err := server.NewAggregate(h.aggregatedMessages, queryStr) - - if err != nil { - h.send(h.serverMessages, logger.Error(h.user, err)) - return - } - - h.aggregate = aggregate -} - -func (h *ServerHandler) handleAckCommand(argc int, args []string) { - if argc < 3 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - if args[1] == "close" && args[2] == "connection" { - close(h.ackStopReceived) - } -} - -func (h *ServerHandler) send(ch chan<- string, message string) { - select { - case ch <- message: - case <-h.stop: - } -} - -// Wait (block) until server handler is closed or a timeout has exceeded. -func (h *ServerHandler) Wait() <-chan struct{} { - wait := make(chan struct{}) - - go func() { - select { - case <-h.ackStopReceived: - logger.Debug(h.user, "Closing wait channel due to ACK stop received") - close(wait) - case <-h.stopTimeout: - logger.Debug(h.user, "Closing wait channel due to wait timeout") - close(wait) - case <-h.stop: - logger.Debug(h.user, "Closing wait channel due to stop") - } - }() - - return wait + return h.activeReaders } diff --git a/internal/server/server.go b/internal/server/server.go index 27a98f5..42eb74c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,14 @@ package server import ( + "context" "errors" "fmt" "io" "net" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/server/handlers" "github.com/mimecast/dtail/internal/ssh/server" user "github.com/mimecast/dtail/internal/user/server" @@ -26,8 +27,6 @@ type Server struct { catLimiterCh chan struct{} // To control the max amount of concurrent tails tailLimiterCh chan struct{} - // Ask to shutdown the server - stop chan struct{} } // New returns a new server. @@ -38,7 +37,6 @@ func New() *Server { sshServerConfig: &gossh.ServerConfig{}, catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats), tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails), - stop: make(chan struct{}), } s.sshServerConfig.PasswordCallback = s.controlUserCallback @@ -54,7 +52,7 @@ func New() *Server { } // Start the server. -func (s *Server) Start() int { +func (s *Server) Start(ctx context.Context) int { logger.Info("Starting server") bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) @@ -64,7 +62,7 @@ func (s *Server) Start() int { logger.FatalExit("Failed to open listening TCP socket", err) } - go s.stats.periodicLogServerStats(s.stop) + go s.stats.periodicLogServerStats(ctx) for { conn, err := listener.Accept() // Blocking @@ -79,11 +77,11 @@ func (s *Server) Start() int { continue } - go s.handleConnection(conn) + go s.handleConnection(ctx, conn) } } -func (s *Server) handleConnection(conn net.Conn) { +func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { logger.Info("Handling connection") sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig) @@ -96,11 +94,11 @@ func (s *Server) handleConnection(conn net.Conn) { go gossh.DiscardRequests(reqs) for newChannel := range chans { - go s.handleChannel(sshConn, newChannel) + go s.handleChannel(ctx, sshConn, newChannel) } } -func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) { +func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) { user := user.New(sshConn.User(), sshConn.RemoteAddr().String()) logger.Info(user, "Invoking channel handler") @@ -117,13 +115,13 @@ func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) return } - if err := s.handleRequests(sshConn, requests, channel, user); err != nil { + if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil { logger.Error(user, err) sshConn.Close() } } -func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { +func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { logger.Info(user, "Invoking request handler") for req := range in { @@ -132,50 +130,50 @@ func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, ch switch req.Type { case "shell": + handlerCtx, cancel := context.WithCancel(ctx) + var handler handlers.Handler + var done <-chan struct{} + switch user.Name { case config.ControlUser: - handler = handlers.NewControlHandler(user) + handler, done = handlers.NewControlHandler(handlerCtx, user) default: - handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh) + handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh) } - // Bi-directionally connect SSH stream to SSH handler - brokenPipe1 := make(chan struct{}) go func() { - defer close(brokenPipe1) + // Handler finished work, cancel all remaining routines + defer cancel() + <-done + }() + + go func() { + // Broken pipe, cancel + defer cancel() + io.Copy(channel, handler) }() - brokenPipe2 := make(chan struct{}) go func() { - defer close(brokenPipe2) + // Broken pipe, cancel + defer cancel() + io.Copy(handler, channel) }() - // Ensure to close all fd's and stop all goroutines once ssh connection terminated go func() { - defer s.stats.decrementConnections() - defer handler.Close() + defer cancel() if err := sshConn.Wait(); err != nil && err != io.EOF { logger.Error(user, err) } + s.stats.decrementConnections() logger.Info(user, "Good bye Mister!") }() - // Close the underlying ssh socket when server shuts down go func() { - select { - case <-s.stop: - logger.Debug(user, "Server initiating shutdown on handler") - case <-handler.Wait(): - logger.Debug(user, "Handler initiating shutdown by its own") - case <-brokenPipe1: - logger.Debug(user, "Broken pipe1") - case <-brokenPipe2: - logger.Debug(user, "Broken pipe2") - } + <-handlerCtx.Done() sshConn.Close() logger.Info(user, "Closed SSH connection") }() @@ -204,9 +202,3 @@ func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*g return nil, fmt.Errorf("Not authorized") } - -// Stop the server. -func (s *Server) Stop() { - close(s.stop) - s.stats.waitForConnections() -} diff --git a/internal/server/stats.go b/internal/server/stats.go index beb1885..4d661f7 100644 --- a/internal/server/stats.go +++ b/internal/server/stats.go @@ -1,12 +1,14 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "context" "fmt" "runtime" "sync" "time" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" ) // Used to collect and display various server stats. @@ -65,12 +67,12 @@ func (s *stats) serverLimitExceeded() error { return nil } -func (s *stats) periodicLogServerStats(stop <-chan struct{}) { +func (s *stats) periodicLogServerStats(ctx context.Context) { for { select { case <-time.NewTimer(time.Second * 10).C: s.logServerStats() - case <-stop: + case <-ctx.Done(): return } } diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index 3392eb1..967866f 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -2,7 +2,7 @@ package client import ( "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/ssh" "os" diff --git a/internal/ssh/client/hostkeycallback.go b/internal/ssh/client/hostkeycallback.go index 4023e59..7ae2396 100644 --- a/internal/ssh/client/hostkeycallback.go +++ b/internal/ssh/client/hostkeycallback.go @@ -2,8 +2,7 @@ package client import ( "bufio" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/prompt" + "context" "fmt" "net" "os" @@ -11,6 +10,9 @@ import ( "sync" "time" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/prompt" + "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/knownhosts" ) @@ -116,7 +118,7 @@ func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback { // PromptAddHosts prompts a question to the user whether unknown hosts should // be added to the known hosts or not. -func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) { +func (h *HostKeyCallback) PromptAddHosts(ctx context.Context) { var hosts []unknownHost for { @@ -135,7 +137,7 @@ func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) { h.promptAddHosts(hosts) hosts = []unknownHost{} } - case <-stop: + case <-ctx.Done(): logger.Debug("Stopping goroutine prompting new hosts...") return } diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go index 7baa4aa..07790ad 100644 --- a/internal/ssh/server/hostkey.go +++ b/internal/ssh/server/hostkey.go @@ -2,7 +2,7 @@ package server import ( "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/ssh" "io/ioutil" "os" diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go index c6929d7..757def7 100644 --- a/internal/ssh/server/publickeycallback.go +++ b/internal/ssh/server/publickeycallback.go @@ -7,7 +7,7 @@ import ( osUser "os/user" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" user "github.com/mimecast/dtail/internal/user/server" gossh "golang.org/x/crypto/ssh" diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 77cc341..3a2e416 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -4,9 +4,9 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" - "github.com/mimecast/dtail/internal/logger" "encoding/pem" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "io/ioutil" "net" "os" diff --git a/internal/user/name.go b/internal/user/name.go index 5171ec7..28ab0a4 100644 --- a/internal/user/name.go +++ b/internal/user/name.go @@ -2,10 +2,10 @@ package user import ( "os/user" - ) +) - -func Name() string { +// NoRootCheck verifies that the DTail run user is not with UID or GID 0. +func NoRootCheck() { user, err := user.Current() if err != nil { panic(err) @@ -18,7 +18,14 @@ func Name() string { if user.Gid == "0" { panic("Not allowed to run as GID 0") } +} + +// Name of the current run user. +func Name() string { + user, err := user.Current() + if err != nil { + panic(err) + } return user.Username } - diff --git a/internal/user/server/user.go b/internal/user/server/user.go index fad38d8..271a4ac 100644 --- a/internal/user/server/user.go +++ b/internal/user/server/user.go @@ -1,14 +1,15 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/fs/permissions" - "github.com/mimecast/dtail/internal/logger" "fmt" "os" "path/filepath" "regexp" "strings" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/fs/permissions" + "github.com/mimecast/dtail/internal/io/logger" ) const maxLinkDepth int = 100 @@ -37,26 +38,28 @@ func (u *User) String() string { } // HasFilePermission is used to determine whether user is alowed to read a file. -func (u *User) HasFilePermission(filePath string) (hasPermission bool) { +func (u *User) HasFilePermission(filePath, permissionType string) (hasPermission bool) { + logger.Debug(u, filePath, permissionType, "Checking config permissions") + cleanPath, err := filepath.EvalSymlinks(filePath) if err != nil { - logger.Error(u, filePath, "Unable to evaluate symlinks", err) + logger.Error(u, filePath, permissionType, "Unable to evaluate symlinks", err) hasPermission = false return } cleanPath, err = filepath.Abs(cleanPath) if err != nil { - logger.Error(u, cleanPath, "Unable to make file path absolute", err) + logger.Error(u, cleanPath, permissionType, "Unable to make file path absolute", err) hasPermission = false return } if cleanPath != filePath { - logger.Info(u, filePath, cleanPath, "Calculated new clean path from original file path (possibly symlink)") + logger.Info(u, filePath, cleanPath, permissionType, "Calculated new clean path from original file path (possibly symlink)") } - hasPermission, err = u.hasFilePermission(cleanPath) + hasPermission, err = u.hasFilePermission(cleanPath, permissionType) if err != nil { logger.Warn(u, cleanPath, err) } @@ -64,12 +67,12 @@ func (u *User) HasFilePermission(filePath string) (hasPermission bool) { return } -func (u *User) hasFilePermission(cleanPath string) (bool, error) { +func (u *User) hasFilePermission(cleanPath, permissionType string) (bool, error) { // First check file system Linux/UNIX permission. if _, err := permissions.ToRead(u.Name, cleanPath); err != nil { - return false, fmt.Errorf("User without OS file system permissions to read file: '%v'", err) + return false, fmt.Errorf("User without OS file system permissions to read path: '%v'", err) } - logger.Info(u, cleanPath, "User has OS file system permissions to read file") + logger.Info(u, cleanPath, permissionType, "User with OS file system permissions to path") // If file system permission is given, also check permissions // as configured in DTail config file. @@ -84,7 +87,7 @@ func (u *User) hasFilePermission(cleanPath string) (bool, error) { var hasPermission bool var err error - if hasPermission, err = u.iteratePaths(cleanPath); err != nil { + if hasPermission, err = u.iteratePaths(cleanPath, permissionType); err != nil { return false, err } @@ -101,17 +104,28 @@ func (u *User) hasFilePermission(cleanPath string) (bool, error) { return hasPermission, nil } -func (u *User) iteratePaths(cleanPath string) (bool, error) { +func (u *User) iteratePaths(cleanPath, permissionType string) (bool, error) { for _, permission := range u.permissions { + typeStr := "readfiles" // Assume ReadFiles by default. + var regexStr string var negate bool + splitted := strings.Split(permission, ":") + if len(splitted) > 1 { + typeStr = splitted[0] + permission = strings.Join(splitted[1:], ":") + } + + if typeStr != permissionType { + continue + } + + regexStr = permission if strings.HasPrefix(permission, "!") { regexStr = permission[1:] negate = true } - regexStr = permission - negate = false re, err := regexp.Compile(regexStr) if err != nil { diff --git a/internal/version/version.go b/internal/version/version.go index 3a4a5dc..3c057df 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -7,18 +7,20 @@ import ( "github.com/mimecast/dtail/internal/color" ) -// Name of DTail. -const Name = "DTail" - -// Version of DTail. -const Version = "1.1.0" - -// Additional information. -const Additional = "" +const ( + // Name of DTail. + Name string = "DTail" + // Version of DTail. + Version string = "2.0.0" + // Additional information for DTail + Additional string = "" + // ProtocolCompat -ibility version. + ProtocolCompat string = "2" +) // String representation of the DTail version. func String() string { - return fmt.Sprintf("%s v%v %s", Name, Version, Additional) + return fmt.Sprintf("%s %v Protocol %s %s", Name, Version, ProtocolCompat, Additional) } // PaintedString is a prettier string representation of the DTail version. @@ -30,7 +32,7 @@ func PaintedString() string { version := color.Paint(color.Blue, Version) descr := color.Paint(color.Green, Additional) - return fmt.Sprintf("%s %v %s", name, version, descr) + return fmt.Sprintf("%s %v Protocol %s %s", name, version, ProtocolCompat, descr) } // PrintAndExit prints the program version and exists. diff --git a/samples/dtail.json.sample b/samples/dtail.json.sample index 99c0a73..6b713e8 100644 --- a/samples/dtail.json.sample +++ b/samples/dtail.json.sample @@ -10,16 +10,18 @@ "HostKeyBits" : 2048, "Permissions": { "Default": [ - "^/.*$" + "readfiles:^/.*$", + "runcommands:!^/.*$" ], "Users": { "pbuetow": [ - "^/.*$" + "readfiles:^/.*$" + "runcommands:^/.*$" ], "jblake": [ - "^/tmp/foo.log$", - "^/.*$", - "!^/tmp/bar.log$" + "readfiles:^/tmp/foo.log$", + "readfiles:^/.*$", + "readfiles:!^/tmp/bar.log$" ] } } |
