sqlite_wasm_rs/vfs/
memory.rs

1//! Memory VFS, used as the default VFS
2
3use crate::libsqlite3::*;
4use crate::vfs::utils::{
5    check_import_db, ImportDbError, MemChunksFile, SQLiteIoMethods, SQLiteVfs, SQLiteVfsFile,
6    VfsAppData, VfsError, VfsFile, VfsResult, VfsStore,
7};
8
9use once_cell::sync::OnceCell;
10use parking_lot::RwLock;
11use std::collections::HashMap;
12
13type Result<T> = std::result::Result<T, MemVfsError>;
14
15enum MemFile {
16    Main(MemChunksFile),
17    Temp(MemChunksFile),
18}
19
20impl MemFile {
21    fn new(flags: i32) -> Self {
22        if flags & SQLITE_OPEN_MAIN_DB == 0 {
23            Self::Temp(MemChunksFile::default())
24        } else {
25            Self::Main(MemChunksFile::waiting_for_write())
26        }
27    }
28
29    fn file(&self) -> &MemChunksFile {
30        let (MemFile::Main(file) | MemFile::Temp(file)) = self;
31        file
32    }
33
34    fn file_mut(&mut self) -> &mut MemChunksFile {
35        let (MemFile::Main(file) | MemFile::Temp(file)) = self;
36        file
37    }
38}
39
40impl VfsFile for MemFile {
41    fn read(&self, buf: &mut [u8], offset: usize) -> VfsResult<bool> {
42        self.file().read(buf, offset)
43    }
44
45    fn write(&mut self, buf: &[u8], offset: usize) -> VfsResult<()> {
46        self.file_mut().write(buf, offset)
47    }
48
49    fn truncate(&mut self, size: usize) -> VfsResult<()> {
50        self.file_mut().truncate(size)
51    }
52
53    fn flush(&mut self) -> VfsResult<()> {
54        self.file_mut().flush()
55    }
56
57    fn size(&self) -> VfsResult<usize> {
58        self.file().size()
59    }
60}
61
62type MemAppData = RwLock<HashMap<String, MemFile>>;
63
64struct MemStore;
65
66impl VfsStore<MemFile, MemAppData> for MemStore {
67    fn add_file(vfs: *mut sqlite3_vfs, file: &str, flags: i32) -> VfsResult<()> {
68        let app_data = unsafe { Self::app_data(vfs) };
69        app_data.write().insert(file.into(), MemFile::new(flags));
70        Ok(())
71    }
72
73    fn contains_file(vfs: *mut sqlite3_vfs, file: &str) -> VfsResult<bool> {
74        let app_data = unsafe { Self::app_data(vfs) };
75        Ok(app_data.read().contains_key(file))
76    }
77
78    fn delete_file(vfs: *mut sqlite3_vfs, file: &str) -> VfsResult<()> {
79        let app_data = unsafe { Self::app_data(vfs) };
80        if app_data.write().remove(file).is_none() {
81            return Err(VfsError::new(
82                SQLITE_IOERR_DELETE,
83                format!("{file} not found"),
84            ));
85        }
86        Ok(())
87    }
88
89    fn with_file<F: Fn(&MemFile) -> VfsResult<i32>>(
90        vfs_file: &SQLiteVfsFile,
91        f: F,
92    ) -> VfsResult<i32> {
93        let name = unsafe { vfs_file.name() };
94        let app_data = unsafe { Self::app_data(vfs_file.vfs) };
95        match app_data.read().get(name) {
96            Some(file) => f(file),
97            None => Err(VfsError::new(SQLITE_IOERR, format!("{name} not found"))),
98        }
99    }
100
101    fn with_file_mut<F: Fn(&mut MemFile) -> VfsResult<i32>>(
102        vfs_file: &SQLiteVfsFile,
103        f: F,
104    ) -> VfsResult<i32> {
105        let name = unsafe { vfs_file.name() };
106        let app_data = unsafe { Self::app_data(vfs_file.vfs) };
107        match app_data.write().get_mut(name) {
108            Some(file) => f(file),
109            None => Err(VfsError::new(SQLITE_IOERR, format!("{name} not found"))),
110        }
111    }
112}
113
114struct MemIoMethods;
115
116impl SQLiteIoMethods for MemIoMethods {
117    type File = MemFile;
118    type AppData = MemAppData;
119    type Store = MemStore;
120
121    const VERSION: ::std::os::raw::c_int = 1;
122}
123
124struct MemVfs;
125
126impl SQLiteVfs<MemIoMethods> for MemVfs {
127    const VERSION: ::std::os::raw::c_int = 1;
128}
129
130static APP_DATA: OnceCell<&'static VfsAppData<MemAppData>> = OnceCell::new();
131
132fn app_data() -> &'static VfsAppData<MemAppData> {
133    APP_DATA.get_or_init(|| unsafe { &*VfsAppData::new(MemAppData::default()).leak() })
134}
135
136#[derive(thiserror::Error, Debug)]
137pub enum MemVfsError {
138    #[error(transparent)]
139    ImportDb(#[from] ImportDbError),
140    #[error("Generic error: {0}")]
141    Generic(String),
142}
143
144/// MemVfs management tools exposed to clients.
145pub struct MemVfsUtil(&'static VfsAppData<MemAppData>);
146
147impl Default for MemVfsUtil {
148    fn default() -> Self {
149        MemVfsUtil::new()
150    }
151}
152
153impl MemVfsUtil {
154    /// Get management tool
155    pub fn new() -> Self {
156        MemVfsUtil(app_data())
157    }
158}
159
160impl MemVfsUtil {
161    fn import_db_unchecked_impl(
162        &self,
163        filename: &str,
164        bytes: &[u8],
165        page_size: usize,
166        clear_wal: bool,
167    ) -> Result<()> {
168        if self.exists(filename) {
169            return Err(MemVfsError::Generic(format!(
170                "{filename} file already exists"
171            )));
172        }
173
174        self.0.write().insert(filename.into(), {
175            let mut file = MemFile::Main(MemChunksFile::new(page_size));
176            file.write(bytes, 0).unwrap();
177            if clear_wal {
178                file.write(&[1, 1], 18).unwrap();
179            }
180            file
181        });
182
183        Ok(())
184    }
185
186    /// Import the database.
187    ///
188    /// If the database is imported with WAL mode enabled,
189    /// it will be forced to write back to legacy mode, see
190    /// <https://sqlite.org/forum/forumpost/67882c5b04>
191    ///
192    /// If the imported database is encrypted, use `import_db_unchecked` instead.
193    pub fn import_db(&self, filename: &str, bytes: &[u8]) -> Result<()> {
194        let page_size = check_import_db(bytes)?;
195        self.import_db_unchecked_impl(filename, bytes, page_size, true)
196    }
197
198    /// `import_db` without checking, can be used to import encrypted database.
199    pub fn import_db_unchecked(
200        &self,
201        filename: &str,
202        bytes: &[u8],
203        page_size: usize,
204    ) -> Result<()> {
205        self.import_db_unchecked_impl(filename, bytes, page_size, false)
206    }
207
208    /// Export the database.
209    pub fn export_db(&self, filename: &str) -> Result<Vec<u8>> {
210        let name2file = self.0.read();
211
212        if let Some(file) = name2file.get(filename) {
213            let file_size = file.size().unwrap();
214            let mut ret = vec![0; file_size];
215            file.read(&mut ret, 0).unwrap();
216            Ok(ret)
217        } else {
218            Err(MemVfsError::Generic(
219                "The file to be exported does not exist".into(),
220            ))
221        }
222    }
223
224    /// Delete the specified database, please make sure that the database is closed.
225    pub fn delete_db(&self, filename: &str) {
226        self.0.write().remove(filename);
227    }
228
229    /// Delete all database, please make sure that all database is closed.
230    pub fn clear_all(&self) {
231        std::mem::take(&mut *self.0.write());
232    }
233
234    /// Does the database exists.
235    pub fn exists(&self, filename: &str) -> bool {
236        self.0.read().contains_key(filename)
237    }
238
239    /// List all files.
240    pub fn list(&self) -> Vec<String> {
241        self.0.read().keys().cloned().collect()
242    }
243
244    /// Number of files.
245    pub fn count(&self) -> usize {
246        self.0.read().len()
247    }
248}
249
250pub(crate) fn install() -> ::std::os::raw::c_int {
251    let app_data = app_data();
252    let vfs = Box::leak(Box::new(MemVfs::vfs(
253        c"memvfs".as_ptr().cast(),
254        app_data as *const _ as *mut _,
255    )));
256    unsafe { sqlite3_vfs_register(vfs, 1) }
257}
258
259#[cfg(test)]
260mod tests {
261    use crate::{
262        mem_vfs::{MemAppData, MemFile, MemStore},
263        utils::{test_suite::test_vfs_store, VfsAppData},
264    };
265    use wasm_bindgen_test::wasm_bindgen_test;
266
267    #[wasm_bindgen_test]
268    fn test_memory_vfs_store() {
269        test_vfs_store::<MemAppData, MemFile, MemStore>(VfsAppData::new(MemAppData::default()))
270            .unwrap();
271    }
272}