Skip to main content

rust_tg_bot_ext/utils/
tracking_dict.rs

1//! A mutable mapping that tracks which keys have been written to.
2//!
3//! Port of `telegram.ext._utils.trackingdict.TrackingDict`.
4//! Read access is **not** tracked; only mutations (`insert`, `remove`,
5//! `clear`, etc.) mark keys as dirty.
6
7use std::collections::{HashMap, HashSet};
8use std::hash::Hash;
9
10/// Sentinel value returned by [`TrackingDict::pop_accessed_write_items`]
11/// when an entry was deleted rather than updated.
12#[derive(Debug, Clone, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum EntryValue<V> {
15    /// The key still exists and holds this value.
16    Value(V),
17    /// The key was deleted since the last drain.
18    Deleted,
19}
20
21/// A `HashMap` wrapper that records which keys have been mutated.
22#[derive(Debug, Clone)]
23pub struct TrackingDict<K, V> {
24    data: HashMap<K, V>,
25    dirty: HashSet<K>,
26}
27
28impl<K, V> Default for TrackingDict<K, V>
29where
30    K: Eq + Hash + Clone,
31{
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl<K, V> TrackingDict<K, V>
38where
39    K: Eq + Hash + Clone,
40{
41    /// Create an empty `TrackingDict`.
42    pub fn new() -> Self {
43        Self {
44            data: HashMap::new(),
45            dirty: HashSet::new(),
46        }
47    }
48
49    /// Create a `TrackingDict` pre-populated with `data`.
50    /// None of the initial keys are considered dirty.
51    pub fn from_map(data: HashMap<K, V>) -> Self {
52        Self {
53            data,
54            dirty: HashSet::new(),
55        }
56    }
57
58    // ------------------------------------------------------------------
59    // Read access (not tracked)
60    // ------------------------------------------------------------------
61
62    /// Returns a reference to the value for the given key, if present.
63    pub fn get(&self, key: &K) -> Option<&V> {
64        self.data.get(key)
65    }
66
67    /// Returns `true` if the map contains the key.
68    pub fn contains_key(&self, key: &K) -> bool {
69        self.data.contains_key(key)
70    }
71
72    /// Returns the number of entries.
73    pub fn len(&self) -> usize {
74        self.data.len()
75    }
76
77    /// Returns `true` if the map is empty.
78    pub fn is_empty(&self) -> bool {
79        self.data.is_empty()
80    }
81
82    /// Iterate over all `(key, value)` pairs. Not tracked.
83    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
84        self.data.iter()
85    }
86
87    /// Returns a reference to the underlying `HashMap`. Not tracked.
88    pub fn inner(&self) -> &HashMap<K, V> {
89        &self.data
90    }
91
92    // ------------------------------------------------------------------
93    // Write access (tracked)
94    // ------------------------------------------------------------------
95
96    /// Insert a key-value pair. Marks `key` as dirty.
97    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
98        self.dirty.insert(key.clone());
99        self.data.insert(key, value)
100    }
101
102    /// Remove a key. Marks `key` as dirty if it was present.
103    pub fn remove(&mut self, key: &K) -> Option<V> {
104        if self.data.contains_key(key) {
105            self.dirty.insert(key.clone());
106        }
107        self.data.remove(key)
108    }
109
110    /// Clear all entries. Marks every existing key as dirty.
111    pub fn clear(&mut self) {
112        for key in self.data.keys() {
113            self.dirty.insert(key.clone());
114        }
115        self.data.clear();
116    }
117
118    /// Like `HashMap::entry`, but marks the key as dirty on any insertion.
119    pub fn set_default(&mut self, key: K, default: V) -> &mut V
120    where
121        V: Clone,
122    {
123        if !self.data.contains_key(&key) {
124            self.dirty.insert(key.clone());
125            self.data.insert(key.clone(), default);
126        }
127        self.data.get_mut(&key).expect("just inserted")
128    }
129
130    // ------------------------------------------------------------------
131    // Bulk update without tracking
132    // ------------------------------------------------------------------
133
134    /// Merge entries from `other` without marking any key as dirty.
135    /// Equivalent to Python's `update_no_track`.
136    pub fn update_no_track(&mut self, other: HashMap<K, V>) {
137        for (k, v) in other {
138            self.data.insert(k, v);
139        }
140    }
141
142    // ------------------------------------------------------------------
143    // Dirty-key access
144    // ------------------------------------------------------------------
145
146    /// Manually mark a key as dirty so it appears in the next drain.
147    pub fn mark_as_accessed(&mut self, key: K) {
148        self.dirty.insert(key);
149    }
150
151    /// Drain and return all keys that have been written to since the last
152    /// call to this method (or since construction).
153    pub fn pop_accessed_keys(&mut self) -> HashSet<K> {
154        std::mem::take(&mut self.dirty)
155    }
156
157    /// Drain dirty keys together with their current values.
158    /// If a key was deleted, the value is [`EntryValue::Deleted`].
159    pub fn pop_accessed_write_items(&mut self) -> Vec<(K, EntryValue<V>)>
160    where
161        V: Clone,
162    {
163        let keys = self.pop_accessed_keys();
164        keys.into_iter()
165            .map(|k| {
166                let v = self
167                    .data
168                    .get(&k)
169                    .map_or(EntryValue::Deleted, |v| EntryValue::Value(v.clone()));
170                (k, v)
171            })
172            .collect()
173    }
174}
175
176// ---------------------------------------------------------------------------
177// Trait implementations
178// ---------------------------------------------------------------------------
179
180impl<K, V> FromIterator<(K, V)> for TrackingDict<K, V>
181where
182    K: Eq + Hash + Clone,
183{
184    fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
185        Self::from_map(iter.into_iter().collect())
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn insert_tracks_key() {
195        let mut td: TrackingDict<String, i32> = TrackingDict::new();
196        td.insert("a".into(), 1);
197        let keys = td.pop_accessed_keys();
198        assert!(keys.contains("a"));
199        // Second call should be empty.
200        assert!(td.pop_accessed_keys().is_empty());
201    }
202
203    #[test]
204    fn remove_tracks_key() {
205        let mut td = TrackingDict::from_map(HashMap::from([("x".to_owned(), 42)]));
206        td.remove(&"x".to_owned());
207        let keys = td.pop_accessed_keys();
208        assert!(keys.contains("x"));
209    }
210
211    #[test]
212    fn update_no_track_is_silent() {
213        let mut td: TrackingDict<String, i32> = TrackingDict::new();
214        td.update_no_track(HashMap::from([("b".into(), 2)]));
215        assert!(td.pop_accessed_keys().is_empty());
216        assert_eq!(td.get(&"b".into()), Some(&2));
217    }
218
219    #[test]
220    fn clear_marks_all_dirty() {
221        let mut td =
222            TrackingDict::from_map(HashMap::from([("a".to_owned(), 1), ("b".to_owned(), 2)]));
223        td.clear();
224        let keys = td.pop_accessed_keys();
225        assert!(keys.contains("a"));
226        assert!(keys.contains("b"));
227        assert!(td.is_empty());
228    }
229
230    #[test]
231    fn pop_accessed_write_items_returns_deleted() {
232        let mut td = TrackingDict::from_map(HashMap::from([("k".to_owned(), 10)]));
233        td.remove(&"k".to_owned());
234        let items = td.pop_accessed_write_items();
235        assert_eq!(items.len(), 1);
236        assert_eq!(items[0].1, EntryValue::Deleted);
237    }
238
239    #[test]
240    fn set_default_tracks_on_miss() {
241        let mut td: TrackingDict<String, i32> = TrackingDict::new();
242        td.set_default("new".to_owned(), 5);
243        let keys = td.pop_accessed_keys();
244        assert!(keys.contains("new"));
245    }
246}