transmit/
repository_postgres.rs

1use std::error::Error;
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use log::{debug, info};
6use serde::{Deserialize, Serialize};
7use serde_json;
8use sqlx::postgres::PgPool;
9use uuid::Uuid;
10
11use crate::contract::Repository;
12use crate::model::*;
13
14pub struct RepositoryPostgres {
15    conn: PgPool,
16}
17
18impl RepositoryPostgres {
19    pub fn new(conn: PgPool) -> RepositoryPostgres {
20        info!("constructing new repository");
21
22        RepositoryPostgres { conn }
23    }
24
25    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
26        info!("migrating");
27
28        sqlx::migrate!().run(&self.conn).await?;
29
30        Ok(())
31    }
32
33    pub async fn clear_all(&self) -> Result<(), Box<dyn Error>> {
34        info!("deleting all transmissions");
35
36        let _ = sqlx::query!("delete from transmission;")
37            .execute(&self.conn)
38            .await?;
39
40        Ok(())
41    }
42}
43
44#[async_trait]
45impl Repository for RepositoryPostgres {
46    async fn store_transmission(
47        &self,
48        schedule: &Transmission,
49    ) -> Result<(), Box<dyn Error + Send + Sync>> {
50        info!("storing transmission");
51
52        let schedule_sql = TransmissionSql::from(schedule);
53
54        let _ = sqlx::query!(
55            "
56INSERT INTO transmission (
57    id, message, next, schedule, transmission_count, inserted_at, is_locked
58) VALUES (
59    $1, $2, $3, $4, $5, now(), false
60);
61        ",
62            schedule_sql.id,
63            schedule_sql.message,
64            schedule_sql.next,
65            schedule_sql.schedule,
66            schedule_sql.transmission_count as i32,
67        )
68        .execute(&self.conn)
69        .await?;
70
71        Ok(())
72    }
73
74    async fn poll_transmissions(
75        &self,
76        before: DateTime<Utc>,
77        batch_size: u32,
78    ) -> Result<Vec<Transmission>, Box<dyn Error>> {
79        debug!("polling batch");
80        let message_schedules_sql = sqlx::query_as!(
81            TransmissionSql,
82            "
83WITH locked_schedules AS (
84    UPDATE transmission
85    SET is_locked = true
86    WHERE id IN (
87        SELECT id
88        FROM (
89            SELECT id, MAX(inserted_at) AS latest_inserted_at
90            FROM transmission
91            GROUP BY id
92        ) latest_entries
93        WHERE inserted_at = latest_inserted_at
94    )
95    AND next IS NOT NULL
96    AND next < $1
97    AND is_locked = false
98    RETURNING id, message, next, schedule, transmission_count
99)
100SELECT * FROM locked_schedules
101LIMIT $2;
102        ",
103            before,
104            batch_size as i64,
105        )
106        .fetch_all(&self.conn)
107        .await?;
108
109        let message_schedules: Vec<Transmission> = message_schedules_sql
110            .iter()
111            .map(|message_schedule_sql| Transmission::from((*message_schedule_sql).clone()))
112            .collect();
113
114        Ok(message_schedules)
115    }
116
117    async fn save(&self, schedule: &Transmission) -> Result<(), Box<dyn Error + Send + Sync>> {
118        self.store_transmission(schedule).await
119    }
120
121    async fn reschedule(
122        &self,
123        transmission_id: &uuid::Uuid,
124    ) -> Result<(), Box<dyn Error + Send + Sync>> {
125        let _ = sqlx::query!(
126            "
127UPDATE transmission
128SET is_locked = true
129WHERE id = $1
130  AND is_locked = false
131  AND (id, inserted_at) = (
132    SELECT id, MAX(inserted_at)
133    FROM transmission
134    WHERE id = $2
135      AND is_locked = false
136    GROUP BY id
137  );
138        ",
139            transmission_id,
140            transmission_id,
141        )
142        .execute(&self.conn)
143        .await?;
144
145        Ok(())
146    }
147}
148
149#[derive(Debug, Serialize, Deserialize, Clone, sqlx::FromRow)]
150struct TransmissionSql {
151    id: Uuid,
152    message: String,
153    schedule: String,
154    next: Option<DateTime<Utc>>,
155    transmission_count: i32,
156}
157
158impl From<&Transmission> for TransmissionSql {
159    fn from(schedule: &Transmission) -> TransmissionSql {
160        TransmissionSql {
161            id: schedule.id,
162            message: serde_json::to_string(&schedule.message).expect("Failed to serialize message"),
163            schedule: serde_json::to_string(&schedule.schedule)
164                .expect("Failed to serialize schedule"),
165            transmission_count: schedule.transmission_count as i32,
166            next: schedule.next,
167        }
168    }
169}
170
171impl From<TransmissionSql> for Transmission {
172    fn from(schedule_sql: TransmissionSql) -> Transmission {
173        Transmission {
174            id: schedule_sql.id,
175            schedule: match serde_json::from_str(&schedule_sql.schedule) {
176                Ok(schedule) => schedule,
177                Err(err) => panic!("Failed to deserialize schedule: {:?}", err),
178            },
179            message: match serde_json::from_str(&schedule_sql.message) {
180                Ok(message) => message,
181                Err(err) => panic!("Failed to deserialize message: {:?}", err),
182            },
183            transmission_count: schedule_sql.transmission_count as u32,
184            next: schedule_sql.next,
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    use crate::postgres;
194
195    #[tokio::test]
196    async fn test_store() {
197        let config = postgres::Config {
198            name: "transmit".into(),
199            host: "localhost".into(),
200            port: 5432,
201            user: "postgres".into(),
202            password: "postgres".into(),
203            ssl: false,
204        };
205        let connection = postgres::connect_to_test_database(config)
206            .await
207            .expect("connecting to postgres failed. Is postgres running on port 5432?");
208
209        let repository = RepositoryPostgres::new(connection);
210        repository
211            .migrate()
212            .await
213            .expect("could not run migrations");
214        repository.clear_all().await.expect("could not clear table");
215
216        let now = Utc::now();
217        let past = now - chrono::Duration::milliseconds(100);
218        let future = now + chrono::Duration::milliseconds(100);
219
220        let polled_schedules_empty = repository
221            .poll_transmissions(now, 100)
222            .await
223            .expect("poll batch should be ok");
224        assert_eq!(polled_schedules_empty.len(), 0);
225
226        let schedules = vec![
227            Transmission::new(
228                Schedule::Delayed(Delayed::new(past)),
229                Message::NatsEvent(NatsEvent::new(
230                    "ARBITRARY.subject".into(),
231                    "first payload".into(),
232                )),
233            ),
234            Transmission::new(
235                Schedule::Delayed(Delayed::new(future)),
236                Message::NatsEvent(NatsEvent::new(
237                    "ARBITRARY.subject".into(),
238                    "second payload".into(),
239                )),
240            ),
241        ];
242        let expected_polled_schedules: Vec<Transmission> = vec![Transmission {
243            id: schedules[0].id.clone(),
244            schedule: Schedule::Delayed(Delayed::new(past)),
245            message: schedules[0].message.clone(),
246            next: Some(past),
247            transmission_count: 0,
248        }];
249
250        for schedule in schedules.iter() {
251            repository
252                .store_transmission(schedule)
253                .await
254                .expect("store schedule should be ok");
255        }
256
257        let polled_schedules = repository
258            .poll_transmissions(now, 100)
259            .await
260            .expect("poll batch should be ok");
261        assert_eq!(polled_schedules, expected_polled_schedules);
262
263        for schedule in schedules.iter() {
264            let transmitted_message = schedule.transmitted();
265            match transmitted_message {
266                Ok(schedule) => {
267                    assert!(schedule.next.is_none());
268                    match repository.save(&schedule).await {
269                        Ok(()) => (),
270                        Err(err) => panic!("failed to save: {err}"),
271                    }
272                }
273                Err(err) => panic!("failed to transition to transmitted state: {err}"),
274            };
275        }
276
277        let polled_schedules_transmitted = repository
278            .poll_transmissions(now, 100)
279            .await
280            .expect("last poll batch should be ok");
281        assert_eq!(polled_schedules_transmitted, vec![]);
282    }
283}