1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
// Copyright (C) 2019-2023 Aleo Systems Inc.
// This file is part of the snarkOS library.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use snarkvm::{console::types::Field, ledger::narwhal::TransmissionID, prelude::Network};

use core::hash::Hash;
use parking_lot::RwLock;
use std::{
    collections::{BTreeMap, HashMap},
    net::{IpAddr, SocketAddr},
};
use time::OffsetDateTime;

#[derive(Debug)]
pub struct Cache<N: Network> {
    /// The ordered timestamp map of peer connections and cache hits.
    seen_inbound_connections: RwLock<BTreeMap<i64, HashMap<IpAddr, u32>>>,
    /// The ordered timestamp map of peer IPs and cache hits.
    seen_inbound_events: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
    /// The ordered timestamp map of certificate IDs and cache hits.
    seen_inbound_certificates: RwLock<BTreeMap<i64, HashMap<Field<N>, u32>>>,
    /// The ordered timestamp map of transmission IDs and cache hits.
    seen_inbound_transmissions: RwLock<BTreeMap<i64, HashMap<TransmissionID<N>, u32>>>,
    /// The ordered timestamp map of peer IPs and their cache hits on outbound events.
    seen_outbound_events: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
    /// The ordered timestamp map of peer IPs and their cache hits on certificate requests.
    seen_outbound_certificates: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
    /// The ordered timestamp map of peer IPs and their cache hits on transmission requests.
    seen_outbound_transmissions: RwLock<BTreeMap<i64, HashMap<SocketAddr, u32>>>,
    /// The map of IPs to the number of validators requests.
    seen_outbound_validators_requests: RwLock<HashMap<SocketAddr, u16>>,
}

impl<N: Network> Default for Cache<N> {
    /// Initializes a new instance of the cache.
    fn default() -> Self {
        Self::new()
    }
}

impl<N: Network> Cache<N> {
    /// Initializes a new instance of the cache.
    pub fn new() -> Self {
        Self {
            seen_inbound_connections: Default::default(),
            seen_inbound_events: Default::default(),
            seen_inbound_certificates: Default::default(),
            seen_inbound_transmissions: Default::default(),
            seen_outbound_events: Default::default(),
            seen_outbound_certificates: Default::default(),
            seen_outbound_transmissions: Default::default(),
            seen_outbound_validators_requests: Default::default(),
        }
    }
}

impl<N: Network> Cache<N> {
    /// Inserts a new timestamp for the given peer connection, returning the number of recent connection requests.
    pub fn insert_inbound_connection(&self, peer_ip: IpAddr, interval_in_secs: i64) -> usize {
        Self::retain_and_insert(&self.seen_inbound_connections, peer_ip, interval_in_secs)
    }

    /// Inserts a new timestamp for the given peer, returning the number of recent events.
    pub fn insert_inbound_event(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
        Self::retain_and_insert(&self.seen_inbound_events, peer_ip, interval_in_secs)
    }

    /// Inserts a certificate ID into the cache, returning the number of recent events.
    pub fn insert_inbound_certificate(&self, key: Field<N>, interval_in_secs: i64) -> usize {
        Self::retain_and_insert(&self.seen_inbound_certificates, key, interval_in_secs)
    }

    /// Inserts a transmission ID into the cache, returning the number of recent events.
    pub fn insert_inbound_transmission(&self, key: TransmissionID<N>, interval_in_secs: i64) -> usize {
        Self::retain_and_insert(&self.seen_inbound_transmissions, key, interval_in_secs)
    }
}

impl<N: Network> Cache<N> {
    /// Inserts a new timestamp for the given peer, returning the number of recent events.
    pub fn insert_outbound_event(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
        Self::retain_and_insert(&self.seen_outbound_events, peer_ip, interval_in_secs)
    }

    /// Inserts a new timestamp for the given peer, returning the number of recent events.
    pub fn insert_outbound_certificate(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
        Self::retain_and_insert(&self.seen_outbound_certificates, peer_ip, interval_in_secs)
    }

    /// Inserts a new timestamp for the given peer, returning the number of recent events.
    pub fn insert_outbound_transmission(&self, peer_ip: SocketAddr, interval_in_secs: i64) -> usize {
        Self::retain_and_insert(&self.seen_outbound_transmissions, peer_ip, interval_in_secs)
    }
}

impl<N: Network> Cache<N> {
    /// Returns `true` if the cache contains a validators request from the given IP.
    pub fn contains_outbound_validators_request(&self, peer_ip: SocketAddr) -> bool {
        self.seen_outbound_validators_requests.read().get(&peer_ip).map(|r| *r > 0).unwrap_or(false)
    }

    /// Increment the IP's number of validators requests, returning the updated number of validators requests.
    pub fn increment_outbound_validators_requests(&self, peer_ip: SocketAddr) -> u16 {
        Self::increment_counter(&self.seen_outbound_validators_requests, peer_ip)
    }

    /// Decrement the IP's number of validators requests, returning the updated number of validators requests.
    pub fn decrement_outbound_validators_requests(&self, peer_ip: SocketAddr) -> u16 {
        Self::decrement_counter(&self.seen_outbound_validators_requests, peer_ip)
    }
}

