s2n_quic_core/task/waker/
contract.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use alloc::{sync::Arc, task::Wake};
use core::{
    sync::atomic::{AtomicBool, Ordering},
    task::{Context, Poll, Waker},
};

/// Checks that if a function returns [`Poll::Pending`], then the function called [`Waker::clone`],
/// [`Waker::wake`], or [`Waker::wake_by_ref`] on the [`Context`]'s [`Waker`].
pub struct Contract {
    state: Arc<State>,
    waker: Waker,
}

struct State {
    inner: Waker,
    wake_called: AtomicBool,
}

impl Wake for State {
    #[inline]
    fn wake(self: Arc<Self>) {
        Wake::wake_by_ref(&self)
    }

    #[inline]
    fn wake_by_ref(self: &Arc<Self>) {
        self.wake_called.store(true, Ordering::Release);
        self.inner.wake_by_ref();
    }
}

impl Contract {
    /// Wraps a [`Context`] in the contract checker
    #[inline]
    pub fn new(cx: &mut Context) -> Self {
        let state = State {
            inner: cx.waker().clone(),
            wake_called: AtomicBool::new(false),
        };
        let state = Arc::new(state);
        let waker = Waker::from(state.clone());
        Self { state, waker }
    }

    /// Returns a new [`Context`] to be checked
    #[inline]
    pub fn context(&self) -> Context {
        Context::from_waker(&self.waker)
    }

    /// Checks the state of the waker based on the provided `outcome`
    #[inline]
    #[track_caller]
    pub fn check_outcome<T>(self, outcome: &Poll<T>) {
        if outcome.is_ready() {
            return;
        }

        let strong_count = Arc::strong_count(&self.state);
        let is_cloned = strong_count > 2; // 1 for `state`, one for our owned `waker`
        let wake_called = self.state.wake_called.load(Ordering::Acquire);

        let is_ok = is_cloned || wake_called;

        assert!(
            is_ok,
            "strong_count = {strong_count}; is_cloned = {is_cloned}; wake_called = {wake_called}"
        );
    }
}

/// Checks that if a function returns [`Poll::Pending`], then the function called [`Waker::clone`],
/// [`Waker::wake`], or [`Waker::wake_by_ref`] on the [`Context`]'s [`Waker`].
#[inline(always)]
#[track_caller]
pub fn assert_contract<F: FnOnce(&mut Context) -> Poll<R>, R>(cx: &mut Context, f: F) -> Poll<R> {
    let contract = Contract::new(cx);
    let mut cx = contract.context();
    let outcome = f(&mut cx);
    contract.check_outcome(&outcome);
    outcome
}

/// Checks that if a function returns [`Poll::Pending`], then the function called [`Waker::clone`],
/// [`Waker::wake`], or [`Waker::wake_by_ref`] on the [`Context`]'s [`Waker`].
///
/// This is only enabled with `debug_assertions`.
#[inline(always)]
#[track_caller]
pub fn debug_assert_contract<F: FnOnce(&mut Context) -> Poll<R>, R>(
    cx: &mut Context,
    f: F,
) -> Poll<R> {
    #[cfg(debug_assertions)]
    return assert_contract(cx, f);

    #[cfg(not(debug_assertions))]
    return f(cx);
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::task::waker;

    #[test]
    fn correct_test() {
        let waker = waker::noop();
        let mut cx = Context::from_waker(&waker);

        // the contract isn't violated when returning Ready
        let _ = assert_contract(&mut cx, |_cx| Poll::Ready(()));

        // the contract isn't violated if the waker is immediately woken
        let _ = assert_contract(&mut cx, |cx| {
            cx.waker().wake_by_ref();
            Poll::<()>::Pending
        });

        // the contract isn't violated if the waker is cloned then immediately woken
        let _ = assert_contract(&mut cx, |cx| {
            let waker = cx.waker().clone();
            waker.wake();
            Poll::<()>::Pending
        });

        // the contract isn't violated if the waker is cloned and stored for later
        let mut stored = None;
        let _ = assert_contract(&mut cx, |cx| {
            stored = Some(cx.waker().clone());
            Poll::<()>::Pending
        });
    }

    #[test]
    #[should_panic]
    fn incorrect_test() {
        let waker = waker::noop();
        let mut cx = Context::from_waker(&waker);

        // the contract is violated if we return Pending without doing anything
        let _ = assert_contract(&mut cx, |_cx| Poll::<()>::Pending);
    }
}