1use async_channel::{unbounded, Receiver, Sender};
2use atomic_refcell::AtomicRefCell as RefCell;
3use crossbeam_utils::CachePadded;
4use futures_channel::oneshot;
5use futures_util::FutureExt;
6use smallvec_wrapper::MediumVec;
7
8use core::{
9 cmp::Reverse,
10 sync::atomic::{AtomicU64, Ordering},
11};
12
13#[cfg(feature = "std")]
14use std::{
15 borrow::Cow,
16 collections::{BinaryHeap, HashMap},
17 sync::Arc,
18};
19
20#[cfg(not(feature = "std"))]
21use alloc::{borrow::Cow, collections::BinaryHeap};
22
23#[cfg(not(feature = "std"))]
24use hashbrown::HashMap;
25
26use crate::{AsyncCloser, AsyncSpawner, WaterMarkError};
27
28type Result<T> = core::result::Result<T, WaterMarkError>;
29
30#[derive(Debug)]
31enum MarkIndex {
32 Single(u64),
33 Multiple(MediumVec<u64>),
34}
35
36#[derive(Debug)]
37struct Mark {
38 index: MarkIndex,
39 waiter: Option<oneshot::Sender<()>>,
40 done: bool,
41}
42
43#[derive(Debug)]
44struct Inner<S> {
45 done_until: CachePadded<AtomicU64>,
46 last_index: CachePadded<AtomicU64>,
47 name: Cow<'static, str>,
48 mark_tx: Sender<Mark>,
49 mark_rx: Receiver<Mark>,
50 _spawner: core::marker::PhantomData<S>,
51}
52
53impl<S: AsyncSpawner> Inner<S> {
54 async fn process(&self, closer: AsyncCloser<S>) {
55 scopeguard::defer!(closer.done(););
56
57 let mut indices: BinaryHeap<Reverse<u64>> = BinaryHeap::new();
58 let pending: RefCell<HashMap<u64, i64>> = RefCell::new(HashMap::new());
60 let waiters: RefCell<HashMap<u64, MediumVec<oneshot::Sender<()>>>> =
61 RefCell::new(HashMap::new());
62
63 let mut process_one = |idx: u64, done: bool| {
64 let mut pending = pending.borrow_mut();
66 let mut waiters = waiters.borrow_mut();
67
68 if !pending.contains_key(&idx) {
69 indices.push(Reverse(idx));
70 }
71
72 let mut delta = 1;
73 if done {
74 delta = -1;
75 }
76 pending
77 .entry(idx)
78 .and_modify(|v| *v += delta)
79 .or_insert(delta);
80
81 let done_until = self.done_until.load(Ordering::SeqCst);
84 assert!(
85 done_until <= idx,
86 "name: {}, done_until: {}, idx: {}",
87 self.name,
88 done_until,
89 idx
90 );
91
92 let mut until = done_until;
93
94 while !indices.is_empty() {
95 let min = indices.peek().unwrap().0;
96 if let Some(done) = pending.get(&min) {
97 if done.gt(&0) {
98 break; }
100 }
101 indices.pop();
104 pending.remove(&min);
105 until = min;
106 }
107
108 if until != done_until {
109 assert_eq!(
110 self
111 .done_until
112 .compare_exchange(done_until, until, Ordering::SeqCst, Ordering::Acquire),
113 Ok(done_until)
114 );
115 }
116
117 if until - done_until <= waiters.len() as u64 {
118 (done_until + 1..=until).for_each(|idx| {
120 let _ = waiters.remove(&idx);
121 });
122 } else {
123 waiters.retain(|idx, _| *idx > until);
125 }
126 };
127
128 let closer = closer.listen();
129 loop {
130 futures_util::select_biased! {
131 _ = closer.wait().fuse() => return,
132 mark = self.mark_rx.recv().fuse() => match mark {
133 Ok(mark) => {
134 if let Some(wait_tx) = mark.waiter {
135 if let MarkIndex::Single(index) = mark.index {
136 let done_until = self.done_until.load(Ordering::SeqCst);
137 if done_until >= index {
138 let _ = wait_tx; } else {
140 waiters.borrow_mut().entry(index).or_default().push(wait_tx);
141 }
142 }
143 } else {
144 match mark.index {
145 MarkIndex::Single(idx) => process_one(idx, mark.done),
146 MarkIndex::Multiple(indices) => indices.into_iter().for_each(|idx| process_one(idx, mark.done)),
147 }
148 }
149 },
150 Err(_) => {
151 #[cfg(feature = "tracing")]
153 tracing::error!(target: "watermark", err = "watermark has been dropped.");
154 return;
155 }
156 },
157 }
158 }
159 }
160}
161
162#[derive(Debug)]
173pub struct AsyncWaterMark<S: AsyncSpawner> {
174 inner: Arc<Inner<S>>,
175 initialized: bool,
176}
177
178impl<S: AsyncSpawner> AsyncWaterMark<S> {
179 #[inline]
183 pub fn new(name: Cow<'static, str>) -> Self {
184 let (mark_tx, mark_rx) = unbounded();
185 Self {
186 inner: Arc::new(Inner {
187 done_until: CachePadded::new(AtomicU64::new(0)),
188 last_index: CachePadded::new(AtomicU64::new(0)),
189 name,
190 mark_tx,
191 mark_rx,
192 _spawner: core::marker::PhantomData,
193 }),
194 initialized: false,
195 }
196 }
197
198 #[inline(always)]
200 pub fn name(&self) -> &str {
201 self.inner.name.as_ref()
202 }
203
204 #[inline]
206 pub fn init(&mut self, closer: AsyncCloser<S>) {
207 if self.initialized {
208 return;
209 }
210
211 let inner = self.inner.clone();
212 self.initialized = true;
213
214 S::spawn_detach(async move {
215 inner.process(closer).await;
216 });
217 }
218
219 #[inline]
221 pub fn begin(&self, index: u64) -> Result<()> {
222 self.check()?;
223 self.inner.last_index.store(index, Ordering::SeqCst);
224 self
225 .inner
226 .mark_tx
227 .try_send(Mark {
228 index: MarkIndex::Single(index),
229 waiter: None,
230 done: false,
231 })
232 .unwrap(); Ok(())
234 }
235
236 #[inline]
238 pub fn begin_many(&self, indices: MediumVec<u64>) -> Result<()> {
239 if indices.is_empty() {
240 return Ok(());
241 }
242
243 self.check()?;
244
245 let last_index = *indices.last().unwrap();
246 self.inner.last_index.store(last_index, Ordering::SeqCst);
247 self
248 .inner
249 .mark_tx
250 .try_send(Mark {
251 index: MarkIndex::Multiple(indices),
252 waiter: None,
253 done: false,
254 })
255 .unwrap(); Ok(())
257 }
258
259 #[inline]
261 pub fn done(&self, index: u64) -> Result<()> {
262 self.check()?;
263 self
264 .inner
265 .mark_tx
266 .try_send(Mark {
267 index: MarkIndex::Single(index),
268 waiter: None,
269 done: true,
270 })
271 .unwrap(); Ok(())
273 }
274
275 #[inline]
277 pub fn done_many(&self, indices: MediumVec<u64>) -> Result<()> {
278 self.check()?;
279 self
280 .inner
281 .mark_tx
282 .try_send(Mark {
283 index: MarkIndex::Multiple(indices),
284 waiter: None,
285 done: true,
286 })
287 .unwrap(); Ok(())
289 }
290
291 #[inline]
294 pub fn done_until(&self) -> Result<u64> {
295 self
296 .check()
297 .map(|_| self.inner.done_until.load(Ordering::SeqCst))
298 }
299
300 #[inline]
303 pub fn set_done_util(&self, val: u64) -> Result<()> {
304 self
305 .check()
306 .map(|_| self.inner.done_until.store(val, Ordering::SeqCst))
307 }
308
309 #[inline]
311 pub fn last_index(&self) -> Result<u64> {
312 self
313 .check()
314 .map(|_| self.inner.last_index.load(Ordering::SeqCst))
315 }
316
317 #[inline]
319 pub async fn wait_for_mark(&self, index: u64) -> Result<()> {
320 if self.inner.done_until.load(Ordering::SeqCst) >= index {
321 return Ok(());
322 }
323
324 let (wait_tx, wait_rx) = oneshot::channel();
325 self
326 .inner
327 .mark_tx
328 .try_send(Mark {
329 index: MarkIndex::Single(index),
330 waiter: Some(wait_tx),
331 done: false,
332 })
333 .unwrap(); let _ = wait_rx.await;
336 Ok(())
337 }
338
339 #[inline]
340 fn check(&self) -> Result<()> {
341 if !self.initialized {
342 Err(WaterMarkError::Uninitialized)
343 } else {
344 Ok(())
345 }
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use core::future::Future;
353
354 async fn init_and_close<S, Fut, F>(f: F)
355 where
356 Fut: Future,
357 F: FnOnce(AsyncWaterMark<S>) -> Fut,
358 S: AsyncSpawner,
359 {
360 let closer = AsyncCloser::new(1);
361
362 let mut watermark = AsyncWaterMark::new("watermark".into());
363 watermark.init(closer.clone());
364 assert_eq!(watermark.name(), "watermark");
365
366 f(watermark).await;
367
368 closer.signal_and_wait().await;
369 }
370
371 #[tokio::test]
372 async fn test_basic() {
373 init_and_close::<crate::TokioSpawner, _, _>(|_| async {}).await;
374 }
375
376 #[tokio::test]
377 async fn test_begin_done() {
378 init_and_close::<crate::TokioSpawner, _, _>(|watermark| async move {
379 watermark.begin(1).unwrap();
380 watermark.begin_many([2, 3].into_iter().collect()).unwrap();
381
382 watermark.done(1).unwrap();
383 watermark.done_many([2, 3].into_iter().collect()).unwrap();
384 })
385 .await;
386 }
387
388 #[tokio::test]
389 async fn test_wait_for_mark() {
390 init_and_close::<crate::TokioSpawner, _, _>(|watermark| async move {
391 watermark
392 .begin_many([1, 2, 3].into_iter().collect())
393 .unwrap();
394 watermark.done_many([2, 3].into_iter().collect()).unwrap();
395
396 assert_eq!(watermark.done_until().unwrap(), 0);
397
398 watermark.done(1).unwrap();
399 watermark.wait_for_mark(1).await.unwrap();
400 watermark.wait_for_mark(3).await.unwrap();
401 assert_eq!(watermark.done_until().unwrap(), 3);
402 })
403 .await;
404 }
405
406 #[tokio::test]
407 async fn test_set_done_until() {
408 init_and_close::<crate::TokioSpawner, _, _>(|watermark| async move {
409 watermark.set_done_util(1).unwrap();
410 assert_eq!(watermark.done_until().unwrap(), 1);
411 })
412 .await;
413 }
414
415 #[tokio::test]
416 async fn test_last_index() {
417 init_and_close::<crate::TokioSpawner, _, _>(|watermark| async move {
418 watermark
419 .begin_many([1, 2, 3].into_iter().collect())
420 .unwrap();
421 watermark.done_many([2, 3].into_iter().collect()).unwrap();
422
423 assert_eq!(watermark.last_index().unwrap(), 3);
424 })
425 .await;
426 }
427
428 #[tokio::test]
429 async fn test_multiple_singles() {
430 let closer = AsyncCloser::<crate::TokioSpawner>::default();
431 closer.signal();
432 closer.signal();
433 closer.signal_and_wait().await;
434
435 let closer = AsyncCloser::<crate::TokioSpawner>::new(1);
436 closer.done();
437 closer.signal_and_wait().await;
438 closer.signal_and_wait().await;
439 closer.signal();
440 }
441
442 #[tokio::test]
443 async fn test_closer() {
444 let closer = AsyncCloser::<crate::TokioSpawner>::new(1);
445 let tc = closer.clone();
446 tokio::spawn(async move {
447 tc.listen().wait().await;
448 tc.done();
449 });
450 closer.signal_and_wait().await;
451 }
452
453 #[tokio::test]
454 async fn test_closer_() {
455 use async_channel::unbounded;
456 use core::time::Duration;
457
458 let (tx, rx) = unbounded();
459
460 let c = AsyncCloser::<crate::TokioSpawner>::default();
461
462 for _ in 0..10 {
463 let c = c.clone();
464 let tx = tx.clone();
465 tokio::spawn(async move {
466 c.listen().wait().await;
467 tx.send(()).await.unwrap();
468 });
469 }
470 c.signal();
471 for _ in 0..10 {
472 tokio::time::timeout(Duration::from_millis(1000), rx.recv())
473 .await
474 .unwrap()
475 .unwrap();
476 }
477 }
478}