rust_tg_bot_ext/handlers/
callback_query.rs1use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use regex::Regex;
14use rust_tg_bot_raw::types::update::Update;
15
16use super::base::{Handler, HandlerCallback, HandlerResult, MatchResult};
17use crate::context::CallbackContext;
18
19#[derive(Clone)]
21#[non_exhaustive]
22pub enum CallbackPattern {
23 Data(Regex),
25 Game(Regex),
27 Both {
29 data: Regex,
31 game: Regex,
33 },
34 Predicate(Arc<dyn Fn(&str) -> bool + Send + Sync>),
40}
41
42pub struct CallbackQueryHandler {
76 callback: HandlerCallback,
77 pattern: Option<CallbackPattern>,
78 block: bool,
79}
80
81impl CallbackQueryHandler {
82 pub fn new(callback: HandlerCallback, pattern: Option<CallbackPattern>, block: bool) -> Self {
84 Self {
85 callback,
86 pattern,
87 block,
88 }
89 }
90
91 fn try_regex(re: &Regex, text: &str) -> Option<MatchResult> {
98 let caps = re.captures(text)?;
99
100 let positional: Vec<String> = caps
101 .iter()
102 .filter_map(|m| m.map(|m| m.as_str().to_owned()))
103 .collect();
104
105 let mut named: HashMap<String, String> = HashMap::new();
109 for name in re.capture_names().flatten() {
110 if let Some(m) = caps.name(name) {
111 named.insert(name.to_owned(), m.as_str().to_owned());
112 }
113 }
114
115 if named.is_empty() {
116 Some(MatchResult::RegexMatch(positional))
117 } else {
118 Some(MatchResult::RegexMatchWithNames { positional, named })
119 }
120 }
121}
122
123impl Handler for CallbackQueryHandler {
124 fn check_update(&self, update: &Update) -> Option<MatchResult> {
125 let cq = update.callback_query()?;
126
127 match &self.pattern {
128 None => {
129 Some(MatchResult::Empty)
131 }
132 Some(CallbackPattern::Data(re)) => {
133 let data = cq.data.as_ref()?;
134 Self::try_regex(re, data)
135 }
136 Some(CallbackPattern::Game(re)) => {
137 let game = cq.game_short_name.as_ref()?;
138 Self::try_regex(re, game)
139 }
140 Some(CallbackPattern::Both {
141 data: dre,
142 game: gre,
143 }) => {
144 if let Some(data) = cq.data.as_ref() {
146 if let Some(mr) = Self::try_regex(dre, data) {
147 return Some(mr);
148 }
149 }
150 if let Some(game) = cq.game_short_name.as_ref() {
151 return Self::try_regex(gre, game);
152 }
153 None
154 }
155 Some(CallbackPattern::Predicate(pred)) => {
156 let data = cq.data.as_ref()?;
157 if pred(data) {
158 Some(MatchResult::Empty)
159 } else {
160 None
161 }
162 }
163 }
164 }
165
166 fn handle_update(
167 &self,
168 update: Arc<Update>,
169 match_result: MatchResult,
170 ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
171 (self.callback)(update, match_result)
172 }
173
174 fn block(&self) -> bool {
175 self.block
176 }
177
178 fn collect_additional_context(
184 &self,
185 context: &mut CallbackContext,
186 match_result: &MatchResult,
187 ) {
188 match match_result {
189 MatchResult::RegexMatch(groups) => {
190 context.matches = Some(groups.clone());
191 }
192 MatchResult::RegexMatchWithNames { positional, named } => {
193 context.matches = Some(positional.clone());
194 context.named_matches = Some(named.clone());
195 }
196 _ => {}
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use std::sync::Arc;
204
205 use super::*;
206
207 fn noop_callback() -> HandlerCallback {
208 Arc::new(|_update, _mr| Box::pin(async { HandlerResult::Continue }))
209 }
210
211 fn make_callback_query_update(data: &str) -> Update {
212 serde_json::from_value(serde_json::json!({
213 "update_id": 1,
214 "callback_query": {
215 "id": "42",
216 "from": {"id": 1, "is_bot": false, "first_name": "Test"},
217 "chat_instance": "ci",
218 "data": data
219 }
220 }))
221 .expect("valid callback query update")
222 }
223
224 #[test]
225 fn no_callback_query_returns_none() {
226 let h = CallbackQueryHandler::new(noop_callback(), None, true);
227 let update: Update = serde_json::from_str(r#"{"update_id": 1}"#).unwrap();
228 assert!(h.check_update(&update).is_none());
229 }
230
231 #[test]
232 fn predicate_accepts_matching_data() {
233 let h = CallbackQueryHandler::new(
234 noop_callback(),
235 Some(CallbackPattern::Predicate(Arc::new(|data| {
236 data.starts_with("btn_")
237 }))),
238 true,
239 );
240 let update = make_callback_query_update("btn_42");
241 assert!(h.check_update(&update).is_some());
242 }
243
244 #[test]
245 fn predicate_rejects_non_matching_data() {
246 let h = CallbackQueryHandler::new(
247 noop_callback(),
248 Some(CallbackPattern::Predicate(Arc::new(|data| {
249 data.starts_with("btn_")
250 }))),
251 true,
252 );
253 let update = make_callback_query_update("action_42");
254 assert!(h.check_update(&update).is_none());
255 }
256
257 #[test]
258 fn predicate_requires_data_field() {
259 let h = CallbackQueryHandler::new(
261 noop_callback(),
262 Some(CallbackPattern::Predicate(Arc::new(|_| true))),
263 true,
264 );
265 let update: Update = serde_json::from_value(serde_json::json!({
266 "update_id": 1,
267 "callback_query": {
268 "id": "42",
269 "from": {"id": 1, "is_bot": false, "first_name": "Test"},
270 "chat_instance": "ci",
271 "game_short_name": "mygame"
272 }
273 }))
274 .expect("valid");
275 assert!(h.check_update(&update).is_none());
276 }
277
278 #[test]
279 fn no_pattern_accepts_any_callback_query() {
280 let h = CallbackQueryHandler::new(noop_callback(), None, true);
281 let update = make_callback_query_update("anything");
282 assert!(h.check_update(&update).is_some());
283 }
284
285 #[test]
286 fn regex_data_pattern_matches() {
287 let h = CallbackQueryHandler::new(
288 noop_callback(),
289 Some(CallbackPattern::Data(Regex::new(r"^btn_(\d+)$").unwrap())),
290 true,
291 );
292 let update = make_callback_query_update("btn_123");
293 let result = h.check_update(&update);
294 assert!(result.is_some());
295 if let Some(MatchResult::RegexMatch(groups)) = result {
296 assert_eq!(groups[0], "btn_123");
297 assert_eq!(groups[1], "123");
298 } else {
299 panic!("expected RegexMatch");
300 }
301 }
302
303 #[test]
304 fn named_group_pattern_returns_regex_match_with_names() {
305 let re = Regex::new(r"^btn_(?P<id>\d+)$").unwrap();
306 let h = CallbackQueryHandler::new(noop_callback(), Some(CallbackPattern::Data(re)), true);
307 let update = make_callback_query_update("btn_99");
308 match h.check_update(&update) {
309 Some(MatchResult::RegexMatchWithNames { positional, named }) => {
310 assert_eq!(positional[0], "btn_99");
311 assert_eq!(named.get("id").map(String::as_str), Some("99"));
312 }
313 other => panic!("expected RegexMatchWithNames, got {other:?}"),
314 }
315 }
316
317 #[test]
318 fn collect_context_positional() {
319 use crate::context::CallbackContext;
320 use crate::ext_bot::test_support::mock_request;
321 use rust_tg_bot_raw::bot::Bot;
322
323 let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
324 "test",
325 mock_request(),
326 )));
327 let stores = (
328 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
329 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
330 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
331 );
332 let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
333
334 let h = CallbackQueryHandler::new(noop_callback(), None, true);
335 let mr = MatchResult::RegexMatch(vec!["full".into(), "group1".into()]);
336 h.collect_additional_context(&mut ctx, &mr);
337
338 assert_eq!(ctx.matches, Some(vec!["full".into(), "group1".into()]));
339 assert!(ctx.named_matches.is_none());
340 }
341
342 #[test]
343 fn collect_context_named() {
344 use crate::context::CallbackContext;
345 use crate::ext_bot::test_support::mock_request;
346 use rust_tg_bot_raw::bot::Bot;
347
348 let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
349 "test",
350 mock_request(),
351 )));
352 let stores = (
353 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
354 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
355 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
356 );
357 let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
358
359 let h = CallbackQueryHandler::new(noop_callback(), None, true);
360 let mut named = HashMap::new();
361 named.insert("id".into(), "99".into());
362 let mr = MatchResult::RegexMatchWithNames {
363 positional: vec!["btn_99".into(), "99".into()],
364 named,
365 };
366 h.collect_additional_context(&mut ctx, &mr);
367
368 assert_eq!(ctx.matches, Some(vec!["btn_99".into(), "99".into()]));
369 assert_eq!(
370 ctx.named_matches
371 .as_ref()
372 .and_then(|m| m.get("id"))
373 .map(String::as_str),
374 Some("99")
375 );
376 }
377}