rust_mcp_sdk/auth/auth_provider/
remote_auth_provider.rs

1use crate::{
2    auth::{
3        create_protected_resource_metadata_url, AuthInfo, AuthProvider, AuthenticationError,
4        AuthorizationServerMetadata, OauthEndpoint, OauthProtectedResourceMetadata,
5        OauthTokenVerifier, WELL_KNOWN_OAUTH_AUTHORIZATION_SERVER,
6    },
7    mcp_http::{
8        middleware::CorsMiddleware, url_base, GenericBody, GenericBodyExt, McpAppState, Middleware,
9    },
10    mcp_server::error::{TransportServerError, TransportServerResult},
11};
12use async_trait::async_trait;
13use bytes::Bytes;
14use http::{header::CONTENT_TYPE, StatusCode};
15use http_body_util::{BodyExt, Full};
16use reqwest::Client;
17use std::{collections::HashMap, sync::Arc};
18
19/// Represents a **Remote OAuth authentication provider** integrated with the MCP server.
20/// This struct defines how the MCP server interacts with an external identity provider
21/// that supports **Dynamic Client Registration (DCR)**.
22/// The [`RemoteAuthProvider`] enables enterprise-grade authentication by leveraging
23/// external OAuth infrastructure, while maintaining secure token verification and
24/// identity validation within the MCP server.
25pub struct RemoteAuthProvider {
26    auth_server_meta: AuthorizationServerMetadata,
27    protected_resource_meta: OauthProtectedResourceMetadata,
28    token_verifier: Box<dyn OauthTokenVerifier>,
29    endpoint_map: HashMap<String, OauthEndpoint>,
30    required_scopes: Option<Vec<String>>,
31    protected_resource_metadata_url: String,
32}
33
34impl RemoteAuthProvider {
35    pub fn new(
36        auth_server_meta: AuthorizationServerMetadata,
37        protected_resource_meta: OauthProtectedResourceMetadata,
38        token_verifier: Box<dyn OauthTokenVerifier>,
39        required_scopes: Option<Vec<String>>,
40    ) -> Self {
41        let mut endpoint_map = HashMap::new();
42        endpoint_map.insert(
43            WELL_KNOWN_OAUTH_AUTHORIZATION_SERVER.to_string(),
44            OauthEndpoint::AuthorizationServerMetadata,
45        );
46
47        let resource_url = &protected_resource_meta.resource;
48        let relative_url = create_protected_resource_metadata_url(resource_url.path());
49        let base_url = url_base(resource_url);
50        let protected_resource_metadata_url =
51            format!("{}{relative_url}", base_url.trim_end_matches('/'));
52
53        endpoint_map.insert(relative_url, OauthEndpoint::ProtectedResourceMetadata);
54
55        Self {
56            auth_server_meta,
57            protected_resource_meta,
58            token_verifier,
59            endpoint_map,
60            required_scopes,
61            protected_resource_metadata_url,
62        }
63    }
64
65    pub async fn with_remote_metadata_url(
66        authorization_server_metadata_url: &str,
67        protected_resource_meta: OauthProtectedResourceMetadata,
68        token_verifier: Box<dyn OauthTokenVerifier>,
69        required_scopes: Option<Vec<String>>,
70    ) -> Result<Self, reqwest::Error> {
71        let client = Client::new();
72
73        let auth_server_meta = client
74            .get(authorization_server_metadata_url)
75            .send()
76            .await?
77            .json::<AuthorizationServerMetadata>()
78            .await?;
79
80        Ok(Self::new(
81            auth_server_meta,
82            protected_resource_meta,
83            token_verifier,
84            required_scopes,
85        ))
86    }
87
88    fn handle_authorization_server_metadata(
89        response_str: String,
90    ) -> TransportServerResult<http::Response<GenericBody>> {
91        let body = Full::new(Bytes::from(response_str))
92            .map_err(|err| TransportServerError::HttpError(err.to_string()))
93            .boxed();
94        http::Response::builder()
95            .status(StatusCode::OK)
96            .header(CONTENT_TYPE, "application/json")
97            .body(body)
98            .map_err(|err| TransportServerError::HttpError(err.to_string()))
99    }
100
101    fn handle_protected_resource_metadata(
102        response_str: String,
103    ) -> TransportServerResult<http::Response<GenericBody>> {
104        use http_body_util::BodyExt;
105
106        let body = Full::new(Bytes::from(response_str))
107            .map_err(|err| TransportServerError::HttpError(err.to_string()))
108            .boxed();
109        http::Response::builder()
110            .status(StatusCode::OK)
111            .header(CONTENT_TYPE, "application/json")
112            .body(body)
113            .map_err(|err| TransportServerError::HttpError(err.to_string()))
114    }
115}
116
117#[async_trait]
118impl AuthProvider for RemoteAuthProvider {
119    fn protected_resource_metadata_url(&self) -> Option<&str> {
120        Some(self.protected_resource_metadata_url.as_str())
121    }
122
123    async fn verify_token(&self, access_token: String) -> Result<AuthInfo, AuthenticationError> {
124        self.token_verifier.verify_token(access_token).await
125    }
126
127    fn required_scopes(&self) -> Option<&Vec<String>> {
128        self.required_scopes.as_ref()
129    }
130
131    async fn handle_request(
132        &self,
133        request: http::Request<&str>,
134        state: Arc<McpAppState>,
135    ) -> Result<http::Response<GenericBody>, TransportServerError> {
136        let Some(endpoint) = self.endpoint_type(&request) else {
137            return http::Response::builder()
138                .status(StatusCode::NOT_FOUND)
139                .body(GenericBody::empty())
140                .map_err(|err| TransportServerError::HttpError(err.to_string()));
141        };
142
143        // return early if method is not allowed
144        if let Some(response) = self.validate_allowed_methods(endpoint, request.method()) {
145            return Ok(response);
146        }
147
148        match endpoint {
149            OauthEndpoint::AuthorizationServerMetadata => {
150                let json_payload = serde_json::to_string(&self.auth_server_meta)
151                    .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
152                let cors = &CorsMiddleware::default();
153                cors.handle(
154                    request,
155                    state,
156                    Box::new(move |_req, _state| {
157                        Box::pin(
158                            async move { Self::handle_authorization_server_metadata(json_payload) },
159                        )
160                    }),
161                )
162                .await
163            }
164            OauthEndpoint::ProtectedResourceMetadata => {
165                let json_payload = serde_json::to_string(&self.protected_resource_meta)
166                    .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
167
168                let cors = &CorsMiddleware::default();
169                cors.handle(
170                    request,
171                    state,
172                    Box::new(move |_req, _state| {
173                        Box::pin(
174                            async move { Self::handle_protected_resource_metadata(json_payload) },
175                        )
176                    }),
177                )
178                .await
179            }
180            _ => Ok(GenericBody::create_404_response()),
181        }
182    }
183
184    fn auth_endpoints(&self) -> Option<&HashMap<String, OauthEndpoint>> {
185        Some(&self.endpoint_map)
186    }
187}