1use super::{types::ManifestVersion3, zip::Reader, Error, Result};
2use crate::entity::{
3 AccountEntity, AccountRecord, AccountRow, EventEntity, EventRecordRow,
4 FolderEntity, FolderRow, PreferenceEntity, PreferenceRow, SecretRow,
5 ServerEntity, ServerRow, SystemMessageEntity, SystemMessageRow,
6};
7use async_sqlite::rusqlite::Connection;
8use sha2::{Digest, Sha256};
9use sos_core::{
10 commit::CommitHash,
11 constants::{BLOBS_DIR, DATABASE_FILE},
12 events::EventLogType,
13 AccountId, ExternalFile, Paths,
14};
15use sos_vfs as vfs;
16use std::{
17 collections::HashMap,
18 io::{self, BufWriter, Write},
19 path::Path,
20};
21use tempfile::NamedTempFile;
22use tokio::io::BufReader;
23
24struct HashingWriter<W: Write, H: Digest> {
25 inner: W,
26 hasher: H,
27}
28
29impl<W: Write, H: Digest> Write for HashingWriter<W, H> {
30 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
31 self.hasher.update(buf);
32 self.inner.write(buf)
33 }
34
35 fn flush(&mut self) -> io::Result<()> {
36 self.inner.flush()
37 }
38}
39
40struct ImportDataSource {
42 account_row: AccountRow,
43 account_events: Vec<EventRecordRow>,
44 login_folder: (FolderRow, Vec<SecretRow>, Vec<EventRecordRow>),
45 device_folder: Option<(FolderRow, Vec<SecretRow>, Vec<EventRecordRow>)>,
46 user_folders: Vec<(FolderRow, Vec<SecretRow>, Vec<EventRecordRow>)>,
47 file_events: Vec<EventRecordRow>,
48 servers: Vec<ServerRow>,
49 account_preferences: Vec<PreferenceRow>,
50 system_messages: Vec<SystemMessageRow>,
51}
52
53pub struct BackupImport<'conn> {
55 source_db: Box<Connection>,
58 target_db: &'conn mut Connection,
59 paths: Paths,
60 #[allow(dead_code)]
61 manifest: ManifestVersion3,
62 #[allow(dead_code)]
65 db_temp: NamedTempFile,
66 blobs: HashMap<AccountId, Vec<ExternalFile>>,
67 zip_reader: Reader<BufReader<vfs::File>>,
68}
69
70impl<'conn> BackupImport<'conn> {
71 pub fn list_source_accounts(&self) -> Result<Vec<AccountRecord>> {
73 let accounts = AccountEntity::new(&self.source_db);
74 let rows = accounts.list_accounts()?;
75 let mut records = Vec::new();
76 for row in rows {
77 records.push(row.try_into()?);
78 }
79 Ok(records)
80 }
81
82 pub fn list_target_accounts(&self) -> Result<Vec<AccountRecord>> {
84 let accounts = AccountEntity::new(&self.target_db);
85 let rows = accounts.list_accounts()?;
86 let mut records = Vec::new();
87 for row in rows {
88 records.push(row.try_into()?);
89 }
90 Ok(records)
91 }
92
93 pub fn migrate_source(&mut self) -> Result<refinery::Report> {
95 Ok(crate::migrations::migrate_connection(&mut self.source_db)?)
96 }
97
98 pub fn migrate_target(&mut self) -> Result<refinery::Report> {
100 Ok(crate::migrations::migrate_connection(self.target_db)?)
101 }
102
103 pub async fn import_account(
109 &mut self,
110 record: &AccountRecord,
111 ) -> Result<()> {
112 let account_row = {
114 let accounts = AccountEntity::new(&self.source_db);
115 let account =
116 accounts.find_optional(record.identity.account_id())?;
117
118 let Some(account_row) = account else {
119 return Err(Error::ImportSourceNotExists(
120 *record.identity.account_id(),
121 ));
122 };
123
124 account_row
125 };
126
127 {
129 let accounts = AccountEntity::new(&self.target_db);
130 let target_account =
131 accounts.find_optional(record.identity.account_id())?;
132
133 if target_account.is_some() {
134 return Err(Error::ImportTargetExists(
135 *record.identity.account_id(),
136 ));
137 }
138 }
139
140 let account_paths =
141 self.paths.with_account_id(record.identity.account_id());
142
143 let data_source = self.read_import_data_source(account_row)?;
145
146 self.write_import_data_source(data_source)?;
148
149 if let Some(files) = self.blobs.get(record.identity.account_id()) {
151 for file in files {
152 let entry_name = format!(
153 "{}/{}/{}/{}/{}",
154 BLOBS_DIR,
155 record.identity.account_id(),
156 file.vault_id(),
157 file.secret_id(),
158 file.file_name(),
159 );
160 let target = account_paths.blob_location(
161 file.vault_id(),
162 file.secret_id(),
163 file.file_name().to_string(),
164 );
165 let blob_buffer =
166 self.zip_reader.by_name(&entry_name).await?.unwrap();
167
168 if let Some(parent) = target.parent() {
169 vfs::create_dir_all(parent).await?;
170 }
171 vfs::write(&target, &blob_buffer).await?;
172 }
173 }
174
175 Ok(())
176 }
177
178 fn read_import_data_source(
180 &self,
181 account_row: AccountRow,
182 ) -> Result<ImportDataSource> {
183 let account_id = account_row.row_id;
184
185 let folder_entity = FolderEntity::new(&self.source_db);
186 let event_entity = EventEntity::new(&self.source_db);
187 let server_entity = ServerEntity::new(&self.source_db);
188 let preference_entity = PreferenceEntity::new(&self.source_db);
189 let system_messages_entity =
190 SystemMessageEntity::new(&self.source_db);
191
192 let account_events = event_entity.load_events(
194 EventLogType::Account,
195 account_id,
196 None,
197 )?;
198
199 let login_folder = folder_entity.find_login_folder(account_id)?;
201 let login_secrets =
202 folder_entity.load_secrets(login_folder.row_id)?;
203 let login_events = event_entity.load_events(
204 EventLogType::Identity,
205 account_id,
206 Some(login_folder.row_id),
207 )?;
208
209 let device_folder = folder_entity.find_device_folder(account_id)?;
211 let device_folder = if let Some(device_folder) = device_folder {
212 let device_events = event_entity.load_events(
213 EventLogType::Identity,
214 account_id,
215 Some(device_folder.row_id),
216 )?;
217 let device_secrets =
218 folder_entity.load_secrets(device_folder.row_id)?;
219 Some((device_folder, device_secrets, device_events))
220 } else {
221 None
222 };
223
224 let folders = folder_entity.list_user_folders(account_id)?;
226 let mut user_folders = Vec::new();
227 for user_folder in folders {
228 let folder_events = event_entity.load_events(
229 EventLogType::Identity,
230 account_id,
231 Some(user_folder.row_id),
232 )?;
233 let folder_secrets =
234 folder_entity.load_secrets(user_folder.row_id)?;
235 user_folders.push((user_folder, folder_secrets, folder_events));
236 }
237
238 let file_events = event_entity.load_events(
240 EventLogType::Files,
241 account_id,
242 None,
243 )?;
244
245 let servers = server_entity.load_servers(account_id)?;
247 let account_preferences =
248 preference_entity.load_preferences(Some(account_id))?;
249 let system_messages =
250 system_messages_entity.load_system_messages(account_id)?;
251
252 let data_source = ImportDataSource {
254 account_row,
255 account_events,
256 login_folder: (login_folder, login_secrets, login_events),
257 device_folder,
258 user_folders,
259 file_events,
260 servers,
261 account_preferences,
262 system_messages,
263 };
264
265 Ok(data_source)
266 }
267
268 fn write_import_data_source(
270 &mut self,
271 data: ImportDataSource,
272 ) -> Result<()> {
273 let tx = self.target_db.transaction()?;
274
275 let account_entity = AccountEntity::new(&tx);
276 let folder_entity = FolderEntity::new(&tx);
277 let event_entity = EventEntity::new(&tx);
278 let server_entity = ServerEntity::new(&tx);
279 let preference_entity = PreferenceEntity::new(&tx);
280 let system_messages_entity = SystemMessageEntity::new(&tx);
281
282 let account_id = account_entity.insert(&data.account_row)?;
284
285 event_entity
287 .insert_account_events(account_id, &data.account_events)?;
288
289 let login_folder_id =
291 folder_entity.insert_folder(account_id, &data.login_folder.0)?;
292 folder_entity
293 .insert_folder_secrets(login_folder_id, &data.login_folder.1)?;
294 event_entity
295 .insert_folder_events(login_folder_id, &data.login_folder.2)?;
296 account_entity.insert_login_folder(account_id, login_folder_id)?;
297
298 if let Some((device_folder, device_secrets, device_events)) =
300 &data.device_folder
301 {
302 let device_folder_id =
303 folder_entity.insert_folder(account_id, device_folder)?;
304 folder_entity
305 .insert_folder_secrets(device_folder_id, device_secrets)?;
306 event_entity
307 .insert_device_events(device_folder_id, device_events)?;
308 account_entity
309 .insert_device_folder(account_id, device_folder_id)?;
310 }
311
312 for (folder, secrets, events) in &data.user_folders {
314 let folder_id =
315 folder_entity.insert_folder(account_id, folder)?;
316 folder_entity.insert_folder_secrets(folder_id, secrets)?;
317 event_entity.insert_folder_events(folder_id, events)?;
318 }
319
320 event_entity.insert_file_events(account_id, &data.file_events)?;
322
323 server_entity.insert_servers(account_id, &data.servers)?;
325 preference_entity.insert_preferences(
326 Some(account_id),
327 &data.account_preferences,
328 )?;
329 system_messages_entity
330 .insert_system_messages(account_id, &data.system_messages)?;
331
332 tx.commit()?;
333
334 Ok(())
335 }
336}
337
338pub(crate) async fn start<'conn>(
347 target_db: &'conn mut Connection,
348 paths: &Paths,
349 input: impl AsRef<Path>,
350 ) -> Result<BackupImport<'conn>> {
352 if !vfs::try_exists(input.as_ref()).await? {
353 return Err(Error::ArchiveFileNotExists(input.as_ref().to_owned()));
354 }
355
356 let zip_file = BufReader::new(vfs::File::open(input.as_ref()).await?);
357 let mut zip_reader = Reader::new(zip_file).await?;
358 let manifest = zip_reader.find_manifest().await?.ok_or_else(|| {
359 Error::InvalidArchiveManifest(input.as_ref().to_owned())
360 })?;
361
362 let blobs = zip_reader.find_blobs()?;
363
364 let db_buffer =
366 zip_reader.by_name(DATABASE_FILE).await?.ok_or_else(|| {
367 Error::NoDatabaseFile(
368 input.as_ref().to_owned(),
369 DATABASE_FILE.to_owned(),
370 )
371 })?;
372 let mut db_temp = NamedTempFile::new()?;
373
374 let checksum = {
375 let buf_writer = BufWriter::new(db_temp.as_file_mut());
376 let mut hash_writer = HashingWriter {
377 inner: buf_writer,
378 hasher: Sha256::new(),
379 };
380
381 hash_writer.write_all(&db_buffer)?;
382 hash_writer.flush()?;
383
384 let digest = hash_writer.hasher.finalize();
385 CommitHash(digest.as_slice().try_into()?)
386 };
387
388 if checksum != manifest.checksum {
389 return Err(Error::DatabaseChecksum(manifest.checksum, checksum));
390 }
391
392 let source_db = Connection::open(db_temp.path())?;
393 let import = BackupImport {
394 target_db,
395 paths: paths.clone(),
396 manifest,
397 db_temp,
398 source_db: Box::new(source_db),
399 blobs,
400 zip_reader,
401 };
402
403 Ok(import)
404}