Skip to main content

systemprompt_api/services/middleware/
auth.rs

1use axum::extract::{FromRequestParts, Request};
2use axum::http::request::Parts;
3use axum::http::{HeaderMap, StatusCode};
4use axum::middleware;
5use axum::middleware::Next;
6use axum::response::Response;
7use systemprompt_models::AuthenticatedUser;
8use systemprompt_models::modules::ApiPaths;
9use systemprompt_security::TokenExtractor;
10
11#[derive(Debug, Clone)]
12pub struct ApiAuthMiddlewareConfig {
13    pub public_paths: Vec<&'static str>,
14}
15
16impl Default for ApiAuthMiddlewareConfig {
17    fn default() -> Self {
18        Self {
19            public_paths: vec![
20                ApiPaths::OAUTH_SESSION,
21                ApiPaths::OAUTH_REGISTER,
22                ApiPaths::OAUTH_AUTHORIZE,
23                ApiPaths::OAUTH_TOKEN,
24                ApiPaths::OAUTH_CALLBACK,
25                ApiPaths::OAUTH_CONSENT,
26                ApiPaths::OAUTH_WEBAUTHN_COMPLETE,
27                ApiPaths::WELLKNOWN_BASE,
28                ApiPaths::STREAM_BASE,
29                ApiPaths::CONTEXTS_WEBHOOK,
30                ApiPaths::DISCOVERY,
31            ],
32        }
33    }
34}
35
36impl ApiAuthMiddlewareConfig {
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    pub fn is_public_path(&self, path: &str) -> bool {
42        if !path.starts_with(ApiPaths::API_BASE) && !path.starts_with(ApiPaths::WELLKNOWN_BASE) {
43            return true;
44        }
45
46        self.public_paths.iter().any(|p| path.starts_with(p))
47            || path.starts_with(ApiPaths::WELLKNOWN_BASE)
48    }
49}
50
51#[derive(Debug, Clone, Copy)]
52pub struct AuthMiddleware;
53
54impl AuthMiddleware {
55    pub fn apply_auth_enrichment_layer(router: axum::Router) -> axum::Router {
56        router.layer(middleware::from_fn(move |req, next| {
57            let config = ApiAuthMiddlewareConfig::default();
58            async move { auth_enrichment_middleware(config, req, next).await }
59        }))
60    }
61}
62
63pub async fn auth_enrichment_middleware(
64    config: ApiAuthMiddlewareConfig,
65    mut req: Request,
66    next: Next,
67) -> Response {
68    let path = req.uri().path();
69
70    if config.is_public_path(path) {
71        return next.run(req).await;
72    }
73
74    if let Some(user) = extract_optional_user(req.headers()) {
75        req.extensions_mut().insert(user);
76    }
77
78    next.run(req).await
79}
80
81#[derive(Debug, Clone)]
82pub struct RequireAuth(pub AuthenticatedUser);
83
84impl<S: Send + Sync> FromRequestParts<S> for RequireAuth {
85    type Rejection = (StatusCode, &'static str);
86
87    async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
88        parts
89            .extensions
90            .get::<AuthenticatedUser>()
91            .cloned()
92            .map(RequireAuth)
93            .ok_or((StatusCode::UNAUTHORIZED, "authentication required"))
94    }
95}
96
97fn extract_optional_user(headers: &HeaderMap) -> Option<AuthenticatedUser> {
98    use systemprompt_config::SecretsBootstrap;
99    use systemprompt_oauth::validate_jwt_token;
100    use uuid::Uuid;
101
102    let token = TokenExtractor::browser_only().extract(headers).ok()?;
103
104    if token.trim().is_empty() {
105        return None;
106    }
107
108    let jwt_secret = SecretsBootstrap::jwt_secret().ok()?;
109    let config = systemprompt_models::Config::get().ok()?;
110    let claims = match validate_jwt_token(
111        &token,
112        jwt_secret,
113        &config.jwt_issuer,
114        &config.jwt_audiences,
115    ) {
116        Ok(claims) => claims,
117        Err(e) => {
118            tracing::warn!(error = %e, "JWT validation failed");
119            return None;
120        },
121    };
122
123    let user_id = Uuid::parse_str(&claims.sub).ok()?;
124
125    let permissions = claims.scope;
126    let roles = claims.roles;
127
128    Some(AuthenticatedUser::new_with_roles(
129        user_id,
130        claims.username,
131        claims.email,
132        permissions,
133        roles,
134    ))
135}