shield_core/
totp.rs

1//! TOTP (Time-based One-Time Password) implementation.
2//!
3//! RFC 6238 compliant with recovery codes support.
4
5// TOTP digits (6-8) and Base32 indices (0-31) never overflow u32
6#![allow(clippy::cast_possible_truncation)]
7
8use ring::hmac;
9use ring::rand::{SecureRandom, SystemRandom};
10use std::collections::HashSet;
11use std::time::{SystemTime, UNIX_EPOCH};
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14use crate::error::{Result, ShieldError};
15
16/// Default secret length in bytes.
17const DEFAULT_SECRET_LEN: usize = 20;
18
19/// TOTP generator and validator.
20///
21/// Secret is securely zeroized from memory when dropped.
22#[derive(Zeroize, ZeroizeOnDrop)]
23pub struct TOTP {
24    secret: Vec<u8>,
25    #[zeroize(skip)]
26    digits: usize,
27    #[zeroize(skip)]
28    interval: u64,
29}
30
31impl TOTP {
32    /// Create new TOTP with secret.
33    #[must_use] 
34    pub fn new(secret: Vec<u8>, digits: usize, interval: u64) -> Self {
35        Self {
36            secret,
37            digits: if digits == 0 { 6 } else { digits },
38            interval: if interval == 0 { 30 } else { interval },
39        }
40    }
41
42    /// Create with default settings (6 digits, 30 second interval).
43    #[must_use] 
44    pub fn with_secret(secret: Vec<u8>) -> Self {
45        Self::new(secret, 6, 30)
46    }
47
48    /// Generate a random secret.
49    pub fn generate_secret() -> Result<Vec<u8>> {
50        let rng = SystemRandom::new();
51        let mut secret = vec![0u8; DEFAULT_SECRET_LEN];
52        rng.fill(&mut secret).map_err(|_| ShieldError::RandomFailed)?;
53        Ok(secret)
54    }
55
56    /// Generate TOTP code for given time.
57    #[must_use] 
58    pub fn generate(&self, timestamp: Option<u64>) -> String {
59        let time = timestamp.unwrap_or_else(|| {
60            SystemTime::now()
61                .duration_since(UNIX_EPOCH)
62                .unwrap()
63                .as_secs()
64        });
65
66        let counter = time / self.interval;
67        self.generate_hotp(counter)
68    }
69
70    /// Generate HOTP code for counter.
71    fn generate_hotp(&self, counter: u64) -> String {
72        let counter_bytes = counter.to_be_bytes();
73        let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, &self.secret);
74        let tag = hmac::sign(&key, &counter_bytes);
75        let hash = tag.as_ref();
76
77        // Dynamic truncation (RFC 4226)
78        let offset = (hash[19] & 0xf) as usize;
79        let code = u32::from_be_bytes([
80            hash[offset] & 0x7f,
81            hash[offset + 1],
82            hash[offset + 2],
83            hash[offset + 3],
84        ]);
85
86        let modulo = 10u32.pow(self.digits as u32);
87        format!("{:0width$}", code % modulo, width = self.digits)
88    }
89
90    /// Verify TOTP code with time window.
91    #[must_use] 
92    pub fn verify(&self, code: &str, timestamp: Option<u64>, window: u32) -> bool {
93        let time = timestamp.unwrap_or_else(|| {
94            SystemTime::now()
95                .duration_since(UNIX_EPOCH)
96                .unwrap()
97                .as_secs()
98        });
99
100        let window = if window == 0 { 1 } else { window };
101
102        for i in 0..=window {
103            let t = time.saturating_sub(u64::from(i) * self.interval);
104            if self.generate(Some(t)) == code {
105                return true;
106            }
107            if i > 0 {
108                let t = time + u64::from(i) * self.interval;
109                if self.generate(Some(t)) == code {
110                    return true;
111                }
112            }
113        }
114        false
115    }
116
117    /// Convert secret to Base32.
118    #[must_use] 
119    pub fn secret_to_base32(secret: &[u8]) -> String {
120        const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
121        let mut result = String::new();
122        let mut buffer = 0u64;
123        let mut bits = 0;
124
125        for &byte in secret {
126            buffer = (buffer << 8) | u64::from(byte);
127            bits += 8;
128            while bits >= 5 {
129                bits -= 5;
130                result.push(ALPHABET[((buffer >> bits) & 0x1f) as usize] as char);
131            }
132        }
133
134        if bits > 0 {
135            result.push(ALPHABET[((buffer << (5 - bits)) & 0x1f) as usize] as char);
136        }
137
138        result
139    }
140
141    /// Decode Base32 secret.
142    pub fn secret_from_base32(encoded: &str) -> Result<Vec<u8>> {
143        const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
144        let mut result = Vec::new();
145        let mut buffer = 0u64;
146        let mut bits = 0;
147
148        for c in encoded.chars() {
149            let c = c.to_ascii_uppercase();
150            if c == '=' {
151                continue;
152            }
153            let val = ALPHABET
154                .iter()
155                .position(|&b| b == c as u8)
156                .ok_or(ShieldError::InvalidFormat)?;
157
158            buffer = (buffer << 5) | (val as u64);
159            bits += 5;
160
161            if bits >= 8 {
162                bits -= 8;
163                result.push((buffer >> bits) as u8);
164            }
165        }
166
167        Ok(result)
168    }
169
170    /// Generate provisioning URI for QR codes.
171    #[must_use] 
172    pub fn provisioning_uri(&self, account: &str, issuer: &str) -> String {
173        let secret_b32 = Self::secret_to_base32(&self.secret);
174        format!(
175            "otpauth://totp/{}:{}?secret={}&issuer={}&algorithm=SHA1&digits={}&period={}",
176            issuer, account, secret_b32, issuer, self.digits, self.interval
177        )
178    }
179
180    /// Get the secret.
181    #[must_use] 
182    pub fn secret(&self) -> &[u8] {
183        &self.secret
184    }
185}
186
187/// Recovery codes for 2FA backup.
188pub struct RecoveryCodes {
189    codes: HashSet<String>,
190    original_count: usize,
191}
192
193impl RecoveryCodes {
194    /// Generate new recovery codes.
195    pub fn new(count: usize) -> Result<Self> {
196        let codes = Self::generate_codes(count)?;
197        let original_count = codes.len();
198        Ok(Self {
199            codes: codes.into_iter().collect(),
200            original_count,
201        })
202    }
203
204    /// Generate codes list.
205    pub fn generate_codes(count: usize) -> Result<Vec<String>> {
206        let rng = SystemRandom::new();
207        let mut codes = Vec::with_capacity(count);
208
209        for _ in 0..count {
210            let mut bytes = [0u8; 4];
211            rng.fill(&mut bytes).map_err(|_| ShieldError::RandomFailed)?;
212            let code = format!(
213                "{:04X}-{:04X}",
214                u16::from_be_bytes([bytes[0], bytes[1]]),
215                u16::from_be_bytes([bytes[2], bytes[3]])
216            );
217            codes.push(code);
218        }
219
220        Ok(codes)
221    }
222
223    /// Verify and consume a recovery code.
224    pub fn verify(&mut self, code: &str) -> bool {
225        let normalized = code.to_uppercase().replace([' ', '-'], "");
226        let formatted = if normalized.len() == 8 {
227            format!("{}-{}", &normalized[0..4], &normalized[4..8])
228        } else {
229            code.to_uppercase()
230        };
231
232        self.codes.remove(&formatted)
233    }
234
235    /// Get remaining code count.
236    #[must_use] 
237    pub fn remaining(&self) -> usize {
238        self.codes.len()
239    }
240
241    /// Get original code count (how many were initially generated).
242    #[must_use] 
243    pub fn original_count(&self) -> usize {
244        self.original_count
245    }
246
247    /// Get used code count (original - remaining).
248    #[must_use] 
249    pub fn used_count(&self) -> usize {
250        self.original_count - self.codes.len()
251    }
252
253    /// Get all codes (for display to user).
254    #[must_use] 
255    pub fn codes(&self) -> Vec<String> {
256        self.codes.iter().cloned().collect()
257    }
258
259    /// Check if any codes remain.
260    #[must_use] 
261    pub fn has_codes(&self) -> bool {
262        !self.codes.is_empty()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_totp_generate_verify() {
272        let secret = TOTP::generate_secret().unwrap();
273        let totp = TOTP::with_secret(secret);
274        let code = totp.generate(None);
275        assert!(totp.verify(&code, None, 1));
276    }
277
278    #[test]
279    fn test_totp_known_vector() {
280        // RFC 6238 test vector
281        let secret = b"12345678901234567890".to_vec();
282        let totp = TOTP::new(secret, 8, 30);
283        let code = totp.generate(Some(59));
284        assert_eq!(code, "94287082");
285    }
286
287    #[test]
288    fn test_base32_roundtrip() {
289        let secret = TOTP::generate_secret().unwrap();
290        let encoded = TOTP::secret_to_base32(&secret);
291        let decoded = TOTP::secret_from_base32(&encoded).unwrap();
292        assert_eq!(secret, decoded);
293    }
294
295    #[test]
296    fn test_recovery_codes() {
297        let mut rc = RecoveryCodes::new(10).unwrap();
298        assert_eq!(rc.remaining(), 10);
299
300        let codes = rc.codes();
301        assert!(rc.verify(&codes[0]));
302        assert_eq!(rc.remaining(), 9);
303
304        // Can't reuse
305        assert!(!rc.verify(&codes[0]));
306    }
307
308    #[test]
309    fn test_provisioning_uri() {
310        let totp = TOTP::with_secret(vec![1, 2, 3, 4, 5]);
311        let uri = totp.provisioning_uri("user@example.com", "TestApp");
312        assert!(uri.starts_with("otpauth://totp/"));
313        assert!(uri.contains("TestApp"));
314    }
315}