sqlx_ledger/entry/
repo.rs

1use chrono::{DateTime, Utc};
2use rust_decimal::Decimal;
3use sqlx::{PgPool, Postgres, QueryBuilder, Row, Transaction};
4use tracing::instrument;
5use uuid::Uuid;
6
7use std::{collections::HashMap, str::FromStr};
8
9use super::entity::*;
10use crate::{error::*, primitives::*};
11
12/// Repository for working with `Entry` (Debit/Credit) entities.
13#[derive(Debug, Clone)]
14pub struct Entries {
15    pool: PgPool,
16}
17
18#[derive(Debug)]
19pub(crate) struct StagedEntry {
20    pub(crate) account_id: AccountId,
21    pub(crate) entry_id: EntryId,
22    pub(crate) units: Decimal,
23    pub(crate) currency: Currency,
24    pub(crate) direction: DebitOrCredit,
25    pub(crate) layer: Layer,
26    pub(crate) created_at: DateTime<Utc>,
27}
28
29impl Entries {
30    pub fn new(pool: &PgPool) -> Self {
31        Self { pool: pool.clone() }
32    }
33
34    #[instrument(
35        level = "trace",
36        name = "sqlx_ledger.entries.create_all",
37        skip(self, tx)
38    )]
39    pub(crate) async fn create_all<'a>(
40        &self,
41        journal_id: JournalId,
42        transaction_id: TransactionId,
43        entries: Vec<NewEntry>,
44        tx: &mut Transaction<'a, Postgres>,
45    ) -> Result<Vec<StagedEntry>, SqlxLedgerError> {
46        let mut query_builder: QueryBuilder<Postgres> = QueryBuilder::new(
47            r#"WITH new_entries as (
48                 INSERT INTO sqlx_ledger_entries
49                  (id, transaction_id, journal_id, entry_type, layer,
50                   units, currency, direction, description, sequence, account_id)"#,
51        );
52        let mut partial_ret = HashMap::new();
53        let mut sequence = 1;
54        query_builder.push_values(
55            entries,
56            |mut builder,
57             NewEntry {
58                 account_id,
59                 entry_type,
60                 layer,
61                 units,
62                 currency,
63                 direction,
64                 description,
65             }: NewEntry| {
66                builder.push("gen_random_uuid()");
67                builder.push_bind(transaction_id);
68                builder.push_bind(journal_id);
69                builder.push_bind(entry_type);
70                builder.push_bind(layer);
71                builder.push_bind(units);
72                builder.push_bind(currency.code());
73                builder.push_bind(direction);
74                builder.push_bind(description);
75                builder.push_bind(sequence);
76                builder.push("(SELECT id FROM sqlx_ledger_accounts WHERE id = ");
77                builder.push_bind_unseparated(account_id);
78                builder.push_unseparated(")");
79                partial_ret.insert(sequence, (account_id, units, currency, layer, direction));
80                sequence += 1;
81            },
82        );
83        query_builder.push(
84            "RETURNING id, sequence, created_at ) SELECT * FROM new_entries ORDER BY sequence",
85        );
86        let query = query_builder.build();
87        let records = query.fetch_all(&mut **tx).await?;
88
89        let mut ret = Vec::new();
90        sequence = 1;
91        for r in records {
92            let entry_id: Uuid = r.get("id");
93            let created_at = r.get("created_at");
94            let (account_id, units, currency, layer, direction) =
95                partial_ret.remove(&sequence).expect("sequence not found");
96            ret.push(StagedEntry {
97                entry_id: entry_id.into(),
98                account_id,
99                units,
100                currency,
101                layer,
102                direction,
103                created_at,
104            });
105            sequence += 1;
106        }
107
108        Ok(ret)
109    }
110
111    pub async fn list_by_transaction_ids(
112        &self,
113        tx_ids: impl IntoIterator<Item = impl std::borrow::Borrow<TransactionId>>,
114    ) -> Result<HashMap<TransactionId, Vec<Entry>>, SqlxLedgerError> {
115        let tx_ids: Vec<Uuid> = tx_ids
116            .into_iter()
117            .map(|id| Uuid::from(id.borrow()))
118            .collect();
119        let records = sqlx::query!(
120            r#"SELECT id, version, transaction_id, account_id, journal_id, entry_type, layer as "layer: Layer", units, currency, direction as "direction: DebitOrCredit", sequence, description, created_at, modified_at
121            FROM sqlx_ledger_entries
122            WHERE transaction_id = ANY($1) ORDER BY transaction_id ASC, sequence ASC, version DESC"#,
123            &tx_ids[..]
124        ).fetch_all(&self.pool).await?;
125
126        let mut transactions: HashMap<TransactionId, Vec<Entry>> = HashMap::new();
127
128        let mut current_tx_id = TransactionId::new();
129        let mut last_sequence = 0;
130        for row in records {
131            let transaction_id = TransactionId::from(row.transaction_id);
132            // Skip old entry versions (description is mutable)
133            if last_sequence == row.sequence && transaction_id == current_tx_id {
134                continue;
135            }
136            current_tx_id = transaction_id;
137            last_sequence = row.sequence;
138
139            let entry = transactions.entry(transaction_id).or_default();
140
141            entry.push(Entry {
142                id: EntryId::from(row.id),
143                transaction_id,
144                version: row.version as u32,
145                account_id: AccountId::from(row.account_id),
146                journal_id: JournalId::from(row.journal_id),
147                entry_type: row.entry_type,
148                layer: row.layer,
149                units: row.units,
150                currency: Currency::from_str(row.currency.as_str())
151                    .expect("Couldn't convert currency"),
152                direction: row.direction,
153                sequence: row.sequence as u32,
154                description: row.description,
155                created_at: row.created_at,
156                modified_at: row.modified_at,
157            })
158        }
159
160        Ok(transactions)
161    }
162}