1use std::time::Duration;
2
3use reqwest::{Client, StatusCode};
4use serde::de::DeserializeOwned;
5use serde::{Deserialize, Serialize};
6use systemprompt_models::net::{HTTP_CONNECT_TIMEOUT, HTTP_SYNC_DEPLOY_TIMEOUT};
7use tokio::time::sleep;
8
9use crate::error::{SyncError, SyncResult};
10
11#[derive(Debug, Clone, Copy)]
12pub struct RetryConfig {
13 pub max_attempts: u32,
14 pub initial_delay: Duration,
15 pub max_delay: Duration,
16 pub exponential_base: u32,
17}
18
19impl Default for RetryConfig {
20 fn default() -> Self {
21 Self {
22 max_attempts: 5,
23 initial_delay: Duration::from_secs(2),
24 max_delay: Duration::from_secs(30),
25 exponential_base: 2,
26 }
27 }
28}
29
30#[derive(Clone, Debug)]
31pub struct SyncApiClient {
32 client: Client,
33 api_url: String,
34 token: String,
35 hostname: Option<String>,
36 sync_token: Option<String>,
37 retry_config: RetryConfig,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct RegistryToken {
42 pub registry: String,
43 pub username: String,
44 pub token: String,
45}
46
47#[derive(Debug, Clone, Copy, Deserialize)]
48pub struct UploadResponse {
49 pub files_uploaded: usize,
50}
51
52#[derive(Debug, Deserialize)]
53pub struct DeployResponse {
54 pub status: String,
55 pub app_url: Option<String>,
56}
57
58impl SyncApiClient {
59 pub fn new(api_url: &str, token: &str) -> SyncResult<Self> {
60 Ok(Self {
61 client: Client::builder()
62 .connect_timeout(HTTP_CONNECT_TIMEOUT)
63 .timeout(HTTP_SYNC_DEPLOY_TIMEOUT)
64 .build()?,
65 api_url: api_url.to_string(),
66 token: token.to_string(),
67 hostname: None,
68 sync_token: None,
69 retry_config: RetryConfig::default(),
70 })
71 }
72
73 pub fn with_direct_sync(
74 mut self,
75 hostname: Option<String>,
76 sync_token: Option<String>,
77 ) -> Self {
78 self.hostname = hostname;
79 self.sync_token = sync_token;
80 self
81 }
82
83 fn direct_sync_credentials(&self) -> Option<(String, String)> {
84 match (&self.hostname, &self.sync_token) {
85 (Some(hostname), Some(token)) => {
86 let url = format!("https://{}/api/v1/sync/files", hostname);
87 Some((url, token.clone()))
88 },
89 _ => None,
90 }
91 }
92
93 fn calculate_next_delay(&self, current: Duration) -> Duration {
94 current
95 .saturating_mul(self.retry_config.exponential_base)
96 .min(self.retry_config.max_delay)
97 }
98
99 pub async fn upload_files(
100 &self,
101 tenant_id: &systemprompt_identifiers::TenantId,
102 data: Vec<u8>,
103 ) -> SyncResult<UploadResponse> {
104 let (url, token) = self.direct_sync_credentials().unwrap_or_else(|| {
105 (
106 format!("{}/api/v1/cloud/tenants/{}/files", self.api_url, tenant_id),
107 self.token.clone(),
108 )
109 });
110
111 let mut current_delay = self.retry_config.initial_delay;
112
113 for attempt in 1..=self.retry_config.max_attempts {
114 let response = self
115 .client
116 .post(&url)
117 .header("Authorization", format!("Bearer {}", token))
118 .header("Content-Type", "application/octet-stream")
119 .body(data.clone())
120 .send()
121 .await?;
122
123 match self.handle_json_response::<UploadResponse>(response).await {
124 Ok(upload) => return Ok(upload),
125 Err(error) if error.is_retryable() && attempt < self.retry_config.max_attempts => {
126 tracing::warn!(
127 attempt = attempt,
128 max_attempts = self.retry_config.max_attempts,
129 delay_ms = current_delay.as_millis() as u64,
130 error = %error,
131 "Retryable sync error, waiting before retry"
132 );
133 sleep(current_delay).await;
134 current_delay = self.calculate_next_delay(current_delay);
135 },
136 Err(error) => return Err(error),
137 }
138 }
139
140 Err(SyncError::ApiError {
141 status: 503,
142 message: "Max retry attempts exceeded".to_string(),
143 })
144 }
145
146 pub async fn download_files(
147 &self,
148 tenant_id: &systemprompt_identifiers::TenantId,
149 ) -> SyncResult<Vec<u8>> {
150 let (url, token) = self.direct_sync_credentials().unwrap_or_else(|| {
151 (
152 format!("{}/api/v1/cloud/tenants/{}/files", self.api_url, tenant_id),
153 self.token.clone(),
154 )
155 });
156
157 let mut current_delay = self.retry_config.initial_delay;
158
159 for attempt in 1..=self.retry_config.max_attempts {
160 let response = self
161 .client
162 .get(&url)
163 .header("Authorization", format!("Bearer {}", token))
164 .send()
165 .await?;
166
167 match self.handle_binary_response(response).await {
168 Ok(data) => return Ok(data),
169 Err(error) if error.is_retryable() && attempt < self.retry_config.max_attempts => {
170 tracing::warn!(
171 attempt = attempt,
172 max_attempts = self.retry_config.max_attempts,
173 delay_ms = current_delay.as_millis() as u64,
174 error = %error,
175 "Retryable sync error, waiting before retry"
176 );
177 sleep(current_delay).await;
178 current_delay = self.calculate_next_delay(current_delay);
179 },
180 Err(error) => return Err(error),
181 }
182 }
183
184 Err(SyncError::ApiError {
185 status: 503,
186 message: "Max retry attempts exceeded".to_string(),
187 })
188 }
189
190 pub async fn get_registry_token(
191 &self,
192 tenant_id: &systemprompt_identifiers::TenantId,
193 ) -> SyncResult<RegistryToken> {
194 let url = format!(
195 "{}/api/v1/cloud/tenants/{}/registry-token",
196 self.api_url, tenant_id
197 );
198 self.get(&url).await
199 }
200
201 pub async fn deploy(
202 &self,
203 tenant_id: &systemprompt_identifiers::TenantId,
204 image: &str,
205 ) -> SyncResult<DeployResponse> {
206 let url = format!("{}/api/v1/cloud/tenants/{}/deploy", self.api_url, tenant_id);
207 self.post(&url, &serde_json::json!({ "image": image }))
208 .await
209 }
210
211 pub async fn get_tenant_app_id(
212 &self,
213 tenant_id: &systemprompt_identifiers::TenantId,
214 ) -> SyncResult<String> {
215 #[derive(Deserialize)]
216 struct TenantInfo {
217 fly_app_name: Option<String>,
218 }
219 let url = format!("{}/api/v1/cloud/tenants/{}", self.api_url, tenant_id);
220 let info: TenantInfo = self.get(&url).await?;
221 info.fly_app_name.ok_or(SyncError::TenantNoApp)
222 }
223
224 pub async fn get_database_url(
225 &self,
226 tenant_id: &systemprompt_identifiers::TenantId,
227 ) -> SyncResult<String> {
228 #[derive(Deserialize)]
229 struct DatabaseInfo {
230 database_url: Option<String>,
231 }
232 let url = format!(
233 "{}/api/v1/cloud/tenants/{}/database",
234 self.api_url, tenant_id
235 );
236 let info: DatabaseInfo = self.get(&url).await?;
237 info.database_url.ok_or_else(|| SyncError::ApiError {
238 status: 404,
239 message: "Database URL not available for tenant".to_string(),
240 })
241 }
242
243 async fn get<T: DeserializeOwned>(&self, url: &str) -> SyncResult<T> {
244 let response = self
245 .client
246 .get(url)
247 .header("Authorization", format!("Bearer {}", self.token))
248 .send()
249 .await?;
250
251 self.handle_json_response(response).await
252 }
253
254 async fn post<T: DeserializeOwned, B: Serialize + Sync>(
255 &self,
256 url: &str,
257 body: &B,
258 ) -> SyncResult<T> {
259 let response = self
260 .client
261 .post(url)
262 .header("Authorization", format!("Bearer {}", self.token))
263 .json(body)
264 .send()
265 .await?;
266
267 self.handle_json_response(response).await
268 }
269
270 async fn handle_json_response<T: DeserializeOwned>(
271 &self,
272 response: reqwest::Response,
273 ) -> SyncResult<T> {
274 let status = response.status();
275 if status == StatusCode::UNAUTHORIZED {
276 return Err(SyncError::Unauthorized);
277 }
278 if !status.is_success() {
279 let message = response.text().await?;
280 return Err(SyncError::ApiError {
281 status: status.as_u16(),
282 message,
283 });
284 }
285 Ok(response.json().await?)
286 }
287
288 async fn handle_binary_response(&self, response: reqwest::Response) -> SyncResult<Vec<u8>> {
289 let status = response.status();
290 if !status.is_success() {
291 let message = response
292 .text()
293 .await
294 .unwrap_or_else(|e| format!("(body unreadable: {})", e));
295 return Err(SyncError::ApiError {
296 status: status.as_u16(),
297 message,
298 });
299 }
300 Ok(response.bytes().await?.to_vec())
301 }
302}