1use hmac::{Hmac, Mac};
8use sha1::Sha1;
9use sha2::{Sha256, Sha512};
10use zeroize::Zeroizing;
11
12use crate::errors::{SafeError, SafeResult};
13
14type HmacSha1 = Hmac<Sha1>;
15type HmacSha256 = Hmac<Sha256>;
16type HmacSha512 = Hmac<Sha512>;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19enum TotpAlgorithm {
20 Sha1,
21 Sha256,
22 Sha512,
23}
24
25impl TotpAlgorithm {
26 fn parse(value: &str) -> SafeResult<Self> {
27 match value.to_ascii_uppercase().as_str() {
28 "SHA1" | "SHA-1" => Ok(Self::Sha1),
29 "SHA256" | "SHA-256" => Ok(Self::Sha256),
30 "SHA512" | "SHA-512" => Ok(Self::Sha512),
31 other => Err(SafeError::InvalidVault {
32 reason: format!("unsupported TOTP algorithm '{other}'"),
33 }),
34 }
35 }
36
37 fn as_uri_str(self) -> &'static str {
38 match self {
39 Self::Sha1 => "SHA1",
40 Self::Sha256 => "SHA256",
41 Self::Sha512 => "SHA512",
42 }
43 }
44}
45
46#[derive(Debug)]
47struct TotpConfig {
48 secret: Zeroizing<String>,
49 algorithm: TotpAlgorithm,
50 digits: u32,
51 period: u64,
52}
53
54impl TotpConfig {
55 fn default_for_secret(secret: String) -> Self {
56 Self {
57 secret: Zeroizing::new(secret),
58 algorithm: TotpAlgorithm::Sha1,
59 digits: 6,
60 period: 30,
61 }
62 }
63}
64
65pub fn extract_base32(input: &str) -> SafeResult<Zeroizing<String>> {
69 Ok(parse_totp_config(input)?.secret)
70}
71
72pub fn provisioning_uri(
77 label: &str,
78 input: &str,
79 algorithm: Option<&str>,
80 digits: Option<u32>,
81 period: Option<u64>,
82) -> SafeResult<String> {
83 let mut config = parse_totp_config(input)?;
84 if let Some(algorithm) = algorithm {
85 config.algorithm = TotpAlgorithm::parse(algorithm)?;
86 }
87 if let Some(digits) = digits {
88 config.digits = validate_digits(digits)?;
89 }
90 if let Some(period) = period {
91 config.period = validate_period(period)?;
92 }
93
94 Ok(format!(
95 "otpauth://totp/{label}?secret={}&algorithm={}&digits={}&period={}",
96 config.secret.as_str(),
97 config.algorithm.as_uri_str(),
98 config.digits,
99 config.period
100 ))
101}
102
103pub fn generate_code(input: &str) -> SafeResult<String> {
109 generate_code_at(input, unix_timestamp())
110}
111
112pub fn generate_code_at(input: &str, timestamp: u64) -> SafeResult<String> {
116 let config = parse_totp_config(input)?;
117 let key_bytes = decode_base32(&config.secret)?;
118 let counter = timestamp / config.period;
119 let code = hotp(&key_bytes, counter, config.digits, config.algorithm)?;
120 Ok(format_code(code, config.digits))
121}
122
123pub fn seconds_remaining() -> u64 {
125 let ts = unix_timestamp();
126 30 - (ts % 30)
127}
128
129pub fn seconds_remaining_for(input: &str) -> SafeResult<u64> {
131 seconds_remaining_for_at(input, unix_timestamp())
132}
133
134pub fn seconds_remaining_for_at(input: &str, timestamp: u64) -> SafeResult<u64> {
136 let config = parse_totp_config(input)?;
137 Ok(config.period - (timestamp % config.period))
138}
139
140fn unix_timestamp() -> u64 {
143 std::time::SystemTime::now()
144 .duration_since(std::time::UNIX_EPOCH)
145 .map(|d| d.as_secs())
146 .unwrap_or(0)
147}
148
149fn decode_base32(s: &str) -> SafeResult<Vec<u8>> {
150 base32::decode(base32::Alphabet::Rfc4648 { padding: false }, s)
152 .or_else(|| base32::decode(base32::Alphabet::Rfc4648 { padding: true }, s))
153 .ok_or_else(|| SafeError::InvalidVault {
154 reason: "invalid TOTP base32 secret".into(),
155 })
156}
157
158fn parse_totp_config(input: &str) -> SafeResult<TotpConfig> {
159 if input.starts_with("otpauth://") {
160 parse_otpauth_uri(input)
161 } else {
162 let normalised = normalise_base32(input);
163 decode_base32(&normalised)?;
164 Ok(TotpConfig::default_for_secret(normalised))
165 }
166}
167
168fn parse_otpauth_uri(input: &str) -> SafeResult<TotpConfig> {
169 let query_start = input.find('?').ok_or_else(|| SafeError::InvalidVault {
170 reason: "otpauth:// URI has no query string".into(),
171 })?;
172 let query = &input[query_start + 1..];
173
174 let mut secret = None;
175 let mut algorithm = TotpAlgorithm::Sha1;
176 let mut digits = 6;
177 let mut period = 30;
178
179 for pair in query.split('&') {
180 let Some((key, value)) = pair.split_once('=') else {
181 continue;
182 };
183 let value = decode_query_component(value)?;
184 match key.to_ascii_lowercase().as_str() {
185 "secret" if secret.is_none() => {
186 let normalised = normalise_base32(&value);
187 decode_base32(&normalised)?;
188 secret = Some(Zeroizing::new(normalised));
189 }
190 "algorithm" => {
191 algorithm = TotpAlgorithm::parse(&value)?;
192 }
193 "digits" => {
194 digits = parse_digits(&value)?;
195 }
196 "period" => {
197 period = parse_period(&value)?;
198 }
199 _ => {}
200 }
201 }
202
203 let secret = secret.ok_or_else(|| SafeError::InvalidVault {
204 reason: "otpauth:// URI is missing the 'secret' parameter".into(),
205 })?;
206
207 Ok(TotpConfig {
208 secret,
209 algorithm,
210 digits,
211 period,
212 })
213}
214
215fn normalise_base32(raw: &str) -> String {
216 raw.chars()
217 .filter(|c| !c.is_whitespace() && *c != '-')
218 .map(|c| c.to_ascii_uppercase())
219 .collect()
220}
221
222fn decode_query_component(input: &str) -> SafeResult<String> {
223 let bytes = input.as_bytes();
224 let mut out = Vec::with_capacity(bytes.len());
225 let mut i = 0;
226 while i < bytes.len() {
227 match bytes[i] {
228 b'%' if i + 2 < bytes.len() => {
229 let high = from_hex(bytes[i + 1])?;
230 let low = from_hex(bytes[i + 2])?;
231 out.push((high << 4) | low);
232 i += 3;
233 }
234 b'%' => {
235 return Err(SafeError::InvalidVault {
236 reason: "invalid percent encoding in otpauth:// URI".into(),
237 });
238 }
239 b'+' => {
240 out.push(b' ');
241 i += 1;
242 }
243 byte => {
244 out.push(byte);
245 i += 1;
246 }
247 }
248 }
249 String::from_utf8(out).map_err(|e| SafeError::InvalidVault {
250 reason: format!("otpauth:// URI parameter is not UTF-8: {e}"),
251 })
252}
253
254fn from_hex(byte: u8) -> SafeResult<u8> {
255 match byte {
256 b'0'..=b'9' => Ok(byte - b'0'),
257 b'a'..=b'f' => Ok(byte - b'a' + 10),
258 b'A'..=b'F' => Ok(byte - b'A' + 10),
259 _ => Err(SafeError::InvalidVault {
260 reason: "invalid percent encoding in otpauth:// URI".into(),
261 }),
262 }
263}
264
265fn parse_digits(value: &str) -> SafeResult<u32> {
266 let digits = value.parse::<u32>().map_err(|_| SafeError::InvalidVault {
267 reason: format!("invalid TOTP digits '{value}'"),
268 })?;
269 validate_digits(digits)
270}
271
272fn validate_digits(digits: u32) -> SafeResult<u32> {
273 if (1..=10).contains(&digits) {
274 Ok(digits)
275 } else {
276 Err(SafeError::InvalidVault {
277 reason: "TOTP digits must be between 1 and 10".into(),
278 })
279 }
280}
281
282fn parse_period(value: &str) -> SafeResult<u64> {
283 let period = value.parse::<u64>().map_err(|_| SafeError::InvalidVault {
284 reason: format!("invalid TOTP period '{value}'"),
285 })?;
286 validate_period(period)
287}
288
289fn validate_period(period: u64) -> SafeResult<u64> {
290 if period > 0 {
291 Ok(period)
292 } else {
293 Err(SafeError::InvalidVault {
294 reason: "TOTP period must be at least 1 second".into(),
295 })
296 }
297}
298
299fn hotp(key: &[u8], counter: u64, digits: u32, algorithm: TotpAlgorithm) -> SafeResult<u64> {
301 let counter_bytes = counter.to_be_bytes();
302 let result = hmac_digest(key, &counter_bytes, algorithm)?;
303
304 let offset = (result[result.len() - 1] & 0x0f) as usize;
305 let code = u32::from_be_bytes([
306 result[offset] & 0x7f,
307 result[offset + 1],
308 result[offset + 2],
309 result[offset + 3],
310 ]);
311
312 let modulus = 10u64.pow(digits);
313 Ok(u64::from(code) % modulus)
314}
315
316fn hmac_digest(key: &[u8], counter_bytes: &[u8], algorithm: TotpAlgorithm) -> SafeResult<Vec<u8>> {
317 match algorithm {
318 TotpAlgorithm::Sha1 => {
319 let mut mac = HmacSha1::new_from_slice(key).map_err(hmac_key_error)?;
320 mac.update(counter_bytes);
321 Ok(mac.finalize().into_bytes().to_vec())
322 }
323 TotpAlgorithm::Sha256 => {
324 let mut mac = HmacSha256::new_from_slice(key).map_err(hmac_key_error)?;
325 mac.update(counter_bytes);
326 Ok(mac.finalize().into_bytes().to_vec())
327 }
328 TotpAlgorithm::Sha512 => {
329 let mut mac = HmacSha512::new_from_slice(key).map_err(hmac_key_error)?;
330 mac.update(counter_bytes);
331 Ok(mac.finalize().into_bytes().to_vec())
332 }
333 }
334}
335
336fn hmac_key_error(e: hmac::digest::InvalidLength) -> SafeError {
337 SafeError::InvalidVault {
338 reason: format!("HMAC key error: {e}"),
339 }
340}
341
342fn format_code(code: u64, digits: u32) -> String {
343 format!("{code:0>width$}", width = digits as usize)
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 const KNOWN_B32: &str = "JBSWY3DPEHPK3PXP";
352
353 #[test]
356 fn extract_base32_plain_returns_normalised() {
357 let result = extract_base32(KNOWN_B32).unwrap();
358 assert_eq!(*result, KNOWN_B32);
359 }
360
361 #[test]
362 fn extract_base32_lowercase_is_normalised_to_upper() {
363 let result = extract_base32(&KNOWN_B32.to_lowercase()).unwrap();
364 assert_eq!(*result, KNOWN_B32);
365 }
366
367 #[test]
368 fn extract_base32_strips_spaces_and_hyphens() {
369 let spaced = "JBSWY 3DP-EHPK 3PXP";
371 let result = extract_base32(spaced).unwrap();
372 assert_eq!(*result, KNOWN_B32);
373 }
374
375 #[test]
376 fn extract_base32_parses_otpauth_uri() {
377 let uri = format!("otpauth://totp/Alice?secret={KNOWN_B32}&issuer=Example");
378 let result = extract_base32(&uri).unwrap();
379 assert_eq!(*result, KNOWN_B32);
380 }
381
382 #[test]
383 fn extract_base32_otpauth_uri_secret_case_insensitive_param_name() {
384 let uri = format!("otpauth://totp/Alice?SECRET={KNOWN_B32}");
385 let result = extract_base32(&uri).unwrap();
386 assert_eq!(*result, KNOWN_B32);
387 }
388
389 #[test]
390 fn extract_base32_otpauth_uri_missing_query_string_errors() {
391 let result = extract_base32("otpauth://totp/Alice");
392 assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
393 }
394
395 #[test]
396 fn extract_base32_otpauth_uri_missing_secret_param_errors() {
397 let result = extract_base32("otpauth://totp/Alice?issuer=Example");
398 assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
399 }
400
401 #[test]
402 fn extract_base32_invalid_base32_chars_errors() {
403 let result = extract_base32("!!!NOT-VALID-BASE32!!!");
404 assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
405 }
406
407 #[test]
408 fn provisioning_uri_preserves_otpauth_parameters() {
409 let seed = encode_base32(b"12345678901234567890123456789012");
410 let input =
411 format!("otpauth://totp/Alice?secret={seed}&algorithm=SHA256&digits=8&period=60");
412
413 let uri = provisioning_uri("GITHUB_2FA", &input, None, None, None).unwrap();
414
415 assert_eq!(
416 uri,
417 format!("otpauth://totp/GITHUB_2FA?secret={seed}&algorithm=SHA256&digits=8&period=60")
418 );
419 }
420
421 #[test]
422 fn provisioning_uri_overrides_otpauth_parameters() {
423 let seed = encode_base32(b"12345678901234567890123456789012");
424 let input =
425 format!("otpauth://totp/Alice?secret={seed}&algorithm=SHA256&digits=8&period=60");
426
427 let uri = provisioning_uri("GITHUB_2FA", &input, Some("SHA1"), Some(6), Some(30)).unwrap();
428
429 assert_eq!(
430 uri,
431 format!("otpauth://totp/GITHUB_2FA?secret={seed}&algorithm=SHA1&digits=6&period=30")
432 );
433 }
434
435 #[test]
438 fn generate_code_returns_six_digit_string() {
439 let code = generate_code(KNOWN_B32).unwrap();
440 assert_eq!(
441 code.len(),
442 6,
443 "TOTP code must be exactly 6 chars, got {code:?}"
444 );
445 assert!(
446 code.chars().all(|c| c.is_ascii_digit()),
447 "TOTP code must be all digits, got {code:?}"
448 );
449 }
450
451 #[test]
452 fn generate_code_is_stable_within_same_30s_window() {
453 let a = generate_code(KNOWN_B32).unwrap();
456 let b = generate_code(KNOWN_B32).unwrap();
457 assert_eq!(a, b, "codes differed between two rapid calls");
458 }
459
460 #[test]
461 fn generate_code_rejects_invalid_base32() {
462 let result = generate_code("!!!INVALID!!!");
463 assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
464 }
465
466 #[test]
467 fn generate_code_zero_pads_to_six_digits() {
468 for _ in 0..3 {
473 let code = generate_code(KNOWN_B32).unwrap();
474 let n: u32 = code.parse().expect("should parse as integer");
475 assert!(n < 1_000_000, "code {n} must be < 1_000_000");
476 }
477 }
478
479 #[test]
480 fn generate_code_at_matches_rfc6238_vectors() {
481 let seed_sha1 = encode_base32(b"12345678901234567890");
486 let seed_sha256 = encode_base32(b"12345678901234567890123456789012");
487 let seed_sha512 =
488 encode_base32(b"1234567890123456789012345678901234567890123456789012345678901234");
489 let vectors = [
490 (59, "94287082", "46119246", "90693936"),
491 (1_111_111_109, "07081804", "68084774", "25091201"),
492 (1_111_111_111, "14050471", "67062674", "99943326"),
493 (1_234_567_890, "89005924", "91819424", "93441116"),
494 (2_000_000_000, "69279037", "90698825", "38618901"),
495 (20_000_000_000, "65353130", "77737706", "47863826"),
496 ];
497
498 for (timestamp, sha1, sha256, sha512) in vectors {
499 assert_eq!(
500 generate_code_at(&format_uri(&seed_sha1, "SHA1", 8, 30), timestamp).unwrap(),
501 sha1
502 );
503 assert_eq!(
504 generate_code_at(&format_uri(&seed_sha256, "SHA256", 8, 30), timestamp).unwrap(),
505 sha256
506 );
507 assert_eq!(
508 generate_code_at(&format_uri(&seed_sha512, "SHA512", 8, 30), timestamp).unwrap(),
509 sha512
510 );
511 }
512 }
513
514 #[test]
515 fn generate_code_at_honors_digits_parameter() {
516 let seed = encode_base32(b"12345678901234567890");
517 let uri = format_uri(&seed, "SHA1", 8, 30);
518
519 assert_eq!(generate_code_at(&uri, 59).unwrap(), "94287082");
520 }
521
522 #[test]
523 fn generate_code_at_honors_period_parameter() {
524 let seed = encode_base32(b"12345678901234567890");
525 let uri = format_uri(&seed, "SHA1", 8, 60);
526
527 assert_eq!(generate_code_at(&uri, 59).unwrap(), "84755224");
528 assert_eq!(generate_code_at(&uri, 60).unwrap(), "94287082");
529 }
530
531 #[test]
532 fn generate_code_at_parses_lowercase_otpauth_parameters() {
533 let seed = encode_base32(b"12345678901234567890123456789012");
534 let uri = format!("otpauth://totp/Alice?secret={seed}&algorithm=sha256&digits=8&period=30");
535
536 assert_eq!(generate_code_at(&uri, 59).unwrap(), "46119246");
537 }
538
539 #[test]
540 fn generate_code_at_rejects_invalid_otpauth_parameters() {
541 let seed = encode_base32(b"12345678901234567890");
542
543 for uri in [
544 format!("otpauth://totp/Alice?secret={seed}&algorithm=MD5"),
545 format!("otpauth://totp/Alice?secret={seed}&digits=0"),
546 format!("otpauth://totp/Alice?secret={seed}&digits=11"),
547 format!("otpauth://totp/Alice?secret={seed}&period=0"),
548 format!("otpauth://totp/Alice?secret={seed}&period=abc"),
549 ] {
550 let result = generate_code_at(&uri, 59);
551 assert!(
552 matches!(result, Err(SafeError::InvalidVault { .. })),
553 "expected invalid parameter error for {uri:?}, got {result:?}"
554 );
555 }
556 }
557
558 #[test]
559 fn seconds_remaining_for_honors_period_parameter() {
560 let seed = encode_base32(b"12345678901234567890");
561 let uri = format_uri(&seed, "SHA1", 6, 60);
562
563 assert_eq!(seconds_remaining_for_at(&uri, 59).unwrap(), 1);
564 assert_eq!(seconds_remaining_for_at(&uri, 60).unwrap(), 60);
565 }
566
567 #[test]
570 fn seconds_remaining_is_in_range_1_to_30() {
571 let secs = seconds_remaining();
572 assert!(
573 (1..=30).contains(&secs),
574 "seconds_remaining() returned {secs}, expected 1..=30"
575 );
576 }
577
578 fn encode_base32(bytes: &[u8]) -> String {
579 base32::encode(base32::Alphabet::Rfc4648 { padding: false }, bytes)
580 }
581
582 fn format_uri(secret: &str, algorithm: &str, digits: u32, period: u64) -> String {
583 format!(
584 "otpauth://totp/Alice?secret={secret}&algorithm={algorithm}&digits={digits}&period={period}"
585 )
586 }
587}