1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5use synaptic_core::{RunnableConfig, SynapseError};
6
7pub type RunnableOutputStream<'a, O> =
9 Pin<Box<dyn Stream<Item = Result<O, SynapseError>> + Send + 'a>>;
10
11#[async_trait]
17pub trait Runnable<I, O>: Send + Sync
18where
19 I: Send + 'static,
20 O: Send + 'static,
21{
22 async fn invoke(&self, input: I, config: &RunnableConfig) -> Result<O, SynapseError>;
24
25 async fn batch(&self, inputs: Vec<I>, config: &RunnableConfig) -> Vec<Result<O, SynapseError>> {
27 let mut results = Vec::with_capacity(inputs.len());
28 for input in inputs {
29 results.push(self.invoke(input, config).await);
30 }
31 results
32 }
33
34 fn stream<'a>(&'a self, input: I, config: &'a RunnableConfig) -> RunnableOutputStream<'a, O>
37 where
38 I: 'a,
39 {
40 Box::pin(async_stream::stream! {
41 match self.invoke(input, config).await {
42 Ok(output) => yield Ok(output),
43 Err(e) => yield Err(e),
44 }
45 })
46 }
47
48 fn boxed(self) -> BoxRunnable<I, O>
50 where
51 Self: Sized + 'static,
52 {
53 BoxRunnable {
54 inner: Box::new(self),
55 }
56 }
57}
58
59trait RunnableStream<I: Send + 'static, O: Send + 'static>: Runnable<I, O> {
61 fn stream_boxed<'a>(
62 &'a self,
63 input: I,
64 config: &'a RunnableConfig,
65 ) -> RunnableOutputStream<'a, O>
66 where
67 I: 'a;
68}
69
70impl<I: Send + 'static, O: Send + 'static, T: Runnable<I, O>> RunnableStream<I, O> for T {
71 fn stream_boxed<'a>(
72 &'a self,
73 input: I,
74 config: &'a RunnableConfig,
75 ) -> RunnableOutputStream<'a, O>
76 where
77 I: 'a,
78 {
79 self.stream(input, config)
80 }
81}
82
83pub struct BoxRunnable<I: Send + 'static, O: Send + 'static> {
90 inner: Box<dyn RunnableStream<I, O>>,
91}
92
93impl<I: Send + 'static, O: Send + 'static> BoxRunnable<I, O> {
94 pub fn new<R: Runnable<I, O> + 'static>(runnable: R) -> Self {
95 Self {
96 inner: Box::new(runnable),
97 }
98 }
99
100 pub fn stream<'a>(
102 &'a self,
103 input: I,
104 config: &'a RunnableConfig,
105 ) -> RunnableOutputStream<'a, O> {
106 self.inner.stream_boxed(input, config)
107 }
108
109 pub fn bind(
112 self,
113 transform: impl Fn(RunnableConfig) -> RunnableConfig + Send + Sync + 'static,
114 ) -> BoxRunnable<I, O> {
115 BoxRunnable::new(RunnableBind {
116 inner: self,
117 config_transform: Box::new(transform),
118 })
119 }
120
121 pub fn with_config(self, config: RunnableConfig) -> BoxRunnable<I, O> {
124 self.bind(move |_| config.clone())
125 }
126
127 pub fn with_listeners(
129 self,
130 on_start: impl Fn(&RunnableConfig) + Send + Sync + 'static,
131 on_end: impl Fn(&RunnableConfig) + Send + Sync + 'static,
132 ) -> BoxRunnable<I, O> {
133 BoxRunnable::new(RunnableWithListeners {
134 inner: self,
135 on_start: Box::new(on_start),
136 on_end: Box::new(on_end),
137 })
138 }
139}
140
141impl<I: Send + 'static, O: Send + 'static> BoxRunnable<Vec<I>, Vec<O>> {
142 pub fn map_each(inner: BoxRunnable<I, O>) -> BoxRunnable<Vec<I>, Vec<O>> {
148 BoxRunnable::new(crate::each::RunnableEach::new(inner))
149 }
150}
151
152#[async_trait]
153impl<I: Send + 'static, O: Send + 'static> Runnable<I, O> for BoxRunnable<I, O> {
154 async fn invoke(&self, input: I, config: &RunnableConfig) -> Result<O, SynapseError> {
155 self.inner.invoke(input, config).await
156 }
157
158 async fn batch(&self, inputs: Vec<I>, config: &RunnableConfig) -> Vec<Result<O, SynapseError>> {
159 self.inner.batch(inputs, config).await
160 }
161
162 fn stream<'a>(&'a self, input: I, config: &'a RunnableConfig) -> RunnableOutputStream<'a, O>
163 where
164 I: 'a,
165 {
166 self.inner.stream_boxed(input, config)
167 }
168}
169
170struct RunnableBind<I: Send + 'static, O: Send + 'static> {
172 inner: BoxRunnable<I, O>,
173 config_transform: Box<dyn Fn(RunnableConfig) -> RunnableConfig + Send + Sync>,
174}
175
176#[async_trait]
177impl<I: Send + 'static, O: Send + 'static> Runnable<I, O> for RunnableBind<I, O> {
178 async fn invoke(&self, input: I, config: &RunnableConfig) -> Result<O, SynapseError> {
179 let transformed = (self.config_transform)(config.clone());
180 self.inner.invoke(input, &transformed).await
181 }
182
183 fn stream<'a>(&'a self, input: I, config: &'a RunnableConfig) -> RunnableOutputStream<'a, O>
184 where
185 I: 'a,
186 {
187 Box::pin(async_stream::stream! {
188 let transformed = (self.config_transform)(config.clone());
189 let mut inner_stream = std::pin::pin!(self.inner.stream(input, &transformed));
190 use futures::StreamExt;
191 while let Some(item) = inner_stream.next().await {
192 yield item;
193 }
194 })
195 }
196}
197
198struct RunnableWithListeners<I: Send + 'static, O: Send + 'static> {
200 inner: BoxRunnable<I, O>,
201 on_start: Box<dyn Fn(&RunnableConfig) + Send + Sync>,
202 on_end: Box<dyn Fn(&RunnableConfig) + Send + Sync>,
203}
204
205#[async_trait]
206impl<I: Send + 'static, O: Send + 'static> Runnable<I, O> for RunnableWithListeners<I, O> {
207 async fn invoke(&self, input: I, config: &RunnableConfig) -> Result<O, SynapseError> {
208 (self.on_start)(config);
209 let result = self.inner.invoke(input, config).await;
210 (self.on_end)(config);
211 result
212 }
213
214 fn stream<'a>(&'a self, input: I, config: &'a RunnableConfig) -> RunnableOutputStream<'a, O>
215 where
216 I: 'a,
217 {
218 Box::pin(async_stream::stream! {
219 (self.on_start)(config);
220 use futures::StreamExt;
221 let mut s = std::pin::pin!(self.inner.stream(input, config));
222 while let Some(item) = s.next().await {
223 yield item;
224 }
225 (self.on_end)(config);
226 })
227 }
228}