1use crate::{AuthMethod, TokenValidator, error::AuthError, types::TokenClaims};
2use base64::Engine;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use tracing::{debug, warn};
7
8#[derive(Debug, Clone)]
10pub struct AuthContext {
11 pub user_id: Option<String>,
12 pub scopes: Vec<String>,
13 pub claims: Option<TokenClaims>,
14 pub auth_method: AuthMethod,
15 pub is_authenticated: bool,
16}
17
18impl AuthContext {
19 pub fn new() -> Self {
20 Self {
21 user_id: None,
22 scopes: Vec::new(),
23 claims: None,
24 auth_method: AuthMethod::None,
25 is_authenticated: false,
26 }
27 }
28
29 pub fn with_user_id(mut self, user_id: String) -> Self {
30 self.user_id = Some(user_id);
31 self
32 }
33
34 pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
35 self.scopes = scopes;
36 self
37 }
38
39 pub fn with_claims(mut self, claims: TokenClaims) -> Self {
40 self.claims = Some(claims);
41 self
42 }
43
44 pub fn with_auth_method(mut self, auth_method: AuthMethod) -> Self {
45 self.auth_method = auth_method;
46 self
47 }
48
49 pub fn authenticated(mut self) -> Self {
50 self.is_authenticated = true;
51 self
52 }
53
54 pub fn has_scope(&self, scope: &str) -> bool {
55 self.scopes.contains(&scope.to_string())
56 }
57
58 pub fn has_any_scope(&self, scopes: &[String]) -> bool {
59 scopes.iter().any(|scope| self.has_scope(scope))
60 }
61
62 pub fn has_all_scopes(&self, scopes: &[String]) -> bool {
63 scopes.iter().all(|scope| self.has_scope(scope))
64 }
65}
66
67impl Default for AuthContext {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73pub struct ServerAuthMiddleware {
75 token_validator: Arc<TokenValidator>,
76 required_scopes: Vec<String>,
77 auth_enabled: bool,
78 session_store: Arc<RwLock<HashMap<String, AuthContext>>>,
79}
80
81impl ServerAuthMiddleware {
82 pub fn new(token_validator: TokenValidator) -> Self {
83 Self {
84 token_validator: Arc::new(token_validator),
85 required_scopes: Vec::new(),
86 auth_enabled: true,
87 session_store: Arc::new(RwLock::new(HashMap::new())),
88 }
89 }
90
91 pub fn with_required_scopes(mut self, scopes: Vec<String>) -> Self {
92 self.required_scopes = scopes;
93 self
94 }
95
96 pub fn with_auth_enabled(mut self, enabled: bool) -> Self {
97 self.auth_enabled = enabled;
98 self
99 }
100
101 pub async fn validate_request(
103 &self,
104 headers: &HashMap<String, String>,
105 ) -> Result<AuthContext, AuthError> {
106 if !self.auth_enabled {
107 return Ok(AuthContext::new().authenticated());
108 }
109
110 if let Some(auth_header) = headers.get("Authorization") {
112 return self.validate_auth_header(auth_header).await;
113 }
114
115 for (key, value) in headers {
117 if key.to_lowercase().contains("api-key") || key.to_lowercase().contains("x-api-key") {
118 return self.validate_api_key(key, value).await;
119 }
120 }
121
122 if self.required_scopes.is_empty() {
124 Ok(AuthContext::new())
126 } else {
127 Err(AuthError::InvalidCredentials)
128 }
129 }
130
131 async fn validate_auth_header(&self, auth_header: &str) -> Result<AuthContext, AuthError> {
133 if auth_header.starts_with("Bearer ") {
134 self.validate_bearer_token(auth_header).await
135 } else if auth_header.starts_with("Basic ") {
136 self.validate_basic_auth(auth_header).await
137 } else {
138 Err(AuthError::InvalidToken(
139 "Unsupported authorization scheme".to_string(),
140 ))
141 }
142 }
143
144 async fn validate_bearer_token(&self, auth_header: &str) -> Result<AuthContext, AuthError> {
146 let token = crate::validation::extract_bearer_token(auth_header)?;
147
148 let claims = self.token_validator.validate_token(token).await?;
150
151 if !self.required_scopes.is_empty() {
153 self.token_validator
154 .validate_scopes(&claims, &self.required_scopes)?;
155 }
156
157 let scopes = claims
158 .scope
159 .as_ref()
160 .map(|s| s.split_whitespace().map(|s| s.to_string()).collect())
161 .unwrap_or_default();
162
163 let auth_context = AuthContext::new()
164 .with_user_id(claims.sub.clone())
165 .with_scopes(scopes)
166 .with_claims(claims)
167 .with_auth_method(AuthMethod::bearer(token.to_string()))
168 .authenticated();
169
170 debug!(
171 "Bearer token validated for user: {}",
172 auth_context
173 .user_id
174 .as_ref()
175 .unwrap_or(&"unknown".to_string())
176 );
177 Ok(auth_context)
178 }
179
180 async fn validate_basic_auth(&self, auth_header: &str) -> Result<AuthContext, AuthError> {
182 let encoded = auth_header
184 .strip_prefix("Basic ")
185 .ok_or_else(|| AuthError::InvalidToken("Invalid Basic auth format".to_string()))?;
186
187 let decoded = base64::engine::general_purpose::STANDARD
188 .decode(encoded)
189 .map_err(|_| AuthError::InvalidToken("Invalid Basic auth encoding".to_string()))?;
190
191 let credentials = String::from_utf8(decoded)
192 .map_err(|_| AuthError::InvalidToken("Invalid Basic auth credentials".to_string()))?;
193
194 let parts: Vec<&str> = credentials.splitn(2, ':').collect();
195 if parts.len() != 2 {
196 return Err(AuthError::InvalidToken(
197 "Invalid Basic auth format".to_string(),
198 ));
199 }
200
201 let username = parts[0];
202 let password = parts[1];
203
204 if username.is_empty() || password.is_empty() {
207 return Err(AuthError::InvalidCredentials);
208 }
209
210 let auth_context = AuthContext::new()
211 .with_user_id(username.to_string())
212 .with_auth_method(AuthMethod::basic(
213 username.to_string(),
214 password.to_string(),
215 ))
216 .authenticated();
217
218 debug!("Basic auth validated for user: {}", username);
219 Ok(auth_context)
220 }
221
222 async fn validate_api_key(
224 &self,
225 _header_name: &str,
226 api_key: &str,
227 ) -> Result<AuthContext, AuthError> {
228 if api_key.is_empty() {
229 return Err(AuthError::InvalidCredentials);
230 }
231
232 let auth_context = AuthContext::new()
235 .with_user_id(format!("api_user_{}", &api_key[..8.min(api_key.len())]))
236 .with_auth_method(AuthMethod::api_key(api_key.to_string()))
237 .authenticated();
238
239 debug!(
240 "API key validated for user: {}",
241 auth_context
242 .user_id
243 .as_ref()
244 .unwrap_or(&"unknown".to_string())
245 );
246 Ok(auth_context)
247 }
248
249 pub async fn store_session(&self, session_id: String, auth_context: AuthContext) {
251 let mut sessions = self.session_store.write().await;
252 sessions.insert(session_id, auth_context);
253 }
254
255 pub async fn get_session(&self, session_id: &str) -> Option<AuthContext> {
257 let sessions = self.session_store.read().await;
258 sessions.get(session_id).cloned()
259 }
260
261 pub async fn remove_session(&self, session_id: &str) {
263 let mut sessions = self.session_store.write().await;
264 sessions.remove(session_id);
265 }
266
267 pub fn check_scopes(
269 &self,
270 auth_context: &AuthContext,
271 required_scopes: &[String],
272 ) -> Result<(), AuthError> {
273 if required_scopes.is_empty() {
274 return Ok(());
275 }
276
277 if !auth_context.has_all_scopes(required_scopes) {
278 let missing_scopes: Vec<String> = required_scopes
279 .iter()
280 .filter(|scope| !auth_context.has_scope(scope))
281 .cloned()
282 .collect();
283
284 return Err(AuthError::MissingScope {
285 scope: missing_scopes.join(", "),
286 });
287 }
288
289 Ok(())
290 }
291}
292
293pub struct ClientAuthMiddleware {
295 auth_method: AuthMethod,
296 auto_refresh: bool,
297}
298
299impl ClientAuthMiddleware {
300 pub fn new(auth_method: AuthMethod) -> Self {
301 Self {
302 auth_method,
303 auto_refresh: true,
304 }
305 }
306
307 pub fn with_auto_refresh(mut self, enabled: bool) -> Self {
308 self.auto_refresh = enabled;
309 self
310 }
311
312 pub async fn get_headers(&mut self) -> Result<HashMap<String, String>, AuthError> {
314 if self.auto_refresh && self.auth_method.requires_refresh() {
316 if let Err(e) = self.auth_method.refresh().await {
317 warn!("Failed to refresh authentication: {:?}", e);
318 }
319 }
320
321 self.auth_method.get_headers().await
322 }
323
324 pub fn with_auth_method(mut self, auth_method: AuthMethod) -> Self {
326 self.auth_method = auth_method;
327 self
328 }
329
330 pub fn get_auth_method(&self) -> &AuthMethod {
332 &self.auth_method
333 }
334}
335
336#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[tokio::test]
343 async fn test_bearer_auth_validation() {
344 let validator = TokenValidator::new("test_secret".to_string());
345 let middleware = ServerAuthMiddleware::new(validator)
346 .with_required_scopes(vec!["read".to_string(), "write".to_string()]);
347
348 let mut headers = HashMap::new();
349 headers.insert(
350 "Authorization".to_string(),
351 "Bearer invalid_token".to_string(),
352 );
353
354 let result = middleware.validate_request(&headers).await;
355 assert!(result.is_err());
356 }
357
358 #[tokio::test]
359 async fn test_basic_auth_validation() {
360 let validator = TokenValidator::new("test_secret".to_string());
361 let middleware = ServerAuthMiddleware::new(validator);
362
363 let mut headers = HashMap::new();
364 headers.insert(
365 "Authorization".to_string(),
366 "Basic dXNlcjpwYXNz".to_string(),
367 ); let result = middleware.validate_request(&headers).await;
370 assert!(result.is_ok());
371 }
372
373 #[tokio::test]
374 async fn test_api_key_validation() {
375 let validator = TokenValidator::new("test_secret".to_string());
376 let middleware = ServerAuthMiddleware::new(validator);
377
378 let mut headers = HashMap::new();
379 headers.insert("X-API-Key".to_string(), "test_api_key".to_string());
380
381 let result = middleware.validate_request(&headers).await;
382 assert!(result.is_ok());
383 }
384
385 #[tokio::test]
386 async fn test_auth_context_scopes() {
387 let context = AuthContext::new().with_scopes(vec!["read".to_string(), "write".to_string()]);
388
389 assert!(context.has_scope("read"));
390 assert!(context.has_scope("write"));
391 assert!(!context.has_scope("delete"));
392
393 assert!(context.has_any_scope(&["read".to_string(), "delete".to_string()]));
394 assert!(context.has_all_scopes(&["read".to_string(), "write".to_string()]));
395 assert!(!context.has_all_scopes(&["read".to_string(), "delete".to_string()]));
396 }
397}