Skip to main content

rust_tg_bot_ext/handlers/
string_regex.rs

1//! [`StringRegexHandler`] -- handles messages whose text matches a regex.
2//!
3//! Adapted from `python-telegram-bot`'s `StringRegexHandler`. The Python
4//! version operates on raw strings, not Telegram updates. Per the design
5//! decision, this Rust version operates on `Update` objects, extracting
6//! message text and matching it against a compiled regex.
7
8use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use regex::Regex;
14use rust_tg_bot_raw::types::update::Update;
15
16use super::base::{Handler, HandlerCallback, HandlerResult, MatchResult};
17use crate::context::CallbackContext;
18
19/// Handler that matches messages whose text matches a regex pattern.
20///
21/// Uses `regex::Regex::captures` (anchored at the start of the string,
22/// matching `re.match` from Python).
23///
24/// When the regex contains named capture groups, both `context.matches`
25/// (positional) and `context.named_matches` (named groups map) are
26/// populated.
27///
28/// # Example
29///
30/// ```rust,ignore
31/// use rust_tg_bot_ext::handlers::string_regex::StringRegexHandler;
32/// use rust_tg_bot_ext::handlers::base::*;
33/// use regex::Regex;
34/// use std::sync::Arc;
35///
36/// let handler = StringRegexHandler::new(
37///     Regex::new(r"^hello (\w+)").unwrap(),
38///     Arc::new(|update, mr| Box::pin(async move { HandlerResult::Continue })),
39///     true,
40/// );
41/// ```
42pub struct StringRegexHandler {
43    pattern: Regex,
44    callback: HandlerCallback,
45    block: bool,
46}
47
48impl StringRegexHandler {
49    /// Create a new `StringRegexHandler`.
50    pub fn new(pattern: Regex, callback: HandlerCallback, block: bool) -> Self {
51        Self {
52            pattern,
53            callback,
54            block,
55        }
56    }
57}
58
59impl Handler for StringRegexHandler {
60    fn check_update(&self, update: &Update) -> Option<MatchResult> {
61        let message = update.effective_message()?;
62        let text = message.text.as_ref()?;
63
64        let caps = self.pattern.captures(text)?;
65
66        let positional: Vec<String> = caps
67            .iter()
68            .filter_map(|m| m.map(|m| m.as_str().to_owned()))
69            .collect();
70
71        // Collect named groups (only those that matched).
72        let mut named: HashMap<String, String> = HashMap::new();
73        for name in self.pattern.capture_names().flatten() {
74            if let Some(m) = caps.name(name) {
75                named.insert(name.to_owned(), m.as_str().to_owned());
76            }
77        }
78
79        if named.is_empty() {
80            Some(MatchResult::RegexMatch(positional))
81        } else {
82            Some(MatchResult::RegexMatchWithNames { positional, named })
83        }
84    }
85
86    fn handle_update(
87        &self,
88        update: Arc<Update>,
89        match_result: MatchResult,
90    ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
91        (self.callback)(update, match_result)
92    }
93
94    fn block(&self) -> bool {
95        self.block
96    }
97
98    /// Populate `context.matches` (positional) and `context.named_matches`
99    /// (named groups) from the regex match result.
100    fn collect_additional_context(
101        &self,
102        context: &mut CallbackContext,
103        match_result: &MatchResult,
104    ) {
105        match match_result {
106            MatchResult::RegexMatch(groups) => {
107                context.matches = Some(groups.clone());
108            }
109            MatchResult::RegexMatchWithNames { positional, named } => {
110                context.matches = Some(positional.clone());
111                context.named_matches = Some(named.clone());
112            }
113            _ => {}
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use std::sync::Arc;
121
122    use super::*;
123
124    fn noop_callback() -> HandlerCallback {
125        Arc::new(|_update, _mr| Box::pin(async { HandlerResult::Continue }))
126    }
127
128    fn make_message_update(text: &str) -> Update {
129        serde_json::from_str(&format!(
130            r#"{{"update_id":1,"message":{{"message_id":1,"date":0,"chat":{{"id":1,"type":"private"}},"text":"{text}"}}}}"#
131        ))
132        .unwrap()
133    }
134
135    #[test]
136    fn matches_regex() {
137        let h =
138            StringRegexHandler::new(Regex::new(r"^hello (\w+)").unwrap(), noop_callback(), true);
139        let update: Update = serde_json::from_str(
140            r#"{"update_id":1,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"},"text":"hello world"}}"#,
141        ).unwrap();
142        let result = h.check_update(&update);
143        assert!(result.is_some());
144        if let Some(MatchResult::RegexMatch(groups)) = result {
145            assert_eq!(groups.len(), 2);
146            assert_eq!(groups[0], "hello world");
147            assert_eq!(groups[1], "world");
148        } else {
149            panic!("expected RegexMatch");
150        }
151    }
152
153    #[test]
154    fn no_match_returns_none() {
155        let h = StringRegexHandler::new(Regex::new(r"^goodbye").unwrap(), noop_callback(), true);
156        let update: Update = serde_json::from_str(
157            r#"{"update_id":1,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"},"text":"hello world"}}"#,
158        ).unwrap();
159        assert!(h.check_update(&update).is_none());
160    }
161
162    #[test]
163    fn named_group_returns_regex_match_with_names() {
164        let h = StringRegexHandler::new(
165            Regex::new(r"^hello (?P<name>\w+)").unwrap(),
166            noop_callback(),
167            true,
168        );
169        let update = make_message_update("hello alice");
170        match h.check_update(&update) {
171            Some(MatchResult::RegexMatchWithNames { positional, named }) => {
172                assert_eq!(positional[0], "hello alice");
173                assert_eq!(named.get("name").map(String::as_str), Some("alice"));
174            }
175            other => panic!("expected RegexMatchWithNames, got {other:?}"),
176        }
177    }
178
179    #[test]
180    fn collect_context_populates_matches() {
181        use crate::context::CallbackContext;
182        use crate::ext_bot::test_support::mock_request;
183        use rust_tg_bot_raw::bot::Bot;
184
185        let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
186            "test",
187            mock_request(),
188        )));
189        let stores = (
190            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
191            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
192            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
193        );
194        let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
195
196        let h = StringRegexHandler::new(Regex::new(r"x").unwrap(), noop_callback(), true);
197        let mr = MatchResult::RegexMatch(vec!["hello".into()]);
198        h.collect_additional_context(&mut ctx, &mr);
199        assert_eq!(ctx.matches, Some(vec!["hello".into()]));
200        assert!(ctx.named_matches.is_none());
201    }
202
203    #[test]
204    fn collect_context_populates_named_matches() {
205        use crate::context::CallbackContext;
206        use crate::ext_bot::test_support::mock_request;
207        use rust_tg_bot_raw::bot::Bot;
208
209        let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
210            "test",
211            mock_request(),
212        )));
213        let stores = (
214            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
215            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
216            Arc::new(tokio::sync::RwLock::new(HashMap::new())),
217        );
218        let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
219
220        let h = StringRegexHandler::new(Regex::new(r"x").unwrap(), noop_callback(), true);
221        let mut named = HashMap::new();
222        named.insert("name".into(), "alice".into());
223        let mr = MatchResult::RegexMatchWithNames {
224            positional: vec!["hello alice".into(), "alice".into()],
225            named,
226        };
227        h.collect_additional_context(&mut ctx, &mr);
228        assert_eq!(
229            ctx.matches,
230            Some(vec!["hello alice".into(), "alice".into()])
231        );
232        assert_eq!(
233            ctx.named_matches
234                .as_ref()
235                .and_then(|m| m.get("name"))
236                .map(String::as_str),
237            Some("alice")
238        );
239    }
240}