1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView2};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use sklears_core::types::Float;
10
11pub trait GraphBuilder: Clone {
13 fn build(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>>;
15}
16
17pub trait GraphTransform: Clone {
19 fn transform(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>>;
21}
22
23#[derive(Debug, Clone)]
25pub struct KNNGraphBuilder {
26 n_neighbors: usize,
27 weighted: bool,
28 sigma: f64,
29}
30
31impl KNNGraphBuilder {
32 pub fn new(n_neighbors: usize) -> Self {
34 Self {
35 n_neighbors,
36 weighted: true,
37 sigma: 1.0,
38 }
39 }
40
41 pub fn weighted(mut self, weighted: bool) -> Self {
43 self.weighted = weighted;
44 self
45 }
46
47 pub fn sigma(mut self, sigma: f64) -> Self {
49 self.sigma = sigma;
50 self
51 }
52}
53
54impl GraphBuilder for KNNGraphBuilder {
55 fn build(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>> {
56 let n_samples = X.nrows();
57 let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
58
59 for i in 0..n_samples {
60 let mut distances: Vec<(usize, f64)> = Vec::new();
61
62 for j in 0..n_samples {
63 if i != j {
64 let diff = &X.row(i) - &X.row(j);
65 let dist = diff.mapv(|x| x * x).sum().sqrt();
66 distances.push((j, dist));
67 }
68 }
69
70 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
71
72 for &(j, dist) in distances.iter().take(self.n_neighbors) {
73 if self.weighted {
74 let weight = (-dist * dist / (2.0 * self.sigma * self.sigma)).exp();
75 graph[[i, j]] = weight;
76 } else {
77 graph[[i, j]] = 1.0;
78 }
79 }
80 }
81
82 Ok(graph)
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct EpsilonGraphBuilder {
89 epsilon: f64,
90 weighted: bool,
91 sigma: f64,
92}
93
94impl EpsilonGraphBuilder {
95 pub fn new(epsilon: f64) -> Self {
97 Self {
98 epsilon,
99 weighted: true,
100 sigma: 1.0,
101 }
102 }
103
104 pub fn weighted(mut self, weighted: bool) -> Self {
106 self.weighted = weighted;
107 self
108 }
109
110 pub fn sigma(mut self, sigma: f64) -> Self {
112 self.sigma = sigma;
113 self
114 }
115}
116
117impl GraphBuilder for EpsilonGraphBuilder {
118 fn build(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>> {
119 let n_samples = X.nrows();
120 let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
121
122 for i in 0..n_samples {
123 for j in 0..n_samples {
124 if i != j {
125 let diff = &X.row(i) - &X.row(j);
126 let dist = diff.mapv(|x| x * x).sum().sqrt();
127
128 if dist < self.epsilon {
129 if self.weighted {
130 let weight = (-dist * dist / (2.0 * self.sigma * self.sigma)).exp();
131 graph[[i, j]] = weight;
132 } else {
133 graph[[i, j]] = 1.0;
134 }
135 }
136 }
137 }
138 }
139
140 Ok(graph)
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct SymmetrizeTransform {
147 method: String,
148}
149
150impl SymmetrizeTransform {
151 pub fn new(method: String) -> Self {
153 Self { method }
154 }
155}
156
157impl GraphTransform for SymmetrizeTransform {
158 fn transform(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>> {
159 let n = graph.nrows();
160 let mut symmetric = graph.clone();
161
162 match self.method.as_str() {
163 "max" => {
164 for i in 0..n {
165 for j in (i + 1)..n {
166 let value = graph[[i, j]].max(graph[[j, i]]);
167 symmetric[[i, j]] = value;
168 symmetric[[j, i]] = value;
169 }
170 }
171 }
172 "average" => {
173 for i in 0..n {
174 for j in (i + 1)..n {
175 let value = (graph[[i, j]] + graph[[j, i]]) / 2.0;
176 symmetric[[i, j]] = value;
177 symmetric[[j, i]] = value;
178 }
179 }
180 }
181 _ => {
182 return Err(SklearsError::InvalidInput(format!(
183 "Unknown symmetrization method: {}",
184 self.method
185 )));
186 }
187 }
188
189 Ok(symmetric)
190 }
191}
192
193#[derive(Debug, Clone)]
195pub struct NormalizeTransform {
196 method: String,
197}
198
199impl NormalizeTransform {
200 pub fn new(method: String) -> Self {
202 Self { method }
203 }
204}
205
206impl GraphTransform for NormalizeTransform {
207 fn transform(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>> {
208 let n = graph.nrows();
209 let mut normalized = graph.clone();
210
211 match self.method.as_str() {
212 "row" => {
213 for i in 0..n {
214 let row_sum: f64 = graph.row(i).sum();
215 if row_sum > 0.0 {
216 for j in 0..n {
217 normalized[[i, j]] /= row_sum;
218 }
219 }
220 }
221 }
222 "symmetric" => {
223 let mut degrees = Array1::<f64>::zeros(n);
225 for i in 0..n {
226 degrees[i] = graph.row(i).sum();
227 }
228
229 for i in 0..n {
230 for j in 0..n {
231 if degrees[i] > 0.0 && degrees[j] > 0.0 {
232 normalized[[i, j]] = graph[[i, j]] / (degrees[i] * degrees[j]).sqrt();
233 }
234 }
235 }
236 }
237 _ => {
238 return Err(SklearsError::InvalidInput(format!(
239 "Unknown normalization method: {}",
240 self.method
241 )));
242 }
243 }
244
245 Ok(normalized)
246 }
247}
248
249#[derive(Debug, Clone)]
251pub struct SparsifyTransform {
252 threshold: f64,
253}
254
255impl SparsifyTransform {
256 pub fn new(threshold: f64) -> Self {
258 Self { threshold }
259 }
260}
261
262impl GraphTransform for SparsifyTransform {
263 fn transform(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>> {
264 let mut sparse = graph.clone();
265 let n = graph.nrows();
266
267 for i in 0..n {
268 for j in 0..n {
269 if sparse[[i, j]] < self.threshold {
270 sparse[[i, j]] = 0.0;
271 }
272 }
273 }
274
275 Ok(sparse)
276 }
277}
278
279#[derive(Clone)]
281pub struct GraphPipeline {
282 builder: Box<dyn GraphBuilderTrait>,
283 transforms: Vec<Box<dyn GraphTransformTrait>>,
284}
285
286impl std::fmt::Debug for GraphPipeline {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 f.debug_struct("GraphPipeline")
289 .field("builder", &"Box<dyn GraphBuilderTrait>")
290 .field(
291 "transforms",
292 &format!("{} transforms", self.transforms.len()),
293 )
294 .finish()
295 }
296}
297
298trait GraphBuilderTrait {
300 fn build_graph(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>>;
301 fn clone_box(&self) -> Box<dyn GraphBuilderTrait>;
302}
303
304trait GraphTransformTrait {
305 fn transform_graph(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>>;
306 fn clone_box(&self) -> Box<dyn GraphTransformTrait>;
307}
308
309impl<T: GraphBuilder + 'static> GraphBuilderTrait for T {
310 fn build_graph(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>> {
311 self.build(X)
312 }
313
314 fn clone_box(&self) -> Box<dyn GraphBuilderTrait> {
315 Box::new(self.clone())
316 }
317}
318
319impl<T: GraphTransform + 'static> GraphTransformTrait for T {
320 fn transform_graph(&self, graph: &Array2<f64>) -> SklResult<Array2<f64>> {
321 self.transform(graph)
322 }
323
324 fn clone_box(&self) -> Box<dyn GraphTransformTrait> {
325 Box::new(self.clone())
326 }
327}
328
329impl Clone for Box<dyn GraphBuilderTrait> {
330 fn clone(&self) -> Self {
331 self.clone_box()
332 }
333}
334
335impl Clone for Box<dyn GraphTransformTrait> {
336 fn clone(&self) -> Self {
337 self.clone_box()
338 }
339}
340
341impl GraphPipeline {
342 pub fn new<B: GraphBuilder + 'static>(builder: B) -> Self {
344 Self {
345 builder: Box::new(builder),
346 transforms: Vec::new(),
347 }
348 }
349
350 pub fn add_transform<T: GraphTransform + 'static>(mut self, transform: T) -> Self {
352 self.transforms.push(Box::new(transform));
353 self
354 }
355
356 pub fn build(&self, X: &ArrayView2<Float>) -> SklResult<Array2<f64>> {
358 let mut graph = self.builder.build_graph(X)?;
359
360 for transform in &self.transforms {
361 graph = transform.transform_graph(&graph)?;
362 }
363
364 Ok(graph)
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use scirs2_core::array;
372
373 #[test]
374 #[allow(non_snake_case)]
375 fn test_knn_graph_builder() {
376 let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
377 let builder = KNNGraphBuilder::new(2).weighted(true).sigma(1.0);
378
379 let graph = builder.build(&X.view()).unwrap();
380
381 assert_eq!(graph.dim(), (4, 4));
382 for i in 0..4 {
384 let row_nonzero = graph.row(i).iter().filter(|&&x| x > 0.0).count();
385 assert_eq!(row_nonzero, 2);
386 }
387 }
388
389 #[test]
390 #[allow(non_snake_case)]
391 fn test_epsilon_graph_builder() {
392 let X = array![[0.0, 0.0], [1.0, 0.0], [10.0, 0.0]];
393 let builder = EpsilonGraphBuilder::new(2.0).weighted(false);
394
395 let graph = builder.build(&X.view()).unwrap();
396
397 assert_eq!(graph.dim(), (3, 3));
398 assert_eq!(graph[[0, 1]], 1.0);
400 assert_eq!(graph[[1, 0]], 1.0);
401 assert_eq!(graph[[0, 2]], 0.0);
403 }
404
405 #[test]
406 fn test_symmetrize_transform() {
407 let mut graph = Array2::<f64>::zeros((3, 3));
408 graph[[0, 1]] = 1.0;
409 graph[[1, 0]] = 2.0;
410 graph[[1, 2]] = 3.0;
411 graph[[2, 1]] = 4.0;
412
413 let transform = SymmetrizeTransform::new("max".to_string());
414 let symmetric = transform.transform(&graph).unwrap();
415
416 assert_eq!(symmetric[[0, 1]], 2.0);
417 assert_eq!(symmetric[[1, 0]], 2.0);
418 assert_eq!(symmetric[[1, 2]], 4.0);
419 assert_eq!(symmetric[[2, 1]], 4.0);
420 }
421
422 #[test]
423 fn test_normalize_transform() {
424 let mut graph = Array2::<f64>::zeros((3, 3));
425 graph[[0, 1]] = 2.0;
426 graph[[0, 2]] = 2.0;
427 graph[[1, 0]] = 1.0;
428
429 let transform = NormalizeTransform::new("row".to_string());
430 let normalized = transform.transform(&graph).unwrap();
431
432 let row_sum: f64 = normalized.row(0).sum();
434 assert!((row_sum - 1.0).abs() < 1e-10);
435 }
436
437 #[test]
438 fn test_sparsify_transform() {
439 let mut graph = Array2::<f64>::zeros((3, 3));
440 graph[[0, 1]] = 0.1;
441 graph[[0, 2]] = 0.5;
442 graph[[1, 2]] = 0.3;
443
444 let transform = SparsifyTransform::new(0.2);
445 let sparse = transform.transform(&graph).unwrap();
446
447 assert_eq!(sparse[[0, 1]], 0.0); assert_eq!(sparse[[0, 2]], 0.5); assert_eq!(sparse[[1, 2]], 0.3); }
451
452 #[test]
453 #[allow(non_snake_case)]
454 fn test_graph_pipeline() {
455 let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
456
457 let pipeline = GraphPipeline::new(KNNGraphBuilder::new(2).weighted(true))
458 .add_transform(SymmetrizeTransform::new("average".to_string()))
459 .add_transform(NormalizeTransform::new("row".to_string()));
460
461 let graph = pipeline.build(&X.view()).unwrap();
462
463 assert_eq!(graph.dim(), (4, 4));
464
465 for i in 0..4 {
467 let row_sum: f64 = graph.row(i).sum();
468 assert!((row_sum - 1.0).abs() < 1e-6 || row_sum == 0.0);
469 }
470
471 let total_edges: usize = graph.iter().filter(|&&x| x > 0.0).count();
474 assert!(total_edges > 0);
475 }
476
477 #[test]
478 #[allow(non_snake_case)]
479 fn test_symmetric_pipeline() {
480 let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]];
481
482 let pipeline = GraphPipeline::new(KNNGraphBuilder::new(1).weighted(true))
483 .add_transform(SymmetrizeTransform::new("max".to_string()));
484
485 let graph = pipeline.build(&X.view()).unwrap();
486
487 assert_eq!(graph.dim(), (3, 3));
488
489 for i in 0..3 {
491 for j in 0..3 {
492 assert!((graph[[i, j]] - graph[[j, i]]).abs() < 1e-10);
493 }
494 }
495 }
496}