snap_coin/node/
peer.rs

1use bincode::error::EncodeError;
2use std::{
3    collections::{HashMap, VecDeque},
4    net::SocketAddr,
5    sync::Arc,
6    time::Duration,
7};
8use thiserror::Error;
9use tokio::{
10    net::TcpStream,
11    sync::{RwLock, oneshot},
12    task::JoinHandle,
13    time::{sleep, timeout},
14};
15
16use crate::{
17    core::{blockchain::BlockchainError, utxo::TransactionError},
18    node::{
19        message::{Command, Message, MessageError},
20        node::Node,
21        sync::sync_to_peer,
22    },
23};
24
25#[derive(Error, Debug)]
26pub enum PeerError {
27    #[error("{0}")]
28    MessageError(#[from] MessageError),
29
30    #[error("Disconnected")]
31    Disconnected,
32
33    #[error("Blockchain error: {0}")]
34    BlockchainError(#[from] BlockchainError),
35
36    #[error("Transaction error: {0}")]
37    TransactionError(#[from] TransactionError),
38
39    #[error("Sync peer returned an invalid response")]
40    SyncResponseInvalid,
41
42    #[error("Could not find fork point with peer")]
43    NoForkPoint,
44
45    #[error("Block has invalid difficulty")]
46    BadBlockDifficulty,
47
48    #[error("Block has invalid block hash")]
49    BadBlockHash,
50
51    #[error("Block has no block hash attached")]
52    NoBlockHash,
53
54    #[error("Encode error: {0}")]
55    EncodeError(#[from] EncodeError),
56}
57
58pub const TIMEOUT: Duration = Duration::from_secs(15);
59
60/// A struct representing one peer (peer connection. Can be both a client peer or a connected peer)
61pub struct Peer {
62    pub address: SocketAddr,
63
64    pub is_client: bool,
65
66    // Outgoing messages waiting to be written to stream
67    send_queue: VecDeque<Message>,
68
69    // Pending requests waiting for a response (id -> oneshot sender)
70    pending: HashMap<u16, oneshot::Sender<Message>>,
71}
72
73impl Peer {
74    /// Create a new peer
75    pub fn new(address: SocketAddr, is_client: bool) -> Self {
76        Self {
77            address,
78            is_client,
79            send_queue: VecDeque::new(),
80            pending: HashMap::new(),
81        }
82    }
83
84    async fn on_fail(peer: Arc<RwLock<Peer>>, node: Arc<RwLock<Node>>) {
85        let peer_address = peer.read().await.address;
86
87        let mut node_peers = node.write().await;
88
89        let mut new_peers = Vec::new();
90        for p in node_peers.peers.drain(..) {
91            let p_address = p.read().await.address;
92            if p_address != peer_address {
93                new_peers.push(p);
94            }
95        }
96
97        node_peers.peers = new_peers;
98    }
99
100    /// Main connection handler
101    pub async fn connect(
102        peer: Arc<RwLock<Peer>>,
103        node: Arc<RwLock<Node>>,
104        stream: TcpStream,
105    ) -> JoinHandle<Result<(), PeerError>> {
106        let (mut read_stream, mut write_stream) = stream.into_split();
107
108        // Spawn peer handler task
109        tokio::spawn(async move {
110            let peer_cloned = peer.clone();
111            let node_cloned = node.clone();
112
113            // Spawn ping / pong task
114            let pinger = {
115                let peer = peer.clone();
116                let node = node.clone();
117                Box::pin(async move {
118                    loop {
119                        sleep(Duration::from_secs(5)).await; // 5 second ping interval
120                        let height = node.read().await.blockchain.get_height();
121                        match Peer::request(
122                            // Send Ping and wait for Pong
123                            peer.clone(),
124                            Message::new(Command::Ping { height }),
125                        )
126                        .await?
127                        .command
128                        {
129                            Command::Pong { .. } => {}
130                            _ => {}
131                        }
132                    }
133                    #[allow(unreachable_code)]
134                    Ok::<(), PeerError>(())
135                })
136            };
137
138            // Spawn reader task
139            let reader = {
140                let peer = peer.clone();
141                let node = node.clone();
142                Box::pin(async move {
143                    loop {
144                        let msg = Message::from_stream(&mut read_stream).await?;
145                        match timeout(
146                            TIMEOUT,
147                            Peer::handle_incoming(peer.clone(), node.clone(), msg),
148                        )
149                        .await
150                        {
151                            Ok(()) => {}
152                            Err(..) => return Err(PeerError::Disconnected),
153                        }
154                    }
155                    #[allow(unreachable_code)]
156                    Ok::<(), PeerError>(())
157                })
158            };
159
160            // Spawn writer task
161            let writer = {
162                let peer = peer.clone();
163                Box::pin(async move {
164                    loop {
165                        let maybe_msg = {
166                            let mut p = peer.write().await;
167                            p.send_queue.pop_front()
168                        };
169
170                        if let Some(msg) = maybe_msg {
171                            match timeout(TIMEOUT, msg.send(&mut write_stream)).await {
172                                Ok(e) => e?,
173                                Err(..) => return Err(PeerError::Disconnected),
174                            }
175                        } else {
176                            sleep(Duration::from_millis(10)).await;
177                        }
178                    }
179                    #[allow(unreachable_code)]
180                    Ok::<(), PeerError>(())
181                })
182            };
183
184            // Join all tasks
185            let result = tokio::select! {
186              r = reader => r,
187              r = writer => r,
188              r = pinger => r,
189            };
190
191            if let Err(e) = result {
192                Node::log(format!(
193                    "Disconnected peer: {}:{}. Error: {:?}",
194                    peer.read().await.address.ip(),
195                    peer.read().await.address.port(),
196                    e
197                ));
198                let peer_cloned = peer_cloned.clone();
199                let node_cloned = node_cloned.clone();
200
201                tokio::spawn(async move {
202                    Self::on_fail(peer_cloned, node_cloned).await;
203                });
204            }
205            Ok(())
206        })
207    }
208
209    /// Handle incoming message
210    async fn handle_incoming(peer: Arc<RwLock<Peer>>, node: Arc<RwLock<Node>>, message: Message) {
211        {
212            let mut p = peer.write().await;
213            if let Some(tx) = p.pending.remove(&message.id) {
214                let _ = tx.send(message);
215                return;
216            }
217        }
218
219        Peer::on_message(peer.clone(), node.clone(), message).await;
220    }
221
222    /// Handle incoming message
223    async fn on_message(peer: Arc<RwLock<Peer>>, node: Arc<RwLock<Node>>, message: Message) {
224        if let Err(err) = async {
225            match message.command {
226                Command::Connect => {
227                    Peer::send(peer, message.make_response(Command::AcknowledgeConnection)).await;
228                }
229                Command::AcknowledgeConnection => {
230                    Node::log(format!("Got unhandled AcknowledgeConnection"));
231                }
232                Command::Ping { height } => {
233                    let local_height = node.read().await.blockchain.get_height();
234                    Peer::send(
235                        peer.clone(),
236                        message.make_response(Command::Pong {
237                            height: local_height,
238                        }),
239                    )
240                    .await;
241
242                    // Only spawn the sync task if we are allowed to sync
243                    let should_sync = {
244                        let mut node = node.write().await;
245                        if node.is_syncing {
246                            false
247                        } else {
248                            node.is_syncing = true;
249                            true
250                        }
251                    };
252
253                    if should_sync && local_height < height {
254                        tokio::spawn(async move {
255                            let result = sync_to_peer(node.clone(), peer.clone(), height).await;
256
257                            if let Err(e) = result {
258                                Node::log(format!("[SYNC] Failed: {}", e));
259                            } else {
260                                Node::log(format!("[SYNC] Completed"));
261                            }
262
263                            node.write().await.is_syncing = false;
264                        });
265                    }
266                }
267                Command::Pong { .. } => {
268                    Node::log(format!("Got unhandled Pong"));
269                }
270                Command::GetPeers => {
271                    let peers: Vec<String> = {
272                        let node_read = node.read().await;
273                        let mut peer_addrs = Vec::new();
274                        for p in &node_read.peers {
275                            if p.read().await.is_client {
276                                continue;
277                            }
278                            let p_addr = p.read().await.address.to_string();
279                            peer_addrs.push(p_addr);
280                        }
281                        peer_addrs
282                    };
283                    let response = message.make_response(Command::SendPeers { peers });
284                    Peer::send(peer, response).await;
285                }
286                Command::SendPeers { .. } => {
287                    Node::log(format!("Got unhandled SendPeers"));
288                }
289                Command::NewBlock { ref block } => {
290                    // Make sure block is not in the blockchain
291                    if Some(node.read().await.last_seen_block) != block.hash {
292                        Node::submit_block(node.clone(), block.clone()).await?;
293                    }
294                }
295                Command::NewTransaction { ref transaction } => {
296                    // Check if transaction was already seen
297                    if !node
298                        .read()
299                        .await
300                        .mempool
301                        .validate_transaction(transaction)
302                        .await
303                    {
304                        return Ok(());
305                    }
306
307                    Node::submit_transaction(node, transaction.clone()).await?;
308                }
309                Command::GetBlock { block_hash } => {
310                    Peer::send(
311                        peer,
312                        message.make_response(Command::GetBlockResponse {
313                            block: node.read().await.blockchain.get_block_by_hash(&block_hash),
314                        }),
315                    )
316                    .await;
317                }
318                Command::GetBlockResponse { .. } => {
319                    Node::log(format!("Got unhandled SendBlock"));
320                }
321                Command::GetBlockHashes { start, end } => {
322                    let mut block_hashes = Vec::new();
323                    for i in start..end {
324                        if let Some(block_hash) =
325                            node.read().await.blockchain.get_block_hash_by_height(i)
326                        {
327                            block_hashes.push(*block_hash);
328                        }
329                    }
330                    Peer::send(
331                        peer,
332                        message.make_response(Command::GetBlockHashesResponse { block_hashes }),
333                    )
334                    .await;
335                }
336                Command::GetBlockHashesResponse { .. } => {
337                    Node::log(format!("Got unhandled SendBlockHashes"));
338                }
339            };
340            Ok::<(), PeerError>(())
341        }
342        .await
343        {
344            Node::log(format!("Error processing incoming message: {err}"));
345        }
346    }
347
348    /// Send a request and wait for the response
349    pub async fn request(peer: Arc<RwLock<Peer>>, message: Message) -> Result<Message, PeerError> {
350        let id = message.id;
351
352        let (tx, rx) = oneshot::channel();
353
354        {
355            let mut p = peer.write().await;
356            p.pending.insert(id, tx);
357            p.send_queue.push_back(message);
358        }
359
360        match timeout(Duration::from_secs(10), rx).await {
361            Ok(Ok(msg)) => Ok(msg),
362            Ok(Err(_)) => Err(PeerError::Disconnected),
363            Err(_) => Err(PeerError::Disconnected),
364        }
365    }
366
367    /// Send a message to this peer, without expecting a response
368    pub async fn send(peer: Arc<RwLock<Peer>>, message: Message) {
369        let mut p = peer.write().await;
370        p.send_queue.push_back(message);
371    }
372
373    /// Send this message to all peers but this one
374    pub async fn send_to_peers(node: Arc<RwLock<Node>>, message: Message) {
375        // clone the peer list while holding the lock, then drop the lock
376        let peers = {
377            let guard = node.read().await;
378            guard.peers.clone()
379        };
380
381        for peer in peers {
382            // now safe to await
383            Peer::send(peer, message.clone()).await;
384        }
385    }
386}