1use std::collections::HashMap;
7use std::sync::Arc;
8
9use std::time::{Duration, Instant};
10use tokio::sync::RwLock;
11use tokio::time;
12
13#[derive(Debug, Clone, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
15pub enum OperationType {
16 ApiCall,
18 FileOperation,
20 CodeAnalysis,
22 ToolExecution,
24 NetworkRequest,
26 Processing,
28 Custom(String),
30}
31
32#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
34pub struct TimeoutConfig {
35 pub timeout_duration: Duration,
37 pub max_retries: u32,
39 pub initial_retry_delay: Duration,
41 pub max_retry_delay: Duration,
43 pub backoff_multiplier: f64,
45 pub use_jitter: bool,
47 pub retry_on_timeout: bool,
49 pub retry_on_errors: Vec<String>,
51}
52
53impl Default for TimeoutConfig {
54 fn default() -> Self {
55 Self {
56 timeout_duration: Duration::from_secs(30),
57 max_retries: 3,
58 initial_retry_delay: Duration::from_millis(100),
59 max_retry_delay: Duration::from_secs(30),
60 backoff_multiplier: 2.0,
61 use_jitter: true,
62 retry_on_timeout: true,
63 retry_on_errors: vec![
64 "timeout".to_string(),
65 "connection".to_string(),
66 "network".to_string(),
67 "server_error".to_string(),
68 ],
69 }
70 }
71}
72
73impl TimeoutConfig {
74 pub fn api_call() -> Self {
76 Self {
77 timeout_duration: Duration::from_secs(60),
78 max_retries: 5,
79 initial_retry_delay: Duration::from_millis(200),
80 max_retry_delay: Duration::from_secs(10),
81 backoff_multiplier: 1.5,
82 ..Default::default()
83 }
84 }
85
86 pub fn file_operation() -> Self {
88 Self {
89 timeout_duration: Duration::from_secs(10),
90 max_retries: 2,
91 initial_retry_delay: Duration::from_millis(50),
92 max_retry_delay: Duration::from_secs(2),
93 backoff_multiplier: 2.0,
94 retry_on_timeout: false, ..Default::default()
96 }
97 }
98
99 pub fn analysis() -> Self {
101 Self {
102 timeout_duration: Duration::from_secs(120),
103 max_retries: 1,
104 initial_retry_delay: Duration::from_secs(5),
105 max_retry_delay: Duration::from_secs(10),
106 backoff_multiplier: 2.0,
107 ..Default::default()
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct TimeoutEvent {
115 pub operation_id: String,
116 pub operation_type: OperationType,
117 pub start_time: Instant,
118 pub timeout_duration: Duration,
119 pub retry_count: u32,
120 pub error_message: Option<String>,
121}
122
123#[derive(Debug, Clone, Default)]
125pub struct TimeoutStats {
126 pub total_operations: usize,
127 pub timed_out_operations: usize,
128 pub successful_retries: usize,
129 pub failed_retries: usize,
130 pub average_timeout_duration: Duration,
131 pub total_retry_attempts: usize,
132}
133
134pub struct TimeoutDetector {
136 configs: Arc<RwLock<HashMap<OperationType, TimeoutConfig>>>,
137 active_operations: Arc<RwLock<HashMap<String, TimeoutEvent>>>,
138 stats: Arc<RwLock<TimeoutStats>>,
139}
140
141impl Default for TimeoutDetector {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl TimeoutDetector {
148 pub fn new() -> Self {
149 let mut configs = HashMap::new();
150
151 configs.insert(OperationType::ApiCall, TimeoutConfig::api_call());
153 configs.insert(
154 OperationType::FileOperation,
155 TimeoutConfig::file_operation(),
156 );
157 configs.insert(OperationType::CodeAnalysis, TimeoutConfig::analysis());
158 configs.insert(OperationType::ToolExecution, TimeoutConfig::default());
159 configs.insert(OperationType::NetworkRequest, TimeoutConfig::api_call());
160 configs.insert(OperationType::Processing, TimeoutConfig::analysis());
161
162 Self {
163 configs: Arc::new(RwLock::new(configs)),
164 active_operations: Arc::new(RwLock::new(HashMap::new())),
165 stats: Arc::new(RwLock::new(TimeoutStats::default())),
166 }
167 }
168
169 pub async fn set_config(&self, operation_type: OperationType, config: TimeoutConfig) {
171 let mut configs = self.configs.write().await;
172 configs.insert(operation_type, config);
173 }
174
175 pub async fn get_config(&self, operation_type: &OperationType) -> TimeoutConfig {
177 let configs = self.configs.read().await;
178 configs.get(operation_type).cloned().unwrap_or_default()
179 }
180
181 pub async fn start_operation(
183 &self,
184 operation_id: String,
185 operation_type: OperationType,
186 ) -> TimeoutHandle {
187 let config = self.get_config(&operation_type).await;
188
189 let event = TimeoutEvent {
190 operation_id: operation_id.clone(),
191 operation_type,
192 start_time: Instant::now(),
193 timeout_duration: config.timeout_duration,
194 retry_count: 0,
195 error_message: None,
196 };
197
198 let mut active_ops = self.active_operations.write().await;
199 active_ops.insert(operation_id.clone(), event);
200
201 let mut stats = self.stats.write().await;
202 stats.total_operations += 1;
203
204 TimeoutHandle {
205 operation_id,
206 detector: Arc::new(self.clone()),
207 }
208 }
209
210 pub async fn check_timeout(&self, operation_id: &str) -> Option<TimeoutEvent> {
212 let active_ops = self.active_operations.read().await;
213 if let Some(event) = active_ops.get(operation_id) {
214 if event.start_time.elapsed() >= event.timeout_duration {
215 return Some(event.clone());
216 }
217 }
218 None
219 }
220
221 pub async fn record_timeout(&self, operation_id: &str, error_message: Option<String>) {
223 let mut active_ops = self.active_operations.write().await;
224 if let Some(event) = active_ops.get_mut(operation_id) {
225 event.error_message = error_message;
226 }
227
228 let mut stats = self.stats.write().await;
229 stats.timed_out_operations += 1;
230 }
231
232 pub async fn record_successful_retry(&self, _operation_id: &str) {
234 let mut stats = self.stats.write().await;
235 stats.successful_retries += 1;
236 stats.total_retry_attempts += 1;
237 }
238
239 pub async fn record_failed_retry(&self, _operation_id: &str) {
241 let mut stats = self.stats.write().await;
242 stats.failed_retries += 1;
243 stats.total_retry_attempts += 1;
244 }
245
246 pub async fn end_operation(&self, operation_id: &str) {
248 let mut active_ops = self.active_operations.write().await;
249 if let Some(event) = active_ops.remove(operation_id) {
250 let duration = event.start_time.elapsed();
251 let mut stats = self.stats.write().await;
252 if stats.total_operations > 0 {
254 let total_duration =
255 stats.average_timeout_duration * (stats.total_operations - 1) as u32;
256 stats.average_timeout_duration =
257 (total_duration + duration) / stats.total_operations as u32;
258 }
259 }
260 }
261
262 pub async fn get_stats(&self) -> TimeoutStats {
264 self.stats.read().await.clone()
265 }
266
267 pub async fn calculate_retry_delay(
269 &self,
270 operation_type: &OperationType,
271 attempt: u32,
272 ) -> Duration {
273 let config = self.get_config(operation_type).await;
274
275 let base_delay = config.initial_retry_delay.as_millis() as f64;
276 let multiplier = config.backoff_multiplier.powi(attempt as i32);
277 let delay_ms = (base_delay * multiplier) as u64;
278
279 let mut delay =
280 Duration::from_millis(delay_ms.min(config.max_retry_delay.as_millis() as u64));
281
282 if config.use_jitter {
284 use std::time::SystemTime;
285 let seed = SystemTime::now()
286 .duration_since(std::time::UNIX_EPOCH)
287 .unwrap()
288 .as_nanos() as u64;
289 let jitter_factor = (seed % 100) as f64 / 100.0; let jitter_ms = (delay.as_millis() as f64 * 0.1 * jitter_factor) as u64; delay += Duration::from_millis(jitter_ms);
292 }
293
294 delay
295 }
296
297 pub async fn should_retry(
299 &self,
300 operation_type: &OperationType,
301 error: &anyhow::Error,
302 attempt: u32,
303 ) -> bool {
304 let config = self.get_config(operation_type).await;
305
306 if attempt >= config.max_retries {
307 return false;
308 }
309
310 let error_str = error.to_string().to_lowercase();
311
312 for retry_error in &config.retry_on_errors {
314 if error_str.contains(retry_error) {
315 return true;
316 }
317 }
318
319 if config.retry_on_timeout && error_str.contains("timeout") {
321 return true;
322 }
323
324 false
325 }
326
327 pub async fn execute_with_timeout_retry<F, Fut, T>(
329 &self,
330 operation_id: String,
331 operation_type: OperationType,
332 mut operation: F,
333 ) -> Result<T, anyhow::Error>
334 where
335 F: FnMut() -> Fut,
336 Fut: std::future::Future<Output = Result<T, anyhow::Error>>,
337 {
338 let config = self.get_config(&operation_type).await;
339 let mut attempt = 0;
340 let _last_error: Option<anyhow::Error> = None;
341
342 loop {
343 let handle = self
344 .start_operation(
345 format!("{}_{}", operation_id, attempt),
346 operation_type.clone(),
347 )
348 .await;
349
350 let result = match time::timeout(config.timeout_duration, operation()).await {
351 Ok(result) => result,
352 Err(_) => {
353 self.record_timeout(
354 &handle.operation_id,
355 Some("Operation timed out".to_string()),
356 )
357 .await;
358 Err(anyhow::anyhow!(
359 "Operation '{}' timed out after {:?}",
360 operation_id,
361 config.timeout_duration
362 ))
363 }
364 };
365
366 handle.end().await;
367
368 match result {
369 Ok(value) => {
370 if attempt > 0 {
371 self.record_successful_retry(&format!("{}_{}", operation_id, attempt))
372 .await;
373 }
374 return Ok(value);
375 }
376 Err(error) => {
377 let should_retry_op = self.should_retry(&operation_type, &error, attempt).await;
378
379 if !should_retry_op {
380 if attempt > 0 {
381 self.record_failed_retry(&format!("{}_{}", operation_id, attempt))
382 .await;
383 }
384 return Err(error);
385 }
386
387 attempt += 1;
388 self.record_failed_retry(&format!("{}_{}", operation_id, attempt))
389 .await;
390
391 let delay = self.calculate_retry_delay(&operation_type, attempt).await;
392 eprintln!(
393 "Operation '{}' failed (attempt {}/{}), retrying in {:?}",
394 operation_id, attempt, config.max_retries, delay
395 );
396 time::sleep(delay).await;
397 }
398 }
399 }
400 }
401}
402
403impl Clone for TimeoutDetector {
404 fn clone(&self) -> Self {
405 Self {
406 configs: Arc::clone(&self.configs),
407 active_operations: Arc::clone(&self.active_operations),
408 stats: Arc::clone(&self.stats),
409 }
410 }
411}
412
413pub struct TimeoutHandle {
415 operation_id: String,
416 detector: Arc<TimeoutDetector>,
417}
418
419impl TimeoutHandle {
420 pub async fn end(self) {
422 self.detector.end_operation(&self.operation_id).await;
423 }
424
425 pub fn operation_id(&self) -> &str {
427 &self.operation_id
428 }
429}
430
431impl Drop for TimeoutHandle {
432 fn drop(&mut self) {
433 let operation_id = self.operation_id.clone();
435 let detector = Arc::clone(&self.detector);
436
437 tokio::spawn(async move {
438 detector.end_operation(&operation_id).await;
439 });
440 }
441}
442
443use once_cell::sync::Lazy;
445pub static TIMEOUT_DETECTOR: Lazy<TimeoutDetector> = Lazy::new(TimeoutDetector::new);
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use std::sync::atomic::{AtomicUsize, Ordering};
451 use std::time::Duration;
452 use tokio::time::sleep;
453
454 #[tokio::test]
455 async fn test_timeout_detection() {
456 let detector = TimeoutDetector::new();
457
458 let config = TimeoutConfig {
460 timeout_duration: Duration::from_millis(10),
461 max_retries: 0,
462 ..Default::default()
463 };
464
465 detector.set_config(OperationType::ApiCall, config).await;
466
467 let result = detector
468 .execute_with_timeout_retry(
469 "test_operation".to_string(),
470 OperationType::ApiCall,
471 || async {
472 sleep(Duration::from_millis(20)).await;
473 Ok("success")
474 },
475 )
476 .await;
477
478 assert!(result.is_err());
479 assert!(result.unwrap_err().to_string().contains("timed out"));
480 }
481
482 #[tokio::test]
483 async fn test_successful_retry() {
484 let detector = TimeoutDetector::new();
485
486 let config = TimeoutConfig {
487 timeout_duration: Duration::from_millis(50),
488 max_retries: 2,
489 initial_retry_delay: Duration::from_millis(5),
490 retry_on_timeout: true,
491 ..Default::default()
492 };
493
494 detector.set_config(OperationType::ApiCall, config).await;
495
496 let call_count = Arc::new(AtomicUsize::new(0));
497 let call_count_clone = call_count.clone();
498 let result = detector
499 .execute_with_timeout_retry(
500 "test_retry".to_string(),
501 OperationType::ApiCall,
502 move || {
503 let call_count = call_count_clone.clone();
504 async move {
505 let count = call_count.fetch_add(1, Ordering::SeqCst) + 1;
506 if count == 1 {
507 sleep(Duration::from_millis(60)).await;
509 Ok("should not reach here")
510 } else {
511 sleep(Duration::from_millis(10)).await;
513 Ok("success")
514 }
515 }
516 },
517 )
518 .await;
519
520 assert!(result.is_ok());
521 assert_eq!(result.unwrap(), "success");
522 assert_eq!(call_count.load(Ordering::SeqCst), 2);
523
524 let stats = detector.get_stats().await;
525 assert_eq!(stats.successful_retries, 1);
526 assert_eq!(stats.total_retry_attempts, 1);
527 }
528
529 #[tokio::test]
530 async fn test_calculate_retry_delay() {
531 let detector = TimeoutDetector::new();
532
533 let delay = detector
534 .calculate_retry_delay(&OperationType::ApiCall, 0)
535 .await;
536 assert!(delay >= Duration::from_millis(200)); let delay2 = detector
539 .calculate_retry_delay(&OperationType::ApiCall, 1)
540 .await;
541 assert!(delay2 > delay); }
543}