scirs2_graph/condensation/
types.rs1use scirs2_core::ndarray::Array2;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum CondensationMethod {
15 KCenter,
18 ImportanceSampling,
21 GradientMatching,
25 Herding,
29}
30
31#[derive(Debug, Clone)]
33pub struct CondensationConfig {
34 pub target_nodes: usize,
36 pub method: CondensationMethod,
38 pub max_iterations: usize,
40 pub learning_rate: f64,
42}
43
44impl Default for CondensationConfig {
45 fn default() -> Self {
46 Self {
47 target_nodes: 100,
48 method: CondensationMethod::KCenter,
49 max_iterations: 200,
50 learning_rate: 0.01,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct CondensedGraph {
58 pub adjacency: Array2<f64>,
60 pub features: Array2<f64>,
62 pub labels: Vec<usize>,
64 pub source_mapping: Vec<usize>,
68}
69
70#[derive(Debug, Clone)]
73pub struct QualityMetrics {
74 pub degree_distribution_distance: f64,
77 pub spectral_distance: f64,
80 pub label_coverage: f64,
83}
84
85#[derive(Debug, Clone)]
87pub struct CondensationResult {
88 pub condensed: CondensedGraph,
90 pub compression_ratio: f64,
92 pub quality_metrics: QualityMetrics,
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn test_default_config() {
102 let config = CondensationConfig::default();
103 assert_eq!(config.target_nodes, 100);
104 assert_eq!(config.method, CondensationMethod::KCenter);
105 assert_eq!(config.max_iterations, 200);
106 assert!((config.learning_rate - 0.01).abs() < 1e-12);
107 }
108
109 #[test]
110 fn test_condensation_method_variants() {
111 let methods = [
112 CondensationMethod::KCenter,
113 CondensationMethod::ImportanceSampling,
114 CondensationMethod::GradientMatching,
115 CondensationMethod::Herding,
116 ];
117 for i in 0..methods.len() {
119 for j in (i + 1)..methods.len() {
120 assert_ne!(methods[i], methods[j]);
121 }
122 }
123 }
124
125 #[test]
126 fn test_condensation_method_clone_and_copy() {
127 let method = CondensationMethod::Herding;
128 let cloned = method;
129 let copied = method;
130 assert_eq!(method, cloned);
131 assert_eq!(method, copied);
132 }
133
134 #[test]
135 fn test_condensed_graph_creation() {
136 let adj = Array2::<f64>::zeros((3, 3));
137 let features = Array2::<f64>::ones((3, 2));
138 let labels = vec![0, 1, 0];
139 let source_mapping = vec![0, 5, 10];
140
141 let graph = CondensedGraph {
142 adjacency: adj.clone(),
143 features: features.clone(),
144 labels: labels.clone(),
145 source_mapping: source_mapping.clone(),
146 };
147
148 assert_eq!(graph.adjacency.nrows(), 3);
149 assert_eq!(graph.adjacency.ncols(), 3);
150 assert_eq!(graph.features.nrows(), 3);
151 assert_eq!(graph.features.ncols(), 2);
152 assert_eq!(graph.labels, vec![0, 1, 0]);
153 assert_eq!(graph.source_mapping, vec![0, 5, 10]);
154 }
155
156 #[test]
157 fn test_condensed_graph_clone() {
158 let graph = CondensedGraph {
159 adjacency: Array2::<f64>::eye(2),
160 features: Array2::<f64>::ones((2, 3)),
161 labels: vec![1, 2],
162 source_mapping: vec![0, 1],
163 };
164
165 let cloned = graph.clone();
166 assert_eq!(cloned.labels, graph.labels);
167 assert_eq!(cloned.source_mapping, graph.source_mapping);
168 assert_eq!(cloned.adjacency, graph.adjacency);
169 }
170
171 #[test]
172 fn test_quality_metrics_creation() {
173 let metrics = QualityMetrics {
174 degree_distribution_distance: 0.05,
175 spectral_distance: 0.1,
176 label_coverage: 0.95,
177 };
178
179 assert!((metrics.degree_distribution_distance - 0.05).abs() < 1e-12);
180 assert!((metrics.spectral_distance - 0.1).abs() < 1e-12);
181 assert!((metrics.label_coverage - 0.95).abs() < 1e-12);
182 }
183
184 #[test]
185 fn test_quality_metrics_perfect() {
186 let metrics = QualityMetrics {
187 degree_distribution_distance: 0.0,
188 spectral_distance: 0.0,
189 label_coverage: 1.0,
190 };
191
192 assert!((metrics.degree_distribution_distance).abs() < 1e-12);
193 assert!((metrics.spectral_distance).abs() < 1e-12);
194 assert!((metrics.label_coverage - 1.0).abs() < 1e-12);
195 }
196
197 #[test]
198 fn test_condensation_result_creation() {
199 let result = CondensationResult {
200 condensed: CondensedGraph {
201 adjacency: Array2::<f64>::zeros((2, 2)),
202 features: Array2::<f64>::ones((2, 4)),
203 labels: vec![0, 1],
204 source_mapping: vec![3, 7],
205 },
206 compression_ratio: 5.0,
207 quality_metrics: QualityMetrics {
208 degree_distribution_distance: 0.1,
209 spectral_distance: 0.2,
210 label_coverage: 1.0,
211 },
212 };
213
214 assert!((result.compression_ratio - 5.0).abs() < 1e-12);
215 assert_eq!(result.condensed.labels.len(), 2);
216 assert!((result.quality_metrics.label_coverage - 1.0).abs() < 1e-12);
217 }
218
219 #[test]
220 fn test_config_custom() {
221 let config = CondensationConfig {
222 target_nodes: 50,
223 method: CondensationMethod::GradientMatching,
224 max_iterations: 500,
225 learning_rate: 0.001,
226 };
227
228 assert_eq!(config.target_nodes, 50);
229 assert_eq!(config.method, CondensationMethod::GradientMatching);
230 assert_eq!(config.max_iterations, 500);
231 assert!((config.learning_rate - 0.001).abs() < 1e-12);
232 }
233}