summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-02-15 08:28:43 +0200
committerPaul Buetow <paul@buetow.org>2026-02-15 08:29:45 +0200
commitbbbb7461d19e611e6fab3f24edd5f8e0d2d45b1e (patch)
treebee4b9e07bafb2810f0e2cc2db4fb34e7154b2d4 /internal
parentd89b9e6760e2aadf9779faa6f23678f67c731e1e (diff)
refactor: implement context-aware network dialing
Modernize network dialing to use Go's context-aware patterns for better cancellation support and connection reliability. Changes: - Update Go version from 1.24 to 1.25 in go.mod - Replace ssh.Dial with net.Dialer.DialContext + ssh.NewClientConn for SSH client connections in serverconnection.go - Add TCP KeepAlive (30s) for SSH connection health monitoring - Implement context-aware dialing for SSH agent connections in ssh.go - Improve error messages to distinguish dial vs SSH handshake failures - Update AGENTS.md with integration test requirements Benefits: - Context cancellation now properly affects connection establishment - TCP KeepAlive prevents silent connection failures - Better integration with Go's cancellation patterns - Improved reliability for distributed systems All integration tests pass with race detection enabled. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/connectors/serverconnection.go23
-rw-r--r--internal/ssh/ssh.go19
2 files changed, 36 insertions, 6 deletions
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index 34d3997..d114d06 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
+ "net"
"strconv"
"strings"
"time"
@@ -135,10 +136,28 @@ func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc,
address := fmt.Sprintf("%s:%d", c.hostname, c.port)
dlog.Client.Debug(c.server, "Dialing into the connection", address)
- client, err := ssh.Dial("tcp", address, c.config)
+ // Use context-aware dialing to enable proper cancellation during connection establishment.
+ // TCP KeepAlive (30s) prevents silent connection failures on long-lived connections.
+ dialer := &net.Dialer{
+ Timeout: c.config.Timeout, // Use the SSH config timeout (2 seconds)
+ KeepAlive: 30 * time.Second, // Standard Go default for connection health monitoring
+ }
+
+ // Establish TCP connection with context support for cancellation
+ conn, err := dialer.DialContext(ctx, "tcp", address)
+ if err != nil {
+ return fmt.Errorf("failed to dial TCP connection to %s: %w", address, err)
+ }
+
+ // Perform SSH handshake over the established TCP connection
+ sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, c.config)
if err != nil {
- return fmt.Errorf("failed to dial SSH connection to %s: %w", address, err)
+ conn.Close()
+ return fmt.Errorf("SSH handshake failed for %s: %w", address, err)
}
+
+ // Create SSH client from the connection components
+ client := ssh.NewClient(sshConn, chans, reqs)
defer client.Close()
return c.session(ctx, cancel, client, throttleCh)
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 41cce05..7088e89 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -1,6 +1,7 @@
package ssh
import (
+ "context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
@@ -9,6 +10,7 @@ import (
"net"
"os"
"syscall"
+ "time"
"github.com/mimecast/dtail/internal/io/dlog"
@@ -49,7 +51,16 @@ func Agent() (gossh.AuthMethod, error) {
// AgentWithKeyIndex used for SSH auth with a specific key index from the agent.
// If keyIndex is -1, all keys are used. Otherwise, only the specified key is used.
func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
- sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
+ // Use context-aware dialing for SSH agent connection (local Unix socket).
+ // 2-second timeout is reasonable for local socket connections.
+ dialer := &net.Dialer{
+ Timeout: 2 * time.Second,
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ sshAgent, err := dialer.DialContext(ctx, "unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
return nil, fmt.Errorf("failed to connect to SSH agent: %w", err)
}
@@ -61,17 +72,17 @@ func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
for i, key := range keys {
dlog.Common.Debug("Public key", i, key)
}
-
+
// If no specific key index requested, use all keys (backwards compatible default)
if keyIndex < 0 {
return gossh.PublicKeysCallback(agentClient.Signers), nil
}
-
+
// Use only the specified key index (0-based)
if keyIndex >= len(keys) {
return nil, fmt.Errorf("key index %d out of range (agent has %d keys)", keyIndex, len(keys))
}
-
+
dlog.Common.Debug("Using SSH agent key at index", keyIndex)
return gossh.PublicKeysCallback(func() ([]gossh.Signer, error) {
signers, err := agentClient.Signers()