Skip to main content

spider_browser/protocol/
cdp_session.rs

1//! Lock-free CDP JSON-RPC session over the Spider WebSocket transport.
2
3use crate::errors::{Result, SpiderError};
4use arc_swap::ArcSwap;
5use dashmap::DashMap;
6use serde_json::{json, Value};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use tokio::sync::{mpsc, oneshot};
10use tracing::info;
11
12const RETRYABLE_NAV_ERRORS: &[&str] = &[
13    "ERR_ABORTED",
14    "ERR_CONNECTION_RESET",
15    "ERR_CONNECTION_CLOSED",
16    "ERR_CONNECTION_REFUSED",
17    "ERR_CONNECTION_TIMED_OUT",
18    "ERR_TIMED_OUT",
19    "ERR_EMPTY_RESPONSE",
20    "ERR_SOCKET_NOT_CONNECTED",
21    "ERR_NETWORK_CHANGED",
22    "ERR_BLOCKED_BY_CLIENT",
23    "ERR_SSL_PROTOCOL_ERROR",
24    "ERR_SSL_VERSION_OR_CIPHER_MISMATCH",
25];
26
27type EventHandler = Box<dyn Fn(Value) + Send + Sync>;
28
29/// Lock-free CDP session using DashMap for pending responses and events,
30/// ArcSwap for the session ID, and atomics for the ID counter.
31pub struct CDPSession {
32    next_id: AtomicU64,
33    pending: Arc<DashMap<u64, oneshot::Sender<Value>>>,
34    event_handlers: Arc<DashMap<String, Vec<Arc<dyn Fn(Value) + Send + Sync>>>>,
35    target_session_id: ArcSwap<Option<String>>,
36    timeout_ms: u64,
37    send_tx: mpsc::UnboundedSender<String>,
38}
39
40impl CDPSession {
41    pub fn new(send_tx: mpsc::UnboundedSender<String>, timeout_ms: u64) -> Self {
42        Self {
43            next_id: AtomicU64::new(1),
44            pending: Arc::new(DashMap::new()),
45            event_handlers: Arc::new(DashMap::new()),
46            target_session_id: ArcSwap::from_pointee(None),
47            timeout_ms,
48            send_tx,
49        }
50    }
51
52    /// Process an incoming message. Returns true if it was handled.
53    pub fn handle_message(&self, data: &str) -> bool {
54        let Ok(msg) = serde_json::from_str::<Value>(data) else {
55            return false;
56        };
57
58        // Response
59        if let Some(id) = msg.get("id").and_then(|v| v.as_u64()) {
60            if let Some((_, tx)) = self.pending.remove(&id) {
61                let _ = tx.send(msg);
62                return true;
63            }
64            return false;
65        }
66
67        // Event
68        if let Some(method) = msg.get("method").and_then(|v| v.as_str()) {
69            let params = msg.get("params").cloned().unwrap_or(json!({}));
70            if let Some(list) = self.event_handlers.get(method) {
71                // Clone handlers to release DashMap shard before calling.
72                let handlers = list.clone();
73                drop(list);
74                for h in &handlers {
75                    h(params.clone());
76                }
77            }
78            return true;
79        }
80
81        false
82    }
83
84    /// Send a browser-level CDP command.
85    pub async fn send(&self, method: &str, params: Value) -> Result<Value> {
86        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
87        let cmd = json!({"id": id, "method": method, "params": params});
88
89        let (tx, rx) = oneshot::channel();
90        self.pending.insert(id, tx);
91        self.send_tx
92            .send(cmd.to_string())
93            .map_err(|_| SpiderError::connection("WebSocket is not connected"))?;
94
95        tokio::time::timeout(
96            tokio::time::Duration::from_millis(self.timeout_ms),
97            rx,
98        )
99        .await
100        .map_err(|_| {
101            self.pending.remove(&id);
102            SpiderError::Timeout(format!("CDP command timeout: {method} ({}ms)", self.timeout_ms))
103        })?
104        .map_err(|_| SpiderError::connection("CDP response channel closed"))
105    }
106
107    /// Send a page-scoped CDP command.
108    pub async fn send_to_target(&self, method: &str, params: Value) -> Result<Value> {
109        let session_id = self.target_session_id.load();
110        let session_id = session_id.as_ref().as_ref()
111            .ok_or_else(|| SpiderError::Protocol("No target session — call attach_to_page() first".into()))?;
112
113        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
114        let cmd = json!({
115            "id": id,
116            "method": method,
117            "params": params,
118            "sessionId": session_id,
119        });
120
121        let (tx, rx) = oneshot::channel();
122        self.pending.insert(id, tx);
123        self.send_tx
124            .send(cmd.to_string())
125            .map_err(|_| SpiderError::connection("WebSocket is not connected"))?;
126
127        let resp = tokio::time::timeout(
128            tokio::time::Duration::from_millis(self.timeout_ms),
129            rx,
130        )
131        .await
132        .map_err(|_| {
133            self.pending.remove(&id);
134            SpiderError::Timeout(format!("CDP command timeout: {method} ({}ms)", self.timeout_ms))
135        })?
136        .map_err(|_| SpiderError::connection("CDP response channel closed"))?;
137
138        if let Some(err) = resp.get("error") {
139            let msg = err.get("message").and_then(|v| v.as_str()).unwrap_or("unknown");
140            return Err(SpiderError::Protocol(format!("CDP error: {msg}")));
141        }
142
143        Ok(resp)
144    }
145
146    pub fn on(&self, method: &str, handler: Arc<dyn Fn(Value) + Send + Sync>) {
147        self.event_handlers
148            .entry(method.to_string())
149            .or_default()
150            .push(handler);
151    }
152
153    /// Wait for a CDP event to fire, with timeout.
154    async fn wait_for_event(&self, method: &str, timeout_ms: u64) -> bool {
155        let fired = Arc::new(std::sync::atomic::AtomicBool::new(false));
156        let fired_clone = fired.clone();
157        let notify = Arc::new(tokio::sync::Notify::new());
158        let notify_clone = notify.clone();
159
160        self.on(method, Arc::new(move |_params| {
161            if !fired_clone.swap(true, Ordering::Relaxed) {
162                notify_clone.notify_one();
163            }
164        }));
165
166        let result = tokio::time::timeout(
167            tokio::time::Duration::from_millis(timeout_ms),
168            notify.notified(),
169        )
170        .await;
171
172        // Remove handler (best-effort cleanup)
173        self.event_handlers.remove(method);
174
175        result.is_ok()
176    }
177
178    /// Discover/create a page target, attach, enable domains.
179    pub async fn attach_to_page(&self) -> Result<String> {
180        self.send("Target.setDiscoverTargets", json!({"discover": true})).await?;
181
182        // Always create a fresh page target for session isolation.
183        let cr = self.send("Target.createTarget", json!({"url": "about:blank"})).await?;
184        let target_id = cr.get("result")
185            .and_then(|r| r.get("targetId"))
186            .and_then(|v| v.as_str())
187            .map(|s| s.to_string())
188            .ok_or_else(|| SpiderError::Protocol("Failed to create page target".into()))?;
189
190        let attach_resp = self.send("Target.attachToTarget", json!({"targetId": target_id, "flatten": true})).await?;
191        let sid = attach_resp.get("result")
192            .and_then(|r| r.get("sessionId"))
193            .and_then(|v| v.as_str())
194            .map(|s| s.to_string());
195
196        let session_id = if let Some(s) = sid {
197            s
198        } else {
199            // Wait for Target.attachedToTarget event
200            let fired = Arc::new(std::sync::atomic::AtomicBool::new(false));
201            let result = Arc::new(ArcSwap::from_pointee(String::new()));
202            let fired_clone = fired.clone();
203            let result_clone = result.clone();
204            let notify = Arc::new(tokio::sync::Notify::new());
205            let notify_clone = notify.clone();
206
207            self.on("Target.attachedToTarget", Arc::new(move |params| {
208                if let Some(s) = params.get("sessionId").and_then(|v| v.as_str()) {
209                    result_clone.store(Arc::new(s.to_string()));
210                    if !fired_clone.swap(true, Ordering::Relaxed) {
211                        notify_clone.notify_one();
212                    }
213                }
214            }));
215
216            tokio::time::timeout(
217                tokio::time::Duration::from_secs(5),
218                notify.notified(),
219            )
220            .await
221            .map_err(|_| SpiderError::Timeout("Timeout waiting for Target.attachedToTarget".into()))?;
222
223            self.event_handlers.remove("Target.attachedToTarget");
224            let s = result.load();
225            if s.is_empty() {
226                return Err(SpiderError::Protocol("No sessionId received".into()));
227            }
228            s.as_ref().clone()
229        };
230
231        self.target_session_id.store(Arc::new(Some(session_id.clone())));
232        info!("attached to page target target_id={} session_id={}", target_id, session_id);
233
234        self.send_to_target("Page.enable", json!({})).await?;
235        self.send_to_target("Runtime.enable", json!({})).await?;
236
237        Ok(session_id)
238    }
239
240    // ------------------------------------------------------------------
241    // High-level CDP commands
242    // ------------------------------------------------------------------
243
244    pub async fn navigate(&self, url: &str) -> Result<()> {
245        // Listen for load events BEFORE sending navigate
246        let load_fired = Arc::new(std::sync::atomic::AtomicBool::new(false));
247        let load_notify = Arc::new(tokio::sync::Notify::new());
248        let stop_fired = Arc::new(std::sync::atomic::AtomicBool::new(false));
249        let stop_notify = Arc::new(tokio::sync::Notify::new());
250
251        {
252            let f = load_fired.clone();
253            let n = load_notify.clone();
254            self.on("Page.loadEventFired", Arc::new(move |_| {
255                if !f.swap(true, Ordering::Relaxed) { n.notify_one(); }
256            }));
257        }
258        {
259            let f = stop_fired.clone();
260            let n = stop_notify.clone();
261            self.on("Page.frameStoppedLoading", Arc::new(move |_| {
262                if !f.swap(true, Ordering::Relaxed) { n.notify_one(); }
263            }));
264        }
265
266        let resp = self.send_to_target("Page.navigate", json!({"url": url})).await?;
267
268        if let Some(error_text) = resp.get("result").and_then(|r| r.get("errorText")).and_then(|v| v.as_str()) {
269            self.event_handlers.remove("Page.loadEventFired");
270            self.event_handlers.remove("Page.frameStoppedLoading");
271            if is_retryable_nav_error(error_text) {
272                return Err(SpiderError::Navigation(format!("Navigation failed: {error_text}")));
273            }
274            return Err(SpiderError::Protocol(format!("Navigation failed: {error_text}")));
275        }
276
277        // Wait for load event (8s); if timeout, fall back to frameStoppedLoading (10s)
278        let loaded = tokio::time::timeout(
279            tokio::time::Duration::from_millis(8_000),
280            load_notify.notified(),
281        ).await.is_ok();
282
283        if !loaded {
284            let _ = tokio::time::timeout(
285                tokio::time::Duration::from_millis(10_000),
286                stop_notify.notified(),
287            ).await;
288        }
289
290        self.event_handlers.remove("Page.loadEventFired");
291        self.event_handlers.remove("Page.frameStoppedLoading");
292        Ok(())
293    }
294
295    pub async fn navigate_fast(&self, url: &str) -> Result<()> {
296        let load_fired = Arc::new(std::sync::atomic::AtomicBool::new(false));
297        let load_notify = Arc::new(tokio::sync::Notify::new());
298        {
299            let f = load_fired.clone();
300            let n = load_notify.clone();
301            self.on("Page.loadEventFired", Arc::new(move |_| {
302                if !f.swap(true, Ordering::Relaxed) { n.notify_one(); }
303            }));
304        }
305
306        let resp = self.send_to_target("Page.navigate", json!({"url": url})).await?;
307        if let Some(error_text) = resp.get("result").and_then(|r| r.get("errorText")).and_then(|v| v.as_str()) {
308            self.event_handlers.remove("Page.loadEventFired");
309            if is_retryable_nav_error(error_text) {
310                return Err(SpiderError::Navigation(format!("Navigation failed: {error_text}")));
311            }
312            return Err(SpiderError::Protocol(format!("Navigation failed: {error_text}")));
313        }
314
315        let _ = tokio::time::timeout(
316            tokio::time::Duration::from_millis(5_000),
317            load_notify.notified(),
318        ).await;
319
320        self.event_handlers.remove("Page.loadEventFired");
321        Ok(())
322    }
323
324    pub async fn navigate_dom(&self, url: &str) -> Result<()> {
325        let dom_fired = Arc::new(std::sync::atomic::AtomicBool::new(false));
326        let dom_notify = Arc::new(tokio::sync::Notify::new());
327        {
328            let f = dom_fired.clone();
329            let n = dom_notify.clone();
330            self.on("Page.domContentEventFired", Arc::new(move |_| {
331                if !f.swap(true, Ordering::Relaxed) { n.notify_one(); }
332            }));
333        }
334
335        let resp = self.send_to_target("Page.navigate", json!({"url": url})).await?;
336        if let Some(error_text) = resp.get("result").and_then(|r| r.get("errorText")).and_then(|v| v.as_str()) {
337            self.event_handlers.remove("Page.domContentEventFired");
338            if is_retryable_nav_error(error_text) {
339                return Err(SpiderError::Navigation(format!("Navigation failed: {error_text}")));
340            }
341            return Err(SpiderError::Protocol(format!("Navigation failed: {error_text}")));
342        }
343
344        let _ = tokio::time::timeout(
345            tokio::time::Duration::from_millis(3_000),
346            dom_notify.notified(),
347        ).await;
348
349        self.event_handlers.remove("Page.domContentEventFired");
350        Ok(())
351    }
352
353    pub async fn capture_screenshot(&self) -> Result<String> {
354        let resp = self.send_to_target("Page.captureScreenshot", json!({"format": "png"})).await?;
355        resp.get("result")
356            .and_then(|r| r.get("data"))
357            .and_then(|v| v.as_str())
358            .map(|s| s.to_string())
359            .ok_or_else(|| SpiderError::Protocol("captureScreenshot: missing result.data".into()))
360    }
361
362    pub async fn get_html(&self) -> Result<String> {
363        let val = self.evaluate("document.documentElement.outerHTML").await?;
364        Ok(val.as_str().unwrap_or("").to_string())
365    }
366
367    pub async fn evaluate(&self, expression: &str) -> Result<Value> {
368        let resp = self.send_to_target("Runtime.evaluate", json!({
369            "expression": expression,
370            "returnByValue": true,
371        })).await?;
372        if let Some(err) = resp.get("result").and_then(|r| r.get("exceptionDetails")) {
373            let msg = err.get("text").and_then(|v| v.as_str()).unwrap_or("evaluation error");
374            return Err(SpiderError::Protocol(format!("CDP eval error: {msg}")));
375        }
376        Ok(resp
377            .get("result")
378            .and_then(|r| r.get("result"))
379            .and_then(|r| r.get("value"))
380            .cloned()
381            .unwrap_or(Value::Null))
382    }
383
384    pub async fn click_point(&self, x: f64, y: f64) -> Result<()> {
385        self.dispatch_mouse("mouseMoved", x, y, "none", 0).await?;
386        self.dispatch_mouse("mousePressed", x, y, "left", 1).await?;
387        self.dispatch_mouse("mouseReleased", x, y, "left", 1).await
388    }
389
390    pub async fn right_click_point(&self, x: f64, y: f64) -> Result<()> {
391        self.dispatch_mouse("mouseMoved", x, y, "none", 0).await?;
392        self.dispatch_mouse("mousePressed", x, y, "right", 1).await?;
393        self.dispatch_mouse("mouseReleased", x, y, "right", 1).await
394    }
395
396    pub async fn double_click_point(&self, x: f64, y: f64) -> Result<()> {
397        self.dispatch_mouse("mouseMoved", x, y, "none", 0).await?;
398        self.dispatch_mouse("mousePressed", x, y, "left", 1).await?;
399        self.dispatch_mouse("mouseReleased", x, y, "left", 1).await?;
400        self.dispatch_mouse("mousePressed", x, y, "left", 2).await?;
401        self.dispatch_mouse("mouseReleased", x, y, "left", 2).await
402    }
403
404    pub async fn click_hold_point(&self, x: f64, y: f64, hold_ms: u64) -> Result<()> {
405        self.dispatch_mouse("mouseMoved", x, y, "none", 0).await?;
406        self.dispatch_mouse("mousePressed", x, y, "left", 1).await?;
407        tokio::time::sleep(tokio::time::Duration::from_millis(hold_ms)).await;
408        self.dispatch_mouse("mouseReleased", x, y, "left", 1).await
409    }
410
411    pub async fn hover_point(&self, x: f64, y: f64) -> Result<()> {
412        self.dispatch_mouse("mouseMoved", x, y, "none", 0).await
413    }
414
415    pub async fn drag_point(&self, fx: f64, fy: f64, tx: f64, ty: f64) -> Result<()> {
416        let steps = 10;
417        self.dispatch_mouse("mouseMoved", fx, fy, "none", 0).await?;
418        self.dispatch_mouse("mousePressed", fx, fy, "left", 1).await?;
419        for i in 1..=steps {
420            let t = i as f64 / steps as f64;
421            self.dispatch_mouse("mouseMoved", fx + (tx - fx) * t, fy + (ty - fy) * t, "left", 0).await?;
422            tokio::time::sleep(tokio::time::Duration::from_millis(16)).await;
423        }
424        self.dispatch_mouse("mouseReleased", tx, ty, "left", 1).await
425    }
426
427    pub async fn insert_text(&self, text: &str) -> Result<()> {
428        self.send_to_target("Input.insertText", json!({"text": text})).await?;
429        Ok(())
430    }
431
432    pub async fn press_key(&self, key: &str, code: &str, key_code: u32) -> Result<()> {
433        self.send_to_target("Input.dispatchKeyEvent", json!({
434            "type": "keyDown", "key": key, "code": code,
435            "windowsVirtualKeyCode": key_code, "text": key,
436        })).await?;
437        self.send_to_target("Input.dispatchKeyEvent", json!({
438            "type": "keyUp", "key": key, "code": code,
439            "windowsVirtualKeyCode": key_code,
440        })).await?;
441        Ok(())
442    }
443
444    pub async fn key_down(&self, key: &str, code: &str, key_code: u32) -> Result<()> {
445        self.send_to_target("Input.dispatchKeyEvent", json!({
446            "type": "keyDown", "key": key, "code": code,
447            "windowsVirtualKeyCode": key_code, "text": key,
448        })).await?;
449        Ok(())
450    }
451
452    pub async fn key_up(&self, key: &str, code: &str, key_code: u32) -> Result<()> {
453        self.send_to_target("Input.dispatchKeyEvent", json!({
454            "type": "keyUp", "key": key, "code": code,
455            "windowsVirtualKeyCode": key_code,
456        })).await?;
457        Ok(())
458    }
459
460    pub async fn set_viewport(&self, w: u32, h: u32, dpr: f64, mobile: bool) -> Result<()> {
461        self.send_to_target("Emulation.setDeviceMetricsOverride", json!({
462            "width": w, "height": h, "deviceScaleFactor": dpr, "mobile": mobile,
463        })).await?;
464        Ok(())
465    }
466
467    pub fn destroy(&self) {
468        self.pending.clear();
469        self.event_handlers.clear();
470        self.target_session_id.store(Arc::new(None));
471    }
472
473    async fn dispatch_mouse(&self, typ: &str, x: f64, y: f64, button: &str, click_count: u32) -> Result<()> {
474        self.send_to_target("Input.dispatchMouseEvent", json!({
475            "type": typ, "x": x, "y": y, "button": button, "clickCount": click_count,
476        })).await?;
477        Ok(())
478    }
479}
480
481fn is_retryable_nav_error(error_text: &str) -> bool {
482    RETRYABLE_NAV_ERRORS.iter().any(|e| error_text.contains(e))
483}