1use std::collections::{HashMap, VecDeque};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use serde_json::Value;
14
15use rust_tg_bot_raw::types::callback_query::CallbackQuery;
16use rust_tg_bot_raw::types::inline::inline_keyboard_button::InlineKeyboardButton;
17use rust_tg_bot_raw::types::inline::inline_keyboard_markup::InlineKeyboardMarkup;
18
19static COUNTER: AtomicU64 = AtomicU64::new(0);
24
25fn generate_uuid() -> String {
27 let ts = SystemTime::now()
28 .duration_since(UNIX_EPOCH)
29 .unwrap_or_default()
30 .as_nanos();
31 let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
32 format!("{ts:016x}{seq:016x}")
33}
34
35#[derive(Debug, Clone, thiserror::Error)]
41#[error(
42 "The object belonging to this callback_data was deleted or the callback_data was manipulated."
43)]
44pub struct InvalidCallbackData {
45 pub callback_data: Option<String>,
47}
48
49#[derive(Debug, Clone)]
54struct KeyboardData {
55 keyboard_uuid: String,
56 access_time: f64,
57 button_data: HashMap<String, Value>,
59}
60
61impl KeyboardData {
62 fn new(keyboard_uuid: String) -> Self {
63 Self {
64 keyboard_uuid,
65 access_time: now_f64(),
66 button_data: HashMap::new(),
67 }
68 }
69
70 fn update_access_time(&mut self) {
71 self.access_time = now_f64();
72 }
73
74 fn to_tuple(&self) -> (String, f64, HashMap<String, Value>) {
75 (
76 self.keyboard_uuid.clone(),
77 self.access_time,
78 self.button_data.clone(),
79 )
80 }
81}
82
83fn now_f64() -> f64 {
84 SystemTime::now()
85 .duration_since(UNIX_EPOCH)
86 .unwrap_or_default()
87 .as_secs_f64()
88}
89
90#[derive(Debug, Clone)]
96struct LruMap<V> {
97 map: HashMap<String, V>,
98 order: VecDeque<String>,
99 maxsize: usize,
100}
101
102impl<V> LruMap<V> {
103 fn new(maxsize: usize) -> Self {
104 Self {
108 map: HashMap::new(),
109 order: VecDeque::new(),
110 maxsize,
111 }
112 }
113
114 fn get_mut(&mut self, key: &str) -> Option<&mut V> {
115 if self.map.contains_key(key) {
116 self.order.retain(|k| k != key);
118 self.order.push_back(key.to_owned());
119 self.map.get_mut(key)
120 } else {
121 None
122 }
123 }
124
125 fn insert(&mut self, key: String, value: V) {
126 if self.map.contains_key(&key) {
127 self.order.retain(|k| k != &key);
128 } else if self.map.len() >= self.maxsize {
129 if let Some(evicted) = self.order.pop_front() {
130 self.map.remove(&evicted);
131 }
132 }
133 self.order.push_back(key.clone());
134 self.map.insert(key, value);
135 }
136
137 fn remove(&mut self, key: &str) -> Option<V> {
138 if let Some(v) = self.map.remove(key) {
139 self.order.retain(|k| k != key);
140 Some(v)
141 } else {
142 None
143 }
144 }
145
146 fn clear(&mut self) {
147 self.map.clear();
148 self.order.clear();
149 }
150
151 fn values(&self) -> impl Iterator<Item = &V> {
152 self.map.values()
153 }
154
155 fn iter(&self) -> impl Iterator<Item = (&String, &V)> {
156 self.map.iter()
157 }
158
159 fn retain<F: FnMut(&String, &V) -> bool>(&mut self, mut f: F) {
160 let to_remove: Vec<String> = self
161 .map
162 .iter()
163 .filter(|(k, v)| !f(k, v))
164 .map(|(k, _)| k.clone())
165 .collect();
166 for key in &to_remove {
167 self.map.remove(key);
168 }
169 self.order.retain(|k| !to_remove.contains(k));
170 }
171}
172
173pub type CdcData = (
183 Vec<(String, f64, HashMap<String, Value>)>,
184 HashMap<String, String>,
185);
186
187#[derive(Debug, Clone)]
201pub struct CallbackDataCache {
202 keyboard_data: LruMap<KeyboardData>,
203 callback_queries: LruMap<String>,
204 maxsize: usize,
205}
206
207impl CallbackDataCache {
208 #[must_use]
214 pub fn new(maxsize: usize) -> Self {
215 Self {
216 keyboard_data: LruMap::new(maxsize),
217 callback_queries: LruMap::new(maxsize),
218 maxsize,
219 }
220 }
221
222 pub fn load_persistence_data(&mut self, data: CdcData) {
224 let (keyboard_list, query_map) = data;
225 for (uuid, access_time, button_data) in keyboard_list {
226 self.keyboard_data.insert(
227 uuid.clone(),
228 KeyboardData {
229 keyboard_uuid: uuid,
230 access_time,
231 button_data,
232 },
233 );
234 }
235 for (qid, kbd_uuid) in query_map {
236 self.callback_queries.insert(qid, kbd_uuid);
237 }
238 }
239
240 #[must_use]
242 pub fn maxsize(&self) -> usize {
243 self.maxsize
244 }
245
246 #[must_use]
248 pub fn persistence_data(&self) -> CdcData {
249 let kbd_list: Vec<_> = self
250 .keyboard_data
251 .values()
252 .map(KeyboardData::to_tuple)
253 .collect();
254 let query_map: HashMap<String, String> = self
255 .callback_queries
256 .iter()
257 .map(|(k, v)| (k.clone(), v.clone()))
258 .collect();
259 (kbd_list, query_map)
260 }
261
262 pub fn process_keyboard(
268 &mut self,
269 reply_markup: &InlineKeyboardMarkup,
270 ) -> InlineKeyboardMarkup {
271 let keyboard_uuid = generate_uuid();
272 let mut kbd_data = KeyboardData::new(keyboard_uuid.clone());
273
274 let mut new_rows: Vec<Vec<InlineKeyboardButton>> = Vec::new();
275 let mut any_replaced = false;
276
277 for row in &reply_markup.inline_keyboard {
278 let mut new_row: Vec<InlineKeyboardButton> = Vec::new();
279 for btn in row {
280 if btn.callback_data.is_some() {
281 let mut btn_copy = btn.clone();
282 let btn_uuid = generate_uuid();
283 kbd_data.button_data.insert(
284 btn_uuid.clone(),
285 Value::String(btn.callback_data.clone().unwrap_or_default()),
286 );
287 btn_copy.callback_data = Some(format!("{keyboard_uuid}{btn_uuid}"));
288 new_row.push(btn_copy);
289 any_replaced = true;
290 } else {
291 new_row.push(btn.clone());
292 }
293 }
294 new_rows.push(new_row);
295 }
296
297 if !any_replaced {
298 return reply_markup.clone();
299 }
300
301 self.keyboard_data.insert(keyboard_uuid, kbd_data);
302
303 InlineKeyboardMarkup::new(new_rows)
304 }
305
306 #[must_use]
310 pub fn extract_uuids(callback_data: &str) -> (&str, &str) {
311 if callback_data.len() >= 32 {
312 (&callback_data[..32], &callback_data[32..])
313 } else {
314 (callback_data, "")
315 }
316 }
317
318 fn get_keyboard_uuid_and_button_data(
319 &mut self,
320 callback_data: &str,
321 ) -> Result<(String, Value), InvalidCallbackData> {
322 let (keyboard_uuid, button_uuid) = Self::extract_uuids(callback_data);
323
324 let kbd = self
325 .keyboard_data
326 .get_mut(keyboard_uuid)
327 .ok_or_else(|| InvalidCallbackData {
328 callback_data: Some(callback_data.to_owned()),
329 })?;
330
331 let btn_data =
332 kbd.button_data
333 .get(button_uuid)
334 .cloned()
335 .ok_or_else(|| InvalidCallbackData {
336 callback_data: Some(callback_data.to_owned()),
337 })?;
338
339 kbd.update_access_time();
340
341 Ok((keyboard_uuid.to_owned(), btn_data))
342 }
343
344 pub fn process_message_value(&mut self, message: &mut Value) -> Option<String> {
351 let rm = message.get_mut("reply_markup")?;
352 if rm.is_null() {
353 return None;
354 }
355
356 let mut markup: InlineKeyboardMarkup = serde_json::from_value(rm.clone()).ok()?;
358
359 let mut keyboard_uuid: Option<String> = None;
360
361 for row in &mut markup.inline_keyboard {
362 for button in row {
363 if let Some(ref raw_data) = button.callback_data.clone() {
364 match self.get_keyboard_uuid_and_button_data(raw_data) {
365 Ok((kbd_id, data)) => {
366 button.callback_data = Some(data.to_string());
367 if keyboard_uuid.is_none() {
368 keyboard_uuid = Some(kbd_id);
369 }
370 }
371 Err(_) => {
372 button.callback_data = None;
373 }
374 }
375 }
376 }
377 }
378
379 if let Ok(v) = serde_json::to_value(&markup) {
381 *rm = v;
382 }
383
384 keyboard_uuid
385 }
386
387 pub fn process_callback_query(&mut self, callback_query: &mut CallbackQuery) {
392 if let Some(ref raw_data) = callback_query.data.clone() {
393 match self.get_keyboard_uuid_and_button_data(raw_data) {
394 Ok((kbd_uuid, data)) => {
395 callback_query.data = Some(data.to_string());
396 self.callback_queries
397 .insert(callback_query.id.clone(), kbd_uuid);
398 }
399 Err(_) => {
400 callback_query.data = None;
401 }
402 }
403 }
404
405 if let Some(ref mut msg) = callback_query.message {
407 if let Ok(mut msg_val) = serde_json::to_value(&**msg) {
410 self.process_message_value(&mut msg_val);
411 if let Ok(processed_msg) = serde_json::from_value::<
412 rust_tg_bot_raw::types::message::MaybeInaccessibleMessage,
413 >(msg_val)
414 {
415 **msg = processed_msg;
416 }
417 }
418 }
419 }
420
421 pub fn drop_data(&mut self, callback_query_id: &str) -> Result<(), InvalidCallbackData> {
427 let kbd_uuid =
428 self.callback_queries
429 .remove(callback_query_id)
430 .ok_or(InvalidCallbackData {
431 callback_data: None,
432 })?;
433
434 let _ = self.keyboard_data.remove(&kbd_uuid);
436 Ok(())
437 }
438
439 pub fn clear_callback_data(&mut self, time_cutoff: Option<f64>) {
443 match time_cutoff {
444 None => self.keyboard_data.clear(),
445 Some(cutoff) => {
446 self.keyboard_data.retain(|_, v| v.access_time >= cutoff);
447 }
448 }
449 }
450
451 pub fn clear_callback_queries(&mut self) {
453 self.callback_queries.clear();
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn uuid_generation_is_unique() {
463 let a = generate_uuid();
464 let b = generate_uuid();
465 assert_ne!(a, b);
466 assert_eq!(a.len(), 32);
467 }
468
469 #[test]
470 fn extract_uuids_splits_correctly() {
471 let combined = format!("{}{}", "a".repeat(32), "b".repeat(32));
472 let (kbd, btn) = CallbackDataCache::extract_uuids(&combined);
473 assert_eq!(kbd, "a".repeat(32));
474 assert_eq!(btn, "b".repeat(32));
475 }
476
477 #[test]
478 fn process_keyboard_replaces_callback_data() {
479 let mut cache = CallbackDataCache::new(128);
480
481 let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
482 "Click", "my_data",
483 )]]);
484
485 let new_markup = cache.process_keyboard(&markup);
486 let new_data = new_markup.inline_keyboard[0][0]
487 .callback_data
488 .as_ref()
489 .unwrap();
490
491 assert_eq!(new_data.len(), 64);
493 assert_ne!(new_data, "my_data");
494 }
495
496 #[test]
497 fn process_keyboard_noop_without_callback_data() {
498 let mut cache = CallbackDataCache::new(128);
499
500 let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::url(
501 "URL",
502 "https://example.com",
503 )]]);
504
505 let new_markup = cache.process_keyboard(&markup);
506 assert_eq!(
507 new_markup.inline_keyboard[0][0].url,
508 markup.inline_keyboard[0][0].url
509 );
510 }
511
512 #[test]
513 fn roundtrip_process_and_resolve() {
514 let mut cache = CallbackDataCache::new(128);
515
516 let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
517 "Click", "original",
518 )]]);
519
520 let new_markup = cache.process_keyboard(&markup);
521 let uuid_data = new_markup.inline_keyboard[0][0]
522 .callback_data
523 .clone()
524 .unwrap();
525
526 let user = rust_tg_bot_raw::types::user::User::new(1, false, "Test");
528 let mut cq = CallbackQuery::new("query_1", user, "inst");
529 cq.data = Some(uuid_data);
530
531 cache.process_callback_query(&mut cq);
532
533 assert_eq!(cq.data.as_deref(), Some("\"original\""));
535 }
536
537 #[test]
538 fn drop_data_removes_entry() {
539 let mut cache = CallbackDataCache::new(128);
540
541 let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
542 "Click", "payload",
543 )]]);
544
545 let new_markup = cache.process_keyboard(&markup);
546 let uuid_data = new_markup.inline_keyboard[0][0]
547 .callback_data
548 .clone()
549 .unwrap();
550
551 let user = rust_tg_bot_raw::types::user::User::new(1, false, "T");
552 let mut cq = CallbackQuery::new("q2", user, "i");
553 cq.data = Some(uuid_data);
554
555 cache.process_callback_query(&mut cq);
556 assert!(cache.drop_data("q2").is_ok());
557 assert!(cache.drop_data("q2").is_err());
558 }
559
560 #[test]
561 fn lru_eviction() {
562 let mut cache = CallbackDataCache::new(2);
563
564 for i in 0..3 {
565 let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
566 format!("btn_{i}"),
567 format!("data_{i}"),
568 )]]);
569 cache.process_keyboard(&markup);
570 }
571
572 assert_eq!(cache.keyboard_data.map.len(), 2);
574 }
575
576 #[test]
577 fn persistence_roundtrip() {
578 let mut cache = CallbackDataCache::new(128);
579
580 let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
581 "Click",
582 "persist_me",
583 )]]);
584
585 cache.process_keyboard(&markup);
586 let persisted = cache.persistence_data();
587
588 let mut cache2 = CallbackDataCache::new(128);
589 cache2.load_persistence_data(persisted);
590
591 assert_eq!(cache2.keyboard_data.map.len(), 1);
592 }
593
594 #[test]
595 fn clear_with_cutoff() {
596 let mut cache = CallbackDataCache::new(128);
597
598 let markup = InlineKeyboardMarkup::new(vec![vec![InlineKeyboardButton::callback(
599 "Old", "old_data",
600 )]]);
601
602 cache.process_keyboard(&markup);
603
604 cache.clear_callback_data(Some(f64::MAX));
606 assert_eq!(cache.keyboard_data.map.len(), 0);
607 }
608}