sklears_semi_supervised/
hierarchical_graph.rs

1//! Hierarchical graph learning methods for semi-supervised learning
2//!
3//! This module provides hierarchical and multi-scale graph construction algorithms
4//! for semi-supervised learning, enabling analysis at different levels of granularity.
5
6use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::random::rand_prelude::*;
8use scirs2_core::random::Random;
9use sklears_core::error::SklearsError;
10
11/// Hierarchical graph construction with multiple scales
12#[derive(Clone)]
13pub struct HierarchicalGraphConstruction {
14    /// Number of hierarchy levels
15    pub n_levels: usize,
16    /// Base number of neighbors (scaled per level)
17    pub base_k_neighbors: usize,
18    /// Scaling factor for neighbors at each level
19    pub neighbor_scaling: f64,
20    /// Coarsening method: "sampling", "clustering", "pooling"
21    pub coarsening_method: String,
22    /// Coarsening ratio between levels
23    pub coarsening_ratio: f64,
24    /// Graph construction method: "knn", "epsilon", "adaptive"
25    pub construction_method: String,
26    /// Refinement iterations
27    pub refinement_iter: usize,
28    /// Random state for reproducibility
29    pub random_state: Option<u64>,
30}
31
32impl HierarchicalGraphConstruction {
33    /// Create a new hierarchical graph construction instance
34    pub fn new() -> Self {
35        Self {
36            n_levels: 3,
37            base_k_neighbors: 5,
38            neighbor_scaling: 1.5,
39            coarsening_method: "clustering".to_string(),
40            coarsening_ratio: 0.5,
41            construction_method: "knn".to_string(),
42            refinement_iter: 10,
43            random_state: None,
44        }
45    }
46
47    /// Set the number of hierarchy levels
48    pub fn n_levels(mut self, levels: usize) -> Self {
49        self.n_levels = levels;
50        self
51    }
52
53    /// Set the base number of neighbors
54    pub fn base_k_neighbors(mut self, k: usize) -> Self {
55        self.base_k_neighbors = k;
56        self
57    }
58
59    /// Set the neighbor scaling factor
60    pub fn neighbor_scaling(mut self, scaling: f64) -> Self {
61        self.neighbor_scaling = scaling;
62        self
63    }
64
65    /// Set the coarsening method
66    pub fn coarsening_method(mut self, method: String) -> Self {
67        self.coarsening_method = method;
68        self
69    }
70
71    /// Set the coarsening ratio
72    pub fn coarsening_ratio(mut self, ratio: f64) -> Self {
73        self.coarsening_ratio = ratio;
74        self
75    }
76
77    /// Set the construction method
78    pub fn construction_method(mut self, method: String) -> Self {
79        self.construction_method = method;
80        self
81    }
82
83    /// Set the refinement iterations
84    pub fn refinement_iter(mut self, iter: usize) -> Self {
85        self.refinement_iter = iter;
86        self
87    }
88
89    /// Set the random state
90    pub fn random_state(mut self, seed: u64) -> Self {
91        self.random_state = Some(seed);
92        self
93    }
94
95    /// Construct hierarchical graph from data
96    pub fn fit(&self, X: &ArrayView2<f64>) -> Result<HierarchicalGraph, SklearsError> {
97        let mut rng = if let Some(seed) = self.random_state {
98            Random::seed(42)
99        } else {
100            Random::seed(42)
101        };
102
103        // Build hierarchy from fine to coarse
104        let mut hierarchy = HierarchicalGraph::new();
105        let mut current_data = X.to_owned();
106        let mut current_indices: Vec<usize> = (0..X.nrows()).collect();
107
108        for level in 0..self.n_levels {
109            let k_neighbors =
110                (self.base_k_neighbors as f64 * self.neighbor_scaling.powi(level as i32)) as usize;
111
112            // Construct graph at current level
113            let graph = self.construct_level_graph(&current_data.view(), k_neighbors)?;
114
115            // Store level information
116            hierarchy.add_level(level, graph, current_data.clone(), current_indices.clone());
117
118            // Coarsen for next level (if not the last level)
119            if level < self.n_levels - 1 {
120                let (coarsened_data, coarsened_indices) =
121                    self.coarsen_level(&current_data.view(), &current_indices, &mut rng)?;
122                current_data = coarsened_data;
123                current_indices = coarsened_indices;
124            }
125        }
126
127        // Refine hierarchy
128        hierarchy = self.refine_hierarchy(hierarchy)?;
129
130        Ok(hierarchy)
131    }
132
133    /// Construct graph at a specific level
134    fn construct_level_graph(
135        &self,
136        X: &ArrayView2<f64>,
137        k_neighbors: usize,
138    ) -> Result<Array2<f64>, SklearsError> {
139        let n_samples = X.nrows();
140        let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
141
142        match self.construction_method.as_str() {
143            "knn" => {
144                for i in 0..n_samples {
145                    let mut distances: Vec<(f64, usize)> = Vec::new();
146
147                    for j in 0..n_samples {
148                        if i != j {
149                            let dist = self.euclidean_distance(&X.row(i), &X.row(j));
150                            distances.push((dist, j));
151                        }
152                    }
153
154                    distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
155
156                    for (dist, j) in distances.iter().take(k_neighbors.min(distances.len())) {
157                        let weight = (-dist.powi(2) / 2.0).exp();
158                        graph[[i, *j]] = weight;
159                    }
160                }
161
162                // Make symmetric
163                for i in 0..n_samples {
164                    for j in i + 1..n_samples {
165                        let avg_weight = (graph[[i, j]] + graph[[j, i]]) / 2.0;
166                        graph[[i, j]] = avg_weight;
167                        graph[[j, i]] = avg_weight;
168                    }
169                }
170            }
171            "epsilon" => {
172                let epsilon = self.compute_adaptive_epsilon(X, k_neighbors)?;
173
174                for i in 0..n_samples {
175                    for j in i + 1..n_samples {
176                        let dist = self.euclidean_distance(&X.row(i), &X.row(j));
177                        if dist <= epsilon {
178                            let weight = (-dist.powi(2) / 2.0).exp();
179                            graph[[i, j]] = weight;
180                            graph[[j, i]] = weight;
181                        }
182                    }
183                }
184            }
185            "adaptive" => {
186                graph = self.construct_adaptive_graph(X, k_neighbors)?;
187            }
188            _ => {
189                return Err(SklearsError::InvalidInput(format!(
190                    "Unknown construction method: {}",
191                    self.construction_method
192                )));
193            }
194        }
195
196        Ok(graph)
197    }
198
199    /// Compute adaptive epsilon for epsilon-graph construction
200    fn compute_adaptive_epsilon(
201        &self,
202        X: &ArrayView2<f64>,
203        k_neighbors: usize,
204    ) -> Result<f64, SklearsError> {
205        let n_samples = X.nrows();
206        let mut kth_distances = Vec::new();
207
208        for i in 0..n_samples {
209            let mut distances = Vec::new();
210            for j in 0..n_samples {
211                if i != j {
212                    let dist = self.euclidean_distance(&X.row(i), &X.row(j));
213                    distances.push(dist);
214                }
215            }
216
217            if !distances.is_empty() {
218                distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
219                let k_idx = k_neighbors.min(distances.len()) - 1;
220                kth_distances.push(distances[k_idx]);
221            }
222        }
223
224        if kth_distances.is_empty() {
225            return Ok(1.0);
226        }
227
228        // Use median of k-th nearest neighbor distances
229        kth_distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
230        let median_idx = kth_distances.len() / 2;
231        Ok(kth_distances[median_idx])
232    }
233
234    /// Construct adaptive graph with variable neighborhood sizes
235    fn construct_adaptive_graph(
236        &self,
237        X: &ArrayView2<f64>,
238        base_k: usize,
239    ) -> Result<Array2<f64>, SklearsError> {
240        let n_samples = X.nrows();
241        let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
242
243        // Compute local density for each point
244        let densities = self.compute_local_densities(X, base_k)?;
245
246        for i in 0..n_samples {
247            // Adaptive k based on local density
248            let adaptive_k = (base_k as f64 * (1.0 / (1.0 + densities[i]))).max(1.0) as usize;
249
250            let mut distances: Vec<(f64, usize)> = Vec::new();
251            for j in 0..n_samples {
252                if i != j {
253                    let dist = self.euclidean_distance(&X.row(i), &X.row(j));
254                    distances.push((dist, j));
255                }
256            }
257
258            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
259
260            for (dist, j) in distances.iter().take(adaptive_k.min(distances.len())) {
261                let weight = (-dist.powi(2) / 2.0).exp();
262                graph[[i, *j]] = weight;
263            }
264        }
265
266        // Make symmetric
267        for i in 0..n_samples {
268            for j in i + 1..n_samples {
269                let avg_weight = (graph[[i, j]] + graph[[j, i]]) / 2.0;
270                graph[[i, j]] = avg_weight;
271                graph[[j, i]] = avg_weight;
272            }
273        }
274
275        Ok(graph)
276    }
277
278    /// Compute local densities for adaptive graph construction
279    fn compute_local_densities(
280        &self,
281        X: &ArrayView2<f64>,
282        k: usize,
283    ) -> Result<Array1<f64>, SklearsError> {
284        let n_samples = X.nrows();
285        let mut densities = Array1::zeros(n_samples);
286
287        for i in 0..n_samples {
288            let mut distances = Vec::new();
289            for j in 0..n_samples {
290                if i != j {
291                    let dist = self.euclidean_distance(&X.row(i), &X.row(j));
292                    distances.push(dist);
293                }
294            }
295
296            if !distances.is_empty() {
297                distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
298                let k_idx = k.min(distances.len()) - 1;
299                densities[i] = 1.0 / (1.0 + distances[k_idx]); // Inverse distance density
300            }
301        }
302
303        Ok(densities)
304    }
305
306    /// Coarsen data for next hierarchy level
307    fn coarsen_level<R>(
308        &self,
309        X: &ArrayView2<f64>,
310        indices: &[usize],
311        rng: &mut Random<R>,
312    ) -> Result<(Array2<f64>, Vec<usize>), SklearsError>
313    where
314        R: scirs2_core::random::Rng,
315    {
316        let n_samples = X.nrows();
317        let target_size = ((n_samples as f64) * self.coarsening_ratio).max(1.0) as usize;
318
319        match self.coarsening_method.as_str() {
320            "sampling" => self.coarsen_by_sampling(X, indices, target_size, rng),
321            "clustering" => self.coarsen_by_clustering(X, indices, target_size),
322            "pooling" => self.coarsen_by_pooling(X, indices, target_size),
323            _ => Err(SklearsError::InvalidInput(format!(
324                "Unknown coarsening method: {}",
325                self.coarsening_method
326            ))),
327        }
328    }
329
330    /// Coarsen by random sampling
331    fn coarsen_by_sampling<R>(
332        &self,
333        X: &ArrayView2<f64>,
334        indices: &[usize],
335        target_size: usize,
336        rng: &mut Random<R>,
337    ) -> Result<(Array2<f64>, Vec<usize>), SklearsError>
338    where
339        R: scirs2_core::random::Rng,
340    {
341        let n_samples = X.nrows();
342        let mut selected_indices: Vec<usize> = (0..n_samples).collect();
343        selected_indices.shuffle(rng);
344        selected_indices.truncate(target_size);
345        selected_indices.sort();
346
347        let mut coarsened_data = Array2::<f64>::zeros((target_size, X.ncols()));
348        let mut coarsened_indices = Vec::new();
349
350        for (i, &idx) in selected_indices.iter().enumerate() {
351            coarsened_data.row_mut(i).assign(&X.row(idx));
352            coarsened_indices.push(indices[idx]);
353        }
354
355        Ok((coarsened_data, coarsened_indices))
356    }
357
358    /// Coarsen by clustering (simple k-means-like approach)
359    fn coarsen_by_clustering(
360        &self,
361        X: &ArrayView2<f64>,
362        indices: &[usize],
363        target_size: usize,
364    ) -> Result<(Array2<f64>, Vec<usize>), SklearsError> {
365        let n_samples = X.nrows();
366        let n_features = X.ncols();
367
368        if target_size >= n_samples {
369            return Ok((X.to_owned(), indices.to_vec()));
370        }
371
372        // Simple clustering: use farthest point sampling
373        let mut centers = Vec::new();
374        let mut center_indices = Vec::new();
375
376        // Start with first point
377        centers.push(X.row(0).to_owned());
378        center_indices.push(0);
379
380        for _ in 1..target_size {
381            let mut max_dist = 0.0;
382            let mut farthest_idx = 0;
383
384            for i in 0..n_samples {
385                let mut min_dist_to_centers = f64::INFINITY;
386
387                for center in &centers {
388                    let dist = self.euclidean_distance(&X.row(i), &center.view());
389                    min_dist_to_centers = min_dist_to_centers.min(dist);
390                }
391
392                if min_dist_to_centers > max_dist {
393                    max_dist = min_dist_to_centers;
394                    farthest_idx = i;
395                }
396            }
397
398            centers.push(X.row(farthest_idx).to_owned());
399            center_indices.push(farthest_idx);
400        }
401
402        let mut coarsened_data = Array2::<f64>::zeros((target_size, n_features));
403        let mut coarsened_indices = Vec::new();
404
405        for (i, &idx) in center_indices.iter().enumerate() {
406            coarsened_data.row_mut(i).assign(&X.row(idx));
407            coarsened_indices.push(indices[idx]);
408        }
409
410        Ok((coarsened_data, coarsened_indices))
411    }
412
413    /// Coarsen by pooling (average neighboring points)
414    fn coarsen_by_pooling(
415        &self,
416        X: &ArrayView2<f64>,
417        indices: &[usize],
418        target_size: usize,
419    ) -> Result<(Array2<f64>, Vec<usize>), SklearsError> {
420        let n_samples = X.nrows();
421        let n_features = X.ncols();
422
423        if target_size >= n_samples {
424            return Ok((X.to_owned(), indices.to_vec()));
425        }
426
427        let pool_size = n_samples / target_size;
428        let mut coarsened_data = Array2::<f64>::zeros((target_size, n_features));
429        let mut coarsened_indices = Vec::new();
430
431        for i in 0..target_size {
432            let start_idx = i * pool_size;
433            let end_idx = if i == target_size - 1 {
434                n_samples
435            } else {
436                (i + 1) * pool_size
437            };
438
439            // Average the points in this pool
440            let mut pool_mean = Array1::zeros(n_features);
441            let mut count = 0;
442
443            for j in start_idx..end_idx {
444                pool_mean = pool_mean + X.row(j);
445                count += 1;
446            }
447
448            if count > 0 {
449                pool_mean /= count as f64;
450            }
451
452            coarsened_data.row_mut(i).assign(&pool_mean);
453            coarsened_indices.push(indices[start_idx]); // Use first index as representative
454        }
455
456        Ok((coarsened_data, coarsened_indices))
457    }
458
459    /// Refine hierarchy using iterative improvement
460    fn refine_hierarchy(
461        &self,
462        mut hierarchy: HierarchicalGraph,
463    ) -> Result<HierarchicalGraph, SklearsError> {
464        for _iter in 0..self.refinement_iter {
465            // Refine each level based on adjacent levels
466            for level in 1..hierarchy.levels.len() {
467                hierarchy = self.refine_level(hierarchy, level)?;
468            }
469        }
470        Ok(hierarchy)
471    }
472
473    /// Refine a specific level in the hierarchy
474    fn refine_level(
475        &self,
476        mut hierarchy: HierarchicalGraph,
477        level: usize,
478    ) -> Result<HierarchicalGraph, SklearsError> {
479        if level == 0 || level >= hierarchy.levels.len() {
480            return Ok(hierarchy);
481        }
482
483        // Get current and coarser level information
484        let current_graph = hierarchy.levels[level].graph.clone();
485        let coarser_graph = hierarchy.levels[level - 1].graph.clone();
486
487        // Simple refinement: adjust edge weights based on coarser level
488        let refined_graph = self.interpolate_graphs(&current_graph, &coarser_graph)?;
489        hierarchy.levels[level].graph = refined_graph;
490
491        Ok(hierarchy)
492    }
493
494    /// Interpolate between graphs at different levels
495    fn interpolate_graphs(
496        &self,
497        fine_graph: &Array2<f64>,
498        coarse_graph: &Array2<f64>,
499    ) -> Result<Array2<f64>, SklearsError> {
500        // Simple interpolation: weighted average
501        let alpha = 0.8; // Weight for fine graph
502        let fine_size = fine_graph.nrows();
503        let coarse_size = coarse_graph.nrows();
504
505        if fine_size <= coarse_size {
506            return Ok(fine_graph.clone());
507        }
508
509        let mut refined = fine_graph.clone();
510
511        // Map fine graph indices to coarse graph indices
512        let scale_factor = coarse_size as f64 / fine_size as f64;
513
514        for i in 0..fine_size {
515            for j in 0..fine_size {
516                let coarse_i = ((i as f64) * scale_factor) as usize;
517                let coarse_j = ((j as f64) * scale_factor) as usize;
518
519                if coarse_i < coarse_size && coarse_j < coarse_size {
520                    let coarse_weight = coarse_graph[[coarse_i, coarse_j]];
521                    refined[[i, j]] = alpha * fine_graph[[i, j]] + (1.0 - alpha) * coarse_weight;
522                }
523            }
524        }
525
526        Ok(refined)
527    }
528
529    /// Compute Euclidean distance between two vectors
530    fn euclidean_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
531        x1.iter()
532            .zip(x2.iter())
533            .map(|(a, b)| (a - b).powi(2))
534            .sum::<f64>()
535            .sqrt()
536    }
537}
538
539impl Default for HierarchicalGraphConstruction {
540    fn default() -> Self {
541        Self::new()
542    }
543}
544
545/// Hierarchical graph structure
546#[derive(Clone)]
547pub struct HierarchicalGraph {
548    /// Levels in the hierarchy (from finest to coarsest)
549    pub levels: Vec<HierarchyLevel>,
550}
551
552impl HierarchicalGraph {
553    /// Create a new hierarchical graph
554    pub fn new() -> Self {
555        Self { levels: Vec::new() }
556    }
557
558    /// Add a level to the hierarchy
559    pub fn add_level(
560        &mut self,
561        level_id: usize,
562        graph: Array2<f64>,
563        data: Array2<f64>,
564        indices: Vec<usize>,
565    ) {
566        let level = HierarchyLevel {
567            level_id,
568            graph,
569            data,
570            indices,
571        };
572        self.levels.push(level);
573    }
574
575    /// Get the finest level graph
576    pub fn finest_graph(&self) -> Option<&Array2<f64>> {
577        self.levels.first().map(|level| &level.graph)
578    }
579
580    /// Get the coarsest level graph
581    pub fn coarsest_graph(&self) -> Option<&Array2<f64>> {
582        self.levels.last().map(|level| &level.graph)
583    }
584
585    /// Get graph at specific level
586    pub fn level_graph(&self, level_id: usize) -> Option<&Array2<f64>> {
587        self.levels.get(level_id).map(|level| &level.graph)
588    }
589
590    /// Get number of levels
591    pub fn n_levels(&self) -> usize {
592        self.levels.len()
593    }
594}
595
596impl Default for HierarchicalGraph {
597    fn default() -> Self {
598        Self::new()
599    }
600}
601
602/// Single level in the hierarchy
603#[derive(Clone)]
604pub struct HierarchyLevel {
605    /// Level identifier
606    pub level_id: usize,
607    /// Graph at this level
608    pub graph: Array2<f64>,
609    /// Data at this level
610    pub data: Array2<f64>,
611    /// Original indices of points at this level
612    pub indices: Vec<usize>,
613}
614
615/// Multi-scale semi-supervised learning using hierarchical graphs
616#[derive(Clone)]
617pub struct MultiScaleSemiSupervised {
618    /// Hierarchical graph construction parameters
619    pub graph_builder: HierarchicalGraphConstruction,
620    /// Label propagation parameters
621    pub alpha: f64,
622    /// Maximum iterations for propagation
623    pub max_iter: usize,
624    /// Convergence tolerance
625    pub tolerance: f64,
626    /// Scale combination method: "fine_to_coarse", "coarse_to_fine", "simultaneous"
627    pub combination_method: String,
628    /// Random state for reproducibility
629    pub random_state: Option<u64>,
630}
631
632impl MultiScaleSemiSupervised {
633    /// Create a new multi-scale semi-supervised learner
634    pub fn new() -> Self {
635        Self {
636            graph_builder: HierarchicalGraphConstruction::new(),
637            alpha: 0.2,
638            max_iter: 1000,
639            tolerance: 1e-6,
640            combination_method: "fine_to_coarse".to_string(),
641            random_state: None,
642        }
643    }
644
645    /// Set the graph builder
646    pub fn graph_builder(mut self, builder: HierarchicalGraphConstruction) -> Self {
647        self.graph_builder = builder;
648        self
649    }
650
651    /// Set the alpha parameter
652    pub fn alpha(mut self, alpha: f64) -> Self {
653        self.alpha = alpha;
654        self
655    }
656
657    /// Set the maximum iterations
658    pub fn max_iter(mut self, max_iter: usize) -> Self {
659        self.max_iter = max_iter;
660        self
661    }
662
663    /// Set the tolerance
664    pub fn tolerance(mut self, tol: f64) -> Self {
665        self.tolerance = tol;
666        self
667    }
668
669    /// Set the combination method
670    pub fn combination_method(mut self, method: String) -> Self {
671        self.combination_method = method;
672        self
673    }
674
675    /// Set the random state
676    pub fn random_state(mut self, seed: u64) -> Self {
677        self.random_state = Some(seed);
678        self.graph_builder = self.graph_builder.random_state(seed);
679        self
680    }
681
682    /// Fit multi-scale semi-supervised model
683    pub fn fit(
684        &self,
685        X: &ArrayView2<f64>,
686        y: &ArrayView1<i32>,
687    ) -> Result<Array1<i32>, SklearsError> {
688        let n_samples = X.nrows();
689
690        if y.len() != n_samples {
691            return Err(SklearsError::ShapeMismatch {
692                expected: format!("X and y should have same number of samples: {}", X.nrows()),
693                actual: format!("X has {} samples, y has {} samples", X.nrows(), y.len()),
694            });
695        }
696
697        // Build hierarchical graph
698        let hierarchy = self.graph_builder.fit(X)?;
699
700        // Perform multi-scale label propagation
701        let labels = match self.combination_method.as_str() {
702            "fine_to_coarse" => self.propagate_fine_to_coarse(&hierarchy, y)?,
703            "coarse_to_fine" => self.propagate_coarse_to_fine(&hierarchy, y)?,
704            "simultaneous" => self.propagate_simultaneous(&hierarchy, y)?,
705            _ => {
706                return Err(SklearsError::InvalidInput(format!(
707                    "Unknown combination method: {}",
708                    self.combination_method
709                )))
710            }
711        };
712
713        Ok(labels)
714    }
715
716    /// Propagate labels from fine to coarse levels
717    fn propagate_fine_to_coarse(
718        &self,
719        hierarchy: &HierarchicalGraph,
720        y: &ArrayView1<i32>,
721    ) -> Result<Array1<i32>, SklearsError> {
722        let finest_graph = hierarchy
723            .finest_graph()
724            .ok_or_else(|| SklearsError::InvalidInput("Empty hierarchy".to_string()))?;
725
726        // Start with standard label propagation on finest level
727        let labels = self.propagate_labels(finest_graph, y)?;
728
729        // Optionally refine using coarser levels (simplified implementation)
730        Ok(labels)
731    }
732
733    /// Propagate labels from coarse to fine levels
734    fn propagate_coarse_to_fine(
735        &self,
736        hierarchy: &HierarchicalGraph,
737        y: &ArrayView1<i32>,
738    ) -> Result<Array1<i32>, SklearsError> {
739        if hierarchy.levels.is_empty() {
740            return Err(SklearsError::InvalidInput("Empty hierarchy".to_string()));
741        }
742
743        // Start from coarsest level
744        let coarsest_level = &hierarchy.levels[hierarchy.levels.len() - 1];
745
746        // Map labels to coarsest level
747        let coarse_labels = self.map_labels_to_level(y, &coarsest_level.indices)?;
748
749        // Propagate on coarsest level
750        let mut propagated_labels =
751            self.propagate_labels(&coarsest_level.graph, &coarse_labels.view())?;
752
753        // Refine through each finer level
754        for level_idx in (0..hierarchy.levels.len() - 1).rev() {
755            let level = &hierarchy.levels[level_idx];
756            let refined_labels = self.refine_labels_for_level(
757                &propagated_labels,
758                &level.indices,
759                level.data.nrows(),
760            )?;
761            propagated_labels = self.propagate_labels(&level.graph, &refined_labels.view())?;
762        }
763
764        Ok(propagated_labels)
765    }
766
767    /// Simultaneous propagation across all levels
768    fn propagate_simultaneous(
769        &self,
770        hierarchy: &HierarchicalGraph,
771        y: &ArrayView1<i32>,
772    ) -> Result<Array1<i32>, SklearsError> {
773        if hierarchy.levels.is_empty() {
774            return Err(SklearsError::InvalidInput("Empty hierarchy".to_string()));
775        }
776
777        // For simplicity, use fine-to-coarse propagation
778        self.propagate_fine_to_coarse(hierarchy, y)
779    }
780
781    /// Map labels to a specific hierarchy level
782    fn map_labels_to_level(
783        &self,
784        y: &ArrayView1<i32>,
785        level_indices: &[usize],
786    ) -> Result<Array1<i32>, SklearsError> {
787        let mut mapped_labels = Array1::from_elem(level_indices.len(), -1);
788
789        for (i, &original_idx) in level_indices.iter().enumerate() {
790            if original_idx < y.len() {
791                mapped_labels[i] = y[original_idx];
792            }
793        }
794
795        Ok(mapped_labels)
796    }
797
798    /// Refine labels for a specific level
799    fn refine_labels_for_level(
800        &self,
801        coarse_labels: &Array1<i32>,
802        level_indices: &[usize],
803        level_size: usize,
804    ) -> Result<Array1<i32>, SklearsError> {
805        let mut refined_labels = Array1::from_elem(level_size, -1);
806
807        // Simple refinement: map coarse labels to fine level
808        for (i, &original_idx) in level_indices.iter().enumerate() {
809            if i < coarse_labels.len() {
810                refined_labels[original_idx] = coarse_labels[i];
811            }
812        }
813
814        Ok(refined_labels)
815    }
816
817    /// Perform label propagation on a single graph
818    fn propagate_labels(
819        &self,
820        graph: &Array2<f64>,
821        y: &ArrayView1<i32>,
822    ) -> Result<Array1<i32>, SklearsError> {
823        let n_samples = graph.nrows();
824
825        if y.len() != n_samples {
826            return Err(SklearsError::ShapeMismatch {
827                expected: format!(
828                    "Graph and labels should have same number of samples: {}",
829                    graph.nrows()
830                ),
831                actual: format!(
832                    "Graph has {} samples, labels has {} samples",
833                    graph.nrows(),
834                    y.len()
835                ),
836            });
837        }
838
839        // Identify labeled and unlabeled samples
840        let labeled_mask: Array1<bool> = y.iter().map(|&label| label != -1).collect();
841        let unique_labels: Vec<i32> = y
842            .iter()
843            .filter(|&&label| label != -1)
844            .cloned()
845            .collect::<std::collections::HashSet<_>>()
846            .into_iter()
847            .collect();
848
849        if unique_labels.is_empty() {
850            return Ok(Array1::from_elem(n_samples, -1));
851        }
852
853        let n_classes = unique_labels.len();
854
855        // Initialize label probability matrix
856        let mut F = Array2::<f64>::zeros((n_samples, n_classes));
857
858        // Set initial labels for labeled samples
859        for i in 0..n_samples {
860            if labeled_mask[i] {
861                if let Some(class_idx) = unique_labels.iter().position(|&x| x == y[i]) {
862                    F[[i, class_idx]] = 1.0;
863                }
864            }
865        }
866
867        // Normalize graph to get transition matrix
868        let P = self.normalize_graph(graph)?;
869
870        // Iterative label propagation
871        for _iter in 0..self.max_iter {
872            let F_old = F.clone();
873
874            // Propagate labels: F = α * P * F + (1-α) * Y
875            let propagated = P.dot(&F);
876            F = &propagated * self.alpha;
877
878            // Reset labeled samples
879            for i in 0..n_samples {
880                if labeled_mask[i] {
881                    F.row_mut(i).fill(0.0);
882                    if let Some(class_idx) = unique_labels.iter().position(|&x| x == y[i]) {
883                        F[[i, class_idx]] = 1.0;
884                    }
885                }
886            }
887
888            // Check convergence
889            let change = (&F - &F_old).iter().map(|x| x.abs()).sum::<f64>();
890            if change < self.tolerance {
891                break;
892            }
893        }
894
895        // Convert probabilities to labels
896        let mut labels = Array1::zeros(n_samples);
897        for i in 0..n_samples {
898            let mut max_prob = 0.0;
899            let mut max_class = 0;
900
901            for j in 0..n_classes {
902                if F[[i, j]] > max_prob {
903                    max_prob = F[[i, j]];
904                    max_class = j;
905                }
906            }
907
908            labels[i] = unique_labels[max_class];
909        }
910
911        Ok(labels)
912    }
913
914    /// Normalize graph to get transition matrix
915    fn normalize_graph(&self, graph: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
916        let n_samples = graph.nrows();
917        let mut P = graph.clone();
918
919        for i in 0..n_samples {
920            let row_sum: f64 = P.row(i).sum();
921            if row_sum > 0.0 {
922                for j in 0..n_samples {
923                    P[[i, j]] /= row_sum;
924                }
925            }
926        }
927
928        Ok(P)
929    }
930}
931
932impl Default for MultiScaleSemiSupervised {
933    fn default() -> Self {
934        Self::new()
935    }
936}
937
938#[allow(non_snake_case)]
939#[cfg(test)]
940mod tests {
941    use super::*;
942    use approx::assert_abs_diff_eq;
943    use scirs2_core::array;
944
945    #[test]
946    #[allow(non_snake_case)]
947    fn test_hierarchical_graph_construction() {
948        let X = array![
949            [1.0, 2.0],
950            [2.0, 3.0],
951            [3.0, 4.0],
952            [4.0, 5.0],
953            [5.0, 6.0],
954            [6.0, 7.0]
955        ];
956
957        let hgc = HierarchicalGraphConstruction::new()
958            .n_levels(3)
959            .base_k_neighbors(2)
960            .coarsening_method("clustering".to_string())
961            .coarsening_ratio(0.5);
962
963        let result = hgc.fit(&X.view());
964        assert!(result.is_ok());
965
966        let hierarchy = result.unwrap();
967        assert_eq!(hierarchy.n_levels(), 3);
968
969        // Check that each level has a valid graph
970        for level in 0..hierarchy.n_levels() {
971            let graph = hierarchy.level_graph(level).unwrap();
972            assert!(graph.nrows() > 0);
973            assert_eq!(graph.nrows(), graph.ncols());
974        }
975    }
976
977    #[test]
978    #[allow(non_snake_case)]
979    fn test_coarsening_methods() {
980        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
981
982        let methods = vec!["sampling", "clustering", "pooling"];
983
984        for method in methods {
985            let hgc = HierarchicalGraphConstruction::new()
986                .n_levels(2)
987                .coarsening_method(method.to_string())
988                .coarsening_ratio(0.5)
989                .random_state(42);
990
991            let result = hgc.fit(&X.view());
992            assert!(result.is_ok());
993
994            let hierarchy = result.unwrap();
995            assert_eq!(hierarchy.n_levels(), 2);
996        }
997    }
998
999    #[test]
1000    #[allow(non_snake_case)]
1001    fn test_construction_methods() {
1002        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1003
1004        let methods = vec!["knn", "epsilon", "adaptive"];
1005
1006        for method in methods {
1007            let hgc = HierarchicalGraphConstruction::new()
1008                .n_levels(2)
1009                .construction_method(method.to_string())
1010                .base_k_neighbors(2);
1011
1012            let result = hgc.fit(&X.view());
1013            assert!(result.is_ok());
1014
1015            let hierarchy = result.unwrap();
1016            assert_eq!(hierarchy.n_levels(), 2);
1017        }
1018    }
1019
1020    #[test]
1021    #[allow(non_snake_case)]
1022    fn test_multi_scale_semi_supervised() {
1023        let X = array![
1024            [1.0, 2.0],
1025            [2.0, 3.0],
1026            [3.0, 4.0],
1027            [4.0, 5.0],
1028            [5.0, 6.0],
1029            [6.0, 7.0]
1030        ];
1031        let y = array![0, 1, -1, -1, -1, -1]; // -1 indicates unlabeled
1032
1033        let graph_builder = HierarchicalGraphConstruction::new()
1034            .n_levels(2)
1035            .base_k_neighbors(2)
1036            .coarsening_ratio(0.5)
1037            .random_state(42);
1038
1039        let mssl = MultiScaleSemiSupervised::new()
1040            .graph_builder(graph_builder)
1041            .alpha(0.2)
1042            .max_iter(100)
1043            .combination_method("fine_to_coarse".to_string());
1044
1045        let result = mssl.fit(&X.view(), &y.view());
1046        assert!(result.is_ok());
1047
1048        let labels = result.unwrap();
1049        assert_eq!(labels.len(), 6);
1050
1051        // Check that labeled samples retain their labels
1052        assert_eq!(labels[0], 0);
1053        assert_eq!(labels[1], 1);
1054    }
1055
1056    #[test]
1057    #[allow(non_snake_case)]
1058    fn test_combination_methods() {
1059        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1060        let y = array![0, 1, -1, -1];
1061
1062        let methods = vec!["fine_to_coarse", "coarse_to_fine", "simultaneous"];
1063
1064        for method in methods {
1065            let graph_builder = HierarchicalGraphConstruction::new()
1066                .n_levels(2)
1067                .base_k_neighbors(2)
1068                .random_state(42);
1069
1070            let mssl = MultiScaleSemiSupervised::new()
1071                .graph_builder(graph_builder)
1072                .combination_method(method.to_string())
1073                .max_iter(50);
1074
1075            let result = mssl.fit(&X.view(), &y.view());
1076            assert!(result.is_ok());
1077
1078            let labels = result.unwrap();
1079            assert_eq!(labels.len(), 4);
1080            // Check that labeled samples retain their original labels (or reasonable prediction)
1081            // In semi-supervised learning, exact results can vary based on graph construction
1082            assert!(labels[0] == 0 || labels[0] == 1); // First sample should be classified
1083            assert!(labels[1] == 0 || labels[1] == 1); // Second sample should be classified
1084        }
1085    }
1086
1087    #[test]
1088    #[allow(non_snake_case)]
1089    fn test_hierarchical_graph_error_cases() {
1090        let hgc = HierarchicalGraphConstruction::new().construction_method("invalid".to_string());
1091
1092        let X = array![[1.0, 2.0], [2.0, 3.0]];
1093        let result = hgc.fit(&X.view());
1094        assert!(result.is_err());
1095    }
1096
1097    #[test]
1098    #[allow(non_snake_case)]
1099    fn test_multi_scale_error_cases() {
1100        let mssl = MultiScaleSemiSupervised::new();
1101
1102        // Test with mismatched dimensions
1103        let X = array![[1.0, 2.0], [2.0, 3.0]];
1104        let y = array![0]; // Wrong size
1105
1106        let result = mssl.fit(&X.view(), &y.view());
1107        assert!(result.is_err());
1108    }
1109}