1use crate::helpers::{check_timestamp_for_liveness, fmt_id};
17use snarkos_node_bft_ledger_service::LedgerService;
18use snarkos_node_bft_storage_service::StorageService;
19use snarkvm::{
20 ledger::{
21 block::{Block, Transaction},
22 narwhal::{BatchCertificate, BatchHeader, Transmission, TransmissionID},
23 },
24 prelude::{Address, Field, Network, Result, anyhow, bail, ensure},
25 utilities::{cfg_into_iter, cfg_iter, cfg_sorted_by, flatten_error},
26};
27
28use anyhow::Context;
29use indexmap::{IndexMap, IndexSet, map::Entry};
30#[cfg(feature = "locktick")]
31use locktick::parking_lot::RwLock;
32use lru::LruCache;
33#[cfg(not(feature = "locktick"))]
34use parking_lot::RwLock;
35#[cfg(not(feature = "serial"))]
36use rayon::prelude::*;
37use std::{
38 collections::{HashMap, HashSet},
39 num::NonZeroUsize,
40 sync::{
41 Arc,
42 atomic::{AtomicU32, AtomicU64, Ordering},
43 },
44};
45
46#[derive(Clone, Debug)]
47pub struct Storage<N: Network>(Arc<StorageInner<N>>);
48
49impl<N: Network> std::ops::Deref for Storage<N> {
50 type Target = Arc<StorageInner<N>>;
51
52 fn deref(&self) -> &Self::Target {
53 &self.0
54 }
55}
56
57#[derive(Debug)]
77pub struct StorageInner<N: Network> {
78 ledger: Arc<dyn LedgerService<N>>,
80 current_height: AtomicU32,
83 current_round: AtomicU64,
92 gc_round: AtomicU64,
94 max_gc_rounds: u64,
96 rounds: RwLock<IndexMap<u64, IndexSet<(Field<N>, Address<N>)>>>,
99 unprocessed_certificates: RwLock<LruCache<Field<N>, BatchCertificate<N>>>,
101 certificates: RwLock<IndexMap<Field<N>, BatchCertificate<N>>>,
103 batch_ids: RwLock<IndexMap<Field<N>, u64>>,
105 transmissions: Arc<dyn StorageService<N>>,
107}
108
109impl<N: Network> Storage<N> {
110 pub fn new(
112 ledger: Arc<dyn LedgerService<N>>,
113 transmissions: Arc<dyn StorageService<N>>,
114 max_gc_rounds: u64,
115 ) -> Self {
116 let committee = ledger.current_committee().expect("Ledger is missing a committee.");
119 let current_round = committee.starting_round().max(1);
121 let unprocessed_cache_size = NonZeroUsize::new((N::LATEST_MAX_CERTIFICATES().unwrap() * 2) as usize).unwrap();
123
124 let storage = Self(Arc::new(StorageInner {
126 ledger,
127 current_height: Default::default(),
128 current_round: AtomicU64::new(current_round),
129 gc_round: Default::default(),
130 max_gc_rounds,
131 rounds: Default::default(),
132 unprocessed_certificates: RwLock::new(LruCache::new(unprocessed_cache_size)),
133 certificates: Default::default(),
134 batch_ids: Default::default(),
135 transmissions,
136 }));
137 storage.garbage_collect_certificates(current_round);
140 storage
142 }
143}
144
145impl<N: Network> Storage<N> {
146 pub fn current_height(&self) -> u32 {
148 self.current_height.load(Ordering::SeqCst)
150 }
151}
152
153impl<N: Network> Storage<N> {
154 pub fn current_round(&self) -> u64 {
156 self.current_round.load(Ordering::SeqCst)
158 }
159
160 pub fn gc_round(&self) -> u64 {
162 self.gc_round.load(Ordering::SeqCst)
164 }
165
166 pub fn max_gc_rounds(&self) -> u64 {
168 self.max_gc_rounds
169 }
170
171 pub fn increment_to_next_round(&self, current_round: u64) -> Result<u64> {
174 let next_round = current_round + 1;
176
177 {
179 let storage_round = self.current_round();
181 if next_round < storage_round {
183 return Ok(storage_round);
184 }
185
186 trace!("Incrementing storage from round {storage_round} to {next_round}");
187 }
188
189 let current_committee = self.ledger.current_committee()?;
191 let starting_round = current_committee.starting_round();
193 if next_round < starting_round {
195 let latest_block_round = self.ledger.latest_round();
197 info!(
199 "Syncing primary round ({next_round}) with the current committee's starting round ({starting_round}). Syncing with the latest block round {latest_block_round}..."
200 );
201 self.sync_round_with_block(latest_block_round);
203 return Ok(latest_block_round);
205 }
206
207 self.update_current_round(next_round);
209
210 #[cfg(feature = "metrics")]
211 metrics::gauge(metrics::bft::LAST_STORED_ROUND, next_round as f64);
212
213 let storage_round = self.current_round();
215 let gc_round = self.gc_round();
217 ensure!(next_round == storage_round, "The next round {next_round} does not match in storage ({storage_round})");
219 ensure!(next_round >= gc_round, "The next round {next_round} is behind the GC round {gc_round}");
221
222 info!("Starting round {next_round}...");
224 Ok(next_round)
225 }
226
227 fn update_current_round(&self, next_round: u64) {
229 self.current_round.store(next_round, Ordering::SeqCst);
231 }
232
233 pub(crate) fn garbage_collect_certificates(&self, next_round: u64) {
235 let current_gc_round = self.gc_round();
237 let next_gc_round = next_round.saturating_sub(self.max_gc_rounds);
239 if next_gc_round > current_gc_round {
241 for gc_round in current_gc_round..=next_gc_round {
243 for id in self.get_certificate_ids_for_round(gc_round).into_iter() {
245 trace!(
246 "Garbage collecting certificate {id} at round {gc_round} (cut-off is round {next_gc_round})"
247 );
248 self.remove_certificate(id);
249 }
250 }
251 self.gc_round.store(next_gc_round, Ordering::SeqCst);
253 }
254 }
255}
256
257impl<N: Network> Storage<N> {
258 pub fn contains_certificates_for_round(&self, round: u64) -> bool {
260 self.rounds.read().contains_key(&round)
262 }
263
264 pub fn contains_certificate(&self, certificate_id: Field<N>) -> bool {
266 self.certificates.read().contains_key(&certificate_id)
268 }
269
270 pub fn contains_certificate_in_round_from(&self, round: u64, author: Address<N>) -> bool {
272 self.rounds.read().get(&round).is_some_and(|set| set.iter().any(|(_, a)| a == &author))
273 }
274
275 pub fn contains_unprocessed_certificate(&self, certificate_id: Field<N>) -> bool {
277 self.unprocessed_certificates.read().contains(&certificate_id)
279 }
280
281 pub fn contains_batch(&self, batch_id: Field<N>) -> bool {
283 self.batch_ids.read().contains_key(&batch_id)
285 }
286
287 pub fn contains_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> bool {
289 self.transmissions.contains_transmission(transmission_id.into())
290 }
291
292 pub fn get_transmission(&self, transmission_id: impl Into<TransmissionID<N>>) -> Option<Transmission<N>> {
295 self.transmissions.get_transmission(transmission_id.into())
296 }
297
298 pub fn get_round_for_certificate(&self, certificate_id: Field<N>) -> Option<u64> {
301 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
303 }
304
305 pub fn get_round_for_batch(&self, batch_id: Field<N>) -> Option<u64> {
308 self.batch_ids.read().get(&batch_id).copied()
310 }
311
312 pub fn get_certificate_round(&self, certificate_id: Field<N>) -> Option<u64> {
315 self.certificates.read().get(&certificate_id).map(|certificate| certificate.round())
317 }
318
319 pub fn get_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
322 self.certificates.read().get(&certificate_id).cloned()
324 }
325
326 pub fn get_unprocessed_certificate(&self, certificate_id: Field<N>) -> Option<BatchCertificate<N>> {
329 self.unprocessed_certificates.read().peek(&certificate_id).cloned()
331 }
332
333 pub fn get_certificate_for_round_with_author(&self, round: u64, author: Address<N>) -> Option<BatchCertificate<N>> {
337 if let Some(entries) = self.rounds.read().get(&round) {
339 let certificates = self.certificates.read();
340 entries.iter().find_map(
341 |(certificate_id, a)| if a == &author { certificates.get(certificate_id).cloned() } else { None },
342 )
343 } else {
344 Default::default()
345 }
346 }
347
348 pub fn get_certificates_for_round(&self, round: u64) -> IndexSet<BatchCertificate<N>> {
351 if round == 0 {
353 return Default::default();
354 }
355 if let Some(entries) = self.rounds.read().get(&round) {
357 let certificates = self.certificates.read();
358 entries.iter().flat_map(|(certificate_id, _)| certificates.get(certificate_id).cloned()).collect()
359 } else {
360 Default::default()
361 }
362 }
363
364 pub fn get_certificate_ids_for_round(&self, round: u64) -> IndexSet<Field<N>> {
367 if round == 0 {
369 return Default::default();
370 }
371 if let Some(entries) = self.rounds.read().get(&round) {
373 entries.iter().map(|(certificate_id, _)| *certificate_id).collect()
374 } else {
375 Default::default()
376 }
377 }
378
379 pub fn get_certificate_authors_for_round(&self, round: u64) -> HashSet<Address<N>> {
382 if round == 0 {
384 return Default::default();
385 }
386 if let Some(entries) = self.rounds.read().get(&round) {
388 entries.iter().map(|(_, author)| *author).collect()
389 } else {
390 Default::default()
391 }
392 }
393
394 pub(crate) fn get_pending_certificates(&self) -> IndexSet<BatchCertificate<N>> {
397 let rounds = self.rounds.read();
399 let certificates = self.certificates.read();
400
401 cfg_sorted_by!(rounds.clone(), |a, _, b, _| a.cmp(b))
403 .flat_map(|(_, certificates_for_round)| {
404 cfg_into_iter!(certificates_for_round).filter_map(|(certificate_id, _)| {
406 if self.ledger.contains_certificate(&certificate_id).unwrap_or(false) {
408 None
409 } else {
410 certificates.get(&certificate_id).cloned()
412 }
413 })
414 })
415 .collect()
416 }
417
418 pub fn check_batch_header(
442 &self,
443 batch_header: &BatchHeader<N>,
444 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
445 aborted_transmissions: HashSet<TransmissionID<N>>,
446 ) -> Result<Option<HashMap<TransmissionID<N>, Transmission<N>>>> {
447 let round = batch_header.round();
449 let gc_round = self.gc_round();
451 let gc_log = format!("(gc = {gc_round})");
453
454 if self.contains_batch(batch_header.batch_id()) {
456 debug!("Batch for round {round} already exists in storage {gc_log}");
457 return Ok(None);
458 }
459
460 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
462 bail!("Storage failed to retrieve the committee lookback for round {round} {gc_log}")
463 };
464 if !committee_lookback.is_committee_member(batch_header.author()) {
466 bail!("Author {} is not in the committee for round {round} {gc_log}", batch_header.author())
467 }
468
469 check_timestamp_for_liveness(batch_header.timestamp())?;
471
472 let missing_transmissions = self
474 .transmissions
475 .find_missing_transmissions(batch_header, transmissions, aborted_transmissions)
476 .map_err(|e| anyhow!("{e} for round {round} {gc_log}"))?;
477
478 let previous_round = round.saturating_sub(1);
480 if previous_round > gc_round {
482 let Ok(previous_committee_lookback) = self.ledger.get_committee_lookback_for_round(previous_round) else {
484 bail!("Missing committee for the previous round {previous_round} in storage {gc_log}")
485 };
486 if !self.contains_certificates_for_round(previous_round) {
488 bail!("Missing certificates for the previous round {previous_round} in storage {gc_log}")
489 }
490 if batch_header.previous_certificate_ids().len() > previous_committee_lookback.num_members() {
492 bail!("Too many previous certificates for round {round} {gc_log}")
493 }
494 let mut previous_authors = HashSet::with_capacity(batch_header.previous_certificate_ids().len());
496 for previous_certificate_id in batch_header.previous_certificate_ids() {
498 let Some(previous_certificate) = self.get_certificate(*previous_certificate_id) else {
500 bail!(
501 "Missing previous certificate '{}' for certificate in round {round} {gc_log}",
502 fmt_id(previous_certificate_id)
503 )
504 };
505 if previous_certificate.round() != previous_round {
507 bail!("Round {round} certificate contains a round {previous_round} certificate {gc_log}")
508 }
509 if previous_authors.contains(&previous_certificate.author()) {
511 bail!("Round {round} certificate contains a duplicate author {gc_log}")
512 }
513 previous_authors.insert(previous_certificate.author());
515 }
516 if !previous_committee_lookback.is_quorum_threshold_reached(&previous_authors) {
518 bail!("Previous certificates for a batch in round {round} did not reach quorum threshold {gc_log}")
519 }
520 }
521
522 Ok(Some(missing_transmissions))
523 }
524
525 pub fn check_incoming_certificate(&self, certificate: &BatchCertificate<N>) -> Result<()> {
537 let certificate_author = certificate.author();
539 let certificate_round = certificate.round();
540
541 let committee_lookback = self.ledger.get_committee_lookback_for_round(certificate_round)?;
543
544 let mut signers: HashSet<Address<N>> =
547 certificate.signatures().map(|signature| signature.to_address()).collect();
548 signers.insert(certificate_author);
549 ensure!(
550 committee_lookback.is_quorum_threshold_reached(&signers),
551 "Certificate '{}' for round {certificate_round} does not meet quorum requirements",
552 certificate.id()
553 );
554
555 cfg_iter!(&signers).try_for_each(|signer| {
557 ensure!(
558 committee_lookback.is_committee_member(*signer),
559 "Signer '{signer}' of certificate '{}' for round {certificate_round} is not in the committee",
560 certificate.id()
561 );
562 Ok(())
563 })?;
564
565 Ok(())
566 }
567
568 pub fn check_certificate(
590 &self,
591 certificate: &BatchCertificate<N>,
592 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
593 aborted_transmissions: HashSet<TransmissionID<N>>,
594 ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
595 let round = certificate.round();
597 let gc_round = self.gc_round();
599 let gc_log = format!("(gc = {gc_round})");
601
602 if self.contains_certificate(certificate.id()) {
604 bail!("Certificate for round {round} already exists in storage {gc_log}")
605 }
606
607 if self.contains_certificate_in_round_from(round, certificate.author()) {
609 bail!("Certificate with this author for round {round} already exists in storage {gc_log}")
610 }
611
612 let Some(missing_transmissions) =
614 self.check_batch_header(certificate.batch_header(), transmissions, aborted_transmissions)?
615 else {
616 bail!("Certificate for round {round} already exists in storage {gc_log}")
617 };
618
619 check_timestamp_for_liveness(certificate.timestamp())?;
621
622 let Ok(committee_lookback) = self.ledger.get_committee_lookback_for_round(round) else {
624 bail!("Storage failed to retrieve the committee for round {round} {gc_log}")
625 };
626
627 let mut signers = HashSet::with_capacity(certificate.signatures().len() + 1);
629 signers.insert(certificate.author());
631
632 for signature in certificate.signatures() {
634 let signer = signature.to_address();
636 if !committee_lookback.is_committee_member(signer) {
638 bail!("Signer {signer} is not in the committee for round {round} {gc_log}")
639 }
640 signers.insert(signer);
642 }
643
644 if !committee_lookback.is_quorum_threshold_reached(&signers) {
646 bail!("Signatures for a batch in round {round} did not reach quorum threshold {gc_log}")
647 }
648
649 Ok(missing_transmissions)
650 }
651
652 pub fn insert_certificate(
670 &self,
671 certificate: BatchCertificate<N>,
672 transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
673 aborted_transmissions: HashSet<TransmissionID<N>>,
674 ) -> Result<()> {
675 ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
677 let missing_transmissions =
679 self.check_certificate(&certificate, transmissions, aborted_transmissions.clone())?;
680 self.insert_certificate_atomic(certificate, aborted_transmissions, missing_transmissions);
682 Ok(())
683 }
684
685 fn insert_certificate_atomic(
691 &self,
692 certificate: BatchCertificate<N>,
693 aborted_transmission_ids: HashSet<TransmissionID<N>>,
694 missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
695 ) {
696 let round = certificate.round();
698 let certificate_id = certificate.id();
700 let author = certificate.author();
702
703 self.rounds.write().entry(round).or_default().insert((certificate_id, author));
705 let transmission_ids = certificate.transmission_ids().clone();
707 self.certificates.write().insert(certificate_id, certificate);
709 self.unprocessed_certificates.write().pop(&certificate_id);
711 self.batch_ids.write().insert(certificate_id, round);
713 self.transmissions.insert_transmissions(
715 certificate_id,
716 transmission_ids,
717 aborted_transmission_ids,
718 missing_transmissions,
719 );
720 }
721
722 pub fn insert_unprocessed_certificate(&self, certificate: BatchCertificate<N>) -> Result<()> {
726 ensure!(certificate.round() > self.gc_round(), "Certificate round is at or below the GC round");
728 self.unprocessed_certificates.write().put(certificate.id(), certificate);
730
731 Ok(())
732 }
733
734 fn remove_certificate(&self, certificate_id: Field<N>) -> bool {
741 let Some(certificate) = self.get_certificate(certificate_id) else {
743 warn!("Certificate {certificate_id} does not exist in storage");
744 return false;
745 };
746 let round = certificate.round();
748 let author = certificate.author();
750
751 match self.rounds.write().entry(round) {
757 Entry::Occupied(mut entry) => {
758 entry.get_mut().swap_remove(&(certificate_id, author));
760 if entry.get().is_empty() {
762 entry.swap_remove();
763 }
764 }
765 Entry::Vacant(_) => {}
766 }
767 self.certificates.write().swap_remove(&certificate_id);
769 self.unprocessed_certificates.write().pop(&certificate_id);
771 self.batch_ids.write().swap_remove(&certificate_id);
773 self.transmissions.remove_transmissions(&certificate_id, certificate.transmission_ids());
775 true
777 }
778}
779
780impl<N: Network> Storage<N> {
781 pub(crate) fn sync_height_with_block(&self, next_height: u32) {
783 if next_height > self.current_height() {
785 self.current_height.store(next_height, Ordering::SeqCst);
787 }
788 }
789
790 pub(crate) fn sync_round_with_block(&self, next_round: u64) {
792 let next_round = next_round.max(1);
794 if next_round > self.current_round() {
796 self.update_current_round(next_round);
798 info!("Synced to round {next_round}...");
800 } else {
801 trace!(
802 "Skipping sync to round {next_round} as it is less than the current round ({})",
803 self.current_round()
804 );
805 }
806 }
807
808 pub(crate) fn sync_certificate_with_block(
810 &self,
811 block: &Block<N>,
812 certificate: BatchCertificate<N>,
813 unconfirmed_transactions: &HashMap<N::TransactionID, Transaction<N>>,
814 ) {
815 let gc_round = self.gc_round();
817 if certificate.round() <= gc_round {
818 trace!("Got certificate for round {} below GC round ({gc_round}). Will not store it.", certificate.round());
819 return;
820 }
821
822 if self.contains_certificate(certificate.id()) {
824 trace!("Got certificate {} for round {} more than once.", certificate.id(), certificate.round());
825 return;
826 }
827 let mut missing_transmissions = HashMap::new();
829
830 let mut aborted_transmissions = HashSet::new();
832
833 let aborted_solutions: IndexSet<_> = block.aborted_solution_ids().iter().collect();
835 let aborted_transactions: IndexSet<_> = block.aborted_transaction_ids().iter().collect();
836
837 for transmission_id in certificate.transmission_ids() {
839 if missing_transmissions.contains_key(transmission_id) {
841 continue;
842 }
843 if self.contains_transmission(*transmission_id) {
845 continue;
846 }
847 match transmission_id {
849 TransmissionID::Ratification => (),
850 TransmissionID::Solution(solution_id, _) => {
851 match block.get_solution(solution_id) {
853 Some(solution) => missing_transmissions.insert(*transmission_id, (*solution).into()),
855 None => match self.ledger.get_solution(solution_id) {
857 Ok(solution) => missing_transmissions.insert(*transmission_id, solution.into()),
859 Err(_) => {
861 match aborted_solutions.contains(solution_id)
863 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
864 {
865 true => {
866 aborted_transmissions.insert(*transmission_id);
867 }
868 false => error!("Missing solution {solution_id} in block {}", block.height()),
869 }
870 continue;
871 }
872 },
873 };
874 }
875 TransmissionID::Transaction(transaction_id, _) => {
876 match unconfirmed_transactions.get(transaction_id) {
878 Some(transaction) => missing_transmissions.insert(*transmission_id, transaction.clone().into()),
880 None => match self.ledger.get_unconfirmed_transaction(*transaction_id) {
882 Ok(transaction) => missing_transmissions.insert(*transmission_id, transaction.into()),
884 Err(_) => {
886 match aborted_transactions.contains(transaction_id)
888 || self.ledger.contains_transmission(transmission_id).unwrap_or(false)
889 {
890 true => {
891 aborted_transmissions.insert(*transmission_id);
892 }
893 false => warn!("Missing transaction {transaction_id} in block {}", block.height()),
894 }
895 continue;
896 }
897 },
898 };
899 }
900 }
901 }
902 let certificate_id = fmt_id(certificate.id());
904 debug!(
905 "Syncing certificate '{certificate_id}' for round {} with {} transmissions",
906 certificate.round(),
907 certificate.transmission_ids().len()
908 );
909
910 if let Err(error) = self
911 .insert_certificate(certificate, missing_transmissions, aborted_transmissions)
912 .with_context(|| format!("Failed to insert certificate '{certificate_id}' from block {}", block.height()))
913 {
914 error!("{}", &flatten_error(&error));
915 }
916 }
917}
918
919#[cfg(test)]
920impl<N: Network> Storage<N> {
921 pub fn ledger(&self) -> &Arc<dyn LedgerService<N>> {
923 &self.ledger
924 }
925
926 pub fn rounds_iter(&self) -> impl Iterator<Item = (u64, IndexSet<(Field<N>, Address<N>)>)> {
928 self.rounds.read().clone().into_iter()
929 }
930
931 pub fn certificates_iter(&self) -> impl Iterator<Item = (Field<N>, BatchCertificate<N>)> {
933 self.certificates.read().clone().into_iter()
934 }
935
936 pub fn batch_ids_iter(&self) -> impl Iterator<Item = (Field<N>, u64)> {
938 self.batch_ids.read().clone().into_iter()
939 }
940
941 pub fn transmissions_iter(
943 &self,
944 ) -> impl Iterator<Item = (TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>))> {
945 self.transmissions.as_hashmap().into_iter()
946 }
947
948 #[cfg(test)]
952 #[doc(hidden)]
953 pub(crate) fn testing_only_insert_certificate_testing_only(&self, certificate: BatchCertificate<N>) {
954 let round = certificate.round();
956 let certificate_id = certificate.id();
958 let author = certificate.author();
960
961 self.rounds.write().entry(round).or_default().insert((certificate_id, author));
963 let transmission_ids = certificate.transmission_ids().clone();
965 self.certificates.write().insert(certificate_id, certificate);
967 self.batch_ids.write().insert(certificate_id, round);
969
970 let missing_transmissions = transmission_ids
972 .iter()
973 .map(|id| (*id, Transmission::Transaction(snarkvm::ledger::narwhal::Data::Buffer(bytes::Bytes::new()))))
974 .collect::<HashMap<_, _>>();
975 self.transmissions.insert_transmissions(
977 certificate_id,
978 transmission_ids,
979 Default::default(),
980 missing_transmissions,
981 );
982 }
983}
984
985#[cfg(test)]
986pub(crate) mod tests {
987 use super::*;
988 use snarkos_node_bft_ledger_service::MockLedgerService;
989 use snarkos_node_bft_storage_service::BFTMemoryService;
990 use snarkvm::{
991 ledger::narwhal::{Data, batch_certificate::test_helpers::sample_batch_certificate_for_round_with_committee},
992 prelude::{Rng, TestRng},
993 };
994
995 use ::bytes::Bytes;
996 use indexmap::indexset;
997
998 type CurrentNetwork = snarkvm::prelude::MainnetV0;
999
1000 pub fn assert_storage<N: Network>(
1002 storage: &Storage<N>,
1003 rounds: &[(u64, IndexSet<(Field<N>, Address<N>)>)],
1004 certificates: &[(Field<N>, BatchCertificate<N>)],
1005 batch_ids: &[(Field<N>, u64)],
1006 transmissions: &HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>,
1007 ) {
1008 assert_eq!(storage.rounds_iter().collect::<Vec<_>>(), *rounds);
1010 assert_eq!(storage.certificates_iter().collect::<Vec<_>>(), *certificates);
1012 assert_eq!(storage.batch_ids_iter().collect::<Vec<_>>(), *batch_ids);
1014 assert_eq!(storage.transmissions_iter().collect::<HashMap<_, _>>(), *transmissions);
1016 }
1017
1018 fn sample_transmission(rng: &mut TestRng) -> Transmission<CurrentNetwork> {
1020 let s = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..512).map(|_| rng.r#gen::<u8>()).collect::<Vec<_>>()));
1022 let t = |rng: &mut TestRng| Data::Buffer(Bytes::from((0..2048).map(|_| rng.r#gen::<u8>()).collect::<Vec<_>>()));
1024 match rng.r#gen::<bool>() {
1026 true => Transmission::Solution(s(rng)),
1027 false => Transmission::Transaction(t(rng)),
1028 }
1029 }
1030
1031 pub(crate) fn sample_transmissions(
1033 certificate: &BatchCertificate<CurrentNetwork>,
1034 rng: &mut TestRng,
1035 ) -> (
1036 HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>>,
1037 HashMap<TransmissionID<CurrentNetwork>, (Transmission<CurrentNetwork>, IndexSet<Field<CurrentNetwork>>)>,
1038 ) {
1039 let certificate_id = certificate.id();
1041
1042 let mut missing_transmissions = HashMap::new();
1043 let mut transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1044 for transmission_id in certificate.transmission_ids() {
1045 let transmission = sample_transmission(rng);
1047 missing_transmissions.insert(*transmission_id, transmission.clone());
1049 transmissions
1051 .entry(*transmission_id)
1052 .or_insert((transmission, Default::default()))
1053 .1
1054 .insert(certificate_id);
1055 }
1056 (missing_transmissions, transmissions)
1057 }
1058
1059 #[test]
1062 fn test_certificate_insert_remove() {
1063 let rng = &mut TestRng::default();
1064
1065 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1067 let ledger = Arc::new(MockLedgerService::new(committee));
1069 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1071
1072 assert_storage(&storage, &[], &[], &[], &Default::default());
1074
1075 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1077 let certificate_id = certificate.id();
1079 let round = certificate.round();
1081 let author = certificate.author();
1083
1084 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1086
1087 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions);
1089 assert!(storage.contains_certificate(certificate_id));
1091 assert_eq!(storage.get_certificates_for_round(round), indexset! { certificate.clone() });
1093 assert_eq!(storage.get_certificate_for_round_with_author(round, author), Some(certificate.clone()));
1095
1096 {
1098 let rounds = [(round, indexset! { (certificate_id, author) })];
1100 let certificates = [(certificate_id, certificate.clone())];
1102 let batch_ids = [(certificate_id, round)];
1104 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1106 }
1107
1108 let candidate_certificate = storage.get_certificate(certificate_id).unwrap();
1110 assert_eq!(certificate, candidate_certificate);
1112
1113 assert!(storage.remove_certificate(certificate_id));
1115 assert!(!storage.contains_certificate(certificate_id));
1117 assert!(storage.get_certificates_for_round(round).is_empty());
1119 assert_eq!(storage.get_certificate_for_round_with_author(round, author), None);
1121 assert_storage(&storage, &[], &[], &[], &Default::default());
1123 }
1124
1125 #[test]
1126 fn test_certificate_duplicate() {
1127 let rng = &mut TestRng::default();
1128
1129 let committee = snarkvm::ledger::committee::test_helpers::sample_committee(rng);
1131 let ledger = Arc::new(MockLedgerService::new(committee));
1133 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1135
1136 assert_storage(&storage, &[], &[], &[], &Default::default());
1138
1139 let certificate = snarkvm::ledger::narwhal::batch_certificate::test_helpers::sample_batch_certificate(rng);
1141 let certificate_id = certificate.id();
1143 let round = certificate.round();
1145 let author = certificate.author();
1147
1148 let rounds = [(round, indexset! { (certificate_id, author) })];
1150 let certificates = [(certificate_id, certificate.clone())];
1152 let batch_ids = [(certificate_id, round)];
1154 let (missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1156
1157 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1159 assert!(storage.contains_certificate(certificate_id));
1161 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1163
1164 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1166 assert!(storage.contains_certificate(certificate_id));
1168 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1170
1171 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1173 assert!(storage.contains_certificate(certificate_id));
1175 assert_storage(&storage, &rounds, &certificates, &batch_ids, &transmissions);
1177 }
1178
1179 #[test]
1181 fn test_valid_incoming_certificate() {
1182 let rng = &mut TestRng::default();
1183
1184 let (committee, private_keys) =
1186 snarkvm::ledger::committee::test_helpers::sample_committee_and_keys_for_round(0, 5, rng);
1187 let ledger = Arc::new(MockLedgerService::new(committee));
1189 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1191
1192 let mut previous_certs = IndexSet::default();
1194
1195 for round in 1..=100 {
1196 let mut new_certs = IndexSet::default();
1197
1198 for private_key in private_keys.iter() {
1200 let other_keys: Vec<_> = private_keys.iter().cloned().filter(|k| k != private_key).collect();
1201
1202 let certificate = sample_batch_certificate_for_round_with_committee(
1203 round,
1204 previous_certs.clone(),
1205 private_key,
1206 &other_keys,
1207 rng,
1208 );
1209 storage.check_incoming_certificate(&certificate).expect("Valid certificate rejected");
1210 new_certs.insert(certificate.id());
1211
1212 let (missing_transmissions, _transmissions) = sample_transmissions(&certificate, rng);
1214 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1215 }
1216
1217 previous_certs = new_certs;
1218 }
1219 }
1220
1221 #[test]
1223 fn test_invalid_incoming_certificate_missing_signature() {
1224 let rng = &mut TestRng::default();
1225
1226 let (committee, private_keys) =
1228 snarkvm::ledger::committee::test_helpers::sample_committee_and_keys_for_round(0, 10, rng);
1229 let ledger = Arc::new(MockLedgerService::new(committee));
1231 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1233
1234 let mut previous_certs = IndexSet::default();
1236
1237 for round in 1..=5 {
1238 let mut new_certs = IndexSet::default();
1239
1240 for private_key in private_keys.iter() {
1242 if round < 5 {
1243 let other_keys: Vec<_> = private_keys.iter().cloned().filter(|k| k != private_key).collect();
1244
1245 let certificate = sample_batch_certificate_for_round_with_committee(
1246 round,
1247 previous_certs.clone(),
1248 private_key,
1249 &other_keys,
1250 rng,
1251 );
1252 storage.check_incoming_certificate(&certificate).expect("Valid certificate rejected");
1253 new_certs.insert(certificate.id());
1254
1255 let (missing_transmissions, _transmissions) = sample_transmissions(&certificate, rng);
1257 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1258 } else {
1259 let other_keys: Vec<_> = private_keys[0..=3].iter().cloned().filter(|k| k != private_key).collect();
1261
1262 let certificate = sample_batch_certificate_for_round_with_committee(
1263 round,
1264 previous_certs.clone(),
1265 private_key,
1266 &other_keys,
1267 rng,
1268 );
1269 assert!(storage.check_incoming_certificate(&certificate).is_err());
1270 }
1271 }
1272
1273 previous_certs = new_certs;
1274 }
1275 }
1276
1277 #[test]
1279 fn test_invalid_certificate_insufficient_previous_certs() {
1280 let rng = &mut TestRng::default();
1281
1282 let (committee, private_keys) =
1284 snarkvm::ledger::committee::test_helpers::sample_committee_and_keys_for_round(0, 10, rng);
1285 let ledger = Arc::new(MockLedgerService::new(committee));
1287 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1289
1290 let mut previous_certs = IndexSet::default();
1292
1293 for round in 1..=6 {
1294 let mut new_certs = IndexSet::default();
1295
1296 for private_key in private_keys.iter() {
1298 let other_keys: Vec<_> = private_keys.iter().cloned().filter(|k| k != private_key).collect();
1299
1300 let certificate = sample_batch_certificate_for_round_with_committee(
1301 round,
1302 previous_certs.clone(),
1303 private_key,
1304 &other_keys,
1305 rng,
1306 );
1307
1308 let (_missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1310 let transmissions = transmissions.into_iter().map(|(k, (t, _))| (k, t)).collect();
1311
1312 if round <= 5 {
1313 new_certs.insert(certificate.id());
1314 storage
1315 .insert_certificate(certificate, transmissions, Default::default())
1316 .expect("Valid certificate rejected");
1317 } else {
1318 assert!(storage.insert_certificate(certificate, transmissions, Default::default()).is_err());
1319 }
1320 }
1321
1322 if round < 5 {
1323 previous_certs = new_certs;
1324 } else {
1325 previous_certs = new_certs.into_iter().skip(6).collect();
1327 }
1328 }
1329 }
1330
1331 #[test]
1333 fn test_invalid_certificate_wrong_round_number() {
1334 let rng = &mut TestRng::default();
1335
1336 let (committee, private_keys) =
1338 snarkvm::ledger::committee::test_helpers::sample_committee_and_keys_for_round(0, 10, rng);
1339 let ledger = Arc::new(MockLedgerService::new(committee));
1341 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1343
1344 let mut previous_certs = IndexSet::default();
1346
1347 for round in 1..=6 {
1348 let mut new_certs = IndexSet::default();
1349
1350 for private_key in private_keys.iter() {
1352 let cert_round = round.min(5); let other_keys: Vec<_> = private_keys.iter().cloned().filter(|k| k != private_key).collect();
1354
1355 let certificate = sample_batch_certificate_for_round_with_committee(
1356 cert_round,
1357 previous_certs.clone(),
1358 private_key,
1359 &other_keys,
1360 rng,
1361 );
1362
1363 let (_missing_transmissions, transmissions) = sample_transmissions(&certificate, rng);
1365 let transmissions = transmissions.into_iter().map(|(k, (t, _))| (k, t)).collect();
1366
1367 if round <= 5 {
1368 new_certs.insert(certificate.id());
1369 storage
1370 .insert_certificate(certificate, transmissions, Default::default())
1371 .expect("Valid certificate rejected");
1372 } else {
1373 assert!(storage.insert_certificate(certificate, transmissions, Default::default()).is_err());
1374 }
1375 }
1376
1377 if round < 5 {
1378 previous_certs = new_certs;
1379 } else {
1380 previous_certs = new_certs.into_iter().skip(6).collect();
1382 }
1383 }
1384 }
1385}
1386
1387#[cfg(test)]
1388pub mod prop_tests {
1389 use super::*;
1390 use crate::helpers::{now, storage::tests::assert_storage};
1391 use snarkos_node_bft_ledger_service::MockLedgerService;
1392 use snarkos_node_bft_storage_service::BFTMemoryService;
1393 use snarkvm::{
1394 ledger::{
1395 committee::prop_tests::{CommitteeContext, ValidatorSet},
1396 narwhal::{BatchHeader, Data},
1397 puzzle::SolutionID,
1398 },
1399 prelude::{Signature, Uniform},
1400 };
1401
1402 use ::bytes::Bytes;
1403 use indexmap::indexset;
1404 use proptest::{
1405 collection,
1406 prelude::{Arbitrary, BoxedStrategy, Just, Strategy, any},
1407 prop_oneof,
1408 sample::{Selector, size_range},
1409 test_runner::TestRng,
1410 };
1411 use rand::{CryptoRng, Error, Rng, RngCore};
1412 use std::fmt::Debug;
1413 use test_strategy::proptest;
1414
1415 type CurrentNetwork = snarkvm::prelude::MainnetV0;
1416
1417 impl Arbitrary for Storage<CurrentNetwork> {
1418 type Parameters = CommitteeContext;
1419 type Strategy = BoxedStrategy<Storage<CurrentNetwork>>;
1420
1421 fn arbitrary() -> Self::Strategy {
1422 (any::<CommitteeContext>(), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1423 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1424 let ledger = Arc::new(MockLedgerService::new(committee));
1425 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds)
1426 })
1427 .boxed()
1428 }
1429
1430 fn arbitrary_with(context: Self::Parameters) -> Self::Strategy {
1431 (Just(context), 0..BatchHeader::<CurrentNetwork>::MAX_GC_ROUNDS as u64)
1432 .prop_map(|(CommitteeContext(committee, _), gc_rounds)| {
1433 let ledger = Arc::new(MockLedgerService::new(committee));
1434 Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), gc_rounds)
1435 })
1436 .boxed()
1437 }
1438 }
1439
1440 #[derive(Debug)]
1442 pub struct CryptoTestRng(TestRng);
1443
1444 impl Arbitrary for CryptoTestRng {
1445 type Parameters = ();
1446 type Strategy = BoxedStrategy<CryptoTestRng>;
1447
1448 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1449 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng)).boxed()
1450 }
1451 }
1452 impl RngCore for CryptoTestRng {
1453 fn next_u32(&mut self) -> u32 {
1454 self.0.next_u32()
1455 }
1456
1457 fn next_u64(&mut self) -> u64 {
1458 self.0.next_u64()
1459 }
1460
1461 fn fill_bytes(&mut self, dest: &mut [u8]) {
1462 self.0.fill_bytes(dest);
1463 }
1464
1465 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> std::result::Result<(), Error> {
1466 self.0.try_fill_bytes(dest)
1467 }
1468 }
1469
1470 impl CryptoRng for CryptoTestRng {}
1471
1472 #[derive(Debug, Clone)]
1473 pub struct AnyTransmission(pub Transmission<CurrentNetwork>);
1474
1475 impl Arbitrary for AnyTransmission {
1476 type Parameters = ();
1477 type Strategy = BoxedStrategy<AnyTransmission>;
1478
1479 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1480 any_transmission().prop_map(AnyTransmission).boxed()
1481 }
1482 }
1483
1484 #[derive(Debug, Clone)]
1485 pub struct AnyTransmissionID(pub TransmissionID<CurrentNetwork>);
1486
1487 impl Arbitrary for AnyTransmissionID {
1488 type Parameters = ();
1489 type Strategy = BoxedStrategy<AnyTransmissionID>;
1490
1491 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1492 any_transmission_id().prop_map(AnyTransmissionID).boxed()
1493 }
1494 }
1495
1496 fn any_transmission() -> BoxedStrategy<Transmission<CurrentNetwork>> {
1497 prop_oneof![
1498 (collection::vec(any::<u8>(), 512..=512))
1499 .prop_map(|bytes| Transmission::Solution(Data::Buffer(Bytes::from(bytes)))),
1500 (collection::vec(any::<u8>(), 2048..=2048))
1501 .prop_map(|bytes| Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))),
1502 ]
1503 .boxed()
1504 }
1505
1506 pub fn any_solution_id() -> BoxedStrategy<SolutionID<CurrentNetwork>> {
1507 Just(0).prop_perturb(|_, rng| CryptoTestRng(rng).r#gen::<u64>().into()).boxed()
1508 }
1509
1510 pub fn any_transaction_id() -> BoxedStrategy<<CurrentNetwork as Network>::TransactionID> {
1511 Just(0)
1512 .prop_perturb(|_, rng| {
1513 <CurrentNetwork as Network>::TransactionID::from(Field::rand(&mut CryptoTestRng(rng)))
1514 })
1515 .boxed()
1516 }
1517
1518 pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
1519 prop_oneof![
1520 any_transaction_id().prop_perturb(|id, mut rng| TransmissionID::Transaction(
1521 id,
1522 rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1523 )),
1524 any_solution_id().prop_perturb(|id, mut rng| TransmissionID::Solution(
1525 id,
1526 rng.r#gen::<<CurrentNetwork as Network>::TransmissionChecksum>()
1527 )),
1528 ]
1529 .boxed()
1530 }
1531
1532 pub fn sign_batch_header<R: Rng + CryptoRng>(
1533 validator_set: &ValidatorSet,
1534 batch_header: &BatchHeader<CurrentNetwork>,
1535 rng: &mut R,
1536 ) -> IndexSet<Signature<CurrentNetwork>> {
1537 let mut signatures = IndexSet::with_capacity(validator_set.0.len());
1538 for validator in validator_set.0.iter() {
1539 let private_key = validator.private_key;
1540 signatures.insert(private_key.sign(&[batch_header.batch_id()], rng).unwrap());
1541 }
1542 signatures
1543 }
1544
1545 #[proptest]
1546 fn test_certificate_duplicate(
1547 context: CommitteeContext,
1548 #[any(size_range(1..16).lift())] transmissions: Vec<(AnyTransmissionID, AnyTransmission)>,
1549 mut rng: CryptoTestRng,
1550 selector: Selector,
1551 ) {
1552 let CommitteeContext(committee, ValidatorSet(validators)) = context;
1553 let committee_id = committee.id();
1554
1555 let ledger = Arc::new(MockLedgerService::new(committee));
1557 let storage = Storage::<CurrentNetwork>::new(ledger, Arc::new(BFTMemoryService::new()), 1);
1558
1559 assert_storage(&storage, &[], &[], &[], &Default::default());
1561
1562 let signer = selector.select(&validators);
1564
1565 let mut transmission_map = IndexMap::new();
1566
1567 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter() {
1568 transmission_map.insert(*id, t.clone());
1569 }
1570
1571 let batch_header = BatchHeader::new(
1572 &signer.private_key,
1573 0,
1574 now(),
1575 committee_id,
1576 transmission_map.keys().cloned().collect(),
1577 Default::default(),
1578 &mut rng,
1579 )
1580 .unwrap();
1581
1582 let mut validators = validators.clone();
1585 validators.remove(signer);
1586
1587 let certificate = BatchCertificate::from(
1588 batch_header.clone(),
1589 sign_batch_header(&ValidatorSet(validators), &batch_header, &mut rng),
1590 )
1591 .unwrap();
1592
1593 let certificate_id = certificate.id();
1595 let mut internal_transmissions = HashMap::<_, (_, IndexSet<Field<CurrentNetwork>>)>::new();
1596 for (AnyTransmissionID(id), AnyTransmission(t)) in transmissions.iter().cloned() {
1597 internal_transmissions.entry(id).or_insert((t, Default::default())).1.insert(certificate_id);
1598 }
1599
1600 let round = certificate.round();
1602 let author = certificate.author();
1604
1605 let rounds = [(round, indexset! { (certificate_id, author) })];
1607 let certificates = [(certificate_id, certificate.clone())];
1609 let batch_ids = [(certificate_id, round)];
1611
1612 let missing_transmissions: HashMap<TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>> =
1614 transmission_map.into_iter().collect();
1615 storage.insert_certificate_atomic(certificate.clone(), Default::default(), missing_transmissions.clone());
1616 assert!(storage.contains_certificate(certificate_id));
1618 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1620
1621 storage.insert_certificate_atomic(certificate.clone(), Default::default(), Default::default());
1623 assert!(storage.contains_certificate(certificate_id));
1625 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1627
1628 storage.insert_certificate_atomic(certificate, Default::default(), missing_transmissions);
1630 assert!(storage.contains_certificate(certificate_id));
1632 assert_storage(&storage, &rounds, &certificates, &batch_ids, &internal_transmissions);
1634 }
1635}