From 0945da8dfefcbb723eecea0e5f4eafff63398253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20B=C3=BCtow?= Date: Sun, 26 Jan 2020 11:26:53 +0000 Subject: Introduce drun command, refactor code to use context package --- Makefile | 16 +- cmd/dcat/main.go | 19 +- cmd/dexec/main.go | 80 ---- cmd/dgrep/main.go | 18 +- cmd/dmap/main.go | 19 +- cmd/drun/main.go | 82 ++++ cmd/dserver/main.go | 22 +- cmd/dtail/main.go | 19 +- doc/examples.md | 4 +- doc/installation.md | 5 +- doc/quickstart.md | 23 +- go.mod | 1 + go.sum | 9 + internal/clients/args.go | 3 +- internal/clients/baseclient.go | 130 +++--- internal/clients/catclient.go | 20 +- internal/clients/client.go | 5 +- internal/clients/connectionmaker.go | 12 - internal/clients/execclient.go | 48 --- internal/clients/grepclient.go | 20 +- internal/clients/handlers/basehandler.go | 84 ++-- internal/clients/handlers/clienthandler.go | 11 +- internal/clients/handlers/handler.go | 12 +- internal/clients/handlers/healthhandler.go | 21 +- internal/clients/handlers/maprhandler.go | 21 +- internal/clients/handlers/withcancel.go | 24 ++ internal/clients/healthclient.go | 7 +- internal/clients/maker.go | 8 + internal/clients/maprclient.go | 52 ++- internal/clients/remote/connection.go | 116 ++---- internal/clients/runclient.go | 40 ++ internal/clients/stats.go | 8 +- internal/clients/tailclient.go | 21 +- internal/discovery/comma.go | 2 +- internal/discovery/discovery.go | 21 +- internal/discovery/file.go | 2 +- internal/fs/catfile.go | 27 -- internal/fs/filereader.go | 9 - internal/fs/lineread.go | 28 -- internal/fs/permissions/permission.go | 14 - internal/fs/permissions/permission_linux.c | 395 ------------------- internal/fs/permissions/permission_linux.go | 33 -- internal/fs/permissions/permission_linux.h | 60 --- internal/fs/permissions/permission_test.go | 112 ------ internal/fs/readfile.go | 318 --------------- internal/fs/stats.go | 69 ---- internal/fs/tailfile.go | 27 -- internal/io/fs/catfile.go | 21 + internal/io/fs/filereader.go | 14 + internal/io/fs/permissions/permission.go | 14 + internal/io/fs/permissions/permission_linux.c | 395 +++++++++++++++++++ internal/io/fs/permissions/permission_linux.go | 33 ++ internal/io/fs/permissions/permission_linux.h | 60 +++ internal/io/fs/permissions/permission_test.go | 112 ++++++ internal/io/fs/readfile.go | 307 +++++++++++++++ internal/io/fs/stats.go | 69 ++++ internal/io/fs/tailfile.go | 21 + internal/io/line/line.go | 28 ++ internal/io/logger/logger.go | 445 +++++++++++++++++++++ internal/io/run/run.go | 104 +++++ internal/logger/logger.go | 457 ---------------------- internal/mapr/aggregateset.go | 5 +- internal/mapr/client/aggregate.go | 25 +- internal/mapr/groupset.go | 5 +- internal/mapr/logformat/parser.go | 2 +- internal/mapr/query.go | 2 +- internal/mapr/server/aggregate.go | 141 ++++--- internal/mapr/wherecondition.go | 2 +- internal/omode/mode.go | 6 +- internal/pprof/pprof.go | 3 +- internal/prompt/prompt.go | 2 +- internal/server/handlers/controlhandler.go | 42 +- internal/server/handlers/handler.go | 2 - internal/server/handlers/mapcommand.go | 35 ++ internal/server/handlers/readcommand.go | 158 ++++++++ internal/server/handlers/runcommand.go | 73 ++++ internal/server/handlers/serverhandler.go | 521 +++++++++---------------- internal/server/server.go | 70 ++-- internal/server/stats.go | 10 +- internal/ssh/client/authmethods.go | 2 +- internal/ssh/client/hostkeycallback.go | 10 +- internal/ssh/server/hostkey.go | 2 +- internal/ssh/server/publickeycallback.go | 2 +- internal/ssh/ssh.go | 2 +- internal/user/name.go | 15 +- internal/user/server/user.go | 44 ++- internal/version/version.go | 22 +- samples/dtail.json.sample | 12 +- 88 files changed, 2791 insertions(+), 2601 deletions(-) delete mode 100644 cmd/dexec/main.go create mode 100644 cmd/drun/main.go delete mode 100644 internal/clients/connectionmaker.go delete mode 100644 internal/clients/execclient.go create mode 100644 internal/clients/handlers/withcancel.go create mode 100644 internal/clients/maker.go create mode 100644 internal/clients/runclient.go delete mode 100644 internal/fs/catfile.go delete mode 100644 internal/fs/filereader.go delete mode 100644 internal/fs/lineread.go delete mode 100644 internal/fs/permissions/permission.go delete mode 100644 internal/fs/permissions/permission_linux.c delete mode 100644 internal/fs/permissions/permission_linux.go delete mode 100644 internal/fs/permissions/permission_linux.h delete mode 100644 internal/fs/permissions/permission_test.go delete mode 100644 internal/fs/readfile.go delete mode 100644 internal/fs/stats.go delete mode 100644 internal/fs/tailfile.go create mode 100644 internal/io/fs/catfile.go create mode 100644 internal/io/fs/filereader.go create mode 100644 internal/io/fs/permissions/permission.go create mode 100644 internal/io/fs/permissions/permission_linux.c create mode 100644 internal/io/fs/permissions/permission_linux.go create mode 100644 internal/io/fs/permissions/permission_linux.h create mode 100644 internal/io/fs/permissions/permission_test.go create mode 100644 internal/io/fs/readfile.go create mode 100644 internal/io/fs/stats.go create mode 100644 internal/io/fs/tailfile.go create mode 100644 internal/io/line/line.go create mode 100644 internal/io/logger/logger.go create mode 100644 internal/io/run/run.go delete mode 100644 internal/logger/logger.go create mode 100644 internal/server/handlers/mapcommand.go create mode 100644 internal/server/handlers/readcommand.go create mode 100644 internal/server/handlers/runcommand.go diff --git a/Makefile b/Makefile index 3480637..c358d8e 100644 --- a/Makefile +++ b/Makefile @@ -1,29 +1,31 @@ GO ?= go all: build build: + ${GO} build -o dserver ./cmd/dserver/main.go ${GO} build -o dcat ./cmd/dcat/main.go - ${GO} build -o dexec ./cmd/dexec/main.go ${GO} build -o dgrep ./cmd/dgrep/main.go ${GO} build -o dmap ./cmd/dmap/main.go - ${GO} build -o dserver ./cmd/dserver/main.go + ${GO} build -o drun ./cmd/drun/main.go ${GO} build -o dtail ./cmd/dtail/main.go clean: - rm -v dtail dgrep dcat dmap dserver dexec 2>/dev/null + ls ./cmd/ | while read cmd; do \ + test -f $$cmd && rm $$cmd; \ + done install: build + cp -pv dserver ${GOPATH}/bin/dserver cp -pv dcat ${GOPATH}/bin/dcat - cp -pv dexec ${GOPATH}/bin/dexec cp -pv dgrep ${GOPATH}/bin/dgrep cp -pv dmap ${GOPATH}/bin/dmap - cp -pv dserver ${GOPATH}/bin/dserver + cp -pv drun ${GOPATH}/bin/drun cp -pv dtail ${GOPATH}/bin/dtail vet: find . -type d | while read dir; do \ echo ${GO} vet $$dir; \ ${GO} vet $$dir; \ - done + done lint: ${GO} get golang.org/x/lint/golint find . -type d | while read dir; do \ echo ${GOPATH}/bin/golint $$dir; \ ${GOPATH}/bin/golint $$dir; \ - done + done diff --git a/cmd/dcat/main.go b/cmd/dcat/main.go index b02d369..1ec945d 100644 --- a/cmd/dcat/main.go +++ b/cmd/dcat/main.go @@ -1,12 +1,14 @@ package main import ( + "context" "flag" + "os" "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/user" "github.com/mimecast/dtail/internal/version" @@ -27,7 +29,6 @@ func main() { var sshPort int var trustAllHosts bool - pingTimeoutS := 60 userName := user.Name() flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") @@ -37,7 +38,6 @@ func main() { flag.BoolVar(&silentEnable, "silent", false, "Reduce output") flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") - flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") flag.IntVar(&sshPort, "port", 2222, "SSH server port") flag.StringVar(&cfgFile, "cfg", "", "Config file path") flag.StringVar(&discovery, "discovery", "", "Server discovery method") @@ -54,9 +54,10 @@ func main() { version.PrintAndExit() } + ctx := context.Background() serverEnable := false - logger.Start(serverEnable, debugEnable, silentEnable, silentEnable) - defer logger.Stop() + + logger.Start(ctx, serverEnable, debugEnable, silentEnable, silentEnable) if pprofEnable || config.Common.PProfEnable { pprof.Start() @@ -67,14 +68,16 @@ func main() { ServersStr: serversStr, Discovery: discovery, UserName: userName, - Files: files, + What: files, TrustAllHosts: trustAllHosts, - PingTimeout: pingTimeoutS, } client, err := clients.NewCatClient(args) if err != nil { panic(err) } - client.Start() + + status := client.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/cmd/dexec/main.go b/cmd/dexec/main.go deleted file mode 100644 index 7a7ab1f..0000000 --- a/cmd/dexec/main.go +++ /dev/null @@ -1,80 +0,0 @@ -package main - -import ( - "flag" - - "github.com/mimecast/dtail/internal/clients" - "github.com/mimecast/dtail/internal/color" - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/pprof" - "github.com/mimecast/dtail/internal/user" - "github.com/mimecast/dtail/internal/version" -) - -// The evil begins here. -func main() { - var cfgFile string - var connectionsPerCPU int - var debugEnable bool - var discovery string - var displayVersion bool - var command string - var noColor bool - var pprofEnable bool - var serversStr string - var silentEnable bool - var sshPort int - var trustAllHosts bool - - pingTimeoutS := 60 - userName := user.Name() - - flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") - flag.BoolVar(&displayVersion, "version", false, "Display version") - flag.BoolVar(&noColor, "noColor", false, "Disable ANSII terminal colors") - flag.BoolVar(&pprofEnable, "pprofEnable", false, "Enable pprof server") - flag.BoolVar(&silentEnable, "silent", false, "Reduce output") - flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") - flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") - flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") - flag.IntVar(&sshPort, "port", 2222, "SSH server port") - flag.StringVar(&cfgFile, "cfg", "", "Config file path") - flag.StringVar(&discovery, "discovery", "", "Server discovery method") - flag.StringVar(&command, "command", "", "Command to run") - flag.StringVar(&serversStr, "servers", "", "Remote servers to connect") - flag.StringVar(&userName, "user", userName, "Your system user name") - - flag.Parse() - - config.Read(cfgFile, sshPort) - color.Colored = !noColor - - if displayVersion { - version.PrintAndExit() - } - - serverEnable := false - logger.Start(serverEnable, debugEnable, silentEnable, silentEnable) - defer logger.Stop() - - if pprofEnable || config.Common.PProfEnable { - pprof.Start() - } - - args := clients.Args{ - ConnectionsPerCPU: connectionsPerCPU, - ServersStr: serversStr, - Discovery: discovery, - UserName: userName, - Files: files, - TrustAllHosts: trustAllHosts, - PingTimeout: pingTimeoutS, - } - - client, err := clients.NewExecClient(args) - if err != nil { - panic(err) - } - client.Start() -} diff --git a/cmd/dgrep/main.go b/cmd/dgrep/main.go index d1a7d52..74a501f 100644 --- a/cmd/dgrep/main.go +++ b/cmd/dgrep/main.go @@ -1,12 +1,14 @@ package main import ( + "context" "flag" + "os" "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/user" "github.com/mimecast/dtail/internal/version" @@ -28,7 +30,6 @@ func main() { var sshPort int var trustAllHosts bool - pingTimeoutS := 60 userName := user.Name() flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") @@ -38,7 +39,6 @@ func main() { flag.BoolVar(&silentEnable, "silent", false, "Reduce output") flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") - flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") flag.IntVar(&sshPort, "port", 2222, "SSH server port") flag.StringVar(&cfgFile, "cfg", "", "Config file path") flag.StringVar(&discovery, "discovery", "", "Server discovery method") @@ -56,9 +56,9 @@ func main() { version.PrintAndExit() } + ctx := context.Background() serverEnable := false - logger.Start(serverEnable, debugEnable, silentEnable, silentEnable) - defer logger.Stop() + logger.Start(ctx, serverEnable, debugEnable, silentEnable, silentEnable) if pprofEnable || config.Common.PProfEnable { pprof.Start() @@ -69,9 +69,8 @@ func main() { ServersStr: serversStr, Discovery: discovery, UserName: userName, - Files: files, + What: files, TrustAllHosts: trustAllHosts, - PingTimeout: pingTimeoutS, Regex: regex, } @@ -79,5 +78,8 @@ func main() { if err != nil { panic(err) } - client.Start() + + status := client.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/cmd/dmap/main.go b/cmd/dmap/main.go index 83dad50..f3f706a 100644 --- a/cmd/dmap/main.go +++ b/cmd/dmap/main.go @@ -1,13 +1,15 @@ package main import ( + "context" "flag" + "os" - "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/user" "github.com/mimecast/dtail/internal/version" @@ -29,7 +31,6 @@ func main() { var sshPort int var trustAllHosts bool - pingTimeoutS := 900 userName := user.Name() flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") @@ -39,7 +40,6 @@ func main() { flag.BoolVar(&silentEnable, "silent", false, "Reduce output") flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") - flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") flag.IntVar(&sshPort, "port", 2222, "SSH server port") flag.StringVar(&cfgFile, "cfg", "", "Config file path") flag.StringVar(&discovery, "discovery", "", "Server discovery method") @@ -57,10 +57,10 @@ func main() { version.PrintAndExit() } + ctx := context.Background() serverEnable := false - logger.Start(serverEnable, debugEnable, silentEnable, silentEnable) - defer logger.Stop() + logger.Start(ctx, serverEnable, debugEnable, silentEnable, silentEnable) if pprofEnable || config.Common.PProfEnable { pprof.Start() } @@ -70,9 +70,8 @@ func main() { ServersStr: serversStr, Discovery: discovery, UserName: userName, - Files: files, + What: files, TrustAllHosts: trustAllHosts, - PingTimeout: pingTimeoutS, Mode: omode.MapClient, } @@ -81,5 +80,7 @@ func main() { panic(err) } - client.Start() + status := client.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/cmd/drun/main.go b/cmd/drun/main.go new file mode 100644 index 0000000..b1936d4 --- /dev/null +++ b/cmd/drun/main.go @@ -0,0 +1,82 @@ +package main + +import ( + "context" + "flag" + "os" + + "github.com/mimecast/dtail/internal/clients" + "github.com/mimecast/dtail/internal/color" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/pprof" + "github.com/mimecast/dtail/internal/user" + "github.com/mimecast/dtail/internal/version" +) + +// The evil begins here. +func main() { + var cfgFile string + var command string + var connectionsPerCPU int + var debugEnable bool + var discovery string + var displayVersion bool + var noColor bool + var pprofEnable bool + var serversStr string + var silentEnable bool + var sshPort int + var trustAllHosts bool + + userName := user.Name() + + flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") + flag.BoolVar(&displayVersion, "version", false, "Display version") + flag.BoolVar(&noColor, "noColor", false, "Disable ANSII terminal colors") + flag.BoolVar(&pprofEnable, "pprofEnable", false, "Enable pprof server") + flag.BoolVar(&silentEnable, "silent", false, "Reduce output") + flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") + flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") + flag.IntVar(&sshPort, "port", 2222, "SSH server port") + flag.StringVar(&cfgFile, "cfg", "", "Config file path") + flag.StringVar(&command, "command", "", "Command to run") + flag.StringVar(&discovery, "discovery", "", "Server discovery method") + flag.StringVar(&serversStr, "servers", "", "Remote servers to connect") + flag.StringVar(&userName, "user", userName, "Your system user name") + + flag.Parse() + + config.Read(cfgFile, sshPort) + color.Colored = !noColor + + if displayVersion { + version.PrintAndExit() + } + + ctx := context.Background() + serverEnable := false + + logger.Start(ctx, serverEnable, debugEnable, silentEnable, silentEnable) + if pprofEnable || config.Common.PProfEnable { + pprof.Start() + } + + args := clients.Args{ + ConnectionsPerCPU: connectionsPerCPU, + ServersStr: serversStr, + Discovery: discovery, + UserName: userName, + What: command, + TrustAllHosts: trustAllHosts, + } + + client, err := clients.NewRunClient(args) + if err != nil { + panic(err) + } + + status := client.Start(ctx) + logger.Flush() + os.Exit(status) +} diff --git a/cmd/dserver/main.go b/cmd/dserver/main.go index 489910b..aa209a8 100644 --- a/cmd/dserver/main.go +++ b/cmd/dserver/main.go @@ -1,13 +1,14 @@ package main import ( + "context" "flag" "os" "time" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/server" "github.com/mimecast/dtail/internal/user" @@ -24,7 +25,7 @@ func main() { var shutdownAfter int var sshPort int - userName := user.Name() + user.NoRootCheck() flag.BoolVar(&debugEnable, "debug", false, "Activate debug messages") flag.BoolVar(&displayVersion, "version", false, "Display version") @@ -43,19 +44,23 @@ func main() { version.PrintAndExit() } + ctx := context.Background() + serverEnable := true silentEnable := false nothingEnable := false - logger.Start(serverEnable, debugEnable, silentEnable, nothingEnable) - defer logger.Stop() + logger.Start(ctx, serverEnable, debugEnable, silentEnable, nothingEnable) if shutdownAfter > 0 { go func() { defer os.Exit(1) logger.Info("Enabling auto shutdown timer", shutdownAfter) - time.Sleep(time.Duration(shutdownAfter) * time.Second) - logger.Info("Auto shutdown timer reached, shutting down now") + select { + case <-time.After(time.Duration(shutdownAfter) * time.Second): + logger.Info("Auto shutdown timer reached, shutting down now") + case <-ctx.Done(): + } }() } @@ -63,7 +68,8 @@ func main() { pprof.Start() } - logger.Info("Launching server", version.String(), userName) sshServer := server.New() - sshServer.Start() + status := sshServer.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/cmd/dtail/main.go b/cmd/dtail/main.go index 1bf77c7..76070ff 100644 --- a/cmd/dtail/main.go +++ b/cmd/dtail/main.go @@ -1,13 +1,14 @@ package main import ( + "context" "flag" "os" "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/pprof" "github.com/mimecast/dtail/internal/user" @@ -32,7 +33,6 @@ func main() { var sshPort int var trustAllHosts bool - pingTimeoutS := 5 userName := user.Name() flag.BoolVar(&checkHealth, "checkHealth", false, "Only check for server health") @@ -43,7 +43,6 @@ func main() { flag.BoolVar(&silentEnable, "silent", false, "Reduce output") flag.BoolVar(&trustAllHosts, "trustAllHosts", false, "Auto trust all unknown host keys") flag.IntVar(&connectionsPerCPU, "cpc", 10, "How many connections established per CPU core concurrently") - flag.IntVar(&pingTimeoutS, "pingTimeout", 10, "The server ping timeout (0 means disable pings)") flag.IntVar(&sshPort, "port", 2222, "SSH server port") flag.StringVar(&cfgFile, "cfg", "", "Config file path") flag.StringVar(&discovery, "discovery", "", "Server discovery method") @@ -62,17 +61,18 @@ func main() { version.PrintAndExit() } + ctx := context.Background() + if checkHealth { healthClient, _ := clients.NewHealthClient(omode.HealthClient) - os.Exit(healthClient.Start()) + os.Exit(healthClient.Start(ctx)) } serverEnable := false if checkHealth { silentEnable = true } - logger.Start(serverEnable, debugEnable, silentEnable, silentEnable) - defer logger.Stop() + logger.Start(ctx, serverEnable, debugEnable, silentEnable, silentEnable) if pprofEnable || config.Common.PProfEnable { pprof.Start() @@ -83,9 +83,8 @@ func main() { ServersStr: serversStr, Discovery: discovery, UserName: userName, - Files: files, + What: files, TrustAllHosts: trustAllHosts, - PingTimeout: pingTimeoutS, Regex: regex, Mode: omode.TailClient, } @@ -104,5 +103,7 @@ func main() { } } - client.Start() + status := client.Start(ctx) + logger.Flush() + os.Exit(status) } diff --git a/doc/examples.md b/doc/examples.md index 959105c..964660a 100644 --- a/doc/examples.md +++ b/doc/examples.md @@ -25,7 +25,7 @@ To run ad-hoc mapreduce aggregations on newly written log lines you also must ad --files '/var/log/service/*.log' ``` -In order for mapreduce queries to work you have to make sure that your log format is supported by DTail. You can either use the ones which are already defined in ``mapr/logformat`` or add an extension to support a custom log format. +In order for mapreduce queries to work you have to make sure that your log format is supported by DTail. You can either use the ones which are already defined in ``internal/mapr/logformat`` or add an extension to support a custom log format. ![dtail-map](dtail-map.gif "Tail mapreduce example") @@ -62,6 +62,6 @@ To run a mapreduce aggregation over logs written in the past the ``dmap`` comman --files "/var/log/service/*.log" ``` -Remember: In order for that to work you have to make sure that your log format is supported by DTail. You can either use the ones which are already defined in ``mapr/logformat`` or add an extension to support a custom log format. +Remember: In order for that to work you have to make sure that your log format is supported by DTail. You can either use the ones which are already defined in ``internal/mapr/logformat`` or add an extension to support a custom log format. ![dmap](dmap.gif "DMap example") diff --git a/doc/installation.md b/doc/installation.md index 305eae5..a15beb1 100644 --- a/doc/installation.md +++ b/doc/installation.md @@ -3,8 +3,6 @@ DTail Installation Guide The following installation guide has been tested successfully on CentOS 7. You may need to adjust accordingly depending on the distribution you use. -This guide also assumes that you know how to use ``systemd`` and how to configure a service there. If you are unsure please consult the documentation of your distribution. - # Compile it Please check the [Quick Starting Guide](quickstart.md) for instructions how to compile DTail. It is recommended to automate the build process via your build pipeline (e.g. produce a deployable RPM via Jenkins). You don't have to use ``go get...`` to compile and install the binaries. You can also clone the repository and use ``make`` instead. @@ -12,6 +10,7 @@ Please check the [Quick Starting Guide](quickstart.md) for instructions how to c # Install it It is recommended to automate all the installation process outlined here. You could use a configuration management system such as Puppet, Chef or Ansible. However, that relies heavily on how your infrastructure is managed and is out of scope of this documentation. + 1. The ``dserver`` binary has to be installed on all machines (server boxes) involved. A good location for the binary would be ``/usr/local/bin/dserver`` with permissions set as follows: ```console @@ -72,7 +71,7 @@ Now you should be able to use DTail client like outlined in the [Quick Starting # Monitor it -To verify that DTail server is up and running and functioning as expected you should configure the Nagios check [check_dserver.sh](../samples/check_dserver.sh.sample) in your monitoring system. The check has to be executed locally on the server (e.g. via NRPE). How to configure the monitoring system in detail is out of scope of this guide, as it depends on the monitoring infrastructure used. +To verify that DTail server is up and running and functioning as expected you should configure the Nagios check [check_dserver.sh](../samples/check_dserver.sh.sample) in your monitoring system. The check has to be executed locally on the server (e.g. via NRPE). How to configure the monitoring system in detail is out of scope of this guide. ```console % ./check_dserver.sh diff --git a/doc/quickstart.md b/doc/quickstart.md index 46f7fae..7b6fbf4 100644 --- a/doc/quickstart.md +++ b/doc/quickstart.md @@ -3,13 +3,11 @@ Quick Starting Guide This is the quick starting guide. For a more sustainable setup, involving how to create a background service via ``systemd``, recommendations about automation via Jenkins and/or Puppet and health monitoring via Nagios please also follow the [Installation Guide](installation.md). -This guide assumes that you know how to generate and configure a public/private SSH key pair for secure authorization and shell access. That is out of scope of this guide. For more information please have a look at the OpenSSH documentation of your distribution. +This guide assumes that you know how to generate and configure a public/private SSH key pair for secure authorization and shell access. For more information please have a look at the OpenSSH documentation of your distribution. -This guide also assumes that you know how to install and use a Go compiler and GNU make. +# Install it -# Compile it - -To install all DTail binaries from github run: +To compile and install all DTail binaries directly from GitHub run: ```console % go get github.com/mimecast/dtail/cmd/dtail @@ -48,7 +46,9 @@ SERVER|serv-001|INFO|Binding server|0.0.0.0:2222 Make sure that your public SSH key is listed in ``~/.ssh/authorized_keys`` on all server machines involved. The private SSH key counterpart should preferably stay on your Laptop or workstation in ``~/.ssh/id_rsa`` or ``~/.ssh/id_dsa``. -DTail utilises the SSH Agent for SSH authentication. This is to avoid entering the passphrase of the private SSH key over and over again when a new SSH session is initiated from the DTail client to a new DTail server. For this the private SSH key has to be registered at the SSH Agent: +DTail relies on SSH for secure authentication and communication. The clients (all client binaries such as ``dtail``, ``dgrep`` and so on...) communicate with an auth backend via the SSH auth socket. The SSH auth socket is configured via the environment variable ``SSH_AUTH_SOCK`` which usually points to ``~/.ssh/ssh_auth_socket`` or similar (depending on your configuration it may also point to other auth backends such as GPG Agent, in which case ``SSH_AUTH_SOCK`` would point to ``~/.gnupg/S.gpg-agent.ssh`` or similar). + +Usually you would use the SSH Auth Agent. For this the private SSH key has to be registered at the SSH Agent: ```console % ssh-add ~/.ssh/id_rsa @@ -56,16 +56,17 @@ Enter passphrase for ~/.ssh/id_rsa: ********** Identity added: ~/.ssh/id_rsa (~/.ssh/id_rsa) ``` -The DTail client communicates with the SSH Agent through ``~/.ssh/ssh_auth_socket`` whenever a new session to a DTail server is established. - To test whether SSH is setup correctly you should be able to SSH into the servers with the OpenSSH client and your private SSH key through the SSH Agent without entering the private keys passphrase. The following assumes to have an OpenSSH server running on ``serv-001.lan.example.org`` and an OpenSSH client installed on your laptop or workstation. Please notice that DTail does not require to have an OpenSSH infrastructure set up but DTail uses by default the same public/private key file paths as OpenSSH. OpenSSH can be of a great help to verify that the SSH keys are configured correctly: ```console -% ssh serv-001.lan.example.org -% -% exit +workstation01 ~ % ssh serv-001.lan.example.org +serv-001 ~ % +serv-001 ~ % exit +workstation01 ~ % ``` +Please consult the OpenSSH documentation of your distribution if the test above does not work for you. + ## Run DTail client Now it is time to connect to the DTail servers through the DTail client: diff --git a/go.mod b/go.mod index bba791a..9a50e44 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ go 1.13 require ( github.com/DataDog/zstd v1.4.4 golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 + golang.org/x/lint v0.0.0-20200130185559-910be7a94367 // indirect ) diff --git a/go.sum b/go.sum index cbb61a3..39af4c9 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,19 @@ github.com/DataDog/zstd v1.4.4 h1:+IawcoXhCBylN7ccwdwf8LOH2jKq7NavGpEPanrlTzE= github.com/DataDog/zstd v1.4.4/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90onWEf1dF4C+0hPJCc9Mpc= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/lint v0.0.0-20200130185559-910be7a94367 h1:0IiAsCRByjO2QjX7ZPkw5oU9x+n1YqRL802rjC0c3Aw= +golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7 h1:EBZoQjiKKPaLbPrbpssUfuHtwM6KV/vb4U85g/cigFY= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/clients/args.go b/internal/clients/args.go index 5fe0a72..dea5a9e 100644 --- a/internal/clients/args.go +++ b/internal/clients/args.go @@ -9,10 +9,9 @@ type Args struct { Mode omode.Mode ServersStr string UserName string - Files string + What string Regex string TrustAllHosts bool Discovery string ConnectionsPerCPU int - PingTimeout int } diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 574ae94..b1540ea 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -1,13 +1,14 @@ package clients import ( + "context" "regexp" "sync" "time" "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/discovery" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/ssh/client" @@ -27,111 +28,110 @@ type baseClient struct { sshAuthMethods []gossh.AuthMethod // To deal with SSH host keys hostKeyCallback *client.HostKeyCallback - // To stop the client. - stop chan struct{} - // To indicate that the client has stopped. - stopped chan struct{} // Throttle how fast we initiate SSH connections concurrently throttleCh chan struct{} // Retry connection upon failure? retry bool - // Connection helper. - maker connectionMaker + // Connection maker helper. + maker maker } -func (c *baseClient) init(maker connectionMaker) { +func (c *baseClient) init(maker maker) { logger.Info("Initiating base client") c.maker = maker - //c.connections = make(map[string]*remote.Connection) c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh) + discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle) - // Retrieve a shuffled list of remote dtail servers. - shuffleServers := true - discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers) for _, server := range discoveryService.ServerList() { - c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) + c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) } if _, err := regexp.Compile(c.Regex); err != nil { logger.FatalExit(c.Regex, "Can't test compile regex", err) } - // Periodically check for unknown hosts, and ask the user whether to trust them or not. - go c.hostKeyCallback.PromptAddHosts(c.stop) - - // Periodically print out connection stats to the client. c.stats = newTailStats(len(c.connections)) - go c.stats.periodicLogStats(c.throttleCh, c.stop) } -func (c *baseClient) Start() (status int) { +func (c *baseClient) Start(ctx context.Context) (status int) { + // Periodically check for unknown hosts, and ask the user whether to trust them or not. + go c.hostKeyCallback.PromptAddHosts(ctx) + // Periodically print out connection stats to the client. + go c.stats.periodicLogStats(ctx, c.throttleCh) + // Keep count of active connections active := make(chan struct{}, len(c.connections)) - var wg sync.WaitGroup - wg.Add(len(c.connections)) - + var mutex sync.Mutex for i, conn := range c.connections { go func(i int, conn *remote.Connection) { - active <- struct{}{} - defer func() { - logger.Debug(conn.Server, "Disconnected completely...") - <-active - }() - wg.Done() - - for { - conn.Start(c.throttleCh, c.stats.connectionsEstCh) - if !c.retry { - return - } - time.Sleep(time.Second * 2) - logger.Debug(conn.Server, "Reconencting") - conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) - c.connections[i] = conn + connStatus := c.start(ctx, active, i, conn) + + // Update global status. + mutex.Lock() + defer mutex.Unlock() + if connStatus > status { + status = connStatus } }(i, conn) } - wg.Wait() - c.waitUntilDone(active) - + c.waitUntilDone(ctx, active) return } -func (c *baseClient) waitUntilDone(active chan struct{}) { - defer close(c.stopped) +func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) { + // Increment connection count + active <- struct{}{} + // Derement connection count + defer func() { <-active }() - if c.Mode != omode.TailClient { - c.waitUntilZero(active) - logger.Info("All connections stopped") - return - } + for { + connCtx, cancel := conn.Handler.WithCancel(ctx) + defer cancel() - <-c.stop - logger.Info("Stopping client") - for _, conn := range c.connections { - conn.Stop() + conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh) + // Retrieve status code from handler (dtail client will exit with that status) + status = conn.Handler.Status() + + if !c.retry { + return + } + + time.Sleep(time.Second * 2) + logger.Debug(conn.Server, "Reconnecting") + + conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) + c.connections[i] = conn } +} - c.waitUntilZero(active) +func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = c.maker.makeHandler(server) + conn.Commands = c.maker.makeCommands() + + return conn } -func (c *baseClient) waitUntilZero(active chan struct{}) { +func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) { + defer logger.Info("Terminated connection") + + // We want to have at least one active connection + <-active + // Put it back on the channel + active <- struct{}{} + + if c.Mode == omode.TailClient { + <-ctx.Done() + } + for { - logger.Debug("Active connections", len(active)) - if len(active) == 0 { + numActive := len(active) + if numActive == 0 { return } + logger.Debug("Active connections", numActive) time.Sleep(time.Second) } } - -func (c *baseClient) Stop() { - close(c.stop) - <-c.WaitC() -} - -func (c *baseClient) WaitC() <-chan struct{} { - return c.stopped -} diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go index 5ea701d..7fd6bdc 100644 --- a/internal/clients/catclient.go +++ b/internal/clients/catclient.go @@ -7,11 +7,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // CatClient is a client for returning a whole file from the beginning to the end. @@ -31,8 +27,6 @@ func NewCatClient(args Args) (*CatClient, error) { c := CatClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, }, @@ -43,11 +37,13 @@ func NewCatClient(args Args) (*CatClient, error) { return &c, nil } -func (c CatClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c CatClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} + +func (c CatClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - return conn + return } diff --git a/internal/clients/client.go b/internal/clients/client.go index 85d1aae..1fc5e23 100644 --- a/internal/clients/client.go +++ b/internal/clients/client.go @@ -1,7 +1,8 @@ package clients +import "context" + // Client is the interface for the end user command line client. type Client interface { - Start() int - Stop() + Start(ctx context.Context) int } diff --git a/internal/clients/connectionmaker.go b/internal/clients/connectionmaker.go deleted file mode 100644 index 0617992..0000000 --- a/internal/clients/connectionmaker.go +++ /dev/null @@ -1,12 +0,0 @@ -package clients - -import ( - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" -) - -type connectionMaker interface { - makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection -} diff --git a/internal/clients/execclient.go b/internal/clients/execclient.go deleted file mode 100644 index 10bd081..0000000 --- a/internal/clients/execclient.go +++ /dev/null @@ -1,48 +0,0 @@ -package clients - -import ( - "fmt" - "runtime" - "strings" - - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" -) - -// ExecClient is a client for execute various commands on the server. -type ExecClient struct { - baseClient -} - -// NewExecClient returns a new cat client. -func NewExecClient(args Args) (*ExecClient, error) { - args.Regex = "." - args.Mode = omode.ExecClient - - c := ExecClient{ - baseClient: baseClient{ - Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), - throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), - retry: false, - }, - } - - c.init(c) - - return &c, nil -} - -func (c ExecClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) - for _, file := range strings.Split(c.Files, ";") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s", c.Mode.String(), file)) - } - return conn -} diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go index c568f63..8d11458 100644 --- a/internal/clients/grepclient.go +++ b/internal/clients/grepclient.go @@ -7,11 +7,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed. @@ -29,8 +25,6 @@ func NewGrepClient(args Args) (*GrepClient, error) { c := GrepClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, }, @@ -41,13 +35,13 @@ func NewGrepClient(args Args) (*GrepClient, error) { return &c, nil } -func (c GrepClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) +func (c GrepClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c GrepClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - - return conn + return } diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 19246f9..68b8ddc 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -1,60 +1,44 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" - "errors" + "encoding/base64" "fmt" "io" + "strconv" "strings" "time" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/version" ) type baseHandler struct { + withCancel server string shellStarted bool commands chan string - pong chan struct{} receiveBuf []byte - stop chan struct{} - pingTimeout int + status int } func (h *baseHandler) Server() string { return h.server } -// Used to determine whether server is still responding to requests or not. -func (h *baseHandler) Ping() error { - if h.pingTimeout == 0 { - // Server ping disabled - return nil - } - - if err := h.SendCommand("ping"); err != nil { - return err - } - - select { - case <-h.pong: - return nil - case <-time.After(time.Duration(h.pingTimeout) * time.Second): - } - - return errors.New("Didn't receive any server pongs (ping replies)") +func (h *baseHandler) Status() int { + return h.status } -func (h *baseHandler) SendCommand(command string) error { - if command == "ping" { - logger.Trace("Sending command", h.server, command) - } else { - logger.Debug("Sending command", h.server, command) - } +// SendMessage to the server. +func (h *baseHandler) SendMessage(command string) error { + encoded := base64.StdEncoding.EncodeToString([]byte(command)) + logger.Debug("Sending command", h.server, command, encoded) select { - case h.commands <- fmt.Sprintf("%s;", command): + case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded): case <-time.After(time.Second * 5): - return errors.New("Timed out sending command " + command) - case <-h.stop: + return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded) + case <-h.ctx.Done(): } return nil @@ -81,7 +65,7 @@ func (h *baseHandler) Read(p []byte) (n int, err error) { select { case command := <-h.commands: n = copy(p, []byte(command)) - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF } return @@ -92,6 +76,7 @@ func (h *baseHandler) handleMessageType(message string) { if len(h.receiveBuf) == 0 { return } + // Hidden server commands starti with a dot "." if h.receiveBuf[0] == '.' { h.handleHiddenMessage(message) @@ -108,6 +93,7 @@ func (h *baseHandler) handleMessageType(message string) { h.receiveBuf = h.receiveBuf[:0] return } + logger.Raw(message) h.receiveBuf = h.receiveBuf[:0] } @@ -116,19 +102,27 @@ func (h *baseHandler) handleMessageType(message string) { // to the end user. func (h *baseHandler) handleHiddenMessage(message string) { switch { - case strings.HasPrefix(message, ".pong"): - h.pong <- struct{}{} case strings.HasPrefix(message, ".syn close connection"): - h.SendCommand("ack close connection") - } -} + h.SendMessage(".ack close connection") + select { + case <-time.After(time.Second * 1): + logger.Debug("Shutting down client after timeout and sending ack to server") + h.withCancel.shutdown() + case <-h.ctx.Done(): + } -// Stop the handler. -func (h *baseHandler) Stop() { - select { - case <-h.stop: - default: - logger.Debug("Stopping base handler", h.server) - close(h.stop) + case strings.HasPrefix(message, ".run exitstatus"): + splitted := strings.Split(strings.TrimSuffix(message, "\n"), " ") + if len(splitted) != 3 { + logger.Error("Unable to retrieve exitstatus", message) + return + } + i, err := strconv.Atoi(splitted[2]) + if err != nil { + logger.Error("Unable to retrieve exitstatus", message, err) + return + } + h.status = i + logger.Debug("Retrieved exitstatus", h.status) } } diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go index 4738cd3..fcd8052 100644 --- a/internal/clients/handlers/clienthandler.go +++ b/internal/clients/handlers/clienthandler.go @@ -1,7 +1,7 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" ) // ClientHandler is the basic client handler interface. @@ -10,7 +10,7 @@ type ClientHandler struct { } // NewClientHandler creates a new client handler. -func NewClientHandler(server string, pingTimeout int) *ClientHandler { +func NewClientHandler(server string) *ClientHandler { logger.Debug(server, "Creating new client handler") return &ClientHandler{ @@ -18,9 +18,10 @@ func NewClientHandler(server string, pingTimeout int) *ClientHandler { server: server, shellStarted: false, commands: make(chan string), - pong: make(chan struct{}, 1), - stop: make(chan struct{}), - pingTimeout: pingTimeout, + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, }, } } diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go index 2013be0..c53ca34 100644 --- a/internal/clients/handlers/handler.go +++ b/internal/clients/handlers/handler.go @@ -1,12 +1,16 @@ package handlers -import "io" +import ( + "context" + "io" +) // Handler provides all methods which can be run on any client handler. type Handler interface { io.ReadWriter - Ping() error - Stop() - SendCommand(command string) error + SendMessage(command string) error Server() string + Status() int + WithCancel(ctx context.Context) (context.Context, context.CancelFunc) + Done() <-chan struct{} } diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go index 4051e2c..9051015 100644 --- a/internal/clients/handlers/healthhandler.go +++ b/internal/clients/handlers/healthhandler.go @@ -8,6 +8,7 @@ import ( // HealthHandler implements the handler required for health checks. type HealthHandler struct { + withCancel // Buffer of incoming data from server. receiveBuf []byte // To send commands to the server. @@ -16,6 +17,7 @@ type HealthHandler struct { receive chan<- string // The remote server address server string + status int } // NewHealthHandler returns a new health check handler. @@ -24,6 +26,10 @@ func NewHealthHandler(server string, receive chan<- string) *HealthHandler { server: server, receive: receive, commands: make(chan string), + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, } return &h @@ -34,18 +40,13 @@ func (h *HealthHandler) Server() string { return h.server } -// Stop is not of use for health check handler. -func (h *HealthHandler) Stop() { - // Nothing done here. +// Status of the handler. +func (h *HealthHandler) Status() int { + return h.status } -// Ping is not of use for health check handler. -func (h *HealthHandler) Ping() error { - return nil -} - -// SendCommand send a DTail command to the server. -func (h *HealthHandler) SendCommand(command string) error { +// SendMessage sends a DTail command to the server. +func (h *HealthHandler) SendMessage(command string) error { select { case h.commands <- fmt.Sprintf("%s;", command): case <-time.NewTimer(time.Second * 10).C: diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go index d76cdfd..874bb7d 100644 --- a/internal/clients/handlers/maprhandler.go +++ b/internal/clients/handlers/maprhandler.go @@ -1,10 +1,11 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" + "strings" + + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/mapr/client" - "strings" ) // MaprHandler is the handler used on the client side for running mapreduce aggregations. @@ -16,15 +17,16 @@ type MaprHandler struct { } // NewMaprHandler returns a new mapreduce client handler. -func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet, pingTimeout int) *MaprHandler { +func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *MaprHandler { return &MaprHandler{ baseHandler: baseHandler{ server: server, shellStarted: false, commands: make(chan string), - pong: make(chan struct{}, 1), - stop: make(chan struct{}), - pingTimeout: pingTimeout, + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, }, query: query, aggregate: client.NewAggregate(server, query, globalGroup), @@ -65,10 +67,3 @@ func (h *MaprHandler) handleAggregateMessage(message string) { h.aggregate.Aggregate(parts[2:]) logger.Debug("Aggregated aggregate data", h.server, h.count) } - -// Stop stops the mapreduce client handler. -func (h *MaprHandler) Stop() { - logger.Debug("Stopping mapreduce handler", h.server) - h.aggregate.Stop() - h.baseHandler.Stop() -} diff --git a/internal/clients/handlers/withcancel.go b/internal/clients/handlers/withcancel.go new file mode 100644 index 0000000..7c9cf4e --- /dev/null +++ b/internal/clients/handlers/withcancel.go @@ -0,0 +1,24 @@ +package handlers + +import "context" + +type withCancel struct { + ctx context.Context + done chan struct{} +} + +// WithCancel sets and returns the context used. +func (w *withCancel) WithCancel(ctx context.Context) (context.Context, context.CancelFunc) { + cancelCtx, cancel := context.WithCancel(ctx) + w.ctx = cancelCtx + + return cancelCtx, cancel +} + +func (w *withCancel) Done() <-chan struct{} { + return w.done +} + +func (w *withCancel) shutdown() { + close(w.done) +} diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go index ff13b83..7313583 100644 --- a/internal/clients/healthclient.go +++ b/internal/clients/healthclient.go @@ -1,6 +1,7 @@ package clients import ( + "context" "fmt" "runtime" "strings" @@ -39,7 +40,7 @@ func NewHealthClient(mode omode.Mode) (*HealthClient, error) { } // Start the health client. -func (c *HealthClient) Start() (status int) { +func (c *HealthClient) Start(ctx context.Context) (status int) { receive := make(chan string) throttleCh := make(chan struct{}, runtime.NumCPU()) @@ -49,8 +50,8 @@ func (c *HealthClient) Start() (status int) { conn.Handler = handlers.NewHealthHandler(c.server, receive) conn.Commands = []string{c.mode.String()} - go conn.Start(throttleCh, statsCh) - defer conn.Stop() + connCtx, cancel := conn.Handler.WithCancel(ctx) + go conn.Start(connCtx, cancel, throttleCh, statsCh) for { select { diff --git a/internal/clients/maker.go b/internal/clients/maker.go new file mode 100644 index 0000000..da9dfc9 --- /dev/null +++ b/internal/clients/maker.go @@ -0,0 +1,8 @@ +package clients + +import "github.com/mimecast/dtail/internal/clients/handlers" + +type maker interface { + makeHandler(server string) handlers.Handler + makeCommands() (commands []string) +} diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go index 9070827..b581844 100644 --- a/internal/clients/maprclient.go +++ b/internal/clients/maprclient.go @@ -1,6 +1,7 @@ package clients import ( + "context" "errors" "fmt" "runtime" @@ -8,13 +9,9 @@ import ( "time" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // MaprClient is used for running mapreduce aggregations on remote files. @@ -39,8 +36,6 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { c := MaprClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: args.Mode == omode.TailClient, }, @@ -70,35 +65,36 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { return &c, nil } -func (c MaprClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewMaprHandler(conn.Server, c.query, c.globalGroup, c.PingTimeout) +// Start starts the mapreduce client. +func (c *MaprClient) Start(ctx context.Context) (status int) { + if c.query.Outfile == "" { + // Only print out periodic results if we don't write an outfile + go c.periodicPrintResults(ctx) + } - conn.Commands = append(conn.Commands, fmt.Sprintf("map %s", c.query.RawQuery)) - commandStr := "tail" + status = c.baseClient.Start(ctx) if c.additative { - commandStr = "cat" + c.recievedFinalResult() } - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", commandStr, file, c.Regex)) - } + return +} - return conn +func (c MaprClient) makeHandler(server string) handlers.Handler { + return handlers.NewMaprHandler(server, c.query, c.globalGroup) } -// Start starts the mapreduce client. -func (c *MaprClient) Start() (status int) { - if c.query.Outfile == "" { - // Only print out periodic results if we don't write an outfile - go c.periodicPrintResults() - } +func (c MaprClient) makeCommands() (commands []string) { + commands = append(commands, fmt.Sprintf("map %s", c.query.RawQuery)) - status = c.baseClient.Start() + modeStr := "tail" if c.additative { - c.recievedFinalResult() + modeStr = "cat" + } + + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", modeStr, file, c.Regex)) } - c.baseClient.Stop() return } @@ -120,13 +116,13 @@ func (c *MaprClient) recievedFinalResult() { logger.Info(fmt.Sprintf("Wrote final mapreduce result to '%s'", c.query.Outfile)) } -func (c *MaprClient) periodicPrintResults() { +func (c *MaprClient) periodicPrintResults(ctx context.Context) { for { select { case <-time.After(c.query.Interval): logger.Info("Gathering interim mapreduce result") c.printResults() - case <-c.baseClient.stop: + case <-ctx.Done(): return } } diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go index bfc7bc5..71639b1 100644 --- a/internal/clients/remote/connection.go +++ b/internal/clients/remote/connection.go @@ -1,16 +1,18 @@ package remote import ( - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/ssh/client" + "context" "fmt" "io" "strconv" "strings" "time" + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/ssh/client" + "golang.org/x/crypto/ssh" ) @@ -30,8 +32,6 @@ type Connection struct { Commands []string // Is it a persistent connection or a one-off? isOneOff bool - // Used to stop the connection - stop chan struct{} // To deal with SSH server host keys hostKeyCallback *client.HostKeyCallback } @@ -48,7 +48,6 @@ func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, HostKeyCallback: hostKeyCallback.Wrap(), Timeout: time.Second * 3, }, - stop: make(chan struct{}), } c.initServerPort(server) @@ -64,7 +63,6 @@ func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthM Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), }, - stop: make(chan struct{}), isOneOff: true, } @@ -90,39 +88,34 @@ func (c *Connection) initServerPort(server string) { } } -// Start the server connection. Build up SSH session and send some DTail commandc. -func (c *Connection) Start(throttleCh, statsCh chan struct{}) { +// Start the server connection. Build up SSH session and send some DTail commands. +func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { + // Throttle how many connections can be established concurrently (based on ch length) select { - case <-c.stop: - logger.Info(c.Server, c.port, "Disconnecting client") + case throttleCh <- struct{}{}: + defer func() { <-throttleCh }() + case <-ctx.Done(): return - default: } - // Wait for SSH connection throttler - throttleCh <- struct{}{} - - // Wait until connection has been initiated or an error occured - // during initialization. - throttleStopCh := make(chan struct{}, 2) go func() { - <-throttleStopCh - <-throttleCh - }() + defer cancel() - if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil { - logger.Warn(c.Server, c.port, err) - throttleStopCh <- struct{}{} + if err := c.dial(ctx, cancel, c.Server, c.port, statsCh); err != nil { + logger.Warn(c.Server, c.port, err) - if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { - logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port) - return + if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { + logger.Debug("Not trusting host", c.Server, c.port) + return + } } - } + }() + + <-ctx.Done() } // Dail into a new SSH connection. Close connection in case of an error. -func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error { +func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, statsCh chan struct{}) error { statsCh <- struct{}{} defer func() { <-statsCh }() @@ -135,11 +128,11 @@ func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan st } defer client.Close() - return c.session(client, throttleStopCh) + return c.session(ctx, cancel, client) } // Create the SSH session. Close the session in case of an error. -func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error { +func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client) error { logger.Debug(c.Server, "session") session, err := client.NewSession() @@ -148,14 +141,10 @@ func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) } defer session.Close() - return c.handle(session, throttleStopCh) + return c.handle(ctx, cancel, session) } -// Handle the SSH session. Also send periodic pings to the server in order -// to determine that session is still intact. -func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error { - defer c.Handler.Stop() - +func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session) error { logger.Debug(c.Server, "handle") stdinPipe, err := session.StdinPipe() @@ -172,59 +161,30 @@ func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{} return err } - // Establish Bi-directional pipe between SSH session and client handler. - brokenStdinPipe := make(chan struct{}) go func() { - defer close(brokenStdinPipe) + defer cancel() io.Copy(stdinPipe, c.Handler) }() - brokenStdoutPipe := make(chan struct{}) go func() { - defer close(brokenStdoutPipe) + defer cancel() io.Copy(c.Handler, stdoutPipe) }() - // SSH session established, other goroutine can initiate session now. - throttleStopCh <- struct{}{} + go func() { + defer cancel() + select { + case <-c.Handler.Done(): + case <-ctx.Done(): + } + }() // Send all commands to client. for _, command := range c.Commands { logger.Debug(command) - c.Handler.SendCommand(command) + c.Handler.SendMessage(command) } - if !c.isOneOff { - return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe) - } - - <-c.stop - - // Normal shutdown, all fine + <-ctx.Done() return nil } - -// Periodically check whether connection is still alive or not. -func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error { - for { - select { - case <-time.After(time.Second * 3): - if err := c.Handler.Ping(); err != nil { - return err - } - case <-brokenStdinPipe: - logger.Debug("Broken stdin pipe", c.Server, c.port) - return nil - case <-brokenStdoutPipe: - logger.Debug("Broken stdout pipe", c.Server, c.port) - return nil - case <-c.stop: - return nil - } - } -} - -// Stop the connection. -func (c *Connection) Stop() { - close(c.stop) -} diff --git a/internal/clients/runclient.go b/internal/clients/runclient.go new file mode 100644 index 0000000..7a62fcc --- /dev/null +++ b/internal/clients/runclient.go @@ -0,0 +1,40 @@ +package clients + +import ( + "fmt" + "runtime" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/omode" +) + +// RunClient is a client to run various commands on the server. +type RunClient struct { + baseClient +} + +// NewRunClient returns a new cat client. +func NewRunClient(args Args) (*RunClient, error) { + args.Mode = omode.RunClient + + c := RunClient{ + baseClient: baseClient{ + Args: args, + throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), + retry: false, + }, + } + + c.init(c) + return &c, nil +} + +func (c RunClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} + +func (c RunClient) makeCommands() (commands []string) { + // Send "run COMMAND" to server! + commands = append(commands, fmt.Sprintf("%s %s", c.Mode.String(), c.What)) + return +} diff --git a/internal/clients/stats.go b/internal/clients/stats.go index d36cef6..ec6adfe 100644 --- a/internal/clients/stats.go +++ b/internal/clients/stats.go @@ -1,11 +1,13 @@ package clients import ( - "github.com/mimecast/dtail/internal/logger" + "context" "fmt" "runtime" "sync" "time" + + "github.com/mimecast/dtail/internal/io/logger" ) // Used to collect and display various client stats. @@ -28,14 +30,14 @@ func newTailStats(connectionsTotal int) *stats { } } -func (s *stats) periodicLogStats(throttleCh chan struct{}, stop <-chan struct{}) { +func (s *stats) periodicLogStats(ctx context.Context, throttleCh chan struct{}) { connectedLast := 0 statsInterval := 5 for { select { case <-time.After(time.Second * time.Duration(statsInterval)): - case <-stop: + case <-ctx.Done(): return } diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go index 674ca36..4d81fd5 100644 --- a/internal/clients/tailclient.go +++ b/internal/clients/tailclient.go @@ -6,11 +6,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // TailClient is used for tailing remote log files (opening, seeking to the end and returning only new incoming lines). @@ -25,25 +21,22 @@ func NewTailClient(args Args) (*TailClient, error) { c := TailClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: true, }, } c.init(c) - return &c, nil } -func (c TailClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) +func (c TailClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c TailClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - - return conn + return } diff --git a/internal/discovery/comma.go b/internal/discovery/comma.go index ad18be0..94276c7 100644 --- a/internal/discovery/comma.go +++ b/internal/discovery/comma.go @@ -1,7 +1,7 @@ package discovery import ( - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "strings" ) diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index d76c1b2..1090ea9 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -1,7 +1,6 @@ package discovery import ( - "github.com/mimecast/dtail/internal/logger" "fmt" "math/rand" "os" @@ -9,6 +8,16 @@ import ( "regexp" "strings" "time" + + "github.com/mimecast/dtail/internal/io/logger" +) + +// ServerOrder to specify how to sort the server list. +type ServerOrder int + +const ( + // Shuffle the server list? + Shuffle ServerOrder = iota ) // Discovery method for discovering a list of available DTail servers. @@ -21,12 +30,12 @@ type Discovery struct { server string // To filter server list. regex *regexp.Regexp - // To shuffle resulting server list. - shuffle bool + // How to order the server list. + order ServerOrder } // New returns a new discovery method. -func New(method, server string, shuffle bool) *Discovery { +func New(method, server string, order ServerOrder) *Discovery { module := method options := "" @@ -43,7 +52,7 @@ func New(method, server string, shuffle bool) *Discovery { module: strings.ToUpper(module), options: options, server: server, - shuffle: shuffle, + order: order, } if strings.HasPrefix(server, "/") && strings.HasSuffix(server, "/") { @@ -84,7 +93,7 @@ func (d *Discovery) ServerList() []string { servers = d.dedupList(servers) - if d.shuffle { + if d.order == Shuffle { servers = d.shuffleList(servers) } diff --git a/internal/discovery/file.go b/internal/discovery/file.go index 2edc867..c04173e 100644 --- a/internal/discovery/file.go +++ b/internal/discovery/file.go @@ -2,7 +2,7 @@ package discovery import ( "bufio" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "os" ) diff --git a/internal/fs/catfile.go b/internal/fs/catfile.go deleted file mode 100644 index 99f521f..0000000 --- a/internal/fs/catfile.go +++ /dev/null @@ -1,27 +0,0 @@ -package fs - -import "sync" - -// CatFile is for reading a whole file. -type CatFile struct { - readFile -} - -// NewCatFile returns a new file catter. -func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile { - var mutex sync.Mutex - - return CatFile{ - readFile: readFile{ - filePath: filePath, - stop: make(chan struct{}), - globID: globID, - serverMessages: serverMessages, - retry: false, - canSkipLines: false, - seekEOF: false, - limiter: limiter, - mutex: &mutex, - }, - } -} diff --git a/internal/fs/filereader.go b/internal/fs/filereader.go deleted file mode 100644 index 5a08e27..0000000 --- a/internal/fs/filereader.go +++ /dev/null @@ -1,9 +0,0 @@ -package fs - -// FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file. -type FileReader interface { - Start(lines chan<- LineRead, regex string) error - FilePath() string - Retry() bool - Stop() -} diff --git a/internal/fs/lineread.go b/internal/fs/lineread.go deleted file mode 100644 index 7ee558e..0000000 --- a/internal/fs/lineread.go +++ /dev/null @@ -1,28 +0,0 @@ -package fs - -import ( - "fmt" -) - -// LineRead represents a read log line. -type LineRead struct { - // The content of the log line. - Content []byte - // Until now, how many log lines were processed? - Count uint64 - // Sometimes we produce too many log lines so that the client - // is too slow to process all of them. The server will drop log - // lines if that happens but it will signal to the client how - // many log lines in % could be transmitted to the client. - TransmittedPerc int - GlobID *string -} - -// Return a human readable representation of the followed line. -func (l LineRead) String() string { - return fmt.Sprintf("LineRead(Content:%s,TransmittedPerc:%v,Count:%v,GlobID:%s)", - string(l.Content), - l.TransmittedPerc, - l.Count, - *l.GlobID) -} diff --git a/internal/fs/permissions/permission.go b/internal/fs/permissions/permission.go deleted file mode 100644 index 6e83309..0000000 --- a/internal/fs/permissions/permission.go +++ /dev/null @@ -1,14 +0,0 @@ -// +build !linux - -package permissions - -import ( - "github.com/mimecast/dtail/internal/logger" -) - -// ToRead is to check whether user has read permissions to a given file. -func ToRead(user, filePath string) (bool, error) { - // Only implemented for Linux, always expect true - logger.Warn(user, filePath, "Not performing ACL check, not supported on this platform") - return true, nil -} diff --git a/internal/fs/permissions/permission_linux.c b/internal/fs/permissions/permission_linux.c deleted file mode 100644 index cd10525..0000000 --- a/internal/fs/permissions/permission_linux.c +++ /dev/null @@ -1,395 +0,0 @@ -#include "permission_linux.h" - -#ifdef DEBUG -void debug_print_checker(struct permission_checker *pc) { - fprintf(stderr, "DEBUG: user_name:%s (%d)\n", - pc->user_name, pc->uid); - - fprintf(stderr, "DEBUG: ngids:%d\n", pc->ngids); - int j; - for (j = 0; j < pc->ngids; j++) { - fprintf(stderr, "DEBUG: %d", pc->gids[j]); - struct group *gr = getgrgid(pc->gids[j]); - if (gr != NULL) - fprintf(stderr, " (%s)", gr->gr_name); - fprintf(stderr, "\n"); - } - - fprintf(stderr, "DEBUG: file_path:%s (%d:%d)\n", - pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); -} -#endif // DEBUG - -int stat_file(struct permission_checker *pc) { - if (stat(pc->file_path, &pc->file_stat) != 0) - return -1; - -#ifdef DEBUG - fprintf(stderr, "DEBUG: File'%s' is owned by '%d:%d'\n", - pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); -#endif - - return 0; -} - -int get_user_uid(struct permission_checker *pc) { - struct passwd *result = NULL; - - size_t bufsize = sysconf(_SC_GETPW_R_SIZE_MAX); - if (bufsize == -1) - bufsize = 16384; - - char *buf = malloc(bufsize); - if (buf == NULL) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unabel to allocate bufer while retrieving user '%s'\n", pc->user_name); -#endif - return -1; - } - - int rc = getpwnam_r(pc->user_name, &pc->pw, buf, bufsize, &result); - - if (result == NULL) { -#ifdef DEBUG - if (rc == 0) { - fprintf(stderr, "DEBUG: No user '%s' found\n", pc->user_name); - } else { - fprintf(stderr, "DEBUG: Unknown error while retrieving user '%s'\n", pc->user_name); - } -#endif - - free(buf); - return -1; - } - - pc->uid = pc->pw.pw_uid; - - free(buf); - return 0; -} - -int get_user_groups(struct permission_checker *pc) { - // First assume we are in 10 groups max - pc->ngids = 10; - pc->gids = malloc(pc->ngids * sizeof(gid_t)); - - if (pc->gids == NULL) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unable to allocate space for gids."); -#endif - return -1; - } - - // Try so many times to load group list until it fits into group array. - while (getgrouplist(pc->user_name, pc->pw.pw_gid, pc->gids, &pc->ngids) == -1) { - // Too many groups, enlarge group array and try again - int newngids = pc->ngids + 100; - size_t newsize = newngids * sizeof(gid_t); - - if (SIZE_MAX / newngids < sizeof(gid_t)) { - // Overflow -#ifdef DEBUG - fprintf(stderr, "DEBUG: Overflow detected."); -#endif - return -1; - } - - gid_t *newgids = realloc(pc->gids, newsize); - if (newgids == NULL) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unable to allocate space for gids."); -#endif - free(pc->gids); - return -1; - } - - pc->gids = newgids; - pc->ngids = newngids; - } - - return 0; -} - -int is_member_of_group(struct permission_checker *pc, gid_t gid) { - int j; - for (j = 0; j < pc->ngids; j++) - if (pc->gids[j] == gid) - return 1; - return 0; -} - -int check_acl_uid_matches(uid_t uid, acl_entry_t entry) { - int ret = -1; - uid_t *acl_uid = acl_get_qualifier(entry); - if (acl_uid == NULL) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); -#endif - return -1; - } - - ret = *acl_uid == uid ? 0 : -1; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL user match?: %d <=> %d: %d\n", *acl_uid, uid, ret); -#endif - acl_free(acl_uid); - return ret; -} - -int check_acl_gid_matches(gid_t *gids, int ngids, acl_entry_t entry) { - int ret = -1; - gid_t *acl_gid = acl_get_qualifier(entry); - if (acl_gid == NULL) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); -#endif - return -1; - } - - int j; - for (j = 0; j < ngids; j++) { - if (*acl_gid == gids[j]) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: User is in group %d", *acl_gid); -#endif - ret = 0; - break; - } - } - -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL group match?: %d <=> ...: %d\n", *acl_gid, ret); -#endif - acl_free(acl_gid); - return ret; -} - -int check_acl(struct permission_checker *pc, const int flag) { - // By default user has no read perm. - int has_read_perm = 0; - - // By default mask tells that there are read perm. However in order to have - // read permissions both, has_read_perm and mask_allows_read_access must be 1! - int mask_allows_read_access = 1; - - acl_type_t type = ACL_TYPE_ACCESS; - acl_t acl = acl_get_file(pc->file_path, type); - - if (acl == NULL) - // Unable to retrieve ACL. - return -1; - - // Walk through each entry of this ACL. - int id; - for (id = ACL_FIRST_ENTRY; ; id = ACL_NEXT_ENTRY) { - acl_entry_t entry; - if (acl_get_entry(acl, id, &entry) != 1) - // No more ACL entries. - break; - - acl_tag_t tag; - if (acl_get_tag_type(entry, &tag) == -1) - // Unable to retrieve ACL tag. - return -1; - - switch (tag) { - case ACL_USER_OBJ: - if (flag == GROUP_CHECK) - continue; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_USER_OBJ\n"); -#endif - // Ignore this ACL entry if user is not owner of file. - if (pc->uid != pc->file_stat.st_uid) - continue; - break; - case ACL_USER: - if (flag == GROUP_CHECK) - continue; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_USER\n"); -#endif - // Ignore this ACL entry if uid does not match. - if (check_acl_uid_matches(pc->uid, entry) != 0) - continue; - break; - case ACL_GROUP_OBJ: - if (flag == USER_CHECK) - continue; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_GROUP_OBJ\n"); -#endif - // Ignore ACL entry if user is not in group of file. - if (!is_member_of_group(pc, pc->file_stat.st_gid)) - continue; - break; - case ACL_GROUP: - if (flag == USER_CHECK) - continue; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_GROUP\n"); -#endif - // Ignore ACL entry if user is not in group of entry. - if (check_acl_gid_matches(pc->gids, pc->ngids, entry) != 0) - continue; - break; - case ACL_OTHER: - if (flag == GROUP_CHECK) - continue; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_OTHER\n"); -#endif - break; - case ACL_MASK: -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_MASK\n"); -#endif - break; - default: -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unknown ACL tag\n"); -#endif - return -1; - } - -#ifdef DEBUG - fprintf(stderr, "DEBUG: Retrieving permset\n"); -#endif - acl_permset_t permset; - int permission; - if (acl_get_permset(entry, &permset) == -1) - // Unable to retrieve permset. - return -1; - - if ((permission = acl_get_perm(permset, ACL_READ)) == -1) - // Unable to retrieve permset value. - return -1; - - if (permission == 1 && tag != ACL_MASK) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL says user has permission to read file.\n"); -#endif - has_read_perm = 1; - } else if (permission == 0 && tag == ACL_MASK) { - // Mask says that there are no permissions to read. - mask_allows_read_access = 0; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL mask says no permission to read file.\n"); -#endif - } - } - - if (has_read_perm && mask_allows_read_access) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL end result: User has permission to read file.\n"); -#endif - return 1; - } - -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL end result: User has no permission to read file.\n"); -#endif - return 0; -} - -int check_traditional(struct permission_checker *pc, const int flag) { - mode_t mode = pc->file_stat.st_mode; - uid_t uid = pc->file_stat.st_uid; - gid_t gid = pc->file_stat.st_gid; - - if (flag == USER_CHECK && (mode & S_IROTH)) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Others can read file '%s'\n", - pc->file_path); -#endif - return 1; - - } else if (flag == USER_CHECK && (mode & S_IRUSR) && uid == pc->uid) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: User '%s' can read file '%s'\n", - pc->user_name, pc->file_path); -#endif - return 1; - - } else if (flag == GROUP_CHECK && (mode & S_IRGRP) && is_member_of_group(pc, gid)) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: User's '%s' group can read file '%s'\n", - pc->user_name, pc->file_path); -#endif - return 1; - } - - return 0; -} - -int permission_to_read(char* user_name, char *file_path) { - int rc = -1; - -#ifdef DEBUG - fprintf(stderr, "DEBUG: User check '%s' for file '%s'\n", user_name, file_path); -#endif - struct permission_checker pc = { - .user_name = user_name, - .gids = NULL, - .ngids = 0, - .file_path = file_path, - }; - - // Gather user's UID. - if ((rc = get_user_uid(&pc)) == -1) - // Could not retrieve UID. - goto cleanup; - - // Gather file owner (user and group). - if ((rc = stat_file(&pc)) == -1) - // Could not stat file. - goto cleanup; - - // Check whether there is an ACL entry which would allow the user - // to read the file. Don't check for any groups yet. The issue with - // groups is that it can be very slow to retrieve the list of groups - // of a specific user when done via a remote LDAP server! - if ((rc = check_acl(&pc, USER_CHECK)) == 1) - // Yes, has permissions. - goto cleanup; - - // Check whether ACLs of file could be retrieved. - if (rc == -1) { - if (errno != ENOTSUP) - // Unknown error. - goto cleanup; - - // File system does not support ACLs. - // Fallback to traditional permissions. - if ((rc = check_traditional(&pc, USER_CHECK)) == 1) - // Yes, has traditional permissions. - goto cleanup; - - if ((rc = get_user_groups(&pc)) == -1) - // Can not retrieve user's groups. - goto cleanup; - - rc = check_traditional(&pc, GROUP_CHECK); - goto cleanup; - } - - if ((rc = get_user_groups(&pc)) == -1) - // Can not retrieve use'r groups. - goto cleanup; - - // Check whether there is an ACL entry which would allow any of the - // user's groups to read the file. - rc = check_acl(&pc, GROUP_CHECK); - -cleanup: -#ifdef DEBUG - debug_print_checker(&pc); -#endif - - if (pc.ngids) - free(pc.gids); - - return rc; -} - -// vim: set tabstop=8 softtabstop=0 expandtab shiftwidth=4 smarttab diff --git a/internal/fs/permissions/permission_linux.go b/internal/fs/permissions/permission_linux.go deleted file mode 100644 index feae729..0000000 --- a/internal/fs/permissions/permission_linux.go +++ /dev/null @@ -1,33 +0,0 @@ -package permissions - -/* -#include "permission_linux.h" -#cgo LDFLAGS: -L. -lacl -*/ -import "C" - -import ( - "errors" - "unsafe" -) - -// To check whether user has Linux file system permissions to read a given file. -func ToRead(user, filePath string) (bool, error) { - cUser := C.CString(user) - cFilePath := C.CString(filePath) - - defer C.free(unsafe.Pointer(cUser)) - defer C.free(unsafe.Pointer(cFilePath)) - - cOk, err := C.permission_to_read(cUser, cFilePath) - if cOk == 1 { - return true, nil - } - - if err != nil { - // err contains errno message - return false, err - } - - return false, errors.New("User without permission to read file") -} diff --git a/internal/fs/permissions/permission_linux.h b/internal/fs/permissions/permission_linux.h deleted file mode 100644 index a2c266e..0000000 --- a/internal/fs/permissions/permission_linux.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef PERMISSION_LINUX_H -#define PERMISSION_LINUX_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -//#define DEBUG -#define USER_CHECK 0 -#define GROUP_CHECK 1 - -struct permission_checker { - char *user_name; - uid_t uid; - gid_t *gids; - int ngids; - char *file_path; - struct stat file_stat; - struct passwd pw; -}; - - -#ifdef DEBUG -// Print out permission_checker struct. -void debug_print_checker(struct permission_checker *pc); -#endif - -// Stat a given file to retrieve traditional UNIX permissions. -int stat_file(struct permission_checker *pc); - -// Retrieve UID of user. -int get_user_uid(struct permission_checker *pc); - -// Retrieve all groups of the user. -int get_user_groups(struct permission_checker *pc); - -// Check whether user is member of a group or not. -int is_member_of_group(struct permission_checker *pc, gid_t gid); - -// Check whether user can read file according Linux ACLs. -// As flag use either USER_CHECK or GROUP_CHECK. -int check_acl(struct permission_checker *pc, const int flag); - -// Check whether user has permissions to read file according traditional -// UNIX permissions. As flag use either USER_CHECK or GROUP_CHECK. -int check_traditional(struct permission_checker *pc, const int flag); - -// Returns 1 if user has permission to read file. -// Returns <0 on error and returns 0 if no permissions. -int permission_to_read(char* user, char *file_path); - -#endif // PERMISSION_LINUX_H diff --git a/internal/fs/permissions/permission_test.go b/internal/fs/permissions/permission_test.go deleted file mode 100644 index d415ac2..0000000 --- a/internal/fs/permissions/permission_test.go +++ /dev/null @@ -1,112 +0,0 @@ -// +build linux - -package permissions - -import ( - "os" - "os/exec" - "os/user" - "strings" - "testing" -) - -const ( - setfacl string = "/usr/bin/setfacl" - file string = "/tmp/acltest" -) - -func TestLinuxACL(t *testing.T) { - setfacl := "/usr/bin/setfacl" - file := "/tmp/acltest" - - // Delete file if it exists. - if _, err := os.Stat(file); err == nil { - os.Remove(file) - } - - f, err := os.Create(file) - if err != nil { - t.Errorf("%v", err) - } - defer func() { - f.Close() - //os.Remove(file) - }() - - user, err := user.Current() - if err != nil { - t.Errorf("Unable to retrieve current user: %v", err) - } - - // Test 1: Remove all permissions and perform a permission check - cmd := exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, _ := ToRead(user.Username, file); ok { - t.Errorf("Didn't expect permissions to read file!") - } - - // Test 2: Add read permission to file owner - cmd = exec.Command(setfacl, "-b", "-m", "u::r--,g::---,o::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, err := ToRead(user.Username, file); !ok { - t.Errorf("Expected permissions to read file: %v", err) - } - - // Test 3: Add read permission to file group - cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::r--,o::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, err := ToRead(user.Username, file); !ok { - t.Errorf("Expected permissions to read file: %v", err) - } - - // Test 4: Add read permission to others - cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::r--", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - - if ok, err := ToRead(user.Username, file); !ok { - t.Errorf("Expected permissions to read file: %v", err) - } - - // Test 5: Remove read permission from mask - cmd = exec.Command(setfacl, "-m", "m::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, _ := ToRead(user.Username, file); ok { - t.Errorf("Didn't expect permissions to read file!") - } - cmd = exec.Command(setfacl, "-m", "m::r--", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - - // Test 6: Add read permission to specific group - cmd = exec.Command(setfacl, "-b", "-m", "u::---,g:"+user.Username+":r--,o::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, err := ToRead(user.Username, file); !ok { - t.Errorf("Expected permissions to read file for user %v: %v", user.Username, err) - } - - // Test 7: Remove all permissions but mask - cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - cmd = exec.Command(setfacl, "-m", "m::r--", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, _ := ToRead(user.Username, file); ok { - t.Errorf("Didn't expect permissions to read file!") - } -} diff --git a/internal/fs/readfile.go b/internal/fs/readfile.go deleted file mode 100644 index 312447a..0000000 --- a/internal/fs/readfile.go +++ /dev/null @@ -1,318 +0,0 @@ -package fs - -import ( - "bufio" - "compress/gzip" - "github.com/mimecast/dtail/internal/logger" - "errors" - "io" - "os" - "regexp" - "strings" - "sync" - "time" - - "github.com/DataDog/zstd" -) - -// Used to tail and filter a local log file. -type readFile struct { - // Various statistics (e.g. regex hit percentage, transfer percentage). - stats - // Path of log file to tail. - filePath string - // Only consider all log lines matching this regular expression. - re *regexp.Regexp - // The glob identifier of the file. - globID string - // Channel to send a server message to the dtail client - serverMessages chan<- string - // Signals to stop tailing the log file. - stop chan struct{} - // Periodically retry reading file. - retry bool - // Can I skip messages when there are too many? - canSkipLines bool - // Seek to the EOF before processing file? - seekEOF bool - // Mutex to control the stopping of the file - mutex *sync.Mutex - limiter chan struct{} -} - -// FilePath returns the full file path. -func (f readFile) FilePath() string { - return f.filePath -} - -// Retry reading the file on error? -func (f readFile) Retry() bool { - return f.retry -} - -// Start tailing a log file. -func (f readFile) Start(lines chan<- LineRead, regex string) error { - defer func() { - select { - case <-f.limiter: - default: - } - }() - - select { - case f.limiter <- struct{}{}: - default: - select { - case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."): - case <-f.stop: - return nil - } - f.limiter <- struct{}{} - } - - fd, err := os.Open(f.filePath) - if err != nil { - return err - } - defer fd.Close() - - if f.seekEOF { - fd.Seek(0, io.SeekEnd) - } - - rawLines := make(chan []byte, 100) - truncate := make(chan struct{}) - - var wg sync.WaitGroup - wg.Add(1) - - go f.periodicTruncateCheck(truncate) - go f.filter(&wg, rawLines, lines, regex) - - err = f.read(fd, rawLines, truncate) - close(rawLines) - wg.Wait() - - return err -} - -func (f readFile) periodicTruncateCheck(truncate chan struct{}) { - for { - select { - case <-time.After(time.Second * 3): - select { - case truncate <- struct{}{}: - case <-f.stop: - } - case <-f.stop: - return - } - } -} - -// Stop reading file. -func (f readFile) Stop() { - f.mutex.Lock() - defer f.mutex.Unlock() - - select { - case <-f.stop: - return - default: - } - - close(f.stop) -} - -func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) { - switch { - case strings.HasSuffix(f.FilePath(), ".gz"): - fallthrough - case strings.HasSuffix(f.FilePath(), ".gzip"): - logger.Info(f.FilePath(), "Detected gzip compression format") - var gzipReader *gzip.Reader - gzipReader, err = gzip.NewReader(fd) - if err != nil { - return - } - reader = bufio.NewReader(gzipReader) - case strings.HasSuffix(f.FilePath(), ".zst"): - logger.Info(f.FilePath(), "Detected zstd compression format") - reader = bufio.NewReader(zstd.NewReader(fd)) - default: - reader = bufio.NewReader(fd) - } - - return -} - -func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error { - reader, err := f.makeReader(fd) - if err != nil { - return err - } - rawLine := make([]byte, 0, 512) - var offset uint64 - - lineLengthThreshold := 1024 * 1024 // 1mb - longLineWarning := false - - for { - select { - case <-truncate: - if isTruncated, err := f.truncated(fd); isTruncated { - return err - } - logger.Info(f.filePath, "Current offset", offset) - - case <-f.stop: - return nil - default: - } - - // Read some bytes (max 4k at once as of go 1.12). isPrefix will - // be set if line does not fit into 4k buffer. - bytes, isPrefix, err := reader.ReadLine() - - if err != nil { - // If EOF, sleep a couple of ms and return with nil error. - // If other error, return with non-nil error. - if err != io.EOF { - return err - } - if !f.seekEOF { - logger.Debug(f.FilePath(), "End of file reached") - return nil - } - time.Sleep(time.Millisecond * 100) - continue - } - - rawLine = append(rawLine, bytes...) - offset += uint64(len(bytes)) - - if !isPrefix { - // last LineRead call returned contend until end of line. - rawLine = append(rawLine, '\n') - select { - case rawLines <- rawLine: - case <-f.stop: - return nil - } - rawLine = make([]byte, 0, 512) - if longLineWarning { - longLineWarning = false - } - continue - } - - // Last LineRead call could not read content until end of line, buffer - // was too small. Determine whether we exceed the max line length we - // want dtail to send to the client at once. Possibly split up log line - // into multiple log lines. - if len(rawLine) >= lineLengthThreshold { - if !longLineWarning { - f.serverMessages <- logger.Warn(f.filePath, "Long log line, splitting into multiple lines") - // Only print out one warning per long log line. - longLineWarning = true - } - rawLine = append(rawLine, '\n') - select { - case rawLines <- rawLine: - case <-f.stop: - return nil - } - rawLine = make([]byte, 0, 512) - } - } -} - -// Filter log lines matching a given regular expression. -func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- LineRead, regex string) { - defer wg.Done() - - if regex == "" { - regex = "." - } - - re, err := regexp.Compile(regex) - if err != nil { - logger.Error(regex, "Can't compile regex, using '.' instead", err) - re = regexp.MustCompile(".") - } - f.re = re - - for { - select { - case line, ok := <-rawLines: - f.updatePosition() - if !ok { - return - } - if filteredLine, ok := f.transmittable(line, len(lines), cap(lines)); ok { - select { - case lines <- filteredLine: - case <-f.stop: - return - } - } - } - } -} - -func (f readFile) transmittable(line []byte, length, capacity int) (LineRead, bool) { - var read LineRead - - if !f.re.Match(line) { - f.updateLineNotMatched() - f.updateLineNotTransmitted() - return read, false - } - f.updateLineMatched() - - // Can we actually send more messages, channel capacity reached? - if f.canSkipLines && length >= capacity { - f.updateLineNotTransmitted() - return read, false - } - f.updateLineTransmitted() - - read = LineRead{ - Content: line, - GlobID: &f.globID, - Count: f.totalLineCount(), - TransmittedPerc: f.transmittedPerc(), - } - - return read, true -} - -// Check wether log file is truncated. Returns nil if not. -func (f readFile) truncated(fd *os.File) (bool, error) { - logger.Debug(f.filePath, "File truncation check") - - // Can not seek currently open FD. - curPos, err := fd.Seek(0, os.SEEK_CUR) - if err != nil { - return true, err - } - - // Can not open file at original path. - pathFd, err := os.Open(f.filePath) - if err != nil { - return true, err - } - defer pathFd.Close() - - // Can not seek file at original path. - pathPos, err := pathFd.Seek(0, io.SeekEnd) - if err != nil { - return true, err - } - - if curPos > pathPos { - return true, errors.New("File got truncated") - } - - return false, nil -} diff --git a/internal/fs/stats.go b/internal/fs/stats.go deleted file mode 100644 index 4121ff7..0000000 --- a/internal/fs/stats.go +++ /dev/null @@ -1,69 +0,0 @@ -package fs - -// Used to calculate how many log lines matched the regular expression -// and how many log files could be transmitted from the server to the client. -// Hit and transmit percentage takes only the last 100 log lines into calculation. -type stats struct { - pos int - lineCount uint64 - matched [100]bool - matchCount uint64 - transmitted [100]bool - transmitCount int -} - -// Return the total line count. -func (f *stats) totalLineCount() uint64 { - return f.lineCount -} - -// Calculate the percentage of log lines transmitted to the client. -func (f *stats) transmittedPerc() int { - return int(percentOf(float64(f.matchCount), float64(f.transmitCount))) -} - -// Update bucket position. We only take into consideration the last 100 -// lines for stats. -func (f *stats) updatePosition() { - f.pos = (f.pos + 1) % 100 - f.lineCount++ -} - -// Increment match counter. -func (f *stats) updateLineMatched() { - if !f.matched[f.pos] { - f.matchCount++ - f.matched[f.pos] = true - } -} - -// Increment transmitted counter. -func (f *stats) updateLineTransmitted() { - if !f.transmitted[f.pos] { - f.transmitCount++ - f.transmitted[f.pos] = true - } -} - -// Decrement match counter. -func (f *stats) updateLineNotMatched() { - if f.matched[f.pos] { - f.matchCount-- - f.matched[f.pos] = false - } -} - -// Decrement transmitted counter. -func (f *stats) updateLineNotTransmitted() { - if f.transmitted[f.pos] { - f.transmitCount-- - f.transmitted[f.pos] = false - } -} - -func percentOf(total float64, value float64) float64 { - if total == 0 || total == value { - return 100 - } - return value / (total / 100.0) -} diff --git a/internal/fs/tailfile.go b/internal/fs/tailfile.go deleted file mode 100644 index a19d4e6..0000000 --- a/internal/fs/tailfile.go +++ /dev/null @@ -1,27 +0,0 @@ -package fs - -import "sync" - -// TailFile is to tail and filter a log file. -type TailFile struct { - readFile -} - -// NewTailFile returns a new file tailer. -func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile { - var mutex sync.Mutex - - return TailFile{ - readFile: readFile{ - filePath: filePath, - stop: make(chan struct{}), - globID: globID, - serverMessages: serverMessages, - retry: true, - canSkipLines: true, - seekEOF: true, - limiter: limiter, - mutex: &mutex, - }, - } -} diff --git a/internal/io/fs/catfile.go b/internal/io/fs/catfile.go new file mode 100644 index 0000000..7f387bc --- /dev/null +++ b/internal/io/fs/catfile.go @@ -0,0 +1,21 @@ +package fs + +// CatFile is for reading a whole file. +type CatFile struct { + readFile +} + +// NewCatFile returns a new file catter. +func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile { + return CatFile{ + readFile: readFile{ + filePath: filePath, + globID: globID, + serverMessages: serverMessages, + retry: false, + canSkipLines: false, + seekEOF: false, + limiter: limiter, + }, + } +} diff --git a/internal/io/fs/filereader.go b/internal/io/fs/filereader.go new file mode 100644 index 0000000..05e58a1 --- /dev/null +++ b/internal/io/fs/filereader.go @@ -0,0 +1,14 @@ +package fs + +import ( + "context" + + "github.com/mimecast/dtail/internal/io/line" +) + +// FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file. +type FileReader interface { + Start(ctx context.Context, lines chan<- line.Line, regex string) error + FilePath() string + Retry() bool +} diff --git a/internal/io/fs/permissions/permission.go b/internal/io/fs/permissions/permission.go new file mode 100644 index 0000000..0ed4f17 --- /dev/null +++ b/internal/io/fs/permissions/permission.go @@ -0,0 +1,14 @@ +// +build !linux + +package permissions + +import ( + "github.com/mimecast/dtail/internal/io/logger" +) + +// ToRead is to check whether user has read permissions to a given file. +func ToRead(user, filePath string) (bool, error) { + // Only implemented for Linux, always expect true + logger.Warn(user, filePath, "Not performing ACL check, not supported on this platform") + return true, nil +} diff --git a/internal/io/fs/permissions/permission_linux.c b/internal/io/fs/permissions/permission_linux.c new file mode 100644 index 0000000..cd10525 --- /dev/null +++ b/internal/io/fs/permissions/permission_linux.c @@ -0,0 +1,395 @@ +#include "permission_linux.h" + +#ifdef DEBUG +void debug_print_checker(struct permission_checker *pc) { + fprintf(stderr, "DEBUG: user_name:%s (%d)\n", + pc->user_name, pc->uid); + + fprintf(stderr, "DEBUG: ngids:%d\n", pc->ngids); + int j; + for (j = 0; j < pc->ngids; j++) { + fprintf(stderr, "DEBUG: %d", pc->gids[j]); + struct group *gr = getgrgid(pc->gids[j]); + if (gr != NULL) + fprintf(stderr, " (%s)", gr->gr_name); + fprintf(stderr, "\n"); + } + + fprintf(stderr, "DEBUG: file_path:%s (%d:%d)\n", + pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); +} +#endif // DEBUG + +int stat_file(struct permission_checker *pc) { + if (stat(pc->file_path, &pc->file_stat) != 0) + return -1; + +#ifdef DEBUG + fprintf(stderr, "DEBUG: File'%s' is owned by '%d:%d'\n", + pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); +#endif + + return 0; +} + +int get_user_uid(struct permission_checker *pc) { + struct passwd *result = NULL; + + size_t bufsize = sysconf(_SC_GETPW_R_SIZE_MAX); + if (bufsize == -1) + bufsize = 16384; + + char *buf = malloc(bufsize); + if (buf == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unabel to allocate bufer while retrieving user '%s'\n", pc->user_name); +#endif + return -1; + } + + int rc = getpwnam_r(pc->user_name, &pc->pw, buf, bufsize, &result); + + if (result == NULL) { +#ifdef DEBUG + if (rc == 0) { + fprintf(stderr, "DEBUG: No user '%s' found\n", pc->user_name); + } else { + fprintf(stderr, "DEBUG: Unknown error while retrieving user '%s'\n", pc->user_name); + } +#endif + + free(buf); + return -1; + } + + pc->uid = pc->pw.pw_uid; + + free(buf); + return 0; +} + +int get_user_groups(struct permission_checker *pc) { + // First assume we are in 10 groups max + pc->ngids = 10; + pc->gids = malloc(pc->ngids * sizeof(gid_t)); + + if (pc->gids == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to allocate space for gids."); +#endif + return -1; + } + + // Try so many times to load group list until it fits into group array. + while (getgrouplist(pc->user_name, pc->pw.pw_gid, pc->gids, &pc->ngids) == -1) { + // Too many groups, enlarge group array and try again + int newngids = pc->ngids + 100; + size_t newsize = newngids * sizeof(gid_t); + + if (SIZE_MAX / newngids < sizeof(gid_t)) { + // Overflow +#ifdef DEBUG + fprintf(stderr, "DEBUG: Overflow detected."); +#endif + return -1; + } + + gid_t *newgids = realloc(pc->gids, newsize); + if (newgids == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to allocate space for gids."); +#endif + free(pc->gids); + return -1; + } + + pc->gids = newgids; + pc->ngids = newngids; + } + + return 0; +} + +int is_member_of_group(struct permission_checker *pc, gid_t gid) { + int j; + for (j = 0; j < pc->ngids; j++) + if (pc->gids[j] == gid) + return 1; + return 0; +} + +int check_acl_uid_matches(uid_t uid, acl_entry_t entry) { + int ret = -1; + uid_t *acl_uid = acl_get_qualifier(entry); + if (acl_uid == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); +#endif + return -1; + } + + ret = *acl_uid == uid ? 0 : -1; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL user match?: %d <=> %d: %d\n", *acl_uid, uid, ret); +#endif + acl_free(acl_uid); + return ret; +} + +int check_acl_gid_matches(gid_t *gids, int ngids, acl_entry_t entry) { + int ret = -1; + gid_t *acl_gid = acl_get_qualifier(entry); + if (acl_gid == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); +#endif + return -1; + } + + int j; + for (j = 0; j < ngids; j++) { + if (*acl_gid == gids[j]) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User is in group %d", *acl_gid); +#endif + ret = 0; + break; + } + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL group match?: %d <=> ...: %d\n", *acl_gid, ret); +#endif + acl_free(acl_gid); + return ret; +} + +int check_acl(struct permission_checker *pc, const int flag) { + // By default user has no read perm. + int has_read_perm = 0; + + // By default mask tells that there are read perm. However in order to have + // read permissions both, has_read_perm and mask_allows_read_access must be 1! + int mask_allows_read_access = 1; + + acl_type_t type = ACL_TYPE_ACCESS; + acl_t acl = acl_get_file(pc->file_path, type); + + if (acl == NULL) + // Unable to retrieve ACL. + return -1; + + // Walk through each entry of this ACL. + int id; + for (id = ACL_FIRST_ENTRY; ; id = ACL_NEXT_ENTRY) { + acl_entry_t entry; + if (acl_get_entry(acl, id, &entry) != 1) + // No more ACL entries. + break; + + acl_tag_t tag; + if (acl_get_tag_type(entry, &tag) == -1) + // Unable to retrieve ACL tag. + return -1; + + switch (tag) { + case ACL_USER_OBJ: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_USER_OBJ\n"); +#endif + // Ignore this ACL entry if user is not owner of file. + if (pc->uid != pc->file_stat.st_uid) + continue; + break; + case ACL_USER: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_USER\n"); +#endif + // Ignore this ACL entry if uid does not match. + if (check_acl_uid_matches(pc->uid, entry) != 0) + continue; + break; + case ACL_GROUP_OBJ: + if (flag == USER_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_GROUP_OBJ\n"); +#endif + // Ignore ACL entry if user is not in group of file. + if (!is_member_of_group(pc, pc->file_stat.st_gid)) + continue; + break; + case ACL_GROUP: + if (flag == USER_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_GROUP\n"); +#endif + // Ignore ACL entry if user is not in group of entry. + if (check_acl_gid_matches(pc->gids, pc->ngids, entry) != 0) + continue; + break; + case ACL_OTHER: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_OTHER\n"); +#endif + break; + case ACL_MASK: +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_MASK\n"); +#endif + break; + default: +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unknown ACL tag\n"); +#endif + return -1; + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: Retrieving permset\n"); +#endif + acl_permset_t permset; + int permission; + if (acl_get_permset(entry, &permset) == -1) + // Unable to retrieve permset. + return -1; + + if ((permission = acl_get_perm(permset, ACL_READ)) == -1) + // Unable to retrieve permset value. + return -1; + + if (permission == 1 && tag != ACL_MASK) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL says user has permission to read file.\n"); +#endif + has_read_perm = 1; + } else if (permission == 0 && tag == ACL_MASK) { + // Mask says that there are no permissions to read. + mask_allows_read_access = 0; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL mask says no permission to read file.\n"); +#endif + } + } + + if (has_read_perm && mask_allows_read_access) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL end result: User has permission to read file.\n"); +#endif + return 1; + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL end result: User has no permission to read file.\n"); +#endif + return 0; +} + +int check_traditional(struct permission_checker *pc, const int flag) { + mode_t mode = pc->file_stat.st_mode; + uid_t uid = pc->file_stat.st_uid; + gid_t gid = pc->file_stat.st_gid; + + if (flag == USER_CHECK && (mode & S_IROTH)) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Others can read file '%s'\n", + pc->file_path); +#endif + return 1; + + } else if (flag == USER_CHECK && (mode & S_IRUSR) && uid == pc->uid) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User '%s' can read file '%s'\n", + pc->user_name, pc->file_path); +#endif + return 1; + + } else if (flag == GROUP_CHECK && (mode & S_IRGRP) && is_member_of_group(pc, gid)) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User's '%s' group can read file '%s'\n", + pc->user_name, pc->file_path); +#endif + return 1; + } + + return 0; +} + +int permission_to_read(char* user_name, char *file_path) { + int rc = -1; + +#ifdef DEBUG + fprintf(stderr, "DEBUG: User check '%s' for file '%s'\n", user_name, file_path); +#endif + struct permission_checker pc = { + .user_name = user_name, + .gids = NULL, + .ngids = 0, + .file_path = file_path, + }; + + // Gather user's UID. + if ((rc = get_user_uid(&pc)) == -1) + // Could not retrieve UID. + goto cleanup; + + // Gather file owner (user and group). + if ((rc = stat_file(&pc)) == -1) + // Could not stat file. + goto cleanup; + + // Check whether there is an ACL entry which would allow the user + // to read the file. Don't check for any groups yet. The issue with + // groups is that it can be very slow to retrieve the list of groups + // of a specific user when done via a remote LDAP server! + if ((rc = check_acl(&pc, USER_CHECK)) == 1) + // Yes, has permissions. + goto cleanup; + + // Check whether ACLs of file could be retrieved. + if (rc == -1) { + if (errno != ENOTSUP) + // Unknown error. + goto cleanup; + + // File system does not support ACLs. + // Fallback to traditional permissions. + if ((rc = check_traditional(&pc, USER_CHECK)) == 1) + // Yes, has traditional permissions. + goto cleanup; + + if ((rc = get_user_groups(&pc)) == -1) + // Can not retrieve user's groups. + goto cleanup; + + rc = check_traditional(&pc, GROUP_CHECK); + goto cleanup; + } + + if ((rc = get_user_groups(&pc)) == -1) + // Can not retrieve use'r groups. + goto cleanup; + + // Check whether there is an ACL entry which would allow any of the + // user's groups to read the file. + rc = check_acl(&pc, GROUP_CHECK); + +cleanup: +#ifdef DEBUG + debug_print_checker(&pc); +#endif + + if (pc.ngids) + free(pc.gids); + + return rc; +} + +// vim: set tabstop=8 softtabstop=0 expandtab shiftwidth=4 smarttab diff --git a/internal/io/fs/permissions/permission_linux.go b/internal/io/fs/permissions/permission_linux.go new file mode 100644 index 0000000..feae729 --- /dev/null +++ b/internal/io/fs/permissions/permission_linux.go @@ -0,0 +1,33 @@ +package permissions + +/* +#include "permission_linux.h" +#cgo LDFLAGS: -L. -lacl +*/ +import "C" + +import ( + "errors" + "unsafe" +) + +// To check whether user has Linux file system permissions to read a given file. +func ToRead(user, filePath string) (bool, error) { + cUser := C.CString(user) + cFilePath := C.CString(filePath) + + defer C.free(unsafe.Pointer(cUser)) + defer C.free(unsafe.Pointer(cFilePath)) + + cOk, err := C.permission_to_read(cUser, cFilePath) + if cOk == 1 { + return true, nil + } + + if err != nil { + // err contains errno message + return false, err + } + + return false, errors.New("User without permission to read file") +} diff --git a/internal/io/fs/permissions/permission_linux.h b/internal/io/fs/permissions/permission_linux.h new file mode 100644 index 0000000..a2c266e --- /dev/null +++ b/internal/io/fs/permissions/permission_linux.h @@ -0,0 +1,60 @@ +#ifndef PERMISSION_LINUX_H +#define PERMISSION_LINUX_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//#define DEBUG +#define USER_CHECK 0 +#define GROUP_CHECK 1 + +struct permission_checker { + char *user_name; + uid_t uid; + gid_t *gids; + int ngids; + char *file_path; + struct stat file_stat; + struct passwd pw; +}; + + +#ifdef DEBUG +// Print out permission_checker struct. +void debug_print_checker(struct permission_checker *pc); +#endif + +// Stat a given file to retrieve traditional UNIX permissions. +int stat_file(struct permission_checker *pc); + +// Retrieve UID of user. +int get_user_uid(struct permission_checker *pc); + +// Retrieve all groups of the user. +int get_user_groups(struct permission_checker *pc); + +// Check whether user is member of a group or not. +int is_member_of_group(struct permission_checker *pc, gid_t gid); + +// Check whether user can read file according Linux ACLs. +// As flag use either USER_CHECK or GROUP_CHECK. +int check_acl(struct permission_checker *pc, const int flag); + +// Check whether user has permissions to read file according traditional +// UNIX permissions. As flag use either USER_CHECK or GROUP_CHECK. +int check_traditional(struct permission_checker *pc, const int flag); + +// Returns 1 if user has permission to read file. +// Returns <0 on error and returns 0 if no permissions. +int permission_to_read(char* user, char *file_path); + +#endif // PERMISSION_LINUX_H diff --git a/internal/io/fs/permissions/permission_test.go b/internal/io/fs/permissions/permission_test.go new file mode 100644 index 0000000..d415ac2 --- /dev/null +++ b/internal/io/fs/permissions/permission_test.go @@ -0,0 +1,112 @@ +// +build linux + +package permissions + +import ( + "os" + "os/exec" + "os/user" + "strings" + "testing" +) + +const ( + setfacl string = "/usr/bin/setfacl" + file string = "/tmp/acltest" +) + +func TestLinuxACL(t *testing.T) { + setfacl := "/usr/bin/setfacl" + file := "/tmp/acltest" + + // Delete file if it exists. + if _, err := os.Stat(file); err == nil { + os.Remove(file) + } + + f, err := os.Create(file) + if err != nil { + t.Errorf("%v", err) + } + defer func() { + f.Close() + //os.Remove(file) + }() + + user, err := user.Current() + if err != nil { + t.Errorf("Unable to retrieve current user: %v", err) + } + + // Test 1: Remove all permissions and perform a permission check + cmd := exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } + + // Test 2: Add read permission to file owner + cmd = exec.Command(setfacl, "-b", "-m", "u::r--,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 3: Add read permission to file group + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::r--,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 4: Add read permission to others + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 5: Remove read permission from mask + cmd = exec.Command(setfacl, "-m", "m::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } + cmd = exec.Command(setfacl, "-m", "m::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + + // Test 6: Add read permission to specific group + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g:"+user.Username+":r--,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file for user %v: %v", user.Username, err) + } + + // Test 7: Remove all permissions but mask + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + cmd = exec.Command(setfacl, "-m", "m::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } +} diff --git a/internal/io/fs/readfile.go b/internal/io/fs/readfile.go new file mode 100644 index 0000000..321432e --- /dev/null +++ b/internal/io/fs/readfile.go @@ -0,0 +1,307 @@ +package fs + +import ( + "bufio" + "compress/gzip" + "context" + "errors" + "io" + "os" + "regexp" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" + + "github.com/DataDog/zstd" +) + +// Used to tail and filter a local log file. +type readFile struct { + // Various statistics (e.g. regex hit percentage, transfer percentage). + stats + // Path of log file to tail. + filePath string + // Only consider all log lines matching this regular expression. + re *regexp.Regexp + // The glob identifier of the file. + globID string + // Channel to send a server message to the dtail client + serverMessages chan<- string + // Periodically retry reading file. + retry bool + // Can I skip messages when there are too many? + canSkipLines bool + // Seek to the EOF before processing file? + seekEOF bool + limiter chan struct{} +} + +// FilePath returns the full file path. +func (f readFile) FilePath() string { + return f.filePath +} + +// Retry reading the file on error? +func (f readFile) Retry() bool { + return f.retry +} + +// Start tailing a log file. +func (f readFile) Start(ctx context.Context, lines chan<- line.Line, regex string) error { + defer func() { + select { + case <-f.limiter: + default: + } + }() + + select { + case f.limiter <- struct{}{}: + default: + select { + case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."): + case <-ctx.Done(): + return nil + } + f.limiter <- struct{}{} + } + + fd, err := os.Open(f.filePath) + if err != nil { + return err + } + defer fd.Close() + + if f.seekEOF { + fd.Seek(0, io.SeekEnd) + } + + rawLines := make(chan []byte, 100) + truncate := make(chan struct{}) + + var wg sync.WaitGroup + wg.Add(1) + + go f.periodicTruncateCheck(ctx, truncate) + go f.filter(ctx, &wg, rawLines, lines, regex) + + err = f.read(ctx, fd, rawLines, truncate) + close(rawLines) + wg.Wait() + + return err +} + +func (f readFile) periodicTruncateCheck(ctx context.Context, truncate chan struct{}) { + for { + select { + case <-time.After(time.Second * 3): + select { + case truncate <- struct{}{}: + case <-ctx.Done(): + } + case <-ctx.Done(): + return + } + } +} + +func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) { + switch { + case strings.HasSuffix(f.FilePath(), ".gz"): + fallthrough + case strings.HasSuffix(f.FilePath(), ".gzip"): + logger.Info(f.FilePath(), "Detected gzip compression format") + var gzipReader *gzip.Reader + gzipReader, err = gzip.NewReader(fd) + if err != nil { + return + } + reader = bufio.NewReader(gzipReader) + case strings.HasSuffix(f.FilePath(), ".zst"): + logger.Info(f.FilePath(), "Detected zstd compression format") + reader = bufio.NewReader(zstd.NewReader(fd)) + default: + reader = bufio.NewReader(fd) + } + + return +} + +func (f readFile) read(ctx context.Context, fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error { + var offset uint64 + + reader, err := f.makeReader(fd) + if err != nil { + return err + } + rawLine := make([]byte, 0, 512) + + lineLengthThreshold := 1024 * 1024 // 1mb + longLineWarning := false + + for { + select { + case <-ctx.Done(): + return nil + default: + } + + select { + case <-truncate: + if isTruncated, err := f.truncated(fd); isTruncated { + return err + } + logger.Info(f.filePath, "Current offset", offset) + default: + } + + // Read some bytes (max 4k at once as of go 1.12). isPrefix will + // be set if line does not fit into 4k buffer. + bytes, isPrefix, err := reader.ReadLine() + + if err != nil { + // If EOF, sleep a couple of ms and return with nil error. + // If other error, return with non-nil error. + if err != io.EOF { + return err + } + if !f.seekEOF { + logger.Debug(f.FilePath(), "End of file reached") + return nil + } + time.Sleep(time.Millisecond * 100) + continue + } + + rawLine = append(rawLine, bytes...) + offset += uint64(len(bytes)) + + if !isPrefix { + // last LineRead call returned contend until end of line. + rawLine = append(rawLine, '\n') + select { + case rawLines <- rawLine: + case <-ctx.Done(): + return nil + } + rawLine = make([]byte, 0, 512) + if longLineWarning { + longLineWarning = false + } + continue + } + + // Last LineRead call could not read content until end of line, buffer + // was too small. Determine whether we exceed the max line length we + // want dtail to send to the client at once. Possibly split up log line + // into multiple log lines. + if len(rawLine) >= lineLengthThreshold { + if !longLineWarning { + f.serverMessages <- logger.Warn(f.filePath, "Long log line, splitting into multiple lines") + // Only print out one warning per long log line. + longLineWarning = true + } + rawLine = append(rawLine, '\n') + select { + case rawLines <- rawLine: + case <-ctx.Done(): + return nil + } + rawLine = make([]byte, 0, 512) + } + } +} + +// Filter log lines matching a given regular expression. +func (f readFile) filter(ctx context.Context, wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- line.Line, regex string) { + defer wg.Done() + + if regex == "" { + regex = "." + } + + re, err := regexp.Compile(regex) + if err != nil { + logger.Error(regex, "Can't compile regex, using '.' instead", err) + re = regexp.MustCompile(".") + } + f.re = re + + for { + select { + case line, ok := <-rawLines: + f.updatePosition() + if !ok { + return + } + if filteredLine, ok := f.transmittable(line, len(lines), cap(lines)); ok { + select { + case lines <- filteredLine: + case <-ctx.Done(): + return + } + } + } + } +} + +func (f readFile) transmittable(lineBytes []byte, length, capacity int) (line.Line, bool) { + var read line.Line + + if !f.re.Match(lineBytes) { + f.updateLineNotMatched() + f.updateLineNotTransmitted() + return read, false + } + f.updateLineMatched() + + // Can we actually send more messages, channel capacity reached? + if f.canSkipLines && length >= capacity { + f.updateLineNotTransmitted() + return read, false + } + f.updateLineTransmitted() + + read = line.Line{ + Content: lineBytes, + SourceID: f.globID, + Count: f.totalLineCount(), + TransmittedPerc: f.transmittedPerc(), + } + + return read, true +} + +// Check wether log file is truncated. Returns nil if not. +func (f readFile) truncated(fd *os.File) (bool, error) { + logger.Debug(f.filePath, "File truncation check") + + // Can not seek currently open FD. + curPos, err := fd.Seek(0, os.SEEK_CUR) + if err != nil { + return true, err + } + + // Can not open file at original path. + pathFd, err := os.Open(f.filePath) + if err != nil { + return true, err + } + defer pathFd.Close() + + // Can not seek file at original path. + pathPos, err := pathFd.Seek(0, io.SeekEnd) + if err != nil { + return true, err + } + + if curPos > pathPos { + return true, errors.New("File got truncated") + } + + return false, nil +} diff --git a/internal/io/fs/stats.go b/internal/io/fs/stats.go new file mode 100644 index 0000000..4121ff7 --- /dev/null +++ b/internal/io/fs/stats.go @@ -0,0 +1,69 @@ +package fs + +// Used to calculate how many log lines matched the regular expression +// and how many log files could be transmitted from the server to the client. +// Hit and transmit percentage takes only the last 100 log lines into calculation. +type stats struct { + pos int + lineCount uint64 + matched [100]bool + matchCount uint64 + transmitted [100]bool + transmitCount int +} + +// Return the total line count. +func (f *stats) totalLineCount() uint64 { + return f.lineCount +} + +// Calculate the percentage of log lines transmitted to the client. +func (f *stats) transmittedPerc() int { + return int(percentOf(float64(f.matchCount), float64(f.transmitCount))) +} + +// Update bucket position. We only take into consideration the last 100 +// lines for stats. +func (f *stats) updatePosition() { + f.pos = (f.pos + 1) % 100 + f.lineCount++ +} + +// Increment match counter. +func (f *stats) updateLineMatched() { + if !f.matched[f.pos] { + f.matchCount++ + f.matched[f.pos] = true + } +} + +// Increment transmitted counter. +func (f *stats) updateLineTransmitted() { + if !f.transmitted[f.pos] { + f.transmitCount++ + f.transmitted[f.pos] = true + } +} + +// Decrement match counter. +func (f *stats) updateLineNotMatched() { + if f.matched[f.pos] { + f.matchCount-- + f.matched[f.pos] = false + } +} + +// Decrement transmitted counter. +func (f *stats) updateLineNotTransmitted() { + if f.transmitted[f.pos] { + f.transmitCount-- + f.transmitted[f.pos] = false + } +} + +func percentOf(total float64, value float64) float64 { + if total == 0 || total == value { + return 100 + } + return value / (total / 100.0) +} diff --git a/internal/io/fs/tailfile.go b/internal/io/fs/tailfile.go new file mode 100644 index 0000000..14994e5 --- /dev/null +++ b/internal/io/fs/tailfile.go @@ -0,0 +1,21 @@ +package fs + +// TailFile is to tail and filter a log file. +type TailFile struct { + readFile +} + +// NewTailFile returns a new file tailer. +func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile { + return TailFile{ + readFile: readFile{ + filePath: filePath, + globID: globID, + serverMessages: serverMessages, + retry: true, + canSkipLines: true, + seekEOF: true, + limiter: limiter, + }, + } +} diff --git a/internal/io/line/line.go b/internal/io/line/line.go new file mode 100644 index 0000000..9db93c0 --- /dev/null +++ b/internal/io/line/line.go @@ -0,0 +1,28 @@ +package line + +import ( + "fmt" +) + +// Line represents a read log line. +type Line struct { + // The content of the log line. + Content []byte + // Until now, how many log lines were processed? + Count uint64 + // Sometimes we produce too many log lines so that the client + // is too slow to process all of them. The server will drop log + // lines if that happens but it will signal to the client how + // many log lines in % could be transmitted to the client. + TransmittedPerc int + SourceID string +} + +// Return a human readable representation of the followed line. +func (l Line) String() string { + return fmt.Sprintf("Line(Content:%s,TransmittedPerc:%v,Count:%v,SourceID:%s)", + string(l.Content), + l.TransmittedPerc, + l.Count, + l.SourceID) +} diff --git a/internal/io/logger/logger.go b/internal/io/logger/logger.go new file mode 100644 index 0000000..e30b907 --- /dev/null +++ b/internal/io/logger/logger.go @@ -0,0 +1,445 @@ +package logger + +import ( + "bufio" + "context" + "fmt" + "os" + "os/signal" + "runtime" + "strings" + "sync" + "syscall" + "time" + + "github.com/mimecast/dtail/internal/color" + "github.com/mimecast/dtail/internal/config" +) + +const ( + clientStr string = "CLIENT" + serverStr string = "SERVER" + infoStr string = "INFO" + warnStr string = "WARN" + errorStr string = "ERROR" + fatalStr string = "FATAL" + debugStr string = "DEBUG" + traceStr string = "TRACE" +) + +// Synchronise access to logging. +var mutex sync.Mutex + +// File descriptor of log file when logToFile enabled. +var fd *os.File + +// File write buffer of log file when logToFile enabled. +var writer *bufio.Writer + +// File write buffer of stdout when logToStdout enabled. +var stdoutWriter *bufio.Writer + +// Current hostname. +var hostname string + +// Used to detect change of day (create one log file per day0 +var lastDateStr string + +// True if log in server mode, false if log in client mode. +var serverEnable bool + +// Used to make logging non-blocking. +var fileLogBufCh chan buf +var stdoutBufCh chan string + +// Stdout channel, required to pause output +var pauseCh chan struct{} +var resumeCh chan struct{} + +// Tell the logger about logrotation +var rotateCh chan os.Signal + +// LogMode allows to specify the verbosity of logging. +type LogMode int + +// Possible log modes. +const ( + NormalMode LogMode = iota + DebugMode LogMode = iota + SilentMode LogMode = iota + TraceMode LogMode = iota + NothingMode LogMode = iota +) + +// Mode is the current log mode in use. +var Mode LogMode + +// LogStrategy allows to specify a log rotation strategy. +type LogStrategy int + +// Possible log strategies. +const ( + NormalStrategy LogStrategy = iota + DailyStrategy LogStrategy = iota + StdoutStrategy LogStrategy = iota +) + +// Strategy is the current log strattegy used. +var Strategy LogStrategy + +// Enables logging to stdout. +var logToStdout bool + +// Enables logging to file. +var logToFile bool + +// Helper type to make logging non-blocking. +type buf struct { + time time.Time + message string +} + +// Start logging. +func Start(ctx context.Context, myServerEnable, debugEnable, silentEnable, nothingEnable bool) { + serverEnable = myServerEnable + + mode := logMode(debugEnable, silentEnable, nothingEnable) + strategy := logStrategy() + + stdoutWriter = bufio.NewWriter(os.Stdout) + Mode = mode + Strategy = strategy + + if Mode == NothingMode { + return + } + + switch Strategy { + case DailyStrategy: + _, err := os.Stat(config.Common.LogDir) + logToFile = !os.IsNotExist(err) + logToStdout = !serverEnable || Mode == DebugMode || Mode == TraceMode + case StdoutStrategy: + fallthrough + default: + logToFile = !serverEnable + logToStdout = true + } + + fqdn, err := os.Hostname() + if err != nil { + panic(err) + } + s := strings.Split(fqdn, ".") + hostname = s[0] + + pauseCh = make(chan struct{}) + resumeCh = make(chan struct{}) + + // Setup logrotation + rotateCh = make(chan os.Signal, 1) + signal.Notify(rotateCh, syscall.SIGHUP) + + if logToStdout { + stdoutBufCh = make(chan string, runtime.NumCPU()*100) + go writeToStdout(ctx) + } + + if logToFile { + fileLogBufCh = make(chan buf, runtime.NumCPU()*100) + go writeToFile(ctx) + } +} + +func logMode(debugEnable, silentEnable, nothingEnable bool) LogMode { + switch { + case debugEnable: + return DebugMode + case nothingEnable: + return NothingMode + case config.Common.TraceEnable: + return TraceMode + case config.Common.DebugEnable: + return DebugMode + case silentEnable: + return SilentMode + default: + } + return NormalMode +} + +func logStrategy() LogStrategy { + switch config.Common.LogStrategy { + case "daily": + return DailyStrategy + default: + } + return StdoutStrategy +} + +// Info message logging. +func Info(args ...interface{}) string { + if serverEnable { + return log(serverStr, infoStr, args) + } + + return log(clientStr, infoStr, args) +} + +// Warn message logging. +func Warn(args ...interface{}) string { + if serverEnable { + return log(serverStr, warnStr, args) + } + + return log(clientStr, warnStr, args) +} + +// Error message logging. +func Error(args ...interface{}) string { + if serverEnable { + return log(serverStr, errorStr, args) + } + + return log(clientStr, errorStr, args) +} + +// FatalExit logs an error and exists the process. +func FatalExit(args ...interface{}) { + what := clientStr + if serverEnable { + what = serverStr + } + log(what, fatalStr, args) + + time.Sleep(time.Second) + mutex.Lock() + defer mutex.Unlock() + + closeWriter() + os.Exit(3) +} + +// Debug message logging. +func Debug(args ...interface{}) string { + if Mode == DebugMode || Mode == TraceMode { + if serverEnable { + return log(serverStr, debugStr, args) + } + return log(clientStr, debugStr, args) + } + + return "" +} + +// Trace message logging. +func Trace(args ...interface{}) string { + if Mode == TraceMode { + if serverEnable { + return log(serverStr, traceStr, args) + } + return log(clientStr, traceStr, args) + } + + return "" +} + +// Write log line to buffer and/or log file. +func write(what, severity, message string) { + if logToStdout && (Mode != SilentMode || severity != warnStr) { + line := fmt.Sprintf("%s|%s|%s|%s\n", what, hostname, severity, message) + + if color.Colored { + line = color.Colorfy(line) + } + + stdoutBufCh <- line + } + + if logToFile { + t := time.Now() + timeStr := t.Format("20060102-150405") + fileLogBufCh <- buf{ + time: t, + message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message), + } + } +} + +// Generig log message. +func log(what string, severity string, args []interface{}) string { + if Mode == NothingMode { + return "" + } + + var messages []string + + for _, arg := range args { + switch v := arg.(type) { + case string: + messages = append(messages, v) + case int: + messages = append(messages, fmt.Sprintf("%d", v)) + case error: + messages = append(messages, v.Error()) + default: + messages = append(messages, fmt.Sprintf("%v", v)) + } + } + + message := strings.Join(messages, "|") + write(what, severity, message) + + return fmt.Sprintf("%s|%s", severity, message) +} + +// Raw message logging. +func Raw(message string) { + if Mode == NothingMode { + return + } + + if logToFile { + fileLogBufCh <- buf{time.Now(), message} + } + + if logToStdout { + if color.Colored { + message = color.Colorfy(message) + } + stdoutBufCh <- message + } +} + +// Close log writer (e.g. on change of day). +func closeWriter() { + if writer != nil { + writer.Flush() + fd.Close() + } +} + +// Return the correct log file writer +func fileWriter(dateStr string) *bufio.Writer { + if dateStr != lastDateStr { + return updateFileWriter(dateStr) + } + + // Check for log rotation signal + select { + case <-rotateCh: + stdoutWriter.WriteString("Received signal for logrotation\n") + return updateFileWriter(dateStr) + default: + } + + return writer +} + +// Update log file writer +func updateFileWriter(dateStr string) *bufio.Writer { + // Detected change of day. Close current writer and create a new one. + mutex.Lock() + defer mutex.Unlock() + closeWriter() + + if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) { + if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil { + panic(err) + } + } + + logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, dateStr) + newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644) + if err != nil { + panic(err) + } + + fd = newFd + writer = bufio.NewWriterSize(fd, 1) + lastDateStr = dateStr + + return writer +} + +// Flush all outstanding lines. +func Flush() { + for { + select { + case message := <-stdoutBufCh: + stdoutWriter.WriteString(message) + default: + stdoutWriter.Flush() + return + } + } +} + +func writeToStdout(ctx context.Context) { + for { + select { + case message := <-stdoutBufCh: + stdoutWriter.WriteString(message) + case <-time.After(time.Millisecond * 100): + stdoutWriter.Flush() + case <-pauseCh: + PAUSE: + for { + select { + case <-stdoutBufCh: + case <-resumeCh: + break PAUSE + case <-ctx.Done(): + return + } + } + case <-ctx.Done(): + Flush() + return + } + } +} + +func writeToFile(ctx context.Context) { + for { + select { + case buf := <-fileLogBufCh: + dateStr := buf.time.Format("20060102") + w := fileWriter(dateStr) + w.WriteString(buf.message) + case <-pauseCh: + PAUSE: + for { + select { + case <-stdoutBufCh: + case <-resumeCh: + break PAUSE + case <-ctx.Done(): + return + } + } + case <-ctx.Done(): + return + } + } +} + +// Pause logging. +func Pause() { + if logToStdout { + pauseCh <- struct{}{} + } + if logToFile { + pauseCh <- struct{}{} + } +} + +// Resume logging (after pausing). +func Resume() { + if logToStdout { + resumeCh <- struct{}{} + } + if logToFile { + resumeCh <- struct{}{} + } +} diff --git a/internal/io/run/run.go b/internal/io/run/run.go new file mode 100644 index 0000000..b608639 --- /dev/null +++ b/internal/io/run/run.go @@ -0,0 +1,104 @@ +package run + +import ( + "bufio" + "context" + "io" + "os/exec" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" +) + +// Run is for execute a command. +type Run struct { + commandPath string + args []string + cmd *exec.Cmd +} + +// New returns a new command runner. +func New(commandPath string, args []string) Run { + return Run{ + commandPath: commandPath, + args: args, + } +} + +// Start running the command. +func (r Run) Start(ctx context.Context, lines chan<- line.Line) (pid int, ec int, err error) { + done := make(chan struct{}) + defer close(done) + + ec = -1 + pid = -1 + + if len(r.args) > 0 { + logger.Debug(r.commandPath, strings.Join(r.args, " ")) + r.cmd = exec.CommandContext(ctx, r.commandPath, strings.Join(r.args, " ")) + } else { + logger.Debug(r.commandPath) + r.cmd = exec.CommandContext(ctx, r.commandPath) + } + + stdoutPipe, myErr := r.cmd.StdoutPipe() + if err != nil { + err = myErr + return + } + + stderrPipe, myErr := r.cmd.StderrPipe() + if myErr != nil { + err = myErr + return + } + + if myErr := r.cmd.Start(); err != nil { + err = myErr + return + } + + pid = r.cmd.Process.Pid + ec = 0 + + var wg sync.WaitGroup + wg.Add(2) + + go r.pipeToLines(done, &wg, pid, stdoutPipe, "STDOUT", lines) + go r.pipeToLines(done, &wg, pid, stderrPipe, "STDERR", lines) + + if err = r.cmd.Wait(); err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + ec = exitError.ExitCode() + } + } + + return +} + +func (r Run) pipeToLines(done chan struct{}, wg *sync.WaitGroup, pid int, reader io.Reader, what string, lines chan<- line.Line) { + defer wg.Done() + bufReader := bufio.NewReader(reader) + + for { + lineStr, err := bufReader.ReadString('\n') + for err == nil { + lines <- line.Line{ + Content: []byte(lineStr), + Count: uint64(pid), + TransmittedPerc: 100, + SourceID: what, + } + lineStr, err = bufReader.ReadString('\n') + } + select { + case <-done: + return + default: + } + time.Sleep(time.Millisecond * 10) + } +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go deleted file mode 100644 index ca85e32..0000000 --- a/internal/logger/logger.go +++ /dev/null @@ -1,457 +0,0 @@ -package logger - -import ( - "bufio" - "fmt" - "os" - "os/signal" - "runtime" - "strings" - "sync" - "syscall" - "time" - - "github.com/mimecast/dtail/internal/color" - "github.com/mimecast/dtail/internal/config" -) - -const ( - clientStr string = "CLIENT" - serverStr string = "SERVER" - infoStr string = "INFO" - warnStr string = "WARN" - errorStr string = "ERROR" - fatalStr string = "FATAL" - debugStr string = "DEBUG" - traceStr string = "TRACE" -) - -// Synchronise access to logging. -var mutex sync.Mutex - -// File descriptor of log file when logToFile enabled. -var fd *os.File - -// File write buffer of log file when logToFile enabled. -var writer *bufio.Writer - -// File write buffer of stdout when logToStdout enabled. -var stdoutWriter *bufio.Writer - -// Current hostname. -var hostname string - -// Used to detect change of day (create one log file per day0 -var lastDateStr string - -// True if log in server mode, false if log in client mode. -var serverEnable bool - -// Used to make logging non-blocking. -var logBufCh chan buf -var stdoutBufCh chan string - -// Stdout channel, required to pause output -var pauseCh chan struct{} -var resumeCh chan struct{} - -// Tell the logger that we are done, program shuts down -var stop chan struct{} -var stdoutFlushed chan struct{} - -// Tell the logger about logrotation -var rotateCh chan os.Signal - -// LogMode allows to specify the verbosity of logging. -type LogMode int - -// Possible log modes. -const ( - NormalMode LogMode = iota - DebugMode LogMode = iota - SilentMode LogMode = iota - TraceMode LogMode = iota - NothingMode LogMode = iota -) - -// Mode is the current log mode in use. -var Mode LogMode - -// LogStrategy allows to specify a log rotation strategy. -type LogStrategy int - -// Possible log strategies. -const ( - NormalStrategy LogStrategy = iota - DailyStrategy LogStrategy = iota - StdoutStrategy LogStrategy = iota -) - -// Strategy is the current log strattegy used. -var Strategy LogStrategy - -// Enables logging to stdout. -var logToStdout bool - -// Enables logging to file. -var logToFile bool - -// Helper type to make logging non-blocking. -type buf struct { - time time.Time - message string -} - -// Start logging. -func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) { - serverEnable = myServerEnable - - mode := logMode(debugEnable, silentEnable, nothingEnable) - strategy := logStrategy() - - stdoutWriter = bufio.NewWriter(os.Stdout) - Mode = mode - Strategy = strategy - - if Mode == NothingMode { - return - } - - switch Strategy { - case DailyStrategy: - _, err := os.Stat(config.Common.LogDir) - logToFile = !os.IsNotExist(err) - logToStdout = !serverEnable || Mode == DebugMode || Mode == TraceMode - case StdoutStrategy: - fallthrough - default: - logToFile = false - logToStdout = true - } - - fqdn, err := os.Hostname() - if err != nil { - panic(err) - } - s := strings.Split(fqdn, ".") - hostname = s[0] - - pauseCh = make(chan struct{}) - resumeCh = make(chan struct{}) - stop = make(chan struct{}) - stdoutFlushed = make(chan struct{}) - - // Setup logrotation - rotateCh = make(chan os.Signal, 1) - signal.Notify(rotateCh, syscall.SIGHUP) - - if logToStdout { - stdoutBufCh = make(chan string, runtime.NumCPU()*100) - go writeToStdout() - } - - if logToFile { - logBufCh = make(chan buf, runtime.NumCPU()*100) - go writeToFile() - } -} - -func logMode(debugEnable, silentEnable, nothingEnable bool) LogMode { - switch { - case debugEnable: - return DebugMode - case nothingEnable: - return NothingMode - case config.Common.TraceEnable: - return TraceMode - case config.Common.DebugEnable: - return DebugMode - case silentEnable: - return SilentMode - default: - } - return NormalMode -} - -func logStrategy() LogStrategy { - switch config.Common.LogStrategy { - case "daily": - return DailyStrategy - default: - } - return StdoutStrategy -} - -// Info message logging. -func Info(args ...interface{}) string { - if serverEnable { - return log(serverStr, infoStr, args) - } - - return log(clientStr, infoStr, args) -} - -// Warn message logging. -func Warn(args ...interface{}) string { - if serverEnable { - return log(serverStr, warnStr, args) - } - - return log(clientStr, warnStr, args) -} - -// Error message logging. -func Error(args ...interface{}) string { - if serverEnable { - return log(serverStr, errorStr, args) - } - - return log(clientStr, errorStr, args) -} - -// FatalExit logs an error and exists the process. -func FatalExit(args ...interface{}) { - what := clientStr - if serverEnable { - what = serverStr - } - log(what, fatalStr, args) - - time.Sleep(time.Second) - mutex.Lock() - defer mutex.Unlock() - - closeWriter() - os.Exit(3) -} - -// Debug message logging. -func Debug(args ...interface{}) string { - if Mode == DebugMode || Mode == TraceMode { - if serverEnable { - return log(serverStr, debugStr, args) - } - return log(clientStr, debugStr, args) - } - - return "" -} - -// Trace message logging. -func Trace(args ...interface{}) string { - if Mode == TraceMode { - if serverEnable { - return log(serverStr, traceStr, args) - } - return log(clientStr, traceStr, args) - } - - return "" -} - -// Write log line to buffer and/or log file. -func write(what, severity, message string) { - if logToStdout && (Mode != SilentMode || severity != warnStr) { - line := fmt.Sprintf("%s|%s|%s|%s\n", what, hostname, severity, message) - - if color.Colored { - line = color.Colorfy(line) - } - - stdoutBufCh <- line - } - - if logToFile { - t := time.Now() - timeStr := t.Format("20060102-150405") - logBufCh <- buf{ - time: t, - message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message), - } - } -} - -// Generig log message. -func log(what string, severity string, args []interface{}) string { - if Mode == NothingMode { - return "" - } - - var messages []string - - for _, arg := range args { - switch v := arg.(type) { - case string: - messages = append(messages, v) - case int: - messages = append(messages, fmt.Sprintf("%d", v)) - case error: - messages = append(messages, v.Error()) - default: - messages = append(messages, fmt.Sprintf("%v", v)) - } - } - - message := strings.Join(messages, "|") - write(what, severity, message) - - return fmt.Sprintf("%s|%s", severity, message) -} - -// Raw message logging. -func Raw(message string) { - if Mode == NothingMode { - return - } - - if logToStdout { - if color.Colored { - message = color.Colorfy(message) - } - stdoutBufCh <- message - } - - if logToFile { - logBufCh <- buf{time.Now(), message} - } -} - -// Close log writer (e.g. on change of day). -func closeWriter() { - if writer != nil { - writer.Flush() - fd.Close() - } -} - -// Return the correct log file writer -func fileWriter(dateStr string) *bufio.Writer { - if dateStr != lastDateStr { - return updateFileWriter(dateStr) - } - - // Check for log rotation signal - select { - case <-rotateCh: - stdoutWriter.WriteString("Received signal for logrotation\n") - return updateFileWriter(dateStr) - default: - } - - return writer -} - -// Update log file writer -func updateFileWriter(dateStr string) *bufio.Writer { - // Detected change of day. Close current writer and create a new one. - mutex.Lock() - defer mutex.Unlock() - closeWriter() - - if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) { - if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil { - panic(err) - } - } - - logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, dateStr) - newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644) - if err != nil { - panic(err) - } - - fd = newFd - writer = bufio.NewWriterSize(fd, 1) - lastDateStr = dateStr - - return writer -} - -func flushStdout() { - defer close(stdoutFlushed) - - for { - select { - case message := <-stdoutBufCh: - stdoutWriter.WriteString(message) - default: - stdoutWriter.Flush() - return - } - } -} - -func writeToStdout() { - for { - select { - case message := <-stdoutBufCh: - stdoutWriter.WriteString(message) - case <-time.After(time.Millisecond * 100): - stdoutWriter.Flush() - case <-pauseCh: - PAUSE: - for { - select { - case <-stdoutBufCh: - case <-resumeCh: - break PAUSE - case <-stop: - return - } - } - case <-stop: - flushStdout() - return - } - } -} - -func writeToFile() { - for { - select { - case buf := <-logBufCh: - dateStr := buf.time.Format("20060102") - w := fileWriter(dateStr) - w.WriteString(buf.message) - case <-pauseCh: - PAUSE: - for { - select { - case <-stdoutBufCh: - case <-resumeCh: - break PAUSE - case <-stop: - return - } - } - case <-stop: - return - } - } -} - -// Pause logging. -func Pause() { - if logToStdout { - pauseCh <- struct{}{} - } - if logToFile { - pauseCh <- struct{}{} - } -} - -// Resume logging (after pausing). -func Resume() { - if logToStdout { - resumeCh <- struct{}{} - } - if logToFile { - resumeCh <- struct{}{} - } -} - -// Stop logging. -func Stop() { - close(stop) - <-stdoutFlushed -} diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go index 2096c3c..7fb4c17 100644 --- a/internal/mapr/aggregateset.go +++ b/internal/mapr/aggregateset.go @@ -1,6 +1,7 @@ package mapr import ( + "context" "fmt" "strconv" "strings" @@ -64,7 +65,7 @@ func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error { } // Serialize the aggregate set so it can be sent over the wire. -func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan struct{}) { +func (s *AggregateSet) Serialize(ctx context.Context, groupKey string, ch chan<- string) { //logger.Trace("Serialising mapr.AggregateSet", s) var sb strings.Builder @@ -87,7 +88,7 @@ func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan st select { case ch <- sb.String(): - case <-stop: + case <-ctx.Done(): } } diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go index 3f2b7a5..1272a19 100644 --- a/internal/mapr/client/aggregate.go +++ b/internal/mapr/client/aggregate.go @@ -1,10 +1,11 @@ package client import ( - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/mapr" "strconv" "strings" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/mapr" ) // Aggregate mapreduce data on the DTail client side. @@ -15,7 +16,6 @@ type Aggregate struct { group *mapr.GroupSet // This represents the merged aggregated data of all servers. globalGroup *mapr.GlobalGroupSet - stop chan struct{} // The server we aggregate the data for (logging and debugging purposes only) server string } @@ -26,20 +26,12 @@ func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGrou query: query, group: mapr.NewGroupSet(), globalGroup: globalGroup, - stop: make(chan struct{}), server: server, } } // Aggregate data from mapr log line into local (and global) group sets. func (a *Aggregate) Aggregate(parts []string) { - select { - case <-a.stop: - logger.Error("Client aggregator stopped for server, not processing new data", a.server) - return - default: - } - groupKey := parts[0] samples, err := strconv.Atoi(parts[1]) if err != nil { @@ -87,14 +79,3 @@ func (a *Aggregate) makeFields(parts []string) map[string]string { return fields } - -// Stop the client side mapreduce aggregator. -func (a *Aggregate) Stop() { - logger.Debug("Stopping client mapreduce aggregator") - close(a.stop) - - err := a.globalGroup.Merge(a.query, a.group) - if err != nil { - panic(err) - } -} diff --git a/internal/mapr/groupset.go b/internal/mapr/groupset.go index d8f9379..e9e0d37 100644 --- a/internal/mapr/groupset.go +++ b/internal/mapr/groupset.go @@ -1,6 +1,7 @@ package mapr import ( + "context" "errors" "fmt" "io/ioutil" @@ -46,9 +47,9 @@ func (g *GroupSet) GetSet(groupKey string) *AggregateSet { } // Serialize the group set (e.g. to send it over the wire). -func (g *GroupSet) Serialize(ch chan<- string, stop chan struct{}) { +func (g *GroupSet) Serialize(ctx context.Context, ch chan<- string) { for groupKey, set := range g.sets { - set.Serialize(groupKey, ch, stop) + set.Serialize(ctx, groupKey, ch) } } diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go index 5730d29..09c706b 100644 --- a/internal/mapr/logformat/parser.go +++ b/internal/mapr/logformat/parser.go @@ -1,9 +1,9 @@ package logformat import ( - "github.com/mimecast/dtail/internal/logger" "errors" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "os" "reflect" "strings" diff --git a/internal/mapr/query.go b/internal/mapr/query.go index 3805d15..0127be3 100644 --- a/internal/mapr/query.go +++ b/internal/mapr/query.go @@ -1,9 +1,9 @@ package mapr import ( - "github.com/mimecast/dtail/internal/logger" "errors" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "strconv" "strings" "time" diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go index 900756e..922dcbd 100644 --- a/internal/mapr/server/aggregate.go +++ b/internal/mapr/server/aggregate.go @@ -1,26 +1,28 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/fs" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/mapr" - "github.com/mimecast/dtail/internal/mapr/logformat" + "context" "os" "strings" "time" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/mapr" + "github.com/mimecast/dtail/internal/mapr/logformat" ) // Aggregate is for aggregating mapreduce data on the DTail server side. type Aggregate struct { // Log lines to process (parsing MAPREDUCE lines). - Lines chan fs.LineRead + Lines chan line.Line // Hostname of the current server (used to populate $hostname field). hostname string - // Signals to exit goroutine. - stop chan struct{} // Signals to serialize data. serialize chan struct{} + // Signals to flush data. + flush chan struct{} // The mapr query query *mapr.Query // The mapr log format parser @@ -28,7 +30,7 @@ type Aggregate struct { } // NewAggregate return a new server side aggregator. -func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error) { +func NewAggregate(queryStr string) (*Aggregate, error) { query, err := mapr.NewQuery(queryStr) if err != nil { return nil, err @@ -47,76 +49,98 @@ func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error) } a := Aggregate{ - Lines: make(chan fs.LineRead, 100), - stop: make(chan struct{}), + Lines: make(chan line.Line, 100), serialize: make(chan struct{}), + flush: make(chan struct{}), hostname: s[0], query: query, parser: logParser, } - go a.periodicAggregateTimer() - - fieldsCh := make(chan map[string]string) - go a.readFields(fieldsCh, maprLines) - go a.readLines(fieldsCh) - return &a, nil } -func (a *Aggregate) periodicAggregateTimer() { +// Start an aggregation run. +func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) { + fieldsCh := a.linesToFields(ctx) + go a.fieldsToMaprLines(ctx, fieldsCh, maprLines) + a.periodicAggregateTimer(ctx) +} + +func (a *Aggregate) periodicAggregateTimer(ctx context.Context) { for { select { case <-time.After(a.query.Interval): - a.Serialize() - case <-a.stop: + a.Serialize(ctx) + case <-ctx.Done(): return } } } -func (a *Aggregate) readFields(fieldsCh <-chan map[string]string, maprLines chan<- string) { - group := mapr.NewGroupSet() +func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string { + fieldsCh := make(chan map[string]string) - for { - select { - case fields := <-fieldsCh: - a.aggregate(group, fields) - case <-a.serialize: - logger.Info("Serializing mapreduce result") - group.Serialize(maprLines, a.stop) - logger.Info("Done serializing mapreduce result") - group = mapr.NewGroupSet() - case <-a.stop: - return + go func() { + defer close(fieldsCh) + + for { + select { + case line, ok := <-a.Lines: + if !ok { + return + } + + maprLine := strings.TrimSpace(string(line.Content)) + fields, err := a.parser.MakeFields(maprLine) + + if err != nil { + logger.Error(err) + continue + } + if !a.query.WhereClause(fields) { + continue + } + + select { + case fieldsCh <- fields: + case <-ctx.Done(): + } + case <-ctx.Done(): + return + } } - } + }() + + return fieldsCh } -func (a *Aggregate) readLines(fieldsCh chan<- map[string]string) { +func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) { + group := mapr.NewGroupSet() + for { select { - case line, ok := <-a.Lines: + case fields, ok := <-fieldsCh: if !ok { + logger.Info("Serializing mapreduce result (final)") + group.Serialize(ctx, maprLines) + group = mapr.NewGroupSet() + logger.Info("Done serializing mapreduce result (final)") return } - - maprLine := strings.TrimSpace(string(line.Content)) - fields, err := a.parser.MakeFields(maprLine) - - if err != nil { - logger.Error(err) - continue - } - if !a.query.WhereClause(fields) { - continue - } - - select { - case fieldsCh <- fields: - case <-a.stop: - } - case <-a.stop: + a.aggregate(group, fields) + case <-a.serialize: + logger.Info("Serializing mapreduce result") + group.Serialize(ctx, maprLines) + group = mapr.NewGroupSet() + logger.Info("Done serializing mapreduce result") + case <-a.flush: + logger.Info("Flushing mapreduce result") + group.Serialize(ctx, maprLines) + group = mapr.NewGroupSet() + a.flush <- struct{}{} + logger.Info("Done flushing mapreduce result") + case <-ctx.Done(): return } } @@ -157,14 +181,15 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) { } // Serialize all the aggregated data. -func (a *Aggregate) Serialize() { +func (a *Aggregate) Serialize(ctx context.Context) { select { case a.serialize <- struct{}{}: - case <-a.stop: + case <-ctx.Done(): } } -// Close the aggregator. -func (a *Aggregate) Close() { - close(a.stop) +// Flush all data. +func (a *Aggregate) Flush() { + a.flush <- struct{}{} + <-a.flush } diff --git a/internal/mapr/wherecondition.go b/internal/mapr/wherecondition.go index e1f4e5b..ab46bed 100644 --- a/internal/mapr/wherecondition.go +++ b/internal/mapr/wherecondition.go @@ -1,9 +1,9 @@ package mapr import ( - "github.com/mimecast/dtail/internal/logger" "errors" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "strconv" "strings" ) diff --git a/internal/omode/mode.go b/internal/omode/mode.go index 57366d2..e29aacc 100644 --- a/internal/omode/mode.go +++ b/internal/omode/mode.go @@ -12,7 +12,7 @@ const ( GrepClient Mode = iota MapClient Mode = iota HealthClient Mode = iota - ExecClient Mode = iota + RunClient Mode = iota ) func (m Mode) String() string { @@ -29,8 +29,8 @@ func (m Mode) String() string { return "map" case HealthClient: return "health" - case ExecClient: - return "exec" + case RunClient: + return "run" default: return "unknown" } diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go index f78bcf6..c6d11ca 100644 --- a/internal/pprof/pprof.go +++ b/internal/pprof/pprof.go @@ -7,9 +7,10 @@ import ( _ "net/http/pprof" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" ) +// Start the profiler HTTP server. func Start() { bindAddr := fmt.Sprintf("%s:%d", config.Common.PProfBindAddress, config.Common.PProfPort) logger.Info("Starting PProf server", bindAddr) diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go index 76a2726..a438d33 100644 --- a/internal/prompt/prompt.go +++ b/internal/prompt/prompt.go @@ -2,8 +2,8 @@ package prompt import ( "bufio" - "github.com/mimecast/dtail/internal/logger" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "os" "strings" ) diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go index 482f759..a33a78b 100644 --- a/internal/server/handlers/controlhandler.go +++ b/internal/server/handlers/controlhandler.go @@ -1,33 +1,34 @@ package handlers import ( + "context" "fmt" "io" "os" "strings" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" user "github.com/mimecast/dtail/internal/user/server" ) // ControlHandler is used for control functions and health monitoring. type ControlHandler struct { - serverMessages chan string - pong chan struct{} - stop chan struct{} - payload []byte + ctx context.Context + done chan struct{} hostname string + payload []byte + serverMessages chan string user *user.User } // NewControlHandler returns a new control handler. -func NewControlHandler(user *user.User) *ControlHandler { +func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) { logger.Debug(user, "Creating control handler") h := ControlHandler{ + ctx: ctx, + done: make(chan struct{}), serverMessages: make(chan string, 10), - pong: make(chan struct{}, 10), - stop: make(chan struct{}), user: user, } @@ -38,7 +39,8 @@ func NewControlHandler(user *user.User) *ControlHandler { s := strings.Split(fqdn, ".") h.hostname = s[0] - return &h + + return &h, h.done } // Read is to send data to the client via the Reader interface. @@ -49,11 +51,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) { wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) n = copy(p, wholePayload) return - case <-h.pong: - logger.Info(h.user, "Sending pong") - n = copy(p, []byte(".pong\n")) - return - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF } } @@ -65,7 +63,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) { switch c { case ';': wholePayload := strings.TrimSpace(string(h.payload)) - h.handleCommand(wholePayload) + h.handleCommand(h.ctx, wholePayload) h.payload = nil default: @@ -77,17 +75,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) { return } -// Close the control handler. -func (h *ControlHandler) Close() { - close(h.stop) -} - -// Wait returns the handler stop channel. -func (h *ControlHandler) Wait() <-chan struct{} { - return h.stop -} - -func (h *ControlHandler) handleCommand(command string) { +func (h *ControlHandler) handleCommand(ctx context.Context, command string) { logger.Info(h.user, command) s := strings.Split(command, " ") logger.Debug(h.user, "Receiving command", command, s) @@ -96,8 +84,6 @@ func (h *ControlHandler) handleCommand(command string) { case "health": h.serverMessages <- "OK: DTail SSH Server seems fine" h.serverMessages <- "done;" - case "ping": - h.pong <- struct{}{} case "debug": h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s) default: diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go index 8b1f73e..c42ceb9 100644 --- a/internal/server/handlers/handler.go +++ b/internal/server/handlers/handler.go @@ -5,6 +5,4 @@ import "io" // Handler interface for server side functionality. type Handler interface { io.ReadWriter - Close() - Wait() <-chan struct{} } diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go new file mode 100644 index 0000000..10372da --- /dev/null +++ b/internal/server/handlers/mapcommand.go @@ -0,0 +1,35 @@ +package handlers + +import ( + "context" + "strings" + + "github.com/mimecast/dtail/internal/mapr/server" +) + +// Map command implements the mapreduce command server side. +type mapCommand struct { + aggregate *server.Aggregate + server *ServerHandler +} + +// NewMapCommand returns a new server side mapreduce command. +func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) { + mapCommand := mapCommand{ + server: serverHandler, + } + + queryStr := strings.Join(args[1:], " ") + aggregate, err := server.NewAggregate(queryStr) + if err != nil { + return mapCommand, nil, err + } + + mapCommand.aggregate = aggregate + return mapCommand, aggregate, nil + +} + +func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) { + m.aggregate.Start(ctx, aggregatedMessages) +} diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go new file mode 100644 index 0000000..e4079e8 --- /dev/null +++ b/internal/server/handlers/readcommand.go @@ -0,0 +1,158 @@ +package handlers + +import ( + "context" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/fs" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/omode" +) + +type readCommand struct { + server *ServerHandler + mode omode.Mode +} + +func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand { + return &readCommand{ + server: server, + mode: mode, + } +} + +func (r *readCommand) Start(ctx context.Context, argc int, args []string) { + regex := "." + if argc >= 4 { + regex = args[3] + } + if argc < 3 { + r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) + return + } + r.readGlob(ctx, args[1], regex) +} + +func (r *readCommand) readGlob(ctx context.Context, glob string, regex string) { + retryInterval := time.Second * 5 + glob = filepath.Clean(glob) + + maxRetries := 10 + for { + maxRetries-- + if maxRetries < 0 { + r.server.sendServerMessage(logger.Warn(r.server.user, "Giving up to read file(s)")) + return + } + + paths, err := filepath.Glob(glob) + if err != nil { + logger.Warn(r.server.user, glob, err) + time.Sleep(retryInterval) + continue + } + + if numPaths := len(paths); numPaths == 0 { + logger.Error(r.server.user, "No such file(s) to read", glob) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + select { + case <-ctx.Done(): + return + default: + } + time.Sleep(retryInterval) + continue + } + + r.readFiles(ctx, paths, glob, regex, retryInterval) + break + } +} + +func (r *readCommand) readFiles(ctx context.Context, paths []string, glob string, regex string, retryInterval time.Duration) { + var wg sync.WaitGroup + wg.Add(len(paths)) + + for _, path := range paths { + go r.readFileIfPermissions(ctx, &wg, path, glob, regex) + } + + wg.Wait() +} + +func (r *readCommand) readFileIfPermissions(ctx context.Context, wg *sync.WaitGroup, path, glob, regex string) { + defer wg.Done() + globID := r.makeGlobID(path, glob) + + if !r.server.user.HasFilePermission(path, "readfiles") { + logger.Error(r.server.user, "No permission to read file", path, globID) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + return + } + + r.readFile(ctx, path, globID, regex) +} + +func (r *readCommand) readFile(ctx context.Context, path, globID, regex string) { + logger.Info(r.server.user, "Start reading file", path, globID) + + var reader fs.FileReader + switch r.mode { + case omode.TailClient: + reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter) + case omode.GrepClient, omode.CatClient: + reader = fs.NewCatFile(path, globID, r.server.serverMessages, r.server.catLimiter) + default: + reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter) + } + + lines := r.server.lines + + // Plug in mappreduce engine + if r.server.aggregate != nil { + lines = r.server.aggregate.Lines + } + + for { + if err := reader.Start(ctx, lines, regex); err != nil { + logger.Error(r.server.user, path, globID, err) + } + + select { + case <-ctx.Done(): + return + default: + if !reader.Retry() { + return + } + } + + time.Sleep(time.Second * 2) + logger.Info(path, globID, "Reading file again") + } +} + +func (r *readCommand) makeGlobID(path, glob string) string { + var idParts []string + pathParts := strings.Split(path, "/") + + for i, globPart := range strings.Split(glob, "/") { + if strings.Contains(globPart, "*") { + idParts = append(idParts, pathParts[i]) + } + } + + if len(idParts) > 0 { + return strings.Join(idParts, "/") + } + + if len(pathParts) > 0 { + return pathParts[len(pathParts)-1] + } + + r.server.sendServerMessage(logger.Error("Empty file path given?", path, glob)) + return "" +} diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go new file mode 100644 index 0000000..e260060 --- /dev/null +++ b/internal/server/handlers/runcommand.go @@ -0,0 +1,73 @@ +package handlers + +import ( + "context" + "fmt" + "os/exec" + "strings" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/run" +) + +type runCommand struct { + server *ServerHandler + run run.Run +} + +func newRunCommand(server *ServerHandler) runCommand { + return runCommand{ + server: server, + } +} + +func (r runCommand) Start(ctx context.Context, argc int, args []string) { + if argc < 2 { + r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) + return + } + commands := strings.Split(strings.Join(args[1:], " "), ";") + r.start(ctx, commands) +} + +func (r runCommand) start(ctx context.Context, commands []string) { + for _, command := range commands { + command = strings.TrimSpace(command) + if len(command) == 0 { + continue + } + splitted := strings.Split(command, " ") + path := splitted[0] + args := splitted[1:] + + qualifiedPath, err := exec.LookPath(path) + if err != nil { + logger.Error(r.server.user, err) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs")) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1)) + return + } + + if !r.server.user.HasFilePermission(qualifiedPath, "runcommands") { + logger.Error(r.server.user, "No permission to execute path", qualifiedPath) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs")) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1)) + return + } + + r.run = run.New(qualifiedPath, args) + pid, ec, err := r.run.Start(ctx, r.server.lines) + + if err != nil { + message := fmt.Sprintf("Unable to execute remote command '%s'", command) + logger.Error(r.server.user, message, ec, pid, err) + r.server.sendServerMessage(logger.Error(message, ec, pid, err)) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", ec)) + return + } + + message := fmt.Sprintf("Remote process '%d' exited with status '%d'", pid, ec) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec)) + r.server.sendServerMessage(logger.Info("run", pid, ec, message)) + } +} diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index bed8609..3f0d6ce 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -1,17 +1,19 @@ package handlers import ( + "context" + "encoding/base64" + "errors" "fmt" "io" "os" - "path/filepath" "strings" "sync" "time" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/fs" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr/server" "github.com/mimecast/dtail/internal/omode" user "github.com/mimecast/dtail/internal/user/server" @@ -26,51 +28,33 @@ const ( // the Bi-directional communication between SSH client and server. // This handler implements the handler of the SSH server. type ServerHandler struct { - // Local log file readers - fileReaders []fs.FileReader - fileReadersMtx *sync.Mutex - // Channel for read lines. - lines chan fs.LineRead - // Only process log lines matching this regex. - regex string - // Server side mapr log aggregation. - aggregate *server.Aggregate - // Channel of aggregated log lines. + mutex *sync.Mutex + lines chan line.Line + regex string + aggregate *server.Aggregate aggregatedMessages chan string - // Channel for server messages to be sent to the client. - serverMessages chan string - // Channel for hidden messages to be sent to the client. - hiddenMessages chan string - // The current payload sent to the client. - payload []byte - // The current server hostname. - hostname string - // The user connecting to dtail. - user *user.User - // To limit the server wide max amount of concurrent cats - catLimiter chan struct{} - // To limit the server wide max amount of concurrent tails - tailLimiter chan struct{} - // Server can tell handler to stop the handler. - stop chan struct{} - // Indicate that client responded to server with "ack stop connection" - ackStopReceived chan struct{} - // Stop timeout. - stopTimeout chan struct{} + serverMessages chan string + payload []byte + hostname string + user *user.User + catLimiter chan struct{} + tailLimiter chan struct{} + ackCloseReceived chan struct{} + ctx context.Context + done chan struct{} + activeReaders int } // NewServerHandler returns the server handler. -func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) *ServerHandler { - logger.Debug(user, "Creating tail handler") +func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) { h := ServerHandler{ - fileReadersMtx: &sync.Mutex{}, - lines: make(chan fs.LineRead, 100), + ctx: ctx, + done: make(chan struct{}), + mutex: &sync.Mutex{}, + lines: make(chan line.Line, 100), serverMessages: make(chan string, 10), aggregatedMessages: make(chan string, 10), - hiddenMessages: make(chan string, 10), - ackStopReceived: make(chan struct{}), - stopTimeout: make(chan struct{}), - stop: make(chan struct{}), + ackCloseReceived: make(chan struct{}), catLimiter: catLimiter, tailLimiter: tailLimiter, regex: ".", @@ -85,37 +69,46 @@ func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter cha s := strings.Split(fqdn, ".") h.hostname = s[0] - return &h + return &h, h.done } // Read is to send data to the dtail client via Reader interface. func (h *ServerHandler) Read(p []byte) (n int, err error) { for { select { + case message := <-h.serverMessages: + if message[0] == '.' { + // Handle hidden message (don't display to the user, interpreted by dtail client) + wholePayload := []byte(fmt.Sprintf("%s\n", message)) + n = copy(p, wholePayload) + return + } + + // Handle normal server message (display to the user) wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) n = copy(p, wholePayload) return + case message := <-h.aggregatedMessages: + // Send mapreduce-aggregated data as a message. data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message) - //logger.Debug("Sending aggregation data", data) wholePayload := []byte(data) n = copy(p, wholePayload) return - case message := <-h.hiddenMessages: - //logger.Debug(h.user, "Sending hidden message", message) - wholePayload := []byte(fmt.Sprintf(".%s\n", message)) - n = copy(p, wholePayload) - return + case line := <-h.lines: + // Send normal file content data as a message. serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|", - h.hostname, line.TransmittedPerc, line.Count, *line.GlobID)) + h.hostname, line.TransmittedPerc, line.Count, line.SourceID)) wholePayload := append(serverInfo, line.Content[:]...) n = copy(p, wholePayload) return + case <-time.After(time.Second): + // Once in a while check whether we are done. select { - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF default: } @@ -129,7 +122,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { switch c { case ';': commandStr := strings.TrimSpace(string(h.payload)) - h.handleCommand(commandStr) + h.handleCommand(h.ctx, commandStr) h.payload = nil default: h.payload = append(h.payload, c) @@ -140,210 +133,167 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { return } -// Close the server handler. -func (h *ServerHandler) Close() { - h.fileReadersMtx.Lock() - defer h.fileReadersMtx.Unlock() +func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { + logger.Debug(h.user, commandStr) - for _, reader := range h.fileReaders { - reader.Stop() + args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return } - if h.aggregate != nil { - h.aggregate.Close() + + args, argc, err = h.handleBase64(args, argc) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return } - close(h.stop) -} + if h.user.Name == config.ControlUser { + h.handleControlCommand(argc, args) + return + } -func (h *ServerHandler) makeGlobID(path, glob string) string { - var idParts []string - pathParts := strings.Split(path, "/") + h.handleUserCommand(ctx, argc, args) +} - for i, globPart := range strings.Split(glob, "/") { - if strings.Contains(globPart, "*") { - idParts = append(idParts, pathParts[i]) - } - } +func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) { + argc := len(args) - if len(idParts) > 0 { - return strings.Join(idParts, "/") + if argc <= 2 || args[0] != "protocol" { + return args, argc, errors.New("unable to determine protocol version") } - if len(pathParts) > 0 { - return pathParts[len(pathParts)-1] + if args[1] != version.ProtocolCompat { + err := fmt.Errorf("server with protool version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1]) + return args, argc, err } - h.send(h.serverMessages, logger.Error("Empty file path given?", path, glob)) - return "" + return args[2:], argc - 2, nil } -func (h *ServerHandler) processFileGlob(mode omode.Mode, glob string, regex string) { - retryInterval := time.Second * 5 - glob = filepath.Clean(glob) - - errors := make(chan struct{}) - stop := make(chan struct{}) - defer close(stop) +func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) { + err := errors.New("Unable to decode client message") - go func() { - for { - select { - case <-errors: - h.send(h.serverMessages, logger.Warn(h.user, "Unable to read file(s), check server logs")) - case <-stop: - return - case <-h.stop: - return - } - } - }() + if argc != 2 || args[0] != "base64" { + return args, argc, err + } - maxRetries := 10 - for { - maxRetries-- - if maxRetries < 0 { - h.send(h.serverMessages, logger.Warn(h.user, "Giving up to read file(s)")) - h.internalClose() - return - } + decoded, err := base64.StdEncoding.DecodeString(args[1]) + if err != nil { + return args, argc, err + } + decodedStr := string(decoded) - paths, err := filepath.Glob(glob) - if err != nil { - logger.Warn(h.user, glob, err) - time.Sleep(retryInterval) - continue - } + args = strings.Split(decodedStr, " ") + argc = len(decodedStr) + logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args) - if numPaths := len(paths); numPaths == 0 { - logger.Error(h.user, "No such file(s) to read", glob) - select { - case errors <- struct{}{}: - case <-h.stop: - return - default: - } - time.Sleep(retryInterval) - continue - } + return args, argc, nil +} - h.startReadingFiles(mode, paths, glob, regex, retryInterval, errors) - break +func (h *ServerHandler) handleControlCommand(argc int, args []string) { + switch args[0] { + case "debug": + h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) + default: + logger.Warn(h.user, "Received unknown command", argc, args) } } -func (h *ServerHandler) startReadingFiles(mode omode.Mode, paths []string, glob string, regex string, retryInterval time.Duration, errors chan<- struct{}) { - var wg sync.WaitGroup - wg.Add(len(paths)) +func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) { + logger.Debug(h.user, "handleUserCommand", argc, args) - read := func(path string, wg *sync.WaitGroup) { - defer wg.Done() - globID := h.makeGlobID(path, glob) + switch args[0] { + case "grep", "cat": + command := newReadCommand(h, omode.CatClient) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() + } + }() - if !h.user.HasFilePermission(path) { - logger.Error(h.user, "No permission to read file", path, globID) - select { - case errors <- struct{}{}: - default: + case "tail": + command := newReadCommand(h, omode.TailClient) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() } + }() + + case "map": + command, aggregate, err := newMapCommand(h, argc, args) + if err != nil { + h.sendServerMessage(err.Error()) + logger.Error(h.user, err) return } - h.startReadingFile(mode, path, globID, regex) - } - - for _, path := range paths { - go read(path, &wg) - } + h.aggregate = aggregate + go func() { + command.Start(ctx, h.aggregatedMessages) + h.shutdown() + }() + + case "run": + command := newRunCommand(h) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() + } + }() - wg.Wait() -} + case "ack", ".ack": + h.handleAckCommand(argc, args) -func (h *ServerHandler) startReadingFile(mode omode.Mode, path, globID, regex string) { - defer h.stopReadingFile(path) - logger.Info(h.user, "Start reading file", path, globID) - - var reader fs.FileReader - switch mode { - case omode.TailClient: - reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) - case omode.GrepClient: - fallthrough - case omode.CatClient: - reader = fs.NewCatFile(path, globID, h.serverMessages, h.catLimiter) default: - reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) + h.sendServerMessage(logger.Error(h.user, "Received unknown command", argc, args)) } +} - h.fileReadersMtx.Lock() - h.fileReaders = append(h.fileReaders, reader) - h.fileReadersMtx.Unlock() - - lines := h.lines - // Plugin mappreduce engine - if h.aggregate != nil { - lines = h.aggregate.Lines +func (h *ServerHandler) handleAckCommand(argc int, args []string) { + if argc < 3 { + h.sendServerMessage(logger.Warn(h.user, commandParseWarning, args, argc)) + return } - - for { - if err := reader.Start(lines, regex); err != nil { - logger.Error(h.user, path, globID, err) - } - - select { - case <-h.stop: - return - default: - if !reader.Retry() { - return - } - } - - time.Sleep(time.Second * 2) - logger.Info(path, globID, "Reading file again") + if args[1] == "close" && args[2] == "connection" { + close(h.ackCloseReceived) } } -func (h *ServerHandler) stopReadingFile(path string) { - logger.Info(h.user, "Stop reading file", path) +func (h *ServerHandler) send(ch chan<- string, message string) { + select { + case ch <- message: + case <-h.ctx.Done(): + } +} - h.fileReadersMtx.Lock() - defer h.fileReadersMtx.Unlock() +func (h *ServerHandler) sendServerMessage(message string) { + h.send(h.serverMessageC(), message) +} - path = filepath.Clean(path) - var fileReaders []fs.FileReader +func (h *ServerHandler) serverMessageC() chan<- string { + return h.serverMessages +} - for _, reader := range h.fileReaders { - if reader.FilePath() == path { - reader.Stop() - continue - } - fileReaders = append(fileReaders, reader) - } +func (h *ServerHandler) flush() { + logger.Debug(h.user, "flush()") - if len(fileReaders) == len(h.fileReaders) { - logger.Warn(h.user, "Didn't read file path", path) - return + if h.aggregate != nil { + h.aggregate.Flush() } - h.fileReaders = fileReaders - - if len(fileReaders) == 0 { - if h.aggregate != nil { - h.aggregate.Serialize() - } - h.allLinesSent() + unsentMessages := func() int { + return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages) } -} - -func (h *ServerHandler) numUnsentMessages() int { - return len(h.lines) + len(h.serverMessages) + len(h.hiddenMessages) + len(h.aggregatedMessages) -} - -func (h *ServerHandler) allLinesSent() { - defer h.internalClose() for i := 0; i < 3; i++ { - if h.numUnsentMessages() == 0 { + if unsentMessages() == 0 { logger.Debug(h.user, "All lines sent") return } @@ -351,142 +301,43 @@ func (h *ServerHandler) allLinesSent() { time.Sleep(time.Second) } - logger.Warn(h.user, "Some lines remain unsent", h.numUnsentMessages()) + logger.Warn(h.user, "Some lines remain unsent", unsentMessages()) } -// Handler decides to shutdown the connection, not the server itself. -func (h *ServerHandler) internalClose() { - select { - case h.hiddenMessages <- "syn close connection": - case <-time.After(time.Second * 5): - logger.Debug(h.user, "Not waiting for ack close connection") - close(h.stopTimeout) - return - } +func (h *ServerHandler) shutdown() { + logger.Debug(h.user, "shutdown()") + h.flush() + + go func() { + select { + case h.serverMessageC() <- ".syn close connection": + case <-h.ctx.Done(): + } + }() select { - case <-h.Wait(): + case <-h.ackCloseReceived: case <-time.After(time.Second * 5): - logger.Debug(h.user, "Not waiting for ack close connection") - close(h.stopTimeout) - } -} - -func (h *ServerHandler) handleCommand(commandStr string) { - logger.Info(h.user, commandStr) - - args := strings.Split(commandStr, " ") - argc := len(args) - - logger.Debug(h.user, "Received command", commandStr, argc, args) - - if h.user.Name == config.ControlUser { - h.handleControlCommand(argc, args) - return + logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") + case <-h.ctx.Done(): } - h.handleUserCommand(argc, args) -} - -// Special (restricted) set of commands for anonymous ControlUser access. -func (h *ServerHandler) handleControlCommand(argc int, args []string) { - switch args[0] { - case "ping": - h.send(h.hiddenMessages, "pong") - case "debug": - h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) - default: - logger.Warn(h.user, "Received unknown command", argc, args) - } -} - -// Commands for authed users. -func (h *ServerHandler) handleUserCommand(argc int, args []string) { - switch args[0] { - case "grep": - fallthrough - case "cat": - h.handleReadCommand(argc, args, omode.CatClient) - case "tail": - h.handleReadCommand(argc, args, omode.TailClient) - case "map": - h.handleMapCommand(argc, args) - case "ack": - h.handleAckCommand(argc, args) - case "ping": - h.send(h.hiddenMessages, "pong") - case "version": - h.send(h.serverMessages, fmt.Sprintf("Server version is "+version.String())) - case "debug": - h.send(h.serverMessages, logger.Debug(h.user, "Received debug command", argc, args)) + select { + case h.done <- struct{}{}: default: - h.send(h.serverMessages, logger.Warn(h.user, "Received unknown command", argc, args)) } } -func (h *ServerHandler) handleReadCommand(argc int, args []string, mode omode.Mode) { - regex := "." - if argc >= 4 { - regex = args[3] - } - if argc < 3 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - go h.processFileGlob(mode, args[1], regex) +func (h *ServerHandler) incrementActiveReaders() { + // TODO: Use atomic counter variable instead, so we can get rid of the mutex + h.mutex.Lock() + defer h.mutex.Unlock() + h.activeReaders++ } +func (h *ServerHandler) decrementActiveReaders() int { + h.mutex.Lock() + defer h.mutex.Unlock() + h.activeReaders-- -func (h *ServerHandler) handleMapCommand(argc int, args []string) { - if argc < 2 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - - queryStr := strings.Join(args[1:], " ") - logger.Info(h.user, "Creating new mapr aggregator", queryStr) - aggregate, err := server.NewAggregate(h.aggregatedMessages, queryStr) - - if err != nil { - h.send(h.serverMessages, logger.Error(h.user, err)) - return - } - - h.aggregate = aggregate -} - -func (h *ServerHandler) handleAckCommand(argc int, args []string) { - if argc < 3 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - if args[1] == "close" && args[2] == "connection" { - close(h.ackStopReceived) - } -} - -func (h *ServerHandler) send(ch chan<- string, message string) { - select { - case ch <- message: - case <-h.stop: - } -} - -// Wait (block) until server handler is closed or a timeout has exceeded. -func (h *ServerHandler) Wait() <-chan struct{} { - wait := make(chan struct{}) - - go func() { - select { - case <-h.ackStopReceived: - logger.Debug(h.user, "Closing wait channel due to ACK stop received") - close(wait) - case <-h.stopTimeout: - logger.Debug(h.user, "Closing wait channel due to wait timeout") - close(wait) - case <-h.stop: - logger.Debug(h.user, "Closing wait channel due to stop") - } - }() - - return wait + return h.activeReaders } diff --git a/internal/server/server.go b/internal/server/server.go index 27a98f5..42eb74c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,14 @@ package server import ( + "context" "errors" "fmt" "io" "net" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/server/handlers" "github.com/mimecast/dtail/internal/ssh/server" user "github.com/mimecast/dtail/internal/user/server" @@ -26,8 +27,6 @@ type Server struct { catLimiterCh chan struct{} // To control the max amount of concurrent tails tailLimiterCh chan struct{} - // Ask to shutdown the server - stop chan struct{} } // New returns a new server. @@ -38,7 +37,6 @@ func New() *Server { sshServerConfig: &gossh.ServerConfig{}, catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats), tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails), - stop: make(chan struct{}), } s.sshServerConfig.PasswordCallback = s.controlUserCallback @@ -54,7 +52,7 @@ func New() *Server { } // Start the server. -func (s *Server) Start() int { +func (s *Server) Start(ctx context.Context) int { logger.Info("Starting server") bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) @@ -64,7 +62,7 @@ func (s *Server) Start() int { logger.FatalExit("Failed to open listening TCP socket", err) } - go s.stats.periodicLogServerStats(s.stop) + go s.stats.periodicLogServerStats(ctx) for { conn, err := listener.Accept() // Blocking @@ -79,11 +77,11 @@ func (s *Server) Start() int { continue } - go s.handleConnection(conn) + go s.handleConnection(ctx, conn) } } -func (s *Server) handleConnection(conn net.Conn) { +func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { logger.Info("Handling connection") sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig) @@ -96,11 +94,11 @@ func (s *Server) handleConnection(conn net.Conn) { go gossh.DiscardRequests(reqs) for newChannel := range chans { - go s.handleChannel(sshConn, newChannel) + go s.handleChannel(ctx, sshConn, newChannel) } } -func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) { +func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) { user := user.New(sshConn.User(), sshConn.RemoteAddr().String()) logger.Info(user, "Invoking channel handler") @@ -117,13 +115,13 @@ func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) return } - if err := s.handleRequests(sshConn, requests, channel, user); err != nil { + if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil { logger.Error(user, err) sshConn.Close() } } -func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { +func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { logger.Info(user, "Invoking request handler") for req := range in { @@ -132,50 +130,50 @@ func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, ch switch req.Type { case "shell": + handlerCtx, cancel := context.WithCancel(ctx) + var handler handlers.Handler + var done <-chan struct{} + switch user.Name { case config.ControlUser: - handler = handlers.NewControlHandler(user) + handler, done = handlers.NewControlHandler(handlerCtx, user) default: - handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh) + handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh) } - // Bi-directionally connect SSH stream to SSH handler - brokenPipe1 := make(chan struct{}) go func() { - defer close(brokenPipe1) + // Handler finished work, cancel all remaining routines + defer cancel() + <-done + }() + + go func() { + // Broken pipe, cancel + defer cancel() + io.Copy(channel, handler) }() - brokenPipe2 := make(chan struct{}) go func() { - defer close(brokenPipe2) + // Broken pipe, cancel + defer cancel() + io.Copy(handler, channel) }() - // Ensure to close all fd's and stop all goroutines once ssh connection terminated go func() { - defer s.stats.decrementConnections() - defer handler.Close() + defer cancel() if err := sshConn.Wait(); err != nil && err != io.EOF { logger.Error(user, err) } + s.stats.decrementConnections() logger.Info(user, "Good bye Mister!") }() - // Close the underlying ssh socket when server shuts down go func() { - select { - case <-s.stop: - logger.Debug(user, "Server initiating shutdown on handler") - case <-handler.Wait(): - logger.Debug(user, "Handler initiating shutdown by its own") - case <-brokenPipe1: - logger.Debug(user, "Broken pipe1") - case <-brokenPipe2: - logger.Debug(user, "Broken pipe2") - } + <-handlerCtx.Done() sshConn.Close() logger.Info(user, "Closed SSH connection") }() @@ -204,9 +202,3 @@ func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*g return nil, fmt.Errorf("Not authorized") } - -// Stop the server. -func (s *Server) Stop() { - close(s.stop) - s.stats.waitForConnections() -} diff --git a/internal/server/stats.go b/internal/server/stats.go index beb1885..4d661f7 100644 --- a/internal/server/stats.go +++ b/internal/server/stats.go @@ -1,12 +1,14 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "context" "fmt" "runtime" "sync" "time" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" ) // Used to collect and display various server stats. @@ -65,12 +67,12 @@ func (s *stats) serverLimitExceeded() error { return nil } -func (s *stats) periodicLogServerStats(stop <-chan struct{}) { +func (s *stats) periodicLogServerStats(ctx context.Context) { for { select { case <-time.NewTimer(time.Second * 10).C: s.logServerStats() - case <-stop: + case <-ctx.Done(): return } } diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index 3392eb1..967866f 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -2,7 +2,7 @@ package client import ( "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/ssh" "os" diff --git a/internal/ssh/client/hostkeycallback.go b/internal/ssh/client/hostkeycallback.go index 4023e59..7ae2396 100644 --- a/internal/ssh/client/hostkeycallback.go +++ b/internal/ssh/client/hostkeycallback.go @@ -2,8 +2,7 @@ package client import ( "bufio" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/prompt" + "context" "fmt" "net" "os" @@ -11,6 +10,9 @@ import ( "sync" "time" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/prompt" + "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/knownhosts" ) @@ -116,7 +118,7 @@ func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback { // PromptAddHosts prompts a question to the user whether unknown hosts should // be added to the known hosts or not. -func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) { +func (h *HostKeyCallback) PromptAddHosts(ctx context.Context) { var hosts []unknownHost for { @@ -135,7 +137,7 @@ func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) { h.promptAddHosts(hosts) hosts = []unknownHost{} } - case <-stop: + case <-ctx.Done(): logger.Debug("Stopping goroutine prompting new hosts...") return } diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go index 7baa4aa..07790ad 100644 --- a/internal/ssh/server/hostkey.go +++ b/internal/ssh/server/hostkey.go @@ -2,7 +2,7 @@ package server import ( "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/ssh" "io/ioutil" "os" diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go index c6929d7..757def7 100644 --- a/internal/ssh/server/publickeycallback.go +++ b/internal/ssh/server/publickeycallback.go @@ -7,7 +7,7 @@ import ( osUser "os/user" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" user "github.com/mimecast/dtail/internal/user/server" gossh "golang.org/x/crypto/ssh" diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 77cc341..3a2e416 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -4,9 +4,9 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" - "github.com/mimecast/dtail/internal/logger" "encoding/pem" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "io/ioutil" "net" "os" diff --git a/internal/user/name.go b/internal/user/name.go index 5171ec7..28ab0a4 100644 --- a/internal/user/name.go +++ b/internal/user/name.go @@ -2,10 +2,10 @@ package user import ( "os/user" - ) +) - -func Name() string { +// NoRootCheck verifies that the DTail run user is not with UID or GID 0. +func NoRootCheck() { user, err := user.Current() if err != nil { panic(err) @@ -18,7 +18,14 @@ func Name() string { if user.Gid == "0" { panic("Not allowed to run as GID 0") } +} + +// Name of the current run user. +func Name() string { + user, err := user.Current() + if err != nil { + panic(err) + } return user.Username } - diff --git a/internal/user/server/user.go b/internal/user/server/user.go index fad38d8..271a4ac 100644 --- a/internal/user/server/user.go +++ b/internal/user/server/user.go @@ -1,14 +1,15 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/fs/permissions" - "github.com/mimecast/dtail/internal/logger" "fmt" "os" "path/filepath" "regexp" "strings" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/fs/permissions" + "github.com/mimecast/dtail/internal/io/logger" ) const maxLinkDepth int = 100 @@ -37,26 +38,28 @@ func (u *User) String() string { } // HasFilePermission is used to determine whether user is alowed to read a file. -func (u *User) HasFilePermission(filePath string) (hasPermission bool) { +func (u *User) HasFilePermission(filePath, permissionType string) (hasPermission bool) { + logger.Debug(u, filePath, permissionType, "Checking config permissions") + cleanPath, err := filepath.EvalSymlinks(filePath) if err != nil { - logger.Error(u, filePath, "Unable to evaluate symlinks", err) + logger.Error(u, filePath, permissionType, "Unable to evaluate symlinks", err) hasPermission = false return } cleanPath, err = filepath.Abs(cleanPath) if err != nil { - logger.Error(u, cleanPath, "Unable to make file path absolute", err) + logger.Error(u, cleanPath, permissionType, "Unable to make file path absolute", err) hasPermission = false return } if cleanPath != filePath { - logger.Info(u, filePath, cleanPath, "Calculated new clean path from original file path (possibly symlink)") + logger.Info(u, filePath, cleanPath, permissionType, "Calculated new clean path from original file path (possibly symlink)") } - hasPermission, err = u.hasFilePermission(cleanPath) + hasPermission, err = u.hasFilePermission(cleanPath, permissionType) if err != nil { logger.Warn(u, cleanPath, err) } @@ -64,12 +67,12 @@ func (u *User) HasFilePermission(filePath string) (hasPermission bool) { return } -func (u *User) hasFilePermission(cleanPath string) (bool, error) { +func (u *User) hasFilePermission(cleanPath, permissionType string) (bool, error) { // First check file system Linux/UNIX permission. if _, err := permissions.ToRead(u.Name, cleanPath); err != nil { - return false, fmt.Errorf("User without OS file system permissions to read file: '%v'", err) + return false, fmt.Errorf("User without OS file system permissions to read path: '%v'", err) } - logger.Info(u, cleanPath, "User has OS file system permissions to read file") + logger.Info(u, cleanPath, permissionType, "User with OS file system permissions to path") // If file system permission is given, also check permissions // as configured in DTail config file. @@ -84,7 +87,7 @@ func (u *User) hasFilePermission(cleanPath string) (bool, error) { var hasPermission bool var err error - if hasPermission, err = u.iteratePaths(cleanPath); err != nil { + if hasPermission, err = u.iteratePaths(cleanPath, permissionType); err != nil { return false, err } @@ -101,17 +104,28 @@ func (u *User) hasFilePermission(cleanPath string) (bool, error) { return hasPermission, nil } -func (u *User) iteratePaths(cleanPath string) (bool, error) { +func (u *User) iteratePaths(cleanPath, permissionType string) (bool, error) { for _, permission := range u.permissions { + typeStr := "readfiles" // Assume ReadFiles by default. + var regexStr string var negate bool + splitted := strings.Split(permission, ":") + if len(splitted) > 1 { + typeStr = splitted[0] + permission = strings.Join(splitted[1:], ":") + } + + if typeStr != permissionType { + continue + } + + regexStr = permission if strings.HasPrefix(permission, "!") { regexStr = permission[1:] negate = true } - regexStr = permission - negate = false re, err := regexp.Compile(regexStr) if err != nil { diff --git a/internal/version/version.go b/internal/version/version.go index 3a4a5dc..3c057df 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -7,18 +7,20 @@ import ( "github.com/mimecast/dtail/internal/color" ) -// Name of DTail. -const Name = "DTail" - -// Version of DTail. -const Version = "1.1.0" - -// Additional information. -const Additional = "" +const ( + // Name of DTail. + Name string = "DTail" + // Version of DTail. + Version string = "2.0.0" + // Additional information for DTail + Additional string = "" + // ProtocolCompat -ibility version. + ProtocolCompat string = "2" +) // String representation of the DTail version. func String() string { - return fmt.Sprintf("%s v%v %s", Name, Version, Additional) + return fmt.Sprintf("%s %v Protocol %s %s", Name, Version, ProtocolCompat, Additional) } // PaintedString is a prettier string representation of the DTail version. @@ -30,7 +32,7 @@ func PaintedString() string { version := color.Paint(color.Blue, Version) descr := color.Paint(color.Green, Additional) - return fmt.Sprintf("%s %v %s", name, version, descr) + return fmt.Sprintf("%s %v Protocol %s %s", name, version, ProtocolCompat, descr) } // PrintAndExit prints the program version and exists. diff --git a/samples/dtail.json.sample b/samples/dtail.json.sample index 99c0a73..6b713e8 100644 --- a/samples/dtail.json.sample +++ b/samples/dtail.json.sample @@ -10,16 +10,18 @@ "HostKeyBits" : 2048, "Permissions": { "Default": [ - "^/.*$" + "readfiles:^/.*$", + "runcommands:!^/.*$" ], "Users": { "pbuetow": [ - "^/.*$" + "readfiles:^/.*$" + "runcommands:^/.*$" ], "jblake": [ - "^/tmp/foo.log$", - "^/.*$", - "!^/tmp/bar.log$" + "readfiles:^/tmp/foo.log$", + "readfiles:^/.*$", + "readfiles:!^/tmp/bar.log$" ] } } -- cgit v1.2.3