Skip to main content

smol_workflow_engine/
workflow.rs

1use crate::agent_providers::{
2    create_agent_provider, AgentProvider, AgentProviderContext, AgentProviderResult,
3    AgentProviderRunInput, AgentRunIsolation, AgentUsage, AgentUsageCost,
4};
5use crate::js_runtime::rquickjs::RQuickJSWorkflowRuntime;
6use crate::js_runtime::{
7    WorkflowBudgetSnapshot, WorkflowJSRuntime, WorkflowModuleInput, WorkflowModuleOutput,
8    WorkflowRef, WorkflowRuntimeCall, WorkflowRuntimeExecution, WorkflowRuntimePoll,
9    WorkflowRuntimeRequest, WorkflowRuntimeRequestResolution,
10};
11use crate::metadata::{read_workflow_metadata, WorkflowMetadata};
12use anyhow::{anyhow, bail, Context};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::collections::{BTreeMap, BTreeSet, VecDeque};
16use std::fs;
17use std::path::{Path, PathBuf};
18use std::process::Command as StdCommand;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::{mpsc, watch};
22use tokio::task::{JoinSet, LocalSet};
23
24pub use crate::events::{
25    WorkflowEvent, WorkflowEventMetadata, WorkflowEventSink, WorkflowEventType,
26};
27
28#[async_trait::async_trait]
29pub trait AgentSessionLogSink: Send + Sync {
30    async fn write_agent_result(
31        &self,
32        provider: &str,
33        result: &AgentProviderResult,
34    ) -> anyhow::Result<()>;
35}
36
37#[async_trait::async_trait]
38pub trait WorkflowAgentRunner: Send + Sync {
39    async fn run_agent(
40        &self,
41        default_provider: Arc<dyn AgentProvider>,
42        provider_override: Option<String>,
43        input: AgentProviderRunInput,
44    ) -> anyhow::Result<AgentProviderResult>;
45
46    /// Whether the workflow scheduler should wrap this runner's `run_agent` call
47    /// in the per-agent retry loop.
48    ///
49    /// Most runners should use the default `true`: their `run_agent` method is a
50    /// single agent/provider boundary, so retrying it is safe and keeps retry
51    /// behavior centralized in the runtime scheduler.
52    ///
53    /// Runners that perform their own checkpointing or replay should return
54    /// `false` and apply retry internally around the nondeterministic provider
55    /// call. For example, the SQLite durable runner must not have the scheduler
56    /// retry its whole `run_agent` method, because each call advances durable
57    /// occurrence counters and may create a distinct checkpoint such as
58    /// `step:sig_x#2` instead of retrying the original durable step.
59    fn retry_in_runtime(&self) -> bool {
60        true
61    }
62
63    async fn sleep(&self, duration_ms: u64) -> anyhow::Result<()> {
64        tokio::time::sleep(std::time::Duration::from_millis(duration_ms)).await;
65        Ok(())
66    }
67}
68
69#[derive(Debug, Default)]
70pub struct DirectWorkflowAgentRunner;
71
72#[async_trait::async_trait]
73impl WorkflowAgentRunner for DirectWorkflowAgentRunner {
74    async fn run_agent(
75        &self,
76        default_provider: Arc<dyn AgentProvider>,
77        provider_override: Option<String>,
78        input: AgentProviderRunInput,
79    ) -> anyhow::Result<AgentProviderResult> {
80        run_agent_provider(default_provider, provider_override, input).await
81    }
82}
83
84pub struct RunWorkflowOptions {
85    pub script_path: PathBuf,
86    pub args: Value,
87    pub agent_provider: Arc<dyn AgentProvider>,
88    pub model_map: BTreeMap<String, String>,
89    pub budget_total: Option<u64>,
90    pub budget_spent: u64,
91    pub nesting_depth: usize,
92    pub max_parallel_agent_requests: Option<usize>,
93    pub agent_runner: Option<Arc<dyn WorkflowAgentRunner>>,
94    pub cancel_rx: Option<watch::Receiver<bool>>,
95    pub event_sink: Option<Arc<dyn WorkflowEventSink>>,
96    pub event_parent_step_id: Option<String>,
97    pub event_stream_start: Option<Instant>,
98    pub session_log_sink: Option<Arc<dyn AgentSessionLogSink>>,
99}
100
101#[derive(Debug)]
102pub struct RunWorkflowResult {
103    pub output: WorkflowModuleOutput,
104    pub logs: Vec<Vec<Value>>,
105    pub phases: Vec<WorkflowPhaseCall>,
106    pub agent_calls: Vec<WorkflowRuntimeRequest>,
107    pub workflow_calls: Vec<WorkflowRuntimeRequest>,
108    pub budget: WorkflowBudgetSnapshot,
109    pub token_usage: WorkflowTokenUsage,
110    pub token_usage_by_phase: std::collections::BTreeMap<String, WorkflowTokenUsage>,
111    pub agent_runs: Vec<WorkflowAgentRunSummary>,
112}
113
114#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
115#[serde(rename_all = "camelCase")]
116pub struct WorkflowTokenUsage {
117    pub input_tokens: u64,
118    pub output_tokens: u64,
119    pub cache_read_tokens: u64,
120    pub cache_write_tokens: u64,
121    pub total_tokens: u64,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub cost: Option<AgentUsageCost>,
124}
125
126#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
127#[serde(rename_all = "camelCase")]
128pub struct WorkflowAgentRunSummary {
129    pub id: String,
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub phase: Option<String>,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub provider: Option<String>,
134    #[serde(skip_serializing_if = "Option::is_none")]
135    pub model: Option<String>,
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub provider_session_id: Option<String>,
138    #[serde(skip_serializing_if = "Option::is_none")]
139    pub usage: Option<AgentUsage>,
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub isolation: Option<AgentRunIsolation>,
142}
143
144#[derive(Debug, Clone, PartialEq)]
145pub struct WorkflowPhaseCall {
146    pub name: String,
147    pub options: Option<Value>,
148}
149
150pub async fn run_workflow(options: RunWorkflowOptions) -> anyhow::Result<RunWorkflowResult> {
151    LocalSet::new().run_until(run_workflow_inner(options)).await
152}
153
154async fn run_workflow_inner(options: RunWorkflowOptions) -> anyhow::Result<RunWorkflowResult> {
155    log::debug!(
156        "run_workflow start script={} provider={} nesting_depth={} budget_total={:?} budget_spent={}",
157        options.script_path.display(),
158        options.agent_provider.name(),
159        options.nesting_depth,
160        options.budget_total,
161        options.budget_spent
162    );
163    let script_path = fs::canonicalize(&options.script_path).with_context(|| {
164        format!(
165            "failed to resolve workflow script {}",
166            options.script_path.display()
167        )
168    })?;
169    let metadata = read_workflow_metadata(&script_path)?.ok_or_else(|| {
170        anyhow!("Workflow script must export valid literal metadata as `export const meta = {{ name, description, ... }}`")
171    })?;
172    log::debug!(
173        "workflow metadata loaded name={} phases={}",
174        metadata.name,
175        metadata.phases.len()
176    );
177    let source = fs::read_to_string(&script_path)
178        .with_context(|| format!("failed to read workflow script {}", script_path.display()))?;
179    let runtime = RQuickJSWorkflowRuntime::new();
180    let execution = runtime.start_module(WorkflowModuleInput {
181        source,
182        source_name: script_path.to_string_lossy().into_owned(),
183        args: options.args,
184        budget: WorkflowBudgetSnapshot {
185            total: options.budget_total,
186            spent: options.budget_spent,
187        },
188        sandbox: Default::default(),
189    })?;
190
191    let (js_commands, js_command_rx) = mpsc::channel::<JsCommand>(64);
192    let (js_event_tx, mut js_events) = mpsc::channel::<JsEvent>(64);
193    let js_task = tokio::task::spawn_local(js_runtime_actor(execution, js_command_rx, js_event_tx));
194
195    let emit_lifecycle_events = options.event_sink.is_some();
196    let event_start = options.event_stream_start.unwrap_or_else(Instant::now);
197
198    let mut state = RunState {
199        script_path,
200        metadata,
201        event_start,
202        agent_provider: options.agent_provider,
203        model_map: options.model_map,
204        logs: Vec::new(),
205        phases: Vec::new(),
206        agent_calls: Vec::new(),
207        workflow_calls: Vec::new(),
208        budget: WorkflowBudgetSnapshot {
209            total: options.budget_total,
210            spent: options.budget_spent,
211        },
212        token_usage: WorkflowTokenUsage::default(),
213        token_usage_by_phase: Default::default(),
214        agent_runs: Vec::new(),
215        active_request_ids: BTreeSet::new(),
216        nesting_depth: options.nesting_depth,
217        max_parallel_agent_requests: options.max_parallel_agent_requests,
218        agent_runner: options
219            .agent_runner
220            .unwrap_or_else(|| Arc::new(DirectWorkflowAgentRunner)),
221        cancel_rx: options.cancel_rx,
222        event_sink: options.event_sink,
223        event_parent_step_id: options.event_parent_step_id,
224        session_log_sink: options.session_log_sink,
225    };
226
227    let mut pending_requests = VecDeque::<WorkflowRuntimeRequest>::new();
228    let mut agent_tasks = JoinSet::<AgentTaskCompletion>::new();
229    let mut sleep_tasks = JoinSet::<SleepTaskCompletion>::new();
230
231    if emit_lifecycle_events {
232        if let Err(error) = state
233            .emit_event(WorkflowEvent::started(rfc3339_now()?))
234            .await
235        {
236            let _ = send_js_command(&js_commands, JsCommand::Shutdown).await;
237            let _ = js_task.await;
238            return Err(error);
239        }
240    }
241
242    let workflow_result: anyhow::Result<RunWorkflowResult> = loop {
243        if let Err(error) = state
244            .start_pending_requests(
245                &mut pending_requests,
246                &mut agent_tasks,
247                &mut sleep_tasks,
248                &js_commands,
249            )
250            .await
251        {
252            break Err(error);
253        }
254
255        tokio::select! {
256            biased;
257            () = wait_for_cancellation(&mut state.cancel_rx) => {
258                break state.cancel_workflow(
259                    &mut pending_requests,
260                    &mut agent_tasks,
261                    &mut sleep_tasks,
262                    &js_commands,
263                    &mut js_events,
264                ).await;
265            }
266            event = js_events.recv() => {
267                let event = match event {
268                    Some(event) => event,
269                    None => break Err(anyhow!("JavaScript runtime actor stopped unexpectedly")),
270                };
271                match state.handle_js_event(event, &mut pending_requests).await {
272                    Ok(Some(result)) => break Ok(result),
273                    Ok(None) => {}
274                    Err(error) => break Err(error),
275                }
276            }
277            completion = agent_tasks.join_next(), if !agent_tasks.is_empty() => {
278                let completion = match completion {
279                    Some(Ok(completion)) => completion,
280                    Some(Err(error)) => break Err(anyhow!("agent provider task failed: {error}")),
281                    None => break Err(anyhow!("agent task set ended unexpectedly")),
282                };
283                let AgentTaskCompletion { id, input, provider, result } = completion;
284                state.active_request_ids.remove(&id);
285                let resolution = match result {
286                    Ok(result) => match state.apply_agent_result(&id, &input, provider, result).await {
287                        Ok(value) => WorkflowRuntimeRequestResolution::OkWithBudget {
288                            value,
289                            budget: state.budget.clone(),
290                        },
291                        Err(error) => WorkflowRuntimeRequestResolution::Err {
292                            message: error.to_string(),
293                        },
294                    },
295                    Err(error) => {
296                        let message = error.to_string();
297                        if let Err(emit_error) = state.emit_agent_failed_event(&id, provider.as_deref(), &message).await {
298                            log::debug!("failed to emit agent failure event: {emit_error:#}");
299                        }
300                        WorkflowRuntimeRequestResolution::Err { message }
301                    },
302                };
303                if let Err(error) = send_js_command(&js_commands, JsCommand::ResolveRequest { id, resolution }).await {
304                    break Err(error);
305                }
306            }
307            completion = sleep_tasks.join_next(), if !sleep_tasks.is_empty() => {
308                let completion = match completion {
309                    Some(Ok(completion)) => completion,
310                    Some(Err(error)) => break Err(anyhow!("sleep task failed: {error}")),
311                    None => break Err(anyhow!("sleep task set ended unexpectedly")),
312                };
313                let SleepTaskCompletion { id, result } = completion;
314                state.active_request_ids.remove(&id);
315                let resolution = match result {
316                    Ok(()) => WorkflowRuntimeRequestResolution::OkUndefined,
317                    Err(error) => WorkflowRuntimeRequestResolution::Err {
318                        message: error.to_string(),
319                    },
320                };
321                if let Err(error) = send_js_command(&js_commands, JsCommand::ResolveRequest { id, resolution }).await {
322                    break Err(error);
323                }
324            }
325        }
326    };
327
328    let _ = send_js_command(&js_commands, JsCommand::Shutdown).await;
329    let _ = js_task.await;
330
331    if emit_lifecycle_events {
332        match &workflow_result {
333            Ok(result) => {
334                state
335                    .emit_event(WorkflowEvent::result(
336                        result.token_usage.input_tokens,
337                        result.token_usage.output_tokens,
338                        result.token_usage.total_tokens,
339                        result.output.result.clone(),
340                    ))
341                    .await?
342            }
343            Err(error) => {
344                state
345                    .emit_event(WorkflowEvent::error(error.to_string(), None))
346                    .await?;
347            }
348        }
349    }
350
351    workflow_result
352}
353
354enum JsCommand {
355    ResolveRequest {
356        id: String,
357        resolution: WorkflowRuntimeRequestResolution,
358    },
359    Shutdown,
360}
361
362enum JsEvent {
363    Call(WorkflowRuntimeCall),
364    Request(WorkflowRuntimeRequest),
365    Complete(WorkflowModuleOutput),
366    Error(String),
367}
368
369async fn js_runtime_actor(
370    mut execution: Box<dyn WorkflowRuntimeExecution>,
371    mut commands: mpsc::Receiver<JsCommand>,
372    events: mpsc::Sender<JsEvent>,
373) {
374    let mut outstanding_requests = 0usize;
375    loop {
376        match execution.poll() {
377            Ok(WorkflowRuntimePoll::Call(call)) => {
378                if events.send(JsEvent::Call(call)).await.is_err() {
379                    return;
380                }
381            }
382            Ok(WorkflowRuntimePoll::Request(request)) => {
383                let requests = match execution.take_pending_requests() {
384                    Ok(requests) if requests.is_empty() => vec![request],
385                    Ok(requests) => requests,
386                    Err(error) => {
387                        let _ = events.send(JsEvent::Error(error.to_string())).await;
388                        return;
389                    }
390                };
391                outstanding_requests = outstanding_requests.saturating_add(requests.len());
392                for request in requests {
393                    if events.send(JsEvent::Request(request)).await.is_err() {
394                        return;
395                    }
396                }
397            }
398            Ok(WorkflowRuntimePoll::Complete(output)) => {
399                let _ = events.send(JsEvent::Complete(output)).await;
400                return;
401            }
402            Ok(WorkflowRuntimePoll::Pending) => {
403                if outstanding_requests == 0 {
404                    tokio::time::sleep(std::time::Duration::from_millis(1)).await;
405                    continue;
406                }
407                match commands.recv().await {
408                    Some(JsCommand::ResolveRequest { id, resolution }) => {
409                        outstanding_requests = outstanding_requests.saturating_sub(1);
410                        if let Err(error) = execution.resolve_request(&id, resolution) {
411                            let _ = events.send(JsEvent::Error(error.to_string())).await;
412                            return;
413                        }
414                    }
415                    Some(JsCommand::Shutdown) | None => return,
416                }
417            }
418            Err(error) => {
419                let _ = events.send(JsEvent::Error(error.to_string())).await;
420                return;
421            }
422        }
423    }
424}
425
426async fn send_js_command(
427    commands: &mpsc::Sender<JsCommand>,
428    command: JsCommand,
429) -> anyhow::Result<()> {
430    commands
431        .send(command)
432        .await
433        .map_err(|_| anyhow!("JavaScript runtime actor stopped unexpectedly"))
434}
435
436struct RunState {
437    script_path: PathBuf,
438    metadata: WorkflowMetadata,
439    event_start: Instant,
440    agent_provider: Arc<dyn AgentProvider>,
441    model_map: BTreeMap<String, String>,
442    logs: Vec<Vec<Value>>,
443    phases: Vec<WorkflowPhaseCall>,
444    agent_calls: Vec<WorkflowRuntimeRequest>,
445    workflow_calls: Vec<WorkflowRuntimeRequest>,
446    budget: WorkflowBudgetSnapshot,
447    token_usage: WorkflowTokenUsage,
448    token_usage_by_phase: std::collections::BTreeMap<String, WorkflowTokenUsage>,
449    agent_runs: Vec<WorkflowAgentRunSummary>,
450    active_request_ids: BTreeSet<String>,
451    nesting_depth: usize,
452    max_parallel_agent_requests: Option<usize>,
453    agent_runner: Arc<dyn WorkflowAgentRunner>,
454    cancel_rx: Option<watch::Receiver<bool>>,
455    event_sink: Option<Arc<dyn WorkflowEventSink>>,
456    event_parent_step_id: Option<String>,
457    session_log_sink: Option<Arc<dyn AgentSessionLogSink>>,
458}
459
460struct PreparedAgentRun {
461    provider_override: Option<String>,
462    input: AgentProviderRunInput,
463}
464
465struct AgentTaskCompletion {
466    id: String,
467    input: AgentProviderRunInput,
468    provider: Option<String>,
469    result: anyhow::Result<AgentProviderResult>,
470}
471
472struct SleepTaskCompletion {
473    id: String,
474    result: anyhow::Result<()>,
475}
476
477fn add_usage(total: &mut WorkflowTokenUsage, usage: Option<&AgentUsage>) {
478    let Some(usage) = usage else {
479        return;
480    };
481
482    total.input_tokens = total
483        .input_tokens
484        .saturating_add(usage.input_tokens.unwrap_or_default());
485    total.output_tokens = total
486        .output_tokens
487        .saturating_add(usage.output_tokens.unwrap_or_default());
488    total.cache_read_tokens = total
489        .cache_read_tokens
490        .saturating_add(usage.cache_read_tokens.unwrap_or_default());
491    total.cache_write_tokens = total
492        .cache_write_tokens
493        .saturating_add(usage.cache_write_tokens.unwrap_or_default());
494    total.total_tokens = total
495        .total_tokens
496        .saturating_add(usage.total_tokens.unwrap_or_default());
497
498    if let Some(cost) = usage.cost.as_ref() {
499        total.cost = Some(merge_cost(total.cost.as_ref(), cost));
500    }
501}
502
503fn merge_token_usage(total: &mut WorkflowTokenUsage, usage: &WorkflowTokenUsage) {
504    total.input_tokens = total.input_tokens.saturating_add(usage.input_tokens);
505    total.output_tokens = total.output_tokens.saturating_add(usage.output_tokens);
506    total.cache_read_tokens = total
507        .cache_read_tokens
508        .saturating_add(usage.cache_read_tokens);
509    total.cache_write_tokens = total
510        .cache_write_tokens
511        .saturating_add(usage.cache_write_tokens);
512    total.total_tokens = total.total_tokens.saturating_add(usage.total_tokens);
513    if let Some(cost) = usage.cost.as_ref() {
514        total.cost = Some(merge_cost(total.cost.as_ref(), cost));
515    }
516}
517
518fn merge_cost(left: Option<&AgentUsageCost>, right: &AgentUsageCost) -> AgentUsageCost {
519    AgentUsageCost {
520        input: sum_f64(left.and_then(|cost| cost.input), right.input),
521        output: sum_f64(left.and_then(|cost| cost.output), right.output),
522        cache_read: sum_f64(left.and_then(|cost| cost.cache_read), right.cache_read),
523        cache_write: sum_f64(left.and_then(|cost| cost.cache_write), right.cache_write),
524        total: sum_f64(left.and_then(|cost| cost.total), right.total),
525        currency: right
526            .currency
527            .clone()
528            .or_else(|| left.and_then(|cost| cost.currency.clone())),
529    }
530}
531
532fn elapsed_nanos(start: Instant) -> u64 {
533    u64::try_from(start.elapsed().as_nanos()).unwrap_or(u64::MAX)
534}
535
536fn rfc3339_now() -> anyhow::Result<String> {
537    Ok(time::OffsetDateTime::now_utc().format(&time::format_description::well_known::Rfc3339)?)
538}
539
540fn raw_agent_event_payloads(raw: &Value) -> Vec<Value> {
541    if let Some(events) = raw.get("events").and_then(Value::as_array) {
542        events.clone()
543    } else if let Some(items) = raw.as_array() {
544        items.clone()
545    } else {
546        vec![raw.clone()]
547    }
548}
549
550fn agent_session_event_payload(provider_event: Value, metadata: &WorkflowEventMetadata) -> Value {
551    let mut payload = serde_json::Map::new();
552    if let Some(provider) = metadata.provider.as_ref() {
553        payload.insert("provider".to_string(), Value::String(provider.clone()));
554    }
555    if let Some(session_id) = metadata.session_id.as_ref() {
556        payload.insert("sessionId".to_string(), Value::String(session_id.clone()));
557    }
558    if let Some(run_id) = metadata.run_id.as_ref() {
559        payload.insert("runId".to_string(), Value::String(run_id.clone()));
560    }
561    if let Some(step_id) = metadata.step_id.as_ref() {
562        payload.insert("stepId".to_string(), Value::String(step_id.clone()));
563    }
564    payload.insert("providerEvent".to_string(), provider_event);
565    Value::Object(payload)
566}
567
568fn truncate_for_event(value: &str, max_chars: usize) -> String {
569    let mut chars = value.chars();
570    let truncated = chars.by_ref().take(max_chars).collect::<String>();
571    if chars.next().is_some() {
572        format!("{truncated}…")
573    } else {
574        truncated
575    }
576}
577
578fn format_log_message(values: &[Value]) -> String {
579    values
580        .iter()
581        .map(|value| match value {
582            Value::String(value) => value.clone(),
583            value => serde_json::to_string(value).unwrap_or_else(|_| String::from("<unprintable>")),
584        })
585        .collect::<Vec<_>>()
586        .join(" ")
587}
588
589fn sum_f64(left: Option<f64>, right: Option<f64>) -> Option<f64> {
590    match (left, right) {
591        (None, None) => None,
592        (left, right) => Some(left.unwrap_or_default() + right.unwrap_or_default()),
593    }
594}
595
596async fn wait_for_cancellation(cancel_rx: &mut Option<watch::Receiver<bool>>) {
597    let Some(cancel_rx) = cancel_rx else {
598        std::future::pending::<()>().await;
599        return;
600    };
601    while !*cancel_rx.borrow() {
602        if cancel_rx.changed().await.is_err() {
603            return;
604        }
605    }
606}
607
608impl RunState {
609    async fn handle_js_event(
610        &mut self,
611        event: JsEvent,
612        pending_requests: &mut VecDeque<WorkflowRuntimeRequest>,
613    ) -> anyhow::Result<Option<RunWorkflowResult>> {
614        match event {
615            JsEvent::Call(call) => self.handle_call(call).await?,
616            JsEvent::Request(request) => {
617                log::debug!(
618                    "workflow runtime request id={} kind={}",
619                    request.id(),
620                    request.kind()
621                );
622                pending_requests.push_back(request);
623            }
624            JsEvent::Complete(output) => {
625                log::debug!(
626                    "run_workflow complete script={} budget_spent={}",
627                    self.script_path.display(),
628                    self.budget.spent
629                );
630                return Ok(Some(RunWorkflowResult {
631                    output,
632                    logs: std::mem::take(&mut self.logs),
633                    phases: std::mem::take(&mut self.phases),
634                    agent_calls: std::mem::take(&mut self.agent_calls),
635                    workflow_calls: std::mem::take(&mut self.workflow_calls),
636                    budget: self.budget.clone(),
637                    token_usage: std::mem::take(&mut self.token_usage),
638                    token_usage_by_phase: std::mem::take(&mut self.token_usage_by_phase),
639                    agent_runs: std::mem::take(&mut self.agent_runs),
640                }));
641            }
642            JsEvent::Error(message) => bail!(message),
643        }
644        Ok(None)
645    }
646
647    async fn start_pending_requests(
648        &mut self,
649        pending_requests: &mut VecDeque<WorkflowRuntimeRequest>,
650        agent_tasks: &mut JoinSet<AgentTaskCompletion>,
651        sleep_tasks: &mut JoinSet<SleepTaskCompletion>,
652        js_commands: &mpsc::Sender<JsCommand>,
653    ) -> anyhow::Result<()> {
654        loop {
655            let Some(request) = pending_requests.front() else {
656                return Ok(());
657            };
658            if matches!(request, WorkflowRuntimeRequest::Agent { .. })
659                && !self.agent_capacity_available(agent_tasks.len())
660            {
661                return Ok(());
662            }
663
664            let request = pending_requests
665                .pop_front()
666                .expect("pending request should exist");
667            match request {
668                WorkflowRuntimeRequest::Agent { .. } => match self.prepare_agent_request(request) {
669                    Ok((id, prepared)) => {
670                        self.emit_agent_started_event(&id, &prepared).await?;
671                        self.spawn_agent_task(agent_tasks, id, prepared);
672                    }
673                    Err((id, error)) => {
674                        send_js_command(
675                            js_commands,
676                            JsCommand::ResolveRequest {
677                                id,
678                                resolution: WorkflowRuntimeRequestResolution::Err {
679                                    message: error.to_string(),
680                                },
681                            },
682                        )
683                        .await?;
684                    }
685                },
686                WorkflowRuntimeRequest::Sleep { id, duration_ms } => {
687                    self.spawn_sleep_task(sleep_tasks, id, duration_ms);
688                }
689                WorkflowRuntimeRequest::Workflow {
690                    id,
691                    workflow_ref,
692                    args,
693                } => {
694                    self.workflow_calls.push(WorkflowRuntimeRequest::Workflow {
695                        id: id.clone(),
696                        workflow_ref: workflow_ref.clone(),
697                        args: args.clone(),
698                    });
699                    let parent_event_step_id = self.event_step_id(&id);
700                    let resolution = match self
701                        .handle_workflow(parent_event_step_id, workflow_ref, args)
702                        .await
703                    {
704                        Ok(value) => WorkflowRuntimeRequestResolution::OkWithBudget {
705                            value,
706                            budget: self.budget.clone(),
707                        },
708                        Err(error) => WorkflowRuntimeRequestResolution::Err {
709                            message: error.to_string(),
710                        },
711                    };
712                    send_js_command(js_commands, JsCommand::ResolveRequest { id, resolution })
713                        .await?;
714                }
715            }
716        }
717    }
718
719    async fn cancel_workflow(
720        &mut self,
721        pending_requests: &mut VecDeque<WorkflowRuntimeRequest>,
722        agent_tasks: &mut JoinSet<AgentTaskCompletion>,
723        sleep_tasks: &mut JoinSet<SleepTaskCompletion>,
724        js_commands: &mpsc::Sender<JsCommand>,
725        js_events: &mut mpsc::Receiver<JsEvent>,
726    ) -> anyhow::Result<RunWorkflowResult> {
727        log::debug!(
728            "workflow cancellation requested script={}",
729            self.script_path.display()
730        );
731
732        if pending_requests.is_empty()
733            && self.active_request_ids.is_empty()
734            && agent_tasks.is_empty()
735            && sleep_tasks.is_empty()
736            && self
737                .reject_next_runtime_request_for_cancellation(js_commands, js_events)
738                .await
739        {
740            bail!("workflow cancelled");
741        }
742
743        self.reject_pending_requests_for_cancellation(pending_requests, js_commands)
744            .await;
745        sleep_tasks.abort_all();
746        self.reject_active_sleep_requests_for_cancellation(sleep_tasks, js_commands)
747            .await;
748
749        if self.session_log_sink.is_some() {
750            while let Some(completion) = agent_tasks.join_next().await {
751                match completion {
752                    Ok(AgentTaskCompletion {
753                        id,
754                        input,
755                        provider,
756                        result: Ok(result),
757                    }) => {
758                        self.active_request_ids.remove(&id);
759                        if let Err(error) = self
760                            .emit_agent_result_events(&id, provider.as_deref(), &result)
761                            .await
762                        {
763                            log::debug!("failed to emit drained agent events during cancellation: {error:#}");
764                        }
765                        if let Err(error) = self
766                            .emit_agent_completed_event(&id, provider.as_deref(), &result)
767                            .await
768                        {
769                            log::debug!("failed to emit drained agent completion during cancellation: {error:#}");
770                        }
771                        self.record_agent_run(&id, &input, provider, &result);
772                        self.reject_request_for_cancellation(id, js_commands).await;
773                    }
774                    Ok(AgentTaskCompletion {
775                        id,
776                        provider,
777                        result: Err(error),
778                        ..
779                    }) => {
780                        self.active_request_ids.remove(&id);
781                        let message = error.to_string();
782                        if let Err(error) = self
783                            .emit_agent_failed_event(&id, provider.as_deref(), &message)
784                            .await
785                        {
786                            log::debug!("failed to emit drained agent failure during cancellation: {error:#}");
787                        }
788                        log::debug!("agent task failed while draining cancellation: {message}");
789                        self.reject_request_for_cancellation(id, js_commands).await;
790                    }
791                    Err(error) => {
792                        log::debug!("agent task join failed while draining cancellation: {error}");
793                    }
794                }
795            }
796        } else {
797            let ids: Vec<String> = self.active_request_ids.iter().cloned().collect();
798            agent_tasks.abort_all();
799            for id in ids {
800                self.active_request_ids.remove(&id);
801                self.reject_request_for_cancellation(id, js_commands).await;
802            }
803        }
804
805        self.reject_remaining_active_requests_for_cancellation(js_commands)
806            .await;
807        self.drain_runtime_after_cancellation(js_events).await;
808        let _ = send_js_command(js_commands, JsCommand::Shutdown).await;
809        bail!("workflow cancelled")
810    }
811
812    async fn reject_next_runtime_request_for_cancellation(
813        &mut self,
814        js_commands: &mpsc::Sender<JsCommand>,
815        js_events: &mut mpsc::Receiver<JsEvent>,
816    ) -> bool {
817        loop {
818            match js_events.recv().await {
819                Some(JsEvent::Call(call)) => {
820                    let _ = self.handle_call(call).await;
821                }
822                Some(JsEvent::Request(request)) => {
823                    self.reject_request_for_cancellation(request.id().to_string(), js_commands)
824                        .await;
825                    return false;
826                }
827                Some(JsEvent::Complete(_)) | Some(JsEvent::Error(_)) | None => return true,
828            }
829        }
830    }
831
832    async fn reject_pending_requests_for_cancellation(
833        &mut self,
834        pending_requests: &mut VecDeque<WorkflowRuntimeRequest>,
835        js_commands: &mpsc::Sender<JsCommand>,
836    ) {
837        while let Some(request) = pending_requests.pop_front() {
838            self.reject_request_for_cancellation(request.id().to_string(), js_commands)
839                .await;
840        }
841    }
842
843    async fn reject_active_sleep_requests_for_cancellation(
844        &mut self,
845        sleep_tasks: &mut JoinSet<SleepTaskCompletion>,
846        js_commands: &mpsc::Sender<JsCommand>,
847    ) {
848        while let Some(completion) = sleep_tasks.join_next().await {
849            if let Ok(SleepTaskCompletion { id, .. }) = completion {
850                self.active_request_ids.remove(&id);
851                self.reject_request_for_cancellation(id, js_commands).await;
852            }
853        }
854    }
855
856    async fn reject_remaining_active_requests_for_cancellation(
857        &mut self,
858        js_commands: &mpsc::Sender<JsCommand>,
859    ) {
860        let ids: Vec<String> = self.active_request_ids.iter().cloned().collect();
861        for id in ids {
862            self.active_request_ids.remove(&id);
863            self.reject_request_for_cancellation(id, js_commands).await;
864        }
865    }
866
867    async fn reject_request_for_cancellation(
868        &self,
869        id: String,
870        js_commands: &mpsc::Sender<JsCommand>,
871    ) {
872        let _ = send_js_command(
873            js_commands,
874            JsCommand::ResolveRequest {
875                id,
876                resolution: WorkflowRuntimeRequestResolution::Err {
877                    message: "workflow cancelled".to_string(),
878                },
879            },
880        )
881        .await;
882    }
883
884    async fn drain_runtime_after_cancellation(&mut self, js_events: &mut mpsc::Receiver<JsEvent>) {
885        while let Some(event) = js_events.recv().await {
886            match event {
887                JsEvent::Call(call) => {
888                    let _ = self.handle_call(call).await;
889                }
890                JsEvent::Request(request) => {
891                    log::debug!(
892                        "ignoring request after cancellation id={} kind={}",
893                        request.id(),
894                        request.kind()
895                    );
896                }
897                JsEvent::Complete(_) | JsEvent::Error(_) => break,
898            }
899        }
900    }
901
902    fn event_step_id(&self, runtime_request_id: &str) -> String {
903        let parent = self.event_parent_step_id.as_deref().unwrap_or("");
904        let hash = blake3::hash(
905            format!("{parent}:{}:{runtime_request_id}", self.nesting_depth).as_bytes(),
906        );
907        format!("step_{}", &hash.to_hex()[..16])
908    }
909
910    async fn emit_event(&self, mut event: WorkflowEvent) -> anyhow::Result<()> {
911        if (event.event_type.as_str() != "workflow.started" || self.nesting_depth > 0)
912            && event.elapsed_nanos.is_none()
913        {
914            event.elapsed_nanos = Some(elapsed_nanos(self.event_start));
915        }
916        let metadata = event
917            .metadata
918            .get_or_insert_with(WorkflowEventMetadata::default);
919        if metadata.workflow_depth.is_none() {
920            metadata.workflow_depth = Some(u32::try_from(self.nesting_depth).unwrap_or(u32::MAX));
921        }
922        if metadata.parent_step_id.is_none() {
923            metadata.parent_step_id = self.event_parent_step_id.clone();
924        }
925        if let Some(event_sink) = self.event_sink.as_ref() {
926            event_sink.emit(event).await?;
927        }
928        Ok(())
929    }
930
931    async fn handle_call(&mut self, call: WorkflowRuntimeCall) -> anyhow::Result<()> {
932        match call {
933            WorkflowRuntimeCall::Log { values } => {
934                self.emit_event(WorkflowEvent::log(format_log_message(&values)))
935                    .await?;
936                self.logs.push(values);
937            }
938            WorkflowRuntimeCall::Phase { name, options } => {
939                let phase = WorkflowPhaseCall { name, options };
940                self.emit_event(WorkflowEvent::phase(
941                    phase.name.clone(),
942                    phase.options.clone(),
943                ))
944                .await?;
945                self.phases.push(phase);
946            }
947        }
948        Ok(())
949    }
950
951    fn agent_capacity_available(&self, in_flight: usize) -> bool {
952        let max_parallel = self
953            .max_parallel_agent_requests
954            .filter(|value| *value > 0)
955            .unwrap_or(usize::MAX);
956        in_flight < max_parallel
957    }
958
959    fn prepare_agent_request(
960        &mut self,
961        request: WorkflowRuntimeRequest,
962    ) -> Result<(String, PreparedAgentRun), (String, anyhow::Error)> {
963        match request {
964            WorkflowRuntimeRequest::Agent {
965                id,
966                prompt,
967                options,
968            } => {
969                self.agent_calls.push(WorkflowRuntimeRequest::Agent {
970                    id: id.clone(),
971                    prompt: prompt.clone(),
972                    options: options.clone(),
973                });
974                match self.prepare_agent_run(prompt, options) {
975                    Ok(prepared) => Ok((id, prepared)),
976                    Err(error) => Err((id, error)),
977                }
978            }
979            WorkflowRuntimeRequest::Workflow { .. } | WorkflowRuntimeRequest::Sleep { .. } => {
980                unreachable!("prepare_agent_request only accepts agent requests")
981            }
982        }
983    }
984
985    fn spawn_agent_task(
986        &mut self,
987        agent_tasks: &mut JoinSet<AgentTaskCompletion>,
988        id: String,
989        prepared: PreparedAgentRun,
990    ) {
991        let default_provider_name = self.agent_provider.name().to_string();
992        let default_provider = Arc::clone(&self.agent_provider);
993        let agent_runner = Arc::clone(&self.agent_runner);
994        let retry_in_runtime = agent_runner.retry_in_runtime();
995        let cancel_rx = self.cancel_rx.clone();
996        let completion_input = prepared.input.clone();
997        let completion_provider = prepared
998            .provider_override
999            .clone()
1000            .or(Some(default_provider_name));
1001        let session_log_sink = self.session_log_sink.clone();
1002        let max_parallel = self
1003            .max_parallel_agent_requests
1004            .filter(|value| *value > 0)
1005            .unwrap_or(usize::MAX);
1006        log::debug!(
1007            "starting agent request id={} in_flight_after_start={} max_parallel={}",
1008            id,
1009            agent_tasks.len() + 1,
1010            max_parallel
1011        );
1012        self.active_request_ids.insert(id.clone());
1013        agent_tasks.spawn(async move {
1014            let result = if retry_in_runtime {
1015                run_agent_runner_with_retry(
1016                    Arc::clone(&agent_runner),
1017                    default_provider,
1018                    prepared.provider_override,
1019                    prepared.input,
1020                    cancel_rx,
1021                )
1022                .await
1023            } else {
1024                agent_runner
1025                    .run_agent(default_provider, prepared.provider_override, prepared.input)
1026                    .await
1027            };
1028            let result = match result {
1029                Ok(result) => {
1030                    if let Some(session_log_sink) = session_log_sink.as_ref() {
1031                        let provider_name = completion_provider
1032                            .as_deref()
1033                            .expect("completion provider should always be set");
1034                        match session_log_sink
1035                            .write_agent_result(provider_name, &result)
1036                            .await
1037                        {
1038                            Ok(()) => Ok(result),
1039                            Err(error) => Err(error),
1040                        }
1041                    } else {
1042                        Ok(result)
1043                    }
1044                }
1045                Err(error) => Err(error),
1046            };
1047            AgentTaskCompletion {
1048                id,
1049                input: completion_input,
1050                provider: completion_provider,
1051                result,
1052            }
1053        });
1054    }
1055
1056    fn spawn_sleep_task(
1057        &mut self,
1058        sleep_tasks: &mut JoinSet<SleepTaskCompletion>,
1059        id: String,
1060        duration_ms: u64,
1061    ) {
1062        let agent_runner = Arc::clone(&self.agent_runner);
1063        log::debug!(
1064            "starting sleep request id={} duration_ms={}",
1065            id,
1066            duration_ms
1067        );
1068        self.active_request_ids.insert(id.clone());
1069        sleep_tasks.spawn(async move {
1070            SleepTaskCompletion {
1071                id,
1072                result: agent_runner.sleep(duration_ms).await,
1073            }
1074        });
1075    }
1076
1077    fn prepare_agent_run(
1078        &self,
1079        prompt: String,
1080        options: Option<Value>,
1081    ) -> anyhow::Result<PreparedAgentRun> {
1082        let options = apply_phase_defaults(options, &self.metadata);
1083        let context = AgentProviderContext {
1084            phase: options
1085                .as_ref()
1086                .and_then(|options| options.get("phase"))
1087                .and_then(Value::as_str)
1088                .map(ToString::to_string),
1089            cwd: self.script_path.parent().map(Path::to_path_buf),
1090        };
1091        let provider_override = options
1092            .as_ref()
1093            .and_then(|options| options.get("provider"))
1094            .and_then(Value::as_str)
1095            .map(ToString::to_string);
1096        let provider_name = provider_override
1097            .as_deref()
1098            .unwrap_or_else(|| self.agent_provider.name());
1099        let options = resolve_model_options(options, provider_name, &self.model_map)?;
1100        agent_retry_policy(&options)?;
1101        log::debug!(
1102            "agent call provider={} phase={:?} model={:?} prompt_len={}",
1103            provider_name,
1104            context.phase.as_deref(),
1105            options
1106                .as_ref()
1107                .and_then(|options| options.get("model"))
1108                .and_then(Value::as_str),
1109            prompt.len()
1110        );
1111        Ok(PreparedAgentRun {
1112            provider_override,
1113            input: AgentProviderRunInput {
1114                prompt,
1115                options,
1116                context,
1117            },
1118        })
1119    }
1120
1121    async fn emit_agent_started_event(
1122        &self,
1123        id: &str,
1124        prepared: &PreparedAgentRun,
1125    ) -> anyhow::Result<()> {
1126        let provider = prepared
1127            .provider_override
1128            .as_deref()
1129            .unwrap_or_else(|| self.agent_provider.name());
1130        let metadata = self.agent_event_metadata(id, Some(provider), None);
1131        self.emit_event(WorkflowEvent::agent_started(
1132            serde_json::json!({
1133                "phase": prepared.input.context.phase,
1134                "promptPreview": truncate_for_event(&prepared.input.prompt, 200),
1135            }),
1136            metadata,
1137        ))
1138        .await
1139    }
1140
1141    async fn apply_agent_result(
1142        &mut self,
1143        id: &str,
1144        input: &AgentProviderRunInput,
1145        provider: Option<String>,
1146        result: AgentProviderResult,
1147    ) -> anyhow::Result<Value> {
1148        if let Some(output_tokens) = result.usage.as_ref().and_then(|usage| usage.output_tokens) {
1149            self.budget.spent = self.budget.spent.saturating_add(output_tokens);
1150        }
1151        self.emit_agent_result_events(id, provider.as_deref(), &result)
1152            .await?;
1153        self.emit_agent_completed_event(id, provider.as_deref(), &result)
1154            .await?;
1155        self.record_agent_run(id, input, provider, &result);
1156        log::debug!(
1157            "agent call complete session_id={:?} output_tokens={:?} budget_spent={}",
1158            result.session_id,
1159            result.usage.as_ref().and_then(|usage| usage.output_tokens),
1160            self.budget.spent
1161        );
1162        Ok(result.output)
1163    }
1164
1165    async fn emit_agent_result_events(
1166        &self,
1167        id: &str,
1168        provider: Option<&str>,
1169        result: &AgentProviderResult,
1170    ) -> anyhow::Result<()> {
1171        let Some(raw) = result.raw.as_ref() else {
1172            return Ok(());
1173        };
1174        let metadata = self.agent_event_metadata(id, provider, result.session_id.clone());
1175        for provider_event in raw_agent_event_payloads(raw) {
1176            let event_data = agent_session_event_payload(provider_event, &metadata);
1177            self.emit_event(WorkflowEvent::agent_event(event_data, metadata.clone()))
1178                .await?;
1179        }
1180        Ok(())
1181    }
1182
1183    async fn emit_agent_completed_event(
1184        &self,
1185        id: &str,
1186        provider: Option<&str>,
1187        result: &AgentProviderResult,
1188    ) -> anyhow::Result<()> {
1189        let metadata = self.agent_event_metadata(id, provider, result.session_id.clone());
1190        self.emit_event(WorkflowEvent::agent_completed(
1191            serde_json::json!({
1192                "sessionId": result.session_id,
1193                "model": result.model,
1194                "usage": result.usage,
1195            }),
1196            metadata,
1197        ))
1198        .await
1199    }
1200
1201    async fn emit_agent_failed_event(
1202        &self,
1203        id: &str,
1204        provider: Option<&str>,
1205        message: &str,
1206    ) -> anyhow::Result<()> {
1207        let metadata = self.agent_event_metadata(id, provider, None);
1208        self.emit_event(WorkflowEvent::agent_failed(
1209            serde_json::json!({ "message": message }),
1210            metadata,
1211        ))
1212        .await
1213    }
1214
1215    fn agent_event_metadata(
1216        &self,
1217        id: &str,
1218        provider: Option<&str>,
1219        session_id: Option<String>,
1220    ) -> WorkflowEventMetadata {
1221        WorkflowEventMetadata {
1222            run_id: None,
1223            step_id: Some(self.event_step_id(id)),
1224            provider: Some(
1225                provider
1226                    .unwrap_or_else(|| self.agent_provider.name())
1227                    .to_string(),
1228            ),
1229            session_id,
1230            workflow_depth: None,
1231            parent_step_id: None,
1232        }
1233    }
1234
1235    fn record_agent_run(
1236        &mut self,
1237        id: &str,
1238        input: &AgentProviderRunInput,
1239        provider: Option<String>,
1240        result: &AgentProviderResult,
1241    ) {
1242        add_usage(&mut self.token_usage, result.usage.as_ref());
1243        if let Some(phase) = input.context.phase.as_ref() {
1244            let phase_usage = self.token_usage_by_phase.entry(phase.clone()).or_default();
1245            add_usage(phase_usage, result.usage.as_ref());
1246        }
1247        let model = result.model.clone().or_else(|| {
1248            input
1249                .options
1250                .as_ref()
1251                .and_then(|options| options.get("model"))
1252                .and_then(Value::as_str)
1253                .map(ToString::to_string)
1254        });
1255        self.agent_runs.push(WorkflowAgentRunSummary {
1256            id: id.to_string(),
1257            phase: input.context.phase.clone(),
1258            provider,
1259            model,
1260            provider_session_id: result.session_id.clone(),
1261            usage: result.usage.clone(),
1262            isolation: result.isolation.clone(),
1263        });
1264    }
1265
1266    async fn handle_workflow(
1267        &mut self,
1268        parent_step_id: String,
1269        workflow_ref: WorkflowRef,
1270        args: Option<Value>,
1271    ) -> anyhow::Result<Value> {
1272        if self.nesting_depth >= 1 {
1273            bail!("Nested workflow() calls are limited to one level");
1274        }
1275        let script_path = match workflow_ref {
1276            WorkflowRef::ScriptPath { script_path } => {
1277                resolve_relative_script(&self.script_path, &script_path)
1278            }
1279            WorkflowRef::Name(name) => resolve_named_workflow(&name)?,
1280        };
1281        log::debug!("child workflow call script={}", script_path.display());
1282        let child = Box::pin(run_workflow_inner(RunWorkflowOptions {
1283            script_path,
1284            args: args.unwrap_or(Value::Null),
1285            agent_provider: Arc::clone(&self.agent_provider),
1286            model_map: self.model_map.clone(),
1287            budget_total: self.budget.total,
1288            budget_spent: self.budget.spent,
1289            nesting_depth: self.nesting_depth + 1,
1290            max_parallel_agent_requests: self.max_parallel_agent_requests,
1291            agent_runner: Some(Arc::clone(&self.agent_runner)),
1292            cancel_rx: self.cancel_rx.clone(),
1293            event_sink: self.event_sink.clone(),
1294            event_parent_step_id: Some(parent_step_id),
1295            event_stream_start: Some(self.event_start),
1296            session_log_sink: self.session_log_sink.clone(),
1297        }))
1298        .await?;
1299        self.budget = child.budget;
1300        self.logs.extend(child.logs);
1301        self.phases.extend(child.phases);
1302        self.agent_calls.extend(child.agent_calls);
1303        self.workflow_calls.extend(child.workflow_calls);
1304        merge_token_usage(&mut self.token_usage, &child.token_usage);
1305        for (phase, usage) in child.token_usage_by_phase {
1306            merge_token_usage(self.token_usage_by_phase.entry(phase).or_default(), &usage);
1307        }
1308        self.agent_runs.extend(child.agent_runs);
1309        Ok(child.output.result)
1310    }
1311}
1312
1313async fn run_agent_runner_with_retry(
1314    agent_runner: Arc<dyn WorkflowAgentRunner>,
1315    default_provider: Arc<dyn AgentProvider>,
1316    provider_override: Option<String>,
1317    input: AgentProviderRunInput,
1318    mut cancel_rx: Option<watch::Receiver<bool>>,
1319) -> anyhow::Result<AgentProviderResult> {
1320    let retry = agent_retry_policy(&input.options)?;
1321    let mut final_result = None;
1322    for attempt in 1..=retry.max_attempts {
1323        let attempt_result = agent_runner
1324            .run_agent(
1325                Arc::clone(&default_provider),
1326                provider_override.clone(),
1327                input.clone(),
1328            )
1329            .await;
1330        match attempt_result {
1331            Ok(result) => {
1332                final_result = Some(Ok(result));
1333                break;
1334            }
1335            Err(error) if attempt < retry.max_attempts => {
1336                log::debug!(
1337                    "agent call failed on attempt {attempt}/{}; retrying after {}ms: {error:#}",
1338                    retry.max_attempts,
1339                    retry.backoff_ms
1340                );
1341                sleep_retry_backoff(retry.backoff_ms, &mut cancel_rx).await?;
1342            }
1343            Err(error) => {
1344                final_result = Some(Err(error));
1345                break;
1346            }
1347        }
1348    }
1349    final_result.unwrap_or_else(|| Err(anyhow!("agent retry loop finished without a result")))
1350}
1351
1352async fn sleep_retry_backoff(
1353    backoff_ms: u64,
1354    cancel_rx: &mut Option<watch::Receiver<bool>>,
1355) -> anyhow::Result<()> {
1356    if backoff_ms == 0 {
1357        return Ok(());
1358    }
1359    let Some(cancel_rx) = cancel_rx.as_mut() else {
1360        tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
1361        return Ok(());
1362    };
1363    if *cancel_rx.borrow() {
1364        bail!("workflow cancelled");
1365    }
1366    let sleep = tokio::time::sleep(Duration::from_millis(backoff_ms));
1367    tokio::pin!(sleep);
1368    loop {
1369        tokio::select! {
1370            _ = &mut sleep => return Ok(()),
1371            changed = cancel_rx.changed() => {
1372                match changed {
1373                    Ok(()) if *cancel_rx.borrow() => bail!("workflow cancelled"),
1374                    Ok(()) => continue,
1375                    Err(_) => {
1376                        sleep.await;
1377                        return Ok(());
1378                    }
1379                }
1380            }
1381        }
1382    }
1383}
1384
1385pub(crate) async fn run_agent_provider_with_retry(
1386    default_provider: Arc<dyn AgentProvider>,
1387    provider_override: Option<String>,
1388    input: AgentProviderRunInput,
1389    mut cancel_rx: Option<watch::Receiver<bool>>,
1390) -> anyhow::Result<AgentProviderResult> {
1391    let retry = agent_retry_policy(&input.options)?;
1392    let provider = resolve_agent_provider(default_provider, provider_override)?;
1393    let mut final_result = None;
1394    for attempt in 1..=retry.max_attempts {
1395        let attempt_result =
1396            run_agent_with_optional_isolation(Arc::clone(&provider), input.clone()).await;
1397        match attempt_result {
1398            Ok(result) => {
1399                final_result = Some(Ok(result));
1400                break;
1401            }
1402            Err(error) if attempt < retry.max_attempts => {
1403                log::debug!(
1404                    "agent provider failed on attempt {attempt}/{}; retrying after {}ms: {error:#}",
1405                    retry.max_attempts,
1406                    retry.backoff_ms
1407                );
1408                sleep_retry_backoff(retry.backoff_ms, &mut cancel_rx).await?;
1409            }
1410            Err(error) => {
1411                final_result = Some(Err(error));
1412                break;
1413            }
1414        }
1415    }
1416    final_result.unwrap_or_else(|| Err(anyhow!("agent retry loop finished without a result")))
1417}
1418
1419pub(crate) async fn run_agent_provider(
1420    default_provider: Arc<dyn AgentProvider>,
1421    provider_override: Option<String>,
1422    input: AgentProviderRunInput,
1423) -> anyhow::Result<AgentProviderResult> {
1424    let provider = resolve_agent_provider(default_provider, provider_override)?;
1425    run_agent_with_optional_isolation(provider, input).await
1426}
1427
1428fn resolve_agent_provider(
1429    default_provider: Arc<dyn AgentProvider>,
1430    provider_override: Option<String>,
1431) -> anyhow::Result<Arc<dyn AgentProvider>> {
1432    if let Some(provider_override) = provider_override {
1433        Ok(Arc::from(create_agent_provider(&provider_override)?))
1434    } else {
1435        Ok(default_provider)
1436    }
1437}
1438
1439#[derive(Debug, Clone, Copy)]
1440pub(crate) struct AgentRetryPolicy {
1441    pub max_attempts: u32,
1442    pub backoff_ms: u64,
1443}
1444
1445pub(crate) fn agent_retry_policy(options: &Option<Value>) -> anyhow::Result<AgentRetryPolicy> {
1446    let default = AgentRetryPolicy {
1447        max_attempts: 1,
1448        backoff_ms: 0,
1449    };
1450    let Some(retry) = options.as_ref().and_then(|options| options.get("retry")) else {
1451        return Ok(default);
1452    };
1453    if retry.is_null() {
1454        return Ok(default);
1455    }
1456    let object = retry
1457        .as_object()
1458        .ok_or_else(|| anyhow!("agent retry option must be an object"))?;
1459    let max_attempts = match object.get("maxAttempts") {
1460        Some(value) => {
1461            let value = value
1462                .as_u64()
1463                .ok_or_else(|| anyhow!("agent retry.maxAttempts must be a positive integer"))?;
1464            if value == 0 || value > u32::MAX as u64 {
1465                bail!("agent retry.maxAttempts must be between 1 and {}", u32::MAX);
1466            }
1467            value as u32
1468        }
1469        None => default.max_attempts,
1470    };
1471    let backoff_ms = match object.get("backoffMs") {
1472        Some(value) => value
1473            .as_u64()
1474            .ok_or_else(|| anyhow!("agent retry.backoffMs must be a non-negative integer"))?,
1475        None => default.backoff_ms,
1476    };
1477    Ok(AgentRetryPolicy {
1478        max_attempts,
1479        backoff_ms,
1480    })
1481}
1482
1483async fn run_agent_with_optional_isolation(
1484    provider: Arc<dyn AgentProvider>,
1485    input: AgentProviderRunInput,
1486) -> anyhow::Result<AgentProviderResult> {
1487    if !requests_worktree_isolation(&input.options) {
1488        return run_agent_with_schema_validation(provider, input).await;
1489    }
1490
1491    let isolation = WorktreeIsolation::create(input.context.cwd.as_deref())?;
1492    let isolation_info = isolation.info();
1493    let mut isolated_input = input;
1494    isolated_input.context.cwd = Some(isolation.cwd.clone());
1495    let mut result = run_agent_with_schema_validation(provider, isolated_input).await;
1496    if let Ok(result) = &mut result {
1497        result.isolation = Some(isolation_info);
1498    }
1499    if let Err(error) = isolation.cleanup() {
1500        log::warn!("failed to cleanup isolated agent worktree: {error:#}");
1501    }
1502    result
1503}
1504
1505fn requests_worktree_isolation(options: &Option<Value>) -> bool {
1506    options
1507        .as_ref()
1508        .and_then(|options| options.get("isolation"))
1509        .and_then(Value::as_str)
1510        == Some("worktree")
1511}
1512
1513struct WorktreeIsolation {
1514    repo_root: PathBuf,
1515    worktree_root: PathBuf,
1516    cwd: PathBuf,
1517    branch_name: String,
1518    cleaned: bool,
1519    _temp_dir: tempfile::TempDir,
1520}
1521
1522impl WorktreeIsolation {
1523    fn create(cwd: Option<&Path>) -> anyhow::Result<Self> {
1524        let cwd = cwd
1525            .map(Path::to_path_buf)
1526            .unwrap_or(std::env::current_dir()?)
1527            .canonicalize()
1528            .context("failed to canonicalize workflow cwd for worktree isolation")?;
1529        let repo_root = git_output(&cwd, &["rev-parse", "--show-toplevel"]).context(
1530            "agent isolation='worktree' requires the workflow cwd to be inside a git repository",
1531        )?;
1532        let repo_root = PathBuf::from(repo_root.trim())
1533            .canonicalize()
1534            .context("failed to canonicalize git repository root for worktree isolation")?;
1535        let relative_cwd = cwd.strip_prefix(&repo_root).with_context(|| {
1536            format!(
1537                "workflow cwd {} is not under git repository root {}",
1538                cwd.display(),
1539                repo_root.display()
1540            )
1541        })?;
1542
1543        let temp_dir = tempfile::Builder::new()
1544            .prefix("smol-wf-agent-worktree-")
1545            .tempdir()
1546            .context("failed to create temp directory for agent worktree isolation")?;
1547        let worktree_root = temp_dir.path().join("worktree");
1548        let worktree_arg = path_arg(&worktree_root);
1549        let branch_name = format!(
1550            "smol-wf/agent-run/{}",
1551            ulid::Ulid::new().to_string().to_ascii_lowercase()
1552        );
1553        git_status(
1554            &repo_root,
1555            &[
1556                "worktree",
1557                "add",
1558                "--quiet",
1559                "-b",
1560                &branch_name,
1561                &worktree_arg,
1562                "HEAD",
1563            ],
1564        )
1565        .context("failed to create isolated git worktree for agent run")?;
1566        let isolated_cwd = if relative_cwd.as_os_str().is_empty() {
1567            worktree_root.clone()
1568        } else {
1569            worktree_root.join(relative_cwd)
1570        };
1571        Ok(Self {
1572            repo_root,
1573            worktree_root,
1574            cwd: isolated_cwd,
1575            branch_name,
1576            cleaned: false,
1577            _temp_dir: temp_dir,
1578        })
1579    }
1580
1581    fn info(&self) -> AgentRunIsolation {
1582        AgentRunIsolation {
1583            kind: "worktree".to_string(),
1584            branch: Some(self.branch_name.clone()),
1585            worktree_path: Some(path_arg(&self.worktree_root)),
1586            cwd: Some(path_arg(&self.cwd)),
1587        }
1588    }
1589
1590    fn cleanup(mut self) -> anyhow::Result<()> {
1591        self.remove_worktree()?;
1592        self.delete_branch()?;
1593        self.cleaned = true;
1594        Ok(())
1595    }
1596
1597    fn remove_worktree(&self) -> anyhow::Result<()> {
1598        let worktree_arg = path_arg(&self.worktree_root);
1599        git_status(
1600            &self.repo_root,
1601            &["worktree", "remove", "--force", &worktree_arg],
1602        )
1603        .context("failed to remove isolated git worktree")
1604    }
1605
1606    fn delete_branch(&self) -> anyhow::Result<()> {
1607        git_status(&self.repo_root, &["branch", "-D", &self.branch_name])
1608            .context("failed to delete isolated agent worktree branch")
1609    }
1610}
1611
1612impl Drop for WorktreeIsolation {
1613    fn drop(&mut self) {
1614        if !self.cleaned {
1615            if let Err(error) = self.remove_worktree() {
1616                log::warn!("failed to cleanup isolated agent worktree during drop: {error:#}");
1617            }
1618            if let Err(error) = self.delete_branch() {
1619                log::warn!(
1620                    "failed to delete isolated agent worktree branch during drop: {error:#}"
1621                );
1622            }
1623        }
1624    }
1625}
1626
1627fn path_arg(path: &Path) -> String {
1628    path.to_string_lossy().into_owned()
1629}
1630
1631fn git_output(cwd: &Path, args: &[&str]) -> anyhow::Result<String> {
1632    let output = StdCommand::new("git")
1633        .args(args)
1634        .current_dir(cwd)
1635        .output()
1636        .with_context(|| format!("failed to run git {}", args.join(" ")))?;
1637    if output.status.success() {
1638        Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
1639    } else {
1640        bail!(
1641            "git {} failed with {}{}",
1642            args.join(" "),
1643            status_text(output.status.code()),
1644            command_stderr(&output.stderr)
1645        )
1646    }
1647}
1648
1649fn git_status(cwd: &Path, args: &[&str]) -> anyhow::Result<()> {
1650    let output = StdCommand::new("git")
1651        .args(args)
1652        .current_dir(cwd)
1653        .output()
1654        .with_context(|| format!("failed to run git {}", args.join(" ")))?;
1655    if output.status.success() {
1656        Ok(())
1657    } else {
1658        bail!(
1659            "git {} failed with {}{}",
1660            args.join(" "),
1661            status_text(output.status.code()),
1662            command_stderr(&output.stderr)
1663        )
1664    }
1665}
1666
1667fn status_text(code: Option<i32>) -> String {
1668    code.map(|code| format!("code {code}"))
1669        .unwrap_or_else(|| "signal".to_string())
1670}
1671
1672fn command_stderr(stderr: &[u8]) -> String {
1673    let stderr = String::from_utf8_lossy(stderr);
1674    let stderr = stderr.trim();
1675    if stderr.is_empty() {
1676        String::new()
1677    } else {
1678        format!(": {stderr}")
1679    }
1680}
1681
1682async fn run_agent_with_schema_validation(
1683    provider: Arc<dyn AgentProvider>,
1684    input: AgentProviderRunInput,
1685) -> anyhow::Result<AgentProviderResult> {
1686    let Some(schema) = input
1687        .options
1688        .as_ref()
1689        .and_then(|options| options.get("schema"))
1690        .cloned()
1691    else {
1692        return provider.run(input).await;
1693    };
1694
1695    let max_attempts = 2;
1696    let original_prompt = input.prompt.clone();
1697    let mut attempt_input = input;
1698    let mut last_errors = Vec::new();
1699
1700    for attempt in 1..=max_attempts {
1701        let result = provider.run(attempt_input.clone()).await?;
1702        match validate_structured_output(&schema, &result.output) {
1703            Ok(()) => return Ok(result),
1704            Err(errors) => {
1705                last_errors = errors;
1706                if attempt < max_attempts {
1707                    attempt_input.prompt =
1708                        with_structured_output_retry_prompt(&original_prompt, &last_errors);
1709                }
1710            }
1711        }
1712    }
1713
1714    bail!(
1715        "{}",
1716        format_structured_output_validation_error(&last_errors)
1717    )
1718}
1719
1720fn validate_structured_output(schema: &Value, output: &Value) -> Result<(), Vec<String>> {
1721    let validator = jsonschema::validator_for(schema)
1722        .map_err(|error| vec![format!("/ schema is invalid: {}", error)])?;
1723    let errors = validator
1724        .iter_errors(output)
1725        .map(|error| {
1726            let path = error.instance_path().to_string();
1727            let path = if path.is_empty() {
1728                "/".to_string()
1729            } else {
1730                path
1731            };
1732            format!("{path} {error}")
1733        })
1734        .collect::<Vec<_>>();
1735
1736    if errors.is_empty() {
1737        Ok(())
1738    } else {
1739        Err(errors)
1740    }
1741}
1742
1743fn format_structured_output_validation_error(errors: &[String]) -> String {
1744    format!(
1745        "Structured output did not match JSON Schema: {}",
1746        errors.join("; ")
1747    )
1748}
1749
1750fn with_structured_output_retry_prompt(prompt: &str, errors: &[String]) -> String {
1751    let mut lines = vec![
1752        prompt.to_string(),
1753        String::new(),
1754        "Previous structured output failed JSON Schema validation.".to_string(),
1755        "Return a corrected structured output that satisfies the original JSON Schema.".to_string(),
1756        "Validation errors:".to_string(),
1757    ];
1758    lines.extend(errors.iter().map(|error| format!("- {error}")));
1759    lines.join("\n")
1760}
1761
1762#[derive(Debug, Clone, PartialEq, Eq)]
1763struct ResolvedModelSelector {
1764    requested: String,
1765    selector: String,
1766    model_id: String,
1767    model_provider: Option<String>,
1768    thinking: Option<String>,
1769}
1770
1771impl ResolvedModelSelector {
1772    fn provider_model(&self) -> String {
1773        match &self.model_provider {
1774            Some(provider) => format!("{provider}/{}", self.model_id),
1775            None => self.model_id.clone(),
1776        }
1777    }
1778}
1779
1780fn resolve_model_options(
1781    options: Option<Value>,
1782    agent_provider: &str,
1783    model_map: &BTreeMap<String, String>,
1784) -> anyhow::Result<Option<Value>> {
1785    let Some(model) = options
1786        .as_ref()
1787        .and_then(Value::as_object)
1788        .and_then(|object| object.get("model"))
1789        .and_then(Value::as_str)
1790        .map(ToString::to_string)
1791    else {
1792        return Ok(options);
1793    };
1794
1795    let mapped_selector = model_map.get(&model).cloned();
1796    let alias_matched = mapped_selector.is_some();
1797    let selector = mapped_selector.unwrap_or_else(|| model.clone());
1798    let resolved = parse_model_selector(&model, &selector)?;
1799    validate_model_selector_for_provider(&resolved, agent_provider)?;
1800
1801    let mut object = options
1802        .and_then(|value| value.as_object().cloned())
1803        .unwrap_or_default();
1804    object.insert(
1805        "model".to_string(),
1806        Value::String(resolved.provider_model()),
1807    );
1808
1809    let selector_has_extra_parts = alias_matched
1810        || resolved.selector.contains('?')
1811        || resolved.model_provider.is_some()
1812        || resolved.thinking.is_some();
1813    if selector_has_extra_parts {
1814        object.insert(
1815            "requestedModel".to_string(),
1816            Value::String(resolved.requested.clone()),
1817        );
1818        object.insert(
1819            "modelSelector".to_string(),
1820            Value::String(resolved.selector.clone()),
1821        );
1822    } else {
1823        object.remove("requestedModel");
1824        object.remove("modelSelector");
1825    }
1826
1827    if let Some(provider) = resolved.model_provider {
1828        object.insert("modelProvider".to_string(), Value::String(provider));
1829    } else {
1830        object.remove("modelProvider");
1831    }
1832    if let Some(thinking) = resolved.thinking {
1833        object.insert("thinking".to_string(), Value::String(thinking));
1834    } else {
1835        object.remove("thinking");
1836    }
1837    Ok(Some(Value::Object(object)))
1838}
1839
1840fn parse_model_selector(requested: &str, selector: &str) -> anyhow::Result<ResolvedModelSelector> {
1841    let (model_part, query) = selector.split_once('?').unwrap_or((selector, ""));
1842    if model_part.trim().is_empty() {
1843        bail!("model selector must include a model id: {selector}");
1844    }
1845
1846    let (slash_provider, model_id) = match model_part.split_once('/') {
1847        Some((provider, model_id)) if !provider.is_empty() && !model_id.is_empty() => {
1848            (Some(provider.to_string()), model_id.to_string())
1849        }
1850        Some(_) => bail!("model selector provider/model form is invalid: {selector}"),
1851        None => (None, model_part.to_string()),
1852    };
1853
1854    let mut query_provider = None::<String>;
1855    let mut thinking = None::<String>;
1856    if !query.is_empty() {
1857        for pair in query.split('&') {
1858            if pair.is_empty() {
1859                continue;
1860            }
1861            let (key, value) = pair.split_once('=').ok_or_else(|| {
1862                anyhow!("model selector query parameter must use key=value: {pair}")
1863            })?;
1864            let key = percent_decode(key)?;
1865            let value = percent_decode(value)?;
1866            if value.is_empty() {
1867                bail!("model selector query parameter `{key}` must not be empty");
1868            }
1869            match key.as_str() {
1870                "provider" => set_unique_query_value(&mut query_provider, key, value)?,
1871                "thinking" => set_unique_query_value(&mut thinking, key, value)?,
1872                _ => bail!("unknown model selector query parameter `{key}`"),
1873            }
1874        }
1875    }
1876
1877    let model_provider = match (slash_provider, query_provider) {
1878        (Some(slash), Some(query)) if slash != query => bail!(
1879            "conflicting model provider qualifiers in selector `{selector}`: `{slash}` and `{query}`"
1880        ),
1881        (Some(provider), Some(_)) | (Some(provider), None) | (None, Some(provider)) => {
1882            Some(provider)
1883        }
1884        (None, None) => None,
1885    };
1886
1887    Ok(ResolvedModelSelector {
1888        requested: requested.to_string(),
1889        selector: selector.to_string(),
1890        model_id,
1891        model_provider,
1892        thinking,
1893    })
1894}
1895
1896fn set_unique_query_value(
1897    target: &mut Option<String>,
1898    key: String,
1899    value: String,
1900) -> anyhow::Result<()> {
1901    if target.replace(value).is_some() {
1902        bail!("duplicate model selector query parameter `{key}`");
1903    }
1904    Ok(())
1905}
1906
1907fn percent_decode(value: &str) -> anyhow::Result<String> {
1908    let bytes = value.as_bytes();
1909    let mut output = Vec::with_capacity(bytes.len());
1910    let mut index = 0;
1911    while index < bytes.len() {
1912        match bytes[index] {
1913            b'%' => {
1914                if index + 2 >= bytes.len() {
1915                    bail!("invalid percent escape in model selector query: {value}");
1916                }
1917                let high = hex_value(bytes[index + 1]).ok_or_else(|| {
1918                    anyhow!("invalid percent escape in model selector query: {value}")
1919                })?;
1920                let low = hex_value(bytes[index + 2]).ok_or_else(|| {
1921                    anyhow!("invalid percent escape in model selector query: {value}")
1922                })?;
1923                output.push((high << 4) | low);
1924                index += 3;
1925            }
1926            b'+' => {
1927                output.push(b' ');
1928                index += 1;
1929            }
1930            byte => {
1931                output.push(byte);
1932                index += 1;
1933            }
1934        }
1935    }
1936    String::from_utf8(output).context("model selector query is not valid UTF-8")
1937}
1938
1939fn hex_value(byte: u8) -> Option<u8> {
1940    match byte {
1941        b'0'..=b'9' => Some(byte - b'0'),
1942        b'a'..=b'f' => Some(byte - b'a' + 10),
1943        b'A'..=b'F' => Some(byte - b'A' + 10),
1944        _ => None,
1945    }
1946}
1947
1948fn validate_model_selector_for_provider(
1949    resolved: &ResolvedModelSelector,
1950    agent_provider: &str,
1951) -> anyhow::Result<()> {
1952    match agent_provider {
1953        "codex" => {
1954            if resolved.model_provider.is_some() {
1955                bail!("Codex model selectors do not support ?provider=... or provider/model form");
1956            }
1957            if resolved.thinking.is_some() {
1958                bail!("Codex model selectors do not support thinking=...");
1959            }
1960        }
1961        "claude-code" if resolved.model_provider.is_some() => {
1962            bail!(
1963                "Claude Code model selectors do not support ?provider=... or provider/model form"
1964            );
1965        }
1966        "opencode" if resolved.model_provider.is_none() => {
1967            bail!("OpenCode model selectors must use provider/model or ?provider=...");
1968        }
1969        "debug" | "pi" => {}
1970        _ => {}
1971    }
1972    Ok(())
1973}
1974
1975fn apply_phase_defaults(options: Option<Value>, metadata: &WorkflowMetadata) -> Option<Value> {
1976    let phase_name = options
1977        .as_ref()
1978        .and_then(|options| options.get("phase"))
1979        .and_then(Value::as_str)
1980        .map(ToString::to_string);
1981    let phase_metadata = phase_name.as_ref().and_then(|phase_name| {
1982        metadata
1983            .phases
1984            .iter()
1985            .find(|phase| phase.title == *phase_name)
1986    });
1987
1988    if phase_name.is_none() && phase_metadata.is_none() {
1989        return options;
1990    }
1991
1992    let mut object = options
1993        .and_then(|value| value.as_object().cloned())
1994        .unwrap_or_default();
1995
1996    if let Some(phase_name) = phase_name {
1997        object
1998            .entry("phase".to_string())
1999            .or_insert(Value::String(phase_name));
2000    }
2001    if let Some(model) = phase_metadata.and_then(|phase| phase.model.clone()) {
2002        object
2003            .entry("model".to_string())
2004            .or_insert(Value::String(model));
2005    }
2006    if let Some(provider) = phase_metadata.and_then(|phase| phase.provider.clone()) {
2007        object
2008            .entry("provider".to_string())
2009            .or_insert(Value::String(provider));
2010    }
2011
2012    Some(Value::Object(object))
2013}
2014
2015fn resolve_relative_script(current_script_path: &Path, script_path: &str) -> PathBuf {
2016    let script_path = PathBuf::from(script_path);
2017    if script_path.is_absolute() {
2018        script_path
2019    } else {
2020        current_script_path
2021            .parent()
2022            .unwrap_or_else(|| Path::new("."))
2023            .join(script_path)
2024    }
2025}
2026
2027fn resolve_named_workflow(name: &str) -> anyhow::Result<PathBuf> {
2028    let workflows_dir = PathBuf::from(".claude/workflows");
2029    for entry in fs::read_dir(&workflows_dir).unwrap_or_else(|_| fs::read_dir(".").unwrap()) {
2030        let entry = entry?;
2031        let path = entry.path();
2032        if path.extension().and_then(|extension| extension.to_str()) != Some("js") {
2033            continue;
2034        }
2035        if read_workflow_metadata(&path)?.is_some_and(|metadata| metadata.name == name) {
2036            return Ok(path);
2037        }
2038    }
2039    bail!("Unknown workflow: {name}")
2040}