viewpoint_core/network/websocket/
mod.rs

1//! WebSocket monitoring.
2//!
3//! This module provides functionality for monitoring WebSocket connections,
4//! including frame events for sent and received messages.
5
6// Allow dead code for websocket monitoring scaffolding (spec: network-events)
7
8use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use tokio::sync::{RwLock, broadcast};
15use tracing::{debug, trace};
16use viewpoint_cdp::CdpConnection;
17use viewpoint_cdp::protocol::{
18    WebSocketClosedEvent, WebSocketCreatedEvent, WebSocketFrame as CdpWebSocketFrame,
19    WebSocketFrameReceivedEvent, WebSocketFrameSentEvent,
20};
21
22/// A WebSocket connection being monitored.
23///
24/// This struct represents an active WebSocket connection and provides
25/// methods to register handlers for frame events.
26#[derive(Clone)]
27pub struct WebSocket {
28    /// The request ID identifying this WebSocket.
29    request_id: String,
30    /// The WebSocket URL.
31    url: String,
32    /// Whether the WebSocket is closed.
33    is_closed: Arc<AtomicBool>,
34    /// Frame sent event broadcaster.
35    frame_sent_tx: broadcast::Sender<WebSocketFrame>,
36    /// Frame received event broadcaster.
37    frame_received_tx: broadcast::Sender<WebSocketFrame>,
38    /// Close event broadcaster.
39    close_tx: broadcast::Sender<()>,
40}
41
42impl std::fmt::Debug for WebSocket {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("WebSocket")
45            .field("request_id", &self.request_id)
46            .field("url", &self.url)
47            .field("is_closed", &self.is_closed.load(Ordering::SeqCst))
48            .finish()
49    }
50}
51
52impl WebSocket {
53    /// Create a new WebSocket instance.
54    pub(crate) fn new(request_id: String, url: String) -> Self {
55        let (frame_sent_tx, _) = broadcast::channel(256);
56        let (frame_received_tx, _) = broadcast::channel(256);
57        let (close_tx, _) = broadcast::channel(16);
58
59        Self {
60            request_id,
61            url,
62            is_closed: Arc::new(AtomicBool::new(false)),
63            frame_sent_tx,
64            frame_received_tx,
65            close_tx,
66        }
67    }
68
69    /// Get the WebSocket URL.
70    pub fn url(&self) -> &str {
71        &self.url
72    }
73
74    /// Check if the WebSocket is closed.
75    pub fn is_closed(&self) -> bool {
76        self.is_closed.load(Ordering::SeqCst)
77    }
78
79    /// Get the request ID for this WebSocket.
80    pub fn request_id(&self) -> &str {
81        &self.request_id
82    }
83
84    /// Register a handler for frame sent events.
85    ///
86    /// The handler will be called whenever a frame is sent over this WebSocket.
87    ///
88    /// # Example
89    ///
90    /// ```no_run
91    /// use viewpoint_core::WebSocket;
92    ///
93    /// # async fn example(websocket: WebSocket) -> Result<(), viewpoint_core::CoreError> {
94    /// websocket.on_framesent(|frame| async move {
95    ///     println!("Sent: {:?}", frame.payload());
96    /// }).await;
97    /// # Ok(())
98    /// # }
99    pub async fn on_framesent<F, Fut>(&self, handler: F)
100    where
101        F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
102        Fut: Future<Output = ()> + Send + 'static,
103    {
104        let mut rx = self.frame_sent_tx.subscribe();
105        tokio::spawn(async move {
106            while let Ok(frame) = rx.recv().await {
107                handler(frame).await;
108            }
109        });
110    }
111
112    /// Register a handler for frame received events.
113    ///
114    /// The handler will be called whenever a frame is received on this WebSocket.
115    ///
116    /// # Example
117    ///
118    /// ```no_run
119    /// use viewpoint_core::WebSocket;
120    ///
121    /// # async fn example(websocket: WebSocket) -> Result<(), viewpoint_core::CoreError> {
122    /// websocket.on_framereceived(|frame| async move {
123    ///     println!("Received: {:?}", frame.payload());
124    /// }).await;
125    /// # Ok(())
126    /// # }
127    pub async fn on_framereceived<F, Fut>(&self, handler: F)
128    where
129        F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
130        Fut: Future<Output = ()> + Send + 'static,
131    {
132        let mut rx = self.frame_received_tx.subscribe();
133        tokio::spawn(async move {
134            while let Ok(frame) = rx.recv().await {
135                handler(frame).await;
136            }
137        });
138    }
139
140    /// Register a handler for WebSocket close events.
141    ///
142    /// The handler will be called when this WebSocket connection is closed.
143    ///
144    /// # Example
145    ///
146    /// ```no_run
147    /// use viewpoint_core::WebSocket;
148    ///
149    /// # async fn example(websocket: WebSocket) -> Result<(), viewpoint_core::CoreError> {
150    /// websocket.on_close(|| async {
151    ///     println!("WebSocket closed");
152    /// }).await;
153    /// # Ok(())
154    /// # }
155    pub async fn on_close<F, Fut>(&self, handler: F)
156    where
157        F: Fn() -> Fut + Send + Sync + 'static,
158        Fut: Future<Output = ()> + Send + 'static,
159    {
160        let mut rx = self.close_tx.subscribe();
161        tokio::spawn(async move {
162            if rx.recv().await.is_ok() {
163                handler().await;
164            }
165        });
166    }
167
168    /// Emit a frame sent event (internal use).
169    pub(crate) fn emit_frame_sent(&self, frame: WebSocketFrame) {
170        let _ = self.frame_sent_tx.send(frame);
171    }
172
173    /// Emit a frame received event (internal use).
174    pub(crate) fn emit_frame_received(&self, frame: WebSocketFrame) {
175        let _ = self.frame_received_tx.send(frame);
176    }
177
178    /// Mark the WebSocket as closed and emit close event (internal use).
179    pub(crate) fn mark_closed(&self) {
180        self.is_closed.store(true, Ordering::SeqCst);
181        let _ = self.close_tx.send(());
182    }
183}
184
185/// A WebSocket message frame.
186#[derive(Debug, Clone)]
187pub struct WebSocketFrame {
188    /// The frame opcode (1 for text, 2 for binary).
189    opcode: u8,
190    /// The frame payload data.
191    payload_data: String,
192}
193
194impl WebSocketFrame {
195    /// Create a new WebSocket frame.
196    pub(crate) fn new(opcode: u8, payload_data: String) -> Self {
197        Self {
198            opcode,
199            payload_data,
200        }
201    }
202
203    /// Create a WebSocket frame from CDP frame data.
204    pub(crate) fn from_cdp(cdp_frame: &CdpWebSocketFrame) -> Self {
205        Self {
206            opcode: cdp_frame.opcode as u8,
207            payload_data: cdp_frame.payload_data.clone(),
208        }
209    }
210
211    /// Get the frame opcode.
212    ///
213    /// Common opcodes:
214    /// - 1: Text frame
215    /// - 2: Binary frame
216    /// - 8: Close frame
217    /// - 9: Ping frame
218    /// - 10: Pong frame
219    pub fn opcode(&self) -> u8 {
220        self.opcode
221    }
222
223    /// Get the frame payload data.
224    pub fn payload(&self) -> &str {
225        &self.payload_data
226    }
227
228    /// Check if this is a text frame.
229    pub fn is_text(&self) -> bool {
230        self.opcode == 1
231    }
232
233    /// Check if this is a binary frame.
234    pub fn is_binary(&self) -> bool {
235        self.opcode == 2
236    }
237}
238
239/// Type alias for the WebSocket event handler function.
240pub type WebSocketEventHandler =
241    Box<dyn Fn(WebSocket) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
242
243/// Manager for WebSocket events on a page.
244pub struct WebSocketManager {
245    /// CDP connection.
246    connection: Arc<CdpConnection>,
247    /// Session ID.
248    session_id: String,
249    /// Active WebSocket connections indexed by request ID.
250    websockets: Arc<RwLock<HashMap<String, WebSocket>>>,
251    /// WebSocket created event handler.
252    handler: Arc<RwLock<Option<WebSocketEventHandler>>>,
253    /// Whether the manager is listening for events.
254    is_listening: AtomicBool,
255}
256
257impl WebSocketManager {
258    /// Create a new WebSocket manager for a page.
259    pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
260        Self {
261            connection,
262            session_id,
263            websockets: Arc::new(RwLock::new(HashMap::new())),
264            handler: Arc::new(RwLock::new(None)),
265            is_listening: AtomicBool::new(false),
266        }
267    }
268
269    /// Set a handler for WebSocket created events.
270    pub async fn set_handler<F, Fut>(&self, handler: F)
271    where
272        F: Fn(WebSocket) -> Fut + Send + Sync + 'static,
273        Fut: Future<Output = ()> + Send + 'static,
274    {
275        let boxed_handler: WebSocketEventHandler = Box::new(move |ws| Box::pin(handler(ws)));
276        let mut h = self.handler.write().await;
277        *h = Some(boxed_handler);
278
279        // Start listening for events if not already
280        self.start_listening().await;
281    }
282
283    /// Remove the WebSocket handler.
284    pub async fn remove_handler(&self) {
285        let mut h = self.handler.write().await;
286        *h = None;
287    }
288
289    /// Start listening for WebSocket CDP events.
290    async fn start_listening(&self) {
291        if self.is_listening.swap(true, Ordering::SeqCst) {
292            // Already listening
293            return;
294        }
295
296        let mut events = self.connection.subscribe_events();
297        let session_id = self.session_id.clone();
298        let websockets = self.websockets.clone();
299        let handler = self.handler.clone();
300
301        tokio::spawn(async move {
302            debug!("WebSocket manager started listening for events");
303
304            while let Ok(event) = events.recv().await {
305                // Filter events for this session
306                if event.session_id.as_deref() != Some(&session_id) {
307                    continue;
308                }
309
310                match event.method.as_str() {
311                    "Network.webSocketCreated" => {
312                        if let Some(params) = &event.params {
313                            if let Ok(created) =
314                                serde_json::from_value::<WebSocketCreatedEvent>(params.clone())
315                            {
316                                trace!(
317                                    "WebSocket created: {} -> {}",
318                                    created.request_id, created.url
319                                );
320
321                                let ws = WebSocket::new(created.request_id.clone(), created.url);
322
323                                // Store the WebSocket
324                                {
325                                    let mut sockets = websockets.write().await;
326                                    sockets.insert(created.request_id, ws.clone());
327                                }
328
329                                // Call the handler
330                                let h = handler.read().await;
331                                if let Some(ref handler_fn) = *h {
332                                    handler_fn(ws).await;
333                                }
334                            }
335                        }
336                    }
337                    "Network.webSocketClosed" => {
338                        if let Some(params) = &event.params {
339                            if let Ok(closed) =
340                                serde_json::from_value::<WebSocketClosedEvent>(params.clone())
341                            {
342                                trace!("WebSocket closed: {}", closed.request_id);
343
344                                let sockets = websockets.read().await;
345                                if let Some(ws) = sockets.get(&closed.request_id) {
346                                    ws.mark_closed();
347                                }
348                            }
349                        }
350                    }
351                    "Network.webSocketFrameSent" => {
352                        if let Some(params) = &event.params {
353                            if let Ok(frame_event) =
354                                serde_json::from_value::<WebSocketFrameSentEvent>(params.clone())
355                            {
356                                trace!("WebSocket frame sent: {}", frame_event.request_id);
357
358                                let sockets = websockets.read().await;
359                                if let Some(ws) = sockets.get(&frame_event.request_id) {
360                                    let frame = WebSocketFrame::from_cdp(&frame_event.response);
361                                    ws.emit_frame_sent(frame);
362                                }
363                            }
364                        }
365                    }
366                    "Network.webSocketFrameReceived" => {
367                        if let Some(params) = &event.params {
368                            if let Ok(frame_event) = serde_json::from_value::<
369                                WebSocketFrameReceivedEvent,
370                            >(params.clone())
371                            {
372                                trace!("WebSocket frame received: {}", frame_event.request_id);
373
374                                let sockets = websockets.read().await;
375                                if let Some(ws) = sockets.get(&frame_event.request_id) {
376                                    let frame = WebSocketFrame::from_cdp(&frame_event.response);
377                                    ws.emit_frame_received(frame);
378                                }
379                            }
380                        }
381                    }
382                    _ => {}
383                }
384            }
385
386            debug!("WebSocket manager stopped listening");
387        });
388    }
389
390    /// Get a WebSocket by request ID.
391    pub async fn get(&self, request_id: &str) -> Option<WebSocket> {
392        let sockets = self.websockets.read().await;
393        sockets.get(request_id).cloned()
394    }
395
396    /// Get all active `WebSockets`.
397    pub async fn all(&self) -> Vec<WebSocket> {
398        let sockets = self.websockets.read().await;
399        sockets.values().cloned().collect()
400    }
401}
402
403#[cfg(test)]
404mod tests;