1use crate::common::IntegrateFloat;
7use crate::distributed::types::{
8 ChunkId, DistributedError, DistributedResult, JobId, LoadBalancingStrategy, NodeId, NodeInfo,
9 WorkChunk,
10};
11use scirs2_core::ndarray::Array1;
12use std::collections::{HashMap, VecDeque};
13use std::sync::atomic::{AtomicUsize, Ordering};
14use std::sync::{Arc, Mutex, RwLock};
15use std::time::{Duration, Instant};
16
17pub struct LoadBalancer<F: IntegrateFloat> {
19 strategy: RwLock<LoadBalancingStrategy>,
21 node_performance: RwLock<HashMap<NodeId, NodePerformance>>,
23 assignment_history: Mutex<VecDeque<Assignment>>,
25 round_robin_counter: AtomicUsize,
27 config: LoadBalancerConfig,
29 _phantom: std::marker::PhantomData<F>,
31}
32
33#[derive(Debug, Clone)]
35pub struct LoadBalancerConfig {
36 pub max_history: usize,
38 pub min_samples_for_adaptation: usize,
40 pub smoothing_factor: f64,
42 pub imbalance_threshold: f64,
44 pub enable_work_stealing: bool,
46 pub work_stealing_threshold: f64,
48}
49
50impl Default for LoadBalancerConfig {
51 fn default() -> Self {
52 Self {
53 max_history: 1000,
54 min_samples_for_adaptation: 10,
55 smoothing_factor: 0.3,
56 imbalance_threshold: 0.3,
57 enable_work_stealing: true,
58 work_stealing_threshold: 0.5,
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct NodePerformance {
66 pub node_id: NodeId,
68 pub avg_time_per_cost: f64,
70 pub time_stddev: f64,
72 pub chunks_processed: usize,
74 pub total_time: Duration,
76 pub failures: usize,
78 pub success_rate: f64,
80 pub current_load: usize,
82 recent_times: VecDeque<f64>,
84}
85
86impl NodePerformance {
87 pub fn new(node_id: NodeId) -> Self {
89 Self {
90 node_id,
91 avg_time_per_cost: 1.0,
92 time_stddev: 0.0,
93 chunks_processed: 0,
94 total_time: Duration::ZERO,
95 failures: 0,
96 success_rate: 1.0,
97 current_load: 0,
98 recent_times: VecDeque::with_capacity(100),
99 }
100 }
101
102 pub fn update(&mut self, processing_time: Duration, estimated_cost: f64, success: bool) {
104 if success {
105 let time_per_cost = processing_time.as_secs_f64() / estimated_cost.max(0.001);
106
107 if self.chunks_processed == 0 {
109 self.avg_time_per_cost = time_per_cost;
110 } else {
111 let alpha = 0.3;
112 self.avg_time_per_cost =
113 alpha * time_per_cost + (1.0 - alpha) * self.avg_time_per_cost;
114 }
115
116 self.recent_times.push_back(time_per_cost);
118 if self.recent_times.len() > 100 {
119 self.recent_times.pop_front();
120 }
121
122 if self.recent_times.len() >= 2 {
124 let mean: f64 =
125 self.recent_times.iter().sum::<f64>() / self.recent_times.len() as f64;
126 let variance: f64 = self
127 .recent_times
128 .iter()
129 .map(|t| (t - mean).powi(2))
130 .sum::<f64>()
131 / self.recent_times.len() as f64;
132 self.time_stddev = variance.sqrt();
133 }
134
135 self.chunks_processed += 1;
136 self.total_time += processing_time;
137 } else {
138 self.failures += 1;
139 }
140
141 let total_attempts = self.chunks_processed + self.failures;
143 if total_attempts > 0 {
144 self.success_rate = self.chunks_processed as f64 / total_attempts as f64;
145 }
146 }
147
148 pub fn expected_time(&self, estimated_cost: f64) -> Duration {
150 Duration::from_secs_f64(self.avg_time_per_cost * estimated_cost)
151 }
152
153 pub fn assignment_score(&self, estimated_cost: f64) -> f64 {
155 let speed_score = 1.0 / (self.avg_time_per_cost + 0.001);
157 let reliability_score = self.success_rate;
158 let load_penalty = 1.0 / (1.0 + self.current_load as f64);
159
160 speed_score * reliability_score * load_penalty
161 }
162}
163
164#[derive(Debug, Clone)]
166struct Assignment {
167 chunk_id: ChunkId,
169 node_id: NodeId,
171 timestamp: Instant,
173 estimated_cost: f64,
175}
176
177impl<F: IntegrateFloat> LoadBalancer<F> {
178 pub fn new(strategy: LoadBalancingStrategy, config: LoadBalancerConfig) -> Self {
180 Self {
181 strategy: RwLock::new(strategy),
182 node_performance: RwLock::new(HashMap::new()),
183 assignment_history: Mutex::new(VecDeque::new()),
184 round_robin_counter: AtomicUsize::new(0),
185 config,
186 _phantom: std::marker::PhantomData,
187 }
188 }
189
190 pub fn register_node(&self, node_id: NodeId) -> DistributedResult<()> {
192 match self.node_performance.write() {
193 Ok(mut perf) => {
194 perf.insert(node_id, NodePerformance::new(node_id));
195 Ok(())
196 }
197 Err(_) => Err(DistributedError::ConfigError(
198 "Failed to register node".to_string(),
199 )),
200 }
201 }
202
203 pub fn deregister_node(&self, node_id: NodeId) -> DistributedResult<()> {
205 match self.node_performance.write() {
206 Ok(mut perf) => {
207 perf.remove(&node_id);
208 Ok(())
209 }
210 Err(_) => Err(DistributedError::ConfigError(
211 "Failed to deregister node".to_string(),
212 )),
213 }
214 }
215
216 pub fn get_strategy(&self) -> LoadBalancingStrategy {
218 match self.strategy.read() {
219 Ok(s) => *s,
220 Err(_) => LoadBalancingStrategy::RoundRobin,
221 }
222 }
223
224 pub fn set_strategy(&self, strategy: LoadBalancingStrategy) {
226 if let Ok(mut s) = self.strategy.write() {
227 *s = strategy;
228 }
229 }
230
231 pub fn assign_chunk(
233 &self,
234 chunk: &WorkChunk<F>,
235 available_nodes: &[NodeInfo],
236 ) -> DistributedResult<NodeId> {
237 if available_nodes.is_empty() {
238 return Err(DistributedError::ResourceExhausted(
239 "No available nodes".to_string(),
240 ));
241 }
242
243 let strategy = self.get_strategy();
244 let node_id = match strategy {
245 LoadBalancingStrategy::RoundRobin => self.round_robin_assignment(available_nodes)?,
246 LoadBalancingStrategy::CapabilityBased => {
247 self.capability_based_assignment(chunk, available_nodes)?
248 }
249 LoadBalancingStrategy::WorkStealing => {
250 self.work_stealing_assignment(chunk, available_nodes)?
251 }
252 LoadBalancingStrategy::Adaptive => self.adaptive_assignment(chunk, available_nodes)?,
253 LoadBalancingStrategy::LocalityAware => {
254 self.locality_aware_assignment(chunk, available_nodes)?
255 }
256 };
257
258 self.record_assignment(chunk.id, node_id, chunk.estimated_cost);
260
261 if let Ok(mut perf) = self.node_performance.write() {
263 if let Some(p) = perf.get_mut(&node_id) {
264 p.current_load += 1;
265 }
266 }
267
268 Ok(node_id)
269 }
270
271 fn round_robin_assignment(&self, nodes: &[NodeInfo]) -> DistributedResult<NodeId> {
273 let idx = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % nodes.len();
274 Ok(nodes[idx].id)
275 }
276
277 fn capability_based_assignment(
279 &self,
280 chunk: &WorkChunk<F>,
281 nodes: &[NodeInfo],
282 ) -> DistributedResult<NodeId> {
283 let best_node = nodes
285 .iter()
286 .max_by(|a, b| {
287 let score_a = Self::capability_score(a, chunk.estimated_cost);
288 let score_b = Self::capability_score(b, chunk.estimated_cost);
289 score_a
290 .partial_cmp(&score_b)
291 .unwrap_or(std::cmp::Ordering::Equal)
292 })
293 .ok_or_else(|| DistributedError::ResourceExhausted("No suitable node".to_string()))?;
294
295 Ok(best_node.id)
296 }
297
298 fn capability_score(node: &NodeInfo, estimated_cost: f64) -> f64 {
300 let cpu_score = node.capabilities.cpu_cores as f64;
301 let memory_score = (node.capabilities.memory_bytes as f64 / 1e9).min(32.0) / 32.0;
302 let gpu_bonus = if node.capabilities.has_gpu { 5.0 } else { 0.0 };
303 let latency_penalty = (node.capabilities.latency_us as f64 / 10000.0).min(1.0);
304
305 (cpu_score + memory_score + gpu_bonus) * (1.0 - latency_penalty * 0.1)
306 }
307
308 fn work_stealing_assignment(
310 &self,
311 chunk: &WorkChunk<F>,
312 nodes: &[NodeInfo],
313 ) -> DistributedResult<NodeId> {
314 match self.node_performance.read() {
316 Ok(perf) => {
317 let best_node = nodes
318 .iter()
319 .min_by(|a, b| {
320 let load_a = perf.get(&a.id).map(|p| p.current_load).unwrap_or(0);
321 let load_b = perf.get(&b.id).map(|p| p.current_load).unwrap_or(0);
322 load_a.cmp(&load_b)
323 })
324 .ok_or_else(|| {
325 DistributedError::ResourceExhausted("No suitable node".to_string())
326 })?;
327
328 Ok(best_node.id)
329 }
330 Err(_) => self.round_robin_assignment(nodes),
331 }
332 }
333
334 fn adaptive_assignment(
336 &self,
337 chunk: &WorkChunk<F>,
338 nodes: &[NodeInfo],
339 ) -> DistributedResult<NodeId> {
340 match self.node_performance.read() {
341 Ok(perf) => {
342 let total_samples: usize = perf.values().map(|p| p.chunks_processed).sum();
344
345 if total_samples < self.config.min_samples_for_adaptation {
346 return self.round_robin_assignment(nodes);
348 }
349
350 let best_node = nodes
352 .iter()
353 .max_by(|a, b| {
354 let score_a = perf
355 .get(&a.id)
356 .map(|p| p.assignment_score(chunk.estimated_cost))
357 .unwrap_or(0.0);
358 let score_b = perf
359 .get(&b.id)
360 .map(|p| p.assignment_score(chunk.estimated_cost))
361 .unwrap_or(0.0);
362 score_a
363 .partial_cmp(&score_b)
364 .unwrap_or(std::cmp::Ordering::Equal)
365 })
366 .ok_or_else(|| {
367 DistributedError::ResourceExhausted("No suitable node".to_string())
368 })?;
369
370 Ok(best_node.id)
371 }
372 Err(_) => self.round_robin_assignment(nodes),
373 }
374 }
375
376 fn locality_aware_assignment(
378 &self,
379 chunk: &WorkChunk<F>,
380 nodes: &[NodeInfo],
381 ) -> DistributedResult<NodeId> {
382 let job_mod = chunk.job_id.value() as usize % nodes.len();
384 let chunk_mod = chunk.id.value() as usize % nodes.len();
385
386 let idx = (job_mod + chunk_mod) % nodes.len();
388 Ok(nodes[idx].id)
389 }
390
391 fn record_assignment(&self, chunk_id: ChunkId, node_id: NodeId, estimated_cost: f64) {
393 if let Ok(mut history) = self.assignment_history.lock() {
394 history.push_back(Assignment {
395 chunk_id,
396 node_id,
397 timestamp: Instant::now(),
398 estimated_cost,
399 });
400
401 while history.len() > self.config.max_history {
403 history.pop_front();
404 }
405 }
406 }
407
408 pub fn report_completion(
410 &self,
411 node_id: NodeId,
412 estimated_cost: f64,
413 processing_time: Duration,
414 success: bool,
415 ) {
416 if let Ok(mut perf) = self.node_performance.write() {
417 if let Some(p) = perf.get_mut(&node_id) {
418 p.update(processing_time, estimated_cost, success);
419 if p.current_load > 0 {
420 p.current_load -= 1;
421 }
422 }
423 }
424 }
425
426 pub fn get_load_distribution(&self) -> HashMap<NodeId, usize> {
428 match self.node_performance.read() {
429 Ok(perf) => perf.iter().map(|(id, p)| (*id, p.current_load)).collect(),
430 Err(_) => HashMap::new(),
431 }
432 }
433
434 pub fn needs_rebalancing(&self) -> bool {
436 match self.node_performance.read() {
437 Ok(perf) => {
438 if perf.is_empty() {
439 return false;
440 }
441
442 let loads: Vec<f64> = perf.values().map(|p| p.current_load as f64).collect();
443
444 if loads.is_empty() {
445 return false;
446 }
447
448 let mean = loads.iter().sum::<f64>() / loads.len() as f64;
449 if mean <= 0.0 {
450 return false;
451 }
452
453 let max_deviation = loads
454 .iter()
455 .map(|l| (l - mean).abs() / mean)
456 .fold(0.0_f64, f64::max);
457
458 max_deviation > self.config.imbalance_threshold
459 }
460 Err(_) => false,
461 }
462 }
463
464 pub fn get_overloaded_nodes(&self) -> Vec<(NodeId, usize)> {
466 match self.node_performance.read() {
467 Ok(perf) => {
468 let loads: Vec<_> = perf.iter().map(|(id, p)| (*id, p.current_load)).collect();
469
470 if loads.is_empty() {
471 return Vec::new();
472 }
473
474 let mean_load: f64 =
475 loads.iter().map(|(_, l)| *l as f64).sum::<f64>() / loads.len() as f64;
476 let threshold = (mean_load * (1.0 + self.config.imbalance_threshold)) as usize;
477
478 loads
479 .into_iter()
480 .filter(|(_, load)| *load > threshold)
481 .collect()
482 }
483 Err(_) => Vec::new(),
484 }
485 }
486
487 pub fn get_underloaded_nodes(&self) -> Vec<(NodeId, usize)> {
489 match self.node_performance.read() {
490 Ok(perf) => {
491 let loads: Vec<_> = perf.iter().map(|(id, p)| (*id, p.current_load)).collect();
492
493 if loads.is_empty() {
494 return Vec::new();
495 }
496
497 let mean_load: f64 =
498 loads.iter().map(|(_, l)| *l as f64).sum::<f64>() / loads.len() as f64;
499 let threshold = (mean_load * (1.0 - self.config.imbalance_threshold)) as usize;
500
501 loads
502 .into_iter()
503 .filter(|(_, load)| *load < threshold)
504 .collect()
505 }
506 Err(_) => Vec::new(),
507 }
508 }
509
510 pub fn get_statistics(&self) -> LoadBalancerStatistics {
512 match self.node_performance.read() {
513 Ok(perf) => {
514 let node_count = perf.len();
515 let total_chunks: usize = perf.values().map(|p| p.chunks_processed).sum();
516 let total_failures: usize = perf.values().map(|p| p.failures).sum();
517
518 let loads: Vec<f64> = perf.values().map(|p| p.current_load as f64).collect();
519 let load_variance = if !loads.is_empty() {
520 let mean = loads.iter().sum::<f64>() / loads.len() as f64;
521 loads.iter().map(|l| (l - mean).powi(2)).sum::<f64>() / loads.len() as f64
522 } else {
523 0.0
524 };
525
526 LoadBalancerStatistics {
527 node_count,
528 total_chunks_assigned: total_chunks,
529 total_failures,
530 load_variance,
531 current_strategy: self.get_strategy(),
532 }
533 }
534 Err(_) => LoadBalancerStatistics::default(),
535 }
536 }
537}
538
539#[derive(Debug, Clone, Default)]
541pub struct LoadBalancerStatistics {
542 pub node_count: usize,
544 pub total_chunks_assigned: usize,
546 pub total_failures: usize,
548 pub load_variance: f64,
550 pub current_strategy: LoadBalancingStrategy,
552}
553
554#[allow(clippy::derivable_impls)]
555impl Default for LoadBalancingStrategy {
556 fn default() -> Self {
557 Self::Adaptive
558 }
559}
560
561pub struct ChunkDistributor<F: IntegrateFloat> {
563 job_id: JobId,
565 next_chunk_id: AtomicUsize,
567 _phantom: std::marker::PhantomData<F>,
569}
570
571impl<F: IntegrateFloat> ChunkDistributor<F> {
572 pub fn new(job_id: JobId) -> Self {
574 Self {
575 job_id,
576 next_chunk_id: AtomicUsize::new(0),
577 _phantom: std::marker::PhantomData,
578 }
579 }
580
581 pub fn create_chunks(
583 &self,
584 t_span: (F, F),
585 initial_state: Array1<F>,
586 num_chunks: usize,
587 ) -> Vec<WorkChunk<F>> {
588 let t_start = t_span.0;
589 let t_end = t_span.1;
590 let dt = (t_end - t_start) / F::from(num_chunks).unwrap_or(F::one());
591
592 let mut chunks = Vec::with_capacity(num_chunks);
593
594 for i in 0..num_chunks {
595 let chunk_t_start = t_start + dt * F::from(i).unwrap_or(F::zero());
596 let chunk_t_end = if i == num_chunks - 1 {
597 t_end
598 } else {
599 t_start + dt * F::from(i + 1).unwrap_or(F::one())
600 };
601
602 let chunk_id = ChunkId::new(self.next_chunk_id.fetch_add(1, Ordering::SeqCst) as u64);
603
604 let state = if i == 0 {
607 initial_state.clone()
608 } else {
609 Array1::zeros(initial_state.len())
610 };
611
612 chunks.push(WorkChunk::new(
613 chunk_id,
614 self.job_id,
615 (chunk_t_start, chunk_t_end),
616 state,
617 ));
618 }
619
620 chunks
621 }
622
623 pub fn subdivide_chunk(&self, chunk: &WorkChunk<F>, num_parts: usize) -> Vec<WorkChunk<F>> {
625 let (t_start, t_end) = chunk.time_interval;
626 let dt = (t_end - t_start) / F::from(num_parts).unwrap_or(F::one());
627
628 let mut sub_chunks = Vec::with_capacity(num_parts);
629
630 for i in 0..num_parts {
631 let sub_t_start = t_start + dt * F::from(i).unwrap_or(F::zero());
632 let sub_t_end = if i == num_parts - 1 {
633 t_end
634 } else {
635 t_start + dt * F::from(i + 1).unwrap_or(F::one())
636 };
637
638 let sub_chunk_id =
639 ChunkId::new(self.next_chunk_id.fetch_add(1, Ordering::SeqCst) as u64);
640
641 let state = if i == 0 {
642 chunk.initial_state.clone()
643 } else {
644 Array1::zeros(chunk.initial_state.len())
645 };
646
647 let mut sub_chunk =
648 WorkChunk::new(sub_chunk_id, chunk.job_id, (sub_t_start, sub_t_end), state);
649
650 sub_chunk.priority = chunk.priority;
651 sub_chunks.push(sub_chunk);
652 }
653
654 sub_chunks
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661 use crate::distributed::types::NodeCapabilities;
662 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
663
664 fn create_test_nodes(n: usize) -> Vec<NodeInfo> {
665 (0..n)
666 .map(|i| {
667 let addr =
668 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080 + i as u16);
669 let mut info = NodeInfo::new(NodeId::new(i as u64), addr);
670 info.capabilities = NodeCapabilities::default();
671 info
672 })
673 .collect()
674 }
675
676 #[test]
677 fn test_round_robin_assignment() {
678 let balancer: LoadBalancer<f64> = LoadBalancer::new(
679 LoadBalancingStrategy::RoundRobin,
680 LoadBalancerConfig::default(),
681 );
682
683 let nodes = create_test_nodes(3);
684
685 for node in &nodes {
687 balancer.register_node(node.id).expect("Failed to register");
688 }
689
690 let chunk = WorkChunk::new(ChunkId::new(1), JobId::new(1), (0.0, 1.0), Array1::zeros(3));
691
692 let assignments: Vec<_> = (0..6)
694 .map(|_| {
695 balancer
696 .assign_chunk(&chunk, &nodes)
697 .expect("Assignment failed")
698 })
699 .collect();
700
701 for i in 0..3 {
703 assert_eq!(assignments[i], assignments[i + 3]);
704 }
705 }
706
707 #[test]
708 fn test_performance_update() {
709 let mut perf = NodePerformance::new(NodeId::new(1));
710
711 perf.update(Duration::from_millis(100), 1.0, true);
712 assert_eq!(perf.chunks_processed, 1);
713 assert!(perf.success_rate > 0.9);
714
715 perf.update(Duration::from_millis(50), 1.0, false);
716 assert_eq!(perf.failures, 1);
717 assert!(perf.success_rate < 1.0);
718 }
719
720 #[test]
721 fn test_chunk_distributor() {
722 let distributor: ChunkDistributor<f64> = ChunkDistributor::new(JobId::new(1));
723
724 let chunks = distributor.create_chunks((0.0, 10.0), Array1::from_vec(vec![1.0, 2.0]), 5);
725
726 assert_eq!(chunks.len(), 5);
727 assert!((chunks[0].time_interval.0 - 0.0).abs() < 1e-10);
728 assert!((chunks[4].time_interval.1 - 10.0).abs() < 1e-10);
729 }
730
731 #[test]
732 fn test_load_distribution() {
733 let balancer: LoadBalancer<f64> = LoadBalancer::new(
734 LoadBalancingStrategy::Adaptive,
735 LoadBalancerConfig::default(),
736 );
737
738 let nodes = create_test_nodes(3);
739 for node in &nodes {
740 balancer.register_node(node.id).expect("Failed to register");
741 }
742
743 for i in 0..10 {
745 let chunk =
746 WorkChunk::new(ChunkId::new(i), JobId::new(1), (0.0, 1.0), Array1::zeros(3));
747 let _ = balancer.assign_chunk(&chunk, &nodes);
748 }
749
750 let distribution = balancer.get_load_distribution();
751 assert_eq!(distribution.len(), 3);
752 }
753}