Skip to main content

sayiir_postgres/
signal_store.rs

1//! [`SignalStore`] implementation for Postgres.
2//!
3//! Overrides the 3 default composite methods with single-transaction
4//! implementations for true ACID atomicity.
5
6use sayiir_core::codec::{self, Decoder, Encoder};
7use sayiir_core::snapshot::{
8    PauseRequest, SignalKind, SignalRequest, SnapshotStatus, WorkflowSnapshot,
9};
10use sayiir_persistence::{BackendError, SignalStore};
11use sqlx::Row;
12
13use crate::backend::PostgresBackend;
14use crate::error::PgError;
15use crate::helpers::{
16    completed_task_count, current_task_id, delay_wake_at, error_message, position_kind, status_str,
17};
18
19impl<C> SignalStore for PostgresBackend<C>
20where
21    C: Encoder
22        + Decoder
23        + codec::sealed::EncodeValue<WorkflowSnapshot>
24        + codec::sealed::DecodeValue<WorkflowSnapshot>,
25{
26    async fn store_signal(
27        &self,
28        instance_id: &str,
29        kind: SignalKind,
30        request: SignalRequest,
31    ) -> Result<(), BackendError> {
32        tracing::debug!(instance_id, kind = %kind.as_ref(), "storing signal");
33        // Validate workflow state first
34        let row =
35            sqlx::query("SELECT status FROM sayiir_workflow_snapshots WHERE instance_id = $1")
36                .bind(instance_id)
37                .fetch_optional(&self.pool)
38                .await
39                .map_err(PgError)?
40                .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
41
42        let status: String = row.get("status");
43        validate_signal_allowed(&status, kind)?;
44
45        sqlx::query(
46            "INSERT INTO sayiir_workflow_signals (instance_id, kind, reason, requested_by)
47             VALUES ($1, $2, $3, $4)
48             ON CONFLICT (instance_id, kind) DO UPDATE SET
49                reason = $3, requested_by = $4, created_at = now()",
50        )
51        .bind(instance_id)
52        .bind(kind.as_ref())
53        .bind(&request.reason)
54        .bind(&request.requested_by)
55        .execute(&self.pool)
56        .await
57        .map_err(PgError)?;
58
59        Ok(())
60    }
61
62    async fn get_signal(
63        &self,
64        instance_id: &str,
65        kind: SignalKind,
66    ) -> Result<Option<SignalRequest>, BackendError> {
67        tracing::debug!(instance_id, kind = %kind.as_ref(), "getting signal");
68        let row = sqlx::query(
69            "SELECT reason, requested_by, created_at
70             FROM sayiir_workflow_signals
71             WHERE instance_id = $1 AND kind = $2",
72        )
73        .bind(instance_id)
74        .bind(kind.as_ref())
75        .fetch_optional(&self.pool)
76        .await
77        .map_err(PgError)?;
78
79        Ok(row.map(|r| SignalRequest {
80            reason: r.get("reason"),
81            requested_by: r.get("requested_by"),
82            requested_at: r.get("created_at"),
83        }))
84    }
85
86    async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
87        tracing::debug!(instance_id, kind = %kind.as_ref(), "clearing signal");
88        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
89            .bind(instance_id)
90            .bind(kind.as_ref())
91            .execute(&self.pool)
92            .await
93            .map_err(PgError)?;
94        Ok(())
95    }
96
97    async fn send_event(
98        &self,
99        instance_id: &str,
100        signal_name: &str,
101        payload: bytes::Bytes,
102    ) -> Result<(), BackendError> {
103        tracing::debug!(instance_id, signal_name, "buffering external event");
104        sqlx::query(
105            "INSERT INTO sayiir_workflow_events (instance_id, signal_name, payload)
106             VALUES ($1, $2, $3)",
107        )
108        .bind(instance_id)
109        .bind(signal_name)
110        .bind(payload.as_ref())
111        .execute(&self.pool)
112        .await
113        .map_err(PgError)?;
114        Ok(())
115    }
116
117    async fn consume_event(
118        &self,
119        instance_id: &str,
120        signal_name: &str,
121    ) -> Result<Option<bytes::Bytes>, BackendError> {
122        tracing::debug!(instance_id, signal_name, "consuming oldest buffered event");
123        // Atomically delete-and-return the oldest event for this (instance, signal).
124        let row = sqlx::query(
125            "DELETE FROM sayiir_workflow_events
126             WHERE id = (
127                 SELECT id FROM sayiir_workflow_events
128                 WHERE instance_id = $1 AND signal_name = $2
129                 ORDER BY id ASC
130                 LIMIT 1
131                 FOR UPDATE SKIP LOCKED
132             )
133             RETURNING payload",
134        )
135        .bind(instance_id)
136        .bind(signal_name)
137        .fetch_optional(&self.pool)
138        .await
139        .map_err(PgError)?;
140
141        Ok(row.map(|r| {
142            let raw: Vec<u8> = r.get("payload");
143            bytes::Bytes::from(raw)
144        }))
145    }
146
147    // --- Overridden composites: single ACID transactions ---
148
149    async fn check_and_cancel(
150        &self,
151        instance_id: &str,
152        interrupted_at_task: Option<&str>,
153    ) -> Result<bool, BackendError> {
154        tracing::debug!(instance_id, "checking for cancel signal");
155        let mut tx = self.pool.begin().await.map_err(PgError)?;
156
157        // Check for cancel signal (lock the row)
158        let signal_row = sqlx::query(
159            "SELECT reason, requested_by
160             FROM sayiir_workflow_signals
161             WHERE instance_id = $1 AND kind = $2
162             FOR UPDATE",
163        )
164        .bind(instance_id)
165        .bind(SignalKind::Cancel.as_ref())
166        .fetch_optional(&mut *tx)
167        .await
168        .map_err(PgError)?;
169
170        let Some(signal_row) = signal_row else {
171            tx.rollback().await.map_err(PgError)?;
172            return Ok(false);
173        };
174
175        // Lock and load the snapshot
176        let snap_row = sqlx::query(
177            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
178        )
179        .bind(instance_id)
180        .fetch_one(&mut *tx)
181        .await
182        .map_err(PgError)?;
183
184        let raw: &[u8] = snap_row.get("data");
185        let mut snapshot = self.decode(raw)?;
186
187        if !snapshot.state.is_in_progress() {
188            tx.rollback().await.map_err(PgError)?;
189            return Ok(false);
190        }
191
192        let reason: Option<String> = signal_row.get("reason");
193        let requested_by: Option<String> = signal_row.get("requested_by");
194        snapshot.mark_cancelled(reason, requested_by, interrupted_at_task.map(String::from));
195
196        let data = self.encode(&snapshot)?;
197        let status = status_str(&snapshot.state);
198        let error = error_message(&snapshot).map(ToString::to_string);
199        let pos_kind = position_kind(&snapshot);
200        let wake_at = delay_wake_at(&snapshot);
201
202        sqlx::query(
203            "UPDATE sayiir_workflow_snapshots
204             SET data = $1, status = $2, error = $3,
205                 position_kind = $4, delay_wake_at = $5,
206                 completed_at = now(), updated_at = now()
207             WHERE instance_id = $6",
208        )
209        .bind(&data)
210        .bind(status)
211        .bind(&error)
212        .bind(pos_kind)
213        .bind(wake_at)
214        .bind(instance_id)
215        .execute(&mut *tx)
216        .await
217        .map_err(PgError)?;
218
219        // Mark any still-active tasks as cancelled
220        sqlx::query(
221            "UPDATE sayiir_workflow_tasks SET status = 'cancelled', completed_at = now()
222             WHERE instance_id = $1 AND status = 'active'",
223        )
224        .bind(instance_id)
225        .execute(&mut *tx)
226        .await
227        .map_err(PgError)?;
228
229        // Clear the signal
230        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
231            .bind(instance_id)
232            .bind(SignalKind::Cancel.as_ref())
233            .execute(&mut *tx)
234            .await
235            .map_err(PgError)?;
236
237        tx.commit().await.map_err(PgError)?;
238        tracing::info!(instance_id, "workflow cancelled");
239        Ok(true)
240    }
241
242    async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
243        tracing::debug!(instance_id, "checking for pause signal");
244        let mut tx = self.pool.begin().await.map_err(PgError)?;
245
246        // Check for pause signal (lock the row)
247        let signal_row = sqlx::query(
248            "SELECT reason, requested_by
249             FROM sayiir_workflow_signals
250             WHERE instance_id = $1 AND kind = $2
251             FOR UPDATE",
252        )
253        .bind(instance_id)
254        .bind(SignalKind::Pause.as_ref())
255        .fetch_optional(&mut *tx)
256        .await
257        .map_err(PgError)?;
258
259        let Some(signal_row) = signal_row else {
260            tx.rollback().await.map_err(PgError)?;
261            return Ok(false);
262        };
263
264        // Lock and load the snapshot
265        let snap_row = sqlx::query(
266            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
267        )
268        .bind(instance_id)
269        .fetch_one(&mut *tx)
270        .await
271        .map_err(PgError)?;
272
273        let raw: &[u8] = snap_row.get("data");
274        let mut snapshot = self.decode(raw)?;
275
276        if !snapshot.state.is_in_progress() {
277            tx.rollback().await.map_err(PgError)?;
278            return Ok(false);
279        }
280
281        let reason: Option<String> = signal_row.get("reason");
282        let requested_by: Option<String> = signal_row.get("requested_by");
283        let pause_request = PauseRequest::new(reason, requested_by);
284        snapshot.mark_paused(&pause_request);
285
286        let data = self.encode(&snapshot)?;
287        let status = status_str(&snapshot.state);
288        let task_id = current_task_id(&snapshot).map(ToString::to_string);
289        let task_count = completed_task_count(&snapshot);
290        let pos_kind = position_kind(&snapshot);
291        let wake_at = delay_wake_at(&snapshot);
292
293        sqlx::query(
294            "UPDATE sayiir_workflow_snapshots
295             SET data = $1, status = $2, current_task_id = $3,
296                 completed_task_count = $4, position_kind = $5,
297                 delay_wake_at = $6, updated_at = now()
298             WHERE instance_id = $7",
299        )
300        .bind(&data)
301        .bind(status)
302        .bind(&task_id)
303        .bind(task_count)
304        .bind(pos_kind)
305        .bind(wake_at)
306        .bind(instance_id)
307        .execute(&mut *tx)
308        .await
309        .map_err(PgError)?;
310
311        // Clear the signal
312        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
313            .bind(instance_id)
314            .bind(SignalKind::Pause.as_ref())
315            .execute(&mut *tx)
316            .await
317            .map_err(PgError)?;
318
319        tx.commit().await.map_err(PgError)?;
320        tracing::info!(instance_id, "workflow paused");
321        Ok(true)
322    }
323
324    async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
325        tracing::debug!(instance_id, "unpausing workflow");
326        let mut tx = self.pool.begin().await.map_err(PgError)?;
327
328        let row = sqlx::query(
329            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
330        )
331        .bind(instance_id)
332        .fetch_optional(&mut *tx)
333        .await
334        .map_err(PgError)?
335        .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
336
337        let raw: &[u8] = row.get("data");
338        let mut snapshot = self.decode(raw)?;
339
340        if !snapshot.state.is_paused() {
341            let state_name = status_str(&snapshot.state);
342            return Err(BackendError::CannotPause(format!(
343                "Workflow is not paused (current state: {state_name:?})"
344            )));
345        }
346
347        snapshot.mark_unpaused();
348
349        let data = self.encode(&snapshot)?;
350        let status = status_str(&snapshot.state);
351        let task_id = current_task_id(&snapshot).map(ToString::to_string);
352        let task_count = completed_task_count(&snapshot);
353        let pos_kind = position_kind(&snapshot);
354        let wake_at = delay_wake_at(&snapshot);
355
356        sqlx::query(
357            "UPDATE sayiir_workflow_snapshots
358             SET data = $1, status = $2, current_task_id = $3,
359                 completed_task_count = $4, position_kind = $5,
360                 delay_wake_at = $6, updated_at = now()
361             WHERE instance_id = $7",
362        )
363        .bind(&data)
364        .bind(status)
365        .bind(&task_id)
366        .bind(task_count)
367        .bind(pos_kind)
368        .bind(wake_at)
369        .bind(instance_id)
370        .execute(&mut *tx)
371        .await
372        .map_err(PgError)?;
373
374        tx.commit().await.map_err(PgError)?;
375        tracing::info!(instance_id, "workflow unpaused");
376        Ok(snapshot)
377    }
378}
379
380/// Validate that a signal can be sent to a workflow in the given state.
381fn validate_signal_allowed(status: &str, kind: SignalKind) -> Result<(), BackendError> {
382    use std::str::FromStr;
383
384    let Ok(status) = SnapshotStatus::from_str(status) else {
385        // Unknown status from DB — be permissive (forward compatibility).
386        return Ok(());
387    };
388
389    match kind {
390        SignalKind::Cancel => match status {
391            SnapshotStatus::Completed | SnapshotStatus::Failed => {
392                Err(BackendError::CannotCancel(status.as_ref().to_string()))
393            }
394            _ => Ok(()),
395        },
396        SignalKind::Pause => match status {
397            SnapshotStatus::Completed | SnapshotStatus::Failed | SnapshotStatus::Cancelled => {
398                Err(BackendError::CannotPause(status.as_ref().to_string()))
399            }
400            _ => Ok(()),
401        },
402    }
403}