summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-02 10:03:44 +0200
committerPaul Buetow <paul@buetow.org>2026-03-02 10:03:44 +0200
commitd32b03d3bf562dcc6b3b83055c9aac6fb852fd14 (patch)
tree9b18e1967019d79aa6ae104452438b56fe81812f /internal
parent0b775532c53b7dfdb22037037e063dc00a418ef9 (diff)
clients: add jittered exponential reconnect backoff
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/baseclient.go72
-rw-r--r--internal/clients/baseclient_retry_test.go58
2 files changed, 127 insertions, 3 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index e7a13f5..95fa721 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -2,6 +2,7 @@ package clients
import (
"context"
+ "math/rand"
"sync"
"time"
@@ -15,6 +16,12 @@ import (
gossh "golang.org/x/crypto/ssh"
)
+const (
+ initialRetryDelay = 2 * time.Second
+ maxRetryDelay = 60 * time.Second
+ retryJitterFactor = 0.2 // +/-20% jitter to avoid synchronized reconnect storms.
+)
+
// This is the main client data structure.
type baseClient struct {
config.Args
@@ -102,6 +109,9 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i
func (c *baseClient) startConnection(ctx context.Context, i int,
conn connectors.Connector) (status int) {
+ retryDelay := initialRetryDelay
+ retryRandom := newRetryRandom(i)
+
for {
connCtx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -122,14 +132,70 @@ func (c *baseClient) startConnection(ctx context.Context, i int,
default:
}
- // Yes, we want to retry.
- time.Sleep(time.Second * 2)
- dlog.Client.Debug(conn.Server(), "Reconnecting")
+ // Yes, we want to retry with exponential backoff and jitter.
+ sleepDuration := jitterRetryDelay(retryDelay, retryRandom)
+ dlog.Client.Debug(conn.Server(), "Reconnecting", "backoff", sleepDuration)
+ if !sleepWithContext(ctx, sleepDuration) {
+ return
+ }
+
+ retryDelay = nextRetryDelay(retryDelay)
conn = c.makeConnection(conn.Server(), c.sshAuthMethods, c.hostKeyCallback)
c.connections[i] = conn
}
}
+func nextRetryDelay(current time.Duration) time.Duration {
+ if current <= 0 {
+ return initialRetryDelay
+ }
+
+ next := current * 2
+ if next > maxRetryDelay || next < current {
+ return maxRetryDelay
+ }
+ return next
+}
+
+func jitterRetryDelay(base time.Duration, random *rand.Rand) time.Duration {
+ if base <= 0 || random == nil {
+ return base
+ }
+
+ jitter := time.Duration(float64(base) * retryJitterFactor)
+ if jitter <= 0 {
+ return base
+ }
+
+ minDelay := base - jitter
+ maxDelay := base + jitter
+ if maxDelay < minDelay {
+ return base
+ }
+
+ return minDelay + time.Duration(random.Int63n(int64(maxDelay-minDelay+1)))
+}
+
+func sleepWithContext(ctx context.Context, delay time.Duration) bool {
+ if delay <= 0 {
+ return true
+ }
+
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+
+ select {
+ case <-ctx.Done():
+ return false
+ case <-timer.C:
+ return true
+ }
+}
+
+func newRetryRandom(seedOffset int) *rand.Rand {
+ return rand.New(rand.NewSource(time.Now().UnixNano() + int64(seedOffset)))
+}
+
func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod,
hostKeyCallback client.HostKeyCallback) connectors.Connector {
if c.Args.Serverless {
diff --git a/internal/clients/baseclient_retry_test.go b/internal/clients/baseclient_retry_test.go
new file mode 100644
index 0000000..323ceae
--- /dev/null
+++ b/internal/clients/baseclient_retry_test.go
@@ -0,0 +1,58 @@
+package clients
+
+import (
+ "context"
+ "math/rand"
+ "testing"
+ "time"
+)
+
+func TestNextRetryDelay(t *testing.T) {
+ tests := []struct {
+ name string
+ current time.Duration
+ want time.Duration
+ }{
+ {name: "zero uses initial", current: 0, want: initialRetryDelay},
+ {name: "doubles normally", current: 4 * time.Second, want: 8 * time.Second},
+ {name: "caps at max", current: 40 * time.Second, want: maxRetryDelay},
+ {name: "stays max at max", current: maxRetryDelay, want: maxRetryDelay},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := nextRetryDelay(tt.current); got != tt.want {
+ t.Fatalf("nextRetryDelay(%v) = %v, want %v", tt.current, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestJitterRetryDelayWithinBounds(t *testing.T) {
+ base := 10 * time.Second
+ random := rand.New(rand.NewSource(1))
+
+ min := 8 * time.Second
+ max := 12 * time.Second
+
+ for i := 0; i < 100; i++ {
+ got := jitterRetryDelay(base, random)
+ if got < min || got > max {
+ t.Fatalf("jitterRetryDelay() = %v, expected between %v and %v", got, min, max)
+ }
+ }
+}
+
+func TestSleepWithContextCancellation(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ start := time.Now()
+ if sleepWithContext(ctx, time.Second) {
+ t.Fatalf("sleepWithContext should stop when context is canceled")
+ }
+
+ if time.Since(start) > 100*time.Millisecond {
+ t.Fatalf("sleepWithContext took too long to exit on canceled context")
+ }
+}