1use super::*;
13use js_sys::{Array, Atomics, Int32Array, SharedArrayBuffer, Uint8Array};
14use std::marker::PhantomData;
15use std::time::Duration;
16#[cfg(test)]
17#[allow(unused_imports)]
18use wasm_rs_dbg::dbg;
19
20pub struct SharedChannel<T>
26where
27 T: Shareable,
28{
29 _header: SharedArrayBuffer,
30 _data: SharedArrayBuffer,
31 header: Int32Array,
32 data: Uint8Array,
33 len: u32,
34 phantom_data: PhantomData<T>,
35}
36
37impl<T> From<SharedChannel<T>> for JsValue
38where
39 T: Shareable,
40{
41 fn from(channel: SharedChannel<T>) -> JsValue {
42 let array = Array::new();
43 array.push(&channel._header);
44 array.push(&channel._data);
45 array.into()
46 }
47}
48
49impl<T> From<JsValue> for SharedChannel<T>
50where
51 T: Shareable,
52{
53 fn from(array: JsValue) -> SharedChannel<T> {
54 let array: Array = array.into();
55 let header = array.shift();
56 let data = array.shift();
57 channel_(header.into(), data.into())
58 }
59}
60
61const A_START: u32 = 0;
62const A_END: u32 = 1;
63const B_END: u32 = 2;
64const B_USE: u32 = 3;
65
66impl<T> SharedChannel<T>
67where
68 T: Shareable,
69{
70 fn unused(&self) -> Result<u32, JsValue> {
71 let b_use = (Atomics::load(&self.header, B_USE)? as u32) == 1;
72 if b_use {
73 let a_start = Atomics::load(&self.header, A_START)? as u32;
74 let b_end = Atomics::load(&self.header, B_END)? as u32;
75 Ok(a_start - b_end)
76 } else {
77 let a_end = Atomics::load(&self.header, A_END)? as u32;
78 Ok(self.len - a_end)
79 }
80 }
81
82 fn maybe_switch(&self) -> Result<(), JsValue> {
83 let a_start = Atomics::load(&self.header, A_START)? as u32;
84 let a_end = Atomics::load(&self.header, A_END)? as u32;
85 let b_end = Atomics::load(&self.header, B_END)? as u32;
86 if self.len - a_end < a_start - b_end {
87 Atomics::store(&self.header, B_USE, 1i32)?;
88 }
89 Ok(())
90 }
91
92 pub fn split(self) -> (Sender<T>, Receiver<T>) {
96 (Sender(self.clone()), Receiver(self))
97 }
98}
99
100impl<T> Clone for SharedChannel<T>
101where
102 T: Shareable,
103{
104 fn clone(&self) -> Self {
105 Self {
106 _header: self._header.clone(),
107 _data: self._data.clone(),
108 header: self.header.clone(),
109 data: self.data.clone(),
110 len: self.len,
111 phantom_data: PhantomData,
112 }
113 }
114}
115
116#[derive(Clone)]
118pub struct Sender<T>(pub SharedChannel<T>)
119where
120 T: Shareable;
121
122pub struct Receiver<T>(pub SharedChannel<T>)
124where
125 T: Shareable;
126
127pub fn channel<T>(len: u32) -> SharedChannel<T>
129where
130 T: Shareable,
131{
132 let header = SharedArrayBuffer::new(4 * (std::mem::size_of::<u32>() as u32));
133 let data = SharedArrayBuffer::new(len);
134 channel_(header, data)
135}
136
137fn channel_<T>(header: SharedArrayBuffer, data: SharedArrayBuffer) -> SharedChannel<T>
138where
139 T: Shareable,
140{
141 let header_ = Int32Array::new(&header);
142 let data_ = Uint8Array::new(&data);
143 let len = data_.byte_length();
144 SharedChannel {
145 _header: header,
146 _data: data,
147 header: header_,
148 data: data_,
149 len,
150 phantom_data: PhantomData,
151 }
152}
153
154impl<T> Sender<T>
155where
156 T: Shareable,
157{
158 pub fn send(&self, value: &T) -> Result<(), JsValue> {
163 let bytes = value
164 .to_bytes()
165 .map_err(|e| JsValue::from(format!("serialization error: {}", e)))?;
166 let len = bytes.byte_length();
167 if self.0.unused()? < len {
168 return Err("not enough space".to_string().into());
169 }
170 let b_use = (Atomics::load(&self.0.header, B_USE)? as u32) == 1;
171 let end_header = if b_use { B_END } else { A_END };
172
173 let end = Atomics::load(&self.0.header, end_header)? as u32;
174 for i in 0..len {
175 self.0.data.set_index(end + i, bytes.get_index(i));
176 }
177 Atomics::store(&self.0.header, end_header, (end + len) as i32)?;
178 Atomics::notify(&self.0.header, end_header)?;
179 Atomics::notify(&self.0.header, A_START)?;
180
181 self.0.maybe_switch()?;
182
183 Ok(())
184 }
185}
186
187impl<T> Receiver<T>
188where
189 T: Shareable,
190{
191 pub fn recv(&self, timeout: Option<Duration>) -> Result<Option<T>, JsValue> {
202 let mut array = Uint8Array::new_with_length(0);
203 loop {
204 match T::from(&array)
205 .map_err(|e| JsValue::from(format!("deserialization error: {}", e)))?
206 {
207 Ok(value) => {
208 return Ok(Some(value));
209 }
210 Err(Expects(sz)) => {
211 array = Uint8Array::new_with_length(sz);
212 let mut a_start = Atomics::load(&self.0.header, A_START)? as u32;
213 let mut a_end = Atomics::load(&self.0.header, A_END)? as u32;
214 if a_start == a_end || self.0.len < a_start + sz {
215 match timeout {
216 None => return Ok(None),
217 Some(duration) => {
218 let result = Atomics::wait_with_timeout(
219 &self.0.header,
220 A_START,
221 a_start as i32,
222 duration.as_millis() as f64,
223 )?;
224 if result == "timed-out" {
225 return Ok(None);
226 }
227 continue;
228 }
229 }
230 }
231 for i in 0..sz {
232 array.set_index(i, self.0.data.get_index(a_start + i));
233 }
234 a_start += sz;
235 let mut b_end = Atomics::load(&self.0.header, B_END)? as u32;
236 let mut b_use = (Atomics::load(&self.0.header, B_USE)? as u32) == 1;
237 if a_start == a_end {
238 if b_use {
239 a_start = 0;
240 a_end = b_end;
241 b_end = 0;
242 b_use = false;
243 } else {
244 a_start = 0;
245 a_end = 0;
246 }
247 }
248 if T::from(&array)
249 .map_err(|e| JsValue::from(format!("deserialization error: {}", e)))?
250 .is_ok()
251 {
252 Atomics::store(&self.0.header, B_USE, if b_use { 1i32 } else { 0i32 })?;
253 Atomics::store(&self.0.header, A_START, a_start as i32)?;
254 Atomics::store(&self.0.header, A_END, a_end as i32)?;
255 Atomics::store(&self.0.header, B_END, b_end as i32)?;
256 self.0.maybe_switch()?;
257 }
258 }
259 }
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
267
268 use super::*;
269 use wasm_bindgen_test::*;
270
271 #[wasm_bindgen_test]
272 fn test() {
273 let sz = 0u8.to_bytes().unwrap().byte_length();
274 let (sender, receiver) = channel::<u8>(2 * sz).split();
275 sender.send(&1).unwrap();
276 sender.send(&2).unwrap();
277 assert_eq!(receiver.recv(None).unwrap().unwrap(), 1);
278 assert_eq!(receiver.recv(None).unwrap().unwrap(), 2);
279 }
280
281 #[wasm_bindgen_test]
282 fn not_enough_space() {
283 let sz = 0u8.to_bytes().unwrap().byte_length();
284 let (sender, _receiver) = channel::<u8>(1 * sz).split();
285 sender.send(&1).unwrap();
286 assert!(sender.send(&2).is_err());
287 }
288
289 #[wasm_bindgen_test]
290 fn circular() {
291 let sz = 0u8.to_bytes().unwrap().byte_length();
292 let (sender, receiver) = channel::<u8>(8 * sz).split();
293 sender.send(&1).unwrap();
294 sender.send(&2).unwrap();
295 sender.send(&3).unwrap();
296 sender.send(&4).unwrap();
297 sender.send(&5).unwrap();
298 sender.send(&6).unwrap();
299 sender.send(&7).unwrap();
300 sender.send(&8).unwrap();
301 assert_eq!(receiver.recv(None).unwrap().unwrap(), 1);
302 assert_eq!(receiver.recv(None).unwrap().unwrap(), 2);
303 assert_eq!(receiver.recv(None).unwrap().unwrap(), 3);
304 sender.send(&9).unwrap();
305 sender.send(&10).unwrap();
306 sender.send(&11).unwrap();
307 assert_eq!(receiver.recv(None).unwrap().unwrap(), 4);
308 assert_eq!(receiver.recv(None).unwrap().unwrap(), 5);
309 assert_eq!(receiver.recv(None).unwrap().unwrap(), 6);
310 assert_eq!(receiver.recv(None).unwrap().unwrap(), 7);
311 assert_eq!(receiver.recv(None).unwrap().unwrap(), 8);
312 assert_eq!(receiver.recv(None).unwrap().unwrap(), 9);
313 assert_eq!(receiver.recv(None).unwrap().unwrap(), 10);
314 assert_eq!(receiver.recv(None).unwrap().unwrap(), 11);
315 }
316
317 #[wasm_bindgen_test]
318 fn jsvalue() {
319 let sz = 0u8.to_bytes().unwrap().byte_length();
320 let ch = channel::<u8>(2 * sz);
321 let js_value: JsValue = ch.into();
322 let ch: SharedChannel<u8> = js_value.into();
323 let (sender, receiver) = ch.split();
324 sender.send(&1).unwrap();
325 sender.send(&2).unwrap();
326 assert_eq!(receiver.recv(None).unwrap().unwrap(), 1);
327 assert_eq!(receiver.recv(None).unwrap().unwrap(), 2);
328 }
329}