Skip to main content

torsh_distributed/expert_parallelism/
stats.rs

1//! Routing statistics and monitoring for expert parallelism
2//!
3//! This module provides comprehensive statistics tracking and monitoring capabilities
4//! for expert routing decisions, load balancing effectiveness, and system performance.
5
6use super::router::RoutingDecision;
7use log::info;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::time::{Duration, Instant};
11
12/// Routing statistics for monitoring and debugging
13///
14/// Tracks various metrics related to expert routing performance, efficiency,
15/// and utilization patterns over time.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RoutingStats {
18    /// Total number of routing operations performed
19    pub total_routings: u64,
20    /// Total number of tokens processed
21    pub total_tokens: u64,
22    /// Total number of tokens dropped due to capacity constraints
23    pub tokens_dropped: u64,
24    /// Expert utilization rates (running average per expert)
25    pub expert_utilization: Vec<f32>,
26    /// Average load balance loss across all routing operations
27    pub average_load_balance_loss: f32,
28    /// Average router z-loss for numerical stability
29    pub average_router_z_loss: f32,
30    /// Overall routing efficiency as a percentage
31    pub routing_efficiency: f32,
32    /// Per-expert token assignment counts
33    pub expert_token_counts: Vec<u64>,
34    /// Routing latency statistics
35    pub routing_latency_stats: LatencyStats,
36    /// Expert load variance over time
37    pub load_variance_history: Vec<f32>,
38    /// Capacity utilization statistics
39    pub capacity_stats: CapacityStats,
40}
41
42impl RoutingStats {
43    /// Create a new routing statistics tracker
44    pub fn new() -> Self {
45        Self {
46            total_routings: 0,
47            total_tokens: 0,
48            tokens_dropped: 0,
49            expert_utilization: Vec::new(),
50            average_load_balance_loss: 0.0,
51            average_router_z_loss: 0.0,
52            routing_efficiency: 0.0,
53            expert_token_counts: Vec::new(),
54            routing_latency_stats: LatencyStats::new(),
55            load_variance_history: Vec::new(),
56            capacity_stats: CapacityStats::new(),
57        }
58    }
59
60    /// Record a routing decision and update statistics
61    ///
62    /// # Arguments
63    ///
64    /// * `routing_decision` - The routing decision to record
65    pub fn record_routing(&mut self, routing_decision: &RoutingDecision) {
66        self.total_routings += 1;
67        self.total_tokens += routing_decision.total_tokens as u64;
68        self.tokens_dropped += routing_decision.tokens_dropped as u64;
69
70        // Update running averages using exponential moving average
71        let alpha = 1.0 / self.total_routings as f32;
72        self.average_load_balance_loss = alpha * routing_decision.load_balance_loss
73            + (1.0 - alpha) * self.average_load_balance_loss;
74        self.average_router_z_loss =
75            alpha * routing_decision.router_z_loss + (1.0 - alpha) * self.average_router_z_loss;
76
77        // Calculate routing efficiency
78        if self.total_tokens > 0 {
79            self.routing_efficiency =
80                (self.total_tokens - self.tokens_dropped) as f32 / self.total_tokens as f32 * 100.0;
81        }
82
83        // Update expert utilization
84        if self.expert_utilization.len() != routing_decision.expert_capacities.len() {
85            self.expert_utilization = vec![0.0; routing_decision.expert_capacities.len()];
86            self.expert_token_counts = vec![0; routing_decision.expert_capacities.len()];
87        }
88
89        for (i, &capacity) in routing_decision.expert_capacities.iter().enumerate() {
90            if i < self.expert_utilization.len() {
91                let utilization = if routing_decision.total_tokens > 0 {
92                    capacity as f32 / routing_decision.total_tokens as f32
93                } else {
94                    0.0
95                };
96                self.expert_utilization[i] =
97                    alpha * utilization + (1.0 - alpha) * self.expert_utilization[i];
98                self.expert_token_counts[i] += capacity as u64;
99            }
100        }
101
102        // Calculate and record load variance
103        let load_variance = self.calculate_load_variance(&routing_decision.expert_capacities);
104        self.load_variance_history.push(load_variance);
105
106        // Limit history size
107        if self.load_variance_history.len() > 1000 {
108            self.load_variance_history.remove(0);
109        }
110
111        // Update capacity statistics
112        self.capacity_stats.update(routing_decision);
113    }
114
115    /// Record routing latency for performance monitoring
116    ///
117    /// # Arguments
118    ///
119    /// * `latency` - The latency of the routing operation
120    pub fn record_routing_latency(&mut self, latency: Duration) {
121        self.routing_latency_stats.record_latency(latency);
122    }
123
124    /// Calculate load variance for a given set of expert capacities
125    fn calculate_load_variance(&self, capacities: &[usize]) -> f32 {
126        if capacities.is_empty() {
127            return 0.0;
128        }
129
130        let mean = capacities.iter().sum::<usize>() as f32 / capacities.len() as f32;
131        let variance = capacities
132            .iter()
133            .map(|&cap| {
134                let diff = cap as f32 - mean;
135                diff * diff
136            })
137            .sum::<f32>()
138            / capacities.len() as f32;
139
140        variance
141    }
142
143    /// Get the coefficient of variation for expert utilization
144    pub fn utilization_cv(&self) -> f32 {
145        if self.expert_utilization.is_empty() {
146            return 0.0;
147        }
148
149        let mean =
150            self.expert_utilization.iter().sum::<f32>() / self.expert_utilization.len() as f32;
151        if mean <= 0.0 {
152            return 0.0;
153        }
154
155        let variance = self
156            .expert_utilization
157            .iter()
158            .map(|&util| {
159                let diff = util - mean;
160                diff * diff
161            })
162            .sum::<f32>()
163            / self.expert_utilization.len() as f32;
164
165        variance.sqrt() / mean
166    }
167
168    /// Get the most utilized expert
169    pub fn most_utilized_expert(&self) -> Option<(usize, f32)> {
170        self.expert_utilization
171            .iter()
172            .enumerate()
173            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
174            .map(|(idx, &util)| (idx, util))
175    }
176
177    /// Get the least utilized expert
178    pub fn least_utilized_expert(&self) -> Option<(usize, f32)> {
179        self.expert_utilization
180            .iter()
181            .enumerate()
182            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
183            .map(|(idx, &util)| (idx, util))
184    }
185
186    /// Get utilization statistics summary
187    pub fn utilization_summary(&self) -> HashMap<String, f32> {
188        let mut summary = HashMap::new();
189
190        if !self.expert_utilization.is_empty() {
191            let mean =
192                self.expert_utilization.iter().sum::<f32>() / self.expert_utilization.len() as f32;
193            let min = self
194                .expert_utilization
195                .iter()
196                .copied()
197                .fold(f32::INFINITY, f32::min);
198            let max = self
199                .expert_utilization
200                .iter()
201                .copied()
202                .fold(f32::NEG_INFINITY, f32::max);
203
204            summary.insert("mean_utilization".to_string(), mean);
205            summary.insert("min_utilization".to_string(), min);
206            summary.insert("max_utilization".to_string(), max);
207            summary.insert("utilization_cv".to_string(), self.utilization_cv());
208        }
209
210        summary.insert("routing_efficiency".to_string(), self.routing_efficiency);
211        summary.insert(
212            "average_load_balance_loss".to_string(),
213            self.average_load_balance_loss,
214        );
215        summary.insert(
216            "average_router_z_loss".to_string(),
217            self.average_router_z_loss,
218        );
219
220        summary
221    }
222
223    /// Get recent load variance trend
224    pub fn recent_load_variance_trend(&self, window: usize) -> f32 {
225        if self.load_variance_history.len() < 2 {
226            return 0.0;
227        }
228
229        let start_idx = self.load_variance_history.len().saturating_sub(window);
230        let recent_variances = &self.load_variance_history[start_idx..];
231
232        if recent_variances.len() < 2 {
233            return 0.0;
234        }
235
236        // Simple linear trend calculation
237        let n = recent_variances.len() as f32;
238        let sum_x: f32 = (0..recent_variances.len()).map(|i| i as f32).sum();
239        let sum_y: f32 = recent_variances.iter().sum();
240        let sum_xy: f32 = recent_variances
241            .iter()
242            .enumerate()
243            .map(|(i, &y)| i as f32 * y)
244            .sum();
245        let sum_x2: f32 = (0..recent_variances.len())
246            .map(|i| (i as f32).powi(2))
247            .sum();
248
249        let denominator = n * sum_x2 - sum_x.powi(2);
250        if denominator.abs() < f32::EPSILON {
251            0.0
252        } else {
253            (n * sum_xy - sum_x * sum_y) / denominator
254        }
255    }
256
257    /// Reset all statistics
258    pub fn reset(&mut self) {
259        *self = Self::new();
260    }
261
262    /// Get throughput statistics
263    pub fn throughput_stats(&self) -> ThroughputStats {
264        ThroughputStats {
265            total_tokens: self.total_tokens,
266            total_routings: self.total_routings,
267            tokens_per_routing: if self.total_routings > 0 {
268                self.total_tokens as f32 / self.total_routings as f32
269            } else {
270                0.0
271            },
272            routing_efficiency: self.routing_efficiency,
273            average_latency: self.routing_latency_stats.average_latency(),
274        }
275    }
276}
277
278impl Default for RoutingStats {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284/// Latency statistics for routing operations
285#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct LatencyStats {
287    /// Total latency measurements
288    pub total_measurements: u64,
289    /// Sum of all latencies in milliseconds
290    pub total_latency_ms: f64,
291    /// Minimum observed latency
292    pub min_latency_ms: f64,
293    /// Maximum observed latency
294    pub max_latency_ms: f64,
295    /// Recent latency measurements for percentile calculation
296    pub recent_latencies: Vec<f64>,
297}
298
299impl LatencyStats {
300    /// Create new latency statistics
301    pub fn new() -> Self {
302        Self {
303            total_measurements: 0,
304            total_latency_ms: 0.0,
305            min_latency_ms: f64::INFINITY,
306            max_latency_ms: 0.0,
307            recent_latencies: Vec::new(),
308        }
309    }
310
311    /// Record a latency measurement
312    pub fn record_latency(&mut self, latency: Duration) {
313        let latency_ms = latency.as_secs_f64() * 1000.0;
314
315        self.total_measurements += 1;
316        self.total_latency_ms += latency_ms;
317        self.min_latency_ms = self.min_latency_ms.min(latency_ms);
318        self.max_latency_ms = self.max_latency_ms.max(latency_ms);
319
320        self.recent_latencies.push(latency_ms);
321
322        // Keep only recent measurements for percentile calculation
323        if self.recent_latencies.len() > 1000 {
324            self.recent_latencies.remove(0);
325        }
326    }
327
328    /// Get average latency in milliseconds
329    pub fn average_latency(&self) -> f64 {
330        if self.total_measurements > 0 {
331            self.total_latency_ms / self.total_measurements as f64
332        } else {
333            0.0
334        }
335    }
336
337    /// Get latency percentile
338    pub fn percentile(&self, p: f64) -> f64 {
339        if self.recent_latencies.is_empty() {
340            return 0.0;
341        }
342
343        let mut sorted_latencies = self.recent_latencies.clone();
344        sorted_latencies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
345
346        let index = ((p / 100.0) * (sorted_latencies.len() - 1) as f64) as usize;
347        sorted_latencies[index.min(sorted_latencies.len() - 1)]
348    }
349
350    /// Get 95th percentile latency
351    pub fn p95_latency(&self) -> f64 {
352        self.percentile(95.0)
353    }
354
355    /// Get 99th percentile latency
356    pub fn p99_latency(&self) -> f64 {
357        self.percentile(99.0)
358    }
359}
360
361impl Default for LatencyStats {
362    fn default() -> Self {
363        Self::new()
364    }
365}
366
367/// Capacity utilization statistics
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct CapacityStats {
370    /// Average capacity utilization across all experts
371    pub average_utilization: f32,
372    /// Peak capacity utilization observed
373    pub peak_utilization: f32,
374    /// Number of times capacity was exceeded
375    pub capacity_exceeded_count: u64,
376    /// Total available capacity across all experts
377    pub total_capacity: u64,
378    /// Total used capacity across all experts
379    pub total_used: u64,
380}
381
382impl CapacityStats {
383    /// Create new capacity statistics
384    pub fn new() -> Self {
385        Self {
386            average_utilization: 0.0,
387            peak_utilization: 0.0,
388            capacity_exceeded_count: 0,
389            total_capacity: 0,
390            total_used: 0,
391        }
392    }
393
394    /// Update capacity statistics with a routing decision
395    pub fn update(&mut self, routing_decision: &RoutingDecision) {
396        let current_utilization = routing_decision.expert_capacities.iter().sum::<usize>() as f32
397            / (routing_decision.expert_capacities.len() as f32 * 100.0); // Assuming capacity of 100 per expert
398
399        // Update average utilization
400        let alpha = 0.1; // Smoothing factor
401        self.average_utilization =
402            alpha * current_utilization + (1.0 - alpha) * self.average_utilization;
403
404        // Update peak utilization
405        self.peak_utilization = self.peak_utilization.max(current_utilization);
406
407        // Count capacity exceeded events
408        if routing_decision.tokens_dropped > 0 {
409            self.capacity_exceeded_count += 1;
410        }
411
412        // Update totals
413        self.total_used += routing_decision.expert_capacities.iter().sum::<usize>() as u64;
414        self.total_capacity += routing_decision.expert_capacities.len() as u64 * 100;
415        // Assuming capacity of 100 per expert
416    }
417
418    /// Get overall utilization percentage
419    pub fn overall_utilization(&self) -> f32 {
420        if self.total_capacity > 0 {
421            (self.total_used as f32 / self.total_capacity as f32) * 100.0
422        } else {
423            0.0
424        }
425    }
426}
427
428impl Default for CapacityStats {
429    fn default() -> Self {
430        Self::new()
431    }
432}
433
434/// Throughput statistics
435#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct ThroughputStats {
437    /// Total tokens processed
438    pub total_tokens: u64,
439    /// Total routing operations
440    pub total_routings: u64,
441    /// Average tokens per routing operation
442    pub tokens_per_routing: f32,
443    /// Routing efficiency percentage
444    pub routing_efficiency: f32,
445    /// Average routing latency in milliseconds
446    pub average_latency: f64,
447}
448
449impl ThroughputStats {
450    /// Calculate tokens per second
451    pub fn tokens_per_second(&self) -> f64 {
452        if self.average_latency > 0.0 {
453            (self.tokens_per_routing as f64 * 1000.0) / self.average_latency
454        } else {
455            0.0
456        }
457    }
458
459    /// Calculate routings per second
460    pub fn routings_per_second(&self) -> f64 {
461        if self.average_latency > 0.0 {
462            1000.0 / self.average_latency
463        } else {
464            0.0
465        }
466    }
467}
468
469/// Performance monitoring utilities
470pub mod monitoring {
471    use super::*;
472
473    /// Performance monitor for expert routing system
474    pub struct PerformanceMonitor {
475        stats: RoutingStats,
476        start_time: Instant,
477        last_report_time: Instant,
478        report_interval: Duration,
479    }
480
481    impl PerformanceMonitor {
482        /// Create a new performance monitor
483        pub fn new(report_interval: Duration) -> Self {
484            let now = Instant::now();
485            Self {
486                stats: RoutingStats::new(),
487                start_time: now,
488                last_report_time: now,
489                report_interval,
490            }
491        }
492
493        /// Record a routing decision
494        pub fn record_routing(&mut self, routing_decision: &RoutingDecision, latency: Duration) {
495            self.stats.record_routing(routing_decision);
496            self.stats.record_routing_latency(latency);
497
498            // Check if it's time for a report
499            if self.last_report_time.elapsed() >= self.report_interval {
500                self.print_report();
501                self.last_report_time = Instant::now();
502            }
503        }
504
505        /// Print performance report
506        pub fn print_report(&self) {
507            let uptime = self.start_time.elapsed();
508            let throughput = self.stats.throughput_stats();
509
510            info!("🔍 Expert Routing Performance Report");
511            info!("  Uptime: {:.2}s", uptime.as_secs_f64());
512            info!("  Total routings: {}", self.stats.total_routings);
513            info!("  Total tokens: {}", self.stats.total_tokens);
514            info!(
515                "  Routing efficiency: {:.2}%",
516                self.stats.routing_efficiency
517            );
518            info!("  Tokens/second: {:.2}", throughput.tokens_per_second());
519            info!(
520                "  Average latency: {:.2}ms",
521                self.stats.routing_latency_stats.average_latency()
522            );
523            info!(
524                "  P95 latency: {:.2}ms",
525                self.stats.routing_latency_stats.p95_latency()
526            );
527            info!("  Utilization CV: {:.3}", self.stats.utilization_cv());
528
529            if let Some((idx, util)) = self.stats.most_utilized_expert() {
530                info!("  Most utilized expert: {} ({:.2}%)", idx, util * 100.0);
531            }
532            if let Some((idx, util)) = self.stats.least_utilized_expert() {
533                info!("  Least utilized expert: {} ({:.2}%)", idx, util * 100.0);
534            }
535        }
536
537        /// Get current statistics
538        pub fn stats(&self) -> &RoutingStats {
539            &self.stats
540        }
541
542        /// Reset statistics
543        pub fn reset(&mut self) {
544            self.stats.reset();
545            self.start_time = Instant::now();
546            self.last_report_time = Instant::now();
547        }
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use crate::expert_parallelism::router::{ExpertAssignment, RoutingDecision};
555
556    #[test]
557    fn test_routing_stats_creation() {
558        let stats = RoutingStats::new();
559        assert_eq!(stats.total_routings, 0);
560        assert_eq!(stats.total_tokens, 0);
561        assert_eq!(stats.routing_efficiency, 0.0);
562    }
563
564    #[test]
565    fn test_routing_stats_recording() {
566        let mut stats = RoutingStats::new();
567
568        let routing_decision = RoutingDecision {
569            expert_assignments: vec![vec![ExpertAssignment::new(0, 0.8, 0, 0)]],
570            expert_capacities: vec![5, 3, 2, 0],
571            total_tokens: 10,
572            tokens_dropped: 0,
573            load_balance_loss: 0.1,
574            router_z_loss: 0.05,
575            auxiliary_loss: 0.15,
576        };
577
578        stats.record_routing(&routing_decision);
579
580        assert_eq!(stats.total_routings, 1);
581        assert_eq!(stats.total_tokens, 10);
582        assert_eq!(stats.tokens_dropped, 0);
583        assert_eq!(stats.routing_efficiency, 100.0);
584        assert_eq!(stats.expert_utilization.len(), 4);
585    }
586
587    #[test]
588    fn test_latency_stats() {
589        let mut latency_stats = LatencyStats::new();
590
591        latency_stats.record_latency(Duration::from_millis(10));
592        latency_stats.record_latency(Duration::from_millis(20));
593        latency_stats.record_latency(Duration::from_millis(30));
594
595        assert_eq!(latency_stats.total_measurements, 3);
596        assert_eq!(latency_stats.average_latency(), 20.0);
597        assert_eq!(latency_stats.min_latency_ms, 10.0);
598        assert_eq!(latency_stats.max_latency_ms, 30.0);
599    }
600
601    #[test]
602    fn test_utilization_cv() {
603        let mut stats = RoutingStats::new();
604        stats.expert_utilization = vec![0.1, 0.2, 0.3, 0.4]; // Varied utilization
605
606        let cv = stats.utilization_cv();
607        assert!(cv > 0.0); // Should have some variance
608    }
609
610    #[test]
611    fn test_capacity_stats() {
612        let mut capacity_stats = CapacityStats::new();
613
614        let routing_decision = RoutingDecision {
615            expert_assignments: vec![],
616            expert_capacities: vec![50, 75, 25, 100], // Mixed utilization
617            total_tokens: 250,
618            tokens_dropped: 0,
619            load_balance_loss: 0.0,
620            router_z_loss: 0.0,
621            auxiliary_loss: 0.0,
622        };
623
624        capacity_stats.update(&routing_decision);
625        assert!(capacity_stats.average_utilization > 0.0);
626        assert!(capacity_stats.peak_utilization > 0.0);
627    }
628
629    #[test]
630    fn test_throughput_stats() {
631        let throughput = ThroughputStats {
632            total_tokens: 1000,
633            total_routings: 10,
634            tokens_per_routing: 100.0,
635            routing_efficiency: 95.0,
636            average_latency: 50.0, // 50ms
637        };
638
639        assert_eq!(throughput.tokens_per_second(), 2000.0); // (100 * 1000) / 50
640        assert_eq!(throughput.routings_per_second(), 20.0); // 1000 / 50
641    }
642
643    #[test]
644    fn test_performance_monitor() {
645        let mut monitor = monitoring::PerformanceMonitor::new(Duration::from_secs(1));
646
647        let routing_decision = RoutingDecision {
648            expert_assignments: vec![],
649            expert_capacities: vec![10, 20, 30],
650            total_tokens: 60,
651            tokens_dropped: 0,
652            load_balance_loss: 0.1,
653            router_z_loss: 0.05,
654            auxiliary_loss: 0.15,
655        };
656
657        monitor.record_routing(&routing_decision, Duration::from_millis(25));
658
659        assert_eq!(monitor.stats().total_routings, 1);
660        assert_eq!(monitor.stats().total_tokens, 60);
661    }
662}