1use hashbrown::HashMap;
7use std::sync::Arc;
8
9use std::time::{Duration, Instant};
10use tokio::sync::RwLock;
11use tokio::time;
12
13#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
15pub enum OperationType {
16 ApiCall,
18 FileOperation,
20 CodeAnalysis,
22 ToolExecution,
24 NetworkRequest,
26 Processing,
28 Custom(String),
30}
31
32#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
34pub struct TimeoutConfig {
35 pub timeout_duration: Duration,
37 pub max_retries: u32,
39 pub initial_retry_delay: Duration,
41 pub max_retry_delay: Duration,
43 pub backoff_multiplier: f64,
45 pub use_jitter: bool,
47 pub retry_on_timeout: bool,
49 pub retry_on_errors: Vec<String>,
51}
52
53const DEFAULT_RETRY_ERRORS: [&str; 4] = ["timeout", "connection", "network", "server_error"];
55
56impl Default for TimeoutConfig {
57 fn default() -> Self {
58 Self {
59 timeout_duration: Duration::from_secs(30),
60 max_retries: 3,
61 initial_retry_delay: Duration::from_millis(100),
62 max_retry_delay: Duration::from_secs(30),
63 backoff_multiplier: 2.0,
64 use_jitter: true,
65 retry_on_timeout: true,
66 retry_on_errors: DEFAULT_RETRY_ERRORS.iter().map(|s| (*s).into()).collect(),
67 }
68 }
69}
70
71impl TimeoutConfig {
72 pub fn api_call() -> Self {
74 Self {
75 timeout_duration: Duration::from_secs(60),
76 max_retries: 5,
77 initial_retry_delay: Duration::from_millis(200),
78 max_retry_delay: Duration::from_secs(10),
79 backoff_multiplier: 1.5,
80 ..Default::default()
81 }
82 }
83
84 pub fn file_operation() -> Self {
86 Self {
87 timeout_duration: Duration::from_secs(10),
88 max_retries: 2,
89 initial_retry_delay: Duration::from_millis(50),
90 max_retry_delay: Duration::from_secs(2),
91 backoff_multiplier: 2.0,
92 retry_on_timeout: false, ..Default::default()
94 }
95 }
96
97 pub fn analysis() -> Self {
99 Self {
100 timeout_duration: Duration::from_secs(120),
101 max_retries: 1,
102 initial_retry_delay: Duration::from_secs(5),
103 max_retry_delay: Duration::from_secs(10),
104 backoff_multiplier: 2.0,
105 ..Default::default()
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct TimeoutEvent {
113 pub operation_id: String,
114 pub operation_type: OperationType,
115 pub start_time: Instant,
116 pub timeout_duration: Duration,
117 pub retry_count: u32,
118 pub error_message: Option<String>,
119}
120
121#[derive(Debug, Clone, Default)]
123pub struct TimeoutStats {
124 pub total_operations: usize,
125 pub timed_out_operations: usize,
126 pub successful_retries: usize,
127 pub failed_retries: usize,
128 pub average_timeout_duration: Duration,
129 pub total_retry_attempts: usize,
130}
131
132pub struct TimeoutDetector {
134 configs: Arc<RwLock<HashMap<OperationType, TimeoutConfig>>>,
135 active_operations: Arc<RwLock<HashMap<String, TimeoutEvent>>>,
136 stats: Arc<RwLock<TimeoutStats>>,
137}
138
139impl Default for TimeoutDetector {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145impl TimeoutDetector {
146 pub fn new() -> Self {
147 let mut configs = HashMap::new();
148
149 configs.insert(OperationType::ApiCall, TimeoutConfig::api_call());
151 configs.insert(
152 OperationType::FileOperation,
153 TimeoutConfig::file_operation(),
154 );
155 configs.insert(OperationType::CodeAnalysis, TimeoutConfig::analysis());
156 configs.insert(OperationType::ToolExecution, TimeoutConfig::default());
157 configs.insert(OperationType::NetworkRequest, TimeoutConfig::api_call());
158 configs.insert(OperationType::Processing, TimeoutConfig::analysis());
159
160 Self {
161 configs: Arc::new(RwLock::new(configs)),
162 active_operations: Arc::new(RwLock::new(HashMap::new())),
163 stats: Arc::new(RwLock::new(TimeoutStats::default())),
164 }
165 }
166
167 pub async fn set_config(&self, operation_type: OperationType, config: TimeoutConfig) {
169 let mut configs = self.configs.write().await;
170 configs.insert(operation_type, config);
171 }
172
173 pub async fn get_config(&self, operation_type: &OperationType) -> TimeoutConfig {
175 let configs = self.configs.read().await;
176 configs.get(operation_type).cloned().unwrap_or_default()
177 }
178
179 pub async fn start_operation(
181 &self,
182 operation_id: String,
183 operation_type: OperationType,
184 ) -> TimeoutHandle {
185 let config = self.get_config(&operation_type).await;
186
187 let event = TimeoutEvent {
188 operation_id: operation_id.clone(),
189 operation_type,
190 start_time: Instant::now(),
191 timeout_duration: config.timeout_duration,
192 retry_count: 0,
193 error_message: None,
194 };
195
196 let mut active_ops = self.active_operations.write().await;
197 active_ops.insert(operation_id.clone(), event);
198
199 let mut stats = self.stats.write().await;
200 stats.total_operations += 1;
201
202 TimeoutHandle {
203 operation_id,
204 detector: Arc::new(self.clone()),
205 }
206 }
207
208 pub async fn check_timeout(&self, operation_id: &str) -> Option<TimeoutEvent> {
210 let active_ops = self.active_operations.read().await;
211 active_ops
212 .get(operation_id)
213 .filter(|event| event.start_time.elapsed() >= event.timeout_duration)
214 .cloned()
215 }
216
217 pub async fn record_timeout(&self, operation_id: &str, error_message: Option<String>) {
219 let mut active_ops = self.active_operations.write().await;
220 if let Some(event) = active_ops.get_mut(operation_id) {
221 event.error_message = error_message;
222 }
223
224 let mut stats = self.stats.write().await;
225 stats.timed_out_operations += 1;
226 }
227
228 pub async fn record_successful_retry(&self, _operation_id: &str) {
230 let mut stats = self.stats.write().await;
231 stats.successful_retries += 1;
232 stats.total_retry_attempts += 1;
233 }
234
235 pub async fn record_failed_retry(&self, _operation_id: &str) {
237 let mut stats = self.stats.write().await;
238 stats.failed_retries += 1;
239 stats.total_retry_attempts += 1;
240 }
241
242 pub async fn end_operation(&self, operation_id: &str) {
244 let mut active_ops = self.active_operations.write().await;
245 if let Some(event) = active_ops.remove(operation_id) {
246 let duration = event.start_time.elapsed();
247 let mut stats = self.stats.write().await;
248 if stats.total_operations > 0 {
250 let total_duration =
251 stats.average_timeout_duration * (stats.total_operations - 1) as u32;
252 stats.average_timeout_duration =
253 (total_duration + duration) / stats.total_operations as u32;
254 }
255 }
256 }
257
258 pub async fn get_stats(&self) -> TimeoutStats {
260 self.stats.read().await.clone()
261 }
262
263 pub async fn calculate_retry_delay(
265 &self,
266 operation_type: &OperationType,
267 attempt: u32,
268 ) -> Duration {
269 let config = self.get_config(operation_type).await;
270
271 let base_delay = config.initial_retry_delay.as_millis() as f64;
272 let multiplier = config.backoff_multiplier.powi(attempt as i32);
273 let delay_ms = (base_delay * multiplier) as u64;
274
275 let mut delay =
276 Duration::from_millis(delay_ms.min(config.max_retry_delay.as_millis() as u64));
277
278 if config.use_jitter {
280 use std::time::SystemTime;
281 let seed = SystemTime::now()
282 .duration_since(std::time::UNIX_EPOCH)
283 .unwrap_or_default()
284 .as_nanos() as u64;
285 let jitter_factor = (seed % 100) as f64 / 100.0; let jitter_ms = (delay.as_millis() as f64 * 0.1 * jitter_factor) as u64; delay += Duration::from_millis(jitter_ms);
288 }
289
290 delay
291 }
292
293 pub async fn should_retry(
296 &self,
297 operation_type: &OperationType,
298 error: &anyhow::Error,
299 attempt: u32,
300 ) -> bool {
301 let config = self.get_config(operation_type).await;
302
303 if attempt >= config.max_retries {
304 return false;
305 }
306
307 let error_str = error.to_string();
308
309 let contains_ci = |pattern: &str| {
311 error_str
312 .as_bytes()
313 .windows(pattern.len())
314 .any(|window| window.eq_ignore_ascii_case(pattern.as_bytes()))
315 };
316
317 for retry_error in &config.retry_on_errors {
319 if contains_ci(retry_error) {
320 return true;
321 }
322 }
323
324 if config.retry_on_timeout && (contains_ci("timeout") || contains_ci("timed out")) {
326 return true;
327 }
328
329 false
330 }
331
332 pub async fn execute_with_timeout_retry<F, Fut, T>(
334 &self,
335 operation_id: String,
336 operation_type: OperationType,
337 mut operation: F,
338 ) -> Result<T, anyhow::Error>
339 where
340 F: FnMut() -> Fut,
341 Fut: Future<Output = Result<T, anyhow::Error>>,
342 {
343 let config = self.get_config(&operation_type).await;
344 let mut attempt = 0;
345 let _last_error: Option<anyhow::Error> = None;
346
347 loop {
348 let handle = self
349 .start_operation(
350 format!("{}_{}", operation_id, attempt),
351 operation_type.clone(),
352 )
353 .await;
354
355 let result = match time::timeout(config.timeout_duration, operation()).await {
356 Ok(result) => result,
357 Err(_) => {
358 self.record_timeout(
359 &handle.operation_id,
360 Some("Operation timed out".to_owned()),
361 )
362 .await;
363 Err(anyhow::anyhow!(
364 "Operation '{}' timed out after {:?}",
365 operation_id,
366 config.timeout_duration
367 ))
368 }
369 };
370
371 handle.end().await;
372
373 match result {
374 Ok(value) => {
375 if attempt > 0 {
376 self.record_successful_retry(&format!("{}_{}", operation_id, attempt))
377 .await;
378 }
379 return Ok(value);
380 }
381 Err(error) => {
382 let should_retry_op = self.should_retry(&operation_type, &error, attempt).await;
383
384 if !should_retry_op {
385 if attempt > 0 {
386 self.record_failed_retry(&format!("{}_{}", operation_id, attempt))
387 .await;
388 }
389 return Err(error);
390 }
391
392 attempt += 1;
393 self.record_failed_retry(&format!("{}_{}", operation_id, attempt))
394 .await;
395
396 let delay = self.calculate_retry_delay(&operation_type, attempt).await;
397 tracing::warn!(
398 operation_id,
399 attempt,
400 max_retries = config.max_retries,
401 delay = ?delay,
402 "Operation failed and will be retried"
403 );
404 time::sleep(delay).await;
405 }
406 }
407 }
408 }
409}
410
411impl Clone for TimeoutDetector {
412 fn clone(&self) -> Self {
413 Self {
414 configs: Arc::clone(&self.configs),
415 active_operations: Arc::clone(&self.active_operations),
416 stats: Arc::clone(&self.stats),
417 }
418 }
419}
420
421pub struct TimeoutHandle {
423 operation_id: String,
424 detector: Arc<TimeoutDetector>,
425}
426
427impl TimeoutHandle {
428 pub async fn end(self) {
430 self.detector.end_operation(&self.operation_id).await;
431 }
432
433 pub fn operation_id(&self) -> &str {
435 &self.operation_id
436 }
437}
438
439impl Drop for TimeoutHandle {
440 fn drop(&mut self) {
441 let operation_id = self.operation_id.clone();
443 let detector = Arc::clone(&self.detector);
444
445 tokio::spawn(async move {
446 detector.end_operation(&operation_id).await;
447 });
448 }
449}
450
451use once_cell::sync::Lazy;
453pub static TIMEOUT_DETECTOR: Lazy<TimeoutDetector> = Lazy::new(TimeoutDetector::new);
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use std::sync::atomic::{AtomicUsize, Ordering};
459 use std::time::Duration;
460 use tokio::time::sleep;
461
462 #[tokio::test]
463 async fn test_timeout_detection() {
464 let detector = TimeoutDetector::new();
465
466 let config = TimeoutConfig {
468 timeout_duration: Duration::from_millis(10),
469 max_retries: 0,
470 ..Default::default()
471 };
472
473 detector.set_config(OperationType::ApiCall, config).await;
474
475 let result = detector
476 .execute_with_timeout_retry(
477 "test_operation".to_owned(),
478 OperationType::ApiCall,
479 || async {
480 sleep(Duration::from_millis(20)).await;
481 Ok("success")
482 },
483 )
484 .await;
485
486 assert!(result.is_err());
487 assert!(result.unwrap_err().to_string().contains("timed out"));
488 }
489
490 #[tokio::test]
491 async fn test_successful_retry() {
492 let detector = TimeoutDetector::new();
493
494 let config = TimeoutConfig {
495 timeout_duration: Duration::from_millis(50),
496 max_retries: 2,
497 initial_retry_delay: Duration::from_millis(5),
498 retry_on_timeout: true,
499 ..Default::default()
500 };
501
502 detector.set_config(OperationType::ApiCall, config).await;
503
504 let call_count = Arc::new(AtomicUsize::new(0));
505 let call_count_clone = call_count.clone();
506 let result = detector
507 .execute_with_timeout_retry(
508 "test_retry".to_owned(),
509 OperationType::ApiCall,
510 move || {
511 let call_count = call_count_clone.clone();
512 async move {
513 let count = call_count.fetch_add(1, Ordering::SeqCst) + 1;
514 if count == 1 {
515 sleep(Duration::from_millis(60)).await;
517 Ok("should not reach here")
518 } else {
519 sleep(Duration::from_millis(10)).await;
521 Ok("success")
522 }
523 }
524 },
525 )
526 .await;
527
528 assert!(result.is_ok());
529 assert_eq!(result.unwrap(), "success");
530 assert_eq!(call_count.load(Ordering::SeqCst), 2);
531
532 let stats = detector.get_stats().await;
533 assert_eq!(stats.successful_retries, 1);
534 assert_eq!(stats.total_retry_attempts, 2);
535 }
536
537 #[tokio::test]
538 async fn test_calculate_retry_delay() {
539 let detector = TimeoutDetector::new();
540
541 let delay = detector
542 .calculate_retry_delay(&OperationType::ApiCall, 0)
543 .await;
544 assert!(delay >= Duration::from_millis(200)); let delay2 = detector
547 .calculate_retry_delay(&OperationType::ApiCall, 1)
548 .await;
549 assert!(delay2 > delay); }
551}