1use sayiir_core::codec::{self, Decoder, Encoder};
4use sayiir_core::snapshot::{SnapshotStatus, WorkflowSnapshot};
5use sayiir_persistence::{BackendError, SnapshotStore};
6use sqlx::Row;
7
8use crate::backend::PostgresBackend;
9use crate::error::PgError;
10
11impl<C> SnapshotStore for PostgresBackend<C>
12where
13 C: Encoder
14 + Decoder
15 + codec::sealed::EncodeValue<WorkflowSnapshot>
16 + codec::sealed::DecodeValue<WorkflowSnapshot>,
17{
18 #[tracing::instrument(
19 name = "db.save_snapshot",
20 skip(self, snapshot),
21 fields(
22 db.system = "postgresql",
23 instance_id = %snapshot.instance_id,
24 status = %snapshot.state.as_ref(),
25 ),
26 err(level = tracing::Level::ERROR),
27 )]
28 #[allow(clippy::too_many_lines)]
29 async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
30 tracing::debug!("saving snapshot");
31 let data = self.encode(snapshot)?;
32 let status = snapshot.state.as_ref();
33 let task_id = snapshot.current_task_id().map(ToString::to_string);
34 let task_count = snapshot.completed_task_count();
35 let error = snapshot.error_message().map(ToString::to_string);
36 let terminal = snapshot.state.is_terminal();
37 let pos_kind = snapshot.position_kind();
38 let wake_at = snapshot.delay_wake_at();
39
40 let task_priority = i16::from(snapshot.current_task_priority());
41 let task_tags: Vec<&str> = snapshot
42 .current_task_tags()
43 .iter()
44 .map(String::as_str)
45 .collect();
46
47 let mut tx = self.pool.begin().await.map_err(PgError)?;
48
49 sqlx::query(
51 "INSERT INTO sayiir_workflow_snapshots
52 (instance_id, status, definition_hash, current_task_id,
53 completed_task_count, data, error, position_kind, delay_wake_at,
54 trace_parent, task_priority, task_tags, completed_at, updated_at)
55 VALUES ($1, $2, $3, $4, $5, $6, $7, $9, $10, $11, $12, $13,
56 CASE WHEN $8 THEN now() ELSE NULL END, now())
57 ON CONFLICT (instance_id) DO UPDATE SET
58 status = $2,
59 definition_hash = $3,
60 current_task_id = $4,
61 completed_task_count = $5,
62 data = $6,
63 error = $7,
64 position_kind = $9,
65 delay_wake_at = $10,
66 trace_parent = $11,
67 task_priority = $12,
68 task_tags = $13,
69 completed_at = CASE WHEN $8 THEN now() ELSE sayiir_workflow_snapshots.completed_at END,
70 updated_at = now()",
71 )
72 .bind(&snapshot.instance_id) .bind(status) .bind(&snapshot.definition_hash) .bind(&task_id) .bind(task_count) .bind(&data) .bind(&error) .bind(terminal) .bind(pos_kind) .bind(wake_at) .bind(snapshot.trace_parent.as_deref()) .bind(task_priority) .bind(&task_tags) .execute(&mut *tx)
86 .await
87 .map_err(PgError)?;
88
89 sqlx::query(
91 "INSERT INTO sayiir_workflow_snapshot_history
92 (instance_id, version, status, current_task_id, data)
93 VALUES (
94 $1,
95 (SELECT COALESCE(MAX(version), 0) + 1
96 FROM sayiir_workflow_snapshot_history WHERE instance_id = $1),
97 $2, $3, $4
98 )",
99 )
100 .bind(&snapshot.instance_id)
101 .bind(status)
102 .bind(&task_id)
103 .bind(&data)
104 .execute(&mut *tx)
105 .await
106 .map_err(PgError)?;
107
108 if let Some(ref tid) = task_id {
112 sqlx::query(
113 "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, started_at)
114 VALUES ($1, $2, 'active', now())
115 ON CONFLICT (instance_id, task_id) DO UPDATE SET
116 status = CASE
117 WHEN sayiir_workflow_tasks.status = 'completed' THEN sayiir_workflow_tasks.status
118 ELSE 'active'
119 END,
120 started_at = COALESCE(sayiir_workflow_tasks.started_at, now())",
121 )
122 .bind(&snapshot.instance_id)
123 .bind(tid)
124 .execute(&mut *tx)
125 .await
126 .map_err(PgError)?;
127 }
128
129 if terminal {
131 let terminal_status = match SnapshotStatus::from(&snapshot.state) {
132 SnapshotStatus::Failed => "failed",
133 SnapshotStatus::Cancelled => "cancelled",
134 _ => "completed",
135 };
136 sqlx::query(
137 "UPDATE sayiir_workflow_tasks SET status = $1, completed_at = now(), error = $2
138 WHERE instance_id = $3 AND status = 'active'",
139 )
140 .bind(terminal_status)
141 .bind(&error)
142 .bind(&snapshot.instance_id)
143 .execute(&mut *tx)
144 .await
145 .map_err(PgError)?;
146 }
147
148 tx.commit().await.map_err(PgError)?;
149 Ok(())
150 }
151
152 #[tracing::instrument(
153 name = "db.save_task_result",
154 skip(self, output),
155 fields(db.system = "postgresql"),
156 err(level = tracing::Level::ERROR),
157 )]
158 async fn save_task_result(
159 &self,
160 instance_id: &str,
161 task_id: &str,
162 output: bytes::Bytes,
163 ) -> Result<(), BackendError> {
164 tracing::debug!("saving task result");
165 let mut tx = self.pool.begin().await.map_err(PgError)?;
166
167 let row = sqlx::query(
169 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
170 )
171 .bind(instance_id)
172 .fetch_optional(&mut *tx)
173 .await
174 .map_err(PgError)?
175 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
176
177 let raw: &[u8] = row.get("data");
178 let mut snapshot = self.decode(raw)?;
179 snapshot.mark_task_completed(task_id.to_string(), output);
180
181 let data = self.encode(&snapshot)?;
182 let status = snapshot.state.as_ref();
183 let current = snapshot.current_task_id().map(ToString::to_string);
184 let task_count = snapshot.completed_task_count();
185
186 sqlx::query(
187 "UPDATE sayiir_workflow_snapshots
188 SET data = $1, status = $2, current_task_id = $3,
189 completed_task_count = $4, updated_at = now()
190 WHERE instance_id = $5",
191 )
192 .bind(&data)
193 .bind(status)
194 .bind(¤t)
195 .bind(task_count)
196 .bind(instance_id)
197 .execute(&mut *tx)
198 .await
199 .map_err(PgError)?;
200
201 sqlx::query(
203 "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, completed_at)
204 VALUES ($1, $2, 'completed', now())
205 ON CONFLICT (instance_id, task_id) DO UPDATE SET
206 status = 'completed', completed_at = now(), error = NULL",
207 )
208 .bind(instance_id)
209 .bind(task_id)
210 .execute(&mut *tx)
211 .await
212 .map_err(PgError)?;
213
214 tx.commit().await.map_err(PgError)?;
215 Ok(())
216 }
217
218 #[tracing::instrument(
219 name = "db.load_snapshot",
220 skip(self),
221 fields(db.system = "postgresql"),
222 err(level = tracing::Level::ERROR),
223 )]
224 async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
225 tracing::debug!("loading snapshot");
226 let row = sqlx::query(
227 "SELECT data, trace_parent FROM sayiir_workflow_snapshots WHERE instance_id = $1",
228 )
229 .bind(instance_id)
230 .fetch_optional(&self.pool)
231 .await
232 .map_err(PgError)?
233 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
234
235 let raw: &[u8] = row.get("data");
236 let mut snapshot = self.decode(raw)?;
237 snapshot.trace_parent = row.get("trace_parent");
238 Ok(snapshot)
239 }
240
241 #[tracing::instrument(
242 name = "db.delete_snapshot",
243 skip(self),
244 fields(db.system = "postgresql"),
245 err(level = tracing::Level::ERROR),
246 )]
247 async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
248 tracing::debug!("deleting snapshot");
249 let result = sqlx::query("DELETE FROM sayiir_workflow_snapshots WHERE instance_id = $1")
250 .bind(instance_id)
251 .execute(&self.pool)
252 .await
253 .map_err(PgError)?;
254
255 if result.rows_affected() == 0 {
256 return Err(BackendError::NotFound(instance_id.to_string()));
257 }
258 Ok(())
259 }
260
261 #[tracing::instrument(
262 name = "db.list_snapshots",
263 skip(self),
264 fields(db.system = "postgresql"),
265 err(level = tracing::Level::ERROR),
266 )]
267 async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
268 tracing::debug!("listing snapshots");
269 let rows = sqlx::query("SELECT instance_id FROM sayiir_workflow_snapshots")
270 .fetch_all(&self.pool)
271 .await
272 .map_err(PgError)?;
273
274 Ok(rows.iter().map(|r| r.get("instance_id")).collect())
275 }
276}