Skip to main content

rust_tg_bot_ext/handlers/
callback_query.rs

1//! [`CallbackQueryHandler`] -- handles updates containing a callback query.
2//!
3//! Ported from `python-telegram-bot`'s `CallbackQueryHandler`. Supports
4//! optional regex pattern matching on `callback_query.data` and
5//! `callback_query.game_short_name`, as well as predicate functions for
6//! Rust-idiomatic `callable(data)` and type-check patterns.
7
8use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use regex::Regex;
14use rust_tg_bot_raw::types::update::Update;
15
16use super::base::{Handler, HandlerCallback, HandlerResult, MatchResult};
17use crate::context::CallbackContext;
18
19/// Pattern to match against callback query data or game short name.
20#[derive(Clone)]
21#[non_exhaustive]
22pub enum CallbackPattern {
23    /// Match `callback_query.data` against this compiled regex.
24    Data(Regex),
25    /// Match `callback_query.game_short_name` against this compiled regex.
26    Game(Regex),
27    /// Match both: data against the first regex, game against the second.
28    Both {
29        /// Regex to match against `callback_query.data`.
30        data: Regex,
31        /// Regex to match against `callback_query.game_short_name`.
32        game: Regex,
33    },
34    /// Match `callback_query.data` using an arbitrary predicate function.
35    ///
36    /// This covers Python's `callable(data)` and `isinstance(data, type)`
37    /// patterns in a Rust-idiomatic way. The predicate receives the data
38    /// string and returns `true` if the callback query should be handled.
39    Predicate(Arc<dyn Fn(&str) -> bool + Send + Sync>),
40}
41
42/// Handler for `Update.callback_query`.
43///
44/// When no pattern is set, any callback query matches. When a `Data` pattern
45/// is set, only queries with matching `.data` are accepted (queries with only
46/// `.game_short_name` are rejected, and vice versa).
47///
48/// Named capture groups in the regex pattern are exposed via
49/// `context.named_matches` (a `HashMap<String, String>`), while all captures
50/// (positional) are available as `context.matches`. This mirrors Python's
51/// behaviour of putting the full `re.Match` object into `context.matches`.
52///
53/// # Example
54///
55/// ```rust,ignore
56/// use rust_tg_bot_ext::handlers::callback_query::{CallbackQueryHandler, CallbackPattern};
57/// use rust_tg_bot_ext::handlers::base::*;
58/// use regex::Regex;
59/// use std::sync::Arc;
60///
61/// // Regex-based matching:
62/// let handler = CallbackQueryHandler::new(
63///     Arc::new(|update, mr| Box::pin(async move { HandlerResult::Continue })),
64///     Some(CallbackPattern::Data(Regex::new(r"^btn_(\d+)$").unwrap())),
65///     true,
66/// );
67///
68/// // Predicate-based matching (covers callable/type patterns):
69/// let handler2 = CallbackQueryHandler::new(
70///     Arc::new(|update, mr| Box::pin(async move { HandlerResult::Continue })),
71///     Some(CallbackPattern::Predicate(Arc::new(|data| data.starts_with("action_")))),
72///     true,
73/// );
74/// ```
75pub struct CallbackQueryHandler {
76    callback: HandlerCallback,
77    pattern: Option<CallbackPattern>,
78    block: bool,
79}
80
81impl CallbackQueryHandler {
82    /// Create a new `CallbackQueryHandler`.
83    pub fn new(callback: HandlerCallback, pattern: Option<CallbackPattern>, block: bool) -> Self {
84        Self {
85            callback,
86            pattern,
87            block,
88        }
89    }
90
91    /// Attempt regex match, returning captured groups as a `MatchResult`.
92    ///
93    /// When the regex contains at least one named capture group, returns
94    /// `MatchResult::RegexMatchWithNames` so that callers can access both
95    /// positional captures and the named-group map. Otherwise returns
96    /// `MatchResult::RegexMatch` (positional-only, backwards compatible).
97    fn try_regex(re: &Regex, text: &str) -> Option<MatchResult> {
98        let caps = re.captures(text)?;
99
100        let positional: Vec<String> = caps
101            .iter()
102            .filter_map(|m| m.map(|m| m.as_str().to_owned()))
103            .collect();
104
105        // Collect named groups. `capture_names()` yields `Option<&str>` for
106        // each capture slot (None for unnamed slots). We only include names
107        // that actually matched.
108        let mut named: HashMap<String, String> = HashMap::new();
109        for name in re.capture_names().flatten() {
110            if let Some(m) = caps.name(name) {
111                named.insert(name.to_owned(), m.as_str().to_owned());
112            }
113        }
114
115        if named.is_empty() {
116            Some(MatchResult::RegexMatch(positional))
117        } else {
118            Some(MatchResult::RegexMatchWithNames { positional, named })
119        }
120    }
121}
122
123impl Handler for CallbackQueryHandler {
124    fn check_update(&self, update: &Update) -> Option<MatchResult> {
125        let cq = update.callback_query()?;
126
127        match &self.pattern {
128            None => {
129                // No pattern: accept any callback query.
130                Some(MatchResult::Empty)
131            }
132            Some(CallbackPattern::Data(re)) => {
133                let data = cq.data.as_ref()?;
134                Self::try_regex(re, data)
135            }
136            Some(CallbackPattern::Game(re)) => {
137                let game = cq.game_short_name.as_ref()?;
138                Self::try_regex(re, game)
139            }
140            Some(CallbackPattern::Both {
141                data: dre,
142                game: gre,
143            }) => {
144                // Match whichever field is present.
145                if let Some(data) = cq.data.as_ref() {
146                    if let Some(mr) = Self::try_regex(dre, data) {
147                        return Some(mr);
148                    }
149                }
150                if let Some(game) = cq.game_short_name.as_ref() {
151                    return Self::try_regex(gre, game);
152                }
153                None
154            }
155            Some(CallbackPattern::Predicate(pred)) => {
156                let data = cq.data.as_ref()?;
157                if pred(data) {
158                    Some(MatchResult::Empty)
159                } else {
160                    None
161                }
162            }
163        }
164    }
165
166    fn handle_update(
167        &self,
168        update: Arc<Update>,
169        match_result: MatchResult,
170    ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
171        (self.callback)(update, match_result)
172    }
173
174    fn block(&self) -> bool {
175        self.block
176    }
177
178    /// Populate `context.matches` (positional) and `context.named_matches`
179    /// (named groups) from the regex match result.
180    ///
181    /// Mirrors Python's `CallbackQueryHandler.collect_additional_context`
182    /// which injects the `re.Match` object into `context.matches`.
183    fn collect_additional_context(
184        &self,
185        context: &mut CallbackContext,
186        match_result: &MatchResult,
187    ) {
188        match match_result {
189            MatchResult::RegexMatch(groups) => {
190                context.matches = Some(groups.clone());
191            }
192            MatchResult::RegexMatchWithNames { positional, named } => {
193                context.matches = Some(positional.clone());
194                context.named_matches = Some(named.clone());
195            }
196            _ => {}
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use std::sync::Arc;
204
205    use super::*;
206
207    fn noop_callback() -> HandlerCallback {
208        Arc::new(|_update, _mr| Box::pin(async { HandlerResult::Continue }))
209    }
210
211    fn make_callback_query_update(data: &str) -> Update {
212        serde_json::from_value(serde_json::json!({
213            "update_id": 1,
214            "callback_query": {
215                "id": "42",
216                "from": {"id": 1, "is_bot": false, "first_name": "Test"},
217                "chat_instance": "ci",
218                "data": data
219            }
220        }))
221        .expect("valid callback query update")
222    }
223
224    #[test]
225    fn no_callback_query_returns_none() {
226        let h = CallbackQueryHandler::new(noop_callback(), None, true);
227        let update: Update = serde_json::from_str(r#"{"update_id": 1}"#).unwrap();
228        assert!(h.check_update(&update).is_none());
229    }
230
231    #[test]
232    fn predicate_accepts_matching_data() {
233        let h = CallbackQueryHandler::new(
234            noop_callback(),
235            Some(CallbackPattern::Predicate(Arc::new(|data| {
236                data.starts_with("btn_")
237            }))),
238            true,
239        );
240        let update = make_callback_query_update("btn_42");
241        assert!(h.check_update(&update).is_some());
242    }
243
244    #[test]
245    fn predicate_rejects_non_matching_data() {
246        let h = CallbackQueryHandler::new(
247            noop_callback(),
248            Some(CallbackPattern::Predicate(Arc::new(|data| {
249                data.starts_with("btn_")
250            }))),
251            true,
252        );
253        let update = make_callback_query_update("action_42");
254        assert!(h.check_update(&update).is_none());
255    }
256
257    #[test]
258    fn predicate_requires_data_field() {
259        // Callback query without data should not match Predicate.
260        let h = CallbackQueryHandler::new(
261            noop_callback(),
262            Some(CallbackPattern::Predicate(Arc::new(|_| true))),
263            true,
264        );
265        let update: Update = serde_json::from_value(serde_json::json!({
266            "update_id": 1,
267            "callback_query": {
268                "id": "42",
269                "from": {"id": 1, "is_bot": false, "first_name": "Test"},
270                "chat_instance": "ci",
271                "game_short_name": "mygame"
272            }
273        }))
274        .expect("valid");
275        assert!(h.check_update(&update).is_none());
276    }
277
278    #[test]
279    fn no_pattern_accepts_any_callback_query() {
280        let h = CallbackQueryHandler::new(noop_callback(), None, true);
281        let update = make_callback_query_update("anything");
282        assert!(h.check_update(&update).is_some());
283    }
284
285    #[test]
286    fn regex_data_pattern_matches() {
287        let h = CallbackQueryHandler::new(
288            noop_callback(),
289            Some(CallbackPattern::Data(Regex::new(r"^btn_(\d+)$").unwrap())),
290            true,
291        );
292        let update = make_callback_query_update("btn_123");
293        let result = h.check_update(&update);
294        assert!(result.is_some());
295        if let Some(MatchResult::RegexMatch(groups)) = result {
296            assert_eq!(groups[0], "btn_123");
297            assert_eq!(groups[1], "123");
298        } else {
299            panic!("expected RegexMatch");
300        }
301    }
302
303    #[test]
304    fn named_group_pattern_returns_regex_match_with_names() {
305        let re = Regex::new(r"^btn_(?P<id>\d+)$").unwrap();
306        let h = CallbackQueryHandler::new(noop_callback(), Some(CallbackPattern::Data(re)), true);
307        let update = make_callback_query_update("btn_99");
308        match h.check_update(&update) {
309            Some(MatchResult::RegexMatchWithNames { positional, named }) => {
310                assert_eq!(positional[0], "btn_99");
311                assert_eq!(named.get("id").map(String::as_str), Some("99"));
312            }
313            other => panic!("expected RegexMatchWithNames, got {other:?}"),
314        }
315    }
316
317    #[test]
318    fn collect_context_positional() {
319        use crate::context::CallbackContext;
320        use crate::ext_bot::test_support::mock_request;
321        use rust_tg_bot_raw::bot::Bot;
322
323        let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
324            "test",
325            mock_request(),
326        )));
327        let stores = (
328            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
329            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
330            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
331        );
332        let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
333
334        let h = CallbackQueryHandler::new(noop_callback(), None, true);
335        let mr = MatchResult::RegexMatch(vec!["full".into(), "group1".into()]);
336        h.collect_additional_context(&mut ctx, &mr);
337
338        assert_eq!(ctx.matches, Some(vec!["full".into(), "group1".into()]));
339        assert!(ctx.named_matches.is_none());
340    }
341
342    #[test]
343    fn collect_context_named() {
344        use crate::context::CallbackContext;
345        use crate::ext_bot::test_support::mock_request;
346        use rust_tg_bot_raw::bot::Bot;
347
348        let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
349            "test",
350            mock_request(),
351        )));
352        let stores = (
353            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
354            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
355            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
356        );
357        let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
358
359        let h = CallbackQueryHandler::new(noop_callback(), None, true);
360        let mut named = HashMap::new();
361        named.insert("id".into(), "99".into());
362        let mr = MatchResult::RegexMatchWithNames {
363            positional: vec!["btn_99".into(), "99".into()],
364            named,
365        };
366        h.collect_additional_context(&mut ctx, &mr);
367
368        assert_eq!(ctx.matches, Some(vec!["btn_99".into(), "99".into()]));
369        assert_eq!(
370            ctx.named_matches
371                .as_ref()
372                .and_then(|m| m.get("id"))
373                .map(String::as_str),
374            Some("99")
375        );
376    }
377}