Skip to main content

rustbac_client/
throttle.rs

1use rustbac_datalink::DataLinkAddress;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
6use tokio::time::Instant;
7
8#[derive(Clone, Copy, Debug)]
9struct DeviceThrottleConfig {
10    max_concurrent: usize,
11    min_interval: Duration,
12}
13
14/// Per-device request coordination primitive.
15///
16/// This utility lets callers limit concurrent requests per target address and
17/// enforce a minimum delay between request starts.
18#[derive(Debug)]
19pub struct DeviceThrottle {
20    semaphores: Mutex<HashMap<DataLinkAddress, Arc<Semaphore>>>,
21    last_request: Mutex<HashMap<DataLinkAddress, Instant>>,
22    overrides: Mutex<HashMap<DataLinkAddress, DeviceThrottleConfig>>,
23    default_max_concurrent: usize,
24    default_min_interval: Duration,
25}
26
27impl DeviceThrottle {
28    /// Creates a new throttle using the given defaults.
29    pub fn new(max_concurrent: usize, min_interval: Duration) -> Self {
30        Self {
31            semaphores: Mutex::new(HashMap::new()),
32            last_request: Mutex::new(HashMap::new()),
33            overrides: Mutex::new(HashMap::new()),
34            default_max_concurrent: max_concurrent.max(1),
35            default_min_interval: min_interval,
36        }
37    }
38
39    /// Sets (or replaces) a per-device override.
40    pub async fn set_device_limit(
41        &self,
42        address: DataLinkAddress,
43        max_concurrent: usize,
44        min_interval: Duration,
45    ) {
46        let max_concurrent = max_concurrent.max(1);
47        self.overrides.lock().await.insert(
48            address,
49            DeviceThrottleConfig {
50                max_concurrent,
51                min_interval,
52            },
53        );
54
55        // Swap in a new semaphore so future acquisitions use the new limit.
56        self.semaphores
57            .lock()
58            .await
59            .insert(address, Arc::new(Semaphore::new(max_concurrent)));
60    }
61
62    /// Acquires a permit for `address`, respecting per-device concurrency and
63    /// minimum interval between request starts.
64    pub async fn acquire(&self, address: DataLinkAddress) -> OwnedSemaphorePermit {
65        let config = {
66            let overrides = self.overrides.lock().await;
67            overrides
68                .get(&address)
69                .copied()
70                .unwrap_or(DeviceThrottleConfig {
71                    max_concurrent: self.default_max_concurrent,
72                    min_interval: self.default_min_interval,
73                })
74        };
75
76        let semaphore = {
77            let mut semaphores = self.semaphores.lock().await;
78            semaphores
79                .entry(address)
80                .or_insert_with(|| Arc::new(Semaphore::new(config.max_concurrent)))
81                .clone()
82        };
83
84        let permit = semaphore
85            .acquire_owned()
86            .await
87            .expect("device throttle semaphore closed unexpectedly");
88
89        if !config.min_interval.is_zero() {
90            let mut last_request = self.last_request.lock().await;
91            if let Some(last) = last_request.get(&address) {
92                let elapsed = last.elapsed();
93                if elapsed < config.min_interval {
94                    tokio::time::sleep(config.min_interval - elapsed).await;
95                }
96            }
97            last_request.insert(address, Instant::now());
98        }
99
100        permit
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::DeviceThrottle;
107    use rustbac_datalink::DataLinkAddress;
108    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
109    use std::time::Duration;
110    use tokio::time::{timeout, Instant};
111
112    fn addr(port: u16) -> DataLinkAddress {
113        DataLinkAddress::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port))
114    }
115
116    #[tokio::test]
117    async fn enforces_concurrency_limit() {
118        let throttle = DeviceThrottle::new(1, Duration::ZERO);
119        let first = throttle.acquire(addr(47808)).await;
120
121        let blocked = timeout(Duration::from_millis(40), throttle.acquire(addr(47808))).await;
122        assert!(blocked.is_err());
123
124        drop(first);
125        let second = timeout(Duration::from_millis(200), throttle.acquire(addr(47808)))
126            .await
127            .expect("second permit should be acquired");
128        drop(second);
129    }
130
131    #[tokio::test]
132    async fn enforces_minimum_interval() {
133        let throttle = DeviceThrottle::new(1, Duration::from_millis(80));
134        let first = throttle.acquire(addr(47809)).await;
135        drop(first);
136
137        let started = Instant::now();
138        let second = throttle.acquire(addr(47809)).await;
139        let elapsed = started.elapsed();
140        drop(second);
141
142        assert!(
143            elapsed >= Duration::from_millis(70),
144            "elapsed {:?} was shorter than expected interval",
145            elapsed
146        );
147    }
148
149    #[tokio::test]
150    async fn applies_per_device_overrides() {
151        let throttle = DeviceThrottle::new(1, Duration::from_millis(120));
152        let target = addr(47810);
153        let other = addr(47811);
154
155        throttle
156            .set_device_limit(target, 2, Duration::from_millis(10))
157            .await;
158
159        let first = throttle.acquire(target).await;
160        let second = throttle.acquire(target).await;
161        let third = timeout(Duration::from_millis(40), throttle.acquire(target)).await;
162        assert!(
163            third.is_err(),
164            "third permit should block at override limit"
165        );
166        drop(first);
167        drop(second);
168
169        let first_other = throttle.acquire(other).await;
170        let blocked_other = timeout(Duration::from_millis(40), throttle.acquire(other)).await;
171        assert!(blocked_other.is_err(), "default limit should still be one");
172        drop(first_other);
173    }
174}