1use std::collections::{HashMap, HashSet};
51use std::future::Future;
52use std::hash::Hash;
53use std::pin::Pin;
54use std::sync::Arc;
55use std::time::Duration;
56
57use tokio::sync::{watch, RwLock};
58use tracing::{debug, error, warn};
59
60use rust_tg_bot_raw::types::update::Update;
61
62use super::base::{Handler, HandlerResult, MatchResult};
63
64pub type ConversationKey = Vec<i64>;
73
74#[derive(Debug, Clone)]
81#[non_exhaustive]
82pub enum ConversationResult<S> {
83 NextState(S),
85 End,
87 Stay,
89}
90
91pub type ConversationCallback<S> = Arc<
100 dyn Fn(
101 Arc<Update>,
102 MatchResult,
103 ) -> Pin<Box<dyn Future<Output = (HandlerResult, ConversationResult<S>)> + Send>>
104 + Send
105 + Sync,
106>;
107
108pub struct ConversationStepHandler<S: Hash + Eq + Clone + Send + Sync + 'static> {
115 pub handler: Box<dyn Handler>,
117 pub conv_callback: ConversationCallback<S>,
119}
120
121pub struct ConversationHandler<S: Hash + Eq + Clone + Send + Sync + 'static> {
130 entry_points: Vec<ConversationStepHandler<S>>,
132 states: HashMap<S, Vec<ConversationStepHandler<S>>>,
134 fallbacks: Vec<ConversationStepHandler<S>>,
136
137 conversations: Arc<RwLock<HashMap<ConversationKey, S>>>,
139
140 allow_reentry: bool,
142 per_chat: bool,
144 per_user: bool,
146 per_message: bool,
148
149 conversation_timeout: Option<Duration>,
152
153 map_to_parent: Option<HashMap<S, S>>,
157
158 timeout_handlers: Vec<ConversationStepHandler<S>>,
161
162 timeout_cancellers: Arc<RwLock<HashMap<ConversationKey, watch::Sender<bool>>>>,
165
166 persistent: bool,
168
169 name: Option<String>,
171
172 pending_callbacks: Arc<RwLock<HashSet<ConversationKey>>>,
176}
177
178impl<S: Hash + Eq + Clone + Send + Sync + 'static> ConversationHandler<S> {
179 pub fn builder() -> ConversationHandlerBuilder<S> {
181 ConversationHandlerBuilder::default()
182 }
183
184 fn build_key(&self, update: &Update) -> Option<ConversationKey> {
186 let mut key = Vec::new();
187
188 if self.per_chat {
189 let chat = update.effective_chat()?;
190 key.push(chat.id);
191 }
192
193 if self.per_user {
194 let user = update.effective_user()?;
195 key.push(user.id);
196 }
197
198 if self.per_message {
199 let cq = update.callback_query()?;
200 if let Some(ref inline_id) = cq.inline_message_id {
201 use std::hash::Hasher;
202 let mut hasher = std::collections::hash_map::DefaultHasher::new();
203 hasher.write(inline_id.as_bytes());
204 key.push(hasher.finish() as i64);
205 } else if let Some(ref msg) = cq.message {
206 key.push(msg.message_id());
207 } else {
208 return None;
209 }
210 }
211
212 if key.is_empty() {
213 return None;
214 }
215
216 Some(key)
217 }
218
219 fn find_matching(
221 handlers: &[ConversationStepHandler<S>],
222 update: &Update,
223 ) -> Option<(usize, MatchResult)> {
224 for (idx, step) in handlers.iter().enumerate() {
225 if let Some(mr) = step.handler.check_update(update) {
226 return Some((idx, mr));
227 }
228 }
229 None
230 }
231
232 pub async fn get_state(&self, key: &ConversationKey) -> Option<S> {
234 self.conversations.read().await.get(key).cloned()
235 }
236
237 pub async fn active_conversations(&self) -> HashMap<ConversationKey, S> {
239 self.conversations.read().await.clone()
240 }
241
242 pub async fn load_conversations(&self, data: HashMap<ConversationKey, S>) {
246 *self.conversations.write().await = data;
247 }
248
249 pub async fn save_conversations(&self) -> HashMap<ConversationKey, S> {
251 self.conversations.read().await.clone()
252 }
253
254 pub fn is_persistent(&self) -> bool {
256 self.persistent
257 }
258
259 pub fn name(&self) -> Option<&str> {
261 self.name.as_deref()
262 }
263
264 async fn apply_state_transition(
269 conversations: &RwLock<HashMap<ConversationKey, S>>,
270 pending_callbacks: &RwLock<HashSet<ConversationKey>>,
271 key: &ConversationKey,
272 conv_result: ConversationResult<S>,
273 current_state: &Option<S>,
274 map_to_parent: &Option<HashMap<S, S>>,
275 ) -> Option<S> {
276 match conv_result {
277 ConversationResult::End => {
278 conversations.write().await.remove(key);
279 pending_callbacks.write().await.remove(key);
280 None
281 }
282 ConversationResult::Stay => current_state.clone(),
283 ConversationResult::NextState(s) => {
284 if let Some(ref mtp) = map_to_parent {
286 if mtp.contains_key(&s) {
287 conversations.write().await.remove(key);
288 pending_callbacks.write().await.remove(key);
289 debug!(
290 "ConversationHandler: map_to_parent triggered for key {:?}",
291 key
292 );
293 return None;
294 }
295 }
296 Some(s)
297 }
298 }
299 }
300
301 fn spawn_timeout(
304 conversations: Arc<RwLock<HashMap<ConversationKey, S>>>,
305 pending_callbacks: Arc<RwLock<HashSet<ConversationKey>>>,
306 timeout_cancellers: Arc<RwLock<HashMap<ConversationKey, watch::Sender<bool>>>>,
307 key: ConversationKey,
308 update: Arc<Update>,
309 duration: Duration,
310 timeout_cbs: Vec<ConversationCallback<S>>,
311 ) -> watch::Sender<bool> {
312 let (cancel_tx, mut cancel_rx) = watch::channel(false);
313 let key2 = key.clone();
314
315 tokio::spawn(async move {
316 tokio::select! {
317 _ = tokio::time::sleep(duration) => {
318 for cb in &timeout_cbs {
319 let _ = cb(update.clone(), MatchResult::Empty).await;
320 }
321 conversations.write().await.remove(&key2);
322 pending_callbacks.write().await.remove(&key2);
323 timeout_cancellers.write().await.remove(&key2);
324 debug!("Conversation {:?} timed out", key2);
325 }
326 _ = cancel_rx.changed() => {
327 debug!("Timeout cancelled for {:?}", key2);
328 }
329 }
330 });
331
332 cancel_tx
333 }
334}
335
336impl<S: Hash + Eq + Clone + Send + Sync + 'static> Handler for ConversationHandler<S> {
337 fn check_update(&self, update: &Update) -> Option<MatchResult> {
338 if update.channel_post().is_some() || update.edited_channel_post().is_some() {
340 return None;
341 }
342
343 let key = self.build_key(update)?;
344
345 if let Ok(pending) = self.pending_callbacks.try_read() {
347 if pending.contains(&key) {
348 debug!(
349 "ConversationHandler: skipping update for {:?} (pending callback)",
350 key
351 );
352 return None;
353 }
354 }
355
356 let current_state = match self.conversations.try_read() {
358 Ok(guard) => guard.get(&key).cloned(),
359 Err(_) => {
360 debug!(
361 "ConversationHandler: conversations lock contended, skipping {:?}",
362 key
363 );
364 return None;
365 }
366 };
367
368 match current_state {
369 None => {
370 if Self::find_matching(&self.entry_points, update).is_some() {
371 return Some(MatchResult::Empty);
372 }
373 }
374 Some(ref state) => {
375 if self.allow_reentry && Self::find_matching(&self.entry_points, update).is_some() {
376 return Some(MatchResult::Empty);
377 }
378
379 if let Some(handlers) = self.states.get(state) {
380 if Self::find_matching(handlers, update).is_some() {
381 return Some(MatchResult::Empty);
382 }
383 }
384
385 if Self::find_matching(&self.fallbacks, update).is_some() {
386 return Some(MatchResult::Empty);
387 }
388 }
389 }
390
391 None
392 }
393
394 fn handle_update(
395 &self,
396 update: Arc<Update>,
397 _match_result: MatchResult,
398 ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
399 let conversations = Arc::clone(&self.conversations);
400 let pending_callbacks = Arc::clone(&self.pending_callbacks);
401 let allow_reentry = self.allow_reentry;
402
403 #[derive(Debug, Clone, Copy)]
404 enum HandlerSource {
405 EntryPoint(usize),
406 State(usize),
407 Fallback(usize),
408 }
409
410 let key = self.build_key(&update);
411
412 let current_state = key.as_ref().and_then(|k| {
413 self.conversations
414 .try_read()
415 .ok()
416 .and_then(|g| g.get(k).cloned())
417 });
418
419 let mut source = None;
421 let mut match_result = MatchResult::Empty;
422
423 let check_entries = current_state.is_none() || allow_reentry;
424 if check_entries {
425 if let Some((idx, mr)) = Self::find_matching(&self.entry_points, &update) {
426 source = Some(HandlerSource::EntryPoint(idx));
427 match_result = mr;
428 }
429 }
430
431 if source.is_none() {
432 if let Some(ref state) = current_state {
433 if let Some(handlers) = self.states.get(state) {
434 if let Some((idx, mr)) = Self::find_matching(handlers, &update) {
435 source = Some(HandlerSource::State(idx));
436 match_result = mr;
437 }
438 }
439 }
440 }
441
442 if source.is_none() {
443 if let Some((idx, mr)) = Self::find_matching(&self.fallbacks, &update) {
444 source = Some(HandlerSource::Fallback(idx));
445 match_result = mr;
446 }
447 }
448
449 let conv_cb = match source {
451 Some(HandlerSource::EntryPoint(idx)) => {
452 Arc::clone(&self.entry_points[idx].conv_callback)
453 }
454 Some(HandlerSource::State(idx)) => {
455 let mut cb = None;
456 if let Some(ref state) = current_state {
457 if let Some(handlers) = self.states.get(state) {
458 if idx < handlers.len() {
459 cb = Some(Arc::clone(&handlers[idx].conv_callback));
460 }
461 }
462 }
463 cb.unwrap_or_else(|| {
464 Arc::new(|_u, _m| {
465 Box::pin(async { (HandlerResult::Continue, ConversationResult::Stay) })
466 })
467 })
468 }
469 Some(HandlerSource::Fallback(idx)) => Arc::clone(&self.fallbacks[idx].conv_callback),
470 None => {
471 return Box::pin(async { HandlerResult::Continue });
472 }
473 };
474
475 let is_entry = matches!(source, Some(HandlerSource::EntryPoint(_)));
476
477 let is_blocking = match source {
479 Some(HandlerSource::EntryPoint(idx)) => self.entry_points[idx].handler.block(),
480 Some(HandlerSource::State(idx)) => current_state
481 .as_ref()
482 .and_then(|s| self.states.get(s))
483 .and_then(|handlers| handlers.get(idx))
484 .map_or(true, |step| step.handler.block()),
485 Some(HandlerSource::Fallback(idx)) => self.fallbacks[idx].handler.block(),
486 None => true,
487 };
488
489 let map_to_parent = self.map_to_parent.clone();
490 let has_timeout = self.conversation_timeout.is_some();
491 let timeout_cancellers = Arc::clone(&self.timeout_cancellers);
492 let timeout_duration = self.conversation_timeout;
493 let timeout_cbs: Vec<_> = self
494 .timeout_handlers
495 .iter()
496 .map(|step| Arc::clone(&step.conv_callback))
497 .collect();
498
499 let is_persistent = self.persistent;
500 let _handler_name = self.name.clone();
501
502 Box::pin(async move {
503 let key = match key {
504 Some(k) => k,
505 None => return HandlerResult::Continue,
506 };
507
508 let current_state = conversations.read().await.get(&key).cloned();
509
510 if is_entry && current_state.is_some() && !allow_reentry {
511 debug!("ConversationHandler: ignoring re-entry for key {:?}", key);
512 return HandlerResult::Continue;
513 }
514
515 if has_timeout {
517 if let Some(tx) = timeout_cancellers.write().await.remove(&key) {
518 let _ = tx.send(true);
519 }
520 }
521
522 if !is_blocking {
529 pending_callbacks.write().await.insert(key.clone());
530
531 let conversations2 = Arc::clone(&conversations);
532 let pending2 = Arc::clone(&pending_callbacks);
533 let map_to_parent2 = map_to_parent.clone();
534 let key2 = key.clone();
535 let current_state2 = current_state.clone();
536 let update2 = update.clone();
537 let timeout_cancellers2 = Arc::clone(&timeout_cancellers);
538 let timeout_cbs2 = timeout_cbs;
539
540 tokio::spawn(async move {
541 let result = tokio::spawn(conv_cb(update2.clone(), match_result)).await;
543
544 match result {
545 Ok((_handler_result, conv_result)) => {
546 let new_state = Self::apply_state_transition(
547 &conversations2,
548 &pending2,
549 &key2,
550 conv_result,
551 ¤t_state2,
552 &map_to_parent2,
553 )
554 .await;
555
556 if let Some(new_s) = new_state {
557 conversations2.write().await.insert(key2.clone(), new_s);
558 }
559 }
560 Err(join_err) => {
561 error!(
564 "ConversationHandler: non-blocking callback failed for {:?}: {}. \
565 Reverting to previous state.",
566 key2, join_err
567 );
568 if let Some(ref prev) = current_state2 {
569 conversations2
570 .write()
571 .await
572 .insert(key2.clone(), prev.clone());
573 } else {
574 conversations2.write().await.remove(&key2);
575 }
576 }
577 }
578
579 pending2.write().await.remove(&key2);
581
582 if has_timeout {
584 if let Some(duration) = timeout_duration {
585 let cancel_tx = Self::spawn_timeout(
586 Arc::clone(&conversations2),
587 Arc::clone(&pending2),
588 Arc::clone(&timeout_cancellers2),
589 key2.clone(),
590 update2,
591 duration,
592 timeout_cbs2,
593 );
594 timeout_cancellers2.write().await.insert(key2, cancel_tx);
595 }
596 }
597 });
598
599 return HandlerResult::Continue;
600 }
601
602 let (handler_result, conv_result) = conv_cb(update.clone(), match_result).await;
604
605 let new_state = Self::apply_state_transition(
606 &conversations,
607 &pending_callbacks,
608 &key,
609 conv_result,
610 ¤t_state,
611 &map_to_parent,
612 )
613 .await;
614
615 if new_state.is_none() && !conversations.read().await.contains_key(&key) {
617 return handler_result;
618 }
619
620 if let Some(new_s) = new_state {
621 conversations.write().await.insert(key.clone(), new_s);
622 }
623
624 if has_timeout {
626 if let Some(duration) = timeout_duration {
627 let cancel_tx = Self::spawn_timeout(
628 Arc::clone(&conversations),
629 Arc::clone(&pending_callbacks),
630 Arc::clone(&timeout_cancellers),
631 key.clone(),
632 update,
633 duration,
634 timeout_cbs,
635 );
636 timeout_cancellers.write().await.insert(key, cancel_tx);
637 }
638 }
639
640 if is_persistent {
641 debug!("ConversationHandler: state changed (persistent handler)");
642 }
643
644 handler_result
645 })
646 }
647
648 fn block(&self) -> bool {
649 true
650 }
651}
652
653pub struct ConversationHandlerBuilder<S: Hash + Eq + Clone + Send + Sync + 'static> {
659 entry_points: Vec<ConversationStepHandler<S>>,
660 states: HashMap<S, Vec<ConversationStepHandler<S>>>,
661 fallbacks: Vec<ConversationStepHandler<S>>,
662 allow_reentry: bool,
663 per_chat: bool,
664 per_user: bool,
665 per_message: bool,
666 conversation_timeout: Option<Duration>,
667 name: Option<String>,
668 map_to_parent: Option<HashMap<S, S>>,
669 timeout_handlers: Vec<ConversationStepHandler<S>>,
670 persistent: bool,
671}
672
673impl<S: Hash + Eq + Clone + Send + Sync + 'static> Default for ConversationHandlerBuilder<S> {
674 fn default() -> Self {
675 Self {
676 entry_points: Vec::new(),
677 states: HashMap::new(),
678 fallbacks: Vec::new(),
679 allow_reentry: false,
680 per_chat: true,
681 per_user: true,
682 per_message: false,
683 conversation_timeout: None,
684 name: None,
685 map_to_parent: None,
686 timeout_handlers: Vec::new(),
687 persistent: false,
688 }
689 }
690}
691
692impl<S: Hash + Eq + Clone + Send + Sync + 'static> ConversationHandlerBuilder<S> {
693 pub fn entry_point(mut self, handler: ConversationStepHandler<S>) -> Self {
695 self.entry_points.push(handler);
696 self
697 }
698
699 pub fn entry_points(mut self, handlers: Vec<ConversationStepHandler<S>>) -> Self {
701 self.entry_points.extend(handlers);
702 self
703 }
704
705 pub fn state(mut self, state: S, handlers: Vec<ConversationStepHandler<S>>) -> Self {
707 self.states.insert(state, handlers);
708 self
709 }
710
711 pub fn fallback(mut self, handler: ConversationStepHandler<S>) -> Self {
713 self.fallbacks.push(handler);
714 self
715 }
716
717 pub fn fallbacks(mut self, handlers: Vec<ConversationStepHandler<S>>) -> Self {
719 self.fallbacks.extend(handlers);
720 self
721 }
722
723 pub fn allow_reentry(mut self, allow: bool) -> Self {
725 self.allow_reentry = allow;
726 self
727 }
728
729 pub fn per_chat(mut self, enabled: bool) -> Self {
731 self.per_chat = enabled;
732 self
733 }
734
735 pub fn per_user(mut self, enabled: bool) -> Self {
737 self.per_user = enabled;
738 self
739 }
740
741 pub fn per_message(mut self, enabled: bool) -> Self {
743 self.per_message = enabled;
744 self
745 }
746
747 pub fn conversation_timeout(mut self, timeout: Duration) -> Self {
749 self.conversation_timeout = Some(timeout);
750 self
751 }
752
753 pub fn name(mut self, name: String) -> Self {
755 self.name = Some(name);
756 self
757 }
758
759 pub fn map_to_parent(mut self, mapping: HashMap<S, S>) -> Self {
761 self.map_to_parent = Some(mapping);
762 self
763 }
764
765 pub fn timeout_handlers(mut self, handlers: Vec<ConversationStepHandler<S>>) -> Self {
767 self.timeout_handlers = handlers;
768 self
769 }
770
771 pub fn timeout_handler(mut self, handler: ConversationStepHandler<S>) -> Self {
773 self.timeout_handlers.push(handler);
774 self
775 }
776
777 pub fn persistent(mut self, enabled: bool) -> Self {
779 self.persistent = enabled;
780 self
781 }
782
783 pub fn build(self) -> ConversationHandler<S> {
790 assert!(
791 self.per_chat || self.per_user || self.per_message,
792 "At least one of per_chat, per_user, per_message must be true"
793 );
794
795 if self.persistent && self.name.is_none() {
796 panic!("Conversations can't be persistent when handler is unnamed");
797 }
798
799 if self.per_message && !self.per_chat {
800 warn!(
801 "ConversationHandler: per_message=true without per_chat=true \
802 -- message IDs are not globally unique"
803 );
804 }
805
806 ConversationHandler {
807 entry_points: self.entry_points,
808 states: self.states,
809 fallbacks: self.fallbacks,
810 conversations: Arc::new(RwLock::new(HashMap::new())),
811 allow_reentry: self.allow_reentry,
812 per_chat: self.per_chat,
813 per_user: self.per_user,
814 per_message: self.per_message,
815 conversation_timeout: self.conversation_timeout,
816 map_to_parent: self.map_to_parent,
817 timeout_handlers: self.timeout_handlers,
818 timeout_cancellers: Arc::new(RwLock::new(HashMap::new())),
819 persistent: self.persistent,
820 name: self.name,
821 pending_callbacks: Arc::new(RwLock::new(HashSet::new())),
822 }
823 }
824}
825
826#[cfg(test)]
831mod tests {
832 use super::*;
833 use serde_json::json;
834 use std::sync::Arc;
835
836 #[derive(Debug, Clone, Hash, Eq, PartialEq)]
839 enum TestState {
840 AskName,
841 AskAge,
842 }
843
844 fn always_match_handler() -> Box<dyn Handler> {
848 struct AlwaysMatch;
849 impl Handler for AlwaysMatch {
850 fn check_update(&self, update: &Update) -> Option<MatchResult> {
851 if update.message().is_some() {
852 Some(MatchResult::Empty)
853 } else {
854 None
855 }
856 }
857 fn handle_update(
858 &self,
859 _update: Arc<Update>,
860 _match_result: MatchResult,
861 ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
862 Box::pin(async { HandlerResult::Continue })
863 }
864 }
865 Box::new(AlwaysMatch)
866 }
867
868 fn never_match_handler() -> Box<dyn Handler> {
870 struct NeverMatch;
871 impl Handler for NeverMatch {
872 fn check_update(&self, _update: &Update) -> Option<MatchResult> {
873 None
874 }
875 fn handle_update(
876 &self,
877 _update: Arc<Update>,
878 _match_result: MatchResult,
879 ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
880 Box::pin(async { HandlerResult::Continue })
881 }
882 }
883 Box::new(NeverMatch)
884 }
885
886 fn make_step<S: Hash + Eq + Clone + Send + Sync + 'static>(
887 handler: Box<dyn Handler>,
888 result: ConversationResult<S>,
889 ) -> ConversationStepHandler<S> {
890 ConversationStepHandler {
891 handler,
892 conv_callback: Arc::new(move |_u, _m| {
893 let r = result.clone();
894 Box::pin(async move { (HandlerResult::Continue, r) })
895 }),
896 }
897 }
898
899 fn make_update(chat_id: i64, user_id: i64) -> Update {
900 serde_json::from_value(json!({
901 "update_id": 1,
902 "message": {
903 "message_id": 1,
904 "date": 0,
905 "chat": {"id": chat_id, "type": "private"},
906 "from": {"id": user_id, "is_bot": false, "first_name": "Test"}
907 }
908 }))
909 .expect("test update JSON must be valid")
910 }
911
912 fn make_channel_post_update() -> Update {
913 serde_json::from_value(json!({
914 "update_id": 1,
915 "channel_post": {
916 "message_id": 1,
917 "date": 0,
918 "chat": {"id": -100, "type": "channel", "title": "Test"}
919 }
920 }))
921 .expect("test update JSON must be valid")
922 }
923
924 #[tokio::test]
927 async fn state_transition_entry_to_state1_to_state2_to_end() {
928 let conv = ConversationHandler::builder()
929 .entry_point(make_step(
930 always_match_handler(),
931 ConversationResult::NextState(TestState::AskName),
932 ))
933 .state(
934 TestState::AskName,
935 vec![make_step(
936 always_match_handler(),
937 ConversationResult::NextState(TestState::AskAge),
938 )],
939 )
940 .state(
941 TestState::AskAge,
942 vec![make_step(always_match_handler(), ConversationResult::End)],
943 )
944 .build();
945
946 let key = vec![100i64, 42i64];
947 let update = Arc::new(make_update(100, 42));
948
949 assert!(conv.check_update(&update).is_some());
951
952 conv.handle_update(update.clone(), MatchResult::Empty).await;
954 assert_eq!(conv.get_state(&key).await, Some(TestState::AskName));
955
956 assert!(conv.check_update(&update).is_some());
958 conv.handle_update(update.clone(), MatchResult::Empty).await;
959 assert_eq!(conv.get_state(&key).await, Some(TestState::AskAge));
960
961 assert!(conv.check_update(&update).is_some());
963 conv.handle_update(update.clone(), MatchResult::Empty).await;
964 assert_eq!(conv.get_state(&key).await, None);
965 }
966
967 #[tokio::test]
968 async fn timeout_removes_conversation() {
969 let conv = ConversationHandler::builder()
970 .entry_point(make_step(
971 always_match_handler(),
972 ConversationResult::NextState(TestState::AskName),
973 ))
974 .state(
975 TestState::AskName,
976 vec![make_step(
977 always_match_handler(),
978 ConversationResult::NextState(TestState::AskAge),
979 )],
980 )
981 .conversation_timeout(Duration::from_millis(50))
982 .build();
983
984 let key = vec![100i64, 42i64];
985 let update = Arc::new(make_update(100, 42));
986
987 conv.handle_update(update.clone(), MatchResult::Empty).await;
989 assert_eq!(conv.get_state(&key).await, Some(TestState::AskName));
990
991 tokio::time::sleep(Duration::from_millis(120)).await;
993
994 assert_eq!(conv.get_state(&key).await, None);
996 }
997
998 #[tokio::test]
999 async fn fallback_triggers_on_unmatched_input() {
1000 let conv = ConversationHandler::builder()
1001 .entry_point(make_step(
1002 always_match_handler(),
1003 ConversationResult::NextState(TestState::AskName),
1004 ))
1005 .state(
1006 TestState::AskName,
1007 vec![make_step(
1008 never_match_handler(), ConversationResult::NextState(TestState::AskAge),
1010 )],
1011 )
1012 .fallback(make_step(always_match_handler(), ConversationResult::End))
1013 .build();
1014
1015 let key = vec![100i64, 42i64];
1016 let update = Arc::new(make_update(100, 42));
1017
1018 conv.handle_update(update.clone(), MatchResult::Empty).await;
1020 assert_eq!(conv.get_state(&key).await, Some(TestState::AskName));
1021
1022 assert!(conv.check_update(&update).is_some());
1025 conv.handle_update(update.clone(), MatchResult::Empty).await;
1026 assert_eq!(conv.get_state(&key).await, None);
1027 }
1028
1029 #[test]
1030 fn channel_post_returns_none() {
1031 let conv = ConversationHandler::<TestState>::builder()
1032 .entry_point(make_step(
1033 always_match_handler(),
1034 ConversationResult::NextState(TestState::AskName),
1035 ))
1036 .build();
1037
1038 let channel_update = make_channel_post_update();
1039 assert!(
1040 conv.check_update(&channel_update).is_none(),
1041 "Channel posts must be rejected by ConversationHandler"
1042 );
1043 }
1044
1045 #[tokio::test]
1046 async fn persistence_load_save_roundtrip() {
1047 let conv = ConversationHandler::<TestState>::builder()
1048 .entry_point(make_step(
1049 always_match_handler(),
1050 ConversationResult::NextState(TestState::AskName),
1051 ))
1052 .state(
1053 TestState::AskName,
1054 vec![make_step(
1055 always_match_handler(),
1056 ConversationResult::NextState(TestState::AskAge),
1057 )],
1058 )
1059 .name("test_conv".to_string())
1060 .persistent(true)
1061 .build();
1062
1063 let mut data = HashMap::new();
1065 data.insert(vec![1i64, 2i64], TestState::AskAge);
1066 data.insert(vec![3i64, 4i64], TestState::AskName);
1067 conv.load_conversations(data).await;
1068
1069 assert_eq!(
1071 conv.get_state(&vec![1i64, 2i64]).await,
1072 Some(TestState::AskAge)
1073 );
1074 assert_eq!(
1075 conv.get_state(&vec![3i64, 4i64]).await,
1076 Some(TestState::AskName)
1077 );
1078
1079 let saved = conv.save_conversations().await;
1081 assert_eq!(saved.len(), 2);
1082 assert_eq!(saved.get(&vec![1i64, 2i64]), Some(&TestState::AskAge));
1083 }
1084
1085 #[test]
1086 fn builder_name_and_persistence() {
1087 let conv = ConversationHandler::<TestState>::builder()
1088 .entry_point(make_step(
1089 always_match_handler(),
1090 ConversationResult::NextState(TestState::AskName),
1091 ))
1092 .name("my_conv".to_string())
1093 .persistent(true)
1094 .build();
1095
1096 assert!(conv.is_persistent());
1097 assert_eq!(conv.name(), Some("my_conv"));
1098 }
1099
1100 #[test]
1101 #[should_panic(expected = "At least one of per_chat, per_user, per_message must be true")]
1102 fn builder_panics_without_key_components() {
1103 ConversationHandler::<TestState>::builder()
1104 .per_chat(false)
1105 .per_user(false)
1106 .per_message(false)
1107 .build();
1108 }
1109
1110 #[test]
1111 #[should_panic(expected = "Conversations can't be persistent when handler is unnamed")]
1112 fn builder_panics_persistent_without_name() {
1113 ConversationHandler::<TestState>::builder()
1114 .persistent(true)
1115 .build();
1116 }
1117}