Skip to main content

rust_tg_bot_ext/
callback_data_cache.rs

1//! Arbitrary callback data cache.
2//!
3//! Ported from `python-telegram-bot/src/telegram/ext/_callbackdatacache.py`.
4//!
5//! Stores arbitrary callback data for inline keyboard buttons so that the actual
6//! objects (not just short strings) can be passed through the Telegram callback
7//! mechanism.  Uses a simple LRU eviction strategy bounded by `maxsize`.
8
9use std::collections::{HashMap, VecDeque};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use serde_json::Value;
14
15use rust_tg_bot_raw::types::callback_query::CallbackQuery;
16use rust_tg_bot_raw::types::inline::inline_keyboard_button::InlineKeyboardButton;
17use rust_tg_bot_raw::types::inline::inline_keyboard_markup::InlineKeyboardMarkup;
18
19// ---------------------------------------------------------------------------
20// UUID generation (SystemTime + atomic counter -- no external crate)
21// ---------------------------------------------------------------------------
22
23static COUNTER: AtomicU64 = AtomicU64::new(0);
24
25/// Generates a 32-hex-char unique id using the current timestamp and an atomic counter.
26fn generate_uuid() -> String {
27    let ts = SystemTime::now()
28        .duration_since(UNIX_EPOCH)
29        .unwrap_or_default()
30        .as_nanos();
31    let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
32    format!("{ts:016x}{seq:016x}")
33}
34
35// ---------------------------------------------------------------------------
36// Error type
37// ---------------------------------------------------------------------------
38
39/// Raised when the received callback data has been tampered with or deleted from cache.
40#[derive(Debug, Clone, thiserror::Error)]
41#[error(
42    "The object belonging to this callback_data was deleted or the callback_data was manipulated."
43)]
44pub struct InvalidCallbackData {
45    /// The raw callback data string that could not be resolved.
46    pub callback_data: Option<String>,
47}
48
49// ---------------------------------------------------------------------------
50// Internal keyboard metadata
51// ---------------------------------------------------------------------------
52
53#[derive(Debug, Clone)]
54struct KeyboardData {
55    keyboard_uuid: String,
56    access_time: f64,
57    /// Maps button uuid -> arbitrary data stored as `Value`.
58    button_data: HashMap<String, Value>,
59}
60
61impl KeyboardData {
62    fn new(keyboard_uuid: String) -> Self {
63        Self {
64            keyboard_uuid,
65            access_time: now_f64(),
66            button_data: HashMap::new(),
67        }
68    }
69
70    fn update_access_time(&mut self) {
71        self.access_time = now_f64();
72    }
73
74    fn to_tuple(&self) -> (String, f64, HashMap<String, Value>) {
75        (
76            self.keyboard_uuid.clone(),
77            self.access_time,
78            self.button_data.clone(),
79        )
80    }
81}
82
83fn now_f64() -> f64 {
84    SystemTime::now()
85        .duration_since(UNIX_EPOCH)
86        .unwrap_or_default()
87        .as_secs_f64()
88}
89
90// ---------------------------------------------------------------------------
91// LRU map (simple insertion-order VecDeque + HashMap)
92// ---------------------------------------------------------------------------
93
94/// A minimal bounded LRU cache backed by a `HashMap` + `VecDeque` for ordering.
95#[derive(Debug, Clone)]
96struct LruMap<V> {
97    map: HashMap<String, V>,
98    order: VecDeque<String>,
99    maxsize: usize,
100}
101
102impl<V> LruMap<V> {
103    fn new(maxsize: usize) -> Self {
104        // Start with zero allocation -- the HashMap and VecDeque will grow
105        // lazily on first insert.  With maxsize=1024 this avoids ~64 KB of
106        // upfront heap usage per LruMap (hash table + VecDeque backing array).
107        Self {
108            map: HashMap::new(),
109            order: VecDeque::new(),
110            maxsize,
111        }
112    }
113
114    fn get_mut(&mut self, key: &str) -> Option<&mut V> {
115        if self.map.contains_key(key) {
116            // Move to back (most recently used)
117            self.order.retain(|k| k != key);
118            self.order.push_back(key.to_owned());
119            self.map.get_mut(key)
120        } else {
121            None
122        }
123    }
124
125    fn insert(&mut self, key: String, value: V) {
126        if self.map.contains_key(&key) {
127            self.order.retain(|k| k != &key);
128        } else if self.map.len() >= self.maxsize {
129            if let Some(evicted) = self.order.pop_front() {
130                self.map.remove(&evicted);
131            }
132        }
133        self.order.push_back(key.clone());
134        self.map.insert(key, value);
135    }
136
137    fn remove(&mut self, key: &str) -> Option<V> {
138        if let Some(v) = self.map.remove(key) {
139            self.order.retain(|k| k != key);
140            Some(v)
141        } else {
142            None
143        }
144    }
145
146    fn clear(&mut self) {
147        self.map.clear();
148        self.order.clear();
149    }
150
151    fn values(&self) -> impl Iterator<Item = &V> {
152        self.map.values()
153    }
154
155    fn iter(&self) -> impl Iterator<Item = (&String, &V)> {
156        self.map.iter()
157    }
158
159    fn retain<F: FnMut(&String, &V) -> bool>(&mut self, mut f: F) {
160        let to_remove: Vec<String> = self
161            .map
162            .iter()
163            .filter(|(k, v)| !f(k, v))
164            .map(|(k, _)| k.clone())
165            .collect();
166        for key in &to_remove {
167            self.map.remove(key);
168        }
169        self.order.retain(|k| !to_remove.contains(k));
170    }
171}
172
173// ---------------------------------------------------------------------------
174// Persistence data type alias
175// ---------------------------------------------------------------------------
176
177/// Persistent representation of the cache state.
178///
179/// Tuple of:
180/// - list of `(keyboard_uuid, access_time, button_data)` tuples
181/// - map of `callback_query_id -> keyboard_uuid`
182pub type CdcData = (
183    Vec<(String, f64, HashMap<String, Value>)>,
184    HashMap<String, String>,
185);
186
187// ---------------------------------------------------------------------------
188// CallbackDataCache
189// ---------------------------------------------------------------------------
190
191/// A custom cache for storing the callback data of an [`ExtBot`](super::ext_bot::ExtBot).
192///
193/// Internally, it keeps two mappings with fixed maximum size:
194///
195/// * One for mapping the data received in callback queries to the cached objects.
196/// * One for mapping the IDs of received callback queries to the cached objects.
197///
198/// The second mapping allows manually dropping data cached for keyboards of messages sent via
199/// inline mode.  If necessary, the least recently used items are evicted.
200#[derive(Debug, Clone)]
201pub struct CallbackDataCache {
202    keyboard_data: LruMap<KeyboardData>,
203    callback_queries: LruMap<String>,
204    maxsize: usize,
205}
206
207impl CallbackDataCache {
208    /// Creates a new `CallbackDataCache`.
209    ///
210    /// # Arguments
211    ///
212    /// * `maxsize` - Maximum number of items in each of the internal mappings.
213    #[must_use]
214    pub fn new(maxsize: usize) -> Self {
215        Self {
216            keyboard_data: LruMap::new(maxsize),
217            callback_queries: LruMap::new(maxsize),
218            maxsize,
219        }
220    }
221
222    /// Loads persisted data into the cache.
223    pub fn load_persistence_data(&mut self, data: CdcData) {
224        let (keyboard_list, query_map) = data;
225        for (uuid, access_time, button_data) in keyboard_list {
226            self.keyboard_data.insert(
227                uuid.clone(),
228                KeyboardData {
229                    keyboard_uuid: uuid,
230                    access_time,
231                    button_data,
232                },
233            );
234        }
235        for (qid, kbd_uuid) in query_map {
236            self.callback_queries.insert(qid, kbd_uuid);
237        }
238    }
239
240    /// The maximum size of the cache.
241    #[must_use]
242    pub fn maxsize(&self) -> usize {
243        self.maxsize
244    }
245
246    /// Returns the data that needs to be persisted.
247    #[must_use]
248    pub fn persistence_data(&self) -> CdcData {
249        let kbd_list: Vec<_> = self
250            .keyboard_data
251            .values()
252            .map(KeyboardData::to_tuple)
253            .collect();
254        let query_map: HashMap<String, String> = self
255            .callback_queries
256            .iter()
257            .map(|(k, v)| (k.clone(), v.clone()))
258            .collect();
259        (kbd_list, query_map)
260    }
261
262    /// Registers the reply markup in the cache.
263    ///
264    /// If any of the buttons have `callback_data`, stores that data and builds a new keyboard
265    /// with the correspondingly replaced buttons.  Otherwise, returns the original reply markup
266    /// unchanged.
267    pub fn process_keyboard(
268        &mut self,
269        reply_markup: &InlineKeyboardMarkup,
270    ) -> InlineKeyboardMarkup {
271        let keyboard_uuid = generate_uuid();
272        let mut kbd_data = KeyboardData::new(keyboard_uuid.clone());
273
274        let mut new_rows: Vec<Vec<InlineKeyboardButton>> = Vec::new();
275        let mut any_replaced = false;
276
277        for row in &reply_markup.inline_keyboard {
278            let mut new_row: Vec<InlineKeyboardButton> = Vec::new();
279            for btn in row {
280                if btn.callback_data.is_some() {
281                    let mut btn_copy = btn.clone();
282                    let btn_uuid = generate_uuid();
283                    kbd_data.button_data.insert(
284                        btn_uuid.clone(),
285                        Value::String(btn.callback_data.clone().unwrap_or_default()),
286                    );
287                    btn_copy.callback_data = Some(format!("{keyboard_uuid}{btn_uuid}"));
288                    new_row.push(btn_copy);
289                    any_replaced = true;
290                } else {
291                    new_row.push(btn.clone());
292                }
293            }
294            new_rows.push(new_row);
295        }
296
297        if !any_replaced {
298            return reply_markup.clone();
299        }
300
301        self.keyboard_data.insert(keyboard_uuid, kbd_data);
302
303        InlineKeyboardMarkup::new(new_rows)
304    }
305
306    /// Extracts keyboard uuid and button uuid from a raw callback data string.
307    ///
308    /// The first 32 characters are the keyboard uuid, the rest is the button uuid.
309    #[must_use]
310    pub fn extract_uuids(callback_data: &str) -> (&str, &str) {
311        if callback_data.len() >= 32 {
312            (&callback_data[..32], &callback_data[32..])
313        } else {
314            (callback_data, "")
315        }
316    }
317
318    fn get_keyboard_uuid_and_button_data(
319        &mut self,
320        callback_data: &str,
321    ) -> Result<(String, Value), InvalidCallbackData> {
322        let (keyboard_uuid, button_uuid) = Self::extract_uuids(callback_data);
323
324        let kbd = self
325            .keyboard_data
326            .get_mut(keyboard_uuid)
327            .ok_or_else(|| InvalidCallbackData {
328                callback_data: Some(callback_data.to_owned()),
329            })?;
330
331        let btn_data =
332            kbd.button_data
333                .get(button_uuid)
334                .cloned()
335                .ok_or_else(|| InvalidCallbackData {
336                    callback_data: Some(callback_data.to_owned()),
337                })?;
338
339        kbd.update_access_time();
340
341        Ok((keyboard_uuid.to_owned(), btn_data))
342    }
343
344    /// Replaces the data in the inline keyboard attached to a raw JSON message value.
345    ///
346    /// Works with `Message.reply_markup` being `Option<Value>` (the raw type from the
347    /// `rust-tg-bot-raw` crate).
348    ///
349    /// Returns the keyboard UUID if resolution succeeded.
350    pub fn process_message_value(&mut self, message: &mut Value) -> Option<String> {
351        let rm = message.get_mut("reply_markup")?;
352        if rm.is_null() {
353            return None;
354        }
355
356        // Try to deserialize as InlineKeyboardMarkup
357        let mut markup: InlineKeyboardMarkup = serde_json::from_value(rm.clone()).ok()?;
358
359        let mut keyboard_uuid: Option<String> = None;
360
361        for row in &mut markup.inline_keyboard {
362            for button in row {
363                if let Some(ref raw_data) = button.callback_data.clone() {
364                    match self.get_keyboard_uuid_and_button_data(raw_data) {
365                        Ok((kbd_id, data)) => {
366                            button.callback_data = Some(data.to_string());
367                            if keyboard_uuid.is_none() {
368                                keyboard_uuid = Some(kbd_id);
369                            }
370                        }
371                        Err(_) => {
372                            button.callback_data = None;
373                        }
374                    }
375                }
376            }
377        }
378
379        // Write back the modified markup
380        if let Ok(v) = serde_json::to_value(&markup) {
381            *rm = v;
382        }
383
384        keyboard_uuid
385    }
386
387    /// Replaces the data in the callback query (and attached message keyboard) with cached
388    /// objects.
389    ///
390    /// **In place** -- modifies the passed `CallbackQuery`.
391    pub fn process_callback_query(&mut self, callback_query: &mut CallbackQuery) {
392        if let Some(ref raw_data) = callback_query.data.clone() {
393            match self.get_keyboard_uuid_and_button_data(raw_data) {
394                Ok((kbd_uuid, data)) => {
395                    callback_query.data = Some(data.to_string());
396                    self.callback_queries
397                        .insert(callback_query.id.clone(), kbd_uuid);
398                }
399                Err(_) => {
400                    callback_query.data = None;
401                }
402            }
403        }
404
405        // Process the attached message (as raw Value via the `message` field).
406        if let Some(ref mut msg) = callback_query.message {
407            // The message is Box<MaybeInaccessibleMessage> which contains reply_markup: Option<Value>.
408            // We need to convert to a Value, process it, and write it back.
409            if let Ok(mut msg_val) = serde_json::to_value(&**msg) {
410                self.process_message_value(&mut msg_val);
411                if let Ok(processed_msg) = serde_json::from_value::<
412                    rust_tg_bot_raw::types::message::MaybeInaccessibleMessage,
413                >(msg_val)
414                {
415                    **msg = processed_msg;
416                }
417            }
418        }
419    }
420
421    /// Deletes the data for the specified callback query.
422    ///
423    /// # Errors
424    ///
425    /// Returns `Err` if the callback query is not found in the cache.
426    pub fn drop_data(&mut self, callback_query_id: &str) -> Result<(), InvalidCallbackData> {
427        let kbd_uuid =
428            self.callback_queries
429                .remove(callback_query_id)
430                .ok_or(InvalidCallbackData {
431                    callback_data: None,
432                })?;
433
434        // Silently ignore if the keyboard itself is already gone.
435        let _ = self.keyboard_data.remove(&kbd_uuid);
436        Ok(())
437    }
438
439    /// Clears the stored callback data.
440    ///
441    /// If `time_cutoff` is provided, only entries older than that UNIX timestamp are cleared.
442    pub fn clear_callback_data(&mut self, time_cutoff: Option<f64>) {
443        match time_cutoff {
444            None => self.keyboard_data.clear(),
445            Some(cutoff) => {
446                self.keyboard_data.retain(|_, v| v.access_time >= cutoff);
447            }
448        }
449    }
450
451    /// Clears all stored callback query IDs.
452    pub fn clear_callback_queries(&mut self) {
453        self.callback_queries.clear();
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn uuid_generation_is_unique() {
463        let a = generate_uuid();
464        let b = generate_uuid();
465        assert_ne!(a, b);
466        assert_eq!(a.len(), 32);
467    }
468
469    #[test]
470    fn extract_uuids_splits_correctly() {
471        let combined = format!("{}{}", "a".repeat(32), "b".repeat(32));
472        let (kbd, btn) = CallbackDataCache::extract_uuids(&combined);
473        assert_eq!(kbd, "a".repeat(32));
474        assert_eq!(btn, "b".repeat(32));
475    }
476
477    #[test]
478    fn process_keyboard_replaces_callback_data() {
479        let mut cache = CallbackDataCache::new(128);
480
481        let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
482            "Click", "my_data",
483        )]]);
484
485        let new_markup = cache.process_keyboard(&markup);
486        let new_data = new_markup.inline_keyboard[0][0]
487            .callback_data
488            .as_ref()
489            .unwrap();
490
491        // The replaced data should be 64 chars (keyboard_uuid + button_uuid).
492        assert_eq!(new_data.len(), 64);
493        assert_ne!(new_data, "my_data");
494    }
495
496    #[test]
497    fn process_keyboard_noop_without_callback_data() {
498        let mut cache = CallbackDataCache::new(128);
499
500        let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::url(
501            "URL",
502            "https://example.com",
503        )]]);
504
505        let new_markup = cache.process_keyboard(&markup);
506        assert_eq!(
507            new_markup.inline_keyboard[0][0].url,
508            markup.inline_keyboard[0][0].url
509        );
510    }
511
512    #[test]
513    fn roundtrip_process_and_resolve() {
514        let mut cache = CallbackDataCache::new(128);
515
516        let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
517            "Click", "original",
518        )]]);
519
520        let new_markup = cache.process_keyboard(&markup);
521        let uuid_data = new_markup.inline_keyboard[0][0]
522            .callback_data
523            .clone()
524            .unwrap();
525
526        // Simulate receiving the callback query
527        let user = rust_tg_bot_raw::types::user::User::new(1, false, "Test");
528        let mut cq = CallbackQuery::new("query_1", user, "inst");
529        cq.data = Some(uuid_data);
530
531        cache.process_callback_query(&mut cq);
532
533        // The data should now be the JSON of the original string.
534        assert_eq!(cq.data.as_deref(), Some("\"original\""));
535    }
536
537    #[test]
538    fn drop_data_removes_entry() {
539        let mut cache = CallbackDataCache::new(128);
540
541        let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
542            "Click", "payload",
543        )]]);
544
545        let new_markup = cache.process_keyboard(&markup);
546        let uuid_data = new_markup.inline_keyboard[0][0]
547            .callback_data
548            .clone()
549            .unwrap();
550
551        let user = rust_tg_bot_raw::types::user::User::new(1, false, "T");
552        let mut cq = CallbackQuery::new("q2", user, "i");
553        cq.data = Some(uuid_data);
554
555        cache.process_callback_query(&mut cq);
556        assert!(cache.drop_data("q2").is_ok());
557        assert!(cache.drop_data("q2").is_err());
558    }
559
560    #[test]
561    fn lru_eviction() {
562        let mut cache = CallbackDataCache::new(2);
563
564        for i in 0..3 {
565            let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
566                format!("btn_{i}"),
567                format!("data_{i}"),
568            )]]);
569            cache.process_keyboard(&markup);
570        }
571
572        // Only 2 keyboards should remain in the cache.
573        assert_eq!(cache.keyboard_data.map.len(), 2);
574    }
575
576    #[test]
577    fn persistence_roundtrip() {
578        let mut cache = CallbackDataCache::new(128);
579
580        let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
581            "Click",
582            "persist_me",
583        )]]);
584
585        cache.process_keyboard(&markup);
586        let persisted = cache.persistence_data();
587
588        let mut cache2 = CallbackDataCache::new(128);
589        cache2.load_persistence_data(persisted);
590
591        assert_eq!(cache2.keyboard_data.map.len(), 1);
592    }
593
594    #[test]
595    fn clear_with_cutoff() {
596        let mut cache = CallbackDataCache::new(128);
597
598        let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
599            "Old", "old_data",
600        )]]);
601
602        cache.process_keyboard(&markup);
603
604        // Clearing with a far-future cutoff should remove everything.
605        cache.clear_callback_data(Some(f64::MAX));
606        assert_eq!(cache.keyboard_data.map.len(), 0);
607    }
608}