Skip to main content

rust_tg_bot_ext/
update_processor.rs

1//! Semaphore-based concurrent update processing.
2//!
3//! Ported from `python-telegram-bot/src/telegram/ext/_baseupdateprocessor.py`.
4//!
5//! Provides [`BaseUpdateProcessor`] (the async trait) and [`SimpleUpdateProcessor`] (the default
6//! implementation that immediately awaits each coroutine under a semaphore).
7
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12
13use tokio::sync::Semaphore;
14
15use rust_tg_bot_raw::types::update::Update;
16
17// ---------------------------------------------------------------------------
18// Error type
19// ---------------------------------------------------------------------------
20
21/// Errors that may occur during update processing.
22#[derive(Debug, thiserror::Error)]
23#[non_exhaustive]
24pub enum UpdateProcessorError {
25    /// `max_concurrent_updates` was not a positive integer.
26    #[error("`max_concurrent_updates` must be a positive integer")]
27    InvalidConcurrency,
28
29    /// An inner handler returned an error.
30    #[error("Handler error: {0}")]
31    Handler(Box<dyn std::error::Error + Send + Sync>),
32}
33
34// ---------------------------------------------------------------------------
35// Trait
36// ---------------------------------------------------------------------------
37
38/// An abstract base for update processors.
39///
40/// Implementations control *how* update coroutines are driven (e.g. immediately awaited,
41/// batched, prioritised, etc.).
42///
43/// The [`process_update`](BaseUpdateProcessor::process_update) method is *final* -- it
44/// acquires the internal semaphore and then delegates to
45/// [`do_process_update`](BaseUpdateProcessor::do_process_update).
46#[async_trait::async_trait]
47pub trait UpdateProcessor: Send + Sync {
48    /// Custom implementation of how to process an update.  Must be implemented by the
49    /// concrete type.
50    ///
51    /// **Warning**: This method is called by
52    /// [`process_update`](BaseUpdateProcessor::process_update).  It should *not* be called
53    /// manually.
54    async fn do_process_update(
55        &self,
56        update: Arc<Update>,
57        coroutine: Pin<Box<dyn Future<Output = ()> + Send>>,
58    );
59
60    /// Called once before the processor starts handling updates.
61    async fn initialize(&self) {}
62
63    /// Called once when the processor is shutting down.
64    async fn shutdown(&self) {}
65}
66
67// ---------------------------------------------------------------------------
68// BaseUpdateProcessor -- semaphore wrapper
69// ---------------------------------------------------------------------------
70
71/// Wraps any [`UpdateProcessor`] with a semaphore to bound concurrency.
72pub struct BaseUpdateProcessor {
73    inner: Box<dyn UpdateProcessor>,
74    semaphore: Arc<Semaphore>,
75    max_concurrent_updates: usize,
76    /// Tracks how many permits are currently held so we can report
77    /// `current_concurrent_updates`.
78    active: AtomicUsize,
79}
80
81impl std::fmt::Debug for BaseUpdateProcessor {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("BaseUpdateProcessor")
84            .field("max_concurrent_updates", &self.max_concurrent_updates)
85            .field("active", &self.active.load(Ordering::Relaxed))
86            .finish()
87    }
88}
89
90impl BaseUpdateProcessor {
91    /// Creates a new `BaseUpdateProcessor`.
92    ///
93    /// # Errors
94    ///
95    /// Returns [`UpdateProcessorError::InvalidConcurrency`] if
96    /// `max_concurrent_updates` is zero.
97    pub fn new(
98        inner: Box<dyn UpdateProcessor>,
99        max_concurrent_updates: usize,
100    ) -> Result<Self, UpdateProcessorError> {
101        if max_concurrent_updates == 0 {
102            return Err(UpdateProcessorError::InvalidConcurrency);
103        }
104        Ok(Self {
105            inner,
106            semaphore: Arc::new(Semaphore::new(max_concurrent_updates)),
107            max_concurrent_updates,
108            active: AtomicUsize::new(0),
109        })
110    }
111
112    /// The maximum number of updates that can be processed concurrently.
113    #[must_use]
114    pub fn max_concurrent_updates(&self) -> usize {
115        self.max_concurrent_updates
116    }
117
118    /// A snapshot of the number of updates currently being processed.
119    #[must_use]
120    pub fn current_concurrent_updates(&self) -> usize {
121        self.active.load(Ordering::Relaxed)
122    }
123
124    /// Acquires the semaphore and then delegates to [`UpdateProcessor::do_process_update`].
125    pub async fn process_update(
126        &self,
127        update: Arc<Update>,
128        coroutine: Pin<Box<dyn Future<Output = ()> + Send>>,
129    ) {
130        let _permit = self
131            .semaphore
132            .acquire()
133            .await
134            .expect("semaphore should not be closed");
135        self.active.fetch_add(1, Ordering::Relaxed);
136        self.inner.do_process_update(update, coroutine).await;
137        self.active.fetch_sub(1, Ordering::Relaxed);
138    }
139
140    /// Delegates to the inner processor's `initialize`.
141    pub async fn initialize(&self) {
142        self.inner.initialize().await;
143    }
144
145    /// Delegates to the inner processor's `shutdown`.
146    pub async fn shutdown(&self) {
147        self.inner.shutdown().await;
148    }
149}
150
151// ---------------------------------------------------------------------------
152// SimpleUpdateProcessor
153// ---------------------------------------------------------------------------
154
155/// Default [`UpdateProcessor`] that immediately awaits the coroutine.
156///
157/// This is used when `ApplicationBuilder.concurrent_updates` is set to an integer -- the
158/// semaphore in [`BaseUpdateProcessor`] provides the actual bounding.
159#[derive(Debug, Default)]
160pub struct SimpleUpdateProcessor;
161
162#[async_trait::async_trait]
163impl UpdateProcessor for SimpleUpdateProcessor {
164    async fn do_process_update(
165        &self,
166        _update: Arc<Update>,
167        coroutine: Pin<Box<dyn Future<Output = ()> + Send>>,
168    ) {
169        coroutine.await;
170    }
171}
172
173/// Convenience constructor that builds a [`BaseUpdateProcessor`] wrapping a
174/// [`SimpleUpdateProcessor`] with the given concurrency limit.
175///
176/// # Errors
177///
178/// Returns [`UpdateProcessorError::InvalidConcurrency`] if `max_concurrent_updates` is zero.
179pub fn simple_processor(
180    max_concurrent_updates: usize,
181) -> Result<BaseUpdateProcessor, UpdateProcessorError> {
182    BaseUpdateProcessor::new(Box::new(SimpleUpdateProcessor), max_concurrent_updates)
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    fn dummy_update() -> Update {
190        serde_json::from_value(serde_json::json!({"update_id": 0})).unwrap()
191    }
192
193    #[tokio::test]
194    async fn simple_processor_runs_coroutine() {
195        let proc = simple_processor(1).unwrap();
196        proc.initialize().await;
197
198        let flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
199        let flag2 = flag.clone();
200
201        let fut: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(async move {
202            flag2.store(true, Ordering::Relaxed);
203        });
204
205        proc.process_update(Arc::new(dummy_update()), fut).await;
206        assert!(flag.load(Ordering::Relaxed));
207
208        proc.shutdown().await;
209    }
210
211    #[test]
212    fn zero_concurrency_rejected() {
213        assert!(simple_processor(0).is_err());
214    }
215
216    #[tokio::test]
217    async fn concurrent_updates_tracking() {
218        let proc = simple_processor(4).unwrap();
219        assert_eq!(proc.max_concurrent_updates(), 4);
220        assert_eq!(proc.current_concurrent_updates(), 0);
221    }
222
223    #[tokio::test]
224    async fn concurrent_processing_bounded() {
225        let proc = Arc::new(simple_processor(2).unwrap());
226        let counter = Arc::new(AtomicUsize::new(0));
227        let max_seen = Arc::new(AtomicUsize::new(0));
228
229        let mut handles = Vec::new();
230
231        for _ in 0..10 {
232            let p = proc.clone();
233            let c = counter.clone();
234            let m = max_seen.clone();
235
236            handles.push(tokio::spawn(async move {
237                let cc = c.clone();
238                let mm = m.clone();
239                let fut: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(async move {
240                    let current = cc.fetch_add(1, Ordering::SeqCst) + 1;
241                    // Record the maximum concurrent count observed.
242                    mm.fetch_max(current, Ordering::SeqCst);
243                    tokio::task::yield_now().await;
244                    cc.fetch_sub(1, Ordering::SeqCst);
245                });
246                p.process_update(Arc::new(dummy_update()), fut).await;
247            }));
248        }
249
250        for h in handles {
251            h.await.unwrap();
252        }
253
254        // The semaphore should have bounded concurrency to 2.
255        assert!(max_seen.load(Ordering::SeqCst) <= 2);
256    }
257}