1#![allow(clippy::cast_possible_truncation)]
7
8use ring::hmac;
9use std::collections::HashSet;
10use std::time::{SystemTime, UNIX_EPOCH};
11use subtle::ConstantTimeEq;
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14use crate::error::{Result, ShieldError};
15
16const DEFAULT_SECRET_LEN: usize = 20;
18
19#[derive(Zeroize, ZeroizeOnDrop)]
23pub struct TOTP {
24 secret: Vec<u8>,
25 #[zeroize(skip)]
26 digits: usize,
27 #[zeroize(skip)]
28 interval: u64,
29}
30
31impl TOTP {
32 #[must_use]
34 pub fn new(secret: Vec<u8>, digits: usize, interval: u64) -> Self {
35 Self {
36 secret,
37 digits: digits.clamp(6, 9),
38 interval: if interval == 0 { 30 } else { interval },
39 }
40 }
41
42 #[must_use]
44 pub fn with_secret(secret: Vec<u8>) -> Self {
45 Self::new(secret, 6, 30)
46 }
47
48 pub fn generate_secret() -> Result<Vec<u8>> {
50 crate::random::random_vec(DEFAULT_SECRET_LEN)
51 }
52
53 #[must_use]
55 pub fn generate(&self, timestamp: Option<u64>) -> String {
56 let time = timestamp.unwrap_or_else(|| {
57 SystemTime::now()
58 .duration_since(UNIX_EPOCH)
59 .map_or(0, |d| d.as_secs())
60 });
61
62 let counter = time / self.interval;
63 self.generate_hotp(counter)
64 }
65
66 fn generate_hotp(&self, counter: u64) -> String {
68 let counter_bytes = counter.to_be_bytes();
69 let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, &self.secret);
70 let tag = hmac::sign(&key, &counter_bytes);
71 let hash = tag.as_ref();
72
73 let offset = (hash[19] & 0xf) as usize;
75 let code = u32::from_be_bytes([
76 hash[offset] & 0x7f,
77 hash[offset + 1],
78 hash[offset + 2],
79 hash[offset + 3],
80 ]);
81
82 let modulo = 10u32.pow(self.digits as u32);
83 format!("{:0width$}", code % modulo, width = self.digits)
84 }
85
86 #[must_use]
88 pub fn verify(&self, code: &str, timestamp: Option<u64>, window: u32) -> bool {
89 let time = timestamp.unwrap_or_else(|| {
90 SystemTime::now()
91 .duration_since(UNIX_EPOCH)
92 .map_or(0, |d| d.as_secs())
93 });
94
95 for i in 0..=window {
96 let t = time.saturating_sub(u64::from(i) * self.interval);
97 let expected = self.generate(Some(t));
98 if expected.as_bytes().ct_eq(code.as_bytes()).unwrap_u8() == 1 {
99 return true;
100 }
101 if i > 0 {
102 let t = time + u64::from(i) * self.interval;
103 let expected = self.generate(Some(t));
104 if expected.as_bytes().ct_eq(code.as_bytes()).unwrap_u8() == 1 {
105 return true;
106 }
107 }
108 }
109 false
110 }
111
112 #[must_use]
114 pub fn secret_to_base32(secret: &[u8]) -> String {
115 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
116 let mut result = String::new();
117 let mut buffer = 0u64;
118 let mut bits = 0;
119
120 for &byte in secret {
121 buffer = (buffer << 8) | u64::from(byte);
122 bits += 8;
123 while bits >= 5 {
124 bits -= 5;
125 result.push(ALPHABET[((buffer >> bits) & 0x1f) as usize] as char);
126 }
127 }
128
129 if bits > 0 {
130 result.push(ALPHABET[((buffer << (5 - bits)) & 0x1f) as usize] as char);
131 }
132
133 result
134 }
135
136 pub fn secret_from_base32(encoded: &str) -> Result<Vec<u8>> {
138 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
139 let mut result = Vec::new();
140 let mut buffer = 0u64;
141 let mut bits = 0;
142
143 for c in encoded.chars() {
144 let c = c.to_ascii_uppercase();
145 if c == '=' {
146 continue;
147 }
148 let val = ALPHABET
149 .iter()
150 .position(|&b| b == c as u8)
151 .ok_or(ShieldError::InvalidFormat)?;
152
153 buffer = (buffer << 5) | (val as u64);
154 bits += 5;
155
156 if bits >= 8 {
157 bits -= 8;
158 result.push((buffer >> bits) as u8);
159 }
160 }
161
162 Ok(result)
163 }
164
165 #[must_use]
167 pub fn provisioning_uri(&self, account: &str, issuer: &str) -> String {
168 let secret_b32 = Self::secret_to_base32(&self.secret);
169 format!(
170 "otpauth://totp/{}:{}?secret={}&issuer={}&algorithm=SHA1&digits={}&period={}",
171 issuer, account, secret_b32, issuer, self.digits, self.interval
172 )
173 }
174
175 #[must_use]
177 pub fn secret(&self) -> &[u8] {
178 &self.secret
179 }
180}
181
182pub struct RecoveryCodes {
184 codes: HashSet<String>,
185 original_count: usize,
186}
187
188impl RecoveryCodes {
189 pub fn new(count: usize) -> Result<Self> {
191 let codes = Self::generate_codes(count)?;
192 let original_count = codes.len();
193 Ok(Self {
194 codes: codes.into_iter().collect(),
195 original_count,
196 })
197 }
198
199 pub fn generate_codes(count: usize) -> Result<Vec<String>> {
201 let mut codes = Vec::with_capacity(count);
202
203 for _ in 0..count {
204 let bytes: [u8; 8] = crate::random::random_bytes()?;
205 let code = format!(
206 "{:04X}-{:04X}-{:04X}-{:04X}",
207 u16::from_be_bytes([bytes[0], bytes[1]]),
208 u16::from_be_bytes([bytes[2], bytes[3]]),
209 u16::from_be_bytes([bytes[4], bytes[5]]),
210 u16::from_be_bytes([bytes[6], bytes[7]])
211 );
212 codes.push(code);
213 }
214
215 Ok(codes)
216 }
217
218 pub fn verify(&mut self, code: &str) -> bool {
220 let normalized = code.to_uppercase().replace([' ', '-'], "");
221 let formatted = if normalized.len() == 16 {
222 format!(
223 "{}-{}-{}-{}",
224 &normalized[0..4],
225 &normalized[4..8],
226 &normalized[8..12],
227 &normalized[12..16]
228 )
229 } else {
230 code.to_uppercase()
231 };
232
233 self.codes.remove(&formatted)
234 }
235
236 #[must_use]
238 pub fn remaining(&self) -> usize {
239 self.codes.len()
240 }
241
242 #[must_use]
244 pub fn original_count(&self) -> usize {
245 self.original_count
246 }
247
248 #[must_use]
250 pub fn used_count(&self) -> usize {
251 self.original_count - self.codes.len()
252 }
253
254 #[must_use]
256 pub fn codes(&self) -> Vec<String> {
257 self.codes.iter().cloned().collect()
258 }
259
260 #[must_use]
262 pub fn has_codes(&self) -> bool {
263 !self.codes.is_empty()
264 }
265}
266
267impl Drop for RecoveryCodes {
268 fn drop(&mut self) {
269 for code in self.codes.drain() {
270 let mut bytes = code.into_bytes();
272 bytes.zeroize();
273 }
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_totp_generate_verify() {
283 let secret = TOTP::generate_secret().unwrap();
284 let totp = TOTP::with_secret(secret);
285 let code = totp.generate(None);
286 assert!(totp.verify(&code, None, 1));
287 }
288
289 #[test]
290 fn test_totp_minimum_digits_clamping() {
291 let secret = b"12345678901234567890".to_vec();
292
293 let totp1 = TOTP::new(secret.clone(), 1, 30);
295 let code1 = totp1.generate(Some(59));
296 assert_eq!(code1.len(), 6);
297
298 let totp5 = TOTP::new(secret.clone(), 5, 30);
300 let code5 = totp5.generate(Some(59));
301 assert_eq!(code5.len(), 6);
302
303 let totp0 = TOTP::new(secret.clone(), 0, 30);
305 let code0 = totp0.generate(Some(59));
306 assert_eq!(code0.len(), 6);
307
308 let totp6 = TOTP::new(secret, 6, 30);
310 let code6 = totp6.generate(Some(59));
311 assert_eq!(code6.len(), 6);
312 }
313
314 #[test]
315 fn test_totp_known_vector() {
316 let secret = b"12345678901234567890".to_vec();
318 let totp = TOTP::new(secret, 8, 30);
319 let code = totp.generate(Some(59));
320 assert_eq!(code, "94287082");
321 }
322
323 #[test]
324 fn test_base32_roundtrip() {
325 let secret = TOTP::generate_secret().unwrap();
326 let encoded = TOTP::secret_to_base32(&secret);
327 let decoded = TOTP::secret_from_base32(&encoded).unwrap();
328 assert_eq!(secret, decoded);
329 }
330
331 #[test]
332 fn test_recovery_codes() {
333 let mut rc = RecoveryCodes::new(10).unwrap();
334 assert_eq!(rc.remaining(), 10);
335
336 let codes = rc.codes();
337 assert!(rc.verify(&codes[0]));
338 assert_eq!(rc.remaining(), 9);
339
340 assert!(!rc.verify(&codes[0]));
342 }
343
344 #[test]
345 fn test_provisioning_uri() {
346 let totp = TOTP::with_secret(vec![1, 2, 3, 4, 5]);
347 let uri = totp.provisioning_uri("user@example.com", "TestApp");
348 assert!(uri.starts_with("otpauth://totp/"));
349 assert!(uri.contains("TestApp"));
350 }
351}