1use chrono::{DateTime, Utc};
2use rust_decimal::Decimal;
3use sqlx::{PgPool, Postgres, QueryBuilder, Row, Transaction};
4use tracing::instrument;
5use uuid::Uuid;
6
7use std::{collections::HashMap, str::FromStr};
8
9use super::entity::*;
10use crate::{error::*, primitives::*};
11
12#[derive(Debug, Clone)]
14pub struct Entries {
15 pool: PgPool,
16}
17
18#[derive(Debug)]
19pub(crate) struct StagedEntry {
20 pub(crate) account_id: AccountId,
21 pub(crate) entry_id: EntryId,
22 pub(crate) units: Decimal,
23 pub(crate) currency: Currency,
24 pub(crate) direction: DebitOrCredit,
25 pub(crate) layer: Layer,
26 pub(crate) created_at: DateTime<Utc>,
27}
28
29impl Entries {
30 pub fn new(pool: &PgPool) -> Self {
31 Self { pool: pool.clone() }
32 }
33
34 #[instrument(
35 level = "trace",
36 name = "sqlx_ledger.entries.create_all",
37 skip(self, tx)
38 )]
39 pub(crate) async fn create_all<'a>(
40 &self,
41 journal_id: JournalId,
42 transaction_id: TransactionId,
43 entries: Vec<NewEntry>,
44 tx: &mut Transaction<'a, Postgres>,
45 ) -> Result<Vec<StagedEntry>, SqlxLedgerError> {
46 let mut query_builder: QueryBuilder<Postgres> = QueryBuilder::new(
47 r#"WITH new_entries as (
48 INSERT INTO sqlx_ledger_entries
49 (id, transaction_id, journal_id, entry_type, layer,
50 units, currency, direction, description, sequence, account_id)"#,
51 );
52 let mut partial_ret = HashMap::new();
53 let mut sequence = 1;
54 query_builder.push_values(
55 entries,
56 |mut builder,
57 NewEntry {
58 account_id,
59 entry_type,
60 layer,
61 units,
62 currency,
63 direction,
64 description,
65 }: NewEntry| {
66 builder.push("gen_random_uuid()");
67 builder.push_bind(transaction_id);
68 builder.push_bind(journal_id);
69 builder.push_bind(entry_type);
70 builder.push_bind(layer);
71 builder.push_bind(units);
72 builder.push_bind(currency.code());
73 builder.push_bind(direction);
74 builder.push_bind(description);
75 builder.push_bind(sequence);
76 builder.push("(SELECT id FROM sqlx_ledger_accounts WHERE id = ");
77 builder.push_bind_unseparated(account_id);
78 builder.push_unseparated(")");
79 partial_ret.insert(sequence, (account_id, units, currency, layer, direction));
80 sequence += 1;
81 },
82 );
83 query_builder.push(
84 "RETURNING id, sequence, created_at ) SELECT * FROM new_entries ORDER BY sequence",
85 );
86 let query = query_builder.build();
87 let records = query.fetch_all(&mut **tx).await?;
88
89 let mut ret = Vec::new();
90 sequence = 1;
91 for r in records {
92 let entry_id: Uuid = r.get("id");
93 let created_at = r.get("created_at");
94 let (account_id, units, currency, layer, direction) =
95 partial_ret.remove(&sequence).expect("sequence not found");
96 ret.push(StagedEntry {
97 entry_id: entry_id.into(),
98 account_id,
99 units,
100 currency,
101 layer,
102 direction,
103 created_at,
104 });
105 sequence += 1;
106 }
107
108 Ok(ret)
109 }
110
111 pub async fn list_by_transaction_ids(
112 &self,
113 tx_ids: impl IntoIterator<Item = impl std::borrow::Borrow<TransactionId>>,
114 ) -> Result<HashMap<TransactionId, Vec<Entry>>, SqlxLedgerError> {
115 let tx_ids: Vec<Uuid> = tx_ids
116 .into_iter()
117 .map(|id| Uuid::from(id.borrow()))
118 .collect();
119 let records = sqlx::query!(
120 r#"SELECT id, version, transaction_id, account_id, journal_id, entry_type, layer as "layer: Layer", units, currency, direction as "direction: DebitOrCredit", sequence, description, created_at, modified_at
121 FROM sqlx_ledger_entries
122 WHERE transaction_id = ANY($1) ORDER BY transaction_id ASC, sequence ASC, version DESC"#,
123 &tx_ids[..]
124 ).fetch_all(&self.pool).await?;
125
126 let mut transactions: HashMap<TransactionId, Vec<Entry>> = HashMap::new();
127
128 let mut current_tx_id = TransactionId::new();
129 let mut last_sequence = 0;
130 for row in records {
131 let transaction_id = TransactionId::from(row.transaction_id);
132 if last_sequence == row.sequence && transaction_id == current_tx_id {
134 continue;
135 }
136 current_tx_id = transaction_id;
137 last_sequence = row.sequence;
138
139 let entry = transactions.entry(transaction_id).or_default();
140
141 entry.push(Entry {
142 id: EntryId::from(row.id),
143 transaction_id,
144 version: row.version as u32,
145 account_id: AccountId::from(row.account_id),
146 journal_id: JournalId::from(row.journal_id),
147 entry_type: row.entry_type,
148 layer: row.layer,
149 units: row.units,
150 currency: Currency::from_str(row.currency.as_str())
151 .expect("Couldn't convert currency"),
152 direction: row.direction,
153 sequence: row.sequence as u32,
154 description: row.description,
155 created_at: row.created_at,
156 modified_at: row.modified_at,
157 })
158 }
159
160 Ok(transactions)
161 }
162}