Skip to main content

snarkos_node_tcp/protocols/
writing.rs

1// Copyright (c) 2019-2026 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{any::Any, collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration};
17
18use async_trait::async_trait;
19use futures_util::sink::SinkExt;
20#[cfg(feature = "locktick")]
21use locktick::parking_lot::RwLock;
22#[cfg(not(feature = "locktick"))]
23use parking_lot::RwLock;
24use tokio::{
25    io::AsyncWrite,
26    sync::{mpsc, oneshot},
27    time::timeout,
28};
29use tokio_util::codec::{Encoder, FramedWrite};
30use tracing::*;
31
32#[cfg(doc)]
33use crate::{Config, Tcp, protocols::Handshake};
34use crate::{
35    Connection,
36    ConnectionSide,
37    P2P,
38    connections::create_connection_span,
39    protocols::{Protocol, ProtocolHandler, ReturnableConnection},
40};
41
42type WritingSenders = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<WrappedMessage>>>>;
43
44/// Can be used to specify and enable writing, i.e. sending outbound messages. If the [`Handshake`]
45/// protocol is enabled too, it goes into force only after the handshake has been concluded.
46#[async_trait]
47pub trait Writing: P2P
48where
49    Self: Clone + Send + Sync + 'static,
50{
51    /// The depth of per-connection queues used to send outbound messages; the greater it is, the more outbound
52    /// messages the node can enqueue. Setting it to a large value is not recommended, as doing it might
53    /// obscure potential issues with your implementation (like slow serialization) or network.
54    ///
55    /// The default value is 1024.
56    fn message_queue_depth(&self) -> usize {
57        1024
58    }
59
60    /// The maximum time (in milliseconds) allowed for a single message write to flush
61    /// to the underlying stream before the connection is considered dead.
62    const TIMEOUT: Duration = Duration::from_secs(5);
63
64    /// The type of the outbound messages; unless their serialization is expensive and the message
65    /// is broadcasted (in which case it would get serialized multiple times), serialization should
66    /// be done in the implementation of [`Self::Codec`].
67    type Message: Send;
68
69    /// The user-supplied [`Encoder`] used to write outbound messages to the target stream.
70    type Codec: Encoder<Self::Message, Error = io::Error> + Send;
71
72    /// Prepares the node to send messages.
73    async fn enable_writing(&self) {
74        let (conn_sender, mut conn_receiver) = mpsc::channel(self.tcp().config().max_connections as usize);
75
76        // the conn_senders are used to send messages from the Tcp to individual connections
77        let conn_senders: WritingSenders = Default::default();
78        // procure a clone to create the WritingHandler with
79        let senders = conn_senders.clone();
80
81        // use a channel to know when the writing task is ready
82        let (tx_writing, rx_writing) = oneshot::channel();
83
84        // the task spawning tasks sending messages to all the streams
85        let self_clone = self.clone();
86        let writing_task = tokio::spawn(async move {
87            trace!(parent: self_clone.tcp().span(), "spawned the Writing handler task");
88            tx_writing.send(()).unwrap(); // safe; the channel was just opened
89
90            // these objects are sent from `Tcp::adapt_stream`
91            while let Some(returnable_conn) = conn_receiver.recv().await {
92                self_clone.handle_new_connection(returnable_conn, &conn_senders).await;
93            }
94        });
95        let _ = rx_writing.await;
96        self.tcp().tasks.lock().push(writing_task);
97
98        // register the WritingHandler with the Tcp
99        let hdl = Box::new(WritingHandler { handler: ProtocolHandler(conn_sender), senders });
100        assert!(self.tcp().protocols.writing.set(hdl).is_ok(), "the Writing protocol was enabled more than once!");
101    }
102
103    /// Creates an [`Encoder`] used to write the outbound messages to the target stream.
104    /// The `side` param indicates the connection side **from the node's perspective**.
105    fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
106
107    /// Sends the provided message to the specified [`SocketAddr`]. Returns as soon as the message is queued to
108    /// be sent, without waiting for the actual delivery; instead, the caller is provided with a [`oneshot::Receiver`]
109    /// which can be used to determine when and whether the message has been delivered.
110    ///
111    /// # Errors
112    ///
113    /// The following errors can be returned:
114    /// - [`io::ErrorKind::NotConnected`] if the node is not connected to the provided address
115    /// - [`io::ErrorKind::Other`] if the outbound message queue for this address is full
116    /// - [`io::ErrorKind::Unsupported`] if [`Writing::enable_writing`] hadn't been called yet
117    fn unicast(&self, addr: SocketAddr, message: Self::Message) -> io::Result<oneshot::Receiver<io::Result<()>>> {
118        // access the protocol handler
119        if let Some(handler) = self.tcp().protocols.writing.get() {
120            // find the message sender for the given address
121            if let Some(sender) = handler.senders.read().get(&addr).cloned() {
122                let (msg, delivery) = WrappedMessage::new(Box::new(message));
123                sender
124                    .try_send(msg)
125                    .map_err(|e| {
126                        let conn_span = create_connection_span(addr, self.tcp().span());
127                        error!(parent: conn_span, "can't send a message: {e}");
128                        self.tcp().stats().register_failure();
129                        io::ErrorKind::Other.into()
130                    })
131                    .map(|_| delivery)
132            } else {
133                Err(io::ErrorKind::NotConnected.into())
134            }
135        } else {
136            Err(io::ErrorKind::Unsupported.into())
137        }
138    }
139
140    /// Broadcasts the provided message to all connected peers. Returns as soon as the message is queued to
141    /// be sent to all the peers, without waiting for the actual delivery. This method doesn't provide the
142    /// means to check when and if the messages actually get delivered; you can achieve that by calling
143    /// [`Writing::unicast`] for each address returned by [`Tcp::connected_addrs`].
144    ///
145    /// # Errors
146    ///
147    /// Returns [`io::ErrorKind::Unsupported`] if [`Writing::enable_writing`] hadn't been called yet.
148    fn broadcast(&self, message: Self::Message) -> io::Result<()>
149    where
150        Self::Message: Clone,
151    {
152        // access the protocol handler
153        if let Some(handler) = self.tcp().protocols.writing.get() {
154            let senders = handler.senders.read().clone();
155            for (addr, message_sender) in senders {
156                let (msg, _delivery) = WrappedMessage::new(Box::new(message.clone()));
157                let _ = message_sender.try_send(msg).map_err(|e| {
158                    let conn_span = create_connection_span(addr, self.tcp().span());
159                    error!(parent: conn_span, "can't send a message: {e}");
160                    self.tcp().stats().register_failure();
161                });
162            }
163
164            Ok(())
165        } else {
166            Err(io::ErrorKind::Unsupported.into())
167        }
168    }
169}
170
171/// This trait is used to restrict access to methods that would otherwise be public in [`Writing`].
172#[async_trait]
173trait WritingInternal: Writing {
174    /// Writes the given message to the network stream and returns the number of written bytes.
175    async fn write_to_stream<W: AsyncWrite + Unpin + Send>(
176        &self,
177        message: Self::Message,
178        writer: &mut FramedWrite<W, Self::Codec>,
179    ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error>;
180
181    /// Applies the [`Writing`] protocol to a single connection.
182    async fn handle_new_connection(&self, (conn, conn_returner): ReturnableConnection, conn_senders: &WritingSenders);
183}
184
185#[async_trait]
186impl<W: Writing> WritingInternal for W {
187    async fn write_to_stream<A: AsyncWrite + Unpin + Send>(
188        &self,
189        message: Self::Message,
190        writer: &mut FramedWrite<A, Self::Codec>,
191    ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error> {
192        writer.feed(message).await?;
193        let len = writer.write_buffer().len();
194        // Guard against write starvation
195        match timeout(W::TIMEOUT, writer.flush()).await {
196            Ok(Ok(())) => Ok(len),
197            Ok(Err(e)) => Err(e),
198            Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "write timed out")),
199        }
200    }
201
202    async fn handle_new_connection(
203        &self,
204        (mut conn, conn_returner): ReturnableConnection,
205        conn_senders: &WritingSenders,
206    ) {
207        let addr = conn.addr();
208        let codec = self.codec(addr, !conn.side());
209        let writer = conn.writer.take().expect("missing connection writer!");
210        let mut framed = FramedWrite::new(writer, codec);
211
212        let (outbound_message_sender, mut outbound_message_receiver) = mpsc::channel(self.message_queue_depth());
213
214        // register the connection's message sender with the Writing protocol handler
215        conn_senders.write().insert(addr, outbound_message_sender);
216
217        // this will automatically drop the sender upon a disconnect
218        let auto_cleanup = SenderCleanup { addr, senders: Arc::clone(conn_senders) };
219
220        // use a channel to know when the writer task is ready
221        let (tx_writer, rx_writer) = oneshot::channel();
222
223        // the task for writing outbound messages
224        let self_clone = self.clone();
225        let conn_span = conn.span().clone();
226        let writer_task = tokio::spawn(Box::pin(async move {
227            let node = self_clone.tcp();
228            trace!(parent: &conn_span, "spawned a task for writing messages");
229            tx_writer.send(()).unwrap(); // safe; the channel was just opened
230
231            // move the cleanup into the task that gets aborted on disconnect
232            let _auto_cleanup = auto_cleanup;
233
234            while let Some(wrapped_msg) = outbound_message_receiver.recv().await {
235                let msg = wrapped_msg.msg.downcast().unwrap();
236
237                match self_clone.write_to_stream(*msg, &mut framed).await {
238                    Ok(len) => {
239                        let _ = wrapped_msg.delivery_notification.send(Ok(()));
240                        // node.known_peers().register_sent_message(addr.ip(), len);
241                        node.stats().register_sent_message(len);
242                        trace!(parent: &conn_span, "sent {len}B");
243                    }
244                    Err(e) => {
245                        node.known_peers().register_failure(addr.ip());
246                        error!(parent: &conn_span, "couldn't send a message: {e}");
247                        let is_fatal = node.config().fatal_io_errors.contains(&e.kind());
248                        let _ = wrapped_msg.delivery_notification.send(Err(e));
249                        if is_fatal {
250                            break;
251                        }
252                    }
253                }
254            }
255
256            node.disconnect(addr).await;
257        }));
258        let _ = rx_writer.await;
259        conn.tasks.push(writer_task);
260
261        // return the Connection to the Tcp, resuming Tcp::adapt_stream
262        if conn_returner.send(Ok(conn)).is_err() {
263            unreachable!("couldn't return a Connection to the Tcp");
264        }
265    }
266}
267
268/// Used to queue messages for delivery.
269struct WrappedMessage {
270    msg: Box<dyn Any + Send>,
271    delivery_notification: oneshot::Sender<io::Result<()>>,
272}
273
274impl WrappedMessage {
275    fn new(msg: Box<dyn Any + Send>) -> (Self, oneshot::Receiver<io::Result<()>>) {
276        let (tx, rx) = oneshot::channel();
277        let wrapped_msg = Self { msg, delivery_notification: tx };
278
279        (wrapped_msg, rx)
280    }
281}
282
283/// The handler object dedicated to the [`Writing`] protocol.
284pub(crate) struct WritingHandler {
285    handler: ProtocolHandler<Connection, io::Result<Connection>>,
286    senders: WritingSenders,
287}
288
289impl Protocol<Connection, io::Result<Connection>> for WritingHandler {
290    async fn trigger(&self, item: ReturnableConnection) {
291        self.handler.trigger(item).await;
292    }
293}
294
295struct SenderCleanup {
296    addr: SocketAddr,
297    senders: WritingSenders,
298}
299
300impl Drop for SenderCleanup {
301    fn drop(&mut self) {
302        self.senders.write().remove(&self.addr);
303    }
304}