1use actix_utils::future::{Ready, ready};
2use actix_web::{
3 FromRequest, HttpMessage, HttpRequest,
4 dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
5 error::Error,
6};
7use anyhow::{Context, Result, bail};
8use senax_common::session::SessionKey;
9use senax_common::session::interface::{SaveError, SessionData, SessionStore};
10use serde::{Serialize, de::DeserializeOwned};
11use sha2::{Digest, Sha256};
12use std::{collections::HashMap, mem, sync::Arc, sync::Mutex};
13use time::Duration;
14
15use crate::{config::Configuration, middleware::e500};
16
17const MAX_RETRY_COUNT: usize = 10;
18
19#[derive(Clone)]
20pub struct Session<Store: SessionStore + 'static>(Arc<Mutex<SessionInner<Store>>>);
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum SessionStatus {
24 Unchanged,
25 Changed,
26 Purged,
27}
28
29pub struct SessionInner<Store: SessionStore + 'static> {
30 session_key: Option<SessionKey>,
31 guest_zone: HashMap<String, Vec<u8>>,
32 user_zone: HashMap<String, Vec<u8>>,
33 debug_zone: HashMap<String, Vec<u8>>,
34 update: bool,
35 status: SessionStatus,
36 state_ttl: Duration,
37 version: u32,
38 storage: Arc<Store>,
39}
40
41impl<Store: SessionStore + 'static> SessionInner<Store> {
42 pub fn get_from_guest_zone<T: DeserializeOwned>(&mut self, key: &str) -> Result<Option<T>> {
43 if let Some(val) = self.guest_zone.get(key) {
44 Ok(Some(ciborium::from_reader(val.as_slice())?))
45 } else {
46 Ok(None)
47 }
48 }
49
50 pub fn insert_to_guest_zone<T: Serialize>(
51 &mut self,
52 key: impl Into<String>,
53 value: T,
54 ) -> Result<()> {
55 self.update = true;
56 let mut buf = Vec::new();
57 ciborium::into_writer(&value, &mut buf)?;
58 self.guest_zone.insert(key.into(), buf);
59 Ok(())
60 }
61
62 pub fn remove_from_guest_zone(&mut self, key: &str) {
63 self.update = true;
64 self.guest_zone.remove(key);
65 }
66
67 pub fn remove_from_guest_zone_as<T: DeserializeOwned>(
68 &mut self,
69 key: &str,
70 ) -> Option<Result<T>> {
71 self.update = true;
72 self.guest_zone
73 .remove(key)
74 .map(|val| Ok(ciborium::from_reader(val.as_slice())?))
75 }
76
77 pub fn clear_guest_zone(&mut self) {
78 self.update = true;
79 self.guest_zone.clear();
80 }
81
82 pub fn get_from_user_zone<T: DeserializeOwned>(&mut self, key: &str) -> Result<Option<T>> {
83 if let Some(val) = self.user_zone.get(key) {
84 Ok(Some(ciborium::from_reader(val.as_slice())?))
85 } else {
86 Ok(None)
87 }
88 }
89
90 pub fn insert_to_user_zone<T: Serialize>(
91 &mut self,
92 key: impl Into<String>,
93 value: T,
94 ) -> Result<()> {
95 self.update = true;
96 let mut buf = Vec::new();
97 ciborium::into_writer(&value, &mut buf)?;
98 self.user_zone.insert(key.into(), buf);
99 Ok(())
100 }
101
102 pub fn remove_from_user_zone(&mut self, key: &str) {
103 self.update = true;
104 self.user_zone.remove(key);
105 }
106
107 pub fn remove_from_user_zone_as<T: DeserializeOwned>(
108 &mut self,
109 key: &str,
110 ) -> Option<Result<T>> {
111 self.update = true;
112 self.user_zone
113 .remove(key)
114 .map(|val| Ok(ciborium::from_reader(val.as_slice())?))
115 }
116
117 pub fn clear_user_zone(&mut self) {
118 self.update = true;
119 self.user_zone.clear();
120 }
121
122 pub fn get_from_debug_zone<T: DeserializeOwned>(&mut self, key: &str) -> Result<Option<T>> {
123 if cfg!(debug_assertions) {
124 if let Some(val) = self.debug_zone.get(key) {
125 Ok(Some(ciborium::from_reader(val.as_slice())?))
126 } else {
127 Ok(None)
128 }
129 } else {
130 Ok(None)
131 }
132 }
133
134 pub fn insert_to_debug_zone<T: Serialize>(
135 &mut self,
136 key: impl Into<String>,
137 value: T,
138 ) -> Result<()> {
139 if cfg!(debug_assertions) {
140 self.update = true;
141 let mut buf = Vec::new();
142 ciborium::into_writer(&value, &mut buf)?;
143 self.debug_zone.insert(key.into(), buf);
144 }
145 Ok(())
146 }
147
148 pub fn remove_from_debug_zone(&mut self, key: &str) {
149 if cfg!(debug_assertions) {
150 self.update = true;
151 self.debug_zone.remove(key);
152 }
153 }
154
155 pub fn remove_from_debug_zone_as<T: DeserializeOwned>(
156 &mut self,
157 key: &str,
158 ) -> Option<Result<T>> {
159 if cfg!(debug_assertions) {
160 self.update = true;
161 self.debug_zone
162 .remove(key)
163 .map(|val| Ok(ciborium::from_reader(val.as_slice())?))
164 } else {
165 None
166 }
167 }
168
169 pub fn clear_debug_zone(&mut self) {
170 if cfg!(debug_assertions) {
171 self.update = true;
172 self.debug_zone.clear();
173 }
174 }
175}
176
177impl<Store: SessionStore + 'static> Session<Store> {
178 pub fn session_key(&self) -> Option<SessionKey> {
179 self.0.lock().unwrap().session_key.as_ref().cloned()
180 }
181
182 pub fn csrf_token(&self) -> Option<String> {
183 use std::fmt::Write;
184 self.0.lock().unwrap().session_key.as_ref().map(|v| {
185 Sha256::digest(String::from(v))
186 .iter()
187 .take(8)
188 .fold(String::new(), |mut output, x| {
189 write!(output, "{:02X}", x).unwrap();
190 output
191 })
192 })
193 }
194
195 pub fn contains_in_guest_zone(&self, key: &str) -> bool {
196 self.0.lock().unwrap().guest_zone.contains_key(key)
197 }
198
199 pub fn contains_in_user_zone(&self, key: &str) -> bool {
200 self.0.lock().unwrap().user_zone.contains_key(key)
201 }
202
203 pub fn contains_in_debug_zone(&self, key: &str) -> bool {
204 if cfg!(debug_assertions) {
205 self.0.lock().unwrap().debug_zone.contains_key(key)
206 } else {
207 false
208 }
209 }
210
211 pub fn get_from_guest_zone<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
212 if let Some(val) = self.0.lock().unwrap().guest_zone.get(key) {
213 Ok(Some(ciborium::from_reader(val.as_slice())?))
214 } else {
215 Ok(None)
216 }
217 }
218
219 pub fn get_from_user_zone<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
220 if let Some(val) = self.0.lock().unwrap().user_zone.get(key) {
221 Ok(Some(ciborium::from_reader(val.as_slice())?))
222 } else {
223 Ok(None)
224 }
225 }
226
227 pub fn get_from_debug_zone<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
228 if cfg!(debug_assertions) {
229 if let Some(val) = self.0.lock().unwrap().debug_zone.get(key) {
230 Ok(Some(ciborium::from_reader(val.as_slice())?))
231 } else {
232 Ok(None)
233 }
234 } else {
235 Ok(None)
236 }
237 }
238
239 pub fn keys_of_guest_zone(&self) -> Vec<String> {
240 self.0.lock().unwrap().guest_zone.keys().cloned().collect()
241 }
242
243 pub fn keys_of_user_zone(&self) -> Vec<String> {
244 self.0.lock().unwrap().user_zone.keys().cloned().collect()
245 }
246
247 pub fn keys_of_debug_zone(&self) -> Vec<String> {
248 self.0.lock().unwrap().debug_zone.keys().cloned().collect()
249 }
250
251 pub fn status(&self) -> SessionStatus {
252 self.0.lock().unwrap().status
253 }
254
255 pub async fn update<F, R>(&self, f: F) -> Result<R>
256 where
257 F: Fn(&mut SessionInner<Store>) -> Result<R>,
258 {
259 let mut retry_count = 0;
260 loop {
261 let (f_result, session_data, key, storage) = {
262 let mut inner = self.0.lock().unwrap();
263 let f_result = f(&mut inner);
264 if !inner.update {
265 return f_result;
266 }
267 inner.update = false;
268 let list = vec![&inner.debug_zone, &inner.user_zone, &inner.guest_zone];
269 let mut buf = Vec::new();
270 ciborium::into_writer(&list, &mut buf)?;
271 let session_data = SessionData::from((buf, inner.state_ttl, inner.version));
272 let key = inner.session_key.as_ref().cloned();
273 let storage = Arc::clone(&inner.storage);
274 (f_result, session_data, key, storage)
275 };
276 match storage.save(key, session_data).await {
277 Ok(key) => {
278 let mut inner = self.0.lock().unwrap();
279 if !matches!(&inner.session_key, Some(x) if x == &key) {
280 inner.status = SessionStatus::Changed;
281 inner.session_key = Some(key);
282 }
283 }
284 Err(SaveError::Retryable) => {
285 retry_count += 1;
286 if retry_count > MAX_RETRY_COUNT {
287 bail!("too many session update retry");
288 }
289 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
290 self.reload().await?;
291 continue;
292 }
293 Err(SaveError::RetryableWithData(data)) => {
294 self.reload_from_data(data).await?;
295 continue;
296 }
297 Err(SaveError::Other(e)) => {
298 return Err(e);
299 }
300 }
301 return f_result;
302 }
303 }
304
305 pub async fn purge(&self) -> Result<()> {
306 let key = self.0.lock().unwrap().session_key.as_ref().cloned();
307 let storage = Arc::clone(&self.0.lock().unwrap().storage);
308 if key.is_some() {
309 storage.delete(key.as_ref().unwrap()).await?;
310 }
311 let mut inner = self.0.lock().unwrap();
312 inner.status = SessionStatus::Purged;
313 inner.session_key = None;
314 inner.guest_zone.clear();
315 inner.user_zone.clear();
316 inner.debug_zone.clear();
317 inner.version = 0;
318 Ok(())
319 }
320
321 pub async fn renew<F, R>(&self, f: F) -> Result<R>
328 where
329 F: Fn(&mut SessionInner<Store>) -> Result<R>,
330 {
331 self.reload().await?;
332 let mut retry_count = 0;
333 loop {
334 let (f_result, session_data, storage) = {
335 let mut inner = self.0.lock().unwrap();
336 let f_result = f(&mut inner);
337 inner.update = false;
338 let list = vec![&inner.debug_zone, &inner.user_zone, &inner.guest_zone];
339 let mut buf = Vec::new();
340 ciborium::into_writer(&list, &mut buf)?;
341 let session_data = SessionData::from((buf, inner.state_ttl, inner.version));
342 let storage = Arc::clone(&inner.storage);
343 (f_result, session_data, storage)
344 };
345 match storage.save(None, session_data).await {
346 Ok(key) => {
347 let mut inner = self.0.lock().unwrap();
348 inner.status = SessionStatus::Changed;
349 inner.session_key = Some(key);
350 }
351 Err(SaveError::Retryable) => {
352 retry_count += 1;
353 if retry_count > MAX_RETRY_COUNT {
354 bail!("too many session renew retry");
355 }
356 continue;
357 }
358 Err(SaveError::RetryableWithData(_)) => {
359 bail!("unreachable error");
360 }
361 Err(SaveError::Other(e)) => {
362 return Err(e);
363 }
364 }
365 return f_result;
366 }
367 }
368
369 pub(crate) async fn set_session(
370 req: &mut ServiceRequest,
371 session_key: Option<SessionKey>,
372 storage: Arc<Store>,
373 configuration: &Configuration,
374 ) -> Result<()> {
375 let (session_key, mut data) = if let Some(session_key) = session_key {
376 let data = storage.load(&session_key).await?;
377 if let Some(data) = data {
378 if data.ttl_as_duration().is_positive() {
379 (Some(session_key), data)
380 } else {
381 (None, SessionData::default())
382 }
383 } else {
384 (None, SessionData::default())
385 }
386 } else {
387 (None, SessionData::default())
388 };
389 let mut status = SessionStatus::Unchanged;
390 let ttl = configuration.session.state_ttl.whole_seconds();
391 if let Some(session_key) = &session_key {
392 if (ttl - data.ttl()) > (ttl >> 6) {
393 data.set_ttl(configuration.session.state_ttl);
394 let _ = storage.update_ttl(session_key, &data).await.map_err(|e| {
395 log::warn!("{}", e);
396 });
397 if configuration.cookie.max_age.is_some() {
398 status = SessionStatus::Changed;
399 }
400 }
401 }
402 let mut list: Vec<HashMap<String, Vec<u8>>> = if data.is_empty_data() {
403 Vec::new()
404 } else {
405 ciborium::from_reader(data.data())?
406 };
407 let inner = SessionInner::<Store> {
408 session_key,
409 guest_zone: list.pop().unwrap_or_default(),
410 user_zone: list.pop().unwrap_or_default(),
411 debug_zone: list.pop().unwrap_or_default(),
412 update: false,
413 status,
414 state_ttl: configuration.session.state_ttl,
415 version: data.version(),
416 storage,
417 };
418 let inner = Arc::new(Mutex::new(inner));
419 req.extensions_mut().insert(inner);
420 Ok(())
421 }
422
423 pub async fn reset(&self) {
424 let mut inner = self.0.lock().unwrap();
425 inner.status = SessionStatus::Unchanged;
426 inner.session_key = None;
427 inner.guest_zone.clear();
428 inner.user_zone.clear();
429 inner.debug_zone.clear();
430 inner.version = 0;
431 }
432
433 pub async fn load(&self, key: &str) -> Result<()> {
434 let mut session_key = Some(key.to_string().try_into()?);
435 let (data, mut list) = {
436 let storage = Arc::clone(&self.0.lock().unwrap().storage);
437 let data = {
438 let data = storage.load(session_key.as_ref().unwrap()).await?;
439 if let Some(data) = data {
440 data
441 } else {
442 session_key = None;
443 SessionData::default()
444 }
445 };
446 let list: Vec<HashMap<String, Vec<u8>>> = if data.is_empty_data() {
447 Vec::new()
448 } else {
449 ciborium::from_reader(data.data())?
450 };
451 (data, list)
452 };
453 let mut inner = self.0.lock().unwrap();
454 inner.session_key = session_key;
455 inner.guest_zone = list.pop().unwrap_or_default();
456 inner.user_zone = list.pop().unwrap_or_default();
457 inner.debug_zone = list.pop().unwrap_or_default();
458 inner.version = data.version();
459 Ok(())
460 }
461
462 pub(crate) async fn reload(&self) -> Result<()> {
463 let (data, mut list) = {
464 let key = self.0.lock().unwrap().session_key.as_ref().cloned();
465 let storage = Arc::clone(&self.0.lock().unwrap().storage);
466 let data = if let Some(session_key) = key {
467 let data = storage.reload(&session_key).await?;
468 if let Some(data) = data {
469 data
470 } else {
471 self.0.lock().unwrap().session_key = None;
472 return Ok(());
473 }
474 } else {
475 return Ok(());
476 };
477 let list: Vec<HashMap<String, Vec<u8>>> = if data.is_empty_data() {
478 Vec::new()
479 } else {
480 ciborium::from_reader(data.data())?
481 };
482 (data, list)
483 };
484 let mut inner = self.0.lock().unwrap();
485 inner.guest_zone = list.pop().unwrap_or_default();
486 inner.user_zone = list.pop().unwrap_or_default();
487 inner.debug_zone = list.pop().unwrap_or_default();
488 inner.version = data.version();
489 Ok(())
490 }
491
492 pub(crate) async fn reload_from_data(&self, data: SessionData) -> Result<()> {
493 let (data, mut list) = {
494 let list: Vec<HashMap<String, Vec<u8>>> = if data.is_empty_data() {
495 Vec::new()
496 } else {
497 ciborium::from_reader(data.data())?
498 };
499 (data, list)
500 };
501 let mut inner = self.0.lock().unwrap();
502 inner.guest_zone = list.pop().unwrap_or_default();
503 inner.user_zone = list.pop().unwrap_or_default();
504 inner.debug_zone = list.pop().unwrap_or_default();
505 inner.version = data.version();
506 Ok(())
507 }
508
509 pub(crate) fn get_status<B>(
510 res: &mut ServiceResponse<B>,
511 ) -> (SessionStatus, Option<SessionKey>) {
512 if let Some(s_impl) = res
513 .request()
514 .extensions()
515 .get::<Arc<Mutex<SessionInner<Store>>>>()
516 {
517 let session_key = mem::take(&mut s_impl.lock().unwrap().session_key);
518 (s_impl.lock().unwrap().status, session_key)
519 } else {
520 (SessionStatus::Unchanged, None)
521 }
522 }
523
524 pub(crate) fn get_session(extensions: &mut Extensions) -> Result<Session<Store>> {
525 let s_impl = extensions
526 .get::<Arc<Mutex<SessionInner<Store>>>>()
527 .with_context(|| "No session is set up.")?;
528 Ok(Session(Arc::clone(s_impl)))
529 }
530}
531
532impl<Store: SessionStore + 'static> std::fmt::Debug for Session<Store> {
533 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534 let inner = self.0.lock().unwrap();
535 let list: Vec<HashMap<&String, ciborium::Value>> = vec![
536 inner
537 .guest_zone
538 .iter()
539 .map(|(k, v)| (k, ciborium::from_reader(v.as_slice()).unwrap()))
540 .collect(),
541 inner
542 .user_zone
543 .iter()
544 .map(|(k, v)| (k, ciborium::from_reader(v.as_slice()).unwrap()))
545 .collect(),
546 inner
547 .debug_zone
548 .iter()
549 .map(|(k, v)| (k, ciborium::from_reader(v.as_slice()).unwrap()))
550 .collect(),
551 ];
552 f.debug_tuple("Session")
553 .field(&inner.session_key)
554 .field(&list)
555 .finish()
556 }
557}
558
559impl<Store: SessionStore + 'static> FromRequest for Session<Store> {
560 type Error = Error;
561 type Future = Ready<Result<Session<Store>, Error>>;
562
563 #[inline]
564 fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
565 ready(Session::get_session(&mut req.extensions_mut()).map_err(e500))
566 }
567}