1use crate::{
2 dispatching::{
3 DefaultKey, DpHandlerDescription, ShutdownToken,
4 distribution::default_distribution_function,
5 },
6 error_handlers::{ErrorHandler, LoggingErrorHandler},
7 requests::{Request, Requester},
8 stop::StopToken,
9 types::{Update, UpdateKind},
10 update_listeners::{self, UpdateListener},
11};
12
13use dptree::di::DependencyMap;
14use either::Either;
15use futures::{
16 FutureExt as _, StreamExt as _,
17 future::{self, BoxFuture},
18 stream::FuturesUnordered,
19};
20use tokio_stream::wrappers::ReceiverStream;
21
22use std::{
23 collections::HashMap,
24 fmt::Debug,
25 future::Future,
26 hash::Hash,
27 ops::{ControlFlow, Deref},
28 pin::pin,
29 sync::{
30 Arc,
31 atomic::{AtomicBool, AtomicU32, Ordering},
32 },
33};
34
35pub struct DispatcherBuilder<R, Err, Key> {
40 bot: R,
41 dependencies: DependencyMap,
42 handler: Arc<UpdateHandler<Err>>,
43 default_handler: DefaultHandler,
44 error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
45 ctrlc_handler: bool,
46 distribution_f: fn(&Update) -> Option<Key>,
47 worker_queue_size: usize,
48}
49
50impl<R, Err, Key> DispatcherBuilder<R, Err, Key>
51where
52 R: Clone + Requester + Clone + Send + Sync + 'static,
53 Err: Debug + Send + Sync + 'static,
54{
55 #[must_use]
59 pub fn default_handler<H, Fut>(self, handler: H) -> Self
60 where
61 H: Fn(Arc<Update>) -> Fut + Send + Sync + 'static,
62 Fut: Future<Output = ()> + Send + 'static,
63 {
64 let handler = Arc::new(handler);
65
66 Self {
67 default_handler: Arc::new(move |upd| {
68 let handler = Arc::clone(&handler);
69 Box::pin(handler(upd))
70 }),
71 ..self
72 }
73 }
74
75 #[must_use]
79 pub fn error_handler(self, handler: Arc<dyn ErrorHandler<Err> + Send + Sync>) -> Self {
80 Self { error_handler: handler, ..self }
81 }
82
83 #[must_use]
87 pub fn dependencies(self, dependencies: DependencyMap) -> Self {
88 Self { dependencies, ..self }
89 }
90
91 #[cfg(feature = "ctrlc_handler")]
95 #[must_use]
96 pub fn enable_ctrlc_handler(self) -> Self {
97 Self { ctrlc_handler: true, ..self }
98 }
99
100 #[must_use]
104 pub fn worker_queue_size(self, size: usize) -> Self {
105 Self { worker_queue_size: size, ..self }
106 }
107
108 #[must_use]
112 #[deprecated(since = "0.15.0", note = "This method is a no-op; you can just remove it.")]
113 pub fn stack_size(self, _size: usize) -> Self {
114 self
115 }
116
117 #[must_use]
173 pub fn distribution_function<K>(
174 self,
175 f: fn(&Update) -> Option<K>,
176 ) -> DispatcherBuilder<R, Err, K>
177 where
178 K: Hash + Eq,
179 {
180 let Self {
181 bot,
182 dependencies,
183 handler,
184 default_handler,
185 error_handler,
186 ctrlc_handler,
187 distribution_f: _,
188 worker_queue_size,
189 } = self;
190
191 DispatcherBuilder {
192 bot,
193 dependencies,
194 handler,
195 default_handler,
196 error_handler,
197 ctrlc_handler,
198 distribution_f: f,
199 worker_queue_size,
200 }
201 }
202
203 #[must_use]
209 pub fn build(self) -> Dispatcher<R, Err, Key> {
210 let Self {
211 bot,
212 dependencies,
213 handler,
214 default_handler,
215 error_handler,
216 distribution_f,
217 worker_queue_size,
218 ctrlc_handler,
219 } = self;
220
221 dptree::type_check(
222 handler.sig(),
223 &dependencies,
224 &[
225 dptree::Type::of::<R>(),
226 dptree::Type::of::<teloxide_core_ng::types::Update>(),
227 dptree::Type::of::<teloxide_core_ng::types::Me>(),
228 ],
229 );
230
231 let _ = ctrlc_handler;
233
234 let dp = Dispatcher {
235 bot,
236 dependencies,
237 handler,
238 default_handler,
239 error_handler,
240 state: ShutdownToken::new(),
241 distribution_f,
242 worker_queue_size,
243 workers: HashMap::new(),
244 default_worker: None,
245 current_number_of_active_workers: Default::default(),
246 max_number_of_active_workers: Default::default(),
247 };
248
249 #[cfg(feature = "ctrlc_handler")]
250 {
251 if ctrlc_handler {
252 let mut dp = dp;
253 dp.setup_ctrlc_handler_inner();
254 return dp;
255 }
256 }
257
258 dp
259 }
260}
261
262pub struct Dispatcher<R, Err, Key> {
275 bot: R,
276 dependencies: DependencyMap,
277
278 handler: Arc<UpdateHandler<Err>>,
279 default_handler: DefaultHandler,
280
281 distribution_f: fn(&Update) -> Option<Key>,
282 worker_queue_size: usize,
283 current_number_of_active_workers: Arc<AtomicU32>,
284 max_number_of_active_workers: Arc<AtomicU32>,
285 workers: HashMap<Key, Worker>,
287 default_worker: Option<Worker>,
289
290 error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
291
292 state: ShutdownToken,
293}
294
295struct Worker {
296 tx: tokio::sync::mpsc::Sender<Update>,
297 handle: tokio::task::JoinHandle<()>,
298 is_waiting: Arc<AtomicBool>,
299}
300
301pub type UpdateHandler<Err> = dptree::Handler<'static, Result<(), Err>, DpHandlerDescription>;
306
307type DefaultHandler = Arc<dyn Fn(Arc<Update>) -> BoxFuture<'static, ()> + Send + Sync>;
308
309impl<R, Err> Dispatcher<R, Err, DefaultKey>
310where
311 R: Requester + Clone + Send + Sync + 'static,
312 Err: Send + Sync + 'static,
313{
314 #[must_use]
316 pub fn builder(bot: R, handler: UpdateHandler<Err>) -> DispatcherBuilder<R, Err, DefaultKey>
317 where
318 Err: Debug,
319 {
320 const DEFAULT_WORKER_QUEUE_SIZE: usize = 64;
321
322 DispatcherBuilder {
323 bot,
324 dependencies: DependencyMap::new(),
325 handler: Arc::new(handler),
326 default_handler: Arc::new(|upd| {
327 log::warn!("Unhandled update: {upd:?}");
328 Box::pin(async {})
329 }),
330 error_handler: LoggingErrorHandler::new(),
331 ctrlc_handler: false,
332 worker_queue_size: DEFAULT_WORKER_QUEUE_SIZE,
333 distribution_f: default_distribution_function,
334 }
335 }
336}
337
338impl<R, Err, Key> Dispatcher<R, Err, Key>
339where
340 R: Requester + Clone + Send + Sync + 'static,
341 Err: Send + Sync + 'static,
342 Key: Hash + Eq + Clone + Send,
343{
344 pub async fn dispatch(&mut self)
359 where
360 R: Requester + Clone,
361 <R as Requester>::GetUpdates: Send,
362 {
363 let listener = update_listeners::polling_default(self.bot.clone()).await;
364 let error_handler =
365 LoggingErrorHandler::with_custom_text("An error from the update listener");
366
367 self.dispatch_with_listener(listener, error_handler).await;
368 }
369
370 pub async fn dispatch_with_listener<'a, UListener, Eh>(
375 &'a mut self,
376 update_listener: UListener,
377 update_listener_error_handler: Arc<Eh>,
378 ) where
379 UListener: UpdateListener + Send + 'a,
380 Eh: ErrorHandler<UListener::Err> + Send + Sync + 'a,
381 UListener::Err: Debug,
382 {
383 self.try_dispatch_with_listener(update_listener, update_listener_error_handler)
384 .await
385 .expect("Couldn't prepare dispatching context")
386 }
387
388 pub async fn try_dispatch_with_listener<'a, UListener, Eh>(
396 &'a mut self,
397 mut update_listener: UListener,
398 update_listener_error_handler: Arc<Eh>,
399 ) -> Result<(), R::Err>
400 where
401 UListener: UpdateListener + Send + 'a,
402 Eh: ErrorHandler<UListener::Err> + Send + Sync + 'a,
403 UListener::Err: Debug,
404 {
405 let me = self.bot.get_me().send().await?;
407 self.dependencies.insert(me);
408 self.dependencies.insert(self.bot.clone());
409
410 let description = self.handler.description();
411 let allowed_updates = description.allowed_updates();
412 log::debug!("hinting allowed updates: {allowed_updates:?}");
413 update_listener.hint_allowed_updates(&mut allowed_updates.into_iter());
414
415 let stop_token = Some(update_listener.stop_token());
416 self.start_listening(update_listener, update_listener_error_handler, stop_token).await;
417
418 Ok(())
419 }
420
421 async fn start_listening<'a, UListener, Eh>(
422 &'a mut self,
423 mut update_listener: UListener,
424 update_listener_error_handler: Arc<Eh>,
425 mut stop_token: Option<StopToken>,
426 ) where
427 UListener: UpdateListener + 'a,
428 Eh: ErrorHandler<UListener::Err> + 'a,
429 UListener::Err: Debug,
430 {
431 self.state.start_dispatching();
432
433 let stream = update_listener.as_stream();
434 tokio::pin!(stream);
435
436 loop {
437 self.remove_inactive_workers_if_needed().await;
438
439 let res = future::select(stream.next(), pin!(self.state.wait_for_changes()))
440 .map(either)
441 .await
442 .map_either(|l| l.0, |r| r.0);
443
444 match res {
445 Either::Left(upd) => match upd {
446 Some(upd) => self.process_update(upd, &update_listener_error_handler).await,
447 None => break,
448 },
449 Either::Right(()) => {
450 if self.state.is_shutting_down() {
451 if let Some(token) = stop_token.take() {
452 log::debug!("Start shutting down dispatching...");
453 token.stop();
454 }
455 }
456 }
457 }
458 }
459
460 self.workers
461 .drain()
462 .map(|(_chat_id, worker)| worker.handle)
463 .chain(self.default_worker.take().map(|worker| worker.handle))
464 .collect::<FuturesUnordered<_>>()
465 .for_each(|res| async {
466 res.expect("Failed to wait for a worker.");
467 })
468 .await;
469
470 self.state.done();
471 }
472
473 async fn process_update<LErr, LErrHandler>(
474 &mut self,
475 update: Result<Update, LErr>,
476 err_handler: &Arc<LErrHandler>,
477 ) where
478 LErrHandler: ErrorHandler<LErr>,
479 {
480 match update {
481 Ok(upd) => {
482 if let UpdateKind::Error(err) = upd.kind {
483 log::error!(
484 "Cannot parse an update.\nError: {err:?}\n\
485 This is a bug in teloxide-core, please open an issue here: \
486 https://github.com/teloxide/teloxide/issues.",
487 );
488 return;
489 }
490
491 let worker = match (self.distribution_f)(&upd) {
492 Some(key) => self.workers.entry(key).or_insert_with(|| {
493 let deps = self.dependencies.clone();
494 let handler = Arc::clone(&self.handler);
495 let default_handler = Arc::clone(&self.default_handler);
496 let error_handler = Arc::clone(&self.error_handler);
497
498 spawn_worker(
499 deps,
500 handler,
501 default_handler,
502 error_handler,
503 Arc::clone(&self.current_number_of_active_workers),
504 Arc::clone(&self.max_number_of_active_workers),
505 self.worker_queue_size,
506 )
507 }),
508 None => self.default_worker.get_or_insert_with(|| {
509 let deps = self.dependencies.clone();
510 let handler = Arc::clone(&self.handler);
511 let default_handler = Arc::clone(&self.default_handler);
512 let error_handler = Arc::clone(&self.error_handler);
513
514 spawn_default_worker(
515 deps,
516 handler,
517 default_handler,
518 error_handler,
519 self.worker_queue_size,
520 )
521 }),
522 };
523
524 worker.tx.send(upd).await.expect("TX is dead");
525 }
526 Err(err) => err_handler.clone().handle_error(err).await,
527 }
528 }
529
530 async fn remove_inactive_workers_if_needed(&mut self) {
531 let workers = self.workers.len();
532 let max = self.max_number_of_active_workers.load(Ordering::Relaxed) as usize;
533
534 if workers <= max {
535 return;
536 }
537
538 self.remove_inactive_workers().await;
539 }
540
541 #[inline(never)] async fn remove_inactive_workers(&mut self) {
543 let handles = self
544 .workers
545 .iter()
546 .filter(|(_, worker)| {
547 worker.tx.capacity() == self.worker_queue_size
548 && worker.is_waiting.load(Ordering::Relaxed)
549 })
550 .map(|(k, _)| k)
551 .cloned()
552 .collect::<Vec<_>>()
553 .into_iter()
554 .map(|key| {
555 let Worker { tx, handle, .. } = self.workers.remove(&key).unwrap();
556
557 drop(tx);
560
561 handle
562 });
563
564 for handle in handles {
565 let _ = handle.await;
569 }
570 }
571
572 pub fn shutdown_token(&self) -> ShutdownToken {
575 self.state.clone()
576 }
577}
578
579impl<R, Err, Key> Dispatcher<R, Err, Key> {
580 #[cfg(feature = "ctrlc_handler")]
581 fn setup_ctrlc_handler_inner(&mut self) {
582 let token = self.state.clone();
583 tokio::spawn(async move {
584 loop {
585 tokio::signal::ctrl_c().await.expect("Failed to listen for ^C");
586
587 match token.shutdown() {
588 Ok(f) => {
589 log::info!("^C received, trying to shutdown the dispatcher...");
590 f.await;
591 log::info!("dispatcher is shutdown...");
592 }
593 Err(_) => {
594 log::info!("^C received, the dispatcher isn't running, ignoring the signal")
595 }
596 }
597 }
598 });
599 }
600}
601
602fn spawn_worker<Err>(
603 deps: DependencyMap,
604 handler: Arc<UpdateHandler<Err>>,
605 default_handler: DefaultHandler,
606 error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
607 current_number_of_active_workers: Arc<AtomicU32>,
608 max_number_of_active_workers: Arc<AtomicU32>,
609 queue_size: usize,
610) -> Worker
611where
612 Err: Send + Sync + 'static,
613{
614 let (tx, mut rx) = tokio::sync::mpsc::channel(queue_size);
615 let is_waiting = Arc::new(AtomicBool::new(true));
616 let is_waiting_local = Arc::clone(&is_waiting);
617
618 let deps = Arc::new(deps);
619
620 let handle = tokio::spawn(async move {
621 while let Some(update) = rx.recv().await {
622 is_waiting_local.store(false, Ordering::Relaxed);
623 {
624 let current = current_number_of_active_workers.fetch_add(1, Ordering::Relaxed) + 1;
625 max_number_of_active_workers.fetch_max(current, Ordering::Relaxed);
626 }
627
628 let deps = Arc::clone(&deps);
629 let handler = Arc::clone(&handler);
630 let default_handler = Arc::clone(&default_handler);
631 let error_handler = Arc::clone(&error_handler);
632
633 handle_update(update, deps, handler, default_handler, error_handler).await;
634
635 current_number_of_active_workers.fetch_sub(1, Ordering::Relaxed);
636 is_waiting_local.store(true, Ordering::Relaxed);
637 }
638 });
639
640 Worker { tx, handle, is_waiting }
641}
642
643fn spawn_default_worker<Err>(
644 deps: DependencyMap,
645 handler: Arc<UpdateHandler<Err>>,
646 default_handler: DefaultHandler,
647 error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
648 queue_size: usize,
649) -> Worker
650where
651 Err: Send + Sync + 'static,
652{
653 let (tx, rx) = tokio::sync::mpsc::channel(queue_size);
654
655 let deps = Arc::new(deps);
656
657 let handle = tokio::spawn(ReceiverStream::new(rx).for_each_concurrent(None, move |update| {
658 let deps = Arc::clone(&deps);
659 let handler = Arc::clone(&handler);
660 let default_handler = Arc::clone(&default_handler);
661 let error_handler = Arc::clone(&error_handler);
662
663 handle_update(update, deps, handler, default_handler, error_handler)
664 }));
665
666 Worker { tx, handle, is_waiting: Arc::new(AtomicBool::new(true)) }
667}
668
669async fn handle_update<Err>(
670 update: Update,
671 deps: Arc<DependencyMap>,
672 handler: Arc<UpdateHandler<Err>>,
673 default_handler: DefaultHandler,
674 error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
675) where
676 Err: Send + Sync + 'static,
677{
678 let mut deps = deps.deref().clone();
679 deps.insert(update);
680
681 match handler.dispatch(deps).await {
682 ControlFlow::Break(Ok(())) => {}
683 ControlFlow::Break(Err(err)) => error_handler.clone().handle_error(err).await,
684 ControlFlow::Continue(deps) => {
685 let update = deps.get();
686 (default_handler)(update).await;
687 }
688 }
689}
690
691fn either<L, R>(x: future::Either<L, R>) -> Either<L, R> {
692 match x {
693 future::Either::Left(l) => Either::Left(l),
694 future::Either::Right(r) => Either::Right(r),
695 }
696}
697#[cfg(test)]
698mod tests {
699 use std::convert::Infallible;
700
701 use teloxide_core_ng::Bot;
702
703 use super::*;
704
705 #[tokio::test]
706 async fn test_tokio_spawn() {
707 tokio::spawn(async {
708 if false {
710 Dispatcher::<_, Infallible, _>::builder(Bot::new(""), dptree::entry())
711 .build()
712 .dispatch()
713 .await;
714 }
715 })
716 .await
717 .unwrap();
718 }
719}