rust_tg_bot_ext/handlers/
string_regex.rs1use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use regex::Regex;
14use rust_tg_bot_raw::types::update::Update;
15
16use super::base::{Handler, HandlerCallback, HandlerResult, MatchResult};
17use crate::context::CallbackContext;
18
19pub struct StringRegexHandler {
43 pattern: Regex,
44 callback: HandlerCallback,
45 block: bool,
46}
47
48impl StringRegexHandler {
49 pub fn new(pattern: Regex, callback: HandlerCallback, block: bool) -> Self {
51 Self {
52 pattern,
53 callback,
54 block,
55 }
56 }
57}
58
59impl Handler for StringRegexHandler {
60 fn check_update(&self, update: &Update) -> Option<MatchResult> {
61 let message = update.effective_message()?;
62 let text = message.text.as_ref()?;
63
64 let caps = self.pattern.captures(text)?;
65
66 let positional: Vec<String> = caps
67 .iter()
68 .filter_map(|m| m.map(|m| m.as_str().to_owned()))
69 .collect();
70
71 let mut named: HashMap<String, String> = HashMap::new();
73 for name in self.pattern.capture_names().flatten() {
74 if let Some(m) = caps.name(name) {
75 named.insert(name.to_owned(), m.as_str().to_owned());
76 }
77 }
78
79 if named.is_empty() {
80 Some(MatchResult::RegexMatch(positional))
81 } else {
82 Some(MatchResult::RegexMatchWithNames { positional, named })
83 }
84 }
85
86 fn handle_update(
87 &self,
88 update: Arc<Update>,
89 match_result: MatchResult,
90 ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
91 (self.callback)(update, match_result)
92 }
93
94 fn block(&self) -> bool {
95 self.block
96 }
97
98 fn collect_additional_context(
101 &self,
102 context: &mut CallbackContext,
103 match_result: &MatchResult,
104 ) {
105 match match_result {
106 MatchResult::RegexMatch(groups) => {
107 context.matches = Some(groups.clone());
108 }
109 MatchResult::RegexMatchWithNames { positional, named } => {
110 context.matches = Some(positional.clone());
111 context.named_matches = Some(named.clone());
112 }
113 _ => {}
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use std::sync::Arc;
121
122 use super::*;
123
124 fn noop_callback() -> HandlerCallback {
125 Arc::new(|_update, _mr| Box::pin(async { HandlerResult::Continue }))
126 }
127
128 fn make_message_update(text: &str) -> Update {
129 serde_json::from_str(&format!(
130 r#"{{"update_id":1,"message":{{"message_id":1,"date":0,"chat":{{"id":1,"type":"private"}},"text":"{text}"}}}}"#
131 ))
132 .unwrap()
133 }
134
135 #[test]
136 fn matches_regex() {
137 let h =
138 StringRegexHandler::new(Regex::new(r"^hello (\w+)").unwrap(), noop_callback(), true);
139 let update: Update = serde_json::from_str(
140 r#"{"update_id":1,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"},"text":"hello world"}}"#,
141 ).unwrap();
142 let result = h.check_update(&update);
143 assert!(result.is_some());
144 if let Some(MatchResult::RegexMatch(groups)) = result {
145 assert_eq!(groups.len(), 2);
146 assert_eq!(groups[0], "hello world");
147 assert_eq!(groups[1], "world");
148 } else {
149 panic!("expected RegexMatch");
150 }
151 }
152
153 #[test]
154 fn no_match_returns_none() {
155 let h = StringRegexHandler::new(Regex::new(r"^goodbye").unwrap(), noop_callback(), true);
156 let update: Update = serde_json::from_str(
157 r#"{"update_id":1,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"},"text":"hello world"}}"#,
158 ).unwrap();
159 assert!(h.check_update(&update).is_none());
160 }
161
162 #[test]
163 fn named_group_returns_regex_match_with_names() {
164 let h = StringRegexHandler::new(
165 Regex::new(r"^hello (?P<name>\w+)").unwrap(),
166 noop_callback(),
167 true,
168 );
169 let update = make_message_update("hello alice");
170 match h.check_update(&update) {
171 Some(MatchResult::RegexMatchWithNames { positional, named }) => {
172 assert_eq!(positional[0], "hello alice");
173 assert_eq!(named.get("name").map(String::as_str), Some("alice"));
174 }
175 other => panic!("expected RegexMatchWithNames, got {other:?}"),
176 }
177 }
178
179 #[test]
180 fn collect_context_populates_matches() {
181 use crate::context::CallbackContext;
182 use crate::ext_bot::test_support::mock_request;
183 use rust_tg_bot_raw::bot::Bot;
184
185 let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
186 "test",
187 mock_request(),
188 )));
189 let stores = (
190 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
191 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
192 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
193 );
194 let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
195
196 let h = StringRegexHandler::new(Regex::new(r"x").unwrap(), noop_callback(), true);
197 let mr = MatchResult::RegexMatch(vec!["hello".into()]);
198 h.collect_additional_context(&mut ctx, &mr);
199 assert_eq!(ctx.matches, Some(vec!["hello".into()]));
200 assert!(ctx.named_matches.is_none());
201 }
202
203 #[test]
204 fn collect_context_populates_named_matches() {
205 use crate::context::CallbackContext;
206 use crate::ext_bot::test_support::mock_request;
207 use rust_tg_bot_raw::bot::Bot;
208
209 let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
210 "test",
211 mock_request(),
212 )));
213 let stores = (
214 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
215 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
216 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
217 );
218 let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
219
220 let h = StringRegexHandler::new(Regex::new(r"x").unwrap(), noop_callback(), true);
221 let mut named = HashMap::new();
222 named.insert("name".into(), "alice".into());
223 let mr = MatchResult::RegexMatchWithNames {
224 positional: vec!["hello alice".into(), "alice".into()],
225 named,
226 };
227 h.collect_additional_context(&mut ctx, &mr);
228 assert_eq!(
229 ctx.matches,
230 Some(vec!["hello alice".into(), "alice".into()])
231 );
232 assert_eq!(
233 ctx.named_matches
234 .as_ref()
235 .and_then(|m| m.get("name"))
236 .map(String::as_str),
237 Some("alice")
238 );
239 }
240}