systemprompt_oauth/services/
auth_provider.rs1use async_trait::async_trait;
2use std::sync::Arc;
3use systemprompt_models::auth::JwtAudience;
4use systemprompt_traits::{
5 AuthAction, AuthPermission, AuthProvider, AuthProviderError, AuthResult, AuthorizationProvider,
6 TokenClaims, TokenPair,
7};
8
9use crate::models::JwtClaims as OAuthJwtClaims;
10use crate::services::validation::jwt as jwt_validation;
11
12#[derive(Debug, Clone)]
13pub struct JwtAuthProvider {
14 secret: String,
15 issuer: String,
16 audiences: Vec<JwtAudience>,
17}
18
19impl JwtAuthProvider {
20 pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
21 Self {
22 secret,
23 issuer,
24 audiences,
25 }
26 }
27
28 pub fn from_config() -> anyhow::Result<Self> {
29 let config = systemprompt_models::Config::get()?;
30 Ok(Self {
31 secret: systemprompt_models::SecretsBootstrap::jwt_secret()?.to_string(),
32 issuer: config.jwt_issuer.clone(),
33 audiences: config.jwt_audiences.clone(),
34 })
35 }
36}
37
38fn convert_claims(claims: OAuthJwtClaims) -> TokenClaims {
39 TokenClaims {
40 subject: claims.sub,
41 username: claims.username,
42 email: Some(claims.email),
43 audiences: claims.aud.iter().map(ToString::to_string).collect(),
44 permissions: claims.scope.iter().map(ToString::to_string).collect(),
45 expires_at: claims.exp,
46 issued_at: claims.iat,
47 }
48}
49
50#[async_trait]
51impl AuthProvider for JwtAuthProvider {
52 async fn validate_token(&self, token: &str) -> AuthResult<TokenClaims> {
53 let claims =
54 jwt_validation::validate_jwt_token(token, &self.secret, &self.issuer, &self.audiences)
55 .map_err(|e| {
56 AuthProviderError::Internal(format!("Token validation failed: {e}"))
57 })?;
58
59 Ok(convert_claims(claims))
60 }
61
62 async fn refresh_token(&self, _refresh_token: &str) -> AuthResult<TokenPair> {
63 Err(AuthProviderError::Internal(
64 "Token refresh not yet implemented via trait".to_string(),
65 ))
66 }
67
68 async fn revoke_token(&self, _token: &str) -> AuthResult<()> {
69 Err(AuthProviderError::Internal(
70 "Token revocation not yet implemented via trait".to_string(),
71 ))
72 }
73}
74
75#[derive(Debug, Clone)]
76pub struct JwtAuthorizationProvider {
77 secret: String,
78 issuer: String,
79 audiences: Vec<JwtAudience>,
80}
81
82impl JwtAuthorizationProvider {
83 pub const fn new(secret: String, issuer: String, audiences: Vec<JwtAudience>) -> Self {
84 Self {
85 secret,
86 issuer,
87 audiences,
88 }
89 }
90}
91
92#[async_trait]
93impl AuthorizationProvider for JwtAuthorizationProvider {
94 async fn authorize(
95 &self,
96 _user_id: &str,
97 _resource: &str,
98 _action: &AuthAction,
99 ) -> AuthResult<bool> {
100 Ok(true)
101 }
102
103 async fn get_permissions(&self, _user_id: &str) -> AuthResult<Vec<AuthPermission>> {
104 Ok(vec![])
105 }
106
107 async fn has_audience(&self, token: &str, audience: &str) -> AuthResult<bool> {
108 let claims =
109 jwt_validation::validate_jwt_token(token, &self.secret, &self.issuer, &self.audiences)
110 .map_err(|e| {
111 AuthProviderError::Internal(format!("Token validation failed: {e}"))
112 })?;
113
114 let has_aud = claims.aud.iter().any(|a| a.to_string() == audience);
115 Ok(has_aud)
116 }
117}
118
119#[derive(Clone)]
120pub struct TraitBasedAuthService {
121 auth_provider: Arc<dyn AuthProvider>,
122 authorization_provider: Arc<dyn AuthorizationProvider>,
123}
124
125impl std::fmt::Debug for TraitBasedAuthService {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.debug_struct("TraitBasedAuthService")
128 .field("auth_provider", &"AuthProvider")
129 .field("authorization_provider", &"AuthorizationProvider")
130 .finish()
131 }
132}
133
134impl TraitBasedAuthService {
135 pub fn new(
136 auth_provider: Arc<dyn AuthProvider>,
137 authorization_provider: Arc<dyn AuthorizationProvider>,
138 ) -> Self {
139 Self {
140 auth_provider,
141 authorization_provider,
142 }
143 }
144
145 pub fn from_config() -> anyhow::Result<Self> {
146 let config = systemprompt_models::Config::get()?;
147 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?.to_string();
148 let auth = Arc::new(JwtAuthProvider::new(
149 jwt_secret.clone(),
150 config.jwt_issuer.clone(),
151 config.jwt_audiences.clone(),
152 ));
153 let authz = Arc::new(JwtAuthorizationProvider::new(
154 jwt_secret,
155 config.jwt_issuer.clone(),
156 config.jwt_audiences.clone(),
157 ));
158 Ok(Self::new(auth, authz))
159 }
160
161 pub fn auth_provider(&self) -> &Arc<dyn AuthProvider> {
162 &self.auth_provider
163 }
164
165 pub fn authorization_provider(&self) -> &Arc<dyn AuthorizationProvider> {
166 &self.authorization_provider
167 }
168
169 pub async fn validate_token(&self, token: &str) -> AuthResult<TokenClaims> {
170 self.auth_provider.validate_token(token).await
171 }
172
173 pub async fn has_audience(&self, token: &str, audience: &str) -> AuthResult<bool> {
174 self.authorization_provider
175 .has_audience(token, audience)
176 .await
177 }
178}