1use crate::ensure_newline;
2use crate::subscription::BoxedSubscription;
3use bytes::Bytes;
4use futures_util::{SinkExt, StreamExt};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt::Debug;
9use std::hash::Hash;
10use std::sync::Arc;
11use thiserror::Error;
12use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
13use tokio::sync::Notify;
14use tokio_util::codec::{BytesCodec, Framed};
15use tracing::{debug, warn};
16
17pub struct Console<Services, A> {
22 inner: Arc<Inner<Services>>,
23 bind_address: Option<A>,
24 stop: Arc<Notify>,
25}
26
27struct Inner<Services> {
28 subscriptions: HashMap<Services, BoxedSubscription>,
29 welcome: String,
30 accept_only_localhost: bool,
31}
32
33impl<Services, A> Console<Services, A> {
34 pub(crate) fn new(
35 subscriptions: HashMap<Services, BoxedSubscription>,
36 bind_address: A,
37 welcome: String,
38 accept_only_localhost: bool,
39 ) -> Self {
40 Self {
41 inner: Arc::new(Inner {
42 subscriptions,
43 welcome,
44 accept_only_localhost,
45 }),
46 bind_address: Some(bind_address),
47 stop: Arc::new(Notify::new()),
48 }
49 }
50}
51impl<Services, A> Console<Services, A>
52where
53 Services: DeserializeOwned + Eq + Hash + Debug + Send + Sync + 'static,
54 A: ToSocketAddrs + 'static,
55{
56 pub async fn spawn(&mut self) -> Result<(), Error> {
58 let Some(bind_address) = self.bind_address.take() else {
59 warn!("Console has already started");
60 return Err(Error::AlreadyStarted);
61 };
62
63 let listener = TcpListener::bind(bind_address).await?;
64 let inner = self.inner.clone();
65 let stop = self.stop.clone();
66
67 tokio::spawn(async move {
68 debug!(
69 "Listening on {:?}",
70 listener.local_addr().expect("Local address must be known")
71 );
72
73 loop {
74 let stream = tokio::select! {
79 _ = stop.notified() => {
80 debug!("Stopping console");
81 return;
82 }
83 Ok((stream, _)) = listener.accept() => {
84 stream
85 }
86 };
87
88 debug!("New console connection.");
89
90 let Ok(addr) = stream.peer_addr() else {
91 warn!("Could not get peer address. Closing the connection.");
92 continue;
93 };
94 if inner.accept_only_localhost && !addr.ip().is_loopback() {
95 warn!("Only connection from the localhost are allowed. Connected peer address {addr}. Closing the connection.");
96 continue;
97 }
98
99 tokio::spawn(Self::handle_console_session(
100 stream,
101 inner.clone(),
102 stop.clone(),
103 ));
104 }
105 });
106
107 Ok(())
108 }
109
110 pub fn stop(&self) {
112 self.stop.notify_waiters();
113 }
114
115 async fn handle_console_session(
117 stream: TcpStream,
118 inner: Arc<Inner<Services>>,
119 stop: Arc<Notify>,
120 ) {
121 let Ok(addr) = stream.peer_addr() else {
122 warn!("Could not get peer address. Closing the session.");
123 return;
124 };
125
126 debug!("Connected to {addr}");
127
128 let mut bytes_stream = Framed::new(stream, BytesCodec::new());
129
130 debug!("Welcoming {addr}");
131 let bytes: Bytes = inner.welcome.as_bytes().to_vec().into();
132 let _ = bytes_stream.send(bytes).await;
133 debug!("Finished welcoming {addr}");
134
135 loop {
136 let bytes = tokio::select! {
137 _ = stop.notified() => {
138 debug!("Stopping session for {addr}");
139 return;
140 }
141 result = bytes_stream.next() => match result {
142 Some(Ok(bytes)) => {
143 bytes.freeze()
144 }
145 Some(Err(err)) => {
146 warn!("Error while receiving bytes: {err}. Received bytes will not be processed");
147 continue;
148 }
149 None => {
150 debug!("Connection closed by {addr}");
152 return;
153 }
154 }
155 };
156
157 match bcs::from_bytes::<Message<Services>>(bytes.as_ref()) {
158 Ok(Message { service_id, bytes }) => {
159 debug!("Received message for {service_id:?}");
162
163 if let Some(subscription) = inner.subscriptions.get(&service_id) {
164 debug!("Found subscription for service {service_id:?}");
165
166 match subscription.handle(bytes).await {
167 Ok(None) => {}
168 Ok(Some(bytes)) => {
169 let _ = bytes_stream.send(bytes).await;
170 }
171 Err(err) => warn!("Error handling message: {err}"),
172 }
173 } else {
174 warn!("No subscription found for service {service_id:?}. Ignoring the message.");
175 }
176 }
177 Err(_err) => {
178 let text = String::from_utf8_lossy(bytes.as_ref()).trim().to_string();
182 debug!("Received message is not typed. Treating it as text: {text}");
183
184 for (service_id, subscription) in &inner.subscriptions {
185 debug!("[{service_id:?}] request to process text message: `{text}`");
186
187 match subscription.weak_handle(&text).await {
188 Ok(None) => {
189 continue;
190 }
191 Ok(Some(message)) => {
192 debug!("[{service_id:?}] Message processed");
193 let vec: Bytes = ensure_newline(message).as_bytes().to_vec().into();
194 let _ = bytes_stream.send(vec).await;
195 break;
196 }
197 Err(err) => {
198 warn!("Service {service_id:?} failed to handle message: {err}");
199 continue;
200 }
201 }
202 }
203 }
204 }
205 }
206 }
207}
208
209#[derive(Serialize, Deserialize)]
211pub(crate) struct Message<Services> {
212 service_id: Services,
213 bytes: Bytes,
214}
215
216impl<Services> Message<Services> {
217 pub(crate) fn new(service_id: Services, message: &impl Serialize) -> Result<Self, Error> {
219 Ok(Self {
220 service_id,
221 bytes: Bytes::from(bcs::to_bytes(message)?),
222 })
223 }
224}
225
226#[derive(Debug, Error)]
227pub enum Error {
228 #[error("Subscription cannot be registered: service id `{0}` is already in use")]
229 ServiceIdUsed(String),
230 #[error("Console bind address is not specified")]
231 NoBindAddress,
232 #[error("Console had already started")]
233 AlreadyStarted,
234 #[error("IO error: {0}")]
235 Io(#[from] std::io::Error),
236 #[error("Serde error: {0}")]
237 Serde(#[from] bcs::Error),
238}