use crate::serializer::Serializer;
use crate::Savepointable;
use crate::{db, identifier::Identifier};
use rusqlite::{params, Connection, OptionalExtension, Savepoint};
use std::{borrow::Borrow, marker::PhantomData};
mod error;
mod iter;
pub use iter::Iter;
pub use error::{Error, OpenError};
#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
pub struct Config<'db, 'tbl> {
pub database: Identifier<'db>,
pub table: Identifier<'tbl>,
}
impl Default for Config<'static, 'static> {
fn default() -> Self {
Config {
database: "main".try_into().unwrap(),
table: "ds::set".try_into().unwrap(),
}
}
}
pub struct Set<'db, 'tbl, S, C>
where
S: Serializer,
C: Savepointable,
{
connection: C,
database: Identifier<'db>,
table: Identifier<'tbl>,
serializer: PhantomData<S>,
}
impl<S, C> Set<'static, 'static, S, C>
where
S: Serializer,
C: Savepointable,
{
pub fn open(connection: C) -> Result<Self, OpenError> {
Set::open_with_config(connection, Config::default())
}
pub fn unchecked_open(connection: C) -> Self {
Set::unchecked_open_with_config(connection, Config::default())
}
}
impl<'db, 'tbl, S, C> Set<'db, 'tbl, S, C>
where
S: Serializer,
C: Savepointable,
{
pub fn open_with_config(
mut connection: C,
config: Config<'db, 'tbl>,
) -> Result<Self, OpenError> {
let database = config.database;
let table = config.table;
{
let sp = connection.savepoint()?;
let mut version = db::setup(&sp, &database, &table, "ds::set")?;
if version < 0 {
return Err(OpenError::TableVersion(version));
}
let prev_version = version;
if version < 1 {
let trailer = db::strict_without_rowid();
let sql_type = S::sql_type();
sp.execute(
&format!(
"CREATE TABLE {database}.{table} (
key {sql_type} UNIQUE PRIMARY KEY NOT NULL
){trailer}"
),
[],
)?;
version = 1;
}
if version > 1 {
return Err(OpenError::TableVersion(version));
}
if prev_version != version {
db::set_version(&sp, &database, &table, version)?;
}
sp.commit()?;
}
Ok(Self {
connection,
database,
table,
serializer: PhantomData,
})
}
pub fn unchecked_open_with_config(connection: C, config: Config<'db, 'tbl>) -> Self {
let database = config.database;
let table = config.table;
Self {
connection,
database,
table,
serializer: PhantomData,
}
}
pub fn insert(&mut self, value: &S::TargetBorrowed) -> Result<bool, Error<S>> {
let database = &self.database;
let table = &self.table;
let serialized = match S::serialize(value) {
Ok(s) => s,
Err(e) => return Err(Error::Serialize(e)),
};
let sp = self.connection.savepoint()?;
let ret = if db::has_upsert() {
sp.prepare_cached(&format!(
"INSERT INTO {database}.{table} (key) VALUES (?) ON CONFLICT DO NOTHING"
))?
.execute(params![serialized.borrow()])?;
sp.changes() > 0
} else if Self::contains_serialized(database, table, &sp, serialized.borrow())? {
false
} else {
sp.prepare_cached(&format!("INSERT INTO {database}.{table} (key) VALUES (?)"))?
.execute(params![serialized.borrow()])?;
true
};
sp.commit()?;
Ok(ret)
}
pub fn contains(&mut self, value: &S::TargetBorrowed) -> Result<bool, Error<S>> {
let serialized = S::serialize(value).map_err(|e| Error::Serialize(e))?;
Self::contains_serialized(
&self.database,
&self.table,
&*self.connection.savepoint()?,
serialized.borrow(),
)
}
fn contains_serialized(
database: &Identifier,
table: &Identifier,
connection: &Connection,
value: &S::BufferBorrowed,
) -> Result<bool, Error<S>> {
Ok(connection
.prepare_cached(&format!("SELECT 1 FROM {database}.{table} WHERE key = ?"))?
.query_row(params![value], |_| Ok(()))
.optional()?
.is_some())
}
pub fn remove<Q>(&mut self, value: &S::TargetBorrowed) -> Result<bool, Error<S>> {
let database = &self.database;
let table = &self.table;
let serialized = match S::serialize(value) {
Ok(s) => s,
Err(e) => return Err(Error::Serialize(e)),
};
let sp = self.connection.savepoint()?;
let changes = sp
.prepare_cached(&format!("DELETE FROM {database}.{table} WHERE key = ?"))?
.execute(params![serialized.borrow()])?;
sp.commit()?;
Ok(changes > 0)
}
pub fn clear(&mut self) -> Result<(), Error<S>> {
let database = &self.database;
let table = &self.table;
let sp = self.connection.savepoint()?;
sp.prepare_cached(&format!("DELETE FROM {database}.{table}"))?
.execute([])?;
sp.commit()?;
Ok(())
}
pub fn first(&mut self) -> Result<Option<S::Target>, Error<S>> {
let database = &self.database;
let table = &self.table;
let serialized: Option<S::Buffer> = self
.connection
.savepoint()?
.prepare_cached(&format!(
"SELECT key FROM {database}.{table} ORDER BY key ASC"
))?
.query_row([], |row| row.get(0))
.optional()?;
match serialized.map(|s| S::deserialize(s.borrow())).transpose() {
Ok(s) => Ok(s),
Err(e) => Err(Error::Deserialize(e)),
}
}
pub fn last(&mut self) -> Result<Option<S::Target>, Error<S>> {
let database = &self.database;
let table = &self.table;
let serialized: Option<S::Buffer> = self
.connection
.savepoint()?
.prepare_cached(&format!(
"SELECT key FROM {database}.{table} ORDER BY key DESC"
))?
.query_row([], |row| row.get(0))
.optional()?;
match serialized.map(|s| S::deserialize(s.borrow())).transpose() {
Ok(s) => Ok(s),
Err(e) => Err(Error::Deserialize(e)),
}
}
pub fn len(&mut self) -> Result<u64, Error<S>> {
let database = &self.database;
let table = &self.table;
Ok(self
.connection
.savepoint()?
.prepare_cached(&format!("SELECT COUNT(*) FROM {database}.{table}"))?
.query_row([], |row| row.get(0))?)
}
pub fn iter(&mut self) -> Result<Iter<'db, 'tbl, S, Savepoint<'_>>, Error<S>> {
Ok(Iter::new(
self.connection.savepoint()?,
self.database.clone(),
self.table.clone(),
)?)
}
}
#[cfg(test)]
mod test;