shield_core/
totp.rs

1//! TOTP (Time-based One-Time Password) implementation.
2//!
3//! RFC 6238 compliant with recovery codes support.
4
5// TOTP digits (6-8) and Base32 indices (0-31) never overflow u32
6#![allow(clippy::cast_possible_truncation)]
7
8use ring::hmac;
9use ring::rand::{SecureRandom, SystemRandom};
10use std::collections::HashSet;
11use std::time::{SystemTime, UNIX_EPOCH};
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14use crate::error::{Result, ShieldError};
15
16/// Default secret length in bytes.
17const DEFAULT_SECRET_LEN: usize = 20;
18
19/// TOTP generator and validator.
20///
21/// Secret is securely zeroized from memory when dropped.
22#[derive(Zeroize, ZeroizeOnDrop)]
23pub struct TOTP {
24    secret: Vec<u8>,
25    #[zeroize(skip)]
26    digits: usize,
27    #[zeroize(skip)]
28    interval: u64,
29}
30
31impl TOTP {
32    /// Create new TOTP with secret.
33    #[must_use]
34    pub fn new(secret: Vec<u8>, digits: usize, interval: u64) -> Self {
35        Self {
36            secret,
37            digits: if digits == 0 { 6 } else { digits },
38            interval: if interval == 0 { 30 } else { interval },
39        }
40    }
41
42    /// Create with default settings (6 digits, 30 second interval).
43    #[must_use]
44    pub fn with_secret(secret: Vec<u8>) -> Self {
45        Self::new(secret, 6, 30)
46    }
47
48    /// Generate a random secret.
49    pub fn generate_secret() -> Result<Vec<u8>> {
50        let rng = SystemRandom::new();
51        let mut secret = vec![0u8; DEFAULT_SECRET_LEN];
52        rng.fill(&mut secret)
53            .map_err(|_| ShieldError::RandomFailed)?;
54        Ok(secret)
55    }
56
57    /// Generate TOTP code for given time.
58    #[must_use]
59    pub fn generate(&self, timestamp: Option<u64>) -> String {
60        let time = timestamp.unwrap_or_else(|| {
61            SystemTime::now()
62                .duration_since(UNIX_EPOCH)
63                .unwrap()
64                .as_secs()
65        });
66
67        let counter = time / self.interval;
68        self.generate_hotp(counter)
69    }
70
71    /// Generate HOTP code for counter.
72    fn generate_hotp(&self, counter: u64) -> String {
73        let counter_bytes = counter.to_be_bytes();
74        let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, &self.secret);
75        let tag = hmac::sign(&key, &counter_bytes);
76        let hash = tag.as_ref();
77
78        // Dynamic truncation (RFC 4226)
79        let offset = (hash[19] & 0xf) as usize;
80        let code = u32::from_be_bytes([
81            hash[offset] & 0x7f,
82            hash[offset + 1],
83            hash[offset + 2],
84            hash[offset + 3],
85        ]);
86
87        let modulo = 10u32.pow(self.digits as u32);
88        format!("{:0width$}", code % modulo, width = self.digits)
89    }
90
91    /// Verify TOTP code with time window.
92    #[must_use]
93    pub fn verify(&self, code: &str, timestamp: Option<u64>, window: u32) -> bool {
94        let time = timestamp.unwrap_or_else(|| {
95            SystemTime::now()
96                .duration_since(UNIX_EPOCH)
97                .unwrap()
98                .as_secs()
99        });
100
101        let window = if window == 0 { 1 } else { window };
102
103        for i in 0..=window {
104            let t = time.saturating_sub(u64::from(i) * self.interval);
105            if self.generate(Some(t)) == code {
106                return true;
107            }
108            if i > 0 {
109                let t = time + u64::from(i) * self.interval;
110                if self.generate(Some(t)) == code {
111                    return true;
112                }
113            }
114        }
115        false
116    }
117
118    /// Convert secret to Base32.
119    #[must_use]
120    pub fn secret_to_base32(secret: &[u8]) -> String {
121        const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
122        let mut result = String::new();
123        let mut buffer = 0u64;
124        let mut bits = 0;
125
126        for &byte in secret {
127            buffer = (buffer << 8) | u64::from(byte);
128            bits += 8;
129            while bits >= 5 {
130                bits -= 5;
131                result.push(ALPHABET[((buffer >> bits) & 0x1f) as usize] as char);
132            }
133        }
134
135        if bits > 0 {
136            result.push(ALPHABET[((buffer << (5 - bits)) & 0x1f) as usize] as char);
137        }
138
139        result
140    }
141
142    /// Decode Base32 secret.
143    pub fn secret_from_base32(encoded: &str) -> Result<Vec<u8>> {
144        const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
145        let mut result = Vec::new();
146        let mut buffer = 0u64;
147        let mut bits = 0;
148
149        for c in encoded.chars() {
150            let c = c.to_ascii_uppercase();
151            if c == '=' {
152                continue;
153            }
154            let val = ALPHABET
155                .iter()
156                .position(|&b| b == c as u8)
157                .ok_or(ShieldError::InvalidFormat)?;
158
159            buffer = (buffer << 5) | (val as u64);
160            bits += 5;
161
162            if bits >= 8 {
163                bits -= 8;
164                result.push((buffer >> bits) as u8);
165            }
166        }
167
168        Ok(result)
169    }
170
171    /// Generate provisioning URI for QR codes.
172    #[must_use]
173    pub fn provisioning_uri(&self, account: &str, issuer: &str) -> String {
174        let secret_b32 = Self::secret_to_base32(&self.secret);
175        format!(
176            "otpauth://totp/{}:{}?secret={}&issuer={}&algorithm=SHA1&digits={}&period={}",
177            issuer, account, secret_b32, issuer, self.digits, self.interval
178        )
179    }
180
181    /// Get the secret.
182    #[must_use]
183    pub fn secret(&self) -> &[u8] {
184        &self.secret
185    }
186}
187
188/// Recovery codes for 2FA backup.
189pub struct RecoveryCodes {
190    codes: HashSet<String>,
191    original_count: usize,
192}
193
194impl RecoveryCodes {
195    /// Generate new recovery codes.
196    pub fn new(count: usize) -> Result<Self> {
197        let codes = Self::generate_codes(count)?;
198        let original_count = codes.len();
199        Ok(Self {
200            codes: codes.into_iter().collect(),
201            original_count,
202        })
203    }
204
205    /// Generate codes list.
206    pub fn generate_codes(count: usize) -> Result<Vec<String>> {
207        let rng = SystemRandom::new();
208        let mut codes = Vec::with_capacity(count);
209
210        for _ in 0..count {
211            let mut bytes = [0u8; 4];
212            rng.fill(&mut bytes)
213                .map_err(|_| ShieldError::RandomFailed)?;
214            let code = format!(
215                "{:04X}-{:04X}",
216                u16::from_be_bytes([bytes[0], bytes[1]]),
217                u16::from_be_bytes([bytes[2], bytes[3]])
218            );
219            codes.push(code);
220        }
221
222        Ok(codes)
223    }
224
225    /// Verify and consume a recovery code.
226    pub fn verify(&mut self, code: &str) -> bool {
227        let normalized = code.to_uppercase().replace([' ', '-'], "");
228        let formatted = if normalized.len() == 8 {
229            format!("{}-{}", &normalized[0..4], &normalized[4..8])
230        } else {
231            code.to_uppercase()
232        };
233
234        self.codes.remove(&formatted)
235    }
236
237    /// Get remaining code count.
238    #[must_use]
239    pub fn remaining(&self) -> usize {
240        self.codes.len()
241    }
242
243    /// Get original code count (how many were initially generated).
244    #[must_use]
245    pub fn original_count(&self) -> usize {
246        self.original_count
247    }
248
249    /// Get used code count (original - remaining).
250    #[must_use]
251    pub fn used_count(&self) -> usize {
252        self.original_count - self.codes.len()
253    }
254
255    /// Get all codes (for display to user).
256    #[must_use]
257    pub fn codes(&self) -> Vec<String> {
258        self.codes.iter().cloned().collect()
259    }
260
261    /// Check if any codes remain.
262    #[must_use]
263    pub fn has_codes(&self) -> bool {
264        !self.codes.is_empty()
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_totp_generate_verify() {
274        let secret = TOTP::generate_secret().unwrap();
275        let totp = TOTP::with_secret(secret);
276        let code = totp.generate(None);
277        assert!(totp.verify(&code, None, 1));
278    }
279
280    #[test]
281    fn test_totp_known_vector() {
282        // RFC 6238 test vector
283        let secret = b"12345678901234567890".to_vec();
284        let totp = TOTP::new(secret, 8, 30);
285        let code = totp.generate(Some(59));
286        assert_eq!(code, "94287082");
287    }
288
289    #[test]
290    fn test_base32_roundtrip() {
291        let secret = TOTP::generate_secret().unwrap();
292        let encoded = TOTP::secret_to_base32(&secret);
293        let decoded = TOTP::secret_from_base32(&encoded).unwrap();
294        assert_eq!(secret, decoded);
295    }
296
297    #[test]
298    fn test_recovery_codes() {
299        let mut rc = RecoveryCodes::new(10).unwrap();
300        assert_eq!(rc.remaining(), 10);
301
302        let codes = rc.codes();
303        assert!(rc.verify(&codes[0]));
304        assert_eq!(rc.remaining(), 9);
305
306        // Can't reuse
307        assert!(!rc.verify(&codes[0]));
308    }
309
310    #[test]
311    fn test_provisioning_uri() {
312        let totp = TOTP::with_secret(vec![1, 2, 3, 4, 5]);
313        let uri = totp.provisioning_uri("user@example.com", "TestApp");
314        assert!(uri.starts_with("otpauth://totp/"));
315        assert!(uri.contains("TestApp"));
316    }
317}