use crate::serializer::Serializer;
use crate::Savepointable;
use crate::{db, identifier::Identifier};
use rusqlite::{params, Connection, OptionalExtension, Savepoint};
use std::{borrow::Borrow, marker::PhantomData};
mod error;
pub use error::{Error, OpenError};
#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
pub struct Config<'db, 'tbl> {
pub database: Identifier<'db>,
pub table: Identifier<'tbl>,
}
impl Default for Config<'static, 'static> {
fn default() -> Self {
Config {
database: "main".try_into().unwrap(),
table: "ds::map".try_into().unwrap(),
}
}
}
pub struct Map<'db, 'tbl, K, V, C>
where
K: Serializer,
V: Serializer,
C: Savepointable,
{
connection: C,
database: Identifier<'db>,
table: Identifier<'tbl>,
key_serializer: PhantomData<K>,
value_serializer: PhantomData<V>,
}
impl<K, V, C> Map<'static, 'static, K, V, C>
where
K: Serializer,
V: Serializer,
C: Savepointable,
{
pub fn open(connection: C) -> Result<Self, OpenError> {
Map::open_with_config(connection, Config::default())
}
pub fn unchecked_open(connection: C) -> Self {
Map::unchecked_open_with_config(connection, Config::default())
}
}
impl<'db, 'tbl, K, V, C> Map<'db, 'tbl, K, V, C>
where
K: Serializer,
V: Serializer,
C: Savepointable,
{
pub fn open_with_config(
mut connection: C,
config: Config<'db, 'tbl>,
) -> Result<Self, OpenError> {
let database = config.database;
let table = config.table;
{
let sp = connection.savepoint()?;
let mut version = db::setup(&sp, &database, &table, "ds::map")?;
if version < 0 {
return Err(OpenError::TableVersion(version));
}
let prev_version = version;
if version < 1 {
let trailer = db::strict_without_rowid();
let sql_type = K::sql_type();
sp.execute(
&format!(
"CREATE TABLE {database}.{table} (
key {sql_type} UNIQUE PRIMARY KEY NOT NULL,
value {sql_type} NOT NULL
){trailer}"
),
[],
)?;
version = 1;
}
if version > 1 {
return Err(OpenError::TableVersion(version));
}
if prev_version != version {
db::set_version(&sp, &database, &table, version)?;
}
sp.commit()?;
}
Ok(Self {
connection,
database,
table,
key_serializer: PhantomData,
value_serializer: PhantomData,
})
}
pub fn unchecked_open_with_config(connection: C, config: Config<'db, 'tbl>) -> Self {
let database = config.database;
let table = config.table;
Self {
connection,
database,
table,
key_serializer: PhantomData,
value_serializer: PhantomData,
}
}
pub fn insert(
&mut self,
key: &K::TargetBorrowed,
value: &V::TargetBorrowed,
) -> Result<Option<V::Target>, Error<K, V>> {
let database = &self.database;
let table = &self.table;
let key = K::serialize(key).map_err(|e| Error::KeySerialize(e))?;
let value = V::serialize(value).map_err(|e| Error::ValueSerialize(e))?;
let sp = self.connection.savepoint()?;
let prev_value = Self::get_from_serialized(database, table, &sp, key.borrow())?;
if db::has_upsert() {
sp.prepare_cached(&format!(
"INSERT INTO {database}.{table} (key, value) VALUES (?, ?) ON CONFLICT DO UPDATE SET value=excluded.value"
))?
.execute(params![key.borrow(), value.borrow()])?;
} else if prev_value.is_some() {
sp.prepare_cached(&format!(
"UPDATE {database}.{table} SET value=? WHERE key=?"
))?
.execute(params![value.borrow(), key.borrow()])?;
} else {
sp.prepare_cached(&format!(
"INSERT INTO {database}.{table} (key, value) VALUES (?, ?)"
))?
.execute(params![key.borrow(), value.borrow()])?;
};
sp.commit()?;
Ok(prev_value)
}
fn contains_from_serialized(
database: &Identifier,
table: &Identifier,
connection: &Connection,
key: &K::BufferBorrowed,
) -> Result<bool, Error<K, V>> {
Ok(connection
.prepare_cached(&format!("SELECT 1 FROM {database}.{table} WHERE key = ?"))?
.query_row(params![key], |_| Ok(()))
.optional()?
.is_some())
}
fn get_from_serialized(
database: &Identifier,
table: &Identifier,
connection: &Connection,
key: &K::BufferBorrowed,
) -> Result<Option<V::Target>, Error<K, V>> {
Ok(
Self::get_serialized_from_serialized(database, table, connection, key)?
.map(|b| V::deserialize(b.borrow()).map_err(|e| Error::ValueDeserialize(e)))
.transpose()?,
)
}
fn get_serialized_from_serialized(
database: &Identifier,
table: &Identifier,
connection: &Connection,
key: &K::BufferBorrowed,
) -> Result<Option<V::Buffer>, Error<K, V>> {
Ok(connection
.prepare_cached(&format!(
"SELECT value FROM {database}.{table} WHERE key = ?"
))?
.query_row(params![key], |row| row.get(0))
.optional()?)
}
}