tcp_console/
console.rs

1use crate::ensure_newline;
2use crate::subscription::BoxedSubscription;
3use bytes::Bytes;
4use futures_util::{SinkExt, StreamExt};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt::Debug;
9use std::hash::Hash;
10use std::sync::Arc;
11use thiserror::Error;
12use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
13use tokio::sync::Notify;
14use tokio_util::codec::{BytesCodec, Framed};
15use tracing::{debug, warn};
16
17/// A TCP console to process both strongly typed and free form messages.
18/// Free form messages are sent to all known subscriptions in random order until the _first_ success.
19///
20/// This console only allows message from localhost.
21pub struct Console<Services, A> {
22    inner: Arc<Inner<Services>>,
23    bind_address: Option<A>,
24    stop: Arc<Notify>,
25}
26
27struct Inner<Services> {
28    subscriptions: HashMap<Services, BoxedSubscription>,
29    welcome: String,
30    accept_only_localhost: bool,
31}
32
33impl<Services, A> Console<Services, A> {
34    pub(crate) fn new(
35        subscriptions: HashMap<Services, BoxedSubscription>,
36        bind_address: A,
37        welcome: String,
38        accept_only_localhost: bool,
39    ) -> Self {
40        Self {
41            inner: Arc::new(Inner {
42                subscriptions,
43                welcome,
44                accept_only_localhost,
45            }),
46            bind_address: Some(bind_address),
47            stop: Arc::new(Notify::new()),
48        }
49    }
50}
51impl<Services, A> Console<Services, A>
52where
53    Services: DeserializeOwned + Eq + Hash + Debug + Send + Sync + 'static,
54    A: ToSocketAddrs + 'static,
55{
56    /// Spawn the console by opening a TCP socket at the specified address.
57    pub async fn spawn(&mut self) -> Result<(), Error> {
58        let Some(bind_address) = self.bind_address.take() else {
59            warn!("Console has already started");
60            return Err(Error::AlreadyStarted);
61        };
62
63        let listener = TcpListener::bind(bind_address).await?;
64        let inner = self.inner.clone();
65        let stop = self.stop.clone();
66
67        tokio::spawn(async move {
68            debug!(
69                "Listening on {:?}",
70                listener.local_addr().expect("Local address must be known")
71            );
72
73            loop {
74                // Keep accepting console sessions,
75                // verify that they satisfy the requirements,
76                // if so, spawn a task to handle the session.
77
78                let stream = tokio::select! {
79                    _ = stop.notified() => {
80                        debug!("Stopping console");
81                        return;
82                    }
83                    Ok((stream, _)) = listener.accept() => {
84                        stream
85                    }
86                };
87
88                debug!("New console connection.");
89
90                let Ok(addr) = stream.peer_addr() else {
91                    warn!("Could not get peer address. Closing the connection.");
92                    continue;
93                };
94                if inner.accept_only_localhost && !addr.ip().is_loopback() {
95                    warn!("Only connection from the localhost are allowed. Connected peer address {addr}. Closing the connection.");
96                    continue;
97                }
98
99                tokio::spawn(Self::handle_console_session(
100                    stream,
101                    inner.clone(),
102                    stop.clone(),
103                ));
104            }
105        });
106
107        Ok(())
108    }
109
110    /// Stop the console and break all the current connections.
111    pub fn stop(&self) {
112        self.stop.notify_waiters();
113    }
114
115    /// Internal function handling a remote console session.
116    async fn handle_console_session(
117        stream: TcpStream,
118        inner: Arc<Inner<Services>>,
119        stop: Arc<Notify>,
120    ) {
121        let Ok(addr) = stream.peer_addr() else {
122            warn!("Could not get peer address. Closing the session.");
123            return;
124        };
125
126        debug!("Connected to {addr}");
127
128        let mut bytes_stream = Framed::new(stream, BytesCodec::new());
129
130        debug!("Welcoming {addr}");
131        let bytes: Bytes = inner.welcome.as_bytes().to_vec().into();
132        let _ = bytes_stream.send(bytes).await;
133        debug!("Finished welcoming {addr}");
134
135        loop {
136            let bytes = tokio::select! {
137                _ = stop.notified() => {
138                    debug!("Stopping session for {addr}");
139                    return;
140                }
141                result = bytes_stream.next() => match result {
142                    Some(Ok(bytes)) => {
143                        bytes.freeze()
144                    }
145                    Some(Err(err)) => {
146                        warn!("Error while receiving bytes: {err}. Received bytes will not be processed");
147                        continue;
148                    }
149                    None => {
150                        // Connection closed.
151                        debug!("Connection closed by {addr}");
152                        return;
153                    }
154                }
155            };
156
157            match bcs::from_bytes::<Message<Services>>(bytes.as_ref()) {
158                Ok(Message { service_id, bytes }) => {
159                    // Message is strongly typed.
160
161                    debug!("Received message for {service_id:?}");
162
163                    if let Some(subscription) = inner.subscriptions.get(&service_id) {
164                        debug!("Found subscription for service {service_id:?}");
165
166                        match subscription.handle(bytes).await {
167                            Ok(None) => {}
168                            Ok(Some(bytes)) => {
169                                let _ = bytes_stream.send(bytes).await;
170                            }
171                            Err(err) => warn!("Error handling message: {err}"),
172                        }
173                    } else {
174                        warn!("No subscription found for service {service_id:?}. Ignoring the message.");
175                    }
176                }
177                Err(_err) => {
178                    // Message is not strongly typed and probably came from netcat or a similar client.
179                    // Try all subscriptions to make sense of it until the FIRST success.
180
181                    let text = String::from_utf8_lossy(bytes.as_ref()).trim().to_string();
182                    debug!("Received message is not typed. Treating it as text: {text}");
183
184                    for (service_id, subscription) in &inner.subscriptions {
185                        debug!("[{service_id:?}] request to process text message: `{text}`");
186
187                        match subscription.weak_handle(&text).await {
188                            Ok(None) => {
189                                continue;
190                            }
191                            Ok(Some(message)) => {
192                                debug!("[{service_id:?}] Message processed");
193                                let vec: Bytes = ensure_newline(message).as_bytes().to_vec().into();
194                                let _ = bytes_stream.send(vec).await;
195                                break;
196                            }
197                            Err(err) => {
198                                warn!("Service {service_id:?} failed to handle message: {err}");
199                                continue;
200                            }
201                        }
202                    }
203                }
204            }
205        }
206    }
207}
208
209/// A wrapper struct to pass strongly-typed messages on [Console].
210#[derive(Serialize, Deserialize)]
211pub(crate) struct Message<Services> {
212    service_id: Services,
213    bytes: Bytes,
214}
215
216impl<Services> Message<Services> {
217    /// Creates a new [Message] with any serializable payload.
218    pub(crate) fn new(service_id: Services, message: &impl Serialize) -> Result<Self, Error> {
219        Ok(Self {
220            service_id,
221            bytes: Bytes::from(bcs::to_bytes(message)?),
222        })
223    }
224}
225
226#[derive(Debug, Error)]
227pub enum Error {
228    #[error("Subscription cannot be registered: service id `{0}` is already in use")]
229    ServiceIdUsed(String),
230    #[error("Console bind address is not specified")]
231    NoBindAddress,
232    #[error("Console had already started")]
233    AlreadyStarted,
234    #[error("IO error: {0}")]
235    Io(#[from] std::io::Error),
236    #[error("Serde error: {0}")]
237    Serde(#[from] bcs::Error),
238}