Skip to main content

spider_agent/
config.rs

1//! Configuration types for spider_agent.
2
3use dashmap::DashMap;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::Duration;
8
9/// Usage limits for controlling agent resource consumption.
10#[derive(Debug, Clone, Default)]
11pub struct UsageLimits {
12    /// Maximum total tokens (prompt + completion).
13    pub max_total_tokens: Option<u64>,
14    /// Maximum prompt tokens.
15    pub max_prompt_tokens: Option<u64>,
16    /// Maximum completion tokens.
17    pub max_completion_tokens: Option<u64>,
18    /// Maximum LLM API calls.
19    pub max_llm_calls: Option<u64>,
20    /// Maximum search API calls.
21    pub max_search_calls: Option<u64>,
22    /// Maximum HTTP fetch calls.
23    pub max_fetch_calls: Option<u64>,
24    /// Maximum web browser calls (Chrome/WebDriver combined).
25    pub max_webbrowser_calls: Option<u64>,
26    /// Maximum custom tool calls.
27    pub max_custom_tool_calls: Option<u64>,
28    /// Maximum generic tool calls.
29    pub max_tool_calls: Option<u64>,
30}
31
32impl UsageLimits {
33    /// Create new usage limits with no restrictions.
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    /// Set maximum total tokens.
39    pub fn with_max_total_tokens(mut self, limit: u64) -> Self {
40        self.max_total_tokens = Some(limit);
41        self
42    }
43
44    /// Set maximum prompt tokens.
45    pub fn with_max_prompt_tokens(mut self, limit: u64) -> Self {
46        self.max_prompt_tokens = Some(limit);
47        self
48    }
49
50    /// Set maximum completion tokens.
51    pub fn with_max_completion_tokens(mut self, limit: u64) -> Self {
52        self.max_completion_tokens = Some(limit);
53        self
54    }
55
56    /// Set maximum LLM calls.
57    pub fn with_max_llm_calls(mut self, limit: u64) -> Self {
58        self.max_llm_calls = Some(limit);
59        self
60    }
61
62    /// Set maximum search calls.
63    pub fn with_max_search_calls(mut self, limit: u64) -> Self {
64        self.max_search_calls = Some(limit);
65        self
66    }
67
68    /// Set maximum fetch calls.
69    pub fn with_max_fetch_calls(mut self, limit: u64) -> Self {
70        self.max_fetch_calls = Some(limit);
71        self
72    }
73
74    /// Set maximum web browser calls.
75    pub fn with_max_webbrowser_calls(mut self, limit: u64) -> Self {
76        self.max_webbrowser_calls = Some(limit);
77        self
78    }
79
80    /// Set maximum custom tool calls.
81    pub fn with_max_custom_tool_calls(mut self, limit: u64) -> Self {
82        self.max_custom_tool_calls = Some(limit);
83        self
84    }
85
86    /// Set maximum tool calls.
87    pub fn with_max_tool_calls(mut self, limit: u64) -> Self {
88        self.max_tool_calls = Some(limit);
89        self
90    }
91}
92
93/// Type of limit that was exceeded.
94#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
95pub enum LimitType {
96    /// Total tokens limit exceeded.
97    TotalTokens {
98        /// Tokens used so far.
99        used: u64,
100        /// The limit that was set.
101        limit: u64,
102    },
103    /// Prompt tokens limit exceeded.
104    PromptTokens {
105        /// Tokens used so far.
106        used: u64,
107        /// The limit that was set.
108        limit: u64,
109    },
110    /// Completion tokens limit exceeded.
111    CompletionTokens {
112        /// Tokens used so far.
113        used: u64,
114        /// The limit that was set.
115        limit: u64,
116    },
117    /// LLM calls limit exceeded.
118    LlmCalls {
119        /// Calls made so far.
120        used: u64,
121        /// The limit that was set.
122        limit: u64,
123    },
124    /// Search calls limit exceeded.
125    SearchCalls {
126        /// Calls made so far.
127        used: u64,
128        /// The limit that was set.
129        limit: u64,
130    },
131    /// Fetch calls limit exceeded.
132    FetchCalls {
133        /// Calls made so far.
134        used: u64,
135        /// The limit that was set.
136        limit: u64,
137    },
138    /// Web browser calls limit exceeded.
139    WebbrowserCalls {
140        /// Calls made so far.
141        used: u64,
142        /// The limit that was set.
143        limit: u64,
144    },
145    /// Custom tool calls limit exceeded.
146    CustomToolCalls {
147        /// Calls made so far.
148        used: u64,
149        /// The limit that was set.
150        limit: u64,
151    },
152    /// Tool calls limit exceeded.
153    ToolCalls {
154        /// Calls made so far.
155        used: u64,
156        /// The limit that was set.
157        limit: u64,
158    },
159}
160
161impl std::fmt::Display for LimitType {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        match self {
164            Self::TotalTokens { used, limit } => {
165                write!(f, "total tokens ({} used, {} limit)", used, limit)
166            }
167            Self::PromptTokens { used, limit } => {
168                write!(f, "prompt tokens ({} used, {} limit)", used, limit)
169            }
170            Self::CompletionTokens { used, limit } => {
171                write!(f, "completion tokens ({} used, {} limit)", used, limit)
172            }
173            Self::LlmCalls { used, limit } => {
174                write!(f, "LLM calls ({} used, {} limit)", used, limit)
175            }
176            Self::SearchCalls { used, limit } => {
177                write!(f, "search calls ({} used, {} limit)", used, limit)
178            }
179            Self::FetchCalls { used, limit } => {
180                write!(f, "fetch calls ({} used, {} limit)", used, limit)
181            }
182            Self::WebbrowserCalls { used, limit } => {
183                write!(f, "web browser calls ({} used, {} limit)", used, limit)
184            }
185            Self::CustomToolCalls { used, limit } => {
186                write!(f, "custom tool calls ({} used, {} limit)", used, limit)
187            }
188            Self::ToolCalls { used, limit } => {
189                write!(f, "tool calls ({} used, {} limit)", used, limit)
190            }
191        }
192    }
193}
194
195/// Agent configuration.
196#[derive(Debug, Clone)]
197pub struct AgentConfig {
198    /// System prompt for LLM.
199    pub system_prompt: Option<String>,
200
201    /// Max concurrent LLM calls.
202    pub max_concurrent_llm_calls: usize,
203
204    /// LLM temperature (0.0 - 1.0).
205    pub temperature: f32,
206
207    /// Max tokens for LLM response.
208    pub max_tokens: u16,
209
210    /// Request timeout.
211    pub timeout: Duration,
212
213    /// Retry configuration.
214    pub retry: RetryConfig,
215
216    /// Max HTML bytes to send to LLM.
217    pub html_max_bytes: usize,
218
219    /// HTML cleaning mode.
220    pub html_cleaning: HtmlCleaningMode,
221
222    /// Whether to request JSON output from LLM.
223    pub json_mode: bool,
224
225    /// Usage limits for resource control.
226    pub limits: UsageLimits,
227}
228
229impl Default for AgentConfig {
230    fn default() -> Self {
231        Self {
232            system_prompt: None,
233            max_concurrent_llm_calls: 5,
234            temperature: 0.1,
235            max_tokens: 4096,
236            timeout: Duration::from_secs(60),
237            retry: RetryConfig::default(),
238            html_max_bytes: 24_000,
239            html_cleaning: HtmlCleaningMode::Default,
240            json_mode: true,
241            limits: UsageLimits::default(),
242        }
243    }
244}
245
246impl AgentConfig {
247    /// Create a new config with defaults.
248    pub fn new() -> Self {
249        Self::default()
250    }
251
252    /// Set system prompt.
253    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
254        self.system_prompt = Some(prompt.into());
255        self
256    }
257
258    /// Set max concurrent LLM calls.
259    pub fn with_max_concurrent_llm_calls(mut self, n: usize) -> Self {
260        self.max_concurrent_llm_calls = n;
261        self
262    }
263
264    /// Set LLM temperature.
265    pub fn with_temperature(mut self, temp: f32) -> Self {
266        self.temperature = temp.clamp(0.0, 2.0);
267        self
268    }
269
270    /// Set max tokens.
271    pub fn with_max_tokens(mut self, tokens: u16) -> Self {
272        self.max_tokens = tokens;
273        self
274    }
275
276    /// Set request timeout.
277    pub fn with_timeout(mut self, timeout: Duration) -> Self {
278        self.timeout = timeout;
279        self
280    }
281
282    /// Set retry config.
283    pub fn with_retry(mut self, retry: RetryConfig) -> Self {
284        self.retry = retry;
285        self
286    }
287
288    /// Set HTML max bytes.
289    pub fn with_html_max_bytes(mut self, bytes: usize) -> Self {
290        self.html_max_bytes = bytes;
291        self
292    }
293
294    /// Set HTML cleaning mode.
295    pub fn with_html_cleaning(mut self, mode: HtmlCleaningMode) -> Self {
296        self.html_cleaning = mode;
297        self
298    }
299
300    /// Enable or disable JSON mode.
301    pub fn with_json_mode(mut self, enabled: bool) -> Self {
302        self.json_mode = enabled;
303        self
304    }
305
306    /// Set usage limits.
307    pub fn with_limits(mut self, limits: UsageLimits) -> Self {
308        self.limits = limits;
309        self
310    }
311}
312
313/// Retry configuration.
314#[derive(Debug, Clone)]
315pub struct RetryConfig {
316    /// Max retry attempts.
317    pub max_attempts: usize,
318    /// Backoff delay between attempts.
319    pub backoff: Duration,
320    /// Retry on parse errors.
321    pub retry_on_parse_error: bool,
322}
323
324impl Default for RetryConfig {
325    fn default() -> Self {
326        Self {
327            max_attempts: 3,
328            backoff: Duration::from_millis(500),
329            retry_on_parse_error: true,
330        }
331    }
332}
333
334impl RetryConfig {
335    /// Create a new retry config.
336    pub fn new() -> Self {
337        Self::default()
338    }
339
340    /// Set max attempts.
341    pub fn with_max_attempts(mut self, n: usize) -> Self {
342        self.max_attempts = n;
343        self
344    }
345
346    /// Set backoff delay.
347    pub fn with_backoff(mut self, backoff: Duration) -> Self {
348        self.backoff = backoff;
349        self
350    }
351}
352
353/// HTML cleaning mode.
354#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
355pub enum HtmlCleaningMode {
356    /// Standard cleaning - removes scripts, styles, comments.
357    #[default]
358    Default,
359    /// Aggressive cleaning - also removes SVGs, images, etc.
360    Aggressive,
361    /// Minimal cleaning - only removes scripts.
362    Minimal,
363    /// No cleaning - raw HTML.
364    Raw,
365}
366
367/// Search options for web search.
368#[derive(Debug, Clone, Default, Serialize, Deserialize)]
369pub struct SearchOptions {
370    /// Maximum number of results to return.
371    pub limit: Option<usize>,
372    /// Country/region code (e.g., "us", "uk").
373    pub country: Option<String>,
374    /// Language code (e.g., "en", "es").
375    pub language: Option<String>,
376    /// Filter to specific domains.
377    pub site_filter: Option<Vec<String>>,
378    /// Exclude specific domains.
379    pub exclude_domains: Option<Vec<String>>,
380    /// Time range filter.
381    pub time_range: Option<TimeRange>,
382}
383
384impl SearchOptions {
385    /// Create new search options with defaults.
386    pub fn new() -> Self {
387        Self::default()
388    }
389
390    /// Set maximum number of results.
391    pub fn with_limit(mut self, limit: usize) -> Self {
392        self.limit = Some(limit);
393        self
394    }
395
396    /// Set country/region code.
397    pub fn with_country(mut self, country: impl Into<String>) -> Self {
398        self.country = Some(country.into());
399        self
400    }
401
402    /// Set language code.
403    pub fn with_language(mut self, language: impl Into<String>) -> Self {
404        self.language = Some(language.into());
405        self
406    }
407
408    /// Filter results to specific domains.
409    pub fn with_site_filter(mut self, domains: Vec<String>) -> Self {
410        self.site_filter = Some(domains);
411        self
412    }
413
414    /// Exclude specific domains from results.
415    pub fn with_exclude_domains(mut self, domains: Vec<String>) -> Self {
416        self.exclude_domains = Some(domains);
417        self
418    }
419
420    /// Set time range filter.
421    pub fn with_time_range(mut self, range: TimeRange) -> Self {
422        self.time_range = Some(range);
423        self
424    }
425}
426
427/// Time range for filtering search results.
428#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
429pub enum TimeRange {
430    /// Results from the past day.
431    Day,
432    /// Results from the past week.
433    Week,
434    /// Results from the past month.
435    Month,
436    /// Results from the past year.
437    Year,
438    /// Custom date range.
439    Custom {
440        /// Start date (format depends on provider).
441        start: String,
442        /// End date (format depends on provider).
443        end: String,
444    },
445}
446
447/// Research options for research tasks.
448#[derive(Debug, Clone, Default)]
449pub struct ResearchOptions {
450    /// Maximum pages to visit.
451    pub max_pages: usize,
452    /// Search options for the query.
453    pub search_options: Option<SearchOptions>,
454    /// Custom extraction prompt.
455    pub extraction_prompt: Option<String>,
456    /// Whether to synthesize findings into a summary.
457    pub synthesize: bool,
458}
459
460impl ResearchOptions {
461    /// Create new research options with defaults.
462    pub fn new() -> Self {
463        Self {
464            max_pages: 5,
465            search_options: None,
466            extraction_prompt: None,
467            synthesize: true,
468        }
469    }
470
471    /// Set max pages to visit.
472    pub fn with_max_pages(mut self, n: usize) -> Self {
473        self.max_pages = n;
474        self
475    }
476
477    /// Set search options.
478    pub fn with_search_options(mut self, options: SearchOptions) -> Self {
479        self.search_options = Some(options);
480        self
481    }
482
483    /// Set extraction prompt.
484    pub fn with_extraction_prompt(mut self, prompt: impl Into<String>) -> Self {
485        self.extraction_prompt = Some(prompt.into());
486        self
487    }
488
489    /// Enable or disable synthesis.
490    pub fn with_synthesize(mut self, enabled: bool) -> Self {
491        self.synthesize = enabled;
492        self
493    }
494}
495
496/// Usage statistics for tracking agent operations.
497///
498/// Uses atomic counters for lock-free concurrent updates.
499#[derive(Debug)]
500pub struct UsageStats {
501    /// Total LLM prompt tokens used.
502    pub prompt_tokens: AtomicU64,
503    /// Total LLM completion tokens used.
504    pub completion_tokens: AtomicU64,
505    /// Total LLM calls made.
506    pub llm_calls: AtomicU64,
507    /// Total search calls made.
508    pub search_calls: AtomicU64,
509    /// Total HTTP fetch calls made.
510    pub fetch_calls: AtomicU64,
511    /// Total web browser calls made (Chrome/WebDriver combined).
512    pub webbrowser_calls: AtomicU64,
513    /// Custom tool calls tracked by tool name (lock-free via DashMap).
514    pub custom_tool_calls: DashMap<String, AtomicU64>,
515    /// Total tool calls made.
516    pub tool_calls: AtomicU64,
517}
518
519impl Default for UsageStats {
520    fn default() -> Self {
521        Self {
522            prompt_tokens: AtomicU64::new(0),
523            completion_tokens: AtomicU64::new(0),
524            llm_calls: AtomicU64::new(0),
525            search_calls: AtomicU64::new(0),
526            fetch_calls: AtomicU64::new(0),
527            webbrowser_calls: AtomicU64::new(0),
528            custom_tool_calls: DashMap::new(),
529            tool_calls: AtomicU64::new(0),
530        }
531    }
532}
533
534impl UsageStats {
535    /// Create new usage stats.
536    pub fn new() -> Self {
537        Self::default()
538    }
539
540    /// Add tokens from an LLM response.
541    pub fn add_tokens(&self, prompt: u64, completion: u64) {
542        self.prompt_tokens.fetch_add(prompt, Ordering::Relaxed);
543        self.completion_tokens
544            .fetch_add(completion, Ordering::Relaxed);
545    }
546
547    /// Increment LLM call count.
548    pub fn increment_llm_calls(&self) {
549        self.llm_calls.fetch_add(1, Ordering::Relaxed);
550    }
551
552    /// Increment search call count.
553    pub fn increment_search_calls(&self) {
554        self.search_calls.fetch_add(1, Ordering::Relaxed);
555    }
556
557    /// Increment fetch call count.
558    pub fn increment_fetch_calls(&self) {
559        self.fetch_calls.fetch_add(1, Ordering::Relaxed);
560    }
561
562    /// Increment web browser call count (Chrome/WebDriver).
563    pub fn increment_webbrowser_calls(&self) {
564        self.webbrowser_calls.fetch_add(1, Ordering::Relaxed);
565    }
566
567    /// Increment custom tool call count for a specific tool.
568    pub fn increment_custom_tool_calls(&self, tool_name: &str) {
569        self.custom_tool_calls
570            .entry(tool_name.to_string())
571            .or_insert_with(|| AtomicU64::new(0))
572            .fetch_add(1, Ordering::Relaxed);
573    }
574
575    /// Get custom tool call count for a specific tool.
576    pub fn get_custom_tool_calls(&self, tool_name: &str) -> u64 {
577        self.custom_tool_calls
578            .get(tool_name)
579            .map(|v| v.load(Ordering::Relaxed))
580            .unwrap_or(0)
581    }
582
583    /// Get total custom tool calls across all tools.
584    pub fn total_custom_tool_calls(&self) -> u64 {
585        self.custom_tool_calls
586            .iter()
587            .map(|entry| entry.value().load(Ordering::Relaxed))
588            .sum()
589    }
590
591    /// Increment tool call count.
592    pub fn increment_tool_calls(&self) {
593        self.tool_calls.fetch_add(1, Ordering::Relaxed);
594    }
595
596    /// Get total tokens used.
597    pub fn total_tokens(&self) -> u64 {
598        self.prompt_tokens.load(Ordering::Relaxed) + self.completion_tokens.load(Ordering::Relaxed)
599    }
600
601    /// Get a snapshot of all stats.
602    pub fn snapshot(&self) -> UsageSnapshot {
603        let custom_tool_calls: HashMap<String, u64> = self
604            .custom_tool_calls
605            .iter()
606            .map(|entry| (entry.key().clone(), entry.value().load(Ordering::Relaxed)))
607            .collect();
608
609        UsageSnapshot {
610            prompt_tokens: self.prompt_tokens.load(Ordering::Relaxed),
611            completion_tokens: self.completion_tokens.load(Ordering::Relaxed),
612            llm_calls: self.llm_calls.load(Ordering::Relaxed),
613            search_calls: self.search_calls.load(Ordering::Relaxed),
614            fetch_calls: self.fetch_calls.load(Ordering::Relaxed),
615            webbrowser_calls: self.webbrowser_calls.load(Ordering::Relaxed),
616            custom_tool_calls,
617            tool_calls: self.tool_calls.load(Ordering::Relaxed),
618        }
619    }
620
621    /// Reset all counters.
622    pub fn reset(&self) {
623        self.prompt_tokens.store(0, Ordering::Relaxed);
624        self.completion_tokens.store(0, Ordering::Relaxed);
625        self.llm_calls.store(0, Ordering::Relaxed);
626        self.search_calls.store(0, Ordering::Relaxed);
627        self.fetch_calls.store(0, Ordering::Relaxed);
628        self.webbrowser_calls.store(0, Ordering::Relaxed);
629        self.custom_tool_calls.clear();
630        self.tool_calls.store(0, Ordering::Relaxed);
631    }
632
633    // ==================== Limit Checking Methods ====================
634
635    /// Check if LLM call limit would be exceeded.
636    pub fn check_llm_limit(&self, limits: &UsageLimits) -> Option<LimitType> {
637        if let Some(limit) = limits.max_llm_calls {
638            let used = self.llm_calls.load(Ordering::Relaxed);
639            if used >= limit {
640                return Some(LimitType::LlmCalls { used, limit });
641            }
642        }
643        None
644    }
645
646    /// Check if search call limit would be exceeded.
647    pub fn check_search_limit(&self, limits: &UsageLimits) -> Option<LimitType> {
648        if let Some(limit) = limits.max_search_calls {
649            let used = self.search_calls.load(Ordering::Relaxed);
650            if used >= limit {
651                return Some(LimitType::SearchCalls { used, limit });
652            }
653        }
654        None
655    }
656
657    /// Check if fetch call limit would be exceeded.
658    pub fn check_fetch_limit(&self, limits: &UsageLimits) -> Option<LimitType> {
659        if let Some(limit) = limits.max_fetch_calls {
660            let used = self.fetch_calls.load(Ordering::Relaxed);
661            if used >= limit {
662                return Some(LimitType::FetchCalls { used, limit });
663            }
664        }
665        None
666    }
667
668    /// Check if web browser call limit would be exceeded.
669    pub fn check_webbrowser_limit(&self, limits: &UsageLimits) -> Option<LimitType> {
670        if let Some(limit) = limits.max_webbrowser_calls {
671            let used = self.webbrowser_calls.load(Ordering::Relaxed);
672            if used >= limit {
673                return Some(LimitType::WebbrowserCalls { used, limit });
674            }
675        }
676        None
677    }
678
679    /// Check if custom tool call limit would be exceeded (total across all tools).
680    pub fn check_custom_tool_limit(&self, limits: &UsageLimits) -> Option<LimitType> {
681        if let Some(limit) = limits.max_custom_tool_calls {
682            let used = self.total_custom_tool_calls();
683            if used >= limit {
684                return Some(LimitType::CustomToolCalls { used, limit });
685            }
686        }
687        None
688    }
689
690    /// Check if tool call limit would be exceeded.
691    pub fn check_tool_limit(&self, limits: &UsageLimits) -> Option<LimitType> {
692        if let Some(limit) = limits.max_tool_calls {
693            let used = self.tool_calls.load(Ordering::Relaxed);
694            if used >= limit {
695                return Some(LimitType::ToolCalls { used, limit });
696            }
697        }
698        None
699    }
700
701    /// Check if token limits would be exceeded.
702    pub fn check_token_limits(&self, limits: &UsageLimits) -> Option<LimitType> {
703        let prompt = self.prompt_tokens.load(Ordering::Relaxed);
704        let completion = self.completion_tokens.load(Ordering::Relaxed);
705        let total = prompt + completion;
706
707        if let Some(limit) = limits.max_total_tokens {
708            if total >= limit {
709                return Some(LimitType::TotalTokens { used: total, limit });
710            }
711        }
712
713        if let Some(limit) = limits.max_prompt_tokens {
714            if prompt >= limit {
715                return Some(LimitType::PromptTokens {
716                    used: prompt,
717                    limit,
718                });
719            }
720        }
721
722        if let Some(limit) = limits.max_completion_tokens {
723            if completion >= limit {
724                return Some(LimitType::CompletionTokens {
725                    used: completion,
726                    limit,
727                });
728            }
729        }
730
731        None
732    }
733}
734
735/// Snapshot of usage statistics.
736#[derive(Debug, Clone, Default, Serialize, Deserialize)]
737pub struct UsageSnapshot {
738    /// Total LLM prompt tokens.
739    pub prompt_tokens: u64,
740    /// Total LLM completion tokens.
741    pub completion_tokens: u64,
742    /// Total LLM calls.
743    pub llm_calls: u64,
744    /// Total search calls.
745    pub search_calls: u64,
746    /// Total HTTP fetch calls.
747    pub fetch_calls: u64,
748    /// Total web browser calls (Chrome/WebDriver combined).
749    pub webbrowser_calls: u64,
750    /// Custom tool calls by tool name.
751    pub custom_tool_calls: HashMap<String, u64>,
752    /// Total tool calls.
753    pub tool_calls: u64,
754}
755
756impl UsageSnapshot {
757    /// Get total tokens.
758    pub fn total_tokens(&self) -> u64 {
759        self.prompt_tokens + self.completion_tokens
760    }
761
762    /// Get total custom tool calls across all tools.
763    pub fn total_custom_tool_calls(&self) -> u64 {
764        self.custom_tool_calls.values().sum()
765    }
766
767    /// Get call count for a specific custom tool.
768    pub fn get_custom_tool_calls(&self, tool_name: &str) -> u64 {
769        self.custom_tool_calls.get(tool_name).copied().unwrap_or(0)
770    }
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776
777    #[test]
778    fn test_usage_limits_builder() {
779        let limits = UsageLimits::new()
780            .with_max_total_tokens(10000)
781            .with_max_llm_calls(100)
782            .with_max_search_calls(50)
783            .with_max_fetch_calls(200)
784            .with_max_webbrowser_calls(30)
785            .with_max_custom_tool_calls(25)
786            .with_max_tool_calls(500);
787
788        assert_eq!(limits.max_total_tokens, Some(10000));
789        assert_eq!(limits.max_llm_calls, Some(100));
790        assert_eq!(limits.max_search_calls, Some(50));
791        assert_eq!(limits.max_fetch_calls, Some(200));
792        assert_eq!(limits.max_webbrowser_calls, Some(30));
793        assert_eq!(limits.max_custom_tool_calls, Some(25));
794        assert_eq!(limits.max_tool_calls, Some(500));
795    }
796
797    #[test]
798    fn test_usage_stats_tracking() {
799        let stats = UsageStats::new();
800
801        // Track various calls
802        stats.increment_llm_calls();
803        stats.increment_llm_calls();
804        stats.increment_search_calls();
805        stats.increment_fetch_calls();
806        stats.increment_fetch_calls();
807        stats.increment_fetch_calls();
808        stats.increment_webbrowser_calls();
809        stats.increment_custom_tool_calls("my_api");
810        stats.increment_custom_tool_calls("my_api");
811        stats.increment_custom_tool_calls("other_api");
812        stats.add_tokens(100, 50);
813
814        let snapshot = stats.snapshot();
815        assert_eq!(snapshot.llm_calls, 2);
816        assert_eq!(snapshot.search_calls, 1);
817        assert_eq!(snapshot.fetch_calls, 3);
818        assert_eq!(snapshot.webbrowser_calls, 1);
819        assert_eq!(snapshot.prompt_tokens, 100);
820        assert_eq!(snapshot.completion_tokens, 50);
821        assert_eq!(snapshot.total_tokens(), 150);
822
823        // Check custom tool tracking
824        assert_eq!(snapshot.get_custom_tool_calls("my_api"), 2);
825        assert_eq!(snapshot.get_custom_tool_calls("other_api"), 1);
826        assert_eq!(snapshot.get_custom_tool_calls("unknown"), 0);
827        assert_eq!(snapshot.total_custom_tool_calls(), 3);
828    }
829
830    #[test]
831    fn test_usage_stats_reset() {
832        let stats = UsageStats::new();
833        stats.increment_llm_calls();
834        stats.increment_search_calls();
835        stats.increment_custom_tool_calls("my_api");
836        stats.add_tokens(100, 50);
837
838        stats.reset();
839
840        let snapshot = stats.snapshot();
841        assert_eq!(snapshot.llm_calls, 0);
842        assert_eq!(snapshot.search_calls, 0);
843        assert_eq!(snapshot.prompt_tokens, 0);
844        assert_eq!(snapshot.total_custom_tool_calls(), 0);
845    }
846
847    #[test]
848    fn test_limit_checking_llm() {
849        let stats = UsageStats::new();
850        let limits = UsageLimits::new().with_max_llm_calls(3);
851
852        // Under limit
853        stats.increment_llm_calls();
854        stats.increment_llm_calls();
855        assert!(stats.check_llm_limit(&limits).is_none());
856
857        // At limit
858        stats.increment_llm_calls();
859        let exceeded = stats.check_llm_limit(&limits);
860        assert!(exceeded.is_some());
861        match exceeded.unwrap() {
862            LimitType::LlmCalls { used, limit } => {
863                assert_eq!(used, 3);
864                assert_eq!(limit, 3);
865            }
866            _ => panic!("Expected LlmCalls limit type"),
867        }
868    }
869
870    #[test]
871    fn test_limit_checking_tokens() {
872        let stats = UsageStats::new();
873        let limits = UsageLimits::new()
874            .with_max_total_tokens(100)
875            .with_max_prompt_tokens(60);
876
877        // Under limit
878        stats.add_tokens(30, 20);
879        assert!(stats.check_token_limits(&limits).is_none());
880
881        // Prompt limit exceeded
882        stats.add_tokens(40, 0);
883        let exceeded = stats.check_token_limits(&limits);
884        assert!(exceeded.is_some());
885        match exceeded.unwrap() {
886            LimitType::PromptTokens { used, limit } => {
887                assert_eq!(used, 70);
888                assert_eq!(limit, 60);
889            }
890            _ => panic!("Expected PromptTokens limit type"),
891        }
892    }
893
894    #[test]
895    fn test_limit_checking_custom_tools() {
896        let stats = UsageStats::new();
897        let limits = UsageLimits::new().with_max_custom_tool_calls(5);
898
899        stats.increment_custom_tool_calls("api_a");
900        stats.increment_custom_tool_calls("api_b");
901        stats.increment_custom_tool_calls("api_a");
902        stats.increment_custom_tool_calls("api_c");
903        assert!(stats.check_custom_tool_limit(&limits).is_none());
904
905        stats.increment_custom_tool_calls("api_a");
906        let exceeded = stats.check_custom_tool_limit(&limits);
907        assert!(exceeded.is_some());
908        match exceeded.unwrap() {
909            LimitType::CustomToolCalls { used, limit } => {
910                assert_eq!(used, 5);
911                assert_eq!(limit, 5);
912            }
913            _ => panic!("Expected CustomToolCalls limit type"),
914        }
915    }
916
917    #[test]
918    fn test_agent_config_with_limits() {
919        let limits = UsageLimits::new()
920            .with_max_llm_calls(100)
921            .with_max_search_calls(50);
922
923        let config = AgentConfig::new().with_limits(limits);
924
925        assert_eq!(config.limits.max_llm_calls, Some(100));
926        assert_eq!(config.limits.max_search_calls, Some(50));
927    }
928
929    #[test]
930    fn test_limit_type_display() {
931        let limit = LimitType::LlmCalls { used: 10, limit: 5 };
932        assert_eq!(format!("{}", limit), "LLM calls (10 used, 5 limit)");
933
934        let limit = LimitType::CustomToolCalls {
935            used: 25,
936            limit: 20,
937        };
938        assert_eq!(
939            format!("{}", limit),
940            "custom tool calls (25 used, 20 limit)"
941        );
942
943        let limit = LimitType::TotalTokens {
944            used: 1000,
945            limit: 500,
946        };
947        assert_eq!(format!("{}", limit), "total tokens (1000 used, 500 limit)");
948    }
949}