syncable_ag_ui_client/
ws.rs1use std::pin::Pin;
20use std::task::{Context, Poll};
21
22use syncable_ag_ui_core::{Event, JsonValue};
23use futures::{SinkExt, Stream};
24use tokio_tungstenite::{
25 connect_async,
26 tungstenite::{self, Message},
27 MaybeTlsStream, WebSocketStream,
28};
29
30use crate::error::{ClientError, Result};
31
32#[derive(Debug, Clone)]
34pub struct WsConfig {
35 pub headers: Vec<(String, String)>,
37 pub auto_pong: bool,
39}
40
41impl Default for WsConfig {
42 fn default() -> Self {
43 Self {
44 headers: Vec::new(),
45 auto_pong: true,
46 }
47 }
48}
49
50impl WsConfig {
51 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
58 self.headers.push((name.into(), value.into()));
59 self
60 }
61
62 pub fn bearer_token(self, token: impl Into<String>) -> Self {
64 self.header("Authorization", format!("Bearer {}", token.into()))
65 }
66
67 pub fn disable_auto_pong(mut self) -> Self {
69 self.auto_pong = false;
70 self
71 }
72}
73
74pub struct WsClient {
79 socket: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
80 auto_pong: bool,
81}
82
83impl WsClient {
84 pub async fn connect(url: &str) -> Result<Self> {
96 Self::connect_with_config(url, WsConfig::default()).await
97 }
98
99 pub async fn connect_with_config(url: &str, config: WsConfig) -> Result<Self> {
114 let mut request = tungstenite::http::Request::builder()
116 .uri(url)
117 .header("Host", extract_host(url)?)
118 .header("Connection", "Upgrade")
119 .header("Upgrade", "websocket")
120 .header("Sec-WebSocket-Version", "13")
121 .header(
122 "Sec-WebSocket-Key",
123 tungstenite::handshake::client::generate_key(),
124 );
125
126 for (name, value) in config.headers {
127 request = request.header(name, value);
128 }
129
130 let request = request
131 .body(())
132 .map_err(|e| ClientError::connection(e.to_string()))?;
133
134 let (socket, _response) = connect_async(request)
135 .await
136 .map_err(|e| ClientError::connection(e.to_string()))?;
137
138 Ok(Self {
139 socket,
140 auto_pong: config.auto_pong,
141 })
142 }
143
144 pub fn into_stream(self) -> WsEventStream {
148 WsEventStream {
149 socket: self.socket,
150 auto_pong: self.auto_pong,
151 }
152 }
153
154 pub async fn close(mut self) -> Result<()> {
156 self.socket
157 .close(None)
158 .await
159 .map_err(|e| ClientError::connection(e.to_string()))
160 }
161}
162
163pub struct WsEventStream {
167 socket: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
168 auto_pong: bool,
169}
170
171impl Stream for WsEventStream {
172 type Item = Result<Event<JsonValue>>;
173
174 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
175 loop {
176 match Pin::new(&mut self.socket).poll_next(cx) {
177 Poll::Ready(Some(Ok(msg))) => {
178 match msg {
179 Message::Text(text) => {
180 match serde_json::from_str::<Event<JsonValue>>(&text) {
182 Ok(event) => return Poll::Ready(Some(Ok(event))),
183 Err(e) => {
184 return Poll::Ready(Some(Err(ClientError::parse(format!(
185 "failed to parse event: {}",
186 e
187 )))))
188 }
189 }
190 }
191 Message::Ping(data) => {
192 if self.auto_pong {
193 let mut socket = Pin::new(&mut self.socket);
195 let _ = socket.start_send_unpin(Message::Pong(data));
196 }
197 continue;
198 }
199 Message::Pong(_) => {
200 continue;
202 }
203 Message::Close(_) => {
204 return Poll::Ready(None);
205 }
206 Message::Binary(_) | Message::Frame(_) => {
207 continue;
209 }
210 }
211 }
212 Poll::Ready(Some(Err(e))) => {
213 return Poll::Ready(Some(Err(ClientError::WebSocket(e))))
214 }
215 Poll::Ready(None) => return Poll::Ready(None),
216 Poll::Pending => return Poll::Pending,
217 }
218 }
219 }
220}
221
222fn extract_host(url: &str) -> Result<String> {
224 let url = url::Url::parse(url).map_err(|e| ClientError::InvalidUrl(e.to_string()))?;
225
226 let host = url
227 .host_str()
228 .ok_or_else(|| ClientError::InvalidUrl("missing host".to_string()))?;
229
230 match url.port() {
231 Some(port) => Ok(format!("{}:{}", host, port)),
232 None => Ok(host.to_string()),
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_ws_config_default() {
242 let config = WsConfig::default();
243 assert!(config.headers.is_empty());
244 assert!(config.auto_pong);
245 }
246
247 #[test]
248 fn test_ws_config_builder() {
249 let config = WsConfig::new()
250 .header("X-Custom", "value")
251 .bearer_token("token123")
252 .disable_auto_pong();
253
254 assert_eq!(config.headers.len(), 2);
255 assert_eq!(config.headers[0], ("X-Custom".to_string(), "value".to_string()));
256 assert_eq!(
257 config.headers[1],
258 ("Authorization".to_string(), "Bearer token123".to_string())
259 );
260 assert!(!config.auto_pong);
261 }
262
263 #[test]
264 fn test_extract_host() {
265 assert_eq!(extract_host("ws://localhost:3000/ws").unwrap(), "localhost:3000");
266 assert_eq!(extract_host("wss://example.com/events").unwrap(), "example.com");
267 assert_eq!(
268 extract_host("ws://api.example.com:8080/stream").unwrap(),
269 "api.example.com:8080"
270 );
271 }
272
273 #[test]
274 fn test_extract_host_invalid() {
275 assert!(extract_host("not a url").is_err());
276 }
277}