Skip to main content

spider_agent/
tools.rs

1//! Custom tool support for external API calls.
2
3use dashmap::DashMap;
4use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use std::time::Duration;
8
9use crate::error::{AgentError, AgentResult};
10
11const DEFAULT_SPIDER_CLOUD_API_URL: &str = "https://api.spider.cloud";
12const DEFAULT_SPIDER_CLOUD_AUTH_HEADER: &str = "Authorization";
13const DEFAULT_TOOL_PREFIX: &str = "spider_cloud";
14
15fn strip_bearer_prefix(value: &str) -> &str {
16    let trimmed = value.trim();
17    if trimmed.len() >= 7 && trimmed[..7].eq_ignore_ascii_case("bearer ") {
18        trimmed[7..].trim_start()
19    } else {
20        trimmed
21    }
22}
23
24/// HTTP method for API calls.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26pub enum HttpMethod {
27    /// GET request.
28    Get,
29    /// POST request.
30    Post,
31    /// PUT request.
32    Put,
33    /// PATCH request.
34    Patch,
35    /// DELETE request.
36    Delete,
37}
38
39impl HttpMethod {
40    fn as_reqwest_method(&self) -> reqwest::Method {
41        match self {
42            HttpMethod::Get => reqwest::Method::GET,
43            HttpMethod::Post => reqwest::Method::POST,
44            HttpMethod::Put => reqwest::Method::PUT,
45            HttpMethod::Patch => reqwest::Method::PATCH,
46            HttpMethod::Delete => reqwest::Method::DELETE,
47        }
48    }
49}
50
51/// Authentication configuration for custom tools.
52#[derive(Debug, Clone)]
53pub enum AuthConfig {
54    /// No authentication.
55    None,
56    /// Bearer token authentication.
57    Bearer(String),
58    /// API key in header.
59    ApiKey {
60        /// Header name for the API key.
61        header: String,
62        /// API key value.
63        key: String,
64    },
65    /// Basic authentication.
66    Basic {
67        /// Username.
68        username: String,
69        /// Password.
70        password: String,
71    },
72    /// Custom header authentication.
73    CustomHeader {
74        /// Header name.
75        name: String,
76        /// Header value.
77        value: String,
78    },
79}
80
81/// Configuration for Spider Cloud tool registration.
82///
83/// By default this registers core routes:
84/// - `/crawl`
85/// - `/scrape`
86/// - `/search`
87/// - `/links`
88/// - `/transform`
89/// - `/unblocker`
90///
91/// AI routes are disabled by default and must be explicitly enabled with
92/// `with_enable_ai_routes(true)` because they require a Spider Cloud AI plan.
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(default)]
95pub struct SpiderCloudToolConfig {
96    /// Spider Cloud API key.
97    pub api_key: String,
98    /// Spider Cloud API base URL.
99    pub api_url: String,
100    /// Prefix used for registered tool names.
101    ///
102    /// Default: `spider_cloud`, resulting in names like `spider_cloud_scrape`.
103    /// Set to empty string for unprefixed names (`scrape`, `search`, etc.).
104    pub tool_name_prefix: String,
105    /// Header used for API key auth. Defaults to `Authorization`.
106    pub auth_header: String,
107    /// Whether to use `Bearer <key>` formatting for the Authorization header.
108    ///
109    /// Spider Cloud expects raw `Authorization: <key>` by default, so this is
110    /// `false` unless explicitly enabled.
111    pub use_bearer_auth: bool,
112    /// Request timeout in seconds for each tool call.
113    pub timeout_secs: u64,
114    /// Register `/crawl`.
115    pub include_crawl: bool,
116    /// Register `/scrape`.
117    pub include_scrape: bool,
118    /// Register `/search`.
119    pub include_search: bool,
120    /// Register `/links`.
121    pub include_links: bool,
122    /// Register `/transform`.
123    pub include_transform: bool,
124    /// Register `/unblocker`.
125    pub include_unblocker: bool,
126    /// Register `/ai/*` routes.
127    ///
128    /// These routes require a paid Spider Cloud AI subscription:
129    /// https://spider.cloud/ai/pricing
130    pub enable_ai_routes: bool,
131}
132
133impl Default for SpiderCloudToolConfig {
134    fn default() -> Self {
135        Self {
136            api_key: String::new(),
137            api_url: DEFAULT_SPIDER_CLOUD_API_URL.to_string(),
138            tool_name_prefix: DEFAULT_TOOL_PREFIX.to_string(),
139            auth_header: DEFAULT_SPIDER_CLOUD_AUTH_HEADER.to_string(),
140            use_bearer_auth: false,
141            timeout_secs: 60,
142            include_crawl: true,
143            include_scrape: true,
144            include_search: true,
145            include_links: true,
146            include_transform: true,
147            include_unblocker: true,
148            enable_ai_routes: false,
149        }
150    }
151}
152
153impl SpiderCloudToolConfig {
154    /// Create a Spider Cloud config with core routes enabled.
155    pub fn new(api_key: impl Into<String>) -> Self {
156        Self {
157            api_key: api_key.into(),
158            ..Self::default()
159        }
160    }
161
162    /// Set Spider Cloud API base URL.
163    pub fn with_api_url(mut self, api_url: impl Into<String>) -> Self {
164        self.api_url = api_url.into();
165        self
166    }
167
168    /// Set the prefix for generated tool names.
169    ///
170    /// Example:
171    /// - prefix `spider_cloud` -> `spider_cloud_search`
172    /// - prefix `web_api` -> `web_api_search`
173    /// - empty prefix -> `search`
174    pub fn with_tool_name_prefix(mut self, prefix: impl Into<String>) -> Self {
175        self.tool_name_prefix = prefix.into();
176        self
177    }
178
179    /// Set auth header name. Use non-default header names for custom gateways.
180    pub fn with_auth_header(mut self, auth_header: impl Into<String>) -> Self {
181        self.auth_header = auth_header.into();
182        self
183    }
184
185    /// Enable/disable Bearer formatting for Authorization auth.
186    ///
187    /// When `true`, sends `Authorization: Bearer <key>`.
188    /// When `false` (default), sends `Authorization: <key>`.
189    pub fn with_bearer_auth(mut self, enabled: bool) -> Self {
190        self.use_bearer_auth = enabled;
191        self
192    }
193
194    /// Set timeout in seconds for each registered tool.
195    pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
196        self.timeout_secs = timeout_secs.max(1);
197        self
198    }
199
200    /// Enable or disable `/unblocker` route registration.
201    pub fn with_unblocker(mut self, enabled: bool) -> Self {
202        self.include_unblocker = enabled;
203        self
204    }
205
206    /// Enable or disable `/transform` route registration.
207    pub fn with_transform(mut self, enabled: bool) -> Self {
208        self.include_transform = enabled;
209        self
210    }
211
212    /// Enable or disable AI route registration.
213    ///
214    /// AI routes require a paid Spider Cloud AI plan:
215    /// https://spider.cloud/ai/pricing
216    pub fn with_enable_ai_routes(mut self, enabled: bool) -> Self {
217        self.enable_ai_routes = enabled;
218        self
219    }
220
221    fn endpoint(&self, route: &str) -> String {
222        format!(
223            "{}/{}",
224            self.api_url.trim_end_matches('/'),
225            route.trim_start_matches('/')
226        )
227    }
228
229    fn tool_name(&self, suffix: &str) -> String {
230        let prefix = self.tool_name_prefix.trim().trim_end_matches('_');
231        if prefix.is_empty() {
232            suffix.to_string()
233        } else {
234            format!("{}_{}", prefix, suffix)
235        }
236    }
237
238    fn auth_tool(&self, tool: CustomTool) -> CustomTool {
239        if self
240            .auth_header
241            .eq_ignore_ascii_case(DEFAULT_SPIDER_CLOUD_AUTH_HEADER)
242        {
243            // Accept env inputs like `SPIDER_CLOUD_API_KEY=...` and
244            // `SPIDER_CLOUD_API_KEY=Bearer ...` without double-prefixing.
245            let token = strip_bearer_prefix(&self.api_key).to_string();
246            if self.use_bearer_auth {
247                tool.with_bearer_auth(token)
248            } else {
249                tool.with_api_key(self.auth_header.clone(), token)
250            }
251        } else {
252            tool.with_api_key(self.auth_header.clone(), self.api_key.trim().to_string())
253        }
254    }
255
256    fn build_tool(&self, name: &str, route: &str, description: &str) -> CustomTool {
257        let tool = CustomTool::new(name, self.endpoint(route))
258            .with_description(description)
259            .with_method(HttpMethod::Post)
260            .with_content_type("application/json")
261            .with_timeout(Duration::from_secs(self.timeout_secs))
262            .with_header(
263                "User-Agent",
264                format!("spider_agent/{}", env!("CARGO_PKG_VERSION")),
265            );
266        self.auth_tool(tool)
267    }
268
269    /// Build Spider Cloud tools from this configuration.
270    pub fn to_custom_tools(&self) -> Vec<CustomTool> {
271        let mut tools = Vec::new();
272
273        if self.include_crawl {
274            tools.push(self.build_tool(
275                &self.tool_name("crawl"),
276                "crawl",
277                "Spider Cloud /crawl endpoint for crawling and extraction.",
278            ));
279        }
280        if self.include_scrape {
281            tools.push(self.build_tool(
282                &self.tool_name("scrape"),
283                "scrape",
284                "Spider Cloud /scrape endpoint for page scraping and extraction.",
285            ));
286        }
287        if self.include_search {
288            tools.push(self.build_tool(
289                &self.tool_name("search"),
290                "search",
291                "Spider Cloud /search endpoint for web search plus page retrieval.",
292            ));
293        }
294        if self.include_links {
295            tools.push(self.build_tool(
296                &self.tool_name("links"),
297                "links",
298                "Spider Cloud /links endpoint for link extraction only.",
299            ));
300        }
301        if self.include_transform {
302            tools.push(self.build_tool(
303                &self.tool_name("transform"),
304                "transform",
305                "Spider Cloud /transform endpoint for structured content transformation.",
306            ));
307        }
308        if self.include_unblocker {
309            tools.push(self.build_tool(
310                &self.tool_name("unblocker"),
311                "unblocker",
312                "Spider Cloud /unblocker endpoint for anti-bot bypass and hard-to-reach pages.",
313            ));
314        }
315
316        if self.enable_ai_routes {
317            tools.push(self.build_tool(
318                &self.tool_name("ai_crawl"),
319                "ai/crawl",
320                "Spider Cloud /ai/crawl endpoint for AI-guided crawling (AI subscription required).",
321            ));
322            tools.push(self.build_tool(
323                &self.tool_name("ai_scrape"),
324                "ai/scrape",
325                "Spider Cloud /ai/scrape endpoint for AI-guided scraping (AI subscription required).",
326            ));
327            tools.push(self.build_tool(
328                &self.tool_name("ai_search"),
329                "ai/search",
330                "Spider Cloud /ai/search endpoint for AI-enhanced search (AI subscription required).",
331            ));
332            tools.push(self.build_tool(
333                &self.tool_name("ai_browser"),
334                "ai/browser",
335                "Spider Cloud /ai/browser endpoint for AI browser automation (AI subscription required).",
336            ));
337            tools.push(self.build_tool(
338                &self.tool_name("ai_links"),
339                "ai/links",
340                "Spider Cloud /ai/links endpoint for AI link extraction (AI subscription required).",
341            ));
342        }
343
344        tools
345    }
346}
347
348// ─── Spider Browser Cloud ────────────────────────────────────────────────────
349
350const DEFAULT_SPIDER_BROWSER_WSS_URL: &str = "wss://browser.spider.cloud/v1/browser";
351const DEFAULT_BROWSER_TOOL_PREFIX: &str = "spider_browser";
352
353/// Configuration for [Spider Browser Cloud](https://spider.cloud/docs/api#browser)
354/// tool registration.
355///
356/// Registers tools that interact with a remote CDP browser session at
357/// `wss://browser.spider.cloud/v1/browser?token=API_KEY`.
358///
359/// Tools registered by default:
360/// - `spider_browser_navigate` — navigate to a URL
361/// - `spider_browser_html` — get page HTML
362/// - `spider_browser_screenshot` — take a screenshot
363/// - `spider_browser_evaluate` — execute JavaScript
364/// - `spider_browser_click` — click a CSS selector
365/// - `spider_browser_fill` — fill an input element
366/// - `spider_browser_wait` — wait for a selector
367#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
368#[serde(default)]
369pub struct SpiderBrowserToolConfig {
370    /// Spider Cloud API key.
371    pub api_key: String,
372    /// WebSocket base URL for the browser endpoint.
373    pub wss_url: String,
374    /// Prefix used for registered tool names (default: `spider_browser`).
375    pub tool_name_prefix: String,
376    /// Enable stealth mode (anti-fingerprinting).
377    pub stealth: bool,
378    /// Browser type (e.g. `"chrome"`, `"firefox"`).
379    pub browser: Option<String>,
380    /// Country code for geo-targeting (e.g. `"us"`, `"gb"`).
381    pub country: Option<String>,
382    /// Request timeout in seconds for each tool call.
383    pub timeout_secs: u64,
384    /// Register navigate tool.
385    pub include_navigate: bool,
386    /// Register HTML extraction tool.
387    pub include_html: bool,
388    /// Register screenshot tool.
389    pub include_screenshot: bool,
390    /// Register JavaScript evaluation tool.
391    pub include_evaluate: bool,
392    /// Register click tool.
393    pub include_click: bool,
394    /// Register fill (type text) tool.
395    pub include_fill: bool,
396    /// Register wait-for-selector tool.
397    pub include_wait: bool,
398}
399
400impl Default for SpiderBrowserToolConfig {
401    fn default() -> Self {
402        Self {
403            api_key: String::new(),
404            wss_url: DEFAULT_SPIDER_BROWSER_WSS_URL.to_string(),
405            tool_name_prefix: DEFAULT_BROWSER_TOOL_PREFIX.to_string(),
406            stealth: false,
407            browser: None,
408            country: None,
409            timeout_secs: 60,
410            include_navigate: true,
411            include_html: true,
412            include_screenshot: true,
413            include_evaluate: true,
414            include_click: true,
415            include_fill: true,
416            include_wait: true,
417        }
418    }
419}
420
421impl SpiderBrowserToolConfig {
422    /// Create a config with the given API key.
423    pub fn new(api_key: impl Into<String>) -> Self {
424        Self {
425            api_key: api_key.into(),
426            ..Self::default()
427        }
428    }
429
430    /// Set a custom WSS base URL.
431    pub fn with_wss_url(mut self, url: impl Into<String>) -> Self {
432        self.wss_url = url.into();
433        self
434    }
435
436    /// Set the tool name prefix.
437    pub fn with_tool_name_prefix(mut self, prefix: impl Into<String>) -> Self {
438        self.tool_name_prefix = prefix.into();
439        self
440    }
441
442    /// Enable or disable stealth mode.
443    pub fn with_stealth(mut self, stealth: bool) -> Self {
444        self.stealth = stealth;
445        self
446    }
447
448    /// Set the browser type to request.
449    pub fn with_browser(mut self, browser: impl Into<String>) -> Self {
450        self.browser = Some(browser.into());
451        self
452    }
453
454    /// Set the country for geo-targeting.
455    pub fn with_country(mut self, country: impl Into<String>) -> Self {
456        self.country = Some(country.into());
457        self
458    }
459
460    /// Set timeout in seconds for each registered tool.
461    pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
462        self.timeout_secs = timeout_secs.max(1);
463        self
464    }
465
466    /// Build the full WSS connection URL with authentication and options.
467    pub fn connection_url(&self) -> String {
468        let mut url = self.wss_url.clone();
469        if url.contains('?') {
470            url.push('&');
471        } else {
472            url.push('?');
473        }
474        url.push_str("token=");
475        url.push_str(&self.api_key);
476
477        if self.stealth {
478            url.push_str("&stealth=true");
479        }
480        if let Some(ref browser) = self.browser {
481            url.push_str("&browser=");
482            url.push_str(browser);
483        }
484        if let Some(ref country) = self.country {
485            url.push_str("&country=");
486            url.push_str(country);
487        }
488        url
489    }
490
491    fn tool_name(&self, suffix: &str) -> String {
492        let prefix = self.tool_name_prefix.trim().trim_end_matches('_');
493        if prefix.is_empty() {
494            suffix.to_string()
495        } else {
496            format!("{}_{}", prefix, suffix)
497        }
498    }
499
500    /// Build the browser tools as custom tool definitions.
501    ///
502    /// These tools use the Spider Browser Cloud REST-like interface where
503    /// each tool POSTs a JSON action body to the browser endpoint.
504    pub fn to_custom_tools(&self) -> Vec<CustomTool> {
505        let mut tools = Vec::new();
506        let base = self.connection_url();
507
508        let build = |name: &str, desc: &str| -> CustomTool {
509            CustomTool::new(name, &base)
510                .with_description(desc)
511                .with_method(HttpMethod::Post)
512                .with_content_type("application/json")
513                .with_timeout(Duration::from_secs(self.timeout_secs))
514                .with_header(
515                    "User-Agent",
516                    format!("spider_agent/{}", env!("CARGO_PKG_VERSION")),
517                )
518        };
519
520        if self.include_navigate {
521            tools.push(build(
522                &self.tool_name("navigate"),
523                "Spider Browser Cloud: navigate to a URL. Body: {\"url\": \"...\"}",
524            ));
525        }
526        if self.include_html {
527            tools.push(build(
528                &self.tool_name("html"),
529                "Spider Browser Cloud: extract HTML from current page.",
530            ));
531        }
532        if self.include_screenshot {
533            tools.push(build(
534                &self.tool_name("screenshot"),
535                "Spider Browser Cloud: take a screenshot of the current page.",
536            ));
537        }
538        if self.include_evaluate {
539            tools.push(build(
540                &self.tool_name("evaluate"),
541                "Spider Browser Cloud: evaluate JavaScript on the page. Body: {\"script\": \"...\"}",
542            ));
543        }
544        if self.include_click {
545            tools.push(build(
546                &self.tool_name("click"),
547                "Spider Browser Cloud: click an element by CSS selector. Body: {\"selector\": \"...\"}",
548            ));
549        }
550        if self.include_fill {
551            tools.push(build(
552                &self.tool_name("fill"),
553                "Spider Browser Cloud: fill an input element. Body: {\"selector\": \"...\", \"value\": \"...\"}",
554            ));
555        }
556        if self.include_wait {
557            tools.push(build(
558                &self.tool_name("wait"),
559                "Spider Browser Cloud: wait for a CSS selector to appear. Body: {\"selector\": \"...\"}",
560            ));
561        }
562
563        tools
564    }
565}
566
567/// Configuration for a custom tool (external API call).
568#[derive(Debug, Clone)]
569pub struct CustomTool {
570    /// Unique name for this tool.
571    pub name: String,
572    /// Description of what this tool does.
573    pub description: String,
574    /// Base URL for the API.
575    pub base_url: String,
576    /// Default HTTP method.
577    pub method: HttpMethod,
578    /// Authentication configuration.
579    pub auth: AuthConfig,
580    /// Additional headers.
581    pub headers: Vec<(String, String)>,
582    /// Request timeout.
583    pub timeout: Duration,
584    /// Content type for requests.
585    pub content_type: Option<String>,
586}
587
588impl CustomTool {
589    /// Create a new custom tool with GET method.
590    pub fn new(name: impl Into<String>, base_url: impl Into<String>) -> Self {
591        Self {
592            name: name.into(),
593            description: String::new(),
594            base_url: base_url.into(),
595            method: HttpMethod::Get,
596            auth: AuthConfig::None,
597            headers: Vec::new(),
598            timeout: Duration::from_secs(30),
599            content_type: None,
600        }
601    }
602
603    /// Set description.
604    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
605        self.description = desc.into();
606        self
607    }
608
609    /// Set HTTP method.
610    pub fn with_method(mut self, method: HttpMethod) -> Self {
611        self.method = method;
612        self
613    }
614
615    /// Set bearer token authentication.
616    pub fn with_bearer_auth(mut self, token: impl Into<String>) -> Self {
617        self.auth = AuthConfig::Bearer(token.into());
618        self
619    }
620
621    /// Set API key authentication.
622    pub fn with_api_key(mut self, header: impl Into<String>, key: impl Into<String>) -> Self {
623        self.auth = AuthConfig::ApiKey {
624            header: header.into(),
625            key: key.into(),
626        };
627        self
628    }
629
630    /// Set basic authentication.
631    pub fn with_basic_auth(
632        mut self,
633        username: impl Into<String>,
634        password: impl Into<String>,
635    ) -> Self {
636        self.auth = AuthConfig::Basic {
637            username: username.into(),
638            password: password.into(),
639        };
640        self
641    }
642
643    /// Set custom header authentication.
644    pub fn with_custom_auth(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
645        self.auth = AuthConfig::CustomHeader {
646            name: name.into(),
647            value: value.into(),
648        };
649        self
650    }
651
652    /// Add a custom header.
653    pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
654        self.headers.push((name.into(), value.into()));
655        self
656    }
657
658    /// Set request timeout.
659    pub fn with_timeout(mut self, timeout: Duration) -> Self {
660        self.timeout = timeout;
661        self
662    }
663
664    /// Set content type.
665    pub fn with_content_type(mut self, content_type: impl Into<String>) -> Self {
666        self.content_type = Some(content_type.into());
667        self
668    }
669
670    /// Build the headers for a request.
671    fn build_headers(&self) -> AgentResult<HeaderMap> {
672        let mut headers = HeaderMap::new();
673
674        // Add authentication
675        match &self.auth {
676            AuthConfig::None => {}
677            AuthConfig::Bearer(token) => {
678                headers.insert(
679                    reqwest::header::AUTHORIZATION,
680                    HeaderValue::from_str(&format!("Bearer {}", token))
681                        .map_err(|e| AgentError::Tool(format!("Invalid bearer token: {}", e)))?,
682                );
683            }
684            AuthConfig::ApiKey { header, key } => {
685                let header_name = HeaderName::try_from(header.as_str())
686                    .map_err(|e| AgentError::Tool(format!("Invalid header name: {}", e)))?;
687                let header_value = HeaderValue::from_str(key)
688                    .map_err(|e| AgentError::Tool(format!("Invalid API key: {}", e)))?;
689                headers.insert(header_name, header_value);
690            }
691            AuthConfig::Basic { username, password } => {
692                let credentials = base64::Engine::encode(
693                    &base64::engine::general_purpose::STANDARD,
694                    format!("{}:{}", username, password),
695                );
696                headers.insert(
697                    reqwest::header::AUTHORIZATION,
698                    HeaderValue::from_str(&format!("Basic {}", credentials))
699                        .map_err(|e| AgentError::Tool(format!("Invalid basic auth: {}", e)))?,
700                );
701            }
702            AuthConfig::CustomHeader { name, value } => {
703                let header_name = HeaderName::try_from(name.as_str())
704                    .map_err(|e| AgentError::Tool(format!("Invalid header name: {}", e)))?;
705                let header_value = HeaderValue::from_str(value)
706                    .map_err(|e| AgentError::Tool(format!("Invalid header value: {}", e)))?;
707                headers.insert(header_name, header_value);
708            }
709        }
710
711        // Add content type if specified
712        if let Some(ref ct) = self.content_type {
713            headers.insert(
714                reqwest::header::CONTENT_TYPE,
715                HeaderValue::from_str(ct)
716                    .map_err(|e| AgentError::Tool(format!("Invalid content type: {}", e)))?,
717            );
718        }
719
720        // Add custom headers
721        for (name, value) in &self.headers {
722            let header_name = HeaderName::try_from(name.as_str())
723                .map_err(|e| AgentError::Tool(format!("Invalid header name '{}': {}", name, e)))?;
724            let header_value = HeaderValue::from_str(value).map_err(|e| {
725                AgentError::Tool(format!("Invalid header value for '{}': {}", name, e))
726            })?;
727            headers.insert(header_name, header_value);
728        }
729
730        Ok(headers)
731    }
732}
733
734/// Result from executing a custom tool.
735#[derive(Debug, Clone, Serialize, Deserialize)]
736pub struct CustomToolResult {
737    /// The tool name that was executed.
738    pub tool_name: String,
739    /// HTTP status code.
740    pub status: u16,
741    /// Response body.
742    pub body: String,
743    /// Response headers.
744    pub headers: Vec<(String, String)>,
745    /// Whether the request was successful (2xx status).
746    pub success: bool,
747}
748
749/// Registry for custom tools.
750#[derive(Debug, Default)]
751pub struct CustomToolRegistry {
752    tools: DashMap<String, Arc<CustomTool>>,
753}
754
755impl CustomToolRegistry {
756    /// Create a new empty registry.
757    pub fn new() -> Self {
758        Self {
759            tools: DashMap::new(),
760        }
761    }
762
763    /// Register a custom tool.
764    pub fn register(&self, tool: CustomTool) {
765        self.tools.insert(tool.name.clone(), Arc::new(tool));
766    }
767
768    /// Get a tool by name.
769    pub fn get(&self, name: &str) -> Option<Arc<CustomTool>> {
770        self.tools.get(name).map(|r| r.clone())
771    }
772
773    /// Remove a tool.
774    pub fn remove(&self, name: &str) -> Option<Arc<CustomTool>> {
775        self.tools.remove(name).map(|(_, v)| v)
776    }
777
778    /// List all registered tools.
779    pub fn list(&self) -> Vec<String> {
780        self.tools.iter().map(|e| e.key().clone()).collect()
781    }
782
783    /// Check if a tool is registered.
784    pub fn contains(&self, name: &str) -> bool {
785        self.tools.contains_key(name)
786    }
787
788    /// Clear all tools.
789    pub fn clear(&self) {
790        self.tools.clear();
791    }
792
793    /// Register Spider Cloud tools from a shared config.
794    ///
795    /// Returns the number of tools registered.
796    pub fn register_spider_cloud(&self, config: &SpiderCloudToolConfig) -> usize {
797        let tools = config.to_custom_tools();
798        let count = tools.len();
799        for tool in tools {
800            self.register(tool);
801        }
802        count
803    }
804
805    /// Register Spider Browser Cloud tools from a shared config.
806    ///
807    /// Returns the number of tools registered.
808    pub fn register_spider_browser(&self, config: &SpiderBrowserToolConfig) -> usize {
809        let tools = config.to_custom_tools();
810        let count = tools.len();
811        for tool in tools {
812            self.register(tool);
813        }
814        count
815    }
816
817    /// Execute a custom tool.
818    pub async fn execute(
819        &self,
820        name: &str,
821        client: &reqwest::Client,
822        path: Option<&str>,
823        query: Option<&[(&str, &str)]>,
824        body: Option<&str>,
825    ) -> AgentResult<CustomToolResult> {
826        let tool = self
827            .get(name)
828            .ok_or_else(|| AgentError::Tool(format!("Custom tool '{}' not found", name)))?;
829
830        // Build URL
831        let mut url = tool.base_url.clone();
832        if let Some(p) = path {
833            if !url.ends_with('/') && !p.starts_with('/') {
834                url.push('/');
835            }
836            url.push_str(p);
837        }
838
839        // Build request
840        let mut request = client
841            .request(tool.method.as_reqwest_method(), &url)
842            .timeout(tool.timeout)
843            .headers(tool.build_headers()?);
844
845        // Add query parameters
846        if let Some(q) = query {
847            request = request.query(q);
848        }
849
850        // Add body
851        if let Some(b) = body {
852            request = request.body(b.to_string());
853        }
854
855        // Execute
856        let response = request.send().await?;
857
858        let status = response.status().as_u16();
859        let success = response.status().is_success();
860
861        let headers: Vec<(String, String)> = response
862            .headers()
863            .iter()
864            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
865            .collect();
866
867        let body = response.text().await?;
868
869        Ok(CustomToolResult {
870            tool_name: name.to_string(),
871            status,
872            body,
873            headers,
874            success,
875        })
876    }
877}
878
879#[cfg(test)]
880mod tests {
881    use super::*;
882
883    #[test]
884    fn test_custom_tool_builder() {
885        let tool = CustomTool::new("my_api", "https://api.example.com")
886            .with_description("My custom API")
887            .with_method(HttpMethod::Post)
888            .with_bearer_auth("secret_token")
889            .with_header("X-Custom", "value")
890            .with_timeout(Duration::from_secs(60))
891            .with_content_type("application/json");
892
893        assert_eq!(tool.name, "my_api");
894        assert_eq!(tool.base_url, "https://api.example.com");
895        assert_eq!(tool.description, "My custom API");
896        assert_eq!(tool.method, HttpMethod::Post);
897        assert_eq!(tool.timeout, Duration::from_secs(60));
898        assert_eq!(tool.content_type, Some("application/json".to_string()));
899        assert_eq!(tool.headers.len(), 1);
900        assert!(matches!(tool.auth, AuthConfig::Bearer(_)));
901    }
902
903    #[test]
904    fn test_custom_tool_registry() {
905        let registry = CustomToolRegistry::new();
906
907        // Register tools
908        let tool1 = CustomTool::new("api_1", "https://api1.example.com");
909        let tool2 = CustomTool::new("api_2", "https://api2.example.com");
910
911        registry.register(tool1);
912        registry.register(tool2);
913
914        // Check registration
915        assert!(registry.contains("api_1"));
916        assert!(registry.contains("api_2"));
917        assert!(!registry.contains("api_3"));
918
919        // List tools
920        let tools = registry.list();
921        assert_eq!(tools.len(), 2);
922        assert!(tools.contains(&"api_1".to_string()));
923        assert!(tools.contains(&"api_2".to_string()));
924
925        // Get tool
926        let tool = registry.get("api_1");
927        assert!(tool.is_some());
928        assert_eq!(tool.unwrap().base_url, "https://api1.example.com");
929
930        // Remove tool
931        let removed = registry.remove("api_1");
932        assert!(removed.is_some());
933        assert!(!registry.contains("api_1"));
934
935        // Clear
936        registry.clear();
937        assert!(registry.list().is_empty());
938    }
939
940    #[test]
941    fn test_auth_config_variants() {
942        let tool =
943            CustomTool::new("test", "https://example.com").with_api_key("X-API-Key", "my_key");
944        assert!(matches!(tool.auth, AuthConfig::ApiKey { .. }));
945
946        let tool = CustomTool::new("test", "https://example.com").with_basic_auth("user", "pass");
947        assert!(matches!(tool.auth, AuthConfig::Basic { .. }));
948
949        let tool = CustomTool::new("test", "https://example.com")
950            .with_custom_auth("X-Custom-Auth", "token123");
951        assert!(matches!(tool.auth, AuthConfig::CustomHeader { .. }));
952    }
953
954    #[test]
955    fn test_http_method_conversion() {
956        assert_eq!(HttpMethod::Get.as_reqwest_method(), reqwest::Method::GET);
957        assert_eq!(HttpMethod::Post.as_reqwest_method(), reqwest::Method::POST);
958        assert_eq!(HttpMethod::Put.as_reqwest_method(), reqwest::Method::PUT);
959        assert_eq!(
960            HttpMethod::Patch.as_reqwest_method(),
961            reqwest::Method::PATCH
962        );
963        assert_eq!(
964            HttpMethod::Delete.as_reqwest_method(),
965            reqwest::Method::DELETE
966        );
967    }
968
969    #[test]
970    fn test_custom_tool_result() {
971        let result = CustomToolResult {
972            tool_name: "my_api".to_string(),
973            status: 200,
974            body: r#"{"success": true}"#.to_string(),
975            headers: vec![("content-type".to_string(), "application/json".to_string())],
976            success: true,
977        };
978
979        assert_eq!(result.tool_name, "my_api");
980        assert_eq!(result.status, 200);
981        assert!(result.success);
982    }
983
984    #[test]
985    fn test_spider_cloud_tools_default_routes_only() {
986        let cfg = SpiderCloudToolConfig::new("sk_spider_cloud");
987        let tools = cfg.to_custom_tools();
988        let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
989
990        assert_eq!(tools.len(), 6);
991        assert!(names.contains(&"spider_cloud_crawl"));
992        assert!(names.contains(&"spider_cloud_scrape"));
993        assert!(names.contains(&"spider_cloud_search"));
994        assert!(names.contains(&"spider_cloud_links"));
995        assert!(names.contains(&"spider_cloud_transform"));
996        assert!(names.contains(&"spider_cloud_unblocker"));
997
998        assert!(!names.contains(&"spider_cloud_ai_crawl"));
999        assert!(!names.contains(&"spider_cloud_ai_scrape"));
1000        assert!(!names.contains(&"spider_cloud_ai_search"));
1001        assert!(!names.contains(&"spider_cloud_ai_browser"));
1002        assert!(!names.contains(&"spider_cloud_ai_links"));
1003
1004        // Default auth should be raw Authorization header (not Bearer).
1005        let crawl = tools
1006            .iter()
1007            .find(|t| t.name == "spider_cloud_crawl")
1008            .expect("crawl tool");
1009        assert!(matches!(
1010            crawl.auth,
1011            AuthConfig::ApiKey {
1012                ref header,
1013                ref key
1014            } if header == "Authorization" && key == "sk_spider_cloud"
1015        ));
1016    }
1017
1018    #[test]
1019    fn test_spider_cloud_tools_with_ai_subscription_enabled() {
1020        let cfg = SpiderCloudToolConfig::new("sk_spider_cloud").with_enable_ai_routes(true);
1021        let tools = cfg.to_custom_tools();
1022        let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
1023
1024        assert_eq!(tools.len(), 11);
1025        assert!(names.contains(&"spider_cloud_ai_crawl"));
1026        assert!(names.contains(&"spider_cloud_ai_scrape"));
1027        assert!(names.contains(&"spider_cloud_ai_search"));
1028        assert!(names.contains(&"spider_cloud_ai_browser"));
1029        assert!(names.contains(&"spider_cloud_ai_links"));
1030    }
1031
1032    #[test]
1033    fn test_spider_cloud_registry_registration() {
1034        let registry = CustomToolRegistry::new();
1035        let cfg = SpiderCloudToolConfig::new("sk_spider_cloud")
1036            .with_unblocker(true)
1037            .with_transform(true)
1038            .with_enable_ai_routes(false);
1039        let count = registry.register_spider_cloud(&cfg);
1040
1041        assert_eq!(count, 6);
1042        assert!(registry.contains("spider_cloud_crawl"));
1043        assert!(registry.contains("spider_cloud_transform"));
1044        assert!(registry.contains("spider_cloud_unblocker"));
1045        assert!(!registry.contains("spider_cloud_ai_scrape"));
1046    }
1047
1048    #[test]
1049    fn test_spider_cloud_bearer_auth_opt_in() {
1050        let cfg = SpiderCloudToolConfig::new("sk_spider_cloud").with_bearer_auth(true);
1051        let tools = cfg.to_custom_tools();
1052        let crawl = tools
1053            .iter()
1054            .find(|t| t.name == "spider_cloud_crawl")
1055            .expect("crawl tool");
1056        assert!(matches!(crawl.auth, AuthConfig::Bearer(ref t) if t == "sk_spider_cloud"));
1057    }
1058
1059    #[test]
1060    fn test_spider_cloud_strips_bearer_prefix_in_default_mode() {
1061        let cfg = SpiderCloudToolConfig::new("Bearer sk_spider_cloud");
1062        let tools = cfg.to_custom_tools();
1063        let crawl = tools
1064            .iter()
1065            .find(|t| t.name == "spider_cloud_crawl")
1066            .expect("crawl tool");
1067        assert!(matches!(
1068            crawl.auth,
1069            AuthConfig::ApiKey {
1070                ref header,
1071                ref key
1072            } if header == "Authorization" && key == "sk_spider_cloud"
1073        ));
1074    }
1075
1076    #[test]
1077    fn test_spider_cloud_bearer_opt_in_avoids_double_prefix() {
1078        let cfg = SpiderCloudToolConfig::new("Bearer sk_spider_cloud").with_bearer_auth(true);
1079        let tools = cfg.to_custom_tools();
1080        let crawl = tools
1081            .iter()
1082            .find(|t| t.name == "spider_cloud_crawl")
1083            .expect("crawl tool");
1084        assert!(matches!(crawl.auth, AuthConfig::Bearer(ref t) if t == "sk_spider_cloud"));
1085    }
1086
1087    #[test]
1088    fn test_spider_cloud_custom_prefix_and_api_url() {
1089        let cfg = SpiderCloudToolConfig::new("sk_spider_cloud")
1090            .with_api_url("https://custom.provider.local/v1")
1091            .with_tool_name_prefix("web_api")
1092            .with_enable_ai_routes(false);
1093        let tools = cfg.to_custom_tools();
1094
1095        let transform = tools
1096            .iter()
1097            .find(|t| t.name == "web_api_transform")
1098            .expect("transform tool with custom prefix");
1099        assert_eq!(
1100            transform.base_url,
1101            "https://custom.provider.local/v1/transform"
1102        );
1103    }
1104
1105    #[test]
1106    fn test_spider_cloud_empty_prefix_uses_plain_names() {
1107        let cfg = SpiderCloudToolConfig::new("sk_spider_cloud").with_tool_name_prefix("");
1108        let tools = cfg.to_custom_tools();
1109        let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
1110
1111        assert!(names.contains(&"crawl"));
1112        assert!(names.contains(&"search"));
1113        assert!(names.contains(&"transform"));
1114    }
1115
1116    // ─── Spider Browser Cloud Tests ──────────────────────────────────────
1117
1118    #[test]
1119    fn test_spider_browser_config_defaults() {
1120        let cfg = SpiderBrowserToolConfig::new("test-key");
1121        assert_eq!(cfg.api_key, "test-key");
1122        assert_eq!(cfg.wss_url, "wss://browser.spider.cloud/v1/browser");
1123        assert!(!cfg.stealth);
1124        assert!(cfg.browser.is_none());
1125        assert!(cfg.country.is_none());
1126        assert_eq!(cfg.timeout_secs, 60);
1127        assert!(cfg.include_navigate);
1128        assert!(cfg.include_html);
1129        assert!(cfg.include_screenshot);
1130        assert!(cfg.include_evaluate);
1131        assert!(cfg.include_click);
1132        assert!(cfg.include_fill);
1133        assert!(cfg.include_wait);
1134    }
1135
1136    #[test]
1137    fn test_spider_browser_connection_url_basic() {
1138        let cfg = SpiderBrowserToolConfig::new("sk-abc");
1139        assert_eq!(
1140            cfg.connection_url(),
1141            "wss://browser.spider.cloud/v1/browser?token=sk-abc"
1142        );
1143    }
1144
1145    #[test]
1146    fn test_spider_browser_connection_url_with_options() {
1147        let cfg = SpiderBrowserToolConfig::new("key")
1148            .with_stealth(true)
1149            .with_browser("chrome")
1150            .with_country("gb");
1151        assert_eq!(
1152            cfg.connection_url(),
1153            "wss://browser.spider.cloud/v1/browser?token=key&stealth=true&browser=chrome&country=gb"
1154        );
1155    }
1156
1157    #[test]
1158    fn test_spider_browser_custom_wss_url() {
1159        let cfg =
1160            SpiderBrowserToolConfig::new("key").with_wss_url("wss://custom.example.com/browser");
1161        assert_eq!(
1162            cfg.connection_url(),
1163            "wss://custom.example.com/browser?token=key"
1164        );
1165    }
1166
1167    #[test]
1168    fn test_spider_browser_to_custom_tools_count() {
1169        let cfg = SpiderBrowserToolConfig::new("key");
1170        let tools = cfg.to_custom_tools();
1171        assert_eq!(tools.len(), 7); // navigate, html, screenshot, evaluate, click, fill, wait
1172    }
1173
1174    #[test]
1175    fn test_spider_browser_to_custom_tools_names() {
1176        let cfg = SpiderBrowserToolConfig::new("key");
1177        let tools = cfg.to_custom_tools();
1178        let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
1179        assert!(names.contains(&"spider_browser_navigate"));
1180        assert!(names.contains(&"spider_browser_html"));
1181        assert!(names.contains(&"spider_browser_screenshot"));
1182        assert!(names.contains(&"spider_browser_evaluate"));
1183        assert!(names.contains(&"spider_browser_click"));
1184        assert!(names.contains(&"spider_browser_fill"));
1185        assert!(names.contains(&"spider_browser_wait"));
1186    }
1187
1188    #[test]
1189    fn test_spider_browser_custom_prefix() {
1190        let cfg = SpiderBrowserToolConfig::new("key").with_tool_name_prefix("remote_browser");
1191        let tools = cfg.to_custom_tools();
1192        let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
1193        assert!(names.contains(&"remote_browser_navigate"));
1194        assert!(names.contains(&"remote_browser_html"));
1195    }
1196
1197    #[test]
1198    fn test_spider_browser_tools_use_wss_base_url() {
1199        let cfg = SpiderBrowserToolConfig::new("my-key").with_stealth(true);
1200        let tools = cfg.to_custom_tools();
1201        for tool in &tools {
1202            assert!(tool
1203                .base_url
1204                .starts_with("wss://browser.spider.cloud/v1/browser?token=my-key"));
1205            assert!(tool.base_url.contains("stealth=true"));
1206            assert_eq!(tool.method, HttpMethod::Post);
1207        }
1208    }
1209
1210    #[test]
1211    fn test_spider_browser_registry_register() {
1212        let registry = CustomToolRegistry::new();
1213        let cfg = SpiderBrowserToolConfig::new("key");
1214        let count = registry.register_spider_browser(&cfg);
1215        assert_eq!(count, 7);
1216        assert!(registry.contains("spider_browser_navigate"));
1217        assert!(registry.contains("spider_browser_html"));
1218    }
1219}