use async_trait::async_trait;
use bson::{doc, to_document};
use mongodb::{options::UpdateOptions, Client, Collection};
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use crate::{session::Id, ExpiredDeletion, Session, SessionStore};
#[derive(thiserror::Error, Debug)]
pub enum MongoDBStoreError {
#[error("MongoDB error: {0}")]
MongoDB(#[from] mongodb::error::Error),
#[error("Bson serialize error: {0}")]
BsonSerialize(#[from] bson::ser::Error),
#[error("Bson deserialize error: {0}")]
BsonDeserialize(#[from] bson::de::Error),
#[error("Rust MsgPack encode error: {0}")]
RmpSerdeEncode(#[from] rmp_serde::encode::Error),
#[error("Rust MsgPack decode error: {0}")]
RmpSerdeDecode(#[from] rmp_serde::decode::Error),
}
#[derive(Serialize, Deserialize, Debug)]
struct MongoDBSessionRecord {
data: bson::Binary,
#[serde(rename = "expireAt")]
expiry_date: bson::DateTime,
}
#[derive(Clone, Debug)]
pub struct MongoDBStore {
collection: Collection<MongoDBSessionRecord>,
}
impl MongoDBStore {
pub fn new(client: Client, database: String) -> Self {
Self {
collection: client.database(&database).collection("sessions"),
}
}
}
#[async_trait]
impl ExpiredDeletion for MongoDBStore {
async fn delete_expired(&self) -> Result<(), Self::Error> {
self.collection
.delete_many(
doc! { "expireAt": {"$lt": OffsetDateTime::now_utc()} },
None,
)
.await?;
Ok(())
}
}
#[async_trait]
impl SessionStore for MongoDBStore {
type Error = MongoDBStoreError;
async fn save(&self, session: &Session) -> Result<(), Self::Error> {
let doc = to_document(&MongoDBSessionRecord {
data: bson::Binary {
subtype: bson::spec::BinarySubtype::Generic,
bytes: rmp_serde::to_vec(session)?,
},
expiry_date: bson::DateTime::from(session.expiry_date()),
})?;
self.collection
.update_one(
doc! { "_id": session.id().to_string() },
doc! { "$set": doc },
UpdateOptions::builder().upsert(true).build(),
)
.await?;
Ok(())
}
async fn load(&self, session_id: &Id) -> Result<Option<Session>, Self::Error> {
let doc = self
.collection
.find_one(
doc! {
"_id": session_id.to_string(),
"expireAt": {"$gt": OffsetDateTime::now_utc()}
},
None,
)
.await?;
if let Some(doc) = doc {
Ok(Some(rmp_serde::from_slice(&doc.data.bytes)?))
} else {
Ok(None)
}
}
async fn delete(&self, session_id: &Id) -> Result<(), Self::Error> {
self.collection
.delete_one(doc! { "_id": session_id.to_string() }, None)
.await?;
Ok(())
}
}