trz_gateway_server/
auth_code.rs

1//! Ephemeral code to authenticate certifiate generation.
2
3use std::sync::Mutex;
4use std::time::Duration;
5
6use futures::FutureExt as _;
7use nameth::NamedEnumValues as _;
8use nameth::NamedType as _;
9use nameth::nameth;
10use tokio::sync::oneshot;
11use tracing::Instrument as _;
12use trz_gateway_common::declare_identifier;
13use uuid::Uuid;
14
15const AUTH_CODE_UPDATE_PERIOD: Duration = Duration::from_secs(60);
16declare_identifier!(AuthCode);
17
18impl AuthCode {
19    pub fn current() -> Self {
20        let mut lock = CURRENT_CODE.lock().unwrap();
21        if let Some(current_code) = &*lock {
22            return current_code.current.clone();
23        }
24
25        let (tx, rx) = oneshot::channel();
26        let current = AuthCode::new();
27        let current_code = CurrentCode {
28            periodic_updater: tx,
29            previous: current.clone(),
30            current: current.clone(),
31        };
32        *lock = Some(current_code);
33        drop(lock);
34        start_periodic_updates(rx);
35        return current;
36    }
37
38    pub fn is_valid(&self) -> bool {
39        let lock = CURRENT_CODE.lock().unwrap();
40        let Some(current_code) = &*lock else {
41            return false;
42        };
43        return *self == current_code.current || *self == current_code.previous;
44    }
45
46    pub fn stop_periodic_updates() -> Result<(), StopPeriodicUpdatesError> {
47        CURRENT_CODE
48            .lock()
49            .unwrap()
50            .take()
51            .ok_or(StopPeriodicUpdatesError::NotRunning)?
52            .periodic_updater
53            .send(())
54            .map_err(|()| StopPeriodicUpdatesError::SignalFailed)
55    }
56
57    fn new() -> Self {
58        Self::from(Uuid::new_v4().to_string())
59    }
60}
61
62#[nameth]
63#[derive(thiserror::Error, Debug)]
64pub enum StopPeriodicUpdatesError {
65    #[error("[{n}] Periodic {t} updates are not scheduled", n = self.name(), t = AuthCode::type_name())]
66    NotRunning,
67
68    #[error("[{n}] Failed to send signal to stop periodic {t} updates", n = self.name(), t = AuthCode::type_name())]
69    SignalFailed,
70}
71
72static CURRENT_CODE: Mutex<Option<CurrentCode>> = Mutex::new(None);
73
74struct CurrentCode {
75    periodic_updater: oneshot::Sender<()>,
76    previous: AuthCode,
77    current: AuthCode,
78}
79
80impl CurrentCode {
81    fn renew(&mut self) {
82        self.previous = std::mem::replace(&mut self.current, AuthCode::new())
83    }
84}
85
86fn start_periodic_updates(rx: oneshot::Receiver<()>) {
87    tokio::spawn(
88        async {
89            let rx = rx.shared();
90            loop {
91                tokio::select! {
92                    _ = tokio::time::sleep(AUTH_CODE_UPDATE_PERIOD) => {}
93                    _ = rx.clone() => { break; }
94                }
95
96                let mut lock = CURRENT_CODE.lock().unwrap();
97                let Some(current_code) = &mut *lock else {
98                    return;
99                };
100                current_code.renew();
101            }
102        }
103        .in_current_span(),
104    );
105}
106
107#[cfg(test)]
108mod tests {
109    use tokio::sync::Mutex;
110
111    use super::AuthCode;
112    use super::StopPeriodicUpdatesError;
113
114    /// By default, Rust tests run in parallel
115    static LOCK: Mutex<()> = Mutex::const_new(());
116
117    #[tokio::test]
118    async fn current() {
119        let _lock = LOCK.lock().await;
120
121        let auth_code = AuthCode::current();
122        assert!(auth_code.is_valid());
123
124        let () = AuthCode::stop_periodic_updates().unwrap();
125
126        let auth_code2 = AuthCode::current();
127        assert!(!auth_code.is_valid());
128        assert!(auth_code2.is_valid());
129        assert_ne!(auth_code, auth_code2);
130
131        let () = AuthCode::stop_periodic_updates().unwrap();
132    }
133
134    #[tokio::test]
135    async fn not_running() {
136        let _lock = LOCK.lock().await;
137        let error = AuthCode::stop_periodic_updates().unwrap_err();
138        assert!(matches!(error, StopPeriodicUpdatesError::NotRunning));
139        assert_eq!(
140            "[NotRunning] Periodic AuthCode updates are not scheduled",
141            error.to_string()
142        );
143    }
144}