1#![allow(clippy::cast_possible_truncation)]
7
8use ring::hmac;
9use ring::rand::{SecureRandom, SystemRandom};
10use std::collections::HashSet;
11use std::time::{SystemTime, UNIX_EPOCH};
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14use crate::error::{Result, ShieldError};
15
16const DEFAULT_SECRET_LEN: usize = 20;
18
19#[derive(Zeroize, ZeroizeOnDrop)]
23pub struct TOTP {
24 secret: Vec<u8>,
25 #[zeroize(skip)]
26 digits: usize,
27 #[zeroize(skip)]
28 interval: u64,
29}
30
31impl TOTP {
32 #[must_use]
34 pub fn new(secret: Vec<u8>, digits: usize, interval: u64) -> Self {
35 Self {
36 secret,
37 digits: if digits == 0 { 6 } else { digits },
38 interval: if interval == 0 { 30 } else { interval },
39 }
40 }
41
42 #[must_use]
44 pub fn with_secret(secret: Vec<u8>) -> Self {
45 Self::new(secret, 6, 30)
46 }
47
48 pub fn generate_secret() -> Result<Vec<u8>> {
50 let rng = SystemRandom::new();
51 let mut secret = vec![0u8; DEFAULT_SECRET_LEN];
52 rng.fill(&mut secret)
53 .map_err(|_| ShieldError::RandomFailed)?;
54 Ok(secret)
55 }
56
57 #[must_use]
59 pub fn generate(&self, timestamp: Option<u64>) -> String {
60 let time = timestamp.unwrap_or_else(|| {
61 SystemTime::now()
62 .duration_since(UNIX_EPOCH)
63 .unwrap()
64 .as_secs()
65 });
66
67 let counter = time / self.interval;
68 self.generate_hotp(counter)
69 }
70
71 fn generate_hotp(&self, counter: u64) -> String {
73 let counter_bytes = counter.to_be_bytes();
74 let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, &self.secret);
75 let tag = hmac::sign(&key, &counter_bytes);
76 let hash = tag.as_ref();
77
78 let offset = (hash[19] & 0xf) as usize;
80 let code = u32::from_be_bytes([
81 hash[offset] & 0x7f,
82 hash[offset + 1],
83 hash[offset + 2],
84 hash[offset + 3],
85 ]);
86
87 let modulo = 10u32.pow(self.digits as u32);
88 format!("{:0width$}", code % modulo, width = self.digits)
89 }
90
91 #[must_use]
93 pub fn verify(&self, code: &str, timestamp: Option<u64>, window: u32) -> bool {
94 let time = timestamp.unwrap_or_else(|| {
95 SystemTime::now()
96 .duration_since(UNIX_EPOCH)
97 .unwrap()
98 .as_secs()
99 });
100
101 let window = if window == 0 { 1 } else { window };
102
103 for i in 0..=window {
104 let t = time.saturating_sub(u64::from(i) * self.interval);
105 if self.generate(Some(t)) == code {
106 return true;
107 }
108 if i > 0 {
109 let t = time + u64::from(i) * self.interval;
110 if self.generate(Some(t)) == code {
111 return true;
112 }
113 }
114 }
115 false
116 }
117
118 #[must_use]
120 pub fn secret_to_base32(secret: &[u8]) -> String {
121 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
122 let mut result = String::new();
123 let mut buffer = 0u64;
124 let mut bits = 0;
125
126 for &byte in secret {
127 buffer = (buffer << 8) | u64::from(byte);
128 bits += 8;
129 while bits >= 5 {
130 bits -= 5;
131 result.push(ALPHABET[((buffer >> bits) & 0x1f) as usize] as char);
132 }
133 }
134
135 if bits > 0 {
136 result.push(ALPHABET[((buffer << (5 - bits)) & 0x1f) as usize] as char);
137 }
138
139 result
140 }
141
142 pub fn secret_from_base32(encoded: &str) -> Result<Vec<u8>> {
144 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
145 let mut result = Vec::new();
146 let mut buffer = 0u64;
147 let mut bits = 0;
148
149 for c in encoded.chars() {
150 let c = c.to_ascii_uppercase();
151 if c == '=' {
152 continue;
153 }
154 let val = ALPHABET
155 .iter()
156 .position(|&b| b == c as u8)
157 .ok_or(ShieldError::InvalidFormat)?;
158
159 buffer = (buffer << 5) | (val as u64);
160 bits += 5;
161
162 if bits >= 8 {
163 bits -= 8;
164 result.push((buffer >> bits) as u8);
165 }
166 }
167
168 Ok(result)
169 }
170
171 #[must_use]
173 pub fn provisioning_uri(&self, account: &str, issuer: &str) -> String {
174 let secret_b32 = Self::secret_to_base32(&self.secret);
175 format!(
176 "otpauth://totp/{}:{}?secret={}&issuer={}&algorithm=SHA1&digits={}&period={}",
177 issuer, account, secret_b32, issuer, self.digits, self.interval
178 )
179 }
180
181 #[must_use]
183 pub fn secret(&self) -> &[u8] {
184 &self.secret
185 }
186}
187
188pub struct RecoveryCodes {
190 codes: HashSet<String>,
191 original_count: usize,
192}
193
194impl RecoveryCodes {
195 pub fn new(count: usize) -> Result<Self> {
197 let codes = Self::generate_codes(count)?;
198 let original_count = codes.len();
199 Ok(Self {
200 codes: codes.into_iter().collect(),
201 original_count,
202 })
203 }
204
205 pub fn generate_codes(count: usize) -> Result<Vec<String>> {
207 let rng = SystemRandom::new();
208 let mut codes = Vec::with_capacity(count);
209
210 for _ in 0..count {
211 let mut bytes = [0u8; 4];
212 rng.fill(&mut bytes)
213 .map_err(|_| ShieldError::RandomFailed)?;
214 let code = format!(
215 "{:04X}-{:04X}",
216 u16::from_be_bytes([bytes[0], bytes[1]]),
217 u16::from_be_bytes([bytes[2], bytes[3]])
218 );
219 codes.push(code);
220 }
221
222 Ok(codes)
223 }
224
225 pub fn verify(&mut self, code: &str) -> bool {
227 let normalized = code.to_uppercase().replace([' ', '-'], "");
228 let formatted = if normalized.len() == 8 {
229 format!("{}-{}", &normalized[0..4], &normalized[4..8])
230 } else {
231 code.to_uppercase()
232 };
233
234 self.codes.remove(&formatted)
235 }
236
237 #[must_use]
239 pub fn remaining(&self) -> usize {
240 self.codes.len()
241 }
242
243 #[must_use]
245 pub fn original_count(&self) -> usize {
246 self.original_count
247 }
248
249 #[must_use]
251 pub fn used_count(&self) -> usize {
252 self.original_count - self.codes.len()
253 }
254
255 #[must_use]
257 pub fn codes(&self) -> Vec<String> {
258 self.codes.iter().cloned().collect()
259 }
260
261 #[must_use]
263 pub fn has_codes(&self) -> bool {
264 !self.codes.is_empty()
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_totp_generate_verify() {
274 let secret = TOTP::generate_secret().unwrap();
275 let totp = TOTP::with_secret(secret);
276 let code = totp.generate(None);
277 assert!(totp.verify(&code, None, 1));
278 }
279
280 #[test]
281 fn test_totp_known_vector() {
282 let secret = b"12345678901234567890".to_vec();
284 let totp = TOTP::new(secret, 8, 30);
285 let code = totp.generate(Some(59));
286 assert_eq!(code, "94287082");
287 }
288
289 #[test]
290 fn test_base32_roundtrip() {
291 let secret = TOTP::generate_secret().unwrap();
292 let encoded = TOTP::secret_to_base32(&secret);
293 let decoded = TOTP::secret_from_base32(&encoded).unwrap();
294 assert_eq!(secret, decoded);
295 }
296
297 #[test]
298 fn test_recovery_codes() {
299 let mut rc = RecoveryCodes::new(10).unwrap();
300 assert_eq!(rc.remaining(), 10);
301
302 let codes = rc.codes();
303 assert!(rc.verify(&codes[0]));
304 assert_eq!(rc.remaining(), 9);
305
306 assert!(!rc.verify(&codes[0]));
308 }
309
310 #[test]
311 fn test_provisioning_uri() {
312 let totp = TOTP::with_secret(vec![1, 2, 3, 4, 5]);
313 let uri = totp.provisioning_uri("user@example.com", "TestApp");
314 assert!(uri.starts_with("otpauth://totp/"));
315 assert!(uri.contains("TestApp"));
316 }
317}