toad_common/
stem.rs

1use core::ops::{Deref, DerefMut};
2
3#[cfg(feature = "std")]
4type Inner<T> = std::sync::RwLock<T>;
5
6#[cfg(not(feature = "std"))]
7type Inner<T> = core::cell::RefCell<T>;
8
9/// A thread-safe mutable memory location that allows
10/// for many concurrent readers or a single writer.
11///
12/// When feature `std` enabled, this uses [`std::sync::RwLock`].
13/// When `std` disabled, uses [`core::cell::Cell`].
14#[derive(Debug, Default)]
15pub struct Stem<T>(Inner<T>);
16
17impl<T> Stem<T> {
18  /// Create a new Stem cell
19  pub const fn new(t: T) -> Self {
20    Self(Inner::new(t))
21  }
22
23  /// Map a reference to `T` to a new type
24  ///
25  /// This will block if called concurrently with `map_mut`.
26  ///
27  /// There can be any number of concurrent `map_ref`
28  /// sections running at a given time.
29  pub fn map_ref<F, R>(&self, f: F) -> R
30    where F: for<'a> FnMut(&'a T) -> R
31  {
32    self.0.map_ref(f)
33  }
34
35  /// Map a mutable reference to `T` to a new type
36  ///
37  /// This will block if called concurrently with `map_ref` or `map_mut`.
38  pub fn map_mut<F, R>(&self, f: F) -> R
39    where F: for<'a> FnMut(&'a mut T) -> R
40  {
41    self.0.map_mut(f)
42  }
43}
44
45// NOTE(orion): I chose to use a trait here to tie RwLock
46// and Cell together in a testable way, to keep the actual
47// code behind feature flags extremely thin.
48
49/// A mutable memory location
50///
51/// This is used to back the behavior of [`Stem`],
52/// which should be used instead of this trait.
53pub trait StemCellBehavior<T> {
54  /// Create an instance of `Self`
55  fn new(t: T) -> Self
56    where Self: Sized;
57
58  /// Map a reference to `T` to a new type
59  ///
60  /// Implementors may choose to panic or block
61  /// if `map_mut` called concurrently.
62  fn map_ref<F, R>(&self, f: F) -> R
63    where F: for<'a> FnMut(&'a T) -> R;
64
65  /// Map a mutable reference to `T` to a new type
66  ///
67  /// Implementors may choose to panic or block
68  /// if `map_ref` or `map_mut` called concurrently.
69  fn map_mut<F, R>(&self, f: F) -> R
70    where F: for<'a> FnMut(&'a mut T) -> R;
71}
72
73#[cfg(feature = "std")]
74impl<T> StemCellBehavior<T> for std::sync::RwLock<T> {
75  fn new(t: T) -> Self {
76    Self::new(t)
77  }
78
79  fn map_ref<F, R>(&self, mut f: F) -> R
80    where F: for<'a> FnMut(&'a T) -> R
81  {
82    f(self.read().unwrap().deref())
83  }
84
85  fn map_mut<F, R>(&self, mut f: F) -> R
86    where F: for<'a> FnMut(&'a mut T) -> R
87  {
88    f(self.write().unwrap().deref_mut())
89  }
90}
91
92impl<T> StemCellBehavior<T> for core::cell::RefCell<T> {
93  fn new(t: T) -> Self {
94    Self::new(t)
95  }
96
97  fn map_ref<F, R>(&self, mut f: F) -> R
98    where F: for<'a> FnMut(&'a T) -> R
99  {
100    f(self.borrow().deref())
101  }
102
103  fn map_mut<F, R>(&self, mut f: F) -> R
104    where F: for<'a> FnMut(&'a mut T) -> R
105  {
106    f(self.borrow_mut().deref_mut())
107  }
108}
109
110#[cfg(test)]
111mod test {
112  use core::cell::RefCell;
113  use std::sync::{Arc, Barrier, RwLock};
114
115  use super::*;
116
117  #[test]
118  fn refcell_modify() {
119    let s = RefCell::new(Vec::<usize>::new());
120    s.map_mut(|v| v.push(12));
121    s.map_ref(|v| assert_eq!(v, &vec![12usize]));
122  }
123
124  #[test]
125  fn refcell_concurrent_read_does_not_panic() {
126    let s = RefCell::new(Vec::<usize>::new());
127    s.map_ref(|_| s.map_ref(|_| ()));
128  }
129
130  #[test]
131  fn rwlock_modify() {
132    let s = RwLock::new(Vec::<usize>::new());
133    s.map_mut(|v| v.push(12));
134    s.map_ref(|v| assert_eq!(v, &vec![12usize]));
135  }
136
137  #[test]
138  fn rwlock_concurrent_read_does_not_panic() {
139    let s = RwLock::new(Vec::<usize>::new());
140    s.map_ref(|_| s.map_ref(|_| ()));
141  }
142
143  #[test]
144  fn stem_modify_blocks_until_refs_dropped() {
145    unsafe {
146      static VEC: Stem<Vec<usize>> = Stem::new(Vec::new());
147
148      static mut START: Option<Arc<Barrier>> = None;
149      static mut READING: Option<Arc<Barrier>> = None;
150      static mut READING_DONE: Option<Arc<Barrier>> = None;
151      static mut MODIFY_DONE: Option<Arc<Barrier>> = None;
152
153      START = Some(Arc::new(Barrier::new(3)));
154      READING = Some(Arc::new(Barrier::new(3)));
155      READING_DONE = Some(Arc::new(Barrier::new(2)));
156      MODIFY_DONE = Some(Arc::new(Barrier::new(3)));
157
158      macro_rules! wait {
159        ($b:ident) => {
160          $b.as_ref().unwrap().clone().wait();
161        };
162      }
163
164      std::thread::spawn(|| {
165        wait!(START);
166        VEC.map_ref(|v| {
167             assert!(v.is_empty());
168             wait!(READING);
169             wait!(READING_DONE);
170           });
171
172        wait!(MODIFY_DONE);
173      });
174
175      std::thread::spawn(|| {
176        wait!(START);
177        wait!(READING);
178        VEC.map_mut(|v| v.push(12)); // unblocked by READING_DONE
179        wait!(MODIFY_DONE);
180      });
181
182      wait!(START);
183      wait!(READING);
184      wait!(READING_DONE);
185      wait!(MODIFY_DONE);
186      VEC.map_ref(|v| assert_eq!(v, &vec![12]));
187    }
188  }
189}