1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
pub mod error;

mod bucket;
mod headers;

pub use self::{
    error::{RatelimitError, RatelimitResult},
    headers::RatelimitHeaders,
};

use crate::routing::Path;
use bucket::{Bucket, BucketQueueTask, TimeRemaining};
use futures_channel::oneshot::{self, Receiver, Sender};
use futures_util::lock::Mutex;
use std::{
    collections::hash_map::{Entry, HashMap},
    sync::{
        atomic::{AtomicBool, Ordering},
        Arc,
    },
    time::Duration,
};

/// Global lock. We use a pair to avoid actually locking the mutex every check.
/// This allows futures to only wait on the global lock when a global ratelimit
/// is in place by, in turn, waiting for a guard, and then each immediately
/// dropping it.
#[derive(Debug, Default)]
struct GlobalLockPair(Mutex<()>, AtomicBool);

impl GlobalLockPair {
    pub fn lock(&self) {
        self.1.store(true, Ordering::Release);
    }

    pub fn unlock(&self) {
        self.1.store(false, Ordering::Release);
    }

    pub fn is_locked(&self) -> bool {
        self.1.load(Ordering::Relaxed)
    }
}

#[derive(Clone, Debug, Default)]
pub struct Ratelimiter {
    buckets: Arc<Mutex<HashMap<Path, Arc<Bucket>>>>,
    global: Arc<GlobalLockPair>,
}

impl Ratelimiter {
    /// Create a new ratelimiter.
    ///
    /// Most users won't need to use this directly. If you're creating your own
    /// HTTP proxy then this is good to use for your own ratelimiting.
    pub fn new() -> Self {
        Self::default()
    }

    pub async fn get(&self, path: Path) -> Receiver<Sender<Option<RatelimitHeaders>>> {
        tracing::debug!("getting bucket for path: {:?}", path);

        let (tx, rx) = oneshot::channel();
        let (bucket, fresh) = self.entry(path.clone(), tx).await;

        if fresh {
            tokio::spawn(
                BucketQueueTask::new(
                    bucket,
                    Arc::clone(&self.buckets),
                    Arc::clone(&self.global),
                    path,
                )
                .run(),
            );
        }

        rx
    }

    /// Provide an estimate for the time left until a path can be used
    /// without being ratelimited.
    ///
    /// This method is not guaranteed to be accurate and may return
    /// None if either no ratelimit is known or buckets are remaining.
    pub async fn time_until_available(&self, path: &Path) -> Option<Duration> {
        let buckets = self.buckets.lock().await;
        match buckets.get(path)?.time_remaining().await {
            TimeRemaining::Finished | TimeRemaining::NotStarted => None,
            TimeRemaining::Some(duration) => Some(duration),
        }
    }

    async fn entry(
        &self,
        path: Path,
        tx: Sender<Sender<Option<RatelimitHeaders>>>,
    ) -> (Arc<Bucket>, bool) {
        // nb: not realisically point of contention
        let mut buckets = self.buckets.lock().await;

        match buckets.entry(path.clone()) {
            Entry::Occupied(bucket) => {
                tracing::debug!("got existing bucket: {:?}", path);

                let bucket = bucket.into_mut();
                bucket.queue.push(tx);
                tracing::debug!("added request into bucket queue: {:?}", path);

                (Arc::clone(&bucket), false)
            }
            Entry::Vacant(entry) => {
                tracing::debug!("making new bucket for path: {:?}", path);
                let bucket = Bucket::new(path.clone());
                bucket.queue.push(tx);

                let bucket = Arc::new(bucket);
                entry.insert(Arc::clone(&bucket));

                (bucket, true)
            }
        }
    }
}