Skip to main content

trojan_auth/
lib.rs

1//! Authentication backends for trojan.
2
3use std::collections::HashSet;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use sha2::{Digest, Sha224};
9
10#[derive(Debug, Clone)]
11pub struct AuthResult {
12    pub user_id: Option<String>,
13}
14
15#[async_trait]
16pub trait AuthBackend: Send + Sync {
17    async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError>;
18}
19
20#[derive(Debug, thiserror::Error)]
21pub enum AuthError {
22    #[error("invalid credential")]
23    Invalid,
24    #[error("backend error: {0}")]
25    Backend(String),
26}
27
28pub fn sha224_hex(input: &str) -> String {
29    let mut hasher = Sha224::new();
30    hasher.update(input.as_bytes());
31    let digest = hasher.finalize();
32    hex::encode(digest)
33}
34
35#[derive(Debug, Clone)]
36pub struct MemoryAuth {
37    hashes: HashSet<String>,
38}
39
40impl MemoryAuth {
41    pub fn from_hashes<I, S>(hashes: I) -> Self
42    where
43        I: IntoIterator<Item = S>,
44        S: Into<String>,
45    {
46        let hashes = hashes.into_iter().map(Into::into).collect();
47        Self { hashes }
48    }
49
50    pub fn from_plain<I, S>(passwords: I) -> Self
51    where
52        I: IntoIterator<Item = S>,
53        S: AsRef<str>,
54    {
55        let hashes = passwords
56            .into_iter()
57            .map(|p| sha224_hex(p.as_ref()))
58            .collect();
59        Self { hashes }
60    }
61}
62
63#[async_trait]
64impl AuthBackend for MemoryAuth {
65    async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
66        if self.hashes.contains(hash) {
67            Ok(AuthResult { user_id: None })
68        } else {
69            Err(AuthError::Invalid)
70        }
71    }
72}
73
74/// Blanket implementation for Arc<A> where A: AuthBackend.
75/// This allows passing Arc<ReloadableAuth> directly to functions expecting impl AuthBackend.
76#[async_trait]
77impl<A: AuthBackend + ?Sized> AuthBackend for Arc<A> {
78    async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
79        (**self).verify(hash).await
80    }
81}
82
83/// A reloadable auth backend that wraps another backend and allows hot-swapping.
84/// Uses parking_lot::RwLock which doesn't poison on panic.
85pub struct ReloadableAuth {
86    inner: RwLock<Arc<dyn AuthBackend>>,
87}
88
89impl ReloadableAuth {
90    /// Create a new reloadable auth with the given initial backend.
91    pub fn new<A: AuthBackend + 'static>(auth: A) -> Self {
92        Self {
93            inner: RwLock::new(Arc::new(auth)),
94        }
95    }
96
97    /// Replace the auth backend with a new one.
98    pub fn reload<A: AuthBackend + 'static>(&self, auth: A) {
99        let mut inner = self.inner.write();
100        *inner = Arc::new(auth);
101    }
102}
103
104#[async_trait]
105impl AuthBackend for ReloadableAuth {
106    async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
107        // Clone the Arc so we don't hold the lock across await
108        let backend = {
109            let guard = self.inner.read();
110            guard.clone()
111        };
112        backend.verify(hash).await
113    }
114}