Skip to main content

stygian_graph/adapters/
websocket.rs

1//! WebSocket stream source adapter.
2//!
3//! Implements [`StreamSourcePort`](crate::ports::stream_source::StreamSourcePort) and [`ScrapingService`](crate::ports::ScrapingService) for consuming
4//! WebSocket feeds.  Uses `tokio-tungstenite` for the underlying connection.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use stygian_graph::adapters::websocket::WebSocketSource;
10//! use stygian_graph::ports::stream_source::StreamSourcePort;
11//!
12//! # async fn example() {
13//! let source = WebSocketSource::default();
14//! let events = source.subscribe("wss://api.example.com/ws", Some(10)).await.unwrap();
15//! println!("received {} events", events.len());
16//! # }
17//! ```
18
19use async_trait::async_trait;
20use futures::stream::StreamExt;
21use serde_json::json;
22use std::time::Duration;
23use tokio::time::timeout;
24use tokio_tungstenite::tungstenite::Message;
25use tokio_tungstenite::tungstenite::client::IntoClientRequest;
26
27use crate::domain::error::{Result, ServiceError, StygianError};
28use crate::ports::stream_source::{StreamEvent, StreamSourcePort};
29use crate::ports::{ScrapingService, ServiceInput, ServiceOutput};
30
31// ─── Configuration ────────────────────────────────────────────────────────────
32
33/// Configuration for a WebSocket connection.
34///
35/// # Example
36///
37/// ```
38/// use stygian_graph::adapters::websocket::WebSocketConfig;
39///
40/// let config = WebSocketConfig {
41///     subscribe_message: Some(r#"{"type":"subscribe","channel":"prices"}"#.into()),
42///     bearer_token: None,
43///     timeout_secs: 30,
44///     max_reconnect_attempts: 3,
45/// };
46/// ```
47#[derive(Debug, Clone)]
48pub struct WebSocketConfig {
49    /// Optional message to send immediately after connecting (e.g. subscribe).
50    pub subscribe_message: Option<String>,
51    /// Optional Bearer token for Authorization header on the upgrade request.
52    pub bearer_token: Option<String>,
53    /// Connection timeout in seconds.
54    pub timeout_secs: u64,
55    /// Maximum reconnection attempts on connection drop.
56    pub max_reconnect_attempts: u32,
57}
58
59impl Default for WebSocketConfig {
60    fn default() -> Self {
61        Self {
62            subscribe_message: None,
63            bearer_token: None,
64            timeout_secs: 30,
65            max_reconnect_attempts: 3,
66        }
67    }
68}
69
70// ─── Adapter ──────────────────────────────────────────────────────────────────
71
72/// WebSocket stream source adapter.
73///
74/// Connects to a WebSocket endpoint and collects messages until `max_events`
75/// is reached, the stream closes, or a connection timeout occurs.
76#[derive(Default)]
77pub struct WebSocketSource {
78    config: WebSocketConfig,
79}
80
81impl WebSocketSource {
82    /// Create a new WebSocket source with custom configuration.
83    pub fn new(config: WebSocketConfig) -> Self {
84        Self { config }
85    }
86
87    /// Extract configuration from `ServiceInput.params` overrides.
88    fn config_from_params(&self, params: &serde_json::Value) -> WebSocketConfig {
89        let mut cfg = self.config.clone();
90        if let Some(msg) = params.get("subscribe_message").and_then(|v| v.as_str()) {
91            cfg.subscribe_message = Some(msg.to_string());
92        }
93        if let Some(token) = params.get("bearer_token").and_then(|v| v.as_str()) {
94            cfg.bearer_token = Some(token.to_string());
95        }
96        if let Some(t) = params.get("timeout_secs").and_then(|v| v.as_u64()) {
97            cfg.timeout_secs = t;
98        }
99        if let Some(r) = params
100            .get("max_reconnect_attempts")
101            .and_then(|v| v.as_u64())
102        {
103            cfg.max_reconnect_attempts = r as u32;
104        }
105        cfg
106    }
107
108    /// Connect and collect events from a WebSocket endpoint.
109    async fn collect_events(
110        &self,
111        url: &str,
112        max_events: Option<usize>,
113        cfg: &WebSocketConfig,
114    ) -> Result<Vec<StreamEvent>> {
115        let mut request = url.into_client_request().map_err(|e| {
116            StygianError::Service(ServiceError::Unavailable(format!(
117                "invalid WebSocket URL: {e}"
118            )))
119        })?;
120
121        // Inject auth header if configured
122        if let Some(token) = &cfg.bearer_token {
123            request.headers_mut().insert(
124                reqwest::header::AUTHORIZATION,
125                format!("Bearer {token}").parse().map_err(|e| {
126                    StygianError::Service(ServiceError::Unavailable(format!(
127                        "invalid auth header: {e}"
128                    )))
129                })?,
130            );
131        }
132
133        let connect_timeout = Duration::from_secs(cfg.timeout_secs);
134        let (ws_stream, _) = timeout(connect_timeout, tokio_tungstenite::connect_async(request))
135            .await
136            .map_err(|_| {
137                StygianError::Service(ServiceError::Unavailable(
138                    "WebSocket connection timed out".into(),
139                ))
140            })?
141            .map_err(|e| {
142                StygianError::Service(ServiceError::Unavailable(format!(
143                    "WebSocket connection failed: {e}"
144                )))
145            })?;
146
147        let (mut write, mut read) = ws_stream.split();
148
149        // Send subscribe message if configured
150        if let Some(ref sub_msg) = cfg.subscribe_message {
151            use futures::SinkExt;
152            write
153                .send(Message::Text(sub_msg.clone().into()))
154                .await
155                .map_err(|e| {
156                    StygianError::Service(ServiceError::Unavailable(format!(
157                        "failed to send subscribe message: {e}"
158                    )))
159                })?;
160        }
161
162        let mut events = Vec::new();
163        let mut frame_idx: u64 = 0;
164
165        while let Some(msg_result) = timeout(Duration::from_secs(cfg.timeout_secs), read.next())
166            .await
167            .ok()
168            .flatten()
169        {
170            match msg_result {
171                Ok(msg) => {
172                    if let Some(event) = map_message_to_event(msg, frame_idx) {
173                        events.push(event);
174                        frame_idx += 1;
175
176                        if let Some(max) = max_events
177                            && events.len() >= max
178                        {
179                            break;
180                        }
181                    }
182                }
183                Err(e) => {
184                    tracing::warn!("WebSocket receive error: {e}");
185                    break;
186                }
187            }
188        }
189
190        Ok(events)
191    }
192}
193
194/// Map a WebSocket message to a [`StreamEvent`].
195///
196/// Returns `None` for internal frames (Pong, Close, Frame).
197fn map_message_to_event(msg: Message, frame_idx: u64) -> Option<StreamEvent> {
198    match msg {
199        Message::Text(text) => Some(StreamEvent {
200            id: Some(frame_idx.to_string()),
201            event_type: Some("text".into()),
202            data: text.to_string(),
203        }),
204        Message::Binary(data) => {
205            use base64::Engine;
206            let encoded = base64::engine::general_purpose::STANDARD.encode(&data);
207            Some(StreamEvent {
208                id: Some(frame_idx.to_string()),
209                event_type: Some("binary".into()),
210                data: encoded,
211            })
212        }
213        Message::Ping(data) => Some(StreamEvent {
214            id: Some(frame_idx.to_string()),
215            event_type: Some("ping".into()),
216            data: String::from_utf8_lossy(&data).to_string(),
217        }),
218        // Pong, Close, and Frame are internal — skip
219        Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
220    }
221}
222
223// ─── StreamSourcePort ─────────────────────────────────────────────────────────
224
225#[async_trait]
226impl StreamSourcePort for WebSocketSource {
227    async fn subscribe(&self, url: &str, max_events: Option<usize>) -> Result<Vec<StreamEvent>> {
228        let cfg = self.config.clone();
229        let mut last_err = None;
230
231        for attempt in 0..=cfg.max_reconnect_attempts {
232            match self.collect_events(url, max_events, &cfg).await {
233                Ok(events) => return Ok(events),
234                Err(e) => {
235                    tracing::warn!(
236                        "WebSocket attempt {}/{} failed: {e}",
237                        attempt + 1,
238                        cfg.max_reconnect_attempts + 1
239                    );
240                    last_err = Some(e);
241
242                    if attempt < cfg.max_reconnect_attempts {
243                        // Exponential backoff: 1s, 2s, 4s ...
244                        let backoff = Duration::from_secs(1 << attempt);
245                        tokio::time::sleep(backoff).await;
246                    }
247                }
248            }
249        }
250
251        Err(last_err.unwrap_or_else(|| {
252            StygianError::Service(ServiceError::Unavailable(
253                "WebSocket connection failed after all retries".into(),
254            ))
255        }))
256    }
257
258    fn source_name(&self) -> &str {
259        "websocket"
260    }
261}
262
263// ─── ScrapingService ──────────────────────────────────────────────────────────
264
265#[async_trait]
266impl ScrapingService for WebSocketSource {
267    /// Collect messages from a WebSocket and return as JSON array.
268    ///
269    /// # Params (optional)
270    ///
271    /// * `max_events` — integer; maximum messages to collect.
272    /// * `subscribe_message` — string; message to send on connect.
273    /// * `bearer_token` — string; Bearer token for auth header.
274    /// * `timeout_secs` — integer; connection/read timeout.
275    async fn execute(&self, input: ServiceInput) -> Result<ServiceOutput> {
276        let cfg = self.config_from_params(&input.params);
277        let max_events = input
278            .params
279            .get("max_events")
280            .and_then(|v| v.as_u64())
281            .map(|n| n as usize);
282
283        let events = self.collect_events(&input.url, max_events, &cfg).await?;
284        let count = events.len();
285
286        let data = serde_json::to_string(&events).map_err(|e| {
287            StygianError::Service(ServiceError::InvalidResponse(format!(
288                "websocket serialization failed: {e}"
289            )))
290        })?;
291
292        Ok(ServiceOutput {
293            data,
294            metadata: json!({
295                "source": "websocket",
296                "event_count": count,
297                "source_url": input.url,
298            }),
299        })
300    }
301
302    fn name(&self) -> &'static str {
303        "websocket"
304    }
305}
306
307// ─── Tests ────────────────────────────────────────────────────────────────────
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn map_text_frame() {
315        let msg = Message::Text(r#"{"price": 42.5}"#.into());
316        let event = map_message_to_event(msg, 0).expect("should map");
317        assert_eq!(event.id.as_deref(), Some("0"));
318        assert_eq!(event.event_type.as_deref(), Some("text"));
319        assert_eq!(event.data, r#"{"price": 42.5}"#);
320    }
321
322    #[test]
323    fn map_binary_frame_to_base64() {
324        let data = vec![0xDE, 0xAD, 0xBE, 0xEF];
325        let msg = Message::Binary(data.into());
326        let event = map_message_to_event(msg, 1).expect("should map");
327        assert_eq!(event.event_type.as_deref(), Some("binary"));
328        // Verify it's valid base64
329        use base64::Engine;
330        let decoded = base64::engine::general_purpose::STANDARD
331            .decode(&event.data)
332            .expect("valid base64");
333        assert_eq!(decoded, vec![0xDE, 0xAD, 0xBE, 0xEF]);
334    }
335
336    #[test]
337    fn map_ping_frame() {
338        let msg = Message::Ping(vec![1, 2, 3].into());
339        let event = map_message_to_event(msg, 2).expect("should map");
340        assert_eq!(event.event_type.as_deref(), Some("ping"));
341    }
342
343    #[test]
344    fn pong_frame_is_skipped() {
345        let msg = Message::Pong(vec![].into());
346        assert!(map_message_to_event(msg, 0).is_none());
347    }
348
349    #[test]
350    fn close_frame_is_skipped() {
351        let msg = Message::Close(None);
352        assert!(map_message_to_event(msg, 0).is_none());
353    }
354
355    #[test]
356    fn default_config() {
357        let cfg = WebSocketConfig::default();
358        assert_eq!(cfg.timeout_secs, 30);
359        assert_eq!(cfg.max_reconnect_attempts, 3);
360        assert!(cfg.subscribe_message.is_none());
361        assert!(cfg.bearer_token.is_none());
362    }
363
364    #[test]
365    fn config_from_params_overrides() {
366        let source = WebSocketSource::default();
367        let params = json!({
368            "subscribe_message": "{\"action\":\"sub\"}",
369            "bearer_token": "tok123",
370            "timeout_secs": 60,
371            "max_reconnect_attempts": 5
372        });
373        let cfg = source.config_from_params(&params);
374        assert_eq!(
375            cfg.subscribe_message.as_deref(),
376            Some("{\"action\":\"sub\"}")
377        );
378        assert_eq!(cfg.bearer_token.as_deref(), Some("tok123"));
379        assert_eq!(cfg.timeout_secs, 60);
380        assert_eq!(cfg.max_reconnect_attempts, 5);
381    }
382
383    #[test]
384    fn frame_index_increments() {
385        let msgs = vec![
386            Message::Text("a".into()),
387            Message::Pong(vec![].into()), // skipped
388            Message::Text("b".into()),
389        ];
390
391        let mut idx: u64 = 0;
392        let mut events = Vec::new();
393        for msg in msgs {
394            if let Some(event) = map_message_to_event(msg, idx) {
395                events.push(event);
396                idx += 1;
397            }
398        }
399
400        assert_eq!(events.len(), 2);
401        assert_eq!(events[0].id.as_deref(), Some("0"));
402        assert_eq!(events[1].id.as_deref(), Some("1"));
403    }
404
405    // Integration tests require a running WebSocket server — marked #[ignore]
406    #[tokio::test]
407    #[ignore = "requires WebSocket echo server"]
408    async fn connect_to_echo_server() {
409        let source = WebSocketSource::new(WebSocketConfig {
410            subscribe_message: Some("hello".into()),
411            timeout_secs: 5,
412            ..WebSocketConfig::default()
413        });
414        let events = source
415            .subscribe("ws://127.0.0.1:9001/echo", Some(1))
416            .await
417            .expect("connect");
418        assert!(!events.is_empty());
419    }
420}