viewpoint_core/network/websocket/
mod.rs

1//! WebSocket monitoring.
2//!
3//! This module provides functionality for monitoring WebSocket connections,
4//! including frame events for sent and received messages.
5
6// Allow dead code for websocket monitoring scaffolding (spec: network-events)
7
8use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14use tokio::sync::{broadcast, RwLock};
15use tracing::{debug, trace};
16use viewpoint_cdp::protocol::{
17    WebSocketClosedEvent, WebSocketCreatedEvent, WebSocketFrameReceivedEvent,
18    WebSocketFrameSentEvent, WebSocketFrame as CdpWebSocketFrame,
19};
20use viewpoint_cdp::CdpConnection;
21
22/// A WebSocket connection being monitored.
23///
24/// This struct represents an active WebSocket connection and provides
25/// methods to register handlers for frame events.
26#[derive(Clone)]
27pub struct WebSocket {
28    /// The request ID identifying this WebSocket.
29    request_id: String,
30    /// The WebSocket URL.
31    url: String,
32    /// Whether the WebSocket is closed.
33    is_closed: Arc<AtomicBool>,
34    /// Frame sent event broadcaster.
35    frame_sent_tx: broadcast::Sender<WebSocketFrame>,
36    /// Frame received event broadcaster.
37    frame_received_tx: broadcast::Sender<WebSocketFrame>,
38    /// Close event broadcaster.
39    close_tx: broadcast::Sender<()>,
40}
41
42impl std::fmt::Debug for WebSocket {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("WebSocket")
45            .field("request_id", &self.request_id)
46            .field("url", &self.url)
47            .field("is_closed", &self.is_closed.load(Ordering::SeqCst))
48            .finish()
49    }
50}
51
52impl WebSocket {
53    /// Create a new WebSocket instance.
54    pub(crate) fn new(request_id: String, url: String) -> Self {
55        let (frame_sent_tx, _) = broadcast::channel(256);
56        let (frame_received_tx, _) = broadcast::channel(256);
57        let (close_tx, _) = broadcast::channel(16);
58
59        Self {
60            request_id,
61            url,
62            is_closed: Arc::new(AtomicBool::new(false)),
63            frame_sent_tx,
64            frame_received_tx,
65            close_tx,
66        }
67    }
68
69    /// Get the WebSocket URL.
70    pub fn url(&self) -> &str {
71        &self.url
72    }
73
74    /// Check if the WebSocket is closed.
75    pub fn is_closed(&self) -> bool {
76        self.is_closed.load(Ordering::SeqCst)
77    }
78
79    /// Get the request ID for this WebSocket.
80    pub fn request_id(&self) -> &str {
81        &self.request_id
82    }
83
84    /// Register a handler for frame sent events.
85    ///
86    /// The handler will be called whenever a frame is sent over this WebSocket.
87    ///
88    /// # Example
89    ///
90    /// ```ignore
91    /// websocket.on_framesent(|frame| async move {
92    ///     println!("Sent: {:?}", frame.payload());
93    /// }).await;
94    /// ```
95    pub async fn on_framesent<F, Fut>(&self, handler: F)
96    where
97        F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
98        Fut: Future<Output = ()> + Send + 'static,
99    {
100        let mut rx = self.frame_sent_tx.subscribe();
101        tokio::spawn(async move {
102            while let Ok(frame) = rx.recv().await {
103                handler(frame).await;
104            }
105        });
106    }
107
108    /// Register a handler for frame received events.
109    ///
110    /// The handler will be called whenever a frame is received on this WebSocket.
111    ///
112    /// # Example
113    ///
114    /// ```ignore
115    /// websocket.on_framereceived(|frame| async move {
116    ///     println!("Received: {:?}", frame.payload());
117    /// }).await;
118    /// ```
119    pub async fn on_framereceived<F, Fut>(&self, handler: F)
120    where
121        F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
122        Fut: Future<Output = ()> + Send + 'static,
123    {
124        let mut rx = self.frame_received_tx.subscribe();
125        tokio::spawn(async move {
126            while let Ok(frame) = rx.recv().await {
127                handler(frame).await;
128            }
129        });
130    }
131
132    /// Register a handler for WebSocket close events.
133    ///
134    /// The handler will be called when this WebSocket connection is closed.
135    ///
136    /// # Example
137    ///
138    /// ```ignore
139    /// websocket.on_close(|| async {
140    ///     println!("WebSocket closed");
141    /// }).await;
142    /// ```
143    pub async fn on_close<F, Fut>(&self, handler: F)
144    where
145        F: Fn() -> Fut + Send + Sync + 'static,
146        Fut: Future<Output = ()> + Send + 'static,
147    {
148        let mut rx = self.close_tx.subscribe();
149        tokio::spawn(async move {
150            if rx.recv().await.is_ok() {
151                handler().await;
152            }
153        });
154    }
155
156    /// Emit a frame sent event (internal use).
157    pub(crate) fn emit_frame_sent(&self, frame: WebSocketFrame) {
158        let _ = self.frame_sent_tx.send(frame);
159    }
160
161    /// Emit a frame received event (internal use).
162    pub(crate) fn emit_frame_received(&self, frame: WebSocketFrame) {
163        let _ = self.frame_received_tx.send(frame);
164    }
165
166    /// Mark the WebSocket as closed and emit close event (internal use).
167    pub(crate) fn mark_closed(&self) {
168        self.is_closed.store(true, Ordering::SeqCst);
169        let _ = self.close_tx.send(());
170    }
171}
172
173/// A WebSocket message frame.
174#[derive(Debug, Clone)]
175pub struct WebSocketFrame {
176    /// The frame opcode (1 for text, 2 for binary).
177    opcode: u8,
178    /// The frame payload data.
179    payload_data: String,
180}
181
182impl WebSocketFrame {
183    /// Create a new WebSocket frame.
184    pub(crate) fn new(opcode: u8, payload_data: String) -> Self {
185        Self {
186            opcode,
187            payload_data,
188        }
189    }
190
191    /// Create a WebSocket frame from CDP frame data.
192    pub(crate) fn from_cdp(cdp_frame: &CdpWebSocketFrame) -> Self {
193        Self {
194            opcode: cdp_frame.opcode as u8,
195            payload_data: cdp_frame.payload_data.clone(),
196        }
197    }
198
199    /// Get the frame opcode.
200    ///
201    /// Common opcodes:
202    /// - 1: Text frame
203    /// - 2: Binary frame
204    /// - 8: Close frame
205    /// - 9: Ping frame
206    /// - 10: Pong frame
207    pub fn opcode(&self) -> u8 {
208        self.opcode
209    }
210
211    /// Get the frame payload data.
212    pub fn payload(&self) -> &str {
213        &self.payload_data
214    }
215
216    /// Check if this is a text frame.
217    pub fn is_text(&self) -> bool {
218        self.opcode == 1
219    }
220
221    /// Check if this is a binary frame.
222    pub fn is_binary(&self) -> bool {
223        self.opcode == 2
224    }
225}
226
227/// Type alias for the WebSocket event handler function.
228pub type WebSocketEventHandler = Box<
229    dyn Fn(WebSocket) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
230>;
231
232/// Manager for WebSocket events on a page.
233pub struct WebSocketManager {
234    /// CDP connection.
235    connection: Arc<CdpConnection>,
236    /// Session ID.
237    session_id: String,
238    /// Active WebSocket connections indexed by request ID.
239    websockets: Arc<RwLock<HashMap<String, WebSocket>>>,
240    /// WebSocket created event handler.
241    handler: Arc<RwLock<Option<WebSocketEventHandler>>>,
242    /// Whether the manager is listening for events.
243    is_listening: AtomicBool,
244}
245
246impl WebSocketManager {
247    /// Create a new WebSocket manager for a page.
248    pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
249        Self {
250            connection,
251            session_id,
252            websockets: Arc::new(RwLock::new(HashMap::new())),
253            handler: Arc::new(RwLock::new(None)),
254            is_listening: AtomicBool::new(false),
255        }
256    }
257
258    /// Set a handler for WebSocket created events.
259    pub async fn set_handler<F, Fut>(&self, handler: F)
260    where
261        F: Fn(WebSocket) -> Fut + Send + Sync + 'static,
262        Fut: Future<Output = ()> + Send + 'static,
263    {
264        let boxed_handler: WebSocketEventHandler = Box::new(move |ws| {
265            Box::pin(handler(ws))
266        });
267        let mut h = self.handler.write().await;
268        *h = Some(boxed_handler);
269
270        // Start listening for events if not already
271        self.start_listening().await;
272    }
273
274    /// Remove the WebSocket handler.
275    pub async fn remove_handler(&self) {
276        let mut h = self.handler.write().await;
277        *h = None;
278    }
279
280    /// Start listening for WebSocket CDP events.
281    async fn start_listening(&self) {
282        if self.is_listening.swap(true, Ordering::SeqCst) {
283            // Already listening
284            return;
285        }
286
287        let mut events = self.connection.subscribe_events();
288        let session_id = self.session_id.clone();
289        let websockets = self.websockets.clone();
290        let handler = self.handler.clone();
291
292        tokio::spawn(async move {
293            debug!("WebSocket manager started listening for events");
294            
295            while let Ok(event) = events.recv().await {
296                // Filter events for this session
297                if event.session_id.as_deref() != Some(&session_id) {
298                    continue;
299                }
300
301                match event.method.as_str() {
302                    "Network.webSocketCreated" => {
303                        if let Some(params) = &event.params {
304                            if let Ok(created) = serde_json::from_value::<WebSocketCreatedEvent>(params.clone()) {
305                                trace!("WebSocket created: {} -> {}", created.request_id, created.url);
306                                
307                                let ws = WebSocket::new(created.request_id.clone(), created.url);
308                                
309                                // Store the WebSocket
310                                {
311                                    let mut sockets = websockets.write().await;
312                                    sockets.insert(created.request_id, ws.clone());
313                                }
314                                
315                                // Call the handler
316                                let h = handler.read().await;
317                                if let Some(ref handler_fn) = *h {
318                                    handler_fn(ws).await;
319                                }
320                            }
321                        }
322                    }
323                    "Network.webSocketClosed" => {
324                        if let Some(params) = &event.params {
325                            if let Ok(closed) = serde_json::from_value::<WebSocketClosedEvent>(params.clone()) {
326                                trace!("WebSocket closed: {}", closed.request_id);
327                                
328                                let sockets = websockets.read().await;
329                                if let Some(ws) = sockets.get(&closed.request_id) {
330                                    ws.mark_closed();
331                                }
332                            }
333                        }
334                    }
335                    "Network.webSocketFrameSent" => {
336                        if let Some(params) = &event.params {
337                            if let Ok(frame_event) = serde_json::from_value::<WebSocketFrameSentEvent>(params.clone()) {
338                                trace!("WebSocket frame sent: {}", frame_event.request_id);
339                                
340                                let sockets = websockets.read().await;
341                                if let Some(ws) = sockets.get(&frame_event.request_id) {
342                                    let frame = WebSocketFrame::from_cdp(&frame_event.response);
343                                    ws.emit_frame_sent(frame);
344                                }
345                            }
346                        }
347                    }
348                    "Network.webSocketFrameReceived" => {
349                        if let Some(params) = &event.params {
350                            if let Ok(frame_event) = serde_json::from_value::<WebSocketFrameReceivedEvent>(params.clone()) {
351                                trace!("WebSocket frame received: {}", frame_event.request_id);
352                                
353                                let sockets = websockets.read().await;
354                                if let Some(ws) = sockets.get(&frame_event.request_id) {
355                                    let frame = WebSocketFrame::from_cdp(&frame_event.response);
356                                    ws.emit_frame_received(frame);
357                                }
358                            }
359                        }
360                    }
361                    _ => {}
362                }
363            }
364            
365            debug!("WebSocket manager stopped listening");
366        });
367    }
368
369    /// Get a WebSocket by request ID.
370    pub async fn get(&self, request_id: &str) -> Option<WebSocket> {
371        let sockets = self.websockets.read().await;
372        sockets.get(request_id).cloned()
373    }
374
375    /// Get all active `WebSockets`.
376    pub async fn all(&self) -> Vec<WebSocket> {
377        let sockets = self.websockets.read().await;
378        sockets.values().cloned().collect()
379    }
380}
381
382#[cfg(test)]
383mod tests;