1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
use super::{DayLimiter, Queue};
use futures_channel::{
    mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
    oneshot::{self, Sender},
};
use futures_util::{sink::SinkExt, stream::StreamExt};
use std::{fmt::Debug, future::Future, pin::Pin, time::Duration};
use tokio::time::delay_for;

/// Queue built for single-process clusters that require identifying via
/// [Sharding for Very Large Bots].
///
/// Usage with other processes will cause inconsistencies between each process
/// cluster's ratelimit buckets. If you use multiple processes for clusters,
/// then refer to the [module-level] documentation.
///
/// [Sharding for Very Large Bots]: https://discord.com/developers/docs/topics/gateway#sharding-for-very-large-bots
/// [module-level]: ./index.html
#[derive(Debug)]
pub struct LargeBotQueue {
    buckets: Vec<UnboundedSender<Sender<()>>>,
    limiter: DayLimiter,
}

impl LargeBotQueue {
    /// Create a new large bot queue.
    ///
    /// You must provide the number of buckets Discord requires your bot to
    /// connect with.
    pub async fn new(buckets: usize, http: &twilight_http::Client) -> Self {
        let mut queues = Vec::with_capacity(buckets);
        for _ in 0..buckets {
            let (tx, rx) = unbounded();

            tokio::spawn(waiter(rx));

            queues.push(tx)
        }

        let limiter = DayLimiter::new(http).await.expect(
            "Getting the first session limits failed, \
             Is network connection available?",
        );

        // The level_enabled macro does not turn off with the dynamic
        // tracing levels. It is made for the static_max_level_xxx features
        // And will return false if you do not use those features of if
        // You use the feature but then dynamically set a lower feature.
        if tracing::level_enabled!(tracing::Level::INFO) {
            let lock = limiter.0.lock().await;
            tracing::info!(
                "{}/{} identifies used before next reset in {:.2?}",
                lock.current,
                lock.total,
                lock.next_reset
            );
        }

        Self {
            buckets: queues,
            limiter,
        }
    }
}

async fn waiter(mut rx: UnboundedReceiver<Sender<()>>) {
    const DUR: Duration = Duration::from_secs(6);
    while let Some(req) = rx.next().await {
        if let Err(err) = req.send(()) {
            tracing::warn!("skipping, send failed with: {:?}", err);
        }
        delay_for(DUR).await;
    }
}

impl Queue for LargeBotQueue {
    /// Request to be able to identify with the gateway. This will place this
    /// request behind all other requests, and the returned future will resolve
    /// once the request has been completed.
    fn request(&'_ self, shard_id: [u64; 2]) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
        #[allow(clippy::cast_possible_truncation)]
        let bucket = (shard_id[0] % (self.buckets.len() as u64)) as usize;
        let (tx, rx) = oneshot::channel();

        Box::pin(async move {
            self.limiter.get().await;
            if let Err(err) = self.buckets[bucket].clone().send(tx).await {
                tracing::warn!("skipping, send failed with: {:?}", err);
                return;
            }

            tracing::info!("waiting for allowance on shard {}", shard_id[0]);

            let _ = rx.await;
        })
    }
}

#[cfg(test)]
mod tests {
    use super::{LargeBotQueue, Queue};
    use static_assertions::assert_impl_all;
    use std::fmt::Debug;

    assert_impl_all!(LargeBotQueue: Debug, Queue, Send, Sync);
}