Skip to main content

snap_coin/full_node/
mempool.rs

1use std::{collections::BTreeMap, sync::Arc, time::Duration};
2
3use tokio::{sync::RwLock, time::sleep};
4
5use crate::{
6    core::transaction::{Transaction, TransactionId},
7    economics::EXPIRATION_TIME,
8};
9
10pub struct MemPool {
11    /// BTreeMap of expiry timestamp -> transactions
12    pending: Arc<RwLock<BTreeMap<u64, Vec<Transaction>>>>,
13}
14
15impl MemPool {
16    pub fn new() -> Self {
17        MemPool {
18            pending: Arc::new(RwLock::new(BTreeMap::new())),
19        }
20    }
21
22    /// Starts a background task that removes expired transactions
23    pub fn start_expiry_watchdog(
24        &self,
25        mut on_expiry: impl FnMut(TransactionId) + Send + Sync + 'static,
26    ) {
27        let pending = self.pending.clone();
28        tokio::spawn(async move {
29            loop {
30                sleep(Duration::from_millis(500)).await;
31                let now = chrono::Utc::now().timestamp() as u64;
32
33                let mut write_guard = pending.write().await;
34
35                // Remove all expired transactions efficiently
36                let expired_keys: Vec<u64> = write_guard.range(..=now).map(|(&k, _)| k).collect();
37
38                for key in expired_keys {
39                    if let Some(txs) = write_guard.remove(&key) {
40                        for tx in txs {
41                            if let Some(tx_id) = tx.transaction_id {
42                                on_expiry(tx_id);
43                            }
44                        }
45                    }
46                }
47            }
48        });
49    }
50
51    /// Get a vector of all transactions in this mempool
52    pub async fn get_mempool(&self) -> Vec<Transaction> {
53        self.pending
54            .read()
55            .await
56            .values()
57            .flat_map(|v| v.iter().cloned())
58            .collect()
59    }
60
61    /// Add a transaction to the mempool
62    /// WARNING: Make sure this transaction is valid before
63    pub async fn add_transaction(&self, transaction: Transaction) {
64        let expiry = chrono::Utc::now().timestamp() as u64 + EXPIRATION_TIME;
65
66        let mut write_guard = self.pending.write().await;
67        write_guard.entry(expiry).or_default().push(transaction);
68    }
69
70    /// Returns true if a transaction is valid (check for double spending)
71    pub async fn validate_transaction(&self, transaction: &Transaction) -> bool {
72        let mempool = self.get_mempool().await;
73        for mempool_transaction in mempool {
74            if transaction.inputs.iter().any(|i| {
75                mempool_transaction.inputs.iter().any(|mi| {
76                    mi.output_index == i.output_index && mi.transaction_id == i.transaction_id
77                })
78            }) {
79                return false;
80            }
81        }
82        true
83    }
84
85    /// Remove transactions that have been spent
86    pub async fn spend_transactions(&self, transactions: Vec<TransactionId>) {
87        let mut pending = self.pending.write().await;
88
89        for txs in pending.values_mut() {
90            txs.retain(|mempool_tx| {
91                if let Some(id) = mempool_tx.transaction_id {
92                    !transactions.contains(&id)
93                } else {
94                    true
95                }
96            });
97        }
98
99        // Clean up empty expiry buckets
100        pending.retain(|_, txs| !txs.is_empty());
101    }
102
103    pub async fn mempool_size(&self) -> usize {
104        self.pending
105            .read()
106            .await
107            .values()
108            .fold(0, |acc, txs| acc + txs.len())
109    }
110
111    pub async fn clear(&self) {
112        self.pending.write().await.clear();
113    }
114}