1use std::{collections::BTreeMap, time::Duration};
4
5use chrono::{DateTime, Utc};
6use rustrails_support::runtime;
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9use url::Url;
10
11use crate::{
12 Blob,
13 service::StorageService,
14 urls::{SignedUrlError, sign_payload, verify_payload},
15};
16
17#[derive(Debug, Error)]
19pub enum DirectUploadError {
20 #[error(transparent)]
22 Storage(#[from] crate::service::StorageError),
23 #[error(transparent)]
25 SignedUrl(#[from] SignedUrlError),
26 #[error("invalid direct upload token")]
28 InvalidToken,
29 #[error("direct upload token has expired")]
31 Expired,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
36pub struct DirectUploadTokenClaims {
37 pub blob_id: uuid::Uuid,
39 pub key: String,
41 pub byte_size: u64,
43 pub checksum: String,
45 pub service_name: String,
47 pub expires_at: i64,
49}
50
51#[derive(Debug, Clone)]
53pub struct DirectUploadRequest {
54 pub blob: Blob,
56 pub upload_url: Url,
58 pub headers: BTreeMap<String, String>,
60 pub token: String,
62 pub expires_at: DateTime<Utc>,
64}
65
66#[derive(Debug, Clone)]
68pub struct DirectUploadManager {
69 secret: Vec<u8>,
70}
71
72impl DirectUploadManager {
73 #[must_use]
75 pub fn new(secret: impl Into<Vec<u8>>) -> Self {
76 Self {
77 secret: secret.into(),
78 }
79 }
80
81 pub async fn prepare<S: StorageService + ?Sized>(
87 &self,
88 blob: Blob,
89 service: &S,
90 expires_in: Duration,
91 ) -> Result<DirectUploadRequest, DirectUploadError> {
92 let expires_at = Utc::now()
93 + chrono::Duration::from_std(expires_in)
94 .map_err(|_| DirectUploadError::InvalidToken)?;
95 let claims = DirectUploadTokenClaims {
96 blob_id: blob.id(),
97 key: blob.key().to_owned(),
98 byte_size: blob.byte_size(),
99 checksum: blob.checksum().to_owned(),
100 service_name: blob.service_name().to_owned(),
101 expires_at: expires_at.timestamp(),
102 };
103 let payload = serde_json::to_vec(&claims).map_err(|_| DirectUploadError::InvalidToken)?;
104 let token = sign_payload(&self.secret, &payload)?;
105 let upload_url = service.url(blob.key(), expires_in).await?;
106 let mut headers = BTreeMap::new();
107 headers.insert(
108 "x-rustrails-checksum".to_owned(),
109 blob.checksum().to_owned(),
110 );
111 headers.insert(
112 "x-rustrails-byte-size".to_owned(),
113 blob.byte_size().to_string(),
114 );
115 if let Some(content_type) = blob.content_type() {
116 headers.insert("content-type".to_owned(), content_type.to_owned());
117 }
118 Ok(DirectUploadRequest {
119 blob,
120 upload_url,
121 headers,
122 token,
123 expires_at,
124 })
125 }
126
127 pub fn prepare_sync<S: StorageService + ?Sized>(
133 &self,
134 blob: Blob,
135 service: &S,
136 expires_in: Duration,
137 ) -> Result<DirectUploadRequest, DirectUploadError> {
138 runtime::block_on(self.prepare(blob, service, expires_in))
139 }
140
141 pub fn verify(
147 &self,
148 token: &str,
149 now: DateTime<Utc>,
150 ) -> Result<DirectUploadTokenClaims, DirectUploadError> {
151 let payload = verify_payload(token, &self.secret)?;
152 let claims: DirectUploadTokenClaims =
153 serde_json::from_slice(&payload).map_err(|_| DirectUploadError::InvalidToken)?;
154 if now.timestamp() > claims.expires_at {
155 return Err(DirectUploadError::Expired);
156 }
157 Ok(claims)
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use bytes::Bytes;
164
165 use super::*;
166 use crate::{blob::Blob, service::memory::MemoryService, test_support::run_sync_test};
167
168 fn blob() -> Blob {
169 Blob::create(
170 Bytes::from_static(b"hello"),
171 "hello.txt",
172 None,
173 Default::default(),
174 "memory",
175 )
176 .expect("blob should build")
177 }
178
179 #[tokio::test]
180 async fn test_prepare_builds_upload_request() {
181 let manager = DirectUploadManager::new(b"secret".to_vec());
182 let service = MemoryService::new("memory").expect("service should build");
183 let request = manager
184 .prepare(blob(), &service, Duration::from_secs(60))
185 .await
186 .expect("request should build");
187 assert!(request.upload_url.as_str().contains("expires_in=60"));
188 assert!(request.headers.contains_key("x-rustrails-checksum"));
189 }
190
191 #[test]
192 fn test_prepare_sync_builds_upload_request() {
193 run_sync_test(|| {
194 let manager = DirectUploadManager::new(b"secret".to_vec());
195 let service = MemoryService::new("memory").expect("service should build");
196 let request = manager
197 .prepare_sync(blob(), &service, Duration::from_secs(60))
198 .expect("request should build");
199 assert!(request.upload_url.as_str().contains("expires_in=60"));
200 assert!(request.headers.contains_key("x-rustrails-checksum"));
201 });
202 }
203
204 #[tokio::test]
205 async fn test_prepare_includes_byte_size_header() {
206 let manager = DirectUploadManager::new(b"secret".to_vec());
207 let service = MemoryService::new("memory").expect("service should build");
208 let blob = blob();
209 let expected_byte_size = blob.byte_size().to_string();
210
211 let request = manager
212 .prepare(blob, &service, Duration::from_secs(60))
213 .await
214 .expect("request should build");
215
216 assert_eq!(
217 request
218 .headers
219 .get("x-rustrails-byte-size")
220 .map(String::as_str),
221 Some(expected_byte_size.as_str())
222 );
223 }
224
225 #[tokio::test]
226 async fn test_prepare_preserves_blob_reference() {
227 let manager = DirectUploadManager::new(b"secret".to_vec());
228 let service = MemoryService::new("memory").expect("service should build");
229 let blob = blob();
230 let request = manager
231 .prepare(blob.clone(), &service, Duration::from_secs(60))
232 .await
233 .expect("request should build");
234 assert_eq!(request.blob.id(), blob.id());
235 }
236
237 #[tokio::test]
238 async fn test_verify_round_trips_claims() {
239 let manager = DirectUploadManager::new(b"secret".to_vec());
240 let service = MemoryService::new("memory").expect("service should build");
241 let request = manager
242 .prepare(blob(), &service, Duration::from_secs(60))
243 .await
244 .expect("request should build");
245 let claims = manager
246 .verify(&request.token, Utc::now())
247 .expect("token should verify");
248 assert_eq!(claims.key, request.blob.key());
249 assert_eq!(claims.checksum, request.blob.checksum());
250 }
251
252 #[tokio::test]
253 async fn test_prepare_round_trips_blob_metadata_service_name_and_expiration() {
254 let manager = DirectUploadManager::new(b"secret".to_vec());
255 let service = MemoryService::new("public").expect("service should build");
256 let mut metadata = BTreeMap::new();
257 metadata.insert("custom".to_owned(), serde_json::json!("value"));
258 let blob = Blob::create_before_direct_upload(
259 "direct-key",
260 "hello.txt",
261 6,
262 "checksum",
263 Some("text/plain"),
264 metadata.clone(),
265 "mirror",
266 )
267 .expect("blob should build");
268 let blob_id = blob.id();
269
270 let request = manager
271 .prepare(blob, &service, Duration::from_secs(60))
272 .await
273 .expect("request should build");
274 let claims = manager
275 .verify(&request.token, Utc::now())
276 .expect("token should verify");
277
278 assert_eq!(request.blob.metadata(), &metadata);
279 assert_eq!(claims.blob_id, blob_id);
280 assert_eq!(claims.service_name, "mirror");
281 assert_eq!(claims.expires_at, request.expires_at.timestamp());
282 }
283
284 #[tokio::test]
285 async fn test_verify_rejects_expired_token() {
286 let manager = DirectUploadManager::new(b"secret".to_vec());
287 let service = MemoryService::new("memory").expect("service should build");
288 let request = manager
289 .prepare(blob(), &service, Duration::from_secs(1))
290 .await
291 .expect("request should build");
292 let future = Utc::now() + chrono::Duration::seconds(2);
293 let error = manager
294 .verify(&request.token, future)
295 .expect_err("token should fail");
296 assert!(matches!(error, DirectUploadError::Expired));
297 }
298
299 #[tokio::test]
300 async fn test_verify_rejects_tampered_token() {
301 let manager = DirectUploadManager::new(b"secret".to_vec());
302 let error = manager
303 .verify("tampered", Utc::now())
304 .expect_err("token should fail");
305 assert!(matches!(error, DirectUploadError::SignedUrl(_)));
306 }
307
308 #[tokio::test]
309 async fn test_verify_rejects_token_signed_with_different_secret() {
310 let manager = DirectUploadManager::new(b"secret".to_vec());
311 let other_manager = DirectUploadManager::new(b"other-secret".to_vec());
312 let service = MemoryService::new("memory").expect("service should build");
313 let request = manager
314 .prepare(blob(), &service, Duration::from_secs(60))
315 .await
316 .expect("request should build");
317
318 let error = other_manager
319 .verify(&request.token, Utc::now())
320 .expect_err("token should fail");
321
322 assert!(matches!(error, DirectUploadError::SignedUrl(_)));
323 }
324
325 #[tokio::test]
326 async fn test_verify_accepts_exact_expiration_timestamp() {
327 let manager = DirectUploadManager::new(b"secret".to_vec());
328 let service = MemoryService::new("memory").expect("service should build");
329 let request = manager
330 .prepare(blob(), &service, Duration::from_secs(1))
331 .await
332 .expect("request should build");
333 let claims = manager
334 .verify(&request.token, Utc::now())
335 .expect("token should verify");
336 let boundary = chrono::DateTime::<Utc>::from_timestamp(claims.expires_at, 0)
337 .expect("timestamp should be valid");
338
339 let boundary_claims = manager
340 .verify(&request.token, boundary)
341 .expect("boundary token should verify");
342
343 assert_eq!(boundary_claims, claims);
344 }
345
346 #[tokio::test]
347 async fn test_prepare_includes_content_type_header_when_known() {
348 let manager = DirectUploadManager::new(b"secret".to_vec());
349 let service = MemoryService::new("memory").expect("service should build");
350 let blob = Blob::create(
351 Bytes::from_static(b"hello"),
352 "hello.txt",
353 Some("text/plain"),
354 Default::default(),
355 "memory",
356 )
357 .expect("blob should build");
358 let request = manager
359 .prepare(blob, &service, Duration::from_secs(60))
360 .await
361 .expect("request should build");
362 assert_eq!(
363 request.headers.get("content-type").map(String::as_str),
364 Some("text/plain")
365 );
366 }
367
368 #[tokio::test]
369 async fn test_prepare_omits_content_type_header_when_unknown() {
370 let manager = DirectUploadManager::new(b"secret".to_vec());
371 let service = MemoryService::new("memory").expect("service should build");
372 let blob = Blob::create_before_direct_upload(
373 "direct-key",
374 "unknown_file",
375 100,
376 "checksum",
377 None,
378 BTreeMap::new(),
379 "memory",
380 )
381 .expect("blob should build");
382
383 let request = manager
384 .prepare(blob, &service, Duration::from_secs(60))
385 .await
386 .expect("request should build");
387
388 assert!(!request.headers.contains_key("content-type"));
389 }
390}