1use sqlx::{PgPool, Postgres, QueryBuilder, Row, Transaction};
2use tracing::instrument;
3use uuid::Uuid;
4
5use std::{collections::HashMap, str::FromStr};
6
7use super::entity::*;
8use crate::{error::*, primitives::*};
9
10#[derive(Debug, Clone)]
12pub struct Balances {
13 pool: PgPool,
14}
15
16impl Balances {
17 pub fn new(pool: &PgPool) -> Self {
18 Self { pool: pool.clone() }
19 }
20
21 #[instrument(name = "sqlx_ledger.balances.find", skip(self))]
22 pub async fn find(
23 &self,
24 journal_id: JournalId,
25 account_id: AccountId,
26 currency: Currency,
27 ) -> Result<Option<AccountBalance>, SqlxLedgerError> {
28 let record = sqlx::query!(
29 r#"SELECT
30 a.normal_balance_type as "normal_balance_type: DebitOrCredit", b.journal_id, b.account_id, entry_id, b.currency,
31 settled_dr_balance, settled_cr_balance, settled_entry_id, settled_modified_at,
32 pending_dr_balance, pending_cr_balance, pending_entry_id, pending_modified_at,
33 encumbered_dr_balance, encumbered_cr_balance, encumbered_entry_id, encumbered_modified_at,
34 c.version, modified_at, created_at
35 FROM sqlx_ledger_balances b JOIN (
36 SELECT * FROM sqlx_ledger_current_balances WHERE journal_id = $1 AND account_id = $2 AND currency = $3 ) c
37 ON b.journal_id = c.journal_id AND b.account_id = c.account_id AND b.currency = c.currency AND b.version = c.version
38 JOIN ( SELECT id, normal_balance_type FROM sqlx_ledger_accounts WHERE id = $2 LIMIT 1 ) a
39 ON a.id = b.account_id"#,
40 journal_id as JournalId,
41 account_id as AccountId,
42 currency.code()
43 )
44 .fetch_optional(&self.pool)
45 .await?;
46 Ok(record.map(|record| AccountBalance {
47 balance_type: record.normal_balance_type,
48 details: BalanceDetails {
49 journal_id,
50 account_id,
51 entry_id: EntryId::from(record.entry_id),
52 currency,
53 settled_dr_balance: record.settled_dr_balance,
54 settled_cr_balance: record.settled_cr_balance,
55 settled_entry_id: EntryId::from(record.settled_entry_id),
56 settled_modified_at: record.settled_modified_at,
57 pending_dr_balance: record.pending_dr_balance,
58 pending_cr_balance: record.pending_cr_balance,
59 pending_entry_id: EntryId::from(record.pending_entry_id),
60 pending_modified_at: record.pending_modified_at,
61 encumbered_dr_balance: record.encumbered_dr_balance,
62 encumbered_cr_balance: record.encumbered_cr_balance,
63 encumbered_entry_id: EntryId::from(record.encumbered_entry_id),
64 encumbered_modified_at: record.encumbered_modified_at,
65 version: record.version,
66 modified_at: record.modified_at,
67 created_at: record.created_at,
68 },
69 }))
70 }
71
72 #[instrument(name = "sqlx_ledger.balances.find_all", skip(self, accounts))]
73 pub async fn find_all(
74 &self,
75 journal_id: JournalId,
76 accounts: impl IntoIterator<Item = AccountId>,
77 ) -> Result<HashMap<AccountId, HashMap<Currency, AccountBalance>>, SqlxLedgerError> {
78 let account_ids: Vec<Uuid> = accounts.into_iter().map(Uuid::from).collect();
79 let rows = sqlx::query!(
80 r#"SELECT
81 a.normal_balance_type as "normal_balance_type: DebitOrCredit", b.journal_id, b.account_id, entry_id, b.currency,
82 settled_dr_balance, settled_cr_balance, settled_entry_id, settled_modified_at,
83 pending_dr_balance, pending_cr_balance, pending_entry_id, pending_modified_at,
84 encumbered_dr_balance, encumbered_cr_balance, encumbered_entry_id, encumbered_modified_at,
85 c.version, modified_at, created_at
86 FROM sqlx_ledger_balances b JOIN (
87 SELECT * FROM sqlx_ledger_current_balances WHERE journal_id = $1 AND account_id = ANY($2)) c
88 ON b.journal_id = c.journal_id AND b.account_id = c.account_id AND b.currency = c.currency AND b.version = c.version
89 JOIN ( SELECT DISTINCT(id), normal_balance_type FROM sqlx_ledger_accounts WHERE id = ANY($2)) a
90 ON a.id = b.account_id"#,
91 journal_id as JournalId,
92 &account_ids[..]
93 )
94 .fetch_all(&self.pool)
95 .await?;
96 let mut ret = HashMap::new();
97 for row in rows {
98 ret.entry(AccountId::from(row.account_id))
99 .or_insert_with(HashMap::new)
100 .insert(
101 row.currency.parse().expect("Currency code is invalid"),
102 AccountBalance {
103 balance_type: row.normal_balance_type,
104 details: BalanceDetails {
105 journal_id,
106 account_id: AccountId::from(row.account_id),
107 entry_id: EntryId::from(row.entry_id),
108 currency: row.currency.parse().unwrap(),
109 settled_dr_balance: row.settled_dr_balance,
110 settled_cr_balance: row.settled_cr_balance,
111 settled_entry_id: EntryId::from(row.settled_entry_id),
112 settled_modified_at: row.settled_modified_at,
113 pending_dr_balance: row.pending_dr_balance,
114 pending_cr_balance: row.pending_cr_balance,
115 pending_entry_id: EntryId::from(row.pending_entry_id),
116 pending_modified_at: row.pending_modified_at,
117 encumbered_dr_balance: row.encumbered_dr_balance,
118 encumbered_cr_balance: row.encumbered_cr_balance,
119 encumbered_entry_id: EntryId::from(row.encumbered_entry_id),
120 encumbered_modified_at: row.encumbered_modified_at,
121 version: row.version,
122 modified_at: row.modified_at,
123 created_at: row.created_at,
124 },
125 },
126 );
127 }
128 Ok(ret)
129 }
130
131 #[instrument(
132 level = "trace",
133 name = "sqlx_ledger.balances.find_for_update",
134 skip(self, tx)
135 )]
136 pub(crate) async fn find_for_update<'a>(
137 &self,
138 journal_id: JournalId,
139 ids: Vec<(AccountId, &Currency)>,
140 tx: &mut Transaction<'a, Postgres>,
141 ) -> Result<HashMap<(AccountId, Currency), BalanceDetails>, SqlxLedgerError> {
142 let mut query_builder: QueryBuilder<Postgres> = QueryBuilder::new(
143 r#"SELECT
144 b.journal_id, b.account_id, entry_id, b.currency,
145 settled_dr_balance, settled_cr_balance, settled_entry_id, settled_modified_at,
146 pending_dr_balance, pending_cr_balance, pending_entry_id, pending_modified_at,
147 encumbered_dr_balance, encumbered_cr_balance, encumbered_entry_id, encumbered_modified_at,
148 c.version, modified_at, created_at
149 FROM sqlx_ledger_balances b JOIN (
150 SELECT * FROM sqlx_ledger_current_balances WHERE journal_id = "#,
151 );
152 query_builder.push_bind(journal_id);
153 query_builder.push(r#" AND (account_id, currency) IN"#);
154 query_builder.push_tuples(ids, |mut builder, (id, currency)| {
155 builder.push_bind(id);
156 builder.push_bind(currency.code());
157 });
158 query_builder.push(
159 r#"FOR UPDATE ) c ON
160 b.journal_id = c.journal_id AND b.account_id = c.account_id AND b.currency = c.currency AND b.version = c.version"#,
161 );
162
163 let query = query_builder.build();
164 let records = query.fetch_all(&mut **tx).await?;
165 let mut ret = HashMap::new();
166 for r in records {
167 let account_id = AccountId::from(r.get::<Uuid, _>("account_id"));
168 let currency =
169 Currency::from_str(r.get("currency")).expect("currency code should be valid");
170 ret.insert(
171 (account_id, currency),
172 BalanceDetails {
173 account_id,
174 journal_id: JournalId::from(r.get::<Uuid, _>("journal_id")),
175 entry_id: EntryId::from(r.get::<Uuid, _>("entry_id")),
176 currency: r.get::<&str, _>("currency").parse()?,
177 settled_dr_balance: r.get("settled_dr_balance"),
178 settled_cr_balance: r.get("settled_cr_balance"),
179 settled_entry_id: EntryId::from(r.get::<Uuid, _>("settled_entry_id")),
180 settled_modified_at: r.get("settled_modified_at"),
181 pending_dr_balance: r.get("pending_dr_balance"),
182 pending_cr_balance: r.get("pending_cr_balance"),
183 pending_entry_id: EntryId::from(r.get::<Uuid, _>("pending_entry_id")),
184 pending_modified_at: r.get("pending_modified_at"),
185 encumbered_dr_balance: r.get("encumbered_dr_balance"),
186 encumbered_cr_balance: r.get("encumbered_cr_balance"),
187 encumbered_entry_id: EntryId::from(r.get::<Uuid, _>("encumbered_entry_id")),
188 encumbered_modified_at: r.get("encumbered_modified_at"),
189 version: r.get("version"),
190 modified_at: r.get("modified_at"),
191 created_at: r.get("created_at"),
192 },
193 );
194 }
195 Ok(ret)
196 }
197
198 #[instrument(
199 level = "trace",
200 name = "sqlx_ledger.balances.update_balances",
201 skip(self, tx)
202 )]
203 pub(crate) async fn update_balances<'a>(
204 &self,
205 journal_id: JournalId,
206 new_balances: Vec<BalanceDetails>,
207 tx: &mut Transaction<'a, Postgres>,
208 ) -> Result<(), SqlxLedgerError> {
209 let mut latest_versions = HashMap::new();
210 let mut previous_versions = HashMap::new();
211 for BalanceDetails {
212 account_id,
213 version,
214 currency,
215 ..
216 } in new_balances.iter()
217 {
218 latest_versions.insert((account_id, currency), version);
219 if previous_versions.contains_key(&(account_id, currency)) {
220 continue;
221 }
222 previous_versions.insert((account_id, currency), version - 1);
223 }
224 let expected_accounts_effected = latest_versions.len();
225 let mut query_builder: QueryBuilder<Postgres> = QueryBuilder::new(
226 r#"INSERT INTO sqlx_ledger_current_balances
227 (journal_id, account_id, currency, version)"#,
228 );
229 let mut any_new = false;
230 query_builder.push_values(
231 previous_versions.iter().filter(|(_, v)| **v == 0),
232 |mut builder, ((account_id, currency), version)| {
233 any_new = true;
234 builder.push_bind(journal_id);
235 builder.push_bind(**account_id);
236 builder.push_bind(currency.code());
237 builder.push_bind(version);
238 },
239 );
240 if any_new {
241 query_builder.build().execute(&mut **tx).await?;
242 }
243 let mut query_builder: QueryBuilder<Postgres> =
244 QueryBuilder::new(r#"UPDATE sqlx_ledger_current_balances SET version = CASE"#);
245 let mut bind_numbers = HashMap::new();
246 let mut next_bind_number = 1;
247 for ((account_id, currency), version) in latest_versions {
248 bind_numbers.insert((account_id, currency), next_bind_number);
249 next_bind_number += 3;
250 query_builder.push(" WHEN account_id = ");
251 query_builder.push_bind(*account_id);
252 query_builder.push(" AND currency = ");
253 query_builder.push_bind(currency.code());
254 query_builder.push(" THEN ");
255 query_builder.push_bind(version);
256 }
257 query_builder.push(" END WHERE journal_id = ");
258 query_builder.push_bind(journal_id);
259 query_builder.push(" AND (account_id, currency, version) IN");
260 query_builder.push_tuples(
261 previous_versions,
262 |mut builder, ((account_id, currency), version)| {
263 let n = bind_numbers.remove(&(account_id, currency)).unwrap();
264 builder.push(format!("${}, ${}", n, n + 1));
265 builder.push_bind(version);
266 },
267 );
268 let result = query_builder.build().execute(&mut **tx).await?;
269 if result.rows_affected() != (expected_accounts_effected as u64) {
270 return Err(SqlxLedgerError::OptimisticLockingError);
271 }
272
273 let mut query_builder: QueryBuilder<Postgres> = QueryBuilder::new(
274 r#"INSERT INTO sqlx_ledger_balances (
275 journal_id, account_id, entry_id, currency,
276 settled_dr_balance, settled_cr_balance, settled_entry_id, settled_modified_at,
277 pending_dr_balance, pending_cr_balance, pending_entry_id, pending_modified_at,
278 encumbered_dr_balance, encumbered_cr_balance, encumbered_entry_id, encumbered_modified_at,
279 version, modified_at, created_at)
280 "#,
281 );
282 query_builder.push_values(new_balances, |mut builder, b| {
283 builder.push_bind(b.journal_id);
284 builder.push_bind(b.account_id);
285 builder.push_bind(b.entry_id);
286 builder.push_bind(b.currency.code());
287 builder.push_bind(b.settled_dr_balance);
288 builder.push_bind(b.settled_cr_balance);
289 builder.push_bind(b.settled_entry_id);
290 builder.push_bind(b.settled_modified_at);
291 builder.push_bind(b.pending_dr_balance);
292 builder.push_bind(b.pending_cr_balance);
293 builder.push_bind(b.pending_entry_id);
294 builder.push_bind(b.pending_modified_at);
295 builder.push_bind(b.encumbered_dr_balance);
296 builder.push_bind(b.encumbered_cr_balance);
297 builder.push_bind(b.encumbered_entry_id);
298 builder.push_bind(b.encumbered_modified_at);
299 builder.push_bind(b.version);
300 builder.push_bind(b.modified_at);
301 builder.push_bind(b.created_at);
302 });
303 query_builder.build().execute(&mut **tx).await?;
304 Ok(())
305 }
306}