1use async_trait::async_trait;
20use futures::stream::StreamExt;
21use serde_json::json;
22use std::time::Duration;
23use tokio::time::timeout;
24use tokio_tungstenite::tungstenite::Message;
25use tokio_tungstenite::tungstenite::client::IntoClientRequest;
26
27use crate::domain::error::{Result, ServiceError, StygianError};
28use crate::ports::stream_source::{StreamEvent, StreamSourcePort};
29use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
30
31#[derive(Debug, Clone)]
48pub struct WebSocketConfig {
49 pub subscribe_message: Option<String>,
51 pub bearer_token: Option<String>,
53 pub timeout_secs: u64,
55 pub max_reconnect_attempts: u32,
57}
58
59impl Default for WebSocketConfig {
60 fn default() -> Self {
61 Self {
62 subscribe_message: None,
63 bearer_token: None,
64 timeout_secs: 30,
65 max_reconnect_attempts: 3,
66 }
67 }
68}
69
70#[derive(Default)]
77pub struct WebSocketSource {
78 config: WebSocketConfig,
79}
80
81impl WebSocketSource {
82 #[must_use]
84 pub const fn new(config: WebSocketConfig) -> Self {
85 Self { config }
86 }
87
88 fn config_from_params(&self, params: &serde_json::Value) -> WebSocketConfig {
90 let mut cfg = self.config.clone();
91 if let Some(msg) = params.get("subscribe_message").and_then(|v| v.as_str()) {
92 cfg.subscribe_message = Some(msg.to_string());
93 }
94 if let Some(token) = params.get("bearer_token").and_then(|v| v.as_str()) {
95 cfg.bearer_token = Some(token.to_string());
96 }
97 if let Some(t) = params
98 .get("timeout_secs")
99 .and_then(serde_json::Value::as_u64)
100 {
101 cfg.timeout_secs = t;
102 }
103 if let Some(r) = params
104 .get("max_reconnect_attempts")
105 .and_then(serde_json::Value::as_u64)
106 {
107 cfg.max_reconnect_attempts = u32::try_from(r).unwrap_or(u32::MAX);
108 }
109 cfg
110 }
111
112 async fn collect_events(
114 &self,
115 url: &str,
116 max_events: Option<usize>,
117 cfg: &WebSocketConfig,
118 ) -> Result<Vec<StreamEvent>> {
119 let mut request = url.into_client_request().map_err(|e| {
120 StygianError::Service(ServiceError::Unavailable(format!(
121 "invalid WebSocket URL: {e}"
122 )))
123 })?;
124
125 if let Some(token) = &cfg.bearer_token {
127 request.headers_mut().insert(
128 reqwest::header::AUTHORIZATION,
129 format!("Bearer {token}").parse().map_err(|e| {
130 StygianError::Service(ServiceError::Unavailable(format!(
131 "invalid auth header: {e}"
132 )))
133 })?,
134 );
135 }
136
137 let connect_timeout = Duration::from_secs(cfg.timeout_secs);
138 let (ws_stream, _) = timeout(connect_timeout, tokio_tungstenite::connect_async(request))
139 .await
140 .map_err(|_| {
141 StygianError::Service(ServiceError::Unavailable(
142 "WebSocket connection timed out".into(),
143 ))
144 })?
145 .map_err(|e| {
146 StygianError::Service(ServiceError::Unavailable(format!(
147 "WebSocket connection failed: {e}"
148 )))
149 })?;
150
151 let (mut write, mut read) = ws_stream.split();
152
153 if let Some(ref sub_msg) = cfg.subscribe_message {
155 use futures::SinkExt;
156 write
157 .send(Message::Text(sub_msg.clone().into()))
158 .await
159 .map_err(|e| {
160 StygianError::Service(ServiceError::Unavailable(format!(
161 "failed to send subscribe message: {e}"
162 )))
163 })?;
164 }
165
166 let mut events = Vec::new();
167 let mut frame_idx: u64 = 0;
168
169 while let Some(msg_result) = timeout(Duration::from_secs(cfg.timeout_secs), read.next())
170 .await
171 .ok()
172 .flatten()
173 {
174 match msg_result {
175 Ok(msg) => {
176 if let Some(event) = map_message_to_event(msg, frame_idx) {
177 events.push(event);
178 frame_idx += 1;
179
180 if let Some(max) = max_events
181 && events.len() >= max
182 {
183 break;
184 }
185 }
186 }
187 Err(e) => {
188 tracing::warn!("WebSocket receive error: {e}");
189 break;
190 }
191 }
192 }
193
194 Ok(events)
195 }
196}
197
198fn map_message_to_event(msg: Message, frame_idx: u64) -> Option<StreamEvent> {
202 match msg {
203 Message::Text(text) => Some(StreamEvent {
204 id: Some(frame_idx.to_string()),
205 event_type: Some("text".into()),
206 data: text.to_string(),
207 }),
208 Message::Binary(data) => {
209 use base64::Engine;
210 let encoded = base64::engine::general_purpose::STANDARD.encode(&data);
211 Some(StreamEvent {
212 id: Some(frame_idx.to_string()),
213 event_type: Some("binary".into()),
214 data: encoded,
215 })
216 }
217 Message::Ping(data) => Some(StreamEvent {
218 id: Some(frame_idx.to_string()),
219 event_type: Some("ping".into()),
220 data: String::from_utf8_lossy(&data).to_string(),
221 }),
222 Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
224 }
225}
226
227#[async_trait]
230impl StreamSourcePort for WebSocketSource {
231 async fn subscribe(&self, url: &str, max_events: Option<usize>) -> Result<Vec<StreamEvent>> {
232 let cfg = self.config.clone();
233 let mut last_err = None;
234
235 for attempt in 0..=cfg.max_reconnect_attempts {
236 match self.collect_events(url, max_events, &cfg).await {
237 Ok(events) => return Ok(events),
238 Err(e) => {
239 tracing::warn!(
240 "WebSocket attempt {}/{} failed: {e}",
241 attempt + 1,
242 cfg.max_reconnect_attempts + 1
243 );
244 last_err = Some(e);
245
246 if attempt < cfg.max_reconnect_attempts {
247 let backoff = Duration::from_secs(1 << attempt);
249 tokio::time::sleep(backoff).await;
250 }
251 }
252 }
253 }
254
255 Err(last_err.unwrap_or_else(|| {
256 StygianError::Service(ServiceError::Unavailable(
257 "WebSocket connection failed after all retries".into(),
258 ))
259 }))
260 }
261
262 fn source_name(&self) -> &'static str {
263 "websocket"
264 }
265}
266
267#[async_trait]
270impl ScrapingService for WebSocketSource {
271 async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
280 let cfg = self.config_from_params(&input.params);
281 let max_events = input
282 .params
283 .get("max_events")
284 .and_then(serde_json::Value::as_u64)
285 .map(|n| usize::try_from(n).unwrap_or(usize::MAX));
286
287 let events = self.collect_events(&input.url, max_events, &cfg).await?;
288 let count = events.len();
289
290 let data = serde_json::to_string(&events).map_err(|e| {
291 StygianError::Service(ServiceError::InvalidResponse(format!(
292 "websocket serialization failed: {e}"
293 )))
294 })?;
295
296 Ok(ServiceOutput {
297 data,
298 metadata: json!({
299 "source": "websocket",
300 "event_count": count,
301 "source_url": input.url,
302 }),
303 })
304 }
305
306 fn name(&self) -> &'static str {
307 "websocket"
308 }
309}
310
311#[cfg(test)]
314mod tests {
315 use base64::Engine;
316
317 use super::*;
318
319 #[test]
320 fn map_text_frame() -> std::result::Result<(), Box<dyn std::error::Error>> {
321 let msg = Message::Text(r#"{"price": 42.5}"#.into());
322 let event =
323 map_message_to_event(msg, 0).ok_or_else(|| std::io::Error::other("should map"))?;
324 assert_eq!(event.id.as_deref(), Some("0"));
325 assert_eq!(event.event_type.as_deref(), Some("text"));
326 assert_eq!(event.data, r#"{"price": 42.5}"#);
327 Ok(())
328 }
329
330 #[test]
331 fn map_binary_frame_to_base64() -> std::result::Result<(), Box<dyn std::error::Error>> {
332 let data = vec![0xDE, 0xAD, 0xBE, 0xEF];
333 let msg = Message::Binary(data.into());
334 let event =
335 map_message_to_event(msg, 1).ok_or_else(|| std::io::Error::other("should map"))?;
336 assert_eq!(event.event_type.as_deref(), Some("binary"));
337 let decoded = base64::engine::general_purpose::STANDARD.decode(&event.data)?;
339 assert_eq!(decoded, vec![0xDE, 0xAD, 0xBE, 0xEF]);
340 Ok(())
341 }
342
343 #[test]
344 fn map_ping_frame() -> std::result::Result<(), Box<dyn std::error::Error>> {
345 let msg = Message::Ping(vec![1, 2, 3].into());
346 let event =
347 map_message_to_event(msg, 2).ok_or_else(|| std::io::Error::other("should map"))?;
348 assert_eq!(event.event_type.as_deref(), Some("ping"));
349 Ok(())
350 }
351
352 #[test]
353 fn pong_frame_is_skipped() {
354 let msg = Message::Pong(vec![].into());
355 assert!(map_message_to_event(msg, 0).is_none());
356 }
357
358 #[test]
359 fn close_frame_is_skipped() {
360 let msg = Message::Close(None);
361 assert!(map_message_to_event(msg, 0).is_none());
362 }
363
364 #[test]
365 fn default_config() {
366 let cfg = WebSocketConfig::default();
367 assert_eq!(cfg.timeout_secs, 30);
368 assert_eq!(cfg.max_reconnect_attempts, 3);
369 assert!(cfg.subscribe_message.is_none());
370 assert!(cfg.bearer_token.is_none());
371 }
372
373 #[test]
374 fn config_from_params_overrides() {
375 let source = WebSocketSource::default();
376 let params = json!({
377 "subscribe_message": "{\"action\":\"sub\"}",
378 "bearer_token": "tok123",
379 "timeout_secs": 60,
380 "max_reconnect_attempts": 5
381 });
382 let cfg = source.config_from_params(¶ms);
383 assert_eq!(
384 cfg.subscribe_message.as_deref(),
385 Some("{\"action\":\"sub\"}")
386 );
387 assert_eq!(cfg.bearer_token.as_deref(), Some("tok123"));
388 assert_eq!(cfg.timeout_secs, 60);
389 assert_eq!(cfg.max_reconnect_attempts, 5);
390 }
391
392 #[test]
393 fn frame_index_increments() {
394 let msgs = vec![
395 Message::Text("a".into()),
396 Message::Pong(vec![].into()), Message::Text("b".into()),
398 ];
399
400 let mut idx: u64 = 0;
401 let mut events = Vec::new();
402 for msg in msgs {
403 if let Some(event) = map_message_to_event(msg, idx) {
404 events.push(event);
405 idx += 1;
406 }
407 }
408
409 assert_eq!(events.len(), 2);
410 assert_eq!(events.first().and_then(|e| e.id.as_deref()), Some("0"));
411 assert_eq!(events.get(1).and_then(|e| e.id.as_deref()), Some("1"));
412 }
413
414 #[tokio::test]
416 #[ignore = "requires WebSocket echo server"]
417 async fn connect_to_echo_server() -> std::result::Result<(), Box<dyn std::error::Error>> {
418 let source = WebSocketSource::new(WebSocketConfig {
419 subscribe_message: Some("hello".into()),
420 timeout_secs: 5,
421 ..WebSocketConfig::default()
422 });
423 let events = source
424 .subscribe("ws://127.0.0.1:9001/echo", Some(1))
425 .await?;
426 assert!(!events.is_empty());
427 Ok(())
428 }
429}