Skip to main content

spider_browser/protocol/
bidi_session.rs

1//! Lock-free WebDriver BiDi session over the Spider WebSocket transport.
2
3use crate::errors::{Result, SpiderError};
4use arc_swap::ArcSwap;
5use dashmap::DashMap;
6use serde_json::{json, Value};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use tokio::sync::{mpsc, oneshot};
10
11/// Lock-free BiDi session using DashMap for pending/events, ArcSwap for context.
12pub struct BiDiSession {
13    next_id: AtomicU64,
14    pending: Arc<DashMap<u64, oneshot::Sender<Value>>>,
15    event_handlers: Arc<DashMap<String, Vec<Arc<dyn Fn(Value) + Send + Sync>>>>,
16    browsing_context: ArcSwap<Option<String>>,
17    timeout_ms: u64,
18    send_tx: mpsc::UnboundedSender<String>,
19}
20
21impl BiDiSession {
22    pub fn new(send_tx: mpsc::UnboundedSender<String>, timeout_ms: u64) -> Self {
23        Self {
24            next_id: AtomicU64::new(1),
25            pending: Arc::new(DashMap::new()),
26            event_handlers: Arc::new(DashMap::new()),
27            browsing_context: ArcSwap::from_pointee(None),
28            timeout_ms,
29            send_tx,
30        }
31    }
32
33    pub fn context(&self) -> Option<String> {
34        self.browsing_context.load().as_ref().clone()
35    }
36
37    /// Process a raw message from the transport. Returns true if handled.
38    pub fn handle_message(&self, data: &str) -> bool {
39        let Ok(msg) = serde_json::from_str::<Value>(data) else {
40            return false;
41        };
42
43        // Response (has "id" and "type")
44        if msg.get("id").and_then(|v| v.as_u64()).is_some()
45            && msg.get("type").and_then(|v| v.as_str()).is_some()
46        {
47            let id = msg["id"].as_u64().unwrap();
48            if let Some((_, tx)) = self.pending.remove(&id) {
49                let _ = tx.send(msg);
50                return true;
51            }
52            return false;
53        }
54
55        // Event (has "type": "event" and "method")
56        if msg.get("type").and_then(|v| v.as_str()) == Some("event") {
57            if let Some(method) = msg.get("method").and_then(|v| v.as_str()) {
58                let params = msg.get("params").cloned().unwrap_or(json!({}));
59                if let Some(list) = self.event_handlers.get(method) {
60                    let handlers = list.clone();
61                    drop(list);
62                    for h in &handlers {
63                        h(params.clone());
64                    }
65                }
66                return true;
67            }
68        }
69
70        false
71    }
72
73    /// Send a BiDi command and wait for the response.
74    pub async fn send(&self, method: &str, params: Value) -> Result<Value> {
75        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
76        let cmd = json!({"id": id, "method": method, "params": params});
77
78        let (tx, rx) = oneshot::channel();
79        self.pending.insert(id, tx);
80        self.send_tx
81            .send(cmd.to_string())
82            .map_err(|_| SpiderError::connection("WebSocket is not connected"))?;
83
84        let resp = tokio::time::timeout(
85            tokio::time::Duration::from_millis(self.timeout_ms),
86            rx,
87        )
88        .await
89        .map_err(|_| {
90            self.pending.remove(&id);
91            SpiderError::Timeout(format!("BiDi command timeout: {method} ({}ms)", self.timeout_ms))
92        })?
93        .map_err(|_| SpiderError::connection("BiDi response channel closed"))?;
94
95        if resp.get("type").and_then(|v| v.as_str()) == Some("error") {
96            let msg = resp.get("message").or(resp.get("error"))
97                .and_then(|v| v.as_str()).unwrap_or("unknown");
98            return Err(SpiderError::Protocol(format!("BiDi error: {msg}")));
99        }
100
101        Ok(resp)
102    }
103
104    pub fn on(&self, method: &str, handler: Arc<dyn Fn(Value) + Send + Sync>) {
105        self.event_handlers
106            .entry(method.to_string())
107            .or_default()
108            .push(handler);
109    }
110
111    /// Get or create a browsing context.
112    pub async fn get_or_create_context(&self) -> Result<String> {
113        if let Some(ctx) = self.context() {
114            return Ok(ctx);
115        }
116
117        // Strategy 1: browsingContext.getTree
118        if let Ok(resp) = tokio::time::timeout(
119            tokio::time::Duration::from_secs(5),
120            self.send("browsingContext.getTree", json!({})),
121        ).await {
122            if let Ok(resp) = resp {
123                if let Some(contexts) = resp.get("result")
124                    .and_then(|r| r.get("contexts"))
125                    .and_then(|v| v.as_array())
126                {
127                    if let Some(first) = contexts.first() {
128                        if let Some(ctx) = first.get("context").and_then(|v| v.as_str()) {
129                            self.browsing_context.store(Arc::new(Some(ctx.to_string())));
130                            return Ok(ctx.to_string());
131                        }
132                    }
133                }
134            }
135        }
136
137        // Strategy 2: browsingContext.create
138        if let Ok(resp) = tokio::time::timeout(
139            tokio::time::Duration::from_secs(5),
140            self.send("browsingContext.create", json!({"type": "tab"})),
141        ).await {
142            if let Ok(resp) = resp {
143                if let Some(ctx) = resp.get("result")
144                    .and_then(|r| r.get("context"))
145                    .and_then(|v| v.as_str())
146                {
147                    self.browsing_context.store(Arc::new(Some(ctx.to_string())));
148                    return Ok(ctx.to_string());
149                }
150            }
151        }
152
153        // Strategy 3: placeholder
154        let placeholder = "__default__".to_string();
155        self.browsing_context.store(Arc::new(Some(placeholder.clone())));
156        Ok(placeholder)
157    }
158
159    pub fn set_context(&self, context_id: &str) {
160        self.browsing_context.store(Arc::new(Some(context_id.to_string())));
161    }
162
163    pub async fn navigate(&self, url: &str) -> Result<()> {
164        let ctx = self.get_or_create_context().await?;
165        let resp = self.send("browsingContext.navigate", json!({
166            "context": ctx, "url": url, "wait": "complete",
167        })).await?;
168        // Extract real context from response if placeholder
169        if ctx == "__default__" {
170            if let Some(real_ctx) = resp.get("params")
171                .and_then(|p| p.get("context"))
172                .and_then(|v| v.as_str())
173            {
174                self.set_context(real_ctx);
175            }
176        }
177        Ok(())
178    }
179
180    pub async fn capture_screenshot(&self) -> Result<String> {
181        let ctx = self.get_or_create_context().await?;
182        let resp = self.send("browsingContext.captureScreenshot", json!({"context": ctx})).await?;
183        resp.get("result")
184            .and_then(|r| r.get("data"))
185            .and_then(|v| v.as_str())
186            .map(|s| s.to_string())
187            .ok_or_else(|| SpiderError::Protocol("captureScreenshot: missing result.data".into()))
188    }
189
190    pub async fn evaluate(&self, expression: &str) -> Result<Value> {
191        let ctx = self.get_or_create_context().await?;
192        let resp = self.send("script.evaluate", json!({
193            "expression": expression,
194            "target": {"context": ctx},
195            "awaitPromise": false,
196            "resultOwnership": "none",
197        })).await?;
198
199        let result_obj = resp.get("result")
200            .and_then(|r| r.get("result"))
201            .or_else(|| resp.get("result"))
202            .cloned()
203            .unwrap_or(Value::Null);
204        Ok(extract_bidi_value(&result_obj))
205    }
206
207    pub async fn get_html(&self) -> Result<String> {
208        let val = self.evaluate("document.documentElement.outerHTML").await?;
209        Ok(val.as_str().unwrap_or("").to_string())
210    }
211
212    pub async fn perform_actions(&self, actions: Value) -> Result<()> {
213        let ctx = self.get_or_create_context().await?;
214        self.send("input.performActions", json!({
215            "context": ctx, "actions": actions,
216        })).await?;
217        Ok(())
218    }
219
220    pub async fn click_point(&self, x: f64, y: f64) -> Result<()> {
221        self.perform_actions(json!([{
222            "type": "pointer", "id": "mouse",
223            "actions": [
224                {"type": "pointerMove", "x": x.round() as i64, "y": y.round() as i64},
225                {"type": "pointerDown", "button": 0},
226                {"type": "pointerUp", "button": 0},
227            ]
228        }])).await
229    }
230
231    pub async fn insert_text(&self, text: &str) -> Result<()> {
232        let actions: Vec<Value> = text.chars().flat_map(|ch| {
233            let s = ch.to_string();
234            vec![
235                json!({"type": "keyDown", "value": s}),
236                json!({"type": "keyUp", "value": s}),
237            ]
238        }).collect();
239        self.perform_actions(json!([{"type": "key", "id": "keyboard", "actions": actions}])).await
240    }
241
242    pub fn destroy(&self) {
243        self.pending.clear();
244        self.event_handlers.clear();
245        self.browsing_context.store(Arc::new(None));
246    }
247}
248
249fn extract_bidi_value(remote: &Value) -> Value {
250    match remote.get("type").and_then(|v| v.as_str()) {
251        Some("undefined") | Some("null") => Value::Null,
252        Some("string") | Some("number") | Some("boolean") | Some("bigint") => {
253            remote.get("value").cloned().unwrap_or(Value::Null)
254        }
255        Some("array") => {
256            if let Some(arr) = remote.get("value").and_then(|v| v.as_array()) {
257                Value::Array(arr.iter().map(extract_bidi_value).collect())
258            } else {
259                remote.get("value").cloned().unwrap_or(Value::Null)
260            }
261        }
262        Some("object") => {
263            if let Some(pairs) = remote.get("value").and_then(|v| v.as_array()) {
264                let mut map = serde_json::Map::new();
265                for entry in pairs {
266                    if let Some(pair) = entry.as_array() {
267                        if pair.len() == 2 {
268                            let key = pair[0].as_str()
269                                .map(|s| s.to_string())
270                                .or_else(|| pair[0].get("value").and_then(|v| v.as_str()).map(|s| s.to_string()))
271                                .unwrap_or_default();
272                            map.insert(key, extract_bidi_value(&pair[1]));
273                        }
274                    }
275                }
276                Value::Object(map)
277            } else {
278                remote.get("value").cloned().unwrap_or(Value::Null)
279            }
280        }
281        _ => remote.get("value").cloned().unwrap_or(remote.clone()),
282    }
283}