Skip to main content

tsafe_core/
totp.rs

1//! TOTP (Time-based One-Time Password) — RFC 6238 code generation and secret management.
2//!
3//! Secrets are stored in the vault as base32-encoded TOTP seeds.  All intermediate
4//! strings holding the decoded secret use [`Zeroizing`] to ensure the raw bytes are
5//! wiped from heap memory when they go out of scope.
6
7use hmac::{Hmac, Mac};
8use sha1::Sha1;
9use zeroize::Zeroizing;
10
11use crate::errors::{SafeError, SafeResult};
12
13type HmacSha1 = Hmac<Sha1>;
14
15/// Parse base32 secret from either a raw base32 string or an otpauth:// URI.
16/// Returns bare base32 string (uppercase, no spaces), wrapped in `Zeroizing`
17/// so the secret is wiped from memory when the caller drops it.
18pub fn extract_base32(input: &str) -> SafeResult<Zeroizing<String>> {
19    let raw = if input.starts_with("otpauth://") {
20        // Parse the `secret` query parameter from the URI.
21        // Format: otpauth://totp/<label>?secret=BASE32&...
22        let query_start = input.find('?').ok_or_else(|| SafeError::InvalidVault {
23            reason: "otpauth:// URI has no query string".into(),
24        })?;
25        let query = &input[query_start + 1..];
26        let secret = query
27            .split('&')
28            .find_map(|pair| {
29                let (k, v) = pair.split_once('=')?;
30                if k.eq_ignore_ascii_case("secret") {
31                    Some(v)
32                } else {
33                    None
34                }
35            })
36            .ok_or_else(|| SafeError::InvalidVault {
37                reason: "otpauth:// URI is missing the 'secret' parameter".into(),
38            })?;
39        secret.to_string()
40    } else {
41        input.to_string()
42    };
43
44    // Normalise: uppercase, strip spaces and hyphens.
45    let normalised: String = raw
46        .chars()
47        .filter(|c| !c.is_whitespace() && *c != '-')
48        .map(|c| c.to_ascii_uppercase())
49        .collect();
50
51    // Validate by attempting a decode.
52    decode_base32(&normalised)?;
53    Ok(Zeroizing::new(normalised))
54}
55
56/// Compute the current 6-digit TOTP code for the given base32 secret.
57/// Uses the current Unix time, 30-second window, SHA1, 6 digits.
58pub fn generate_code(base32_secret: &str) -> SafeResult<String> {
59    let key_bytes = decode_base32(base32_secret)?;
60    let counter = current_counter();
61    let code = hotp(&key_bytes, counter, 6)?;
62    Ok(format!("{code:0>6}"))
63}
64
65/// Seconds remaining in the current 30-second TOTP window.
66pub fn seconds_remaining() -> u64 {
67    let ts = unix_timestamp();
68    30 - (ts % 30)
69}
70
71// ── internal ──────────────────────────────────────────────────────────────────
72
73fn unix_timestamp() -> u64 {
74    std::time::SystemTime::now()
75        .duration_since(std::time::UNIX_EPOCH)
76        .map(|d| d.as_secs())
77        .unwrap_or(0)
78}
79
80fn current_counter() -> u64 {
81    unix_timestamp() / 30
82}
83
84fn decode_base32(s: &str) -> SafeResult<Vec<u8>> {
85    // Try without padding first, then with padding.
86    base32::decode(base32::Alphabet::Rfc4648 { padding: false }, s)
87        .or_else(|| base32::decode(base32::Alphabet::Rfc4648 { padding: true }, s))
88        .ok_or_else(|| SafeError::InvalidVault {
89            reason: "invalid TOTP base32 secret".into(),
90        })
91}
92
93/// RFC 4226 HOTP: HMAC-SHA1 + dynamic truncation.
94fn hotp(key: &[u8], counter: u64, digits: u32) -> SafeResult<u32> {
95    let counter_bytes = counter.to_be_bytes();
96
97    let mut mac = HmacSha1::new_from_slice(key).map_err(|e| SafeError::InvalidVault {
98        reason: format!("HMAC key error: {e}"),
99    })?;
100    mac.update(&counter_bytes);
101    let result = mac.finalize().into_bytes();
102    let result = result.as_slice();
103
104    // Dynamic truncation.
105    let offset = (result[19] & 0x0f) as usize;
106    let code = u32::from_be_bytes([
107        result[offset] & 0x7f,
108        result[offset + 1],
109        result[offset + 2],
110        result[offset + 3],
111    ]);
112
113    let modulus = 10u32.pow(digits);
114    Ok(code % modulus)
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    /// Well-known base32 secret from RFC 6238 test vectors.
122    const KNOWN_B32: &str = "JBSWY3DPEHPK3PXP";
123
124    // ── extract_base32 ────────────────────────────────────────────────────────
125
126    #[test]
127    fn extract_base32_plain_returns_normalised() {
128        let result = extract_base32(KNOWN_B32).unwrap();
129        assert_eq!(*result, KNOWN_B32);
130    }
131
132    #[test]
133    fn extract_base32_lowercase_is_normalised_to_upper() {
134        let result = extract_base32(&KNOWN_B32.to_lowercase()).unwrap();
135        assert_eq!(*result, KNOWN_B32);
136    }
137
138    #[test]
139    fn extract_base32_strips_spaces_and_hyphens() {
140        // Authenticator apps often display secrets with spaces or hyphens for readability.
141        let spaced = "JBSWY 3DP-EHPK 3PXP";
142        let result = extract_base32(spaced).unwrap();
143        assert_eq!(*result, KNOWN_B32);
144    }
145
146    #[test]
147    fn extract_base32_parses_otpauth_uri() {
148        let uri = format!("otpauth://totp/Alice?secret={KNOWN_B32}&issuer=Example");
149        let result = extract_base32(&uri).unwrap();
150        assert_eq!(*result, KNOWN_B32);
151    }
152
153    #[test]
154    fn extract_base32_otpauth_uri_secret_case_insensitive_param_name() {
155        let uri = format!("otpauth://totp/Alice?SECRET={KNOWN_B32}");
156        let result = extract_base32(&uri).unwrap();
157        assert_eq!(*result, KNOWN_B32);
158    }
159
160    #[test]
161    fn extract_base32_otpauth_uri_missing_query_string_errors() {
162        let result = extract_base32("otpauth://totp/Alice");
163        assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
164    }
165
166    #[test]
167    fn extract_base32_otpauth_uri_missing_secret_param_errors() {
168        let result = extract_base32("otpauth://totp/Alice?issuer=Example");
169        assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
170    }
171
172    #[test]
173    fn extract_base32_invalid_base32_chars_errors() {
174        let result = extract_base32("!!!NOT-VALID-BASE32!!!");
175        assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
176    }
177
178    // ── generate_code ─────────────────────────────────────────────────────────
179
180    #[test]
181    fn generate_code_returns_six_digit_string() {
182        let code = generate_code(KNOWN_B32).unwrap();
183        assert_eq!(
184            code.len(),
185            6,
186            "TOTP code must be exactly 6 chars, got {code:?}"
187        );
188        assert!(
189            code.chars().all(|c| c.is_ascii_digit()),
190            "TOTP code must be all digits, got {code:?}"
191        );
192    }
193
194    #[test]
195    fn generate_code_is_stable_within_same_30s_window() {
196        // Two calls within the same 30-second window must return identical codes.
197        // Tiny probability of racing across a window boundary — acceptable for CI.
198        let a = generate_code(KNOWN_B32).unwrap();
199        let b = generate_code(KNOWN_B32).unwrap();
200        assert_eq!(a, b, "codes differed between two rapid calls");
201    }
202
203    #[test]
204    fn generate_code_rejects_invalid_base32() {
205        let result = generate_code("!!!INVALID!!!");
206        assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
207    }
208
209    #[test]
210    fn generate_code_zero_pads_to_six_digits() {
211        // JBSWY3DPEHPK3PXP is a known secret; verify we get a zero-padded string.
212        // We can't pin the exact value without controlling time, but we can verify
213        // the format contract: always exactly 6 decimal digits, potentially with
214        // leading zeros.
215        for _ in 0..3 {
216            let code = generate_code(KNOWN_B32).unwrap();
217            let n: u32 = code.parse().expect("should parse as integer");
218            assert!(n < 1_000_000, "code {n} must be < 1_000_000");
219        }
220    }
221
222    // ── seconds_remaining ─────────────────────────────────────────────────────
223
224    #[test]
225    fn seconds_remaining_is_in_range_1_to_30() {
226        let secs = seconds_remaining();
227        assert!(
228            (1..=30).contains(&secs),
229            "seconds_remaining() returned {secs}, expected 1..=30"
230        );
231    }
232}