1use std::pin::Pin;
2
3use async_trait::async_trait;
4use futures::Stream;
5use synaptic_core::{RunnableConfig, SynapticError};
6
7pub type RunnableOutputStream<'a, O> =
9 Pin<Box<dyn Stream<Item = Result<O, SynapticError>> + Send + 'a>>;
10
11#[async_trait]
17pub trait Runnable<I, O>: Send + Sync
18where
19 I: Send + 'static,
20 O: Send + 'static,
21{
22 async fn invoke(&self, input: I, config: &RunnableConfig) -> Result<O, SynapticError>;
24
25 async fn batch(
27 &self,
28 inputs: Vec<I>,
29 config: &RunnableConfig,
30 ) -> Vec<Result<O, SynapticError>> {
31 let mut results = Vec::with_capacity(inputs.len());
32 for input in inputs {
33 results.push(self.invoke(input, config).await);
34 }
35 results
36 }
37
38 fn stream<'a>(&'a self, input: I, config: &'a RunnableConfig) -> RunnableOutputStream<'a, O>
41 where
42 I: 'a,
43 {
44 Box::pin(async_stream::stream! {
45 match self.invoke(input, config).await {
46 Ok(output) => yield Ok(output),
47 Err(e) => yield Err(e),
48 }
49 })
50 }
51
52 fn boxed(self) -> BoxRunnable<I, O>
54 where
55 Self: Sized + 'static,
56 {
57 BoxRunnable {
58 inner: Box::new(self),
59 }
60 }
61}
62
63trait RunnableStream<I: Send + 'static, O: Send + 'static>: Runnable<I, O> {
65 fn stream_boxed<'a>(
66 &'a self,
67 input: I,
68 config: &'a RunnableConfig,
69 ) -> RunnableOutputStream<'a, O>
70 where
71 I: 'a;
72}
73
74impl<I: Send + 'static, O: Send + 'static, T: Runnable<I, O>> RunnableStream<I, O> for T {
75 fn stream_boxed<'a>(
76 &'a self,
77 input: I,
78 config: &'a RunnableConfig,
79 ) -> RunnableOutputStream<'a, O>
80 where
81 I: 'a,
82 {
83 self.stream(input, config)
84 }
85}
86
87pub struct BoxRunnable<I: Send + 'static, O: Send + 'static> {
94 inner: Box<dyn RunnableStream<I, O>>,
95}
96
97impl<I: Send + 'static, O: Send + 'static> BoxRunnable<I, O> {
98 pub fn new<R: Runnable<I, O> + 'static>(runnable: R) -> Self {
99 Self {
100 inner: Box::new(runnable),
101 }
102 }
103
104 pub fn stream<'a>(
106 &'a self,
107 input: I,
108 config: &'a RunnableConfig,
109 ) -> RunnableOutputStream<'a, O> {
110 self.inner.stream_boxed(input, config)
111 }
112
113 pub fn bind(
116 self,
117 transform: impl Fn(RunnableConfig) -> RunnableConfig + Send + Sync + 'static,
118 ) -> BoxRunnable<I, O> {
119 BoxRunnable::new(RunnableBind {
120 inner: self,
121 config_transform: Box::new(transform),
122 })
123 }
124
125 pub fn with_config(self, config: RunnableConfig) -> BoxRunnable<I, O> {
128 self.bind(move |_| config.clone())
129 }
130
131 pub fn with_listeners(
133 self,
134 on_start: impl Fn(&RunnableConfig) + Send + Sync + 'static,
135 on_end: impl Fn(&RunnableConfig) + Send + Sync + 'static,
136 ) -> BoxRunnable<I, O> {
137 BoxRunnable::new(RunnableWithListeners {
138 inner: self,
139 on_start: Box::new(on_start),
140 on_end: Box::new(on_end),
141 })
142 }
143}
144
145impl<I: Send + 'static, O: Send + 'static> BoxRunnable<Vec<I>, Vec<O>> {
146 pub fn map_each(inner: BoxRunnable<I, O>) -> BoxRunnable<Vec<I>, Vec<O>> {
152 BoxRunnable::new(crate::each::RunnableEach::new(inner))
153 }
154}
155
156#[async_trait]
157impl<I: Send + 'static, O: Send + 'static> Runnable<I, O> for BoxRunnable<I, O> {
158 async fn invoke(&self, input: I, config: &RunnableConfig) -> Result<O, SynapticError> {
159 self.inner.invoke(input, config).await
160 }
161
162 async fn batch(
163 &self,
164 inputs: Vec<I>,
165 config: &RunnableConfig,
166 ) -> Vec<Result<O, SynapticError>> {
167 self.inner.batch(inputs, config).await
168 }
169
170 fn stream<'a>(&'a self, input: I, config: &'a RunnableConfig) -> RunnableOutputStream<'a, O>
171 where
172 I: 'a,
173 {
174 self.inner.stream_boxed(input, config)
175 }
176}
177
178struct RunnableBind<I: Send + 'static, O: Send + 'static> {
180 inner: BoxRunnable<I, O>,
181 config_transform: Box<dyn Fn(RunnableConfig) -> RunnableConfig + Send + Sync>,
182}
183
184#[async_trait]
185impl<I: Send + 'static, O: Send + 'static> Runnable<I, O> for RunnableBind<I, O> {
186 async fn invoke(&self, input: I, config: &RunnableConfig) -> Result<O, SynapticError> {
187 let transformed = (self.config_transform)(config.clone());
188 self.inner.invoke(input, &transformed).await
189 }
190
191 fn stream<'a>(&'a self, input: I, config: &'a RunnableConfig) -> RunnableOutputStream<'a, O>
192 where
193 I: 'a,
194 {
195 Box::pin(async_stream::stream! {
196 let transformed = (self.config_transform)(config.clone());
197 let mut inner_stream = std::pin::pin!(self.inner.stream(input, &transformed));
198 use futures::StreamExt;
199 while let Some(item) = inner_stream.next().await {
200 yield item;
201 }
202 })
203 }
204}
205
206struct RunnableWithListeners<I: Send + 'static, O: Send + 'static> {
208 inner: BoxRunnable<I, O>,
209 on_start: Box<dyn Fn(&RunnableConfig) + Send + Sync>,
210 on_end: Box<dyn Fn(&RunnableConfig) + Send + Sync>,
211}
212
213#[async_trait]
214impl<I: Send + 'static, O: Send + 'static> Runnable<I, O> for RunnableWithListeners<I, O> {
215 async fn invoke(&self, input: I, config: &RunnableConfig) -> Result<O, SynapticError> {
216 (self.on_start)(config);
217 let result = self.inner.invoke(input, config).await;
218 (self.on_end)(config);
219 result
220 }
221
222 fn stream<'a>(&'a self, input: I, config: &'a RunnableConfig) -> RunnableOutputStream<'a, O>
223 where
224 I: 'a,
225 {
226 Box::pin(async_stream::stream! {
227 (self.on_start)(config);
228 use futures::StreamExt;
229 let mut s = std::pin::pin!(self.inner.stream(input, config));
230 while let Some(item) = s.next().await {
231 yield item;
232 }
233 (self.on_end)(config);
234 })
235 }
236}