Skip to main content

rust_tg_bot_ext/
context.rs

1//! Callback context passed to handlers and error handlers.
2//!
3//! Ported from `python-telegram-bot/src/telegram/ext/_callbackcontext.py`.
4
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use serde_json::Value;
9use tokio::sync::RwLock;
10
11use rust_tg_bot_raw::bot::MessageOrBool;
12use rust_tg_bot_raw::types::files::input_file::InputFile;
13use rust_tg_bot_raw::types::update::Update;
14
15use crate::context_types::DefaultData;
16use crate::ext_bot::ExtBot;
17#[cfg(feature = "job-queue")]
18use crate::job_queue::JobQueue;
19
20// ---------------------------------------------------------------------------
21// Typed data guard wrappers
22// ---------------------------------------------------------------------------
23
24/// A typed read guard over a [`DefaultData`] map.
25///
26/// Provides convenience accessors that eliminate manual `get().and_then(|v| v.as_*)` chains
27/// while still exposing the raw `HashMap` via [`raw()`](Self::raw).
28pub struct DataReadGuard<'a> {
29    inner: tokio::sync::RwLockReadGuard<'a, DefaultData>,
30}
31
32impl<'a> DataReadGuard<'a> {
33    /// Get a string value by key.
34    #[must_use]
35    pub fn get_str(&self, key: &str) -> Option<&str> {
36        self.inner.get(key).and_then(|v| v.as_str())
37    }
38
39    /// Get an `i64` value by key.
40    #[must_use]
41    pub fn get_i64(&self, key: &str) -> Option<i64> {
42        self.inner.get(key).and_then(|v| v.as_i64())
43    }
44
45    /// Get a `f64` value by key.
46    #[must_use]
47    pub fn get_f64(&self, key: &str) -> Option<f64> {
48        self.inner.get(key).and_then(|v| v.as_f64())
49    }
50
51    /// Get a `bool` value by key.
52    #[must_use]
53    pub fn get_bool(&self, key: &str) -> Option<bool> {
54        self.inner.get(key).and_then(|v| v.as_bool())
55    }
56
57    /// Get a raw [`Value`] by key.
58    #[must_use]
59    pub fn get(&self, key: &str) -> Option<&Value> {
60        self.inner.get(key)
61    }
62
63    /// Get a set of `i64` IDs stored as a JSON array under `key`.
64    ///
65    /// This is a common pattern for tracking `user_ids`, `chat_ids`, etc.
66    /// Returns an empty set if the key is missing or the value is not an array.
67    #[must_use]
68    pub fn get_id_set(&self, key: &str) -> HashSet<i64> {
69        self.inner
70            .get(key)
71            .and_then(|v| v.as_array())
72            .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
73            .unwrap_or_default()
74    }
75
76    /// Access the raw underlying `HashMap`.
77    #[must_use]
78    pub fn raw(&self) -> &DefaultData {
79        &self.inner
80    }
81
82    /// Returns `true` if the underlying map is empty.
83    #[must_use]
84    pub fn is_empty(&self) -> bool {
85        self.inner.is_empty()
86    }
87
88    /// Returns the number of entries in the underlying map.
89    #[must_use]
90    pub fn len(&self) -> usize {
91        self.inner.len()
92    }
93}
94
95impl std::ops::Deref for DataReadGuard<'_> {
96    type Target = DefaultData;
97
98    fn deref(&self) -> &DefaultData {
99        &self.inner
100    }
101}
102
103/// A typed write guard over a [`DefaultData`] map.
104///
105/// Provides typed setters alongside the raw `HashMap` accessors.
106pub struct DataWriteGuard<'a> {
107    inner: tokio::sync::RwLockWriteGuard<'a, DefaultData>,
108}
109
110impl<'a> DataWriteGuard<'a> {
111    // -- Typed getters (same as DataReadGuard) --------------------------------
112
113    /// Get a string value by key.
114    #[must_use]
115    pub fn get_str(&self, key: &str) -> Option<&str> {
116        self.inner.get(key).and_then(|v| v.as_str())
117    }
118
119    /// Get an `i64` value by key.
120    #[must_use]
121    pub fn get_i64(&self, key: &str) -> Option<i64> {
122        self.inner.get(key).and_then(|v| v.as_i64())
123    }
124
125    /// Get a `f64` value by key.
126    #[must_use]
127    pub fn get_f64(&self, key: &str) -> Option<f64> {
128        self.inner.get(key).and_then(|v| v.as_f64())
129    }
130
131    /// Get a `bool` value by key.
132    #[must_use]
133    pub fn get_bool(&self, key: &str) -> Option<bool> {
134        self.inner.get(key).and_then(|v| v.as_bool())
135    }
136
137    /// Get a raw [`Value`] by key.
138    #[must_use]
139    pub fn get(&self, key: &str) -> Option<&Value> {
140        self.inner.get(key)
141    }
142
143    /// Get a set of `i64` IDs stored as a JSON array under `key`.
144    #[must_use]
145    pub fn get_id_set(&self, key: &str) -> HashSet<i64> {
146        self.inner
147            .get(key)
148            .and_then(|v| v.as_array())
149            .map(|arr| arr.iter().filter_map(|v| v.as_i64()).collect())
150            .unwrap_or_default()
151    }
152
153    // -- Typed setters --------------------------------------------------------
154
155    /// Set a string value.
156    pub fn set_str(&mut self, key: impl Into<String>, value: impl Into<String>) {
157        self.inner.insert(key.into(), Value::String(value.into()));
158    }
159
160    /// Set an `i64` value.
161    pub fn set_i64(&mut self, key: impl Into<String>, value: i64) {
162        self.inner.insert(key.into(), Value::Number(value.into()));
163    }
164
165    /// Set a `bool` value.
166    pub fn set_bool(&mut self, key: impl Into<String>, value: bool) {
167        self.inner.insert(key.into(), Value::Bool(value));
168    }
169
170    /// Insert a raw [`Value`].
171    pub fn insert(&mut self, key: String, value: Value) -> Option<Value> {
172        self.inner.insert(key, value)
173    }
174
175    /// Add an `i64` to a set stored as a JSON array under `key`.
176    ///
177    /// Creates the array if the key does not exist. Deduplicates values.
178    pub fn add_to_id_set(&mut self, key: &str, id: i64) {
179        let entry = self
180            .inner
181            .entry(key.to_owned())
182            .or_insert_with(|| Value::Array(vec![]));
183        if let Some(arr) = entry.as_array_mut() {
184            let val = Value::Number(id.into());
185            if !arr.contains(&val) {
186                arr.push(val);
187            }
188        }
189    }
190
191    /// Remove an `i64` from a set stored as a JSON array under `key`.
192    pub fn remove_from_id_set(&mut self, key: &str, id: i64) {
193        if let Some(arr) = self.inner.get_mut(key).and_then(|v| v.as_array_mut()) {
194            arr.retain(|v| v.as_i64() != Some(id));
195        }
196    }
197
198    /// Access the raw underlying `HashMap`.
199    #[must_use]
200    pub fn raw(&self) -> &DefaultData {
201        &self.inner
202    }
203
204    /// Access the raw underlying `HashMap` mutably.
205    pub fn raw_mut(&mut self) -> &mut DefaultData {
206        &mut self.inner
207    }
208
209    /// Access the `Entry` API of the underlying `HashMap`.
210    pub fn entry(&mut self, key: String) -> std::collections::hash_map::Entry<'_, String, Value> {
211        self.inner.entry(key)
212    }
213
214    /// Get a mutable reference to a value by key.
215    pub fn get_mut(&mut self, key: &str) -> Option<&mut Value> {
216        self.inner.get_mut(key)
217    }
218
219    /// Returns `true` if the underlying map is empty.
220    #[must_use]
221    pub fn is_empty(&self) -> bool {
222        self.inner.is_empty()
223    }
224
225    /// Returns the number of entries in the underlying map.
226    #[must_use]
227    pub fn len(&self) -> usize {
228        self.inner.len()
229    }
230
231    /// Remove a key from the underlying map.
232    pub fn remove(&mut self, key: &str) -> Option<Value> {
233        self.inner.remove(key)
234    }
235}
236
237impl std::ops::Deref for DataWriteGuard<'_> {
238    type Target = DefaultData;
239
240    fn deref(&self) -> &DefaultData {
241        &self.inner
242    }
243}
244
245impl std::ops::DerefMut for DataWriteGuard<'_> {
246    fn deref_mut(&mut self) -> &mut DefaultData {
247        &mut self.inner
248    }
249}
250
251// ---------------------------------------------------------------------------
252// CallbackContext
253// ---------------------------------------------------------------------------
254
255/// A context object passed to handler callbacks.
256#[derive(Debug, Clone)]
257pub struct CallbackContext {
258    /// The bot associated with this context.
259    bot: Arc<ExtBot>,
260
261    /// The chat id associated with this context (used to look up `chat_data`).
262    chat_id: Option<i64>,
263
264    /// The user id associated with this context (used to look up `user_data`).
265    user_id: Option<i64>,
266
267    // -- Shared data references (populated by Application) --------------------
268    /// Reference into the application's per-user data store.
269    user_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
270
271    /// Reference into the application's per-chat data store.
272    chat_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
273
274    /// Reference to the application's bot-wide data.
275    bot_data: Arc<RwLock<DefaultData>>,
276
277    // -- Per-callback mutable state -------------------------------------------
278    /// Positional regex match results (populated by regex-based handlers).
279    pub matches: Option<Vec<String>>,
280
281    /// Named regex capture groups (populated by regex-based handlers when the
282    /// pattern contains at least one named group).
283    ///
284    /// Mirrors Python's `context.matches` which exposes the full `re.Match`
285    /// object including `match.groupdict()`.
286    pub named_matches: Option<HashMap<String, String>>,
287
288    /// Arguments to a command (populated by `CommandHandler`).
289    pub args: Option<Vec<String>>,
290
291    /// The error that was raised.  Only present in error handler contexts.
292    pub error: Option<Arc<dyn std::error::Error + Send + Sync>>,
293
294    /// Extra key-value pairs that handlers can attach for downstream handlers.
295    /// Lazy: `None` until first insertion to avoid clone overhead during dispatch.
296    extra: Option<HashMap<String, Value>>,
297
298    /// Optional reference to the application's job queue.
299    ///
300    /// Requires the `job-queue` feature.
301    #[cfg(feature = "job-queue")]
302    pub job_queue: Option<Arc<JobQueue>>,
303}
304
305impl CallbackContext {
306    /// Creates a new `CallbackContext`.
307    #[must_use]
308    pub fn new(
309        bot: Arc<ExtBot>,
310        chat_id: Option<i64>,
311        user_id: Option<i64>,
312        user_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
313        chat_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
314        bot_data: Arc<RwLock<DefaultData>>,
315    ) -> Self {
316        Self {
317            bot,
318            chat_id,
319            user_id,
320            user_data_store,
321            chat_data_store,
322            bot_data,
323            matches: None,
324            named_matches: None,
325            args: None,
326            error: None,
327            extra: None,
328            #[cfg(feature = "job-queue")]
329            job_queue: None,
330        }
331    }
332
333    // -- Factory methods (mirrors Python classmethod constructors) -------------
334
335    /// Constructs a context from a typed [`Update`].
336    #[must_use]
337    pub fn from_update(
338        update: &Update,
339        bot: Arc<ExtBot>,
340        user_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
341        chat_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
342        bot_data: Arc<RwLock<DefaultData>>,
343    ) -> Self {
344        let (chat_id, user_id) = extract_ids(update);
345        Self::new(
346            bot,
347            chat_id,
348            user_id,
349            user_data_store,
350            chat_data_store,
351            bot_data,
352        )
353    }
354
355    /// Constructs a context for an error handler.
356    #[must_use]
357    pub fn from_error(
358        update: Option<&Update>,
359        error: Arc<dyn std::error::Error + Send + Sync>,
360        bot: Arc<ExtBot>,
361        user_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
362        chat_data_store: Arc<RwLock<HashMap<i64, DefaultData>>>,
363        bot_data: Arc<RwLock<DefaultData>>,
364    ) -> Self {
365        let (chat_id, user_id) = update.map_or((None, None), extract_ids);
366        let mut ctx = Self::new(
367            bot,
368            chat_id,
369            user_id,
370            user_data_store,
371            chat_data_store,
372            bot_data,
373        );
374        ctx.error = Some(error);
375        ctx
376    }
377
378    // -- Accessors ------------------------------------------------------------
379
380    // -- Accessors ------------------------------------------------------------
381
382    /// Returns a reference to the bot associated with this context.
383    #[must_use]
384    pub fn bot(&self) -> &Arc<ExtBot> {
385        &self.bot
386    }
387
388    /// Returns the chat ID extracted from the update, if available.
389    #[must_use]
390    pub fn chat_id(&self) -> Option<i64> {
391        self.chat_id
392    }
393
394    /// Returns the user ID extracted from the update, if available.
395    #[must_use]
396    pub fn user_id(&self) -> Option<i64> {
397        self.user_id
398    }
399
400    // -- Typed bot_data accessors ---------------------------------------------
401
402    /// Acquire a read lock on the bot-wide data store, returning a typed guard.
403    pub async fn bot_data(&self) -> DataReadGuard<'_> {
404        DataReadGuard {
405            inner: self.bot_data.read().await,
406        }
407    }
408
409    /// Acquire a write lock on the bot-wide data store, returning a typed guard.
410    pub async fn bot_data_mut(&self) -> DataWriteGuard<'_> {
411        DataWriteGuard {
412            inner: self.bot_data.write().await,
413        }
414    }
415
416    // -- user_data / chat_data (unchanged API, returns cloned snapshot) --------
417
418    /// Returns a cloned snapshot of the current user's data, if a user ID is set.
419    pub async fn user_data(&self) -> Option<DefaultData> {
420        let uid = self.user_id?;
421        let store = self.user_data_store.read().await;
422        store.get(&uid).cloned()
423    }
424
425    /// Returns a cloned snapshot of the current chat's data, if a chat ID is set.
426    pub async fn chat_data(&self) -> Option<DefaultData> {
427        let cid = self.chat_id?;
428        let store = self.chat_data_store.read().await;
429        store.get(&cid).cloned()
430    }
431
432    /// Insert a key-value pair into the current user's data store. Returns `false` if no user ID.
433    pub async fn set_user_data(&self, key: String, value: Value) -> bool {
434        let uid = match self.user_id {
435            Some(id) => id,
436            None => return false,
437        };
438        let mut store = self.user_data_store.write().await;
439        store
440            .entry(uid)
441            .or_insert_with(HashMap::new)
442            .insert(key, value);
443        true
444    }
445
446    /// Insert a key-value pair into the current chat's data store. Returns `false` if no chat ID.
447    pub async fn set_chat_data(&self, key: String, value: Value) -> bool {
448        let cid = match self.chat_id {
449            Some(id) => id,
450            None => return false,
451        };
452        let mut store = self.chat_data_store.write().await;
453        store
454            .entry(cid)
455            .or_insert_with(HashMap::new)
456            .insert(key, value);
457        true
458    }
459
460    /// Returns the first positional regex match, if available.
461    #[must_use]
462    pub fn match_result(&self) -> Option<&str> {
463        self.matches
464            .as_ref()
465            .and_then(|m| m.first().map(String::as_str))
466    }
467
468    /// Returns a reference to the extra data map, if any data has been set.
469    #[must_use]
470    pub fn extra(&self) -> Option<&HashMap<String, Value>> {
471        self.extra.as_ref()
472    }
473
474    /// Returns a mutable reference to the extra data map, creating it if needed.
475    pub fn extra_mut(&mut self) -> &mut HashMap<String, Value> {
476        self.extra.get_or_insert_with(HashMap::new)
477    }
478
479    /// Insert a key-value pair into the extra data map.
480    pub fn set_extra(&mut self, key: String, value: Value) {
481        self.extra
482            .get_or_insert_with(HashMap::new)
483            .insert(key, value);
484    }
485
486    /// Get a value from the extra data map by key.
487    #[must_use]
488    pub fn get_extra(&self, key: &str) -> Option<&Value> {
489        self.extra.as_ref().and_then(|m| m.get(key))
490    }
491
492    /// Drop the cached callback data for a given callback query ID.
493    pub async fn drop_callback_data(
494        &self,
495        callback_query_id: &str,
496    ) -> Result<(), crate::callback_data_cache::InvalidCallbackData> {
497        let cache = self.bot.callback_data_cache().ok_or(
498            crate::callback_data_cache::InvalidCallbackData {
499                callback_data: None,
500            },
501        )?;
502        let mut guard = cache.write().await;
503        guard.drop_data(callback_query_id)
504    }
505
506    /// Set the job queue reference on this context.
507    ///
508    /// Requires the `job-queue` feature.
509    #[cfg(feature = "job-queue")]
510    pub fn with_job_queue(mut self, jq: Arc<JobQueue>) -> Self {
511        self.job_queue = Some(jq);
512        self
513    }
514
515    // -- Convenience methods (mirrors python-telegram-bot patterns) -----------
516
517    /// Send a text reply to the chat associated with the given update.
518    ///
519    /// This is a convenience method that mirrors python-telegram-bot's
520    /// `update.message.reply_text(text)` / `context.bot.send_message(...)`.
521    ///
522    /// # Errors
523    ///
524    /// Returns `TelegramError` if the chat cannot be determined from the
525    /// update or if the Telegram API call fails.
526    pub async fn reply_text(
527        &self,
528        update: &Update,
529        text: &str,
530    ) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
531    {
532        let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
533            rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
534        })?;
535        self.bot().send_message(chat_id, text).await
536    }
537
538    /// Send an HTML-formatted text reply to the chat associated with the given update.
539    ///
540    /// Equivalent to `reply_text` with `parse_mode("HTML")`.
541    ///
542    /// # Errors
543    ///
544    /// Returns `TelegramError` if the chat cannot be determined from the
545    /// update or if the Telegram API call fails.
546    pub async fn reply_html(
547        &self,
548        update: &Update,
549        text: &str,
550    ) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
551    {
552        let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
553            rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
554        })?;
555        self.bot()
556            .send_message(chat_id, text)
557            .parse_mode("HTML")
558            .await
559    }
560
561    /// Send a MarkdownV2-formatted text reply to the chat associated with the given update.
562    ///
563    /// Equivalent to `reply_text` with `parse_mode("MarkdownV2")`.
564    ///
565    /// # Errors
566    ///
567    /// Returns `TelegramError` if the chat cannot be determined from the
568    /// update or if the Telegram API call fails.
569    pub async fn reply_markdown_v2(
570        &self,
571        update: &Update,
572        text: &str,
573    ) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
574    {
575        let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
576            rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
577        })?;
578        self.bot()
579            .send_message(chat_id, text)
580            .parse_mode("MarkdownV2")
581            .await
582    }
583
584    /// Send a photo reply to the chat associated with the given update.
585    ///
586    /// # Errors
587    ///
588    /// Returns `TelegramError` if the chat cannot be determined from the
589    /// update or if the Telegram API call fails.
590    pub async fn reply_photo(
591        &self,
592        update: &Update,
593        photo: InputFile,
594    ) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
595    {
596        let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
597            rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
598        })?;
599        self.bot().send_photo(chat_id, photo).await
600    }
601
602    /// Send a document reply to the chat associated with the given update.
603    ///
604    /// # Errors
605    ///
606    /// Returns `TelegramError` if the chat cannot be determined from the
607    /// update or if the Telegram API call fails.
608    pub async fn reply_document(
609        &self,
610        update: &Update,
611        document: InputFile,
612    ) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
613    {
614        let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
615            rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
616        })?;
617        self.bot().send_document(chat_id, document).await
618    }
619
620    /// Send a sticker reply to the chat associated with the given update.
621    ///
622    /// # Errors
623    ///
624    /// Returns `TelegramError` if the chat cannot be determined from the
625    /// update or if the Telegram API call fails.
626    pub async fn reply_sticker(
627        &self,
628        update: &Update,
629        sticker: InputFile,
630    ) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
631    {
632        let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
633            rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
634        })?;
635        self.bot().send_sticker(chat_id, sticker).await
636    }
637
638    /// Send a location reply to the chat associated with the given update.
639    ///
640    /// # Errors
641    ///
642    /// Returns `TelegramError` if the chat cannot be determined from the
643    /// update or if the Telegram API call fails.
644    pub async fn reply_location(
645        &self,
646        update: &Update,
647        latitude: f64,
648        longitude: f64,
649    ) -> Result<rust_tg_bot_raw::types::message::Message, rust_tg_bot_raw::error::TelegramError>
650    {
651        let chat_id = update.effective_chat().map(|c| c.id).ok_or_else(|| {
652            rust_tg_bot_raw::error::TelegramError::Network("No chat in update".into())
653        })?;
654        self.bot().send_location(chat_id, latitude, longitude).await
655    }
656
657    /// Answer the callback query from the given update.
658    ///
659    /// Automatically extracts the callback query ID from the update. This is
660    /// a convenience shortcut that eliminates the common boilerplate of
661    /// extracting `update.callback_query.id` manually.
662    ///
663    /// # Errors
664    ///
665    /// Returns `TelegramError` if the update does not contain a callback query
666    /// or if the Telegram API call fails.
667    pub async fn answer_callback_query(
668        &self,
669        update: &Update,
670    ) -> Result<bool, rust_tg_bot_raw::error::TelegramError> {
671        let cq = update.callback_query().ok_or_else(|| {
672            rust_tg_bot_raw::error::TelegramError::Network("No callback query in update".into())
673        })?;
674        self.bot().answer_callback_query(&cq.id).await
675    }
676
677    /// Edit the text of the message that originated the callback query.
678    ///
679    /// Automatically determines whether to use `chat_id + message_id` or
680    /// `inline_message_id` based on the callback query contents.
681    ///
682    /// # Errors
683    ///
684    /// Returns `TelegramError` if the update does not contain a callback query,
685    /// the callback query has no associated message, or the Telegram API call fails.
686    pub async fn edit_callback_message_text(
687        &self,
688        update: &Update,
689        text: &str,
690    ) -> Result<MessageOrBool, rust_tg_bot_raw::error::TelegramError> {
691        let cq = update.callback_query().ok_or_else(|| {
692            rust_tg_bot_raw::error::TelegramError::Network("No callback query in update".into())
693        })?;
694
695        if let Some(msg) = cq.message.as_deref() {
696            self.bot()
697                .edit_message_text(text)
698                .chat_id(msg.chat().id)
699                .message_id(msg.message_id())
700                .await
701        } else if let Some(ref iid) = cq.inline_message_id {
702            self.bot()
703                .edit_message_text(text)
704                .inline_message_id(iid)
705                .await
706        } else {
707            Err(rust_tg_bot_raw::error::TelegramError::Network(
708                "No message in callback query".into(),
709            ))
710        }
711    }
712}
713
714// ---------------------------------------------------------------------------
715// Helpers
716// ---------------------------------------------------------------------------
717
718/// Extract chat and user IDs from a typed [`Update`] using its computed
719/// properties. This is vastly cleaner than the previous Value-based approach.
720fn extract_ids(update: &Update) -> (Option<i64>, Option<i64>) {
721    let chat_id = update.effective_chat().map(|c| c.id);
722    let user_id = update.effective_user().map(|u| u.id);
723    (chat_id, user_id)
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use crate::ext_bot::test_support::mock_request;
730    use rust_tg_bot_raw::bot::Bot;
731
732    fn make_bot() -> Arc<ExtBot> {
733        let bot = Bot::new("test", mock_request());
734        Arc::new(ExtBot::from_bot(bot))
735    }
736
737    fn make_stores() -> (
738        Arc<RwLock<HashMap<i64, DefaultData>>>,
739        Arc<RwLock<HashMap<i64, DefaultData>>>,
740        Arc<RwLock<DefaultData>>,
741    ) {
742        (
743            Arc::new(RwLock::new(HashMap::new())),
744            Arc::new(RwLock::new(HashMap::new())),
745            Arc::new(RwLock::new(HashMap::new())),
746        )
747    }
748
749    fn make_update(json_val: serde_json::Value) -> Update {
750        serde_json::from_value(json_val).unwrap()
751    }
752
753    #[test]
754    fn context_basic_creation() {
755        let bot = make_bot();
756        let (ud, cd, bd) = make_stores();
757        let ctx = CallbackContext::new(bot.clone(), Some(42), Some(7), ud, cd, bd);
758        assert_eq!(ctx.chat_id(), Some(42));
759        assert_eq!(ctx.user_id(), Some(7));
760        assert!(ctx.error.is_none());
761        assert!(ctx.args.is_none());
762        assert!(ctx.matches.is_none());
763        assert!(ctx.named_matches.is_none());
764        #[cfg(feature = "job-queue")]
765        assert!(ctx.job_queue.is_none());
766    }
767
768    #[test]
769    fn extract_ids_from_message_update() {
770        let update = make_update(
771            serde_json::json!({"update_id": 1, "message": {"message_id": 1, "date": 0, "chat": {"id": 100, "type": "private"}, "from": {"id": 200, "is_bot": false, "first_name": "Test"}}}),
772        );
773        let (chat_id, user_id) = extract_ids(&update);
774        assert_eq!(chat_id, Some(100));
775        assert_eq!(user_id, Some(200));
776    }
777
778    #[test]
779    fn extract_ids_from_callback_query() {
780        let update = make_update(
781            serde_json::json!({"update_id": 2, "callback_query": {"id": "abc", "from": {"id": 300, "is_bot": false, "first_name": "U"}, "chat_instance": "ci", "message": {"message_id": 5, "date": 0, "chat": {"id": 400, "type": "group"}}}}),
782        );
783        let (chat_id, user_id) = extract_ids(&update);
784        assert_eq!(chat_id, Some(400));
785        assert_eq!(user_id, Some(300));
786    }
787
788    #[test]
789    fn extract_ids_returns_none_for_empty() {
790        let update = make_update(serde_json::json!({"update_id": 3}));
791        let (chat_id, user_id) = extract_ids(&update);
792        assert!(chat_id.is_none());
793        assert!(user_id.is_none());
794    }
795
796    #[test]
797    fn from_update_factory() {
798        let bot = make_bot();
799        let (ud, cd, bd) = make_stores();
800        let update = make_update(
801            serde_json::json!({"update_id": 1, "message": {"message_id": 1, "date": 0, "chat": {"id": 10, "type": "private"}, "from": {"id": 20, "is_bot": false, "first_name": "T"}}}),
802        );
803        let ctx = CallbackContext::from_update(&update, bot, ud, cd, bd);
804        assert_eq!(ctx.chat_id(), Some(10));
805        assert_eq!(ctx.user_id(), Some(20));
806    }
807
808    #[test]
809    fn from_error_factory() {
810        let bot = make_bot();
811        let (ud, cd, bd) = make_stores();
812        let err: Arc<dyn std::error::Error + Send + Sync> =
813            Arc::new(std::io::Error::new(std::io::ErrorKind::Other, "boom"));
814        let ctx = CallbackContext::from_error(None, err, bot, ud, cd, bd);
815        assert!(ctx.error.is_some());
816        assert!(ctx.chat_id().is_none());
817    }
818
819    #[tokio::test]
820    async fn bot_data_access() {
821        let bot = make_bot();
822        let (ud, cd, bd) = make_stores();
823        let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
824        {
825            let mut guard = ctx.bot_data_mut().await;
826            guard.insert("key".into(), Value::String("val".into()));
827        }
828        let guard = ctx.bot_data().await;
829        assert_eq!(guard.get("key"), Some(&Value::String("val".into())));
830    }
831
832    #[tokio::test]
833    async fn user_data_returns_none_without_user_id() {
834        let bot = make_bot();
835        let (ud, cd, bd) = make_stores();
836        let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
837        assert!(ctx.user_data().await.is_none());
838    }
839
840    #[tokio::test]
841    async fn chat_data_returns_none_without_chat_id() {
842        let bot = make_bot();
843        let (ud, cd, bd) = make_stores();
844        let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
845        assert!(ctx.chat_data().await.is_none());
846    }
847
848    #[tokio::test]
849    async fn set_user_data_works() {
850        let bot = make_bot();
851        let (ud, cd, bd) = make_stores();
852        let ctx = CallbackContext::new(bot, None, Some(42), ud.clone(), cd, bd);
853        assert!(
854            ctx.set_user_data("score".into(), Value::Number(100.into()))
855                .await
856        );
857        let store = ud.read().await;
858        assert_eq!(
859            store.get(&42).unwrap().get("score"),
860            Some(&Value::Number(100.into()))
861        );
862    }
863
864    #[tokio::test]
865    async fn set_chat_data_works() {
866        let bot = make_bot();
867        let (ud, cd, bd) = make_stores();
868        let ctx = CallbackContext::new(bot, Some(10), None, ud, cd.clone(), bd);
869        assert!(
870            ctx.set_chat_data("topic".into(), Value::String("rust".into()))
871                .await
872        );
873        let store = cd.read().await;
874        assert_eq!(
875            store.get(&10).unwrap().get("topic"),
876            Some(&Value::String("rust".into()))
877        );
878    }
879
880    #[tokio::test]
881    async fn set_user_data_returns_false_without_user_id() {
882        let bot = make_bot();
883        let (ud, cd, bd) = make_stores();
884        let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
885        assert!(!ctx.set_user_data("k".into(), Value::Null).await);
886    }
887
888    #[test]
889    fn match_result_shortcut() {
890        let bot = make_bot();
891        let (ud, cd, bd) = make_stores();
892        let mut ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
893        assert!(ctx.match_result().is_none());
894        ctx.matches = Some(vec!["hello".into(), "world".into()]);
895        assert_eq!(ctx.match_result(), Some("hello"));
896    }
897
898    #[test]
899    fn extra_is_lazily_initialized() {
900        let bot = make_bot();
901        let (ud, cd, bd) = make_stores();
902        let mut ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
903
904        assert!(ctx.extra().is_none());
905        assert!(ctx.get_extra("missing").is_none());
906
907        ctx.extra_mut()
908            .insert("count".into(), Value::Number(1.into()));
909        assert_eq!(ctx.get_extra("count"), Some(&Value::Number(1.into())));
910
911        ctx.set_extra("name".into(), Value::String("Alice".into()));
912        assert_eq!(
913            ctx.extra().and_then(|extra| extra.get("name")),
914            Some(&Value::String("Alice".into()))
915        );
916    }
917
918    #[cfg(feature = "job-queue")]
919    #[test]
920    fn with_job_queue() {
921        let bot = make_bot();
922        let (ud, cd, bd) = make_stores();
923        let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
924        let jq = Arc::new(JobQueue::new());
925        let ctx = ctx.with_job_queue(jq.clone());
926        assert!(ctx.job_queue.is_some());
927    }
928
929    // -- Typed guard tests ----------------------------------------------------
930
931    #[tokio::test]
932    async fn data_write_guard_typed_setters() {
933        let bot = make_bot();
934        let (ud, cd, bd) = make_stores();
935        let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
936
937        {
938            let mut guard = ctx.bot_data_mut().await;
939            guard.set_str("name", "Alice");
940            guard.set_i64("score", 42);
941            guard.set_bool("active", true);
942        }
943
944        let guard = ctx.bot_data().await;
945        assert_eq!(guard.get_str("name"), Some("Alice"));
946        assert_eq!(guard.get_i64("score"), Some(42));
947        assert_eq!(guard.get_bool("active"), Some(true));
948    }
949
950    #[tokio::test]
951    async fn data_write_guard_id_set_operations() {
952        let bot = make_bot();
953        let (ud, cd, bd) = make_stores();
954        let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
955
956        {
957            let mut guard = ctx.bot_data_mut().await;
958            guard.add_to_id_set("user_ids", 100);
959            guard.add_to_id_set("user_ids", 200);
960            guard.add_to_id_set("user_ids", 100); // duplicate -- should not add
961        }
962
963        let guard = ctx.bot_data().await;
964        let ids = guard.get_id_set("user_ids");
965        assert_eq!(ids.len(), 2);
966        assert!(ids.contains(&100));
967        assert!(ids.contains(&200));
968
969        drop(guard);
970
971        {
972            let mut guard = ctx.bot_data_mut().await;
973            guard.remove_from_id_set("user_ids", 100);
974        }
975
976        let guard = ctx.bot_data().await;
977        let ids = guard.get_id_set("user_ids");
978        assert_eq!(ids.len(), 1);
979        assert!(ids.contains(&200));
980    }
981
982    #[tokio::test]
983    async fn data_read_guard_empty_id_set() {
984        let bot = make_bot();
985        let (ud, cd, bd) = make_stores();
986        let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
987
988        let guard = ctx.bot_data().await;
989        let ids = guard.get_id_set("nonexistent");
990        assert!(ids.is_empty());
991    }
992
993    #[tokio::test]
994    async fn data_guard_deref_to_hashmap() {
995        let bot = make_bot();
996        let (ud, cd, bd) = make_stores();
997        let ctx = CallbackContext::new(bot, None, None, ud, cd, bd);
998
999        {
1000            let mut guard = ctx.bot_data_mut().await;
1001            guard.set_str("key", "val");
1002        }
1003
1004        let guard = ctx.bot_data().await;
1005        // Use Deref to access HashMap methods directly
1006        assert!(guard.contains_key("key"));
1007        assert_eq!(guard.len(), 1);
1008    }
1009}