1use std::time::{Duration, SystemTime, UNIX_EPOCH};
46
47use axum::{
48 extract::{FromRequestParts, OriginalUri, Request},
49 http::{request::Parts, StatusCode},
50 middleware::Next,
51 response::{IntoResponse, Response},
52 Json,
53};
54use hmac::{Hmac, Mac};
55use serde::Deserialize;
56use sha2::Sha256;
57use subtle::ConstantTimeEq;
58use tracing::{debug, warn};
59use url::form_urlencoded;
60
61use super::handlers::ErrorResponse;
62
63type HmacSha256 = Hmac<Sha256>;
69
70#[derive(Debug, Clone)]
72pub enum AuthError {
73 MissingSignature,
75
76 MissingExpiry,
78
79 Expired {
81 expired_at: u64,
83 current_time: u64,
85 },
86
87 InvalidSignature,
89
90 InvalidSignatureFormat,
92
93 InvalidExpiryFormat,
95}
96
97impl std::fmt::Display for AuthError {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 match self {
100 AuthError::MissingSignature => write!(f, "Missing signature parameter"),
101 AuthError::MissingExpiry => write!(f, "Missing expiry parameter"),
102 AuthError::Expired {
103 expired_at,
104 current_time,
105 } => write!(
106 f,
107 "Signature expired at {} (current time: {})",
108 expired_at, current_time
109 ),
110 AuthError::InvalidSignature => write!(f, "Invalid signature"),
111 AuthError::InvalidSignatureFormat => write!(f, "Invalid signature format"),
112 AuthError::InvalidExpiryFormat => write!(f, "Invalid expiry format"),
113 }
114 }
115}
116
117impl IntoResponse for AuthError {
118 fn into_response(self) -> Response {
119 let (status, error_type, message) = match &self {
120 AuthError::MissingSignature => (
121 StatusCode::UNAUTHORIZED,
122 "missing_signature",
123 self.to_string(),
124 ),
125 AuthError::MissingExpiry => {
126 (StatusCode::UNAUTHORIZED, "missing_expiry", self.to_string())
127 }
128 AuthError::Expired { .. } => (
129 StatusCode::UNAUTHORIZED,
130 "signature_expired",
131 self.to_string(),
132 ),
133 AuthError::InvalidSignature => (
134 StatusCode::UNAUTHORIZED,
135 "invalid_signature",
136 self.to_string(),
137 ),
138 AuthError::InvalidSignatureFormat => (
139 StatusCode::BAD_REQUEST,
140 "invalid_signature_format",
141 self.to_string(),
142 ),
143 AuthError::InvalidExpiryFormat => (
144 StatusCode::BAD_REQUEST,
145 "invalid_expiry_format",
146 self.to_string(),
147 ),
148 };
149
150 match &self {
154 AuthError::InvalidSignature => {
155 warn!(
156 error_type = error_type,
157 status = status.as_u16(),
158 "Authentication failed: {}",
159 message
160 );
161 }
162 AuthError::Expired { .. } => {
163 debug!(
164 error_type = error_type,
165 status = status.as_u16(),
166 "Authentication failed: {}",
167 message
168 );
169 }
170 _ => {
171 debug!(
172 error_type = error_type,
173 status = status.as_u16(),
174 "Authentication failed: {}",
175 message
176 );
177 }
178 }
179
180 let error_response = ErrorResponse::with_status(error_type, message, status);
181 (status, Json(error_response)).into_response()
182 }
183}
184
185#[derive(Clone)]
194pub struct SignedUrlAuth {
195 secret_key: Vec<u8>,
197}
198
199impl SignedUrlAuth {
200 pub fn new(secret_key: impl AsRef<[u8]>) -> Self {
207 Self {
208 secret_key: secret_key.as_ref().to_vec(),
209 }
210 }
211
212 pub fn sign(&self, path: &str, ttl: Duration) -> (String, u64) {
225 self.sign_with_params(path, ttl, &[])
226 }
227
228 pub fn sign_with_params(
232 &self,
233 path: &str,
234 ttl: Duration,
235 params: &[(&str, &str)],
236 ) -> (String, u64) {
237 let expiry = SystemTime::now()
238 .duration_since(UNIX_EPOCH)
239 .unwrap()
240 .as_secs()
241 + ttl.as_secs();
242
243 let signature = self.compute_signature(path, expiry, params);
244 (signature, expiry)
245 }
246
247 pub fn sign_with_expiry(&self, path: &str, expiry: u64) -> String {
260 self.sign_with_expiry_and_params(path, expiry, &[])
261 }
262
263 pub fn sign_with_expiry_and_params(
267 &self,
268 path: &str,
269 expiry: u64,
270 params: &[(&str, &str)],
271 ) -> String {
272 self.compute_signature(path, expiry, params)
273 }
274
275 pub fn verify(
287 &self,
288 path: &str,
289 signature: &str,
290 expiry: u64,
291 params: &[(&str, &str)],
292 ) -> Result<(), AuthError> {
293 let current_time = SystemTime::now()
295 .duration_since(UNIX_EPOCH)
296 .unwrap()
297 .as_secs();
298
299 if current_time > expiry {
300 return Err(AuthError::Expired {
301 expired_at: expiry,
302 current_time,
303 });
304 }
305
306 let provided_sig = hex::decode(signature).map_err(|_| AuthError::InvalidSignatureFormat)?;
308
309 let expected_sig_hex = self.compute_signature(path, expiry, params);
311 let expected_sig =
312 hex::decode(&expected_sig_hex).map_err(|_| AuthError::InvalidSignatureFormat)?;
313
314 if provided_sig.ct_eq(&expected_sig).into() {
316 Ok(())
317 } else {
318 Err(AuthError::InvalidSignature)
319 }
320 }
321
322 fn compute_signature(&self, path: &str, expiry: u64, params: &[(&str, &str)]) -> String {
324 let message = signature_base(path, expiry, params);
325
326 let mut mac =
328 HmacSha256::new_from_slice(&self.secret_key).expect("HMAC can take key of any size");
329 mac.update(message.as_bytes());
330 let result = mac.finalize();
331
332 hex::encode(result.into_bytes())
334 }
335
336 pub fn generate_signed_url(
349 &self,
350 base_url: &str,
351 path: &str,
352 ttl: Duration,
353 extra_params: &[(&str, &str)],
354 ) -> String {
355 let (signature, expiry) = self.sign_with_params(path, ttl, extra_params);
356
357 let mut url = format!("{}{}", base_url, path);
358
359 let mut serializer = form_urlencoded::Serializer::new(String::new());
360 for (key, value) in extra_params {
361 serializer.append_pair(key, value);
362 }
363 serializer.append_pair("exp", &expiry.to_string());
364 serializer.append_pair("sig", &signature);
365
366 url.push('?');
367 url.push_str(&serializer.finish());
368
369 url
370 }
371
372 pub fn generate_viewer_token(&self, slide_id: &str, ttl: Duration) -> (String, u64) {
387 let expiry = SystemTime::now()
388 .duration_since(UNIX_EPOCH)
389 .unwrap()
390 .as_secs()
391 + ttl.as_secs();
392
393 let message = format!("viewer:{}:{}", slide_id, expiry);
394
395 let mut mac =
396 HmacSha256::new_from_slice(&self.secret_key).expect("HMAC can take key of any size");
397 mac.update(message.as_bytes());
398 let result = mac.finalize();
399
400 (hex::encode(result.into_bytes()), expiry)
401 }
402
403 pub fn verify_viewer_token(
415 &self,
416 slide_id: &str,
417 token: &str,
418 expiry: u64,
419 ) -> Result<(), AuthError> {
420 let current_time = SystemTime::now()
422 .duration_since(UNIX_EPOCH)
423 .unwrap()
424 .as_secs();
425
426 if current_time > expiry {
427 return Err(AuthError::Expired {
428 expired_at: expiry,
429 current_time,
430 });
431 }
432
433 let provided_token = hex::decode(token).map_err(|_| AuthError::InvalidSignatureFormat)?;
435
436 let message = format!("viewer:{}:{}", slide_id, expiry);
438 let mut mac =
439 HmacSha256::new_from_slice(&self.secret_key).expect("HMAC can take key of any size");
440 mac.update(message.as_bytes());
441 let expected_token = mac.finalize().into_bytes();
442
443 if provided_token.ct_eq(&expected_token).into() {
445 Ok(())
446 } else {
447 Err(AuthError::InvalidSignature)
448 }
449 }
450}
451
452fn signature_base(path: &str, expiry: u64, params: &[(&str, &str)]) -> String {
453 let mut all_params: Vec<(String, String)> = Vec::with_capacity(params.len() + 1);
454 for (key, value) in params {
455 all_params.push(((*key).to_string(), (*value).to_string()));
456 }
457 all_params.push(("exp".to_string(), expiry.to_string()));
458
459 let canonical = canonical_query(&all_params);
460 if canonical.is_empty() {
461 path.to_string()
462 } else {
463 format!("{}?{}", path, canonical)
464 }
465}
466
467fn canonical_query(params: &[(String, String)]) -> String {
468 let mut pairs = params.to_vec();
469 pairs.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
470 pairs
471 .into_iter()
472 .map(|(key, value)| format!("{}={}", key, value))
473 .collect::<Vec<_>>()
474 .join("&")
475}
476
477#[derive(Debug, Deserialize)]
483pub struct AuthQueryParams {
484 pub sig: Option<String>,
486
487 pub exp: Option<u64>,
489}
490
491pub async fn auth_middleware(
519 axum::extract::State(auth): axum::extract::State<SignedUrlAuth>,
520 OriginalUri(original_uri): OriginalUri,
521 request: Request,
522 next: Next,
523) -> Result<Response, AuthError> {
524 let query = original_uri.query().unwrap_or("");
525 let mut signature: Option<String> = None;
526 let mut viewer_token: Option<String> = None;
527 let mut expiry: Option<u64> = None;
528 let mut extra_params: Vec<(String, String)> = Vec::new();
529
530 for (key, value) in form_urlencoded::parse(query.as_bytes()) {
531 if key == "sig" {
532 if signature.is_some() {
533 return Err(AuthError::InvalidSignatureFormat);
534 }
535 signature = Some(value.into_owned());
536 continue;
537 }
538 if key == "vt" {
539 if viewer_token.is_some() {
540 return Err(AuthError::InvalidSignatureFormat);
541 }
542 viewer_token = Some(value.into_owned());
543 continue;
544 }
545 if key == "exp" {
546 if expiry.is_some() {
547 return Err(AuthError::InvalidExpiryFormat);
548 }
549 let parsed = value
550 .parse::<u64>()
551 .map_err(|_| AuthError::InvalidExpiryFormat)?;
552 expiry = Some(parsed);
553 continue;
554 }
555
556 extra_params.push((key.into_owned(), value.into_owned()));
557 }
558
559 let expiry = expiry.ok_or(AuthError::MissingExpiry)?;
560 let path = original_uri.path();
561
562 if let Some(token) = viewer_token {
564 let slide_id = extract_slide_id_from_path(path);
567 if let Some(slide_id) = slide_id {
568 auth.verify_viewer_token(&slide_id, &token, expiry)?;
569 return Ok(next.run(request).await);
570 }
571 }
573
574 let signature = signature.ok_or(AuthError::MissingSignature)?;
576
577 let extra_params_ref: Vec<(&str, &str)> = extra_params
579 .iter()
580 .map(|(key, value)| (key.as_str(), value.as_str()))
581 .collect();
582 auth.verify(path, &signature, expiry, &extra_params_ref)?;
583
584 Ok(next.run(request).await)
586}
587
588fn extract_slide_id_from_path(path: &str) -> Option<String> {
596 let parts: Vec<&str> = path.split('/').collect();
597
598 if parts.len() < 3 {
600 return None;
601 }
602
603 match parts[1] {
604 "tiles" | "slides" => {
605 urlencoding::decode(parts[2]).ok().map(|s| s.into_owned())
607 }
608 _ => None,
609 }
610}
611
612#[derive(Debug, Clone)]
618pub struct OptionalAuth {
619 pub authenticated: bool,
621}
622
623impl<S> FromRequestParts<S> for OptionalAuth
624where
625 S: Send + Sync,
626{
627 type Rejection = std::convert::Infallible;
628
629 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
630 let query = parts.uri.query().unwrap_or("");
632 let has_sig = query.contains("sig=");
633 let has_exp = query.contains("exp=");
634
635 Ok(OptionalAuth {
636 authenticated: has_sig && has_exp,
637 })
638 }
639}
640
641#[cfg(test)]
646mod tests {
647 use super::*;
648 use std::time::Duration;
649
650 #[test]
651 fn test_sign_and_verify() {
652 let auth = SignedUrlAuth::new("test-secret-key");
653 let path = "/tiles/slides/sample.svs/0/1/2.jpg";
654 let ttl = Duration::from_secs(3600);
655
656 let (signature, expiry) = auth.sign(path, ttl);
657
658 assert!(auth.verify(path, &signature, expiry, &[]).is_ok());
660 }
661
662 #[test]
663 fn test_verify_wrong_signature() {
664 let auth = SignedUrlAuth::new("test-secret-key");
665 let path = "/tiles/slides/sample.svs/0/1/2.jpg";
666 let ttl = Duration::from_secs(3600);
667
668 let (_, expiry) = auth.sign(path, ttl);
669
670 let wrong_sig = "0".repeat(64); let result = auth.verify(path, &wrong_sig, expiry, &[]);
673 assert!(matches!(result, Err(AuthError::InvalidSignature)));
674 }
675
676 #[test]
677 fn test_verify_wrong_path() {
678 let auth = SignedUrlAuth::new("test-secret-key");
679 let path = "/tiles/slides/sample.svs/0/1/2.jpg";
680 let ttl = Duration::from_secs(3600);
681
682 let (signature, expiry) = auth.sign(path, ttl);
683
684 let wrong_path = "/tiles/slides/other.svs/0/1/2.jpg";
686 let result = auth.verify(wrong_path, &signature, expiry, &[]);
687 assert!(matches!(result, Err(AuthError::InvalidSignature)));
688 }
689
690 #[test]
691 fn test_verify_expired() {
692 let auth = SignedUrlAuth::new("test-secret-key");
693 let path = "/tiles/slides/sample.svs/0/1/2.jpg";
694
695 let expired_time = SystemTime::now()
697 .duration_since(UNIX_EPOCH)
698 .unwrap()
699 .as_secs()
700 - 100; let signature = auth.sign_with_expiry(path, expired_time);
703
704 let result = auth.verify(path, &signature, expired_time, &[]);
705 assert!(matches!(result, Err(AuthError::Expired { .. })));
706 }
707
708 #[test]
709 fn test_verify_invalid_hex() {
710 let auth = SignedUrlAuth::new("test-secret-key");
711 let path = "/tiles/slides/sample.svs/0/1/2.jpg";
712 let expiry = SystemTime::now()
713 .duration_since(UNIX_EPOCH)
714 .unwrap()
715 .as_secs()
716 + 3600;
717
718 let result = auth.verify(path, "not-valid-hex!", expiry, &[]);
720 assert!(matches!(result, Err(AuthError::InvalidSignatureFormat)));
721 }
722
723 #[test]
724 fn test_different_keys_different_signatures() {
725 let auth1 = SignedUrlAuth::new("key1");
726 let auth2 = SignedUrlAuth::new("key2");
727 let path = "/tiles/slides/sample.svs/0/1/2.jpg";
728 let ttl = Duration::from_secs(3600);
729
730 let (sig1, expiry) = auth1.sign(path, ttl);
731 let sig2 = auth2.sign_with_expiry(path, expiry);
732
733 assert_ne!(sig1, sig2);
735
736 assert!(auth1.verify(path, &sig1, expiry, &[]).is_ok());
738 assert!(auth1.verify(path, &sig2, expiry, &[]).is_err());
739 assert!(auth2.verify(path, &sig2, expiry, &[]).is_ok());
740 assert!(auth2.verify(path, &sig1, expiry, &[]).is_err());
741 }
742
743 #[test]
744 fn test_signature_is_deterministic() {
745 let auth = SignedUrlAuth::new("test-secret-key");
746 let path = "/tiles/slides/sample.svs/0/1/2.jpg";
747 let expiry = 1735689600u64;
748
749 let sig1 = auth.sign_with_expiry(path, expiry);
750 let sig2 = auth.sign_with_expiry(path, expiry);
751
752 assert_eq!(sig1, sig2);
754 }
755
756 #[test]
757 fn test_generate_signed_url() {
758 let auth = SignedUrlAuth::new("test-secret-key");
759 let base_url = "https://example.com";
760 let path = "/tiles/slides/sample.svs/0/1/2.jpg";
761 let ttl = Duration::from_secs(3600);
762
763 let url = auth.generate_signed_url(base_url, path, ttl, &[("quality", "80")]);
764
765 assert!(url.starts_with("https://example.com/tiles/slides/sample.svs/0/1/2.jpg?"));
767 assert!(url.contains("quality=80"));
768 assert!(url.contains("exp="));
769 assert!(url.contains("sig="));
770 }
771
772 #[test]
773 fn test_auth_error_display() {
774 let err = AuthError::MissingSignature;
775 assert_eq!(err.to_string(), "Missing signature parameter");
776
777 let err = AuthError::MissingExpiry;
778 assert_eq!(err.to_string(), "Missing expiry parameter");
779
780 let err = AuthError::Expired {
781 expired_at: 1000,
782 current_time: 2000,
783 };
784 assert!(err.to_string().contains("1000"));
785 assert!(err.to_string().contains("2000"));
786
787 let err = AuthError::InvalidSignature;
788 assert_eq!(err.to_string(), "Invalid signature");
789
790 let err = AuthError::InvalidSignatureFormat;
791 assert_eq!(err.to_string(), "Invalid signature format");
792
793 let err = AuthError::InvalidExpiryFormat;
794 assert_eq!(err.to_string(), "Invalid expiry format");
795 }
796
797 #[test]
798 fn test_constant_time_comparison() {
799 let auth = SignedUrlAuth::new("test-secret-key");
802 let path = "/tiles/slides/sample.svs/0/1/2.jpg";
803 let expiry = SystemTime::now()
804 .duration_since(UNIX_EPOCH)
805 .unwrap()
806 .as_secs()
807 + 3600;
808
809 let correct_sig = auth.sign_with_expiry(path, expiry);
810
811 fn flip_hex_char(c: char) -> char {
813 match c {
814 '0'..='8' => ((c as u8) + 1) as char,
815 '9' => '0',
816 'a'..='e' => ((c as u8) + 1) as char,
817 'f' => 'a',
818 _ => '0',
819 }
820 }
821
822 let mut wrong_first = correct_sig.clone();
824 let first_char = correct_sig.chars().next().unwrap();
825 wrong_first.replace_range(0..1, &flip_hex_char(first_char).to_string());
826
827 let mut wrong_middle = correct_sig.clone();
828 let mid = correct_sig.len() / 2;
829 let mid_char = correct_sig.chars().nth(mid).unwrap();
830 wrong_middle.replace_range(mid..mid + 1, &flip_hex_char(mid_char).to_string());
831
832 let mut wrong_last = correct_sig.clone();
833 let last = correct_sig.len() - 1;
834 let last_char = correct_sig.chars().nth(last).unwrap();
835 wrong_last.replace_range(last..last + 1, &flip_hex_char(last_char).to_string());
836
837 assert!(auth.verify(path, &wrong_first, expiry, &[]).is_err());
839 assert!(auth.verify(path, &wrong_middle, expiry, &[]).is_err());
840 assert!(auth.verify(path, &wrong_last, expiry, &[]).is_err());
841 }
842
843 #[test]
844 fn test_viewer_token_generate_and_verify() {
845 let auth = SignedUrlAuth::new("test-secret-key");
846 let slide_id = "slides/sample.svs";
847 let ttl = Duration::from_secs(3600);
848
849 let (token, expiry) = auth.generate_viewer_token(slide_id, ttl);
850
851 assert!(auth.verify_viewer_token(slide_id, &token, expiry).is_ok());
853 }
854
855 #[test]
856 fn test_viewer_token_wrong_slide() {
857 let auth = SignedUrlAuth::new("test-secret-key");
858 let slide_id = "slides/sample.svs";
859 let wrong_slide = "slides/other.svs";
860 let ttl = Duration::from_secs(3600);
861
862 let (token, expiry) = auth.generate_viewer_token(slide_id, ttl);
863
864 assert!(auth
866 .verify_viewer_token(wrong_slide, &token, expiry)
867 .is_err());
868 }
869
870 #[test]
871 fn test_viewer_token_expired() {
872 let auth = SignedUrlAuth::new("test-secret-key");
873 let slide_id = "slides/sample.svs";
874
875 let expired_time = SystemTime::now()
877 .duration_since(UNIX_EPOCH)
878 .unwrap()
879 .as_secs()
880 - 100; let message = format!("viewer:{}:{}", slide_id, expired_time);
883 let mut mac = HmacSha256::new_from_slice(b"test-secret-key").unwrap();
884 mac.update(message.as_bytes());
885 let token = hex::encode(mac.finalize().into_bytes());
886
887 let result = auth.verify_viewer_token(slide_id, &token, expired_time);
888 assert!(matches!(result, Err(AuthError::Expired { .. })));
889 }
890
891 #[test]
892 fn test_viewer_token_different_keys() {
893 let auth1 = SignedUrlAuth::new("key1");
894 let auth2 = SignedUrlAuth::new("key2");
895 let slide_id = "slides/sample.svs";
896 let ttl = Duration::from_secs(3600);
897
898 let (token, expiry) = auth1.generate_viewer_token(slide_id, ttl);
899
900 assert!(auth2.verify_viewer_token(slide_id, &token, expiry).is_err());
902 }
903
904 #[test]
905 fn test_extract_slide_id_from_path_tiles() {
906 assert_eq!(
907 extract_slide_id_from_path("/tiles/sample.svs/0/1/2.jpg"),
908 Some("sample.svs".to_string())
909 );
910 assert_eq!(
911 extract_slide_id_from_path("/tiles/folder%2Fsample.svs/0/1/2.jpg"),
912 Some("folder/sample.svs".to_string())
913 );
914 }
915
916 #[test]
917 fn test_extract_slide_id_from_path_slides() {
918 assert_eq!(
919 extract_slide_id_from_path("/slides/sample.svs"),
920 Some("sample.svs".to_string())
921 );
922 assert_eq!(
923 extract_slide_id_from_path("/slides/sample.svs/dzi"),
924 Some("sample.svs".to_string())
925 );
926 assert_eq!(
927 extract_slide_id_from_path("/slides/sample.svs/thumbnail"),
928 Some("sample.svs".to_string())
929 );
930 }
931
932 #[test]
933 fn test_extract_slide_id_from_path_invalid() {
934 assert_eq!(extract_slide_id_from_path("/health"), None);
935 assert_eq!(extract_slide_id_from_path("/view/sample.svs"), None);
936 assert_eq!(extract_slide_id_from_path("/"), None);
937 assert_eq!(extract_slide_id_from_path(""), None);
938 }
939}