systemprompt_api/services/middleware/jwt/
token.rs1use anyhow::{Result, anyhow};
2use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
3
4use systemprompt_identifiers::{Actor, ClientId, SessionId, UserId};
5use systemprompt_models::auth::UserType;
6use systemprompt_oauth::models::JwtClaims;
7use systemprompt_security::keys::authority;
8
9#[derive(Debug, Clone)]
10pub struct JwtUserContext {
11 pub user_id: UserId,
12 pub session_id: SessionId,
13 pub role: systemprompt_models::auth::Permission,
14 pub user_type: UserType,
15 pub client_id: Option<ClientId>,
16 pub roles: Vec<String>,
17 pub department: Option<String>,
18 pub act_chain: Vec<Actor>,
19 pub jti: String,
20 pub exp: i64,
21}
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct JwtExtractor;
25
26impl JwtExtractor {
27 #[must_use]
28 pub const fn new() -> Self {
29 Self
30 }
31
32 fn build_validation() -> Validation {
33 let mut validation = Validation::new(Algorithm::RS256);
34 validation.validate_exp = true;
35 validation.validate_aud = false;
36 validation
37 }
38
39 fn decoding_key_for(token: &str) -> Result<&'static jsonwebtoken::DecodingKey, String> {
40 let header = decode_header(token).map_err(|e| format!("invalid header: {e}"))?;
41 if header.alg != Algorithm::RS256 {
42 return Err("JWT must be RS256-signed".to_string());
43 }
44 let kid = header
45 .kid
46 .as_deref()
47 .ok_or_else(|| "JWT missing `kid` header".to_string())?;
48 authority::decoding_key_for_kid(kid)
49 .map_err(|e| format!("key lookup: {e}"))?
50 .ok_or_else(|| format!("unknown `kid` `{kid}`"))
51 }
52
53 #[allow(clippy::unused_self)]
54 pub fn validate_token(&self, token: &str) -> Result<(), String> {
55 let key = Self::decoding_key_for(token)?;
56 match decode::<JwtClaims>(token, key, &Self::build_validation()) {
57 Ok(_) => Ok(()),
58 Err(err) => {
59 let reason = err.to_string();
60 if reason.contains("InvalidSignature") || reason.contains("invalid signature") {
61 Err("Invalid signature".to_string())
62 } else if reason.contains("ExpiredSignature") || reason.contains("token expired") {
63 Err("Token expired".to_string())
64 } else if reason.contains("MissingRequiredClaim") || reason.contains("missing") {
65 Err("Missing required claim".to_string())
66 } else {
67 Err("Invalid token".to_string())
68 }
69 },
70 }
71 }
72
73 #[allow(clippy::unused_self)]
74 pub fn extract_user_context(&self, token: &str) -> Result<JwtUserContext> {
75 let key = Self::decoding_key_for(token).map_err(|e| anyhow!(e))?;
76 let token_data = decode::<JwtClaims>(token, key, &Self::build_validation())?;
77
78 let session_id_str = token_data
79 .claims
80 .session_id
81 .ok_or_else(|| anyhow!("JWT must contain session_id claim"))?;
82
83 let role = *token_data
84 .claims
85 .scope
86 .first()
87 .ok_or_else(|| anyhow!("JWT must contain valid scope claim"))?;
88
89 let client_id = token_data.claims.client_id.map(ClientId::new);
90
91 let derived_type = UserType::from_permissions(&token_data.claims.scope);
95 if derived_type != token_data.claims.user_type {
96 return Err(anyhow!(
97 "user_type claim '{}' does not match permissions (derived '{}')",
98 token_data.claims.user_type,
99 derived_type
100 ));
101 }
102
103 let act_chain = token_data
104 .claims
105 .act
106 .as_ref()
107 .map(systemprompt_models::auth::ActClaim::flatten_to_chain)
108 .unwrap_or_default();
109
110 Ok(JwtUserContext {
111 user_id: UserId::new(token_data.claims.sub),
112 session_id: SessionId::new(session_id_str),
113 role,
114 user_type: derived_type,
115 client_id,
116 roles: token_data.claims.roles,
117 department: token_data.claims.department,
118 act_chain,
119 jti: token_data.claims.jti,
120 exp: token_data.claims.exp,
121 })
122 }
123}