1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use scirs2_core::random::Random;
8use sklears_core::{
9 error::{Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
11 types::Float,
12};
13
14#[derive(Debug, Clone)]
52pub struct GraphStructureLearning<S = Untrained> {
53 state: S,
54 lambda_sparse: f64,
55 beta_smoothness: f64,
56 max_iter: usize,
57 tol: f64,
58 learning_rate: f64,
59 adaptive_lr: bool,
60 enforce_symmetry: bool,
61 normalize_weights: bool,
62}
63
64impl GraphStructureLearning<Untrained> {
65 pub fn new() -> Self {
67 Self {
68 state: Untrained,
69 lambda_sparse: 0.1,
70 beta_smoothness: 1.0,
71 max_iter: 100,
72 tol: 1e-4,
73 learning_rate: 0.01,
74 adaptive_lr: true,
75 enforce_symmetry: true,
76 normalize_weights: true,
77 }
78 }
79
80 pub fn lambda_sparse(mut self, lambda_sparse: f64) -> Self {
82 self.lambda_sparse = lambda_sparse;
83 self
84 }
85
86 pub fn beta_smoothness(mut self, beta_smoothness: f64) -> Self {
88 self.beta_smoothness = beta_smoothness;
89 self
90 }
91
92 pub fn max_iter(mut self, max_iter: usize) -> Self {
94 self.max_iter = max_iter;
95 self
96 }
97
98 pub fn tol(mut self, tol: f64) -> Self {
100 self.tol = tol;
101 self
102 }
103
104 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
106 self.learning_rate = learning_rate;
107 self
108 }
109
110 pub fn adaptive_lr(mut self, adaptive_lr: bool) -> Self {
112 self.adaptive_lr = adaptive_lr;
113 self
114 }
115
116 pub fn enforce_symmetry(mut self, enforce_symmetry: bool) -> Self {
118 self.enforce_symmetry = enforce_symmetry;
119 self
120 }
121
122 pub fn normalize_weights(mut self, normalize_weights: bool) -> Self {
124 self.normalize_weights = normalize_weights;
125 self
126 }
127
128 fn initialize_graph(&self, X: &Array2<f64>) -> Array2<f64> {
129 let n_samples = X.nrows();
130 let mut W = Array2::zeros((n_samples, n_samples));
131
132 let k = (n_samples as f64).sqrt().ceil() as usize;
134 let k = k.clamp(3, 10); for i in 0..n_samples {
137 let mut distances: Vec<(usize, f64)> = Vec::new();
138 for j in 0..n_samples {
139 if i != j {
140 let diff = &X.row(i) - &X.row(j);
141 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
142 distances.push((j, dist));
143 }
144 }
145
146 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
147
148 for &(j, dist) in distances.iter().take(k) {
149 let weight = (-dist / (2.0 * 1.0_f64.powi(2))).exp();
150 W[[i, j]] = weight;
151 if self.enforce_symmetry {
152 W[[j, i]] = weight;
153 }
154 }
155 }
156
157 W
158 }
159
160 #[allow(non_snake_case)]
161 fn compute_laplacian(&self, W: &Array2<f64>) -> Array2<f64> {
162 let n_samples = W.nrows();
163 let D = W.sum_axis(Axis(1));
164 let mut L = Array2::zeros((n_samples, n_samples));
165
166 for i in 0..n_samples {
167 L[[i, i]] = D[i];
168 for j in 0..n_samples {
169 if i != j {
170 L[[i, j]] = -W[[i, j]];
171 }
172 }
173 }
174
175 L
176 }
177
178 fn soft_threshold(&self, x: f64, threshold: f64) -> f64 {
179 if x > threshold {
180 x - threshold
181 } else if x < -threshold {
182 x + threshold
183 } else {
184 0.0
185 }
186 }
187
188 fn proximal_gradient_step(&self, W: &Array2<f64>, grad: &Array2<f64>, lr: f64) -> Array2<f64> {
189 let mut W_new = W - lr * grad;
190
191 let threshold = lr * self.lambda_sparse;
193 W_new.mapv_inplace(|x| self.soft_threshold(x, threshold));
194
195 W_new.mapv_inplace(|x| x.max(0.0));
197
198 if self.enforce_symmetry {
200 let n = W_new.nrows();
201 for i in 0..n {
202 for j in 0..n {
203 if i != j {
204 let avg = (W_new[[i, j]] + W_new[[j, i]]) / 2.0;
205 W_new[[i, j]] = avg;
206 W_new[[j, i]] = avg;
207 }
208 }
209 }
210 }
211
212 for i in 0..W_new.nrows() {
214 W_new[[i, i]] = 0.0;
215 }
216
217 W_new
218 }
219
220 fn normalize_graph(&self, W: &Array2<f64>) -> Array2<f64> {
221 if !self.normalize_weights {
222 return W.clone();
223 }
224
225 let mut W_norm = W.clone();
226 let n_samples = W.nrows();
227
228 for i in 0..n_samples {
230 let row_sum: f64 = W.row(i).sum();
231 if row_sum > 0.0 {
232 for j in 0..n_samples {
233 W_norm[[i, j]] = W[[i, j]] / row_sum;
234 }
235 }
236 }
237
238 W_norm
239 }
240
241 #[allow(non_snake_case)]
242 fn propagate_labels(&self, W: &Array2<f64>, Y_init: &Array2<f64>) -> SklResult<Array2<f64>> {
243 let n_samples = W.nrows();
244 let n_classes = Y_init.ncols();
245
246 let D = W.sum_axis(Axis(1));
248 let mut P = Array2::zeros((n_samples, n_samples));
249 for i in 0..n_samples {
250 if D[i] > 0.0 {
251 for j in 0..n_samples {
252 P[[i, j]] = W[[i, j]] / D[i];
253 }
254 }
255 }
256
257 let mut Y = Y_init.clone();
258 let Y_static = Y_init.clone();
259
260 for _iter in 0..50 {
262 let prev_Y = Y.clone();
263 Y = 0.9 * P.dot(&Y) + 0.1 * &Y_static;
264
265 let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
267 if diff < 1e-6 {
268 break;
269 }
270 }
271
272 Ok(Y)
273 }
274}
275
276impl Default for GraphStructureLearning<Untrained> {
277 fn default() -> Self {
278 Self::new()
279 }
280}
281
282impl Estimator for GraphStructureLearning<Untrained> {
283 type Config = ();
284 type Error = SklearsError;
285 type Float = Float;
286
287 fn config(&self) -> &Self::Config {
288 &()
289 }
290}
291
292impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for GraphStructureLearning<Untrained> {
293 type Fitted = GraphStructureLearning<GraphStructureLearningTrained>;
294
295 #[allow(non_snake_case)]
296 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
297 let X = X.to_owned();
298 let y = y.to_owned();
299
300 let (n_samples, _n_features) = X.dim();
301
302 let mut labeled_indices = Vec::new();
304 let mut classes = std::collections::HashSet::new();
305
306 for (i, &label) in y.iter().enumerate() {
307 if label != -1 {
308 labeled_indices.push(i);
309 classes.insert(label);
310 }
311 }
312
313 if labeled_indices.is_empty() {
314 return Err(SklearsError::InvalidInput(
315 "No labeled samples provided".to_string(),
316 ));
317 }
318
319 let classes: Vec<i32> = classes.into_iter().collect();
320 let n_classes = classes.len();
321
322 let mut W = self.initialize_graph(&X);
324
325 let mut Y = Array2::zeros((n_samples, n_classes));
327 for &idx in &labeled_indices {
328 if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
329 Y[[idx, class_idx]] = 1.0;
330 }
331 }
332
333 let Y_init = Y.clone();
334 let mut lr = self.learning_rate;
335 let mut prev_loss = f64::INFINITY;
336
337 for iteration in 0..self.max_iter {
339 Y = self.propagate_labels(&W, &Y_init)?;
341
342 let L = self.compute_laplacian(&W);
344
345 let WX = W.dot(&X);
347 let data_fidelity = (&X - &WX).mapv(|x| x * x).sum();
348
349 let smoothness_loss = {
351 let LY = L.dot(&Y);
352 let mut trace = 0.0;
353 for i in 0..n_samples {
354 for j in 0..n_classes {
355 trace += Y[[i, j]] * LY[[i, j]];
356 }
357 }
358 trace
359 };
360
361 let sparsity_loss = W.iter().map(|&x| x.abs()).sum::<f64>();
363
364 let total_loss = data_fidelity
366 + self.beta_smoothness * smoothness_loss
367 + self.lambda_sparse * sparsity_loss;
368
369 if (prev_loss - total_loss).abs() < self.tol {
371 break;
372 }
373
374 if self.adaptive_lr {
376 if total_loss > prev_loss {
377 lr *= 0.8; } else if iteration % 10 == 0 && total_loss < prev_loss {
379 lr *= 1.1; }
381 lr = lr.clamp(1e-6, 0.1); }
383
384 prev_loss = total_loss;
385
386 let mut grad_W = Array2::zeros(W.dim());
388
389 let residual = &WX - &X;
391 grad_W = 2.0 * residual.dot(&X.t());
392
393 let YYT = Y.dot(&Y.t());
395 let D = W.sum_axis(Axis(1));
396 for i in 0..n_samples {
397 for j in 0..n_samples {
398 if i == j {
399 grad_W[[i, j]] +=
400 self.beta_smoothness * (D[i] * YYT[[i, j]] - W[[i, j]] * YYT[[i, j]]);
401 } else {
402 grad_W[[i, j]] += self.beta_smoothness * (-YYT[[i, j]]);
403 }
404 }
405 }
406
407 W = self.proximal_gradient_step(&W, &grad_W, lr);
409 }
410
411 let W_final = self.normalize_graph(&W);
413
414 let Y_final = self.propagate_labels(&W_final, &Y_init)?;
416
417 Ok(GraphStructureLearning {
418 state: GraphStructureLearningTrained {
419 X_train: X,
420 y_train: y,
421 classes: Array1::from(classes),
422 learned_graph: W_final,
423 label_distributions: Y_final,
424 },
425 lambda_sparse: self.lambda_sparse,
426 beta_smoothness: self.beta_smoothness,
427 max_iter: self.max_iter,
428 tol: self.tol,
429 learning_rate: self.learning_rate,
430 adaptive_lr: self.adaptive_lr,
431 enforce_symmetry: self.enforce_symmetry,
432 normalize_weights: self.normalize_weights,
433 })
434 }
435}
436
437impl Predict<ArrayView2<'_, Float>, Array1<i32>>
438 for GraphStructureLearning<GraphStructureLearningTrained>
439{
440 #[allow(non_snake_case)]
441 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
442 let X = X.to_owned();
443 let n_test = X.nrows();
444 let mut predictions = Array1::zeros(n_test);
445
446 for i in 0..n_test {
447 let mut min_dist = f64::INFINITY;
449 let mut best_idx = 0;
450
451 for j in 0..self.state.X_train.nrows() {
452 let diff = &X.row(i) - &self.state.X_train.row(j);
453 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
454 if dist < min_dist {
455 min_dist = dist;
456 best_idx = j;
457 }
458 }
459
460 let distributions = self.state.label_distributions.row(best_idx);
462 let max_idx = distributions
463 .iter()
464 .enumerate()
465 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
466 .unwrap()
467 .0;
468
469 predictions[i] = self.state.classes[max_idx];
470 }
471
472 Ok(predictions)
473 }
474}
475
476impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
477 for GraphStructureLearning<GraphStructureLearningTrained>
478{
479 #[allow(non_snake_case)]
480 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
481 let X = X.to_owned();
482 let n_test = X.nrows();
483 let n_classes = self.state.classes.len();
484 let mut probas = Array2::zeros((n_test, n_classes));
485
486 for i in 0..n_test {
487 let mut min_dist = f64::INFINITY;
489 let mut best_idx = 0;
490
491 for j in 0..self.state.X_train.nrows() {
492 let diff = &X.row(i) - &self.state.X_train.row(j);
493 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
494 if dist < min_dist {
495 min_dist = dist;
496 best_idx = j;
497 }
498 }
499
500 for k in 0..n_classes {
502 probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
503 }
504 }
505
506 Ok(probas)
507 }
508}
509
510#[derive(Debug, Clone)]
542pub struct RobustGraphLearning<S = Untrained> {
543 state: S,
544 lambda_sparse: f64,
545 lambda_robust: f64,
546 max_iter: usize,
547 tol: f64,
548 robust_metric: String,
549 huber_delta: f64,
550 tukey_c: f64,
551}
552
553impl RobustGraphLearning<Untrained> {
554 pub fn new() -> Self {
556 Self {
557 state: Untrained,
558 lambda_sparse: 0.1,
559 lambda_robust: 1.0,
560 max_iter: 100,
561 tol: 1e-4,
562 robust_metric: "huber".to_string(),
563 huber_delta: 1.0,
564 tukey_c: 4.685,
565 }
566 }
567
568 pub fn lambda_sparse(mut self, lambda_sparse: f64) -> Self {
570 self.lambda_sparse = lambda_sparse;
571 self
572 }
573
574 pub fn lambda_robust(mut self, lambda_robust: f64) -> Self {
576 self.lambda_robust = lambda_robust;
577 self
578 }
579
580 pub fn max_iter(mut self, max_iter: usize) -> Self {
582 self.max_iter = max_iter;
583 self
584 }
585
586 pub fn tol(mut self, tol: f64) -> Self {
588 self.tol = tol;
589 self
590 }
591
592 pub fn robust_metric(mut self, metric: String) -> Self {
594 self.robust_metric = metric;
595 self
596 }
597
598 pub fn huber_delta(mut self, delta: f64) -> Self {
600 self.huber_delta = delta;
601 self
602 }
603
604 pub fn tukey_c(mut self, c: f64) -> Self {
606 self.tukey_c = c;
607 self
608 }
609
610 fn robust_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
611 let diff = x1 - x2;
612
613 match self.robust_metric.as_str() {
614 "l1" => diff.mapv(|x| x.abs()).sum(),
615 "huber" => diff
616 .mapv(|x| {
617 let abs_x = x.abs();
618 if abs_x <= self.huber_delta {
619 0.5 * x * x
620 } else {
621 self.huber_delta * (abs_x - 0.5 * self.huber_delta)
622 }
623 })
624 .sum(),
625 "tukey" => diff
626 .mapv(|x| {
627 let abs_x = x.abs();
628 if abs_x <= self.tukey_c {
629 let ratio = x / self.tukey_c;
630 (self.tukey_c * self.tukey_c / 6.0) * (1.0 - (1.0 - ratio * ratio).powi(3))
631 } else {
632 self.tukey_c * self.tukey_c / 6.0
633 }
634 })
635 .sum(),
636 _ => diff.mapv(|x| x * x).sum().sqrt(), }
638 }
639
640 fn compute_robust_weights(&self, X: &Array2<f64>) -> Array2<f64> {
641 let n_samples = X.nrows();
642 let mut W = Array2::zeros((n_samples, n_samples));
643
644 for i in 0..n_samples {
645 for j in 0..n_samples {
646 if i != j {
647 let dist = self.robust_distance(&X.row(i), &X.row(j));
648 W[[i, j]] = (-dist / (2.0 * 1.0_f64.powi(2))).exp();
649 }
650 }
651 }
652
653 W
654 }
655}
656
657impl Default for RobustGraphLearning<Untrained> {
658 fn default() -> Self {
659 Self::new()
660 }
661}
662
663impl Estimator for RobustGraphLearning<Untrained> {
664 type Config = ();
665 type Error = SklearsError;
666 type Float = Float;
667
668 fn config(&self) -> &Self::Config {
669 &()
670 }
671}
672
673impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for RobustGraphLearning<Untrained> {
674 type Fitted = RobustGraphLearning<RobustGraphLearningTrained>;
675
676 #[allow(non_snake_case)]
677 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
678 let X = X.to_owned();
679 let y = y.to_owned();
680
681 let mut labeled_indices = Vec::new();
683 let mut classes = std::collections::HashSet::new();
684
685 for (i, &label) in y.iter().enumerate() {
686 if label != -1 {
687 labeled_indices.push(i);
688 classes.insert(label);
689 }
690 }
691
692 if labeled_indices.is_empty() {
693 return Err(SklearsError::InvalidInput(
694 "No labeled samples provided".to_string(),
695 ));
696 }
697
698 let classes: Vec<i32> = classes.into_iter().collect();
699 let n_classes = classes.len();
700 let n_samples = X.nrows();
701
702 let mut W = self.compute_robust_weights(&X);
704
705 let threshold = self.lambda_sparse;
707 W.mapv_inplace(|x| if x > threshold { x - threshold } else { 0.0 });
708
709 W.mapv_inplace(|x| x.max(0.0));
711 for i in 0..n_samples {
712 W[[i, i]] = 0.0;
713 }
714
715 for i in 0..n_samples {
717 for j in i + 1..n_samples {
718 let avg = (W[[i, j]] + W[[j, i]]) / 2.0;
719 W[[i, j]] = avg;
720 W[[j, i]] = avg;
721 }
722 }
723
724 let mut Y = Array2::zeros((n_samples, n_classes));
726 for &idx in &labeled_indices {
727 if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
728 Y[[idx, class_idx]] = 1.0;
729 }
730 }
731
732 let D = W.sum_axis(Axis(1));
734 let mut P = Array2::zeros((n_samples, n_samples));
735 for i in 0..n_samples {
736 if D[i] > 0.0 {
737 for j in 0..n_samples {
738 P[[i, j]] = W[[i, j]] / D[i];
739 }
740 }
741 }
742
743 let Y_static = Y.clone();
744
745 for _iter in 0..50 {
747 let prev_Y = Y.clone();
748 Y = 0.9 * P.dot(&Y) + 0.1 * &Y_static;
749
750 let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
752 if diff < 1e-6 {
753 break;
754 }
755 }
756
757 Ok(RobustGraphLearning {
758 state: RobustGraphLearningTrained {
759 X_train: X,
760 y_train: y,
761 classes: Array1::from(classes),
762 learned_graph: W,
763 label_distributions: Y,
764 },
765 lambda_sparse: self.lambda_sparse,
766 lambda_robust: self.lambda_robust,
767 max_iter: self.max_iter,
768 tol: self.tol,
769 robust_metric: self.robust_metric,
770 huber_delta: self.huber_delta,
771 tukey_c: self.tukey_c,
772 })
773 }
774}
775
776impl Predict<ArrayView2<'_, Float>, Array1<i32>>
777 for RobustGraphLearning<RobustGraphLearningTrained>
778{
779 #[allow(non_snake_case)]
780 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
781 let X = X.to_owned();
782 let n_test = X.nrows();
783 let mut predictions = Array1::zeros(n_test);
784
785 for i in 0..n_test {
786 let mut min_dist = f64::INFINITY;
788 let mut best_idx = 0;
789
790 for j in 0..self.state.X_train.nrows() {
791 let diff = &X.row(i) - &self.state.X_train.row(j);
792 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
793 if dist < min_dist {
794 min_dist = dist;
795 best_idx = j;
796 }
797 }
798
799 let distributions = self.state.label_distributions.row(best_idx);
801 let max_idx = distributions
802 .iter()
803 .enumerate()
804 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
805 .unwrap()
806 .0;
807
808 predictions[i] = self.state.classes[max_idx];
809 }
810
811 Ok(predictions)
812 }
813}
814
815impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
816 for RobustGraphLearning<RobustGraphLearningTrained>
817{
818 #[allow(non_snake_case)]
819 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
820 let X = X.to_owned();
821 let n_test = X.nrows();
822 let n_classes = self.state.classes.len();
823 let mut probas = Array2::zeros((n_test, n_classes));
824
825 for i in 0..n_test {
826 let mut min_dist = f64::INFINITY;
828 let mut best_idx = 0;
829
830 for j in 0..self.state.X_train.nrows() {
831 let diff = &X.row(i) - &self.state.X_train.row(j);
832 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
833 if dist < min_dist {
834 min_dist = dist;
835 best_idx = j;
836 }
837 }
838
839 for k in 0..n_classes {
841 probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
842 }
843 }
844
845 Ok(probas)
846 }
847}
848
849#[derive(Debug, Clone)]
851pub struct GraphStructureLearningTrained {
852 pub X_train: Array2<f64>,
854 pub y_train: Array1<i32>,
856 pub classes: Array1<i32>,
858 pub learned_graph: Array2<f64>,
860 pub label_distributions: Array2<f64>,
862}
863
864#[derive(Debug, Clone)]
866pub struct RobustGraphLearningTrained {
867 pub X_train: Array2<f64>,
869 pub y_train: Array1<i32>,
871 pub classes: Array1<i32>,
873 pub learned_graph: Array2<f64>,
875 pub label_distributions: Array2<f64>,
877}
878
879#[derive(Debug, Clone)]
914pub struct DistributedGraphLearning<S = Untrained> {
915 state: S,
916 n_workers: usize,
917 lambda_sparse: f64,
918 beta_smoothness: f64,
919 max_iter: usize,
920 tol: f64,
921 learning_rate: f64,
922 partition_strategy: String,
923 communication_rounds: usize,
924 overlap_ratio: f64,
925 consensus_weight: f64,
926}
927
928impl DistributedGraphLearning<Untrained> {
929 pub fn new() -> Self {
931 Self {
932 state: Untrained,
933 n_workers: 2,
934 lambda_sparse: 0.1,
935 beta_smoothness: 1.0,
936 max_iter: 100,
937 tol: 1e-4,
938 learning_rate: 0.01,
939 partition_strategy: "spectral".to_string(),
940 communication_rounds: 10,
941 overlap_ratio: 0.1,
942 consensus_weight: 0.5,
943 }
944 }
945
946 pub fn n_workers(mut self, n_workers: usize) -> Self {
948 self.n_workers = n_workers;
949 self
950 }
951
952 pub fn lambda_sparse(mut self, lambda_sparse: f64) -> Self {
954 self.lambda_sparse = lambda_sparse;
955 self
956 }
957
958 pub fn beta_smoothness(mut self, beta_smoothness: f64) -> Self {
960 self.beta_smoothness = beta_smoothness;
961 self
962 }
963
964 pub fn max_iter(mut self, max_iter: usize) -> Self {
966 self.max_iter = max_iter;
967 self
968 }
969
970 pub fn tol(mut self, tol: f64) -> Self {
972 self.tol = tol;
973 self
974 }
975
976 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
978 self.learning_rate = learning_rate;
979 self
980 }
981
982 pub fn partition_strategy(mut self, strategy: String) -> Self {
984 self.partition_strategy = strategy;
985 self
986 }
987
988 pub fn communication_rounds(mut self, rounds: usize) -> Self {
990 self.communication_rounds = rounds;
991 self
992 }
993
994 pub fn overlap_ratio(mut self, ratio: f64) -> Self {
996 self.overlap_ratio = ratio;
997 self
998 }
999
1000 pub fn consensus_weight(mut self, weight: f64) -> Self {
1002 self.consensus_weight = weight;
1003 self
1004 }
1005
1006 fn partition_nodes(&self, n_samples: usize) -> Vec<Vec<usize>> {
1007 let nodes_per_worker = (n_samples + self.n_workers - 1) / self.n_workers;
1008 let overlap_size = (nodes_per_worker as f64 * self.overlap_ratio) as usize;
1009
1010 let mut partitions = Vec::with_capacity(self.n_workers);
1011
1012 match self.partition_strategy.as_str() {
1013 "random" => {
1014 let mut nodes: Vec<usize> = (0..n_samples).collect();
1016 use scirs2_core::random::rand_prelude::SliceRandom;
1017 let mut rng = Random::seed(42);
1018 nodes.shuffle(&mut rng);
1019
1020 for i in 0..self.n_workers {
1021 let start = i * nodes_per_worker;
1022 let end = ((i + 1) * nodes_per_worker).min(n_samples);
1023 let overlap_start = start.saturating_sub(overlap_size);
1024 let overlap_end = (end + overlap_size).min(n_samples);
1025
1026 let mut partition = Vec::new();
1027 for j in overlap_start..overlap_end {
1028 if j < nodes.len() {
1029 partition.push(nodes[j]);
1030 }
1031 }
1032 partitions.push(partition);
1033 }
1034 }
1035 "spectral" => {
1036 self.spectral_partition(n_samples, &mut partitions, nodes_per_worker, overlap_size);
1038 }
1039 _ => {
1040 for i in 0..self.n_workers {
1042 let start = i * nodes_per_worker;
1043 let end = ((i + 1) * nodes_per_worker).min(n_samples);
1044 let overlap_start = start.saturating_sub(overlap_size);
1045 let overlap_end = (end + overlap_size).min(n_samples);
1046
1047 let partition: Vec<usize> = (overlap_start..overlap_end).collect();
1048 partitions.push(partition);
1049 }
1050 }
1051 }
1052
1053 partitions
1054 }
1055
1056 fn spectral_partition(
1057 &self,
1058 n_samples: usize,
1059 partitions: &mut Vec<Vec<usize>>,
1060 nodes_per_worker: usize,
1061 overlap_size: usize,
1062 ) {
1063 let mut spectral_order: Vec<usize> = (0..n_samples).collect();
1066
1067 let center = n_samples / 2;
1069 spectral_order.sort_by_key(|&i| i.abs_diff(center));
1070
1071 for i in 0..self.n_workers {
1072 let start = i * nodes_per_worker;
1073 let end = ((i + 1) * nodes_per_worker).min(n_samples);
1074 let overlap_start = start.saturating_sub(overlap_size);
1075 let overlap_end = (end + overlap_size).min(n_samples);
1076
1077 let mut partition = Vec::new();
1078 for j in overlap_start..overlap_end {
1079 if j < spectral_order.len() {
1080 partition.push(spectral_order[j]);
1081 }
1082 }
1083 partitions.push(partition);
1084 }
1085 }
1086
1087 fn extract_subgraph(&self, X: &Array2<f64>, partition: &[usize]) -> Array2<f64> {
1088 let n_nodes = partition.len();
1089 let n_features = X.ncols();
1090 let mut X_sub = Array2::zeros((n_nodes, n_features));
1091
1092 for (i, &node_idx) in partition.iter().enumerate() {
1093 if node_idx < X.nrows() {
1094 X_sub.row_mut(i).assign(&X.row(node_idx));
1095 }
1096 }
1097
1098 X_sub
1099 }
1100
1101 fn extract_sublabels(&self, y: &Array1<i32>, partition: &[usize]) -> Array1<i32> {
1102 let n_nodes = partition.len();
1103 let mut y_sub = Array1::from_elem(n_nodes, -1);
1104
1105 for (i, &node_idx) in partition.iter().enumerate() {
1106 if node_idx < y.len() {
1107 y_sub[i] = y[node_idx];
1108 }
1109 }
1110
1111 y_sub
1112 }
1113
1114 fn learn_local_graph(
1115 &self,
1116 X_sub: &Array2<f64>,
1117 y_sub: &Array1<i32>,
1118 ) -> SklResult<Array2<f64>> {
1119 let n_samples = X_sub.nrows();
1120 let mut W = Array2::zeros((n_samples, n_samples));
1121
1122 let k = (n_samples as f64).sqrt().ceil() as usize;
1124 let k = k.clamp(3, 10);
1125
1126 for i in 0..n_samples {
1127 let mut distances: Vec<(usize, f64)> = Vec::new();
1128 for j in 0..n_samples {
1129 if i != j {
1130 let diff = &X_sub.row(i) - &X_sub.row(j);
1131 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
1132 distances.push((j, dist));
1133 }
1134 }
1135
1136 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1137
1138 for &(j, dist) in distances.iter().take(k) {
1139 let weight = (-dist / (2.0 * 1.0_f64.powi(2))).exp();
1140 W[[i, j]] = weight;
1141 W[[j, i]] = weight; }
1143 }
1144
1145 let threshold = self.lambda_sparse;
1147 W.mapv_inplace(|x| if x > threshold { x - threshold } else { 0.0 });
1148 W.mapv_inplace(|x| x.max(0.0));
1149
1150 for i in 0..n_samples {
1152 W[[i, i]] = 0.0;
1153 }
1154
1155 Ok(W)
1156 }
1157
1158 fn communicate_boundaries(
1159 &self,
1160 local_graphs: &[Array2<f64>],
1161 partitions: &[Vec<usize>],
1162 ) -> Vec<Array2<f64>> {
1163 let mut updated_graphs = local_graphs.to_vec();
1164
1165 for i in 0..self.n_workers {
1167 for j in (i + 1)..self.n_workers {
1168 let common_nodes: Vec<(usize, usize)> = partitions[i]
1170 .iter()
1171 .enumerate()
1172 .filter_map(|(idx_i, &node)| {
1173 partitions[j]
1174 .iter()
1175 .position(|&n| n == node)
1176 .map(|idx_j| (idx_i, idx_j))
1177 })
1178 .collect();
1179
1180 for &(idx_i, idx_j) in &common_nodes {
1182 if idx_i < updated_graphs[i].nrows() && idx_j < updated_graphs[j].nrows() {
1183 for &(other_i, other_j) in &common_nodes {
1184 if other_i < updated_graphs[i].ncols()
1185 && other_j < updated_graphs[j].ncols()
1186 {
1187 let weight_i = updated_graphs[i][[idx_i, other_i]];
1188 let weight_j = updated_graphs[j][[idx_j, other_j]];
1189 let avg_weight = (weight_i + weight_j) / 2.0;
1190
1191 updated_graphs[i][[idx_i, other_i]] = avg_weight;
1192 updated_graphs[j][[idx_j, other_j]] = avg_weight;
1193 }
1194 }
1195 }
1196 }
1197 }
1198 }
1199
1200 updated_graphs
1201 }
1202
1203 fn merge_graphs(
1204 &self,
1205 local_graphs: &[Array2<f64>],
1206 partitions: &[Vec<usize>],
1207 n_total: usize,
1208 ) -> Array2<f64> {
1209 let mut global_graph = Array2::zeros((n_total, n_total));
1210 let mut weight_counts: Array2<f64> = Array2::zeros((n_total, n_total));
1211
1212 for (worker_idx, (local_graph, partition)) in
1214 local_graphs.iter().zip(partitions.iter()).enumerate()
1215 {
1216 for (i, &node_i) in partition.iter().enumerate() {
1217 for (j, &node_j) in partition.iter().enumerate() {
1218 if i < local_graph.nrows()
1219 && j < local_graph.ncols()
1220 && node_i < n_total
1221 && node_j < n_total
1222 {
1223 global_graph[[node_i, node_j]] += local_graph[[i, j]];
1224 if local_graph[[i, j]] > 0.0 {
1225 weight_counts[[node_i, node_j]] += 1.0;
1226 }
1227 }
1228 }
1229 }
1230 }
1231
1232 for i in 0..n_total {
1234 for j in 0..n_total {
1235 if weight_counts[[i, j]] > 0.0 {
1236 global_graph[[i, j]] /= weight_counts[[i, j]];
1237 }
1238 }
1239 }
1240
1241 global_graph
1242 }
1243}
1244
1245impl Default for DistributedGraphLearning<Untrained> {
1246 fn default() -> Self {
1247 Self::new()
1248 }
1249}
1250
1251impl Estimator for DistributedGraphLearning<Untrained> {
1252 type Config = ();
1253 type Error = SklearsError;
1254 type Float = Float;
1255
1256 fn config(&self) -> &Self::Config {
1257 &()
1258 }
1259}
1260
1261impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for DistributedGraphLearning<Untrained> {
1262 type Fitted = DistributedGraphLearning<DistributedGraphLearningTrained>;
1263
1264 #[allow(non_snake_case)]
1265 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
1266 let X = X.to_owned();
1267 let y = y.to_owned();
1268 let (n_samples, _n_features) = X.dim();
1269
1270 let mut labeled_indices = Vec::new();
1272 let mut classes = std::collections::HashSet::new();
1273
1274 for (i, &label) in y.iter().enumerate() {
1275 if label != -1 {
1276 labeled_indices.push(i);
1277 classes.insert(label);
1278 }
1279 }
1280
1281 if labeled_indices.is_empty() {
1282 return Err(SklearsError::InvalidInput(
1283 "No labeled samples provided".to_string(),
1284 ));
1285 }
1286
1287 let classes: Vec<i32> = classes.into_iter().collect();
1288
1289 let partitions = self.partition_nodes(n_samples);
1291
1292 let mut local_graphs = Vec::with_capacity(self.n_workers);
1294 for partition in &partitions {
1295 let X_sub = self.extract_subgraph(&X, partition);
1296 let y_sub = self.extract_sublabels(&y, partition);
1297 let local_graph = self.learn_local_graph(&X_sub, &y_sub)?;
1298 local_graphs.push(local_graph);
1299 }
1300
1301 for _round in 0..self.communication_rounds {
1303 local_graphs = self.communicate_boundaries(&local_graphs, &partitions);
1304 }
1305
1306 let global_graph = self.merge_graphs(&local_graphs, &partitions, n_samples);
1308
1309 let n_classes = classes.len();
1311 let mut Y = Array2::zeros((n_samples, n_classes));
1312 for &idx in &labeled_indices {
1313 if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
1314 Y[[idx, class_idx]] = 1.0;
1315 }
1316 }
1317
1318 let D = global_graph.sum_axis(Axis(1));
1320 let mut P = Array2::zeros((n_samples, n_samples));
1321 for i in 0..n_samples {
1322 if D[i] > 0.0 {
1323 for j in 0..n_samples {
1324 P[[i, j]] = global_graph[[i, j]] / D[i];
1325 }
1326 }
1327 }
1328
1329 let Y_static = Y.clone();
1330 for _iter in 0..50 {
1331 let prev_Y = Y.clone();
1332 Y = 0.9 * P.dot(&Y) + 0.1 * &Y_static;
1333
1334 let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
1335 if diff < 1e-6 {
1336 break;
1337 }
1338 }
1339
1340 Ok(DistributedGraphLearning {
1341 state: DistributedGraphLearningTrained {
1342 X_train: X,
1343 y_train: y,
1344 classes: Array1::from(classes),
1345 global_graph,
1346 label_distributions: Y,
1347 partitions,
1348 },
1349 n_workers: self.n_workers,
1350 lambda_sparse: self.lambda_sparse,
1351 beta_smoothness: self.beta_smoothness,
1352 max_iter: self.max_iter,
1353 tol: self.tol,
1354 learning_rate: self.learning_rate,
1355 partition_strategy: self.partition_strategy,
1356 communication_rounds: self.communication_rounds,
1357 overlap_ratio: self.overlap_ratio,
1358 consensus_weight: self.consensus_weight,
1359 })
1360 }
1361}
1362
1363impl Predict<ArrayView2<'_, Float>, Array1<i32>>
1364 for DistributedGraphLearning<DistributedGraphLearningTrained>
1365{
1366 #[allow(non_snake_case)]
1367 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
1368 let X = X.to_owned();
1369 let n_test = X.nrows();
1370 let mut predictions = Array1::zeros(n_test);
1371
1372 for i in 0..n_test {
1373 let mut min_dist = f64::INFINITY;
1374 let mut best_idx = 0;
1375
1376 for j in 0..self.state.X_train.nrows() {
1377 let diff = &X.row(i) - &self.state.X_train.row(j);
1378 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
1379 if dist < min_dist {
1380 min_dist = dist;
1381 best_idx = j;
1382 }
1383 }
1384
1385 let distributions = self.state.label_distributions.row(best_idx);
1386 let max_idx = distributions
1387 .iter()
1388 .enumerate()
1389 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
1390 .unwrap()
1391 .0;
1392
1393 predictions[i] = self.state.classes[max_idx];
1394 }
1395
1396 Ok(predictions)
1397 }
1398}
1399
1400impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
1401 for DistributedGraphLearning<DistributedGraphLearningTrained>
1402{
1403 #[allow(non_snake_case)]
1404 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
1405 let X = X.to_owned();
1406 let n_test = X.nrows();
1407 let n_classes = self.state.classes.len();
1408 let mut probas = Array2::zeros((n_test, n_classes));
1409
1410 for i in 0..n_test {
1411 let mut min_dist = f64::INFINITY;
1412 let mut best_idx = 0;
1413
1414 for j in 0..self.state.X_train.nrows() {
1415 let diff = &X.row(i) - &self.state.X_train.row(j);
1416 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
1417 if dist < min_dist {
1418 min_dist = dist;
1419 best_idx = j;
1420 }
1421 }
1422
1423 for k in 0..n_classes {
1424 probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
1425 }
1426 }
1427
1428 Ok(probas)
1429 }
1430}
1431
1432#[derive(Debug, Clone)]
1434pub struct DistributedGraphLearningTrained {
1435 pub X_train: Array2<f64>,
1437 pub y_train: Array1<i32>,
1439 pub classes: Array1<i32>,
1441 pub global_graph: Array2<f64>,
1443 pub label_distributions: Array2<f64>,
1445 pub partitions: Vec<Vec<usize>>,
1447}
1448
1449#[allow(non_snake_case)]
1450#[cfg(test)]
1451mod tests {
1452 use super::*;
1453 use scirs2_core::array;
1454
1455 #[test]
1456 #[allow(non_snake_case)]
1457 fn test_graph_structure_learning() {
1458 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1459 let y = array![0, 1, -1, -1]; let gsl = GraphStructureLearning::new()
1462 .lambda_sparse(0.1)
1463 .beta_smoothness(1.0)
1464 .max_iter(20);
1465 let fitted = gsl.fit(&X.view(), &y.view()).unwrap();
1466
1467 let predictions = fitted.predict(&X.view()).unwrap();
1468 assert_eq!(predictions.len(), 4);
1469
1470 let probas = fitted.predict_proba(&X.view()).unwrap();
1471 assert_eq!(probas.dim(), (4, 2));
1472
1473 let n_edges = fitted
1475 .state
1476 .learned_graph
1477 .iter()
1478 .filter(|&&x| x > 0.0)
1479 .count();
1480 let total_edges = 4 * 4 - 4; assert!(n_edges < total_edges); }
1483
1484 #[test]
1485 #[allow(non_snake_case)]
1486 fn test_robust_graph_learning() {
1487 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1488 let y = array![0, 1, -1, -1]; let rgl = RobustGraphLearning::new()
1491 .lambda_sparse(0.1)
1492 .robust_metric("huber".to_string())
1493 .max_iter(20);
1494 let fitted = rgl.fit(&X.view(), &y.view()).unwrap();
1495
1496 let predictions = fitted.predict(&X.view()).unwrap();
1497 assert_eq!(predictions.len(), 4);
1498
1499 let probas = fitted.predict_proba(&X.view()).unwrap();
1500 assert_eq!(probas.dim(), (4, 2));
1501 }
1502
1503 #[test]
1504 fn test_robust_distance_metrics() {
1505 let rgl = RobustGraphLearning::new();
1506 let x1 = array![1.0, 2.0];
1507 let x2 = array![3.0, 4.0];
1508
1509 let rgl_l1 = rgl.clone().robust_metric("l1".to_string());
1511 let dist_l1 = rgl_l1.robust_distance(&x1.view(), &x2.view());
1512 assert_eq!(dist_l1, 4.0); let rgl_huber = rgl
1516 .clone()
1517 .robust_metric("huber".to_string())
1518 .huber_delta(1.0);
1519 let dist_huber = rgl_huber.robust_distance(&x1.view(), &x2.view());
1520 assert!(dist_huber > 0.0);
1521
1522 let rgl_tukey = rgl.robust_metric("tukey".to_string());
1524 let dist_tukey = rgl_tukey.robust_distance(&x1.view(), &x2.view());
1525 assert!(dist_tukey > 0.0);
1526 }
1527
1528 #[test]
1529 fn test_soft_threshold() {
1530 let gsl = GraphStructureLearning::new();
1531
1532 assert_eq!(gsl.soft_threshold(2.0, 1.0), 1.0);
1533 assert_eq!(gsl.soft_threshold(-2.0, 1.0), -1.0);
1534 assert_eq!(gsl.soft_threshold(0.5, 1.0), 0.0);
1535 assert_eq!(gsl.soft_threshold(-0.5, 1.0), 0.0);
1536 }
1537
1538 #[test]
1539 #[allow(non_snake_case)]
1540 fn test_symmetry_enforcement() {
1541 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1542 let y = array![0, 1, -1];
1543
1544 let gsl = GraphStructureLearning::new()
1545 .enforce_symmetry(true)
1546 .max_iter(5) .lambda_sparse(0.01); let fitted = gsl.fit(&X.view(), &y.view()).unwrap();
1549
1550 let W = &fitted.state.learned_graph;
1551 let n = W.nrows();
1552
1553 let mut max_asymmetry = 0.0_f64;
1555 for i in 0..n {
1556 for j in 0..n {
1557 let asymmetry = (W[[i, j]] - W[[j, i]]).abs();
1558 max_asymmetry = max_asymmetry.max(asymmetry);
1559 }
1560 }
1561 assert!(
1562 max_asymmetry < 0.5,
1563 "Maximum asymmetry: {} - optimization may not maintain perfect symmetry",
1564 max_asymmetry
1565 );
1566
1567 for i in 0..n {
1569 assert_eq!(W[[i, i]], 0.0);
1570 }
1571
1572 for i in 0..n {
1574 for j in 0..n {
1575 assert!(W[[i, j]] >= 0.0);
1576 }
1577 }
1578 }
1579
1580 #[test]
1581 #[allow(non_snake_case)]
1582 fn test_distributed_graph_learning() {
1583 let X = array![
1584 [1.0, 2.0],
1585 [2.0, 3.0],
1586 [3.0, 4.0],
1587 [4.0, 5.0],
1588 [5.0, 6.0],
1589 [6.0, 7.0],
1590 [7.0, 8.0],
1591 [8.0, 9.0]
1592 ];
1593 let y = array![0, 1, -1, -1, 0, 1, -1, -1]; let dgl = DistributedGraphLearning::new()
1596 .n_workers(2)
1597 .lambda_sparse(0.05)
1598 .communication_rounds(5)
1599 .partition_strategy("spectral".to_string());
1600 let fitted = dgl.fit(&X.view(), &y.view()).unwrap();
1601
1602 let predictions = fitted.predict(&X.view()).unwrap();
1603 assert_eq!(predictions.len(), 8);
1604
1605 let probas = fitted.predict_proba(&X.view()).unwrap();
1606 assert_eq!(probas.dim(), (8, 2));
1607
1608 assert_eq!(fitted.state.global_graph.dim(), (8, 8));
1610
1611 assert_eq!(fitted.state.partitions.len(), 2);
1613 assert!(!fitted.state.partitions[0].is_empty());
1614 assert!(!fitted.state.partitions[1].is_empty());
1615
1616 for &pred in predictions.iter() {
1618 assert!(pred == 0 || pred == 1);
1619 }
1620 }
1621
1622 #[test]
1623 fn test_distributed_graph_learning_partitioning() {
1624 let dgl = DistributedGraphLearning::new()
1625 .n_workers(3)
1626 .overlap_ratio(0.2);
1627
1628 let partitions_default = dgl.partition_nodes(10);
1630 assert_eq!(partitions_default.len(), 3);
1631
1632 let partitions_random = dgl
1633 .clone()
1634 .partition_strategy("random".to_string())
1635 .partition_nodes(10);
1636 assert_eq!(partitions_random.len(), 3);
1637
1638 let partitions_spectral = dgl
1639 .clone()
1640 .partition_strategy("spectral".to_string())
1641 .partition_nodes(10);
1642 assert_eq!(partitions_spectral.len(), 3);
1643
1644 let mut all_nodes = std::collections::HashSet::new();
1646 for partition in &partitions_default {
1647 for &node in partition {
1648 all_nodes.insert(node);
1649 }
1650 }
1651 assert_eq!(all_nodes.len(), 10);
1652 }
1653
1654 #[test]
1655 #[allow(non_snake_case)]
1656 fn test_distributed_graph_learning_communication() {
1657 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1658 let y = array![0, 1, -1, -1];
1659
1660 let dgl = DistributedGraphLearning::new()
1661 .n_workers(2)
1662 .communication_rounds(3)
1663 .overlap_ratio(0.3);
1664
1665 let partitions = dgl.partition_nodes(4);
1666 let X_sub1 = dgl.extract_subgraph(&X, &partitions[0]);
1667 let X_sub2 = dgl.extract_subgraph(&X, &partitions[1]);
1668 let y_sub1 = dgl.extract_sublabels(&y, &partitions[0]);
1669 let y_sub2 = dgl.extract_sublabels(&y, &partitions[1]);
1670
1671 let graph1 = dgl.learn_local_graph(&X_sub1, &y_sub1).unwrap();
1672 let graph2 = dgl.learn_local_graph(&X_sub2, &y_sub2).unwrap();
1673
1674 let local_graphs = vec![graph1, graph2];
1675 let updated_graphs = dgl.communicate_boundaries(&local_graphs, &partitions);
1676
1677 assert_eq!(updated_graphs.len(), 2);
1678 assert_eq!(updated_graphs[0].dim(), local_graphs[0].dim());
1679 assert_eq!(updated_graphs[1].dim(), local_graphs[1].dim());
1680 }
1681}