impl<N: Network> Cache<N> {
    /// Insert a new timestamp for the given key, returning the number of recent entries.
    fn retain_and_insert<K: Copy + Clone + PartialEq + Eq + Hash>(
        map: &RwLock<BTreeMap<i64, HashMap<K, u32>>>,
        key: K,
        interval_in_secs: i64,
    ) -> usize {
        // Fetch the current timestamp.
        let now = OffsetDateTime::now_utc().unix_timestamp();

        // Get the write lock.
        let mut map_write = map.write();
        // Insert the new timestamp and increment the frequency for the key.
        *map_write.entry(now).or_default().entry(key).or_default() += 1;
        // Calculate the cutoff time for the entries to retain.
        let cutoff = now.saturating_sub(interval_in_secs);
        // Obtain the oldest timestamp from the map; it's guaranteed to exist at this point.
        let (oldest, _) = map_write.first_key_value().unwrap();
        // Track the number of cache hits of the key.
        let mut cache_hits = 0;
        // If the oldest timestamp is above the cutoff value, all the entries can be retained.
        if cutoff <= *oldest {
            for cache_keys in map_write.values() {
                cache_hits += *cache_keys.get(&key).unwrap_or(&0);
            }
        } else {
            // Extract the subtree after interval (i.e. non-expired entries)
            let retained = map_write.split_off(&cutoff);
            // Clear all the expired entries.
            map_write.clear();
            // Reinsert the entries into map and sum the frequency of recent requests for `key` while looping.
            for (time, cache_keys) in retained {
                cache_hits += *cache_keys.get(&key).unwrap_or(&0);
                map_write.insert(time, cache_keys);
            }
        }
        // Return the frequency.
        cache_hits as usize
    }

    /// Increments the key's counter in the map, returning the updated counter.
    fn increment_counter<K: Hash + Eq>(map: &RwLock<HashMap<K, u16>>, key: K) -> u16 {
        let mut map_write = map.write();
        // Load the entry for the key, and increment the counter.
        let entry = map_write.entry(key).or_default();
        *entry = entry.saturating_add(1);
        // Return the updated counter.
        *entry
    }

    /// Decrements the key's counter in the map, returning the updated counter.
    fn decrement_counter<K: Copy + Hash + Eq>(map: &RwLock<HashMap<K, u16>>, key: K) -> u16 {
        let mut map_write = map.write();
        // Load the entry for the key, and decrement the counter.
        let entry = map_write.entry(key).or_default();
        let value = entry.saturating_sub(1);
        // If the entry is 0, remove the entry.
        if *entry == 0 {
            map_write.remove(&key);
        } else {
            *entry = value;
        }
        // Return the updated counter.
        value
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use snarkvm::prelude::Testnet3;

    use std::{net::Ipv4Addr, thread, time::Duration};

    type CurrentNetwork = Testnet3;

    trait Input {
        fn input() -> Self;
    }

    impl Input for IpAddr {
        fn input() -> Self {
            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))
        }
    }

    impl Input for SocketAddr {
        fn input() -> Self {
            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234)
        }
    }

    impl Input for Field<CurrentNetwork> {
        fn input() -> Self {
            Field::from_u8(1)
        }
    }

    impl Input for TransmissionID<CurrentNetwork> {
        fn input() -> Self {
            TransmissionID::Transaction(Default::default())
        }
    }

    const INTERVAL_IN_SECS: i64 = 3;

    macro_rules! test_cache_fields {
        ($($name:ident),*) => {
            $(
                paste::paste! {
                    #[test]
                    fn [<test_seen_ $name s>]() {
                        let cache = Cache::<CurrentNetwork>::default();
                        let input = Input::input();

                        // Check that the cache is empty.
                        assert!(cache.[<seen_ $name s>].read().is_empty());

                        // Insert an input, recent events should be 1.
                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 1);
                        // Wait for 1s so that the next entry doesn't overwrite the first one.
                        thread::sleep(Duration::from_secs(1));
                        // Insert an input, recent events should be 2.
                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 2);
                        // Wait for 1s so that the next entry doesn't overwrite the first one.
                        thread::sleep(Duration::from_secs(1));
                        // Insert an input, recent events should be 3.
                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 3);

                        // Check that the cache contains the input for 3 entries.
                        assert_eq!(cache.[<seen_ $name s>].read().len(), 3);

                        // Insert the input again with a small interval, causing one entry to be removed.
                        cache.[<insert_ $name>](input, 1);
                        // Check that the cache contains the input for 2 entries.
                        assert_eq!(cache.[<seen_ $name s>].read().len(), 2);

                        // Insert the input again with a large interval, causing nothing to be removed.
                        cache.[<insert_ $name>](input, 10);
                        // Check that the cache contains the input for 2 entries.
                        assert_eq!(cache.[<seen_ $name s>].read().len(), 2);

                        // Wait for the input to expire.
                        thread::sleep(Duration::from_secs(INTERVAL_IN_SECS as u64 + 1));

                        // Insert an input again, recent events should be 1.
                        assert_eq!(cache.[<insert_ $name>](input, INTERVAL_IN_SECS), 1);

                        // Check that the cache contains the input for 1 entry.
                        assert_eq!(cache.[<seen_ $name s>].read().len(), 1);

                        // Check that the cache still contains the input.
                        let counts: u32 = cache.[<seen_ $name s>].read().values().map(|hash_map| hash_map.get(&input).unwrap_or(&0)).cloned().sum();
                        assert_eq!(counts, 1);

                        // Check that the cache contains the input and 1 timestamp entry.
                        assert_eq!(cache.[<seen_ $name s>].read().len(), 1);
                    }
                }
            )*
        }
    }

    test_cache_fields! {
       inbound_connection,
       inbound_event,
       inbound_certificate,
       inbound_transmission,
       outbound_event,
       outbound_certificate,
       outbound_transmission
    }
}