Skip to main content

systemprompt_api/services/middleware/jwt/
token.rs

1use anyhow::{Result, anyhow};
2use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
3
4use systemprompt_identifiers::{ClientId, SessionId, UserId};
5use systemprompt_models::auth::UserType;
6use systemprompt_oauth::models::JwtClaims;
7
8#[derive(Debug, Clone)]
9pub struct JwtUserContext {
10    pub user_id: UserId,
11    pub session_id: SessionId,
12    pub role: systemprompt_models::auth::Permission,
13    pub user_type: UserType,
14    pub client_id: Option<ClientId>,
15    pub roles: Vec<String>,
16    pub department: Option<String>,
17}
18
19pub struct JwtExtractor {
20    decoding_key: DecodingKey,
21    validation: Validation,
22}
23
24impl std::fmt::Debug for JwtExtractor {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("JwtExtractor")
27            .field("decoding_key", &"<DecodingKey>")
28            .field("validation", &self.validation)
29            .finish()
30    }
31}
32
33impl JwtExtractor {
34    pub fn new(jwt_secret: &str) -> Self {
35        let mut validation = Validation::new(Algorithm::HS256);
36        validation.validate_exp = true;
37        validation.validate_aud = false;
38
39        Self {
40            decoding_key: DecodingKey::from_secret(jwt_secret.as_bytes()),
41            validation,
42        }
43    }
44
45    pub fn validate_token(&self, token: &str) -> Result<(), String> {
46        match decode::<JwtClaims>(token, &self.decoding_key, &self.validation) {
47            Ok(_) => Ok(()),
48            Err(err) => {
49                let reason = err.to_string();
50                if reason.contains("InvalidSignature") || reason.contains("invalid signature") {
51                    Err("Invalid signature".to_string())
52                } else if reason.contains("ExpiredSignature") || reason.contains("token expired") {
53                    Err("Token expired".to_string())
54                } else if reason.contains("MissingRequiredClaim") || reason.contains("missing") {
55                    Err("Missing required claim".to_string())
56                } else {
57                    Err("Invalid token".to_string())
58                }
59            },
60        }
61    }
62
63    pub fn extract_user_context(&self, token: &str) -> Result<JwtUserContext> {
64        let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)?;
65
66        let session_id_str = token_data
67            .claims
68            .session_id
69            .ok_or_else(|| anyhow!("JWT must contain session_id claim"))?;
70
71        let role = *token_data
72            .claims
73            .scope
74            .first()
75            .ok_or_else(|| anyhow!("JWT must contain valid scope claim"))?;
76
77        let client_id = token_data.claims.client_id.map(ClientId::new);
78
79        Ok(JwtUserContext {
80            user_id: UserId::new(token_data.claims.sub),
81            session_id: SessionId::new(session_id_str),
82            role,
83            user_type: token_data.claims.user_type,
84            client_id,
85            roles: token_data.claims.roles,
86            department: token_data.claims.department,
87        })
88    }
89}