sentinel_proxy/upstream/
least_tokens.rs

1//! Least Tokens Queued load balancer for inference workloads
2//!
3//! This load balancer selects upstreams based on the estimated number of tokens
4//! currently being processed, optimized for LLM/AI inference traffic where
5//! request processing time correlates strongly with token count.
6
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::RwLock;
13use tracing::{debug, trace};
14
15use sentinel_common::errors::{SentinelError, SentinelResult};
16
17use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
18
19/// Configuration for the least tokens queued balancer
20#[derive(Debug, Clone)]
21pub struct LeastTokensQueuedConfig {
22    /// Smoothing factor for tokens-per-second EWMA (0.0-1.0)
23    /// Higher values = more responsive to recent measurements
24    pub ewma_alpha: f64,
25    /// Default tokens-per-second estimate for new targets
26    pub default_tps: f64,
27    /// Minimum tokens-per-second to avoid division issues
28    pub min_tps: f64,
29}
30
31impl Default for LeastTokensQueuedConfig {
32    fn default() -> Self {
33        Self {
34            ewma_alpha: 0.3,
35            default_tps: 100.0,  // Conservative default
36            min_tps: 1.0,
37        }
38    }
39}
40
41/// Per-target metrics for token-aware load balancing
42struct TargetMetrics {
43    /// Currently queued tokens (estimated)
44    queued_tokens: AtomicU64,
45    /// Currently queued requests
46    queued_requests: AtomicU64,
47    /// Exponentially weighted moving average of tokens per second
48    tps_ewma: parking_lot::Mutex<f64>,
49    /// Total tokens processed (for debugging/metrics)
50    total_tokens: AtomicU64,
51    /// Total requests processed
52    total_requests: AtomicU64,
53}
54
55impl TargetMetrics {
56    fn new(default_tps: f64) -> Self {
57        Self {
58            queued_tokens: AtomicU64::new(0),
59            queued_requests: AtomicU64::new(0),
60            tps_ewma: parking_lot::Mutex::new(default_tps),
61            total_tokens: AtomicU64::new(0),
62            total_requests: AtomicU64::new(0),
63        }
64    }
65
66    /// Get the estimated queue time: queued_tokens / tokens_per_second
67    fn estimated_queue_time(&self, min_tps: f64) -> f64 {
68        let queued = self.queued_tokens.load(Ordering::Relaxed) as f64;
69        let tps = (*self.tps_ewma.lock()).max(min_tps);
70        queued / tps
71    }
72
73    /// Add tokens to the queue (when request starts)
74    fn enqueue(&self, tokens: u64) {
75        self.queued_tokens.fetch_add(tokens, Ordering::AcqRel);
76        self.queued_requests.fetch_add(1, Ordering::AcqRel);
77    }
78
79    /// Remove tokens from queue and update TPS (when request completes)
80    fn dequeue(&self, tokens: u64, duration: Duration, ewma_alpha: f64) {
81        // Remove from queue
82        self.queued_tokens.fetch_saturating_sub(tokens);
83        self.queued_requests.fetch_saturating_sub(1);
84
85        // Update totals
86        self.total_tokens.fetch_add(tokens, Ordering::Relaxed);
87        self.total_requests.fetch_add(1, Ordering::Relaxed);
88
89        // Update TPS EWMA
90        if duration.as_secs_f64() > 0.0 {
91            let measured_tps = tokens as f64 / duration.as_secs_f64();
92            let mut tps = self.tps_ewma.lock();
93            *tps = ewma_alpha * measured_tps + (1.0 - ewma_alpha) * *tps;
94        }
95    }
96}
97
98/// Extension trait for AtomicU64 to add saturating_sub
99trait AtomicSaturatingSub {
100    fn fetch_saturating_sub(&self, val: u64);
101}
102
103impl AtomicSaturatingSub for AtomicU64 {
104    fn fetch_saturating_sub(&self, val: u64) {
105        loop {
106            let current = self.load(Ordering::Acquire);
107            let new = current.saturating_sub(val);
108            if self
109                .compare_exchange(current, new, Ordering::AcqRel, Ordering::Relaxed)
110                .is_ok()
111            {
112                break;
113            }
114        }
115    }
116}
117
118/// Least Tokens Queued load balancer
119///
120/// Selects the upstream with the lowest estimated queue time,
121/// calculated as: queued_tokens / tokens_per_second
122pub struct LeastTokensQueuedBalancer {
123    targets: Vec<UpstreamTarget>,
124    metrics: Arc<HashMap<String, TargetMetrics>>,
125    health_status: Arc<RwLock<HashMap<String, bool>>>,
126    config: LeastTokensQueuedConfig,
127}
128
129impl LeastTokensQueuedBalancer {
130    /// Create a new least tokens queued balancer
131    pub fn new(targets: Vec<UpstreamTarget>, config: LeastTokensQueuedConfig) -> Self {
132        let mut metrics = HashMap::new();
133        let mut health_status = HashMap::new();
134
135        for target in &targets {
136            let addr = target.full_address();
137            metrics.insert(addr.clone(), TargetMetrics::new(config.default_tps));
138            health_status.insert(addr, true);
139        }
140
141        Self {
142            targets,
143            metrics: Arc::new(metrics),
144            health_status: Arc::new(RwLock::new(health_status)),
145            config,
146        }
147    }
148
149    /// Enqueue tokens for a target (call when request starts)
150    pub fn enqueue_tokens(&self, address: &str, estimated_tokens: u64) {
151        if let Some(metrics) = self.metrics.get(address) {
152            metrics.enqueue(estimated_tokens);
153            trace!(
154                target = address,
155                tokens = estimated_tokens,
156                queued = metrics.queued_tokens.load(Ordering::Relaxed),
157                "Enqueued tokens for target"
158            );
159        }
160    }
161
162    /// Dequeue tokens for a target (call when request completes)
163    pub fn dequeue_tokens(&self, address: &str, actual_tokens: u64, duration: Duration) {
164        if let Some(metrics) = self.metrics.get(address) {
165            metrics.dequeue(actual_tokens, duration, self.config.ewma_alpha);
166            debug!(
167                target = address,
168                tokens = actual_tokens,
169                duration_ms = duration.as_millis() as u64,
170                queued = metrics.queued_tokens.load(Ordering::Relaxed),
171                tps = *metrics.tps_ewma.lock(),
172                "Dequeued tokens for target"
173            );
174        }
175    }
176
177    /// Get current metrics for a target (for debugging/observability)
178    pub fn target_metrics(&self, address: &str) -> Option<LeastTokensQueuedTargetStats> {
179        self.metrics.get(address).map(|m| LeastTokensQueuedTargetStats {
180            queued_tokens: m.queued_tokens.load(Ordering::Relaxed),
181            queued_requests: m.queued_requests.load(Ordering::Relaxed),
182            tokens_per_second: *m.tps_ewma.lock(),
183            total_tokens: m.total_tokens.load(Ordering::Relaxed),
184            total_requests: m.total_requests.load(Ordering::Relaxed),
185        })
186    }
187
188    /// Get all targets' current queue times for debugging
189    pub async fn queue_times(&self) -> Vec<(String, f64)> {
190        let health = self.health_status.read().await;
191        self.targets
192            .iter()
193            .filter_map(|t| {
194                let addr = t.full_address();
195                if *health.get(&addr).unwrap_or(&true) {
196                    self.metrics
197                        .get(&addr)
198                        .map(|m| (addr, m.estimated_queue_time(self.config.min_tps)))
199                } else {
200                    None
201                }
202            })
203            .collect()
204    }
205}
206
207/// Target statistics for observability
208#[derive(Debug, Clone)]
209pub struct LeastTokensQueuedTargetStats {
210    pub queued_tokens: u64,
211    pub queued_requests: u64,
212    pub tokens_per_second: f64,
213    pub total_tokens: u64,
214    pub total_requests: u64,
215}
216
217#[async_trait]
218impl LoadBalancer for LeastTokensQueuedBalancer {
219    async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
220        trace!(
221            total_targets = self.targets.len(),
222            algorithm = "least_tokens_queued",
223            "Selecting upstream target"
224        );
225
226        let health = self.health_status.read().await;
227
228        let mut best_target = None;
229        let mut min_queue_time = f64::MAX;
230
231        for target in &self.targets {
232            let addr = target.full_address();
233
234            // Skip unhealthy targets
235            if !*health.get(&addr).unwrap_or(&true) {
236                trace!(
237                    target = %addr,
238                    algorithm = "least_tokens_queued",
239                    "Skipping unhealthy target"
240                );
241                continue;
242            }
243
244            // Calculate estimated queue time
245            let queue_time = self
246                .metrics
247                .get(&addr)
248                .map(|m| m.estimated_queue_time(self.config.min_tps))
249                .unwrap_or(0.0);
250
251            trace!(
252                target = %addr,
253                queue_time_secs = queue_time,
254                "Evaluating target queue time"
255            );
256
257            if queue_time < min_queue_time {
258                min_queue_time = queue_time;
259                best_target = Some(target);
260            }
261        }
262
263        match best_target {
264            Some(target) => {
265                debug!(
266                    selected_target = %target.full_address(),
267                    queue_time_secs = min_queue_time,
268                    algorithm = "least_tokens_queued",
269                    "Selected target with lowest queue time"
270                );
271                Ok(TargetSelection {
272                    address: target.full_address(),
273                    weight: target.weight,
274                    metadata: HashMap::new(),
275                })
276            }
277            None => {
278                tracing::warn!(
279                    total_targets = self.targets.len(),
280                    algorithm = "least_tokens_queued",
281                    "No healthy upstream targets available"
282                );
283                Err(SentinelError::NoHealthyUpstream)
284            }
285        }
286    }
287
288    async fn report_health(&self, address: &str, healthy: bool) {
289        trace!(
290            target = %address,
291            healthy = healthy,
292            algorithm = "least_tokens_queued",
293            "Updating target health status"
294        );
295        self.health_status
296            .write()
297            .await
298            .insert(address.to_string(), healthy);
299    }
300
301    async fn healthy_targets(&self) -> Vec<String> {
302        self.health_status
303            .read()
304            .await
305            .iter()
306            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
307            .collect()
308    }
309
310    async fn report_result(
311        &self,
312        selection: &TargetSelection,
313        success: bool,
314        latency: Option<Duration>,
315    ) {
316        // Update health based on success
317        self.report_health(&selection.address, success).await;
318
319        // Note: Token dequeuing should be done explicitly via dequeue_tokens()
320        // when the actual token count is known from the response
321    }
322
323    async fn report_result_with_latency(
324        &self,
325        address: &str,
326        success: bool,
327        latency: Option<Duration>,
328    ) {
329        self.report_health(address, success).await;
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    fn test_targets() -> Vec<UpstreamTarget> {
338        vec![
339            UpstreamTarget::new("server1", 8080, 100),
340            UpstreamTarget::new("server2", 8080, 100),
341            UpstreamTarget::new("server3", 8080, 100),
342        ]
343    }
344
345    #[tokio::test]
346    async fn test_basic_selection() {
347        let balancer = LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
348
349        // All targets start with 0 queued tokens, so selection should work
350        let selection = balancer.select(None).await.unwrap();
351        assert!(!selection.address.is_empty());
352    }
353
354    #[tokio::test]
355    async fn test_selects_least_queued() {
356        let balancer = LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
357
358        // Add tokens to server1 and server2
359        balancer.enqueue_tokens("server1:8080", 1000);
360        balancer.enqueue_tokens("server2:8080", 500);
361        // server3 has 0 tokens
362
363        let selection = balancer.select(None).await.unwrap();
364        assert_eq!(selection.address, "server3:8080");
365    }
366
367    #[tokio::test]
368    async fn test_dequeue_updates_tps() {
369        let balancer = LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
370
371        // Enqueue and then dequeue with timing
372        balancer.enqueue_tokens("server1:8080", 1000);
373        balancer.dequeue_tokens("server1:8080", 1000, Duration::from_secs(1));
374
375        // Check that TPS was updated
376        let stats = balancer.target_metrics("server1:8080").unwrap();
377        assert!(stats.total_tokens == 1000);
378        assert!(stats.total_requests == 1);
379    }
380
381    #[tokio::test]
382    async fn test_unhealthy_target_skipped() {
383        let balancer = LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
384
385        // Mark server3 as unhealthy
386        balancer.report_health("server3:8080", false).await;
387
388        // Add tokens to server1
389        balancer.enqueue_tokens("server1:8080", 1000);
390
391        // Should select server2 (healthy and lowest queue)
392        let selection = balancer.select(None).await.unwrap();
393        assert_eq!(selection.address, "server2:8080");
394    }
395}