sqlite_wasm_rs/vfs/
memory.rs

1//! Memory VFS, used as the default VFS
2
3use crate::libsqlite3::*;
4use crate::vfs::utils::{
5    check_import_db, ImportDbError, MemChunksFile, SQLiteIoMethods, SQLiteVfs, SQLiteVfsFile,
6    VfsAppData, VfsError, VfsFile, VfsResult, VfsStore,
7};
8
9use once_cell::sync::OnceCell;
10use parking_lot::RwLock;
11use std::collections::HashMap;
12
13type Result<T> = std::result::Result<T, MemVfsError>;
14
15enum MemFile {
16    Main(MemChunksFile),
17    Temp(MemChunksFile),
18}
19
20impl MemFile {
21    fn new(flags: i32) -> Self {
22        if flags & SQLITE_OPEN_MAIN_DB == 0 {
23            Self::Temp(MemChunksFile::default())
24        } else {
25            Self::Main(MemChunksFile::waiting_for_write())
26        }
27    }
28
29    fn file(&self) -> &MemChunksFile {
30        let (MemFile::Main(file) | MemFile::Temp(file)) = self;
31        file
32    }
33
34    fn file_mut(&mut self) -> &mut MemChunksFile {
35        let (MemFile::Main(file) | MemFile::Temp(file)) = self;
36        file
37    }
38}
39
40impl VfsFile for MemFile {
41    fn read(&self, buf: &mut [u8], offset: usize) -> VfsResult<i32> {
42        self.file().read(buf, offset)
43    }
44
45    fn write(&mut self, buf: &[u8], offset: usize) -> VfsResult<()> {
46        self.file_mut().write(buf, offset)
47    }
48
49    fn truncate(&mut self, size: usize) -> VfsResult<()> {
50        self.file_mut().truncate(size)
51    }
52
53    fn flush(&mut self) -> VfsResult<()> {
54        self.file_mut().flush()
55    }
56
57    fn size(&self) -> VfsResult<usize> {
58        self.file().size()
59    }
60}
61
62type MemAppData = RwLock<HashMap<String, MemFile>>;
63
64struct MemStore;
65
66impl VfsStore<MemFile, MemAppData> for MemStore {
67    fn add_file(vfs: *mut sqlite3_vfs, file: &str, flags: i32) -> VfsResult<()> {
68        let app_data = unsafe { Self::app_data(vfs) };
69        app_data.write().insert(file.into(), MemFile::new(flags));
70        Ok(())
71    }
72
73    fn contains_file(vfs: *mut sqlite3_vfs, file: &str) -> VfsResult<bool> {
74        let app_data = unsafe { Self::app_data(vfs) };
75        Ok(app_data.read().contains_key(file))
76    }
77
78    fn delete_file(vfs: *mut sqlite3_vfs, file: &str) -> VfsResult<()> {
79        let app_data = unsafe { Self::app_data(vfs) };
80        if app_data.write().remove(file).is_none() {
81            return Err(VfsError::new(
82                SQLITE_IOERR_DELETE,
83                format!("{file} not found"),
84            ));
85        }
86        Ok(())
87    }
88
89    fn with_file<F: Fn(&MemFile) -> i32>(vfs_file: &SQLiteVfsFile, f: F) -> VfsResult<i32> {
90        let name = unsafe { vfs_file.name() };
91        let app_data = unsafe { Self::app_data(vfs_file.vfs) };
92        match app_data.read().get(name) {
93            Some(file) => Ok(f(file)),
94            None => Err(VfsError::new(SQLITE_IOERR, format!("{name} not found"))),
95        }
96    }
97
98    fn with_file_mut<F: Fn(&mut MemFile) -> i32>(vfs_file: &SQLiteVfsFile, f: F) -> VfsResult<i32> {
99        let name = unsafe { vfs_file.name() };
100        let app_data = unsafe { Self::app_data(vfs_file.vfs) };
101        match app_data.write().get_mut(name) {
102            Some(file) => Ok(f(file)),
103            None => Err(VfsError::new(SQLITE_IOERR, format!("{name} not found"))),
104        }
105    }
106}
107
108struct MemIoMethods;
109
110impl SQLiteIoMethods for MemIoMethods {
111    type File = MemFile;
112    type AppData = MemAppData;
113    type Store = MemStore;
114
115    const VERSION: ::std::os::raw::c_int = 1;
116}
117
118struct MemVfs;
119
120impl SQLiteVfs<MemIoMethods> for MemVfs {
121    const VERSION: ::std::os::raw::c_int = 1;
122}
123
124static APP_DATA: OnceCell<&'static VfsAppData<MemAppData>> = OnceCell::new();
125
126fn app_data() -> &'static VfsAppData<MemAppData> {
127    APP_DATA.get_or_init(|| unsafe { &*VfsAppData::new(MemAppData::default()).leak() })
128}
129
130#[derive(thiserror::Error, Debug)]
131pub enum MemVfsError {
132    #[error(transparent)]
133    ImportDb(#[from] ImportDbError),
134    #[error("Generic error: {0}")]
135    Generic(String),
136}
137
138/// MemVfs management tools exposed to clients.
139pub struct MemVfsUtil(&'static VfsAppData<MemAppData>);
140
141impl Default for MemVfsUtil {
142    fn default() -> Self {
143        MemVfsUtil::new()
144    }
145}
146
147impl MemVfsUtil {
148    /// Get management tool
149    pub fn new() -> Self {
150        MemVfsUtil(app_data())
151    }
152}
153
154impl MemVfsUtil {
155    fn import_db_unchecked_impl(
156        &self,
157        path: &str,
158        bytes: &[u8],
159        page_size: usize,
160        clear_wal: bool,
161    ) -> Result<()> {
162        if self.exists(path) {
163            return Err(MemVfsError::Generic(format!("{path} file already exists")));
164        }
165
166        self.0.write().insert(path.into(), {
167            let mut file = MemFile::Main(MemChunksFile::new(page_size));
168            file.write(bytes, 0).unwrap();
169            if clear_wal {
170                file.write(&[1, 1], 18).unwrap();
171            }
172            file
173        });
174
175        Ok(())
176    }
177
178    /// Import the database.
179    ///
180    /// If the database is imported with WAL mode enabled,
181    /// it will be forced to write back to legacy mode, see
182    /// <https://sqlite.org/forum/forumpost/67882c5b04>
183    ///
184    /// If the imported database is encrypted, use `import_db_unchecked` instead.
185    pub fn import_db(&self, path: &str, bytes: &[u8]) -> Result<()> {
186        let page_size = check_import_db(bytes)?;
187        self.import_db_unchecked_impl(path, bytes, page_size, true)
188    }
189
190    /// `import_db` without checking, can be used to import encrypted database.
191    pub fn import_db_unchecked(&self, path: &str, bytes: &[u8], page_size: usize) -> Result<()> {
192        self.import_db_unchecked_impl(path, bytes, page_size, false)
193    }
194
195    /// Export the database.
196    pub fn export_db(&self, path: &str) -> Result<Vec<u8>> {
197        let name2file = self.0.read();
198
199        if let Some(file) = name2file.get(path) {
200            let file_size = file.size().unwrap();
201            let mut ret = vec![0; file_size];
202            file.read(&mut ret, 0).unwrap();
203            Ok(ret)
204        } else {
205            Err(MemVfsError::Generic(
206                "The file to be exported does not exist".into(),
207            ))
208        }
209    }
210
211    /// Delete the specified database, please make sure that the database is closed.
212    pub fn delete_db(&self, path: &str) {
213        self.0.write().remove(path);
214    }
215
216    /// Delete all database, please make sure that all database is closed.
217    pub fn clear_all(&self) {
218        std::mem::take(&mut *self.0.write());
219    }
220
221    /// Does the database exists.
222    pub fn exists(&self, path: &str) -> bool {
223        self.0.read().contains_key(path)
224    }
225
226    /// List all file paths.
227    pub fn list(&self) -> Vec<String> {
228        self.0.read().keys().cloned().collect()
229    }
230
231    /// Number of files.
232    pub fn count(&self) -> usize {
233        self.0.read().len()
234    }
235}
236
237pub(crate) fn install() -> ::std::os::raw::c_int {
238    let app_data = app_data();
239    let vfs = Box::leak(Box::new(MemVfs::vfs(
240        c"memvfs".as_ptr().cast(),
241        app_data as *const _ as *mut _,
242    )));
243    unsafe { sqlite3_vfs_register(vfs, 1) }
244}