Skip to main content

wesichain_core/
chain.rs

1use async_trait::async_trait;
2use std::marker::PhantomData;
3
4use futures::stream::{self, BoxStream};
5use futures::{StreamExt, TryStreamExt};
6
7use crate::{Retrying, Runnable, StreamEvent, WesichainError};
8
9pub struct Chain<Head, Tail, Mid> {
10    head: Head,
11    tail: Tail,
12    _marker: PhantomData<Mid>,
13}
14
15impl<Head, Tail, Mid> Chain<Head, Tail, Mid> {
16    pub fn new(head: Head, tail: Tail) -> Self {
17        Self {
18            head,
19            tail,
20            _marker: PhantomData,
21        }
22    }
23}
24
25#[async_trait]
26impl<Input, Mid, Output, Head, Tail> Runnable<Input, Output> for Chain<Head, Tail, Mid>
27where
28    Input: Send + 'static,
29    Mid: Send + Sync + 'static,
30    Output: Send + 'static,
31    Head: Runnable<Input, Mid> + Send + Sync,
32    Tail: Runnable<Mid, Output> + Send + Sync,
33{
34    async fn invoke(&self, input: Input) -> Result<Output, WesichainError> {
35        let mid = self.head.invoke(input).await?;
36        self.tail.invoke(mid).await
37    }
38
39    fn stream(&self, input: Input) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
40        // v0: streaming reflects the tail runnable only; the head is executed via invoke.
41        let head = &self.head;
42        let tail = &self.tail;
43        let stream = stream::once(async move { head.invoke(input).await })
44            .map_ok(move |mid| tail.stream(mid))
45            .try_flatten();
46        stream.boxed()
47    }
48
49    fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
50        let head_ser = self.head.to_serializable()?;
51        let tail_ser = self.tail.to_serializable()?;
52
53        // Attempt to flatten if head is also a Chain
54        let mut steps = Vec::new();
55        match head_ser {
56            crate::serde::SerializableRunnable::Chain { steps: mut s } => steps.append(&mut s),
57            _ => steps.push(head_ser),
58        }
59
60        // Same for tail?? No, tail is just the next step.
61        // Actually, Chain<A, Chain<B, C>> is A -> B -> C.
62        // So if tail is a chain, we append its steps.
63        match tail_ser {
64            crate::serde::SerializableRunnable::Chain { steps: mut s } => steps.append(&mut s),
65            _ => steps.push(tail_ser),
66        }
67
68        Some(crate::serde::SerializableRunnable::Chain { steps })
69    }
70}
71
72pub trait RunnableExt<Input: Send + 'static, Output: Send + 'static>:
73    Runnable<Input, Output> + Sized
74{
75    fn then<NextOutput, Next>(self, next: Next) -> Chain<Self, Next, Output>
76    where
77        Next: Runnable<Output, NextOutput> + Send + Sync,
78        NextOutput: Send + 'static,
79    {
80        Chain::new(self, next)
81    }
82
83    fn with_retries(self, max_attempts: usize) -> Retrying<Self>
84    where
85        Self: Send + Sync,
86        Input: Clone,
87    {
88        Retrying::new(self, max_attempts)
89    }
90
91    fn bind(self, args: crate::Value) -> crate::RunnableBinding<Self, Input, Output>
92    where
93        Self: Send + Sync,
94        Input: crate::Bindable + Clone + Send + 'static,
95        Output: Send + Sync + 'static,
96    {
97        crate::RunnableBinding::new(self, args)
98    }
99
100    fn with_fallbacks(
101        self,
102        fallbacks: Vec<std::sync::Arc<dyn Runnable<Input, Output> + Send + Sync>>,
103    ) -> crate::RunnableWithFallbacks<Input, Output>
104    where
105        Self: Send + Sync + 'static,
106        Input: Clone + Send + 'static,
107    {
108        crate::RunnableWithFallbacks::new(std::sync::Arc::new(self), fallbacks)
109    }
110
111    fn with_timeout(self, duration: std::time::Duration) -> crate::TimeLimited<Self>
112    where
113        Self: Send + Sync,
114        Input: Clone,
115    {
116        crate::TimeLimited::new(self, duration)
117    }
118
119    fn with_rate_limit(self, requests_per_minute: u32) -> crate::RateLimited<Self>
120    where
121        Self: Send + Sync,
122        Input: Clone,
123    {
124        crate::RateLimited::new(self, requests_per_minute)
125    }
126}
127
128impl<Input: Send + 'static, Output: Send + 'static, T> RunnableExt<Input, Output> for T where
129    T: Runnable<Input, Output> + Sized
130{
131}
132
133use crate::Value;
134use std::sync::Arc;
135
136/// A runtime-constructed chain that operates on `Value`.
137/// This is used for deserialization where types are not known at compile time.
138pub struct RuntimeChain {
139    steps: Vec<Arc<dyn Runnable<Value, Value>>>,
140}
141
142impl RuntimeChain {
143    pub fn new(steps: Vec<Arc<dyn Runnable<Value, Value>>>) -> Self {
144        Self { steps }
145    }
146}
147
148#[async_trait]
149impl Runnable<Value, Value> for RuntimeChain {
150    async fn invoke(&self, input: Value) -> Result<Value, WesichainError> {
151        let mut current = input;
152        for step in &self.steps {
153            current = step.invoke(current).await?;
154        }
155        Ok(current)
156    }
157
158    fn stream<'a>(&'a self, input: Value) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
159        if self.steps.is_empty() {
160            return stream::empty().boxed();
161        }
162
163        let steps = &self.steps;
164
165        let s = async move {
166            let mut current = input;
167            let last_idx = steps.len() - 1;
168            for (i, step) in steps.iter().enumerate() {
169                if i == last_idx {
170                    break;
171                }
172                current = step.invoke(current).await?;
173            }
174            Ok::<Value, WesichainError>(current)
175        };
176
177        stream::once(s)
178            .map_ok(move |val| steps.last().unwrap().stream(val))
179            .try_flatten()
180            .boxed()
181    }
182
183    fn to_serializable(&self) -> Option<crate::serde::SerializableRunnable> {
184        let mut steps = Vec::new();
185        for step in &self.steps {
186            steps.push(step.to_serializable()?);
187        }
188        Some(crate::serde::SerializableRunnable::Chain { steps })
189    }
190}