1use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
2
3use log::{error};
4use tokio::{
5 net::{
6 TcpStream,
7 tcp::{OwnedReadHalf, OwnedWriteHalf},
8 },
9 sync::{
10 Mutex,
11 mpsc::{self, Receiver},
12 oneshot,
13 },
14 time::{sleep, timeout},
15};
16
17use thiserror::Error;
18
19use crate::{
20 core::blockchain::BlockchainError, light_node::block_meta_store::BlockMetaStoreError, node::{
21 message::{Command, Message, MessageId},peer_behavior::SharedPeerBehavior
22 }
23};
24
25pub enum Outgoing {
27 Request(Message, oneshot::Sender<Message>),
28 OneWay(Message),
29}
30
31type Pending = Arc<Mutex<HashMap<MessageId, oneshot::Sender<Message>>>>;
32type KillSignal = String;
33
34pub const PEER_TIMEOUT: Duration = Duration::from_secs(5);
36
37pub const PEER_PING_INTERVAL: Duration = Duration::from_secs(5);
39
40#[derive(Error, Debug)]
41pub enum PeerError {
42 #[error("IO error: {0}")]
43 Io(String),
44
45 #[error("Timeout waiting for peer response")]
46 Timeout,
47
48 #[error("Failed to send request to peer: {0}")]
49 SendError(String),
50
51 #[error("Failed to receive response from peer: {0}")]
52 ReceiveError(String),
53
54 #[error("Peer killed: {0}")]
55 Killed(String),
56
57 #[error("Message decoding error: {0}")]
58 MessageDecode(String),
59
60 #[error("Message encoding error: {0}")]
61 MessageEncode(String),
62
63 #[error("Peer disconnected unexpectedly")]
64 Disconnected,
65
66 #[error("Unknown error: {0}")]
67 Unknown(String),
68
69 #[error("Blockchain error: {0}")]
70 Blockchain(#[from] BlockchainError),
71
72 #[error("Meta Store error: {0}")]
73 BlockMetaStore(#[from] BlockMetaStoreError),
74
75 #[error("Sync error: {0}")]
76 SyncError(String),
77
78 #[error("Incorrect response received")]
79 IncorrectResponse,
80}
81
82#[derive(Clone, Debug)]
84pub struct PeerHandle {
85 pub address: SocketAddr,
86 pub is_client: bool,
87 send: mpsc::Sender<Outgoing>,
88 kill: Arc<Mutex<Option<oneshot::Sender<KillSignal>>>>,
89}
90
91impl PeerHandle {
92 pub async fn request(&self, request: Message) -> Result<Message, PeerError> {
94 let (callback_tx, callback_rx) = oneshot::channel::<Message>();
95
96 match timeout(
97 PEER_TIMEOUT,
98 self.send.send(Outgoing::Request(request, callback_tx)),
99 )
100 .await
101 {
102 Ok(res) => res.map_err(|e| PeerError::SendError(e.to_string()))?,
103 Err(_) => {
104 self.kill("Peer timed out".to_string()).await?;
105 return Err(PeerError::Timeout);
106 }
107 }
108
109 callback_rx
110 .await
111 .map_err(|e| PeerError::ReceiveError(e.to_string()))
112 }
113
114 pub async fn send(&self, message: Message) -> Result<(), PeerError> {
116 self.send
117 .send(Outgoing::OneWay(message))
118 .await
119 .map_err(|e| PeerError::SendError(e.to_string()))
120 }
121
122 pub async fn kill(&self, message: String) -> Result<(), PeerError> {
124 if let Some(kill) = self.kill.lock().await.take() {
125 kill.send(message.clone())
126 .map_err(|_| PeerError::Killed(message))?;
127 }
128 Ok(())
129 }
130}
131
132pub fn create_peer(
134 stream: TcpStream,
135 behavior: SharedPeerBehavior,
136 is_client: bool,
137) -> Result<PeerHandle, PeerError> {
138 let address = stream
139 .peer_addr()
140 .map_err(|e| PeerError::Io(format!("IO error: {e}")))?;
141
142 let (outgoing_tx, outgoing_rx) = mpsc::channel::<Outgoing>(64);
143 let (kill, should_kill) = oneshot::channel::<KillSignal>();
144
145 let handle = PeerHandle {
146 send: outgoing_tx,
147 kill: Arc::new(Mutex::new(Some(kill))),
148 is_client,
149 address,
150 };
151 let my_handle = handle.clone();
152
153 tokio::spawn(async move {
154 let behavior_on_kill = behavior.clone();
155 let my_handle_on_kill = my_handle.clone();
156 if let Err(e) = async move {
157 let (reader, writer) = stream.into_split();
158
159 let pending: Pending =
160 Arc::new(Mutex::new(HashMap::<MessageId, oneshot::Sender<Message>>::new()));
161
162 tokio::select! {
163 res = reader_task(reader, pending.clone(), my_handle.clone(), behavior.clone()) => res,
164 res = writer_task(writer, outgoing_rx, pending) => res,
165 res = pinger_task(my_handle, behavior.clone()) => res,
166 res = async move {
167 let message = should_kill
168 .await
169 .map_err(|_| PeerError::Killed("Kill channel closed".to_string()))?;
170 Err(PeerError::Killed(message))
171 } => res
172 }?;
173
174 Ok::<(), PeerError>(())
175 }
176 .await
177 {
178 tokio::spawn(async move {
179 behavior_on_kill.on_kill(&my_handle_on_kill).await;
180 error!("Peer error (disconnected): {e}");
181 });
182
183 }
184 });
185
186 Ok(handle)
187}
188
189async fn reader_task(
190 mut stream: OwnedReadHalf,
191 pending: Pending,
192 my_handle: PeerHandle,
193 behavior: SharedPeerBehavior
194) -> Result<(), PeerError> {
195 loop {
196 let message = Message::from_stream(&mut stream)
197 .await
198 .map_err(|e| PeerError::MessageDecode(e.to_string()))?;
199
200 if let Some(requester) = pending.lock().await.remove(&message.id) {
201 let _ = requester.send(message);
202 } else {
203 let response = behavior.on_message(message, &my_handle).await?;
204 my_handle.send(response).await?;
205 }
206 }
207}
208
209async fn writer_task(
210 mut stream: OwnedWriteHalf,
211 mut receiver: Receiver<Outgoing>,
212 pending: Pending,
213) -> Result<(), PeerError> {
214 while let Some(outgoing) = receiver.recv().await {
215 match outgoing {
216 Outgoing::Request(msg, responder) => {
217 pending.lock().await.insert(msg.id, responder);
218 msg.send(&mut stream)
219 .await
220 .map_err(|e| PeerError::MessageEncode(e.to_string()))?;
221 }
222 Outgoing::OneWay(msg) => {
223 msg.send(&mut stream)
224 .await
225 .map_err(|e| PeerError::MessageEncode(e.to_string()))?;
226 }
227 }
228 }
229 Err(PeerError::Disconnected)
230}
231
232async fn pinger_task(
233 my_handle: PeerHandle,
234 behavior: SharedPeerBehavior
235) -> Result<(), PeerError> {
236 loop {
237 sleep(PEER_PING_INTERVAL).await;
238 my_handle.request(
239 Message::new(Command::Ping {
240 height: behavior.get_height().await,
241 }),
242 )
243 .await?;
244 }
245}