Skip to main content

systemprompt_api/services/middleware/jwt/
token.rs

1use anyhow::{Result, anyhow};
2use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
3
4use systemprompt_identifiers::{Actor, ClientId, SessionId, UserId};
5use systemprompt_models::auth::UserType;
6use systemprompt_oauth::models::JwtClaims;
7use systemprompt_security::keys::authority;
8
9#[derive(Debug, Clone)]
10pub struct JwtUserContext {
11    pub user_id: UserId,
12    pub session_id: SessionId,
13    pub role: systemprompt_models::auth::Permission,
14    pub user_type: UserType,
15    pub client_id: Option<ClientId>,
16    pub roles: Vec<String>,
17    pub department: Option<String>,
18    pub act_chain: Vec<Actor>,
19    pub jti: String,
20    pub exp: i64,
21}
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct JwtExtractor;
25
26impl JwtExtractor {
27    #[must_use]
28    pub const fn new() -> Self {
29        Self
30    }
31
32    fn build_validation() -> Validation {
33        let mut validation = Validation::new(Algorithm::RS256);
34        validation.validate_exp = true;
35        validation.validate_aud = false;
36        validation
37    }
38
39    fn decoding_key_for(token: &str) -> Result<&'static jsonwebtoken::DecodingKey, String> {
40        let header = decode_header(token).map_err(|e| format!("invalid header: {e}"))?;
41        if header.alg != Algorithm::RS256 {
42            return Err("JWT must be RS256-signed".to_string());
43        }
44        let kid = header
45            .kid
46            .as_deref()
47            .ok_or_else(|| "JWT missing `kid` header".to_string())?;
48        authority::decoding_key_for_kid(kid)
49            .map_err(|e| format!("key lookup: {e}"))?
50            .ok_or_else(|| format!("unknown `kid` `{kid}`"))
51    }
52
53    #[allow(clippy::unused_self)]
54    pub fn validate_token(&self, token: &str) -> Result<(), String> {
55        let key = Self::decoding_key_for(token)?;
56        match decode::<JwtClaims>(token, key, &Self::build_validation()) {
57            Ok(_) => Ok(()),
58            Err(err) => {
59                let reason = err.to_string();
60                if reason.contains("InvalidSignature") || reason.contains("invalid signature") {
61                    Err("Invalid signature".to_string())
62                } else if reason.contains("ExpiredSignature") || reason.contains("token expired") {
63                    Err("Token expired".to_string())
64                } else if reason.contains("MissingRequiredClaim") || reason.contains("missing") {
65                    Err("Missing required claim".to_string())
66                } else {
67                    Err("Invalid token".to_string())
68                }
69            },
70        }
71    }
72
73    #[allow(clippy::unused_self)]
74    pub fn extract_user_context(&self, token: &str) -> Result<JwtUserContext> {
75        let key = Self::decoding_key_for(token).map_err(|e| anyhow!(e))?;
76        let token_data = decode::<JwtClaims>(token, key, &Self::build_validation())?;
77
78        let session_id_str = token_data
79            .claims
80            .session_id
81            .ok_or_else(|| anyhow!("JWT must contain session_id claim"))?;
82
83        let role = *token_data
84            .claims
85            .scope
86            .first()
87            .ok_or_else(|| anyhow!("JWT must contain valid scope claim"))?;
88
89        let client_id = token_data.claims.client_id.map(ClientId::new);
90
91        // Defence-in-depth: the `user_type` claim is set at mint time from the
92        // permission set; re-derive it here and reject any token whose claim
93        // disagrees, so a forged or mis-minted type cannot ride past the gate.
94        let derived_type = UserType::from_permissions(&token_data.claims.scope);
95        if derived_type != token_data.claims.user_type {
96            return Err(anyhow!(
97                "user_type claim '{}' does not match permissions (derived '{}')",
98                token_data.claims.user_type,
99                derived_type
100            ));
101        }
102
103        let act_chain = token_data
104            .claims
105            .act
106            .as_ref()
107            .map(systemprompt_models::auth::ActClaim::flatten_to_chain)
108            .unwrap_or_default();
109
110        Ok(JwtUserContext {
111            user_id: UserId::new(token_data.claims.sub),
112            session_id: SessionId::new(session_id_str),
113            role,
114            user_type: derived_type,
115            client_id,
116            roles: token_data.claims.roles,
117            department: token_data.claims.department,
118            act_chain,
119            jti: token_data.claims.jti,
120            exp: token_data.claims.exp,
121        })
122    }
123}