Skip to main content

sayiir_postgres/
snapshot_store.rs

1//! [`SnapshotStore`] implementation for Postgres.
2
3use sayiir_core::codec::{self, Decoder, Encoder};
4use sayiir_core::snapshot::{SnapshotStatus, WorkflowSnapshot};
5use sayiir_persistence::{BackendError, SnapshotStore};
6use sqlx::Row;
7
8use crate::backend::PostgresBackend;
9use crate::error::PgError;
10use crate::helpers::{
11    completed_task_count, current_task_id, delay_wake_at, error_message, is_terminal,
12    position_kind, status_str,
13};
14
15impl<C> SnapshotStore for PostgresBackend<C>
16where
17    C: Encoder
18        + Decoder
19        + codec::sealed::EncodeValue<WorkflowSnapshot>
20        + codec::sealed::DecodeValue<WorkflowSnapshot>,
21{
22    async fn save_snapshot(&self, snapshot: &WorkflowSnapshot) -> Result<(), BackendError> {
23        tracing::debug!(
24            instance_id = %snapshot.instance_id,
25            status = %status_str(&snapshot.state),
26            "saving snapshot"
27        );
28        let data = self.encode(snapshot)?;
29        let status = status_str(&snapshot.state);
30        let task_id = current_task_id(snapshot).map(ToString::to_string);
31        let task_count = completed_task_count(snapshot);
32        let error = error_message(snapshot).map(ToString::to_string);
33        let terminal = is_terminal(snapshot);
34        let pos_kind = position_kind(snapshot);
35        let wake_at = delay_wake_at(snapshot);
36
37        let mut tx = self.pool.begin().await.map_err(PgError)?;
38
39        // Upsert current state
40        sqlx::query(
41            "INSERT INTO sayiir_workflow_snapshots
42                (instance_id, status, definition_hash, current_task_id,
43                 completed_task_count, data, error, position_kind, delay_wake_at,
44                 completed_at, updated_at)
45             VALUES ($1, $2, $3, $4, $5, $6, $7, $9, $10,
46                     CASE WHEN $8 THEN now() ELSE NULL END, now())
47             ON CONFLICT (instance_id) DO UPDATE SET
48                status = $2,
49                definition_hash = $3,
50                current_task_id = $4,
51                completed_task_count = $5,
52                data = $6,
53                error = $7,
54                position_kind = $9,
55                delay_wake_at = $10,
56                completed_at = CASE WHEN $8 THEN now() ELSE sayiir_workflow_snapshots.completed_at END,
57                updated_at = now()",
58        )
59        .bind(&snapshot.instance_id) // $1
60        .bind(status) // $2
61        .bind(&snapshot.definition_hash) // $3
62        .bind(&task_id) // $4
63        .bind(task_count) // $5
64        .bind(&data) // $6
65        .bind(&error) // $7
66        .bind(terminal) // $8
67        .bind(pos_kind) // $9
68        .bind(wake_at) // $10
69        .execute(&mut *tx)
70        .await
71        .map_err(PgError)?;
72
73        // Append to history
74        sqlx::query(
75            "INSERT INTO sayiir_workflow_snapshot_history
76                (instance_id, version, status, current_task_id, data)
77             VALUES (
78                $1,
79                (SELECT COALESCE(MAX(version), 0) + 1
80                 FROM sayiir_workflow_snapshot_history WHERE instance_id = $1),
81                $2, $3, $4
82             )",
83        )
84        .bind(&snapshot.instance_id)
85        .bind(status)
86        .bind(&task_id)
87        .bind(&data)
88        .execute(&mut *tx)
89        .await
90        .map_err(PgError)?;
91
92        // --- Maintain sayiir_workflow_tasks lifecycle ---
93
94        // If at a task, mark it as active
95        if let Some(ref tid) = task_id {
96            sqlx::query(
97                "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, started_at)
98                 VALUES ($1, $2, 'active', now())
99                 ON CONFLICT (instance_id, task_id) DO UPDATE SET
100                    status = CASE
101                        WHEN sayiir_workflow_tasks.status = 'completed' THEN sayiir_workflow_tasks.status
102                        ELSE 'active'
103                    END,
104                    started_at = COALESCE(sayiir_workflow_tasks.started_at, now())",
105            )
106            .bind(&snapshot.instance_id)
107            .bind(tid)
108            .execute(&mut *tx)
109            .await
110            .map_err(PgError)?;
111        }
112
113        // On terminal states, mark any still-active task as failed/cancelled
114        if terminal {
115            let terminal_status = match SnapshotStatus::from(&snapshot.state) {
116                SnapshotStatus::Failed => "failed",
117                SnapshotStatus::Cancelled => "cancelled",
118                _ => "completed",
119            };
120            sqlx::query(
121                "UPDATE sayiir_workflow_tasks SET status = $1, completed_at = now(), error = $2
122                 WHERE instance_id = $3 AND status = 'active'",
123            )
124            .bind(terminal_status)
125            .bind(&error)
126            .bind(&snapshot.instance_id)
127            .execute(&mut *tx)
128            .await
129            .map_err(PgError)?;
130        }
131
132        tx.commit().await.map_err(PgError)?;
133        tracing::debug!(instance_id = %snapshot.instance_id, "snapshot saved");
134        Ok(())
135    }
136
137    async fn save_task_result(
138        &self,
139        instance_id: &str,
140        task_id: &str,
141        output: bytes::Bytes,
142    ) -> Result<(), BackendError> {
143        tracing::debug!(instance_id, task_id, "saving task result");
144        let mut tx = self.pool.begin().await.map_err(PgError)?;
145
146        // Lock and load the snapshot
147        let row = sqlx::query(
148            "SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1 FOR UPDATE",
149        )
150        .bind(instance_id)
151        .fetch_optional(&mut *tx)
152        .await
153        .map_err(PgError)?
154        .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
155
156        let raw: &[u8] = row.get("data");
157        let mut snapshot = self.decode(raw)?;
158        snapshot.mark_task_completed(task_id.to_string(), output);
159
160        let data = self.encode(&snapshot)?;
161        let status = status_str(&snapshot.state);
162        let current = current_task_id(&snapshot).map(ToString::to_string);
163        let task_count = completed_task_count(&snapshot);
164
165        sqlx::query(
166            "UPDATE sayiir_workflow_snapshots
167             SET data = $1, status = $2, current_task_id = $3,
168                 completed_task_count = $4, updated_at = now()
169             WHERE instance_id = $5",
170        )
171        .bind(&data)
172        .bind(status)
173        .bind(&current)
174        .bind(task_count)
175        .bind(instance_id)
176        .execute(&mut *tx)
177        .await
178        .map_err(PgError)?;
179
180        // Mark task as completed in sayiir_workflow_tasks
181        sqlx::query(
182            "INSERT INTO sayiir_workflow_tasks (instance_id, task_id, status, completed_at)
183             VALUES ($1, $2, 'completed', now())
184             ON CONFLICT (instance_id, task_id) DO UPDATE SET
185                status = 'completed', completed_at = now(), error = NULL",
186        )
187        .bind(instance_id)
188        .bind(task_id)
189        .execute(&mut *tx)
190        .await
191        .map_err(PgError)?;
192
193        tx.commit().await.map_err(PgError)?;
194        tracing::debug!(instance_id, task_id, "task result saved");
195        Ok(())
196    }
197
198    async fn load_snapshot(&self, instance_id: &str) -> Result<WorkflowSnapshot, BackendError> {
199        tracing::debug!(instance_id, "loading snapshot");
200        let row = sqlx::query("SELECT data FROM sayiir_workflow_snapshots WHERE instance_id = $1")
201            .bind(instance_id)
202            .fetch_optional(&self.pool)
203            .await
204            .map_err(PgError)?
205            .ok_or_else(|| BackendError::NotFound(instance_id.to_string()))?;
206
207        let raw: &[u8] = row.get("data");
208        self.decode(raw)
209    }
210
211    async fn delete_snapshot(&self, instance_id: &str) -> Result<(), BackendError> {
212        tracing::debug!(instance_id, "deleting snapshot");
213        let result = sqlx::query("DELETE FROM sayiir_workflow_snapshots WHERE instance_id = $1")
214            .bind(instance_id)
215            .execute(&self.pool)
216            .await
217            .map_err(PgError)?;
218
219        if result.rows_affected() == 0 {
220            return Err(BackendError::NotFound(instance_id.to_string()));
221        }
222        Ok(())
223    }
224
225    async fn list_snapshots(&self) -> Result<Vec<String>, BackendError> {
226        tracing::debug!("listing snapshots");
227        let rows = sqlx::query("SELECT instance_id FROM sayiir_workflow_snapshots")
228            .fetch_all(&self.pool)
229            .await
230            .map_err(PgError)?;
231
232        Ok(rows.iter().map(|r| r.get("instance_id")).collect())
233    }
234}