1#![allow(clippy::cast_possible_truncation)]
7
8use ring::hmac;
9use ring::rand::{SecureRandom, SystemRandom};
10use std::collections::HashSet;
11use std::time::{SystemTime, UNIX_EPOCH};
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14use crate::error::{Result, ShieldError};
15
16const DEFAULT_SECRET_LEN: usize = 20;
18
19#[derive(Zeroize, ZeroizeOnDrop)]
23pub struct TOTP {
24 secret: Vec<u8>,
25 #[zeroize(skip)]
26 digits: usize,
27 #[zeroize(skip)]
28 interval: u64,
29}
30
31impl TOTP {
32 #[must_use]
34 pub fn new(secret: Vec<u8>, digits: usize, interval: u64) -> Self {
35 Self {
36 secret,
37 digits: if digits == 0 { 6 } else { digits },
38 interval: if interval == 0 { 30 } else { interval },
39 }
40 }
41
42 #[must_use]
44 pub fn with_secret(secret: Vec<u8>) -> Self {
45 Self::new(secret, 6, 30)
46 }
47
48 pub fn generate_secret() -> Result<Vec<u8>> {
50 let rng = SystemRandom::new();
51 let mut secret = vec![0u8; DEFAULT_SECRET_LEN];
52 rng.fill(&mut secret).map_err(|_| ShieldError::RandomFailed)?;
53 Ok(secret)
54 }
55
56 #[must_use]
58 pub fn generate(&self, timestamp: Option<u64>) -> String {
59 let time = timestamp.unwrap_or_else(|| {
60 SystemTime::now()
61 .duration_since(UNIX_EPOCH)
62 .unwrap()
63 .as_secs()
64 });
65
66 let counter = time / self.interval;
67 self.generate_hotp(counter)
68 }
69
70 fn generate_hotp(&self, counter: u64) -> String {
72 let counter_bytes = counter.to_be_bytes();
73 let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, &self.secret);
74 let tag = hmac::sign(&key, &counter_bytes);
75 let hash = tag.as_ref();
76
77 let offset = (hash[19] & 0xf) as usize;
79 let code = u32::from_be_bytes([
80 hash[offset] & 0x7f,
81 hash[offset + 1],
82 hash[offset + 2],
83 hash[offset + 3],
84 ]);
85
86 let modulo = 10u32.pow(self.digits as u32);
87 format!("{:0width$}", code % modulo, width = self.digits)
88 }
89
90 #[must_use]
92 pub fn verify(&self, code: &str, timestamp: Option<u64>, window: u32) -> bool {
93 let time = timestamp.unwrap_or_else(|| {
94 SystemTime::now()
95 .duration_since(UNIX_EPOCH)
96 .unwrap()
97 .as_secs()
98 });
99
100 let window = if window == 0 { 1 } else { window };
101
102 for i in 0..=window {
103 let t = time.saturating_sub(u64::from(i) * self.interval);
104 if self.generate(Some(t)) == code {
105 return true;
106 }
107 if i > 0 {
108 let t = time + u64::from(i) * self.interval;
109 if self.generate(Some(t)) == code {
110 return true;
111 }
112 }
113 }
114 false
115 }
116
117 #[must_use]
119 pub fn secret_to_base32(secret: &[u8]) -> String {
120 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
121 let mut result = String::new();
122 let mut buffer = 0u64;
123 let mut bits = 0;
124
125 for &byte in secret {
126 buffer = (buffer << 8) | u64::from(byte);
127 bits += 8;
128 while bits >= 5 {
129 bits -= 5;
130 result.push(ALPHABET[((buffer >> bits) & 0x1f) as usize] as char);
131 }
132 }
133
134 if bits > 0 {
135 result.push(ALPHABET[((buffer << (5 - bits)) & 0x1f) as usize] as char);
136 }
137
138 result
139 }
140
141 pub fn secret_from_base32(encoded: &str) -> Result<Vec<u8>> {
143 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
144 let mut result = Vec::new();
145 let mut buffer = 0u64;
146 let mut bits = 0;
147
148 for c in encoded.chars() {
149 let c = c.to_ascii_uppercase();
150 if c == '=' {
151 continue;
152 }
153 let val = ALPHABET
154 .iter()
155 .position(|&b| b == c as u8)
156 .ok_or(ShieldError::InvalidFormat)?;
157
158 buffer = (buffer << 5) | (val as u64);
159 bits += 5;
160
161 if bits >= 8 {
162 bits -= 8;
163 result.push((buffer >> bits) as u8);
164 }
165 }
166
167 Ok(result)
168 }
169
170 #[must_use]
172 pub fn provisioning_uri(&self, account: &str, issuer: &str) -> String {
173 let secret_b32 = Self::secret_to_base32(&self.secret);
174 format!(
175 "otpauth://totp/{}:{}?secret={}&issuer={}&algorithm=SHA1&digits={}&period={}",
176 issuer, account, secret_b32, issuer, self.digits, self.interval
177 )
178 }
179
180 #[must_use]
182 pub fn secret(&self) -> &[u8] {
183 &self.secret
184 }
185}
186
187pub struct RecoveryCodes {
189 codes: HashSet<String>,
190 original_count: usize,
191}
192
193impl RecoveryCodes {
194 pub fn new(count: usize) -> Result<Self> {
196 let codes = Self::generate_codes(count)?;
197 let original_count = codes.len();
198 Ok(Self {
199 codes: codes.into_iter().collect(),
200 original_count,
201 })
202 }
203
204 pub fn generate_codes(count: usize) -> Result<Vec<String>> {
206 let rng = SystemRandom::new();
207 let mut codes = Vec::with_capacity(count);
208
209 for _ in 0..count {
210 let mut bytes = [0u8; 4];
211 rng.fill(&mut bytes).map_err(|_| ShieldError::RandomFailed)?;
212 let code = format!(
213 "{:04X}-{:04X}",
214 u16::from_be_bytes([bytes[0], bytes[1]]),
215 u16::from_be_bytes([bytes[2], bytes[3]])
216 );
217 codes.push(code);
218 }
219
220 Ok(codes)
221 }
222
223 pub fn verify(&mut self, code: &str) -> bool {
225 let normalized = code.to_uppercase().replace([' ', '-'], "");
226 let formatted = if normalized.len() == 8 {
227 format!("{}-{}", &normalized[0..4], &normalized[4..8])
228 } else {
229 code.to_uppercase()
230 };
231
232 self.codes.remove(&formatted)
233 }
234
235 #[must_use]
237 pub fn remaining(&self) -> usize {
238 self.codes.len()
239 }
240
241 #[must_use]
243 pub fn original_count(&self) -> usize {
244 self.original_count
245 }
246
247 #[must_use]
249 pub fn used_count(&self) -> usize {
250 self.original_count - self.codes.len()
251 }
252
253 #[must_use]
255 pub fn codes(&self) -> Vec<String> {
256 self.codes.iter().cloned().collect()
257 }
258
259 #[must_use]
261 pub fn has_codes(&self) -> bool {
262 !self.codes.is_empty()
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_totp_generate_verify() {
272 let secret = TOTP::generate_secret().unwrap();
273 let totp = TOTP::with_secret(secret);
274 let code = totp.generate(None);
275 assert!(totp.verify(&code, None, 1));
276 }
277
278 #[test]
279 fn test_totp_known_vector() {
280 let secret = b"12345678901234567890".to_vec();
282 let totp = TOTP::new(secret, 8, 30);
283 let code = totp.generate(Some(59));
284 assert_eq!(code, "94287082");
285 }
286
287 #[test]
288 fn test_base32_roundtrip() {
289 let secret = TOTP::generate_secret().unwrap();
290 let encoded = TOTP::secret_to_base32(&secret);
291 let decoded = TOTP::secret_from_base32(&encoded).unwrap();
292 assert_eq!(secret, decoded);
293 }
294
295 #[test]
296 fn test_recovery_codes() {
297 let mut rc = RecoveryCodes::new(10).unwrap();
298 assert_eq!(rc.remaining(), 10);
299
300 let codes = rc.codes();
301 assert!(rc.verify(&codes[0]));
302 assert_eq!(rc.remaining(), 9);
303
304 assert!(!rc.verify(&codes[0]));
306 }
307
308 #[test]
309 fn test_provisioning_uri() {
310 let totp = TOTP::with_secret(vec![1, 2, 3, 4, 5]);
311 let uri = totp.provisioning_uri("user@example.com", "TestApp");
312 assert!(uri.starts_with("otpauth://totp/"));
313 assert!(uri.contains("TestApp"));
314 }
315}