1use dashmap::DashMap;
4use std::time::Instant;
5
6pub struct RateLimiter {
7 buckets: DashMap<String, TokenBucket>,
9 default_rate: u64,
11 agent_rates: DashMap<String, u64>,
13}
14
15struct TokenBucket {
16 tokens: f64,
17 max_tokens: f64,
18 refill_rate: f64, last_refill: Instant,
20}
21
22impl Default for RateLimiter {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl RateLimiter {
29 pub fn new() -> Self {
30 Self {
31 buckets: DashMap::new(),
32 default_rate: 60, agent_rates: DashMap::new(),
34 }
35 }
36
37 #[inline]
39 pub fn check(&self, agent_id: &str, tool_name: &str) -> bool {
40 let key = bucket_key(agent_id, tool_name);
41 let rate = self
42 .agent_rates
43 .get(agent_id)
44 .map(|r| *r)
45 .unwrap_or(self.default_rate);
46 let mut bucket = self.buckets.entry(key).or_insert_with(|| TokenBucket {
47 tokens: rate as f64,
48 max_tokens: rate as f64,
49 refill_rate: rate as f64 / 60.0,
50 last_refill: Instant::now(),
51 });
52
53 let now = Instant::now();
55 let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
56 bucket.tokens = (bucket.tokens + elapsed * bucket.refill_rate).min(bucket.max_tokens);
57 bucket.last_refill = now;
58
59 if bucket.tokens >= 1.0 {
61 bucket.tokens -= 1.0;
62 true
63 } else {
64 false
65 }
66 }
67
68 pub fn set_rate(&self, agent_id: &str, calls_per_minute: u64) {
73 let new_max = calls_per_minute as f64;
74 self.agent_rates
76 .insert(agent_id.to_string(), calls_per_minute);
77 let prefix = format!("{agent_id}\x1f");
79 for mut entry in self.buckets.iter_mut() {
80 if entry.key().starts_with(&prefix) {
81 let bucket = entry.value_mut();
82 bucket.max_tokens = new_max;
83 bucket.refill_rate = new_max / 60.0;
84 bucket.tokens = bucket.tokens.min(new_max);
86 }
87 }
88 }
89}
90
91#[inline]
93fn bucket_key(agent_id: &str, tool_name: &str) -> String {
94 use std::fmt::Write;
95 let mut key = String::with_capacity(agent_id.len() + 1 + tool_name.len());
96 let _ = write!(key, "{agent_id}\x1f{tool_name}");
97 key
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn basic_rate_limit() {
106 let limiter = RateLimiter::new();
107 for _ in 0..60 {
109 assert!(limiter.check("agent", "tool"));
110 }
111 assert!(!limiter.check("agent", "tool"));
113 }
114
115 #[test]
116 fn different_agents_separate_buckets() {
117 let limiter = RateLimiter::new();
118 for _ in 0..60 {
119 limiter.check("agent-a", "tool");
120 }
121 assert!(limiter.check("agent-b", "tool"));
123 }
124
125 #[test]
126 fn different_tools_separate_buckets() {
127 let limiter = RateLimiter::new();
128 for _ in 0..60 {
129 limiter.check("agent", "tool_a");
130 }
131 assert!(!limiter.check("agent", "tool_a"));
132 assert!(limiter.check("agent", "tool_b"));
134 }
135
136 #[test]
137 fn set_rate_lowers_limit() {
138 let limiter = RateLimiter::new();
139 assert!(limiter.check("agent", "tool"));
141 limiter.set_rate("agent", 10);
143 let mut allowed = 0;
145 for _ in 0..20 {
146 if limiter.check("agent", "tool") {
147 allowed += 1;
148 } else {
149 break;
150 }
151 }
152 assert_eq!(allowed, 10);
153 }
154
155 #[test]
156 fn set_rate_does_not_affect_other_agents() {
157 let limiter = RateLimiter::new();
158 assert!(limiter.check("agent-a", "tool"));
160 assert!(limiter.check("agent-b", "tool"));
161
162 limiter.set_rate("agent-a", 5);
163
164 let mut count = 0;
166 for _ in 0..59 {
167 if limiter.check("agent-b", "tool") {
168 count += 1;
169 }
170 }
171 assert_eq!(count, 59); }
173
174 #[test]
175 fn set_rate_before_any_check() {
176 let limiter = RateLimiter::new();
177 limiter.set_rate("nobody", 5);
179 let mut count = 0;
181 for _ in 0..10 {
182 if limiter.check("nobody", "tool") {
183 count += 1;
184 }
185 }
186 assert_eq!(count, 5);
187 }
188
189 #[test]
190 fn token_refill_over_time() {
191 let limiter = RateLimiter::new();
192 for _ in 0..60 {
194 limiter.check("agent", "tool");
195 }
196 assert!(!limiter.check("agent", "tool"));
197
198 assert!(limiter.buckets.contains_key(&bucket_key("agent", "tool")));
202 }
203}