Skip to main content

spider_browser/protocol/
transport.rs

1//! Lock-free WebSocket transport to Spider's browser fleet.
2
3use crate::errors::{Result, SpiderError};
4use crate::events::SpiderEventEmitter;
5use arc_swap::ArcSwap;
6use futures_util::{SinkExt, StreamExt};
7use serde_json::{json, Value};
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::Arc;
10use tokio::sync::mpsc;
11use tokio_tungstenite::tungstenite::Message;
12use tracing::{debug, info, warn};
13
14/// Sentinel for "no f64 value stored".
15const NO_F64: u64 = u64::MAX;
16
17/// Transport configuration options.
18#[derive(Clone, Debug)]
19pub struct TransportOptions {
20    pub api_key: String,
21    pub server_url: String,
22    pub browser: String,
23    pub url: Option<String>,
24    pub captcha: Option<String>,
25    pub stealth_level: u32,
26    pub connect_timeout_ms: u64,
27    pub command_timeout_ms: u64,
28    pub hedge: bool,
29    pub record: bool,
30    pub mode: Option<String>,
31    pub country: Option<String>,
32    pub proxy_url: Option<String>,
33}
34
35impl Default for TransportOptions {
36    fn default() -> Self {
37        Self {
38            api_key: String::new(),
39            server_url: "wss://browser.spider.cloud".into(),
40            browser: "auto".into(),
41            url: None,
42            captcha: Some("solve".into()),
43            stealth_level: 0,
44            connect_timeout_ms: 30_000,
45            command_timeout_ms: 30_000,
46            hedge: false,
47            record: false,
48            mode: None,
49            country: None,
50            proxy_url: None,
51        }
52    }
53}
54
55/// Lock-free WebSocket transport.
56///
57/// All state is stored in atomics or ArcSwap. The WS write half is owned
58/// by a dedicated writer task; callers send through a channel.
59pub struct Transport {
60    opts: TransportOptions,
61    current_browser: ArcSwap<String>,
62    stealth_level: AtomicU64,
63    emitter: SpiderEventEmitter,
64    /// Channel for outgoing WS messages — writer task drains this.
65    ws_send_tx: mpsc::UnboundedSender<String>,
66    ws_send_rx: tokio::sync::Mutex<Option<mpsc::UnboundedReceiver<String>>>,
67    /// Channel for incoming messages forwarded to CDPSession/BiDiSession.
68    message_tx: mpsc::UnboundedSender<String>,
69    message_rx: tokio::sync::Mutex<Option<mpsc::UnboundedReceiver<String>>>,
70    generation: AtomicU64,
71    connected: AtomicBool,
72    // Metering (stored as f64 bits in AtomicU64)
73    upgrade_credits: AtomicU64,
74    upgrade_stealth_tier: AtomicU64,
75    upgrade_proxy_tier: AtomicU64,
76    session_credits_used: AtomicU64,
77    /// Handle to abort recv+write tasks on close/reconnect.
78    task_handles: ArcSwap<Vec<tokio::task::JoinHandle<()>>>,
79}
80
81impl Transport {
82    pub fn new(opts: TransportOptions, emitter: SpiderEventEmitter) -> Arc<Self> {
83        let browser = if opts.browser == "auto" {
84            "chrome-h".to_string()
85        } else {
86            opts.browser.clone()
87        };
88        let stealth = opts.stealth_level;
89        let (ws_tx, ws_rx) = mpsc::unbounded_channel();
90        let (msg_tx, msg_rx) = mpsc::unbounded_channel();
91
92        Arc::new(Self {
93            opts,
94            current_browser: ArcSwap::from_pointee(browser),
95            stealth_level: AtomicU64::new(stealth as u64),
96            emitter,
97            ws_send_tx: ws_tx,
98            ws_send_rx: tokio::sync::Mutex::new(Some(ws_rx)),
99            message_tx: msg_tx,
100            message_rx: tokio::sync::Mutex::new(Some(msg_rx)),
101            generation: AtomicU64::new(0),
102            connected: AtomicBool::new(false),
103            upgrade_credits: AtomicU64::new(NO_F64),
104            upgrade_stealth_tier: AtomicU64::new(NO_F64),
105            upgrade_proxy_tier: AtomicU64::new(NO_F64),
106            session_credits_used: AtomicU64::new(NO_F64),
107            task_handles: ArcSwap::from_pointee(Vec::new()),
108        })
109    }
110
111    pub fn browser(&self) -> String {
112        self.current_browser.load().as_ref().clone()
113    }
114
115    pub fn is_connected(&self) -> bool {
116        self.connected.load(Ordering::Relaxed)
117    }
118
119    pub fn get_stealth_level(&self) -> u32 {
120        self.stealth_level.load(Ordering::Relaxed) as u32
121    }
122
123    pub fn set_stealth_level(&self, level: u32) {
124        self.stealth_level
125            .store(level.min(3) as u64, Ordering::Relaxed);
126    }
127
128    pub fn upgrade_credits(&self) -> Option<f64> {
129        let bits = self.upgrade_credits.load(Ordering::Relaxed);
130        if bits == NO_F64 {
131            None
132        } else {
133            Some(f64::from_bits(bits))
134        }
135    }
136
137    pub fn session_credits_used(&self) -> Option<f64> {
138        let bits = self.session_credits_used.load(Ordering::Relaxed);
139        if bits == NO_F64 {
140            None
141        } else {
142            Some(f64::from_bits(bits))
143        }
144    }
145
146    pub fn command_timeout_ms(&self) -> u64 {
147        self.opts.command_timeout_ms
148    }
149
150    /// Take the message receiver (can only be called once).
151    pub async fn take_message_rx(&self) -> Option<mpsc::UnboundedReceiver<String>> {
152        self.message_rx.lock().await.take()
153    }
154
155    /// Send a raw JSON string through the WebSocket channel.
156    pub fn send(&self, data: String) -> Result<()> {
157        self.ws_send_tx
158            .send(data)
159            .map_err(|_| SpiderError::connection("WebSocket is not connected"))
160    }
161
162    /// Connect to the Spider WebSocket with retry.
163    pub async fn connect(self: &Arc<Self>, max_attempts: u32) -> Result<()> {
164        if self.connected.load(Ordering::Relaxed) {
165            return Ok(());
166        }
167
168        let mut last_error = None;
169        for attempt in 1..=max_attempts {
170            match self.connect_internal().await {
171                Ok(()) => return Ok(()),
172                Err(e) => {
173                    if matches!(e, SpiderError::Auth(_)) {
174                        return Err(e);
175                    }
176                    last_error = Some(e);
177                    if attempt < max_attempts {
178                        let backoff = 500 * attempt as u64;
179                        warn!(
180                            "connect attempt {}/{} failed, retrying in {}ms",
181                            attempt, max_attempts, backoff
182                        );
183                        tokio::time::sleep(tokio::time::Duration::from_millis(backoff)).await;
184                    }
185                }
186            }
187        }
188        Err(last_error.unwrap())
189    }
190
191    /// Reconnect with a different browser type.
192    pub async fn reconnect(self: &Arc<Self>, browser: &str) -> Result<()> {
193        let prev = self.browser();
194        self.current_browser
195            .store(Arc::new(browser.to_string()));
196        self.close();
197        info!("switching browser: {} -> {}", prev, browser);
198        self.connect_internal().await
199    }
200
201    /// Close the WebSocket connection.
202    pub fn close(&self) {
203        let handles = self.task_handles.swap(Arc::new(Vec::new()));
204        for h in handles.iter() {
205            h.abort();
206        }
207        self.connected.store(false, Ordering::Relaxed);
208    }
209
210    fn build_url(&self, browser: &str, stealth: u32) -> String {
211        let base = self.opts.server_url.trim_end_matches('/');
212        let mut params = vec![format!("token={}", self.opts.api_key)];
213        if browser != "auto" {
214            params.push(format!("browser={browser}"));
215        }
216        if let Some(ref url) = self.opts.url {
217            params.push(format!("url={}", urlencoding::encode(url)));
218        }
219        if let Some(ref captcha) = self.opts.captcha {
220            if captcha != "off" {
221                params.push(format!("ai_captcha={captcha}"));
222            }
223        }
224        if stealth > 0 {
225            params.push(format!("s={stealth}"));
226        }
227        if self.opts.hedge {
228            params.push("hedge=true".into());
229        }
230        if self.opts.record {
231            params.push("record=true".into());
232        }
233        if let Some(ref mode) = self.opts.mode {
234            params.push(format!("mode={mode}"));
235        }
236        if let Some(ref country) = self.opts.country {
237            params.push(format!("country={country}"));
238        }
239        if let Some(ref proxy_url) = self.opts.proxy_url {
240            params.push(format!("proxy_url={}", urlencoding::encode(proxy_url)));
241        }
242        format!("{base}/v1/browser?{}", params.join("&"))
243    }
244
245    async fn connect_internal(self: &Arc<Self>) -> Result<()> {
246        let gen = self.generation.fetch_add(1, Ordering::Relaxed) + 1;
247        let browser = self.browser();
248        let stealth = self.get_stealth_level();
249
250        let url_str = self.build_url(&browser, stealth);
251        let safe_url = url_str.split("token=").next().unwrap_or(&url_str);
252        debug!("connecting to {}token=***", safe_url);
253
254        let timeout = tokio::time::Duration::from_millis(self.opts.connect_timeout_ms);
255
256        let ws_stream = tokio::time::timeout(timeout, async {
257            let (stream, _response) = tokio_tungstenite::connect_async(&url_str)
258                .await
259                .map_err(|e| {
260                    let msg = e.to_string();
261                    if msg.contains("429") {
262                        SpiderError::rate_limit("Server at capacity (429)")
263                    } else {
264                        SpiderError::connection(format!("WebSocket error: {e}"))
265                    }
266                })?;
267            Ok::<_, SpiderError>(stream)
268        })
269        .await
270        .map_err(|_| {
271            SpiderError::Timeout(format!(
272                "WebSocket connection timeout ({}ms)",
273                self.opts.connect_timeout_ms
274            ))
275        })??;
276
277        let (mut sink, mut stream) = ws_stream.split();
278        self.connected.store(true, Ordering::Relaxed);
279        self.emitter.emit("ws.open", json!({}));
280        info!("connected (browser={}, stealth={})", browser, stealth);
281
282        // Writer task — owns the sink, drains the ws_send channel.
283        // Take the receiver (only on first connect; subsequent connects create new channels).
284        let ws_rx = self.ws_send_rx.lock().await.take();
285        let write_handle = if let Some(mut rx) = ws_rx {
286            tokio::spawn(async move {
287                while let Some(data) = rx.recv().await {
288                    if sink.send(Message::Text(data.into())).await.is_err() {
289                        break;
290                    }
291                }
292                let _ = sink.close().await;
293            })
294        } else {
295            // Fallback: no receiver available (shouldn't happen in normal flow)
296            tokio::spawn(async move {
297                let _ = sink.close().await;
298            })
299        };
300
301        // Reader task — forwards messages to message_tx, intercepts Spider.* events.
302        let msg_tx = self.message_tx.clone();
303        let this = Arc::clone(self);
304
305        let read_handle = tokio::spawn(async move {
306            while let Some(msg) = stream.next().await {
307                if this.generation.load(Ordering::Relaxed) != gen {
308                    return;
309                }
310                match msg {
311                    Ok(Message::Text(text)) => {
312                        let text_str = text.to_string();
313                        if text_str.contains("\"Spider.") {
314                            if handle_spider_transport_event(
315                                &text_str,
316                                &this,
317                            ) {
318                                continue;
319                            }
320                        }
321                        let _ = msg_tx.send(text_str);
322                    }
323                    Ok(Message::Binary(data)) => {
324                        if let Ok(text) = String::from_utf8(data.to_vec()) {
325                            let _ = msg_tx.send(text);
326                        }
327                    }
328                    Ok(Message::Close(frame)) => {
329                        let (code, reason) = frame
330                            .map(|f| (f.code.into(), f.reason.to_string()))
331                            .unwrap_or((0u16, String::new()));
332                        if this.generation.load(Ordering::Relaxed) == gen {
333                            this.emitter.emit(
334                                "ws.close",
335                                json!({"code": code, "reason": reason}),
336                            );
337                        }
338                        break;
339                    }
340                    Err(e) => {
341                        if this.generation.load(Ordering::Relaxed) == gen {
342                            this.emitter.emit("ws.error", json!({"error": e.to_string()}));
343                        }
344                        break;
345                    }
346                    _ => {}
347                }
348            }
349            this.connected.store(false, Ordering::Relaxed);
350        });
351
352        self.task_handles
353            .store(Arc::new(vec![write_handle, read_handle]));
354        Ok(())
355    }
356}
357
358fn handle_spider_transport_event(
359    data: &str,
360    transport: &Transport,
361) -> bool {
362    let Ok(msg) = serde_json::from_str::<Value>(data) else {
363        return false;
364    };
365    let method = msg.get("method").and_then(|v| v.as_str()).unwrap_or("");
366    let params = msg.get("params").cloned().unwrap_or(json!({}));
367
368    match method {
369        "Spider.screencastFrame" => {
370            transport.emitter.emit("screencast.frame", params);
371            true
372        }
373        "Spider.interactionEvents" => {
374            transport.emitter.emit("screencast.interactionEvents", params);
375            true
376        }
377        "Spider.rrwebEvents" => {
378            transport.emitter.emit("screencast.rrwebEvents", params);
379            true
380        }
381        "Spider.recordingStarted" => {
382            transport.emitter.emit("recording.started", params);
383            true
384        }
385        "Spider.recordingCompleted" => {
386            transport.emitter.emit("recording.completed", params);
387            true
388        }
389        "Spider.metering" => {
390            if let Some(credits_used) = params.get("credits_used").and_then(|v| v.as_f64()) {
391                transport.session_credits_used.store(credits_used.to_bits(), Ordering::Relaxed);
392                let uc_bits = transport.upgrade_credits.load(Ordering::Relaxed);
393                let uc = if uc_bits == NO_F64 {
394                    0.0
395                } else {
396                    f64::from_bits(uc_bits)
397                };
398                let us_bits = transport.upgrade_stealth_tier.load(Ordering::Relaxed);
399                let us = if us_bits == NO_F64 {
400                    0u32
401                } else {
402                    f64::from_bits(us_bits) as u32
403                };
404                transport.emitter.emit(
405                    "metering",
406                    json!({
407                        "credits": uc,
408                        "rate": us,
409                        "session_credits_used": credits_used,
410                    }),
411                );
412                return true;
413            }
414            false
415        }
416        _ => false,
417    }
418}
419
420mod urlencoding {
421    pub fn encode(s: &str) -> String {
422        url::form_urlencoded::byte_serialize(s.as_bytes()).collect()
423    }
424}