Skip to main content

wesichain_graph/
graph.rs

1use ahash::RandomState;
2use std::collections::{HashMap, HashSet, VecDeque};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use chrono::Utc;
7use futures::stream::{self, BoxStream, StreamExt};
8use petgraph::graph::Graph;
9use tokio::sync::mpsc;
10use tokio::task::JoinSet;
11
12use crate::observer::ObserverCallbackAdapter;
13use crate::{
14    Checkpoint, Checkpointer, EdgeKind, ExecutionConfig, ExecutionOptions, GraphError, GraphEvent,
15    GraphProgram, GraphState, NodeData, Observer, StateSchema, StateUpdate, END, START,
16};
17use serde_json::json;
18use wesichain_core::{
19    ensure_object, AgentEvent, CallbackManager, RunContext, RunType, Runnable, ToTraceInput,
20    ToTraceOutput, WesichainError,
21};
22
23pub type Condition<S> = Box<dyn Fn(&GraphState<S>) -> Vec<String> + Send + Sync>;
24
25pub struct GraphContext {
26    pub remaining_steps: Option<usize>,
27    pub observer: Option<Arc<dyn Observer>>,
28    pub node_id: String,
29}
30
31async fn emit_status_event(
32    sender: &Option<mpsc::Sender<AgentEvent>>,
33    step: &mut usize,
34    thread_id: &str,
35    stage: impl Into<String>,
36    message: impl Into<String>,
37) {
38    if let Some(sender) = sender {
39        *step += 1;
40        let _ = sender
41            .send(AgentEvent::Status {
42                stage: stage.into(),
43                message: message.into(),
44                step: *step,
45                thread_id: thread_id.to_string(),
46            })
47            .await;
48    }
49}
50
51async fn emit_error_event(
52    sender: &Option<mpsc::Sender<AgentEvent>>,
53    step: &mut usize,
54    message: impl Into<String>,
55    source: Option<String>,
56) {
57    if let Some(sender) = sender {
58        *step += 1;
59        let _ = sender
60            .send(AgentEvent::Error {
61                message: message.into(),
62                step: *step,
63                recoverable: false,
64                source,
65            })
66            .await;
67    }
68}
69
70#[async_trait::async_trait]
71pub trait GraphNode<S: StateSchema>: Send + Sync {
72    async fn invoke_with_context(
73        &self,
74        input: GraphState<S>,
75        context: &GraphContext,
76    ) -> Result<StateUpdate<S>, WesichainError>;
77}
78
79#[async_trait::async_trait]
80impl<S, R> GraphNode<S> for R
81where
82    S: StateSchema,
83    R: Runnable<GraphState<S>, StateUpdate<S>> + Send + Sync,
84{
85    async fn invoke_with_context(
86        &self,
87        input: GraphState<S>,
88        _context: &GraphContext,
89    ) -> Result<StateUpdate<S>, WesichainError> {
90        self.invoke(input).await
91    }
92}
93
94pub struct GraphBuilder<S: StateSchema> {
95    nodes: HashMap<String, Arc<dyn GraphNode<S>>>,
96    edges: HashMap<String, Vec<String>>,
97    conditional: HashMap<String, Condition<S>>,
98    checkpointer: Option<(Box<dyn Checkpointer<S>>, String)>,
99    observer: Option<Arc<dyn Observer>>,
100    default_config: ExecutionConfig,
101    entry: Option<String>,
102    interrupt_before: Vec<String>,
103    interrupt_after: Vec<String>,
104}
105
106impl<S: StateSchema> Default for GraphBuilder<S> {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112impl<S: StateSchema> GraphBuilder<S> {
113    pub fn new() -> Self {
114        Self {
115            nodes: HashMap::new(),
116            edges: HashMap::new(),
117            conditional: HashMap::new(),
118            checkpointer: None,
119            observer: None,
120            default_config: ExecutionConfig::default(),
121            entry: None,
122            interrupt_before: Vec::new(),
123            interrupt_after: Vec::new(),
124        }
125    }
126
127    pub fn add_node<R>(mut self, name: &str, node: R) -> Self
128    where
129        R: GraphNode<S> + 'static,
130    {
131        self.nodes.insert(name.to_string(), Arc::new(node));
132        self
133    }
134
135    pub fn set_entry(mut self, name: &str) -> Self {
136        self.entry = Some(name.to_string());
137        self
138    }
139
140    pub fn add_edge(mut self, from: &str, to: &str) -> Self {
141        self.edges
142            .entry(from.to_string())
143            .or_default()
144            .push(to.to_string());
145        self
146    }
147
148    pub fn add_edges(mut self, from: &str, targets: &[&str]) -> Self {
149        let entry = self.edges.entry(from.to_string()).or_default();
150        for target in targets {
151            entry.push(target.to_string());
152        }
153        self
154    }
155
156    pub fn add_conditional_edge<F>(mut self, from: &str, condition: F) -> Self
157    where
158        F: Fn(&GraphState<S>) -> Vec<String> + Send + Sync + 'static,
159    {
160        self.conditional
161            .insert(from.to_string(), Box::new(condition));
162        self
163    }
164    #[deprecated(since = "0.3.0", note = "Use `with_default_config` instead")]
165    pub fn with_config(mut self, config: ExecutionConfig) -> Self {
166        self.default_config = config;
167        self
168    }
169
170    pub fn with_checkpointer<C>(mut self, checkpointer: C, thread_id: &str) -> Self
171    where
172        C: Checkpointer<S> + 'static,
173    {
174        self.checkpointer = Some((Box::new(checkpointer), thread_id.to_string()));
175        self
176    }
177
178    pub fn with_observer(mut self, observer: Arc<dyn Observer>) -> Self {
179        self.observer = Some(observer);
180        self
181    }
182
183    pub fn with_default_config(mut self, config: ExecutionConfig) -> Self {
184        self.default_config = config;
185        self
186    }
187
188    pub fn with_interrupt_before<I, S2>(mut self, nodes: I) -> Self
189    where
190        I: IntoIterator<Item = S2>,
191        S2: Into<String>,
192    {
193        self.interrupt_before = nodes.into_iter().map(Into::into).collect();
194        self
195    }
196
197    pub fn with_interrupt_after<I, S2>(mut self, nodes: I) -> Self
198    where
199        I: IntoIterator<Item = S2>,
200        S2: Into<String>,
201    {
202        self.interrupt_after = nodes.into_iter().map(Into::into).collect();
203        self
204    }
205
206    pub fn build(self) -> ExecutableGraph<S> {
207        ExecutableGraph {
208            nodes: self.nodes,
209            edges: self.edges,
210            conditional: self.conditional,
211            checkpointer: self.checkpointer,
212            observer: self.observer,
213            default_config: self.default_config,
214            entry: self.entry.expect("entry"),
215            interrupt_before: self.interrupt_before,
216            interrupt_after: self.interrupt_after,
217        }
218    }
219
220    pub fn build_program(self) -> GraphProgram<S> {
221        let GraphBuilder { nodes, edges, .. } = self;
222        let mut graph = Graph::new();
223        let mut name_to_index = HashMap::new();
224
225        for (name, runnable) in nodes {
226            let index = graph.add_node(NodeData {
227                name: name.clone(),
228                runnable,
229            });
230            name_to_index.insert(name, index);
231        }
232
233        for (from, targets) in edges.iter() {
234            if from == START {
235                continue;
236            }
237            if let Some(from_idx) = name_to_index.get(from) {
238                for to in targets {
239                    if to == END {
240                        continue;
241                    }
242                    if let Some(to_idx) = name_to_index.get(to) {
243                        graph.add_edge(*from_idx, *to_idx, EdgeKind::Default);
244                    }
245                }
246            }
247        }
248
249        GraphProgram::new(graph, name_to_index)
250    }
251}
252
253fn stable_hash<T: Hash + ?Sized>(t: &T) -> u64 {
254    RandomState::with_seeds(0x517cc1b727220a95, 0x6ed9eba1999cd92d, 0, 0).hash_one(t)
255}
256
257pub struct ExecutableGraph<S: StateSchema> {
258    nodes: HashMap<String, Arc<dyn GraphNode<S>>>,
259    edges: HashMap<String, Vec<String>>,
260    conditional: HashMap<String, Condition<S>>,
261    checkpointer: Option<(Box<dyn Checkpointer<S>>, String)>,
262    observer: Option<Arc<dyn Observer>>,
263    default_config: ExecutionConfig,
264    entry: String,
265    interrupt_before: Vec<String>,
266    interrupt_after: Vec<String>,
267}
268
269impl<S: StateSchema<Update = S>> ExecutableGraph<S> {
270    pub async fn invoke_graph(&self, state: GraphState<S>) -> Result<GraphState<S>, GraphError> {
271        self.invoke_graph_with_options(state, ExecutionOptions::default())
272            .await
273    }
274
275    pub fn stream_invoke(
276        &self,
277        state: GraphState<S>,
278    ) -> BoxStream<'_, Result<GraphEvent<S>, GraphError>> {
279        self.stream_invoke_with_options(state, ExecutionOptions::default())
280    }
281
282    pub fn stream_invoke_with_options(
283        &self,
284        state: GraphState<S>,
285        options: ExecutionOptions,
286    ) -> BoxStream<'_, Result<GraphEvent<S>, GraphError>> {
287        let checkpoint_thread_id = options.checkpoint_thread_id.clone().or_else(|| {
288            self.checkpointer
289                .as_ref()
290                .map(|(_, thread_id)| thread_id.clone())
291        });
292
293        // Initialize Callbacks and Observer (Unified)
294        let observer = options.observer.clone().or_else(|| self.observer.clone());
295        let mut run_config = options.run_config.clone().unwrap_or_default();
296
297        if let Some(obs) = observer {
298            let adapter = Arc::new(ObserverCallbackAdapter(obs));
299            let handlers = if let Some(mut manager) = run_config.callbacks.take() {
300                // Merge the adapter into the existing CallbackManager
301                manager.add_handler(adapter);
302                manager
303            } else {
304                CallbackManager::new(vec![adapter])
305            };
306            run_config.callbacks = Some(handlers);
307        }
308
309        let run_config_option = Some(run_config);
310
311        // We need to run initialization async to call on_start
312        // Since stream::unfold expects an initial state, we'll use a wrapper enum or
313        // handle initialization in the first step of the loop.
314        // Or better yet, we can't easily do async setup *outside* the stream if we return a stream immediately.
315        // So we'll trigger the start events in the first iteration.
316
317        struct StreamState<S: StateSchema> {
318            state: GraphState<S>,
319            step_count: usize,
320            recent: VecDeque<String>,
321            pending_events: VecDeque<GraphEvent<S>>,
322            effective: ExecutionConfig,
323            queue: VecDeque<(String, u64)>,
324            join_set: JoinSet<(String, Result<StateUpdate<S>, WesichainError>, u64)>,
325            start_time: std::time::Instant,
326            visit_counts: HashMap<String, u32>,
327            path_visits: HashMap<(String, u64), u32>,
328            // Unified fields
329            active_tasks: HashSet<(String, u64)>,
330            callbacks: Option<(CallbackManager, RunContext)>,
331            callback_nodes: HashMap<(String, u64), RunContext>,
332            agent_event_sender: Option<mpsc::Sender<AgentEvent>>,
333            agent_event_thread_id: String,
334            agent_event_step: usize,
335            checkpoint_thread_id: Option<String>,
336            initialized: bool,
337            run_config: Option<wesichain_core::RunConfig>, // Store for delayed init
338            observer: Option<Arc<dyn Observer>>,
339        }
340
341        if !self.nodes.contains_key(&self.entry) {
342            return stream::iter(vec![Ok(GraphEvent::Error(GraphError::MissingNode {
343                node: self.entry.clone(),
344            }))])
345            .boxed();
346        }
347
348        let effective = self.default_config.merge(&options);
349
350        let agent_event_thread_id = options
351            .agent_event_thread_id
352            .clone()
353            .or_else(|| checkpoint_thread_id.clone())
354            .unwrap_or_else(|| "graph".to_string());
355
356        let initial_queue = options
357            .initial_queue
358            .clone()
359            .map(VecDeque::from)
360            .unwrap_or_else(|| VecDeque::from([(self.entry.clone(), 0)]));
361
362        let initial_step = options.initial_step.unwrap_or(0);
363
364        let stream_state = StreamState {
365            state,
366            step_count: initial_step,
367            recent: VecDeque::new(),
368            pending_events: VecDeque::new(),
369            effective,
370            queue: initial_queue,
371            join_set: JoinSet::new(),
372            start_time: std::time::Instant::now(),
373            visit_counts: HashMap::new(),
374            path_visits: HashMap::new(),
375            active_tasks: HashSet::new(),
376            callbacks: None, // Will init in loop
377            callback_nodes: HashMap::new(),
378            agent_event_sender: options.agent_event_sender,
379            agent_event_thread_id,
380            agent_event_step: 0,
381            checkpoint_thread_id,
382            initialized: false,
383            run_config: run_config_option,
384            observer: options.observer,
385        };
386
387        stream::unfold(stream_state, move |mut ctx| async move {
388            loop {
389                // 1. delayed initialization (on first poll)
390                if !ctx.initialized {
391                    ctx.initialized = true;
392
393                    // Initialize callbacks
394                    if let Some(run_config) = ctx.run_config.take() {
395                        if let Some(manager) = run_config.callbacks {
396                            if !manager.is_noop() {
397                                let name = run_config
398                                    .name_override
399                                    .unwrap_or_else(|| "graph_execution".to_string());
400                                let root = RunContext::root(
401                                    RunType::Graph,
402                                    name,
403                                    run_config.tags,
404                                    run_config.metadata,
405                                );
406                                let inputs = ensure_object(ctx.state.to_trace_input());
407                                manager.on_start(&root, &inputs).await;
408                                ctx.callbacks = Some((manager, root));
409                            }
410                        }
411                    }
412                }
413
414                // 2. Emit pending events
415                if let Some(event) = ctx.pending_events.pop_front() {
416                    return Some((Ok(event), ctx));
417                }
418
419                // 3. Process Queue
420                if let Some((current, path_id)) = ctx.queue.pop_front() {
421                    // Safety Checks
422                    // Global Timer
423                    if let Some(duration) = ctx.effective.max_duration {
424                        if ctx.start_time.elapsed() > duration {
425                            let error = GraphError::Timeout {
426                                node: "global".to_string(),
427                                elapsed: ctx.start_time.elapsed(),
428                            };
429                            // callbacks error
430                            if let Some((manager, root)) = &ctx.callbacks {
431                                let error_value =
432                                    ensure_object(error.to_string().to_trace_output());
433                                let duration_ms = root.start_instant.elapsed().as_millis();
434                                manager.on_error(root, &error_value, duration_ms).await;
435                            }
436
437                            ctx.join_set.shutdown().await;
438                            ctx.pending_events.push_back(GraphEvent::Error(error));
439                            continue;
440                        }
441                    }
442
443                    // Max Steps
444                    if let Some(max) = ctx.effective.max_steps {
445                        if ctx.step_count >= max {
446                            let error = GraphError::MaxStepsExceeded {
447                                max,
448                                reached: ctx.step_count,
449                            };
450                            if let Some((manager, root)) = &ctx.callbacks {
451                                let error_value =
452                                    ensure_object(error.to_string().to_trace_output());
453                                let duration_ms = root.start_instant.elapsed().as_millis();
454                                manager.on_error(root, &error_value, duration_ms).await;
455                            }
456                            ctx.join_set.shutdown().await;
457                            ctx.pending_events.push_back(GraphEvent::Error(error));
458                            continue;
459                        }
460                    }
461
462                    // Max Visits
463                    if let Some(max_visits) = ctx.effective.max_visits {
464                        let count = ctx.visit_counts.entry(current.clone()).or_insert(0);
465                        *count += 1;
466                        if *count > max_visits {
467                            let error = GraphError::MaxVisitsExceeded {
468                                node: current.clone(),
469                                max: max_visits,
470                            };
471                            if let Some((manager, root)) = &ctx.callbacks {
472                                let error_value =
473                                    ensure_object(error.to_string().to_trace_output());
474                                let duration_ms = root.start_instant.elapsed().as_millis();
475                                manager.on_error(root, &error_value, duration_ms).await;
476                            }
477                            ctx.join_set.shutdown().await;
478                            ctx.pending_events.push_back(GraphEvent::Error(error));
479                            continue;
480                        }
481                    }
482
483                    // Path loops
484                    if let Some(max_loops) = ctx.effective.max_loop_iterations {
485                        let key = (current.clone(), path_id);
486                        let count = ctx.path_visits.entry(key).or_insert(0);
487                        *count += 1;
488                        if *count > max_loops {
489                            let error = GraphError::MaxLoopIterationsExceeded {
490                                node: current.clone(),
491                                max: max_loops,
492                                path_id,
493                            };
494                            if let Some((manager, root)) = &ctx.callbacks {
495                                let error_value =
496                                    ensure_object(error.to_string().to_trace_output());
497                                let duration_ms = root.start_instant.elapsed().as_millis();
498                                manager.on_error(root, &error_value, duration_ms).await;
499                            }
500                            ctx.join_set.shutdown().await;
501                            ctx.pending_events.push_back(GraphEvent::Error(error));
502                            continue;
503                        }
504                    }
505
506                    ctx.step_count += 1;
507
508                    // Cycle detection
509                    if ctx.effective.cycle_detection {
510                        if ctx.recent.len() == ctx.effective.cycle_window {
511                            ctx.recent.pop_front();
512                        }
513                        ctx.recent.push_back(current.clone());
514                        let count = ctx.recent.iter().filter(|node| **node == current).count();
515                        if count >= 2 {
516                            let error = GraphError::CycleDetected {
517                                node: current.clone(),
518                                recent: ctx.recent.iter().cloned().collect(),
519                            };
520                            if let Some((manager, root)) = &ctx.callbacks {
521                                let error_value =
522                                    ensure_object(error.to_string().to_trace_output());
523                                let duration_ms = root.start_instant.elapsed().as_millis();
524                                manager.on_error(root, &error_value, duration_ms).await;
525                            }
526                            ctx.join_set.shutdown().await;
527                            ctx.pending_events.push_back(GraphEvent::Error(error));
528                            continue;
529                        }
530                    }
531
532                    // Interrupt Before
533                    if ctx.effective.interrupt_before.contains(&current)
534                        || self.interrupt_before.contains(&current)
535                    {
536                        let error = GraphError::Interrupted;
537                        if let Some((manager, root)) = &ctx.callbacks {
538                            let error_value = ensure_object(error.to_string().to_trace_output());
539                            let duration_ms = root.start_instant.elapsed().as_millis();
540                            manager.on_error(root, &error_value, duration_ms).await;
541                        }
542
543                        // Save checkpoint on interrupt
544                        if let (Some((checkpointer, _)), Some(thread_id)) = (
545                            self.checkpointer.as_ref(),
546                            ctx.checkpoint_thread_id.as_deref(),
547                        ) {
548                            let mut full_queue = ctx.queue.iter().cloned().collect::<Vec<_>>();
549                            full_queue.push((current.clone(), path_id));
550                            full_queue.extend(ctx.active_tasks.iter().cloned());
551
552                            let checkpoint = Checkpoint::new(
553                                thread_id.to_string(),
554                                ctx.state.clone(),
555                                ctx.step_count as u64,
556                                current.clone(),
557                                full_queue,
558                            );
559                            if let Err(e) = checkpointer.save(&checkpoint).await {
560                                let graph_err = GraphError::from(e);
561                                if let Some((manager, root)) = &ctx.callbacks {
562                                    let error_value =
563                                        ensure_object(graph_err.to_string().to_trace_output());
564                                    let duration_ms = root.start_instant.elapsed().as_millis();
565                                    manager.on_error(root, &error_value, duration_ms).await;
566                                }
567                                ctx.pending_events.push_back(GraphEvent::Error(graph_err));
568                            } else {
569                                ctx.pending_events.push_back(GraphEvent::CheckpointSaved {
570                                    node: current.clone(),
571                                    timestamp: Utc::now().timestamp_millis() as u64,
572                                });
573                                if let Some((manager, root)) = &ctx.callbacks {
574                                    // Checkpoint saved event
575                                    manager
576                                        .on_event(
577                                            root,
578                                            "checkpoint_saved",
579                                            &json!({"node_id": current}),
580                                        )
581                                        .await;
582                                }
583                            }
584                        }
585
586                        ctx.join_set.shutdown().await;
587                        ctx.pending_events.push_back(GraphEvent::Error(error));
588                        continue;
589                    }
590
591                    // Get Node
592                    let node = match self.nodes.get(&current) {
593                        Some(node) => node.clone(),
594                        None => {
595                            let error = GraphError::InvalidEdge {
596                                node: current.clone(),
597                            };
598                            // observers...
599                            ctx.pending_events.push_back(GraphEvent::Error(error));
600                            continue;
601                        }
602                    };
603
604                    // Side Effects: Node Start
605                    emit_status_event(
606                        &ctx.agent_event_sender,
607                        &mut ctx.agent_event_step,
608                        &ctx.agent_event_thread_id,
609                        "node_start",
610                        format!("Starting node {current}"),
611                    )
612                    .await;
613
614                    if let Some((manager, root)) = &ctx.callbacks {
615                        let node_ctx = root.child(RunType::Chain, current.clone());
616                        let node_inputs = ensure_object(ctx.state.to_trace_input());
617                        manager.on_start(&node_ctx, &node_inputs).await;
618                        ctx.callback_nodes
619                            .insert((current.clone(), path_id), node_ctx);
620                    }
621
622                    ctx.pending_events.push_back(GraphEvent::NodeEnter {
623                        node: current.clone(),
624                        timestamp: Utc::now().timestamp_millis() as u64,
625                    });
626
627                    // Prepare Node Execution
628                    let input_state = ctx.state.clone();
629                    // We need a node-specific context
630                    let node_ctx_obs = ctx.observer.clone();
631                    let node_id = current.clone();
632                    let effective_config_spawn = ctx.effective.clone();
633                    let remaining = effective_config_spawn
634                        .max_steps
635                        .map(|m| m.saturating_sub(ctx.step_count)); // approximate
636
637                    let context = GraphContext {
638                        remaining_steps: remaining,
639                        observer: node_ctx_obs,
640                        node_id: node_id.clone(),
641                    };
642
643                    ctx.active_tasks.insert((current.clone(), path_id));
644
645                    // Spawn
646                    ctx.join_set.spawn(async move {
647                        let future = node.invoke_with_context(input_state, &context);
648                        let result = if let Some(timeout) = effective_config_spawn.node_timeout {
649                            match tokio::time::timeout(timeout, future).await {
650                                Ok(res) => res,
651                                Err(_) => Err(WesichainError::Custom(format!(
652                                    "Node {} timed out after {:?}",
653                                    node_id, timeout
654                                ))),
655                            }
656                        } else {
657                            future.await
658                        };
659                        (current, result, path_id)
660                    });
661
662                    continue; // Loop back to pick up next event or task
663                }
664
665                // 4. Process Completed Tasks
666                if !ctx.join_set.is_empty() {
667                    if let Some(join_res) = ctx.join_set.join_next().await {
668                        let (current, invoke_res, path_id) = match join_res {
669                            Ok(r) => r,
670                            Err(err) => {
671                                let error = GraphError::System(err.to_string());
672                                ctx.join_set.shutdown().await;
673                                ctx.pending_events.push_back(GraphEvent::Error(error));
674                                continue;
675                            }
676                        };
677
678                        ctx.active_tasks.remove(&(current.clone(), path_id));
679
680                        match invoke_res {
681                            Ok(update) => {
682                                // Node Success
683                                let output_debug =
684                                    serde_json::to_string(&update).unwrap_or_default();
685                                ctx.state = ctx.state.apply_update(update.clone());
686
687                                ctx.pending_events.push_back(GraphEvent::NodeFinished {
688                                    node: current.clone(),
689                                    output: output_debug,
690                                    timestamp: Utc::now().timestamp_millis() as u64,
691                                });
692
693                                // CRITICAL: Emit StateUpdate for invoke_graph consumers
694                                ctx.pending_events
695                                    .push_back(GraphEvent::StateUpdate(update));
696
697                                // Callbacks end
698                                if let Some((manager, _root)) = &ctx.callbacks {
699                                    if let Some(node_ctx) =
700                                        ctx.callback_nodes.remove(&(current.clone(), path_id))
701                                    {
702                                        let node_outputs =
703                                            ensure_object(ctx.state.to_trace_output());
704                                        let duration_ms =
705                                            node_ctx.start_instant.elapsed().as_millis();
706                                        manager.on_end(&node_ctx, &node_outputs, duration_ms).await;
707                                    }
708                                }
709                                // Observer end (handled by callbacks)
710                                emit_status_event(
711                                    &ctx.agent_event_sender,
712                                    &mut ctx.agent_event_step,
713                                    &ctx.agent_event_thread_id,
714                                    "node_end",
715                                    format!("Completed node {current}"),
716                                )
717                                .await;
718
719                                ctx.pending_events.push_back(GraphEvent::NodeExit {
720                                    node: current.clone(),
721                                    timestamp: Utc::now().timestamp_millis() as u64,
722                                });
723
724                                // 4c. Route Next (moved before Checkpoint)
725                                if let Some(condition) = self.conditional.get(&current) {
726                                    let targets = condition(&ctx.state);
727                                    let next_paths: Vec<(String, u64)> = if targets.len() > 1 {
728                                        targets
729                                            .into_iter()
730                                            .map(|t| {
731                                                if t == END {
732                                                    (t, path_id)
733                                                } else {
734                                                    let h = stable_hash(&(path_id, &t));
735                                                    (t, h)
736                                                }
737                                            })
738                                            .collect()
739                                    } else {
740                                        targets.into_iter().map(|t| (t, path_id)).collect()
741                                    };
742
743                                    for (next, next_path_id) in next_paths {
744                                        if next == END {
745                                            continue;
746                                        }
747                                        if !self.nodes.contains_key(&next) {
748                                            // Error
749                                            let error =
750                                                GraphError::InvalidEdge { node: next.clone() };
751                                            ctx.pending_events.push_back(GraphEvent::Error(error));
752                                            ctx.join_set.shutdown().await;
753                                            continue; // Outer loop continues, catches next event
754                                        }
755                                        ctx.queue.push_back((next, next_path_id));
756                                    }
757                                } else if let Some(targets) = self.edges.get(&current) {
758                                    let next_paths: Vec<(String, u64)> = if targets.len() > 1 {
759                                        targets
760                                            .iter()
761                                            .map(|t| {
762                                                if *t == END {
763                                                    (t.clone(), path_id)
764                                                } else {
765                                                    (t.clone(), stable_hash(&(path_id, t)))
766                                                }
767                                            })
768                                            .collect()
769                                    } else {
770                                        targets.iter().cloned().map(|t| (t, path_id)).collect()
771                                    };
772
773                                    for (next, next_path_id) in next_paths {
774                                        if next == END {
775                                            continue;
776                                        }
777                                        if !self.nodes.contains_key(&next) {
778                                            let error =
779                                                GraphError::InvalidEdge { node: next.clone() };
780                                            ctx.pending_events.push_back(GraphEvent::Error(error));
781                                            ctx.join_set.shutdown().await;
782                                            continue;
783                                        }
784                                        ctx.queue.push_back((next, next_path_id));
785                                    }
786                                }
787
788                                // 4a. Checkpoint
789                                if let (Some((checkpointer, _)), Some(thread_id)) = (
790                                    self.checkpointer.as_ref(),
791                                    ctx.checkpoint_thread_id.as_deref(),
792                                ) {
793                                    let mut full_queue =
794                                        ctx.queue.iter().cloned().collect::<Vec<_>>();
795                                    full_queue.extend(ctx.active_tasks.iter().cloned());
796
797                                    let checkpoint = Checkpoint::new(
798                                        thread_id.to_string(),
799                                        ctx.state.clone(),
800                                        ctx.step_count as u64,
801                                        current.clone(),
802                                        full_queue,
803                                    );
804
805                                    if let Err(e) = checkpointer.save(&checkpoint).await {
806                                        let graph_err = GraphError::from(e);
807                                        if let Some((manager, root)) = &ctx.callbacks {
808                                            let error_value = ensure_object(
809                                                graph_err.to_string().to_trace_output(),
810                                            );
811                                            let duration_ms =
812                                                root.start_instant.elapsed().as_millis();
813                                            manager.on_error(root, &error_value, duration_ms).await;
814                                        }
815                                        ctx.pending_events.push_back(GraphEvent::Error(graph_err));
816                                        ctx.join_set.shutdown().await;
817                                        continue;
818                                    } else {
819                                        ctx.pending_events.push_back(GraphEvent::CheckpointSaved {
820                                            node: current.clone(),
821                                            timestamp: Utc::now().timestamp_millis() as u64,
822                                        });
823
824                                        if let Some((manager, root)) = &ctx.callbacks {
825                                            // Checkpoint saved event
826                                            manager
827                                                .on_event(
828                                                    root,
829                                                    "checkpoint_saved",
830                                                    &json!({"node_id": current}),
831                                                )
832                                                .await;
833                                        }
834                                    }
835                                }
836
837                                // 4b. Interrupt After (AFTER checkpoint)
838                                if ctx.effective.interrupt_after.contains(&current)
839                                    || self.interrupt_after.contains(&current)
840                                {
841                                    let error = GraphError::Interrupted;
842                                    if let Some((manager, root)) = &ctx.callbacks {
843                                        let error_value =
844                                            ensure_object(error.to_string().to_trace_output());
845                                        let duration_ms = root.start_instant.elapsed().as_millis();
846                                        manager.on_error(root, &error_value, duration_ms).await;
847                                    }
848                                    ctx.pending_events.push_back(GraphEvent::Error(error));
849                                    continue;
850                                }
851                            }
852                            Err(e) => {
853                                // Node Failure
854                                let error = GraphError::NodeFailed {
855                                    node: current.clone(),
856                                    source: Box::new(e),
857                                };
858                                if let Some((manager, _root)) = &ctx.callbacks {
859                                    if let Some(node_ctx) =
860                                        ctx.callback_nodes.remove(&(current.clone(), path_id))
861                                    {
862                                        let error_value =
863                                            ensure_object(error.to_string().to_trace_output());
864                                        let duration_ms =
865                                            node_ctx.start_instant.elapsed().as_millis();
866                                        manager
867                                            .on_error(&node_ctx, &error_value, duration_ms)
868                                            .await;
869                                    }
870                                }
871                                ctx.join_set.shutdown().await;
872                                ctx.pending_events.push_back(GraphEvent::Error(error));
873                                continue;
874                            }
875                        }
876                    }
877                } else if ctx.queue.is_empty() {
878                    // Done!
879                    if let Some((manager, root)) = &ctx.callbacks {
880                        let outputs = ensure_object(ctx.state.to_trace_output());
881                        let duration_ms = root.start_instant.elapsed().as_millis();
882                        manager.on_end(root, &outputs, duration_ms).await;
883                    }
884
885                    emit_status_event(
886                        &ctx.agent_event_sender,
887                        &mut ctx.agent_event_step,
888                        &ctx.agent_event_thread_id,
889                        "completed",
890                        "Graph execution completed",
891                    )
892                    .await;
893
894                    return None;
895                }
896            }
897        })
898        .boxed()
899    }
900
901    pub async fn invoke_graph_with_options(
902        &self,
903        mut state: GraphState<S>,
904        mut options: ExecutionOptions,
905    ) -> Result<GraphState<S>, GraphError> {
906        let checkpoint_thread_id = options.checkpoint_thread_id.clone().or_else(|| {
907            self.checkpointer
908                .as_ref()
909                .map(|(_, thread_id)| thread_id.clone())
910        });
911
912        let agent_event_sender = options.agent_event_sender.clone();
913        let _agent_event_thread_id = options
914            .agent_event_thread_id
915            .clone()
916            .or_else(|| checkpoint_thread_id.clone())
917            .unwrap_or_else(|| "graph".to_string());
918        let mut agent_event_step = 0usize;
919
920        if options.auto_resume {
921            if let (Some((checkpointer, _)), Some(thread_id)) =
922                (self.checkpointer.as_ref(), checkpoint_thread_id.as_deref())
923            {
924                match checkpointer.load(thread_id).await {
925                    Ok(Some(saved)) => {
926                        state = saved.state;
927                        // Important: when resuming, we must respect the saved queue and step
928                        if !saved.queue.is_empty() {
929                            options.initial_queue = Some(saved.queue);
930                            options.initial_step = Some(saved.step as usize + 1);
931                        } else {
932                            // If queue is empty, it means the previous run finished.
933                            // We use the loaded state but allow the default (or provided) initial_queue
934                            // to start a new execution path from this state.
935                        }
936                    }
937                    Ok(None) => {}
938                    Err(error) => return Err(error.into()),
939                }
940            }
941        }
942
943        if !self.nodes.contains_key(&self.entry) {
944            let error = GraphError::MissingNode {
945                node: self.entry.clone(),
946            };
947            emit_error_event(
948                &agent_event_sender,
949                &mut agent_event_step,
950                error.to_string(),
951                Some("graph".to_string()),
952            )
953            .await;
954            return Err(error);
955        }
956
957        let mut stream = self.stream_invoke_with_options(state.clone(), options);
958
959        while let Some(event) = stream.next().await {
960            match event {
961                Ok(GraphEvent::StateUpdate(update)) => {
962                    state = state.apply_update(update);
963                }
964                Ok(GraphEvent::Error(e)) | Err(e) => return Err(e),
965                // Other events (NodeEnter, etc.) can be ignored by invoke_graph
966                // as they are handled by stream side effects (observers/callbacks).
967                _ => {}
968            }
969        }
970
971        Ok(state)
972    }
973
974    pub async fn invoke(&self, state: GraphState<S>) -> Result<GraphState<S>, WesichainError> {
975        self.invoke_graph(state)
976            .await
977            .map_err(|err| WesichainError::Custom(err.to_string()))
978    }
979
980    pub async fn invoke_with_options(
981        &self,
982        state: GraphState<S>,
983        options: ExecutionOptions,
984    ) -> Result<GraphState<S>, WesichainError> {
985        self.invoke_graph_with_options(state, options)
986            .await
987            .map_err(|err| WesichainError::Custom(err.to_string()))
988    }
989
990    pub async fn get_state(&self, thread_id: &str) -> Result<Option<GraphState<S>>, GraphError> {
991        if let Some((checkpointer, _)) = &self.checkpointer {
992            let checkpoint = checkpointer.load(thread_id).await?;
993            Ok(checkpoint.map(|cp| cp.state))
994        } else {
995            Ok(None)
996        }
997    }
998
999    pub async fn resume(
1000        &self,
1001        checkpoint: Checkpoint<S>,
1002        mut options: ExecutionOptions,
1003    ) -> Result<GraphState<S>, GraphError> {
1004        options.initial_queue = Some(checkpoint.queue);
1005        // Start from next logical step
1006        options.initial_step = Some(checkpoint.step as usize + 1);
1007        self.invoke_graph_with_options(checkpoint.state, options)
1008            .await
1009    }
1010
1011    pub async fn update_state(
1012        &self,
1013        thread_id: &str,
1014        values: S,
1015        as_node: Option<String>,
1016    ) -> Result<(), GraphError> {
1017        if let Some((checkpointer, _)) = &self.checkpointer {
1018            // Load current state or default
1019            let (mut state, step) = if let Some(checkpoint) = checkpointer.load(thread_id).await? {
1020                (checkpoint.state, checkpoint.step + 1)
1021            } else {
1022                (GraphState::new(S::default()), 1)
1023            };
1024
1025            // Apply update
1026            let update = StateUpdate::new(values);
1027            state = state.apply_update(update);
1028
1029            // Save new checkpoint
1030            let node = as_node.unwrap_or_else(|| "user".to_string());
1031            let checkpoint = Checkpoint::new(thread_id.to_string(), state, step, node, vec![]);
1032            checkpointer.save(&checkpoint).await?;
1033            Ok(())
1034        } else {
1035            Err(GraphError::Checkpoint("Checkpointer not configured".into()))
1036        }
1037    }
1038}
1039
1040#[async_trait::async_trait]
1041impl<S: StateSchema<Update = S>> Runnable<GraphState<S>, StateUpdate<S>> for ExecutableGraph<S> {
1042    async fn invoke(&self, input: GraphState<S>) -> Result<StateUpdate<S>, WesichainError> {
1043        let result = self
1044            .invoke_graph(input)
1045            .await
1046            .map_err(|e| WesichainError::Custom(e.to_string()))?;
1047        Ok(StateUpdate::new(result.data))
1048    }
1049
1050    fn stream<'a>(
1051        &'a self,
1052        input: GraphState<S>,
1053    ) -> BoxStream<'a, Result<wesichain_core::StreamEvent, WesichainError>> {
1054        let stream = self.stream_invoke(input);
1055
1056        stream
1057            .filter_map(|event_res| async move {
1058                match event_res {
1059                    Ok(GraphEvent::Error(e)) | Err(e) => {
1060                        Some(Err(WesichainError::Custom(e.to_string())))
1061                    }
1062                    // In a real implementation, we would map Node events to metadata
1063                    // or if the graph output was compatible, stream chunks.
1064                    // For now, subgraphs are mostly opaque unless we add a specific event mapper.
1065                    _ => None,
1066                }
1067            })
1068            .boxed()
1069    }
1070}
1071
1072#[cfg(test)]
1073mod tests {
1074    use super::*;
1075
1076    #[test]
1077    fn test_stable_path_hashing() {
1078        let parent_id = 12345u64;
1079        let node_name = "test_node";
1080
1081        // Hash with our specific fixed keys
1082        let state = RandomState::with_seeds(0x517cc1b727220a95, 0x6ed9eba1999cd92d, 0, 0);
1083        let hash1 = state.hash_one((parent_id, node_name));
1084
1085        // Re-compute to ensure determinism
1086        let expected1 = state.hash_one((parent_id, node_name));
1087        assert_eq!(hash1, expected1, "Hash MUST be deterministic");
1088
1089        let different_hash =
1090            RandomState::with_seeds(123, 456, 0, 0).hash_one((parent_id, node_name));
1091
1092        assert_ne!(hash1, different_hash, "Should differ from arbitrary keys");
1093    }
1094}