Skip to main content

systemprompt_security/auth/
validation.rs

1use anyhow::{anyhow, Result};
2use axum::http::HeaderMap;
3use systemprompt_identifiers::{AgentName, ContextId, SessionId, TraceId, UserId};
4use systemprompt_models::auth::{JwtAudience, JwtClaims, Permission, UserType};
5use systemprompt_models::execution::context::RequestContext;
6
7use crate::extraction::HeaderExtractor;
8use crate::session::ValidatedSessionClaims;
9
10const ANONYMOUS_SESSION_ID: &str = "anonymous";
11const TEST_SESSION_ID: &str = "test";
12const TEST_TRACE_ID: &str = "test-trace";
13const TEST_CONTEXT_ID: &str = "test-context";
14const TEST_AGENT_NAME: &str = "test-agent";
15const TEST_USER_ID: &str = "test-user";
16const BEARER_PREFIX: &str = "Bearer ";
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AuthMode {
20    Required,
21    Optional,
22    Disabled,
23}
24
25#[derive(Debug)]
26pub struct AuthValidationService {
27    secret: String,
28    issuer: String,
29    audiences: Vec<JwtAudience>,
30}
31
32impl AuthValidationService {
33    pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
34        Self {
35            secret,
36            issuer,
37            audiences,
38        }
39    }
40
41    pub fn validate_request(&self, headers: &HeaderMap, mode: AuthMode) -> Result<RequestContext> {
42        match mode {
43            AuthMode::Required => self.validate_and_fail_fast(headers),
44            AuthMode::Optional => Ok(self.try_validate_or_anonymous(headers)),
45            AuthMode::Disabled => Ok(Self::create_test_context()),
46        }
47    }
48
49    fn validate_and_fail_fast(&self, headers: &HeaderMap) -> Result<RequestContext> {
50        let token =
51            Self::extract_token(headers).ok_or_else(|| anyhow!("Missing authorization header"))?;
52
53        let claims = self.validate_token(token)?;
54        Ok(Self::create_context_from_claims(&claims, token, headers))
55    }
56
57    fn try_validate_or_anonymous(&self, headers: &HeaderMap) -> RequestContext {
58        Self::extract_token(headers).map_or_else(
59            || Self::create_anonymous_context(headers),
60            |token| {
61                self.validate_token(token)
62                    .map_err(|e| {
63                        tracing::debug!(error = %e, "Token validation failed, falling back to anonymous");
64                        e
65                    })
66                    .map_or_else(
67                        |_| Self::create_anonymous_context(headers),
68                        |claims| Self::create_context_from_claims(&claims, token, headers),
69                    )
70            },
71        )
72    }
73
74    fn extract_token(headers: &HeaderMap) -> Option<&str> {
75        headers
76            .get("authorization")
77            .and_then(|h| {
78                h.to_str()
79                    .map_err(|e| {
80                        tracing::debug!(error = %e, "Authorization header contains non-ASCII characters");
81                        e
82                    })
83                    .ok()
84            })
85            .and_then(|s| s.strip_prefix(BEARER_PREFIX))
86    }
87
88    fn validate_token(&self, token: &str) -> Result<ValidatedSessionClaims> {
89        use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
90
91        let mut validation = Validation::new(Algorithm::HS256);
92
93        validation.set_issuer(&[&self.issuer]);
94
95        let audience_strs: Vec<&str> = self.audiences.iter().map(JwtAudience::as_str).collect();
96        validation.set_audience(&audience_strs);
97
98        let token_data = decode::<JwtClaims>(
99            token,
100            &DecodingKey::from_secret(self.secret.as_bytes()),
101            &validation,
102        )
103        .map_err(|e| anyhow!("Invalid JWT token: {e}"))?;
104
105        let claims = token_data.claims;
106
107        let user_type = if claims.scope.contains(&Permission::Admin) {
108            UserType::Admin
109        } else {
110            claims.user_type
111        };
112
113        Ok(ValidatedSessionClaims {
114            user_id: claims.sub,
115            session_id: claims
116                .session_id
117                .ok_or_else(|| anyhow!("Missing session_id in token"))?,
118            user_type,
119        })
120    }
121
122    fn create_context_from_claims(
123        claims: &ValidatedSessionClaims,
124        token: &str,
125        headers: &HeaderMap,
126    ) -> RequestContext {
127        let session_id = SessionId::new(claims.session_id.clone());
128        let user_id = UserId::new(claims.user_id.clone());
129
130        RequestContext::new(
131            session_id,
132            HeaderExtractor::extract_trace_id(headers),
133            HeaderExtractor::extract_context_id(headers),
134            HeaderExtractor::extract_agent_name(headers),
135        )
136        .with_user_id(user_id)
137        .with_auth_token(token)
138        .with_user_type(claims.user_type)
139    }
140
141    fn create_anonymous_context(headers: &HeaderMap) -> RequestContext {
142        RequestContext::new(
143            SessionId::new(ANONYMOUS_SESSION_ID.to_string()),
144            HeaderExtractor::extract_trace_id(headers),
145            HeaderExtractor::extract_context_id(headers),
146            HeaderExtractor::extract_agent_name(headers),
147        )
148        .with_user_id(UserId::anonymous())
149        .with_user_type(UserType::Anon)
150    }
151
152    fn create_test_context() -> RequestContext {
153        RequestContext::new(
154            SessionId::new(TEST_SESSION_ID.to_string()),
155            TraceId::new(TEST_TRACE_ID.to_string()),
156            ContextId::new(TEST_CONTEXT_ID.to_string()),
157            AgentName::new(TEST_AGENT_NAME.to_string()),
158        )
159        .with_user_id(UserId::new(TEST_USER_ID.to_string()))
160        .with_user_type(UserType::User)
161    }
162}