1use crate::agent_providers::{
2 create_agent_provider, AgentProvider, AgentProviderContext, AgentProviderResult,
3 AgentProviderRunInput, AgentRunIsolation, AgentUsage, AgentUsageCost,
4};
5use crate::js_runtime::rquickjs::RQuickJSWorkflowRuntime;
6use crate::js_runtime::{
7 WorkflowBudgetSnapshot, WorkflowJSRuntime, WorkflowModuleInput, WorkflowModuleOutput,
8 WorkflowRef, WorkflowRuntimeCall, WorkflowRuntimeExecution, WorkflowRuntimePoll,
9 WorkflowRuntimeRequest, WorkflowRuntimeRequestResolution,
10};
11use crate::metadata::{read_workflow_metadata, WorkflowMetadata};
12use anyhow::{anyhow, bail, Context};
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::collections::{BTreeMap, BTreeSet, VecDeque};
16use std::fs;
17use std::path::{Path, PathBuf};
18use std::process::Command as StdCommand;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tokio::sync::{mpsc, watch};
22use tokio::task::{JoinSet, LocalSet};
23
24pub use crate::events::{
25 WorkflowEvent, WorkflowEventMetadata, WorkflowEventSink, WorkflowEventType,
26};
27
28#[async_trait::async_trait]
29pub trait AgentSessionLogSink: Send + Sync {
30 async fn write_agent_result(
31 &self,
32 provider: &str,
33 result: &AgentProviderResult,
34 ) -> anyhow::Result<()>;
35}
36
37#[async_trait::async_trait]
38pub trait WorkflowAgentRunner: Send + Sync {
39 async fn run_agent(
40 &self,
41 default_provider: Arc<dyn AgentProvider>,
42 provider_override: Option<String>,
43 input: AgentProviderRunInput,
44 ) -> anyhow::Result<AgentProviderResult>;
45
46 fn retry_in_runtime(&self) -> bool {
60 true
61 }
62
63 async fn sleep(&self, duration_ms: u64) -> anyhow::Result<()> {
64 tokio::time::sleep(std::time::Duration::from_millis(duration_ms)).await;
65 Ok(())
66 }
67}
68
69#[derive(Debug, Default)]
70pub struct DirectWorkflowAgentRunner;
71
72#[async_trait::async_trait]
73impl WorkflowAgentRunner for DirectWorkflowAgentRunner {
74 async fn run_agent(
75 &self,
76 default_provider: Arc<dyn AgentProvider>,
77 provider_override: Option<String>,
78 input: AgentProviderRunInput,
79 ) -> anyhow::Result<AgentProviderResult> {
80 run_agent_provider(default_provider, provider_override, input).await
81 }
82}
83
84pub struct RunWorkflowOptions {
85 pub script_path: PathBuf,
86 pub args: Value,
87 pub agent_provider: Arc<dyn AgentProvider>,
88 pub model_map: BTreeMap<String, String>,
89 pub budget_total: Option<u64>,
90 pub budget_spent: u64,
91 pub nesting_depth: usize,
92 pub max_parallel_agent_requests: Option<usize>,
93 pub agent_runner: Option<Arc<dyn WorkflowAgentRunner>>,
94 pub cancel_rx: Option<watch::Receiver<bool>>,
95 pub event_sink: Option<Arc<dyn WorkflowEventSink>>,
96 pub event_parent_step_id: Option<String>,
97 pub event_stream_start: Option<Instant>,
98 pub session_log_sink: Option<Arc<dyn AgentSessionLogSink>>,
99}
100
101#[derive(Debug)]
102pub struct RunWorkflowResult {
103 pub output: WorkflowModuleOutput,
104 pub logs: Vec<Vec<Value>>,
105 pub phases: Vec<WorkflowPhaseCall>,
106 pub agent_calls: Vec<WorkflowRuntimeRequest>,
107 pub workflow_calls: Vec<WorkflowRuntimeRequest>,
108 pub budget: WorkflowBudgetSnapshot,
109 pub token_usage: WorkflowTokenUsage,
110 pub token_usage_by_phase: std::collections::BTreeMap<String, WorkflowTokenUsage>,
111 pub agent_runs: Vec<WorkflowAgentRunSummary>,
112}
113
114#[derive(Debug, Clone, Default, Deserialize, Serialize, PartialEq)]
115#[serde(rename_all = "camelCase")]
116pub struct WorkflowTokenUsage {
117 pub input_tokens: u64,
118 pub output_tokens: u64,
119 pub cache_read_tokens: u64,
120 pub cache_write_tokens: u64,
121 pub total_tokens: u64,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 pub cost: Option<AgentUsageCost>,
124}
125
126#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
127#[serde(rename_all = "camelCase")]
128pub struct WorkflowAgentRunSummary {
129 pub id: String,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 pub phase: Option<String>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 pub provider: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 pub model: Option<String>,
136 #[serde(skip_serializing_if = "Option::is_none")]
137 pub provider_session_id: Option<String>,
138 #[serde(skip_serializing_if = "Option::is_none")]
139 pub usage: Option<AgentUsage>,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 pub isolation: Option<AgentRunIsolation>,
142}
143
144#[derive(Debug, Clone, PartialEq)]
145pub struct WorkflowPhaseCall {
146 pub name: String,
147 pub options: Option<Value>,
148}
149
150pub async fn run_workflow(options: RunWorkflowOptions) -> anyhow::Result<RunWorkflowResult> {
151 LocalSet::new().run_until(run_workflow_inner(options)).await
152}
153
154async fn run_workflow_inner(options: RunWorkflowOptions) -> anyhow::Result<RunWorkflowResult> {
155 log::debug!(
156 "run_workflow start script={} provider={} nesting_depth={} budget_total={:?} budget_spent={}",
157 options.script_path.display(),
158 options.agent_provider.name(),
159 options.nesting_depth,
160 options.budget_total,
161 options.budget_spent
162 );
163 let script_path = fs::canonicalize(&options.script_path).with_context(|| {
164 format!(
165 "failed to resolve workflow script {}",
166 options.script_path.display()
167 )
168 })?;
169 let metadata = read_workflow_metadata(&script_path)?.ok_or_else(|| {
170 anyhow!("Workflow script must export valid literal metadata as `export const meta = {{ name, description, ... }}`")
171 })?;
172 log::debug!(
173 "workflow metadata loaded name={} phases={}",
174 metadata.name,
175 metadata.phases.len()
176 );
177 let source = fs::read_to_string(&script_path)
178 .with_context(|| format!("failed to read workflow script {}", script_path.display()))?;
179 let runtime = RQuickJSWorkflowRuntime::new();
180 let execution = runtime.start_module(WorkflowModuleInput {
181 source,
182 source_name: script_path.to_string_lossy().into_owned(),
183 args: options.args,
184 budget: WorkflowBudgetSnapshot {
185 total: options.budget_total,
186 spent: options.budget_spent,
187 },
188 sandbox: Default::default(),
189 })?;
190
191 let (js_commands, js_command_rx) = mpsc::channel::<JsCommand>(64);
192 let (js_event_tx, mut js_events) = mpsc::channel::<JsEvent>(64);
193 let js_task = tokio::task::spawn_local(js_runtime_actor(execution, js_command_rx, js_event_tx));
194
195 let emit_lifecycle_events = options.event_sink.is_some();
196 let event_start = options.event_stream_start.unwrap_or_else(Instant::now);
197
198 let mut state = RunState {
199 script_path,
200 metadata,
201 event_start,
202 agent_provider: options.agent_provider,
203 model_map: options.model_map,
204 logs: Vec::new(),
205 phases: Vec::new(),
206 agent_calls: Vec::new(),
207 workflow_calls: Vec::new(),
208 budget: WorkflowBudgetSnapshot {
209 total: options.budget_total,
210 spent: options.budget_spent,
211 },
212 token_usage: WorkflowTokenUsage::default(),
213 token_usage_by_phase: Default::default(),
214 agent_runs: Vec::new(),
215 active_request_ids: BTreeSet::new(),
216 nesting_depth: options.nesting_depth,
217 max_parallel_agent_requests: options.max_parallel_agent_requests,
218 agent_runner: options
219 .agent_runner
220 .unwrap_or_else(|| Arc::new(DirectWorkflowAgentRunner)),
221 cancel_rx: options.cancel_rx,
222 event_sink: options.event_sink,
223 event_parent_step_id: options.event_parent_step_id,
224 session_log_sink: options.session_log_sink,
225 };
226
227 let mut pending_requests = VecDeque::<WorkflowRuntimeRequest>::new();
228 let mut agent_tasks = JoinSet::<AgentTaskCompletion>::new();
229 let mut sleep_tasks = JoinSet::<SleepTaskCompletion>::new();
230
231 if emit_lifecycle_events {
232 if let Err(error) = state
233 .emit_event(WorkflowEvent::started(rfc3339_now()?))
234 .await
235 {
236 let _ = send_js_command(&js_commands, JsCommand::Shutdown).await;
237 let _ = js_task.await;
238 return Err(error);
239 }
240 }
241
242 let workflow_result: anyhow::Result<RunWorkflowResult> = loop {
243 if let Err(error) = state
244 .start_pending_requests(
245 &mut pending_requests,
246 &mut agent_tasks,
247 &mut sleep_tasks,
248 &js_commands,
249 )
250 .await
251 {
252 break Err(error);
253 }
254
255 tokio::select! {
256 biased;
257 () = wait_for_cancellation(&mut state.cancel_rx) => {
258 break state.cancel_workflow(
259 &mut pending_requests,
260 &mut agent_tasks,
261 &mut sleep_tasks,
262 &js_commands,
263 &mut js_events,
264 ).await;
265 }
266 event = js_events.recv() => {
267 let event = match event {
268 Some(event) => event,
269 None => break Err(anyhow!("JavaScript runtime actor stopped unexpectedly")),
270 };
271 match state.handle_js_event(event, &mut pending_requests).await {
272 Ok(Some(result)) => break Ok(result),
273 Ok(None) => {}
274 Err(error) => break Err(error),
275 }
276 }
277 completion = agent_tasks.join_next(), if !agent_tasks.is_empty() => {
278 let completion = match completion {
279 Some(Ok(completion)) => completion,
280 Some(Err(error)) => break Err(anyhow!("agent provider task failed: {error}")),
281 None => break Err(anyhow!("agent task set ended unexpectedly")),
282 };
283 let AgentTaskCompletion { id, input, provider, result } = completion;
284 state.active_request_ids.remove(&id);
285 let resolution = match result {
286 Ok(result) => match state.apply_agent_result(&id, &input, provider, result).await {
287 Ok(value) => WorkflowRuntimeRequestResolution::OkWithBudget {
288 value,
289 budget: state.budget.clone(),
290 },
291 Err(error) => WorkflowRuntimeRequestResolution::Err {
292 message: error.to_string(),
293 },
294 },
295 Err(error) => {
296 let message = error.to_string();
297 if let Err(emit_error) = state.emit_agent_failed_event(&id, provider.as_deref(), &message).await {
298 log::debug!("failed to emit agent failure event: {emit_error:#}");
299 }
300 WorkflowRuntimeRequestResolution::Err { message }
301 },
302 };
303 if let Err(error) = send_js_command(&js_commands, JsCommand::ResolveRequest { id, resolution }).await {
304 break Err(error);
305 }
306 }
307 completion = sleep_tasks.join_next(), if !sleep_tasks.is_empty() => {
308 let completion = match completion {
309 Some(Ok(completion)) => completion,
310 Some(Err(error)) => break Err(anyhow!("sleep task failed: {error}")),
311 None => break Err(anyhow!("sleep task set ended unexpectedly")),
312 };
313 let SleepTaskCompletion { id, result } = completion;
314 state.active_request_ids.remove(&id);
315 let resolution = match result {
316 Ok(()) => WorkflowRuntimeRequestResolution::OkUndefined,
317 Err(error) => WorkflowRuntimeRequestResolution::Err {
318 message: error.to_string(),
319 },
320 };
321 if let Err(error) = send_js_command(&js_commands, JsCommand::ResolveRequest { id, resolution }).await {
322 break Err(error);
323 }
324 }
325 }
326 };
327
328 let _ = send_js_command(&js_commands, JsCommand::Shutdown).await;
329 let _ = js_task.await;
330
331 if emit_lifecycle_events {
332 match &workflow_result {
333 Ok(result) => {
334 state
335 .emit_event(WorkflowEvent::result(
336 result.token_usage.input_tokens,
337 result.token_usage.output_tokens,
338 result.token_usage.total_tokens,
339 result.output.result.clone(),
340 ))
341 .await?
342 }
343 Err(error) => {
344 state
345 .emit_event(WorkflowEvent::error(error.to_string(), None))
346 .await?;
347 }
348 }
349 }
350
351 workflow_result
352}
353
354enum JsCommand {
355 ResolveRequest {
356 id: String,
357 resolution: WorkflowRuntimeRequestResolution,
358 },
359 Shutdown,
360}
361
362enum JsEvent {
363 Call(WorkflowRuntimeCall),
364 Request(WorkflowRuntimeRequest),
365 Complete(WorkflowModuleOutput),
366 Error(String),
367}
368
369async fn js_runtime_actor(
370 mut execution: Box<dyn WorkflowRuntimeExecution>,
371 mut commands: mpsc::Receiver<JsCommand>,
372 events: mpsc::Sender<JsEvent>,
373) {
374 let mut outstanding_requests = 0usize;
375 loop {
376 match execution.poll() {
377 Ok(WorkflowRuntimePoll::Call(call)) => {
378 if events.send(JsEvent::Call(call)).await.is_err() {
379 return;
380 }
381 }
382 Ok(WorkflowRuntimePoll::Request(request)) => {
383 let requests = match execution.take_pending_requests() {
384 Ok(requests) if requests.is_empty() => vec![request],
385 Ok(requests) => requests,
386 Err(error) => {
387 let _ = events.send(JsEvent::Error(error.to_string())).await;
388 return;
389 }
390 };
391 outstanding_requests = outstanding_requests.saturating_add(requests.len());
392 for request in requests {
393 if events.send(JsEvent::Request(request)).await.is_err() {
394 return;
395 }
396 }
397 }
398 Ok(WorkflowRuntimePoll::Complete(output)) => {
399 let _ = events.send(JsEvent::Complete(output)).await;
400 return;
401 }
402 Ok(WorkflowRuntimePoll::Pending) => {
403 if outstanding_requests == 0 {
404 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
405 continue;
406 }
407 match commands.recv().await {
408 Some(JsCommand::ResolveRequest { id, resolution }) => {
409 outstanding_requests = outstanding_requests.saturating_sub(1);
410 if let Err(error) = execution.resolve_request(&id, resolution) {
411 let _ = events.send(JsEvent::Error(error.to_string())).await;
412 return;
413 }
414 }
415 Some(JsCommand::Shutdown) | None => return,
416 }
417 }
418 Err(error) => {
419 let _ = events.send(JsEvent::Error(error.to_string())).await;
420 return;
421 }
422 }
423 }
424}
425
426async fn send_js_command(
427 commands: &mpsc::Sender<JsCommand>,
428 command: JsCommand,
429) -> anyhow::Result<()> {
430 commands
431 .send(command)
432 .await
433 .map_err(|_| anyhow!("JavaScript runtime actor stopped unexpectedly"))
434}
435
436struct RunState {
437 script_path: PathBuf,
438 metadata: WorkflowMetadata,
439 event_start: Instant,
440 agent_provider: Arc<dyn AgentProvider>,
441 model_map: BTreeMap<String, String>,
442 logs: Vec<Vec<Value>>,
443 phases: Vec<WorkflowPhaseCall>,
444 agent_calls: Vec<WorkflowRuntimeRequest>,
445 workflow_calls: Vec<WorkflowRuntimeRequest>,
446 budget: WorkflowBudgetSnapshot,
447 token_usage: WorkflowTokenUsage,
448 token_usage_by_phase: std::collections::BTreeMap<String, WorkflowTokenUsage>,
449 agent_runs: Vec<WorkflowAgentRunSummary>,
450 active_request_ids: BTreeSet<String>,
451 nesting_depth: usize,
452 max_parallel_agent_requests: Option<usize>,
453 agent_runner: Arc<dyn WorkflowAgentRunner>,
454 cancel_rx: Option<watch::Receiver<bool>>,
455 event_sink: Option<Arc<dyn WorkflowEventSink>>,
456 event_parent_step_id: Option<String>,
457 session_log_sink: Option<Arc<dyn AgentSessionLogSink>>,
458}
459
460struct PreparedAgentRun {
461 provider_override: Option<String>,
462 input: AgentProviderRunInput,
463}
464
465struct AgentTaskCompletion {
466 id: String,
467 input: AgentProviderRunInput,
468 provider: Option<String>,
469 result: anyhow::Result<AgentProviderResult>,
470}
471
472struct SleepTaskCompletion {
473 id: String,
474 result: anyhow::Result<()>,
475}
476
477fn add_usage(total: &mut WorkflowTokenUsage, usage: Option<&AgentUsage>) {
478 let Some(usage) = usage else {
479 return;
480 };
481
482 total.input_tokens = total
483 .input_tokens
484 .saturating_add(usage.input_tokens.unwrap_or_default());
485 total.output_tokens = total
486 .output_tokens
487 .saturating_add(usage.output_tokens.unwrap_or_default());
488 total.cache_read_tokens = total
489 .cache_read_tokens
490 .saturating_add(usage.cache_read_tokens.unwrap_or_default());
491 total.cache_write_tokens = total
492 .cache_write_tokens
493 .saturating_add(usage.cache_write_tokens.unwrap_or_default());
494 total.total_tokens = total
495 .total_tokens
496 .saturating_add(usage.total_tokens.unwrap_or_default());
497
498 if let Some(cost) = usage.cost.as_ref() {
499 total.cost = Some(merge_cost(total.cost.as_ref(), cost));
500 }
501}
502
503fn merge_token_usage(total: &mut WorkflowTokenUsage, usage: &WorkflowTokenUsage) {
504 total.input_tokens = total.input_tokens.saturating_add(usage.input_tokens);
505 total.output_tokens = total.output_tokens.saturating_add(usage.output_tokens);
506 total.cache_read_tokens = total
507 .cache_read_tokens
508 .saturating_add(usage.cache_read_tokens);
509 total.cache_write_tokens = total
510 .cache_write_tokens
511 .saturating_add(usage.cache_write_tokens);
512 total.total_tokens = total.total_tokens.saturating_add(usage.total_tokens);
513 if let Some(cost) = usage.cost.as_ref() {
514 total.cost = Some(merge_cost(total.cost.as_ref(), cost));
515 }
516}
517
518fn merge_cost(left: Option<&AgentUsageCost>, right: &AgentUsageCost) -> AgentUsageCost {
519 AgentUsageCost {
520 input: sum_f64(left.and_then(|cost| cost.input), right.input),
521 output: sum_f64(left.and_then(|cost| cost.output), right.output),
522 cache_read: sum_f64(left.and_then(|cost| cost.cache_read), right.cache_read),
523 cache_write: sum_f64(left.and_then(|cost| cost.cache_write), right.cache_write),
524 total: sum_f64(left.and_then(|cost| cost.total), right.total),
525 currency: right
526 .currency
527 .clone()
528 .or_else(|| left.and_then(|cost| cost.currency.clone())),
529 }
530}
531
532fn elapsed_nanos(start: Instant) -> u64 {
533 u64::try_from(start.elapsed().as_nanos()).unwrap_or(u64::MAX)
534}
535
536fn rfc3339_now() -> anyhow::Result<String> {
537 Ok(time::OffsetDateTime::now_utc().format(&time::format_description::well_known::Rfc3339)?)
538}
539
540fn raw_agent_event_payloads(raw: &Value) -> Vec<Value> {
541 if let Some(events) = raw.get("events").and_then(Value::as_array) {
542 events.clone()
543 } else if let Some(items) = raw.as_array() {
544 items.clone()
545 } else {
546 vec![raw.clone()]
547 }
548}
549
550fn agent_session_event_payload(provider_event: Value, metadata: &WorkflowEventMetadata) -> Value {
551 let mut payload = serde_json::Map::new();
552 if let Some(provider) = metadata.provider.as_ref() {
553 payload.insert("provider".to_string(), Value::String(provider.clone()));
554 }
555 if let Some(session_id) = metadata.session_id.as_ref() {
556 payload.insert("sessionId".to_string(), Value::String(session_id.clone()));
557 }
558 if let Some(run_id) = metadata.run_id.as_ref() {
559 payload.insert("runId".to_string(), Value::String(run_id.clone()));
560 }
561 if let Some(step_id) = metadata.step_id.as_ref() {
562 payload.insert("stepId".to_string(), Value::String(step_id.clone()));
563 }
564 payload.insert("providerEvent".to_string(), provider_event);
565 Value::Object(payload)
566}
567
568fn truncate_for_event(value: &str, max_chars: usize) -> String {
569 let mut chars = value.chars();
570 let truncated = chars.by_ref().take(max_chars).collect::<String>();
571 if chars.next().is_some() {
572 format!("{truncated}…")
573 } else {
574 truncated
575 }
576}
577
578fn format_log_message(values: &[Value]) -> String {
579 values
580 .iter()
581 .map(|value| match value {
582 Value::String(value) => value.clone(),
583 value => serde_json::to_string(value).unwrap_or_else(|_| String::from("<unprintable>")),
584 })
585 .collect::<Vec<_>>()
586 .join(" ")
587}
588
589fn sum_f64(left: Option<f64>, right: Option<f64>) -> Option<f64> {
590 match (left, right) {
591 (None, None) => None,
592 (left, right) => Some(left.unwrap_or_default() + right.unwrap_or_default()),
593 }
594}
595
596async fn wait_for_cancellation(cancel_rx: &mut Option<watch::Receiver<bool>>) {
597 let Some(cancel_rx) = cancel_rx else {
598 std::future::pending::<()>().await;
599 return;
600 };
601 while !*cancel_rx.borrow() {
602 if cancel_rx.changed().await.is_err() {
603 return;
604 }
605 }
606}
607
608impl RunState {
609 async fn handle_js_event(
610 &mut self,
611 event: JsEvent,
612 pending_requests: &mut VecDeque<WorkflowRuntimeRequest>,
613 ) -> anyhow::Result<Option<RunWorkflowResult>> {
614 match event {
615 JsEvent::Call(call) => self.handle_call(call).await?,
616 JsEvent::Request(request) => {
617 log::debug!(
618 "workflow runtime request id={} kind={}",
619 request.id(),
620 request.kind()
621 );
622 pending_requests.push_back(request);
623 }
624 JsEvent::Complete(output) => {
625 log::debug!(
626 "run_workflow complete script={} budget_spent={}",
627 self.script_path.display(),
628 self.budget.spent
629 );
630 return Ok(Some(RunWorkflowResult {
631 output,
632 logs: std::mem::take(&mut self.logs),
633 phases: std::mem::take(&mut self.phases),
634 agent_calls: std::mem::take(&mut self.agent_calls),
635 workflow_calls: std::mem::take(&mut self.workflow_calls),
636 budget: self.budget.clone(),
637 token_usage: std::mem::take(&mut self.token_usage),
638 token_usage_by_phase: std::mem::take(&mut self.token_usage_by_phase),
639 agent_runs: std::mem::take(&mut self.agent_runs),
640 }));
641 }
642 JsEvent::Error(message) => bail!(message),
643 }
644 Ok(None)
645 }
646
647 async fn start_pending_requests(
648 &mut self,
649 pending_requests: &mut VecDeque<WorkflowRuntimeRequest>,
650 agent_tasks: &mut JoinSet<AgentTaskCompletion>,
651 sleep_tasks: &mut JoinSet<SleepTaskCompletion>,
652 js_commands: &mpsc::Sender<JsCommand>,
653 ) -> anyhow::Result<()> {
654 loop {
655 let Some(request) = pending_requests.front() else {
656 return Ok(());
657 };
658 if matches!(request, WorkflowRuntimeRequest::Agent { .. })
659 && !self.agent_capacity_available(agent_tasks.len())
660 {
661 return Ok(());
662 }
663
664 let request = pending_requests
665 .pop_front()
666 .expect("pending request should exist");
667 match request {
668 WorkflowRuntimeRequest::Agent { .. } => match self.prepare_agent_request(request) {
669 Ok((id, prepared)) => {
670 self.emit_agent_started_event(&id, &prepared).await?;
671 self.spawn_agent_task(agent_tasks, id, prepared);
672 }
673 Err((id, error)) => {
674 send_js_command(
675 js_commands,
676 JsCommand::ResolveRequest {
677 id,
678 resolution: WorkflowRuntimeRequestResolution::Err {
679 message: error.to_string(),
680 },
681 },
682 )
683 .await?;
684 }
685 },
686 WorkflowRuntimeRequest::Sleep { id, duration_ms } => {
687 self.spawn_sleep_task(sleep_tasks, id, duration_ms);
688 }
689 WorkflowRuntimeRequest::Workflow {
690 id,
691 workflow_ref,
692 args,
693 } => {
694 self.workflow_calls.push(WorkflowRuntimeRequest::Workflow {
695 id: id.clone(),
696 workflow_ref: workflow_ref.clone(),
697 args: args.clone(),
698 });
699 let parent_event_step_id = self.event_step_id(&id);
700 let resolution = match self
701 .handle_workflow(parent_event_step_id, workflow_ref, args)
702 .await
703 {
704 Ok(value) => WorkflowRuntimeRequestResolution::OkWithBudget {
705 value,
706 budget: self.budget.clone(),
707 },
708 Err(error) => WorkflowRuntimeRequestResolution::Err {
709 message: error.to_string(),
710 },
711 };
712 send_js_command(js_commands, JsCommand::ResolveRequest { id, resolution })
713 .await?;
714 }
715 }
716 }
717 }
718
719 async fn cancel_workflow(
720 &mut self,
721 pending_requests: &mut VecDeque<WorkflowRuntimeRequest>,
722 agent_tasks: &mut JoinSet<AgentTaskCompletion>,
723 sleep_tasks: &mut JoinSet<SleepTaskCompletion>,
724 js_commands: &mpsc::Sender<JsCommand>,
725 js_events: &mut mpsc::Receiver<JsEvent>,
726 ) -> anyhow::Result<RunWorkflowResult> {
727 log::debug!(
728 "workflow cancellation requested script={}",
729 self.script_path.display()
730 );
731
732 if pending_requests.is_empty()
733 && self.active_request_ids.is_empty()
734 && agent_tasks.is_empty()
735 && sleep_tasks.is_empty()
736 && self
737 .reject_next_runtime_request_for_cancellation(js_commands, js_events)
738 .await
739 {
740 bail!("workflow cancelled");
741 }
742
743 self.reject_pending_requests_for_cancellation(pending_requests, js_commands)
744 .await;
745 sleep_tasks.abort_all();
746 self.reject_active_sleep_requests_for_cancellation(sleep_tasks, js_commands)
747 .await;
748
749 if self.session_log_sink.is_some() {
750 while let Some(completion) = agent_tasks.join_next().await {
751 match completion {
752 Ok(AgentTaskCompletion {
753 id,
754 input,
755 provider,
756 result: Ok(result),
757 }) => {
758 self.active_request_ids.remove(&id);
759 if let Err(error) = self
760 .emit_agent_result_events(&id, provider.as_deref(), &result)
761 .await
762 {
763 log::debug!("failed to emit drained agent events during cancellation: {error:#}");
764 }
765 if let Err(error) = self
766 .emit_agent_completed_event(&id, provider.as_deref(), &result)
767 .await
768 {
769 log::debug!("failed to emit drained agent completion during cancellation: {error:#}");
770 }
771 self.record_agent_run(&id, &input, provider, &result);
772 self.reject_request_for_cancellation(id, js_commands).await;
773 }
774 Ok(AgentTaskCompletion {
775 id,
776 provider,
777 result: Err(error),
778 ..
779 }) => {
780 self.active_request_ids.remove(&id);
781 let message = error.to_string();
782 if let Err(error) = self
783 .emit_agent_failed_event(&id, provider.as_deref(), &message)
784 .await
785 {
786 log::debug!("failed to emit drained agent failure during cancellation: {error:#}");
787 }
788 log::debug!("agent task failed while draining cancellation: {message}");
789 self.reject_request_for_cancellation(id, js_commands).await;
790 }
791 Err(error) => {
792 log::debug!("agent task join failed while draining cancellation: {error}");
793 }
794 }
795 }
796 } else {
797 let ids: Vec<String> = self.active_request_ids.iter().cloned().collect();
798 agent_tasks.abort_all();
799 for id in ids {
800 self.active_request_ids.remove(&id);
801 self.reject_request_for_cancellation(id, js_commands).await;
802 }
803 }
804
805 self.reject_remaining_active_requests_for_cancellation(js_commands)
806 .await;
807 self.drain_runtime_after_cancellation(js_events).await;
808 let _ = send_js_command(js_commands, JsCommand::Shutdown).await;
809 bail!("workflow cancelled")
810 }
811
812 async fn reject_next_runtime_request_for_cancellation(
813 &mut self,
814 js_commands: &mpsc::Sender<JsCommand>,
815 js_events: &mut mpsc::Receiver<JsEvent>,
816 ) -> bool {
817 loop {
818 match js_events.recv().await {
819 Some(JsEvent::Call(call)) => {
820 let _ = self.handle_call(call).await;
821 }
822 Some(JsEvent::Request(request)) => {
823 self.reject_request_for_cancellation(request.id().to_string(), js_commands)
824 .await;
825 return false;
826 }
827 Some(JsEvent::Complete(_)) | Some(JsEvent::Error(_)) | None => return true,
828 }
829 }
830 }
831
832 async fn reject_pending_requests_for_cancellation(
833 &mut self,
834 pending_requests: &mut VecDeque<WorkflowRuntimeRequest>,
835 js_commands: &mpsc::Sender<JsCommand>,
836 ) {
837 while let Some(request) = pending_requests.pop_front() {
838 self.reject_request_for_cancellation(request.id().to_string(), js_commands)
839 .await;
840 }
841 }
842
843 async fn reject_active_sleep_requests_for_cancellation(
844 &mut self,
845 sleep_tasks: &mut JoinSet<SleepTaskCompletion>,
846 js_commands: &mpsc::Sender<JsCommand>,
847 ) {
848 while let Some(completion) = sleep_tasks.join_next().await {
849 if let Ok(SleepTaskCompletion { id, .. }) = completion {
850 self.active_request_ids.remove(&id);
851 self.reject_request_for_cancellation(id, js_commands).await;
852 }
853 }
854 }
855
856 async fn reject_remaining_active_requests_for_cancellation(
857 &mut self,
858 js_commands: &mpsc::Sender<JsCommand>,
859 ) {
860 let ids: Vec<String> = self.active_request_ids.iter().cloned().collect();
861 for id in ids {
862 self.active_request_ids.remove(&id);
863 self.reject_request_for_cancellation(id, js_commands).await;
864 }
865 }
866
867 async fn reject_request_for_cancellation(
868 &self,
869 id: String,
870 js_commands: &mpsc::Sender<JsCommand>,
871 ) {
872 let _ = send_js_command(
873 js_commands,
874 JsCommand::ResolveRequest {
875 id,
876 resolution: WorkflowRuntimeRequestResolution::Err {
877 message: "workflow cancelled".to_string(),
878 },
879 },
880 )
881 .await;
882 }
883
884 async fn drain_runtime_after_cancellation(&mut self, js_events: &mut mpsc::Receiver<JsEvent>) {
885 while let Some(event) = js_events.recv().await {
886 match event {
887 JsEvent::Call(call) => {
888 let _ = self.handle_call(call).await;
889 }
890 JsEvent::Request(request) => {
891 log::debug!(
892 "ignoring request after cancellation id={} kind={}",
893 request.id(),
894 request.kind()
895 );
896 }
897 JsEvent::Complete(_) | JsEvent::Error(_) => break,
898 }
899 }
900 }
901
902 fn event_step_id(&self, runtime_request_id: &str) -> String {
903 let parent = self.event_parent_step_id.as_deref().unwrap_or("");
904 let hash = blake3::hash(
905 format!("{parent}:{}:{runtime_request_id}", self.nesting_depth).as_bytes(),
906 );
907 format!("step_{}", &hash.to_hex()[..16])
908 }
909
910 async fn emit_event(&self, mut event: WorkflowEvent) -> anyhow::Result<()> {
911 if (event.event_type.as_str() != "workflow.started" || self.nesting_depth > 0)
912 && event.elapsed_nanos.is_none()
913 {
914 event.elapsed_nanos = Some(elapsed_nanos(self.event_start));
915 }
916 let metadata = event
917 .metadata
918 .get_or_insert_with(WorkflowEventMetadata::default);
919 if metadata.workflow_depth.is_none() {
920 metadata.workflow_depth = Some(u32::try_from(self.nesting_depth).unwrap_or(u32::MAX));
921 }
922 if metadata.parent_step_id.is_none() {
923 metadata.parent_step_id = self.event_parent_step_id.clone();
924 }
925 if let Some(event_sink) = self.event_sink.as_ref() {
926 event_sink.emit(event).await?;
927 }
928 Ok(())
929 }
930
931 async fn handle_call(&mut self, call: WorkflowRuntimeCall) -> anyhow::Result<()> {
932 match call {
933 WorkflowRuntimeCall::Log { values } => {
934 self.emit_event(WorkflowEvent::log(format_log_message(&values)))
935 .await?;
936 self.logs.push(values);
937 }
938 WorkflowRuntimeCall::Phase { name, options } => {
939 let phase = WorkflowPhaseCall { name, options };
940 self.emit_event(WorkflowEvent::phase(
941 phase.name.clone(),
942 phase.options.clone(),
943 ))
944 .await?;
945 self.phases.push(phase);
946 }
947 }
948 Ok(())
949 }
950
951 fn agent_capacity_available(&self, in_flight: usize) -> bool {
952 let max_parallel = self
953 .max_parallel_agent_requests
954 .filter(|value| *value > 0)
955 .unwrap_or(usize::MAX);
956 in_flight < max_parallel
957 }
958
959 fn prepare_agent_request(
960 &mut self,
961 request: WorkflowRuntimeRequest,
962 ) -> Result<(String, PreparedAgentRun), (String, anyhow::Error)> {
963 match request {
964 WorkflowRuntimeRequest::Agent {
965 id,
966 prompt,
967 options,
968 } => {
969 self.agent_calls.push(WorkflowRuntimeRequest::Agent {
970 id: id.clone(),
971 prompt: prompt.clone(),
972 options: options.clone(),
973 });
974 match self.prepare_agent_run(prompt, options) {
975 Ok(prepared) => Ok((id, prepared)),
976 Err(error) => Err((id, error)),
977 }
978 }
979 WorkflowRuntimeRequest::Workflow { .. } | WorkflowRuntimeRequest::Sleep { .. } => {
980 unreachable!("prepare_agent_request only accepts agent requests")
981 }
982 }
983 }
984
985 fn spawn_agent_task(
986 &mut self,
987 agent_tasks: &mut JoinSet<AgentTaskCompletion>,
988 id: String,
989 prepared: PreparedAgentRun,
990 ) {
991 let default_provider_name = self.agent_provider.name().to_string();
992 let default_provider = Arc::clone(&self.agent_provider);
993 let agent_runner = Arc::clone(&self.agent_runner);
994 let retry_in_runtime = agent_runner.retry_in_runtime();
995 let cancel_rx = self.cancel_rx.clone();
996 let completion_input = prepared.input.clone();
997 let completion_provider = prepared
998 .provider_override
999 .clone()
1000 .or(Some(default_provider_name));
1001 let session_log_sink = self.session_log_sink.clone();
1002 let max_parallel = self
1003 .max_parallel_agent_requests
1004 .filter(|value| *value > 0)
1005 .unwrap_or(usize::MAX);
1006 log::debug!(
1007 "starting agent request id={} in_flight_after_start={} max_parallel={}",
1008 id,
1009 agent_tasks.len() + 1,
1010 max_parallel
1011 );
1012 self.active_request_ids.insert(id.clone());
1013 agent_tasks.spawn(async move {
1014 let result = if retry_in_runtime {
1015 run_agent_runner_with_retry(
1016 Arc::clone(&agent_runner),
1017 default_provider,
1018 prepared.provider_override,
1019 prepared.input,
1020 cancel_rx,
1021 )
1022 .await
1023 } else {
1024 agent_runner
1025 .run_agent(default_provider, prepared.provider_override, prepared.input)
1026 .await
1027 };
1028 let result = match result {
1029 Ok(result) => {
1030 if let Some(session_log_sink) = session_log_sink.as_ref() {
1031 let provider_name = completion_provider
1032 .as_deref()
1033 .expect("completion provider should always be set");
1034 match session_log_sink
1035 .write_agent_result(provider_name, &result)
1036 .await
1037 {
1038 Ok(()) => Ok(result),
1039 Err(error) => Err(error),
1040 }
1041 } else {
1042 Ok(result)
1043 }
1044 }
1045 Err(error) => Err(error),
1046 };
1047 AgentTaskCompletion {
1048 id,
1049 input: completion_input,
1050 provider: completion_provider,
1051 result,
1052 }
1053 });
1054 }
1055
1056 fn spawn_sleep_task(
1057 &mut self,
1058 sleep_tasks: &mut JoinSet<SleepTaskCompletion>,
1059 id: String,
1060 duration_ms: u64,
1061 ) {
1062 let agent_runner = Arc::clone(&self.agent_runner);
1063 log::debug!(
1064 "starting sleep request id={} duration_ms={}",
1065 id,
1066 duration_ms
1067 );
1068 self.active_request_ids.insert(id.clone());
1069 sleep_tasks.spawn(async move {
1070 SleepTaskCompletion {
1071 id,
1072 result: agent_runner.sleep(duration_ms).await,
1073 }
1074 });
1075 }
1076
1077 fn prepare_agent_run(
1078 &self,
1079 prompt: String,
1080 options: Option<Value>,
1081 ) -> anyhow::Result<PreparedAgentRun> {
1082 let options = apply_phase_defaults(options, &self.metadata);
1083 let context = AgentProviderContext {
1084 phase: options
1085 .as_ref()
1086 .and_then(|options| options.get("phase"))
1087 .and_then(Value::as_str)
1088 .map(ToString::to_string),
1089 cwd: self.script_path.parent().map(Path::to_path_buf),
1090 };
1091 let provider_override = options
1092 .as_ref()
1093 .and_then(|options| options.get("provider"))
1094 .and_then(Value::as_str)
1095 .map(ToString::to_string);
1096 let provider_name = provider_override
1097 .as_deref()
1098 .unwrap_or_else(|| self.agent_provider.name());
1099 let options = resolve_model_options(options, provider_name, &self.model_map)?;
1100 agent_retry_policy(&options)?;
1101 log::debug!(
1102 "agent call provider={} phase={:?} model={:?} prompt_len={}",
1103 provider_name,
1104 context.phase.as_deref(),
1105 options
1106 .as_ref()
1107 .and_then(|options| options.get("model"))
1108 .and_then(Value::as_str),
1109 prompt.len()
1110 );
1111 Ok(PreparedAgentRun {
1112 provider_override,
1113 input: AgentProviderRunInput {
1114 prompt,
1115 options,
1116 context,
1117 },
1118 })
1119 }
1120
1121 async fn emit_agent_started_event(
1122 &self,
1123 id: &str,
1124 prepared: &PreparedAgentRun,
1125 ) -> anyhow::Result<()> {
1126 let provider = prepared
1127 .provider_override
1128 .as_deref()
1129 .unwrap_or_else(|| self.agent_provider.name());
1130 let metadata = self.agent_event_metadata(id, Some(provider), None);
1131 self.emit_event(WorkflowEvent::agent_started(
1132 serde_json::json!({
1133 "phase": prepared.input.context.phase,
1134 "promptPreview": truncate_for_event(&prepared.input.prompt, 200),
1135 }),
1136 metadata,
1137 ))
1138 .await
1139 }
1140
1141 async fn apply_agent_result(
1142 &mut self,
1143 id: &str,
1144 input: &AgentProviderRunInput,
1145 provider: Option<String>,
1146 result: AgentProviderResult,
1147 ) -> anyhow::Result<Value> {
1148 if let Some(output_tokens) = result.usage.as_ref().and_then(|usage| usage.output_tokens) {
1149 self.budget.spent = self.budget.spent.saturating_add(output_tokens);
1150 }
1151 self.emit_agent_result_events(id, provider.as_deref(), &result)
1152 .await?;
1153 self.emit_agent_completed_event(id, provider.as_deref(), &result)
1154 .await?;
1155 self.record_agent_run(id, input, provider, &result);
1156 log::debug!(
1157 "agent call complete session_id={:?} output_tokens={:?} budget_spent={}",
1158 result.session_id,
1159 result.usage.as_ref().and_then(|usage| usage.output_tokens),
1160 self.budget.spent
1161 );
1162 Ok(result.output)
1163 }
1164
1165 async fn emit_agent_result_events(
1166 &self,
1167 id: &str,
1168 provider: Option<&str>,
1169 result: &AgentProviderResult,
1170 ) -> anyhow::Result<()> {
1171 let Some(raw) = result.raw.as_ref() else {
1172 return Ok(());
1173 };
1174 let metadata = self.agent_event_metadata(id, provider, result.session_id.clone());
1175 for provider_event in raw_agent_event_payloads(raw) {
1176 let event_data = agent_session_event_payload(provider_event, &metadata);
1177 self.emit_event(WorkflowEvent::agent_event(event_data, metadata.clone()))
1178 .await?;
1179 }
1180 Ok(())
1181 }
1182
1183 async fn emit_agent_completed_event(
1184 &self,
1185 id: &str,
1186 provider: Option<&str>,
1187 result: &AgentProviderResult,
1188 ) -> anyhow::Result<()> {
1189 let metadata = self.agent_event_metadata(id, provider, result.session_id.clone());
1190 self.emit_event(WorkflowEvent::agent_completed(
1191 serde_json::json!({
1192 "sessionId": result.session_id,
1193 "model": result.model,
1194 "usage": result.usage,
1195 }),
1196 metadata,
1197 ))
1198 .await
1199 }
1200
1201 async fn emit_agent_failed_event(
1202 &self,
1203 id: &str,
1204 provider: Option<&str>,
1205 message: &str,
1206 ) -> anyhow::Result<()> {
1207 let metadata = self.agent_event_metadata(id, provider, None);
1208 self.emit_event(WorkflowEvent::agent_failed(
1209 serde_json::json!({ "message": message }),
1210 metadata,
1211 ))
1212 .await
1213 }
1214
1215 fn agent_event_metadata(
1216 &self,
1217 id: &str,
1218 provider: Option<&str>,
1219 session_id: Option<String>,
1220 ) -> WorkflowEventMetadata {
1221 WorkflowEventMetadata {
1222 run_id: None,
1223 step_id: Some(self.event_step_id(id)),
1224 provider: Some(
1225 provider
1226 .unwrap_or_else(|| self.agent_provider.name())
1227 .to_string(),
1228 ),
1229 session_id,
1230 workflow_depth: None,
1231 parent_step_id: None,
1232 }
1233 }
1234
1235 fn record_agent_run(
1236 &mut self,
1237 id: &str,
1238 input: &AgentProviderRunInput,
1239 provider: Option<String>,
1240 result: &AgentProviderResult,
1241 ) {
1242 add_usage(&mut self.token_usage, result.usage.as_ref());
1243 if let Some(phase) = input.context.phase.as_ref() {
1244 let phase_usage = self.token_usage_by_phase.entry(phase.clone()).or_default();
1245 add_usage(phase_usage, result.usage.as_ref());
1246 }
1247 let model = result.model.clone().or_else(|| {
1248 input
1249 .options
1250 .as_ref()
1251 .and_then(|options| options.get("model"))
1252 .and_then(Value::as_str)
1253 .map(ToString::to_string)
1254 });
1255 self.agent_runs.push(WorkflowAgentRunSummary {
1256 id: id.to_string(),
1257 phase: input.context.phase.clone(),
1258 provider,
1259 model,
1260 provider_session_id: result.session_id.clone(),
1261 usage: result.usage.clone(),
1262 isolation: result.isolation.clone(),
1263 });
1264 }
1265
1266 async fn handle_workflow(
1267 &mut self,
1268 parent_step_id: String,
1269 workflow_ref: WorkflowRef,
1270 args: Option<Value>,
1271 ) -> anyhow::Result<Value> {
1272 if self.nesting_depth >= 1 {
1273 bail!("Nested workflow() calls are limited to one level");
1274 }
1275 let script_path = match workflow_ref {
1276 WorkflowRef::ScriptPath { script_path } => {
1277 resolve_relative_script(&self.script_path, &script_path)
1278 }
1279 WorkflowRef::Name(name) => resolve_named_workflow(&name)?,
1280 };
1281 log::debug!("child workflow call script={}", script_path.display());
1282 let child = Box::pin(run_workflow_inner(RunWorkflowOptions {
1283 script_path,
1284 args: args.unwrap_or(Value::Null),
1285 agent_provider: Arc::clone(&self.agent_provider),
1286 model_map: self.model_map.clone(),
1287 budget_total: self.budget.total,
1288 budget_spent: self.budget.spent,
1289 nesting_depth: self.nesting_depth + 1,
1290 max_parallel_agent_requests: self.max_parallel_agent_requests,
1291 agent_runner: Some(Arc::clone(&self.agent_runner)),
1292 cancel_rx: self.cancel_rx.clone(),
1293 event_sink: self.event_sink.clone(),
1294 event_parent_step_id: Some(parent_step_id),
1295 event_stream_start: Some(self.event_start),
1296 session_log_sink: self.session_log_sink.clone(),
1297 }))
1298 .await?;
1299 self.budget = child.budget;
1300 self.logs.extend(child.logs);
1301 self.phases.extend(child.phases);
1302 self.agent_calls.extend(child.agent_calls);
1303 self.workflow_calls.extend(child.workflow_calls);
1304 merge_token_usage(&mut self.token_usage, &child.token_usage);
1305 for (phase, usage) in child.token_usage_by_phase {
1306 merge_token_usage(self.token_usage_by_phase.entry(phase).or_default(), &usage);
1307 }
1308 self.agent_runs.extend(child.agent_runs);
1309 Ok(child.output.result)
1310 }
1311}
1312
1313async fn run_agent_runner_with_retry(
1314 agent_runner: Arc<dyn WorkflowAgentRunner>,
1315 default_provider: Arc<dyn AgentProvider>,
1316 provider_override: Option<String>,
1317 input: AgentProviderRunInput,
1318 mut cancel_rx: Option<watch::Receiver<bool>>,
1319) -> anyhow::Result<AgentProviderResult> {
1320 let retry = agent_retry_policy(&input.options)?;
1321 let mut final_result = None;
1322 for attempt in 1..=retry.max_attempts {
1323 let attempt_result = agent_runner
1324 .run_agent(
1325 Arc::clone(&default_provider),
1326 provider_override.clone(),
1327 input.clone(),
1328 )
1329 .await;
1330 match attempt_result {
1331 Ok(result) => {
1332 final_result = Some(Ok(result));
1333 break;
1334 }
1335 Err(error) if attempt < retry.max_attempts => {
1336 log::debug!(
1337 "agent call failed on attempt {attempt}/{}; retrying after {}ms: {error:#}",
1338 retry.max_attempts,
1339 retry.backoff_ms
1340 );
1341 sleep_retry_backoff(retry.backoff_ms, &mut cancel_rx).await?;
1342 }
1343 Err(error) => {
1344 final_result = Some(Err(error));
1345 break;
1346 }
1347 }
1348 }
1349 final_result.unwrap_or_else(|| Err(anyhow!("agent retry loop finished without a result")))
1350}
1351
1352async fn sleep_retry_backoff(
1353 backoff_ms: u64,
1354 cancel_rx: &mut Option<watch::Receiver<bool>>,
1355) -> anyhow::Result<()> {
1356 if backoff_ms == 0 {
1357 return Ok(());
1358 }
1359 let Some(cancel_rx) = cancel_rx.as_mut() else {
1360 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
1361 return Ok(());
1362 };
1363 if *cancel_rx.borrow() {
1364 bail!("workflow cancelled");
1365 }
1366 let sleep = tokio::time::sleep(Duration::from_millis(backoff_ms));
1367 tokio::pin!(sleep);
1368 loop {
1369 tokio::select! {
1370 _ = &mut sleep => return Ok(()),
1371 changed = cancel_rx.changed() => {
1372 match changed {
1373 Ok(()) if *cancel_rx.borrow() => bail!("workflow cancelled"),
1374 Ok(()) => continue,
1375 Err(_) => {
1376 sleep.await;
1377 return Ok(());
1378 }
1379 }
1380 }
1381 }
1382 }
1383}
1384
1385pub(crate) async fn run_agent_provider_with_retry(
1386 default_provider: Arc<dyn AgentProvider>,
1387 provider_override: Option<String>,
1388 input: AgentProviderRunInput,
1389 mut cancel_rx: Option<watch::Receiver<bool>>,
1390) -> anyhow::Result<AgentProviderResult> {
1391 let retry = agent_retry_policy(&input.options)?;
1392 let provider = resolve_agent_provider(default_provider, provider_override)?;
1393 let mut final_result = None;
1394 for attempt in 1..=retry.max_attempts {
1395 let attempt_result =
1396 run_agent_with_optional_isolation(Arc::clone(&provider), input.clone()).await;
1397 match attempt_result {
1398 Ok(result) => {
1399 final_result = Some(Ok(result));
1400 break;
1401 }
1402 Err(error) if attempt < retry.max_attempts => {
1403 log::debug!(
1404 "agent provider failed on attempt {attempt}/{}; retrying after {}ms: {error:#}",
1405 retry.max_attempts,
1406 retry.backoff_ms
1407 );
1408 sleep_retry_backoff(retry.backoff_ms, &mut cancel_rx).await?;
1409 }
1410 Err(error) => {
1411 final_result = Some(Err(error));
1412 break;
1413 }
1414 }
1415 }
1416 final_result.unwrap_or_else(|| Err(anyhow!("agent retry loop finished without a result")))
1417}
1418
1419pub(crate) async fn run_agent_provider(
1420 default_provider: Arc<dyn AgentProvider>,
1421 provider_override: Option<String>,
1422 input: AgentProviderRunInput,
1423) -> anyhow::Result<AgentProviderResult> {
1424 let provider = resolve_agent_provider(default_provider, provider_override)?;
1425 run_agent_with_optional_isolation(provider, input).await
1426}
1427
1428fn resolve_agent_provider(
1429 default_provider: Arc<dyn AgentProvider>,
1430 provider_override: Option<String>,
1431) -> anyhow::Result<Arc<dyn AgentProvider>> {
1432 if let Some(provider_override) = provider_override {
1433 Ok(Arc::from(create_agent_provider(&provider_override)?))
1434 } else {
1435 Ok(default_provider)
1436 }
1437}
1438
1439#[derive(Debug, Clone, Copy)]
1440pub(crate) struct AgentRetryPolicy {
1441 pub max_attempts: u32,
1442 pub backoff_ms: u64,
1443}
1444
1445pub(crate) fn agent_retry_policy(options: &Option<Value>) -> anyhow::Result<AgentRetryPolicy> {
1446 let default = AgentRetryPolicy {
1447 max_attempts: 1,
1448 backoff_ms: 0,
1449 };
1450 let Some(retry) = options.as_ref().and_then(|options| options.get("retry")) else {
1451 return Ok(default);
1452 };
1453 if retry.is_null() {
1454 return Ok(default);
1455 }
1456 let object = retry
1457 .as_object()
1458 .ok_or_else(|| anyhow!("agent retry option must be an object"))?;
1459 let max_attempts = match object.get("maxAttempts") {
1460 Some(value) => {
1461 let value = value
1462 .as_u64()
1463 .ok_or_else(|| anyhow!("agent retry.maxAttempts must be a positive integer"))?;
1464 if value == 0 || value > u32::MAX as u64 {
1465 bail!("agent retry.maxAttempts must be between 1 and {}", u32::MAX);
1466 }
1467 value as u32
1468 }
1469 None => default.max_attempts,
1470 };
1471 let backoff_ms = match object.get("backoffMs") {
1472 Some(value) => value
1473 .as_u64()
1474 .ok_or_else(|| anyhow!("agent retry.backoffMs must be a non-negative integer"))?,
1475 None => default.backoff_ms,
1476 };
1477 Ok(AgentRetryPolicy {
1478 max_attempts,
1479 backoff_ms,
1480 })
1481}
1482
1483async fn run_agent_with_optional_isolation(
1484 provider: Arc<dyn AgentProvider>,
1485 input: AgentProviderRunInput,
1486) -> anyhow::Result<AgentProviderResult> {
1487 if !requests_worktree_isolation(&input.options) {
1488 return run_agent_with_schema_validation(provider, input).await;
1489 }
1490
1491 let isolation = WorktreeIsolation::create(input.context.cwd.as_deref())?;
1492 let isolation_info = isolation.info();
1493 let mut isolated_input = input;
1494 isolated_input.context.cwd = Some(isolation.cwd.clone());
1495 let mut result = run_agent_with_schema_validation(provider, isolated_input).await;
1496 if let Ok(result) = &mut result {
1497 result.isolation = Some(isolation_info);
1498 }
1499 if let Err(error) = isolation.cleanup() {
1500 log::warn!("failed to cleanup isolated agent worktree: {error:#}");
1501 }
1502 result
1503}
1504
1505fn requests_worktree_isolation(options: &Option<Value>) -> bool {
1506 options
1507 .as_ref()
1508 .and_then(|options| options.get("isolation"))
1509 .and_then(Value::as_str)
1510 == Some("worktree")
1511}
1512
1513struct WorktreeIsolation {
1514 repo_root: PathBuf,
1515 worktree_root: PathBuf,
1516 cwd: PathBuf,
1517 branch_name: String,
1518 cleaned: bool,
1519 _temp_dir: tempfile::TempDir,
1520}
1521
1522impl WorktreeIsolation {
1523 fn create(cwd: Option<&Path>) -> anyhow::Result<Self> {
1524 let cwd = cwd
1525 .map(Path::to_path_buf)
1526 .unwrap_or(std::env::current_dir()?)
1527 .canonicalize()
1528 .context("failed to canonicalize workflow cwd for worktree isolation")?;
1529 let repo_root = git_output(&cwd, &["rev-parse", "--show-toplevel"]).context(
1530 "agent isolation='worktree' requires the workflow cwd to be inside a git repository",
1531 )?;
1532 let repo_root = PathBuf::from(repo_root.trim())
1533 .canonicalize()
1534 .context("failed to canonicalize git repository root for worktree isolation")?;
1535 let relative_cwd = cwd.strip_prefix(&repo_root).with_context(|| {
1536 format!(
1537 "workflow cwd {} is not under git repository root {}",
1538 cwd.display(),
1539 repo_root.display()
1540 )
1541 })?;
1542
1543 let temp_dir = tempfile::Builder::new()
1544 .prefix("smol-wf-agent-worktree-")
1545 .tempdir()
1546 .context("failed to create temp directory for agent worktree isolation")?;
1547 let worktree_root = temp_dir.path().join("worktree");
1548 let worktree_arg = path_arg(&worktree_root);
1549 let branch_name = format!(
1550 "smol-wf/agent-run/{}",
1551 ulid::Ulid::new().to_string().to_ascii_lowercase()
1552 );
1553 git_status(
1554 &repo_root,
1555 &[
1556 "worktree",
1557 "add",
1558 "--quiet",
1559 "-b",
1560 &branch_name,
1561 &worktree_arg,
1562 "HEAD",
1563 ],
1564 )
1565 .context("failed to create isolated git worktree for agent run")?;
1566 let isolated_cwd = if relative_cwd.as_os_str().is_empty() {
1567 worktree_root.clone()
1568 } else {
1569 worktree_root.join(relative_cwd)
1570 };
1571 Ok(Self {
1572 repo_root,
1573 worktree_root,
1574 cwd: isolated_cwd,
1575 branch_name,
1576 cleaned: false,
1577 _temp_dir: temp_dir,
1578 })
1579 }
1580
1581 fn info(&self) -> AgentRunIsolation {
1582 AgentRunIsolation {
1583 kind: "worktree".to_string(),
1584 branch: Some(self.branch_name.clone()),
1585 worktree_path: Some(path_arg(&self.worktree_root)),
1586 cwd: Some(path_arg(&self.cwd)),
1587 }
1588 }
1589
1590 fn cleanup(mut self) -> anyhow::Result<()> {
1591 self.remove_worktree()?;
1592 self.delete_branch()?;
1593 self.cleaned = true;
1594 Ok(())
1595 }
1596
1597 fn remove_worktree(&self) -> anyhow::Result<()> {
1598 let worktree_arg = path_arg(&self.worktree_root);
1599 git_status(
1600 &self.repo_root,
1601 &["worktree", "remove", "--force", &worktree_arg],
1602 )
1603 .context("failed to remove isolated git worktree")
1604 }
1605
1606 fn delete_branch(&self) -> anyhow::Result<()> {
1607 git_status(&self.repo_root, &["branch", "-D", &self.branch_name])
1608 .context("failed to delete isolated agent worktree branch")
1609 }
1610}
1611
1612impl Drop for WorktreeIsolation {
1613 fn drop(&mut self) {
1614 if !self.cleaned {
1615 if let Err(error) = self.remove_worktree() {
1616 log::warn!("failed to cleanup isolated agent worktree during drop: {error:#}");
1617 }
1618 if let Err(error) = self.delete_branch() {
1619 log::warn!(
1620 "failed to delete isolated agent worktree branch during drop: {error:#}"
1621 );
1622 }
1623 }
1624 }
1625}
1626
1627fn path_arg(path: &Path) -> String {
1628 path.to_string_lossy().into_owned()
1629}
1630
1631fn git_output(cwd: &Path, args: &[&str]) -> anyhow::Result<String> {
1632 let output = StdCommand::new("git")
1633 .args(args)
1634 .current_dir(cwd)
1635 .output()
1636 .with_context(|| format!("failed to run git {}", args.join(" ")))?;
1637 if output.status.success() {
1638 Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
1639 } else {
1640 bail!(
1641 "git {} failed with {}{}",
1642 args.join(" "),
1643 status_text(output.status.code()),
1644 command_stderr(&output.stderr)
1645 )
1646 }
1647}
1648
1649fn git_status(cwd: &Path, args: &[&str]) -> anyhow::Result<()> {
1650 let output = StdCommand::new("git")
1651 .args(args)
1652 .current_dir(cwd)
1653 .output()
1654 .with_context(|| format!("failed to run git {}", args.join(" ")))?;
1655 if output.status.success() {
1656 Ok(())
1657 } else {
1658 bail!(
1659 "git {} failed with {}{}",
1660 args.join(" "),
1661 status_text(output.status.code()),
1662 command_stderr(&output.stderr)
1663 )
1664 }
1665}
1666
1667fn status_text(code: Option<i32>) -> String {
1668 code.map(|code| format!("code {code}"))
1669 .unwrap_or_else(|| "signal".to_string())
1670}
1671
1672fn command_stderr(stderr: &[u8]) -> String {
1673 let stderr = String::from_utf8_lossy(stderr);
1674 let stderr = stderr.trim();
1675 if stderr.is_empty() {
1676 String::new()
1677 } else {
1678 format!(": {stderr}")
1679 }
1680}
1681
1682async fn run_agent_with_schema_validation(
1683 provider: Arc<dyn AgentProvider>,
1684 input: AgentProviderRunInput,
1685) -> anyhow::Result<AgentProviderResult> {
1686 let Some(schema) = input
1687 .options
1688 .as_ref()
1689 .and_then(|options| options.get("schema"))
1690 .cloned()
1691 else {
1692 return provider.run(input).await;
1693 };
1694
1695 let max_attempts = 2;
1696 let original_prompt = input.prompt.clone();
1697 let mut attempt_input = input;
1698 let mut last_errors = Vec::new();
1699
1700 for attempt in 1..=max_attempts {
1701 let result = provider.run(attempt_input.clone()).await?;
1702 match validate_structured_output(&schema, &result.output) {
1703 Ok(()) => return Ok(result),
1704 Err(errors) => {
1705 last_errors = errors;
1706 if attempt < max_attempts {
1707 attempt_input.prompt =
1708 with_structured_output_retry_prompt(&original_prompt, &last_errors);
1709 }
1710 }
1711 }
1712 }
1713
1714 bail!(
1715 "{}",
1716 format_structured_output_validation_error(&last_errors)
1717 )
1718}
1719
1720fn validate_structured_output(schema: &Value, output: &Value) -> Result<(), Vec<String>> {
1721 let validator = jsonschema::validator_for(schema)
1722 .map_err(|error| vec![format!("/ schema is invalid: {}", error)])?;
1723 let errors = validator
1724 .iter_errors(output)
1725 .map(|error| {
1726 let path = error.instance_path().to_string();
1727 let path = if path.is_empty() {
1728 "/".to_string()
1729 } else {
1730 path
1731 };
1732 format!("{path} {error}")
1733 })
1734 .collect::<Vec<_>>();
1735
1736 if errors.is_empty() {
1737 Ok(())
1738 } else {
1739 Err(errors)
1740 }
1741}
1742
1743fn format_structured_output_validation_error(errors: &[String]) -> String {
1744 format!(
1745 "Structured output did not match JSON Schema: {}",
1746 errors.join("; ")
1747 )
1748}
1749
1750fn with_structured_output_retry_prompt(prompt: &str, errors: &[String]) -> String {
1751 let mut lines = vec![
1752 prompt.to_string(),
1753 String::new(),
1754 "Previous structured output failed JSON Schema validation.".to_string(),
1755 "Return a corrected structured output that satisfies the original JSON Schema.".to_string(),
1756 "Validation errors:".to_string(),
1757 ];
1758 lines.extend(errors.iter().map(|error| format!("- {error}")));
1759 lines.join("\n")
1760}
1761
1762#[derive(Debug, Clone, PartialEq, Eq)]
1763struct ResolvedModelSelector {
1764 requested: String,
1765 selector: String,
1766 model_id: String,
1767 model_provider: Option<String>,
1768 thinking: Option<String>,
1769}
1770
1771impl ResolvedModelSelector {
1772 fn provider_model(&self) -> String {
1773 match &self.model_provider {
1774 Some(provider) => format!("{provider}/{}", self.model_id),
1775 None => self.model_id.clone(),
1776 }
1777 }
1778}
1779
1780fn resolve_model_options(
1781 options: Option<Value>,
1782 agent_provider: &str,
1783 model_map: &BTreeMap<String, String>,
1784) -> anyhow::Result<Option<Value>> {
1785 let Some(model) = options
1786 .as_ref()
1787 .and_then(Value::as_object)
1788 .and_then(|object| object.get("model"))
1789 .and_then(Value::as_str)
1790 .map(ToString::to_string)
1791 else {
1792 return Ok(options);
1793 };
1794
1795 let mapped_selector = model_map.get(&model).cloned();
1796 let alias_matched = mapped_selector.is_some();
1797 let selector = mapped_selector.unwrap_or_else(|| model.clone());
1798 let resolved = parse_model_selector(&model, &selector)?;
1799 validate_model_selector_for_provider(&resolved, agent_provider)?;
1800
1801 let mut object = options
1802 .and_then(|value| value.as_object().cloned())
1803 .unwrap_or_default();
1804 object.insert(
1805 "model".to_string(),
1806 Value::String(resolved.provider_model()),
1807 );
1808
1809 let selector_has_extra_parts = alias_matched
1810 || resolved.selector.contains('?')
1811 || resolved.model_provider.is_some()
1812 || resolved.thinking.is_some();
1813 if selector_has_extra_parts {
1814 object.insert(
1815 "requestedModel".to_string(),
1816 Value::String(resolved.requested.clone()),
1817 );
1818 object.insert(
1819 "modelSelector".to_string(),
1820 Value::String(resolved.selector.clone()),
1821 );
1822 } else {
1823 object.remove("requestedModel");
1824 object.remove("modelSelector");
1825 }
1826
1827 if let Some(provider) = resolved.model_provider {
1828 object.insert("modelProvider".to_string(), Value::String(provider));
1829 } else {
1830 object.remove("modelProvider");
1831 }
1832 if let Some(thinking) = resolved.thinking {
1833 object.insert("thinking".to_string(), Value::String(thinking));
1834 } else {
1835 object.remove("thinking");
1836 }
1837 Ok(Some(Value::Object(object)))
1838}
1839
1840fn parse_model_selector(requested: &str, selector: &str) -> anyhow::Result<ResolvedModelSelector> {
1841 let (model_part, query) = selector.split_once('?').unwrap_or((selector, ""));
1842 if model_part.trim().is_empty() {
1843 bail!("model selector must include a model id: {selector}");
1844 }
1845
1846 let (slash_provider, model_id) = match model_part.split_once('/') {
1847 Some((provider, model_id)) if !provider.is_empty() && !model_id.is_empty() => {
1848 (Some(provider.to_string()), model_id.to_string())
1849 }
1850 Some(_) => bail!("model selector provider/model form is invalid: {selector}"),
1851 None => (None, model_part.to_string()),
1852 };
1853
1854 let mut query_provider = None::<String>;
1855 let mut thinking = None::<String>;
1856 if !query.is_empty() {
1857 for pair in query.split('&') {
1858 if pair.is_empty() {
1859 continue;
1860 }
1861 let (key, value) = pair.split_once('=').ok_or_else(|| {
1862 anyhow!("model selector query parameter must use key=value: {pair}")
1863 })?;
1864 let key = percent_decode(key)?;
1865 let value = percent_decode(value)?;
1866 if value.is_empty() {
1867 bail!("model selector query parameter `{key}` must not be empty");
1868 }
1869 match key.as_str() {
1870 "provider" => set_unique_query_value(&mut query_provider, key, value)?,
1871 "thinking" => set_unique_query_value(&mut thinking, key, value)?,
1872 _ => bail!("unknown model selector query parameter `{key}`"),
1873 }
1874 }
1875 }
1876
1877 let model_provider = match (slash_provider, query_provider) {
1878 (Some(slash), Some(query)) if slash != query => bail!(
1879 "conflicting model provider qualifiers in selector `{selector}`: `{slash}` and `{query}`"
1880 ),
1881 (Some(provider), Some(_)) | (Some(provider), None) | (None, Some(provider)) => {
1882 Some(provider)
1883 }
1884 (None, None) => None,
1885 };
1886
1887 Ok(ResolvedModelSelector {
1888 requested: requested.to_string(),
1889 selector: selector.to_string(),
1890 model_id,
1891 model_provider,
1892 thinking,
1893 })
1894}
1895
1896fn set_unique_query_value(
1897 target: &mut Option<String>,
1898 key: String,
1899 value: String,
1900) -> anyhow::Result<()> {
1901 if target.replace(value).is_some() {
1902 bail!("duplicate model selector query parameter `{key}`");
1903 }
1904 Ok(())
1905}
1906
1907fn percent_decode(value: &str) -> anyhow::Result<String> {
1908 let bytes = value.as_bytes();
1909 let mut output = Vec::with_capacity(bytes.len());
1910 let mut index = 0;
1911 while index < bytes.len() {
1912 match bytes[index] {
1913 b'%' => {
1914 if index + 2 >= bytes.len() {
1915 bail!("invalid percent escape in model selector query: {value}");
1916 }
1917 let high = hex_value(bytes[index + 1]).ok_or_else(|| {
1918 anyhow!("invalid percent escape in model selector query: {value}")
1919 })?;
1920 let low = hex_value(bytes[index + 2]).ok_or_else(|| {
1921 anyhow!("invalid percent escape in model selector query: {value}")
1922 })?;
1923 output.push((high << 4) | low);
1924 index += 3;
1925 }
1926 b'+' => {
1927 output.push(b' ');
1928 index += 1;
1929 }
1930 byte => {
1931 output.push(byte);
1932 index += 1;
1933 }
1934 }
1935 }
1936 String::from_utf8(output).context("model selector query is not valid UTF-8")
1937}
1938
1939fn hex_value(byte: u8) -> Option<u8> {
1940 match byte {
1941 b'0'..=b'9' => Some(byte - b'0'),
1942 b'a'..=b'f' => Some(byte - b'a' + 10),
1943 b'A'..=b'F' => Some(byte - b'A' + 10),
1944 _ => None,
1945 }
1946}
1947
1948fn validate_model_selector_for_provider(
1949 resolved: &ResolvedModelSelector,
1950 agent_provider: &str,
1951) -> anyhow::Result<()> {
1952 match agent_provider {
1953 "codex" => {
1954 if resolved.model_provider.is_some() {
1955 bail!("Codex model selectors do not support ?provider=... or provider/model form");
1956 }
1957 if resolved.thinking.is_some() {
1958 bail!("Codex model selectors do not support thinking=...");
1959 }
1960 }
1961 "claude-code" if resolved.model_provider.is_some() => {
1962 bail!(
1963 "Claude Code model selectors do not support ?provider=... or provider/model form"
1964 );
1965 }
1966 "opencode" if resolved.model_provider.is_none() => {
1967 bail!("OpenCode model selectors must use provider/model or ?provider=...");
1968 }
1969 "debug" | "pi" => {}
1970 _ => {}
1971 }
1972 Ok(())
1973}
1974
1975fn apply_phase_defaults(options: Option<Value>, metadata: &WorkflowMetadata) -> Option<Value> {
1976 let phase_name = options
1977 .as_ref()
1978 .and_then(|options| options.get("phase"))
1979 .and_then(Value::as_str)
1980 .map(ToString::to_string);
1981 let phase_metadata = phase_name.as_ref().and_then(|phase_name| {
1982 metadata
1983 .phases
1984 .iter()
1985 .find(|phase| phase.title == *phase_name)
1986 });
1987
1988 if phase_name.is_none() && phase_metadata.is_none() {
1989 return options;
1990 }
1991
1992 let mut object = options
1993 .and_then(|value| value.as_object().cloned())
1994 .unwrap_or_default();
1995
1996 if let Some(phase_name) = phase_name {
1997 object
1998 .entry("phase".to_string())
1999 .or_insert(Value::String(phase_name));
2000 }
2001 if let Some(model) = phase_metadata.and_then(|phase| phase.model.clone()) {
2002 object
2003 .entry("model".to_string())
2004 .or_insert(Value::String(model));
2005 }
2006 if let Some(provider) = phase_metadata.and_then(|phase| phase.provider.clone()) {
2007 object
2008 .entry("provider".to_string())
2009 .or_insert(Value::String(provider));
2010 }
2011
2012 Some(Value::Object(object))
2013}
2014
2015fn resolve_relative_script(current_script_path: &Path, script_path: &str) -> PathBuf {
2016 let script_path = PathBuf::from(script_path);
2017 if script_path.is_absolute() {
2018 script_path
2019 } else {
2020 current_script_path
2021 .parent()
2022 .unwrap_or_else(|| Path::new("."))
2023 .join(script_path)
2024 }
2025}
2026
2027fn resolve_named_workflow(name: &str) -> anyhow::Result<PathBuf> {
2028 let workflows_dir = PathBuf::from(".claude/workflows");
2029 for entry in fs::read_dir(&workflows_dir).unwrap_or_else(|_| fs::read_dir(".").unwrap()) {
2030 let entry = entry?;
2031 let path = entry.path();
2032 if path.extension().and_then(|extension| extension.to_str()) != Some("js") {
2033 continue;
2034 }
2035 if read_workflow_metadata(&path)?.is_some_and(|metadata| metadata.name == name) {
2036 return Ok(path);
2037 }
2038 }
2039 bail!("Unknown workflow: {name}")
2040}