systemprompt_security/jwt/
decode.rs1use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
12use std::collections::BTreeMap;
13use systemprompt_identifiers::{Actor, ClientId, SessionId, UserId};
14use systemprompt_models::auth::{JwtClaims, Permission, UserType};
15
16use crate::error::{AuthError, AuthResult};
17use crate::keys::authority;
18
19#[derive(Debug, Clone)]
20pub struct JwtUserContext {
21 pub user_id: UserId,
22 pub session_id: SessionId,
23 pub role: Permission,
24 pub user_type: UserType,
25 pub client_id: Option<ClientId>,
26 pub act_chain: Vec<Actor>,
27 pub attributes: BTreeMap<String, serde_json::Value>,
28 pub jti: String,
29 pub exp: i64,
30}
31
32pub fn extract_user_context(token: &str) -> AuthResult<JwtUserContext> {
33 let header = decode_header(token).map_err(AuthError::InvalidToken)?;
34 if header.alg != Algorithm::RS256 {
35 return Err(AuthError::UnsupportedAlgorithm);
36 }
37 let kid = header.kid.as_deref().ok_or(AuthError::MissingKid)?;
38 let key = authority::decoding_key_for_kid(kid)
39 .map_err(|e| AuthError::KeyLookup(e.to_string()))?
40 .ok_or_else(|| AuthError::UnknownKid(kid.to_owned()))?;
41
42 let mut validation = Validation::new(Algorithm::RS256);
43 validation.validate_exp = true;
44 validation.validate_aud = false;
45
46 let claims = decode::<JwtClaims>(token, key, &validation)
47 .map_err(AuthError::InvalidToken)?
48 .claims;
49
50 let session_id = claims.session_id.ok_or(AuthError::MissingSessionId)?;
51 let role = *claims.scope.first().ok_or(AuthError::MissingScope)?;
52 let derived_type = UserType::from_permissions(&claims.scope);
53 if derived_type != claims.user_type {
54 return Err(AuthError::UserTypeMismatch {
55 claimed: claims.user_type,
56 derived: derived_type,
57 });
58 }
59 let act_chain = claims
60 .act
61 .as_ref()
62 .map(systemprompt_models::auth::ActClaim::flatten_to_chain)
63 .unwrap_or_default();
64
65 Ok(JwtUserContext {
66 user_id: UserId::new(claims.sub),
67 session_id,
68 role,
69 user_type: derived_type,
70 client_id: claims.client_id,
71 act_chain,
72 attributes: claims.attributes,
73 jti: claims.jti,
74 exp: claims.exp,
75 })
76}