Skip to main content

systemprompt_api/services/middleware/jwt/
token.rs

1use anyhow::{Result, anyhow};
2use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
3
4use systemprompt_identifiers::{Actor, ClientId, SessionId, UserId};
5use systemprompt_models::auth::UserType;
6use systemprompt_oauth::models::JwtClaims;
7use systemprompt_security::keys::authority;
8
9#[derive(Debug, Clone)]
10pub struct JwtUserContext {
11    pub user_id: UserId,
12    pub session_id: SessionId,
13    pub role: systemprompt_models::auth::Permission,
14    pub user_type: UserType,
15    pub client_id: Option<ClientId>,
16    pub roles: Vec<String>,
17    pub department: Option<String>,
18    pub act_chain: Vec<Actor>,
19    pub jti: String,
20    pub exp: i64,
21}
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct JwtExtractor;
25
26impl JwtExtractor {
27    #[must_use]
28    pub const fn new() -> Self {
29        Self
30    }
31
32    fn build_validation() -> Validation {
33        let mut validation = Validation::new(Algorithm::RS256);
34        validation.validate_exp = true;
35        validation.validate_aud = false;
36        validation
37    }
38
39    fn decoding_key_for(token: &str) -> Result<&'static jsonwebtoken::DecodingKey, String> {
40        let header = decode_header(token).map_err(|e| format!("invalid header: {e}"))?;
41        if header.alg != Algorithm::RS256 {
42            return Err("JWT must be RS256-signed".to_owned());
43        }
44        let kid = header
45            .kid
46            .as_deref()
47            .ok_or_else(|| "JWT missing `kid` header".to_owned())?;
48        authority::decoding_key_for_kid(kid)
49            .map_err(|e| format!("key lookup: {e}"))?
50            .ok_or_else(|| format!("unknown `kid` `{kid}`"))
51    }
52
53    #[expect(
54        clippy::unused_self,
55        reason = "method is on JwtMiddleware so future caching or context can be added without \
56                  changing the API"
57    )]
58    pub fn validate_token(&self, token: &str) -> Result<(), String> {
59        let key = Self::decoding_key_for(token)?;
60        match decode::<JwtClaims>(token, key, &Self::build_validation()) {
61            Ok(_) => Ok(()),
62            Err(err) => {
63                let reason = err.to_string();
64                if reason.contains("InvalidSignature") || reason.contains("invalid signature") {
65                    Err("Invalid signature".to_owned())
66                } else if reason.contains("ExpiredSignature") || reason.contains("token expired") {
67                    Err("Token expired".to_owned())
68                } else if reason.contains("MissingRequiredClaim") || reason.contains("missing") {
69                    Err("Missing required claim".to_owned())
70                } else {
71                    Err("Invalid token".to_owned())
72                }
73            },
74        }
75    }
76
77    #[expect(
78        clippy::unused_self,
79        reason = "method is on JwtMiddleware so future caching or context can be added without \
80                  changing the API"
81    )]
82    pub fn extract_user_context(&self, token: &str) -> Result<JwtUserContext> {
83        let key = Self::decoding_key_for(token).map_err(|e| anyhow!(e))?;
84        let token_data = decode::<JwtClaims>(token, key, &Self::build_validation())?;
85
86        let session_id_str = token_data
87            .claims
88            .session_id
89            .ok_or_else(|| anyhow!("JWT must contain session_id claim"))?;
90
91        let role = *token_data
92            .claims
93            .scope
94            .first()
95            .ok_or_else(|| anyhow!("JWT must contain valid scope claim"))?;
96
97        let client_id = token_data.claims.client_id.map(ClientId::new);
98
99        // Defence-in-depth: the `user_type` claim is set at mint time from the
100        // permission set; re-derive it here and reject any token whose claim
101        // disagrees, so a forged or mis-minted type cannot ride past the gate.
102        let derived_type = UserType::from_permissions(&token_data.claims.scope);
103        if derived_type != token_data.claims.user_type {
104            return Err(anyhow!(
105                "user_type claim '{}' does not match permissions (derived '{}')",
106                token_data.claims.user_type,
107                derived_type
108            ));
109        }
110
111        let act_chain = token_data
112            .claims
113            .act
114            .as_ref()
115            .map(systemprompt_models::auth::ActClaim::flatten_to_chain)
116            .unwrap_or_default();
117
118        Ok(JwtUserContext {
119            user_id: UserId::new(token_data.claims.sub),
120            session_id: SessionId::new(session_id_str),
121            role,
122            user_type: derived_type,
123            client_id,
124            roles: token_data.claims.roles,
125            department: token_data.claims.department,
126            act_chain,
127            jti: token_data.claims.jti,
128            exp: token_data.claims.exp,
129        })
130    }
131}