rustenium_core/
transport.rs1use std::{error::Error, future::Future};
2use std::fmt::{Display};
3use std::sync::Arc;
4use fastwebsockets::{handshake, Frame, OpCode, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite};
5use hyper::{
6 body::Bytes,
7 header::{CONNECTION, UPGRADE},
8 upgrade::Upgraded,
9 Request,
10};
11use hyper_util::rt::TokioIo;
12use tokio::io::{ReadHalf, WriteHalf};
13use tokio::net::TcpStream;
14use tokio::sync::mpsc::UnboundedSender;
15use tokio::sync::Mutex;
16
17#[derive(Debug, Clone)]
18pub enum ConnectionTransportProtocol {
19 Http,
20 Https,
21 Ws,
22 Wss,
23}
24
25impl Display for ConnectionTransportProtocol {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 let str = match self {
28 ConnectionTransportProtocol::Http => "http",
29 ConnectionTransportProtocol::Https => "https",
30 ConnectionTransportProtocol::Ws => "ws",
31 ConnectionTransportProtocol::Wss => "wss",
32 };
33 write!(f, "{}", str)
34 }
35}
36pub enum UrlFormat {
37 HostPort,
38 ProtocolHostPort,
39 Full, }
41
42#[derive(Debug, Clone)]
43pub struct ConnectionTransportConfig {
44 pub protocol: ConnectionTransportProtocol,
45 pub host: String,
46 pub port: u16,
47 pub path: &'static str,
48}
49
50impl Default for ConnectionTransportConfig {
51 fn default() -> Self {
52 Self {
53 protocol: ConnectionTransportProtocol::Ws,
54 host: String::from("localhost"),
55 port: 0,
56 path: "session",
57 }
58 }
59}
60
61impl ConnectionTransportConfig {
62 pub fn full_endpoint(&self) -> String {
63 format!("{}://{}{}", self.protocol, self.host_port(), self.path())
64 }
65
66 pub fn host_port(&self) -> String {
67 format!("{}:{}", self.host, self.port)
68 }
69
70 pub fn path(&self) -> String {
71 let path_str = self.path.trim_start_matches('/');
72 format!("/{}", path_str)
73 }
74
75}
76
77pub trait ConnectionTransport {
78 fn send(&mut self, message: String) -> impl Future<Output=()> + Send;
79 fn listen(&self, listener: UnboundedSender<String>) -> ();
80 fn close(&self) -> ();
81 fn on_close(&self) -> ();
82}
83
84pub struct WebsocketConnectionTransport {
85 client_tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
86 client_rx: Arc<Mutex<WebSocketRead<ReadHalf<TokioIo<Upgraded>>>>>,
87}
88
89impl ConnectionTransport for WebsocketConnectionTransport {
90 fn send(&mut self, message: String) -> impl Future<Output=()> + Send
91 {
92 async move {
93 let frame = Frame::text(fastwebsockets::Payload::from(message.as_bytes()));
94 self.client_tx.lock().await.write_frame(frame).await.unwrap();
95 }
96 }
97
98 fn listen(&self, listener: UnboundedSender<String>) -> () {
99 WebsocketConnectionTransport::listener_loop(self.client_rx.clone(), self.client_tx.clone(), listener).unwrap();
100 }
101
102 fn close(&self) -> () {
103 let client_tx = self.client_tx.clone();
104 tokio::spawn(async move {
105 let mut tx = client_tx.lock().await;
106 let _ = tx.write_frame(Frame::close(1000, b"")).await;
107 });
108 }
109
110 fn on_close(&self) -> () {
111 todo!()
112 }
113}
114
115impl WebsocketConnectionTransport {
116 pub async fn new(connection_config: &ConnectionTransportConfig) -> Result<Self, Box<dyn Error>> {
117 let addr_host = connection_config.host_port();
118
119 let retry_delay_ms = 400;
121 let mut retries = 3;
122
123 let stream = loop {
124 match TcpStream::connect(&addr_host).await {
125 Ok(stream) => break stream,
126 Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused && retries > 0 => {
127 tracing::warn!("Connection refused, retrying... ({} attempts remaining)", retries);
128 retries -= 1;
129 tokio::time::sleep(tokio::time::Duration::from_millis(retry_delay_ms)).await;
130 }
131 Err(e) => return Err(Box::new(e)),
132 }
133 };
134
135 let uri = connection_config.path();
136 let req = Request::builder()
137 .method("GET")
138 .uri(uri)
139 .header("Host", &addr_host)
140 .header(UPGRADE, "websocket")
141 .header(CONNECTION, "upgrade")
142 .header(
143 "Sec-WebSocket-Key",
144 fastwebsockets::handshake::generate_key(),
145 )
146 .header("Sec-WebSocket-Version", "13")
147 .body(http_body_util::Empty::<Bytes>::new()).unwrap();
148
149 let (mut ws, _) = handshake::client(&SpawnExecutor, req, stream).await.unwrap();
150 ws = Self::configure_client(ws);
151 let (rx, tx) = ws.split(tokio::io::split);
152 tracing::info!("Successfully connected to WebDriver");
153
154 Ok(Self {
155 client_rx: Arc::new(Mutex::new(rx)),
156 client_tx: Arc::new(Mutex::new(tx))
157 })
158 }
159
160 fn configure_client(mut ws: WebSocket<TokioIo<Upgraded>>) -> WebSocket<TokioIo<Upgraded>> {
161 ws.set_writev(true);
162 ws.set_auto_close(true);
163 ws.set_auto_pong(true);
164
165 ws
166 }
167 pub fn listener_loop(ws_rx: Arc<Mutex<WebSocketRead<ReadHalf<TokioIo<Upgraded>>>>>, ws_tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>, tx: UnboundedSender<String>) -> Result<(), WebSocketError>
168 {
169 tokio::spawn(async move {
170 loop {
171 let mut ws_rx_half = ws_rx.lock().await;
172 let frame = match ws_rx_half.read_frame(&mut |frame| async {
173 let mut ws_write_half = ws_tx.lock().await;
175 return ws_write_half.write_frame(frame).await;
176 }).await {
177 Ok(frame) => frame,
178 Err(WebSocketError::UnexpectedEOF) => {
183 tracing::warn!("WebSocket connection closed (unexpected EOF). Exiting listener loop.");
184 break;
185 }
186 Err(e) => {
187 panic!("Unexpected WebSocket error: {:?}", e);
188 }
189 };
190
191 match frame.opcode {
192 OpCode::Close => break,
193 OpCode::Text | OpCode::Binary => {
194 let incoming = Frame::new(true, frame.opcode, None, frame.payload);
195 assert!(incoming.fin);
196 let string_payload = String::from_utf8(incoming.payload.to_owned());
197 if let Ok(str_payload) = string_payload {
198 tx.send(str_payload).unwrap()
199 }
200 }
201 _ => {}
202 }
203 }
204 });
205 Ok(())
206 }
207} struct SpawnExecutor;
210
211impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
212where
213 Fut: Future + Send + 'static,
214 Fut::Output: Send + 'static,
215{
216 fn execute(&self, fut: Fut) {
217 tokio::task::spawn(fut);
218 }
219}