systemprompt_api/services/middleware/
auth.rs1use axum::extract::Request;
2use axum::http::HeaderMap;
3use axum::middleware;
4use axum::middleware::Next;
5use axum::response::Response;
6use systemprompt_models::modules::ApiPaths;
7use systemprompt_security::TokenExtractor;
8
9#[derive(Debug, Clone)]
10pub struct ApiAuthMiddlewareConfig {
11 pub public_paths: Vec<&'static str>,
12}
13
14impl Default for ApiAuthMiddlewareConfig {
15 fn default() -> Self {
16 Self {
17 public_paths: vec![
18 ApiPaths::OAUTH_SESSION,
19 ApiPaths::OAUTH_REGISTER,
20 ApiPaths::OAUTH_AUTHORIZE,
21 ApiPaths::OAUTH_TOKEN,
22 ApiPaths::OAUTH_CALLBACK,
23 ApiPaths::OAUTH_CONSENT,
24 ApiPaths::OAUTH_WEBAUTHN_COMPLETE,
25 ApiPaths::WELLKNOWN_BASE,
26 ApiPaths::STREAM_BASE,
27 ApiPaths::CONTEXTS_WEBHOOK,
28 ApiPaths::DISCOVERY,
29 ],
30 }
31 }
32}
33
34impl ApiAuthMiddlewareConfig {
35 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn is_public_path(&self, path: &str) -> bool {
40 if !path.starts_with(ApiPaths::API_BASE) && !path.starts_with(ApiPaths::WELLKNOWN_BASE) {
41 return true;
42 }
43
44 self.public_paths.iter().any(|p| path.starts_with(p))
45 || path.starts_with(ApiPaths::WELLKNOWN_BASE)
46 }
47}
48
49#[derive(Debug, Clone, Copy)]
50pub struct AuthMiddleware;
51
52impl AuthMiddleware {
53 pub fn apply_auth_layer(router: axum::Router) -> axum::Router {
54 router.layer(middleware::from_fn(move |req, next| {
55 let config = ApiAuthMiddlewareConfig::default();
56 async move { auth_middleware(config, req, next).await }
57 }))
58 }
59}
60
61pub async fn auth_middleware(
62 config: ApiAuthMiddlewareConfig,
63 mut req: Request,
64 next: Next,
65) -> Response {
66 let path = req.uri().path();
67
68 if config.is_public_path(path) {
69 return next.run(req).await;
70 }
71
72 if let Some(user) = extract_optional_user(req.headers()) {
73 req.extensions_mut().insert(user);
74 }
75
76 next.run(req).await
77}
78
79fn extract_optional_user(headers: &HeaderMap) -> Option<systemprompt_models::AuthenticatedUser> {
80 use systemprompt_models::SecretsBootstrap;
81 use systemprompt_oauth::validate_jwt_token;
82 use uuid::Uuid;
83
84 let token = TokenExtractor::browser_only().extract(headers).ok()?;
85
86 if token.trim().is_empty() {
87 return None;
88 }
89
90 let jwt_secret = SecretsBootstrap::jwt_secret().ok()?;
91 let config = systemprompt_models::Config::get().ok()?;
92 let claims = match validate_jwt_token(
93 &token,
94 jwt_secret,
95 &config.jwt_issuer,
96 &config.jwt_audiences,
97 ) {
98 Ok(claims) => claims,
99 Err(e) => {
100 tracing::warn!(error = %e, "JWT validation failed");
101 return None;
102 },
103 };
104
105 let user_id = Uuid::parse_str(&claims.sub).ok()?;
106
107 let permissions = claims.scope;
108 let roles = claims.roles;
109
110 Some(systemprompt_models::AuthenticatedUser::new_with_roles(
111 user_id,
112 claims.username,
113 claims.email,
114 permissions,
115 roles,
116 ))
117}