Skip to main content

sayiir_postgres/
signal_store.rs

1//! [`SignalStore`] implementation for Postgres.
2//!
3//! Overrides the 3 default composite methods with single-transaction
4//! implementations for true ACID atomicity.
5
6use sayiir_core::codec::{self, Decoder, Encoder};
7use sayiir_core::snapshot::{
8    PauseRequest, SignalKind, SignalRequest, SnapshotStatus, WorkflowSnapshot,
9};
10use sayiir_persistence::{BackendError, SignalStore};
11use sqlx::Row;
12
13use crate::backend::PostgresBackend;
14use crate::error::PgError;
15use crate::helpers::{
16    completed_task_count, current_task_id, delay_wake_at, error_message, position_kind, status_str,
17};
18
19impl<C> SignalStore for PostgresBackend<C>
20where
21    C: Encoder
22        + Decoder
23        + codec::sealed::EncodeValue<WorkflowSnapshot>
24        + codec::sealed::DecodeValue<WorkflowSnapshot>,
25{
26    #[tracing::instrument(
27        name = "db.store_signal",
28        skip(self, request),
29        fields(db.system = "postgresql", kind = %kind.as_ref()),
30        err(level = tracing::Level::ERROR),
31    )]
32    async fn store_signal(
33        &self,
34        instance_id: &str,
35        kind: SignalKind,
36        request: SignalRequest,
37    ) -> Result<(), BackendError> {
38        tracing::debug!("storing signal");
39        // Validate workflow state first
40        let row =
41            sqlx::query("SELECT status FROM sayiir_workflow_snapshots WHERE instance_id = $1")
42                .bind(instance_id)
43                .fetch_optional(&self.pool)
44                .await
45                .map_err(PgError)?
46                .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
47
48        let status: String = row.get("status");
49        validate_signal_allowed(&status, kind)?;
50
51        sqlx::query(
52            "INSERT INTO sayiir_workflow_signals (instance_id, kind, reason, requested_by)
53             VALUES ($1, $2, $3, $4)
54             ON CONFLICT (instance_id, kind) DO UPDATE SET
55                reason = $3, requested_by = $4, created_at = now()",
56        )
57        .bind(instance_id)
58        .bind(kind.as_ref())
59        .bind(&request.reason)
60        .bind(&request.requested_by)
61        .execute(&self.pool)
62        .await
63        .map_err(PgError)?;
64
65        Ok(())
66    }
67
68    #[tracing::instrument(
69        name = "db.get_signal",
70        skip(self),
71        fields(db.system = "postgresql", kind = %kind.as_ref()),
72        err(level = tracing::Level::ERROR),
73    )]
74    async fn get_signal(
75        &self,
76        instance_id: &str,
77        kind: SignalKind,
78    ) -> Result<Option<SignalRequest>, BackendError> {
79        tracing::debug!("getting signal");
80        let row = sqlx::query(
81            "SELECT reason, requested_by, created_at
82             FROM sayiir_workflow_signals
83             WHERE instance_id = $1 AND kind = $2",
84        )
85        .bind(instance_id)
86        .bind(kind.as_ref())
87        .fetch_optional(&self.pool)
88        .await
89        .map_err(PgError)?;
90
91        Ok(row.map(|r| SignalRequest {
92            reason: r.get("reason"),
93            requested_by: r.get("requested_by"),
94            requested_at: r.get("created_at"),
95        }))
96    }
97
98    #[tracing::instrument(
99        name = "db.clear_signal",
100        skip(self),
101        fields(db.system = "postgresql", kind = %kind.as_ref()),
102        err(level = tracing::Level::ERROR),
103    )]
104    async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
105        tracing::debug!("clearing signal");
106        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
107            .bind(instance_id)
108            .bind(kind.as_ref())
109            .execute(&self.pool)
110            .await
111            .map_err(PgError)?;
112        Ok(())
113    }
114
115    #[tracing::instrument(
116        name = "db.send_event",
117        skip(self, payload),
118        fields(db.system = "postgresql"),
119        err(level = tracing::Level::ERROR),
120    )]
121    async fn send_event(
122        &self,
123        instance_id: &str,
124        signal_name: &str,
125        payload: bytes::Bytes,
126    ) -> Result<(), BackendError> {
127        tracing::debug!("buffering external event");
128        sqlx::query(
129            "INSERT INTO sayiir_workflow_events (instance_id, signal_name, payload)
130             VALUES ($1, $2, $3)",
131        )
132        .bind(instance_id)
133        .bind(signal_name)
134        .bind(payload.as_ref())
135        .execute(&self.pool)
136        .await
137        .map_err(PgError)?;
138        Ok(())
139    }
140
141    #[tracing::instrument(
142        name = "db.consume_event",
143        skip(self),
144        fields(db.system = "postgresql"),
145        err(level = tracing::Level::ERROR),
146    )]
147    async fn consume_event(
148        &self,
149        instance_id: &str,
150        signal_name: &str,
151    ) -> Result<Option<bytes::Bytes>, BackendError> {
152        tracing::debug!("consuming oldest buffered event");
153        // Atomically delete-and-return the oldest event for this (instance, signal).
154        let row = sqlx::query(
155            "DELETE FROM sayiir_workflow_events
156             WHERE id = (
157                 SELECT id FROM sayiir_workflow_events
158                 WHERE instance_id = $1 AND signal_name = $2
159                 ORDER BY id ASC
160                 LIMIT 1
161                 FOR UPDATE SKIP LOCKED
162             )
163             RETURNING payload",
164        )
165        .bind(instance_id)
166        .bind(signal_name)
167        .fetch_optional(&self.pool)
168        .await
169        .map_err(PgError)?;
170
171        Ok(row.map(|r| {
172            let raw: Vec<u8> = r.get("payload");
173            bytes::Bytes::from(raw)
174        }))
175    }
176
177    // --- Overridden composites: single ACID transactions ---
178
179    #[tracing::instrument(
180        name = "db.check_and_cancel",
181        skip(self),
182        fields(db.system = "postgresql"),
183        err(level = tracing::Level::ERROR),
184    )]
185    async fn check_and_cancel(
186        &self,
187        instance_id: &str,
188        interrupted_at_task: Option<&str>,
189    ) -> Result<bool, BackendError> {
190        tracing::debug!("checking for cancel signal");
191        let mut tx = self.pool.begin().await.map_err(PgError)?;
192
193        // Check for cancel signal (lock the row)
194        let signal_row = sqlx::query(
195            "SELECT reason, requested_by
196             FROM sayiir_workflow_signals
197             WHERE instance_id = $1 AND kind = $2
198             FOR UPDATE",
199        )
200        .bind(instance_id)
201        .bind(SignalKind::Cancel.as_ref())
202        .fetch_optional(&mut *tx)
203        .await
204        .map_err(PgError)?;
205
206        let Some(signal_row) = signal_row else {
207            tx.rollback().await.map_err(PgError)?;
208            return Ok(false);
209        };
210
211        // Lock and load the snapshot
212        let snap_row = sqlx::query(
213            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
214        )
215        .bind(instance_id)
216        .fetch_one(&mut *tx)
217        .await
218        .map_err(PgError)?;
219
220        let raw: &[u8] = snap_row.get("data");
221        let mut snapshot = self.decode(raw)?;
222
223        if !snapshot.state.is_in_progress() {
224            tx.rollback().await.map_err(PgError)?;
225            return Ok(false);
226        }
227
228        let reason: Option<String> = signal_row.get("reason");
229        let requested_by: Option<String> = signal_row.get("requested_by");
230        snapshot.mark_cancelled(reason, requested_by, interrupted_at_task.map(String::from));
231
232        let data = self.encode(&snapshot)?;
233        let status = status_str(&snapshot.state);
234        let error = error_message(&snapshot).map(ToString::to_string);
235        let pos_kind = position_kind(&snapshot);
236        let wake_at = delay_wake_at(&snapshot);
237
238        sqlx::query(
239            "UPDATE sayiir_workflow_snapshots
240             SET data = $1, status = $2, error = $3,
241                 position_kind = $4, delay_wake_at = $5,
242                 completed_at = now(), updated_at = now()
243             WHERE instance_id = $6",
244        )
245        .bind(&data)
246        .bind(status)
247        .bind(&error)
248        .bind(pos_kind)
249        .bind(wake_at)
250        .bind(instance_id)
251        .execute(&mut *tx)
252        .await
253        .map_err(PgError)?;
254
255        // Mark any still-active tasks as cancelled
256        sqlx::query(
257            "UPDATE sayiir_workflow_tasks SET status = 'cancelled', completed_at = now()
258             WHERE instance_id = $1 AND status = 'active'",
259        )
260        .bind(instance_id)
261        .execute(&mut *tx)
262        .await
263        .map_err(PgError)?;
264
265        // Clear the signal
266        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
267            .bind(instance_id)
268            .bind(SignalKind::Cancel.as_ref())
269            .execute(&mut *tx)
270            .await
271            .map_err(PgError)?;
272
273        tx.commit().await.map_err(PgError)?;
274        tracing::info!(instance_id, "workflow cancelled");
275        Ok(true)
276    }
277
278    #[tracing::instrument(
279        name = "db.check_and_pause",
280        skip(self),
281        fields(db.system = "postgresql"),
282        err(level = tracing::Level::ERROR),
283    )]
284    async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
285        tracing::debug!("checking for pause signal");
286        let mut tx = self.pool.begin().await.map_err(PgError)?;
287
288        // Check for pause signal (lock the row)
289        let signal_row = sqlx::query(
290            "SELECT reason, requested_by
291             FROM sayiir_workflow_signals
292             WHERE instance_id = $1 AND kind = $2
293             FOR UPDATE",
294        )
295        .bind(instance_id)
296        .bind(SignalKind::Pause.as_ref())
297        .fetch_optional(&mut *tx)
298        .await
299        .map_err(PgError)?;
300
301        let Some(signal_row) = signal_row else {
302            tx.rollback().await.map_err(PgError)?;
303            return Ok(false);
304        };
305
306        // Lock and load the snapshot
307        let snap_row = sqlx::query(
308            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
309        )
310        .bind(instance_id)
311        .fetch_one(&mut *tx)
312        .await
313        .map_err(PgError)?;
314
315        let raw: &[u8] = snap_row.get("data");
316        let mut snapshot = self.decode(raw)?;
317
318        if !snapshot.state.is_in_progress() {
319            tx.rollback().await.map_err(PgError)?;
320            return Ok(false);
321        }
322
323        let reason: Option<String> = signal_row.get("reason");
324        let requested_by: Option<String> = signal_row.get("requested_by");
325        let pause_request = PauseRequest::new(reason, requested_by);
326        snapshot.mark_paused(&pause_request);
327
328        let data = self.encode(&snapshot)?;
329        let status = status_str(&snapshot.state);
330        let task_id = current_task_id(&snapshot).map(ToString::to_string);
331        let task_count = completed_task_count(&snapshot);
332        let pos_kind = position_kind(&snapshot);
333        let wake_at = delay_wake_at(&snapshot);
334
335        sqlx::query(
336            "UPDATE sayiir_workflow_snapshots
337             SET data = $1, status = $2, current_task_id = $3,
338                 completed_task_count = $4, position_kind = $5,
339                 delay_wake_at = $6, updated_at = now()
340             WHERE instance_id = $7",
341        )
342        .bind(&data)
343        .bind(status)
344        .bind(&task_id)
345        .bind(task_count)
346        .bind(pos_kind)
347        .bind(wake_at)
348        .bind(instance_id)
349        .execute(&mut *tx)
350        .await
351        .map_err(PgError)?;
352
353        // Clear the signal
354        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
355            .bind(instance_id)
356            .bind(SignalKind::Pause.as_ref())
357            .execute(&mut *tx)
358            .await
359            .map_err(PgError)?;
360
361        tx.commit().await.map_err(PgError)?;
362        tracing::info!(instance_id, "workflow paused");
363        Ok(true)
364    }
365
366    #[tracing::instrument(
367        name = "db.unpause",
368        skip(self),
369        fields(db.system = "postgresql"),
370        err(level = tracing::Level::ERROR),
371    )]
372    async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
373        tracing::debug!("unpausing workflow");
374        let mut tx = self.pool.begin().await.map_err(PgError)?;
375
376        let row = sqlx::query(
377            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
378        )
379        .bind(instance_id)
380        .fetch_optional(&mut *tx)
381        .await
382        .map_err(PgError)?
383        .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
384
385        let raw: &[u8] = row.get("data");
386        let mut snapshot = self.decode(raw)?;
387
388        if !snapshot.state.is_paused() {
389            let state_name = status_str(&snapshot.state);
390            return Err(BackendError::CannotPause(format!(
391                "Workflow is not paused (current state: {state_name:?})"
392            )));
393        }
394
395        snapshot.mark_unpaused();
396
397        let data = self.encode(&snapshot)?;
398        let status = status_str(&snapshot.state);
399        let task_id = current_task_id(&snapshot).map(ToString::to_string);
400        let task_count = completed_task_count(&snapshot);
401        let pos_kind = position_kind(&snapshot);
402        let wake_at = delay_wake_at(&snapshot);
403
404        sqlx::query(
405            "UPDATE sayiir_workflow_snapshots
406             SET data = $1, status = $2, current_task_id = $3,
407                 completed_task_count = $4, position_kind = $5,
408                 delay_wake_at = $6, updated_at = now()
409             WHERE instance_id = $7",
410        )
411        .bind(&data)
412        .bind(status)
413        .bind(&task_id)
414        .bind(task_count)
415        .bind(pos_kind)
416        .bind(wake_at)
417        .bind(instance_id)
418        .execute(&mut *tx)
419        .await
420        .map_err(PgError)?;
421
422        tx.commit().await.map_err(PgError)?;
423        tracing::info!(instance_id, "workflow unpaused");
424        Ok(snapshot)
425    }
426}
427
428/// Validate that a signal can be sent to a workflow in the given state.
429fn validate_signal_allowed(status: &str, kind: SignalKind) -> Result<(), BackendError> {
430    use std::str::FromStr;
431
432    let Ok(status) = SnapshotStatus::from_str(status) else {
433        // Unknown status from DB — be permissive (forward compatibility).
434        return Ok(());
435    };
436
437    match kind {
438        SignalKind::Cancel => match status {
439            SnapshotStatus::Completed | SnapshotStatus::Failed => {
440                Err(BackendError::CannotCancel(status.as_ref().to_string()))
441            }
442            _ => Ok(()),
443        },
444        SignalKind::Pause => match status {
445            SnapshotStatus::Completed | SnapshotStatus::Failed | SnapshotStatus::Cancelled => {
446                Err(BackendError::CannotPause(status.as_ref().to_string()))
447            }
448            _ => Ok(()),
449        },
450    }
451}