spider_browser/protocol/
bidi_session.rs1use crate::errors::{Result, SpiderError};
4use arc_swap::ArcSwap;
5use dashmap::DashMap;
6use serde_json::{json, Value};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use tokio::sync::{mpsc, oneshot};
10
11pub struct BiDiSession {
13 next_id: AtomicU64,
14 pending: Arc<DashMap<u64, oneshot::Sender<Value>>>,
15 event_handlers: Arc<DashMap<String, Vec<Arc<dyn Fn(Value) + Send + Sync>>>>,
16 browsing_context: ArcSwap<Option<String>>,
17 timeout_ms: u64,
18 send_tx: mpsc::UnboundedSender<String>,
19}
20
21impl BiDiSession {
22 pub fn new(send_tx: mpsc::UnboundedSender<String>, timeout_ms: u64) -> Self {
23 Self {
24 next_id: AtomicU64::new(1),
25 pending: Arc::new(DashMap::new()),
26 event_handlers: Arc::new(DashMap::new()),
27 browsing_context: ArcSwap::from_pointee(None),
28 timeout_ms,
29 send_tx,
30 }
31 }
32
33 pub fn context(&self) -> Option<String> {
34 self.browsing_context.load().as_ref().clone()
35 }
36
37 pub fn handle_message(&self, data: &str) -> bool {
39 let Ok(msg) = serde_json::from_str::<Value>(data) else {
40 return false;
41 };
42
43 if msg.get("id").and_then(|v| v.as_u64()).is_some()
45 && msg.get("type").and_then(|v| v.as_str()).is_some()
46 {
47 let id = msg["id"].as_u64().unwrap();
48 if let Some((_, tx)) = self.pending.remove(&id) {
49 let _ = tx.send(msg);
50 return true;
51 }
52 return false;
53 }
54
55 if msg.get("type").and_then(|v| v.as_str()) == Some("event") {
57 if let Some(method) = msg.get("method").and_then(|v| v.as_str()) {
58 let params = msg.get("params").cloned().unwrap_or(json!({}));
59 if let Some(list) = self.event_handlers.get(method) {
60 let handlers = list.clone();
61 drop(list);
62 for h in &handlers {
63 h(params.clone());
64 }
65 }
66 return true;
67 }
68 }
69
70 false
71 }
72
73 pub async fn send(&self, method: &str, params: Value) -> Result<Value> {
75 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
76 let cmd = json!({"id": id, "method": method, "params": params});
77
78 let (tx, rx) = oneshot::channel();
79 self.pending.insert(id, tx);
80 self.send_tx
81 .send(cmd.to_string())
82 .map_err(|_| SpiderError::connection("WebSocket is not connected"))?;
83
84 let resp = tokio::time::timeout(
85 tokio::time::Duration::from_millis(self.timeout_ms),
86 rx,
87 )
88 .await
89 .map_err(|_| {
90 self.pending.remove(&id);
91 SpiderError::Timeout(format!("BiDi command timeout: {method} ({}ms)", self.timeout_ms))
92 })?
93 .map_err(|_| SpiderError::connection("BiDi response channel closed"))?;
94
95 if resp.get("type").and_then(|v| v.as_str()) == Some("error") {
96 let msg = resp.get("message").or(resp.get("error"))
97 .and_then(|v| v.as_str()).unwrap_or("unknown");
98 return Err(SpiderError::Protocol(format!("BiDi error: {msg}")));
99 }
100
101 Ok(resp)
102 }
103
104 pub fn on(&self, method: &str, handler: Arc<dyn Fn(Value) + Send + Sync>) {
105 self.event_handlers
106 .entry(method.to_string())
107 .or_default()
108 .push(handler);
109 }
110
111 pub async fn get_or_create_context(&self) -> Result<String> {
113 if let Some(ctx) = self.context() {
114 return Ok(ctx);
115 }
116
117 if let Ok(resp) = tokio::time::timeout(
119 tokio::time::Duration::from_secs(5),
120 self.send("browsingContext.getTree", json!({})),
121 ).await {
122 if let Ok(resp) = resp {
123 if let Some(contexts) = resp.get("result")
124 .and_then(|r| r.get("contexts"))
125 .and_then(|v| v.as_array())
126 {
127 if let Some(first) = contexts.first() {
128 if let Some(ctx) = first.get("context").and_then(|v| v.as_str()) {
129 self.browsing_context.store(Arc::new(Some(ctx.to_string())));
130 return Ok(ctx.to_string());
131 }
132 }
133 }
134 }
135 }
136
137 if let Ok(resp) = tokio::time::timeout(
139 tokio::time::Duration::from_secs(5),
140 self.send("browsingContext.create", json!({"type": "tab"})),
141 ).await {
142 if let Ok(resp) = resp {
143 if let Some(ctx) = resp.get("result")
144 .and_then(|r| r.get("context"))
145 .and_then(|v| v.as_str())
146 {
147 self.browsing_context.store(Arc::new(Some(ctx.to_string())));
148 return Ok(ctx.to_string());
149 }
150 }
151 }
152
153 let placeholder = "__default__".to_string();
155 self.browsing_context.store(Arc::new(Some(placeholder.clone())));
156 Ok(placeholder)
157 }
158
159 pub fn set_context(&self, context_id: &str) {
160 self.browsing_context.store(Arc::new(Some(context_id.to_string())));
161 }
162
163 pub async fn navigate(&self, url: &str) -> Result<()> {
164 let ctx = self.get_or_create_context().await?;
165 let resp = self.send("browsingContext.navigate", json!({
166 "context": ctx, "url": url, "wait": "complete",
167 })).await?;
168 if ctx == "__default__" {
170 if let Some(real_ctx) = resp.get("params")
171 .and_then(|p| p.get("context"))
172 .and_then(|v| v.as_str())
173 {
174 self.set_context(real_ctx);
175 }
176 }
177 Ok(())
178 }
179
180 pub async fn capture_screenshot(&self) -> Result<String> {
181 let ctx = self.get_or_create_context().await?;
182 let resp = self.send("browsingContext.captureScreenshot", json!({"context": ctx})).await?;
183 resp.get("result")
184 .and_then(|r| r.get("data"))
185 .and_then(|v| v.as_str())
186 .map(|s| s.to_string())
187 .ok_or_else(|| SpiderError::Protocol("captureScreenshot: missing result.data".into()))
188 }
189
190 pub async fn evaluate(&self, expression: &str) -> Result<Value> {
191 let ctx = self.get_or_create_context().await?;
192 let resp = self.send("script.evaluate", json!({
193 "expression": expression,
194 "target": {"context": ctx},
195 "awaitPromise": false,
196 "resultOwnership": "none",
197 })).await?;
198
199 let result_obj = resp.get("result")
200 .and_then(|r| r.get("result"))
201 .or_else(|| resp.get("result"))
202 .cloned()
203 .unwrap_or(Value::Null);
204 Ok(extract_bidi_value(&result_obj))
205 }
206
207 pub async fn get_html(&self) -> Result<String> {
208 let val = self.evaluate("document.documentElement.outerHTML").await?;
209 Ok(val.as_str().unwrap_or("").to_string())
210 }
211
212 pub async fn perform_actions(&self, actions: Value) -> Result<()> {
213 let ctx = self.get_or_create_context().await?;
214 self.send("input.performActions", json!({
215 "context": ctx, "actions": actions,
216 })).await?;
217 Ok(())
218 }
219
220 pub async fn click_point(&self, x: f64, y: f64) -> Result<()> {
221 self.perform_actions(json!([{
222 "type": "pointer", "id": "mouse",
223 "actions": [
224 {"type": "pointerMove", "x": x.round() as i64, "y": y.round() as i64},
225 {"type": "pointerDown", "button": 0},
226 {"type": "pointerUp", "button": 0},
227 ]
228 }])).await
229 }
230
231 pub async fn insert_text(&self, text: &str) -> Result<()> {
232 let actions: Vec<Value> = text.chars().flat_map(|ch| {
233 let s = ch.to_string();
234 vec![
235 json!({"type": "keyDown", "value": s}),
236 json!({"type": "keyUp", "value": s}),
237 ]
238 }).collect();
239 self.perform_actions(json!([{"type": "key", "id": "keyboard", "actions": actions}])).await
240 }
241
242 pub fn destroy(&self) {
243 self.pending.clear();
244 self.event_handlers.clear();
245 self.browsing_context.store(Arc::new(None));
246 }
247}
248
249fn extract_bidi_value(remote: &Value) -> Value {
250 match remote.get("type").and_then(|v| v.as_str()) {
251 Some("undefined") | Some("null") => Value::Null,
252 Some("string") | Some("number") | Some("boolean") | Some("bigint") => {
253 remote.get("value").cloned().unwrap_or(Value::Null)
254 }
255 Some("array") => {
256 if let Some(arr) = remote.get("value").and_then(|v| v.as_array()) {
257 Value::Array(arr.iter().map(extract_bidi_value).collect())
258 } else {
259 remote.get("value").cloned().unwrap_or(Value::Null)
260 }
261 }
262 Some("object") => {
263 if let Some(pairs) = remote.get("value").and_then(|v| v.as_array()) {
264 let mut map = serde_json::Map::new();
265 for entry in pairs {
266 if let Some(pair) = entry.as_array() {
267 if pair.len() == 2 {
268 let key = pair[0].as_str()
269 .map(|s| s.to_string())
270 .or_else(|| pair[0].get("value").and_then(|v| v.as_str()).map(|s| s.to_string()))
271 .unwrap_or_default();
272 map.insert(key, extract_bidi_value(&pair[1]));
273 }
274 }
275 }
276 Value::Object(map)
277 } else {
278 remote.get("value").cloned().unwrap_or(Value::Null)
279 }
280 }
281 _ => remote.get("value").cloned().unwrap_or(remote.clone()),
282 }
283}