teloxide/dispatching/
handler_ext.rs

1use crate::{
2    dispatching::{
3        dialogue::{GetChatId, Storage},
4        DpHandlerDescription,
5    },
6    types::{Me, Message},
7    utils::command::BotCommands,
8};
9use dptree::Handler;
10
11use std::fmt::Debug;
12
13/// Extension methods for working with `dptree` handlers.
14pub trait HandlerExt<Output> {
15    /// Returns a handler that accepts a parsed command `C`.
16    ///
17    /// ## Dependency requirements
18    ///
19    ///  - [`crate::types::Message`]
20    ///  - [`crate::types::Me`]
21    #[must_use]
22    fn filter_command<C>(self) -> Self
23    where
24        C: BotCommands + Send + Sync + 'static;
25
26    /// Returns a handler that accepts a parsed command `C` if the command
27    /// contains a bot mention, for example `/start@my_bot`.
28    ///
29    /// ## Dependency requirements
30    ///
31    ///  - [`crate::types::Message`]
32    ///  - [`crate::types::Me`]
33    #[must_use]
34    fn filter_mention_command<C>(self) -> Self
35    where
36        C: BotCommands + Send + Sync + 'static;
37
38    /// Passes [`Dialogue<D, S>`] and `D` as handler dependencies.
39    ///
40    /// It does so by the following steps:
41    ///
42    ///  1. If an incoming update has no chat ID ([`GetChatId::chat_id`] returns
43    ///     `None`), the rest of the chain will not be executed. Otherwise,
44    ///     passes `Dialogue::new(storage, chat_id)` forwards.
45    ///  2. If [`Dialogue::get_or_default`] on the passed dialogue returns `Ok`,
46    ///     passes the dialogue state forwards. Otherwise, logs an error and the
47    ///     rest of the chain is not executed.
48    ///
49    /// If `TELOXIDE_DIALOGUE_BEHAVIOUR` environmental variable exists and is
50    /// equal to "default", this function will not panic if it can't get the
51    /// dialogue (if, for example, the state enum was updated). Setting the
52    /// value to "panic" will return the initial behaviour.
53    ///
54    /// ## Dependency requirements
55    ///
56    ///  - `Arc<S>`
57    ///  - `Upd`
58    ///
59    /// [`Dialogue<D, S>`]: super::dialogue::Dialogue
60    /// [`Dialogue::get_or_default`]: super::dialogue::Dialogue::get_or_default
61    #[must_use]
62    fn enter_dialogue<Upd, S, D>(self) -> Self
63    where
64        S: Storage<D> + ?Sized + Send + Sync + 'static,
65        <S as Storage<D>>::Error: Debug + Send,
66        D: Default + Clone + Send + Sync + 'static,
67        Upd: GetChatId + Clone + Send + Sync + 'static;
68}
69
70impl<Output> HandlerExt<Output> for Handler<'static, Output, DpHandlerDescription>
71where
72    Output: Send + Sync + 'static,
73{
74    fn filter_command<C>(self) -> Self
75    where
76        C: BotCommands + Send + Sync + 'static,
77    {
78        self.chain(filter_command::<C, Output>())
79    }
80
81    fn filter_mention_command<C>(self) -> Self
82    where
83        C: BotCommands + Send + Sync + 'static,
84    {
85        self.chain(filter_mention_command::<C, Output>())
86    }
87
88    fn enter_dialogue<Upd, S, D>(self) -> Self
89    where
90        S: Storage<D> + ?Sized + Send + Sync + 'static,
91        <S as Storage<D>>::Error: Debug + Send,
92        D: Default + Clone + Send + Sync + 'static,
93        Upd: GetChatId + Clone + Send + Sync + 'static,
94    {
95        self.chain(super::dialogue::enter::<Upd, S, D, Output>())
96    }
97}
98
99/// Returns a handler that accepts a parsed command `C`.
100///
101/// A call to this function is the same as `dptree::entry().filter_command()`.
102///
103/// See [`HandlerExt::filter_command`].
104///
105/// ## Dependency requirements
106///
107///  - [`crate::types::Message`]
108///  - [`crate::types::Me`]
109#[must_use]
110pub fn filter_command<C, Output>() -> Handler<'static, Output, DpHandlerDescription>
111where
112    C: BotCommands + Send + Sync + 'static,
113    Output: Send + Sync + 'static,
114{
115    dptree::filter_map(move |message: Message, me: Me| {
116        let bot_name = me.user.username.expect("Bots must have a username");
117        message.text().or_else(|| message.caption()).and_then(|text| C::parse(text, &bot_name).ok())
118    })
119}
120
121/// Returns a handler that accepts a parsed command `C` if the command
122/// contains a bot mention, for example `/start@my_bot`.
123///
124/// A call to this function is the same as
125/// `dptree::entry().filter_mention_command()`.
126///
127/// See [`HandlerExt::filter_mention_command`].
128///
129/// ## Dependency requirements
130///
131///  - [`crate::types::Message`]
132///  - [`crate::types::Me`]
133#[must_use]
134pub fn filter_mention_command<C, Output>() -> Handler<'static, Output, DpHandlerDescription>
135where
136    C: BotCommands + Send + Sync + 'static,
137    Output: Send + Sync + 'static,
138{
139    dptree::filter_map(move |message: Message, me: Me| {
140        let bot_name = me.user.username.expect("Bots must have a username");
141
142        let text_or_caption = message.text().or_else(|| message.caption());
143        let command = text_or_caption.and_then(|text| C::parse(text, &bot_name).ok());
144        // If the parsing succeeds with a bot_name,
145        // but fails without - there is a mention
146        let is_username_required =
147            text_or_caption.and_then(|text| C::parse(text, "").ok()).is_none();
148
149        if !is_username_required {
150            return None;
151        }
152        command
153    })
154}
155
156#[cfg(test)]
157#[cfg(feature = "macros")]
158mod tests {
159    use crate::{self as teloxide, dispatching::UpdateFilterExt, utils::command::BotCommands};
160    use chrono::DateTime;
161    use dptree::deps;
162    use teloxide_core::types::{
163        Chat, ChatId, ChatKind, ChatPrivate, LinkPreviewOptions, Me, MediaKind, MediaText, Message,
164        MessageCommon, MessageId, MessageKind, Update, UpdateId, UpdateKind, User, UserId,
165    };
166
167    use super::HandlerExt;
168
169    #[derive(BotCommands, Clone)]
170    #[command(rename_rule = "lowercase")]
171    enum Cmd {
172        Test,
173    }
174
175    fn make_update(text: String) -> Update {
176        let timestamp = 1_569_518_829;
177        let date = DateTime::from_timestamp(timestamp, 0).unwrap();
178        Update {
179            id: UpdateId(326_170_274),
180            kind: UpdateKind::Message(Message {
181                via_bot: None,
182                id: MessageId(5042),
183                thread_id: None,
184                from: Some(User {
185                    id: UserId(109_998_024),
186                    is_bot: false,
187                    first_name: String::from("Laster"),
188                    last_name: None,
189                    username: Some(String::from("laster_alex")),
190                    language_code: Some(String::from("en")),
191                    is_premium: false,
192                    added_to_attachment_menu: false,
193                }),
194                sender_chat: None,
195                is_topic_message: false,
196                sender_business_bot: None,
197                date,
198                chat: Chat {
199                    id: ChatId(109_998_024),
200                    kind: ChatKind::Private(ChatPrivate {
201                        username: Some(String::from("Laster")),
202                        first_name: Some(String::from("laster_alex")),
203                        last_name: None,
204                    }),
205                },
206                kind: MessageKind::Common(MessageCommon {
207                    reply_to_message: None,
208                    forward_origin: None,
209                    external_reply: None,
210                    quote: None,
211                    edit_date: None,
212                    media_kind: MediaKind::Text(MediaText {
213                        text,
214                        entities: vec![],
215                        link_preview_options: Some(LinkPreviewOptions {
216                            is_disabled: true,
217                            url: None,
218                            prefer_small_media: false,
219                            prefer_large_media: false,
220                            show_above_text: false,
221                        }),
222                    }),
223                    reply_markup: None,
224                    author_signature: None,
225                    paid_star_count: None,
226                    effect_id: None,
227                    is_automatic_forward: false,
228                    has_protected_content: false,
229                    reply_to_story: None,
230                    sender_boost_count: None,
231                    is_from_offline: false,
232                    business_connection_id: None,
233                }),
234            }),
235        }
236    }
237
238    fn make_me() -> Me {
239        Me {
240            user: User {
241                id: UserId(42),
242                is_bot: true,
243                first_name: "First".to_owned(),
244                last_name: None,
245                username: Some("SomethingSomethingBot".to_owned()),
246                language_code: None,
247                is_premium: false,
248                added_to_attachment_menu: false,
249            },
250            can_join_groups: false,
251            can_read_all_group_messages: false,
252            supports_inline_queries: false,
253            can_connect_to_business: false,
254            has_main_web_app: false,
255        }
256    }
257
258    #[tokio::test]
259    async fn test_filter_command() {
260        let h = dptree::entry()
261            .branch(Update::filter_message().filter_command::<Cmd>().endpoint(|| async {}));
262        let me = make_me();
263
264        let update = make_update("/test@".to_owned() + me.username());
265        let result = h.dispatch(deps![update, me.clone()]).await;
266        assert!(result.is_break());
267
268        let update = make_update("/test@".to_owned() + "SomeOtherBot");
269        let result = h.dispatch(deps![update, me.clone()]).await;
270        assert!(result.is_continue());
271
272        let update = make_update("/test".to_owned());
273        let result = h.dispatch(deps![update, me.clone()]).await;
274        assert!(result.is_break());
275    }
276
277    #[tokio::test]
278    async fn test_filter_mention_command() {
279        let h = dptree::entry()
280            .branch(Update::filter_message().filter_mention_command::<Cmd>().endpoint(|| async {}));
281        let me = make_me();
282
283        let update = make_update("/test@".to_owned() + me.username());
284        let result = h.dispatch(deps![update, me.clone()]).await;
285        assert!(result.is_break());
286
287        let update = make_update("/test@".to_owned() + "SomeOtherBot");
288        let result = h.dispatch(deps![update, me.clone()]).await;
289        assert!(result.is_continue());
290
291        let update = make_update("/test".to_owned());
292        let result = h.dispatch(deps![update, me.clone()]).await;
293        assert!(result.is_continue());
294    }
295}