Skip to main content

scirs2_integrate/distributed/
load_balancing.rs

1//! Load balancing strategies for distributed integration
2//!
3//! This module provides various load balancing strategies for distributing
4//! work across compute nodes in the distributed integration system.
5
6use crate::common::IntegrateFloat;
7use crate::distributed::types::{
8    ChunkId, DistributedError, DistributedResult, JobId, LoadBalancingStrategy, NodeId, NodeInfo,
9    WorkChunk,
10};
11use scirs2_core::ndarray::Array1;
12use std::collections::{HashMap, VecDeque};
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::{Arc, Mutex, RwLock};
15use std::time::{Duration, Instant};
16
17/// Load balancer for distributing work chunks across nodes
18pub struct LoadBalancer<F: IntegrateFloat> {
19    /// Current strategy
20    strategy: RwLock<LoadBalancingStrategy>,
21    /// Node performance history
22    node_performance: RwLock<HashMap<NodeId, NodePerformance>>,
23    /// Work assignment history
24    assignment_history: Mutex<VecDeque<Assignment>>,
25    /// Round-robin counter
26    round_robin_counter: AtomicUsize,
27    /// Configuration
28    config: LoadBalancerConfig,
29    /// Phantom for float type
30    _phantom: std::marker::PhantomData<F>,
31}
32
33/// Configuration for the load balancer
34#[derive(Debug, Clone)]
35pub struct LoadBalancerConfig {
36    /// Maximum history entries to keep
37    pub max_history: usize,
38    /// Minimum samples before adapting
39    pub min_samples_for_adaptation: usize,
40    /// Performance smoothing factor (EMA alpha)
41    pub smoothing_factor: f64,
42    /// Imbalance threshold for triggering rebalancing
43    pub imbalance_threshold: f64,
44    /// Enable work stealing
45    pub enable_work_stealing: bool,
46    /// Work stealing threshold (fraction of work to steal)
47    pub work_stealing_threshold: f64,
48}
49
50impl Default for LoadBalancerConfig {
51    fn default() -> Self {
52        Self {
53            max_history: 1000,
54            min_samples_for_adaptation: 10,
55            smoothing_factor: 0.3,
56            imbalance_threshold: 0.3,
57            enable_work_stealing: true,
58            work_stealing_threshold: 0.5,
59        }
60    }
61}
62
63/// Performance metrics for a node
64#[derive(Debug, Clone)]
65pub struct NodePerformance {
66    /// Node ID
67    pub node_id: NodeId,
68    /// Average processing time per unit of estimated cost
69    pub avg_time_per_cost: f64,
70    /// Standard deviation of processing times
71    pub time_stddev: f64,
72    /// Number of chunks processed
73    pub chunks_processed: usize,
74    /// Total processing time
75    pub total_time: Duration,
76    /// Number of failures
77    pub failures: usize,
78    /// Success rate (0.0 to 1.0)
79    pub success_rate: f64,
80    /// Current load (number of pending chunks)
81    pub current_load: usize,
82    /// Recent processing times for variance calculation
83    recent_times: VecDeque<f64>,
84}
85
86impl NodePerformance {
87    /// Create new performance metrics
88    pub fn new(node_id: NodeId) -> Self {
89        Self {
90            node_id,
91            avg_time_per_cost: 1.0,
92            time_stddev: 0.0,
93            chunks_processed: 0,
94            total_time: Duration::ZERO,
95            failures: 0,
96            success_rate: 1.0,
97            current_load: 0,
98            recent_times: VecDeque::with_capacity(100),
99        }
100    }
101
102    /// Update performance with a new sample
103    pub fn update(&mut self, processing_time: Duration, estimated_cost: f64, success: bool) {
104        if success {
105            let time_per_cost = processing_time.as_secs_f64() / estimated_cost.max(0.001);
106
107            // Update EMA of time per cost
108            if self.chunks_processed == 0 {
109                self.avg_time_per_cost = time_per_cost;
110            } else {
111                let alpha = 0.3;
112                self.avg_time_per_cost =
113                    alpha * time_per_cost + (1.0 - alpha) * self.avg_time_per_cost;
114            }
115
116            // Track recent times for variance
117            self.recent_times.push_back(time_per_cost);
118            if self.recent_times.len() > 100 {
119                self.recent_times.pop_front();
120            }
121
122            // Update variance
123            if self.recent_times.len() >= 2 {
124                let mean: f64 =
125                    self.recent_times.iter().sum::<f64>() / self.recent_times.len() as f64;
126                let variance: f64 = self
127                    .recent_times
128                    .iter()
129                    .map(|t| (t - mean).powi(2))
130                    .sum::<f64>()
131                    / self.recent_times.len() as f64;
132                self.time_stddev = variance.sqrt();
133            }
134
135            self.chunks_processed += 1;
136            self.total_time += processing_time;
137        } else {
138            self.failures += 1;
139        }
140
141        // Update success rate
142        let total_attempts = self.chunks_processed + self.failures;
143        if total_attempts > 0 {
144            self.success_rate = self.chunks_processed as f64 / total_attempts as f64;
145        }
146    }
147
148    /// Get expected processing time for a given cost
149    pub fn expected_time(&self, estimated_cost: f64) -> Duration {
150        Duration::from_secs_f64(self.avg_time_per_cost * estimated_cost)
151    }
152
153    /// Calculate node score for assignment (higher is better)
154    pub fn assignment_score(&self, estimated_cost: f64) -> f64 {
155        // Factors: speed, reliability, current load
156        let speed_score = 1.0 / (self.avg_time_per_cost + 0.001);
157        let reliability_score = self.success_rate;
158        let load_penalty = 1.0 / (1.0 + self.current_load as f64);
159
160        speed_score * reliability_score * load_penalty
161    }
162}
163
164/// Record of a work assignment
165#[derive(Debug, Clone)]
166struct Assignment {
167    /// Chunk ID
168    chunk_id: ChunkId,
169    /// Assigned node
170    node_id: NodeId,
171    /// Timestamp
172    timestamp: Instant,
173    /// Estimated cost
174    estimated_cost: f64,
175}
176
177impl<F: IntegrateFloat> LoadBalancer<F> {
178    /// Create a new load balancer
179    pub fn new(strategy: LoadBalancingStrategy, config: LoadBalancerConfig) -> Self {
180        Self {
181            strategy: RwLock::new(strategy),
182            node_performance: RwLock::new(HashMap::new()),
183            assignment_history: Mutex::new(VecDeque::new()),
184            round_robin_counter: AtomicUsize::new(0),
185            config,
186            _phantom: std::marker::PhantomData,
187        }
188    }
189
190    /// Register a new node
191    pub fn register_node(&self, node_id: NodeId) -> DistributedResult<()> {
192        match self.node_performance.write() {
193            Ok(mut perf) => {
194                perf.insert(node_id, NodePerformance::new(node_id));
195                Ok(())
196            }
197            Err(_) => Err(DistributedError::ConfigError(
198                "Failed to register node".to_string(),
199            )),
200        }
201    }
202
203    /// Deregister a node
204    pub fn deregister_node(&self, node_id: NodeId) -> DistributedResult<()> {
205        match self.node_performance.write() {
206            Ok(mut perf) => {
207                perf.remove(&node_id);
208                Ok(())
209            }
210            Err(_) => Err(DistributedError::ConfigError(
211                "Failed to deregister node".to_string(),
212            )),
213        }
214    }
215
216    /// Get current strategy
217    pub fn get_strategy(&self) -> LoadBalancingStrategy {
218        match self.strategy.read() {
219            Ok(s) => *s,
220            Err(_) => LoadBalancingStrategy::RoundRobin,
221        }
222    }
223
224    /// Set strategy
225    pub fn set_strategy(&self, strategy: LoadBalancingStrategy) {
226        if let Ok(mut s) = self.strategy.write() {
227            *s = strategy;
228        }
229    }
230
231    /// Assign a work chunk to a node
232    pub fn assign_chunk(
233        &self,
234        chunk: &WorkChunk<F>,
235        available_nodes: &[NodeInfo],
236    ) -> DistributedResult<NodeId> {
237        if available_nodes.is_empty() {
238            return Err(DistributedError::ResourceExhausted(
239                "No available nodes".to_string(),
240            ));
241        }
242
243        let strategy = self.get_strategy();
244        let node_id = match strategy {
245            LoadBalancingStrategy::RoundRobin => self.round_robin_assignment(available_nodes)?,
246            LoadBalancingStrategy::CapabilityBased => {
247                self.capability_based_assignment(chunk, available_nodes)?
248            }
249            LoadBalancingStrategy::WorkStealing => {
250                self.work_stealing_assignment(chunk, available_nodes)?
251            }
252            LoadBalancingStrategy::Adaptive => self.adaptive_assignment(chunk, available_nodes)?,
253            LoadBalancingStrategy::LocalityAware => {
254                self.locality_aware_assignment(chunk, available_nodes)?
255            }
256        };
257
258        // Record assignment
259        self.record_assignment(chunk.id, node_id, chunk.estimated_cost);
260
261        // Update current load
262        if let Ok(mut perf) = self.node_performance.write() {
263            if let Some(p) = perf.get_mut(&node_id) {
264                p.current_load += 1;
265            }
266        }
267
268        Ok(node_id)
269    }
270
271    /// Round-robin assignment
272    fn round_robin_assignment(&self, nodes: &[NodeInfo]) -> DistributedResult<NodeId> {
273        let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % nodes.len();
274        Ok(nodes[idx].id)
275    }
276
277    /// Capability-based assignment
278    fn capability_based_assignment(
279        &self,
280        chunk: &WorkChunk<F>,
281        nodes: &[NodeInfo],
282    ) -> DistributedResult<NodeId> {
283        // Score nodes by capabilities
284        let best_node = nodes
285            .iter()
286            .max_by(|a, b| {
287                let score_a = Self::capability_score(a, chunk.estimated_cost);
288                let score_b = Self::capability_score(b, chunk.estimated_cost);
289                score_a
290                    .partial_cmp(&score_b)
291                    .unwrap_or(std::cmp::Ordering::Equal)
292            })
293            .ok_or_else(|| DistributedError::ResourceExhausted("No suitable node".to_string()))?;
294
295        Ok(best_node.id)
296    }
297
298    /// Calculate capability score for a node
299    fn capability_score(node: &NodeInfo, estimated_cost: f64) -> f64 {
300        let cpu_score = node.capabilities.cpu_cores as f64;
301        let memory_score = (node.capabilities.memory_bytes as f64 / 1e9).min(32.0) / 32.0;
302        let gpu_bonus = if node.capabilities.has_gpu { 5.0 } else { 0.0 };
303        let latency_penalty = (node.capabilities.latency_us as f64 / 10000.0).min(1.0);
304
305        (cpu_score + memory_score + gpu_bonus) * (1.0 - latency_penalty * 0.1)
306    }
307
308    /// Work-stealing-aware assignment
309    fn work_stealing_assignment(
310        &self,
311        chunk: &WorkChunk<F>,
312        nodes: &[NodeInfo],
313    ) -> DistributedResult<NodeId> {
314        // Find node with lowest current load, considering performance
315        match self.node_performance.read() {
316            Ok(perf) => {
317                let best_node = nodes
318                    .iter()
319                    .min_by(|a, b| {
320                        let load_a = perf.get(&a.id).map(|p| p.current_load).unwrap_or(0);
321                        let load_b = perf.get(&b.id).map(|p| p.current_load).unwrap_or(0);
322                        load_a.cmp(&load_b)
323                    })
324                    .ok_or_else(|| {
325                        DistributedError::ResourceExhausted("No suitable node".to_string())
326                    })?;
327
328                Ok(best_node.id)
329            }
330            Err(_) => self.round_robin_assignment(nodes),
331        }
332    }
333
334    /// Adaptive assignment based on performance history
335    fn adaptive_assignment(
336        &self,
337        chunk: &WorkChunk<F>,
338        nodes: &[NodeInfo],
339    ) -> DistributedResult<NodeId> {
340        match self.node_performance.read() {
341            Ok(perf) => {
342                // Check if we have enough samples for adaptation
343                let total_samples: usize = perf.values().map(|p| p.chunks_processed).sum();
344
345                if total_samples < self.config.min_samples_for_adaptation {
346                    // Not enough data, use round-robin
347                    return self.round_robin_assignment(nodes);
348                }
349
350                // Score each node
351                let best_node = nodes
352                    .iter()
353                    .max_by(|a, b| {
354                        let score_a = perf
355                            .get(&a.id)
356                            .map(|p| p.assignment_score(chunk.estimated_cost))
357                            .unwrap_or(0.0);
358                        let score_b = perf
359                            .get(&b.id)
360                            .map(|p| p.assignment_score(chunk.estimated_cost))
361                            .unwrap_or(0.0);
362                        score_a
363                            .partial_cmp(&score_b)
364                            .unwrap_or(std::cmp::Ordering::Equal)
365                    })
366                    .ok_or_else(|| {
367                        DistributedError::ResourceExhausted("No suitable node".to_string())
368                    })?;
369
370                Ok(best_node.id)
371            }
372            Err(_) => self.round_robin_assignment(nodes),
373        }
374    }
375
376    /// Locality-aware assignment (keeps related chunks together)
377    fn locality_aware_assignment(
378        &self,
379        chunk: &WorkChunk<F>,
380        nodes: &[NodeInfo],
381    ) -> DistributedResult<NodeId> {
382        // For now, use job ID modulo to keep related chunks on same nodes
383        let job_mod = chunk.job_id.value() as usize % nodes.len();
384        let chunk_mod = chunk.id.value() as usize % nodes.len();
385
386        // Combine job and chunk locality
387        let idx = (job_mod + chunk_mod) % nodes.len();
388        Ok(nodes[idx].id)
389    }
390
391    /// Record an assignment
392    fn record_assignment(&self, chunk_id: ChunkId, node_id: NodeId, estimated_cost: f64) {
393        if let Ok(mut history) = self.assignment_history.lock() {
394            history.push_back(Assignment {
395                chunk_id,
396                node_id,
397                timestamp: Instant::now(),
398                estimated_cost,
399            });
400
401            // Trim history
402            while history.len() > self.config.max_history {
403                history.pop_front();
404            }
405        }
406    }
407
408    /// Report chunk completion
409    pub fn report_completion(
410        &self,
411        node_id: NodeId,
412        estimated_cost: f64,
413        processing_time: Duration,
414        success: bool,
415    ) {
416        if let Ok(mut perf) = self.node_performance.write() {
417            if let Some(p) = perf.get_mut(&node_id) {
418                p.update(processing_time, estimated_cost, success);
419                if p.current_load > 0 {
420                    p.current_load -= 1;
421                }
422            }
423        }
424    }
425
426    /// Get current load distribution
427    pub fn get_load_distribution(&self) -> HashMap<NodeId, usize> {
428        match self.node_performance.read() {
429            Ok(perf) => perf.iter().map(|(id, p)| (*id, p.current_load)).collect(),
430            Err(_) => HashMap::new(),
431        }
432    }
433
434    /// Check if rebalancing is needed
435    pub fn needs_rebalancing(&self) -> bool {
436        match self.node_performance.read() {
437            Ok(perf) => {
438                if perf.is_empty() {
439                    return false;
440                }
441
442                let loads: Vec<f64> = perf.values().map(|p| p.current_load as f64).collect();
443
444                if loads.is_empty() {
445                    return false;
446                }
447
448                let mean = loads.iter().sum::<f64>() / loads.len() as f64;
449                if mean <= 0.0 {
450                    return false;
451                }
452
453                let max_deviation = loads
454                    .iter()
455                    .map(|l| (l - mean).abs() / mean)
456                    .fold(0.0_f64, f64::max);
457
458                max_deviation > self.config.imbalance_threshold
459            }
460            Err(_) => false,
461        }
462    }
463
464    /// Get nodes with excess work (candidates for work stealing)
465    pub fn get_overloaded_nodes(&self) -> Vec<(NodeId, usize)> {
466        match self.node_performance.read() {
467            Ok(perf) => {
468                let loads: Vec<_> = perf.iter().map(|(id, p)| (*id, p.current_load)).collect();
469
470                if loads.is_empty() {
471                    return Vec::new();
472                }
473
474                let mean_load: f64 =
475                    loads.iter().map(|(_, l)| *l as f64).sum::<f64>() / loads.len() as f64;
476                let threshold = (mean_load * (1.0 + self.config.imbalance_threshold)) as usize;
477
478                loads
479                    .into_iter()
480                    .filter(|(_, load)| *load > threshold)
481                    .collect()
482            }
483            Err(_) => Vec::new(),
484        }
485    }
486
487    /// Get nodes with room for more work
488    pub fn get_underloaded_nodes(&self) -> Vec<(NodeId, usize)> {
489        match self.node_performance.read() {
490            Ok(perf) => {
491                let loads: Vec<_> = perf.iter().map(|(id, p)| (*id, p.current_load)).collect();
492
493                if loads.is_empty() {
494                    return Vec::new();
495                }
496
497                let mean_load: f64 =
498                    loads.iter().map(|(_, l)| *l as f64).sum::<f64>() / loads.len() as f64;
499                let threshold = (mean_load * (1.0 - self.config.imbalance_threshold)) as usize;
500
501                loads
502                    .into_iter()
503                    .filter(|(_, load)| *load < threshold)
504                    .collect()
505            }
506            Err(_) => Vec::new(),
507        }
508    }
509
510    /// Get performance statistics
511    pub fn get_statistics(&self) -> LoadBalancerStatistics {
512        match self.node_performance.read() {
513            Ok(perf) => {
514                let node_count = perf.len();
515                let total_chunks: usize = perf.values().map(|p| p.chunks_processed).sum();
516                let total_failures: usize = perf.values().map(|p| p.failures).sum();
517
518                let loads: Vec<f64> = perf.values().map(|p| p.current_load as f64).collect();
519                let load_variance = if !loads.is_empty() {
520                    let mean = loads.iter().sum::<f64>() / loads.len() as f64;
521                    loads.iter().map(|l| (l - mean).powi(2)).sum::<f64>() / loads.len() as f64
522                } else {
523                    0.0
524                };
525
526                LoadBalancerStatistics {
527                    node_count,
528                    total_chunks_assigned: total_chunks,
529                    total_failures,
530                    load_variance,
531                    current_strategy: self.get_strategy(),
532                }
533            }
534            Err(_) => LoadBalancerStatistics::default(),
535        }
536    }
537}
538
539/// Statistics about load balancer performance
540#[derive(Debug, Clone, Default)]
541pub struct LoadBalancerStatistics {
542    /// Number of registered nodes
543    pub node_count: usize,
544    /// Total chunks assigned
545    pub total_chunks_assigned: usize,
546    /// Total failures
547    pub total_failures: usize,
548    /// Current load variance
549    pub load_variance: f64,
550    /// Current strategy
551    pub current_strategy: LoadBalancingStrategy,
552}
553
554#[allow(clippy::derivable_impls)]
555impl Default for LoadBalancingStrategy {
556    fn default() -> Self {
557        Self::Adaptive
558    }
559}
560
561/// Work chunk distributor for initial distribution
562pub struct ChunkDistributor<F: IntegrateFloat> {
563    /// Job ID
564    job_id: JobId,
565    /// Next chunk ID
566    next_chunk_id: AtomicUsize,
567    /// Phantom for float type
568    _phantom: std::marker::PhantomData<F>,
569}
570
571impl<F: IntegrateFloat> ChunkDistributor<F> {
572    /// Create a new chunk distributor
573    pub fn new(job_id: JobId) -> Self {
574        Self {
575            job_id,
576            next_chunk_id: AtomicUsize::new(0),
577            _phantom: std::marker::PhantomData,
578        }
579    }
580
581    /// Create work chunks from a time interval
582    pub fn create_chunks(
583        &self,
584        t_span: (F, F),
585        initial_state: Array1<F>,
586        num_chunks: usize,
587    ) -> Vec<WorkChunk<F>> {
588        let t_start = t_span.0;
589        let t_end = t_span.1;
590        let dt = (t_end - t_start) / F::from(num_chunks).unwrap_or(F::one());
591
592        let mut chunks = Vec::with_capacity(num_chunks);
593
594        for i in 0..num_chunks {
595            let chunk_t_start = t_start + dt * F::from(i).unwrap_or(F::zero());
596            let chunk_t_end = if i == num_chunks - 1 {
597                t_end
598            } else {
599                t_start + dt * F::from(i + 1).unwrap_or(F::one())
600            };
601
602            let chunk_id = ChunkId::new(self.next_chunk_id.fetch_add(1, Ordering::SeqCst) as u64);
603
604            // Initial state for first chunk, placeholder for others
605            // (will be filled in by boundary exchange)
606            let state = if i == 0 {
607                initial_state.clone()
608            } else {
609                Array1::zeros(initial_state.len())
610            };
611
612            chunks.push(WorkChunk::new(
613                chunk_id,
614                self.job_id,
615                (chunk_t_start, chunk_t_end),
616                state,
617            ));
618        }
619
620        chunks
621    }
622
623    /// Subdivide a chunk into smaller chunks
624    pub fn subdivide_chunk(&self, chunk: &WorkChunk<F>, num_parts: usize) -> Vec<WorkChunk<F>> {
625        let (t_start, t_end) = chunk.time_interval;
626        let dt = (t_end - t_start) / F::from(num_parts).unwrap_or(F::one());
627
628        let mut sub_chunks = Vec::with_capacity(num_parts);
629
630        for i in 0..num_parts {
631            let sub_t_start = t_start + dt * F::from(i).unwrap_or(F::zero());
632            let sub_t_end = if i == num_parts - 1 {
633                t_end
634            } else {
635                t_start + dt * F::from(i + 1).unwrap_or(F::one())
636            };
637
638            let sub_chunk_id =
639                ChunkId::new(self.next_chunk_id.fetch_add(1, Ordering::SeqCst) as u64);
640
641            let state = if i == 0 {
642                chunk.initial_state.clone()
643            } else {
644                Array1::zeros(chunk.initial_state.len())
645            };
646
647            let mut sub_chunk =
648                WorkChunk::new(sub_chunk_id, chunk.job_id, (sub_t_start, sub_t_end), state);
649
650            sub_chunk.priority = chunk.priority;
651            sub_chunks.push(sub_chunk);
652        }
653
654        sub_chunks
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use crate::distributed::types::NodeCapabilities;
662    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
663
664    fn create_test_nodes(n: usize) -> Vec<NodeInfo> {
665        (0..n)
666            .map(|i| {
667                let addr =
668                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080 + i as u16);
669                let mut info = NodeInfo::new(NodeId::new(i as u64), addr);
670                info.capabilities = NodeCapabilities::default();
671                info
672            })
673            .collect()
674    }
675
676    #[test]
677    fn test_round_robin_assignment() {
678        let balancer: LoadBalancer<f64> = LoadBalancer::new(
679            LoadBalancingStrategy::RoundRobin,
680            LoadBalancerConfig::default(),
681        );
682
683        let nodes = create_test_nodes(3);
684
685        // Register nodes
686        for node in &nodes {
687            balancer.register_node(node.id).expect("Failed to register");
688        }
689
690        let chunk = WorkChunk::new(ChunkId::new(1), JobId::new(1), (0.0, 1.0), Array1::zeros(3));
691
692        // Should cycle through nodes
693        let assignments: Vec<_> = (0..6)
694            .map(|_| {
695                balancer
696                    .assign_chunk(&chunk, &nodes)
697                    .expect("Assignment failed")
698            })
699            .collect();
700
701        // Check round-robin pattern
702        for i in 0..3 {
703            assert_eq!(assignments[i], assignments[i + 3]);
704        }
705    }
706
707    #[test]
708    fn test_performance_update() {
709        let mut perf = NodePerformance::new(NodeId::new(1));
710
711        perf.update(Duration::from_millis(100), 1.0, true);
712        assert_eq!(perf.chunks_processed, 1);
713        assert!(perf.success_rate > 0.9);
714
715        perf.update(Duration::from_millis(50), 1.0, false);
716        assert_eq!(perf.failures, 1);
717        assert!(perf.success_rate < 1.0);
718    }
719
720    #[test]
721    fn test_chunk_distributor() {
722        let distributor: ChunkDistributor<f64> = ChunkDistributor::new(JobId::new(1));
723
724        let chunks = distributor.create_chunks((0.0, 10.0), Array1::from_vec(vec![1.0, 2.0]), 5);
725
726        assert_eq!(chunks.len(), 5);
727        assert!((chunks[0].time_interval.0 - 0.0).abs() < 1e-10);
728        assert!((chunks[4].time_interval.1 - 10.0).abs() < 1e-10);
729    }
730
731    #[test]
732    fn test_load_distribution() {
733        let balancer: LoadBalancer<f64> = LoadBalancer::new(
734            LoadBalancingStrategy::Adaptive,
735            LoadBalancerConfig::default(),
736        );
737
738        let nodes = create_test_nodes(3);
739        for node in &nodes {
740            balancer.register_node(node.id).expect("Failed to register");
741        }
742
743        // Simulate assignments
744        for i in 0..10 {
745            let chunk =
746                WorkChunk::new(ChunkId::new(i), JobId::new(1), (0.0, 1.0), Array1::zeros(3));
747            let _ = balancer.assign_chunk(&chunk, &nodes);
748        }
749
750        let distribution = balancer.get_load_distribution();
751        assert_eq!(distribution.len(), 3);
752    }
753}