Skip to main content

streamweave_attractor/
runner.rs

1//! Compiled graph runner: compile AttractorGraph to StreamWeave graph and run it.
2//!
3//! - [run_streamweave_graph]: run a compiled graph (one trigger in, first output out).
4//! - [run_compiled_graph]: compile AST then run, return [crate::nodes::execution_loop::AttractorResult].
5
6use crate::checkpoint_io::{self, CHECKPOINT_FILENAME};
7use crate::nodes::execution_loop::AttractorResult;
8use crate::nodes::execution_loop::{RunLoopResult, run_execution_loop_once};
9use crate::nodes::init_context::create_initial_state;
10use crate::types::{AttractorGraph, Checkpoint, ExecutionLog, GraphPayload, NodeOutcome};
11use std::path::Path;
12use std::sync::Arc;
13use tracing::instrument;
14
15/// Runs a compiled StreamWeave graph: feeds one trigger into the "input" port,
16/// runs until the graph produces output on the "output" port, then returns the first output item.
17///
18/// The graph must have been built with `input` and `output` port names (as produced by
19/// [crate::compile_attractor_graph].
20#[instrument(level = "trace", skip(graph, initial))]
21pub async fn run_streamweave_graph(
22  mut graph: streamweave::graph::Graph,
23  initial: GraphPayload,
24) -> Result<Option<Arc<dyn std::any::Any + Send + Sync>>, String> {
25  let (tx_in, rx_in) = tokio::sync::mpsc::channel(1);
26  let (_tx_out, mut rx_out) = tokio::sync::mpsc::channel(16);
27
28  graph
29    .connect_input_channel("input", rx_in)
30    .map_err(|e| e.to_string())?;
31  graph
32    .connect_output_channel("output", _tx_out)
33    .map_err(|e| e.to_string())?;
34
35  tx_in
36    .send(Arc::new(initial) as Arc<dyn std::any::Any + Send + Sync>)
37    .await
38    .map_err(|e| e.to_string())?;
39  drop(tx_in);
40
41  tracing::trace!("run_streamweave_graph: calling graph.execute()");
42  graph.execute().await.map_err(|e| e.to_string())?;
43  tracing::trace!("run_streamweave_graph: execute done, waiting for output on rx_out.recv()");
44  let first = rx_out.recv().await;
45  tracing::trace!("run_streamweave_graph: received output, calling wait_for_completion()");
46  graph
47    .wait_for_completion()
48    .await
49    .map_err(|e| e.to_string())?;
50  Ok(first)
51}
52
53/// Options for [run_compiled_graph].
54pub struct RunOptions<'a> {
55  /// If set, checkpoint is written here at successful exit (to `run_dir/checkpoint.json`).
56  pub run_dir: Option<&'a Path>,
57  /// Command for agent/codergen nodes (e.g. cursor-agent). Required if the graph has codergen nodes.
58  pub agent_cmd: Option<String>,
59  /// Directory for agent outcome.json and staging.
60  pub stage_dir: Option<std::path::PathBuf>,
61  /// If set, execution steps are recorded and written to this path as execution.log.json (on success and failure).
62  pub execution_log_path: Option<std::path::PathBuf>,
63}
64
65/// Writes execution.log.json to the given path (on both success and failure).
66fn write_execution_log(
67  path: &Path,
68  goal: &str,
69  started_at: &str,
70  final_status: &str,
71  completed_nodes: &[String],
72  steps: Vec<crate::types::ExecutionStepEntry>,
73) -> Result<(), String> {
74  let finished_at = chrono::Utc::now().to_rfc3339();
75  let log = ExecutionLog {
76    version: 1,
77    goal: goal.to_string(),
78    started_at: started_at.to_string(),
79    finished_at: Some(finished_at),
80    final_status: final_status.to_string(),
81    completed_nodes: completed_nodes.to_vec(),
82    steps,
83  };
84  let json = serde_json::to_string_pretty(&log).map_err(|e| e.to_string())?;
85  std::fs::write(path, json).map_err(|e| e.to_string())?;
86  Ok(())
87}
88
89/// Compiles the Attractor graph to a StreamWeave graph, runs it, and returns an [AttractorResult].
90/// Uses [crate::compile_attractor_graph]. Initial context includes the graph goal.
91/// When [RunOptions::execution_log_path] is set, runs via the execution loop and writes execution.log.json.
92#[instrument(level = "trace", skip(ast, options))]
93pub async fn run_compiled_graph(
94  ast: &AttractorGraph,
95  options: RunOptions<'_>,
96) -> Result<AttractorResult, String> {
97  if let Some(ref log_path) = options.execution_log_path {
98    let started_at = chrono::Utc::now().to_rfc3339();
99    let mut state = create_initial_state(ast.clone(), Some(vec![]));
100    match run_execution_loop_once(&mut state) {
101      RunLoopResult::Ok(result) => {
102        let steps = state.step_log.unwrap_or_default();
103        write_execution_log(
104          log_path,
105          &ast.goal,
106          &started_at,
107          "success",
108          &result.completed_nodes,
109          steps,
110        )?;
111        if let Some(run_dir) = options.run_dir {
112          let cp = Checkpoint {
113            context: result.context.clone(),
114            current_node_id: result.completed_nodes.last().cloned().unwrap_or_default(),
115            completed_nodes: result.completed_nodes.clone(),
116          };
117          let path = run_dir.join(CHECKPOINT_FILENAME);
118          checkpoint_io::save_checkpoint(&path, &cp).map_err(|e| e.to_string())?;
119        }
120        return Ok(result);
121      }
122      RunLoopResult::Err(e) => {
123        let steps = state.step_log.unwrap_or_default();
124        let completed = state.completed_nodes.clone();
125        write_execution_log(log_path, &ast.goal, &started_at, "error", &completed, steps)?;
126        return Err(e);
127      }
128    }
129  }
130
131  let stage_dir = options
132    .stage_dir
133    .as_deref()
134    .or_else(|| Some(std::path::Path::new(crate::DEFAULT_STAGE_DIR)));
135  let mut graph =
136    crate::compiler::compile_attractor_graph(ast, None, options.agent_cmd.as_deref(), stage_dir)?;
137  let mut ctx = std::collections::HashMap::new();
138  ctx.insert("goal".to_string(), ast.goal.clone());
139  ctx.insert("graph.goal".to_string(), ast.goal.clone());
140  let start_id = ast
141    .find_start()
142    .map(|n| n.id.clone())
143    .ok_or("missing start node")?;
144  let initial = GraphPayload::initial(ctx, start_id);
145
146  let (tx_in, rx_in) = tokio::sync::mpsc::channel(1);
147  let (_tx_out, mut rx_out) = tokio::sync::mpsc::channel(16);
148  let (_tx_err, mut rx_err) = tokio::sync::mpsc::channel(16);
149
150  graph
151    .connect_input_channel("input", rx_in)
152    .map_err(|e| e.to_string())?;
153  graph
154    .connect_output_channel("output", _tx_out)
155    .map_err(|e| e.to_string())?;
156  let has_error_port = graph.connect_output_channel("error", _tx_err).is_ok();
157
158  tx_in
159    .send(Arc::new(initial) as Arc<dyn std::any::Any + Send + Sync>)
160    .await
161    .map_err(|e| e.to_string())?;
162  drop(tx_in);
163
164  tracing::trace!("run_streamweave_graph: calling graph.execute()");
165  graph.execute().await.map_err(|e| e.to_string())?;
166  tracing::trace!("run_streamweave_graph: execute done, waiting for first of output or error");
167  let first = if has_error_port {
168    tokio::select! {
169      Some(arc) = rx_out.recv() => Some(arc),
170      Some(arc) = rx_err.recv() => Some(arc),
171      else => None,
172    }
173  } else {
174    rx_out.recv().await
175  };
176  // Do not wait_for_completion(); first result decides outcome, avoids hang on merge graphs.
177
178  let payload = first
179    .and_then(|arc| arc.downcast::<GraphPayload>().ok())
180    .map(|p| (*p).clone());
181  let (context, last_outcome, completed_nodes, current_node_id) = payload
182    .as_ref()
183    .map(|p| {
184      (
185        p.context.clone(),
186        p.outcome
187          .clone()
188          .unwrap_or_else(|| NodeOutcome::success("Exit")),
189        p.completed_nodes.clone(),
190        p.current_node_id.clone(),
191      )
192    })
193    .unwrap_or_else(|| {
194      (
195        std::collections::HashMap::new(),
196        NodeOutcome::success("Exit"),
197        vec![],
198        String::new(),
199      )
200    });
201
202  if let Some(run_dir) = options.run_dir {
203    let cp = Checkpoint {
204      context: context.clone(),
205      current_node_id: current_node_id.clone(),
206      completed_nodes: completed_nodes.clone(),
207    };
208    let path = run_dir.join(CHECKPOINT_FILENAME);
209    checkpoint_io::save_checkpoint(&path, &cp).map_err(|e| e.to_string())?;
210  }
211
212  Ok(AttractorResult {
213    last_outcome,
214    completed_nodes,
215    context,
216  })
217}