1use std::{error::Error as StdError, fmt};
36
37use async_trait::async_trait;
38use time::OffsetDateTime;
39
40use crate::SessionKey;
41
42type Result<T, E = Error> = std::result::Result<T, E>;
43
44pub trait SessionStore<T>: 'static + Send + Sync + SessionStoreImpl<T> {}
55
56#[doc(hidden)]
59#[async_trait]
60pub trait SessionStoreImpl<T>: 'static + Send + Sync {
61 async fn create(&self, data: &T, ttl: Ttl) -> Result<SessionKey>;
62
63 async fn load(&self, session_key: &SessionKey) -> Result<Option<Record<T>>>;
64
65 async fn update(&self, session_key: &SessionKey, data: &T, ttl: Ttl) -> Result<()>;
66
67 async fn update_ttl(&self, session_key: &SessionKey, ttl: Ttl) -> Result<()>;
68
69 async fn delete(&self, session_key: &SessionKey) -> Result<()>;
70}
71
72pub type Ttl = OffsetDateTime;
77
78#[derive(Clone, Debug)]
80#[non_exhaustive]
81pub struct Record<T> {
82 pub data: T,
83 pub ttl: Ttl,
84}
85
86impl<T> Record<T> {
87 pub fn new(data: T, ttl: Ttl) -> Record<T> {
88 Record { data, ttl }
89 }
90
91 pub fn unix_timestamp(&self) -> i64 {
92 self.ttl.unix_timestamp()
93 }
94}
95
96pub struct Error {
98 kind: ErrorKind,
100}
101
102#[non_exhaustive]
104pub enum ErrorKind {
105 Store(Box<dyn StdError + Send + Sync>),
107 Serde(Box<dyn StdError + Send + Sync>),
109 Message(Box<str>),
111}
112
113impl Error {
115 #[must_use]
118 pub fn store(err: impl Into<Box<dyn StdError + Send + Sync + 'static>>) -> Error {
119 Error {
120 kind: ErrorKind::Store(err.into()),
121 }
122 }
123
124 #[must_use]
127 pub fn serde(err: impl Into<Box<dyn StdError + Send + Sync + 'static>>) -> Error {
128 Error {
129 kind: ErrorKind::Serde(err.into()),
130 }
131 }
132
133 #[must_use]
135 pub fn message(msg: impl Into<Box<str>>) -> Error {
136 Error {
137 kind: ErrorKind::Message(msg.into()),
138 }
139 }
140
141 pub fn kind(&self) -> &ErrorKind {
143 &self.kind
144 }
145}
146
147impl fmt::Debug for Error {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 let mut builder = f.debug_struct("store::Error");
150
151 use ErrorKind::*;
152 match &self.kind {
153 Message(msg) => {
154 builder.field("message", msg);
155 }
156 Store(err) => {
157 builder.field("kind", &"Store");
158 builder.field("source", err);
159 }
160 Serde(err) => {
161 builder.field("kind", &"Serde");
162 builder.field("source", err);
163 }
164 }
165
166 builder.finish()
167 }
168}
169
170impl fmt::Display for Error {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 use ErrorKind::*;
173 match &self.kind {
174 Message(msg) => f.write_str(msg),
175 Store(_) => f.write_str("session store error"),
176 Serde(_) => f.write_str("session serialization error"),
177 }
178 }
179}
180
181impl StdError for Error {
182 fn source(&self) -> Option<&(dyn StdError + 'static)> {
183 use ErrorKind::*;
184 match &self.kind {
185 Message(_) => None,
186 Store(err) => Some(err.as_ref()),
187 Serde(err) => Some(err.as_ref()),
188 }
189 }
190}
191
192#[cfg(test)]
193mod test {
194 use std::iter;
195
196 use serde::Deserialize;
197
198 use super::*;
199
200 trait ErrorExt {
201 fn display_chain(&self) -> DisplayChain<'_>;
202 }
203
204 impl<E> ErrorExt for E
205 where
206 E: StdError + 'static,
207 {
208 fn display_chain(&self) -> DisplayChain<'_> {
209 DisplayChain { inner: self }
210 }
211 }
212
213 struct DisplayChain<'a> {
214 inner: &'a (dyn StdError + 'static),
215 }
216
217 impl fmt::Display for DisplayChain<'_> {
218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219 write!(f, "{}", self.inner)?;
220
221 for error in iter::successors(Some(self.inner), |err| (*err).source()).skip(1) {
222 write!(f, ": {}", error)?;
223 }
224
225 Ok(())
226 }
227 }
228
229 #[test]
230 fn test_store_dyn_compatible() {
231 use std::sync::Arc;
232
233 const _: fn() = || {
234 let _dyn_store: Arc<dyn SessionStore<()>> = todo!();
235 };
236 }
237
238 #[test]
239 fn test_error_constraints() {
240 fn require_traits<T: Send + Sync + 'static>() {}
241
242 require_traits::<Error>();
243 }
244
245 fn error_store() -> Error {
246 let err = "Reconnecting failed: Connection refused (os error 111)";
247 Error::store(err)
248 }
249
250 fn error_serde() -> Error {
251 #[derive(Debug, Deserialize)]
252 struct Data {
253 #[allow(dead_code)]
254 hello: String,
255 }
256
257 let err = serde_json::from_str::<Data>(r#"{"hello": "world}"#).unwrap_err();
258 Error::serde(err)
259 }
260
261 fn error_msg() -> Error {
262 Error::message("max iterations reached when handling session key collisions")
263 }
264
265 #[test]
266 fn test_error_display() {
267 insta::assert_snapshot!(error_store(), @"session store error");
268 insta::assert_snapshot!(
269 error_store().display_chain(),
270 @"session store error: Reconnecting failed: Connection refused (os error 111)"
271 );
272 insta::assert_snapshot!(error_serde(), @"session serialization error");
273 insta::assert_snapshot!(
274 error_serde().display_chain(),
275 @"session serialization error: EOF while parsing a string at line 1 column 17"
276 );
277 insta::assert_snapshot!(error_msg(), @"max iterations reached when handling session key collisions");
278 }
279
280 #[test]
281 fn test_error_debug() {
282 insta::assert_debug_snapshot!( error_store(), @r#"
283 store::Error {
284 kind: "Store",
285 source: "Reconnecting failed: Connection refused (os error 111)",
286 }
287 "#
288 );
289 insta::assert_debug_snapshot!(error_serde(), @r#"
290 store::Error {
291 kind: "Serde",
292 source: Error("EOF while parsing a string", line: 1, column: 17),
293 }
294 "#
295 );
296 insta::assert_debug_snapshot!(error_msg(), @r#"
297 store::Error {
298 message: "max iterations reached when handling session key collisions",
299 }
300 "#);
301 }
302}