Skip to main content

snap_coin/full_node/
node_state.rs

1use serde::{Deserialize, Serialize};
2use std::{
3    collections::{HashMap, HashSet, VecDeque},
4    net::{IpAddr, SocketAddr},
5    sync::Arc,
6};
7use tokio::sync::{
8    Mutex, RwLock, broadcast,
9    watch::{self, Ref},
10};
11
12use crate::{
13    core::{
14        block::Block,
15        difficulty::calculate_live_transaction_difficulty,
16        transaction::{Transaction, TransactionId},
17    },
18    crypto::Hash,
19    full_node::{
20        mempool::MemPool,
21        p2p_server::{BAN_SCORE_THRESHOLD, ClientHealthScores, PUNISHMENT},
22    },
23    node::peer::PeerHandle,
24};
25
26pub type SharedNodeState = Arc<NodeState>;
27
28pub struct NodeState {
29    pub connected_peers: RwLock<HashMap<SocketAddr, PeerHandle>>,
30    pub mempool: MemPool,
31    pub is_syncing: RwLock<bool>,
32    pub chain_events: broadcast::Sender<ChainEvent>,
33    pub processing: Mutex<()>,
34    pub client_health_scores: ClientHealthScores,
35    last_seen_block_reader: watch::Receiver<Hash>,
36    last_seen_block_writer: watch::Sender<Hash>,
37    last_seen_transactions_reader: watch::Receiver<VecDeque<TransactionId>>,
38    last_seen_transactions_writer: watch::Sender<VecDeque<TransactionId>>,
39}
40
41impl NodeState {
42    pub fn new_empty() -> SharedNodeState {
43        let (last_seen_block_writer, last_seen_block_reader) =
44            watch::channel(Hash::new_from_buf([0u8; 32]));
45        let (last_seen_transactions_writer, last_seen_transactions_reader) =
46            watch::channel(VecDeque::new());
47
48        Arc::new(NodeState {
49            connected_peers: RwLock::new(HashMap::new()),
50            mempool: MemPool::new(),
51            is_syncing: RwLock::new(false),
52            chain_events: broadcast::channel(64).0,
53            processing: Mutex::new(()),
54            last_seen_block_reader,
55            last_seen_block_writer,
56            last_seen_transactions_reader,
57            last_seen_transactions_writer,
58            client_health_scores: ClientHealthScores::new(HashMap::new()),
59        })
60    }
61
62    /// Get the latest seen block
63    pub fn last_seen_block(&self) -> Hash {
64        self.last_seen_block_reader.borrow().clone()
65    }
66
67    /// Set a new last seen block
68    pub fn set_last_seen_block(&self, hash: Hash) {
69        let _ = self.last_seen_block_writer.send(hash);
70    }
71
72    /// Get the latest seen transactions
73    pub fn last_seen_transactions(&self) -> Ref<'_, VecDeque<TransactionId>> {
74        self.last_seen_transactions_reader.borrow()
75    }
76
77    /// Add a new last seen transaction, removing the oldest if >500
78    pub fn add_last_seen_transaction(&self, tx_id: TransactionId) {
79        let mut transactions: VecDeque<TransactionId> =
80            self.last_seen_transactions_reader.borrow().clone();
81
82        // Avoid duplicates
83        if !transactions.contains(&tx_id) {
84            transactions.push_back(tx_id);
85
86            // Keep only the latest 500 transactions
87            if transactions.len() > 500 {
88                transactions.pop_front(); // remove oldest
89            }
90
91            let _ = self.last_seen_transactions_writer.send(transactions);
92        }
93    }
94
95    pub async fn get_live_transaction_difficulty(
96        &self,
97        transaction_difficulty: [u8; 32],
98    ) -> [u8; 32] {
99        calculate_live_transaction_difficulty(
100            &transaction_difficulty,
101            self.mempool.mempool_size().await,
102        )
103    }
104
105    /// Punish a IP address
106    pub async fn punish_ip(&self, ip: IpAddr) {
107        *self
108            .client_health_scores
109            .write()
110            .await
111            .entry(ip)
112            .or_insert(0) += PUNISHMENT;
113    }
114
115    /// "Forgive" everyone by 1 pt
116    pub async fn decrement_punishments(&self) {
117        let mut scores = self.client_health_scores.write().await;
118
119        let mut to_remove = Vec::new();
120
121        for (ip, score) in scores.iter_mut() {
122            *score = score.saturating_sub(PUNISHMENT);
123            if *score == 0 {
124                to_remove.push(ip.clone());
125            }
126        }
127
128        for ip in to_remove {
129            scores.remove(&ip);
130        }
131    }
132
133    /// Get a list of banned ips
134    pub async fn get_banned_ips(&self) -> HashSet<IpAddr> {
135        self.client_health_scores
136            .read()
137            .await
138            .iter()
139            .filter(|(_ip, score)| score > &&BAN_SCORE_THRESHOLD)
140            .map(|(ip, _score)| *ip)
141            .collect()
142    }
143}
144
145#[derive(Serialize, Deserialize, Clone)]
146pub enum ChainEvent {
147    Block { block: Block },
148    Transaction { transaction: Transaction },
149    TransactionExpiration { transaction: TransactionId },
150}