1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, SynapticError};
6
7#[derive(Debug, Clone)]
8pub struct RetryPolicy {
9 pub max_attempts: usize,
10 pub base_delay: Duration,
11}
12
13impl Default for RetryPolicy {
14 fn default() -> Self {
15 Self {
16 max_attempts: 3,
17 base_delay: Duration::from_millis(500),
18 }
19 }
20}
21
22pub struct RetryChatModel {
23 inner: Arc<dyn ChatModel>,
24 policy: RetryPolicy,
25}
26
27impl RetryChatModel {
28 pub fn new(inner: Arc<dyn ChatModel>, policy: RetryPolicy) -> Self {
29 Self { inner, policy }
30 }
31}
32
33fn is_retryable(err: &SynapticError) -> bool {
34 matches!(err, SynapticError::RateLimit(_) | SynapticError::Timeout(_))
35}
36
37#[async_trait]
38impl ChatModel for RetryChatModel {
39 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
40 let mut last_error = None;
41 for attempt in 0..self.policy.max_attempts {
42 match self.inner.chat(request.clone()).await {
43 Ok(resp) => return Ok(resp),
44 Err(e) if is_retryable(&e) && attempt + 1 < self.policy.max_attempts => {
45 let delay = self.policy.base_delay * 2u32.saturating_pow(attempt as u32);
46 tokio::time::sleep(delay).await;
47 last_error = Some(e);
48 }
49 Err(e) => return Err(e),
50 }
51 }
52 Err(last_error.unwrap_or_else(|| SynapticError::Model("retry exhausted".to_string())))
53 }
54
55 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
56 let inner = self.inner.clone();
57 let policy = self.policy.clone();
58
59 Box::pin(async_stream::stream! {
60 let mut last_error = None;
61 for attempt in 0..policy.max_attempts {
62 let mut stream = inner.stream_chat(request.clone());
63
64 use futures::StreamExt;
65 let mut chunks = Vec::new();
66 let mut had_error = false;
67
68 while let Some(result) = stream.next().await {
69 match result {
70 Ok(chunk) => chunks.push(chunk),
71 Err(e) if is_retryable(&e) && attempt + 1 < policy.max_attempts => {
72 let delay = policy.base_delay * 2u32.saturating_pow(attempt as u32);
73 tokio::time::sleep(delay).await;
74 last_error = Some(e);
75 had_error = true;
76 break;
77 }
78 Err(e) => {
79 yield Err(e);
80 return;
81 }
82 }
83 }
84
85 if !had_error {
86 for chunk in chunks {
87 yield Ok(chunk);
88 }
89 return;
90 }
91 }
92 if let Some(e) = last_error {
93 yield Err(e);
94 }
95 })
96 }
97}