1use chrono::{Duration, Utc};
4use sayiir_core::codec::{self, Decoder, Encoder};
5use sayiir_core::snapshot::{ExecutionPosition, WorkflowSnapshot, WorkflowSnapshotState};
6use sayiir_core::task_claim::{AvailableTask, TaskClaim};
7use sayiir_persistence::{BackendError, SnapshotStore, TaskClaimStore};
8use sqlx::Row;
9
10use crate::backend::PostgresBackend;
11use crate::error::PgError;
12
13impl<C> TaskClaimStore for PostgresBackend<C>
14where
15 C: Encoder
16 + Decoder
17 + codec::sealed::EncodeValue<WorkflowSnapshot>
18 + codec::sealed::DecodeValue<WorkflowSnapshot>,
19{
20 #[tracing::instrument(
21 name = "db.claim_task",
22 skip(self),
23 fields(db.system = "postgresql"),
24 err(level = tracing::Level::ERROR),
25 )]
26 async fn claim_task(
27 &self,
28 instance_id: &str,
29 task_id: &str,
30 worker_id: &str,
31 ttl: Option<Duration>,
32 ) -> Result<Option<TaskClaim>, BackendError> {
33 tracing::debug!("claiming task");
34 let expires_at = ttl.and_then(|d| Utc::now().checked_add_signed(d));
35
36 let row = sqlx::query(
38 "INSERT INTO sayiir_task_claims (instance_id, task_id, worker_id, expires_at)
39 VALUES ($1, $2, $3, $4)
40 ON CONFLICT (instance_id, task_id) DO UPDATE
41 SET worker_id = $3, claimed_at = now(), expires_at = $4
42 WHERE sayiir_task_claims.expires_at IS NOT NULL AND sayiir_task_claims.expires_at < now()
43 RETURNING instance_id, task_id, worker_id,
44 EXTRACT(EPOCH FROM claimed_at)::BIGINT AS claimed_epoch,
45 EXTRACT(EPOCH FROM expires_at)::BIGINT AS expires_epoch",
46 )
47 .bind(instance_id)
48 .bind(task_id)
49 .bind(worker_id)
50 .bind(expires_at)
51 .fetch_optional(&self.pool)
52 .await
53 .map_err(PgError)?;
54
55 Ok(row.map(|r| TaskClaim {
56 instance_id: r.get("instance_id"),
57 task_id: r.get("task_id"),
58 worker_id: r.get("worker_id"),
59 claimed_at: r.get::<i64, _>("claimed_epoch").cast_unsigned(),
60 expires_at: r
61 .get::<Option<i64>, _>("expires_epoch")
62 .map(i64::cast_unsigned),
63 }))
64 }
65
66 #[tracing::instrument(
67 name = "db.release_task_claim",
68 skip(self),
69 fields(db.system = "postgresql"),
70 err(level = tracing::Level::ERROR),
71 )]
72 async fn release_task_claim(
73 &self,
74 instance_id: &str,
75 task_id: &str,
76 worker_id: &str,
77 ) -> Result<(), BackendError> {
78 tracing::debug!("releasing task claim");
79 let row = sqlx::query(
81 "SELECT worker_id FROM sayiir_task_claims WHERE instance_id = $1 AND task_id = $2",
82 )
83 .bind(instance_id)
84 .bind(task_id)
85 .fetch_optional(&self.pool)
86 .await
87 .map_err(PgError)?
88 .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
89
90 let owner: String = row.get("worker_id");
91 if owner != worker_id {
92 return Err(BackendError::Backend(format!(
93 "Claim owned by different worker: {owner}"
94 )));
95 }
96
97 sqlx::query(
98 "DELETE FROM sayiir_task_claims
99 WHERE instance_id = $1 AND task_id = $2 AND worker_id = $3",
100 )
101 .bind(instance_id)
102 .bind(task_id)
103 .bind(worker_id)
104 .execute(&self.pool)
105 .await
106 .map_err(PgError)?;
107
108 Ok(())
109 }
110
111 #[tracing::instrument(
112 name = "db.extend_task_claim",
113 skip(self),
114 fields(db.system = "postgresql"),
115 err(level = tracing::Level::ERROR),
116 )]
117 async fn extend_task_claim(
118 &self,
119 instance_id: &str,
120 task_id: &str,
121 worker_id: &str,
122 additional_duration: Duration,
123 ) -> Result<(), BackendError> {
124 tracing::debug!("extending task claim");
125 let row = sqlx::query(
126 "SELECT worker_id, expires_at FROM sayiir_task_claims
127 WHERE instance_id = $1 AND task_id = $2",
128 )
129 .bind(instance_id)
130 .bind(task_id)
131 .fetch_optional(&self.pool)
132 .await
133 .map_err(PgError)?
134 .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
135
136 let owner: String = row.get("worker_id");
137 if owner != worker_id {
138 return Err(BackendError::Backend(format!(
139 "Claim owned by different worker: {owner}"
140 )));
141 }
142
143 let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
145 if let Some(exp) = expires_at {
146 let new_exp = exp
147 .checked_add_signed(additional_duration)
148 .ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
149
150 sqlx::query(
151 "UPDATE sayiir_task_claims SET expires_at = $1
152 WHERE instance_id = $2 AND task_id = $3",
153 )
154 .bind(new_exp)
155 .bind(instance_id)
156 .bind(task_id)
157 .execute(&self.pool)
158 .await
159 .map_err(PgError)?;
160 }
161
162 Ok(())
163 }
164
165 #[tracing::instrument(
166 name = "db.find_available_tasks",
167 skip(self),
168 fields(db.system = "postgresql"),
169 err(level = tracing::Level::ERROR),
170 )]
171 #[allow(clippy::cast_precision_loss, clippy::too_many_lines)]
172 async fn find_available_tasks(
173 &self,
174 worker_id: &str,
175 limit: usize,
176 aging_interval: Duration,
177 worker_tags: &[String],
178 ) -> Result<Vec<AvailableTask>, BackendError> {
179 sqlx::query(
181 "DELETE FROM sayiir_task_claims WHERE expires_at IS NOT NULL AND expires_at < now()",
182 )
183 .execute(&self.pool)
184 .await
185 .map_err(PgError)?;
186
187 let aging_secs = (aging_interval.num_milliseconds() as f64 / 1000.0).max(1.0);
191 let worker_tags_vec: Vec<&str> = worker_tags.iter().map(String::as_str).collect();
192
193 let tag_filter = if worker_tags.is_empty() {
197 ""
198 } else {
199 "AND s.task_tags <@ $3"
200 };
201
202 let query = format!(
203 "SELECT s.instance_id, s.data, s.trace_parent
204 FROM sayiir_workflow_snapshots s
205 WHERE s.status = 'InProgress'
206 AND NOT EXISTS (
207 SELECT 1 FROM sayiir_task_claims c
208 WHERE c.instance_id = s.instance_id
209 AND c.task_id = s.current_task_id
210 AND (c.expires_at IS NULL OR c.expires_at > now())
211 )
212 AND NOT EXISTS (
213 SELECT 1 FROM sayiir_workflow_signals sig
214 WHERE sig.instance_id = s.instance_id
215 )
216 {tag_filter}
217 ORDER BY
218 (s.task_priority - EXTRACT(EPOCH FROM (now() - s.updated_at)) / $2) ASC,
219 s.updated_at ASC
220 LIMIT $1"
221 );
222
223 let mut q = sqlx::query(&query)
224 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
225 .bind(aging_secs);
226 if !worker_tags.is_empty() {
227 q = q.bind(&worker_tags_vec);
228 }
229 let rows = q.fetch_all(&self.pool).await.map_err(PgError)?;
230
231 let mut available: Vec<(bool, AvailableTask)> = Vec::with_capacity(rows.len());
235 for row in &rows {
236 let raw: &[u8] = row.get("data");
237 let mut snapshot = self.decode(raw)?;
238 snapshot.trace_parent = row.get("trace_parent");
239
240 match &snapshot.state {
241 WorkflowSnapshotState::InProgress {
243 position:
244 ExecutionPosition::AtDelay {
245 wake_at,
246 next_task_id,
247 delay_id,
248 ..
249 },
250 ..
251 } if Utc::now() >= *wake_at => {
252 if let Some(next_id) = next_task_id.clone() {
253 snapshot.update_position(ExecutionPosition::AtTask { task_id: next_id });
254 self.save_snapshot(&snapshot).await?;
255
256 if let WorkflowSnapshotState::InProgress {
257 position: ExecutionPosition::AtTask { task_id },
258 completed_tasks,
259 ..
260 } = &snapshot.state
261 && let Some(task) =
262 build_available_task(&snapshot, task_id, completed_tasks, worker_id)
263 {
264 let bias = snapshot.has_failed_on_worker(task_id, worker_id);
265 available.push((bias, task));
266 }
267 } else {
268 let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
270 snapshot.mark_completed(output);
271 self.save_snapshot(&snapshot).await?;
272 }
273 }
274
275 WorkflowSnapshotState::InProgress {
277 position: ExecutionPosition::AtTask { task_id },
278 completed_tasks,
279 ..
280 } => {
281 if completed_tasks.contains_key(task_id) {
283 continue;
284 }
285
286 if let Some(rs) = snapshot.task_retries.get(task_id)
288 && Utc::now() < rs.next_retry_at
289 {
290 continue;
291 }
292
293 if let Some(task) =
294 build_available_task(&snapshot, task_id, completed_tasks, worker_id)
295 {
296 let bias = snapshot.has_failed_on_worker(task_id, worker_id);
297 available.push((bias, task));
298 }
299 }
300
301 _ => continue,
302 }
303
304 if available.len() >= limit {
305 break;
306 }
307 }
308
309 available.sort_by_key(|(bias, _)| *bias);
313
314 tracing::debug!(count = available.len(), "available tasks found");
315 Ok(available.into_iter().map(|(_, task)| task).collect())
316 }
317}
318
319fn build_available_task(
321 snapshot: &WorkflowSnapshot,
322 task_id: &str,
323 completed_tasks: &std::collections::HashMap<String, sayiir_core::snapshot::TaskResult>,
324 _worker_id: &str,
325) -> Option<AvailableTask> {
326 let input = if completed_tasks.is_empty() {
327 snapshot.initial_input_bytes()
328 } else {
329 snapshot.get_last_task_output()
330 };
331
332 input.map(|input_bytes| AvailableTask {
333 instance_id: snapshot.instance_id.clone(),
334 task_id: task_id.to_string(),
335 input: input_bytes,
336 workflow_definition_hash: snapshot.definition_hash.clone(),
337 trace_parent: snapshot.trace_parent.clone(),
338 })
339}