snarkos_node_tcp/protocols/
writing.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkOS library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{any::Any, collections::HashMap, io, net::SocketAddr, sync::Arc};
17
18use async_trait::async_trait;
19use futures_util::sink::SinkExt;
20#[cfg(feature = "locktick")]
21use locktick::parking_lot::RwLock;
22#[cfg(not(feature = "locktick"))]
23use parking_lot::RwLock;
24use tokio::{
25    io::AsyncWrite,
26    sync::{mpsc, oneshot},
27};
28use tokio_util::codec::{Encoder, FramedWrite};
29use tracing::*;
30
31#[cfg(doc)]
32use crate::{Config, Tcp, protocols::Handshake};
33use crate::{
34    Connection,
35    ConnectionSide,
36    P2P,
37    protocols::{Protocol, ProtocolHandler, ReturnableConnection},
38};
39
40type WritingSenders = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<WrappedMessage>>>>;
41
42/// Can be used to specify and enable writing, i.e. sending outbound messages. If the [`Handshake`]
43/// protocol is enabled too, it goes into force only after the handshake has been concluded.
44#[async_trait]
45pub trait Writing: P2P
46where
47    Self: Clone + Send + Sync + 'static,
48{
49    /// The depth of per-connection queues used to send outbound messages; the greater it is, the more outbound
50    /// messages the node can enqueue. Setting it to a large value is not recommended, as doing it might
51    /// obscure potential issues with your implementation (like slow serialization) or network.
52    ///
53    /// The default value is 1024.
54    const MESSAGE_QUEUE_DEPTH: usize = 1024;
55
56    /// The type of the outbound messages; unless their serialization is expensive and the message
57    /// is broadcasted (in which case it would get serialized multiple times), serialization should
58    /// be done in the implementation of [`Self::Codec`].
59    type Message: Send;
60
61    /// The user-supplied [`Encoder`] used to write outbound messages to the target stream.
62    type Codec: Encoder<Self::Message, Error = io::Error> + Send;
63
64    /// Prepares the node to send messages.
65    async fn enable_writing(&self) {
66        let (conn_sender, mut conn_receiver) = mpsc::unbounded_channel();
67
68        // the conn_senders are used to send messages from the Tcp to individual connections
69        let conn_senders: WritingSenders = Default::default();
70        // procure a clone to create the WritingHandler with
71        let senders = conn_senders.clone();
72
73        // use a channel to know when the writing task is ready
74        let (tx_writing, rx_writing) = oneshot::channel();
75
76        // the task spawning tasks sending messages to all the streams
77        let self_clone = self.clone();
78        let writing_task = tokio::spawn(async move {
79            trace!(parent: self_clone.tcp().span(), "spawned the Writing handler task");
80            tx_writing.send(()).unwrap(); // safe; the channel was just opened
81
82            // these objects are sent from `Tcp::adapt_stream`
83            while let Some(returnable_conn) = conn_receiver.recv().await {
84                self_clone.handle_new_connection(returnable_conn, &conn_senders).await;
85            }
86        });
87        let _ = rx_writing.await;
88        self.tcp().tasks.lock().push(writing_task);
89
90        // register the WritingHandler with the Tcp
91        let hdl = Box::new(WritingHandler { handler: ProtocolHandler(conn_sender), senders });
92        assert!(self.tcp().protocols.writing.set(hdl).is_ok(), "the Writing protocol was enabled more than once!");
93    }
94
95    /// Creates an [`Encoder`] used to write the outbound messages to the target stream.
96    /// The `side` param indicates the connection side **from the node's perspective**.
97    fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
98
99    /// Sends the provided message to the specified [`SocketAddr`]. Returns as soon as the message is queued to
100    /// be sent, without waiting for the actual delivery; instead, the caller is provided with a [`oneshot::Receiver`]
101    /// which can be used to determine when and whether the message has been delivered.
102    ///
103    /// # Errors
104    ///
105    /// The following errors can be returned:
106    /// - [`io::ErrorKind::NotConnected`] if the node is not connected to the provided address
107    /// - [`io::ErrorKind::Other`] if the outbound message queue for this address is full
108    /// - [`io::ErrorKind::Unsupported`] if [`Writing::enable_writing`] hadn't been called yet
109    fn unicast(&self, addr: SocketAddr, message: Self::Message) -> io::Result<oneshot::Receiver<io::Result<()>>> {
110        // access the protocol handler
111        if let Some(handler) = self.tcp().protocols.writing.get() {
112            // find the message sender for the given address
113            if let Some(sender) = handler.senders.read().get(&addr).cloned() {
114                let (msg, delivery) = WrappedMessage::new(Box::new(message));
115                sender
116                    .try_send(msg)
117                    .map_err(|e| {
118                        error!(parent: self.tcp().span(), "can't send a message to {}: {}", addr, e);
119                        self.tcp().stats().register_failure();
120                        io::ErrorKind::Other.into()
121                    })
122                    .map(|_| delivery)
123            } else {
124                Err(io::ErrorKind::NotConnected.into())
125            }
126        } else {
127            Err(io::ErrorKind::Unsupported.into())
128        }
129    }
130
131    /// Broadcasts the provided message to all connected peers. Returns as soon as the message is queued to
132    /// be sent to all the peers, without waiting for the actual delivery. This method doesn't provide the
133    /// means to check when and if the messages actually get delivered; you can achieve that by calling
134    /// [`Writing::unicast`] for each address returned by [`Tcp::connected_addrs`].
135    ///
136    /// # Errors
137    ///
138    /// Returns [`io::ErrorKind::Unsupported`] if [`Writing::enable_writing`] hadn't been called yet.
139    fn broadcast(&self, message: Self::Message) -> io::Result<()>
140    where
141        Self::Message: Clone,
142    {
143        // access the protocol handler
144        if let Some(handler) = self.tcp().protocols.writing.get() {
145            let senders = handler.senders.read().clone();
146            for (addr, message_sender) in senders {
147                let (msg, _delivery) = WrappedMessage::new(Box::new(message.clone()));
148                let _ = message_sender.try_send(msg).map_err(|e| {
149                    error!(parent: self.tcp().span(), "can't send a message to {}: {}", addr, e);
150                    self.tcp().stats().register_failure();
151                });
152            }
153
154            Ok(())
155        } else {
156            Err(io::ErrorKind::Unsupported.into())
157        }
158    }
159}
160
161/// This trait is used to restrict access to methods that would otherwise be public in [`Writing`].
162#[async_trait]
163trait WritingInternal: Writing {
164    /// Writes the given message to the network stream and returns the number of written bytes.
165    async fn write_to_stream<W: AsyncWrite + Unpin + Send>(
166        &self,
167        message: Self::Message,
168        writer: &mut FramedWrite<W, Self::Codec>,
169    ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error>;
170
171    /// Applies the [`Writing`] protocol to a single connection.
172    async fn handle_new_connection(&self, (conn, conn_returner): ReturnableConnection, conn_senders: &WritingSenders);
173}
174
175#[async_trait]
176impl<W: Writing> WritingInternal for W {
177    async fn write_to_stream<A: AsyncWrite + Unpin + Send>(
178        &self,
179        message: Self::Message,
180        writer: &mut FramedWrite<A, Self::Codec>,
181    ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error> {
182        writer.feed(message).await?;
183        let len = writer.write_buffer().len();
184        writer.flush().await?;
185
186        Ok(len)
187    }
188
189    async fn handle_new_connection(
190        &self,
191        (mut conn, conn_returner): ReturnableConnection,
192        conn_senders: &WritingSenders,
193    ) {
194        let addr = conn.addr();
195        let codec = self.codec(addr, !conn.side());
196        let writer = conn.writer.take().expect("missing connection writer!");
197        let mut framed = FramedWrite::new(writer, codec);
198
199        let (outbound_message_sender, mut outbound_message_receiver) = mpsc::channel(Self::MESSAGE_QUEUE_DEPTH);
200
201        // register the connection's message sender with the Writing protocol handler
202        conn_senders.write().insert(addr, outbound_message_sender);
203
204        // this will automatically drop the sender upon a disconnect
205        let auto_cleanup = SenderCleanup { addr, senders: Arc::clone(conn_senders) };
206
207        // use a channel to know when the writer task is ready
208        let (tx_writer, rx_writer) = oneshot::channel();
209
210        // the task for writing outbound messages
211        let self_clone = self.clone();
212        let writer_task = tokio::spawn(async move {
213            let node = self_clone.tcp();
214            trace!(parent: node.span(), "spawned a task for writing messages to {}", addr);
215            tx_writer.send(()).unwrap(); // safe; the channel was just opened
216
217            // move the cleanup into the task that gets aborted on disconnect
218            let _auto_cleanup = auto_cleanup;
219
220            while let Some(wrapped_msg) = outbound_message_receiver.recv().await {
221                let msg = wrapped_msg.msg.downcast().unwrap();
222
223                match self_clone.write_to_stream(*msg, &mut framed).await {
224                    Ok(len) => {
225                        let _ = wrapped_msg.delivery_notification.send(Ok(()));
226                        node.known_peers().register_sent_message(addr.ip(), len);
227                        node.stats().register_sent_message(len);
228                        trace!(parent: node.span(), "sent {}B to {}", len, addr);
229                    }
230                    Err(e) => {
231                        node.known_peers().register_failure(addr.ip());
232                        error!(parent: node.span(), "couldn't send a message to {}: {}", addr, e);
233                        let is_fatal = node.config().fatal_io_errors.contains(&e.kind());
234                        let _ = wrapped_msg.delivery_notification.send(Err(e));
235                        if is_fatal {
236                            break;
237                        }
238                    }
239                }
240            }
241
242            node.disconnect(addr).await;
243        });
244        let _ = rx_writer.await;
245        conn.tasks.push(writer_task);
246
247        // return the Connection to the Tcp, resuming Tcp::adapt_stream
248        if conn_returner.send(Ok(conn)).is_err() {
249            unreachable!("couldn't return a Connection to the Tcp");
250        }
251    }
252}
253
254/// Used to queue messages for delivery.
255struct WrappedMessage {
256    msg: Box<dyn Any + Send>,
257    delivery_notification: oneshot::Sender<io::Result<()>>,
258}
259
260impl WrappedMessage {
261    fn new(msg: Box<dyn Any + Send>) -> (Self, oneshot::Receiver<io::Result<()>>) {
262        let (tx, rx) = oneshot::channel();
263        let wrapped_msg = Self { msg, delivery_notification: tx };
264
265        (wrapped_msg, rx)
266    }
267}
268
269/// The handler object dedicated to the [`Writing`] protocol.
270pub(crate) struct WritingHandler {
271    handler: ProtocolHandler<Connection, io::Result<Connection>>,
272    senders: WritingSenders,
273}
274
275impl Protocol<Connection, io::Result<Connection>> for WritingHandler {
276    fn trigger(&self, item: ReturnableConnection) {
277        self.handler.trigger(item);
278    }
279}
280
281struct SenderCleanup {
282    addr: SocketAddr,
283    senders: WritingSenders,
284}
285
286impl Drop for SenderCleanup {
287    fn drop(&mut self) {
288        self.senders.write().remove(&self.addr);
289    }
290}