sqlx_ledger/balance/
repo.rs

1use sqlx::{PgPool, Postgres, QueryBuilder, Row, Transaction};
2use tracing::instrument;
3use uuid::Uuid;
4
5use std::{collections::HashMap, str::FromStr};
6
7use super::entity::*;
8use crate::{error::*, primitives::*};
9
10/// Repository for working with `AccountBalance` entities.
11#[derive(Debug, Clone)]
12pub struct Balances {
13    pool: PgPool,
14}
15
16impl Balances {
17    pub fn new(pool: &PgPool) -> Self {
18        Self { pool: pool.clone() }
19    }
20
21    #[instrument(name = "sqlx_ledger.balances.find", skip(self))]
22    pub async fn find(
23        &self,
24        journal_id: JournalId,
25        account_id: AccountId,
26        currency: Currency,
27    ) -> Result<Option<AccountBalance>, SqlxLedgerError> {
28        let record = sqlx::query!(
29            r#"SELECT
30              a.normal_balance_type as "normal_balance_type: DebitOrCredit", b.journal_id, b.account_id, entry_id, b.currency,
31              settled_dr_balance, settled_cr_balance, settled_entry_id, settled_modified_at,
32              pending_dr_balance, pending_cr_balance, pending_entry_id, pending_modified_at,
33              encumbered_dr_balance, encumbered_cr_balance, encumbered_entry_id, encumbered_modified_at,
34              c.version, modified_at, created_at
35                FROM sqlx_ledger_balances b JOIN (
36                  SELECT * FROM sqlx_ledger_current_balances WHERE journal_id = $1 AND account_id = $2 AND currency = $3 ) c
37                ON b.journal_id = c.journal_id AND b.account_id = c.account_id AND b.currency = c.currency AND b.version = c.version
38                JOIN ( SELECT id, normal_balance_type FROM sqlx_ledger_accounts WHERE id = $2 LIMIT 1 ) a
39                  ON a.id = b.account_id"#,
40            journal_id as JournalId,
41            account_id as AccountId,
42            currency.code()
43        )
44        .fetch_optional(&self.pool)
45        .await?;
46        Ok(record.map(|record| AccountBalance {
47            balance_type: record.normal_balance_type,
48            details: BalanceDetails {
49                journal_id,
50                account_id,
51                entry_id: EntryId::from(record.entry_id),
52                currency,
53                settled_dr_balance: record.settled_dr_balance,
54                settled_cr_balance: record.settled_cr_balance,
55                settled_entry_id: EntryId::from(record.settled_entry_id),
56                settled_modified_at: record.settled_modified_at,
57                pending_dr_balance: record.pending_dr_balance,
58                pending_cr_balance: record.pending_cr_balance,
59                pending_entry_id: EntryId::from(record.pending_entry_id),
60                pending_modified_at: record.pending_modified_at,
61                encumbered_dr_balance: record.encumbered_dr_balance,
62                encumbered_cr_balance: record.encumbered_cr_balance,
63                encumbered_entry_id: EntryId::from(record.encumbered_entry_id),
64                encumbered_modified_at: record.encumbered_modified_at,
65                version: record.version,
66                modified_at: record.modified_at,
67                created_at: record.created_at,
68            },
69        }))
70    }
71
72    #[instrument(name = "sqlx_ledger.balances.find_all", skip(self, accounts))]
73    pub async fn find_all(
74        &self,
75        journal_id: JournalId,
76        accounts: impl IntoIterator<Item = AccountId>,
77    ) -> Result<HashMap<AccountId, HashMap<Currency, AccountBalance>>, SqlxLedgerError> {
78        let account_ids: Vec<Uuid> = accounts.into_iter().map(Uuid::from).collect();
79        let rows = sqlx::query!(
80            r#"SELECT
81              a.normal_balance_type as "normal_balance_type: DebitOrCredit", b.journal_id, b.account_id, entry_id, b.currency,
82              settled_dr_balance, settled_cr_balance, settled_entry_id, settled_modified_at,
83              pending_dr_balance, pending_cr_balance, pending_entry_id, pending_modified_at,
84              encumbered_dr_balance, encumbered_cr_balance, encumbered_entry_id, encumbered_modified_at,
85              c.version, modified_at, created_at
86                FROM sqlx_ledger_balances b JOIN (
87                  SELECT * FROM sqlx_ledger_current_balances WHERE journal_id = $1 AND account_id = ANY($2)) c
88                ON b.journal_id = c.journal_id AND b.account_id = c.account_id AND b.currency = c.currency AND b.version = c.version
89                JOIN ( SELECT DISTINCT(id), normal_balance_type FROM sqlx_ledger_accounts WHERE id = ANY($2)) a
90                  ON a.id = b.account_id"#,
91            journal_id as JournalId,
92            &account_ids[..]
93        )
94        .fetch_all(&self.pool)
95        .await?;
96        let mut ret = HashMap::new();
97        for row in rows {
98            ret.entry(AccountId::from(row.account_id))
99                .or_insert_with(HashMap::new)
100                .insert(
101                    row.currency.parse().expect("Currency code is invalid"),
102                    AccountBalance {
103                        balance_type: row.normal_balance_type,
104                        details: BalanceDetails {
105                            journal_id,
106                            account_id: AccountId::from(row.account_id),
107                            entry_id: EntryId::from(row.entry_id),
108                            currency: row.currency.parse().unwrap(),
109                            settled_dr_balance: row.settled_dr_balance,
110                            settled_cr_balance: row.settled_cr_balance,
111                            settled_entry_id: EntryId::from(row.settled_entry_id),
112                            settled_modified_at: row.settled_modified_at,
113                            pending_dr_balance: row.pending_dr_balance,
114                            pending_cr_balance: row.pending_cr_balance,
115                            pending_entry_id: EntryId::from(row.pending_entry_id),
116                            pending_modified_at: row.pending_modified_at,
117                            encumbered_dr_balance: row.encumbered_dr_balance,
118                            encumbered_cr_balance: row.encumbered_cr_balance,
119                            encumbered_entry_id: EntryId::from(row.encumbered_entry_id),
120                            encumbered_modified_at: row.encumbered_modified_at,
121                            version: row.version,
122                            modified_at: row.modified_at,
123                            created_at: row.created_at,
124                        },
125                    },
126                );
127        }
128        Ok(ret)
129    }
130
131    #[instrument(
132        level = "trace",
133        name = "sqlx_ledger.balances.find_for_update",
134        skip(self, tx)
135    )]
136    pub(crate) async fn find_for_update<'a>(
137        &self,
138        journal_id: JournalId,
139        ids: Vec<(AccountId, &Currency)>,
140        tx: &mut Transaction<'a, Postgres>,
141    ) -> Result<HashMap<(AccountId, Currency), BalanceDetails>, SqlxLedgerError> {
142        let mut query_builder: QueryBuilder<Postgres> = QueryBuilder::new(
143            r#"SELECT
144              b.journal_id, b.account_id, entry_id, b.currency,
145              settled_dr_balance, settled_cr_balance, settled_entry_id, settled_modified_at,
146              pending_dr_balance, pending_cr_balance, pending_entry_id, pending_modified_at,
147              encumbered_dr_balance, encumbered_cr_balance, encumbered_entry_id, encumbered_modified_at,
148              c.version, modified_at, created_at
149                FROM sqlx_ledger_balances b JOIN (
150                    SELECT * FROM sqlx_ledger_current_balances WHERE journal_id = "#,
151        );
152        query_builder.push_bind(journal_id);
153        query_builder.push(r#" AND (account_id, currency) IN"#);
154        query_builder.push_tuples(ids, |mut builder, (id, currency)| {
155            builder.push_bind(id);
156            builder.push_bind(currency.code());
157        });
158        query_builder.push(
159            r#"FOR UPDATE ) c ON
160                b.journal_id = c.journal_id AND b.account_id = c.account_id AND b.currency = c.currency AND b.version = c.version"#,
161        );
162
163        let query = query_builder.build();
164        let records = query.fetch_all(&mut **tx).await?;
165        let mut ret = HashMap::new();
166        for r in records {
167            let account_id = AccountId::from(r.get::<Uuid, _>("account_id"));
168            let currency =
169                Currency::from_str(r.get("currency")).expect("currency code should be valid");
170            ret.insert(
171                (account_id, currency),
172                BalanceDetails {
173                    account_id,
174                    journal_id: JournalId::from(r.get::<Uuid, _>("journal_id")),
175                    entry_id: EntryId::from(r.get::<Uuid, _>("entry_id")),
176                    currency: r.get::<&str, _>("currency").parse()?,
177                    settled_dr_balance: r.get("settled_dr_balance"),
178                    settled_cr_balance: r.get("settled_cr_balance"),
179                    settled_entry_id: EntryId::from(r.get::<Uuid, _>("settled_entry_id")),
180                    settled_modified_at: r.get("settled_modified_at"),
181                    pending_dr_balance: r.get("pending_dr_balance"),
182                    pending_cr_balance: r.get("pending_cr_balance"),
183                    pending_entry_id: EntryId::from(r.get::<Uuid, _>("pending_entry_id")),
184                    pending_modified_at: r.get("pending_modified_at"),
185                    encumbered_dr_balance: r.get("encumbered_dr_balance"),
186                    encumbered_cr_balance: r.get("encumbered_cr_balance"),
187                    encumbered_entry_id: EntryId::from(r.get::<Uuid, _>("encumbered_entry_id")),
188                    encumbered_modified_at: r.get("encumbered_modified_at"),
189                    version: r.get("version"),
190                    modified_at: r.get("modified_at"),
191                    created_at: r.get("created_at"),
192                },
193            );
194        }
195        Ok(ret)
196    }
197
198    #[instrument(
199        level = "trace",
200        name = "sqlx_ledger.balances.update_balances",
201        skip(self, tx)
202    )]
203    pub(crate) async fn update_balances<'a>(
204        &self,
205        journal_id: JournalId,
206        new_balances: Vec<BalanceDetails>,
207        tx: &mut Transaction<'a, Postgres>,
208    ) -> Result<(), SqlxLedgerError> {
209        let mut latest_versions = HashMap::new();
210        let mut previous_versions = HashMap::new();
211        for BalanceDetails {
212            account_id,
213            version,
214            currency,
215            ..
216        } in new_balances.iter()
217        {
218            latest_versions.insert((account_id, currency), version);
219            if previous_versions.contains_key(&(account_id, currency)) {
220                continue;
221            }
222            previous_versions.insert((account_id, currency), version - 1);
223        }
224        let expected_accounts_effected = latest_versions.len();
225        let mut query_builder: QueryBuilder<Postgres> = QueryBuilder::new(
226            r#"INSERT INTO sqlx_ledger_current_balances
227                  (journal_id, account_id, currency, version)"#,
228        );
229        let mut any_new = false;
230        query_builder.push_values(
231            previous_versions.iter().filter(|(_, v)| **v == 0),
232            |mut builder, ((account_id, currency), version)| {
233                any_new = true;
234                builder.push_bind(journal_id);
235                builder.push_bind(**account_id);
236                builder.push_bind(currency.code());
237                builder.push_bind(version);
238            },
239        );
240        if any_new {
241            query_builder.build().execute(&mut **tx).await?;
242        }
243        let mut query_builder: QueryBuilder<Postgres> =
244            QueryBuilder::new(r#"UPDATE sqlx_ledger_current_balances SET version = CASE"#);
245        let mut bind_numbers = HashMap::new();
246        let mut next_bind_number = 1;
247        for ((account_id, currency), version) in latest_versions {
248            bind_numbers.insert((account_id, currency), next_bind_number);
249            next_bind_number += 3;
250            query_builder.push(" WHEN account_id = ");
251            query_builder.push_bind(*account_id);
252            query_builder.push(" AND currency = ");
253            query_builder.push_bind(currency.code());
254            query_builder.push(" THEN ");
255            query_builder.push_bind(version);
256        }
257        query_builder.push(" END WHERE journal_id = ");
258        query_builder.push_bind(journal_id);
259        query_builder.push(" AND (account_id, currency, version) IN");
260        query_builder.push_tuples(
261            previous_versions,
262            |mut builder, ((account_id, currency), version)| {
263                let n = bind_numbers.remove(&(account_id, currency)).unwrap();
264                builder.push(format!("${}, ${}", n, n + 1));
265                builder.push_bind(version);
266            },
267        );
268        let result = query_builder.build().execute(&mut **tx).await?;
269        if result.rows_affected() != (expected_accounts_effected as u64) {
270            return Err(SqlxLedgerError::OptimisticLockingError);
271        }
272
273        let mut query_builder: QueryBuilder<Postgres> = QueryBuilder::new(
274            r#"INSERT INTO sqlx_ledger_balances (
275                 journal_id, account_id, entry_id, currency,
276                 settled_dr_balance, settled_cr_balance, settled_entry_id, settled_modified_at,
277                 pending_dr_balance, pending_cr_balance, pending_entry_id, pending_modified_at,
278                 encumbered_dr_balance, encumbered_cr_balance, encumbered_entry_id, encumbered_modified_at,
279                 version, modified_at, created_at)
280            "#,
281        );
282        query_builder.push_values(new_balances, |mut builder, b| {
283            builder.push_bind(b.journal_id);
284            builder.push_bind(b.account_id);
285            builder.push_bind(b.entry_id);
286            builder.push_bind(b.currency.code());
287            builder.push_bind(b.settled_dr_balance);
288            builder.push_bind(b.settled_cr_balance);
289            builder.push_bind(b.settled_entry_id);
290            builder.push_bind(b.settled_modified_at);
291            builder.push_bind(b.pending_dr_balance);
292            builder.push_bind(b.pending_cr_balance);
293            builder.push_bind(b.pending_entry_id);
294            builder.push_bind(b.pending_modified_at);
295            builder.push_bind(b.encumbered_dr_balance);
296            builder.push_bind(b.encumbered_cr_balance);
297            builder.push_bind(b.encumbered_entry_id);
298            builder.push_bind(b.encumbered_modified_at);
299            builder.push_bind(b.version);
300            builder.push_bind(b.modified_at);
301            builder.push_bind(b.created_at);
302        });
303        query_builder.build().execute(&mut **tx).await?;
304        Ok(())
305    }
306}