Skip to main content

snarkos_node_tcp/protocols/
reading.rs

1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#[cfg(doc)]
17use crate::{Config, protocols::Handshake};
18use crate::{
19    Connection,
20    ConnectionSide,
21    P2P,
22    Tcp,
23    protocols::{ProtocolHandler, ReturnableConnection},
24};
25
26use async_trait::async_trait;
27use bytes::BytesMut;
28use futures_util::StreamExt;
29use std::{
30    io,
31    net::SocketAddr,
32    time::{Duration, Instant},
33};
34use tokio::{
35    io::AsyncRead,
36    sync::{mpsc, oneshot},
37    time::timeout,
38};
39use tokio_util::codec::{Decoder, FramedRead};
40use tracing::*;
41
42/// Can be used to specify and enable reading, i.e. receiving inbound messages. If the [`Handshake`]
43/// protocol is enabled too, it goes into force only after the handshake has been concluded.
44///
45/// Each inbound message is isolated by the user-supplied [`Reading::Codec`], creating a [`Reading::Message`],
46/// which is immediately queued (with a [`Reading::MESSAGE_QUEUE_DEPTH`] limit) to be processed by
47/// [`Reading::process_message`]. The configured fatal IO errors result in an immediate disconnect
48/// (in order to e.g. avoid accidentally reading "borked" messages).
49#[async_trait]
50pub trait Reading: P2P
51where
52    Self: Clone + Send + Sync + 'static,
53{
54    /// The depth of per-connection queues used to process inbound messages; the greater it is, the more inbound
55    /// messages the node can enqueue, but setting it to a large value can make the node more susceptible to DoS
56    /// attacks.
57    ///
58    /// The default value is 1024.
59    fn message_queue_depth(&self) -> usize {
60        1024
61    }
62
63    /// The initial size of a per-connection buffer for reading inbound messages. Can be set to the maximum expected size
64    /// of the inbound message in order to only allocate it once.
65    ///
66    /// The default value is 1024KiB.
67    const INITIAL_BUFFER_SIZE: usize = 1024 * 1024;
68
69    /// The maximum time the node will wait for a new message before considering the connection dead.
70    const IDLE_TIMEOUT: Duration = Duration::from_secs(150);
71
72    /// The final (deserialized) type of inbound messages.
73    type Message: Send;
74
75    /// The user-supplied [`Decoder`] used to interpret inbound messages.
76    type Codec: Decoder<Item = Self::Message, Error = io::Error> + Send;
77
78    /// Prepares the node to receive messages.
79    async fn enable_reading(&self) {
80        let (conn_sender, mut conn_receiver) = mpsc::channel(self.tcp().config().max_connections as usize);
81
82        // use a channel to know when the reading task is ready
83        let (tx_reading, rx_reading) = oneshot::channel();
84
85        // the main task spawning per-connection tasks reading messages from their streams
86        let self_clone = self.clone();
87        let reading_task = tokio::spawn(async move {
88            trace!(parent: self_clone.tcp().span(), "spawned the Reading handler task");
89            tx_reading.send(()).unwrap(); // safe; the channel was just opened
90
91            // these objects are sent from `Tcp::adapt_stream`
92            while let Some(returnable_conn) = conn_receiver.recv().await {
93                self_clone.handle_new_connection(returnable_conn).await;
94            }
95        });
96        let _ = rx_reading.await;
97        self.tcp().tasks.lock().push(reading_task);
98
99        // register the Reading handler with the Tcp
100        let hdl = Box::new(ProtocolHandler(conn_sender));
101        assert!(self.tcp().protocols.reading.set(hdl).is_ok(), "the Reading protocol was enabled more than once!");
102    }
103
104    /// Creates a [`Decoder`] used to interpret messages from the network.
105    /// The `side` param indicates the connection side **from the node's perspective**.
106    fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
107
108    /// Processes an inbound message. Can be used to update state, send replies etc.
109    async fn process_message(&self, source: SocketAddr, message: Self::Message) -> io::Result<()>;
110}
111
112/// This trait is used to restrict access to methods that would otherwise be public in [`Reading`].
113#[async_trait]
114trait ReadingInternal: Reading {
115    /// Applies the [`Reading`] protocol to a single connection.
116    async fn handle_new_connection(&self, (conn, conn_returner): ReturnableConnection);
117
118    /// Wraps the user-supplied [`Decoder`] ([`Reading::Codec`]) in another one used for message accounting.
119    fn map_codec<T: AsyncRead>(
120        &self,
121        framed: FramedRead<T, Self::Codec>,
122        conn: &Connection,
123    ) -> FramedRead<T, CountingCodec<Self::Codec>>;
124}
125
126#[async_trait]
127impl<R: Reading> ReadingInternal for R {
128    async fn handle_new_connection(&self, (mut conn, conn_returner): ReturnableConnection) {
129        let addr = conn.addr();
130        let codec = self.codec(addr, !conn.side());
131        let reader = conn.reader.take().expect("missing connection reader!");
132        let framed = FramedRead::new(reader, codec);
133        let mut framed = self.map_codec(framed, &conn);
134
135        // the connection will notify the reading task once it's fully ready
136        let (tx_conn_ready, rx_conn_ready) = oneshot::channel();
137        conn.readiness_notifier = Some(tx_conn_ready);
138
139        if Self::INITIAL_BUFFER_SIZE != 0 {
140            framed.read_buffer_mut().reserve(Self::INITIAL_BUFFER_SIZE);
141        }
142
143        let (inbound_message_sender, mut inbound_message_receiver) =
144            mpsc::channel::<(R::Message, QueuedMessageGuard)>(self.message_queue_depth());
145
146        // use a channel to know when the processing task is ready
147        let (tx_processing, rx_processing) = oneshot::channel::<()>();
148
149        // the task for processing parsed messages
150        let self_clone = self.clone();
151        let conn_span = conn.span().clone();
152        let inbound_processing_task = tokio::spawn(Box::pin(async move {
153            let node = self_clone.tcp();
154            trace!(parent: &conn_span, "spawned a task for processing messages");
155            tx_processing.send(()).unwrap(); // safe; the channel was just opened
156
157            while let Some((msg, _guard)) = inbound_message_receiver.recv().await {
158                if let Err(e) = self_clone.process_message(addr, msg).await {
159                    error!(parent: &conn_span, "can't process a message: {e}");
160                    node.known_peers().register_failure(addr.ip());
161                }
162                // _guard drops here, after process_message completes
163            }
164        }));
165        let _ = rx_processing.await;
166        conn.tasks.push(inbound_processing_task);
167
168        // use a channel to know when the reader task is ready
169        let (tx_reader, rx_reader) = oneshot::channel::<()>();
170
171        // the task for reading messages from a stream
172        let node = self.tcp().clone();
173        let conn_span = conn.span().clone();
174        let reader_task = tokio::spawn(Box::pin(async move {
175            trace!(parent: &conn_span, "spawned a task for reading messages");
176            tx_reader.send(()).unwrap(); // safe; the channel was just opened
177
178            // postpone reads until the connection is fully established; if the process fails,
179            // this task gets aborted, so there is no need for a dedicated timeout
180            let _ = rx_conn_ready.await;
181
182            // dropped message log suppression helpers
183            let mut dropped_count: usize = 0;
184            let mut last_drop_log = Instant::now();
185
186            loop {
187                let next_frame_future = framed.next();
188                let read_result = match timeout(Self::IDLE_TIMEOUT, next_frame_future).await {
189                    Ok(res) => res, // IO completed (success or error)
190                    Err(_) => {
191                        debug!(parent: &conn_span, "connection timed out due to inactivity");
192                        break;
193                    }
194                };
195                match read_result {
196                    Some(Ok(msg)) => {
197                        // send the message for further processing
198                        if let Err(e) = inbound_message_sender.try_send((msg, QueuedMessageGuard::new())) {
199                            node.stats().register_failure();
200                            match e {
201                                mpsc::error::TrySendError::Full(_) => {
202                                    // avoid log flooding
203                                    dropped_count += 1;
204                                    if last_drop_log.elapsed() >= Duration::from_secs(1) {
205                                        warn_about_dropped_messages(&conn_span, &mut dropped_count, &mut last_drop_log);
206                                    }
207                                }
208                                mpsc::error::TrySendError::Closed(_) => {
209                                    error!(parent: &conn_span, "inbound channel closed");
210                                    break;
211                                }
212                            }
213                        } else if dropped_count != 0 {
214                            warn_about_dropped_messages(&conn_span, &mut dropped_count, &mut last_drop_log);
215                            debug!(parent: &conn_span, "the inbound queue is no longer saturated");
216                        }
217                        #[cfg(feature = "metrics")]
218                        metrics::increment_gauge(metrics::tcp::TCP_TASKS, 1f64);
219                    }
220                    Some(Err(e)) => {
221                        error!(parent: &conn_span, "can't read: {e}");
222                        node.known_peers().register_failure(addr.ip());
223                        if node.config().fatal_io_errors.contains(&e.kind()) {
224                            break;
225                        }
226                    }
227                    None => break, // end of stream
228                }
229            }
230
231            let _ = node.disconnect(addr).await;
232        }));
233        let _ = rx_reader.await;
234        conn.tasks.push(reader_task);
235
236        // return the Connection to the Tcp, resuming Tcp::adapt_stream
237        if conn_returner.send(Ok(conn)).is_err() {
238            unreachable!("couldn't return a Connection to the Tcp");
239        }
240    }
241
242    fn map_codec<T: AsyncRead>(
243        &self,
244        framed: FramedRead<T, Self::Codec>,
245        conn: &Connection,
246    ) -> FramedRead<T, CountingCodec<Self::Codec>> {
247        framed.map_decoder(|codec| CountingCodec { codec, node: self.tcp().clone(), acc: 0, span: conn.span().clone() })
248    }
249}
250
251/// A wrapper [`Decoder`] that also counts the inbound messages.
252struct CountingCodec<D: Decoder> {
253    codec: D,
254    node: Tcp,
255    acc: usize,
256    span: Span,
257}
258
259impl<D: Decoder> Decoder for CountingCodec<D> {
260    type Error = D::Error;
261    type Item = D::Item;
262
263    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
264        let initial_buf_len = src.len();
265        let ret = self.codec.decode(src)?;
266        let final_buf_len = src.len();
267        let read_len = initial_buf_len - final_buf_len + self.acc;
268
269        if read_len != 0 {
270            trace!(parent: &self.span, "read {read_len}B");
271
272            if ret.is_some() {
273                self.acc = 0;
274                // self.node.known_peers().register_received_message(self.addr.ip(), read_len);
275                self.node.stats().register_received_message(read_len);
276            } else {
277                self.acc = read_len;
278            }
279        }
280
281        Ok(ret)
282    }
283}
284
285/// Decrements the TCP_TASKS gauge on drop. Paired with each queued message so the gauge stays
286/// balanced whether the message is processed normally or discarded when the inbound channel is
287/// dropped (e.g. on connection abort). The caller must hold this guard until processing is
288/// complete; dropping it earlier will decrement the gauge prematurely.
289struct QueuedMessageGuard;
290
291impl QueuedMessageGuard {
292    fn new() -> Self {
293        #[cfg(feature = "metrics")]
294        metrics::increment_gauge(metrics::tcp::TCP_TASKS, 1f64);
295        Self
296    }
297}
298
299impl Drop for QueuedMessageGuard {
300    fn drop(&mut self) {
301        #[cfg(feature = "metrics")]
302        metrics::decrement_gauge(metrics::tcp::TCP_TASKS, 1f64);
303    }
304}
305
306/// Warns that some messages were dropped and resets the related counters.
307fn warn_about_dropped_messages(span: &Span, dropped_count: &mut usize, last_drop_log: &mut Instant) {
308    warn!(
309        parent: span,
310        "dropped {dropped_count} messages due\
311        to inbound queue saturation",
312    );
313    // reset counters
314    *dropped_count = 0;
315    *last_drop_log = Instant::now();
316}