synwire_core/runnables/
branch.rs1use std::sync::Arc;
4
5use crate::BoxFuture;
6use crate::error::SynwireError;
7use crate::runnables::config::RunnableConfig;
8use crate::runnables::core::RunnableCore;
9use serde_json::Value;
10
11type ConditionFn = Arc<dyn Fn(&Value) -> bool + Send + Sync>;
13
14pub struct RunnableBranch {
33 branches: Vec<(ConditionFn, Box<dyn RunnableCore>)>,
34 default: Box<dyn RunnableCore>,
35}
36
37impl RunnableBranch {
38 pub fn new(
40 branches: Vec<(ConditionFn, Box<dyn RunnableCore>)>,
41 default: Box<dyn RunnableCore>,
42 ) -> Self {
43 Self { branches, default }
44 }
45}
46
47impl RunnableCore for RunnableBranch {
48 fn invoke<'a>(
49 &'a self,
50 input: Value,
51 config: Option<&'a RunnableConfig>,
52 ) -> BoxFuture<'a, Result<Value, SynwireError>> {
53 Box::pin(async move {
54 for (condition, runnable) in &self.branches {
55 if condition(&input) {
56 return runnable.invoke(input, config).await;
57 }
58 }
59 self.default.invoke(input, config).await
60 })
61 }
62
63 #[allow(clippy::unnecessary_literal_bound)]
64 fn name(&self) -> &str {
65 "RunnableBranch"
66 }
67}
68
69#[cfg(test)]
70#[allow(clippy::unwrap_used)]
71mod tests {
72 use super::*;
73 use crate::runnables::lambda::RunnableLambda;
74 use crate::runnables::passthrough::RunnablePassthrough;
75
76 #[tokio::test]
77 async fn test_branch_routes_correctly() {
78 let is_number: ConditionFn = Arc::new(|v: &Value| v.is_number());
79 let double = RunnableLambda::new(|v: Value| {
80 Box::pin(async move {
81 let n = v.as_i64().unwrap() * 2;
82 Ok(Value::from(n))
83 })
84 });
85
86 let branch = RunnableBranch::new(
87 vec![(is_number, Box::new(double) as Box<dyn RunnableCore>)],
88 Box::new(RunnablePassthrough),
89 );
90
91 let result = branch.invoke(Value::from(5), None).await.unwrap();
93 assert_eq!(result, Value::from(10));
94
95 let result = branch.invoke(Value::from("hello"), None).await.unwrap();
97 assert_eq!(result, Value::from("hello"));
98 }
99
100 #[tokio::test]
101 async fn test_branch_default_when_no_match() {
102 let never_true: ConditionFn = Arc::new(|_: &Value| false);
103 let branch = RunnableBranch::new(
104 vec![(
105 never_true,
106 Box::new(RunnablePassthrough) as Box<dyn RunnableCore>,
107 )],
108 Box::new(RunnableLambda::new(|_| {
109 Box::pin(async { Ok(Value::from("default")) })
110 })),
111 );
112
113 let result = branch.invoke(Value::from(1), None).await.unwrap();
114 assert_eq!(result, Value::from("default"));
115 }
116}