systemprompt_api/services/middleware/
auth.rs1use axum::extract::{FromRequestParts, Request};
2use axum::http::request::Parts;
3use axum::http::{HeaderMap, StatusCode};
4use axum::middleware;
5use axum::middleware::Next;
6use axum::response::Response;
7use systemprompt_models::AuthenticatedUser;
8use systemprompt_models::modules::ApiPaths;
9use systemprompt_security::TokenExtractor;
10
11#[derive(Debug, Clone)]
12pub struct ApiAuthMiddlewareConfig {
13 pub public_paths: Vec<&'static str>,
14}
15
16impl Default for ApiAuthMiddlewareConfig {
17 fn default() -> Self {
18 Self {
19 public_paths: vec![
20 ApiPaths::OAUTH_SESSION,
21 ApiPaths::OAUTH_REGISTER,
22 ApiPaths::OAUTH_AUTHORIZE,
23 ApiPaths::OAUTH_TOKEN,
24 ApiPaths::OAUTH_CALLBACK,
25 ApiPaths::OAUTH_CONSENT,
26 ApiPaths::OAUTH_WEBAUTHN_COMPLETE,
27 ApiPaths::WELLKNOWN_BASE,
28 ApiPaths::STREAM_BASE,
29 ApiPaths::CONTEXTS_WEBHOOK,
30 ApiPaths::DISCOVERY,
31 ],
32 }
33 }
34}
35
36impl ApiAuthMiddlewareConfig {
37 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn is_public_path(&self, path: &str) -> bool {
42 if !path.starts_with(ApiPaths::API_BASE) && !path.starts_with(ApiPaths::WELLKNOWN_BASE) {
43 return true;
44 }
45
46 self.public_paths.iter().any(|p| path.starts_with(p))
47 || path.starts_with(ApiPaths::WELLKNOWN_BASE)
48 }
49}
50
51#[derive(Debug, Clone, Copy)]
52pub struct AuthMiddleware;
53
54impl AuthMiddleware {
55 pub fn apply_auth_enrichment_layer(router: axum::Router) -> axum::Router {
56 router.layer(middleware::from_fn(move |req, next| {
57 let config = ApiAuthMiddlewareConfig::default();
58 async move { auth_enrichment_middleware(config, req, next).await }
59 }))
60 }
61}
62
63pub async fn auth_enrichment_middleware(
64 config: ApiAuthMiddlewareConfig,
65 mut req: Request,
66 next: Next,
67) -> Response {
68 let path = req.uri().path();
69
70 if config.is_public_path(path) {
71 return next.run(req).await;
72 }
73
74 if let Some(user) = extract_optional_user(req.headers()) {
75 req.extensions_mut().insert(user);
76 }
77
78 next.run(req).await
79}
80
81#[derive(Debug, Clone)]
82pub struct RequireAuth(pub AuthenticatedUser);
83
84impl<S: Send + Sync> FromRequestParts<S> for RequireAuth {
85 type Rejection = (StatusCode, &'static str);
86
87 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
88 parts
89 .extensions
90 .get::<AuthenticatedUser>()
91 .cloned()
92 .map(RequireAuth)
93 .ok_or((StatusCode::UNAUTHORIZED, "authentication required"))
94 }
95}
96
97fn extract_optional_user(headers: &HeaderMap) -> Option<AuthenticatedUser> {
98 use systemprompt_config::SecretsBootstrap;
99 use systemprompt_oauth::validate_jwt_token;
100 use uuid::Uuid;
101
102 let token = TokenExtractor::browser_only().extract(headers).ok()?;
103
104 if token.trim().is_empty() {
105 return None;
106 }
107
108 let jwt_secret = SecretsBootstrap::jwt_secret().ok()?;
109 let config = systemprompt_models::Config::get().ok()?;
110 let claims = match validate_jwt_token(
111 &token,
112 jwt_secret,
113 &config.jwt_issuer,
114 &config.jwt_audiences,
115 ) {
116 Ok(claims) => claims,
117 Err(e) => {
118 tracing::warn!(error = %e, "JWT validation failed");
119 return None;
120 },
121 };
122
123 let user_id = Uuid::parse_str(&claims.sub).ok()?;
124
125 let permissions = claims.scope;
126 let roles = claims.roles;
127
128 Some(AuthenticatedUser::new_with_roles(
129 user_id,
130 claims.username,
131 claims.email,
132 permissions,
133 roles,
134 ))
135}