Skip to main content

shield_core/
totp.rs

1//! TOTP (Time-based One-Time Password) implementation.
2//!
3//! RFC 6238 compliant with recovery codes support.
4
5// TOTP digits (6-8) and Base32 indices (0-31) never overflow u32
6#![allow(clippy::cast_possible_truncation)]
7
8use ring::hmac;
9use std::collections::HashSet;
10use std::time::{SystemTime, UNIX_EPOCH};
11use subtle::ConstantTimeEq;
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14use crate::error::{Result, ShieldError};
15
16/// Default secret length in bytes.
17const DEFAULT_SECRET_LEN: usize = 20;
18
19/// TOTP generator and validator.
20///
21/// Secret is securely zeroized from memory when dropped.
22#[derive(Zeroize, ZeroizeOnDrop)]
23pub struct TOTP {
24    secret: Vec<u8>,
25    #[zeroize(skip)]
26    digits: usize,
27    #[zeroize(skip)]
28    interval: u64,
29}
30
31impl TOTP {
32    /// Create new TOTP with secret.
33    #[must_use]
34    pub fn new(secret: Vec<u8>, digits: usize, interval: u64) -> Self {
35        Self {
36            secret,
37            digits: digits.clamp(6, 9),
38            interval: if interval == 0 { 30 } else { interval },
39        }
40    }
41
42    /// Create with default settings (6 digits, 30 second interval).
43    #[must_use]
44    pub fn with_secret(secret: Vec<u8>) -> Self {
45        Self::new(secret, 6, 30)
46    }
47
48    /// Generate a random secret.
49    pub fn generate_secret() -> Result<Vec<u8>> {
50        crate::random::random_vec(DEFAULT_SECRET_LEN)
51    }
52
53    /// Generate TOTP code for given time.
54    #[must_use]
55    pub fn generate(&self, timestamp: Option<u64>) -> String {
56        let time = timestamp.unwrap_or_else(|| {
57            SystemTime::now()
58                .duration_since(UNIX_EPOCH)
59                .map_or(0, |d| d.as_secs())
60        });
61
62        let counter = time / self.interval;
63        self.generate_hotp(counter)
64    }
65
66    /// Generate HOTP code for counter.
67    fn generate_hotp(&self, counter: u64) -> String {
68        let counter_bytes = counter.to_be_bytes();
69        let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, &self.secret);
70        let tag = hmac::sign(&key, &counter_bytes);
71        let hash = tag.as_ref();
72
73        // Dynamic truncation (RFC 4226)
74        let offset = (hash[19] & 0xf) as usize;
75        let code = u32::from_be_bytes([
76            hash[offset] & 0x7f,
77            hash[offset + 1],
78            hash[offset + 2],
79            hash[offset + 3],
80        ]);
81
82        let modulo = 10u32.pow(self.digits as u32);
83        format!("{:0width$}", code % modulo, width = self.digits)
84    }
85
86    /// Verify TOTP code with time window.
87    #[must_use]
88    pub fn verify(&self, code: &str, timestamp: Option<u64>, window: u32) -> bool {
89        let time = timestamp.unwrap_or_else(|| {
90            SystemTime::now()
91                .duration_since(UNIX_EPOCH)
92                .map_or(0, |d| d.as_secs())
93        });
94
95        for i in 0..=window {
96            let t = time.saturating_sub(u64::from(i) * self.interval);
97            let expected = self.generate(Some(t));
98            if expected.as_bytes().ct_eq(code.as_bytes()).unwrap_u8() == 1 {
99                return true;
100            }
101            if i > 0 {
102                let t = time + u64::from(i) * self.interval;
103                let expected = self.generate(Some(t));
104                if expected.as_bytes().ct_eq(code.as_bytes()).unwrap_u8() == 1 {
105                    return true;
106                }
107            }
108        }
109        false
110    }
111
112    /// Convert secret to Base32.
113    #[must_use]
114    pub fn secret_to_base32(secret: &[u8]) -> String {
115        const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
116        let mut result = String::new();
117        let mut buffer = 0u64;
118        let mut bits = 0;
119
120        for &byte in secret {
121            buffer = (buffer << 8) | u64::from(byte);
122            bits += 8;
123            while bits >= 5 {
124                bits -= 5;
125                result.push(ALPHABET[((buffer >> bits) & 0x1f) as usize] as char);
126            }
127        }
128
129        if bits > 0 {
130            result.push(ALPHABET[((buffer << (5 - bits)) & 0x1f) as usize] as char);
131        }
132
133        result
134    }
135
136    /// Decode Base32 secret.
137    pub fn secret_from_base32(encoded: &str) -> Result<Vec<u8>> {
138        const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
139        let mut result = Vec::new();
140        let mut buffer = 0u64;
141        let mut bits = 0;
142
143        for c in encoded.chars() {
144            let c = c.to_ascii_uppercase();
145            if c == '=' {
146                continue;
147            }
148            let val = ALPHABET
149                .iter()
150                .position(|&b| b == c as u8)
151                .ok_or(ShieldError::InvalidFormat)?;
152
153            buffer = (buffer << 5) | (val as u64);
154            bits += 5;
155
156            if bits >= 8 {
157                bits -= 8;
158                result.push((buffer >> bits) as u8);
159            }
160        }
161
162        Ok(result)
163    }
164
165    /// Generate provisioning URI for QR codes.
166    #[must_use]
167    pub fn provisioning_uri(&self, account: &str, issuer: &str) -> String {
168        let secret_b32 = Self::secret_to_base32(&self.secret);
169        format!(
170            "otpauth://totp/{}:{}?secret={}&issuer={}&algorithm=SHA1&digits={}&period={}",
171            issuer, account, secret_b32, issuer, self.digits, self.interval
172        )
173    }
174
175    /// Get the secret.
176    #[must_use]
177    pub fn secret(&self) -> &[u8] {
178        &self.secret
179    }
180}
181
182/// Recovery codes for 2FA backup.
183pub struct RecoveryCodes {
184    codes: HashSet<String>,
185    original_count: usize,
186}
187
188impl RecoveryCodes {
189    /// Generate new recovery codes.
190    pub fn new(count: usize) -> Result<Self> {
191        let codes = Self::generate_codes(count)?;
192        let original_count = codes.len();
193        Ok(Self {
194            codes: codes.into_iter().collect(),
195            original_count,
196        })
197    }
198
199    /// Generate codes list.
200    pub fn generate_codes(count: usize) -> Result<Vec<String>> {
201        let mut codes = Vec::with_capacity(count);
202
203        for _ in 0..count {
204            let bytes: [u8; 8] = crate::random::random_bytes()?;
205            let code = format!(
206                "{:04X}-{:04X}-{:04X}-{:04X}",
207                u16::from_be_bytes([bytes[0], bytes[1]]),
208                u16::from_be_bytes([bytes[2], bytes[3]]),
209                u16::from_be_bytes([bytes[4], bytes[5]]),
210                u16::from_be_bytes([bytes[6], bytes[7]])
211            );
212            codes.push(code);
213        }
214
215        Ok(codes)
216    }
217
218    /// Verify and consume a recovery code.
219    pub fn verify(&mut self, code: &str) -> bool {
220        let normalized = code.to_uppercase().replace([' ', '-'], "");
221        let formatted = if normalized.len() == 16 {
222            format!(
223                "{}-{}-{}-{}",
224                &normalized[0..4],
225                &normalized[4..8],
226                &normalized[8..12],
227                &normalized[12..16]
228            )
229        } else {
230            code.to_uppercase()
231        };
232
233        self.codes.remove(&formatted)
234    }
235
236    /// Get remaining code count.
237    #[must_use]
238    pub fn remaining(&self) -> usize {
239        self.codes.len()
240    }
241
242    /// Get original code count (how many were initially generated).
243    #[must_use]
244    pub fn original_count(&self) -> usize {
245        self.original_count
246    }
247
248    /// Get used code count (original - remaining).
249    #[must_use]
250    pub fn used_count(&self) -> usize {
251        self.original_count - self.codes.len()
252    }
253
254    /// Get all codes (for display to user).
255    #[must_use]
256    pub fn codes(&self) -> Vec<String> {
257        self.codes.iter().cloned().collect()
258    }
259
260    /// Check if any codes remain.
261    #[must_use]
262    pub fn has_codes(&self) -> bool {
263        !self.codes.is_empty()
264    }
265}
266
267impl Drop for RecoveryCodes {
268    fn drop(&mut self) {
269        for code in self.codes.drain() {
270            // Zeroize the String's bytes before dropping
271            let mut bytes = code.into_bytes();
272            bytes.zeroize();
273        }
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_totp_generate_verify() {
283        let secret = TOTP::generate_secret().unwrap();
284        let totp = TOTP::with_secret(secret);
285        let code = totp.generate(None);
286        assert!(totp.verify(&code, None, 1));
287    }
288
289    #[test]
290    fn test_totp_minimum_digits_clamping() {
291        let secret = b"12345678901234567890".to_vec();
292
293        // digits=1 should be clamped to 6
294        let totp1 = TOTP::new(secret.clone(), 1, 30);
295        let code1 = totp1.generate(Some(59));
296        assert_eq!(code1.len(), 6);
297
298        // digits=5 should be clamped to 6
299        let totp5 = TOTP::new(secret.clone(), 5, 30);
300        let code5 = totp5.generate(Some(59));
301        assert_eq!(code5.len(), 6);
302
303        // digits=0 should be clamped to 6
304        let totp0 = TOTP::new(secret.clone(), 0, 30);
305        let code0 = totp0.generate(Some(59));
306        assert_eq!(code0.len(), 6);
307
308        // digits=6 should stay 6
309        let totp6 = TOTP::new(secret, 6, 30);
310        let code6 = totp6.generate(Some(59));
311        assert_eq!(code6.len(), 6);
312    }
313
314    #[test]
315    fn test_totp_known_vector() {
316        // RFC 6238 test vector
317        let secret = b"12345678901234567890".to_vec();
318        let totp = TOTP::new(secret, 8, 30);
319        let code = totp.generate(Some(59));
320        assert_eq!(code, "94287082");
321    }
322
323    #[test]
324    fn test_base32_roundtrip() {
325        let secret = TOTP::generate_secret().unwrap();
326        let encoded = TOTP::secret_to_base32(&secret);
327        let decoded = TOTP::secret_from_base32(&encoded).unwrap();
328        assert_eq!(secret, decoded);
329    }
330
331    #[test]
332    fn test_recovery_codes() {
333        let mut rc = RecoveryCodes::new(10).unwrap();
334        assert_eq!(rc.remaining(), 10);
335
336        let codes = rc.codes();
337        assert!(rc.verify(&codes[0]));
338        assert_eq!(rc.remaining(), 9);
339
340        // Can't reuse
341        assert!(!rc.verify(&codes[0]));
342    }
343
344    #[test]
345    fn test_provisioning_uri() {
346        let totp = TOTP::with_secret(vec![1, 2, 3, 4, 5]);
347        let uri = totp.provisioning_uri("user@example.com", "TestApp");
348        assert!(uri.starts_with("otpauth://totp/"));
349        assert!(uri.contains("TestApp"));
350    }
351}