1use chrono::{Duration, Utc};
4use sayiir_core::codec::{self, Decoder, Encoder};
5use sayiir_core::snapshot::{ExecutionPosition, WorkflowSnapshot, WorkflowSnapshotState};
6use sayiir_core::task_claim::{AvailableTask, TaskClaim};
7use sayiir_persistence::{BackendError, SnapshotStore, TaskClaimStore};
8use sqlx::Row;
9
10use crate::backend::PostgresBackend;
11use crate::error::PgError;
12
13impl<C> TaskClaimStore for PostgresBackend<C>
14where
15 C: Encoder
16 + Decoder
17 + codec::sealed::EncodeValue<WorkflowSnapshot>
18 + codec::sealed::DecodeValue<WorkflowSnapshot>,
19{
20 async fn claim_task(
21 &self,
22 instance_id: &str,
23 task_id: &str,
24 worker_id: &str,
25 ttl: Option<Duration>,
26 ) -> Result<Option<TaskClaim>, BackendError> {
27 tracing::debug!(instance_id, task_id, worker_id, "claiming task");
28 let expires_at = ttl.and_then(|d| Utc::now().checked_add_signed(d));
29
30 let row = sqlx::query(
32 "INSERT INTO sayiir_task_claims (instance_id, task_id, worker_id, expires_at)
33 VALUES ($1, $2, $3, $4)
34 ON CONFLICT (instance_id, task_id) DO UPDATE
35 SET worker_id = $3, claimed_at = now(), expires_at = $4
36 WHERE sayiir_task_claims.expires_at IS NOT NULL AND sayiir_task_claims.expires_at < now()
37 RETURNING instance_id, task_id, worker_id,
38 EXTRACT(EPOCH FROM claimed_at)::BIGINT AS claimed_epoch,
39 EXTRACT(EPOCH FROM expires_at)::BIGINT AS expires_epoch",
40 )
41 .bind(instance_id)
42 .bind(task_id)
43 .bind(worker_id)
44 .bind(expires_at)
45 .fetch_optional(&self.pool)
46 .await
47 .map_err(PgError)?;
48
49 Ok(row.map(|r| TaskClaim {
50 instance_id: r.get("instance_id"),
51 task_id: r.get("task_id"),
52 worker_id: r.get("worker_id"),
53 claimed_at: r.get::<i64, _>("claimed_epoch") as u64,
54 expires_at: r.get::<Option<i64>, _>("expires_epoch").map(|e| e as u64),
55 }))
56 }
57
58 async fn release_task_claim(
59 &self,
60 instance_id: &str,
61 task_id: &str,
62 worker_id: &str,
63 ) -> Result<(), BackendError> {
64 tracing::debug!(instance_id, task_id, worker_id, "releasing task claim");
65 let row = sqlx::query(
67 "SELECT worker_id FROM sayiir_task_claims WHERE instance_id = $1 AND task_id = $2",
68 )
69 .bind(instance_id)
70 .bind(task_id)
71 .fetch_optional(&self.pool)
72 .await
73 .map_err(PgError)?
74 .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
75
76 let owner: String = row.get("worker_id");
77 if owner != worker_id {
78 return Err(BackendError::Backend(format!(
79 "Claim owned by different worker: {owner}"
80 )));
81 }
82
83 sqlx::query(
84 "DELETE FROM sayiir_task_claims
85 WHERE instance_id = $1 AND task_id = $2 AND worker_id = $3",
86 )
87 .bind(instance_id)
88 .bind(task_id)
89 .bind(worker_id)
90 .execute(&self.pool)
91 .await
92 .map_err(PgError)?;
93
94 Ok(())
95 }
96
97 async fn extend_task_claim(
98 &self,
99 instance_id: &str,
100 task_id: &str,
101 worker_id: &str,
102 additional_duration: Duration,
103 ) -> Result<(), BackendError> {
104 tracing::debug!(instance_id, task_id, worker_id, "extending task claim");
105 let row = sqlx::query(
106 "SELECT worker_id, expires_at FROM sayiir_task_claims
107 WHERE instance_id = $1 AND task_id = $2",
108 )
109 .bind(instance_id)
110 .bind(task_id)
111 .fetch_optional(&self.pool)
112 .await
113 .map_err(PgError)?
114 .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
115
116 let owner: String = row.get("worker_id");
117 if owner != worker_id {
118 return Err(BackendError::Backend(format!(
119 "Claim owned by different worker: {owner}"
120 )));
121 }
122
123 let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
125 if let Some(exp) = expires_at {
126 let new_exp = exp
127 .checked_add_signed(additional_duration)
128 .ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
129
130 sqlx::query(
131 "UPDATE sayiir_task_claims SET expires_at = $1
132 WHERE instance_id = $2 AND task_id = $3",
133 )
134 .bind(new_exp)
135 .bind(instance_id)
136 .bind(task_id)
137 .execute(&self.pool)
138 .await
139 .map_err(PgError)?;
140 }
141
142 Ok(())
143 }
144
145 async fn find_available_tasks(
146 &self,
147 worker_id: &str,
148 limit: usize,
149 ) -> Result<Vec<AvailableTask>, BackendError> {
150 tracing::debug!(worker_id, limit, "finding available tasks");
151 sqlx::query(
153 "DELETE FROM sayiir_task_claims WHERE expires_at IS NOT NULL AND expires_at < now()",
154 )
155 .execute(&self.pool)
156 .await
157 .map_err(PgError)?;
158
159 let rows = sqlx::query(
161 "SELECT s.instance_id, s.data
162 FROM sayiir_workflow_snapshots s
163 WHERE s.status = 'InProgress'
164 AND NOT EXISTS (
165 SELECT 1 FROM sayiir_task_claims c
166 WHERE c.instance_id = s.instance_id
167 AND c.task_id = s.current_task_id
168 AND (c.expires_at IS NULL OR c.expires_at > now())
169 )
170 AND NOT EXISTS (
171 SELECT 1 FROM sayiir_workflow_signals sig
172 WHERE sig.instance_id = s.instance_id
173 )
174 ORDER BY s.updated_at ASC
175 LIMIT $1",
176 )
177 .bind(limit as i64)
178 .fetch_all(&self.pool)
179 .await
180 .map_err(PgError)?;
181
182 let mut available = Vec::new();
184 for row in &rows {
185 let raw: &[u8] = row.get("data");
186 let mut snapshot = self.decode(raw)?;
187
188 match &snapshot.state {
189 WorkflowSnapshotState::InProgress {
191 position:
192 ExecutionPosition::AtDelay {
193 wake_at,
194 next_task_id,
195 delay_id,
196 ..
197 },
198 ..
199 } if Utc::now() >= *wake_at => {
200 if let Some(next_id) = next_task_id.clone() {
201 snapshot.update_position(ExecutionPosition::AtTask { task_id: next_id });
202 self.save_snapshot(&snapshot).await?;
203
204 if let WorkflowSnapshotState::InProgress {
205 position: ExecutionPosition::AtTask { task_id },
206 completed_tasks,
207 ..
208 } = &snapshot.state
209 && let Some(task) =
210 build_available_task(&snapshot, task_id, completed_tasks, worker_id)
211 {
212 available.push(task);
213 }
214 } else {
215 let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
217 snapshot.mark_completed(output);
218 self.save_snapshot(&snapshot).await?;
219 }
220 }
221
222 WorkflowSnapshotState::InProgress {
224 position: ExecutionPosition::AtTask { task_id },
225 completed_tasks,
226 ..
227 } => {
228 if completed_tasks.contains_key(task_id) {
230 continue;
231 }
232
233 if let Some(rs) = snapshot.task_retries.get(task_id)
235 && Utc::now() < rs.next_retry_at
236 {
237 continue;
238 }
239
240 if let Some(task) =
241 build_available_task(&snapshot, task_id, completed_tasks, worker_id)
242 {
243 available.push(task);
244 }
245 }
246
247 _ => continue,
248 }
249
250 if available.len() >= limit {
251 break;
252 }
253 }
254
255 tracing::debug!(worker_id, count = available.len(), "available tasks found");
256 Ok(available)
257 }
258}
259
260fn build_available_task(
262 snapshot: &WorkflowSnapshot,
263 task_id: &str,
264 completed_tasks: &std::collections::HashMap<String, sayiir_core::snapshot::TaskResult>,
265 _worker_id: &str,
266) -> Option<AvailableTask> {
267 let input = if completed_tasks.is_empty() {
268 snapshot.initial_input_bytes()
269 } else {
270 snapshot.get_last_task_output()
271 };
272
273 input.map(|input_bytes| AvailableTask {
274 instance_id: snapshot.instance_id.clone(),
275 task_id: task_id.to_string(),
276 input: input_bytes,
277 workflow_definition_hash: snapshot.definition_hash.clone(),
278 })
279}