Skip to main content

systemprompt_security/auth/
validation.rs

1use axum::http::HeaderMap;
2use systemprompt_identifiers::{AgentName, ContextId, SessionId, TraceId, UserId};
3use systemprompt_models::auth::{JwtAudience, JwtClaims, Permission, UserType};
4use systemprompt_models::execution::context::RequestContext;
5
6use crate::error::{AuthError, AuthResult};
7use crate::extraction::HeaderExtractor;
8use crate::session::ValidatedSessionClaims;
9
10const ANONYMOUS_SESSION_ID: &str = "anonymous";
11const TEST_SESSION_ID: &str = "test";
12const TEST_TRACE_ID: &str = "test-trace";
13const TEST_CONTEXT_ID: &str = "test-context";
14const TEST_AGENT_NAME: &str = "test-agent";
15const TEST_USER_ID: &str = "test-user";
16const BEARER_PREFIX: &str = "Bearer ";
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AuthMode {
20    Required,
21    Optional,
22    Disabled,
23}
24
25#[derive(Debug)]
26pub struct AuthValidationService {
27    secret: String,
28    issuer: String,
29    audiences: Vec<JwtAudience>,
30}
31
32impl AuthValidationService {
33    #[must_use]
34    pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
35        Self {
36            secret,
37            issuer,
38            audiences,
39        }
40    }
41
42    pub fn validate_request(
43        &self,
44        headers: &HeaderMap,
45        mode: AuthMode,
46    ) -> AuthResult<RequestContext> {
47        match mode {
48            AuthMode::Required => self.validate_and_fail_fast(headers),
49            AuthMode::Optional => Ok(self.try_validate_or_anonymous(headers)),
50            AuthMode::Disabled => Ok(Self::create_test_context()),
51        }
52    }
53
54    fn validate_and_fail_fast(&self, headers: &HeaderMap) -> AuthResult<RequestContext> {
55        let token = Self::extract_token(headers).ok_or(AuthError::MissingAuthorization)?;
56
57        let claims = self.validate_token(token)?;
58        Ok(Self::create_context_from_claims(&claims, token, headers))
59    }
60
61    fn try_validate_or_anonymous(&self, headers: &HeaderMap) -> RequestContext {
62        Self::extract_token(headers).map_or_else(
63            || Self::create_anonymous_context(headers),
64            |token| {
65                self.validate_token(token)
66                    .map_err(|e| {
67                        tracing::debug!(error = %e, "Token validation failed, falling back to anonymous");
68                        e
69                    })
70                    .map_or_else(
71                        |_| Self::create_anonymous_context(headers),
72                        |claims| Self::create_context_from_claims(&claims, token, headers),
73                    )
74            },
75        )
76    }
77
78    fn extract_token(headers: &HeaderMap) -> Option<&str> {
79        headers
80            .get("authorization")
81            .and_then(|h| {
82                h.to_str()
83                    .map_err(|e| {
84                        tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
85                        e
86                    })
87                    .ok()
88            })
89            .and_then(|s| s.strip_prefix(BEARER_PREFIX))
90    }
91
92    fn validate_token(&self, token: &str) -> AuthResult<ValidatedSessionClaims> {
93        use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
94
95        let mut validation = Validation::new(Algorithm::HS256);
96
97        validation.set_issuer(&[&self.issuer]);
98
99        let audience_strs: Vec<&str> = self.audiences.iter().map(JwtAudience::as_str).collect();
100        validation.set_audience(&audience_strs);
101
102        let token_data = decode::<JwtClaims>(
103            token,
104            &DecodingKey::from_secret(self.secret.as_bytes()),
105            &validation,
106        )
107        .map_err(AuthError::InvalidToken)?;
108
109        let claims = token_data.claims;
110
111        let user_type = if claims.scope.contains(&Permission::Admin) {
112            UserType::Admin
113        } else {
114            claims.user_type
115        };
116
117        Ok(ValidatedSessionClaims {
118            user_id: UserId::new(claims.sub),
119            session_id: claims
120                .session_id
121                .map(SessionId::new)
122                .ok_or(AuthError::MissingSessionId)?,
123            user_type,
124        })
125    }
126
127    fn create_context_from_claims(
128        claims: &ValidatedSessionClaims,
129        token: &str,
130        headers: &HeaderMap,
131    ) -> RequestContext {
132        let session_id = claims.session_id.clone();
133        let user_id = claims.user_id.clone();
134
135        RequestContext::new(
136            session_id,
137            HeaderExtractor::extract_trace_id(headers),
138            HeaderExtractor::extract_context_id(headers),
139            HeaderExtractor::extract_agent_name(headers),
140        )
141        .with_user_id(user_id)
142        .with_auth_token(token)
143        .with_user_type(claims.user_type)
144    }
145
146    fn create_anonymous_context(headers: &HeaderMap) -> RequestContext {
147        RequestContext::new(
148            SessionId::new(ANONYMOUS_SESSION_ID.to_string()),
149            HeaderExtractor::extract_trace_id(headers),
150            HeaderExtractor::extract_context_id(headers),
151            HeaderExtractor::extract_agent_name(headers),
152        )
153        .with_user_id(UserId::anonymous())
154        .with_user_type(UserType::Anon)
155    }
156
157    fn create_test_context() -> RequestContext {
158        RequestContext::new(
159            SessionId::new(TEST_SESSION_ID.to_string()),
160            TraceId::new(TEST_TRACE_ID.to_string()),
161            ContextId::new(TEST_CONTEXT_ID.to_string()),
162            AgentName::new(TEST_AGENT_NAME.to_string()),
163        )
164        .with_user_id(UserId::new(TEST_USER_ID.to_string()))
165        .with_user_type(UserType::User)
166    }
167}