sos_database/archive/
import.rs

1use super::{types::ManifestVersion3, zip::Reader, Error, Result};
2use crate::entity::{
3    AccountEntity, AccountRecord, AccountRow, EventEntity, EventRecordRow,
4    FolderEntity, FolderRow, PreferenceEntity, PreferenceRow, SecretRow,
5    ServerEntity, ServerRow, SystemMessageEntity, SystemMessageRow,
6};
7use async_sqlite::rusqlite::Connection;
8use sha2::{Digest, Sha256};
9use sos_core::{
10    commit::CommitHash,
11    constants::{BLOBS_DIR, DATABASE_FILE},
12    events::EventLogType,
13    AccountId, ExternalFile, Paths,
14};
15use sos_vfs as vfs;
16use std::{
17    collections::HashMap,
18    io::{self, BufWriter, Write},
19    path::Path,
20};
21use tempfile::NamedTempFile;
22use tokio::io::BufReader;
23
24struct HashingWriter<W: Write, H: Digest> {
25    inner: W,
26    hasher: H,
27}
28
29impl<W: Write, H: Digest> Write for HashingWriter<W, H> {
30    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
31        self.hasher.update(buf);
32        self.inner.write(buf)
33    }
34
35    fn flush(&mut self) -> io::Result<()> {
36        self.inner.flush()
37    }
38}
39
40/// Data source for an account import.
41struct ImportDataSource {
42    account_row: AccountRow,
43    account_events: Vec<EventRecordRow>,
44    login_folder: (FolderRow, Vec<SecretRow>, Vec<EventRecordRow>),
45    device_folder: Option<(FolderRow, Vec<SecretRow>, Vec<EventRecordRow>)>,
46    user_folders: Vec<(FolderRow, Vec<SecretRow>, Vec<EventRecordRow>)>,
47    file_events: Vec<EventRecordRow>,
48    servers: Vec<ServerRow>,
49    account_preferences: Vec<PreferenceRow>,
50    system_messages: Vec<SystemMessageRow>,
51}
52
53/// Backup import.
54pub struct BackupImport<'conn> {
55    // Box the connection so it implements Deref<Target = Connection>
56    // which database entities use so they can accept transactions
57    source_db: Box<Connection>,
58    target_db: &'conn mut Connection,
59    paths: Paths,
60    #[allow(dead_code)]
61    manifest: ManifestVersion3,
62    // Ensure the temp file is not deleted
63    // until this struct is dropped
64    #[allow(dead_code)]
65    db_temp: NamedTempFile,
66    blobs: HashMap<AccountId, Vec<ExternalFile>>,
67    zip_reader: Reader<BufReader<vfs::File>>,
68}
69
70impl<'conn> BackupImport<'conn> {
71    /// List accounts in the temporary source database.
72    pub fn list_source_accounts(&self) -> Result<Vec<AccountRecord>> {
73        let accounts = AccountEntity::new(&self.source_db);
74        let rows = accounts.list_accounts()?;
75        let mut records = Vec::new();
76        for row in rows {
77            records.push(row.try_into()?);
78        }
79        Ok(records)
80    }
81
82    /// List accounts in the target database.
83    pub fn list_target_accounts(&self) -> Result<Vec<AccountRecord>> {
84        let accounts = AccountEntity::new(&self.target_db);
85        let rows = accounts.list_accounts()?;
86        let mut records = Vec::new();
87        for row in rows {
88            records.push(row.try_into()?);
89        }
90        Ok(records)
91    }
92
93    /// Run migrations on the temporary source database.
94    pub fn migrate_source(&mut self) -> Result<refinery::Report> {
95        Ok(crate::migrations::migrate_connection(&mut self.source_db)?)
96    }
97
98    /// Run migrations on the target database.
99    pub fn migrate_target(&mut self) -> Result<refinery::Report> {
100        Ok(crate::migrations::migrate_connection(self.target_db)?)
101    }
102
103    /// Try to import an account from the source to the
104    /// target database.
105    ///
106    /// It is an error if the account already exists in
107    /// the target database.
108    pub async fn import_account(
109        &mut self,
110        record: &AccountRecord,
111    ) -> Result<()> {
112        // Check account exists in the source db
113        let account_row = {
114            let accounts = AccountEntity::new(&self.source_db);
115            let account =
116                accounts.find_optional(record.identity.account_id())?;
117
118            let Some(account_row) = account else {
119                return Err(Error::ImportSourceNotExists(
120                    *record.identity.account_id(),
121                ));
122            };
123
124            account_row
125        };
126
127        // Check account does not exist in target
128        {
129            let accounts = AccountEntity::new(&self.target_db);
130            let target_account =
131                accounts.find_optional(record.identity.account_id())?;
132
133            if target_account.is_some() {
134                return Err(Error::ImportTargetExists(
135                    *record.identity.account_id(),
136                ));
137            }
138        }
139
140        let account_paths =
141            self.paths.with_account_id(record.identity.account_id());
142
143        // Read data from the source db
144        let data_source = self.read_import_data_source(account_row)?;
145
146        // Write data to the target db
147        self.write_import_data_source(data_source)?;
148
149        // Extract blobs for this account
150        if let Some(files) = self.blobs.get(record.identity.account_id()) {
151            for file in files {
152                let entry_name = format!(
153                    "{}/{}/{}/{}/{}",
154                    BLOBS_DIR,
155                    record.identity.account_id(),
156                    file.vault_id(),
157                    file.secret_id(),
158                    file.file_name(),
159                );
160                let target = account_paths.blob_location(
161                    file.vault_id(),
162                    file.secret_id(),
163                    file.file_name().to_string(),
164                );
165                let blob_buffer =
166                    self.zip_reader.by_name(&entry_name).await?.unwrap();
167
168                if let Some(parent) = target.parent() {
169                    vfs::create_dir_all(parent).await?;
170                }
171                vfs::write(&target, &blob_buffer).await?;
172            }
173        }
174
175        Ok(())
176    }
177
178    /// Read import data into memory from the source db.
179    fn read_import_data_source(
180        &self,
181        account_row: AccountRow,
182    ) -> Result<ImportDataSource> {
183        let account_id = account_row.row_id;
184
185        let folder_entity = FolderEntity::new(&self.source_db);
186        let event_entity = EventEntity::new(&self.source_db);
187        let server_entity = ServerEntity::new(&self.source_db);
188        let preference_entity = PreferenceEntity::new(&self.source_db);
189        let system_messages_entity =
190            SystemMessageEntity::new(&self.source_db);
191
192        // Account events
193        let account_events = event_entity.load_events(
194            EventLogType::Account,
195            account_id,
196            None,
197        )?;
198
199        // Login folder
200        let login_folder = folder_entity.find_login_folder(account_id)?;
201        let login_secrets =
202            folder_entity.load_secrets(login_folder.row_id)?;
203        let login_events = event_entity.load_events(
204            EventLogType::Identity,
205            account_id,
206            Some(login_folder.row_id),
207        )?;
208
209        // Device folder
210        let device_folder = folder_entity.find_device_folder(account_id)?;
211        let device_folder = if let Some(device_folder) = device_folder {
212            let device_events = event_entity.load_events(
213                EventLogType::Identity,
214                account_id,
215                Some(device_folder.row_id),
216            )?;
217            let device_secrets =
218                folder_entity.load_secrets(device_folder.row_id)?;
219            Some((device_folder, device_secrets, device_events))
220        } else {
221            None
222        };
223
224        // User defined folders
225        let folders = folder_entity.list_user_folders(account_id)?;
226        let mut user_folders = Vec::new();
227        for user_folder in folders {
228            let folder_events = event_entity.load_events(
229                EventLogType::Identity,
230                account_id,
231                Some(user_folder.row_id),
232            )?;
233            let folder_secrets =
234                folder_entity.load_secrets(user_folder.row_id)?;
235            user_folders.push((user_folder, folder_secrets, folder_events));
236        }
237
238        // File events
239        let file_events = event_entity.load_events(
240            EventLogType::Files,
241            account_id,
242            None,
243        )?;
244
245        // Servers, preferences and system messages
246        let servers = server_entity.load_servers(account_id)?;
247        let account_preferences =
248            preference_entity.load_preferences(Some(account_id))?;
249        let system_messages =
250            system_messages_entity.load_system_messages(account_id)?;
251
252        // Data source
253        let data_source = ImportDataSource {
254            account_row,
255            account_events,
256            login_folder: (login_folder, login_secrets, login_events),
257            device_folder,
258            user_folders,
259            file_events,
260            servers,
261            account_preferences,
262            system_messages,
263        };
264
265        Ok(data_source)
266    }
267
268    /// Write import data source into the target db using a transaction.
269    fn write_import_data_source(
270        &mut self,
271        data: ImportDataSource,
272    ) -> Result<()> {
273        let tx = self.target_db.transaction()?;
274
275        let account_entity = AccountEntity::new(&tx);
276        let folder_entity = FolderEntity::new(&tx);
277        let event_entity = EventEntity::new(&tx);
278        let server_entity = ServerEntity::new(&tx);
279        let preference_entity = PreferenceEntity::new(&tx);
280        let system_messages_entity = SystemMessageEntity::new(&tx);
281
282        // Insert the account
283        let account_id = account_entity.insert(&data.account_row)?;
284
285        // Create account events
286        event_entity
287            .insert_account_events(account_id, &data.account_events)?;
288
289        // Login folder
290        let login_folder_id =
291            folder_entity.insert_folder(account_id, &data.login_folder.0)?;
292        folder_entity
293            .insert_folder_secrets(login_folder_id, &data.login_folder.1)?;
294        event_entity
295            .insert_folder_events(login_folder_id, &data.login_folder.2)?;
296        account_entity.insert_login_folder(account_id, login_folder_id)?;
297
298        // Device folder
299        if let Some((device_folder, device_secrets, device_events)) =
300            &data.device_folder
301        {
302            let device_folder_id =
303                folder_entity.insert_folder(account_id, device_folder)?;
304            folder_entity
305                .insert_folder_secrets(device_folder_id, device_secrets)?;
306            event_entity
307                .insert_device_events(device_folder_id, device_events)?;
308            account_entity
309                .insert_device_folder(account_id, device_folder_id)?;
310        }
311
312        // User folders
313        for (folder, secrets, events) in &data.user_folders {
314            let folder_id =
315                folder_entity.insert_folder(account_id, folder)?;
316            folder_entity.insert_folder_secrets(folder_id, secrets)?;
317            event_entity.insert_folder_events(folder_id, events)?;
318        }
319
320        // Create file events
321        event_entity.insert_file_events(account_id, &data.file_events)?;
322
323        // Servers, preferences and system messages
324        server_entity.insert_servers(account_id, &data.servers)?;
325        preference_entity.insert_preferences(
326            Some(account_id),
327            &data.account_preferences,
328        )?;
329        system_messages_entity
330            .insert_system_messages(account_id, &data.system_messages)?;
331
332        tx.commit()?;
333
334        Ok(())
335    }
336}
337
338/// Start importing a backup archive.
339///
340/// Reads the archive manifest and extracts the archive database file
341/// to a temporary file and prepares the database connections.
342///
343/// The returned struct will hold the temporary file and connections
344/// in memory until dropped and can be used to inspect the accounts in the
345/// archive and perform imports.
346pub(crate) async fn start<'conn>(
347    target_db: &'conn mut Connection,
348    paths: &Paths,
349    input: impl AsRef<Path>,
350    // progress: fn(backup::Progress),
351) -> Result<BackupImport<'conn>> {
352    if !vfs::try_exists(input.as_ref()).await? {
353        return Err(Error::ArchiveFileNotExists(input.as_ref().to_owned()));
354    }
355
356    let zip_file = BufReader::new(vfs::File::open(input.as_ref()).await?);
357    let mut zip_reader = Reader::new(zip_file).await?;
358    let manifest = zip_reader.find_manifest().await?.ok_or_else(|| {
359        Error::InvalidArchiveManifest(input.as_ref().to_owned())
360    })?;
361
362    let blobs = zip_reader.find_blobs()?;
363
364    // Extract the database and write to a temp file
365    let db_buffer =
366        zip_reader.by_name(DATABASE_FILE).await?.ok_or_else(|| {
367            Error::NoDatabaseFile(
368                input.as_ref().to_owned(),
369                DATABASE_FILE.to_owned(),
370            )
371        })?;
372    let mut db_temp = NamedTempFile::new()?;
373
374    let checksum = {
375        let buf_writer = BufWriter::new(db_temp.as_file_mut());
376        let mut hash_writer = HashingWriter {
377            inner: buf_writer,
378            hasher: Sha256::new(),
379        };
380
381        hash_writer.write_all(&db_buffer)?;
382        hash_writer.flush()?;
383
384        let digest = hash_writer.hasher.finalize();
385        CommitHash(digest.as_slice().try_into()?)
386    };
387
388    if checksum != manifest.checksum {
389        return Err(Error::DatabaseChecksum(manifest.checksum, checksum));
390    }
391
392    let source_db = Connection::open(db_temp.path())?;
393    let import = BackupImport {
394        target_db,
395        paths: paths.clone(),
396        manifest,
397        db_temp,
398        source_db: Box::new(source_db),
399        blobs,
400        zip_reader,
401    };
402
403    Ok(import)
404}