sayiir_runtime/execution/
executors.rs1use backon::{BlockingRetryable, Retryable};
4use bytes::Bytes;
5use sayiir_core::error::{BoxError, WorkflowError};
6use sayiir_core::workflow::WorkflowContinuation;
7use sayiir_persistence::SignalStore;
8use std::future::Future;
9use std::sync::Arc;
10
11use crate::error::RuntimeError;
12
13use super::fork::{
14 JoinResolution, collect_cached_branches, execute_fork_branches_sequential, resolve_join,
15 settle_fork_outcome,
16};
17use super::helpers::{
18 check_guards, execute_task_step, park_at_delay, park_at_signal, policy_to_backoff,
19};
20
21pub fn execute_continuation_sync<F>(
36 continuation: &WorkflowContinuation,
37 input: Bytes,
38 execute_task: &F,
39) -> Result<Bytes, RuntimeError>
40where
41 F: Fn(&str, Bytes) -> Result<Bytes, BoxError>,
42{
43 let mut current = continuation;
44 let mut current_input = input;
45
46 loop {
47 match current {
48 WorkflowContinuation::Task {
49 id,
50 retry_policy,
51 next,
52 ..
53 } => {
54 let output = (|| execute_task(id, current_input.clone()))
55 .retry(policy_to_backoff(retry_policy.as_ref()))
56 .sleep(std::thread::sleep)
57 .notify(|e, dur: std::time::Duration| {
58 tracing::info!(
59 task_id = %id,
60 delay_ms = dur.as_millis(),
61 error = %e,
62 "Retrying task (sync)"
63 );
64 })
65 .call()
66 .map_err(RuntimeError::from)?;
67
68 match next {
69 Some(next_cont) => {
70 current = next_cont;
71 current_input = output;
72 }
73 None => return Ok(output),
74 }
75 }
76 WorkflowContinuation::Fork { branches, join, .. } => {
77 let mut branch_results = Vec::with_capacity(branches.len());
79
80 for branch in branches {
81 let branch_id = branch.id().to_string();
82 let output =
83 execute_continuation_sync(branch, current_input.clone(), execute_task)?;
84 branch_results.push((branch_id, output));
85 }
86
87 match resolve_join(join.as_deref(), &branch_results)? {
88 JoinResolution::Continue { next, input } => {
89 current = next;
90 current_input = input;
91 }
92 JoinResolution::Done(output) => return Ok(output),
93 }
94 }
95 WorkflowContinuation::Delay { duration, next, .. } => {
96 std::thread::sleep(*duration);
97 match next {
98 Some(next_cont) => {
99 current = next_cont;
100 }
101 None => return Ok(current_input),
102 }
103 }
104 WorkflowContinuation::AwaitSignal { id, .. } => {
105 return Err(WorkflowError::ResumeError(format!(
107 "AwaitSignal '{id}' not supported in sync executor"
108 ))
109 .into());
110 }
111 }
112 }
113}
114
115pub async fn execute_continuation_async(
129 continuation: &WorkflowContinuation,
130 input: Bytes,
131) -> Result<Bytes, RuntimeError> {
132 execute_async_inner(continuation, input, true).await
133}
134
135async fn run_task_with_retry(
140 id: &str,
141 input: Bytes,
142 func: &dyn sayiir_core::task::CoreTask<
143 Input = Bytes,
144 Output = Bytes,
145 Future = sayiir_core::task::BytesFuture,
146 >,
147 timeout: Option<&std::time::Duration>,
148 retry_policy: Option<&sayiir_core::task::RetryPolicy>,
149) -> Result<Bytes, RuntimeError> {
150 (|| async {
151 let task_input = input.clone();
152 if let Some(d) = timeout {
153 match tokio::time::timeout(*d, func.run(task_input)).await {
154 Ok(result) => result.map_err(RuntimeError::from),
155 Err(_) => Err(WorkflowError::TaskTimedOut {
156 task_id: id.to_string(),
157 timeout: *d,
158 }
159 .into()),
160 }
161 } else {
162 func.run(task_input).await.map_err(RuntimeError::from)
163 }
164 })
165 .retry(policy_to_backoff(retry_policy))
166 .notify(|e, dur: std::time::Duration| {
167 tracing::info!(
168 task_id = %id,
169 delay_ms = dur.as_millis(),
170 error = %e,
171 "Retrying task"
172 );
173 })
174 .await
175}
176
177fn execute_async_inner<'a>(
185 continuation: &'a WorkflowContinuation,
186 input: Bytes,
187 parallel_branches: bool,
188) -> std::pin::Pin<Box<dyn Future<Output = Result<Bytes, RuntimeError>> + Send + 'a>> {
189 Box::pin(async move {
190 let mut current = continuation;
191 let mut current_input = input;
192
193 loop {
194 match current {
195 WorkflowContinuation::Task {
196 id,
197 func: Some(func),
198 timeout,
199 retry_policy,
200 next,
201 } => {
202 let output = run_task_with_retry(
203 id,
204 current_input.clone(),
205 func.as_ref(),
206 timeout.as_ref(),
207 retry_policy.as_ref(),
208 )
209 .await?;
210
211 match next {
212 Some(next_cont) => {
213 current = next_cont;
214 current_input = output;
215 }
216 None => return Ok(output),
217 }
218 }
219 WorkflowContinuation::Task { func: None, id, .. } => {
220 return Err(WorkflowError::TaskNotImplemented(id.clone()).into());
221 }
222 WorkflowContinuation::Delay { duration, next, .. } => {
223 tokio::time::sleep(*duration).await;
224 match next {
225 Some(next_cont) => {
226 current = next_cont;
227 }
228 None => return Ok(current_input),
229 }
230 }
231 WorkflowContinuation::AwaitSignal { id, .. } => {
232 return Err(WorkflowError::ResumeError(format!(
234 "AwaitSignal '{id}' not supported in non-durable async executor"
235 ))
236 .into());
237 }
238 WorkflowContinuation::Fork { branches, join, .. } => {
239 let branch_results = if parallel_branches && branches.len() > 1 {
240 let mut set = tokio::task::JoinSet::new();
242 for branch in branches {
243 let branch_id = branch.id().to_string();
244 let branch = Arc::clone(branch);
245 let branch_input = current_input.clone();
246 set.spawn(async move {
247 execute_async_inner(&branch, branch_input, false)
248 .await
249 .map(|output| (branch_id, output))
250 });
251 }
252
253 let mut results = Vec::with_capacity(set.len());
254 while let Some(res) = set.join_next().await {
255 results.push(res??);
256 }
257 results
258 } else {
259 let mut results = Vec::with_capacity(branches.len());
261 for branch in branches {
262 let branch_id = branch.id().to_string();
263 let output =
264 execute_async_inner(branch, current_input.clone(), false).await?;
265 results.push((branch_id, output));
266 }
267 results
268 };
269
270 match resolve_join(join.as_deref(), &branch_results)? {
271 JoinResolution::Continue { next, input } => {
272 current = next;
273 current_input = input;
274 }
275 JoinResolution::Done(output) => return Ok(output),
276 }
277 }
278 }
279 }
280 })
281}
282
283#[allow(clippy::too_many_lines)]
305pub async fn execute_continuation_with_checkpointing<F, Fut, B>(
306 continuation: &WorkflowContinuation,
307 input: Bytes,
308 snapshot: &mut sayiir_core::snapshot::WorkflowSnapshot,
309 backend: &B,
310 execute_task: &F,
311) -> Result<Bytes, RuntimeError>
312where
313 B: SignalStore,
314 F: Fn(&str, Bytes) -> Fut + Send + Sync,
315 Fut: Future<Output = Result<Bytes, BoxError>> + Send,
316{
317 let mut current = continuation;
318 let mut current_input = input;
319
320 loop {
321 match current {
322 WorkflowContinuation::Task {
323 id,
324 timeout,
325 retry_policy,
326 next,
327 ..
328 } => {
329 let output = execute_task_step(
330 id,
331 timeout.as_ref(),
332 retry_policy.as_ref(),
333 next.as_deref(),
334 current_input.clone(),
335 snapshot,
336 backend,
337 |i| execute_task(id, i),
338 )
339 .await?;
340
341 match next {
342 Some(next_continuation) => {
343 current = next_continuation;
344 current_input = output;
345 }
346 None => return Ok(output),
347 }
348 }
349 WorkflowContinuation::Delay { id, duration, next } => {
350 check_guards(backend, &snapshot.instance_id, Some(id)).await?;
351
352 if snapshot.get_task_result(id).is_some() {
353 match next {
354 Some(n) => {
355 current = n;
356 continue;
357 }
358 None => return Ok(current_input),
359 }
360 }
361
362 return Err(park_at_delay(
363 id,
364 duration,
365 next.as_deref(),
366 current_input,
367 snapshot,
368 backend,
369 )
370 .await);
371 }
372 WorkflowContinuation::AwaitSignal {
373 id,
374 signal_name,
375 timeout,
376 next,
377 } => {
378 check_guards(backend, &snapshot.instance_id, Some(id)).await?;
379
380 if snapshot.get_task_result(id).is_some() {
382 match next {
383 Some(n) => {
384 current = n;
385 current_input =
387 snapshot.get_task_result_bytes(id).unwrap_or(current_input);
388 continue;
389 }
390 None => return Ok(current_input),
391 }
392 }
393
394 let err = park_at_signal(
395 id,
396 signal_name,
397 timeout.as_ref(),
398 next.as_deref(),
399 snapshot,
400 backend,
401 )
402 .await;
403
404 if matches!(err, RuntimeError::Workflow(WorkflowError::SignalConsumed)) {
407 if let Some(n) = next {
408 current = n;
409 current_input = snapshot.get_task_result_bytes(id).unwrap_or(current_input);
410 continue;
411 }
412 let output = snapshot.get_task_result_bytes(id).unwrap_or(current_input);
413 return Ok(output);
414 }
415
416 return Err(err);
417 }
418 WorkflowContinuation::Fork {
419 id: fork_id,
420 branches,
421 join,
422 } => {
423 check_guards(backend, &snapshot.instance_id, None).await?;
424
425 let branch_results =
426 if let Some(cached) = collect_cached_branches(branches, snapshot) {
427 cached
428 } else {
429 let outcome = execute_fork_branches_sequential(
430 branches,
431 ¤t_input,
432 snapshot,
433 backend,
434 execute_task,
435 )
436 .await?;
437 settle_fork_outcome(fork_id, outcome, join.as_deref(), snapshot, backend)
438 .await?
439 };
440
441 match resolve_join(join.as_deref(), &branch_results)? {
442 JoinResolution::Continue { next, input } => {
443 current = next;
444 current_input = input;
445 }
446 JoinResolution::Done(output) => return Ok(output),
447 }
448 }
449 }
450 }
451}