Skip to main content

srv_session/
lib.rs

1use argon2::{Algorithm, Argon2, Params, Version};
2use chrono::{DateTime, Duration, Utc};
3use log::{debug, info};
4use std::time::{Duration as StdDuration, Instant};
5
6pub const DEFAULT_HASH_LEN: usize = 32;
7pub const DEFAULT_PEPPER_LEN: usize = 16;
8pub const DEFAULT_SALT_LEN: usize = 16;
9pub const DEFAULT_SESSION_LEN: usize = 32;
10
11#[derive(Clone)]
12pub struct HashConfig<const PEPPER_LEN: usize = DEFAULT_PEPPER_LEN> {
13    pub pepper: [u8; PEPPER_LEN],
14    pub memory_kib: u32,
15    pub time_cost: u32,
16    pub lanes: u32,
17}
18
19impl<const PEPPER_LEN: usize> serde::Serialize for HashConfig<PEPPER_LEN> {
20    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
21    where
22        S: serde::Serializer,
23    {
24        use serde::ser::SerializeStruct;
25        let mut state = serializer.serialize_struct("HashConfig", 4)?;
26        let mut hex_pepper = String::with_capacity(PEPPER_LEN * 2);
27        for b in &self.pepper {
28            use std::fmt::Write;
29            write!(&mut hex_pepper, "{:02x}", b).unwrap();
30        }
31        state.serialize_field("pepper", &hex_pepper)?;
32        state.serialize_field("memory_kib", &self.memory_kib)?;
33        state.serialize_field("time_cost", &self.time_cost)?;
34        state.serialize_field("lanes", &self.lanes)?;
35        state.end()
36    }
37}
38
39impl<'de, const PEPPER_LEN: usize> serde::Deserialize<'de> for HashConfig<PEPPER_LEN> {
40    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41    where
42        D: serde::Deserializer<'de>,
43    {
44        #[derive(serde::Deserialize)]
45        struct HashConfigHelper {
46            pepper: String,
47            memory_kib: u32,
48            time_cost: u32,
49            lanes: u32,
50        }
51        let helper = HashConfigHelper::deserialize(deserializer)?;
52        let s = helper.pepper.as_str();
53        if s.len() != PEPPER_LEN * 2 {
54            return Err(serde::de::Error::invalid_length(
55                s.len() / 2,
56                &format!("expected {} bytes for pepper", PEPPER_LEN).as_str(),
57            ));
58        }
59        let mut pepper = [0u8; PEPPER_LEN];
60        for (i, byte) in pepper.iter_mut().enumerate() {
61            let idx = i * 2;
62            *byte = u8::from_str_radix(&s[idx..idx + 2], 16).map_err(serde::de::Error::custom)?;
63        }
64        Ok(HashConfig {
65            pepper,
66            memory_kib: helper.memory_kib,
67            time_cost: helper.time_cost,
68            lanes: helper.lanes,
69        })
70    }
71}
72
73impl<const PEPPER_LEN: usize> HashConfig<PEPPER_LEN> {
74    pub fn benchmark(target_ms: u64) -> Self {
75        info!("Benchmarking HashConfig parameters...");
76        let test_password = "benchmark_password";
77        let salt = [0u8; 16];
78        let target_duration = StdDuration::from_millis(target_ms);
79
80        info!(
81            "Benchmark assumptions: target_duration={:?}, test_password='{}', salt={:?}",
82            target_duration, test_password, salt
83        );
84
85        let pepper = Self::generate_random_pepper();
86        info!("Generated random pepper for benchmark");
87
88        let best_memory = Self::binary_search_param(
89            target_duration,
90            |memory| {
91                let params = Params::new(memory, 3, 1, Some(32)).expect("argon2 params for memory");
92                let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
93                let start = Instant::now();
94                let mut out = [0u8; 32];
95                let mut adv = Vec::new();
96                adv.extend_from_slice(&salt);
97                adv.extend_from_slice(&pepper);
98                hasher
99                    .hash_password_into(test_password.as_bytes(), &adv, &mut out)
100                    .expect("hash during memory benchmark");
101                let duration = start.elapsed();
102                debug!("Memory {} KiB: duration={:?}", memory, duration);
103                duration
104            },
105            32768,
106            1048576,
107        );
108
109        let best_time = Self::binary_search_param(
110            target_duration,
111            |time| {
112                let params =
113                    Params::new(best_memory, time, 1, Some(32)).expect("argon2 params for time");
114                let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
115                let start = Instant::now();
116                let mut out = [0u8; 32];
117                let mut adv = Vec::new();
118                adv.extend_from_slice(&salt);
119                adv.extend_from_slice(&pepper);
120                hasher
121                    .hash_password_into(test_password.as_bytes(), &adv, &mut out)
122                    .expect("hash during time benchmark");
123                let duration = start.elapsed();
124                debug!("Time {}: duration={:?}", time, duration);
125                duration
126            },
127            1,
128            10,
129        );
130
131        let best_lanes = Self::binary_search_param(
132            target_duration,
133            |lanes| {
134                let params = Params::new(best_memory, best_time, lanes, Some(32))
135                    .expect("argon2 params for lanes");
136                let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
137                let start = Instant::now();
138                let mut out = [0u8; 32];
139                let mut adv = Vec::new();
140                adv.extend_from_slice(&salt);
141                adv.extend_from_slice(&pepper);
142                hasher
143                    .hash_password_into(test_password.as_bytes(), &adv, &mut out)
144                    .expect("hash during lanes benchmark");
145                let duration = start.elapsed();
146                debug!("Lanes {}: duration={:?}", lanes, duration);
147                duration
148            },
149            1,
150            8,
151        );
152
153        let best_config = Self {
154            pepper,
155            memory_kib: best_memory,
156            time_cost: best_time,
157            lanes: best_lanes,
158        };
159
160        let params = Params::new(best_memory, best_time, best_lanes, Some(32))
161            .expect("argon2 params for final measurement");
162        let hasher = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
163        let start = Instant::now();
164        let mut out = [0u8; 32];
165        let mut adv = Vec::new();
166        adv.extend_from_slice(&salt);
167        adv.extend_from_slice(&best_config.pepper);
168        hasher
169            .hash_password_into(test_password.as_bytes(), &adv, &mut out)
170            .expect("hash during final benchmark");
171        let final_duration = start.elapsed();
172
173        info!(
174            "Best HashConfig: memory={} KiB, time={}, lanes={}, duration={:?}",
175            best_config.memory_kib, best_config.time_cost, best_config.lanes, final_duration
176        );
177        best_config
178    }
179
180    fn generate_random_pepper() -> [u8; PEPPER_LEN] {
181        let mut bytes = [0u8; PEPPER_LEN];
182        getrandom::fill(&mut bytes).expect("generate random pepper");
183        bytes
184    }
185
186    fn binary_search_param<F>(target: StdDuration, measure: F, min: u32, max: u32) -> u32
187    where
188        F: Fn(u32) -> StdDuration,
189    {
190        let mut low = min;
191        let mut high = max;
192        let mut best = min;
193        let mut best_diff = StdDuration::from_secs(1000);
194
195        while low <= high {
196            let mid = low + (high - low) / 2;
197            let duration = measure(mid);
198            let diff = if duration > target {
199                duration - target
200            } else {
201                target - duration
202            };
203
204            if diff < best_diff {
205                best = mid;
206                best_diff = diff;
207            }
208
209            if duration < target {
210                low = mid + 1;
211            } else {
212                if mid == 0 {
213                    break;
214                }
215                high = mid - 1;
216            }
217        }
218
219        best
220    }
221}
222
223/// DBへの直接アクセスはやめてください
224/// キャッシュなどの階層構造を実装し、できる限りレイテンシを減らす実装をしてください
225/// DashMapとかがおすすめ
226pub trait KVTrait<K, V>
227where
228    K: ?Sized,
229{
230    fn get(&self, key: &K) -> Option<V>;
231    fn set(&self, key: &K, value: V);
232    fn contains(&self, key: &K) -> bool;
233    fn delete(&self, key: &K) -> bool;
234}
235
236/// セッションのデータ構造
237pub struct SessionValue<const SESSION_LEN: usize> {
238    pub session_key: [u8; SESSION_LEN],
239    pub linked_accounts: Vec<Box<str>>,
240    pub last_time: DateTime<Utc>,
241    pub created_time: DateTime<Utc>,
242}
243
244/// アカウントのデータ構造
245pub struct AccountValue<const SALT_LEN: usize, const HASH_LEN: usize, const SESSION_LEN: usize> {
246    pub password_hash: [u8; HASH_LEN],
247    pub salt: [u8; SALT_LEN],
248    pub last_time: DateTime<Utc>,
249    pub linked_sessions: Vec<[u8; SESSION_LEN]>,
250}
251
252/// 認証マネージャー
253pub struct AuthManager<
254    S,
255    A,
256    const SESSION_LEN: usize = DEFAULT_SESSION_LEN,
257    const HASH_LEN: usize = DEFAULT_HASH_LEN,
258    const PEPPER_LEN: usize = DEFAULT_PEPPER_LEN,
259    const SALT_LEN: usize = DEFAULT_SALT_LEN,
260> where
261    S: KVTrait<[u8; SESSION_LEN], SessionValue<SESSION_LEN>> + Send + Sync,
262    A: KVTrait<str, AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> + Send + Sync,
263{
264    pub sessions: S,
265    pub accounts: A,
266    pub session_timeout: Duration,
267    pub account_timeout: Duration,
268    pub password_hasher: Argon2<'static>,
269    pub pepper: [u8; PEPPER_LEN],
270}
271
272impl<
273    S,
274    A,
275    const SESSION_LEN: usize,
276    const HASH_LEN: usize,
277    const PEPPER_LEN: usize,
278    const SALT_LEN: usize,
279> AuthManager<S, A, SESSION_LEN, HASH_LEN, PEPPER_LEN, SALT_LEN>
280where
281    S: KVTrait<[u8; SESSION_LEN], SessionValue<SESSION_LEN>> + Send + Sync,
282    A: KVTrait<str, AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> + Send + Sync,
283{
284    pub fn new(
285        sessions: S,
286        accounts: A,
287        session_timeout: Duration,
288        account_timeout: Duration,
289        hash_config: HashConfig<PEPPER_LEN>,
290    ) -> Self {
291        Self {
292            sessions,
293            accounts,
294            session_timeout,
295            account_timeout,
296            password_hasher: Argon2::new(
297                argon2::Algorithm::Argon2id,
298                argon2::Version::V0x13,
299                argon2::Params::new(
300                    hash_config.memory_kib,
301                    hash_config.time_cost,
302                    hash_config.lanes,
303                    Some(HASH_LEN),
304                )
305                .expect("argon2 hash params"),
306            ),
307            pepper: hash_config.pepper,
308        }
309    }
310
311    /// 新しいセッションを追加
312    pub fn create_session(&self) -> [u8; SESSION_LEN] {
313        let session_id = Self::generate_session();
314        if self.sessions.contains(&session_id) {
315            return self.create_session(); // Regenerate if collision occurs
316        }
317        let session_value = SessionValue::<SESSION_LEN> {
318            session_key: session_id,
319            linked_accounts: Vec::new(),
320            last_time: Utc::now(),
321            created_time: Utc::now(),
322        };
323        self.sessions.set(&session_id, session_value);
324        session_id
325    }
326
327    /// セッションを削除
328    pub fn delete_session(&self, session_id: &[u8; SESSION_LEN]) -> bool {
329        if let Some(session) = self.sessions.get(session_id) {
330            for account in session.linked_accounts {
331                if let Some(mut account_value) = self.accounts.get(&account) {
332                    account_value.linked_sessions.retain(|s| s != session_id);
333                    self.accounts.set(&account, account_value);
334                }
335            }
336            self.sessions.delete(session_id)
337        } else {
338            false
339        }
340    }
341
342    /// セッションを取得
343    pub fn get_session(&self, session_id: &[u8; SESSION_LEN]) -> Option<SessionValue<SESSION_LEN>> {
344        self.sessions.get(session_id)
345    }
346
347    /// 新しいアカウントを追加
348    pub fn add_account(&self, username: &str, password: &str) {
349        let salt = Self::generate_random_salt();
350        let password_hash = self.hash_password(password, &salt);
351        let account_value = AccountValue::<SALT_LEN, HASH_LEN, SESSION_LEN> {
352            password_hash,
353            salt,
354            last_time: Utc::now(),
355            linked_sessions: Vec::new(),
356        };
357        self.accounts.set(username, account_value);
358    }
359
360    /// アカウントを削除
361    pub fn delete_account(&self, username: &str) -> bool {
362        if let Some(account) = self.accounts.get(username) {
363            for session_id in account.linked_sessions {
364                if let Some(mut session_value) = self.sessions.get(&session_id) {
365                    session_value
366                        .linked_accounts
367                        .retain(|a| a.as_ref() != username);
368                    self.sessions.set(&session_id, session_value);
369                }
370            }
371            self.accounts.delete(username)
372        } else {
373            false
374        }
375    }
376
377    /// アカウントを取得
378    pub fn get_account(
379        &self,
380        username: &str,
381    ) -> Option<AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>> {
382        self.accounts.get(username)
383    }
384
385    /// セッションを検査して、期限切れなら削除
386    pub fn check_and_gc_session(
387        &self,
388        session_id: &[u8; SESSION_LEN],
389    ) -> Option<SessionValue<SESSION_LEN>> {
390        if let Some(session) = self.sessions.get(session_id) {
391            let now = Utc::now();
392            if now - session.last_time > self.session_timeout {
393                self.delete_session(session_id);
394                return Some(session);
395            }
396        }
397        None
398    }
399
400    /// ログイン処理
401    pub fn auth_login(
402        &self,
403        session_id: &[u8; SESSION_LEN],
404        username: &str,
405        password: &str,
406    ) -> bool {
407        if let Some(account) = self.accounts.get(username) {
408            let expected_hash = self.hash_password(password, &account.salt);
409            if expected_hash == account.password_hash
410                && let Some(session) = self.sessions.get(session_id)
411            {
412                return self.link_account_to_session(username, session_id, account, session);
413            }
414        }
415        false
416    }
417
418    /// ログアウト処理
419    pub fn auth_logout(&self, session_id: &[u8; SESSION_LEN], username: &str) -> bool {
420        if let Some(account) = self.accounts.get(username) {
421            if let Some(session) = self.sessions.get(session_id) {
422                return self.unlink_account_from_session(username, session_id, account, session);
423            }
424        }
425        false
426    }
427
428    /// セッションとアカウントをリンク
429    pub fn link_account_to_session(
430        &self,
431        username: &str,
432        session_id: &[u8; SESSION_LEN],
433        mut account: AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>,
434        mut session: SessionValue<SESSION_LEN>,
435    ) -> bool {
436        account.linked_sessions.push(*session_id);
437        session.linked_accounts.push(username.into());
438        self.accounts.set(username, account);
439        self.sessions.set(session_id, session);
440        true
441    }
442
443    /// セッションとアカウントのリンクを解除
444    pub fn unlink_account_from_session(
445        &self,
446        username: &str,
447        session_id: &[u8; SESSION_LEN],
448        mut account: AccountValue<SALT_LEN, HASH_LEN, SESSION_LEN>,
449        mut session: SessionValue<SESSION_LEN>,
450    ) -> bool {
451        account.linked_sessions.retain(|s| s != session_id);
452        session.linked_accounts.retain(|a| a.as_ref() != username);
453        self.accounts.set(username, account);
454        self.sessions.set(session_id, session);
455        true
456    }
457
458    fn hash_password(&self, password: &str, salt: &[u8; SALT_LEN]) -> [u8; HASH_LEN] {
459        let mut out = [0u8; HASH_LEN];
460        let mut adv = Vec::new();
461        adv.extend_from_slice(salt);
462        adv.extend_from_slice(&self.pepper);
463        self.password_hasher
464            .hash_password_into(password.as_bytes(), &adv, &mut out)
465            .unwrap();
466        out
467    }
468
469    fn generate_random_salt() -> [u8; SALT_LEN] {
470        let mut salt = [0u8; SALT_LEN];
471        getrandom::fill(&mut salt).expect("generate random salt");
472        salt
473    }
474
475    fn generate_session() -> [u8; SESSION_LEN] {
476        let mut session_id = [0u8; SESSION_LEN];
477        getrandom::fill(&mut session_id).expect("generate random session ID");
478        session_id
479    }
480}