1use sayiir_core::codec::{self, Decoder, Encoder};
4use sayiir_core::snapshot::{SnapshotStatus, WorkflowSnapshot};
5use sayiir_persistence::{BackendError, SnapshotStore};
6use sqlx::Row;
7
8use crate::backend::PostgresBackend;
9use crate::error::PgError;
10use crate::helpers::{
11 completed_task_count, current_task_id, delay_wake_at, error_message, is_terminal,
12 position_kind, status_str,
13};
14
15impl<C> SnapshotStore for PostgresBackend<C>
16where
17 C: Encoder
18 + Decoder
19 + codec::sealed::EncodeValue<WorkflowSnapshot>
20 + codec::sealed::DecodeValue<WorkflowSnapshot>,
21{
22 async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
23 tracing::debug!(
24 instance_id = %snapshot.instance_id,
25 status = %status_str(&snapshot.state),
26 "saving snapshot"
27 );
28 let data = self.encode(snapshot)?;
29 let status = status_str(&snapshot.state);
30 let task_id = current_task_id(snapshot).map(ToString::to_string);
31 let task_count = completed_task_count(snapshot);
32 let error = error_message(snapshot).map(ToString::to_string);
33 let terminal = is_terminal(snapshot);
34 let pos_kind = position_kind(snapshot);
35 let wake_at = delay_wake_at(snapshot);
36
37 let mut tx = self.pool.begin().await.map_err(PgError)?;
38
39 sqlx::query(
41 "INSERT INTO sayiir_workflow_snapshots
42 (instance_id, status, definition_hash, current_task_id,
43 completed_task_count, data, error, position_kind, delay_wake_at,
44 completed_at, updated_at)
45 VALUES ($1, $2, $3, $4, $5, $6, $7, $9, $10,
46 CASE WHEN $8 THEN now() ELSE NULL END, now())
47 ON CONFLICT (instance_id) DO UPDATE SET
48 status = $2,
49 definition_hash = $3,
50 current_task_id = $4,
51 completed_task_count = $5,
52 data = $6,
53 error = $7,
54 position_kind = $9,
55 delay_wake_at = $10,
56 completed_at = CASE WHEN $8 THEN now() ELSE sayiir_workflow_snapshots.completed_at END,
57 updated_at = now()",
58 )
59 .bind(&snapshot.instance_id) .bind(status) .bind(&snapshot.definition_hash) .bind(&task_id) .bind(task_count) .bind(&data) .bind(&error) .bind(terminal) .bind(pos_kind) .bind(wake_at) .execute(&mut *tx)
70 .await
71 .map_err(PgError)?;
72
73 sqlx::query(
75 "INSERT INTO sayiir_workflow_snapshot_history
76 (instance_id, version, status, current_task_id, data)
77 VALUES (
78 $1,
79 (SELECT COALESCE(MAX(version), 0) + 1
80 FROM sayiir_workflow_snapshot_history WHERE instance_id = $1),
81 $2, $3, $4
82 )",
83 )
84 .bind(&snapshot.instance_id)
85 .bind(status)
86 .bind(&task_id)
87 .bind(&data)
88 .execute(&mut *tx)
89 .await
90 .map_err(PgError)?;
91
92 if let Some(ref tid) = task_id {
96 sqlx::query(
97 "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, started_at)
98 VALUES ($1, $2, 'active', now())
99 ON CONFLICT (instance_id, task_id) DO UPDATE SET
100 status = CASE
101 WHEN sayiir_workflow_tasks.status = 'completed' THEN sayiir_workflow_tasks.status
102 ELSE 'active'
103 END,
104 started_at = COALESCE(sayiir_workflow_tasks.started_at, now())",
105 )
106 .bind(&snapshot.instance_id)
107 .bind(tid)
108 .execute(&mut *tx)
109 .await
110 .map_err(PgError)?;
111 }
112
113 if terminal {
115 let terminal_status = match SnapshotStatus::from(&snapshot.state) {
116 SnapshotStatus::Failed => "failed",
117 SnapshotStatus::Cancelled => "cancelled",
118 _ => "completed",
119 };
120 sqlx::query(
121 "UPDATE sayiir_workflow_tasks SET status = $1, completed_at = now(), error = $2
122 WHERE instance_id = $3 AND status = 'active'",
123 )
124 .bind(terminal_status)
125 .bind(&error)
126 .bind(&snapshot.instance_id)
127 .execute(&mut *tx)
128 .await
129 .map_err(PgError)?;
130 }
131
132 tx.commit().await.map_err(PgError)?;
133 tracing::debug!(instance_id = %snapshot.instance_id, "snapshot saved");
134 Ok(())
135 }
136
137 async fn save_task_result(
138 &self,
139 instance_id: &str,
140 task_id: &str,
141 output: bytes::Bytes,
142 ) -> Result<(), BackendError> {
143 tracing::debug!(instance_id, task_id, "saving task result");
144 let mut tx = self.pool.begin().await.map_err(PgError)?;
145
146 let row = sqlx::query(
148 "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
149 )
150 .bind(instance_id)
151 .fetch_optional(&mut *tx)
152 .await
153 .map_err(PgError)?
154 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
155
156 let raw: &[u8] = row.get("data");
157 let mut snapshot = self.decode(raw)?;
158 snapshot.mark_task_completed(task_id.to_string(), output);
159
160 let data = self.encode(&snapshot)?;
161 let status = status_str(&snapshot.state);
162 let current = current_task_id(&snapshot).map(ToString::to_string);
163 let task_count = completed_task_count(&snapshot);
164
165 sqlx::query(
166 "UPDATE sayiir_workflow_snapshots
167 SET data = $1, status = $2, current_task_id = $3,
168 completed_task_count = $4, updated_at = now()
169 WHERE instance_id = $5",
170 )
171 .bind(&data)
172 .bind(status)
173 .bind(¤t)
174 .bind(task_count)
175 .bind(instance_id)
176 .execute(&mut *tx)
177 .await
178 .map_err(PgError)?;
179
180 sqlx::query(
182 "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, completed_at)
183 VALUES ($1, $2, 'completed', now())
184 ON CONFLICT (instance_id, task_id) DO UPDATE SET
185 status = 'completed', completed_at = now(), error = NULL",
186 )
187 .bind(instance_id)
188 .bind(task_id)
189 .execute(&mut *tx)
190 .await
191 .map_err(PgError)?;
192
193 tx.commit().await.map_err(PgError)?;
194 tracing::debug!(instance_id, task_id, "task result saved");
195 Ok(())
196 }
197
198 async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
199 tracing::debug!(instance_id, "loading snapshot");
200 let row = sqlx::query("SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1")
201 .bind(instance_id)
202 .fetch_optional(&self.pool)
203 .await
204 .map_err(PgError)?
205 .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
206
207 let raw: &[u8] = row.get("data");
208 self.decode(raw)
209 }
210
211 async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
212 tracing::debug!(instance_id, "deleting snapshot");
213 let result = sqlx::query("DELETE FROM sayiir_workflow_snapshots WHERE instance_id = $1")
214 .bind(instance_id)
215 .execute(&self.pool)
216 .await
217 .map_err(PgError)?;
218
219 if result.rows_affected() == 0 {
220 return Err(BackendError::NotFound(instance_id.to_string()));
221 }
222 Ok(())
223 }
224
225 async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
226 tracing::debug!("listing snapshots");
227 let rows = sqlx::query("SELECT instance_id FROM sayiir_workflow_snapshots")
228 .fetch_all(&self.pool)
229 .await
230 .map_err(PgError)?;
231
232 Ok(rows.iter().map(|r| r.get("instance_id")).collect())
233 }
234}