1use std::time::Duration;
4
5use chrono::{DateTime, Utc};
6use hmac::{Hmac, Mac};
7use serde::{Deserialize, Serialize};
8use sha2::Sha256;
9use thiserror::Error;
10use url::Url;
11
12use crate::{Blob, Variant, urlsafe_decode, urlsafe_encode};
13
14type HmacSha256 = Hmac<Sha256>;
15
16#[derive(Debug, Error)]
18pub enum SignedUrlError {
19 #[error("invalid url: {0}")]
21 InvalidUrl(String),
22 #[error("signature verification failed")]
24 InvalidSignature,
25 #[error("invalid token payload")]
27 InvalidPayload,
28 #[error("signed url has expired")]
30 Expired,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35pub enum SignedResource {
36 Blob { key: String },
38 Variant { key: String },
40 Redirect { location: String },
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45struct SignedClaims {
46 resource: SignedResource,
47 expires_at: i64,
48}
49
50#[derive(Debug, Clone)]
52pub struct SignedUrlGenerator {
53 base_url: Url,
54 secret: Vec<u8>,
55}
56
57impl SignedUrlGenerator {
58 pub fn new(
64 base_url: impl AsRef<str>,
65 secret: impl Into<Vec<u8>>,
66 ) -> Result<Self, SignedUrlError> {
67 Ok(Self {
68 base_url: Url::parse(base_url.as_ref())
69 .map_err(|error| SignedUrlError::InvalidUrl(error.to_string()))?,
70 secret: secret.into(),
71 })
72 }
73
74 pub fn blob_url(&self, blob: &Blob, expires_in: Duration) -> Result<Url, SignedUrlError> {
80 self.signed_url(
81 SignedResource::Blob {
82 key: blob.key().to_owned(),
83 },
84 expires_in,
85 )
86 }
87
88 pub fn variant_url(
94 &self,
95 variant: &Variant,
96 expires_in: Duration,
97 ) -> Result<Url, SignedUrlError> {
98 self.signed_url(
99 SignedResource::Variant {
100 key: variant.key().to_owned(),
101 },
102 expires_in,
103 )
104 }
105
106 pub fn redirect_url(
112 &self,
113 location: &Url,
114 expires_in: Duration,
115 ) -> Result<Url, SignedUrlError> {
116 self.signed_url(
117 SignedResource::Redirect {
118 location: location.to_string(),
119 },
120 expires_in,
121 )
122 }
123
124 pub fn verify(&self, url: &Url) -> Result<SignedResource, SignedUrlError> {
130 self.verify_at(url, Utc::now())
131 }
132
133 pub fn verify_at(
139 &self,
140 url: &Url,
141 now: DateTime<Utc>,
142 ) -> Result<SignedResource, SignedUrlError> {
143 let token = url
144 .query_pairs()
145 .find(|(key, _)| key == "token")
146 .map(|(_, value)| value.into_owned())
147 .ok_or(SignedUrlError::InvalidPayload)?;
148 let (payload, signature) = token
149 .split_once('.')
150 .ok_or(SignedUrlError::InvalidPayload)?;
151 let payload_bytes = urlsafe_decode(payload).map_err(|_| SignedUrlError::InvalidPayload)?;
152 let signature_bytes =
153 urlsafe_decode(signature).map_err(|_| SignedUrlError::InvalidPayload)?;
154 let expected = sign_bytes(&self.secret, &payload_bytes)?;
155 if expected != signature_bytes {
156 return Err(SignedUrlError::InvalidSignature);
157 }
158 let claims: SignedClaims =
159 serde_json::from_slice(&payload_bytes).map_err(|_| SignedUrlError::InvalidPayload)?;
160 if now.timestamp() > claims.expires_at {
161 return Err(SignedUrlError::Expired);
162 }
163 Ok(claims.resource)
164 }
165
166 fn signed_url(
167 &self,
168 resource: SignedResource,
169 expires_in: Duration,
170 ) -> Result<Url, SignedUrlError> {
171 let expires_at = Utc::now()
172 + chrono::Duration::from_std(expires_in).map_err(|_| SignedUrlError::InvalidPayload)?;
173 let claims = SignedClaims {
174 resource,
175 expires_at: expires_at.timestamp(),
176 };
177 let payload = serde_json::to_vec(&claims).map_err(|_| SignedUrlError::InvalidPayload)?;
178 let signature = sign_bytes(&self.secret, &payload)?;
179 let token = format!("{}.{}", urlsafe_encode(&payload), urlsafe_encode(signature));
180 let mut url = self.base_url.clone();
181 url.query_pairs_mut().append_pair("token", &token);
182 Ok(url)
183 }
184}
185
186pub(crate) fn sign_payload(secret: &[u8], payload: &[u8]) -> Result<String, SignedUrlError> {
187 let signature = sign_bytes(secret, payload)?;
188 Ok(format!(
189 "{}.{}",
190 urlsafe_encode(payload),
191 urlsafe_encode(signature)
192 ))
193}
194
195pub(crate) fn verify_payload(token: &str, secret: &[u8]) -> Result<Vec<u8>, SignedUrlError> {
196 let (payload, signature) = token
197 .split_once('.')
198 .ok_or(SignedUrlError::InvalidPayload)?;
199 let payload_bytes = urlsafe_decode(payload).map_err(|_| SignedUrlError::InvalidPayload)?;
200 let signature_bytes = urlsafe_decode(signature).map_err(|_| SignedUrlError::InvalidPayload)?;
201 let expected = sign_bytes(secret, &payload_bytes)?;
202 if expected != signature_bytes {
203 return Err(SignedUrlError::InvalidSignature);
204 }
205 Ok(payload_bytes)
206}
207
208fn sign_bytes(secret: &[u8], payload: &[u8]) -> Result<Vec<u8>, SignedUrlError> {
209 let mut mac = HmacSha256::new_from_slice(secret).map_err(|_| SignedUrlError::InvalidPayload)?;
210 mac.update(payload);
211 Ok(mac.finalize().into_bytes().to_vec())
212}
213
214#[cfg(test)]
215mod tests {
216 use bytes::Bytes;
217
218 use super::*;
219
220 fn generator() -> SignedUrlGenerator {
221 SignedUrlGenerator::new("https://example.test/storage", b"secret".to_vec())
222 .expect("generator should build")
223 }
224
225 fn blob() -> Blob {
226 Blob::create(
227 Bytes::from_static(b"hello"),
228 "hello.txt",
229 None,
230 Default::default(),
231 "memory",
232 )
233 .expect("blob should build")
234 }
235
236 #[test]
237 fn test_blob_url_round_trip_verification() {
238 let generator = generator();
239 let url = generator
240 .blob_url(&blob(), Duration::from_secs(60))
241 .expect("url should build");
242 let resource = generator.verify(&url).expect("url should verify");
243 assert!(matches!(resource, SignedResource::Blob { .. }));
244 }
245
246 #[test]
247 fn test_variant_url_round_trip_verification() {
248 let generator = generator();
249 let variant = Variant::new(blob(), Default::default());
250 let url = generator
251 .variant_url(&variant, Duration::from_secs(60))
252 .expect("url should build");
253 let resource = generator.verify(&url).expect("url should verify");
254 assert!(matches!(resource, SignedResource::Variant { .. }));
255 }
256
257 #[test]
258 fn test_redirect_url_round_trip_verification() {
259 let generator = generator();
260 let location = Url::parse("https://cdn.example/files/1").expect("url should parse");
261 let url = generator
262 .redirect_url(&location, Duration::from_secs(60))
263 .expect("url should build");
264 let resource = generator.verify(&url).expect("url should verify");
265 assert_eq!(
266 resource,
267 SignedResource::Redirect {
268 location: location.to_string()
269 }
270 );
271 }
272
273 #[test]
274 fn test_verify_rejects_expired_url() {
275 let generator = generator();
276 let url = generator
277 .blob_url(&blob(), Duration::from_secs(1))
278 .expect("url should build");
279 let future = Utc::now() + chrono::Duration::seconds(2);
280 let error = generator
281 .verify_at(&url, future)
282 .expect_err("url should be expired");
283 assert!(matches!(error, SignedUrlError::Expired));
284 }
285
286 #[test]
287 fn test_verify_rejects_tampered_token() {
288 let generator = generator();
289 let mut url = generator
290 .blob_url(&blob(), Duration::from_secs(60))
291 .expect("url should build");
292 url.query_pairs_mut()
293 .clear()
294 .append_pair("token", "tampered");
295 let error = generator.verify(&url).expect_err("url should fail");
296 assert!(matches!(error, SignedUrlError::InvalidPayload));
297 }
298
299 #[test]
300 fn test_sign_payload_and_verify_payload_round_trip() {
301 let payload = b"hello";
302 let token = sign_payload(b"secret", payload).expect("token should build");
303 let decoded = verify_payload(&token, b"secret").expect("token should verify");
304 assert_eq!(decoded, payload);
305 }
306}