1use crate::{
4 base64::base64_encode,
5 error::{CryptoError, CryptoResult},
6 hmac::{HmacAlgorithm, hmac_sign},
7};
8use rand::Rng;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum TotpAlgorithm {
13 SHA1,
15 SHA256,
17 SHA512,
19}
20
21impl Default for TotpAlgorithm {
22 fn default() -> Self {
23 TotpAlgorithm::SHA1
24 }
25}
26
27const BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
29
30#[derive(Debug, Clone)]
32pub struct TotpSecret {
33 bytes: Vec<u8>,
34 base32: String,
35}
36
37impl TotpSecret {
38 pub fn generate(length: usize) -> CryptoResult<Self> {
40 let mut bytes = vec![0u8; length];
41 rand::rng().fill_bytes(&mut bytes);
42 Self::from_bytes(&bytes)
43 }
44
45 pub fn generate_default() -> CryptoResult<Self> {
47 Self::generate(20)
48 }
49
50 pub fn from_bytes(bytes: &[u8]) -> CryptoResult<Self> {
52 let base32 = Self::encode_base32(bytes);
53 Ok(Self { bytes: bytes.to_vec(), base32 })
54 }
55
56 pub fn from_base32(s: &str) -> CryptoResult<Self> {
58 let bytes = Self::decode_base32(s)?;
59 Ok(Self { bytes, base32: s.to_uppercase().replace(" ", "") })
60 }
61
62 pub fn as_bytes(&self) -> &[u8] {
64 &self.bytes
65 }
66
67 pub fn as_base32(&self) -> &str {
69 &self.base32
70 }
71
72 pub fn len(&self) -> usize {
74 self.bytes.len()
75 }
76
77 pub fn is_empty(&self) -> bool {
79 self.bytes.is_empty()
80 }
81
82 fn encode_base32(data: &[u8]) -> String {
84 let mut result = String::new();
85 let mut i = 0;
86 let n = data.len();
87
88 while i < n {
89 let mut word: u64 = 0;
90 let mut bits = 0;
91
92 for j in 0..5 {
93 if i + j < n {
94 word = (word << 8) | (data[i + j] as u64);
95 bits += 8;
96 }
97 }
98
99 i += 5;
100
101 while bits >= 5 {
102 bits -= 5;
103 let index = ((word >> bits) & 0x1F) as usize;
104 result.push(BASE32_CHARS[index] as char);
105 }
106
107 if bits > 0 {
108 let index = ((word << (5 - bits)) & 0x1F) as usize;
109 result.push(BASE32_CHARS[index] as char);
110 }
111 }
112
113 result
114 }
115
116 fn decode_base32(s: &str) -> CryptoResult<Vec<u8>> {
118 let s = s.to_uppercase().replace(" ", "").replace("-", "");
119 let mut result = Vec::new();
120 let chars: Vec<char> = s.chars().collect();
121
122 let mut i = 0;
123 while i < chars.len() {
124 let mut word: u64 = 0;
125 let mut bits = 0;
126
127 for j in 0..8 {
128 if i + j < chars.len() {
129 let val = Self::base32_char_to_value(chars[i + j])?;
130 word = (word << 5) | (val as u64);
131 bits += 5;
132 }
133 }
134
135 i += 8;
136
137 while bits >= 8 {
138 bits -= 8;
139 result.push(((word >> bits) & 0xFF) as u8);
140 }
141 }
142
143 Ok(result)
144 }
145
146 fn base32_char_to_value(c: char) -> CryptoResult<u8> {
148 match c {
149 'A'..='Z' => Ok((c as u8) - b'A'),
150 '2'..='7' => Ok((c as u8) - b'2' + 26),
151 _ => Err(CryptoError::Base32Error(format!("Invalid character: {}", c))),
152 }
153 }
154}
155
156impl std::fmt::Display for TotpSecret {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 write!(f, "{}", self.base32)
159 }
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub enum SecretFormat {
165 Base32,
167 Base32Spaced,
169 Raw,
171 Base64,
173}
174
175impl TotpSecret {
176 pub fn format(&self, format: SecretFormat) -> String {
178 match format {
179 SecretFormat::Base32 => self.base32.clone(),
180 SecretFormat::Base32Spaced => self
181 .base32
182 .as_bytes()
183 .chunks(4)
184 .map(|chunk| std::str::from_utf8(chunk).unwrap_or(""))
185 .collect::<Vec<_>>()
186 .join(" "),
187 SecretFormat::Raw => self.bytes.iter().map(|b| format!("{:02x}", b)).collect(),
188 SecretFormat::Base64 => base64_encode(&self.bytes),
189 }
190 }
191}
192
193fn dynamic_truncate(hmac_result: &[u8], digits: u32) -> u32 {
195 let offset = (hmac_result.last().unwrap() & 0x0F) as usize;
196 let binary = ((hmac_result[offset] as u32 & 0x7F) << 24)
197 | ((hmac_result[offset + 1] as u32 & 0xFF) << 16)
198 | ((hmac_result[offset + 2] as u32 & 0xFF) << 8)
199 | (hmac_result[offset + 3] as u32 & 0xFF);
200
201 let power = 10u32.pow(digits);
202 binary % power
203}
204
205fn compute_hmac(algorithm: TotpAlgorithm, key: &[u8], counter: &[u8]) -> CryptoResult<Vec<u8>> {
207 let hmac_alg = match algorithm {
208 TotpAlgorithm::SHA1 => HmacAlgorithm::SHA1,
209 TotpAlgorithm::SHA256 => HmacAlgorithm::SHA256,
210 TotpAlgorithm::SHA512 => HmacAlgorithm::SHA512,
211 };
212 hmac_sign(hmac_alg, key, counter)
213}
214
215pub fn generate_hotp(secret: &[u8], counter: u64, digits: u32, algorithm: TotpAlgorithm) -> CryptoResult<String> {
217 let counter_bytes = counter.to_be_bytes();
218 let hmac_result = compute_hmac(algorithm, secret, &counter_bytes)?;
219 let code = dynamic_truncate(&hmac_result, digits);
220 Ok(format!("{:0width$}", code, width = digits as usize))
221}
222
223pub fn generate_totp(
225 secret: &[u8],
226 timestamp: u64,
227 time_step: u64,
228 digits: u32,
229 algorithm: TotpAlgorithm,
230) -> CryptoResult<String> {
231 let counter = timestamp / time_step;
232 generate_hotp(secret, counter, digits, algorithm)
233}
234
235pub fn verify_totp(
237 secret: &[u8],
238 code: &str,
239 timestamp: u64,
240 time_step: u64,
241 digits: u32,
242 algorithm: TotpAlgorithm,
243 window: u32,
244) -> CryptoResult<bool> {
245 let current_counter = timestamp / time_step;
246
247 for i in -(window as i64)..=(window as i64) {
248 let counter = (current_counter as i64 + i) as u64;
249 let expected = generate_hotp(secret, counter, digits, algorithm)?;
250
251 if constant_time_compare(code, &expected) {
252 return Ok(true);
253 }
254 }
255
256 Ok(false)
257}
258
259pub fn verify_hotp(secret: &[u8], code: &str, counter: u64, digits: u32, algorithm: TotpAlgorithm) -> CryptoResult<bool> {
261 let expected = generate_hotp(secret, counter, digits, algorithm)?;
262 Ok(constant_time_compare(code, &expected))
263}
264
265fn constant_time_compare(a: &str, b: &str) -> bool {
267 if a.len() != b.len() {
268 return false;
269 }
270
271 let a_bytes = a.as_bytes();
272 let b_bytes = b.as_bytes();
273
274 let mut result = 0u8;
275 for i in 0..a.len() {
276 result |= a_bytes[i] ^ b_bytes[i];
277 }
278
279 result == 0
280}
281
282pub fn get_time_step(timestamp: u64, time_step: u64) -> u64 {
284 timestamp / time_step
285}
286
287pub fn get_remaining_seconds(timestamp: u64, time_step: u64) -> u64 {
289 time_step - (timestamp % time_step)
290}