systemprompt_api/services/middleware/jwt/
token.rs1use anyhow::{anyhow, Result};
2use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
3
4use systemprompt_identifiers::{ClientId, SessionId, UserId};
5use systemprompt_models::auth::UserType;
6use systemprompt_oauth::models::JwtClaims;
7
8#[derive(Debug, Clone)]
9pub struct JwtUserContext {
10 pub user_id: UserId,
11 pub session_id: SessionId,
12 pub role: systemprompt_models::auth::Permission,
13 pub user_type: UserType,
14 pub client_id: Option<ClientId>,
15}
16
17pub struct JwtExtractor {
18 decoding_key: DecodingKey,
19 validation: Validation,
20}
21
22impl std::fmt::Debug for JwtExtractor {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("JwtExtractor")
25 .field("decoding_key", &"<DecodingKey>")
26 .field("validation", &self.validation)
27 .finish()
28 }
29}
30
31impl JwtExtractor {
32 pub fn new(jwt_secret: &str) -> Self {
33 let mut validation = Validation::new(Algorithm::HS256);
34 validation.validate_exp = true;
35 validation.validate_aud = false;
36
37 Self {
38 decoding_key: DecodingKey::from_secret(jwt_secret.as_bytes()),
39 validation,
40 }
41 }
42
43 pub fn validate_token(&self, token: &str) -> Result<(), String> {
44 match decode::<JwtClaims>(token, &self.decoding_key, &self.validation) {
45 Ok(_) => Ok(()),
46 Err(err) => {
47 let reason = err.to_string();
48 if reason.contains("InvalidSignature") || reason.contains("invalid signature") {
49 Err("Invalid signature".to_string())
50 } else if reason.contains("ExpiredSignature") || reason.contains("token expired") {
51 Err("Token expired".to_string())
52 } else if reason.contains("MissingRequiredClaim") || reason.contains("missing") {
53 Err("Missing required claim".to_string())
54 } else {
55 Err("Invalid token".to_string())
56 }
57 },
58 }
59 }
60
61 pub fn extract_user_context(&self, token: &str) -> Result<JwtUserContext> {
62 let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)?;
63
64 let session_id_str = token_data
65 .claims
66 .session_id
67 .ok_or_else(|| anyhow!("JWT must contain session_id claim"))?;
68
69 let role = *token_data
70 .claims
71 .scope
72 .first()
73 .ok_or_else(|| anyhow!("JWT must contain valid scope claim"))?;
74
75 let client_id = token_data.claims.client_id.map(ClientId::new);
76
77 Ok(JwtUserContext {
78 user_id: UserId::new(token_data.claims.sub),
79 session_id: SessionId::new(session_id_str),
80 role,
81 user_type: token_data.claims.user_type,
82 client_id,
83 })
84 }
85}