1use rayon::prelude::*;
2use scirs2_core::ndarray::{s, Array1, Array2, Axis};
3use scirs2_core::random::rngs::StdRng;
4use scirs2_core::random::Rng;
5use scirs2_core::random::{thread_rng, SeedableRng};
6use scirs2_core::StandardNormal;
7use sklears_core::error::{Result, SklearsError};
8
9#[derive(Debug, Clone)]
16pub enum PartitionStrategy {
18 Random,
20 Block,
22 Stratified,
24 Custom(fn(usize, usize) -> Vec<Vec<usize>>),
26}
27
28#[derive(Debug, Clone)]
30pub enum CommunicationPattern {
32 AllToAll,
34 MasterWorker,
36 Ring,
38 Tree,
40}
41
42#[derive(Debug, Clone)]
44pub enum AggregationMethod {
46 Average,
48 WeightedAverage,
50 Concatenate,
52 BestQuality,
54 Ensemble,
56}
57
58#[derive(Debug, Clone)]
60pub struct DistributedConfig {
62 pub n_workers: usize,
64 pub partition_strategy: PartitionStrategy,
66 pub communication_pattern: CommunicationPattern,
68 pub aggregation_method: AggregationMethod,
70 pub chunk_size: Option<usize>,
72 pub overlap_ratio: f64,
74 pub fault_tolerance: bool,
76 pub load_balancing: bool,
78}
79
80impl Default for DistributedConfig {
81 fn default() -> Self {
82 Self {
83 n_workers: num_cpus::get(),
84 partition_strategy: PartitionStrategy::Block,
85 communication_pattern: CommunicationPattern::MasterWorker,
86 aggregation_method: AggregationMethod::Average,
87 chunk_size: None,
88 overlap_ratio: 0.1,
89 fault_tolerance: false,
90 load_balancing: true,
91 }
92 }
93}
94
95#[derive(Debug)]
97pub struct Worker {
99 pub id: usize,
101 pub data_indices: Vec<usize>,
103 pub local_features: Option<Array2<f64>>,
105 pub is_active: bool,
107 pub computation_time: f64,
109 pub memory_usage: usize,
111}
112
113impl Worker {
114 pub fn new(id: usize, data_indices: Vec<usize>) -> Self {
115 Self {
116 id,
117 data_indices,
118 local_features: None,
119 is_active: true,
120 computation_time: 0.0,
121 memory_usage: 0,
122 }
123 }
124}
125
126pub struct DistributedRBFSampler {
131 n_components: usize,
132 gamma: f64,
133 config: DistributedConfig,
134 workers: Vec<Worker>,
135 global_weights: Option<Array2<f64>>,
136 global_bias: Option<Array1<f64>>,
137 random_state: Option<u64>,
138}
139
140impl DistributedRBFSampler {
141 pub fn new(n_components: usize, gamma: f64) -> Self {
143 Self {
144 n_components,
145 gamma,
146 config: DistributedConfig::default(),
147 workers: Vec::new(),
148 global_weights: None,
149 global_bias: None,
150 random_state: None,
151 }
152 }
153
154 pub fn with_config(mut self, config: DistributedConfig) -> Self {
156 self.config = config;
157 self
158 }
159
160 pub fn with_random_state(mut self, random_state: u64) -> Self {
162 self.random_state = Some(random_state);
163 self
164 }
165
166 pub fn fit(&mut self, x: &Array2<f64>) -> Result<()> {
168 let (n_samples, n_features) = x.dim();
169
170 self.initialize_workers(n_samples)?;
172
173 let components_per_worker = self.n_components / self.config.n_workers;
175 let mut all_weights = Vec::new();
176 let mut all_bias = Vec::new();
177
178 let weight_results: Vec<(Array2<f64>, Array1<f64>)> = (0..self.config.n_workers)
180 .into_par_iter()
181 .map(|worker_id| {
182 let mut rng = match self.random_state {
183 Some(seed) => StdRng::seed_from_u64(seed + worker_id as u64),
184 None => StdRng::from_seed(thread_rng().gen()),
185 };
186
187 let worker_components = if worker_id == self.config.n_workers - 1 {
188 self.n_components - components_per_worker * worker_id
190 } else {
191 components_per_worker
192 };
193
194 let mut worker_weights = Array2::zeros((worker_components, n_features));
196 for i in 0..worker_components {
197 for j in 0..n_features {
198 worker_weights[[i, j]] =
199 rng.sample::<f64, _>(StandardNormal) * (2.0 * self.gamma).sqrt();
200 }
201 }
202
203 let mut worker_bias = Array1::zeros(worker_components);
205 for i in 0..worker_components {
206 worker_bias[i] = rng.gen_range(0.0..2.0 * std::f64::consts::PI);
207 }
208
209 (worker_weights, worker_bias)
210 })
211 .collect();
212
213 for (weights, bias) in weight_results {
214 all_weights.push(weights);
215 all_bias.push(bias);
216 }
217
218 self.global_weights = Some(
220 scirs2_core::ndarray::concatenate(
221 Axis(0),
222 &all_weights
223 .iter()
224 .map(|w: &Array2<f64>| w.view())
225 .collect::<Vec<_>>(),
226 )
227 .map_err(|e| SklearsError::Other(format!("Failed to concatenate weights: {}", e)))?,
228 );
229
230 self.global_bias = Some(
231 scirs2_core::ndarray::concatenate(
232 Axis(0),
233 &all_bias
234 .iter()
235 .map(|b: &Array1<f64>| b.view())
236 .collect::<Vec<_>>(),
237 )
238 .map_err(|e| SklearsError::Other(format!("Failed to concatenate bias: {}", e)))?,
239 );
240
241 Ok(())
242 }
243
244 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
246 let weights = self
247 .global_weights
248 .as_ref()
249 .ok_or_else(|| SklearsError::NotFitted {
250 operation: "transform".to_string(),
251 })?;
252 let bias = self
253 .global_bias
254 .as_ref()
255 .ok_or_else(|| SklearsError::NotFitted {
256 operation: "transform".to_string(),
257 })?;
258
259 let (n_samples, _) = x.dim();
260
261 let samples_per_worker = n_samples / self.config.n_workers;
263
264 let feature_results: Vec<Array2<f64>> = (0..self.config.n_workers)
265 .into_par_iter()
266 .map(|worker_id| {
267 let start_idx = worker_id * samples_per_worker;
268 let end_idx = if worker_id == self.config.n_workers - 1 {
269 n_samples
270 } else {
271 (worker_id + 1) * samples_per_worker
272 };
273
274 let worker_data = x.slice(s![start_idx..end_idx, ..]);
275 self.compute_features(&worker_data, weights, bias)
276 })
277 .collect();
278
279 let combined_features = scirs2_core::ndarray::concatenate(
281 Axis(0),
282 &feature_results.iter().map(|f| f.view()).collect::<Vec<_>>(),
283 )
284 .map_err(|e| SklearsError::Other(format!("Failed to concatenate features: {}", e)))?;
285
286 Ok(combined_features)
287 }
288
289 fn compute_features(
291 &self,
292 x: &scirs2_core::ndarray::ArrayView2<f64>,
293 weights: &Array2<f64>,
294 bias: &Array1<f64>,
295 ) -> Array2<f64> {
296 let (n_samples, _) = x.dim();
297 let n_components = weights.nrows();
298
299 let projection = x.dot(&weights.t()) + bias;
301
302 let mut features = Array2::zeros((n_samples, n_components));
304 let norm_factor = (2.0 / n_components as f64).sqrt();
305
306 for i in 0..n_samples {
307 for j in 0..n_components {
308 features[[i, j]] = norm_factor * projection[[i, j]].cos();
309 }
310 }
311
312 features
313 }
314
315 fn initialize_workers(&mut self, n_samples: usize) -> Result<()> {
317 self.workers.clear();
318
319 let partitions = match &self.config.partition_strategy {
320 PartitionStrategy::Block => self.create_block_partitions(n_samples),
321 PartitionStrategy::Random => self.create_random_partitions(n_samples),
322 PartitionStrategy::Stratified => {
323 self.create_block_partitions(n_samples)
325 }
326 PartitionStrategy::Custom(partition_fn) => {
327 partition_fn(n_samples, self.config.n_workers)
328 }
329 };
330
331 for (worker_id, indices) in partitions.into_iter().enumerate() {
332 self.workers.push(Worker::new(worker_id, indices));
333 }
334
335 Ok(())
336 }
337
338 fn create_block_partitions(&self, n_samples: usize) -> Vec<Vec<usize>> {
340 let samples_per_worker = n_samples / self.config.n_workers;
341 let mut partitions = Vec::new();
342
343 for worker_id in 0..self.config.n_workers {
344 let start_idx = worker_id * samples_per_worker;
345 let end_idx = if worker_id == self.config.n_workers - 1 {
346 n_samples
347 } else {
348 (worker_id + 1) * samples_per_worker
349 };
350
351 partitions.push((start_idx..end_idx).collect());
352 }
353
354 partitions
355 }
356
357 fn create_random_partitions(&self, n_samples: usize) -> Vec<Vec<usize>> {
359 let mut rng = match self.random_state {
360 Some(seed) => StdRng::seed_from_u64(seed),
361 None => StdRng::from_seed(thread_rng().gen()),
362 };
363
364 let mut indices: Vec<usize> = (0..n_samples).collect();
365
366 for i in (1..indices.len()).rev() {
368 let j = rng.gen_range(0..i + 1);
369 indices.swap(i, j);
370 }
371
372 let samples_per_worker = n_samples / self.config.n_workers;
374 let mut partitions = Vec::new();
375
376 for worker_id in 0..self.config.n_workers {
377 let start_idx = worker_id * samples_per_worker;
378 let end_idx = if worker_id == self.config.n_workers - 1 {
379 n_samples
380 } else {
381 (worker_id + 1) * samples_per_worker
382 };
383
384 partitions.push(indices[start_idx..end_idx].to_vec());
385 }
386
387 partitions
388 }
389
390 pub fn worker_stats(&self) -> Vec<(usize, usize, bool)> {
392 self.workers
393 .iter()
394 .map(|w| (w.id, w.data_indices.len(), w.is_active))
395 .collect()
396 }
397
398 pub fn total_memory_usage(&self) -> usize {
400 self.workers.iter().map(|w| w.memory_usage).sum()
401 }
402}
403
404pub struct DistributedNystroem {
409 n_components: usize,
410 gamma: f64,
411 config: DistributedConfig,
412 workers: Vec<Worker>,
413 eigenvalues: Option<Array1<f64>>,
414 eigenvectors: Option<Array2<f64>>,
415 inducing_points: Option<Array2<f64>>,
416 random_state: Option<u64>,
417}
418
419impl DistributedNystroem {
420 pub fn new(n_components: usize, gamma: f64) -> Self {
422 Self {
423 n_components,
424 gamma,
425 config: DistributedConfig::default(),
426 workers: Vec::new(),
427 eigenvalues: None,
428 eigenvectors: None,
429 inducing_points: None,
430 random_state: None,
431 }
432 }
433
434 pub fn with_config(mut self, config: DistributedConfig) -> Self {
436 self.config = config;
437 self
438 }
439
440 pub fn with_random_state(mut self, random_state: u64) -> Self {
442 self.random_state = Some(random_state);
443 self
444 }
445
446 pub fn fit(&mut self, x: &Array2<f64>) -> Result<()> {
448 let (_n_samples, _) = x.dim();
449
450 let inducing_indices = self.select_inducing_points(x)?;
452 let inducing_points = x.select(Axis(0), &inducing_indices);
453
454 let kernel_matrix = self.compute_kernel_matrix(&inducing_points)?;
456
457 let (eigenvalues, eigenvectors) = self.eigendecomposition(&kernel_matrix)?;
459
460 self.inducing_points = Some(inducing_points);
462 self.eigenvalues = Some(eigenvalues);
463 self.eigenvectors = Some(eigenvectors);
464
465 Ok(())
466 }
467
468 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
470 let inducing_points =
471 self.inducing_points
472 .as_ref()
473 .ok_or_else(|| SklearsError::NotFitted {
474 operation: "transform".to_string(),
475 })?;
476 let eigenvalues = self
477 .eigenvalues
478 .as_ref()
479 .ok_or_else(|| SklearsError::NotFitted {
480 operation: "transform".to_string(),
481 })?;
482 let eigenvectors = self
483 .eigenvectors
484 .as_ref()
485 .ok_or_else(|| SklearsError::NotFitted {
486 operation: "transform".to_string(),
487 })?;
488
489 let kernel_x_inducing = self.compute_kernel(x, inducing_points)?;
491
492 let mut features = kernel_x_inducing.dot(eigenvectors);
494
495 for i in 0..eigenvalues.len() {
497 if eigenvalues[i] > 1e-12 {
498 let scale = 1.0 / eigenvalues[i].sqrt();
499 for j in 0..features.nrows() {
500 features[[j, i]] *= scale;
501 }
502 }
503 }
504
505 Ok(features)
506 }
507
508 fn select_inducing_points(&self, x: &Array2<f64>) -> Result<Vec<usize>> {
510 let n_samples = x.nrows();
511 let mut rng = match self.random_state {
512 Some(seed) => StdRng::seed_from_u64(seed),
513 None => StdRng::from_seed(thread_rng().gen()),
514 };
515
516 let mut indices = Vec::new();
518 for _ in 0..self.n_components {
519 indices.push(rng.gen_range(0..n_samples));
520 }
521
522 Ok(indices)
523 }
524
525 fn compute_kernel_matrix(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
527 let n_samples = x.nrows();
528 let mut kernel_matrix = Array2::zeros((n_samples, n_samples));
529
530 for i in 0..n_samples {
531 for j in i..n_samples {
532 let diff = &x.row(i) - &x.row(j);
533 let squared_dist = diff.mapv(|x| x * x).sum();
534 let kernel_val = (-self.gamma * squared_dist).exp();
535 kernel_matrix[[i, j]] = kernel_val;
536 kernel_matrix[[j, i]] = kernel_val;
537 }
538 }
539
540 Ok(kernel_matrix)
541 }
542
543 fn compute_kernel(&self, x: &Array2<f64>, y: &Array2<f64>) -> Result<Array2<f64>> {
545 let (n_samples_x, _) = x.dim();
546 let (n_samples_y, _) = y.dim();
547 let mut kernel_matrix = Array2::zeros((n_samples_x, n_samples_y));
548
549 for i in 0..n_samples_x {
550 for j in 0..n_samples_y {
551 let diff = &x.row(i) - &y.row(j);
552 let squared_dist = diff.mapv(|x| x * x).sum();
553 let kernel_val = (-self.gamma * squared_dist).exp();
554 kernel_matrix[[i, j]] = kernel_val;
555 }
556 }
557
558 Ok(kernel_matrix)
559 }
560
561 fn eigendecomposition(&self, matrix: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
563 let n = matrix.nrows();
566 let eigenvalues = Array1::ones(self.n_components.min(n));
567 let eigenvectors = Array2::eye(n)
568 .slice(s![.., ..self.n_components.min(n)])
569 .to_owned();
570
571 Ok((eigenvalues, eigenvectors))
572 }
573}
574
575#[allow(non_snake_case)]
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use scirs2_core::ndarray::array;
580
581 #[test]
582 fn test_distributed_rbf_sampler_basic() {
583 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
584
585 let mut sampler = DistributedRBFSampler::new(100, 0.1).with_random_state(42);
586
587 sampler.fit(&x).unwrap();
588 let features = sampler.transform(&x).unwrap();
589
590 assert_eq!(features.nrows(), 4);
591 assert_eq!(features.ncols(), 100);
592 }
593
594 #[test]
595 fn test_distributed_config() {
596 let config = DistributedConfig {
597 n_workers: 4,
598 partition_strategy: PartitionStrategy::Random,
599 communication_pattern: CommunicationPattern::AllToAll,
600 aggregation_method: AggregationMethod::WeightedAverage,
601 ..Default::default()
602 };
603
604 assert_eq!(config.n_workers, 4);
605 assert!(matches!(
606 config.partition_strategy,
607 PartitionStrategy::Random
608 ));
609 }
610
611 #[test]
612 fn test_worker_initialization() {
613 let worker = Worker::new(0, vec![0, 1, 2, 3]);
614 assert_eq!(worker.id, 0);
615 assert_eq!(worker.data_indices.len(), 4);
616 assert!(worker.is_active);
617 }
618
619 #[test]
620 fn test_distributed_nystroem_basic() {
621 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
622
623 let mut nystroem = DistributedNystroem::new(3, 0.1).with_random_state(42);
624
625 nystroem.fit(&x).unwrap();
626 let features = nystroem.transform(&x).unwrap();
627
628 assert_eq!(features.nrows(), 4);
629 assert_eq!(features.ncols(), 3);
630 }
631
632 #[test]
633 fn test_partition_strategies() {
634 let mut sampler = DistributedRBFSampler::new(50, 0.1);
635 sampler.config.n_workers = 2;
636
637 sampler.config.partition_strategy = PartitionStrategy::Block;
639 sampler.initialize_workers(10).unwrap();
640 assert_eq!(sampler.workers.len(), 2);
641 assert_eq!(sampler.workers[0].data_indices.len(), 5);
642 assert_eq!(sampler.workers[1].data_indices.len(), 5);
643
644 sampler.config.partition_strategy = PartitionStrategy::Random;
646 sampler.random_state = Some(42);
647 sampler.initialize_workers(10).unwrap();
648 assert_eq!(sampler.workers.len(), 2);
649 }
650
651 #[test]
652 fn test_worker_stats() {
653 let mut sampler = DistributedRBFSampler::new(50, 0.1);
654 sampler.config.n_workers = 3;
655 sampler.initialize_workers(12).unwrap();
656
657 let stats = sampler.worker_stats();
658 assert_eq!(stats.len(), 3);
659 assert_eq!(stats[0].1, 4); assert_eq!(stats[1].1, 4); assert_eq!(stats[2].1, 4); }
663
664 #[test]
665 fn test_reproducibility() {
666 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
667
668 let mut sampler1 = DistributedRBFSampler::new(50, 0.1).with_random_state(42);
669 sampler1.fit(&x).unwrap();
670 let features1 = sampler1.transform(&x).unwrap();
671
672 let mut sampler2 = DistributedRBFSampler::new(50, 0.1).with_random_state(42);
673 sampler2.fit(&x).unwrap();
674 let features2 = sampler2.transform(&x).unwrap();
675
676 assert!((features1 - features2).mapv(f64::abs).sum() < 1e-10);
678 }
679
680 #[test]
681 fn test_different_worker_counts() {
682 let x = array![
683 [1.0, 2.0],
684 [3.0, 4.0],
685 [5.0, 6.0],
686 [7.0, 8.0],
687 [9.0, 10.0],
688 [11.0, 12.0]
689 ];
690
691 for n_workers in [1, 2, 3, 6] {
692 let config = DistributedConfig {
693 n_workers,
694 ..Default::default()
695 };
696
697 let mut sampler = DistributedRBFSampler::new(50, 0.1)
698 .with_config(config)
699 .with_random_state(42);
700
701 sampler.fit(&x).unwrap();
702 let features = sampler.transform(&x).unwrap();
703
704 assert_eq!(features.nrows(), 6);
705 assert_eq!(features.ncols(), 50);
706 }
707 }
708}