Skip to main content

tsafe_core/
totp.rs

1//! TOTP (Time-based One-Time Password) — RFC 6238 code generation and secret management.
2//!
3//! Secrets are stored in the vault as base32-encoded TOTP seeds.  All intermediate
4//! strings holding the decoded secret use [`Zeroizing`] to ensure the raw bytes are
5//! wiped from heap memory when they go out of scope.
6
7use hmac::{Hmac, Mac};
8use sha1::Sha1;
9use sha2::{Sha256, Sha512};
10use zeroize::Zeroizing;
11
12use crate::errors::{SafeError, SafeResult};
13
14type HmacSha1 = Hmac<Sha1>;
15type HmacSha256 = Hmac<Sha256>;
16type HmacSha512 = Hmac<Sha512>;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19enum TotpAlgorithm {
20    Sha1,
21    Sha256,
22    Sha512,
23}
24
25impl TotpAlgorithm {
26    fn parse(value: &str) -> SafeResult<Self> {
27        match value.to_ascii_uppercase().as_str() {
28            "SHA1" | "SHA-1" => Ok(Self::Sha1),
29            "SHA256" | "SHA-256" => Ok(Self::Sha256),
30            "SHA512" | "SHA-512" => Ok(Self::Sha512),
31            other => Err(SafeError::InvalidVault {
32                reason: format!("unsupported TOTP algorithm '{other}'"),
33            }),
34        }
35    }
36
37    fn as_uri_str(self) -> &'static str {
38        match self {
39            Self::Sha1 => "SHA1",
40            Self::Sha256 => "SHA256",
41            Self::Sha512 => "SHA512",
42        }
43    }
44}
45
46#[derive(Debug)]
47struct TotpConfig {
48    secret: Zeroizing<String>,
49    algorithm: TotpAlgorithm,
50    digits: u32,
51    period: u64,
52}
53
54impl TotpConfig {
55    fn default_for_secret(secret: String) -> Self {
56        Self {
57            secret: Zeroizing::new(secret),
58            algorithm: TotpAlgorithm::Sha1,
59            digits: 6,
60            period: 30,
61        }
62    }
63}
64
65/// Parse base32 secret from either a raw base32 string or an otpauth:// URI.
66/// Returns bare base32 string (uppercase, no spaces), wrapped in `Zeroizing`
67/// so the secret is wiped from memory when the caller drops it.
68pub fn extract_base32(input: &str) -> SafeResult<Zeroizing<String>> {
69    Ok(parse_totp_config(input)?.secret)
70}
71
72/// Build a canonical `otpauth://` URI from raw base32 or pasted URI input.
73///
74/// Existing URI algorithm, digits, and period parameters are preserved unless
75/// the caller supplies an explicit override.
76pub fn provisioning_uri(
77    label: &str,
78    input: &str,
79    algorithm: Option<&str>,
80    digits: Option<u32>,
81    period: Option<u64>,
82) -> SafeResult<String> {
83    let mut config = parse_totp_config(input)?;
84    if let Some(algorithm) = algorithm {
85        config.algorithm = TotpAlgorithm::parse(algorithm)?;
86    }
87    if let Some(digits) = digits {
88        config.digits = validate_digits(digits)?;
89    }
90    if let Some(period) = period {
91        config.period = validate_period(period)?;
92    }
93
94    Ok(format!(
95        "otpauth://totp/{label}?secret={}&algorithm={}&digits={}&period={}",
96        config.secret.as_str(),
97        config.algorithm.as_uri_str(),
98        config.digits,
99        config.period
100    ))
101}
102
103/// Compute the current TOTP code for a raw base32 secret or `otpauth://` URI.
104///
105/// Raw base32 input uses the common SHA1 / 6 digits / 30 seconds profile.
106/// `otpauth://` input honors `algorithm`, `digits`, and `period` query
107/// parameters, defaulting omitted values to SHA1 / 6 / 30.
108pub fn generate_code(input: &str) -> SafeResult<String> {
109    generate_code_at(input, unix_timestamp())
110}
111
112/// Compute a TOTP code at a specific Unix timestamp.
113///
114/// This deterministic API is useful for tests and non-wall-clock consumers.
115pub fn generate_code_at(input: &str, timestamp: u64) -> SafeResult<String> {
116    let config = parse_totp_config(input)?;
117    let key_bytes = decode_base32(&config.secret)?;
118    let counter = timestamp / config.period;
119    let code = hotp(&key_bytes, counter, config.digits, config.algorithm)?;
120    Ok(format_code(code, config.digits))
121}
122
123/// Seconds remaining in the current 30-second TOTP window.
124pub fn seconds_remaining() -> u64 {
125    let ts = unix_timestamp();
126    30 - (ts % 30)
127}
128
129/// Seconds remaining in the current TOTP window for a raw base32 secret or URI.
130pub fn seconds_remaining_for(input: &str) -> SafeResult<u64> {
131    seconds_remaining_for_at(input, unix_timestamp())
132}
133
134/// Seconds remaining in a TOTP window for a specific Unix timestamp.
135pub fn seconds_remaining_for_at(input: &str, timestamp: u64) -> SafeResult<u64> {
136    let config = parse_totp_config(input)?;
137    Ok(config.period - (timestamp % config.period))
138}
139
140// ── internal ──────────────────────────────────────────────────────────────────
141
142fn unix_timestamp() -> u64 {
143    std::time::SystemTime::now()
144        .duration_since(std::time::UNIX_EPOCH)
145        .map(|d| d.as_secs())
146        .unwrap_or(0)
147}
148
149fn decode_base32(s: &str) -> SafeResult<Vec<u8>> {
150    // Try without padding first, then with padding.
151    base32::decode(base32::Alphabet::Rfc4648 { padding: false }, s)
152        .or_else(|| base32::decode(base32::Alphabet::Rfc4648 { padding: true }, s))
153        .ok_or_else(|| SafeError::InvalidVault {
154            reason: "invalid TOTP base32 secret".into(),
155        })
156}
157
158fn parse_totp_config(input: &str) -> SafeResult<TotpConfig> {
159    if input.starts_with("otpauth://") {
160        parse_otpauth_uri(input)
161    } else {
162        let normalised = normalise_base32(input);
163        decode_base32(&normalised)?;
164        Ok(TotpConfig::default_for_secret(normalised))
165    }
166}
167
168fn parse_otpauth_uri(input: &str) -> SafeResult<TotpConfig> {
169    let query_start = input.find('?').ok_or_else(|| SafeError::InvalidVault {
170        reason: "otpauth:// URI has no query string".into(),
171    })?;
172    let query = &input[query_start + 1..];
173
174    let mut secret = None;
175    let mut algorithm = TotpAlgorithm::Sha1;
176    let mut digits = 6;
177    let mut period = 30;
178
179    for pair in query.split('&') {
180        let Some((key, value)) = pair.split_once('=') else {
181            continue;
182        };
183        let value = decode_query_component(value)?;
184        match key.to_ascii_lowercase().as_str() {
185            "secret" if secret.is_none() => {
186                let normalised = normalise_base32(&value);
187                decode_base32(&normalised)?;
188                secret = Some(Zeroizing::new(normalised));
189            }
190            "algorithm" => {
191                algorithm = TotpAlgorithm::parse(&value)?;
192            }
193            "digits" => {
194                digits = parse_digits(&value)?;
195            }
196            "period" => {
197                period = parse_period(&value)?;
198            }
199            _ => {}
200        }
201    }
202
203    let secret = secret.ok_or_else(|| SafeError::InvalidVault {
204        reason: "otpauth:// URI is missing the 'secret' parameter".into(),
205    })?;
206
207    Ok(TotpConfig {
208        secret,
209        algorithm,
210        digits,
211        period,
212    })
213}
214
215fn normalise_base32(raw: &str) -> String {
216    raw.chars()
217        .filter(|c| !c.is_whitespace() && *c != '-')
218        .map(|c| c.to_ascii_uppercase())
219        .collect()
220}
221
222fn decode_query_component(input: &str) -> SafeResult<String> {
223    let bytes = input.as_bytes();
224    let mut out = Vec::with_capacity(bytes.len());
225    let mut i = 0;
226    while i < bytes.len() {
227        match bytes[i] {
228            b'%' if i + 2 < bytes.len() => {
229                let high = from_hex(bytes[i + 1])?;
230                let low = from_hex(bytes[i + 2])?;
231                out.push((high << 4) | low);
232                i += 3;
233            }
234            b'%' => {
235                return Err(SafeError::InvalidVault {
236                    reason: "invalid percent encoding in otpauth:// URI".into(),
237                });
238            }
239            b'+' => {
240                out.push(b' ');
241                i += 1;
242            }
243            byte => {
244                out.push(byte);
245                i += 1;
246            }
247        }
248    }
249    String::from_utf8(out).map_err(|e| SafeError::InvalidVault {
250        reason: format!("otpauth:// URI parameter is not UTF-8: {e}"),
251    })
252}
253
254fn from_hex(byte: u8) -> SafeResult<u8> {
255    match byte {
256        b'0'..=b'9' => Ok(byte - b'0'),
257        b'a'..=b'f' => Ok(byte - b'a' + 10),
258        b'A'..=b'F' => Ok(byte - b'A' + 10),
259        _ => Err(SafeError::InvalidVault {
260            reason: "invalid percent encoding in otpauth:// URI".into(),
261        }),
262    }
263}
264
265fn parse_digits(value: &str) -> SafeResult<u32> {
266    let digits = value.parse::<u32>().map_err(|_| SafeError::InvalidVault {
267        reason: format!("invalid TOTP digits '{value}'"),
268    })?;
269    validate_digits(digits)
270}
271
272fn validate_digits(digits: u32) -> SafeResult<u32> {
273    if (1..=10).contains(&digits) {
274        Ok(digits)
275    } else {
276        Err(SafeError::InvalidVault {
277            reason: "TOTP digits must be between 1 and 10".into(),
278        })
279    }
280}
281
282fn parse_period(value: &str) -> SafeResult<u64> {
283    let period = value.parse::<u64>().map_err(|_| SafeError::InvalidVault {
284        reason: format!("invalid TOTP period '{value}'"),
285    })?;
286    validate_period(period)
287}
288
289fn validate_period(period: u64) -> SafeResult<u64> {
290    if period > 0 {
291        Ok(period)
292    } else {
293        Err(SafeError::InvalidVault {
294            reason: "TOTP period must be at least 1 second".into(),
295        })
296    }
297}
298
299/// RFC 4226 HOTP: HMAC + dynamic truncation.
300fn hotp(key: &[u8], counter: u64, digits: u32, algorithm: TotpAlgorithm) -> SafeResult<u64> {
301    let counter_bytes = counter.to_be_bytes();
302    let result = hmac_digest(key, &counter_bytes, algorithm)?;
303
304    let offset = (result[result.len() - 1] & 0x0f) as usize;
305    let code = u32::from_be_bytes([
306        result[offset] & 0x7f,
307        result[offset + 1],
308        result[offset + 2],
309        result[offset + 3],
310    ]);
311
312    let modulus = 10u64.pow(digits);
313    Ok(u64::from(code) % modulus)
314}
315
316fn hmac_digest(key: &[u8], counter_bytes: &[u8], algorithm: TotpAlgorithm) -> SafeResult<Vec<u8>> {
317    match algorithm {
318        TotpAlgorithm::Sha1 => {
319            let mut mac = HmacSha1::new_from_slice(key).map_err(hmac_key_error)?;
320            mac.update(counter_bytes);
321            Ok(mac.finalize().into_bytes().to_vec())
322        }
323        TotpAlgorithm::Sha256 => {
324            let mut mac = HmacSha256::new_from_slice(key).map_err(hmac_key_error)?;
325            mac.update(counter_bytes);
326            Ok(mac.finalize().into_bytes().to_vec())
327        }
328        TotpAlgorithm::Sha512 => {
329            let mut mac = HmacSha512::new_from_slice(key).map_err(hmac_key_error)?;
330            mac.update(counter_bytes);
331            Ok(mac.finalize().into_bytes().to_vec())
332        }
333    }
334}
335
336fn hmac_key_error(e: hmac::digest::InvalidLength) -> SafeError {
337    SafeError::InvalidVault {
338        reason: format!("HMAC key error: {e}"),
339    }
340}
341
342fn format_code(code: u64, digits: u32) -> String {
343    format!("{code:0>width$}", width = digits as usize)
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    /// Well-known base32 secret from RFC 6238 test vectors.
351    const KNOWN_B32: &str = "JBSWY3DPEHPK3PXP";
352
353    // ── extract_base32 ────────────────────────────────────────────────────────
354
355    #[test]
356    fn extract_base32_plain_returns_normalised() {
357        let result = extract_base32(KNOWN_B32).unwrap();
358        assert_eq!(*result, KNOWN_B32);
359    }
360
361    #[test]
362    fn extract_base32_lowercase_is_normalised_to_upper() {
363        let result = extract_base32(&KNOWN_B32.to_lowercase()).unwrap();
364        assert_eq!(*result, KNOWN_B32);
365    }
366
367    #[test]
368    fn extract_base32_strips_spaces_and_hyphens() {
369        // Authenticator apps often display secrets with spaces or hyphens for readability.
370        let spaced = "JBSWY 3DP-EHPK 3PXP";
371        let result = extract_base32(spaced).unwrap();
372        assert_eq!(*result, KNOWN_B32);
373    }
374
375    #[test]
376    fn extract_base32_parses_otpauth_uri() {
377        let uri = format!("otpauth://totp/Alice?secret={KNOWN_B32}&issuer=Example");
378        let result = extract_base32(&uri).unwrap();
379        assert_eq!(*result, KNOWN_B32);
380    }
381
382    #[test]
383    fn extract_base32_otpauth_uri_secret_case_insensitive_param_name() {
384        let uri = format!("otpauth://totp/Alice?SECRET={KNOWN_B32}");
385        let result = extract_base32(&uri).unwrap();
386        assert_eq!(*result, KNOWN_B32);
387    }
388
389    #[test]
390    fn extract_base32_otpauth_uri_missing_query_string_errors() {
391        let result = extract_base32("otpauth://totp/Alice");
392        assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
393    }
394
395    #[test]
396    fn extract_base32_otpauth_uri_missing_secret_param_errors() {
397        let result = extract_base32("otpauth://totp/Alice?issuer=Example");
398        assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
399    }
400
401    #[test]
402    fn extract_base32_invalid_base32_chars_errors() {
403        let result = extract_base32("!!!NOT-VALID-BASE32!!!");
404        assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
405    }
406
407    #[test]
408    fn provisioning_uri_preserves_otpauth_parameters() {
409        let seed = encode_base32(b"12345678901234567890123456789012");
410        let input =
411            format!("otpauth://totp/Alice?secret={seed}&algorithm=SHA256&digits=8&period=60");
412
413        let uri = provisioning_uri("GITHUB_2FA", &input, None, None, None).unwrap();
414
415        assert_eq!(
416            uri,
417            format!("otpauth://totp/GITHUB_2FA?secret={seed}&algorithm=SHA256&digits=8&period=60")
418        );
419    }
420
421    #[test]
422    fn provisioning_uri_overrides_otpauth_parameters() {
423        let seed = encode_base32(b"12345678901234567890123456789012");
424        let input =
425            format!("otpauth://totp/Alice?secret={seed}&algorithm=SHA256&digits=8&period=60");
426
427        let uri = provisioning_uri("GITHUB_2FA", &input, Some("SHA1"), Some(6), Some(30)).unwrap();
428
429        assert_eq!(
430            uri,
431            format!("otpauth://totp/GITHUB_2FA?secret={seed}&algorithm=SHA1&digits=6&period=30")
432        );
433    }
434
435    // ── generate_code ─────────────────────────────────────────────────────────
436
437    #[test]
438    fn generate_code_returns_six_digit_string() {
439        let code = generate_code(KNOWN_B32).unwrap();
440        assert_eq!(
441            code.len(),
442            6,
443            "TOTP code must be exactly 6 chars, got {code:?}"
444        );
445        assert!(
446            code.chars().all(|c| c.is_ascii_digit()),
447            "TOTP code must be all digits, got {code:?}"
448        );
449    }
450
451    #[test]
452    fn generate_code_is_stable_within_same_30s_window() {
453        // Two calls within the same 30-second window must return identical codes.
454        // Tiny probability of racing across a window boundary — acceptable for CI.
455        let a = generate_code(KNOWN_B32).unwrap();
456        let b = generate_code(KNOWN_B32).unwrap();
457        assert_eq!(a, b, "codes differed between two rapid calls");
458    }
459
460    #[test]
461    fn generate_code_rejects_invalid_base32() {
462        let result = generate_code("!!!INVALID!!!");
463        assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
464    }
465
466    #[test]
467    fn generate_code_zero_pads_to_six_digits() {
468        // JBSWY3DPEHPK3PXP is a known secret; verify we get a zero-padded string.
469        // We can't pin the exact value without controlling time, but we can verify
470        // the format contract: always exactly 6 decimal digits, potentially with
471        // leading zeros.
472        for _ in 0..3 {
473            let code = generate_code(KNOWN_B32).unwrap();
474            let n: u32 = code.parse().expect("should parse as integer");
475            assert!(n < 1_000_000, "code {n} must be < 1_000_000");
476        }
477    }
478
479    #[test]
480    fn generate_code_at_matches_rfc6238_vectors() {
481        // RFC 6238 Appendix B uses these ASCII seeds:
482        // SHA1:   "12345678901234567890"
483        // SHA256: "12345678901234567890123456789012"
484        // SHA512: "1234567890123456789012345678901234567890123456789012345678901234"
485        let seed_sha1 = encode_base32(b"12345678901234567890");
486        let seed_sha256 = encode_base32(b"12345678901234567890123456789012");
487        let seed_sha512 =
488            encode_base32(b"1234567890123456789012345678901234567890123456789012345678901234");
489        let vectors = [
490            (59, "94287082", "46119246", "90693936"),
491            (1_111_111_109, "07081804", "68084774", "25091201"),
492            (1_111_111_111, "14050471", "67062674", "99943326"),
493            (1_234_567_890, "89005924", "91819424", "93441116"),
494            (2_000_000_000, "69279037", "90698825", "38618901"),
495            (20_000_000_000, "65353130", "77737706", "47863826"),
496        ];
497
498        for (timestamp, sha1, sha256, sha512) in vectors {
499            assert_eq!(
500                generate_code_at(&format_uri(&seed_sha1, "SHA1", 8, 30), timestamp).unwrap(),
501                sha1
502            );
503            assert_eq!(
504                generate_code_at(&format_uri(&seed_sha256, "SHA256", 8, 30), timestamp).unwrap(),
505                sha256
506            );
507            assert_eq!(
508                generate_code_at(&format_uri(&seed_sha512, "SHA512", 8, 30), timestamp).unwrap(),
509                sha512
510            );
511        }
512    }
513
514    #[test]
515    fn generate_code_at_honors_digits_parameter() {
516        let seed = encode_base32(b"12345678901234567890");
517        let uri = format_uri(&seed, "SHA1", 8, 30);
518
519        assert_eq!(generate_code_at(&uri, 59).unwrap(), "94287082");
520    }
521
522    #[test]
523    fn generate_code_at_honors_period_parameter() {
524        let seed = encode_base32(b"12345678901234567890");
525        let uri = format_uri(&seed, "SHA1", 8, 60);
526
527        assert_eq!(generate_code_at(&uri, 59).unwrap(), "84755224");
528        assert_eq!(generate_code_at(&uri, 60).unwrap(), "94287082");
529    }
530
531    #[test]
532    fn generate_code_at_parses_lowercase_otpauth_parameters() {
533        let seed = encode_base32(b"12345678901234567890123456789012");
534        let uri = format!("otpauth://totp/Alice?secret={seed}&algorithm=sha256&digits=8&period=30");
535
536        assert_eq!(generate_code_at(&uri, 59).unwrap(), "46119246");
537    }
538
539    #[test]
540    fn generate_code_at_rejects_invalid_otpauth_parameters() {
541        let seed = encode_base32(b"12345678901234567890");
542
543        for uri in [
544            format!("otpauth://totp/Alice?secret={seed}&algorithm=MD5"),
545            format!("otpauth://totp/Alice?secret={seed}&digits=0"),
546            format!("otpauth://totp/Alice?secret={seed}&digits=11"),
547            format!("otpauth://totp/Alice?secret={seed}&period=0"),
548            format!("otpauth://totp/Alice?secret={seed}&period=abc"),
549        ] {
550            let result = generate_code_at(&uri, 59);
551            assert!(
552                matches!(result, Err(SafeError::InvalidVault { .. })),
553                "expected invalid parameter error for {uri:?}, got {result:?}"
554            );
555        }
556    }
557
558    #[test]
559    fn seconds_remaining_for_honors_period_parameter() {
560        let seed = encode_base32(b"12345678901234567890");
561        let uri = format_uri(&seed, "SHA1", 6, 60);
562
563        assert_eq!(seconds_remaining_for_at(&uri, 59).unwrap(), 1);
564        assert_eq!(seconds_remaining_for_at(&uri, 60).unwrap(), 60);
565    }
566
567    // ── seconds_remaining ─────────────────────────────────────────────────────
568
569    #[test]
570    fn seconds_remaining_is_in_range_1_to_30() {
571        let secs = seconds_remaining();
572        assert!(
573            (1..=30).contains(&secs),
574            "seconds_remaining() returned {secs}, expected 1..=30"
575        );
576    }
577
578    fn encode_base32(bytes: &[u8]) -> String {
579        base32::encode(base32::Alphabet::Rfc4648 { padding: false }, bytes)
580    }
581
582    fn format_uri(secret: &str, algorithm: &str, digits: u32, period: u64) -> String {
583        format!(
584            "otpauth://totp/Alice?secret={secret}&algorithm={algorithm}&digits={digits}&period={period}"
585        )
586    }
587}