snarkos_node_bft/helpers/
cache.rs1use crate::events::BlockRequest;
17use snarkvm::{console::types::Field, ledger::narwhal::TransmissionID, prelude::Network};
18
19use core::hash::Hash;
20#[cfg(feature = "locktick")]
21use locktick::parking_lot::RwLock;
22#[cfg(not(feature = "locktick"))]
23use parking_lot::RwLock;
24use std::{
25 collections::{BTreeMap, HashMap, HashSet},
26 net::{IpAddr, SocketAddr},
27};
28use time::OffsetDateTime;
29
30#[derive(Debug)]
31pub struct Cache<N: Network> {
32 seen_inbound_connections: RwLock<BTreeMap<i64, HashMap<IpAddr, u32>>>,
34 seen_inbound_events: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
36 seen_inbound_certificates: RwLock<BTreeMap<i64, HashMap<Field<N>, u32>>>,
38 seen_inbound_transmissions: RwLock<BTreeMap<i64, HashMap<TransmissionID<N>, u32>>>,
40 seen_inbound_block_requests: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
42 seen_outbound_events: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
44 seen_outbound_certificates: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
46 seen_outbound_transmissions: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
48 seen_outbound_validators_requests: RwLock<HashMap<SocketAddr, u32>>,
50 seen_outbound_block_requests: RwLock<HashMap<SocketAddr, HashSet<BlockRequest>>>,
52}
53
54impl<N: Network> Default for Cache<N> {
55 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl<N: Network> Cache<N> {
62 pub fn new() -> Self {
64 Self {
65 seen_inbound_connections: Default::default(),
66 seen_inbound_events: Default::default(),
67 seen_inbound_certificates: Default::default(),
68 seen_inbound_transmissions: Default::default(),
69 seen_inbound_block_requests: Default::default(),
70 seen_outbound_events: Default::default(),
71 seen_outbound_certificates: Default::default(),
72 seen_outbound_transmissions: Default::default(),
73 seen_outbound_validators_requests: Default::default(),
74 seen_outbound_block_requests: Default::default(),
75 }
76 }
77}
78
79impl<N: Network> Cache<N> {
80 pub fn insert_inbound_connection(&self, peer_ip: IpAddr, interval_in_secs: i64) -> usize {
82 Self::retain_and_insert(&self.seen_inbound_connections, peer_ip, interval_in_secs)
83 }
84
85 pub fn insert_inbound_event(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
87 Self::retain_and_insert(&self.seen_inbound_events, peer_ip, interval_in_secs)
88 }
89
90 pub fn insert_inbound_certificate(&self, key: Field<N>, interval_in_secs: i64) -> usize {
92 Self::retain_and_insert(&self.seen_inbound_certificates, key, interval_in_secs)
93 }
94
95 pub fn insert_inbound_transmission(&self, key: TransmissionID<N>, interval_in_secs: i64) -> usize {
97 Self::retain_and_insert(&self.seen_inbound_transmissions, key, interval_in_secs)
98 }
99
100 pub fn insert_inbound_block_request(&self, key: SocketAddr, interval_in_secs: i64) -> usize {
102 Self::retain_and_insert(&self.seen_inbound_block_requests, key, interval_in_secs)
103 }
104}
105
106impl<N: Network> Cache<N> {
107 pub fn insert_outbound_event(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
109 Self::retain_and_insert(&self.seen_outbound_events, peer_ip, interval_in_secs)
110 }
111
112 pub fn insert_outbound_certificate(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
114 Self::retain_and_insert(&self.seen_outbound_certificates, peer_ip, interval_in_secs)
115 }
116
117 pub fn insert_outbound_transmission(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
119 Self::retain_and_insert(&self.seen_outbound_transmissions, peer_ip, interval_in_secs)
120 }
121}
122
123impl<N: Network> Cache<N> {
124 pub fn contains_outbound_validators_request(&self, peer_ip: SocketAddr) -> bool {
126 self.seen_outbound_validators_requests.read().get(&peer_ip).map(|r| *r > 0).unwrap_or(false)
127 }
128
129 pub fn increment_outbound_validators_requests(&self, peer_ip: SocketAddr) -> u32 {
131 Self::increment_counter(&self.seen_outbound_validators_requests, peer_ip)
132 }
133
134 pub fn decrement_outbound_validators_requests(&self, peer_ip: SocketAddr) -> u32 {
136 Self::decrement_counter(&self.seen_outbound_validators_requests, peer_ip)
137 }
138
139 pub fn clear_outbound_validators_requests(&self, peer_ip: SocketAddr) {
141 self.seen_outbound_validators_requests.write().remove(&peer_ip);
142 }
143
144 pub fn insert_outbound_block_request(&self, peer_ip: SocketAddr, request: BlockRequest) {
146 self.seen_outbound_block_requests.write().entry(peer_ip).or_default().insert(request);
147 }
148
149 pub fn remove_outbound_block_request(&self, peer_ip: SocketAddr, request: &BlockRequest) -> bool {
151 self.seen_outbound_block_requests
152 .write()
153 .get_mut(&peer_ip)
154 .map(|requests| requests.remove(request))
155 .unwrap_or(false)
156 }
157
158 pub fn clear_outbound_block_requests(&self, peer_ip: SocketAddr) {
160 self.seen_outbound_block_requests.write().remove(&peer_ip);
161 }
162}
163
164impl<N: Network> Cache<N> {
165 fn retain_and_insert<K: Copy + Clone + PartialEq + Eq + Hash>(
167 map: &RwLock<BTreeMap<i64, HashMap<K, u32>>>,
168 key: K,
169 interval_in_secs: i64,
170 ) -> usize {
171 let now = OffsetDateTime::now_utc().unix_timestamp();
173
174 let mut map_write = map.write();
176 *map_write.entry(now).or_default().entry(key).or_default() += 1;
178 let cutoff = now.saturating_sub(interval_in_secs);
180 let (oldest, _) = map_write.first_key_value().unwrap();
182 let mut cache_hits = 0;
184 if cutoff <= *oldest {
186 for cache_keys in map_write.values() {
187 cache_hits += *cache_keys.get(&key).unwrap_or(&0);
188 }
189 } else {
190 let retained = map_write.split_off(&cutoff);
192 map_write.clear();
194 for (time, cache_keys) in retained {
196 cache_hits += *cache_keys.get(&key).unwrap_or(&0);
197 map_write.insert(time, cache_keys);
198 }
199 }
200 cache_hits as usize
202 }
203
204 fn increment_counter<K: Hash + Eq>(map: &RwLock<HashMap<K, u32>>, key: K) -> u32 {
206 let mut map_write = map.write();
207 let entry = map_write.entry(key).or_default();
209 *entry = entry.saturating_add(1);
210 *entry
212 }
213
214 fn decrement_counter<K: Copy + Hash + Eq>(map: &RwLock<HashMap<K, u32>>, key: K) -> u32 {
216 let mut map_write = map.write();
217 let entry = map_write.entry(key).or_default();
219 let value = entry.saturating_sub(1);
220 if *entry == 0 {
222 map_write.remove(&key);
223 } else {
224 *entry = value;
225 }
226 value
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use snarkvm::prelude::MainnetV0;
235
236 use std::{net::Ipv4Addr, thread, time::Duration};
237
238 type CurrentNetwork = MainnetV0;
239
240 trait Input {
241 fn input() -> Self;
242 }
243
244 impl Input for IpAddr {
245 fn input() -> Self {
246 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
247 }
248 }
249
250 impl Input for SocketAddr {
251 fn input() -> Self {
252 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234)
253 }
254 }
255
256 impl Input for Field<CurrentNetwork> {
257 fn input() -> Self {
258 Field::from_u8(1)
259 }
260 }
261
262 impl Input for TransmissionID<CurrentNetwork> {
263 fn input() -> Self {
264 TransmissionID::Transaction(Default::default(), Default::default())
265 }
266 }
267
268 const INTERVAL_IN_SECS: i64 = 3;
269
270 macro_rules! test_cache_fields {
271 ($($name:ident),*) => {
272 $(
273 paste::paste! {
274 #[test]
275 fn [<test_seen_ $name s>]() {
276 let cache = Cache::<CurrentNetwork>::default();
277 let input = Input::input();
278
279 assert!(cache.[<seen_ $name s>].read().is_empty());
281
282 assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 1);
284 thread::sleep(Duration::from_secs(1));
286 assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 2);
288 thread::sleep(Duration::from_secs(1));
290 assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 3);
292
293 assert_eq!(cache.[<seen_ $name s>].read().len(), 3);
295
296 cache.[<insert_ $name>](input, 1);
298 assert_eq!(cache.[<seen_ $name s>].read().len(), 2);
300
301 cache.[<insert_ $name>](input, 10);
303 assert_eq!(cache.[<seen_ $name s>].read().len(), 2);
305
306 thread::sleep(Duration::from_secs(INTERVAL_IN_SECS as u64 + 1));
308
309 assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 1);
311
312 assert_eq!(cache.[<seen_ $name s>].read().len(), 1);
314
315 let counts: u32 = cache.[<seen_ $name s>].read().values().map(|hash_map| hash_map.get(&input).unwrap_or(&0)).cloned().sum();
317 assert_eq!(counts, 1);
318
319 assert_eq!(cache.[<seen_ $name s>].read().len(), 1);
321 }
322 }
323 )*
324 }
325 }
326
327 test_cache_fields! {
328 inbound_connection,
329 inbound_event,
330 inbound_certificate,
331 inbound_transmission,
332 outbound_event,
333 outbound_certificate,
334 outbound_transmission
335 }
336
337 #[test]
338 fn test_seen_outbound_validators_requests() {
339 let cache = Cache::<CurrentNetwork>::default();
340 let input = Input::input();
341
342 assert!(!cache.contains_outbound_validators_request(input));
344
345 for _ in 0..3 {
347 cache.increment_outbound_validators_requests(input);
348 assert!(cache.contains_outbound_validators_request(input));
349 }
350
351 cache.decrement_outbound_validators_requests(input);
353 assert!(cache.contains_outbound_validators_request(input));
354
355 cache.clear_outbound_validators_requests(input);
357 assert!(!cache.contains_outbound_validators_request(input));
358 }
359}