viewpoint_core/network/websocket/
mod.rs1use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14use tokio::sync::{broadcast, RwLock};
15use tracing::{debug, trace};
16use viewpoint_cdp::protocol::{
17 WebSocketClosedEvent, WebSocketCreatedEvent, WebSocketFrameReceivedEvent,
18 WebSocketFrameSentEvent, WebSocketFrame as CdpWebSocketFrame,
19};
20use viewpoint_cdp::CdpConnection;
21
22#[derive(Clone)]
27pub struct WebSocket {
28 request_id: String,
30 url: String,
32 is_closed: Arc<AtomicBool>,
34 frame_sent_tx: broadcast::Sender<WebSocketFrame>,
36 frame_received_tx: broadcast::Sender<WebSocketFrame>,
38 close_tx: broadcast::Sender<()>,
40}
41
42impl std::fmt::Debug for WebSocket {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("WebSocket")
45 .field("request_id", &self.request_id)
46 .field("url", &self.url)
47 .field("is_closed", &self.is_closed.load(Ordering::SeqCst))
48 .finish()
49 }
50}
51
52impl WebSocket {
53 pub(crate) fn new(request_id: String, url: String) -> Self {
55 let (frame_sent_tx, _) = broadcast::channel(256);
56 let (frame_received_tx, _) = broadcast::channel(256);
57 let (close_tx, _) = broadcast::channel(16);
58
59 Self {
60 request_id,
61 url,
62 is_closed: Arc::new(AtomicBool::new(false)),
63 frame_sent_tx,
64 frame_received_tx,
65 close_tx,
66 }
67 }
68
69 pub fn url(&self) -> &str {
71 &self.url
72 }
73
74 pub fn is_closed(&self) -> bool {
76 self.is_closed.load(Ordering::SeqCst)
77 }
78
79 pub fn request_id(&self) -> &str {
81 &self.request_id
82 }
83
84 pub async fn on_framesent<F, Fut>(&self, handler: F)
100 where
101 F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
102 Fut: Future<Output = ()> + Send + 'static,
103 {
104 let mut rx = self.frame_sent_tx.subscribe();
105 tokio::spawn(async move {
106 while let Ok(frame) = rx.recv().await {
107 handler(frame).await;
108 }
109 });
110 }
111
112 pub async fn on_framereceived<F, Fut>(&self, handler: F)
128 where
129 F: Fn(WebSocketFrame) -> Fut + Send + Sync + 'static,
130 Fut: Future<Output = ()> + Send + 'static,
131 {
132 let mut rx = self.frame_received_tx.subscribe();
133 tokio::spawn(async move {
134 while let Ok(frame) = rx.recv().await {
135 handler(frame).await;
136 }
137 });
138 }
139
140 pub async fn on_close<F, Fut>(&self, handler: F)
156 where
157 F: Fn() -> Fut + Send + Sync + 'static,
158 Fut: Future<Output = ()> + Send + 'static,
159 {
160 let mut rx = self.close_tx.subscribe();
161 tokio::spawn(async move {
162 if rx.recv().await.is_ok() {
163 handler().await;
164 }
165 });
166 }
167
168 pub(crate) fn emit_frame_sent(&self, frame: WebSocketFrame) {
170 let _ = self.frame_sent_tx.send(frame);
171 }
172
173 pub(crate) fn emit_frame_received(&self, frame: WebSocketFrame) {
175 let _ = self.frame_received_tx.send(frame);
176 }
177
178 pub(crate) fn mark_closed(&self) {
180 self.is_closed.store(true, Ordering::SeqCst);
181 let _ = self.close_tx.send(());
182 }
183}
184
185#[derive(Debug, Clone)]
187pub struct WebSocketFrame {
188 opcode: u8,
190 payload_data: String,
192}
193
194impl WebSocketFrame {
195 pub(crate) fn new(opcode: u8, payload_data: String) -> Self {
197 Self {
198 opcode,
199 payload_data,
200 }
201 }
202
203 pub(crate) fn from_cdp(cdp_frame: &CdpWebSocketFrame) -> Self {
205 Self {
206 opcode: cdp_frame.opcode as u8,
207 payload_data: cdp_frame.payload_data.clone(),
208 }
209 }
210
211 pub fn opcode(&self) -> u8 {
220 self.opcode
221 }
222
223 pub fn payload(&self) -> &str {
225 &self.payload_data
226 }
227
228 pub fn is_text(&self) -> bool {
230 self.opcode == 1
231 }
232
233 pub fn is_binary(&self) -> bool {
235 self.opcode == 2
236 }
237}
238
239pub type WebSocketEventHandler = Box<
241 dyn Fn(WebSocket) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
242>;
243
244pub struct WebSocketManager {
246 connection: Arc<CdpConnection>,
248 session_id: String,
250 websockets: Arc<RwLock<HashMap<String, WebSocket>>>,
252 handler: Arc<RwLock<Option<WebSocketEventHandler>>>,
254 is_listening: AtomicBool,
256}
257
258impl WebSocketManager {
259 pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
261 Self {
262 connection,
263 session_id,
264 websockets: Arc::new(RwLock::new(HashMap::new())),
265 handler: Arc::new(RwLock::new(None)),
266 is_listening: AtomicBool::new(false),
267 }
268 }
269
270 pub async fn set_handler<F, Fut>(&self, handler: F)
272 where
273 F: Fn(WebSocket) -> Fut + Send + Sync + 'static,
274 Fut: Future<Output = ()> + Send + 'static,
275 {
276 let boxed_handler: WebSocketEventHandler = Box::new(move |ws| {
277 Box::pin(handler(ws))
278 });
279 let mut h = self.handler.write().await;
280 *h = Some(boxed_handler);
281
282 self.start_listening().await;
284 }
285
286 pub async fn remove_handler(&self) {
288 let mut h = self.handler.write().await;
289 *h = None;
290 }
291
292 async fn start_listening(&self) {
294 if self.is_listening.swap(true, Ordering::SeqCst) {
295 return;
297 }
298
299 let mut events = self.connection.subscribe_events();
300 let session_id = self.session_id.clone();
301 let websockets = self.websockets.clone();
302 let handler = self.handler.clone();
303
304 tokio::spawn(async move {
305 debug!("WebSocket manager started listening for events");
306
307 while let Ok(event) = events.recv().await {
308 if event.session_id.as_deref() != Some(&session_id) {
310 continue;
311 }
312
313 match event.method.as_str() {
314 "Network.webSocketCreated" => {
315 if let Some(params) = &event.params {
316 if let Ok(created) = serde_json::from_value::<WebSocketCreatedEvent>(params.clone()) {
317 trace!("WebSocket created: {} -> {}", created.request_id, created.url);
318
319 let ws = WebSocket::new(created.request_id.clone(), created.url);
320
321 {
323 let mut sockets = websockets.write().await;
324 sockets.insert(created.request_id, ws.clone());
325 }
326
327 let h = handler.read().await;
329 if let Some(ref handler_fn) = *h {
330 handler_fn(ws).await;
331 }
332 }
333 }
334 }
335 "Network.webSocketClosed" => {
336 if let Some(params) = &event.params {
337 if let Ok(closed) = serde_json::from_value::<WebSocketClosedEvent>(params.clone()) {
338 trace!("WebSocket closed: {}", closed.request_id);
339
340 let sockets = websockets.read().await;
341 if let Some(ws) = sockets.get(&closed.request_id) {
342 ws.mark_closed();
343 }
344 }
345 }
346 }
347 "Network.webSocketFrameSent" => {
348 if let Some(params) = &event.params {
349 if let Ok(frame_event) = serde_json::from_value::<WebSocketFrameSentEvent>(params.clone()) {
350 trace!("WebSocket frame sent: {}", frame_event.request_id);
351
352 let sockets = websockets.read().await;
353 if let Some(ws) = sockets.get(&frame_event.request_id) {
354 let frame = WebSocketFrame::from_cdp(&frame_event.response);
355 ws.emit_frame_sent(frame);
356 }
357 }
358 }
359 }
360 "Network.webSocketFrameReceived" => {
361 if let Some(params) = &event.params {
362 if let Ok(frame_event) = serde_json::from_value::<WebSocketFrameReceivedEvent>(params.clone()) {
363 trace!("WebSocket frame received: {}", frame_event.request_id);
364
365 let sockets = websockets.read().await;
366 if let Some(ws) = sockets.get(&frame_event.request_id) {
367 let frame = WebSocketFrame::from_cdp(&frame_event.response);
368 ws.emit_frame_received(frame);
369 }
370 }
371 }
372 }
373 _ => {}
374 }
375 }
376
377 debug!("WebSocket manager stopped listening");
378 });
379 }
380
381 pub async fn get(&self, request_id: &str) -> Option<WebSocket> {
383 let sockets = self.websockets.read().await;
384 sockets.get(request_id).cloned()
385 }
386
387 pub async fn all(&self) -> Vec<WebSocket> {
389 let sockets = self.websockets.read().await;
390 sockets.values().cloned().collect()
391 }
392}
393
394#[cfg(test)]
395mod tests;