rust_mcp_sdk/auth/auth_provider/
remote_auth_provider.rs1use crate::{
2 auth::{
3 create_protected_resource_metadata_url, AuthInfo, AuthProvider, AuthenticationError,
4 AuthorizationServerMetadata, OauthEndpoint, OauthProtectedResourceMetadata,
5 OauthTokenVerifier, WELL_KNOWN_OAUTH_AUTHORIZATION_SERVER,
6 },
7 mcp_http::{
8 middleware::CorsMiddleware, url_base, GenericBody, GenericBodyExt, McpAppState, Middleware,
9 },
10 mcp_server::error::{TransportServerError, TransportServerResult},
11};
12use async_trait::async_trait;
13use bytes::Bytes;
14use http::{header::CONTENT_TYPE, StatusCode};
15use http_body_util::{BodyExt, Full};
16use reqwest::Client;
17use std::{collections::HashMap, sync::Arc};
18
19pub struct RemoteAuthProvider {
26 auth_server_meta: AuthorizationServerMetadata,
27 protected_resource_meta: OauthProtectedResourceMetadata,
28 token_verifier: Box<dyn OauthTokenVerifier>,
29 endpoint_map: HashMap<String, OauthEndpoint>,
30 required_scopes: Option<Vec<String>>,
31 protected_resource_metadata_url: String,
32}
33
34impl RemoteAuthProvider {
35 pub fn new(
36 auth_server_meta: AuthorizationServerMetadata,
37 protected_resource_meta: OauthProtectedResourceMetadata,
38 token_verifier: Box<dyn OauthTokenVerifier>,
39 required_scopes: Option<Vec<String>>,
40 ) -> Self {
41 let mut endpoint_map = HashMap::new();
42 endpoint_map.insert(
43 WELL_KNOWN_OAUTH_AUTHORIZATION_SERVER.to_string(),
44 OauthEndpoint::AuthorizationServerMetadata,
45 );
46
47 let resource_url = &protected_resource_meta.resource;
48 let relative_url = create_protected_resource_metadata_url(resource_url.path());
49 let base_url = url_base(resource_url);
50 let protected_resource_metadata_url =
51 format!("{}{relative_url}", base_url.trim_end_matches('/'));
52
53 endpoint_map.insert(relative_url, OauthEndpoint::ProtectedResourceMetadata);
54
55 Self {
56 auth_server_meta,
57 protected_resource_meta,
58 token_verifier,
59 endpoint_map,
60 required_scopes,
61 protected_resource_metadata_url,
62 }
63 }
64
65 pub async fn with_remote_metadata_url(
66 authorization_server_metadata_url: &str,
67 protected_resource_meta: OauthProtectedResourceMetadata,
68 token_verifier: Box<dyn OauthTokenVerifier>,
69 required_scopes: Option<Vec<String>>,
70 ) -> Result<Self, reqwest::Error> {
71 let client = Client::new();
72
73 let auth_server_meta = client
74 .get(authorization_server_metadata_url)
75 .send()
76 .await?
77 .json::<AuthorizationServerMetadata>()
78 .await?;
79
80 Ok(Self::new(
81 auth_server_meta,
82 protected_resource_meta,
83 token_verifier,
84 required_scopes,
85 ))
86 }
87
88 fn handle_authorization_server_metadata(
89 response_str: String,
90 ) -> TransportServerResult<http::Response<GenericBody>> {
91 let body = Full::new(Bytes::from(response_str))
92 .map_err(|err| TransportServerError::HttpError(err.to_string()))
93 .boxed();
94 http::Response::builder()
95 .status(StatusCode::OK)
96 .header(CONTENT_TYPE, "application/json")
97 .body(body)
98 .map_err(|err| TransportServerError::HttpError(err.to_string()))
99 }
100
101 fn handle_protected_resource_metadata(
102 response_str: String,
103 ) -> TransportServerResult<http::Response<GenericBody>> {
104 use http_body_util::BodyExt;
105
106 let body = Full::new(Bytes::from(response_str))
107 .map_err(|err| TransportServerError::HttpError(err.to_string()))
108 .boxed();
109 http::Response::builder()
110 .status(StatusCode::OK)
111 .header(CONTENT_TYPE, "application/json")
112 .body(body)
113 .map_err(|err| TransportServerError::HttpError(err.to_string()))
114 }
115}
116
117#[async_trait]
118impl AuthProvider for RemoteAuthProvider {
119 fn protected_resource_metadata_url(&self) -> Option<&str> {
120 Some(self.protected_resource_metadata_url.as_str())
121 }
122
123 async fn verify_token(&self, access_token: String) -> Result<AuthInfo, AuthenticationError> {
124 self.token_verifier.verify_token(access_token).await
125 }
126
127 fn required_scopes(&self) -> Option<&Vec<String>> {
128 self.required_scopes.as_ref()
129 }
130
131 async fn handle_request(
132 &self,
133 request: http::Request<&str>,
134 state: Arc<McpAppState>,
135 ) -> Result<http::Response<GenericBody>, TransportServerError> {
136 let Some(endpoint) = self.endpoint_type(&request) else {
137 return http::Response::builder()
138 .status(StatusCode::NOT_FOUND)
139 .body(GenericBody::empty())
140 .map_err(|err| TransportServerError::HttpError(err.to_string()));
141 };
142
143 if let Some(response) = self.validate_allowed_methods(endpoint, request.method()) {
145 return Ok(response);
146 }
147
148 match endpoint {
149 OauthEndpoint::AuthorizationServerMetadata => {
150 let json_payload = serde_json::to_string(&self.auth_server_meta)
151 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
152 let cors = &CorsMiddleware::default();
153 cors.handle(
154 request,
155 state,
156 Box::new(move |_req, _state| {
157 Box::pin(
158 async move { Self::handle_authorization_server_metadata(json_payload) },
159 )
160 }),
161 )
162 .await
163 }
164 OauthEndpoint::ProtectedResourceMetadata => {
165 let json_payload = serde_json::to_string(&self.protected_resource_meta)
166 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
167
168 let cors = &CorsMiddleware::default();
169 cors.handle(
170 request,
171 state,
172 Box::new(move |_req, _state| {
173 Box::pin(
174 async move { Self::handle_protected_resource_metadata(json_payload) },
175 )
176 }),
177 )
178 .await
179 }
180 _ => Ok(GenericBody::create_404_response()),
181 }
182 }
183
184 fn auth_endpoints(&self) -> Option<&HashMap<String, OauthEndpoint>> {
185 Some(&self.endpoint_map)
186 }
187}