Skip to main content

sayiir_runtime/execution/
executors.rs

1//! Sync, async, and checkpointing execution loops.
2
3use backon::{BlockingRetryable, Retryable};
4use bytes::Bytes;
5use sayiir_core::error::{BoxError, WorkflowError};
6use sayiir_core::workflow::WorkflowContinuation;
7use sayiir_persistence::SignalStore;
8use std::future::Future;
9use std::sync::Arc;
10
11use crate::error::RuntimeError;
12
13use super::fork::{
14    JoinResolution, collect_cached_branches, execute_fork_branches_sequential, resolve_join,
15    settle_fork_outcome,
16};
17use super::helpers::{
18    check_guards, execute_task_step, park_at_delay, park_at_signal, policy_to_backoff,
19};
20
21// ── Sync ────────────────────────────────────────────────────────────────
22
23/// Execute a workflow continuation synchronously.
24///
25/// This is useful for environments that don't support async (like Python with GIL).
26/// Branches are executed sequentially.
27///
28/// # Arguments
29/// * `continuation` - The workflow continuation to execute
30/// * `input` - Input bytes for the first task
31/// * `execute_task` - Callback to execute a task: (`task_id`, input) -> `Result<output>`
32///
33/// # Errors
34/// Returns an error if task execution fails.
35pub fn execute_continuation_sync<F>(
36    continuation: &WorkflowContinuation,
37    input: Bytes,
38    execute_task: &F,
39) -> Result<Bytes, RuntimeError>
40where
41    F: Fn(&str, Bytes) -> Result<Bytes, BoxError>,
42{
43    let mut current = continuation;
44    let mut current_input = input;
45
46    loop {
47        match current {
48            WorkflowContinuation::Task {
49                id,
50                retry_policy,
51                next,
52                ..
53            } => {
54                let output = (|| execute_task(id, current_input.clone()))
55                    .retry(policy_to_backoff(retry_policy.as_ref()))
56                    .sleep(std::thread::sleep)
57                    .notify(|e, dur: std::time::Duration| {
58                        tracing::info!(
59                            task_id = %id,
60                            delay_ms = dur.as_millis(),
61                            error = %e,
62                            "Retrying task (sync)"
63                        );
64                    })
65                    .call()
66                    .map_err(RuntimeError::from)?;
67
68                match next {
69                    Some(next_cont) => {
70                        current = next_cont;
71                        current_input = output;
72                    }
73                    None => return Ok(output),
74                }
75            }
76            WorkflowContinuation::Fork { branches, join, .. } => {
77                // Execute branches sequentially
78                let mut branch_results = Vec::with_capacity(branches.len());
79
80                for branch in branches {
81                    let branch_id = branch.id().to_string();
82                    let output =
83                        execute_continuation_sync(branch, current_input.clone(), execute_task)?;
84                    branch_results.push((branch_id, output));
85                }
86
87                match resolve_join(join.as_deref(), &branch_results)? {
88                    JoinResolution::Continue { next, input } => {
89                        current = next;
90                        current_input = input;
91                    }
92                    JoinResolution::Done(output) => return Ok(output),
93                }
94            }
95            WorkflowContinuation::Delay { duration, next, .. } => {
96                std::thread::sleep(*duration);
97                match next {
98                    Some(next_cont) => {
99                        current = next_cont;
100                    }
101                    None => return Ok(current_input),
102                }
103            }
104            WorkflowContinuation::AwaitSignal { id, .. } => {
105                // Sync executor cannot wait for external signals
106                return Err(WorkflowError::ResumeError(format!(
107                    "AwaitSignal '{id}' not supported in sync executor"
108                ))
109                .into());
110            }
111        }
112    }
113}
114
115// ── Async ───────────────────────────────────────────────────────────────
116
117/// Execute a workflow continuation asynchronously with parallel branch execution.
118///
119/// Uses the `func` from each task in the continuation for execution, and spawns
120/// branches in parallel using tokio tasks.
121///
122/// # Arguments
123/// * `continuation` - The workflow continuation to execute
124/// * `input` - Input bytes for the first task
125///
126/// # Errors
127/// Returns an error if task execution fails.
128pub async fn execute_continuation_async(
129    continuation: &WorkflowContinuation,
130    input: Bytes,
131) -> Result<Bytes, RuntimeError> {
132    execute_async_inner(continuation, input, true).await
133}
134
135/// Execute a task function with optional timeout wrapping and retry backoff.
136///
137/// Wraps the task's `func.run(input)` in an optional `tokio::time::timeout`,
138/// then retries the whole closure according to `retry_policy` (via `backon`).
139async fn run_task_with_retry(
140    id: &str,
141    input: Bytes,
142    func: &dyn sayiir_core::task::CoreTask<
143        Input = Bytes,
144        Output = Bytes,
145        Future = sayiir_core::task::BytesFuture,
146    >,
147    timeout: Option<&std::time::Duration>,
148    retry_policy: Option<&sayiir_core::task::RetryPolicy>,
149) -> Result<Bytes, RuntimeError> {
150    (|| async {
151        let task_input = input.clone();
152        if let Some(d) = timeout {
153            match tokio::time::timeout(*d, func.run(task_input)).await {
154                Ok(result) => result.map_err(RuntimeError::from),
155                Err(_) => Err(WorkflowError::TaskTimedOut {
156                    task_id: id.to_string(),
157                    timeout: *d,
158                }
159                .into()),
160            }
161        } else {
162            func.run(task_input).await.map_err(RuntimeError::from)
163        }
164    })
165    .retry(policy_to_backoff(retry_policy))
166    .notify(|e, dur: std::time::Duration| {
167        tracing::info!(
168            task_id = %id,
169            delay_ms = dur.as_millis(),
170            error = %e,
171            "Retrying task"
172        );
173    })
174    .await
175}
176
177/// Shared implementation for async continuation execution.
178///
179/// When `parallel_branches` is `true`, top-level fork branches are spawned as
180/// parallel tokio tasks. When `false` (used inside branches to avoid unbounded
181/// spawning), branches run sequentially.
182///
183/// Returns a boxed future so the recursive call is provably `Send` for `tokio::spawn`.
184fn execute_async_inner<'a>(
185    continuation: &'a WorkflowContinuation,
186    input: Bytes,
187    parallel_branches: bool,
188) -> std::pin::Pin<Box<dyn Future<Output = Result<Bytes, RuntimeError>> + Send + 'a>> {
189    Box::pin(async move {
190        let mut current = continuation;
191        let mut current_input = input;
192
193        loop {
194            match current {
195                WorkflowContinuation::Task {
196                    id,
197                    func: Some(func),
198                    timeout,
199                    retry_policy,
200                    next,
201                } => {
202                    let output = run_task_with_retry(
203                        id,
204                        current_input.clone(),
205                        func.as_ref(),
206                        timeout.as_ref(),
207                        retry_policy.as_ref(),
208                    )
209                    .await?;
210
211                    match next {
212                        Some(next_cont) => {
213                            current = next_cont;
214                            current_input = output;
215                        }
216                        None => return Ok(output),
217                    }
218                }
219                WorkflowContinuation::Task { func: None, id, .. } => {
220                    return Err(WorkflowError::TaskNotImplemented(id.clone()).into());
221                }
222                WorkflowContinuation::Delay { duration, next, .. } => {
223                    tokio::time::sleep(*duration).await;
224                    match next {
225                        Some(next_cont) => {
226                            current = next_cont;
227                        }
228                        None => return Ok(current_input),
229                    }
230                }
231                WorkflowContinuation::AwaitSignal { id, .. } => {
232                    // Async executor (non-durable) cannot wait for external signals
233                    return Err(WorkflowError::ResumeError(format!(
234                        "AwaitSignal '{id}' not supported in non-durable async executor"
235                    ))
236                    .into());
237                }
238                WorkflowContinuation::Fork { branches, join, .. } => {
239                    let branch_results = if parallel_branches && branches.len() > 1 {
240                        // Multiple branches: spawn each as a tokio task for parallelism
241                        let mut set = tokio::task::JoinSet::new();
242                        for branch in branches {
243                            let branch_id = branch.id().to_string();
244                            let branch = Arc::clone(branch);
245                            let branch_input = current_input.clone();
246                            set.spawn(async move {
247                                execute_async_inner(&branch, branch_input, false)
248                                    .await
249                                    .map(|output| (branch_id, output))
250                            });
251                        }
252
253                        let mut results = Vec::with_capacity(set.len());
254                        while let Some(res) = set.join_next().await {
255                            results.push(res??);
256                        }
257                        results
258                    } else {
259                        // Single branch or non-parallel: run inline (no spawn overhead)
260                        let mut results = Vec::with_capacity(branches.len());
261                        for branch in branches {
262                            let branch_id = branch.id().to_string();
263                            let output =
264                                execute_async_inner(branch, current_input.clone(), false).await?;
265                            results.push((branch_id, output));
266                        }
267                        results
268                    };
269
270                    match resolve_join(join.as_deref(), &branch_results)? {
271                        JoinResolution::Continue { next, input } => {
272                            current = next;
273                            current_input = input;
274                        }
275                        JoinResolution::Done(output) => return Ok(output),
276                    }
277                }
278            }
279        }
280    })
281}
282
283// ── Checkpointing ───────────────────────────────────────────────────────
284
285/// Execute a workflow continuation with checkpointing after each task.
286///
287/// This is the callback-based variant of `CheckpointingRunner::execute_with_checkpointing`.
288/// Instead of calling `func.run(input)` directly (which requires real Rust task implementations),
289/// it delegates task execution to a caller-supplied async callback. This enables environments
290/// like Python bindings to provide task implementations while still getting full checkpointing,
291/// cancellation, and resume support.
292///
293/// Fork branches are executed **sequentially** (correct for Python's GIL; parallel can come later).
294///
295/// # Arguments
296/// * `continuation` - The workflow continuation to execute
297/// * `input` - Input bytes for the first task
298/// * `snapshot` - Mutable snapshot tracking execution progress
299/// * `backend` - Persistent backend for saving checkpoints
300/// * `execute_task` - Async callback: `(task_id, input) -> Result<output>`
301///
302/// # Errors
303/// Returns an error if task execution, cancellation checking, or snapshot saving fails.
304#[allow(clippy::too_many_lines)]
305pub async fn execute_continuation_with_checkpointing<F, Fut, B>(
306    continuation: &WorkflowContinuation,
307    input: Bytes,
308    snapshot: &mut sayiir_core::snapshot::WorkflowSnapshot,
309    backend: &B,
310    execute_task: &F,
311) -> Result<Bytes, RuntimeError>
312where
313    B: SignalStore,
314    F: Fn(&str, Bytes) -> Fut + Send + Sync,
315    Fut: Future<Output = Result<Bytes, BoxError>> + Send,
316{
317    let mut current = continuation;
318    let mut current_input = input;
319
320    loop {
321        match current {
322            WorkflowContinuation::Task {
323                id,
324                timeout,
325                retry_policy,
326                next,
327                ..
328            } => {
329                let output = execute_task_step(
330                    id,
331                    timeout.as_ref(),
332                    retry_policy.as_ref(),
333                    next.as_deref(),
334                    current_input.clone(),
335                    snapshot,
336                    backend,
337                    |i| execute_task(id, i),
338                )
339                .await?;
340
341                match next {
342                    Some(next_continuation) => {
343                        current = next_continuation;
344                        current_input = output;
345                    }
346                    None => return Ok(output),
347                }
348            }
349            WorkflowContinuation::Delay { id, duration, next } => {
350                check_guards(backend, &snapshot.instance_id, Some(id)).await?;
351
352                if snapshot.get_task_result(id).is_some() {
353                    match next {
354                        Some(n) => {
355                            current = n;
356                            continue;
357                        }
358                        None => return Ok(current_input),
359                    }
360                }
361
362                return Err(park_at_delay(
363                    id,
364                    duration,
365                    next.as_deref(),
366                    current_input,
367                    snapshot,
368                    backend,
369                )
370                .await);
371            }
372            WorkflowContinuation::AwaitSignal {
373                id,
374                signal_name,
375                timeout,
376                next,
377            } => {
378                check_guards(backend, &snapshot.instance_id, Some(id)).await?;
379
380                // If we already consumed this signal (on resume), skip it
381                if snapshot.get_task_result(id).is_some() {
382                    match next {
383                        Some(n) => {
384                            current = n;
385                            // Use the signal payload as input for the next step
386                            current_input =
387                                snapshot.get_task_result_bytes(id).unwrap_or(current_input);
388                            continue;
389                        }
390                        None => return Ok(current_input),
391                    }
392                }
393
394                let err = park_at_signal(
395                    id,
396                    signal_name,
397                    timeout.as_ref(),
398                    next.as_deref(),
399                    snapshot,
400                    backend,
401                )
402                .await;
403
404                // If the signal was already buffered, park_at_signal consumed it
405                // and updated the snapshot — continue execution
406                if matches!(err, RuntimeError::Workflow(WorkflowError::SignalConsumed)) {
407                    if let Some(n) = next {
408                        current = n;
409                        current_input = snapshot.get_task_result_bytes(id).unwrap_or(current_input);
410                        continue;
411                    }
412                    let output = snapshot.get_task_result_bytes(id).unwrap_or(current_input);
413                    return Ok(output);
414                }
415
416                return Err(err);
417            }
418            WorkflowContinuation::Fork {
419                id: fork_id,
420                branches,
421                join,
422            } => {
423                check_guards(backend, &snapshot.instance_id, None).await?;
424
425                let branch_results =
426                    if let Some(cached) = collect_cached_branches(branches, snapshot) {
427                        cached
428                    } else {
429                        let outcome = execute_fork_branches_sequential(
430                            branches,
431                            &current_input,
432                            snapshot,
433                            backend,
434                            execute_task,
435                        )
436                        .await?;
437                        settle_fork_outcome(fork_id, outcome, join.as_deref(), snapshot, backend)
438                            .await?
439                    };
440
441                match resolve_join(join.as_deref(), &branch_results)? {
442                    JoinResolution::Continue { next, input } => {
443                        current = next;
444                        current_input = input;
445                    }
446                    JoinResolution::Done(output) => return Ok(output),
447                }
448            }
449        }
450    }
451}