Skip to main content

sayiir_postgres/
task_claim_store.rs

1//! [`TaskClaimStore`] implementation for Postgres.
2
3use chrono::{Duration, Utc};
4use sayiir_core::codec::{self, Decoder, Encoder};
5use sayiir_core::snapshot::{ExecutionPosition, WorkflowSnapshot, WorkflowSnapshotState};
6use sayiir_core::task_claim::{AvailableTask, TaskClaim};
7use sayiir_persistence::{BackendError, SnapshotStore, TaskClaimStore};
8use sqlx::Row;
9
10use crate::backend::PostgresBackend;
11use crate::error::PgError;
12
13impl<C> TaskClaimStore for PostgresBackend<C>
14where
15    C: Encoder
16        + Decoder
17        + codec::sealed::EncodeValue<WorkflowSnapshot>
18        + codec::sealed::DecodeValue<WorkflowSnapshot>,
19{
20    #[tracing::instrument(
21        name = "db.claim_task",
22        skip(self),
23        fields(db.system = "postgresql"),
24        err(level = tracing::Level::ERROR),
25    )]
26    async fn claim_task(
27        &self,
28        instance_id: &str,
29        task_id: &str,
30        worker_id: &str,
31        ttl: Option<Duration>,
32    ) -> Result<Option<TaskClaim>, BackendError> {
33        tracing::debug!("claiming task");
34        let expires_at = ttl.and_then(|d| Utc::now().checked_add_signed(d));
35
36        // Insert claim; on conflict only replace if the existing claim has expired.
37        let row = sqlx::query(
38            "INSERT INTO sayiir_task_claims (instance_id, task_id, worker_id, expires_at)
39             VALUES ($1, $2, $3, $4)
40             ON CONFLICT (instance_id, task_id) DO UPDATE
41                SET worker_id = $3, claimed_at = now(), expires_at = $4
42                WHERE sayiir_task_claims.expires_at IS NOT NULL AND sayiir_task_claims.expires_at < now()
43             RETURNING instance_id, task_id, worker_id,
44                       EXTRACT(EPOCH FROM claimed_at)::BIGINT AS claimed_epoch,
45                       EXTRACT(EPOCH FROM expires_at)::BIGINT AS expires_epoch",
46        )
47        .bind(instance_id)
48        .bind(task_id)
49        .bind(worker_id)
50        .bind(expires_at)
51        .fetch_optional(&self.pool)
52        .await
53        .map_err(PgError)?;
54
55        Ok(row.map(|r| TaskClaim {
56            instance_id: r.get("instance_id"),
57            task_id: r.get("task_id"),
58            worker_id: r.get("worker_id"),
59            claimed_at: r.get::<i64, _>("claimed_epoch").cast_unsigned(),
60            expires_at: r
61                .get::<Option<i64>, _>("expires_epoch")
62                .map(i64::cast_unsigned),
63        }))
64    }
65
66    #[tracing::instrument(
67        name = "db.release_task_claim",
68        skip(self),
69        fields(db.system = "postgresql"),
70        err(level = tracing::Level::ERROR),
71    )]
72    async fn release_task_claim(
73        &self,
74        instance_id: &str,
75        task_id: &str,
76        worker_id: &str,
77    ) -> Result<(), BackendError> {
78        tracing::debug!("releasing task claim");
79        // Check ownership first
80        let row = sqlx::query(
81            "SELECT worker_id FROM sayiir_task_claims WHERE instance_id = $1 AND task_id = $2",
82        )
83        .bind(instance_id)
84        .bind(task_id)
85        .fetch_optional(&self.pool)
86        .await
87        .map_err(PgError)?
88        .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
89
90        let owner: String = row.get("worker_id");
91        if owner != worker_id {
92            return Err(BackendError::Backend(format!(
93                "Claim owned by different worker: {owner}"
94            )));
95        }
96
97        sqlx::query(
98            "DELETE FROM sayiir_task_claims
99             WHERE instance_id = $1 AND task_id = $2 AND worker_id = $3",
100        )
101        .bind(instance_id)
102        .bind(task_id)
103        .bind(worker_id)
104        .execute(&self.pool)
105        .await
106        .map_err(PgError)?;
107
108        Ok(())
109    }
110
111    #[tracing::instrument(
112        name = "db.extend_task_claim",
113        skip(self),
114        fields(db.system = "postgresql"),
115        err(level = tracing::Level::ERROR),
116    )]
117    async fn extend_task_claim(
118        &self,
119        instance_id: &str,
120        task_id: &str,
121        worker_id: &str,
122        additional_duration: Duration,
123    ) -> Result<(), BackendError> {
124        tracing::debug!("extending task claim");
125        let row = sqlx::query(
126            "SELECT worker_id, expires_at FROM sayiir_task_claims
127             WHERE instance_id = $1 AND task_id = $2",
128        )
129        .bind(instance_id)
130        .bind(task_id)
131        .fetch_optional(&self.pool)
132        .await
133        .map_err(PgError)?
134        .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
135
136        let owner: String = row.get("worker_id");
137        if owner != worker_id {
138            return Err(BackendError::Backend(format!(
139                "Claim owned by different worker: {owner}"
140            )));
141        }
142
143        // Only extend if there's an expiration set
144        let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
145        if let Some(exp) = expires_at {
146            let new_exp = exp
147                .checked_add_signed(additional_duration)
148                .ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
149
150            sqlx::query(
151                "UPDATE sayiir_task_claims SET expires_at = $1
152                 WHERE instance_id = $2 AND task_id = $3",
153            )
154            .bind(new_exp)
155            .bind(instance_id)
156            .bind(task_id)
157            .execute(&self.pool)
158            .await
159            .map_err(PgError)?;
160        }
161
162        Ok(())
163    }
164
165    #[tracing::instrument(
166        name = "db.find_available_tasks",
167        skip(self),
168        fields(db.system = "postgresql"),
169        err(level = tracing::Level::ERROR),
170    )]
171    async fn find_available_tasks(
172        &self,
173        worker_id: &str,
174        limit: usize,
175    ) -> Result<Vec<AvailableTask>, BackendError> {
176        // Step 1: Clean expired claims
177        sqlx::query(
178            "DELETE FROM sayiir_task_claims WHERE expires_at IS NOT NULL AND expires_at < now()",
179        )
180        .execute(&self.pool)
181        .await
182        .map_err(PgError)?;
183
184        // Step 2: Fetch candidate workflows via SQL bulk filter
185        let rows = sqlx::query(
186            "SELECT s.instance_id, s.data, s.trace_parent
187             FROM sayiir_workflow_snapshots s
188             WHERE s.status = 'InProgress'
189               AND NOT EXISTS (
190                   SELECT 1 FROM sayiir_task_claims c
191                   WHERE c.instance_id = s.instance_id
192                     AND c.task_id = s.current_task_id
193                     AND (c.expires_at IS NULL OR c.expires_at > now())
194               )
195               AND NOT EXISTS (
196                   SELECT 1 FROM sayiir_workflow_signals sig
197                   WHERE sig.instance_id = s.instance_id
198               )
199             ORDER BY s.updated_at ASC
200             LIMIT $1",
201        )
202        .bind(i64::try_from(limit).unwrap_or(i64::MAX))
203        .fetch_all(&self.pool)
204        .await
205        .map_err(PgError)?;
206
207        // Step 3: App-level evaluation per candidate
208        let mut available = Vec::new();
209        for row in &rows {
210            let raw: &[u8] = row.get("data");
211            let mut snapshot = self.decode(raw)?;
212            snapshot.trace_parent = row.get("trace_parent");
213
214            match &snapshot.state {
215                // Delay: if expired, advance past it
216                WorkflowSnapshotState::InProgress {
217                    position:
218                        ExecutionPosition::AtDelay {
219                            wake_at,
220                            next_task_id,
221                            delay_id,
222                            ..
223                        },
224                    ..
225                } if Utc::now() >= *wake_at => {
226                    if let Some(next_id) = next_task_id.clone() {
227                        snapshot.update_position(ExecutionPosition::AtTask { task_id: next_id });
228                        self.save_snapshot(&snapshot).await?;
229
230                        if let WorkflowSnapshotState::InProgress {
231                            position: ExecutionPosition::AtTask { task_id },
232                            completed_tasks,
233                            ..
234                        } = &snapshot.state
235                            && let Some(task) =
236                                build_available_task(&snapshot, task_id, completed_tasks, worker_id)
237                        {
238                            available.push(task);
239                        }
240                    } else {
241                        // Delay is the last node — complete the workflow
242                        let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
243                        snapshot.mark_completed(output);
244                        self.save_snapshot(&snapshot).await?;
245                    }
246                }
247
248                // Task: check retry backoff, then add to available
249                WorkflowSnapshotState::InProgress {
250                    position: ExecutionPosition::AtTask { task_id },
251                    completed_tasks,
252                    ..
253                } => {
254                    // Skip if task is already completed
255                    if completed_tasks.contains_key(task_id) {
256                        continue;
257                    }
258
259                    // Skip if retry backoff hasn't elapsed
260                    if let Some(rs) = snapshot.task_retries.get(task_id)
261                        && Utc::now() < rs.next_retry_at
262                    {
263                        continue;
264                    }
265
266                    if let Some(task) =
267                        build_available_task(&snapshot, task_id, completed_tasks, worker_id)
268                    {
269                        available.push(task);
270                    }
271                }
272
273                _ => continue,
274            }
275
276            if available.len() >= limit {
277                break;
278            }
279        }
280
281        tracing::debug!(count = available.len(), "available tasks found");
282        Ok(available)
283    }
284}
285
286/// Build an [`AvailableTask`] from a snapshot at a task position.
287fn build_available_task(
288    snapshot: &WorkflowSnapshot,
289    task_id: &str,
290    completed_tasks: &std::collections::HashMap<String, sayiir_core::snapshot::TaskResult>,
291    _worker_id: &str,
292) -> Option<AvailableTask> {
293    let input = if completed_tasks.is_empty() {
294        snapshot.initial_input_bytes()
295    } else {
296        snapshot.get_last_task_output()
297    };
298
299    input.map(|input_bytes| AvailableTask {
300        instance_id: snapshot.instance_id.clone(),
301        task_id: task_id.to_string(),
302        input: input_bytes,
303        workflow_definition_hash: snapshot.definition_hash.clone(),
304        trace_parent: snapshot.trace_parent.clone(),
305    })
306}