sklears_semi_supervised/
landmark_methods.rs

1//! Landmark-based methods for large-scale semi-supervised learning
2//!
3//! This module provides landmark-based algorithms that scale to very large datasets
4//! by selecting representative points (landmarks) and building graphs based on
5//! relationships to these landmarks rather than all pairwise relationships.
6
7use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::rand_prelude::*;
9use scirs2_core::random::Random;
10use sklears_core::error::SklearsError;
11use std::collections::HashMap;
12
13/// Landmark-based graph construction for scalable semi-supervised learning
14#[derive(Clone)]
15pub struct LandmarkGraphConstruction {
16    /// Number of landmarks to select
17    pub n_landmarks: usize,
18    /// Number of neighbors to connect to each landmark
19    pub k_neighbors: usize,
20    /// Landmark selection strategy: "random", "kmeans", "farthest_first", "density_based"
21    pub selection_strategy: String,
22    /// Graph construction method: "knn_to_landmarks", "rbf_to_landmarks", "interpolation"
23    pub construction_method: String,
24    /// Bandwidth parameter for RBF connections
25    pub bandwidth: f64,
26    /// Maximum iterations for k-means landmark selection
27    pub max_iter: usize,
28    /// Random state for reproducibility
29    pub random_state: Option<u64>,
30}
31
32impl LandmarkGraphConstruction {
33    /// Create a new landmark graph construction instance
34    pub fn new() -> Self {
35        Self {
36            n_landmarks: 100,
37            k_neighbors: 5,
38            selection_strategy: "kmeans".to_string(),
39            construction_method: "knn_to_landmarks".to_string(),
40            bandwidth: 1.0,
41            max_iter: 100,
42            random_state: None,
43        }
44    }
45
46    /// Set the number of landmarks
47    pub fn n_landmarks(mut self, n: usize) -> Self {
48        self.n_landmarks = n;
49        self
50    }
51
52    /// Set the number of neighbors per landmark
53    pub fn k_neighbors(mut self, k: usize) -> Self {
54        self.k_neighbors = k;
55        self
56    }
57
58    /// Set the landmark selection strategy
59    pub fn selection_strategy(mut self, strategy: String) -> Self {
60        self.selection_strategy = strategy;
61        self
62    }
63
64    /// Set the graph construction method
65    pub fn construction_method(mut self, method: String) -> Self {
66        self.construction_method = method;
67        self
68    }
69
70    /// Set the bandwidth parameter
71    pub fn bandwidth(mut self, bw: f64) -> Self {
72        self.bandwidth = bw;
73        self
74    }
75
76    /// Set the maximum iterations
77    pub fn max_iter(mut self, max_iter: usize) -> Self {
78        self.max_iter = max_iter;
79        self
80    }
81
82    /// Set the random state
83    pub fn random_state(mut self, seed: u64) -> Self {
84        self.random_state = Some(seed);
85        self
86    }
87
88    /// Construct landmark-based graph
89    pub fn fit(&self, X: &ArrayView2<f64>) -> Result<LandmarkGraphResult, SklearsError> {
90        let n_samples = X.nrows();
91
92        if n_samples == 0 {
93            return Err(SklearsError::InvalidInput(
94                "No samples provided".to_string(),
95            ));
96        }
97
98        let effective_landmarks = self.n_landmarks.min(n_samples);
99
100        let mut rng = Random::default();
101
102        // Select landmarks
103        let (landmark_indices, landmarks) =
104            self.select_landmarks(X, effective_landmarks, &mut rng)?;
105
106        // Construct graph based on landmarks
107        let adjacency_matrix = self.construct_landmark_graph(X, &landmarks, &landmark_indices)?;
108
109        Ok(LandmarkGraphResult {
110            adjacency_matrix,
111            landmark_indices,
112            landmarks,
113        })
114    }
115
116    /// Select landmarks using different strategies
117    fn select_landmarks(
118        &self,
119        X: &ArrayView2<f64>,
120        n_landmarks: usize,
121        rng: &mut Random,
122    ) -> Result<(Vec<usize>, Array2<f64>), SklearsError> {
123        match self.selection_strategy.as_str() {
124            "random" => self.random_landmarks(X, n_landmarks, rng),
125            "kmeans" => self.kmeans_landmarks(X, n_landmarks, rng),
126            "farthest_first" => self.farthest_first_landmarks(X, n_landmarks, rng),
127            "density_based" => self.density_based_landmarks(X, n_landmarks, rng),
128            _ => Err(SklearsError::InvalidInput(format!(
129                "Unknown selection strategy: {}",
130                self.selection_strategy
131            ))),
132        }
133    }
134
135    /// Random landmark selection
136    fn random_landmarks(
137        &self,
138        X: &ArrayView2<f64>,
139        n_landmarks: usize,
140        rng: &mut Random,
141    ) -> Result<(Vec<usize>, Array2<f64>), SklearsError> {
142        let n_samples = X.nrows();
143        let indices: Vec<usize> = (0..n_samples)
144            .choose_multiple(rng, n_landmarks)
145            .into_iter()
146            .collect();
147
148        let mut landmarks = Array2::<f64>::zeros((n_landmarks, X.ncols()));
149        for (i, &idx) in indices.iter().enumerate() {
150            landmarks.row_mut(i).assign(&X.row(idx));
151        }
152
153        Ok((indices, landmarks))
154    }
155
156    /// K-means based landmark selection
157    fn kmeans_landmarks(
158        &self,
159        X: &ArrayView2<f64>,
160        n_landmarks: usize,
161        rng: &mut Random,
162    ) -> Result<(Vec<usize>, Array2<f64>), SklearsError> {
163        let n_samples = X.nrows();
164        let n_features = X.ncols();
165
166        // Initialize centroids randomly
167        let mut centroids = Array2::<f64>::zeros((n_landmarks, n_features));
168        for i in 0..n_landmarks {
169            let sample_idx = rng.gen_range(0..n_samples);
170            centroids.row_mut(i).assign(&X.row(sample_idx));
171        }
172
173        let mut labels = Array1::<usize>::zeros(n_samples);
174
175        // K-means iterations
176        for _iter in 0..self.max_iter {
177            let mut changed = false;
178
179            // Assign points to nearest centroids
180            for i in 0..n_samples {
181                let mut min_dist = f64::INFINITY;
182                let mut best_cluster = 0;
183
184                for k in 0..n_landmarks {
185                    let dist = self.euclidean_distance(&X.row(i), &centroids.row(k));
186                    if dist < min_dist {
187                        min_dist = dist;
188                        best_cluster = k;
189                    }
190                }
191
192                if labels[i] != best_cluster {
193                    labels[i] = best_cluster;
194                    changed = true;
195                }
196            }
197
198            if !changed {
199                break;
200            }
201
202            // Update centroids
203            for k in 0..n_landmarks {
204                let mut count = 0;
205                let mut sum = Array1::<f64>::zeros(n_features);
206
207                for i in 0..n_samples {
208                    if labels[i] == k {
209                        count += 1;
210                        for j in 0..n_features {
211                            sum[j] += X[[i, j]];
212                        }
213                    }
214                }
215
216                if count > 0 {
217                    for j in 0..n_features {
218                        centroids[[k, j]] = sum[j] / count as f64;
219                    }
220                }
221            }
222        }
223
224        // Find closest actual points to centroids as landmarks
225        let mut landmark_indices = Vec::new();
226        for k in 0..n_landmarks {
227            let mut min_dist = f64::INFINITY;
228            let mut closest_idx = 0;
229
230            for i in 0..n_samples {
231                let dist = self.euclidean_distance(&X.row(i), &centroids.row(k));
232                if dist < min_dist {
233                    min_dist = dist;
234                    closest_idx = i;
235                }
236            }
237
238            landmark_indices.push(closest_idx);
239        }
240
241        // Remove duplicates and get unique landmarks
242        landmark_indices.sort_unstable();
243        landmark_indices.dedup();
244
245        // If we have fewer unique landmarks than requested, add random ones
246        while landmark_indices.len() < n_landmarks {
247            let new_idx = rng.gen_range(0..n_samples);
248            if !landmark_indices.contains(&new_idx) {
249                landmark_indices.push(new_idx);
250            }
251        }
252
253        let mut landmarks = Array2::<f64>::zeros((landmark_indices.len(), n_features));
254        for (i, &idx) in landmark_indices.iter().enumerate() {
255            landmarks.row_mut(i).assign(&X.row(idx));
256        }
257
258        Ok((landmark_indices, landmarks))
259    }
260
261    /// Farthest-first landmark selection
262    fn farthest_first_landmarks(
263        &self,
264        X: &ArrayView2<f64>,
265        n_landmarks: usize,
266        rng: &mut Random,
267    ) -> Result<(Vec<usize>, Array2<f64>), SklearsError> {
268        let n_samples = X.nrows();
269        let n_features = X.ncols();
270
271        let mut landmark_indices = Vec::new();
272
273        // Select first landmark randomly
274        let first_idx = rng.gen_range(0..n_samples);
275        landmark_indices.push(first_idx);
276
277        // Select remaining landmarks by maximizing minimum distance
278        for _ in 1..n_landmarks {
279            let mut max_min_dist = 0.0;
280            let mut best_idx = 0;
281
282            for i in 0..n_samples {
283                if landmark_indices.contains(&i) {
284                    continue;
285                }
286
287                // Find minimum distance to existing landmarks
288                let mut min_dist = f64::INFINITY;
289                for &landmark_idx in &landmark_indices {
290                    let dist = self.euclidean_distance(&X.row(i), &X.row(landmark_idx));
291                    if dist < min_dist {
292                        min_dist = dist;
293                    }
294                }
295
296                // Update if this point has larger minimum distance
297                if min_dist > max_min_dist {
298                    max_min_dist = min_dist;
299                    best_idx = i;
300                }
301            }
302
303            landmark_indices.push(best_idx);
304        }
305
306        let mut landmarks = Array2::<f64>::zeros((n_landmarks, n_features));
307        for (i, &idx) in landmark_indices.iter().enumerate() {
308            landmarks.row_mut(i).assign(&X.row(idx));
309        }
310
311        Ok((landmark_indices, landmarks))
312    }
313
314    /// Density-based landmark selection
315    fn density_based_landmarks(
316        &self,
317        X: &ArrayView2<f64>,
318        n_landmarks: usize,
319        rng: &mut Random,
320    ) -> Result<(Vec<usize>, Array2<f64>), SklearsError> {
321        let n_samples = X.nrows();
322        let n_features = X.ncols();
323
324        // Estimate local density for each point
325        let mut densities = Array1::<f64>::zeros(n_samples);
326        let radius = self.estimate_density_radius(X)?;
327
328        for i in 0..n_samples {
329            let mut neighbor_count = 0;
330            for j in 0..n_samples {
331                if i != j {
332                    let dist = self.euclidean_distance(&X.row(i), &X.row(j));
333                    if dist <= radius {
334                        neighbor_count += 1;
335                    }
336                }
337            }
338            densities[i] = neighbor_count as f64;
339        }
340
341        // Select landmarks with probability proportional to density
342        let mut landmark_indices = Vec::new();
343        let total_density: f64 = densities.sum();
344
345        if total_density > 0.0 {
346            for iteration in 0..n_landmarks {
347                let threshold = rng.gen::<f64>() * total_density;
348                let mut cumulative = 0.0;
349                let previous_len = landmark_indices.len();
350
351                for i in 0..n_samples {
352                    if landmark_indices.contains(&i) {
353                        continue;
354                    }
355
356                    cumulative += densities[i];
357                    if cumulative >= threshold {
358                        landmark_indices.push(i);
359                        break;
360                    }
361                }
362
363                // If we couldn't find a new landmark, add a random one
364                if landmark_indices.len() == previous_len {
365                    for i in 0..n_samples {
366                        if !landmark_indices.contains(&i) {
367                            landmark_indices.push(i);
368                            break;
369                        }
370                    }
371                }
372            }
373        } else {
374            // Fallback to random selection if density calculation fails
375            return self.random_landmarks(X, n_landmarks, rng);
376        }
377
378        let mut landmarks = Array2::<f64>::zeros((landmark_indices.len(), n_features));
379        for (i, &idx) in landmark_indices.iter().enumerate() {
380            landmarks.row_mut(i).assign(&X.row(idx));
381        }
382
383        Ok((landmark_indices, landmarks))
384    }
385
386    /// Estimate radius for density computation
387    fn estimate_density_radius(&self, X: &ArrayView2<f64>) -> Result<f64, SklearsError> {
388        let n_samples = X.nrows();
389        let sample_size = (n_samples / 10).clamp(10, 100);
390
391        let mut distances = Vec::new();
392
393        // Sample pairwise distances
394        for i in 0..sample_size {
395            for j in (i + 1)..sample_size {
396                if i < n_samples && j < n_samples {
397                    let dist = self.euclidean_distance(&X.row(i), &X.row(j));
398                    distances.push(dist);
399                }
400            }
401        }
402
403        if distances.is_empty() {
404            return Ok(1.0);
405        }
406
407        distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
408
409        // Use median distance as radius
410        let median_idx = distances.len() / 2;
411        Ok(distances[median_idx])
412    }
413
414    /// Construct graph based on landmarks
415    fn construct_landmark_graph(
416        &self,
417        X: &ArrayView2<f64>,
418        landmarks: &Array2<f64>,
419        landmark_indices: &[usize],
420    ) -> Result<Array2<f64>, SklearsError> {
421        match self.construction_method.as_str() {
422            "knn_to_landmarks" => self.knn_to_landmarks_graph(X, landmarks, landmark_indices),
423            "rbf_to_landmarks" => self.rbf_to_landmarks_graph(X, landmarks, landmark_indices),
424            "interpolation" => self.interpolation_graph(X, landmarks, landmark_indices),
425            _ => Err(SklearsError::InvalidInput(format!(
426                "Unknown construction method: {}",
427                self.construction_method
428            ))),
429        }
430    }
431
432    /// K-NN to landmarks graph construction
433    fn knn_to_landmarks_graph(
434        &self,
435        X: &ArrayView2<f64>,
436        landmarks: &Array2<f64>,
437        _landmark_indices: &[usize],
438    ) -> Result<Array2<f64>, SklearsError> {
439        let n_samples = X.nrows();
440        let n_landmarks = landmarks.nrows();
441        let mut adjacency = Array2::<f64>::zeros((n_samples, n_samples));
442
443        for i in 0..n_samples {
444            // Find k nearest landmarks for each point
445            let mut landmark_distances: Vec<(f64, usize)> = Vec::new();
446
447            for j in 0..n_landmarks {
448                let dist = self.euclidean_distance(&X.row(i), &landmarks.row(j));
449                landmark_distances.push((dist, j));
450            }
451
452            landmark_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
453
454            let k_landmarks = self.k_neighbors.min(n_landmarks);
455            let mut weights = Vec::new();
456            let mut total_weight = 0.0;
457            #[allow(clippy::needless_range_loop)]
458            for k in 0..k_landmarks {
459                let weight =
460                    (-landmark_distances[k].0.powi(2) / (2.0 * self.bandwidth.powi(2))).exp();
461                weights.push(weight);
462                total_weight += weight;
463            }
464
465            // Normalize weights
466            if total_weight > 0.0 {
467                for weight in &mut weights {
468                    *weight /= total_weight;
469                }
470            }
471
472            // Connect to other points through shared landmarks
473            for j in (i + 1)..n_samples {
474                let mut shared_weight = 0.0;
475
476                // Find landmarks for point j
477                let mut j_landmark_distances: Vec<(f64, usize)> = Vec::new();
478                for l in 0..n_landmarks {
479                    let dist = self.euclidean_distance(&X.row(j), &landmarks.row(l));
480                    j_landmark_distances.push((dist, l));
481                }
482                j_landmark_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
483
484                let mut j_weights = Vec::new();
485                let mut j_total_weight = 0.0;
486
487                #[allow(clippy::needless_range_loop)]
488                for k in 0..k_landmarks {
489                    let weight =
490                        (-j_landmark_distances[k].0.powi(2) / (2.0 * self.bandwidth.powi(2))).exp();
491                    j_weights.push(weight);
492                    j_total_weight += weight;
493                }
494
495                // Normalize weights for point j
496                if j_total_weight > 0.0 {
497                    for weight in &mut j_weights {
498                        *weight /= j_total_weight;
499                    }
500                }
501
502                // Compute shared landmark weight
503                for k in 0..k_landmarks {
504                    for l in 0..k_landmarks {
505                        if landmark_distances[k].1 == j_landmark_distances[l].1 {
506                            shared_weight += weights[k] * j_weights[l];
507                        }
508                    }
509                }
510
511                adjacency[[i, j]] = shared_weight;
512                adjacency[[j, i]] = shared_weight;
513            }
514        }
515
516        Ok(adjacency)
517    }
518
519    /// RBF to landmarks graph construction
520    fn rbf_to_landmarks_graph(
521        &self,
522        X: &ArrayView2<f64>,
523        landmarks: &Array2<f64>,
524        _landmark_indices: &[usize],
525    ) -> Result<Array2<f64>, SklearsError> {
526        let n_samples = X.nrows();
527        let n_landmarks = landmarks.nrows();
528        let mut adjacency = Array2::<f64>::zeros((n_samples, n_samples));
529
530        // Compute point-to-landmark weights
531        let mut point_landmark_weights = Array2::<f64>::zeros((n_samples, n_landmarks));
532
533        for i in 0..n_samples {
534            let mut total_weight = 0.0;
535            for j in 0..n_landmarks {
536                let dist = self.euclidean_distance(&X.row(i), &landmarks.row(j));
537                let weight = (-dist.powi(2) / (2.0 * self.bandwidth.powi(2))).exp();
538                point_landmark_weights[[i, j]] = weight;
539                total_weight += weight;
540            }
541
542            // Normalize weights
543            if total_weight > 0.0 {
544                for j in 0..n_landmarks {
545                    point_landmark_weights[[i, j]] /= total_weight;
546                }
547            }
548        }
549
550        // Compute point-to-point similarities through landmarks
551        for i in 0..n_samples {
552            for j in (i + 1)..n_samples {
553                let mut similarity = 0.0;
554
555                for l in 0..n_landmarks {
556                    similarity += point_landmark_weights[[i, l]] * point_landmark_weights[[j, l]];
557                }
558
559                adjacency[[i, j]] = similarity;
560                adjacency[[j, i]] = similarity;
561            }
562        }
563
564        Ok(adjacency)
565    }
566
567    /// Interpolation-based graph construction
568    fn interpolation_graph(
569        &self,
570        X: &ArrayView2<f64>,
571        landmarks: &Array2<f64>,
572        landmark_indices: &[usize],
573    ) -> Result<Array2<f64>, SklearsError> {
574        let n_samples = X.nrows();
575        let mut adjacency = Array2::<f64>::zeros((n_samples, n_samples));
576
577        // Connect landmarks to each other first
578        for i in 0..landmark_indices.len() {
579            for j in (i + 1)..landmark_indices.len() {
580                let idx_i = landmark_indices[i];
581                let idx_j = landmark_indices[j];
582                let dist = self.euclidean_distance(&X.row(idx_i), &X.row(idx_j));
583                let weight = (-dist.powi(2) / (2.0 * self.bandwidth.powi(2))).exp();
584                adjacency[[idx_i, idx_j]] = weight;
585                adjacency[[idx_j, idx_i]] = weight;
586            }
587        }
588
589        // Connect non-landmark points to landmarks and interpolate
590        for i in 0..n_samples {
591            if landmark_indices.contains(&i) {
592                continue; // Skip landmarks
593            }
594
595            // Find nearest landmarks
596            let mut landmark_distances: Vec<(f64, usize)> = Vec::new();
597            for (l_idx, &landmark_idx) in landmark_indices.iter().enumerate() {
598                let dist = self.euclidean_distance(&X.row(i), &X.row(landmark_idx));
599                landmark_distances.push((dist, landmark_idx));
600            }
601
602            landmark_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
603
604            let k_connect = self.k_neighbors.min(landmark_distances.len());
605
606            // Connect to nearest landmarks
607            #[allow(clippy::needless_range_loop)]
608            for k in 0..k_connect {
609                let landmark_idx = landmark_distances[k].1;
610                let weight =
611                    (-landmark_distances[k].0.powi(2) / (2.0 * self.bandwidth.powi(2))).exp();
612                adjacency[[i, landmark_idx]] = weight;
613                adjacency[[landmark_idx, i]] = weight;
614            }
615        }
616
617        Ok(adjacency)
618    }
619
620    /// Compute Euclidean distance between two vectors
621    fn euclidean_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
622        x1.iter()
623            .zip(x2.iter())
624            .map(|(a, b)| (a - b).powi(2))
625            .sum::<f64>()
626            .sqrt()
627    }
628}
629
630impl Default for LandmarkGraphConstruction {
631    fn default() -> Self {
632        Self::new()
633    }
634}
635
636/// Result of landmark graph construction
637#[derive(Clone, Debug)]
638pub struct LandmarkGraphResult {
639    /// The constructed adjacency matrix
640    pub adjacency_matrix: Array2<f64>,
641    /// Indices of selected landmarks in original data
642    pub landmark_indices: Vec<usize>,
643    /// Landmark points
644    pub landmarks: Array2<f64>,
645}
646
647impl LandmarkGraphResult {
648    /// Get the number of landmarks
649    pub fn n_landmarks(&self) -> usize {
650        self.landmarks.nrows()
651    }
652
653    /// Get the sparsity of the resulting graph
654    pub fn sparsity(&self) -> f64 {
655        let total_edges = self.adjacency_matrix.len();
656        let non_zero_edges = self.adjacency_matrix.iter().filter(|&&x| x > 0.0).count();
657        1.0 - (non_zero_edges as f64 / total_edges as f64)
658    }
659
660    /// Get landmark coverage statistics
661    pub fn landmark_coverage(&self) -> HashMap<String, f64> {
662        let mut stats = HashMap::new();
663        let n_samples = self.adjacency_matrix.nrows();
664        let n_landmarks = self.landmarks.nrows();
665
666        stats.insert(
667            "landmark_ratio".to_string(),
668            n_landmarks as f64 / n_samples as f64,
669        );
670        stats.insert("sparsity".to_string(), self.sparsity());
671        stats.insert("n_landmarks".to_string(), n_landmarks as f64);
672        stats.insert("n_samples".to_string(), n_samples as f64);
673
674        stats
675    }
676}
677
678/// Landmark-based label propagation for large-scale semi-supervised learning
679#[derive(Clone)]
680pub struct LandmarkLabelPropagation {
681    /// Graph construction parameters
682    pub graph_constructor: LandmarkGraphConstruction,
683    /// Maximum iterations for label propagation
684    pub max_iter: usize,
685    /// Convergence tolerance
686    pub tolerance: f64,
687    /// Alpha parameter for label spreading
688    pub alpha: f64,
689}
690
691impl LandmarkLabelPropagation {
692    /// Create a new landmark label propagation instance
693    pub fn new() -> Self {
694        Self {
695            graph_constructor: LandmarkGraphConstruction::new(),
696            max_iter: 1000,
697            tolerance: 1e-6,
698            alpha: 0.2,
699        }
700    }
701
702    /// Set the graph constructor
703    pub fn graph_constructor(mut self, constructor: LandmarkGraphConstruction) -> Self {
704        self.graph_constructor = constructor;
705        self
706    }
707
708    /// Set the maximum iterations
709    pub fn max_iter(mut self, max_iter: usize) -> Self {
710        self.max_iter = max_iter;
711        self
712    }
713
714    /// Set the convergence tolerance
715    pub fn tolerance(mut self, tol: f64) -> Self {
716        self.tolerance = tol;
717        self
718    }
719
720    /// Set the alpha parameter
721    pub fn alpha(mut self, alpha: f64) -> Self {
722        self.alpha = alpha;
723        self
724    }
725
726    /// Perform landmark-based label propagation
727    pub fn fit_predict(
728        &self,
729        X: &ArrayView2<f64>,
730        y: &ArrayView1<i32>,
731    ) -> Result<Array1<i32>, SklearsError> {
732        let n_samples = X.nrows();
733
734        if y.len() != n_samples {
735            return Err(SklearsError::ShapeMismatch {
736                expected: format!("X and y should have same number of samples: {}", n_samples),
737                actual: format!("X has {} samples, y has {} samples", n_samples, y.len()),
738            });
739        }
740
741        // Construct landmark graph
742        let graph_result = self.graph_constructor.fit(X)?;
743
744        // Perform label propagation on the landmark graph
745        self.propagate_labels(&graph_result.adjacency_matrix, y)
746    }
747
748    /// Propagate labels on the constructed graph
749    #[allow(non_snake_case)]
750    fn propagate_labels(
751        &self,
752        adjacency: &Array2<f64>,
753        y: &ArrayView1<i32>,
754    ) -> Result<Array1<i32>, SklearsError> {
755        let n_samples = adjacency.nrows();
756
757        // Identify labeled and unlabeled samples
758        let labeled_mask: Array1<bool> = y.iter().map(|&label| label != -1).collect();
759        let unique_labels: Vec<i32> = y
760            .iter()
761            .filter(|&&label| label != -1)
762            .cloned()
763            .collect::<std::collections::HashSet<_>>()
764            .into_iter()
765            .collect();
766
767        if unique_labels.is_empty() {
768            return Err(SklearsError::InvalidInput(
769                "No labeled samples found".to_string(),
770            ));
771        }
772
773        let n_classes = unique_labels.len();
774
775        // Initialize label probability matrix
776        let mut F = Array2::<f64>::zeros((n_samples, n_classes));
777
778        // Set initial labels for labeled samples
779        for i in 0..n_samples {
780            if labeled_mask[i] {
781                if let Some(class_idx) = unique_labels.iter().position(|&x| x == y[i]) {
782                    F[[i, class_idx]] = 1.0;
783                }
784            }
785        }
786
787        // Normalize adjacency matrix to get transition matrix
788        let P = self.normalize_adjacency(adjacency)?;
789
790        // Iterative label propagation
791        for _iter in 0..self.max_iter {
792            let F_old = F.clone();
793
794            // Propagate labels: F = α * P * F + (1-α) * Y
795            let propagated = P.dot(&F);
796            F = &propagated * self.alpha;
797
798            // Add back original labels with weight (1-α)
799            for i in 0..n_samples {
800                if labeled_mask[i] {
801                    if let Some(class_idx) = unique_labels.iter().position(|&x| x == y[i]) {
802                        F[[i, class_idx]] += 1.0 - self.alpha;
803                    }
804                }
805            }
806
807            // Normalize probabilities
808            for i in 0..n_samples {
809                let row_sum: f64 = F.row(i).sum();
810                if row_sum > 0.0 {
811                    for j in 0..n_classes {
812                        F[[i, j]] /= row_sum;
813                    }
814                }
815            }
816
817            // Check convergence
818            let change = (&F - &F_old).iter().map(|x| x.abs()).sum::<f64>();
819            if change < self.tolerance {
820                break;
821            }
822        }
823
824        // Convert probabilities to labels
825        let mut labels = Array1::zeros(n_samples);
826        for i in 0..n_samples {
827            let mut max_prob = 0.0;
828            let mut max_class = 0;
829
830            for j in 0..n_classes {
831                if F[[i, j]] > max_prob {
832                    max_prob = F[[i, j]];
833                    max_class = j;
834                }
835            }
836
837            labels[i] = unique_labels[max_class];
838        }
839
840        Ok(labels)
841    }
842
843    /// Normalize adjacency matrix to transition matrix
844    fn normalize_adjacency(&self, adjacency: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
845        let n_samples = adjacency.nrows();
846        let mut P = adjacency.clone();
847
848        for i in 0..n_samples {
849            let row_sum: f64 = P.row(i).sum();
850            if row_sum > 0.0 {
851                for j in 0..n_samples {
852                    P[[i, j]] /= row_sum;
853                }
854            }
855        }
856
857        Ok(P)
858    }
859}
860
861impl Default for LandmarkLabelPropagation {
862    fn default() -> Self {
863        Self::new()
864    }
865}
866
867#[allow(non_snake_case)]
868#[cfg(test)]
869mod tests {
870    use super::*;
871    use approx::assert_abs_diff_eq;
872    use scirs2_core::array;
873
874    #[test]
875    #[allow(non_snake_case)]
876    fn test_landmark_graph_construction_random() {
877        let X = array![
878            [1.0, 2.0],
879            [2.0, 3.0],
880            [3.0, 4.0],
881            [4.0, 5.0],
882            [5.0, 6.0],
883            [6.0, 7.0]
884        ];
885
886        let lgc = LandmarkGraphConstruction::new()
887            .n_landmarks(3)
888            .selection_strategy("random".to_string())
889            .random_state(42);
890
891        let result = lgc.fit(&X.view());
892        assert!(result.is_ok());
893
894        let graph_result = result.unwrap();
895        assert_eq!(graph_result.adjacency_matrix.dim(), (6, 6));
896        assert_eq!(graph_result.n_landmarks(), 3);
897        assert_eq!(graph_result.landmark_indices.len(), 3);
898    }
899
900    #[test]
901    #[allow(non_snake_case)]
902    fn test_landmark_graph_construction_kmeans() {
903        let X = array![
904            [1.0, 1.0],
905            [1.5, 1.5],
906            [2.0, 2.0],
907            [8.0, 8.0],
908            [8.5, 8.5],
909            [9.0, 9.0]
910        ];
911
912        let lgc = LandmarkGraphConstruction::new()
913            .n_landmarks(2)
914            .selection_strategy("kmeans".to_string())
915            .random_state(42);
916
917        let result = lgc.fit(&X.view());
918        assert!(result.is_ok());
919
920        let graph_result = result.unwrap();
921        assert_eq!(graph_result.adjacency_matrix.dim(), (6, 6));
922        assert_eq!(graph_result.n_landmarks(), 2);
923    }
924
925    #[test]
926    #[allow(non_snake_case)]
927    fn test_landmark_graph_construction_farthest_first() {
928        let X = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
929
930        let lgc = LandmarkGraphConstruction::new()
931            .n_landmarks(2)
932            .selection_strategy("farthest_first".to_string())
933            .random_state(42);
934
935        let result = lgc.fit(&X.view());
936        assert!(result.is_ok());
937
938        let graph_result = result.unwrap();
939        assert_eq!(graph_result.adjacency_matrix.dim(), (4, 4));
940        assert_eq!(graph_result.n_landmarks(), 2);
941    }
942
943    #[test]
944    #[allow(non_snake_case)]
945    fn test_landmark_graph_construction_density_based() {
946        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
947
948        let lgc = LandmarkGraphConstruction::new()
949            .n_landmarks(2)
950            .selection_strategy("density_based".to_string())
951            .random_state(42);
952
953        let result = lgc.fit(&X.view());
954        assert!(result.is_ok());
955
956        let graph_result = result.unwrap();
957        assert_eq!(graph_result.adjacency_matrix.dim(), (4, 4));
958        assert_eq!(graph_result.n_landmarks(), 2);
959    }
960
961    #[test]
962    #[allow(non_snake_case)]
963    fn test_different_construction_methods() {
964        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
965
966        let methods = vec!["knn_to_landmarks", "rbf_to_landmarks", "interpolation"];
967
968        for method in methods {
969            let lgc = LandmarkGraphConstruction::new()
970                .n_landmarks(2)
971                .construction_method(method.to_string())
972                .random_state(42);
973
974            let result = lgc.fit(&X.view());
975            assert!(result.is_ok());
976
977            let graph_result = result.unwrap();
978            assert_eq!(graph_result.adjacency_matrix.dim(), (4, 4));
979        }
980    }
981
982    #[test]
983    #[allow(non_snake_case)]
984    fn test_landmark_label_propagation() {
985        let X = array![
986            [1.0, 2.0],
987            [2.0, 3.0],
988            [3.0, 4.0],
989            [4.0, 5.0],
990            [5.0, 6.0],
991            [6.0, 7.0]
992        ];
993        let y = array![0, 1, -1, -1, -1, -1]; // First two are labeled
994
995        let llp = LandmarkLabelPropagation::new();
996        let mut graph_constructor = LandmarkGraphConstruction::new()
997            .n_landmarks(3)
998            .random_state(42);
999
1000        let llp = llp.graph_constructor(graph_constructor);
1001
1002        let result = llp.fit_predict(&X.view(), &y.view());
1003        assert!(result.is_ok());
1004
1005        let labels = result.unwrap();
1006        assert_eq!(labels.len(), 6);
1007
1008        // Check that labeled samples keep their labels
1009        assert_eq!(labels[0], 0);
1010        assert_eq!(labels[1], 1);
1011
1012        // Check that all labels are valid
1013        for &label in labels.iter() {
1014            assert!(label == 0 || label == 1);
1015        }
1016    }
1017
1018    #[test]
1019    fn test_graph_result_methods() {
1020        let adjacency = array![
1021            [0.0, 0.8, 0.0, 0.2],
1022            [0.8, 0.0, 0.3, 0.0],
1023            [0.0, 0.3, 0.0, 0.9],
1024            [0.2, 0.0, 0.9, 0.0]
1025        ];
1026
1027        let landmark_indices = vec![0, 2];
1028        let landmarks = array![[1.0, 2.0], [3.0, 4.0]];
1029
1030        let result = LandmarkGraphResult {
1031            adjacency_matrix: adjacency,
1032            landmark_indices,
1033            landmarks,
1034        };
1035
1036        assert_eq!(result.n_landmarks(), 2);
1037        assert!(result.sparsity() > 0.0);
1038        assert!(result.sparsity() < 1.0);
1039
1040        let coverage = result.landmark_coverage();
1041        assert!(coverage.contains_key("landmark_ratio"));
1042        assert!(coverage.contains_key("sparsity"));
1043        assert_eq!(coverage["n_landmarks"], 2.0);
1044        assert_eq!(coverage["n_samples"], 4.0);
1045    }
1046
1047    #[test]
1048    fn test_landmark_graph_construction_builder() {
1049        let lgc = LandmarkGraphConstruction::new()
1050            .n_landmarks(50)
1051            .k_neighbors(8)
1052            .selection_strategy("farthest_first".to_string())
1053            .construction_method("rbf_to_landmarks".to_string())
1054            .bandwidth(2.0)
1055            .max_iter(200)
1056            .random_state(123);
1057
1058        assert_eq!(lgc.n_landmarks, 50);
1059        assert_eq!(lgc.k_neighbors, 8);
1060        assert_eq!(lgc.selection_strategy, "farthest_first");
1061        assert_eq!(lgc.construction_method, "rbf_to_landmarks");
1062        assert_eq!(lgc.bandwidth, 2.0);
1063        assert_eq!(lgc.max_iter, 200);
1064        assert_eq!(lgc.random_state, Some(123));
1065    }
1066
1067    #[test]
1068    #[allow(non_snake_case)]
1069    fn test_error_cases() {
1070        let lgc =
1071            LandmarkGraphConstruction::new().selection_strategy("invalid_strategy".to_string());
1072
1073        let X = array![[1.0, 2.0], [2.0, 3.0]];
1074        let result = lgc.fit(&X.view());
1075        assert!(result.is_err());
1076
1077        let lgc =
1078            LandmarkGraphConstruction::new().construction_method("invalid_method".to_string());
1079
1080        let result = lgc.fit(&X.view());
1081        assert!(result.is_err());
1082
1083        // Test with empty dataset
1084        let empty_X = Array2::<f64>::zeros((0, 2));
1085        let lgc = LandmarkGraphConstruction::new();
1086        let result = lgc.fit(&empty_X.view());
1087        assert!(result.is_err());
1088
1089        // Test label propagation with mismatched dimensions
1090        let llp = LandmarkLabelPropagation::new();
1091        let y = array![0]; // Wrong size
1092        let result = llp.fit_predict(&X.view(), &y.view());
1093        assert!(result.is_err());
1094
1095        // Test with no labeled samples
1096        let y_unlabeled = array![-1, -1];
1097        let result = llp.fit_predict(&X.view(), &y_unlabeled.view());
1098        assert!(result.is_err());
1099    }
1100}