1use sayiir_core::codec::{self, Decoder, Encoder};
7use sayiir_core::snapshot::{
8 PauseRequest, SignalKind, SignalRequest, SnapshotStatus, WorkflowSnapshot,
9};
10use sayiir_persistence::{BackendError, SignalStore};
11use sqlx::Row;
12
13use crate::backend::PostgresBackend;
14use crate::error::PgError;
15use crate::helpers::{
16 completed_task_count, current_task_id, delay_wake_at, error_message, position_kind, status_str,
17};
18
19impl<C> SignalStore for PostgresBackend<C>
20where
21 C: Encoder
22 + Decoder
23 + codec::sealed::EncodeValue<WorkflowSnapshot>
24 + codec::sealed::DecodeValue<WorkflowSnapshot>,
25{
26 #[tracing::instrument(
27 name = "db.store_signal",
28 skip(self, request),
29 fields(db.system = "postgresql", kind = %kind.as_ref()),
30 err(level = tracing::Level::ERROR),
31 )]
32 async fn store_signal(
33 &self,
34 instance_id: &str,
35 kind: SignalKind,
36 request: SignalRequest,
37 ) -> Result<(), BackendError> {
38 tracing::debug!("storing signal");
39 let row =
41 sqlx::query("SELECT status FROM sayiir_workflow_snapshots WHERE instance_id = $1")
42 .bind(instance_id)
43 .fetch_optional(&self.pool)
44 .await
45 .map_err(PgError)?
46 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
47
48 let status: String = row.get("status");
49 validate_signal_allowed(&status, kind)?;
50
51 sqlx::query(
52 "INSERT INTO sayiir_workflow_signals (instance_id, kind, reason, requested_by)
53 VALUES ($1, $2, $3, $4)
54 ON CONFLICT (instance_id, kind) DO UPDATE SET
55 reason = $3, requested_by = $4, created_at = now()",
56 )
57 .bind(instance_id)
58 .bind(kind.as_ref())
59 .bind(&request.reason)
60 .bind(&request.requested_by)
61 .execute(&self.pool)
62 .await
63 .map_err(PgError)?;
64
65 Ok(())
66 }
67
68 #[tracing::instrument(
69 name = "db.get_signal",
70 skip(self),
71 fields(db.system = "postgresql", kind = %kind.as_ref()),
72 err(level = tracing::Level::ERROR),
73 )]
74 async fn get_signal(
75 &self,
76 instance_id: &str,
77 kind: SignalKind,
78 ) -> Result<Option<SignalRequest>, BackendError> {
79 tracing::debug!("getting signal");
80 let row = sqlx::query(
81 "SELECT reason, requested_by, created_at
82 FROM sayiir_workflow_signals
83 WHERE instance_id = $1 AND kind = $2",
84 )
85 .bind(instance_id)
86 .bind(kind.as_ref())
87 .fetch_optional(&self.pool)
88 .await
89 .map_err(PgError)?;
90
91 Ok(row.map(|r| SignalRequest {
92 reason: r.get("reason"),
93 requested_by: r.get("requested_by"),
94 requested_at: r.get("created_at"),
95 }))
96 }
97
98 #[tracing::instrument(
99 name = "db.clear_signal",
100 skip(self),
101 fields(db.system = "postgresql", kind = %kind.as_ref()),
102 err(level = tracing::Level::ERROR),
103 )]
104 async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
105 tracing::debug!("clearing signal");
106 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
107 .bind(instance_id)
108 .bind(kind.as_ref())
109 .execute(&self.pool)
110 .await
111 .map_err(PgError)?;
112 Ok(())
113 }
114
115 #[tracing::instrument(
116 name = "db.send_event",
117 skip(self, payload),
118 fields(db.system = "postgresql"),
119 err(level = tracing::Level::ERROR),
120 )]
121 async fn send_event(
122 &self,
123 instance_id: &str,
124 signal_name: &str,
125 payload: bytes::Bytes,
126 ) -> Result<(), BackendError> {
127 tracing::debug!("buffering external event");
128 sqlx::query(
129 "INSERT INTO sayiir_workflow_events (instance_id, signal_name, payload)
130 VALUES ($1, $2, $3)",
131 )
132 .bind(instance_id)
133 .bind(signal_name)
134 .bind(payload.as_ref())
135 .execute(&self.pool)
136 .await
137 .map_err(PgError)?;
138 Ok(())
139 }
140
141 #[tracing::instrument(
142 name = "db.consume_event",
143 skip(self),
144 fields(db.system = "postgresql"),
145 err(level = tracing::Level::ERROR),
146 )]
147 async fn consume_event(
148 &self,
149 instance_id: &str,
150 signal_name: &str,
151 ) -> Result<Option<bytes::Bytes>, BackendError> {
152 tracing::debug!("consuming oldest buffered event");
153 let row = sqlx::query(
155 "DELETE FROM sayiir_workflow_events
156 WHERE id = (
157 SELECT id FROM sayiir_workflow_events
158 WHERE instance_id = $1 AND signal_name = $2
159 ORDER BY id ASC
160 LIMIT 1
161 FOR UPDATE SKIP LOCKED
162 )
163 RETURNING payload",
164 )
165 .bind(instance_id)
166 .bind(signal_name)
167 .fetch_optional(&self.pool)
168 .await
169 .map_err(PgError)?;
170
171 Ok(row.map(|r| {
172 let raw: Vec<u8> = r.get("payload");
173 bytes::Bytes::from(raw)
174 }))
175 }
176
177 #[tracing::instrument(
180 name = "db.check_and_cancel",
181 skip(self),
182 fields(db.system = "postgresql"),
183 err(level = tracing::Level::ERROR),
184 )]
185 async fn check_and_cancel(
186 &self,
187 instance_id: &str,
188 interrupted_at_task: Option<&str>,
189 ) -> Result<bool, BackendError> {
190 tracing::debug!("checking for cancel signal");
191 let mut tx = self.pool.begin().await.map_err(PgError)?;
192
193 let signal_row = sqlx::query(
195 "SELECT reason, requested_by
196 FROM sayiir_workflow_signals
197 WHERE instance_id = $1 AND kind = $2
198 FOR UPDATE",
199 )
200 .bind(instance_id)
201 .bind(SignalKind::Cancel.as_ref())
202 .fetch_optional(&mut *tx)
203 .await
204 .map_err(PgError)?;
205
206 let Some(signal_row) = signal_row else {
207 tx.rollback().await.map_err(PgError)?;
208 return Ok(false);
209 };
210
211 let snap_row = sqlx::query(
213 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
214 )
215 .bind(instance_id)
216 .fetch_one(&mut *tx)
217 .await
218 .map_err(PgError)?;
219
220 let raw: &[u8] = snap_row.get("data");
221 let mut snapshot = self.decode(raw)?;
222
223 if !snapshot.state.is_in_progress() {
224 tx.rollback().await.map_err(PgError)?;
225 return Ok(false);
226 }
227
228 let reason: Option<String> = signal_row.get("reason");
229 let requested_by: Option<String> = signal_row.get("requested_by");
230 snapshot.mark_cancelled(reason, requested_by, interrupted_at_task.map(String::from));
231
232 let data = self.encode(&snapshot)?;
233 let status = status_str(&snapshot.state);
234 let error = error_message(&snapshot).map(ToString::to_string);
235 let pos_kind = position_kind(&snapshot);
236 let wake_at = delay_wake_at(&snapshot);
237
238 sqlx::query(
239 "UPDATE sayiir_workflow_snapshots
240 SET data = $1, status = $2, error = $3,
241 position_kind = $4, delay_wake_at = $5,
242 completed_at = now(), updated_at = now()
243 WHERE instance_id = $6",
244 )
245 .bind(&data)
246 .bind(status)
247 .bind(&error)
248 .bind(pos_kind)
249 .bind(wake_at)
250 .bind(instance_id)
251 .execute(&mut *tx)
252 .await
253 .map_err(PgError)?;
254
255 sqlx::query(
257 "UPDATE sayiir_workflow_tasks SET status = 'cancelled', completed_at = now()
258 WHERE instance_id = $1 AND status = 'active'",
259 )
260 .bind(instance_id)
261 .execute(&mut *tx)
262 .await
263 .map_err(PgError)?;
264
265 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
267 .bind(instance_id)
268 .bind(SignalKind::Cancel.as_ref())
269 .execute(&mut *tx)
270 .await
271 .map_err(PgError)?;
272
273 tx.commit().await.map_err(PgError)?;
274 tracing::info!(instance_id, "workflow cancelled");
275 Ok(true)
276 }
277
278 #[tracing::instrument(
279 name = "db.check_and_pause",
280 skip(self),
281 fields(db.system = "postgresql"),
282 err(level = tracing::Level::ERROR),
283 )]
284 async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
285 tracing::debug!("checking for pause signal");
286 let mut tx = self.pool.begin().await.map_err(PgError)?;
287
288 let signal_row = sqlx::query(
290 "SELECT reason, requested_by
291 FROM sayiir_workflow_signals
292 WHERE instance_id = $1 AND kind = $2
293 FOR UPDATE",
294 )
295 .bind(instance_id)
296 .bind(SignalKind::Pause.as_ref())
297 .fetch_optional(&mut *tx)
298 .await
299 .map_err(PgError)?;
300
301 let Some(signal_row) = signal_row else {
302 tx.rollback().await.map_err(PgError)?;
303 return Ok(false);
304 };
305
306 let snap_row = sqlx::query(
308 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
309 )
310 .bind(instance_id)
311 .fetch_one(&mut *tx)
312 .await
313 .map_err(PgError)?;
314
315 let raw: &[u8] = snap_row.get("data");
316 let mut snapshot = self.decode(raw)?;
317
318 if !snapshot.state.is_in_progress() {
319 tx.rollback().await.map_err(PgError)?;
320 return Ok(false);
321 }
322
323 let reason: Option<String> = signal_row.get("reason");
324 let requested_by: Option<String> = signal_row.get("requested_by");
325 let pause_request = PauseRequest::new(reason, requested_by);
326 snapshot.mark_paused(&pause_request);
327
328 let data = self.encode(&snapshot)?;
329 let status = status_str(&snapshot.state);
330 let task_id = current_task_id(&snapshot).map(ToString::to_string);
331 let task_count = completed_task_count(&snapshot);
332 let pos_kind = position_kind(&snapshot);
333 let wake_at = delay_wake_at(&snapshot);
334
335 sqlx::query(
336 "UPDATE sayiir_workflow_snapshots
337 SET data = $1, status = $2, current_task_id = $3,
338 completed_task_count = $4, position_kind = $5,
339 delay_wake_at = $6, updated_at = now()
340 WHERE instance_id = $7",
341 )
342 .bind(&data)
343 .bind(status)
344 .bind(&task_id)
345 .bind(task_count)
346 .bind(pos_kind)
347 .bind(wake_at)
348 .bind(instance_id)
349 .execute(&mut *tx)
350 .await
351 .map_err(PgError)?;
352
353 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
355 .bind(instance_id)
356 .bind(SignalKind::Pause.as_ref())
357 .execute(&mut *tx)
358 .await
359 .map_err(PgError)?;
360
361 tx.commit().await.map_err(PgError)?;
362 tracing::info!(instance_id, "workflow paused");
363 Ok(true)
364 }
365
366 #[tracing::instrument(
367 name = "db.unpause",
368 skip(self),
369 fields(db.system = "postgresql"),
370 err(level = tracing::Level::ERROR),
371 )]
372 async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
373 tracing::debug!("unpausing workflow");
374 let mut tx = self.pool.begin().await.map_err(PgError)?;
375
376 let row = sqlx::query(
377 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
378 )
379 .bind(instance_id)
380 .fetch_optional(&mut *tx)
381 .await
382 .map_err(PgError)?
383 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
384
385 let raw: &[u8] = row.get("data");
386 let mut snapshot = self.decode(raw)?;
387
388 if !snapshot.state.is_paused() {
389 let state_name = status_str(&snapshot.state);
390 return Err(BackendError::CannotPause(format!(
391 "Workflow is not paused (current state: {state_name:?})"
392 )));
393 }
394
395 snapshot.mark_unpaused();
396
397 let data = self.encode(&snapshot)?;
398 let status = status_str(&snapshot.state);
399 let task_id = current_task_id(&snapshot).map(ToString::to_string);
400 let task_count = completed_task_count(&snapshot);
401 let pos_kind = position_kind(&snapshot);
402 let wake_at = delay_wake_at(&snapshot);
403
404 sqlx::query(
405 "UPDATE sayiir_workflow_snapshots
406 SET data = $1, status = $2, current_task_id = $3,
407 completed_task_count = $4, position_kind = $5,
408 delay_wake_at = $6, updated_at = now()
409 WHERE instance_id = $7",
410 )
411 .bind(&data)
412 .bind(status)
413 .bind(&task_id)
414 .bind(task_count)
415 .bind(pos_kind)
416 .bind(wake_at)
417 .bind(instance_id)
418 .execute(&mut *tx)
419 .await
420 .map_err(PgError)?;
421
422 tx.commit().await.map_err(PgError)?;
423 tracing::info!(instance_id, "workflow unpaused");
424 Ok(snapshot)
425 }
426}
427
428fn validate_signal_allowed(status: &str, kind: SignalKind) -> Result<(), BackendError> {
430 use std::str::FromStr;
431
432 let Ok(status) = SnapshotStatus::from_str(status) else {
433 return Ok(());
435 };
436
437 match kind {
438 SignalKind::Cancel => match status {
439 SnapshotStatus::Completed | SnapshotStatus::Failed => {
440 Err(BackendError::CannotCancel(status.as_ref().to_string()))
441 }
442 _ => Ok(()),
443 },
444 SignalKind::Pause => match status {
445 SnapshotStatus::Completed | SnapshotStatus::Failed | SnapshotStatus::Cancelled => {
446 Err(BackendError::CannotPause(status.as_ref().to_string()))
447 }
448 _ => Ok(()),
449 },
450 }
451}