Skip to main content

sayiir_postgres/
task_claim_store.rs

1//! [`TaskClaimStore`] implementation for Postgres.
2
3use chrono::{Duration, Utc};
4use sayiir_core::codec::{self, Decoder, Encoder};
5use sayiir_core::snapshot::{ExecutionPosition, WorkflowSnapshot, WorkflowSnapshotState};
6use sayiir_core::task_claim::{AvailableTask, TaskClaim};
7use sayiir_persistence::{BackendError, SnapshotStore, TaskClaimStore};
8use sqlx::Row;
9
10use crate::backend::PostgresBackend;
11use crate::error::PgError;
12
13impl<C> TaskClaimStore for PostgresBackend<C>
14where
15    C: Encoder
16        + Decoder
17        + codec::sealed::EncodeValue<WorkflowSnapshot>
18        + codec::sealed::DecodeValue<WorkflowSnapshot>,
19{
20    async fn claim_task(
21        &self,
22        instance_id: &str,
23        task_id: &str,
24        worker_id: &str,
25        ttl: Option<Duration>,
26    ) -> Result<Option<TaskClaim>, BackendError> {
27        tracing::debug!(instance_id, task_id, worker_id, "claiming task");
28        let expires_at = ttl.and_then(|d| Utc::now().checked_add_signed(d));
29
30        // Insert claim; on conflict only replace if the existing claim has expired.
31        let row = sqlx::query(
32            "INSERT INTO sayiir_task_claims (instance_id, task_id, worker_id, expires_at)
33             VALUES ($1, $2, $3, $4)
34             ON CONFLICT (instance_id, task_id) DO UPDATE
35                SET worker_id = $3, claimed_at = now(), expires_at = $4
36                WHERE sayiir_task_claims.expires_at IS NOT NULL AND sayiir_task_claims.expires_at < now()
37             RETURNING instance_id, task_id, worker_id,
38                       EXTRACT(EPOCH FROM claimed_at)::BIGINT AS claimed_epoch,
39                       EXTRACT(EPOCH FROM expires_at)::BIGINT AS expires_epoch",
40        )
41        .bind(instance_id)
42        .bind(task_id)
43        .bind(worker_id)
44        .bind(expires_at)
45        .fetch_optional(&self.pool)
46        .await
47        .map_err(PgError)?;
48
49        Ok(row.map(|r| TaskClaim {
50            instance_id: r.get("instance_id"),
51            task_id: r.get("task_id"),
52            worker_id: r.get("worker_id"),
53            claimed_at: r.get::<i64, _>("claimed_epoch").cast_unsigned(),
54            expires_at: r
55                .get::<Option<i64>, _>("expires_epoch")
56                .map(i64::cast_unsigned),
57        }))
58    }
59
60    async fn release_task_claim(
61        &self,
62        instance_id: &str,
63        task_id: &str,
64        worker_id: &str,
65    ) -> Result<(), BackendError> {
66        tracing::debug!(instance_id, task_id, worker_id, "releasing task claim");
67        // Check ownership first
68        let row = sqlx::query(
69            "SELECT worker_id FROM sayiir_task_claims WHERE instance_id = $1 AND task_id = $2",
70        )
71        .bind(instance_id)
72        .bind(task_id)
73        .fetch_optional(&self.pool)
74        .await
75        .map_err(PgError)?
76        .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
77
78        let owner: String = row.get("worker_id");
79        if owner != worker_id {
80            return Err(BackendError::Backend(format!(
81                "Claim owned by different worker: {owner}"
82            )));
83        }
84
85        sqlx::query(
86            "DELETE FROM sayiir_task_claims
87             WHERE instance_id = $1 AND task_id = $2 AND worker_id = $3",
88        )
89        .bind(instance_id)
90        .bind(task_id)
91        .bind(worker_id)
92        .execute(&self.pool)
93        .await
94        .map_err(PgError)?;
95
96        Ok(())
97    }
98
99    async fn extend_task_claim(
100        &self,
101        instance_id: &str,
102        task_id: &str,
103        worker_id: &str,
104        additional_duration: Duration,
105    ) -> Result<(), BackendError> {
106        tracing::debug!(instance_id, task_id, worker_id, "extending task claim");
107        let row = sqlx::query(
108            "SELECT worker_id, expires_at FROM sayiir_task_claims
109             WHERE instance_id = $1 AND task_id = $2",
110        )
111        .bind(instance_id)
112        .bind(task_id)
113        .fetch_optional(&self.pool)
114        .await
115        .map_err(PgError)?
116        .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
117
118        let owner: String = row.get("worker_id");
119        if owner != worker_id {
120            return Err(BackendError::Backend(format!(
121                "Claim owned by different worker: {owner}"
122            )));
123        }
124
125        // Only extend if there's an expiration set
126        let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
127        if let Some(exp) = expires_at {
128            let new_exp = exp
129                .checked_add_signed(additional_duration)
130                .ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
131
132            sqlx::query(
133                "UPDATE sayiir_task_claims SET expires_at = $1
134                 WHERE instance_id = $2 AND task_id = $3",
135            )
136            .bind(new_exp)
137            .bind(instance_id)
138            .bind(task_id)
139            .execute(&self.pool)
140            .await
141            .map_err(PgError)?;
142        }
143
144        Ok(())
145    }
146
147    async fn find_available_tasks(
148        &self,
149        worker_id: &str,
150        limit: usize,
151    ) -> Result<Vec<AvailableTask>, BackendError> {
152        tracing::debug!(worker_id, limit, "finding available tasks");
153        // Step 1: Clean expired claims
154        sqlx::query(
155            "DELETE FROM sayiir_task_claims WHERE expires_at IS NOT NULL AND expires_at < now()",
156        )
157        .execute(&self.pool)
158        .await
159        .map_err(PgError)?;
160
161        // Step 2: Fetch candidate workflows via SQL bulk filter
162        let rows = sqlx::query(
163            "SELECT s.instance_id, s.data
164             FROM sayiir_workflow_snapshots s
165             WHERE s.status = 'InProgress'
166               AND NOT EXISTS (
167                   SELECT 1 FROM sayiir_task_claims c
168                   WHERE c.instance_id = s.instance_id
169                     AND c.task_id = s.current_task_id
170                     AND (c.expires_at IS NULL OR c.expires_at > now())
171               )
172               AND NOT EXISTS (
173                   SELECT 1 FROM sayiir_workflow_signals sig
174                   WHERE sig.instance_id = s.instance_id
175               )
176             ORDER BY s.updated_at ASC
177             LIMIT $1",
178        )
179        .bind(i64::try_from(limit).unwrap_or(i64::MAX))
180        .fetch_all(&self.pool)
181        .await
182        .map_err(PgError)?;
183
184        // Step 3: App-level evaluation per candidate
185        let mut available = Vec::new();
186        for row in &rows {
187            let raw: &[u8] = row.get("data");
188            let mut snapshot = self.decode(raw)?;
189
190            match &snapshot.state {
191                // Delay: if expired, advance past it
192                WorkflowSnapshotState::InProgress {
193                    position:
194                        ExecutionPosition::AtDelay {
195                            wake_at,
196                            next_task_id,
197                            delay_id,
198                            ..
199                        },
200                    ..
201                } if Utc::now() >= *wake_at => {
202                    if let Some(next_id) = next_task_id.clone() {
203                        snapshot.update_position(ExecutionPosition::AtTask { task_id: next_id });
204                        self.save_snapshot(&snapshot).await?;
205
206                        if let WorkflowSnapshotState::InProgress {
207                            position: ExecutionPosition::AtTask { task_id },
208                            completed_tasks,
209                            ..
210                        } = &snapshot.state
211                            && let Some(task) =
212                                build_available_task(&snapshot, task_id, completed_tasks, worker_id)
213                        {
214                            available.push(task);
215                        }
216                    } else {
217                        // Delay is the last node — complete the workflow
218                        let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
219                        snapshot.mark_completed(output);
220                        self.save_snapshot(&snapshot).await?;
221                    }
222                }
223
224                // Task: check retry backoff, then add to available
225                WorkflowSnapshotState::InProgress {
226                    position: ExecutionPosition::AtTask { task_id },
227                    completed_tasks,
228                    ..
229                } => {
230                    // Skip if task is already completed
231                    if completed_tasks.contains_key(task_id) {
232                        continue;
233                    }
234
235                    // Skip if retry backoff hasn't elapsed
236                    if let Some(rs) = snapshot.task_retries.get(task_id)
237                        && Utc::now() < rs.next_retry_at
238                    {
239                        continue;
240                    }
241
242                    if let Some(task) =
243                        build_available_task(&snapshot, task_id, completed_tasks, worker_id)
244                    {
245                        available.push(task);
246                    }
247                }
248
249                _ => continue,
250            }
251
252            if available.len() >= limit {
253                break;
254            }
255        }
256
257        tracing::debug!(worker_id, count = available.len(), "available tasks found");
258        Ok(available)
259    }
260}
261
262/// Build an [`AvailableTask`] from a snapshot at a task position.
263fn build_available_task(
264    snapshot: &WorkflowSnapshot,
265    task_id: &str,
266    completed_tasks: &std::collections::HashMap<String, sayiir_core::snapshot::TaskResult>,
267    _worker_id: &str,
268) -> Option<AvailableTask> {
269    let input = if completed_tasks.is_empty() {
270        snapshot.initial_input_bytes()
271    } else {
272        snapshot.get_last_task_output()
273    };
274
275    input.map(|input_bytes| AvailableTask {
276        instance_id: snapshot.instance_id.clone(),
277        task_id: task_id.to_string(),
278        input: input_bytes,
279        workflow_definition_hash: snapshot.definition_hash.clone(),
280    })
281}