wmark/watermark/
future.rs

1use async_channel::{unbounded, Receiver, Sender};
2use atomic_refcell::AtomicRefCell as RefCell;
3use crossbeam_utils::CachePadded;
4use futures_channel::oneshot;
5use futures_util::FutureExt;
6use smallvec_wrapper::MediumVec;
7
8use core::{
9  cmp::Reverse,
10  sync::atomic::{AtomicU64, Ordering},
11};
12
13#[cfg(feature = "std")]
14use std::{
15  borrow::Cow,
16  collections::{BinaryHeap, HashMap},
17  sync::Arc,
18};
19
20#[cfg(not(feature = "std"))]
21use alloc::{borrow::Cow, collections::BinaryHeap};
22
23#[cfg(not(feature = "std"))]
24use hashbrown::HashMap;
25
26use crate::{AsyncCloser, AsyncSpawner, WaterMarkError};
27
28type Result<T> = core::result::Result<T, WaterMarkError>;
29
30#[derive(Debug)]
31enum MarkIndex {
32  Single(u64),
33  Multiple(MediumVec<u64>),
34}
35
36#[derive(Debug)]
37struct Mark {
38  index: MarkIndex,
39  waiter: Option<oneshot::Sender<()>>,
40  done: bool,
41}
42
43#[derive(Debug)]
44struct Inner<S> {
45  done_until: CachePadded<AtomicU64>,
46  last_index: CachePadded<AtomicU64>,
47  name: Cow<'static, str>,
48  mark_tx: Sender<Mark>,
49  mark_rx: Receiver<Mark>,
50  _spawner: core::marker::PhantomData<S>,
51}
52
53impl<S: AsyncSpawner> Inner<S> {
54  async fn process(&self, closer: AsyncCloser<S>) {
55    scopeguard::defer!(closer.done(););
56
57    let mut indices: BinaryHeap<Reverse<u64>> = BinaryHeap::new();
58    // pending maps raft proposal index to the number of pending mutations for this proposal.
59    let pending: RefCell<HashMap<u64, i64>> = RefCell::new(HashMap::new());
60    let waiters: RefCell<HashMap<u64, MediumVec<oneshot::Sender<()>>>> =
61      RefCell::new(HashMap::new());
62
63    let mut process_one = |idx: u64, done: bool| {
64      // If not already done, then set. Otherwise, don't undo a done entry.
65      let mut pending = pending.borrow_mut();
66      let mut waiters = waiters.borrow_mut();
67
68      if !pending.contains_key(&idx) {
69        indices.push(Reverse(idx));
70      }
71
72      let mut delta = 1;
73      if done {
74        delta = -1;
75      }
76      pending
77        .entry(idx)
78        .and_modify(|v| *v += delta)
79        .or_insert(delta);
80
81      // Update mark by going through all indices in order; and checking if they have
82      // been done. Stop at the first index, which isn't done.
83      let done_until = self.done_until.load(Ordering::SeqCst);
84      assert!(
85        done_until <= idx,
86        "name: {}, done_until: {}, idx: {}",
87        self.name,
88        done_until,
89        idx
90      );
91
92      let mut until = done_until;
93
94      while !indices.is_empty() {
95        let min = indices.peek().unwrap().0;
96        if let Some(done) = pending.get(&min) {
97          if done.gt(&0) {
98            break; // len(indices) will be > 0.
99          }
100        }
101        // Even if done is called multiple times causing it to become
102        // negative, we should still pop the index.
103        indices.pop();
104        pending.remove(&min);
105        until = min;
106      }
107
108      if until != done_until {
109        assert_eq!(
110          self
111            .done_until
112            .compare_exchange(done_until, until, Ordering::SeqCst, Ordering::Acquire),
113          Ok(done_until)
114        );
115      }
116
117      if until - done_until <= waiters.len() as u64 {
118        // Close channel and remove from waiters.
119        (done_until + 1..=until).for_each(|idx| {
120          let _ = waiters.remove(&idx);
121        });
122      } else {
123        // Close and drop idx <= util channels.
124        waiters.retain(|idx, _| *idx > until);
125      }
126    };
127
128    let closer = closer.listen();
129    loop {
130      futures_util::select_biased! {
131        _ = closer.wait().fuse() => return,
132        mark = self.mark_rx.recv().fuse() => match mark {
133          Ok(mark) => {
134            if let Some(wait_tx) = mark.waiter {
135              if let MarkIndex::Single(index) = mark.index {
136                let done_until = self.done_until.load(Ordering::SeqCst);
137                if done_until >= index {
138                  let _ = wait_tx; // Close channel.
139                } else {
140                  waiters.borrow_mut().entry(index).or_default().push(wait_tx);
141                }
142              }
143            } else {
144              match mark.index {
145                MarkIndex::Single(idx) => process_one(idx, mark.done),
146                MarkIndex::Multiple(indices) => indices.into_iter().for_each(|idx| process_one(idx, mark.done)),
147              }
148            }
149          },
150          Err(_) => {
151            // Channel closed.
152            #[cfg(feature = "tracing")]
153            tracing::error!(target: "watermark", err = "watermark has been dropped.");
154            return;
155          }
156        },
157      }
158    }
159  }
160}
161
162/// WaterMark is used to keep track of the minimum un-finished index. Typically, an index k becomes
163/// finished or "done" according to a WaterMark once `done(k)` has been called
164///  1. as many times as `begin(k)` has, AND
165///  2. a positive number of times.
166///
167/// An index may also become "done" by calling `set_done_until` at a time such that it is not
168/// inter-mingled with `begin/done` calls.
169///
170/// Since `done_until` and `last_index` addresses are passed to sync/atomic packages, we ensure that they
171/// are 64-bit aligned by putting them at the beginning of the structure.
172#[derive(Debug)]
173pub struct AsyncWaterMark<S: AsyncSpawner> {
174  inner: Arc<Inner<S>>,
175  initialized: bool,
176}
177
178impl<S: AsyncSpawner> AsyncWaterMark<S> {
179  /// Create a new WaterMark with the given name.
180  ///
181  /// **Note**: Before using the watermark, you must call `init` to start the background thread.
182  #[inline]
183  pub fn new(name: Cow<'static, str>) -> Self {
184    let (mark_tx, mark_rx) = unbounded();
185    Self {
186      inner: Arc::new(Inner {
187        done_until: CachePadded::new(AtomicU64::new(0)),
188        last_index: CachePadded::new(AtomicU64::new(0)),
189        name,
190        mark_tx,
191        mark_rx,
192        _spawner: core::marker::PhantomData,
193      }),
194      initialized: false,
195    }
196  }
197
198  /// Returns the name of the watermark.
199  #[inline(always)]
200  pub fn name(&self) -> &str {
201    self.inner.name.as_ref()
202  }
203
204  /// Initializes a WaterMark struct. MUST be called before using it.
205  #[inline]
206  pub fn init(&mut self, closer: AsyncCloser<S>) {
207    if self.initialized {
208      return;
209    }
210
211    let inner = self.inner.clone();
212    self.initialized = true;
213
214    S::spawn_detach(async move {
215      inner.process(closer).await;
216    });
217  }
218
219  /// Sets the last index to the given value.
220  #[inline]
221  pub fn begin(&self, index: u64) -> Result<()> {
222    self.check()?;
223    self.inner.last_index.store(index, Ordering::SeqCst);
224    self
225      .inner
226      .mark_tx
227      .try_send(Mark {
228        index: MarkIndex::Single(index),
229        waiter: None,
230        done: false,
231      })
232      .unwrap(); // we hold both rx and tx, so cannot fail
233    Ok(())
234  }
235
236  /// Works like [`begin`] but accepts multiple indices.
237  #[inline]
238  pub fn begin_many(&self, indices: MediumVec<u64>) -> Result<()> {
239    if indices.is_empty() {
240      return Ok(());
241    }
242
243    self.check()?;
244
245    let last_index = *indices.last().unwrap();
246    self.inner.last_index.store(last_index, Ordering::SeqCst);
247    self
248      .inner
249      .mark_tx
250      .try_send(Mark {
251        index: MarkIndex::Multiple(indices),
252        waiter: None,
253        done: false,
254      })
255      .unwrap(); // we hold both rx and tx, so cannot fail
256    Ok(())
257  }
258
259  /// Sets a single index as done.
260  #[inline]
261  pub fn done(&self, index: u64) -> Result<()> {
262    self.check()?;
263    self
264      .inner
265      .mark_tx
266      .try_send(Mark {
267        index: MarkIndex::Single(index),
268        waiter: None,
269        done: true,
270      })
271      .unwrap(); // we hold both rx and tx, so cannot fail
272    Ok(())
273  }
274
275  /// Sets multiple indices as done.
276  #[inline]
277  pub fn done_many(&self, indices: MediumVec<u64>) -> Result<()> {
278    self.check()?;
279    self
280      .inner
281      .mark_tx
282      .try_send(Mark {
283        index: MarkIndex::Multiple(indices),
284        waiter: None,
285        done: true,
286      })
287      .unwrap(); // we hold both rx and tx, so cannot fail
288    Ok(())
289  }
290
291  /// Returns the maximum index that has the property that all indices
292  /// less than or equal to it are done.
293  #[inline]
294  pub fn done_until(&self) -> Result<u64> {
295    self
296      .check()
297      .map(|_| self.inner.done_until.load(Ordering::SeqCst))
298  }
299
300  /// Sets the maximum index that has the property that all indices
301  /// less than or equal to it are done.
302  #[inline]
303  pub fn set_done_util(&self, val: u64) -> Result<()> {
304    self
305      .check()
306      .map(|_| self.inner.done_until.store(val, Ordering::SeqCst))
307  }
308
309  /// Returns the last index for which `begin` has been called.
310  #[inline]
311  pub fn last_index(&self) -> Result<u64> {
312    self
313      .check()
314      .map(|_| self.inner.last_index.load(Ordering::SeqCst))
315  }
316
317  /// Waits until the given index is marked as done.
318  #[inline]
319  pub async fn wait_for_mark(&self, index: u64) -> Result<()> {
320    if self.inner.done_until.load(Ordering::SeqCst) >= index {
321      return Ok(());
322    }
323
324    let (wait_tx, wait_rx) = oneshot::channel();
325    self
326      .inner
327      .mark_tx
328      .try_send(Mark {
329        index: MarkIndex::Single(index),
330        waiter: Some(wait_tx),
331        done: false,
332      })
333      .unwrap(); // we hold both rx and tx, so cannot fail?
334
335    let _ = wait_rx.await;
336    Ok(())
337  }
338
339  #[inline]
340  fn check(&self) -> Result<()> {
341    if !self.initialized {
342      Err(WaterMarkError::Uninitialized)
343    } else {
344      Ok(())
345    }
346  }
347}
348
349#[cfg(test)]
350mod tests {
351  use super::*;
352  use core::future::Future;
353
354  async fn init_and_close<S, Fut, F>(f: F)
355  where
356    Fut: Future,
357    F: FnOnce(AsyncWaterMark<S>) -> Fut,
358    S: AsyncSpawner,
359  {
360    let closer = AsyncCloser::new(1);
361
362    let mut watermark = AsyncWaterMark::new("watermark".into());
363    watermark.init(closer.clone());
364    assert_eq!(watermark.name(), "watermark");
365
366    f(watermark).await;
367
368    closer.signal_and_wait().await;
369  }
370
371  #[tokio::test]
372  async fn test_basic() {
373    init_and_close::<crate::TokioSpawner, _, _>(|_| async {}).await;
374  }
375
376  #[tokio::test]
377  async fn test_begin_done() {
378    init_and_close::<crate::TokioSpawner, _, _>(|watermark| async move {
379      watermark.begin(1).unwrap();
380      watermark.begin_many([2, 3].into_iter().collect()).unwrap();
381
382      watermark.done(1).unwrap();
383      watermark.done_many([2, 3].into_iter().collect()).unwrap();
384    })
385    .await;
386  }
387
388  #[tokio::test]
389  async fn test_wait_for_mark() {
390    init_and_close::<crate::TokioSpawner, _, _>(|watermark| async move {
391      watermark
392        .begin_many([1, 2, 3].into_iter().collect())
393        .unwrap();
394      watermark.done_many([2, 3].into_iter().collect()).unwrap();
395
396      assert_eq!(watermark.done_until().unwrap(), 0);
397
398      watermark.done(1).unwrap();
399      watermark.wait_for_mark(1).await.unwrap();
400      watermark.wait_for_mark(3).await.unwrap();
401      assert_eq!(watermark.done_until().unwrap(), 3);
402    })
403    .await;
404  }
405
406  #[tokio::test]
407  async fn test_set_done_until() {
408    init_and_close::<crate::TokioSpawner, _, _>(|watermark| async move {
409      watermark.set_done_util(1).unwrap();
410      assert_eq!(watermark.done_until().unwrap(), 1);
411    })
412    .await;
413  }
414
415  #[tokio::test]
416  async fn test_last_index() {
417    init_and_close::<crate::TokioSpawner, _, _>(|watermark| async move {
418      watermark
419        .begin_many([1, 2, 3].into_iter().collect())
420        .unwrap();
421      watermark.done_many([2, 3].into_iter().collect()).unwrap();
422
423      assert_eq!(watermark.last_index().unwrap(), 3);
424    })
425    .await;
426  }
427
428  #[tokio::test]
429  async fn test_multiple_singles() {
430    let closer = AsyncCloser::<crate::TokioSpawner>::default();
431    closer.signal();
432    closer.signal();
433    closer.signal_and_wait().await;
434
435    let closer = AsyncCloser::<crate::TokioSpawner>::new(1);
436    closer.done();
437    closer.signal_and_wait().await;
438    closer.signal_and_wait().await;
439    closer.signal();
440  }
441
442  #[tokio::test]
443  async fn test_closer() {
444    let closer = AsyncCloser::<crate::TokioSpawner>::new(1);
445    let tc = closer.clone();
446    tokio::spawn(async move {
447      tc.listen().wait().await;
448      tc.done();
449    });
450    closer.signal_and_wait().await;
451  }
452
453  #[tokio::test]
454  async fn test_closer_() {
455    use async_channel::unbounded;
456    use core::time::Duration;
457
458    let (tx, rx) = unbounded();
459
460    let c = AsyncCloser::<crate::TokioSpawner>::default();
461
462    for _ in 0..10 {
463      let c = c.clone();
464      let tx = tx.clone();
465      tokio::spawn(async move {
466        c.listen().wait().await;
467        tx.send(()).await.unwrap();
468      });
469    }
470    c.signal();
471    for _ in 0..10 {
472      tokio::time::timeout(Duration::from_millis(1000), rx.recv())
473        .await
474        .unwrap()
475        .unwrap();
476    }
477  }
478}