web_rwkv/runtime/
mod.rs

1use std::{future::Future, marker::PhantomData};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use thiserror::Error;
8
9pub mod infer;
10pub mod loader;
11pub mod model;
12pub mod softmax;
13pub mod v4;
14pub mod v5;
15pub mod v6;
16pub mod v7;
17
18// const MAX_QUEUE_SIZE: usize = 2;
19
20pub trait JobInfo: Clone + Send + Sync + 'static {
21    /// Check if the info are compatible.
22    fn check(&self, info: &Self) -> bool;
23}
24
25pub trait JobInput: Send + Sync + 'static {
26    /// One chunk of the whole input at a step.
27    type Chunk: Send + Sync + 'static;
28
29    /// Advance the input for a step.
30    fn step(&mut self);
31    /// The current step's chunk to feed into the job.
32    fn chunk(&self) -> Self::Chunk;
33}
34
35/// A [`Job`] to be executed on GPU.
36pub trait Job {
37    type Input: JobInput;
38    type Output: Send + Sync + 'static;
39
40    /// Load the data from CPU to GPU.
41    fn load(&self, input: &<Self::Input as JobInput>::Chunk) -> Result<(), RuntimeError>;
42    /// Submit the job to GPU and execute it immediately.
43    fn submit(&mut self);
44    #[cfg(not(target_arch = "wasm32"))]
45    /// Wait for the job to finish and read the data back.
46    fn back(self) -> impl Future<Output = Result<Self::Output, RuntimeError>> + Send;
47    #[cfg(target_arch = "wasm32")]
48    /// Wait for the job to finish and read the data back.
49    fn back(self) -> impl Future<Output = Result<Self::Output, RuntimeError>>;
50}
51
52pub trait Dispatcher<J: Job> {
53    type Info;
54
55    /// Build a [`Job`] from the given info.
56    /// This usually involves creating a list of GPU commands (but not actually execution).
57    fn dispatch(&self, info: Self::Info) -> Result<J, RuntimeError>;
58}
59
60#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))]
61#[allow(clippy::type_complexity)]
62#[derive(Debug)]
63struct Submission<I: infer::Infer> {
64    input: I::Input,
65    sender: flume::Sender<Result<(I::Input, I::Output), RuntimeError>>,
66}
67
68#[derive(Debug, Error)]
69pub enum RuntimeError {
70    #[error("input iterator exhausted")]
71    InputExhausted,
72    #[error("tensor error")]
73    TensorError(#[from] crate::tensor::TensorError),
74    #[error("recv error")]
75    RecvError(#[from] flume::RecvError),
76    #[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))]
77    #[error("join error")]
78    JoinError(#[from] tokio::task::JoinError),
79}
80
81#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))]
82#[derive(Debug, Clone)]
83pub struct TokioRuntime<I: infer::Infer>(flume::Sender<Submission<I>>);
84
85#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))]
86#[allow(clippy::type_complexity)]
87impl<I, T, F> TokioRuntime<I>
88where
89    I: infer::Infer,
90    T: JobInfo,
91    F: Iterator<Item = T> + Send + 'static,
92    for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
93{
94    pub async fn new<M, J>(bundle: M) -> Self
95    where
96        M: Dispatcher<J, Info = T> + Send + Sync + 'static,
97        J: Job<Input = I::Input, Output = I::Output> + Send + 'static,
98    {
99        let (sender, receiver) = flume::bounded(1);
100        let handle = tokio::spawn(Self::run(bundle.into(), receiver));
101        tokio::spawn(async move {
102            if let Err(err) = handle.await {
103                log::error!("{err}");
104            }
105        });
106        Self(sender)
107    }
108
109    async fn run<M, J>(model: std::sync::Arc<M>, receiver: flume::Receiver<Submission<I>>)
110    where
111        M: Dispatcher<J, Info = T> + Send + Sync + 'static,
112        J: Job<Input = I::Input, Output = I::Output> + Send + 'static,
113    {
114        let mut queue: Vec<(T, tokio::task::JoinHandle<Result<J, RuntimeError>>)> = vec![];
115        let mut iter: Option<F> = None;
116        let mut predict: usize = 0;
117
118        'main: while let Ok(Submission { input, sender }) = receiver.recv_async().await {
119            let Some(info) = (&input).into_iter().next() else {
120                let _ = sender.send(Err(RuntimeError::InputExhausted));
121                continue 'main;
122            };
123
124            let chunk = input.chunk();
125
126            let mut job = loop {
127                let mut candidates = vec![];
128                let mut remain = vec![];
129                for (key, handle) in queue {
130                    match (candidates.is_empty(), info.check(&key)) {
131                        (true, false) => handle.abort(),
132                        (false, false) => remain.push((key, handle)),
133                        (_, true) => candidates.push(handle),
134                    }
135                }
136                queue = remain;
137
138                predict = match predict {
139                    2 => 1,
140                    1 => 0,
141                    0 => 2,
142                    _ => unreachable!(),
143                };
144
145                // we have a cache miss, restart the pipeline
146                if candidates.is_empty() || iter.is_none() {
147                    iter = Some((&input).into_iter());
148                    predict = 2;
149                }
150                let iter = iter.as_mut().unwrap();
151
152                for info in iter.take(predict) {
153                    #[cfg(feature = "trace")]
154                    tracing::event!(
155                        tracing::Level::TRACE,
156                        "launch ({queue}, {candidates}, {predict})",
157                        queue = queue.len(),
158                        candidates = candidates.len(),
159                        predict = predict
160                    );
161
162                    let key = info.clone();
163                    let model = model.clone();
164                    let handle = tokio::task::spawn_blocking(move || model.dispatch(key));
165                    queue.push((info.clone(), handle));
166                }
167
168                if !candidates.is_empty() {
169                    let (job, _, remain) = futures::future::select_all(candidates).await;
170                    let mut remain = remain
171                        .into_iter()
172                        .map(|handle| (info.clone(), handle))
173                        .collect();
174                    std::mem::swap(&mut queue, &mut remain);
175                    queue.append(&mut remain);
176
177                    break match job {
178                        Ok(Ok(job)) => job,
179                        Ok(Err(error)) => {
180                            let _ = sender.send(Err(error));
181                            continue 'main;
182                        }
183                        Err(error) => {
184                            let _ = sender.send(Err(error.into()));
185                            continue 'main;
186                        }
187                    };
188                }
189            };
190
191            if let Err(error) = job.load(&chunk) {
192                let _ = sender.send(Err(error));
193                continue 'main;
194            }
195
196            #[cfg(feature = "trace")]
197            let _span = tracing::trace_span!("submit").entered();
198            job.submit();
199
200            // read back the results asynchronously
201            tokio::spawn(async move {
202                let output = job.back().await;
203                let mut input = input;
204                input.step();
205                let _ = sender.send(output.map(|output| (input, output)));
206            });
207        }
208    }
209
210    /// Perform (partial) inference and return the remaining input and (perhaps partial) output.
211    /// The amount of input processed during one call is bound by the input chunk size.
212    pub async fn infer(&self, input: I::Input) -> Result<(I::Input, I::Output), RuntimeError> {
213        let (sender, receiver) = flume::bounded(1);
214        let submission = Submission { input, sender };
215        let _ = self.0.send_async(submission).await;
216        receiver.recv_async().await?
217    }
218}
219
220#[derive(Debug, Clone)]
221pub struct SimpleRuntime<M, I, J>(M, PhantomData<(I, J)>);
222
223impl<M, I, J> SimpleRuntime<M, I, J> {
224    #[inline]
225    pub fn new<T, F>(bundle: M) -> Self
226    where
227        M: Dispatcher<J, Info = T>,
228        I: infer::Infer,
229        J: Job<Input = I::Input, Output = I::Output>,
230        T: JobInfo,
231        F: Iterator<Item = T> + Send + 'static,
232        for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
233    {
234        Self(bundle, PhantomData)
235    }
236
237    pub async fn infer<T, F>(
238        &self,
239        mut input: I::Input,
240    ) -> Result<(I::Input, I::Output), RuntimeError>
241    where
242        M: Dispatcher<J, Info = T>,
243        I: infer::Infer,
244        J: Job<Input = I::Input, Output = I::Output>,
245        T: JobInfo,
246        F: Iterator<Item = T> + Send + 'static,
247        for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
248    {
249        let Some(info) = (&input).into_iter().next() else {
250            return Err(RuntimeError::InputExhausted);
251        };
252        let chunk = input.chunk();
253
254        let mut job = self.0.dispatch(info)?;
255        job.load(&chunk)?;
256        job.submit();
257
258        let output = job.back().await?;
259        input.step();
260
261        Ok((input, output))
262    }
263}
264
265#[allow(clippy::type_complexity)]
266pub trait Runtime<I: infer::Infer> {
267    #[cfg(not(target_arch = "wasm32"))]
268    fn infer(&self, input: I::Input) -> BoxFuture<'_, Result<(I::Input, I::Output), RuntimeError>>;
269
270    #[cfg(target_arch = "wasm32")]
271    fn infer(
272        &self,
273        input: I::Input,
274    ) -> LocalBoxFuture<'_, Result<(I::Input, I::Output), RuntimeError>>;
275}
276
277#[cfg(all(not(target_arch = "wasm32"), feature = "tokio"))]
278#[allow(clippy::type_complexity)]
279impl<I, T, F> Runtime<I> for TokioRuntime<I>
280where
281    I: infer::Infer,
282    T: JobInfo,
283    F: Iterator<Item = T> + Send + 'static,
284    for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
285{
286    #[cfg(not(target_arch = "wasm32"))]
287    fn infer(&self, input: I::Input) -> BoxFuture<'_, Result<(I::Input, I::Output), RuntimeError>> {
288        Box::pin(self.infer(input))
289    }
290}
291
292#[cfg(not(target_arch = "wasm32"))]
293#[allow(clippy::type_complexity)]
294impl<M, I, J, T, F> Runtime<I> for SimpleRuntime<M, I, J>
295where
296    I: infer::Infer,
297    J: Job<Input = I::Input, Output = I::Output> + Send + Sync,
298    M: Dispatcher<J, Info = T> + Sync,
299    T: JobInfo,
300    F: Iterator<Item = T> + Send + 'static,
301    for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
302{
303    fn infer(&self, input: I::Input) -> BoxFuture<'_, Result<(I::Input, I::Output), RuntimeError>> {
304        Box::pin(self.infer(input))
305    }
306}
307
308#[cfg(target_arch = "wasm32")]
309#[allow(clippy::type_complexity)]
310impl<M, I, J, T, F> Runtime<I> for SimpleRuntime<M, I, J>
311where
312    I: infer::Infer,
313    J: Job<Input = I::Input, Output = I::Output>,
314    M: Dispatcher<J, Info = T>,
315    T: JobInfo,
316    F: Iterator<Item = T> + Send + 'static,
317    for<'a> &'a I::Input: IntoIterator<Item = T, IntoIter = F>,
318{
319    fn infer(
320        &self,
321        input: I::Input,
322    ) -> LocalBoxFuture<'_, Result<(I::Input, I::Output), RuntimeError>> {
323        Box::pin(self.infer(input))
324    }
325}