diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/clients/args.go | 1 | ||||
| -rw-r--r-- | internal/clients/runclient.go | 5 | ||||
| -rw-r--r-- | internal/config/common.go | 14 | ||||
| -rw-r--r-- | internal/config/config.go | 2 | ||||
| -rw-r--r-- | internal/io/run/run.go | 35 | ||||
| -rw-r--r-- | internal/pprof/pprof.go | 18 | ||||
| -rw-r--r-- | internal/server/handlers/runcommand.go | 4 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 73 | ||||
| -rw-r--r-- | internal/server/server.go | 46 |
9 files changed, 140 insertions, 58 deletions
diff --git a/internal/clients/args.go b/internal/clients/args.go index f55ce90..4f3c86a 100644 --- a/internal/clients/args.go +++ b/internal/clients/args.go @@ -16,6 +16,7 @@ type Args struct { TrustAllHosts bool Discovery string ConnectionsPerCPU int + Timeout int SSHAuthMethods []gossh.AuthMethod SSHHostKeyCallback gossh.HostKeyCallback } diff --git a/internal/clients/runclient.go b/internal/clients/runclient.go index 7a62fcc..e3be616 100644 --- a/internal/clients/runclient.go +++ b/internal/clients/runclient.go @@ -35,6 +35,11 @@ func (c RunClient) makeHandler(server string) handlers.Handler { func (c RunClient) makeCommands() (commands []string) { // Send "run COMMAND" to server! + if c.Timeout > 0 { + commands = append(commands, fmt.Sprintf("timeout %d %s %s", c.Timeout, c.Mode.String(), c.What)) + return + } + commands = append(commands, fmt.Sprintf("%s %s", c.Mode.String(), c.What)) return } diff --git a/internal/config/common.go b/internal/config/common.go index f0f1a94..a09a3ad 100644 --- a/internal/config/common.go +++ b/internal/config/common.go @@ -9,13 +9,10 @@ type CommonConfig struct { // The log strategy to use, one of // stdout: only log to stdout (useful when used with systemd) // daily: create a log file for every day - LogStrategy string - LogDir string - CacheDir string - TmpDir string `json:",omitempty"` - PProfEnable bool `json:",omitempty"` - PProfPort int `json:",omitempty"` - PProfBindAddress string `json:",omitempty"` + LogStrategy string + LogDir string + CacheDir string + TmpDir string `json:",omitempty"` } // Create a new default configuration. @@ -28,8 +25,5 @@ func newDefaultCommonConfig() *CommonConfig { LogDir: "log", CacheDir: "cache", TmpDir: "/tmp", - PProfEnable: false, - PProfPort: 6060, - PProfBindAddress: "0.0.0.0", } } diff --git a/internal/config/config.go b/internal/config/config.go index 239bfd1..7241276 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,7 +9,7 @@ import ( // ControlUser is used for various DTail specific operations. const ControlUser string = "DTAIL-CONTROL-USER" -// ScheduledUser is used for scheduled queries. +// ScheduleUser is used for scheduled queries. const ScheduleUser string = "DTAIL-SCHEDULE-USER" // Client holds a DTail client configuration. diff --git a/internal/io/run/run.go b/internal/io/run/run.go index 5951cde..f9cd980 100644 --- a/internal/io/run/run.go +++ b/internal/io/run/run.go @@ -6,6 +6,7 @@ import ( "io" "os/exec" "sync" + "syscall" "time" "github.com/mimecast/dtail/internal/io/line" @@ -14,16 +15,18 @@ import ( // Run is for execute a command. type Run struct { - command string - args []string - cmd *exec.Cmd + command string + args []string + cmd *exec.Cmd + pgroupKilled chan struct{} } // New returns a new command runner. func New(command string, args []string) Run { return Run{ - command: command, - args: args, + command: command, + args: args, + pgroupKilled: make(chan struct{}), } } @@ -42,6 +45,8 @@ func (r Run) Start(ctx context.Context, lines chan<- line.Line) (pid int, ec int logger.Debug(r.command) r.cmd = exec.CommandContext(ctx, r.command) } + // Create a new process group, so that kill() will only kill this command + pgroup. + r.cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} stdoutPipe, myErr := r.cmd.StdoutPipe() if err != nil { @@ -64,6 +69,7 @@ func (r Run) Start(ctx context.Context, lines chan<- line.Line) (pid int, ec int pid = r.cmd.Process.Pid ec = 0 } + go r.killPgroup(ctx, pid) var wg sync.WaitGroup wg.Add(2) @@ -103,3 +109,22 @@ func (r Run) pipeToLines(done chan struct{}, wg *sync.WaitGroup, pid int, reader time.Sleep(time.Millisecond * 10) } } + +// PgroupKilled identifies whether all subprocesses are killed or not. +func (r Run) PgroupKilled() <-chan struct{} { + return r.pgroupKilled +} + +func (r Run) killPgroup(ctx context.Context, pid int) { + if pid == -1 { + close(r.pgroupKilled) + return + } + + if pgid, err := syscall.Getpgid(pid); err == nil { + // Kill process group when done + <-ctx.Done() + syscall.Kill(-pgid, syscall.SIGKILL) + close(r.pgroupKilled) + } +} diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go deleted file mode 100644 index c6d11ca..0000000 --- a/internal/pprof/pprof.go +++ /dev/null @@ -1,18 +0,0 @@ -package pprof - -import ( - "fmt" - "net/http" - _ "net/http" - _ "net/http/pprof" - - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" -) - -// Start the profiler HTTP server. -func Start() { - bindAddr := fmt.Sprintf("%s:%d", config.Common.PProfBindAddress, config.Common.PProfPort) - logger.Info("Starting PProf server", bindAddr) - go http.ListenAndServe(bindAddr, nil) -} diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go index 120c402..95db52f 100644 --- a/internal/server/handlers/runcommand.go +++ b/internal/server/handlers/runcommand.go @@ -92,4 +92,8 @@ func (r runCommand) start(ctx context.Context, command string) { r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec)) r.server.sendServerMessage(logger.Info(fmt.Sprintf("Process %d exited with status %d", pid, ec))) + + logger.Debug(r.server.user, "Waiting for Pgroup to be killed") + <-r.run.PgroupKilled() + logger.Debug(r.server.user, "Pgroup killed") } diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index b840c77..946ae83 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "os" + "strconv" "strings" "sync" "time" @@ -37,28 +38,31 @@ type ServerHandler struct { payload []byte hostname string user *user.User - catLimiter chan struct{} - tailLimiter chan struct{} - ackCloseReceived chan struct{} - ctx context.Context - done chan struct{} - activeReaders int + // TODO: Move all these channels into a separate struct for readability! + catLimiter chan struct{} + tailLimiter chan struct{} + globalServerWaitFor chan struct{} + ackCloseReceived chan struct{} + ctx context.Context + done chan struct{} + activeReaders int } // NewServerHandler returns the server handler. -func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) { +func NewServerHandler(ctx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) (*ServerHandler, <-chan struct{}) { h := ServerHandler{ - ctx: ctx, - done: make(chan struct{}), - mutex: &sync.Mutex{}, - lines: make(chan line.Line, 100), - serverMessages: make(chan string, 10), - aggregatedMessages: make(chan string, 10), - ackCloseReceived: make(chan struct{}), - catLimiter: catLimiter, - tailLimiter: tailLimiter, - regex: ".", - user: user, + ctx: ctx, + done: make(chan struct{}), + mutex: &sync.Mutex{}, + lines: make(chan line.Line, 100), + serverMessages: make(chan string, 10), + aggregatedMessages: make(chan string, 10), + ackCloseReceived: make(chan struct{}), + catLimiter: catLimiter, + tailLimiter: tailLimiter, + globalServerWaitFor: globalServerWaitFor, + regex: ".", + user: user, } fqdn, err := os.Hostname() @@ -135,6 +139,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { logger.Debug(h.user, commandStr) + var timeout time.Duration args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) if err != nil { @@ -148,11 +153,28 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { return } + args, argc, timeout, err = h.handleTimeout(args, argc) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return + } + if h.user.Name == config.ControlUser { h.handleControlCommand(argc, args) return } + if timeout > 0 { + logger.Debug("Command with timeout context", argc, args, timeout) + commandCtx, cancel := context.WithTimeout(ctx, timeout) + go func() { + <-commandCtx.Done() + cancel() + }() + h.handleUserCommand(commandCtx, argc, args) + return + } + h.handleUserCommand(ctx, argc, args) } @@ -191,6 +213,16 @@ func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, er return args, argc, nil } +func (h *ServerHandler) handleTimeout(args []string, argc int) ([]string, int, time.Duration, error) { + if argc <= 2 || args[0] != "timeout" { + // No timeout specified + return args, argc, time.Duration(0) * time.Second, nil + } + + timeout, err := strconv.Atoi(args[1]) + return args[2:], argc - 2, time.Duration(timeout) * time.Second, err +} + func (h *ServerHandler) handleControlCommand(argc int, args []string) { switch args[0] { case "debug": @@ -241,8 +273,10 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] case "run": command := newRunCommand(h) h.incrementActiveReaders() + go func() { h.globalServerWaitFor <- struct{}{} }() go func() { command.Start(ctx, argc, args) + <-h.globalServerWaitFor if h.decrementActiveReaders() == 0 { h.shutdown() } @@ -332,12 +366,13 @@ func (h *ServerHandler) incrementActiveReaders() { // TODO: Use atomic counter variable instead, so we can get rid of the mutex h.mutex.Lock() defer h.mutex.Unlock() + h.activeReaders++ } func (h *ServerHandler) decrementActiveReaders() int { h.mutex.Lock() defer h.mutex.Unlock() - h.activeReaders-- + h.activeReaders-- return h.activeReaders } diff --git a/internal/server/server.go b/internal/server/server.go index 9b9d8a0..9314468 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "time" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/logger" @@ -24,11 +25,13 @@ type Server struct { // SSH server configuration. sshServerConfig *gossh.ServerConfig // To control the max amount of concurrent cats (which can cause a lot of I/O on the server) - catLimiterCh chan struct{} + catLimiter chan struct{} // To control the max amount of concurrent tails - tailLimiterCh chan struct{} + tailLimiter chan struct{} // To run scheduled tasks (if configured) sched *scheduler + // Wait counter, e.g. there might be still subprocesses (forked by drun) to be killed. + shutdownWaitFor chan struct{} } // New returns a new server. @@ -37,8 +40,9 @@ func New() *Server { s := Server{ sshServerConfig: &gossh.ServerConfig{}, - catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats), - tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails), + catLimiter: make(chan struct{}, config.Server.MaxConcurrentCats), + tailLimiter: make(chan struct{}, config.Server.MaxConcurrentTails), + shutdownWaitFor: make(chan struct{}, 1000), sched: newScheduler(), } @@ -60,6 +64,7 @@ func (s *Server) Start(ctx context.Context) int { bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) logger.Info("Binding server", bindAt) + listener, err := net.Listen("tcp", bindAt) if err != nil { logger.FatalExit("Failed to open listening TCP socket", err) @@ -67,10 +72,40 @@ func (s *Server) Start(ctx context.Context) int { go s.stats.start(ctx) go s.sched.start(ctx) + go s.listenerLoop(ctx, listener) + + select { + case <-ctx.Done(): + // Wait until all commands/jobs/children are no more! + s.wait() + } + + // For future use. + return 0 +} + +func (s *Server) wait() { + for { + num := len(s.shutdownWaitFor) + logger.Debug("Waiting for stuff to finish", num) + if num <= 0 { + return + } + time.Sleep(time.Second) + } +} + +func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) { + logger.Debug("Starting listener loop") for { conn, err := listener.Accept() // Blocking if err != nil { + select { + case <-ctx.Done(): + return + default: + } logger.Error("Failed to accept incoming connection", err) continue } @@ -143,12 +178,13 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch case config.ControlUser: handler, done = handlers.NewControlHandler(handlerCtx, user) default: - handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh) + handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor) } go func() { // Handler finished work, cancel all remaining routines defer cancel() + <-done }() |
