1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum ThrottleLevel {
8 Normal,
10 Reduced,
12 Blocked,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ThrottleConfig {
19 #[serde(default = "default_normal_limit")]
22 pub normal_limit: u32,
23 #[serde(default = "default_reduced_limit")]
26 pub reduced_limit: u32,
27}
28
29fn default_normal_limit() -> u32 {
30 3
31}
32
33fn default_reduced_limit() -> u32 {
34 8
35}
36
37impl Default for ThrottleConfig {
38 fn default() -> Self {
39 Self {
40 normal_limit: default_normal_limit(),
41 reduced_limit: default_reduced_limit(),
42 }
43 }
44}
45
46type CallKey = (String, u64);
48
49#[derive(Debug, Clone)]
55pub struct ProgressiveThrottler {
56 config: ThrottleConfig,
57 counts: HashMap<CallKey, u32>,
58}
59
60impl ProgressiveThrottler {
61 pub fn new(config: ThrottleConfig) -> Self {
63 Self {
64 config,
65 counts: HashMap::new(),
66 }
67 }
68
69 pub fn record_call(&mut self, tool_name: &str, params_hash: u64) -> ThrottleLevel {
71 let key = (tool_name.to_string(), params_hash);
72 let count = self.counts.entry(key).or_insert(0);
73 *count += 1;
74 let c = *count;
75 self.level_for_count(c)
76 }
77
78 pub fn get_level(&self, tool_name: &str, params_hash: u64) -> ThrottleLevel {
81 let key = (tool_name.to_string(), params_hash);
82 let count = self.counts.get(&key).copied().unwrap_or(0);
83 self.level_for_count(count)
84 }
85
86 pub fn reset(&mut self) {
88 self.counts.clear();
89 }
90
91 pub fn reset_tool(&mut self, tool_name: &str) {
93 self.counts.retain(|(name, _), _| name != tool_name);
94 }
95
96 pub fn call_count(&self, tool_name: &str, params_hash: u64) -> u32 {
98 let key = (tool_name.to_string(), params_hash);
99 self.counts.get(&key).copied().unwrap_or(0)
100 }
101
102 fn level_for_count(&self, count: u32) -> ThrottleLevel {
105 if count <= self.config.normal_limit {
106 ThrottleLevel::Normal
107 } else if count <= self.config.reduced_limit {
108 ThrottleLevel::Reduced
109 } else {
110 ThrottleLevel::Blocked
111 }
112 }
113}
114
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 fn default_throttler() -> ProgressiveThrottler {
121 ProgressiveThrottler::new(ThrottleConfig::default())
122 }
123
124 #[test]
125 fn normal_band_for_first_three_calls() {
126 let mut t = default_throttler();
127 assert_eq!(t.record_call("read_file", 42), ThrottleLevel::Normal);
128 assert_eq!(t.record_call("read_file", 42), ThrottleLevel::Normal);
129 assert_eq!(t.record_call("read_file", 42), ThrottleLevel::Normal);
130 }
131
132 #[test]
133 fn reduced_band_for_calls_four_through_eight() {
134 let mut t = default_throttler();
135 for _ in 0..3 {
136 t.record_call("read_file", 1);
137 }
138 for i in 4..=8 {
139 let level = t.record_call("read_file", 1);
140 assert_eq!(level, ThrottleLevel::Reduced, "call {i} should be Reduced");
141 }
142 }
143
144 #[test]
145 fn blocked_after_eight_calls() {
146 let mut t = default_throttler();
147 for _ in 0..8 {
148 t.record_call("search", 99);
149 }
150 assert_eq!(t.record_call("search", 99), ThrottleLevel::Blocked);
151 assert_eq!(t.record_call("search", 99), ThrottleLevel::Blocked);
152 }
153
154 #[test]
155 fn different_params_hash_tracked_independently() {
156 let mut t = default_throttler();
157 for _ in 0..8 {
158 t.record_call("read_file", 1);
159 }
160 assert_eq!(t.record_call("read_file", 2), ThrottleLevel::Normal);
162 }
163
164 #[test]
165 fn different_tools_tracked_independently() {
166 let mut t = default_throttler();
167 for _ in 0..8 {
168 t.record_call("read_file", 1);
169 }
170 assert_eq!(t.record_call("write_file", 1), ThrottleLevel::Normal);
171 }
172
173 #[test]
174 fn get_level_without_recording() {
175 let mut t = default_throttler();
176 assert_eq!(t.get_level("read_file", 1), ThrottleLevel::Normal);
178
179 for _ in 0..4 {
180 t.record_call("read_file", 1);
181 }
182 assert_eq!(t.get_level("read_file", 1), ThrottleLevel::Reduced);
184 assert_eq!(t.call_count("read_file", 1), 4);
186 }
187
188 #[test]
189 fn reset_clears_all_counters() {
190 let mut t = default_throttler();
191 for _ in 0..5 {
192 t.record_call("read_file", 1);
193 t.record_call("search", 2);
194 }
195 t.reset();
196 assert_eq!(t.get_level("read_file", 1), ThrottleLevel::Normal);
197 assert_eq!(t.get_level("search", 2), ThrottleLevel::Normal);
198 assert_eq!(t.call_count("read_file", 1), 0);
199 }
200
201 #[test]
202 fn reset_tool_clears_only_that_tool() {
203 let mut t = default_throttler();
204 for _ in 0..5 {
205 t.record_call("read_file", 1);
206 t.record_call("read_file", 99);
207 t.record_call("search", 2);
208 }
209 t.reset_tool("read_file");
210 assert_eq!(t.call_count("read_file", 1), 0);
211 assert_eq!(t.call_count("read_file", 99), 0);
212 assert_eq!(t.call_count("search", 2), 5);
213 }
214
215 #[test]
216 fn custom_thresholds() {
217 let config = ThrottleConfig {
218 normal_limit: 1,
219 reduced_limit: 2,
220 };
221 let mut t = ProgressiveThrottler::new(config);
222 assert_eq!(t.record_call("t", 0), ThrottleLevel::Normal); assert_eq!(t.record_call("t", 0), ThrottleLevel::Reduced); assert_eq!(t.record_call("t", 0), ThrottleLevel::Blocked); }
226
227 #[test]
228 fn zero_count_is_normal() {
229 let t = default_throttler();
230 assert_eq!(t.get_level("nonexistent", 0), ThrottleLevel::Normal);
231 }
232
233 #[test]
234 fn call_count_tracks_correctly() {
235 let mut t = default_throttler();
236 assert_eq!(t.call_count("x", 1), 0);
237 t.record_call("x", 1);
238 assert_eq!(t.call_count("x", 1), 1);
239 t.record_call("x", 1);
240 t.record_call("x", 1);
241 assert_eq!(t.call_count("x", 1), 3);
242 }
243
244 mod prop_tests {
245 use super::*;
246 use proptest::prelude::*;
247
248 proptest! {
258 #[test]
259 fn progressive_throttling_enforces_limits(
260 normal_limit in 1u32..=20,
261 gap in 1u32..=20,
262 extra_calls in 1u32..=10,
263 tool_name in "[a-z_]{1,12}",
264 params_hash in any::<u64>(),
265 ) {
266 let reduced_limit = normal_limit + gap;
267 let total_calls = reduced_limit + extra_calls;
268
269 let config = ThrottleConfig {
270 normal_limit,
271 reduced_limit,
272 };
273 let mut throttler = ProgressiveThrottler::new(config);
274
275 for call_num in 1..=total_calls {
276 let level = throttler.record_call(&tool_name, params_hash);
277
278 if call_num <= normal_limit {
279 prop_assert_eq!(
280 level,
281 ThrottleLevel::Normal,
282 "call {} should be Normal (normal_limit={})",
283 call_num,
284 normal_limit,
285 );
286 } else if call_num <= reduced_limit {
287 prop_assert_eq!(
288 level,
289 ThrottleLevel::Reduced,
290 "call {} should be Reduced (normal_limit={}, reduced_limit={})",
291 call_num,
292 normal_limit,
293 reduced_limit,
294 );
295 } else {
296 prop_assert_eq!(
297 level,
298 ThrottleLevel::Blocked,
299 "call {} should be Blocked (reduced_limit={})",
300 call_num,
301 reduced_limit,
302 );
303 }
304 }
305
306 throttler.reset();
308 prop_assert_eq!(
309 throttler.get_level(&tool_name, params_hash),
310 ThrottleLevel::Normal,
311 "after reset(), level should be Normal",
312 );
313 prop_assert_eq!(
314 throttler.call_count(&tool_name, params_hash),
315 0,
316 "after reset(), call count should be 0",
317 );
318 }
319 }
320 }
321}