1use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::RwLock;
13use tracing::{debug, trace};
14
15use sentinel_common::errors::{SentinelError, SentinelResult};
16
17use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
18
19#[derive(Debug, Clone)]
21pub struct LeastTokensQueuedConfig {
22 pub ewma_alpha: f64,
25 pub default_tps: f64,
27 pub min_tps: f64,
29}
30
31impl Default for LeastTokensQueuedConfig {
32 fn default() -> Self {
33 Self {
34 ewma_alpha: 0.3,
35 default_tps: 100.0, min_tps: 1.0,
37 }
38 }
39}
40
41struct TargetMetrics {
43 queued_tokens: AtomicU64,
45 queued_requests: AtomicU64,
47 tps_ewma: parking_lot::Mutex<f64>,
49 total_tokens: AtomicU64,
51 total_requests: AtomicU64,
53}
54
55impl TargetMetrics {
56 fn new(default_tps: f64) -> Self {
57 Self {
58 queued_tokens: AtomicU64::new(0),
59 queued_requests: AtomicU64::new(0),
60 tps_ewma: parking_lot::Mutex::new(default_tps),
61 total_tokens: AtomicU64::new(0),
62 total_requests: AtomicU64::new(0),
63 }
64 }
65
66 fn estimated_queue_time(&self, min_tps: f64) -> f64 {
68 let queued = self.queued_tokens.load(Ordering::Relaxed) as f64;
69 let tps = (*self.tps_ewma.lock()).max(min_tps);
70 queued / tps
71 }
72
73 fn enqueue(&self, tokens: u64) {
75 self.queued_tokens.fetch_add(tokens, Ordering::AcqRel);
76 self.queued_requests.fetch_add(1, Ordering::AcqRel);
77 }
78
79 fn dequeue(&self, tokens: u64, duration: Duration, ewma_alpha: f64) {
81 self.queued_tokens.fetch_saturating_sub(tokens);
83 self.queued_requests.fetch_saturating_sub(1);
84
85 self.total_tokens.fetch_add(tokens, Ordering::Relaxed);
87 self.total_requests.fetch_add(1, Ordering::Relaxed);
88
89 if duration.as_secs_f64() > 0.0 {
91 let measured_tps = tokens as f64 / duration.as_secs_f64();
92 let mut tps = self.tps_ewma.lock();
93 *tps = ewma_alpha * measured_tps + (1.0 - ewma_alpha) * *tps;
94 }
95 }
96}
97
98trait AtomicSaturatingSub {
100 fn fetch_saturating_sub(&self, val: u64);
101}
102
103impl AtomicSaturatingSub for AtomicU64 {
104 fn fetch_saturating_sub(&self, val: u64) {
105 loop {
106 let current = self.load(Ordering::Acquire);
107 let new = current.saturating_sub(val);
108 if self
109 .compare_exchange(current, new, Ordering::AcqRel, Ordering::Relaxed)
110 .is_ok()
111 {
112 break;
113 }
114 }
115 }
116}
117
118pub struct LeastTokensQueuedBalancer {
123 targets: Vec<UpstreamTarget>,
124 metrics: Arc<HashMap<String, TargetMetrics>>,
125 health_status: Arc<RwLock<HashMap<String, bool>>>,
126 config: LeastTokensQueuedConfig,
127}
128
129impl LeastTokensQueuedBalancer {
130 pub fn new(targets: Vec<UpstreamTarget>, config: LeastTokensQueuedConfig) -> Self {
132 let mut metrics = HashMap::new();
133 let mut health_status = HashMap::new();
134
135 for target in &targets {
136 let addr = target.full_address();
137 metrics.insert(addr.clone(), TargetMetrics::new(config.default_tps));
138 health_status.insert(addr, true);
139 }
140
141 Self {
142 targets,
143 metrics: Arc::new(metrics),
144 health_status: Arc::new(RwLock::new(health_status)),
145 config,
146 }
147 }
148
149 pub fn enqueue_tokens(&self, address: &str, estimated_tokens: u64) {
151 if let Some(metrics) = self.metrics.get(address) {
152 metrics.enqueue(estimated_tokens);
153 trace!(
154 target = address,
155 tokens = estimated_tokens,
156 queued = metrics.queued_tokens.load(Ordering::Relaxed),
157 "Enqueued tokens for target"
158 );
159 }
160 }
161
162 pub fn dequeue_tokens(&self, address: &str, actual_tokens: u64, duration: Duration) {
164 if let Some(metrics) = self.metrics.get(address) {
165 metrics.dequeue(actual_tokens, duration, self.config.ewma_alpha);
166 debug!(
167 target = address,
168 tokens = actual_tokens,
169 duration_ms = duration.as_millis() as u64,
170 queued = metrics.queued_tokens.load(Ordering::Relaxed),
171 tps = *metrics.tps_ewma.lock(),
172 "Dequeued tokens for target"
173 );
174 }
175 }
176
177 pub fn target_metrics(&self, address: &str) -> Option<LeastTokensQueuedTargetStats> {
179 self.metrics.get(address).map(|m| LeastTokensQueuedTargetStats {
180 queued_tokens: m.queued_tokens.load(Ordering::Relaxed),
181 queued_requests: m.queued_requests.load(Ordering::Relaxed),
182 tokens_per_second: *m.tps_ewma.lock(),
183 total_tokens: m.total_tokens.load(Ordering::Relaxed),
184 total_requests: m.total_requests.load(Ordering::Relaxed),
185 })
186 }
187
188 pub async fn queue_times(&self) -> Vec<(String, f64)> {
190 let health = self.health_status.read().await;
191 self.targets
192 .iter()
193 .filter_map(|t| {
194 let addr = t.full_address();
195 if *health.get(&addr).unwrap_or(&true) {
196 self.metrics
197 .get(&addr)
198 .map(|m| (addr, m.estimated_queue_time(self.config.min_tps)))
199 } else {
200 None
201 }
202 })
203 .collect()
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct LeastTokensQueuedTargetStats {
210 pub queued_tokens: u64,
211 pub queued_requests: u64,
212 pub tokens_per_second: f64,
213 pub total_tokens: u64,
214 pub total_requests: u64,
215}
216
217#[async_trait]
218impl LoadBalancer for LeastTokensQueuedBalancer {
219 async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
220 trace!(
221 total_targets = self.targets.len(),
222 algorithm = "least_tokens_queued",
223 "Selecting upstream target"
224 );
225
226 let health = self.health_status.read().await;
227
228 let mut best_target = None;
229 let mut min_queue_time = f64::MAX;
230
231 for target in &self.targets {
232 let addr = target.full_address();
233
234 if !*health.get(&addr).unwrap_or(&true) {
236 trace!(
237 target = %addr,
238 algorithm = "least_tokens_queued",
239 "Skipping unhealthy target"
240 );
241 continue;
242 }
243
244 let queue_time = self
246 .metrics
247 .get(&addr)
248 .map(|m| m.estimated_queue_time(self.config.min_tps))
249 .unwrap_or(0.0);
250
251 trace!(
252 target = %addr,
253 queue_time_secs = queue_time,
254 "Evaluating target queue time"
255 );
256
257 if queue_time < min_queue_time {
258 min_queue_time = queue_time;
259 best_target = Some(target);
260 }
261 }
262
263 match best_target {
264 Some(target) => {
265 debug!(
266 selected_target = %target.full_address(),
267 queue_time_secs = min_queue_time,
268 algorithm = "least_tokens_queued",
269 "Selected target with lowest queue time"
270 );
271 Ok(TargetSelection {
272 address: target.full_address(),
273 weight: target.weight,
274 metadata: HashMap::new(),
275 })
276 }
277 None => {
278 tracing::warn!(
279 total_targets = self.targets.len(),
280 algorithm = "least_tokens_queued",
281 "No healthy upstream targets available"
282 );
283 Err(SentinelError::NoHealthyUpstream)
284 }
285 }
286 }
287
288 async fn report_health(&self, address: &str, healthy: bool) {
289 trace!(
290 target = %address,
291 healthy = healthy,
292 algorithm = "least_tokens_queued",
293 "Updating target health status"
294 );
295 self.health_status
296 .write()
297 .await
298 .insert(address.to_string(), healthy);
299 }
300
301 async fn healthy_targets(&self) -> Vec<String> {
302 self.health_status
303 .read()
304 .await
305 .iter()
306 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
307 .collect()
308 }
309
310 async fn report_result(
311 &self,
312 selection: &TargetSelection,
313 success: bool,
314 latency: Option<Duration>,
315 ) {
316 self.report_health(&selection.address, success).await;
318
319 }
322
323 async fn report_result_with_latency(
324 &self,
325 address: &str,
326 success: bool,
327 latency: Option<Duration>,
328 ) {
329 self.report_health(address, success).await;
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 fn test_targets() -> Vec<UpstreamTarget> {
338 vec![
339 UpstreamTarget::new("server1", 8080, 100),
340 UpstreamTarget::new("server2", 8080, 100),
341 UpstreamTarget::new("server3", 8080, 100),
342 ]
343 }
344
345 #[tokio::test]
346 async fn test_basic_selection() {
347 let balancer = LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
348
349 let selection = balancer.select(None).await.unwrap();
351 assert!(!selection.address.is_empty());
352 }
353
354 #[tokio::test]
355 async fn test_selects_least_queued() {
356 let balancer = LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
357
358 balancer.enqueue_tokens("server1:8080", 1000);
360 balancer.enqueue_tokens("server2:8080", 500);
361 let selection = balancer.select(None).await.unwrap();
364 assert_eq!(selection.address, "server3:8080");
365 }
366
367 #[tokio::test]
368 async fn test_dequeue_updates_tps() {
369 let balancer = LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
370
371 balancer.enqueue_tokens("server1:8080", 1000);
373 balancer.dequeue_tokens("server1:8080", 1000, Duration::from_secs(1));
374
375 let stats = balancer.target_metrics("server1:8080").unwrap();
377 assert!(stats.total_tokens == 1000);
378 assert!(stats.total_requests == 1);
379 }
380
381 #[tokio::test]
382 async fn test_unhealthy_target_skipped() {
383 let balancer = LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
384
385 balancer.report_health("server3:8080", false).await;
387
388 balancer.enqueue_tokens("server1:8080", 1000);
390
391 let selection = balancer.select(None).await.unwrap();
393 assert_eq!(selection.address, "server2:8080");
394 }
395}