1use std::collections::HashMap;
6use std::time::SystemTime;
7
8use async_trait::async_trait;
9use oauth2::RefreshToken;
10use serde::{Deserialize, Serialize};
11
12use turbomcp_protocol::{Error as McpError, Result as McpResult};
13
14use super::config::AuthProviderType;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AuthContext {
19 pub user_id: String,
21 pub user: UserInfo,
23 pub roles: Vec<String>,
25 pub permissions: Vec<String>,
27 pub session_id: String,
29 pub token: Option<TokenInfo>,
31 pub provider: String,
33 pub authenticated_at: SystemTime,
35 pub expires_at: Option<SystemTime>,
37 pub metadata: HashMap<String, serde_json::Value>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct UserInfo {
44 pub id: String,
46 pub username: String,
48 pub email: Option<String>,
50 pub display_name: Option<String>,
52 pub avatar_url: Option<String>,
54 pub metadata: HashMap<String, serde_json::Value>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct TokenInfo {
61 pub access_token: String,
63 pub token_type: String,
65 pub refresh_token: Option<String>,
67 pub expires_in: Option<u64>,
69 pub scope: Option<String>,
71}
72
73#[async_trait]
75pub trait AuthProvider: Send + Sync + std::fmt::Debug {
76 fn name(&self) -> &str;
78
79 fn provider_type(&self) -> AuthProviderType;
81
82 async fn authenticate(&self, credentials: AuthCredentials) -> McpResult<AuthContext>;
84
85 async fn validate_token(&self, token: &str) -> McpResult<AuthContext>;
87
88 async fn refresh_token(&self, refresh_token: &str) -> McpResult<TokenInfo>;
90
91 async fn revoke_token(&self, token: &str) -> McpResult<()>;
93
94 async fn get_user_info(&self, token: &str) -> McpResult<UserInfo>;
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub enum AuthCredentials {
101 UsernamePassword {
103 username: String,
105 password: String,
107 },
108 ApiKey {
110 key: String,
112 },
113 OAuth2Code {
115 code: String,
117 state: String,
119 },
120 JwtToken {
122 token: String,
124 },
125 Custom {
127 data: HashMap<String, serde_json::Value>,
129 },
130}
131
132#[async_trait]
134pub trait TokenStorage: Send + Sync + std::fmt::Debug {
135 async fn store_access_token(&self, user_id: &str, token: &AccessToken) -> McpResult<()>;
137
138 async fn get_access_token(&self, user_id: &str) -> McpResult<Option<AccessToken>>;
140
141 async fn store_refresh_token(&self, user_id: &str, token: &RefreshToken) -> McpResult<()>;
143
144 async fn get_refresh_token(&self, user_id: &str) -> McpResult<Option<RefreshToken>>;
146
147 async fn revoke_tokens(&self, user_id: &str) -> McpResult<()>;
149
150 async fn list_users(&self) -> McpResult<Vec<String>>;
152}
153
154#[derive(Debug, Clone)]
156pub struct AccessToken {
157 pub(crate) token: String,
159 pub(crate) expires_at: Option<SystemTime>,
161 pub(crate) scopes: Vec<String>,
163 pub(crate) metadata: HashMap<String, serde_json::Value>,
165}
166
167impl AccessToken {
168 #[must_use]
170 pub fn new(
171 token: String,
172 expires_at: Option<SystemTime>,
173 scopes: Vec<String>,
174 metadata: HashMap<String, serde_json::Value>,
175 ) -> Self {
176 Self {
177 token,
178 expires_at,
179 scopes,
180 metadata,
181 }
182 }
183
184 #[must_use]
186 pub fn token(&self) -> &str {
187 &self.token
188 }
189
190 #[must_use]
192 pub fn expires_at(&self) -> Option<SystemTime> {
193 self.expires_at
194 }
195
196 #[must_use]
198 pub fn scopes(&self) -> &[String] {
199 &self.scopes
200 }
201
202 #[must_use]
204 pub fn metadata(&self) -> &HashMap<String, serde_json::Value> {
205 &self.metadata
206 }
207}
208
209#[async_trait]
211pub trait AuthMiddleware: Send + Sync {
212 async fn extract_token(&self, headers: &HashMap<String, String>) -> Option<String>;
214
215 async fn handle_auth_failure(&self, error: McpError) -> McpResult<()>;
217}
218
219#[derive(Debug, Clone)]
221pub struct DefaultAuthMiddleware;
222
223#[async_trait]
224impl AuthMiddleware for DefaultAuthMiddleware {
225 async fn extract_token(&self, headers: &HashMap<String, String>) -> Option<String> {
226 if let Some(auth_header) = headers
228 .get("authorization")
229 .or_else(|| headers.get("Authorization"))
230 {
231 if let Some(token) = auth_header.strip_prefix("Bearer ") {
232 return Some(token.to_string());
233 }
234 if let Some(token) = auth_header.strip_prefix("ApiKey ") {
235 return Some(token.to_string());
236 }
237 }
238
239 if let Some(api_key) = headers
241 .get("x-api-key")
242 .or_else(|| headers.get("X-API-Key"))
243 {
244 return Some(api_key.clone());
245 }
246
247 None
248 }
249
250 async fn handle_auth_failure(&self, error: McpError) -> McpResult<()> {
251 tracing::warn!("Authentication failed: {}", error);
252 Err(Box::new(error))
253 }
254}