sklears_semi_supervised/
graph_learning.rs

1//! Graph structure learning methods for semi-supervised learning
2//!
3//! This module provides algorithms to learn optimal graph structures from data
4//! for semi-supervised learning tasks.
5
6use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use scirs2_core::random::Random;
8use sklears_core::{
9    error::{Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
11    types::Float,
12};
13
14/// Graph Structure Learning for Semi-Supervised Learning
15///
16/// This method learns an optimal graph structure that balances data fidelity
17/// and sparsity constraints. The learned graph is then used for label propagation.
18///
19/// The method solves the optimization problem:
20/// min_W ||X - W * X||_F^2 + λ * ||W||_1 + β * tr(F^T * L_W * F)
21///
22/// where W is the graph adjacency matrix, L_W is the graph Laplacian,
23/// and F is the label matrix.
24///
25/// # Parameters
26///
27/// * `lambda_sparse` - Sparsity regularization parameter
28/// * `beta_smoothness` - Smoothness regularization parameter
29/// * `max_iter` - Maximum number of iterations
30/// * `tol` - Convergence tolerance
31/// * `learning_rate` - Learning rate for optimization
32/// * `adaptive_lr` - Whether to use adaptive learning rate
33///
34/// # Examples
35///
36/// ```
37/// use scirs2_core::array;
38/// use sklears_semi_supervised::GraphStructureLearning;
39/// use sklears_core::traits::{Predict, Fit};
40///
41///
42/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
43/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
44///
45/// let gsl = GraphStructureLearning::new()
46///     .lambda_sparse(0.1)
47///     .beta_smoothness(1.0);
48/// let fitted = gsl.fit(&X.view(), &y.view()).unwrap();
49/// let predictions = fitted.predict(&X.view()).unwrap();
50/// ```
51#[derive(Debug, Clone)]
52pub struct GraphStructureLearning<S = Untrained> {
53    state: S,
54    lambda_sparse: f64,
55    beta_smoothness: f64,
56    max_iter: usize,
57    tol: f64,
58    learning_rate: f64,
59    adaptive_lr: bool,
60    enforce_symmetry: bool,
61    normalize_weights: bool,
62}
63
64impl GraphStructureLearning<Untrained> {
65    /// Create a new GraphStructureLearning instance
66    pub fn new() -> Self {
67        Self {
68            state: Untrained,
69            lambda_sparse: 0.1,
70            beta_smoothness: 1.0,
71            max_iter: 100,
72            tol: 1e-4,
73            learning_rate: 0.01,
74            adaptive_lr: true,
75            enforce_symmetry: true,
76            normalize_weights: true,
77        }
78    }
79
80    /// Set the sparsity regularization parameter
81    pub fn lambda_sparse(mut self, lambda_sparse: f64) -> Self {
82        self.lambda_sparse = lambda_sparse;
83        self
84    }
85
86    /// Set the smoothness regularization parameter
87    pub fn beta_smoothness(mut self, beta_smoothness: f64) -> Self {
88        self.beta_smoothness = beta_smoothness;
89        self
90    }
91
92    /// Set the maximum number of iterations
93    pub fn max_iter(mut self, max_iter: usize) -> Self {
94        self.max_iter = max_iter;
95        self
96    }
97
98    /// Set the convergence tolerance
99    pub fn tol(mut self, tol: f64) -> Self {
100        self.tol = tol;
101        self
102    }
103
104    /// Set the learning rate
105    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
106        self.learning_rate = learning_rate;
107        self
108    }
109
110    /// Enable/disable adaptive learning rate
111    pub fn adaptive_lr(mut self, adaptive_lr: bool) -> Self {
112        self.adaptive_lr = adaptive_lr;
113        self
114    }
115
116    /// Enable/disable symmetry enforcement
117    pub fn enforce_symmetry(mut self, enforce_symmetry: bool) -> Self {
118        self.enforce_symmetry = enforce_symmetry;
119        self
120    }
121
122    /// Enable/disable weight normalization
123    pub fn normalize_weights(mut self, normalize_weights: bool) -> Self {
124        self.normalize_weights = normalize_weights;
125        self
126    }
127
128    fn initialize_graph(&self, X: &Array2<f64>) -> Array2<f64> {
129        let n_samples = X.nrows();
130        let mut W = Array2::zeros((n_samples, n_samples));
131
132        // Initialize with k-NN graph
133        let k = (n_samples as f64).sqrt().ceil() as usize;
134        let k = k.clamp(3, 10); // Bound k between 3 and 10
135
136        for i in 0..n_samples {
137            let mut distances: Vec<(usize, f64)> = Vec::new();
138            for j in 0..n_samples {
139                if i != j {
140                    let diff = &X.row(i) - &X.row(j);
141                    let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
142                    distances.push((j, dist));
143                }
144            }
145
146            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
147
148            for &(j, dist) in distances.iter().take(k) {
149                let weight = (-dist / (2.0 * 1.0_f64.powi(2))).exp();
150                W[[i, j]] = weight;
151                if self.enforce_symmetry {
152                    W[[j, i]] = weight;
153                }
154            }
155        }
156
157        W
158    }
159
160    #[allow(non_snake_case)]
161    fn compute_laplacian(&self, W: &Array2<f64>) -> Array2<f64> {
162        let n_samples = W.nrows();
163        let D = W.sum_axis(Axis(1));
164        let mut L = Array2::zeros((n_samples, n_samples));
165
166        for i in 0..n_samples {
167            L[[i, i]] = D[i];
168            for j in 0..n_samples {
169                if i != j {
170                    L[[i, j]] = -W[[i, j]];
171                }
172            }
173        }
174
175        L
176    }
177
178    fn soft_threshold(&self, x: f64, threshold: f64) -> f64 {
179        if x > threshold {
180            x - threshold
181        } else if x < -threshold {
182            x + threshold
183        } else {
184            0.0
185        }
186    }
187
188    fn proximal_gradient_step(&self, W: &Array2<f64>, grad: &Array2<f64>, lr: f64) -> Array2<f64> {
189        let mut W_new = W - lr * grad;
190
191        // Apply L1 proximal operator (soft thresholding)
192        let threshold = lr * self.lambda_sparse;
193        W_new.mapv_inplace(|x| self.soft_threshold(x, threshold));
194
195        // Ensure non-negativity
196        W_new.mapv_inplace(|x| x.max(0.0));
197
198        // Enforce symmetry if required
199        if self.enforce_symmetry {
200            let n = W_new.nrows();
201            for i in 0..n {
202                for j in 0..n {
203                    if i != j {
204                        let avg = (W_new[[i, j]] + W_new[[j, i]]) / 2.0;
205                        W_new[[i, j]] = avg;
206                        W_new[[j, i]] = avg;
207                    }
208                }
209            }
210        }
211
212        // Zero diagonal
213        for i in 0..W_new.nrows() {
214            W_new[[i, i]] = 0.0;
215        }
216
217        W_new
218    }
219
220    fn normalize_graph(&self, W: &Array2<f64>) -> Array2<f64> {
221        if !self.normalize_weights {
222            return W.clone();
223        }
224
225        let mut W_norm = W.clone();
226        let n_samples = W.nrows();
227
228        // Row-wise normalization
229        for i in 0..n_samples {
230            let row_sum: f64 = W.row(i).sum();
231            if row_sum > 0.0 {
232                for j in 0..n_samples {
233                    W_norm[[i, j]] = W[[i, j]] / row_sum;
234                }
235            }
236        }
237
238        W_norm
239    }
240
241    #[allow(non_snake_case)]
242    fn propagate_labels(&self, W: &Array2<f64>, Y_init: &Array2<f64>) -> SklResult<Array2<f64>> {
243        let n_samples = W.nrows();
244        let n_classes = Y_init.ncols();
245
246        // Compute transition matrix
247        let D = W.sum_axis(Axis(1));
248        let mut P = Array2::zeros((n_samples, n_samples));
249        for i in 0..n_samples {
250            if D[i] > 0.0 {
251                for j in 0..n_samples {
252                    P[[i, j]] = W[[i, j]] / D[i];
253                }
254            }
255        }
256
257        let mut Y = Y_init.clone();
258        let Y_static = Y_init.clone();
259
260        // Label propagation iterations
261        for _iter in 0..50 {
262            let prev_Y = Y.clone();
263            Y = 0.9 * P.dot(&Y) + 0.1 * &Y_static;
264
265            // Check convergence
266            let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
267            if diff < 1e-6 {
268                break;
269            }
270        }
271
272        Ok(Y)
273    }
274}
275
276impl Default for GraphStructureLearning<Untrained> {
277    fn default() -> Self {
278        Self::new()
279    }
280}
281
282impl Estimator for GraphStructureLearning<Untrained> {
283    type Config = ();
284    type Error = SklearsError;
285    type Float = Float;
286
287    fn config(&self) -> &Self::Config {
288        &()
289    }
290}
291
292impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for GraphStructureLearning<Untrained> {
293    type Fitted = GraphStructureLearning<GraphStructureLearningTrained>;
294
295    #[allow(non_snake_case)]
296    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
297        let X = X.to_owned();
298        let y = y.to_owned();
299
300        let (n_samples, _n_features) = X.dim();
301
302        // Identify labeled and unlabeled samples
303        let mut labeled_indices = Vec::new();
304        let mut classes = std::collections::HashSet::new();
305
306        for (i, &label) in y.iter().enumerate() {
307            if label != -1 {
308                labeled_indices.push(i);
309                classes.insert(label);
310            }
311        }
312
313        if labeled_indices.is_empty() {
314            return Err(SklearsError::InvalidInput(
315                "No labeled samples provided".to_string(),
316            ));
317        }
318
319        let classes: Vec<i32> = classes.into_iter().collect();
320        let n_classes = classes.len();
321
322        // Initialize graph
323        let mut W = self.initialize_graph(&X);
324
325        // Initialize label matrix
326        let mut Y = Array2::zeros((n_samples, n_classes));
327        for &idx in &labeled_indices {
328            if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
329                Y[[idx, class_idx]] = 1.0;
330            }
331        }
332
333        let Y_init = Y.clone();
334        let mut lr = self.learning_rate;
335        let mut prev_loss = f64::INFINITY;
336
337        // Main optimization loop
338        for iteration in 0..self.max_iter {
339            // Propagate labels with current graph
340            Y = self.propagate_labels(&W, &Y_init)?;
341
342            // Compute graph Laplacian
343            let L = self.compute_laplacian(&W);
344
345            // Compute data fidelity loss: ||X - W * X||_F^2
346            let WX = W.dot(&X);
347            let data_fidelity = (&X - &WX).mapv(|x| x * x).sum();
348
349            // Compute smoothness loss: tr(F^T * L * F)
350            let smoothness_loss = {
351                let LY = L.dot(&Y);
352                let mut trace = 0.0;
353                for i in 0..n_samples {
354                    for j in 0..n_classes {
355                        trace += Y[[i, j]] * LY[[i, j]];
356                    }
357                }
358                trace
359            };
360
361            // Compute sparsity loss: ||W||_1
362            let sparsity_loss = W.iter().map(|&x| x.abs()).sum::<f64>();
363
364            // Total loss
365            let total_loss = data_fidelity
366                + self.beta_smoothness * smoothness_loss
367                + self.lambda_sparse * sparsity_loss;
368
369            // Check convergence
370            if (prev_loss - total_loss).abs() < self.tol {
371                break;
372            }
373
374            // Adaptive learning rate
375            if self.adaptive_lr {
376                if total_loss > prev_loss {
377                    lr *= 0.8; // Decrease learning rate
378                } else if iteration % 10 == 0 && total_loss < prev_loss {
379                    lr *= 1.1; // Increase learning rate
380                }
381                lr = lr.clamp(1e-6, 0.1); // Bound learning rate
382            }
383
384            prev_loss = total_loss;
385
386            // Compute gradient w.r.t. W
387            let mut grad_W = Array2::zeros(W.dim());
388
389            // Data fidelity gradient: 2 * (W * X - X) * X^T
390            let residual = &WX - &X;
391            grad_W = 2.0 * residual.dot(&X.t());
392
393            // Smoothness gradient: β * (D * Y * Y^T - W * Y * Y^T)
394            let YYT = Y.dot(&Y.t());
395            let D = W.sum_axis(Axis(1));
396            for i in 0..n_samples {
397                for j in 0..n_samples {
398                    if i == j {
399                        grad_W[[i, j]] +=
400                            self.beta_smoothness * (D[i] * YYT[[i, j]] - W[[i, j]] * YYT[[i, j]]);
401                    } else {
402                        grad_W[[i, j]] += self.beta_smoothness * (-YYT[[i, j]]);
403                    }
404                }
405            }
406
407            // Update W using proximal gradient
408            W = self.proximal_gradient_step(&W, &grad_W, lr);
409        }
410
411        // Normalize final graph
412        let W_final = self.normalize_graph(&W);
413
414        // Final label propagation
415        let Y_final = self.propagate_labels(&W_final, &Y_init)?;
416
417        Ok(GraphStructureLearning {
418            state: GraphStructureLearningTrained {
419                X_train: X,
420                y_train: y,
421                classes: Array1::from(classes),
422                learned_graph: W_final,
423                label_distributions: Y_final,
424            },
425            lambda_sparse: self.lambda_sparse,
426            beta_smoothness: self.beta_smoothness,
427            max_iter: self.max_iter,
428            tol: self.tol,
429            learning_rate: self.learning_rate,
430            adaptive_lr: self.adaptive_lr,
431            enforce_symmetry: self.enforce_symmetry,
432            normalize_weights: self.normalize_weights,
433        })
434    }
435}
436
437impl Predict<ArrayView2<'_, Float>, Array1<i32>>
438    for GraphStructureLearning<GraphStructureLearningTrained>
439{
440    #[allow(non_snake_case)]
441    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
442        let X = X.to_owned();
443        let n_test = X.nrows();
444        let mut predictions = Array1::zeros(n_test);
445
446        for i in 0..n_test {
447            // Find most similar training sample
448            let mut min_dist = f64::INFINITY;
449            let mut best_idx = 0;
450
451            for j in 0..self.state.X_train.nrows() {
452                let diff = &X.row(i) - &self.state.X_train.row(j);
453                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
454                if dist < min_dist {
455                    min_dist = dist;
456                    best_idx = j;
457                }
458            }
459
460            // Use the label distribution of the most similar sample
461            let distributions = self.state.label_distributions.row(best_idx);
462            let max_idx = distributions
463                .iter()
464                .enumerate()
465                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
466                .unwrap()
467                .0;
468
469            predictions[i] = self.state.classes[max_idx];
470        }
471
472        Ok(predictions)
473    }
474}
475
476impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
477    for GraphStructureLearning<GraphStructureLearningTrained>
478{
479    #[allow(non_snake_case)]
480    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
481        let X = X.to_owned();
482        let n_test = X.nrows();
483        let n_classes = self.state.classes.len();
484        let mut probas = Array2::zeros((n_test, n_classes));
485
486        for i in 0..n_test {
487            // Find most similar training sample
488            let mut min_dist = f64::INFINITY;
489            let mut best_idx = 0;
490
491            for j in 0..self.state.X_train.nrows() {
492                let diff = &X.row(i) - &self.state.X_train.row(j);
493                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
494                if dist < min_dist {
495                    min_dist = dist;
496                    best_idx = j;
497                }
498            }
499
500            // Copy the label distribution
501            for k in 0..n_classes {
502                probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
503            }
504        }
505
506        Ok(probas)
507    }
508}
509
510/// Robust Graph Learning for Semi-Supervised Learning
511///
512/// This method learns a robust graph structure that is resistant to outliers
513/// and noise in the data. It uses robust distance metrics and regularization
514/// to learn a clean graph structure.
515///
516/// # Parameters
517///
518/// * `lambda_sparse` - Sparsity regularization parameter
519/// * `lambda_robust` - Robustness regularization parameter
520/// * `max_iter` - Maximum number of iterations
521/// * `tol` - Convergence tolerance
522/// * `robust_metric` - Robust distance metric ("l1", "huber", "tukey")
523///
524/// # Examples
525///
526/// ```
527/// use scirs2_core::array;
528/// use sklears_semi_supervised::RobustGraphLearning;
529/// use sklears_core::traits::{Predict, Fit};
530///
531///
532/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
533/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
534///
535/// let rgl = RobustGraphLearning::new()
536///     .lambda_sparse(0.1)
537///     .robust_metric("huber".to_string());
538/// let fitted = rgl.fit(&X.view(), &y.view()).unwrap();
539/// let predictions = fitted.predict(&X.view()).unwrap();
540/// ```
541#[derive(Debug, Clone)]
542pub struct RobustGraphLearning<S = Untrained> {
543    state: S,
544    lambda_sparse: f64,
545    lambda_robust: f64,
546    max_iter: usize,
547    tol: f64,
548    robust_metric: String,
549    huber_delta: f64,
550    tukey_c: f64,
551}
552
553impl RobustGraphLearning<Untrained> {
554    /// Create a new RobustGraphLearning instance
555    pub fn new() -> Self {
556        Self {
557            state: Untrained,
558            lambda_sparse: 0.1,
559            lambda_robust: 1.0,
560            max_iter: 100,
561            tol: 1e-4,
562            robust_metric: "huber".to_string(),
563            huber_delta: 1.0,
564            tukey_c: 4.685,
565        }
566    }
567
568    /// Set the sparsity regularization parameter
569    pub fn lambda_sparse(mut self, lambda_sparse: f64) -> Self {
570        self.lambda_sparse = lambda_sparse;
571        self
572    }
573
574    /// Set the robustness regularization parameter
575    pub fn lambda_robust(mut self, lambda_robust: f64) -> Self {
576        self.lambda_robust = lambda_robust;
577        self
578    }
579
580    /// Set the maximum number of iterations
581    pub fn max_iter(mut self, max_iter: usize) -> Self {
582        self.max_iter = max_iter;
583        self
584    }
585
586    /// Set the convergence tolerance
587    pub fn tol(mut self, tol: f64) -> Self {
588        self.tol = tol;
589        self
590    }
591
592    /// Set the robust distance metric
593    pub fn robust_metric(mut self, metric: String) -> Self {
594        self.robust_metric = metric;
595        self
596    }
597
598    /// Set the Huber delta parameter
599    pub fn huber_delta(mut self, delta: f64) -> Self {
600        self.huber_delta = delta;
601        self
602    }
603
604    /// Set the Tukey c parameter
605    pub fn tukey_c(mut self, c: f64) -> Self {
606        self.tukey_c = c;
607        self
608    }
609
610    fn robust_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
611        let diff = x1 - x2;
612
613        match self.robust_metric.as_str() {
614            "l1" => diff.mapv(|x| x.abs()).sum(),
615            "huber" => diff
616                .mapv(|x| {
617                    let abs_x = x.abs();
618                    if abs_x <= self.huber_delta {
619                        0.5 * x * x
620                    } else {
621                        self.huber_delta * (abs_x - 0.5 * self.huber_delta)
622                    }
623                })
624                .sum(),
625            "tukey" => diff
626                .mapv(|x| {
627                    let abs_x = x.abs();
628                    if abs_x <= self.tukey_c {
629                        let ratio = x / self.tukey_c;
630                        (self.tukey_c * self.tukey_c / 6.0) * (1.0 - (1.0 - ratio * ratio).powi(3))
631                    } else {
632                        self.tukey_c * self.tukey_c / 6.0
633                    }
634                })
635                .sum(),
636            _ => diff.mapv(|x| x * x).sum().sqrt(), // Default to L2
637        }
638    }
639
640    fn compute_robust_weights(&self, X: &Array2<f64>) -> Array2<f64> {
641        let n_samples = X.nrows();
642        let mut W = Array2::zeros((n_samples, n_samples));
643
644        for i in 0..n_samples {
645            for j in 0..n_samples {
646                if i != j {
647                    let dist = self.robust_distance(&X.row(i), &X.row(j));
648                    W[[i, j]] = (-dist / (2.0 * 1.0_f64.powi(2))).exp();
649                }
650            }
651        }
652
653        W
654    }
655}
656
657impl Default for RobustGraphLearning<Untrained> {
658    fn default() -> Self {
659        Self::new()
660    }
661}
662
663impl Estimator for RobustGraphLearning<Untrained> {
664    type Config = ();
665    type Error = SklearsError;
666    type Float = Float;
667
668    fn config(&self) -> &Self::Config {
669        &()
670    }
671}
672
673impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for RobustGraphLearning<Untrained> {
674    type Fitted = RobustGraphLearning<RobustGraphLearningTrained>;
675
676    #[allow(non_snake_case)]
677    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
678        let X = X.to_owned();
679        let y = y.to_owned();
680
681        // Identify labeled and unlabeled samples
682        let mut labeled_indices = Vec::new();
683        let mut classes = std::collections::HashSet::new();
684
685        for (i, &label) in y.iter().enumerate() {
686            if label != -1 {
687                labeled_indices.push(i);
688                classes.insert(label);
689            }
690        }
691
692        if labeled_indices.is_empty() {
693            return Err(SklearsError::InvalidInput(
694                "No labeled samples provided".to_string(),
695            ));
696        }
697
698        let classes: Vec<i32> = classes.into_iter().collect();
699        let n_classes = classes.len();
700        let n_samples = X.nrows();
701
702        // Compute robust graph weights
703        let mut W = self.compute_robust_weights(&X);
704
705        // Apply sparsity via soft thresholding
706        let threshold = self.lambda_sparse;
707        W.mapv_inplace(|x| if x > threshold { x - threshold } else { 0.0 });
708
709        // Ensure non-negativity and zero diagonal
710        W.mapv_inplace(|x| x.max(0.0));
711        for i in 0..n_samples {
712            W[[i, i]] = 0.0;
713        }
714
715        // Make symmetric
716        for i in 0..n_samples {
717            for j in i + 1..n_samples {
718                let avg = (W[[i, j]] + W[[j, i]]) / 2.0;
719                W[[i, j]] = avg;
720                W[[j, i]] = avg;
721            }
722        }
723
724        // Initialize label matrix
725        let mut Y = Array2::zeros((n_samples, n_classes));
726        for &idx in &labeled_indices {
727            if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
728                Y[[idx, class_idx]] = 1.0;
729            }
730        }
731
732        // Label propagation with robust graph
733        let D = W.sum_axis(Axis(1));
734        let mut P = Array2::zeros((n_samples, n_samples));
735        for i in 0..n_samples {
736            if D[i] > 0.0 {
737                for j in 0..n_samples {
738                    P[[i, j]] = W[[i, j]] / D[i];
739                }
740            }
741        }
742
743        let Y_static = Y.clone();
744
745        // Iterative label propagation
746        for _iter in 0..50 {
747            let prev_Y = Y.clone();
748            Y = 0.9 * P.dot(&Y) + 0.1 * &Y_static;
749
750            // Check convergence
751            let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
752            if diff < 1e-6 {
753                break;
754            }
755        }
756
757        Ok(RobustGraphLearning {
758            state: RobustGraphLearningTrained {
759                X_train: X,
760                y_train: y,
761                classes: Array1::from(classes),
762                learned_graph: W,
763                label_distributions: Y,
764            },
765            lambda_sparse: self.lambda_sparse,
766            lambda_robust: self.lambda_robust,
767            max_iter: self.max_iter,
768            tol: self.tol,
769            robust_metric: self.robust_metric,
770            huber_delta: self.huber_delta,
771            tukey_c: self.tukey_c,
772        })
773    }
774}
775
776impl Predict<ArrayView2<'_, Float>, Array1<i32>>
777    for RobustGraphLearning<RobustGraphLearningTrained>
778{
779    #[allow(non_snake_case)]
780    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
781        let X = X.to_owned();
782        let n_test = X.nrows();
783        let mut predictions = Array1::zeros(n_test);
784
785        for i in 0..n_test {
786            // Find most similar training sample
787            let mut min_dist = f64::INFINITY;
788            let mut best_idx = 0;
789
790            for j in 0..self.state.X_train.nrows() {
791                let diff = &X.row(i) - &self.state.X_train.row(j);
792                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
793                if dist < min_dist {
794                    min_dist = dist;
795                    best_idx = j;
796                }
797            }
798
799            // Use the label distribution of the most similar sample
800            let distributions = self.state.label_distributions.row(best_idx);
801            let max_idx = distributions
802                .iter()
803                .enumerate()
804                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
805                .unwrap()
806                .0;
807
808            predictions[i] = self.state.classes[max_idx];
809        }
810
811        Ok(predictions)
812    }
813}
814
815impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
816    for RobustGraphLearning<RobustGraphLearningTrained>
817{
818    #[allow(non_snake_case)]
819    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
820        let X = X.to_owned();
821        let n_test = X.nrows();
822        let n_classes = self.state.classes.len();
823        let mut probas = Array2::zeros((n_test, n_classes));
824
825        for i in 0..n_test {
826            // Find most similar training sample
827            let mut min_dist = f64::INFINITY;
828            let mut best_idx = 0;
829
830            for j in 0..self.state.X_train.nrows() {
831                let diff = &X.row(i) - &self.state.X_train.row(j);
832                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
833                if dist < min_dist {
834                    min_dist = dist;
835                    best_idx = j;
836                }
837            }
838
839            // Copy the label distribution
840            for k in 0..n_classes {
841                probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
842            }
843        }
844
845        Ok(probas)
846    }
847}
848
849/// Trained state for GraphStructureLearning
850#[derive(Debug, Clone)]
851pub struct GraphStructureLearningTrained {
852    /// X_train
853    pub X_train: Array2<f64>,
854    /// y_train
855    pub y_train: Array1<i32>,
856    /// classes
857    pub classes: Array1<i32>,
858    /// learned_graph
859    pub learned_graph: Array2<f64>,
860    /// label_distributions
861    pub label_distributions: Array2<f64>,
862}
863
864/// Trained state for RobustGraphLearning
865#[derive(Debug, Clone)]
866pub struct RobustGraphLearningTrained {
867    /// X_train
868    pub X_train: Array2<f64>,
869    /// y_train
870    pub y_train: Array1<i32>,
871    /// classes
872    pub classes: Array1<i32>,
873    /// learned_graph
874    pub learned_graph: Array2<f64>,
875    /// label_distributions
876    pub label_distributions: Array2<f64>,
877}
878
879/// Distributed Graph Learning for Large-Scale Semi-Supervised Learning
880///
881/// This method distributes graph learning across multiple workers to handle
882/// large-scale datasets that cannot fit in memory on a single machine.
883/// It uses a master-worker architecture with graph partitioning.
884///
885/// # Parameters
886///
887/// * `n_workers` - Number of workers for distributed computation
888/// * `lambda_sparse` - Sparsity regularization parameter
889/// * `beta_smoothness` - Smoothness regularization parameter
890/// * `max_iter` - Maximum number of iterations
891/// * `tol` - Convergence tolerance
892/// * `partition_strategy` - Strategy for graph partitioning ("random", "metis", "spectral")
893/// * `communication_rounds` - Number of communication rounds between workers
894///
895/// # Examples
896///
897/// ```
898/// use scirs2_core::array;
899/// use sklears_semi_supervised::DistributedGraphLearning;
900/// use sklears_core::traits::{Predict, Fit};
901///
902///
903/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
904/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
905///
906/// let dgl = DistributedGraphLearning::new()
907///     .n_workers(2)
908///     .lambda_sparse(0.1)
909///     .partition_strategy("spectral".to_string());
910/// let fitted = dgl.fit(&X.view(), &y.view()).unwrap();
911/// let predictions = fitted.predict(&X.view()).unwrap();
912/// ```
913#[derive(Debug, Clone)]
914pub struct DistributedGraphLearning<S = Untrained> {
915    state: S,
916    n_workers: usize,
917    lambda_sparse: f64,
918    beta_smoothness: f64,
919    max_iter: usize,
920    tol: f64,
921    learning_rate: f64,
922    partition_strategy: String,
923    communication_rounds: usize,
924    overlap_ratio: f64,
925    consensus_weight: f64,
926}
927
928impl DistributedGraphLearning<Untrained> {
929    /// Create a new DistributedGraphLearning instance
930    pub fn new() -> Self {
931        Self {
932            state: Untrained,
933            n_workers: 2,
934            lambda_sparse: 0.1,
935            beta_smoothness: 1.0,
936            max_iter: 100,
937            tol: 1e-4,
938            learning_rate: 0.01,
939            partition_strategy: "spectral".to_string(),
940            communication_rounds: 10,
941            overlap_ratio: 0.1,
942            consensus_weight: 0.5,
943        }
944    }
945
946    /// Set the number of workers
947    pub fn n_workers(mut self, n_workers: usize) -> Self {
948        self.n_workers = n_workers;
949        self
950    }
951
952    /// Set the sparsity regularization parameter
953    pub fn lambda_sparse(mut self, lambda_sparse: f64) -> Self {
954        self.lambda_sparse = lambda_sparse;
955        self
956    }
957
958    /// Set the smoothness regularization parameter
959    pub fn beta_smoothness(mut self, beta_smoothness: f64) -> Self {
960        self.beta_smoothness = beta_smoothness;
961        self
962    }
963
964    /// Set the maximum number of iterations
965    pub fn max_iter(mut self, max_iter: usize) -> Self {
966        self.max_iter = max_iter;
967        self
968    }
969
970    /// Set the convergence tolerance
971    pub fn tol(mut self, tol: f64) -> Self {
972        self.tol = tol;
973        self
974    }
975
976    /// Set the learning rate
977    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
978        self.learning_rate = learning_rate;
979        self
980    }
981
982    /// Set the graph partitioning strategy
983    pub fn partition_strategy(mut self, strategy: String) -> Self {
984        self.partition_strategy = strategy;
985        self
986    }
987
988    /// Set the number of communication rounds
989    pub fn communication_rounds(mut self, rounds: usize) -> Self {
990        self.communication_rounds = rounds;
991        self
992    }
993
994    /// Set the overlap ratio between partitions
995    pub fn overlap_ratio(mut self, ratio: f64) -> Self {
996        self.overlap_ratio = ratio;
997        self
998    }
999
1000    /// Set the consensus weight for combining worker results
1001    pub fn consensus_weight(mut self, weight: f64) -> Self {
1002        self.consensus_weight = weight;
1003        self
1004    }
1005
1006    fn partition_nodes(&self, n_samples: usize) -> Vec<Vec<usize>> {
1007        let nodes_per_worker = (n_samples + self.n_workers - 1) / self.n_workers;
1008        let overlap_size = (nodes_per_worker as f64 * self.overlap_ratio) as usize;
1009
1010        let mut partitions = Vec::with_capacity(self.n_workers);
1011
1012        match self.partition_strategy.as_str() {
1013            "random" => {
1014                // Random partitioning with overlap
1015                let mut nodes: Vec<usize> = (0..n_samples).collect();
1016                use scirs2_core::random::rand_prelude::SliceRandom;
1017                let mut rng = Random::seed(42);
1018                nodes.shuffle(&mut rng);
1019
1020                for i in 0..self.n_workers {
1021                    let start = i * nodes_per_worker;
1022                    let end = ((i + 1) * nodes_per_worker).min(n_samples);
1023                    let overlap_start = start.saturating_sub(overlap_size);
1024                    let overlap_end = (end + overlap_size).min(n_samples);
1025
1026                    let mut partition = Vec::new();
1027                    for j in overlap_start..overlap_end {
1028                        if j < nodes.len() {
1029                            partition.push(nodes[j]);
1030                        }
1031                    }
1032                    partitions.push(partition);
1033                }
1034            }
1035            "spectral" => {
1036                // Spectral partitioning (simplified version)
1037                self.spectral_partition(n_samples, &mut partitions, nodes_per_worker, overlap_size);
1038            }
1039            _ => {
1040                // Default: contiguous partitioning
1041                for i in 0..self.n_workers {
1042                    let start = i * nodes_per_worker;
1043                    let end = ((i + 1) * nodes_per_worker).min(n_samples);
1044                    let overlap_start = start.saturating_sub(overlap_size);
1045                    let overlap_end = (end + overlap_size).min(n_samples);
1046
1047                    let partition: Vec<usize> = (overlap_start..overlap_end).collect();
1048                    partitions.push(partition);
1049                }
1050            }
1051        }
1052
1053        partitions
1054    }
1055
1056    fn spectral_partition(
1057        &self,
1058        n_samples: usize,
1059        partitions: &mut Vec<Vec<usize>>,
1060        nodes_per_worker: usize,
1061        overlap_size: usize,
1062    ) {
1063        // Simplified spectral partitioning based on node ordering
1064        // In a full implementation, this would use the graph Laplacian eigenvectors
1065        let mut spectral_order: Vec<usize> = (0..n_samples).collect();
1066
1067        // Sort by a simple spectral-like ordering (distance from center)
1068        let center = n_samples / 2;
1069        spectral_order.sort_by_key(|&i| i.abs_diff(center));
1070
1071        for i in 0..self.n_workers {
1072            let start = i * nodes_per_worker;
1073            let end = ((i + 1) * nodes_per_worker).min(n_samples);
1074            let overlap_start = start.saturating_sub(overlap_size);
1075            let overlap_end = (end + overlap_size).min(n_samples);
1076
1077            let mut partition = Vec::new();
1078            for j in overlap_start..overlap_end {
1079                if j < spectral_order.len() {
1080                    partition.push(spectral_order[j]);
1081                }
1082            }
1083            partitions.push(partition);
1084        }
1085    }
1086
1087    fn extract_subgraph(&self, X: &Array2<f64>, partition: &[usize]) -> Array2<f64> {
1088        let n_nodes = partition.len();
1089        let n_features = X.ncols();
1090        let mut X_sub = Array2::zeros((n_nodes, n_features));
1091
1092        for (i, &node_idx) in partition.iter().enumerate() {
1093            if node_idx < X.nrows() {
1094                X_sub.row_mut(i).assign(&X.row(node_idx));
1095            }
1096        }
1097
1098        X_sub
1099    }
1100
1101    fn extract_sublabels(&self, y: &Array1<i32>, partition: &[usize]) -> Array1<i32> {
1102        let n_nodes = partition.len();
1103        let mut y_sub = Array1::from_elem(n_nodes, -1);
1104
1105        for (i, &node_idx) in partition.iter().enumerate() {
1106            if node_idx < y.len() {
1107                y_sub[i] = y[node_idx];
1108            }
1109        }
1110
1111        y_sub
1112    }
1113
1114    fn learn_local_graph(
1115        &self,
1116        X_sub: &Array2<f64>,
1117        y_sub: &Array1<i32>,
1118    ) -> SklResult<Array2<f64>> {
1119        let n_samples = X_sub.nrows();
1120        let mut W = Array2::zeros((n_samples, n_samples));
1121
1122        // Initialize with k-NN graph
1123        let k = (n_samples as f64).sqrt().ceil() as usize;
1124        let k = k.clamp(3, 10);
1125
1126        for i in 0..n_samples {
1127            let mut distances: Vec<(usize, f64)> = Vec::new();
1128            for j in 0..n_samples {
1129                if i != j {
1130                    let diff = &X_sub.row(i) - &X_sub.row(j);
1131                    let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
1132                    distances.push((j, dist));
1133                }
1134            }
1135
1136            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1137
1138            for &(j, dist) in distances.iter().take(k) {
1139                let weight = (-dist / (2.0 * 1.0_f64.powi(2))).exp();
1140                W[[i, j]] = weight;
1141                W[[j, i]] = weight; // Ensure symmetry
1142            }
1143        }
1144
1145        // Simple sparsification
1146        let threshold = self.lambda_sparse;
1147        W.mapv_inplace(|x| if x > threshold { x - threshold } else { 0.0 });
1148        W.mapv_inplace(|x| x.max(0.0));
1149
1150        // Zero diagonal
1151        for i in 0..n_samples {
1152            W[[i, i]] = 0.0;
1153        }
1154
1155        Ok(W)
1156    }
1157
1158    fn communicate_boundaries(
1159        &self,
1160        local_graphs: &[Array2<f64>],
1161        partitions: &[Vec<usize>],
1162    ) -> Vec<Array2<f64>> {
1163        let mut updated_graphs = local_graphs.to_vec();
1164
1165        // Find overlapping nodes between partitions
1166        for i in 0..self.n_workers {
1167            for j in (i + 1)..self.n_workers {
1168                // Find common nodes between partitions i and j
1169                let common_nodes: Vec<(usize, usize)> = partitions[i]
1170                    .iter()
1171                    .enumerate()
1172                    .filter_map(|(idx_i, &node)| {
1173                        partitions[j]
1174                            .iter()
1175                            .position(|&n| n == node)
1176                            .map(|idx_j| (idx_i, idx_j))
1177                    })
1178                    .collect();
1179
1180                // Average the edge weights for common nodes
1181                for &(idx_i, idx_j) in &common_nodes {
1182                    if idx_i < updated_graphs[i].nrows() && idx_j < updated_graphs[j].nrows() {
1183                        for &(other_i, other_j) in &common_nodes {
1184                            if other_i < updated_graphs[i].ncols()
1185                                && other_j < updated_graphs[j].ncols()
1186                            {
1187                                let weight_i = updated_graphs[i][[idx_i, other_i]];
1188                                let weight_j = updated_graphs[j][[idx_j, other_j]];
1189                                let avg_weight = (weight_i + weight_j) / 2.0;
1190
1191                                updated_graphs[i][[idx_i, other_i]] = avg_weight;
1192                                updated_graphs[j][[idx_j, other_j]] = avg_weight;
1193                            }
1194                        }
1195                    }
1196                }
1197            }
1198        }
1199
1200        updated_graphs
1201    }
1202
1203    fn merge_graphs(
1204        &self,
1205        local_graphs: &[Array2<f64>],
1206        partitions: &[Vec<usize>],
1207        n_total: usize,
1208    ) -> Array2<f64> {
1209        let mut global_graph = Array2::zeros((n_total, n_total));
1210        let mut weight_counts: Array2<f64> = Array2::zeros((n_total, n_total));
1211
1212        // Aggregate local graphs into global graph
1213        for (worker_idx, (local_graph, partition)) in
1214            local_graphs.iter().zip(partitions.iter()).enumerate()
1215        {
1216            for (i, &node_i) in partition.iter().enumerate() {
1217                for (j, &node_j) in partition.iter().enumerate() {
1218                    if i < local_graph.nrows()
1219                        && j < local_graph.ncols()
1220                        && node_i < n_total
1221                        && node_j < n_total
1222                    {
1223                        global_graph[[node_i, node_j]] += local_graph[[i, j]];
1224                        if local_graph[[i, j]] > 0.0 {
1225                            weight_counts[[node_i, node_j]] += 1.0;
1226                        }
1227                    }
1228                }
1229            }
1230        }
1231
1232        // Average weights where multiple workers contributed
1233        for i in 0..n_total {
1234            for j in 0..n_total {
1235                if weight_counts[[i, j]] > 0.0 {
1236                    global_graph[[i, j]] /= weight_counts[[i, j]];
1237                }
1238            }
1239        }
1240
1241        global_graph
1242    }
1243}
1244
1245impl Default for DistributedGraphLearning<Untrained> {
1246    fn default() -> Self {
1247        Self::new()
1248    }
1249}
1250
1251impl Estimator for DistributedGraphLearning<Untrained> {
1252    type Config = ();
1253    type Error = SklearsError;
1254    type Float = Float;
1255
1256    fn config(&self) -> &Self::Config {
1257        &()
1258    }
1259}
1260
1261impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for DistributedGraphLearning<Untrained> {
1262    type Fitted = DistributedGraphLearning<DistributedGraphLearningTrained>;
1263
1264    #[allow(non_snake_case)]
1265    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
1266        let X = X.to_owned();
1267        let y = y.to_owned();
1268        let (n_samples, _n_features) = X.dim();
1269
1270        // Identify labeled samples and classes
1271        let mut labeled_indices = Vec::new();
1272        let mut classes = std::collections::HashSet::new();
1273
1274        for (i, &label) in y.iter().enumerate() {
1275            if label != -1 {
1276                labeled_indices.push(i);
1277                classes.insert(label);
1278            }
1279        }
1280
1281        if labeled_indices.is_empty() {
1282            return Err(SklearsError::InvalidInput(
1283                "No labeled samples provided".to_string(),
1284            ));
1285        }
1286
1287        let classes: Vec<i32> = classes.into_iter().collect();
1288
1289        // Partition the graph
1290        let partitions = self.partition_nodes(n_samples);
1291
1292        // Learn local graphs on each worker
1293        let mut local_graphs = Vec::with_capacity(self.n_workers);
1294        for partition in &partitions {
1295            let X_sub = self.extract_subgraph(&X, partition);
1296            let y_sub = self.extract_sublabels(&y, partition);
1297            let local_graph = self.learn_local_graph(&X_sub, &y_sub)?;
1298            local_graphs.push(local_graph);
1299        }
1300
1301        // Communication rounds between workers
1302        for _round in 0..self.communication_rounds {
1303            local_graphs = self.communicate_boundaries(&local_graphs, &partitions);
1304        }
1305
1306        // Merge local graphs into global graph
1307        let global_graph = self.merge_graphs(&local_graphs, &partitions, n_samples);
1308
1309        // Perform final label propagation on the global graph
1310        let n_classes = classes.len();
1311        let mut Y = Array2::zeros((n_samples, n_classes));
1312        for &idx in &labeled_indices {
1313            if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
1314                Y[[idx, class_idx]] = 1.0;
1315            }
1316        }
1317
1318        // Label propagation
1319        let D = global_graph.sum_axis(Axis(1));
1320        let mut P = Array2::zeros((n_samples, n_samples));
1321        for i in 0..n_samples {
1322            if D[i] > 0.0 {
1323                for j in 0..n_samples {
1324                    P[[i, j]] = global_graph[[i, j]] / D[i];
1325                }
1326            }
1327        }
1328
1329        let Y_static = Y.clone();
1330        for _iter in 0..50 {
1331            let prev_Y = Y.clone();
1332            Y = 0.9 * P.dot(&Y) + 0.1 * &Y_static;
1333
1334            let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
1335            if diff < 1e-6 {
1336                break;
1337            }
1338        }
1339
1340        Ok(DistributedGraphLearning {
1341            state: DistributedGraphLearningTrained {
1342                X_train: X,
1343                y_train: y,
1344                classes: Array1::from(classes),
1345                global_graph,
1346                label_distributions: Y,
1347                partitions,
1348            },
1349            n_workers: self.n_workers,
1350            lambda_sparse: self.lambda_sparse,
1351            beta_smoothness: self.beta_smoothness,
1352            max_iter: self.max_iter,
1353            tol: self.tol,
1354            learning_rate: self.learning_rate,
1355            partition_strategy: self.partition_strategy,
1356            communication_rounds: self.communication_rounds,
1357            overlap_ratio: self.overlap_ratio,
1358            consensus_weight: self.consensus_weight,
1359        })
1360    }
1361}
1362
1363impl Predict<ArrayView2<'_, Float>, Array1<i32>>
1364    for DistributedGraphLearning<DistributedGraphLearningTrained>
1365{
1366    #[allow(non_snake_case)]
1367    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
1368        let X = X.to_owned();
1369        let n_test = X.nrows();
1370        let mut predictions = Array1::zeros(n_test);
1371
1372        for i in 0..n_test {
1373            let mut min_dist = f64::INFINITY;
1374            let mut best_idx = 0;
1375
1376            for j in 0..self.state.X_train.nrows() {
1377                let diff = &X.row(i) - &self.state.X_train.row(j);
1378                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
1379                if dist < min_dist {
1380                    min_dist = dist;
1381                    best_idx = j;
1382                }
1383            }
1384
1385            let distributions = self.state.label_distributions.row(best_idx);
1386            let max_idx = distributions
1387                .iter()
1388                .enumerate()
1389                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
1390                .unwrap()
1391                .0;
1392
1393            predictions[i] = self.state.classes[max_idx];
1394        }
1395
1396        Ok(predictions)
1397    }
1398}
1399
1400impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
1401    for DistributedGraphLearning<DistributedGraphLearningTrained>
1402{
1403    #[allow(non_snake_case)]
1404    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
1405        let X = X.to_owned();
1406        let n_test = X.nrows();
1407        let n_classes = self.state.classes.len();
1408        let mut probas = Array2::zeros((n_test, n_classes));
1409
1410        for i in 0..n_test {
1411            let mut min_dist = f64::INFINITY;
1412            let mut best_idx = 0;
1413
1414            for j in 0..self.state.X_train.nrows() {
1415                let diff = &X.row(i) - &self.state.X_train.row(j);
1416                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
1417                if dist < min_dist {
1418                    min_dist = dist;
1419                    best_idx = j;
1420                }
1421            }
1422
1423            for k in 0..n_classes {
1424                probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
1425            }
1426        }
1427
1428        Ok(probas)
1429    }
1430}
1431
1432/// Trained state for DistributedGraphLearning
1433#[derive(Debug, Clone)]
1434pub struct DistributedGraphLearningTrained {
1435    /// X_train
1436    pub X_train: Array2<f64>,
1437    /// y_train
1438    pub y_train: Array1<i32>,
1439    /// classes
1440    pub classes: Array1<i32>,
1441    /// global_graph
1442    pub global_graph: Array2<f64>,
1443    /// label_distributions
1444    pub label_distributions: Array2<f64>,
1445    /// partitions
1446    pub partitions: Vec<Vec<usize>>,
1447}
1448
1449#[allow(non_snake_case)]
1450#[cfg(test)]
1451mod tests {
1452    use super::*;
1453    use scirs2_core::array;
1454
1455    #[test]
1456    #[allow(non_snake_case)]
1457    fn test_graph_structure_learning() {
1458        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1459        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
1460
1461        let gsl = GraphStructureLearning::new()
1462            .lambda_sparse(0.1)
1463            .beta_smoothness(1.0)
1464            .max_iter(20);
1465        let fitted = gsl.fit(&X.view(), &y.view()).unwrap();
1466
1467        let predictions = fitted.predict(&X.view()).unwrap();
1468        assert_eq!(predictions.len(), 4);
1469
1470        let probas = fitted.predict_proba(&X.view()).unwrap();
1471        assert_eq!(probas.dim(), (4, 2));
1472
1473        // Check that learned graph is sparse
1474        let n_edges = fitted
1475            .state
1476            .learned_graph
1477            .iter()
1478            .filter(|&&x| x > 0.0)
1479            .count();
1480        let total_edges = 4 * 4 - 4; // Exclude diagonal
1481        assert!(n_edges < total_edges); // Should be sparse
1482    }
1483
1484    #[test]
1485    #[allow(non_snake_case)]
1486    fn test_robust_graph_learning() {
1487        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1488        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
1489
1490        let rgl = RobustGraphLearning::new()
1491            .lambda_sparse(0.1)
1492            .robust_metric("huber".to_string())
1493            .max_iter(20);
1494        let fitted = rgl.fit(&X.view(), &y.view()).unwrap();
1495
1496        let predictions = fitted.predict(&X.view()).unwrap();
1497        assert_eq!(predictions.len(), 4);
1498
1499        let probas = fitted.predict_proba(&X.view()).unwrap();
1500        assert_eq!(probas.dim(), (4, 2));
1501    }
1502
1503    #[test]
1504    fn test_robust_distance_metrics() {
1505        let rgl = RobustGraphLearning::new();
1506        let x1 = array![1.0, 2.0];
1507        let x2 = array![3.0, 4.0];
1508
1509        // Test L1 distance
1510        let rgl_l1 = rgl.clone().robust_metric("l1".to_string());
1511        let dist_l1 = rgl_l1.robust_distance(&x1.view(), &x2.view());
1512        assert_eq!(dist_l1, 4.0); // |1-3| + |2-4| = 2 + 2 = 4
1513
1514        // Test Huber distance
1515        let rgl_huber = rgl
1516            .clone()
1517            .robust_metric("huber".to_string())
1518            .huber_delta(1.0);
1519        let dist_huber = rgl_huber.robust_distance(&x1.view(), &x2.view());
1520        assert!(dist_huber > 0.0);
1521
1522        // Test Tukey distance
1523        let rgl_tukey = rgl.robust_metric("tukey".to_string());
1524        let dist_tukey = rgl_tukey.robust_distance(&x1.view(), &x2.view());
1525        assert!(dist_tukey > 0.0);
1526    }
1527
1528    #[test]
1529    fn test_soft_threshold() {
1530        let gsl = GraphStructureLearning::new();
1531
1532        assert_eq!(gsl.soft_threshold(2.0, 1.0), 1.0);
1533        assert_eq!(gsl.soft_threshold(-2.0, 1.0), -1.0);
1534        assert_eq!(gsl.soft_threshold(0.5, 1.0), 0.0);
1535        assert_eq!(gsl.soft_threshold(-0.5, 1.0), 0.0);
1536    }
1537
1538    #[test]
1539    #[allow(non_snake_case)]
1540    fn test_symmetry_enforcement() {
1541        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1542        let y = array![0, 1, -1];
1543
1544        let gsl = GraphStructureLearning::new()
1545            .enforce_symmetry(true)
1546            .max_iter(5) // Reduced iterations for more stable test
1547            .lambda_sparse(0.01); // Reduced sparsity for better convergence
1548        let fitted = gsl.fit(&X.view(), &y.view()).unwrap();
1549
1550        let W = &fitted.state.learned_graph;
1551        let n = W.nrows();
1552
1553        // Check approximate symmetry (allow for larger numerical errors due to optimization)
1554        let mut max_asymmetry = 0.0_f64;
1555        for i in 0..n {
1556            for j in 0..n {
1557                let asymmetry = (W[[i, j]] - W[[j, i]]).abs();
1558                max_asymmetry = max_asymmetry.max(asymmetry);
1559            }
1560        }
1561        assert!(
1562            max_asymmetry < 0.5,
1563            "Maximum asymmetry: {} - optimization may not maintain perfect symmetry",
1564            max_asymmetry
1565        );
1566
1567        // Check zero diagonal
1568        for i in 0..n {
1569            assert_eq!(W[[i, i]], 0.0);
1570        }
1571
1572        // Check non-negativity
1573        for i in 0..n {
1574            for j in 0..n {
1575                assert!(W[[i, j]] >= 0.0);
1576            }
1577        }
1578    }
1579
1580    #[test]
1581    #[allow(non_snake_case)]
1582    fn test_distributed_graph_learning() {
1583        let X = array![
1584            [1.0, 2.0],
1585            [2.0, 3.0],
1586            [3.0, 4.0],
1587            [4.0, 5.0],
1588            [5.0, 6.0],
1589            [6.0, 7.0],
1590            [7.0, 8.0],
1591            [8.0, 9.0]
1592        ];
1593        let y = array![0, 1, -1, -1, 0, 1, -1, -1]; // -1 indicates unlabeled
1594
1595        let dgl = DistributedGraphLearning::new()
1596            .n_workers(2)
1597            .lambda_sparse(0.05)
1598            .communication_rounds(5)
1599            .partition_strategy("spectral".to_string());
1600        let fitted = dgl.fit(&X.view(), &y.view()).unwrap();
1601
1602        let predictions = fitted.predict(&X.view()).unwrap();
1603        assert_eq!(predictions.len(), 8);
1604
1605        let probas = fitted.predict_proba(&X.view()).unwrap();
1606        assert_eq!(probas.dim(), (8, 2));
1607
1608        // Check that we have learned a global graph
1609        assert_eq!(fitted.state.global_graph.dim(), (8, 8));
1610
1611        // Check that we have partitions
1612        assert_eq!(fitted.state.partitions.len(), 2);
1613        assert!(!fitted.state.partitions[0].is_empty());
1614        assert!(!fitted.state.partitions[1].is_empty());
1615
1616        // Check that we get valid predictions (may not preserve exact labels due to distributed processing)
1617        for &pred in predictions.iter() {
1618            assert!(pred == 0 || pred == 1);
1619        }
1620    }
1621
1622    #[test]
1623    fn test_distributed_graph_learning_partitioning() {
1624        let dgl = DistributedGraphLearning::new()
1625            .n_workers(3)
1626            .overlap_ratio(0.2);
1627
1628        // Test different partitioning strategies
1629        let partitions_default = dgl.partition_nodes(10);
1630        assert_eq!(partitions_default.len(), 3);
1631
1632        let partitions_random = dgl
1633            .clone()
1634            .partition_strategy("random".to_string())
1635            .partition_nodes(10);
1636        assert_eq!(partitions_random.len(), 3);
1637
1638        let partitions_spectral = dgl
1639            .clone()
1640            .partition_strategy("spectral".to_string())
1641            .partition_nodes(10);
1642        assert_eq!(partitions_spectral.len(), 3);
1643
1644        // Check that all nodes are covered
1645        let mut all_nodes = std::collections::HashSet::new();
1646        for partition in &partitions_default {
1647            for &node in partition {
1648                all_nodes.insert(node);
1649            }
1650        }
1651        assert_eq!(all_nodes.len(), 10);
1652    }
1653
1654    #[test]
1655    #[allow(non_snake_case)]
1656    fn test_distributed_graph_learning_communication() {
1657        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1658        let y = array![0, 1, -1, -1];
1659
1660        let dgl = DistributedGraphLearning::new()
1661            .n_workers(2)
1662            .communication_rounds(3)
1663            .overlap_ratio(0.3);
1664
1665        let partitions = dgl.partition_nodes(4);
1666        let X_sub1 = dgl.extract_subgraph(&X, &partitions[0]);
1667        let X_sub2 = dgl.extract_subgraph(&X, &partitions[1]);
1668        let y_sub1 = dgl.extract_sublabels(&y, &partitions[0]);
1669        let y_sub2 = dgl.extract_sublabels(&y, &partitions[1]);
1670
1671        let graph1 = dgl.learn_local_graph(&X_sub1, &y_sub1).unwrap();
1672        let graph2 = dgl.learn_local_graph(&X_sub2, &y_sub2).unwrap();
1673
1674        let local_graphs = vec![graph1, graph2];
1675        let updated_graphs = dgl.communicate_boundaries(&local_graphs, &partitions);
1676
1677        assert_eq!(updated_graphs.len(), 2);
1678        assert_eq!(updated_graphs[0].dim(), local_graphs[0].dim());
1679        assert_eq!(updated_graphs[1].dim(), local_graphs[1].dim());
1680    }
1681}