1use crate::format::Format;
2use crate::OpenError;
3use crate::Savepointable;
4use crate::{db, identifier::Identifier};
5use rusqlite::Savepoint;
6use rusqlite::{params, Connection, OptionalExtension};
7
8use std::marker::PhantomData;
9
10mod error;
11pub mod iter;
12
13use iter::KeyIter;
14use iter::KeyValueIter;
15use iter::ValueIter;
16
17pub use error::Error;
18
19#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
20pub struct Config<'db, 'tbl> {
21 pub database: Identifier<'db>,
22 pub table: Identifier<'tbl>,
23}
24
25impl Default for Config<'static, 'static> {
26 fn default() -> Self {
27 Config {
28 database: "main".try_into().unwrap(),
29 table: "ds::map".try_into().unwrap(),
30 }
31 }
32}
33
34pub struct Map<'db, 'tbl, K, V, C>
36where
37 K: Format,
38 V: Format,
39 C: Savepointable,
40{
41 connection: C,
42 database: Identifier<'db>,
43 table: Identifier<'tbl>,
44 key_serializer: PhantomData<K>,
45 value_serializer: PhantomData<V>,
46}
47
48impl<K, V, C> Map<'static, 'static, K, V, C>
49where
50 K: Format,
51 V: Format,
52 C: Savepointable,
53{
54 pub fn open(connection: C) -> Result<Self, OpenError> {
55 Map::open_with_config(connection, Config::default())
56 }
57
58 pub fn unchecked_open(connection: C) -> Self {
62 Map::unchecked_open_with_config(connection, Config::default())
63 }
64}
65
66impl<'db, 'tbl, K, V, C> Map<'db, 'tbl, K, V, C>
67where
68 K: Format,
69 V: Format,
70 C: Savepointable,
71{
72 pub fn open_with_config(
73 mut connection: C,
74 config: Config<'db, 'tbl>,
75 ) -> Result<Self, OpenError> {
76 let database = config.database;
77 let table = config.table;
78
79 {
80 let sp = connection.savepoint()?;
81
82 let mut version = db::setup(&sp, &database, &table, "ds::map")?;
83 if version < 0 {
84 return Err(OpenError::TableVersion(version));
85 }
86 let prev_version = version;
87 if version < 1 {
88 let trailer = db::strict_without_rowid();
89 let sql_type = K::sql_type();
90
91 sp.execute(
92 &format!(
93 "CREATE TABLE {database}.{table} (
94 key {sql_type} UNIQUE PRIMARY KEY NOT NULL,
95 value {sql_type} NOT NULL
96 ){trailer}"
97 ),
98 [],
99 )?;
100 version = 1;
101 }
102 if version > 1 {
103 return Err(OpenError::TableVersion(version));
104 }
105 if prev_version != version {
106 db::set_version(&sp, &database, &table, version)?;
107 }
108
109 sp.commit()?;
110 }
111 Ok(Self {
112 connection,
113 database,
114 table,
115 key_serializer: PhantomData,
116 value_serializer: PhantomData,
117 })
118 }
119
120 pub fn unchecked_open_with_config(connection: C, config: Config<'db, 'tbl>) -> Self {
124 let database = config.database;
125 let table = config.table;
126
127 Self {
128 connection,
129 database,
130 table,
131 key_serializer: PhantomData,
132 value_serializer: PhantomData,
133 }
134 }
135
136 pub fn insert(&mut self, key: &K::In, value: &V::In) -> Result<Option<V::Out>, Error<K, V>> {
137 let database = &self.database;
138 let table = &self.table;
139 let key = K::serialize(key).map_err(|e| Error::KeySerialize(e))?;
140 let value = V::serialize(value).map_err(|e| Error::ValueSerialize(e))?;
141
142 let sp = self.connection.savepoint()?;
143 let prev_value = Self::get_from_serialized(database, table, &sp, &key)?;
144 if db::has_upsert() {
145 sp.prepare_cached(&format!(
146 "INSERT INTO {database}.{table} (key, value) VALUES (?, ?) ON CONFLICT DO UPDATE SET value=excluded.value"
147 ))?
148 .execute(params![key, value])?;
149 } else if prev_value.is_some() {
150 sp.prepare_cached(&format!(
151 "UPDATE {database}.{table} SET value=? WHERE key=?"
152 ))?
153 .execute(params![value, key])?;
154 } else {
155 sp.prepare_cached(&format!(
156 "INSERT INTO {database}.{table} (key, value) VALUES (?, ?)"
157 ))?
158 .execute(params![key, value])?;
159 };
160 sp.commit()?;
161 Ok(prev_value)
162 }
163
164 pub fn len(&mut self) -> Result<u64, Error<K, V>> {
165 let database = &self.database;
166 let table = &self.table;
167 Ok(self
168 .connection
169 .savepoint()?
170 .prepare_cached(&format!("SELECT COUNT(*) FROM {database}.{table}"))?
171 .query_row([], |row| row.get(0))?)
172 }
173
174 pub fn is_empty(&mut self) -> Result<bool, Error<K, V>> {
175 let database = &self.database;
176 let table = &self.table;
177 Ok(self
178 .connection
179 .savepoint()?
180 .prepare_cached(&format!("SELECT 1 FROM {database}.{table} LIMIT 1"))?
181 .query_row([], |_| Ok(()))
182 .optional()?
183 .is_none())
184 }
185
186 fn contains_from_serialized(
187 database: &Identifier,
188 table: &Identifier,
189 connection: &Connection,
190 key: &K::Buffer,
191 ) -> Result<bool, Error<K, V>> {
192 Ok(connection
193 .prepare_cached(&format!("SELECT 1 FROM {database}.{table} WHERE key = ?"))?
194 .query_row(params![key], |_| Ok(()))
195 .optional()?
196 .is_some())
197 }
198
199 fn get_from_serialized(
200 database: &Identifier,
201 table: &Identifier,
202 connection: &Connection,
203 key: &K::Buffer,
204 ) -> Result<Option<V::Out>, Error<K, V>> {
205 Ok(
206 Self::get_serialized_from_serialized(database, table, connection, key)?
207 .map(|b| V::deserialize(&b).map_err(|e| Error::ValueDeserialize(e)))
208 .transpose()?,
209 )
210 }
211
212 fn get_serialized_from_serialized(
213 database: &Identifier,
214 table: &Identifier,
215 connection: &Connection,
216 key: &K::Buffer,
217 ) -> Result<Option<V::Buffer>, Error<K, V>> {
218 Ok(connection
219 .prepare_cached(&format!(
220 "SELECT value FROM {database}.{table} WHERE key = ?"
221 ))?
222 .query_row(params![key], |row| row.get(0))
223 .optional()?)
224 }
225
226 pub fn get(&mut self, key: &K::In) -> Result<Option<V::Out>, Error<K, V>> {
227 let database = &self.database;
228 let table = &self.table;
229 let key = K::serialize(key).map_err(|e| Error::KeySerialize(e))?;
230
231 let sp = self.connection.savepoint()?;
232 let result = Self::get_from_serialized(database, table, &sp, &key)?;
233 sp.commit()?;
234 Ok(result)
235 }
236
237 pub fn remove(&mut self, key: &K::In) -> Result<Option<V::Out>, Error<K, V>> {
238 let database = &self.database;
239 let table = &self.table;
240 let key = K::serialize(key).map_err(|e| Error::KeySerialize(e))?;
241
242 let sp = self.connection.savepoint()?;
243 let result = Self::get_from_serialized(database, table, &sp, &key)?;
244 sp.prepare_cached(&format!("DELETE FROM {database}.{table} WHERE key = ?"))?
245 .execute(params![key])?;
246 sp.commit()?;
247 Ok(result)
248 }
249
250 pub fn contains_key(&mut self, key: &K::In) -> Result<bool, Error<K, V>> {
251 let database = &self.database;
252 let table = &self.table;
253 let key = K::serialize(key).map_err(|e| Error::KeySerialize(e))?;
254
255 let sp = self.connection.savepoint()?;
256 let result = Self::contains_from_serialized(database, table, &sp, &key)?;
257 sp.commit()?;
258 Ok(result)
259 }
260
261 pub fn clear(&mut self) -> Result<(), Error<K, V>> {
262 let database = &self.database;
263 let table = &self.table;
264 let sp = self.connection.savepoint()?;
265 sp.prepare_cached(&format!("DELETE FROM {database}.{table}"))?
266 .execute([])?;
267 sp.commit()?;
268 Ok(())
269 }
270
271 pub fn iter(&mut self) -> Result<KeyValueIter<'db, 'tbl, K, V, Savepoint<'_>>, Error<K, V>> {
272 Ok(KeyValueIter::new(
273 self.connection.savepoint()?,
274 self.database.clone(),
275 self.table.clone(),
276 )?)
277 }
278
279 pub fn keys(&mut self) -> Result<KeyIter<'db, 'tbl, K, V, Savepoint<'_>>, Error<K, V>> {
280 Ok(KeyIter::new(
281 self.connection.savepoint()?,
282 self.database.clone(),
283 self.table.clone(),
284 )?)
285 }
286 pub fn values(&mut self) -> Result<ValueIter<'db, 'tbl, K, V, Savepoint<'_>>, Error<K, V>> {
287 Ok(ValueIter::new(
288 self.connection.savepoint()?,
289 self.database.clone(),
290 self.table.clone(),
291 )?)
292 }
293
294 pub fn retain<F>(&mut self, mut f: F) -> Result<(), Error<K, V>>
301 where
302 F: FnMut(K::Out, V::Out) -> bool,
303 {
304 let database = &self.database;
305 let table = &self.table;
306
307 let sp = self.connection.savepoint()?;
308 {
309 let mut maybe_serialized = sp
310 .prepare_cached(&format!(
311 "SELECT key, value FROM {database}.{table} ORDER BY key ASC LIMIT 1"
312 ))?
313 .query_row([], |row| Ok((row.get(0)?, row.get(1)?)))
314 .optional()?;
315 let mut deleter =
316 sp.prepare_cached(&format!("DELETE FROM {database}.{table} WHERE key = ?"))?;
317 let mut select_next = sp.prepare_cached(&format!(
318 "SELECT key, value FROM {database}.{table} WHERE key > ? ORDER BY key ASC LIMIT 1"
319 ))?;
320 while let Some((serialized_key, value)) = maybe_serialized {
321 let key = K::deserialize(&serialized_key).map_err(Error::KeyDeserialize)?;
322 let value = V::deserialize(&value).map_err(Error::ValueDeserialize)?;
323 if !f(key, value) {
324 deleter.execute(params![serialized_key])?;
325 }
326 maybe_serialized = select_next
327 .query_row(params![serialized_key], |row| {
328 Ok((row.get(0)?, row.get(1)?))
329 })
330 .optional()?;
331 }
332 }
333 sp.commit()?;
334 Ok(())
335 }
336}
337
338#[cfg(test)]
339mod test;