Skip to main content

rustenium_core/
transport.rs

1use std::{error::Error, future::Future};
2use std::fmt::{Display};
3use std::sync::Arc;
4use fastwebsockets::{handshake, Frame, OpCode, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite};
5use hyper::{
6    body::Bytes,
7    header::{CONNECTION, UPGRADE},
8    upgrade::Upgraded,
9    Request,
10};
11use hyper_util::rt::TokioIo;
12use tokio::io::{ReadHalf, WriteHalf};
13use tokio::net::TcpStream;
14use tokio::sync::mpsc::UnboundedSender;
15use tokio::sync::Mutex;
16
17#[derive(Debug, Clone)]
18pub enum ConnectionTransportProtocol {
19    Http,
20    Https,
21    Ws,
22    Wss,
23}
24
25impl Display for ConnectionTransportProtocol {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        let str = match self {
28            ConnectionTransportProtocol::Http => "http",
29            ConnectionTransportProtocol::Https => "https",
30            ConnectionTransportProtocol::Ws => "ws",
31            ConnectionTransportProtocol::Wss => "wss",
32        };
33        write!(f, "{}", str)
34    }
35}
36pub enum UrlFormat {
37    HostPort,
38    ProtocolHostPort,
39    Full, // protocol://host:port/path
40}
41
42#[derive(Debug, Clone)]
43pub struct ConnectionTransportConfig {
44    pub protocol: ConnectionTransportProtocol,
45    pub host: String,
46    pub port: u16,
47    pub path: String,
48}
49
50impl Default for ConnectionTransportConfig {
51    fn default() -> Self {
52        Self {
53            protocol: ConnectionTransportProtocol::Ws,
54            host: String::from("localhost"),
55            port: 0,
56            path: "session".to_string(),
57        }
58    }
59}
60
61impl ConnectionTransportConfig {
62    pub fn full_endpoint(&self) -> String {
63        format!("{}://{}{}", self.protocol, self.host_port(), self.path())
64    }
65
66    pub fn host_port(&self) -> String {
67        format!("{}:{}", self.host, self.port)
68    }
69
70    pub fn path(&self) -> String {
71        let path_str = self.path.trim_start_matches('/');
72        format!("/{}", path_str)
73    }
74
75    /// Parse a WebSocket URL (`ws://` or `wss://`) into a `ConnectionTransportConfig`.
76    pub fn from_ws_url(url: &str) -> Result<Self, String> {
77        let (protocol_str, rest) = url
78            .split_once("://")
79            .ok_or_else(|| format!("missing '://' in URL: {}", url))?;
80
81        let protocol = match protocol_str {
82            "ws" => ConnectionTransportProtocol::Ws,
83            "wss" => ConnectionTransportProtocol::Wss,
84            p => return Err(format!("unsupported WebSocket protocol: {}", p)),
85        };
86
87        let (host_port, path_tail) = rest.split_once('/').unwrap_or((rest, ""));
88        let (host, port_str) = host_port
89            .rsplit_once(':')
90            .ok_or_else(|| format!("missing port in URL: {}", url))?;
91        let port = port_str
92            .parse::<u16>()
93            .map_err(|e| format!("invalid port '{}': {}", port_str, e))?;
94
95        Ok(Self {
96            protocol,
97            host: host.to_string(),
98            port,
99            path: format!("/{}", path_tail),
100        })
101    }
102}
103
104pub trait ConnectionTransport {
105    fn send(&mut self, message: String) -> impl Future<Output=()> + Send;
106    fn listen(&self, listener: UnboundedSender<String>) -> ();
107    fn close(&self) -> impl Future<Output=()> + Send;
108    fn on_close(&self) -> ();
109}
110
111pub struct WebsocketConnectionTransport {
112    client_tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
113    client_rx: Arc<Mutex<WebSocketRead<ReadHalf<TokioIo<Upgraded>>>>>,
114}
115
116impl ConnectionTransport for WebsocketConnectionTransport {
117    fn send(&mut self, message: String) -> impl Future<Output=()> + Send
118    {
119        async move {
120            let frame = Frame::text(fastwebsockets::Payload::from(message.as_bytes()));
121            self.client_tx.lock().await.write_frame(frame).await.unwrap();
122        }
123    }
124
125    fn listen(&self, listener: UnboundedSender<String>) -> () {
126        WebsocketConnectionTransport::listener_loop(self.client_rx.clone(), self.client_tx.clone(), listener).unwrap();
127    }
128
129    fn close(&self) -> impl Future<Output=()> + Send {
130        let client_tx = self.client_tx.clone();
131        async move {
132            let mut tx = client_tx.lock().await;
133            let _ = tx.write_frame(Frame::close(1000, b"")).await;
134        }
135    }
136
137    fn on_close(&self) -> () {
138        todo!()
139    }
140}
141
142impl WebsocketConnectionTransport {
143    pub async fn new(connection_config: &ConnectionTransportConfig) -> Result<Self, Box<dyn Error>> {
144        let addr_host = connection_config.host_port();
145
146        // Retry on connection refused (driver starting up)
147        let retry_delay_ms = 400;
148        let mut retries = 3;
149
150        tracing::debug!("[WebsocketConnectionTransport]: Connecting to websocket @ url: {}", connection_config.full_endpoint());
151        let stream = loop {
152            match TcpStream::connect(&addr_host).await {
153                Ok(stream) => break stream,
154                Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused && retries > 0 => {
155                    tracing::warn!("Connection refused, retrying... ({} attempts remaining)", retries);
156                    retries -= 1;
157                    tokio::time::sleep(tokio::time::Duration::from_millis(retry_delay_ms)).await;
158                }
159                Err(e) => return Err(Box::new(e)),
160            }
161        };
162
163        let uri = connection_config.path();
164        let req = Request::builder()
165            .method("GET")
166            .uri(uri)
167            .header("Host", &addr_host)
168            .header(UPGRADE, "websocket")
169            .header(CONNECTION, "upgrade")
170            .header(
171                "Sec-WebSocket-Key",
172                fastwebsockets::handshake::generate_key(),
173            )
174            .header("Sec-WebSocket-Version", "13")
175            .body(http_body_util::Empty::<Bytes>::new()).unwrap();
176
177        let (mut ws, _) = handshake::client(&SpawnExecutor, req, stream).await.unwrap();
178        ws = Self::configure_client(ws);
179        let (rx, tx) = ws.split(tokio::io::split);
180
181        Ok(Self {
182            client_rx: Arc::new(Mutex::new(rx)),
183            client_tx: Arc::new(Mutex::new(tx))
184        })
185    }
186
187    fn configure_client(mut ws: WebSocket<TokioIo<Upgraded>>) -> WebSocket<TokioIo<Upgraded>> {
188        ws.set_writev(true);
189        ws.set_auto_close(true);
190        ws.set_auto_pong(true);
191
192        ws
193    }
194    pub fn listener_loop(ws_rx: Arc<Mutex<WebSocketRead<ReadHalf<TokioIo<Upgraded>>>>>, ws_tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>, tx: UnboundedSender<String>) -> Result<(), WebSocketError>
195    {
196        tokio::spawn(async move {
197            loop {
198                let mut ws_rx_half = ws_rx.lock().await;
199                let frame = match ws_rx_half.read_frame(&mut |frame| async {
200                    // Handles obligated send
201                    let mut ws_write_half = ws_tx.lock().await;
202                    return ws_write_half.write_frame(frame).await;
203                }).await {
204                    Ok(frame) => frame,
205                    // Err(WebSocketError::IoError(e)) if e.kind() == std::io::ErrorKind::ConnectionAborted => {
206                    //     tracing::warn!("WebSocket connection aborted: {}. Exiting listener loop.", e);
207                    //     break;
208                    // }
209                    Err(WebSocketError::UnexpectedEOF) => {
210                        tracing::warn!("WebSocket connection closed (unexpected EOF). Exiting listener loop.");
211                        break;
212                    }
213                    Err(e) => {
214                        panic!("Unexpected WebSocket error: {:?}", e);
215                    }
216                };
217
218                match frame.opcode {
219                    OpCode::Close => break,
220                    OpCode::Text | OpCode::Binary => {
221                        let incoming = Frame::new(true, frame.opcode, None, frame.payload);
222                        assert!(incoming.fin);
223                        let string_payload = String::from_utf8(incoming.payload.to_owned());
224                        if let Ok(str_payload) = string_payload {
225                            tx.send(str_payload).unwrap()
226                        }
227                    }
228                    _ => {}
229                }
230            }
231        });
232        Ok(())
233    }
234} //
235
236struct SpawnExecutor;
237
238impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
239where
240    Fut: Future + Send + 'static,
241    Fut::Output: Send + 'static,
242{
243    fn execute(&self, fut: Fut) {
244        tokio::task::spawn(fut);
245    }
246}