1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use sqlx::{PgPool, Postgres, QueryBuilder, Row, Transaction};
use tracing::instrument;
use uuid::Uuid;

use std::{collections::HashMap, str::FromStr};

use super::entity::*;
use crate::{error::*, primitives::*};

/// Repository for working with `Entry` (Debit/Credit) entities.
#[derive(Debug, Clone)]
pub struct Entries {
    pool: PgPool,
}

#[derive(Debug)]
pub(crate) struct StagedEntry {
    pub(crate) account_id: AccountId,
    pub(crate) entry_id: EntryId,
    pub(crate) units: Decimal,
    pub(crate) currency: Currency,
    pub(crate) direction: DebitOrCredit,
    pub(crate) layer: Layer,
    pub(crate) created_at: DateTime<Utc>,
}

impl Entries {
    pub fn new(pool: &PgPool) -> Self {
        Self { pool: pool.clone() }
    }

    #[instrument(
        level = "trace",
        name = "sqlx_ledger.entries.create_all",
        skip(self, tx)
    )]
    pub(crate) async fn create_all<'a>(
        &self,
        journal_id: JournalId,
        transaction_id: TransactionId,
        entries: Vec<NewEntry>,
        tx: &mut Transaction<'a, Postgres>,
    ) -> Result<Vec<StagedEntry>, SqlxLedgerError> {
        let mut query_builder: QueryBuilder<Postgres> = QueryBuilder::new(
            r#"WITH new_entries as (
                 INSERT INTO sqlx_ledger_entries
                  (id, transaction_id, journal_id, entry_type, layer,
                   units, currency, direction, description, sequence, account_id)"#,
        );
        let mut partial_ret = HashMap::new();
        let mut sequence = 1;
        query_builder.push_values(
            entries,
            |mut builder,
             NewEntry {
                 account_id,
                 entry_type,
                 layer,
                 units,
                 currency,
                 direction,
                 description,
             }: NewEntry| {
                builder.push("gen_random_uuid()");
                builder.push_bind(transaction_id);
                builder.push_bind(journal_id);
                builder.push_bind(entry_type);
                builder.push_bind(layer);
                builder.push_bind(units);
                builder.push_bind(currency.code());
                builder.push_bind(direction);
                builder.push_bind(description);
                builder.push_bind(sequence);
                builder.push("(SELECT id FROM sqlx_ledger_accounts WHERE id = ");
                builder.push_bind_unseparated(account_id);
                builder.push_unseparated(")");
                partial_ret.insert(sequence, (account_id, units, currency, layer, direction));
                sequence += 1;
            },
        );
        query_builder.push(
            "RETURNING id, sequence, created_at ) SELECT * FROM new_entries ORDER BY sequence",
        );
        let query = query_builder.build();
        let records = query.fetch_all(&mut **tx).await?;

        let mut ret = Vec::new();
        sequence = 1;
        for r in records {
            let entry_id: Uuid = r.get("id");
            let created_at = r.get("created_at");
            let (account_id, units, currency, layer, direction) =
                partial_ret.remove(&sequence).expect("sequence not found");
            ret.push(StagedEntry {
                entry_id: entry_id.into(),
                account_id,
                units,
                currency,
                layer,
                direction,
                created_at,
            });
            sequence += 1;
        }

        Ok(ret)
    }

    pub async fn list_by_transaction_ids(
        &self,
        tx_ids: impl IntoIterator<Item = impl std::borrow::Borrow<TransactionId>>,
    ) -> Result<HashMap<TransactionId, Vec<Entry>>, SqlxLedgerError> {
        let tx_ids: Vec<Uuid> = tx_ids
            .into_iter()
            .map(|id| Uuid::from(id.borrow()))
            .collect();
        let records = sqlx::query!(
            r#"SELECT id, version, transaction_id, account_id, journal_id, entry_type, layer as "layer: Layer", units, currency, direction as "direction: DebitOrCredit", sequence, description, created_at, modified_at
            FROM sqlx_ledger_entries
            WHERE transaction_id = ANY($1) ORDER BY transaction_id ASC, sequence ASC, version DESC"#,
            &tx_ids[..]
        ).fetch_all(&self.pool).await?;

        let mut transactions: HashMap<TransactionId, Vec<Entry>> = HashMap::new();

        let mut current_tx_id = TransactionId::new();
        let mut last_sequence = 0;
        for row in records {
            let transaction_id = TransactionId::from(row.transaction_id);
            // Skip old entry versions (description is mutable)
            if last_sequence == row.sequence && transaction_id == current_tx_id {
                continue;
            }
            current_tx_id = transaction_id;
            last_sequence = row.sequence;

            let entry = transactions.entry(transaction_id).or_default();

            entry.push(Entry {
                id: EntryId::from(row.id),
                transaction_id,
                version: row.version as u32,
                account_id: AccountId::from(row.account_id),
                journal_id: JournalId::from(row.journal_id),
                entry_type: row.entry_type,
                layer: row.layer,
                units: row.units,
                currency: Currency::from_str(row.currency.as_str())
                    .expect("Couldn't convert currency"),
                direction: row.direction,
                sequence: row.sequence as u32,
                description: row.description,
                created_at: row.created_at,
                modified_at: row.modified_at,
            })
        }

        Ok(transactions)
    }
}