snarkos_node_bft/helpers/
storage.rs

1// Copyright 2024 Aleo Network Foundation
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use snarkos_node_bft_ledger_service::LedgerService;
18use snarkos_node_bft_storage_service::StorageService;
19use snarkvm::{
20    ledger::{
21        block::{Block, Transaction},
22        narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
23    },
24    prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
25};
26
27use indexmap::{IndexMap, IndexSet, map::Entry};
28use lru::LruCache;
29use parking_lot::RwLock;
30use std::{
31    collections::{HashMap, HashSet},
32    num::NonZeroUsize,
33    sync::{
34        Arc,
35        atomic::{AtomicU32, AtomicU64, Ordering},
36    },
37};
38
39#[derive(Clone, Debug)]
40pub struct Storage<N: Network>(Arc<StorageInner<N>>);
41
42impl<N: Network> std::ops::Deref for Storage<N> {
43    type Target = Arc<StorageInner<N>>;
44
45    fn deref(&self) -> &Self::Target {
46        &self.0
47    }
48}
49
50/// The storage for the memory pool.
51///
52/// The storage is used to store the following:
53/// - `current_height` tracker.
54/// - `current_round` tracker.
55/// - `round` to `(certificate ID, batch ID, author)` entries.
56/// - `certificate ID` to `certificate` entries.
57/// - `batch ID` to `round` entries.
58/// - `transmission ID` to `(transmission, certificate IDs)` entries.
59///
60/// The chain of events is as follows:
61/// 1. A `transmission` is received.
62/// 2. After a `batch` is ready to be stored:
63///   - The `certificate` is inserted, triggering updates to the
64///     `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
65///   - The missing `transmissions` from storage are inserted into the `transmissions` map.
66///   - The certificate ID is inserted into the `transmissions` map.
67/// 3. After a `round` reaches quorum threshold:
68///  - The next round is inserted into the `current_round`.
69#[derive(Debug)]
70pub struct StorageInner<N: Network> {
71    /// The ledger service.
72    ledger: Arc<dyn LedgerService<N>>,
73    /* Once per block */
74    /// The current height.
75    current_height: AtomicU32,
76    /* Once per round */
77    /// The current round.
78    current_round: AtomicU64,
79    /// The `round` for which garbage collection has occurred **up to** (inclusive).
80    gc_round: AtomicU64,
81    /// The maximum number of rounds to keep in storage.
82    max_gc_rounds: u64,
83    /* Once per batch */
84    /// The map of `round` to a list of `(certificate ID, batch ID, author)` entries.
85    rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
86    /// A cache of `certificate ID` to unprocessed `certificate`.
87    unprocessed_certificates: RwLock<LruCache<Field<N>, BatchCertificate<N>>>,
88    /// The map of `certificate ID` to `certificate`.
89    certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
90    /// The map of `batch ID` to `round`.
91    batch_ids: RwLock<IndexMap<Field<N>, u64>>,
92    /// The map of `transmission ID` to `(transmission, certificate IDs)` entries.
93    transmissions: Arc<dyn StorageService<N>>,
94}
95
96impl<N: Network> Storage<N> {
97    /// Initializes a new instance of storage.
98    pub fn new(
99        ledger: Arc<dyn LedgerService<N>>,
100        transmissions: Arc<dyn StorageService<N>>,
101        max_gc_rounds: u64,
102    ) -> Self {
103        // Retrieve the current committee.
104        let committee = ledger.current_committee().expect("Ledger is missing a committee.");
105        // Retrieve the current round.
106        let current_round = committee.starting_round().max(1);
107        // Set the unprocessed certificates cache size.
108        let unprocessed_cache_size = NonZeroUsize::new((N::MAX_CERTIFICATES * 2) as usize).unwrap();
109
110        // Return the storage.
111        let storage = Self(Arc::new(StorageInner {
112            ledger,
113            current_height: Default::default(),
114            current_round: Default::default(),
115            gc_round: Default::default(),
116            max_gc_rounds,
117            rounds: Default::default(),
118            unprocessed_certificates: RwLock::new(LruCache::new(unprocessed_cache_size)),
119            certificates: Default::default(),
120            batch_ids: Default::default(),
121            transmissions,
122        }));
123        // Update the storage to the current round.
124        storage.update_current_round(current_round);
125        // Perform GC on the current round.
126        storage.garbage_collect_certificates(current_round);
127        // Return the storage.
128        storage
129    }
130}
131
132impl<N: Network> Storage<N> {
133    /// Returns the current height.
134    pub fn current_height(&self) -> u32 {
135        // Get the current height.
136        self.current_height.load(Ordering::SeqCst)
137    }
138}
139
140impl<N: Network> Storage<N> {
141    /// Returns the current round.
142    pub fn current_round(&self) -> u64 {
143        // Get the current round.
144        self.current_round.load(Ordering::SeqCst)
145    }
146
147    /// Returns the `round` that garbage collection has occurred **up to** (inclusive).
148    pub fn gc_round(&self) -> u64 {
149        // Get the GC round.
150        self.gc_round.load(Ordering::SeqCst)
151    }
152
153    /// Returns the maximum number of rounds to keep in storage.
154    pub fn max_gc_rounds(&self) -> u64 {
155        self.max_gc_rounds
156    }
157
158    /// Increments storage to the next round, updating the current round.
159    /// Note: This method is only called once per round, upon certification of the primary's batch.
160    pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
161        // Determine the next round.
162        let next_round = current_round + 1;
163
164        // Check if the next round is less than the current round in storage.
165        {
166            // Retrieve the storage round.
167            let storage_round = self.current_round();
168            // If the next round is less than the current round in storage, return early with the storage round.
169            if next_round < storage_round {
170                return Ok(storage_round);
171            }
172        }
173
174        // Retrieve the current committee.
175        let current_committee = self.ledger.current_committee()?;
176        // Retrieve the current committee's starting round.
177        let starting_round = current_committee.starting_round();
178        // If the primary is behind the current committee's starting round, sync with the latest block.
179        if next_round < starting_round {
180            // Retrieve the latest block round.
181            let latest_block_round = self.ledger.latest_round();
182            // Log the round sync.
183            info!(
184                "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
185            );
186            // Sync the round with the latest block.
187            self.sync_round_with_block(latest_block_round);
188            // Return the latest block round.
189            return Ok(latest_block_round);
190        }
191
192        // Update the storage to the next round.
193        self.update_current_round(next_round);
194
195        #[cfg(feature = "metrics")]
196        metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
197
198        // Retrieve the storage round.
199        let storage_round = self.current_round();
200        // Retrieve the GC round.
201        let gc_round = self.gc_round();
202        // Ensure the next round matches in storage.
203        ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
204        // Ensure the next round is greater than or equal to the GC round.
205        ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
206
207        // Log the updated round.
208        info!("Starting round {next_round}...");
209        Ok(next_round)
210    }
211
212    /// Updates the storage to the next round.
213    fn update_current_round(&self, next_round: u64) {
214        // Update the current round.
215        self.current_round.store(next_round, Ordering::SeqCst);
216    }
217
218    /// Update the storage by performing garbage collection based on the next round.
219    pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
220        // Fetch the current GC round.
221        let current_gc_round = self.gc_round();
222        // Compute the next GC round.
223        let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
224        // Check if storage needs to be garbage collected.
225        if next_gc_round > current_gc_round {
226            // Remove the GC round(s) from storage.
227            for gc_round in current_gc_round..=next_gc_round {
228                // Iterate over the certificates for the GC round.
229                for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
230                    // Remove the certificate from storage.
231                    self.remove_certificate(id);
232                }
233            }
234            // Update the GC round.
235            self.gc_round.store(next_gc_round, Ordering::SeqCst);
236        }
237    }
238}
239
240impl<N: Network> Storage<N> {
241    /// Returns `true` if the storage contains the specified `round`.
242    pub fn contains_certificates_for_round(&self, round: u64) -> bool {
243        // Check if the round exists in storage.
244        self.rounds.read().contains_key(&round)
245    }
246
247    /// Returns `true` if the storage contains the specified `certificate ID`.
248    pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
249        // Check if the certificate ID exists in storage.
250        self.certificates.read().contains_key(&certificate_id)
251    }
252
253    /// Returns `true` if the storage contains a certificate from the specified `author` in the given `round`.
254    pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
255        self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
256    }
257
258    /// Returns `true` if the storage contains the specified `certificate ID`.
259    pub fn contains_unprocessed_certificate(&self, certificate_id: Field<N>) -> bool {
260        // Check if the certificate ID exists in storage.
261        self.unprocessed_certificates.read().contains(&certificate_id)
262    }
263
264    /// Returns `true` if the storage contains the specified `batch ID`.
265    pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
266        // Check if the batch ID exists in storage.
267        self.batch_ids.read().contains_key(&batch_id)
268    }
269
270    /// Returns `true` if the storage contains the specified `transmission ID`.
271    pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
272        self.transmissions.contains_transmission(transmission_id.into())
273    }
274
275    /// Returns the transmission for the given `transmission ID`.
276    /// If the transmission ID does not exist in storage, `None` is returned.
277    pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
278        self.transmissions.get_transmission(transmission_id.into())
279    }
280
281    /// Returns the round for the given `certificate ID`.
282    /// If the certificate ID does not exist in storage, `None` is returned.
283    pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
284        // Get the round.
285        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
286    }
287
288    /// Returns the round for the given `batch ID`.
289    /// If the batch ID does not exist in storage, `None` is returned.
290    pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
291        // Get the round.
292        self.batch_ids.read().get(&batch_id).copied()
293    }
294
295    /// Returns the certificate round for the given `certificate ID`.
296    /// If the certificate ID does not exist in storage, `None` is returned.
297    pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
298        // Get the batch certificate and return the round.
299        self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
300    }
301
302    /// Returns the certificate for the given `certificate ID`.
303    /// If the certificate ID does not exist in storage, `None` is returned.
304    pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
305        // Get the batch certificate.
306        self.certificates.read().get(&certificate_id).cloned()
307    }
308
309    /// Returns the unprocessed certificate for the given `certificate ID`.
310    /// If the certificate ID does not exist in storage, `None` is returned.
311    pub fn get_unprocessed_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
312        // Get the unprocessed certificate.
313        self.unprocessed_certificates.read().peek(&certificate_id).cloned()
314    }
315
316    /// Returns the certificate for the given `round` and `author`.
317    /// If the round does not exist in storage, `None` is returned.
318    /// If the author for the round does not exist in storage, `None` is returned.
319    pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
320        // Retrieve the certificates.
321        if let Some(entries) = self.rounds.read().get(&round) {
322            let certificates = self.certificates.read();
323            entries.iter().find_map(
324                |(certificate_id, _, a)| if a == &author { certificates.get(certificate_id).cloned() } else { None },
325            )
326        } else {
327            Default::default()
328        }
329    }
330
331    /// Returns the certificates for the given `round`.
332    /// If the round does not exist in storage, `None` is returned.
333    pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
334        // The genesis round does not have batch certificates.
335        if round == 0 {
336            return Default::default();
337        }
338        // Retrieve the certificates.
339        if let Some(entries) = self.rounds.read().get(&round) {
340            let certificates = self.certificates.read();
341            entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
342        } else {
343            Default::default()
344        }
345    }
346
347    /// Returns the certificate IDs for the given `round`.
348    /// If the round does not exist in storage, `None` is returned.
349    pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
350        // The genesis round does not have batch certificates.
351        if round == 0 {
352            return Default::default();
353        }
354        // Retrieve the certificates.
355        if let Some(entries) = self.rounds.read().get(&round) {
356            entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
357        } else {
358            Default::default()
359        }
360    }
361
362    /// Returns the certificate authors for the given `round`.
363    /// If the round does not exist in storage, `None` is returned.
364    pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
365        // The genesis round does not have batch certificates.
366        if round == 0 {
367            return Default::default();
368        }
369        // Retrieve the certificates.
370        if let Some(entries) = self.rounds.read().get(&round) {
371            entries.iter().map(|(_, _, author)| *author).collect()
372        } else {
373            Default::default()
374        }
375    }
376
377    /// Returns the certificates that have not yet been included in the ledger.
378    /// Note that the order of this set is by round and then insertion.
379    pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
380        let mut pending_certificates = IndexSet::new();
381
382        // Obtain the read locks.
383        let rounds = self.rounds.read();
384        let certificates = self.certificates.read();
385
386        // Iterate over the rounds.
387        for (_, certificates_for_round) in rounds.clone().sorted_by(|a, _, b, _| a.cmp(b)) {
388            // Iterate over the certificates for the round.
389            for (certificate_id, _, _) in certificates_for_round {
390                // Skip the certificate if it already exists in the ledger.
391                if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
392                    continue;
393                }
394
395                // Add the certificate to the pending certificates.
396                match certificates.get(&certificate_id).cloned() {
397                    Some(certificate) => pending_certificates.insert(certificate),
398                    None => continue,
399                };
400            }
401        }
402
403        pending_certificates
404    }
405
406    /// Checks the given `batch_header` for validity, returning the missing transmissions from storage.
407    ///
408    /// This method ensures the following invariants:
409    /// - The batch ID does not already exist in storage.
410    /// - The author is a member of the committee for the batch round.
411    /// - The timestamp is within the allowed time range.
412    /// - None of the transmissions are from any past rounds (up to GC).
413    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
414    /// - All previous certificates declared in the certificate exist in storage (up to GC).
415    /// - All previous certificates are for the previous round (i.e. round - 1).
416    /// - All previous certificates contain a unique author.
417    /// - The previous certificates reached the quorum threshold (N - f).
418    pub fn check_batch_header(
419        &self,
420        batch_header: &BatchHeader<N>,
421        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
422        aborted_transmissions: HashSet<TransmissionID<N>>,
423    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
424        // Retrieve the round.
425        let round = batch_header.round();
426        // Retrieve the GC round.
427        let gc_round = self.gc_round();
428        // Construct a GC log message.
429        let gc_log = format!("(gc = {gc_round})");
430
431        // Ensure the batch ID does not already exist in storage.
432        if self.contains_batch(batch_header.batch_id()) {
433            bail!("Batch for round {round} already exists in storage {gc_log}")
434        }
435
436        // Retrieve the committee lookback for the batch round.
437        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
438            bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
439        };
440        // Ensure the author is in the committee.
441        if !committee_lookback.is_committee_member(batch_header.author()) {
442            bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
443        }
444
445        // Check the timestamp for liveness.
446        check_timestamp_for_liveness(batch_header.timestamp())?;
447
448        // Retrieve the missing transmissions in storage from the given transmissions.
449        let missing_transmissions = self
450            .transmissions
451            .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
452            .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
453
454        // Compute the previous round.
455        let previous_round = round.saturating_sub(1);
456        // Check if the previous round is within range of the GC round.
457        if previous_round > gc_round {
458            // Retrieve the committee lookback for the previous round.
459            let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
460                bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
461            };
462            // Ensure the previous round certificates exists in storage.
463            if !self.contains_certificates_for_round(previous_round) {
464                bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
465            }
466            // Ensure the number of previous certificate IDs is at or below the number of committee members.
467            if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
468                bail!("Too many previous certificates for round {round} {gc_log}")
469            }
470            // Initialize a set of the previous authors.
471            let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
472            // Ensure storage contains all declared previous certificates (up to GC).
473            for previous_certificate_id in batch_header.previous_certificate_ids() {
474                // Retrieve the previous certificate.
475                let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
476                    bail!(
477                        "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
478                        fmt_id(previous_certificate_id)
479                    )
480                };
481                // Ensure the previous certificate is for the previous round.
482                if previous_certificate.round() != previous_round {
483                    bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
484                }
485                // Ensure the previous author is new.
486                if previous_authors.contains(&previous_certificate.author()) {
487                    bail!("Round {round} certificate contains a duplicate author {gc_log}")
488                }
489                // Insert the author of the previous certificate.
490                previous_authors.insert(previous_certificate.author());
491            }
492            // Ensure the previous certificates have reached the quorum threshold.
493            if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
494                bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
495            }
496        }
497        Ok(missing_transmissions)
498    }
499
500    /// Checks the given `certificate` for validity, returning the missing transmissions from storage.
501    ///
502    /// This method ensures the following invariants:
503    /// - The certificate ID does not already exist in storage.
504    /// - The batch ID does not already exist in storage.
505    /// - The author is a member of the committee for the batch round.
506    /// - The author has not already created a certificate for the batch round.
507    /// - The timestamp is within the allowed time range.
508    /// - None of the transmissions are from any past rounds (up to GC).
509    /// - All transmissions declared in the batch header are provided or exist in storage (up to GC).
510    /// - All previous certificates declared in the certificate exist in storage (up to GC).
511    /// - All previous certificates are for the previous round (i.e. round - 1).
512    /// - The previous certificates reached the quorum threshold (N - f).
513    /// - The timestamps from the signers are all within the allowed time range.
514    /// - The signers have reached the quorum threshold (N - f).
515    pub fn check_certificate(
516        &self,
517        certificate: &BatchCertificate<N>,
518        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
519        aborted_transmissions: HashSet<TransmissionID<N>>,
520    ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
521        // Retrieve the round.
522        let round = certificate.round();
523        // Retrieve the GC round.
524        let gc_round = self.gc_round();
525        // Construct a GC log message.
526        let gc_log = format!("(gc = {gc_round})");
527
528        // Ensure the certificate ID does not already exist in storage.
529        if self.contains_certificate(certificate.id()) {
530            bail!("Certificate for round {round} already exists in storage {gc_log}")
531        }
532
533        // Ensure the storage does not already contain a certificate for this author in this round.
534        if self.contains_certificate_in_round_from(round, certificate.author()) {
535            bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
536        }
537
538        // Ensure the batch header is well-formed.
539        let missing_transmissions =
540            self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
541
542        // Check the timestamp for liveness.
543        check_timestamp_for_liveness(certificate.timestamp())?;
544
545        // Retrieve the committee lookback for the batch round.
546        let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
547            bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
548        };
549
550        // Initialize a set of the signers.
551        let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
552        // Append the batch author.
553        signers.insert(certificate.author());
554
555        // Iterate over the signatures.
556        for signature in certificate.signatures() {
557            // Retrieve the signer.
558            let signer = signature.to_address();
559            // Ensure the signer is in the committee.
560            if !committee_lookback.is_committee_member(signer) {
561                bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
562            }
563            // Append the signer.
564            signers.insert(signer);
565        }
566
567        // Ensure the signatures have reached the quorum threshold.
568        if !committee_lookback.is_quorum_threshold_reached(&signers) {
569            bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
570        }
571        Ok(missing_transmissions)
572    }
573
574    /// Inserts the given `certificate` into storage.
575    ///
576    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
577    ///
578    /// This method ensures the following invariants:
579    /// - The certificate ID does not already exist in storage.
580    /// - The batch ID does not already exist in storage.
581    /// - All transmissions declared in the certificate are provided or exist in storage (up to GC).
582    /// - All previous certificates declared in the certificate exist in storage (up to GC).
583    /// - All previous certificates are for the previous round (i.e. round - 1).
584    /// - The previous certificates reached the quorum threshold (N - f).
585    pub fn insert_certificate(
586        &self,
587        certificate: BatchCertificate<N>,
588        transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
589        aborted_transmissions: HashSet<TransmissionID<N>>,
590    ) -> Result<()> {
591        // Ensure the certificate round is above the GC round.
592        ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
593        // Ensure the certificate and its transmissions are valid.
594        let missing_transmissions =
595            self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
596        // Insert the certificate into storage.
597        self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
598        Ok(())
599    }
600
601    /// Inserts the given `certificate` into storage.
602    ///
603    /// This method assumes **all missing** transmissions are provided in the `missing_transmissions` map.
604    ///
605    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
606    fn insert_certificate_atomic(
607        &self,
608        certificate: BatchCertificate<N>,
609        aborted_transmission_ids: HashSet<TransmissionID<N>>,
610        missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
611    ) {
612        // Retrieve the round.
613        let round = certificate.round();
614        // Retrieve the certificate ID.
615        let certificate_id = certificate.id();
616        // Retrieve the batch ID.
617        let batch_id = certificate.batch_id();
618        // Retrieve the author of the batch.
619        let author = certificate.author();
620
621        // Insert the round to certificate ID entry.
622        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
623        // Obtain the certificate's transmission ids.
624        let transmission_ids = certificate.transmission_ids().clone();
625        // Insert the certificate.
626        self.certificates.write().insert(certificate_id, certificate);
627        // Remove the unprocessed certificate.
628        self.unprocessed_certificates.write().pop(&certificate_id);
629        // Insert the batch ID.
630        self.batch_ids.write().insert(batch_id, round);
631        // Insert the certificate ID for each of the transmissions into storage.
632        self.transmissions.insert_transmissions(
633            certificate_id,
634            transmission_ids,
635            aborted_transmission_ids,
636            missing_transmissions,
637        );
638    }
639
640    /// Inserts the given unprocessed `certificate` into storage.
641    ///
642    /// This is a temporary storage, which is cleared again when calling `insert_certificate_atomic`.
643    pub fn insert_unprocessed_certificate(&self, certificate: BatchCertificate<N>) -> Result<()> {
644        // Ensure the certificate round is above the GC round.
645        ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
646        // Insert the certificate.
647        self.unprocessed_certificates.write().put(certificate.id(), certificate);
648
649        Ok(())
650    }
651
652    /// Removes the given `certificate ID` from storage.
653    ///
654    /// This method triggers updates to the `rounds`, `certificates`, `batch_ids`, and `transmissions` maps.
655    ///
656    /// If the certificate was successfully removed, `true` is returned.
657    /// If the certificate did not exist in storage, `false` is returned.
658    fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
659        // Retrieve the certificate.
660        let Some(certificate) = self.get_certificate(certificate_id) else {
661            warn!("Certificate {certificate_id} does not exist in storage");
662            return false;
663        };
664        // Retrieve the round.
665        let round = certificate.round();
666        // Retrieve the batch ID.
667        let batch_id = certificate.batch_id();
668        // Compute the author of the batch.
669        let author = certificate.author();
670
671        // TODO (howardwu): We may want to use `shift_remove` below, in order to align compatibility
672        //  with tests written to for `remove_certificate`. However, this will come with performance hits.
673        //  It will be better to write tests that compare the union of the sets.
674
675        // Update the round.
676        match self.rounds.write().entry(round) {
677            Entry::Occupied(mut entry) => {
678                // Remove the round to certificate ID entry.
679                entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
680                // If the round is empty, remove it.
681                if entry.get().is_empty() {
682                    entry.swap_remove();
683                }
684            }
685            Entry::Vacant(_) => {}
686        }
687        // Remove the certificate.
688        self.certificates.write().swap_remove(&certificate_id);
689        // Remove the unprocessed certificate.
690        self.unprocessed_certificates.write().pop(&certificate_id);
691        // Remove the batch ID.
692        self.batch_ids.write().swap_remove(&batch_id);
693        // Remove the transmission entries in the certificate from storage.
694        self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
695        // Return successfully.
696        true
697    }
698}
699
700impl<N: Network> Storage<N> {
701    /// Syncs the current height with the block.
702    pub(crate) fn sync_height_with_block(&self, next_height: u32) {
703        // If the block height is greater than the current height in storage, sync the height.
704        if next_height > self.current_height() {
705            // Update the current height in storage.
706            self.current_height.store(next_height, Ordering::SeqCst);
707        }
708    }
709
710    /// Syncs the current round with the block.
711    pub(crate) fn sync_round_with_block(&self, next_round: u64) {
712        // Retrieve the current round in the block.
713        let next_round = next_round.max(1);
714        // If the round in the block is greater than the current round in storage, sync the round.
715        if next_round > self.current_round() {
716            // Update the current round in storage.
717            self.update_current_round(next_round);
718            // Log the updated round.
719            info!("Synced to round {next_round}...");
720        }
721    }
722
723    /// Syncs the batch certificate with the block.
724    pub(crate) fn sync_certificate_with_block(
725        &self,
726        block: &Block<N>,
727        certificate: BatchCertificate<N>,
728        unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
729    ) {
730        // Skip if the certificate round is below the GC round.
731        if certificate.round() <= self.gc_round() {
732            return;
733        }
734        // If the certificate ID already exists in storage, skip it.
735        if self.contains_certificate(certificate.id()) {
736            return;
737        }
738        // Retrieve the transmissions for the certificate.
739        let mut missing_transmissions = HashMap::new();
740
741        // Retrieve the aborted transmissions for the certificate.
742        let mut aborted_transmissions = HashSet::new();
743
744        // Track the block's aborted solutions and transactions.
745        let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
746        let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
747
748        // Iterate over the transmission IDs.
749        for transmission_id in certificate.transmission_ids() {
750            // If the transmission ID already exists in the map, skip it.
751            if missing_transmissions.contains_key(transmission_id) {
752                continue;
753            }
754            // If the transmission ID exists in storage, skip it.
755            if self.contains_transmission(*transmission_id) {
756                continue;
757            }
758            // Retrieve the transmission.
759            match transmission_id {
760                TransmissionID::Ratification => (),
761                TransmissionID::Solution(solution_id, _) => {
762                    // Retrieve the solution.
763                    match block.get_solution(solution_id) {
764                        // Insert the solution.
765                        Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
766                        // Otherwise, try to load the solution from the ledger.
767                        None => match self.ledger.get_solution(solution_id) {
768                            // Insert the solution.
769                            Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
770                            // Check if the solution is in the aborted solutions.
771                            Err(_) => {
772                                // Insert the aborted solution if it exists in the block or ledger.
773                                match aborted_solutions.contains(solution_id)
774                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
775                                {
776                                    true => {
777                                        aborted_transmissions.insert(*transmission_id);
778                                    }
779                                    false => error!("Missing solution {solution_id} in block {}", block.height()),
780                                }
781                                continue;
782                            }
783                        },
784                    };
785                }
786                TransmissionID::Transaction(transaction_id, _) => {
787                    // Retrieve the transaction.
788                    match unconfirmed_transactions.get(transaction_id) {
789                        // Insert the transaction.
790                        Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
791                        // Otherwise, try to load the unconfirmed transaction from the ledger.
792                        None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
793                            // Insert the transaction.
794                            Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
795                            // Check if the transaction is in the aborted transactions.
796                            Err(_) => {
797                                // Insert the aborted transaction if it exists in the block or ledger.
798                                match aborted_transactions.contains(transaction_id)
799                                    || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
800                                {
801                                    true => {
802                                        aborted_transmissions.insert(*transmission_id);
803                                    }
804                                    false => warn!("Missing transaction {transaction_id} in block {}", block.height()),
805                                }
806                                continue;
807                            }
808                        },
809                    };
810                }
811            }
812        }
813        // Insert the batch certificate into storage.
814        let certificate_id = fmt_id(certificate.id());
815        debug!(
816            "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
817            certificate.round(),
818            certificate.transmission_ids().len()
819        );
820        if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
821            error!("Failed to insert certificate '{certificate_id}' from block {} - {error}", block.height());
822        }
823    }
824}
825
826#[cfg(test)]
827impl<N: Network> Storage<N> {
828    /// Returns the ledger service.
829    pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
830        &self.ledger
831    }
832
833    /// Returns an iterator over the `(round, (certificate ID, batch ID, author))` entries.
834    pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
835        self.rounds.read().clone().into_iter()
836    }
837
838    /// Returns an iterator over the `(certificate ID, certificate)` entries.
839    pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
840        self.certificates.read().clone().into_iter()
841    }
842
843    /// Returns an iterator over the `(batch ID, round)` entries.
844    pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
845        self.batch_ids.read().clone().into_iter()
846    }
847
848    /// Returns an iterator over the `(transmission ID, (transmission, certificate IDs))` entries.
849    pub fn transmissions_iter(
850        &self,
851    ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
852        self.transmissions.as_hashmap().into_iter()
853    }
854
855    /// Inserts the given `certificate` into storage.
856    ///
857    /// Note: Do NOT use this in production. This is for **testing only**.
858    #[cfg(test)]
859    #[doc(hidden)]
860    pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
861        // Retrieve the round.
862        let round = certificate.round();
863        // Retrieve the certificate ID.
864        let certificate_id = certificate.id();
865        // Retrieve the batch ID.
866        let batch_id = certificate.batch_id();
867        // Retrieve the author of the batch.
868        let author = certificate.author();
869
870        // Insert the round to certificate ID entry.
871        self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
872        // Obtain the certificate's transmission ids.
873        let transmission_ids = certificate.transmission_ids().clone();
874        // Insert the certificate.
875        self.certificates.write().insert(certificate_id, certificate);
876        // Insert the batch ID.
877        self.batch_ids.write().insert(batch_id, round);
878
879        // Construct the dummy missing transmissions (for testing purposes).
880        let missing_transmissions = transmission_ids
881            .iter()
882            .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
883            .collect::<HashMap<_, _>>();
884        // Insert the certificate ID for each of the transmissions into storage.
885        self.transmissions.insert_transmissions(
886            certificate_id,
887            transmission_ids,
888            Default::default(),
889            missing_transmissions,
890        );
891    }
892}
893
894#[cfg(test)]
895pub(crate) mod tests {
896    use super::*;
897    use snarkos_node_bft_ledger_service::MockLedgerService;
898    use snarkos_node_bft_storage_service::BFTMemoryService;
899    use snarkvm::{
900        ledger::narwhal::Data,
901        prelude::{Rng, TestRng},
902    };
903
904    use ::bytes::Bytes;
905    use indexmap::indexset;
906
907    type CurrentNetwork = snarkvm::prelude::MainnetV0;
908
909    /// Asserts that the storage matches the expected layout.
910    pub fn assert_storage<N: Network>(
911        storage: &Storage<N>,
912        rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
913        certificates: &[(Field<N>, BatchCertificate<N>)],
914        batch_ids: &[(Field<N>, u64)],
915        transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
916    ) {
917        // Ensure the rounds are well-formed.
918        assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
919        // Ensure the certificates are well-formed.
920        assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
921        // Ensure the batch IDs are well-formed.
922        assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
923        // Ensure the transmissions are well-formed.
924        assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
925    }
926
927    /// Samples a random transmission.
928    fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
929        // Sample random fake solution bytes.
930        let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
931        // Sample random fake transaction bytes.
932        let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
933        // Sample a random transmission.
934        match rng.gen::<bool>() {
935            true => Transmission::Solution(s(rng)),
936            false => Transmission::Transaction(t(rng)),
937        }
938    }
939
940    /// Samples the random transmissions, returning the missing transmissions and the transmissions.
941    pub(crate) fn sample_transmissions(
942        certificate: &BatchCertificate<CurrentNetwork>,
943        rng: &mut TestRng,
944    ) -> (
945        HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
946        HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
947    ) {
948        // Retrieve the certificate ID.
949        let certificate_id = certificate.id();
950
951        let mut missing_transmissions = HashMap::new();
952        let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
953        for transmission_id in certificate.transmission_ids() {
954            // Initialize the transmission.
955            let transmission = sample_transmission(rng);
956            // Update the missing transmissions.
957            missing_transmissions.insert(*transmission_id, transmission.clone());
958            // Update the transmissions map.
959            transmissions
960                .entry(*transmission_id)
961                .or_insert((transmission, Default::default()))
962                .1
963                .insert(certificate_id);
964        }
965        (missing_transmissions, transmissions)
966    }
967
968    // TODO (howardwu): Testing with 'max_gc_rounds' set to '0' should ensure everything is cleared after insertion.
969
970    #[test]
971    fn test_certificate_insert_remove() {
972        let rng = &mut TestRng::default();
973
974        // Sample a committee.
975        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
976        // Initialize the ledger.
977        let ledger = Arc::new(MockLedgerService::new(committee));
978        // Initialize the storage.
979        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
980
981        // Ensure the storage is empty.
982        assert_storage(&storage, &[], &[], &[], &Default::default());
983
984        // Create a new certificate.
985        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
986        // Retrieve the certificate ID.
987        let certificate_id = certificate.id();
988        // Retrieve the round.
989        let round = certificate.round();
990        // Retrieve the batch ID.
991        let batch_id = certificate.batch_id();
992        // Retrieve the author of the batch.
993        let author = certificate.author();
994
995        // Construct the sample 'transmissions'.
996        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
997
998        // Insert the certificate.
999        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
1000        // Ensure the certificate exists in storage.
1001        assert!(storage.contains_certificate(certificate_id));
1002        // Ensure the certificate is stored in the correct round.
1003        assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
1004        // Ensure the certificate is stored for the correct round and author.
1005        assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
1006
1007        // Check that the underlying storage representation is correct.
1008        {
1009            // Construct the expected layout for 'rounds'.
1010            let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1011            // Construct the expected layout for 'certificates'.
1012            let certificates = [(certificate_id, certificate.clone())];
1013            // Construct the expected layout for 'batch_ids'.
1014            let batch_ids = [(batch_id, round)];
1015            // Assert the storage is well-formed.
1016            assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1017        }
1018
1019        // Retrieve the certificate.
1020        let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1021        // Ensure the retrieved certificate is the same as the inserted certificate.
1022        assert_eq!(certificate, candidate_certificate);
1023
1024        // Remove the certificate.
1025        assert!(storage.remove_certificate(certificate_id));
1026        // Ensure the certificate does not exist in storage.
1027        assert!(!storage.contains_certificate(certificate_id));
1028        // Ensure the certificate is no longer stored in the round.
1029        assert!(storage.get_certificates_for_round(round).is_empty());
1030        // Ensure the certificate is no longer stored for the round and author.
1031        assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1032        // Ensure the storage is empty.
1033        assert_storage(&storage, &[], &[], &[], &Default::default());
1034    }
1035
1036    #[test]
1037    fn test_certificate_duplicate() {
1038        let rng = &mut TestRng::default();
1039
1040        // Sample a committee.
1041        let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1042        // Initialize the ledger.
1043        let ledger = Arc::new(MockLedgerService::new(committee));
1044        // Initialize the storage.
1045        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1046
1047        // Ensure the storage is empty.
1048        assert_storage(&storage, &[], &[], &[], &Default::default());
1049
1050        // Create a new certificate.
1051        let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1052        // Retrieve the certificate ID.
1053        let certificate_id = certificate.id();
1054        // Retrieve the round.
1055        let round = certificate.round();
1056        // Retrieve the batch ID.
1057        let batch_id = certificate.batch_id();
1058        // Retrieve the author of the batch.
1059        let author = certificate.author();
1060
1061        // Construct the expected layout for 'rounds'.
1062        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1063        // Construct the expected layout for 'certificates'.
1064        let certificates = [(certificate_id, certificate.clone())];
1065        // Construct the expected layout for 'batch_ids'.
1066        let batch_ids = [(batch_id, round)];
1067        // Construct the sample 'transmissions'.
1068        let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1069
1070        // Insert the certificate.
1071        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1072        // Ensure the certificate exists in storage.
1073        assert!(storage.contains_certificate(certificate_id));
1074        // Check that the underlying storage representation is correct.
1075        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1076
1077        // Insert the certificate again - without any missing transmissions.
1078        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1079        // Ensure the certificate exists in storage.
1080        assert!(storage.contains_certificate(certificate_id));
1081        // Check that the underlying storage representation remains unchanged.
1082        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1083
1084        // Insert the certificate again - with all of the original missing transmissions.
1085        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1086        // Ensure the certificate exists in storage.
1087        assert!(storage.contains_certificate(certificate_id));
1088        // Check that the underlying storage representation remains unchanged.
1089        assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1090    }
1091}
1092
1093#[cfg(test)]
1094pub mod prop_tests {
1095    use super::*;
1096    use crate::helpers::{now, storage::tests::assert_storage};
1097    use snarkos_node_bft_ledger_service::MockLedgerService;
1098    use snarkos_node_bft_storage_service::BFTMemoryService;
1099    use snarkvm::{
1100        ledger::{
1101            committee::prop_tests::{CommitteeContext, ValidatorSet},
1102            narwhal::{BatchHeader, Data},
1103            puzzle::SolutionID,
1104        },
1105        prelude::{Signature, Uniform},
1106    };
1107
1108    use ::bytes::Bytes;
1109    use indexmap::indexset;
1110    use proptest::{
1111        collection,
1112        prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1113        prop_oneof,
1114        sample::{Selector, size_range},
1115        test_runner::TestRng,
1116    };
1117    use rand::{CryptoRng, Error, Rng, RngCore};
1118    use std::fmt::Debug;
1119    use test_strategy::proptest;
1120
1121    type CurrentNetwork = snarkvm::prelude::MainnetV0;
1122
1123    impl Arbitrary for Storage<CurrentNetwork> {
1124        type Parameters = CommitteeContext;
1125        type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1126
1127        fn arbitrary() -> Self::Strategy {
1128            (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1129                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1130                    let ledger = Arc::new(MockLedgerService::new(committee));
1131                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds)
1132                })
1133                .boxed()
1134        }
1135
1136        fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1137            (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1138                .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1139                    let ledger = Arc::new(MockLedgerService::new(committee));
1140                    Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds)
1141                })
1142                .boxed()
1143        }
1144    }
1145
1146    // The `proptest::TestRng` doesn't implement `rand_core::CryptoRng` trait which is required in snarkVM, so we use a wrapper
1147    #[derive(Debug)]
1148    pub struct CryptoTestRng(TestRng);
1149
1150    impl Arbitrary for CryptoTestRng {
1151        type Parameters = ();
1152        type Strategy = BoxedStrategy<CryptoTestRng>;
1153
1154        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1155            Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1156        }
1157    }
1158    impl RngCore for CryptoTestRng {
1159        fn next_u32(&mut self) -> u32 {
1160            self.0.next_u32()
1161        }
1162
1163        fn next_u64(&mut self) -> u64 {
1164            self.0.next_u64()
1165        }
1166
1167        fn fill_bytes(&mut self, dest: &mut [u8]) {
1168            self.0.fill_bytes(dest);
1169        }
1170
1171        fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1172            self.0.try_fill_bytes(dest)
1173        }
1174    }
1175
1176    impl CryptoRng for CryptoTestRng {}
1177
1178    #[derive(Debug, Clone)]
1179    pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1180
1181    impl Arbitrary for AnyTransmission {
1182        type Parameters = ();
1183        type Strategy = BoxedStrategy<AnyTransmission>;
1184
1185        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1186            any_transmission().prop_map(AnyTransmission).boxed()
1187        }
1188    }
1189
1190    #[derive(Debug, Clone)]
1191    pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1192
1193    impl Arbitrary for AnyTransmissionID {
1194        type Parameters = ();
1195        type Strategy = BoxedStrategy<AnyTransmissionID>;
1196
1197        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1198            any_transmission_id().prop_map(AnyTransmissionID).boxed()
1199        }
1200    }
1201
1202    fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1203        prop_oneof![
1204            (collection::vec(any::<u8>(), 512..=512))
1205                .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1206            (collection::vec(any::<u8>(), 2048..=2048))
1207                .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1208        ]
1209        .boxed()
1210    }
1211
1212    pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1213        Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1214    }
1215
1216    pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1217        Just(0)
1218            .prop_perturb(|_, rng| {
1219                <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1220            })
1221            .boxed()
1222    }
1223
1224    pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1225        prop_oneof![
1226            any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1227                id,
1228                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1229            )),
1230            any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1231                id,
1232                rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1233            )),
1234        ]
1235        .boxed()
1236    }
1237
1238    pub fn sign_batch_header<R: Rng + CryptoRng>(
1239        validator_set: &ValidatorSet,
1240        batch_header: &BatchHeader<CurrentNetwork>,
1241        rng: &mut R,
1242    ) -> IndexSet<Signature<CurrentNetwork>> {
1243        let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1244        for validator in validator_set.0.iter() {
1245            let private_key = validator.private_key;
1246            signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1247        }
1248        signatures
1249    }
1250
1251    #[proptest]
1252    fn test_certificate_duplicate(
1253        context: CommitteeContext,
1254        #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1255        mut rng: CryptoTestRng,
1256        selector: Selector,
1257    ) {
1258        let CommitteeContext(committee, ValidatorSet(validators)) = context;
1259        let committee_id = committee.id();
1260
1261        // Initialize the storage.
1262        let ledger = Arc::new(MockLedgerService::new(committee));
1263        let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1264
1265        // Ensure the storage is empty.
1266        assert_storage(&storage, &[], &[], &[], &Default::default());
1267
1268        // Create a new certificate.
1269        let signer = selector.select(&validators);
1270
1271        let mut transmission_map = IndexMap::new();
1272
1273        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1274            transmission_map.insert(*id, t.clone());
1275        }
1276
1277        let batch_header = BatchHeader::new(
1278            &signer.private_key,
1279            0,
1280            now(),
1281            committee_id,
1282            transmission_map.keys().cloned().collect(),
1283            Default::default(),
1284            &mut rng,
1285        )
1286        .unwrap();
1287
1288        // Remove the author from the validator set passed to create the batch
1289        // certificate, the author should not sign their own batch.
1290        let mut validators = validators.clone();
1291        validators.remove(signer);
1292
1293        let certificate = BatchCertificate::from(
1294            batch_header.clone(),
1295            sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1296        )
1297        .unwrap();
1298
1299        // Retrieve the certificate ID.
1300        let certificate_id = certificate.id();
1301        let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1302        for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1303            internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1304        }
1305
1306        // Retrieve the round.
1307        let round = certificate.round();
1308        // Retrieve the batch ID.
1309        let batch_id = certificate.batch_id();
1310        // Retrieve the author of the batch.
1311        let author = certificate.author();
1312
1313        // Construct the expected layout for 'rounds'.
1314        let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1315        // Construct the expected layout for 'certificates'.
1316        let certificates = [(certificate_id, certificate.clone())];
1317        // Construct the expected layout for 'batch_ids'.
1318        let batch_ids = [(batch_id, round)];
1319
1320        // Insert the certificate.
1321        let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1322            transmission_map.into_iter().collect();
1323        storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1324        // Ensure the certificate exists in storage.
1325        assert!(storage.contains_certificate(certificate_id));
1326        // Check that the underlying storage representation is correct.
1327        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1328
1329        // Insert the certificate again - without any missing transmissions.
1330        storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1331        // Ensure the certificate exists in storage.
1332        assert!(storage.contains_certificate(certificate_id));
1333        // Check that the underlying storage representation remains unchanged.
1334        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1335
1336        // Insert the certificate again - with all of the original missing transmissions.
1337        storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1338        // Ensure the certificate exists in storage.
1339        assert!(storage.contains_certificate(certificate_id));
1340        // Check that the underlying storage representation remains unchanged.
1341        assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1342    }
1343}