strands_agents/multiagent/
graph.rs

1//! Graph-based multi-agent orchestration.
2//!
3//! Provides a deterministic graph-based agent orchestration system where
4//! agents are nodes in a graph, executed according to edge dependencies,
5//! with output from one node passed as input to connected nodes.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use async_trait::async_trait;
12use futures::StreamExt;
13
14use super::base::{
15    InterruptState, InvocationState, MultiAgentBase, MultiAgentEvent,
16    MultiAgentEventStream, MultiAgentInput, MultiAgentResult, NodeResult, NodeResultValue, Status,
17};
18use crate::agent::Agent;
19use crate::hooks::{
20    AfterInvocationEvent, AfterToolCallEvent, BeforeInvocationEvent, BeforeToolCallEvent,
21    HookEvent, HookRegistry,
22};
23use crate::types::tools::{ToolResult as ToolResultType, ToolUse};
24use crate::types::errors::{Result, StrandsError};
25use crate::types::streaming::{Metrics, Usage};
26
27/// Type alias for edge conditions.
28pub type EdgeCondition = Arc<dyn Fn(&GraphState) -> bool + Send + Sync>;
29
30/// An edge connecting two nodes in the graph.
31pub struct GraphEdge {
32    pub from_node: String,
33    pub to_node: String,
34    pub condition: Option<EdgeCondition>,
35}
36
37impl GraphEdge {
38    /// Creates an unconditional edge.
39    pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
40        Self {
41            from_node: from.into(),
42            to_node: to.into(),
43            condition: None,
44        }
45    }
46
47    /// Creates a conditional edge.
48    pub fn conditional(
49        from: impl Into<String>,
50        to: impl Into<String>,
51        condition: impl Fn(&GraphState) -> bool + Send + Sync + 'static,
52    ) -> Self {
53        Self {
54            from_node: from.into(),
55            to_node: to.into(),
56            condition: Some(Arc::new(condition)),
57        }
58    }
59
60    /// Check if this edge should be traversed.
61    pub fn should_traverse(&self, state: &GraphState) -> bool {
62        match &self.condition {
63            Some(cond) => cond(state),
64            None => true,
65        }
66    }
67}
68
69/// A node in the agent graph.
70pub struct GraphNode {
71    pub node_id: String,
72    pub agent: Agent,
73    pub dependencies: HashSet<String>,
74    pub status: Status,
75    pub result: Option<NodeResult>,
76    pub execution_time_ms: u64,
77}
78
79impl GraphNode {
80    pub fn new(node_id: impl Into<String>, agent: Agent) -> Self {
81        Self {
82            node_id: node_id.into(),
83            agent,
84            dependencies: HashSet::new(),
85            status: Status::Pending,
86            result: None,
87            execution_time_ms: 0,
88        }
89    }
90
91    /// Resets the node state for re-execution.
92    pub fn reset(&mut self) {
93        self.status = Status::Pending;
94        self.result = None;
95        self.execution_time_ms = 0;
96    }
97}
98
99/// State of graph execution.
100#[derive(Debug, Clone, Default)]
101pub struct GraphState {
102    pub status: Status,
103    pub task: String,
104    pub completed_nodes: HashSet<String>,
105    pub failed_nodes: HashSet<String>,
106    pub execution_order: Vec<String>,
107    pub results: HashMap<String, NodeResult>,
108    pub accumulated_usage: Usage,
109    pub accumulated_metrics: Metrics,
110    pub execution_count: u32,
111    pub execution_time_ms: u64,
112    pub start_time: Option<Instant>,
113    pub total_nodes: usize,
114}
115
116impl GraphState {
117    /// Check if graph execution should continue.
118    pub fn should_continue(
119        &self,
120        max_node_executions: Option<usize>,
121        execution_timeout: Option<Duration>,
122    ) -> (bool, &'static str) {
123        if let Some(max) = max_node_executions {
124            if self.execution_order.len() >= max {
125                return (false, "Max node executions reached");
126            }
127        }
128
129        if let (Some(timeout), Some(start)) = (execution_timeout, self.start_time) {
130            if start.elapsed() > timeout {
131                return (false, "Execution timed out");
132            }
133        }
134
135        (true, "Continuing")
136    }
137}
138
139/// Result from graph execution.
140#[derive(Debug, Clone)]
141pub struct GraphResult {
142    pub status: Status,
143    pub results: HashMap<String, NodeResult>,
144    pub execution_order: Vec<String>,
145    pub accumulated_usage: Usage,
146    pub accumulated_metrics: Metrics,
147    pub execution_time_ms: u64,
148    pub total_nodes: usize,
149    pub completed_nodes: usize,
150    pub failed_nodes: usize,
151    pub entry_points: Vec<String>,
152}
153
154impl From<GraphResult> for MultiAgentResult {
155    fn from(gr: GraphResult) -> Self {
156        MultiAgentResult {
157            status: gr.status,
158            results: gr.results,
159            accumulated_usage: gr.accumulated_usage,
160            accumulated_metrics: gr.accumulated_metrics,
161            execution_count: gr.execution_order.len() as u32,
162            execution_time_ms: gr.execution_time_ms,
163            interrupts: Vec::new(),
164        }
165    }
166}
167
168/// Configuration options for graph execution.
169#[derive(Debug, Clone)]
170pub struct GraphConfig {
171    pub max_node_executions: Option<usize>,
172    pub execution_timeout: Option<Duration>,
173    pub node_timeout: Option<Duration>,
174    pub reset_on_revisit: bool,
175}
176
177impl Default for GraphConfig {
178    fn default() -> Self {
179        Self {
180            max_node_executions: Some(100),
181            execution_timeout: Some(Duration::from_secs(900)),
182            node_timeout: Some(Duration::from_secs(300)),
183            reset_on_revisit: false,
184        }
185    }
186}
187
188/// Builder for constructing graphs.
189pub struct GraphBuilder {
190    nodes: HashMap<String, GraphNode>,
191    edges: Vec<GraphEdge>,
192    entry_points: HashSet<String>,
193    config: GraphConfig,
194    id: String,
195    hooks: HookRegistry,
196}
197
198impl Default for GraphBuilder {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204impl GraphBuilder {
205    pub fn new() -> Self {
206        Self {
207            nodes: HashMap::new(),
208            edges: Vec::new(),
209            entry_points: HashSet::new(),
210            config: GraphConfig::default(),
211            id: "default_graph".to_string(),
212            hooks: HookRegistry::new(),
213        }
214    }
215
216    /// Sets the graph ID.
217    pub fn id(mut self, id: impl Into<String>) -> Self {
218        self.id = id.into();
219        self
220    }
221
222    /// Adds a node to the graph.
223    pub fn add_node(mut self, node_id: impl Into<String>, agent: Agent) -> Self {
224        let node_id = node_id.into();
225        self.nodes.insert(node_id.clone(), GraphNode::new(node_id, agent));
226        self
227    }
228
229    /// Adds an edge between two nodes.
230    pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
231        let from = from.into();
232        let to = to.into();
233
234        if let Some(node) = self.nodes.get_mut(&to) {
235            node.dependencies.insert(from.clone());
236        }
237
238        self.edges.push(GraphEdge::new(from, to));
239        self
240    }
241
242    /// Adds a conditional edge between two nodes.
243    pub fn add_conditional_edge<F>(
244        mut self,
245        from: impl Into<String>,
246        to: impl Into<String>,
247        condition: F,
248    ) -> Self
249    where
250        F: Fn(&GraphState) -> bool + Send + Sync + 'static,
251    {
252        let from = from.into();
253        let to = to.into();
254
255        if let Some(node) = self.nodes.get_mut(&to) {
256            node.dependencies.insert(from.clone());
257        }
258
259        self.edges.push(GraphEdge::conditional(from, to, condition));
260        self
261    }
262
263    /// Sets explicit entry points for the graph.
264    pub fn set_entry_points(mut self, entry_points: impl IntoIterator<Item = impl Into<String>>) -> Self {
265        self.entry_points = entry_points.into_iter().map(Into::into).collect();
266        self
267    }
268
269    /// Sets a single entry point.
270    pub fn set_entry_point(mut self, node_id: impl Into<String>) -> Self {
271        self.entry_points.insert(node_id.into());
272        self
273    }
274
275    /// Sets the graph configuration.
276    pub fn config(mut self, config: GraphConfig) -> Self {
277        self.config = config;
278        self
279    }
280
281    /// Sets the maximum number of node executions.
282    pub fn max_node_executions(mut self, max: usize) -> Self {
283        self.config.max_node_executions = Some(max);
284        self
285    }
286
287    /// Sets the execution timeout.
288    pub fn execution_timeout(mut self, timeout: Duration) -> Self {
289        self.config.execution_timeout = Some(timeout);
290        self
291    }
292
293    /// Sets the node timeout.
294    pub fn node_timeout(mut self, timeout: Duration) -> Self {
295        self.config.node_timeout = Some(timeout);
296        self
297    }
298
299    /// Enables reset on revisit.
300    pub fn reset_on_revisit(mut self, enabled: bool) -> Self {
301        self.config.reset_on_revisit = enabled;
302        self
303    }
304
305    /// Sets the hook registry.
306    pub fn hooks(mut self, hooks: HookRegistry) -> Self {
307        self.hooks = hooks;
308        self
309    }
310
311    /// Builds the graph.
312    pub fn build(self) -> Result<Graph> {
313        if self.nodes.is_empty() {
314            return Err(StrandsError::ConfigurationError {
315                message: "Graph must have at least one node".to_string(),
316            });
317        }
318
319        let entry_points = if self.entry_points.is_empty() {
320            self.nodes
321                .values()
322                .filter(|n| n.dependencies.is_empty())
323                .map(|n| n.node_id.clone())
324                .collect()
325        } else {
326            self.entry_points
327        };
328
329        if entry_points.is_empty() {
330            return Err(StrandsError::ConfigurationError {
331                message: "Graph has no entry points (all nodes have dependencies)".to_string(),
332            });
333        }
334
335        Ok(Graph {
336            id: self.id,
337            nodes: self.nodes,
338            edges: self.edges,
339            entry_points,
340            config: self.config,
341            state: GraphState::default(),
342            hooks: self.hooks,
343            interrupt_state: InterruptState::new(),
344        })
345    }
346}
347
348/// A graph of agents for orchestrated execution.
349pub struct Graph {
350    id: String,
351    nodes: HashMap<String, GraphNode>,
352    edges: Vec<GraphEdge>,
353    entry_points: HashSet<String>,
354    config: GraphConfig,
355    state: GraphState,
356    hooks: HookRegistry,
357    interrupt_state: InterruptState,
358}
359
360impl Graph {
361    /// Creates a new graph builder.
362    pub fn builder() -> GraphBuilder {
363        GraphBuilder::new()
364    }
365
366    /// Returns the graph ID.
367    pub fn graph_id(&self) -> &str {
368        &self.id
369    }
370
371    /// Returns the current graph state.
372    pub fn state(&self) -> &GraphState {
373        &self.state
374    }
375
376    /// Returns an iterator over node IDs.
377    pub fn node_ids(&self) -> impl Iterator<Item = &str> {
378        self.nodes.keys().map(|s| s.as_str())
379    }
380
381    /// Returns the entry point node IDs.
382    pub fn entry_points(&self) -> &HashSet<String> {
383        &self.entry_points
384    }
385
386    /// Returns a reference to the interrupt state.
387    pub fn interrupt_state(&self) -> &InterruptState {
388        &self.interrupt_state
389    }
390
391    /// Returns a mutable reference to the interrupt state.
392    pub fn interrupt_state_mut(&mut self) -> &mut InterruptState {
393        &mut self.interrupt_state
394    }
395
396
397    /// Invokes the graph synchronously.
398    pub fn call(&mut self, task: impl Into<MultiAgentInput>) -> Result<GraphResult> {
399        tokio::task::block_in_place(|| {
400            tokio::runtime::Handle::current().block_on(self.invoke_async(task.into(), None))
401        })
402    }
403
404    /// Invokes the graph asynchronously and returns the result.
405    pub async fn invoke_async(
406        &mut self,
407        task: MultiAgentInput,
408        invocation_state: Option<&InvocationState>,
409    ) -> Result<GraphResult> {
410        let total_nodes = self.nodes.len();
411        let entry_points_vec: Vec<String> = self.entry_points.iter().cloned().collect();
412
413        let mut stream = self.stream_async(task, invocation_state);
414        let mut final_result = None;
415
416        while let Some(event) = stream.next().await {
417            if let MultiAgentEvent::Result(result) = event {
418                final_result = Some(result);
419            }
420        }
421
422        drop(stream);
423
424        final_result
425            .map(|r| GraphResult {
426                status: r.status,
427                results: r.results,
428                execution_order: self.state.execution_order.clone(),
429                accumulated_usage: r.accumulated_usage,
430                accumulated_metrics: r.accumulated_metrics,
431                execution_time_ms: r.execution_time_ms,
432                total_nodes,
433                completed_nodes: self.state.completed_nodes.len(),
434                failed_nodes: self.state.failed_nodes.len(),
435                entry_points: entry_points_vec,
436            })
437            .ok_or_else(|| StrandsError::MultiAgentError {
438                message: "Graph execution completed without result".to_string(),
439            })
440    }
441
442    /// Streams events during graph execution.
443    pub fn stream_async<'a>(
444        &'a mut self,
445        task: MultiAgentInput,
446        _invocation_state: Option<&'a InvocationState>,
447    ) -> MultiAgentEventStream<'a> {
448        let task_str = task.to_string_lossy();
449
450        Box::pin(async_stream::stream! {
451            self.hooks.invoke(&HookEvent::BeforeInvocation(BeforeInvocationEvent)).await;
452
453            self.state = GraphState {
454                status: Status::Executing,
455                task: task_str.clone(),
456                start_time: Some(Instant::now()),
457                total_nodes: self.nodes.len(),
458                ..Default::default()
459            };
460
461            let mut queue: VecDeque<String> = self.entry_points.iter().cloned().collect();
462            let mut processed: HashSet<String> = HashSet::new();
463
464            while let Some(node_id) = queue.pop_front() {
465                if processed.contains(&node_id) {
466                    continue;
467                }
468
469                let (should_continue, reason) = self.state.should_continue(
470                    self.config.max_node_executions,
471                    self.config.execution_timeout,
472                );
473                if !should_continue {
474                    tracing::warn!("Graph execution stopped: {reason}");
475                    self.state.status = Status::Failed;
476                    break;
477                }
478
479                let deps_met = {
480                    if let Some(node) = self.nodes.get(&node_id) {
481                        node.dependencies.iter().all(|dep| self.state.completed_nodes.contains(dep))
482                    } else {
483                        false
484                    }
485                };
486
487                if !deps_met {
488                    queue.push_back(node_id);
489                    continue;
490                }
491
492                if self.config.reset_on_revisit && self.state.completed_nodes.contains(&node_id) {
493                    if let Some(node) = self.nodes.get_mut(&node_id) {
494                        node.reset();
495                    }
496                    self.state.completed_nodes.remove(&node_id);
497                }
498
499                yield MultiAgentEvent::node_start(&node_id, "agent");
500
501                self.hooks.invoke(&HookEvent::BeforeToolCall(BeforeToolCallEvent::new(
502                    ToolUse::new(&node_id, &node_id, serde_json::json!({}))
503                ))).await;
504
505                let result = self.execute_node(&node_id, &task_str).await;
506
507                match result {
508                    Ok(node_result) => {
509
510                        if node_result.status == Status::Interrupted {
511                            self.interrupt_state.deactivate();
512                            tracing::error!("user raised interrupt from agent | interrupts are not yet supported in graphs");
513                            self.state.status = Status::Failed;
514                            yield MultiAgentEvent::node_stop(&node_id, node_result);
515                            break;
516                        }
517
518                        self.state.completed_nodes.insert(node_id.clone());
519                        self.state.execution_order.push(node_id.clone());
520                        self.state.accumulated_usage.add(&node_result.accumulated_usage);
521                        self.state.accumulated_metrics.latency_ms += node_result.accumulated_metrics.latency_ms;
522                        self.state.execution_count += 1;
523
524                        if let Some(node) = self.nodes.get_mut(&node_id) {
525                            node.status = Status::Completed;
526                            node.execution_time_ms = node_result.execution_time_ms;
527                        }
528
529                        yield MultiAgentEvent::node_stop(&node_id, node_result.clone());
530
531                        self.state.results.insert(node_id.clone(), node_result);
532
533                        let mut next_nodes = Vec::new();
534                        for edge in &self.edges {
535                            if edge.from_node == node_id && edge.should_traverse(&self.state) {
536                                if !processed.contains(&edge.to_node) {
537                                    next_nodes.push(edge.to_node.clone());
538                                }
539                            }
540                        }
541
542                        if !next_nodes.is_empty() {
543                            yield MultiAgentEvent::handoff(
544                                vec![node_id.clone()],
545                                next_nodes.clone(),
546                                None,
547                            );
548                            for next in next_nodes {
549                                queue.push_back(next);
550                            }
551                        }
552                    }
553                    Err(e) => {
554                        tracing::error!("Node {node_id} failed: {e}");
555                        self.state.failed_nodes.insert(node_id.clone());
556                        if let Some(node) = self.nodes.get_mut(&node_id) {
557                            node.status = Status::Failed;
558                        }
559
560                        let error_result = NodeResult::from_error(e.to_string(), 0);
561                        yield MultiAgentEvent::node_stop(&node_id, error_result);
562                    }
563                }
564
565                self.hooks.invoke(&HookEvent::AfterToolCall(AfterToolCallEvent::new(
566                    ToolUse::new(&node_id, &node_id, serde_json::json!({})),
567                    ToolResultType::success(&node_id, "completed")
568                ))).await;
569                processed.insert(node_id);
570            }
571
572            if self.state.failed_nodes.is_empty() && self.state.status == Status::Executing {
573                self.state.status = Status::Completed;
574            } else if !self.state.failed_nodes.is_empty() {
575                self.state.status = Status::Failed;
576            }
577
578            self.state.execution_time_ms = self.state.start_time
579                .map(|s| s.elapsed().as_millis() as u64)
580                .unwrap_or(0);
581
582            self.hooks.invoke(&HookEvent::AfterInvocation(AfterInvocationEvent::new(None))).await;
583
584            let result = MultiAgentResult {
585                status: self.state.status,
586                results: self.state.results.clone(),
587                accumulated_usage: self.state.accumulated_usage.clone(),
588                accumulated_metrics: self.state.accumulated_metrics.clone(),
589                execution_count: self.state.execution_count,
590                execution_time_ms: self.state.execution_time_ms,
591                interrupts: Vec::new(),
592            };
593
594            yield MultiAgentEvent::result(result);
595        })
596    }
597
598    async fn execute_node(&mut self, node_id: &str, task: &str) -> Result<NodeResult> {
599        let start = Instant::now();
600
601        let input = self.build_node_input(node_id, task);
602
603        let node = self.nodes.get_mut(node_id).ok_or_else(|| StrandsError::InternalError {
604            message: format!("Node '{node_id}' not found"),
605        })?;
606
607        node.status = Status::Executing;
608
609        let agent_result = node.agent.invoke_async(input.as_str()).await?;
610        let execution_time_ms = start.elapsed().as_millis() as u64;
611
612        let usage = agent_result.usage.clone();
613
614        Ok(NodeResult {
615            result: NodeResultValue::Agent(agent_result),
616            execution_time_ms,
617            status: Status::Completed,
618            accumulated_usage: usage,
619            accumulated_metrics: Metrics { latency_ms: execution_time_ms, time_to_first_byte_ms: 0 },
620            execution_count: 1,
621            interrupts: Vec::new(),
622        })
623    }
624
625    fn build_node_input(&self, node_id: &str, task: &str) -> String {
626        let mut input = String::new();
627
628        let node = match self.nodes.get(node_id) {
629            Some(n) => n,
630            None => {
631                input.push_str(&format!("Task: {task}"));
632                return input;
633            }
634        };
635
636        if node.dependencies.is_empty() {
637            input.push_str(&format!("Task: {task}"));
638        } else {
639            input.push_str(&format!("Original Task: {task}\n\n"));
640            input.push_str("Inputs from previous nodes:\n\n");
641
642            for dep in &node.dependencies {
643                if let Some(result) = self.state.results.get(dep) {
644                    input.push_str(&format!("From {dep}:\n"));
645                    for agent_result in result.get_agent_results() {
646                        let text = agent_result.text();
647                        input.push_str(&format!("  - Agent: {text}\n"));
648                    }
649                }
650            }
651        }
652
653        input
654    }
655}
656
657#[async_trait]
658impl MultiAgentBase for Graph {
659    fn id(&self) -> &str {
660        &self.id
661    }
662
663    async fn invoke_async(
664        &mut self,
665        task: MultiAgentInput,
666        invocation_state: Option<&InvocationState>,
667    ) -> Result<MultiAgentResult> {
668        self.invoke_async(task, invocation_state).await.map(Into::into)
669    }
670
671    fn stream_async<'a>(
672        &'a mut self,
673        task: MultiAgentInput,
674        invocation_state: Option<&'a InvocationState>,
675    ) -> MultiAgentEventStream<'a> {
676        self.stream_async(task, invocation_state)
677    }
678
679    fn serialize_state(&self) -> serde_json::Value {
680        serde_json::json!({
681            "type": "graph",
682            "id": self.id,
683            "status": format!("{:?}", self.state.status).to_lowercase(),
684            "completed_nodes": self.state.completed_nodes.iter().collect::<Vec<_>>(),
685            "failed_nodes": self.state.failed_nodes.iter().collect::<Vec<_>>(),
686            "execution_order": self.state.execution_order,
687            "current_task": self.state.task,
688        })
689    }
690
691    fn deserialize_state(&mut self, payload: &serde_json::Value) -> Result<()> {
692        if let Some(status_str) = payload.get("status").and_then(|v| v.as_str()) {
693            self.state.status = match status_str {
694                "pending" => Status::Pending,
695                "executing" => Status::Executing,
696                "completed" => Status::Completed,
697                "failed" => Status::Failed,
698                "interrupted" => Status::Interrupted,
699                _ => Status::Pending,
700            };
701        }
702
703        if let Some(completed) = payload.get("completed_nodes").and_then(|v| v.as_array()) {
704            self.state.completed_nodes = completed
705                .iter()
706                .filter_map(|v| v.as_str().map(|s| s.to_string()))
707                .collect();
708        }
709
710        if let Some(task) = payload.get("current_task").and_then(|v| v.as_str()) {
711            self.state.task = task.to_string();
712        }
713
714        Ok(())
715    }
716}
717
718#[cfg(test)]
719mod tests {
720    use super::*;
721
722    #[test]
723    fn test_graph_no_nodes() {
724        let result = Graph::builder().build();
725        assert!(result.is_err());
726    }
727
728    #[test]
729    fn test_graph_state_should_continue() {
730        let state = GraphState::default();
731        let (should_continue, _) = state.should_continue(Some(10), None);
732        assert!(should_continue);
733
734        let mut state = GraphState::default();
735        state.execution_order = vec!["a".to_string(); 10];
736        let (should_continue, reason) = state.should_continue(Some(10), None);
737        assert!(!should_continue);
738        assert_eq!(reason, "Max node executions reached");
739    }
740
741    #[test]
742    fn test_node_result() {
743        let result = NodeResult::from_error("test error", 100);
744        assert!(result.is_error());
745        assert_eq!(result.execution_time_ms, 100);
746    }
747
748    #[test]
749    fn test_status_default() {
750        let status = Status::default();
751        assert_eq!(status, Status::Pending);
752    }
753}