1use std::collections::HashMap;
77use std::fmt::Debug;
78
79#[non_exhaustive]
81#[derive(Debug, Clone, PartialEq, thiserror::Error)]
82pub enum SheafError {
83 #[error("Node {0} not found")]
85 NodeNotFound(usize),
86 #[error("Edge ({0}, {1}) not found")]
88 EdgeNotFound(usize, usize),
89 #[error("Dimension mismatch: expected {expected}, got {actual}")]
91 DimensionMismatch {
92 expected: usize,
94 actual: usize,
96 },
97 #[error("Invalid restriction: {0}")]
99 InvalidRestriction(String),
100}
101
102pub trait RestrictionMap: Clone + Debug {
110 type Scalar: Clone + Debug;
112 type Vector: Clone + Debug;
114
115 fn in_dim(&self) -> usize;
117
118 fn out_dim(&self) -> usize;
120
121 fn apply(&self, x: &Self::Vector) -> Result<Self::Vector, SheafError>;
123
124 fn apply_transpose(&self, x: &Self::Vector) -> Result<Self::Vector, SheafError>;
127
128 fn as_matrix(&self) -> Vec<Vec<Self::Scalar>>;
130
131 fn frobenius_norm(&self) -> Self::Scalar;
133}
134
135pub trait Stalk: Clone + Debug {
139 type Scalar: Clone + Debug;
141 type Vector: Clone + Debug;
143
144 fn dim(&self) -> usize;
146
147 fn value(&self) -> &Self::Vector;
149
150 fn set_value(&mut self, v: Self::Vector) -> Result<(), SheafError>;
152
153 fn zero(&self) -> Self::Vector;
155}
156
157#[derive(Debug, Clone)]
159pub struct SheafEdge<R: RestrictionMap> {
160 pub source: usize,
162 pub target: usize,
164 pub restriction_source: R,
166 pub restriction_target: R,
168 pub weight: f32,
170}
171
172pub trait SheafGraph: Debug {
177 type Scalar: Clone + Debug + Default;
179 type Vector: Clone + Debug;
181 type Restriction: RestrictionMap<Scalar = Self::Scalar, Vector = Self::Vector>;
183 type Stalk: Stalk<Scalar = Self::Scalar, Vector = Self::Vector>;
185
186 fn num_nodes(&self) -> usize;
188
189 fn num_edges(&self) -> usize;
191
192 fn stalk(&self, node: usize) -> Result<&Self::Stalk, SheafError>;
194
195 fn stalk_mut(&mut self, node: usize) -> Result<&mut Self::Stalk, SheafError>;
197
198 fn edge(
200 &self,
201 source: usize,
202 target: usize,
203 ) -> Result<&SheafEdge<Self::Restriction>, SheafError>;
204
205 fn edges(&self) -> impl Iterator<Item = &SheafEdge<Self::Restriction>>;
207
208 fn neighbors(&self, node: usize) -> Result<Vec<usize>, SheafError>;
210
211 fn dirichlet_energy(&self) -> Result<Self::Scalar, SheafError>;
217
218 fn laplacian_at(&self, node: usize) -> Result<Self::Vector, SheafError>;
222
223 fn diffusion_step(&mut self, step_size: Self::Scalar) -> Result<(), SheafError>;
229}
230
231#[derive(Debug, Clone, Copy, PartialEq, Eq)]
242pub enum LaplacianType {
243 Connection,
245 General,
247 Diagonal,
249}
250
251#[derive(Debug, Clone)]
253pub struct DiffusionConfig {
254 pub num_steps: usize,
256 pub step_size: f32,
258 pub normalize: bool,
260 pub laplacian_type: LaplacianType,
262}
263
264impl Default for DiffusionConfig {
265 fn default() -> Self {
266 Self {
267 num_steps: 5,
268 step_size: 0.1,
269 normalize: true,
270 laplacian_type: LaplacianType::General,
271 }
272 }
273}
274
275#[derive(Debug, Clone)]
281pub struct DenseRestriction {
282 pub data: Vec<f32>,
284 pub rows: usize,
286 pub cols: usize,
288}
289
290impl DenseRestriction {
291 pub fn new(data: Vec<f32>, rows: usize, cols: usize) -> Result<Self, SheafError> {
293 if data.len() != rows * cols {
294 return Err(SheafError::DimensionMismatch {
295 expected: rows * cols,
296 actual: data.len(),
297 });
298 }
299 Ok(Self { data, rows, cols })
300 }
301
302 pub fn identity(dim: usize) -> Self {
304 let mut data = vec![0.0; dim * dim];
305 for i in 0..dim {
306 data[i * dim + i] = 1.0;
307 }
308 Self {
309 data,
310 rows: dim,
311 cols: dim,
312 }
313 }
314
315 #[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
320 #[cfg(feature = "rand")]
321 #[allow(deprecated)]
322 pub fn random_orthogonal(dim: usize) -> Self {
323 use rand::Rng;
324 let mut rng = rand::thread_rng();
325
326 let mut data: Vec<f32> = (0..dim * dim).map(|_| rng.gen_range(-0.5..0.5)).collect();
328
329 for i in 0..dim {
331 let mut norm: f32 = 0.0;
333 for j in 0..dim {
334 norm += data[j * dim + i] * data[j * dim + i];
335 }
336 norm = norm.sqrt();
337 if norm > 1e-6 {
338 for j in 0..dim {
339 data[j * dim + i] /= norm;
340 }
341 }
342
343 for k in (i + 1)..dim {
345 let mut dot = 0.0;
346 for j in 0..dim {
347 dot += data[j * dim + i] * data[j * dim + k];
348 }
349 for j in 0..dim {
350 data[j * dim + k] -= dot * data[j * dim + i];
351 }
352 }
353 }
354
355 Self {
356 data,
357 rows: dim,
358 cols: dim,
359 }
360 }
361}
362
363impl RestrictionMap for DenseRestriction {
364 type Scalar = f32;
365 type Vector = Vec<f32>;
366
367 fn in_dim(&self) -> usize {
368 self.cols
369 }
370
371 fn out_dim(&self) -> usize {
372 self.rows
373 }
374
375 fn apply(&self, x: &Self::Vector) -> Result<Self::Vector, SheafError> {
376 if x.len() != self.cols {
377 return Err(SheafError::DimensionMismatch {
378 expected: self.cols,
379 actual: x.len(),
380 });
381 }
382
383 let mut result = vec![0.0; self.rows];
386 #[allow(clippy::needless_range_loop)]
387 for i in 0..self.rows {
388 for j in 0..self.cols {
389 result[i] += self.data[i * self.cols + j] * x[j];
390 }
391 }
392 Ok(result)
393 }
394
395 fn apply_transpose(&self, x: &Self::Vector) -> Result<Self::Vector, SheafError> {
396 if x.len() != self.rows {
397 return Err(SheafError::DimensionMismatch {
398 expected: self.rows,
399 actual: x.len(),
400 });
401 }
402
403 let mut result = vec![0.0; self.cols];
405 #[allow(clippy::needless_range_loop)]
406 for j in 0..self.cols {
407 for i in 0..self.rows {
408 result[j] += self.data[i * self.cols + j] * x[i];
409 }
410 }
411 Ok(result)
412 }
413
414 fn as_matrix(&self) -> Vec<Vec<Self::Scalar>> {
415 let mut matrix = vec![vec![0.0; self.cols]; self.rows];
417 #[allow(clippy::needless_range_loop)]
418 for i in 0..self.rows {
419 for j in 0..self.cols {
420 matrix[i][j] = self.data[i * self.cols + j];
421 }
422 }
423 matrix
424 }
425
426 fn frobenius_norm(&self) -> Self::Scalar {
427 self.data.iter().map(|x| x * x).sum::<f32>().sqrt()
428 }
429}
430
431#[derive(Debug, Clone)]
433pub struct VecStalk {
434 value: Vec<f32>,
435}
436
437impl VecStalk {
438 pub fn new(value: Vec<f32>) -> Self {
440 Self { value }
441 }
442}
443
444impl Stalk for VecStalk {
445 type Scalar = f32;
446 type Vector = Vec<f32>;
447
448 fn dim(&self) -> usize {
449 self.value.len()
450 }
451
452 fn value(&self) -> &Self::Vector {
453 &self.value
454 }
455
456 fn set_value(&mut self, v: Self::Vector) -> Result<(), SheafError> {
457 if v.len() != self.value.len() {
458 return Err(SheafError::DimensionMismatch {
459 expected: self.value.len(),
460 actual: v.len(),
461 });
462 }
463 self.value = v;
464 Ok(())
465 }
466
467 fn zero(&self) -> Self::Vector {
468 vec![0.0; self.value.len()]
469 }
470}
471
472#[derive(Debug, Clone)]
474pub struct SimpleSheafGraph {
475 stalks: Vec<VecStalk>,
476 edges: Vec<SheafEdge<DenseRestriction>>,
477 adjacency: HashMap<usize, Vec<usize>>,
478}
479
480impl SimpleSheafGraph {
481 pub fn new() -> Self {
483 Self {
484 stalks: Vec::new(),
485 edges: Vec::new(),
486 adjacency: HashMap::new(),
487 }
488 }
489
490 pub fn add_node(&mut self, value: Vec<f32>) -> usize {
492 let id = self.stalks.len();
493 self.stalks.push(VecStalk::new(value));
494 self.adjacency.insert(id, Vec::new());
495 id
496 }
497
498 pub fn add_edge(
500 &mut self,
501 source: usize,
502 target: usize,
503 restriction_source: DenseRestriction,
504 restriction_target: DenseRestriction,
505 weight: f32,
506 ) -> Result<(), SheafError> {
507 if source >= self.stalks.len() {
508 return Err(SheafError::NodeNotFound(source));
509 }
510 if target >= self.stalks.len() {
511 return Err(SheafError::NodeNotFound(target));
512 }
513
514 if restriction_source.in_dim() != self.stalks[source].dim() {
516 return Err(SheafError::DimensionMismatch {
517 expected: self.stalks[source].dim(),
518 actual: restriction_source.in_dim(),
519 });
520 }
521 if restriction_target.in_dim() != self.stalks[target].dim() {
522 return Err(SheafError::DimensionMismatch {
523 expected: self.stalks[target].dim(),
524 actual: restriction_target.in_dim(),
525 });
526 }
527 if restriction_source.out_dim() != restriction_target.out_dim() {
528 return Err(SheafError::InvalidRestriction(
529 "Source and target restrictions must have same output dimension".into(),
530 ));
531 }
532
533 self.edges.push(SheafEdge {
534 source,
535 target,
536 restriction_source,
537 restriction_target,
538 weight,
539 });
540
541 self.adjacency.entry(source).or_default().push(target);
542 self.adjacency.entry(target).or_default().push(source);
543
544 Ok(())
545 }
546}
547
548impl Default for SimpleSheafGraph {
549 fn default() -> Self {
550 Self::new()
551 }
552}
553
554impl SheafGraph for SimpleSheafGraph {
555 type Scalar = f32;
556 type Vector = Vec<f32>;
557 type Restriction = DenseRestriction;
558 type Stalk = VecStalk;
559
560 fn num_nodes(&self) -> usize {
561 self.stalks.len()
562 }
563
564 fn num_edges(&self) -> usize {
565 self.edges.len()
566 }
567
568 fn stalk(&self, node: usize) -> Result<&Self::Stalk, SheafError> {
569 self.stalks.get(node).ok_or(SheafError::NodeNotFound(node))
570 }
571
572 fn stalk_mut(&mut self, node: usize) -> Result<&mut Self::Stalk, SheafError> {
573 self.stalks
574 .get_mut(node)
575 .ok_or(SheafError::NodeNotFound(node))
576 }
577
578 fn edge(
579 &self,
580 source: usize,
581 target: usize,
582 ) -> Result<&SheafEdge<Self::Restriction>, SheafError> {
583 self.edges
584 .iter()
585 .find(|e| {
586 (e.source == source && e.target == target)
587 || (e.source == target && e.target == source)
588 })
589 .ok_or(SheafError::EdgeNotFound(source, target))
590 }
591
592 fn edges(&self) -> impl Iterator<Item = &SheafEdge<Self::Restriction>> {
593 self.edges.iter()
594 }
595
596 fn neighbors(&self, node: usize) -> Result<Vec<usize>, SheafError> {
597 self.adjacency
598 .get(&node)
599 .cloned()
600 .ok_or(SheafError::NodeNotFound(node))
601 }
602
603 fn dirichlet_energy(&self) -> Result<Self::Scalar, SheafError> {
604 let mut energy = 0.0;
605
606 for edge in &self.edges {
607 let x_u = self.stalks[edge.source].value();
608 let x_v = self.stalks[edge.target].value();
609
610 let r_u = edge.restriction_source.apply(x_u)?;
611 let r_v = edge.restriction_target.apply(x_v)?;
612
613 let diff_sq: f32 = r_u
615 .iter()
616 .zip(r_v.iter())
617 .map(|(a, b)| (a - b) * (a - b))
618 .sum();
619
620 energy += edge.weight * diff_sq;
621 }
622
623 Ok(energy)
624 }
625
626 fn laplacian_at(&self, node: usize) -> Result<Self::Vector, SheafError> {
627 let stalk = self.stalk(node)?;
628 let mut result = stalk.zero();
629
630 for edge in &self.edges {
631 let (is_source, other) = if edge.source == node {
632 (true, edge.target)
633 } else if edge.target == node {
634 (false, edge.source)
635 } else {
636 continue;
637 };
638
639 let x_node = self.stalks[node].value();
640 let x_other = self.stalks[other].value();
641
642 let (r_node, r_other) = if is_source {
643 (&edge.restriction_source, &edge.restriction_target)
644 } else {
645 (&edge.restriction_target, &edge.restriction_source)
646 };
647
648 let r_x_node = r_node.apply(x_node)?;
650 let r_x_other = r_other.apply(x_other)?;
651
652 let diff: Vec<f32> = r_x_node
653 .iter()
654 .zip(r_x_other.iter())
655 .map(|(a, b)| a - b)
656 .collect();
657
658 let contrib = r_node.apply_transpose(&diff)?;
660
661 for (i, c) in contrib.iter().enumerate() {
663 result[i] += edge.weight * c;
664 }
665 }
666
667 Ok(result)
668 }
669
670 fn diffusion_step(&mut self, step_size: Self::Scalar) -> Result<(), SheafError> {
671 let laplacians: Vec<Vec<f32>> = (0..self.num_nodes())
673 .map(|i| self.laplacian_at(i))
674 .collect::<Result<_, _>>()?;
675
676 for (i, lap) in laplacians.into_iter().enumerate() {
678 let stalk = &mut self.stalks[i];
679 let new_value: Vec<f32> = stalk
680 .value()
681 .iter()
682 .zip(lap.iter())
683 .map(|(x, l)| x - step_size * l)
684 .collect();
685 stalk.set_value(new_value)?;
686 }
687
688 Ok(())
689 }
690}
691
692pub fn consistency_score(graph: &impl SheafGraph<Scalar = f32>) -> Result<f32, SheafError> {
700 let energy = graph.dirichlet_energy()?;
701 Ok((-energy).exp())
703}
704
705pub fn diffuse_until_convergence(
707 graph: &mut SimpleSheafGraph,
708 config: &DiffusionConfig,
709 tolerance: f32,
710) -> Result<usize, SheafError> {
711 let mut prev_energy = graph.dirichlet_energy()?;
712
713 for step in 0..config.num_steps {
714 graph.diffusion_step(config.step_size)?;
715 let energy = graph.dirichlet_energy()?;
716
717 if (prev_energy - energy).abs() < tolerance {
718 return Ok(step + 1);
719 }
720 prev_energy = energy;
721 }
722
723 Ok(config.num_steps)
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729
730 #[test]
731 fn test_identity_restriction() {
732 let r = DenseRestriction::identity(3);
733 let x = vec![1.0, 2.0, 3.0];
734 let y = r.apply(&x).unwrap();
735 assert_eq!(y, x);
736 }
737
738 #[test]
739 fn test_restriction_transpose() {
740 let r = DenseRestriction::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
742
743 let x = vec![1.0, 2.0, 3.0];
744 let y = r.apply(&x).unwrap();
745 assert_eq!(y.len(), 2);
746
747 let z = vec![1.0, 1.0];
748 let w = r.apply_transpose(&z).unwrap();
749 assert_eq!(w.len(), 3);
750 assert_eq!(w, vec![5.0, 7.0, 9.0]);
753 }
754
755 #[test]
756 fn test_simple_sheaf_graph() {
757 let mut graph = SimpleSheafGraph::new();
758
759 let n0 = graph.add_node(vec![1.0, 0.0]);
761 let n1 = graph.add_node(vec![0.0, 1.0]);
762
763 let r = DenseRestriction::identity(2);
765 graph.add_edge(n0, n1, r.clone(), r.clone(), 1.0).unwrap();
766
767 assert_eq!(graph.num_nodes(), 2);
768 assert_eq!(graph.num_edges(), 1);
769
770 let energy = graph.dirichlet_energy().unwrap();
772 assert!((energy - 2.0).abs() < 1e-6);
773 }
774
775 #[test]
776 fn test_diffusion_reduces_energy() {
777 let mut graph = SimpleSheafGraph::new();
778
779 let n0 = graph.add_node(vec![1.0, 0.0]);
781 let n1 = graph.add_node(vec![0.5, 0.5]);
782 let n2 = graph.add_node(vec![0.0, 1.0]);
783
784 let r = DenseRestriction::identity(2);
785 graph.add_edge(n0, n1, r.clone(), r.clone(), 1.0).unwrap();
786 graph.add_edge(n1, n2, r.clone(), r.clone(), 1.0).unwrap();
787
788 let initial_energy = graph.dirichlet_energy().unwrap();
789
790 for _ in 0..10 {
792 graph.diffusion_step(0.1).unwrap();
793 }
794
795 let final_energy = graph.dirichlet_energy().unwrap();
796 assert!(
797 final_energy < initial_energy,
798 "Diffusion should reduce energy"
799 );
800 }
801
802 #[test]
803 fn test_consistency_score() {
804 let mut graph = SimpleSheafGraph::new();
805
806 graph.add_node(vec![1.0, 2.0]);
808 graph.add_node(vec![1.0, 2.0]);
809
810 let r = DenseRestriction::identity(2);
811 graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
812
813 let score = consistency_score(&graph).unwrap();
815 assert!((score - 1.0).abs() < 1e-6);
816 }
817
818 #[test]
823 fn test_dense_restriction_new_dimension_mismatch() {
824 let result = DenseRestriction::new(vec![1.0, 2.0, 3.0], 2, 2);
825 assert!(matches!(
826 result,
827 Err(SheafError::DimensionMismatch {
828 expected: 4,
829 actual: 3
830 })
831 ));
832 }
833
834 #[test]
835 fn test_dense_restriction_1x1() {
836 let r = DenseRestriction::new(vec![3.0], 1, 1).unwrap();
837 let x = vec![2.0];
838 let y = r.apply(&x).unwrap();
839 assert_eq!(y, vec![6.0]);
840
841 let yt = r.apply_transpose(&vec![2.0]).unwrap();
842 assert_eq!(yt, vec![6.0]); }
844
845 #[test]
846 fn test_dense_restriction_apply_wrong_dim() {
847 let r = DenseRestriction::identity(3);
848 let x = vec![1.0, 2.0]; let result = r.apply(&x);
850 assert!(matches!(
851 result,
852 Err(SheafError::DimensionMismatch {
853 expected: 3,
854 actual: 2
855 })
856 ));
857 }
858
859 #[test]
860 fn test_dense_restriction_apply_transpose_wrong_dim() {
861 let r = DenseRestriction::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
862 let x = vec![1.0, 2.0, 3.0]; let result = r.apply_transpose(&x);
864 assert!(matches!(
865 result,
866 Err(SheafError::DimensionMismatch {
867 expected: 2,
868 actual: 3
869 })
870 ));
871 }
872
873 #[test]
874 fn test_dense_restriction_as_matrix() {
875 let r = DenseRestriction::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
876 let m = r.as_matrix();
877 assert_eq!(m.len(), 2);
878 assert_eq!(m[0], vec![1.0, 2.0, 3.0]);
879 assert_eq!(m[1], vec![4.0, 5.0, 6.0]);
880 }
881
882 #[test]
883 fn test_dense_restriction_frobenius_norm() {
884 let r = DenseRestriction::new(vec![3.0, 4.0], 1, 2).unwrap();
885 let norm = r.frobenius_norm();
886 assert!((norm - 5.0).abs() < 1e-6); }
888
889 #[test]
890 fn test_identity_restriction_is_identity() {
891 let r = DenseRestriction::identity(4);
892 assert_eq!(r.in_dim(), 4);
893 assert_eq!(r.out_dim(), 4);
894 let x = vec![1.0, 2.0, 3.0, 4.0];
895 assert_eq!(r.apply(&x).unwrap(), x);
896 assert_eq!(r.apply_transpose(&x).unwrap(), x); }
898
899 #[test]
904 fn test_vec_stalk_set_value_dimension_mismatch() {
905 let mut s = VecStalk::new(vec![1.0, 2.0]);
906 let result = s.set_value(vec![1.0]);
907 assert!(matches!(
908 result,
909 Err(SheafError::DimensionMismatch {
910 expected: 2,
911 actual: 1
912 })
913 ));
914 }
915
916 #[test]
917 fn test_vec_stalk_zero() {
918 let s = VecStalk::new(vec![5.0, 6.0, 7.0]);
919 assert_eq!(s.zero(), vec![0.0, 0.0, 0.0]);
920 }
921
922 #[test]
923 fn test_vec_stalk_roundtrip() {
924 let mut s = VecStalk::new(vec![1.0, 2.0]);
925 s.set_value(vec![3.0, 4.0]).unwrap();
926 assert_eq!(s.value(), &vec![3.0, 4.0]);
927 assert_eq!(s.dim(), 2);
928 }
929
930 #[test]
935 fn test_add_edge_source_not_found() {
936 let mut graph = SimpleSheafGraph::new();
937 graph.add_node(vec![1.0]);
938 let r = DenseRestriction::identity(1);
939 let result = graph.add_edge(5, 0, r.clone(), r.clone(), 1.0);
940 assert!(matches!(result, Err(SheafError::NodeNotFound(5))));
941 }
942
943 #[test]
944 fn test_add_edge_target_not_found() {
945 let mut graph = SimpleSheafGraph::new();
946 graph.add_node(vec![1.0]);
947 let r = DenseRestriction::identity(1);
948 let result = graph.add_edge(0, 99, r.clone(), r.clone(), 1.0);
949 assert!(matches!(result, Err(SheafError::NodeNotFound(99))));
950 }
951
952 #[test]
953 fn test_add_edge_restriction_dim_mismatch_source() {
954 let mut graph = SimpleSheafGraph::new();
955 graph.add_node(vec![1.0, 2.0]); graph.add_node(vec![1.0, 2.0]); let r_wrong = DenseRestriction::identity(3); let r_ok = DenseRestriction::identity(2);
959 let result = graph.add_edge(0, 1, r_wrong, r_ok, 1.0);
960 assert!(matches!(result, Err(SheafError::DimensionMismatch { .. })));
961 }
962
963 #[test]
964 fn test_add_edge_restriction_output_dim_mismatch() {
965 let mut graph = SimpleSheafGraph::new();
966 graph.add_node(vec![1.0, 2.0]);
967 graph.add_node(vec![1.0, 2.0]);
968 let r_src = DenseRestriction::new(vec![1.0; 6], 3, 2).unwrap();
970 let r_tgt = DenseRestriction::identity(2);
971 let result = graph.add_edge(0, 1, r_src, r_tgt, 1.0);
972 assert!(matches!(result, Err(SheafError::InvalidRestriction(_))));
973 }
974
975 #[test]
976 fn test_stalk_not_found() {
977 let graph = SimpleSheafGraph::new();
978 assert!(matches!(graph.stalk(0), Err(SheafError::NodeNotFound(0))));
979 }
980
981 #[test]
982 fn test_edge_not_found() {
983 let mut graph = SimpleSheafGraph::new();
984 graph.add_node(vec![1.0]);
985 graph.add_node(vec![1.0]);
986 assert!(matches!(
988 graph.edge(0, 1),
989 Err(SheafError::EdgeNotFound(0, 1))
990 ));
991 }
992
993 #[test]
994 fn test_neighbors_not_found() {
995 let graph = SimpleSheafGraph::new();
996 assert!(matches!(
997 graph.neighbors(0),
998 Err(SheafError::NodeNotFound(0))
999 ));
1000 }
1001
1002 #[test]
1003 fn test_edge_lookup_bidirectional() {
1004 let mut graph = SimpleSheafGraph::new();
1005 graph.add_node(vec![1.0]);
1006 graph.add_node(vec![2.0]);
1007 let r = DenseRestriction::identity(1);
1008 graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1009
1010 assert!(graph.edge(0, 1).is_ok());
1012 assert!(graph.edge(1, 0).is_ok());
1013 }
1014
1015 #[test]
1016 fn test_neighbors_bidirectional() {
1017 let mut graph = SimpleSheafGraph::new();
1018 graph.add_node(vec![1.0]);
1019 graph.add_node(vec![2.0]);
1020 graph.add_node(vec![3.0]);
1021 let r = DenseRestriction::identity(1);
1022 graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1023 graph.add_edge(1, 2, r.clone(), r.clone(), 1.0).unwrap();
1024
1025 let n1 = graph.neighbors(1).unwrap();
1026 assert_eq!(n1.len(), 2); }
1028
1029 #[test]
1034 fn test_dirichlet_energy_zero_for_identical_stalks() {
1035 let mut graph = SimpleSheafGraph::new();
1036 graph.add_node(vec![1.0, 2.0, 3.0]);
1037 graph.add_node(vec![1.0, 2.0, 3.0]);
1038 graph.add_node(vec![1.0, 2.0, 3.0]);
1039 let r = DenseRestriction::identity(3);
1040 graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1041 graph.add_edge(1, 2, r.clone(), r.clone(), 1.0).unwrap();
1042 let energy = graph.dirichlet_energy().unwrap();
1043 assert!(
1044 (energy - 0.0).abs() < 1e-6,
1045 "identical stalks should have zero energy"
1046 );
1047 }
1048
1049 #[test]
1050 fn test_dirichlet_energy_weighted() {
1051 let mut graph = SimpleSheafGraph::new();
1052 graph.add_node(vec![1.0, 0.0]);
1053 graph.add_node(vec![0.0, 1.0]);
1054 let r = DenseRestriction::identity(2);
1055 graph.add_edge(0, 1, r.clone(), r.clone(), 2.0).unwrap();
1057 let energy = graph.dirichlet_energy().unwrap();
1058 assert!((energy - 4.0).abs() < 1e-6);
1060 }
1061
1062 #[test]
1063 fn test_laplacian_at_zero_for_consistent_signal() {
1064 let mut graph = SimpleSheafGraph::new();
1065 graph.add_node(vec![1.0, 2.0]);
1066 graph.add_node(vec![1.0, 2.0]);
1067 let r = DenseRestriction::identity(2);
1068 graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1069
1070 let lap = graph.laplacian_at(0).unwrap();
1071 assert!(
1072 lap.iter().all(|&x| x.abs() < 1e-6),
1073 "Laplacian should be zero for consistent signal"
1074 );
1075 }
1076
1077 #[test]
1078 fn test_laplacian_symmetry() {
1079 let mut graph = SimpleSheafGraph::new();
1082 graph.add_node(vec![1.0, 0.0]);
1083 graph.add_node(vec![0.0, 1.0]);
1084 let r = DenseRestriction::identity(2);
1085 graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1086
1087 let lap0 = graph.laplacian_at(0).unwrap();
1088 let lap1 = graph.laplacian_at(1).unwrap();
1089 for i in 0..2 {
1091 assert!(
1092 (lap0[i] + lap1[i]).abs() < 1e-6,
1093 "Laplacian should sum to zero"
1094 );
1095 }
1096 }
1097
1098 #[test]
1103 fn test_diffuse_until_convergence_identical_stalks() {
1104 let mut graph = SimpleSheafGraph::new();
1105 graph.add_node(vec![1.0, 1.0]);
1106 graph.add_node(vec![1.0, 1.0]);
1107 let r = DenseRestriction::identity(2);
1108 graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1109
1110 let config = DiffusionConfig {
1111 num_steps: 100,
1112 step_size: 0.1,
1113 ..Default::default()
1114 };
1115
1116 let steps = diffuse_until_convergence(&mut graph, &config, 1e-8).unwrap();
1118 assert!(
1119 steps <= 2,
1120 "already-converged graph should converge immediately, took {steps}"
1121 );
1122 }
1123
1124 #[test]
1125 fn test_diffuse_until_convergence_reaches_max_steps() {
1126 let mut graph = SimpleSheafGraph::new();
1127 graph.add_node(vec![100.0, 0.0]);
1128 graph.add_node(vec![0.0, 100.0]);
1129 let r = DenseRestriction::identity(2);
1130 graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1131
1132 let config = DiffusionConfig {
1133 num_steps: 3,
1134 step_size: 0.01, ..Default::default()
1136 };
1137
1138 let steps = diffuse_until_convergence(&mut graph, &config, 1e-12).unwrap();
1139 assert_eq!(steps, 3, "should reach max steps");
1140 }
1141
1142 #[test]
1143 fn test_consistency_score_decreases_with_distance() {
1144 let mut g1 = SimpleSheafGraph::new();
1146 g1.add_node(vec![1.0, 0.0]);
1147 g1.add_node(vec![0.9, 0.1]);
1148 let r = DenseRestriction::identity(2);
1149 g1.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1150 let score1 = consistency_score(&g1).unwrap();
1151
1152 let mut g2 = SimpleSheafGraph::new();
1153 g2.add_node(vec![1.0, 0.0]);
1154 g2.add_node(vec![0.0, 1.0]);
1155 g2.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1156 let score2 = consistency_score(&g2).unwrap();
1157
1158 assert!(
1159 score1 > score2,
1160 "closer stalks should have higher consistency"
1161 );
1162 }
1163
1164 #[test]
1169 fn test_diffusion_config_default() {
1170 let config = DiffusionConfig::default();
1171 assert_eq!(config.num_steps, 5);
1172 assert!((config.step_size - 0.1).abs() < 1e-6);
1173 assert!(config.normalize);
1174 assert_eq!(config.laplacian_type, LaplacianType::General);
1175 }
1176
1177 #[test]
1178 fn test_sheaf_error_display() {
1179 assert_eq!(
1180 format!("{}", SheafError::NodeNotFound(5)),
1181 "Node 5 not found"
1182 );
1183 assert_eq!(
1184 format!("{}", SheafError::EdgeNotFound(1, 2)),
1185 "Edge (1, 2) not found"
1186 );
1187 assert_eq!(
1188 format!(
1189 "{}",
1190 SheafError::DimensionMismatch {
1191 expected: 3,
1192 actual: 2
1193 }
1194 ),
1195 "Dimension mismatch: expected 3, got 2"
1196 );
1197 assert!(format!("{}", SheafError::InvalidRestriction("bad".into())).contains("bad"));
1198 }
1199
1200 #[test]
1201 fn test_simple_sheaf_graph_default() {
1202 let graph = SimpleSheafGraph::default();
1203 assert_eq!(graph.num_nodes(), 0);
1204 assert_eq!(graph.num_edges(), 0);
1205 }
1206
1207 #[test]
1212 fn test_non_square_restriction_maps() {
1213 let mut graph = SimpleSheafGraph::new();
1215 graph.add_node(vec![1.0, 0.0, 0.0]);
1216 graph.add_node(vec![0.0, 1.0, 0.0]);
1217
1218 let proj = DenseRestriction::new(vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 2, 3).unwrap();
1220 graph
1221 .add_edge(0, 1, proj.clone(), proj.clone(), 1.0)
1222 .unwrap();
1223
1224 let energy = graph.dirichlet_energy().unwrap();
1225 assert!((energy - 2.0).abs() < 1e-6);
1227 }
1228
1229 #[test]
1230 fn test_diffusion_with_non_square_restrictions() {
1231 let mut graph = SimpleSheafGraph::new();
1233 graph.add_node(vec![1.0, 0.0, 0.0]);
1234 graph.add_node(vec![0.0, 1.0, 0.0]);
1235
1236 let proj = DenseRestriction::new(vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 2, 3).unwrap();
1237 graph
1238 .add_edge(0, 1, proj.clone(), proj.clone(), 1.0)
1239 .unwrap();
1240
1241 let initial_energy = graph.dirichlet_energy().unwrap();
1242 graph.diffusion_step(0.1).unwrap();
1243 let final_energy = graph.dirichlet_energy().unwrap();
1244 assert!(
1245 final_energy < initial_energy,
1246 "diffusion should reduce energy with non-square maps"
1247 );
1248 }
1249
1250 #[test]
1255 fn test_empty_graph_energy() {
1256 let graph = SimpleSheafGraph::new();
1257 let energy = graph.dirichlet_energy().unwrap();
1258 assert_eq!(energy, 0.0);
1259 }
1260
1261 #[test]
1262 fn test_single_node_graph() {
1263 let mut graph = SimpleSheafGraph::new();
1264 graph.add_node(vec![1.0, 2.0]);
1265 assert_eq!(graph.num_nodes(), 1);
1266 assert_eq!(graph.num_edges(), 0);
1267 let energy = graph.dirichlet_energy().unwrap();
1268 assert_eq!(energy, 0.0);
1269 }
1270}