1use argon2::{Algorithm, Argon2, Params, Version};
2use chrono::{DateTime, Duration, Utc};
3use log::{debug, info};
4use std::time::{Duration as StdDuration, Instant};
5
6pub const DEFAULT_HASH_LEN: usize = 32;
7pub const DEFAULT_PEPPER_LEN: usize = 16;
8pub const DEFAULT_SALT_LEN: usize = 16;
9pub const DEFAULT_SESSION_LEN: usize = 32;
10
11#[derive(Clone)]
12pub struct HashConfig<const PEPPER_LEN: usize = DEFAULT_PEPPER_LEN> {
13 pub pepper: [u8; PEPPER_LEN],
14 pub memory_kib: u32,
15 pub time_cost: u32,
16 pub lanes: u32,
17}
18
19impl<const PEPPER_LEN: usize> serde::Serialize for HashConfig<PEPPER_LEN> {
20 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
21 where
22 S: serde::Serializer,
23 {
24 use serde::ser::SerializeStruct;
25 let mut state = serializer.serialize_struct("HashConfig", 4)?;
26 let mut hex_pepper = String::with_capacity(PEPPER_LEN * 2);
27 for b in &self.pepper {
28 use std::fmt::Write;
29 write!(&mut hex_pepper, "{:02x}", b).unwrap();
30 }
31 state.serialize_field("pepper", &hex_pepper)?;
32 state.serialize_field("memory_kib", &self.memory_kib)?;
33 state.serialize_field("time_cost", &self.time_cost)?;
34 state.serialize_field("lanes", &self.lanes)?;
35 state.end()
36 }
37}
38
39impl<'de, const PEPPER_LEN: usize> serde::Deserialize<'de> for HashConfig<PEPPER_LEN> {
40 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41 where
42 D: serde::Deserializer<'de>,
43 {
44 #[derive(serde::Deserialize)]
45 struct HashConfigHelper {
46 pepper: String,
47 memory_kib: u32,
48 time_cost: u32,
49 lanes: u32,
50 }
51 let helper = HashConfigHelper::deserialize(deserializer)?;
52 let s = helper.pepper.as_str();
53 if s.len() != PEPPER_LEN * 2 {
54 return Err(serde::de::Error::invalid_length(
55 s.len() / 2,
56 &format!("expected {} bytes for pepper", PEPPER_LEN).as_str(),
57 ));
58 }
59 let mut pepper = [0u8; PEPPER_LEN];
60 for (i, byte) in pepper.iter_mut().enumerate() {
61 let idx = i * 2;
62 *byte = u8::from_str_radix(&s[idx..idx + 2], 16).map_err(serde::de::Error::custom)?;
63 }
64 Ok(HashConfig {
65 pepper,
66 memory_kib: helper.memory_kib,
67 time_cost: helper.time_cost,
68 lanes: helper.lanes,
69 })
70 }
71}
72
73impl<const PEPPER_LEN: usize> HashConfig<PEPPER_LEN> {
74 pub fn benchmark(target_ms: u64) -> Self {
75 info!("Benchmarking HashConfig parameters...");
76 let test_password = "benchmark_password";
77 let salt = [0u8; 16];
78 let target_duration = StdDuration::from_millis(target_ms);
79
80 info!(
81 "Benchmark assumptions: target_duration={:?}, test_password='{}', salt={:?}",
82 target_duration, test_password, salt
83 );
84
85 let pepper = Self::generate_random_pepper();
86 info!("Generated random pepper for benchmark");
87
88 let best_memory = Self::binary_search_param(
89 target_duration,
90 |memory| {
91 let params = Params::new(memory, 3, 1, Some(32)).expect("argon2 params for memory");
92 let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
93 let start = Instant::now();
94 let mut out = [0u8; 32];
95 let mut adv = Vec::new();
96 adv.extend_from_slice(&salt);
97 adv.extend_from_slice(&pepper);
98 hasher
99 .hash_password_into(test_password.as_bytes(), &adv, &mut out)
100 .expect("hash during memory benchmark");
101 let duration = start.elapsed();
102 debug!("Memory {} KiB: duration={:?}", memory, duration);
103 duration
104 },
105 32768,
106 1048576,
107 );
108
109 let best_time = Self::binary_search_param(
110 target_duration,
111 |time| {
112 let params =
113 Params::new(best_memory, time, 1, Some(32)).expect("argon2 params for time");
114 let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
115 let start = Instant::now();
116 let mut out = [0u8; 32];
117 let mut adv = Vec::new();
118 adv.extend_from_slice(&salt);
119 adv.extend_from_slice(&pepper);
120 hasher
121 .hash_password_into(test_password.as_bytes(), &adv, &mut out)
122 .expect("hash during time benchmark");
123 let duration = start.elapsed();
124 debug!("Time {}: duration={:?}", time, duration);
125 duration
126 },
127 1,
128 10,
129 );
130
131 let best_lanes = Self::binary_search_param(
132 target_duration,
133 |lanes| {
134 let params = Params::new(best_memory, best_time, lanes, Some(32))
135 .expect("argon2 params for lanes");
136 let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
137 let start = Instant::now();
138 let mut out = [0u8; 32];
139 let mut adv = Vec::new();
140 adv.extend_from_slice(&salt);
141 adv.extend_from_slice(&pepper);
142 hasher
143 .hash_password_into(test_password.as_bytes(), &adv, &mut out)
144 .expect("hash during lanes benchmark");
145 let duration = start.elapsed();
146 debug!("Lanes {}: duration={:?}", lanes, duration);
147 duration
148 },
149 1,
150 8,
151 );
152
153 let best_config = Self {
154 pepper,
155 memory_kib: best_memory,
156 time_cost: best_time,
157 lanes: best_lanes,
158 };
159
160 let params = Params::new(best_memory, best_time, best_lanes, Some(32))
161 .expect("argon2 params for final measurement");
162 let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
163 let start = Instant::now();
164 let mut out = [0u8; 32];
165 let mut adv = Vec::new();
166 adv.extend_from_slice(&salt);
167 adv.extend_from_slice(&best_config.pepper);
168 hasher
169 .hash_password_into(test_password.as_bytes(), &adv, &mut out)
170 .expect("hash during final benchmark");
171 let final_duration = start.elapsed();
172
173 info!(
174 "Best HashConfig: memory={} KiB, time={}, lanes={}, duration={:?}",
175 best_config.memory_kib, best_config.time_cost, best_config.lanes, final_duration
176 );
177 best_config
178 }
179
180 fn generate_random_pepper() -> [u8; PEPPER_LEN] {
181 let mut bytes = [0u8; PEPPER_LEN];
182 getrandom::fill(&mut bytes).expect("generate random pepper");
183 bytes
184 }
185
186 fn binary_search_param<F>(target: StdDuration, measure: F, min: u32, max: u32) -> u32
187 where
188 F: Fn(u32) -> StdDuration,
189 {
190 let mut low = min;
191 let mut high = max;
192 let mut best = min;
193 let mut best_diff = StdDuration::from_secs(1000);
194
195 while low <= high {
196 let mid = low + (high - low) / 2;
197 let duration = measure(mid);
198 let diff = if duration > target {
199 duration - target
200 } else {
201 target - duration
202 };
203
204 if diff < best_diff {
205 best = mid;
206 best_diff = diff;
207 }
208
209 if duration < target {
210 low = mid + 1;
211 } else {
212 if mid == 0 {
213 break;
214 }
215 high = mid - 1;
216 }
217 }
218
219 best
220 }
221}
222
223pub trait KVTrait<K, V>
227where
228 K: ?Sized,
229{
230 fn get(&self, key: &K) -> Option<V>;
231 fn set(&self, key: &K, value: V);
232 fn contains(&self, key: &K) -> bool;
233 fn delete(&self, key: &K) -> bool;
234}
235
236pub struct SessionValue<const SESSION_LEN: usize> {
238 pub session_key: [u8; SESSION_LEN],
239 pub linked_accounts: Vec<Box<str>>,
240 pub last_time: DateTime<Utc>,
241 pub created_time: DateTime<Utc>,
242}
243
244pub struct AccountValue<const SALT_LEN: usize, const HASH_LEN: usize, const SESSION_LEN: usize> {
246 pub password_hash: [u8; HASH_LEN],
247 pub salt: [u8; SALT_LEN],
248 pub last_time: DateTime<Utc>,
249 pub linked_sessions: Vec<[u8; SESSION_LEN]>,
250}
251
252pub struct AuthManager<
254 S,
255 A,
256 const SESSION_LEN: usize = DEFAULT_SESSION_LEN,
257 const HASH_LEN: usize = DEFAULT_HASH_LEN,
258 const PEPPER_LEN: usize = DEFAULT_PEPPER_LEN,
259 const SALT_LEN: usize = DEFAULT_SALT_LEN,
260> where
261 S: KVTrait<[u8; SESSION_LEN], SessionValue<SESSION_LEN>> + Send + Sync,
262 A: KVTrait<str, AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> + Send + Sync,
263{
264 pub sessions: S,
265 pub accounts: A,
266 pub session_timeout: Duration,
267 pub account_timeout: Duration,
268 pub password_hasher: Argon2<'static>,
269 pub pepper: [u8; PEPPER_LEN],
270}
271
272impl<
273 S,
274 A,
275 const SESSION_LEN: usize,
276 const HASH_LEN: usize,
277 const PEPPER_LEN: usize,
278 const SALT_LEN: usize,
279> AuthManager<S, A, SESSION_LEN, HASH_LEN, PEPPER_LEN, SALT_LEN>
280where
281 S: KVTrait<[u8; SESSION_LEN], SessionValue<SESSION_LEN>> + Send + Sync,
282 A: KVTrait<str, AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> + Send + Sync,
283{
284 pub fn new(
285 sessions: S,
286 accounts: A,
287 session_timeout: Duration,
288 account_timeout: Duration,
289 hash_config: HashConfig<PEPPER_LEN>,
290 ) -> Self {
291 Self {
292 sessions,
293 accounts,
294 session_timeout,
295 account_timeout,
296 password_hasher: Argon2::new(
297 argon2::Algorithm::Argon2id,
298 argon2::Version::V0x13,
299 argon2::Params::new(
300 hash_config.memory_kib,
301 hash_config.time_cost,
302 hash_config.lanes,
303 Some(HASH_LEN),
304 )
305 .expect("argon2 hash params"),
306 ),
307 pepper: hash_config.pepper,
308 }
309 }
310
311 pub fn create_session(&self) -> [u8; SESSION_LEN] {
313 let session_id = Self::generate_session();
314 if self.sessions.contains(&session_id) {
315 return self.create_session(); }
317 let session_value = SessionValue::<SESSION_LEN> {
318 session_key: session_id,
319 linked_accounts: Vec::new(),
320 last_time: Utc::now(),
321 created_time: Utc::now(),
322 };
323 self.sessions.set(&session_id, session_value);
324 session_id
325 }
326
327 pub fn delete_session(&self, session_id: &[u8; SESSION_LEN]) -> bool {
329 if let Some(session) = self.sessions.get(session_id) {
330 for account in session.linked_accounts {
331 if let Some(mut account_value) = self.accounts.get(&account) {
332 account_value.linked_sessions.retain(|s| s != session_id);
333 self.accounts.set(&account, account_value);
334 }
335 }
336 self.sessions.delete(session_id)
337 } else {
338 false
339 }
340 }
341
342 pub fn get_session(&self, session_id: &[u8; SESSION_LEN]) -> Option<SessionValue<SESSION_LEN>> {
344 self.sessions.get(session_id)
345 }
346
347 pub fn add_account(&self, username: &str, password: &str) {
349 let salt = Self::generate_random_salt();
350 let password_hash = self.hash_password(password, &salt);
351 let account_value = AccountValue::<SALT_LEN, HASH_LEN, SESSION_LEN> {
352 password_hash,
353 salt,
354 last_time: Utc::now(),
355 linked_sessions: Vec::new(),
356 };
357 self.accounts.set(username, account_value);
358 }
359
360 pub fn delete_account(&self, username: &str) -> bool {
362 if let Some(account) = self.accounts.get(username) {
363 for session_id in account.linked_sessions {
364 if let Some(mut session_value) = self.sessions.get(&session_id) {
365 session_value
366 .linked_accounts
367 .retain(|a| a.as_ref() != username);
368 self.sessions.set(&session_id, session_value);
369 }
370 }
371 self.accounts.delete(username)
372 } else {
373 false
374 }
375 }
376
377 pub fn get_account(
379 &self,
380 username: &str,
381 ) -> Option<AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> {
382 self.accounts.get(username)
383 }
384
385 pub fn check_and_gc_session(
387 &self,
388 session_id: &[u8; SESSION_LEN],
389 ) -> Option<SessionValue<SESSION_LEN>> {
390 if let Some(session) = self.sessions.get(session_id) {
391 let now = Utc::now();
392 if now - session.last_time > self.session_timeout {
393 self.delete_session(session_id);
394 return Some(session);
395 }
396 }
397 None
398 }
399
400 pub fn auth_login(
402 &self,
403 session_id: &[u8; SESSION_LEN],
404 username: &str,
405 password: &str,
406 ) -> bool {
407 if let Some(account) = self.accounts.get(username) {
408 let expected_hash = self.hash_password(password, &account.salt);
409 if expected_hash == account.password_hash
410 && let Some(session) = self.sessions.get(session_id)
411 {
412 return self.link_account_to_session(username, session_id, account, session);
413 }
414 }
415 false
416 }
417
418 pub fn auth_logout(&self, session_id: &[u8; SESSION_LEN], username: &str) -> bool {
420 if let Some(account) = self.accounts.get(username) {
421 if let Some(session) = self.sessions.get(session_id) {
422 return self.unlink_account_from_session(username, session_id, account, session);
423 }
424 }
425 false
426 }
427
428 pub fn link_account_to_session(
430 &self,
431 username: &str,
432 session_id: &[u8; SESSION_LEN],
433 mut account: AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>,
434 mut session: SessionValue<SESSION_LEN>,
435 ) -> bool {
436 account.linked_sessions.push(*session_id);
437 session.linked_accounts.push(username.into());
438 self.accounts.set(username, account);
439 self.sessions.set(session_id, session);
440 true
441 }
442
443 pub fn unlink_account_from_session(
445 &self,
446 username: &str,
447 session_id: &[u8; SESSION_LEN],
448 mut account: AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>,
449 mut session: SessionValue<SESSION_LEN>,
450 ) -> bool {
451 account.linked_sessions.retain(|s| s != session_id);
452 session.linked_accounts.retain(|a| a.as_ref() != username);
453 self.accounts.set(username, account);
454 self.sessions.set(session_id, session);
455 true
456 }
457
458 fn hash_password(&self, password: &str, salt: &[u8; SALT_LEN]) -> [u8; HASH_LEN] {
459 let mut out = [0u8; HASH_LEN];
460 let mut adv = Vec::new();
461 adv.extend_from_slice(salt);
462 adv.extend_from_slice(&self.pepper);
463 self.password_hasher
464 .hash_password_into(password.as_bytes(), &adv, &mut out)
465 .unwrap();
466 out
467 }
468
469 fn generate_random_salt() -> [u8; SALT_LEN] {
470 let mut salt = [0u8; SALT_LEN];
471 getrandom::fill(&mut salt).expect("generate random salt");
472 salt
473 }
474
475 fn generate_session() -> [u8; SESSION_LEN] {
476 let mut session_id = [0u8; SESSION_LEN];
477 getrandom::fill(&mut session_id).expect("generate random session ID");
478 session_id
479 }
480}