Skip to main content

tower_mcp/
auth.rs

1//! Authentication middleware helpers for MCP servers
2//!
3//! This module provides helper types and layers for common authentication patterns.
4//! Since tower-mcp is built on Tower, standard tower middleware can be used directly.
5//!
6//! # Patterns
7//!
8//! ## API Key Authentication
9//!
10//! ```rust,ignore
11//! // Requires the `http` feature
12//! use tower_mcp::auth::{AuthConfig, ApiKeyValidator};
13//! use tower_mcp::{McpRouter, HttpTransport};
14//! use std::sync::Arc;
15//!
16//! // Simple in-memory API key validator
17//! let valid_keys = vec!["sk-test-key-123".to_string()];
18//! let validator = ApiKeyValidator::new(valid_keys);
19//!
20//! let router = McpRouter::new().server_info("my-server", "1.0.0");
21//! let transport = HttpTransport::new(router);
22//!
23//! // The auth layer extracts the key from the Authorization header
24//! // and validates it using the provided validator
25//! ```
26//!
27//! ## Bearer Token Authentication
28//!
29//! For OAuth2/JWT tokens, use the `BearerTokenValidator` trait to implement
30//! custom validation logic (e.g., JWT verification, token introspection).
31//!
32//! ## Custom Authentication
33//!
34//! You can implement custom auth by creating a Tower layer. See the examples
35//! directory for a complete example.
36
37use std::collections::HashSet;
38use std::future::Future;
39use std::sync::Arc;
40
41use tower::Layer;
42
43/// Result of an authentication attempt
44#[derive(Debug, Clone)]
45pub enum AuthResult {
46    /// Authentication succeeded with optional user/client info
47    Authenticated(Option<AuthInfo>),
48    /// Authentication failed with a reason
49    Failed(AuthError),
50}
51
52/// Information about an authenticated client
53#[derive(Debug, Clone)]
54pub struct AuthInfo {
55    /// Client/user identifier
56    pub client_id: String,
57    /// Optional additional claims or metadata
58    pub claims: Option<serde_json::Value>,
59}
60
61/// Authentication error
62#[derive(Debug, Clone)]
63pub struct AuthError {
64    /// Error code (e.g., "invalid_token", "expired_token")
65    pub code: String,
66    /// Human-readable error message
67    pub message: String,
68}
69
70impl std::fmt::Display for AuthError {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        write!(f, "{}: {}", self.code, self.message)
73    }
74}
75
76impl std::error::Error for AuthError {}
77
78// =============================================================================
79// Validation Trait
80// =============================================================================
81
82/// Trait for validating authentication credentials.
83///
84/// Implement this trait to provide custom authentication logic for use
85/// with [`AuthLayer`] and [`AuthService`].
86///
87/// The credential string passed to [`validate`](Validate::validate) is the
88/// value extracted from the configured request header after parsing
89/// (e.g., the token portion of `"Bearer sk-123"`).
90///
91/// # Example
92///
93/// ```rust
94/// use tower_mcp::auth::{Validate, AuthResult, AuthInfo, AuthError};
95///
96/// #[derive(Clone)]
97/// struct MyValidator;
98///
99/// impl Validate for MyValidator {
100///     async fn validate(&self, credential: &str) -> AuthResult {
101///         if credential.starts_with("sk-") {
102///             AuthResult::Authenticated(Some(AuthInfo {
103///                 client_id: credential.to_string(),
104///                 claims: None,
105///             }))
106///         } else {
107///             AuthResult::Failed(AuthError {
108///                 code: "invalid_credential".to_string(),
109///                 message: "Credential must start with sk-".to_string(),
110///             })
111///         }
112///     }
113/// }
114/// ```
115pub trait Validate: Clone + Send + Sync + 'static {
116    /// Validate a credential and return the authentication result.
117    fn validate(&self, credential: &str) -> impl Future<Output = AuthResult> + Send;
118}
119
120// =============================================================================
121// API Key Authentication
122// =============================================================================
123
124/// Simple in-memory API key validator
125///
126/// For production use, consider:
127/// - Database-backed validation
128/// - Caching with TTL
129/// - Rate limiting per key
130#[derive(Debug, Clone)]
131pub struct ApiKeyValidator {
132    valid_keys: Arc<HashSet<String>>,
133}
134
135impl ApiKeyValidator {
136    /// Create a new validator with a list of valid API keys
137    pub fn new(keys: impl IntoIterator<Item = String>) -> Self {
138        Self {
139            valid_keys: Arc::new(keys.into_iter().collect()),
140        }
141    }
142
143    /// Add a key to the valid set
144    pub fn add_key(&mut self, key: String) {
145        Arc::make_mut(&mut self.valid_keys).insert(key);
146    }
147
148    /// Check if a key is valid
149    pub fn is_valid(&self, key: &str) -> bool {
150        self.valid_keys.contains(key)
151    }
152}
153
154impl Validate for ApiKeyValidator {
155    async fn validate(&self, key: &str) -> AuthResult {
156        if self.valid_keys.contains(key) {
157            AuthResult::Authenticated(Some(AuthInfo {
158                client_id: format!("api_key:{}", &key[..8.min(key.len())]),
159                claims: None,
160            }))
161        } else {
162            AuthResult::Failed(AuthError {
163                code: "invalid_api_key".to_string(),
164                message: "The provided API key is not valid".to_string(),
165            })
166        }
167    }
168}
169
170// =============================================================================
171// Bearer Token Authentication
172// =============================================================================
173
174/// Simple bearer token validator that checks against a static set of tokens.
175///
176/// For production, implement [`Validate`] with:
177/// - JWT verification using a signing key
178/// - OAuth2 token introspection
179/// - OIDC ID token validation
180#[derive(Debug, Clone)]
181pub struct StaticBearerValidator {
182    valid_tokens: Arc<HashSet<String>>,
183}
184
185impl StaticBearerValidator {
186    /// Create a new validator with a list of valid tokens
187    pub fn new(tokens: impl IntoIterator<Item = String>) -> Self {
188        Self {
189            valid_tokens: Arc::new(tokens.into_iter().collect()),
190        }
191    }
192}
193
194impl Validate for StaticBearerValidator {
195    async fn validate(&self, token: &str) -> AuthResult {
196        if self.valid_tokens.contains(token) {
197            AuthResult::Authenticated(Some(AuthInfo {
198                client_id: format!("bearer:{}", &token[..8.min(token.len())]),
199                claims: None,
200            }))
201        } else {
202            AuthResult::Failed(AuthError {
203                code: "invalid_token".to_string(),
204                message: "The provided bearer token is not valid".to_string(),
205            })
206        }
207    }
208}
209
210// =============================================================================
211// Authorization Header Parsing
212// =============================================================================
213
214/// Extract an API key from an Authorization header
215///
216/// Supports formats:
217/// - `Bearer <key>` (standard)
218/// - `ApiKey <key>`
219/// - `<key>` (raw key)
220pub fn extract_api_key(auth_header: &str) -> Option<&str> {
221    let auth_header = auth_header.trim();
222
223    if let Some(key) = auth_header.strip_prefix("Bearer ") {
224        Some(key.trim())
225    } else if let Some(key) = auth_header.strip_prefix("ApiKey ") {
226        Some(key.trim())
227    } else if !auth_header.contains(' ') {
228        // Raw key without prefix
229        Some(auth_header)
230    } else {
231        None
232    }
233}
234
235/// Extract a bearer token from an Authorization header
236pub fn extract_bearer_token(auth_header: &str) -> Option<&str> {
237    auth_header.trim().strip_prefix("Bearer ").map(|t| t.trim())
238}
239
240// =============================================================================
241// Generic Auth Layer
242// =============================================================================
243
244/// A Tower layer that performs authentication using a provided validator
245///
246/// This is a generic auth layer that can be used with any validator that
247/// implements the appropriate validation trait.
248#[derive(Clone)]
249pub struct AuthLayer<V> {
250    validator: V,
251    header_name: String,
252}
253
254impl<V> AuthLayer<V> {
255    /// Create a new auth layer with the given validator
256    ///
257    /// By default, looks for the `Authorization` header
258    pub fn new(validator: V) -> Self {
259        Self {
260            validator,
261            header_name: "Authorization".to_string(),
262        }
263    }
264
265    /// Use a custom header name for the auth token
266    pub fn header_name(mut self, name: impl Into<String>) -> Self {
267        self.header_name = name.into();
268        self
269    }
270}
271
272impl<S, V: Clone> Layer<S> for AuthLayer<V> {
273    type Service = AuthService<S, V>;
274
275    fn layer(&self, inner: S) -> Self::Service {
276        AuthService {
277            inner,
278            validator: self.validator.clone(),
279            header_name: self.header_name.clone(),
280        }
281    }
282}
283
284/// Tower service that performs authentication on incoming requests.
285///
286/// Created by [`AuthLayer`]. Extracts credentials from the configured HTTP
287/// header, validates them using the provided [`Validate`] implementation,
288/// and either forwards the request (injecting [`AuthInfo`] into request
289/// extensions) or returns an HTTP 401 response.
290///
291/// # Example
292///
293/// ```rust,ignore
294/// // Requires the `http` feature
295/// use tower::ServiceBuilder;
296/// use tower_mcp::auth::{AuthLayer, ApiKeyValidator};
297///
298/// let validator = ApiKeyValidator::new(vec!["sk-test-key-123".to_string()]);
299///
300/// let service = ServiceBuilder::new()
301///     .layer(AuthLayer::new(validator))
302///     .service(inner_service);
303/// ```
304#[derive(Clone)]
305#[cfg_attr(not(feature = "http"), allow(dead_code))]
306pub struct AuthService<S, V> {
307    inner: S,
308    validator: V,
309    header_name: String,
310}
311
312#[cfg(feature = "http")]
313impl<S, V> tower_service::Service<axum::http::Request<axum::body::Body>> for AuthService<S, V>
314where
315    S: tower_service::Service<
316            axum::http::Request<axum::body::Body>,
317            Response = axum::response::Response,
318        > + Clone
319        + Send
320        + 'static,
321    S::Future: Send,
322    S::Error: Into<crate::BoxError> + Send,
323    V: Validate,
324{
325    type Response = axum::response::Response;
326    type Error = S::Error;
327    type Future =
328        std::pin::Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
329
330    fn poll_ready(
331        &mut self,
332        cx: &mut std::task::Context<'_>,
333    ) -> std::task::Poll<Result<(), Self::Error>> {
334        self.inner.poll_ready(cx)
335    }
336
337    fn call(&mut self, req: axum::http::Request<axum::body::Body>) -> Self::Future {
338        let credential = req
339            .headers()
340            .get(&self.header_name)
341            .and_then(|v| v.to_str().ok())
342            .and_then(extract_api_key)
343            .map(|s| s.to_owned());
344
345        let mut inner = self.inner.clone();
346        let validator = self.validator.clone();
347
348        Box::pin(async move {
349            let Some(credential) = credential else {
350                return Ok(unauthorized_response(
351                    "Missing authentication credentials. Provide via Authorization header.",
352                ));
353            };
354
355            match validator.validate(&credential).await {
356                AuthResult::Authenticated(info) => {
357                    let mut req = req;
358                    if let Some(info) = info {
359                        req.extensions_mut().insert(info);
360                    }
361                    inner.call(req).await
362                }
363                AuthResult::Failed(err) => Ok(unauthorized_response(&err.message)),
364            }
365        })
366    }
367}
368
369/// Construct an HTTP 401 Unauthorized response with a JSON-RPC error body.
370#[cfg(feature = "http")]
371fn unauthorized_response(message: &str) -> axum::response::Response {
372    use axum::http::StatusCode;
373    use axum::response::IntoResponse;
374
375    let body = serde_json::json!({
376        "jsonrpc": "2.0",
377        "error": {
378            "code": -32001,
379            "message": message
380        },
381        "id": null
382    });
383
384    (StatusCode::UNAUTHORIZED, axum::Json(body)).into_response()
385}
386
387// =============================================================================
388// Helper for building auth middleware
389// =============================================================================
390
391/// Builder for creating auth middleware configurations
392#[derive(Clone)]
393pub struct AuthConfig {
394    /// Whether to allow unauthenticated requests to pass through
395    pub allow_anonymous: bool,
396    /// Paths that don't require authentication
397    pub public_paths: Vec<String>,
398    /// Custom header name for auth token
399    pub header_name: String,
400}
401
402impl Default for AuthConfig {
403    fn default() -> Self {
404        Self {
405            allow_anonymous: false,
406            public_paths: Vec::new(),
407            header_name: "Authorization".to_string(),
408        }
409    }
410}
411
412impl AuthConfig {
413    /// Create a new auth config
414    pub fn new() -> Self {
415        Self::default()
416    }
417
418    /// Allow anonymous requests (no auth required)
419    pub fn allow_anonymous(mut self, allow: bool) -> Self {
420        self.allow_anonymous = allow;
421        self
422    }
423
424    /// Add paths that don't require authentication
425    pub fn public_path(mut self, path: impl Into<String>) -> Self {
426        self.public_paths.push(path.into());
427        self
428    }
429
430    /// Set the header name for auth tokens
431    pub fn header_name(mut self, name: impl Into<String>) -> Self {
432        self.header_name = name.into();
433        self
434    }
435
436    /// Check if a path is public (doesn't require auth)
437    pub fn is_public(&self, path: &str) -> bool {
438        self.public_paths.iter().any(|p| path.starts_with(p))
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_extract_api_key_bearer() {
448        assert_eq!(extract_api_key("Bearer sk-123"), Some("sk-123"));
449        assert_eq!(extract_api_key("Bearer  sk-123 "), Some("sk-123"));
450    }
451
452    #[test]
453    fn test_extract_api_key_apikey_prefix() {
454        assert_eq!(extract_api_key("ApiKey sk-123"), Some("sk-123"));
455    }
456
457    #[test]
458    fn test_extract_api_key_raw() {
459        assert_eq!(extract_api_key("sk-123"), Some("sk-123"));
460    }
461
462    #[test]
463    fn test_extract_api_key_invalid() {
464        assert_eq!(extract_api_key("Basic user:pass"), None);
465    }
466
467    #[test]
468    fn test_extract_bearer_token() {
469        assert_eq!(extract_bearer_token("Bearer abc123"), Some("abc123"));
470        assert_eq!(extract_bearer_token("bearer abc123"), None); // case sensitive
471        assert_eq!(extract_bearer_token("abc123"), None);
472    }
473
474    #[tokio::test]
475    async fn test_api_key_validator() {
476        let validator = ApiKeyValidator::new(vec!["valid-key".to_string()]);
477
478        match validator.validate("valid-key").await {
479            AuthResult::Authenticated(info) => {
480                assert!(info.is_some());
481            }
482            AuthResult::Failed(_) => panic!("Expected authentication to succeed"),
483        }
484
485        match validator.validate("invalid-key").await {
486            AuthResult::Authenticated(_) => panic!("Expected authentication to fail"),
487            AuthResult::Failed(err) => {
488                assert_eq!(err.code, "invalid_api_key");
489            }
490        }
491    }
492
493    #[tokio::test]
494    async fn test_bearer_validator() {
495        let validator = StaticBearerValidator::new(vec!["token123".to_string()]);
496
497        match validator.validate("token123").await {
498            AuthResult::Authenticated(info) => {
499                assert!(info.is_some());
500            }
501            AuthResult::Failed(_) => panic!("Expected authentication to succeed"),
502        }
503
504        match validator.validate("bad-token").await {
505            AuthResult::Authenticated(_) => panic!("Expected authentication to fail"),
506            AuthResult::Failed(err) => {
507                assert_eq!(err.code, "invalid_token");
508            }
509        }
510    }
511
512    #[test]
513    fn test_auth_config() {
514        let config = AuthConfig::new()
515            .allow_anonymous(false)
516            .public_path("/health")
517            .public_path("/metrics")
518            .header_name("X-API-Key");
519
520        assert!(!config.allow_anonymous);
521        assert!(config.is_public("/health"));
522        assert!(config.is_public("/metrics/cpu"));
523        assert!(!config.is_public("/api/tools"));
524        assert_eq!(config.header_name, "X-API-Key");
525    }
526
527    #[test]
528    fn test_auth_layer_creates_service() {
529        let validator = ApiKeyValidator::new(vec!["key".to_string()]);
530        let layer = AuthLayer::new(validator);
531        // Wrap a no-op service to verify the Layer impl works
532        let _service: AuthService<(), ApiKeyValidator> = layer.layer(());
533    }
534
535    #[cfg(feature = "http")]
536    mod http_tests {
537        use super::*;
538        use std::pin::Pin;
539        use std::task::{Context, Poll};
540
541        use axum::body::Body;
542        use axum::http::{Request, StatusCode};
543        use tower::ServiceExt;
544        use tower_service::Service;
545
546        /// A minimal inner service that returns 200 OK for any request
547        #[derive(Clone)]
548        struct OkService;
549
550        impl Service<Request<Body>> for OkService {
551            type Response = axum::response::Response;
552            type Error = std::convert::Infallible;
553            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
554
555            fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
556                Poll::Ready(Ok(()))
557            }
558
559            fn call(&mut self, _req: Request<Body>) -> Self::Future {
560                Box::pin(async {
561                    Ok(axum::response::Response::builder()
562                        .status(StatusCode::OK)
563                        .body(Body::empty())
564                        .unwrap())
565                })
566            }
567        }
568
569        #[tokio::test]
570        async fn test_auth_service_rejects_missing_credentials() {
571            let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
572            let layer = AuthLayer::new(validator);
573            let mut service = layer.layer(OkService);
574
575            let req = Request::builder().uri("/").body(Body::empty()).unwrap();
576
577            let resp = service.ready().await.unwrap().call(req).await.unwrap();
578            assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
579        }
580
581        #[tokio::test]
582        async fn test_auth_service_rejects_invalid_key() {
583            let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
584            let layer = AuthLayer::new(validator);
585            let mut service = layer.layer(OkService);
586
587            let req = Request::builder()
588                .uri("/")
589                .header("Authorization", "Bearer sk-wrong-key")
590                .body(Body::empty())
591                .unwrap();
592
593            let resp = service.ready().await.unwrap().call(req).await.unwrap();
594            assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
595        }
596
597        #[tokio::test]
598        async fn test_auth_service_accepts_valid_key() {
599            let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
600            let layer = AuthLayer::new(validator);
601            let mut service = layer.layer(OkService);
602
603            let req = Request::builder()
604                .uri("/")
605                .header("Authorization", "Bearer sk-test-123")
606                .body(Body::empty())
607                .unwrap();
608
609            let resp = service.ready().await.unwrap().call(req).await.unwrap();
610            assert_eq!(resp.status(), StatusCode::OK);
611        }
612
613        #[tokio::test]
614        async fn test_auth_service_injects_auth_info() {
615            let validator = ApiKeyValidator::new(vec!["sk-test-123".to_string()]);
616            let layer = AuthLayer::new(validator);
617
618            // Inner service that checks for AuthInfo in extensions
619            #[derive(Clone)]
620            struct CheckAuthInfo;
621
622            impl Service<Request<Body>> for CheckAuthInfo {
623                type Response = axum::response::Response;
624                type Error = std::convert::Infallible;
625                type Future =
626                    Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
627
628                fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
629                    Poll::Ready(Ok(()))
630                }
631
632                fn call(&mut self, req: Request<Body>) -> Self::Future {
633                    let has_auth = req.extensions().get::<AuthInfo>().is_some();
634                    Box::pin(async move {
635                        let status = if has_auth {
636                            StatusCode::OK
637                        } else {
638                            StatusCode::INTERNAL_SERVER_ERROR
639                        };
640                        Ok(axum::response::Response::builder()
641                            .status(status)
642                            .body(Body::empty())
643                            .unwrap())
644                    })
645                }
646            }
647
648            let mut service = layer.layer(CheckAuthInfo);
649
650            let req = Request::builder()
651                .uri("/")
652                .header("Authorization", "Bearer sk-test-123")
653                .body(Body::empty())
654                .unwrap();
655
656            let resp = service.ready().await.unwrap().call(req).await.unwrap();
657            assert_eq!(resp.status(), StatusCode::OK);
658        }
659
660        #[tokio::test]
661        async fn test_auth_service_custom_header() {
662            let validator = ApiKeyValidator::new(vec!["my-key".to_string()]);
663            let layer = AuthLayer::new(validator).header_name("X-API-Key");
664            let mut service = layer.layer(OkService);
665
666            // Standard Authorization header should not work
667            let req = Request::builder()
668                .uri("/")
669                .header("Authorization", "Bearer my-key")
670                .body(Body::empty())
671                .unwrap();
672            let resp = service.ready().await.unwrap().call(req).await.unwrap();
673            assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
674
675            // Custom header should work
676            let req = Request::builder()
677                .uri("/")
678                .header("X-API-Key", "my-key")
679                .body(Body::empty())
680                .unwrap();
681            let resp = service.ready().await.unwrap().call(req).await.unwrap();
682            assert_eq!(resp.status(), StatusCode::OK);
683        }
684    }
685}