1use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use snarkos_node_bft_ledger_service::LedgerService;
18use snarkos_node_bft_storage_service::StorageService;
19use snarkvm::{
20 ledger::{
21 block::{Block, Transaction},
22 narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
23 },
24 prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
25};
26
27use indexmap::{IndexMap, IndexSet, map::Entry};
28use lru::LruCache;
29use parking_lot::RwLock;
30use std::{
31 collections::{HashMap, HashSet},
32 num::NonZeroUsize,
33 sync::{
34 Arc,
35 atomic::{AtomicU32, AtomicU64, Ordering},
36 },
37};
38
39#[derive(Clone, Debug)]
40pub struct Storage<N: Network>(Arc<StorageInner<N>>);
41
42impl<N: Network> std::ops::Deref for Storage<N> {
43 type Target = Arc<StorageInner<N>>;
44
45 fn deref(&self) -> &Self::Target {
46 &self.0
47 }
48}
49
50#[derive(Debug)]
70pub struct StorageInner<N: Network> {
71 ledger: Arc<dyn LedgerService<N>>,
73 current_height: AtomicU32,
76 current_round: AtomicU64,
79 gc_round: AtomicU64,
81 max_gc_rounds: u64,
83 rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Field<N>, Address<N>)>>>,
86 unprocessed_certificates: RwLock<LruCache<Field<N>, BatchCertificate<N>>>,
88 certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
90 batch_ids: RwLock<IndexMap<Field<N>, u64>>,
92 transmissions: Arc<dyn StorageService<N>>,
94}
95
96impl<N: Network> Storage<N> {
97 pub fn new(
99 ledger: Arc<dyn LedgerService<N>>,
100 transmissions: Arc<dyn StorageService<N>>,
101 max_gc_rounds: u64,
102 ) -> Self {
103 let committee = ledger.current_committee().expect("Ledger is missing a committee.");
105 let current_round = committee.starting_round().max(1);
107 let unprocessed_cache_size = NonZeroUsize::new((N::MAX_CERTIFICATES * 2) as usize).unwrap();
109
110 let storage = Self(Arc::new(StorageInner {
112 ledger,
113 current_height: Default::default(),
114 current_round: Default::default(),
115 gc_round: Default::default(),
116 max_gc_rounds,
117 rounds: Default::default(),
118 unprocessed_certificates: RwLock::new(LruCache::new(unprocessed_cache_size)),
119 certificates: Default::default(),
120 batch_ids: Default::default(),
121 transmissions,
122 }));
123 storage.update_current_round(current_round);
125 storage.garbage_collect_certificates(current_round);
127 storage
129 }
130}
131
132impl<N: Network> Storage<N> {
133 pub fn current_height(&self) -> u32 {
135 self.current_height.load(Ordering::SeqCst)
137 }
138}
139
140impl<N: Network> Storage<N> {
141 pub fn current_round(&self) -> u64 {
143 self.current_round.load(Ordering::SeqCst)
145 }
146
147 pub fn gc_round(&self) -> u64 {
149 self.gc_round.load(Ordering::SeqCst)
151 }
152
153 pub fn max_gc_rounds(&self) -> u64 {
155 self.max_gc_rounds
156 }
157
158 pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
161 let next_round = current_round + 1;
163
164 {
166 let storage_round = self.current_round();
168 if next_round < storage_round {
170 return Ok(storage_round);
171 }
172 }
173
174 let current_committee = self.ledger.current_committee()?;
176 let starting_round = current_committee.starting_round();
178 if next_round < starting_round {
180 let latest_block_round = self.ledger.latest_round();
182 info!(
184 "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
185 );
186 self.sync_round_with_block(latest_block_round);
188 return Ok(latest_block_round);
190 }
191
192 self.update_current_round(next_round);
194
195 #[cfg(feature = "metrics")]
196 metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
197
198 let storage_round = self.current_round();
200 let gc_round = self.gc_round();
202 ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
204 ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
206
207 info!("Starting round {next_round}...");
209 Ok(next_round)
210 }
211
212 fn update_current_round(&self, next_round: u64) {
214 self.current_round.store(next_round, Ordering::SeqCst);
216 }
217
218 pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
220 let current_gc_round = self.gc_round();
222 let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
224 if next_gc_round > current_gc_round {
226 for gc_round in current_gc_round..=next_gc_round {
228 for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
230 self.remove_certificate(id);
232 }
233 }
234 self.gc_round.store(next_gc_round, Ordering::SeqCst);
236 }
237 }
238}
239
240impl<N: Network> Storage<N> {
241 pub fn contains_certificates_for_round(&self, round: u64) -> bool {
243 self.rounds.read().contains_key(&round)
245 }
246
247 pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
249 self.certificates.read().contains_key(&certificate_id)
251 }
252
253 pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
255 self.rounds.read().get(&round).map_or(false, |set| set.iter().any(|(_, _, a)| a == &author))
256 }
257
258 pub fn contains_unprocessed_certificate(&self, certificate_id: Field<N>) -> bool {
260 self.unprocessed_certificates.read().contains(&certificate_id)
262 }
263
264 pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
266 self.batch_ids.read().contains_key(&batch_id)
268 }
269
270 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
272 self.transmissions.contains_transmission(transmission_id.into())
273 }
274
275 pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
278 self.transmissions.get_transmission(transmission_id.into())
279 }
280
281 pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
284 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
286 }
287
288 pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
291 self.batch_ids.read().get(&batch_id).copied()
293 }
294
295 pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
298 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
300 }
301
302 pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
305 self.certificates.read().get(&certificate_id).cloned()
307 }
308
309 pub fn get_unprocessed_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
312 self.unprocessed_certificates.read().peek(&certificate_id).cloned()
314 }
315
316 pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
320 if let Some(entries) = self.rounds.read().get(&round) {
322 let certificates = self.certificates.read();
323 entries.iter().find_map(
324 |(certificate_id, _, a)| if a == &author { certificates.get(certificate_id).cloned() } else { None },
325 )
326 } else {
327 Default::default()
328 }
329 }
330
331 pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
334 if round == 0 {
336 return Default::default();
337 }
338 if let Some(entries) = self.rounds.read().get(&round) {
340 let certificates = self.certificates.read();
341 entries.iter().flat_map(|(certificate_id, _, _)| certificates.get(certificate_id).cloned()).collect()
342 } else {
343 Default::default()
344 }
345 }
346
347 pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
350 if round == 0 {
352 return Default::default();
353 }
354 if let Some(entries) = self.rounds.read().get(&round) {
356 entries.iter().map(|(certificate_id, _, _)| *certificate_id).collect()
357 } else {
358 Default::default()
359 }
360 }
361
362 pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
365 if round == 0 {
367 return Default::default();
368 }
369 if let Some(entries) = self.rounds.read().get(&round) {
371 entries.iter().map(|(_, _, author)| *author).collect()
372 } else {
373 Default::default()
374 }
375 }
376
377 pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
380 let mut pending_certificates = IndexSet::new();
381
382 let rounds = self.rounds.read();
384 let certificates = self.certificates.read();
385
386 for (_, certificates_for_round) in rounds.clone().sorted_by(|a, _, b, _| a.cmp(b)) {
388 for (certificate_id, _, _) in certificates_for_round {
390 if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
392 continue;
393 }
394
395 match certificates.get(&certificate_id).cloned() {
397 Some(certificate) => pending_certificates.insert(certificate),
398 None => continue,
399 };
400 }
401 }
402
403 pending_certificates
404 }
405
406 pub fn check_batch_header(
419 &self,
420 batch_header: &BatchHeader<N>,
421 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
422 aborted_transmissions: HashSet<TransmissionID<N>>,
423 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
424 let round = batch_header.round();
426 let gc_round = self.gc_round();
428 let gc_log = format!("(gc = {gc_round})");
430
431 if self.contains_batch(batch_header.batch_id()) {
433 bail!("Batch for round {round} already exists in storage {gc_log}")
434 }
435
436 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
438 bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
439 };
440 if !committee_lookback.is_committee_member(batch_header.author()) {
442 bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
443 }
444
445 check_timestamp_for_liveness(batch_header.timestamp())?;
447
448 let missing_transmissions = self
450 .transmissions
451 .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
452 .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
453
454 let previous_round = round.saturating_sub(1);
456 if previous_round > gc_round {
458 let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
460 bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
461 };
462 if !self.contains_certificates_for_round(previous_round) {
464 bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
465 }
466 if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
468 bail!("Too many previous certificates for round {round} {gc_log}")
469 }
470 let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
472 for previous_certificate_id in batch_header.previous_certificate_ids() {
474 let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
476 bail!(
477 "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
478 fmt_id(previous_certificate_id)
479 )
480 };
481 if previous_certificate.round() != previous_round {
483 bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
484 }
485 if previous_authors.contains(&previous_certificate.author()) {
487 bail!("Round {round} certificate contains a duplicate author {gc_log}")
488 }
489 previous_authors.insert(previous_certificate.author());
491 }
492 if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
494 bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
495 }
496 }
497 Ok(missing_transmissions)
498 }
499
500 pub fn check_certificate(
516 &self,
517 certificate: &BatchCertificate<N>,
518 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
519 aborted_transmissions: HashSet<TransmissionID<N>>,
520 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
521 let round = certificate.round();
523 let gc_round = self.gc_round();
525 let gc_log = format!("(gc = {gc_round})");
527
528 if self.contains_certificate(certificate.id()) {
530 bail!("Certificate for round {round} already exists in storage {gc_log}")
531 }
532
533 if self.contains_certificate_in_round_from(round, certificate.author()) {
535 bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
536 }
537
538 let missing_transmissions =
540 self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?;
541
542 check_timestamp_for_liveness(certificate.timestamp())?;
544
545 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
547 bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
548 };
549
550 let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
552 signers.insert(certificate.author());
554
555 for signature in certificate.signatures() {
557 let signer = signature.to_address();
559 if !committee_lookback.is_committee_member(signer) {
561 bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
562 }
563 signers.insert(signer);
565 }
566
567 if !committee_lookback.is_quorum_threshold_reached(&signers) {
569 bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
570 }
571 Ok(missing_transmissions)
572 }
573
574 pub fn insert_certificate(
586 &self,
587 certificate: BatchCertificate<N>,
588 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
589 aborted_transmissions: HashSet<TransmissionID<N>>,
590 ) -> Result<()> {
591 ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
593 let missing_transmissions =
595 self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
596 self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
598 Ok(())
599 }
600
601 fn insert_certificate_atomic(
607 &self,
608 certificate: BatchCertificate<N>,
609 aborted_transmission_ids: HashSet<TransmissionID<N>>,
610 missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
611 ) {
612 let round = certificate.round();
614 let certificate_id = certificate.id();
616 let batch_id = certificate.batch_id();
618 let author = certificate.author();
620
621 self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
623 let transmission_ids = certificate.transmission_ids().clone();
625 self.certificates.write().insert(certificate_id, certificate);
627 self.unprocessed_certificates.write().pop(&certificate_id);
629 self.batch_ids.write().insert(batch_id, round);
631 self.transmissions.insert_transmissions(
633 certificate_id,
634 transmission_ids,
635 aborted_transmission_ids,
636 missing_transmissions,
637 );
638 }
639
640 pub fn insert_unprocessed_certificate(&self, certificate: BatchCertificate<N>) -> Result<()> {
644 ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
646 self.unprocessed_certificates.write().put(certificate.id(), certificate);
648
649 Ok(())
650 }
651
652 fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
659 let Some(certificate) = self.get_certificate(certificate_id) else {
661 warn!("Certificate {certificate_id} does not exist in storage");
662 return false;
663 };
664 let round = certificate.round();
666 let batch_id = certificate.batch_id();
668 let author = certificate.author();
670
671 match self.rounds.write().entry(round) {
677 Entry::Occupied(mut entry) => {
678 entry.get_mut().swap_remove(&(certificate_id, batch_id, author));
680 if entry.get().is_empty() {
682 entry.swap_remove();
683 }
684 }
685 Entry::Vacant(_) => {}
686 }
687 self.certificates.write().swap_remove(&certificate_id);
689 self.unprocessed_certificates.write().pop(&certificate_id);
691 self.batch_ids.write().swap_remove(&batch_id);
693 self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
695 true
697 }
698}
699
700impl<N: Network> Storage<N> {
701 pub(crate) fn sync_height_with_block(&self, next_height: u32) {
703 if next_height > self.current_height() {
705 self.current_height.store(next_height, Ordering::SeqCst);
707 }
708 }
709
710 pub(crate) fn sync_round_with_block(&self, next_round: u64) {
712 let next_round = next_round.max(1);
714 if next_round > self.current_round() {
716 self.update_current_round(next_round);
718 info!("Synced to round {next_round}...");
720 }
721 }
722
723 pub(crate) fn sync_certificate_with_block(
725 &self,
726 block: &Block<N>,
727 certificate: BatchCertificate<N>,
728 unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
729 ) {
730 if certificate.round() <= self.gc_round() {
732 return;
733 }
734 if self.contains_certificate(certificate.id()) {
736 return;
737 }
738 let mut missing_transmissions = HashMap::new();
740
741 let mut aborted_transmissions = HashSet::new();
743
744 let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
746 let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
747
748 for transmission_id in certificate.transmission_ids() {
750 if missing_transmissions.contains_key(transmission_id) {
752 continue;
753 }
754 if self.contains_transmission(*transmission_id) {
756 continue;
757 }
758 match transmission_id {
760 TransmissionID::Ratification => (),
761 TransmissionID::Solution(solution_id, _) => {
762 match block.get_solution(solution_id) {
764 Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
766 None => match self.ledger.get_solution(solution_id) {
768 Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
770 Err(_) => {
772 match aborted_solutions.contains(solution_id)
774 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
775 {
776 true => {
777 aborted_transmissions.insert(*transmission_id);
778 }
779 false => error!("Missing solution {solution_id} in block {}", block.height()),
780 }
781 continue;
782 }
783 },
784 };
785 }
786 TransmissionID::Transaction(transaction_id, _) => {
787 match unconfirmed_transactions.get(transaction_id) {
789 Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
791 None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
793 Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
795 Err(_) => {
797 match aborted_transactions.contains(transaction_id)
799 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
800 {
801 true => {
802 aborted_transmissions.insert(*transmission_id);
803 }
804 false => warn!("Missing transaction {transaction_id} in block {}", block.height()),
805 }
806 continue;
807 }
808 },
809 };
810 }
811 }
812 }
813 let certificate_id = fmt_id(certificate.id());
815 debug!(
816 "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
817 certificate.round(),
818 certificate.transmission_ids().len()
819 );
820 if let Err(error) = self.insert_certificate(certificate, missing_transmissions, aborted_transmissions) {
821 error!("Failed to insert certificate '{certificate_id}' from block {} - {error}", block.height());
822 }
823 }
824}
825
826#[cfg(test)]
827impl<N: Network> Storage<N> {
828 pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
830 &self.ledger
831 }
832
833 pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)> {
835 self.rounds.read().clone().into_iter()
836 }
837
838 pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
840 self.certificates.read().clone().into_iter()
841 }
842
843 pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
845 self.batch_ids.read().clone().into_iter()
846 }
847
848 pub fn transmissions_iter(
850 &self,
851 ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
852 self.transmissions.as_hashmap().into_iter()
853 }
854
855 #[cfg(test)]
859 #[doc(hidden)]
860 pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
861 let round = certificate.round();
863 let certificate_id = certificate.id();
865 let batch_id = certificate.batch_id();
867 let author = certificate.author();
869
870 self.rounds.write().entry(round).or_default().insert((certificate_id, batch_id, author));
872 let transmission_ids = certificate.transmission_ids().clone();
874 self.certificates.write().insert(certificate_id, certificate);
876 self.batch_ids.write().insert(batch_id, round);
878
879 let missing_transmissions = transmission_ids
881 .iter()
882 .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
883 .collect::<HashMap<_, _>>();
884 self.transmissions.insert_transmissions(
886 certificate_id,
887 transmission_ids,
888 Default::default(),
889 missing_transmissions,
890 );
891 }
892}
893
894#[cfg(test)]
895pub(crate) mod tests {
896 use super::*;
897 use snarkos_node_bft_ledger_service::MockLedgerService;
898 use snarkos_node_bft_storage_service::BFTMemoryService;
899 use snarkvm::{
900 ledger::narwhal::Data,
901 prelude::{Rng, TestRng},
902 };
903
904 use ::bytes::Bytes;
905 use indexmap::indexset;
906
907 type CurrentNetwork = snarkvm::prelude::MainnetV0;
908
909 pub fn assert_storage<N: Network>(
911 storage: &Storage<N>,
912 rounds: &[(u64, IndexSet<(Field<N>, Field<N>, Address<N>)>)],
913 certificates: &[(Field<N>, BatchCertificate<N>)],
914 batch_ids: &[(Field<N>, u64)],
915 transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
916 ) {
917 assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
919 assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
921 assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
923 assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
925 }
926
927 fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
929 let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
931 let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.gen::<u8>()).collect::<Vec<_>>()));
933 match rng.gen::<bool>() {
935 true => Transmission::Solution(s(rng)),
936 false => Transmission::Transaction(t(rng)),
937 }
938 }
939
940 pub(crate) fn sample_transmissions(
942 certificate: &BatchCertificate<CurrentNetwork>,
943 rng: &mut TestRng,
944 ) -> (
945 HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
946 HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
947 ) {
948 let certificate_id = certificate.id();
950
951 let mut missing_transmissions = HashMap::new();
952 let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
953 for transmission_id in certificate.transmission_ids() {
954 let transmission = sample_transmission(rng);
956 missing_transmissions.insert(*transmission_id, transmission.clone());
958 transmissions
960 .entry(*transmission_id)
961 .or_insert((transmission, Default::default()))
962 .1
963 .insert(certificate_id);
964 }
965 (missing_transmissions, transmissions)
966 }
967
968 #[test]
971 fn test_certificate_insert_remove() {
972 let rng = &mut TestRng::default();
973
974 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
976 let ledger = Arc::new(MockLedgerService::new(committee));
978 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
980
981 assert_storage(&storage, &[], &[], &[], &Default::default());
983
984 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
986 let certificate_id = certificate.id();
988 let round = certificate.round();
990 let batch_id = certificate.batch_id();
992 let author = certificate.author();
994
995 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
997
998 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
1000 assert!(storage.contains_certificate(certificate_id));
1002 assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
1004 assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
1006
1007 {
1009 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1011 let certificates = [(certificate_id, certificate.clone())];
1013 let batch_ids = [(batch_id, round)];
1015 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1017 }
1018
1019 let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1021 assert_eq!(certificate, candidate_certificate);
1023
1024 assert!(storage.remove_certificate(certificate_id));
1026 assert!(!storage.contains_certificate(certificate_id));
1028 assert!(storage.get_certificates_for_round(round).is_empty());
1030 assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1032 assert_storage(&storage, &[], &[], &[], &Default::default());
1034 }
1035
1036 #[test]
1037 fn test_certificate_duplicate() {
1038 let rng = &mut TestRng::default();
1039
1040 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1042 let ledger = Arc::new(MockLedgerService::new(committee));
1044 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1046
1047 assert_storage(&storage, &[], &[], &[], &Default::default());
1049
1050 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1052 let certificate_id = certificate.id();
1054 let round = certificate.round();
1056 let batch_id = certificate.batch_id();
1058 let author = certificate.author();
1060
1061 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1063 let certificates = [(certificate_id, certificate.clone())];
1065 let batch_ids = [(batch_id, round)];
1067 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1069
1070 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1072 assert!(storage.contains_certificate(certificate_id));
1074 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1076
1077 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1079 assert!(storage.contains_certificate(certificate_id));
1081 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1083
1084 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1086 assert!(storage.contains_certificate(certificate_id));
1088 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1090 }
1091}
1092
1093#[cfg(test)]
1094pub mod prop_tests {
1095 use super::*;
1096 use crate::helpers::{now, storage::tests::assert_storage};
1097 use snarkos_node_bft_ledger_service::MockLedgerService;
1098 use snarkos_node_bft_storage_service::BFTMemoryService;
1099 use snarkvm::{
1100 ledger::{
1101 committee::prop_tests::{CommitteeContext, ValidatorSet},
1102 narwhal::{BatchHeader, Data},
1103 puzzle::SolutionID,
1104 },
1105 prelude::{Signature, Uniform},
1106 };
1107
1108 use ::bytes::Bytes;
1109 use indexmap::indexset;
1110 use proptest::{
1111 collection,
1112 prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1113 prop_oneof,
1114 sample::{Selector, size_range},
1115 test_runner::TestRng,
1116 };
1117 use rand::{CryptoRng, Error, Rng, RngCore};
1118 use std::fmt::Debug;
1119 use test_strategy::proptest;
1120
1121 type CurrentNetwork = snarkvm::prelude::MainnetV0;
1122
1123 impl Arbitrary for Storage<CurrentNetwork> {
1124 type Parameters = CommitteeContext;
1125 type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1126
1127 fn arbitrary() -> Self::Strategy {
1128 (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1129 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1130 let ledger = Arc::new(MockLedgerService::new(committee));
1131 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds)
1132 })
1133 .boxed()
1134 }
1135
1136 fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1137 (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1138 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1139 let ledger = Arc::new(MockLedgerService::new(committee));
1140 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds)
1141 })
1142 .boxed()
1143 }
1144 }
1145
1146 #[derive(Debug)]
1148 pub struct CryptoTestRng(TestRng);
1149
1150 impl Arbitrary for CryptoTestRng {
1151 type Parameters = ();
1152 type Strategy = BoxedStrategy<CryptoTestRng>;
1153
1154 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1155 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1156 }
1157 }
1158 impl RngCore for CryptoTestRng {
1159 fn next_u32(&mut self) -> u32 {
1160 self.0.next_u32()
1161 }
1162
1163 fn next_u64(&mut self) -> u64 {
1164 self.0.next_u64()
1165 }
1166
1167 fn fill_bytes(&mut self, dest: &mut [u8]) {
1168 self.0.fill_bytes(dest);
1169 }
1170
1171 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1172 self.0.try_fill_bytes(dest)
1173 }
1174 }
1175
1176 impl CryptoRng for CryptoTestRng {}
1177
1178 #[derive(Debug, Clone)]
1179 pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1180
1181 impl Arbitrary for AnyTransmission {
1182 type Parameters = ();
1183 type Strategy = BoxedStrategy<AnyTransmission>;
1184
1185 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1186 any_transmission().prop_map(AnyTransmission).boxed()
1187 }
1188 }
1189
1190 #[derive(Debug, Clone)]
1191 pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1192
1193 impl Arbitrary for AnyTransmissionID {
1194 type Parameters = ();
1195 type Strategy = BoxedStrategy<AnyTransmissionID>;
1196
1197 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1198 any_transmission_id().prop_map(AnyTransmissionID).boxed()
1199 }
1200 }
1201
1202 fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1203 prop_oneof![
1204 (collection::vec(any::<u8>(), 512..=512))
1205 .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1206 (collection::vec(any::<u8>(), 2048..=2048))
1207 .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1208 ]
1209 .boxed()
1210 }
1211
1212 pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1213 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).gen::<u64>().into()).boxed()
1214 }
1215
1216 pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1217 Just(0)
1218 .prop_perturb(|_, rng| {
1219 <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1220 })
1221 .boxed()
1222 }
1223
1224 pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1225 prop_oneof![
1226 any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1227 id,
1228 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1229 )),
1230 any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1231 id,
1232 rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1233 )),
1234 ]
1235 .boxed()
1236 }
1237
1238 pub fn sign_batch_header<R: Rng + CryptoRng>(
1239 validator_set: &ValidatorSet,
1240 batch_header: &BatchHeader<CurrentNetwork>,
1241 rng: &mut R,
1242 ) -> IndexSet<Signature<CurrentNetwork>> {
1243 let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1244 for validator in validator_set.0.iter() {
1245 let private_key = validator.private_key;
1246 signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1247 }
1248 signatures
1249 }
1250
1251 #[proptest]
1252 fn test_certificate_duplicate(
1253 context: CommitteeContext,
1254 #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1255 mut rng: CryptoTestRng,
1256 selector: Selector,
1257 ) {
1258 let CommitteeContext(committee, ValidatorSet(validators)) = context;
1259 let committee_id = committee.id();
1260
1261 let ledger = Arc::new(MockLedgerService::new(committee));
1263 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1264
1265 assert_storage(&storage, &[], &[], &[], &Default::default());
1267
1268 let signer = selector.select(&validators);
1270
1271 let mut transmission_map = IndexMap::new();
1272
1273 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1274 transmission_map.insert(*id, t.clone());
1275 }
1276
1277 let batch_header = BatchHeader::new(
1278 &signer.private_key,
1279 0,
1280 now(),
1281 committee_id,
1282 transmission_map.keys().cloned().collect(),
1283 Default::default(),
1284 &mut rng,
1285 )
1286 .unwrap();
1287
1288 let mut validators = validators.clone();
1291 validators.remove(signer);
1292
1293 let certificate = BatchCertificate::from(
1294 batch_header.clone(),
1295 sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1296 )
1297 .unwrap();
1298
1299 let certificate_id = certificate.id();
1301 let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1302 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1303 internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1304 }
1305
1306 let round = certificate.round();
1308 let batch_id = certificate.batch_id();
1310 let author = certificate.author();
1312
1313 let rounds = [(round, indexset! { (certificate_id, batch_id, author) })];
1315 let certificates = [(certificate_id, certificate.clone())];
1317 let batch_ids = [(batch_id, round)];
1319
1320 let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1322 transmission_map.into_iter().collect();
1323 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1324 assert!(storage.contains_certificate(certificate_id));
1326 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1328
1329 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1331 assert!(storage.contains_certificate(certificate_id));
1333 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1335
1336 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1338 assert!(storage.contains_certificate(certificate_id));
1340 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1342 }
1343}