snarkos_node_tcp/protocols/
reading.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#[cfg(doc)]
17use crate::{Config, protocols::Handshake};
18use crate::{
19    ConnectionSide,
20    P2P,
21    Tcp,
22    protocols::{ProtocolHandler, ReturnableConnection},
23};
24
25use async_trait::async_trait;
26use bytes::BytesMut;
27use futures_util::StreamExt;
28use std::{io, net::SocketAddr};
29use tokio::{
30    io::AsyncRead,
31    sync::{mpsc, oneshot},
32};
33use tokio_util::codec::{Decoder, FramedRead};
34use tracing::*;
35
36/// Can be used to specify and enable reading, i.e. receiving inbound messages. If the [`Handshake`]
37/// protocol is enabled too, it goes into force only after the handshake has been concluded.
38///
39/// Each inbound message is isolated by the user-supplied [`Reading::Codec`], creating a [`Reading::Message`],
40/// which is immediately queued (with a [`Reading::MESSAGE_QUEUE_DEPTH`] limit) to be processed by
41/// [`Reading::process_message`]. The configured fatal IO errors result in an immediate disconnect
42/// (in order to e.g. avoid accidentally reading "borked" messages).
43#[async_trait]
44pub trait Reading: P2P
45where
46    Self: Clone + Send + Sync + 'static,
47{
48    /// The depth of per-connection queues used to process inbound messages; the greater it is, the more inbound
49    /// messages the node can enqueue, but setting it to a large value can make the node more susceptible to DoS
50    /// attacks.
51    ///
52    /// The default value is 1024.
53    fn message_queue_depth(&self) -> usize {
54        1024
55    }
56
57    /// The initial size of a per-connection buffer for reading inbound messages. Can be set to the maximum expected size
58    /// of the inbound message in order to only allocate it once.
59    ///
60    /// The default value is 1024KiB.
61    const INITIAL_BUFFER_SIZE: usize = 1024 * 1024;
62
63    /// The final (deserialized) type of inbound messages.
64    type Message: Send;
65
66    /// The user-supplied [`Decoder`] used to interpret inbound messages.
67    type Codec: Decoder<Item = Self::Message, Error = io::Error> + Send;
68
69    /// Prepares the node to receive messages.
70    async fn enable_reading(&self) {
71        let (conn_sender, mut conn_receiver) = mpsc::unbounded_channel();
72
73        // use a channel to know when the reading task is ready
74        let (tx_reading, rx_reading) = oneshot::channel();
75
76        // the main task spawning per-connection tasks reading messages from their streams
77        let self_clone = self.clone();
78        let reading_task = tokio::spawn(async move {
79            trace!(parent: self_clone.tcp().span(), "spawned the Reading handler task");
80            tx_reading.send(()).unwrap(); // safe; the channel was just opened
81
82            // these objects are sent from `Tcp::adapt_stream`
83            while let Some(returnable_conn) = conn_receiver.recv().await {
84                self_clone.handle_new_connection(returnable_conn).await;
85            }
86        });
87        let _ = rx_reading.await;
88        self.tcp().tasks.lock().push(reading_task);
89
90        // register the Reading handler with the Tcp
91        let hdl = Box::new(ProtocolHandler(conn_sender));
92        assert!(self.tcp().protocols.reading.set(hdl).is_ok(), "the Reading protocol was enabled more than once!");
93    }
94
95    /// Creates a [`Decoder`] used to interpret messages from the network.
96    /// The `side` param indicates the connection side **from the node's perspective**.
97    fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
98
99    /// Processes an inbound message. Can be used to update state, send replies etc.
100    async fn process_message(&self, source: SocketAddr, message: Self::Message) -> io::Result<()>;
101}
102
103/// This trait is used to restrict access to methods that would otherwise be public in [`Reading`].
104#[async_trait]
105trait ReadingInternal: Reading {
106    /// Applies the [`Reading`] protocol to a single connection.
107    async fn handle_new_connection(&self, (conn, conn_returner): ReturnableConnection);
108
109    /// Wraps the user-supplied [`Decoder`] ([`Reading::Codec`]) in another one used for message accounting.
110    fn map_codec<T: AsyncRead>(
111        &self,
112        framed: FramedRead<T, Self::Codec>,
113        addr: SocketAddr,
114    ) -> FramedRead<T, CountingCodec<Self::Codec>>;
115}
116
117#[async_trait]
118impl<R: Reading> ReadingInternal for R {
119    async fn handle_new_connection(&self, (mut conn, conn_returner): ReturnableConnection) {
120        let addr = conn.addr();
121        let codec = self.codec(addr, !conn.side());
122        let reader = conn.reader.take().expect("missing connection reader!");
123        let framed = FramedRead::new(reader, codec);
124        let mut framed = self.map_codec(framed, addr);
125
126        // the connection will notify the reading task once it's fully ready
127        let (tx_conn_ready, rx_conn_ready) = oneshot::channel();
128        conn.readiness_notifier = Some(tx_conn_ready);
129
130        if Self::INITIAL_BUFFER_SIZE != 0 {
131            framed.read_buffer_mut().reserve(Self::INITIAL_BUFFER_SIZE);
132        }
133
134        let (inbound_message_sender, mut inbound_message_receiver) = mpsc::channel(self.message_queue_depth());
135
136        // use a channel to know when the processing task is ready
137        let (tx_processing, rx_processing) = oneshot::channel::<()>();
138
139        // the task for processing parsed messages
140        let self_clone = self.clone();
141        let inbound_processing_task = tokio::spawn(async move {
142            let node = self_clone.tcp();
143            trace!(parent: node.span(), "spawned a task for processing messages from {addr}");
144            tx_processing.send(()).unwrap(); // safe; the channel was just opened
145
146            while let Some(msg) = inbound_message_receiver.recv().await {
147                if let Err(e) = self_clone.process_message(addr, msg).await {
148                    error!(parent: node.span(), "can't process a message from {addr}: {e}");
149                    node.known_peers().register_failure(addr.ip());
150                }
151                #[cfg(feature = "metrics")]
152                metrics::decrement_gauge(metrics::tcp::TCP_TASKS, 1f64);
153            }
154        });
155        let _ = rx_processing.await;
156        conn.tasks.push(inbound_processing_task);
157
158        // use a channel to know when the reader task is ready
159        let (tx_reader, rx_reader) = oneshot::channel::<()>();
160
161        // the task for reading messages from a stream
162        let node = self.tcp().clone();
163        let reader_task = tokio::spawn(async move {
164            trace!(parent: node.span(), "spawned a task for reading messages from {addr}");
165            tx_reader.send(()).unwrap(); // safe; the channel was just opened
166
167            // postpone reads until the connection is fully established; if the process fails,
168            // this task gets aborted, so there is no need for a dedicated timeout
169            let _ = rx_conn_ready.await;
170
171            while let Some(bytes) = framed.next().await {
172                match bytes {
173                    Ok(msg) => {
174                        // send the message for further processing
175                        if let Err(e) = inbound_message_sender.try_send(msg) {
176                            error!(parent: node.span(), "can't process a message from {addr}: {e}");
177                            node.stats().register_failure();
178                        }
179                        #[cfg(feature = "metrics")]
180                        metrics::increment_gauge(metrics::tcp::TCP_TASKS, 1f64);
181                    }
182                    Err(e) => {
183                        error!(parent: node.span(), "can't read from {addr}: {e}");
184                        node.known_peers().register_failure(addr.ip());
185                        if node.config().fatal_io_errors.contains(&e.kind()) {
186                            break;
187                        }
188                    }
189                }
190            }
191
192            let _ = node.disconnect(addr).await;
193        });
194        let _ = rx_reader.await;
195        conn.tasks.push(reader_task);
196
197        // return the Connection to the Tcp, resuming Tcp::adapt_stream
198        if conn_returner.send(Ok(conn)).is_err() {
199            unreachable!("couldn't return a Connection to the Tcp");
200        }
201    }
202
203    fn map_codec<T: AsyncRead>(
204        &self,
205        framed: FramedRead<T, Self::Codec>,
206        addr: SocketAddr,
207    ) -> FramedRead<T, CountingCodec<Self::Codec>> {
208        framed.map_decoder(|codec| CountingCodec { codec, node: self.tcp().clone(), addr, acc: 0 })
209    }
210}
211
212/// A wrapper [`Decoder`] that also counts the inbound messages.
213struct CountingCodec<D: Decoder> {
214    codec: D,
215    node: Tcp,
216    addr: SocketAddr,
217    acc: usize,
218}
219
220impl<D: Decoder> Decoder for CountingCodec<D> {
221    type Error = D::Error;
222    type Item = D::Item;
223
224    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
225        let initial_buf_len = src.len();
226        let ret = self.codec.decode(src)?;
227        let final_buf_len = src.len();
228        let read_len = initial_buf_len - final_buf_len + self.acc;
229
230        if read_len != 0 {
231            trace!(parent: self.node.span(), "read {}B from {}", read_len, self.addr);
232
233            if ret.is_some() {
234                self.acc = 0;
235                self.node.known_peers().register_received_message(self.addr.ip(), read_len);
236                self.node.stats().register_received_message(read_len);
237            } else {
238                self.acc = read_len;
239            }
240        }
241
242        Ok(ret)
243    }
244}