synwire_core/runnables/
chain.rs1use crate::BoxFuture;
4use crate::error::SynwireError;
5use crate::runnables::config::RunnableConfig;
6use crate::runnables::core::RunnableCore;
7use serde_json::Value;
8
9pub struct RunnableSequence {
23 steps: Vec<Box<dyn RunnableCore>>,
24 name: Option<String>,
25}
26
27impl RunnableSequence {
28 pub fn new(steps: Vec<Box<dyn RunnableCore>>) -> Self {
30 Self { steps, name: None }
31 }
32
33 #[must_use]
35 pub fn with_name(mut self, name: impl Into<String>) -> Self {
36 self.name = Some(name.into());
37 self
38 }
39}
40
41impl RunnableCore for RunnableSequence {
42 fn invoke<'a>(
43 &'a self,
44 input: Value,
45 config: Option<&'a RunnableConfig>,
46 ) -> BoxFuture<'a, Result<Value, SynwireError>> {
47 Box::pin(async move {
48 let mut current = input;
49 for step in &self.steps {
50 current = step.invoke(current, config).await?;
51 }
52 Ok(current)
53 })
54 }
55
56 fn name(&self) -> &str {
57 self.name.as_deref().unwrap_or("RunnableSequence")
58 }
59}
60
61pub fn pipe(first: Box<dyn RunnableCore>, second: Box<dyn RunnableCore>) -> RunnableSequence {
65 RunnableSequence::new(vec![first, second])
66}
67
68pub struct RunnableParallel {
83 steps: Vec<(String, Box<dyn RunnableCore>)>,
84}
85
86impl RunnableParallel {
87 pub fn new(steps: Vec<(String, Box<dyn RunnableCore>)>) -> Self {
89 Self { steps }
90 }
91}
92
93impl RunnableCore for RunnableParallel {
94 fn invoke<'a>(
95 &'a self,
96 input: Value,
97 config: Option<&'a RunnableConfig>,
98 ) -> BoxFuture<'a, Result<Value, SynwireError>> {
99 Box::pin(async move {
100 let futures: Vec<_> = self
101 .steps
102 .iter()
103 .map(|(name, runnable)| {
104 let input_clone = input.clone();
105 let name = name.clone();
106 async move {
107 let result = runnable.invoke(input_clone, config).await?;
108 Ok::<_, SynwireError>((name, result))
109 }
110 })
111 .collect();
112
113 let results = futures_util::future::try_join_all(futures).await?;
114 let mut map = serde_json::Map::new();
115 for (name, value) in results {
116 let _replaced = map.insert(name, value);
117 }
118 Ok(Value::Object(map))
119 })
120 }
121
122 #[allow(clippy::unnecessary_literal_bound)]
123 fn name(&self) -> &str {
124 "RunnableParallel"
125 }
126}
127
128#[cfg(test)]
129#[allow(clippy::unwrap_used)]
130mod tests {
131 use super::*;
132 use crate::runnables::lambda::RunnableLambda;
133 use crate::runnables::passthrough::RunnablePassthrough;
134
135 #[tokio::test]
136 async fn test_runnable_sequence() {
137 let add_one = RunnableLambda::new(|v: Value| {
138 Box::pin(async move {
139 let n = v.as_i64().unwrap() + 1;
140 Ok(Value::from(n))
141 })
142 });
143 let multiply_two = RunnableLambda::new(|v: Value| {
144 Box::pin(async move {
145 let n = v.as_i64().unwrap() * 2;
146 Ok(Value::from(n))
147 })
148 });
149
150 let seq = RunnableSequence::new(vec![Box::new(add_one), Box::new(multiply_two)]);
151 let result = seq.invoke(Value::from(5), None).await.unwrap();
152 assert_eq!(result, Value::from(12)); }
154
155 #[tokio::test]
156 async fn test_pipe_composes() {
157 let add_one = RunnableLambda::new(|v: Value| {
158 Box::pin(async move {
159 let n = v.as_i64().unwrap() + 1;
160 Ok(Value::from(n))
161 })
162 });
163 let multiply_two = RunnableLambda::new(|v: Value| {
164 Box::pin(async move {
165 let n = v.as_i64().unwrap() * 2;
166 Ok(Value::from(n))
167 })
168 });
169
170 let seq = pipe(Box::new(add_one), Box::new(multiply_two));
171 let result = seq.invoke(Value::from(10), None).await.unwrap();
172 assert_eq!(result, Value::from(22)); }
174
175 #[tokio::test]
176 async fn test_runnable_parallel() {
177 let double = RunnableLambda::new(|v: Value| {
178 Box::pin(async move {
179 let n = v.as_i64().unwrap() * 2;
180 Ok(Value::from(n))
181 })
182 });
183 let passthrough = RunnablePassthrough;
184
185 let par = RunnableParallel::new(vec![
186 ("doubled".into(), Box::new(double) as Box<dyn RunnableCore>),
187 (
188 "original".into(),
189 Box::new(passthrough) as Box<dyn RunnableCore>,
190 ),
191 ]);
192
193 let result = par.invoke(Value::from(5), None).await.unwrap();
194 let obj = result.as_object().unwrap();
195 assert_eq!(obj.get("doubled").unwrap(), &Value::from(10));
196 assert_eq!(obj.get("original").unwrap(), &Value::from(5));
197 }
198
199 #[tokio::test]
200 async fn test_sequence_name_default() {
201 let seq = RunnableSequence::new(vec![]);
202 assert_eq!(seq.name(), "RunnableSequence");
203 }
204
205 #[tokio::test]
206 async fn test_sequence_name_custom() {
207 let seq = RunnableSequence::new(vec![]).with_name("my_chain");
208 assert_eq!(seq.name(), "my_chain");
209 }
210}