Skip to main content

sh_layer3/builtin_tools/
web_search.rs

1//! # Web Search Tool
2//!
3//! Web search functionality supporting multiple search engines.
4//! Provides rate limiting, result caching, and error recovery.
5
6use crate::builtin_tools::BuiltinTool;
7use crate::types::{Layer3Result, ToolCategory};
8use async_trait::async_trait;
9use parking_lot::RwLock;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16// ============================================================================
17// Search Engine Configuration
18// ============================================================================
19
20/// Supported search engines
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22pub enum SearchEngine {
23    /// DuckDuckGo (free, no API key required)
24    #[default]
25    DuckDuckGo,
26    /// Google Custom Search (requires API key)
27    Google,
28    /// Bing Search (requires API key)
29    Bing,
30}
31
32/// Search engine configuration
33#[derive(Debug, Clone)]
34pub struct SearchEngineConfig {
35    /// Engine type
36    pub engine: SearchEngine,
37    /// API key (required for Google/Bing)
38    pub api_key: Option<String>,
39    /// Custom search engine ID (required for Google)
40    pub cx: Option<String>,
41    /// Maximum results per search
42    pub max_results: usize,
43    /// Request timeout in seconds
44    pub timeout_secs: u64,
45    /// Enable result caching
46    pub enable_cache: bool,
47    /// Cache TTL in seconds
48    pub cache_ttl_secs: u64,
49}
50
51impl Default for SearchEngineConfig {
52    fn default() -> Self {
53        Self {
54            engine: SearchEngine::DuckDuckGo,
55            api_key: None,
56            cx: None,
57            max_results: 10,
58            timeout_secs: 30,
59            enable_cache: true,
60            cache_ttl_secs: 3600,
61        }
62    }
63}
64
65// ============================================================================
66// Search Result Types
67// ============================================================================
68
69/// A single search result
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SearchResult {
72    /// Result title
73    pub title: String,
74    /// Result URL
75    pub url: String,
76    /// Result snippet/description
77    pub snippet: String,
78    /// Source engine
79    pub engine: String,
80    /// Position in results
81    pub position: usize,
82}
83
84/// Complete search response
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SearchResponse {
87    /// Search query
88    pub query: String,
89    /// Search results
90    pub results: Vec<SearchResult>,
91    /// Total results available
92    pub total: usize,
93    /// Search engine used
94    pub engine: String,
95    /// Response time in ms
96    pub response_time_ms: u64,
97    /// Whether results came from cache
98    pub from_cache: bool,
99}
100
101// ============================================================================
102// Rate Limiter
103// ============================================================================
104
105/// Simple rate limiter for API calls
106struct RateLimiter {
107    /// Minimum interval between requests
108    min_interval: Duration,
109    /// Last request time
110    last_request: RwLock<Option<Instant>>,
111}
112
113impl RateLimiter {
114    fn new(min_interval: Duration) -> Self {
115        Self {
116            min_interval,
117            last_request: RwLock::new(None),
118        }
119    }
120
121    async fn acquire(&self) {
122        loop {
123            let now = Instant::now();
124            let should_wait = {
125                let last = self.last_request.read();
126                if let Some(last_time) = *last {
127                    let elapsed = now.duration_since(last_time);
128                    elapsed < self.min_interval
129                } else {
130                    false
131                }
132            };
133
134            if should_wait {
135                tokio::time::sleep(Duration::from_millis(100)).await;
136            } else {
137                break;
138            }
139        }
140
141        // Update last request time
142        *self.last_request.write() = Some(Instant::now());
143    }
144}
145
146// ============================================================================
147// Result Cache
148// ============================================================================
149
150/// Cache entry with expiration
151struct CacheEntry {
152    response: SearchResponse,
153    created_at: Instant,
154    ttl: Duration,
155}
156
157impl CacheEntry {
158    fn is_expired(&self) -> bool {
159        Instant::now().duration_since(self.created_at) > self.ttl
160    }
161}
162
163/// Search result cache
164struct SearchResultCache {
165    entries: RwLock<HashMap<String, CacheEntry>>,
166}
167
168impl SearchResultCache {
169    fn new() -> Self {
170        Self {
171            entries: RwLock::new(HashMap::new()),
172        }
173    }
174
175    fn get(&self, key: &str) -> Option<SearchResponse> {
176        let entries = self.entries.read();
177        entries.get(key).and_then(|entry| {
178            if entry.is_expired() {
179                None
180            } else {
181                Some(entry.response.clone())
182            }
183        })
184    }
185
186    fn put(&self, key: String, response: SearchResponse, ttl: Duration) {
187        let mut entries = self.entries.write();
188        entries.insert(
189            key,
190            CacheEntry {
191                response,
192                created_at: Instant::now(),
193                ttl,
194            },
195        );
196
197        // Cleanup expired entries
198        let keys_to_remove: Vec<String> = entries
199            .iter()
200            .filter(|(_, e)| e.is_expired())
201            .map(|(k, _)| k.clone())
202            .collect();
203        for key in keys_to_remove {
204            entries.remove(&key);
205        }
206    }
207}
208
209// ============================================================================
210// Web Search Tool
211// ============================================================================
212
213/// Web Search Tool implementation
214pub struct WebSearchTool {
215    /// HTTP client
216    client: Client,
217    /// Configuration
218    config: SearchEngineConfig,
219    /// Rate limiter
220    rate_limiter: RateLimiter,
221    /// Result cache
222    cache: Option<Arc<SearchResultCache>>,
223}
224
225impl WebSearchTool {
226    /// Create a new web search tool with default configuration
227    pub fn new() -> Self {
228        Self::with_config(SearchEngineConfig::default())
229    }
230
231    /// Create with custom configuration
232    pub fn with_config(config: SearchEngineConfig) -> Self {
233        let client = Client::builder()
234            .timeout(Duration::from_secs(config.timeout_secs))
235            .user_agent("ContinuumSDK/1.0")
236            .build()
237            .unwrap_or_else(|_| Client::new());
238
239        let cache = if config.enable_cache {
240            Some(Arc::new(SearchResultCache::new()))
241        } else {
242            None
243        };
244
245        Self {
246            client,
247            config,
248            rate_limiter: RateLimiter::new(Duration::from_millis(500)),
249            cache,
250        }
251    }
252
253    /// Create with API key for Google/Bing
254    pub fn with_api_key(engine: SearchEngine, api_key: String, cx: Option<String>) -> Self {
255        let mut config = SearchEngineConfig {
256            engine,
257            api_key: Some(api_key),
258            cx: cx.clone(),
259            ..Default::default()
260        };
261
262        if engine == SearchEngine::Google && cx.is_none() {
263            // Use a default Google Custom Search Engine ID if not provided
264            config.cx = Some("017576662512468239146:omuauf_lfve".to_string());
265        }
266
267        Self::with_config(config)
268    }
269
270    /// Execute search
271    pub async fn search(&self, query: &str) -> Layer3Result<SearchResponse> {
272        // Check cache first
273        if let Some(cache) = &self.cache {
274            if let Some(cached) = cache.get(query) {
275                return Ok(cached);
276            }
277        }
278
279        // Rate limit
280        self.rate_limiter.acquire().await;
281
282        let start = Instant::now();
283        let results = match self.config.engine {
284            SearchEngine::DuckDuckGo => self.search_duckduckgo(query).await?,
285            SearchEngine::Google => self.search_google(query).await?,
286            SearchEngine::Bing => self.search_bing(query).await?,
287        };
288        let response_time_ms = start.elapsed().as_millis() as u64;
289
290        let response = SearchResponse {
291            query: query.to_string(),
292            results: results.clone(),
293            total: results.len(),
294            engine: format!("{:?}", self.config.engine),
295            response_time_ms,
296            from_cache: false,
297        };
298
299        // Cache the result
300        if let Some(cache) = &self.cache {
301            cache.put(
302                query.to_string(),
303                response.clone(),
304                Duration::from_secs(self.config.cache_ttl_secs),
305            );
306        }
307
308        Ok(response)
309    }
310
311    /// Search using DuckDuckGo (free, no API key required)
312    async fn search_duckduckgo(&self, query: &str) -> Layer3Result<Vec<SearchResult>> {
313        // DuckDuckGo Instant Answer API
314        let url = format!(
315            "https://api.duckduckgo.com/?q={}&format=json&no_html=1",
316            urlencoding::encode(query)
317        );
318
319        let response = self
320            .client
321            .get(&url)
322            .send()
323            .await
324            .map_err(|e| anyhow::anyhow!("DuckDuckGo API error: {}", e))?;
325
326        if !response.status().is_success() {
327            return Err(anyhow::anyhow!(
328                "DuckDuckGo API returned status: {}",
329                response.status()
330            ));
331        }
332
333        let json: serde_json::Value = response
334            .json()
335            .await
336            .map_err(|e| anyhow::anyhow!("Failed to parse DuckDuckGo response: {}", e))?;
337
338        Ok(self.parse_duckduckgo_results(&json))
339    }
340
341    fn parse_duckduckgo_results(&self, json: &serde_json::Value) -> Vec<SearchResult> {
342        let mut results = Vec::new();
343
344        // Parse instant answer
345        if let Some(abstract_text) = json.get("AbstractText").and_then(|v| v.as_str()) {
346            if !abstract_text.is_empty() {
347                if let Some(abstract_url) = json.get("AbstractURL").and_then(|v| v.as_str()) {
348                    if !abstract_url.is_empty() {
349                        results.push(SearchResult {
350                            title: json
351                                .get("Heading")
352                                .and_then(|v| v.as_str())
353                                .unwrap_or("DuckDuckGo Result")
354                                .to_string(),
355                            url: abstract_url.to_string(),
356                            snippet: abstract_text.to_string(),
357                            engine: "DuckDuckGo".to_string(),
358                            position: 1,
359                        });
360                    }
361                }
362            }
363        }
364
365        // Parse related topics
366        if let Some(topics) = json.get("RelatedTopics").and_then(|v| v.as_array()) {
367            for topic in topics.iter().take(self.config.max_results - results.len()) {
368                if let (Some(text), Some(url), Some(first_url)) = (
369                    topic.get("Text").and_then(|v| v.as_str()),
370                    topic.get("FirstURL").and_then(|v| v.as_str()),
371                    topic.get("FirstURL").and_then(|v| v.as_str()),
372                ) {
373                    if !text.is_empty() && !first_url.is_empty() {
374                        results.push(SearchResult {
375                            title: text.split(" - ").next().unwrap_or(text).to_string(),
376                            url: url.to_string(),
377                            snippet: text.to_string(),
378                            engine: "DuckDuckGo".to_string(),
379                            position: results.len() + 1,
380                        });
381                    }
382                }
383            }
384        }
385
386        results
387    }
388
389    /// Search using Google Custom Search API
390    async fn search_google(&self, query: &str) -> Layer3Result<Vec<SearchResult>> {
391        let api_key = self
392            .config
393            .api_key
394            .as_ref()
395            .ok_or_else(|| anyhow::anyhow!("Google Search requires an API key"))?;
396
397        let cx = self.config.cx.as_ref().ok_or_else(|| {
398            anyhow::anyhow!("Google Search requires a Custom Search Engine ID (cx)")
399        })?;
400
401        let url = format!(
402            "https://www.googleapis.com/customsearch/v1?key={}&cx={}&q={}&num={}",
403            api_key,
404            cx,
405            urlencoding::encode(query),
406            self.config.max_results
407        );
408
409        let response = self
410            .client
411            .get(&url)
412            .send()
413            .await
414            .map_err(|e| anyhow::anyhow!("Google API error: {}", e))?;
415
416        let status = response.status();
417        if !status.is_success() {
418            let error_body = response.text().await.unwrap_or_default();
419            return Err(anyhow::anyhow!(
420                "Google API returned status {}: {}",
421                status,
422                error_body
423            ));
424        }
425
426        let json: serde_json::Value = response
427            .json()
428            .await
429            .map_err(|e| anyhow::anyhow!("Failed to parse Google response: {}", e))?;
430
431        let mut results = Vec::new();
432
433        if let Some(items) = json.get("items").and_then(|v| v.as_array()) {
434            for (i, item) in items.iter().enumerate() {
435                results.push(SearchResult {
436                    title: item
437                        .get("title")
438                        .and_then(|v| v.as_str())
439                        .unwrap_or("")
440                        .to_string(),
441                    url: item
442                        .get("link")
443                        .and_then(|v| v.as_str())
444                        .unwrap_or("")
445                        .to_string(),
446                    snippet: item
447                        .get("snippet")
448                        .and_then(|v| v.as_str())
449                        .unwrap_or("")
450                        .to_string(),
451                    engine: "Google".to_string(),
452                    position: i + 1,
453                });
454            }
455        }
456
457        Ok(results)
458    }
459
460    /// Search using Bing Search API
461    async fn search_bing(&self, query: &str) -> Layer3Result<Vec<SearchResult>> {
462        let api_key = self
463            .config
464            .api_key
465            .as_ref()
466            .ok_or_else(|| anyhow::anyhow!("Bing Search requires an API key"))?;
467
468        let url = format!(
469            "https://api.bing.microsoft.com/v7.0/search?q={}&count={}",
470            urlencoding::encode(query),
471            self.config.max_results
472        );
473
474        let response = self
475            .client
476            .get(&url)
477            .header("Ocp-Apim-Subscription-Key", api_key)
478            .send()
479            .await
480            .map_err(|e| anyhow::anyhow!("Bing API error: {}", e))?;
481
482        if !response.status().is_success() {
483            return Err(anyhow::anyhow!(
484                "Bing API returned status: {}",
485                response.status()
486            ));
487        }
488
489        let json: serde_json::Value = response
490            .json()
491            .await
492            .map_err(|e| anyhow::anyhow!("Failed to parse Bing response: {}", e))?;
493
494        let mut results = Vec::new();
495
496        if let Some(web_pages) = json.get("webPages").and_then(|v| v.get("value")) {
497            if let Some(items) = web_pages.as_array() {
498                for (i, item) in items.iter().enumerate() {
499                    results.push(SearchResult {
500                        title: item
501                            .get("name")
502                            .and_then(|v| v.as_str())
503                            .unwrap_or("")
504                            .to_string(),
505                        url: item
506                            .get("url")
507                            .and_then(|v| v.as_str())
508                            .unwrap_or("")
509                            .to_string(),
510                        snippet: item
511                            .get("snippet")
512                            .and_then(|v| v.as_str())
513                            .unwrap_or("")
514                            .to_string(),
515                        engine: "Bing".to_string(),
516                        position: i + 1,
517                    });
518                }
519            }
520        }
521
522        Ok(results)
523    }
524}
525
526impl Default for WebSearchTool {
527    fn default() -> Self {
528        Self::new()
529    }
530}
531
532#[async_trait]
533impl BuiltinTool for WebSearchTool {
534    fn name(&self) -> &str {
535        "web_search"
536    }
537
538    fn description(&self) -> &str {
539        "Search the web for information using DuckDuckGo, Google, or Bing."
540    }
541
542    fn parameters_schema(&self) -> serde_json::Value {
543        serde_json::json!({
544            "type": "object",
545            "properties": {
546                "query": {
547                    "type": "string",
548                    "description": "The search query"
549                },
550                "engine": {
551                    "type": "string",
552                    "enum": ["duckduckgo", "google", "bing"],
553                    "description": "Search engine to use (default: duckduckgo)"
554                },
555                "max_results": {
556                    "type": "integer",
557                    "description": "Maximum number of results to return (default: 10)"
558                }
559            },
560            "required": ["query"]
561        })
562    }
563
564    fn category(&self) -> ToolCategory {
565        ToolCategory::Search
566    }
567
568    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
569        let query = args["query"]
570            .as_str()
571            .ok_or_else(|| anyhow::anyhow!("Missing query parameter"))?;
572
573        // Parse engine if specified
574        let engine_str = args["engine"].as_str().unwrap_or("duckduckgo");
575        let engine = match engine_str.to_lowercase().as_str() {
576            "google" => SearchEngine::Google,
577            "bing" => SearchEngine::Bing,
578            _ => SearchEngine::DuckDuckGo,
579        };
580
581        // Create a temporary tool with the requested engine
582        let tool = if engine != self.config.engine {
583            let mut config = self.config.clone();
584            config.engine = engine;
585            WebSearchTool::with_config(config)
586        } else {
587            // Can't clone self, so just use self
588            return self.search(query).await.map(|r| {
589                serde_json::to_string_pretty(&r).unwrap_or_else(|_| {
590                    r.results
591                        .iter()
592                        .map(|r| format!("{}: {}", r.title, r.url))
593                        .collect::<Vec<_>>()
594                        .join("\n")
595                })
596            });
597        };
598
599        tool.search(query).await.map(|r| {
600            serde_json::to_string_pretty(&r).unwrap_or_else(|_| {
601                r.results
602                    .iter()
603                    .map(|r| format!("{}: {}", r.title, r.url))
604                    .collect::<Vec<_>>()
605                    .join("\n")
606            })
607        })
608    }
609}
610
611// ============================================================================
612// URL Encoding Helper
613// ============================================================================
614
615mod urlencoding {
616    pub fn encode(s: &str) -> String {
617        url::form_urlencoded::byte_serialize(s.as_bytes()).collect()
618    }
619}
620
621// ============================================================================
622// Tests
623// ============================================================================
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628
629    #[test]
630    fn test_tool_creation() {
631        let tool = WebSearchTool::new();
632        assert_eq!(tool.name(), "web_search");
633        assert_eq!(tool.category(), ToolCategory::Search);
634    }
635
636    #[test]
637    fn test_config_default() {
638        let config = SearchEngineConfig::default();
639        assert_eq!(config.engine, SearchEngine::DuckDuckGo);
640        assert!(config.api_key.is_none());
641        assert_eq!(config.max_results, 10);
642    }
643
644    #[test]
645    fn test_cache_basic() {
646        let cache = SearchResultCache::new();
647        let response = SearchResponse {
648            query: "test".to_string(),
649            results: vec![],
650            total: 0,
651            engine: "DuckDuckGo".to_string(),
652            response_time_ms: 100,
653            from_cache: false,
654        };
655
656        cache.put(
657            "test".to_string(),
658            response.clone(),
659            Duration::from_secs(60),
660        );
661        let cached = cache.get("test");
662        assert!(cached.is_some());
663    }
664
665    #[tokio::test]
666    async fn test_rate_limiter() {
667        let limiter = RateLimiter::new(Duration::from_millis(100));
668
669        // First call should proceed immediately
670        let start = Instant::now();
671        limiter.acquire().await;
672        let elapsed = start.elapsed();
673        assert!(elapsed < Duration::from_millis(50));
674    }
675
676    #[test]
677    fn test_search_result_serialization() {
678        let result = SearchResult {
679            title: "Test".to_string(),
680            url: "https://example.com".to_string(),
681            snippet: "Test snippet".to_string(),
682            engine: "DuckDuckGo".to_string(),
683            position: 1,
684        };
685
686        let json = serde_json::to_string(&result).unwrap();
687        assert!(json.contains("Test"));
688        assert!(json.contains("example.com"));
689    }
690
691    #[test]
692    fn test_duckduckgo_no_results_returns_empty_list() {
693        let tool = WebSearchTool::new();
694        let json = serde_json::json!({
695            "AbstractText": "",
696            "AbstractURL": "",
697            "RelatedTopics": []
698        });
699
700        let results = tool.parse_duckduckgo_results(&json);
701
702        assert!(results.is_empty());
703    }
704}