systemprompt_api/services/middleware/jwt/
token.rs1use anyhow::{Result, anyhow};
2use jsonwebtoken::{Algorithm, Validation, decode, decode_header};
3
4use systemprompt_identifiers::{Actor, ClientId, SessionId, UserId};
5use systemprompt_models::auth::UserType;
6use systemprompt_oauth::models::JwtClaims;
7use systemprompt_security::keys::authority;
8
9#[derive(Debug, Clone)]
10pub struct JwtUserContext {
11 pub user_id: UserId,
12 pub session_id: SessionId,
13 pub role: systemprompt_models::auth::Permission,
14 pub user_type: UserType,
15 pub client_id: Option<ClientId>,
16 pub roles: Vec<String>,
17 pub department: Option<String>,
18 pub act_chain: Vec<Actor>,
19 pub jti: String,
20 pub exp: i64,
21}
22
23#[derive(Debug, Default, Clone, Copy)]
24pub struct JwtExtractor;
25
26impl JwtExtractor {
27 #[must_use]
28 pub const fn new() -> Self {
29 Self
30 }
31
32 fn build_validation() -> Validation {
33 let mut validation = Validation::new(Algorithm::RS256);
34 validation.validate_exp = true;
35 validation.validate_aud = false;
36 validation
37 }
38
39 fn decoding_key_for(token: &str) -> Result<&'static jsonwebtoken::DecodingKey, String> {
40 let header = decode_header(token).map_err(|e| format!("invalid header: {e}"))?;
41 if header.alg != Algorithm::RS256 {
42 return Err("JWT must be RS256-signed".to_owned());
43 }
44 let kid = header
45 .kid
46 .as_deref()
47 .ok_or_else(|| "JWT missing `kid` header".to_owned())?;
48 authority::decoding_key_for_kid(kid)
49 .map_err(|e| format!("key lookup: {e}"))?
50 .ok_or_else(|| format!("unknown `kid` `{kid}`"))
51 }
52
53 #[expect(
54 clippy::unused_self,
55 reason = "method is on JwtMiddleware so future caching or context can be added without \
56 changing the API"
57 )]
58 pub fn validate_token(&self, token: &str) -> Result<(), String> {
59 let key = Self::decoding_key_for(token)?;
60 match decode::<JwtClaims>(token, key, &Self::build_validation()) {
61 Ok(_) => Ok(()),
62 Err(err) => {
63 let reason = err.to_string();
64 if reason.contains("InvalidSignature") || reason.contains("invalid signature") {
65 Err("Invalid signature".to_owned())
66 } else if reason.contains("ExpiredSignature") || reason.contains("token expired") {
67 Err("Token expired".to_owned())
68 } else if reason.contains("MissingRequiredClaim") || reason.contains("missing") {
69 Err("Missing required claim".to_owned())
70 } else {
71 Err("Invalid token".to_owned())
72 }
73 },
74 }
75 }
76
77 #[expect(
78 clippy::unused_self,
79 reason = "method is on JwtMiddleware so future caching or context can be added without \
80 changing the API"
81 )]
82 pub fn extract_user_context(&self, token: &str) -> Result<JwtUserContext> {
83 let key = Self::decoding_key_for(token).map_err(|e| anyhow!(e))?;
84 let token_data = decode::<JwtClaims>(token, key, &Self::build_validation())?;
85
86 let session_id_str = token_data
87 .claims
88 .session_id
89 .ok_or_else(|| anyhow!("JWT must contain session_id claim"))?;
90
91 let role = *token_data
92 .claims
93 .scope
94 .first()
95 .ok_or_else(|| anyhow!("JWT must contain valid scope claim"))?;
96
97 let client_id = token_data.claims.client_id.map(ClientId::new);
98
99 let derived_type = UserType::from_permissions(&token_data.claims.scope);
103 if derived_type != token_data.claims.user_type {
104 return Err(anyhow!(
105 "user_type claim '{}' does not match permissions (derived '{}')",
106 token_data.claims.user_type,
107 derived_type
108 ));
109 }
110
111 let act_chain = token_data
112 .claims
113 .act
114 .as_ref()
115 .map(systemprompt_models::auth::ActClaim::flatten_to_chain)
116 .unwrap_or_default();
117
118 Ok(JwtUserContext {
119 user_id: UserId::new(token_data.claims.sub),
120 session_id: SessionId::new(session_id_str),
121 role,
122 user_type: derived_type,
123 client_id,
124 roles: token_data.claims.roles,
125 department: token_data.claims.department,
126 act_chain,
127 jti: token_data.claims.jti,
128 exp: token_data.claims.exp,
129 })
130 }
131}