sqlite_collections/ds/
set.rs

1use crate::format::Format;
2use crate::OpenError;
3use crate::Savepointable;
4use crate::{db, identifier::Identifier};
5use rusqlite::{params, Connection, OptionalExtension, Savepoint};
6
7use std::marker::PhantomData;
8
9mod error;
10mod iter;
11pub use iter::Iter;
12
13pub use error::Error;
14
15#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
16pub struct Config<'db, 'tbl> {
17    pub database: Identifier<'db>,
18    pub table: Identifier<'tbl>,
19}
20
21impl Default for Config<'static, 'static> {
22    fn default() -> Self {
23        Config {
24            database: "main".try_into().unwrap(),
25            table: "ds::set".try_into().unwrap(),
26        }
27    }
28}
29
30/// Deterministic store set.
31pub struct Set<'db, 'tbl, S, C>
32where
33    S: Format,
34    C: Savepointable,
35{
36    connection: C,
37    database: Identifier<'db>,
38    table: Identifier<'tbl>,
39    serializer: PhantomData<S>,
40}
41
42impl<S, C> Set<'static, 'static, S, C>
43where
44    S: Format,
45    C: Savepointable,
46{
47    pub fn open(connection: C) -> Result<Self, OpenError> {
48        Set::open_with_config(connection, Config::default())
49    }
50
51    /// Open a set without creating it or checking if it exists.  This is safe
52    /// if you call a safe open in (or under) the same transaction or savepoint
53    /// beforehand.
54    pub fn unchecked_open(connection: C) -> Self {
55        Set::unchecked_open_with_config(connection, Config::default())
56    }
57}
58impl<'db, 'tbl, S, C> Set<'db, 'tbl, S, C>
59where
60    S: Format,
61    C: Savepointable,
62{
63    pub fn open_with_config(
64        mut connection: C,
65        config: Config<'db, 'tbl>,
66    ) -> Result<Self, OpenError> {
67        let database = config.database;
68        let table = config.table;
69
70        {
71            let sp = connection.savepoint()?;
72
73            let mut version = db::setup(&sp, &database, &table, "ds::set")?;
74            if version < 0 {
75                return Err(OpenError::TableVersion(version));
76            }
77            let prev_version = version;
78            if version < 1 {
79                let trailer = db::strict_without_rowid();
80                let sql_type = S::sql_type();
81
82                sp.execute(
83                    &format!(
84                        "CREATE TABLE {database}.{table} (
85                            key {sql_type} UNIQUE PRIMARY KEY NOT NULL
86                        ){trailer}"
87                    ),
88                    [],
89                )?;
90                version = 1;
91            }
92            if version > 1 {
93                return Err(OpenError::TableVersion(version));
94            }
95            if prev_version != version {
96                db::set_version(&sp, &database, &table, version)?;
97            }
98
99            sp.commit()?;
100        }
101        Ok(Self {
102            connection,
103            database,
104            table,
105            serializer: PhantomData,
106        })
107    }
108
109    /// Open a set without creating it or checking if it exists.  This is safe
110    /// if you call a safe open in (or under) the same transaction or savepoint
111    /// beforehand.
112    pub fn unchecked_open_with_config(connection: C, config: Config<'db, 'tbl>) -> Self {
113        let database = config.database;
114        let table = config.table;
115
116        Self {
117            connection,
118            database,
119            table,
120            serializer: PhantomData,
121        }
122    }
123
124    pub fn insert(&mut self, value: &S::In) -> Result<bool, Error<S>> {
125        let database = &self.database;
126        let table = &self.table;
127        let serialized = S::serialize(value).map_err(Error::Serialize)?;
128
129        let sp = self.connection.savepoint()?;
130        let ret = if db::has_upsert() {
131            sp.prepare_cached(&format!(
132                "INSERT INTO {database}.{table} (key) VALUES (?) ON CONFLICT DO NOTHING"
133            ))?
134            .execute(params![serialized])?;
135            sp.changes() > 0
136        } else if Self::contains_serialized(database, table, &sp, &serialized)? {
137            false
138        } else {
139            sp.prepare_cached(&format!("INSERT INTO {database}.{table} (key) VALUES (?)"))?
140                .execute(params![serialized])?;
141            true
142        };
143        sp.commit()?;
144        Ok(ret)
145    }
146
147    pub fn contains(&mut self, value: &S::In) -> Result<bool, Error<S>> {
148        let serialized = S::serialize(value).map_err(|e| Error::Serialize(e))?;
149        Self::contains_serialized(
150            &self.database,
151            &self.table,
152            &*self.connection.savepoint()?,
153            &serialized,
154        )
155    }
156
157    fn contains_serialized(
158        database: &Identifier,
159        table: &Identifier,
160        connection: &Connection,
161        value: &S::Buffer,
162    ) -> Result<bool, Error<S>> {
163        Ok(connection
164            .prepare_cached(&format!("SELECT 1 FROM {database}.{table} WHERE key = ?"))?
165            .query_row(params![value], |_| Ok(()))
166            .optional()?
167            .is_some())
168    }
169
170    pub fn remove(&mut self, value: &S::In) -> Result<bool, Error<S>> {
171        let database = &self.database;
172        let table = &self.table;
173        let serialized = S::serialize(value).map_err(Error::Serialize)?;
174
175        let sp = self.connection.savepoint()?;
176        let changes = sp
177            .prepare_cached(&format!("DELETE FROM {database}.{table} WHERE key = ?"))?
178            .execute(params![serialized])?;
179
180        sp.commit()?;
181
182        Ok(changes > 0)
183    }
184
185    pub fn clear(&mut self) -> Result<(), Error<S>> {
186        let database = &self.database;
187        let table = &self.table;
188        let sp = self.connection.savepoint()?;
189        sp.prepare_cached(&format!("DELETE FROM {database}.{table}"))?
190            .execute([])?;
191        sp.commit()?;
192        Ok(())
193    }
194
195    pub fn first(&mut self) -> Result<Option<S::Out>, Error<S>> {
196        let database = &self.database;
197        let table = &self.table;
198
199        let serialized: Option<S::Buffer> = self
200            .connection
201            .savepoint()?
202            .prepare_cached(&format!(
203                "SELECT key FROM {database}.{table} ORDER BY key ASC"
204            ))?
205            .query_row([], |row| row.get(0))
206            .optional()?;
207
208        serialized
209            .map(|s| S::deserialize(&s))
210            .transpose()
211            .map_err(Error::Deserialize)
212    }
213
214    pub fn last(&mut self) -> Result<Option<S::Out>, Error<S>> {
215        let database = &self.database;
216        let table = &self.table;
217        let serialized: Option<S::Buffer> = self
218            .connection
219            .savepoint()?
220            .prepare_cached(&format!(
221                "SELECT key FROM {database}.{table} ORDER BY key DESC"
222            ))?
223            .query_row([], |row| row.get(0))
224            .optional()?;
225
226        serialized
227            .map(|s| S::deserialize(&s))
228            .transpose()
229            .map_err(Error::Deserialize)
230    }
231
232    pub fn len(&mut self) -> Result<u64, Error<S>> {
233        let database = &self.database;
234        let table = &self.table;
235        Ok(self
236            .connection
237            .savepoint()?
238            .prepare_cached(&format!("SELECT COUNT(*) FROM {database}.{table}"))?
239            .query_row([], |row| row.get(0))?)
240    }
241
242    pub fn is_empty(&mut self) -> Result<bool, Error<S>> {
243        let database = &self.database;
244        let table = &self.table;
245        Ok(self
246            .connection
247            .savepoint()?
248            .prepare_cached(&format!("SELECT 1 FROM {database}.{table} LIMIT 1"))?
249            .query_row([], |_| Ok(()))
250            .optional()?
251            .is_none())
252    }
253
254    pub fn iter(&mut self) -> Result<Iter<'db, 'tbl, S, Savepoint<'_>>, Error<S>> {
255        Ok(Iter::new(
256            self.connection.savepoint()?,
257            self.database.clone(),
258            self.table.clone(),
259        )?)
260    }
261
262    /// Retains only the elements specified by the predicate.
263    ///
264    /// In other words, remove all elements e for which f(e) returns false. The
265    /// elements are visited in ascending serialized order.
266    ///
267    /// This is all done in a single transaction.
268    pub fn retain<F>(&mut self, mut f: F) -> Result<(), Error<S>>
269    where
270        F: FnMut(S::Out) -> bool,
271    {
272        let database = &self.database;
273        let table = &self.table;
274
275        let sp = self.connection.savepoint()?;
276        {
277            let mut maybe_serialized = sp
278                .prepare_cached(&format!(
279                    "SELECT key FROM {database}.{table} ORDER BY key ASC LIMIT 1"
280                ))?
281                .query_row([], |row| row.get(0))
282                .optional()?;
283            let mut deleter =
284                sp.prepare_cached(&format!("DELETE FROM {database}.{table} WHERE key = ?"))?;
285            let mut select_next = sp.prepare_cached(&format!(
286                "SELECT key FROM {database}.{table} WHERE key > ? ORDER BY key ASC LIMIT 1"
287            ))?;
288            while let Some(serialized) = maybe_serialized {
289                let item = S::deserialize(&serialized).map_err(|e| Error::Deserialize(e))?;
290                if !f(item) {
291                    deleter.execute(params![serialized])?;
292                }
293                maybe_serialized = select_next
294                    .query_row(params![serialized], |row| row.get(0))
295                    .optional()?;
296            }
297        }
298        sp.commit()?;
299        Ok(())
300    }
301}
302
303#[cfg(test)]
304mod test;