sos_database/entity/
account.rs

1use crate::{
2    entity::{FolderEntity, FolderRecord, SecretRow},
3    Error, Result,
4};
5use async_sqlite::{
6    rusqlite::{
7        Connection, Error as SqlError, OptionalExtension, Row, Transaction,
8    },
9    Client,
10};
11use sos_core::{AccountId, PublicIdentity, UtcDateTime, VaultCommit};
12use sos_vault::Vault;
13use sql_query_builder as sql;
14use std::ops::Deref;
15
16use super::FolderRow;
17
18/// Account row from the database.
19#[doc(hidden)]
20#[derive(Debug, Default)]
21pub struct AccountRow {
22    /// Row identifier.
23    pub row_id: i64,
24    /// RFC3339 date and time.
25    created_at: String,
26    /// RFC3339 date and time.
27    modified_at: String,
28    /// Account identifier.
29    identifier: String,
30    /// Account name.
31    name: String,
32}
33
34impl AccountRow {
35    /// Create an account row for insertion.
36    pub fn new_insert(account_id: &AccountId, name: String) -> Result<Self> {
37        Ok(AccountRow {
38            identifier: account_id.to_string(),
39            name,
40            created_at: UtcDateTime::default().to_rfc3339()?,
41            modified_at: UtcDateTime::default().to_rfc3339()?,
42            ..Default::default()
43        })
44    }
45}
46
47impl<'a> TryFrom<&Row<'a>> for AccountRow {
48    type Error = SqlError;
49    fn try_from(row: &Row<'a>) -> std::result::Result<Self, Self::Error> {
50        Ok(AccountRow {
51            row_id: row.get(0)?,
52            created_at: row.get(1)?,
53            modified_at: row.get(2)?,
54            identifier: row.get(3)?,
55            name: row.get(4)?,
56        })
57    }
58}
59
60/// Account record from the database.
61#[derive(Debug)]
62pub struct AccountRecord {
63    /// Row identifier.
64    pub row_id: i64,
65    /// Created date and time.
66    pub created_at: UtcDateTime,
67    /// Modified date and time.
68    pub modified_at: UtcDateTime,
69    /// Account identity.
70    pub identity: PublicIdentity,
71}
72
73impl TryFrom<AccountRow> for AccountRecord {
74    type Error = Error;
75
76    fn try_from(value: AccountRow) -> std::result::Result<Self, Self::Error> {
77        let created_at = UtcDateTime::parse_rfc3339(&value.created_at)?;
78        let modified_at = UtcDateTime::parse_rfc3339(&value.modified_at)?;
79        let account_id: AccountId = value.identifier.parse()?;
80        Ok(AccountRecord {
81            row_id: value.row_id,
82            created_at,
83            modified_at,
84            identity: PublicIdentity::new(account_id, value.name),
85        })
86    }
87}
88
89/// Account entity.
90pub struct AccountEntity<'conn, C>
91where
92    C: Deref<Target = Connection>,
93{
94    conn: &'conn C,
95}
96
97impl<'conn> AccountEntity<'conn, Box<Connection>> {
98    /// Liat all accounts.
99    pub async fn list_all_accounts(
100        client: &Client,
101    ) -> Result<Vec<AccountRecord>> {
102        let account_rows = client
103            .conn_and_then(move |conn| {
104                let account = AccountEntity::new(&conn);
105                account.list_accounts()
106            })
107            .await?;
108
109        let mut accounts = Vec::new();
110        for row in account_rows {
111            accounts.push(row.try_into()?);
112        }
113        Ok(accounts)
114    }
115
116    /// Find an account and login folder.
117    pub async fn find_account_with_login(
118        client: &Client,
119        account_id: &AccountId,
120    ) -> Result<(AccountRecord, FolderRecord)> {
121        let (account, folder_row) =
122            Self::find_account_with_login_optional(client, account_id)
123                .await?;
124
125        let account_id = account.row_id;
126        Ok((
127            account,
128            folder_row.ok_or_else(|| Error::NoLoginFolder(account_id))?,
129        ))
130    }
131
132    /// Find an account and optional login folder.
133    pub async fn find_account_with_login_optional(
134        client: &Client,
135        account_id: &AccountId,
136    ) -> Result<(AccountRecord, Option<FolderRecord>)> {
137        let account_id = *account_id;
138        let (account_row, folder_row) = client
139            .conn_and_then(move |conn| {
140                let account = AccountEntity::new(&conn);
141                let account_row = account.find_one(&account_id)?;
142                let folders = FolderEntity::new(&conn);
143                let folder_row =
144                    folders.find_login_folder_optional(account_row.row_id)?;
145                Ok::<_, Error>((account_row, folder_row))
146            })
147            .await?;
148
149        let login_folder = if let Some(folder_row) = folder_row {
150            Some(FolderRecord::from_row(folder_row).await?)
151        } else {
152            None
153        };
154        Ok((account_row.try_into()?, login_folder))
155    }
156}
157
158impl<'conn> AccountEntity<'conn, Transaction<'conn>> {
159    /// Upsert the login folder.
160    pub async fn upsert_login_folder(
161        client: &Client,
162        account_id: &AccountId,
163        vault: &Vault,
164    ) -> Result<(AccountRecord, i64)> {
165        // Check if we already have a login folder
166        let (account, folder) =
167            AccountEntity::find_account_with_login_optional(
168                client, account_id,
169            )
170            .await?;
171
172        // TODO: folder creation and join should be merged into a single
173        // TODO: transaction
174
175        // Create or update the folder and secrets
176        let (folder_row_id, _) = FolderEntity::upsert_folder_and_secrets(
177            client,
178            account.row_id,
179            vault,
180        )
181        .await?;
182
183        let account_row_id = account.row_id;
184
185        // Update or insert the join
186        if folder.is_some() {
187            client
188                .conn(move |conn| {
189                    let account_entity = AccountEntity::new(&conn);
190                    account_entity
191                        .update_login_folder(account_row_id, folder_row_id)
192                })
193                .await?;
194        } else {
195            client
196                .conn(move |conn| {
197                    let account_entity = AccountEntity::new(&conn);
198                    account_entity
199                        .insert_login_folder(account_row_id, folder_row_id)
200                })
201                .await?;
202        }
203
204        Ok((account, folder_row_id))
205    }
206
207    /// Replace the login folder.
208    pub async fn replace_login_folder(
209        client: &mut Client,
210        account_id: &AccountId,
211        vault: &Vault,
212    ) -> Result<()> {
213        // Check if we already have a login folder
214        let (account, login_folder) =
215            AccountEntity::find_account_with_login(client, account_id)
216                .await?;
217
218        let login_folder_id = *login_folder.summary.id();
219        let new_login_folder = FolderRow::new_insert(vault).await?;
220
221        let mut secret_rows = Vec::new();
222        for (secret_id, commit) in vault.iter() {
223            let VaultCommit(commit, entry) = commit;
224            secret_rows.push(SecretRow::new(secret_id, commit, entry).await?);
225        }
226
227        client
228            .conn_mut_and_then(move |conn| {
229                let tx = conn.transaction()?;
230                let account_entity = AccountEntity::new(&tx);
231                let folder_entity = FolderEntity::new(&tx);
232
233                // Delete the old folder
234                folder_entity.delete_folder(&login_folder_id)?;
235
236                // Create the new folder
237                let folder_row_id = folder_entity
238                    .insert_folder(account.row_id, &new_login_folder)?;
239
240                // Insert the secrets
241                folder_entity.insert_folder_secrets(
242                    folder_row_id,
243                    secret_rows.as_slice(),
244                )?;
245
246                // Update the join
247                account_entity
248                    .update_login_folder(account.row_id, folder_row_id)?;
249
250                tx.commit()?;
251                Ok::<_, Error>(())
252            })
253            .await?;
254
255        Ok(())
256    }
257}
258
259impl<'conn, C> AccountEntity<'conn, C>
260where
261    C: Deref<Target = Connection>,
262{
263    /// Create a new account entity.
264    pub fn new(conn: &'conn C) -> Self {
265        Self { conn }
266    }
267
268    fn account_select_columns(&self, sql: sql::Select) -> sql::Select {
269        sql.select(
270            r#"
271                account_id,
272                created_at,
273                modified_at,
274                identifier,
275                name
276            "#,
277        )
278    }
279
280    /// Find an account in the database.
281    pub fn find_one(
282        &self,
283        account_id: &AccountId,
284    ) -> std::result::Result<AccountRow, SqlError> {
285        let query = self
286            .account_select_columns(sql::Select::new())
287            .from("accounts")
288            .where_clause("identifier = ?1");
289        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
290        stmt.query_row([account_id.to_string()], |row| row.try_into())
291    }
292
293    /// Find an optional account in the database.
294    pub fn find_optional(
295        &self,
296        account_id: &AccountId,
297    ) -> std::result::Result<Option<AccountRow>, SqlError> {
298        let query = self
299            .account_select_columns(sql::Select::new())
300            .from("accounts")
301            .where_clause("identifier = ?1");
302        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
303        stmt.query_row([account_id.to_string()], |row| row.try_into())
304            .optional()
305    }
306
307    /// List accounts.
308    pub fn list_accounts(&self) -> Result<Vec<AccountRow>> {
309        let query = self
310            .account_select_columns(sql::Select::new())
311            .from("accounts");
312
313        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
314
315        fn convert_row(row: &Row<'_>) -> Result<AccountRow> {
316            Ok(row.try_into()?)
317        }
318
319        let rows = stmt.query_and_then([], convert_row)?;
320        let mut accounts = Vec::new();
321        for row in rows {
322            accounts.push(row?);
323        }
324        Ok(accounts)
325    }
326
327    /// Create the account entity in the database.
328    pub fn insert(
329        &self,
330        row: &AccountRow,
331    ) -> std::result::Result<i64, SqlError> {
332        let query = sql::Insert::new()
333            .insert_into(
334                "accounts (created_at, modified_at, identifier, name)",
335            )
336            .values("(?1, ?2, ?3, ?4)");
337        self.conn.execute(
338            &query.as_string(),
339            (
340                &row.created_at,
341                &row.modified_at,
342                &row.identifier,
343                &row.name,
344            ),
345        )?;
346        Ok(self.conn.last_insert_rowid())
347    }
348
349    /// Create the join for the account login folder.
350    pub fn insert_login_folder(
351        &self,
352        account_id: i64,
353        folder_id: i64,
354    ) -> std::result::Result<i64, SqlError> {
355        let query = sql::Insert::new()
356            .insert_into("account_login_folder (account_id, folder_id)")
357            .values("(?1, ?2)");
358        self.conn
359            .execute(&query.as_string(), [account_id, folder_id])?;
360        Ok(self.conn.last_insert_rowid())
361    }
362
363    /// Update the join for an account login folder.
364    pub fn update_login_folder(
365        &self,
366        account_id: i64,
367        folder_id: i64,
368    ) -> std::result::Result<(), SqlError> {
369        let query = sql::Update::new()
370            .update("account_login_folder")
371            .set("folder_id = ?2")
372            .where_clause("account_id = ?1");
373        self.conn
374            .execute(&query.as_string(), [account_id, folder_id])?;
375        Ok(())
376    }
377
378    /// Create the join for the account device folder.
379    pub fn insert_device_folder(
380        &self,
381        account_id: i64,
382        folder_id: i64,
383    ) -> std::result::Result<i64, SqlError> {
384        let query = sql::Insert::new()
385            .insert_into("account_device_folder (account_id, folder_id)")
386            .values("(?1, ?2)");
387        self.conn
388            .execute(&query.as_string(), [account_id, folder_id])?;
389        Ok(self.conn.last_insert_rowid())
390    }
391
392    /// Rename the account.
393    pub fn rename_account(&self, account_id: i64, name: &str) -> Result<()> {
394        let modified_at = UtcDateTime::default().to_rfc3339()?;
395        let query = sql::Update::new()
396            .update("accounts")
397            .set("name = ?1, modified_at = ?2")
398            .where_clause("account_id = ?3");
399        let mut stmt = self.conn.prepare_cached(&query.as_string())?;
400        stmt.execute((name, modified_at, account_id))?;
401        Ok(())
402    }
403
404    /// Delete the account from the database.
405    pub fn delete_account(
406        &self,
407        account_id: &AccountId,
408    ) -> std::result::Result<(), SqlError> {
409        let query = sql::Delete::new()
410            .delete_from("accounts")
411            .where_clause("identifier = ?1");
412        self.conn
413            .execute(&query.as_string(), [account_id.to_string()])?;
414        Ok(())
415    }
416}