wesichain_core/
runnable_parallel.rs1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use crate::{Runnable, StreamEvent, WesichainError};
5use async_trait::async_trait;
6use futures::future::join_all;
7use futures::stream::{self, BoxStream, StreamExt};
8
9pub struct RunnableParallel<Input, Output> {
10 steps: BTreeMap<String, Arc<dyn Runnable<Input, Output> + Send + Sync>>,
11}
12
13impl<Input, Output> RunnableParallel<Input, Output> {
14 pub fn new(steps: BTreeMap<String, Arc<dyn Runnable<Input, Output> + Send + Sync>>) -> Self {
15 Self { steps }
16 }
17}
18
19#[async_trait]
20impl<Input, Output> Runnable<Input, BTreeMap<String, Output>> for RunnableParallel<Input, Output>
21where
22 Input: Clone + Send + Sync + 'static,
23 Output: Send + Sync + 'static,
24{
25 async fn invoke(&self, input: Input) -> Result<BTreeMap<String, Output>, WesichainError> {
26 let mut keys = Vec::new();
27 let mut futures = Vec::new();
28
29 for (key, step) in &self.steps {
30 keys.push(key.clone());
31 futures.push(step.invoke(input.clone()));
32 }
33
34 let results = join_all(futures).await;
35
36 let mut output = BTreeMap::new();
37 for (key, result) in keys.into_iter().zip(results) {
38 output.insert(key, result?);
39 }
40
41 Ok(output)
42 }
43
44 fn stream<'a>(&'a self, input: Input) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
45 let streams: Vec<BoxStream<'a, Result<StreamEvent, WesichainError>>> = self
48 .steps
49 .iter()
50 .map(|(key, step)| {
51 let branch_name = key.clone();
52 let metadata_event = Ok(StreamEvent::Metadata {
53 key: "parallel_step".to_string(),
54 value: serde_json::json!(branch_name),
55 });
56 stream::once(std::future::ready(metadata_event))
57 .chain(step.stream(input.clone()))
58 .boxed()
59 })
60 .collect();
61
62 if streams.is_empty() {
63 return stream::empty().boxed();
64 }
65
66 futures::stream::select_all(streams).boxed()
67 }
68
69 fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
70 let mut steps = std::collections::HashMap::new();
71 for (key, step) in &self.steps {
72 steps.insert(key.clone(), step.to_serializable()?);
73 }
74 Some(crate::serde::SerializableRunnable::Parallel { steps })
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 struct ConstRunnable(String);
83
84 #[async_trait]
85 impl Runnable<String, String> for ConstRunnable {
86 async fn invoke(&self, _input: String) -> Result<String, WesichainError> {
87 Ok(self.0.clone())
88 }
89
90 fn stream<'a>(
91 &'a self,
92 _input: String,
93 ) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
94 stream::iter(vec![Ok(StreamEvent::ContentChunk(self.0.clone()))]).boxed()
95 }
96 }
97
98 #[tokio::test]
99 async fn test_parallel_invoke_two_branches() {
100 let mut steps: BTreeMap<String, Arc<dyn Runnable<String, String> + Send + Sync>> =
101 BTreeMap::new();
102 steps.insert("a".to_string(), Arc::new(ConstRunnable("hello".to_string())));
103 steps.insert("b".to_string(), Arc::new(ConstRunnable("world".to_string())));
104 let parallel = RunnableParallel::new(steps);
105 let result = parallel.invoke("input".to_string()).await.unwrap();
106 assert_eq!(result.get("a").unwrap(), "hello");
107 assert_eq!(result.get("b").unwrap(), "world");
108 }
109
110 #[tokio::test]
111 async fn test_parallel_invoke_three_branches() {
112 let mut steps: BTreeMap<String, Arc<dyn Runnable<String, String> + Send + Sync>> =
113 BTreeMap::new();
114 steps.insert("x".to_string(), Arc::new(ConstRunnable("1".to_string())));
115 steps.insert("y".to_string(), Arc::new(ConstRunnable("2".to_string())));
116 steps.insert("z".to_string(), Arc::new(ConstRunnable("3".to_string())));
117 let parallel = RunnableParallel::new(steps);
118 let result = parallel.invoke("input".to_string()).await.unwrap();
119 assert_eq!(result.len(), 3);
120 assert_eq!(result.get("x").unwrap(), "1");
121 assert_eq!(result.get("y").unwrap(), "2");
122 assert_eq!(result.get("z").unwrap(), "3");
123 }
124
125 #[tokio::test]
126 async fn test_parallel_stream_emits_from_all_branches() {
127 let mut steps: BTreeMap<String, Arc<dyn Runnable<String, String> + Send + Sync>> =
128 BTreeMap::new();
129 steps.insert("a".to_string(), Arc::new(ConstRunnable("hello".to_string())));
130 steps.insert("b".to_string(), Arc::new(ConstRunnable("world".to_string())));
131 let parallel = RunnableParallel::new(steps);
132 let events: Vec<_> = parallel.stream("input".to_string()).collect().await;
133
134 assert_eq!(events.len(), 4);
136
137 let metadata_count = events
138 .iter()
139 .filter(|e| matches!(e, Ok(StreamEvent::Metadata { key, .. }) if key == "parallel_step"))
140 .count();
141 assert_eq!(metadata_count, 2);
142
143 let content_count = events
144 .iter()
145 .filter(|e| matches!(e, Ok(StreamEvent::ContentChunk(_))))
146 .count();
147 assert_eq!(content_count, 2);
148 }
149
150 #[tokio::test]
151 async fn test_parallel_stream_empty() {
152 let steps: BTreeMap<String, Arc<dyn Runnable<String, String> + Send + Sync>> =
153 BTreeMap::new();
154 let parallel = RunnableParallel::new(steps);
155 let events: Vec<_> = parallel.stream("input".to_string()).collect().await;
156 assert!(events.is_empty());
157 }
158}