systemprompt_api/services/middleware/jwt/
token.rs1use anyhow::{Result, anyhow};
2use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
3
4use systemprompt_identifiers::{ClientId, SessionId, UserId};
5use systemprompt_models::auth::UserType;
6use systemprompt_oauth::models::JwtClaims;
7
8#[derive(Debug, Clone)]
9pub struct JwtUserContext {
10 pub user_id: UserId,
11 pub session_id: SessionId,
12 pub role: systemprompt_models::auth::Permission,
13 pub user_type: UserType,
14 pub client_id: Option<ClientId>,
15 pub roles: Vec<String>,
16 pub department: Option<String>,
17}
18
19pub struct JwtExtractor {
20 decoding_key: DecodingKey,
21 validation: Validation,
22}
23
24impl std::fmt::Debug for JwtExtractor {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("JwtExtractor")
27 .field("decoding_key", &"<DecodingKey>")
28 .field("validation", &self.validation)
29 .finish()
30 }
31}
32
33impl JwtExtractor {
34 pub fn new(jwt_secret: &str) -> Self {
35 let mut validation = Validation::new(Algorithm::HS256);
36 validation.validate_exp = true;
37 validation.validate_aud = false;
38
39 Self {
40 decoding_key: DecodingKey::from_secret(jwt_secret.as_bytes()),
41 validation,
42 }
43 }
44
45 pub fn validate_token(&self, token: &str) -> Result<(), String> {
46 match decode::<JwtClaims>(token, &self.decoding_key, &self.validation) {
47 Ok(_) => Ok(()),
48 Err(err) => {
49 let reason = err.to_string();
50 if reason.contains("InvalidSignature") || reason.contains("invalid signature") {
51 Err("Invalid signature".to_string())
52 } else if reason.contains("ExpiredSignature") || reason.contains("token expired") {
53 Err("Token expired".to_string())
54 } else if reason.contains("MissingRequiredClaim") || reason.contains("missing") {
55 Err("Missing required claim".to_string())
56 } else {
57 Err("Invalid token".to_string())
58 }
59 },
60 }
61 }
62
63 pub fn extract_user_context(&self, token: &str) -> Result<JwtUserContext> {
64 let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)?;
65
66 let session_id_str = token_data
67 .claims
68 .session_id
69 .ok_or_else(|| anyhow!("JWT must contain session_id claim"))?;
70
71 let role = *token_data
72 .claims
73 .scope
74 .first()
75 .ok_or_else(|| anyhow!("JWT must contain valid scope claim"))?;
76
77 let client_id = token_data.claims.client_id.map(ClientId::new);
78
79 Ok(JwtUserContext {
80 user_id: UserId::new(token_data.claims.sub),
81 session_id: SessionId::new(session_id_str),
82 role,
83 user_type: token_data.claims.user_type,
84 client_id,
85 roles: token_data.claims.roles,
86 department: token_data.claims.department,
87 })
88 }
89}