Skip to main content

sayiir_postgres/
task_claim_store.rs

1//! [`TaskClaimStore`] implementation for Postgres.
2
3use chrono::{Duration, Utc};
4use sayiir_core::codec::{self, Decoder, Encoder};
5use sayiir_core::snapshot::{ExecutionPosition, WorkflowSnapshot, WorkflowSnapshotState};
6use sayiir_core::task_claim::{AvailableTask, TaskClaim};
7use sayiir_persistence::{BackendError, SnapshotStore, TaskClaimStore};
8use sqlx::Row;
9
10use crate::backend::PostgresBackend;
11use crate::error::PgError;
12
13impl<C> TaskClaimStore for PostgresBackend<C>
14where
15    C: Encoder
16        + Decoder
17        + codec::sealed::EncodeValue<WorkflowSnapshot>
18        + codec::sealed::DecodeValue<WorkflowSnapshot>,
19{
20    #[tracing::instrument(
21        name = "db.claim_task",
22        skip(self),
23        fields(db.system = "postgresql"),
24        err(level = tracing::Level::ERROR),
25    )]
26    async fn claim_task(
27        &self,
28        instance_id: &str,
29        task_id: &str,
30        worker_id: &str,
31        ttl: Option<Duration>,
32    ) -> Result<Option<TaskClaim>, BackendError> {
33        tracing::debug!("claiming task");
34        let expires_at = ttl.and_then(|d| Utc::now().checked_add_signed(d));
35
36        // Insert claim; on conflict only replace if the existing claim has expired.
37        let row = sqlx::query(
38            "INSERT INTO sayiir_task_claims (instance_id, task_id, worker_id, expires_at)
39             VALUES ($1, $2, $3, $4)
40             ON CONFLICT (instance_id, task_id) DO UPDATE
41                SET worker_id = $3, claimed_at = now(), expires_at = $4
42                WHERE sayiir_task_claims.expires_at IS NOT NULL AND sayiir_task_claims.expires_at < now()
43             RETURNING instance_id, task_id, worker_id,
44                       EXTRACT(EPOCH FROM claimed_at)::BIGINT AS claimed_epoch,
45                       EXTRACT(EPOCH FROM expires_at)::BIGINT AS expires_epoch",
46        )
47        .bind(instance_id)
48        .bind(task_id)
49        .bind(worker_id)
50        .bind(expires_at)
51        .fetch_optional(&self.pool)
52        .await
53        .map_err(PgError)?;
54
55        Ok(row.map(|r| TaskClaim {
56            instance_id: r.get("instance_id"),
57            task_id: r.get("task_id"),
58            worker_id: r.get("worker_id"),
59            claimed_at: r.get::<i64, _>("claimed_epoch").cast_unsigned(),
60            expires_at: r
61                .get::<Option<i64>, _>("expires_epoch")
62                .map(i64::cast_unsigned),
63        }))
64    }
65
66    #[tracing::instrument(
67        name = "db.release_task_claim",
68        skip(self),
69        fields(db.system = "postgresql"),
70        err(level = tracing::Level::ERROR),
71    )]
72    async fn release_task_claim(
73        &self,
74        instance_id: &str,
75        task_id: &str,
76        worker_id: &str,
77    ) -> Result<(), BackendError> {
78        tracing::debug!("releasing task claim");
79        // Check ownership first
80        let row = sqlx::query(
81            "SELECT worker_id FROM sayiir_task_claims WHERE instance_id = $1 AND task_id = $2",
82        )
83        .bind(instance_id)
84        .bind(task_id)
85        .fetch_optional(&self.pool)
86        .await
87        .map_err(PgError)?
88        .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
89
90        let owner: String = row.get("worker_id");
91        if owner != worker_id {
92            return Err(BackendError::Backend(format!(
93                "Claim owned by different worker: {owner}"
94            )));
95        }
96
97        sqlx::query(
98            "DELETE FROM sayiir_task_claims
99             WHERE instance_id = $1 AND task_id = $2 AND worker_id = $3",
100        )
101        .bind(instance_id)
102        .bind(task_id)
103        .bind(worker_id)
104        .execute(&self.pool)
105        .await
106        .map_err(PgError)?;
107
108        Ok(())
109    }
110
111    #[tracing::instrument(
112        name = "db.extend_task_claim",
113        skip(self),
114        fields(db.system = "postgresql"),
115        err(level = tracing::Level::ERROR),
116    )]
117    async fn extend_task_claim(
118        &self,
119        instance_id: &str,
120        task_id: &str,
121        worker_id: &str,
122        additional_duration: Duration,
123    ) -> Result<(), BackendError> {
124        tracing::debug!("extending task claim");
125        let row = sqlx::query(
126            "SELECT worker_id, expires_at FROM sayiir_task_claims
127             WHERE instance_id = $1 AND task_id = $2",
128        )
129        .bind(instance_id)
130        .bind(task_id)
131        .fetch_optional(&self.pool)
132        .await
133        .map_err(PgError)?
134        .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
135
136        let owner: String = row.get("worker_id");
137        if owner != worker_id {
138            return Err(BackendError::Backend(format!(
139                "Claim owned by different worker: {owner}"
140            )));
141        }
142
143        // Only extend if there's an expiration set
144        let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
145        if let Some(exp) = expires_at {
146            let new_exp = exp
147                .checked_add_signed(additional_duration)
148                .ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
149
150            sqlx::query(
151                "UPDATE sayiir_task_claims SET expires_at = $1
152                 WHERE instance_id = $2 AND task_id = $3",
153            )
154            .bind(new_exp)
155            .bind(instance_id)
156            .bind(task_id)
157            .execute(&self.pool)
158            .await
159            .map_err(PgError)?;
160        }
161
162        Ok(())
163    }
164
165    #[tracing::instrument(
166        name = "db.find_available_tasks",
167        skip(self),
168        fields(db.system = "postgresql"),
169        err(level = tracing::Level::ERROR),
170    )]
171    #[allow(clippy::cast_precision_loss, clippy::too_many_lines)]
172    async fn find_available_tasks(
173        &self,
174        worker_id: &str,
175        limit: usize,
176        aging_interval: Duration,
177        worker_tags: &[String],
178    ) -> Result<Vec<AvailableTask>, BackendError> {
179        // Step 1: Clean expired claims
180        sqlx::query(
181            "DELETE FROM sayiir_task_claims WHERE expires_at IS NOT NULL AND expires_at < now()",
182        )
183        .execute(&self.pool)
184        .await
185        .map_err(PgError)?;
186
187        // Step 2: Fetch candidate workflows ordered by effective priority with aging.
188        // effective_priority = task_priority - (seconds_waiting / aging_interval)
189        // Clamp to a minimum of 1s to prevent division by zero in the SQL expression.
190        let aging_secs = (aging_interval.num_milliseconds() as f64 / 1000.0).max(1.0);
191        let worker_tags_vec: Vec<&str> = worker_tags.iter().map(String::as_str).collect();
192
193        // When worker has tags, add filter: task_tags must be a subset of
194        // worker_tags. The `<@` operator checks array containment; an empty
195        // array is a subset of every array, so untagged tasks always pass.
196        let tag_filter = if worker_tags.is_empty() {
197            ""
198        } else {
199            "AND s.task_tags <@ $3"
200        };
201
202        let query = format!(
203            "SELECT s.instance_id, s.data, s.trace_parent
204             FROM sayiir_workflow_snapshots s
205             WHERE s.status = 'InProgress'
206               AND NOT EXISTS (
207                   SELECT 1 FROM sayiir_task_claims c
208                   WHERE c.instance_id = s.instance_id
209                     AND c.task_id = s.current_task_id
210                     AND (c.expires_at IS NULL OR c.expires_at > now())
211               )
212               AND NOT EXISTS (
213                   SELECT 1 FROM sayiir_workflow_signals sig
214                   WHERE sig.instance_id = s.instance_id
215               )
216               {tag_filter}
217             ORDER BY
218               (s.task_priority - EXTRACT(EPOCH FROM (now() - s.updated_at)) / $2) ASC,
219               s.updated_at ASC
220             LIMIT $1"
221        );
222
223        let mut q = sqlx::query(&query)
224            .bind(i64::try_from(limit).unwrap_or(i64::MAX))
225            .bind(aging_secs);
226        if !worker_tags.is_empty() {
227            q = q.bind(&worker_tags_vec);
228        }
229        let rows = q.fetch_all(&self.pool).await.map_err(PgError)?;
230
231        // Step 3: App-level evaluation per candidate.
232        // Collect (worker_failed_here, task) pairs so we can stable-sort by
233        // worker bias afterwards without re-decoding.
234        let mut available: Vec<(bool, AvailableTask)> = Vec::with_capacity(rows.len());
235        for row in &rows {
236            let raw: &[u8] = row.get("data");
237            let mut snapshot = self.decode(raw)?;
238            snapshot.trace_parent = row.get("trace_parent");
239
240            match &snapshot.state {
241                // Delay: if expired, advance past it
242                WorkflowSnapshotState::InProgress {
243                    position:
244                        ExecutionPosition::AtDelay {
245                            wake_at,
246                            next_task_id,
247                            delay_id,
248                            ..
249                        },
250                    ..
251                } if Utc::now() >= *wake_at => {
252                    if let Some(next_id) = next_task_id.clone() {
253                        snapshot.update_position(ExecutionPosition::AtTask { task_id: next_id });
254                        self.save_snapshot(&snapshot).await?;
255
256                        if let WorkflowSnapshotState::InProgress {
257                            position: ExecutionPosition::AtTask { task_id },
258                            completed_tasks,
259                            ..
260                        } = &snapshot.state
261                            && let Some(task) =
262                                build_available_task(&snapshot, task_id, completed_tasks, worker_id)
263                        {
264                            let bias = snapshot.has_failed_on_worker(task_id, worker_id);
265                            available.push((bias, task));
266                        }
267                    } else {
268                        // Delay is the last node — complete the workflow
269                        let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
270                        snapshot.mark_completed(output);
271                        self.save_snapshot(&snapshot).await?;
272                    }
273                }
274
275                // Task: check retry backoff, then add to available
276                WorkflowSnapshotState::InProgress {
277                    position: ExecutionPosition::AtTask { task_id },
278                    completed_tasks,
279                    ..
280                } => {
281                    // Skip if task is already completed
282                    if completed_tasks.contains_key(task_id) {
283                        continue;
284                    }
285
286                    // Skip if retry backoff hasn't elapsed
287                    if let Some(rs) = snapshot.task_retries.get(task_id)
288                        && Utc::now() < rs.next_retry_at
289                    {
290                        continue;
291                    }
292
293                    if let Some(task) =
294                        build_available_task(&snapshot, task_id, completed_tasks, worker_id)
295                    {
296                        let bias = snapshot.has_failed_on_worker(task_id, worker_id);
297                        available.push((bias, task));
298                    }
299                }
300
301                _ => continue,
302            }
303
304            if available.len() >= limit {
305                break;
306            }
307        }
308
309        // Step 4: Stable-sort by worker bias so tasks whose last failure was on
310        // this worker sink to the bottom, while preserving the effective-priority
311        // order from the SQL query for everything else.
312        available.sort_by_key(|(bias, _)| *bias);
313
314        tracing::debug!(count = available.len(), "available tasks found");
315        Ok(available.into_iter().map(|(_, task)| task).collect())
316    }
317}
318
319/// Build an [`AvailableTask`] from a snapshot at a task position.
320fn build_available_task(
321    snapshot: &WorkflowSnapshot,
322    task_id: &str,
323    completed_tasks: &std::collections::HashMap<String, sayiir_core::snapshot::TaskResult>,
324    _worker_id: &str,
325) -> Option<AvailableTask> {
326    let input = if completed_tasks.is_empty() {
327        snapshot.initial_input_bytes()
328    } else {
329        snapshot.get_last_task_output()
330    };
331
332    input.map(|input_bytes| AvailableTask {
333        instance_id: snapshot.instance_id.clone(),
334        task_id: task_id.to_string(),
335        input: input_bytes,
336        workflow_definition_hash: snapshot.definition_hash.clone(),
337        trace_parent: snapshot.trace_parent.clone(),
338    })
339}