Skip to main content

sayiir_postgres/
snapshot_store.rs

1//! [`SnapshotStore`] implementation for Postgres.
2
3use sayiir_core::codec::{self, Decoder, Encoder};
4use sayiir_core::snapshot::{SnapshotStatus, WorkflowSnapshot};
5use sayiir_persistence::{BackendError, SnapshotStore};
6use sqlx::Row;
7
8use crate::backend::PostgresBackend;
9use crate::error::PgError;
10use crate::helpers::{
11    completed_task_count, current_task_id, delay_wake_at, error_message, is_terminal,
12    position_kind, status_str,
13};
14
15impl<C> SnapshotStore for PostgresBackend<C>
16where
17    C: Encoder
18        + Decoder
19        + codec::sealed::EncodeValue<WorkflowSnapshot>
20        + codec::sealed::DecodeValue<WorkflowSnapshot>,
21{
22    #[tracing::instrument(
23        name = "db.save_snapshot",
24        skip(self, snapshot),
25        fields(
26            db.system = "postgresql",
27            instance_id = %snapshot.instance_id,
28            status = %status_str(&snapshot.state),
29        ),
30        err(level = tracing::Level::ERROR),
31    )]
32    #[allow(clippy::too_many_lines)]
33    async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
34        tracing::debug!("saving snapshot");
35        let data = self.encode(snapshot)?;
36        let status = status_str(&snapshot.state);
37        let task_id = current_task_id(snapshot).map(ToString::to_string);
38        let task_count = completed_task_count(snapshot);
39        let error = error_message(snapshot).map(ToString::to_string);
40        let terminal = is_terminal(snapshot);
41        let pos_kind = position_kind(snapshot);
42        let wake_at = delay_wake_at(snapshot);
43
44        let mut tx = self.pool.begin().await.map_err(PgError)?;
45
46        // Upsert current state
47        sqlx::query(
48            "INSERT INTO sayiir_workflow_snapshots
49                (instance_id, status, definition_hash, current_task_id,
50                 completed_task_count, data, error, position_kind, delay_wake_at,
51                 trace_parent, completed_at, updated_at)
52             VALUES ($1, $2, $3, $4, $5, $6, $7, $9, $10, $11,
53                     CASE WHEN $8 THEN now() ELSE NULL END, now())
54             ON CONFLICT (instance_id) DO UPDATE SET
55                status = $2,
56                definition_hash = $3,
57                current_task_id = $4,
58                completed_task_count = $5,
59                data = $6,
60                error = $7,
61                position_kind = $9,
62                delay_wake_at = $10,
63                trace_parent = $11,
64                completed_at = CASE WHEN $8 THEN now() ELSE sayiir_workflow_snapshots.completed_at END,
65                updated_at = now()",
66        )
67        .bind(&snapshot.instance_id) // $1
68        .bind(status) // $2
69        .bind(&snapshot.definition_hash) // $3
70        .bind(&task_id) // $4
71        .bind(task_count) // $5
72        .bind(&data) // $6
73        .bind(&error) // $7
74        .bind(terminal) // $8
75        .bind(pos_kind) // $9
76        .bind(wake_at) // $10
77        .bind(snapshot.trace_parent.as_deref()) // $11
78        .execute(&mut *tx)
79        .await
80        .map_err(PgError)?;
81
82        // Append to history
83        sqlx::query(
84            "INSERT INTO sayiir_workflow_snapshot_history
85                (instance_id, version, status, current_task_id, data)
86             VALUES (
87                $1,
88                (SELECT COALESCE(MAX(version), 0) + 1
89                 FROM sayiir_workflow_snapshot_history WHERE instance_id = $1),
90                $2, $3, $4
91             )",
92        )
93        .bind(&snapshot.instance_id)
94        .bind(status)
95        .bind(&task_id)
96        .bind(&data)
97        .execute(&mut *tx)
98        .await
99        .map_err(PgError)?;
100
101        // --- Maintain sayiir_workflow_tasks lifecycle ---
102
103        // If at a task, mark it as active
104        if let Some(ref tid) = task_id {
105            sqlx::query(
106                "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, started_at)
107                 VALUES ($1, $2, 'active', now())
108                 ON CONFLICT (instance_id, task_id) DO UPDATE SET
109                    status = CASE
110                        WHEN sayiir_workflow_tasks.status = 'completed' THEN sayiir_workflow_tasks.status
111                        ELSE 'active'
112                    END,
113                    started_at = COALESCE(sayiir_workflow_tasks.started_at, now())",
114            )
115            .bind(&snapshot.instance_id)
116            .bind(tid)
117            .execute(&mut *tx)
118            .await
119            .map_err(PgError)?;
120        }
121
122        // On terminal states, mark any still-active task as failed/cancelled
123        if terminal {
124            let terminal_status = match SnapshotStatus::from(&snapshot.state) {
125                SnapshotStatus::Failed => "failed",
126                SnapshotStatus::Cancelled => "cancelled",
127                _ => "completed",
128            };
129            sqlx::query(
130                "UPDATE sayiir_workflow_tasks SET status = $1, completed_at = now(), error = $2
131                 WHERE instance_id = $3 AND status = 'active'",
132            )
133            .bind(terminal_status)
134            .bind(&error)
135            .bind(&snapshot.instance_id)
136            .execute(&mut *tx)
137            .await
138            .map_err(PgError)?;
139        }
140
141        tx.commit().await.map_err(PgError)?;
142        Ok(())
143    }
144
145    #[tracing::instrument(
146        name = "db.save_task_result",
147        skip(self, output),
148        fields(db.system = "postgresql"),
149        err(level = tracing::Level::ERROR),
150    )]
151    async fn save_task_result(
152        &self,
153        instance_id: &str,
154        task_id: &str,
155        output: bytes::Bytes,
156    ) -> Result<(), BackendError> {
157        tracing::debug!("saving task result");
158        let mut tx = self.pool.begin().await.map_err(PgError)?;
159
160        // Lock and load the snapshot
161        let row = sqlx::query(
162            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
163        )
164        .bind(instance_id)
165        .fetch_optional(&mut *tx)
166        .await
167        .map_err(PgError)?
168        .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
169
170        let raw: &[u8] = row.get("data");
171        let mut snapshot = self.decode(raw)?;
172        snapshot.mark_task_completed(task_id.to_string(), output);
173
174        let data = self.encode(&snapshot)?;
175        let status = status_str(&snapshot.state);
176        let current = current_task_id(&snapshot).map(ToString::to_string);
177        let task_count = completed_task_count(&snapshot);
178
179        sqlx::query(
180            "UPDATE sayiir_workflow_snapshots
181             SET data = $1, status = $2, current_task_id = $3,
182                 completed_task_count = $4, updated_at = now()
183             WHERE instance_id = $5",
184        )
185        .bind(&data)
186        .bind(status)
187        .bind(&current)
188        .bind(task_count)
189        .bind(instance_id)
190        .execute(&mut *tx)
191        .await
192        .map_err(PgError)?;
193
194        // Mark task as completed in sayiir_workflow_tasks
195        sqlx::query(
196            "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, completed_at)
197             VALUES ($1, $2, 'completed', now())
198             ON CONFLICT (instance_id, task_id) DO UPDATE SET
199                status = 'completed', completed_at = now(), error = NULL",
200        )
201        .bind(instance_id)
202        .bind(task_id)
203        .execute(&mut *tx)
204        .await
205        .map_err(PgError)?;
206
207        tx.commit().await.map_err(PgError)?;
208        Ok(())
209    }
210
211    #[tracing::instrument(
212        name = "db.load_snapshot",
213        skip(self),
214        fields(db.system = "postgresql"),
215        err(level = tracing::Level::ERROR),
216    )]
217    async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
218        tracing::debug!("loading snapshot");
219        let row = sqlx::query(
220            "SELECT data, trace_parent FROM sayiir_workflow_snapshots WHERE instance_id = $1",
221        )
222        .bind(instance_id)
223        .fetch_optional(&self.pool)
224        .await
225        .map_err(PgError)?
226        .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
227
228        let raw: &[u8] = row.get("data");
229        let mut snapshot = self.decode(raw)?;
230        snapshot.trace_parent = row.get("trace_parent");
231        Ok(snapshot)
232    }
233
234    #[tracing::instrument(
235        name = "db.delete_snapshot",
236        skip(self),
237        fields(db.system = "postgresql"),
238        err(level = tracing::Level::ERROR),
239    )]
240    async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
241        tracing::debug!("deleting snapshot");
242        let result = sqlx::query("DELETE FROM sayiir_workflow_snapshots WHERE instance_id = $1")
243            .bind(instance_id)
244            .execute(&self.pool)
245            .await
246            .map_err(PgError)?;
247
248        if result.rows_affected() == 0 {
249            return Err(BackendError::NotFound(instance_id.to_string()));
250        }
251        Ok(())
252    }
253
254    #[tracing::instrument(
255        name = "db.list_snapshots",
256        skip(self),
257        fields(db.system = "postgresql"),
258        err(level = tracing::Level::ERROR),
259    )]
260    async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
261        tracing::debug!("listing snapshots");
262        let rows = sqlx::query("SELECT instance_id FROM sayiir_workflow_snapshots")
263            .fetch_all(&self.pool)
264            .await
265            .map_err(PgError)?;
266
267        Ok(rows.iter().map(|r| r.get("instance_id")).collect())
268    }
269}