1use std::error::Error;
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use log::{debug, info};
6use serde::{Deserialize, Serialize};
7use serde_json;
8use sqlx::postgres::PgPool;
9use uuid::Uuid;
10
11use crate::contract::Repository;
12use crate::model::*;
13
14pub struct RepositoryPostgres {
15 conn: PgPool,
16}
17
18impl RepositoryPostgres {
19 pub fn new(conn: PgPool) -> RepositoryPostgres {
20 info!("constructing new repository");
21
22 RepositoryPostgres { conn }
23 }
24
25 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
26 info!("migrating");
27
28 sqlx::migrate!().run(&self.conn).await?;
29
30 Ok(())
31 }
32
33 pub async fn clear_all(&self) -> Result<(), Box<dyn Error>> {
34 info!("deleting all transmissions");
35
36 let _ = sqlx::query!("delete from transmission;")
37 .execute(&self.conn)
38 .await?;
39
40 Ok(())
41 }
42}
43
44#[async_trait]
45impl Repository for RepositoryPostgres {
46 async fn store_transmission(
47 &self,
48 schedule: &Transmission,
49 ) -> Result<(), Box<dyn Error + Send + Sync>> {
50 info!("storing transmission");
51
52 let schedule_sql = TransmissionSql::from(schedule);
53
54 let _ = sqlx::query!(
55 "
56INSERT INTO transmission (
57 id, message, next, schedule, transmission_count, inserted_at, is_locked
58) VALUES (
59 $1, $2, $3, $4, $5, now(), false
60);
61 ",
62 schedule_sql.id,
63 schedule_sql.message,
64 schedule_sql.next,
65 schedule_sql.schedule,
66 schedule_sql.transmission_count as i32,
67 )
68 .execute(&self.conn)
69 .await?;
70
71 Ok(())
72 }
73
74 async fn poll_transmissions(
75 &self,
76 before: DateTime<Utc>,
77 batch_size: u32,
78 ) -> Result<Vec<Transmission>, Box<dyn Error>> {
79 debug!("polling batch");
80 let message_schedules_sql = sqlx::query_as!(
81 TransmissionSql,
82 "
83WITH locked_schedules AS (
84 UPDATE transmission
85 SET is_locked = true
86 WHERE id IN (
87 SELECT id
88 FROM (
89 SELECT id, MAX(inserted_at) AS latest_inserted_at
90 FROM transmission
91 GROUP BY id
92 ) latest_entries
93 WHERE inserted_at = latest_inserted_at
94 )
95 AND next IS NOT NULL
96 AND next < $1
97 AND is_locked = false
98 RETURNING id, message, next, schedule, transmission_count
99)
100SELECT * FROM locked_schedules
101LIMIT $2;
102 ",
103 before,
104 batch_size as i64,
105 )
106 .fetch_all(&self.conn)
107 .await?;
108
109 let message_schedules: Vec<Transmission> = message_schedules_sql
110 .iter()
111 .map(|message_schedule_sql| Transmission::from((*message_schedule_sql).clone()))
112 .collect();
113
114 Ok(message_schedules)
115 }
116
117 async fn save(&self, schedule: &Transmission) -> Result<(), Box<dyn Error + Send + Sync>> {
118 self.store_transmission(schedule).await
119 }
120
121 async fn reschedule(
122 &self,
123 transmission_id: &uuid::Uuid,
124 ) -> Result<(), Box<dyn Error + Send + Sync>> {
125 let _ = sqlx::query!(
126 "
127UPDATE transmission
128SET is_locked = true
129WHERE id = $1
130 AND is_locked = false
131 AND (id, inserted_at) = (
132 SELECT id, MAX(inserted_at)
133 FROM transmission
134 WHERE id = $2
135 AND is_locked = false
136 GROUP BY id
137 );
138 ",
139 transmission_id,
140 transmission_id,
141 )
142 .execute(&self.conn)
143 .await?;
144
145 Ok(())
146 }
147}
148
149#[derive(Debug, Serialize, Deserialize, Clone, sqlx::FromRow)]
150struct TransmissionSql {
151 id: Uuid,
152 message: String,
153 schedule: String,
154 next: Option<DateTime<Utc>>,
155 transmission_count: i32,
156}
157
158impl From<&Transmission> for TransmissionSql {
159 fn from(schedule: &Transmission) -> TransmissionSql {
160 TransmissionSql {
161 id: schedule.id,
162 message: serde_json::to_string(&schedule.message).expect("Failed to serialize message"),
163 schedule: serde_json::to_string(&schedule.schedule)
164 .expect("Failed to serialize schedule"),
165 transmission_count: schedule.transmission_count as i32,
166 next: schedule.next,
167 }
168 }
169}
170
171impl From<TransmissionSql> for Transmission {
172 fn from(schedule_sql: TransmissionSql) -> Transmission {
173 Transmission {
174 id: schedule_sql.id,
175 schedule: match serde_json::from_str(&schedule_sql.schedule) {
176 Ok(schedule) => schedule,
177 Err(err) => panic!("Failed to deserialize schedule: {:?}", err),
178 },
179 message: match serde_json::from_str(&schedule_sql.message) {
180 Ok(message) => message,
181 Err(err) => panic!("Failed to deserialize message: {:?}", err),
182 },
183 transmission_count: schedule_sql.transmission_count as u32,
184 next: schedule_sql.next,
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 use crate::postgres;
194
195 #[tokio::test]
196 async fn test_store() {
197 let config = postgres::Config {
198 name: "transmit".into(),
199 host: "localhost".into(),
200 port: 5432,
201 user: "postgres".into(),
202 password: "postgres".into(),
203 ssl: false,
204 };
205 let connection = postgres::connect_to_test_database(config)
206 .await
207 .expect("connecting to postgres failed. Is postgres running on port 5432?");
208
209 let repository = RepositoryPostgres::new(connection);
210 repository
211 .migrate()
212 .await
213 .expect("could not run migrations");
214 repository.clear_all().await.expect("could not clear table");
215
216 let now = Utc::now();
217 let past = now - chrono::Duration::milliseconds(100);
218 let future = now + chrono::Duration::milliseconds(100);
219
220 let polled_schedules_empty = repository
221 .poll_transmissions(now, 100)
222 .await
223 .expect("poll batch should be ok");
224 assert_eq!(polled_schedules_empty.len(), 0);
225
226 let schedules = vec![
227 Transmission::new(
228 Schedule::Delayed(Delayed::new(past)),
229 Message::NatsEvent(NatsEvent::new(
230 "ARBITRARY.subject".into(),
231 "first payload".into(),
232 )),
233 ),
234 Transmission::new(
235 Schedule::Delayed(Delayed::new(future)),
236 Message::NatsEvent(NatsEvent::new(
237 "ARBITRARY.subject".into(),
238 "second payload".into(),
239 )),
240 ),
241 ];
242 let expected_polled_schedules: Vec<Transmission> = vec![Transmission {
243 id: schedules[0].id.clone(),
244 schedule: Schedule::Delayed(Delayed::new(past)),
245 message: schedules[0].message.clone(),
246 next: Some(past),
247 transmission_count: 0,
248 }];
249
250 for schedule in schedules.iter() {
251 repository
252 .store_transmission(schedule)
253 .await
254 .expect("store schedule should be ok");
255 }
256
257 let polled_schedules = repository
258 .poll_transmissions(now, 100)
259 .await
260 .expect("poll batch should be ok");
261 assert_eq!(polled_schedules, expected_polled_schedules);
262
263 for schedule in schedules.iter() {
264 let transmitted_message = schedule.transmitted();
265 match transmitted_message {
266 Ok(schedule) => {
267 assert!(schedule.next.is_none());
268 match repository.save(&schedule).await {
269 Ok(()) => (),
270 Err(err) => panic!("failed to save: {err}"),
271 }
272 }
273 Err(err) => panic!("failed to transition to transmitted state: {err}"),
274 };
275 }
276
277 let polled_schedules_transmitted = repository
278 .poll_transmissions(now, 100)
279 .await
280 .expect("last poll batch should be ok");
281 assert_eq!(polled_schedules_transmitted, vec![]);
282 }
283}