1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::random::rand_prelude::*;
8use scirs2_core::random::Random;
9use sklears_core::error::SklearsError;
10
11#[derive(Clone)]
13pub struct HierarchicalGraphConstruction {
14 pub n_levels: usize,
16 pub base_k_neighbors: usize,
18 pub neighbor_scaling: f64,
20 pub coarsening_method: String,
22 pub coarsening_ratio: f64,
24 pub construction_method: String,
26 pub refinement_iter: usize,
28 pub random_state: Option<u64>,
30}
31
32impl HierarchicalGraphConstruction {
33 pub fn new() -> Self {
35 Self {
36 n_levels: 3,
37 base_k_neighbors: 5,
38 neighbor_scaling: 1.5,
39 coarsening_method: "clustering".to_string(),
40 coarsening_ratio: 0.5,
41 construction_method: "knn".to_string(),
42 refinement_iter: 10,
43 random_state: None,
44 }
45 }
46
47 pub fn n_levels(mut self, levels: usize) -> Self {
49 self.n_levels = levels;
50 self
51 }
52
53 pub fn base_k_neighbors(mut self, k: usize) -> Self {
55 self.base_k_neighbors = k;
56 self
57 }
58
59 pub fn neighbor_scaling(mut self, scaling: f64) -> Self {
61 self.neighbor_scaling = scaling;
62 self
63 }
64
65 pub fn coarsening_method(mut self, method: String) -> Self {
67 self.coarsening_method = method;
68 self
69 }
70
71 pub fn coarsening_ratio(mut self, ratio: f64) -> Self {
73 self.coarsening_ratio = ratio;
74 self
75 }
76
77 pub fn construction_method(mut self, method: String) -> Self {
79 self.construction_method = method;
80 self
81 }
82
83 pub fn refinement_iter(mut self, iter: usize) -> Self {
85 self.refinement_iter = iter;
86 self
87 }
88
89 pub fn random_state(mut self, seed: u64) -> Self {
91 self.random_state = Some(seed);
92 self
93 }
94
95 pub fn fit(&self, X: &ArrayView2<f64>) -> Result<HierarchicalGraph, SklearsError> {
97 let mut rng = if let Some(seed) = self.random_state {
98 Random::seed(42)
99 } else {
100 Random::seed(42)
101 };
102
103 let mut hierarchy = HierarchicalGraph::new();
105 let mut current_data = X.to_owned();
106 let mut current_indices: Vec<usize> = (0..X.nrows()).collect();
107
108 for level in 0..self.n_levels {
109 let k_neighbors =
110 (self.base_k_neighbors as f64 * self.neighbor_scaling.powi(level as i32)) as usize;
111
112 let graph = self.construct_level_graph(¤t_data.view(), k_neighbors)?;
114
115 hierarchy.add_level(level, graph, current_data.clone(), current_indices.clone());
117
118 if level < self.n_levels - 1 {
120 let (coarsened_data, coarsened_indices) =
121 self.coarsen_level(¤t_data.view(), ¤t_indices, &mut rng)?;
122 current_data = coarsened_data;
123 current_indices = coarsened_indices;
124 }
125 }
126
127 hierarchy = self.refine_hierarchy(hierarchy)?;
129
130 Ok(hierarchy)
131 }
132
133 fn construct_level_graph(
135 &self,
136 X: &ArrayView2<f64>,
137 k_neighbors: usize,
138 ) -> Result<Array2<f64>, SklearsError> {
139 let n_samples = X.nrows();
140 let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
141
142 match self.construction_method.as_str() {
143 "knn" => {
144 for i in 0..n_samples {
145 let mut distances: Vec<(f64, usize)> = Vec::new();
146
147 for j in 0..n_samples {
148 if i != j {
149 let dist = self.euclidean_distance(&X.row(i), &X.row(j));
150 distances.push((dist, j));
151 }
152 }
153
154 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
155
156 for (dist, j) in distances.iter().take(k_neighbors.min(distances.len())) {
157 let weight = (-dist.powi(2) / 2.0).exp();
158 graph[[i, *j]] = weight;
159 }
160 }
161
162 for i in 0..n_samples {
164 for j in i + 1..n_samples {
165 let avg_weight = (graph[[i, j]] + graph[[j, i]]) / 2.0;
166 graph[[i, j]] = avg_weight;
167 graph[[j, i]] = avg_weight;
168 }
169 }
170 }
171 "epsilon" => {
172 let epsilon = self.compute_adaptive_epsilon(X, k_neighbors)?;
173
174 for i in 0..n_samples {
175 for j in i + 1..n_samples {
176 let dist = self.euclidean_distance(&X.row(i), &X.row(j));
177 if dist <= epsilon {
178 let weight = (-dist.powi(2) / 2.0).exp();
179 graph[[i, j]] = weight;
180 graph[[j, i]] = weight;
181 }
182 }
183 }
184 }
185 "adaptive" => {
186 graph = self.construct_adaptive_graph(X, k_neighbors)?;
187 }
188 _ => {
189 return Err(SklearsError::InvalidInput(format!(
190 "Unknown construction method: {}",
191 self.construction_method
192 )));
193 }
194 }
195
196 Ok(graph)
197 }
198
199 fn compute_adaptive_epsilon(
201 &self,
202 X: &ArrayView2<f64>,
203 k_neighbors: usize,
204 ) -> Result<f64, SklearsError> {
205 let n_samples = X.nrows();
206 let mut kth_distances = Vec::new();
207
208 for i in 0..n_samples {
209 let mut distances = Vec::new();
210 for j in 0..n_samples {
211 if i != j {
212 let dist = self.euclidean_distance(&X.row(i), &X.row(j));
213 distances.push(dist);
214 }
215 }
216
217 if !distances.is_empty() {
218 distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
219 let k_idx = k_neighbors.min(distances.len()) - 1;
220 kth_distances.push(distances[k_idx]);
221 }
222 }
223
224 if kth_distances.is_empty() {
225 return Ok(1.0);
226 }
227
228 kth_distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
230 let median_idx = kth_distances.len() / 2;
231 Ok(kth_distances[median_idx])
232 }
233
234 fn construct_adaptive_graph(
236 &self,
237 X: &ArrayView2<f64>,
238 base_k: usize,
239 ) -> Result<Array2<f64>, SklearsError> {
240 let n_samples = X.nrows();
241 let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
242
243 let densities = self.compute_local_densities(X, base_k)?;
245
246 for i in 0..n_samples {
247 let adaptive_k = (base_k as f64 * (1.0 / (1.0 + densities[i]))).max(1.0) as usize;
249
250 let mut distances: Vec<(f64, usize)> = Vec::new();
251 for j in 0..n_samples {
252 if i != j {
253 let dist = self.euclidean_distance(&X.row(i), &X.row(j));
254 distances.push((dist, j));
255 }
256 }
257
258 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
259
260 for (dist, j) in distances.iter().take(adaptive_k.min(distances.len())) {
261 let weight = (-dist.powi(2) / 2.0).exp();
262 graph[[i, *j]] = weight;
263 }
264 }
265
266 for i in 0..n_samples {
268 for j in i + 1..n_samples {
269 let avg_weight = (graph[[i, j]] + graph[[j, i]]) / 2.0;
270 graph[[i, j]] = avg_weight;
271 graph[[j, i]] = avg_weight;
272 }
273 }
274
275 Ok(graph)
276 }
277
278 fn compute_local_densities(
280 &self,
281 X: &ArrayView2<f64>,
282 k: usize,
283 ) -> Result<Array1<f64>, SklearsError> {
284 let n_samples = X.nrows();
285 let mut densities = Array1::zeros(n_samples);
286
287 for i in 0..n_samples {
288 let mut distances = Vec::new();
289 for j in 0..n_samples {
290 if i != j {
291 let dist = self.euclidean_distance(&X.row(i), &X.row(j));
292 distances.push(dist);
293 }
294 }
295
296 if !distances.is_empty() {
297 distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
298 let k_idx = k.min(distances.len()) - 1;
299 densities[i] = 1.0 / (1.0 + distances[k_idx]); }
301 }
302
303 Ok(densities)
304 }
305
306 fn coarsen_level<R>(
308 &self,
309 X: &ArrayView2<f64>,
310 indices: &[usize],
311 rng: &mut Random<R>,
312 ) -> Result<(Array2<f64>, Vec<usize>), SklearsError>
313 where
314 R: scirs2_core::random::Rng,
315 {
316 let n_samples = X.nrows();
317 let target_size = ((n_samples as f64) * self.coarsening_ratio).max(1.0) as usize;
318
319 match self.coarsening_method.as_str() {
320 "sampling" => self.coarsen_by_sampling(X, indices, target_size, rng),
321 "clustering" => self.coarsen_by_clustering(X, indices, target_size),
322 "pooling" => self.coarsen_by_pooling(X, indices, target_size),
323 _ => Err(SklearsError::InvalidInput(format!(
324 "Unknown coarsening method: {}",
325 self.coarsening_method
326 ))),
327 }
328 }
329
330 fn coarsen_by_sampling<R>(
332 &self,
333 X: &ArrayView2<f64>,
334 indices: &[usize],
335 target_size: usize,
336 rng: &mut Random<R>,
337 ) -> Result<(Array2<f64>, Vec<usize>), SklearsError>
338 where
339 R: scirs2_core::random::Rng,
340 {
341 let n_samples = X.nrows();
342 let mut selected_indices: Vec<usize> = (0..n_samples).collect();
343 selected_indices.shuffle(rng);
344 selected_indices.truncate(target_size);
345 selected_indices.sort();
346
347 let mut coarsened_data = Array2::<f64>::zeros((target_size, X.ncols()));
348 let mut coarsened_indices = Vec::new();
349
350 for (i, &idx) in selected_indices.iter().enumerate() {
351 coarsened_data.row_mut(i).assign(&X.row(idx));
352 coarsened_indices.push(indices[idx]);
353 }
354
355 Ok((coarsened_data, coarsened_indices))
356 }
357
358 fn coarsen_by_clustering(
360 &self,
361 X: &ArrayView2<f64>,
362 indices: &[usize],
363 target_size: usize,
364 ) -> Result<(Array2<f64>, Vec<usize>), SklearsError> {
365 let n_samples = X.nrows();
366 let n_features = X.ncols();
367
368 if target_size >= n_samples {
369 return Ok((X.to_owned(), indices.to_vec()));
370 }
371
372 let mut centers = Vec::new();
374 let mut center_indices = Vec::new();
375
376 centers.push(X.row(0).to_owned());
378 center_indices.push(0);
379
380 for _ in 1..target_size {
381 let mut max_dist = 0.0;
382 let mut farthest_idx = 0;
383
384 for i in 0..n_samples {
385 let mut min_dist_to_centers = f64::INFINITY;
386
387 for center in ¢ers {
388 let dist = self.euclidean_distance(&X.row(i), ¢er.view());
389 min_dist_to_centers = min_dist_to_centers.min(dist);
390 }
391
392 if min_dist_to_centers > max_dist {
393 max_dist = min_dist_to_centers;
394 farthest_idx = i;
395 }
396 }
397
398 centers.push(X.row(farthest_idx).to_owned());
399 center_indices.push(farthest_idx);
400 }
401
402 let mut coarsened_data = Array2::<f64>::zeros((target_size, n_features));
403 let mut coarsened_indices = Vec::new();
404
405 for (i, &idx) in center_indices.iter().enumerate() {
406 coarsened_data.row_mut(i).assign(&X.row(idx));
407 coarsened_indices.push(indices[idx]);
408 }
409
410 Ok((coarsened_data, coarsened_indices))
411 }
412
413 fn coarsen_by_pooling(
415 &self,
416 X: &ArrayView2<f64>,
417 indices: &[usize],
418 target_size: usize,
419 ) -> Result<(Array2<f64>, Vec<usize>), SklearsError> {
420 let n_samples = X.nrows();
421 let n_features = X.ncols();
422
423 if target_size >= n_samples {
424 return Ok((X.to_owned(), indices.to_vec()));
425 }
426
427 let pool_size = n_samples / target_size;
428 let mut coarsened_data = Array2::<f64>::zeros((target_size, n_features));
429 let mut coarsened_indices = Vec::new();
430
431 for i in 0..target_size {
432 let start_idx = i * pool_size;
433 let end_idx = if i == target_size - 1 {
434 n_samples
435 } else {
436 (i + 1) * pool_size
437 };
438
439 let mut pool_mean = Array1::zeros(n_features);
441 let mut count = 0;
442
443 for j in start_idx..end_idx {
444 pool_mean = pool_mean + X.row(j);
445 count += 1;
446 }
447
448 if count > 0 {
449 pool_mean /= count as f64;
450 }
451
452 coarsened_data.row_mut(i).assign(&pool_mean);
453 coarsened_indices.push(indices[start_idx]); }
455
456 Ok((coarsened_data, coarsened_indices))
457 }
458
459 fn refine_hierarchy(
461 &self,
462 mut hierarchy: HierarchicalGraph,
463 ) -> Result<HierarchicalGraph, SklearsError> {
464 for _iter in 0..self.refinement_iter {
465 for level in 1..hierarchy.levels.len() {
467 hierarchy = self.refine_level(hierarchy, level)?;
468 }
469 }
470 Ok(hierarchy)
471 }
472
473 fn refine_level(
475 &self,
476 mut hierarchy: HierarchicalGraph,
477 level: usize,
478 ) -> Result<HierarchicalGraph, SklearsError> {
479 if level == 0 || level >= hierarchy.levels.len() {
480 return Ok(hierarchy);
481 }
482
483 let current_graph = hierarchy.levels[level].graph.clone();
485 let coarser_graph = hierarchy.levels[level - 1].graph.clone();
486
487 let refined_graph = self.interpolate_graphs(¤t_graph, &coarser_graph)?;
489 hierarchy.levels[level].graph = refined_graph;
490
491 Ok(hierarchy)
492 }
493
494 fn interpolate_graphs(
496 &self,
497 fine_graph: &Array2<f64>,
498 coarse_graph: &Array2<f64>,
499 ) -> Result<Array2<f64>, SklearsError> {
500 let alpha = 0.8; let fine_size = fine_graph.nrows();
503 let coarse_size = coarse_graph.nrows();
504
505 if fine_size <= coarse_size {
506 return Ok(fine_graph.clone());
507 }
508
509 let mut refined = fine_graph.clone();
510
511 let scale_factor = coarse_size as f64 / fine_size as f64;
513
514 for i in 0..fine_size {
515 for j in 0..fine_size {
516 let coarse_i = ((i as f64) * scale_factor) as usize;
517 let coarse_j = ((j as f64) * scale_factor) as usize;
518
519 if coarse_i < coarse_size && coarse_j < coarse_size {
520 let coarse_weight = coarse_graph[[coarse_i, coarse_j]];
521 refined[[i, j]] = alpha * fine_graph[[i, j]] + (1.0 - alpha) * coarse_weight;
522 }
523 }
524 }
525
526 Ok(refined)
527 }
528
529 fn euclidean_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
531 x1.iter()
532 .zip(x2.iter())
533 .map(|(a, b)| (a - b).powi(2))
534 .sum::<f64>()
535 .sqrt()
536 }
537}
538
539impl Default for HierarchicalGraphConstruction {
540 fn default() -> Self {
541 Self::new()
542 }
543}
544
545#[derive(Clone)]
547pub struct HierarchicalGraph {
548 pub levels: Vec<HierarchyLevel>,
550}
551
552impl HierarchicalGraph {
553 pub fn new() -> Self {
555 Self { levels: Vec::new() }
556 }
557
558 pub fn add_level(
560 &mut self,
561 level_id: usize,
562 graph: Array2<f64>,
563 data: Array2<f64>,
564 indices: Vec<usize>,
565 ) {
566 let level = HierarchyLevel {
567 level_id,
568 graph,
569 data,
570 indices,
571 };
572 self.levels.push(level);
573 }
574
575 pub fn finest_graph(&self) -> Option<&Array2<f64>> {
577 self.levels.first().map(|level| &level.graph)
578 }
579
580 pub fn coarsest_graph(&self) -> Option<&Array2<f64>> {
582 self.levels.last().map(|level| &level.graph)
583 }
584
585 pub fn level_graph(&self, level_id: usize) -> Option<&Array2<f64>> {
587 self.levels.get(level_id).map(|level| &level.graph)
588 }
589
590 pub fn n_levels(&self) -> usize {
592 self.levels.len()
593 }
594}
595
596impl Default for HierarchicalGraph {
597 fn default() -> Self {
598 Self::new()
599 }
600}
601
602#[derive(Clone)]
604pub struct HierarchyLevel {
605 pub level_id: usize,
607 pub graph: Array2<f64>,
609 pub data: Array2<f64>,
611 pub indices: Vec<usize>,
613}
614
615#[derive(Clone)]
617pub struct MultiScaleSemiSupervised {
618 pub graph_builder: HierarchicalGraphConstruction,
620 pub alpha: f64,
622 pub max_iter: usize,
624 pub tolerance: f64,
626 pub combination_method: String,
628 pub random_state: Option<u64>,
630}
631
632impl MultiScaleSemiSupervised {
633 pub fn new() -> Self {
635 Self {
636 graph_builder: HierarchicalGraphConstruction::new(),
637 alpha: 0.2,
638 max_iter: 1000,
639 tolerance: 1e-6,
640 combination_method: "fine_to_coarse".to_string(),
641 random_state: None,
642 }
643 }
644
645 pub fn graph_builder(mut self, builder: HierarchicalGraphConstruction) -> Self {
647 self.graph_builder = builder;
648 self
649 }
650
651 pub fn alpha(mut self, alpha: f64) -> Self {
653 self.alpha = alpha;
654 self
655 }
656
657 pub fn max_iter(mut self, max_iter: usize) -> Self {
659 self.max_iter = max_iter;
660 self
661 }
662
663 pub fn tolerance(mut self, tol: f64) -> Self {
665 self.tolerance = tol;
666 self
667 }
668
669 pub fn combination_method(mut self, method: String) -> Self {
671 self.combination_method = method;
672 self
673 }
674
675 pub fn random_state(mut self, seed: u64) -> Self {
677 self.random_state = Some(seed);
678 self.graph_builder = self.graph_builder.random_state(seed);
679 self
680 }
681
682 pub fn fit(
684 &self,
685 X: &ArrayView2<f64>,
686 y: &ArrayView1<i32>,
687 ) -> Result<Array1<i32>, SklearsError> {
688 let n_samples = X.nrows();
689
690 if y.len() != n_samples {
691 return Err(SklearsError::ShapeMismatch {
692 expected: format!("X and y should have same number of samples: {}", X.nrows()),
693 actual: format!("X has {} samples, y has {} samples", X.nrows(), y.len()),
694 });
695 }
696
697 let hierarchy = self.graph_builder.fit(X)?;
699
700 let labels = match self.combination_method.as_str() {
702 "fine_to_coarse" => self.propagate_fine_to_coarse(&hierarchy, y)?,
703 "coarse_to_fine" => self.propagate_coarse_to_fine(&hierarchy, y)?,
704 "simultaneous" => self.propagate_simultaneous(&hierarchy, y)?,
705 _ => {
706 return Err(SklearsError::InvalidInput(format!(
707 "Unknown combination method: {}",
708 self.combination_method
709 )))
710 }
711 };
712
713 Ok(labels)
714 }
715
716 fn propagate_fine_to_coarse(
718 &self,
719 hierarchy: &HierarchicalGraph,
720 y: &ArrayView1<i32>,
721 ) -> Result<Array1<i32>, SklearsError> {
722 let finest_graph = hierarchy
723 .finest_graph()
724 .ok_or_else(|| SklearsError::InvalidInput("Empty hierarchy".to_string()))?;
725
726 let labels = self.propagate_labels(finest_graph, y)?;
728
729 Ok(labels)
731 }
732
733 fn propagate_coarse_to_fine(
735 &self,
736 hierarchy: &HierarchicalGraph,
737 y: &ArrayView1<i32>,
738 ) -> Result<Array1<i32>, SklearsError> {
739 if hierarchy.levels.is_empty() {
740 return Err(SklearsError::InvalidInput("Empty hierarchy".to_string()));
741 }
742
743 let coarsest_level = &hierarchy.levels[hierarchy.levels.len() - 1];
745
746 let coarse_labels = self.map_labels_to_level(y, &coarsest_level.indices)?;
748
749 let mut propagated_labels =
751 self.propagate_labels(&coarsest_level.graph, &coarse_labels.view())?;
752
753 for level_idx in (0..hierarchy.levels.len() - 1).rev() {
755 let level = &hierarchy.levels[level_idx];
756 let refined_labels = self.refine_labels_for_level(
757 &propagated_labels,
758 &level.indices,
759 level.data.nrows(),
760 )?;
761 propagated_labels = self.propagate_labels(&level.graph, &refined_labels.view())?;
762 }
763
764 Ok(propagated_labels)
765 }
766
767 fn propagate_simultaneous(
769 &self,
770 hierarchy: &HierarchicalGraph,
771 y: &ArrayView1<i32>,
772 ) -> Result<Array1<i32>, SklearsError> {
773 if hierarchy.levels.is_empty() {
774 return Err(SklearsError::InvalidInput("Empty hierarchy".to_string()));
775 }
776
777 self.propagate_fine_to_coarse(hierarchy, y)
779 }
780
781 fn map_labels_to_level(
783 &self,
784 y: &ArrayView1<i32>,
785 level_indices: &[usize],
786 ) -> Result<Array1<i32>, SklearsError> {
787 let mut mapped_labels = Array1::from_elem(level_indices.len(), -1);
788
789 for (i, &original_idx) in level_indices.iter().enumerate() {
790 if original_idx < y.len() {
791 mapped_labels[i] = y[original_idx];
792 }
793 }
794
795 Ok(mapped_labels)
796 }
797
798 fn refine_labels_for_level(
800 &self,
801 coarse_labels: &Array1<i32>,
802 level_indices: &[usize],
803 level_size: usize,
804 ) -> Result<Array1<i32>, SklearsError> {
805 let mut refined_labels = Array1::from_elem(level_size, -1);
806
807 for (i, &original_idx) in level_indices.iter().enumerate() {
809 if i < coarse_labels.len() {
810 refined_labels[original_idx] = coarse_labels[i];
811 }
812 }
813
814 Ok(refined_labels)
815 }
816
817 fn propagate_labels(
819 &self,
820 graph: &Array2<f64>,
821 y: &ArrayView1<i32>,
822 ) -> Result<Array1<i32>, SklearsError> {
823 let n_samples = graph.nrows();
824
825 if y.len() != n_samples {
826 return Err(SklearsError::ShapeMismatch {
827 expected: format!(
828 "Graph and labels should have same number of samples: {}",
829 graph.nrows()
830 ),
831 actual: format!(
832 "Graph has {} samples, labels has {} samples",
833 graph.nrows(),
834 y.len()
835 ),
836 });
837 }
838
839 let labeled_mask: Array1<bool> = y.iter().map(|&label| label != -1).collect();
841 let unique_labels: Vec<i32> = y
842 .iter()
843 .filter(|&&label| label != -1)
844 .cloned()
845 .collect::<std::collections::HashSet<_>>()
846 .into_iter()
847 .collect();
848
849 if unique_labels.is_empty() {
850 return Ok(Array1::from_elem(n_samples, -1));
851 }
852
853 let n_classes = unique_labels.len();
854
855 let mut F = Array2::<f64>::zeros((n_samples, n_classes));
857
858 for i in 0..n_samples {
860 if labeled_mask[i] {
861 if let Some(class_idx) = unique_labels.iter().position(|&x| x == y[i]) {
862 F[[i, class_idx]] = 1.0;
863 }
864 }
865 }
866
867 let P = self.normalize_graph(graph)?;
869
870 for _iter in 0..self.max_iter {
872 let F_old = F.clone();
873
874 let propagated = P.dot(&F);
876 F = &propagated * self.alpha;
877
878 for i in 0..n_samples {
880 if labeled_mask[i] {
881 F.row_mut(i).fill(0.0);
882 if let Some(class_idx) = unique_labels.iter().position(|&x| x == y[i]) {
883 F[[i, class_idx]] = 1.0;
884 }
885 }
886 }
887
888 let change = (&F - &F_old).iter().map(|x| x.abs()).sum::<f64>();
890 if change < self.tolerance {
891 break;
892 }
893 }
894
895 let mut labels = Array1::zeros(n_samples);
897 for i in 0..n_samples {
898 let mut max_prob = 0.0;
899 let mut max_class = 0;
900
901 for j in 0..n_classes {
902 if F[[i, j]] > max_prob {
903 max_prob = F[[i, j]];
904 max_class = j;
905 }
906 }
907
908 labels[i] = unique_labels[max_class];
909 }
910
911 Ok(labels)
912 }
913
914 fn normalize_graph(&self, graph: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
916 let n_samples = graph.nrows();
917 let mut P = graph.clone();
918
919 for i in 0..n_samples {
920 let row_sum: f64 = P.row(i).sum();
921 if row_sum > 0.0 {
922 for j in 0..n_samples {
923 P[[i, j]] /= row_sum;
924 }
925 }
926 }
927
928 Ok(P)
929 }
930}
931
932impl Default for MultiScaleSemiSupervised {
933 fn default() -> Self {
934 Self::new()
935 }
936}
937
938#[allow(non_snake_case)]
939#[cfg(test)]
940mod tests {
941 use super::*;
942 use approx::assert_abs_diff_eq;
943 use scirs2_core::array;
944
945 #[test]
946 #[allow(non_snake_case)]
947 fn test_hierarchical_graph_construction() {
948 let X = array![
949 [1.0, 2.0],
950 [2.0, 3.0],
951 [3.0, 4.0],
952 [4.0, 5.0],
953 [5.0, 6.0],
954 [6.0, 7.0]
955 ];
956
957 let hgc = HierarchicalGraphConstruction::new()
958 .n_levels(3)
959 .base_k_neighbors(2)
960 .coarsening_method("clustering".to_string())
961 .coarsening_ratio(0.5);
962
963 let result = hgc.fit(&X.view());
964 assert!(result.is_ok());
965
966 let hierarchy = result.unwrap();
967 assert_eq!(hierarchy.n_levels(), 3);
968
969 for level in 0..hierarchy.n_levels() {
971 let graph = hierarchy.level_graph(level).unwrap();
972 assert!(graph.nrows() > 0);
973 assert_eq!(graph.nrows(), graph.ncols());
974 }
975 }
976
977 #[test]
978 #[allow(non_snake_case)]
979 fn test_coarsening_methods() {
980 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
981
982 let methods = vec!["sampling", "clustering", "pooling"];
983
984 for method in methods {
985 let hgc = HierarchicalGraphConstruction::new()
986 .n_levels(2)
987 .coarsening_method(method.to_string())
988 .coarsening_ratio(0.5)
989 .random_state(42);
990
991 let result = hgc.fit(&X.view());
992 assert!(result.is_ok());
993
994 let hierarchy = result.unwrap();
995 assert_eq!(hierarchy.n_levels(), 2);
996 }
997 }
998
999 #[test]
1000 #[allow(non_snake_case)]
1001 fn test_construction_methods() {
1002 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1003
1004 let methods = vec!["knn", "epsilon", "adaptive"];
1005
1006 for method in methods {
1007 let hgc = HierarchicalGraphConstruction::new()
1008 .n_levels(2)
1009 .construction_method(method.to_string())
1010 .base_k_neighbors(2);
1011
1012 let result = hgc.fit(&X.view());
1013 assert!(result.is_ok());
1014
1015 let hierarchy = result.unwrap();
1016 assert_eq!(hierarchy.n_levels(), 2);
1017 }
1018 }
1019
1020 #[test]
1021 #[allow(non_snake_case)]
1022 fn test_multi_scale_semi_supervised() {
1023 let X = array![
1024 [1.0, 2.0],
1025 [2.0, 3.0],
1026 [3.0, 4.0],
1027 [4.0, 5.0],
1028 [5.0, 6.0],
1029 [6.0, 7.0]
1030 ];
1031 let y = array![0, 1, -1, -1, -1, -1]; let graph_builder = HierarchicalGraphConstruction::new()
1034 .n_levels(2)
1035 .base_k_neighbors(2)
1036 .coarsening_ratio(0.5)
1037 .random_state(42);
1038
1039 let mssl = MultiScaleSemiSupervised::new()
1040 .graph_builder(graph_builder)
1041 .alpha(0.2)
1042 .max_iter(100)
1043 .combination_method("fine_to_coarse".to_string());
1044
1045 let result = mssl.fit(&X.view(), &y.view());
1046 assert!(result.is_ok());
1047
1048 let labels = result.unwrap();
1049 assert_eq!(labels.len(), 6);
1050
1051 assert_eq!(labels[0], 0);
1053 assert_eq!(labels[1], 1);
1054 }
1055
1056 #[test]
1057 #[allow(non_snake_case)]
1058 fn test_combination_methods() {
1059 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1060 let y = array![0, 1, -1, -1];
1061
1062 let methods = vec!["fine_to_coarse", "coarse_to_fine", "simultaneous"];
1063
1064 for method in methods {
1065 let graph_builder = HierarchicalGraphConstruction::new()
1066 .n_levels(2)
1067 .base_k_neighbors(2)
1068 .random_state(42);
1069
1070 let mssl = MultiScaleSemiSupervised::new()
1071 .graph_builder(graph_builder)
1072 .combination_method(method.to_string())
1073 .max_iter(50);
1074
1075 let result = mssl.fit(&X.view(), &y.view());
1076 assert!(result.is_ok());
1077
1078 let labels = result.unwrap();
1079 assert_eq!(labels.len(), 4);
1080 assert!(labels[0] == 0 || labels[0] == 1); assert!(labels[1] == 0 || labels[1] == 1); }
1085 }
1086
1087 #[test]
1088 #[allow(non_snake_case)]
1089 fn test_hierarchical_graph_error_cases() {
1090 let hgc = HierarchicalGraphConstruction::new().construction_method("invalid".to_string());
1091
1092 let X = array![[1.0, 2.0], [2.0, 3.0]];
1093 let result = hgc.fit(&X.view());
1094 assert!(result.is_err());
1095 }
1096
1097 #[test]
1098 #[allow(non_snake_case)]
1099 fn test_multi_scale_error_cases() {
1100 let mssl = MultiScaleSemiSupervised::new();
1101
1102 let X = array![[1.0, 2.0], [2.0, 3.0]];
1104 let y = array![0]; let result = mssl.fit(&X.view(), &y.view());
1107 assert!(result.is_err());
1108 }
1109}