systemprompt_security/auth/
validation.rs1use axum::http::HeaderMap;
2use systemprompt_identifiers::{AgentName, ContextId, SessionId, TraceId, UserId};
3use systemprompt_models::auth::{JwtAudience, JwtClaims, Permission, UserType};
4use systemprompt_models::execution::context::RequestContext;
5
6use crate::error::{AuthError, AuthResult};
7use crate::extraction::HeaderExtractor;
8use crate::session::ValidatedSessionClaims;
9
10const ANONYMOUS_SESSION_ID: &str = "anonymous";
11const TEST_SESSION_ID: &str = "test";
12const TEST_TRACE_ID: &str = "test-trace";
13const TEST_CONTEXT_ID: &str = "test-context";
14const TEST_AGENT_NAME: &str = "test-agent";
15const TEST_USER_ID: &str = "test-user";
16const BEARER_PREFIX: &str = "Bearer ";
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AuthMode {
20 Required,
21 Optional,
22 Disabled,
23}
24
25#[derive(Debug)]
26pub struct AuthValidationService {
27 secret: String,
28 issuer: String,
29 audiences: Vec<JwtAudience>,
30}
31
32impl AuthValidationService {
33 #[must_use]
34 pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
35 Self {
36 secret,
37 issuer,
38 audiences,
39 }
40 }
41
42 pub fn validate_request(
43 &self,
44 headers: &HeaderMap,
45 mode: AuthMode,
46 ) -> AuthResult<RequestContext> {
47 match mode {
48 AuthMode::Required => self.validate_and_fail_fast(headers),
49 AuthMode::Optional => Ok(self.try_validate_or_anonymous(headers)),
50 AuthMode::Disabled => Ok(Self::create_test_context()),
51 }
52 }
53
54 fn validate_and_fail_fast(&self, headers: &HeaderMap) -> AuthResult<RequestContext> {
55 let token = Self::extract_token(headers).ok_or(AuthError::MissingAuthorization)?;
56
57 let claims = self.validate_token(token)?;
58 Ok(Self::create_context_from_claims(&claims, token, headers))
59 }
60
61 fn try_validate_or_anonymous(&self, headers: &HeaderMap) -> RequestContext {
62 Self::extract_token(headers).map_or_else(
63 || Self::create_anonymous_context(headers),
64 |token| {
65 self.validate_token(token)
66 .map_err(|e| {
67 tracing::debug!(error = %e, "Token validation failed, falling back to anonymous");
68 e
69 })
70 .map_or_else(
71 |_| Self::create_anonymous_context(headers),
72 |claims| Self::create_context_from_claims(&claims, token, headers),
73 )
74 },
75 )
76 }
77
78 fn extract_token(headers: &HeaderMap) -> Option<&str> {
79 headers
80 .get("authorization")
81 .and_then(|h| {
82 h.to_str()
83 .map_err(|e| {
84 tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
85 e
86 })
87 .ok()
88 })
89 .and_then(|s| s.strip_prefix(BEARER_PREFIX))
90 }
91
92 fn validate_token(&self, token: &str) -> AuthResult<ValidatedSessionClaims> {
93 use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
94
95 let mut validation = Validation::new(Algorithm::HS256);
96
97 validation.set_issuer(&[&self.issuer]);
98
99 let audience_strs: Vec<&str> = self.audiences.iter().map(JwtAudience::as_str).collect();
100 validation.set_audience(&audience_strs);
101
102 let token_data = decode::<JwtClaims>(
103 token,
104 &DecodingKey::from_secret(self.secret.as_bytes()),
105 &validation,
106 )
107 .map_err(AuthError::InvalidToken)?;
108
109 let claims = token_data.claims;
110
111 let user_type = if claims.scope.contains(&Permission::Admin) {
112 UserType::Admin
113 } else {
114 claims.user_type
115 };
116
117 Ok(ValidatedSessionClaims {
118 user_id: UserId::new(claims.sub),
119 session_id: claims
120 .session_id
121 .map(SessionId::new)
122 .ok_or(AuthError::MissingSessionId)?,
123 user_type,
124 })
125 }
126
127 fn create_context_from_claims(
128 claims: &ValidatedSessionClaims,
129 token: &str,
130 headers: &HeaderMap,
131 ) -> RequestContext {
132 let session_id = claims.session_id.clone();
133 let user_id = claims.user_id.clone();
134
135 RequestContext::new(
136 session_id,
137 HeaderExtractor::extract_trace_id(headers),
138 HeaderExtractor::extract_context_id(headers),
139 HeaderExtractor::extract_agent_name(headers),
140 )
141 .with_user_id(user_id)
142 .with_auth_token(token)
143 .with_user_type(claims.user_type)
144 }
145
146 fn create_anonymous_context(headers: &HeaderMap) -> RequestContext {
147 RequestContext::new(
148 SessionId::new(ANONYMOUS_SESSION_ID.to_string()),
149 HeaderExtractor::extract_trace_id(headers),
150 HeaderExtractor::extract_context_id(headers),
151 HeaderExtractor::extract_agent_name(headers),
152 )
153 .with_user_id(UserId::anonymous())
154 .with_user_type(UserType::Anon)
155 }
156
157 fn create_test_context() -> RequestContext {
158 RequestContext::new(
159 SessionId::new(TEST_SESSION_ID.to_string()),
160 TraceId::new(TEST_TRACE_ID.to_string()),
161 ContextId::new(TEST_CONTEXT_ID.to_string()),
162 AgentName::new(TEST_AGENT_NAME.to_string()),
163 )
164 .with_user_id(UserId::new(TEST_USER_ID.to_string()))
165 .with_user_type(UserType::User)
166 }
167}