viewpoint_core/network/websocket/
mod.rs1use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use tokio::sync::{RwLock, broadcast};
15use tracing::{debug, trace};
16use viewpoint_cdp::CdpConnection;
17use viewpoint_cdp::protocol::{
18 WebSocketClosedEvent, WebSocketCreatedEvent, WebSocketFrame as CdpWebSocketFrame,
19 WebSocketFrameReceivedEvent, WebSocketFrameSentEvent,
20};
21
22#[derive(Clone)]
27pub struct WebSocket {
28 request_id: String,
30 url: String,
32 is_closed: Arc<AtomicBool>,
34 frame_sent_tx: broadcast::Sender<WebSocketFrame>,
36 frame_received_tx: broadcast::Sender<WebSocketFrame>,
38 close_tx: broadcast::Sender<()>,
40}
41
42impl std::fmt::Debug for WebSocket {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("WebSocket")
45 .field("request_id", &self.request_id)
46 .field("url", &self.url)
47 .field("is_closed", &self.is_closed.load(Ordering::SeqCst))
48 .finish()
49 }
50}
51
52impl WebSocket {
53 pub(crate) fn new(request_id: String, url: String) -> Self {
55 let (frame_sent_tx, _) = broadcast::channel(256);
56 let (frame_received_tx, _) = broadcast::channel(256);
57 let (close_tx, _) = broadcast::channel(16);
58
59 Self {
60 request_id,
61 url,
62 is_closed: Arc::new(AtomicBool::new(false)),
63 frame_sent_tx,
64 frame_received_tx,
65 close_tx,
66 }
67 }
68
69 pub fn url(&self) -> &str {
71 &self.url
72 }
73
74 pub fn is_closed(&self) -> bool {
76 self.is_closed.load(Ordering::SeqCst)
77 }
78
79 pub fn request_id(&self) -> &str {
81 &self.request_id
82 }
83
84 pub async fn on_framesent<F, Fut>(&self, handler: F)
100 where
101 F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
102 Fut: Future<Output = ()> + Send + 'static,
103 {
104 let mut rx = self.frame_sent_tx.subscribe();
105 tokio::spawn(async move {
106 while let Ok(frame) = rx.recv().await {
107 handler(frame).await;
108 }
109 });
110 }
111
112 pub async fn on_framereceived<F, Fut>(&self, handler: F)
128 where
129 F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
130 Fut: Future<Output = ()> + Send + 'static,
131 {
132 let mut rx = self.frame_received_tx.subscribe();
133 tokio::spawn(async move {
134 while let Ok(frame) = rx.recv().await {
135 handler(frame).await;
136 }
137 });
138 }
139
140 pub async fn on_close<F, Fut>(&self, handler: F)
156 where
157 F: Fn() -> Fut + Send + Sync + 'static,
158 Fut: Future<Output = ()> + Send + 'static,
159 {
160 let mut rx = self.close_tx.subscribe();
161 tokio::spawn(async move {
162 if rx.recv().await.is_ok() {
163 handler().await;
164 }
165 });
166 }
167
168 pub(crate) fn emit_frame_sent(&self, frame: WebSocketFrame) {
170 let _ = self.frame_sent_tx.send(frame);
171 }
172
173 pub(crate) fn emit_frame_received(&self, frame: WebSocketFrame) {
175 let _ = self.frame_received_tx.send(frame);
176 }
177
178 pub(crate) fn mark_closed(&self) {
180 self.is_closed.store(true, Ordering::SeqCst);
181 let _ = self.close_tx.send(());
182 }
183}
184
185#[derive(Debug, Clone)]
187pub struct WebSocketFrame {
188 opcode: u8,
190 payload_data: String,
192}
193
194impl WebSocketFrame {
195 pub(crate) fn new(opcode: u8, payload_data: String) -> Self {
197 Self {
198 opcode,
199 payload_data,
200 }
201 }
202
203 pub(crate) fn from_cdp(cdp_frame: &CdpWebSocketFrame) -> Self {
205 Self {
206 opcode: cdp_frame.opcode as u8,
207 payload_data: cdp_frame.payload_data.clone(),
208 }
209 }
210
211 pub fn opcode(&self) -> u8 {
220 self.opcode
221 }
222
223 pub fn payload(&self) -> &str {
225 &self.payload_data
226 }
227
228 pub fn is_text(&self) -> bool {
230 self.opcode == 1
231 }
232
233 pub fn is_binary(&self) -> bool {
235 self.opcode == 2
236 }
237}
238
239pub type WebSocketEventHandler =
241 Box<dyn Fn(WebSocket) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
242
243pub struct WebSocketManager {
245 connection: Arc<CdpConnection>,
247 session_id: String,
249 websockets: Arc<RwLock<HashMap<String, WebSocket>>>,
251 handler: Arc<RwLock<Option<WebSocketEventHandler>>>,
253 is_listening: AtomicBool,
255}
256
257impl WebSocketManager {
258 pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
260 Self {
261 connection,
262 session_id,
263 websockets: Arc::new(RwLock::new(HashMap::new())),
264 handler: Arc::new(RwLock::new(None)),
265 is_listening: AtomicBool::new(false),
266 }
267 }
268
269 pub async fn set_handler<F, Fut>(&self, handler: F)
271 where
272 F: Fn(WebSocket) -> Fut + Send + Sync + 'static,
273 Fut: Future<Output = ()> + Send + 'static,
274 {
275 let boxed_handler: WebSocketEventHandler = Box::new(move |ws| Box::pin(handler(ws)));
276 let mut h = self.handler.write().await;
277 *h = Some(boxed_handler);
278
279 self.start_listening().await;
281 }
282
283 pub async fn remove_handler(&self) {
285 let mut h = self.handler.write().await;
286 *h = None;
287 }
288
289 async fn start_listening(&self) {
291 if self.is_listening.swap(true, Ordering::SeqCst) {
292 return;
294 }
295
296 let mut events = self.connection.subscribe_events();
297 let session_id = self.session_id.clone();
298 let websockets = self.websockets.clone();
299 let handler = self.handler.clone();
300
301 tokio::spawn(async move {
302 debug!("WebSocket manager started listening for events");
303
304 while let Ok(event) = events.recv().await {
305 if event.session_id.as_deref() != Some(&session_id) {
307 continue;
308 }
309
310 match event.method.as_str() {
311 "Network.webSocketCreated" => {
312 if let Some(params) = &event.params {
313 if let Ok(created) =
314 serde_json::from_value::<WebSocketCreatedEvent>(params.clone())
315 {
316 trace!(
317 "WebSocket created: {} -> {}",
318 created.request_id, created.url
319 );
320
321 let ws = WebSocket::new(created.request_id.clone(), created.url);
322
323 {
325 let mut sockets = websockets.write().await;
326 sockets.insert(created.request_id, ws.clone());
327 }
328
329 let h = handler.read().await;
331 if let Some(ref handler_fn) = *h {
332 handler_fn(ws).await;
333 }
334 }
335 }
336 }
337 "Network.webSocketClosed" => {
338 if let Some(params) = &event.params {
339 if let Ok(closed) =
340 serde_json::from_value::<WebSocketClosedEvent>(params.clone())
341 {
342 trace!("WebSocket closed: {}", closed.request_id);
343
344 let sockets = websockets.read().await;
345 if let Some(ws) = sockets.get(&closed.request_id) {
346 ws.mark_closed();
347 }
348 }
349 }
350 }
351 "Network.webSocketFrameSent" => {
352 if let Some(params) = &event.params {
353 if let Ok(frame_event) =
354 serde_json::from_value::<WebSocketFrameSentEvent>(params.clone())
355 {
356 trace!("WebSocket frame sent: {}", frame_event.request_id);
357
358 let sockets = websockets.read().await;
359 if let Some(ws) = sockets.get(&frame_event.request_id) {
360 let frame = WebSocketFrame::from_cdp(&frame_event.response);
361 ws.emit_frame_sent(frame);
362 }
363 }
364 }
365 }
366 "Network.webSocketFrameReceived" => {
367 if let Some(params) = &event.params {
368 if let Ok(frame_event) = serde_json::from_value::<
369 WebSocketFrameReceivedEvent,
370 >(params.clone())
371 {
372 trace!("WebSocket frame received: {}", frame_event.request_id);
373
374 let sockets = websockets.read().await;
375 if let Some(ws) = sockets.get(&frame_event.request_id) {
376 let frame = WebSocketFrame::from_cdp(&frame_event.response);
377 ws.emit_frame_received(frame);
378 }
379 }
380 }
381 }
382 _ => {}
383 }
384 }
385
386 debug!("WebSocket manager stopped listening");
387 });
388 }
389
390 pub async fn get(&self, request_id: &str) -> Option<WebSocket> {
392 let sockets = self.websockets.read().await;
393 sockets.get(request_id).cloned()
394 }
395
396 pub async fn all(&self) -> Vec<WebSocket> {
398 let sockets = self.websockets.read().await;
399 sockets.values().cloned().collect()
400 }
401}
402
403#[cfg(test)]
404mod tests;