Skip to main content

wesichain_core/
runnable_parallel.rs

1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use crate::{Runnable, StreamEvent, WesichainError};
5use async_trait::async_trait;
6use futures::future::join_all;
7use futures::stream::{self, BoxStream, StreamExt};
8
9pub struct RunnableParallel<Input, Output> {
10    steps: BTreeMap<String, Arc<dyn Runnable<Input, Output> + Send + Sync>>,
11}
12
13impl<Input, Output> RunnableParallel<Input, Output> {
14    pub fn new(steps: BTreeMap<String, Arc<dyn Runnable<Input, Output> + Send + Sync>>) -> Self {
15        Self { steps }
16    }
17}
18
19#[async_trait]
20impl<Input, Output> Runnable<Input, BTreeMap<String, Output>> for RunnableParallel<Input, Output>
21where
22    Input: Clone + Send + Sync + 'static,
23    Output: Send + Sync + 'static,
24{
25    async fn invoke(&self, input: Input) -> Result<BTreeMap<String, Output>, WesichainError> {
26        let mut keys = Vec::new();
27        let mut futures = Vec::new();
28
29        for (key, step) in &self.steps {
30            keys.push(key.clone());
31            futures.push(step.invoke(input.clone()));
32        }
33
34        let results = join_all(futures).await;
35
36        let mut output = BTreeMap::new();
37        for (key, result) in keys.into_iter().zip(results) {
38            output.insert(key, result?);
39        }
40
41        Ok(output)
42    }
43
44    fn stream<'a>(&'a self, input: Input) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
45        // Fan-out: each branch emits a Metadata event tagging its name, then its own stream.
46        // All branch streams are merged with select_all for interleaved output.
47        let streams: Vec<BoxStream<'a, Result<StreamEvent, WesichainError>>> = self
48            .steps
49            .iter()
50            .map(|(key, step)| {
51                let branch_name = key.clone();
52                let metadata_event = Ok(StreamEvent::Metadata {
53                    key: "parallel_step".to_string(),
54                    value: serde_json::json!(branch_name),
55                });
56                stream::once(std::future::ready(metadata_event))
57                    .chain(step.stream(input.clone()))
58                    .boxed()
59            })
60            .collect();
61
62        if streams.is_empty() {
63            return stream::empty().boxed();
64        }
65
66        futures::stream::select_all(streams).boxed()
67    }
68
69    fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
70        let mut steps = std::collections::HashMap::new();
71        for (key, step) in &self.steps {
72            steps.insert(key.clone(), step.to_serializable()?);
73        }
74        Some(crate::serde::SerializableRunnable::Parallel { steps })
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    struct ConstRunnable(String);
83
84    #[async_trait]
85    impl Runnable<String, String> for ConstRunnable {
86        async fn invoke(&self, _input: String) -> Result<String, WesichainError> {
87            Ok(self.0.clone())
88        }
89
90        fn stream<'a>(
91            &'a self,
92            _input: String,
93        ) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
94            stream::iter(vec![Ok(StreamEvent::ContentChunk(self.0.clone()))]).boxed()
95        }
96    }
97
98    #[tokio::test]
99    async fn test_parallel_invoke_two_branches() {
100        let mut steps: BTreeMap<String, Arc<dyn Runnable<String, String> + Send + Sync>> =
101            BTreeMap::new();
102        steps.insert("a".to_string(), Arc::new(ConstRunnable("hello".to_string())));
103        steps.insert("b".to_string(), Arc::new(ConstRunnable("world".to_string())));
104        let parallel = RunnableParallel::new(steps);
105        let result = parallel.invoke("input".to_string()).await.unwrap();
106        assert_eq!(result.get("a").unwrap(), "hello");
107        assert_eq!(result.get("b").unwrap(), "world");
108    }
109
110    #[tokio::test]
111    async fn test_parallel_invoke_three_branches() {
112        let mut steps: BTreeMap<String, Arc<dyn Runnable<String, String> + Send + Sync>> =
113            BTreeMap::new();
114        steps.insert("x".to_string(), Arc::new(ConstRunnable("1".to_string())));
115        steps.insert("y".to_string(), Arc::new(ConstRunnable("2".to_string())));
116        steps.insert("z".to_string(), Arc::new(ConstRunnable("3".to_string())));
117        let parallel = RunnableParallel::new(steps);
118        let result = parallel.invoke("input".to_string()).await.unwrap();
119        assert_eq!(result.len(), 3);
120        assert_eq!(result.get("x").unwrap(), "1");
121        assert_eq!(result.get("y").unwrap(), "2");
122        assert_eq!(result.get("z").unwrap(), "3");
123    }
124
125    #[tokio::test]
126    async fn test_parallel_stream_emits_from_all_branches() {
127        let mut steps: BTreeMap<String, Arc<dyn Runnable<String, String> + Send + Sync>> =
128            BTreeMap::new();
129        steps.insert("a".to_string(), Arc::new(ConstRunnable("hello".to_string())));
130        steps.insert("b".to_string(), Arc::new(ConstRunnable("world".to_string())));
131        let parallel = RunnableParallel::new(steps);
132        let events: Vec<_> = parallel.stream("input".to_string()).collect().await;
133
134        // Each branch emits 1 Metadata + 1 ContentChunk → 4 total events
135        assert_eq!(events.len(), 4);
136
137        let metadata_count = events
138            .iter()
139            .filter(|e| matches!(e, Ok(StreamEvent::Metadata { key, .. }) if key == "parallel_step"))
140            .count();
141        assert_eq!(metadata_count, 2);
142
143        let content_count = events
144            .iter()
145            .filter(|e| matches!(e, Ok(StreamEvent::ContentChunk(_))))
146            .count();
147        assert_eq!(content_count, 2);
148    }
149
150    #[tokio::test]
151    async fn test_parallel_stream_empty() {
152        let steps: BTreeMap<String, Arc<dyn Runnable<String, String> + Send + Sync>> =
153            BTreeMap::new();
154        let parallel = RunnableParallel::new(steps);
155        let events: Vec<_> = parallel.stream("input".to_string()).collect().await;
156        assert!(events.is_empty());
157    }
158}