Skip to main content

rust_mcp_extra/auth_provider/
work_os.rs

1//! # WorkOS AuthKit OAuth2 Provider for MCP Servers
2//!
3//! This module implements an OAuth2 specifically designed to integrate
4//! [WorkOS AuthKit](https://workos.com/docs/authkit) as the identity
5//! provider (IdP) in an MCP (Model Context Protocol) server ecosystem.
6//!
7//! It enables your MCP server to:
8//! - Expose standard OAuth2/.well-known endpoints
9//! - Serve authorization server metadata (`/.well-known/oauth-authorization-server`)
10//! - Serve protected resource metadata (custom per MCP)
11//! - Verify incoming access tokens using JWKs + UserInfo endpoint validation
12//!
13//! ## Features
14//!
15//! - Zero-downtime token verification with cached JWKs
16//! - Automatic construction of OAuth2 discovery documents
17//! - Built-in CORS support for metadata endpoints
18//! - Pluggable into `rust-mcp-sdk`'s authentication system via the `AuthProvider` trait
19//!
20//! ## Example
21//!
22//! ```rust,ignore
23//!
24//! let auth_provider = WorkOsAuthProvider::new(WorkOSAuthOptions {
25//!     // Your AuthKit app domain (found in WorkOS dashboard)
26//!     authkit_domain: "https://your-app.authkit.app".to_string(),
27//!     // Base URL of your MCP server (used to build protected resource metadata URL)
28//!     mcp_server_url: "http://localhost:3000/mcp".to_string(),
29//! })?;
30//!
31//! // Register in your MCP server
32//! let server = hyper_server::create_server(
33//! server_details,
34//! handler,
35//! HyperServerOptions {
36//!     host: "localhost".to_string(),
37//!     port: 3000,
38//!     auth: Some(Arc::new(auth_provider)),
39//!     ..Default::default()
40//! });
41//! ```
42use crate::token_verifier::{
43    GenericOauthTokenVerifier, TokenVerifierOptions, VerificationStrategies,
44};
45use async_trait::async_trait;
46use bytes::Bytes;
47use http::{header::CONTENT_TYPE, StatusCode};
48use http_body_util::{BodyExt, Full};
49use rust_mcp_sdk::{
50    auth::{
51        create_discovery_endpoints, AuthInfo, AuthMetadataBuilder, AuthProvider,
52        AuthenticationError, AuthorizationServerMetadata, OauthEndpoint,
53        OauthProtectedResourceMetadata, OauthTokenVerifier,
54    },
55    error::McpSdkError,
56    mcp_http::{middleware::CorsMiddleware, GenericBody, GenericBodyExt, Middleware},
57    mcp_server::{
58        error::{TransportServerError, TransportServerResult},
59        join_url, McpAppState,
60    },
61};
62use std::{collections::HashMap, sync::Arc, vec};
63
64static SCOPES_SUPPORTED: &[&str] = &["email", "offline_access", "openid", "profile"];
65
66/// Configuration options for the WorkOS AuthKit OAuth provider.
67pub struct WorkOSAuthOptions<'a> {
68    pub authkit_domain: String,
69    pub mcp_server_url: String,
70    pub required_scopes: Option<Vec<&'a str>>,
71    pub token_verifier: Option<Box<dyn OauthTokenVerifier>>,
72    pub resource_name: Option<String>,
73    pub resource_documentation: Option<String>,
74}
75
76/// WorkOS AuthKit integration implementing `AuthProvider` for MCP servers.
77///
78/// This provider makes your MCP server compatible with clients that expect standard
79/// OAuth2 authorization server and protected resource discovery endpoints when using
80/// WorkOS AuthKit as the identity provider.
81pub struct WorkOsAuthProvider {
82    auth_server_meta: AuthorizationServerMetadata,
83    protected_resource_meta: OauthProtectedResourceMetadata,
84    endpoint_map: HashMap<String, OauthEndpoint>,
85    protected_resource_metadata_url: String,
86    token_verifier: Box<dyn OauthTokenVerifier>,
87}
88
89impl WorkOsAuthProvider {
90    /// Creates a new `WorkOsAuthProvider` instance.
91    ///
92    /// This performs:
93    /// - Validation and parsing of URLs
94    /// - Construction of OAuth2 metadata documents
95    /// - Setup of token verification using JWKs and UserInfo endpoint
96    ///
97    /// /// # Example
98    ///
99    /// ```rust,ignore
100    /// use rust_mcp_extra::auth_provider::work_os::{WorkOSAuthOptions, WorkOsAuthProvider};
101    ///
102    /// let auth_provider = WorkOsAuthProvider::new(WorkOSAuthOptions {
103    ///    authkit_domain: "https://your-app.authkit.app".to_string(),
104    ///    mcp_server_url: "http://localhost:3000/mcp".to_string(),
105    /// })?;
106    ///
107    pub fn new(mut options: WorkOSAuthOptions) -> Result<Self, McpSdkError> {
108        let (endpoint_map, protected_resource_metadata_url) =
109            create_discovery_endpoints(&options.mcp_server_url)?;
110
111        let required_scopes = options.required_scopes.take();
112        let scopes_supported = required_scopes.clone().unwrap_or(SCOPES_SUPPORTED.to_vec());
113
114        let mut builder = AuthMetadataBuilder::new(&options.mcp_server_url)
115            .issuer(&options.authkit_domain)
116            .authorization_servers(vec![&options.authkit_domain])
117            .authorization_endpoint("/oauth2/authorize")
118            .introspection_endpoint("/oauth2/introspection")
119            .registration_endpoint("/oauth2/register")
120            .token_endpoint("/oauth2/token")
121            .jwks_uri("/oauth2/jwks")
122            .scopes_supported(scopes_supported);
123
124        if let Some(scopes) = required_scopes {
125            builder = builder.reqquired_scopes(scopes)
126        }
127        if let Some(resource_name) = options.resource_name.as_ref() {
128            builder = builder.resource_name(resource_name)
129        }
130        if let Some(resource_documentation) = options.resource_documentation.as_ref() {
131            builder = builder.service_documentation(resource_documentation)
132        }
133
134        let (auth_server_meta, protected_resource_meta) = builder.build()?;
135
136        let Some(jwks_uri) = auth_server_meta.jwks_uri.as_ref().map(|s| s.to_string()) else {
137            return Err(McpSdkError::Internal {
138                description: "jwks_uri is not defined!".to_string(),
139            });
140        };
141
142        let userinfo_uri = join_url(&auth_server_meta.issuer, "oauth2/userinfo")
143            .map_err(|err| McpSdkError::Internal {
144                description: format!("invalid userinfo url :{err}"),
145            })?
146            .to_string();
147
148        let token_verifier: Box<dyn OauthTokenVerifier> = match options.token_verifier {
149            Some(verifier) => verifier,
150            None => Box::new(GenericOauthTokenVerifier::new(TokenVerifierOptions {
151                strategies: vec![
152                    VerificationStrategies::JWKs { jwks_uri },
153                    VerificationStrategies::UserInfo { userinfo_uri },
154                ],
155                validate_audience: None,
156                validate_issuer: Some(options.authkit_domain.clone()),
157                cache_capacity: None,
158            })?),
159        };
160
161        Ok(Self {
162            endpoint_map,
163            protected_resource_metadata_url,
164            token_verifier,
165            auth_server_meta,
166            protected_resource_meta,
167        })
168    }
169
170    /// Helper to build JSON response for authorization server metadata with CORS.
171    fn handle_authorization_server_metadata(
172        response_str: String,
173    ) -> TransportServerResult<http::Response<GenericBody>> {
174        let body = Full::new(Bytes::from(response_str))
175            .map_err(|err| TransportServerError::HttpError(err.to_string()))
176            .boxed();
177        http::Response::builder()
178            .status(StatusCode::OK)
179            .header(CONTENT_TYPE, "application/json")
180            .body(body)
181            .map_err(|err| TransportServerError::HttpError(err.to_string()))
182    }
183
184    /// Helper to build JSON response for protected resource metadata with permissive CORS.
185    fn handle_protected_resource_metadata(
186        response_str: String,
187    ) -> TransportServerResult<http::Response<GenericBody>> {
188        use http_body_util::BodyExt;
189
190        let body = Full::new(Bytes::from(response_str))
191            .map_err(|err| TransportServerError::HttpError(err.to_string()))
192            .boxed();
193        http::Response::builder()
194            .status(StatusCode::OK)
195            .header(CONTENT_TYPE, "application/json")
196            .body(body)
197            .map_err(|err| TransportServerError::HttpError(err.to_string()))
198    }
199}
200
201#[async_trait]
202impl AuthProvider for WorkOsAuthProvider {
203    /// Returns the map of supported OAuth discovery endpoints.
204    fn auth_endpoints(&self) -> Option<&HashMap<String, OauthEndpoint>> {
205        Some(&self.endpoint_map)
206    }
207
208    /// Handles incoming requests to OAuth metadata endpoints.
209    async fn handle_request(
210        &self,
211        request: http::Request<&str>,
212        state: Arc<McpAppState>,
213    ) -> Result<http::Response<GenericBody>, TransportServerError> {
214        let Some(endpoint) = self.endpoint_type(&request) else {
215            return http::Response::builder()
216                .status(StatusCode::NOT_FOUND)
217                .body(GenericBody::empty())
218                .map_err(|err| TransportServerError::HttpError(err.to_string()));
219        };
220
221        // return early if method is not allowed
222        if let Some(response) = self.validate_allowed_methods(endpoint, request.method()) {
223            return Ok(response);
224        }
225
226        match endpoint {
227            OauthEndpoint::AuthorizationServerMetadata => {
228                let json_payload = serde_json::to_string(&self.auth_server_meta)
229                    .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
230                let cors = &CorsMiddleware::default();
231                cors.handle(
232                    request,
233                    state,
234                    Box::new(move |_req, _state| {
235                        Box::pin(
236                            async move { Self::handle_authorization_server_metadata(json_payload) },
237                        )
238                    }),
239                )
240                .await
241            }
242            OauthEndpoint::ProtectedResourceMetadata => {
243                let json_payload = serde_json::to_string(&self.protected_resource_meta)
244                    .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
245                let cors = &CorsMiddleware::default();
246                cors.handle(
247                    request,
248                    state,
249                    Box::new(move |_req, _state| {
250                        Box::pin(
251                            async move { Self::handle_protected_resource_metadata(json_payload) },
252                        )
253                    }),
254                )
255                .await
256            }
257            _ => Ok(GenericBody::create_404_response()),
258        }
259    }
260
261    /// Verifies an access token using JWKs and optional UserInfo validation.
262    ///
263    /// Returns authenticated `AuthInfo` on success.
264    async fn verify_token(&self, access_token: String) -> Result<AuthInfo, AuthenticationError> {
265        self.token_verifier.verify_token(access_token).await
266    }
267
268    /// Returns the full URL to the protected resource metadata document.
269    fn protected_resource_metadata_url(&self) -> Option<&str> {
270        Some(self.protected_resource_metadata_url.as_str())
271    }
272}