1use hmac::{Hmac, Mac};
8use sha1::Sha1;
9use zeroize::Zeroizing;
10
11use crate::errors::{SafeError, SafeResult};
12
13type HmacSha1 = Hmac<Sha1>;
14
15pub fn extract_base32(input: &str) -> SafeResult<Zeroizing<String>> {
19 let raw = if input.starts_with("otpauth://") {
20 let query_start = input.find('?').ok_or_else(|| SafeError::InvalidVault {
23 reason: "otpauth:// URI has no query string".into(),
24 })?;
25 let query = &input[query_start + 1..];
26 let secret = query
27 .split('&')
28 .find_map(|pair| {
29 let (k, v) = pair.split_once('=')?;
30 if k.eq_ignore_ascii_case("secret") {
31 Some(v)
32 } else {
33 None
34 }
35 })
36 .ok_or_else(|| SafeError::InvalidVault {
37 reason: "otpauth:// URI is missing the 'secret' parameter".into(),
38 })?;
39 secret.to_string()
40 } else {
41 input.to_string()
42 };
43
44 let normalised: String = raw
46 .chars()
47 .filter(|c| !c.is_whitespace() && *c != '-')
48 .map(|c| c.to_ascii_uppercase())
49 .collect();
50
51 decode_base32(&normalised)?;
53 Ok(Zeroizing::new(normalised))
54}
55
56pub fn generate_code(base32_secret: &str) -> SafeResult<String> {
59 let key_bytes = decode_base32(base32_secret)?;
60 let counter = current_counter();
61 let code = hotp(&key_bytes, counter, 6)?;
62 Ok(format!("{code:0>6}"))
63}
64
65pub fn seconds_remaining() -> u64 {
67 let ts = unix_timestamp();
68 30 - (ts % 30)
69}
70
71fn unix_timestamp() -> u64 {
74 std::time::SystemTime::now()
75 .duration_since(std::time::UNIX_EPOCH)
76 .map(|d| d.as_secs())
77 .unwrap_or(0)
78}
79
80fn current_counter() -> u64 {
81 unix_timestamp() / 30
82}
83
84fn decode_base32(s: &str) -> SafeResult<Vec<u8>> {
85 base32::decode(base32::Alphabet::Rfc4648 { padding: false }, s)
87 .or_else(|| base32::decode(base32::Alphabet::Rfc4648 { padding: true }, s))
88 .ok_or_else(|| SafeError::InvalidVault {
89 reason: "invalid TOTP base32 secret".into(),
90 })
91}
92
93fn hotp(key: &[u8], counter: u64, digits: u32) -> SafeResult<u32> {
95 let counter_bytes = counter.to_be_bytes();
96
97 let mut mac = HmacSha1::new_from_slice(key).map_err(|e| SafeError::InvalidVault {
98 reason: format!("HMAC key error: {e}"),
99 })?;
100 mac.update(&counter_bytes);
101 let result = mac.finalize().into_bytes();
102 let result = result.as_slice();
103
104 let offset = (result[19] & 0x0f) as usize;
106 let code = u32::from_be_bytes([
107 result[offset] & 0x7f,
108 result[offset + 1],
109 result[offset + 2],
110 result[offset + 3],
111 ]);
112
113 let modulus = 10u32.pow(digits);
114 Ok(code % modulus)
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 const KNOWN_B32: &str = "JBSWY3DPEHPK3PXP";
123
124 #[test]
127 fn extract_base32_plain_returns_normalised() {
128 let result = extract_base32(KNOWN_B32).unwrap();
129 assert_eq!(*result, KNOWN_B32);
130 }
131
132 #[test]
133 fn extract_base32_lowercase_is_normalised_to_upper() {
134 let result = extract_base32(&KNOWN_B32.to_lowercase()).unwrap();
135 assert_eq!(*result, KNOWN_B32);
136 }
137
138 #[test]
139 fn extract_base32_strips_spaces_and_hyphens() {
140 let spaced = "JBSWY 3DP-EHPK 3PXP";
142 let result = extract_base32(spaced).unwrap();
143 assert_eq!(*result, KNOWN_B32);
144 }
145
146 #[test]
147 fn extract_base32_parses_otpauth_uri() {
148 let uri = format!("otpauth://totp/Alice?secret={KNOWN_B32}&issuer=Example");
149 let result = extract_base32(&uri).unwrap();
150 assert_eq!(*result, KNOWN_B32);
151 }
152
153 #[test]
154 fn extract_base32_otpauth_uri_secret_case_insensitive_param_name() {
155 let uri = format!("otpauth://totp/Alice?SECRET={KNOWN_B32}");
156 let result = extract_base32(&uri).unwrap();
157 assert_eq!(*result, KNOWN_B32);
158 }
159
160 #[test]
161 fn extract_base32_otpauth_uri_missing_query_string_errors() {
162 let result = extract_base32("otpauth://totp/Alice");
163 assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
164 }
165
166 #[test]
167 fn extract_base32_otpauth_uri_missing_secret_param_errors() {
168 let result = extract_base32("otpauth://totp/Alice?issuer=Example");
169 assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
170 }
171
172 #[test]
173 fn extract_base32_invalid_base32_chars_errors() {
174 let result = extract_base32("!!!NOT-VALID-BASE32!!!");
175 assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
176 }
177
178 #[test]
181 fn generate_code_returns_six_digit_string() {
182 let code = generate_code(KNOWN_B32).unwrap();
183 assert_eq!(
184 code.len(),
185 6,
186 "TOTP code must be exactly 6 chars, got {code:?}"
187 );
188 assert!(
189 code.chars().all(|c| c.is_ascii_digit()),
190 "TOTP code must be all digits, got {code:?}"
191 );
192 }
193
194 #[test]
195 fn generate_code_is_stable_within_same_30s_window() {
196 let a = generate_code(KNOWN_B32).unwrap();
199 let b = generate_code(KNOWN_B32).unwrap();
200 assert_eq!(a, b, "codes differed between two rapid calls");
201 }
202
203 #[test]
204 fn generate_code_rejects_invalid_base32() {
205 let result = generate_code("!!!INVALID!!!");
206 assert!(matches!(result, Err(SafeError::InvalidVault { .. })));
207 }
208
209 #[test]
210 fn generate_code_zero_pads_to_six_digits() {
211 for _ in 0..3 {
216 let code = generate_code(KNOWN_B32).unwrap();
217 let n: u32 = code.parse().expect("should parse as integer");
218 assert!(n < 1_000_000, "code {n} must be < 1_000_000");
219 }
220 }
221
222 #[test]
225 fn seconds_remaining_is_in_range_1_to_30() {
226 let secs = seconds_remaining();
227 assert!(
228 (1..=30).contains(&secs),
229 "seconds_remaining() returned {secs}, expected 1..=30"
230 );
231 }
232}