sos_database/archive/
import.rs

1use super::{types::ManifestVersion3, Error, Result};
2use crate::entity::{
3    AccountEntity, AccountRecord, AccountRow, EventEntity, EventRecordRow,
4    FolderEntity, FolderRow, PreferenceEntity, PreferenceRow, SecretRow,
5    ServerEntity, ServerRow, SystemMessageEntity, SystemMessageRow,
6};
7use async_sqlite::rusqlite::Connection;
8use sha2::{Digest, Sha256};
9use sos_archive::{sanitize_file_path, ZipReader};
10use sos_core::{
11    commit::CommitHash,
12    constants::{BLOBS_DIR, DATABASE_FILE},
13    events::EventLogType,
14    AccountId, ExternalFile, ExternalFileName, Paths, SecretId, SecretPath,
15    VaultId,
16};
17use sos_vfs as vfs;
18use std::{
19    collections::HashMap,
20    io::{self, BufWriter, Write},
21    path::Path,
22};
23use tempfile::NamedTempFile;
24use tokio::io::BufReader;
25
26struct HashingWriter<W: Write, H: Digest> {
27    inner: W,
28    hasher: H,
29}
30
31impl<W: Write, H: Digest> Write for HashingWriter<W, H> {
32    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
33        self.hasher.update(buf);
34        self.inner.write(buf)
35    }
36
37    fn flush(&mut self) -> io::Result<()> {
38        self.inner.flush()
39    }
40}
41
42/// Data source for an account import.
43struct ImportDataSource {
44    account_row: AccountRow,
45    account_events: Vec<EventRecordRow>,
46    login_folder: (FolderRow, Vec<SecretRow>, Vec<EventRecordRow>),
47    device_folder: Option<(FolderRow, Vec<SecretRow>, Vec<EventRecordRow>)>,
48    user_folders: Vec<(FolderRow, Vec<SecretRow>, Vec<EventRecordRow>)>,
49    file_events: Vec<EventRecordRow>,
50    servers: Vec<ServerRow>,
51    account_preferences: Vec<PreferenceRow>,
52    system_messages: Vec<SystemMessageRow>,
53}
54
55/// Backup import.
56pub struct BackupImport {
57    // Box the connection so it implements Deref<Target = Connection>
58    // which database entities use so they can accept transactions
59    source_db: Box<Connection>,
60    target_db: Box<Connection>,
61    paths: Paths,
62    #[allow(dead_code)]
63    manifest: ManifestVersion3,
64    // Ensure the temp file is not deleted
65    // until this struct is dropped
66    #[allow(dead_code)]
67    db_temp: NamedTempFile,
68    blobs: HashMap<AccountId, Vec<ExternalFile>>,
69    zip_reader: ZipReader<BufReader<vfs::File>>,
70}
71
72impl BackupImport {
73    /// List accounts in the temporary source database.
74    pub fn list_source_accounts(&self) -> Result<Vec<AccountRecord>> {
75        let accounts = AccountEntity::new(&self.source_db);
76        let rows = accounts.list_accounts()?;
77        let mut records = Vec::new();
78        for row in rows {
79            records.push(row.try_into()?);
80        }
81        Ok(records)
82    }
83
84    /// List accounts in the target database.
85    pub fn list_target_accounts(&self) -> Result<Vec<AccountRecord>> {
86        let accounts = AccountEntity::new(&self.target_db);
87        let rows = accounts.list_accounts()?;
88        let mut records = Vec::new();
89        for row in rows {
90            records.push(row.try_into()?);
91        }
92        Ok(records)
93    }
94
95    /// Run migrations on the temporary source database.
96    pub fn migrate_source(&mut self) -> Result<refinery::Report> {
97        Ok(crate::migrations::migrate_connection(&mut self.source_db)?)
98    }
99
100    /// Run migrations on the target database.
101    pub fn migrate_target(&mut self) -> Result<refinery::Report> {
102        Ok(crate::migrations::migrate_connection(&mut *self.target_db)?)
103    }
104
105    /// Try to import an account from the source to the
106    /// target database.
107    ///
108    /// It is an error if the account already exists in
109    /// the target database.
110    pub async fn import_account(
111        &mut self,
112        record: &AccountRecord,
113    ) -> Result<()> {
114        // Check account exists in the source db
115        let account_row = {
116            let accounts = AccountEntity::new(&self.source_db);
117            let account =
118                accounts.find_optional(record.identity.account_id())?;
119
120            let Some(account_row) = account else {
121                return Err(Error::ImportSourceNotExists(
122                    *record.identity.account_id(),
123                ));
124            };
125
126            account_row
127        };
128
129        // Check account does not exist in target
130        {
131            let accounts = AccountEntity::new(&self.target_db);
132            let target_account =
133                accounts.find_optional(record.identity.account_id())?;
134
135            if target_account.is_some() {
136                return Err(Error::ImportTargetExists(
137                    *record.identity.account_id(),
138                ));
139            }
140        }
141
142        let account_paths =
143            self.paths.with_account_id(record.identity.account_id());
144
145        // Read data from the source db
146        let data_source = self.read_import_data_source(account_row)?;
147
148        // Write data to the target db
149        self.write_import_data_source(data_source)?;
150
151        // Extract blobs for this account
152        if let Some(files) = self.blobs.get(record.identity.account_id()) {
153            for file in files {
154                let entry_name = format!(
155                    "{}/{}/{}/{}/{}",
156                    BLOBS_DIR,
157                    record.identity.account_id(),
158                    file.vault_id(),
159                    file.secret_id(),
160                    file.file_name(),
161                );
162                let target = account_paths.into_file_path(file);
163                let blob_buffer =
164                    self.zip_reader.by_name(&entry_name).await?.unwrap();
165
166                if let Some(parent) = target.parent() {
167                    vfs::create_dir_all(parent).await?;
168                }
169                vfs::write(&target, &blob_buffer).await?;
170            }
171        }
172
173        Ok(())
174    }
175
176    /// Read import data into memory from the source db.
177    fn read_import_data_source(
178        &self,
179        account_row: AccountRow,
180    ) -> Result<ImportDataSource> {
181        let account_id = account_row.row_id;
182
183        let folder_entity = FolderEntity::new(&self.source_db);
184        let event_entity = EventEntity::new(&self.source_db);
185        let server_entity = ServerEntity::new(&self.source_db);
186        let preference_entity = PreferenceEntity::new(&self.source_db);
187        let system_messages_entity =
188            SystemMessageEntity::new(&self.source_db);
189
190        // Account events
191        let account_events = event_entity.load_events(
192            EventLogType::Account,
193            account_id,
194            None,
195        )?;
196
197        // Login folder
198        let login_folder = folder_entity.find_login_folder(account_id)?;
199        let login_secrets =
200            folder_entity.load_secrets(login_folder.row_id)?;
201        let login_events = event_entity.load_events(
202            EventLogType::Identity,
203            account_id,
204            Some(login_folder.row_id),
205        )?;
206
207        // Device folder
208        let device_folder = folder_entity.find_device_folder(account_id)?;
209        let device_folder = if let Some(device_folder) = device_folder {
210            let device_events = event_entity.load_events(
211                EventLogType::Identity,
212                account_id,
213                Some(device_folder.row_id),
214            )?;
215            let device_secrets =
216                folder_entity.load_secrets(device_folder.row_id)?;
217            Some((device_folder, device_secrets, device_events))
218        } else {
219            None
220        };
221
222        // User defined folders
223        let folders = folder_entity.list_user_folders(account_id)?;
224        let mut user_folders = Vec::new();
225        for user_folder in folders {
226            let folder_events = event_entity.load_events(
227                EventLogType::Identity,
228                account_id,
229                Some(user_folder.row_id),
230            )?;
231            let folder_secrets =
232                folder_entity.load_secrets(user_folder.row_id)?;
233            user_folders.push((user_folder, folder_secrets, folder_events));
234        }
235
236        // File events
237        let file_events = event_entity.load_events(
238            EventLogType::Files,
239            account_id,
240            None,
241        )?;
242
243        // Servers, preferences and system messages
244        let servers = server_entity.load_servers(account_id)?;
245        let account_preferences =
246            preference_entity.load_preferences(Some(account_id))?;
247        let system_messages =
248            system_messages_entity.load_system_messages(account_id)?;
249
250        // Data source
251        let data_source = ImportDataSource {
252            account_row,
253            account_events,
254            login_folder: (login_folder, login_secrets, login_events),
255            device_folder,
256            user_folders,
257            file_events,
258            servers,
259            account_preferences,
260            system_messages,
261        };
262
263        Ok(data_source)
264    }
265
266    /// Write import data source into the target db using a transaction.
267    fn write_import_data_source(
268        &mut self,
269        data: ImportDataSource,
270    ) -> Result<()> {
271        let tx = self.target_db.transaction()?;
272
273        let account_entity = AccountEntity::new(&tx);
274        let folder_entity = FolderEntity::new(&tx);
275        let event_entity = EventEntity::new(&tx);
276        let server_entity = ServerEntity::new(&tx);
277        let preference_entity = PreferenceEntity::new(&tx);
278        let system_messages_entity = SystemMessageEntity::new(&tx);
279
280        // Insert the account
281        let account_id = account_entity.insert(&data.account_row)?;
282
283        // Create account events
284        event_entity
285            .insert_account_events(account_id, &data.account_events)?;
286
287        // Login folder
288        let login_folder_id =
289            folder_entity.insert_folder(account_id, &data.login_folder.0)?;
290        folder_entity
291            .insert_folder_secrets(login_folder_id, &data.login_folder.1)?;
292        event_entity
293            .insert_folder_events(login_folder_id, &data.login_folder.2)?;
294        account_entity.insert_login_folder(account_id, login_folder_id)?;
295
296        // Device folder
297        if let Some((device_folder, device_secrets, device_events)) =
298            &data.device_folder
299        {
300            let device_folder_id =
301                folder_entity.insert_folder(account_id, device_folder)?;
302            folder_entity
303                .insert_folder_secrets(device_folder_id, device_secrets)?;
304            event_entity
305                .insert_device_events(device_folder_id, device_events)?;
306            account_entity
307                .insert_device_folder(account_id, device_folder_id)?;
308        }
309
310        // User folders
311        for (folder, secrets, events) in &data.user_folders {
312            let folder_id =
313                folder_entity.insert_folder(account_id, folder)?;
314            folder_entity.insert_folder_secrets(folder_id, secrets)?;
315            event_entity.insert_folder_events(folder_id, events)?;
316        }
317
318        // Create file events
319        event_entity.insert_file_events(account_id, &data.file_events)?;
320
321        // Servers, preferences and system messages
322        server_entity.insert_servers(account_id, &data.servers)?;
323        preference_entity.insert_preferences(
324            Some(account_id),
325            &data.account_preferences,
326        )?;
327        system_messages_entity
328            .insert_system_messages(account_id, &data.system_messages)?;
329
330        tx.commit()?;
331
332        Ok(())
333    }
334}
335
336/// Start importing a backup archive.
337///
338/// Reads the archive manifest and extracts the archive database file
339/// to a temporary file and prepares the database connections.
340///
341/// The returned struct will hold the temporary file and connections
342/// in memory until dropped and can be used to inspect the accounts in the
343/// archive and perform imports.
344pub(crate) async fn start(
345    // target_db: &'conn mut Connection,
346    target_db: impl AsRef<Path>,
347    paths: &Paths,
348    input: impl AsRef<Path>,
349    // progress: fn(backup::Progress),
350) -> Result<BackupImport> {
351    if !vfs::try_exists(input.as_ref()).await? {
352        return Err(Error::ArchiveFileNotExists(input.as_ref().to_owned()));
353    }
354
355    let zip_file = BufReader::new(vfs::File::open(input.as_ref()).await?);
356    let mut zip_reader = ZipReader::new(zip_file).await?;
357    let manifest = zip_reader
358        .find_manifest::<ManifestVersion3>()
359        .await?
360        .ok_or_else(|| {
361            Error::InvalidArchiveManifest(input.as_ref().to_owned())
362        })?;
363
364    let blobs = find_blobs(&zip_reader)?;
365
366    // Extract the database and write to a temp file
367    let db_buffer =
368        zip_reader.by_name(DATABASE_FILE).await?.ok_or_else(|| {
369            Error::NoDatabaseFile(
370                input.as_ref().to_owned(),
371                DATABASE_FILE.to_owned(),
372            )
373        })?;
374    let mut db_temp = NamedTempFile::new()?;
375
376    let checksum = {
377        let buf_writer = BufWriter::new(db_temp.as_file_mut());
378        let mut hash_writer = HashingWriter {
379            inner: buf_writer,
380            hasher: Sha256::new(),
381        };
382
383        hash_writer.write_all(&db_buffer)?;
384        hash_writer.flush()?;
385
386        let digest = hash_writer.hasher.finalize();
387        CommitHash(digest.as_slice().try_into()?)
388    };
389
390    if checksum != manifest.checksum {
391        return Err(Error::DatabaseChecksum(manifest.checksum, checksum));
392    }
393
394    let source_db = Connection::open(db_temp.path())?;
395    let target_db = Connection::open(target_db.as_ref())?;
396    let import = BackupImport {
397        target_db: Box::new(target_db),
398        paths: paths.clone(),
399        manifest,
400        db_temp,
401        source_db: Box::new(source_db),
402        blobs,
403        zip_reader,
404    };
405
406    Ok(import)
407}
408
409/// Find blobs embedded in the archive.
410fn find_blobs(
411    reader: &ZipReader<BufReader<vfs::File>>,
412) -> Result<HashMap<AccountId, Vec<ExternalFile>>> {
413    let mut out = HashMap::new();
414    for index in 0..reader.inner().file().entries().len() {
415        let entry = reader.inner().file().entries().get(index).unwrap();
416        let is_dir = entry.dir().map_err(sos_archive::Error::from)?;
417        if !is_dir {
418            let file_name = entry.filename();
419            let path = sanitize_file_path(
420                file_name.as_str().map_err(sos_archive::Error::from)?,
421            );
422            let mut it = path.iter();
423            if let (
424                Some(first),
425                Some(second),
426                Some(third),
427                Some(fourth),
428                Some(fifth),
429            ) = (it.next(), it.next(), it.next(), it.next(), it.next())
430            {
431                if first == BLOBS_DIR {
432                    if let Ok(account_id) =
433                        second.to_string_lossy().parse::<AccountId>()
434                    {
435                        let files =
436                            out.entry(account_id).or_insert(Vec::new());
437
438                        if let (Ok(folder_id), Ok(secret_id), Ok(file_name)) = (
439                            third.to_string_lossy().parse::<VaultId>(),
440                            fourth.to_string_lossy().parse::<SecretId>(),
441                            fifth
442                                .to_string_lossy()
443                                .parse::<ExternalFileName>(),
444                        ) {
445                            files.push(ExternalFile::new(
446                                SecretPath(folder_id, secret_id),
447                                file_name,
448                            ));
449                        }
450                    }
451                }
452            }
453        }
454    }
455    Ok(out)
456}