rust_langgraph/pregel/
branch.rs1use crate::errors::Result;
6use crate::state::State;
7use crate::types::Send as SendType;
8use async_trait::async_trait;
9use std::future::Future;
10
11#[derive(Debug, Clone)]
16pub enum BranchResult {
17 Single(String),
19
20 Multiple(Vec<String>),
22
23 Send(Vec<SendType>),
25
26 End,
28}
29
30impl BranchResult {
31 pub fn single(node: impl Into<String>) -> Self {
33 BranchResult::Single(node.into())
34 }
35
36 pub fn multiple(nodes: Vec<impl Into<String>>) -> Self {
38 BranchResult::Multiple(nodes.into_iter().map(|n| n.into()).collect())
39 }
40
41 pub fn send(sends: Vec<SendType>) -> Self {
43 BranchResult::Send(sends)
44 }
45
46 pub fn end() -> Self {
48 BranchResult::End
49 }
50
51 pub fn is_end(&self) -> bool {
53 matches!(self, BranchResult::End)
54 }
55
56 pub fn node_names(&self) -> Vec<String> {
58 match self {
59 BranchResult::Single(name) => vec![name.clone()],
60 BranchResult::Multiple(names) => names.clone(),
61 BranchResult::Send(_) => vec![],
62 BranchResult::End => vec![],
63 }
64 }
65}
66
67#[async_trait]
91pub trait Branch<S>: std::marker::Send + Sync
92where
93 S: State,
94{
95 async fn route(&self, state: &S) -> Result<BranchResult>;
105}
106
107#[async_trait]
109impl<S, F, Fut> Branch<S> for F
110where
111 S: State,
112 F: Fn(&S) -> Fut + std::marker::Send + Sync,
113 Fut: Future<Output = Result<BranchResult>> + std::marker::Send,
114{
115 async fn route(&self, state: &S) -> Result<BranchResult> {
116 self(state).await
117 }
118}
119
120pub type BranchBox<S> = Box<dyn Branch<S>>;
122
123pub fn branch_fn<S, F, Fut>(f: F) -> impl Branch<S>
125where
126 S: State,
127 F: Fn(&S) -> Fut + std::marker::Send + Sync + 'static,
128 Fut: Future<Output = Result<BranchResult>> + std::marker::Send + 'static,
129{
130 f
131}
132
133pub struct StaticBranch {
135 target: String,
136}
137
138impl StaticBranch {
139 pub fn new(target: impl Into<String>) -> Self {
141 Self {
142 target: target.into(),
143 }
144 }
145}
146
147#[async_trait]
148impl<S: State> Branch<S> for StaticBranch {
149 async fn route(&self, _state: &S) -> Result<BranchResult> {
150 Ok(BranchResult::Single(self.target.clone()))
151 }
152}
153
154pub struct EndBranch;
156
157#[async_trait]
158impl<S: State> Branch<S> for EndBranch {
159 async fn route(&self, _state: &S) -> Result<BranchResult> {
160 Ok(BranchResult::End)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::state::DictState;
168
169 #[test]
170 fn test_branch_result_creation() {
171 let single = BranchResult::single("node1");
172 assert!(matches!(single, BranchResult::Single(_)));
173 assert_eq!(single.node_names(), vec!["node1"]);
174
175 let multiple = BranchResult::multiple(vec!["node1", "node2"]);
176 assert!(matches!(multiple, BranchResult::Multiple(_)));
177 assert_eq!(multiple.node_names(), vec!["node1", "node2"]);
178
179 let end = BranchResult::end();
180 assert!(end.is_end());
181 assert!(end.node_names().is_empty());
182 }
183
184 #[tokio::test]
185 async fn test_static_branch() {
186 let branch = StaticBranch::new("target");
187 let state = DictState::new();
188 let result = branch.route(&state).await.unwrap();
189
190 assert!(matches!(result, BranchResult::Single(_)));
191 assert_eq!(result.node_names(), vec!["target"]);
192 }
193
194 #[tokio::test]
195 async fn test_end_branch() {
196 let branch = EndBranch;
197 let state = DictState::new();
198 let result = branch.route(&state).await.unwrap();
199
200 assert!(result.is_end());
201 }
202
203 #[tokio::test]
204 async fn test_branch_closure() {
205 let branch = |_state: &DictState| async {
206 Ok(BranchResult::single("dynamic_node"))
207 };
208
209 let state = DictState::new();
210 let result = branch.route(&state).await.unwrap();
211
212 assert_eq!(result.node_names(), vec!["dynamic_node"]);
213 }
214
215 #[tokio::test]
216 async fn test_branch_with_send() {
217 let sends = vec![
218 SendType::new("process", serde_json::json!({"id": 1})),
219 SendType::new("process", serde_json::json!({"id": 2})),
220 ];
221
222 let result = BranchResult::send(sends);
223 assert!(matches!(result, BranchResult::Send(_)));
224 }
225}