1use chrono::{Duration, Utc};
4use sayiir_core::codec::{self, Decoder, Encoder};
5use sayiir_core::snapshot::{ExecutionPosition, WorkflowSnapshot, WorkflowSnapshotState};
6use sayiir_core::task_claim::{AvailableTask, TaskClaim};
7use sayiir_persistence::{BackendError, SnapshotStore, TaskClaimStore};
8use sqlx::Row;
9
10use crate::backend::PostgresBackend;
11use crate::error::PgError;
12
13impl<C> TaskClaimStore for PostgresBackend<C>
14where
15 C: Encoder
16 + Decoder
17 + codec::sealed::EncodeValue<WorkflowSnapshot>
18 + codec::sealed::DecodeValue<WorkflowSnapshot>,
19{
20 async fn claim_task(
21 &self,
22 instance_id: &str,
23 task_id: &str,
24 worker_id: &str,
25 ttl: Option<Duration>,
26 ) -> Result<Option<TaskClaim>, BackendError> {
27 tracing::debug!(instance_id, task_id, worker_id, "claiming task");
28 let expires_at = ttl.and_then(|d| Utc::now().checked_add_signed(d));
29
30 let row = sqlx::query(
32 "INSERT INTO sayiir_task_claims (instance_id, task_id, worker_id, expires_at)
33 VALUES ($1, $2, $3, $4)
34 ON CONFLICT (instance_id, task_id) DO UPDATE
35 SET worker_id = $3, claimed_at = now(), expires_at = $4
36 WHERE sayiir_task_claims.expires_at IS NOT NULL AND sayiir_task_claims.expires_at < now()
37 RETURNING instance_id, task_id, worker_id,
38 EXTRACT(EPOCH FROM claimed_at)::BIGINT AS claimed_epoch,
39 EXTRACT(EPOCH FROM expires_at)::BIGINT AS expires_epoch",
40 )
41 .bind(instance_id)
42 .bind(task_id)
43 .bind(worker_id)
44 .bind(expires_at)
45 .fetch_optional(&self.pool)
46 .await
47 .map_err(PgError)?;
48
49 Ok(row.map(|r| TaskClaim {
50 instance_id: r.get("instance_id"),
51 task_id: r.get("task_id"),
52 worker_id: r.get("worker_id"),
53 claimed_at: r.get::<i64, _>("claimed_epoch").cast_unsigned(),
54 expires_at: r
55 .get::<Option<i64>, _>("expires_epoch")
56 .map(i64::cast_unsigned),
57 }))
58 }
59
60 async fn release_task_claim(
61 &self,
62 instance_id: &str,
63 task_id: &str,
64 worker_id: &str,
65 ) -> Result<(), BackendError> {
66 tracing::debug!(instance_id, task_id, worker_id, "releasing task claim");
67 let row = sqlx::query(
69 "SELECT worker_id FROM sayiir_task_claims WHERE instance_id = $1 AND task_id = $2",
70 )
71 .bind(instance_id)
72 .bind(task_id)
73 .fetch_optional(&self.pool)
74 .await
75 .map_err(PgError)?
76 .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
77
78 let owner: String = row.get("worker_id");
79 if owner != worker_id {
80 return Err(BackendError::Backend(format!(
81 "Claim owned by different worker: {owner}"
82 )));
83 }
84
85 sqlx::query(
86 "DELETE FROM sayiir_task_claims
87 WHERE instance_id = $1 AND task_id = $2 AND worker_id = $3",
88 )
89 .bind(instance_id)
90 .bind(task_id)
91 .bind(worker_id)
92 .execute(&self.pool)
93 .await
94 .map_err(PgError)?;
95
96 Ok(())
97 }
98
99 async fn extend_task_claim(
100 &self,
101 instance_id: &str,
102 task_id: &str,
103 worker_id: &str,
104 additional_duration: Duration,
105 ) -> Result<(), BackendError> {
106 tracing::debug!(instance_id, task_id, worker_id, "extending task claim");
107 let row = sqlx::query(
108 "SELECT worker_id, expires_at FROM sayiir_task_claims
109 WHERE instance_id = $1 AND task_id = $2",
110 )
111 .bind(instance_id)
112 .bind(task_id)
113 .fetch_optional(&self.pool)
114 .await
115 .map_err(PgError)?
116 .ok_or_else(|| BackendError::NotFound(format!("{instance_id}:{task_id}")))?;
117
118 let owner: String = row.get("worker_id");
119 if owner != worker_id {
120 return Err(BackendError::Backend(format!(
121 "Claim owned by different worker: {owner}"
122 )));
123 }
124
125 let expires_at: Option<chrono::DateTime<Utc>> = row.get("expires_at");
127 if let Some(exp) = expires_at {
128 let new_exp = exp
129 .checked_add_signed(additional_duration)
130 .ok_or_else(|| BackendError::Backend("Time overflow".to_string()))?;
131
132 sqlx::query(
133 "UPDATE sayiir_task_claims SET expires_at = $1
134 WHERE instance_id = $2 AND task_id = $3",
135 )
136 .bind(new_exp)
137 .bind(instance_id)
138 .bind(task_id)
139 .execute(&self.pool)
140 .await
141 .map_err(PgError)?;
142 }
143
144 Ok(())
145 }
146
147 async fn find_available_tasks(
148 &self,
149 worker_id: &str,
150 limit: usize,
151 ) -> Result<Vec<AvailableTask>, BackendError> {
152 tracing::debug!(worker_id, limit, "finding available tasks");
153 sqlx::query(
155 "DELETE FROM sayiir_task_claims WHERE expires_at IS NOT NULL AND expires_at < now()",
156 )
157 .execute(&self.pool)
158 .await
159 .map_err(PgError)?;
160
161 let rows = sqlx::query(
163 "SELECT s.instance_id, s.data
164 FROM sayiir_workflow_snapshots s
165 WHERE s.status = 'InProgress'
166 AND NOT EXISTS (
167 SELECT 1 FROM sayiir_task_claims c
168 WHERE c.instance_id = s.instance_id
169 AND c.task_id = s.current_task_id
170 AND (c.expires_at IS NULL OR c.expires_at > now())
171 )
172 AND NOT EXISTS (
173 SELECT 1 FROM sayiir_workflow_signals sig
174 WHERE sig.instance_id = s.instance_id
175 )
176 ORDER BY s.updated_at ASC
177 LIMIT $1",
178 )
179 .bind(i64::try_from(limit).unwrap_or(i64::MAX))
180 .fetch_all(&self.pool)
181 .await
182 .map_err(PgError)?;
183
184 let mut available = Vec::new();
186 for row in &rows {
187 let raw: &[u8] = row.get("data");
188 let mut snapshot = self.decode(raw)?;
189
190 match &snapshot.state {
191 WorkflowSnapshotState::InProgress {
193 position:
194 ExecutionPosition::AtDelay {
195 wake_at,
196 next_task_id,
197 delay_id,
198 ..
199 },
200 ..
201 } if Utc::now() >= *wake_at => {
202 if let Some(next_id) = next_task_id.clone() {
203 snapshot.update_position(ExecutionPosition::AtTask { task_id: next_id });
204 self.save_snapshot(&snapshot).await?;
205
206 if let WorkflowSnapshotState::InProgress {
207 position: ExecutionPosition::AtTask { task_id },
208 completed_tasks,
209 ..
210 } = &snapshot.state
211 && let Some(task) =
212 build_available_task(&snapshot, task_id, completed_tasks, worker_id)
213 {
214 available.push(task);
215 }
216 } else {
217 let output = snapshot.get_task_result_bytes(delay_id).unwrap_or_default();
219 snapshot.mark_completed(output);
220 self.save_snapshot(&snapshot).await?;
221 }
222 }
223
224 WorkflowSnapshotState::InProgress {
226 position: ExecutionPosition::AtTask { task_id },
227 completed_tasks,
228 ..
229 } => {
230 if completed_tasks.contains_key(task_id) {
232 continue;
233 }
234
235 if let Some(rs) = snapshot.task_retries.get(task_id)
237 && Utc::now() < rs.next_retry_at
238 {
239 continue;
240 }
241
242 if let Some(task) =
243 build_available_task(&snapshot, task_id, completed_tasks, worker_id)
244 {
245 available.push(task);
246 }
247 }
248
249 _ => continue,
250 }
251
252 if available.len() >= limit {
253 break;
254 }
255 }
256
257 tracing::debug!(worker_id, count = available.len(), "available tasks found");
258 Ok(available)
259 }
260}
261
262fn build_available_task(
264 snapshot: &WorkflowSnapshot,
265 task_id: &str,
266 completed_tasks: &std::collections::HashMap<String, sayiir_core::snapshot::TaskResult>,
267 _worker_id: &str,
268) -> Option<AvailableTask> {
269 let input = if completed_tasks.is_empty() {
270 snapshot.initial_input_bytes()
271 } else {
272 snapshot.get_last_task_output()
273 };
274
275 input.map(|input_bytes| AvailableTask {
276 instance_id: snapshot.instance_id.clone(),
277 task_id: task_id.to_string(),
278 input: input_bytes,
279 workflow_definition_hash: snapshot.definition_hash.clone(),
280 })
281}