snarkos_node_bft/helpers/
cache.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::events::BlockRequest;
17use snarkvm::{console::types::Field, ledger::narwhal::TransmissionID, prelude::Network};
18
19use core::hash::Hash;
20#[cfg(feature = "locktick")]
21use locktick::parking_lot::RwLock;
22#[cfg(not(feature = "locktick"))]
23use parking_lot::RwLock;
24use std::{
25    collections::{BTreeMap, HashMap, HashSet},
26    net::{IpAddr, SocketAddr},
27};
28use time::OffsetDateTime;
29
30#[derive(Debug)]
31pub struct Cache<N: Network> {
32    /// The ordered timestamp map of peer connections and cache hits.
33    seen_inbound_connections: RwLock<BTreeMap<i64, HashMap<IpAddr, u32>>>,
34    /// The ordered timestamp map of peer IPs and cache hits.
35    seen_inbound_events: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
36    /// The ordered timestamp map of certificate IDs and cache hits.
37    seen_inbound_certificates: RwLock<BTreeMap<i64, HashMap<Field<N>, u32>>>,
38    /// The ordered timestamp map of transmission IDs and cache hits.
39    seen_inbound_transmissions: RwLock<BTreeMap<i64, HashMap<TransmissionID<N>, u32>>>,
40    /// The ordered timestamp map of inbound block requests and cache hits.
41    seen_inbound_block_requests: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
42    /// The ordered timestamp map of peer IPs and their cache hits on outbound events.
43    seen_outbound_events: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
44    /// The ordered timestamp map of peer IPs and their cache hits on certificate requests.
45    seen_outbound_certificates: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
46    /// The ordered timestamp map of peer IPs and their cache hits on transmission requests.
47    seen_outbound_transmissions: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
48    /// The map of IPs to the number of validators requests.
49    seen_outbound_validators_requests: RwLock<HashMap<SocketAddr, u32>>,
50    /// The ordered timestamp map of outbound block requests and cache hits.
51    seen_outbound_block_requests: RwLock<HashMap<SocketAddr, HashSet<BlockRequest>>>,
52}
53
54impl<N: Network> Default for Cache<N> {
55    /// Initializes a new instance of the cache.
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl<N: Network> Cache<N> {
62    /// Initializes a new instance of the cache.
63    pub fn new() -> Self {
64        Self {
65            seen_inbound_connections: Default::default(),
66            seen_inbound_events: Default::default(),
67            seen_inbound_certificates: Default::default(),
68            seen_inbound_transmissions: Default::default(),
69            seen_inbound_block_requests: Default::default(),
70            seen_outbound_events: Default::default(),
71            seen_outbound_certificates: Default::default(),
72            seen_outbound_transmissions: Default::default(),
73            seen_outbound_validators_requests: Default::default(),
74            seen_outbound_block_requests: Default::default(),
75        }
76    }
77}
78
79impl<N: Network> Cache<N> {
80    /// Inserts a new timestamp for the given peer connection, returning the number of recent connection requests.
81    pub fn insert_inbound_connection(&self, peer_ip: IpAddr, interval_in_secs: i64) -> usize {
82        Self::retain_and_insert(&self.seen_inbound_connections, peer_ip, interval_in_secs)
83    }
84
85    /// Inserts a new timestamp for the given peer, returning the number of recent events.
86    pub fn insert_inbound_event(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
87        Self::retain_and_insert(&self.seen_inbound_events, peer_ip, interval_in_secs)
88    }
89
90    /// Inserts a certificate ID into the cache, returning the number of recent events.
91    pub fn insert_inbound_certificate(&self, key: Field<N>, interval_in_secs: i64) -> usize {
92        Self::retain_and_insert(&self.seen_inbound_certificates, key, interval_in_secs)
93    }
94
95    /// Inserts a transmission ID into the cache, returning the number of recent events.
96    pub fn insert_inbound_transmission(&self, key: TransmissionID<N>, interval_in_secs: i64) -> usize {
97        Self::retain_and_insert(&self.seen_inbound_transmissions, key, interval_in_secs)
98    }
99
100    /// Inserts a block request into the cache, returning the number of recent events.
101    pub fn insert_inbound_block_request(&self, key: SocketAddr, interval_in_secs: i64) -> usize {
102        Self::retain_and_insert(&self.seen_inbound_block_requests, key, interval_in_secs)
103    }
104}
105
106impl<N: Network> Cache<N> {
107    /// Inserts a new timestamp for the given peer, returning the number of recent events.
108    pub fn insert_outbound_event(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
109        Self::retain_and_insert(&self.seen_outbound_events, peer_ip, interval_in_secs)
110    }
111
112    /// Inserts a new timestamp for the given peer, returning the number of recent events.
113    pub fn insert_outbound_certificate(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
114        Self::retain_and_insert(&self.seen_outbound_certificates, peer_ip, interval_in_secs)
115    }
116
117    /// Inserts a new timestamp for the given peer, returning the number of recent events.
118    pub fn insert_outbound_transmission(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
119        Self::retain_and_insert(&self.seen_outbound_transmissions, peer_ip, interval_in_secs)
120    }
121}
122
123impl<N: Network> Cache<N> {
124    /// Returns `true` if the cache contains a validators request from the given IP.
125    pub fn contains_outbound_validators_request(&self, peer_ip: SocketAddr) -> bool {
126        self.seen_outbound_validators_requests.read().get(&peer_ip).map(|r| *r > 0).unwrap_or(false)
127    }
128
129    /// Increment the IP's number of validators requests, returning the updated number of validators requests.
130    pub fn increment_outbound_validators_requests(&self, peer_ip: SocketAddr) -> u32 {
131        Self::increment_counter(&self.seen_outbound_validators_requests, peer_ip)
132    }
133
134    /// Decrement the IP's number of validators requests, returning the updated number of validators requests.
135    pub fn decrement_outbound_validators_requests(&self, peer_ip: SocketAddr) -> u32 {
136        Self::decrement_counter(&self.seen_outbound_validators_requests, peer_ip)
137    }
138
139    /// Clears the the IP's number of validator requests.
140    pub fn clear_outbound_validators_requests(&self, peer_ip: SocketAddr) {
141        self.seen_outbound_validators_requests.write().remove(&peer_ip);
142    }
143
144    /// Inserts the block request for the given peer.
145    pub fn insert_outbound_block_request(&self, peer_ip: SocketAddr, request: BlockRequest) {
146        self.seen_outbound_block_requests.write().entry(peer_ip).or_default().insert(request);
147    }
148
149    /// Removes the block request for the given peer. Returns whether the request was present.
150    pub fn remove_outbound_block_request(&self, peer_ip: SocketAddr, request: &BlockRequest) -> bool {
151        self.seen_outbound_block_requests
152            .write()
153            .get_mut(&peer_ip)
154            .map(|requests| requests.remove(request))
155            .unwrap_or(false)
156    }
157
158    /// Clears the peer's number of outbound block requests.
159    pub fn clear_outbound_block_requests(&self, peer_ip: SocketAddr) {
160        self.seen_outbound_block_requests.write().remove(&peer_ip);
161    }
162}
163
164impl<N: Network> Cache<N> {
165    /// Insert a new timestamp for the given key, returning the number of recent entries.
166    fn retain_and_insert<K: Copy + Clone + PartialEq + Eq + Hash>(
167        map: &RwLock<BTreeMap<i64, HashMap<K, u32>>>,
168        key: K,
169        interval_in_secs: i64,
170    ) -> usize {
171        // Fetch the current timestamp.
172        let now = OffsetDateTime::now_utc().unix_timestamp();
173
174        // Get the write lock.
175        let mut map_write = map.write();
176        // Insert the new timestamp and increment the frequency for the key.
177        *map_write.entry(now).or_default().entry(key).or_default() += 1;
178        // Calculate the cutoff time for the entries to retain.
179        let cutoff = now.saturating_sub(interval_in_secs);
180        // Obtain the oldest timestamp from the map; it's guaranteed to exist at this point.
181        let (oldest, _) = map_write.first_key_value().unwrap();
182        // Track the number of cache hits of the key.
183        let mut cache_hits = 0;
184        // If the oldest timestamp is above the cutoff value, all the entries can be retained.
185        if cutoff <= *oldest {
186            for cache_keys in map_write.values() {
187                cache_hits += *cache_keys.get(&key).unwrap_or(&0);
188            }
189        } else {
190            // Extract the subtree after interval (i.e. non-expired entries)
191            let retained = map_write.split_off(&cutoff);
192            // Clear all the expired entries.
193            map_write.clear();
194            // Reinsert the entries into map and sum the frequency of recent requests for `key` while looping.
195            for (time, cache_keys) in retained {
196                cache_hits += *cache_keys.get(&key).unwrap_or(&0);
197                map_write.insert(time, cache_keys);
198            }
199        }
200        // Return the frequency.
201        cache_hits as usize
202    }
203
204    /// Increments the key's counter in the map, returning the updated counter.
205    fn increment_counter<K: Hash + Eq>(map: &RwLock<HashMap<K, u32>>, key: K) -> u32 {
206        let mut map_write = map.write();
207        // Load the entry for the key, and increment the counter.
208        let entry = map_write.entry(key).or_default();
209        *entry = entry.saturating_add(1);
210        // Return the updated counter.
211        *entry
212    }
213
214    /// Decrements the key's counter in the map, returning the updated counter.
215    fn decrement_counter<K: Copy + Hash + Eq>(map: &RwLock<HashMap<K, u32>>, key: K) -> u32 {
216        let mut map_write = map.write();
217        // Load the entry for the key, and decrement the counter.
218        let entry = map_write.entry(key).or_default();
219        let value = entry.saturating_sub(1);
220        // If the entry is 0, remove the entry.
221        if *entry == 0 {
222            map_write.remove(&key);
223        } else {
224            *entry = value;
225        }
226        // Return the updated counter.
227        value
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use snarkvm::prelude::MainnetV0;
235
236    use std::{net::Ipv4Addr, thread, time::Duration};
237
238    type CurrentNetwork = MainnetV0;
239
240    trait Input {
241        fn input() -> Self;
242    }
243
244    impl Input for IpAddr {
245        fn input() -> Self {
246            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
247        }
248    }
249
250    impl Input for SocketAddr {
251        fn input() -> Self {
252            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234)
253        }
254    }
255
256    impl Input for Field<CurrentNetwork> {
257        fn input() -> Self {
258            Field::from_u8(1)
259        }
260    }
261
262    impl Input for TransmissionID<CurrentNetwork> {
263        fn input() -> Self {
264            TransmissionID::Transaction(Default::default(), Default::default())
265        }
266    }
267
268    const INTERVAL_IN_SECS: i64 = 3;
269
270    macro_rules! test_cache_fields {
271        ($($name:ident),*) => {
272            $(
273                paste::paste! {
274                    #[test]
275                    fn [<test_seen_ $name s>]() {
276                        let cache = Cache::<CurrentNetwork>::default();
277                        let input = Input::input();
278
279                        // Check that the cache is empty.
280                        assert!(cache.[<seen_ $name s>].read().is_empty());
281
282                        // Insert an input, recent events should be 1.
283                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 1);
284                        // Wait for 1s so that the next entry doesn't overwrite the first one.
285                        thread::sleep(Duration::from_secs(1));
286                        // Insert an input, recent events should be 2.
287                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 2);
288                        // Wait for 1s so that the next entry doesn't overwrite the first one.
289                        thread::sleep(Duration::from_secs(1));
290                        // Insert an input, recent events should be 3.
291                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 3);
292
293                        // Check that the cache contains the input for 3 entries.
294                        assert_eq!(cache.[<seen_ $name s>].read().len(), 3);
295
296                        // Insert the input again with a small interval, causing one entry to be removed.
297                        cache.[<insert_ $name>](input, 1);
298                        // Check that the cache contains the input for 2 entries.
299                        assert_eq!(cache.[<seen_ $name s>].read().len(), 2);
300
301                        // Insert the input again with a large interval, causing nothing to be removed.
302                        cache.[<insert_ $name>](input, 10);
303                        // Check that the cache contains the input for 2 entries.
304                        assert_eq!(cache.[<seen_ $name s>].read().len(), 2);
305
306                        // Wait for the input to expire.
307                        thread::sleep(Duration::from_secs(INTERVAL_IN_SECS as u64 + 1));
308
309                        // Insert an input again, recent events should be 1.
310                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 1);
311
312                        // Check that the cache contains the input for 1 entry.
313                        assert_eq!(cache.[<seen_ $name s>].read().len(), 1);
314
315                        // Check that the cache still contains the input.
316                        let counts: u32 = cache.[<seen_ $name s>].read().values().map(|hash_map| hash_map.get(&input).unwrap_or(&0)).cloned().sum();
317                        assert_eq!(counts, 1);
318
319                        // Check that the cache contains the input and 1 timestamp entry.
320                        assert_eq!(cache.[<seen_ $name s>].read().len(), 1);
321                    }
322                }
323            )*
324        }
325    }
326
327    test_cache_fields! {
328       inbound_connection,
329       inbound_event,
330       inbound_certificate,
331       inbound_transmission,
332       outbound_event,
333       outbound_certificate,
334       outbound_transmission
335    }
336
337    #[test]
338    fn test_seen_outbound_validators_requests() {
339        let cache = Cache::<CurrentNetwork>::default();
340        let input = Input::input();
341
342        // Check the map is empty.
343        assert!(!cache.contains_outbound_validators_request(input));
344
345        // Insert some requests.
346        for _ in 0..3 {
347            cache.increment_outbound_validators_requests(input);
348            assert!(cache.contains_outbound_validators_request(input));
349        }
350
351        // Remove a request.
352        cache.decrement_outbound_validators_requests(input);
353        assert!(cache.contains_outbound_validators_request(input));
354
355        // Clear all requests.
356        cache.clear_outbound_validators_requests(input);
357        assert!(!cache.contains_outbound_validators_request(input));
358    }
359}