1use axum::extract::Request;
6use axum::http::StatusCode;
7use axum::middleware::Next;
8use axum::response::{IntoResponse, Response};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub struct AuthConfig {
14 pub enabled: bool,
16 api_key: Option<String>,
18}
19
20impl AuthConfig {
21 pub const fn disabled() -> Self {
23 Self {
24 enabled: false,
25 api_key: None,
26 }
27 }
28
29 pub const fn with_api_key(api_key: String) -> Self {
31 Self {
32 enabled: true,
33 api_key: Some(api_key),
34 }
35 }
36
37 pub fn validate_key(&self, provided_key: &str) -> bool {
39 if !self.enabled {
40 return true;
41 }
42
43 match &self.api_key {
44 Some(key) => constant_time_compare(key, provided_key),
45 None => false,
46 }
47 }
48
49 pub const fn is_required(&self) -> bool {
51 self.enabled
52 }
53
54 pub fn api_key(&self) -> Option<&str> {
56 self.api_key.as_deref()
57 }
58}
59
60impl Default for AuthConfig {
61 fn default() -> Self {
62 Self::disabled()
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
68pub enum AuthError {
69 MissingCredentials,
71 InvalidCredentials,
73 MalformedHeader,
75}
76
77impl std::fmt::Display for AuthError {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 match self {
80 Self::MissingCredentials => write!(f, "Authentication required"),
81 Self::InvalidCredentials => write!(f, "Invalid API key"),
82 Self::MalformedHeader => write!(f, "Malformed authorization header"),
83 }
84 }
85}
86
87impl std::error::Error for AuthError {}
88
89pub type AuthResult<T> = Result<T, AuthError>;
91
92pub fn extract_from_header(header_value: &str) -> AuthResult<String> {
99 let header = header_value.trim();
100
101 if header.is_empty() {
102 return Err(AuthError::MissingCredentials);
103 }
104
105 if let Some(rest) = header.strip_prefix("Bearer ") {
107 let key = rest.trim();
108 if key.is_empty() {
109 return Err(AuthError::MalformedHeader);
110 }
111 return Ok(key.to_string());
112 }
113
114 if let Some(rest) = header.strip_prefix("Bearer\t") {
116 let key = rest.trim();
117 if key.is_empty() {
118 return Err(AuthError::MalformedHeader);
119 }
120 return Ok(key.to_string());
121 }
122
123 if header == "Bearer" {
125 return Err(AuthError::MalformedHeader);
126 }
127
128 if let Some(rest) = header.strip_prefix("ApiKey ") {
130 let key = rest.trim();
131 if key.is_empty() {
132 return Err(AuthError::MalformedHeader);
133 }
134 return Ok(key.to_string());
135 }
136
137 if let Some(rest) = header.strip_prefix("ApiKey\t") {
139 let key = rest.trim();
140 if key.is_empty() {
141 return Err(AuthError::MalformedHeader);
142 }
143 return Ok(key.to_string());
144 }
145
146 if header == "ApiKey" {
148 return Err(AuthError::MalformedHeader);
149 }
150
151 Ok(header.to_string())
153}
154
155pub fn extract_from_ws_protocol(header: &str) -> AuthResult<String> {
161 for protocol in header.split(',') {
162 let protocol = protocol.trim();
163 if let Some(key) = protocol.strip_prefix("varpulis-auth.") {
164 if !key.is_empty() {
165 return Ok(key.to_string());
166 }
167 }
168 }
169 Err(AuthError::MissingCredentials)
170}
171
172pub fn extract_from_query(query: &str) -> AuthResult<String> {
176 if query.is_empty() {
177 return Err(AuthError::MissingCredentials);
178 }
179
180 for pair in query.split('&') {
182 let mut parts = pair.splitn(2, '=');
183 let key = parts.next().unwrap_or("");
184 let value = parts.next().unwrap_or("");
185
186 if (key == "api_key" || key == "token") && !value.is_empty() {
187 let decoded = url_decode(value);
189 return Ok(decoded);
190 }
191 }
192
193 Err(AuthError::MissingCredentials)
194}
195
196fn url_decode(s: &str) -> String {
198 let mut result = String::with_capacity(s.len());
199 let mut chars = s.chars();
200
201 while let Some(c) = chars.next() {
202 if c == '%' {
203 let hex: String = chars.by_ref().take(2).collect();
205 if hex.len() == 2 {
206 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
207 result.push(byte as char);
208 continue;
209 }
210 }
211 result.push('%');
213 result.push_str(&hex);
214 } else if c == '+' {
215 result.push(' ');
216 } else {
217 result.push(c);
218 }
219 }
220
221 result
222}
223
224pub fn constant_time_compare(a: &str, b: &str) -> bool {
229 varpulis_core::security::constant_time_compare(a, b)
230}
231
232pub fn generate_api_key() -> String {
237 use rand::Rng;
238
239 let mut rng = rand::thread_rng();
240 let mut key = String::with_capacity(32);
241 const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
242
243 for _ in 0..32 {
244 let idx = rng.gen_range(0..CHARSET.len());
245 key.push(CHARSET[idx] as char);
246 }
247
248 key
249}
250
251pub fn auth_middleware(config: Arc<AuthConfig>) -> impl tower::Layer<axum::routing::Route> + Clone {
262 axum::middleware::from_fn_with_state::<_, _, ()>(config, auth_middleware_fn)
263}
264
265#[derive(Debug, Clone)]
267pub struct AuthState {
268 pub config: Arc<AuthConfig>,
269 pub oauth_state: Option<crate::oauth::SharedOAuthState>,
270}
271
272pub fn auth_middleware_with_jwt(
274 config: Arc<AuthConfig>,
275 oauth_state: Option<crate::oauth::SharedOAuthState>,
276) -> impl tower::Layer<axum::routing::Route> + Clone {
277 let state = AuthState {
278 config,
279 oauth_state,
280 };
281 axum::middleware::from_fn_with_state::<_, _, ()>(state, auth_middleware_jwt_fn)
282}
283
284pub async fn auth_middleware_fn(
286 axum::extract::State(config): axum::extract::State<Arc<AuthConfig>>,
287 req: Request,
288 next: Next,
289) -> Result<Response, AuthRejection> {
290 let state = AuthState {
291 config,
292 oauth_state: None,
293 };
294 check_auth(&state, &req).await?;
295 Ok(next.run(req).await)
296}
297
298async fn auth_middleware_jwt_fn(
300 axum::extract::State(state): axum::extract::State<AuthState>,
301 req: Request,
302 next: Next,
303) -> Result<Response, AuthRejection> {
304 check_auth(&state, &req).await?;
305 Ok(next.run(req).await)
306}
307
308pub async fn check_auth(state: &AuthState, req: &Request) -> Result<(), AuthRejection> {
310 check_auth_from_parts(state, req.headers(), req.uri()).await
311}
312
313pub async fn check_auth_from_parts(
318 state: &AuthState,
319 headers: &axum::http::HeaderMap,
320 uri: &axum::http::Uri,
321) -> Result<(), AuthRejection> {
322 let config = &state.config;
323 let oauth = &state.oauth_state;
324
325 if !config.is_required() {
327 return Ok(());
328 }
329
330 let auth_header = headers
331 .get("authorization")
332 .and_then(|v| v.to_str().ok())
333 .map(|s| s.to_string());
334 let cookie_header = headers
335 .get("cookie")
336 .and_then(|v| v.to_str().ok())
337 .map(|s| s.to_string());
338 let ws_protocol = headers
339 .get("sec-websocket-protocol")
340 .and_then(|v| v.to_str().ok())
341 .map(|s| s.to_string());
342 let query = uri.query().unwrap_or("").to_string();
343
344 if let Some(header) = &auth_header {
346 match extract_from_header(header) {
347 Ok(key) if config.validate_key(&key) => return Ok(()),
348 Ok(_) => return Err(AuthRejection::InvalidCredentials),
349 Err(AuthError::MalformedHeader) => return Err(AuthRejection::MalformedHeader),
350 Err(_) => {} }
352 }
353
354 if let Some(ref cookie) = cookie_header {
356 if let Some(jwt) = crate::oauth::extract_jwt_from_cookie(cookie) {
357 if let Some(ref state) = oauth {
358 let hash = crate::oauth::token_hash(&jwt);
360 if !state.sessions.read().await.is_revoked(&hash)
361 && crate::oauth::verify_jwt(&state.config, &jwt).is_ok()
362 {
363 return Ok(());
364 }
365 }
366 }
367 }
368
369 if let Some(ref header) = auth_header {
371 if let Some(token) = header.strip_prefix("Bearer ") {
372 let token = token.trim();
373 if !token.is_empty() {
374 if let Some(ref state) = oauth {
375 let hash = crate::oauth::token_hash(token);
376 if !state.sessions.read().await.is_revoked(&hash)
377 && crate::oauth::verify_jwt(&state.config, token).is_ok()
378 {
379 return Ok(());
380 }
381 }
382 }
383 }
384 }
385
386 if let Some(ref protocol) = ws_protocol {
388 match extract_from_ws_protocol(protocol) {
389 Ok(key) if config.validate_key(&key) => return Ok(()),
390 Ok(_) => return Err(AuthRejection::InvalidCredentials),
391 Err(_) => {} }
393 }
394
395 match extract_from_query(&query) {
397 Ok(key) if config.validate_key(&key) => Ok(()),
398 Ok(_) => Err(AuthRejection::InvalidCredentials),
399 Err(_) => Err(AuthRejection::MissingCredentials),
400 }
401}
402
403#[derive(Debug)]
405pub enum AuthRejection {
406 MissingCredentials,
407 InvalidCredentials,
408 MalformedHeader,
409}
410
411impl IntoResponse for AuthRejection {
412 fn into_response(self) -> Response {
413 let (code, message) = match self {
414 Self::MissingCredentials => (StatusCode::UNAUTHORIZED, "Authentication required"),
415 Self::InvalidCredentials => (StatusCode::UNAUTHORIZED, "Invalid API key"),
416 Self::MalformedHeader => (StatusCode::BAD_REQUEST, "Malformed authorization header"),
417 };
418 (code, axum::Json(serde_json::json!({ "error": message }))).into_response()
419 }
420}
421
422#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
435 fn test_auth_config_disabled() {
436 let config = AuthConfig::disabled();
437 assert!(!config.enabled);
438 assert!(!config.is_required());
439 }
440
441 #[test]
442 fn test_auth_config_with_api_key() {
443 let config = AuthConfig::with_api_key("secret123".to_string());
444 assert!(config.enabled);
445 assert!(config.is_required());
446 }
447
448 #[test]
449 fn test_auth_config_validate_key_disabled() {
450 let config = AuthConfig::disabled();
451 assert!(config.validate_key("anything"));
453 assert!(config.validate_key(""));
454 }
455
456 #[test]
457 fn test_auth_config_validate_key_correct() {
458 let config = AuthConfig::with_api_key("secret123".to_string());
459 assert!(config.validate_key("secret123"));
460 }
461
462 #[test]
463 fn test_auth_config_validate_key_incorrect() {
464 let config = AuthConfig::with_api_key("secret123".to_string());
465 assert!(!config.validate_key("wrong"));
466 assert!(!config.validate_key(""));
467 assert!(!config.validate_key("secret1234")); assert!(!config.validate_key("secret12")); }
470
471 #[test]
472 fn test_auth_config_default() {
473 let config = AuthConfig::default();
474 assert!(!config.enabled);
475 }
476
477 #[test]
482 fn test_extract_from_header_bearer() {
483 let result = extract_from_header("Bearer my-api-key");
484 assert_eq!(result, Ok("my-api-key".to_string()));
485 }
486
487 #[test]
488 fn test_extract_from_header_bearer_with_spaces() {
489 let result = extract_from_header(" Bearer my-api-key ");
490 assert_eq!(result, Ok("my-api-key".to_string()));
491 }
492
493 #[test]
494 fn test_extract_from_header_apikey() {
495 let result = extract_from_header("ApiKey secret-key");
496 assert_eq!(result, Ok("secret-key".to_string()));
497 }
498
499 #[test]
500 fn test_extract_from_header_raw() {
501 let result = extract_from_header("raw-key-without-prefix");
502 assert_eq!(result, Ok("raw-key-without-prefix".to_string()));
503 }
504
505 #[test]
506 fn test_extract_from_header_empty() {
507 let result = extract_from_header("");
508 assert_eq!(result, Err(AuthError::MissingCredentials));
509 }
510
511 #[test]
512 fn test_extract_from_header_bearer_empty_key() {
513 let result = extract_from_header("Bearer ");
514 assert_eq!(result, Err(AuthError::MalformedHeader));
515 }
516
517 #[test]
518 fn test_extract_from_header_apikey_empty_key() {
519 let result = extract_from_header("ApiKey ");
520 assert_eq!(result, Err(AuthError::MalformedHeader));
521 }
522
523 #[test]
528 fn test_extract_from_query_api_key() {
529 let result = extract_from_query("api_key=my-secret");
530 assert_eq!(result, Ok("my-secret".to_string()));
531 }
532
533 #[test]
534 fn test_extract_from_query_token() {
535 let result = extract_from_query("token=my-token");
536 assert_eq!(result, Ok("my-token".to_string()));
537 }
538
539 #[test]
540 fn test_extract_from_query_with_other_params() {
541 let result = extract_from_query("foo=bar&api_key=secret&baz=qux");
542 assert_eq!(result, Ok("secret".to_string()));
543 }
544
545 #[test]
546 fn test_extract_from_query_empty() {
547 let result = extract_from_query("");
548 assert_eq!(result, Err(AuthError::MissingCredentials));
549 }
550
551 #[test]
552 fn test_extract_from_query_no_key() {
553 let result = extract_from_query("foo=bar&baz=qux");
554 assert_eq!(result, Err(AuthError::MissingCredentials));
555 }
556
557 #[test]
558 fn test_extract_from_query_empty_value() {
559 let result = extract_from_query("api_key=");
560 assert_eq!(result, Err(AuthError::MissingCredentials));
561 }
562
563 #[test]
564 fn test_extract_from_query_url_encoded() {
565 let result = extract_from_query("api_key=key%20with%20spaces");
566 assert_eq!(result, Ok("key with spaces".to_string()));
567 }
568
569 #[test]
570 fn test_extract_from_query_plus_sign() {
571 let result = extract_from_query("api_key=key+with+plus");
572 assert_eq!(result, Ok("key with plus".to_string()));
573 }
574
575 #[test]
580 fn test_extract_from_ws_protocol_valid() {
581 let result = extract_from_ws_protocol("varpulis-v1, varpulis-auth.my-secret-key");
582 assert_eq!(result, Ok("my-secret-key".to_string()));
583 }
584
585 #[test]
586 fn test_extract_from_ws_protocol_only_auth() {
587 let result = extract_from_ws_protocol("varpulis-auth.abc123");
588 assert_eq!(result, Ok("abc123".to_string()));
589 }
590
591 #[test]
592 fn test_extract_from_ws_protocol_no_auth() {
593 let result = extract_from_ws_protocol("varpulis-v1");
594 assert!(result.is_err());
595 }
596
597 #[test]
598 fn test_extract_from_ws_protocol_empty() {
599 let result = extract_from_ws_protocol("");
600 assert!(result.is_err());
601 }
602
603 #[test]
604 fn test_extract_from_ws_protocol_empty_key() {
605 let result = extract_from_ws_protocol("varpulis-auth.");
606 assert!(result.is_err());
607 }
608
609 #[test]
614 fn test_url_decode_plain() {
615 assert_eq!(url_decode("hello"), "hello");
616 }
617
618 #[test]
619 fn test_url_decode_spaces() {
620 assert_eq!(url_decode("hello%20world"), "hello world");
621 }
622
623 #[test]
624 fn test_url_decode_plus() {
625 assert_eq!(url_decode("hello+world"), "hello world");
626 }
627
628 #[test]
629 fn test_url_decode_special_chars() {
630 assert_eq!(url_decode("%21%40%23"), "!@#");
631 }
632
633 #[test]
638 fn test_constant_time_compare_equal() {
639 assert!(constant_time_compare("abc", "abc"));
640 assert!(constant_time_compare("", ""));
641 assert!(constant_time_compare(
642 "longer-string-123",
643 "longer-string-123"
644 ));
645 }
646
647 #[test]
648 fn test_constant_time_compare_not_equal() {
649 assert!(!constant_time_compare("abc", "abd"));
650 assert!(!constant_time_compare("abc", "ab"));
651 assert!(!constant_time_compare("abc", "abcd"));
652 assert!(!constant_time_compare("", "a"));
653 }
654
655 #[test]
660 fn test_generate_api_key_length() {
661 let key = generate_api_key();
662 assert_eq!(key.len(), 32);
663 }
664
665 #[test]
666 fn test_generate_api_key_alphanumeric() {
667 let key = generate_api_key();
668 assert!(key.chars().all(|c| c.is_ascii_alphanumeric()));
669 }
670
671 #[test]
672 fn test_generate_api_key_unique() {
673 let key1 = generate_api_key();
674 std::thread::sleep(std::time::Duration::from_millis(1));
675 let key2 = generate_api_key();
676 assert_ne!(key1, key2);
677 }
678
679 #[test]
684 fn test_auth_error_display_missing() {
685 let err = AuthError::MissingCredentials;
686 assert_eq!(format!("{err}"), "Authentication required");
687 }
688
689 #[test]
690 fn test_auth_error_display_invalid() {
691 let err = AuthError::InvalidCredentials;
692 assert_eq!(format!("{err}"), "Invalid API key");
693 }
694
695 #[test]
696 fn test_auth_error_display_malformed() {
697 let err = AuthError::MalformedHeader;
698 assert_eq!(format!("{err}"), "Malformed authorization header");
699 }
700
701 #[tokio::test]
706 async fn test_with_auth_disabled() {
707 let config = Arc::new(AuthConfig::disabled());
708 let state = AuthState {
709 config,
710 oauth_state: None,
711 };
712 let req = Request::builder()
714 .uri("/")
715 .body(axum::body::Body::empty())
716 .unwrap();
717 let result = check_auth(&state, &req).await;
718 assert!(result.is_ok());
719 }
720
721 #[tokio::test]
722 async fn test_with_auth_valid_header() {
723 let config = Arc::new(AuthConfig::with_api_key("secret".to_string()));
724 let state = AuthState {
725 config,
726 oauth_state: None,
727 };
728 let req = Request::builder()
729 .uri("/")
730 .header("authorization", "Bearer secret")
731 .body(axum::body::Body::empty())
732 .unwrap();
733 let result = check_auth(&state, &req).await;
734 assert!(result.is_ok());
735 }
736
737 #[tokio::test]
738 async fn test_with_auth_valid_query() {
739 let config = Arc::new(AuthConfig::with_api_key("secret".to_string()));
740 let state = AuthState {
741 config,
742 oauth_state: None,
743 };
744 let req = Request::builder()
745 .uri("/?api_key=secret")
746 .body(axum::body::Body::empty())
747 .unwrap();
748 let result = check_auth(&state, &req).await;
749 assert!(result.is_ok());
750 }
751
752 #[tokio::test]
753 async fn test_with_auth_invalid_key() {
754 let config = Arc::new(AuthConfig::with_api_key("secret".to_string()));
755 let state = AuthState {
756 config,
757 oauth_state: None,
758 };
759 let req = Request::builder()
760 .uri("/")
761 .header("authorization", "Bearer wrong")
762 .body(axum::body::Body::empty())
763 .unwrap();
764 let result = check_auth(&state, &req).await;
765 assert!(matches!(result, Err(AuthRejection::InvalidCredentials)));
766 }
767
768 #[tokio::test]
769 async fn test_with_auth_missing_credentials() {
770 let config = Arc::new(AuthConfig::with_api_key("secret".to_string()));
771 let state = AuthState {
772 config,
773 oauth_state: None,
774 };
775 let req = Request::builder()
776 .uri("/")
777 .body(axum::body::Body::empty())
778 .unwrap();
779 let result = check_auth(&state, &req).await;
780 assert!(matches!(result, Err(AuthRejection::MissingCredentials)));
781 }
782}