slave_pool/
oneshot.rs

1//! Underlying oneshot implementation
2use std::time;
3use core::{ptr, task, pin};
4use core::cell::{Cell, UnsafeCell};
5use core::mem::MaybeUninit;
6use core::sync::atomic::{Ordering, AtomicU8};
7use core::future::Future;
8
9const UNINIT: u8 = 0;
10const READY: u8 = 0b00001;
11const WAKER_SET: u8 = 0b00010;
12const SEND_CLOSED: u8 = 0b00100;
13const CONSUMED: u8 = 0b01000;
14const RECV_CLOSED: u8 = 0b10000;
15
16use super::JoinError;
17
18enum Notifier {
19    Thread(std::thread::Thread),
20    Waker(core::task::Waker),
21}
22
23struct Payload<T> {
24    state: AtomicU8,
25    value: UnsafeCell<MaybeUninit<T>>,
26    notifier: Cell<MaybeUninit<Notifier>>
27}
28
29impl<T> Payload<T> {
30    const fn new() -> Self {
31        Self {
32            state: AtomicU8::new(UNINIT),
33            value: UnsafeCell::new(MaybeUninit::uninit()),
34            notifier: Cell::new(MaybeUninit::uninit()),
35        }
36    }
37
38    #[inline(never)]
39    ///Sets notifier, updates state and returns previous state
40    fn set_notifier(&self, notifier: Notifier) -> u8 {
41        self.notifier.set(MaybeUninit::new(notifier));
42        self.state.fetch_or(WAKER_SET, Ordering::AcqRel)
43    }
44
45    #[inline(always)]
46    fn take_notifier(&self) -> Notifier {
47        let storage = self.notifier.replace(MaybeUninit::uninit());
48
49        unsafe {
50            storage.assume_init()
51        }
52    }
53}
54
55impl<T> Drop for Payload<T> {
56    fn drop(&mut self) {
57        let state = self.state.load(Ordering::Relaxed);
58        match (state & READY == READY) && (state & CONSUMED != CONSUMED) {
59            true => unsafe {
60                ptr::drop_in_place((*self.value.get()).as_mut_ptr());
61            },
62            _ => (),
63        }
64
65        //If no one is interested in waker, then just drop it without waking up
66        if state & WAKER_SET == WAKER_SET {
67            self.take_notifier();
68        }
69    }
70}
71
72#[repr(transparent)]
73///Sender end, allows to send message once
74///
75///On `Drop` will notify `Receiver`
76pub struct Sender<T> {
77    payload: ptr::NonNull<Payload<T>>,
78}
79
80impl<T> Sender<T> {
81    #[inline(always)]
82    fn payload(&self) -> &Payload<T> {
83        unsafe  {
84            &*self.payload.as_ptr()
85        }
86    }
87
88    ///Performs send of the message, waking receiver, if it awaits
89    pub fn send(self, value: T) {
90        //there is always only one sender
91        unsafe {
92            ptr::write((*self.payload().value.get()).as_mut_ptr(), value);
93        }
94
95        let state = self.payload().state.fetch_or(READY, Ordering::AcqRel);
96        if state & WAKER_SET == WAKER_SET {
97            let notifier = self.payload().take_notifier();
98            self.payload().state.fetch_and(!WAKER_SET, Ordering::Release);
99
100            match notifier {
101                Notifier::Thread(thread) => thread.unpark(),
102                Notifier::Waker(waker) => waker.wake(),
103            }
104        }
105    }
106}
107
108impl<T> Drop for Sender<T> {
109    fn drop(&mut self) {
110        //Make sure to guarantee we acquire RECV_CLOSED prior setting SEND_CLOSED
111        let mut state = self.payload().state.load(Ordering::Acquire);
112        if state & WAKER_SET == WAKER_SET {
113            let notifier = self.payload().take_notifier();
114            //Unset WAKER_SET and set SEND_CLOSED
115            state = self.payload().state.fetch_xor(WAKER_SET | SEND_CLOSED, Ordering::AcqRel);
116
117            match notifier {
118                Notifier::Thread(thread) => thread.unpark(),
119                Notifier::Waker(waker) => waker.wake(),
120            }
121        } else {
122            state = self.payload().state.fetch_or(SEND_CLOSED, Ordering::AcqRel);
123        }
124
125        if state & RECV_CLOSED == RECV_CLOSED {
126            unsafe {
127                let _ = Box::from_raw(self.payload.as_ptr());
128            }
129        }
130    }
131}
132
133unsafe impl<T: Send> Send for Sender<T> {}
134unsafe impl<T: Sync> Sync for Sender<T> {}
135
136#[repr(transparent)]
137///Receiver end to receive message
138///
139///Implements `Future`
140pub struct Receiver<T> {
141    payload: ptr::NonNull<Payload<T>>,
142}
143
144impl<T> Receiver<T> {
145    #[inline(always)]
146    fn payload(&self) -> &Payload<T> {
147        unsafe  {
148            &*self.payload.as_ptr()
149        }
150    }
151
152    fn consume(&self) -> T {
153        self.payload().state.fetch_or(CONSUMED, Ordering::Release);
154        let mut result = MaybeUninit::uninit();
155
156        unsafe {
157            ptr::swap(result.as_mut_ptr(), (*self.payload().value.get()).as_mut_ptr());
158
159            result.assume_init()
160        }
161    }
162
163    #[inline(always)]
164    ///Returns whether job has been finished
165    pub fn is_ready(&self) -> bool {
166        self.payload().state.load(Ordering::Acquire) & READY == READY
167    }
168
169    #[inline(always)]
170    ///Returns whether receiver has been 'consumed'
171    pub fn is_consumed(&self) -> bool {
172        self.payload().state.load(Ordering::Acquire) & CONSUMED == CONSUMED
173    }
174
175    ///Checks if message is received, returning it, if possible
176    ///Otherwise returns `None`
177    pub fn try_recv(&self) -> Result<Option<T>, JoinError> {
178        let state = self.payload().state.load(Ordering::Acquire);
179
180        if state & CONSUMED == CONSUMED {
181            Err(JoinError::AlreadyConsumed)
182        } else if state & READY == READY {
183            Ok(Some(self.consume()))
184        } else if state & SEND_CLOSED == SEND_CLOSED {
185            Err(JoinError::Disconnect)
186        } else {
187            Ok(None)
188        }
189    }
190
191    ///Awaits message blocking until message arrives, returning it
192    ///Or if `Sender` closes unexpectedly (e.g. due to panic) returns `JoinError::Disconnect`
193    pub fn recv(self) -> Result<T, JoinError> {
194        let mut state = self.payload().state.load(Ordering::Acquire);
195
196        if state & CONSUMED == CONSUMED {
197            return Err(JoinError::AlreadyConsumed);
198        } else if state & READY == READY {
199            return Ok(self.consume());
200        } else if state & SEND_CLOSED == SEND_CLOSED {
201            return Err(JoinError::Disconnect);
202        }
203
204        state = self.payload().set_notifier(Notifier::Thread(std::thread::current()));
205
206        while state & READY != READY {
207            //Make sure we're not dropped yet
208            if state & SEND_CLOSED == SEND_CLOSED {
209                return Err(JoinError::Disconnect);
210            }
211
212            std::thread::park();
213
214            state = self.payload().state.load(Ordering::Acquire);
215        }
216
217        Ok(self.consume())
218    }
219
220    ///Awaits message blocking for the duration of `time` until message arrives, returning it
221    ///Or if `Sender` closes unexpectedly (e.g. due to panic) returns `JoinError::Disconnect`
222    ///
223    ///If timeout expires, returns `Ok(None)`
224    pub fn recv_timeout(&self, mut time: time::Duration) -> Result<Option<T>, JoinError> {
225        let mut state = self.payload().state.load(Ordering::Acquire);
226
227        if state & CONSUMED == CONSUMED {
228            return Err(JoinError::AlreadyConsumed);
229        } else if state & READY == READY {
230            return Ok(Some(self.consume()));
231        } else if state & SEND_CLOSED == SEND_CLOSED {
232            return Err(JoinError::Disconnect);
233        }
234
235        state = self.payload().set_notifier(Notifier::Thread(std::thread::current()));
236
237        let start_time = time::Instant::now();
238        while state & READY != READY {
239            std::thread::park_timeout(time);
240
241            if let Some(left_over) = time.checked_sub(start_time.elapsed()) {
242                //If any time left reload state to check flag again before entering new loop
243                time = left_over;
244                state = self.payload().state.load(Ordering::Acquire);
245            } else {
246                break;
247            }
248        }
249        state = self.payload().state.fetch_and(!WAKER_SET, Ordering::AcqRel);
250
251        if state & WAKER_SET == WAKER_SET {
252            self.payload().take_notifier();
253        }
254
255        if state & READY == READY {
256            Ok(Some(self.consume()))
257        } else {
258            Ok(None)
259        }
260    }
261}
262
263impl<T> Drop for Receiver<T> {
264    #[inline(always)]
265    fn drop(&mut self) {
266        //Make sure to guarantee we acquire SEND_CLOSED prior setting RECV_CLOSED
267        let state = self.payload().state.fetch_or(RECV_CLOSED, Ordering::AcqRel);
268        if state & SEND_CLOSED == SEND_CLOSED {
269            unsafe {
270                let _ = Box::from_raw(self.payload.as_ptr());
271            }
272        }
273    }
274}
275
276impl<T> Future for Receiver<T> {
277    type Output = Result<T, JoinError>;
278
279    fn poll(self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
280        let mut state = self.payload().state.load(Ordering::Acquire);
281
282        if state & CONSUMED == CONSUMED {
283            return task::Poll::Ready(Err(JoinError::AlreadyConsumed));
284        } else if state & READY == READY {
285            return task::Poll::Ready(Ok(self.consume()));
286        } else if state & SEND_CLOSED == SEND_CLOSED {
287            return task::Poll::Ready(Err(JoinError::Disconnect));
288        }
289
290        //Account for spontaneous wake up
291        if state & WAKER_SET == WAKER_SET {
292            state = self.payload().state.load(Ordering::Acquire);
293        } else {
294            state = self.payload().set_notifier(Notifier::Waker(cx.waker().clone()));
295        }
296
297        //Just in case double-check
298        if state & CONSUMED == CONSUMED {
299            return task::Poll::Ready(Err(JoinError::AlreadyConsumed));
300        } else if state & READY == READY {
301            return task::Poll::Ready(Ok(self.consume()));
302        } else if state & SEND_CLOSED == SEND_CLOSED {
303            return task::Poll::Ready(Err(JoinError::Disconnect));
304        } else {
305            task::Poll::Pending
306        }
307    }
308}
309
310unsafe impl<T: Send> Send for Receiver<T> {}
311impl<T> Unpin for Receiver<T> {}
312
313//Impossible to guarantee as we need to write waker without lock
314//unsafe impl<T> Sync for Receiver<T> {}
315
316///Creates new oneshot pipe
317pub fn oneshot<T>() -> (Sender<T>, Receiver<T>) {
318    let payload = ptr::NonNull::from(Box::leak(Box::new(Payload::new())));
319
320    let sender = Sender {
321        payload,
322    };
323
324    let receiver = Receiver {
325        payload,
326    };
327
328    (sender, receiver)
329}