Skip to main content

rustrails_storage/
direct_upload.rs

1//! Direct upload flow helpers.
2
3use std::{collections::BTreeMap, time::Duration};
4
5use chrono::{DateTime, Utc};
6use rustrails_support::runtime;
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9use url::Url;
10
11use crate::{
12    Blob,
13    service::StorageService,
14    urls::{SignedUrlError, sign_payload, verify_payload},
15};
16
17/// Errors returned by direct upload helpers.
18#[derive(Debug, Error)]
19pub enum DirectUploadError {
20    /// The storage service URL could not be generated.
21    #[error(transparent)]
22    Storage(#[from] crate::service::StorageError),
23    /// The upload token was invalid.
24    #[error(transparent)]
25    SignedUrl(#[from] SignedUrlError),
26    /// The token payload could not be decoded.
27    #[error("invalid direct upload token")]
28    InvalidToken,
29    /// The token has expired.
30    #[error("direct upload token has expired")]
31    Expired,
32}
33
34/// Signed claims embedded in a direct upload token.
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
36pub struct DirectUploadTokenClaims {
37    /// Blob identifier.
38    pub blob_id: uuid::Uuid,
39    /// Blob storage key.
40    pub key: String,
41    /// Expected byte size.
42    pub byte_size: u64,
43    /// Expected checksum.
44    pub checksum: String,
45    /// Target service name.
46    pub service_name: String,
47    /// Expiration timestamp.
48    pub expires_at: i64,
49}
50
51/// Full direct upload response.
52#[derive(Debug, Clone)]
53pub struct DirectUploadRequest {
54    /// The blob metadata that was pre-created.
55    pub blob: Blob,
56    /// The signed upload URL returned to the client.
57    pub upload_url: Url,
58    /// Headers the client should include while uploading.
59    pub headers: BTreeMap<String, String>,
60    /// The signed integrity token.
61    pub token: String,
62    /// Expiration timestamp.
63    pub expires_at: DateTime<Utc>,
64}
65
66/// Creates and verifies direct upload tokens.
67#[derive(Debug, Clone)]
68pub struct DirectUploadManager {
69    secret: Vec<u8>,
70}
71
72impl DirectUploadManager {
73    /// Creates a new manager using the provided signing secret.
74    #[must_use]
75    pub fn new(secret: impl Into<Vec<u8>>) -> Self {
76        Self {
77            secret: secret.into(),
78        }
79    }
80
81    /// Prepares a direct upload request for a blob.
82    ///
83    /// # Errors
84    ///
85    /// Returns an error when the storage service cannot generate an upload URL or the token cannot be signed.
86    pub async fn prepare<S: StorageService + ?Sized>(
87        &self,
88        blob: Blob,
89        service: &S,
90        expires_in: Duration,
91    ) -> Result<DirectUploadRequest, DirectUploadError> {
92        let expires_at = Utc::now()
93            + chrono::Duration::from_std(expires_in)
94                .map_err(|_| DirectUploadError::InvalidToken)?;
95        let claims = DirectUploadTokenClaims {
96            blob_id: blob.id(),
97            key: blob.key().to_owned(),
98            byte_size: blob.byte_size(),
99            checksum: blob.checksum().to_owned(),
100            service_name: blob.service_name().to_owned(),
101            expires_at: expires_at.timestamp(),
102        };
103        let payload = serde_json::to_vec(&claims).map_err(|_| DirectUploadError::InvalidToken)?;
104        let token = sign_payload(&self.secret, &payload)?;
105        let upload_url = service.url(blob.key(), expires_in).await?;
106        let mut headers = BTreeMap::new();
107        headers.insert(
108            "x-rustrails-checksum".to_owned(),
109            blob.checksum().to_owned(),
110        );
111        headers.insert(
112            "x-rustrails-byte-size".to_owned(),
113            blob.byte_size().to_string(),
114        );
115        if let Some(content_type) = blob.content_type() {
116            headers.insert("content-type".to_owned(), content_type.to_owned());
117        }
118        Ok(DirectUploadRequest {
119            blob,
120            upload_url,
121            headers,
122            token,
123            expires_at,
124        })
125    }
126
127    /// Prepares a direct upload request for a blob using the thread-local runtime.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error when the storage service cannot generate an upload URL or the token cannot be signed.
132    pub fn prepare_sync<S: StorageService + ?Sized>(
133        &self,
134        blob: Blob,
135        service: &S,
136        expires_in: Duration,
137    ) -> Result<DirectUploadRequest, DirectUploadError> {
138        runtime::block_on(self.prepare(blob, service, expires_in))
139    }
140
141    /// Verifies a direct upload token.
142    ///
143    /// # Errors
144    ///
145    /// Returns an error when the token is invalid or expired.
146    pub fn verify(
147        &self,
148        token: &str,
149        now: DateTime<Utc>,
150    ) -> Result<DirectUploadTokenClaims, DirectUploadError> {
151        let payload = verify_payload(token, &self.secret)?;
152        let claims: DirectUploadTokenClaims =
153            serde_json::from_slice(&payload).map_err(|_| DirectUploadError::InvalidToken)?;
154        if now.timestamp() > claims.expires_at {
155            return Err(DirectUploadError::Expired);
156        }
157        Ok(claims)
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use bytes::Bytes;
164
165    use super::*;
166    use crate::{blob::Blob, service::memory::MemoryService, test_support::run_sync_test};
167
168    fn blob() -> Blob {
169        Blob::create(
170            Bytes::from_static(b"hello"),
171            "hello.txt",
172            None,
173            Default::default(),
174            "memory",
175        )
176        .expect("blob should build")
177    }
178
179    #[tokio::test]
180    async fn test_prepare_builds_upload_request() {
181        let manager = DirectUploadManager::new(b"secret".to_vec());
182        let service = MemoryService::new("memory").expect("service should build");
183        let request = manager
184            .prepare(blob(), &service, Duration::from_secs(60))
185            .await
186            .expect("request should build");
187        assert!(request.upload_url.as_str().contains("expires_in=60"));
188        assert!(request.headers.contains_key("x-rustrails-checksum"));
189    }
190
191    #[test]
192    fn test_prepare_sync_builds_upload_request() {
193        run_sync_test(|| {
194            let manager = DirectUploadManager::new(b"secret".to_vec());
195            let service = MemoryService::new("memory").expect("service should build");
196            let request = manager
197                .prepare_sync(blob(), &service, Duration::from_secs(60))
198                .expect("request should build");
199            assert!(request.upload_url.as_str().contains("expires_in=60"));
200            assert!(request.headers.contains_key("x-rustrails-checksum"));
201        });
202    }
203
204    #[tokio::test]
205    async fn test_prepare_includes_byte_size_header() {
206        let manager = DirectUploadManager::new(b"secret".to_vec());
207        let service = MemoryService::new("memory").expect("service should build");
208        let blob = blob();
209        let expected_byte_size = blob.byte_size().to_string();
210
211        let request = manager
212            .prepare(blob, &service, Duration::from_secs(60))
213            .await
214            .expect("request should build");
215
216        assert_eq!(
217            request
218                .headers
219                .get("x-rustrails-byte-size")
220                .map(String::as_str),
221            Some(expected_byte_size.as_str())
222        );
223    }
224
225    #[tokio::test]
226    async fn test_prepare_preserves_blob_reference() {
227        let manager = DirectUploadManager::new(b"secret".to_vec());
228        let service = MemoryService::new("memory").expect("service should build");
229        let blob = blob();
230        let request = manager
231            .prepare(blob.clone(), &service, Duration::from_secs(60))
232            .await
233            .expect("request should build");
234        assert_eq!(request.blob.id(), blob.id());
235    }
236
237    #[tokio::test]
238    async fn test_verify_round_trips_claims() {
239        let manager = DirectUploadManager::new(b"secret".to_vec());
240        let service = MemoryService::new("memory").expect("service should build");
241        let request = manager
242            .prepare(blob(), &service, Duration::from_secs(60))
243            .await
244            .expect("request should build");
245        let claims = manager
246            .verify(&request.token, Utc::now())
247            .expect("token should verify");
248        assert_eq!(claims.key, request.blob.key());
249        assert_eq!(claims.checksum, request.blob.checksum());
250    }
251
252    #[tokio::test]
253    async fn test_prepare_round_trips_blob_metadata_service_name_and_expiration() {
254        let manager = DirectUploadManager::new(b"secret".to_vec());
255        let service = MemoryService::new("public").expect("service should build");
256        let mut metadata = BTreeMap::new();
257        metadata.insert("custom".to_owned(), serde_json::json!("value"));
258        let blob = Blob::create_before_direct_upload(
259            "direct-key",
260            "hello.txt",
261            6,
262            "checksum",
263            Some("text/plain"),
264            metadata.clone(),
265            "mirror",
266        )
267        .expect("blob should build");
268        let blob_id = blob.id();
269
270        let request = manager
271            .prepare(blob, &service, Duration::from_secs(60))
272            .await
273            .expect("request should build");
274        let claims = manager
275            .verify(&request.token, Utc::now())
276            .expect("token should verify");
277
278        assert_eq!(request.blob.metadata(), &metadata);
279        assert_eq!(claims.blob_id, blob_id);
280        assert_eq!(claims.service_name, "mirror");
281        assert_eq!(claims.expires_at, request.expires_at.timestamp());
282    }
283
284    #[tokio::test]
285    async fn test_verify_rejects_expired_token() {
286        let manager = DirectUploadManager::new(b"secret".to_vec());
287        let service = MemoryService::new("memory").expect("service should build");
288        let request = manager
289            .prepare(blob(), &service, Duration::from_secs(1))
290            .await
291            .expect("request should build");
292        let future = Utc::now() + chrono::Duration::seconds(2);
293        let error = manager
294            .verify(&request.token, future)
295            .expect_err("token should fail");
296        assert!(matches!(error, DirectUploadError::Expired));
297    }
298
299    #[tokio::test]
300    async fn test_verify_rejects_tampered_token() {
301        let manager = DirectUploadManager::new(b"secret".to_vec());
302        let error = manager
303            .verify("tampered", Utc::now())
304            .expect_err("token should fail");
305        assert!(matches!(error, DirectUploadError::SignedUrl(_)));
306    }
307
308    #[tokio::test]
309    async fn test_verify_rejects_token_signed_with_different_secret() {
310        let manager = DirectUploadManager::new(b"secret".to_vec());
311        let other_manager = DirectUploadManager::new(b"other-secret".to_vec());
312        let service = MemoryService::new("memory").expect("service should build");
313        let request = manager
314            .prepare(blob(), &service, Duration::from_secs(60))
315            .await
316            .expect("request should build");
317
318        let error = other_manager
319            .verify(&request.token, Utc::now())
320            .expect_err("token should fail");
321
322        assert!(matches!(error, DirectUploadError::SignedUrl(_)));
323    }
324
325    #[tokio::test]
326    async fn test_verify_accepts_exact_expiration_timestamp() {
327        let manager = DirectUploadManager::new(b"secret".to_vec());
328        let service = MemoryService::new("memory").expect("service should build");
329        let request = manager
330            .prepare(blob(), &service, Duration::from_secs(1))
331            .await
332            .expect("request should build");
333        let claims = manager
334            .verify(&request.token, Utc::now())
335            .expect("token should verify");
336        let boundary = chrono::DateTime::<Utc>::from_timestamp(claims.expires_at, 0)
337            .expect("timestamp should be valid");
338
339        let boundary_claims = manager
340            .verify(&request.token, boundary)
341            .expect("boundary token should verify");
342
343        assert_eq!(boundary_claims, claims);
344    }
345
346    #[tokio::test]
347    async fn test_prepare_includes_content_type_header_when_known() {
348        let manager = DirectUploadManager::new(b"secret".to_vec());
349        let service = MemoryService::new("memory").expect("service should build");
350        let blob = Blob::create(
351            Bytes::from_static(b"hello"),
352            "hello.txt",
353            Some("text/plain"),
354            Default::default(),
355            "memory",
356        )
357        .expect("blob should build");
358        let request = manager
359            .prepare(blob, &service, Duration::from_secs(60))
360            .await
361            .expect("request should build");
362        assert_eq!(
363            request.headers.get("content-type").map(String::as_str),
364            Some("text/plain")
365        );
366    }
367
368    #[tokio::test]
369    async fn test_prepare_omits_content_type_header_when_unknown() {
370        let manager = DirectUploadManager::new(b"secret".to_vec());
371        let service = MemoryService::new("memory").expect("service should build");
372        let blob = Blob::create_before_direct_upload(
373            "direct-key",
374            "unknown_file",
375            100,
376            "checksum",
377            None,
378            BTreeMap::new(),
379            "memory",
380        )
381        .expect("blob should build");
382
383        let request = manager
384            .prepare(blob, &service, Duration::from_secs(60))
385            .await
386            .expect("request should build");
387
388        assert!(!request.headers.contains_key("content-type"));
389    }
390}