systemprompt_security/auth/
validation.rs1use axum::http::HeaderMap;
2use systemprompt_identifiers::{Actor, ContextId, SessionId, UserId};
3use systemprompt_models::auth::{JwtAudience, MAX_ACT_CHAIN_DEPTH, Permission, UserType};
4use systemprompt_models::execution::context::RequestContext;
5
6use crate::error::{AuthError, AuthResult};
7use crate::extraction::{HeaderExtractor, TokenExtractor};
8use crate::jwt::{ValidationPolicy, decode_rs256_claims};
9use crate::session::ValidatedSessionClaims;
10
11#[derive(Debug)]
12pub struct AuthValidationService {
13 issuer: String,
14 audiences: Vec<JwtAudience>,
15}
16
17impl AuthValidationService {
18 #[must_use]
19 pub const fn new(issuer: String, audiences: Vec<JwtAudience>) -> Self {
20 Self { issuer, audiences }
21 }
22
23 pub fn validate_request(&self, headers: &HeaderMap) -> AuthResult<RequestContext> {
24 let token = TokenExtractor::extract_from_authorization(headers)
25 .map_err(|_e| AuthError::MissingAuthorization)?;
26 let claims = self.validate_token(&token)?;
27 Ok(Self::create_context_from_claims(&claims, &token, headers))
28 }
29
30 fn validate_token(&self, token: &str) -> AuthResult<ValidatedSessionClaims> {
31 let policy = ValidationPolicy::issuer_scoped(&self.issuer, &self.audiences);
32 let claims = decode_rs256_claims(token, &policy)?;
33
34 if let Some(ref act) = claims.act {
35 let depth = act.depth();
36 if depth > MAX_ACT_CHAIN_DEPTH {
37 return Err(AuthError::ActChainTooDeep {
38 depth,
39 max: MAX_ACT_CHAIN_DEPTH,
40 });
41 }
42 }
43
44 let user_type = if claims.scope.contains(&Permission::Admin) {
45 UserType::Admin
46 } else {
47 claims.user_type
48 };
49
50 Ok(ValidatedSessionClaims {
51 user_id: UserId::new(claims.sub),
52 session_id: claims
53 .session_id
54 .map(SessionId::new)
55 .ok_or(AuthError::MissingSessionId)?,
56 user_type,
57 jti: claims.jti,
58 exp: claims.exp,
59 })
60 }
61
62 fn create_context_from_claims(
63 claims: &ValidatedSessionClaims,
64 token: &str,
65 headers: &HeaderMap,
66 ) -> RequestContext {
67 let session_id = claims.session_id.clone();
68 let user_id = claims.user_id.clone();
69
70 RequestContext::new(
71 session_id,
72 HeaderExtractor::extract_trace_id(headers),
73 HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate),
74 HeaderExtractor::extract_agent_name(headers),
75 )
76 .with_actor(Actor::user(user_id))
77 .with_auth_token(token)
78 .with_user_type(claims.user_type)
79 .with_jti(claims.jti.clone())
80 .with_token_exp(claims.exp)
81 }
82}