s2n_quic_core/task/waker/
contract.rsuse alloc::{sync::Arc, task::Wake};
use core::{
sync::atomic::{AtomicBool, Ordering},
task::{Context, Poll, Waker},
};
pub struct Contract {
state: Arc<State>,
waker: Waker,
}
struct State {
inner: Waker,
wake_called: AtomicBool,
}
impl Wake for State {
#[inline]
fn wake(self: Arc<Self>) {
Wake::wake_by_ref(&self)
}
#[inline]
fn wake_by_ref(self: &Arc<Self>) {
self.wake_called.store(true, Ordering::Release);
self.inner.wake_by_ref();
}
}
impl Contract {
#[inline]
pub fn new(cx: &mut Context) -> Self {
let state = State {
inner: cx.waker().clone(),
wake_called: AtomicBool::new(false),
};
let state = Arc::new(state);
let waker = Waker::from(state.clone());
Self { state, waker }
}
#[inline]
pub fn context(&self) -> Context {
Context::from_waker(&self.waker)
}
#[inline]
#[track_caller]
pub fn check_outcome<T>(self, outcome: &Poll<T>) {
if outcome.is_ready() {
return;
}
let strong_count = Arc::strong_count(&self.state);
let is_cloned = strong_count > 2; let wake_called = self.state.wake_called.load(Ordering::Acquire);
let is_ok = is_cloned || wake_called;
assert!(
is_ok,
"strong_count = {strong_count}; is_cloned = {is_cloned}; wake_called = {wake_called}"
);
}
}
#[inline(always)]
#[track_caller]
pub fn assert_contract<F: FnOnce(&mut Context) -> Poll<R>, R>(cx: &mut Context, f: F) -> Poll<R> {
let contract = Contract::new(cx);
let mut cx = contract.context();
let outcome = f(&mut cx);
contract.check_outcome(&outcome);
outcome
}
#[inline(always)]
#[track_caller]
pub fn debug_assert_contract<F: FnOnce(&mut Context) -> Poll<R>, R>(
cx: &mut Context,
f: F,
) -> Poll<R> {
#[cfg(debug_assertions)]
return assert_contract(cx, f);
#[cfg(not(debug_assertions))]
return f(cx);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::task::waker;
#[test]
fn correct_test() {
let waker = waker::noop();
let mut cx = Context::from_waker(&waker);
let _ = assert_contract(&mut cx, |_cx| Poll::Ready(()));
let _ = assert_contract(&mut cx, |cx| {
cx.waker().wake_by_ref();
Poll::<()>::Pending
});
let _ = assert_contract(&mut cx, |cx| {
let waker = cx.waker().clone();
waker.wake();
Poll::<()>::Pending
});
let mut stored = None;
let _ = assert_contract(&mut cx, |cx| {
stored = Some(cx.waker().clone());
Poll::<()>::Pending
});
}
#[test]
#[should_panic]
fn incorrect_test() {
let waker = waker::noop();
let mut cx = Context::from_waker(&waker);
let _ = assert_contract(&mut cx, |_cx| Poll::<()>::Pending);
}
}