1use crate::{
4 config::Config,
5 error::{Result, ZeroTrustError},
6 types::{AuthResponse, User},
7};
8use reqwest::Client;
9use serde_json::json;
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
14pub struct AuthManager {
15 config: Arc<Config>,
16 http_client: Arc<Client>,
17}
18
19impl AuthManager {
20 pub(crate) fn new(config: Arc<Config>, http_client: Arc<Client>) -> Self {
22 Self {
23 config,
24 http_client,
25 }
26 }
27
28 pub async fn login<S1, S2>(&self, email: S1, password: S2) -> Result<AuthResponse>
54 where
55 S1: AsRef<str>,
56 S2: AsRef<str>,
57 {
58 let url = format!("{}/api/v1/auth/login", self.config.api_url);
59
60 let payload = json!({
61 "email": email.as_ref(),
62 "password": password.as_ref()
63 });
64
65 let response = self
66 .http_client
67 .post(&url)
68 .header("Content-Type", "application/json")
69 .json(&payload)
70 .send()
71 .await?;
72
73 if response.status().is_success() {
74 let auth_response: AuthResponse = response.json().await?;
75 Ok(auth_response)
76 } else {
77 let status = response.status();
78 let error_text = response.text().await.unwrap_or_default();
79 match status.as_u16() {
80 401 => Err(ZeroTrustError::auth("Invalid email or password")),
81 400 => Err(ZeroTrustError::validation(error_text)),
82 _ => Err(ZeroTrustError::generic(format!(
83 "Login failed: {}",
84 error_text
85 ))),
86 }
87 }
88 }
89
90 pub async fn register<S1, S2, S3>(
117 &self,
118 email: S1,
119 password: S2,
120 role: Option<S3>,
121 ) -> Result<AuthResponse>
122 where
123 S1: AsRef<str>,
124 S2: AsRef<str>,
125 S3: AsRef<str>,
126 {
127 let url = format!("{}/api/v1/auth/register", self.config.api_url);
128
129 let mut payload = json!({
130 "email": email.as_ref(),
131 "password": password.as_ref()
132 });
133
134 if let Some(role) = role {
135 payload["role"] = json!(role.as_ref());
136 }
137
138 let response = self
139 .http_client
140 .post(&url)
141 .header("Content-Type", "application/json")
142 .json(&payload)
143 .send()
144 .await?;
145
146 if response.status().is_success() {
147 let auth_response: AuthResponse = response.json().await?;
148 Ok(auth_response)
149 } else {
150 let status = response.status();
151 let error_text = response.text().await.unwrap_or_default();
152 match status.as_u16() {
153 409 => Err(ZeroTrustError::validation("Email already exists")),
154 400 => Err(ZeroTrustError::validation(error_text)),
155 _ => Err(ZeroTrustError::generic(format!(
156 "Registration failed: {}",
157 error_text
158 ))),
159 }
160 }
161 }
162
163 pub async fn logout(&self) -> Result<()> {
183 if !self.config.is_authenticated() {
184 return Err(ZeroTrustError::auth("Not currently logged in"));
185 }
186
187 let url = format!("{}/api/v1/auth/logout", self.config.api_url);
188
189 let response = self
190 .http_client
191 .post(&url)
192 .header("Content-Type", "application/json")
193 .header(
194 "Authorization",
195 format!("Bearer {}", self.config.token.as_ref().unwrap()),
196 )
197 .send()
198 .await?;
199
200 if response.status().is_success() {
201 Ok(())
202 } else {
203 let error_text = response.text().await.unwrap_or_default();
204 Err(ZeroTrustError::generic(format!(
205 "Logout failed: {}",
206 error_text
207 )))
208 }
209 }
210
211 pub async fn me(&self) -> Result<User> {
231 if !self.config.is_authenticated() {
232 return Err(ZeroTrustError::auth("Authentication required"));
233 }
234
235 let url = format!("{}/api/v1/auth/me", self.config.api_url);
236
237 let response = self
238 .http_client
239 .get(&url)
240 .header(
241 "Authorization",
242 format!("Bearer {}", self.config.token.as_ref().unwrap()),
243 )
244 .send()
245 .await?;
246
247 if response.status().is_success() {
248 let user: User = response.json().await?;
249 Ok(user)
250 } else {
251 let status = response.status();
252 let error_text = response.text().await.unwrap_or_default();
253 match status.as_u16() {
254 401 => Err(ZeroTrustError::auth("Authentication failed")),
255 _ => Err(ZeroTrustError::generic(format!(
256 "Failed to get user info: {}",
257 error_text
258 ))),
259 }
260 }
261 }
262
263 pub async fn refresh_token(&self) -> Result<String> {
283 if !self.config.is_authenticated() {
284 return Err(ZeroTrustError::auth("Authentication required"));
285 }
286
287 let url = format!("{}/api/v1/auth/refresh", self.config.api_url);
288
289 let response = self
290 .http_client
291 .post(&url)
292 .header("Content-Type", "application/json")
293 .header(
294 "Authorization",
295 format!("Bearer {}", self.config.token.as_ref().unwrap()),
296 )
297 .send()
298 .await?;
299
300 if response.status().is_success() {
301 let response_data: serde_json::Value = response.json().await?;
302 let new_token = response_data["token"]
303 .as_str()
304 .ok_or_else(|| ZeroTrustError::generic("Invalid token response"))?;
305 Ok(new_token.to_string())
306 } else {
307 let status = response.status();
308 let error_text = response.text().await.unwrap_or_default();
309 match status.as_u16() {
310 401 => Err(ZeroTrustError::auth("Token refresh failed")),
311 _ => Err(ZeroTrustError::generic(format!(
312 "Token refresh failed: {}",
313 error_text
314 ))),
315 }
316 }
317 }
318
319 pub async fn change_password<S1, S2>(
346 &self,
347 current_password: S1,
348 new_password: S2,
349 ) -> Result<()>
350 where
351 S1: AsRef<str>,
352 S2: AsRef<str>,
353 {
354 if !self.config.is_authenticated() {
355 return Err(ZeroTrustError::auth("Authentication required"));
356 }
357
358 let url = format!("{}/api/v1/auth/change-password", self.config.api_url);
359
360 let payload = json!({
361 "current_password": current_password.as_ref(),
362 "new_password": new_password.as_ref()
363 });
364
365 let response = self
366 .http_client
367 .post(&url)
368 .header("Content-Type", "application/json")
369 .header(
370 "Authorization",
371 format!("Bearer {}", self.config.token.as_ref().unwrap()),
372 )
373 .json(&payload)
374 .send()
375 .await?;
376
377 if response.status().is_success() {
378 Ok(())
379 } else {
380 let status = response.status();
381 let error_text = response.text().await.unwrap_or_default();
382 match status.as_u16() {
383 401 => Err(ZeroTrustError::auth("Current password is incorrect")),
384 400 => Err(ZeroTrustError::validation(error_text)),
385 _ => Err(ZeroTrustError::generic(format!(
386 "Password change failed: {}",
387 error_text
388 ))),
389 }
390 }
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use mockito::{Matcher, Server};
398 use std::time::Duration;
399
400 async fn create_test_auth_manager() -> (AuthManager, mockito::ServerGuard) {
401 let server = Server::new_async().await;
402 let url = server.url();
403
404 let config = Config::new(&url).unwrap();
405 let http_client = reqwest::Client::new();
406
407 let auth_manager = AuthManager::new(
408 Arc::new(config),
409 Arc::new(http_client),
410 );
411
412 (auth_manager, server)
413 }
414
415 #[tokio::test]
416 async fn test_login_success() {
417 let (auth_manager, mut server) = create_test_auth_manager().await;
418
419 let mock = server
420 .mock("POST", "/api/v1/auth/login")
421 .with_status(200)
422 .with_header("content-type", "application/json")
423 .with_body(r#"{
424 "token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9...",
425 "user": {
426 "id": "123",
427 "email": "test@example.com",
428 "role": "user",
429 "created_at": "2023-01-01T00:00:00Z",
430 "wallet_address": null
431 },
432 "expires_at": "2023-01-01T01:00:00Z"
433 }"#)
434 .create_async()
435 .await;
436
437 let result = auth_manager.login("test@example.com", "password").await;
438 assert!(result.is_ok());
439
440 let auth_response = result.unwrap();
441 assert_eq!(auth_response.user.email, "test@example.com");
442 assert!(!auth_response.token.is_empty());
443
444 mock.assert_async().await;
445 }
446
447 #[tokio::test]
448 async fn test_login_failure() {
449 let (auth_manager, mut server) = create_test_auth_manager().await;
450
451 let mock = server
452 .mock("POST", "/api/v1/auth/login")
453 .with_status(401)
454 .with_header("content-type", "application/json")
455 .with_body(r#"{"error": "Invalid credentials"}"#)
456 .create_async()
457 .await;
458
459 let result = auth_manager.login("test@example.com", "wrongpassword").await;
460 assert!(result.is_err());
461 assert!(result.unwrap_err().is_auth_error());
462
463 mock.assert_async().await;
464 }
465}