Skip to main content

sayiir_postgres/
task_claim_store.rs

1//! [`TaskClaimStore`] implementation for Postgres.
2
3use chrono::{Duration, Utc};
4use sayiir_core::codec::{self, Decoder, Encoder};
5use sayiir_core::snapshot::{ExecutionPosition, WorkflowSnapshot, WorkflowSnapshotState};
6use sayiir_core::task_claim::{AvailableTask, TaskClaim};
7use sayiir_persistence::{BackendError, SnapshotStore, TaskClaimStore};
8use sqlx::Row;
9
10use crate::backend::PostgresBackend;
11use crate::error::PgError;
12
13impl<C> TaskClaimStore for PostgresBackend<C>
14where
15    C: Encoder
16        + Decoder
17        + codec::sealed::EncodeValue<WorkflowSnapshot>
18        + codec::sealed::DecodeValue<WorkflowSnapshot>,
19{
20    async fn claim_task(
21        &self,
22        instance_id: &str,
23        task_id: &str,
24        worker_id: &str,
25        ttl: Option<Duration>,
26    ) -> Result<Option<TaskClaim>, BackendError> {
27        tracing::debug!(instance_id, task_id, worker_id, "claiming task");
28        let expires_at = ttl.and_then(|d| Utc::now().checked_add_signed(d));
29
30        // Insert claim; on conflict only replace if the existing claim has expired.
31        let row = sqlx::query(
32            "INSERT INTO sayiir_task_claims (instance_id, task_id, worker_id, expires_at)
33             VALUES ($1, $2, $3, $4)
34             ON CONFLICT (instance_id, task_id) DO UPDATE
35                SET worker_id = $3, claimed_at = now(), expires_at = $4
36                WHERE sayiir_task_claims.expires_at IS NOT NULL AND sayiir_task_claims.expires_at < now()
37             RETURNING instance_id, task_id, worker_id,
38                       EXTRACT(EPOCH FROM claimed_at)::BIGINT AS claimed_epoch,
39                       EXTRACT(EPOCH FROM expires_at)::BIGINT AS expires_epoch",
40        )
41        .bind(instance_id)
42        .bind(task_id)
43        .bind(worker_id)
44        .bind(expires_at)
45        .fetch_optional(&self.pool)
46        .await
47        .map_err(PgError)?;
48
49        Ok(row.map(|r| TaskClaim {
50            instance_id: r.get("instance_id"),
51            task_id: r.get("task_id"),
52            worker_id: r.get("worker_id"),
53            claimed_at: r.get::<i64, _>("claimed_epoch") as u64,
54            expires_at: r.get::<Option<i64>, _>("expires_epoch").map(|e| e as u64),
55        }))
56    }
57
58    async fn release_task_claim(
59        &self,
60        instance_id: &str,
61        task_id: &str,
62        worker_id: &str,
63    ) -> Result<(), BackendError> {
64        tracing::debug!(instance_id, task_id, worker_id, "releasing task claim");
65        // Check ownership first
66        let row = sqlx::query(
67            "SELECT worker_id FROM sayiir_task_claims WHERE instance_id = $1 AND task_id = $2",
68        )
69        .bind(instance_id)
70        .bind(task_id)
71        .fetch_optional(&self.pool)
72        .await
73        .map_err(PgError)?
74        .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
75
76        let owner: String = row.get("worker_id");
77        if owner != worker_id {
78            return Err(BackendError::Backend(format!(
79                "Claim owned by different worker: {owner}"
80            )));
81        }
82
83        sqlx::query(
84            "DELETE FROM sayiir_task_claims
85             WHERE instance_id = $1 AND task_id = $2 AND worker_id = $3",
86        )
87        .bind(instance_id)
88        .bind(task_id)
89        .bind(worker_id)
90        .execute(&self.pool)
91        .await
92        .map_err(PgError)?;
93
94        Ok(())
95    }
96
97    async fn extend_task_claim(
98        &self,
99        instance_id: &str,
100        task_id: &str,
101        worker_id: &str,
102        additional_duration: Duration,
103    ) -> Result<(), BackendError> {
104        tracing::debug!(instance_id, task_id, worker_id, "extending task claim");
105        let row = sqlx::query(
106            "SELECT worker_id, expires_at FROM sayiir_task_claims
107             WHERE instance_id = $1 AND task_id = $2",
108        )
109        .bind(instance_id)
110        .bind(task_id)
111        .fetch_optional(&self.pool)
112        .await
113        .map_err(PgError)?
114        .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
115
116        let owner: String = row.get("worker_id");
117        if owner != worker_id {
118            return Err(BackendError::Backend(format!(
119                "Claim owned by different worker: {owner}"
120            )));
121        }
122
123        // Only extend if there's an expiration set
124        let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
125        if let Some(exp) = expires_at {
126            let new_exp = exp
127                .checked_add_signed(additional_duration)
128                .ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
129
130            sqlx::query(
131                "UPDATE sayiir_task_claims SET expires_at = $1
132                 WHERE instance_id = $2 AND task_id = $3",
133            )
134            .bind(new_exp)
135            .bind(instance_id)
136            .bind(task_id)
137            .execute(&self.pool)
138            .await
139            .map_err(PgError)?;
140        }
141
142        Ok(())
143    }
144
145    async fn find_available_tasks(
146        &self,
147        worker_id: &str,
148        limit: usize,
149    ) -> Result<Vec<AvailableTask>, BackendError> {
150        tracing::debug!(worker_id, limit, "finding available tasks");
151        // Step 1: Clean expired claims
152        sqlx::query(
153            "DELETE FROM sayiir_task_claims WHERE expires_at IS NOT NULL AND expires_at < now()",
154        )
155        .execute(&self.pool)
156        .await
157        .map_err(PgError)?;
158
159        // Step 2: Fetch candidate workflows via SQL bulk filter
160        let rows = sqlx::query(
161            "SELECT s.instance_id, s.data
162             FROM sayiir_workflow_snapshots s
163             WHERE s.status = 'InProgress'
164               AND NOT EXISTS (
165                   SELECT 1 FROM sayiir_task_claims c
166                   WHERE c.instance_id = s.instance_id
167                     AND c.task_id = s.current_task_id
168                     AND (c.expires_at IS NULL OR c.expires_at > now())
169               )
170               AND NOT EXISTS (
171                   SELECT 1 FROM sayiir_workflow_signals sig
172                   WHERE sig.instance_id = s.instance_id
173               )
174             ORDER BY s.updated_at ASC
175             LIMIT $1",
176        )
177        .bind(limit as i64)
178        .fetch_all(&self.pool)
179        .await
180        .map_err(PgError)?;
181
182        // Step 3: App-level evaluation per candidate
183        let mut available = Vec::new();
184        for row in &rows {
185            let raw: &[u8] = row.get("data");
186            let mut snapshot = self.decode(raw)?;
187
188            match &snapshot.state {
189                // Delay: if expired, advance past it
190                WorkflowSnapshotState::InProgress {
191                    position:
192                        ExecutionPosition::AtDelay {
193                            wake_at,
194                            next_task_id,
195                            delay_id,
196                            ..
197                        },
198                    ..
199                } if Utc::now() >= *wake_at => {
200                    if let Some(next_id) = next_task_id.clone() {
201                        snapshot.update_position(ExecutionPosition::AtTask { task_id: next_id });
202                        self.save_snapshot(&snapshot).await?;
203
204                        if let WorkflowSnapshotState::InProgress {
205                            position: ExecutionPosition::AtTask { task_id },
206                            completed_tasks,
207                            ..
208                        } = &snapshot.state
209                            && let Some(task) =
210                                build_available_task(&snapshot, task_id, completed_tasks, worker_id)
211                        {
212                            available.push(task);
213                        }
214                    } else {
215                        // Delay is the last node — complete the workflow
216                        let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
217                        snapshot.mark_completed(output);
218                        self.save_snapshot(&snapshot).await?;
219                    }
220                }
221
222                // Task: check retry backoff, then add to available
223                WorkflowSnapshotState::InProgress {
224                    position: ExecutionPosition::AtTask { task_id },
225                    completed_tasks,
226                    ..
227                } => {
228                    // Skip if task is already completed
229                    if completed_tasks.contains_key(task_id) {
230                        continue;
231                    }
232
233                    // Skip if retry backoff hasn't elapsed
234                    if let Some(rs) = snapshot.task_retries.get(task_id)
235                        && Utc::now() < rs.next_retry_at
236                    {
237                        continue;
238                    }
239
240                    if let Some(task) =
241                        build_available_task(&snapshot, task_id, completed_tasks, worker_id)
242                    {
243                        available.push(task);
244                    }
245                }
246
247                _ => continue,
248            }
249
250            if available.len() >= limit {
251                break;
252            }
253        }
254
255        tracing::debug!(worker_id, count = available.len(), "available tasks found");
256        Ok(available)
257    }
258}
259
260/// Build an [`AvailableTask`] from a snapshot at a task position.
261fn build_available_task(
262    snapshot: &WorkflowSnapshot,
263    task_id: &str,
264    completed_tasks: &std::collections::HashMap<String, sayiir_core::snapshot::TaskResult>,
265    _worker_id: &str,
266) -> Option<AvailableTask> {
267    let input = if completed_tasks.is_empty() {
268        snapshot.initial_input_bytes()
269    } else {
270        snapshot.get_last_task_output()
271    };
272
273    input.map(|input_bytes| AvailableTask {
274        instance_id: snapshot.instance_id.clone(),
275        task_id: task_id.to_string(),
276        input: input_bytes,
277        workflow_definition_hash: snapshot.definition_hash.clone(),
278    })
279}