1use std::collections::HashMap;
6use std::sync::Arc;
7#[cfg(feature = "dpop")]
8use std::time::Duration;
9
10use secrecy::{ExposeSecret, SecretString};
11use serde::{Deserialize, Serialize};
12use tokio::sync::RwLock;
13
14use turbomcp_protocol::{Error as McpError, Result as McpResult};
15
16#[cfg(feature = "dpop")]
18use super::dpop::DpopAlgorithm;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct AuthConfig {
29 pub enabled: bool,
31 pub providers: Vec<AuthProviderConfig>,
33 pub authorization: AuthorizationConfig,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct AuthProviderConfig {
40 pub name: String,
42 pub provider_type: AuthProviderType,
44 pub settings: HashMap<String, serde_json::Value>,
46 pub enabled: bool,
48 pub priority: u32,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54pub enum AuthProviderType {
55 OAuth2,
57 ApiKey,
59 Jwt,
61 Custom,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
67pub enum SecurityLevel {
68 #[default]
70 Standard,
71 Enhanced,
73 Maximum,
75}
76
77#[cfg(feature = "dpop")]
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct DpopConfig {
81 pub key_algorithm: DpopAlgorithm,
83 #[serde(default = "default_proof_lifetime")]
85 pub proof_lifetime: Duration,
86 #[serde(default = "default_clock_skew")]
88 pub clock_skew_tolerance: Duration,
89 #[serde(default)]
91 pub key_storage: DpopKeyStorageConfig,
92}
93
94#[cfg(feature = "dpop")]
95fn default_proof_lifetime() -> Duration {
96 Duration::from_secs(60)
97}
98
99#[cfg(feature = "dpop")]
100fn default_clock_skew() -> Duration {
101 Duration::from_secs(300)
102}
103
104#[cfg(feature = "dpop")]
106#[derive(Debug, Clone, Serialize, Deserialize, Default)]
107pub enum DpopKeyStorageConfig {
108 #[default]
110 Memory,
111 Redis {
113 url: String,
115 },
116 Hsm {
118 config: serde_json::Value,
120 },
121}
122
123#[cfg(feature = "dpop")]
124impl Default for DpopConfig {
125 fn default() -> Self {
126 Self {
127 key_algorithm: DpopAlgorithm::ES256,
128 proof_lifetime: default_proof_lifetime(),
129 clock_skew_tolerance: default_clock_skew(),
130 key_storage: DpopKeyStorageConfig::default(),
131 }
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct AuthorizationConfig {
138 pub rbac_enabled: bool,
140 pub default_roles: Vec<String>,
142 pub inheritance_rules: HashMap<String, Vec<String>>,
144 pub resource_permissions: HashMap<String, Vec<String>>,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct OAuth2Config {
151 pub client_id: String,
153 #[serde(
155 serialize_with = "serialize_secret",
156 deserialize_with = "deserialize_secret"
157 )]
158 pub client_secret: SecretString,
159 pub auth_url: String,
161 pub token_url: String,
163 #[serde(default)]
165 pub revocation_url: Option<String>,
166 pub redirect_uri: String,
168 pub scopes: Vec<String>,
170 pub flow_type: OAuth2FlowType,
172 pub additional_params: HashMap<String, String>,
174 #[serde(default)]
176 pub security_level: SecurityLevel,
177 #[cfg(feature = "dpop")]
179 #[serde(default)]
180 pub dpop_config: Option<DpopConfig>,
181 #[serde(default)]
184 pub mcp_resource_uri: Option<String>,
185 #[serde(default = "default_auto_resource_indicators")]
188 pub auto_resource_indicators: bool,
189}
190
191fn serialize_secret<S>(secret: &SecretString, serializer: S) -> Result<S::Ok, S::Error>
193where
194 S: serde::Serializer,
195{
196 serializer.serialize_str(secret.expose_secret())
197}
198
199fn deserialize_secret<'de, D>(deserializer: D) -> Result<SecretString, D::Error>
201where
202 D: serde::Deserializer<'de>,
203{
204 let s: String = serde::Deserialize::deserialize(deserializer)?;
205 Ok(SecretString::new(s))
206}
207
208fn default_auto_resource_indicators() -> bool {
210 true
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
215pub enum OAuth2FlowType {
216 AuthorizationCode,
218 ClientCredentials,
220 DeviceCode,
222 Implicit,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct OAuth2AuthResult {
229 pub auth_url: String,
231 pub state: String,
233 pub code_verifier: Option<String>,
235 pub device_code: Option<String>,
237 pub user_code: Option<String>,
239 pub verification_uri: Option<String>,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct ProtectedResourceMetadata {
246 pub resource: String,
248 pub authorization_server: String,
250 #[serde(skip_serializing_if = "Option::is_none")]
252 pub scopes_supported: Option<Vec<String>>,
253 #[serde(skip_serializing_if = "Option::is_none")]
255 pub bearer_methods_supported: Option<Vec<BearerTokenMethod>>,
256 #[serde(skip_serializing_if = "Option::is_none")]
258 pub resource_documentation: Option<String>,
259 #[serde(flatten)]
261 pub additional_metadata: HashMap<String, serde_json::Value>,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
266#[serde(rename_all = "lowercase")]
267pub enum BearerTokenMethod {
268 #[default]
270 Header,
271 Query,
273 Body,
275}
276
277#[derive(Debug, Clone)]
279pub struct McpResourceRegistry {
280 resources: Arc<RwLock<HashMap<String, ProtectedResourceMetadata>>>,
282 default_auth_server: String,
284 base_resource_uri: String,
286}
287
288impl McpResourceRegistry {
289 #[must_use]
291 pub fn new(base_resource_uri: String, auth_server: String) -> Self {
292 Self {
293 resources: Arc::new(RwLock::new(HashMap::new())),
294 default_auth_server: auth_server,
295 base_resource_uri,
296 }
297 }
298
299 pub async fn register_resource(
301 &self,
302 resource_id: &str,
303 scopes: Vec<String>,
304 documentation: Option<String>,
305 ) -> McpResult<()> {
306 let resource_uri = format!(
307 "{}/{}",
308 self.base_resource_uri.trim_end_matches('/'),
309 resource_id
310 );
311
312 let metadata = ProtectedResourceMetadata {
313 resource: resource_uri.clone(),
314 authorization_server: self.default_auth_server.clone(),
315 scopes_supported: Some(scopes),
316 bearer_methods_supported: Some(vec![
317 BearerTokenMethod::Header, BearerTokenMethod::Body, ]),
320 resource_documentation: documentation,
321 additional_metadata: HashMap::new(),
322 };
323
324 self.resources.write().await.insert(resource_uri, metadata);
325 Ok(())
326 }
327
328 pub async fn get_resource_metadata(
330 &self,
331 resource_uri: &str,
332 ) -> Option<ProtectedResourceMetadata> {
333 self.resources.read().await.get(resource_uri).cloned()
334 }
335
336 pub async fn list_resources(&self) -> Vec<String> {
338 self.resources.read().await.keys().cloned().collect()
339 }
340
341 pub async fn generate_well_known_metadata(&self) -> HashMap<String, ProtectedResourceMetadata> {
343 self.resources.read().await.clone()
344 }
345
346 pub async fn validate_scope_for_resource(
348 &self,
349 resource_uri: &str,
350 token_scopes: &[String],
351 ) -> McpResult<bool> {
352 if let Some(metadata) = self.get_resource_metadata(resource_uri).await {
353 if let Some(required_scopes) = metadata.scopes_supported {
354 let has_required_scope = required_scopes
356 .iter()
357 .any(|scope| token_scopes.contains(scope));
358 Ok(has_required_scope)
359 } else {
360 Ok(true)
362 }
363 } else {
364 Err(McpError::validation(format!(
365 "Unknown resource: {}",
366 resource_uri
367 )))
368 }
369 }
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
374pub struct ClientRegistrationRequest {
375 #[serde(skip_serializing_if = "Option::is_none")]
377 pub redirect_uris: Option<Vec<String>>,
378 #[serde(skip_serializing_if = "Option::is_none")]
380 pub response_types: Option<Vec<String>>,
381 #[serde(skip_serializing_if = "Option::is_none")]
383 pub grant_types: Option<Vec<String>>,
384 #[serde(skip_serializing_if = "Option::is_none")]
386 pub application_type: Option<ApplicationType>,
387 #[serde(skip_serializing_if = "Option::is_none")]
389 pub client_name: Option<String>,
390 #[serde(skip_serializing_if = "Option::is_none")]
392 pub client_uri: Option<String>,
393 #[serde(skip_serializing_if = "Option::is_none")]
395 pub logo_uri: Option<String>,
396 #[serde(skip_serializing_if = "Option::is_none")]
398 pub scope: Option<String>,
399 #[serde(skip_serializing_if = "Option::is_none")]
401 pub contacts: Option<Vec<String>>,
402 #[serde(skip_serializing_if = "Option::is_none")]
404 pub tos_uri: Option<String>,
405 #[serde(skip_serializing_if = "Option::is_none")]
407 pub policy_uri: Option<String>,
408 #[serde(skip_serializing_if = "Option::is_none")]
410 pub software_id: Option<String>,
411 #[serde(skip_serializing_if = "Option::is_none")]
413 pub software_version: Option<String>,
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct ClientRegistrationResponse {
419 pub client_id: String,
421 #[serde(skip_serializing_if = "Option::is_none")]
423 pub client_secret: Option<String>,
424 #[serde(skip_serializing_if = "Option::is_none")]
426 pub registration_access_token: Option<String>,
427 #[serde(skip_serializing_if = "Option::is_none")]
429 pub registration_client_uri: Option<String>,
430 #[serde(skip_serializing_if = "Option::is_none")]
432 pub client_id_issued_at: Option<i64>,
433 #[serde(skip_serializing_if = "Option::is_none")]
435 pub client_secret_expires_at: Option<i64>,
436 #[serde(skip_serializing_if = "Option::is_none")]
438 pub redirect_uris: Option<Vec<String>>,
439 #[serde(skip_serializing_if = "Option::is_none")]
441 pub response_types: Option<Vec<String>>,
442 #[serde(skip_serializing_if = "Option::is_none")]
444 pub grant_types: Option<Vec<String>>,
445 #[serde(skip_serializing_if = "Option::is_none")]
447 pub application_type: Option<ApplicationType>,
448 #[serde(skip_serializing_if = "Option::is_none")]
450 pub client_name: Option<String>,
451 #[serde(skip_serializing_if = "Option::is_none")]
453 pub scope: Option<String>,
454}
455
456#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
458#[serde(rename_all = "lowercase")]
459pub enum ApplicationType {
460 #[default]
462 Web,
463 Native,
465}
466
467#[derive(Debug, Clone, Serialize, Deserialize)]
469pub struct ClientRegistrationError {
470 pub error: ClientRegistrationErrorCode,
472 #[serde(skip_serializing_if = "Option::is_none")]
474 pub error_description: Option<String>,
475}
476
477#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
479#[serde(rename_all = "snake_case")]
480pub enum ClientRegistrationErrorCode {
481 InvalidRedirectUri,
483 InvalidClientMetadata,
485 InvalidSoftwareStatement,
487 UnapprovedSoftwareStatement,
489}
490
491#[derive(Debug, Clone)]
493pub struct DynamicClientRegistration {
494 registration_endpoint: String,
496 default_application_type: ApplicationType,
498 default_grant_types: Vec<String>,
500 default_response_types: Vec<String>,
502 client: reqwest::Client,
504}
505
506impl DynamicClientRegistration {
507 #[must_use]
509 pub fn new(registration_endpoint: String) -> Self {
510 Self {
511 registration_endpoint,
512 default_application_type: ApplicationType::Web,
513 default_grant_types: vec!["authorization_code".to_string()],
514 default_response_types: vec!["code".to_string()],
515 client: reqwest::Client::new(),
516 }
517 }
518
519 pub async fn register_client(
521 &self,
522 request: ClientRegistrationRequest,
523 ) -> McpResult<ClientRegistrationResponse> {
524 let mut registration_request = request;
526
527 if registration_request.application_type.is_none() {
529 registration_request.application_type = Some(self.default_application_type.clone());
530 }
531 if registration_request.grant_types.is_none() {
532 registration_request.grant_types = Some(self.default_grant_types.clone());
533 }
534 if registration_request.response_types.is_none() {
535 registration_request.response_types = Some(self.default_response_types.clone());
536 }
537
538 let response = self
540 .client
541 .post(&self.registration_endpoint)
542 .header("Content-Type", "application/json")
543 .json(®istration_request)
544 .send()
545 .await
546 .map_err(|e| McpError::validation(format!("Registration request failed: {}", e)))?;
547
548 if response.status().is_success() {
550 let registration_response: ClientRegistrationResponse =
551 response.json().await.map_err(|e| {
552 McpError::validation(format!("Invalid registration response: {}", e))
553 })?;
554 Ok(registration_response)
555 } else {
556 let error_response: ClientRegistrationError = response
558 .json()
559 .await
560 .map_err(|e| McpError::validation(format!("Invalid error response: {}", e)))?;
561 Err(McpError::validation(format!(
562 "Client registration failed: {} - {}",
563 error_response.error as u32,
564 error_response.error_description.unwrap_or_default()
565 )))
566 }
567 }
568
569 #[must_use]
571 pub fn create_mcp_client_request(
572 client_name: &str,
573 redirect_uris: Vec<String>,
574 mcp_server_uri: &str,
575 ) -> ClientRegistrationRequest {
576 ClientRegistrationRequest {
577 redirect_uris: Some(redirect_uris),
578 response_types: Some(vec!["code".to_string()]),
579 grant_types: Some(vec!["authorization_code".to_string()]),
580 application_type: Some(ApplicationType::Web),
581 client_name: Some(format!("MCP Client: {}", client_name)),
582 client_uri: Some(mcp_server_uri.to_string()),
583 scope: Some(
584 "mcp:tools:read mcp:tools:execute mcp:resources:read mcp:prompts:read".to_string(),
585 ),
586 software_id: Some("turbomcp".to_string()),
587 software_version: Some(env!("CARGO_PKG_VERSION").to_string()),
588 logo_uri: None,
589 contacts: None,
590 tos_uri: None,
591 policy_uri: None,
592 }
593 }
594}
595
596#[derive(Debug, Clone, Serialize, Deserialize)]
598pub struct DeviceAuthorizationResponse {
599 pub device_code: String,
601 pub user_code: String,
603 pub verification_uri: String,
605 pub verification_uri_complete: Option<String>,
607 pub expires_in: u64,
609 pub interval: u64,
611}
612
613#[derive(Debug, Clone)]
615pub struct ProviderConfig {
616 pub provider_type: ProviderType,
618 pub default_scopes: Vec<String>,
620 pub refresh_behavior: RefreshBehavior,
622 pub userinfo_endpoint: Option<String>,
624 pub additional_params: HashMap<String, String>,
626}
627
628#[derive(Debug, Clone, PartialEq)]
630pub enum ProviderType {
631 Google,
633 Microsoft,
635 GitHub,
637 GitLab,
639 Apple,
641 Okta,
643 Auth0,
645 Keycloak,
647 Generic,
649 Custom(String),
651}
652
653#[derive(Debug, Clone)]
655pub enum RefreshBehavior {
656 Proactive,
658 Reactive,
660 Custom,
662}