1use sayiir_core::codec::{self, Decoder, Encoder};
4use sayiir_core::snapshot::{SnapshotStatus, WorkflowSnapshot};
5use sayiir_persistence::{BackendError, SnapshotStore};
6use sqlx::Row;
7
8use crate::backend::PostgresBackend;
9use crate::error::PgError;
10use crate::helpers::{
11 completed_task_count, current_task_id, delay_wake_at, error_message, is_terminal,
12 position_kind, status_str,
13};
14
15impl<C> SnapshotStore for PostgresBackend<C>
16where
17 C: Encoder
18 + Decoder
19 + codec::sealed::EncodeValue<WorkflowSnapshot>
20 + codec::sealed::DecodeValue<WorkflowSnapshot>,
21{
22 #[allow(clippy::too_many_lines)]
23 async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
24 tracing::debug!(
25 instance_id = %snapshot.instance_id,
26 status = %status_str(&snapshot.state),
27 "saving snapshot"
28 );
29 let data = self.encode(snapshot)?;
30 let status = status_str(&snapshot.state);
31 let task_id = current_task_id(snapshot).map(ToString::to_string);
32 let task_count = completed_task_count(snapshot);
33 let error = error_message(snapshot).map(ToString::to_string);
34 let terminal = is_terminal(snapshot);
35 let pos_kind = position_kind(snapshot);
36 let wake_at = delay_wake_at(snapshot);
37
38 let mut tx = self.pool.begin().await.map_err(PgError)?;
39
40 sqlx::query(
42 "INSERT INTO sayiir_workflow_snapshots
43 (instance_id, status, definition_hash, current_task_id,
44 completed_task_count, data, error, position_kind, delay_wake_at,
45 completed_at, updated_at)
46 VALUES ($1, $2, $3, $4, $5, $6, $7, $9, $10,
47 CASE WHEN $8 THEN now() ELSE NULL END, now())
48 ON CONFLICT (instance_id) DO UPDATE SET
49 status = $2,
50 definition_hash = $3,
51 current_task_id = $4,
52 completed_task_count = $5,
53 data = $6,
54 error = $7,
55 position_kind = $9,
56 delay_wake_at = $10,
57 completed_at = CASE WHEN $8 THEN now() ELSE sayiir_workflow_snapshots.completed_at END,
58 updated_at = now()",
59 )
60 .bind(&snapshot.instance_id) .bind(status) .bind(&snapshot.definition_hash) .bind(&task_id) .bind(task_count) .bind(&data) .bind(&error) .bind(terminal) .bind(pos_kind) .bind(wake_at) .execute(&mut *tx)
71 .await
72 .map_err(PgError)?;
73
74 sqlx::query(
76 "INSERT INTO sayiir_workflow_snapshot_history
77 (instance_id, version, status, current_task_id, data)
78 VALUES (
79 $1,
80 (SELECT COALESCE(MAX(version), 0) + 1
81 FROM sayiir_workflow_snapshot_history WHERE instance_id = $1),
82 $2, $3, $4
83 )",
84 )
85 .bind(&snapshot.instance_id)
86 .bind(status)
87 .bind(&task_id)
88 .bind(&data)
89 .execute(&mut *tx)
90 .await
91 .map_err(PgError)?;
92
93 if let Some(ref tid) = task_id {
97 sqlx::query(
98 "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, started_at)
99 VALUES ($1, $2, 'active', now())
100 ON CONFLICT (instance_id, task_id) DO UPDATE SET
101 status = CASE
102 WHEN sayiir_workflow_tasks.status = 'completed' THEN sayiir_workflow_tasks.status
103 ELSE 'active'
104 END,
105 started_at = COALESCE(sayiir_workflow_tasks.started_at, now())",
106 )
107 .bind(&snapshot.instance_id)
108 .bind(tid)
109 .execute(&mut *tx)
110 .await
111 .map_err(PgError)?;
112 }
113
114 if terminal {
116 let terminal_status = match SnapshotStatus::from(&snapshot.state) {
117 SnapshotStatus::Failed => "failed",
118 SnapshotStatus::Cancelled => "cancelled",
119 _ => "completed",
120 };
121 sqlx::query(
122 "UPDATE sayiir_workflow_tasks SET status = $1, completed_at = now(), error = $2
123 WHERE instance_id = $3 AND status = 'active'",
124 )
125 .bind(terminal_status)
126 .bind(&error)
127 .bind(&snapshot.instance_id)
128 .execute(&mut *tx)
129 .await
130 .map_err(PgError)?;
131 }
132
133 tx.commit().await.map_err(PgError)?;
134 tracing::debug!(instance_id = %snapshot.instance_id, "snapshot saved");
135 Ok(())
136 }
137
138 async fn save_task_result(
139 &self,
140 instance_id: &str,
141 task_id: &str,
142 output: bytes::Bytes,
143 ) -> Result<(), BackendError> {
144 tracing::debug!(instance_id, task_id, "saving task result");
145 let mut tx = self.pool.begin().await.map_err(PgError)?;
146
147 let row = sqlx::query(
149 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
150 )
151 .bind(instance_id)
152 .fetch_optional(&mut *tx)
153 .await
154 .map_err(PgError)?
155 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
156
157 let raw: &[u8] = row.get("data");
158 let mut snapshot = self.decode(raw)?;
159 snapshot.mark_task_completed(task_id.to_string(), output);
160
161 let data = self.encode(&snapshot)?;
162 let status = status_str(&snapshot.state);
163 let current = current_task_id(&snapshot).map(ToString::to_string);
164 let task_count = completed_task_count(&snapshot);
165
166 sqlx::query(
167 "UPDATE sayiir_workflow_snapshots
168 SET data = $1, status = $2, current_task_id = $3,
169 completed_task_count = $4, updated_at = now()
170 WHERE instance_id = $5",
171 )
172 .bind(&data)
173 .bind(status)
174 .bind(¤t)
175 .bind(task_count)
176 .bind(instance_id)
177 .execute(&mut *tx)
178 .await
179 .map_err(PgError)?;
180
181 sqlx::query(
183 "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, completed_at)
184 VALUES ($1, $2, 'completed', now())
185 ON CONFLICT (instance_id, task_id) DO UPDATE SET
186 status = 'completed', completed_at = now(), error = NULL",
187 )
188 .bind(instance_id)
189 .bind(task_id)
190 .execute(&mut *tx)
191 .await
192 .map_err(PgError)?;
193
194 tx.commit().await.map_err(PgError)?;
195 tracing::debug!(instance_id, task_id, "task result saved");
196 Ok(())
197 }
198
199 async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
200 tracing::debug!(instance_id, "loading snapshot");
201 let row = sqlx::query("SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1")
202 .bind(instance_id)
203 .fetch_optional(&self.pool)
204 .await
205 .map_err(PgError)?
206 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
207
208 let raw: &[u8] = row.get("data");
209 self.decode(raw)
210 }
211
212 async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
213 tracing::debug!(instance_id, "deleting snapshot");
214 let result = sqlx::query("DELETE FROM sayiir_workflow_snapshots WHERE instance_id = $1")
215 .bind(instance_id)
216 .execute(&self.pool)
217 .await
218 .map_err(PgError)?;
219
220 if result.rows_affected() == 0 {
221 return Err(BackendError::NotFound(instance_id.to_string()));
222 }
223 Ok(())
224 }
225
226 async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
227 tracing::debug!("listing snapshots");
228 let rows = sqlx::query("SELECT instance_id FROM sayiir_workflow_snapshots")
229 .fetch_all(&self.pool)
230 .await
231 .map_err(PgError)?;
232
233 Ok(rows.iter().map(|r| r.get("instance_id")).collect())
234 }
235}