systemprompt_security/auth/
validation.rs1use axum::http::HeaderMap;
2use systemprompt_identifiers::{AgentName, ContextId, SessionId, TraceId, UserId};
3use systemprompt_models::auth::{JwtAudience, JwtClaims, Permission, UserType};
4use systemprompt_models::execution::context::RequestContext;
5
6use crate::error::{AuthError, AuthResult};
7use crate::extraction::HeaderExtractor;
8use crate::session::ValidatedSessionClaims;
9
10const ANONYMOUS_SESSION_ID: &str = "anonymous";
11const TEST_SESSION_ID: &str = "test";
12const TEST_TRACE_ID: &str = "test-trace";
13const TEST_AGENT_NAME: &str = "test-agent";
14const TEST_USER_ID: &str = "test-user";
15const BEARER_PREFIX: &str = "Bearer ";
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum AuthMode {
19 Required,
20 Optional,
21 Disabled,
22}
23
24#[derive(Debug)]
25pub struct AuthValidationService {
26 secret: String,
27 issuer: String,
28 audiences: Vec<JwtAudience>,
29}
30
31impl AuthValidationService {
32 #[must_use]
33 pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
34 Self {
35 secret,
36 issuer,
37 audiences,
38 }
39 }
40
41 pub fn validate_request(
42 &self,
43 headers: &HeaderMap,
44 mode: AuthMode,
45 ) -> AuthResult<RequestContext> {
46 match mode {
47 AuthMode::Required => self.validate_and_fail_fast(headers),
48 AuthMode::Optional => Ok(self.try_validate_or_anonymous(headers)),
49 AuthMode::Disabled => Ok(Self::create_test_context()),
50 }
51 }
52
53 fn validate_and_fail_fast(&self, headers: &HeaderMap) -> AuthResult<RequestContext> {
54 let token = Self::extract_token(headers).ok_or(AuthError::MissingAuthorization)?;
55
56 let claims = self.validate_token(token)?;
57 Ok(Self::create_context_from_claims(&claims, token, headers))
58 }
59
60 fn try_validate_or_anonymous(&self, headers: &HeaderMap) -> RequestContext {
61 Self::extract_token(headers).map_or_else(
62 || Self::create_anonymous_context(headers),
63 |token| {
64 self.validate_token(token)
65 .map_err(|e| {
66 tracing::debug!(error = %e, "Token validation failed, falling back to anonymous");
67 e
68 })
69 .map_or_else(
70 |_| Self::create_anonymous_context(headers),
71 |claims| Self::create_context_from_claims(&claims, token, headers),
72 )
73 },
74 )
75 }
76
77 fn extract_token(headers: &HeaderMap) -> Option<&str> {
78 headers
79 .get("authorization")
80 .and_then(|h| {
81 h.to_str()
82 .map_err(|e| {
83 tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
84 e
85 })
86 .ok()
87 })
88 .and_then(|s| s.strip_prefix(BEARER_PREFIX))
89 }
90
91 fn validate_token(&self, token: &str) -> AuthResult<ValidatedSessionClaims> {
92 use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
93
94 let mut validation = Validation::new(Algorithm::HS256);
95
96 validation.set_issuer(&[&self.issuer]);
97
98 let audience_strs: Vec<&str> = self.audiences.iter().map(JwtAudience::as_str).collect();
99 validation.set_audience(&audience_strs);
100
101 let token_data = decode::<JwtClaims>(
102 token,
103 &DecodingKey::from_secret(self.secret.as_bytes()),
104 &validation,
105 )
106 .map_err(AuthError::InvalidToken)?;
107
108 let claims = token_data.claims;
109
110 let user_type = if claims.scope.contains(&Permission::Admin) {
111 UserType::Admin
112 } else {
113 claims.user_type
114 };
115
116 Ok(ValidatedSessionClaims {
117 user_id: UserId::new(claims.sub),
118 session_id: claims
119 .session_id
120 .map(SessionId::new)
121 .ok_or(AuthError::MissingSessionId)?,
122 user_type,
123 })
124 }
125
126 fn create_context_from_claims(
127 claims: &ValidatedSessionClaims,
128 token: &str,
129 headers: &HeaderMap,
130 ) -> RequestContext {
131 let session_id = claims.session_id.clone();
132 let user_id = claims.user_id.clone();
133
134 RequestContext::new(
135 session_id,
136 HeaderExtractor::extract_trace_id(headers),
137 HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate),
138 HeaderExtractor::extract_agent_name(headers),
139 )
140 .with_user_id(user_id)
141 .with_auth_token(token)
142 .with_user_type(claims.user_type)
143 }
144
145 fn create_anonymous_context(headers: &HeaderMap) -> RequestContext {
146 RequestContext::new(
147 SessionId::new(ANONYMOUS_SESSION_ID.to_string()),
148 HeaderExtractor::extract_trace_id(headers),
149 HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate),
150 HeaderExtractor::extract_agent_name(headers),
151 )
152 .with_user_id(UserId::anonymous())
153 .with_user_type(UserType::Anon)
154 }
155
156 fn create_test_context() -> RequestContext {
157 RequestContext::new(
158 SessionId::new(TEST_SESSION_ID.to_string()),
159 TraceId::new(TEST_TRACE_ID.to_string()),
160 ContextId::generate(),
161 AgentName::new(TEST_AGENT_NAME.to_string()),
162 )
163 .with_user_id(UserId::new(TEST_USER_ID.to_string()))
164 .with_user_type(UserType::User)
165 }
166}