Skip to main content

sayiir_postgres/
snapshot_store.rs

1//! [`SnapshotStore`] implementation for Postgres.
2
3use sayiir_core::codec::{self, Decoder, Encoder};
4use sayiir_core::snapshot::{SnapshotStatus, WorkflowSnapshot};
5use sayiir_persistence::{BackendError, SnapshotStore};
6use sqlx::Row;
7
8use crate::backend::PostgresBackend;
9use crate::error::PgError;
10use crate::helpers::{
11    completed_task_count, current_task_id, delay_wake_at, error_message, is_terminal,
12    position_kind, status_str,
13};
14
15impl<C> SnapshotStore for PostgresBackend<C>
16where
17    C: Encoder
18        + Decoder
19        + codec::sealed::EncodeValue<WorkflowSnapshot>
20        + codec::sealed::DecodeValue<WorkflowSnapshot>,
21{
22    #[allow(clippy::too_many_lines)]
23    async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
24        tracing::debug!(
25            instance_id = %snapshot.instance_id,
26            status = %status_str(&snapshot.state),
27            "saving snapshot"
28        );
29        let data = self.encode(snapshot)?;
30        let status = status_str(&snapshot.state);
31        let task_id = current_task_id(snapshot).map(ToString::to_string);
32        let task_count = completed_task_count(snapshot);
33        let error = error_message(snapshot).map(ToString::to_string);
34        let terminal = is_terminal(snapshot);
35        let pos_kind = position_kind(snapshot);
36        let wake_at = delay_wake_at(snapshot);
37
38        let mut tx = self.pool.begin().await.map_err(PgError)?;
39
40        // Upsert current state
41        sqlx::query(
42            "INSERT INTO sayiir_workflow_snapshots
43                (instance_id, status, definition_hash, current_task_id,
44                 completed_task_count, data, error, position_kind, delay_wake_at,
45                 completed_at, updated_at)
46             VALUES ($1, $2, $3, $4, $5, $6, $7, $9, $10,
47                     CASE WHEN $8 THEN now() ELSE NULL END, now())
48             ON CONFLICT (instance_id) DO UPDATE SET
49                status = $2,
50                definition_hash = $3,
51                current_task_id = $4,
52                completed_task_count = $5,
53                data = $6,
54                error = $7,
55                position_kind = $9,
56                delay_wake_at = $10,
57                completed_at = CASE WHEN $8 THEN now() ELSE sayiir_workflow_snapshots.completed_at END,
58                updated_at = now()",
59        )
60        .bind(&snapshot.instance_id) // $1
61        .bind(status) // $2
62        .bind(&snapshot.definition_hash) // $3
63        .bind(&task_id) // $4
64        .bind(task_count) // $5
65        .bind(&data) // $6
66        .bind(&error) // $7
67        .bind(terminal) // $8
68        .bind(pos_kind) // $9
69        .bind(wake_at) // $10
70        .execute(&mut *tx)
71        .await
72        .map_err(PgError)?;
73
74        // Append to history
75        sqlx::query(
76            "INSERT INTO sayiir_workflow_snapshot_history
77                (instance_id, version, status, current_task_id, data)
78             VALUES (
79                $1,
80                (SELECT COALESCE(MAX(version), 0) + 1
81                 FROM sayiir_workflow_snapshot_history WHERE instance_id = $1),
82                $2, $3, $4
83             )",
84        )
85        .bind(&snapshot.instance_id)
86        .bind(status)
87        .bind(&task_id)
88        .bind(&data)
89        .execute(&mut *tx)
90        .await
91        .map_err(PgError)?;
92
93        // --- Maintain sayiir_workflow_tasks lifecycle ---
94
95        // If at a task, mark it as active
96        if let Some(ref tid) = task_id {
97            sqlx::query(
98                "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, started_at)
99                 VALUES ($1, $2, 'active', now())
100                 ON CONFLICT (instance_id, task_id) DO UPDATE SET
101                    status = CASE
102                        WHEN sayiir_workflow_tasks.status = 'completed' THEN sayiir_workflow_tasks.status
103                        ELSE 'active'
104                    END,
105                    started_at = COALESCE(sayiir_workflow_tasks.started_at, now())",
106            )
107            .bind(&snapshot.instance_id)
108            .bind(tid)
109            .execute(&mut *tx)
110            .await
111            .map_err(PgError)?;
112        }
113
114        // On terminal states, mark any still-active task as failed/cancelled
115        if terminal {
116            let terminal_status = match SnapshotStatus::from(&snapshot.state) {
117                SnapshotStatus::Failed => "failed",
118                SnapshotStatus::Cancelled => "cancelled",
119                _ => "completed",
120            };
121            sqlx::query(
122                "UPDATE sayiir_workflow_tasks SET status = $1, completed_at = now(), error = $2
123                 WHERE instance_id = $3 AND status = 'active'",
124            )
125            .bind(terminal_status)
126            .bind(&error)
127            .bind(&snapshot.instance_id)
128            .execute(&mut *tx)
129            .await
130            .map_err(PgError)?;
131        }
132
133        tx.commit().await.map_err(PgError)?;
134        tracing::debug!(instance_id = %snapshot.instance_id, "snapshot saved");
135        Ok(())
136    }
137
138    async fn save_task_result(
139        &self,
140        instance_id: &str,
141        task_id: &str,
142        output: bytes::Bytes,
143    ) -> Result<(), BackendError> {
144        tracing::debug!(instance_id, task_id, "saving task result");
145        let mut tx = self.pool.begin().await.map_err(PgError)?;
146
147        // Lock and load the snapshot
148        let row = sqlx::query(
149            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
150        )
151        .bind(instance_id)
152        .fetch_optional(&mut *tx)
153        .await
154        .map_err(PgError)?
155        .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
156
157        let raw: &[u8] = row.get("data");
158        let mut snapshot = self.decode(raw)?;
159        snapshot.mark_task_completed(task_id.to_string(), output);
160
161        let data = self.encode(&snapshot)?;
162        let status = status_str(&snapshot.state);
163        let current = current_task_id(&snapshot).map(ToString::to_string);
164        let task_count = completed_task_count(&snapshot);
165
166        sqlx::query(
167            "UPDATE sayiir_workflow_snapshots
168             SET data = $1, status = $2, current_task_id = $3,
169                 completed_task_count = $4, updated_at = now()
170             WHERE instance_id = $5",
171        )
172        .bind(&data)
173        .bind(status)
174        .bind(&current)
175        .bind(task_count)
176        .bind(instance_id)
177        .execute(&mut *tx)
178        .await
179        .map_err(PgError)?;
180
181        // Mark task as completed in sayiir_workflow_tasks
182        sqlx::query(
183            "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, completed_at)
184             VALUES ($1, $2, 'completed', now())
185             ON CONFLICT (instance_id, task_id) DO UPDATE SET
186                status = 'completed', completed_at = now(), error = NULL",
187        )
188        .bind(instance_id)
189        .bind(task_id)
190        .execute(&mut *tx)
191        .await
192        .map_err(PgError)?;
193
194        tx.commit().await.map_err(PgError)?;
195        tracing::debug!(instance_id, task_id, "task result saved");
196        Ok(())
197    }
198
199    async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
200        tracing::debug!(instance_id, "loading snapshot");
201        let row = sqlx::query("SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1")
202            .bind(instance_id)
203            .fetch_optional(&self.pool)
204            .await
205            .map_err(PgError)?
206            .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
207
208        let raw: &[u8] = row.get("data");
209        self.decode(raw)
210    }
211
212    async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
213        tracing::debug!(instance_id, "deleting snapshot");
214        let result = sqlx::query("DELETE FROM sayiir_workflow_snapshots WHERE instance_id = $1")
215            .bind(instance_id)
216            .execute(&self.pool)
217            .await
218            .map_err(PgError)?;
219
220        if result.rows_affected() == 0 {
221            return Err(BackendError::NotFound(instance_id.to_string()));
222        }
223        Ok(())
224    }
225
226    async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
227        tracing::debug!("listing snapshots");
228        let rows = sqlx::query("SELECT instance_id FROM sayiir_workflow_snapshots")
229            .fetch_all(&self.pool)
230            .await
231            .map_err(PgError)?;
232
233        Ok(rows.iter().map(|r| r.get("instance_id")).collect())
234    }
235}