viewpoint_cdp/connection/
mod.rs

1//! CDP WebSocket connection management.
2
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use futures_util::{SinkExt, StreamExt};
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use serde_json::Value;
12use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
13use tokio::time::timeout;
14use tokio_tungstenite::tungstenite::Message;
15use tracing::{debug, error, info, instrument, trace, warn};
16
17use crate::error::CdpError;
18use crate::transport::{CdpEvent, CdpMessage, CdpRequest, CdpResponse};
19
20/// Default timeout for CDP commands.
21const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
22
23/// Buffer size for the event broadcast channel.
24const EVENT_CHANNEL_SIZE: usize = 256;
25
26/// A CDP connection to a browser.
27#[derive(Debug)]
28pub struct CdpConnection {
29    /// Sender for outgoing messages.
30    tx: mpsc::Sender<CdpRequest>,
31    /// Receiver for incoming events.
32    event_rx: broadcast::Sender<CdpEvent>,
33    /// Pending responses waiting for completion.
34    pending: Arc<Mutex<HashMap<u64, oneshot::Sender<CdpResponse>>>>,
35    /// Atomic counter for message IDs.
36    message_id: AtomicU64,
37    /// Handle to the background read task.
38    _read_handle: tokio::task::JoinHandle<()>,
39    /// Handle to the background write task.
40    _write_handle: tokio::task::JoinHandle<()>,
41}
42
43impl CdpConnection {
44    /// Connect to a CDP WebSocket endpoint.
45    ///
46    /// # Errors
47    ///
48    /// Returns an error if the WebSocket connection fails.
49    #[instrument(level = "info", skip(ws_url), fields(ws_url = %ws_url))]
50    pub async fn connect(ws_url: &str) -> Result<Self, CdpError> {
51        info!("Connecting to CDP WebSocket endpoint");
52        let (ws_stream, response) = tokio_tungstenite::connect_async(ws_url).await?;
53        info!(status = %response.status(), "WebSocket connection established");
54        
55        let (write, read) = ws_stream.split();
56
57        // Channels for internal communication
58        let (tx, rx) = mpsc::channel::<CdpRequest>(64);
59        let (event_tx, _) = broadcast::channel::<CdpEvent>(EVENT_CHANNEL_SIZE);
60        let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<CdpResponse>>>> =
61            Arc::new(Mutex::new(HashMap::new()));
62
63        // Spawn the write task
64        let write_handle = tokio::spawn(Self::write_loop(rx, write));
65        debug!("Spawned CDP write loop");
66
67        // Spawn the read task
68        let read_pending = pending.clone();
69        let read_event_tx = event_tx.clone();
70        let read_handle = tokio::spawn(Self::read_loop(read, read_pending, read_event_tx));
71        debug!("Spawned CDP read loop");
72
73        info!("CDP connection ready");
74        Ok(Self {
75            tx,
76            event_rx: event_tx,
77            pending,
78            message_id: AtomicU64::new(1),
79            _read_handle: read_handle,
80            _write_handle: write_handle,
81        })
82    }
83
84    /// Background task that writes CDP requests to the WebSocket.
85    async fn write_loop<S>(mut rx: mpsc::Receiver<CdpRequest>, mut sink: S)
86    where
87        S: futures_util::Sink<Message, Error = tokio_tungstenite::tungstenite::Error> + Unpin,
88    {
89        debug!("CDP write loop started");
90        while let Some(request) = rx.recv().await {
91            let method = request.method.clone();
92            let id = request.id;
93            
94            let json = match serde_json::to_string(&request) {
95                Ok(j) => j,
96                Err(e) => {
97                    error!(error = %e, method = %method, "Failed to serialize CDP request");
98                    continue;
99                }
100            };
101
102            trace!(id = id, method = %method, json_len = json.len(), "Sending CDP request");
103
104            if sink.send(Message::Text(json.into())).await.is_err() {
105                warn!("WebSocket sink closed, ending write loop");
106                break;
107            }
108            
109            debug!(id = id, method = %method, "CDP request sent");
110        }
111        debug!("CDP write loop ended");
112    }
113
114    /// Background task that reads CDP messages from the WebSocket.
115    async fn read_loop<S>(
116        mut stream: S,
117        pending: Arc<Mutex<HashMap<u64, oneshot::Sender<CdpResponse>>>>,
118        event_tx: broadcast::Sender<CdpEvent>,
119    ) where
120        S: futures_util::Stream<Item = Result<Message, tokio_tungstenite::tungstenite::Error>>
121            + Unpin,
122    {
123        debug!("CDP read loop started");
124        while let Some(msg) = stream.next().await {
125            let msg = match msg {
126                Ok(Message::Text(text)) => text,
127                Ok(Message::Close(frame)) => {
128                    info!(?frame, "WebSocket closed by remote");
129                    break;
130                }
131                Err(e) => {
132                    warn!(error = %e, "WebSocket error, ending read loop");
133                    break;
134                }
135                Ok(_) => continue,
136            };
137
138            trace!(json_len = msg.len(), "Received CDP message");
139
140            // Parse the incoming message
141            let cdp_msg: CdpMessage = match serde_json::from_str(&msg) {
142                Ok(m) => m,
143                Err(e) => {
144                    error!(error = %e, "Failed to parse CDP message");
145                    continue;
146                }
147            };
148
149            match cdp_msg {
150                CdpMessage::Response(resp) => {
151                    let id = resp.id;
152                    let has_error = resp.error.is_some();
153                    debug!(id = id, has_error = has_error, "Received CDP response");
154                    
155                    let mut pending = pending.lock().await;
156                    if let Some(sender) = pending.remove(&id) {
157                        let _ = sender.send(resp);
158                    } else {
159                        warn!(id = id, "Received response for unknown request ID");
160                    }
161                }
162                CdpMessage::Event(ref event) => {
163                    trace!(method = %event.method, session_id = ?event.session_id, "Received CDP event");
164                    // Broadcast to all subscribers; ignore if no receivers.
165                    let _ = event_tx.send(event.clone());
166                }
167            }
168        }
169        debug!("CDP read loop ended");
170    }
171
172    /// Send a CDP command and wait for the response.
173    ///
174    /// # Errors
175    ///
176    /// Returns an error if:
177    /// - The command cannot be sent
178    /// - The response times out
179    /// - The browser returns a protocol error
180    pub async fn send_command<P, R>(
181        &self,
182        method: &str,
183        params: Option<P>,
184        session_id: Option<&str>,
185    ) -> Result<R, CdpError>
186    where
187        P: Serialize,
188        R: DeserializeOwned,
189    {
190        self.send_command_with_timeout(method, params, session_id, DEFAULT_TIMEOUT)
191            .await
192    }
193
194    /// Send a CDP command with a custom timeout.
195    ///
196    /// # Errors
197    ///
198    /// Returns an error if:
199    /// - The command cannot be sent
200    /// - The response times out
201    /// - The browser returns a protocol error
202    #[instrument(level = "debug", skip(self, params), fields(method = %method, session_id = ?session_id))]
203    pub async fn send_command_with_timeout<P, R>(
204        &self,
205        method: &str,
206        params: Option<P>,
207        session_id: Option<&str>,
208        timeout_duration: Duration,
209    ) -> Result<R, CdpError>
210    where
211        P: Serialize,
212        R: DeserializeOwned,
213    {
214        let id = self.message_id.fetch_add(1, Ordering::Relaxed);
215        debug!(id = id, timeout_ms = timeout_duration.as_millis(), "Preparing CDP command");
216
217        let params_value = params
218            .map(|p| serde_json::to_value(p))
219            .transpose()?;
220
221        let request = CdpRequest {
222            id,
223            method: method.to_string(),
224            params: params_value,
225            session_id: session_id.map(ToString::to_string),
226        };
227
228        // Create a oneshot channel for the response
229        let (resp_tx, resp_rx) = oneshot::channel();
230
231        // Register the pending response
232        {
233            let mut pending = self.pending.lock().await;
234            pending.insert(id, resp_tx);
235            trace!(id = id, pending_count = pending.len(), "Registered pending response");
236        }
237
238        // Send the request
239        self.tx
240            .send(request)
241            .await
242            .map_err(|_| CdpError::ConnectionLost)?;
243        
244        trace!(id = id, "Request queued for sending");
245
246        // Wait for the response with timeout
247        let response = timeout(timeout_duration, resp_rx)
248            .await
249            .map_err(|_| {
250                warn!(id = id, method = %method, "CDP command timed out");
251                CdpError::Timeout(timeout_duration)
252            })?
253            .map_err(|_| CdpError::ConnectionLost)?;
254
255        // Check for protocol errors
256        if let Some(ref error) = response.error {
257            warn!(id = id, method = %method, code = error.code, error_msg = %error.message, "CDP protocol error");
258            return Err(CdpError::Protocol {
259                code: error.code,
260                message: error.message.clone(),
261            });
262        }
263
264        debug!(id = id, "CDP command completed successfully");
265
266        // Parse the result
267        let result = response.result.unwrap_or(Value::Null);
268        serde_json::from_value(result).map_err(CdpError::from)
269    }
270
271    /// Subscribe to CDP events.
272    ///
273    /// Returns a receiver that will receive all CDP events from the browser.
274    pub fn subscribe_events(&self) -> broadcast::Receiver<CdpEvent> {
275        debug!("New CDP event subscription created");
276        self.event_rx.subscribe()
277    }
278}
279
280#[cfg(test)]
281mod tests;