senax_actix_session/
session.rs

1use actix_utils::future::{Ready, ready};
2use actix_web::{
3    FromRequest, HttpMessage, HttpRequest,
4    dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
5    error::Error,
6};
7use anyhow::{Context, Result, bail};
8use senax_common::session::SessionKey;
9use senax_common::session::interface::{SaveError, SessionData, SessionStore};
10use serde::{Serialize, de::DeserializeOwned};
11use sha2::{Digest, Sha256};
12use std::{collections::HashMap, mem, sync::Arc, sync::Mutex};
13use time::Duration;
14
15use crate::{config::Configuration, middleware::e500};
16
17const MAX_RETRY_COUNT: usize = 10;
18
19#[derive(Clone)]
20pub struct Session<Store: SessionStore + 'static>(Arc<Mutex<SessionInner<Store>>>);
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum SessionStatus {
24    Unchanged,
25    Changed,
26    Purged,
27}
28
29pub struct SessionInner<Store: SessionStore + 'static> {
30    session_key: Option<SessionKey>,
31    guest_zone: HashMap<String, Vec<u8>>,
32    user_zone: HashMap<String, Vec<u8>>,
33    debug_zone: HashMap<String, Vec<u8>>,
34    update: bool,
35    status: SessionStatus,
36    state_ttl: Duration,
37    version: u32,
38    storage: Arc<Store>,
39}
40
41impl<Store: SessionStore + 'static> SessionInner<Store> {
42    pub fn get_from_guest_zone<T: DeserializeOwned>(&mut self, key: &str) -> Result<Option<T>> {
43        if let Some(val) = self.guest_zone.get(key) {
44            Ok(Some(ciborium::from_reader(val.as_slice())?))
45        } else {
46            Ok(None)
47        }
48    }
49
50    pub fn insert_to_guest_zone<T: Serialize>(
51        &mut self,
52        key: impl Into<String>,
53        value: T,
54    ) -> Result<()> {
55        self.update = true;
56        let mut buf = Vec::new();
57        ciborium::into_writer(&value, &mut buf)?;
58        self.guest_zone.insert(key.into(), buf);
59        Ok(())
60    }
61
62    pub fn remove_from_guest_zone(&mut self, key: &str) {
63        self.update = true;
64        self.guest_zone.remove(key);
65    }
66
67    pub fn remove_from_guest_zone_as<T: DeserializeOwned>(
68        &mut self,
69        key: &str,
70    ) -> Option<Result<T>> {
71        self.update = true;
72        self.guest_zone
73            .remove(key)
74            .map(|val| Ok(ciborium::from_reader(val.as_slice())?))
75    }
76
77    pub fn clear_guest_zone(&mut self) {
78        self.update = true;
79        self.guest_zone.clear();
80    }
81
82    pub fn get_from_user_zone<T: DeserializeOwned>(&mut self, key: &str) -> Result<Option<T>> {
83        if let Some(val) = self.user_zone.get(key) {
84            Ok(Some(ciborium::from_reader(val.as_slice())?))
85        } else {
86            Ok(None)
87        }
88    }
89
90    pub fn insert_to_user_zone<T: Serialize>(
91        &mut self,
92        key: impl Into<String>,
93        value: T,
94    ) -> Result<()> {
95        self.update = true;
96        let mut buf = Vec::new();
97        ciborium::into_writer(&value, &mut buf)?;
98        self.user_zone.insert(key.into(), buf);
99        Ok(())
100    }
101
102    pub fn remove_from_user_zone(&mut self, key: &str) {
103        self.update = true;
104        self.user_zone.remove(key);
105    }
106
107    pub fn remove_from_user_zone_as<T: DeserializeOwned>(
108        &mut self,
109        key: &str,
110    ) -> Option<Result<T>> {
111        self.update = true;
112        self.user_zone
113            .remove(key)
114            .map(|val| Ok(ciborium::from_reader(val.as_slice())?))
115    }
116
117    pub fn clear_user_zone(&mut self) {
118        self.update = true;
119        self.user_zone.clear();
120    }
121
122    pub fn get_from_debug_zone<T: DeserializeOwned>(&mut self, key: &str) -> Result<Option<T>> {
123        if cfg!(debug_assertions) {
124            if let Some(val) = self.debug_zone.get(key) {
125                Ok(Some(ciborium::from_reader(val.as_slice())?))
126            } else {
127                Ok(None)
128            }
129        } else {
130            Ok(None)
131        }
132    }
133
134    pub fn insert_to_debug_zone<T: Serialize>(
135        &mut self,
136        key: impl Into<String>,
137        value: T,
138    ) -> Result<()> {
139        if cfg!(debug_assertions) {
140            self.update = true;
141            let mut buf = Vec::new();
142            ciborium::into_writer(&value, &mut buf)?;
143            self.debug_zone.insert(key.into(), buf);
144        }
145        Ok(())
146    }
147
148    pub fn remove_from_debug_zone(&mut self, key: &str) {
149        if cfg!(debug_assertions) {
150            self.update = true;
151            self.debug_zone.remove(key);
152        }
153    }
154
155    pub fn remove_from_debug_zone_as<T: DeserializeOwned>(
156        &mut self,
157        key: &str,
158    ) -> Option<Result<T>> {
159        if cfg!(debug_assertions) {
160            self.update = true;
161            self.debug_zone
162                .remove(key)
163                .map(|val| Ok(ciborium::from_reader(val.as_slice())?))
164        } else {
165            None
166        }
167    }
168
169    pub fn clear_debug_zone(&mut self) {
170        if cfg!(debug_assertions) {
171            self.update = true;
172            self.debug_zone.clear();
173        }
174    }
175}
176
177impl<Store: SessionStore + 'static> Session<Store> {
178    pub fn session_key(&self) -> Option<SessionKey> {
179        self.0.lock().unwrap().session_key.as_ref().cloned()
180    }
181
182    pub fn csrf_token(&self) -> Option<String> {
183        use std::fmt::Write;
184        self.0.lock().unwrap().session_key.as_ref().map(|v| {
185            Sha256::digest(String::from(v))
186                .iter()
187                .take(8)
188                .fold(String::new(), |mut output, x| {
189                    write!(output, "{:02X}", x).unwrap();
190                    output
191                })
192        })
193    }
194
195    pub fn contains_in_guest_zone(&self, key: &str) -> bool {
196        self.0.lock().unwrap().guest_zone.contains_key(key)
197    }
198
199    pub fn contains_in_user_zone(&self, key: &str) -> bool {
200        self.0.lock().unwrap().user_zone.contains_key(key)
201    }
202
203    pub fn contains_in_debug_zone(&self, key: &str) -> bool {
204        if cfg!(debug_assertions) {
205            self.0.lock().unwrap().debug_zone.contains_key(key)
206        } else {
207            false
208        }
209    }
210
211    pub fn get_from_guest_zone<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
212        if let Some(val) = self.0.lock().unwrap().guest_zone.get(key) {
213            Ok(Some(ciborium::from_reader(val.as_slice())?))
214        } else {
215            Ok(None)
216        }
217    }
218
219    pub fn get_from_user_zone<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
220        if let Some(val) = self.0.lock().unwrap().user_zone.get(key) {
221            Ok(Some(ciborium::from_reader(val.as_slice())?))
222        } else {
223            Ok(None)
224        }
225    }
226
227    pub fn get_from_debug_zone<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
228        if cfg!(debug_assertions) {
229            if let Some(val) = self.0.lock().unwrap().debug_zone.get(key) {
230                Ok(Some(ciborium::from_reader(val.as_slice())?))
231            } else {
232                Ok(None)
233            }
234        } else {
235            Ok(None)
236        }
237    }
238
239    pub fn keys_of_guest_zone(&self) -> Vec<String> {
240        self.0.lock().unwrap().guest_zone.keys().cloned().collect()
241    }
242
243    pub fn keys_of_user_zone(&self) -> Vec<String> {
244        self.0.lock().unwrap().user_zone.keys().cloned().collect()
245    }
246
247    pub fn keys_of_debug_zone(&self) -> Vec<String> {
248        self.0.lock().unwrap().debug_zone.keys().cloned().collect()
249    }
250
251    pub fn status(&self) -> SessionStatus {
252        self.0.lock().unwrap().status
253    }
254
255    pub async fn update<F, R>(&self, f: F) -> Result<R>
256    where
257        F: Fn(&mut SessionInner<Store>) -> Result<R>,
258    {
259        let mut retry_count = 0;
260        loop {
261            let (f_result, session_data, key, storage) = {
262                let mut inner = self.0.lock().unwrap();
263                let f_result = f(&mut inner);
264                if !inner.update {
265                    return f_result;
266                }
267                inner.update = false;
268                let list = vec![&inner.debug_zone, &inner.user_zone, &inner.guest_zone];
269                let mut buf = Vec::new();
270                ciborium::into_writer(&list, &mut buf)?;
271                let session_data = SessionData::from((buf, inner.state_ttl, inner.version));
272                let key = inner.session_key.as_ref().cloned();
273                let storage = Arc::clone(&inner.storage);
274                (f_result, session_data, key, storage)
275            };
276            match storage.save(key, session_data).await {
277                Ok(key) => {
278                    let mut inner = self.0.lock().unwrap();
279                    if !matches!(&inner.session_key, Some(x) if x == &key) {
280                        inner.status = SessionStatus::Changed;
281                        inner.session_key = Some(key);
282                    }
283                }
284                Err(SaveError::Retryable) => {
285                    retry_count += 1;
286                    if retry_count > MAX_RETRY_COUNT {
287                        bail!("too many session update retry");
288                    }
289                    tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
290                    self.reload().await?;
291                    continue;
292                }
293                Err(SaveError::RetryableWithData(data)) => {
294                    self.reload_from_data(data).await?;
295                    continue;
296                }
297                Err(SaveError::Other(e)) => {
298                    return Err(e);
299                }
300            }
301            return f_result;
302        }
303    }
304
305    pub async fn purge(&self) -> Result<()> {
306        let key = self.0.lock().unwrap().session_key.as_ref().cloned();
307        let storage = Arc::clone(&self.0.lock().unwrap().storage);
308        if key.is_some() {
309            storage.delete(key.as_ref().unwrap()).await?;
310        }
311        let mut inner = self.0.lock().unwrap();
312        inner.status = SessionStatus::Purged;
313        inner.session_key = None;
314        inner.guest_zone.clear();
315        inner.user_zone.clear();
316        inner.debug_zone.clear();
317        inner.version = 0;
318        Ok(())
319    }
320
321    /// Update session key.
322    /// The old session will not be deleted for the following reasons:
323    /// * If there are simultaneous accesses, the access that comes after the renew will create a new session and cancel the previous cookie.
324    /// * Consider cases where modified cookies cannot be received due to communication errors.
325    ///
326    /// Thus, since previous session data is not deleted, RENEW should not be used for logout.
327    pub async fn renew<F, R>(&self, f: F) -> Result<R>
328    where
329        F: Fn(&mut SessionInner<Store>) -> Result<R>,
330    {
331        self.reload().await?;
332        let mut retry_count = 0;
333        loop {
334            let (f_result, session_data, storage) = {
335                let mut inner = self.0.lock().unwrap();
336                let f_result = f(&mut inner);
337                inner.update = false;
338                let list = vec![&inner.debug_zone, &inner.user_zone, &inner.guest_zone];
339                let mut buf = Vec::new();
340                ciborium::into_writer(&list, &mut buf)?;
341                let session_data = SessionData::from((buf, inner.state_ttl, inner.version));
342                let storage = Arc::clone(&inner.storage);
343                (f_result, session_data, storage)
344            };
345            match storage.save(None, session_data).await {
346                Ok(key) => {
347                    let mut inner = self.0.lock().unwrap();
348                    inner.status = SessionStatus::Changed;
349                    inner.session_key = Some(key);
350                }
351                Err(SaveError::Retryable) => {
352                    retry_count += 1;
353                    if retry_count > MAX_RETRY_COUNT {
354                        bail!("too many session renew retry");
355                    }
356                    continue;
357                }
358                Err(SaveError::RetryableWithData(_)) => {
359                    bail!("unreachable error");
360                }
361                Err(SaveError::Other(e)) => {
362                    return Err(e);
363                }
364            }
365            return f_result;
366        }
367    }
368
369    pub(crate) async fn set_session(
370        req: &mut ServiceRequest,
371        session_key: Option<SessionKey>,
372        storage: Arc<Store>,
373        configuration: &Configuration,
374    ) -> Result<()> {
375        let (session_key, mut data) = if let Some(session_key) = session_key {
376            let data = storage.load(&session_key).await?;
377            if let Some(data) = data {
378                if data.ttl_as_duration().is_positive() {
379                    (Some(session_key), data)
380                } else {
381                    (None, SessionData::default())
382                }
383            } else {
384                (None, SessionData::default())
385            }
386        } else {
387            (None, SessionData::default())
388        };
389        let mut status = SessionStatus::Unchanged;
390        let ttl = configuration.session.state_ttl.whole_seconds();
391        if let Some(session_key) = &session_key {
392            if (ttl - data.ttl()) > (ttl >> 6) {
393                data.set_ttl(configuration.session.state_ttl);
394                let _ = storage.update_ttl(session_key, &data).await.map_err(|e| {
395                    log::warn!("{}", e);
396                });
397                if configuration.cookie.max_age.is_some() {
398                    status = SessionStatus::Changed;
399                }
400            }
401        }
402        let mut list: Vec<HashMap<String, Vec<u8>>> = if data.is_empty_data() {
403            Vec::new()
404        } else {
405            ciborium::from_reader(data.data())?
406        };
407        let inner = SessionInner::<Store> {
408            session_key,
409            guest_zone: list.pop().unwrap_or_default(),
410            user_zone: list.pop().unwrap_or_default(),
411            debug_zone: list.pop().unwrap_or_default(),
412            update: false,
413            status,
414            state_ttl: configuration.session.state_ttl,
415            version: data.version(),
416            storage,
417        };
418        let inner = Arc::new(Mutex::new(inner));
419        req.extensions_mut().insert(inner);
420        Ok(())
421    }
422
423    pub async fn reset(&self) {
424        let mut inner = self.0.lock().unwrap();
425        inner.status = SessionStatus::Unchanged;
426        inner.session_key = None;
427        inner.guest_zone.clear();
428        inner.user_zone.clear();
429        inner.debug_zone.clear();
430        inner.version = 0;
431    }
432
433    pub async fn load(&self, key: &str) -> Result<()> {
434        let mut session_key = Some(key.to_string().try_into()?);
435        let (data, mut list) = {
436            let storage = Arc::clone(&self.0.lock().unwrap().storage);
437            let data = {
438                let data = storage.load(session_key.as_ref().unwrap()).await?;
439                if let Some(data) = data {
440                    data
441                } else {
442                    session_key = None;
443                    SessionData::default()
444                }
445            };
446            let list: Vec<HashMap<String, Vec<u8>>> = if data.is_empty_data() {
447                Vec::new()
448            } else {
449                ciborium::from_reader(data.data())?
450            };
451            (data, list)
452        };
453        let mut inner = self.0.lock().unwrap();
454        inner.session_key = session_key;
455        inner.guest_zone = list.pop().unwrap_or_default();
456        inner.user_zone = list.pop().unwrap_or_default();
457        inner.debug_zone = list.pop().unwrap_or_default();
458        inner.version = data.version();
459        Ok(())
460    }
461
462    pub(crate) async fn reload(&self) -> Result<()> {
463        let (data, mut list) = {
464            let key = self.0.lock().unwrap().session_key.as_ref().cloned();
465            let storage = Arc::clone(&self.0.lock().unwrap().storage);
466            let data = if let Some(session_key) = key {
467                let data = storage.reload(&session_key).await?;
468                if let Some(data) = data {
469                    data
470                } else {
471                    self.0.lock().unwrap().session_key = None;
472                    return Ok(());
473                }
474            } else {
475                return Ok(());
476            };
477            let list: Vec<HashMap<String, Vec<u8>>> = if data.is_empty_data() {
478                Vec::new()
479            } else {
480                ciborium::from_reader(data.data())?
481            };
482            (data, list)
483        };
484        let mut inner = self.0.lock().unwrap();
485        inner.guest_zone = list.pop().unwrap_or_default();
486        inner.user_zone = list.pop().unwrap_or_default();
487        inner.debug_zone = list.pop().unwrap_or_default();
488        inner.version = data.version();
489        Ok(())
490    }
491
492    pub(crate) async fn reload_from_data(&self, data: SessionData) -> Result<()> {
493        let (data, mut list) = {
494            let list: Vec<HashMap<String, Vec<u8>>> = if data.is_empty_data() {
495                Vec::new()
496            } else {
497                ciborium::from_reader(data.data())?
498            };
499            (data, list)
500        };
501        let mut inner = self.0.lock().unwrap();
502        inner.guest_zone = list.pop().unwrap_or_default();
503        inner.user_zone = list.pop().unwrap_or_default();
504        inner.debug_zone = list.pop().unwrap_or_default();
505        inner.version = data.version();
506        Ok(())
507    }
508
509    pub(crate) fn get_status<B>(
510        res: &mut ServiceResponse<B>,
511    ) -> (SessionStatus, Option<SessionKey>) {
512        if let Some(s_impl) = res
513            .request()
514            .extensions()
515            .get::<Arc<Mutex<SessionInner<Store>>>>()
516        {
517            let session_key = mem::take(&mut s_impl.lock().unwrap().session_key);
518            (s_impl.lock().unwrap().status, session_key)
519        } else {
520            (SessionStatus::Unchanged, None)
521        }
522    }
523
524    pub(crate) fn get_session(extensions: &mut Extensions) -> Result<Session<Store>> {
525        let s_impl = extensions
526            .get::<Arc<Mutex<SessionInner<Store>>>>()
527            .with_context(|| "No session is set up.")?;
528        Ok(Session(Arc::clone(s_impl)))
529    }
530}
531
532impl<Store: SessionStore + 'static> std::fmt::Debug for Session<Store> {
533    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534        let inner = self.0.lock().unwrap();
535        let list: Vec<HashMap<&String, ciborium::Value>> = vec![
536            inner
537                .guest_zone
538                .iter()
539                .map(|(k, v)| (k, ciborium::from_reader(v.as_slice()).unwrap()))
540                .collect(),
541            inner
542                .user_zone
543                .iter()
544                .map(|(k, v)| (k, ciborium::from_reader(v.as_slice()).unwrap()))
545                .collect(),
546            inner
547                .debug_zone
548                .iter()
549                .map(|(k, v)| (k, ciborium::from_reader(v.as_slice()).unwrap()))
550                .collect(),
551        ];
552        f.debug_tuple("Session")
553            .field(&inner.session_key)
554            .field(&list)
555            .finish()
556    }
557}
558
559impl<Store: SessionStore + 'static> FromRequest for Session<Store> {
560    type Error = Error;
561    type Future = Ready<Result<Session<Store>, Error>>;
562
563    #[inline]
564    fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
565        ready(Session::get_session(&mut req.extensions_mut()).map_err(e500))
566    }
567}