1use std::{error::Error, future::Future};
2use std::fmt::{Display};
3use std::sync::Arc;
4use fastwebsockets::{handshake, Frame, OpCode, WebSocket, WebSocketError, WebSocketRead, WebSocketWrite};
5use hyper::{
6 body::Bytes,
7 header::{CONNECTION, UPGRADE},
8 upgrade::Upgraded,
9 Request,
10};
11use hyper_util::rt::TokioIo;
12use tokio::io::{ReadHalf, WriteHalf};
13use tokio::net::TcpStream;
14use tokio::sync::mpsc::UnboundedSender;
15use tokio::sync::Mutex;
16
17#[derive(Debug, Clone)]
18pub enum ConnectionTransportProtocol {
19 Http,
20 Https,
21 Ws,
22 Wss,
23}
24
25impl Display for ConnectionTransportProtocol {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 let str = match self {
28 ConnectionTransportProtocol::Http => "http",
29 ConnectionTransportProtocol::Https => "https",
30 ConnectionTransportProtocol::Ws => "ws",
31 ConnectionTransportProtocol::Wss => "wss",
32 };
33 write!(f, "{}", str)
34 }
35}
36pub enum UrlFormat {
37 HostPort,
38 ProtocolHostPort,
39 Full, }
41
42#[derive(Debug, Clone)]
43pub struct ConnectionTransportConfig {
44 pub protocol: ConnectionTransportProtocol,
45 pub host: String,
46 pub port: u16,
47 pub path: String,
48}
49
50impl Default for ConnectionTransportConfig {
51 fn default() -> Self {
52 Self {
53 protocol: ConnectionTransportProtocol::Ws,
54 host: String::from("localhost"),
55 port: 0,
56 path: "session".to_string(),
57 }
58 }
59}
60
61impl ConnectionTransportConfig {
62 pub fn full_endpoint(&self) -> String {
63 format!("{}://{}{}", self.protocol, self.host_port(), self.path())
64 }
65
66 pub fn host_port(&self) -> String {
67 format!("{}:{}", self.host, self.port)
68 }
69
70 pub fn path(&self) -> String {
71 let path_str = self.path.trim_start_matches('/');
72 format!("/{}", path_str)
73 }
74
75 pub fn from_ws_url(url: &str) -> Result<Self, String> {
77 let (protocol_str, rest) = url
78 .split_once("://")
79 .ok_or_else(|| format!("missing '://' in URL: {}", url))?;
80
81 let protocol = match protocol_str {
82 "ws" => ConnectionTransportProtocol::Ws,
83 "wss" => ConnectionTransportProtocol::Wss,
84 p => return Err(format!("unsupported WebSocket protocol: {}", p)),
85 };
86
87 let (host_port, path_tail) = rest.split_once('/').unwrap_or((rest, ""));
88 let (host, port_str) = host_port
89 .rsplit_once(':')
90 .ok_or_else(|| format!("missing port in URL: {}", url))?;
91 let port = port_str
92 .parse::<u16>()
93 .map_err(|e| format!("invalid port '{}': {}", port_str, e))?;
94
95 Ok(Self {
96 protocol,
97 host: host.to_string(),
98 port,
99 path: format!("/{}", path_tail),
100 })
101 }
102}
103
104pub trait ConnectionTransport {
105 fn send(&mut self, message: String) -> impl Future<Output=()> + Send;
106 fn listen(&self, listener: UnboundedSender<String>) -> ();
107 fn close(&self) -> impl Future<Output=()> + Send;
108 fn on_close(&self) -> ();
109}
110
111pub struct WebsocketConnectionTransport {
112 client_tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>,
113 client_rx: Arc<Mutex<WebSocketRead<ReadHalf<TokioIo<Upgraded>>>>>,
114}
115
116impl ConnectionTransport for WebsocketConnectionTransport {
117 fn send(&mut self, message: String) -> impl Future<Output=()> + Send
118 {
119 async move {
120 let frame = Frame::text(fastwebsockets::Payload::from(message.as_bytes()));
121 self.client_tx.lock().await.write_frame(frame).await.unwrap();
122 }
123 }
124
125 fn listen(&self, listener: UnboundedSender<String>) -> () {
126 WebsocketConnectionTransport::listener_loop(self.client_rx.clone(), self.client_tx.clone(), listener).unwrap();
127 }
128
129 fn close(&self) -> impl Future<Output=()> + Send {
130 let client_tx = self.client_tx.clone();
131 async move {
132 let mut tx = client_tx.lock().await;
133 let _ = tx.write_frame(Frame::close(1000, b"")).await;
134 }
135 }
136
137 fn on_close(&self) -> () {
138 todo!()
139 }
140}
141
142impl WebsocketConnectionTransport {
143 pub async fn new(connection_config: &ConnectionTransportConfig) -> Result<Self, Box<dyn Error>> {
144 let addr_host = connection_config.host_port();
145
146 let retry_delay_ms = 400;
148 let mut retries = 3;
149
150 tracing::debug!("[WebsocketConnectionTransport]: Connecting to websocket @ url: {}", connection_config.full_endpoint());
151 let stream = loop {
152 match TcpStream::connect(&addr_host).await {
153 Ok(stream) => break stream,
154 Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused && retries > 0 => {
155 tracing::warn!("Connection refused, retrying... ({} attempts remaining)", retries);
156 retries -= 1;
157 tokio::time::sleep(tokio::time::Duration::from_millis(retry_delay_ms)).await;
158 }
159 Err(e) => return Err(Box::new(e)),
160 }
161 };
162
163 let uri = connection_config.path();
164 let req = Request::builder()
165 .method("GET")
166 .uri(uri)
167 .header("Host", &addr_host)
168 .header(UPGRADE, "websocket")
169 .header(CONNECTION, "upgrade")
170 .header(
171 "Sec-WebSocket-Key",
172 fastwebsockets::handshake::generate_key(),
173 )
174 .header("Sec-WebSocket-Version", "13")
175 .body(http_body_util::Empty::<Bytes>::new()).unwrap();
176
177 let (mut ws, _) = handshake::client(&SpawnExecutor, req, stream).await.unwrap();
178 ws = Self::configure_client(ws);
179 let (rx, tx) = ws.split(tokio::io::split);
180
181 Ok(Self {
182 client_rx: Arc::new(Mutex::new(rx)),
183 client_tx: Arc::new(Mutex::new(tx))
184 })
185 }
186
187 fn configure_client(mut ws: WebSocket<TokioIo<Upgraded>>) -> WebSocket<TokioIo<Upgraded>> {
188 ws.set_writev(true);
189 ws.set_auto_close(true);
190 ws.set_auto_pong(true);
191
192 ws
193 }
194 pub fn listener_loop(ws_rx: Arc<Mutex<WebSocketRead<ReadHalf<TokioIo<Upgraded>>>>>, ws_tx: Arc<Mutex<WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>>>, tx: UnboundedSender<String>) -> Result<(), WebSocketError>
195 {
196 tokio::spawn(async move {
197 loop {
198 let mut ws_rx_half = ws_rx.lock().await;
199 let frame = match ws_rx_half.read_frame(&mut |frame| async {
200 let mut ws_write_half = ws_tx.lock().await;
202 return ws_write_half.write_frame(frame).await;
203 }).await {
204 Ok(frame) => frame,
205 Err(WebSocketError::UnexpectedEOF) => {
210 tracing::warn!("WebSocket connection closed (unexpected EOF). Exiting listener loop.");
211 break;
212 }
213 Err(e) => {
214 panic!("Unexpected WebSocket error: {:?}", e);
215 }
216 };
217
218 match frame.opcode {
219 OpCode::Close => break,
220 OpCode::Text | OpCode::Binary => {
221 let incoming = Frame::new(true, frame.opcode, None, frame.payload);
222 assert!(incoming.fin);
223 let string_payload = String::from_utf8(incoming.payload.to_owned());
224 if let Ok(str_payload) = string_payload {
225 tx.send(str_payload).unwrap()
226 }
227 }
228 _ => {}
229 }
230 }
231 });
232 Ok(())
233 }
234} struct SpawnExecutor;
237
238impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
239where
240 Fut: Future + Send + 'static,
241 Fut::Output: Send + 'static,
242{
243 fn execute(&self, fut: Fut) {
244 tokio::task::spawn(fut);
245 }
246}