Skip to main content

spider_browser/ai/
agent.rs

1//! Autonomous multi-step agent and action executor.
2//!
3//! Ported from TypeScript `ai/agent.ts`.
4//!
5//! Loop: screenshot -> HTML -> LLM -> parse plan -> execute actions -> repeat.
6
7use crate::ai::llm_provider::{LLMContent, LLMMessage, LLMProvider, LLMRole};
8use crate::ai::prompts::{build_user_message, SYSTEM_PROMPT};
9use crate::errors::{Result, SpiderError};
10use crate::events::SpiderEventEmitter;
11use crate::protocol::protocol_adapter::ProtocolAdapter;
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14use tokio::time::{sleep, Duration};
15use tracing::{info, warn};
16
17// -------------------------------------------------------------------
18// Action types (mirrors actions.rs AgentAction enum)
19// -------------------------------------------------------------------
20
21/// All possible agent actions.
22///
23/// Each variant is an externally-tagged enum so the JSON format is
24/// `{ "Click": "selector" }` or `{ "ClickPoint": { "x": 100, "y": 200 } }`.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub enum AgentAction {
27    Click(String),
28    ClickAll(String),
29    ClickPoint { x: f64, y: f64 },
30    ClickHold { selector: String, hold_ms: u64 },
31    ClickHoldPoint { x: f64, y: f64, hold_ms: u64 },
32    DoubleClick(String),
33    DoubleClickPoint { x: f64, y: f64 },
34    RightClick(String),
35    RightClickPoint { x: f64, y: f64 },
36    WaitForAndClick(String),
37    ClickDrag {
38        from: String,
39        to: String,
40        #[serde(default)]
41        modifier: Option<u32>,
42    },
43    ClickDragPoint {
44        from_x: f64,
45        from_y: f64,
46        to_x: f64,
47        to_y: f64,
48        #[serde(default)]
49        modifier: Option<u32>,
50    },
51    Type { value: String },
52    Fill { selector: String, value: String },
53    Clear(String),
54    Press(String),
55    KeyDown(String),
56    KeyUp(String),
57    Select { selector: String, value: String },
58    Focus(String),
59    Blur(String),
60    Hover(String),
61    HoverPoint { x: f64, y: f64 },
62    ScrollY(f64),
63    ScrollX(f64),
64    ScrollTo { selector: String },
65    ScrollToPoint { x: f64, y: f64 },
66    InfiniteScroll(u32),
67    Wait(u64),
68    WaitFor(String),
69    WaitForWithTimeout { selector: String, timeout: u64 },
70    WaitForNavigation,
71    WaitForDom {
72        #[serde(default)]
73        selector: Option<String>,
74        timeout: u64,
75    },
76    Navigate(String),
77    GoBack,
78    GoForward,
79    Reload,
80    SetViewport {
81        width: u32,
82        height: u32,
83        #[serde(default)]
84        device_scale_factor: Option<f64>,
85        #[serde(default)]
86        mobile: Option<bool>,
87    },
88    Evaluate(String),
89    Screenshot,
90}
91
92/// Parsed LLM response -- mirrors `actions.rs AgentPlan`.
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct AgentPlan {
95    #[serde(default)]
96    pub label: String,
97    #[serde(default)]
98    pub done: bool,
99    #[serde(default)]
100    pub steps: Option<Vec<AgentAction>>,
101    #[serde(default)]
102    pub extracted: Option<Value>,
103    #[serde(default)]
104    pub memory_ops: Option<Vec<Value>>,
105}
106
107/// Options for the autonomous agent loop.
108#[derive(Debug, Clone)]
109pub struct AgentOptions {
110    /// Max automation rounds (default: 30).
111    pub max_rounds: u32,
112    /// Delay in ms after actions for page settle (default: 1500).
113    pub step_delay_ms: u64,
114    /// Extra context/instruction for each round.
115    pub instruction: Option<String>,
116}
117
118impl Default for AgentOptions {
119    fn default() -> Self {
120        Self {
121            max_rounds: 30,
122            step_delay_ms: 1500,
123            instruction: None,
124        }
125    }
126}
127
128/// Result of an agent execution run.
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct AgentResult {
131    /// Whether the agent completed the task.
132    pub done: bool,
133    /// Number of rounds executed.
134    pub rounds: u32,
135    /// Extracted data (accumulated across rounds).
136    #[serde(default, skip_serializing_if = "Option::is_none")]
137    pub extracted: Option<Value>,
138    /// Final label from the agent.
139    #[serde(default)]
140    pub label: String,
141}
142
143/// Autonomous multi-step agent.
144pub struct Agent<'a> {
145    adapter: &'a ProtocolAdapter,
146    llm: &'a dyn LLMProvider,
147    emitter: &'a SpiderEventEmitter,
148    max_rounds: u32,
149    step_delay_ms: u64,
150}
151
152impl<'a> Agent<'a> {
153    pub fn new(
154        adapter: &'a ProtocolAdapter,
155        llm: &'a dyn LLMProvider,
156        emitter: &'a SpiderEventEmitter,
157        options: Option<AgentOptions>,
158    ) -> Self {
159        let opts = options.unwrap_or_default();
160        Self {
161            adapter,
162            llm,
163            emitter,
164            max_rounds: opts.max_rounds,
165            step_delay_ms: opts.step_delay_ms,
166        }
167    }
168
169    /// Execute the agent loop until the task is done or max rounds reached.
170    pub async fn execute(&self, instruction: &str) -> AgentResult {
171        let mut extracted: Option<Value> = None;
172        let mut last_label = String::new();
173
174        // Small initial delay for page to render
175        sleep(Duration::from_millis(500)).await;
176
177        for round in 0..self.max_rounds {
178            // 1. Capture screenshot
179            let screenshot = match self.adapter.capture_screenshot().await {
180                Ok(s) => s,
181                Err(err) => {
182                    warn!(round, error = %err, "agent: screenshot failed");
183                    break;
184                }
185            };
186
187            // 2. Get page HTML
188            let html = match self.adapter.get_html().await {
189                Ok(h) => h,
190                Err(err) => {
191                    warn!(round, error = %err, "agent: get HTML failed");
192                    break;
193                }
194            };
195
196            // 3. Get URL and title
197            let url = self
198                .adapter
199                .evaluate("window.location.href")
200                .await
201                .ok()
202                .and_then(|v| v.as_str().map(String::from))
203                .unwrap_or_else(|| "unknown".to_string());
204
205            let title = self
206                .adapter
207                .evaluate("document.title")
208                .await
209                .ok()
210                .and_then(|v| v.as_str().map(String::from))
211                .unwrap_or_default();
212
213            // 4. Call LLM
214            let context = format!(
215                "Round {}/{}. Task: {instruction}\nPAGE TITLE: {title}",
216                round + 1,
217                self.max_rounds
218            );
219
220            let messages = vec![
221                LLMMessage::system(SYSTEM_PROMPT),
222                LLMMessage {
223                    role: LLMRole::User,
224                    content: LLMContent::Parts(build_user_message(
225                        &url,
226                        &html,
227                        &screenshot,
228                        Some(&context),
229                    )),
230                },
231            ];
232
233            let plan: AgentPlan = match crate::ai::llm_provider::chat_json(self.llm, &messages).await {
234                Ok(p) => p,
235                Err(err) => {
236                    warn!(round, error = %err, "agent: LLM call failed");
237                    sleep(Duration::from_millis(2000)).await;
238                    continue;
239                }
240            };
241
242            last_label = plan.label.clone();
243            if plan.extracted.is_some() {
244                extracted = plan.extracted.clone();
245            }
246
247            let steps_count = plan.steps.as_ref().map(|s| s.len()).unwrap_or(0);
248
249            info!(
250                round = round + 1,
251                label = %plan.label,
252                done = plan.done,
253                steps = steps_count,
254                "agent: round"
255            );
256
257            self.emitter.emit(
258                "agent.step",
259                json!({
260                    "round": round + 1,
261                    "label": plan.label,
262                    "stepsCount": steps_count,
263                }),
264            );
265
266            // 5. Check if done
267            if plan.done {
268                self.emitter.emit(
269                    "agent.done",
270                    json!({
271                        "rounds": round + 1,
272                        "result": extracted,
273                    }),
274                );
275                return AgentResult {
276                    done: true,
277                    rounds: round + 1,
278                    extracted,
279                    label: last_label,
280                };
281            }
282
283            if steps_count == 0 {
284                info!("agent: no steps, retrying");
285                sleep(Duration::from_millis(self.step_delay_ms)).await;
286                continue;
287            }
288
289            // 6. Execute each step
290            if let Some(ref steps) = plan.steps {
291                for (i, action) in steps.iter().enumerate() {
292                    if let Err(err) = execute_action(self.adapter, action).await {
293                        warn!(
294                            round,
295                            step = i,
296                            error = %err,
297                            "agent: action failed"
298                        );
299                        break;
300                    }
301                    sleep(Duration::from_millis(200)).await;
302                }
303            }
304
305            // 7. Wait for page to settle
306            sleep(Duration::from_millis(self.step_delay_ms)).await;
307        }
308
309        // Max rounds exceeded
310        warn!("agent: max rounds exceeded");
311        self.emitter.emit(
312            "agent.error",
313            json!({
314                "error": "max rounds exceeded",
315                "round": self.max_rounds,
316            }),
317        );
318
319        AgentResult {
320            done: false,
321            rounds: self.max_rounds,
322            extracted,
323            label: last_label,
324        }
325    }
326}
327
328// -------------------------------------------------------------------
329// Action executor -- mirrors agent.rs execute_action()
330// -------------------------------------------------------------------
331
332/// Execute a single agent action via the protocol adapter.
333///
334/// Handles all action types from the [`AgentAction`] enum.
335pub async fn execute_action(adapter: &ProtocolAdapter, action: &AgentAction) -> Result<()> {
336    match action {
337        // ----- Click actions -----
338        AgentAction::Click(selector) => {
339            let (x, y) = get_element_center(adapter, selector).await?;
340            adapter.click_point(x, y).await
341        }
342        AgentAction::ClickAll(selector) => {
343            let js = format!(
344                r#"(function() {{
345                    var els = document.querySelectorAll({sel});
346                    return Array.from(els).map(function(el) {{
347                        var r = el.getBoundingClientRect();
348                        return {{ x: r.x + r.width / 2, y: r.y + r.height / 2 }};
349                    }});
350                }})()"#,
351                sel = serde_json::to_string(selector).unwrap_or_default()
352            );
353            let val = adapter.evaluate(&js).await?;
354            if let Some(points) = val.as_array() {
355                for pt in points {
356                    let x = pt.get("x").and_then(|v| v.as_f64()).unwrap_or(0.0);
357                    let y = pt.get("y").and_then(|v| v.as_f64()).unwrap_or(0.0);
358                    adapter.click_point(x, y).await?;
359                    sleep(Duration::from_millis(100)).await;
360                }
361            }
362            Ok(())
363        }
364        AgentAction::ClickPoint { x, y } => adapter.click_point(*x, *y).await,
365        AgentAction::ClickHold { selector, hold_ms } => {
366            let (x, y) = get_element_center(adapter, selector).await?;
367            adapter.click_hold_point(x, y, *hold_ms).await
368        }
369        AgentAction::ClickHoldPoint { x, y, hold_ms } => {
370            adapter.click_hold_point(*x, *y, *hold_ms).await
371        }
372        AgentAction::DoubleClick(selector) => {
373            let (x, y) = get_element_center(adapter, selector).await?;
374            adapter.double_click_point(x, y).await
375        }
376        AgentAction::DoubleClickPoint { x, y } => adapter.double_click_point(*x, *y).await,
377        AgentAction::RightClick(selector) => {
378            let (x, y) = get_element_center(adapter, selector).await?;
379            adapter.right_click_point(x, y).await
380        }
381        AgentAction::RightClickPoint { x, y } => adapter.right_click_point(*x, *y).await,
382        AgentAction::WaitForAndClick(selector) => {
383            wait_for_element(adapter, selector, 5000).await?;
384            let (x, y) = get_element_center(adapter, selector).await?;
385            adapter.click_point(x, y).await
386        }
387
388        // ----- Drag actions -----
389        AgentAction::ClickDrag { from, to, .. } => {
390            let (fx, fy) = get_element_center(adapter, from).await?;
391            let (tx, ty) = get_element_center(adapter, to).await?;
392            adapter.drag_point(fx, fy, tx, ty).await
393        }
394        AgentAction::ClickDragPoint {
395            from_x,
396            from_y,
397            to_x,
398            to_y,
399            ..
400        } => adapter.drag_point(*from_x, *from_y, *to_x, *to_y).await,
401
402        // ----- Input actions -----
403        AgentAction::Type { value } => adapter.insert_text(value).await,
404        AgentAction::Fill { selector, value } => {
405            // Clear via JS
406            let sel_json = serde_json::to_string(selector).unwrap_or_default();
407            let clear_js = format!(
408                r#"(function() {{
409                    var el = document.querySelector({sel_json});
410                    if (el) {{ el.focus(); el.value = ''; }}
411                }})()"#
412            );
413            adapter.evaluate(&clear_js).await?;
414
415            // Click for real focus
416            if let Ok((x, y)) = get_element_center(adapter, selector).await {
417                let _ = adapter.click_point(x, y).await;
418            }
419
420            // Insert text
421            adapter.insert_text(value).await?;
422
423            // Dispatch events
424            let event_js = format!(
425                r#"(function() {{
426                    var el = document.querySelector({sel_json});
427                    if (el) {{
428                        el.dispatchEvent(new Event('input', {{ bubbles: true }}));
429                        el.dispatchEvent(new Event('change', {{ bubbles: true }}));
430                    }}
431                }})()"#
432            );
433            adapter.evaluate(&event_js).await?;
434            Ok(())
435        }
436        AgentAction::Clear(selector) => {
437            let sel_json = serde_json::to_string(selector).unwrap_or_default();
438            let js = format!("document.querySelector({sel_json}).value = ''");
439            adapter.evaluate(&js).await?;
440            Ok(())
441        }
442        AgentAction::Press(key) => adapter.press_key(key).await,
443        AgentAction::KeyDown(key) => adapter.key_down(key).await,
444        AgentAction::KeyUp(key) => adapter.key_up(key).await,
445
446        // ----- Select & Focus -----
447        AgentAction::Select { selector, value } => {
448            let sel_json = serde_json::to_string(selector).unwrap_or_default();
449            let val_json = serde_json::to_string(value).unwrap_or_default();
450            let js = format!(
451                r#"(function() {{
452                    var el = document.querySelector({sel_json});
453                    if (el) {{
454                        el.value = {val_json};
455                        el.dispatchEvent(new Event('change', {{ bubbles: true }}));
456                    }}
457                }})()"#
458            );
459            adapter.evaluate(&js).await?;
460            Ok(())
461        }
462        AgentAction::Focus(selector) => {
463            let sel_json = serde_json::to_string(selector).unwrap_or_default();
464            adapter
465                .evaluate(&format!(
466                    "document.querySelector({sel_json})?.focus()"
467                ))
468                .await?;
469            Ok(())
470        }
471        AgentAction::Blur(selector) => {
472            let sel_json = serde_json::to_string(selector).unwrap_or_default();
473            adapter
474                .evaluate(&format!(
475                    "document.querySelector({sel_json})?.blur()"
476                ))
477                .await?;
478            Ok(())
479        }
480        AgentAction::Hover(selector) => {
481            let (x, y) = get_element_center(adapter, selector).await?;
482            adapter.hover_point(x, y).await
483        }
484        AgentAction::HoverPoint { x, y } => adapter.hover_point(*x, *y).await,
485
486        // ----- Scroll actions -----
487        AgentAction::ScrollY(delta) => {
488            adapter
489                .evaluate(&format!("window.scrollBy(0, {delta})"))
490                .await?;
491            Ok(())
492        }
493        AgentAction::ScrollX(delta) => {
494            adapter
495                .evaluate(&format!("window.scrollBy({delta}, 0)"))
496                .await?;
497            Ok(())
498        }
499        AgentAction::ScrollTo { selector } => {
500            let sel_json = serde_json::to_string(selector).unwrap_or_default();
501            adapter
502                .evaluate(&format!(
503                    "document.querySelector({sel_json})?.scrollIntoView({{ behavior: 'smooth', block: 'center' }})"
504                ))
505                .await?;
506            Ok(())
507        }
508        AgentAction::ScrollToPoint { x, y } => {
509            adapter
510                .evaluate(&format!("window.scrollTo({x}, {y})"))
511                .await?;
512            Ok(())
513        }
514        AgentAction::InfiniteScroll(max) => {
515            for _ in 0..*max {
516                adapter
517                    .evaluate("window.scrollTo(0, document.body.scrollHeight)")
518                    .await?;
519                sleep(Duration::from_millis(500)).await;
520            }
521            Ok(())
522        }
523
524        // ----- Wait actions -----
525        AgentAction::Wait(ms) => {
526            sleep(Duration::from_millis(*ms)).await;
527            Ok(())
528        }
529        AgentAction::WaitFor(selector) => wait_for_element(adapter, selector, 5000).await,
530        AgentAction::WaitForWithTimeout { selector, timeout } => {
531            wait_for_element(adapter, selector, *timeout).await
532        }
533        AgentAction::WaitForNavigation => {
534            sleep(Duration::from_millis(1000)).await;
535            Ok(())
536        }
537        AgentAction::WaitForDom { timeout, .. } => {
538            sleep(Duration::from_millis(*timeout)).await;
539            Ok(())
540        }
541
542        // ----- Navigation actions -----
543        AgentAction::Navigate(url) => adapter.navigate(url).await,
544        AgentAction::GoBack => {
545            adapter.evaluate("window.history.back()").await?;
546            Ok(())
547        }
548        AgentAction::GoForward => {
549            adapter.evaluate("window.history.forward()").await?;
550            Ok(())
551        }
552        AgentAction::Reload => {
553            adapter.evaluate("window.location.reload()").await?;
554            Ok(())
555        }
556
557        // ----- Viewport -----
558        AgentAction::SetViewport {
559            width,
560            height,
561            device_scale_factor,
562            mobile,
563        } => {
564            adapter
565                .set_viewport(
566                    *width,
567                    *height,
568                    device_scale_factor.unwrap_or(2.0),
569                    mobile.unwrap_or(false),
570                )
571                .await
572        }
573
574        // ----- JavaScript -----
575        AgentAction::Evaluate(code) => {
576            adapter.evaluate(code).await?;
577            Ok(())
578        }
579
580        // ----- Screenshot (no-op in client -- handled by agent loop) -----
581        AgentAction::Screenshot => Ok(()),
582    }
583}
584
585// -------------------------------------------------------------------
586// Helpers (mirrors agent.rs get_element_center / wait_for_element)
587// -------------------------------------------------------------------
588
589/// Get the center point of a DOM element, scrolling it into view first.
590async fn get_element_center(adapter: &ProtocolAdapter, selector: &str) -> Result<(f64, f64)> {
591    let sel_json = serde_json::to_string(selector).unwrap_or_default();
592    let js = format!(
593        r#"(function() {{
594            var el = document.querySelector({sel_json});
595            if (!el) return null;
596            el.scrollIntoView({{ block: 'center', behavior: 'instant' }});
597            var r = el.getBoundingClientRect();
598            return {{ x: r.x + r.width / 2, y: r.y + r.height / 2 }};
599        }})()"#
600    );
601
602    let result = adapter.evaluate(&js).await?;
603    if result.is_null() {
604        return Err(SpiderError::Other(format!(
605            "Element not found: {selector}"
606        )));
607    }
608
609    let x = result
610        .get("x")
611        .and_then(|v| v.as_f64())
612        .ok_or_else(|| SpiderError::Other(format!("Missing x for element: {selector}")))?;
613    let y = result
614        .get("y")
615        .and_then(|v| v.as_f64())
616        .ok_or_else(|| SpiderError::Other(format!("Missing y for element: {selector}")))?;
617
618    Ok((x, y))
619}
620
621/// Poll for a DOM element until it appears or timeout expires.
622async fn wait_for_element(adapter: &ProtocolAdapter, selector: &str, timeout_ms: u64) -> Result<()> {
623    let interval = 100u64;
624    let max_iter = (timeout_ms + interval - 1) / interval;
625    let sel_json = serde_json::to_string(selector).unwrap_or_default();
626    let check_js = format!("!!document.querySelector({sel_json})");
627
628    for _ in 0..max_iter {
629        let found = adapter.evaluate(&check_js).await?;
630        if found.as_bool().unwrap_or(false) {
631            return Ok(());
632        }
633        sleep(Duration::from_millis(interval)).await;
634    }
635
636    Err(SpiderError::Timeout(format!(
637        "Timeout waiting for element: {selector}"
638    )))
639}