1use std::time::Duration;
2
3use rand::Rng;
4
5const DEFAULT_MAX_RECONNECT_RETRIES: u32 = 10;
6
7#[derive(Debug, Clone)]
12pub struct ReconnectPolicy {
13 base_delay: Duration,
14 max_delay: Duration,
15 max_retries: Option<u32>,
16 current_attempt: u32,
17 jitter: bool,
18}
19
20impl ReconnectPolicy {
21 pub fn new(base_delay: Duration, max_delay: Duration, max_retries: Option<u32>) -> Self {
27 Self {
28 base_delay,
29 max_delay,
30 max_retries,
31 current_attempt: 0,
32 jitter: false,
33 }
34 }
35
36 pub fn default_policy() -> Self {
41 Self::new(
42 Duration::from_secs(1),
43 Duration::from_secs(30),
44 Some(DEFAULT_MAX_RECONNECT_RETRIES),
45 )
46 .with_jitter()
47 }
48
49 pub fn with_jitter(mut self) -> Self {
53 self.jitter = true;
54 self
55 }
56
57 pub fn next_delay(&mut self) -> Option<Duration> {
61 if let Some(max) = self.max_retries
62 && self.current_attempt >= max
63 {
64 return None;
65 }
66
67 let multiplier = 1u32.checked_shl(self.current_attempt).unwrap_or(u32::MAX);
68 let delay = self.base_delay.saturating_mul(multiplier);
69 let mut delay = delay.min(self.max_delay);
70 if self.jitter {
71 delay = delay.saturating_add(random_jitter_below(delay));
72 }
73
74 self.current_attempt += 1;
75
76 Some(delay)
77 }
78
79 pub fn reset(&mut self) {
81 self.current_attempt = 0;
82 }
83
84 pub fn attempts(&self) -> u32 {
86 self.current_attempt
87 }
88}
89
90fn random_jitter_below(delay: Duration) -> Duration {
91 let max_millis = delay.as_millis();
92 if max_millis == 0 {
93 return Duration::ZERO;
94 }
95
96 let jitter_millis = if max_millis > u64::MAX as u128 {
97 rand::thread_rng().gen_range(0..u64::MAX)
98 } else {
99 rand::thread_rng().gen_range(0..max_millis as u64)
100 };
101 Duration::from_millis(jitter_millis)
102}
103
104#[cfg(test)]
105mod tests;