systemprompt_security/auth/
validation.rs1use anyhow::{anyhow, Result};
2use axum::http::HeaderMap;
3use systemprompt_identifiers::{AgentName, ContextId, SessionId, TraceId, UserId};
4use systemprompt_models::auth::{JwtAudience, JwtClaims, Permission, UserType};
5use systemprompt_models::execution::context::RequestContext;
6
7use crate::extraction::HeaderExtractor;
8use crate::session::ValidatedSessionClaims;
9
10const ANONYMOUS_SESSION_ID: &str = "anonymous";
11const TEST_SESSION_ID: &str = "test";
12const TEST_TRACE_ID: &str = "test-trace";
13const TEST_CONTEXT_ID: &str = "test-context";
14const TEST_AGENT_NAME: &str = "test-agent";
15const TEST_USER_ID: &str = "test-user";
16const BEARER_PREFIX: &str = "Bearer ";
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AuthMode {
20 Required,
21 Optional,
22 Disabled,
23}
24
25#[derive(Debug)]
26pub struct AuthValidationService {
27 secret: String,
28 issuer: String,
29 audiences: Vec<JwtAudience>,
30}
31
32impl AuthValidationService {
33 pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
34 Self {
35 secret,
36 issuer,
37 audiences,
38 }
39 }
40
41 pub fn validate_request(&self, headers: &HeaderMap, mode: AuthMode) -> Result<RequestContext> {
42 match mode {
43 AuthMode::Required => self.validate_and_fail_fast(headers),
44 AuthMode::Optional => Ok(self.try_validate_or_anonymous(headers)),
45 AuthMode::Disabled => Ok(Self::create_test_context()),
46 }
47 }
48
49 fn validate_and_fail_fast(&self, headers: &HeaderMap) -> Result<RequestContext> {
50 let token =
51 Self::extract_token(headers).ok_or_else(|| anyhow!("Missing authorization header"))?;
52
53 let claims = self.validate_token(token)?;
54 Ok(Self::create_context_from_claims(&claims, token, headers))
55 }
56
57 fn try_validate_or_anonymous(&self, headers: &HeaderMap) -> RequestContext {
58 Self::extract_token(headers).map_or_else(
59 || Self::create_anonymous_context(headers),
60 |token| {
61 self.validate_token(token)
62 .map_err(|e| {
63 tracing::debug!(error = %e, "Token validation failed, falling back to anonymous");
64 e
65 })
66 .map_or_else(
67 |_| Self::create_anonymous_context(headers),
68 |claims| Self::create_context_from_claims(&claims, token, headers),
69 )
70 },
71 )
72 }
73
74 fn extract_token(headers: &HeaderMap) -> Option<&str> {
75 headers
76 .get("authorization")
77 .and_then(|h| {
78 h.to_str()
79 .map_err(|e| {
80 tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
81 e
82 })
83 .ok()
84 })
85 .and_then(|s| s.strip_prefix(BEARER_PREFIX))
86 }
87
88 fn validate_token(&self, token: &str) -> Result<ValidatedSessionClaims> {
89 use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
90
91 let mut validation = Validation::new(Algorithm::HS256);
92
93 validation.set_issuer(&[&self.issuer]);
94
95 let audience_strs: Vec<&str> = self.audiences.iter().map(JwtAudience::as_str).collect();
96 validation.set_audience(&audience_strs);
97
98 let token_data = decode::<JwtClaims>(
99 token,
100 &DecodingKey::from_secret(self.secret.as_bytes()),
101 &validation,
102 )
103 .map_err(|e| anyhow!("Invalid JWT token: {e}"))?;
104
105 let claims = token_data.claims;
106
107 let user_type = if claims.scope.contains(&Permission::Admin) {
108 UserType::Admin
109 } else {
110 claims.user_type
111 };
112
113 Ok(ValidatedSessionClaims {
114 user_id: claims.sub,
115 session_id: claims
116 .session_id
117 .ok_or_else(|| anyhow!("Missing session_id in token"))?,
118 user_type,
119 })
120 }
121
122 fn create_context_from_claims(
123 claims: &ValidatedSessionClaims,
124 token: &str,
125 headers: &HeaderMap,
126 ) -> RequestContext {
127 let session_id = SessionId::new(claims.session_id.clone());
128 let user_id = UserId::new(claims.user_id.clone());
129
130 RequestContext::new(
131 session_id,
132 HeaderExtractor::extract_trace_id(headers),
133 HeaderExtractor::extract_context_id(headers),
134 HeaderExtractor::extract_agent_name(headers),
135 )
136 .with_user_id(user_id)
137 .with_auth_token(token)
138 .with_user_type(claims.user_type)
139 }
140
141 fn create_anonymous_context(headers: &HeaderMap) -> RequestContext {
142 RequestContext::new(
143 SessionId::new(ANONYMOUS_SESSION_ID.to_string()),
144 HeaderExtractor::extract_trace_id(headers),
145 HeaderExtractor::extract_context_id(headers),
146 HeaderExtractor::extract_agent_name(headers),
147 )
148 .with_user_id(UserId::anonymous())
149 .with_user_type(UserType::Anon)
150 }
151
152 fn create_test_context() -> RequestContext {
153 RequestContext::new(
154 SessionId::new(TEST_SESSION_ID.to_string()),
155 TraceId::new(TEST_TRACE_ID.to_string()),
156 ContextId::new(TEST_CONTEXT_ID.to_string()),
157 AgentName::new(TEST_AGENT_NAME.to_string()),
158 )
159 .with_user_id(UserId::new(TEST_USER_ID.to_string()))
160 .with_user_type(UserType::User)
161 }
162}