1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use crate::{
    message::Message,
    socket::{ring::Producer, task::events},
};
use core::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use s2n_quic_core::task::cooldown::Cooldown;

pub use events::RxEvents as Events;

pub trait Socket<T: Message> {
    type Error;

    fn recv(
        &mut self,
        cx: &mut Context,
        entries: &mut [T],
        events: &mut Events,
    ) -> Result<(), Self::Error>;
}

pub struct Receiver<T: Message, S: Socket<T>> {
    ring: Producer<T>,
    /// Implementation of a socket that fills free slots in the ring buffer
    rx: S,
    ring_cooldown: Cooldown,
    io_cooldown: Cooldown,
}

impl<T, S> Receiver<T, S>
where
    T: Message + Unpin,
    S: Socket<T> + Unpin,
{
    #[inline]
    pub fn new(ring: Producer<T>, rx: S, cooldown: Cooldown) -> Self {
        Self {
            ring,
            rx,
            ring_cooldown: cooldown.clone(),
            io_cooldown: cooldown,
        }
    }

    #[inline]
    fn poll_ring(&mut self, watermark: u32, cx: &mut Context) -> Poll<Result<(), ()>> {
        loop {
            let is_loop = self.ring_cooldown.state().is_loop();

            let count = if is_loop {
                self.ring.acquire(watermark)
            } else {
                match self.ring.poll_acquire(watermark, cx) {
                    Poll::Ready(count) => count,
                    Poll::Pending if !self.ring.is_open() => return Err(()).into(),
                    Poll::Pending => 0,
                }
            };

            // if the number of free slots increased since last time then yield
            if count > 0 {
                self.ring_cooldown.on_ready();
                return Ok(()).into();
            }

            if is_loop && self.ring_cooldown.on_pending_task(cx).is_sleep() {
                continue;
            }

            return Poll::Pending;
        }
    }
}

impl<T, S> Future for Receiver<T, S>
where
    T: Message + Unpin,
    S: Socket<T> + Unpin,
{
    type Output = Option<S::Error>;

    #[inline]
    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
        let this = self.get_mut();

        let mut events = Events::default();

        let mut pending_wake = false;

        while !events.take_blocked() {
            match this.poll_ring(u32::MAX, cx) {
                Poll::Ready(Ok(_)) => {}
                Poll::Ready(Err(_)) => return None.into(),
                Poll::Pending => {
                    if pending_wake {
                        this.ring.wake();
                    }
                    return Poll::Pending;
                }
            }

            let entries = this.ring.data();

            // perform the recv syscall
            match this.rx.recv(cx, entries, &mut events) {
                Ok(_) => {
                    // increment the number of received messages
                    let count = events.take_count() as u32;

                    if count > 0 {
                        this.ring.release_no_wake(count);
                        this.io_cooldown.on_ready();
                        pending_wake = true;
                    }
                }
                Err(err) => return Some(err).into(),
            }
        }

        this.io_cooldown.on_pending_task(cx);

        if pending_wake {
            this.ring.wake();
        }

        Poll::Pending
    }
}