Skip to main content

sayiir_postgres/
snapshot_store.rs

1//! [`SnapshotStore`] implementation for Postgres.
2
3use sayiir_core::codec::{self, Decoder, Encoder};
4use sayiir_core::snapshot::{SnapshotStatus, WorkflowSnapshot};
5use sayiir_persistence::{BackendError, SnapshotStore};
6use sqlx::Row;
7
8use crate::backend::PostgresBackend;
9use crate::error::PgError;
10
11impl<C> SnapshotStore for PostgresBackend<C>
12where
13    C: Encoder
14        + Decoder
15        + codec::sealed::EncodeValue<WorkflowSnapshot>
16        + codec::sealed::DecodeValue<WorkflowSnapshot>,
17{
18    #[tracing::instrument(
19        name = "db.save_snapshot",
20        skip(self, snapshot),
21        fields(
22            db.system = "postgresql",
23            instance_id = %snapshot.instance_id,
24            status = %snapshot.state.as_ref(),
25        ),
26        err(level = tracing::Level::ERROR),
27    )]
28    #[allow(clippy::too_many_lines)]
29    async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
30        tracing::debug!("saving snapshot");
31        let data = self.encode(snapshot)?;
32        let status = snapshot.state.as_ref();
33        let task_id = snapshot.current_task_id().map(ToString::to_string);
34        let task_count = snapshot.completed_task_count();
35        let error = snapshot.error_message().map(ToString::to_string);
36        let terminal = snapshot.state.is_terminal();
37        let pos_kind = snapshot.position_kind();
38        let wake_at = snapshot.delay_wake_at();
39
40        let task_priority = i16::from(snapshot.current_task_priority());
41        let task_tags: Vec<&str> = snapshot
42            .current_task_tags()
43            .iter()
44            .map(String::as_str)
45            .collect();
46
47        let mut tx = self.pool.begin().await.map_err(PgError)?;
48
49        // Upsert current state
50        sqlx::query(
51            "INSERT INTO sayiir_workflow_snapshots
52                (instance_id, status, definition_hash, current_task_id,
53                 completed_task_count, data, error, position_kind, delay_wake_at,
54                 trace_parent, task_priority, task_tags, completed_at, updated_at)
55             VALUES ($1, $2, $3, $4, $5, $6, $7, $9, $10, $11, $12, $13,
56                     CASE WHEN $8 THEN now() ELSE NULL END, now())
57             ON CONFLICT (instance_id) DO UPDATE SET
58                status = $2,
59                definition_hash = $3,
60                current_task_id = $4,
61                completed_task_count = $5,
62                data = $6,
63                error = $7,
64                position_kind = $9,
65                delay_wake_at = $10,
66                trace_parent = $11,
67                task_priority = $12,
68                task_tags = $13,
69                completed_at = CASE WHEN $8 THEN now() ELSE sayiir_workflow_snapshots.completed_at END,
70                updated_at = now()",
71        )
72        .bind(&snapshot.instance_id) // $1
73        .bind(status) // $2
74        .bind(&snapshot.definition_hash) // $3
75        .bind(&task_id) // $4
76        .bind(task_count) // $5
77        .bind(&data) // $6
78        .bind(&error) // $7
79        .bind(terminal) // $8
80        .bind(pos_kind) // $9
81        .bind(wake_at) // $10
82        .bind(snapshot.trace_parent.as_deref()) // $11
83        .bind(task_priority) // $12
84        .bind(&task_tags) // $13
85        .execute(&mut *tx)
86        .await
87        .map_err(PgError)?;
88
89        // Append to history
90        sqlx::query(
91            "INSERT INTO sayiir_workflow_snapshot_history
92                (instance_id, version, status, current_task_id, data)
93             VALUES (
94                $1,
95                (SELECT COALESCE(MAX(version), 0) + 1
96                 FROM sayiir_workflow_snapshot_history WHERE instance_id = $1),
97                $2, $3, $4
98             )",
99        )
100        .bind(&snapshot.instance_id)
101        .bind(status)
102        .bind(&task_id)
103        .bind(&data)
104        .execute(&mut *tx)
105        .await
106        .map_err(PgError)?;
107
108        // --- Maintain sayiir_workflow_tasks lifecycle ---
109
110        // If at a task, mark it as active
111        if let Some(ref tid) = task_id {
112            sqlx::query(
113                "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, started_at)
114                 VALUES ($1, $2, 'active', now())
115                 ON CONFLICT (instance_id, task_id) DO UPDATE SET
116                    status = CASE
117                        WHEN sayiir_workflow_tasks.status = 'completed' THEN sayiir_workflow_tasks.status
118                        ELSE 'active'
119                    END,
120                    started_at = COALESCE(sayiir_workflow_tasks.started_at, now())",
121            )
122            .bind(&snapshot.instance_id)
123            .bind(tid)
124            .execute(&mut *tx)
125            .await
126            .map_err(PgError)?;
127        }
128
129        // On terminal states, mark any still-active task as failed/cancelled
130        if terminal {
131            let terminal_status = match SnapshotStatus::from(&snapshot.state) {
132                SnapshotStatus::Failed => "failed",
133                SnapshotStatus::Cancelled => "cancelled",
134                _ => "completed",
135            };
136            sqlx::query(
137                "UPDATE sayiir_workflow_tasks SET status = $1, completed_at = now(), error = $2
138                 WHERE instance_id = $3 AND status = 'active'",
139            )
140            .bind(terminal_status)
141            .bind(&error)
142            .bind(&snapshot.instance_id)
143            .execute(&mut *tx)
144            .await
145            .map_err(PgError)?;
146        }
147
148        tx.commit().await.map_err(PgError)?;
149        Ok(())
150    }
151
152    #[tracing::instrument(
153        name = "db.save_task_result",
154        skip(self, output),
155        fields(db.system = "postgresql"),
156        err(level = tracing::Level::ERROR),
157    )]
158    async fn save_task_result(
159        &self,
160        instance_id: &str,
161        task_id: &str,
162        output: bytes::Bytes,
163    ) -> Result<(), BackendError> {
164        tracing::debug!("saving task result");
165        let mut tx = self.pool.begin().await.map_err(PgError)?;
166
167        // Lock and load the snapshot
168        let row = sqlx::query(
169            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
170        )
171        .bind(instance_id)
172        .fetch_optional(&mut *tx)
173        .await
174        .map_err(PgError)?
175        .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
176
177        let raw: &[u8] = row.get("data");
178        let mut snapshot = self.decode(raw)?;
179        snapshot.mark_task_completed(task_id.to_string(), output);
180
181        let data = self.encode(&snapshot)?;
182        let status = snapshot.state.as_ref();
183        let current = snapshot.current_task_id().map(ToString::to_string);
184        let task_count = snapshot.completed_task_count();
185
186        sqlx::query(
187            "UPDATE sayiir_workflow_snapshots
188             SET data = $1, status = $2, current_task_id = $3,
189                 completed_task_count = $4, updated_at = now()
190             WHERE instance_id = $5",
191        )
192        .bind(&data)
193        .bind(status)
194        .bind(&current)
195        .bind(task_count)
196        .bind(instance_id)
197        .execute(&mut *tx)
198        .await
199        .map_err(PgError)?;
200
201        // Mark task as completed in sayiir_workflow_tasks
202        sqlx::query(
203            "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, completed_at)
204             VALUES ($1, $2, 'completed', now())
205             ON CONFLICT (instance_id, task_id) DO UPDATE SET
206                status = 'completed', completed_at = now(), error = NULL",
207        )
208        .bind(instance_id)
209        .bind(task_id)
210        .execute(&mut *tx)
211        .await
212        .map_err(PgError)?;
213
214        tx.commit().await.map_err(PgError)?;
215        Ok(())
216    }
217
218    #[tracing::instrument(
219        name = "db.load_snapshot",
220        skip(self),
221        fields(db.system = "postgresql"),
222        err(level = tracing::Level::ERROR),
223    )]
224    async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
225        tracing::debug!("loading snapshot");
226        let row = sqlx::query(
227            "SELECT data, trace_parent FROM sayiir_workflow_snapshots WHERE instance_id = $1",
228        )
229        .bind(instance_id)
230        .fetch_optional(&self.pool)
231        .await
232        .map_err(PgError)?
233        .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
234
235        let raw: &[u8] = row.get("data");
236        let mut snapshot = self.decode(raw)?;
237        snapshot.trace_parent = row.get("trace_parent");
238        Ok(snapshot)
239    }
240
241    #[tracing::instrument(
242        name = "db.delete_snapshot",
243        skip(self),
244        fields(db.system = "postgresql"),
245        err(level = tracing::Level::ERROR),
246    )]
247    async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
248        tracing::debug!("deleting snapshot");
249        let result = sqlx::query("DELETE FROM sayiir_workflow_snapshots WHERE instance_id = $1")
250            .bind(instance_id)
251            .execute(&self.pool)
252            .await
253            .map_err(PgError)?;
254
255        if result.rows_affected() == 0 {
256            return Err(BackendError::NotFound(instance_id.to_string()));
257        }
258        Ok(())
259    }
260
261    #[tracing::instrument(
262        name = "db.list_snapshots",
263        skip(self),
264        fields(db.system = "postgresql"),
265        err(level = tracing::Level::ERROR),
266    )]
267    async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
268        tracing::debug!("listing snapshots");
269        let rows = sqlx::query("SELECT instance_id FROM sayiir_workflow_snapshots")
270            .fetch_all(&self.pool)
271            .await
272            .map_err(PgError)?;
273
274        Ok(rows.iter().map(|r| r.get("instance_id")).collect())
275    }
276}