1use async_trait::async_trait;
2use std::marker::PhantomData;
3
4use futures::stream::{self, BoxStream};
5use futures::{StreamExt, TryStreamExt};
6
7use crate::{Retrying, Runnable, StreamEvent, WesichainError};
8
9pub struct Chain<Head, Tail, Mid> {
10 head: Head,
11 tail: Tail,
12 _marker: PhantomData<Mid>,
13}
14
15impl<Head, Tail, Mid> Chain<Head, Tail, Mid> {
16 pub fn new(head: Head, tail: Tail) -> Self {
17 Self {
18 head,
19 tail,
20 _marker: PhantomData,
21 }
22 }
23}
24
25#[async_trait]
26impl<Input, Mid, Output, Head, Tail> Runnable<Input, Output> for Chain<Head, Tail, Mid>
27where
28 Input: Send + 'static,
29 Mid: Send + Sync + 'static,
30 Output: Send + 'static,
31 Head: Runnable<Input, Mid> + Send + Sync,
32 Tail: Runnable<Mid, Output> + Send + Sync,
33{
34 async fn invoke(&self, input: Input) -> Result<Output, WesichainError> {
35 let mid = self.head.invoke(input).await?;
36 self.tail.invoke(mid).await
37 }
38
39 fn stream(&self, input: Input) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
40 let head = &self.head;
42 let tail = &self.tail;
43 let stream = stream::once(async move { head.invoke(input).await })
44 .map_ok(move |mid| tail.stream(mid))
45 .try_flatten();
46 stream.boxed()
47 }
48
49 fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
50 let head_ser = self.head.to_serializable()?;
51 let tail_ser = self.tail.to_serializable()?;
52
53 let mut steps = Vec::new();
55 match head_ser {
56 crate::serde::SerializableRunnable::Chain { steps: mut s } => steps.append(&mut s),
57 _ => steps.push(head_ser),
58 }
59
60 match tail_ser {
64 crate::serde::SerializableRunnable::Chain { steps: mut s } => steps.append(&mut s),
65 _ => steps.push(tail_ser),
66 }
67
68 Some(crate::serde::SerializableRunnable::Chain { steps })
69 }
70}
71
72pub trait RunnableExt<Input: Send + 'static, Output: Send + 'static>:
73 Runnable<Input, Output> + Sized
74{
75 fn then<NextOutput, Next>(self, next: Next) -> Chain<Self, Next, Output>
76 where
77 Next: Runnable<Output, NextOutput> + Send + Sync,
78 NextOutput: Send + 'static,
79 {
80 Chain::new(self, next)
81 }
82
83 fn with_retries(self, max_attempts: usize) -> Retrying<Self>
84 where
85 Self: Send + Sync,
86 Input: Clone,
87 {
88 Retrying::new(self, max_attempts)
89 }
90
91 fn bind(self, args: crate::Value) -> crate::RunnableBinding<Self, Input, Output>
92 where
93 Self: Send + Sync,
94 Input: crate::Bindable + Clone + Send + 'static,
95 Output: Send + Sync + 'static,
96 {
97 crate::RunnableBinding::new(self, args)
98 }
99
100 fn with_fallbacks(
101 self,
102 fallbacks: Vec<std::sync::Arc<dyn Runnable<Input, Output> + Send + Sync>>,
103 ) -> crate::RunnableWithFallbacks<Input, Output>
104 where
105 Self: Send + Sync + 'static,
106 Input: Clone + Send + 'static,
107 {
108 crate::RunnableWithFallbacks::new(std::sync::Arc::new(self), fallbacks)
109 }
110
111 fn with_timeout(self, duration: std::time::Duration) -> crate::TimeLimited<Self>
112 where
113 Self: Send + Sync,
114 Input: Clone,
115 {
116 crate::TimeLimited::new(self, duration)
117 }
118
119 fn with_rate_limit(self, requests_per_minute: u32) -> crate::RateLimited<Self>
120 where
121 Self: Send + Sync,
122 Input: Clone,
123 {
124 crate::RateLimited::new(self, requests_per_minute)
125 }
126}
127
128impl<Input: Send + 'static, Output: Send + 'static, T> RunnableExt<Input, Output> for T where
129 T: Runnable<Input, Output> + Sized
130{
131}
132
133use crate::Value;
134use std::sync::Arc;
135
136pub struct RuntimeChain {
139 steps: Vec<Arc<dyn Runnable<Value, Value>>>,
140}
141
142impl RuntimeChain {
143 pub fn new(steps: Vec<Arc<dyn Runnable<Value, Value>>>) -> Self {
144 Self { steps }
145 }
146}
147
148#[async_trait]
149impl Runnable<Value, Value> for RuntimeChain {
150 async fn invoke(&self, input: Value) -> Result<Value, WesichainError> {
151 let mut current = input;
152 for step in &self.steps {
153 current = step.invoke(current).await?;
154 }
155 Ok(current)
156 }
157
158 fn stream<'a>(&'a self, input: Value) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
159 if self.steps.is_empty() {
160 return stream::empty().boxed();
161 }
162
163 let steps = &self.steps;
164
165 let s = async move {
166 let mut current = input;
167 let last_idx = steps.len() - 1;
168 for (i, step) in steps.iter().enumerate() {
169 if i == last_idx {
170 break;
171 }
172 current = step.invoke(current).await?;
173 }
174 Ok::<Value, WesichainError>(current)
175 };
176
177 stream::once(s)
178 .map_ok(move |val| steps.last().unwrap().stream(val))
179 .try_flatten()
180 .boxed()
181 }
182
183 fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
184 let mut steps = Vec::new();
185 for step in &self.steps {
186 steps.push(step.to_serializable()?);
187 }
188 Some(crate::serde::SerializableRunnable::Chain { steps })
189 }
190}