Skip to main content

vtcode_core/
retry.rs

1//! Shared retry policy and classification helpers for VT Code.
2
3use std::future::Future;
4use std::result::Result as StdResult;
5use std::time::Duration;
6
7use crate::config::constants::tools;
8use crate::error::{ErrorCategory, VtCodeError};
9use crate::retry_after::retry_after_from_llm_metadata;
10use crate::tools::registry::ToolExecutionError;
11use crate::tools::tool_intent;
12use crate::tools::unified_error::UnifiedToolError;
13use vtcode_commons::llm::{LLMError, LLMErrorMetadata};
14
15/// Typed retry policy shared across runtime layers.
16#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
17pub struct RetryPolicy {
18    /// Maximum number of total attempts, including the initial call.
19    pub max_attempts: u32,
20    pub initial_delay: Duration,
21    pub max_delay: Duration,
22    pub multiplier: f64,
23    pub jitter: f64,
24}
25
26impl RetryPolicy {
27    pub fn new(
28        max_attempts: u32,
29        initial_delay: Duration,
30        max_delay: Duration,
31        multiplier: f64,
32    ) -> Self {
33        Self {
34            max_attempts: max_attempts.max(1),
35            initial_delay,
36            max_delay,
37            multiplier: multiplier.max(1.0),
38            jitter: 0.0,
39        }
40    }
41
42    pub fn from_retries(
43        max_retries: u32,
44        initial_delay: Duration,
45        max_delay: Duration,
46        multiplier: f64,
47    ) -> Self {
48        Self::new(
49            max_retries.saturating_add(1),
50            initial_delay,
51            max_delay,
52            multiplier,
53        )
54    }
55
56    pub fn delay_for_attempt(&self, attempt_index: u32) -> Duration {
57        let multiplier = self.multiplier.powi(attempt_index as i32);
58        let base_delay = Duration::try_from_secs_f64(self.initial_delay.as_secs_f64() * multiplier)
59            .unwrap_or(self.max_delay)
60            .min(self.max_delay);
61
62        if !self.jitter.is_finite() || self.jitter <= 0.0 {
63            return base_delay;
64        }
65
66        let max_jitter_ms = (base_delay.as_millis() as f64 * self.jitter)
67            .round()
68            .clamp(0.0, u64::MAX as f64) as u64;
69        if max_jitter_ms == 0 {
70            return base_delay;
71        }
72
73        let offset = (u64::from(attempt_index) * 31) % max_jitter_ms.saturating_add(1);
74        base_delay.saturating_add(Duration::from_millis(offset))
75    }
76
77    pub fn decision_for_category(
78        &self,
79        category: ErrorCategory,
80        attempt_index: u32,
81        retry_after: Option<Duration>,
82    ) -> RetryDecision {
83        let has_remaining_attempts = attempt_index.saturating_add(1) < self.max_attempts;
84        if !category.is_retryable() || !has_remaining_attempts {
85            return RetryDecision {
86                category,
87                retryable: false,
88                delay: None,
89                retry_after,
90            };
91        }
92
93        let delay = retry_after.unwrap_or_else(|| self.delay_for_attempt(attempt_index));
94        RetryDecision {
95            category,
96            retryable: true,
97            delay: Some(delay),
98            retry_after,
99        }
100    }
101
102    pub fn decision_for_vtcode_error(
103        &self,
104        error: &VtCodeError,
105        attempt_index: u32,
106        tool_name: Option<&str>,
107    ) -> RetryDecision {
108        self.decision_for_category_with_tool(
109            error.category,
110            attempt_index,
111            error.retry_after(),
112            tool_name,
113        )
114    }
115
116    pub fn decision_for_anyhow(
117        &self,
118        error: &anyhow::Error,
119        attempt_index: u32,
120        tool_name: Option<&str>,
121    ) -> RetryDecision {
122        if let Some(vtcode_error) = error.downcast_ref::<VtCodeError>() {
123            return self.decision_for_vtcode_error(vtcode_error, attempt_index, tool_name);
124        }
125        if let Some(llm_error) = error.downcast_ref::<LLMError>() {
126            return self.decision_for_llm_error(llm_error, attempt_index);
127        }
128        if let Some(tool_error) = error.downcast_ref::<UnifiedToolError>() {
129            let tool_name = tool_name.or_else(|| {
130                tool_error
131                    .debug_context
132                    .as_ref()
133                    .map(|ctx| ctx.tool_name.as_str())
134                    .filter(|tool_name| !tool_name.is_empty())
135            });
136            return self.decision_for_category_with_tool(
137                tool_error.category(),
138                attempt_index,
139                None,
140                tool_name,
141            );
142        }
143
144        let category = vtcode_commons::classify_anyhow_error(error);
145        self.decision_for_category_with_tool(category, attempt_index, None, tool_name)
146    }
147
148    pub fn decision_for_llm_error(&self, error: &LLMError, attempt_index: u32) -> RetryDecision {
149        let retry_after = llm_metadata(error).and_then(retry_after_from_llm_metadata);
150        self.decision_for_category_with_tool(
151            ErrorCategory::from(error),
152            attempt_index,
153            retry_after,
154            None,
155        )
156    }
157
158    pub fn decision_for_tool_error(
159        &self,
160        error: &UnifiedToolError,
161        attempt_index: u32,
162    ) -> RetryDecision {
163        let tool_name = error
164            .debug_context
165            .as_ref()
166            .map(|ctx| ctx.tool_name.as_str())
167            .filter(|tool_name| !tool_name.is_empty());
168        self.decision_for_category_with_tool(error.category(), attempt_index, None, tool_name)
169    }
170
171    pub fn decision_for_tool_execution_error(
172        &self,
173        error: &ToolExecutionError,
174        attempt_index: u32,
175    ) -> RetryDecision {
176        self.decision_for_category_with_tool(
177            error.category,
178            attempt_index,
179            error.retry_after(),
180            Some(error.tool_name.as_str()),
181        )
182    }
183
184    /// Classify a `VtCodeError` failure into a typed [`RetryStep`].
185    ///
186    /// This consolidates the "decision -> sleep or give-up" branching used by
187    /// agent-level retry loops, removing the brittle
188    /// `decision.delay.expect("retryable decisions need delay")` calls and
189    /// guaranteeing a delay is always available for retryable steps.
190    pub fn step_for_vtcode_error(
191        &self,
192        error: VtCodeError,
193        attempt_index: u32,
194        tool_name: Option<&str>,
195    ) -> RetryStep {
196        let decision = self.decision_for_vtcode_error(&error, attempt_index, tool_name);
197        if decision.retryable {
198            let delay = decision
199                .delay
200                .unwrap_or_else(|| self.delay_for_attempt(attempt_index));
201            RetryStep::Backoff {
202                delay,
203                decision,
204                error,
205            }
206        } else {
207            RetryStep::GiveUp { decision, error }
208        }
209    }
210
211    pub fn apply_to_tool_execution_error(
212        &self,
213        error: ToolExecutionError,
214        attempt_index: u32,
215        tool_name: Option<&str>,
216    ) -> ToolExecutionError {
217        let decision = self.decision_for_category_with_tool(
218            error.category,
219            attempt_index,
220            error.retry_after(),
221            tool_name.or(Some(error.tool_name.as_str())),
222        );
223        error.with_retry_decision(decision)
224    }
225
226    fn decision_for_category_with_tool(
227        &self,
228        category: ErrorCategory,
229        attempt_index: u32,
230        retry_after: Option<Duration>,
231        tool_name: Option<&str>,
232    ) -> RetryDecision {
233        if matches!(category, ErrorCategory::Timeout) && tool_name.is_some_and(is_command_tool) {
234            return RetryDecision {
235                category,
236                retryable: false,
237                delay: None,
238                retry_after,
239            };
240        }
241
242        self.decision_for_category(category, attempt_index, retry_after)
243    }
244}
245
246impl Default for RetryPolicy {
247    fn default() -> Self {
248        Self::from_retries(2, Duration::from_secs(1), Duration::from_secs(60), 2.0)
249    }
250}
251
252/// Result of classifying a failure for retry handling.
253#[derive(Debug, Clone, PartialEq, Eq)]
254pub struct RetryDecision {
255    pub category: ErrorCategory,
256    pub retryable: bool,
257    pub delay: Option<Duration>,
258    pub retry_after: Option<Duration>,
259}
260
261/// Typed step produced by [`RetryPolicy::step_for_vtcode_error`].
262///
263/// Callers match on this instead of re-deriving the
264/// `if decision.retryable { sleep(decision.delay.expect(...)) }` pattern.
265#[derive(Debug)]
266pub enum RetryStep {
267    /// Wait `delay` then retry; `error` is the failure being retried.
268    Backoff {
269        delay: Duration,
270        decision: RetryDecision,
271        error: VtCodeError,
272    },
273    /// Give up immediately and surface `error`.
274    GiveUp {
275        decision: RetryDecision,
276        error: VtCodeError,
277    },
278}
279
280/// Lifecycle event emitted by [`run_with_retry`] so callers can attach
281/// logging, metrics, or stats updates without re-implementing the loop.
282///
283/// `category_was_retryable` is provided alongside `GiveUp` and `Backoff`
284/// because call sites frequently need to distinguish "non-retryable
285/// category" from "retryable category, budget exhausted" — the distinction
286/// is otherwise invisible once `step_for_vtcode_error` collapses both
287/// into `GiveUp`.
288#[derive(Debug)]
289pub enum RetryEvent<'a> {
290    /// An attempt is about to start. `attempt` is 0-indexed.
291    AttemptStart { attempt: u32, max_attempts: u32 },
292    /// The operation succeeded on the given attempt.
293    Success { attempt: u32 },
294    /// The policy decided to give up and surface `error` immediately.
295    GiveUp {
296        attempt: u32,
297        error: &'a VtCodeError,
298        decision: &'a RetryDecision,
299        category_was_retryable: bool,
300    },
301    /// The policy decided to back off and retry. The driver will sleep
302    /// for `delay` before invoking the operation again.
303    Backoff {
304        attempt: u32,
305        error: &'a VtCodeError,
306        decision: &'a RetryDecision,
307        delay: Duration,
308        category_was_retryable: bool,
309    },
310    /// All attempts were exhausted. `last_error` is the most recent
311    /// error captured from a `Backoff` step.
312    Exhausted { last_error: Option<&'a VtCodeError> },
313}
314
315/// Drive a retry loop per `policy`, invoking `on_event` for each
316/// lifecycle event. Returns the first successful result, or the final
317/// error if the policy gives up or all attempts are exhausted.
318///
319/// `state` is reborrowed mutably and passed to each callback on every
320/// invocation, so callers can use it to thread a `&mut self` (or any
321/// other mutable context) through the loop without resorting to
322/// `RefCell` or split borrows. The `operation` callback must return a
323/// boxed future so the helper can `await` it without tying its lifetime
324/// to the closure's own borrow of `state`.
325///
326/// The returned future is `Send` so it can be passed to `tokio::spawn`.
327///
328/// `synthesize_exhausted_error` is invoked only in the (degenerate)
329/// case where the loop completes without ever recording a `Backoff`
330/// error. The closure receives the `&RetryPolicy` so it can attach
331/// site-specific context (e.g. `policy.max_attempts`) without having
332/// to capture the policy in its environment — which is exactly the
333/// pattern that conflicts with `&mut state`.
334#[allow(clippy::too_many_arguments)]
335pub async fn run_with_retry<T, E, S, F, OnEvent, Synthesize>(
336    policy: &RetryPolicy,
337    state: &mut S,
338    mut on_event: OnEvent,
339    mut operation: F,
340    synthesize_exhausted_error: Synthesize,
341) -> crate::error::Result<T>
342where
343    F: for<'a> FnMut(
344        &'a mut S,
345    ) -> std::pin::Pin<Box<dyn Future<Output = StdResult<T, E>> + Send + 'a>>,
346    E: Into<VtCodeError>,
347    OnEvent: FnMut(&mut S, RetryEvent<'_>),
348    Synthesize: FnOnce(&RetryPolicy) -> VtCodeError,
349{
350    use tokio::time::sleep;
351
352    let mut last_error: Option<VtCodeError> = None;
353    for attempt in 0..policy.max_attempts {
354        on_event(
355            state,
356            RetryEvent::AttemptStart {
357                attempt,
358                max_attempts: policy.max_attempts,
359            },
360        );
361        match operation(state).await {
362            Ok(value) => {
363                on_event(state, RetryEvent::Success { attempt });
364                return Ok(value);
365            }
366            Err(err) => {
367                let err: VtCodeError = err.into();
368                let category_was_retryable = err.category.is_retryable();
369                let step = policy.step_for_vtcode_error(err, attempt, None);
370                match step {
371                    RetryStep::GiveUp { decision, error } => {
372                        on_event(
373                            state,
374                            RetryEvent::GiveUp {
375                                attempt,
376                                error: &error,
377                                decision: &decision,
378                                category_was_retryable,
379                            },
380                        );
381                        return Err(error);
382                    }
383                    RetryStep::Backoff {
384                        delay,
385                        decision,
386                        error,
387                    } => {
388                        on_event(
389                            state,
390                            RetryEvent::Backoff {
391                                attempt,
392                                error: &error,
393                                decision: &decision,
394                                delay,
395                                category_was_retryable,
396                            },
397                        );
398                        last_error = Some(error);
399                        sleep(delay).await;
400                    }
401                }
402            }
403        }
404    }
405    let final_error = last_error.unwrap_or_else(|| synthesize_exhausted_error(policy));
406    on_event(
407        state,
408        RetryEvent::Exhausted {
409            last_error: Some(&final_error),
410        },
411    );
412    Err(final_error)
413}
414
415fn llm_metadata(error: &LLMError) -> Option<&LLMErrorMetadata> {
416    match error {
417        LLMError::Authentication { metadata, .. }
418        | LLMError::RateLimit { metadata }
419        | LLMError::InvalidRequest { metadata, .. }
420        | LLMError::Network { metadata, .. }
421        | LLMError::Provider { metadata, .. } => metadata.as_deref(),
422    }
423}
424
425pub fn decision_for_vtcode_error(
426    error: &VtCodeError,
427    attempt_index: u32,
428    tool_name: Option<&str>,
429    policy_override: Option<&RetryPolicy>,
430) -> RetryDecision {
431    let owned_policy;
432    let policy = if let Some(policy) = policy_override {
433        policy
434    } else {
435        owned_policy = RetryPolicy::default();
436        &owned_policy
437    };
438    policy.decision_for_vtcode_error(error, attempt_index, tool_name)
439}
440
441pub fn decision_for_anyhow_error(
442    error: &anyhow::Error,
443    attempt_index: u32,
444    tool_name: Option<&str>,
445    policy_override: Option<&RetryPolicy>,
446) -> RetryDecision {
447    let owned_policy;
448    let policy = if let Some(policy) = policy_override {
449        policy
450    } else {
451        owned_policy = RetryPolicy::default();
452        &owned_policy
453    };
454    policy.decision_for_anyhow(error, attempt_index, tool_name)
455}
456
457#[must_use]
458pub fn is_command_tool(tool_name: &str) -> bool {
459    tool_name == tools::CREATE_PTY_SESSION
460        || tool_name == tools::SEND_PTY_INPUT
461        || tool_intent::canonical_unified_exec_tool_name(tool_name).is_some()
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use crate::error::{ErrorCode, VtCodeError};
468
469    #[test]
470    fn non_retryable_categories_stop_immediately() {
471        let policy =
472            RetryPolicy::from_retries(2, Duration::from_secs(1), Duration::from_secs(8), 2.0);
473        let err = VtCodeError::security(ErrorCode::PermissionDenied, "blocked by policy");
474
475        let decision = policy.decision_for_vtcode_error(&err, 0, None);
476        assert_eq!(decision.category, ErrorCategory::PolicyViolation);
477        assert!(!decision.retryable);
478        assert!(decision.delay.is_none());
479    }
480
481    #[test]
482    fn retry_after_header_overrides_backoff_delay() {
483        let policy =
484            RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
485        let err = LLMError::RateLimit {
486            metadata: Some(LLMErrorMetadata::new(
487                "Anthropic",
488                Some(429),
489                Some("rate_limit_error".to_string()),
490                None,
491                None,
492                Some("7".to_string()),
493                Some("too many requests".to_string()),
494            )),
495        };
496
497        let decision = policy.decision_for_llm_error(&err, 0);
498        assert!(decision.retryable);
499        assert_eq!(decision.retry_after, Some(Duration::from_secs(7)));
500        assert_eq!(decision.delay, Some(Duration::from_secs(7)));
501    }
502
503    #[test]
504    fn delay_for_attempt_clamps_overflowing_backoff_to_max_delay() {
505        let policy =
506            RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), f64::MAX);
507
508        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(8));
509    }
510
511    #[test]
512    fn delay_for_attempt_ignores_non_finite_jitter() {
513        let mut policy =
514            RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
515        policy.jitter = f64::INFINITY;
516
517        assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
518    }
519
520    #[test]
521    fn delay_for_attempt_handles_huge_finite_jitter() {
522        let mut policy =
523            RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
524        policy.jitter = f64::MAX;
525
526        assert!(policy.delay_for_attempt(1) >= Duration::from_secs(2));
527    }
528
529    #[test]
530    fn quota_exhaustion_is_not_retryable() {
531        let policy =
532            RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
533        let err = LLMError::RateLimit {
534            metadata: Some(LLMErrorMetadata::new(
535                "OpenAI",
536                Some(429),
537                Some("insufficient_quota".to_string()),
538                None,
539                None,
540                None,
541                Some("quota exceeded".to_string()),
542            )),
543        };
544
545        let decision = policy.decision_for_llm_error(&err, 0);
546        assert_eq!(decision.category, ErrorCategory::ResourceExhausted);
547        assert!(!decision.retryable);
548    }
549
550    #[test]
551    fn anyhow_fallback_uses_shared_classifier() {
552        let policy =
553            RetryPolicy::from_retries(1, Duration::from_secs(1), Duration::from_secs(8), 2.0);
554
555        let decision =
556            policy.decision_for_anyhow(&anyhow::anyhow!("HTTP 503 Service Unavailable"), 0, None);
557        assert_eq!(decision.category, ErrorCategory::ServiceUnavailable);
558        assert!(decision.retryable);
559        assert_eq!(decision.delay, Some(Duration::from_secs(1)));
560    }
561
562    #[test]
563    fn anyhow_prefers_typed_llm_errors() {
564        let policy =
565            RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
566        let err = anyhow::Error::new(LLMError::RateLimit {
567            metadata: Some(LLMErrorMetadata::new(
568                "Anthropic",
569                Some(429),
570                Some("rate_limit_error".to_string()),
571                None,
572                None,
573                Some("9".to_string()),
574                Some("too many requests".to_string()),
575            )),
576        });
577
578        let decision = policy.decision_for_anyhow(&err, 0, None);
579        assert!(decision.retryable);
580        assert_eq!(decision.retry_after, Some(Duration::from_secs(9)));
581        assert_eq!(decision.delay, Some(Duration::from_secs(9)));
582    }
583
584    #[test]
585    fn canonical_exec_aliases_are_command_tools() {
586        for alias in [
587            tools::RUN_PTY_CMD,
588            tools::EXEC_COMMAND,
589            tools::WRITE_STDIN,
590            tools::UNIFIED_EXEC,
591            "shell",
592            "bash",
593            "container.exec",
594        ] {
595            assert!(
596                is_command_tool(alias),
597                "expected {alias} to be a command tool"
598            );
599        }
600    }
601
602    #[test]
603    fn typed_tool_timeout_for_command_tools_is_not_retryable() {
604        let policy =
605            RetryPolicy::from_retries(2, Duration::from_secs(1), Duration::from_secs(8), 2.0);
606        let err = UnifiedToolError::new(
607            crate::tools::unified_error::UnifiedErrorKind::Timeout,
608            "timed out",
609        )
610        .with_tool_name(tools::RUN_PTY_CMD);
611
612        let decision = policy.decision_for_tool_error(&err, 0);
613        assert_eq!(decision.category, ErrorCategory::Timeout);
614        assert!(!decision.retryable);
615    }
616
617    #[test]
618    fn anyhow_typed_tool_timeout_uses_fallback_tool_name() {
619        let policy =
620            RetryPolicy::from_retries(2, Duration::from_secs(1), Duration::from_secs(8), 2.0);
621        let err = anyhow::Error::new(UnifiedToolError::new(
622            crate::tools::unified_error::UnifiedErrorKind::Timeout,
623            "timed out",
624        ));
625
626        let decision = policy.decision_for_anyhow(&err, 0, Some(tools::RUN_PTY_CMD));
627        assert_eq!(decision.category, ErrorCategory::Timeout);
628        assert!(!decision.retryable);
629    }
630
631    #[test]
632    fn command_timeouts_do_not_retry() {
633        let policy =
634            RetryPolicy::from_retries(2, Duration::from_secs(1), Duration::from_secs(8), 2.0);
635        let err = VtCodeError::new(ErrorCategory::Timeout, ErrorCode::Timeout, "timed out");
636
637        let decision = policy.decision_for_vtcode_error(&err, 0, Some(tools::RUN_PTY_CMD));
638        assert_eq!(decision.category, ErrorCategory::Timeout);
639        assert!(!decision.retryable);
640    }
641
642    #[tokio::test]
643    async fn run_with_retry_returns_first_success() {
644        use std::sync::Arc;
645        use std::sync::atomic::{AtomicU32, Ordering};
646        let policy =
647            RetryPolicy::from_retries(3, Duration::from_millis(0), Duration::from_millis(1), 2.0);
648        let attempts = Arc::new(AtomicU32::new(0));
649        let attempts_for_op = attempts.clone();
650        let result: crate::error::Result<String> = run_with_retry(
651            &policy,
652            &mut (),
653            |_: &mut (), _| {},
654            |_| {
655                let attempts = attempts_for_op.clone();
656                Box::pin(async move {
657                    let n = attempts.fetch_add(1, Ordering::SeqCst) + 1;
658                    if n < 2 {
659                        Err(VtCodeError::network(
660                            crate::error::ErrorCode::ConnectionFailed,
661                            "transient",
662                        ))
663                    } else {
664                        Ok("ok".to_string())
665                    }
666                })
667            },
668            |_: &RetryPolicy| {
669                VtCodeError::execution(crate::error::ErrorCode::ToolExecutionFailed, "exhausted")
670            },
671        )
672        .await;
673        assert_eq!(result.unwrap(), "ok");
674        assert_eq!(attempts.load(Ordering::SeqCst), 2);
675    }
676
677    #[tokio::test]
678    async fn run_with_retry_surfaces_give_up_immediately() {
679        use std::sync::Arc;
680        use std::sync::atomic::{AtomicU32, Ordering};
681        let policy =
682            RetryPolicy::from_retries(5, Duration::from_millis(0), Duration::from_millis(1), 2.0);
683        let attempts = Arc::new(AtomicU32::new(0));
684        let attempts_for_op = attempts.clone();
685        let result: crate::error::Result<String> = run_with_retry(
686            &policy,
687            &mut (),
688            |_: &mut (), _| {},
689            |_| {
690                let attempts = attempts_for_op.clone();
691                Box::pin(async move {
692                    attempts.fetch_add(1, Ordering::SeqCst);
693                    Err::<String, _>(VtCodeError::input(
694                        crate::error::ErrorCode::InvalidArgument,
695                        "bad input",
696                    ))
697                })
698            },
699            |_: &RetryPolicy| {
700                VtCodeError::execution(crate::error::ErrorCode::ToolExecutionFailed, "exhausted")
701            },
702        )
703        .await;
704        assert!(result.is_err());
705        assert_eq!(
706            attempts.load(Ordering::SeqCst),
707            1,
708            "GiveUp should short-circuit retries"
709        );
710    }
711}