snap_coin/node/
peer.rs

1use bincode::error::EncodeError;
2use std::{
3    collections::{HashMap, VecDeque},
4    net::SocketAddr,
5    sync::Arc,
6    time::Duration,
7};
8use thiserror::Error;
9use tokio::{
10    net::TcpStream,
11    sync::{RwLock, oneshot},
12    task::JoinHandle,
13    time::{sleep, timeout},
14};
15
16use crate::{
17    core::{blockchain::BlockchainError, utxo::TransactionError},
18    node::{
19        message::{Command, Message, MessageError},
20        node::Node,
21        sync::sync_to_peer,
22    },
23};
24
25#[derive(Error, Debug)]
26pub enum PeerError {
27    #[error("{0}")]
28    MessageError(#[from] MessageError),
29
30    #[error("Disconnected")]
31    Disconnected,
32
33    #[error("Blockchain error: {0}")]
34    BlockchainError(#[from] BlockchainError),
35
36    #[error("Transaction error: {0}")]
37    TransactionError(#[from] TransactionError),
38
39    #[error("Sync peer returned an invalid response")]
40    SyncResponseInvalid,
41
42    #[error("Could not find fork point with peer")]
43    NoForkPoint,
44
45    #[error("Block has invalid difficulty")]
46    BadBlockDifficulty,
47
48    #[error("Block has invalid block hash")]
49    BadBlockHash,
50
51    #[error("Block has no block hash attached")]
52    NoBlockHash,
53
54    #[error("Encode error: {0}")]
55    EncodeError(#[from] EncodeError),
56}
57
58pub const TIMEOUT: Duration = Duration::from_secs(15);
59
60/// A struct representing one peer (peer connection. Can be both a client peer or a connected peer)
61pub struct Peer {
62    pub address: SocketAddr,
63
64    pub is_client: bool,
65
66    // Outgoing messages waiting to be written to stream
67    send_queue: VecDeque<Message>,
68
69    // Pending requests waiting for a response (id -> oneshot sender)
70    pending: HashMap<u16, oneshot::Sender<Message>>,
71
72    shutdown: bool,
73}
74
75impl Peer {
76    /// Create a new peer
77    pub fn new(address: SocketAddr, is_client: bool) -> Self {
78        Self {
79            address,
80            is_client,
81            send_queue: VecDeque::new(),
82            pending: HashMap::new(),
83            shutdown: false,
84        }
85    }
86
87    async fn on_fail(peer: Arc<RwLock<Peer>>, node: Arc<RwLock<Node>>) {
88        let peer_address = peer.read().await.address;
89
90        let mut node_peers = node.write().await;
91
92        let mut new_peers = Vec::new();
93        for p in node_peers.peers.drain(..) {
94            let p_address = p.read().await.address;
95            if p_address != peer_address {
96                new_peers.push(p);
97            }
98        }
99
100        node_peers.peers = new_peers;
101    }
102
103    /// Main connection handler
104    pub async fn connect(
105        peer: Arc<RwLock<Peer>>,
106        node: Arc<RwLock<Node>>,
107        stream: TcpStream,
108    ) -> JoinHandle<Result<(), PeerError>> {
109        let (mut read_stream, mut write_stream) = stream.into_split();
110
111        // Spawn peer handler task
112        tokio::spawn(async move {
113            let peer_cloned = peer.clone();
114            let node_cloned = node.clone();
115            // Spawn ping / pong task
116            let pinger = {
117                let peer_outer = peer.clone();
118                let node_outer = node.clone();
119
120                Box::pin(async move {
121                    loop {
122                        sleep(Duration::from_secs(5)).await;
123                        if peer_outer.read().await.shutdown {
124                            return Err(PeerError::Disconnected);
125                        }
126
127                        let height = node_outer.read().await.blockchain.get_height();
128
129                        let response = Peer::request(
130                            peer_outer.clone(),
131                            Message::new(Command::Ping { height }),
132                        )
133                        .await?;
134
135                        if let Command::Pong { height } = response.command {
136                            let local_height = node_outer.read().await.blockchain.get_height();
137
138                            if local_height < height {
139                                let node_for_task = node_outer.clone();
140                                let peer_for_task = peer_outer.clone();
141
142                                tokio::spawn(async move {
143                                    if node_for_task.read().await.is_syncing {
144                                        return;
145                                    }
146
147                                    node_for_task.write().await.is_syncing = true;
148
149                                    let result = sync_to_peer(
150                                        node_for_task.clone(),
151                                        peer_for_task.clone(),
152                                        height,
153                                    )
154                                    .await;
155
156                                    if let Err(e) = result {
157                                        Node::log(format!(
158                                            "[SYNC] Failed: {}, disconnecting from {}",
159                                            e,
160                                            peer_for_task.read().await.address
161                                        ));
162                                        peer_for_task.write().await.shutdown = true;
163
164                                        let node_for_task = node_for_task.clone();
165                                        Peer::on_fail(peer_for_task, node_for_task).await;
166                                    } else {
167                                        Node::log("[SYNC] Completed".to_string());
168                                    }
169
170                                    node_for_task.write().await.is_syncing = false;
171                                });
172                            }
173                        }
174                    }
175                    #[allow(unreachable_code)]
176                    Ok::<(), PeerError>(())
177                })
178            };
179
180            // Spawn reader task
181            let reader = {
182                let peer = peer.clone();
183                let node = node.clone();
184                Box::pin(async move {
185                    loop {
186                        if peer.read().await.shutdown {
187                            return Err(PeerError::Disconnected);
188                        }
189                        let msg = Message::from_stream(&mut read_stream).await?;
190                        match timeout(
191                            TIMEOUT,
192                            Peer::handle_incoming(peer.clone(), node.clone(), msg),
193                        )
194                        .await
195                        {
196                            Ok(()) => {}
197                            Err(..) => return Err(PeerError::Disconnected),
198                        }
199                    }
200                    #[allow(unreachable_code)]
201                    Ok::<(), PeerError>(())
202                })
203            };
204
205            // Spawn writer task
206            let writer = {
207                let peer = peer.clone();
208                Box::pin(async move {
209                    loop {
210                        if peer.read().await.shutdown {
211                            return Err(PeerError::Disconnected);
212                        }
213                        let maybe_msg = {
214                            let mut p = peer.write().await;
215                            p.send_queue.pop_front()
216                        };
217
218                        if let Some(msg) = maybe_msg {
219                            match timeout(TIMEOUT, msg.send(&mut write_stream)).await {
220                                Ok(e) => e?,
221                                Err(..) => return Err(PeerError::Disconnected),
222                            }
223                        } else {
224                            sleep(Duration::from_millis(10)).await;
225                        }
226                    }
227                    #[allow(unreachable_code)]
228                    Ok::<(), PeerError>(())
229                })
230            };
231
232            // Join all tasks
233            let result = tokio::select! {
234              r = reader => r,
235              r = writer => r,
236              r = pinger => r,
237            };
238
239            if let Err(e) = result {
240                Node::log(format!(
241                    "Disconnected from peer: {}:{}. Error: {:?}",
242                    peer.read().await.address.ip(),
243                    peer.read().await.address.port(),
244                    e
245                ));
246
247                tokio::spawn(async move {
248                    Self::on_fail(peer_cloned, node_cloned).await;
249                });
250            }
251            Ok(())
252        })
253    }
254
255    /// Handle incoming message
256    async fn handle_incoming(peer: Arc<RwLock<Peer>>, node: Arc<RwLock<Node>>, message: Message) {
257        {
258            let mut p = peer.write().await;
259            if let Some(tx) = p.pending.remove(&message.id) {
260                let _ = tx.send(message);
261                return;
262            }
263        }
264
265        Peer::on_message(peer.clone(), node.clone(), message).await;
266    }
267
268    /// Handle incoming message
269    async fn on_message(peer: Arc<RwLock<Peer>>, node: Arc<RwLock<Node>>, message: Message) {
270        if let Err(err) = async {
271            match message.command {
272                Command::Connect => {
273                    Peer::send(peer, message.make_response(Command::AcknowledgeConnection)).await;
274                }
275                Command::AcknowledgeConnection => {
276                    Node::log(format!("Got unhandled AcknowledgeConnection"));
277                }
278                Command::Ping { height: _ } => {
279                    Peer::send(
280                        peer.clone(),
281                        message.make_response(Command::Pong {
282                            height: node.read().await.blockchain.get_height(),
283                        }),
284                    )
285                    .await;
286                }
287                Command::Pong { .. } => {
288                    Node::log(format!("Got unhandled Pong"));
289                }
290                Command::GetPeers => {
291                    let peers: Vec<String> = {
292                        let node_read = node.read().await;
293                        let mut peer_addrs = Vec::new();
294                        for p in &node_read.peers {
295                            if p.read().await.is_client {
296                                continue;
297                            }
298                            let p_addr = p.read().await.address.to_string();
299                            peer_addrs.push(p_addr);
300                        }
301                        peer_addrs
302                    };
303                    let response = message.make_response(Command::SendPeers { peers });
304                    Peer::send(peer, response).await;
305                }
306                Command::SendPeers { .. } => {
307                    Node::log(format!("Got unhandled SendPeers"));
308                }
309                Command::NewBlock { ref block } => {
310                    // Make sure block is not in the blockchain
311                    if Some(node.read().await.last_seen_block) != block.hash && !node.read().await.is_syncing {
312                        Node::submit_block(node.clone(), block.clone()).await?;
313                    }
314                }
315                Command::NewTransaction { ref transaction } => {
316                    // Check if transaction was already seen
317                    if !node
318                        .read()
319                        .await
320                        .mempool
321                        .validate_transaction(transaction)
322                        .await
323                    {
324                        return Ok(());
325                    }
326
327                    Node::submit_transaction(node, transaction.clone()).await?;
328                }
329                Command::GetBlock { block_hash } => {
330                    Peer::send(
331                        peer,
332                        message.make_response(Command::GetBlockResponse {
333                            block: node.read().await.blockchain.get_block_by_hash(&block_hash),
334                        }),
335                    )
336                    .await;
337                }
338                Command::GetBlockResponse { .. } => {
339                    Node::log(format!("Got unhandled SendBlock"));
340                }
341                Command::GetBlockHashes { start, end } => {
342                    let mut block_hashes = Vec::new();
343                    for i in start..end {
344                        if let Some(block_hash) =
345                            node.read().await.blockchain.get_block_hash_by_height(i)
346                        {
347                            block_hashes.push(*block_hash);
348                        }
349                    }
350                    Peer::send(
351                        peer,
352                        message.make_response(Command::GetBlockHashesResponse { block_hashes }),
353                    )
354                    .await;
355                }
356                Command::GetBlockHashesResponse { .. } => {
357                    Node::log(format!("Got unhandled SendBlockHashes"));
358                }
359            };
360            Ok::<(), PeerError>(())
361        }
362        .await
363        {
364            Node::log(format!("Error processing incoming message: {err}"));
365        }
366    }
367
368    /// Send a request and wait for the response
369    pub async fn request(peer: Arc<RwLock<Peer>>, message: Message) -> Result<Message, PeerError> {
370        let id = message.id;
371
372        let (tx, rx) = oneshot::channel();
373
374        {
375            let mut p = peer.write().await;
376            p.pending.insert(id, tx);
377            p.send_queue.push_back(message);
378        }
379
380        match timeout(Duration::from_secs(10), rx).await {
381            Ok(Ok(msg)) => Ok(msg),
382            Ok(Err(_)) => Err(PeerError::Disconnected),
383            Err(_) => Err(PeerError::Disconnected),
384        }
385    }
386
387    /// Send a message to this peer, without expecting a response
388    pub async fn send(peer: Arc<RwLock<Peer>>, message: Message) {
389        let mut p = peer.write().await;
390        p.send_queue.push_back(message);
391    }
392
393    /// Send this message to all peers but this one
394    pub async fn send_to_peers(node: Arc<RwLock<Node>>, message: Message) {
395        // clone the peer list while holding the lock, then drop the lock
396        let peers = {
397            let guard = node.read().await;
398            guard.peers.clone()
399        };
400
401        for peer in peers {
402            // now safe to await
403            Peer::send(peer, message.clone()).await;
404        }
405    }
406}