sos_database/entity/
account.rs

1use crate::{
2    entity::{FolderEntity, FolderRecord, SecretRow},
3    Error, Result,
4};
5use async_sqlite::{
6    rusqlite::{
7        Connection, Error as SqlError, OptionalExtension, Row, Transaction,
8    },
9    Client,
10};
11use sos_core::{AccountId, PublicIdentity, UtcDateTime, VaultCommit};
12use sos_vault::Vault;
13use sql_query_builder as sql;
14use std::ops::Deref;
15
16use super::FolderRow;
17
18/// Account row from the database.
19#[doc(hidden)]
20#[derive(Debug, Default)]
21pub struct AccountRow {
22    /// Row identifier.
23    pub row_id: i64,
24    /// RFC3339 date and time.
25    created_at: String,
26    /// RFC3339 date and time.
27    modified_at: String,
28    /// Account identifier.
29    identifier: String,
30    /// Account name.
31    name: String,
32}
33
34impl AccountRow {
35    /// Create an account row for insertion.
36    pub fn new_insert(account_id: &AccountId, name: String) -> Result<Self> {
37        Ok(AccountRow {
38            identifier: account_id.to_string(),
39            name,
40            created_at: UtcDateTime::default().to_rfc3339()?,
41            modified_at: UtcDateTime::default().to_rfc3339()?,
42            ..Default::default()
43        })
44    }
45}
46
47impl<'a> TryFrom<&Row<'a>> for AccountRow {
48    type Error = SqlError;
49    fn try_from(row: &Row<'a>) -> std::result::Result<Self, Self::Error> {
50        Ok(AccountRow {
51            row_id: row.get(0)?,
52            created_at: row.get(1)?,
53            modified_at: row.get(2)?,
54            identifier: row.get(3)?,
55            name: row.get(4)?,
56        })
57    }
58}
59
60/// Account record from the database.
61#[derive(Debug)]
62pub struct AccountRecord {
63    /// Row identifier.
64    pub row_id: i64,
65    /// Created date and time.
66    pub created_at: UtcDateTime,
67    /// Modified date and time.
68    pub modified_at: UtcDateTime,
69    /// Account identity.
70    pub identity: PublicIdentity,
71}
72
73impl TryFrom<AccountRow> for AccountRecord {
74    type Error = Error;
75
76    fn try_from(value: AccountRow) -> std::result::Result<Self, Self::Error> {
77        let created_at = UtcDateTime::parse_rfc3339(&value.created_at)?;
78        let modified_at = UtcDateTime::parse_rfc3339(&value.modified_at)?;
79        let account_id: AccountId = value.identifier.parse()?;
80        Ok(AccountRecord {
81            row_id: value.row_id,
82            created_at,
83            modified_at,
84            identity: PublicIdentity::new(account_id, value.name),
85        })
86    }
87}
88
89/// Account entity.
90pub struct AccountEntity<'conn, C>
91where
92    C: Deref<Target = Connection>,
93{
94    conn: &'conn C,
95}
96
97impl<'conn> AccountEntity<'conn, Box<Connection>> {
98    /// Liat all accounts.
99    pub async fn list_all_accounts(
100        client: &Client,
101    ) -> Result<Vec<AccountRecord>> {
102        let account_rows = client
103            .conn_and_then(move |conn| {
104                let account = AccountEntity::new(&conn);
105                account.list_accounts()
106            })
107            .await?;
108
109        let mut accounts = Vec::new();
110        for row in account_rows {
111            accounts.push(row.try_into()?);
112        }
113        Ok(accounts)
114    }
115
116    /// Find an account and login folder.
117    pub async fn find_account_with_login(
118        client: &Client,
119        account_id: &AccountId,
120    ) -> Result<(AccountRecord, FolderRecord)> {
121        let (account, folder_row) =
122            Self::find_account_with_login_optional(client, account_id)
123                .await?;
124
125        let account_id = account.row_id;
126        Ok((
127            account,
128            folder_row.ok_or_else(|| Error::NoLoginFolder(account_id))?,
129        ))
130    }
131
132    /// Find an account and optional login folder.
133    pub async fn find_account_with_login_optional(
134        client: &Client,
135        account_id: &AccountId,
136    ) -> Result<(AccountRecord, Option<FolderRecord>)> {
137        let account_id = *account_id;
138        let (account_row, folder_row) = client
139            .conn_and_then(move |conn| {
140                let account = AccountEntity::new(&conn);
141                let account_row = account.find_one(&account_id)?;
142                let folders = FolderEntity::new(&conn);
143                let folder_row =
144                    folders.find_login_folder_optional(account_row.row_id)?;
145                Ok::<_, Error>((account_row, folder_row))
146            })
147            .await?;
148
149        let login_folder = if let Some(folder_row) = folder_row {
150            Some(FolderRecord::from_row(folder_row).await?)
151        } else {
152            None
153        };
154        Ok((account_row.try_into()?, login_folder))
155    }
156}
157
158impl<'conn> AccountEntity<'conn, Transaction<'conn>> {
159    /// Upsert the login folder.
160    pub async fn upsert_login_folder(
161        client: &Client,
162        account_id: &AccountId,
163        vault: &Vault,
164    ) -> Result<(AccountRecord, i64)> {
165        // Check if we already have a login folder
166        let (account, folder) =
167            AccountEntity::find_account_with_login_optional(
168                client, account_id,
169            )
170            .await?;
171
172        // TODO: folder creation and join should be merged into a single
173        // TODO: transaction
174
175        // Create or update the folder and secrets
176        let (folder_row_id, _) = FolderEntity::upsert_folder_and_secrets(
177            client,
178            account.row_id,
179            vault,
180        )
181        .await?;
182
183        let account_row_id = account.row_id;
184
185        // Update or insert the join
186        if folder.is_some() {
187            client
188                .conn(move |conn| {
189                    let account_entity = AccountEntity::new(&conn);
190                    account_entity
191                        .update_login_folder(account_row_id, folder_row_id)
192                })
193                .await?;
194        } else {
195            client
196                .conn(move |conn| {
197                    let account_entity = AccountEntity::new(&conn);
198                    account_entity
199                        .insert_login_folder(account_row_id, folder_row_id)
200                })
201                .await?;
202        }
203
204        Ok((account, folder_row_id))
205    }
206
207    /// Replace the login folder.
208    pub async fn replace_login_folder(
209        client: &mut Client,
210        account_id: &AccountId,
211        vault: &Vault,
212    ) -> Result<()> {
213        // Check if we already have a login folder
214        let (account, login_folder) =
215            AccountEntity::find_account_with_login(client, account_id)
216                .await?;
217
218        let login_folder_id = *login_folder.summary.id();
219        let new_login_folder = FolderRow::new_insert(vault).await?;
220
221        let mut secret_rows = Vec::new();
222        for (secret_id, commit) in vault.iter() {
223            let VaultCommit(commit, entry) = commit;
224            secret_rows.push(SecretRow::new(secret_id, commit, entry).await?);
225        }
226
227        client
228            .conn_mut_and_then(move |conn| {
229                let tx = conn.transaction()?;
230                let account_entity = AccountEntity::new(&tx);
231                let folder_entity = FolderEntity::new(&tx);
232
233                // Delete the old folder
234                folder_entity.delete_folder(&login_folder_id)?;
235
236                // Create the new folder
237                let folder_row_id = folder_entity
238                    .insert_folder(account.row_id, &new_login_folder)?;
239
240                // Insert the secrets
241                folder_entity.insert_folder_secrets(
242                    folder_row_id,
243                    secret_rows.as_slice(),
244                )?;
245
246                // Update the join
247                account_entity
248                    .update_login_folder(account.row_id, folder_row_id)?;
249
250                tx.commit()?;
251                Ok::<_, Error>(())
252            })
253            .await?;
254
255        Ok(())
256    }
257}
258
259impl<'conn, C> AccountEntity<'conn, C>
260where
261    C: Deref<Target = Connection>,
262{
263    /// Create a new account entity.
264    pub fn new(conn: &'conn C) -> Self {
265        Self { conn }
266    }
267
268    fn account_select_columns(&self, sql: sql::Select) -> sql::Select {
269        sql.select(
270            r#"
271                account_id,
272                created_at,
273                modified_at,
274                identifier,
275                name
276            "#,
277        )
278    }
279
280    /// Find an account in the database.
281    pub fn find_one(
282        &self,
283        account_id: &AccountId,
284    ) -> std::result::Result<AccountRow, SqlError> {
285        let query = self
286            .account_select_columns(sql::Select::new())
287            .from("accounts")
288            .where_clause("identifier = ?1");
289        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
290        stmt
291            .query_row([account_id.to_string()], |row| row.try_into())
292    }
293
294    /// Find an optional account in the database.
295    pub fn find_optional(
296        &self,
297        account_id: &AccountId,
298    ) -> std::result::Result<Option<AccountRow>, SqlError> {
299        let query = self
300            .account_select_columns(sql::Select::new())
301            .from("accounts")
302            .where_clause("identifier = ?1");
303        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
304        stmt
305            .query_row([account_id.to_string()], |row| row.try_into())
306            .optional()
307    }
308
309    /// List accounts.
310    pub fn list_accounts(&self) -> Result<Vec<AccountRow>> {
311        let query = self
312            .account_select_columns(sql::Select::new())
313            .from("accounts");
314
315        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
316
317        fn convert_row(row: &Row<'_>) -> Result<AccountRow> {
318            Ok(row.try_into()?)
319        }
320
321        let rows = stmt.query_and_then([], convert_row)?;
322        let mut accounts = Vec::new();
323        for row in rows {
324            accounts.push(row?);
325        }
326        Ok(accounts)
327    }
328
329    /// Create the account entity in the database.
330    pub fn insert(
331        &self,
332        row: &AccountRow,
333    ) -> std::result::Result<i64, SqlError> {
334        let query = sql::Insert::new()
335            .insert_into(
336                "accounts (created_at, modified_at, identifier, name)",
337            )
338            .values("(?1, ?2, ?3, ?4)");
339        self.conn.execute(
340            &query.as_string(),
341            (
342                &row.created_at,
343                &row.modified_at,
344                &row.identifier,
345                &row.name,
346            ),
347        )?;
348        Ok(self.conn.last_insert_rowid())
349    }
350
351    /// Create the join for the account login folder.
352    pub fn insert_login_folder(
353        &self,
354        account_id: i64,
355        folder_id: i64,
356    ) -> std::result::Result<i64, SqlError> {
357        let query = sql::Insert::new()
358            .insert_into("account_login_folder (account_id, folder_id)")
359            .values("(?1, ?2)");
360        self.conn
361            .execute(&query.as_string(), [account_id, folder_id])?;
362        Ok(self.conn.last_insert_rowid())
363    }
364
365    /// Update the join for an account login folder.
366    pub fn update_login_folder(
367        &self,
368        account_id: i64,
369        folder_id: i64,
370    ) -> std::result::Result<(), SqlError> {
371        let query = sql::Update::new()
372            .update("account_login_folder")
373            .set("folder_id = ?2")
374            .where_clause("account_id = ?1");
375        self.conn
376            .execute(&query.as_string(), [account_id, folder_id])?;
377        Ok(())
378    }
379
380    /// Create the join for the account device folder.
381    pub fn insert_device_folder(
382        &self,
383        account_id: i64,
384        folder_id: i64,
385    ) -> std::result::Result<i64, SqlError> {
386        let query = sql::Insert::new()
387            .insert_into("account_device_folder (account_id, folder_id)")
388            .values("(?1, ?2)");
389        self.conn
390            .execute(&query.as_string(), [account_id, folder_id])?;
391        Ok(self.conn.last_insert_rowid())
392    }
393
394    /// Rename the account.
395    pub fn rename_account(&self, account_id: i64, name: &str) -> Result<()> {
396        let modified_at = UtcDateTime::default().to_rfc3339()?;
397        let query = sql::Update::new()
398            .update("accounts")
399            .set("name = ?1, modified_at = ?2")
400            .where_clause("account_id = ?3");
401        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
402        stmt.execute((name, modified_at, account_id))?;
403        Ok(())
404    }
405
406    /// Delete the account from the database.
407    pub fn delete_account(
408        &self,
409        account_id: &AccountId,
410    ) -> std::result::Result<(), SqlError> {
411        let query = sql::Delete::new()
412            .delete_from("accounts")
413            .where_clause("identifier = ?1");
414        self.conn
415            .execute(&query.as_string(), [account_id.to_string()])?;
416        Ok(())
417    }
418}