1use std::time::Duration;
2
3use reqwest::{Client, StatusCode};
4use serde::de::DeserializeOwned;
5use serde::{Deserialize, Serialize};
6use tokio::time::sleep;
7
8use crate::error::{SyncError, SyncResult};
9
10#[derive(Debug, Clone, Copy)]
11pub struct RetryConfig {
12 pub max_attempts: u32,
13 pub initial_delay: Duration,
14 pub max_delay: Duration,
15 pub exponential_base: u32,
16}
17
18impl Default for RetryConfig {
19 fn default() -> Self {
20 Self {
21 max_attempts: 5,
22 initial_delay: Duration::from_secs(2),
23 max_delay: Duration::from_secs(30),
24 exponential_base: 2,
25 }
26 }
27}
28
29#[derive(Clone, Debug)]
30pub struct SyncApiClient {
31 client: Client,
32 api_url: String,
33 token: String,
34 hostname: Option<String>,
35 sync_token: Option<String>,
36 retry_config: RetryConfig,
37}
38
39#[derive(Debug, Deserialize)]
40pub struct RegistryToken {
41 pub registry: String,
42 pub username: String,
43 pub token: String,
44}
45
46#[derive(Debug, Clone, Copy, Deserialize)]
47pub struct UploadResponse {
48 pub files_uploaded: usize,
49}
50
51#[derive(Debug, Deserialize)]
52pub struct DeployResponse {
53 pub status: String,
54 pub app_url: Option<String>,
55}
56
57impl SyncApiClient {
58 pub fn new(api_url: &str, token: &str) -> SyncResult<Self> {
59 Ok(Self {
60 client: Client::builder()
61 .connect_timeout(Duration::from_secs(10))
62 .timeout(Duration::from_secs(60))
63 .build()?,
64 api_url: api_url.to_string(),
65 token: token.to_string(),
66 hostname: None,
67 sync_token: None,
68 retry_config: RetryConfig::default(),
69 })
70 }
71
72 pub fn with_direct_sync(
73 mut self,
74 hostname: Option<String>,
75 sync_token: Option<String>,
76 ) -> Self {
77 self.hostname = hostname;
78 self.sync_token = sync_token;
79 self
80 }
81
82 fn direct_sync_credentials(&self) -> Option<(String, String)> {
83 match (&self.hostname, &self.sync_token) {
84 (Some(hostname), Some(token)) => {
85 let url = format!("https://{}/api/v1/sync/files", hostname);
86 Some((url, token.clone()))
87 },
88 _ => None,
89 }
90 }
91
92 fn calculate_next_delay(&self, current: Duration) -> Duration {
93 current
94 .saturating_mul(self.retry_config.exponential_base)
95 .min(self.retry_config.max_delay)
96 }
97
98 pub async fn upload_files(
99 &self,
100 tenant_id: &systemprompt_identifiers::TenantId,
101 data: Vec<u8>,
102 ) -> SyncResult<UploadResponse> {
103 let (url, token) = self.direct_sync_credentials().unwrap_or_else(|| {
104 (
105 format!("{}/api/v1/cloud/tenants/{}/files", self.api_url, tenant_id),
106 self.token.clone(),
107 )
108 });
109
110 let mut current_delay = self.retry_config.initial_delay;
111
112 for attempt in 1..=self.retry_config.max_attempts {
113 let response = self
114 .client
115 .post(&url)
116 .header("Authorization", format!("Bearer {}", token))
117 .header("Content-Type", "application/octet-stream")
118 .body(data.clone())
119 .send()
120 .await?;
121
122 match self.handle_json_response::<UploadResponse>(response).await {
123 Ok(upload) => return Ok(upload),
124 Err(error) if error.is_retryable() && attempt < self.retry_config.max_attempts => {
125 tracing::warn!(
126 attempt = attempt,
127 max_attempts = self.retry_config.max_attempts,
128 delay_ms = current_delay.as_millis() as u64,
129 error = %error,
130 "Retryable sync error, waiting before retry"
131 );
132 sleep(current_delay).await;
133 current_delay = self.calculate_next_delay(current_delay);
134 },
135 Err(error) => return Err(error),
136 }
137 }
138
139 Err(SyncError::ApiError {
140 status: 503,
141 message: "Max retry attempts exceeded".to_string(),
142 })
143 }
144
145 pub async fn download_files(
146 &self,
147 tenant_id: &systemprompt_identifiers::TenantId,
148 ) -> SyncResult<Vec<u8>> {
149 let (url, token) = self.direct_sync_credentials().unwrap_or_else(|| {
150 (
151 format!("{}/api/v1/cloud/tenants/{}/files", self.api_url, tenant_id),
152 self.token.clone(),
153 )
154 });
155
156 let mut current_delay = self.retry_config.initial_delay;
157
158 for attempt in 1..=self.retry_config.max_attempts {
159 let response = self
160 .client
161 .get(&url)
162 .header("Authorization", format!("Bearer {}", token))
163 .send()
164 .await?;
165
166 match self.handle_binary_response(response).await {
167 Ok(data) => return Ok(data),
168 Err(error) if error.is_retryable() && attempt < self.retry_config.max_attempts => {
169 tracing::warn!(
170 attempt = attempt,
171 max_attempts = self.retry_config.max_attempts,
172 delay_ms = current_delay.as_millis() as u64,
173 error = %error,
174 "Retryable sync error, waiting before retry"
175 );
176 sleep(current_delay).await;
177 current_delay = self.calculate_next_delay(current_delay);
178 },
179 Err(error) => return Err(error),
180 }
181 }
182
183 Err(SyncError::ApiError {
184 status: 503,
185 message: "Max retry attempts exceeded".to_string(),
186 })
187 }
188
189 pub async fn get_registry_token(
190 &self,
191 tenant_id: &systemprompt_identifiers::TenantId,
192 ) -> SyncResult<RegistryToken> {
193 let url = format!(
194 "{}/api/v1/cloud/tenants/{}/registry-token",
195 self.api_url, tenant_id
196 );
197 self.get(&url).await
198 }
199
200 pub async fn deploy(
201 &self,
202 tenant_id: &systemprompt_identifiers::TenantId,
203 image: &str,
204 ) -> SyncResult<DeployResponse> {
205 let url = format!("{}/api/v1/cloud/tenants/{}/deploy", self.api_url, tenant_id);
206 self.post(&url, &serde_json::json!({ "image": image }))
207 .await
208 }
209
210 pub async fn get_tenant_app_id(
211 &self,
212 tenant_id: &systemprompt_identifiers::TenantId,
213 ) -> SyncResult<String> {
214 #[derive(Deserialize)]
215 struct TenantInfo {
216 fly_app_name: Option<String>,
217 }
218 let url = format!("{}/api/v1/cloud/tenants/{}", self.api_url, tenant_id);
219 let info: TenantInfo = self.get(&url).await?;
220 info.fly_app_name.ok_or(SyncError::TenantNoApp)
221 }
222
223 pub async fn get_database_url(
224 &self,
225 tenant_id: &systemprompt_identifiers::TenantId,
226 ) -> SyncResult<String> {
227 #[derive(Deserialize)]
228 struct DatabaseInfo {
229 database_url: Option<String>,
230 }
231 let url = format!(
232 "{}/api/v1/cloud/tenants/{}/database",
233 self.api_url, tenant_id
234 );
235 let info: DatabaseInfo = self.get(&url).await?;
236 info.database_url.ok_or_else(|| SyncError::ApiError {
237 status: 404,
238 message: "Database URL not available for tenant".to_string(),
239 })
240 }
241
242 async fn get<T: DeserializeOwned>(&self, url: &str) -> SyncResult<T> {
243 let response = self
244 .client
245 .get(url)
246 .header("Authorization", format!("Bearer {}", self.token))
247 .send()
248 .await?;
249
250 self.handle_json_response(response).await
251 }
252
253 async fn post<T: DeserializeOwned, B: Serialize + Sync>(
254 &self,
255 url: &str,
256 body: &B,
257 ) -> SyncResult<T> {
258 let response = self
259 .client
260 .post(url)
261 .header("Authorization", format!("Bearer {}", self.token))
262 .json(body)
263 .send()
264 .await?;
265
266 self.handle_json_response(response).await
267 }
268
269 async fn handle_json_response<T: DeserializeOwned>(
270 &self,
271 response: reqwest::Response,
272 ) -> SyncResult<T> {
273 let status = response.status();
274 if status == StatusCode::UNAUTHORIZED {
275 return Err(SyncError::Unauthorized);
276 }
277 if !status.is_success() {
278 let message = response.text().await?;
279 return Err(SyncError::ApiError {
280 status: status.as_u16(),
281 message,
282 });
283 }
284 Ok(response.json().await?)
285 }
286
287 async fn handle_binary_response(&self, response: reqwest::Response) -> SyncResult<Vec<u8>> {
288 let status = response.status();
289 if !status.is_success() {
290 let message = response
291 .text()
292 .await
293 .unwrap_or_else(|e| format!("(body unreadable: {})", e));
294 return Err(SyncError::ApiError {
295 status: status.as_u16(),
296 message,
297 });
298 }
299 Ok(response.bytes().await?.to_vec())
300 }
301}