1use async_trait::async_trait;
20use futures::stream::StreamExt;
21use serde_json::json;
22use std::time::Duration;
23use tokio::time::timeout;
24use tokio_tungstenite::tungstenite::Message;
25use tokio_tungstenite::tungstenite::client::IntoClientRequest;
26
27use crate::domain::error::{Result, ServiceError, StygianError};
28use crate::ports::stream_source::{StreamEvent, StreamSourcePort};
29use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
30
31#[derive(Debug, Clone)]
48pub struct WebSocketConfig {
49 pub subscribe_message: Option<String>,
51 pub bearer_token: Option<String>,
53 pub timeout_secs: u64,
55 pub max_reconnect_attempts: u32,
57}
58
59impl Default for WebSocketConfig {
60 fn default() -> Self {
61 Self {
62 subscribe_message: None,
63 bearer_token: None,
64 timeout_secs: 30,
65 max_reconnect_attempts: 3,
66 }
67 }
68}
69
70#[derive(Default)]
77pub struct WebSocketSource {
78 config: WebSocketConfig,
79}
80
81impl WebSocketSource {
82 pub const fn new(config: WebSocketConfig) -> Self {
84 Self { config }
85 }
86
87 fn config_from_params(&self, params: &serde_json::Value) -> WebSocketConfig {
89 let mut cfg = self.config.clone();
90 if let Some(msg) = params.get("subscribe_message").and_then(|v| v.as_str()) {
91 cfg.subscribe_message = Some(msg.to_string());
92 }
93 if let Some(token) = params.get("bearer_token").and_then(|v| v.as_str()) {
94 cfg.bearer_token = Some(token.to_string());
95 }
96 if let Some(t) = params
97 .get("timeout_secs")
98 .and_then(serde_json::Value::as_u64)
99 {
100 cfg.timeout_secs = t;
101 }
102 if let Some(r) = params
103 .get("max_reconnect_attempts")
104 .and_then(serde_json::Value::as_u64)
105 {
106 cfg.max_reconnect_attempts = u32::try_from(r).unwrap_or(u32::MAX);
107 }
108 cfg
109 }
110
111 async fn collect_events(
113 &self,
114 url: &str,
115 max_events: Option<usize>,
116 cfg: &WebSocketConfig,
117 ) -> Result<Vec<StreamEvent>> {
118 let mut request = url.into_client_request().map_err(|e| {
119 StygianError::Service(ServiceError::Unavailable(format!(
120 "invalid WebSocket URL: {e}"
121 )))
122 })?;
123
124 if let Some(token) = &cfg.bearer_token {
126 request.headers_mut().insert(
127 reqwest::header::AUTHORIZATION,
128 format!("Bearer {token}").parse().map_err(|e| {
129 StygianError::Service(ServiceError::Unavailable(format!(
130 "invalid auth header: {e}"
131 )))
132 })?,
133 );
134 }
135
136 let connect_timeout = Duration::from_secs(cfg.timeout_secs);
137 let (ws_stream, _) = timeout(connect_timeout, tokio_tungstenite::connect_async(request))
138 .await
139 .map_err(|_| {
140 StygianError::Service(ServiceError::Unavailable(
141 "WebSocket connection timed out".into(),
142 ))
143 })?
144 .map_err(|e| {
145 StygianError::Service(ServiceError::Unavailable(format!(
146 "WebSocket connection failed: {e}"
147 )))
148 })?;
149
150 let (mut write, mut read) = ws_stream.split();
151
152 if let Some(ref sub_msg) = cfg.subscribe_message {
154 use futures::SinkExt;
155 write
156 .send(Message::Text(sub_msg.clone().into()))
157 .await
158 .map_err(|e| {
159 StygianError::Service(ServiceError::Unavailable(format!(
160 "failed to send subscribe message: {e}"
161 )))
162 })?;
163 }
164
165 let mut events = Vec::new();
166 let mut frame_idx: u64 = 0;
167
168 while let Some(msg_result) = timeout(Duration::from_secs(cfg.timeout_secs), read.next())
169 .await
170 .ok()
171 .flatten()
172 {
173 match msg_result {
174 Ok(msg) => {
175 if let Some(event) = map_message_to_event(msg, frame_idx) {
176 events.push(event);
177 frame_idx += 1;
178
179 if let Some(max) = max_events
180 && events.len() >= max
181 {
182 break;
183 }
184 }
185 }
186 Err(e) => {
187 tracing::warn!("WebSocket receive error: {e}");
188 break;
189 }
190 }
191 }
192
193 Ok(events)
194 }
195}
196
197fn map_message_to_event(msg: Message, frame_idx: u64) -> Option<StreamEvent> {
201 match msg {
202 Message::Text(text) => Some(StreamEvent {
203 id: Some(frame_idx.to_string()),
204 event_type: Some("text".into()),
205 data: text.to_string(),
206 }),
207 Message::Binary(data) => {
208 use base64::Engine;
209 let encoded = base64::engine::general_purpose::STANDARD.encode(&data);
210 Some(StreamEvent {
211 id: Some(frame_idx.to_string()),
212 event_type: Some("binary".into()),
213 data: encoded,
214 })
215 }
216 Message::Ping(data) => Some(StreamEvent {
217 id: Some(frame_idx.to_string()),
218 event_type: Some("ping".into()),
219 data: String::from_utf8_lossy(&data).to_string(),
220 }),
221 Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
223 }
224}
225
226#[async_trait]
229impl StreamSourcePort for WebSocketSource {
230 async fn subscribe(&self, url: &str, max_events: Option<usize>) -> Result<Vec<StreamEvent>> {
231 let cfg = self.config.clone();
232 let mut last_err = None;
233
234 for attempt in 0..=cfg.max_reconnect_attempts {
235 match self.collect_events(url, max_events, &cfg).await {
236 Ok(events) => return Ok(events),
237 Err(e) => {
238 tracing::warn!(
239 "WebSocket attempt {}/{} failed: {e}",
240 attempt + 1,
241 cfg.max_reconnect_attempts + 1
242 );
243 last_err = Some(e);
244
245 if attempt < cfg.max_reconnect_attempts {
246 let backoff = Duration::from_secs(1 << attempt);
248 tokio::time::sleep(backoff).await;
249 }
250 }
251 }
252 }
253
254 Err(last_err.unwrap_or_else(|| {
255 StygianError::Service(ServiceError::Unavailable(
256 "WebSocket connection failed after all retries".into(),
257 ))
258 }))
259 }
260
261 fn source_name(&self) -> &'static str {
262 "websocket"
263 }
264}
265
266#[async_trait]
269impl ScrapingService for WebSocketSource {
270 async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
279 let cfg = self.config_from_params(&input.params);
280 let max_events = input
281 .params
282 .get("max_events")
283 .and_then(serde_json::Value::as_u64)
284 .map(|n| usize::try_from(n).unwrap_or(usize::MAX));
285
286 let events = self.collect_events(&input.url, max_events, &cfg).await?;
287 let count = events.len();
288
289 let data = serde_json::to_string(&events).map_err(|e| {
290 StygianError::Service(ServiceError::InvalidResponse(format!(
291 "websocket serialization failed: {e}"
292 )))
293 })?;
294
295 Ok(ServiceOutput {
296 data,
297 metadata: json!({
298 "source": "websocket",
299 "event_count": count,
300 "source_url": input.url,
301 }),
302 })
303 }
304
305 fn name(&self) -> &'static str {
306 "websocket"
307 }
308}
309
310#[cfg(test)]
313mod tests {
314 use base64::Engine;
315
316 use super::*;
317
318 #[test]
319 fn map_text_frame() -> std::result::Result<(), Box<dyn std::error::Error>> {
320 let msg = Message::Text(r#"{"price": 42.5}"#.into());
321 let event =
322 map_message_to_event(msg, 0).ok_or_else(|| std::io::Error::other("should map"))?;
323 assert_eq!(event.id.as_deref(), Some("0"));
324 assert_eq!(event.event_type.as_deref(), Some("text"));
325 assert_eq!(event.data, r#"{"price": 42.5}"#);
326 Ok(())
327 }
328
329 #[test]
330 fn map_binary_frame_to_base64() -> std::result::Result<(), Box<dyn std::error::Error>> {
331 let data = vec![0xDE, 0xAD, 0xBE, 0xEF];
332 let msg = Message::Binary(data.into());
333 let event =
334 map_message_to_event(msg, 1).ok_or_else(|| std::io::Error::other("should map"))?;
335 assert_eq!(event.event_type.as_deref(), Some("binary"));
336 let decoded = base64::engine::general_purpose::STANDARD.decode(&event.data)?;
338 assert_eq!(decoded, vec![0xDE, 0xAD, 0xBE, 0xEF]);
339 Ok(())
340 }
341
342 #[test]
343 fn map_ping_frame() -> std::result::Result<(), Box<dyn std::error::Error>> {
344 let msg = Message::Ping(vec![1, 2, 3].into());
345 let event =
346 map_message_to_event(msg, 2).ok_or_else(|| std::io::Error::other("should map"))?;
347 assert_eq!(event.event_type.as_deref(), Some("ping"));
348 Ok(())
349 }
350
351 #[test]
352 fn pong_frame_is_skipped() {
353 let msg = Message::Pong(vec![].into());
354 assert!(map_message_to_event(msg, 0).is_none());
355 }
356
357 #[test]
358 fn close_frame_is_skipped() {
359 let msg = Message::Close(None);
360 assert!(map_message_to_event(msg, 0).is_none());
361 }
362
363 #[test]
364 fn default_config() {
365 let cfg = WebSocketConfig::default();
366 assert_eq!(cfg.timeout_secs, 30);
367 assert_eq!(cfg.max_reconnect_attempts, 3);
368 assert!(cfg.subscribe_message.is_none());
369 assert!(cfg.bearer_token.is_none());
370 }
371
372 #[test]
373 fn config_from_params_overrides() {
374 let source = WebSocketSource::default();
375 let params = json!({
376 "subscribe_message": "{\"action\":\"sub\"}",
377 "bearer_token": "tok123",
378 "timeout_secs": 60,
379 "max_reconnect_attempts": 5
380 });
381 let cfg = source.config_from_params(¶ms);
382 assert_eq!(
383 cfg.subscribe_message.as_deref(),
384 Some("{\"action\":\"sub\"}")
385 );
386 assert_eq!(cfg.bearer_token.as_deref(), Some("tok123"));
387 assert_eq!(cfg.timeout_secs, 60);
388 assert_eq!(cfg.max_reconnect_attempts, 5);
389 }
390
391 #[test]
392 fn frame_index_increments() {
393 let msgs = vec![
394 Message::Text("a".into()),
395 Message::Pong(vec![].into()), Message::Text("b".into()),
397 ];
398
399 let mut idx: u64 = 0;
400 let mut events = Vec::new();
401 for msg in msgs {
402 if let Some(event) = map_message_to_event(msg, idx) {
403 events.push(event);
404 idx += 1;
405 }
406 }
407
408 assert_eq!(events.len(), 2);
409 assert_eq!(events.first().and_then(|e| e.id.as_deref()), Some("0"));
410 assert_eq!(events.get(1).and_then(|e| e.id.as_deref()), Some("1"));
411 }
412
413 #[tokio::test]
415 #[ignore = "requires WebSocket echo server"]
416 async fn connect_to_echo_server() -> std::result::Result<(), Box<dyn std::error::Error>> {
417 let source = WebSocketSource::new(WebSocketConfig {
418 subscribe_message: Some("hello".into()),
419 timeout_secs: 5,
420 ..WebSocketConfig::default()
421 });
422 let events = source
423 .subscribe("ws://127.0.0.1:9001/echo", Some(1))
424 .await?;
425 assert!(!events.is_empty());
426 Ok(())
427 }
428}