tokio_cqrs_es2_store/impls/sql/postgres_store/
query_store.rs

1use async_trait::async_trait;
2use log::{
3    debug,
4    trace,
5};
6use std::marker::PhantomData;
7
8use sqlx::postgres::PgPool;
9
10use cqrs_es2::{
11    Error,
12    EventContext,
13    IAggregate,
14    ICommand,
15    IEvent,
16    IQuery,
17    QueryContext,
18};
19
20use crate::repository::{
21    IEventDispatcher,
22    IQueryStore,
23};
24
25use super::super::postgres_constants::*;
26
27/// Async Postgres query store
28pub struct QueryStore<
29    C: ICommand,
30    E: IEvent,
31    A: IAggregate<C, E>,
32    Q: IQuery<C, E>,
33> {
34    pool: PgPool,
35    _phantom: PhantomData<(C, E, A, Q)>,
36}
37
38impl<
39        C: ICommand,
40        E: IEvent,
41        A: IAggregate<C, E>,
42        Q: IQuery<C, E>,
43    > QueryStore<C, E, A, Q>
44{
45    /// Constructor
46    pub fn new(pool: PgPool) -> Self {
47        let x = Self {
48            pool,
49            _phantom: PhantomData,
50        };
51
52        trace!("Created new async Postgres query store");
53
54        x
55    }
56}
57
58#[async_trait]
59impl<
60        C: ICommand,
61        E: IEvent,
62        A: IAggregate<C, E>,
63        Q: IQuery<C, E>,
64    > IQueryStore<C, E, A, Q> for QueryStore<C, E, A, Q>
65{
66    /// saves the updated query
67    async fn save_query(
68        &mut self,
69        context: QueryContext<C, E, Q>,
70    ) -> Result<(), Error> {
71        let aggregate_type = A::aggregate_type();
72        let query_type = Q::query_type();
73
74        let aggregate_id = context.aggregate_id;
75
76        debug!(
77            "storing a new query '{}' for aggregate id '{}'",
78            query_type, &aggregate_id
79        );
80
81        let sql = match context.version {
82            1 => INSERT_QUERY,
83            _ => UPDATE_QUERY,
84        };
85
86        let payload = match serde_json::to_value(context.payload) {
87            Ok(x) => x,
88            Err(e) => {
89                return Err(Error::new(
90                    format!(
91                        "unable to serialize the payload of query \
92                         '{}' with aggregate id '{}', error: {}",
93                        &query_type, &aggregate_id, e,
94                    )
95                    .as_str(),
96                ));
97            },
98        };
99
100        match sqlx::query(sql)
101            .bind(context.version)
102            .bind(&payload)
103            .bind(&aggregate_type)
104            .bind(&aggregate_id)
105            .bind(&query_type)
106            .execute(&self.pool)
107            .await
108        {
109            Ok(x) => {
110                if x.rows_affected() != 1 {
111                    return Err(Error::new(
112                        format!(
113                            "insert/update query failed for \
114                             aggregate id '{}'",
115                            &aggregate_id
116                        )
117                        .as_str(),
118                    ));
119                }
120            },
121            Err(e) => {
122                return Err(Error::new(
123                    format!(
124                        "unable to insert/update query for \
125                         aggregate id '{}' with error: {}",
126                        &aggregate_id, e
127                    )
128                    .as_str(),
129                ));
130            },
131        };
132
133        Ok(())
134    }
135
136    /// loads the most recent query
137    async fn load_query(
138        &mut self,
139        aggregate_id: &str,
140    ) -> Result<QueryContext<C, E, Q>, Error> {
141        let aggregate_type = A::aggregate_type();
142        let query_type = Q::query_type();
143
144        trace!(
145            "loading query '{}' for aggregate id '{}'",
146            query_type,
147            aggregate_id
148        );
149
150        let rows: Vec<(i64, serde_json::Value)> =
151            match sqlx::query_as(SELECT_QUERY)
152                .bind(&aggregate_type)
153                .bind(&aggregate_id)
154                .bind(&query_type)
155                .fetch_all(&self.pool)
156                .await
157            {
158                Ok(x) => x,
159                Err(e) => {
160                    return Err(Error::new(
161                        format!(
162                            "unable to load queries table for query \
163                             '{}' with aggregate id '{}', error: {}",
164                            &query_type, &aggregate_id, e,
165                        )
166                        .as_str(),
167                    ));
168                },
169            };
170
171        if rows.len() == 0 {
172            trace!(
173                "returning default query '{}' for aggregate id '{}'",
174                query_type,
175                aggregate_id
176            );
177
178            return Ok(QueryContext::new(
179                aggregate_id.to_string(),
180                0,
181                Default::default(),
182            ));
183        }
184
185        let row = rows[0].clone();
186
187        let payload = match serde_json::from_value(row.1) {
188            Ok(x) => x,
189            Err(e) => {
190                return Err(Error::new(
191                    format!(
192                        "bad payload found in queries table for \
193                         query '{}' with aggregate id '{}', error: \
194                         {}",
195                        &query_type, &aggregate_id, e,
196                    )
197                    .as_str(),
198                ));
199            },
200        };
201
202        Ok(QueryContext::new(
203            aggregate_id.to_string(),
204            row.0,
205            payload,
206        ))
207    }
208}
209
210#[async_trait]
211impl<
212        C: ICommand,
213        E: IEvent,
214        A: IAggregate<C, E>,
215        Q: IQuery<C, E>,
216    > IEventDispatcher<C, E> for QueryStore<C, E, A, Q>
217{
218    async fn dispatch(
219        &mut self,
220        aggregate_id: &str,
221        events: &Vec<EventContext<C, E>>,
222    ) -> Result<(), Error> {
223        self.dispatch_events(aggregate_id, events)
224            .await
225    }
226}