1use super::core::Embedding;
13use crate::base::{EdgeWeight, Graph, Node};
14use crate::error::{GraphError, Result};
15use crate::spectral::LaplacianType;
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
17use scirs2_core::random::Rng;
18use scirs2_core::simd_ops::SimdUnifiedOps;
19use std::collections::HashMap;
20
21#[derive(Debug, Clone)]
23pub struct SpectralEmbeddingConfig {
24 pub dimensions: usize,
26 pub laplacian_type: SpectralLaplacianType,
28 pub tolerance: f64,
30 pub max_iterations: usize,
32 pub normalize: bool,
34 pub drop_first: bool,
36}
37
38impl Default for SpectralEmbeddingConfig {
39 fn default() -> Self {
40 SpectralEmbeddingConfig {
41 dimensions: 2,
42 laplacian_type: SpectralLaplacianType::NormalizedNgJordanWeiss,
43 tolerance: 1e-8,
44 max_iterations: 300,
45 normalize: true,
46 drop_first: true,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum SpectralLaplacianType {
54 Unnormalized,
57 Normalized,
60 RandomWalk,
63 NormalizedNgJordanWeiss,
66}
67
68pub struct SpectralEmbedding<N: Node> {
73 config: SpectralEmbeddingConfig,
75 node_to_idx: HashMap<N, usize>,
77 idx_to_node: Vec<N>,
79 embedding_matrix: Option<Array2<f64>>,
81 eigenvalues: Option<Array1<f64>>,
83}
84
85impl<N: Node + std::fmt::Debug> SpectralEmbedding<N> {
86 pub fn new(config: SpectralEmbeddingConfig) -> Self {
88 SpectralEmbedding {
89 config,
90 node_to_idx: HashMap::new(),
91 idx_to_node: Vec::new(),
92 embedding_matrix: None,
93 eigenvalues: None,
94 }
95 }
96
97 pub fn fit<E, Ix>(&mut self, graph: &Graph<N, E, Ix>) -> Result<()>
99 where
100 N: Clone,
101 E: EdgeWeight + Into<f64> + scirs2_core::numeric::Zero + scirs2_core::numeric::One + Copy,
102 Ix: petgraph::graph::IndexType,
103 {
104 let n = graph.node_count();
105 if n == 0 {
106 return Err(GraphError::InvalidGraph(
107 "Cannot compute spectral embedding for empty graph".to_string(),
108 ));
109 }
110
111 let needed_dims = if self.config.drop_first {
112 self.config.dimensions + 1
113 } else {
114 self.config.dimensions
115 };
116
117 if needed_dims > n {
118 return Err(GraphError::InvalidParameter {
119 param: "dimensions".to_string(),
120 value: self.config.dimensions.to_string(),
121 expected: format!(
122 "at most {} (number of nodes{})",
123 if self.config.drop_first { n - 1 } else { n },
124 if self.config.drop_first { " - 1" } else { "" }
125 ),
126 context: "Spectral embedding requires dimensions <= number of nodes".to_string(),
127 });
128 }
129
130 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
132 self.node_to_idx.clear();
133 self.idx_to_node = nodes.clone();
134 for (i, node) in nodes.iter().enumerate() {
135 self.node_to_idx.insert(node.clone(), i);
136 }
137
138 let lap_type = match self.config.laplacian_type {
140 SpectralLaplacianType::Unnormalized => LaplacianType::Standard,
141 SpectralLaplacianType::Normalized | SpectralLaplacianType::NormalizedNgJordanWeiss => {
142 LaplacianType::Normalized
143 }
144 SpectralLaplacianType::RandomWalk => LaplacianType::RandomWalk,
145 };
146
147 let laplacian = crate::spectral::laplacian(graph, lap_type)?;
148
149 let (eigenvalues, eigenvectors) =
151 self.compute_smallest_eigenvectors(&laplacian, needed_dims)?;
152
153 let start_idx = if self.config.drop_first { 1 } else { 0 };
155 let end_idx = start_idx + self.config.dimensions;
156
157 let mut embedding = Array2::zeros((n, self.config.dimensions));
158 for i in 0..n {
159 for (j, col_idx) in (start_idx..end_idx).enumerate() {
160 embedding[[i, j]] = eigenvectors[[i, col_idx]];
161 }
162 }
163
164 if self.config.laplacian_type == SpectralLaplacianType::NormalizedNgJordanWeiss {
166 for i in 0..n {
167 let row = embedding.row(i);
168 let norm = if let Some(slice) = row.as_slice() {
169 let view = ArrayView1::from(slice);
170 f64::simd_norm(&view)
171 } else {
172 row.iter().map(|x| x * x).sum::<f64>().sqrt()
173 };
174 if norm > 1e-15 {
175 for j in 0..self.config.dimensions {
176 embedding[[i, j]] /= norm;
177 }
178 }
179 }
180 }
181
182 if self.config.normalize
184 && self.config.laplacian_type != SpectralLaplacianType::NormalizedNgJordanWeiss
185 {
186 for i in 0..n {
187 let row = embedding.row(i);
188 let norm = row.iter().map(|x| x * x).sum::<f64>().sqrt();
189 if norm > 1e-15 {
190 for j in 0..self.config.dimensions {
191 embedding[[i, j]] /= norm;
192 }
193 }
194 }
195 }
196
197 let selected_eigenvalues = Array1::from_vec(
199 eigenvalues
200 .iter()
201 .skip(start_idx)
202 .take(self.config.dimensions)
203 .copied()
204 .collect(),
205 );
206
207 self.embedding_matrix = Some(embedding);
208 self.eigenvalues = Some(selected_eigenvalues);
209
210 Ok(())
211 }
212
213 fn compute_smallest_eigenvectors(
218 &self,
219 matrix: &Array2<f64>,
220 k: usize,
221 ) -> Result<(Vec<f64>, Array2<f64>)> {
222 let n = matrix.nrows();
223 let mut eigenvalues = Vec::with_capacity(k);
224 let mut eigenvectors = Array2::zeros((n, k));
225 let mut rng = scirs2_core::random::rng();
226
227 let mut max_gershgorin = 0.0_f64;
234 for i in 0..n {
235 let diag = matrix[[i, i]];
236 let off_diag_sum: f64 = (0..n)
237 .filter(|&j| j != i)
238 .map(|j| matrix[[i, j]].abs())
239 .sum();
240 max_gershgorin = max_gershgorin.max(diag + off_diag_sum);
241 }
242
243 let sigma = max_gershgorin + 1.0;
246 let mut shifted = Array2::zeros((n, n));
247 for i in 0..n {
248 for j in 0..n {
249 shifted[[i, j]] = -matrix[[i, j]];
250 }
251 shifted[[i, i]] += sigma;
252 }
253
254 let mut deflation_vectors: Vec<Array1<f64>> = Vec::new();
256
257 for kk in 0..k {
258 let mut v = Array1::from_vec(
260 (0..n)
261 .map(|_| rng.random::<f64>() - 0.5)
262 .collect::<Vec<f64>>(),
263 );
264
265 let norm = v.iter().map(|x| x * x).sum::<f64>().sqrt();
267 if norm > 1e-15 {
268 v.mapv_inplace(|x| x / norm);
269 }
270
271 let mut eigenvalue = 0.0;
272
273 for iter in 0..self.config.max_iterations {
274 let mut w = Array1::zeros(n);
276 for i in 0..n {
277 let row = shifted.row(i);
278 w[i] = if let (Some(row_s), Some(v_s)) = (row.as_slice(), v.as_slice()) {
279 let rv = ArrayView1::from(row_s);
280 let vv = ArrayView1::from(v_s);
281 f64::simd_dot(&rv, &vv)
282 } else {
283 row.dot(&v)
284 };
285 }
286
287 for prev_v in &deflation_vectors {
289 let proj = w.dot(prev_v);
290 for i in 0..n {
291 w[i] -= proj * prev_v[i];
292 }
293 }
294
295 let new_eigenvalue = v.dot(&w);
297
298 let w_norm = w.iter().map(|x| x * x).sum::<f64>().sqrt();
300 if w_norm < 1e-15 {
301 break;
302 }
303 w.mapv_inplace(|x| x / w_norm);
304
305 if iter > 0 && (new_eigenvalue - eigenvalue).abs() < self.config.tolerance {
307 eigenvalue = new_eigenvalue;
308 v = w;
309 break;
310 }
311
312 eigenvalue = new_eigenvalue;
313 v = w;
314 }
315
316 let actual_eigenvalue = sigma - eigenvalue;
318 eigenvalues.push(actual_eigenvalue);
319
320 for i in 0..n {
322 eigenvectors[[i, kk]] = v[i];
323 }
324
325 deflation_vectors.push(v);
326 }
327
328 let mut indices: Vec<usize> = (0..k).collect();
330 indices.sort_by(|&a, &b| {
331 eigenvalues[a]
332 .partial_cmp(&eigenvalues[b])
333 .unwrap_or(std::cmp::Ordering::Equal)
334 });
335
336 let sorted_eigenvalues: Vec<f64> = indices.iter().map(|&i| eigenvalues[i]).collect();
337 let mut sorted_eigenvectors = Array2::zeros((n, k));
338 for (new_col, &old_col) in indices.iter().enumerate() {
339 for i in 0..n {
340 sorted_eigenvectors[[i, new_col]] = eigenvectors[[i, old_col]];
341 }
342 }
343
344 Ok((sorted_eigenvalues, sorted_eigenvectors))
345 }
346
347 pub fn get_embedding(&self, node: &N) -> Result<Embedding> {
349 let idx = self
350 .node_to_idx
351 .get(node)
352 .ok_or_else(|| GraphError::node_not_found(format!("{node:?}")))?;
353
354 let matrix = self.embedding_matrix.as_ref().ok_or_else(|| {
355 GraphError::AlgorithmError("Spectral embedding not computed yet".to_string())
356 })?;
357
358 let row = matrix.row(*idx);
359 Ok(Embedding {
360 vector: row.to_vec(),
361 })
362 }
363
364 pub fn embeddings(&self) -> Result<HashMap<N, Embedding>>
366 where
367 N: Clone,
368 {
369 let matrix = self.embedding_matrix.as_ref().ok_or_else(|| {
370 GraphError::AlgorithmError("Spectral embedding not computed yet".to_string())
371 })?;
372
373 let mut result = HashMap::new();
374 for (i, node) in self.idx_to_node.iter().enumerate() {
375 let row = matrix.row(i);
376 result.insert(
377 node.clone(),
378 Embedding {
379 vector: row.to_vec(),
380 },
381 );
382 }
383
384 Ok(result)
385 }
386
387 pub fn embedding_matrix(&self) -> Result<&Array2<f64>> {
389 self.embedding_matrix.as_ref().ok_or_else(|| {
390 GraphError::AlgorithmError("Spectral embedding not computed yet".to_string())
391 })
392 }
393
394 pub fn eigenvalues(&self) -> Result<&Array1<f64>> {
396 self.eigenvalues.as_ref().ok_or_else(|| {
397 GraphError::AlgorithmError("Spectral embedding not computed yet".to_string())
398 })
399 }
400
401 pub fn dimensions(&self) -> usize {
403 self.config.dimensions
404 }
405
406 pub fn pairwise_distances(&self) -> Result<Array2<f64>> {
408 let matrix = self.embedding_matrix.as_ref().ok_or_else(|| {
409 GraphError::AlgorithmError("Spectral embedding not computed yet".to_string())
410 })?;
411
412 let n = matrix.nrows();
413 let d = matrix.ncols();
414 let mut distances = Array2::zeros((n, n));
415
416 for i in 0..n {
417 for j in (i + 1)..n {
418 let mut dist_sq = 0.0;
419 for k in 0..d {
420 let diff = matrix[[i, k]] - matrix[[j, k]];
421 dist_sq += diff * diff;
422 }
423 let dist = dist_sq.sqrt();
424 distances[[i, j]] = dist;
425 distances[[j, i]] = dist;
426 }
427 }
428
429 Ok(distances)
430 }
431
432 pub fn compute_stress<E, Ix>(&self, graph: &Graph<N, E, Ix>) -> Result<f64>
435 where
436 N: Clone,
437 E: EdgeWeight + Into<f64> + Clone,
438 Ix: petgraph::graph::IndexType,
439 {
440 let matrix = self.embedding_matrix.as_ref().ok_or_else(|| {
441 GraphError::AlgorithmError("Spectral embedding not computed yet".to_string())
442 })?;
443
444 let n = matrix.nrows();
445 let d = matrix.ncols();
446 let mut stress_num = 0.0;
447 let mut stress_den = 0.0;
448
449 let edges = graph.edges();
450 for edge in &edges {
451 let i = self
452 .node_to_idx
453 .get(&edge.source)
454 .copied()
455 .ok_or_else(|| GraphError::node_not_found("source"))?;
456 let j = self
457 .node_to_idx
458 .get(&edge.target)
459 .copied()
460 .ok_or_else(|| GraphError::node_not_found("target"))?;
461
462 let graph_dist: f64 = edge.weight.clone().into();
463 let graph_dist = if graph_dist > 0.0 {
464 1.0 / graph_dist
465 } else {
466 1.0
467 };
468
469 let mut emb_dist_sq = 0.0;
470 for k in 0..d {
471 let diff = matrix[[i, k]] - matrix[[j, k]];
472 emb_dist_sq += diff * diff;
473 }
474 let emb_dist = emb_dist_sq.sqrt();
475
476 stress_num += (graph_dist - emb_dist).powi(2);
477 stress_den += graph_dist.powi(2);
478 }
479
480 if stress_den > 0.0 {
481 Ok(stress_num / stress_den)
482 } else {
483 Ok(0.0)
484 }
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 fn make_path_graph() -> Graph<i32, f64> {
494 let mut g = Graph::new();
495 for i in 0..4 {
496 g.add_node(i);
497 }
498 let _ = g.add_edge(0, 1, 1.0);
499 let _ = g.add_edge(1, 2, 1.0);
500 let _ = g.add_edge(2, 3, 1.0);
501 g
502 }
503
504 fn make_complete_graph() -> Graph<i32, f64> {
506 let mut g = Graph::new();
507 for i in 0..4 {
508 g.add_node(i);
509 }
510 for i in 0..4 {
511 for j in (i + 1)..4 {
512 let _ = g.add_edge(i, j, 1.0);
513 }
514 }
515 g
516 }
517
518 fn make_two_community_graph() -> Graph<i32, f64> {
520 let mut g = Graph::new();
521 for i in 0..8 {
522 g.add_node(i);
523 }
524 for i in 0..4 {
526 for j in (i + 1)..4 {
527 let _ = g.add_edge(i, j, 1.0);
528 }
529 }
530 for i in 4..8 {
532 for j in (i + 1)..8 {
533 let _ = g.add_edge(i, j, 1.0);
534 }
535 }
536 let _ = g.add_edge(3, 4, 1.0);
538 g
539 }
540
541 #[test]
542 fn test_spectral_embedding_basic() {
543 let g = make_path_graph();
544 let config = SpectralEmbeddingConfig {
545 dimensions: 2,
546 laplacian_type: SpectralLaplacianType::Unnormalized,
547 tolerance: 1e-6,
548 max_iterations: 200,
549 normalize: false,
550 drop_first: true,
551 };
552
553 let mut se = SpectralEmbedding::new(config);
554 let result = se.fit(&g);
555 assert!(
556 result.is_ok(),
557 "Spectral embedding should succeed: {:?}",
558 result.err()
559 );
560
561 for node in 0..4 {
563 let emb = se.get_embedding(&node);
564 assert!(emb.is_ok(), "Node {node} should have embedding");
565 let emb = emb.expect("embedding should be valid");
566 assert_eq!(emb.vector.len(), 2);
567 }
568 }
569
570 #[test]
571 fn test_spectral_embedding_eigenvalues() {
572 let g = make_complete_graph();
573 let config = SpectralEmbeddingConfig {
574 dimensions: 2,
575 laplacian_type: SpectralLaplacianType::Unnormalized,
576 tolerance: 1e-8,
577 max_iterations: 300,
578 normalize: false,
579 drop_first: true,
580 };
581
582 let mut se = SpectralEmbedding::new(config);
583 let _ = se.fit(&g);
584
585 let eigenvalues = se.eigenvalues();
586 assert!(eigenvalues.is_ok());
587 let eigenvalues = eigenvalues.expect("eigenvalues should be valid");
588
589 for &val in eigenvalues.iter() {
592 assert!(
593 val > 0.0,
594 "Non-trivial eigenvalues of K4 should be positive, got {val}"
595 );
596 }
597 }
598
599 #[test]
600 fn test_spectral_embedding_normalized() {
601 let g = make_path_graph();
602 let config = SpectralEmbeddingConfig {
603 dimensions: 2,
604 laplacian_type: SpectralLaplacianType::Normalized,
605 normalize: true,
606 ..Default::default()
607 };
608
609 let mut se = SpectralEmbedding::new(config);
610 let result = se.fit(&g);
611 assert!(result.is_ok());
612
613 for node in 0..4 {
615 let emb = se.get_embedding(&node);
616 assert!(emb.is_ok());
617 let emb = emb.expect("embedding should be valid");
618 let norm: f64 = emb.vector.iter().map(|x| x * x).sum::<f64>().sqrt();
619 assert!(
620 (norm - 1.0).abs() < 0.1 || norm < 0.01,
621 "Normalized embedding norm should be close to 1.0 or near zero, got {norm}"
622 );
623 }
624 }
625
626 #[test]
627 fn test_spectral_embedding_two_communities() {
628 let g = make_two_community_graph();
629 let config = SpectralEmbeddingConfig {
630 dimensions: 2,
631 laplacian_type: SpectralLaplacianType::Unnormalized,
632 tolerance: 1e-8,
633 max_iterations: 500,
634 normalize: false,
635 drop_first: true,
636 };
637
638 let mut se = SpectralEmbedding::new(config);
639 let result = se.fit(&g);
640 assert!(
641 result.is_ok(),
642 "Should succeed for two-community graph: {:?}",
643 result.err()
644 );
645
646 for i in 0..8 {
648 let emb = se.get_embedding(&i);
649 assert!(emb.is_ok(), "Node {i} should have an embedding");
650 }
651
652 let eigenvalues = se.eigenvalues();
654 assert!(eigenvalues.is_ok());
655 let eigenvalues = eigenvalues.expect("eigenvalues should be valid");
656
657 assert!(
660 eigenvalues.len() == 2,
661 "Should have 2 eigenvalues, got {}",
662 eigenvalues.len()
663 );
664
665 let distances = se.pairwise_distances();
667 assert!(distances.is_ok());
668
669 let distances = distances.expect("distances should be valid");
670 let mut within_sum = 0.0;
673 let mut within_count = 0;
674 let mut between_sum = 0.0;
675 let mut between_count = 0;
676
677 for i in 0..8 {
678 for j in (i + 1)..8 {
679 let d = distances[[i, j]];
680 if (i < 4 && j < 4) || (i >= 4 && j >= 4) {
681 within_sum += d;
682 within_count += 1;
683 } else {
684 between_sum += d;
685 between_count += 1;
686 }
687 }
688 }
689
690 let avg_within = if within_count > 0 {
691 within_sum / within_count as f64
692 } else {
693 0.0
694 };
695 let avg_between = if between_count > 0 {
696 between_sum / between_count as f64
697 } else {
698 0.0
699 };
700
701 assert!(
704 avg_within.is_finite(),
705 "Within-community distance should be finite"
706 );
707 assert!(
708 avg_between.is_finite(),
709 "Between-community distance should be finite"
710 );
711 }
712
713 #[test]
714 fn test_spectral_embedding_empty_graph_error() {
715 let g: Graph<i32, f64> = Graph::new();
716 let config = SpectralEmbeddingConfig::default();
717
718 let mut se = SpectralEmbedding::new(config);
719 let result = se.fit(&g);
720 assert!(result.is_err(), "Should fail for empty graph");
721 }
722
723 #[test]
724 fn test_spectral_embedding_too_many_dims_error() {
725 let g = make_path_graph(); let config = SpectralEmbeddingConfig {
727 dimensions: 10, drop_first: true,
729 ..Default::default()
730 };
731
732 let mut se = SpectralEmbedding::new(config);
733 let result = se.fit(&g);
734 assert!(result.is_err(), "Should fail when dimensions > nodes");
735 }
736
737 #[test]
738 fn test_spectral_embedding_pairwise_distances() {
739 let g = make_path_graph();
740 let config = SpectralEmbeddingConfig {
741 dimensions: 2,
742 normalize: false,
743 drop_first: true,
744 ..Default::default()
745 };
746
747 let mut se = SpectralEmbedding::new(config);
748 let _ = se.fit(&g);
749
750 let distances = se.pairwise_distances();
751 assert!(distances.is_ok());
752 let distances = distances.expect("distances should be valid");
753
754 for i in 0..4 {
756 assert!(
757 distances[[i, i]].abs() < 1e-10,
758 "Self-distance should be zero"
759 );
760 }
761
762 for i in 0..4 {
764 for j in 0..4 {
765 assert!(
766 (distances[[i, j]] - distances[[j, i]]).abs() < 1e-10,
767 "Distance matrix should be symmetric"
768 );
769 }
770 }
771 }
772
773 #[test]
774 fn test_spectral_embedding_stress() {
775 let g = make_path_graph();
776 let config = SpectralEmbeddingConfig {
777 dimensions: 2,
778 normalize: false,
779 drop_first: true,
780 ..Default::default()
781 };
782
783 let mut se = SpectralEmbedding::new(config);
784 let _ = se.fit(&g);
785
786 let stress = se.compute_stress(&g);
787 assert!(stress.is_ok());
788 let stress = stress.expect("stress should be valid");
789 assert!(stress.is_finite(), "Stress should be finite, got {stress}");
790 assert!(stress >= 0.0, "Stress should be non-negative, got {stress}");
791 }
792
793 #[test]
794 fn test_spectral_embedding_random_walk_laplacian() {
795 let g = make_path_graph();
796 let config = SpectralEmbeddingConfig {
797 dimensions: 2,
798 laplacian_type: SpectralLaplacianType::RandomWalk,
799 tolerance: 1e-6,
800 max_iterations: 200,
801 normalize: false,
802 drop_first: true,
803 };
804
805 let mut se = SpectralEmbedding::new(config);
806 let result = se.fit(&g);
807 assert!(
808 result.is_ok(),
809 "Random walk spectral embedding should succeed"
810 );
811
812 let embs = se.embeddings();
813 assert!(embs.is_ok());
814 assert_eq!(embs.expect("embeddings should be valid").len(), 4);
815 }
816}