Skip to main content

rust_tg_bot_ext/
updater.rs

1//! Fetches updates via long polling or webhook and pushes them into a channel.
2//!
3//! Port of `telegram.ext._updater.Updater`.
4//!
5//! The `Updater` is the bridge between Telegram and the application: it either
6//! polls `getUpdates` or starts a webhook server, then forwards every
7//! `Update` into a `tokio::sync::mpsc` channel for the `Application` to
8//! consume.
9
10use std::sync::Arc;
11use std::time::Duration;
12
13use tokio::sync::{mpsc, watch, Mutex};
14use tracing::{debug, error, warn};
15
16use rust_tg_bot_raw::error::TelegramError;
17
18use crate::utils::network_loop::{network_retry_loop, NetworkLoopConfig};
19
20#[cfg(feature = "webhooks")]
21use tokio::sync::Notify;
22
23#[cfg(feature = "webhooks")]
24use crate::utils::webhook_handler::WebhookServer;
25
26#[cfg(feature = "webhooks")]
27use rust_tg_bot_raw::types::update::Update;
28
29// ---------------------------------------------------------------------------
30// Function types
31// ---------------------------------------------------------------------------
32
33/// A function that fetches updates from the Telegram API.
34/// Signature: `(offset, timeout, allowed_updates) -> Result<Vec<Value>>`.
35pub type GetUpdatesFn = Arc<
36    dyn Fn(
37            i64,
38            Duration,
39            Option<Vec<String>>,
40        ) -> std::pin::Pin<
41            Box<
42                dyn std::future::Future<Output = Result<Vec<serde_json::Value>, TelegramError>>
43                    + Send,
44            >,
45        > + Send
46        + Sync,
47>;
48
49/// A function that deletes the webhook. Signature: `(drop_pending) -> Result<()>`.
50pub type DeleteWebhookFn = Arc<
51    dyn Fn(
52            bool,
53        )
54            -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), TelegramError>> + Send>>
55        + Send
56        + Sync,
57>;
58
59// ---------------------------------------------------------------------------
60// Configuration types
61// ---------------------------------------------------------------------------
62
63/// Configuration for [`Updater::start_polling`].
64#[derive(Clone)]
65pub struct PollingConfig {
66    /// Interval between successive poll requests.
67    pub poll_interval: Duration,
68    /// Long-polling timeout sent to the Telegram API.
69    pub timeout: Duration,
70    /// Maximum number of retries during the bootstrap phase.
71    pub bootstrap_retries: i32,
72    /// List of update types the bot should receive, or `None` for all types.
73    pub allowed_updates: Option<Vec<String>>,
74    /// Whether to drop pending updates before starting the polling loop.
75    pub drop_pending_updates: bool,
76    /// The function used to call `getUpdates`.
77    pub get_updates: GetUpdatesFn,
78    /// The function used to delete the webhook during bootstrap.
79    pub delete_webhook: DeleteWebhookFn,
80}
81
82/// Configuration for [`Updater::start_webhook`].
83#[cfg(feature = "webhooks")]
84#[derive(Clone)]
85pub struct WebhookConfig {
86    pub listen: String,
87    pub port: u16,
88    pub url_path: String,
89    pub webhook_url: Option<String>,
90    pub secret_token: Option<String>,
91    pub bootstrap_retries: i32,
92    pub drop_pending_updates: bool,
93    pub allowed_updates: Option<Vec<String>>,
94    pub max_connections: u32,
95    /// Path to a PEM-encoded TLS certificate file.
96    ///
97    /// When both `cert_path` and `key_path` are set the webhook server will
98    /// serve over HTTPS using `tokio-rustls`. Requires the `webhooks-tls`
99    /// feature.
100    pub cert_path: Option<String>,
101    /// Path to a PEM-encoded TLS private key file.
102    ///
103    /// When both `cert_path` and `key_path` are set the webhook server will
104    /// serve over HTTPS using `tokio-rustls`. Requires the `webhooks-tls`
105    /// feature.
106    pub key_path: Option<String>,
107}
108
109#[cfg(feature = "webhooks")]
110impl Default for WebhookConfig {
111    fn default() -> Self {
112        Self {
113            listen: "127.0.0.1".into(),
114            port: 80,
115            url_path: String::new(),
116            webhook_url: None,
117            secret_token: None,
118            bootstrap_retries: 0,
119            drop_pending_updates: false,
120            allowed_updates: None,
121            max_connections: 40,
122            cert_path: None,
123            key_path: None,
124        }
125    }
126}
127
128#[cfg(feature = "webhooks")]
129impl WebhookConfig {
130    /// Create a new webhook config with the given URL.
131    /// Defaults: listen 127.0.0.1:80, no secret token, no TLS.
132    pub fn new(url: impl Into<String>) -> Self {
133        let url = url.into();
134        Self {
135            webhook_url: Some(url),
136            ..Default::default()
137        }
138    }
139
140    /// Set the listen address (default: "127.0.0.1").
141    pub fn listen(mut self, addr: impl Into<String>) -> Self {
142        self.listen = addr.into();
143        self
144    }
145
146    /// Set the port (default: 80).
147    pub fn port(mut self, port: u16) -> Self {
148        self.port = port;
149        self
150    }
151
152    /// Set the URL path the webhook listens on (default: "").
153    pub fn url_path(mut self, path: impl Into<String>) -> Self {
154        self.url_path = path.into();
155        self
156    }
157
158    /// Set the secret token for webhook validation.
159    pub fn secret_token(mut self, token: impl Into<String>) -> Self {
160        self.secret_token = Some(token.into());
161        self
162    }
163
164    /// Set the number of bootstrap retries (default: 0).
165    pub fn bootstrap_retries(mut self, n: i32) -> Self {
166        self.bootstrap_retries = n;
167        self
168    }
169
170    /// Drop pending updates before starting (default: false).
171    pub fn drop_pending_updates(mut self, drop: bool) -> Self {
172        self.drop_pending_updates = drop;
173        self
174    }
175
176    /// Set allowed update types.
177    pub fn allowed_updates(mut self, types: Vec<String>) -> Self {
178        self.allowed_updates = Some(types);
179        self
180    }
181
182    /// Set max webhook connections (default: 40).
183    pub fn max_connections(mut self, n: u32) -> Self {
184        self.max_connections = n;
185        self
186    }
187
188    /// Configure TLS with certificate and private key PEM files.
189    ///
190    /// When set, the webhook server will serve over HTTPS. The certificate
191    /// file may contain the full chain. Requires the `webhooks-tls` feature
192    /// to be enabled at compile time.
193    ///
194    /// # Example
195    ///
196    /// ```rust,ignore
197    /// let config = WebhookConfig::new("https://mybot.example.com/telegram")
198    ///     .port(8443)
199    ///     .url_path("/telegram")
200    ///     .secret_token("my-secret")
201    ///     .tls("/path/to/cert.pem", "/path/to/key.pem");
202    /// ```
203    pub fn tls(mut self, cert: impl Into<String>, key: impl Into<String>) -> Self {
204        self.cert_path = Some(cert.into());
205        self.key_path = Some(key.into());
206        self
207    }
208
209    /// Returns `true` when both `cert_path` and `key_path` are configured.
210    pub fn has_tls(&self) -> bool {
211        self.cert_path.is_some() && self.key_path.is_some()
212    }
213}
214
215// ---------------------------------------------------------------------------
216// Updater
217// ---------------------------------------------------------------------------
218
219/// Fetches updates for the bot via long polling or webhooks and forwards
220/// them through [`take_update_rx`](Updater::take_update_rx).
221pub struct Updater {
222    update_tx: mpsc::Sender<serde_json::Value>,
223    update_rx: Mutex<Option<mpsc::Receiver<serde_json::Value>>>,
224    running: std::sync::atomic::AtomicBool,
225    initialized: std::sync::atomic::AtomicBool,
226    last_update_id: Mutex<i64>,
227    /// Sending `true` signals the polling loop to stop.
228    stop_tx: watch::Sender<bool>,
229    /// The webhook server, if one was started.
230    #[cfg(feature = "webhooks")]
231    httpd: Mutex<Option<Arc<WebhookServer>>>,
232}
233
234impl std::fmt::Debug for Updater {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        f.debug_struct("Updater")
237            .field("running", &self.is_running())
238            .field(
239                "initialized",
240                &self.initialized.load(std::sync::atomic::Ordering::Relaxed),
241            )
242            .finish()
243    }
244}
245
246impl Updater {
247    /// Create a new `Updater`.
248    ///
249    /// `channel_size` controls the bounded channel capacity.
250    pub fn new(channel_size: usize) -> Self {
251        let (update_tx, update_rx) = mpsc::channel(channel_size);
252        let (stop_tx, _stop_rx) = watch::channel(false);
253        Self {
254            update_tx,
255            update_rx: Mutex::new(Some(update_rx)),
256            running: false.into(),
257            initialized: false.into(),
258            last_update_id: Mutex::new(0),
259            stop_tx,
260            #[cfg(feature = "webhooks")]
261            httpd: Mutex::new(None),
262        }
263    }
264
265    /// Take ownership of the receiving end of the update channel. Can only be
266    /// called once; subsequent calls return `None`.
267    pub async fn take_update_rx(&self) -> Option<mpsc::Receiver<serde_json::Value>> {
268        self.update_rx.lock().await.take()
269    }
270    /// Returns `true` if the updater is currently running (polling or webhook).
271    pub fn is_running(&self) -> bool {
272        self.running.load(std::sync::atomic::Ordering::Relaxed)
273    }
274
275    // -----------------------------------------------------------------------
276    // Lifecycle
277    // -----------------------------------------------------------------------
278
279    /// Initialize the updater.
280    pub async fn initialize(&self) {
281        if self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
282            debug!("Updater already initialized");
283            return;
284        }
285        self.initialized
286            .store(true, std::sync::atomic::Ordering::Relaxed);
287        debug!("Updater initialized");
288    }
289
290    /// Shut down the updater. Must not be called while still running.
291    pub async fn shutdown(&self) -> Result<(), UpdaterError> {
292        if self.is_running() {
293            return Err(UpdaterError::StillRunning);
294        }
295        if !self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
296            debug!("Updater already shut down");
297            return Ok(());
298        }
299        self.initialized
300            .store(false, std::sync::atomic::Ordering::Relaxed);
301        debug!("Updater shut down");
302        Ok(())
303    }
304
305    // -----------------------------------------------------------------------
306    // Polling
307    // -----------------------------------------------------------------------
308
309    /// Start polling for updates.
310    ///
311    /// Returns immediately after the bootstrap phase completes. Updates are
312    /// sent through the channel returned by [`take_update_rx`](Self::take_update_rx).
313    pub async fn start_polling(
314        self: &Arc<Self>,
315        config: PollingConfig,
316    ) -> Result<(), UpdaterError> {
317        if self.is_running() {
318            return Err(UpdaterError::AlreadyRunning);
319        }
320        if !self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
321            return Err(UpdaterError::NotInitialized);
322        }
323
324        self.running
325            .store(true, std::sync::atomic::Ordering::Relaxed);
326
327        // Reset the stop signal from any prior run.
328        let _ = self.stop_tx.send(false);
329
330        // Bootstrap: delete any existing webhook.
331        let delete_fn = config.delete_webhook.clone();
332        let drop_pending = config.drop_pending_updates;
333        let bootstrap_retries = config.bootstrap_retries;
334
335        if let Err(e) = self
336            .bootstrap_delete_webhook(delete_fn, drop_pending, bootstrap_retries)
337            .await
338        {
339            self.running
340                .store(false, std::sync::atomic::Ordering::Relaxed);
341            return Err(UpdaterError::Bootstrap(e.to_string()));
342        }
343
344        debug!("Bootstrap complete, starting polling loop");
345
346        let updater = Arc::clone(self);
347        let stop_rx = self.stop_tx.subscribe();
348
349        tokio::spawn(async move {
350            let tx = updater.update_tx.clone();
351            let timeout = config.timeout;
352            let poll_interval = config.poll_interval;
353            let allowed = config.allowed_updates.clone();
354            let get_updates_fn = config.get_updates.clone();
355
356            let result = network_retry_loop(NetworkLoopConfig {
357                action_cb: || {
358                    let tx = tx.clone();
359                    let updater_inner = updater.clone();
360                    let allowed_inner = allowed.clone();
361                    let get_fn = get_updates_fn.clone();
362                    async move {
363                        let last_id = { *updater_inner.last_update_id.lock().await };
364                        let updates: Vec<serde_json::Value> =
365                            get_fn(last_id, timeout, allowed_inner).await?;
366                        if !updates.is_empty() {
367                            if !updater_inner.is_running() {
368                                warn!(
369                                    "Updater stopped unexpectedly. Pulled updates will be \
370                                     ignored and pulled again on restart."
371                                );
372                                return Ok(());
373                            }
374                            for update in &updates {
375                                if let Err(e) = tx.send(update.clone()).await {
376                                    error!("Failed to enqueue update: {e}");
377                                }
378                            }
379                            if let Some(last) = updates.last() {
380                                if let Some(uid) = last.get("update_id").and_then(|v| v.as_i64()) {
381                                    *updater_inner.last_update_id.lock().await = uid + 1;
382                                }
383                            }
384                        }
385                        Ok(())
386                    }
387                },
388                on_err_cb: Some(|e: &TelegramError| {
389                    error!("Error while polling for updates: {e}");
390                }),
391                description: "Polling Updates",
392                interval: poll_interval.as_secs_f64(),
393                stop_rx: Some(stop_rx),
394                is_running: Some(Box::new({
395                    let u = updater.clone();
396                    move || u.is_running()
397                })),
398                max_retries: -1,
399                repeat_on_success: true,
400            })
401            .await;
402
403            if let Err(e) = result {
404                error!("Polling loop exited with error: {e}");
405            }
406        });
407
408        Ok(())
409    }
410
411    // -----------------------------------------------------------------------
412    // Webhook
413    // -----------------------------------------------------------------------
414
415    /// Start a webhook server to receive updates.
416    #[cfg(feature = "webhooks")]
417    pub async fn start_webhook(
418        self: &Arc<Self>,
419        config: WebhookConfig,
420    ) -> Result<(), UpdaterError> {
421        if self.is_running() {
422            return Err(UpdaterError::AlreadyRunning);
423        }
424        if !self.initialized.load(std::sync::atomic::Ordering::Relaxed) {
425            return Err(UpdaterError::NotInitialized);
426        }
427
428        self.running
429            .store(true, std::sync::atomic::Ordering::Relaxed);
430        let _ = self.stop_tx.send(false);
431
432        // WebhookServer expects Sender<Update> but the updater channel carries
433        // serde_json::Value. Bridge the two with an intermediate typed channel.
434        let (typed_tx, mut typed_rx) = mpsc::channel::<Update>(256);
435        let value_tx = self.update_tx.clone();
436        tokio::spawn(async move {
437            while let Some(update) = typed_rx.recv().await {
438                match serde_json::to_value(&update) {
439                    Ok(v) => {
440                        let _ = value_tx.send(v).await;
441                    }
442                    Err(e) => {
443                        error!("Failed to serialize Update to Value: {e}");
444                    }
445                }
446            }
447        });
448
449        // Build the TLS configuration if paths are provided.
450        #[cfg(feature = "webhooks-tls")]
451        let tls_config = if config.has_tls() {
452            let cert_path = config
453                .cert_path
454                .as_deref()
455                .expect("cert_path checked by has_tls");
456            let key_path = config
457                .key_path
458                .as_deref()
459                .expect("key_path checked by has_tls");
460            match crate::utils::webhook_handler::TlsConfig::from_pem_files(cert_path, key_path)
461                .await
462            {
463                Ok(tls) => Some(tls),
464                Err(e) => {
465                    self.running
466                        .store(false, std::sync::atomic::Ordering::Relaxed);
467                    return Err(UpdaterError::Bootstrap(format!(
468                        "TLS configuration failed: {e}"
469                    )));
470                }
471            }
472        } else {
473            None
474        };
475
476        // Warn at runtime if TLS paths were set but the feature is not enabled.
477        #[cfg(not(feature = "webhooks-tls"))]
478        if config.has_tls() {
479            warn!(
480                "TLS cert_path/key_path are set but the `webhooks-tls` feature is not enabled. \
481                 The server will start without TLS. Enable the `webhooks-tls` feature to use HTTPS."
482            );
483        }
484
485        let server = Arc::new(WebhookServer::new(
486            &config.listen,
487            config.port,
488            &config.url_path,
489            typed_tx,
490            config.secret_token,
491            #[cfg(feature = "webhooks-tls")]
492            tls_config,
493        ));
494
495        let ready = Arc::new(Notify::new());
496        let ready_clone = ready.clone();
497
498        let srv = server.clone();
499        tokio::spawn(async move {
500            if let Err(e) = srv.serve_forever(Some(ready_clone)).await {
501                error!("Webhook server error: {e}");
502            }
503        });
504
505        ready.notified().await;
506        debug!(
507            "Webhook server started on {}:{}",
508            config.listen, config.port
509        );
510
511        *self.httpd.lock().await = Some(server);
512
513        Ok(())
514    }
515
516    // -----------------------------------------------------------------------
517    // Stop
518    // -----------------------------------------------------------------------
519
520    /// Stop the updater (both polling and webhook).
521    pub async fn stop(&self) -> Result<(), UpdaterError> {
522        if !self.is_running() {
523            return Err(UpdaterError::NotRunning);
524        }
525        debug!("Stopping updater");
526        self.running
527            .store(false, std::sync::atomic::Ordering::Relaxed);
528
529        // Signal the polling loop to stop.
530        let _ = self.stop_tx.send(true);
531
532        // Shut down webhook server if present.
533        #[cfg(feature = "webhooks")]
534        {
535            let httpd = self.httpd.lock().await;
536            if let Some(ref server) = *httpd {
537                server.shutdown();
538            }
539        }
540
541        debug!("Updater stopped");
542        Ok(())
543    }
544
545    // -----------------------------------------------------------------------
546    // Bootstrap helpers
547    // -----------------------------------------------------------------------
548
549    async fn bootstrap_delete_webhook(
550        &self,
551        delete_fn: DeleteWebhookFn,
552        drop_pending: bool,
553        max_retries: i32,
554    ) -> Result<(), TelegramError> {
555        debug!("Deleting webhook (bootstrap)");
556        network_retry_loop(NetworkLoopConfig {
557            action_cb: || {
558                let f = delete_fn.clone();
559                async move { f(drop_pending).await }
560            },
561            on_err_cb: None::<fn(&TelegramError)>,
562            description: "Bootstrap delete webhook",
563            interval: 1.0,
564            stop_rx: None,
565            is_running: None,
566            max_retries,
567            repeat_on_success: false,
568        })
569        .await
570    }
571}
572
573// ---------------------------------------------------------------------------
574// Errors
575// ---------------------------------------------------------------------------
576
577#[derive(Debug, thiserror::Error)]
578/// Errors that can occur within the [`Updater`] lifecycle.
579#[non_exhaustive]
580pub enum UpdaterError {
581    /// The updater is already running and cannot be started again.
582    #[error("this Updater is already running")]
583    AlreadyRunning,
584
585    /// The updater is not currently running.
586    #[error("this Updater is not running")]
587    NotRunning,
588
589    /// The updater has not been initialized yet.
590    #[error("this Updater was not initialized")]
591    NotInitialized,
592
593    /// The updater is still running and cannot be shut down.
594    #[error("this Updater is still running")]
595    StillRunning,
596
597    /// The bootstrap phase (e.g. deleting webhooks) failed.
598    #[error("bootstrap failed: {0}")]
599    Bootstrap(String),
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    fn noop_get_updates() -> GetUpdatesFn {
607        Arc::new(|_offset, _timeout, _allowed| Box::pin(async { Ok(Vec::new()) }))
608    }
609
610    fn noop_delete_webhook() -> DeleteWebhookFn {
611        Arc::new(|_drop_pending| Box::pin(async { Ok(()) }))
612    }
613
614    fn default_config() -> PollingConfig {
615        PollingConfig {
616            poll_interval: Duration::ZERO,
617            timeout: Duration::from_secs(1),
618            bootstrap_retries: 0,
619            allowed_updates: None,
620            drop_pending_updates: false,
621            get_updates: noop_get_updates(),
622            delete_webhook: noop_delete_webhook(),
623        }
624    }
625
626    #[tokio::test]
627    async fn lifecycle() {
628        let updater = Arc::new(Updater::new(16));
629        assert!(!updater.is_running());
630
631        updater.initialize().await;
632
633        // Can't stop before starting.
634        assert!(updater.stop().await.is_err());
635
636        updater.shutdown().await.unwrap();
637    }
638
639    #[tokio::test]
640    async fn start_polling_requires_init() {
641        let updater = Arc::new(Updater::new(16));
642        let result = updater.start_polling(default_config()).await;
643        assert!(matches!(result, Err(UpdaterError::NotInitialized)));
644    }
645
646    #[tokio::test]
647    async fn start_and_stop_polling() {
648        let updater = Arc::new(Updater::new(16));
649        updater.initialize().await;
650        updater.start_polling(default_config()).await.unwrap();
651        assert!(updater.is_running());
652
653        // Can't start twice.
654        let result = updater.start_polling(default_config()).await;
655        assert!(matches!(result, Err(UpdaterError::AlreadyRunning)));
656
657        updater.stop().await.unwrap();
658        assert!(!updater.is_running());
659    }
660
661    #[tokio::test]
662    async fn take_update_rx_once() {
663        let updater = Arc::new(Updater::new(16));
664        let rx = updater.take_update_rx().await;
665        assert!(rx.is_some());
666        let rx2 = updater.take_update_rx().await;
667        assert!(rx2.is_none());
668    }
669
670    #[tokio::test]
671    async fn polling_delivers_updates() {
672        let updater = Arc::new(Updater::new(16));
673        updater.initialize().await;
674
675        let mut rx = updater.take_update_rx().await.unwrap();
676
677        // A get_updates that returns one update then empty.
678        let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
679        let cc = call_count.clone();
680        let get_fn: GetUpdatesFn = Arc::new(move |_offset, _timeout, _allowed| {
681            let cc = cc.clone();
682            Box::pin(async move {
683                let n = cc.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
684                if n == 0 {
685                    Ok(vec![serde_json::json!({"update_id": 100, "message": {}})])
686                } else {
687                    Ok(Vec::new())
688                }
689            })
690        });
691
692        let config = PollingConfig {
693            poll_interval: Duration::from_millis(10),
694            timeout: Duration::from_secs(1),
695            bootstrap_retries: 0,
696            allowed_updates: None,
697            drop_pending_updates: false,
698            get_updates: get_fn,
699            delete_webhook: noop_delete_webhook(),
700        };
701
702        updater.start_polling(config).await.unwrap();
703
704        // Should receive the update within a reasonable time.
705        let update = tokio::time::timeout(Duration::from_secs(2), rx.recv())
706            .await
707            .expect("timeout waiting for update")
708            .expect("channel closed");
709
710        assert_eq!(update["update_id"], 100);
711
712        updater.stop().await.unwrap();
713    }
714}