1use std::{error::Error, fmt, sync::Arc, time::SystemTimeError};
2
3use serde::{Serialize, de::DeserializeOwned};
4use serde_json::Error as JsonError;
5use tokio::sync::Mutex;
6
7use crate::{
8 backend::SessionBackend,
9 utils::{decode_value, encode_value},
10 value::{Value, ValueRef},
11};
12
13#[derive(Clone)]
15pub struct Session<B> {
16 id: String,
17 backend: Arc<Mutex<B>>,
18}
19
20impl<B> Session<B>
21where
22 B: SessionBackend,
23{
24 pub(crate) fn new<I>(id: I, backend: Arc<Mutex<B>>) -> Self
25 where
26 I: Into<String>,
27 {
28 Self { id: id.into(), backend }
29 }
30
31 async fn read_value(&mut self, key: &str) -> Result<Option<Value>, SessionError> {
32 let mut backend = self.backend.lock().await;
33 match backend
34 .read_value(&self.id, key.as_ref())
35 .await
36 .map_err(SessionError::backend)?
37 {
38 Some(value) => {
39 let value = decode_value(&value).map_err(SessionError::DecodeValue)?;
40 Ok(Some(value))
41 }
42 None => Ok(None),
43 }
44 }
45
46 async fn write_value<V: Serialize>(&mut self, key: &str, value: V) -> Result<(), SessionError> {
47 let mut backend = self.backend.lock().await;
48 let data = encode_value(&value).map_err(SessionError::EncodeValue)?;
49 backend
50 .write_value(&self.id, key.as_ref(), &data)
51 .await
52 .map_err(SessionError::backend)?;
53 Ok(())
54 }
55
56 pub async fn set<K, V>(&mut self, key: K, value: &V) -> Result<(), SessionError>
58 where
59 K: AsRef<str>,
60 V: Serialize,
61 {
62 let key = key.as_ref();
63 let mut value = ValueRef::new(&value);
64 if let Some(old_value) = self.read_value(key).await?
65 && !old_value.is_expired().map_err(SessionError::CheckExpired)?
66 && let Some(expires_at) = old_value.get_expires_at()
67 {
68 value.set_expires_at(expires_at);
69 };
70 self.write_value(key, value).await?;
71 Ok(())
72 }
73
74 pub async fn get<K, O>(&mut self, key: K) -> Result<Option<O>, SessionError>
76 where
77 K: AsRef<str>,
78 O: DeserializeOwned,
79 {
80 Ok(
81 if let Some(value) = self.read_value(key.as_ref()).await.map_err(SessionError::backend)? {
82 if value.is_expired().map_err(SessionError::CheckExpired)? {
83 None
84 } else {
85 Some(value.into_parsed().map_err(SessionError::ParseValue)?)
86 }
87 } else {
88 None
89 },
90 )
91 }
92
93 pub async fn expire<K>(&mut self, key: K, seconds: u64) -> Result<(), SessionError>
95 where
96 K: AsRef<str>,
97 {
98 let key = key.as_ref();
99 if let Some(mut value) = self.read_value(key).await.map_err(SessionError::backend)? {
100 value.set_lifetime(seconds).map_err(SessionError::ExpireValue)?;
101 self.write_value(key, value).await.map_err(SessionError::backend)?;
102 }
103 Ok(())
104 }
105
106 pub async fn remove<K>(&mut self, key: K) -> Result<(), SessionError>
108 where
109 K: AsRef<str>,
110 {
111 let mut backend = self.backend.lock().await;
112 backend
113 .remove_value(&self.id, key.as_ref())
114 .await
115 .map_err(SessionError::backend)
116 }
117}
118
119#[derive(Debug)]
121pub enum SessionError {
122 Backend(Box<dyn Error + Send + Sync>),
124 CheckExpired(SystemTimeError),
126 DecodeValue(JsonError),
128 EncodeValue(JsonError),
130 ExpireValue(SystemTimeError),
132 ParseValue(JsonError),
134}
135
136impl SessionError {
137 fn backend<E: Error + Send + Sync + 'static>(err: E) -> Self {
138 Self::Backend(Box::new(err))
139 }
140}
141
142impl Error for SessionError {
143 fn source(&self) -> Option<&(dyn Error + 'static)> {
144 match self {
145 SessionError::Backend(err) => Some(err.as_ref()),
146 SessionError::CheckExpired(err) => Some(err),
147 SessionError::DecodeValue(err) => Some(err),
148 SessionError::EncodeValue(err) => Some(err),
149 SessionError::ExpireValue(err) => Some(err),
150 SessionError::ParseValue(err) => Some(err),
151 }
152 }
153}
154
155impl fmt::Display for SessionError {
156 fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
157 match self {
158 SessionError::Backend(err) => write!(out, "backend error: {err}"),
159 SessionError::CheckExpired(err) => {
160 write!(out, "failed to check whether value expired: {err}")
161 }
162 SessionError::DecodeValue(err) => write!(out, "failed to decode value: {err}"),
163 SessionError::EncodeValue(err) => write!(out, "failed to encode value: {err}"),
164 SessionError::ExpireValue(err) => write!(out, "failed to expire value: {err}"),
165 SessionError::ParseValue(err) => write!(out, "failed to parse value: {err}"),
166 }
167 }
168}