rust_tg_bot_ext/
update_processor.rs1use std::future::Future;
9use std::pin::Pin;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12
13use tokio::sync::Semaphore;
14
15use rust_tg_bot_raw::types::update::Update;
16
17#[derive(Debug, thiserror::Error)]
23#[non_exhaustive]
24pub enum UpdateProcessorError {
25 #[error("`max_concurrent_updates` must be a positive integer")]
27 InvalidConcurrency,
28
29 #[error("Handler error: {0}")]
31 Handler(Box<dyn std::error::Error + Send + Sync>),
32}
33
34#[async_trait::async_trait]
47pub trait UpdateProcessor: Send + Sync {
48 async fn do_process_update(
55 &self,
56 update: Arc<Update>,
57 coroutine: Pin<Box<dyn Future<Output = ()> + Send>>,
58 );
59
60 async fn initialize(&self) {}
62
63 async fn shutdown(&self) {}
65}
66
67pub struct BaseUpdateProcessor {
73 inner: Box<dyn UpdateProcessor>,
74 semaphore: Arc<Semaphore>,
75 max_concurrent_updates: usize,
76 active: AtomicUsize,
79}
80
81impl std::fmt::Debug for BaseUpdateProcessor {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("BaseUpdateProcessor")
84 .field("max_concurrent_updates", &self.max_concurrent_updates)
85 .field("active", &self.active.load(Ordering::Relaxed))
86 .finish()
87 }
88}
89
90impl BaseUpdateProcessor {
91 pub fn new(
98 inner: Box<dyn UpdateProcessor>,
99 max_concurrent_updates: usize,
100 ) -> Result<Self, UpdateProcessorError> {
101 if max_concurrent_updates == 0 {
102 return Err(UpdateProcessorError::InvalidConcurrency);
103 }
104 Ok(Self {
105 inner,
106 semaphore: Arc::new(Semaphore::new(max_concurrent_updates)),
107 max_concurrent_updates,
108 active: AtomicUsize::new(0),
109 })
110 }
111
112 #[must_use]
114 pub fn max_concurrent_updates(&self) -> usize {
115 self.max_concurrent_updates
116 }
117
118 #[must_use]
120 pub fn current_concurrent_updates(&self) -> usize {
121 self.active.load(Ordering::Relaxed)
122 }
123
124 pub async fn process_update(
126 &self,
127 update: Arc<Update>,
128 coroutine: Pin<Box<dyn Future<Output = ()> + Send>>,
129 ) {
130 let _permit = self
131 .semaphore
132 .acquire()
133 .await
134 .expect("semaphore should not be closed");
135 self.active.fetch_add(1, Ordering::Relaxed);
136 self.inner.do_process_update(update, coroutine).await;
137 self.active.fetch_sub(1, Ordering::Relaxed);
138 }
139
140 pub async fn initialize(&self) {
142 self.inner.initialize().await;
143 }
144
145 pub async fn shutdown(&self) {
147 self.inner.shutdown().await;
148 }
149}
150
151#[derive(Debug, Default)]
160pub struct SimpleUpdateProcessor;
161
162#[async_trait::async_trait]
163impl UpdateProcessor for SimpleUpdateProcessor {
164 async fn do_process_update(
165 &self,
166 _update: Arc<Update>,
167 coroutine: Pin<Box<dyn Future<Output = ()> + Send>>,
168 ) {
169 coroutine.await;
170 }
171}
172
173pub fn simple_processor(
180 max_concurrent_updates: usize,
181) -> Result<BaseUpdateProcessor, UpdateProcessorError> {
182 BaseUpdateProcessor::new(Box::new(SimpleUpdateProcessor), max_concurrent_updates)
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 fn dummy_update() -> Update {
190 serde_json::from_value(serde_json::json!({"update_id": 0})).unwrap()
191 }
192
193 #[tokio::test]
194 async fn simple_processor_runs_coroutine() {
195 let proc = simple_processor(1).unwrap();
196 proc.initialize().await;
197
198 let flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
199 let flag2 = flag.clone();
200
201 let fut: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(async move {
202 flag2.store(true, Ordering::Relaxed);
203 });
204
205 proc.process_update(Arc::new(dummy_update()), fut).await;
206 assert!(flag.load(Ordering::Relaxed));
207
208 proc.shutdown().await;
209 }
210
211 #[test]
212 fn zero_concurrency_rejected() {
213 assert!(simple_processor(0).is_err());
214 }
215
216 #[tokio::test]
217 async fn concurrent_updates_tracking() {
218 let proc = simple_processor(4).unwrap();
219 assert_eq!(proc.max_concurrent_updates(), 4);
220 assert_eq!(proc.current_concurrent_updates(), 0);
221 }
222
223 #[tokio::test]
224 async fn concurrent_processing_bounded() {
225 let proc = Arc::new(simple_processor(2).unwrap());
226 let counter = Arc::new(AtomicUsize::new(0));
227 let max_seen = Arc::new(AtomicUsize::new(0));
228
229 let mut handles = Vec::new();
230
231 for _ in 0..10 {
232 let p = proc.clone();
233 let c = counter.clone();
234 let m = max_seen.clone();
235
236 handles.push(tokio::spawn(async move {
237 let cc = c.clone();
238 let mm = m.clone();
239 let fut: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(async move {
240 let current = cc.fetch_add(1, Ordering::SeqCst) + 1;
241 mm.fetch_max(current, Ordering::SeqCst);
243 tokio::task::yield_now().await;
244 cc.fetch_sub(1, Ordering::SeqCst);
245 });
246 p.process_update(Arc::new(dummy_update()), fut).await;
247 }));
248 }
249
250 for h in handles {
251 h.await.unwrap();
252 }
253
254 assert!(max_seen.load(Ordering::SeqCst) <= 2);
256 }
257}