1use std::collections::HashMap;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use async_trait::async_trait;
9use oauth2::RefreshToken;
10use serde::{Deserialize, Serialize};
11
12use turbomcp_protocol::{Error as McpError, Result as McpResult};
13
14use super::config::AuthProviderType;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
24#[deprecated(
25 since = "2.0.5",
26 note = "Use context::AuthContext instead. This type is legacy and will be removed in 3.0.0"
27)]
28pub struct AuthContext {
29 pub user_id: String,
31 pub user: UserInfo,
33 pub roles: Vec<String>,
35 pub permissions: Vec<String>,
37 pub request_id: String,
42 pub token: Option<TokenInfo>,
44 pub provider: String,
46 pub authenticated_at: SystemTime,
48 pub expires_at: Option<SystemTime>,
50 pub metadata: HashMap<String, serde_json::Value>,
52}
53
54#[allow(deprecated)]
55impl AuthContext {
56 pub fn to_unified(&self) -> crate::context::AuthContext {
58 crate::context::AuthContext {
59 sub: self.user_id.clone(),
60 iss: None, aud: None, exp: self
63 .expires_at
64 .and_then(|t| t.duration_since(UNIX_EPOCH).ok().map(|d| d.as_secs())),
65 iat: self
66 .authenticated_at
67 .duration_since(UNIX_EPOCH)
68 .ok()
69 .map(|d| d.as_secs()),
70 nbf: None, jti: None, user: self.user.clone(),
73 roles: self.roles.clone(),
74 permissions: self.permissions.clone(),
75 scopes: Vec::new(), request_id: Some(self.request_id.clone()),
77 authenticated_at: self.authenticated_at,
78 expires_at: self.expires_at,
79 token: self.token.clone(),
80 provider: self.provider.clone(),
81 #[cfg(feature = "dpop")]
82 dpop_jkt: None, metadata: self.metadata.clone(),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct UserInfo {
91 pub id: String,
93 pub username: String,
95 pub email: Option<String>,
97 pub display_name: Option<String>,
99 pub avatar_url: Option<String>,
101 pub metadata: HashMap<String, serde_json::Value>,
103}
104
105#[derive(Clone, Serialize, Deserialize)]
107pub struct TokenInfo {
108 pub access_token: String,
110 pub token_type: String,
112 pub refresh_token: Option<String>,
114 pub expires_in: Option<u64>,
116 pub scope: Option<String>,
118}
119
120impl std::fmt::Debug for TokenInfo {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("TokenInfo")
124 .field("access_token", &"[REDACTED]")
125 .field("token_type", &self.token_type)
126 .field(
127 "refresh_token",
128 &self.refresh_token.as_ref().map(|_| "[REDACTED]"),
129 )
130 .field("expires_in", &self.expires_in)
131 .field("scope", &self.scope)
132 .finish()
133 }
134}
135
136#[async_trait]
138pub trait AuthProvider: Send + Sync + std::fmt::Debug {
139 fn name(&self) -> &str;
141
142 fn provider_type(&self) -> AuthProviderType;
144
145 async fn authenticate(
147 &self,
148 credentials: AuthCredentials,
149 ) -> McpResult<crate::context::AuthContext>;
150
151 async fn validate_token(&self, token: &str) -> McpResult<crate::context::AuthContext>;
153
154 async fn refresh_token(&self, refresh_token: &str) -> McpResult<TokenInfo>;
156
157 async fn revoke_token(&self, token: &str) -> McpResult<()>;
159
160 async fn get_user_info(&self, token: &str) -> McpResult<UserInfo>;
162}
163
164#[derive(Clone, Serialize, Deserialize)]
166pub enum AuthCredentials {
167 UsernamePassword {
169 username: String,
171 password: String,
173 },
174 ApiKey {
176 key: String,
178 },
179 OAuth2Code {
181 code: String,
183 state: String,
185 },
186 JwtToken {
188 token: String,
190 },
191 Custom {
193 data: HashMap<String, serde_json::Value>,
195 },
196}
197
198impl std::fmt::Debug for AuthCredentials {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 match self {
202 AuthCredentials::UsernamePassword { username, .. } => f
203 .debug_struct("AuthCredentials::UsernamePassword")
204 .field("username", username)
205 .field("password", &"[REDACTED]")
206 .finish(),
207 AuthCredentials::ApiKey { .. } => f
208 .debug_struct("AuthCredentials::ApiKey")
209 .field("key", &"[REDACTED]")
210 .finish(),
211 AuthCredentials::OAuth2Code { state, .. } => f
212 .debug_struct("AuthCredentials::OAuth2Code")
213 .field("code", &"[REDACTED]")
214 .field("state", state)
215 .finish(),
216 AuthCredentials::JwtToken { .. } => f
217 .debug_struct("AuthCredentials::JwtToken")
218 .field("token", &"[REDACTED]")
219 .finish(),
220 AuthCredentials::Custom { .. } => f
221 .debug_struct("AuthCredentials::Custom")
222 .field("data", &"[REDACTED]")
223 .finish(),
224 }
225 }
226}
227
228#[async_trait]
230pub trait TokenStorage: Send + Sync + std::fmt::Debug {
231 async fn store_access_token(&self, user_id: &str, token: &AccessToken) -> McpResult<()>;
233
234 async fn get_access_token(&self, user_id: &str) -> McpResult<Option<AccessToken>>;
236
237 async fn store_refresh_token(&self, user_id: &str, token: &RefreshToken) -> McpResult<()>;
239
240 async fn get_refresh_token(&self, user_id: &str) -> McpResult<Option<RefreshToken>>;
242
243 async fn revoke_tokens(&self, user_id: &str) -> McpResult<()>;
245
246 async fn list_users(&self) -> McpResult<Vec<String>>;
248}
249
250#[derive(Clone)]
252pub struct AccessToken {
253 pub(crate) token: String,
255 pub(crate) expires_at: Option<SystemTime>,
257 pub(crate) scopes: Vec<String>,
259 pub(crate) metadata: HashMap<String, serde_json::Value>,
261}
262
263impl std::fmt::Debug for AccessToken {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 f.debug_struct("AccessToken")
267 .field("token", &"[REDACTED]")
268 .field("expires_at", &self.expires_at)
269 .field("scopes", &self.scopes)
270 .field("metadata", &self.metadata)
271 .finish()
272 }
273}
274
275impl AccessToken {
276 #[must_use]
278 pub fn new(
279 token: String,
280 expires_at: Option<SystemTime>,
281 scopes: Vec<String>,
282 metadata: HashMap<String, serde_json::Value>,
283 ) -> Self {
284 Self {
285 token,
286 expires_at,
287 scopes,
288 metadata,
289 }
290 }
291
292 #[must_use]
294 pub fn token(&self) -> &str {
295 &self.token
296 }
297
298 #[must_use]
300 pub fn expires_at(&self) -> Option<SystemTime> {
301 self.expires_at
302 }
303
304 #[must_use]
306 pub fn scopes(&self) -> &[String] {
307 &self.scopes
308 }
309
310 #[must_use]
312 pub fn metadata(&self) -> &HashMap<String, serde_json::Value> {
313 &self.metadata
314 }
315}
316
317#[async_trait]
319pub trait AuthMiddleware: Send + Sync {
320 async fn extract_token(&self, headers: &HashMap<String, String>) -> Option<String>;
322
323 async fn handle_auth_failure(&self, error: McpError) -> McpResult<()>;
325}
326
327#[derive(Debug, Clone)]
329pub struct DefaultAuthMiddleware;
330
331#[async_trait]
332impl AuthMiddleware for DefaultAuthMiddleware {
333 async fn extract_token(&self, headers: &HashMap<String, String>) -> Option<String> {
334 if let Some(auth_header) = headers
336 .get("authorization")
337 .or_else(|| headers.get("Authorization"))
338 {
339 if let Some(token) = auth_header.strip_prefix("Bearer ") {
340 return Some(token.to_string());
341 }
342 if let Some(token) = auth_header.strip_prefix("ApiKey ") {
343 return Some(token.to_string());
344 }
345 }
346
347 if let Some(api_key) = headers
349 .get("x-api-key")
350 .or_else(|| headers.get("X-API-Key"))
351 {
352 return Some(api_key.clone());
353 }
354
355 None
356 }
357
358 async fn handle_auth_failure(&self, error: McpError) -> McpResult<()> {
359 tracing::warn!("Authentication failed: {}", error);
360 Err(Box::new(error))
361 }
362}