sklears_feature_selection/domain_specific/
graph_features.rs

1//! Graph-based feature selection for network and graph-structured data.
2//!
3//! This module provides specialized feature selection capabilities for graph-structured data,
4//! including social networks, biological networks, knowledge graphs, and other network data.
5//! It implements various graph-theoretic measures to identify important features based on
6//! network topology and structural properties.
7//!
8//! # Features
9//!
10//! - **Centrality-based selection**: Uses degree, betweenness, closeness, and PageRank centrality measures
11//! - **Community detection**: Leverages community structure for feature importance
12//! - **Structural properties**: Considers clustering coefficients, path lengths, and connectivity
13//! - **Multi-scale analysis**: Analyzes features at different graph scales
14//! - **Dynamic graphs**: Supports temporal graph feature selection
15//!
16//! # Examples
17//!
18//! ## Basic Graph Feature Selection
19//!
20//! ```rust,ignore
21//! use sklears_feature_selection::domain_specific::graph_features::GraphFeatureSelector;
22//! use scirs2_core::ndarray::{Array2, Array1};
23//!
24//! // Sample adjacency matrix for a small graph
25//! let adjacency = Array2::from_shape_vec((5, 5), vec![
26//!     0.0, 1.0, 1.0, 0.0, 0.0,
27//!     1.0, 0.0, 1.0, 1.0, 0.0,
28//!     1.0, 1.0, 0.0, 0.0, 1.0,
29//!     0.0, 1.0, 0.0, 0.0, 1.0,
30//!     0.0, 0.0, 1.0, 1.0, 0.0,
31//! ]).unwrap();
32//!
33//! // Node features (each row is a node, columns are features)
34//! let features = Array2::from_shape_vec((5, 4), vec![
35//!     1.0, 2.0, 3.0, 4.0,
36//!     2.0, 3.0, 1.0, 5.0,
37//!     3.0, 1.0, 4.0, 2.0,
38//!     1.0, 4.0, 2.0, 3.0,
39//!     4.0, 2.0, 5.0, 1.0,
40//! ]).unwrap();
41//!
42//! // Target values for supervised selection
43//! let target = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0, 1.0]);
44//!
45//! let selector = GraphFeatureSelector::builder()
46//!     .include_centrality(true)
47//!     .include_community(true)
48//!     .include_structural(true)
49//!     .centrality_threshold(0.3)
50//!     .k(2)
51//!     .build();
52//!
53//! let trained = selector.fit(&features, &target, Some(&adjacency))?;
54//! let selected_features = trained.transform(&features, Some(&adjacency))?;
55//! ```
56//!
57//! ## Centrality-based Feature Selection
58//!
59//! ```rust,ignore
60//! use sklears_feature_selection::domain_specific::graph_features::GraphFeatureSelector;
61//!
62//! let selector = GraphFeatureSelector::builder()
63//!     .include_centrality(true)
64//!     .include_community(false)
65//!     .include_structural(false)
66//!     .centrality_types(vec!["degree", "pagerank", "betweenness"])
67//!     .centrality_threshold(0.5)
68//!     .build();
69//! ```
70//!
71//! ## Community-aware Feature Selection
72//!
73//! ```rust,ignore
74//! let selector = GraphFeatureSelector::builder()
75//!     .include_community(true)
76//!     .community_method("modularity")
77//!     .min_community_size(3)
78//!     .community_weight(0.7)
79//!     .build();
80//! ```
81
82use crate::base::SelectorMixin;
83use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
84use sklears_core::error::{Result as SklResult, SklearsError};
85use sklears_core::traits::{Estimator, Fit, Transform};
86use std::collections::HashMap;
87use std::marker::PhantomData;
88
89type Result<T> = SklResult<T>;
90type Float = f64;
91
92#[derive(Debug, Clone)]
93pub struct Untrained;
94
95#[derive(Debug, Clone)]
96pub struct Trained {
97    selected_features: Vec<usize>,
98    feature_scores: Array1<Float>,
99    centrality_scores: Option<HashMap<String, Array1<Float>>>,
100    community_assignments: Option<Array1<usize>>,
101    structural_scores: Option<Array1<Float>>,
102    n_features: usize,
103}
104
105/// Graph-based feature selector for network-structured data.
106///
107/// This selector uses graph topology and structure to identify important features
108/// by analyzing centrality measures, community structure, and other graph properties.
109/// It's particularly useful for social networks, biological networks, citation networks,
110/// and other graph-structured data where network topology provides valuable information
111/// about feature importance.
112#[derive(Debug, Clone)]
113pub struct GraphFeatureSelector<State = Untrained> {
114    include_centrality: bool,
115    include_community: bool,
116    include_structural: bool,
117    centrality_threshold: Float,
118    centrality_types: Vec<String>,
119    community_method: String,
120    min_community_size: usize,
121    community_weight: Float,
122    structural_weight: Float,
123    k: Option<usize>,
124    damping_factor: Float,
125    max_iterations: usize,
126    tolerance: Float,
127    adjacency: Option<Array2<Float>>,
128    state: PhantomData<State>,
129    trained_state: Option<Trained>,
130}
131
132impl Default for GraphFeatureSelector<Untrained> {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138impl GraphFeatureSelector<Untrained> {
139    /// Creates a new GraphFeatureSelector with default parameters.
140    pub fn new() -> Self {
141        Self {
142            include_centrality: true,
143            include_community: true,
144            include_structural: true,
145            centrality_threshold: 0.1,
146            centrality_types: vec!["degree".to_string(), "pagerank".to_string()],
147            community_method: "modularity".to_string(),
148            min_community_size: 2,
149            community_weight: 0.5,
150            structural_weight: 0.3,
151            k: None,
152            damping_factor: 0.85,
153            max_iterations: 100,
154            tolerance: 1e-6,
155            adjacency: None,
156            state: PhantomData,
157            trained_state: None,
158        }
159    }
160
161    /// Creates a builder for configuring the GraphFeatureSelector.
162    pub fn builder() -> GraphFeatureSelectorBuilder {
163        GraphFeatureSelectorBuilder::new()
164    }
165}
166
167/// Builder for GraphFeatureSelector configuration.
168#[derive(Debug)]
169pub struct GraphFeatureSelectorBuilder {
170    include_centrality: bool,
171    include_community: bool,
172    include_structural: bool,
173    centrality_threshold: Float,
174    centrality_types: Vec<String>,
175    community_method: String,
176    min_community_size: usize,
177    community_weight: Float,
178    structural_weight: Float,
179    k: Option<usize>,
180    damping_factor: Float,
181    max_iterations: usize,
182    tolerance: Float,
183    adjacency: Option<Array2<Float>>,
184}
185
186impl Default for GraphFeatureSelectorBuilder {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192impl GraphFeatureSelectorBuilder {
193    pub fn new() -> Self {
194        Self {
195            include_centrality: true,
196            include_community: true,
197            include_structural: true,
198            centrality_threshold: 0.1,
199            centrality_types: vec!["degree".to_string(), "pagerank".to_string()],
200            community_method: "modularity".to_string(),
201            min_community_size: 2,
202            community_weight: 0.5,
203            structural_weight: 0.3,
204            k: None,
205            damping_factor: 0.85,
206            max_iterations: 100,
207            tolerance: 1e-6,
208            adjacency: None,
209        }
210    }
211
212    /// Whether to include centrality-based features.
213    pub fn include_centrality(mut self, include: bool) -> Self {
214        self.include_centrality = include;
215        self
216    }
217
218    /// Whether to include community-based features.
219    pub fn include_community(mut self, include: bool) -> Self {
220        self.include_community = include;
221        self
222    }
223
224    /// Whether to include structural property features.
225    pub fn include_structural(mut self, include: bool) -> Self {
226        self.include_structural = include;
227        self
228    }
229
230    /// Minimum centrality score threshold for feature selection.
231    pub fn centrality_threshold(mut self, threshold: Float) -> Self {
232        self.centrality_threshold = threshold;
233        self
234    }
235
236    /// Types of centrality measures to compute.
237    pub fn centrality_types(mut self, types: Vec<&str>) -> Self {
238        self.centrality_types = types.iter().map(|s| s.to_string()).collect();
239        self
240    }
241
242    /// Community detection method to use.
243    pub fn community_method(mut self, method: &str) -> Self {
244        self.community_method = method.to_string();
245        self
246    }
247
248    /// Minimum size for communities to be considered.
249    pub fn min_community_size(mut self, size: usize) -> Self {
250        self.min_community_size = size;
251        self
252    }
253
254    /// Weight for community-based features in scoring.
255    pub fn community_weight(mut self, weight: Float) -> Self {
256        self.community_weight = weight;
257        self
258    }
259
260    /// Weight for structural property features in scoring.
261    pub fn structural_weight(mut self, weight: Float) -> Self {
262        self.structural_weight = weight;
263        self
264    }
265
266    /// Number of top features to select.
267    pub fn k(mut self, k: usize) -> Self {
268        self.k = Some(k);
269        self
270    }
271
272    /// Damping factor for PageRank computation.
273    pub fn damping_factor(mut self, factor: Float) -> Self {
274        self.damping_factor = factor;
275        self
276    }
277
278    /// Maximum iterations for iterative algorithms.
279    pub fn max_iterations(mut self, iterations: usize) -> Self {
280        self.max_iterations = iterations;
281        self
282    }
283
284    /// Convergence tolerance for iterative algorithms.
285    pub fn tolerance(mut self, tol: Float) -> Self {
286        self.tolerance = tol;
287        self
288    }
289
290    /// Sets the adjacency matrix for graph feature selection.
291    pub fn with_adjacency(mut self, adjacency: Array2<Float>) -> Self {
292        self.adjacency = Some(adjacency);
293        self
294    }
295
296    /// Builds the GraphFeatureSelector.
297    pub fn build(self) -> GraphFeatureSelector<Untrained> {
298        GraphFeatureSelector {
299            include_centrality: self.include_centrality,
300            include_community: self.include_community,
301            include_structural: self.include_structural,
302            centrality_threshold: self.centrality_threshold,
303            centrality_types: self.centrality_types,
304            community_method: self.community_method,
305            min_community_size: self.min_community_size,
306            community_weight: self.community_weight,
307            structural_weight: self.structural_weight,
308            k: self.k,
309            damping_factor: self.damping_factor,
310            max_iterations: self.max_iterations,
311            tolerance: self.tolerance,
312            adjacency: self.adjacency,
313            state: PhantomData,
314            trained_state: None,
315        }
316    }
317}
318
319impl Estimator for GraphFeatureSelector<Untrained> {
320    type Config = ();
321    type Error = sklears_core::error::SklearsError;
322    type Float = Float;
323
324    fn config(&self) -> &Self::Config {
325        &()
326    }
327}
328
329impl Estimator for GraphFeatureSelector<Trained> {
330    type Config = ();
331    type Error = sklears_core::error::SklearsError;
332    type Float = Float;
333
334    fn config(&self) -> &Self::Config {
335        &()
336    }
337}
338
339impl Fit<Array2<Float>, Array1<Float>> for GraphFeatureSelector<Untrained> {
340    type Fitted = GraphFeatureSelector<Trained>;
341
342    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
343        let (n_samples, n_features) = x.dim();
344
345        if y.len() != n_samples {
346            return Err(SklearsError::InvalidInput(
347                "Number of samples in X and y must match".to_string(),
348            ));
349        }
350
351        let adjacency = self.adjacency.ok_or_else(|| {
352            SklearsError::InvalidInput(
353                "Adjacency matrix is required for graph feature selection".to_string(),
354            )
355        })?;
356
357        if adjacency.dim() != (n_samples, n_samples) {
358            return Err(SklearsError::InvalidInput(
359                "Adjacency matrix must be square with same number of nodes as samples".to_string(),
360            ));
361        }
362
363        let mut centrality_scores = None;
364        let mut community_assignments = None;
365        let mut structural_scores = None;
366        let mut combined_scores = Array1::zeros(n_features);
367
368        // Compute centrality-based scores
369        if self.include_centrality {
370            let mut centrality_map = HashMap::new();
371
372            for centrality_type in &self.centrality_types {
373                let scores = match centrality_type.as_str() {
374                    "degree" => compute_degree_centrality(&adjacency.view()),
375                    "pagerank" => compute_pagerank_centrality(
376                        &adjacency.view(),
377                        self.damping_factor,
378                        self.max_iterations,
379                        self.tolerance,
380                    ),
381                    "betweenness" => compute_betweenness_centrality(&adjacency.view()),
382                    "closeness" => compute_closeness_centrality(&adjacency.view()),
383                    _ => Array1::zeros(n_samples),
384                };
385                centrality_map.insert(centrality_type.clone(), scores);
386            }
387
388            // Combine centrality scores with feature correlations
389            let feature_centrality_scores = compute_feature_centrality_scores(x, &centrality_map)?;
390            combined_scores = &combined_scores + &feature_centrality_scores;
391            centrality_scores = Some(centrality_map);
392        }
393
394        // Compute community-based scores
395        if self.include_community {
396            let communities = match self.community_method.as_str() {
397                "modularity" => {
398                    detect_communities_modularity(&adjacency.view(), self.min_community_size)
399                }
400                "louvain" => detect_communities_louvain(&adjacency.view(), self.min_community_size),
401                _ => Array1::zeros(n_samples),
402            };
403
404            let community_feature_scores =
405                compute_community_feature_scores(x, &communities, self.community_weight)?;
406            combined_scores = &combined_scores + &community_feature_scores;
407            community_assignments = Some(communities);
408        }
409
410        // Compute structural property scores
411        if self.include_structural {
412            let struct_scores =
413                compute_structural_feature_scores(x, &adjacency.view(), self.structural_weight)?;
414            combined_scores = &combined_scores + &struct_scores;
415            structural_scores = Some(struct_scores);
416        }
417
418        // Select features based on combined scores
419        let selected_features = if let Some(k) = self.k {
420            select_top_k_features(&combined_scores, k)
421        } else {
422            select_features_by_threshold(&combined_scores, self.centrality_threshold)
423        };
424
425        let trained_state = Trained {
426            selected_features,
427            feature_scores: combined_scores,
428            centrality_scores,
429            community_assignments,
430            structural_scores,
431            n_features,
432        };
433
434        Ok(GraphFeatureSelector {
435            include_centrality: self.include_centrality,
436            include_community: self.include_community,
437            include_structural: self.include_structural,
438            centrality_threshold: self.centrality_threshold,
439            centrality_types: self.centrality_types,
440            community_method: self.community_method,
441            min_community_size: self.min_community_size,
442            community_weight: self.community_weight,
443            structural_weight: self.structural_weight,
444            k: self.k,
445            damping_factor: self.damping_factor,
446            max_iterations: self.max_iterations,
447            tolerance: self.tolerance,
448            adjacency: None,
449            state: PhantomData,
450            trained_state: Some(trained_state),
451        })
452    }
453}
454
455impl Transform<Array2<Float>> for GraphFeatureSelector<Trained> {
456    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
457        let trained = self.trained_state.as_ref().ok_or_else(|| {
458            SklearsError::InvalidState("Selector must be fitted before transforming".to_string())
459        })?;
460
461        let (_n_samples, n_features) = x.dim();
462
463        if n_features != trained.n_features {
464            return Err(SklearsError::InvalidInput(format!(
465                "Expected {} features, got {}",
466                trained.n_features, n_features
467            )));
468        }
469
470        if trained.selected_features.is_empty() {
471            return Err(SklearsError::InvalidState(
472                "No features were selected".to_string(),
473            ));
474        }
475
476        let selected_data = x.select(Axis(1), &trained.selected_features);
477        Ok(selected_data)
478    }
479}
480
481impl SelectorMixin for GraphFeatureSelector<Trained> {
482    fn get_support(&self) -> Result<Array1<bool>> {
483        let trained = self.trained_state.as_ref().ok_or_else(|| {
484            SklearsError::InvalidState("Selector must be fitted before getting support".to_string())
485        })?;
486
487        let mut support = Array1::from_elem(trained.n_features, false);
488        for &idx in &trained.selected_features {
489            support[idx] = true;
490        }
491        Ok(support)
492    }
493
494    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
495        let trained = self.trained_state.as_ref().ok_or_else(|| {
496            SklearsError::InvalidState(
497                "Selector must be fitted before transforming features".to_string(),
498            )
499        })?;
500
501        let selected: Vec<usize> = indices
502            .iter()
503            .filter(|&&idx| trained.selected_features.contains(&idx))
504            .cloned()
505            .collect();
506        Ok(selected)
507    }
508}
509
510// Centrality computation functions
511
512fn compute_degree_centrality(adjacency: &ArrayView2<Float>) -> Array1<Float> {
513    let n = adjacency.nrows();
514    let mut centrality = Array1::zeros(n);
515
516    for i in 0..n {
517        let degree: Float = adjacency.row(i).sum();
518        centrality[i] = degree / (n - 1) as Float;
519    }
520
521    centrality
522}
523
524fn compute_pagerank_centrality(
525    adjacency: &ArrayView2<Float>,
526    damping: Float,
527    max_iter: usize,
528    tolerance: Float,
529) -> Array1<Float> {
530    let n = adjacency.nrows();
531    let mut pagerank = Array1::from_elem(n, 1.0 / n as Float);
532    let mut new_pagerank = Array1::zeros(n);
533
534    for _ in 0..max_iter {
535        new_pagerank.fill(0.0);
536
537        for i in 0..n {
538            let out_degree: Float = adjacency.row(i).sum();
539            if out_degree > 0.0 {
540                for j in 0..n {
541                    if adjacency[[i, j]] > 0.0 {
542                        new_pagerank[j] += damping * pagerank[i] / out_degree;
543                    }
544                }
545            }
546        }
547
548        // Add random jump probability
549        for i in 0..n {
550            new_pagerank[i] += (1.0 - damping) / n as Float;
551        }
552
553        // Check convergence
554        let diff: Float = (&new_pagerank - &pagerank).mapv(|x| x.abs()).sum();
555        if diff < tolerance {
556            break;
557        }
558
559        pagerank.assign(&new_pagerank);
560    }
561
562    pagerank
563}
564
565fn compute_betweenness_centrality(adjacency: &ArrayView2<Float>) -> Array1<Float> {
566    let n = adjacency.nrows();
567    let mut betweenness = Array1::zeros(n);
568
569    // Simplified betweenness centrality (approximate)
570    for i in 0..n {
571        let mut local_betweenness = 0.0;
572        for j in 0..n {
573            if i != j {
574                for k in 0..n {
575                    if k != i && k != j {
576                        // Check if node i is on shortest path from j to k
577                        let direct_jk = if adjacency[[j, k]] > 0.0 {
578                            1.0
579                        } else {
580                            f64::INFINITY
581                        };
582                        let via_i = if adjacency[[j, i]] > 0.0 && adjacency[[i, k]] > 0.0 {
583                            2.0
584                        } else {
585                            f64::INFINITY
586                        };
587
588                        if via_i < direct_jk {
589                            local_betweenness += 1.0;
590                        }
591                    }
592                }
593            }
594        }
595        betweenness[i] = local_betweenness / ((n - 1) * (n - 2)) as Float;
596    }
597
598    betweenness
599}
600
601fn compute_closeness_centrality(adjacency: &ArrayView2<Float>) -> Array1<Float> {
602    let n = adjacency.nrows();
603    let mut closeness = Array1::zeros(n);
604
605    for i in 0..n {
606        let mut total_distance = 0.0;
607        let mut reachable_nodes = 0;
608
609        for j in 0..n {
610            if i != j {
611                // Simplified distance calculation (direct connection = 1, else 2 if connected through neighbors)
612                let distance = if adjacency[[i, j]] > 0.0 {
613                    1.0
614                } else {
615                    // Check for 2-hop connection
616                    let mut found_path = false;
617                    for k in 0..n {
618                        if adjacency[[i, k]] > 0.0 && adjacency[[k, j]] > 0.0 {
619                            found_path = true;
620                            break;
621                        }
622                    }
623                    if found_path {
624                        2.0
625                    } else {
626                        f64::INFINITY
627                    }
628                };
629
630                if distance.is_finite() {
631                    total_distance += distance;
632                    reachable_nodes += 1;
633                }
634            }
635        }
636
637        if reachable_nodes > 0 {
638            closeness[i] = reachable_nodes as Float / total_distance;
639        }
640    }
641
642    closeness
643}
644
645// Feature scoring functions
646
647fn compute_feature_centrality_scores(
648    x: &Array2<Float>,
649    centrality_scores: &HashMap<String, Array1<Float>>,
650) -> Result<Array1<Float>> {
651    let (_n_samples, n_features) = x.dim();
652    let mut feature_scores = Array1::zeros(n_features);
653
654    for j in 0..n_features {
655        let feature = x.column(j);
656        let mut total_score = 0.0;
657        let mut weight_sum = 0.0;
658
659        for (centrality_type, centrality) in centrality_scores {
660            let weight = match centrality_type.as_str() {
661                "degree" => 1.0,
662                "pagerank" => 1.5,
663                "betweenness" => 1.2,
664                "closeness" => 1.1,
665                _ => 1.0,
666            };
667
668            let correlation = compute_pearson_correlation(&feature, &centrality.view());
669            total_score += weight * correlation.abs();
670            weight_sum += weight;
671        }
672
673        feature_scores[j] = if weight_sum > 0.0 {
674            total_score / weight_sum
675        } else {
676            0.0
677        };
678    }
679
680    Ok(feature_scores)
681}
682
683fn compute_community_feature_scores(
684    x: &Array2<Float>,
685    communities: &Array1<usize>,
686    weight: Float,
687) -> Result<Array1<Float>> {
688    let (_n_samples, n_features) = x.dim();
689    let mut feature_scores = Array1::zeros(n_features);
690
691    // Find unique communities
692    let max_community = communities.iter().max().cloned().unwrap_or(0);
693
694    for j in 0..n_features {
695        let feature = x.column(j);
696        let mut community_variance = 0.0;
697
698        for c in 0..=max_community {
699            let community_indices: Vec<usize> = communities
700                .iter()
701                .enumerate()
702                .filter(|(_, &comm)| comm == c)
703                .map(|(i, _)| i)
704                .collect();
705
706            if community_indices.len() > 1 {
707                let community_values: Vec<Float> =
708                    community_indices.iter().map(|&i| feature[i]).collect();
709
710                let mean = community_values.iter().sum::<Float>() / community_values.len() as Float;
711                let variance = community_values
712                    .iter()
713                    .map(|&val| (val - mean).powi(2))
714                    .sum::<Float>()
715                    / community_values.len() as Float;
716
717                community_variance += variance;
718            }
719        }
720
721        feature_scores[j] = weight * community_variance;
722    }
723
724    Ok(feature_scores)
725}
726
727fn compute_structural_feature_scores(
728    x: &Array2<Float>,
729    adjacency: &ArrayView2<Float>,
730    weight: Float,
731) -> Result<Array1<Float>> {
732    let (_n_samples, n_features) = x.dim();
733    let mut feature_scores = Array1::zeros(n_features);
734
735    // Compute clustering coefficients
736    let clustering_coeffs = compute_clustering_coefficients(adjacency);
737
738    for j in 0..n_features {
739        let feature = x.column(j);
740        let correlation = compute_pearson_correlation(&feature, &clustering_coeffs.view());
741        feature_scores[j] = weight * correlation.abs();
742    }
743
744    Ok(feature_scores)
745}
746
747// Community detection functions
748
749fn detect_communities_modularity(adjacency: &ArrayView2<Float>, min_size: usize) -> Array1<usize> {
750    let n = adjacency.nrows();
751    let mut communities = Array1::from_iter(0..n);
752
753    // Simple modularity-based community detection (simplified implementation)
754    let total_edges: Float = adjacency.sum() / 2.0;
755
756    if total_edges == 0.0 {
757        return communities;
758    }
759
760    // Greedy modularity optimization
761    let mut improved = true;
762    while improved {
763        improved = false;
764
765        for i in 0..n {
766            let current_community = communities[i];
767            let mut best_community = current_community;
768            let mut best_modularity_gain = 0.0;
769
770            // Try moving node i to different communities
771            for j in 0..n {
772                if i != j {
773                    let target_community = communities[j];
774                    if target_community != current_community {
775                        let modularity_gain = compute_modularity_gain(
776                            i,
777                            current_community,
778                            target_community,
779                            adjacency,
780                            &communities,
781                            total_edges,
782                        );
783                        if modularity_gain > best_modularity_gain {
784                            best_modularity_gain = modularity_gain;
785                            best_community = target_community;
786                        }
787                    }
788                }
789            }
790
791            if best_community != current_community {
792                communities[i] = best_community;
793                improved = true;
794            }
795        }
796    }
797
798    // Ensure minimum community size
799    let mut community_counts = HashMap::new();
800    for &comm in communities.iter() {
801        *community_counts.entry(comm).or_insert(0) += 1;
802    }
803
804    let small_communities: Vec<usize> = community_counts
805        .iter()
806        .filter(|(_, &count)| count < min_size)
807        .map(|(&comm, _)| comm)
808        .collect();
809
810    // Merge small communities with largest neighbor
811    for &small_comm in &small_communities {
812        let nodes_in_small: Vec<usize> = communities
813            .iter()
814            .enumerate()
815            .filter(|(_, &comm)| comm == small_comm)
816            .map(|(i, _)| i)
817            .collect();
818
819        if !nodes_in_small.is_empty() {
820            let target_comm = find_best_merge_community(&nodes_in_small, adjacency, &communities);
821            for &node in &nodes_in_small {
822                communities[node] = target_comm;
823            }
824        }
825    }
826
827    communities
828}
829
830fn detect_communities_louvain(adjacency: &ArrayView2<Float>, min_size: usize) -> Array1<usize> {
831    // Simplified Louvain algorithm (placeholder implementation)
832    detect_communities_modularity(adjacency, min_size)
833}
834
835// Utility functions
836
837fn compute_modularity_gain(
838    node: usize,
839    from_comm: usize,
840    to_comm: usize,
841    adjacency: &ArrayView2<Float>,
842    communities: &Array1<usize>,
843    total_edges: Float,
844) -> Float {
845    if total_edges == 0.0 {
846        return 0.0;
847    }
848
849    // Simplified modularity gain calculation
850    let node_degree: Float = adjacency.row(node).sum();
851
852    let mut edges_to_from = 0.0;
853    let mut edges_to_to = 0.0;
854
855    for i in 0..adjacency.nrows() {
856        if communities[i] == from_comm && i != node {
857            edges_to_from += adjacency[[node, i]];
858        }
859        if communities[i] == to_comm {
860            edges_to_to += adjacency[[node, i]];
861        }
862    }
863
864    (edges_to_to - edges_to_from) / (2.0 * total_edges)
865        - node_degree * node_degree / (4.0 * total_edges * total_edges)
866}
867
868fn find_best_merge_community(
869    nodes: &[usize],
870    adjacency: &ArrayView2<Float>,
871    communities: &Array1<usize>,
872) -> usize {
873    let mut community_connections = HashMap::new();
874
875    for &node in nodes {
876        for i in 0..adjacency.nrows() {
877            if adjacency[[node, i]] > 0.0 && !nodes.contains(&i) {
878                let comm = communities[i];
879                *community_connections.entry(comm).or_insert(0.0) += adjacency[[node, i]];
880            }
881        }
882    }
883
884    community_connections
885        .into_iter()
886        .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
887        .map(|(comm, _)| comm)
888        .unwrap_or(0)
889}
890
891fn compute_clustering_coefficients(adjacency: &ArrayView2<Float>) -> Array1<Float> {
892    let n = adjacency.nrows();
893    let mut clustering = Array1::zeros(n);
894
895    for i in 0..n {
896        let neighbors: Vec<usize> = (0..n)
897            .filter(|&j| i != j && adjacency[[i, j]] > 0.0)
898            .collect();
899
900        let degree = neighbors.len();
901        if degree < 2 {
902            clustering[i] = 0.0;
903            continue;
904        }
905
906        let mut triangles = 0;
907        for j in 0..neighbors.len() {
908            for k in (j + 1)..neighbors.len() {
909                if adjacency[[neighbors[j], neighbors[k]]] > 0.0 {
910                    triangles += 1;
911                }
912            }
913        }
914
915        let possible_triangles = degree * (degree - 1) / 2;
916        clustering[i] = if possible_triangles > 0 {
917            triangles as Float / possible_triangles as Float
918        } else {
919            0.0
920        };
921    }
922
923    clustering
924}
925
926fn compute_pearson_correlation(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
927    let n = x.len();
928    if n != y.len() || n == 0 {
929        return 0.0;
930    }
931
932    let mean_x = x.sum() / n as Float;
933    let mean_y = y.sum() / n as Float;
934
935    let mut numerator = 0.0;
936    let mut sum_sq_x = 0.0;
937    let mut sum_sq_y = 0.0;
938
939    for i in 0..n {
940        let dx = x[i] - mean_x;
941        let dy = y[i] - mean_y;
942        numerator += dx * dy;
943        sum_sq_x += dx * dx;
944        sum_sq_y += dy * dy;
945    }
946
947    let denominator = (sum_sq_x * sum_sq_y).sqrt();
948    if denominator == 0.0 {
949        0.0
950    } else {
951        numerator / denominator
952    }
953}
954
955fn select_top_k_features(scores: &Array1<Float>, k: usize) -> Vec<usize> {
956    let mut indexed_scores: Vec<(usize, Float)> = scores
957        .iter()
958        .enumerate()
959        .map(|(i, &score)| (i, score))
960        .collect();
961
962    indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
963
964    indexed_scores
965        .into_iter()
966        .take(k.min(scores.len()))
967        .map(|(i, _)| i)
968        .collect()
969}
970
971fn select_features_by_threshold(scores: &Array1<Float>, threshold: Float) -> Vec<usize> {
972    scores
973        .iter()
974        .enumerate()
975        .filter(|(_, &score)| score >= threshold)
976        .map(|(i, _)| i)
977        .collect()
978}
979
980#[allow(non_snake_case)]
981#[cfg(test)]
982mod tests {
983    use super::*;
984    use scirs2_core::ndarray::Array2;
985
986    #[test]
987    fn test_graph_feature_selector_creation() {
988        let selector = GraphFeatureSelector::new();
989        assert!(selector.include_centrality);
990        assert!(selector.include_community);
991        assert!(selector.include_structural);
992    }
993
994    #[test]
995    fn test_graph_feature_selector_builder() {
996        let selector = GraphFeatureSelector::builder()
997            .include_centrality(true)
998            .include_community(false)
999            .centrality_threshold(0.5)
1000            .k(3)
1001            .build();
1002
1003        assert!(selector.include_centrality);
1004        assert!(!selector.include_community);
1005        assert_eq!(selector.centrality_threshold, 0.5);
1006        assert_eq!(selector.k, Some(3));
1007    }
1008
1009    #[test]
1010    fn test_degree_centrality() {
1011        let adjacency =
1012            Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0])
1013                .unwrap();
1014
1015        let centrality = compute_degree_centrality(&adjacency.view());
1016
1017        // All nodes have degree 2, so centrality should be 2/2 = 1.0 for all
1018        assert_eq!(centrality.len(), 3);
1019        for &c in centrality.iter() {
1020            assert!((c - 1.0).abs() < 1e-6);
1021        }
1022    }
1023
1024    #[test]
1025    fn test_fit_transform_basic() {
1026        let adjacency = Array2::from_shape_vec(
1027            (4, 4),
1028            vec![
1029                0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
1030            ],
1031        )
1032        .unwrap();
1033
1034        let features = Array2::from_shape_vec(
1035            (4, 3),
1036            vec![1.0, 2.0, 3.0, 2.0, 3.0, 1.0, 3.0, 1.0, 4.0, 1.0, 4.0, 2.0],
1037        )
1038        .unwrap();
1039
1040        let target = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
1041
1042        let selector = GraphFeatureSelector::builder()
1043            .k(2)
1044            .with_adjacency(adjacency)
1045            .build();
1046
1047        let trained = selector.fit(&features, &target).unwrap();
1048        let transformed = trained.transform(&features).unwrap();
1049
1050        assert_eq!(transformed.ncols(), 2);
1051        assert_eq!(transformed.nrows(), 4);
1052    }
1053
1054    #[test]
1055    fn test_get_support() {
1056        let adjacency =
1057            Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0])
1058                .unwrap();
1059
1060        let features = Array2::from_shape_vec(
1061            (3, 4),
1062            vec![1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 1.0, 5.0, 3.0, 1.0, 4.0, 2.0],
1063        )
1064        .unwrap();
1065
1066        let target = Array1::from_vec(vec![0.0, 1.0, 1.0]);
1067
1068        let selector = GraphFeatureSelector::builder()
1069            .k(2)
1070            .with_adjacency(adjacency)
1071            .build();
1072
1073        let trained = selector.fit(&features, &target).unwrap();
1074        let support = trained.get_support().unwrap();
1075
1076        assert_eq!(support.len(), 4);
1077        assert_eq!(support.iter().filter(|&&x| x).count(), 2);
1078    }
1079
1080    #[test]
1081    fn test_clustering_coefficients() {
1082        let adjacency = Array2::from_shape_vec(
1083            (4, 4),
1084            vec![
1085                0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0,
1086            ],
1087        )
1088        .unwrap();
1089
1090        let clustering = compute_clustering_coefficients(&adjacency.view());
1091
1092        assert_eq!(clustering.len(), 4);
1093        // Node 1 has 3 neighbors (0, 2, 3) with 2 connections between them
1094        // So clustering coefficient should be 2/3
1095        assert!((clustering[1] - 2.0 / 3.0).abs() < 1e-6);
1096    }
1097
1098    #[test]
1099    fn test_pagerank_centrality() {
1100        let adjacency =
1101            Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0])
1102                .unwrap();
1103
1104        let pagerank = compute_pagerank_centrality(&adjacency.view(), 0.85, 100, 1e-6);
1105
1106        assert_eq!(pagerank.len(), 3);
1107        // Node 0 should have highest PageRank as it receives links from both other nodes
1108        assert!(pagerank[0] > pagerank[1]);
1109        assert!(pagerank[0] > pagerank[2]);
1110    }
1111}