package remote import ( "context" "fmt" "io" "strconv" "strings" "time" "github.com/mimecast/dtail/internal/clients/handlers" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/ssh/client" "golang.org/x/crypto/ssh" ) // Connection represents a client connection connection to a single server. type Connection struct { // The remote server's hostname connected to. Server string // The remote server's port connected to. port int // The SSH client configuration used. config *ssh.ClientConfig // The SSH client handler to use. Handler handlers.Handler // DTail commands sent from client to server. When client loses // connection to the server it re-connects automatically and sends the // same commands again. Commands []string // Is it a persistent connection or a one-off? isOneOff bool // To deal with SSH server host keys hostKeyCallback client.HostKeyCallback // To determine if connection throttling has finished or not throttlingDone bool } // NewConnection returns a new connection. func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback) *Connection { logger.Debug(server, "Creating new connection") c := Connection{ hostKeyCallback: hostKeyCallback, config: &ssh.ClientConfig{ User: userName, Auth: authMethods, HostKeyCallback: hostKeyCallback.Wrap(), Timeout: time.Second * 3, }, } c.initServerPort(server) return &c } // NewOneOffConnection creates new one-off connection (only for sending a series of commands and then quit). func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthMethod) *Connection { c := Connection{ config: &ssh.ClientConfig{ User: userName, Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), }, isOneOff: true, } c.initServerPort(server) return &c } // Attempt to parse the server port address from the provided server FQDN. func (c *Connection) initServerPort(server string) { c.Server = server c.port = config.Common.SSHPort parts := strings.Split(server, ":") if len(parts) == 2 { logger.Debug("Parsing port from hostname", parts) port, err := strconv.Atoi(parts[1]) if err != nil { logger.FatalExit("Unable to parse client port", server, parts, err) } c.Server = parts[0] c.port = port } } // Start the server connection. Build up SSH session and send some DTail commands. func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { // Throttle how many connections can be established concurrently (based on ch length) logger.Debug(c.Server, "Throttling connection", len(throttleCh), cap(throttleCh)) select { case throttleCh <- struct{}{}: case <-ctx.Done(): logger.Debug(c.Server, "Not establishing connection as context is done", len(throttleCh), cap(throttleCh)) return } logger.Debug(c.Server, "Throttling says that the connection can be established", len(throttleCh), cap(throttleCh)) go func() { defer func() { if !c.throttlingDone { logger.Debug(c.Server, "Unthrottling connection (1)", len(throttleCh), cap(throttleCh)) c.throttlingDone = true <-throttleCh } cancel() }() if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil { logger.Warn(c.Server, c.port, err) if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { logger.Debug(c.Server, "Not trusting host") } } }() <-ctx.Done() } // Dail into a new SSH connection. Close connection in case of an error. func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) error { logger.Debug(c.Server, "Incrementing connection stats") statsCh <- struct{}{} defer func() { logger.Debug(c.Server, "Decrementing connection stats") <-statsCh }() logger.Debug(c.Server, "Dialing into the connection") address := fmt.Sprintf("%s:%d", c.Server, c.port) client, err := ssh.Dial("tcp", address, c.config) if err != nil { return err } defer client.Close() return c.session(ctx, cancel, client, throttleCh) } // Create the SSH session. Close the session in case of an error. func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error { logger.Debug(c.Server, "session") session, err := client.NewSession() if err != nil { return err } defer session.Close() return c.handle(ctx, cancel, session, throttleCh) } func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error { logger.Debug(c.Server, "handle") stdinPipe, err := session.StdinPipe() if err != nil { return err } stdoutPipe, err := session.StdoutPipe() if err != nil { return err } if err := session.Shell(); err != nil { return err } go func() { io.Copy(stdinPipe, c.Handler) cancel() }() go func() { io.Copy(c.Handler, stdoutPipe) cancel() }() go func() { select { case <-c.Handler.Done(): case <-ctx.Done(): } cancel() }() // Send all commands to client. for _, command := range c.Commands { logger.Debug(command) c.Handler.SendMessage(command) } if !c.throttlingDone { logger.Debug(c.Server, "Unthrottling connection (2)", len(throttleCh), cap(throttleCh)) c.throttlingDone = true <-throttleCh } <-ctx.Done() c.Handler.Shutdown() return nil }