snarkos_node_tcp/protocols/
reading.rs1#[cfg(doc)]
17use crate::{Config, protocols::Handshake};
18use crate::{
19 ConnectionSide,
20 P2P,
21 Tcp,
22 protocols::{ProtocolHandler, ReturnableConnection},
23};
24
25use async_trait::async_trait;
26use bytes::BytesMut;
27use futures_util::StreamExt;
28use std::{io, net::SocketAddr};
29use tokio::{
30 io::AsyncRead,
31 sync::{mpsc, oneshot},
32};
33use tokio_util::codec::{Decoder, FramedRead};
34use tracing::*;
35
36#[async_trait]
44pub trait Reading: P2P
45where
46 Self: Clone + Send + Sync + 'static,
47{
48 fn message_queue_depth(&self) -> usize {
54 1024
55 }
56
57 const INITIAL_BUFFER_SIZE: usize = 1024 * 1024;
62
63 type Message: Send;
65
66 type Codec: Decoder<Item = Self::Message, Error = io::Error> + Send;
68
69 async fn enable_reading(&self) {
71 let (conn_sender, mut conn_receiver) = mpsc::unbounded_channel();
72
73 let (tx_reading, rx_reading) = oneshot::channel();
75
76 let self_clone = self.clone();
78 let reading_task = tokio::spawn(async move {
79 trace!(parent: self_clone.tcp().span(), "spawned the Reading handler task");
80 tx_reading.send(()).unwrap(); while let Some(returnable_conn) = conn_receiver.recv().await {
84 self_clone.handle_new_connection(returnable_conn).await;
85 }
86 });
87 let _ = rx_reading.await;
88 self.tcp().tasks.lock().push(reading_task);
89
90 let hdl = Box::new(ProtocolHandler(conn_sender));
92 assert!(self.tcp().protocols.reading.set(hdl).is_ok(), "the Reading protocol was enabled more than once!");
93 }
94
95 fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
98
99 async fn process_message(&self, source: SocketAddr, message: Self::Message) -> io::Result<()>;
101}
102
103#[async_trait]
105trait ReadingInternal: Reading {
106 async fn handle_new_connection(&self, (conn, conn_returner): ReturnableConnection);
108
109 fn map_codec<T: AsyncRead>(
111 &self,
112 framed: FramedRead<T, Self::Codec>,
113 addr: SocketAddr,
114 ) -> FramedRead<T, CountingCodec<Self::Codec>>;
115}
116
117#[async_trait]
118impl<R: Reading> ReadingInternal for R {
119 async fn handle_new_connection(&self, (mut conn, conn_returner): ReturnableConnection) {
120 let addr = conn.addr();
121 let codec = self.codec(addr, !conn.side());
122 let reader = conn.reader.take().expect("missing connection reader!");
123 let framed = FramedRead::new(reader, codec);
124 let mut framed = self.map_codec(framed, addr);
125
126 let (tx_conn_ready, rx_conn_ready) = oneshot::channel();
128 conn.readiness_notifier = Some(tx_conn_ready);
129
130 if Self::INITIAL_BUFFER_SIZE != 0 {
131 framed.read_buffer_mut().reserve(Self::INITIAL_BUFFER_SIZE);
132 }
133
134 let (inbound_message_sender, mut inbound_message_receiver) = mpsc::channel(self.message_queue_depth());
135
136 let (tx_processing, rx_processing) = oneshot::channel::<()>();
138
139 let self_clone = self.clone();
141 let inbound_processing_task = tokio::spawn(Box::pin(async move {
142 let node = self_clone.tcp();
143 trace!(parent: node.span(), "spawned a task for processing messages from {addr}");
144 tx_processing.send(()).unwrap(); while let Some(msg) = inbound_message_receiver.recv().await {
147 if let Err(e) = self_clone.process_message(addr, msg).await {
148 error!(parent: node.span(), "can't process a message from {addr}: {e}");
149 node.known_peers().register_failure(addr.ip());
150 }
151 #[cfg(feature = "metrics")]
152 metrics::decrement_gauge(metrics::tcp::TCP_TASKS, 1f64);
153 }
154 }));
155 let _ = rx_processing.await;
156 conn.tasks.push(inbound_processing_task);
157
158 let (tx_reader, rx_reader) = oneshot::channel::<()>();
160
161 let node = self.tcp().clone();
163 let reader_task = tokio::spawn(Box::pin(async move {
164 trace!(parent: node.span(), "spawned a task for reading messages from {addr}");
165 tx_reader.send(()).unwrap(); let _ = rx_conn_ready.await;
170
171 while let Some(bytes) = framed.next().await {
172 match bytes {
173 Ok(msg) => {
174 if let Err(e) = inbound_message_sender.try_send(msg) {
176 error!(parent: node.span(), "can't process a message from {addr}: {e}");
177 node.stats().register_failure();
178 if matches!(e, mpsc::error::TrySendError::Closed(_)) {
179 break;
180 }
181 }
182 #[cfg(feature = "metrics")]
183 metrics::increment_gauge(metrics::tcp::TCP_TASKS, 1f64);
184 }
185 Err(e) => {
186 error!(parent: node.span(), "can't read from {addr}: {e}");
187 node.known_peers().register_failure(addr.ip());
188 if node.config().fatal_io_errors.contains(&e.kind()) {
189 break;
190 }
191 }
192 }
193 }
194
195 let _ = node.disconnect(addr).await;
196 }));
197 let _ = rx_reader.await;
198 conn.tasks.push(reader_task);
199
200 if conn_returner.send(Ok(conn)).is_err() {
202 unreachable!("couldn't return a Connection to the Tcp");
203 }
204 }
205
206 fn map_codec<T: AsyncRead>(
207 &self,
208 framed: FramedRead<T, Self::Codec>,
209 addr: SocketAddr,
210 ) -> FramedRead<T, CountingCodec<Self::Codec>> {
211 framed.map_decoder(|codec| CountingCodec { codec, node: self.tcp().clone(), addr, acc: 0 })
212 }
213}
214
215struct CountingCodec<D: Decoder> {
217 codec: D,
218 node: Tcp,
219 addr: SocketAddr,
220 acc: usize,
221}
222
223impl<D: Decoder> Decoder for CountingCodec<D> {
224 type Error = D::Error;
225 type Item = D::Item;
226
227 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
228 let initial_buf_len = src.len();
229 let ret = self.codec.decode(src)?;
230 let final_buf_len = src.len();
231 let read_len = initial_buf_len - final_buf_len + self.acc;
232
233 if read_len != 0 {
234 trace!(parent: self.node.span(), "read {}B from {}", read_len, self.addr);
235
236 if ret.is_some() {
237 self.acc = 0;
238 self.node.known_peers().register_received_message(self.addr.ip(), read_len);
239 self.node.stats().register_received_message(read_len);
240 } else {
241 self.acc = read_len;
242 }
243 }
244
245 Ok(ret)
246 }
247}