1use sayiir_core::codec::{self, Decoder, Encoder};
7use sayiir_core::snapshot::{PauseRequest, SignalKind, SignalRequest, WorkflowSnapshot};
8use sayiir_persistence::validation::validate_signal_allowed;
9use sayiir_persistence::{BackendError, SignalStore};
10use sqlx::Row;
11
12use crate::backend::PostgresBackend;
13use crate::error::PgError;
14
15impl<C> SignalStore for PostgresBackend<C>
16where
17 C: Encoder
18 + Decoder
19 + codec::sealed::EncodeValue<WorkflowSnapshot>
20 + codec::sealed::DecodeValue<WorkflowSnapshot>,
21{
22 #[tracing::instrument(
23 name = "db.store_signal",
24 skip(self, request),
25 fields(db.system = "postgresql", kind = %kind.as_ref()),
26 err(level = tracing::Level::ERROR),
27 )]
28 async fn store_signal(
29 &self,
30 instance_id: &str,
31 kind: SignalKind,
32 request: SignalRequest,
33 ) -> Result<(), BackendError> {
34 tracing::debug!("storing signal");
35 let row =
37 sqlx::query("SELECT status FROM sayiir_workflow_snapshots WHERE instance_id = $1")
38 .bind(instance_id)
39 .fetch_optional(&self.pool)
40 .await
41 .map_err(PgError)?
42 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
43
44 let status: String = row.get("status");
45 validate_signal_allowed(&status, kind)?;
46
47 sqlx::query(
48 "INSERT INTO sayiir_workflow_signals (instance_id, kind, reason, requested_by)
49 VALUES ($1, $2, $3, $4)
50 ON CONFLICT (instance_id, kind) DO UPDATE SET
51 reason = $3, requested_by = $4, created_at = now()",
52 )
53 .bind(instance_id)
54 .bind(kind.as_ref())
55 .bind(&request.reason)
56 .bind(&request.requested_by)
57 .execute(&self.pool)
58 .await
59 .map_err(PgError)?;
60
61 Ok(())
62 }
63
64 #[tracing::instrument(
65 name = "db.get_signal",
66 skip(self),
67 fields(db.system = "postgresql", kind = %kind.as_ref()),
68 err(level = tracing::Level::ERROR),
69 )]
70 async fn get_signal(
71 &self,
72 instance_id: &str,
73 kind: SignalKind,
74 ) -> Result<Option<SignalRequest>, BackendError> {
75 tracing::debug!("getting signal");
76 let row = sqlx::query(
77 "SELECT reason, requested_by, created_at
78 FROM sayiir_workflow_signals
79 WHERE instance_id = $1 AND kind = $2",
80 )
81 .bind(instance_id)
82 .bind(kind.as_ref())
83 .fetch_optional(&self.pool)
84 .await
85 .map_err(PgError)?;
86
87 Ok(row.map(|r| SignalRequest {
88 reason: r.get("reason"),
89 requested_by: r.get("requested_by"),
90 requested_at: r.get("created_at"),
91 }))
92 }
93
94 #[tracing::instrument(
95 name = "db.clear_signal",
96 skip(self),
97 fields(db.system = "postgresql", kind = %kind.as_ref()),
98 err(level = tracing::Level::ERROR),
99 )]
100 async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
101 tracing::debug!("clearing signal");
102 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
103 .bind(instance_id)
104 .bind(kind.as_ref())
105 .execute(&self.pool)
106 .await
107 .map_err(PgError)?;
108 Ok(())
109 }
110
111 #[tracing::instrument(
112 name = "db.send_event",
113 skip(self, payload),
114 fields(db.system = "postgresql"),
115 err(level = tracing::Level::ERROR),
116 )]
117 async fn send_event(
118 &self,
119 instance_id: &str,
120 signal_name: &str,
121 payload: bytes::Bytes,
122 ) -> Result<(), BackendError> {
123 tracing::debug!("buffering external event");
124 sqlx::query(
125 "INSERT INTO sayiir_workflow_events (instance_id, signal_name, payload)
126 VALUES ($1, $2, $3)",
127 )
128 .bind(instance_id)
129 .bind(signal_name)
130 .bind(payload.as_ref())
131 .execute(&self.pool)
132 .await
133 .map_err(PgError)?;
134 Ok(())
135 }
136
137 #[tracing::instrument(
138 name = "db.consume_event",
139 skip(self),
140 fields(db.system = "postgresql"),
141 err(level = tracing::Level::ERROR),
142 )]
143 async fn consume_event(
144 &self,
145 instance_id: &str,
146 signal_name: &str,
147 ) -> Result<Option<bytes::Bytes>, BackendError> {
148 tracing::debug!("consuming oldest buffered event");
149 let row = sqlx::query(
151 "DELETE FROM sayiir_workflow_events
152 WHERE id = (
153 SELECT id FROM sayiir_workflow_events
154 WHERE instance_id = $1 AND signal_name = $2
155 ORDER BY id ASC
156 LIMIT 1
157 FOR UPDATE SKIP LOCKED
158 )
159 RETURNING payload",
160 )
161 .bind(instance_id)
162 .bind(signal_name)
163 .fetch_optional(&self.pool)
164 .await
165 .map_err(PgError)?;
166
167 Ok(row.map(|r| {
168 let raw: Vec<u8> = r.get("payload");
169 bytes::Bytes::from(raw)
170 }))
171 }
172
173 #[tracing::instrument(
176 name = "db.check_and_cancel",
177 skip(self),
178 fields(db.system = "postgresql"),
179 err(level = tracing::Level::ERROR),
180 )]
181 async fn check_and_cancel(
182 &self,
183 instance_id: &str,
184 interrupted_at_task: Option<&str>,
185 ) -> Result<bool, BackendError> {
186 tracing::debug!("checking for cancel signal");
187 let mut tx = self.pool.begin().await.map_err(PgError)?;
188
189 let signal_row = sqlx::query(
191 "SELECT reason, requested_by
192 FROM sayiir_workflow_signals
193 WHERE instance_id = $1 AND kind = $2
194 FOR UPDATE",
195 )
196 .bind(instance_id)
197 .bind(SignalKind::Cancel.as_ref())
198 .fetch_optional(&mut *tx)
199 .await
200 .map_err(PgError)?;
201
202 let Some(signal_row) = signal_row else {
203 tx.rollback().await.map_err(PgError)?;
204 return Ok(false);
205 };
206
207 let snap_row = sqlx::query(
209 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
210 )
211 .bind(instance_id)
212 .fetch_one(&mut *tx)
213 .await
214 .map_err(PgError)?;
215
216 let raw: &[u8] = snap_row.get("data");
217 let mut snapshot = self.decode(raw)?;
218
219 if !snapshot.state.is_in_progress() {
220 tx.rollback().await.map_err(PgError)?;
221 return Ok(false);
222 }
223
224 let reason: Option<String> = signal_row.get("reason");
225 let requested_by: Option<String> = signal_row.get("requested_by");
226 snapshot.mark_cancelled(reason, requested_by, interrupted_at_task.map(String::from));
227
228 let data = self.encode(&snapshot)?;
229 let status = snapshot.state.as_ref();
230 let error = snapshot.error_message().map(ToString::to_string);
231 let pos_kind = snapshot.position_kind();
232 let wake_at = snapshot.delay_wake_at();
233
234 sqlx::query(
235 "UPDATE sayiir_workflow_snapshots
236 SET data = $1, status = $2, error = $3,
237 position_kind = $4, delay_wake_at = $5,
238 completed_at = now(), updated_at = now()
239 WHERE instance_id = $6",
240 )
241 .bind(&data)
242 .bind(status)
243 .bind(&error)
244 .bind(pos_kind)
245 .bind(wake_at)
246 .bind(instance_id)
247 .execute(&mut *tx)
248 .await
249 .map_err(PgError)?;
250
251 sqlx::query(
253 "UPDATE sayiir_workflow_tasks SET status = 'cancelled', completed_at = now()
254 WHERE instance_id = $1 AND status = 'active'",
255 )
256 .bind(instance_id)
257 .execute(&mut *tx)
258 .await
259 .map_err(PgError)?;
260
261 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
263 .bind(instance_id)
264 .bind(SignalKind::Cancel.as_ref())
265 .execute(&mut *tx)
266 .await
267 .map_err(PgError)?;
268
269 tx.commit().await.map_err(PgError)?;
270 tracing::info!(instance_id, "workflow cancelled");
271 Ok(true)
272 }
273
274 #[tracing::instrument(
275 name = "db.check_and_pause",
276 skip(self),
277 fields(db.system = "postgresql"),
278 err(level = tracing::Level::ERROR),
279 )]
280 async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
281 tracing::debug!("checking for pause signal");
282 let mut tx = self.pool.begin().await.map_err(PgError)?;
283
284 let signal_row = sqlx::query(
286 "SELECT reason, requested_by
287 FROM sayiir_workflow_signals
288 WHERE instance_id = $1 AND kind = $2
289 FOR UPDATE",
290 )
291 .bind(instance_id)
292 .bind(SignalKind::Pause.as_ref())
293 .fetch_optional(&mut *tx)
294 .await
295 .map_err(PgError)?;
296
297 let Some(signal_row) = signal_row else {
298 tx.rollback().await.map_err(PgError)?;
299 return Ok(false);
300 };
301
302 let snap_row = sqlx::query(
304 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
305 )
306 .bind(instance_id)
307 .fetch_one(&mut *tx)
308 .await
309 .map_err(PgError)?;
310
311 let raw: &[u8] = snap_row.get("data");
312 let mut snapshot = self.decode(raw)?;
313
314 if !snapshot.state.is_in_progress() {
315 tx.rollback().await.map_err(PgError)?;
316 return Ok(false);
317 }
318
319 let reason: Option<String> = signal_row.get("reason");
320 let requested_by: Option<String> = signal_row.get("requested_by");
321 let pause_request = PauseRequest::new(reason, requested_by);
322 snapshot.mark_paused(&pause_request);
323
324 let data = self.encode(&snapshot)?;
325 let status = snapshot.state.as_ref();
326 let task_id = snapshot.current_task_id().map(ToString::to_string);
327 let task_count = snapshot.completed_task_count();
328 let pos_kind = snapshot.position_kind();
329 let wake_at = snapshot.delay_wake_at();
330
331 sqlx::query(
332 "UPDATE sayiir_workflow_snapshots
333 SET data = $1, status = $2, current_task_id = $3,
334 completed_task_count = $4, position_kind = $5,
335 delay_wake_at = $6, updated_at = now()
336 WHERE instance_id = $7",
337 )
338 .bind(&data)
339 .bind(status)
340 .bind(&task_id)
341 .bind(task_count)
342 .bind(pos_kind)
343 .bind(wake_at)
344 .bind(instance_id)
345 .execute(&mut *tx)
346 .await
347 .map_err(PgError)?;
348
349 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
351 .bind(instance_id)
352 .bind(SignalKind::Pause.as_ref())
353 .execute(&mut *tx)
354 .await
355 .map_err(PgError)?;
356
357 tx.commit().await.map_err(PgError)?;
358 tracing::info!(instance_id, "workflow paused");
359 Ok(true)
360 }
361
362 #[tracing::instrument(
363 name = "db.unpause",
364 skip(self),
365 fields(db.system = "postgresql"),
366 err(level = tracing::Level::ERROR),
367 )]
368 async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
369 tracing::debug!("unpausing workflow");
370 let mut tx = self.pool.begin().await.map_err(PgError)?;
371
372 let row = sqlx::query(
373 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
374 )
375 .bind(instance_id)
376 .fetch_optional(&mut *tx)
377 .await
378 .map_err(PgError)?
379 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
380
381 let raw: &[u8] = row.get("data");
382 let mut snapshot = self.decode(raw)?;
383
384 if !snapshot.state.is_paused() {
385 let state_name = snapshot.state.as_ref();
386 return Err(BackendError::CannotPause(format!(
387 "Workflow is not paused (current state: {state_name:?})"
388 )));
389 }
390
391 snapshot.mark_unpaused();
392
393 let data = self.encode(&snapshot)?;
394 let status = snapshot.state.as_ref();
395 let task_id = snapshot.current_task_id().map(ToString::to_string);
396 let task_count = snapshot.completed_task_count();
397 let pos_kind = snapshot.position_kind();
398 let wake_at = snapshot.delay_wake_at();
399
400 sqlx::query(
401 "UPDATE sayiir_workflow_snapshots
402 SET data = $1, status = $2, current_task_id = $3,
403 completed_task_count = $4, position_kind = $5,
404 delay_wake_at = $6, updated_at = now()
405 WHERE instance_id = $7",
406 )
407 .bind(&data)
408 .bind(status)
409 .bind(&task_id)
410 .bind(task_count)
411 .bind(pos_kind)
412 .bind(wake_at)
413 .bind(instance_id)
414 .execute(&mut *tx)
415 .await
416 .map_err(PgError)?;
417
418 tx.commit().await.map_err(PgError)?;
419 tracing::info!(instance_id, "workflow unpaused");
420 Ok(snapshot)
421 }
422}