Skip to main content

sayiir_postgres/
signal_store.rs

1//! [`SignalStore`] implementation for Postgres.
2//!
3//! Overrides the 3 default composite methods with single-transaction
4//! implementations for true ACID atomicity.
5
6use sayiir_core::codec::{self, Decoder, Encoder};
7use sayiir_core::snapshot::{PauseRequest, SignalKind, SignalRequest, WorkflowSnapshot};
8use sayiir_persistence::validation::validate_signal_allowed;
9use sayiir_persistence::{BackendError, SignalStore};
10use sqlx::Row;
11
12use crate::backend::PostgresBackend;
13use crate::error::PgError;
14
15impl<C> SignalStore for PostgresBackend<C>
16where
17    C: Encoder
18        + Decoder
19        + codec::sealed::EncodeValue<WorkflowSnapshot>
20        + codec::sealed::DecodeValue<WorkflowSnapshot>,
21{
22    #[tracing::instrument(
23        name = "db.store_signal",
24        skip(self, request),
25        fields(db.system = "postgresql", kind = %kind.as_ref()),
26        err(level = tracing::Level::ERROR),
27    )]
28    async fn store_signal(
29        &self,
30        instance_id: &str,
31        kind: SignalKind,
32        request: SignalRequest,
33    ) -> Result<(), BackendError> {
34        tracing::debug!("storing signal");
35        // Validate workflow state first
36        let row =
37            sqlx::query("SELECT status FROM sayiir_workflow_snapshots WHERE instance_id = $1")
38                .bind(instance_id)
39                .fetch_optional(&self.pool)
40                .await
41                .map_err(PgError)?
42                .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
43
44        let status: String = row.get("status");
45        validate_signal_allowed(&status, kind)?;
46
47        sqlx::query(
48            "INSERT INTO sayiir_workflow_signals (instance_id, kind, reason, requested_by)
49             VALUES ($1, $2, $3, $4)
50             ON CONFLICT (instance_id, kind) DO UPDATE SET
51                reason = $3, requested_by = $4, created_at = now()",
52        )
53        .bind(instance_id)
54        .bind(kind.as_ref())
55        .bind(&request.reason)
56        .bind(&request.requested_by)
57        .execute(&self.pool)
58        .await
59        .map_err(PgError)?;
60
61        Ok(())
62    }
63
64    #[tracing::instrument(
65        name = "db.get_signal",
66        skip(self),
67        fields(db.system = "postgresql", kind = %kind.as_ref()),
68        err(level = tracing::Level::ERROR),
69    )]
70    async fn get_signal(
71        &self,
72        instance_id: &str,
73        kind: SignalKind,
74    ) -> Result<Option<SignalRequest>, BackendError> {
75        tracing::debug!("getting signal");
76        let row = sqlx::query(
77            "SELECT reason, requested_by, created_at
78             FROM sayiir_workflow_signals
79             WHERE instance_id = $1 AND kind = $2",
80        )
81        .bind(instance_id)
82        .bind(kind.as_ref())
83        .fetch_optional(&self.pool)
84        .await
85        .map_err(PgError)?;
86
87        Ok(row.map(|r| SignalRequest {
88            reason: r.get("reason"),
89            requested_by: r.get("requested_by"),
90            requested_at: r.get("created_at"),
91        }))
92    }
93
94    #[tracing::instrument(
95        name = "db.clear_signal",
96        skip(self),
97        fields(db.system = "postgresql", kind = %kind.as_ref()),
98        err(level = tracing::Level::ERROR),
99    )]
100    async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
101        tracing::debug!("clearing signal");
102        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
103            .bind(instance_id)
104            .bind(kind.as_ref())
105            .execute(&self.pool)
106            .await
107            .map_err(PgError)?;
108        Ok(())
109    }
110
111    #[tracing::instrument(
112        name = "db.send_event",
113        skip(self, payload),
114        fields(db.system = "postgresql"),
115        err(level = tracing::Level::ERROR),
116    )]
117    async fn send_event(
118        &self,
119        instance_id: &str,
120        signal_name: &str,
121        payload: bytes::Bytes,
122    ) -> Result<(), BackendError> {
123        tracing::debug!("buffering external event");
124        sqlx::query(
125            "INSERT INTO sayiir_workflow_events (instance_id, signal_name, payload)
126             VALUES ($1, $2, $3)",
127        )
128        .bind(instance_id)
129        .bind(signal_name)
130        .bind(payload.as_ref())
131        .execute(&self.pool)
132        .await
133        .map_err(PgError)?;
134        Ok(())
135    }
136
137    #[tracing::instrument(
138        name = "db.consume_event",
139        skip(self),
140        fields(db.system = "postgresql"),
141        err(level = tracing::Level::ERROR),
142    )]
143    async fn consume_event(
144        &self,
145        instance_id: &str,
146        signal_name: &str,
147    ) -> Result<Option<bytes::Bytes>, BackendError> {
148        tracing::debug!("consuming oldest buffered event");
149        // Atomically delete-and-return the oldest event for this (instance, signal).
150        let row = sqlx::query(
151            "DELETE FROM sayiir_workflow_events
152             WHERE id = (
153                 SELECT id FROM sayiir_workflow_events
154                 WHERE instance_id = $1 AND signal_name = $2
155                 ORDER BY id ASC
156                 LIMIT 1
157                 FOR UPDATE SKIP LOCKED
158             )
159             RETURNING payload",
160        )
161        .bind(instance_id)
162        .bind(signal_name)
163        .fetch_optional(&self.pool)
164        .await
165        .map_err(PgError)?;
166
167        Ok(row.map(|r| {
168            let raw: Vec<u8> = r.get("payload");
169            bytes::Bytes::from(raw)
170        }))
171    }
172
173    // --- Overridden composites: single ACID transactions ---
174
175    #[tracing::instrument(
176        name = "db.check_and_cancel",
177        skip(self),
178        fields(db.system = "postgresql"),
179        err(level = tracing::Level::ERROR),
180    )]
181    async fn check_and_cancel(
182        &self,
183        instance_id: &str,
184        interrupted_at_task: Option<&str>,
185    ) -> Result<bool, BackendError> {
186        tracing::debug!("checking for cancel signal");
187        let mut tx = self.pool.begin().await.map_err(PgError)?;
188
189        // Check for cancel signal (lock the row)
190        let signal_row = sqlx::query(
191            "SELECT reason, requested_by
192             FROM sayiir_workflow_signals
193             WHERE instance_id = $1 AND kind = $2
194             FOR UPDATE",
195        )
196        .bind(instance_id)
197        .bind(SignalKind::Cancel.as_ref())
198        .fetch_optional(&mut *tx)
199        .await
200        .map_err(PgError)?;
201
202        let Some(signal_row) = signal_row else {
203            tx.rollback().await.map_err(PgError)?;
204            return Ok(false);
205        };
206
207        // Lock and load the snapshot
208        let snap_row = sqlx::query(
209            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
210        )
211        .bind(instance_id)
212        .fetch_one(&mut *tx)
213        .await
214        .map_err(PgError)?;
215
216        let raw: &[u8] = snap_row.get("data");
217        let mut snapshot = self.decode(raw)?;
218
219        if !snapshot.state.is_in_progress() {
220            tx.rollback().await.map_err(PgError)?;
221            return Ok(false);
222        }
223
224        let reason: Option<String> = signal_row.get("reason");
225        let requested_by: Option<String> = signal_row.get("requested_by");
226        snapshot.mark_cancelled(reason, requested_by, interrupted_at_task.map(String::from));
227
228        let data = self.encode(&snapshot)?;
229        let status = snapshot.state.as_ref();
230        let error = snapshot.error_message().map(ToString::to_string);
231        let pos_kind = snapshot.position_kind();
232        let wake_at = snapshot.delay_wake_at();
233
234        sqlx::query(
235            "UPDATE sayiir_workflow_snapshots
236             SET data = $1, status = $2, error = $3,
237                 position_kind = $4, delay_wake_at = $5,
238                 completed_at = now(), updated_at = now()
239             WHERE instance_id = $6",
240        )
241        .bind(&data)
242        .bind(status)
243        .bind(&error)
244        .bind(pos_kind)
245        .bind(wake_at)
246        .bind(instance_id)
247        .execute(&mut *tx)
248        .await
249        .map_err(PgError)?;
250
251        // Mark any still-active tasks as cancelled
252        sqlx::query(
253            "UPDATE sayiir_workflow_tasks SET status = 'cancelled', completed_at = now()
254             WHERE instance_id = $1 AND status = 'active'",
255        )
256        .bind(instance_id)
257        .execute(&mut *tx)
258        .await
259        .map_err(PgError)?;
260
261        // Clear the signal
262        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
263            .bind(instance_id)
264            .bind(SignalKind::Cancel.as_ref())
265            .execute(&mut *tx)
266            .await
267            .map_err(PgError)?;
268
269        tx.commit().await.map_err(PgError)?;
270        tracing::info!(instance_id, "workflow cancelled");
271        Ok(true)
272    }
273
274    #[tracing::instrument(
275        name = "db.check_and_pause",
276        skip(self),
277        fields(db.system = "postgresql"),
278        err(level = tracing::Level::ERROR),
279    )]
280    async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
281        tracing::debug!("checking for pause signal");
282        let mut tx = self.pool.begin().await.map_err(PgError)?;
283
284        // Check for pause signal (lock the row)
285        let signal_row = sqlx::query(
286            "SELECT reason, requested_by
287             FROM sayiir_workflow_signals
288             WHERE instance_id = $1 AND kind = $2
289             FOR UPDATE",
290        )
291        .bind(instance_id)
292        .bind(SignalKind::Pause.as_ref())
293        .fetch_optional(&mut *tx)
294        .await
295        .map_err(PgError)?;
296
297        let Some(signal_row) = signal_row else {
298            tx.rollback().await.map_err(PgError)?;
299            return Ok(false);
300        };
301
302        // Lock and load the snapshot
303        let snap_row = sqlx::query(
304            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
305        )
306        .bind(instance_id)
307        .fetch_one(&mut *tx)
308        .await
309        .map_err(PgError)?;
310
311        let raw: &[u8] = snap_row.get("data");
312        let mut snapshot = self.decode(raw)?;
313
314        if !snapshot.state.is_in_progress() {
315            tx.rollback().await.map_err(PgError)?;
316            return Ok(false);
317        }
318
319        let reason: Option<String> = signal_row.get("reason");
320        let requested_by: Option<String> = signal_row.get("requested_by");
321        let pause_request = PauseRequest::new(reason, requested_by);
322        snapshot.mark_paused(&pause_request);
323
324        let data = self.encode(&snapshot)?;
325        let status = snapshot.state.as_ref();
326        let task_id = snapshot.current_task_id().map(ToString::to_string);
327        let task_count = snapshot.completed_task_count();
328        let pos_kind = snapshot.position_kind();
329        let wake_at = snapshot.delay_wake_at();
330
331        sqlx::query(
332            "UPDATE sayiir_workflow_snapshots
333             SET data = $1, status = $2, current_task_id = $3,
334                 completed_task_count = $4, position_kind = $5,
335                 delay_wake_at = $6, updated_at = now()
336             WHERE instance_id = $7",
337        )
338        .bind(&data)
339        .bind(status)
340        .bind(&task_id)
341        .bind(task_count)
342        .bind(pos_kind)
343        .bind(wake_at)
344        .bind(instance_id)
345        .execute(&mut *tx)
346        .await
347        .map_err(PgError)?;
348
349        // Clear the signal
350        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
351            .bind(instance_id)
352            .bind(SignalKind::Pause.as_ref())
353            .execute(&mut *tx)
354            .await
355            .map_err(PgError)?;
356
357        tx.commit().await.map_err(PgError)?;
358        tracing::info!(instance_id, "workflow paused");
359        Ok(true)
360    }
361
362    #[tracing::instrument(
363        name = "db.unpause",
364        skip(self),
365        fields(db.system = "postgresql"),
366        err(level = tracing::Level::ERROR),
367    )]
368    async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
369        tracing::debug!("unpausing workflow");
370        let mut tx = self.pool.begin().await.map_err(PgError)?;
371
372        let row = sqlx::query(
373            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
374        )
375        .bind(instance_id)
376        .fetch_optional(&mut *tx)
377        .await
378        .map_err(PgError)?
379        .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
380
381        let raw: &[u8] = row.get("data");
382        let mut snapshot = self.decode(raw)?;
383
384        if !snapshot.state.is_paused() {
385            let state_name = snapshot.state.as_ref();
386            return Err(BackendError::CannotPause(format!(
387                "Workflow is not paused (current state: {state_name:?})"
388            )));
389        }
390
391        snapshot.mark_unpaused();
392
393        let data = self.encode(&snapshot)?;
394        let status = snapshot.state.as_ref();
395        let task_id = snapshot.current_task_id().map(ToString::to_string);
396        let task_count = snapshot.completed_task_count();
397        let pos_kind = snapshot.position_kind();
398        let wake_at = snapshot.delay_wake_at();
399
400        sqlx::query(
401            "UPDATE sayiir_workflow_snapshots
402             SET data = $1, status = $2, current_task_id = $3,
403                 completed_task_count = $4, position_kind = $5,
404                 delay_wake_at = $6, updated_at = now()
405             WHERE instance_id = $7",
406        )
407        .bind(&data)
408        .bind(status)
409        .bind(&task_id)
410        .bind(task_count)
411        .bind(pos_kind)
412        .bind(wake_at)
413        .bind(instance_id)
414        .execute(&mut *tx)
415        .await
416        .map_err(PgError)?;
417
418        tx.commit().await.map_err(PgError)?;
419        tracing::info!(instance_id, "workflow unpaused");
420        Ok(snapshot)
421    }
422}