Skip to main content

rust_langgraph/pregel/
branch.rs

1//! Branch logic for conditional routing.
2//!
3//! Branches enable dynamic routing in graphs based on state conditions.
4
5use crate::errors::Result;
6use crate::state::State;
7use crate::types::Send as SendType;
8use async_trait::async_trait;
9use std::future::Future;
10
11/// The result of evaluating a branch condition.
12///
13/// Branches can route to a single node, multiple nodes in parallel,
14/// dynamic Send targets, or end execution.
15#[derive(Debug, Clone)]
16pub enum BranchResult {
17    /// Go to a single next node
18    Single(String),
19
20    /// Go to multiple nodes in parallel
21    Multiple(Vec<String>),
22
23    /// Dynamic routing with Send (map-reduce pattern)
24    Send(Vec<SendType>),
25
26    /// End execution
27    End,
28}
29
30impl BranchResult {
31    /// Create a Single variant
32    pub fn single(node: impl Into<String>) -> Self {
33        BranchResult::Single(node.into())
34    }
35
36    /// Create a Multiple variant
37    pub fn multiple(nodes: Vec<impl Into<String>>) -> Self {
38        BranchResult::Multiple(nodes.into_iter().map(|n| n.into()).collect())
39    }
40
41    /// Create a Send variant
42    pub fn send(sends: Vec<SendType>) -> Self {
43        BranchResult::Send(sends)
44    }
45
46    /// Create an End variant
47    pub fn end() -> Self {
48        BranchResult::End
49    }
50
51    /// Check if this is an End result
52    pub fn is_end(&self) -> bool {
53        matches!(self, BranchResult::End)
54    }
55
56    /// Get the list of node names to route to (if Single or Multiple)
57    pub fn node_names(&self) -> Vec<String> {
58        match self {
59            BranchResult::Single(name) => vec![name.clone()],
60            BranchResult::Multiple(names) => names.clone(),
61            BranchResult::Send(_) => vec![],
62            BranchResult::End => vec![],
63        }
64    }
65}
66
67/// Trait for conditional routing logic.
68///
69/// Branches examine state and decide which node(s) to execute next.
70///
71/// # Example
72///
73/// ```rust
74/// use rust_langgraph::pregel::{Branch, BranchResult};
75/// use rust_langgraph::Error;
76/// use async_trait::async_trait;
77///
78/// struct MyBranch;
79///
80/// #[async_trait]
81/// impl<S> Branch<S> for MyBranch
82/// where
83///     S: Send + Sync + 'static,
84/// {
85///     async fn route(&self, _state: &S) -> Result<BranchResult, Error> {
86///         Ok(BranchResult::single("next_node"))
87///     }
88/// }
89/// ```
90#[async_trait]
91pub trait Branch<S>: std::marker::Send + Sync
92where
93    S: State,
94{
95    /// Evaluate the branch condition and return routing decision.
96    ///
97    /// # Arguments
98    ///
99    /// * `state` - The current graph state
100    ///
101    /// # Returns
102    ///
103    /// A BranchResult indicating where to route next
104    async fn route(&self, state: &S) -> Result<BranchResult>;
105}
106
107// Implement Branch for async closures
108#[async_trait]
109impl<S, F, Fut> Branch<S> for F
110where
111    S: State,
112    F: Fn(&S) -> Fut + std::marker::Send + Sync,
113    Fut: Future<Output = Result<BranchResult>> + std::marker::Send,
114{
115    async fn route(&self, state: &S) -> Result<BranchResult> {
116        self(state).await
117    }
118}
119
120/// Type alias for boxed branches
121pub type BranchBox<S> = Box<dyn Branch<S>>;
122
123/// Helper to create a branch from a closure
124pub fn branch_fn<S, F, Fut>(f: F) -> impl Branch<S>
125where
126    S: State,
127    F: Fn(&S) -> Fut + std::marker::Send + Sync + 'static,
128    Fut: Future<Output = Result<BranchResult>> + std::marker::Send + 'static,
129{
130    f
131}
132
133/// A simple branch that always routes to the same node
134pub struct StaticBranch {
135    target: String,
136}
137
138impl StaticBranch {
139    /// Create a new static branch
140    pub fn new(target: impl Into<String>) -> Self {
141        Self {
142            target: target.into(),
143        }
144    }
145}
146
147#[async_trait]
148impl<S: State> Branch<S> for StaticBranch {
149    async fn route(&self, _state: &S) -> Result<BranchResult> {
150        Ok(BranchResult::Single(self.target.clone()))
151    }
152}
153
154/// A branch that always ends execution
155pub struct EndBranch;
156
157#[async_trait]
158impl<S: State> Branch<S> for EndBranch {
159    async fn route(&self, _state: &S) -> Result<BranchResult> {
160        Ok(BranchResult::End)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::state::DictState;
168
169    #[test]
170    fn test_branch_result_creation() {
171        let single = BranchResult::single("node1");
172        assert!(matches!(single, BranchResult::Single(_)));
173        assert_eq!(single.node_names(), vec!["node1"]);
174
175        let multiple = BranchResult::multiple(vec!["node1", "node2"]);
176        assert!(matches!(multiple, BranchResult::Multiple(_)));
177        assert_eq!(multiple.node_names(), vec!["node1", "node2"]);
178
179        let end = BranchResult::end();
180        assert!(end.is_end());
181        assert!(end.node_names().is_empty());
182    }
183
184    #[tokio::test]
185    async fn test_static_branch() {
186        let branch = StaticBranch::new("target");
187        let state = DictState::new();
188        let result = branch.route(&state).await.unwrap();
189
190        assert!(matches!(result, BranchResult::Single(_)));
191        assert_eq!(result.node_names(), vec!["target"]);
192    }
193
194    #[tokio::test]
195    async fn test_end_branch() {
196        let branch = EndBranch;
197        let state = DictState::new();
198        let result = branch.route(&state).await.unwrap();
199
200        assert!(result.is_end());
201    }
202
203    #[tokio::test]
204    async fn test_branch_closure() {
205        let branch = |_state: &DictState| async {
206            Ok(BranchResult::single("dynamic_node"))
207        };
208
209        let state = DictState::new();
210        let result = branch.route(&state).await.unwrap();
211
212        assert_eq!(result.node_names(), vec!["dynamic_node"]);
213    }
214
215    #[tokio::test]
216    async fn test_branch_with_send() {
217        let sends = vec![
218            SendType::new("process", serde_json::json!({"id": 1})),
219            SendType::new("process", serde_json::json!({"id": 2})),
220        ];
221
222        let result = BranchResult::send(sends);
223        assert!(matches!(result, BranchResult::Send(_)));
224    }
225}