1use sayiir_core::codec::{self, Decoder, Encoder};
7use sayiir_core::snapshot::{
8 PauseRequest, SignalKind, SignalRequest, SnapshotStatus, WorkflowSnapshot,
9};
10use sayiir_persistence::{BackendError, SignalStore};
11use sqlx::Row;
12
13use crate::backend::PostgresBackend;
14use crate::error::PgError;
15use crate::helpers::{
16 completed_task_count, current_task_id, delay_wake_at, error_message, position_kind, status_str,
17};
18
19impl<C> SignalStore for PostgresBackend<C>
20where
21 C: Encoder
22 + Decoder
23 + codec::sealed::EncodeValue<WorkflowSnapshot>
24 + codec::sealed::DecodeValue<WorkflowSnapshot>,
25{
26 async fn store_signal(
27 &self,
28 instance_id: &str,
29 kind: SignalKind,
30 request: SignalRequest,
31 ) -> Result<(), BackendError> {
32 tracing::debug!(instance_id, kind = %kind.as_ref(), "storing signal");
33 let row =
35 sqlx::query("SELECT status FROM sayiir_workflow_snapshots WHERE instance_id = $1")
36 .bind(instance_id)
37 .fetch_optional(&self.pool)
38 .await
39 .map_err(PgError)?
40 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
41
42 let status: String = row.get("status");
43 validate_signal_allowed(&status, &kind)?;
44
45 sqlx::query(
46 "INSERT INTO sayiir_workflow_signals (instance_id, kind, reason, requested_by)
47 VALUES ($1, $2, $3, $4)
48 ON CONFLICT (instance_id, kind) DO UPDATE SET
49 reason = $3, requested_by = $4, created_at = now()",
50 )
51 .bind(instance_id)
52 .bind(kind.as_ref())
53 .bind(&request.reason)
54 .bind(&request.requested_by)
55 .execute(&self.pool)
56 .await
57 .map_err(PgError)?;
58
59 Ok(())
60 }
61
62 async fn get_signal(
63 &self,
64 instance_id: &str,
65 kind: SignalKind,
66 ) -> Result<Option<SignalRequest>, BackendError> {
67 tracing::debug!(instance_id, kind = %kind.as_ref(), "getting signal");
68 let row = sqlx::query(
69 "SELECT reason, requested_by, created_at
70 FROM sayiir_workflow_signals
71 WHERE instance_id = $1 AND kind = $2",
72 )
73 .bind(instance_id)
74 .bind(kind.as_ref())
75 .fetch_optional(&self.pool)
76 .await
77 .map_err(PgError)?;
78
79 Ok(row.map(|r| SignalRequest {
80 reason: r.get("reason"),
81 requested_by: r.get("requested_by"),
82 requested_at: r.get("created_at"),
83 }))
84 }
85
86 async fn clear_signal(&self, instance_id: &str, kind: SignalKind) -> Result<(), BackendError> {
87 tracing::debug!(instance_id, kind = %kind.as_ref(), "clearing signal");
88 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
89 .bind(instance_id)
90 .bind(kind.as_ref())
91 .execute(&self.pool)
92 .await
93 .map_err(PgError)?;
94 Ok(())
95 }
96
97 async fn check_and_cancel(
100 &self,
101 instance_id: &str,
102 interrupted_at_task: Option<&str>,
103 ) -> Result<bool, BackendError> {
104 tracing::debug!(instance_id, "checking for cancel signal");
105 let mut tx = self.pool.begin().await.map_err(PgError)?;
106
107 let signal_row = sqlx::query(
109 "SELECT reason, requested_by
110 FROM sayiir_workflow_signals
111 WHERE instance_id = $1 AND kind = $2
112 FOR UPDATE",
113 )
114 .bind(instance_id)
115 .bind(SignalKind::Cancel.as_ref())
116 .fetch_optional(&mut *tx)
117 .await
118 .map_err(PgError)?;
119
120 let Some(signal_row) = signal_row else {
121 tx.rollback().await.map_err(PgError)?;
122 return Ok(false);
123 };
124
125 let snap_row = sqlx::query(
127 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
128 )
129 .bind(instance_id)
130 .fetch_one(&mut *tx)
131 .await
132 .map_err(PgError)?;
133
134 let raw: &[u8] = snap_row.get("data");
135 let mut snapshot = self.decode(raw)?;
136
137 if !snapshot.state.is_in_progress() {
138 tx.rollback().await.map_err(PgError)?;
139 return Ok(false);
140 }
141
142 let reason: Option<String> = signal_row.get("reason");
143 let requested_by: Option<String> = signal_row.get("requested_by");
144 snapshot.mark_cancelled(reason, requested_by, interrupted_at_task.map(String::from));
145
146 let data = self.encode(&snapshot)?;
147 let status = status_str(&snapshot.state);
148 let error = error_message(&snapshot).map(ToString::to_string);
149 let pos_kind = position_kind(&snapshot);
150 let wake_at = delay_wake_at(&snapshot);
151
152 sqlx::query(
153 "UPDATE sayiir_workflow_snapshots
154 SET data = $1, status = $2, error = $3,
155 position_kind = $4, delay_wake_at = $5,
156 completed_at = now(), updated_at = now()
157 WHERE instance_id = $6",
158 )
159 .bind(&data)
160 .bind(status)
161 .bind(&error)
162 .bind(pos_kind)
163 .bind(wake_at)
164 .bind(instance_id)
165 .execute(&mut *tx)
166 .await
167 .map_err(PgError)?;
168
169 sqlx::query(
171 "UPDATE sayiir_workflow_tasks SET status = 'cancelled', completed_at = now()
172 WHERE instance_id = $1 AND status = 'active'",
173 )
174 .bind(instance_id)
175 .execute(&mut *tx)
176 .await
177 .map_err(PgError)?;
178
179 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
181 .bind(instance_id)
182 .bind(SignalKind::Cancel.as_ref())
183 .execute(&mut *tx)
184 .await
185 .map_err(PgError)?;
186
187 tx.commit().await.map_err(PgError)?;
188 tracing::info!(instance_id, "workflow cancelled");
189 Ok(true)
190 }
191
192 async fn check_and_pause(&self, instance_id: &str) -> Result<bool, BackendError> {
193 tracing::debug!(instance_id, "checking for pause signal");
194 let mut tx = self.pool.begin().await.map_err(PgError)?;
195
196 let signal_row = sqlx::query(
198 "SELECT reason, requested_by
199 FROM sayiir_workflow_signals
200 WHERE instance_id = $1 AND kind = $2
201 FOR UPDATE",
202 )
203 .bind(instance_id)
204 .bind(SignalKind::Pause.as_ref())
205 .fetch_optional(&mut *tx)
206 .await
207 .map_err(PgError)?;
208
209 let Some(signal_row) = signal_row else {
210 tx.rollback().await.map_err(PgError)?;
211 return Ok(false);
212 };
213
214 let snap_row = sqlx::query(
216 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
217 )
218 .bind(instance_id)
219 .fetch_one(&mut *tx)
220 .await
221 .map_err(PgError)?;
222
223 let raw: &[u8] = snap_row.get("data");
224 let mut snapshot = self.decode(raw)?;
225
226 if !snapshot.state.is_in_progress() {
227 tx.rollback().await.map_err(PgError)?;
228 return Ok(false);
229 }
230
231 let reason: Option<String> = signal_row.get("reason");
232 let requested_by: Option<String> = signal_row.get("requested_by");
233 let pause_request = PauseRequest::new(reason, requested_by);
234 snapshot.mark_paused(&pause_request);
235
236 let data = self.encode(&snapshot)?;
237 let status = status_str(&snapshot.state);
238 let task_id = current_task_id(&snapshot).map(ToString::to_string);
239 let task_count = completed_task_count(&snapshot);
240 let pos_kind = position_kind(&snapshot);
241 let wake_at = delay_wake_at(&snapshot);
242
243 sqlx::query(
244 "UPDATE sayiir_workflow_snapshots
245 SET data = $1, status = $2, current_task_id = $3,
246 completed_task_count = $4, position_kind = $5,
247 delay_wake_at = $6, updated_at = now()
248 WHERE instance_id = $7",
249 )
250 .bind(&data)
251 .bind(status)
252 .bind(&task_id)
253 .bind(task_count)
254 .bind(pos_kind)
255 .bind(wake_at)
256 .bind(instance_id)
257 .execute(&mut *tx)
258 .await
259 .map_err(PgError)?;
260
261 sqlx::query("DELETE FROM sayiir_workflow_signals WHERE instance_id = $1 AND kind = $2")
263 .bind(instance_id)
264 .bind(SignalKind::Pause.as_ref())
265 .execute(&mut *tx)
266 .await
267 .map_err(PgError)?;
268
269 tx.commit().await.map_err(PgError)?;
270 tracing::info!(instance_id, "workflow paused");
271 Ok(true)
272 }
273
274 async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
275 tracing::debug!(instance_id, "unpausing workflow");
276 let mut tx = self.pool.begin().await.map_err(PgError)?;
277
278 let row = sqlx::query(
279 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
280 )
281 .bind(instance_id)
282 .fetch_optional(&mut *tx)
283 .await
284 .map_err(PgError)?
285 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
286
287 let raw: &[u8] = row.get("data");
288 let mut snapshot = self.decode(raw)?;
289
290 if !snapshot.state.is_paused() {
291 let state_name = status_str(&snapshot.state);
292 return Err(BackendError::CannotPause(format!(
293 "Workflow is not paused (current state: {state_name:?})"
294 )));
295 }
296
297 snapshot.mark_unpaused();
298
299 let data = self.encode(&snapshot)?;
300 let status = status_str(&snapshot.state);
301 let task_id = current_task_id(&snapshot).map(ToString::to_string);
302 let task_count = completed_task_count(&snapshot);
303 let pos_kind = position_kind(&snapshot);
304 let wake_at = delay_wake_at(&snapshot);
305
306 sqlx::query(
307 "UPDATE sayiir_workflow_snapshots
308 SET data = $1, status = $2, current_task_id = $3,
309 completed_task_count = $4, position_kind = $5,
310 delay_wake_at = $6, updated_at = now()
311 WHERE instance_id = $7",
312 )
313 .bind(&data)
314 .bind(status)
315 .bind(&task_id)
316 .bind(task_count)
317 .bind(pos_kind)
318 .bind(wake_at)
319 .bind(instance_id)
320 .execute(&mut *tx)
321 .await
322 .map_err(PgError)?;
323
324 tx.commit().await.map_err(PgError)?;
325 tracing::info!(instance_id, "workflow unpaused");
326 Ok(snapshot)
327 }
328}
329
330fn validate_signal_allowed(status: &str, kind: &SignalKind) -> Result<(), BackendError> {
332 use std::str::FromStr;
333
334 let Ok(status) = SnapshotStatus::from_str(status) else {
335 return Ok(());
337 };
338
339 match kind {
340 SignalKind::Cancel => match status {
341 SnapshotStatus::Completed | SnapshotStatus::Failed => {
342 Err(BackendError::CannotCancel(status.as_ref().to_string()))
343 }
344 _ => Ok(()),
345 },
346 SignalKind::Pause => match status {
347 SnapshotStatus::Completed | SnapshotStatus::Failed | SnapshotStatus::Cancelled => {
348 Err(BackendError::CannotPause(status.as_ref().to_string()))
349 }
350 _ => Ok(()),
351 },
352 }
353}