diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-02 10:03:44 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-02 10:03:44 +0200 |
| commit | d32b03d3bf562dcc6b3b83055c9aac6fb852fd14 (patch) | |
| tree | 9b18e1967019d79aa6ae104452438b56fe81812f /internal | |
| parent | 0b775532c53b7dfdb22037037e063dc00a418ef9 (diff) | |
clients: add jittered exponential reconnect backoff
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/clients/baseclient.go | 72 | ||||
| -rw-r--r-- | internal/clients/baseclient_retry_test.go | 58 |
2 files changed, 127 insertions, 3 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index e7a13f5..95fa721 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -2,6 +2,7 @@ package clients import ( "context" + "math/rand" "sync" "time" @@ -15,6 +16,12 @@ import ( gossh "golang.org/x/crypto/ssh" ) +const ( + initialRetryDelay = 2 * time.Second + maxRetryDelay = 60 * time.Second + retryJitterFactor = 0.2 // +/-20% jitter to avoid synchronized reconnect storms. +) + // This is the main client data structure. type baseClient struct { config.Args @@ -102,6 +109,9 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i func (c *baseClient) startConnection(ctx context.Context, i int, conn connectors.Connector) (status int) { + retryDelay := initialRetryDelay + retryRandom := newRetryRandom(i) + for { connCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -122,14 +132,70 @@ func (c *baseClient) startConnection(ctx context.Context, i int, default: } - // Yes, we want to retry. - time.Sleep(time.Second * 2) - dlog.Client.Debug(conn.Server(), "Reconnecting") + // Yes, we want to retry with exponential backoff and jitter. + sleepDuration := jitterRetryDelay(retryDelay, retryRandom) + dlog.Client.Debug(conn.Server(), "Reconnecting", "backoff", sleepDuration) + if !sleepWithContext(ctx, sleepDuration) { + return + } + + retryDelay = nextRetryDelay(retryDelay) conn = c.makeConnection(conn.Server(), c.sshAuthMethods, c.hostKeyCallback) c.connections[i] = conn } } +func nextRetryDelay(current time.Duration) time.Duration { + if current <= 0 { + return initialRetryDelay + } + + next := current * 2 + if next > maxRetryDelay || next < current { + return maxRetryDelay + } + return next +} + +func jitterRetryDelay(base time.Duration, random *rand.Rand) time.Duration { + if base <= 0 || random == nil { + return base + } + + jitter := time.Duration(float64(base) * retryJitterFactor) + if jitter <= 0 { + return base + } + + minDelay := base - jitter + maxDelay := base + jitter + if maxDelay < minDelay { + return base + } + + return minDelay + time.Duration(random.Int63n(int64(maxDelay-minDelay+1))) +} + +func sleepWithContext(ctx context.Context, delay time.Duration) bool { + if delay <= 0 { + return true + } + + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } +} + +func newRetryRandom(seedOffset int) *rand.Rand { + return rand.New(rand.NewSource(time.Now().UnixNano() + int64(seedOffset))) +} + func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) connectors.Connector { if c.Args.Serverless { diff --git a/internal/clients/baseclient_retry_test.go b/internal/clients/baseclient_retry_test.go new file mode 100644 index 0000000..323ceae --- /dev/null +++ b/internal/clients/baseclient_retry_test.go @@ -0,0 +1,58 @@ +package clients + +import ( + "context" + "math/rand" + "testing" + "time" +) + +func TestNextRetryDelay(t *testing.T) { + tests := []struct { + name string + current time.Duration + want time.Duration + }{ + {name: "zero uses initial", current: 0, want: initialRetryDelay}, + {name: "doubles normally", current: 4 * time.Second, want: 8 * time.Second}, + {name: "caps at max", current: 40 * time.Second, want: maxRetryDelay}, + {name: "stays max at max", current: maxRetryDelay, want: maxRetryDelay}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := nextRetryDelay(tt.current); got != tt.want { + t.Fatalf("nextRetryDelay(%v) = %v, want %v", tt.current, got, tt.want) + } + }) + } +} + +func TestJitterRetryDelayWithinBounds(t *testing.T) { + base := 10 * time.Second + random := rand.New(rand.NewSource(1)) + + min := 8 * time.Second + max := 12 * time.Second + + for i := 0; i < 100; i++ { + got := jitterRetryDelay(base, random) + if got < min || got > max { + t.Fatalf("jitterRetryDelay() = %v, expected between %v and %v", got, min, max) + } + } +} + +func TestSleepWithContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + if sleepWithContext(ctx, time.Second) { + t.Fatalf("sleepWithContext should stop when context is canceled") + } + + if time.Since(start) > 100*time.Millisecond { + t.Fatalf("sleepWithContext took too long to exit on canceled context") + } +} |
