1use crate::base::SelectorMixin;
83use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
84use sklears_core::error::{Result as SklResult, SklearsError};
85use sklears_core::traits::{Estimator, Fit, Transform};
86use std::collections::HashMap;
87use std::marker::PhantomData;
88
89type Result<T> = SklResult<T>;
90type Float = f64;
91
92#[derive(Debug, Clone)]
93pub struct Untrained;
94
95#[derive(Debug, Clone)]
96pub struct Trained {
97 selected_features: Vec<usize>,
98 feature_scores: Array1<Float>,
99 centrality_scores: Option<HashMap<String, Array1<Float>>>,
100 community_assignments: Option<Array1<usize>>,
101 structural_scores: Option<Array1<Float>>,
102 n_features: usize,
103}
104
105#[derive(Debug, Clone)]
113pub struct GraphFeatureSelector<State = Untrained> {
114 include_centrality: bool,
115 include_community: bool,
116 include_structural: bool,
117 centrality_threshold: Float,
118 centrality_types: Vec<String>,
119 community_method: String,
120 min_community_size: usize,
121 community_weight: Float,
122 structural_weight: Float,
123 k: Option<usize>,
124 damping_factor: Float,
125 max_iterations: usize,
126 tolerance: Float,
127 adjacency: Option<Array2<Float>>,
128 state: PhantomData<State>,
129 trained_state: Option<Trained>,
130}
131
132impl Default for GraphFeatureSelector<Untrained> {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138impl GraphFeatureSelector<Untrained> {
139 pub fn new() -> Self {
141 Self {
142 include_centrality: true,
143 include_community: true,
144 include_structural: true,
145 centrality_threshold: 0.1,
146 centrality_types: vec!["degree".to_string(), "pagerank".to_string()],
147 community_method: "modularity".to_string(),
148 min_community_size: 2,
149 community_weight: 0.5,
150 structural_weight: 0.3,
151 k: None,
152 damping_factor: 0.85,
153 max_iterations: 100,
154 tolerance: 1e-6,
155 adjacency: None,
156 state: PhantomData,
157 trained_state: None,
158 }
159 }
160
161 pub fn builder() -> GraphFeatureSelectorBuilder {
163 GraphFeatureSelectorBuilder::new()
164 }
165}
166
167#[derive(Debug)]
169pub struct GraphFeatureSelectorBuilder {
170 include_centrality: bool,
171 include_community: bool,
172 include_structural: bool,
173 centrality_threshold: Float,
174 centrality_types: Vec<String>,
175 community_method: String,
176 min_community_size: usize,
177 community_weight: Float,
178 structural_weight: Float,
179 k: Option<usize>,
180 damping_factor: Float,
181 max_iterations: usize,
182 tolerance: Float,
183 adjacency: Option<Array2<Float>>,
184}
185
186impl Default for GraphFeatureSelectorBuilder {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192impl GraphFeatureSelectorBuilder {
193 pub fn new() -> Self {
194 Self {
195 include_centrality: true,
196 include_community: true,
197 include_structural: true,
198 centrality_threshold: 0.1,
199 centrality_types: vec!["degree".to_string(), "pagerank".to_string()],
200 community_method: "modularity".to_string(),
201 min_community_size: 2,
202 community_weight: 0.5,
203 structural_weight: 0.3,
204 k: None,
205 damping_factor: 0.85,
206 max_iterations: 100,
207 tolerance: 1e-6,
208 adjacency: None,
209 }
210 }
211
212 pub fn include_centrality(mut self, include: bool) -> Self {
214 self.include_centrality = include;
215 self
216 }
217
218 pub fn include_community(mut self, include: bool) -> Self {
220 self.include_community = include;
221 self
222 }
223
224 pub fn include_structural(mut self, include: bool) -> Self {
226 self.include_structural = include;
227 self
228 }
229
230 pub fn centrality_threshold(mut self, threshold: Float) -> Self {
232 self.centrality_threshold = threshold;
233 self
234 }
235
236 pub fn centrality_types(mut self, types: Vec<&str>) -> Self {
238 self.centrality_types = types.iter().map(|s| s.to_string()).collect();
239 self
240 }
241
242 pub fn community_method(mut self, method: &str) -> Self {
244 self.community_method = method.to_string();
245 self
246 }
247
248 pub fn min_community_size(mut self, size: usize) -> Self {
250 self.min_community_size = size;
251 self
252 }
253
254 pub fn community_weight(mut self, weight: Float) -> Self {
256 self.community_weight = weight;
257 self
258 }
259
260 pub fn structural_weight(mut self, weight: Float) -> Self {
262 self.structural_weight = weight;
263 self
264 }
265
266 pub fn k(mut self, k: usize) -> Self {
268 self.k = Some(k);
269 self
270 }
271
272 pub fn damping_factor(mut self, factor: Float) -> Self {
274 self.damping_factor = factor;
275 self
276 }
277
278 pub fn max_iterations(mut self, iterations: usize) -> Self {
280 self.max_iterations = iterations;
281 self
282 }
283
284 pub fn tolerance(mut self, tol: Float) -> Self {
286 self.tolerance = tol;
287 self
288 }
289
290 pub fn with_adjacency(mut self, adjacency: Array2<Float>) -> Self {
292 self.adjacency = Some(adjacency);
293 self
294 }
295
296 pub fn build(self) -> GraphFeatureSelector<Untrained> {
298 GraphFeatureSelector {
299 include_centrality: self.include_centrality,
300 include_community: self.include_community,
301 include_structural: self.include_structural,
302 centrality_threshold: self.centrality_threshold,
303 centrality_types: self.centrality_types,
304 community_method: self.community_method,
305 min_community_size: self.min_community_size,
306 community_weight: self.community_weight,
307 structural_weight: self.structural_weight,
308 k: self.k,
309 damping_factor: self.damping_factor,
310 max_iterations: self.max_iterations,
311 tolerance: self.tolerance,
312 adjacency: self.adjacency,
313 state: PhantomData,
314 trained_state: None,
315 }
316 }
317}
318
319impl Estimator for GraphFeatureSelector<Untrained> {
320 type Config = ();
321 type Error = sklears_core::error::SklearsError;
322 type Float = Float;
323
324 fn config(&self) -> &Self::Config {
325 &()
326 }
327}
328
329impl Estimator for GraphFeatureSelector<Trained> {
330 type Config = ();
331 type Error = sklears_core::error::SklearsError;
332 type Float = Float;
333
334 fn config(&self) -> &Self::Config {
335 &()
336 }
337}
338
339impl Fit<Array2<Float>, Array1<Float>> for GraphFeatureSelector<Untrained> {
340 type Fitted = GraphFeatureSelector<Trained>;
341
342 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
343 let (n_samples, n_features) = x.dim();
344
345 if y.len() != n_samples {
346 return Err(SklearsError::InvalidInput(
347 "Number of samples in X and y must match".to_string(),
348 ));
349 }
350
351 let adjacency = self.adjacency.ok_or_else(|| {
352 SklearsError::InvalidInput(
353 "Adjacency matrix is required for graph feature selection".to_string(),
354 )
355 })?;
356
357 if adjacency.dim() != (n_samples, n_samples) {
358 return Err(SklearsError::InvalidInput(
359 "Adjacency matrix must be square with same number of nodes as samples".to_string(),
360 ));
361 }
362
363 let mut centrality_scores = None;
364 let mut community_assignments = None;
365 let mut structural_scores = None;
366 let mut combined_scores = Array1::zeros(n_features);
367
368 if self.include_centrality {
370 let mut centrality_map = HashMap::new();
371
372 for centrality_type in &self.centrality_types {
373 let scores = match centrality_type.as_str() {
374 "degree" => compute_degree_centrality(&adjacency.view()),
375 "pagerank" => compute_pagerank_centrality(
376 &adjacency.view(),
377 self.damping_factor,
378 self.max_iterations,
379 self.tolerance,
380 ),
381 "betweenness" => compute_betweenness_centrality(&adjacency.view()),
382 "closeness" => compute_closeness_centrality(&adjacency.view()),
383 _ => Array1::zeros(n_samples),
384 };
385 centrality_map.insert(centrality_type.clone(), scores);
386 }
387
388 let feature_centrality_scores = compute_feature_centrality_scores(x, ¢rality_map)?;
390 combined_scores = &combined_scores + &feature_centrality_scores;
391 centrality_scores = Some(centrality_map);
392 }
393
394 if self.include_community {
396 let communities = match self.community_method.as_str() {
397 "modularity" => {
398 detect_communities_modularity(&adjacency.view(), self.min_community_size)
399 }
400 "louvain" => detect_communities_louvain(&adjacency.view(), self.min_community_size),
401 _ => Array1::zeros(n_samples),
402 };
403
404 let community_feature_scores =
405 compute_community_feature_scores(x, &communities, self.community_weight)?;
406 combined_scores = &combined_scores + &community_feature_scores;
407 community_assignments = Some(communities);
408 }
409
410 if self.include_structural {
412 let struct_scores =
413 compute_structural_feature_scores(x, &adjacency.view(), self.structural_weight)?;
414 combined_scores = &combined_scores + &struct_scores;
415 structural_scores = Some(struct_scores);
416 }
417
418 let selected_features = if let Some(k) = self.k {
420 select_top_k_features(&combined_scores, k)
421 } else {
422 select_features_by_threshold(&combined_scores, self.centrality_threshold)
423 };
424
425 let trained_state = Trained {
426 selected_features,
427 feature_scores: combined_scores,
428 centrality_scores,
429 community_assignments,
430 structural_scores,
431 n_features,
432 };
433
434 Ok(GraphFeatureSelector {
435 include_centrality: self.include_centrality,
436 include_community: self.include_community,
437 include_structural: self.include_structural,
438 centrality_threshold: self.centrality_threshold,
439 centrality_types: self.centrality_types,
440 community_method: self.community_method,
441 min_community_size: self.min_community_size,
442 community_weight: self.community_weight,
443 structural_weight: self.structural_weight,
444 k: self.k,
445 damping_factor: self.damping_factor,
446 max_iterations: self.max_iterations,
447 tolerance: self.tolerance,
448 adjacency: None,
449 state: PhantomData,
450 trained_state: Some(trained_state),
451 })
452 }
453}
454
455impl Transform<Array2<Float>> for GraphFeatureSelector<Trained> {
456 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
457 let trained = self.trained_state.as_ref().ok_or_else(|| {
458 SklearsError::InvalidState("Selector must be fitted before transforming".to_string())
459 })?;
460
461 let (_n_samples, n_features) = x.dim();
462
463 if n_features != trained.n_features {
464 return Err(SklearsError::InvalidInput(format!(
465 "Expected {} features, got {}",
466 trained.n_features, n_features
467 )));
468 }
469
470 if trained.selected_features.is_empty() {
471 return Err(SklearsError::InvalidState(
472 "No features were selected".to_string(),
473 ));
474 }
475
476 let selected_data = x.select(Axis(1), &trained.selected_features);
477 Ok(selected_data)
478 }
479}
480
481impl SelectorMixin for GraphFeatureSelector<Trained> {
482 fn get_support(&self) -> Result<Array1<bool>> {
483 let trained = self.trained_state.as_ref().ok_or_else(|| {
484 SklearsError::InvalidState("Selector must be fitted before getting support".to_string())
485 })?;
486
487 let mut support = Array1::from_elem(trained.n_features, false);
488 for &idx in &trained.selected_features {
489 support[idx] = true;
490 }
491 Ok(support)
492 }
493
494 fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
495 let trained = self.trained_state.as_ref().ok_or_else(|| {
496 SklearsError::InvalidState(
497 "Selector must be fitted before transforming features".to_string(),
498 )
499 })?;
500
501 let selected: Vec<usize> = indices
502 .iter()
503 .filter(|&&idx| trained.selected_features.contains(&idx))
504 .cloned()
505 .collect();
506 Ok(selected)
507 }
508}
509
510fn compute_degree_centrality(adjacency: &ArrayView2<Float>) -> Array1<Float> {
513 let n = adjacency.nrows();
514 let mut centrality = Array1::zeros(n);
515
516 for i in 0..n {
517 let degree: Float = adjacency.row(i).sum();
518 centrality[i] = degree / (n - 1) as Float;
519 }
520
521 centrality
522}
523
524fn compute_pagerank_centrality(
525 adjacency: &ArrayView2<Float>,
526 damping: Float,
527 max_iter: usize,
528 tolerance: Float,
529) -> Array1<Float> {
530 let n = adjacency.nrows();
531 let mut pagerank = Array1::from_elem(n, 1.0 / n as Float);
532 let mut new_pagerank = Array1::zeros(n);
533
534 for _ in 0..max_iter {
535 new_pagerank.fill(0.0);
536
537 for i in 0..n {
538 let out_degree: Float = adjacency.row(i).sum();
539 if out_degree > 0.0 {
540 for j in 0..n {
541 if adjacency[[i, j]] > 0.0 {
542 new_pagerank[j] += damping * pagerank[i] / out_degree;
543 }
544 }
545 }
546 }
547
548 for i in 0..n {
550 new_pagerank[i] += (1.0 - damping) / n as Float;
551 }
552
553 let diff: Float = (&new_pagerank - &pagerank).mapv(|x| x.abs()).sum();
555 if diff < tolerance {
556 break;
557 }
558
559 pagerank.assign(&new_pagerank);
560 }
561
562 pagerank
563}
564
565fn compute_betweenness_centrality(adjacency: &ArrayView2<Float>) -> Array1<Float> {
566 let n = adjacency.nrows();
567 let mut betweenness = Array1::zeros(n);
568
569 for i in 0..n {
571 let mut local_betweenness = 0.0;
572 for j in 0..n {
573 if i != j {
574 for k in 0..n {
575 if k != i && k != j {
576 let direct_jk = if adjacency[[j, k]] > 0.0 {
578 1.0
579 } else {
580 f64::INFINITY
581 };
582 let via_i = if adjacency[[j, i]] > 0.0 && adjacency[[i, k]] > 0.0 {
583 2.0
584 } else {
585 f64::INFINITY
586 };
587
588 if via_i < direct_jk {
589 local_betweenness += 1.0;
590 }
591 }
592 }
593 }
594 }
595 betweenness[i] = local_betweenness / ((n - 1) * (n - 2)) as Float;
596 }
597
598 betweenness
599}
600
601fn compute_closeness_centrality(adjacency: &ArrayView2<Float>) -> Array1<Float> {
602 let n = adjacency.nrows();
603 let mut closeness = Array1::zeros(n);
604
605 for i in 0..n {
606 let mut total_distance = 0.0;
607 let mut reachable_nodes = 0;
608
609 for j in 0..n {
610 if i != j {
611 let distance = if adjacency[[i, j]] > 0.0 {
613 1.0
614 } else {
615 let mut found_path = false;
617 for k in 0..n {
618 if adjacency[[i, k]] > 0.0 && adjacency[[k, j]] > 0.0 {
619 found_path = true;
620 break;
621 }
622 }
623 if found_path {
624 2.0
625 } else {
626 f64::INFINITY
627 }
628 };
629
630 if distance.is_finite() {
631 total_distance += distance;
632 reachable_nodes += 1;
633 }
634 }
635 }
636
637 if reachable_nodes > 0 {
638 closeness[i] = reachable_nodes as Float / total_distance;
639 }
640 }
641
642 closeness
643}
644
645fn compute_feature_centrality_scores(
648 x: &Array2<Float>,
649 centrality_scores: &HashMap<String, Array1<Float>>,
650) -> Result<Array1<Float>> {
651 let (_n_samples, n_features) = x.dim();
652 let mut feature_scores = Array1::zeros(n_features);
653
654 for j in 0..n_features {
655 let feature = x.column(j);
656 let mut total_score = 0.0;
657 let mut weight_sum = 0.0;
658
659 for (centrality_type, centrality) in centrality_scores {
660 let weight = match centrality_type.as_str() {
661 "degree" => 1.0,
662 "pagerank" => 1.5,
663 "betweenness" => 1.2,
664 "closeness" => 1.1,
665 _ => 1.0,
666 };
667
668 let correlation = compute_pearson_correlation(&feature, ¢rality.view());
669 total_score += weight * correlation.abs();
670 weight_sum += weight;
671 }
672
673 feature_scores[j] = if weight_sum > 0.0 {
674 total_score / weight_sum
675 } else {
676 0.0
677 };
678 }
679
680 Ok(feature_scores)
681}
682
683fn compute_community_feature_scores(
684 x: &Array2<Float>,
685 communities: &Array1<usize>,
686 weight: Float,
687) -> Result<Array1<Float>> {
688 let (_n_samples, n_features) = x.dim();
689 let mut feature_scores = Array1::zeros(n_features);
690
691 let max_community = communities.iter().max().cloned().unwrap_or(0);
693
694 for j in 0..n_features {
695 let feature = x.column(j);
696 let mut community_variance = 0.0;
697
698 for c in 0..=max_community {
699 let community_indices: Vec<usize> = communities
700 .iter()
701 .enumerate()
702 .filter(|(_, &comm)| comm == c)
703 .map(|(i, _)| i)
704 .collect();
705
706 if community_indices.len() > 1 {
707 let community_values: Vec<Float> =
708 community_indices.iter().map(|&i| feature[i]).collect();
709
710 let mean = community_values.iter().sum::<Float>() / community_values.len() as Float;
711 let variance = community_values
712 .iter()
713 .map(|&val| (val - mean).powi(2))
714 .sum::<Float>()
715 / community_values.len() as Float;
716
717 community_variance += variance;
718 }
719 }
720
721 feature_scores[j] = weight * community_variance;
722 }
723
724 Ok(feature_scores)
725}
726
727fn compute_structural_feature_scores(
728 x: &Array2<Float>,
729 adjacency: &ArrayView2<Float>,
730 weight: Float,
731) -> Result<Array1<Float>> {
732 let (_n_samples, n_features) = x.dim();
733 let mut feature_scores = Array1::zeros(n_features);
734
735 let clustering_coeffs = compute_clustering_coefficients(adjacency);
737
738 for j in 0..n_features {
739 let feature = x.column(j);
740 let correlation = compute_pearson_correlation(&feature, &clustering_coeffs.view());
741 feature_scores[j] = weight * correlation.abs();
742 }
743
744 Ok(feature_scores)
745}
746
747fn detect_communities_modularity(adjacency: &ArrayView2<Float>, min_size: usize) -> Array1<usize> {
750 let n = adjacency.nrows();
751 let mut communities = Array1::from_iter(0..n);
752
753 let total_edges: Float = adjacency.sum() / 2.0;
755
756 if total_edges == 0.0 {
757 return communities;
758 }
759
760 let mut improved = true;
762 while improved {
763 improved = false;
764
765 for i in 0..n {
766 let current_community = communities[i];
767 let mut best_community = current_community;
768 let mut best_modularity_gain = 0.0;
769
770 for j in 0..n {
772 if i != j {
773 let target_community = communities[j];
774 if target_community != current_community {
775 let modularity_gain = compute_modularity_gain(
776 i,
777 current_community,
778 target_community,
779 adjacency,
780 &communities,
781 total_edges,
782 );
783 if modularity_gain > best_modularity_gain {
784 best_modularity_gain = modularity_gain;
785 best_community = target_community;
786 }
787 }
788 }
789 }
790
791 if best_community != current_community {
792 communities[i] = best_community;
793 improved = true;
794 }
795 }
796 }
797
798 let mut community_counts = HashMap::new();
800 for &comm in communities.iter() {
801 *community_counts.entry(comm).or_insert(0) += 1;
802 }
803
804 let small_communities: Vec<usize> = community_counts
805 .iter()
806 .filter(|(_, &count)| count < min_size)
807 .map(|(&comm, _)| comm)
808 .collect();
809
810 for &small_comm in &small_communities {
812 let nodes_in_small: Vec<usize> = communities
813 .iter()
814 .enumerate()
815 .filter(|(_, &comm)| comm == small_comm)
816 .map(|(i, _)| i)
817 .collect();
818
819 if !nodes_in_small.is_empty() {
820 let target_comm = find_best_merge_community(&nodes_in_small, adjacency, &communities);
821 for &node in &nodes_in_small {
822 communities[node] = target_comm;
823 }
824 }
825 }
826
827 communities
828}
829
830fn detect_communities_louvain(adjacency: &ArrayView2<Float>, min_size: usize) -> Array1<usize> {
831 detect_communities_modularity(adjacency, min_size)
833}
834
835fn compute_modularity_gain(
838 node: usize,
839 from_comm: usize,
840 to_comm: usize,
841 adjacency: &ArrayView2<Float>,
842 communities: &Array1<usize>,
843 total_edges: Float,
844) -> Float {
845 if total_edges == 0.0 {
846 return 0.0;
847 }
848
849 let node_degree: Float = adjacency.row(node).sum();
851
852 let mut edges_to_from = 0.0;
853 let mut edges_to_to = 0.0;
854
855 for i in 0..adjacency.nrows() {
856 if communities[i] == from_comm && i != node {
857 edges_to_from += adjacency[[node, i]];
858 }
859 if communities[i] == to_comm {
860 edges_to_to += adjacency[[node, i]];
861 }
862 }
863
864 (edges_to_to - edges_to_from) / (2.0 * total_edges)
865 - node_degree * node_degree / (4.0 * total_edges * total_edges)
866}
867
868fn find_best_merge_community(
869 nodes: &[usize],
870 adjacency: &ArrayView2<Float>,
871 communities: &Array1<usize>,
872) -> usize {
873 let mut community_connections = HashMap::new();
874
875 for &node in nodes {
876 for i in 0..adjacency.nrows() {
877 if adjacency[[node, i]] > 0.0 && !nodes.contains(&i) {
878 let comm = communities[i];
879 *community_connections.entry(comm).or_insert(0.0) += adjacency[[node, i]];
880 }
881 }
882 }
883
884 community_connections
885 .into_iter()
886 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
887 .map(|(comm, _)| comm)
888 .unwrap_or(0)
889}
890
891fn compute_clustering_coefficients(adjacency: &ArrayView2<Float>) -> Array1<Float> {
892 let n = adjacency.nrows();
893 let mut clustering = Array1::zeros(n);
894
895 for i in 0..n {
896 let neighbors: Vec<usize> = (0..n)
897 .filter(|&j| i != j && adjacency[[i, j]] > 0.0)
898 .collect();
899
900 let degree = neighbors.len();
901 if degree < 2 {
902 clustering[i] = 0.0;
903 continue;
904 }
905
906 let mut triangles = 0;
907 for j in 0..neighbors.len() {
908 for k in (j + 1)..neighbors.len() {
909 if adjacency[[neighbors[j], neighbors[k]]] > 0.0 {
910 triangles += 1;
911 }
912 }
913 }
914
915 let possible_triangles = degree * (degree - 1) / 2;
916 clustering[i] = if possible_triangles > 0 {
917 triangles as Float / possible_triangles as Float
918 } else {
919 0.0
920 };
921 }
922
923 clustering
924}
925
926fn compute_pearson_correlation(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
927 let n = x.len();
928 if n != y.len() || n == 0 {
929 return 0.0;
930 }
931
932 let mean_x = x.sum() / n as Float;
933 let mean_y = y.sum() / n as Float;
934
935 let mut numerator = 0.0;
936 let mut sum_sq_x = 0.0;
937 let mut sum_sq_y = 0.0;
938
939 for i in 0..n {
940 let dx = x[i] - mean_x;
941 let dy = y[i] - mean_y;
942 numerator += dx * dy;
943 sum_sq_x += dx * dx;
944 sum_sq_y += dy * dy;
945 }
946
947 let denominator = (sum_sq_x * sum_sq_y).sqrt();
948 if denominator == 0.0 {
949 0.0
950 } else {
951 numerator / denominator
952 }
953}
954
955fn select_top_k_features(scores: &Array1<Float>, k: usize) -> Vec<usize> {
956 let mut indexed_scores: Vec<(usize, Float)> = scores
957 .iter()
958 .enumerate()
959 .map(|(i, &score)| (i, score))
960 .collect();
961
962 indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
963
964 indexed_scores
965 .into_iter()
966 .take(k.min(scores.len()))
967 .map(|(i, _)| i)
968 .collect()
969}
970
971fn select_features_by_threshold(scores: &Array1<Float>, threshold: Float) -> Vec<usize> {
972 scores
973 .iter()
974 .enumerate()
975 .filter(|(_, &score)| score >= threshold)
976 .map(|(i, _)| i)
977 .collect()
978}
979
980#[allow(non_snake_case)]
981#[cfg(test)]
982mod tests {
983 use super::*;
984 use scirs2_core::ndarray::Array2;
985
986 #[test]
987 fn test_graph_feature_selector_creation() {
988 let selector = GraphFeatureSelector::new();
989 assert!(selector.include_centrality);
990 assert!(selector.include_community);
991 assert!(selector.include_structural);
992 }
993
994 #[test]
995 fn test_graph_feature_selector_builder() {
996 let selector = GraphFeatureSelector::builder()
997 .include_centrality(true)
998 .include_community(false)
999 .centrality_threshold(0.5)
1000 .k(3)
1001 .build();
1002
1003 assert!(selector.include_centrality);
1004 assert!(!selector.include_community);
1005 assert_eq!(selector.centrality_threshold, 0.5);
1006 assert_eq!(selector.k, Some(3));
1007 }
1008
1009 #[test]
1010 fn test_degree_centrality() {
1011 let adjacency =
1012 Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0])
1013 .unwrap();
1014
1015 let centrality = compute_degree_centrality(&adjacency.view());
1016
1017 assert_eq!(centrality.len(), 3);
1019 for &c in centrality.iter() {
1020 assert!((c - 1.0).abs() < 1e-6);
1021 }
1022 }
1023
1024 #[test]
1025 fn test_fit_transform_basic() {
1026 let adjacency = Array2::from_shape_vec(
1027 (4, 4),
1028 vec![
1029 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
1030 ],
1031 )
1032 .unwrap();
1033
1034 let features = Array2::from_shape_vec(
1035 (4, 3),
1036 vec![1.0, 2.0, 3.0, 2.0, 3.0, 1.0, 3.0, 1.0, 4.0, 1.0, 4.0, 2.0],
1037 )
1038 .unwrap();
1039
1040 let target = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
1041
1042 let selector = GraphFeatureSelector::builder()
1043 .k(2)
1044 .with_adjacency(adjacency)
1045 .build();
1046
1047 let trained = selector.fit(&features, &target).unwrap();
1048 let transformed = trained.transform(&features).unwrap();
1049
1050 assert_eq!(transformed.ncols(), 2);
1051 assert_eq!(transformed.nrows(), 4);
1052 }
1053
1054 #[test]
1055 fn test_get_support() {
1056 let adjacency =
1057 Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0])
1058 .unwrap();
1059
1060 let features = Array2::from_shape_vec(
1061 (3, 4),
1062 vec![1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 1.0, 5.0, 3.0, 1.0, 4.0, 2.0],
1063 )
1064 .unwrap();
1065
1066 let target = Array1::from_vec(vec![0.0, 1.0, 1.0]);
1067
1068 let selector = GraphFeatureSelector::builder()
1069 .k(2)
1070 .with_adjacency(adjacency)
1071 .build();
1072
1073 let trained = selector.fit(&features, &target).unwrap();
1074 let support = trained.get_support().unwrap();
1075
1076 assert_eq!(support.len(), 4);
1077 assert_eq!(support.iter().filter(|&&x| x).count(), 2);
1078 }
1079
1080 #[test]
1081 fn test_clustering_coefficients() {
1082 let adjacency = Array2::from_shape_vec(
1083 (4, 4),
1084 vec![
1085 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0,
1086 ],
1087 )
1088 .unwrap();
1089
1090 let clustering = compute_clustering_coefficients(&adjacency.view());
1091
1092 assert_eq!(clustering.len(), 4);
1093 assert!((clustering[1] - 2.0 / 3.0).abs() < 1e-6);
1096 }
1097
1098 #[test]
1099 fn test_pagerank_centrality() {
1100 let adjacency =
1101 Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0])
1102 .unwrap();
1103
1104 let pagerank = compute_pagerank_centrality(&adjacency.view(), 0.85, 100, 1e-6);
1105
1106 assert_eq!(pagerank.len(), 3);
1107 assert!(pagerank[0] > pagerank[1]);
1109 assert!(pagerank[0] > pagerank[2]);
1110 }
1111}