zenoh_sync/
event.rs

1//
2// Copyright (c) 2024 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14use std::{
15    fmt,
16    sync::{
17        atomic::{AtomicU16, AtomicU8, Ordering},
18        Arc,
19    },
20    time::{Duration, Instant},
21};
22
23use event_listener::{Event as EventLib, Listener};
24
25// Error types
26const WAIT_ERR_STR: &str = "No notifier available";
27pub struct WaitError;
28
29impl fmt::Display for WaitError {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        write!(f, "{self:?}")
32    }
33}
34
35impl fmt::Debug for WaitError {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        f.write_str(WAIT_ERR_STR)
38    }
39}
40
41impl std::error::Error for WaitError {}
42
43#[repr(u8)]
44pub enum WaitDeadlineError {
45    Deadline,
46    WaitError,
47}
48
49impl fmt::Display for WaitDeadlineError {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        write!(f, "{self:?}")
52    }
53}
54
55impl fmt::Debug for WaitDeadlineError {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        match self {
58            Self::Deadline => f.write_str("Deadline reached"),
59            Self::WaitError => f.write_str(WAIT_ERR_STR),
60        }
61    }
62}
63
64impl std::error::Error for WaitDeadlineError {}
65
66#[repr(u8)]
67pub enum WaitTimeoutError {
68    Timeout,
69    WaitError,
70}
71
72impl fmt::Display for WaitTimeoutError {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        write!(f, "{self:?}")
75    }
76}
77
78impl fmt::Debug for WaitTimeoutError {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match self {
81            Self::Timeout => f.write_str("Timeout expired"),
82            Self::WaitError => f.write_str(WAIT_ERR_STR),
83        }
84    }
85}
86
87impl std::error::Error for WaitTimeoutError {}
88
89const NOTIFY_ERR_STR: &str = "No waiter available";
90pub struct NotifyError;
91
92impl fmt::Display for NotifyError {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        write!(f, "{self:?}")
95    }
96}
97
98impl fmt::Debug for NotifyError {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        f.write_str(NOTIFY_ERR_STR)
101    }
102}
103
104impl std::error::Error for NotifyError {}
105
106// Inner
107struct EventInner {
108    event: EventLib,
109    flag: AtomicU8,
110    notifiers: AtomicU16,
111    waiters: AtomicU16,
112}
113
114const UNSET: u8 = 0;
115const OK: u8 = 1;
116const ERR: u8 = 1 << 1;
117
118#[repr(u8)]
119enum EventCheck {
120    Unset = UNSET,
121    Ok = OK,
122    Err = ERR,
123}
124
125#[repr(u8)]
126enum EventSet {
127    Ok = OK,
128    Err = ERR,
129}
130
131impl EventInner {
132    fn check(&self) -> EventCheck {
133        let f = self.flag.fetch_and(!OK, Ordering::SeqCst);
134        if f & ERR != 0 {
135            return EventCheck::Err;
136        }
137        if f == OK {
138            return EventCheck::Ok;
139        }
140        EventCheck::Unset
141    }
142
143    fn set(&self) -> EventSet {
144        let f = self.flag.fetch_or(OK, Ordering::SeqCst);
145        if f & ERR != 0 {
146            return EventSet::Err;
147        }
148        EventSet::Ok
149    }
150
151    fn err(&self) {
152        self.flag.store(ERR, Ordering::SeqCst);
153    }
154}
155
156/// Creates a new lock-free event variable. Every time a [`Notifier`] calls ['Notifier::notify`], one [`Waiter`] will be waken-up.
157/// If no waiter is waiting when the `notify` is called, the notification will not be lost. That means the next waiter will return
158/// immediately when calling `wait`.
159pub fn new() -> (Notifier, Waiter) {
160    let inner = Arc::new(EventInner {
161        event: EventLib::new(),
162        flag: AtomicU8::new(UNSET),
163        notifiers: AtomicU16::new(1),
164        waiters: AtomicU16::new(1),
165    });
166    (Notifier(inner.clone()), Waiter(inner))
167}
168
169/// A [`Notifier`] is used to notify and wake up one and only one [`Waiter`].
170#[repr(transparent)]
171pub struct Notifier(Arc<EventInner>);
172
173impl Notifier {
174    /// Notifies one pending listener
175    #[inline]
176    pub fn notify(&self) -> Result<(), NotifyError> {
177        // Set the flag.
178        match self.0.set() {
179            EventSet::Ok => {
180                self.0.event.notify_additional_relaxed(1);
181                Ok(())
182            }
183            EventSet::Err => Err(NotifyError),
184        }
185    }
186}
187
188impl Clone for Notifier {
189    fn clone(&self) -> Self {
190        let n = self.0.notifiers.fetch_add(1, Ordering::SeqCst);
191        // Panic on overflow
192        assert!(n != 0);
193        Self(self.0.clone())
194    }
195}
196
197impl Drop for Notifier {
198    fn drop(&mut self) {
199        let n = self.0.notifiers.fetch_sub(1, Ordering::SeqCst);
200        if n == 1 {
201            // The last Notifier has been dropped, close the event and notify everyone
202            self.0.err();
203            self.0.event.notify(usize::MAX);
204        }
205    }
206}
207
208#[repr(transparent)]
209pub struct Waiter(Arc<EventInner>);
210
211impl Waiter {
212    /// Waits for the condition to be notified
213    #[inline]
214    pub async fn wait_async(&self) -> Result<(), WaitError> {
215        // Wait until the flag is set.
216        loop {
217            // Check the flag.
218            match self.0.check() {
219                EventCheck::Ok => break,
220                EventCheck::Unset => {}
221                EventCheck::Err => return Err(WaitError),
222            }
223
224            // Start listening for events.
225            let listener = self.0.event.listen();
226
227            // Check the flag again after creating the listener.
228            match self.0.check() {
229                EventCheck::Ok => break,
230                EventCheck::Unset => {}
231                EventCheck::Err => return Err(WaitError),
232            }
233
234            // Wait for a notification and continue the loop.
235            listener.await;
236        }
237
238        Ok(())
239    }
240
241    /// Waits for the condition to be notified
242    #[inline]
243    pub fn wait(&self) -> Result<(), WaitError> {
244        // Wait until the flag is set.
245        loop {
246            // Check the flag.
247            match self.0.check() {
248                EventCheck::Ok => break,
249                EventCheck::Unset => {}
250                EventCheck::Err => return Err(WaitError),
251            }
252
253            // Start listening for events.
254            let listener = self.0.event.listen();
255
256            // Check the flag again after creating the listener.
257            match self.0.check() {
258                EventCheck::Ok => break,
259                EventCheck::Unset => {}
260                EventCheck::Err => return Err(WaitError),
261            }
262
263            // Wait for a notification and continue the loop.
264            listener.wait();
265        }
266
267        Ok(())
268    }
269
270    /// Waits for the condition to be notified or returns an error when the deadline is reached
271    #[inline]
272    pub fn wait_deadline(&self, deadline: Instant) -> Result<(), WaitDeadlineError> {
273        // Wait until the flag is set.
274        loop {
275            // Check the flag.
276            match self.0.check() {
277                EventCheck::Ok => break,
278                EventCheck::Unset => {}
279                EventCheck::Err => return Err(WaitDeadlineError::WaitError),
280            }
281
282            // Start listening for events.
283            let listener = self.0.event.listen();
284
285            // Check the flag again after creating the listener.
286            match self.0.check() {
287                EventCheck::Ok => break,
288                EventCheck::Unset => {}
289                EventCheck::Err => return Err(WaitDeadlineError::WaitError),
290            }
291
292            // Wait for a notification and continue the loop.
293            if listener.wait_deadline(deadline).is_none() {
294                return Err(WaitDeadlineError::Deadline);
295            }
296        }
297
298        Ok(())
299    }
300
301    /// Waits for the condition to be notified or returns an error when the timeout is expired
302    #[inline]
303    pub fn wait_timeout(&self, timeout: Duration) -> Result<(), WaitTimeoutError> {
304        // Wait until the flag is set.
305        loop {
306            // Check the flag.
307            match self.0.check() {
308                EventCheck::Ok => break,
309                EventCheck::Unset => {}
310                EventCheck::Err => return Err(WaitTimeoutError::WaitError),
311            }
312
313            // Start listening for events.
314            let listener = self.0.event.listen();
315
316            // Check the flag again after creating the listener.
317            match self.0.check() {
318                EventCheck::Ok => break,
319                EventCheck::Unset => {}
320                EventCheck::Err => return Err(WaitTimeoutError::WaitError),
321            }
322
323            // Wait for a notification and continue the loop.
324            if listener.wait_timeout(timeout).is_none() {
325                return Err(WaitTimeoutError::Timeout);
326            }
327        }
328
329        Ok(())
330    }
331}
332
333impl Clone for Waiter {
334    fn clone(&self) -> Self {
335        let n = self.0.waiters.fetch_add(1, Ordering::Relaxed);
336        // Panic on overflow
337        assert!(n != 0);
338        Self(self.0.clone())
339    }
340}
341
342impl Drop for Waiter {
343    fn drop(&mut self) {
344        let n = self.0.waiters.fetch_sub(1, Ordering::SeqCst);
345        if n == 1 {
346            // The last Waiter has been dropped, close the event
347            self.0.err();
348        }
349    }
350}
351
352mod tests {
353    #[test]
354    fn event_timeout() {
355        use std::{
356            sync::{Arc, Barrier},
357            time::Duration,
358        };
359
360        use crate::WaitTimeoutError;
361
362        let barrier = Arc::new(Barrier::new(2));
363        let (notifier, waiter) = super::new();
364        let tslot = Duration::from_secs(1);
365
366        let bs = barrier.clone();
367        let s = std::thread::spawn(move || {
368            // 1 - Wait one notification
369            match waiter.wait_timeout(tslot) {
370                Ok(()) => {}
371                Err(WaitTimeoutError::Timeout) => panic!("Timeout {tslot:#?}"),
372                Err(WaitTimeoutError::WaitError) => panic!("Event closed"),
373            }
374
375            bs.wait();
376
377            // 2 - Being notified twice but waiting only once
378            bs.wait();
379
380            match waiter.wait_timeout(tslot) {
381                Ok(()) => {}
382                Err(WaitTimeoutError::Timeout) => panic!("Timeout {tslot:#?}"),
383                Err(WaitTimeoutError::WaitError) => panic!("Event closed"),
384            }
385
386            match waiter.wait_timeout(tslot) {
387                Ok(()) => panic!("Event Ok but it should be Timeout"),
388                Err(WaitTimeoutError::Timeout) => {}
389                Err(WaitTimeoutError::WaitError) => panic!("Event closed"),
390            }
391
392            bs.wait();
393
394            // 3 - Notifier has been dropped
395            bs.wait();
396
397            waiter.wait().unwrap_err();
398
399            bs.wait();
400        });
401
402        let bp = barrier.clone();
403        let p = std::thread::spawn(move || {
404            // 1 - Notify once
405            notifier.notify().unwrap();
406
407            bp.wait();
408
409            // 2 - Notify twice
410            notifier.notify().unwrap();
411            notifier.notify().unwrap();
412
413            bp.wait();
414            bp.wait();
415
416            // 3 - Drop notifier yielding an error in the waiter
417            drop(notifier);
418
419            bp.wait();
420            bp.wait();
421        });
422
423        s.join().unwrap();
424        p.join().unwrap();
425    }
426
427    #[test]
428    fn event_deadline() {
429        use std::{
430            sync::{Arc, Barrier},
431            time::{Duration, Instant},
432        };
433
434        use crate::WaitDeadlineError;
435
436        let barrier = Arc::new(Barrier::new(2));
437        let (notifier, waiter) = super::new();
438        let tslot = Duration::from_secs(1);
439
440        let bs = barrier.clone();
441        let s = std::thread::spawn(move || {
442            // 1 - Wait one notification
443            match waiter.wait_deadline(Instant::now() + tslot) {
444                Ok(()) => {}
445                Err(WaitDeadlineError::Deadline) => panic!("Timeout {tslot:#?}"),
446                Err(WaitDeadlineError::WaitError) => panic!("Event closed"),
447            }
448
449            bs.wait();
450
451            // 2 - Being notified twice but waiting only once
452            bs.wait();
453
454            match waiter.wait_deadline(Instant::now() + tslot) {
455                Ok(()) => {}
456                Err(WaitDeadlineError::Deadline) => panic!("Timeout {tslot:#?}"),
457                Err(WaitDeadlineError::WaitError) => panic!("Event closed"),
458            }
459
460            match waiter.wait_deadline(Instant::now() + tslot) {
461                Ok(()) => panic!("Event Ok but it should be Timeout"),
462                Err(WaitDeadlineError::Deadline) => {}
463                Err(WaitDeadlineError::WaitError) => panic!("Event closed"),
464            }
465
466            bs.wait();
467
468            // 3 - Notifier has been dropped
469            bs.wait();
470
471            waiter.wait().unwrap_err();
472
473            bs.wait();
474        });
475
476        let bp = barrier.clone();
477        let p = std::thread::spawn(move || {
478            // 1 - Notify once
479            notifier.notify().unwrap();
480
481            bp.wait();
482
483            // 2 - Notify twice
484            notifier.notify().unwrap();
485            notifier.notify().unwrap();
486
487            bp.wait();
488            bp.wait();
489
490            // 3 - Drop notifier yielding an error in the waiter
491            drop(notifier);
492
493            bp.wait();
494            bp.wait();
495        });
496
497        s.join().unwrap();
498        p.join().unwrap();
499    }
500
501    #[test]
502    fn event_loop() {
503        use std::{
504            sync::{
505                atomic::{AtomicUsize, Ordering},
506                Arc, Barrier,
507            },
508            time::{Duration, Instant},
509        };
510
511        const N: usize = 1_000;
512        static COUNTER: AtomicUsize = AtomicUsize::new(0);
513
514        let (notifier, waiter) = super::new();
515        let barrier = Arc::new(Barrier::new(2));
516
517        let bs = barrier.clone();
518        let s = std::thread::spawn(move || {
519            for _ in 0..N {
520                waiter.wait().unwrap();
521                COUNTER.fetch_add(1, Ordering::Relaxed);
522                bs.wait();
523            }
524        });
525        let p = std::thread::spawn(move || {
526            for _ in 0..N {
527                notifier.notify().unwrap();
528                barrier.wait();
529            }
530        });
531
532        let start = Instant::now();
533        let tout = Duration::from_secs(60);
534        loop {
535            let n = COUNTER.load(Ordering::Relaxed);
536            if n == N {
537                break;
538            }
539            if start.elapsed() > tout {
540                panic!("Timeout {tout:#?}. Counter: {n}/{N}");
541            }
542
543            std::thread::sleep(Duration::from_millis(100));
544        }
545
546        s.join().unwrap();
547        p.join().unwrap();
548    }
549
550    #[test]
551    fn event_multiple() {
552        use std::{
553            sync::atomic::{AtomicUsize, Ordering},
554            time::{Duration, Instant},
555        };
556
557        const N: usize = 1_000;
558        static COUNTER: AtomicUsize = AtomicUsize::new(0);
559
560        let (notifier, waiter) = super::new();
561
562        let w1 = waiter.clone();
563        let s1 = std::thread::spawn(move || {
564            let mut n = 0;
565            while COUNTER.fetch_add(1, Ordering::Relaxed) < N - 2 {
566                w1.wait().unwrap();
567                n += 1;
568            }
569            println!("S1: {n}");
570        });
571        let s2 = std::thread::spawn(move || {
572            let mut n = 0;
573            while COUNTER.fetch_add(1, Ordering::Relaxed) < N - 2 {
574                waiter.wait().unwrap();
575                n += 1;
576            }
577            println!("S2: {n}");
578        });
579
580        let n1 = notifier.clone();
581        let p1 = std::thread::spawn(move || {
582            let mut n = 0;
583            while COUNTER.load(Ordering::Relaxed) < N {
584                n1.notify().unwrap();
585                n += 1;
586                std::thread::sleep(Duration::from_millis(1));
587            }
588            println!("P1: {n}");
589        });
590        let p2 = std::thread::spawn(move || {
591            let mut n = 0;
592            while COUNTER.load(Ordering::Relaxed) < N {
593                notifier.notify().unwrap();
594                n += 1;
595                std::thread::sleep(Duration::from_millis(1));
596            }
597            println!("P2: {n}");
598        });
599
600        std::thread::spawn(move || {
601            let start = Instant::now();
602            let tout = Duration::from_secs(60);
603            loop {
604                let n = COUNTER.load(Ordering::Relaxed);
605                if n == N {
606                    break;
607                }
608                if start.elapsed() > tout {
609                    panic!("Timeout {tout:#?}. Counter: {n}/{N}");
610                }
611
612                std::thread::sleep(Duration::from_millis(100));
613            }
614        });
615
616        p1.join().unwrap();
617        p2.join().unwrap();
618
619        s1.join().unwrap();
620        s2.join().unwrap();
621    }
622}