Skip to main content

systemprompt_api/services/middleware/jwt/
token.rs

1use anyhow::{anyhow, Result};
2use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
3
4use systemprompt_identifiers::{ClientId, SessionId, UserId};
5use systemprompt_models::auth::UserType;
6use systemprompt_oauth::models::JwtClaims;
7
8#[derive(Debug, Clone)]
9pub struct JwtUserContext {
10    pub user_id: UserId,
11    pub session_id: SessionId,
12    pub role: systemprompt_models::auth::Permission,
13    pub user_type: UserType,
14    pub client_id: Option<ClientId>,
15}
16
17pub struct JwtExtractor {
18    decoding_key: DecodingKey,
19    validation: Validation,
20}
21
22impl std::fmt::Debug for JwtExtractor {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("JwtExtractor")
25            .field("decoding_key", &"<DecodingKey>")
26            .field("validation", &self.validation)
27            .finish()
28    }
29}
30
31impl JwtExtractor {
32    pub fn new(jwt_secret: &str) -> Self {
33        let mut validation = Validation::new(Algorithm::HS256);
34        validation.validate_exp = true;
35        validation.validate_aud = false;
36
37        Self {
38            decoding_key: DecodingKey::from_secret(jwt_secret.as_bytes()),
39            validation,
40        }
41    }
42
43    pub fn validate_token(&self, token: &str) -> Result<(), String> {
44        match decode::<JwtClaims>(token, &self.decoding_key, &self.validation) {
45            Ok(_) => Ok(()),
46            Err(err) => {
47                let reason = err.to_string();
48                if reason.contains("InvalidSignature") || reason.contains("invalid signature") {
49                    Err("Invalid signature".to_string())
50                } else if reason.contains("ExpiredSignature") || reason.contains("token expired") {
51                    Err("Token expired".to_string())
52                } else if reason.contains("MissingRequiredClaim") || reason.contains("missing") {
53                    Err("Missing required claim".to_string())
54                } else {
55                    Err("Invalid token".to_string())
56                }
57            },
58        }
59    }
60
61    pub fn extract_user_context(&self, token: &str) -> Result<JwtUserContext> {
62        let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)?;
63
64        let session_id_str = token_data
65            .claims
66            .session_id
67            .ok_or_else(|| anyhow!("JWT must contain session_id claim"))?;
68
69        let role = *token_data
70            .claims
71            .scope
72            .first()
73            .ok_or_else(|| anyhow!("JWT must contain valid scope claim"))?;
74
75        let client_id = token_data.claims.client_id.map(ClientId::new);
76
77        Ok(JwtUserContext {
78            user_id: UserId::new(token_data.claims.sub),
79            session_id: SessionId::new(session_id_str),
80            role,
81            user_type: token_data.claims.user_type,
82            client_id,
83        })
84    }
85}