1use chrono::{Duration, Utc};
4use sayiir_core::codec::{self, Decoder, Encoder};
5use sayiir_core::snapshot::{ExecutionPosition, WorkflowSnapshot, WorkflowSnapshotState};
6use sayiir_core::task_claim::{AvailableTask, TaskClaim};
7use sayiir_persistence::{BackendError, SnapshotStore, TaskClaimStore};
8use sqlx::Row;
9
10use crate::backend::PostgresBackend;
11use crate::error::PgError;
12
13impl<C> TaskClaimStore for PostgresBackend<C>
14where
15 C: Encoder
16 + Decoder
17 + codec::sealed::EncodeValue<WorkflowSnapshot>
18 + codec::sealed::DecodeValue<WorkflowSnapshot>,
19{
20 #[tracing::instrument(
21 name = "db.claim_task",
22 skip(self),
23 fields(db.system = "postgresql"),
24 err(level = tracing::Level::ERROR),
25 )]
26 async fn claim_task(
27 &self,
28 instance_id: &str,
29 task_id: &str,
30 worker_id: &str,
31 ttl: Option<Duration>,
32 ) -> Result<Option<TaskClaim>, BackendError> {
33 tracing::debug!("claiming task");
34 let expires_at = ttl.and_then(|d| Utc::now().checked_add_signed(d));
35
36 let row = sqlx::query(
38 "INSERT INTO sayiir_task_claims (instance_id, task_id, worker_id, expires_at)
39 VALUES ($1, $2, $3, $4)
40 ON CONFLICT (instance_id, task_id) DO UPDATE
41 SET worker_id = $3, claimed_at = now(), expires_at = $4
42 WHERE sayiir_task_claims.expires_at IS NOT NULL AND sayiir_task_claims.expires_at < now()
43 RETURNING instance_id, task_id, worker_id,
44 EXTRACT(EPOCH FROM claimed_at)::BIGINT AS claimed_epoch,
45 EXTRACT(EPOCH FROM expires_at)::BIGINT AS expires_epoch",
46 )
47 .bind(instance_id)
48 .bind(task_id)
49 .bind(worker_id)
50 .bind(expires_at)
51 .fetch_optional(&self.pool)
52 .await
53 .map_err(PgError)?;
54
55 Ok(row.map(|r| TaskClaim {
56 instance_id: r.get("instance_id"),
57 task_id: r.get("task_id"),
58 worker_id: r.get("worker_id"),
59 claimed_at: r.get::<i64, _>("claimed_epoch").cast_unsigned(),
60 expires_at: r
61 .get::<Option<i64>, _>("expires_epoch")
62 .map(i64::cast_unsigned),
63 }))
64 }
65
66 #[tracing::instrument(
67 name = "db.release_task_claim",
68 skip(self),
69 fields(db.system = "postgresql"),
70 err(level = tracing::Level::ERROR),
71 )]
72 async fn release_task_claim(
73 &self,
74 instance_id: &str,
75 task_id: &str,
76 worker_id: &str,
77 ) -> Result<(), BackendError> {
78 tracing::debug!("releasing task claim");
79 let row = sqlx::query(
81 "SELECT worker_id FROM sayiir_task_claims WHERE instance_id = $1 AND task_id = $2",
82 )
83 .bind(instance_id)
84 .bind(task_id)
85 .fetch_optional(&self.pool)
86 .await
87 .map_err(PgError)?
88 .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
89
90 let owner: String = row.get("worker_id");
91 if owner != worker_id {
92 return Err(BackendError::Backend(format!(
93 "Claim owned by different worker: {owner}"
94 )));
95 }
96
97 sqlx::query(
98 "DELETE FROM sayiir_task_claims
99 WHERE instance_id = $1 AND task_id = $2 AND worker_id = $3",
100 )
101 .bind(instance_id)
102 .bind(task_id)
103 .bind(worker_id)
104 .execute(&self.pool)
105 .await
106 .map_err(PgError)?;
107
108 Ok(())
109 }
110
111 #[tracing::instrument(
112 name = "db.extend_task_claim",
113 skip(self),
114 fields(db.system = "postgresql"),
115 err(level = tracing::Level::ERROR),
116 )]
117 async fn extend_task_claim(
118 &self,
119 instance_id: &str,
120 task_id: &str,
121 worker_id: &str,
122 additional_duration: Duration,
123 ) -> Result<(), BackendError> {
124 tracing::debug!("extending task claim");
125 let row = sqlx::query(
126 "SELECT worker_id, expires_at FROM sayiir_task_claims
127 WHERE instance_id = $1 AND task_id = $2",
128 )
129 .bind(instance_id)
130 .bind(task_id)
131 .fetch_optional(&self.pool)
132 .await
133 .map_err(PgError)?
134 .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
135
136 let owner: String = row.get("worker_id");
137 if owner != worker_id {
138 return Err(BackendError::Backend(format!(
139 "Claim owned by different worker: {owner}"
140 )));
141 }
142
143 let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
145 if let Some(exp) = expires_at {
146 let new_exp = exp
147 .checked_add_signed(additional_duration)
148 .ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
149
150 sqlx::query(
151 "UPDATE sayiir_task_claims SET expires_at = $1
152 WHERE instance_id = $2 AND task_id = $3",
153 )
154 .bind(new_exp)
155 .bind(instance_id)
156 .bind(task_id)
157 .execute(&self.pool)
158 .await
159 .map_err(PgError)?;
160 }
161
162 Ok(())
163 }
164
165 #[tracing::instrument(
166 name = "db.find_available_tasks",
167 skip(self),
168 fields(db.system = "postgresql"),
169 err(level = tracing::Level::ERROR),
170 )]
171 async fn find_available_tasks(
172 &self,
173 worker_id: &str,
174 limit: usize,
175 ) -> Result<Vec<AvailableTask>, BackendError> {
176 sqlx::query(
178 "DELETE FROM sayiir_task_claims WHERE expires_at IS NOT NULL AND expires_at < now()",
179 )
180 .execute(&self.pool)
181 .await
182 .map_err(PgError)?;
183
184 let rows = sqlx::query(
186 "SELECT s.instance_id, s.data, s.trace_parent
187 FROM sayiir_workflow_snapshots s
188 WHERE s.status = 'InProgress'
189 AND NOT EXISTS (
190 SELECT 1 FROM sayiir_task_claims c
191 WHERE c.instance_id = s.instance_id
192 AND c.task_id = s.current_task_id
193 AND (c.expires_at IS NULL OR c.expires_at > now())
194 )
195 AND NOT EXISTS (
196 SELECT 1 FROM sayiir_workflow_signals sig
197 WHERE sig.instance_id = s.instance_id
198 )
199 ORDER BY s.updated_at ASC
200 LIMIT $1",
201 )
202 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
203 .fetch_all(&self.pool)
204 .await
205 .map_err(PgError)?;
206
207 let mut available = Vec::new();
209 for row in &rows {
210 let raw: &[u8] = row.get("data");
211 let mut snapshot = self.decode(raw)?;
212 snapshot.trace_parent = row.get("trace_parent");
213
214 match &snapshot.state {
215 WorkflowSnapshotState::InProgress {
217 position:
218 ExecutionPosition::AtDelay {
219 wake_at,
220 next_task_id,
221 delay_id,
222 ..
223 },
224 ..
225 } if Utc::now() >= *wake_at => {
226 if let Some(next_id) = next_task_id.clone() {
227 snapshot.update_position(ExecutionPosition::AtTask { task_id: next_id });
228 self.save_snapshot(&snapshot).await?;
229
230 if let WorkflowSnapshotState::InProgress {
231 position: ExecutionPosition::AtTask { task_id },
232 completed_tasks,
233 ..
234 } = &snapshot.state
235 && let Some(task) =
236 build_available_task(&snapshot, task_id, completed_tasks, worker_id)
237 {
238 available.push(task);
239 }
240 } else {
241 let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
243 snapshot.mark_completed(output);
244 self.save_snapshot(&snapshot).await?;
245 }
246 }
247
248 WorkflowSnapshotState::InProgress {
250 position: ExecutionPosition::AtTask { task_id },
251 completed_tasks,
252 ..
253 } => {
254 if completed_tasks.contains_key(task_id) {
256 continue;
257 }
258
259 if let Some(rs) = snapshot.task_retries.get(task_id)
261 && Utc::now() < rs.next_retry_at
262 {
263 continue;
264 }
265
266 if let Some(task) =
267 build_available_task(&snapshot, task_id, completed_tasks, worker_id)
268 {
269 available.push(task);
270 }
271 }
272
273 _ => continue,
274 }
275
276 if available.len() >= limit {
277 break;
278 }
279 }
280
281 tracing::debug!(count = available.len(), "available tasks found");
282 Ok(available)
283 }
284}
285
286fn build_available_task(
288 snapshot: &WorkflowSnapshot,
289 task_id: &str,
290 completed_tasks: &std::collections::HashMap<String, sayiir_core::snapshot::TaskResult>,
291 _worker_id: &str,
292) -> Option<AvailableTask> {
293 let input = if completed_tasks.is_empty() {
294 snapshot.initial_input_bytes()
295 } else {
296 snapshot.get_last_task_output()
297 };
298
299 input.map(|input_bytes| AvailableTask {
300 instance_id: snapshot.instance_id.clone(),
301 task_id: task_id.to_string(),
302 input: input_bytes,
303 workflow_definition_hash: snapshot.definition_hash.clone(),
304 trace_parent: snapshot.trace_parent.clone(),
305 })
306}