Skip to main content

sayiir_postgres/
signal_store.rs

1//! [`SignalStore`] implementation for Postgres.
2//!
3//! Overrides the 3 default composite methods with single-transaction
4//! implementations for true ACID atomicity.
5
6use sayiir_core::codec::{self, Decoder, Encoder};
7use sayiir_core::snapshot::{
8    PauseRequest, SignalKind, SignalRequest, SnapshotStatus, WorkflowSnapshot,
9};
10use sayiir_persistence::{BackendError, SignalStore};
11use sqlx::Row;
12
13use crate::backend::PostgresBackend;
14use crate::error::PgError;
15use crate::helpers::{
16    completed_task_count, current_task_id, delay_wake_at, error_message, position_kind, status_str,
17};
18
19impl<C> SignalStore for PostgresBackend<C>
20where
21    C: Encoder
22        + Decoder
23        + codec::sealed::EncodeValue<WorkflowSnapshot>
24        + codec::sealed::DecodeValue<WorkflowSnapshot>,
25{
26    async fn store_signal(
27        &self,
28        instance_id: &str,
29        kind: SignalKind,
30        request: SignalRequest,
31    ) -> Result<(), BackendError> {
32        tracing::debug!(instance_id, kind = %kind.as_ref(), "storing signal");
33        // Validate workflow state first
34        let row =
35            sqlx::query("SELECT status FROM sayiir_workflow_snapshots WHERE instance_id = $1")
36                .bind(instance_id)
37                .fetch_optional(&self.pool)
38                .await
39                .map_err(PgError)?
40                .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
41
42        let status: String = row.get("status");
43        validate_signal_allowed(&status, &kind)?;
44
45        sqlx::query(
46            "INSERT INTO sayiir_workflow_signals (instance_id, kind, reason, requested_by)
47             VALUES ($1, $2, $3, $4)
48             ON CONFLICT (instance_id, kind) DO UPDATE SET
49                reason = $3, requested_by = $4, created_at = now()",
50        )
51        .bind(instance_id)
52        .bind(kind.as_ref())
53        .bind(&request.reason)
54        .bind(&request.requested_by)
55        .execute(&self.pool)
56        .await
57        .map_err(PgError)?;
58
59        Ok(())
60    }
61
62    async fn get_signal(
63        &self,
64        instance_id: &str,
65        kind: SignalKind,
66    ) -> Result<Option<SignalRequest>, BackendError> {
67        tracing::debug!(instance_id, kind = %kind.as_ref(), "getting signal");
68        let row = sqlx::query(
69            "SELECT reason, requested_by, created_at
70             FROM sayiir_workflow_signals
71             WHERE instance_id = $1 AND kind = $2",
72        )
73        .bind(instance_id)
74        .bind(kind.as_ref())
75        .fetch_optional(&self.pool)
76        .await
77        .map_err(PgError)?;
78
79        Ok(row.map(|r| SignalRequest {
80            reason: r.get("reason"),
81            requested_by: r.get("requested_by"),
82            requested_at: r.get("created_at"),
83        }))
84    }
85
86    async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
87        tracing::debug!(instance_id, kind = %kind.as_ref(), "clearing signal");
88        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
89            .bind(instance_id)
90            .bind(kind.as_ref())
91            .execute(&self.pool)
92            .await
93            .map_err(PgError)?;
94        Ok(())
95    }
96
97    // --- Overridden composites: single ACID transactions ---
98
99    async fn check_and_cancel(
100        &self,
101        instance_id: &str,
102        interrupted_at_task: Option<&str>,
103    ) -> Result<bool, BackendError> {
104        tracing::debug!(instance_id, "checking for cancel signal");
105        let mut tx = self.pool.begin().await.map_err(PgError)?;
106
107        // Check for cancel signal (lock the row)
108        let signal_row = sqlx::query(
109            "SELECT reason, requested_by
110             FROM sayiir_workflow_signals
111             WHERE instance_id = $1 AND kind = $2
112             FOR UPDATE",
113        )
114        .bind(instance_id)
115        .bind(SignalKind::Cancel.as_ref())
116        .fetch_optional(&mut *tx)
117        .await
118        .map_err(PgError)?;
119
120        let Some(signal_row) = signal_row else {
121            tx.rollback().await.map_err(PgError)?;
122            return Ok(false);
123        };
124
125        // Lock and load the snapshot
126        let snap_row = sqlx::query(
127            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
128        )
129        .bind(instance_id)
130        .fetch_one(&mut *tx)
131        .await
132        .map_err(PgError)?;
133
134        let raw: &[u8] = snap_row.get("data");
135        let mut snapshot = self.decode(raw)?;
136
137        if !snapshot.state.is_in_progress() {
138            tx.rollback().await.map_err(PgError)?;
139            return Ok(false);
140        }
141
142        let reason: Option<String> = signal_row.get("reason");
143        let requested_by: Option<String> = signal_row.get("requested_by");
144        snapshot.mark_cancelled(reason, requested_by, interrupted_at_task.map(String::from));
145
146        let data = self.encode(&snapshot)?;
147        let status = status_str(&snapshot.state);
148        let error = error_message(&snapshot).map(ToString::to_string);
149        let pos_kind = position_kind(&snapshot);
150        let wake_at = delay_wake_at(&snapshot);
151
152        sqlx::query(
153            "UPDATE sayiir_workflow_snapshots
154             SET data = $1, status = $2, error = $3,
155                 position_kind = $4, delay_wake_at = $5,
156                 completed_at = now(), updated_at = now()
157             WHERE instance_id = $6",
158        )
159        .bind(&data)
160        .bind(status)
161        .bind(&error)
162        .bind(pos_kind)
163        .bind(wake_at)
164        .bind(instance_id)
165        .execute(&mut *tx)
166        .await
167        .map_err(PgError)?;
168
169        // Mark any still-active tasks as cancelled
170        sqlx::query(
171            "UPDATE sayiir_workflow_tasks SET status = 'cancelled', completed_at = now()
172             WHERE instance_id = $1 AND status = 'active'",
173        )
174        .bind(instance_id)
175        .execute(&mut *tx)
176        .await
177        .map_err(PgError)?;
178
179        // Clear the signal
180        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
181            .bind(instance_id)
182            .bind(SignalKind::Cancel.as_ref())
183            .execute(&mut *tx)
184            .await
185            .map_err(PgError)?;
186
187        tx.commit().await.map_err(PgError)?;
188        tracing::info!(instance_id, "workflow cancelled");
189        Ok(true)
190    }
191
192    async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
193        tracing::debug!(instance_id, "checking for pause signal");
194        let mut tx = self.pool.begin().await.map_err(PgError)?;
195
196        // Check for pause signal (lock the row)
197        let signal_row = sqlx::query(
198            "SELECT reason, requested_by
199             FROM sayiir_workflow_signals
200             WHERE instance_id = $1 AND kind = $2
201             FOR UPDATE",
202        )
203        .bind(instance_id)
204        .bind(SignalKind::Pause.as_ref())
205        .fetch_optional(&mut *tx)
206        .await
207        .map_err(PgError)?;
208
209        let Some(signal_row) = signal_row else {
210            tx.rollback().await.map_err(PgError)?;
211            return Ok(false);
212        };
213
214        // Lock and load the snapshot
215        let snap_row = sqlx::query(
216            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
217        )
218        .bind(instance_id)
219        .fetch_one(&mut *tx)
220        .await
221        .map_err(PgError)?;
222
223        let raw: &[u8] = snap_row.get("data");
224        let mut snapshot = self.decode(raw)?;
225
226        if !snapshot.state.is_in_progress() {
227            tx.rollback().await.map_err(PgError)?;
228            return Ok(false);
229        }
230
231        let reason: Option<String> = signal_row.get("reason");
232        let requested_by: Option<String> = signal_row.get("requested_by");
233        let pause_request = PauseRequest::new(reason, requested_by);
234        snapshot.mark_paused(&pause_request);
235
236        let data = self.encode(&snapshot)?;
237        let status = status_str(&snapshot.state);
238        let task_id = current_task_id(&snapshot).map(ToString::to_string);
239        let task_count = completed_task_count(&snapshot);
240        let pos_kind = position_kind(&snapshot);
241        let wake_at = delay_wake_at(&snapshot);
242
243        sqlx::query(
244            "UPDATE sayiir_workflow_snapshots
245             SET data = $1, status = $2, current_task_id = $3,
246                 completed_task_count = $4, position_kind = $5,
247                 delay_wake_at = $6, updated_at = now()
248             WHERE instance_id = $7",
249        )
250        .bind(&data)
251        .bind(status)
252        .bind(&task_id)
253        .bind(task_count)
254        .bind(pos_kind)
255        .bind(wake_at)
256        .bind(instance_id)
257        .execute(&mut *tx)
258        .await
259        .map_err(PgError)?;
260
261        // Clear the signal
262        sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
263            .bind(instance_id)
264            .bind(SignalKind::Pause.as_ref())
265            .execute(&mut *tx)
266            .await
267            .map_err(PgError)?;
268
269        tx.commit().await.map_err(PgError)?;
270        tracing::info!(instance_id, "workflow paused");
271        Ok(true)
272    }
273
274    async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
275        tracing::debug!(instance_id, "unpausing workflow");
276        let mut tx = self.pool.begin().await.map_err(PgError)?;
277
278        let row = sqlx::query(
279            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
280        )
281        .bind(instance_id)
282        .fetch_optional(&mut *tx)
283        .await
284        .map_err(PgError)?
285        .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
286
287        let raw: &[u8] = row.get("data");
288        let mut snapshot = self.decode(raw)?;
289
290        if !snapshot.state.is_paused() {
291            let state_name = status_str(&snapshot.state);
292            return Err(BackendError::CannotPause(format!(
293                "Workflow is not paused (current state: {state_name:?})"
294            )));
295        }
296
297        snapshot.mark_unpaused();
298
299        let data = self.encode(&snapshot)?;
300        let status = status_str(&snapshot.state);
301        let task_id = current_task_id(&snapshot).map(ToString::to_string);
302        let task_count = completed_task_count(&snapshot);
303        let pos_kind = position_kind(&snapshot);
304        let wake_at = delay_wake_at(&snapshot);
305
306        sqlx::query(
307            "UPDATE sayiir_workflow_snapshots
308             SET data = $1, status = $2, current_task_id = $3,
309                 completed_task_count = $4, position_kind = $5,
310                 delay_wake_at = $6, updated_at = now()
311             WHERE instance_id = $7",
312        )
313        .bind(&data)
314        .bind(status)
315        .bind(&task_id)
316        .bind(task_count)
317        .bind(pos_kind)
318        .bind(wake_at)
319        .bind(instance_id)
320        .execute(&mut *tx)
321        .await
322        .map_err(PgError)?;
323
324        tx.commit().await.map_err(PgError)?;
325        tracing::info!(instance_id, "workflow unpaused");
326        Ok(snapshot)
327    }
328}
329
330/// Validate that a signal can be sent to a workflow in the given state.
331fn validate_signal_allowed(status: &str, kind: &SignalKind) -> Result<(), BackendError> {
332    use std::str::FromStr;
333
334    let Ok(status) = SnapshotStatus::from_str(status) else {
335        // Unknown status from DB — be permissive (forward compatibility).
336        return Ok(());
337    };
338
339    match kind {
340        SignalKind::Cancel => match status {
341            SnapshotStatus::Completed | SnapshotStatus::Failed => {
342                Err(BackendError::CannotCancel(status.as_ref().to_string()))
343            }
344            _ => Ok(()),
345        },
346        SignalKind::Pause => match status {
347            SnapshotStatus::Completed | SnapshotStatus::Failed | SnapshotStatus::Cancelled => {
348                Err(BackendError::CannotPause(status.as_ref().to_string()))
349            }
350            _ => Ok(()),
351        },
352    }
353}