sqlite_wasm_rs/vfs/
memory.rs

1//! Memory VFS, used as the default VFS
2
3use crate::libsqlite3::*;
4use crate::vfs::utils::{
5    check_import_db, ImportDbError, MemChunksFile, SQLiteIoMethods, SQLiteVfs, SQLiteVfsFile,
6    VfsAppData, VfsError, VfsFile, VfsResult, VfsStore,
7};
8
9use once_cell::sync::OnceCell;
10use parking_lot::RwLock;
11use std::collections::HashMap;
12
13type Result<T> = std::result::Result<T, MemVfsError>;
14
15enum MemFile {
16    Main(MemChunksFile),
17    Temp(MemChunksFile),
18}
19
20impl MemFile {
21    fn new(flags: i32) -> Self {
22        if flags & SQLITE_OPEN_MAIN_DB == 0 {
23            Self::Temp(MemChunksFile::default())
24        } else {
25            Self::Main(MemChunksFile::waiting_for_write())
26        }
27    }
28
29    fn file(&self) -> &MemChunksFile {
30        let (MemFile::Main(file) | MemFile::Temp(file)) = self;
31        file
32    }
33
34    fn file_mut(&mut self) -> &mut MemChunksFile {
35        let (MemFile::Main(file) | MemFile::Temp(file)) = self;
36        file
37    }
38}
39
40impl VfsFile for MemFile {
41    fn read(&self, buf: &mut [u8], offset: usize) -> VfsResult<i32> {
42        self.file().read(buf, offset)
43    }
44
45    fn write(&mut self, buf: &[u8], offset: usize) -> VfsResult<()> {
46        self.file_mut().write(buf, offset)
47    }
48
49    fn truncate(&mut self, size: usize) -> VfsResult<()> {
50        self.file_mut().truncate(size)
51    }
52
53    fn flush(&mut self) -> VfsResult<()> {
54        self.file_mut().flush()
55    }
56
57    fn size(&self) -> VfsResult<usize> {
58        self.file().size()
59    }
60}
61
62type MemAppData = RwLock<HashMap<String, MemFile>>;
63
64struct MemStore;
65
66impl VfsStore<MemFile, MemAppData> for MemStore {
67    fn add_file(vfs: *mut sqlite3_vfs, file: &str, flags: i32) -> VfsResult<()> {
68        let app_data = unsafe { Self::app_data(vfs) };
69        app_data.write().insert(file.into(), MemFile::new(flags));
70        Ok(())
71    }
72
73    fn contains_file(vfs: *mut sqlite3_vfs, file: &str) -> VfsResult<bool> {
74        let app_data = unsafe { Self::app_data(vfs) };
75        Ok(app_data.read().contains_key(file))
76    }
77
78    fn delete_file(vfs: *mut sqlite3_vfs, file: &str) -> VfsResult<()> {
79        let app_data = unsafe { Self::app_data(vfs) };
80        if app_data.write().remove(file).is_none() {
81            return Err(VfsError::new(
82                SQLITE_IOERR_DELETE,
83                format!("{file} not found"),
84            ));
85        }
86        Ok(())
87    }
88
89    fn with_file<F: Fn(&MemFile) -> i32>(vfs_file: &SQLiteVfsFile, f: F) -> VfsResult<i32> {
90        let name = unsafe { vfs_file.name() };
91        let app_data = unsafe { Self::app_data(vfs_file.vfs) };
92        match app_data.read().get(name) {
93            Some(file) => Ok(f(file)),
94            None => Err(VfsError::new(SQLITE_IOERR, format!("{name} not found"))),
95        }
96    }
97
98    fn with_file_mut<F: Fn(&mut MemFile) -> i32>(vfs_file: &SQLiteVfsFile, f: F) -> VfsResult<i32> {
99        let name = unsafe { vfs_file.name() };
100        let app_data = unsafe { Self::app_data(vfs_file.vfs) };
101        match app_data.write().get_mut(name) {
102            Some(file) => Ok(f(file)),
103            None => Err(VfsError::new(SQLITE_IOERR, format!("{name} not found"))),
104        }
105    }
106}
107
108struct MemIoMethods;
109
110impl SQLiteIoMethods for MemIoMethods {
111    type File = MemFile;
112    type AppData = MemAppData;
113    type Store = MemStore;
114
115    const VERSION: ::std::os::raw::c_int = 1;
116}
117
118struct MemVfs;
119
120impl SQLiteVfs<MemIoMethods> for MemVfs {
121    const VERSION: ::std::os::raw::c_int = 1;
122}
123
124static APP_DATA: OnceCell<&'static VfsAppData<MemAppData>> = OnceCell::new();
125
126fn app_data() -> &'static VfsAppData<MemAppData> {
127    APP_DATA.get_or_init(|| unsafe { &*VfsAppData::new(MemAppData::default()).leak() })
128}
129
130#[derive(thiserror::Error, Debug)]
131pub enum MemVfsError {
132    #[error(transparent)]
133    ImportDb(#[from] ImportDbError),
134    #[error("Generic error: {0}")]
135    Generic(String),
136}
137
138/// MemVfs management tools exposed to clients.
139pub struct MemVfsUtil(&'static VfsAppData<MemAppData>);
140
141impl Default for MemVfsUtil {
142    fn default() -> Self {
143        MemVfsUtil::new()
144    }
145}
146
147impl MemVfsUtil {
148    fn import_db_unchecked_impl(
149        &self,
150        path: &str,
151        bytes: &[u8],
152        page_size: usize,
153        clear_wal: bool,
154    ) -> Result<()> {
155        if self.exists(path) {
156            return Err(MemVfsError::Generic(format!("{path} file already exists")));
157        }
158
159        self.0.write().insert(path.into(), {
160            let mut file = MemFile::Main(MemChunksFile::new(page_size));
161            file.write(bytes, 0).unwrap();
162            if clear_wal {
163                file.write(&[1, 1], 18).unwrap();
164            }
165            file
166        });
167
168        Ok(())
169    }
170
171    /// Get management tool
172    pub fn new() -> Self {
173        MemVfsUtil(app_data())
174    }
175
176    /// Import the db file
177    ///
178    /// If the database is imported with WAL mode enabled,
179    /// it will be forced to write back to legacy mode, see
180    /// <https://sqlite.org/forum/forumpost/67882c5b04>
181    ///
182    /// If the imported DB is encrypted, use `import_db_unchecked` instead.
183    pub fn import_db(&self, path: &str, bytes: &[u8]) -> Result<()> {
184        let page_size = check_import_db(bytes)?;
185        self.import_db_unchecked_impl(path, bytes, page_size, true)
186    }
187
188    /// Can be used to import encrypted DB
189    pub fn import_db_unchecked(&self, path: &str, bytes: &[u8], page_size: usize) -> Result<()> {
190        self.import_db_unchecked_impl(path, bytes, page_size, false)
191    }
192
193    /// Export database
194    pub fn export_db(&self, name: &str) -> Result<Vec<u8>> {
195        let name2file = self.0.read();
196
197        if let Some(file) = name2file.get(name) {
198            let file_size = file.size().unwrap();
199            let mut ret = vec![0; file_size];
200            file.read(&mut ret, 0).unwrap();
201            Ok(ret)
202        } else {
203            Err(MemVfsError::Generic(
204                "The file to be exported does not exist".into(),
205            ))
206        }
207    }
208
209    /// Delete the specified db, please make sure that the db is closed.
210    pub fn delete_db(&self, name: &str) {
211        self.0.write().remove(name);
212    }
213
214    /// Delete all dbs, please make sure that all dbs is closed.
215    pub fn clear_all(&self) {
216        std::mem::take(&mut *self.0.write());
217    }
218
219    /// Does the DB exist.
220    pub fn exists(&self, file: &str) -> bool {
221        self.0.read().contains_key(file)
222    }
223}
224
225pub(crate) fn install() -> ::std::os::raw::c_int {
226    let app_data = app_data();
227    let vfs = Box::leak(Box::new(MemVfs::vfs(
228        c"memvfs".as_ptr().cast(),
229        app_data as *const _ as *mut _,
230    )));
231    unsafe { sqlite3_vfs_register(vfs, 1) }
232}