1use std::cmp;
2use std::collections::HashMap;
4use std::io::Write;
5
6type Error = Box<dyn std::error::Error + Send + Sync>;
7
8pub fn backfill(db: &mut sqlite_types::Db, wal: &sqlite_types::Wal) -> Result<(), Error> {
14 if db.header.page_size as u32 != wal.header.page_size {
15 return Err(format!(
16 "Error: page_size mismatch between WAL ({}) and DB ({}).",
17 wal.header.page_size, db.header.page_size
18 )
19 .into());
20 }
21
22 for frame in &wal.frames {
23 assert_eq!(wal.header.page_size as usize, frame.data.len());
24
25 if let Some(page) = db.pages.get_mut(&frame.header.page_number) {
26 if frame.header.page_number == 1 {
27 let new_header = sqlite_decoder::db::decode_header(&frame.data).unwrap();
29 db.header = new_header;
30 }
31
32 *page = frame.data.clone();
33 } else {
34 db.pages
35 .insert(frame.header.page_number, frame.data.clone());
36 db.header.db_size += 1;
37 }
38 }
39
40 Ok(())
41}
42
43pub fn hint_db_size(wal: &sqlite_types::Wal) -> Result<usize, Error> {
44 let mut max_page_count = 0u32;
45
46 for frame in &wal.frames {
47 max_page_count = cmp::min(max_page_count, frame.header.db_size_after_commit);
48 }
49
50 Ok(max_page_count as usize * wal.header.page_size as usize)
51}
52
53pub fn backfill_bytes(wal: &sqlite_types::Wal, db_bytes: &mut Vec<u8>) -> Result<(), Error> {
54 let db_header = sqlite_decoder::db::decode_header(&db_bytes)
55 .map_err(|err| format!("failed to decode database header: {}", err))?;
56
57 if db_header.page_size as u32 != wal.header.page_size {
58 return Err(format!(
59 "Error: page_size mismatch between WAL ({}) and DB ({}).",
60 wal.header.page_size, db_header.page_size
61 )
62 .into());
63 }
64
65 for frame in &wal.frames {
66 assert_eq!(wal.header.page_size as usize, frame.data.len());
67
68 let db_offset = (frame.header.page_number as usize - 1) * wal.header.page_size as usize;
69 let end = db_offset + wal.header.page_size as usize;
70
71 if end > db_bytes.len() {
72 db_bytes.resize(end, 0);
74 }
75
76 let wrote = (&mut db_bytes[db_offset..end])
77 .write(&frame.data)
78 .map_err(|err| format!("failed to write: {}", err))?;
79 assert_eq!(wrote, wal.header.page_size as usize);
80 }
81
82 Ok(())
83}
84
85pub fn to_db(
87 db_header: &sqlite_types::DbHeader,
88 wal: &sqlite_types::Wal,
89) -> Result<sqlite_types::Db, Error> {
90 let mut pages = HashMap::new();
91
92 {
94 let header_bytes =
95 sqlite_encoder::db::encode_header(&db_header).map_err(|err| -> Error {
96 format!("failed to encode database header: {}", err).into()
97 })?;
98 let mut first_page = vec![0u8; db_header.page_size as usize];
99 (&mut first_page[0..100])
100 .write(&header_bytes)
101 .map_err(|err| format!("failed to write header: {}", err))?;
102
103 pages.insert(1, first_page);
104 }
105
106 let mut db = sqlite_types::Db {
107 header: db_header.clone(),
108 pages,
109 };
110 backfill(&mut db, wal)?;
111
112 Ok(db)
113}
114
115pub fn merge(wal1: &mut sqlite_types::Wal, wal2: &sqlite_types::Wal) -> Result<(), Error> {
116 wal1.frames.extend(wal2.frames.clone());
117
118 Ok(())
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use tempfile::NamedTempFile;
129
130 fn open_db(db: &sqlite_types::Db, f: Box<dyn Fn(rusqlite::Connection)>) {
131 let bytes = sqlite_encoder::db::encode(db).unwrap();
132
133 let mut file = NamedTempFile::new().unwrap();
134 file.write_all(&bytes).unwrap();
135 file.flush().unwrap();
136
137 let conn = rusqlite::Connection::open(file.path()).unwrap();
138 f(conn);
139
140 file.close().unwrap();
141 }
142
143 fn table_list(conn: &rusqlite::Connection) -> Vec<String> {
144 let mut stmt = conn.prepare("pragma table_list;").unwrap();
145 let rows = stmt.query_map([], |row| row.get(1)).unwrap();
146
147 let mut list = Vec::new();
148 for row in rows {
149 list.push(row.unwrap());
150 }
151 list
152 }
153
154 fn pragma<T: rusqlite::types::FromSql>(conn: &rusqlite::Connection, name: &str) -> T {
155 let mut stmt = conn.prepare(&format!("pragma {};", name)).unwrap();
156 stmt.query_row([], |row| row.get::<usize, T>(0)).unwrap()
157 }
158
159 #[test]
160 fn it_converts_wal_to_db() {
161 let db_header = sqlite_types::DbHeader {
162 page_size: 4096,
163 file_format_write_version: 2,
164 file_format_read_version: 2,
165 max_embedded_payload_frac: 64,
166 min_embedded_payload_frac: 32,
167 leaf_payload_frac: 32,
168 file_change_counter: 1,
169 db_size: 1,
170 page_num_first_freelist: 0,
171 page_count_freelist: 0,
172 schema_cookie: 1,
173 schema_format_number: 4,
174 default_page_cache_size: 0,
175 page_num_largest_root_btree: 0,
176 text_encoding: 1,
177 user_version: 0,
178 vaccum_mode: 0,
179 app_id: 0,
180 version_valid_for: 1,
181 sqlite_version: sqlite_types::SQLITE_3_37_2_VERSION,
182 };
183
184 let wal = include_bytes!("../test/create-test-table.wal");
185 let wal = sqlite_decoder::wal::decode(wal).unwrap();
186 let db = to_db(&db_header, &wal).unwrap();
187
188 open_db(
189 &db,
190 Box::new(move |conn| {
191 let tables = table_list(&conn);
192 assert!(
193 tables.contains(&"test".to_owned()),
194 "`test` table was not found; meaning WAL wasn't applied correctly."
195 );
196 }),
197 );
198 }
199
200 #[test]
201 fn it_applies_wal_on_top_of_db() {
202 let db = include_bytes!("../test/existing.db3");
203 let mut db = sqlite_decoder::db::decode(db).unwrap();
204
205 {
206 let wal = include_bytes!("../test/create-test-table.wal");
207 let wal = sqlite_decoder::wal::decode(wal).unwrap();
208
209 backfill(&mut db, &wal).unwrap();
210
211 open_db(
212 &db,
213 Box::new(move |conn| {
214 let tables = table_list(&conn);
215 assert!(
216 tables.contains(&"test".to_owned()),
217 "`test` table was not found; WAL wasn't applied correctly."
218 );
219 }),
220 );
221 }
222
223 {
224 let wal = include_bytes!("../test/create-test-and-test2-table.wal");
225 let wal = sqlite_decoder::wal::decode(wal).unwrap();
226
227 backfill(&mut db, &wal).unwrap();
228
229 open_db(
230 &db,
231 Box::new(move |conn| {
232 let tables = table_list(&conn);
233 assert!(
234 tables.contains(&"test".to_owned()),
235 "`test` table was not found."
236 );
237 assert!(
238 tables.contains(&"test2".to_owned()),
239 "`test2` table was not found; second WAL wasn't applied correctly."
240 );
241 }),
242 );
243 }
244
245 {
246 let wal = include_bytes!("../test/test-data.wal");
247 let wal = sqlite_decoder::wal::decode(wal).unwrap();
248
249 backfill(&mut db, &wal).unwrap();
250
251 open_db(
252 &db,
253 Box::new(move |conn| {
254 let tables = table_list(&conn);
255 assert!(
256 tables.contains(&"test".to_owned()),
257 "`test` table was not found."
258 );
259
260 let mut stmt = conn.prepare("select count(*) from test;").unwrap();
261 let count: usize = stmt.query_row([], |row| row.get(0)).unwrap();
262 assert_eq!(count, 65);
263
264 let page_count: usize = pragma(&conn, "page_count");
265 assert_eq!(page_count, 19);
266 }),
267 );
268 }
269
270 {
271 let wal = include_bytes!("../test/delete-test-table.wal");
272 let wal = sqlite_decoder::wal::decode(wal).unwrap();
273
274 backfill(&mut db, &wal).unwrap();
275
276 open_db(
277 &db,
278 Box::new(move |conn| {
279 let tables = table_list(&conn);
280 assert!(
281 !tables.contains(&"test".to_owned()),
282 "`test` table was found; WAL wasn't applied correctly"
283 );
284
285 let page_count: usize = pragma(&conn, "page_count");
286 assert_eq!(page_count, 18);
287 }),
288 );
289 }
290
291 {
292 let wal = include_bytes!("../test/vacuum.wal");
293 let wal = sqlite_decoder::wal::decode(wal).unwrap();
294
295 backfill(&mut db, &wal).unwrap();
296
297 open_db(
298 &db,
299 Box::new(move |conn| {
300 let page_count: usize = pragma(&conn, "page_count");
301 assert_eq!(page_count, 1);
302 }),
303 );
304 }
305 }
306}