1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
use super::{super::json, session::SessionSendError};
use async_tungstenite::tungstenite::Message as TungsteniteMessage;
use futures_channel::mpsc::UnboundedSender;
use serde::{Deserialize, Serialize};
use std::{
    collections::VecDeque,
    convert::TryInto,
    sync::{
        atomic::{AtomicU32, AtomicU64, Ordering},
        Arc, Mutex,
    },
    time::{Duration, Instant},
};
use twilight_model::gateway::payload::Heartbeat;

/// Information about the latency of a [`Shard`]'s websocket connection.
///
/// This is obtained through [`Shard::info`].
///
/// [`Shard`]: crate::shard::Shard
/// [`Shard::info`]: crate::shard::Shard::info
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Latency {
    average: Option<Duration>,
    heartbeats: u32,
    recent: VecDeque<Duration>,
    #[serde(skip)]
    received: Option<Instant>,
    #[serde(skip)]
    sent: Option<Instant>,
}

impl Latency {
    /// The average time it took to receive an acknowledgement for every
    /// heartbeat sent over the duration of the session.
    ///
    /// For example, a reasonable value for this may be between 10 to 100
    /// milliseconds depending on the network connection and physical location.
    ///
    /// # Note
    ///
    /// If this is None, the shard has not received a heartbeat yet.
    pub fn average(&self) -> Option<Duration> {
        self.average
    }

    /// The total number of heartbeats that have been sent during this session.
    pub fn heartbeats(&self) -> u32 {
        self.heartbeats
    }

    /// The 5 most recent latency times.
    ///
    /// Index 0 is the oldest, 4 is the most recent.
    pub fn recent(&self) -> &VecDeque<Duration> {
        &self.recent
    }

    /// When the last heartbeat acknowledgement was received.
    pub fn received(&self) -> Option<Instant> {
        self.received
    }

    /// When the last heartbeat was sent.
    pub fn sent(&self) -> Option<Instant> {
        self.sent
    }
}

#[derive(Debug)]
pub struct Heartbeats {
    received: Mutex<Option<Instant>>,
    recent: Mutex<VecDeque<u64>>,
    sent: Mutex<Option<Instant>>,
    total_iterations: AtomicU32,
    total_time: AtomicU64,
}

impl Heartbeats {
    pub fn latency(&self) -> Latency {
        let iterations = self.total_iterations();
        let recent = self
            .recent
            .lock()
            .expect("recent poisoned")
            .iter()
            .copied()
            .map(Duration::from_millis)
            .collect();

        Latency {
            average: self.total_time().checked_div(iterations),
            heartbeats: iterations,
            recent,
            received: self.received(),
            sent: self.sent(),
        }
    }

    pub fn last_acked(&self) -> bool {
        self.received().is_some()
    }

    pub fn receive(&self) {
        self.set_received(Instant::now());

        self.total_iterations.fetch_add(1, Ordering::SeqCst);

        if let Some(dur) = self.sent().map(|s| s.elapsed()) {
            let millis = if let Ok(millis) = dur.as_millis().try_into() {
                millis
            } else {
                tracing::error!("duration millis is more than u64: {:?}", dur);

                return;
            };

            self.total_time.fetch_add(millis, Ordering::SeqCst);

            let mut recent = self.recent.lock().expect("recent poisoned");

            if recent.len() == 5 {
                recent.pop_front();
            }

            recent.push_back(millis);
        }
    }

    pub fn send(&self) {
        self.received.lock().expect("received poisoned").take();
        self.sent
            .lock()
            .expect("sent poisoned")
            .replace(Instant::now());
    }

    fn received(&self) -> Option<Instant> {
        *self.received.lock().expect("received poisoned")
    }

    fn set_received(&self, received: Instant) {
        self.received
            .lock()
            .expect("received poisoned")
            .replace(received);
    }

    fn sent(&self) -> Option<Instant> {
        *self.sent.lock().expect("sent poisoned")
    }

    fn total_iterations(&self) -> u32 {
        self.total_iterations.load(Ordering::Relaxed)
    }

    fn total_time(&self) -> Duration {
        Duration::from_millis(self.total_time.load(Ordering::Relaxed))
    }
}

impl Default for Heartbeats {
    fn default() -> Self {
        Self {
            received: Mutex::new(None),
            recent: Mutex::new(VecDeque::with_capacity(5)),
            sent: Mutex::new(None),
            total_iterations: AtomicU32::new(0),
            total_time: AtomicU64::new(0),
        }
    }
}

pub struct Heartbeater {
    heartbeats: Arc<Heartbeats>,
    interval: u64,
    seq: Arc<AtomicU64>,
    tx: UnboundedSender<TungsteniteMessage>,
}

impl Heartbeater {
    pub fn new(
        heartbeats: Arc<Heartbeats>,
        interval: u64,
        seq: Arc<AtomicU64>,
        tx: UnboundedSender<TungsteniteMessage>,
    ) -> Self {
        Self {
            heartbeats,
            interval,
            seq,
            tx,
        }
    }

    pub async fn run(self) {
        if let Err(why) = self.try_run().await {
            tracing::warn!("Error sending heartbeat: {:?}", why);
        }
    }

    // If there's an issue sending over the channel, then odds are it
    // got disconnected due to the session ending. This task should have
    // *also* become aborted. Log if that's the case, because that's a
    // programmatic error.
    async fn try_run(self) -> Result<(), SessionSendError> {
        let duration = Duration::from_millis(self.interval);

        let mut last = true;

        loop {
            tokio::time::sleep(duration).await;

            // Check if a heartbeat acknowledgement was received.
            //
            // If so, then check if one was received last time.
            //
            // - if so, then mark that we didn't get one this time
            // - if not, then end the heartbeater because something is off
            // (connecting closed?)
            if self.heartbeats.last_acked() {
                last = true;
            } else if last {
                last = false;
            } else {
                return Ok(());
            }

            let seq = self.seq.load(Ordering::Acquire);
            let heartbeat = Heartbeat::new(seq);
            let bytes = json::to_vec(&heartbeat)
                .map_err(|source| SessionSendError::Serializing { source })?;

            tracing::debug!(seq, "sending heartbeat");
            self.tx
                .unbounded_send(TungsteniteMessage::Binary(bytes))
                .map_err(|source| SessionSendError::Sending { source })?;
            tracing::debug!(seq, "sent heartbeat");
            self.heartbeats.send();
        }
    }
}

#[cfg(test)]
mod tests {
    use super::Latency;
    use static_assertions::assert_impl_all;
    use std::fmt::Debug;

    assert_impl_all!(Latency: Clone, Debug, Send, Sync);
}