1use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
2
3use log::{error};
4use tokio::{
5 net::{
6 TcpStream,
7 tcp::{OwnedReadHalf, OwnedWriteHalf},
8 },
9 sync::{
10 Mutex,
11 mpsc::{self, Receiver},
12 oneshot,
13 },
14 time::{sleep, timeout},
15};
16
17use thiserror::Error;
18
19use crate::{
20 core::blockchain::BlockchainError,
21 node::{
22 message::{Command, Message, MessageId},peer_behavior::SharedPeerBehavior
23 },
24};
25
26pub enum Outgoing {
28 Request(Message, oneshot::Sender<Message>),
29 OneWay(Message),
30}
31
32type Pending = Arc<Mutex<HashMap<MessageId, oneshot::Sender<Message>>>>;
33type KillSignal = String;
34
35pub const PEER_TIMEOUT: Duration = Duration::from_secs(5);
37
38pub const PEER_PING_INTERVAL: Duration = Duration::from_secs(5);
40
41#[derive(Error, Debug)]
42pub enum PeerError {
43 #[error("IO error: {0}")]
44 Io(String),
45
46 #[error("Timeout waiting for peer response")]
47 Timeout,
48
49 #[error("Failed to send request to peer: {0}")]
50 SendError(String),
51
52 #[error("Failed to receive response from peer: {0}")]
53 ReceiveError(String),
54
55 #[error("Peer killed: {0}")]
56 Killed(String),
57
58 #[error("Message decoding error: {0}")]
59 MessageDecode(String),
60
61 #[error("Message encoding error: {0}")]
62 MessageEncode(String),
63
64 #[error("Peer disconnected unexpectedly")]
65 Disconnected,
66
67 #[error("Unknown error: {0}")]
68 Unknown(String),
69
70 #[error("Blockchain error: {0}")]
71 Blockchain(#[from] BlockchainError),
72
73 #[error("Sync error: {0}")]
74 SyncError(String),
75}
76
77#[derive(Clone, Debug)]
79pub struct PeerHandle {
80 pub address: SocketAddr,
81 pub is_client: bool,
82 send: mpsc::Sender<Outgoing>,
83 kill: Arc<Mutex<Option<oneshot::Sender<KillSignal>>>>,
84}
85
86impl PeerHandle {
87 pub async fn request(&self, request: Message) -> Result<Message, PeerError> {
89 let (callback_tx, callback_rx) = oneshot::channel::<Message>();
90
91 match timeout(
92 PEER_TIMEOUT,
93 self.send.send(Outgoing::Request(request, callback_tx)),
94 )
95 .await
96 {
97 Ok(res) => res.map_err(|e| PeerError::SendError(e.to_string()))?,
98 Err(_) => {
99 self.kill("Peer timed out".to_string()).await?;
100 return Err(PeerError::Timeout);
101 }
102 }
103
104 callback_rx
105 .await
106 .map_err(|e| PeerError::ReceiveError(e.to_string()))
107 }
108
109 pub async fn send(&self, message: Message) -> Result<(), PeerError> {
111 self.send
112 .send(Outgoing::OneWay(message))
113 .await
114 .map_err(|e| PeerError::SendError(e.to_string()))
115 }
116
117 pub async fn kill(&self, message: String) -> Result<(), PeerError> {
119 if let Some(kill) = self.kill.lock().await.take() {
120 kill.send(message.clone())
121 .map_err(|_| PeerError::Killed(message))?;
122 }
123 Ok(())
124 }
125}
126
127pub fn create_peer(
129 stream: TcpStream,
130 behavior: SharedPeerBehavior,
131 is_client: bool,
132) -> Result<PeerHandle, PeerError> {
133 let address = stream
134 .peer_addr()
135 .map_err(|e| PeerError::Io(format!("IO error: {e}")))?;
136
137 let (outgoing_tx, outgoing_rx) = mpsc::channel::<Outgoing>(64);
138 let (kill, should_kill) = oneshot::channel::<KillSignal>();
139
140 let handle = PeerHandle {
141 send: outgoing_tx,
142 kill: Arc::new(Mutex::new(Some(kill))),
143 is_client,
144 address,
145 };
146 let my_handle = handle.clone();
147
148 tokio::spawn(async move {
149 let behavior_on_kill = behavior.clone();
150 let my_handle_on_kill = my_handle.clone();
151 if let Err(e) = async move {
152 let (reader, writer) = stream.into_split();
153
154 let pending: Pending =
155 Arc::new(Mutex::new(HashMap::<MessageId, oneshot::Sender<Message>>::new()));
156
157 tokio::select! {
158 res = reader_task(reader, pending.clone(), my_handle.clone(), behavior.clone()) => res,
159 res = writer_task(writer, outgoing_rx, pending) => res,
160 res = pinger_task(my_handle, behavior.clone()) => res,
161 res = async move {
162 let message = should_kill
163 .await
164 .map_err(|_| PeerError::Killed("Kill channel closed".to_string()))?;
165 Err(PeerError::Killed(message))
166 } => res
167 }?;
168
169 Ok::<(), PeerError>(())
170 }
171 .await
172 {
173 tokio::spawn(async move {
174 behavior_on_kill.on_kill(&my_handle_on_kill).await;
175 error!("Peer error (disconnected): {e}");
176 });
177
178 }
179 });
180
181 Ok(handle)
182}
183
184async fn reader_task(
185 mut stream: OwnedReadHalf,
186 pending: Pending,
187 my_handle: PeerHandle,
188 behavior: SharedPeerBehavior
189) -> Result<(), PeerError> {
190 loop {
191 let message = Message::from_stream(&mut stream)
192 .await
193 .map_err(|e| PeerError::MessageDecode(e.to_string()))?;
194
195 if let Some(requester) = pending.lock().await.remove(&message.id) {
196 let _ = requester.send(message);
197 } else {
198 let response = behavior.on_message(message, &my_handle).await?;
199 my_handle.send(response).await?;
200 }
201 }
202}
203
204async fn writer_task(
205 mut stream: OwnedWriteHalf,
206 mut receiver: Receiver<Outgoing>,
207 pending: Pending,
208) -> Result<(), PeerError> {
209 while let Some(outgoing) = receiver.recv().await {
210 match outgoing {
211 Outgoing::Request(msg, responder) => {
212 pending.lock().await.insert(msg.id, responder);
213 msg.send(&mut stream)
214 .await
215 .map_err(|e| PeerError::MessageEncode(e.to_string()))?;
216 }
217 Outgoing::OneWay(msg) => {
218 msg.send(&mut stream)
219 .await
220 .map_err(|e| PeerError::MessageEncode(e.to_string()))?;
221 }
222 }
223 }
224 Err(PeerError::Disconnected)
225}
226
227async fn pinger_task(
228 my_handle: PeerHandle,
229 behavior: SharedPeerBehavior
230) -> Result<(), PeerError> {
231 loop {
232 sleep(PEER_PING_INTERVAL).await;
233 my_handle.request(
234 Message::new(Command::Ping {
235 height: behavior.get_height().await,
236 }),
237 )
238 .await?;
239 }
240}