1use std::collections::HashMap;
27use std::sync::Arc;
28use thiserror::Error;
29
30#[derive(Error, Debug, Clone)]
32pub enum AuthError {
33 #[error("Authentication token missing")]
35 TokenMissing,
36
37 #[error("Invalid token format: {0}")]
39 InvalidFormat(String),
40
41 #[error("Token has expired")]
43 TokenExpired,
44
45 #[error("Invalid token signature")]
47 InvalidSignature,
48
49 #[error("Token validation failed: {0}")]
51 ValidationFailed(String),
52
53 #[error("Insufficient permissions: {0}")]
55 InsufficientPermissions(String),
56}
57
58impl AuthError {
59 pub fn validation_failed(msg: impl Into<String>) -> Self {
61 Self::ValidationFailed(msg.into())
62 }
63
64 pub fn invalid_format(msg: impl Into<String>) -> Self {
66 Self::InvalidFormat(msg.into())
67 }
68
69 pub fn insufficient_permissions(msg: impl Into<String>) -> Self {
71 Self::InsufficientPermissions(msg.into())
72 }
73}
74
75#[derive(Debug, Clone)]
79pub struct Claims {
80 pub sub: String,
82 pub extra: HashMap<String, String>,
84}
85
86impl Claims {
87 pub fn new(sub: impl Into<String>) -> Self {
89 Self {
90 sub: sub.into(),
91 extra: HashMap::new(),
92 }
93 }
94
95 pub fn with_extra(sub: impl Into<String>, extra: HashMap<String, String>) -> Self {
97 Self {
98 sub: sub.into(),
99 extra,
100 }
101 }
102
103 pub fn subject(&self) -> &str {
105 &self.sub
106 }
107
108 pub fn get(&self, key: &str) -> Option<&str> {
110 self.extra.get(key).map(|s| s.as_str())
111 }
112
113 pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) {
115 self.extra.insert(key.into(), value.into());
116 }
117}
118
119#[derive(Debug, Clone)]
121pub enum TokenExtractor {
122 Header(String),
124 Query(String),
126 Protocol,
128}
129
130impl Default for TokenExtractor {
131 fn default() -> Self {
132 Self::Header("Authorization".to_string())
133 }
134}
135
136impl TokenExtractor {
137 pub fn header(name: impl Into<String>) -> Self {
139 Self::Header(name.into())
140 }
141
142 pub fn query(name: impl Into<String>) -> Self {
144 Self::Query(name.into())
145 }
146
147 pub fn protocol() -> Self {
149 Self::Protocol
150 }
151
152 pub fn extract<B>(&self, req: &http::Request<B>) -> Option<String> {
154 match self {
155 TokenExtractor::Header(name) => {
156 req.headers()
157 .get(name)
158 .and_then(|v| v.to_str().ok())
159 .map(|s| {
160 if let Some(token) = s.strip_prefix("Bearer ") {
162 token.to_string()
163 } else {
164 s.to_string()
165 }
166 })
167 }
168 TokenExtractor::Query(name) => req.uri().query().and_then(|query| {
169 url::form_urlencoded::parse(query.as_bytes())
170 .find(|(key, _)| key == name)
171 .map(|(_, value)| value.into_owned())
172 }),
173 TokenExtractor::Protocol => req
174 .headers()
175 .get("Sec-WebSocket-Protocol")
176 .and_then(|v| v.to_str().ok())
177 .map(|s| s.to_string()),
178 }
179 }
180}
181
182#[async_trait::async_trait]
186pub trait TokenValidator: Send + Sync {
187 async fn validate(&self, token: &str) -> Result<Claims, AuthError>;
189}
190
191#[derive(Clone)]
193pub struct WsAuthConfig {
194 pub extractor: TokenExtractor,
196 pub validator: Arc<dyn TokenValidator>,
198 pub required: bool,
200}
201
202impl WsAuthConfig {
203 pub fn new<V: TokenValidator + 'static>(validator: V) -> Self {
205 Self {
206 extractor: TokenExtractor::default(),
207 validator: Arc::new(validator),
208 required: true,
209 }
210 }
211
212 pub fn extractor(mut self, extractor: TokenExtractor) -> Self {
214 self.extractor = extractor;
215 self
216 }
217
218 pub fn required(mut self, required: bool) -> Self {
220 self.required = required;
221 self
222 }
223
224 pub async fn authenticate<B>(
226 &self,
227 req: &http::Request<B>,
228 ) -> Result<Option<Claims>, AuthError> {
229 match self.extractor.extract(req) {
230 Some(token) => {
231 let claims = self.validator.validate(&token).await?;
232 Ok(Some(claims))
233 }
234 None if self.required => Err(AuthError::TokenMissing),
235 None => Ok(None),
236 }
237 }
238}
239
240impl std::fmt::Debug for WsAuthConfig {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 f.debug_struct("WsAuthConfig")
243 .field("extractor", &self.extractor)
244 .field("required", &self.required)
245 .finish()
246 }
247}
248
249pub struct AcceptAllValidator;
253
254#[async_trait::async_trait]
255impl TokenValidator for AcceptAllValidator {
256 async fn validate(&self, token: &str) -> Result<Claims, AuthError> {
257 if token.is_empty() {
258 return Err(AuthError::invalid_format("Token cannot be empty"));
259 }
260 Ok(Claims::new(token))
261 }
262}
263
264pub struct RejectAllValidator;
268
269#[async_trait::async_trait]
270impl TokenValidator for RejectAllValidator {
271 async fn validate(&self, _token: &str) -> Result<Claims, AuthError> {
272 Err(AuthError::validation_failed("All tokens rejected"))
273 }
274}
275
276pub struct StaticTokenValidator {
278 tokens: HashMap<String, Claims>,
279}
280
281impl StaticTokenValidator {
282 pub fn new() -> Self {
284 Self {
285 tokens: HashMap::new(),
286 }
287 }
288
289 pub fn add_token(mut self, token: impl Into<String>, claims: Claims) -> Self {
291 self.tokens.insert(token.into(), claims);
292 self
293 }
294}
295
296impl Default for StaticTokenValidator {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302#[async_trait::async_trait]
303impl TokenValidator for StaticTokenValidator {
304 async fn validate(&self, token: &str) -> Result<Claims, AuthError> {
305 self.tokens
306 .get(token)
307 .cloned()
308 .ok_or_else(|| AuthError::validation_failed("Invalid token"))
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use http::Request;
316
317 #[test]
318 fn test_token_extractor_header() {
319 let extractor = TokenExtractor::header("Authorization");
320
321 let req = Request::builder()
322 .header("Authorization", "Bearer test-token")
323 .body(())
324 .unwrap();
325
326 assert_eq!(extractor.extract(&req), Some("test-token".to_string()));
327 }
328
329 #[test]
330 fn test_token_extractor_header_no_bearer() {
331 let extractor = TokenExtractor::header("X-API-Key");
332
333 let req = Request::builder()
334 .header("X-API-Key", "my-api-key")
335 .body(())
336 .unwrap();
337
338 assert_eq!(extractor.extract(&req), Some("my-api-key".to_string()));
339 }
340
341 #[test]
342 fn test_token_extractor_query() {
343 let extractor = TokenExtractor::query("token");
344
345 let req = Request::builder()
346 .uri("ws://localhost/ws?token=query-token&other=value")
347 .body(())
348 .unwrap();
349
350 assert_eq!(extractor.extract(&req), Some("query-token".to_string()));
351 }
352
353 #[test]
354 fn test_token_extractor_protocol() {
355 let extractor = TokenExtractor::protocol();
356
357 let req = Request::builder()
358 .header("Sec-WebSocket-Protocol", "my-protocol-token")
359 .body(())
360 .unwrap();
361
362 assert_eq!(
363 extractor.extract(&req),
364 Some("my-protocol-token".to_string())
365 );
366 }
367
368 #[test]
369 fn test_token_extractor_missing() {
370 let extractor = TokenExtractor::header("Authorization");
371
372 let req = Request::builder().body(()).unwrap();
373
374 assert_eq!(extractor.extract(&req), None);
375 }
376
377 #[tokio::test]
378 async fn test_accept_all_validator() {
379 let validator = AcceptAllValidator;
380
381 let result = validator.validate("any-token").await;
382 assert!(result.is_ok());
383 assert_eq!(result.unwrap().subject(), "any-token");
384 }
385
386 #[tokio::test]
387 async fn test_accept_all_validator_empty() {
388 let validator = AcceptAllValidator;
389
390 let result = validator.validate("").await;
391 assert!(result.is_err());
392 }
393
394 #[tokio::test]
395 async fn test_reject_all_validator() {
396 let validator = RejectAllValidator;
397
398 let result = validator.validate("any-token").await;
399 assert!(result.is_err());
400 }
401
402 #[tokio::test]
403 async fn test_static_token_validator() {
404 let validator =
405 StaticTokenValidator::new().add_token("valid-token", Claims::new("user-123"));
406
407 let result = validator.validate("valid-token").await;
408 assert!(result.is_ok());
409 assert_eq!(result.unwrap().subject(), "user-123");
410
411 let result = validator.validate("invalid-token").await;
412 assert!(result.is_err());
413 }
414
415 #[tokio::test]
416 async fn test_ws_auth_config_required() {
417 let config = WsAuthConfig::new(AcceptAllValidator)
418 .extractor(TokenExtractor::header("Authorization"))
419 .required(true);
420
421 let req = Request::builder().body(()).unwrap();
422
423 let result = config.authenticate(&req).await;
424 assert!(matches!(result, Err(AuthError::TokenMissing)));
425 }
426
427 #[tokio::test]
428 async fn test_ws_auth_config_optional() {
429 let config = WsAuthConfig::new(AcceptAllValidator)
430 .extractor(TokenExtractor::header("Authorization"))
431 .required(false);
432
433 let req = Request::builder().body(()).unwrap();
434
435 let result = config.authenticate(&req).await;
436 assert!(result.is_ok());
437 assert!(result.unwrap().is_none());
438 }
439
440 #[tokio::test]
441 async fn test_ws_auth_config_with_token() {
442 let config = WsAuthConfig::new(AcceptAllValidator)
443 .extractor(TokenExtractor::header("Authorization"));
444
445 let req = Request::builder()
446 .header("Authorization", "Bearer my-token")
447 .body(())
448 .unwrap();
449
450 let result = config.authenticate(&req).await;
451 assert!(result.is_ok());
452 let claims = result.unwrap().unwrap();
453 assert_eq!(claims.subject(), "my-token");
454 }
455
456 #[test]
457 fn test_claims_extra() {
458 let mut claims = Claims::new("user-123");
459 claims.insert("role", "admin");
460 claims.insert("tenant", "acme");
461
462 assert_eq!(claims.subject(), "user-123");
463 assert_eq!(claims.get("role"), Some("admin"));
464 assert_eq!(claims.get("tenant"), Some("acme"));
465 assert_eq!(claims.get("missing"), None);
466 }
467
468 #[test]
469 fn test_auth_error_display() {
470 let err = AuthError::TokenMissing;
471 assert_eq!(err.to_string(), "Authentication token missing");
472
473 let err = AuthError::validation_failed("custom error");
474 assert_eq!(err.to_string(), "Token validation failed: custom error");
475 }
476
477 #[test]
478 fn test_token_extractor_default() {
479 let extractor = TokenExtractor::default();
480 match extractor {
481 TokenExtractor::Header(name) => assert_eq!(name, "Authorization"),
482 _ => panic!("Expected Header extractor"),
483 }
484 }
485}
486
487#[cfg(test)]
492mod property_tests {
493 use super::*;
494 use proptest::prelude::*;
495
496 fn token_strategy() -> impl Strategy<Value = String> {
498 prop::string::string_regex("[a-zA-Z0-9._-]{1,100}").unwrap()
499 }
500
501 fn header_name_strategy() -> impl Strategy<Value = String> {
503 prop::string::string_regex("[A-Za-z][A-Za-z0-9-]{0,30}").unwrap()
504 }
505
506 fn query_param_strategy() -> impl Strategy<Value = String> {
508 prop::string::string_regex("[a-z][a-z0-9_]{0,20}").unwrap()
509 }
510
511 fn extractor_strategy() -> impl Strategy<Value = TokenExtractor> {
513 prop_oneof![
514 header_name_strategy().prop_map(TokenExtractor::Header),
515 query_param_strategy().prop_map(TokenExtractor::Query),
516 Just(TokenExtractor::Protocol),
517 ]
518 }
519
520 proptest! {
521 #[test]
527 fn prop_auth_required_rejects_missing_token(
528 extractor in extractor_strategy()
529 ) {
530 let rt = tokio::runtime::Runtime::new().unwrap();
531 rt.block_on(async {
532 let config = WsAuthConfig::new(AcceptAllValidator)
533 .extractor(extractor)
534 .required(true);
535
536 let req = http::Request::builder()
538 .uri("ws://localhost/ws")
539 .body(())
540 .unwrap();
541
542 let result = config.authenticate(&req).await;
543 prop_assert!(matches!(result, Err(AuthError::TokenMissing)));
544 Ok(())
545 })?;
546 }
547
548 #[test]
554 fn prop_auth_accepts_valid_token_in_header(
555 token in token_strategy(),
556 header_name in header_name_strategy()
557 ) {
558 let rt = tokio::runtime::Runtime::new().unwrap();
559 rt.block_on(async {
560 let config = WsAuthConfig::new(AcceptAllValidator)
561 .extractor(TokenExtractor::Header(header_name.clone()))
562 .required(true);
563
564 let req = http::Request::builder()
565 .uri("ws://localhost/ws")
566 .header(&header_name, format!("Bearer {}", token))
567 .body(())
568 .unwrap();
569
570 let result = config.authenticate(&req).await;
571 prop_assert!(result.is_ok());
572 let claims = result.unwrap();
573 prop_assert!(claims.is_some());
574 let claims = claims.unwrap();
575 prop_assert_eq!(claims.subject(), &token);
576 Ok(())
577 })?;
578 }
579
580 #[test]
586 fn prop_auth_accepts_valid_token_in_query(
587 token in token_strategy(),
588 param_name in query_param_strategy()
589 ) {
590 let rt = tokio::runtime::Runtime::new().unwrap();
591 rt.block_on(async {
592 let config = WsAuthConfig::new(AcceptAllValidator)
593 .extractor(TokenExtractor::Query(param_name.clone()))
594 .required(true);
595
596 let uri = format!("ws://localhost/ws?{}={}", param_name, token);
597 let req = http::Request::builder()
598 .uri(&uri)
599 .body(())
600 .unwrap();
601
602 let result = config.authenticate(&req).await;
603 prop_assert!(result.is_ok());
604 let claims = result.unwrap();
605 prop_assert!(claims.is_some());
606 let claims = claims.unwrap();
607 prop_assert_eq!(claims.subject(), &token);
608 Ok(())
609 })?;
610 }
611
612 #[test]
618 fn prop_auth_rejects_invalid_token(
619 token in token_strategy()
620 ) {
621 let rt = tokio::runtime::Runtime::new().unwrap();
622 rt.block_on(async {
623 let config = WsAuthConfig::new(RejectAllValidator)
624 .extractor(TokenExtractor::Header("Authorization".to_string()))
625 .required(true);
626
627 let req = http::Request::builder()
628 .uri("ws://localhost/ws")
629 .header("Authorization", format!("Bearer {}", token))
630 .body(())
631 .unwrap();
632
633 let result = config.authenticate(&req).await;
634 prop_assert!(result.is_err());
635 prop_assert!(matches!(result, Err(AuthError::ValidationFailed(_))));
636 Ok(())
637 })?;
638 }
639
640 #[test]
646 fn prop_optional_auth_allows_missing_token(
647 extractor in extractor_strategy()
648 ) {
649 let rt = tokio::runtime::Runtime::new().unwrap();
650 rt.block_on(async {
651 let config = WsAuthConfig::new(AcceptAllValidator)
652 .extractor(extractor)
653 .required(false);
654
655 let req = http::Request::builder()
656 .uri("ws://localhost/ws")
657 .body(())
658 .unwrap();
659
660 let result = config.authenticate(&req).await;
661 prop_assert!(result.is_ok());
662 prop_assert!(result.unwrap().is_none());
663 Ok(())
664 })?;
665 }
666
667 #[test]
673 fn prop_static_validator_only_accepts_known_tokens(
674 valid_token in token_strategy(),
675 test_token in token_strategy(),
676 user_id in "[a-z]{3,10}"
677 ) {
678 let rt = tokio::runtime::Runtime::new().unwrap();
679 rt.block_on(async {
680 let validator = StaticTokenValidator::new()
681 .add_token(valid_token.clone(), Claims::new(user_id.clone()));
682
683 let result = validator.validate(&test_token).await;
684
685 if test_token == valid_token {
686 prop_assert!(result.is_ok());
687 let claims = result.unwrap();
688 prop_assert_eq!(claims.subject(), &user_id);
689 } else {
690 prop_assert!(result.is_err());
691 }
692 Ok(())
693 })?;
694 }
695 }
696}