Skip to main content

systemprompt_oauth/services/
auth_provider.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3use systemprompt_models::auth::JwtAudience;
4use systemprompt_traits::{
5    AuthAction, AuthPermission, AuthProvider, AuthProviderError, AuthResult, AuthorizationProvider,
6    TokenClaims, TokenPair,
7};
8
9use crate::models::JwtClaims as OAuthJwtClaims;
10use crate::services::validation::jwt as jwt_validation;
11
12#[derive(Debug, Clone)]
13pub struct JwtAuthProvider {
14    secret: String,
15    issuer: String,
16    audiences: Vec<JwtAudience>,
17}
18
19impl JwtAuthProvider {
20    pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
21        Self {
22            secret,
23            issuer,
24            audiences,
25        }
26    }
27
28    pub fn from_config() -> anyhow::Result<Self> {
29        let config = systemprompt_models::Config::get()?;
30        Ok(Self {
31            secret: systemprompt_models::SecretsBootstrap::jwt_secret()?.to_string(),
32            issuer: config.jwt_issuer.clone(),
33            audiences: config.jwt_audiences.clone(),
34        })
35    }
36}
37
38fn convert_claims(claims: OAuthJwtClaims) -> TokenClaims {
39    TokenClaims {
40        subject: claims.sub,
41        username: claims.username,
42        email: Some(claims.email),
43        audiences: claims.aud.iter().map(ToString::to_string).collect(),
44        permissions: claims.scope.iter().map(ToString::to_string).collect(),
45        expires_at: claims.exp,
46        issued_at: claims.iat,
47    }
48}
49
50#[async_trait]
51impl AuthProvider for JwtAuthProvider {
52    async fn validate_token(&self, token: &str) -> AuthResult<TokenClaims> {
53        let claims =
54            jwt_validation::validate_jwt_token(token, &self.secret, &self.issuer, &self.audiences)
55                .map_err(|e| {
56                    AuthProviderError::Internal(format!("Token validation failed: {e}"))
57                })?;
58
59        Ok(convert_claims(claims))
60    }
61
62    async fn refresh_token(&self, _refresh_token: &str) -> AuthResult<TokenPair> {
63        Err(AuthProviderError::Internal(
64            "Token refresh not yet implemented via trait".to_string(),
65        ))
66    }
67
68    async fn revoke_token(&self, _token: &str) -> AuthResult<()> {
69        Err(AuthProviderError::Internal(
70            "Token revocation not yet implemented via trait".to_string(),
71        ))
72    }
73}
74
75#[derive(Debug, Clone)]
76pub struct JwtAuthorizationProvider {
77    secret: String,
78    issuer: String,
79    audiences: Vec<JwtAudience>,
80}
81
82impl JwtAuthorizationProvider {
83    pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
84        Self {
85            secret,
86            issuer,
87            audiences,
88        }
89    }
90}
91
92#[async_trait]
93impl AuthorizationProvider for JwtAuthorizationProvider {
94    async fn authorize(
95        &self,
96        _user_id: &str,
97        _resource: &str,
98        _action: &AuthAction,
99    ) -> AuthResult<bool> {
100        Ok(true)
101    }
102
103    async fn get_permissions(&self, _user_id: &str) -> AuthResult<Vec<AuthPermission>> {
104        Ok(vec![])
105    }
106
107    async fn has_audience(&self, token: &str, audience: &str) -> AuthResult<bool> {
108        let claims =
109            jwt_validation::validate_jwt_token(token, &self.secret, &self.issuer, &self.audiences)
110                .map_err(|e| {
111                    AuthProviderError::Internal(format!("Token validation failed: {e}"))
112                })?;
113
114        let has_aud = claims.aud.iter().any(|a| a.to_string() == audience);
115        Ok(has_aud)
116    }
117}
118
119#[derive(Clone)]
120pub struct TraitBasedAuthService {
121    auth_provider: Arc<dyn AuthProvider>,
122    authorization_provider: Arc<dyn AuthorizationProvider>,
123}
124
125impl std::fmt::Debug for TraitBasedAuthService {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.debug_struct("TraitBasedAuthService")
128            .field("auth_provider", &"AuthProvider")
129            .field("authorization_provider", &"AuthorizationProvider")
130            .finish()
131    }
132}
133
134impl TraitBasedAuthService {
135    pub fn new(
136        auth_provider: Arc<dyn AuthProvider>,
137        authorization_provider: Arc<dyn AuthorizationProvider>,
138    ) -> Self {
139        Self {
140            auth_provider,
141            authorization_provider,
142        }
143    }
144
145    pub fn from_config() -> anyhow::Result<Self> {
146        let config = systemprompt_models::Config::get()?;
147        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?.to_string();
148        let auth = Arc::new(JwtAuthProvider::new(
149            jwt_secret.clone(),
150            config.jwt_issuer.clone(),
151            config.jwt_audiences.clone(),
152        ));
153        let authz = Arc::new(JwtAuthorizationProvider::new(
154            jwt_secret,
155            config.jwt_issuer.clone(),
156            config.jwt_audiences.clone(),
157        ));
158        Ok(Self::new(auth, authz))
159    }
160
161    pub fn auth_provider(&self) -> &Arc<dyn AuthProvider> {
162        &self.auth_provider
163    }
164
165    pub fn authorization_provider(&self) -> &Arc<dyn AuthorizationProvider> {
166        &self.authorization_provider
167    }
168
169    pub async fn validate_token(&self, token: &str) -> AuthResult<TokenClaims> {
170        self.auth_provider.validate_token(token).await
171    }
172
173    pub async fn has_audience(&self, token: &str, audience: &str) -> AuthResult<bool> {
174        self.authorization_provider
175            .has_audience(token, audience)
176            .await
177    }
178}