Skip to main content

rust_tg_bot_ext/handlers/
conversation.rs

1//! [`ConversationHandler`] -- stateful multi-step conversation handler.
2//!
3//! Ported from `python-telegram-bot`'s `ConversationHandler`. This is the
4//! most complex handler in the system. It manages a state machine per
5//! conversation, routing updates through entry points, state-specific
6//! handlers, and fallback handlers.
7//!
8//! # Design
9//!
10//! The handler is generic over a state type `S` that must be `Hash + Eq +
11//! Clone + Send + Sync + 'static`. State is tracked per conversation key,
12//! which is a tuple of `(chat_id, user_id)` by default.
13//!
14//! Callbacks return [`ConversationResult<S>`] to control state transitions.
15//!
16//! # Fixes implemented
17//!
18//! - **C3**: `check_update` is now state-aware via `RwLock::try_read()`.
19//!   Only the relevant handler list is checked based on current state.
20//! - **C4**: `map_to_parent` support for nested conversations.
21//! - **C5**: Timeout scheduling via `tokio::spawn` + `tokio::time::sleep`
22//!   with cancellation via `tokio::sync::watch`.
23//! - **C6**: Persistence integration with `load_conversations` / `save_conversations`.
24//! - **M12**: Channel posts and edited channel posts are rejected.
25//! - **M13**: WAITING state -- pending callbacks tracked, updates skipped while busy.
26//! - **PendingState**: Non-blocking callbacks are spawned via `tokio::spawn`
27//!   with result capture. On error the conversation reverts to the previous
28//!   state instead of leaving the conversation in limbo.
29//!
30//! # Example
31//!
32//! ```rust,ignore
33//! use rust_tg_bot_ext::handlers::conversation::*;
34//! use rust_tg_bot_ext::handlers::base::*;
35//! use rust_tg_bot_ext::handlers::command::CommandHandler;
36//! use std::sync::Arc;
37//! use std::collections::HashMap;
38//!
39//! #[derive(Clone, Hash, Eq, PartialEq)]
40//! enum State { AskName, AskAge }
41//!
42//! let conv = ConversationHandler::builder()
43//!     .entry_point(Box::new(start_handler))
44//!     .state(State::AskName, vec![Box::new(name_handler)])
45//!     .state(State::AskAge, vec![Box::new(age_handler)])
46//!     .fallback(Box::new(cancel_handler))
47//!     .build();
48//! ```
49
50use std::collections::{HashMap, HashSet};
51use std::future::Future;
52use std::hash::Hash;
53use std::pin::Pin;
54use std::sync::Arc;
55use std::time::Duration;
56
57use tokio::sync::{watch, RwLock};
58use tracing::{debug, error, warn};
59
60use rust_tg_bot_raw::types::update::Update;
61
62use super::base::{Handler, HandlerResult, MatchResult};
63
64// ---------------------------------------------------------------------------
65// Conversation key
66// ---------------------------------------------------------------------------
67
68/// The key that identifies a unique conversation.
69///
70/// By default this is `(chat_id, user_id)`, but the components included
71/// depend on the `per_chat`, `per_user`, and `per_message` flags.
72pub type ConversationKey = Vec<i64>;
73
74// ---------------------------------------------------------------------------
75// Conversation result
76// ---------------------------------------------------------------------------
77
78/// The result returned by a conversation step callback, controlling the
79/// state machine transition.
80#[derive(Debug, Clone)]
81#[non_exhaustive]
82pub enum ConversationResult<S> {
83    /// Transition to the given state.
84    NextState(S),
85    /// End the conversation (remove the key from the state map).
86    End,
87    /// Stay in the current state (no transition).
88    Stay,
89}
90
91// ---------------------------------------------------------------------------
92// Conversation callback
93// ---------------------------------------------------------------------------
94
95/// A conversation step callback.
96///
97/// Unlike the base `HandlerCallback`, this returns a `ConversationResult<S>`
98/// alongside the `HandlerResult`.
99pub type ConversationCallback<S> = Arc<
100    dyn Fn(
101            Arc<Update>,
102            MatchResult,
103        ) -> Pin<Box<dyn Future<Output = (HandlerResult, ConversationResult<S>)> + Send>>
104        + Send
105        + Sync,
106>;
107
108// ---------------------------------------------------------------------------
109// Conversation step handler
110// ---------------------------------------------------------------------------
111
112/// A handler that participates in a conversation. It wraps a base `Handler`
113/// for the matching logic, plus a conversation-aware callback.
114pub struct ConversationStepHandler<S: Hash + Eq + Clone + Send + Sync + 'static> {
115    /// The underlying handler used for `check_update`.
116    pub handler: Box<dyn Handler>,
117    /// The callback that produces a state transition.
118    pub conv_callback: ConversationCallback<S>,
119}
120
121// ---------------------------------------------------------------------------
122// Conversation handler
123// ---------------------------------------------------------------------------
124
125/// Stateful multi-step conversation handler, generic over state type `S`.
126///
127/// Manages a state machine per conversation key, dispatching updates through
128/// entry points, state handlers, and fallbacks.
129pub struct ConversationHandler<S: Hash + Eq + Clone + Send + Sync + 'static> {
130    /// Handlers that can initiate a new conversation.
131    entry_points: Vec<ConversationStepHandler<S>>,
132    /// Per-state handler lists.
133    states: HashMap<S, Vec<ConversationStepHandler<S>>>,
134    /// Fallback handlers tried when no state handler matches.
135    fallbacks: Vec<ConversationStepHandler<S>>,
136
137    /// Current conversation states, keyed by conversation key.
138    conversations: Arc<RwLock<HashMap<ConversationKey, S>>>,
139
140    /// Whether a user already in a conversation can restart via entry points.
141    allow_reentry: bool,
142    /// Whether to include chat ID in the conversation key.
143    per_chat: bool,
144    /// Whether to include user ID in the conversation key.
145    per_user: bool,
146    /// Whether to include message ID (from callback query) in the key.
147    per_message: bool,
148
149    /// Optional conversation timeout. After this duration of inactivity the
150    /// conversation is ended.
151    conversation_timeout: Option<Duration>,
152
153    /// C4: Mapping from child state to parent state for nested conversations.
154    /// When a callback returns a state that is present in this map, the
155    /// conversation ends and the mapped parent state is returned.
156    map_to_parent: Option<HashMap<S, S>>,
157
158    /// C5: Handlers for the TIMEOUT state. When a conversation times out,
159    /// all matching timeout handlers are run before removing the conversation.
160    timeout_handlers: Vec<ConversationStepHandler<S>>,
161
162    /// C5: Per-conversation timeout cancellation senders. Sending a value
163    /// through the watch channel cancels the pending timeout task.
164    timeout_cancellers: Arc<RwLock<HashMap<ConversationKey, watch::Sender<bool>>>>,
165
166    /// C6: Whether this conversation's state should be persisted.
167    persistent: bool,
168
169    /// C6: Optional name for persistence. Required when `persistent` is true.
170    name: Option<String>,
171
172    /// M13: Set of conversation keys that have a non-blocking callback in
173    /// progress (WAITING state). Updates for these keys are skipped until
174    /// the callback completes.
175    pending_callbacks: Arc<RwLock<HashSet<ConversationKey>>>,
176}
177
178impl<S: Hash + Eq + Clone + Send + Sync + 'static> ConversationHandler<S> {
179    /// Create a builder for constructing a `ConversationHandler`.
180    pub fn builder() -> ConversationHandlerBuilder<S> {
181        ConversationHandlerBuilder::default()
182    }
183
184    /// Build the conversation key for a given update.
185    fn build_key(&self, update: &Update) -> Option<ConversationKey> {
186        let mut key = Vec::new();
187
188        if self.per_chat {
189            let chat = update.effective_chat()?;
190            key.push(chat.id);
191        }
192
193        if self.per_user {
194            let user = update.effective_user()?;
195            key.push(user.id);
196        }
197
198        if self.per_message {
199            let cq = update.callback_query()?;
200            if let Some(ref inline_id) = cq.inline_message_id {
201                use std::hash::Hasher;
202                let mut hasher = std::collections::hash_map::DefaultHasher::new();
203                hasher.write(inline_id.as_bytes());
204                key.push(hasher.finish() as i64);
205            } else if let Some(ref msg) = cq.message {
206                key.push(msg.message_id());
207            } else {
208                return None;
209            }
210        }
211
212        if key.is_empty() {
213            return None;
214        }
215
216        Some(key)
217    }
218
219    /// Try to find a matching handler in the given list.
220    fn find_matching(
221        handlers: &[ConversationStepHandler<S>],
222        update: &Update,
223    ) -> Option<(usize, MatchResult)> {
224        for (idx, step) in handlers.iter().enumerate() {
225            if let Some(mr) = step.handler.check_update(update) {
226                return Some((idx, mr));
227            }
228        }
229        None
230    }
231
232    /// Get the current state for a conversation key.
233    pub async fn get_state(&self, key: &ConversationKey) -> Option<S> {
234        self.conversations.read().await.get(key).cloned()
235    }
236
237    /// Get a read-only snapshot of all active conversations.
238    pub async fn active_conversations(&self) -> HashMap<ConversationKey, S> {
239        self.conversations.read().await.clone()
240    }
241
242    // -- C6: Persistence --------------------------------------------------
243
244    /// Load previously-persisted conversations into this handler.
245    pub async fn load_conversations(&self, data: HashMap<ConversationKey, S>) {
246        *self.conversations.write().await = data;
247    }
248
249    /// Export the current conversation state for persistence.
250    pub async fn save_conversations(&self) -> HashMap<ConversationKey, S> {
251        self.conversations.read().await.clone()
252    }
253
254    /// Whether this handler is configured for persistence.
255    pub fn is_persistent(&self) -> bool {
256        self.persistent
257    }
258
259    /// The handler's name (required for persistence).
260    pub fn name(&self) -> Option<&str> {
261        self.name.as_deref()
262    }
263
264    /// Apply a conversation state transition, handling End and map_to_parent.
265    ///
266    /// Returns `Some(new_state)` if the conversation continues, or `None` if
267    /// the conversation was ended (either explicitly or via map_to_parent).
268    async fn apply_state_transition(
269        conversations: &RwLock<HashMap<ConversationKey, S>>,
270        pending_callbacks: &RwLock<HashSet<ConversationKey>>,
271        key: &ConversationKey,
272        conv_result: ConversationResult<S>,
273        current_state: &Option<S>,
274        map_to_parent: &Option<HashMap<S, S>>,
275    ) -> Option<S> {
276        match conv_result {
277            ConversationResult::End => {
278                conversations.write().await.remove(key);
279                pending_callbacks.write().await.remove(key);
280                None
281            }
282            ConversationResult::Stay => current_state.clone(),
283            ConversationResult::NextState(s) => {
284                // C4: Check map_to_parent.
285                if let Some(ref mtp) = map_to_parent {
286                    if mtp.contains_key(&s) {
287                        conversations.write().await.remove(key);
288                        pending_callbacks.write().await.remove(key);
289                        debug!(
290                            "ConversationHandler: map_to_parent triggered for key {:?}",
291                            key
292                        );
293                        return None;
294                    }
295                }
296                Some(s)
297            }
298        }
299    }
300
301    /// Spawn a timeout task for the given conversation key after a callback
302    /// completes. Shared between the blocking and non-blocking paths.
303    fn spawn_timeout(
304        conversations: Arc<RwLock<HashMap<ConversationKey, S>>>,
305        pending_callbacks: Arc<RwLock<HashSet<ConversationKey>>>,
306        timeout_cancellers: Arc<RwLock<HashMap<ConversationKey, watch::Sender<bool>>>>,
307        key: ConversationKey,
308        update: Arc<Update>,
309        duration: Duration,
310        timeout_cbs: Vec<ConversationCallback<S>>,
311    ) -> watch::Sender<bool> {
312        let (cancel_tx, mut cancel_rx) = watch::channel(false);
313        let key2 = key.clone();
314
315        tokio::spawn(async move {
316            tokio::select! {
317                _ = tokio::time::sleep(duration) => {
318                    for cb in &timeout_cbs {
319                        let _ = cb(update.clone(), MatchResult::Empty).await;
320                    }
321                    conversations.write().await.remove(&key2);
322                    pending_callbacks.write().await.remove(&key2);
323                    timeout_cancellers.write().await.remove(&key2);
324                    debug!("Conversation {:?} timed out", key2);
325                }
326                _ = cancel_rx.changed() => {
327                    debug!("Timeout cancelled for {:?}", key2);
328                }
329            }
330        });
331
332        cancel_tx
333    }
334}
335
336impl<S: Hash + Eq + Clone + Send + Sync + 'static> Handler for ConversationHandler<S> {
337    fn check_update(&self, update: &Update) -> Option<MatchResult> {
338        // ── M12: Reject channel posts and edited channel posts ───────────
339        if update.channel_post().is_some() || update.edited_channel_post().is_some() {
340            return None;
341        }
342
343        let key = self.build_key(update)?;
344
345        // ── M13: Skip if a pending callback is in progress for this key ──
346        if let Ok(pending) = self.pending_callbacks.try_read() {
347            if pending.contains(&key) {
348                debug!(
349                    "ConversationHandler: skipping update for {:?} (pending callback)",
350                    key
351                );
352                return None;
353            }
354        }
355
356        // ── C3: State-aware handler selection via try_read() ─────────────
357        let current_state = match self.conversations.try_read() {
358            Ok(guard) => guard.get(&key).cloned(),
359            Err(_) => {
360                debug!(
361                    "ConversationHandler: conversations lock contended, skipping {:?}",
362                    key
363                );
364                return None;
365            }
366        };
367
368        match current_state {
369            None => {
370                if Self::find_matching(&self.entry_points, update).is_some() {
371                    return Some(MatchResult::Empty);
372                }
373            }
374            Some(ref state) => {
375                if self.allow_reentry && Self::find_matching(&self.entry_points, update).is_some() {
376                    return Some(MatchResult::Empty);
377                }
378
379                if let Some(handlers) = self.states.get(state) {
380                    if Self::find_matching(handlers, update).is_some() {
381                        return Some(MatchResult::Empty);
382                    }
383                }
384
385                if Self::find_matching(&self.fallbacks, update).is_some() {
386                    return Some(MatchResult::Empty);
387                }
388            }
389        }
390
391        None
392    }
393
394    fn handle_update(
395        &self,
396        update: Arc<Update>,
397        _match_result: MatchResult,
398    ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
399        let conversations = Arc::clone(&self.conversations);
400        let pending_callbacks = Arc::clone(&self.pending_callbacks);
401        let allow_reentry = self.allow_reentry;
402
403        #[derive(Debug, Clone, Copy)]
404        enum HandlerSource {
405            EntryPoint(usize),
406            State(usize),
407            Fallback(usize),
408        }
409
410        let key = self.build_key(&update);
411
412        let current_state = key.as_ref().and_then(|k| {
413            self.conversations
414                .try_read()
415                .ok()
416                .and_then(|g| g.get(k).cloned())
417        });
418
419        // ── State-aware handler matching ─────────────────────────────────
420        let mut source = None;
421        let mut match_result = MatchResult::Empty;
422
423        let check_entries = current_state.is_none() || allow_reentry;
424        if check_entries {
425            if let Some((idx, mr)) = Self::find_matching(&self.entry_points, &update) {
426                source = Some(HandlerSource::EntryPoint(idx));
427                match_result = mr;
428            }
429        }
430
431        if source.is_none() {
432            if let Some(ref state) = current_state {
433                if let Some(handlers) = self.states.get(state) {
434                    if let Some((idx, mr)) = Self::find_matching(handlers, &update) {
435                        source = Some(HandlerSource::State(idx));
436                        match_result = mr;
437                    }
438                }
439            }
440        }
441
442        if source.is_none() {
443            if let Some((idx, mr)) = Self::find_matching(&self.fallbacks, &update) {
444                source = Some(HandlerSource::Fallback(idx));
445                match_result = mr;
446            }
447        }
448
449        // Resolve the callback Arc.
450        let conv_cb = match source {
451            Some(HandlerSource::EntryPoint(idx)) => {
452                Arc::clone(&self.entry_points[idx].conv_callback)
453            }
454            Some(HandlerSource::State(idx)) => {
455                let mut cb = None;
456                if let Some(ref state) = current_state {
457                    if let Some(handlers) = self.states.get(state) {
458                        if idx < handlers.len() {
459                            cb = Some(Arc::clone(&handlers[idx].conv_callback));
460                        }
461                    }
462                }
463                cb.unwrap_or_else(|| {
464                    Arc::new(|_u, _m| {
465                        Box::pin(async { (HandlerResult::Continue, ConversationResult::Stay) })
466                    })
467                })
468            }
469            Some(HandlerSource::Fallback(idx)) => Arc::clone(&self.fallbacks[idx].conv_callback),
470            None => {
471                return Box::pin(async { HandlerResult::Continue });
472            }
473        };
474
475        let is_entry = matches!(source, Some(HandlerSource::EntryPoint(_)));
476
477        // Determine if the matched step handler is non-blocking.
478        let is_blocking = match source {
479            Some(HandlerSource::EntryPoint(idx)) => self.entry_points[idx].handler.block(),
480            Some(HandlerSource::State(idx)) => current_state
481                .as_ref()
482                .and_then(|s| self.states.get(s))
483                .and_then(|handlers| handlers.get(idx))
484                .map_or(true, |step| step.handler.block()),
485            Some(HandlerSource::Fallback(idx)) => self.fallbacks[idx].handler.block(),
486            None => true,
487        };
488
489        let map_to_parent = self.map_to_parent.clone();
490        let has_timeout = self.conversation_timeout.is_some();
491        let timeout_cancellers = Arc::clone(&self.timeout_cancellers);
492        let timeout_duration = self.conversation_timeout;
493        let timeout_cbs: Vec<_> = self
494            .timeout_handlers
495            .iter()
496            .map(|step| Arc::clone(&step.conv_callback))
497            .collect();
498
499        let is_persistent = self.persistent;
500        let _handler_name = self.name.clone();
501
502        Box::pin(async move {
503            let key = match key {
504                Some(k) => k,
505                None => return HandlerResult::Continue,
506            };
507
508            let current_state = conversations.read().await.get(&key).cloned();
509
510            if is_entry && current_state.is_some() && !allow_reentry {
511                debug!("ConversationHandler: ignoring re-entry for key {:?}", key);
512                return HandlerResult::Continue;
513            }
514
515            // C5: Cancel any existing timeout before running the callback.
516            if has_timeout {
517                if let Some(tx) = timeout_cancellers.write().await.remove(&key) {
518                    let _ = tx.send(true);
519                }
520            }
521
522            // ── PendingState: Non-blocking callback resolution ───────────
523            //
524            // When the step handler is non-blocking, spawn the callback via
525            // `tokio::spawn`. On success, apply the state transition normally.
526            // On error (task panic / cancellation), revert to the previous
527            // state instead of leaving the conversation in limbo.
528            if !is_blocking {
529                pending_callbacks.write().await.insert(key.clone());
530
531                let conversations2 = Arc::clone(&conversations);
532                let pending2 = Arc::clone(&pending_callbacks);
533                let map_to_parent2 = map_to_parent.clone();
534                let key2 = key.clone();
535                let current_state2 = current_state.clone();
536                let update2 = update.clone();
537                let timeout_cancellers2 = Arc::clone(&timeout_cancellers);
538                let timeout_cbs2 = timeout_cbs;
539
540                tokio::spawn(async move {
541                    // Spawn the callback itself so we can catch panics via JoinError.
542                    let result = tokio::spawn(conv_cb(update2.clone(), match_result)).await;
543
544                    match result {
545                        Ok((_handler_result, conv_result)) => {
546                            let new_state = Self::apply_state_transition(
547                                &conversations2,
548                                &pending2,
549                                &key2,
550                                conv_result,
551                                &current_state2,
552                                &map_to_parent2,
553                            )
554                            .await;
555
556                            if let Some(new_s) = new_state {
557                                conversations2.write().await.insert(key2.clone(), new_s);
558                            }
559                        }
560                        Err(join_err) => {
561                            // Callback panicked or was cancelled -- revert to
562                            // the previous state.
563                            error!(
564                                "ConversationHandler: non-blocking callback failed for {:?}: {}. \
565                                 Reverting to previous state.",
566                                key2, join_err
567                            );
568                            if let Some(ref prev) = current_state2 {
569                                conversations2
570                                    .write()
571                                    .await
572                                    .insert(key2.clone(), prev.clone());
573                            } else {
574                                conversations2.write().await.remove(&key2);
575                            }
576                        }
577                    }
578
579                    // Remove from pending set.
580                    pending2.write().await.remove(&key2);
581
582                    // Reschedule timeout if configured.
583                    if has_timeout {
584                        if let Some(duration) = timeout_duration {
585                            let cancel_tx = Self::spawn_timeout(
586                                Arc::clone(&conversations2),
587                                Arc::clone(&pending2),
588                                Arc::clone(&timeout_cancellers2),
589                                key2.clone(),
590                                update2,
591                                duration,
592                                timeout_cbs2,
593                            );
594                            timeout_cancellers2.write().await.insert(key2, cancel_tx);
595                        }
596                    }
597                });
598
599                return HandlerResult::Continue;
600            }
601
602            // ── Blocking callback path ───────────────────────────────────
603            let (handler_result, conv_result) = conv_cb(update.clone(), match_result).await;
604
605            let new_state = Self::apply_state_transition(
606                &conversations,
607                &pending_callbacks,
608                &key,
609                conv_result,
610                &current_state,
611                &map_to_parent,
612            )
613            .await;
614
615            // For End and map_to_parent, apply_state_transition already cleaned up.
616            if new_state.is_none() && !conversations.read().await.contains_key(&key) {
617                return handler_result;
618            }
619
620            if let Some(new_s) = new_state {
621                conversations.write().await.insert(key.clone(), new_s);
622            }
623
624            // C5: Reschedule timeout after successful state transition.
625            if has_timeout {
626                if let Some(duration) = timeout_duration {
627                    let cancel_tx = Self::spawn_timeout(
628                        Arc::clone(&conversations),
629                        Arc::clone(&pending_callbacks),
630                        Arc::clone(&timeout_cancellers),
631                        key.clone(),
632                        update,
633                        duration,
634                        timeout_cbs,
635                    );
636                    timeout_cancellers.write().await.insert(key, cancel_tx);
637                }
638            }
639
640            if is_persistent {
641                debug!("ConversationHandler: state changed (persistent handler)");
642            }
643
644            handler_result
645        })
646    }
647
648    fn block(&self) -> bool {
649        true
650    }
651}
652
653// ---------------------------------------------------------------------------
654// Builder
655// ---------------------------------------------------------------------------
656
657/// Builder for [`ConversationHandler`].
658pub struct ConversationHandlerBuilder<S: Hash + Eq + Clone + Send + Sync + 'static> {
659    entry_points: Vec<ConversationStepHandler<S>>,
660    states: HashMap<S, Vec<ConversationStepHandler<S>>>,
661    fallbacks: Vec<ConversationStepHandler<S>>,
662    allow_reentry: bool,
663    per_chat: bool,
664    per_user: bool,
665    per_message: bool,
666    conversation_timeout: Option<Duration>,
667    name: Option<String>,
668    map_to_parent: Option<HashMap<S, S>>,
669    timeout_handlers: Vec<ConversationStepHandler<S>>,
670    persistent: bool,
671}
672
673impl<S: Hash + Eq + Clone + Send + Sync + 'static> Default for ConversationHandlerBuilder<S> {
674    fn default() -> Self {
675        Self {
676            entry_points: Vec::new(),
677            states: HashMap::new(),
678            fallbacks: Vec::new(),
679            allow_reentry: false,
680            per_chat: true,
681            per_user: true,
682            per_message: false,
683            conversation_timeout: None,
684            name: None,
685            map_to_parent: None,
686            timeout_handlers: Vec::new(),
687            persistent: false,
688        }
689    }
690}
691
692impl<S: Hash + Eq + Clone + Send + Sync + 'static> ConversationHandlerBuilder<S> {
693    /// Add an entry point handler.
694    pub fn entry_point(mut self, handler: ConversationStepHandler<S>) -> Self {
695        self.entry_points.push(handler);
696        self
697    }
698
699    /// Add multiple entry point handlers.
700    pub fn entry_points(mut self, handlers: Vec<ConversationStepHandler<S>>) -> Self {
701        self.entry_points.extend(handlers);
702        self
703    }
704
705    /// Add handlers for a specific conversation state.
706    pub fn state(mut self, state: S, handlers: Vec<ConversationStepHandler<S>>) -> Self {
707        self.states.insert(state, handlers);
708        self
709    }
710
711    /// Add a fallback handler.
712    pub fn fallback(mut self, handler: ConversationStepHandler<S>) -> Self {
713        self.fallbacks.push(handler);
714        self
715    }
716
717    /// Add multiple fallback handlers.
718    pub fn fallbacks(mut self, handlers: Vec<ConversationStepHandler<S>>) -> Self {
719        self.fallbacks.extend(handlers);
720        self
721    }
722
723    /// Set whether re-entry via entry points is allowed.
724    pub fn allow_reentry(mut self, allow: bool) -> Self {
725        self.allow_reentry = allow;
726        self
727    }
728
729    /// Set whether the conversation key includes the chat ID.
730    pub fn per_chat(mut self, enabled: bool) -> Self {
731        self.per_chat = enabled;
732        self
733    }
734
735    /// Set whether the conversation key includes the user ID.
736    pub fn per_user(mut self, enabled: bool) -> Self {
737        self.per_user = enabled;
738        self
739    }
740
741    /// Set whether the conversation key includes the message ID.
742    pub fn per_message(mut self, enabled: bool) -> Self {
743        self.per_message = enabled;
744        self
745    }
746
747    /// Set the conversation timeout.
748    pub fn conversation_timeout(mut self, timeout: Duration) -> Self {
749        self.conversation_timeout = Some(timeout);
750        self
751    }
752
753    /// Set an optional name (required for persistence).
754    pub fn name(mut self, name: String) -> Self {
755        self.name = Some(name);
756        self
757    }
758
759    /// C4: Set the map-to-parent state mapping for nested conversations.
760    pub fn map_to_parent(mut self, mapping: HashMap<S, S>) -> Self {
761        self.map_to_parent = Some(mapping);
762        self
763    }
764
765    /// C5: Add handlers for the TIMEOUT state.
766    pub fn timeout_handlers(mut self, handlers: Vec<ConversationStepHandler<S>>) -> Self {
767        self.timeout_handlers = handlers;
768        self
769    }
770
771    /// C5: Add a single timeout handler.
772    pub fn timeout_handler(mut self, handler: ConversationStepHandler<S>) -> Self {
773        self.timeout_handlers.push(handler);
774        self
775    }
776
777    /// C6: Enable persistence for this conversation handler.
778    pub fn persistent(mut self, enabled: bool) -> Self {
779        self.persistent = enabled;
780        self
781    }
782
783    /// Build the `ConversationHandler`.
784    ///
785    /// # Panics
786    ///
787    /// Panics if `per_chat`, `per_user`, and `per_message` are all `false`.
788    /// Panics if `persistent` is `true` but `name` is `None`.
789    pub fn build(self) -> ConversationHandler<S> {
790        assert!(
791            self.per_chat || self.per_user || self.per_message,
792            "At least one of per_chat, per_user, per_message must be true"
793        );
794
795        if self.persistent && self.name.is_none() {
796            panic!("Conversations can't be persistent when handler is unnamed");
797        }
798
799        if self.per_message && !self.per_chat {
800            warn!(
801                "ConversationHandler: per_message=true without per_chat=true \
802                 -- message IDs are not globally unique"
803            );
804        }
805
806        ConversationHandler {
807            entry_points: self.entry_points,
808            states: self.states,
809            fallbacks: self.fallbacks,
810            conversations: Arc::new(RwLock::new(HashMap::new())),
811            allow_reentry: self.allow_reentry,
812            per_chat: self.per_chat,
813            per_user: self.per_user,
814            per_message: self.per_message,
815            conversation_timeout: self.conversation_timeout,
816            map_to_parent: self.map_to_parent,
817            timeout_handlers: self.timeout_handlers,
818            timeout_cancellers: Arc::new(RwLock::new(HashMap::new())),
819            persistent: self.persistent,
820            name: self.name,
821            pending_callbacks: Arc::new(RwLock::new(HashSet::new())),
822        }
823    }
824}
825
826// ---------------------------------------------------------------------------
827// Tests
828// ---------------------------------------------------------------------------
829
830#[cfg(test)]
831mod tests {
832    use super::*;
833    use serde_json::json;
834    use std::sync::Arc;
835
836    // -- Test state type ---------------------------------------------------
837
838    #[derive(Debug, Clone, Hash, Eq, PartialEq)]
839    enum TestState {
840        AskName,
841        AskAge,
842    }
843
844    // -- Helpers -----------------------------------------------------------
845
846    /// Create a simple handler that always matches updates with a message.
847    fn always_match_handler() -> Box<dyn Handler> {
848        struct AlwaysMatch;
849        impl Handler for AlwaysMatch {
850            fn check_update(&self, update: &Update) -> Option<MatchResult> {
851                if update.message().is_some() {
852                    Some(MatchResult::Empty)
853                } else {
854                    None
855                }
856            }
857            fn handle_update(
858                &self,
859                _update: Arc<Update>,
860                _match_result: MatchResult,
861            ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
862                Box::pin(async { HandlerResult::Continue })
863            }
864        }
865        Box::new(AlwaysMatch)
866    }
867
868    /// Create a handler that never matches.
869    fn never_match_handler() -> Box<dyn Handler> {
870        struct NeverMatch;
871        impl Handler for NeverMatch {
872            fn check_update(&self, _update: &Update) -> Option<MatchResult> {
873                None
874            }
875            fn handle_update(
876                &self,
877                _update: Arc<Update>,
878                _match_result: MatchResult,
879            ) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
880                Box::pin(async { HandlerResult::Continue })
881            }
882        }
883        Box::new(NeverMatch)
884    }
885
886    fn make_step<S: Hash + Eq + Clone + Send + Sync + 'static>(
887        handler: Box<dyn Handler>,
888        result: ConversationResult<S>,
889    ) -> ConversationStepHandler<S> {
890        ConversationStepHandler {
891            handler,
892            conv_callback: Arc::new(move |_u, _m| {
893                let r = result.clone();
894                Box::pin(async move { (HandlerResult::Continue, r) })
895            }),
896        }
897    }
898
899    fn make_update(chat_id: i64, user_id: i64) -> Update {
900        serde_json::from_value(json!({
901            "update_id": 1,
902            "message": {
903                "message_id": 1,
904                "date": 0,
905                "chat": {"id": chat_id, "type": "private"},
906                "from": {"id": user_id, "is_bot": false, "first_name": "Test"}
907            }
908        }))
909        .expect("test update JSON must be valid")
910    }
911
912    fn make_channel_post_update() -> Update {
913        serde_json::from_value(json!({
914            "update_id": 1,
915            "channel_post": {
916                "message_id": 1,
917                "date": 0,
918                "chat": {"id": -100, "type": "channel", "title": "Test"}
919            }
920        }))
921        .expect("test update JSON must be valid")
922    }
923
924    // -- Tests -------------------------------------------------------------
925
926    #[tokio::test]
927    async fn state_transition_entry_to_state1_to_state2_to_end() {
928        let conv = ConversationHandler::builder()
929            .entry_point(make_step(
930                always_match_handler(),
931                ConversationResult::NextState(TestState::AskName),
932            ))
933            .state(
934                TestState::AskName,
935                vec![make_step(
936                    always_match_handler(),
937                    ConversationResult::NextState(TestState::AskAge),
938                )],
939            )
940            .state(
941                TestState::AskAge,
942                vec![make_step(always_match_handler(), ConversationResult::End)],
943            )
944            .build();
945
946        let key = vec![100i64, 42i64];
947        let update = Arc::new(make_update(100, 42));
948
949        // Step 1: Entry point should match (no current state)
950        assert!(conv.check_update(&update).is_some());
951
952        // Execute the entry point -> transitions to AskName
953        conv.handle_update(update.clone(), MatchResult::Empty).await;
954        assert_eq!(conv.get_state(&key).await, Some(TestState::AskName));
955
956        // Step 2: AskName handler should match
957        assert!(conv.check_update(&update).is_some());
958        conv.handle_update(update.clone(), MatchResult::Empty).await;
959        assert_eq!(conv.get_state(&key).await, Some(TestState::AskAge));
960
961        // Step 3: AskAge handler should match and end the conversation
962        assert!(conv.check_update(&update).is_some());
963        conv.handle_update(update.clone(), MatchResult::Empty).await;
964        assert_eq!(conv.get_state(&key).await, None);
965    }
966
967    #[tokio::test]
968    async fn timeout_removes_conversation() {
969        let conv = ConversationHandler::builder()
970            .entry_point(make_step(
971                always_match_handler(),
972                ConversationResult::NextState(TestState::AskName),
973            ))
974            .state(
975                TestState::AskName,
976                vec![make_step(
977                    always_match_handler(),
978                    ConversationResult::NextState(TestState::AskAge),
979                )],
980            )
981            .conversation_timeout(Duration::from_millis(50))
982            .build();
983
984        let key = vec![100i64, 42i64];
985        let update = Arc::new(make_update(100, 42));
986
987        // Enter the conversation
988        conv.handle_update(update.clone(), MatchResult::Empty).await;
989        assert_eq!(conv.get_state(&key).await, Some(TestState::AskName));
990
991        // Wait for timeout to fire
992        tokio::time::sleep(Duration::from_millis(120)).await;
993
994        // Conversation should be removed
995        assert_eq!(conv.get_state(&key).await, None);
996    }
997
998    #[tokio::test]
999    async fn fallback_triggers_on_unmatched_input() {
1000        let conv = ConversationHandler::builder()
1001            .entry_point(make_step(
1002                always_match_handler(),
1003                ConversationResult::NextState(TestState::AskName),
1004            ))
1005            .state(
1006                TestState::AskName,
1007                vec![make_step(
1008                    never_match_handler(), // state handler won't match
1009                    ConversationResult::NextState(TestState::AskAge),
1010                )],
1011            )
1012            .fallback(make_step(always_match_handler(), ConversationResult::End))
1013            .build();
1014
1015        let key = vec![100i64, 42i64];
1016        let update = Arc::new(make_update(100, 42));
1017
1018        // Enter conversation
1019        conv.handle_update(update.clone(), MatchResult::Empty).await;
1020        assert_eq!(conv.get_state(&key).await, Some(TestState::AskName));
1021
1022        // In AskName state, the state handler won't match (NeverMatch),
1023        // so the fallback should match and end the conversation
1024        assert!(conv.check_update(&update).is_some());
1025        conv.handle_update(update.clone(), MatchResult::Empty).await;
1026        assert_eq!(conv.get_state(&key).await, None);
1027    }
1028
1029    #[test]
1030    fn channel_post_returns_none() {
1031        let conv = ConversationHandler::<TestState>::builder()
1032            .entry_point(make_step(
1033                always_match_handler(),
1034                ConversationResult::NextState(TestState::AskName),
1035            ))
1036            .build();
1037
1038        let channel_update = make_channel_post_update();
1039        assert!(
1040            conv.check_update(&channel_update).is_none(),
1041            "Channel posts must be rejected by ConversationHandler"
1042        );
1043    }
1044
1045    #[tokio::test]
1046    async fn persistence_load_save_roundtrip() {
1047        let conv = ConversationHandler::<TestState>::builder()
1048            .entry_point(make_step(
1049                always_match_handler(),
1050                ConversationResult::NextState(TestState::AskName),
1051            ))
1052            .state(
1053                TestState::AskName,
1054                vec![make_step(
1055                    always_match_handler(),
1056                    ConversationResult::NextState(TestState::AskAge),
1057                )],
1058            )
1059            .name("test_conv".to_string())
1060            .persistent(true)
1061            .build();
1062
1063        // Load pre-existing conversation data
1064        let mut data = HashMap::new();
1065        data.insert(vec![1i64, 2i64], TestState::AskAge);
1066        data.insert(vec![3i64, 4i64], TestState::AskName);
1067        conv.load_conversations(data).await;
1068
1069        // Verify loaded state
1070        assert_eq!(
1071            conv.get_state(&vec![1i64, 2i64]).await,
1072            Some(TestState::AskAge)
1073        );
1074        assert_eq!(
1075            conv.get_state(&vec![3i64, 4i64]).await,
1076            Some(TestState::AskName)
1077        );
1078
1079        // Save and verify round-trip
1080        let saved = conv.save_conversations().await;
1081        assert_eq!(saved.len(), 2);
1082        assert_eq!(saved.get(&vec![1i64, 2i64]), Some(&TestState::AskAge));
1083    }
1084
1085    #[test]
1086    fn builder_name_and_persistence() {
1087        let conv = ConversationHandler::<TestState>::builder()
1088            .entry_point(make_step(
1089                always_match_handler(),
1090                ConversationResult::NextState(TestState::AskName),
1091            ))
1092            .name("my_conv".to_string())
1093            .persistent(true)
1094            .build();
1095
1096        assert!(conv.is_persistent());
1097        assert_eq!(conv.name(), Some("my_conv"));
1098    }
1099
1100    #[test]
1101    #[should_panic(expected = "At least one of per_chat, per_user, per_message must be true")]
1102    fn builder_panics_without_key_components() {
1103        ConversationHandler::<TestState>::builder()
1104            .per_chat(false)
1105            .per_user(false)
1106            .per_message(false)
1107            .build();
1108    }
1109
1110    #[test]
1111    #[should_panic(expected = "Conversations can't be persistent when handler is unnamed")]
1112    fn builder_panics_persistent_without_name() {
1113        ConversationHandler::<TestState>::builder()
1114            .persistent(true)
1115            .build();
1116    }
1117}