1use ahash::RandomState;
2use std::collections::{HashMap, HashSet, VecDeque};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use chrono::Utc;
7use futures::stream::{self, BoxStream, StreamExt};
8use petgraph::graph::Graph;
9use tokio::sync::mpsc;
10use tokio::task::JoinSet;
11
12use crate::observer::ObserverCallbackAdapter;
13use crate::{
14 Checkpoint, Checkpointer, EdgeKind, ExecutionConfig, ExecutionOptions, GraphError, GraphEvent,
15 GraphProgram, GraphState, NodeData, Observer, StateSchema, StateUpdate, END, START,
16};
17use serde_json::json;
18use wesichain_core::{
19 ensure_object, AgentEvent, CallbackManager, RunContext, RunType, Runnable, ToTraceInput,
20 ToTraceOutput, WesichainError,
21};
22
23pub type Condition<S> = Box<dyn Fn(&GraphState<S>) -> Vec<String> + Send + Sync>;
24
25pub struct GraphContext {
26 pub remaining_steps: Option<usize>,
27 pub observer: Option<Arc<dyn Observer>>,
28 pub node_id: String,
29}
30
31async fn emit_status_event(
32 sender: &Option<mpsc::Sender<AgentEvent>>,
33 step: &mut usize,
34 thread_id: &str,
35 stage: impl Into<String>,
36 message: impl Into<String>,
37) {
38 if let Some(sender) = sender {
39 *step += 1;
40 let _ = sender
41 .send(AgentEvent::Status {
42 stage: stage.into(),
43 message: message.into(),
44 step: *step,
45 thread_id: thread_id.to_string(),
46 })
47 .await;
48 }
49}
50
51async fn emit_error_event(
52 sender: &Option<mpsc::Sender<AgentEvent>>,
53 step: &mut usize,
54 message: impl Into<String>,
55 source: Option<String>,
56) {
57 if let Some(sender) = sender {
58 *step += 1;
59 let _ = sender
60 .send(AgentEvent::Error {
61 message: message.into(),
62 step: *step,
63 recoverable: false,
64 source,
65 })
66 .await;
67 }
68}
69
70#[async_trait::async_trait]
71pub trait GraphNode<S: StateSchema>: Send + Sync {
72 async fn invoke_with_context(
73 &self,
74 input: GraphState<S>,
75 context: &GraphContext,
76 ) -> Result<StateUpdate<S>, WesichainError>;
77}
78
79#[async_trait::async_trait]
80impl<S, R> GraphNode<S> for R
81where
82 S: StateSchema,
83 R: Runnable<GraphState<S>, StateUpdate<S>> + Send + Sync,
84{
85 async fn invoke_with_context(
86 &self,
87 input: GraphState<S>,
88 _context: &GraphContext,
89 ) -> Result<StateUpdate<S>, WesichainError> {
90 self.invoke(input).await
91 }
92}
93
94pub struct GraphBuilder<S: StateSchema> {
95 nodes: HashMap<String, Arc<dyn GraphNode<S>>>,
96 edges: HashMap<String, Vec<String>>,
97 conditional: HashMap<String, Condition<S>>,
98 checkpointer: Option<(Box<dyn Checkpointer<S>>, String)>,
99 observer: Option<Arc<dyn Observer>>,
100 default_config: ExecutionConfig,
101 entry: Option<String>,
102 interrupt_before: Vec<String>,
103 interrupt_after: Vec<String>,
104}
105
106impl<S: StateSchema> Default for GraphBuilder<S> {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112impl<S: StateSchema> GraphBuilder<S> {
113 pub fn new() -> Self {
114 Self {
115 nodes: HashMap::new(),
116 edges: HashMap::new(),
117 conditional: HashMap::new(),
118 checkpointer: None,
119 observer: None,
120 default_config: ExecutionConfig::default(),
121 entry: None,
122 interrupt_before: Vec::new(),
123 interrupt_after: Vec::new(),
124 }
125 }
126
127 pub fn add_node<R>(mut self, name: &str, node: R) -> Self
128 where
129 R: GraphNode<S> + 'static,
130 {
131 self.nodes.insert(name.to_string(), Arc::new(node));
132 self
133 }
134
135 pub fn set_entry(mut self, name: &str) -> Self {
136 self.entry = Some(name.to_string());
137 self
138 }
139
140 pub fn add_edge(mut self, from: &str, to: &str) -> Self {
141 self.edges
142 .entry(from.to_string())
143 .or_default()
144 .push(to.to_string());
145 self
146 }
147
148 pub fn add_edges(mut self, from: &str, targets: &[&str]) -> Self {
149 let entry = self.edges.entry(from.to_string()).or_default();
150 for target in targets {
151 entry.push(target.to_string());
152 }
153 self
154 }
155
156 pub fn add_conditional_edge<F>(mut self, from: &str, condition: F) -> Self
157 where
158 F: Fn(&GraphState<S>) -> Vec<String> + Send + Sync + 'static,
159 {
160 self.conditional
161 .insert(from.to_string(), Box::new(condition));
162 self
163 }
164 #[deprecated(since = "0.3.0", note = "Use `with_default_config` instead")]
165 pub fn with_config(mut self, config: ExecutionConfig) -> Self {
166 self.default_config = config;
167 self
168 }
169
170 pub fn with_checkpointer<C>(mut self, checkpointer: C, thread_id: &str) -> Self
171 where
172 C: Checkpointer<S> + 'static,
173 {
174 self.checkpointer = Some((Box::new(checkpointer), thread_id.to_string()));
175 self
176 }
177
178 pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
179 self.observer = Some(observer);
180 self
181 }
182
183 pub fn with_default_config(mut self, config: ExecutionConfig) -> Self {
184 self.default_config = config;
185 self
186 }
187
188 pub fn with_interrupt_before<I, S2>(mut self, nodes: I) -> Self
189 where
190 I: IntoIterator<Item = S2>,
191 S2: Into<String>,
192 {
193 self.interrupt_before = nodes.into_iter().map(Into::into).collect();
194 self
195 }
196
197 pub fn with_interrupt_after<I, S2>(mut self, nodes: I) -> Self
198 where
199 I: IntoIterator<Item = S2>,
200 S2: Into<String>,
201 {
202 self.interrupt_after = nodes.into_iter().map(Into::into).collect();
203 self
204 }
205
206 pub fn build(self) -> ExecutableGraph<S> {
207 ExecutableGraph {
208 nodes: self.nodes,
209 edges: self.edges,
210 conditional: self.conditional,
211 checkpointer: self.checkpointer,
212 observer: self.observer,
213 default_config: self.default_config,
214 entry: self.entry.expect("entry"),
215 interrupt_before: self.interrupt_before,
216 interrupt_after: self.interrupt_after,
217 }
218 }
219
220 pub fn build_program(self) -> GraphProgram<S> {
221 let GraphBuilder { nodes, edges, .. } = self;
222 let mut graph = Graph::new();
223 let mut name_to_index = HashMap::new();
224
225 for (name, runnable) in nodes {
226 let index = graph.add_node(NodeData {
227 name: name.clone(),
228 runnable,
229 });
230 name_to_index.insert(name, index);
231 }
232
233 for (from, targets) in edges.iter() {
234 if from == START {
235 continue;
236 }
237 if let Some(from_idx) = name_to_index.get(from) {
238 for to in targets {
239 if to == END {
240 continue;
241 }
242 if let Some(to_idx) = name_to_index.get(to) {
243 graph.add_edge(*from_idx, *to_idx, EdgeKind::Default);
244 }
245 }
246 }
247 }
248
249 GraphProgram::new(graph, name_to_index)
250 }
251}
252
253fn stable_hash<T: Hash + ?Sized>(t: &T) -> u64 {
254 RandomState::with_seeds(0x517cc1b727220a95, 0x6ed9eba1999cd92d, 0, 0).hash_one(t)
255}
256
257pub struct ExecutableGraph<S: StateSchema> {
258 nodes: HashMap<String, Arc<dyn GraphNode<S>>>,
259 edges: HashMap<String, Vec<String>>,
260 conditional: HashMap<String, Condition<S>>,
261 checkpointer: Option<(Box<dyn Checkpointer<S>>, String)>,
262 observer: Option<Arc<dyn Observer>>,
263 default_config: ExecutionConfig,
264 entry: String,
265 interrupt_before: Vec<String>,
266 interrupt_after: Vec<String>,
267}
268
269impl<S: StateSchema<Update = S>> ExecutableGraph<S> {
270 pub async fn invoke_graph(&self, state: GraphState<S>) -> Result<GraphState<S>, GraphError> {
271 self.invoke_graph_with_options(state, ExecutionOptions::default())
272 .await
273 }
274
275 pub fn stream_invoke(
276 &self,
277 state: GraphState<S>,
278 ) -> BoxStream<'_, Result<GraphEvent<S>, GraphError>> {
279 self.stream_invoke_with_options(state, ExecutionOptions::default())
280 }
281
282 pub fn stream_invoke_with_options(
283 &self,
284 state: GraphState<S>,
285 options: ExecutionOptions,
286 ) -> BoxStream<'_, Result<GraphEvent<S>, GraphError>> {
287 let checkpoint_thread_id = options.checkpoint_thread_id.clone().or_else(|| {
288 self.checkpointer
289 .as_ref()
290 .map(|(_, thread_id)| thread_id.clone())
291 });
292
293 let observer = options.observer.clone().or_else(|| self.observer.clone());
295 let mut run_config = options.run_config.clone().unwrap_or_default();
296
297 if let Some(obs) = observer {
298 let adapter = Arc::new(ObserverCallbackAdapter(obs));
299 let handlers = if let Some(mut manager) = run_config.callbacks.take() {
300 manager.add_handler(adapter);
302 manager
303 } else {
304 CallbackManager::new(vec![adapter])
305 };
306 run_config.callbacks = Some(handlers);
307 }
308
309 let run_config_option = Some(run_config);
310
311 struct StreamState<S: StateSchema> {
318 state: GraphState<S>,
319 step_count: usize,
320 recent: VecDeque<String>,
321 pending_events: VecDeque<GraphEvent<S>>,
322 effective: ExecutionConfig,
323 queue: VecDeque<(String, u64)>,
324 join_set: JoinSet<(String, Result<StateUpdate<S>, WesichainError>, u64)>,
325 start_time: std::time::Instant,
326 visit_counts: HashMap<String, u32>,
327 path_visits: HashMap<(String, u64), u32>,
328 active_tasks: HashSet<(String, u64)>,
330 callbacks: Option<(CallbackManager, RunContext)>,
331 callback_nodes: HashMap<(String, u64), RunContext>,
332 agent_event_sender: Option<mpsc::Sender<AgentEvent>>,
333 agent_event_thread_id: String,
334 agent_event_step: usize,
335 checkpoint_thread_id: Option<String>,
336 initialized: bool,
337 run_config: Option<wesichain_core::RunConfig>, observer: Option<Arc<dyn Observer>>,
339 }
340
341 if !self.nodes.contains_key(&self.entry) {
342 return stream::iter(vec![Ok(GraphEvent::Error(GraphError::MissingNode {
343 node: self.entry.clone(),
344 }))])
345 .boxed();
346 }
347
348 let effective = self.default_config.merge(&options);
349
350 let agent_event_thread_id = options
351 .agent_event_thread_id
352 .clone()
353 .or_else(|| checkpoint_thread_id.clone())
354 .unwrap_or_else(|| "graph".to_string());
355
356 let initial_queue = options
357 .initial_queue
358 .clone()
359 .map(VecDeque::from)
360 .unwrap_or_else(|| VecDeque::from([(self.entry.clone(), 0)]));
361
362 let initial_step = options.initial_step.unwrap_or(0);
363
364 let stream_state = StreamState {
365 state,
366 step_count: initial_step,
367 recent: VecDeque::new(),
368 pending_events: VecDeque::new(),
369 effective,
370 queue: initial_queue,
371 join_set: JoinSet::new(),
372 start_time: std::time::Instant::now(),
373 visit_counts: HashMap::new(),
374 path_visits: HashMap::new(),
375 active_tasks: HashSet::new(),
376 callbacks: None, callback_nodes: HashMap::new(),
378 agent_event_sender: options.agent_event_sender,
379 agent_event_thread_id,
380 agent_event_step: 0,
381 checkpoint_thread_id,
382 initialized: false,
383 run_config: run_config_option,
384 observer: options.observer,
385 };
386
387 stream::unfold(stream_state, move |mut ctx| async move {
388 loop {
389 if !ctx.initialized {
391 ctx.initialized = true;
392
393 if let Some(run_config) = ctx.run_config.take() {
395 if let Some(manager) = run_config.callbacks {
396 if !manager.is_noop() {
397 let name = run_config
398 .name_override
399 .unwrap_or_else(|| "graph_execution".to_string());
400 let root = RunContext::root(
401 RunType::Graph,
402 name,
403 run_config.tags,
404 run_config.metadata,
405 );
406 let inputs = ensure_object(ctx.state.to_trace_input());
407 manager.on_start(&root, &inputs).await;
408 ctx.callbacks = Some((manager, root));
409 }
410 }
411 }
412 }
413
414 if let Some(event) = ctx.pending_events.pop_front() {
416 return Some((Ok(event), ctx));
417 }
418
419 if let Some((current, path_id)) = ctx.queue.pop_front() {
421 if let Some(duration) = ctx.effective.max_duration {
424 if ctx.start_time.elapsed() > duration {
425 let error = GraphError::Timeout {
426 node: "global".to_string(),
427 elapsed: ctx.start_time.elapsed(),
428 };
429 if let Some((manager, root)) = &ctx.callbacks {
431 let error_value =
432 ensure_object(error.to_string().to_trace_output());
433 let duration_ms = root.start_instant.elapsed().as_millis();
434 manager.on_error(root, &error_value, duration_ms).await;
435 }
436
437 ctx.join_set.shutdown().await;
438 ctx.pending_events.push_back(GraphEvent::Error(error));
439 continue;
440 }
441 }
442
443 if let Some(max) = ctx.effective.max_steps {
445 if ctx.step_count >= max {
446 let error = GraphError::MaxStepsExceeded {
447 max,
448 reached: ctx.step_count,
449 };
450 if let Some((manager, root)) = &ctx.callbacks {
451 let error_value =
452 ensure_object(error.to_string().to_trace_output());
453 let duration_ms = root.start_instant.elapsed().as_millis();
454 manager.on_error(root, &error_value, duration_ms).await;
455 }
456 ctx.join_set.shutdown().await;
457 ctx.pending_events.push_back(GraphEvent::Error(error));
458 continue;
459 }
460 }
461
462 if let Some(max_visits) = ctx.effective.max_visits {
464 let count = ctx.visit_counts.entry(current.clone()).or_insert(0);
465 *count += 1;
466 if *count > max_visits {
467 let error = GraphError::MaxVisitsExceeded {
468 node: current.clone(),
469 max: max_visits,
470 };
471 if let Some((manager, root)) = &ctx.callbacks {
472 let error_value =
473 ensure_object(error.to_string().to_trace_output());
474 let duration_ms = root.start_instant.elapsed().as_millis();
475 manager.on_error(root, &error_value, duration_ms).await;
476 }
477 ctx.join_set.shutdown().await;
478 ctx.pending_events.push_back(GraphEvent::Error(error));
479 continue;
480 }
481 }
482
483 if let Some(max_loops) = ctx.effective.max_loop_iterations {
485 let key = (current.clone(), path_id);
486 let count = ctx.path_visits.entry(key).or_insert(0);
487 *count += 1;
488 if *count > max_loops {
489 let error = GraphError::MaxLoopIterationsExceeded {
490 node: current.clone(),
491 max: max_loops,
492 path_id,
493 };
494 if let Some((manager, root)) = &ctx.callbacks {
495 let error_value =
496 ensure_object(error.to_string().to_trace_output());
497 let duration_ms = root.start_instant.elapsed().as_millis();
498 manager.on_error(root, &error_value, duration_ms).await;
499 }
500 ctx.join_set.shutdown().await;
501 ctx.pending_events.push_back(GraphEvent::Error(error));
502 continue;
503 }
504 }
505
506 ctx.step_count += 1;
507
508 if ctx.effective.cycle_detection {
510 if ctx.recent.len() == ctx.effective.cycle_window {
511 ctx.recent.pop_front();
512 }
513 ctx.recent.push_back(current.clone());
514 let count = ctx.recent.iter().filter(|node| **node == current).count();
515 if count >= 2 {
516 let error = GraphError::CycleDetected {
517 node: current.clone(),
518 recent: ctx.recent.iter().cloned().collect(),
519 };
520 if let Some((manager, root)) = &ctx.callbacks {
521 let error_value =
522 ensure_object(error.to_string().to_trace_output());
523 let duration_ms = root.start_instant.elapsed().as_millis();
524 manager.on_error(root, &error_value, duration_ms).await;
525 }
526 ctx.join_set.shutdown().await;
527 ctx.pending_events.push_back(GraphEvent::Error(error));
528 continue;
529 }
530 }
531
532 if ctx.effective.interrupt_before.contains(¤t)
534 || self.interrupt_before.contains(¤t)
535 {
536 let error = GraphError::Interrupted;
537 if let Some((manager, root)) = &ctx.callbacks {
538 let error_value = ensure_object(error.to_string().to_trace_output());
539 let duration_ms = root.start_instant.elapsed().as_millis();
540 manager.on_error(root, &error_value, duration_ms).await;
541 }
542
543 if let (Some((checkpointer, _)), Some(thread_id)) = (
545 self.checkpointer.as_ref(),
546 ctx.checkpoint_thread_id.as_deref(),
547 ) {
548 let mut full_queue = ctx.queue.iter().cloned().collect::<Vec<_>>();
549 full_queue.push((current.clone(), path_id));
550 full_queue.extend(ctx.active_tasks.iter().cloned());
551
552 let checkpoint = Checkpoint::new(
553 thread_id.to_string(),
554 ctx.state.clone(),
555 ctx.step_count as u64,
556 current.clone(),
557 full_queue,
558 );
559 if let Err(e) = checkpointer.save(&checkpoint).await {
560 let graph_err = GraphError::from(e);
561 if let Some((manager, root)) = &ctx.callbacks {
562 let error_value =
563 ensure_object(graph_err.to_string().to_trace_output());
564 let duration_ms = root.start_instant.elapsed().as_millis();
565 manager.on_error(root, &error_value, duration_ms).await;
566 }
567 ctx.pending_events.push_back(GraphEvent::Error(graph_err));
568 } else {
569 ctx.pending_events.push_back(GraphEvent::CheckpointSaved {
570 node: current.clone(),
571 timestamp: Utc::now().timestamp_millis() as u64,
572 });
573 if let Some((manager, root)) = &ctx.callbacks {
574 manager
576 .on_event(
577 root,
578 "checkpoint_saved",
579 &json!({"node_id": current}),
580 )
581 .await;
582 }
583 }
584 }
585
586 ctx.join_set.shutdown().await;
587 ctx.pending_events.push_back(GraphEvent::Error(error));
588 continue;
589 }
590
591 let node = match self.nodes.get(¤t) {
593 Some(node) => node.clone(),
594 None => {
595 let error = GraphError::InvalidEdge {
596 node: current.clone(),
597 };
598 ctx.pending_events.push_back(GraphEvent::Error(error));
600 continue;
601 }
602 };
603
604 emit_status_event(
606 &ctx.agent_event_sender,
607 &mut ctx.agent_event_step,
608 &ctx.agent_event_thread_id,
609 "node_start",
610 format!("Starting node {current}"),
611 )
612 .await;
613
614 if let Some((manager, root)) = &ctx.callbacks {
615 let node_ctx = root.child(RunType::Chain, current.clone());
616 let node_inputs = ensure_object(ctx.state.to_trace_input());
617 manager.on_start(&node_ctx, &node_inputs).await;
618 ctx.callback_nodes
619 .insert((current.clone(), path_id), node_ctx);
620 }
621
622 ctx.pending_events.push_back(GraphEvent::NodeEnter {
623 node: current.clone(),
624 timestamp: Utc::now().timestamp_millis() as u64,
625 });
626
627 let input_state = ctx.state.clone();
629 let node_ctx_obs = ctx.observer.clone();
631 let node_id = current.clone();
632 let effective_config_spawn = ctx.effective.clone();
633 let remaining = effective_config_spawn
634 .max_steps
635 .map(|m| m.saturating_sub(ctx.step_count)); let context = GraphContext {
638 remaining_steps: remaining,
639 observer: node_ctx_obs,
640 node_id: node_id.clone(),
641 };
642
643 ctx.active_tasks.insert((current.clone(), path_id));
644
645 ctx.join_set.spawn(async move {
647 let future = node.invoke_with_context(input_state, &context);
648 let result = if let Some(timeout) = effective_config_spawn.node_timeout {
649 match tokio::time::timeout(timeout, future).await {
650 Ok(res) => res,
651 Err(_) => Err(WesichainError::Custom(format!(
652 "Node {} timed out after {:?}",
653 node_id, timeout
654 ))),
655 }
656 } else {
657 future.await
658 };
659 (current, result, path_id)
660 });
661
662 continue; }
664
665 if !ctx.join_set.is_empty() {
667 if let Some(join_res) = ctx.join_set.join_next().await {
668 let (current, invoke_res, path_id) = match join_res {
669 Ok(r) => r,
670 Err(err) => {
671 let error = GraphError::System(err.to_string());
672 ctx.join_set.shutdown().await;
673 ctx.pending_events.push_back(GraphEvent::Error(error));
674 continue;
675 }
676 };
677
678 ctx.active_tasks.remove(&(current.clone(), path_id));
679
680 match invoke_res {
681 Ok(update) => {
682 let output_debug =
684 serde_json::to_string(&update).unwrap_or_default();
685 ctx.state = ctx.state.apply_update(update.clone());
686
687 ctx.pending_events.push_back(GraphEvent::NodeFinished {
688 node: current.clone(),
689 output: output_debug,
690 timestamp: Utc::now().timestamp_millis() as u64,
691 });
692
693 ctx.pending_events
695 .push_back(GraphEvent::StateUpdate(update));
696
697 if let Some((manager, _root)) = &ctx.callbacks {
699 if let Some(node_ctx) =
700 ctx.callback_nodes.remove(&(current.clone(), path_id))
701 {
702 let node_outputs =
703 ensure_object(ctx.state.to_trace_output());
704 let duration_ms =
705 node_ctx.start_instant.elapsed().as_millis();
706 manager.on_end(&node_ctx, &node_outputs, duration_ms).await;
707 }
708 }
709 emit_status_event(
711 &ctx.agent_event_sender,
712 &mut ctx.agent_event_step,
713 &ctx.agent_event_thread_id,
714 "node_end",
715 format!("Completed node {current}"),
716 )
717 .await;
718
719 ctx.pending_events.push_back(GraphEvent::NodeExit {
720 node: current.clone(),
721 timestamp: Utc::now().timestamp_millis() as u64,
722 });
723
724 if let Some(condition) = self.conditional.get(¤t) {
726 let targets = condition(&ctx.state);
727 let next_paths: Vec<(String, u64)> = if targets.len() > 1 {
728 targets
729 .into_iter()
730 .map(|t| {
731 if t == END {
732 (t, path_id)
733 } else {
734 let h = stable_hash(&(path_id, &t));
735 (t, h)
736 }
737 })
738 .collect()
739 } else {
740 targets.into_iter().map(|t| (t, path_id)).collect()
741 };
742
743 for (next, next_path_id) in next_paths {
744 if next == END {
745 continue;
746 }
747 if !self.nodes.contains_key(&next) {
748 let error =
750 GraphError::InvalidEdge { node: next.clone() };
751 ctx.pending_events.push_back(GraphEvent::Error(error));
752 ctx.join_set.shutdown().await;
753 continue; }
755 ctx.queue.push_back((next, next_path_id));
756 }
757 } else if let Some(targets) = self.edges.get(¤t) {
758 let next_paths: Vec<(String, u64)> = if targets.len() > 1 {
759 targets
760 .iter()
761 .map(|t| {
762 if *t == END {
763 (t.clone(), path_id)
764 } else {
765 (t.clone(), stable_hash(&(path_id, t)))
766 }
767 })
768 .collect()
769 } else {
770 targets.iter().cloned().map(|t| (t, path_id)).collect()
771 };
772
773 for (next, next_path_id) in next_paths {
774 if next == END {
775 continue;
776 }
777 if !self.nodes.contains_key(&next) {
778 let error =
779 GraphError::InvalidEdge { node: next.clone() };
780 ctx.pending_events.push_back(GraphEvent::Error(error));
781 ctx.join_set.shutdown().await;
782 continue;
783 }
784 ctx.queue.push_back((next, next_path_id));
785 }
786 }
787
788 if let (Some((checkpointer, _)), Some(thread_id)) = (
790 self.checkpointer.as_ref(),
791 ctx.checkpoint_thread_id.as_deref(),
792 ) {
793 let mut full_queue =
794 ctx.queue.iter().cloned().collect::<Vec<_>>();
795 full_queue.extend(ctx.active_tasks.iter().cloned());
796
797 let checkpoint = Checkpoint::new(
798 thread_id.to_string(),
799 ctx.state.clone(),
800 ctx.step_count as u64,
801 current.clone(),
802 full_queue,
803 );
804
805 if let Err(e) = checkpointer.save(&checkpoint).await {
806 let graph_err = GraphError::from(e);
807 if let Some((manager, root)) = &ctx.callbacks {
808 let error_value = ensure_object(
809 graph_err.to_string().to_trace_output(),
810 );
811 let duration_ms =
812 root.start_instant.elapsed().as_millis();
813 manager.on_error(root, &error_value, duration_ms).await;
814 }
815 ctx.pending_events.push_back(GraphEvent::Error(graph_err));
816 ctx.join_set.shutdown().await;
817 continue;
818 } else {
819 ctx.pending_events.push_back(GraphEvent::CheckpointSaved {
820 node: current.clone(),
821 timestamp: Utc::now().timestamp_millis() as u64,
822 });
823
824 if let Some((manager, root)) = &ctx.callbacks {
825 manager
827 .on_event(
828 root,
829 "checkpoint_saved",
830 &json!({"node_id": current}),
831 )
832 .await;
833 }
834 }
835 }
836
837 if ctx.effective.interrupt_after.contains(¤t)
839 || self.interrupt_after.contains(¤t)
840 {
841 let error = GraphError::Interrupted;
842 if let Some((manager, root)) = &ctx.callbacks {
843 let error_value =
844 ensure_object(error.to_string().to_trace_output());
845 let duration_ms = root.start_instant.elapsed().as_millis();
846 manager.on_error(root, &error_value, duration_ms).await;
847 }
848 ctx.pending_events.push_back(GraphEvent::Error(error));
849 continue;
850 }
851 }
852 Err(e) => {
853 let error = GraphError::NodeFailed {
855 node: current.clone(),
856 source: Box::new(e),
857 };
858 if let Some((manager, _root)) = &ctx.callbacks {
859 if let Some(node_ctx) =
860 ctx.callback_nodes.remove(&(current.clone(), path_id))
861 {
862 let error_value =
863 ensure_object(error.to_string().to_trace_output());
864 let duration_ms =
865 node_ctx.start_instant.elapsed().as_millis();
866 manager
867 .on_error(&node_ctx, &error_value, duration_ms)
868 .await;
869 }
870 }
871 ctx.join_set.shutdown().await;
872 ctx.pending_events.push_back(GraphEvent::Error(error));
873 continue;
874 }
875 }
876 }
877 } else if ctx.queue.is_empty() {
878 if let Some((manager, root)) = &ctx.callbacks {
880 let outputs = ensure_object(ctx.state.to_trace_output());
881 let duration_ms = root.start_instant.elapsed().as_millis();
882 manager.on_end(root, &outputs, duration_ms).await;
883 }
884
885 emit_status_event(
886 &ctx.agent_event_sender,
887 &mut ctx.agent_event_step,
888 &ctx.agent_event_thread_id,
889 "completed",
890 "Graph execution completed",
891 )
892 .await;
893
894 return None;
895 }
896 }
897 })
898 .boxed()
899 }
900
901 pub async fn invoke_graph_with_options(
902 &self,
903 mut state: GraphState<S>,
904 mut options: ExecutionOptions,
905 ) -> Result<GraphState<S>, GraphError> {
906 let checkpoint_thread_id = options.checkpoint_thread_id.clone().or_else(|| {
907 self.checkpointer
908 .as_ref()
909 .map(|(_, thread_id)| thread_id.clone())
910 });
911
912 let agent_event_sender = options.agent_event_sender.clone();
913 let _agent_event_thread_id = options
914 .agent_event_thread_id
915 .clone()
916 .or_else(|| checkpoint_thread_id.clone())
917 .unwrap_or_else(|| "graph".to_string());
918 let mut agent_event_step = 0usize;
919
920 if options.auto_resume {
921 if let (Some((checkpointer, _)), Some(thread_id)) =
922 (self.checkpointer.as_ref(), checkpoint_thread_id.as_deref())
923 {
924 match checkpointer.load(thread_id).await {
925 Ok(Some(saved)) => {
926 state = saved.state;
927 if !saved.queue.is_empty() {
929 options.initial_queue = Some(saved.queue);
930 options.initial_step = Some(saved.step as usize + 1);
931 } else {
932 }
936 }
937 Ok(None) => {}
938 Err(error) => return Err(error.into()),
939 }
940 }
941 }
942
943 if !self.nodes.contains_key(&self.entry) {
944 let error = GraphError::MissingNode {
945 node: self.entry.clone(),
946 };
947 emit_error_event(
948 &agent_event_sender,
949 &mut agent_event_step,
950 error.to_string(),
951 Some("graph".to_string()),
952 )
953 .await;
954 return Err(error);
955 }
956
957 let mut stream = self.stream_invoke_with_options(state.clone(), options);
958
959 while let Some(event) = stream.next().await {
960 match event {
961 Ok(GraphEvent::StateUpdate(update)) => {
962 state = state.apply_update(update);
963 }
964 Ok(GraphEvent::Error(e)) | Err(e) => return Err(e),
965 _ => {}
968 }
969 }
970
971 Ok(state)
972 }
973
974 pub async fn invoke(&self, state: GraphState<S>) -> Result<GraphState<S>, WesichainError> {
975 self.invoke_graph(state)
976 .await
977 .map_err(|err| WesichainError::Custom(err.to_string()))
978 }
979
980 pub async fn invoke_with_options(
981 &self,
982 state: GraphState<S>,
983 options: ExecutionOptions,
984 ) -> Result<GraphState<S>, WesichainError> {
985 self.invoke_graph_with_options(state, options)
986 .await
987 .map_err(|err| WesichainError::Custom(err.to_string()))
988 }
989
990 pub async fn get_state(&self, thread_id: &str) -> Result<Option<GraphState<S>>, GraphError> {
991 if let Some((checkpointer, _)) = &self.checkpointer {
992 let checkpoint = checkpointer.load(thread_id).await?;
993 Ok(checkpoint.map(|cp| cp.state))
994 } else {
995 Ok(None)
996 }
997 }
998
999 pub async fn resume(
1000 &self,
1001 checkpoint: Checkpoint<S>,
1002 mut options: ExecutionOptions,
1003 ) -> Result<GraphState<S>, GraphError> {
1004 options.initial_queue = Some(checkpoint.queue);
1005 options.initial_step = Some(checkpoint.step as usize + 1);
1007 self.invoke_graph_with_options(checkpoint.state, options)
1008 .await
1009 }
1010
1011 pub async fn update_state(
1012 &self,
1013 thread_id: &str,
1014 values: S,
1015 as_node: Option<String>,
1016 ) -> Result<(), GraphError> {
1017 if let Some((checkpointer, _)) = &self.checkpointer {
1018 let (mut state, step) = if let Some(checkpoint) = checkpointer.load(thread_id).await? {
1020 (checkpoint.state, checkpoint.step + 1)
1021 } else {
1022 (GraphState::new(S::default()), 1)
1023 };
1024
1025 let update = StateUpdate::new(values);
1027 state = state.apply_update(update);
1028
1029 let node = as_node.unwrap_or_else(|| "user".to_string());
1031 let checkpoint = Checkpoint::new(thread_id.to_string(), state, step, node, vec![]);
1032 checkpointer.save(&checkpoint).await?;
1033 Ok(())
1034 } else {
1035 Err(GraphError::Checkpoint("Checkpointer not configured".into()))
1036 }
1037 }
1038}
1039
1040#[async_trait::async_trait]
1041impl<S: StateSchema<Update = S>> Runnable<GraphState<S>, StateUpdate<S>> for ExecutableGraph<S> {
1042 async fn invoke(&self, input: GraphState<S>) -> Result<StateUpdate<S>, WesichainError> {
1043 let result = self
1044 .invoke_graph(input)
1045 .await
1046 .map_err(|e| WesichainError::Custom(e.to_string()))?;
1047 Ok(StateUpdate::new(result.data))
1048 }
1049
1050 fn stream<'a>(
1051 &'a self,
1052 input: GraphState<S>,
1053 ) -> BoxStream<'a, Result<wesichain_core::StreamEvent, WesichainError>> {
1054 let stream = self.stream_invoke(input);
1055
1056 stream
1057 .filter_map(|event_res| async move {
1058 match event_res {
1059 Ok(GraphEvent::Error(e)) | Err(e) => {
1060 Some(Err(WesichainError::Custom(e.to_string())))
1061 }
1062 _ => None,
1066 }
1067 })
1068 .boxed()
1069 }
1070}
1071
1072#[cfg(test)]
1073mod tests {
1074 use super::*;
1075
1076 #[test]
1077 fn test_stable_path_hashing() {
1078 let parent_id = 12345u64;
1079 let node_name = "test_node";
1080
1081 let state = RandomState::with_seeds(0x517cc1b727220a95, 0x6ed9eba1999cd92d, 0, 0);
1083 let hash1 = state.hash_one((parent_id, node_name));
1084
1085 let expected1 = state.hash_one((parent_id, node_name));
1087 assert_eq!(hash1, expected1, "Hash MUST be deterministic");
1088
1089 let different_hash =
1090 RandomState::with_seeds(123, 456, 0, 0).hash_one((parent_id, node_name));
1091
1092 assert_ne!(hash1, different_hash, "Should differ from arbitrary keys");
1093 }
1094}