1use crate::libsqlite3::*;
4use crate::vfs::utils::{
5 check_db_and_page_size, check_import_db, page_read, ImportDbError, MemLinearFile,
6 SQLiteIoMethods, SQLiteVfs, SQLiteVfsFile, VfsAppData, VfsError, VfsFile, VfsResult, VfsStore,
7};
8
9use once_cell::sync::OnceCell;
10use parking_lot::RwLock;
11use std::collections::HashMap;
12
13type Result<T> = std::result::Result<T, MemVfsError>;
14
15type MemAppData = RwLock<HashMap<String, MemFile>>;
16
17#[derive(Default)]
18struct MemPageFile {
19 pages: HashMap<usize, Vec<u8>>,
20 file_size: usize,
21 page_size: usize,
22}
23
24impl VfsFile for MemPageFile {
25 fn read(&self, buf: &mut [u8], offset: usize) -> VfsResult<i32> {
26 Ok(page_read(
27 buf,
28 self.page_size,
29 self.file_size,
30 offset,
31 |addr| self.pages.get(&addr),
32 |page, buf, (start, end)| {
33 buf.copy_from_slice(&page[start..end]);
34 },
35 ))
36 }
37
38 fn write(&mut self, buf: &[u8], offset: usize) -> VfsResult<()> {
39 let page_size = buf.len();
40
41 for fill in (self.file_size..offset).step_by(page_size) {
42 self.pages.insert(fill, vec![0; page_size]);
43 }
44 if let Some(buffer) = self.pages.get_mut(&offset) {
45 buffer.copy_from_slice(buf);
46 } else {
47 self.pages.insert(offset, buf.to_vec());
48 }
49
50 self.page_size = page_size;
51 self.file_size = self.file_size.max(offset + page_size);
52
53 Ok(())
54 }
55
56 fn truncate(&mut self, size: usize) -> VfsResult<()> {
57 for offset in size..self.file_size {
58 self.pages.remove(&offset);
59 }
60 self.file_size = size;
61 Ok(())
62 }
63
64 fn flush(&mut self) -> VfsResult<()> {
65 Ok(())
66 }
67
68 fn size(&self) -> VfsResult<usize> {
69 Ok(self.file_size)
70 }
71}
72
73enum MemFile {
74 Main(MemPageFile),
75 Temp(MemLinearFile),
76}
77
78impl MemFile {
79 fn new(flags: i32) -> Self {
80 if flags & SQLITE_OPEN_MAIN_DB == 0 {
81 Self::Temp(MemLinearFile::default())
82 } else {
83 Self::Main(MemPageFile::default())
84 }
85 }
86}
87
88impl VfsFile for MemFile {
89 fn read(&self, buf: &mut [u8], offset: usize) -> VfsResult<i32> {
90 match self {
91 MemFile::Main(mem_page_file) => mem_page_file.read(buf, offset),
92 MemFile::Temp(mem_linear_file) => mem_linear_file.read(buf, offset),
93 }
94 }
95
96 fn write(&mut self, buf: &[u8], offset: usize) -> VfsResult<()> {
97 match self {
98 MemFile::Main(mem_page_file) => mem_page_file.write(buf, offset),
99 MemFile::Temp(mem_linear_file) => mem_linear_file.write(buf, offset),
100 }
101 }
102
103 fn truncate(&mut self, size: usize) -> VfsResult<()> {
104 match self {
105 MemFile::Main(mem_page_file) => mem_page_file.truncate(size),
106 MemFile::Temp(mem_linear_file) => mem_linear_file.truncate(size),
107 }
108 }
109
110 fn flush(&mut self) -> VfsResult<()> {
111 match self {
112 MemFile::Main(mem_page_file) => mem_page_file.flush(),
113 MemFile::Temp(mem_linear_file) => mem_linear_file.flush(),
114 }
115 }
116
117 fn size(&self) -> VfsResult<usize> {
118 match self {
119 MemFile::Main(mem_page_file) => mem_page_file.size(),
120 MemFile::Temp(mem_linear_file) => mem_linear_file.size(),
121 }
122 }
123}
124
125struct MemStore;
126
127impl VfsStore<MemFile, MemAppData> for MemStore {
128 fn add_file(vfs: *mut sqlite3_vfs, file: &str, flags: i32) -> VfsResult<()> {
129 let app_data = unsafe { Self::app_data(vfs) };
130 app_data.write().insert(file.into(), MemFile::new(flags));
131 Ok(())
132 }
133
134 fn contains_file(vfs: *mut sqlite3_vfs, file: &str) -> VfsResult<bool> {
135 let app_data = unsafe { Self::app_data(vfs) };
136 Ok(app_data.read().contains_key(file))
137 }
138
139 fn delete_file(vfs: *mut sqlite3_vfs, file: &str) -> VfsResult<()> {
140 let app_data = unsafe { Self::app_data(vfs) };
141 if app_data.write().remove(file).is_none() {
142 return Err(VfsError::new(
143 SQLITE_IOERR_DELETE,
144 format!("{file} not found"),
145 ));
146 }
147 Ok(())
148 }
149
150 fn with_file<F: Fn(&MemFile) -> i32>(vfs_file: &SQLiteVfsFile, f: F) -> VfsResult<i32> {
151 let name = unsafe { vfs_file.name() };
152 let app_data = unsafe { Self::app_data(vfs_file.vfs) };
153 match app_data.read().get(name) {
154 Some(file) => Ok(f(file)),
155 None => Err(VfsError::new(SQLITE_IOERR, format!("{name} not found"))),
156 }
157 }
158
159 fn with_file_mut<F: Fn(&mut MemFile) -> i32>(vfs_file: &SQLiteVfsFile, f: F) -> VfsResult<i32> {
160 let name = unsafe { vfs_file.name() };
161 let app_data = unsafe { Self::app_data(vfs_file.vfs) };
162 match app_data.write().get_mut(name) {
163 Some(file) => Ok(f(file)),
164 None => Err(VfsError::new(SQLITE_IOERR, format!("{name} not found"))),
165 }
166 }
167}
168
169struct MemIoMethods;
170
171impl SQLiteIoMethods for MemIoMethods {
172 type File = MemFile;
173 type AppData = MemAppData;
174 type Store = MemStore;
175
176 const VERSION: ::std::os::raw::c_int = 1;
177}
178
179struct MemVfs;
180
181impl SQLiteVfs<MemIoMethods> for MemVfs {
182 const VERSION: ::std::os::raw::c_int = 1;
183}
184
185static APP_DATA: OnceCell<&'static VfsAppData<MemAppData>> = OnceCell::new();
186
187fn app_data() -> &'static VfsAppData<MemAppData> {
188 APP_DATA.get_or_init(|| unsafe { &*VfsAppData::new(MemAppData::default()).leak() })
189}
190
191#[derive(thiserror::Error, Debug)]
192pub enum MemVfsError {
193 #[error(transparent)]
194 ImportDb(#[from] ImportDbError),
195 #[error("Generic error: {0}")]
196 Generic(String),
197}
198
199pub struct MemVfsUtil(&'static VfsAppData<MemAppData>);
201
202impl Default for MemVfsUtil {
203 fn default() -> Self {
204 MemVfsUtil::new()
205 }
206}
207
208impl MemVfsUtil {
209 fn import_db_unchecked_impl(
210 &self,
211 path: &str,
212 bytes: &[u8],
213 page_size: usize,
214 clear_wal: bool,
215 ) -> Result<()> {
216 check_db_and_page_size(bytes.len(), page_size)?;
217
218 if self.exists(path) {
219 return Err(MemVfsError::Generic(format!("{path} file already exists")));
220 }
221
222 let mut pages: HashMap<usize, Vec<u8>> = bytes
223 .chunks(page_size)
224 .enumerate()
225 .map(|(idx, buffer)| (idx * page_size, buffer.to_vec()))
226 .collect();
227
228 if clear_wal {
229 let header = pages.get_mut(&0).unwrap();
231 header[18] = 1;
232 header[19] = 1;
233 }
234
235 self.0.write().insert(
236 path.into(),
237 MemFile::Main(MemPageFile {
238 file_size: pages.len() * page_size,
239 page_size,
240 pages,
241 }),
242 );
243
244 Ok(())
245 }
246
247 pub fn new() -> Self {
249 MemVfsUtil(app_data())
250 }
251
252 pub fn import_db(&self, path: &str, bytes: &[u8]) -> Result<()> {
260 let page_size = check_import_db(bytes)?;
261 self.import_db_unchecked_impl(path, bytes, page_size, true)
262 }
263
264 pub fn import_db_unchecked(&self, path: &str, bytes: &[u8], page_size: usize) -> Result<()> {
266 self.import_db_unchecked_impl(path, bytes, page_size, false)
267 }
268
269 pub fn export_db(&self, name: &str) -> Result<Vec<u8>> {
271 let name2file = self.0.read();
272
273 if let Some(file) = name2file.get(name) {
274 if let MemFile::Main(file) = file {
275 let file_size = file.file_size;
276 let mut ret = vec![0; file.file_size];
277 for (&offset, buffer) in &file.pages {
278 if offset >= file_size {
279 continue;
280 }
281 ret[offset..offset + file.page_size].copy_from_slice(buffer);
282 }
283 Ok(ret)
284 } else {
285 Err(MemVfsError::Generic(
286 "Does not support dumping temporary files".into(),
287 ))
288 }
289 } else {
290 Err(MemVfsError::Generic(
291 "The file to be exported does not exist".into(),
292 ))
293 }
294 }
295
296 pub fn delete_db(&self, name: &str) {
298 self.0.write().remove(name);
299 }
300
301 pub fn clear_all(&self) {
303 std::mem::take(&mut *self.0.write());
304 }
305
306 pub fn exists(&self, file: &str) -> bool {
308 self.0.read().contains_key(file)
309 }
310}
311
312pub(crate) fn install() -> ::std::os::raw::c_int {
313 let app_data = app_data();
314 let vfs = Box::leak(Box::new(MemVfs::vfs(
315 c"memvfs".as_ptr().cast(),
316 app_data as *const _ as *mut _,
317 )));
318 unsafe { sqlite3_vfs_register(vfs, 1) }
319}