viewpoint_core/network/websocket/
mod.rs

1//! WebSocket monitoring.
2//!
3//! This module provides functionality for monitoring WebSocket connections,
4//! including frame events for sent and received messages.
5
6// Allow dead code for websocket monitoring scaffolding (spec: network-events)
7
8use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14use tokio::sync::{broadcast, RwLock};
15use tracing::{debug, trace};
16use viewpoint_cdp::protocol::{
17    WebSocketClosedEvent, WebSocketCreatedEvent, WebSocketFrameReceivedEvent,
18    WebSocketFrameSentEvent, WebSocketFrame as CdpWebSocketFrame,
19};
20use viewpoint_cdp::CdpConnection;
21
22/// A WebSocket connection being monitored.
23///
24/// This struct represents an active WebSocket connection and provides
25/// methods to register handlers for frame events.
26#[derive(Clone)]
27pub struct WebSocket {
28    /// The request ID identifying this WebSocket.
29    request_id: String,
30    /// The WebSocket URL.
31    url: String,
32    /// Whether the WebSocket is closed.
33    is_closed: Arc<AtomicBool>,
34    /// Frame sent event broadcaster.
35    frame_sent_tx: broadcast::Sender<WebSocketFrame>,
36    /// Frame received event broadcaster.
37    frame_received_tx: broadcast::Sender<WebSocketFrame>,
38    /// Close event broadcaster.
39    close_tx: broadcast::Sender<()>,
40}
41
42impl std::fmt::Debug for WebSocket {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("WebSocket")
45            .field("request_id", &self.request_id)
46            .field("url", &self.url)
47            .field("is_closed", &self.is_closed.load(Ordering::SeqCst))
48            .finish()
49    }
50}
51
52impl WebSocket {
53    /// Create a new WebSocket instance.
54    pub(crate) fn new(request_id: String, url: String) -> Self {
55        let (frame_sent_tx, _) = broadcast::channel(256);
56        let (frame_received_tx, _) = broadcast::channel(256);
57        let (close_tx, _) = broadcast::channel(16);
58
59        Self {
60            request_id,
61            url,
62            is_closed: Arc::new(AtomicBool::new(false)),
63            frame_sent_tx,
64            frame_received_tx,
65            close_tx,
66        }
67    }
68
69    /// Get the WebSocket URL.
70    pub fn url(&self) -> &str {
71        &self.url
72    }
73
74    /// Check if the WebSocket is closed.
75    pub fn is_closed(&self) -> bool {
76        self.is_closed.load(Ordering::SeqCst)
77    }
78
79    /// Get the request ID for this WebSocket.
80    pub fn request_id(&self) -> &str {
81        &self.request_id
82    }
83
84    /// Register a handler for frame sent events.
85    ///
86    /// The handler will be called whenever a frame is sent over this WebSocket.
87    ///
88    /// # Example
89    ///
90    /// ```no_run
91    /// use viewpoint_core::WebSocket;
92    ///
93    /// # async fn example(websocket: WebSocket) -> Result<(), viewpoint_core::CoreError> {
94    /// websocket.on_framesent(|frame| async move {
95    ///     println!("Sent: {:?}", frame.payload());
96    /// }).await;
97    /// # Ok(())
98    /// # }
99    pub async fn on_framesent<F, Fut>(&self, handler: F)
100    where
101        F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
102        Fut: Future<Output = ()> + Send + 'static,
103    {
104        let mut rx = self.frame_sent_tx.subscribe();
105        tokio::spawn(async move {
106            while let Ok(frame) = rx.recv().await {
107                handler(frame).await;
108            }
109        });
110    }
111
112    /// Register a handler for frame received events.
113    ///
114    /// The handler will be called whenever a frame is received on this WebSocket.
115    ///
116    /// # Example
117    ///
118    /// ```no_run
119    /// use viewpoint_core::WebSocket;
120    ///
121    /// # async fn example(websocket: WebSocket) -> Result<(), viewpoint_core::CoreError> {
122    /// websocket.on_framereceived(|frame| async move {
123    ///     println!("Received: {:?}", frame.payload());
124    /// }).await;
125    /// # Ok(())
126    /// # }
127    pub async fn on_framereceived<F, Fut>(&self, handler: F)
128    where
129        F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
130        Fut: Future<Output = ()> + Send + 'static,
131    {
132        let mut rx = self.frame_received_tx.subscribe();
133        tokio::spawn(async move {
134            while let Ok(frame) = rx.recv().await {
135                handler(frame).await;
136            }
137        });
138    }
139
140    /// Register a handler for WebSocket close events.
141    ///
142    /// The handler will be called when this WebSocket connection is closed.
143    ///
144    /// # Example
145    ///
146    /// ```no_run
147    /// use viewpoint_core::WebSocket;
148    ///
149    /// # async fn example(websocket: WebSocket) -> Result<(), viewpoint_core::CoreError> {
150    /// websocket.on_close(|| async {
151    ///     println!("WebSocket closed");
152    /// }).await;
153    /// # Ok(())
154    /// # }
155    pub async fn on_close<F, Fut>(&self, handler: F)
156    where
157        F: Fn() -> Fut + Send + Sync + 'static,
158        Fut: Future<Output = ()> + Send + 'static,
159    {
160        let mut rx = self.close_tx.subscribe();
161        tokio::spawn(async move {
162            if rx.recv().await.is_ok() {
163                handler().await;
164            }
165        });
166    }
167
168    /// Emit a frame sent event (internal use).
169    pub(crate) fn emit_frame_sent(&self, frame: WebSocketFrame) {
170        let _ = self.frame_sent_tx.send(frame);
171    }
172
173    /// Emit a frame received event (internal use).
174    pub(crate) fn emit_frame_received(&self, frame: WebSocketFrame) {
175        let _ = self.frame_received_tx.send(frame);
176    }
177
178    /// Mark the WebSocket as closed and emit close event (internal use).
179    pub(crate) fn mark_closed(&self) {
180        self.is_closed.store(true, Ordering::SeqCst);
181        let _ = self.close_tx.send(());
182    }
183}
184
185/// A WebSocket message frame.
186#[derive(Debug, Clone)]
187pub struct WebSocketFrame {
188    /// The frame opcode (1 for text, 2 for binary).
189    opcode: u8,
190    /// The frame payload data.
191    payload_data: String,
192}
193
194impl WebSocketFrame {
195    /// Create a new WebSocket frame.
196    pub(crate) fn new(opcode: u8, payload_data: String) -> Self {
197        Self {
198            opcode,
199            payload_data,
200        }
201    }
202
203    /// Create a WebSocket frame from CDP frame data.
204    pub(crate) fn from_cdp(cdp_frame: &CdpWebSocketFrame) -> Self {
205        Self {
206            opcode: cdp_frame.opcode as u8,
207            payload_data: cdp_frame.payload_data.clone(),
208        }
209    }
210
211    /// Get the frame opcode.
212    ///
213    /// Common opcodes:
214    /// - 1: Text frame
215    /// - 2: Binary frame
216    /// - 8: Close frame
217    /// - 9: Ping frame
218    /// - 10: Pong frame
219    pub fn opcode(&self) -> u8 {
220        self.opcode
221    }
222
223    /// Get the frame payload data.
224    pub fn payload(&self) -> &str {
225        &self.payload_data
226    }
227
228    /// Check if this is a text frame.
229    pub fn is_text(&self) -> bool {
230        self.opcode == 1
231    }
232
233    /// Check if this is a binary frame.
234    pub fn is_binary(&self) -> bool {
235        self.opcode == 2
236    }
237}
238
239/// Type alias for the WebSocket event handler function.
240pub type WebSocketEventHandler = Box<
241    dyn Fn(WebSocket) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
242>;
243
244/// Manager for WebSocket events on a page.
245pub struct WebSocketManager {
246    /// CDP connection.
247    connection: Arc<CdpConnection>,
248    /// Session ID.
249    session_id: String,
250    /// Active WebSocket connections indexed by request ID.
251    websockets: Arc<RwLock<HashMap<String, WebSocket>>>,
252    /// WebSocket created event handler.
253    handler: Arc<RwLock<Option<WebSocketEventHandler>>>,
254    /// Whether the manager is listening for events.
255    is_listening: AtomicBool,
256}
257
258impl WebSocketManager {
259    /// Create a new WebSocket manager for a page.
260    pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
261        Self {
262            connection,
263            session_id,
264            websockets: Arc::new(RwLock::new(HashMap::new())),
265            handler: Arc::new(RwLock::new(None)),
266            is_listening: AtomicBool::new(false),
267        }
268    }
269
270    /// Set a handler for WebSocket created events.
271    pub async fn set_handler<F, Fut>(&self, handler: F)
272    where
273        F: Fn(WebSocket) -> Fut + Send + Sync + 'static,
274        Fut: Future<Output = ()> + Send + 'static,
275    {
276        let boxed_handler: WebSocketEventHandler = Box::new(move |ws| {
277            Box::pin(handler(ws))
278        });
279        let mut h = self.handler.write().await;
280        *h = Some(boxed_handler);
281
282        // Start listening for events if not already
283        self.start_listening().await;
284    }
285
286    /// Remove the WebSocket handler.
287    pub async fn remove_handler(&self) {
288        let mut h = self.handler.write().await;
289        *h = None;
290    }
291
292    /// Start listening for WebSocket CDP events.
293    async fn start_listening(&self) {
294        if self.is_listening.swap(true, Ordering::SeqCst) {
295            // Already listening
296            return;
297        }
298
299        let mut events = self.connection.subscribe_events();
300        let session_id = self.session_id.clone();
301        let websockets = self.websockets.clone();
302        let handler = self.handler.clone();
303
304        tokio::spawn(async move {
305            debug!("WebSocket manager started listening for events");
306            
307            while let Ok(event) = events.recv().await {
308                // Filter events for this session
309                if event.session_id.as_deref() != Some(&session_id) {
310                    continue;
311                }
312
313                match event.method.as_str() {
314                    "Network.webSocketCreated" => {
315                        if let Some(params) = &event.params {
316                            if let Ok(created) = serde_json::from_value::<WebSocketCreatedEvent>(params.clone()) {
317                                trace!("WebSocket created: {} -> {}", created.request_id, created.url);
318                                
319                                let ws = WebSocket::new(created.request_id.clone(), created.url);
320                                
321                                // Store the WebSocket
322                                {
323                                    let mut sockets = websockets.write().await;
324                                    sockets.insert(created.request_id, ws.clone());
325                                }
326                                
327                                // Call the handler
328                                let h = handler.read().await;
329                                if let Some(ref handler_fn) = *h {
330                                    handler_fn(ws).await;
331                                }
332                            }
333                        }
334                    }
335                    "Network.webSocketClosed" => {
336                        if let Some(params) = &event.params {
337                            if let Ok(closed) = serde_json::from_value::<WebSocketClosedEvent>(params.clone()) {
338                                trace!("WebSocket closed: {}", closed.request_id);
339                                
340                                let sockets = websockets.read().await;
341                                if let Some(ws) = sockets.get(&closed.request_id) {
342                                    ws.mark_closed();
343                                }
344                            }
345                        }
346                    }
347                    "Network.webSocketFrameSent" => {
348                        if let Some(params) = &event.params {
349                            if let Ok(frame_event) = serde_json::from_value::<WebSocketFrameSentEvent>(params.clone()) {
350                                trace!("WebSocket frame sent: {}", frame_event.request_id);
351                                
352                                let sockets = websockets.read().await;
353                                if let Some(ws) = sockets.get(&frame_event.request_id) {
354                                    let frame = WebSocketFrame::from_cdp(&frame_event.response);
355                                    ws.emit_frame_sent(frame);
356                                }
357                            }
358                        }
359                    }
360                    "Network.webSocketFrameReceived" => {
361                        if let Some(params) = &event.params {
362                            if let Ok(frame_event) = serde_json::from_value::<WebSocketFrameReceivedEvent>(params.clone()) {
363                                trace!("WebSocket frame received: {}", frame_event.request_id);
364                                
365                                let sockets = websockets.read().await;
366                                if let Some(ws) = sockets.get(&frame_event.request_id) {
367                                    let frame = WebSocketFrame::from_cdp(&frame_event.response);
368                                    ws.emit_frame_received(frame);
369                                }
370                            }
371                        }
372                    }
373                    _ => {}
374                }
375            }
376            
377            debug!("WebSocket manager stopped listening");
378        });
379    }
380
381    /// Get a WebSocket by request ID.
382    pub async fn get(&self, request_id: &str) -> Option<WebSocket> {
383        let sockets = self.websockets.read().await;
384        sockets.get(request_id).cloned()
385    }
386
387    /// Get all active `WebSockets`.
388    pub async fn all(&self) -> Vec<WebSocket> {
389        let sockets = self.websockets.read().await;
390        sockets.values().cloned().collect()
391    }
392}
393
394#[cfg(test)]
395mod tests;