rust_auth_utils/
hmac.rs

1// based on https://github.com/better-auth/utils/blob/main/src/hmac.ts
2
3use crate::types::{EncodingFormat, SHAFamily};
4use crate::{base64, hex};
5use hmac::{Hmac, Mac};
6use sha1::Sha1;
7use sha2::{Sha256, Sha384, Sha512};
8
9pub struct HmacBuilder {
10    algorithm: SHAFamily,
11    encoding: EncodingFormat,
12}
13
14impl Default for HmacBuilder {
15    fn default() -> Self {
16        Self {
17            algorithm: SHAFamily::SHA256,
18            encoding: EncodingFormat::None,
19        }
20    }
21}
22
23impl HmacBuilder {
24    pub fn new(algorithm: Option<SHAFamily>, encoding: Option<EncodingFormat>) -> Self {
25        Self {
26            algorithm: algorithm.unwrap_or(SHAFamily::SHA256),
27            encoding: encoding.unwrap_or(EncodingFormat::None),
28        }
29    }
30
31    pub fn sign(&self, key: &[u8], data: &[u8]) -> Result<Vec<u8>, &'static str> {
32        let signature = match self.algorithm {
33            SHAFamily::SHA1 => {
34                let mut mac =
35                    Hmac::<Sha1>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
36                mac.update(data);
37                mac.finalize().into_bytes().to_vec()
38            }
39            SHAFamily::SHA256 => {
40                let mut mac =
41                    Hmac::<Sha256>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
42                mac.update(data);
43                mac.finalize().into_bytes().to_vec()
44            }
45            SHAFamily::SHA384 => {
46                let mut mac =
47                    Hmac::<Sha384>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
48                mac.update(data);
49                mac.finalize().into_bytes().to_vec()
50            }
51            SHAFamily::SHA512 => {
52                let mut mac =
53                    Hmac::<Sha512>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
54                mac.update(data);
55                mac.finalize().into_bytes().to_vec()
56            }
57        };
58
59        match self.encoding {
60            EncodingFormat::Hex => Ok(hex::Hex::encode(&signature).as_bytes().to_vec()),
61            EncodingFormat::Base64 => Ok(base64::Base64::encode(&signature, Some(true))
62                .as_bytes()
63                .to_vec()),
64            EncodingFormat::Base64Url => Ok(base64::Base64Url::encode(&signature, Some(true))
65                .as_bytes()
66                .to_vec()),
67            EncodingFormat::Base64UrlNoPad => {
68                Ok(base64::Base64Url::encode(&signature, Some(false))
69                    .as_bytes()
70                    .to_vec())
71            }
72            EncodingFormat::None => Ok(signature),
73        }
74    }
75
76    pub fn verify(&self, key: &[u8], data: &[u8], signature: &[u8]) -> Result<bool, &'static str> {
77        let decoded_signature = match self.encoding {
78            EncodingFormat::Hex => {
79                let hex_str = std::str::from_utf8(signature).map_err(|_| "Invalid UTF-8")?;
80                hex::Hex::decode(hex_str).map_err(|_| "Invalid hex encoding")?
81            }
82            EncodingFormat::Base64 => {
83                let base64_str = std::str::from_utf8(signature).map_err(|_| "Invalid UTF-8")?;
84                // Strict base64 format validation
85                if base64_str.len() % 4 != 0 {
86                    return Err("Invalid base64 encoding: length not multiple of 4");
87                }
88                if !base64_str
89                    .chars()
90                    .all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '=')
91                {
92                    return Err("Invalid base64 encoding");
93                }
94                // Validate padding
95                let padding_count = base64_str.chars().rev().take_while(|&c| c == '=').count();
96                if padding_count > 2 {
97                    return Err("Invalid base64 padding");
98                }
99                // Ensure no padding characters except at the end
100                if base64_str[..base64_str.len() - padding_count].contains('=') {
101                    return Err("Invalid base64 encoding: padding in wrong position");
102                }
103                // Ensure it's actually base64 by checking for base64-specific characters
104                let non_padding_part = &base64_str[..base64_str.len() - padding_count];
105                if !non_padding_part.is_empty()
106                    && !non_padding_part.contains(|c| c == '+' || c == '/')
107                {
108                    return Err("Invalid base64 encoding: missing base64-specific characters");
109                }
110                base64::Base64::decode(base64_str).map_err(|_| "Invalid base64 encoding")?
111            }
112            EncodingFormat::Base64Url | EncodingFormat::Base64UrlNoPad => {
113                let base64_str = std::str::from_utf8(signature).map_err(|_| "Invalid UTF-8")?;
114                // Strict base64url format validation
115                if !base64_str
116                    .chars()
117                    .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '=')
118                {
119                    return Err("Invalid base64url encoding");
120                }
121                // For Base64UrlNoPad, reject if there's padding
122                if self.encoding == EncodingFormat::Base64UrlNoPad && base64_str.contains('=') {
123                    return Err("Invalid base64url encoding: unexpected padding");
124                }
125                // For Base64Url, validate padding
126                if self.encoding == EncodingFormat::Base64Url {
127                    if base64_str.len() % 4 != 0 {
128                        return Err("Invalid base64url encoding: length not multiple of 4");
129                    }
130                    let padding_count = base64_str.chars().rev().take_while(|&c| c == '=').count();
131                    if padding_count > 2 {
132                        return Err("Invalid base64url padding");
133                    }
134                    // Ensure no padding characters except at the end
135                    if base64_str[..base64_str.len() - padding_count].contains('=') {
136                        return Err("Invalid base64url encoding: padding in wrong position");
137                    }
138                }
139                // Ensure it's actually base64url by checking for base64url-specific characters
140                let non_padding_part = if self.encoding == EncodingFormat::Base64Url {
141                    let padding_count = base64_str.chars().rev().take_while(|&c| c == '=').count();
142                    &base64_str[..base64_str.len() - padding_count]
143                } else {
144                    base64_str
145                };
146                if !non_padding_part.is_empty()
147                    && !non_padding_part.contains(|c| c == '-' || c == '_')
148                {
149                    return Err(
150                        "Invalid base64url encoding: missing base64url-specific characters",
151                    );
152                }
153                base64::Base64Url::decode(base64_str).map_err(|_| "Invalid base64url encoding")?
154            }
155            EncodingFormat::None => signature.to_vec(),
156        };
157
158        let result = match self.algorithm {
159            SHAFamily::SHA1 => {
160                let mut mac =
161                    Hmac::<Sha1>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
162                mac.update(data);
163                mac.verify_slice(&decoded_signature).is_ok()
164            }
165            SHAFamily::SHA256 => {
166                let mut mac =
167                    Hmac::<Sha256>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
168                mac.update(data);
169                mac.verify_slice(&decoded_signature).is_ok()
170            }
171            SHAFamily::SHA384 => {
172                let mut mac =
173                    Hmac::<Sha384>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
174                mac.update(data);
175                mac.verify_slice(&decoded_signature).is_ok()
176            }
177            SHAFamily::SHA512 => {
178                let mut mac =
179                    Hmac::<Sha512>::new_from_slice(key).map_err(|_| "Failed to create HMAC")?;
180                mac.update(data);
181                mac.verify_slice(&decoded_signature).is_ok()
182            }
183        };
184
185        Ok(result)
186    }
187}