summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/connectors/serverconnection.go23
-rw-r--r--internal/ssh/ssh.go19
2 files changed, 36 insertions, 6 deletions
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index 34d3997..d114d06 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
+ "net"
"strconv"
"strings"
"time"
@@ -135,10 +136,28 @@ func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc,
address := fmt.Sprintf("%s:%d", c.hostname, c.port)
dlog.Client.Debug(c.server, "Dialing into the connection", address)
- client, err := ssh.Dial("tcp", address, c.config)
+ // Use context-aware dialing to enable proper cancellation during connection establishment.
+ // TCP KeepAlive (30s) prevents silent connection failures on long-lived connections.
+ dialer := &net.Dialer{
+ Timeout: c.config.Timeout, // Use the SSH config timeout (2 seconds)
+ KeepAlive: 30 * time.Second, // Standard Go default for connection health monitoring
+ }
+
+ // Establish TCP connection with context support for cancellation
+ conn, err := dialer.DialContext(ctx, "tcp", address)
+ if err != nil {
+ return fmt.Errorf("failed to dial TCP connection to %s: %w", address, err)
+ }
+
+ // Perform SSH handshake over the established TCP connection
+ sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, c.config)
if err != nil {
- return fmt.Errorf("failed to dial SSH connection to %s: %w", address, err)
+ conn.Close()
+ return fmt.Errorf("SSH handshake failed for %s: %w", address, err)
}
+
+ // Create SSH client from the connection components
+ client := ssh.NewClient(sshConn, chans, reqs)
defer client.Close()
return c.session(ctx, cancel, client, throttleCh)
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 41cce05..7088e89 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -1,6 +1,7 @@
package ssh
import (
+ "context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
@@ -9,6 +10,7 @@ import (
"net"
"os"
"syscall"
+ "time"
"github.com/mimecast/dtail/internal/io/dlog"
@@ -49,7 +51,16 @@ func Agent() (gossh.AuthMethod, error) {
// AgentWithKeyIndex used for SSH auth with a specific key index from the agent.
// If keyIndex is -1, all keys are used. Otherwise, only the specified key is used.
func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
- sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
+ // Use context-aware dialing for SSH agent connection (local Unix socket).
+ // 2-second timeout is reasonable for local socket connections.
+ dialer := &net.Dialer{
+ Timeout: 2 * time.Second,
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ sshAgent, err := dialer.DialContext(ctx, "unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
return nil, fmt.Errorf("failed to connect to SSH agent: %w", err)
}
@@ -61,17 +72,17 @@ func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) {
for i, key := range keys {
dlog.Common.Debug("Public key", i, key)
}
-
+
// If no specific key index requested, use all keys (backwards compatible default)
if keyIndex < 0 {
return gossh.PublicKeysCallback(agentClient.Signers), nil
}
-
+
// Use only the specified key index (0-based)
if keyIndex >= len(keys) {
return nil, fmt.Errorf("key index %d out of range (agent has %d keys)", keyIndex, len(keys))
}
-
+
dlog.Common.Debug("Using SSH agent key at index", keyIndex)
return gossh.PublicKeysCallback(func() ([]gossh.Signer, error) {
signers, err := agentClient.Signers()