rust_mcp_extra/auth_provider/
work_os.rs1use crate::token_verifier::{
43 GenericOauthTokenVerifier, TokenVerifierOptions, VerificationStrategies,
44};
45use async_trait::async_trait;
46use bytes::Bytes;
47use http::{header::CONTENT_TYPE, StatusCode};
48use http_body_util::{BodyExt, Full};
49use rust_mcp_sdk::{
50 auth::{
51 create_discovery_endpoints, AuthInfo, AuthMetadataBuilder, AuthProvider,
52 AuthenticationError, AuthorizationServerMetadata, OauthEndpoint,
53 OauthProtectedResourceMetadata, OauthTokenVerifier,
54 },
55 error::McpSdkError,
56 mcp_http::{middleware::CorsMiddleware, GenericBody, GenericBodyExt, Middleware},
57 mcp_server::{
58 error::{TransportServerError, TransportServerResult},
59 join_url, McpAppState,
60 },
61};
62use std::{collections::HashMap, sync::Arc, vec};
63
64static SCOPES_SUPPORTED: &[&str] = &["email", "offline_access", "openid", "profile"];
65
66pub struct WorkOSAuthOptions<'a> {
68 pub authkit_domain: String,
69 pub mcp_server_url: String,
70 pub required_scopes: Option<Vec<&'a str>>,
71 pub token_verifier: Option<Box<dyn OauthTokenVerifier>>,
72 pub resource_name: Option<String>,
73 pub resource_documentation: Option<String>,
74}
75
76pub struct WorkOsAuthProvider {
82 auth_server_meta: AuthorizationServerMetadata,
83 protected_resource_meta: OauthProtectedResourceMetadata,
84 endpoint_map: HashMap<String, OauthEndpoint>,
85 protected_resource_metadata_url: String,
86 token_verifier: Box<dyn OauthTokenVerifier>,
87}
88
89impl WorkOsAuthProvider {
90 pub fn new(mut options: WorkOSAuthOptions) -> Result<Self, McpSdkError> {
108 let (endpoint_map, protected_resource_metadata_url) =
109 create_discovery_endpoints(&options.mcp_server_url)?;
110
111 let required_scopes = options.required_scopes.take();
112 let scopes_supported = required_scopes.clone().unwrap_or(SCOPES_SUPPORTED.to_vec());
113
114 let mut builder = AuthMetadataBuilder::new(&options.mcp_server_url)
115 .issuer(&options.authkit_domain)
116 .authorization_servers(vec![&options.authkit_domain])
117 .authorization_endpoint("/oauth2/authorize")
118 .introspection_endpoint("/oauth2/introspection")
119 .registration_endpoint("/oauth2/register")
120 .token_endpoint("/oauth2/token")
121 .jwks_uri("/oauth2/jwks")
122 .scopes_supported(scopes_supported);
123
124 if let Some(scopes) = required_scopes {
125 builder = builder.reqquired_scopes(scopes)
126 }
127 if let Some(resource_name) = options.resource_name.as_ref() {
128 builder = builder.resource_name(resource_name)
129 }
130 if let Some(resource_documentation) = options.resource_documentation.as_ref() {
131 builder = builder.service_documentation(resource_documentation)
132 }
133
134 let (auth_server_meta, protected_resource_meta) = builder.build()?;
135
136 let Some(jwks_uri) = auth_server_meta.jwks_uri.as_ref().map(|s| s.to_string()) else {
137 return Err(McpSdkError::Internal {
138 description: "jwks_uri is not defined!".to_string(),
139 });
140 };
141
142 let userinfo_uri = join_url(&auth_server_meta.issuer, "oauth2/userinfo")
143 .map_err(|err| McpSdkError::Internal {
144 description: format!("invalid userinfo url :{err}"),
145 })?
146 .to_string();
147
148 let token_verifier: Box<dyn OauthTokenVerifier> = match options.token_verifier {
149 Some(verifier) => verifier,
150 None => Box::new(GenericOauthTokenVerifier::new(TokenVerifierOptions {
151 strategies: vec![
152 VerificationStrategies::JWKs { jwks_uri },
153 VerificationStrategies::UserInfo { userinfo_uri },
154 ],
155 validate_audience: None,
156 validate_issuer: Some(options.authkit_domain.clone()),
157 cache_capacity: None,
158 })?),
159 };
160
161 Ok(Self {
162 endpoint_map,
163 protected_resource_metadata_url,
164 token_verifier,
165 auth_server_meta,
166 protected_resource_meta,
167 })
168 }
169
170 fn handle_authorization_server_metadata(
172 response_str: String,
173 ) -> TransportServerResult<http::Response<GenericBody>> {
174 let body = Full::new(Bytes::from(response_str))
175 .map_err(|err| TransportServerError::HttpError(err.to_string()))
176 .boxed();
177 http::Response::builder()
178 .status(StatusCode::OK)
179 .header(CONTENT_TYPE, "application/json")
180 .body(body)
181 .map_err(|err| TransportServerError::HttpError(err.to_string()))
182 }
183
184 fn handle_protected_resource_metadata(
186 response_str: String,
187 ) -> TransportServerResult<http::Response<GenericBody>> {
188 use http_body_util::BodyExt;
189
190 let body = Full::new(Bytes::from(response_str))
191 .map_err(|err| TransportServerError::HttpError(err.to_string()))
192 .boxed();
193 http::Response::builder()
194 .status(StatusCode::OK)
195 .header(CONTENT_TYPE, "application/json")
196 .body(body)
197 .map_err(|err| TransportServerError::HttpError(err.to_string()))
198 }
199}
200
201#[async_trait]
202impl AuthProvider for WorkOsAuthProvider {
203 fn auth_endpoints(&self) -> Option<&HashMap<String, OauthEndpoint>> {
205 Some(&self.endpoint_map)
206 }
207
208 async fn handle_request(
210 &self,
211 request: http::Request<&str>,
212 state: Arc<McpAppState>,
213 ) -> Result<http::Response<GenericBody>, TransportServerError> {
214 let Some(endpoint) = self.endpoint_type(&request) else {
215 return http::Response::builder()
216 .status(StatusCode::NOT_FOUND)
217 .body(GenericBody::empty())
218 .map_err(|err| TransportServerError::HttpError(err.to_string()));
219 };
220
221 if let Some(response) = self.validate_allowed_methods(endpoint, request.method()) {
223 return Ok(response);
224 }
225
226 match endpoint {
227 OauthEndpoint::AuthorizationServerMetadata => {
228 let json_payload = serde_json::to_string(&self.auth_server_meta)
229 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
230 let cors = &CorsMiddleware::default();
231 cors.handle(
232 request,
233 state,
234 Box::new(move |_req, _state| {
235 Box::pin(
236 async move { Self::handle_authorization_server_metadata(json_payload) },
237 )
238 }),
239 )
240 .await
241 }
242 OauthEndpoint::ProtectedResourceMetadata => {
243 let json_payload = serde_json::to_string(&self.protected_resource_meta)
244 .map_err(|err| TransportServerError::HttpError(err.to_string()))?;
245 let cors = &CorsMiddleware::default();
246 cors.handle(
247 request,
248 state,
249 Box::new(move |_req, _state| {
250 Box::pin(
251 async move { Self::handle_protected_resource_metadata(json_payload) },
252 )
253 }),
254 )
255 .await
256 }
257 _ => Ok(GenericBody::create_404_response()),
258 }
259 }
260
261 async fn verify_token(&self, access_token: String) -> Result<AuthInfo, AuthenticationError> {
265 self.token_verifier.verify_token(access_token).await
266 }
267
268 fn protected_resource_metadata_url(&self) -> Option<&str> {
270 Some(self.protected_resource_metadata_url.as_str())
271 }
272}