diff options
| author | Paul Buetow <paul@buetow.org> | 2026-02-15 08:28:43 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-02-15 08:29:45 +0200 |
| commit | bbbb7461d19e611e6fab3f24edd5f8e0d2d45b1e (patch) | |
| tree | bee4b9e07bafb2810f0e2cc2db4fb34e7154b2d4 /internal | |
| parent | d89b9e6760e2aadf9779faa6f23678f67c731e1e (diff) | |
refactor: implement context-aware network dialing
Modernize network dialing to use Go's context-aware patterns for better
cancellation support and connection reliability.
Changes:
- Update Go version from 1.24 to 1.25 in go.mod
- Replace ssh.Dial with net.Dialer.DialContext + ssh.NewClientConn
for SSH client connections in serverconnection.go
- Add TCP KeepAlive (30s) for SSH connection health monitoring
- Implement context-aware dialing for SSH agent connections in ssh.go
- Improve error messages to distinguish dial vs SSH handshake failures
- Update AGENTS.md with integration test requirements
Benefits:
- Context cancellation now properly affects connection establishment
- TCP KeepAlive prevents silent connection failures
- Better integration with Go's cancellation patterns
- Improved reliability for distributed systems
All integration tests pass with race detection enabled.
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/clients/connectors/serverconnection.go | 23 | ||||
| -rw-r--r-- | internal/ssh/ssh.go | 19 |
2 files changed, 36 insertions, 6 deletions
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go index 34d3997..d114d06 100644 --- a/internal/clients/connectors/serverconnection.go +++ b/internal/clients/connectors/serverconnection.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net" "strconv" "strings" "time" @@ -135,10 +136,28 @@ func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc, address := fmt.Sprintf("%s:%d", c.hostname, c.port) dlog.Client.Debug(c.server, "Dialing into the connection", address) - client, err := ssh.Dial("tcp", address, c.config) + // Use context-aware dialing to enable proper cancellation during connection establishment. + // TCP KeepAlive (30s) prevents silent connection failures on long-lived connections. + dialer := &net.Dialer{ + Timeout: c.config.Timeout, // Use the SSH config timeout (2 seconds) + KeepAlive: 30 * time.Second, // Standard Go default for connection health monitoring + } + + // Establish TCP connection with context support for cancellation + conn, err := dialer.DialContext(ctx, "tcp", address) + if err != nil { + return fmt.Errorf("failed to dial TCP connection to %s: %w", address, err) + } + + // Perform SSH handshake over the established TCP connection + sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, c.config) if err != nil { - return fmt.Errorf("failed to dial SSH connection to %s: %w", address, err) + conn.Close() + return fmt.Errorf("SSH handshake failed for %s: %w", address, err) } + + // Create SSH client from the connection components + client := ssh.NewClient(sshConn, chans, reqs) defer client.Close() return c.session(ctx, cancel, client, throttleCh) diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 41cce05..7088e89 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -1,6 +1,7 @@ package ssh import ( + "context" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -9,6 +10,7 @@ import ( "net" "os" "syscall" + "time" "github.com/mimecast/dtail/internal/io/dlog" @@ -49,7 +51,16 @@ func Agent() (gossh.AuthMethod, error) { // AgentWithKeyIndex used for SSH auth with a specific key index from the agent. // If keyIndex is -1, all keys are used. Otherwise, only the specified key is used. func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) { - sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) + // Use context-aware dialing for SSH agent connection (local Unix socket). + // 2-second timeout is reasonable for local socket connections. + dialer := &net.Dialer{ + Timeout: 2 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + sshAgent, err := dialer.DialContext(ctx, "unix", os.Getenv("SSH_AUTH_SOCK")) if err != nil { return nil, fmt.Errorf("failed to connect to SSH agent: %w", err) } @@ -61,17 +72,17 @@ func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) { for i, key := range keys { dlog.Common.Debug("Public key", i, key) } - + // If no specific key index requested, use all keys (backwards compatible default) if keyIndex < 0 { return gossh.PublicKeysCallback(agentClient.Signers), nil } - + // Use only the specified key index (0-based) if keyIndex >= len(keys) { return nil, fmt.Errorf("key index %d out of range (agent has %d keys)", keyIndex, len(keys)) } - + dlog.Common.Debug("Using SSH agent key at index", keyIndex) return gossh.PublicKeysCallback(func() ([]gossh.Signer, error) { signers, err := agentClient.Signers() |
