Skip to main content

rust_tg_bot_ext/utils/
network_loop.rs

1//! A retry loop for network-oriented operations against the Telegram API.
2//!
3//! Port of `telegram.ext._utils.networkloop.network_retry_loop`.
4//! This is library-internal and not part of the public API stability guarantee.
5
6use std::future::Future;
7use std::time::Duration;
8
9use tokio::sync::watch;
10use tracing::{debug, error};
11
12use rust_tg_bot_raw::error::TelegramError;
13
14// ---------------------------------------------------------------------------
15// Configuration
16// ---------------------------------------------------------------------------
17
18/// Parameters for [`network_retry_loop`].
19pub struct NetworkLoopConfig<'a, A, E> {
20    /// The async action to attempt on each iteration.
21    pub action_cb: A,
22    /// Optional error callback invoked when a `TelegramError` is caught.
23    pub on_err_cb: Option<E>,
24    /// Human-readable label used in log messages.
25    pub description: &'a str,
26    /// Base interval between attempts (seconds).
27    pub interval: f64,
28    /// A watch receiver whose value, when `true`, signals the loop to stop.
29    /// Pass `None` if the loop should only be controlled by `is_running` and
30    /// `max_retries`.
31    pub stop_rx: Option<watch::Receiver<bool>>,
32    /// Predicate checked at the top of every iteration. Returning `false`
33    /// exits the loop.
34    pub is_running: Option<Box<dyn Fn() -> bool + Send + Sync + 'a>>,
35    /// Maximum retry count.
36    /// - negative: retry indefinitely.
37    /// - 0: no retries (single attempt).
38    /// - positive: up to N retries.
39    pub max_retries: i32,
40    /// If `true`, the action is repeated after a successful call.
41    pub repeat_on_success: bool,
42}
43
44// ---------------------------------------------------------------------------
45// Loop implementation
46// ---------------------------------------------------------------------------
47
48/// Run `action_cb` in a loop, retrying on `TelegramError` according to the
49/// back-off / retry policy described by `config`.
50///
51/// # Errors
52///
53/// Returns the last `TelegramError` if retries are exhausted, or propagates
54/// an `InvalidToken` error immediately.
55pub async fn network_retry_loop<'a, A, AF, E>(
56    config: NetworkLoopConfig<'a, A, E>,
57) -> Result<(), TelegramError>
58where
59    A: Fn() -> AF,
60    AF: Future<Output = Result<(), TelegramError>>,
61    E: Fn(&TelegramError),
62{
63    let NetworkLoopConfig {
64        action_cb,
65        on_err_cb,
66        description,
67        interval,
68        mut stop_rx,
69        is_running,
70        max_retries,
71        repeat_on_success,
72    } = config;
73
74    let log_prefix = format!("Network Retry Loop ({description}):");
75    let effective_is_running = is_running.unwrap_or_else(|| Box::new(|| true));
76
77    debug!("{log_prefix} Starting");
78
79    let mut cur_interval = interval;
80    let mut retries: i32 = 0;
81
82    while effective_is_running() {
83        // Execute the action, racing against the stop signal if one exists.
84        let action_result = match stop_rx.as_mut() {
85            Some(rx) => {
86                tokio::select! {
87                    biased;
88                    _ = wait_for_stop(rx) => {
89                        debug!("{log_prefix} Cancelled via stop signal");
90                        return Ok(());
91                    }
92                    res = action_cb() => res,
93                }
94            }
95            None => action_cb().await,
96        };
97
98        match action_result {
99            Ok(()) => {
100                if !repeat_on_success {
101                    debug!("{log_prefix} Action succeeded. Stopping loop.");
102                    return Ok(());
103                }
104                cur_interval = interval;
105            }
106            Err(TelegramError::RetryAfter { retry_after }) => {
107                let slack = Duration::from_millis(500);
108                cur_interval = (retry_after + slack).as_secs_f64();
109                if check_max_retries(retries, max_retries, &log_prefix) {
110                    return Err(TelegramError::RetryAfter { retry_after });
111                }
112            }
113            Err(TelegramError::TimedOut(_)) => {
114                cur_interval = 0.0;
115                if check_max_retries(retries, max_retries, &log_prefix) {
116                    return Err(TelegramError::TimedOut("timed out".into()));
117                }
118            }
119            Err(TelegramError::InvalidToken(msg)) => {
120                error!("{log_prefix} Invalid token. Aborting retry loop.");
121                return Err(TelegramError::InvalidToken(msg));
122            }
123            Err(ref e) => {
124                if let Some(ref cb) = on_err_cb {
125                    cb(e);
126                }
127                if check_max_retries(retries, max_retries, &log_prefix) {
128                    // Move out of the ref to return ownership.
129                    return Err(action_result.unwrap_err());
130                }
131                // Exponential back-off up to 30 seconds.
132                cur_interval = if cur_interval == 0.0 {
133                    1.0
134                } else {
135                    (1.5 * cur_interval).min(30.0)
136                };
137            }
138        }
139
140        retries += 1;
141
142        if cur_interval > 0.0 {
143            tokio::time::sleep(Duration::from_secs_f64(cur_interval)).await;
144        }
145    }
146
147    Ok(())
148}
149
150/// Wait until the watch channel yields `true`.
151async fn wait_for_stop(rx: &mut watch::Receiver<bool>) {
152    while !*rx.borrow_and_update() {
153        if rx.changed().await.is_err() {
154            // Sender dropped -- treat as stop.
155            return;
156        }
157    }
158}
159
160/// Returns `true` if we should abort (max retries reached).
161fn check_max_retries(current: i32, max: i32, prefix: &str) -> bool {
162    if max < 0 || current < max {
163        debug!("{prefix} Failed run {current} of {max}. Retrying.",);
164        false
165    } else {
166        error!("{prefix} Failed run {current} of {max}. Aborting.",);
167        true
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use std::sync::atomic::{AtomicU32, Ordering};
175    use std::sync::Arc;
176
177    #[tokio::test]
178    async fn succeeds_on_first_try() {
179        let result = network_retry_loop(NetworkLoopConfig {
180            action_cb: || async { Ok(()) },
181            on_err_cb: None::<fn(&TelegramError)>,
182            description: "test",
183            interval: 0.0,
184            stop_rx: None,
185            is_running: None,
186            max_retries: 0,
187            repeat_on_success: false,
188        })
189        .await;
190        assert!(result.is_ok());
191    }
192
193    #[tokio::test]
194    async fn retries_and_succeeds() {
195        let counter = Arc::new(AtomicU32::new(0));
196        let c = counter.clone();
197        let result = network_retry_loop(NetworkLoopConfig {
198            action_cb: move || {
199                let c = c.clone();
200                async move {
201                    let n = c.fetch_add(1, Ordering::SeqCst);
202                    if n < 2 {
203                        Err(TelegramError::Network("fail".into()))
204                    } else {
205                        Ok(())
206                    }
207                }
208            },
209            on_err_cb: None::<fn(&TelegramError)>,
210            description: "retry-test",
211            interval: 0.0,
212            stop_rx: None,
213            is_running: None,
214            max_retries: -1, // indefinite
215            repeat_on_success: false,
216        })
217        .await;
218        assert!(result.is_ok());
219        assert_eq!(counter.load(Ordering::SeqCst), 3);
220    }
221
222    #[tokio::test]
223    async fn aborts_after_max_retries() {
224        let result = network_retry_loop(NetworkLoopConfig {
225            action_cb: || async { Err::<(), _>(TelegramError::Network("always fail".into())) },
226            on_err_cb: None::<fn(&TelegramError)>,
227            description: "abort-test",
228            interval: 0.0,
229            stop_rx: None,
230            is_running: None,
231            max_retries: 2,
232            repeat_on_success: false,
233        })
234        .await;
235        assert!(result.is_err());
236    }
237
238    #[tokio::test]
239    async fn invalid_token_aborts_immediately() {
240        let counter = Arc::new(AtomicU32::new(0));
241        let c = counter.clone();
242        let result = network_retry_loop(NetworkLoopConfig {
243            action_cb: move || {
244                let c = c.clone();
245                async move {
246                    c.fetch_add(1, Ordering::SeqCst);
247                    Err::<(), _>(TelegramError::InvalidToken("bad".into()))
248                }
249            },
250            on_err_cb: None::<fn(&TelegramError)>,
251            description: "token-test",
252            interval: 0.0,
253            stop_rx: None,
254            is_running: None,
255            max_retries: -1,
256            repeat_on_success: false,
257        })
258        .await;
259        assert!(result.is_err());
260        // Should abort on the first attempt, no retries.
261        assert_eq!(counter.load(Ordering::SeqCst), 1);
262    }
263
264    #[tokio::test]
265    async fn stop_signal_cancels_loop() {
266        let (tx, rx) = watch::channel(false);
267        let counter = Arc::new(AtomicU32::new(0));
268        let c = counter.clone();
269
270        // Spawn the loop; it will repeat on success indefinitely.
271        let handle = tokio::spawn(async move {
272            network_retry_loop(NetworkLoopConfig {
273                action_cb: move || {
274                    let c = c.clone();
275                    async move {
276                        c.fetch_add(1, Ordering::SeqCst);
277                        Ok(())
278                    }
279                },
280                on_err_cb: None::<fn(&TelegramError)>,
281                description: "stop-test",
282                interval: 0.01,
283                stop_rx: Some(rx),
284                is_running: None,
285                max_retries: -1,
286                repeat_on_success: true,
287            })
288            .await
289        });
290
291        // Let it run a few iterations then signal stop.
292        tokio::time::sleep(Duration::from_millis(80)).await;
293        tx.send(true).unwrap();
294        let result = handle.await.unwrap();
295        assert!(result.is_ok());
296        // It should have run at least once.
297        assert!(counter.load(Ordering::SeqCst) >= 1);
298    }
299}