Skip to main content

trojan_auth/
reloadable.rs

1//! Hot-reloadable authentication backend wrapper.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7
8use crate::error::AuthError;
9use crate::result::AuthResult;
10use crate::traits::AuthBackend;
11
12/// A wrapper that allows hot-swapping the underlying auth backend.
13///
14/// This is useful for reloading configuration without restarting the server.
15/// Uses `parking_lot::RwLock` which doesn't poison on panic.
16///
17/// # Example
18/// ```
19/// use trojan_auth::{ReloadableAuth, MemoryAuth};
20///
21/// let auth = ReloadableAuth::new(MemoryAuth::from_passwords(["initial"]));
22///
23/// // Later, reload with new passwords
24/// auth.reload(MemoryAuth::from_passwords(["new_password"]));
25/// ```
26pub struct ReloadableAuth {
27    inner: RwLock<Arc<dyn AuthBackend>>,
28}
29
30impl ReloadableAuth {
31    /// Create a new reloadable auth with the given initial backend.
32    pub fn new<A: AuthBackend + 'static>(auth: A) -> Self {
33        Self {
34            inner: RwLock::new(Arc::new(auth)),
35        }
36    }
37
38    /// Replace the auth backend with a new one.
39    ///
40    /// This is an atomic operation - in-flight requests will complete
41    /// with the old backend, new requests will use the new backend.
42    pub fn reload<A: AuthBackend + 'static>(&self, auth: A) {
43        let mut inner = self.inner.write();
44        *inner = Arc::new(auth);
45    }
46
47    /// Replace the auth backend with a pre-wrapped Arc.
48    pub fn reload_arc(&self, auth: Arc<dyn AuthBackend>) {
49        let mut inner = self.inner.write();
50        *inner = auth;
51    }
52
53    /// Get a clone of the current backend Arc.
54    ///
55    /// This is useful for passing the backend to other components
56    /// without holding the lock.
57    #[inline]
58    pub fn get(&self) -> Arc<dyn AuthBackend> {
59        self.inner.read().clone()
60    }
61}
62
63// Cannot derive Debug due to dyn AuthBackend
64impl std::fmt::Debug for ReloadableAuth {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("ReloadableAuth").finish_non_exhaustive()
67    }
68}
69
70#[async_trait]
71impl AuthBackend for ReloadableAuth {
72    async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
73        // Clone the Arc so we don't hold the lock across await
74        let backend = self.get();
75        backend.verify(hash).await
76    }
77
78    async fn record_traffic(&self, user_id: &str, bytes: u64) -> Result<(), AuthError> {
79        let backend = self.get();
80        backend.record_traffic(user_id, bytes).await
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::hash::sha224_hex;
88    use crate::memory::MemoryAuth;
89
90    #[tokio::test]
91    async fn test_reload() {
92        let auth = ReloadableAuth::new(MemoryAuth::from_passwords(["old_password"]));
93
94        let old_hash = sha224_hex("old_password");
95        let new_hash = sha224_hex("new_password");
96
97        // Old password works
98        assert!(auth.verify(&old_hash).await.is_ok());
99        assert!(auth.verify(&new_hash).await.is_err());
100
101        // Reload with new passwords
102        auth.reload(MemoryAuth::from_passwords(["new_password"]));
103
104        // Now new password works, old doesn't
105        assert!(auth.verify(&old_hash).await.is_err());
106        assert!(auth.verify(&new_hash).await.is_ok());
107    }
108}