Skip to main content

scirs2_graph/condensation/
types.rs

1//! Types for graph condensation (dataset distillation for graphs).
2//!
3//! This module defines configuration, result, and quality metric types
4//! used throughout the condensation pipeline.
5
6use scirs2_core::ndarray::Array2;
7
8/// Method used for graph condensation.
9///
10/// Each variant represents a different algorithmic approach to
11/// reducing a graph while preserving its structural properties.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum CondensationMethod {
15    /// Greedy farthest-point sampling (coreset).
16    /// Picks nodes that maximize minimum distance to the selected set.
17    KCenter,
18    /// Degree-weighted sampling with feature diversity.
19    /// Combines structural importance (degree) with feature-space coverage.
20    ImportanceSampling,
21    /// Gradient matching between original and synthetic graphs.
22    /// Optimises a small synthetic graph so that GNN gradients match
23    /// those computed on the full graph.
24    GradientMatching,
25    /// Kernel herding selection.
26    /// Greedily picks points that minimize the Maximum Mean Discrepancy
27    /// (MMD) between the selected subset and the full dataset.
28    Herding,
29}
30
31/// Configuration for graph condensation.
32#[derive(Debug, Clone)]
33pub struct CondensationConfig {
34    /// Number of nodes in the condensed graph.
35    pub target_nodes: usize,
36    /// Method to use for condensation.
37    pub method: CondensationMethod,
38    /// Maximum number of iterations (for iterative methods such as GradientMatching).
39    pub max_iterations: usize,
40    /// Learning rate (for iterative methods such as GradientMatching).
41    pub learning_rate: f64,
42}
43
44impl Default for CondensationConfig {
45    fn default() -> Self {
46        Self {
47            target_nodes: 100,
48            method: CondensationMethod::KCenter,
49            max_iterations: 200,
50            learning_rate: 0.01,
51        }
52    }
53}
54
55/// A condensed (distilled) graph produced by condensation.
56#[derive(Debug, Clone)]
57pub struct CondensedGraph {
58    /// Adjacency matrix of the condensed graph (target_nodes x target_nodes).
59    pub adjacency: Array2<f64>,
60    /// Feature matrix of the condensed graph (target_nodes x feature_dim).
61    pub features: Array2<f64>,
62    /// Node labels in the condensed graph.
63    pub labels: Vec<usize>,
64    /// Mapping from condensed node index to original node index.
65    /// For synthetic methods (e.g. GradientMatching) this maps to the
66    /// nearest original node.
67    pub source_mapping: Vec<usize>,
68}
69
70/// Quality metrics that measure how well the condensed graph
71/// preserves the properties of the original graph.
72#[derive(Debug, Clone)]
73pub struct QualityMetrics {
74    /// KL divergence between degree distributions of original and condensed graphs.
75    /// Lower is better; 0.0 means identical distributions.
76    pub degree_distribution_distance: f64,
77    /// L2 distance between the top-k eigenvalues of the graph Laplacians.
78    /// Lower is better; 0.0 means identical spectral properties.
79    pub spectral_distance: f64,
80    /// Fraction of original label classes that are present in the condensed graph.
81    /// 1.0 means all labels are covered.
82    pub label_coverage: f64,
83}
84
85/// Full result of a condensation operation.
86#[derive(Debug, Clone)]
87pub struct CondensationResult {
88    /// The condensed graph.
89    pub condensed: CondensedGraph,
90    /// Compression ratio: original_nodes / condensed_nodes.
91    pub compression_ratio: f64,
92    /// Quality metrics comparing the condensed graph to the original.
93    pub quality_metrics: QualityMetrics,
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_default_config() {
102        let config = CondensationConfig::default();
103        assert_eq!(config.target_nodes, 100);
104        assert_eq!(config.method, CondensationMethod::KCenter);
105        assert_eq!(config.max_iterations, 200);
106        assert!((config.learning_rate - 0.01).abs() < 1e-12);
107    }
108
109    #[test]
110    fn test_condensation_method_variants() {
111        let methods = [
112            CondensationMethod::KCenter,
113            CondensationMethod::ImportanceSampling,
114            CondensationMethod::GradientMatching,
115            CondensationMethod::Herding,
116        ];
117        // Ensure all variants are distinct
118        for i in 0..methods.len() {
119            for j in (i + 1)..methods.len() {
120                assert_ne!(methods[i], methods[j]);
121            }
122        }
123    }
124
125    #[test]
126    fn test_condensation_method_clone_and_copy() {
127        let method = CondensationMethod::Herding;
128        let cloned = method;
129        let copied = method;
130        assert_eq!(method, cloned);
131        assert_eq!(method, copied);
132    }
133
134    #[test]
135    fn test_condensed_graph_creation() {
136        let adj = Array2::<f64>::zeros((3, 3));
137        let features = Array2::<f64>::ones((3, 2));
138        let labels = vec![0, 1, 0];
139        let source_mapping = vec![0, 5, 10];
140
141        let graph = CondensedGraph {
142            adjacency: adj.clone(),
143            features: features.clone(),
144            labels: labels.clone(),
145            source_mapping: source_mapping.clone(),
146        };
147
148        assert_eq!(graph.adjacency.nrows(), 3);
149        assert_eq!(graph.adjacency.ncols(), 3);
150        assert_eq!(graph.features.nrows(), 3);
151        assert_eq!(graph.features.ncols(), 2);
152        assert_eq!(graph.labels, vec![0, 1, 0]);
153        assert_eq!(graph.source_mapping, vec![0, 5, 10]);
154    }
155
156    #[test]
157    fn test_condensed_graph_clone() {
158        let graph = CondensedGraph {
159            adjacency: Array2::<f64>::eye(2),
160            features: Array2::<f64>::ones((2, 3)),
161            labels: vec![1, 2],
162            source_mapping: vec![0, 1],
163        };
164
165        let cloned = graph.clone();
166        assert_eq!(cloned.labels, graph.labels);
167        assert_eq!(cloned.source_mapping, graph.source_mapping);
168        assert_eq!(cloned.adjacency, graph.adjacency);
169    }
170
171    #[test]
172    fn test_quality_metrics_creation() {
173        let metrics = QualityMetrics {
174            degree_distribution_distance: 0.05,
175            spectral_distance: 0.1,
176            label_coverage: 0.95,
177        };
178
179        assert!((metrics.degree_distribution_distance - 0.05).abs() < 1e-12);
180        assert!((metrics.spectral_distance - 0.1).abs() < 1e-12);
181        assert!((metrics.label_coverage - 0.95).abs() < 1e-12);
182    }
183
184    #[test]
185    fn test_quality_metrics_perfect() {
186        let metrics = QualityMetrics {
187            degree_distribution_distance: 0.0,
188            spectral_distance: 0.0,
189            label_coverage: 1.0,
190        };
191
192        assert!((metrics.degree_distribution_distance).abs() < 1e-12);
193        assert!((metrics.spectral_distance).abs() < 1e-12);
194        assert!((metrics.label_coverage - 1.0).abs() < 1e-12);
195    }
196
197    #[test]
198    fn test_condensation_result_creation() {
199        let result = CondensationResult {
200            condensed: CondensedGraph {
201                adjacency: Array2::<f64>::zeros((2, 2)),
202                features: Array2::<f64>::ones((2, 4)),
203                labels: vec![0, 1],
204                source_mapping: vec![3, 7],
205            },
206            compression_ratio: 5.0,
207            quality_metrics: QualityMetrics {
208                degree_distribution_distance: 0.1,
209                spectral_distance: 0.2,
210                label_coverage: 1.0,
211            },
212        };
213
214        assert!((result.compression_ratio - 5.0).abs() < 1e-12);
215        assert_eq!(result.condensed.labels.len(), 2);
216        assert!((result.quality_metrics.label_coverage - 1.0).abs() < 1e-12);
217    }
218
219    #[test]
220    fn test_config_custom() {
221        let config = CondensationConfig {
222            target_nodes: 50,
223            method: CondensationMethod::GradientMatching,
224            max_iterations: 500,
225            learning_rate: 0.001,
226        };
227
228        assert_eq!(config.target_nodes, 50);
229        assert_eq!(config.method, CondensationMethod::GradientMatching);
230        assert_eq!(config.max_iterations, 500);
231        assert!((config.learning_rate - 0.001).abs() < 1e-12);
232    }
233}