Skip to main content

teloxide_ng/dispatching/
dispatcher.rs

1use crate::{
2    dispatching::{
3        DefaultKey, DpHandlerDescription, ShutdownToken,
4        distribution::default_distribution_function,
5    },
6    error_handlers::{ErrorHandler, LoggingErrorHandler},
7    requests::{Request, Requester},
8    stop::StopToken,
9    types::{Update, UpdateKind},
10    update_listeners::{self, UpdateListener},
11};
12
13use dptree::di::DependencyMap;
14use either::Either;
15use futures::{
16    FutureExt as _, StreamExt as _,
17    future::{self, BoxFuture},
18    stream::FuturesUnordered,
19};
20use tokio_stream::wrappers::ReceiverStream;
21
22use std::{
23    collections::HashMap,
24    fmt::Debug,
25    future::Future,
26    hash::Hash,
27    ops::{ControlFlow, Deref},
28    pin::pin,
29    sync::{
30        Arc,
31        atomic::{AtomicBool, AtomicU32, Ordering},
32    },
33};
34
35/// The builder for [`Dispatcher`].
36///
37/// See also: ["Dispatching or
38/// REPLs?"](../dispatching/index.html#dispatching-or-repls)
39pub struct DispatcherBuilder<R, Err, Key> {
40    bot: R,
41    dependencies: DependencyMap,
42    handler: Arc<UpdateHandler<Err>>,
43    default_handler: DefaultHandler,
44    error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
45    ctrlc_handler: bool,
46    distribution_f: fn(&Update) -> Option<Key>,
47    worker_queue_size: usize,
48}
49
50impl<R, Err, Key> DispatcherBuilder<R, Err, Key>
51where
52    R: Clone + Requester + Clone + Send + Sync + 'static,
53    Err: Debug + Send + Sync + 'static,
54{
55    /// Specifies a handler that will be called for an unhandled update.
56    ///
57    /// By default, it is a mere [`log::warn`].
58    #[must_use]
59    pub fn default_handler<H, Fut>(self, handler: H) -> Self
60    where
61        H: Fn(Arc<Update>) -> Fut + Send + Sync + 'static,
62        Fut: Future<Output = ()> + Send + 'static,
63    {
64        let handler = Arc::new(handler);
65
66        Self {
67            default_handler: Arc::new(move |upd| {
68                let handler = Arc::clone(&handler);
69                Box::pin(handler(upd))
70            }),
71            ..self
72        }
73    }
74
75    /// Specifies a handler that will be called on a handler error.
76    ///
77    /// By default, it is [`LoggingErrorHandler`].
78    #[must_use]
79    pub fn error_handler(self, handler: Arc<dyn ErrorHandler<Err> + Send + Sync>) -> Self {
80        Self { error_handler: handler, ..self }
81    }
82
83    /// Specifies dependencies that can be used inside of handlers.
84    ///
85    /// By default, there is no dependencies.
86    #[must_use]
87    pub fn dependencies(self, dependencies: DependencyMap) -> Self {
88        Self { dependencies, ..self }
89    }
90
91    /// Enables the `^C` handler that [`shutdown`]s dispatching.
92    ///
93    /// [`shutdown`]: ShutdownToken::shutdown
94    #[cfg(feature = "ctrlc_handler")]
95    #[must_use]
96    pub fn enable_ctrlc_handler(self) -> Self {
97        Self { ctrlc_handler: true, ..self }
98    }
99
100    /// Specifies size of the queue for workers.
101    ///
102    /// By default it's 64.
103    #[must_use]
104    pub fn worker_queue_size(self, size: usize) -> Self {
105        Self { worker_queue_size: size, ..self }
106    }
107
108    /// Specifies the stack size available to the dispatcher.
109    ///
110    /// By default, it's 8 * 1024 * 1024 bytes (8 MiB).
111    #[must_use]
112    #[deprecated(since = "0.15.0", note = "This method is a no-op; you can just remove it.")]
113    pub fn stack_size(self, _size: usize) -> Self {
114        self
115    }
116
117    /// Specifies the distribution function that decides how updates are grouped
118    /// before execution.
119    ///
120    /// ## Update grouping
121    ///
122    /// When [`Dispatcher`] receives updates, it runs dispatching tree
123    /// (handlers) concurrently. This means that multiple updates can be
124    /// processed at the same time.
125    ///
126    /// However, this is not always convenient. For example, if you have global
127    /// state, then you may want to process some updates sequentially, to
128    /// prevent state inconsistencies.
129    ///
130    /// This is why `teloxide` allows grouping updates. Updates for which the
131    /// distribution function `f` returns the same "distribution key" `K` will
132    /// be run in sequence (while still being processed concurrently with the
133    /// updates with different distribution keys).
134    ///
135    /// Updates for which `f` returns `None` will always be processed in
136    /// parallel.
137    ///
138    /// ## Default distribution function
139    ///
140    /// By default the distribution function is equivalent to `|upd|
141    /// upd.chat().map(|chat| chat.id)`, so updates from the same chat will be
142    /// processed sequentially.
143    ///
144    /// This pair nicely with dialogue system, which has state attached to
145    /// chats.
146    ///
147    /// ## Examples
148    ///
149    /// Grouping updates by user who caused this update to happen:
150    ///
151    /// ```
152    /// use teloxide_ng::{Bot, dispatching::Dispatcher, dptree};
153    ///
154    /// let bot = Bot::new("TOKEN");
155    /// let handler = dptree::entry() /* ... */;
156    /// let dp = Dispatcher::builder(bot, handler)
157    ///     .distribution_function(|upd| upd.from().map(|user| user.id))
158    ///     .build();
159    /// # let _: Dispatcher<_, (), _> = dp;
160    /// ```
161    ///
162    /// Not grouping updates at all, always processing updates concurrently:
163    ///
164    /// ```
165    /// use teloxide_ng::{Bot, dispatching::Dispatcher, dptree};
166    ///
167    /// let bot = Bot::new("TOKEN");
168    /// let handler = dptree::entry() /* ... */;
169    /// let dp = Dispatcher::builder(bot, handler).distribution_function(|_| None::<()>).build();
170    /// # let _: Dispatcher<_, (), _> = dp;
171    /// ```
172    #[must_use]
173    pub fn distribution_function<K>(
174        self,
175        f: fn(&Update) -> Option<K>,
176    ) -> DispatcherBuilder<R, Err, K>
177    where
178        K: Hash + Eq,
179    {
180        let Self {
181            bot,
182            dependencies,
183            handler,
184            default_handler,
185            error_handler,
186            ctrlc_handler,
187            distribution_f: _,
188            worker_queue_size,
189        } = self;
190
191        DispatcherBuilder {
192            bot,
193            dependencies,
194            handler,
195            default_handler,
196            error_handler,
197            ctrlc_handler,
198            distribution_f: f,
199            worker_queue_size,
200        }
201    }
202
203    /// Constructs [`Dispatcher`].
204    ///
205    /// ## Panics
206    /// This function will panic at run-time if [`dptree`] fails to type-check
207    /// the provided handler. An appropriate error message will be emitted.
208    #[must_use]
209    pub fn build(self) -> Dispatcher<R, Err, Key> {
210        let Self {
211            bot,
212            dependencies,
213            handler,
214            default_handler,
215            error_handler,
216            distribution_f,
217            worker_queue_size,
218            ctrlc_handler,
219        } = self;
220
221        dptree::type_check(
222            handler.sig(),
223            &dependencies,
224            &[
225                dptree::Type::of::<R>(),
226                dptree::Type::of::<teloxide_core_ng::types::Update>(),
227                dptree::Type::of::<teloxide_core_ng::types::Me>(),
228            ],
229        );
230
231        // If the `ctrlc_handler` feature is not enabled, don't emit a warning.
232        let _ = ctrlc_handler;
233
234        let dp = Dispatcher {
235            bot,
236            dependencies,
237            handler,
238            default_handler,
239            error_handler,
240            state: ShutdownToken::new(),
241            distribution_f,
242            worker_queue_size,
243            workers: HashMap::new(),
244            default_worker: None,
245            current_number_of_active_workers: Default::default(),
246            max_number_of_active_workers: Default::default(),
247        };
248
249        #[cfg(feature = "ctrlc_handler")]
250        {
251            if ctrlc_handler {
252                let mut dp = dp;
253                dp.setup_ctrlc_handler_inner();
254                return dp;
255            }
256        }
257
258        dp
259    }
260}
261
262/// The base for update dispatching.
263///
264/// ## Update grouping
265///
266/// `Dispatcher` generally processes updates concurrently. However, by default,
267/// updates from the same chat are processed sequentially. Learn more about
268/// [update grouping].
269///
270/// See also: ["Dispatching or
271/// REPLs?"](../dispatching/index.html#dispatching-or-repls)
272///
273/// [update grouping]: DispatcherBuilder#update-grouping
274pub struct Dispatcher<R, Err, Key> {
275    bot: R,
276    dependencies: DependencyMap,
277
278    handler: Arc<UpdateHandler<Err>>,
279    default_handler: DefaultHandler,
280
281    distribution_f: fn(&Update) -> Option<Key>,
282    worker_queue_size: usize,
283    current_number_of_active_workers: Arc<AtomicU32>,
284    max_number_of_active_workers: Arc<AtomicU32>,
285    // Tokio TX channel parts associated with chat IDs that consume updates sequentially.
286    workers: HashMap<Key, Worker>,
287    // The default TX part that consume updates concurrently.
288    default_worker: Option<Worker>,
289
290    error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
291
292    state: ShutdownToken,
293}
294
295struct Worker {
296    tx: tokio::sync::mpsc::Sender<Update>,
297    handle: tokio::task::JoinHandle<()>,
298    is_waiting: Arc<AtomicBool>,
299}
300
301// TODO: it is allowed to return message as response on telegram request in
302// webhooks, so we can allow this too. See more there: https://core.telegram.org/bots/api#making-requests-when-getting-updates
303
304/// A handler that processes updates from Telegram.
305pub type UpdateHandler<Err> = dptree::Handler<'static, Result<(), Err>, DpHandlerDescription>;
306
307type DefaultHandler = Arc<dyn Fn(Arc<Update>) -> BoxFuture<'static, ()> + Send + Sync>;
308
309impl<R, Err> Dispatcher<R, Err, DefaultKey>
310where
311    R: Requester + Clone + Send + Sync + 'static,
312    Err: Send + Sync + 'static,
313{
314    /// Constructs a new [`DispatcherBuilder`] with `bot` and `handler`.
315    #[must_use]
316    pub fn builder(bot: R, handler: UpdateHandler<Err>) -> DispatcherBuilder<R, Err, DefaultKey>
317    where
318        Err: Debug,
319    {
320        const DEFAULT_WORKER_QUEUE_SIZE: usize = 64;
321
322        DispatcherBuilder {
323            bot,
324            dependencies: DependencyMap::new(),
325            handler: Arc::new(handler),
326            default_handler: Arc::new(|upd| {
327                log::warn!("Unhandled update: {upd:?}");
328                Box::pin(async {})
329            }),
330            error_handler: LoggingErrorHandler::new(),
331            ctrlc_handler: false,
332            worker_queue_size: DEFAULT_WORKER_QUEUE_SIZE,
333            distribution_f: default_distribution_function,
334        }
335    }
336}
337
338impl<R, Err, Key> Dispatcher<R, Err, Key>
339where
340    R: Requester + Clone + Send + Sync + 'static,
341    Err: Send + Sync + 'static,
342    Key: Hash + Eq + Clone + Send,
343{
344    /// Starts your bot with the default parameters.
345    ///
346    /// The default parameters are a long polling update listener and log all
347    /// errors produced by this listener.
348    ///
349    /// Each time a handler is invoked, [`Dispatcher`] adds the following
350    /// dependencies (in addition to those passed to
351    /// [`DispatcherBuilder::dependencies`]):
352    ///
353    ///  - Your bot passed to [`Dispatcher::builder`];
354    ///  - An update from Telegram;
355    ///  - [`crate::types::Me`] (can be used in [`HandlerExt::filter_command`]).
356    ///
357    /// [`HandlerExt::filter_command`]: crate::dispatching::HandlerExt::filter_command
358    pub async fn dispatch(&mut self)
359    where
360        R: Requester + Clone,
361        <R as Requester>::GetUpdates: Send,
362    {
363        let listener = update_listeners::polling_default(self.bot.clone()).await;
364        let error_handler =
365            LoggingErrorHandler::with_custom_text("An error from the update listener");
366
367        self.dispatch_with_listener(listener, error_handler).await;
368    }
369
370    /// Starts your bot with custom `update_listener` and
371    /// `update_listener_error_handler`.
372    ///
373    /// This method adds the same dependencies as [`Dispatcher::dispatch`].
374    pub async fn dispatch_with_listener<'a, UListener, Eh>(
375        &'a mut self,
376        update_listener: UListener,
377        update_listener_error_handler: Arc<Eh>,
378    ) where
379        UListener: UpdateListener + Send + 'a,
380        Eh: ErrorHandler<UListener::Err> + Send + Sync + 'a,
381        UListener::Err: Debug,
382    {
383        self.try_dispatch_with_listener(update_listener, update_listener_error_handler)
384            .await
385            .expect("Couldn't prepare dispatching context")
386    }
387
388    /// Same as `dispatch_with_listener` but returns a `Err(_)` instead of
389    /// panicking when the initial telegram api call (`get_me`) fails.
390    ///
391    /// Starts your bot with custom `update_listener` and
392    /// `update_listener_error_handler`.
393    ///
394    /// This method adds the same dependencies as [`Dispatcher::dispatch`].
395    pub async fn try_dispatch_with_listener<'a, UListener, Eh>(
396        &'a mut self,
397        mut update_listener: UListener,
398        update_listener_error_handler: Arc<Eh>,
399    ) -> Result<(), R::Err>
400    where
401        UListener: UpdateListener + Send + 'a,
402        Eh: ErrorHandler<UListener::Err> + Send + Sync + 'a,
403        UListener::Err: Debug,
404    {
405        // FIXME: there should be a way to check if dependency is already inserted
406        let me = self.bot.get_me().send().await?;
407        self.dependencies.insert(me);
408        self.dependencies.insert(self.bot.clone());
409
410        let description = self.handler.description();
411        let allowed_updates = description.allowed_updates();
412        log::debug!("hinting allowed updates: {allowed_updates:?}");
413        update_listener.hint_allowed_updates(&mut allowed_updates.into_iter());
414
415        let stop_token = Some(update_listener.stop_token());
416        self.start_listening(update_listener, update_listener_error_handler, stop_token).await;
417
418        Ok(())
419    }
420
421    async fn start_listening<'a, UListener, Eh>(
422        &'a mut self,
423        mut update_listener: UListener,
424        update_listener_error_handler: Arc<Eh>,
425        mut stop_token: Option<StopToken>,
426    ) where
427        UListener: UpdateListener + 'a,
428        Eh: ErrorHandler<UListener::Err> + 'a,
429        UListener::Err: Debug,
430    {
431        self.state.start_dispatching();
432
433        let stream = update_listener.as_stream();
434        tokio::pin!(stream);
435
436        loop {
437            self.remove_inactive_workers_if_needed().await;
438
439            let res = future::select(stream.next(), pin!(self.state.wait_for_changes()))
440                .map(either)
441                .await
442                .map_either(|l| l.0, |r| r.0);
443
444            match res {
445                Either::Left(upd) => match upd {
446                    Some(upd) => self.process_update(upd, &update_listener_error_handler).await,
447                    None => break,
448                },
449                Either::Right(()) => {
450                    if self.state.is_shutting_down() {
451                        if let Some(token) = stop_token.take() {
452                            log::debug!("Start shutting down dispatching...");
453                            token.stop();
454                        }
455                    }
456                }
457            }
458        }
459
460        self.workers
461            .drain()
462            .map(|(_chat_id, worker)| worker.handle)
463            .chain(self.default_worker.take().map(|worker| worker.handle))
464            .collect::<FuturesUnordered<_>>()
465            .for_each(|res| async {
466                res.expect("Failed to wait for a worker.");
467            })
468            .await;
469
470        self.state.done();
471    }
472
473    async fn process_update<LErr, LErrHandler>(
474        &mut self,
475        update: Result<Update, LErr>,
476        err_handler: &Arc<LErrHandler>,
477    ) where
478        LErrHandler: ErrorHandler<LErr>,
479    {
480        match update {
481            Ok(upd) => {
482                if let UpdateKind::Error(err) = upd.kind {
483                    log::error!(
484                        "Cannot parse an update.\nError: {err:?}\n\
485                            This is a bug in teloxide-core, please open an issue here: \
486                            https://github.com/teloxide/teloxide/issues.",
487                    );
488                    return;
489                }
490
491                let worker = match (self.distribution_f)(&upd) {
492                    Some(key) => self.workers.entry(key).or_insert_with(|| {
493                        let deps = self.dependencies.clone();
494                        let handler = Arc::clone(&self.handler);
495                        let default_handler = Arc::clone(&self.default_handler);
496                        let error_handler = Arc::clone(&self.error_handler);
497
498                        spawn_worker(
499                            deps,
500                            handler,
501                            default_handler,
502                            error_handler,
503                            Arc::clone(&self.current_number_of_active_workers),
504                            Arc::clone(&self.max_number_of_active_workers),
505                            self.worker_queue_size,
506                        )
507                    }),
508                    None => self.default_worker.get_or_insert_with(|| {
509                        let deps = self.dependencies.clone();
510                        let handler = Arc::clone(&self.handler);
511                        let default_handler = Arc::clone(&self.default_handler);
512                        let error_handler = Arc::clone(&self.error_handler);
513
514                        spawn_default_worker(
515                            deps,
516                            handler,
517                            default_handler,
518                            error_handler,
519                            self.worker_queue_size,
520                        )
521                    }),
522                };
523
524                worker.tx.send(upd).await.expect("TX is dead");
525            }
526            Err(err) => err_handler.clone().handle_error(err).await,
527        }
528    }
529
530    async fn remove_inactive_workers_if_needed(&mut self) {
531        let workers = self.workers.len();
532        let max = self.max_number_of_active_workers.load(Ordering::Relaxed) as usize;
533
534        if workers <= max {
535            return;
536        }
537
538        self.remove_inactive_workers().await;
539    }
540
541    #[inline(never)] // Cold function.
542    async fn remove_inactive_workers(&mut self) {
543        let handles = self
544            .workers
545            .iter()
546            .filter(|(_, worker)| {
547                worker.tx.capacity() == self.worker_queue_size
548                    && worker.is_waiting.load(Ordering::Relaxed)
549            })
550            .map(|(k, _)| k)
551            .cloned()
552            .collect::<Vec<_>>()
553            .into_iter()
554            .map(|key| {
555                let Worker { tx, handle, .. } = self.workers.remove(&key).unwrap();
556
557                // Close channel, worker should stop almost immediately
558                // (it's been supposedly waiting on the channel)
559                drop(tx);
560
561                handle
562            });
563
564        for handle in handles {
565            // We must wait for worker to stop anyway, even though it should stop
566            // immediately. This helps in case if we've checked that the worker
567            // is waiting in between it received the update and set the flag.
568            let _ = handle.await;
569        }
570    }
571
572    /// Returns a shutdown token, which can later be used to
573    /// [`ShutdownToken::shutdown`].
574    pub fn shutdown_token(&self) -> ShutdownToken {
575        self.state.clone()
576    }
577}
578
579impl<R, Err, Key> Dispatcher<R, Err, Key> {
580    #[cfg(feature = "ctrlc_handler")]
581    fn setup_ctrlc_handler_inner(&mut self) {
582        let token = self.state.clone();
583        tokio::spawn(async move {
584            loop {
585                tokio::signal::ctrl_c().await.expect("Failed to listen for ^C");
586
587                match token.shutdown() {
588                    Ok(f) => {
589                        log::info!("^C received, trying to shutdown the dispatcher...");
590                        f.await;
591                        log::info!("dispatcher is shutdown...");
592                    }
593                    Err(_) => {
594                        log::info!("^C received, the dispatcher isn't running, ignoring the signal")
595                    }
596                }
597            }
598        });
599    }
600}
601
602fn spawn_worker<Err>(
603    deps: DependencyMap,
604    handler: Arc<UpdateHandler<Err>>,
605    default_handler: DefaultHandler,
606    error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
607    current_number_of_active_workers: Arc<AtomicU32>,
608    max_number_of_active_workers: Arc<AtomicU32>,
609    queue_size: usize,
610) -> Worker
611where
612    Err: Send + Sync + 'static,
613{
614    let (tx, mut rx) = tokio::sync::mpsc::channel(queue_size);
615    let is_waiting = Arc::new(AtomicBool::new(true));
616    let is_waiting_local = Arc::clone(&is_waiting);
617
618    let deps = Arc::new(deps);
619
620    let handle = tokio::spawn(async move {
621        while let Some(update) = rx.recv().await {
622            is_waiting_local.store(false, Ordering::Relaxed);
623            {
624                let current = current_number_of_active_workers.fetch_add(1, Ordering::Relaxed) + 1;
625                max_number_of_active_workers.fetch_max(current, Ordering::Relaxed);
626            }
627
628            let deps = Arc::clone(&deps);
629            let handler = Arc::clone(&handler);
630            let default_handler = Arc::clone(&default_handler);
631            let error_handler = Arc::clone(&error_handler);
632
633            handle_update(update, deps, handler, default_handler, error_handler).await;
634
635            current_number_of_active_workers.fetch_sub(1, Ordering::Relaxed);
636            is_waiting_local.store(true, Ordering::Relaxed);
637        }
638    });
639
640    Worker { tx, handle, is_waiting }
641}
642
643fn spawn_default_worker<Err>(
644    deps: DependencyMap,
645    handler: Arc<UpdateHandler<Err>>,
646    default_handler: DefaultHandler,
647    error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
648    queue_size: usize,
649) -> Worker
650where
651    Err: Send + Sync + 'static,
652{
653    let (tx, rx) = tokio::sync::mpsc::channel(queue_size);
654
655    let deps = Arc::new(deps);
656
657    let handle = tokio::spawn(ReceiverStream::new(rx).for_each_concurrent(None, move |update| {
658        let deps = Arc::clone(&deps);
659        let handler = Arc::clone(&handler);
660        let default_handler = Arc::clone(&default_handler);
661        let error_handler = Arc::clone(&error_handler);
662
663        handle_update(update, deps, handler, default_handler, error_handler)
664    }));
665
666    Worker { tx, handle, is_waiting: Arc::new(AtomicBool::new(true)) }
667}
668
669async fn handle_update<Err>(
670    update: Update,
671    deps: Arc<DependencyMap>,
672    handler: Arc<UpdateHandler<Err>>,
673    default_handler: DefaultHandler,
674    error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
675) where
676    Err: Send + Sync + 'static,
677{
678    let mut deps = deps.deref().clone();
679    deps.insert(update);
680
681    match handler.dispatch(deps).await {
682        ControlFlow::Break(Ok(())) => {}
683        ControlFlow::Break(Err(err)) => error_handler.clone().handle_error(err).await,
684        ControlFlow::Continue(deps) => {
685            let update = deps.get();
686            (default_handler)(update).await;
687        }
688    }
689}
690
691fn either<L, R>(x: future::Either<L, R>) -> Either<L, R> {
692    match x {
693        future::Either::Left(l) => Either::Left(l),
694        future::Either::Right(r) => Either::Right(r),
695    }
696}
697#[cfg(test)]
698mod tests {
699    use std::convert::Infallible;
700
701    use teloxide_core_ng::Bot;
702
703    use super::*;
704
705    #[tokio::test]
706    async fn test_tokio_spawn() {
707        tokio::spawn(async {
708            // Just check that this code compiles.
709            if false {
710                Dispatcher::<_, Infallible, _>::builder(Bot::new(""), dptree::entry())
711                    .build()
712                    .dispatch()
713                    .await;
714            }
715        })
716        .await
717        .unwrap();
718    }
719}