Skip to main content

wasm_rs_shared_channel/
spsc.rs

1//! # Single Publisher Single Consumer Channel
2//!
3//! A simple single publisher, single consumer channel can be used to communicate from main thread
4//! to a worker thread or between worker threads.
5//!
6// NOTE: Current algorithm used to send and receive messages
7// has been modeled after https://github.com/willemt/bipbuffer and has not been extensively
8// tested for suitability and/or correctness.
9//
10// This is an ongoing area of development and the algorithm might change at any moment, so
11// one should not base their expectations based on the particularities of the algorithm.
12use super::*;
13use js_sys::{Array, Atomics, Int32Array, SharedArrayBuffer, Uint8Array};
14use std::marker::PhantomData;
15use std::time::Duration;
16#[cfg(test)]
17#[allow(unused_imports)]
18use wasm_rs_dbg::dbg;
19
20/// Shared single publisher, single-consumer channel
21///
22/// A channel can be passed between different threads with their own instances of a WebAssembly
23/// module by caling [`wasm_bindgen::JsValue::from`] on this channel and subsequently calling
24/// [`SharedChannel::from`] on the value in a different thread.
25pub struct SharedChannel<T>
26where
27    T: Shareable,
28{
29    _header: SharedArrayBuffer,
30    _data: SharedArrayBuffer,
31    header: Int32Array,
32    data: Uint8Array,
33    len: u32,
34    phantom_data: PhantomData<T>,
35}
36
37impl<T> From<SharedChannel<T>> for JsValue
38where
39    T: Shareable,
40{
41    fn from(channel: SharedChannel<T>) -> JsValue {
42        let array = Array::new();
43        array.push(&channel._header);
44        array.push(&channel._data);
45        array.into()
46    }
47}
48
49impl<T> From<JsValue> for SharedChannel<T>
50where
51    T: Shareable,
52{
53    fn from(array: JsValue) -> SharedChannel<T> {
54        let array: Array = array.into();
55        let header = array.shift();
56        let data = array.shift();
57        channel_(header.into(), data.into())
58    }
59}
60
61const A_START: u32 = 0;
62const A_END: u32 = 1;
63const B_END: u32 = 2;
64const B_USE: u32 = 3;
65
66impl<T> SharedChannel<T>
67where
68    T: Shareable,
69{
70    fn unused(&self) -> Result<u32, JsValue> {
71        let b_use = (Atomics::load(&self.header, B_USE)? as u32) == 1;
72        if b_use {
73            let a_start = Atomics::load(&self.header, A_START)? as u32;
74            let b_end = Atomics::load(&self.header, B_END)? as u32;
75            Ok(a_start - b_end)
76        } else {
77            let a_end = Atomics::load(&self.header, A_END)? as u32;
78            Ok(self.len - a_end)
79        }
80    }
81
82    fn maybe_switch(&self) -> Result<(), JsValue> {
83        let a_start = Atomics::load(&self.header, A_START)? as u32;
84        let a_end = Atomics::load(&self.header, A_END)? as u32;
85        let b_end = Atomics::load(&self.header, B_END)? as u32;
86        if self.len - a_end < a_start - b_end {
87            Atomics::store(&self.header, B_USE, 1i32)?;
88        }
89        Ok(())
90    }
91
92    /// Consumes and splits channel into a [`Sender`] and a [`Receiver`]
93    ///
94    /// Splitting it into allows us to ensure roles aren't mixed up.
95    pub fn split(self) -> (Sender<T>, Receiver<T>) {
96        (Sender(self.clone()), Receiver(self))
97    }
98}
99
100impl<T> Clone for SharedChannel<T>
101where
102    T: Shareable,
103{
104    fn clone(&self) -> Self {
105        Self {
106            _header: self._header.clone(),
107            _data: self._data.clone(),
108            header: self.header.clone(),
109            data: self.data.clone(),
110            len: self.len,
111            phantom_data: PhantomData,
112        }
113    }
114}
115
116/// Sender part of the channel
117#[derive(Clone)]
118pub struct Sender<T>(pub SharedChannel<T>)
119where
120    T: Shareable;
121
122/// Receiver part of the channel
123pub struct Receiver<T>(pub SharedChannel<T>)
124where
125    T: Shareable;
126
127/// Creates a channel of `len` bytes
128pub fn channel<T>(len: u32) -> SharedChannel<T>
129where
130    T: Shareable,
131{
132    let header = SharedArrayBuffer::new(4 * (std::mem::size_of::<u32>() as u32));
133    let data = SharedArrayBuffer::new(len);
134    channel_(header, data)
135}
136
137fn channel_<T>(header: SharedArrayBuffer, data: SharedArrayBuffer) -> SharedChannel<T>
138where
139    T: Shareable,
140{
141    let header_ = Int32Array::new(&header);
142    let data_ = Uint8Array::new(&data);
143    let len = data_.byte_length();
144    SharedChannel {
145        _header: header,
146        _data: data,
147        header: header_,
148        data: data_,
149        len,
150        phantom_data: PhantomData,
151    }
152}
153
154impl<T> Sender<T>
155where
156    T: Shareable,
157{
158    /// Sends a value into the channel
159    ///
160    /// If there isn't enough space currently in the channel to accommodate
161    /// the value, it'll throw a JavaScript exception (`"not enough space"`)
162    pub fn send(&self, value: &T) -> Result<(), JsValue> {
163        let bytes = value
164            .to_bytes()
165            .map_err(|e| JsValue::from(format!("serialization error: {}", e)))?;
166        let len = bytes.byte_length();
167        if self.0.unused()? < len {
168            return Err("not enough space".to_string().into());
169        }
170        let b_use = (Atomics::load(&self.0.header, B_USE)? as u32) == 1;
171        let end_header = if b_use { B_END } else { A_END };
172
173        let end = Atomics::load(&self.0.header, end_header)? as u32;
174        for i in 0..len {
175            self.0.data.set_index(end + i, bytes.get_index(i));
176        }
177        Atomics::store(&self.0.header, end_header, (end + len) as i32)?;
178        Atomics::notify(&self.0.header, end_header)?;
179        Atomics::notify(&self.0.header, A_START)?;
180
181        self.0.maybe_switch()?;
182
183        Ok(())
184    }
185}
186
187impl<T> Receiver<T>
188where
189    T: Shareable,
190{
191    /// Receives a value from the channel
192    ///
193    /// If `timeout` is `None`, if there is no message, it'll immediately return
194    /// `Ok(None)`.
195    ///
196    /// If `timeout` is `Some(duration)` it will return `Ok(Some(value))` if there was a value,
197    /// otherwise, it'll return `Ok(None)` when timed out.
198    ///
199    /// There's no way to specify an infinite timeout. Instead, a sufficiently large
200    /// [`std::time::Duration`] should be used.
201    pub fn recv(&self, timeout: Option<Duration>) -> Result<Option<T>, JsValue> {
202        let mut array = Uint8Array::new_with_length(0);
203        loop {
204            match T::from(&array)
205                .map_err(|e| JsValue::from(format!("deserialization error: {}", e)))?
206            {
207                Ok(value) => {
208                    return Ok(Some(value));
209                }
210                Err(Expects(sz)) => {
211                    array = Uint8Array::new_with_length(sz);
212                    let mut a_start = Atomics::load(&self.0.header, A_START)? as u32;
213                    let mut a_end = Atomics::load(&self.0.header, A_END)? as u32;
214                    if a_start == a_end || self.0.len < a_start + sz {
215                        match timeout {
216                            None => return Ok(None),
217                            Some(duration) => {
218                                let result = Atomics::wait_with_timeout(
219                                    &self.0.header,
220                                    A_START,
221                                    a_start as i32,
222                                    duration.as_millis() as f64,
223                                )?;
224                                if result == "timed-out" {
225                                    return Ok(None);
226                                }
227                                continue;
228                            }
229                        }
230                    }
231                    for i in 0..sz {
232                        array.set_index(i, self.0.data.get_index(a_start + i));
233                    }
234                    a_start += sz;
235                    let mut b_end = Atomics::load(&self.0.header, B_END)? as u32;
236                    let mut b_use = (Atomics::load(&self.0.header, B_USE)? as u32) == 1;
237                    if a_start == a_end {
238                        if b_use {
239                            a_start = 0;
240                            a_end = b_end;
241                            b_end = 0;
242                            b_use = false;
243                        } else {
244                            a_start = 0;
245                            a_end = 0;
246                        }
247                    }
248                    if T::from(&array)
249                        .map_err(|e| JsValue::from(format!("deserialization error: {}", e)))?
250                        .is_ok()
251                    {
252                        Atomics::store(&self.0.header, B_USE, if b_use { 1i32 } else { 0i32 })?;
253                        Atomics::store(&self.0.header, A_START, a_start as i32)?;
254                        Atomics::store(&self.0.header, A_END, a_end as i32)?;
255                        Atomics::store(&self.0.header, B_END, b_end as i32)?;
256                        self.0.maybe_switch()?;
257                    }
258                }
259            }
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
267
268    use super::*;
269    use wasm_bindgen_test::*;
270
271    #[wasm_bindgen_test]
272    fn test() {
273        let sz = 0u8.to_bytes().unwrap().byte_length();
274        let (sender, receiver) = channel::<u8>(2 * sz).split();
275        sender.send(&1).unwrap();
276        sender.send(&2).unwrap();
277        assert_eq!(receiver.recv(None).unwrap().unwrap(), 1);
278        assert_eq!(receiver.recv(None).unwrap().unwrap(), 2);
279    }
280
281    #[wasm_bindgen_test]
282    fn not_enough_space() {
283        let sz = 0u8.to_bytes().unwrap().byte_length();
284        let (sender, _receiver) = channel::<u8>(1 * sz).split();
285        sender.send(&1).unwrap();
286        assert!(sender.send(&2).is_err());
287    }
288
289    #[wasm_bindgen_test]
290    fn circular() {
291        let sz = 0u8.to_bytes().unwrap().byte_length();
292        let (sender, receiver) = channel::<u8>(8 * sz).split();
293        sender.send(&1).unwrap();
294        sender.send(&2).unwrap();
295        sender.send(&3).unwrap();
296        sender.send(&4).unwrap();
297        sender.send(&5).unwrap();
298        sender.send(&6).unwrap();
299        sender.send(&7).unwrap();
300        sender.send(&8).unwrap();
301        assert_eq!(receiver.recv(None).unwrap().unwrap(), 1);
302        assert_eq!(receiver.recv(None).unwrap().unwrap(), 2);
303        assert_eq!(receiver.recv(None).unwrap().unwrap(), 3);
304        sender.send(&9).unwrap();
305        sender.send(&10).unwrap();
306        sender.send(&11).unwrap();
307        assert_eq!(receiver.recv(None).unwrap().unwrap(), 4);
308        assert_eq!(receiver.recv(None).unwrap().unwrap(), 5);
309        assert_eq!(receiver.recv(None).unwrap().unwrap(), 6);
310        assert_eq!(receiver.recv(None).unwrap().unwrap(), 7);
311        assert_eq!(receiver.recv(None).unwrap().unwrap(), 8);
312        assert_eq!(receiver.recv(None).unwrap().unwrap(), 9);
313        assert_eq!(receiver.recv(None).unwrap().unwrap(), 10);
314        assert_eq!(receiver.recv(None).unwrap().unwrap(), 11);
315    }
316
317    #[wasm_bindgen_test]
318    fn jsvalue() {
319        let sz = 0u8.to_bytes().unwrap().byte_length();
320        let ch = channel::<u8>(2 * sz);
321        let js_value: JsValue = ch.into();
322        let ch: SharedChannel<u8> = js_value.into();
323        let (sender, receiver) = ch.split();
324        sender.send(&1).unwrap();
325        sender.send(&2).unwrap();
326        assert_eq!(receiver.recv(None).unwrap().unwrap(), 1);
327        assert_eq!(receiver.recv(None).unwrap().unwrap(), 2);
328    }
329}