1use crate::builtin_tools::BuiltinTool;
7use crate::types::{Layer3Result, ToolCategory};
8use async_trait::async_trait;
9use parking_lot::RwLock;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22pub enum SearchEngine {
23 #[default]
25 DuckDuckGo,
26 Google,
28 Bing,
30}
31
32#[derive(Debug, Clone)]
34pub struct SearchEngineConfig {
35 pub engine: SearchEngine,
37 pub api_key: Option<String>,
39 pub cx: Option<String>,
41 pub max_results: usize,
43 pub timeout_secs: u64,
45 pub enable_cache: bool,
47 pub cache_ttl_secs: u64,
49}
50
51impl Default for SearchEngineConfig {
52 fn default() -> Self {
53 Self {
54 engine: SearchEngine::DuckDuckGo,
55 api_key: None,
56 cx: None,
57 max_results: 10,
58 timeout_secs: 30,
59 enable_cache: true,
60 cache_ttl_secs: 3600,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SearchResult {
72 pub title: String,
74 pub url: String,
76 pub snippet: String,
78 pub engine: String,
80 pub position: usize,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SearchResponse {
87 pub query: String,
89 pub results: Vec<SearchResult>,
91 pub total: usize,
93 pub engine: String,
95 pub response_time_ms: u64,
97 pub from_cache: bool,
99}
100
101struct RateLimiter {
107 min_interval: Duration,
109 last_request: RwLock<Option<Instant>>,
111}
112
113impl RateLimiter {
114 fn new(min_interval: Duration) -> Self {
115 Self {
116 min_interval,
117 last_request: RwLock::new(None),
118 }
119 }
120
121 async fn acquire(&self) {
122 loop {
123 let now = Instant::now();
124 let should_wait = {
125 let last = self.last_request.read();
126 if let Some(last_time) = *last {
127 let elapsed = now.duration_since(last_time);
128 elapsed < self.min_interval
129 } else {
130 false
131 }
132 };
133
134 if should_wait {
135 tokio::time::sleep(Duration::from_millis(100)).await;
136 } else {
137 break;
138 }
139 }
140
141 *self.last_request.write() = Some(Instant::now());
143 }
144}
145
146struct CacheEntry {
152 response: SearchResponse,
153 created_at: Instant,
154 ttl: Duration,
155}
156
157impl CacheEntry {
158 fn is_expired(&self) -> bool {
159 Instant::now().duration_since(self.created_at) > self.ttl
160 }
161}
162
163struct SearchResultCache {
165 entries: RwLock<HashMap<String, CacheEntry>>,
166}
167
168impl SearchResultCache {
169 fn new() -> Self {
170 Self {
171 entries: RwLock::new(HashMap::new()),
172 }
173 }
174
175 fn get(&self, key: &str) -> Option<SearchResponse> {
176 let entries = self.entries.read();
177 entries.get(key).and_then(|entry| {
178 if entry.is_expired() {
179 None
180 } else {
181 Some(entry.response.clone())
182 }
183 })
184 }
185
186 fn put(&self, key: String, response: SearchResponse, ttl: Duration) {
187 let mut entries = self.entries.write();
188 entries.insert(
189 key,
190 CacheEntry {
191 response,
192 created_at: Instant::now(),
193 ttl,
194 },
195 );
196
197 let keys_to_remove: Vec<String> = entries
199 .iter()
200 .filter(|(_, e)| e.is_expired())
201 .map(|(k, _)| k.clone())
202 .collect();
203 for key in keys_to_remove {
204 entries.remove(&key);
205 }
206 }
207}
208
209pub struct WebSearchTool {
215 client: Client,
217 config: SearchEngineConfig,
219 rate_limiter: RateLimiter,
221 cache: Option<Arc<SearchResultCache>>,
223}
224
225impl WebSearchTool {
226 pub fn new() -> Self {
228 Self::with_config(SearchEngineConfig::default())
229 }
230
231 pub fn with_config(config: SearchEngineConfig) -> Self {
233 let client = Client::builder()
234 .timeout(Duration::from_secs(config.timeout_secs))
235 .user_agent("ContinuumSDK/1.0")
236 .build()
237 .unwrap_or_else(|_| Client::new());
238
239 let cache = if config.enable_cache {
240 Some(Arc::new(SearchResultCache::new()))
241 } else {
242 None
243 };
244
245 Self {
246 client,
247 config,
248 rate_limiter: RateLimiter::new(Duration::from_millis(500)),
249 cache,
250 }
251 }
252
253 pub fn with_api_key(engine: SearchEngine, api_key: String, cx: Option<String>) -> Self {
255 let mut config = SearchEngineConfig {
256 engine,
257 api_key: Some(api_key),
258 cx: cx.clone(),
259 ..Default::default()
260 };
261
262 if engine == SearchEngine::Google && cx.is_none() {
263 config.cx = Some("017576662512468239146:omuauf_lfve".to_string());
265 }
266
267 Self::with_config(config)
268 }
269
270 pub async fn search(&self, query: &str) -> Layer3Result<SearchResponse> {
272 if let Some(cache) = &self.cache {
274 if let Some(cached) = cache.get(query) {
275 return Ok(cached);
276 }
277 }
278
279 self.rate_limiter.acquire().await;
281
282 let start = Instant::now();
283 let results = match self.config.engine {
284 SearchEngine::DuckDuckGo => self.search_duckduckgo(query).await?,
285 SearchEngine::Google => self.search_google(query).await?,
286 SearchEngine::Bing => self.search_bing(query).await?,
287 };
288 let response_time_ms = start.elapsed().as_millis() as u64;
289
290 let response = SearchResponse {
291 query: query.to_string(),
292 results: results.clone(),
293 total: results.len(),
294 engine: format!("{:?}", self.config.engine),
295 response_time_ms,
296 from_cache: false,
297 };
298
299 if let Some(cache) = &self.cache {
301 cache.put(
302 query.to_string(),
303 response.clone(),
304 Duration::from_secs(self.config.cache_ttl_secs),
305 );
306 }
307
308 Ok(response)
309 }
310
311 async fn search_duckduckgo(&self, query: &str) -> Layer3Result<Vec<SearchResult>> {
313 let url = format!(
315 "https://api.duckduckgo.com/?q={}&format=json&no_html=1",
316 urlencoding::encode(query)
317 );
318
319 let response = self
320 .client
321 .get(&url)
322 .send()
323 .await
324 .map_err(|e| anyhow::anyhow!("DuckDuckGo API error: {}", e))?;
325
326 if !response.status().is_success() {
327 return Err(anyhow::anyhow!(
328 "DuckDuckGo API returned status: {}",
329 response.status()
330 ));
331 }
332
333 let json: serde_json::Value = response
334 .json()
335 .await
336 .map_err(|e| anyhow::anyhow!("Failed to parse DuckDuckGo response: {}", e))?;
337
338 Ok(self.parse_duckduckgo_results(&json))
339 }
340
341 fn parse_duckduckgo_results(&self, json: &serde_json::Value) -> Vec<SearchResult> {
342 let mut results = Vec::new();
343
344 if let Some(abstract_text) = json.get("AbstractText").and_then(|v| v.as_str()) {
346 if !abstract_text.is_empty() {
347 if let Some(abstract_url) = json.get("AbstractURL").and_then(|v| v.as_str()) {
348 if !abstract_url.is_empty() {
349 results.push(SearchResult {
350 title: json
351 .get("Heading")
352 .and_then(|v| v.as_str())
353 .unwrap_or("DuckDuckGo Result")
354 .to_string(),
355 url: abstract_url.to_string(),
356 snippet: abstract_text.to_string(),
357 engine: "DuckDuckGo".to_string(),
358 position: 1,
359 });
360 }
361 }
362 }
363 }
364
365 if let Some(topics) = json.get("RelatedTopics").and_then(|v| v.as_array()) {
367 for topic in topics.iter().take(self.config.max_results - results.len()) {
368 if let (Some(text), Some(url), Some(first_url)) = (
369 topic.get("Text").and_then(|v| v.as_str()),
370 topic.get("FirstURL").and_then(|v| v.as_str()),
371 topic.get("FirstURL").and_then(|v| v.as_str()),
372 ) {
373 if !text.is_empty() && !first_url.is_empty() {
374 results.push(SearchResult {
375 title: text.split(" - ").next().unwrap_or(text).to_string(),
376 url: url.to_string(),
377 snippet: text.to_string(),
378 engine: "DuckDuckGo".to_string(),
379 position: results.len() + 1,
380 });
381 }
382 }
383 }
384 }
385
386 results
387 }
388
389 async fn search_google(&self, query: &str) -> Layer3Result<Vec<SearchResult>> {
391 let api_key = self
392 .config
393 .api_key
394 .as_ref()
395 .ok_or_else(|| anyhow::anyhow!("Google Search requires an API key"))?;
396
397 let cx = self.config.cx.as_ref().ok_or_else(|| {
398 anyhow::anyhow!("Google Search requires a Custom Search Engine ID (cx)")
399 })?;
400
401 let url = format!(
402 "https://www.googleapis.com/customsearch/v1?key={}&cx={}&q={}&num={}",
403 api_key,
404 cx,
405 urlencoding::encode(query),
406 self.config.max_results
407 );
408
409 let response = self
410 .client
411 .get(&url)
412 .send()
413 .await
414 .map_err(|e| anyhow::anyhow!("Google API error: {}", e))?;
415
416 let status = response.status();
417 if !status.is_success() {
418 let error_body = response.text().await.unwrap_or_default();
419 return Err(anyhow::anyhow!(
420 "Google API returned status {}: {}",
421 status,
422 error_body
423 ));
424 }
425
426 let json: serde_json::Value = response
427 .json()
428 .await
429 .map_err(|e| anyhow::anyhow!("Failed to parse Google response: {}", e))?;
430
431 let mut results = Vec::new();
432
433 if let Some(items) = json.get("items").and_then(|v| v.as_array()) {
434 for (i, item) in items.iter().enumerate() {
435 results.push(SearchResult {
436 title: item
437 .get("title")
438 .and_then(|v| v.as_str())
439 .unwrap_or("")
440 .to_string(),
441 url: item
442 .get("link")
443 .and_then(|v| v.as_str())
444 .unwrap_or("")
445 .to_string(),
446 snippet: item
447 .get("snippet")
448 .and_then(|v| v.as_str())
449 .unwrap_or("")
450 .to_string(),
451 engine: "Google".to_string(),
452 position: i + 1,
453 });
454 }
455 }
456
457 Ok(results)
458 }
459
460 async fn search_bing(&self, query: &str) -> Layer3Result<Vec<SearchResult>> {
462 let api_key = self
463 .config
464 .api_key
465 .as_ref()
466 .ok_or_else(|| anyhow::anyhow!("Bing Search requires an API key"))?;
467
468 let url = format!(
469 "https://api.bing.microsoft.com/v7.0/search?q={}&count={}",
470 urlencoding::encode(query),
471 self.config.max_results
472 );
473
474 let response = self
475 .client
476 .get(&url)
477 .header("Ocp-Apim-Subscription-Key", api_key)
478 .send()
479 .await
480 .map_err(|e| anyhow::anyhow!("Bing API error: {}", e))?;
481
482 if !response.status().is_success() {
483 return Err(anyhow::anyhow!(
484 "Bing API returned status: {}",
485 response.status()
486 ));
487 }
488
489 let json: serde_json::Value = response
490 .json()
491 .await
492 .map_err(|e| anyhow::anyhow!("Failed to parse Bing response: {}", e))?;
493
494 let mut results = Vec::new();
495
496 if let Some(web_pages) = json.get("webPages").and_then(|v| v.get("value")) {
497 if let Some(items) = web_pages.as_array() {
498 for (i, item) in items.iter().enumerate() {
499 results.push(SearchResult {
500 title: item
501 .get("name")
502 .and_then(|v| v.as_str())
503 .unwrap_or("")
504 .to_string(),
505 url: item
506 .get("url")
507 .and_then(|v| v.as_str())
508 .unwrap_or("")
509 .to_string(),
510 snippet: item
511 .get("snippet")
512 .and_then(|v| v.as_str())
513 .unwrap_or("")
514 .to_string(),
515 engine: "Bing".to_string(),
516 position: i + 1,
517 });
518 }
519 }
520 }
521
522 Ok(results)
523 }
524}
525
526impl Default for WebSearchTool {
527 fn default() -> Self {
528 Self::new()
529 }
530}
531
532#[async_trait]
533impl BuiltinTool for WebSearchTool {
534 fn name(&self) -> &str {
535 "web_search"
536 }
537
538 fn description(&self) -> &str {
539 "Search the web for information using DuckDuckGo, Google, or Bing."
540 }
541
542 fn parameters_schema(&self) -> serde_json::Value {
543 serde_json::json!({
544 "type": "object",
545 "properties": {
546 "query": {
547 "type": "string",
548 "description": "The search query"
549 },
550 "engine": {
551 "type": "string",
552 "enum": ["duckduckgo", "google", "bing"],
553 "description": "Search engine to use (default: duckduckgo)"
554 },
555 "max_results": {
556 "type": "integer",
557 "description": "Maximum number of results to return (default: 10)"
558 }
559 },
560 "required": ["query"]
561 })
562 }
563
564 fn category(&self) -> ToolCategory {
565 ToolCategory::Search
566 }
567
568 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
569 let query = args["query"]
570 .as_str()
571 .ok_or_else(|| anyhow::anyhow!("Missing query parameter"))?;
572
573 let engine_str = args["engine"].as_str().unwrap_or("duckduckgo");
575 let engine = match engine_str.to_lowercase().as_str() {
576 "google" => SearchEngine::Google,
577 "bing" => SearchEngine::Bing,
578 _ => SearchEngine::DuckDuckGo,
579 };
580
581 let tool = if engine != self.config.engine {
583 let mut config = self.config.clone();
584 config.engine = engine;
585 WebSearchTool::with_config(config)
586 } else {
587 return self.search(query).await.map(|r| {
589 serde_json::to_string_pretty(&r).unwrap_or_else(|_| {
590 r.results
591 .iter()
592 .map(|r| format!("{}: {}", r.title, r.url))
593 .collect::<Vec<_>>()
594 .join("\n")
595 })
596 });
597 };
598
599 tool.search(query).await.map(|r| {
600 serde_json::to_string_pretty(&r).unwrap_or_else(|_| {
601 r.results
602 .iter()
603 .map(|r| format!("{}: {}", r.title, r.url))
604 .collect::<Vec<_>>()
605 .join("\n")
606 })
607 })
608 }
609}
610
611mod urlencoding {
616 pub fn encode(s: &str) -> String {
617 url::form_urlencoded::byte_serialize(s.as_bytes()).collect()
618 }
619}
620
621#[cfg(test)]
626mod tests {
627 use super::*;
628
629 #[test]
630 fn test_tool_creation() {
631 let tool = WebSearchTool::new();
632 assert_eq!(tool.name(), "web_search");
633 assert_eq!(tool.category(), ToolCategory::Search);
634 }
635
636 #[test]
637 fn test_config_default() {
638 let config = SearchEngineConfig::default();
639 assert_eq!(config.engine, SearchEngine::DuckDuckGo);
640 assert!(config.api_key.is_none());
641 assert_eq!(config.max_results, 10);
642 }
643
644 #[test]
645 fn test_cache_basic() {
646 let cache = SearchResultCache::new();
647 let response = SearchResponse {
648 query: "test".to_string(),
649 results: vec![],
650 total: 0,
651 engine: "DuckDuckGo".to_string(),
652 response_time_ms: 100,
653 from_cache: false,
654 };
655
656 cache.put(
657 "test".to_string(),
658 response.clone(),
659 Duration::from_secs(60),
660 );
661 let cached = cache.get("test");
662 assert!(cached.is_some());
663 }
664
665 #[tokio::test]
666 async fn test_rate_limiter() {
667 let limiter = RateLimiter::new(Duration::from_millis(100));
668
669 let start = Instant::now();
671 limiter.acquire().await;
672 let elapsed = start.elapsed();
673 assert!(elapsed < Duration::from_millis(50));
674 }
675
676 #[test]
677 fn test_search_result_serialization() {
678 let result = SearchResult {
679 title: "Test".to_string(),
680 url: "https://example.com".to_string(),
681 snippet: "Test snippet".to_string(),
682 engine: "DuckDuckGo".to_string(),
683 position: 1,
684 };
685
686 let json = serde_json::to_string(&result).unwrap();
687 assert!(json.contains("Test"));
688 assert!(json.contains("example.com"));
689 }
690
691 #[test]
692 fn test_duckduckgo_no_results_returns_empty_list() {
693 let tool = WebSearchTool::new();
694 let json = serde_json::json!({
695 "AbstractText": "",
696 "AbstractURL": "",
697 "RelatedTopics": []
698 });
699
700 let results = tool.parse_duckduckgo_results(&json);
701
702 assert!(results.is_empty());
703 }
704}