rust_auth_utils/
otp.rs

1// based on https://github.com/better-auth/utils/blob/main/src/otp.ts
2
3use crate::base32::Base32;
4use crate::hmac::HmacBuilder;
5use crate::types::SHAFamily;
6use std::time::{SystemTime, UNIX_EPOCH};
7use url::Url;
8
9const DEFAULT_PERIOD: u32 = 30;
10const DEFAULT_DIGITS: u32 = 6;
11
12#[derive(Debug)]
13pub struct Error(String);
14
15impl std::fmt::Display for Error {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        write!(f, "{}", self.0)
18    }
19}
20
21impl std::error::Error for Error {}
22
23pub async fn generate_hotp(
24    secret: &str,
25    counter: u64,
26    digits: Option<u32>,
27    hash: Option<SHAFamily>,
28) -> Result<String, Error> {
29    let digits = digits.unwrap_or(DEFAULT_DIGITS);
30    if digits < 1 || digits > 8 {
31        return Err(Error("Digits must be between 1 and 8".to_string()));
32    }
33
34    let mut buffer = [0u8; 8];
35    buffer.copy_from_slice(&counter.to_be_bytes());
36
37    let hmac = HmacBuilder::new(hash, None);
38    let hmac_result = hmac
39        .sign(secret.as_bytes(), &buffer)
40        .map_err(|e| Error(e.to_string()))?;
41
42    let offset = hmac_result[hmac_result.len() - 1] & 0x0f;
43    let truncated = ((hmac_result[offset as usize] & 0x7f) as u32) << 24
44        | ((hmac_result[(offset + 1) as usize] & 0xff) as u32) << 16
45        | ((hmac_result[(offset + 2) as usize] & 0xff) as u32) << 8
46        | (hmac_result[(offset + 3) as usize] & 0xff) as u32;
47
48    let otp = truncated % 10u32.pow(digits);
49    Ok(format!("{:0width$}", otp, width = digits as usize))
50}
51
52pub async fn generate_totp(
53    secret: &str,
54    period: Option<u32>,
55    digits: Option<u32>,
56    hash: Option<SHAFamily>,
57) -> Result<String, Error> {
58    let period = period.unwrap_or(DEFAULT_PERIOD);
59    let now = SystemTime::now()
60        .duration_since(UNIX_EPOCH)
61        .map_err(|e| Error(e.to_string()))?
62        .as_secs();
63    let counter = now / period as u64;
64
65    generate_hotp(secret, counter, digits, hash).await
66}
67
68pub async fn verify_totp(
69    otp: &str,
70    secret: &str,
71    window: Option<i32>,
72    digits: Option<u32>,
73    period: Option<u32>,
74) -> Result<bool, Error> {
75    let window = window.unwrap_or(1);
76    let period = period.unwrap_or(DEFAULT_PERIOD);
77    let now = SystemTime::now()
78        .duration_since(UNIX_EPOCH)
79        .map_err(|e| Error(e.to_string()))?
80        .as_secs();
81    let counter = now / period as u64;
82
83    for i in -window..=window {
84        let current_counter = counter
85            .checked_add_signed(i as i64)
86            .ok_or_else(|| Error("Counter overflow".to_string()))?;
87        let generated_otp = generate_hotp(secret, current_counter, digits, None).await?;
88        if otp == generated_otp {
89            return Ok(true);
90        }
91    }
92    Ok(false)
93}
94
95pub fn generate_qr_code(
96    issuer: &str,
97    account: &str,
98    secret: &str,
99    digits: Option<u32>,
100    period: Option<u32>,
101) -> Result<String, Error> {
102    let digits = digits.unwrap_or(DEFAULT_DIGITS);
103    let period = period.unwrap_or(DEFAULT_PERIOD);
104
105    let base_uri = format!(
106        "otpauth://totp/{}:{}",
107        urlencoding::encode(issuer),
108        urlencoding::encode(account)
109    );
110
111    let mut url = Url::parse(&base_uri).map_err(|e| Error(e.to_string()))?;
112
113    let encoded_secret = Base32::encode(secret.as_bytes(), Some(false));
114    url.query_pairs_mut()
115        .append_pair("secret", &encoded_secret)
116        .append_pair("issuer", issuer)
117        .append_pair("digits", &digits.to_string())
118        .append_pair("period", &period.to_string());
119
120    Ok(url.to_string())
121}
122
123pub struct OTP {
124    secret: String,
125    digits: u32,
126    period: u32,
127}
128
129impl OTP {
130    pub fn new(secret: &str, digits: Option<u32>, period: Option<u32>) -> Self {
131        Self {
132            secret: secret.to_string(),
133            digits: digits.unwrap_or(DEFAULT_DIGITS),
134            period: period.unwrap_or(DEFAULT_PERIOD),
135        }
136    }
137
138    pub async fn hotp(&self, counter: u64) -> Result<String, Error> {
139        generate_hotp(&self.secret, counter, Some(self.digits), None).await
140    }
141
142    pub async fn totp(&self) -> Result<String, Error> {
143        generate_totp(&self.secret, Some(self.period), Some(self.digits), None).await
144    }
145
146    pub async fn verify(&self, otp: &str, window: Option<i32>) -> Result<bool, Error> {
147        verify_totp(
148            otp,
149            &self.secret,
150            window,
151            Some(self.digits),
152            Some(self.period),
153        )
154        .await
155    }
156
157    pub fn url(&self, issuer: &str, account: &str) -> Result<String, Error> {
158        generate_qr_code(
159            issuer,
160            account,
161            &self.secret,
162            Some(self.digits),
163            Some(self.period),
164        )
165    }
166}