Skip to main content

rust_tg_bot_ext/handlers/
command.rs

1//! [`CommandHandler`] -- handles Telegram bot commands (`/command`).
2//!
3//! Ported from `python-telegram-bot`'s `CommandHandler`. Matches messages
4//! whose first entity is `BOT_COMMAND` at offset 0, validates the command
5//! against a set of known commands, and extracts arguments.
6//!
7//! ## Key behaviours matching the Python implementation
8//!
9//! - **C1 -- `@botname` validation**: When the command contains `@username`,
10//!   the suffix is validated against `bot_username` (case-insensitive). If
11//!   `bot_username` is `None` the `@` part is silently stripped (backwards
12//!   compatible).
13//! - **C2 -- Filter integration**: An optional `filter_fn` runs *before*
14//!   command matching. The default filter accepts message-like updates via
15//!   [`Update::message`](rust_tg_bot_raw::types::update::Update::message).
16
17use std::collections::HashSet;
18use std::future::Future;
19use std::pin::Pin;
20use std::sync::Arc;
21
22use regex::Regex;
23use rust_tg_bot_raw::constants::MessageEntityType;
24use rust_tg_bot_raw::types::update::Update;
25
26use super::base::{ContextCallback, Handler, HandlerCallback, HandlerResult, MatchResult};
27use crate::context::CallbackContext;
28
29/// Specifies how many arguments a command must have to match.
30#[derive(Debug, Clone)]
31#[non_exhaustive]
32pub enum HasArgs {
33    /// Accept any number of arguments (including zero).
34    Any,
35    /// Require at least one argument.
36    NonEmpty,
37    /// Require zero arguments.
38    None,
39    /// Require exactly this many arguments.
40    Exact(usize),
41}
42
43/// Type alias for the optional update filter closure.
44///
45/// Returns `true` if the update should be considered, `false` to reject it
46/// before command matching even runs.
47pub type UpdateFilter = Arc<dyn Fn(&Update) -> bool + Send + Sync>;
48
49/// Handler for Telegram bot commands (messages starting with `/`).
50///
51/// The handler will only trigger on messages where the first entity is a
52/// `bot_command` at offset 0. It validates the command text against the
53/// provided set of commands (case-insensitive) and optionally checks the
54/// argument count.
55///
56/// # Ergonomic constructor
57///
58/// ```rust,ignore
59/// use rust_tg_bot_ext::prelude::*;
60///
61/// async fn start(update: Update, context: Context) -> HandlerResult {
62///     context.reply_text(&update, "Hello!").await?;
63///     Ok(())
64/// }
65///
66/// CommandHandler::new("start", start);
67/// ```
68///
69/// # Full-control constructor
70///
71/// ```rust,ignore
72/// use rust_tg_bot_ext::handlers::command::CommandHandler;
73/// use rust_tg_bot_ext::handlers::base::*;
74/// use std::sync::Arc;
75///
76/// let handler = CommandHandler::with_options(
77///     vec!["start".into(), "help".into()],
78///     Arc::new(|update, match_result| Box::pin(async move {
79///         HandlerResult::Continue
80///     })),
81///     None, // has_args
82///     true, // block
83/// );
84/// ```
85pub struct CommandHandler {
86    /// Lowercased set of commands this handler responds to (without `/`).
87    commands: HashSet<String>,
88    callback: HandlerCallback,
89    has_args: HasArgs,
90    block: bool,
91    /// C1: Optional bot username for `@botname` validation (stored lowercased).
92    bot_username: Option<String>,
93    /// C2: Optional filter applied before command matching. When `None` the
94    /// default behaviour is to require that the update has a `message` or
95    /// `edited_message` field, matching Python's `UpdateType.MESSAGES`.
96    filter_fn: Option<UpdateFilter>,
97    /// Optional context-aware callback for the ergonomic API.
98    context_callback: Option<ContextCallback>,
99}
100
101/// Validation regex: commands must be 1-32 chars of `[a-z0-9_]`.
102fn validate_command(cmd: &str) -> bool {
103    lazy_static_regex().is_match(cmd)
104}
105
106fn lazy_static_regex() -> &'static Regex {
107    use std::sync::OnceLock;
108    static RE: OnceLock<Regex> = OnceLock::new();
109    RE.get_or_init(|| Regex::new(r"^[a-z0-9_]{1,32}$").expect("command regex is valid"))
110}
111
112/// The default update filter: accepts message-like updates.
113fn default_update_filter(update: &Update) -> bool {
114    update.message().is_some()
115}
116
117impl CommandHandler {
118    /// Ergonomic constructor matching python-telegram-bot's
119    /// `CommandHandler("cmd", callback)`.
120    ///
121    /// Accepts a single command name (string) and an async handler function
122    /// with signature `async fn(Update, Context) -> HandlerResult`.
123    ///
124    /// # Example
125    ///
126    /// ```rust,ignore
127    /// use rust_tg_bot_ext::prelude::*;
128    ///
129    /// async fn start(update: Update, context: Context) -> HandlerResult {
130    ///     context.reply_text(&update, "Hello!").await?;
131    ///     Ok(())
132    /// }
133    ///
134    /// CommandHandler::new("start", start);
135    /// ```
136    pub fn new<Cb, Fut>(command: impl Into<String>, callback: Cb) -> Self
137    where
138        Cb: Fn(Arc<Update>, CallbackContext) -> Fut + Send + Sync + 'static,
139        Fut: Future<Output = Result<(), crate::application::HandlerError>> + Send + 'static,
140    {
141        let cmd = command.into();
142        let cb = Arc::new(callback);
143        let context_cb: ContextCallback = Arc::new(move |update, ctx| {
144            let fut = cb(update, ctx);
145            Box::pin(fut)
146                as Pin<
147                    Box<dyn Future<Output = Result<(), crate::application::HandlerError>> + Send>,
148                >
149        });
150
151        // The raw callback is a no-op; handle_update_with_context is used instead.
152        let noop_callback: HandlerCallback =
153            Arc::new(|_update, _mr| Box::pin(async { HandlerResult::Continue }));
154
155        let commands: HashSet<String> = {
156            let lower = cmd.to_lowercase();
157            assert!(
158                validate_command(&lower),
159                "Command `{lower}` is not a valid bot command"
160            );
161            let mut set = HashSet::new();
162            set.insert(lower);
163            set
164        };
165
166        Self {
167            commands,
168            callback: noop_callback,
169            has_args: HasArgs::Any,
170            block: true,
171            bot_username: None,
172            filter_fn: None,
173            context_callback: Some(context_cb),
174        }
175    }
176
177    /// Full-control constructor for advanced use cases.
178    ///
179    /// # Panics
180    ///
181    /// Panics if any command string does not match `[a-z0-9_]{1,32}`.
182    pub fn with_options(
183        commands: Vec<String>,
184        callback: HandlerCallback,
185        has_args: Option<HasArgs>,
186        block: bool,
187    ) -> Self {
188        let commands: HashSet<String> = commands.into_iter().map(|c| c.to_lowercase()).collect();
189        for cmd in &commands {
190            assert!(
191                validate_command(cmd),
192                "Command `{cmd}` is not a valid bot command"
193            );
194        }
195        Self {
196            commands,
197            callback,
198            has_args: has_args.unwrap_or(HasArgs::Any),
199            block,
200            bot_username: None,
201            filter_fn: None,
202            context_callback: None,
203        }
204    }
205
206    /// Set the bot username for `@botname` validation (C1).
207    ///
208    /// When a command like `/start@MyBot` is received, the `@MyBot` suffix
209    /// will be compared case-insensitively against this value. If they do not
210    /// match the update is rejected.
211    ///
212    /// If no bot username is configured, the `@` suffix is silently ignored
213    /// (backwards compatible).
214    pub fn with_bot_username(mut self, username: impl Into<String>) -> Self {
215        self.bot_username = Some(username.into().to_lowercase());
216        self
217    }
218
219    /// Set a custom update filter (C2).
220    ///
221    /// The filter runs *before* any command matching. If it returns `false`
222    /// the update is immediately rejected.
223    ///
224    /// When no custom filter is supplied the default behaviour is to accept
225    /// message-like updates via [`Update::message`](Update::message).
226    pub fn with_filter(mut self, filter: UpdateFilter) -> Self {
227        self.filter_fn = Some(filter);
228        self
229    }
230
231    /// Check whether the argument count satisfies the `has_args` constraint.
232    fn check_args(&self, args: &[String]) -> bool {
233        match &self.has_args {
234            HasArgs::Any => true,
235            HasArgs::NonEmpty => !args.is_empty(),
236            HasArgs::None => args.is_empty(),
237            HasArgs::Exact(n) => args.len() == *n,
238        }
239    }
240}
241
242impl Handler for CommandHandler {
243    fn check_update(&self, update: &Update) -> Option<MatchResult> {
244        // -- C2: Apply filter first -----------------------------------------------
245        let passes_filter = match &self.filter_fn {
246            Some(f) => f(update),
247            None => default_update_filter(update),
248        };
249        if !passes_filter {
250            return None;
251        }
252
253        let message = update.effective_message()?;
254        let text = message.text.as_ref()?;
255        let entities = message.entities.as_ref()?;
256
257        // First entity must be a bot_command at offset 0.
258        let first_entity = entities.first()?;
259        if first_entity.entity_type != MessageEntityType::BotCommand {
260            return None;
261        }
262        if first_entity.offset != 0 {
263            return None;
264        }
265        let length = first_entity.length as usize;
266
267        // Extract command (strip leading `/`) and optional `@botname`.
268        let raw_command = &text[1..length];
269        let command_parts: Vec<&str> = raw_command.splitn(2, '@').collect();
270        let command_name = command_parts[0];
271
272        // -- C1: Validate @botname suffix -----------------------------------------
273        if command_parts.len() > 1 {
274            // Command has an `@suffix`.
275            let at_suffix = command_parts[1];
276            if let Some(ref expected) = self.bot_username {
277                if at_suffix.to_lowercase() != *expected {
278                    return None;
279                }
280            }
281            // When bot_username is None we silently strip the suffix
282            // (backwards compatible).
283        }
284        // When there is no @suffix the command is accepted regardless of
285        // bot_username (matches Python behaviour where the bot appends its
286        // own username and the comparison trivially passes).
287
288        if !self.commands.contains(&command_name.to_lowercase()) {
289            return None;
290        }
291
292        // Extract arguments: everything after the command, split on whitespace.
293        let args: Vec<String> = text.split_whitespace().skip(1).map(String::from).collect();
294
295        if !self.check_args(&args) {
296            return None;
297        }
298
299        Some(MatchResult::Args(args))
300    }
301
302    fn handle_update(
303        &self,
304        update: Arc<Update>,
305        match_result: MatchResult,
306    ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
307        (self.callback)(update, match_result)
308    }
309
310    fn block(&self) -> bool {
311        self.block
312    }
313
314    /// Merge command arguments into `context.args`.
315    ///
316    /// Mirrors Python's `CommandHandler.collect_additional_context` which
317    /// populates `context.args` from the parsed argument list produced by
318    /// [`check_update`](Handler::check_update).
319    fn collect_additional_context(
320        &self,
321        context: &mut CallbackContext,
322        match_result: &MatchResult,
323    ) {
324        if let MatchResult::Args(args) = match_result {
325            context.args = Some(args.clone());
326        }
327    }
328
329    fn handle_update_with_context(
330        &self,
331        update: Arc<Update>,
332        match_result: MatchResult,
333        context: CallbackContext,
334    ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
335        if let Some(ref cb) = self.context_callback {
336            let fut = cb(update, context);
337            Box::pin(async move {
338                match fut.await {
339                    Ok(()) => HandlerResult::Continue,
340                    Err(crate::application::HandlerError::HandlerStop { .. }) => {
341                        HandlerResult::Stop
342                    }
343                    Err(crate::application::HandlerError::Other(e)) => HandlerResult::Error(e),
344                }
345            })
346        } else {
347            (self.callback)(update, match_result)
348        }
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use serde_json::json;
356    use std::sync::Arc;
357
358    fn noop_callback() -> HandlerCallback {
359        Arc::new(|_update, _mr| Box::pin(async { HandlerResult::Continue }))
360    }
361
362    /// Build a minimal `Update` with a command message via JSON deserialization.
363    fn make_command_update(text: &str) -> Update {
364        let cmd_part = text.split_whitespace().next().unwrap_or(text);
365        let entity_len = cmd_part.len();
366        serde_json::from_value(json!({
367            "update_id": 1,
368            "message": {
369                "message_id": 1,
370                "date": 0,
371                "chat": {"id": 1, "type": "private"},
372                "text": text,
373                "entities": [{"type": "bot_command", "offset": 0, "length": entity_len}]
374            }
375        }))
376        .expect("test update JSON must be valid")
377    }
378
379    /// Build an `Update` with an `edited_message` containing a command.
380    fn make_edited_command_update(text: &str) -> Update {
381        let cmd_part = text.split_whitespace().next().unwrap_or(text);
382        let entity_len = cmd_part.len();
383        serde_json::from_value(json!({
384            "update_id": 1,
385            "edited_message": {
386                "message_id": 1,
387                "date": 0,
388                "chat": {"id": 1, "type": "private"},
389                "text": text,
390                "entities": [{"type": "bot_command", "offset": 0, "length": entity_len}]
391            }
392        }))
393        .expect("test update JSON must be valid")
394    }
395
396    #[test]
397    fn valid_commands_accepted() {
398        let h = CommandHandler::with_options(
399            vec!["start".into(), "help".into()],
400            noop_callback(),
401            None,
402            true,
403        );
404        assert!(h.commands.contains("start"));
405        assert!(h.commands.contains("help"));
406    }
407
408    #[test]
409    #[should_panic(expected = "not a valid bot command")]
410    fn invalid_command_panics() {
411        CommandHandler::with_options(vec!["invalid command!".into()], noop_callback(), None, true);
412    }
413
414    #[test]
415    fn check_args_variants() {
416        let h = CommandHandler::with_options(vec!["test".into()], noop_callback(), None, true);
417        assert!(h.check_args(&[]));
418        assert!(h.check_args(&["a".into()]));
419
420        let h2 = CommandHandler::with_options(
421            vec!["test".into()],
422            noop_callback(),
423            Some(HasArgs::Exact(2)),
424            true,
425        );
426        assert!(!h2.check_args(&["a".into()]));
427        assert!(h2.check_args(&["a".into(), "b".into()]));
428    }
429
430    // -- C1 tests ---------------------------------------------------------
431
432    #[test]
433    fn c1_bot_username_matching_accepted() {
434        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
435            .with_bot_username("MyBot");
436        let update = make_command_update("/start@MyBot");
437        assert!(h.check_update(&update).is_some());
438    }
439
440    #[test]
441    fn c1_bot_username_case_insensitive() {
442        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
443            .with_bot_username("mybot");
444        let update = make_command_update("/start@MYBOT");
445        assert!(h.check_update(&update).is_some());
446    }
447
448    #[test]
449    fn c1_wrong_bot_username_rejected() {
450        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
451            .with_bot_username("MyBot");
452        let update = make_command_update("/start@OtherBot");
453        assert!(h.check_update(&update).is_none());
454    }
455
456    #[test]
457    fn c1_no_at_suffix_still_accepted() {
458        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
459            .with_bot_username("MyBot");
460        let update = make_command_update("/start");
461        assert!(h.check_update(&update).is_some());
462    }
463
464    #[test]
465    fn c1_no_bot_username_configured_strips_suffix() {
466        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true);
467        let update = make_command_update("/start@AnyBot");
468        assert!(h.check_update(&update).is_some());
469    }
470
471    // -- C2 tests ---------------------------------------------------------
472
473    #[test]
474    fn c2_default_filter_accepts_edited_message() {
475        let update = make_edited_command_update("/start");
476        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true);
477        // Default filter now matches edited_message (like Python's UpdateType.MESSAGES).
478        assert!(h.check_update(&update).is_some());
479    }
480
481    #[test]
482    fn c2_default_filter_accepts_channel_post() {
483        let update: Update = serde_json::from_value(json!({
484            "update_id": 1,
485            "channel_post": {
486                "message_id": 1,
487                "date": 0,
488                "chat": {"id": -100, "type": "channel"},
489                "text": "/start",
490                "entities": [{"type": "bot_command", "offset": 0, "length": 6}]
491            }
492        }))
493        .expect("valid");
494
495        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true);
496        assert!(h.check_update(&update).is_some());
497    }
498
499    #[test]
500    fn c2_custom_filter_allows_edited() {
501        let update = make_edited_command_update("/start");
502
503        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
504            .with_filter(Arc::new(|u| {
505                u.message().is_some() || u.edited_message().is_some()
506            }));
507        assert!(h.check_update(&update).is_some());
508    }
509
510    #[test]
511    fn c2_custom_filter_rejects() {
512        let update = make_command_update("/start");
513        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
514            .with_filter(Arc::new(|_u| false));
515        assert!(h.check_update(&update).is_none());
516    }
517
518    // -- collect_additional_context tests ---------------------------------
519
520    #[test]
521    fn collect_context_populates_args() {
522        use crate::context::CallbackContext;
523        use crate::ext_bot::test_support::mock_request;
524        use rust_tg_bot_raw::bot::Bot;
525        use std::collections::HashMap;
526
527        let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
528            "test",
529            mock_request(),
530        )));
531        let stores = (
532            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
533            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
534            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
535        );
536        let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
537
538        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true);
539        let mr = MatchResult::Args(vec!["foo".into(), "bar".into()]);
540        h.collect_additional_context(&mut ctx, &mr);
541
542        assert_eq!(ctx.args, Some(vec!["foo".into(), "bar".into()]));
543    }
544
545    #[test]
546    fn collect_context_no_op_for_empty() {
547        use crate::context::CallbackContext;
548        use crate::ext_bot::test_support::mock_request;
549        use rust_tg_bot_raw::bot::Bot;
550        use std::collections::HashMap;
551
552        let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
553            "test",
554            mock_request(),
555        )));
556        let stores = (
557            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
558            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
559            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
560        );
561        let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
562
563        let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true);
564        h.collect_additional_context(&mut ctx, &MatchResult::Empty);
565
566        assert!(ctx.args.is_none());
567    }
568
569    // -- Ergonomic constructor tests --------------------------------------
570
571    #[test]
572    fn ergonomic_new_check_update_works() {
573        async fn dummy(
574            _update: Arc<Update>,
575            _ctx: CallbackContext,
576        ) -> Result<(), crate::application::HandlerError> {
577            Ok(())
578        }
579        let h = CommandHandler::new("start", dummy);
580        let update = make_command_update("/start");
581        assert!(h.check_update(&update).is_some());
582    }
583
584    #[test]
585    fn ergonomic_new_rejects_wrong_command() {
586        async fn dummy(
587            _update: Arc<Update>,
588            _ctx: CallbackContext,
589        ) -> Result<(), crate::application::HandlerError> {
590            Ok(())
591        }
592        let h = CommandHandler::new("start", dummy);
593        let update = make_command_update("/help");
594        assert!(h.check_update(&update).is_none());
595    }
596}