1use std::collections::HashSet;
18use std::future::Future;
19use std::pin::Pin;
20use std::sync::Arc;
21
22use regex::Regex;
23use rust_tg_bot_raw::constants::MessageEntityType;
24use rust_tg_bot_raw::types::update::Update;
25
26use super::base::{ContextCallback, Handler, HandlerCallback, HandlerResult, MatchResult};
27use crate::context::CallbackContext;
28
29#[derive(Debug, Clone)]
31#[non_exhaustive]
32pub enum HasArgs {
33 Any,
35 NonEmpty,
37 None,
39 Exact(usize),
41}
42
43pub type UpdateFilter = Arc<dyn Fn(&Update) -> bool + Send + Sync>;
48
49pub struct CommandHandler {
86 commands: HashSet<String>,
88 callback: HandlerCallback,
89 has_args: HasArgs,
90 block: bool,
91 bot_username: Option<String>,
93 filter_fn: Option<UpdateFilter>,
97 context_callback: Option<ContextCallback>,
99}
100
101fn validate_command(cmd: &str) -> bool {
103 lazy_static_regex().is_match(cmd)
104}
105
106fn lazy_static_regex() -> &'static Regex {
107 use std::sync::OnceLock;
108 static RE: OnceLock<Regex> = OnceLock::new();
109 RE.get_or_init(|| Regex::new(r"^[a-z0-9_]{1,32}$").expect("command regex is valid"))
110}
111
112fn default_update_filter(update: &Update) -> bool {
114 update.message().is_some()
115}
116
117impl CommandHandler {
118 pub fn new<Cb, Fut>(command: impl Into<String>, callback: Cb) -> Self
137 where
138 Cb: Fn(Arc<Update>, CallbackContext) -> Fut + Send + Sync + 'static,
139 Fut: Future<Output = Result<(), crate::application::HandlerError>> + Send + 'static,
140 {
141 let cmd = command.into();
142 let cb = Arc::new(callback);
143 let context_cb: ContextCallback = Arc::new(move |update, ctx| {
144 let fut = cb(update, ctx);
145 Box::pin(fut)
146 as Pin<
147 Box<dyn Future<Output = Result<(), crate::application::HandlerError>> + Send>,
148 >
149 });
150
151 let noop_callback: HandlerCallback =
153 Arc::new(|_update, _mr| Box::pin(async { HandlerResult::Continue }));
154
155 let commands: HashSet<String> = {
156 let lower = cmd.to_lowercase();
157 assert!(
158 validate_command(&lower),
159 "Command `{lower}` is not a valid bot command"
160 );
161 let mut set = HashSet::new();
162 set.insert(lower);
163 set
164 };
165
166 Self {
167 commands,
168 callback: noop_callback,
169 has_args: HasArgs::Any,
170 block: true,
171 bot_username: None,
172 filter_fn: None,
173 context_callback: Some(context_cb),
174 }
175 }
176
177 pub fn with_options(
183 commands: Vec<String>,
184 callback: HandlerCallback,
185 has_args: Option<HasArgs>,
186 block: bool,
187 ) -> Self {
188 let commands: HashSet<String> = commands.into_iter().map(|c| c.to_lowercase()).collect();
189 for cmd in &commands {
190 assert!(
191 validate_command(cmd),
192 "Command `{cmd}` is not a valid bot command"
193 );
194 }
195 Self {
196 commands,
197 callback,
198 has_args: has_args.unwrap_or(HasArgs::Any),
199 block,
200 bot_username: None,
201 filter_fn: None,
202 context_callback: None,
203 }
204 }
205
206 pub fn with_bot_username(mut self, username: impl Into<String>) -> Self {
215 self.bot_username = Some(username.into().to_lowercase());
216 self
217 }
218
219 pub fn with_filter(mut self, filter: UpdateFilter) -> Self {
227 self.filter_fn = Some(filter);
228 self
229 }
230
231 fn check_args(&self, args: &[String]) -> bool {
233 match &self.has_args {
234 HasArgs::Any => true,
235 HasArgs::NonEmpty => !args.is_empty(),
236 HasArgs::None => args.is_empty(),
237 HasArgs::Exact(n) => args.len() == *n,
238 }
239 }
240}
241
242impl Handler for CommandHandler {
243 fn check_update(&self, update: &Update) -> Option<MatchResult> {
244 let passes_filter = match &self.filter_fn {
246 Some(f) => f(update),
247 None => default_update_filter(update),
248 };
249 if !passes_filter {
250 return None;
251 }
252
253 let message = update.effective_message()?;
254 let text = message.text.as_ref()?;
255 let entities = message.entities.as_ref()?;
256
257 let first_entity = entities.first()?;
259 if first_entity.entity_type != MessageEntityType::BotCommand {
260 return None;
261 }
262 if first_entity.offset != 0 {
263 return None;
264 }
265 let length = first_entity.length as usize;
266
267 let raw_command = &text[1..length];
269 let command_parts: Vec<&str> = raw_command.splitn(2, '@').collect();
270 let command_name = command_parts[0];
271
272 if command_parts.len() > 1 {
274 let at_suffix = command_parts[1];
276 if let Some(ref expected) = self.bot_username {
277 if at_suffix.to_lowercase() != *expected {
278 return None;
279 }
280 }
281 }
284 if !self.commands.contains(&command_name.to_lowercase()) {
289 return None;
290 }
291
292 let args: Vec<String> = text.split_whitespace().skip(1).map(String::from).collect();
294
295 if !self.check_args(&args) {
296 return None;
297 }
298
299 Some(MatchResult::Args(args))
300 }
301
302 fn handle_update(
303 &self,
304 update: Arc<Update>,
305 match_result: MatchResult,
306 ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
307 (self.callback)(update, match_result)
308 }
309
310 fn block(&self) -> bool {
311 self.block
312 }
313
314 fn collect_additional_context(
320 &self,
321 context: &mut CallbackContext,
322 match_result: &MatchResult,
323 ) {
324 if let MatchResult::Args(args) = match_result {
325 context.args = Some(args.clone());
326 }
327 }
328
329 fn handle_update_with_context(
330 &self,
331 update: Arc<Update>,
332 match_result: MatchResult,
333 context: CallbackContext,
334 ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
335 if let Some(ref cb) = self.context_callback {
336 let fut = cb(update, context);
337 Box::pin(async move {
338 match fut.await {
339 Ok(()) => HandlerResult::Continue,
340 Err(crate::application::HandlerError::HandlerStop { .. }) => {
341 HandlerResult::Stop
342 }
343 Err(crate::application::HandlerError::Other(e)) => HandlerResult::Error(e),
344 }
345 })
346 } else {
347 (self.callback)(update, match_result)
348 }
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use serde_json::json;
356 use std::sync::Arc;
357
358 fn noop_callback() -> HandlerCallback {
359 Arc::new(|_update, _mr| Box::pin(async { HandlerResult::Continue }))
360 }
361
362 fn make_command_update(text: &str) -> Update {
364 let cmd_part = text.split_whitespace().next().unwrap_or(text);
365 let entity_len = cmd_part.len();
366 serde_json::from_value(json!({
367 "update_id": 1,
368 "message": {
369 "message_id": 1,
370 "date": 0,
371 "chat": {"id": 1, "type": "private"},
372 "text": text,
373 "entities": [{"type": "bot_command", "offset": 0, "length": entity_len}]
374 }
375 }))
376 .expect("test update JSON must be valid")
377 }
378
379 fn make_edited_command_update(text: &str) -> Update {
381 let cmd_part = text.split_whitespace().next().unwrap_or(text);
382 let entity_len = cmd_part.len();
383 serde_json::from_value(json!({
384 "update_id": 1,
385 "edited_message": {
386 "message_id": 1,
387 "date": 0,
388 "chat": {"id": 1, "type": "private"},
389 "text": text,
390 "entities": [{"type": "bot_command", "offset": 0, "length": entity_len}]
391 }
392 }))
393 .expect("test update JSON must be valid")
394 }
395
396 #[test]
397 fn valid_commands_accepted() {
398 let h = CommandHandler::with_options(
399 vec!["start".into(), "help".into()],
400 noop_callback(),
401 None,
402 true,
403 );
404 assert!(h.commands.contains("start"));
405 assert!(h.commands.contains("help"));
406 }
407
408 #[test]
409 #[should_panic(expected = "not a valid bot command")]
410 fn invalid_command_panics() {
411 CommandHandler::with_options(vec!["invalid command!".into()], noop_callback(), None, true);
412 }
413
414 #[test]
415 fn check_args_variants() {
416 let h = CommandHandler::with_options(vec!["test".into()], noop_callback(), None, true);
417 assert!(h.check_args(&[]));
418 assert!(h.check_args(&["a".into()]));
419
420 let h2 = CommandHandler::with_options(
421 vec!["test".into()],
422 noop_callback(),
423 Some(HasArgs::Exact(2)),
424 true,
425 );
426 assert!(!h2.check_args(&["a".into()]));
427 assert!(h2.check_args(&["a".into(), "b".into()]));
428 }
429
430 #[test]
433 fn c1_bot_username_matching_accepted() {
434 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
435 .with_bot_username("MyBot");
436 let update = make_command_update("/start@MyBot");
437 assert!(h.check_update(&update).is_some());
438 }
439
440 #[test]
441 fn c1_bot_username_case_insensitive() {
442 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
443 .with_bot_username("mybot");
444 let update = make_command_update("/start@MYBOT");
445 assert!(h.check_update(&update).is_some());
446 }
447
448 #[test]
449 fn c1_wrong_bot_username_rejected() {
450 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
451 .with_bot_username("MyBot");
452 let update = make_command_update("/start@OtherBot");
453 assert!(h.check_update(&update).is_none());
454 }
455
456 #[test]
457 fn c1_no_at_suffix_still_accepted() {
458 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
459 .with_bot_username("MyBot");
460 let update = make_command_update("/start");
461 assert!(h.check_update(&update).is_some());
462 }
463
464 #[test]
465 fn c1_no_bot_username_configured_strips_suffix() {
466 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true);
467 let update = make_command_update("/start@AnyBot");
468 assert!(h.check_update(&update).is_some());
469 }
470
471 #[test]
474 fn c2_default_filter_accepts_edited_message() {
475 let update = make_edited_command_update("/start");
476 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true);
477 assert!(h.check_update(&update).is_some());
479 }
480
481 #[test]
482 fn c2_default_filter_accepts_channel_post() {
483 let update: Update = serde_json::from_value(json!({
484 "update_id": 1,
485 "channel_post": {
486 "message_id": 1,
487 "date": 0,
488 "chat": {"id": -100, "type": "channel"},
489 "text": "/start",
490 "entities": [{"type": "bot_command", "offset": 0, "length": 6}]
491 }
492 }))
493 .expect("valid");
494
495 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true);
496 assert!(h.check_update(&update).is_some());
497 }
498
499 #[test]
500 fn c2_custom_filter_allows_edited() {
501 let update = make_edited_command_update("/start");
502
503 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
504 .with_filter(Arc::new(|u| {
505 u.message().is_some() || u.edited_message().is_some()
506 }));
507 assert!(h.check_update(&update).is_some());
508 }
509
510 #[test]
511 fn c2_custom_filter_rejects() {
512 let update = make_command_update("/start");
513 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true)
514 .with_filter(Arc::new(|_u| false));
515 assert!(h.check_update(&update).is_none());
516 }
517
518 #[test]
521 fn collect_context_populates_args() {
522 use crate::context::CallbackContext;
523 use crate::ext_bot::test_support::mock_request;
524 use rust_tg_bot_raw::bot::Bot;
525 use std::collections::HashMap;
526
527 let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
528 "test",
529 mock_request(),
530 )));
531 let stores = (
532 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
533 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
534 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
535 );
536 let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
537
538 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true);
539 let mr = MatchResult::Args(vec!["foo".into(), "bar".into()]);
540 h.collect_additional_context(&mut ctx, &mr);
541
542 assert_eq!(ctx.args, Some(vec!["foo".into(), "bar".into()]));
543 }
544
545 #[test]
546 fn collect_context_no_op_for_empty() {
547 use crate::context::CallbackContext;
548 use crate::ext_bot::test_support::mock_request;
549 use rust_tg_bot_raw::bot::Bot;
550 use std::collections::HashMap;
551
552 let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
553 "test",
554 mock_request(),
555 )));
556 let stores = (
557 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
558 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
559 Arc::new(tokio::sync::RwLock::new(HashMap::new())),
560 );
561 let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
562
563 let h = CommandHandler::with_options(vec!["start".into()], noop_callback(), None, true);
564 h.collect_additional_context(&mut ctx, &MatchResult::Empty);
565
566 assert!(ctx.args.is_none());
567 }
568
569 #[test]
572 fn ergonomic_new_check_update_works() {
573 async fn dummy(
574 _update: Arc<Update>,
575 _ctx: CallbackContext,
576 ) -> Result<(), crate::application::HandlerError> {
577 Ok(())
578 }
579 let h = CommandHandler::new("start", dummy);
580 let update = make_command_update("/start");
581 assert!(h.check_update(&update).is_some());
582 }
583
584 #[test]
585 fn ergonomic_new_rejects_wrong_command() {
586 async fn dummy(
587 _update: Arc<Update>,
588 _ctx: CallbackContext,
589 ) -> Result<(), crate::application::HandlerError> {
590 Ok(())
591 }
592 let h = CommandHandler::new("start", dummy);
593 let update = make_command_update("/help");
594 assert!(h.check_update(&update).is_none());
595 }
596}