Skip to main content

synwire_core/runnables/
branch.rs

1//! Conditional branching runnable.
2
3use std::sync::Arc;
4
5use crate::BoxFuture;
6use crate::error::SynwireError;
7use crate::runnables::config::RunnableConfig;
8use crate::runnables::core::RunnableCore;
9use serde_json::Value;
10
11/// Type alias for a branch condition function.
12type ConditionFn = Arc<dyn Fn(&Value) -> bool + Send + Sync>;
13
14/// Routes input to different runnables based on conditions.
15///
16/// Evaluates conditions in order; the first branch whose condition
17/// returns `true` is invoked. If no conditions match, the default
18/// runnable is invoked.
19///
20/// # Example
21///
22/// ```rust,no_run
23/// # use std::sync::Arc;
24/// # use synwire_core::runnables::{RunnableBranch, RunnablePassthrough, RunnableCore};
25/// let branch = RunnableBranch::new(
26///     vec![
27///         (Arc::new(|v: &serde_json::Value| v.is_number()), Box::new(RunnablePassthrough) as Box<dyn RunnableCore>),
28///     ],
29///     Box::new(RunnablePassthrough),
30/// );
31/// ```
32pub struct RunnableBranch {
33    branches: Vec<(ConditionFn, Box<dyn RunnableCore>)>,
34    default: Box<dyn RunnableCore>,
35}
36
37impl RunnableBranch {
38    /// Create a new branch with condition-runnable pairs and a default.
39    pub fn new(
40        branches: Vec<(ConditionFn, Box<dyn RunnableCore>)>,
41        default: Box<dyn RunnableCore>,
42    ) -> Self {
43        Self { branches, default }
44    }
45}
46
47impl RunnableCore for RunnableBranch {
48    fn invoke<'a>(
49        &'a self,
50        input: Value,
51        config: Option<&'a RunnableConfig>,
52    ) -> BoxFuture<'a, Result<Value, SynwireError>> {
53        Box::pin(async move {
54            for (condition, runnable) in &self.branches {
55                if condition(&input) {
56                    return runnable.invoke(input, config).await;
57                }
58            }
59            self.default.invoke(input, config).await
60        })
61    }
62
63    #[allow(clippy::unnecessary_literal_bound)]
64    fn name(&self) -> &str {
65        "RunnableBranch"
66    }
67}
68
69#[cfg(test)]
70#[allow(clippy::unwrap_used)]
71mod tests {
72    use super::*;
73    use crate::runnables::lambda::RunnableLambda;
74    use crate::runnables::passthrough::RunnablePassthrough;
75
76    #[tokio::test]
77    async fn test_branch_routes_correctly() {
78        let is_number: ConditionFn = Arc::new(|v: &Value| v.is_number());
79        let double = RunnableLambda::new(|v: Value| {
80            Box::pin(async move {
81                let n = v.as_i64().unwrap() * 2;
82                Ok(Value::from(n))
83            })
84        });
85
86        let branch = RunnableBranch::new(
87            vec![(is_number, Box::new(double) as Box<dyn RunnableCore>)],
88            Box::new(RunnablePassthrough),
89        );
90
91        // Number input should be doubled.
92        let result = branch.invoke(Value::from(5), None).await.unwrap();
93        assert_eq!(result, Value::from(10));
94
95        // String input should pass through via default.
96        let result = branch.invoke(Value::from("hello"), None).await.unwrap();
97        assert_eq!(result, Value::from("hello"));
98    }
99
100    #[tokio::test]
101    async fn test_branch_default_when_no_match() {
102        let never_true: ConditionFn = Arc::new(|_: &Value| false);
103        let branch = RunnableBranch::new(
104            vec![(
105                never_true,
106                Box::new(RunnablePassthrough) as Box<dyn RunnableCore>,
107            )],
108            Box::new(RunnableLambda::new(|_| {
109                Box::pin(async { Ok(Value::from("default")) })
110            })),
111        );
112
113        let result = branch.invoke(Value::from(1), None).await.unwrap();
114        assert_eq!(result, Value::from("default"));
115    }
116}