sqlite_wasm_rs/vfs/
memory.rs

1//! Memory VFS, used as the default VFS
2
3use crate::libsqlite3::*;
4use crate::vfs::utils::{
5    check_db_and_page_size, check_import_db, page_read, ImportDbError, MemLinearFile,
6    SQLiteIoMethods, SQLiteVfs, SQLiteVfsFile, VfsAppData, VfsError, VfsFile, VfsResult, VfsStore,
7};
8
9use once_cell::sync::OnceCell;
10use parking_lot::RwLock;
11use std::collections::HashMap;
12
13type Result<T> = std::result::Result<T, MemVfsError>;
14
15type MemAppData = RwLock<HashMap<String, MemFile>>;
16
17#[derive(Default)]
18struct MemPageFile {
19    pages: HashMap<usize, Vec<u8>>,
20    file_size: usize,
21    page_size: usize,
22}
23
24impl VfsFile for MemPageFile {
25    fn read(&self, buf: &mut [u8], offset: usize) -> VfsResult<i32> {
26        Ok(page_read(
27            buf,
28            self.page_size,
29            self.file_size,
30            offset,
31            |addr| self.pages.get(&addr),
32            |page, buf, (start, end)| {
33                buf.copy_from_slice(&page[start..end]);
34            },
35        ))
36    }
37
38    fn write(&mut self, buf: &[u8], offset: usize) -> VfsResult<()> {
39        let page_size = buf.len();
40
41        for fill in (self.file_size..offset).step_by(page_size) {
42            self.pages.insert(fill, vec![0; page_size]);
43        }
44        if let Some(buffer) = self.pages.get_mut(&offset) {
45            buffer.copy_from_slice(buf);
46        } else {
47            self.pages.insert(offset, buf.to_vec());
48        }
49
50        self.page_size = page_size;
51        self.file_size = self.file_size.max(offset + page_size);
52
53        Ok(())
54    }
55
56    fn truncate(&mut self, size: usize) -> VfsResult<()> {
57        for offset in size..self.file_size {
58            self.pages.remove(&offset);
59        }
60        self.file_size = size;
61        Ok(())
62    }
63
64    fn flush(&mut self) -> VfsResult<()> {
65        Ok(())
66    }
67
68    fn size(&self) -> VfsResult<usize> {
69        Ok(self.file_size)
70    }
71}
72
73enum MemFile {
74    Main(MemPageFile),
75    Temp(MemLinearFile),
76}
77
78impl MemFile {
79    fn new(flags: i32) -> Self {
80        if flags & SQLITE_OPEN_MAIN_DB == 0 {
81            Self::Temp(MemLinearFile::default())
82        } else {
83            Self::Main(MemPageFile::default())
84        }
85    }
86}
87
88impl VfsFile for MemFile {
89    fn read(&self, buf: &mut [u8], offset: usize) -> VfsResult<i32> {
90        match self {
91            MemFile::Main(mem_page_file) => mem_page_file.read(buf, offset),
92            MemFile::Temp(mem_linear_file) => mem_linear_file.read(buf, offset),
93        }
94    }
95
96    fn write(&mut self, buf: &[u8], offset: usize) -> VfsResult<()> {
97        match self {
98            MemFile::Main(mem_page_file) => mem_page_file.write(buf, offset),
99            MemFile::Temp(mem_linear_file) => mem_linear_file.write(buf, offset),
100        }
101    }
102
103    fn truncate(&mut self, size: usize) -> VfsResult<()> {
104        match self {
105            MemFile::Main(mem_page_file) => mem_page_file.truncate(size),
106            MemFile::Temp(mem_linear_file) => mem_linear_file.truncate(size),
107        }
108    }
109
110    fn flush(&mut self) -> VfsResult<()> {
111        match self {
112            MemFile::Main(mem_page_file) => mem_page_file.flush(),
113            MemFile::Temp(mem_linear_file) => mem_linear_file.flush(),
114        }
115    }
116
117    fn size(&self) -> VfsResult<usize> {
118        match self {
119            MemFile::Main(mem_page_file) => mem_page_file.size(),
120            MemFile::Temp(mem_linear_file) => mem_linear_file.size(),
121        }
122    }
123}
124
125struct MemStore;
126
127impl VfsStore<MemFile, MemAppData> for MemStore {
128    fn add_file(vfs: *mut sqlite3_vfs, file: &str, flags: i32) -> VfsResult<()> {
129        let app_data = unsafe { Self::app_data(vfs) };
130        app_data.write().insert(file.into(), MemFile::new(flags));
131        Ok(())
132    }
133
134    fn contains_file(vfs: *mut sqlite3_vfs, file: &str) -> VfsResult<bool> {
135        let app_data = unsafe { Self::app_data(vfs) };
136        Ok(app_data.read().contains_key(file))
137    }
138
139    fn delete_file(vfs: *mut sqlite3_vfs, file: &str) -> VfsResult<()> {
140        let app_data = unsafe { Self::app_data(vfs) };
141        if app_data.write().remove(file).is_none() {
142            return Err(VfsError::new(
143                SQLITE_IOERR_DELETE,
144                format!("{file} not found"),
145            ));
146        }
147        Ok(())
148    }
149
150    fn with_file<F: Fn(&MemFile) -> i32>(vfs_file: &SQLiteVfsFile, f: F) -> VfsResult<i32> {
151        let name = unsafe { vfs_file.name() };
152        let app_data = unsafe { Self::app_data(vfs_file.vfs) };
153        match app_data.read().get(name) {
154            Some(file) => Ok(f(file)),
155            None => Err(VfsError::new(SQLITE_IOERR, format!("{name} not found"))),
156        }
157    }
158
159    fn with_file_mut<F: Fn(&mut MemFile) -> i32>(vfs_file: &SQLiteVfsFile, f: F) -> VfsResult<i32> {
160        let name = unsafe { vfs_file.name() };
161        let app_data = unsafe { Self::app_data(vfs_file.vfs) };
162        match app_data.write().get_mut(name) {
163            Some(file) => Ok(f(file)),
164            None => Err(VfsError::new(SQLITE_IOERR, format!("{name} not found"))),
165        }
166    }
167}
168
169struct MemIoMethods;
170
171impl SQLiteIoMethods for MemIoMethods {
172    type File = MemFile;
173    type AppData = MemAppData;
174    type Store = MemStore;
175
176    const VERSION: ::std::os::raw::c_int = 1;
177}
178
179struct MemVfs;
180
181impl SQLiteVfs<MemIoMethods> for MemVfs {
182    const VERSION: ::std::os::raw::c_int = 1;
183}
184
185static APP_DATA: OnceCell<&'static VfsAppData<MemAppData>> = OnceCell::new();
186
187fn app_data() -> &'static VfsAppData<MemAppData> {
188    APP_DATA.get_or_init(|| unsafe { &*VfsAppData::new(MemAppData::default()).leak() })
189}
190
191#[derive(thiserror::Error, Debug)]
192pub enum MemVfsError {
193    #[error(transparent)]
194    ImportDb(#[from] ImportDbError),
195    #[error("Generic error: {0}")]
196    Generic(String),
197}
198
199/// MemVfs management tools exposed to clients.
200pub struct MemVfsUtil(&'static VfsAppData<MemAppData>);
201
202impl Default for MemVfsUtil {
203    fn default() -> Self {
204        MemVfsUtil::new()
205    }
206}
207
208impl MemVfsUtil {
209    fn import_db_unchecked_impl(
210        &self,
211        path: &str,
212        bytes: &[u8],
213        page_size: usize,
214        clear_wal: bool,
215    ) -> Result<()> {
216        check_db_and_page_size(bytes.len(), page_size)?;
217
218        if self.exists(path) {
219            return Err(MemVfsError::Generic(format!("{path} file already exists")));
220        }
221
222        let mut pages: HashMap<usize, Vec<u8>> = bytes
223            .chunks(page_size)
224            .enumerate()
225            .map(|(idx, buffer)| (idx * page_size, buffer.to_vec()))
226            .collect();
227
228        if clear_wal {
229            // header
230            let header = pages.get_mut(&0).unwrap();
231            header[18] = 1;
232            header[19] = 1;
233        }
234
235        self.0.write().insert(
236            path.into(),
237            MemFile::Main(MemPageFile {
238                file_size: pages.len() * page_size,
239                page_size,
240                pages,
241            }),
242        );
243
244        Ok(())
245    }
246
247    /// Get management tool
248    pub fn new() -> Self {
249        MemVfsUtil(app_data())
250    }
251
252    /// Import the db file
253    ///
254    /// If the database is imported with WAL mode enabled,
255    /// it will be forced to write back to legacy mode, see
256    /// <https://sqlite.org/forum/forumpost/67882c5b04>
257    ///
258    /// If the imported DB is encrypted, use `import_db_unchecked` instead.
259    pub fn import_db(&self, path: &str, bytes: &[u8]) -> Result<()> {
260        let page_size = check_import_db(bytes)?;
261        self.import_db_unchecked_impl(path, bytes, page_size, true)
262    }
263
264    /// Can be used to import encrypted DB
265    pub fn import_db_unchecked(&self, path: &str, bytes: &[u8], page_size: usize) -> Result<()> {
266        self.import_db_unchecked_impl(path, bytes, page_size, false)
267    }
268
269    /// Export database
270    pub fn export_db(&self, name: &str) -> Result<Vec<u8>> {
271        let name2file = self.0.read();
272
273        if let Some(file) = name2file.get(name) {
274            if let MemFile::Main(file) = file {
275                let file_size = file.file_size;
276                let mut ret = vec![0; file.file_size];
277                for (&offset, buffer) in &file.pages {
278                    if offset >= file_size {
279                        continue;
280                    }
281                    ret[offset..offset + file.page_size].copy_from_slice(buffer);
282                }
283                Ok(ret)
284            } else {
285                Err(MemVfsError::Generic(
286                    "Does not support dumping temporary files".into(),
287                ))
288            }
289        } else {
290            Err(MemVfsError::Generic(
291                "The file to be exported does not exist".into(),
292            ))
293        }
294    }
295
296    /// Delete the specified db, please make sure that the db is closed.
297    pub fn delete_db(&self, name: &str) {
298        self.0.write().remove(name);
299    }
300
301    /// Delete all dbs, please make sure that all dbs is closed.
302    pub fn clear_all(&self) {
303        std::mem::take(&mut *self.0.write());
304    }
305
306    /// Does the DB exist.
307    pub fn exists(&self, file: &str) -> bool {
308        self.0.read().contains_key(file)
309    }
310}
311
312pub(crate) fn install() -> ::std::os::raw::c_int {
313    let app_data = app_data();
314    let vfs = Box::leak(Box::new(MemVfs::vfs(
315        c"memvfs".as_ptr().cast(),
316        app_data as *const _ as *mut _,
317    )));
318    unsafe { sqlite3_vfs_register(vfs, 1) }
319}