Skip to main content

trace_share_core/
worker.rs

1use anyhow::{Context, Result};
2use rand::{Rng, thread_rng};
3use serde::{Deserialize, Serialize};
4use std::collections::BTreeMap;
5use tokio::time::{Duration, sleep};
6
7use crate::{config::AppConfig, episode::EpisodeRecord};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct WorkerUploadResponse {
11    pub episode_id: String,
12    pub object_key: String,
13    pub etag: Option<String>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17struct PresignUploadResponse {
18    upload_url: String,
19    object_key: String,
20    headers: Option<BTreeMap<String, String>>,
21}
22
23pub async fn upload_episode(
24    config: &AppConfig,
25    episode: &EpisodeRecord,
26) -> Result<WorkerUploadResponse> {
27    let mode = config.worker.upload_mode.to_ascii_lowercase();
28    if mode == "presigned" {
29        return upload_episode_presigned(config, episode).await;
30    }
31    if mode == "auto" {
32        match upload_episode_presigned(config, episode).await {
33            Ok(v) => return Ok(v),
34            Err(e) => {
35                let text = format!("{e:#}");
36                if !(text.contains("status=404") || text.contains("status=501")) {
37                    return Err(e).context("worker upload (auto mode, presigned path)");
38                }
39            }
40        }
41    }
42    upload_episode_legacy(config, episode).await
43}
44
45async fn upload_episode_legacy(
46    config: &AppConfig,
47    episode: &EpisodeRecord,
48) -> Result<WorkerUploadResponse> {
49    let base_url = config
50        .worker
51        .base_url
52        .as_ref()
53        .context("missing TRACE_SHARE_WORKER_BASE_URL")?;
54
55    let endpoint = format!("{}/v1/episodes", base_url.trim_end_matches('/'));
56    let client = reqwest::Client::builder()
57        .timeout(std::time::Duration::from_secs(
58            config.worker.timeout_seconds.max(5),
59        ))
60        .build()?;
61
62    let mut attempt: u32 = 0;
63    loop {
64        let mut req = client.post(&endpoint).json(episode);
65        if let Some(token) = config.worker.api_token.as_ref() {
66            req = req.bearer_auth(token);
67        }
68
69        let resp = req.send().await;
70        match resp {
71            Ok(resp) => {
72                let status = resp.status();
73                if status.is_success() {
74                    return Ok(resp.json::<WorkerUploadResponse>().await?);
75                }
76
77                let body = resp.text().await.unwrap_or_default();
78                if !should_retry_status(status) || attempt >= 4 {
79                    anyhow::bail!("worker upload failed: status={} body={}", status, body);
80                }
81            }
82            Err(e) => {
83                let retryable_transport = e.is_timeout() || e.is_connect() || e.is_request();
84                if !retryable_transport || attempt >= 4 {
85                    return Err(e).context("worker upload request failed after retries");
86                }
87            }
88        }
89
90        attempt += 1;
91        let jitter: u64 = thread_rng().gen_range(50..300);
92        let wait_ms = (2u64.pow(attempt) * 200) + jitter;
93        sleep(Duration::from_millis(wait_ms)).await;
94    }
95}
96
97async fn upload_episode_presigned(
98    config: &AppConfig,
99    episode: &EpisodeRecord,
100) -> Result<WorkerUploadResponse> {
101    let base_url = config
102        .worker
103        .base_url
104        .as_ref()
105        .context("missing TRACE_SHARE_WORKER_BASE_URL")?;
106    let presign_endpoint = format!("{}/v1/episodes/presign", base_url.trim_end_matches('/'));
107    let complete_endpoint = format!("{}/v1/episodes/complete", base_url.trim_end_matches('/'));
108    let client = reqwest::Client::builder()
109        .timeout(std::time::Duration::from_secs(
110            config.worker.timeout_seconds.max(5),
111        ))
112        .build()?;
113
114    let presign_payload = serde_json::json!({
115        "episode_id": episode.id,
116        "content_hash": episode.content_hash,
117        "content_type": "application/json",
118    });
119
120    let presign = post_with_retry_json::<PresignUploadResponse>(
121        &client,
122        &presign_endpoint,
123        config.worker.api_token.as_deref(),
124        &presign_payload,
125        "worker episode presign",
126    )
127    .await?;
128
129    let episode_bytes = serde_json::to_vec(episode)?;
130    put_with_retry(
131        &client,
132        &presign.upload_url,
133        presign.headers.as_ref(),
134        &episode_bytes,
135        "worker episode upload",
136    )
137    .await?;
138
139    let complete_payload = serde_json::json!({
140        "episode_id": episode.id,
141        "object_key": presign.object_key,
142        "content_hash": episode.content_hash,
143    });
144
145    post_with_retry_json::<WorkerUploadResponse>(
146        &client,
147        &complete_endpoint,
148        config.worker.api_token.as_deref(),
149        &complete_payload,
150        "worker episode complete",
151    )
152    .await
153}
154
155pub async fn push_revocation(
156    config: &AppConfig,
157    episode_id: &str,
158    revoked_at: &str,
159    reason: Option<&str>,
160) -> Result<()> {
161    let base_url = config
162        .worker
163        .base_url
164        .as_ref()
165        .context("missing TRACE_SHARE_WORKER_BASE_URL")?;
166
167    let endpoint = format!("{}/v1/revocations", base_url.trim_end_matches('/'));
168    let client = reqwest::Client::builder()
169        .timeout(std::time::Duration::from_secs(
170            config.worker.timeout_seconds.max(5),
171        ))
172        .build()?;
173
174    let payload = serde_json::json!({
175        "episode_id": episode_id,
176        "revoked_at": revoked_at,
177        "reason": reason,
178    });
179
180    let mut attempt: u32 = 0;
181    loop {
182        let mut req = client.post(&endpoint).json(&payload);
183        if let Some(token) = config.worker.api_token.as_ref() {
184            req = req.bearer_auth(token);
185        }
186
187        let resp = req.send().await;
188        match resp {
189            Ok(resp) => {
190                let status = resp.status();
191                if status.is_success() {
192                    return Ok(());
193                }
194
195                let body = resp.text().await.unwrap_or_default();
196                if !should_retry_status(status) || attempt >= 4 {
197                    anyhow::bail!(
198                        "worker revocation push failed: status={} body={}",
199                        status,
200                        body
201                    );
202                }
203            }
204            Err(e) => {
205                let retryable_transport = e.is_timeout() || e.is_connect() || e.is_request();
206                if !retryable_transport || attempt >= 4 {
207                    return Err(e).context("worker revocation request failed after retries");
208                }
209            }
210        }
211
212        attempt += 1;
213        let jitter: u64 = thread_rng().gen_range(50..300);
214        let wait_ms = (2u64.pow(attempt) * 200) + jitter;
215        sleep(Duration::from_millis(wait_ms)).await;
216    }
217}
218
219async fn post_with_retry_json<T: for<'de> Deserialize<'de>>(
220    client: &reqwest::Client,
221    endpoint: &str,
222    bearer_token: Option<&str>,
223    payload: &serde_json::Value,
224    label: &str,
225) -> Result<T> {
226    let mut attempt: u32 = 0;
227    loop {
228        let mut req = client.post(endpoint).json(payload);
229        if let Some(token) = bearer_token {
230            req = req.bearer_auth(token);
231        }
232
233        let resp = req.send().await;
234        match resp {
235            Ok(resp) => {
236                let status = resp.status();
237                if status.is_success() {
238                    return Ok(resp.json::<T>().await?);
239                }
240                let body = resp.text().await.unwrap_or_default();
241                if !should_retry_status(status) || attempt >= 4 {
242                    anyhow::bail!("{label} failed: status={} body={}", status, body);
243                }
244            }
245            Err(e) => {
246                let retryable_transport = e.is_timeout() || e.is_connect() || e.is_request();
247                if !retryable_transport || attempt >= 4 {
248                    return Err(e).with_context(|| format!("{label} request failed after retries"));
249                }
250            }
251        }
252
253        attempt += 1;
254        let jitter: u64 = thread_rng().gen_range(50..300);
255        let wait_ms = (2u64.pow(attempt) * 200) + jitter;
256        sleep(Duration::from_millis(wait_ms)).await;
257    }
258}
259
260async fn put_with_retry(
261    client: &reqwest::Client,
262    endpoint: &str,
263    headers: Option<&BTreeMap<String, String>>,
264    body: &[u8],
265    label: &str,
266) -> Result<()> {
267    let mut attempt: u32 = 0;
268    loop {
269        let mut req = client.put(endpoint).body(body.to_vec());
270        if let Some(headers) = headers {
271            for (k, v) in headers {
272                req = req.header(k, v);
273            }
274        }
275
276        let resp = req.send().await;
277        match resp {
278            Ok(resp) => {
279                let status = resp.status();
280                if status.is_success() {
281                    return Ok(());
282                }
283                let body = resp.text().await.unwrap_or_default();
284                if !should_retry_status(status) || attempt >= 4 {
285                    anyhow::bail!("{label} failed: status={} body={}", status, body);
286                }
287            }
288            Err(e) => {
289                let retryable_transport = e.is_timeout() || e.is_connect() || e.is_request();
290                if !retryable_transport || attempt >= 4 {
291                    return Err(e).with_context(|| format!("{label} request failed after retries"));
292                }
293            }
294        }
295
296        attempt += 1;
297        let jitter: u64 = thread_rng().gen_range(50..300);
298        let wait_ms = (2u64.pow(attempt) * 200) + jitter;
299        sleep(Duration::from_millis(wait_ms)).await;
300    }
301}
302
303fn should_retry_status(status: reqwest::StatusCode) -> bool {
304    status == reqwest::StatusCode::TOO_MANY_REQUESTS || status.is_server_error()
305}