1use std::future::Future;
7use std::time::Duration;
8
9use tokio::sync::watch;
10use tracing::{debug, error};
11
12use rust_tg_bot_raw::error::TelegramError;
13
14pub struct NetworkLoopConfig<'a, A, E> {
20 pub action_cb: A,
22 pub on_err_cb: Option<E>,
24 pub description: &'a str,
26 pub interval: f64,
28 pub stop_rx: Option<watch::Receiver<bool>>,
32 pub is_running: Option<Box<dyn Fn() -> bool + Send + Sync + 'a>>,
35 pub max_retries: i32,
40 pub repeat_on_success: bool,
42}
43
44pub async fn network_retry_loop<'a, A, AF, E>(
56 config: NetworkLoopConfig<'a, A, E>,
57) -> Result<(), TelegramError>
58where
59 A: Fn() -> AF,
60 AF: Future<Output = Result<(), TelegramError>>,
61 E: Fn(&TelegramError),
62{
63 let NetworkLoopConfig {
64 action_cb,
65 on_err_cb,
66 description,
67 interval,
68 mut stop_rx,
69 is_running,
70 max_retries,
71 repeat_on_success,
72 } = config;
73
74 let log_prefix = format!("Network Retry Loop ({description}):");
75 let effective_is_running = is_running.unwrap_or_else(|| Box::new(|| true));
76
77 debug!("{log_prefix} Starting");
78
79 let mut cur_interval = interval;
80 let mut retries: i32 = 0;
81
82 while effective_is_running() {
83 let action_result = match stop_rx.as_mut() {
85 Some(rx) => {
86 tokio::select! {
87 biased;
88 _ = wait_for_stop(rx) => {
89 debug!("{log_prefix} Cancelled via stop signal");
90 return Ok(());
91 }
92 res = action_cb() => res,
93 }
94 }
95 None => action_cb().await,
96 };
97
98 match action_result {
99 Ok(()) => {
100 if !repeat_on_success {
101 debug!("{log_prefix} Action succeeded. Stopping loop.");
102 return Ok(());
103 }
104 cur_interval = interval;
105 }
106 Err(TelegramError::RetryAfter { retry_after }) => {
107 let slack = Duration::from_millis(500);
108 cur_interval = (retry_after + slack).as_secs_f64();
109 if check_max_retries(retries, max_retries, &log_prefix) {
110 return Err(TelegramError::RetryAfter { retry_after });
111 }
112 }
113 Err(TelegramError::TimedOut(_)) => {
114 cur_interval = 0.0;
115 if check_max_retries(retries, max_retries, &log_prefix) {
116 return Err(TelegramError::TimedOut("timed out".into()));
117 }
118 }
119 Err(TelegramError::InvalidToken(msg)) => {
120 error!("{log_prefix} Invalid token. Aborting retry loop.");
121 return Err(TelegramError::InvalidToken(msg));
122 }
123 Err(ref e) => {
124 if let Some(ref cb) = on_err_cb {
125 cb(e);
126 }
127 if check_max_retries(retries, max_retries, &log_prefix) {
128 return Err(action_result.unwrap_err());
130 }
131 cur_interval = if cur_interval == 0.0 {
133 1.0
134 } else {
135 (1.5 * cur_interval).min(30.0)
136 };
137 }
138 }
139
140 retries += 1;
141
142 if cur_interval > 0.0 {
143 tokio::time::sleep(Duration::from_secs_f64(cur_interval)).await;
144 }
145 }
146
147 Ok(())
148}
149
150async fn wait_for_stop(rx: &mut watch::Receiver<bool>) {
152 while !*rx.borrow_and_update() {
153 if rx.changed().await.is_err() {
154 return;
156 }
157 }
158}
159
160fn check_max_retries(current: i32, max: i32, prefix: &str) -> bool {
162 if max < 0 || current < max {
163 debug!("{prefix} Failed run {current} of {max}. Retrying.",);
164 false
165 } else {
166 error!("{prefix} Failed run {current} of {max}. Aborting.",);
167 true
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use std::sync::atomic::{AtomicU32, Ordering};
175 use std::sync::Arc;
176
177 #[tokio::test]
178 async fn succeeds_on_first_try() {
179 let result = network_retry_loop(NetworkLoopConfig {
180 action_cb: || async { Ok(()) },
181 on_err_cb: None::<fn(&TelegramError)>,
182 description: "test",
183 interval: 0.0,
184 stop_rx: None,
185 is_running: None,
186 max_retries: 0,
187 repeat_on_success: false,
188 })
189 .await;
190 assert!(result.is_ok());
191 }
192
193 #[tokio::test]
194 async fn retries_and_succeeds() {
195 let counter = Arc::new(AtomicU32::new(0));
196 let c = counter.clone();
197 let result = network_retry_loop(NetworkLoopConfig {
198 action_cb: move || {
199 let c = c.clone();
200 async move {
201 let n = c.fetch_add(1, Ordering::SeqCst);
202 if n < 2 {
203 Err(TelegramError::Network("fail".into()))
204 } else {
205 Ok(())
206 }
207 }
208 },
209 on_err_cb: None::<fn(&TelegramError)>,
210 description: "retry-test",
211 interval: 0.0,
212 stop_rx: None,
213 is_running: None,
214 max_retries: -1, repeat_on_success: false,
216 })
217 .await;
218 assert!(result.is_ok());
219 assert_eq!(counter.load(Ordering::SeqCst), 3);
220 }
221
222 #[tokio::test]
223 async fn aborts_after_max_retries() {
224 let result = network_retry_loop(NetworkLoopConfig {
225 action_cb: || async { Err::<(), _>(TelegramError::Network("always fail".into())) },
226 on_err_cb: None::<fn(&TelegramError)>,
227 description: "abort-test",
228 interval: 0.0,
229 stop_rx: None,
230 is_running: None,
231 max_retries: 2,
232 repeat_on_success: false,
233 })
234 .await;
235 assert!(result.is_err());
236 }
237
238 #[tokio::test]
239 async fn invalid_token_aborts_immediately() {
240 let counter = Arc::new(AtomicU32::new(0));
241 let c = counter.clone();
242 let result = network_retry_loop(NetworkLoopConfig {
243 action_cb: move || {
244 let c = c.clone();
245 async move {
246 c.fetch_add(1, Ordering::SeqCst);
247 Err::<(), _>(TelegramError::InvalidToken("bad".into()))
248 }
249 },
250 on_err_cb: None::<fn(&TelegramError)>,
251 description: "token-test",
252 interval: 0.0,
253 stop_rx: None,
254 is_running: None,
255 max_retries: -1,
256 repeat_on_success: false,
257 })
258 .await;
259 assert!(result.is_err());
260 assert_eq!(counter.load(Ordering::SeqCst), 1);
262 }
263
264 #[tokio::test]
265 async fn stop_signal_cancels_loop() {
266 let (tx, rx) = watch::channel(false);
267 let counter = Arc::new(AtomicU32::new(0));
268 let c = counter.clone();
269
270 let handle = tokio::spawn(async move {
272 network_retry_loop(NetworkLoopConfig {
273 action_cb: move || {
274 let c = c.clone();
275 async move {
276 c.fetch_add(1, Ordering::SeqCst);
277 Ok(())
278 }
279 },
280 on_err_cb: None::<fn(&TelegramError)>,
281 description: "stop-test",
282 interval: 0.01,
283 stop_rx: Some(rx),
284 is_running: None,
285 max_retries: -1,
286 repeat_on_success: true,
287 })
288 .await
289 });
290
291 tokio::time::sleep(Duration::from_millis(80)).await;
293 tx.send(true).unwrap();
294 let result = handle.await.unwrap();
295 assert!(result.is_ok());
296 assert!(counter.load(Ordering::SeqCst) >= 1);
298 }
299}