snarkos_node_bft/helpers/
cache.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::events::BlockRequest;
17use snarkvm::{console::types::Field, ledger::narwhal::TransmissionID, prelude::Network};
18
19use core::hash::Hash;
20use parking_lot::RwLock;
21use std::{
22    collections::{BTreeMap, HashMap, HashSet},
23    net::{IpAddr, SocketAddr},
24};
25use time::OffsetDateTime;
26
27#[derive(Debug)]
28pub struct Cache<N: Network> {
29    /// The ordered timestamp map of peer connections and cache hits.
30    seen_inbound_connections: RwLock<BTreeMap<i64, HashMap<IpAddr, u32>>>,
31    /// The ordered timestamp map of peer IPs and cache hits.
32    seen_inbound_events: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
33    /// The ordered timestamp map of certificate IDs and cache hits.
34    seen_inbound_certificates: RwLock<BTreeMap<i64, HashMap<Field<N>, u32>>>,
35    /// The ordered timestamp map of transmission IDs and cache hits.
36    seen_inbound_transmissions: RwLock<BTreeMap<i64, HashMap<TransmissionID<N>, u32>>>,
37    /// The ordered timestamp map of inbound block requests and cache hits.
38    seen_inbound_block_requests: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
39    /// The ordered timestamp map of peer IPs and their cache hits on outbound events.
40    seen_outbound_events: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
41    /// The ordered timestamp map of peer IPs and their cache hits on certificate requests.
42    seen_outbound_certificates: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
43    /// The ordered timestamp map of peer IPs and their cache hits on transmission requests.
44    seen_outbound_transmissions: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
45    /// The map of IPs to the number of validators requests.
46    seen_outbound_validators_requests: RwLock<HashMap<SocketAddr, u32>>,
47    /// The ordered timestamp map of outbound block requests and cache hits.
48    seen_outbound_block_requests: RwLock<HashMap<SocketAddr, HashSet<BlockRequest>>>,
49}
50
51impl<N: Network> Default for Cache<N> {
52    /// Initializes a new instance of the cache.
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl<N: Network> Cache<N> {
59    /// Initializes a new instance of the cache.
60    pub fn new() -> Self {
61        Self {
62            seen_inbound_connections: Default::default(),
63            seen_inbound_events: Default::default(),
64            seen_inbound_certificates: Default::default(),
65            seen_inbound_transmissions: Default::default(),
66            seen_inbound_block_requests: Default::default(),
67            seen_outbound_events: Default::default(),
68            seen_outbound_certificates: Default::default(),
69            seen_outbound_transmissions: Default::default(),
70            seen_outbound_validators_requests: Default::default(),
71            seen_outbound_block_requests: Default::default(),
72        }
73    }
74}
75
76impl<N: Network> Cache<N> {
77    /// Inserts a new timestamp for the given peer connection, returning the number of recent connection requests.
78    pub fn insert_inbound_connection(&self, peer_ip: IpAddr, interval_in_secs: i64) -> usize {
79        Self::retain_and_insert(&self.seen_inbound_connections, peer_ip, interval_in_secs)
80    }
81
82    /// Inserts a new timestamp for the given peer, returning the number of recent events.
83    pub fn insert_inbound_event(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
84        Self::retain_and_insert(&self.seen_inbound_events, peer_ip, interval_in_secs)
85    }
86
87    /// Inserts a certificate ID into the cache, returning the number of recent events.
88    pub fn insert_inbound_certificate(&self, key: Field<N>, interval_in_secs: i64) -> usize {
89        Self::retain_and_insert(&self.seen_inbound_certificates, key, interval_in_secs)
90    }
91
92    /// Inserts a transmission ID into the cache, returning the number of recent events.
93    pub fn insert_inbound_transmission(&self, key: TransmissionID<N>, interval_in_secs: i64) -> usize {
94        Self::retain_and_insert(&self.seen_inbound_transmissions, key, interval_in_secs)
95    }
96
97    /// Inserts a block request into the cache, returning the number of recent events.
98    pub fn insert_inbound_block_request(&self, key: SocketAddr, interval_in_secs: i64) -> usize {
99        Self::retain_and_insert(&self.seen_inbound_block_requests, key, interval_in_secs)
100    }
101}
102
103impl<N: Network> Cache<N> {
104    /// Inserts a new timestamp for the given peer, returning the number of recent events.
105    pub fn insert_outbound_event(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
106        Self::retain_and_insert(&self.seen_outbound_events, peer_ip, interval_in_secs)
107    }
108
109    /// Inserts a new timestamp for the given peer, returning the number of recent events.
110    pub fn insert_outbound_certificate(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
111        Self::retain_and_insert(&self.seen_outbound_certificates, peer_ip, interval_in_secs)
112    }
113
114    /// Inserts a new timestamp for the given peer, returning the number of recent events.
115    pub fn insert_outbound_transmission(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
116        Self::retain_and_insert(&self.seen_outbound_transmissions, peer_ip, interval_in_secs)
117    }
118}
119
120impl<N: Network> Cache<N> {
121    /// Returns `true` if the cache contains a validators request from the given IP.
122    pub fn contains_outbound_validators_request(&self, peer_ip: SocketAddr) -> bool {
123        self.seen_outbound_validators_requests.read().get(&peer_ip).map(|r| *r > 0).unwrap_or(false)
124    }
125
126    /// Increment the IP's number of validators requests, returning the updated number of validators requests.
127    pub fn increment_outbound_validators_requests(&self, peer_ip: SocketAddr) -> u32 {
128        Self::increment_counter(&self.seen_outbound_validators_requests, peer_ip)
129    }
130
131    /// Decrement the IP's number of validators requests, returning the updated number of validators requests.
132    pub fn decrement_outbound_validators_requests(&self, peer_ip: SocketAddr) -> u32 {
133        Self::decrement_counter(&self.seen_outbound_validators_requests, peer_ip)
134    }
135
136    /// Clears the the IP's number of validator requests.
137    pub fn clear_outbound_validators_requests(&self, peer_ip: SocketAddr) {
138        self.seen_outbound_validators_requests.write().remove(&peer_ip);
139    }
140
141    /// Inserts the block request for the given peer.
142    pub fn insert_outbound_block_request(&self, peer_ip: SocketAddr, request: BlockRequest) {
143        self.seen_outbound_block_requests.write().entry(peer_ip).or_default().insert(request);
144    }
145
146    /// Removes the block request for the given peer. Returns whether the request was present.
147    pub fn remove_outbound_block_request(&self, peer_ip: SocketAddr, request: &BlockRequest) -> bool {
148        self.seen_outbound_block_requests
149            .write()
150            .get_mut(&peer_ip)
151            .map(|requests| requests.remove(request))
152            .unwrap_or(false)
153    }
154
155    /// Clears the peer's number of outbound block requests.
156    pub fn clear_outbound_block_requests(&self, peer_ip: SocketAddr) {
157        self.seen_outbound_block_requests.write().remove(&peer_ip);
158    }
159}
160
161impl<N: Network> Cache<N> {
162    /// Insert a new timestamp for the given key, returning the number of recent entries.
163    fn retain_and_insert<K: Copy + Clone + PartialEq + Eq + Hash>(
164        map: &RwLock<BTreeMap<i64, HashMap<K, u32>>>,
165        key: K,
166        interval_in_secs: i64,
167    ) -> usize {
168        // Fetch the current timestamp.
169        let now = OffsetDateTime::now_utc().unix_timestamp();
170
171        // Get the write lock.
172        let mut map_write = map.write();
173        // Insert the new timestamp and increment the frequency for the key.
174        *map_write.entry(now).or_default().entry(key).or_default() += 1;
175        // Calculate the cutoff time for the entries to retain.
176        let cutoff = now.saturating_sub(interval_in_secs);
177        // Obtain the oldest timestamp from the map; it's guaranteed to exist at this point.
178        let (oldest, _) = map_write.first_key_value().unwrap();
179        // Track the number of cache hits of the key.
180        let mut cache_hits = 0;
181        // If the oldest timestamp is above the cutoff value, all the entries can be retained.
182        if cutoff <= *oldest {
183            for cache_keys in map_write.values() {
184                cache_hits += *cache_keys.get(&key).unwrap_or(&0);
185            }
186        } else {
187            // Extract the subtree after interval (i.e. non-expired entries)
188            let retained = map_write.split_off(&cutoff);
189            // Clear all the expired entries.
190            map_write.clear();
191            // Reinsert the entries into map and sum the frequency of recent requests for `key` while looping.
192            for (time, cache_keys) in retained {
193                cache_hits += *cache_keys.get(&key).unwrap_or(&0);
194                map_write.insert(time, cache_keys);
195            }
196        }
197        // Return the frequency.
198        cache_hits as usize
199    }
200
201    /// Increments the key's counter in the map, returning the updated counter.
202    fn increment_counter<K: Hash + Eq>(map: &RwLock<HashMap<K, u32>>, key: K) -> u32 {
203        let mut map_write = map.write();
204        // Load the entry for the key, and increment the counter.
205        let entry = map_write.entry(key).or_default();
206        *entry = entry.saturating_add(1);
207        // Return the updated counter.
208        *entry
209    }
210
211    /// Decrements the key's counter in the map, returning the updated counter.
212    fn decrement_counter<K: Copy + Hash + Eq>(map: &RwLock<HashMap<K, u32>>, key: K) -> u32 {
213        let mut map_write = map.write();
214        // Load the entry for the key, and decrement the counter.
215        let entry = map_write.entry(key).or_default();
216        let value = entry.saturating_sub(1);
217        // If the entry is 0, remove the entry.
218        if *entry == 0 {
219            map_write.remove(&key);
220        } else {
221            *entry = value;
222        }
223        // Return the updated counter.
224        value
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use snarkvm::prelude::MainnetV0;
232
233    use std::{net::Ipv4Addr, thread, time::Duration};
234
235    type CurrentNetwork = MainnetV0;
236
237    trait Input {
238        fn input() -> Self;
239    }
240
241    impl Input for IpAddr {
242        fn input() -> Self {
243            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
244        }
245    }
246
247    impl Input for SocketAddr {
248        fn input() -> Self {
249            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234)
250        }
251    }
252
253    impl Input for Field<CurrentNetwork> {
254        fn input() -> Self {
255            Field::from_u8(1)
256        }
257    }
258
259    impl Input for TransmissionID<CurrentNetwork> {
260        fn input() -> Self {
261            TransmissionID::Transaction(Default::default(), Default::default())
262        }
263    }
264
265    const INTERVAL_IN_SECS: i64 = 3;
266
267    macro_rules! test_cache_fields {
268        ($($name:ident),*) => {
269            $(
270                paste::paste! {
271                    #[test]
272                    fn [<test_seen_ $name s>]() {
273                        let cache = Cache::<CurrentNetwork>::default();
274                        let input = Input::input();
275
276                        // Check that the cache is empty.
277                        assert!(cache.[<seen_ $name s>].read().is_empty());
278
279                        // Insert an input, recent events should be 1.
280                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 1);
281                        // Wait for 1s so that the next entry doesn't overwrite the first one.
282                        thread::sleep(Duration::from_secs(1));
283                        // Insert an input, recent events should be 2.
284                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 2);
285                        // Wait for 1s so that the next entry doesn't overwrite the first one.
286                        thread::sleep(Duration::from_secs(1));
287                        // Insert an input, recent events should be 3.
288                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 3);
289
290                        // Check that the cache contains the input for 3 entries.
291                        assert_eq!(cache.[<seen_ $name s>].read().len(), 3);
292
293                        // Insert the input again with a small interval, causing one entry to be removed.
294                        cache.[<insert_ $name>](input, 1);
295                        // Check that the cache contains the input for 2 entries.
296                        assert_eq!(cache.[<seen_ $name s>].read().len(), 2);
297
298                        // Insert the input again with a large interval, causing nothing to be removed.
299                        cache.[<insert_ $name>](input, 10);
300                        // Check that the cache contains the input for 2 entries.
301                        assert_eq!(cache.[<seen_ $name s>].read().len(), 2);
302
303                        // Wait for the input to expire.
304                        thread::sleep(Duration::from_secs(INTERVAL_IN_SECS as u64 + 1));
305
306                        // Insert an input again, recent events should be 1.
307                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 1);
308
309                        // Check that the cache contains the input for 1 entry.
310                        assert_eq!(cache.[<seen_ $name s>].read().len(), 1);
311
312                        // Check that the cache still contains the input.
313                        let counts: u32 = cache.[<seen_ $name s>].read().values().map(|hash_map| hash_map.get(&input).unwrap_or(&0)).cloned().sum();
314                        assert_eq!(counts, 1);
315
316                        // Check that the cache contains the input and 1 timestamp entry.
317                        assert_eq!(cache.[<seen_ $name s>].read().len(), 1);
318                    }
319                }
320            )*
321        }
322    }
323
324    test_cache_fields! {
325       inbound_connection,
326       inbound_event,
327       inbound_certificate,
328       inbound_transmission,
329       outbound_event,
330       outbound_certificate,
331       outbound_transmission
332    }
333
334    #[test]
335    fn test_seen_outbound_validators_requests() {
336        let cache = Cache::<CurrentNetwork>::default();
337        let input = Input::input();
338
339        // Check the map is empty.
340        assert!(!cache.contains_outbound_validators_request(input));
341
342        // Insert some requests.
343        for _ in 0..3 {
344            cache.increment_outbound_validators_requests(input);
345            assert!(cache.contains_outbound_validators_request(input));
346        }
347
348        // Remove a request.
349        cache.decrement_outbound_validators_requests(input);
350        assert!(cache.contains_outbound_validators_request(input));
351
352        // Clear all requests.
353        cache.clear_outbound_validators_requests(input);
354        assert!(!cache.contains_outbound_validators_request(input));
355    }
356}