Skip to main content

systemprompt_api/services/middleware/
auth.rs

1use axum::extract::Request;
2use axum::http::HeaderMap;
3use axum::middleware;
4use axum::middleware::Next;
5use axum::response::Response;
6use systemprompt_models::modules::ApiPaths;
7use systemprompt_security::TokenExtractor;
8
9#[derive(Debug, Clone)]
10pub struct ApiAuthMiddlewareConfig {
11    pub public_paths: Vec<&'static str>,
12}
13
14impl Default for ApiAuthMiddlewareConfig {
15    fn default() -> Self {
16        Self {
17            public_paths: vec![
18                ApiPaths::OAUTH_SESSION,
19                ApiPaths::OAUTH_REGISTER,
20                ApiPaths::OAUTH_AUTHORIZE,
21                ApiPaths::OAUTH_TOKEN,
22                ApiPaths::OAUTH_CALLBACK,
23                ApiPaths::OAUTH_CONSENT,
24                ApiPaths::OAUTH_WEBAUTHN_COMPLETE,
25                ApiPaths::WELLKNOWN_BASE,
26                ApiPaths::STREAM_BASE,
27                ApiPaths::CONTEXTS_WEBHOOK,
28                ApiPaths::DISCOVERY,
29            ],
30        }
31    }
32}
33
34impl ApiAuthMiddlewareConfig {
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    pub fn is_public_path(&self, path: &str) -> bool {
40        if !path.starts_with(ApiPaths::API_BASE) && !path.starts_with(ApiPaths::WELLKNOWN_BASE) {
41            return true;
42        }
43
44        self.public_paths.iter().any(|p| path.starts_with(p))
45            || path.starts_with(ApiPaths::WELLKNOWN_BASE)
46    }
47}
48
49#[derive(Debug, Clone, Copy)]
50pub struct AuthMiddleware;
51
52impl AuthMiddleware {
53    pub fn apply_auth_layer(router: axum::Router) -> axum::Router {
54        router.layer(middleware::from_fn(move |req, next| {
55            let config = ApiAuthMiddlewareConfig::default();
56            async move { auth_middleware(config, req, next).await }
57        }))
58    }
59}
60
61pub async fn auth_middleware(
62    config: ApiAuthMiddlewareConfig,
63    mut req: Request,
64    next: Next,
65) -> Response {
66    let path = req.uri().path();
67
68    if config.is_public_path(path) {
69        return next.run(req).await;
70    }
71
72    if let Some(user) = extract_optional_user(req.headers()) {
73        req.extensions_mut().insert(user);
74    }
75
76    next.run(req).await
77}
78
79fn extract_optional_user(headers: &HeaderMap) -> Option<systemprompt_models::AuthenticatedUser> {
80    use systemprompt_models::SecretsBootstrap;
81    use systemprompt_oauth::validate_jwt_token;
82    use uuid::Uuid;
83
84    let token = TokenExtractor::browser_only().extract(headers).ok()?;
85
86    if token.trim().is_empty() {
87        return None;
88    }
89
90    let jwt_secret = SecretsBootstrap::jwt_secret().ok()?;
91    let config = systemprompt_models::Config::get().ok()?;
92    let claims = match validate_jwt_token(
93        &token,
94        jwt_secret,
95        &config.jwt_issuer,
96        &config.jwt_audiences,
97    ) {
98        Ok(claims) => claims,
99        Err(e) => {
100            tracing::warn!(error = %e, "JWT validation failed");
101            return None;
102        },
103    };
104
105    let user_id = Uuid::parse_str(&claims.sub).ok()?;
106
107    let permissions = claims.scope;
108    let roles = claims.roles;
109
110    Some(systemprompt_models::AuthenticatedUser::new_with_roles(
111        user_id,
112        claims.username,
113        claims.email,
114        permissions,
115        roles,
116    ))
117}