turbomcp_auth/
types.rs

1//! Core Authentication Types
2//!
3//! This module contains core types used throughout the TurboMCP authentication system.
4
5use std::collections::HashMap;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use async_trait::async_trait;
9use oauth2::RefreshToken;
10use serde::{Deserialize, Serialize};
11
12use turbomcp_protocol::{Error as McpError, Result as McpResult};
13
14use super::config::AuthProviderType;
15
16/// Authentication context (LEGACY - use `context::AuthContext` instead)
17///
18/// NOTE: This is the legacy AuthContext type. New code should use
19/// `crate::context::AuthContext` (the unified canonical type).
20///
21/// This type will be removed in version 3.0.0. Use the unified `context::AuthContext` instead.
22/// The `to_unified()` method can help with migration.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[deprecated(
25    since = "2.0.5",
26    note = "Use context::AuthContext instead. This type is legacy and will be removed in 3.0.0"
27)]
28pub struct AuthContext {
29    /// User ID
30    pub user_id: String,
31    /// User information
32    pub user: UserInfo,
33    /// User roles
34    pub roles: Vec<String>,
35    /// User permissions
36    pub permissions: Vec<String>,
37    /// Request ID for replay protection (MCP compliant - NOT session-based)
38    ///
39    /// Per MCP specification, authentication is stateless. This field is for
40    /// request-level binding (DPoP nonces, one-time tokens), not session management.
41    pub request_id: String,
42    /// Token information
43    pub token: Option<TokenInfo>,
44    /// Authentication provider used
45    pub provider: String,
46    /// Authentication timestamp
47    pub authenticated_at: SystemTime,
48    /// Token expiry time
49    pub expires_at: Option<SystemTime>,
50    /// Additional metadata
51    pub metadata: HashMap<String, serde_json::Value>,
52}
53
54#[allow(deprecated)]
55impl AuthContext {
56    /// Convert legacy types::AuthContext to unified context::AuthContext
57    pub fn to_unified(&self) -> crate::context::AuthContext {
58        crate::context::AuthContext {
59            sub: self.user_id.clone(),
60            iss: None, // Not present in legacy type
61            aud: None, // Not present in legacy type
62            exp: self
63                .expires_at
64                .and_then(|t| t.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())),
65            iat: self
66                .authenticated_at
67                .duration_since(UNIX_EPOCH)
68                .ok()
69                .map(|d| d.as_secs()),
70            nbf: None, // Not present in legacy type
71            jti: None, // Not present in legacy type
72            user: self.user.clone(),
73            roles: self.roles.clone(),
74            permissions: self.permissions.clone(),
75            scopes: Vec::new(), // Not present in legacy type
76            request_id: Some(self.request_id.clone()),
77            authenticated_at: self.authenticated_at,
78            expires_at: self.expires_at,
79            token: self.token.clone(),
80            provider: self.provider.clone(),
81            #[cfg(feature = "dpop")]
82            dpop_jkt: None, // Not present in legacy type
83            metadata: self.metadata.clone(),
84        }
85    }
86}
87
88/// User information
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct UserInfo {
91    /// User ID
92    pub id: String,
93    /// Username
94    pub username: String,
95    /// Email address
96    pub email: Option<String>,
97    /// Display name
98    pub display_name: Option<String>,
99    /// Avatar URL
100    pub avatar_url: Option<String>,
101    /// User metadata
102    pub metadata: HashMap<String, serde_json::Value>,
103}
104
105/// Token information
106#[derive(Clone, Serialize, Deserialize)]
107pub struct TokenInfo {
108    /// Access token
109    pub access_token: String,
110    /// Token type (Bearer, etc.)
111    pub token_type: String,
112    /// Refresh token
113    pub refresh_token: Option<String>,
114    /// Token expiry in seconds
115    pub expires_in: Option<u64>,
116    /// Token scope
117    pub scope: Option<String>,
118}
119
120// Manual Debug impl to prevent token exposure in logs (Sprint 3.6)
121impl std::fmt::Debug for TokenInfo {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("TokenInfo")
124            .field("access_token", &"[REDACTED]")
125            .field("token_type", &self.token_type)
126            .field(
127                "refresh_token",
128                &self.refresh_token.as_ref().map(|_| "[REDACTED]"),
129            )
130            .field("expires_in", &self.expires_in)
131            .field("scope", &self.scope)
132            .finish()
133    }
134}
135
136/// Authentication provider trait
137#[async_trait]
138pub trait AuthProvider: Send + Sync + std::fmt::Debug {
139    /// Provider name
140    fn name(&self) -> &str;
141
142    /// Provider type
143    fn provider_type(&self) -> AuthProviderType;
144
145    /// Authenticate user with credentials
146    async fn authenticate(
147        &self,
148        credentials: AuthCredentials,
149    ) -> McpResult<crate::context::AuthContext>;
150
151    /// Validate existing token/session
152    async fn validate_token(&self, token: &str) -> McpResult<crate::context::AuthContext>;
153
154    /// Refresh access token
155    async fn refresh_token(&self, refresh_token: &str) -> McpResult<TokenInfo>;
156
157    /// Revoke token/session
158    async fn revoke_token(&self, token: &str) -> McpResult<()>;
159
160    /// Get user information
161    async fn get_user_info(&self, token: &str) -> McpResult<UserInfo>;
162}
163
164/// Authentication credentials
165#[derive(Clone, Serialize, Deserialize)]
166pub enum AuthCredentials {
167    /// Username and password
168    UsernamePassword {
169        /// Username
170        username: String,
171        /// Password
172        password: String,
173    },
174    /// API key
175    ApiKey {
176        /// API key
177        key: String,
178    },
179    /// OAuth 2.1 authorization code
180    OAuth2Code {
181        /// Authorization code
182        code: String,
183        /// State parameter
184        state: String,
185    },
186    /// JWT token
187    JwtToken {
188        /// JWT token
189        token: String,
190    },
191    /// Custom credentials
192    Custom {
193        /// Custom credential data
194        data: HashMap<String, serde_json::Value>,
195    },
196}
197
198// Manual Debug impl to prevent credential exposure in logs (Sprint 3.6)
199impl std::fmt::Debug for AuthCredentials {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        match self {
202            AuthCredentials::UsernamePassword { username, .. } => f
203                .debug_struct("AuthCredentials::UsernamePassword")
204                .field("username", username)
205                .field("password", &"[REDACTED]")
206                .finish(),
207            AuthCredentials::ApiKey { .. } => f
208                .debug_struct("AuthCredentials::ApiKey")
209                .field("key", &"[REDACTED]")
210                .finish(),
211            AuthCredentials::OAuth2Code { state, .. } => f
212                .debug_struct("AuthCredentials::OAuth2Code")
213                .field("code", &"[REDACTED]")
214                .field("state", state)
215                .finish(),
216            AuthCredentials::JwtToken { .. } => f
217                .debug_struct("AuthCredentials::JwtToken")
218                .field("token", &"[REDACTED]")
219                .finish(),
220            AuthCredentials::Custom { .. } => f
221                .debug_struct("AuthCredentials::Custom")
222                .field("data", &"[REDACTED]")
223                .finish(),
224        }
225    }
226}
227
228/// Secure token storage abstraction
229#[async_trait]
230pub trait TokenStorage: Send + Sync + std::fmt::Debug {
231    /// Store access token securely
232    async fn store_access_token(&self, user_id: &str, token: &AccessToken) -> McpResult<()>;
233
234    /// Retrieve access token
235    async fn get_access_token(&self, user_id: &str) -> McpResult<Option<AccessToken>>;
236
237    /// Store refresh token securely (encrypted at rest)
238    async fn store_refresh_token(&self, user_id: &str, token: &RefreshToken) -> McpResult<()>;
239
240    /// Retrieve refresh token
241    async fn get_refresh_token(&self, user_id: &str) -> McpResult<Option<RefreshToken>>;
242
243    /// Remove all tokens for user (logout)
244    async fn revoke_tokens(&self, user_id: &str) -> McpResult<()>;
245
246    /// List all users with stored tokens (for admin)
247    async fn list_users(&self) -> McpResult<Vec<String>>;
248}
249
250/// Secure access token with metadata
251#[derive(Clone)]
252pub struct AccessToken {
253    /// The actual token
254    pub(crate) token: String,
255    /// Token expiration time
256    pub(crate) expires_at: Option<SystemTime>,
257    /// Token scopes
258    pub(crate) scopes: Vec<String>,
259    /// Provider metadata
260    pub(crate) metadata: HashMap<String, serde_json::Value>,
261}
262
263// Manual Debug impl to prevent token exposure in logs (Sprint 3.6)
264impl std::fmt::Debug for AccessToken {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        f.debug_struct("AccessToken")
267            .field("token", &"[REDACTED]")
268            .field("expires_at", &self.expires_at)
269            .field("scopes", &self.scopes)
270            .field("metadata", &self.metadata)
271            .finish()
272    }
273}
274
275impl AccessToken {
276    /// Create a new access token
277    #[must_use]
278    pub fn new(
279        token: String,
280        expires_at: Option<SystemTime>,
281        scopes: Vec<String>,
282        metadata: HashMap<String, serde_json::Value>,
283    ) -> Self {
284        Self {
285            token,
286            expires_at,
287            scopes,
288            metadata,
289        }
290    }
291
292    /// Get the token value
293    #[must_use]
294    pub fn token(&self) -> &str {
295        &self.token
296    }
297
298    /// Get the token expiration time
299    #[must_use]
300    pub fn expires_at(&self) -> Option<SystemTime> {
301        self.expires_at
302    }
303
304    /// Get the token scopes
305    #[must_use]
306    pub fn scopes(&self) -> &[String] {
307        &self.scopes
308    }
309
310    /// Get the token metadata
311    #[must_use]
312    pub fn metadata(&self) -> &HashMap<String, serde_json::Value> {
313        &self.metadata
314    }
315}
316
317/// Authentication middleware trait
318#[async_trait]
319pub trait AuthMiddleware: Send + Sync {
320    /// Extract authentication token from request
321    async fn extract_token(&self, headers: &HashMap<String, String>) -> Option<String>;
322
323    /// Handle authentication failure
324    async fn handle_auth_failure(&self, error: McpError) -> McpResult<()>;
325}
326
327/// Default authentication middleware
328#[derive(Debug, Clone)]
329pub struct DefaultAuthMiddleware;
330
331#[async_trait]
332impl AuthMiddleware for DefaultAuthMiddleware {
333    async fn extract_token(&self, headers: &HashMap<String, String>) -> Option<String> {
334        // Try Authorization header first
335        if let Some(auth_header) = headers
336            .get("authorization")
337            .or_else(|| headers.get("Authorization"))
338        {
339            if let Some(token) = auth_header.strip_prefix("Bearer ") {
340                return Some(token.to_string());
341            }
342            if let Some(token) = auth_header.strip_prefix("ApiKey ") {
343                return Some(token.to_string());
344            }
345        }
346
347        // Try X-API-Key header
348        if let Some(api_key) = headers
349            .get("x-api-key")
350            .or_else(|| headers.get("X-API-Key"))
351        {
352            return Some(api_key.clone());
353        }
354
355        None
356    }
357
358    async fn handle_auth_failure(&self, error: McpError) -> McpResult<()> {
359        tracing::warn!("Authentication failed: {}", error);
360        Err(Box::new(error))
361    }
362}