1use std::{future::Future, marker::PhantomData};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use thiserror::Error;
8
9pub mod infer;
10pub mod loader;
11pub mod model;
12pub mod softmax;
13pub mod v4;
14pub mod v5;
15pub mod v6;
16pub mod v7;
17
18pub trait JobInfo: Clone + Send + Sync + 'static {
21 fn check(&self, info: &Self) -> bool;
23}
24
25pub trait JobInput: Send + Sync + 'static {
26 type Chunk: Send + Sync + 'static;
28
29 fn step(&mut self);
31 fn chunk(&self) -> Self::Chunk;
33}
34
35pub trait Job {
37 type Input: JobInput;
38 type Output: Send + Sync + 'static;
39
40 fn load(&self, input: &<Self::Input as JobInput>::Chunk) -> Result<(), RuntimeError>;
42 fn submit(&mut self);
44 #[cfg(not(target_arch = "wasm32"))]
45 fn back(self) -> impl Future<Output = Result<Self::Output, RuntimeError>> + Send;
47 #[cfg(target_arch = "wasm32")]
48 fn back(self) -> impl Future<Output = Result<Self::Output, RuntimeError>>;
50}
51
52pub trait Dispatcher<J: Job> {
53 type Info;
54
55 fn dispatch(&self, info: Self::Info) -> Result<J, RuntimeError>;
58}
59
60#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))]
61#[allow(clippy::type_complexity)]
62#[derive(Debug)]
63struct Submission<I: infer::Infer> {
64 input: I::Input,
65 sender: flume::Sender<Result<(I::Input, I::Output), RuntimeError>>,
66}
67
68#[derive(Debug, Error)]
69pub enum RuntimeError {
70 #[error("input iterator exhausted")]
71 InputExhausted,
72 #[error("tensor error")]
73 TensorError(#[from] crate::tensor::TensorError),
74 #[error("recv error")]
75 RecvError(#[from] flume::RecvError),
76 #[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))]
77 #[error("join error")]
78 JoinError(#[from] tokio::task::JoinError),
79}
80
81#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))]
82#[derive(Debug, Clone)]
83pub struct TokioRuntime<I: infer::Infer>(flume::Sender<Submission<I>>);
84
85#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))]
86#[allow(clippy::type_complexity)]
87impl<I, T, F> TokioRuntime<I>
88where
89 I: infer::Infer,
90 T: JobInfo,
91 F: Iterator<Item = T> + Send + 'static,
92 for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
93{
94 pub async fn new<M, J>(bundle: M) -> Self
95 where
96 M: Dispatcher<J, Info = T> + Send + Sync + 'static,
97 J: Job<Input = I::Input, Output = I::Output> + Send + 'static,
98 {
99 let (sender, receiver) = flume::bounded(1);
100 let handle = tokio::spawn(Self::run(bundle.into(), receiver));
101 tokio::spawn(async move {
102 if let Err(err) = handle.await {
103 log::error!("{err}");
104 }
105 });
106 Self(sender)
107 }
108
109 async fn run<M, J>(model: std::sync::Arc<M>, receiver: flume::Receiver<Submission<I>>)
110 where
111 M: Dispatcher<J, Info = T> + Send + Sync + 'static,
112 J: Job<Input = I::Input, Output = I::Output> + Send + 'static,
113 {
114 let mut queue: Vec<(T, tokio::task::JoinHandle<Result<J, RuntimeError>>)> = vec![];
115 let mut iter: Option<F> = None;
116 let mut predict: usize = 0;
117
118 'main: while let Ok(Submission { input, sender }) = receiver.recv_async().await {
119 let Some(info) = (&input).into_iter().next() else {
120 let _ = sender.send(Err(RuntimeError::InputExhausted));
121 continue 'main;
122 };
123
124 let chunk = input.chunk();
125
126 let mut job = loop {
127 let mut candidates = vec![];
128 let mut remain = vec![];
129 for (key, handle) in queue {
130 match (candidates.is_empty(), info.check(&key)) {
131 (true, false) => handle.abort(),
132 (false, false) => remain.push((key, handle)),
133 (_, true) => candidates.push(handle),
134 }
135 }
136 queue = remain;
137
138 predict = match predict {
139 2 => 1,
140 1 => 0,
141 0 => 2,
142 _ => unreachable!(),
143 };
144
145 if candidates.is_empty() || iter.is_none() {
147 iter = Some((&input).into_iter());
148 predict = 2;
149 }
150 let iter = iter.as_mut().unwrap();
151
152 for info in iter.take(predict) {
153 #[cfg(feature = "trace")]
154 tracing::event!(
155 tracing::Level::TRACE,
156 "launch ({queue}, {candidates}, {predict})",
157 queue = queue.len(),
158 candidates = candidates.len(),
159 predict = predict
160 );
161
162 let key = info.clone();
163 let model = model.clone();
164 let handle = tokio::task::spawn_blocking(move || model.dispatch(key));
165 queue.push((info.clone(), handle));
166 }
167
168 if !candidates.is_empty() {
169 let (job, _, remain) = futures::future::select_all(candidates).await;
170 let mut remain = remain
171 .into_iter()
172 .map(|handle| (info.clone(), handle))
173 .collect();
174 std::mem::swap(&mut queue, &mut remain);
175 queue.append(&mut remain);
176
177 break match job {
178 Ok(Ok(job)) => job,
179 Ok(Err(error)) => {
180 let _ = sender.send(Err(error));
181 continue 'main;
182 }
183 Err(error) => {
184 let _ = sender.send(Err(error.into()));
185 continue 'main;
186 }
187 };
188 }
189 };
190
191 if let Err(error) = job.load(&chunk) {
192 let _ = sender.send(Err(error));
193 continue 'main;
194 }
195
196 #[cfg(feature = "trace")]
197 let _span = tracing::trace_span!("submit").entered();
198 job.submit();
199
200 tokio::spawn(async move {
202 let output = job.back().await;
203 let mut input = input;
204 input.step();
205 let _ = sender.send(output.map(|output| (input, output)));
206 });
207 }
208 }
209
210 pub async fn infer(&self, input: I::Input) -> Result<(I::Input, I::Output), RuntimeError> {
213 let (sender, receiver) = flume::bounded(1);
214 let submission = Submission { input, sender };
215 let _ = self.0.send_async(submission).await;
216 receiver.recv_async().await?
217 }
218}
219
220#[derive(Debug, Clone)]
221pub struct SimpleRuntime<M, I, J>(M, PhantomData<(I, J)>);
222
223impl<M, I, J> SimpleRuntime<M, I, J> {
224 #[inline]
225 pub fn new<T, F>(bundle: M) -> Self
226 where
227 M: Dispatcher<J, Info = T>,
228 I: infer::Infer,
229 J: Job<Input = I::Input, Output = I::Output>,
230 T: JobInfo,
231 F: Iterator<Item = T> + Send + 'static,
232 for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
233 {
234 Self(bundle, PhantomData)
235 }
236
237 pub async fn infer<T, F>(
238 &self,
239 mut input: I::Input,
240 ) -> Result<(I::Input, I::Output), RuntimeError>
241 where
242 M: Dispatcher<J, Info = T>,
243 I: infer::Infer,
244 J: Job<Input = I::Input, Output = I::Output>,
245 T: JobInfo,
246 F: Iterator<Item = T> + Send + 'static,
247 for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
248 {
249 let Some(info) = (&input).into_iter().next() else {
250 return Err(RuntimeError::InputExhausted);
251 };
252 let chunk = input.chunk();
253
254 let mut job = self.0.dispatch(info)?;
255 job.load(&chunk)?;
256 job.submit();
257
258 let output = job.back().await?;
259 input.step();
260
261 Ok((input, output))
262 }
263}
264
265#[allow(clippy::type_complexity)]
266pub trait Runtime<I: infer::Infer> {
267 #[cfg(not(target_arch = "wasm32"))]
268 fn infer(&self, input: I::Input) -> BoxFuture<'_, Result<(I::Input, I::Output), RuntimeError>>;
269
270 #[cfg(target_arch = "wasm32")]
271 fn infer(
272 &self,
273 input: I::Input,
274 ) -> LocalBoxFuture<'_, Result<(I::Input, I::Output), RuntimeError>>;
275}
276
277#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))]
278#[allow(clippy::type_complexity)]
279impl<I, T, F> Runtime<I> for TokioRuntime<I>
280where
281 I: infer::Infer,
282 T: JobInfo,
283 F: Iterator<Item = T> + Send + 'static,
284 for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
285{
286 #[cfg(not(target_arch = "wasm32"))]
287 fn infer(&self, input: I::Input) -> BoxFuture<'_, Result<(I::Input, I::Output), RuntimeError>> {
288 Box::pin(self.infer(input))
289 }
290}
291
292#[cfg(not(target_arch = "wasm32"))]
293#[allow(clippy::type_complexity)]
294impl<M, I, J, T, F> Runtime<I> for SimpleRuntime<M, I, J>
295where
296 I: infer::Infer,
297 J: Job<Input = I::Input, Output = I::Output> + Send + Sync,
298 M: Dispatcher<J, Info = T> + Sync,
299 T: JobInfo,
300 F: Iterator<Item = T> + Send + 'static,
301 for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
302{
303 fn infer(&self, input: I::Input) -> BoxFuture<'_, Result<(I::Input, I::Output), RuntimeError>> {
304 Box::pin(self.infer(input))
305 }
306}
307
308#[cfg(target_arch = "wasm32")]
309#[allow(clippy::type_complexity)]
310impl<M, I, J, T, F> Runtime<I> for SimpleRuntime<M, I, J>
311where
312 I: infer::Infer,
313 J: Job<Input = I::Input, Output = I::Output>,
314 M: Dispatcher<J, Info = T>,
315 T: JobInfo,
316 F: Iterator<Item = T> + Send + 'static,
317 for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
318{
319 fn infer(
320 &self,
321 input: I::Input,
322 ) -> LocalBoxFuture<'_, Result<(I::Input, I::Output), RuntimeError>> {
323 Box::pin(self.infer(input))
324 }
325}