1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};

#[derive(Default)]
pub struct DataBudget {
    // Amount of bytes we have in the budget to send.
    bytes: AtomicUsize,
    // Last time that we upped the bytes count, used
    // to detect when to up the bytes budget again
    last_timestamp_ms: AtomicU64,
}

impl DataBudget {
    /// Create a data budget with max bytes, used for tests
    pub fn restricted() -> Self {
        Self {
            bytes: AtomicUsize::default(),
            last_timestamp_ms: AtomicU64::new(u64::MAX),
        }
    }

    // If there are enough bytes in the budget, consumes from
    // the budget and returns true. Otherwise returns false.
    #[must_use]
    pub fn take(&self, size: usize) -> bool {
        let mut budget = self.bytes.load(Ordering::Acquire);
        loop {
            if budget < size {
                return false;
            }
            match self.bytes.compare_exchange_weak(
                budget,
                budget.saturating_sub(size),
                Ordering::AcqRel,
                Ordering::Acquire,
            ) {
                Ok(_) => return true,
                Err(bytes) => budget = bytes,
            }
        }
    }

    // Updates timestamp and returns true, if at least given milliseconds
    // has passed since last update. Otherwise returns false.
    fn can_update(&self, duration_millis: u64) -> bool {
        let now = solana_sdk::timing::timestamp();
        let mut last_timestamp = self.last_timestamp_ms.load(Ordering::Acquire);
        loop {
            if now < last_timestamp.saturating_add(duration_millis) {
                return false;
            }
            match self.last_timestamp_ms.compare_exchange_weak(
                last_timestamp,
                now,
                Ordering::AcqRel,
                Ordering::Acquire,
            ) {
                Ok(_) => return true,
                Err(ts) => last_timestamp = ts,
            }
        }
    }

    /// Updates the budget if at least given milliseconds has passed since last
    /// update. Updater function maps current value of bytes to the new one.
    /// Returns current data-budget after the update.
    pub fn update<F>(&self, duration_millis: u64, updater: F) -> usize
    where
        F: Fn(usize) -> usize,
    {
        if self.can_update(duration_millis) {
            let mut bytes = self.bytes.load(Ordering::Acquire);
            loop {
                match self.bytes.compare_exchange_weak(
                    bytes,
                    updater(bytes),
                    Ordering::AcqRel,
                    Ordering::Acquire,
                ) {
                    Ok(_) => break,
                    Err(b) => bytes = b,
                }
            }
        }
        self.bytes.load(Ordering::Acquire)
    }

    // Non-atomic clone only for tests and simulations.
    pub fn clone_non_atomic(&self) -> Self {
        Self {
            bytes: AtomicUsize::new(self.bytes.load(Ordering::Acquire)),
            last_timestamp_ms: AtomicU64::new(self.last_timestamp_ms.load(Ordering::Acquire)),
        }
    }
}

#[cfg(test)]
mod tests {
    use {super::*, std::time::Duration};

    #[test]
    fn test_data_budget() {
        let budget = DataBudget::default();
        assert!(!budget.take(1)); // budget = 0.

        assert_eq!(budget.update(1000, |bytes| bytes + 5), 5); // budget updates to 5.
        assert!(budget.take(1));
        assert!(budget.take(2));
        assert!(!budget.take(3)); // budget = 2, out of budget.

        assert_eq!(budget.update(30, |_| 10), 2); // no update, budget = 2.
        assert!(!budget.take(3)); // budget = 2, out of budget.

        std::thread::sleep(Duration::from_millis(50));
        assert_eq!(budget.update(30, |bytes| bytes * 2), 4); // budget updates to 4.

        assert!(budget.take(3));
        assert!(budget.take(1));
        assert!(!budget.take(1)); // budget = 0.
    }
}