sklears_semi_supervised/
composable_graph.rs

1//! Composable graph construction methods
2//!
3//! This module provides a flexible, composable framework for constructing graphs
4//! used in semi-supervised learning. It allows combining different graph construction
5//! strategies and applying transformations in a pipeline.
6
7use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView2};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use sklears_core::types::Float;
10
11/// Trait for graph construction strategies
12pub trait GraphBuilder: Clone {
13    /// Build a graph from data
14    fn build(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>>;
15}
16
17/// Trait for graph transformations
18pub trait GraphTransform: Clone {
19    /// Transform a graph
20    fn transform(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>>;
21}
22
23/// K-Nearest Neighbors graph builder
24#[derive(Debug, Clone)]
25pub struct KNNGraphBuilder {
26    n_neighbors: usize,
27    weighted: bool,
28    sigma: f64,
29}
30
31impl KNNGraphBuilder {
32    /// Create a new KNN graph builder
33    pub fn new(n_neighbors: usize) -> Self {
34        Self {
35            n_neighbors,
36            weighted: true,
37            sigma: 1.0,
38        }
39    }
40
41    /// Set whether to use weighted edges
42    pub fn weighted(mut self, weighted: bool) -> Self {
43        self.weighted = weighted;
44        self
45    }
46
47    /// Set the kernel bandwidth
48    pub fn sigma(mut self, sigma: f64) -> Self {
49        self.sigma = sigma;
50        self
51    }
52}
53
54impl GraphBuilder for KNNGraphBuilder {
55    fn build(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>> {
56        let n_samples = X.nrows();
57        let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
58
59        for i in 0..n_samples {
60            let mut distances: Vec<(usize, f64)> = Vec::new();
61
62            for j in 0..n_samples {
63                if i != j {
64                    let diff = &X.row(i) - &X.row(j);
65                    let dist = diff.mapv(|x| x * x).sum().sqrt();
66                    distances.push((j, dist));
67                }
68            }
69
70            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
71
72            for &(j, dist) in distances.iter().take(self.n_neighbors) {
73                if self.weighted {
74                    let weight = (-dist * dist / (2.0 * self.sigma * self.sigma)).exp();
75                    graph[[i, j]] = weight;
76                } else {
77                    graph[[i, j]] = 1.0;
78                }
79            }
80        }
81
82        Ok(graph)
83    }
84}
85
86/// Epsilon-ball graph builder
87#[derive(Debug, Clone)]
88pub struct EpsilonGraphBuilder {
89    epsilon: f64,
90    weighted: bool,
91    sigma: f64,
92}
93
94impl EpsilonGraphBuilder {
95    /// Create a new epsilon graph builder
96    pub fn new(epsilon: f64) -> Self {
97        Self {
98            epsilon,
99            weighted: true,
100            sigma: 1.0,
101        }
102    }
103
104    /// Set whether to use weighted edges
105    pub fn weighted(mut self, weighted: bool) -> Self {
106        self.weighted = weighted;
107        self
108    }
109
110    /// Set the kernel bandwidth
111    pub fn sigma(mut self, sigma: f64) -> Self {
112        self.sigma = sigma;
113        self
114    }
115}
116
117impl GraphBuilder for EpsilonGraphBuilder {
118    fn build(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>> {
119        let n_samples = X.nrows();
120        let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
121
122        for i in 0..n_samples {
123            for j in 0..n_samples {
124                if i != j {
125                    let diff = &X.row(i) - &X.row(j);
126                    let dist = diff.mapv(|x| x * x).sum().sqrt();
127
128                    if dist < self.epsilon {
129                        if self.weighted {
130                            let weight = (-dist * dist / (2.0 * self.sigma * self.sigma)).exp();
131                            graph[[i, j]] = weight;
132                        } else {
133                            graph[[i, j]] = 1.0;
134                        }
135                    }
136                }
137            }
138        }
139
140        Ok(graph)
141    }
142}
143
144/// Symmetrize graph transformation
145#[derive(Debug, Clone)]
146pub struct SymmetrizeTransform {
147    method: String,
148}
149
150impl SymmetrizeTransform {
151    /// Create a new symmetrize transform
152    pub fn new(method: String) -> Self {
153        Self { method }
154    }
155}
156
157impl GraphTransform for SymmetrizeTransform {
158    fn transform(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>> {
159        let n = graph.nrows();
160        let mut symmetric = graph.clone();
161
162        match self.method.as_str() {
163            "max" => {
164                for i in 0..n {
165                    for j in (i + 1)..n {
166                        let value = graph[[i, j]].max(graph[[j, i]]);
167                        symmetric[[i, j]] = value;
168                        symmetric[[j, i]] = value;
169                    }
170                }
171            }
172            "average" => {
173                for i in 0..n {
174                    for j in (i + 1)..n {
175                        let value = (graph[[i, j]] + graph[[j, i]]) / 2.0;
176                        symmetric[[i, j]] = value;
177                        symmetric[[j, i]] = value;
178                    }
179                }
180            }
181            _ => {
182                return Err(SklearsError::InvalidInput(format!(
183                    "Unknown symmetrization method: {}",
184                    self.method
185                )));
186            }
187        }
188
189        Ok(symmetric)
190    }
191}
192
193/// Normalize graph transformation
194#[derive(Debug, Clone)]
195pub struct NormalizeTransform {
196    method: String,
197}
198
199impl NormalizeTransform {
200    /// Create a new normalize transform
201    pub fn new(method: String) -> Self {
202        Self { method }
203    }
204}
205
206impl GraphTransform for NormalizeTransform {
207    fn transform(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>> {
208        let n = graph.nrows();
209        let mut normalized = graph.clone();
210
211        match self.method.as_str() {
212            "row" => {
213                for i in 0..n {
214                    let row_sum: f64 = graph.row(i).sum();
215                    if row_sum > 0.0 {
216                        for j in 0..n {
217                            normalized[[i, j]] /= row_sum;
218                        }
219                    }
220                }
221            }
222            "symmetric" => {
223                // D^{-1/2} A D^{-1/2}
224                let mut degrees = Array1::<f64>::zeros(n);
225                for i in 0..n {
226                    degrees[i] = graph.row(i).sum();
227                }
228
229                for i in 0..n {
230                    for j in 0..n {
231                        if degrees[i] > 0.0 && degrees[j] > 0.0 {
232                            normalized[[i, j]] = graph[[i, j]] / (degrees[i] * degrees[j]).sqrt();
233                        }
234                    }
235                }
236            }
237            _ => {
238                return Err(SklearsError::InvalidInput(format!(
239                    "Unknown normalization method: {}",
240                    self.method
241                )));
242            }
243        }
244
245        Ok(normalized)
246    }
247}
248
249/// Sparsify graph transformation
250#[derive(Debug, Clone)]
251pub struct SparsifyTransform {
252    threshold: f64,
253}
254
255impl SparsifyTransform {
256    /// Create a new sparsify transform
257    pub fn new(threshold: f64) -> Self {
258        Self { threshold }
259    }
260}
261
262impl GraphTransform for SparsifyTransform {
263    fn transform(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>> {
264        let mut sparse = graph.clone();
265        let n = graph.nrows();
266
267        for i in 0..n {
268            for j in 0..n {
269                if sparse[[i, j]] < self.threshold {
270                    sparse[[i, j]] = 0.0;
271                }
272            }
273        }
274
275        Ok(sparse)
276    }
277}
278
279/// Composable graph pipeline
280#[derive(Clone)]
281pub struct GraphPipeline {
282    builder: Box<dyn GraphBuilderTrait>,
283    transforms: Vec<Box<dyn GraphTransformTrait>>,
284}
285
286impl std::fmt::Debug for GraphPipeline {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        f.debug_struct("GraphPipeline")
289            .field("builder", &"Box<dyn GraphBuilderTrait>")
290            .field(
291                "transforms",
292                &format!("{} transforms", self.transforms.len()),
293            )
294            .finish()
295    }
296}
297
298// Helper traits with object safety
299trait GraphBuilderTrait {
300    fn build_graph(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>>;
301    fn clone_box(&self) -> Box<dyn GraphBuilderTrait>;
302}
303
304trait GraphTransformTrait {
305    fn transform_graph(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>>;
306    fn clone_box(&self) -> Box<dyn GraphTransformTrait>;
307}
308
309impl<T: GraphBuilder + 'static> GraphBuilderTrait for T {
310    fn build_graph(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>> {
311        self.build(X)
312    }
313
314    fn clone_box(&self) -> Box<dyn GraphBuilderTrait> {
315        Box::new(self.clone())
316    }
317}
318
319impl<T: GraphTransform + 'static> GraphTransformTrait for T {
320    fn transform_graph(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>> {
321        self.transform(graph)
322    }
323
324    fn clone_box(&self) -> Box<dyn GraphTransformTrait> {
325        Box::new(self.clone())
326    }
327}
328
329impl Clone for Box<dyn GraphBuilderTrait> {
330    fn clone(&self) -> Self {
331        self.clone_box()
332    }
333}
334
335impl Clone for Box<dyn GraphTransformTrait> {
336    fn clone(&self) -> Self {
337        self.clone_box()
338    }
339}
340
341impl GraphPipeline {
342    /// Create a new graph pipeline
343    pub fn new<B: GraphBuilder + 'static>(builder: B) -> Self {
344        Self {
345            builder: Box::new(builder),
346            transforms: Vec::new(),
347        }
348    }
349
350    /// Add a transformation to the pipeline
351    pub fn add_transform<T: GraphTransform + 'static>(mut self, transform: T) -> Self {
352        self.transforms.push(Box::new(transform));
353        self
354    }
355
356    /// Build the graph with all transformations
357    pub fn build(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>> {
358        let mut graph = self.builder.build_graph(X)?;
359
360        for transform in &self.transforms {
361            graph = transform.transform_graph(&graph)?;
362        }
363
364        Ok(graph)
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use scirs2_core::array;
372
373    #[test]
374    #[allow(non_snake_case)]
375    fn test_knn_graph_builder() {
376        let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
377        let builder = KNNGraphBuilder::new(2).weighted(true).sigma(1.0);
378
379        let graph = builder.build(&X.view()).unwrap();
380
381        assert_eq!(graph.dim(), (4, 4));
382        // Each node should be connected to 2 neighbors
383        for i in 0..4 {
384            let row_nonzero = graph.row(i).iter().filter(|&&x| x > 0.0).count();
385            assert_eq!(row_nonzero, 2);
386        }
387    }
388
389    #[test]
390    #[allow(non_snake_case)]
391    fn test_epsilon_graph_builder() {
392        let X = array![[0.0, 0.0], [1.0, 0.0], [10.0, 0.0]];
393        let builder = EpsilonGraphBuilder::new(2.0).weighted(false);
394
395        let graph = builder.build(&X.view()).unwrap();
396
397        assert_eq!(graph.dim(), (3, 3));
398        // Nodes 0 and 1 should be connected (distance 1.0 < 2.0)
399        assert_eq!(graph[[0, 1]], 1.0);
400        assert_eq!(graph[[1, 0]], 1.0);
401        // Nodes 0 and 2 should not be connected (distance 10.0 > 2.0)
402        assert_eq!(graph[[0, 2]], 0.0);
403    }
404
405    #[test]
406    fn test_symmetrize_transform() {
407        let mut graph = Array2::<f64>::zeros((3, 3));
408        graph[[0, 1]] = 1.0;
409        graph[[1, 0]] = 2.0;
410        graph[[1, 2]] = 3.0;
411        graph[[2, 1]] = 4.0;
412
413        let transform = SymmetrizeTransform::new("max".to_string());
414        let symmetric = transform.transform(&graph).unwrap();
415
416        assert_eq!(symmetric[[0, 1]], 2.0);
417        assert_eq!(symmetric[[1, 0]], 2.0);
418        assert_eq!(symmetric[[1, 2]], 4.0);
419        assert_eq!(symmetric[[2, 1]], 4.0);
420    }
421
422    #[test]
423    fn test_normalize_transform() {
424        let mut graph = Array2::<f64>::zeros((3, 3));
425        graph[[0, 1]] = 2.0;
426        graph[[0, 2]] = 2.0;
427        graph[[1, 0]] = 1.0;
428
429        let transform = NormalizeTransform::new("row".to_string());
430        let normalized = transform.transform(&graph).unwrap();
431
432        // Row 0 should sum to 1.0
433        let row_sum: f64 = normalized.row(0).sum();
434        assert!((row_sum - 1.0).abs() < 1e-10);
435    }
436
437    #[test]
438    fn test_sparsify_transform() {
439        let mut graph = Array2::<f64>::zeros((3, 3));
440        graph[[0, 1]] = 0.1;
441        graph[[0, 2]] = 0.5;
442        graph[[1, 2]] = 0.3;
443
444        let transform = SparsifyTransform::new(0.2);
445        let sparse = transform.transform(&graph).unwrap();
446
447        assert_eq!(sparse[[0, 1]], 0.0); // Below threshold
448        assert_eq!(sparse[[0, 2]], 0.5); // Above threshold
449        assert_eq!(sparse[[1, 2]], 0.3); // Above threshold
450    }
451
452    #[test]
453    #[allow(non_snake_case)]
454    fn test_graph_pipeline() {
455        let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
456
457        let pipeline = GraphPipeline::new(KNNGraphBuilder::new(2).weighted(true))
458            .add_transform(SymmetrizeTransform::new("average".to_string()))
459            .add_transform(NormalizeTransform::new("row".to_string()));
460
461        let graph = pipeline.build(&X.view()).unwrap();
462
463        assert_eq!(graph.dim(), (4, 4));
464
465        // Check that each row sums to approximately 1.0 (normalized)
466        for i in 0..4 {
467            let row_sum: f64 = graph.row(i).sum();
468            assert!((row_sum - 1.0).abs() < 1e-6 || row_sum == 0.0);
469        }
470
471        // Note: row normalization breaks symmetry, so we don't check for it
472        // Let's check connectivity instead
473        let total_edges: usize = graph.iter().filter(|&&x| x > 0.0).count();
474        assert!(total_edges > 0);
475    }
476
477    #[test]
478    #[allow(non_snake_case)]
479    fn test_symmetric_pipeline() {
480        let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
481
482        let pipeline = GraphPipeline::new(KNNGraphBuilder::new(1).weighted(true))
483            .add_transform(SymmetrizeTransform::new("max".to_string()));
484
485        let graph = pipeline.build(&X.view()).unwrap();
486
487        assert_eq!(graph.dim(), (3, 3));
488
489        // Check symmetry (without row normalization)
490        for i in 0..3 {
491            for j in 0..3 {
492                assert!((graph[[i, j]] - graph[[j, i]]).abs() < 1e-10);
493            }
494        }
495    }
496}