trz_gateway_server/
auth_code.rs1use std::sync::Mutex;
4use std::time::Duration;
5
6use futures::FutureExt as _;
7use nameth::NamedEnumValues as _;
8use nameth::NamedType as _;
9use nameth::nameth;
10use tokio::sync::oneshot;
11use tracing::Instrument as _;
12use trz_gateway_common::declare_identifier;
13use uuid::Uuid;
14
15const AUTH_CODE_UPDATE_PERIOD: Duration = Duration::from_secs(60);
16declare_identifier!(AuthCode);
17
18impl AuthCode {
19 pub fn current() -> Self {
20 let mut lock = CURRENT_CODE.lock().unwrap();
21 if let Some(current_code) = &*lock {
22 return current_code.current.clone();
23 }
24
25 let (tx, rx) = oneshot::channel();
26 let current = AuthCode::new();
27 let current_code = CurrentCode {
28 periodic_updater: tx,
29 previous: current.clone(),
30 current: current.clone(),
31 };
32 *lock = Some(current_code);
33 drop(lock);
34 start_periodic_updates(rx);
35 return current;
36 }
37
38 pub fn is_valid(&self) -> bool {
39 let lock = CURRENT_CODE.lock().unwrap();
40 let Some(current_code) = &*lock else {
41 return false;
42 };
43 return *self == current_code.current || *self == current_code.previous;
44 }
45
46 pub fn stop_periodic_updates() -> Result<(), StopPeriodicUpdatesError> {
47 CURRENT_CODE
48 .lock()
49 .unwrap()
50 .take()
51 .ok_or(StopPeriodicUpdatesError::NotRunning)?
52 .periodic_updater
53 .send(())
54 .map_err(|()| StopPeriodicUpdatesError::SignalFailed)
55 }
56
57 fn new() -> Self {
58 Self::from(Uuid::new_v4().to_string())
59 }
60}
61
62#[nameth]
63#[derive(thiserror::Error, Debug)]
64pub enum StopPeriodicUpdatesError {
65 #[error("[{n}] Periodic {t} updates are not scheduled", n = self.name(), t = AuthCode::type_name())]
66 NotRunning,
67
68 #[error("[{n}] Failed to send signal to stop periodic {t} updates", n = self.name(), t = AuthCode::type_name())]
69 SignalFailed,
70}
71
72static CURRENT_CODE: Mutex<Option<CurrentCode>> = Mutex::new(None);
73
74struct CurrentCode {
75 periodic_updater: oneshot::Sender<()>,
76 previous: AuthCode,
77 current: AuthCode,
78}
79
80impl CurrentCode {
81 fn renew(&mut self) {
82 self.previous = std::mem::replace(&mut self.current, AuthCode::new())
83 }
84}
85
86fn start_periodic_updates(rx: oneshot::Receiver<()>) {
87 tokio::spawn(
88 async {
89 let rx = rx.shared();
90 loop {
91 tokio::select! {
92 _ = tokio::time::sleep(AUTH_CODE_UPDATE_PERIOD) => {}
93 _ = rx.clone() => { break; }
94 }
95
96 let mut lock = CURRENT_CODE.lock().unwrap();
97 let Some(current_code) = &mut *lock else {
98 return;
99 };
100 current_code.renew();
101 }
102 }
103 .in_current_span(),
104 );
105}
106
107#[cfg(test)]
108mod tests {
109 use tokio::sync::Mutex;
110
111 use super::AuthCode;
112 use super::StopPeriodicUpdatesError;
113
114 static LOCK: Mutex<()> = Mutex::const_new(());
116
117 #[tokio::test]
118 async fn current() {
119 let _lock = LOCK.lock().await;
120
121 let auth_code = AuthCode::current();
122 assert!(auth_code.is_valid());
123
124 let () = AuthCode::stop_periodic_updates().unwrap();
125
126 let auth_code2 = AuthCode::current();
127 assert!(!auth_code.is_valid());
128 assert!(auth_code2.is_valid());
129 assert_ne!(auth_code, auth_code2);
130
131 let () = AuthCode::stop_periodic_updates().unwrap();
132 }
133
134 #[tokio::test]
135 async fn not_running() {
136 let _lock = LOCK.lock().await;
137 let error = AuthCode::stop_periodic_updates().unwrap_err();
138 assert!(matches!(error, StopPeriodicUpdatesError::NotRunning));
139 assert_eq!(
140 "[NotRunning] Periodic AuthCode updates are not scheduled",
141 error.to_string()
142 );
143 }
144}