summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/args.go1
-rw-r--r--internal/clients/runclient.go5
-rw-r--r--internal/config/common.go14
-rw-r--r--internal/config/config.go2
-rw-r--r--internal/io/run/run.go35
-rw-r--r--internal/pprof/pprof.go18
-rw-r--r--internal/server/handlers/runcommand.go4
-rw-r--r--internal/server/handlers/serverhandler.go73
-rw-r--r--internal/server/server.go46
9 files changed, 140 insertions, 58 deletions
diff --git a/internal/clients/args.go b/internal/clients/args.go
index f55ce90..4f3c86a 100644
--- a/internal/clients/args.go
+++ b/internal/clients/args.go
@@ -16,6 +16,7 @@ type Args struct {
TrustAllHosts bool
Discovery string
ConnectionsPerCPU int
+ Timeout int
SSHAuthMethods []gossh.AuthMethod
SSHHostKeyCallback gossh.HostKeyCallback
}
diff --git a/internal/clients/runclient.go b/internal/clients/runclient.go
index 7a62fcc..e3be616 100644
--- a/internal/clients/runclient.go
+++ b/internal/clients/runclient.go
@@ -35,6 +35,11 @@ func (c RunClient) makeHandler(server string) handlers.Handler {
func (c RunClient) makeCommands() (commands []string) {
// Send "run COMMAND" to server!
+ if c.Timeout > 0 {
+ commands = append(commands, fmt.Sprintf("timeout %d %s %s", c.Timeout, c.Mode.String(), c.What))
+ return
+ }
+
commands = append(commands, fmt.Sprintf("%s %s", c.Mode.String(), c.What))
return
}
diff --git a/internal/config/common.go b/internal/config/common.go
index f0f1a94..a09a3ad 100644
--- a/internal/config/common.go
+++ b/internal/config/common.go
@@ -9,13 +9,10 @@ type CommonConfig struct {
// The log strategy to use, one of
// stdout: only log to stdout (useful when used with systemd)
// daily: create a log file for every day
- LogStrategy string
- LogDir string
- CacheDir string
- TmpDir string `json:",omitempty"`
- PProfEnable bool `json:",omitempty"`
- PProfPort int `json:",omitempty"`
- PProfBindAddress string `json:",omitempty"`
+ LogStrategy string
+ LogDir string
+ CacheDir string
+ TmpDir string `json:",omitempty"`
}
// Create a new default configuration.
@@ -28,8 +25,5 @@ func newDefaultCommonConfig() *CommonConfig {
LogDir: "log",
CacheDir: "cache",
TmpDir: "/tmp",
- PProfEnable: false,
- PProfPort: 6060,
- PProfBindAddress: "0.0.0.0",
}
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 239bfd1..7241276 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -9,7 +9,7 @@ import (
// ControlUser is used for various DTail specific operations.
const ControlUser string = "DTAIL-CONTROL-USER"
-// ScheduledUser is used for scheduled queries.
+// ScheduleUser is used for scheduled queries.
const ScheduleUser string = "DTAIL-SCHEDULE-USER"
// Client holds a DTail client configuration.
diff --git a/internal/io/run/run.go b/internal/io/run/run.go
index 5951cde..f9cd980 100644
--- a/internal/io/run/run.go
+++ b/internal/io/run/run.go
@@ -6,6 +6,7 @@ import (
"io"
"os/exec"
"sync"
+ "syscall"
"time"
"github.com/mimecast/dtail/internal/io/line"
@@ -14,16 +15,18 @@ import (
// Run is for execute a command.
type Run struct {
- command string
- args []string
- cmd *exec.Cmd
+ command string
+ args []string
+ cmd *exec.Cmd
+ pgroupKilled chan struct{}
}
// New returns a new command runner.
func New(command string, args []string) Run {
return Run{
- command: command,
- args: args,
+ command: command,
+ args: args,
+ pgroupKilled: make(chan struct{}),
}
}
@@ -42,6 +45,8 @@ func (r Run) Start(ctx context.Context, lines chan<- line.Line) (pid int, ec int
logger.Debug(r.command)
r.cmd = exec.CommandContext(ctx, r.command)
}
+ // Create a new process group, so that kill() will only kill this command + pgroup.
+ r.cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
stdoutPipe, myErr := r.cmd.StdoutPipe()
if err != nil {
@@ -64,6 +69,7 @@ func (r Run) Start(ctx context.Context, lines chan<- line.Line) (pid int, ec int
pid = r.cmd.Process.Pid
ec = 0
}
+ go r.killPgroup(ctx, pid)
var wg sync.WaitGroup
wg.Add(2)
@@ -103,3 +109,22 @@ func (r Run) pipeToLines(done chan struct{}, wg *sync.WaitGroup, pid int, reader
time.Sleep(time.Millisecond * 10)
}
}
+
+// PgroupKilled identifies whether all subprocesses are killed or not.
+func (r Run) PgroupKilled() <-chan struct{} {
+ return r.pgroupKilled
+}
+
+func (r Run) killPgroup(ctx context.Context, pid int) {
+ if pid == -1 {
+ close(r.pgroupKilled)
+ return
+ }
+
+ if pgid, err := syscall.Getpgid(pid); err == nil {
+ // Kill process group when done
+ <-ctx.Done()
+ syscall.Kill(-pgid, syscall.SIGKILL)
+ close(r.pgroupKilled)
+ }
+}
diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go
deleted file mode 100644
index c6d11ca..0000000
--- a/internal/pprof/pprof.go
+++ /dev/null
@@ -1,18 +0,0 @@
-package pprof
-
-import (
- "fmt"
- "net/http"
- _ "net/http"
- _ "net/http/pprof"
-
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
-)
-
-// Start the profiler HTTP server.
-func Start() {
- bindAddr := fmt.Sprintf("%s:%d", config.Common.PProfBindAddress, config.Common.PProfPort)
- logger.Info("Starting PProf server", bindAddr)
- go http.ListenAndServe(bindAddr, nil)
-}
diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go
index 120c402..95db52f 100644
--- a/internal/server/handlers/runcommand.go
+++ b/internal/server/handlers/runcommand.go
@@ -92,4 +92,8 @@ func (r runCommand) start(ctx context.Context, command string) {
r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec))
r.server.sendServerMessage(logger.Info(fmt.Sprintf("Process %d exited with status %d", pid, ec)))
+
+ logger.Debug(r.server.user, "Waiting for Pgroup to be killed")
+ <-r.run.PgroupKilled()
+ logger.Debug(r.server.user, "Pgroup killed")
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index b840c77..946ae83 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"os"
+ "strconv"
"strings"
"sync"
"time"
@@ -37,28 +38,31 @@ type ServerHandler struct {
payload []byte
hostname string
user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- ackCloseReceived chan struct{}
- ctx context.Context
- done chan struct{}
- activeReaders int
+ // TODO: Move all these channels into a separate struct for readability!
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ globalServerWaitFor chan struct{}
+ ackCloseReceived chan struct{}
+ ctx context.Context
+ done chan struct{}
+ activeReaders int
}
// NewServerHandler returns the server handler.
-func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) {
+func NewServerHandler(ctx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) (*ServerHandler, <-chan struct{}) {
h := ServerHandler{
- ctx: ctx,
- done: make(chan struct{}),
- mutex: &sync.Mutex{},
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- aggregatedMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- regex: ".",
- user: user,
+ ctx: ctx,
+ done: make(chan struct{}),
+ mutex: &sync.Mutex{},
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ aggregatedMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ globalServerWaitFor: globalServerWaitFor,
+ regex: ".",
+ user: user,
}
fqdn, err := os.Hostname()
@@ -135,6 +139,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
logger.Debug(h.user, commandStr)
+ var timeout time.Duration
args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
if err != nil {
@@ -148,11 +153,28 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
return
}
+ args, argc, timeout, err = h.handleTimeout(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, logger.Error(h.user, err))
+ return
+ }
+
if h.user.Name == config.ControlUser {
h.handleControlCommand(argc, args)
return
}
+ if timeout > 0 {
+ logger.Debug("Command with timeout context", argc, args, timeout)
+ commandCtx, cancel := context.WithTimeout(ctx, timeout)
+ go func() {
+ <-commandCtx.Done()
+ cancel()
+ }()
+ h.handleUserCommand(commandCtx, argc, args)
+ return
+ }
+
h.handleUserCommand(ctx, argc, args)
}
@@ -191,6 +213,16 @@ func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, er
return args, argc, nil
}
+func (h *ServerHandler) handleTimeout(args []string, argc int) ([]string, int, time.Duration, error) {
+ if argc <= 2 || args[0] != "timeout" {
+ // No timeout specified
+ return args, argc, time.Duration(0) * time.Second, nil
+ }
+
+ timeout, err := strconv.Atoi(args[1])
+ return args[2:], argc - 2, time.Duration(timeout) * time.Second, err
+}
+
func (h *ServerHandler) handleControlCommand(argc int, args []string) {
switch args[0] {
case "debug":
@@ -241,8 +273,10 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
case "run":
command := newRunCommand(h)
h.incrementActiveReaders()
+ go func() { h.globalServerWaitFor <- struct{}{} }()
go func() {
command.Start(ctx, argc, args)
+ <-h.globalServerWaitFor
if h.decrementActiveReaders() == 0 {
h.shutdown()
}
@@ -332,12 +366,13 @@ func (h *ServerHandler) incrementActiveReaders() {
// TODO: Use atomic counter variable instead, so we can get rid of the mutex
h.mutex.Lock()
defer h.mutex.Unlock()
+
h.activeReaders++
}
func (h *ServerHandler) decrementActiveReaders() int {
h.mutex.Lock()
defer h.mutex.Unlock()
- h.activeReaders--
+ h.activeReaders--
return h.activeReaders
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 9b9d8a0..9314468 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
+ "time"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/logger"
@@ -24,11 +25,13 @@ type Server struct {
// SSH server configuration.
sshServerConfig *gossh.ServerConfig
// To control the max amount of concurrent cats (which can cause a lot of I/O on the server)
- catLimiterCh chan struct{}
+ catLimiter chan struct{}
// To control the max amount of concurrent tails
- tailLimiterCh chan struct{}
+ tailLimiter chan struct{}
// To run scheduled tasks (if configured)
sched *scheduler
+ // Wait counter, e.g. there might be still subprocesses (forked by drun) to be killed.
+ shutdownWaitFor chan struct{}
}
// New returns a new server.
@@ -37,8 +40,9 @@ func New() *Server {
s := Server{
sshServerConfig: &gossh.ServerConfig{},
- catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats),
- tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails),
+ catLimiter: make(chan struct{}, config.Server.MaxConcurrentCats),
+ tailLimiter: make(chan struct{}, config.Server.MaxConcurrentTails),
+ shutdownWaitFor: make(chan struct{}, 1000),
sched: newScheduler(),
}
@@ -60,6 +64,7 @@ func (s *Server) Start(ctx context.Context) int {
bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
logger.Info("Binding server", bindAt)
+
listener, err := net.Listen("tcp", bindAt)
if err != nil {
logger.FatalExit("Failed to open listening TCP socket", err)
@@ -67,10 +72,40 @@ func (s *Server) Start(ctx context.Context) int {
go s.stats.start(ctx)
go s.sched.start(ctx)
+ go s.listenerLoop(ctx, listener)
+
+ select {
+ case <-ctx.Done():
+ // Wait until all commands/jobs/children are no more!
+ s.wait()
+ }
+
+ // For future use.
+ return 0
+}
+
+func (s *Server) wait() {
+ for {
+ num := len(s.shutdownWaitFor)
+ logger.Debug("Waiting for stuff to finish", num)
+ if num <= 0 {
+ return
+ }
+ time.Sleep(time.Second)
+ }
+}
+
+func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) {
+ logger.Debug("Starting listener loop")
for {
conn, err := listener.Accept() // Blocking
if err != nil {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
logger.Error("Failed to accept incoming connection", err)
continue
}
@@ -143,12 +178,13 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
case config.ControlUser:
handler, done = handlers.NewControlHandler(handlerCtx, user)
default:
- handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh)
+ handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor)
}
go func() {
// Handler finished work, cancel all remaining routines
defer cancel()
+
<-done
}()