1use sayiir_core::codec::{self, Decoder, Encoder};
4use sayiir_core::snapshot::{SnapshotStatus, WorkflowSnapshot};
5use sayiir_persistence::{BackendError, SnapshotStore};
6use sqlx::Row;
7
8use crate::backend::PostgresBackend;
9use crate::error::PgError;
10use crate::helpers::{
11 completed_task_count, current_task_id, delay_wake_at, error_message, is_terminal,
12 position_kind, status_str,
13};
14
15impl<C> SnapshotStore for PostgresBackend<C>
16where
17 C: Encoder
18 + Decoder
19 + codec::sealed::EncodeValue<WorkflowSnapshot>
20 + codec::sealed::DecodeValue<WorkflowSnapshot>,
21{
22 #[tracing::instrument(
23 name = "db.save_snapshot",
24 skip(self, snapshot),
25 fields(
26 db.system = "postgresql",
27 instance_id = %snapshot.instance_id,
28 status = %status_str(&snapshot.state),
29 ),
30 err(level = tracing::Level::ERROR),
31 )]
32 #[allow(clippy::too_many_lines)]
33 async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
34 tracing::debug!("saving snapshot");
35 let data = self.encode(snapshot)?;
36 let status = status_str(&snapshot.state);
37 let task_id = current_task_id(snapshot).map(ToString::to_string);
38 let task_count = completed_task_count(snapshot);
39 let error = error_message(snapshot).map(ToString::to_string);
40 let terminal = is_terminal(snapshot);
41 let pos_kind = position_kind(snapshot);
42 let wake_at = delay_wake_at(snapshot);
43
44 let mut tx = self.pool.begin().await.map_err(PgError)?;
45
46 sqlx::query(
48 "INSERT INTO sayiir_workflow_snapshots
49 (instance_id, status, definition_hash, current_task_id,
50 completed_task_count, data, error, position_kind, delay_wake_at,
51 trace_parent, completed_at, updated_at)
52 VALUES ($1, $2, $3, $4, $5, $6, $7, $9, $10, $11,
53 CASE WHEN $8 THEN now() ELSE NULL END, now())
54 ON CONFLICT (instance_id) DO UPDATE SET
55 status = $2,
56 definition_hash = $3,
57 current_task_id = $4,
58 completed_task_count = $5,
59 data = $6,
60 error = $7,
61 position_kind = $9,
62 delay_wake_at = $10,
63 trace_parent = $11,
64 completed_at = CASE WHEN $8 THEN now() ELSE sayiir_workflow_snapshots.completed_at END,
65 updated_at = now()",
66 )
67 .bind(&snapshot.instance_id) .bind(status) .bind(&snapshot.definition_hash) .bind(&task_id) .bind(task_count) .bind(&data) .bind(&error) .bind(terminal) .bind(pos_kind) .bind(wake_at) .bind(snapshot.trace_parent.as_deref()) .execute(&mut *tx)
79 .await
80 .map_err(PgError)?;
81
82 sqlx::query(
84 "INSERT INTO sayiir_workflow_snapshot_history
85 (instance_id, version, status, current_task_id, data)
86 VALUES (
87 $1,
88 (SELECT COALESCE(MAX(version), 0) + 1
89 FROM sayiir_workflow_snapshot_history WHERE instance_id = $1),
90 $2, $3, $4
91 )",
92 )
93 .bind(&snapshot.instance_id)
94 .bind(status)
95 .bind(&task_id)
96 .bind(&data)
97 .execute(&mut *tx)
98 .await
99 .map_err(PgError)?;
100
101 if let Some(ref tid) = task_id {
105 sqlx::query(
106 "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, started_at)
107 VALUES ($1, $2, 'active', now())
108 ON CONFLICT (instance_id, task_id) DO UPDATE SET
109 status = CASE
110 WHEN sayiir_workflow_tasks.status = 'completed' THEN sayiir_workflow_tasks.status
111 ELSE 'active'
112 END,
113 started_at = COALESCE(sayiir_workflow_tasks.started_at, now())",
114 )
115 .bind(&snapshot.instance_id)
116 .bind(tid)
117 .execute(&mut *tx)
118 .await
119 .map_err(PgError)?;
120 }
121
122 if terminal {
124 let terminal_status = match SnapshotStatus::from(&snapshot.state) {
125 SnapshotStatus::Failed => "failed",
126 SnapshotStatus::Cancelled => "cancelled",
127 _ => "completed",
128 };
129 sqlx::query(
130 "UPDATE sayiir_workflow_tasks SET status = $1, completed_at = now(), error = $2
131 WHERE instance_id = $3 AND status = 'active'",
132 )
133 .bind(terminal_status)
134 .bind(&error)
135 .bind(&snapshot.instance_id)
136 .execute(&mut *tx)
137 .await
138 .map_err(PgError)?;
139 }
140
141 tx.commit().await.map_err(PgError)?;
142 Ok(())
143 }
144
145 #[tracing::instrument(
146 name = "db.save_task_result",
147 skip(self, output),
148 fields(db.system = "postgresql"),
149 err(level = tracing::Level::ERROR),
150 )]
151 async fn save_task_result(
152 &self,
153 instance_id: &str,
154 task_id: &str,
155 output: bytes::Bytes,
156 ) -> Result<(), BackendError> {
157 tracing::debug!("saving task result");
158 let mut tx = self.pool.begin().await.map_err(PgError)?;
159
160 let row = sqlx::query(
162 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
163 )
164 .bind(instance_id)
165 .fetch_optional(&mut *tx)
166 .await
167 .map_err(PgError)?
168 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
169
170 let raw: &[u8] = row.get("data");
171 let mut snapshot = self.decode(raw)?;
172 snapshot.mark_task_completed(task_id.to_string(), output);
173
174 let data = self.encode(&snapshot)?;
175 let status = status_str(&snapshot.state);
176 let current = current_task_id(&snapshot).map(ToString::to_string);
177 let task_count = completed_task_count(&snapshot);
178
179 sqlx::query(
180 "UPDATE sayiir_workflow_snapshots
181 SET data = $1, status = $2, current_task_id = $3,
182 completed_task_count = $4, updated_at = now()
183 WHERE instance_id = $5",
184 )
185 .bind(&data)
186 .bind(status)
187 .bind(¤t)
188 .bind(task_count)
189 .bind(instance_id)
190 .execute(&mut *tx)
191 .await
192 .map_err(PgError)?;
193
194 sqlx::query(
196 "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, completed_at)
197 VALUES ($1, $2, 'completed', now())
198 ON CONFLICT (instance_id, task_id) DO UPDATE SET
199 status = 'completed', completed_at = now(), error = NULL",
200 )
201 .bind(instance_id)
202 .bind(task_id)
203 .execute(&mut *tx)
204 .await
205 .map_err(PgError)?;
206
207 tx.commit().await.map_err(PgError)?;
208 Ok(())
209 }
210
211 #[tracing::instrument(
212 name = "db.load_snapshot",
213 skip(self),
214 fields(db.system = "postgresql"),
215 err(level = tracing::Level::ERROR),
216 )]
217 async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
218 tracing::debug!("loading snapshot");
219 let row = sqlx::query(
220 "SELECT data, trace_parent FROM sayiir_workflow_snapshots WHERE instance_id = $1",
221 )
222 .bind(instance_id)
223 .fetch_optional(&self.pool)
224 .await
225 .map_err(PgError)?
226 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
227
228 let raw: &[u8] = row.get("data");
229 let mut snapshot = self.decode(raw)?;
230 snapshot.trace_parent = row.get("trace_parent");
231 Ok(snapshot)
232 }
233
234 #[tracing::instrument(
235 name = "db.delete_snapshot",
236 skip(self),
237 fields(db.system = "postgresql"),
238 err(level = tracing::Level::ERROR),
239 )]
240 async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
241 tracing::debug!("deleting snapshot");
242 let result = sqlx::query("DELETE FROM sayiir_workflow_snapshots WHERE instance_id = $1")
243 .bind(instance_id)
244 .execute(&self.pool)
245 .await
246 .map_err(PgError)?;
247
248 if result.rows_affected() == 0 {
249 return Err(BackendError::NotFound(instance_id.to_string()));
250 }
251 Ok(())
252 }
253
254 #[tracing::instrument(
255 name = "db.list_snapshots",
256 skip(self),
257 fields(db.system = "postgresql"),
258 err(level = tracing::Level::ERROR),
259 )]
260 async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
261 tracing::debug!("listing snapshots");
262 let rows = sqlx::query("SELECT instance_id FROM sayiir_workflow_snapshots")
263 .fetch_all(&self.pool)
264 .await
265 .map_err(PgError)?;
266
267 Ok(rows.iter().map(|r| r.get("instance_id")).collect())
268 }
269}