1use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::sync::{mpsc, watch, Mutex};
14use tracing::{debug, error, warn};
15
16use rust_tg_bot_raw::error::TelegramError;
17
18use crate::utils::network_loop::{network_retry_loop, NetworkLoopConfig};
19
20#[cfg(feature = "webhooks")]
21use tokio::sync::Notify;
22
23#[cfg(feature = "webhooks")]
24use crate::utils::webhook_handler::WebhookServer;
25
26#[cfg(feature = "webhooks")]
27use rust_tg_bot_raw::types::update::Update;
28
29pub type GetUpdatesFn = Arc<
36 dyn Fn(
37 i64,
38 Duration,
39 Option<Vec<String>>,
40 ) -> std::pin::Pin<
41 Box<
42 dyn std::future::Future<Output = Result<Vec<serde_json::Value>, TelegramError>>
43 + Send,
44 >,
45 > + Send
46 + Sync,
47>;
48
49pub type DeleteWebhookFn = Arc<
51 dyn Fn(
52 bool,
53 )
54 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), TelegramError>> + Send>>
55 + Send
56 + Sync,
57>;
58
59#[derive(Clone)]
65pub struct PollingConfig {
66 pub poll_interval: Duration,
68 pub timeout: Duration,
70 pub bootstrap_retries: i32,
72 pub allowed_updates: Option<Vec<String>>,
74 pub drop_pending_updates: bool,
76 pub get_updates: GetUpdatesFn,
78 pub delete_webhook: DeleteWebhookFn,
80}
81
82#[cfg(feature = "webhooks")]
84#[derive(Clone)]
85pub struct WebhookConfig {
86 pub listen: String,
87 pub port: u16,
88 pub url_path: String,
89 pub webhook_url: Option<String>,
90 pub secret_token: Option<String>,
91 pub bootstrap_retries: i32,
92 pub drop_pending_updates: bool,
93 pub allowed_updates: Option<Vec<String>>,
94 pub max_connections: u32,
95 pub cert_path: Option<String>,
101 pub key_path: Option<String>,
107}
108
109#[cfg(feature = "webhooks")]
110impl Default for WebhookConfig {
111 fn default() -> Self {
112 Self {
113 listen: "127.0.0.1".into(),
114 port: 80,
115 url_path: String::new(),
116 webhook_url: None,
117 secret_token: None,
118 bootstrap_retries: 0,
119 drop_pending_updates: false,
120 allowed_updates: None,
121 max_connections: 40,
122 cert_path: None,
123 key_path: None,
124 }
125 }
126}
127
128#[cfg(feature = "webhooks")]
129impl WebhookConfig {
130 pub fn new(url: impl Into<String>) -> Self {
133 let url = url.into();
134 Self {
135 webhook_url: Some(url),
136 ..Default::default()
137 }
138 }
139
140 pub fn listen(mut self, addr: impl Into<String>) -> Self {
142 self.listen = addr.into();
143 self
144 }
145
146 pub fn port(mut self, port: u16) -> Self {
148 self.port = port;
149 self
150 }
151
152 pub fn url_path(mut self, path: impl Into<String>) -> Self {
154 self.url_path = path.into();
155 self
156 }
157
158 pub fn secret_token(mut self, token: impl Into<String>) -> Self {
160 self.secret_token = Some(token.into());
161 self
162 }
163
164 pub fn bootstrap_retries(mut self, n: i32) -> Self {
166 self.bootstrap_retries = n;
167 self
168 }
169
170 pub fn drop_pending_updates(mut self, drop: bool) -> Self {
172 self.drop_pending_updates = drop;
173 self
174 }
175
176 pub fn allowed_updates(mut self, types: Vec<String>) -> Self {
178 self.allowed_updates = Some(types);
179 self
180 }
181
182 pub fn max_connections(mut self, n: u32) -> Self {
184 self.max_connections = n;
185 self
186 }
187
188 pub fn tls(mut self, cert: impl Into<String>, key: impl Into<String>) -> Self {
204 self.cert_path = Some(cert.into());
205 self.key_path = Some(key.into());
206 self
207 }
208
209 pub fn has_tls(&self) -> bool {
211 self.cert_path.is_some() && self.key_path.is_some()
212 }
213}
214
215pub struct Updater {
222 update_tx: mpsc::Sender<serde_json::Value>,
223 update_rx: Mutex<Option<mpsc::Receiver<serde_json::Value>>>,
224 running: std::sync::atomic::AtomicBool,
225 initialized: std::sync::atomic::AtomicBool,
226 last_update_id: Mutex<i64>,
227 stop_tx: watch::Sender<bool>,
229 #[cfg(feature = "webhooks")]
231 httpd: Mutex<Option<Arc<WebhookServer>>>,
232}
233
234impl std::fmt::Debug for Updater {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 f.debug_struct("Updater")
237 .field("running", &self.is_running())
238 .field(
239 "initialized",
240 &self.initialized.load(std::sync::atomic::Ordering::Relaxed),
241 )
242 .finish()
243 }
244}
245
246impl Updater {
247 pub fn new(channel_size: usize) -> Self {
251 let (update_tx, update_rx) = mpsc::channel(channel_size);
252 let (stop_tx, _stop_rx) = watch::channel(false);
253 Self {
254 update_tx,
255 update_rx: Mutex::new(Some(update_rx)),
256 running: false.into(),
257 initialized: false.into(),
258 last_update_id: Mutex::new(0),
259 stop_tx,
260 #[cfg(feature = "webhooks")]
261 httpd: Mutex::new(None),
262 }
263 }
264
265 pub async fn take_update_rx(&self) -> Option<mpsc::Receiver<serde_json::Value>> {
268 self.update_rx.lock().await.take()
269 }
270 pub fn is_running(&self) -> bool {
272 self.running.load(std::sync::atomic::Ordering::Relaxed)
273 }
274
275 pub async fn initialize(&self) {
281 if self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
282 debug!("Updater already initialized");
283 return;
284 }
285 self.initialized
286 .store(true, std::sync::atomic::Ordering::Relaxed);
287 debug!("Updater initialized");
288 }
289
290 pub async fn shutdown(&self) -> Result<(), UpdaterError> {
292 if self.is_running() {
293 return Err(UpdaterError::StillRunning);
294 }
295 if !self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
296 debug!("Updater already shut down");
297 return Ok(());
298 }
299 self.initialized
300 .store(false, std::sync::atomic::Ordering::Relaxed);
301 debug!("Updater shut down");
302 Ok(())
303 }
304
305 pub async fn start_polling(
314 self: &Arc<Self>,
315 config: PollingConfig,
316 ) -> Result<(), UpdaterError> {
317 if self.is_running() {
318 return Err(UpdaterError::AlreadyRunning);
319 }
320 if !self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
321 return Err(UpdaterError::NotInitialized);
322 }
323
324 self.running
325 .store(true, std::sync::atomic::Ordering::Relaxed);
326
327 let _ = self.stop_tx.send(false);
329
330 let delete_fn = config.delete_webhook.clone();
332 let drop_pending = config.drop_pending_updates;
333 let bootstrap_retries = config.bootstrap_retries;
334
335 if let Err(e) = self
336 .bootstrap_delete_webhook(delete_fn, drop_pending, bootstrap_retries)
337 .await
338 {
339 self.running
340 .store(false, std::sync::atomic::Ordering::Relaxed);
341 return Err(UpdaterError::Bootstrap(e.to_string()));
342 }
343
344 debug!("Bootstrap complete, starting polling loop");
345
346 let updater = Arc::clone(self);
347 let stop_rx = self.stop_tx.subscribe();
348
349 tokio::spawn(async move {
350 let tx = updater.update_tx.clone();
351 let timeout = config.timeout;
352 let poll_interval = config.poll_interval;
353 let allowed = config.allowed_updates.clone();
354 let get_updates_fn = config.get_updates.clone();
355
356 let result = network_retry_loop(NetworkLoopConfig {
357 action_cb: || {
358 let tx = tx.clone();
359 let updater_inner = updater.clone();
360 let allowed_inner = allowed.clone();
361 let get_fn = get_updates_fn.clone();
362 async move {
363 let last_id = { *updater_inner.last_update_id.lock().await };
364 let updates: Vec<serde_json::Value> =
365 get_fn(last_id, timeout, allowed_inner).await?;
366 if !updates.is_empty() {
367 if !updater_inner.is_running() {
368 warn!(
369 "Updater stopped unexpectedly. Pulled updates will be \
370 ignored and pulled again on restart."
371 );
372 return Ok(());
373 }
374 for update in &updates {
375 if let Err(e) = tx.send(update.clone()).await {
376 error!("Failed to enqueue update: {e}");
377 }
378 }
379 if let Some(last) = updates.last() {
380 if let Some(uid) = last.get("update_id").and_then(|v| v.as_i64()) {
381 *updater_inner.last_update_id.lock().await = uid + 1;
382 }
383 }
384 }
385 Ok(())
386 }
387 },
388 on_err_cb: Some(|e: &TelegramError| {
389 error!("Error while polling for updates: {e}");
390 }),
391 description: "Polling Updates",
392 interval: poll_interval.as_secs_f64(),
393 stop_rx: Some(stop_rx),
394 is_running: Some(Box::new({
395 let u = updater.clone();
396 move || u.is_running()
397 })),
398 max_retries: -1,
399 repeat_on_success: true,
400 })
401 .await;
402
403 if let Err(e) = result {
404 error!("Polling loop exited with error: {e}");
405 }
406 });
407
408 Ok(())
409 }
410
411 #[cfg(feature = "webhooks")]
417 pub async fn start_webhook(
418 self: &Arc<Self>,
419 config: WebhookConfig,
420 ) -> Result<(), UpdaterError> {
421 if self.is_running() {
422 return Err(UpdaterError::AlreadyRunning);
423 }
424 if !self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
425 return Err(UpdaterError::NotInitialized);
426 }
427
428 self.running
429 .store(true, std::sync::atomic::Ordering::Relaxed);
430 let _ = self.stop_tx.send(false);
431
432 let (typed_tx, mut typed_rx) = mpsc::channel::<Update>(256);
435 let value_tx = self.update_tx.clone();
436 tokio::spawn(async move {
437 while let Some(update) = typed_rx.recv().await {
438 match serde_json::to_value(&update) {
439 Ok(v) => {
440 let _ = value_tx.send(v).await;
441 }
442 Err(e) => {
443 error!("Failed to serialize Update to Value: {e}");
444 }
445 }
446 }
447 });
448
449 #[cfg(feature = "webhooks-tls")]
451 let tls_config = if config.has_tls() {
452 let cert_path = config
453 .cert_path
454 .as_deref()
455 .expect("cert_path checked by has_tls");
456 let key_path = config
457 .key_path
458 .as_deref()
459 .expect("key_path checked by has_tls");
460 match crate::utils::webhook_handler::TlsConfig::from_pem_files(cert_path, key_path)
461 .await
462 {
463 Ok(tls) => Some(tls),
464 Err(e) => {
465 self.running
466 .store(false, std::sync::atomic::Ordering::Relaxed);
467 return Err(UpdaterError::Bootstrap(format!(
468 "TLS configuration failed: {e}"
469 )));
470 }
471 }
472 } else {
473 None
474 };
475
476 #[cfg(not(feature = "webhooks-tls"))]
478 if config.has_tls() {
479 warn!(
480 "TLS cert_path/key_path are set but the `webhooks-tls` feature is not enabled. \
481 The server will start without TLS. Enable the `webhooks-tls` feature to use HTTPS."
482 );
483 }
484
485 let server = Arc::new(WebhookServer::new(
486 &config.listen,
487 config.port,
488 &config.url_path,
489 typed_tx,
490 config.secret_token,
491 #[cfg(feature = "webhooks-tls")]
492 tls_config,
493 ));
494
495 let ready = Arc::new(Notify::new());
496 let ready_clone = ready.clone();
497
498 let srv = server.clone();
499 tokio::spawn(async move {
500 if let Err(e) = srv.serve_forever(Some(ready_clone)).await {
501 error!("Webhook server error: {e}");
502 }
503 });
504
505 ready.notified().await;
506 debug!(
507 "Webhook server started on {}:{}",
508 config.listen, config.port
509 );
510
511 *self.httpd.lock().await = Some(server);
512
513 Ok(())
514 }
515
516 pub async fn stop(&self) -> Result<(), UpdaterError> {
522 if !self.is_running() {
523 return Err(UpdaterError::NotRunning);
524 }
525 debug!("Stopping updater");
526 self.running
527 .store(false, std::sync::atomic::Ordering::Relaxed);
528
529 let _ = self.stop_tx.send(true);
531
532 #[cfg(feature = "webhooks")]
534 {
535 let httpd = self.httpd.lock().await;
536 if let Some(ref server) = *httpd {
537 server.shutdown();
538 }
539 }
540
541 debug!("Updater stopped");
542 Ok(())
543 }
544
545 async fn bootstrap_delete_webhook(
550 &self,
551 delete_fn: DeleteWebhookFn,
552 drop_pending: bool,
553 max_retries: i32,
554 ) -> Result<(), TelegramError> {
555 debug!("Deleting webhook (bootstrap)");
556 network_retry_loop(NetworkLoopConfig {
557 action_cb: || {
558 let f = delete_fn.clone();
559 async move { f(drop_pending).await }
560 },
561 on_err_cb: None::<fn(&TelegramError)>,
562 description: "Bootstrap delete webhook",
563 interval: 1.0,
564 stop_rx: None,
565 is_running: None,
566 max_retries,
567 repeat_on_success: false,
568 })
569 .await
570 }
571}
572
573#[derive(Debug, thiserror::Error)]
578#[non_exhaustive]
580pub enum UpdaterError {
581 #[error("this Updater is already running")]
583 AlreadyRunning,
584
585 #[error("this Updater is not running")]
587 NotRunning,
588
589 #[error("this Updater was not initialized")]
591 NotInitialized,
592
593 #[error("this Updater is still running")]
595 StillRunning,
596
597 #[error("bootstrap failed: {0}")]
599 Bootstrap(String),
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 fn noop_get_updates() -> GetUpdatesFn {
607 Arc::new(|_offset, _timeout, _allowed| Box::pin(async { Ok(Vec::new()) }))
608 }
609
610 fn noop_delete_webhook() -> DeleteWebhookFn {
611 Arc::new(|_drop_pending| Box::pin(async { Ok(()) }))
612 }
613
614 fn default_config() -> PollingConfig {
615 PollingConfig {
616 poll_interval: Duration::ZERO,
617 timeout: Duration::from_secs(1),
618 bootstrap_retries: 0,
619 allowed_updates: None,
620 drop_pending_updates: false,
621 get_updates: noop_get_updates(),
622 delete_webhook: noop_delete_webhook(),
623 }
624 }
625
626 #[tokio::test]
627 async fn lifecycle() {
628 let updater = Arc::new(Updater::new(16));
629 assert!(!updater.is_running());
630
631 updater.initialize().await;
632
633 assert!(updater.stop().await.is_err());
635
636 updater.shutdown().await.unwrap();
637 }
638
639 #[tokio::test]
640 async fn start_polling_requires_init() {
641 let updater = Arc::new(Updater::new(16));
642 let result = updater.start_polling(default_config()).await;
643 assert!(matches!(result, Err(UpdaterError::NotInitialized)));
644 }
645
646 #[tokio::test]
647 async fn start_and_stop_polling() {
648 let updater = Arc::new(Updater::new(16));
649 updater.initialize().await;
650 updater.start_polling(default_config()).await.unwrap();
651 assert!(updater.is_running());
652
653 let result = updater.start_polling(default_config()).await;
655 assert!(matches!(result, Err(UpdaterError::AlreadyRunning)));
656
657 updater.stop().await.unwrap();
658 assert!(!updater.is_running());
659 }
660
661 #[tokio::test]
662 async fn take_update_rx_once() {
663 let updater = Arc::new(Updater::new(16));
664 let rx = updater.take_update_rx().await;
665 assert!(rx.is_some());
666 let rx2 = updater.take_update_rx().await;
667 assert!(rx2.is_none());
668 }
669
670 #[tokio::test]
671 async fn polling_delivers_updates() {
672 let updater = Arc::new(Updater::new(16));
673 updater.initialize().await;
674
675 let mut rx = updater.take_update_rx().await.unwrap();
676
677 let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
679 let cc = call_count.clone();
680 let get_fn: GetUpdatesFn = Arc::new(move |_offset, _timeout, _allowed| {
681 let cc = cc.clone();
682 Box::pin(async move {
683 let n = cc.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
684 if n == 0 {
685 Ok(vec![serde_json::json!({"update_id": 100, "message": {}})])
686 } else {
687 Ok(Vec::new())
688 }
689 })
690 });
691
692 let config = PollingConfig {
693 poll_interval: Duration::from_millis(10),
694 timeout: Duration::from_secs(1),
695 bootstrap_retries: 0,
696 allowed_updates: None,
697 drop_pending_updates: false,
698 get_updates: get_fn,
699 delete_webhook: noop_delete_webhook(),
700 };
701
702 updater.start_polling(config).await.unwrap();
703
704 let update = tokio::time::timeout(Duration::from_secs(2), rx.recv())
706 .await
707 .expect("timeout waiting for update")
708 .expect("channel closed");
709
710 assert_eq!(update["update_id"], 100);
711
712 updater.stop().await.unwrap();
713 }
714}