Skip to main content

snap_coin/full_node/
mempool.rs

1use std::{collections::BTreeMap, sync::Arc, time::Duration};
2
3use tokio::{sync::RwLock, time::sleep};
4
5use crate::{
6    core::transaction::{Transaction, TransactionId},
7    economics::EXPIRATION_TIME,
8};
9
10pub struct MemPool {
11    /// BTreeMap of expiry timestamp -> transactions
12    pending: Arc<RwLock<BTreeMap<u64, Vec<Transaction>>>>,
13}
14
15impl MemPool {
16    pub fn new() -> Self {
17        MemPool {
18            pending: Arc::new(RwLock::new(BTreeMap::new())),
19        }
20    }
21
22    /// Starts a background task that removes expired transactions
23    pub fn start_expiry_watchdog(&self, mut on_expiry: impl FnMut(TransactionId) + Send + Sync + 'static) {
24        let pending = self.pending.clone();
25        tokio::spawn(async move {
26            loop {
27                sleep(Duration::from_millis(500)).await;
28                let now = chrono::Utc::now().timestamp() as u64;
29
30                let mut write_guard = pending.write().await;
31
32                // Remove all expired transactions efficiently
33                let expired_keys: Vec<u64> = write_guard.range(..=now).map(|(&k, _)| k).collect();
34
35                for key in expired_keys {
36                    if let Some(txs) = write_guard.remove(&key) {
37                        for tx in txs {
38                            if let Some(tx_id) = tx.transaction_id {
39                                on_expiry(tx_id);
40                            }
41                        }
42                    }
43                }
44            }
45        });
46    }
47
48    /// Get a vector of all transactions in this mempool
49    pub async fn get_mempool(&self) -> Vec<Transaction> {
50        self.pending
51            .read()
52            .await
53            .values()
54            .flat_map(|v| v.iter().cloned())
55            .collect()
56    }
57
58    /// Add a transaction to the mempool
59    /// WARNING: Make sure this transaction is valid before
60    pub async fn add_transaction(&self, transaction: Transaction) {
61        let expiry = chrono::Utc::now().timestamp() as u64 + EXPIRATION_TIME;
62
63        let mut write_guard = self.pending.write().await;
64        write_guard.entry(expiry).or_default().push(transaction);
65    }
66
67    /// Returns true if a transaction is valid (check for double spending)
68    pub async fn validate_transaction(&self, transaction: &Transaction) -> bool {
69        let mempool = self.get_mempool().await;
70        for mempool_transaction in mempool {
71            if transaction.inputs.iter().any(|i| {
72                mempool_transaction.inputs.iter().any(|mi| {
73                    mi.output_index == i.output_index && mi.transaction_id == i.transaction_id
74                })
75            }) {
76                return false;
77            }
78        }
79        true
80    }
81
82    /// Remove transactions that have been spent
83    pub async fn spend_transactions(&self, transactions: Vec<TransactionId>) {
84        let mut pending = self.pending.write().await;
85
86        for txs in pending.values_mut() {
87            txs.retain(|mempool_tx| {
88                if let Some(id) = mempool_tx.transaction_id {
89                    !transactions.contains(&id)
90                } else {
91                    true
92                }
93            });
94        }
95
96        // Clean up empty expiry buckets
97        pending.retain(|_, txs| !txs.is_empty());
98    }
99
100    pub async fn mempool_size(&self) -> usize {
101        self.pending
102            .read()
103            .await
104            .values()
105            .fold(0, |acc, txs| acc + txs.len())
106    }
107}