1use sayiir_core::codec::{self, Decoder, Encoder};
7use sayiir_core::snapshot::{
8 PauseRequest, SignalKind, SignalRequest, SnapshotStatus, WorkflowSnapshot,
9};
10use sayiir_persistence::{BackendError, SignalStore};
11use sqlx::Row;
12
13use crate::backend::PostgresBackend;
14use crate::error::PgError;
15use crate::helpers::{
16 completed_task_count, current_task_id, delay_wake_at, error_message, position_kind, status_str,
17};
18
19impl<C> SignalStore for PostgresBackend<C>
20where
21 C: Encoder
22 + Decoder
23 + codec::sealed::EncodeValue<WorkflowSnapshot>
24 + codec::sealed::DecodeValue<WorkflowSnapshot>,
25{
26 async fn store_signal(
27 &self,
28 instance_id: &str,
29 kind: SignalKind,
30 request: SignalRequest,
31 ) -> Result<(), BackendError> {
32 tracing::debug!(instance_id, kind = %kind.as_ref(), "storing signal");
33 let row =
35 sqlx::query("SELECT status FROM sayiir_workflow_snapshots WHERE instance_id = $1")
36 .bind(instance_id)
37 .fetch_optional(&self.pool)
38 .await
39 .map_err(PgError)?
40 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
41
42 let status: String = row.get("status");
43 validate_signal_allowed(&status, kind)?;
44
45 sqlx::query(
46 "INSERT INTO sayiir_workflow_signals (instance_id, kind, reason, requested_by)
47 VALUES ($1, $2, $3, $4)
48 ON CONFLICT (instance_id, kind) DO UPDATE SET
49 reason = $3, requested_by = $4, created_at = now()",
50 )
51 .bind(instance_id)
52 .bind(kind.as_ref())
53 .bind(&request.reason)
54 .bind(&request.requested_by)
55 .execute(&self.pool)
56 .await
57 .map_err(PgError)?;
58
59 Ok(())
60 }
61
62 async fn get_signal(
63 &self,
64 instance_id: &str,
65 kind: SignalKind,
66 ) -> Result<Option<SignalRequest>, BackendError> {
67 tracing::debug!(instance_id, kind = %kind.as_ref(), "getting signal");
68 let row = sqlx::query(
69 "SELECT reason, requested_by, created_at
70 FROM sayiir_workflow_signals
71 WHERE instance_id = $1 AND kind = $2",
72 )
73 .bind(instance_id)
74 .bind(kind.as_ref())
75 .fetch_optional(&self.pool)
76 .await
77 .map_err(PgError)?;
78
79 Ok(row.map(|r| SignalRequest {
80 reason: r.get("reason"),
81 requested_by: r.get("requested_by"),
82 requested_at: r.get("created_at"),
83 }))
84 }
85
86 async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
87 tracing::debug!(instance_id, kind = %kind.as_ref(), "clearing signal");
88 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
89 .bind(instance_id)
90 .bind(kind.as_ref())
91 .execute(&self.pool)
92 .await
93 .map_err(PgError)?;
94 Ok(())
95 }
96
97 async fn send_event(
98 &self,
99 instance_id: &str,
100 signal_name: &str,
101 payload: bytes::Bytes,
102 ) -> Result<(), BackendError> {
103 tracing::debug!(instance_id, signal_name, "buffering external event");
104 sqlx::query(
105 "INSERT INTO sayiir_workflow_events (instance_id, signal_name, payload)
106 VALUES ($1, $2, $3)",
107 )
108 .bind(instance_id)
109 .bind(signal_name)
110 .bind(payload.as_ref())
111 .execute(&self.pool)
112 .await
113 .map_err(PgError)?;
114 Ok(())
115 }
116
117 async fn consume_event(
118 &self,
119 instance_id: &str,
120 signal_name: &str,
121 ) -> Result<Option<bytes::Bytes>, BackendError> {
122 tracing::debug!(instance_id, signal_name, "consuming oldest buffered event");
123 let row = sqlx::query(
125 "DELETE FROM sayiir_workflow_events
126 WHERE id = (
127 SELECT id FROM sayiir_workflow_events
128 WHERE instance_id = $1 AND signal_name = $2
129 ORDER BY id ASC
130 LIMIT 1
131 FOR UPDATE SKIP LOCKED
132 )
133 RETURNING payload",
134 )
135 .bind(instance_id)
136 .bind(signal_name)
137 .fetch_optional(&self.pool)
138 .await
139 .map_err(PgError)?;
140
141 Ok(row.map(|r| {
142 let raw: Vec<u8> = r.get("payload");
143 bytes::Bytes::from(raw)
144 }))
145 }
146
147 async fn check_and_cancel(
150 &self,
151 instance_id: &str,
152 interrupted_at_task: Option<&str>,
153 ) -> Result<bool, BackendError> {
154 tracing::debug!(instance_id, "checking for cancel signal");
155 let mut tx = self.pool.begin().await.map_err(PgError)?;
156
157 let signal_row = sqlx::query(
159 "SELECT reason, requested_by
160 FROM sayiir_workflow_signals
161 WHERE instance_id = $1 AND kind = $2
162 FOR UPDATE",
163 )
164 .bind(instance_id)
165 .bind(SignalKind::Cancel.as_ref())
166 .fetch_optional(&mut *tx)
167 .await
168 .map_err(PgError)?;
169
170 let Some(signal_row) = signal_row else {
171 tx.rollback().await.map_err(PgError)?;
172 return Ok(false);
173 };
174
175 let snap_row = sqlx::query(
177 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
178 )
179 .bind(instance_id)
180 .fetch_one(&mut *tx)
181 .await
182 .map_err(PgError)?;
183
184 let raw: &[u8] = snap_row.get("data");
185 let mut snapshot = self.decode(raw)?;
186
187 if !snapshot.state.is_in_progress() {
188 tx.rollback().await.map_err(PgError)?;
189 return Ok(false);
190 }
191
192 let reason: Option<String> = signal_row.get("reason");
193 let requested_by: Option<String> = signal_row.get("requested_by");
194 snapshot.mark_cancelled(reason, requested_by, interrupted_at_task.map(String::from));
195
196 let data = self.encode(&snapshot)?;
197 let status = status_str(&snapshot.state);
198 let error = error_message(&snapshot).map(ToString::to_string);
199 let pos_kind = position_kind(&snapshot);
200 let wake_at = delay_wake_at(&snapshot);
201
202 sqlx::query(
203 "UPDATE sayiir_workflow_snapshots
204 SET data = $1, status = $2, error = $3,
205 position_kind = $4, delay_wake_at = $5,
206 completed_at = now(), updated_at = now()
207 WHERE instance_id = $6",
208 )
209 .bind(&data)
210 .bind(status)
211 .bind(&error)
212 .bind(pos_kind)
213 .bind(wake_at)
214 .bind(instance_id)
215 .execute(&mut *tx)
216 .await
217 .map_err(PgError)?;
218
219 sqlx::query(
221 "UPDATE sayiir_workflow_tasks SET status = 'cancelled', completed_at = now()
222 WHERE instance_id = $1 AND status = 'active'",
223 )
224 .bind(instance_id)
225 .execute(&mut *tx)
226 .await
227 .map_err(PgError)?;
228
229 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
231 .bind(instance_id)
232 .bind(SignalKind::Cancel.as_ref())
233 .execute(&mut *tx)
234 .await
235 .map_err(PgError)?;
236
237 tx.commit().await.map_err(PgError)?;
238 tracing::info!(instance_id, "workflow cancelled");
239 Ok(true)
240 }
241
242 async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
243 tracing::debug!(instance_id, "checking for pause signal");
244 let mut tx = self.pool.begin().await.map_err(PgError)?;
245
246 let signal_row = sqlx::query(
248 "SELECT reason, requested_by
249 FROM sayiir_workflow_signals
250 WHERE instance_id = $1 AND kind = $2
251 FOR UPDATE",
252 )
253 .bind(instance_id)
254 .bind(SignalKind::Pause.as_ref())
255 .fetch_optional(&mut *tx)
256 .await
257 .map_err(PgError)?;
258
259 let Some(signal_row) = signal_row else {
260 tx.rollback().await.map_err(PgError)?;
261 return Ok(false);
262 };
263
264 let snap_row = sqlx::query(
266 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
267 )
268 .bind(instance_id)
269 .fetch_one(&mut *tx)
270 .await
271 .map_err(PgError)?;
272
273 let raw: &[u8] = snap_row.get("data");
274 let mut snapshot = self.decode(raw)?;
275
276 if !snapshot.state.is_in_progress() {
277 tx.rollback().await.map_err(PgError)?;
278 return Ok(false);
279 }
280
281 let reason: Option<String> = signal_row.get("reason");
282 let requested_by: Option<String> = signal_row.get("requested_by");
283 let pause_request = PauseRequest::new(reason, requested_by);
284 snapshot.mark_paused(&pause_request);
285
286 let data = self.encode(&snapshot)?;
287 let status = status_str(&snapshot.state);
288 let task_id = current_task_id(&snapshot).map(ToString::to_string);
289 let task_count = completed_task_count(&snapshot);
290 let pos_kind = position_kind(&snapshot);
291 let wake_at = delay_wake_at(&snapshot);
292
293 sqlx::query(
294 "UPDATE sayiir_workflow_snapshots
295 SET data = $1, status = $2, current_task_id = $3,
296 completed_task_count = $4, position_kind = $5,
297 delay_wake_at = $6, updated_at = now()
298 WHERE instance_id = $7",
299 )
300 .bind(&data)
301 .bind(status)
302 .bind(&task_id)
303 .bind(task_count)
304 .bind(pos_kind)
305 .bind(wake_at)
306 .bind(instance_id)
307 .execute(&mut *tx)
308 .await
309 .map_err(PgError)?;
310
311 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
313 .bind(instance_id)
314 .bind(SignalKind::Pause.as_ref())
315 .execute(&mut *tx)
316 .await
317 .map_err(PgError)?;
318
319 tx.commit().await.map_err(PgError)?;
320 tracing::info!(instance_id, "workflow paused");
321 Ok(true)
322 }
323
324 async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
325 tracing::debug!(instance_id, "unpausing workflow");
326 let mut tx = self.pool.begin().await.map_err(PgError)?;
327
328 let row = sqlx::query(
329 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
330 )
331 .bind(instance_id)
332 .fetch_optional(&mut *tx)
333 .await
334 .map_err(PgError)?
335 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
336
337 let raw: &[u8] = row.get("data");
338 let mut snapshot = self.decode(raw)?;
339
340 if !snapshot.state.is_paused() {
341 let state_name = status_str(&snapshot.state);
342 return Err(BackendError::CannotPause(format!(
343 "Workflow is not paused (current state: {state_name:?})"
344 )));
345 }
346
347 snapshot.mark_unpaused();
348
349 let data = self.encode(&snapshot)?;
350 let status = status_str(&snapshot.state);
351 let task_id = current_task_id(&snapshot).map(ToString::to_string);
352 let task_count = completed_task_count(&snapshot);
353 let pos_kind = position_kind(&snapshot);
354 let wake_at = delay_wake_at(&snapshot);
355
356 sqlx::query(
357 "UPDATE sayiir_workflow_snapshots
358 SET data = $1, status = $2, current_task_id = $3,
359 completed_task_count = $4, position_kind = $5,
360 delay_wake_at = $6, updated_at = now()
361 WHERE instance_id = $7",
362 )
363 .bind(&data)
364 .bind(status)
365 .bind(&task_id)
366 .bind(task_count)
367 .bind(pos_kind)
368 .bind(wake_at)
369 .bind(instance_id)
370 .execute(&mut *tx)
371 .await
372 .map_err(PgError)?;
373
374 tx.commit().await.map_err(PgError)?;
375 tracing::info!(instance_id, "workflow unpaused");
376 Ok(snapshot)
377 }
378}
379
380fn validate_signal_allowed(status: &str, kind: SignalKind) -> Result<(), BackendError> {
382 use std::str::FromStr;
383
384 let Ok(status) = SnapshotStatus::from_str(status) else {
385 return Ok(());
387 };
388
389 match kind {
390 SignalKind::Cancel => match status {
391 SnapshotStatus::Completed | SnapshotStatus::Failed => {
392 Err(BackendError::CannotCancel(status.as_ref().to_string()))
393 }
394 _ => Ok(()),
395 },
396 SignalKind::Pause => match status {
397 SnapshotStatus::Completed | SnapshotStatus::Failed | SnapshotStatus::Cancelled => {
398 Err(BackendError::CannotPause(status.as_ref().to_string()))
399 }
400 _ => Ok(()),
401 },
402 }
403}