viewpoint_core/network/websocket/
mod.rs1use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14use tokio::sync::{broadcast, RwLock};
15use tracing::{debug, trace};
16use viewpoint_cdp::protocol::{
17 WebSocketClosedEvent, WebSocketCreatedEvent, WebSocketFrameReceivedEvent,
18 WebSocketFrameSentEvent, WebSocketFrame as CdpWebSocketFrame,
19};
20use viewpoint_cdp::CdpConnection;
21
22#[derive(Clone)]
27pub struct WebSocket {
28 request_id: String,
30 url: String,
32 is_closed: Arc<AtomicBool>,
34 frame_sent_tx: broadcast::Sender<WebSocketFrame>,
36 frame_received_tx: broadcast::Sender<WebSocketFrame>,
38 close_tx: broadcast::Sender<()>,
40}
41
42impl std::fmt::Debug for WebSocket {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("WebSocket")
45 .field("request_id", &self.request_id)
46 .field("url", &self.url)
47 .field("is_closed", &self.is_closed.load(Ordering::SeqCst))
48 .finish()
49 }
50}
51
52impl WebSocket {
53 pub(crate) fn new(request_id: String, url: String) -> Self {
55 let (frame_sent_tx, _) = broadcast::channel(256);
56 let (frame_received_tx, _) = broadcast::channel(256);
57 let (close_tx, _) = broadcast::channel(16);
58
59 Self {
60 request_id,
61 url,
62 is_closed: Arc::new(AtomicBool::new(false)),
63 frame_sent_tx,
64 frame_received_tx,
65 close_tx,
66 }
67 }
68
69 pub fn url(&self) -> &str {
71 &self.url
72 }
73
74 pub fn is_closed(&self) -> bool {
76 self.is_closed.load(Ordering::SeqCst)
77 }
78
79 pub fn request_id(&self) -> &str {
81 &self.request_id
82 }
83
84 pub async fn on_framesent<F, Fut>(&self, handler: F)
96 where
97 F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
98 Fut: Future<Output = ()> + Send + 'static,
99 {
100 let mut rx = self.frame_sent_tx.subscribe();
101 tokio::spawn(async move {
102 while let Ok(frame) = rx.recv().await {
103 handler(frame).await;
104 }
105 });
106 }
107
108 pub async fn on_framereceived<F, Fut>(&self, handler: F)
120 where
121 F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
122 Fut: Future<Output = ()> + Send + 'static,
123 {
124 let mut rx = self.frame_received_tx.subscribe();
125 tokio::spawn(async move {
126 while let Ok(frame) = rx.recv().await {
127 handler(frame).await;
128 }
129 });
130 }
131
132 pub async fn on_close<F, Fut>(&self, handler: F)
144 where
145 F: Fn() -> Fut + Send + Sync + 'static,
146 Fut: Future<Output = ()> + Send + 'static,
147 {
148 let mut rx = self.close_tx.subscribe();
149 tokio::spawn(async move {
150 if rx.recv().await.is_ok() {
151 handler().await;
152 }
153 });
154 }
155
156 pub(crate) fn emit_frame_sent(&self, frame: WebSocketFrame) {
158 let _ = self.frame_sent_tx.send(frame);
159 }
160
161 pub(crate) fn emit_frame_received(&self, frame: WebSocketFrame) {
163 let _ = self.frame_received_tx.send(frame);
164 }
165
166 pub(crate) fn mark_closed(&self) {
168 self.is_closed.store(true, Ordering::SeqCst);
169 let _ = self.close_tx.send(());
170 }
171}
172
173#[derive(Debug, Clone)]
175pub struct WebSocketFrame {
176 opcode: u8,
178 payload_data: String,
180}
181
182impl WebSocketFrame {
183 pub(crate) fn new(opcode: u8, payload_data: String) -> Self {
185 Self {
186 opcode,
187 payload_data,
188 }
189 }
190
191 pub(crate) fn from_cdp(cdp_frame: &CdpWebSocketFrame) -> Self {
193 Self {
194 opcode: cdp_frame.opcode as u8,
195 payload_data: cdp_frame.payload_data.clone(),
196 }
197 }
198
199 pub fn opcode(&self) -> u8 {
208 self.opcode
209 }
210
211 pub fn payload(&self) -> &str {
213 &self.payload_data
214 }
215
216 pub fn is_text(&self) -> bool {
218 self.opcode == 1
219 }
220
221 pub fn is_binary(&self) -> bool {
223 self.opcode == 2
224 }
225}
226
227pub type WebSocketEventHandler = Box<
229 dyn Fn(WebSocket) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
230>;
231
232pub struct WebSocketManager {
234 connection: Arc<CdpConnection>,
236 session_id: String,
238 websockets: Arc<RwLock<HashMap<String, WebSocket>>>,
240 handler: Arc<RwLock<Option<WebSocketEventHandler>>>,
242 is_listening: AtomicBool,
244}
245
246impl WebSocketManager {
247 pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
249 Self {
250 connection,
251 session_id,
252 websockets: Arc::new(RwLock::new(HashMap::new())),
253 handler: Arc::new(RwLock::new(None)),
254 is_listening: AtomicBool::new(false),
255 }
256 }
257
258 pub async fn set_handler<F, Fut>(&self, handler: F)
260 where
261 F: Fn(WebSocket) -> Fut + Send + Sync + 'static,
262 Fut: Future<Output = ()> + Send + 'static,
263 {
264 let boxed_handler: WebSocketEventHandler = Box::new(move |ws| {
265 Box::pin(handler(ws))
266 });
267 let mut h = self.handler.write().await;
268 *h = Some(boxed_handler);
269
270 self.start_listening().await;
272 }
273
274 pub async fn remove_handler(&self) {
276 let mut h = self.handler.write().await;
277 *h = None;
278 }
279
280 async fn start_listening(&self) {
282 if self.is_listening.swap(true, Ordering::SeqCst) {
283 return;
285 }
286
287 let mut events = self.connection.subscribe_events();
288 let session_id = self.session_id.clone();
289 let websockets = self.websockets.clone();
290 let handler = self.handler.clone();
291
292 tokio::spawn(async move {
293 debug!("WebSocket manager started listening for events");
294
295 while let Ok(event) = events.recv().await {
296 if event.session_id.as_deref() != Some(&session_id) {
298 continue;
299 }
300
301 match event.method.as_str() {
302 "Network.webSocketCreated" => {
303 if let Some(params) = &event.params {
304 if let Ok(created) = serde_json::from_value::<WebSocketCreatedEvent>(params.clone()) {
305 trace!("WebSocket created: {} -> {}", created.request_id, created.url);
306
307 let ws = WebSocket::new(created.request_id.clone(), created.url);
308
309 {
311 let mut sockets = websockets.write().await;
312 sockets.insert(created.request_id, ws.clone());
313 }
314
315 let h = handler.read().await;
317 if let Some(ref handler_fn) = *h {
318 handler_fn(ws).await;
319 }
320 }
321 }
322 }
323 "Network.webSocketClosed" => {
324 if let Some(params) = &event.params {
325 if let Ok(closed) = serde_json::from_value::<WebSocketClosedEvent>(params.clone()) {
326 trace!("WebSocket closed: {}", closed.request_id);
327
328 let sockets = websockets.read().await;
329 if let Some(ws) = sockets.get(&closed.request_id) {
330 ws.mark_closed();
331 }
332 }
333 }
334 }
335 "Network.webSocketFrameSent" => {
336 if let Some(params) = &event.params {
337 if let Ok(frame_event) = serde_json::from_value::<WebSocketFrameSentEvent>(params.clone()) {
338 trace!("WebSocket frame sent: {}", frame_event.request_id);
339
340 let sockets = websockets.read().await;
341 if let Some(ws) = sockets.get(&frame_event.request_id) {
342 let frame = WebSocketFrame::from_cdp(&frame_event.response);
343 ws.emit_frame_sent(frame);
344 }
345 }
346 }
347 }
348 "Network.webSocketFrameReceived" => {
349 if let Some(params) = &event.params {
350 if let Ok(frame_event) = serde_json::from_value::<WebSocketFrameReceivedEvent>(params.clone()) {
351 trace!("WebSocket frame received: {}", frame_event.request_id);
352
353 let sockets = websockets.read().await;
354 if let Some(ws) = sockets.get(&frame_event.request_id) {
355 let frame = WebSocketFrame::from_cdp(&frame_event.response);
356 ws.emit_frame_received(frame);
357 }
358 }
359 }
360 }
361 _ => {}
362 }
363 }
364
365 debug!("WebSocket manager stopped listening");
366 });
367 }
368
369 pub async fn get(&self, request_id: &str) -> Option<WebSocket> {
371 let sockets = self.websockets.read().await;
372 sockets.get(request_id).cloned()
373 }
374
375 pub async fn all(&self) -> Vec<WebSocket> {
377 let sockets = self.websockets.read().await;
378 sockets.values().cloned().collect()
379 }
380}
381
382#[cfg(test)]
383mod tests;