rustenium_core/
transport.rs

1use std::{error::Error, future::Future};
2use std::fmt::{Display};
3use std::sync::Arc;
4use fastwebsockets::{handshake, Frame, OpCode, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite};
5use hyper::{
6    body::Bytes,
7    header::{CONNECTION, UPGRADE},
8    upgrade::Upgraded,
9    Request,
10};
11use hyper_util::rt::TokioIo;
12use tokio::io::{ReadHalf, WriteHalf};
13use tokio::net::TcpStream;
14use tokio::sync::mpsc::UnboundedSender;
15use tokio::sync::Mutex;
16
17#[derive(Debug, Clone)]
18pub enum ConnectionTransportProtocol {
19    Http,
20    Https,
21    Ws,
22    Wss,
23}
24
25impl Display for ConnectionTransportProtocol {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        let str = match self {
28            ConnectionTransportProtocol::Http => "http",
29            ConnectionTransportProtocol::Https => "https",
30            ConnectionTransportProtocol::Ws => "ws",
31            ConnectionTransportProtocol::Wss => "wss",
32        };
33        write!(f, "{}", str)
34    }
35}
36pub enum UrlFormat {
37    HostPort,
38    ProtocolHostPort,
39    Full, // protocol://host:port/path
40}
41
42#[derive(Debug, Clone)]
43pub struct ConnectionTransportConfig {
44    pub protocol: ConnectionTransportProtocol,
45    pub host: String,
46    pub port: u16,
47    pub path:  &'static str,
48}
49
50impl Default for ConnectionTransportConfig {
51    fn default() -> Self {
52        Self {
53            protocol: ConnectionTransportProtocol::Ws,
54            host: String::from("localhost"),
55            port: 0,
56            path: "session",
57        }
58    }
59}
60
61impl ConnectionTransportConfig {
62    pub fn full_endpoint(&self) -> String {
63        format!("{}://{}{}", self.protocol, self.host_port(), self.path())
64    }
65    
66    pub fn host_port(&self) -> String {
67        format!("{}:{}", self.host, self.port)
68    }
69
70    pub fn path(&self) -> String {
71        let path_str = self.path.trim_start_matches('/');
72        format!("/{}", path_str)
73    }
74
75}
76
77pub trait ConnectionTransport {
78    fn send(&mut self, message: String) -> impl Future<Output=()> + Send;
79    fn listen(&self, listener: UnboundedSender<String>) -> ();
80    fn close(&self) -> ();
81    fn on_close(&self) -> ();
82}
83
84pub struct WebsocketConnectionTransport {
85    client_tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
86    client_rx: Arc<Mutex<WebSocketRead<ReadHalf<TokioIo<Upgraded>>>>>,
87}
88
89impl ConnectionTransport for WebsocketConnectionTransport {
90    fn send(&mut self, message: String) -> impl Future<Output=()> + Send
91    {
92        async move {
93            let frame = Frame::text(fastwebsockets::Payload::from(message.as_bytes()));
94            self.client_tx.lock().await.write_frame(frame).await.unwrap();
95        }
96    }
97
98    fn listen(&self, listener: UnboundedSender<String>) -> () {
99        WebsocketConnectionTransport::listener_loop(self.client_rx.clone(), self.client_tx.clone(), listener).unwrap();
100    }
101
102    fn close(&self) -> () {
103        let client_tx = self.client_tx.clone();
104        tokio::spawn(async move {
105            let mut tx = client_tx.lock().await;
106            let _ = tx.write_frame(Frame::close(1000, b"")).await;
107        });
108    }
109
110    fn on_close(&self) -> () {
111        todo!()
112    }
113}
114
115impl WebsocketConnectionTransport {
116    pub async fn new(connection_config: &ConnectionTransportConfig) -> Result<Self, Box<dyn Error>> {
117        let addr_host = connection_config.host_port();
118
119        // Retry on connection refused (driver starting up)
120        let retry_delay_ms = 400;
121        let mut retries = 3;
122
123        let stream = loop {
124            match TcpStream::connect(&addr_host).await {
125                Ok(stream) => break stream,
126                Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused && retries > 0 => {
127                    tracing::warn!("Connection refused, retrying... ({} attempts remaining)", retries);
128                    retries -= 1;
129                    tokio::time::sleep(tokio::time::Duration::from_millis(retry_delay_ms)).await;
130                }
131                Err(e) => return Err(Box::new(e)),
132            }
133        };
134
135        let uri = connection_config.path();
136        let req = Request::builder()
137            .method("GET")
138            .uri(uri)
139            .header("Host", &addr_host)
140            .header(UPGRADE, "websocket")
141            .header(CONNECTION, "upgrade")
142            .header(
143                "Sec-WebSocket-Key",
144                fastwebsockets::handshake::generate_key(),
145            )
146            .header("Sec-WebSocket-Version", "13")
147            .body(http_body_util::Empty::<Bytes>::new()).unwrap();
148
149        let (mut ws, _) = handshake::client(&SpawnExecutor, req, stream).await.unwrap();
150        ws = Self::configure_client(ws);
151        let (rx, tx) = ws.split(tokio::io::split);
152        tracing::info!("Successfully connected to WebDriver");
153
154        Ok(Self {
155            client_rx: Arc::new(Mutex::new(rx)),
156            client_tx: Arc::new(Mutex::new(tx))
157        })
158    }
159
160    fn configure_client(mut ws: WebSocket<TokioIo<Upgraded>>) -> WebSocket<TokioIo<Upgraded>> {
161        ws.set_writev(true);
162        ws.set_auto_close(true);
163        ws.set_auto_pong(true);
164
165        ws
166    }
167    pub fn listener_loop(ws_rx: Arc<Mutex<WebSocketRead<ReadHalf<TokioIo<Upgraded>>>>>, ws_tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>, tx: UnboundedSender<String>) -> Result<(), WebSocketError>
168    {
169        tokio::spawn(async move {
170            loop {
171                let mut ws_rx_half = ws_rx.lock().await;
172                let frame = match ws_rx_half.read_frame(&mut |frame| async {
173                    // Handles obligated send
174                    let mut ws_write_half = ws_tx.lock().await;
175                    return ws_write_half.write_frame(frame).await;
176                }).await {
177                    Ok(frame) => frame,
178                    // Err(WebSocketError::IoError(e)) if e.kind() == std::io::ErrorKind::ConnectionAborted => {
179                    //     tracing::warn!("WebSocket connection aborted: {}. Exiting listener loop.", e);
180                    //     break;
181                    // }
182                    Err(WebSocketError::UnexpectedEOF) => {
183                        tracing::warn!("WebSocket connection closed (unexpected EOF). Exiting listener loop.");
184                        break;
185                    }
186                    Err(e) => {
187                        panic!("Unexpected WebSocket error: {:?}", e);
188                    }
189                };
190
191                match frame.opcode {
192                    OpCode::Close => break,
193                    OpCode::Text | OpCode::Binary => {
194                        let incoming = Frame::new(true, frame.opcode, None, frame.payload);
195                        assert!(incoming.fin);
196                        let string_payload = String::from_utf8(incoming.payload.to_owned());
197                        if let Ok(str_payload) = string_payload {
198                            tx.send(str_payload).unwrap()
199                        }
200                    }
201                    _ => {}
202                }
203            }
204        });
205        Ok(())
206    }
207} //
208
209struct SpawnExecutor;
210
211impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
212where
213    Fut: Future + Send + 'static,
214    Fut::Output: Send + 'static,
215{
216    fn execute(&self, fut: Fut) {
217        tokio::task::spawn(fut);
218    }
219}