Skip to main content

subsume/
sheaf.rs

1//! # Sheaf Diffusion
2//!
3//! Algebraic structures for enforcing transitivity and consistency in graphs.
4//!
5//! A **sheaf** on a graph assigns vector spaces (stalks) to nodes and linear maps
6//! (restriction maps) to edges. The key insight: if data is "consistent" across
7//! the graph, the Dirichlet energy is zero. Non-zero energy indicates inconsistency.
8//!
9//! This module provides the mathematical primitives (sheaf Laplacian, Euler-step
10//! diffusion) from Hansen & Ghrist (2019) and Bodnar et al. (ICLR 2022).
11//! It does **not** implement learnable restriction maps or neural architectures --
12//! those would be built on top of these primitives.
13//!
14//! # Why Sheaves for Coreference?
15//!
16//! Coreference requires transitivity: if A=B and B=C, then A=C.
17//! Traditional approaches enforce this post-hoc (transitive closure).
18//! A sheaf Laplacian can enforce it structurally: diffusion drives stalks toward
19//! consistency across the graph.
20//!
21//! ```text
22//! Mention A ──[restriction]──> Mention B ──[restriction]──> Mention C
23//!     │                            │                            │
24//!   stalk_A                     stalk_B                      stalk_C
25//!     │                            │                            │
26//!     └── If A=B=C, stalks should be "compatible" under restrictions
27//! ```
28//!
29//! # Mathematical Background
30//!
31//! Given a graph G = (V, E), a **cellular sheaf** F assigns:
32//! - To each vertex v: a vector space F(v) (the "stalk")
33//! - To each edge e = (u,v): a linear map F_{u←e}: F(u) → F(e) (restriction)
34//!
35//! The **sheaf Laplacian** L_F is defined as:
36//!
37//! L_F = δ^T · δ
38//!
39//! where δ is the coboundary operator. The **Dirichlet energy** is:
40//!
41//! E(x) = x^T L_F x = Σ_{(u,v) ∈ E} ||F_{u←e}(x_u) - F_{v←e}(x_v)||²
42//!
43//! Low energy means the signal x is "consistent" across the sheaf.
44//!
45//! # Relationship to Graph Neural Networks
46//!
47//! | Model | Message | Aggregation | Transitivity |
48//! |-------|---------|-------------|--------------|
49//! | GCN | Identity | Sum | Implicit |
50//! | GAT | Attention-weighted | Sum | Implicit |
51//! | Sheaf | Restriction maps | Laplacian diffusion | **Explicit** |
52//!
53//! Sheaf neural networks generalize GNNs by learning edge-specific linear maps
54//! instead of using identity or scalar weights.
55//!
56//! # Implementation
57//!
58//! This module provides framework-agnostic traits. Implementations live in:
59//! - `subsume`: CPU-based with ndarray
60//! - `subsume-candle`: GPU-accelerated with candle
61//!
62//! # References
63//!
64//! - Hansen & Ghrist (2019): "Toward a spectral theory of cellular sheaves"
65//! - Bodnar et al. (2022): "Neural Sheaf Diffusion" (ICLR)
66//! - Barbero et al. (2022): "Sheaf Neural Networks with Connection Laplacians"
67//! - Bodnar (2023): "Topological Deep Learning: Graphs, Complexes, Sheaves"
68//!   (Cambridge PhD thesis) -- connects sheaf structure to asymptotic behavior of
69//!   message passing, providing theoretical grounding
70//! - Zaghen (2024): "Nonlinear Sheaf Diffusion in Graph Neural Networks" -- introduces
71//!   nonlinear Laplacians for sheaf diffusion; the current linear restriction maps could
72//!   be extended with nonlinear variants for heterophilic graphs
73//! - Hu (2026): "Sheaf-Theoretic and Topological Perspective on Complex Network Modeling"
74//!   -- comprehensive survey of sheaf neural networks and sheaf attention mechanisms
75
76use std::collections::HashMap;
77use std::fmt::Debug;
78
79/// Error type for sheaf operations.
80#[non_exhaustive]
81#[derive(Debug, Clone, PartialEq, thiserror::Error)]
82pub enum SheafError {
83    /// Node not found in the graph.
84    #[error("Node {0} not found")]
85    NodeNotFound(usize),
86    /// Edge not found in the graph.
87    #[error("Edge ({0}, {1}) not found")]
88    EdgeNotFound(usize, usize),
89    /// Dimension mismatch in linear map.
90    #[error("Dimension mismatch: expected {expected}, got {actual}")]
91    DimensionMismatch {
92        /// Expected dimension.
93        expected: usize,
94        /// Actual dimension.
95        actual: usize,
96    },
97    /// Invalid restriction map.
98    #[error("Invalid restriction: {0}")]
99    InvalidRestriction(String),
100}
101
102/// A restriction map (linear transformation) on an edge.
103///
104/// For edge (u, v), this maps from the stalk at u to the edge space.
105/// The restriction map captures "how information flows" along the edge.
106///
107/// For coreference: if mentions u and v are coreferent, their stalks
108/// should map to the same point in the edge space.
109pub trait RestrictionMap: Clone + Debug {
110    /// Scalar type for the map.
111    type Scalar: Clone + Debug;
112    /// Vector type for input/output.
113    type Vector: Clone + Debug;
114
115    /// Input dimension (stalk dimension at source node).
116    fn in_dim(&self) -> usize;
117
118    /// Output dimension (edge space dimension).
119    fn out_dim(&self) -> usize;
120
121    /// Apply the restriction map to a stalk vector.
122    fn apply(&self, x: &Self::Vector) -> Result<Self::Vector, SheafError>;
123
124    /// Apply the transpose (adjoint) of the restriction map.
125    /// Used in Laplacian computation.
126    fn apply_transpose(&self, x: &Self::Vector) -> Result<Self::Vector, SheafError>;
127
128    /// Get the matrix representation (for debugging/serialization).
129    fn as_matrix(&self) -> Vec<Vec<Self::Scalar>>;
130
131    /// Frobenius norm of the map (for regularization).
132    fn frobenius_norm(&self) -> Self::Scalar;
133}
134
135/// A stalk (vector space) at a node.
136///
137/// For coreference: the stalk at a mention node contains its embedding.
138pub trait Stalk: Clone + Debug {
139    /// Scalar type.
140    type Scalar: Clone + Debug;
141    /// Vector type.
142    type Vector: Clone + Debug;
143
144    /// Dimension of the stalk.
145    fn dim(&self) -> usize;
146
147    /// Get the current value (signal on the stalk).
148    fn value(&self) -> &Self::Vector;
149
150    /// Set the value.
151    fn set_value(&mut self, v: Self::Vector) -> Result<(), SheafError>;
152
153    /// Zero vector in this stalk.
154    fn zero(&self) -> Self::Vector;
155}
156
157/// Edge data in a sheaf graph.
158#[derive(Debug, Clone)]
159pub struct SheafEdge<R: RestrictionMap> {
160    /// Source node ID.
161    pub source: usize,
162    /// Target node ID.
163    pub target: usize,
164    /// Restriction map from source stalk to edge space.
165    pub restriction_source: R,
166    /// Restriction map from target stalk to edge space.
167    pub restriction_target: R,
168    /// Edge weight (optional, for weighted Laplacian).
169    pub weight: f32,
170}
171
172/// A sheaf on a graph.
173///
174/// This is the main data structure for sheaf neural networks.
175/// It assigns stalks to nodes and restriction maps to edges.
176pub trait SheafGraph: Debug {
177    /// Scalar type.
178    type Scalar: Clone + Debug + Default;
179    /// Vector type.
180    type Vector: Clone + Debug;
181    /// Restriction map type.
182    type Restriction: RestrictionMap<Scalar = Self::Scalar, Vector = Self::Vector>;
183    /// Stalk type.
184    type Stalk: Stalk<Scalar = Self::Scalar, Vector = Self::Vector>;
185
186    /// Number of nodes.
187    fn num_nodes(&self) -> usize;
188
189    /// Number of edges.
190    fn num_edges(&self) -> usize;
191
192    /// Get stalk at node.
193    fn stalk(&self, node: usize) -> Result<&Self::Stalk, SheafError>;
194
195    /// Get mutable stalk at node.
196    fn stalk_mut(&mut self, node: usize) -> Result<&mut Self::Stalk, SheafError>;
197
198    /// Get edge data.
199    fn edge(
200        &self,
201        source: usize,
202        target: usize,
203    ) -> Result<&SheafEdge<Self::Restriction>, SheafError>;
204
205    /// Iterate over all edges.
206    fn edges(&self) -> impl Iterator<Item = &SheafEdge<Self::Restriction>>;
207
208    /// Get neighbors of a node.
209    fn neighbors(&self, node: usize) -> Result<Vec<usize>, SheafError>;
210
211    /// Compute Dirichlet energy for the current stalk values.
212    ///
213    /// E(x) = Σ_{(u,v) ∈ E} w_{uv} ||R_u(x_u) - R_v(x_v)||²
214    ///
215    /// Low energy means consistent signal across the sheaf.
216    fn dirichlet_energy(&self) -> Result<Self::Scalar, SheafError>;
217
218    /// Compute the sheaf Laplacian action on a given node.
219    ///
220    /// (L_F x)_v = Σ_{u ~ v} R_v^T (R_v x_v - R_u x_u)
221    fn laplacian_at(&self, node: usize) -> Result<Self::Vector, SheafError>;
222
223    /// Perform one step of sheaf diffusion.
224    ///
225    /// x_{t+1} = x_t - α * L_F * x_t
226    ///
227    /// This smooths the signal according to sheaf structure.
228    fn diffusion_step(&mut self, step_size: Self::Scalar) -> Result<(), SheafError>;
229}
230
231// =============================================================================
232// Sheaf Laplacian Types
233// =============================================================================
234
235/// Describes the structure of a sheaf Laplacian.
236///
237/// The Laplacian can be:
238/// - **Connection Laplacian**: Uses orthogonal restriction maps (preserves norms)
239/// - **General Laplacian**: Arbitrary linear maps
240/// - **Diagonal Laplacian**: Scalar weights only (reduces to graph Laplacian)
241#[derive(Debug, Clone, Copy, PartialEq, Eq)]
242pub enum LaplacianType {
243    /// Orthogonal restriction maps (O(d) valued).
244    Connection,
245    /// General linear maps (GL(d) valued).
246    General,
247    /// Scalar weights (diagonal, equivalent to weighted graph Laplacian).
248    Diagonal,
249}
250
251/// Configuration for sheaf diffusion.
252#[derive(Debug, Clone)]
253pub struct DiffusionConfig {
254    /// Number of diffusion steps.
255    pub num_steps: usize,
256    /// Step size (learning rate for diffusion).
257    pub step_size: f32,
258    /// Whether to normalize the Laplacian (D^{-1/2} L D^{-1/2}).
259    pub normalize: bool,
260    /// Type of Laplacian to use (Connection, General, or Diagonal).
261    pub laplacian_type: LaplacianType,
262}
263
264impl Default for DiffusionConfig {
265    fn default() -> Self {
266        Self {
267            num_steps: 5,
268            step_size: 0.1,
269            normalize: true,
270            laplacian_type: LaplacianType::General,
271        }
272    }
273}
274
275// =============================================================================
276// Simple In-Memory Implementation (f32, Vec<f32>)
277// =============================================================================
278
279/// Simple restriction map using a dense matrix.
280#[derive(Debug, Clone)]
281pub struct DenseRestriction {
282    /// Matrix data in row-major order.
283    pub data: Vec<f32>,
284    /// Number of rows (output dimension).
285    pub rows: usize,
286    /// Number of columns (input dimension).
287    pub cols: usize,
288}
289
290impl DenseRestriction {
291    /// Create a new restriction map.
292    pub fn new(data: Vec<f32>, rows: usize, cols: usize) -> Result<Self, SheafError> {
293        if data.len() != rows * cols {
294            return Err(SheafError::DimensionMismatch {
295                expected: rows * cols,
296                actual: data.len(),
297            });
298        }
299        Ok(Self { data, rows, cols })
300    }
301
302    /// Create an identity restriction (for same-dimension stalks).
303    pub fn identity(dim: usize) -> Self {
304        let mut data = vec![0.0; dim * dim];
305        for i in 0..dim {
306            data[i * dim + i] = 1.0;
307        }
308        Self {
309            data,
310            rows: dim,
311            cols: dim,
312        }
313    }
314
315    /// Create a random orthogonal restriction map.
316    ///
317    /// Uses QR decomposition of random matrix.
318    /// Useful for connection Laplacians.
319    #[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
320    #[cfg(feature = "rand")]
321    #[allow(deprecated)]
322    pub fn random_orthogonal(dim: usize) -> Self {
323        use rand::Rng;
324        let mut rng = rand::thread_rng();
325
326        // Generate random matrix
327        let mut data: Vec<f32> = (0..dim * dim).map(|_| rng.gen_range(-0.5..0.5)).collect();
328
329        // Simple Gram-Schmidt orthogonalization
330        for i in 0..dim {
331            // Normalize column i
332            let mut norm: f32 = 0.0;
333            for j in 0..dim {
334                norm += data[j * dim + i] * data[j * dim + i];
335            }
336            norm = norm.sqrt();
337            if norm > 1e-6 {
338                for j in 0..dim {
339                    data[j * dim + i] /= norm;
340                }
341            }
342
343            // Subtract projections from remaining columns
344            for k in (i + 1)..dim {
345                let mut dot = 0.0;
346                for j in 0..dim {
347                    dot += data[j * dim + i] * data[j * dim + k];
348                }
349                for j in 0..dim {
350                    data[j * dim + k] -= dot * data[j * dim + i];
351                }
352            }
353        }
354
355        Self {
356            data,
357            rows: dim,
358            cols: dim,
359        }
360    }
361}
362
363impl RestrictionMap for DenseRestriction {
364    type Scalar = f32;
365    type Vector = Vec<f32>;
366
367    fn in_dim(&self) -> usize {
368        self.cols
369    }
370
371    fn out_dim(&self) -> usize {
372        self.rows
373    }
374
375    fn apply(&self, x: &Self::Vector) -> Result<Self::Vector, SheafError> {
376        if x.len() != self.cols {
377            return Err(SheafError::DimensionMismatch {
378                expected: self.cols,
379                actual: x.len(),
380            });
381        }
382
383        // Matrix multiplication: result = A * x
384        // Using explicit indexing for clarity in matrix math
385        let mut result = vec![0.0; self.rows];
386        #[allow(clippy::needless_range_loop)]
387        for i in 0..self.rows {
388            for j in 0..self.cols {
389                result[i] += self.data[i * self.cols + j] * x[j];
390            }
391        }
392        Ok(result)
393    }
394
395    fn apply_transpose(&self, x: &Self::Vector) -> Result<Self::Vector, SheafError> {
396        if x.len() != self.rows {
397            return Err(SheafError::DimensionMismatch {
398                expected: self.rows,
399                actual: x.len(),
400            });
401        }
402
403        // Matrix transpose multiplication: result = A^T * x
404        let mut result = vec![0.0; self.cols];
405        #[allow(clippy::needless_range_loop)]
406        for j in 0..self.cols {
407            for i in 0..self.rows {
408                result[j] += self.data[i * self.cols + j] * x[i];
409            }
410        }
411        Ok(result)
412    }
413
414    fn as_matrix(&self) -> Vec<Vec<Self::Scalar>> {
415        // Convert flat storage to row-major matrix
416        let mut matrix = vec![vec![0.0; self.cols]; self.rows];
417        #[allow(clippy::needless_range_loop)]
418        for i in 0..self.rows {
419            for j in 0..self.cols {
420                matrix[i][j] = self.data[i * self.cols + j];
421            }
422        }
423        matrix
424    }
425
426    fn frobenius_norm(&self) -> Self::Scalar {
427        self.data.iter().map(|x| x * x).sum::<f32>().sqrt()
428    }
429}
430
431/// Simple stalk holding a `Vec<f32>`.
432#[derive(Debug, Clone)]
433pub struct VecStalk {
434    value: Vec<f32>,
435}
436
437impl VecStalk {
438    /// Create a new stalk with given value.
439    pub fn new(value: Vec<f32>) -> Self {
440        Self { value }
441    }
442}
443
444impl Stalk for VecStalk {
445    type Scalar = f32;
446    type Vector = Vec<f32>;
447
448    fn dim(&self) -> usize {
449        self.value.len()
450    }
451
452    fn value(&self) -> &Self::Vector {
453        &self.value
454    }
455
456    fn set_value(&mut self, v: Self::Vector) -> Result<(), SheafError> {
457        if v.len() != self.value.len() {
458            return Err(SheafError::DimensionMismatch {
459                expected: self.value.len(),
460                actual: v.len(),
461            });
462        }
463        self.value = v;
464        Ok(())
465    }
466
467    fn zero(&self) -> Self::Vector {
468        vec![0.0; self.value.len()]
469    }
470}
471
472/// Simple in-memory sheaf graph.
473#[derive(Debug, Clone)]
474pub struct SimpleSheafGraph {
475    stalks: Vec<VecStalk>,
476    edges: Vec<SheafEdge<DenseRestriction>>,
477    adjacency: HashMap<usize, Vec<usize>>,
478}
479
480impl SimpleSheafGraph {
481    /// Create a new empty sheaf graph.
482    pub fn new() -> Self {
483        Self {
484            stalks: Vec::new(),
485            edges: Vec::new(),
486            adjacency: HashMap::new(),
487        }
488    }
489
490    /// Add a node with initial stalk value.
491    pub fn add_node(&mut self, value: Vec<f32>) -> usize {
492        let id = self.stalks.len();
493        self.stalks.push(VecStalk::new(value));
494        self.adjacency.insert(id, Vec::new());
495        id
496    }
497
498    /// Add an edge with restriction maps.
499    pub fn add_edge(
500        &mut self,
501        source: usize,
502        target: usize,
503        restriction_source: DenseRestriction,
504        restriction_target: DenseRestriction,
505        weight: f32,
506    ) -> Result<(), SheafError> {
507        if source >= self.stalks.len() {
508            return Err(SheafError::NodeNotFound(source));
509        }
510        if target >= self.stalks.len() {
511            return Err(SheafError::NodeNotFound(target));
512        }
513
514        // Verify dimensions
515        if restriction_source.in_dim() != self.stalks[source].dim() {
516            return Err(SheafError::DimensionMismatch {
517                expected: self.stalks[source].dim(),
518                actual: restriction_source.in_dim(),
519            });
520        }
521        if restriction_target.in_dim() != self.stalks[target].dim() {
522            return Err(SheafError::DimensionMismatch {
523                expected: self.stalks[target].dim(),
524                actual: restriction_target.in_dim(),
525            });
526        }
527        if restriction_source.out_dim() != restriction_target.out_dim() {
528            return Err(SheafError::InvalidRestriction(
529                "Source and target restrictions must have same output dimension".into(),
530            ));
531        }
532
533        self.edges.push(SheafEdge {
534            source,
535            target,
536            restriction_source,
537            restriction_target,
538            weight,
539        });
540
541        self.adjacency.entry(source).or_default().push(target);
542        self.adjacency.entry(target).or_default().push(source);
543
544        Ok(())
545    }
546}
547
548impl Default for SimpleSheafGraph {
549    fn default() -> Self {
550        Self::new()
551    }
552}
553
554impl SheafGraph for SimpleSheafGraph {
555    type Scalar = f32;
556    type Vector = Vec<f32>;
557    type Restriction = DenseRestriction;
558    type Stalk = VecStalk;
559
560    fn num_nodes(&self) -> usize {
561        self.stalks.len()
562    }
563
564    fn num_edges(&self) -> usize {
565        self.edges.len()
566    }
567
568    fn stalk(&self, node: usize) -> Result<&Self::Stalk, SheafError> {
569        self.stalks.get(node).ok_or(SheafError::NodeNotFound(node))
570    }
571
572    fn stalk_mut(&mut self, node: usize) -> Result<&mut Self::Stalk, SheafError> {
573        self.stalks
574            .get_mut(node)
575            .ok_or(SheafError::NodeNotFound(node))
576    }
577
578    fn edge(
579        &self,
580        source: usize,
581        target: usize,
582    ) -> Result<&SheafEdge<Self::Restriction>, SheafError> {
583        self.edges
584            .iter()
585            .find(|e| {
586                (e.source == source && e.target == target)
587                    || (e.source == target && e.target == source)
588            })
589            .ok_or(SheafError::EdgeNotFound(source, target))
590    }
591
592    fn edges(&self) -> impl Iterator<Item = &SheafEdge<Self::Restriction>> {
593        self.edges.iter()
594    }
595
596    fn neighbors(&self, node: usize) -> Result<Vec<usize>, SheafError> {
597        self.adjacency
598            .get(&node)
599            .cloned()
600            .ok_or(SheafError::NodeNotFound(node))
601    }
602
603    fn dirichlet_energy(&self) -> Result<Self::Scalar, SheafError> {
604        let mut energy = 0.0;
605
606        for edge in &self.edges {
607            let x_u = self.stalks[edge.source].value();
608            let x_v = self.stalks[edge.target].value();
609
610            let r_u = edge.restriction_source.apply(x_u)?;
611            let r_v = edge.restriction_target.apply(x_v)?;
612
613            // ||R_u(x_u) - R_v(x_v)||²
614            let diff_sq: f32 = r_u
615                .iter()
616                .zip(r_v.iter())
617                .map(|(a, b)| (a - b) * (a - b))
618                .sum();
619
620            energy += edge.weight * diff_sq;
621        }
622
623        Ok(energy)
624    }
625
626    fn laplacian_at(&self, node: usize) -> Result<Self::Vector, SheafError> {
627        let stalk = self.stalk(node)?;
628        let mut result = stalk.zero();
629
630        for edge in &self.edges {
631            let (is_source, other) = if edge.source == node {
632                (true, edge.target)
633            } else if edge.target == node {
634                (false, edge.source)
635            } else {
636                continue;
637            };
638
639            let x_node = self.stalks[node].value();
640            let x_other = self.stalks[other].value();
641
642            let (r_node, r_other) = if is_source {
643                (&edge.restriction_source, &edge.restriction_target)
644            } else {
645                (&edge.restriction_target, &edge.restriction_source)
646            };
647
648            // R_node(x_node) - R_other(x_other)
649            let r_x_node = r_node.apply(x_node)?;
650            let r_x_other = r_other.apply(x_other)?;
651
652            let diff: Vec<f32> = r_x_node
653                .iter()
654                .zip(r_x_other.iter())
655                .map(|(a, b)| a - b)
656                .collect();
657
658            // R_node^T (diff)
659            let contrib = r_node.apply_transpose(&diff)?;
660
661            // Accumulate weighted contribution
662            for (i, c) in contrib.iter().enumerate() {
663                result[i] += edge.weight * c;
664            }
665        }
666
667        Ok(result)
668    }
669
670    fn diffusion_step(&mut self, step_size: Self::Scalar) -> Result<(), SheafError> {
671        // Compute Laplacian at all nodes first (to avoid borrowing issues)
672        let laplacians: Vec<Vec<f32>> = (0..self.num_nodes())
673            .map(|i| self.laplacian_at(i))
674            .collect::<Result<_, _>>()?;
675
676        // Update all stalks: x = x - step_size * L_F * x
677        for (i, lap) in laplacians.into_iter().enumerate() {
678            let stalk = &mut self.stalks[i];
679            let new_value: Vec<f32> = stalk
680                .value()
681                .iter()
682                .zip(lap.iter())
683                .map(|(x, l)| x - step_size * l)
684                .collect();
685            stalk.set_value(new_value)?;
686        }
687
688        Ok(())
689    }
690}
691
692// =============================================================================
693// Utility Functions
694// =============================================================================
695
696/// Compute consistency score for a sheaf graph.
697///
698/// Returns 1.0 for perfect consistency (zero energy), decreasing toward 0.
699pub fn consistency_score(graph: &impl SheafGraph<Scalar = f32>) -> Result<f32, SheafError> {
700    let energy = graph.dirichlet_energy()?;
701    // Use exponential decay: exp(-energy)
702    Ok((-energy).exp())
703}
704
705/// Run sheaf diffusion until convergence or max iterations.
706pub fn diffuse_until_convergence(
707    graph: &mut SimpleSheafGraph,
708    config: &DiffusionConfig,
709    tolerance: f32,
710) -> Result<usize, SheafError> {
711    let mut prev_energy = graph.dirichlet_energy()?;
712
713    for step in 0..config.num_steps {
714        graph.diffusion_step(config.step_size)?;
715        let energy = graph.dirichlet_energy()?;
716
717        if (prev_energy - energy).abs() < tolerance {
718            return Ok(step + 1);
719        }
720        prev_energy = energy;
721    }
722
723    Ok(config.num_steps)
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729
730    #[test]
731    fn test_identity_restriction() {
732        let r = DenseRestriction::identity(3);
733        let x = vec![1.0, 2.0, 3.0];
734        let y = r.apply(&x).unwrap();
735        assert_eq!(y, x);
736    }
737
738    #[test]
739    fn test_restriction_transpose() {
740        // 2x3 matrix
741        let r = DenseRestriction::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
742
743        let x = vec![1.0, 2.0, 3.0];
744        let y = r.apply(&x).unwrap();
745        assert_eq!(y.len(), 2);
746
747        let z = vec![1.0, 1.0];
748        let w = r.apply_transpose(&z).unwrap();
749        assert_eq!(w.len(), 3);
750        // Transpose of [[1,2,3],[4,5,6]] is [[1,4],[2,5],[3,6]]
751        // [1,4] · [1,1] = 5, [2,5] · [1,1] = 7, [3,6] · [1,1] = 9
752        assert_eq!(w, vec![5.0, 7.0, 9.0]);
753    }
754
755    #[test]
756    fn test_simple_sheaf_graph() {
757        let mut graph = SimpleSheafGraph::new();
758
759        // Two nodes with 2D stalks
760        let n0 = graph.add_node(vec![1.0, 0.0]);
761        let n1 = graph.add_node(vec![0.0, 1.0]);
762
763        // Identity restrictions (simplest case)
764        let r = DenseRestriction::identity(2);
765        graph.add_edge(n0, n1, r.clone(), r.clone(), 1.0).unwrap();
766
767        assert_eq!(graph.num_nodes(), 2);
768        assert_eq!(graph.num_edges(), 1);
769
770        // Dirichlet energy should be ||[1,0] - [0,1]||² = 2
771        let energy = graph.dirichlet_energy().unwrap();
772        assert!((energy - 2.0).abs() < 1e-6);
773    }
774
775    #[test]
776    fn test_diffusion_reduces_energy() {
777        let mut graph = SimpleSheafGraph::new();
778
779        // Three nodes forming a chain
780        let n0 = graph.add_node(vec![1.0, 0.0]);
781        let n1 = graph.add_node(vec![0.5, 0.5]);
782        let n2 = graph.add_node(vec![0.0, 1.0]);
783
784        let r = DenseRestriction::identity(2);
785        graph.add_edge(n0, n1, r.clone(), r.clone(), 1.0).unwrap();
786        graph.add_edge(n1, n2, r.clone(), r.clone(), 1.0).unwrap();
787
788        let initial_energy = graph.dirichlet_energy().unwrap();
789
790        // Run diffusion
791        for _ in 0..10 {
792            graph.diffusion_step(0.1).unwrap();
793        }
794
795        let final_energy = graph.dirichlet_energy().unwrap();
796        assert!(
797            final_energy < initial_energy,
798            "Diffusion should reduce energy"
799        );
800    }
801
802    #[test]
803    fn test_consistency_score() {
804        let mut graph = SimpleSheafGraph::new();
805
806        // Two nodes with identical stalks
807        graph.add_node(vec![1.0, 2.0]);
808        graph.add_node(vec![1.0, 2.0]);
809
810        let r = DenseRestriction::identity(2);
811        graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
812
813        // Perfect consistency: score should be 1.0
814        let score = consistency_score(&graph).unwrap();
815        assert!((score - 1.0).abs() < 1e-6);
816    }
817
818    // =========================================================================
819    // DenseRestriction edge cases
820    // =========================================================================
821
822    #[test]
823    fn test_dense_restriction_new_dimension_mismatch() {
824        let result = DenseRestriction::new(vec![1.0, 2.0, 3.0], 2, 2);
825        assert!(matches!(
826            result,
827            Err(SheafError::DimensionMismatch {
828                expected: 4,
829                actual: 3
830            })
831        ));
832    }
833
834    #[test]
835    fn test_dense_restriction_1x1() {
836        let r = DenseRestriction::new(vec![3.0], 1, 1).unwrap();
837        let x = vec![2.0];
838        let y = r.apply(&x).unwrap();
839        assert_eq!(y, vec![6.0]);
840
841        let yt = r.apply_transpose(&vec![2.0]).unwrap();
842        assert_eq!(yt, vec![6.0]); // Transpose of 1x1 is itself
843    }
844
845    #[test]
846    fn test_dense_restriction_apply_wrong_dim() {
847        let r = DenseRestriction::identity(3);
848        let x = vec![1.0, 2.0]; // Wrong dimension
849        let result = r.apply(&x);
850        assert!(matches!(
851            result,
852            Err(SheafError::DimensionMismatch {
853                expected: 3,
854                actual: 2
855            })
856        ));
857    }
858
859    #[test]
860    fn test_dense_restriction_apply_transpose_wrong_dim() {
861        let r = DenseRestriction::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
862        let x = vec![1.0, 2.0, 3.0]; // 3 elements but rows = 2
863        let result = r.apply_transpose(&x);
864        assert!(matches!(
865            result,
866            Err(SheafError::DimensionMismatch {
867                expected: 2,
868                actual: 3
869            })
870        ));
871    }
872
873    #[test]
874    fn test_dense_restriction_as_matrix() {
875        let r = DenseRestriction::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
876        let m = r.as_matrix();
877        assert_eq!(m.len(), 2);
878        assert_eq!(m[0], vec![1.0, 2.0, 3.0]);
879        assert_eq!(m[1], vec![4.0, 5.0, 6.0]);
880    }
881
882    #[test]
883    fn test_dense_restriction_frobenius_norm() {
884        let r = DenseRestriction::new(vec![3.0, 4.0], 1, 2).unwrap();
885        let norm = r.frobenius_norm();
886        assert!((norm - 5.0).abs() < 1e-6); // sqrt(9 + 16) = 5
887    }
888
889    #[test]
890    fn test_identity_restriction_is_identity() {
891        let r = DenseRestriction::identity(4);
892        assert_eq!(r.in_dim(), 4);
893        assert_eq!(r.out_dim(), 4);
894        let x = vec![1.0, 2.0, 3.0, 4.0];
895        assert_eq!(r.apply(&x).unwrap(), x);
896        assert_eq!(r.apply_transpose(&x).unwrap(), x); // I^T = I
897    }
898
899    // =========================================================================
900    // VecStalk edge cases
901    // =========================================================================
902
903    #[test]
904    fn test_vec_stalk_set_value_dimension_mismatch() {
905        let mut s = VecStalk::new(vec![1.0, 2.0]);
906        let result = s.set_value(vec![1.0]);
907        assert!(matches!(
908            result,
909            Err(SheafError::DimensionMismatch {
910                expected: 2,
911                actual: 1
912            })
913        ));
914    }
915
916    #[test]
917    fn test_vec_stalk_zero() {
918        let s = VecStalk::new(vec![5.0, 6.0, 7.0]);
919        assert_eq!(s.zero(), vec![0.0, 0.0, 0.0]);
920    }
921
922    #[test]
923    fn test_vec_stalk_roundtrip() {
924        let mut s = VecStalk::new(vec![1.0, 2.0]);
925        s.set_value(vec![3.0, 4.0]).unwrap();
926        assert_eq!(s.value(), &vec![3.0, 4.0]);
927        assert_eq!(s.dim(), 2);
928    }
929
930    // =========================================================================
931    // SimpleSheafGraph error paths
932    // =========================================================================
933
934    #[test]
935    fn test_add_edge_source_not_found() {
936        let mut graph = SimpleSheafGraph::new();
937        graph.add_node(vec![1.0]);
938        let r = DenseRestriction::identity(1);
939        let result = graph.add_edge(5, 0, r.clone(), r.clone(), 1.0);
940        assert!(matches!(result, Err(SheafError::NodeNotFound(5))));
941    }
942
943    #[test]
944    fn test_add_edge_target_not_found() {
945        let mut graph = SimpleSheafGraph::new();
946        graph.add_node(vec![1.0]);
947        let r = DenseRestriction::identity(1);
948        let result = graph.add_edge(0, 99, r.clone(), r.clone(), 1.0);
949        assert!(matches!(result, Err(SheafError::NodeNotFound(99))));
950    }
951
952    #[test]
953    fn test_add_edge_restriction_dim_mismatch_source() {
954        let mut graph = SimpleSheafGraph::new();
955        graph.add_node(vec![1.0, 2.0]); // dim 2
956        graph.add_node(vec![1.0, 2.0]); // dim 2
957        let r_wrong = DenseRestriction::identity(3); // dim 3
958        let r_ok = DenseRestriction::identity(2);
959        let result = graph.add_edge(0, 1, r_wrong, r_ok, 1.0);
960        assert!(matches!(result, Err(SheafError::DimensionMismatch { .. })));
961    }
962
963    #[test]
964    fn test_add_edge_restriction_output_dim_mismatch() {
965        let mut graph = SimpleSheafGraph::new();
966        graph.add_node(vec![1.0, 2.0]);
967        graph.add_node(vec![1.0, 2.0]);
968        // Source restriction: 2->3, target restriction: 2->2 (output dims differ)
969        let r_src = DenseRestriction::new(vec![1.0; 6], 3, 2).unwrap();
970        let r_tgt = DenseRestriction::identity(2);
971        let result = graph.add_edge(0, 1, r_src, r_tgt, 1.0);
972        assert!(matches!(result, Err(SheafError::InvalidRestriction(_))));
973    }
974
975    #[test]
976    fn test_stalk_not_found() {
977        let graph = SimpleSheafGraph::new();
978        assert!(matches!(graph.stalk(0), Err(SheafError::NodeNotFound(0))));
979    }
980
981    #[test]
982    fn test_edge_not_found() {
983        let mut graph = SimpleSheafGraph::new();
984        graph.add_node(vec![1.0]);
985        graph.add_node(vec![1.0]);
986        // No edge added
987        assert!(matches!(
988            graph.edge(0, 1),
989            Err(SheafError::EdgeNotFound(0, 1))
990        ));
991    }
992
993    #[test]
994    fn test_neighbors_not_found() {
995        let graph = SimpleSheafGraph::new();
996        assert!(matches!(
997            graph.neighbors(0),
998            Err(SheafError::NodeNotFound(0))
999        ));
1000    }
1001
1002    #[test]
1003    fn test_edge_lookup_bidirectional() {
1004        let mut graph = SimpleSheafGraph::new();
1005        graph.add_node(vec![1.0]);
1006        graph.add_node(vec![2.0]);
1007        let r = DenseRestriction::identity(1);
1008        graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1009
1010        // Both directions should find the edge
1011        assert!(graph.edge(0, 1).is_ok());
1012        assert!(graph.edge(1, 0).is_ok());
1013    }
1014
1015    #[test]
1016    fn test_neighbors_bidirectional() {
1017        let mut graph = SimpleSheafGraph::new();
1018        graph.add_node(vec![1.0]);
1019        graph.add_node(vec![2.0]);
1020        graph.add_node(vec![3.0]);
1021        let r = DenseRestriction::identity(1);
1022        graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1023        graph.add_edge(1, 2, r.clone(), r.clone(), 1.0).unwrap();
1024
1025        let n1 = graph.neighbors(1).unwrap();
1026        assert_eq!(n1.len(), 2); // connected to both 0 and 2
1027    }
1028
1029    // =========================================================================
1030    // Dirichlet energy and Laplacian
1031    // =========================================================================
1032
1033    #[test]
1034    fn test_dirichlet_energy_zero_for_identical_stalks() {
1035        let mut graph = SimpleSheafGraph::new();
1036        graph.add_node(vec![1.0, 2.0, 3.0]);
1037        graph.add_node(vec![1.0, 2.0, 3.0]);
1038        graph.add_node(vec![1.0, 2.0, 3.0]);
1039        let r = DenseRestriction::identity(3);
1040        graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1041        graph.add_edge(1, 2, r.clone(), r.clone(), 1.0).unwrap();
1042        let energy = graph.dirichlet_energy().unwrap();
1043        assert!(
1044            (energy - 0.0).abs() < 1e-6,
1045            "identical stalks should have zero energy"
1046        );
1047    }
1048
1049    #[test]
1050    fn test_dirichlet_energy_weighted() {
1051        let mut graph = SimpleSheafGraph::new();
1052        graph.add_node(vec![1.0, 0.0]);
1053        graph.add_node(vec![0.0, 1.0]);
1054        let r = DenseRestriction::identity(2);
1055        // Weight 2.0 should double the energy
1056        graph.add_edge(0, 1, r.clone(), r.clone(), 2.0).unwrap();
1057        let energy = graph.dirichlet_energy().unwrap();
1058        // ||[1,0] - [0,1]||^2 = 2, weight 2.0 => 4.0
1059        assert!((energy - 4.0).abs() < 1e-6);
1060    }
1061
1062    #[test]
1063    fn test_laplacian_at_zero_for_consistent_signal() {
1064        let mut graph = SimpleSheafGraph::new();
1065        graph.add_node(vec![1.0, 2.0]);
1066        graph.add_node(vec![1.0, 2.0]);
1067        let r = DenseRestriction::identity(2);
1068        graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1069
1070        let lap = graph.laplacian_at(0).unwrap();
1071        assert!(
1072            lap.iter().all(|&x| x.abs() < 1e-6),
1073            "Laplacian should be zero for consistent signal"
1074        );
1075    }
1076
1077    #[test]
1078    fn test_laplacian_symmetry() {
1079        // For identity restrictions on an edge (0,1), the Laplacian at 0
1080        // should be the negative of the Laplacian at 1 (conservation).
1081        let mut graph = SimpleSheafGraph::new();
1082        graph.add_node(vec![1.0, 0.0]);
1083        graph.add_node(vec![0.0, 1.0]);
1084        let r = DenseRestriction::identity(2);
1085        graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1086
1087        let lap0 = graph.laplacian_at(0).unwrap();
1088        let lap1 = graph.laplacian_at(1).unwrap();
1089        // L(0) + L(1) = 0 (conservation for identity restrictions)
1090        for i in 0..2 {
1091            assert!(
1092                (lap0[i] + lap1[i]).abs() < 1e-6,
1093                "Laplacian should sum to zero"
1094            );
1095        }
1096    }
1097
1098    // =========================================================================
1099    // Diffusion convergence
1100    // =========================================================================
1101
1102    #[test]
1103    fn test_diffuse_until_convergence_identical_stalks() {
1104        let mut graph = SimpleSheafGraph::new();
1105        graph.add_node(vec![1.0, 1.0]);
1106        graph.add_node(vec![1.0, 1.0]);
1107        let r = DenseRestriction::identity(2);
1108        graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1109
1110        let config = DiffusionConfig {
1111            num_steps: 100,
1112            step_size: 0.1,
1113            ..Default::default()
1114        };
1115
1116        // Already converged: should return 1 (converges on first step)
1117        let steps = diffuse_until_convergence(&mut graph, &config, 1e-8).unwrap();
1118        assert!(
1119            steps <= 2,
1120            "already-converged graph should converge immediately, took {steps}"
1121        );
1122    }
1123
1124    #[test]
1125    fn test_diffuse_until_convergence_reaches_max_steps() {
1126        let mut graph = SimpleSheafGraph::new();
1127        graph.add_node(vec![100.0, 0.0]);
1128        graph.add_node(vec![0.0, 100.0]);
1129        let r = DenseRestriction::identity(2);
1130        graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1131
1132        let config = DiffusionConfig {
1133            num_steps: 3,
1134            step_size: 0.01, // Very small step: won't converge in 3 steps
1135            ..Default::default()
1136        };
1137
1138        let steps = diffuse_until_convergence(&mut graph, &config, 1e-12).unwrap();
1139        assert_eq!(steps, 3, "should reach max steps");
1140    }
1141
1142    #[test]
1143    fn test_consistency_score_decreases_with_distance() {
1144        // Larger stalk differences should produce lower consistency
1145        let mut g1 = SimpleSheafGraph::new();
1146        g1.add_node(vec![1.0, 0.0]);
1147        g1.add_node(vec![0.9, 0.1]);
1148        let r = DenseRestriction::identity(2);
1149        g1.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1150        let score1 = consistency_score(&g1).unwrap();
1151
1152        let mut g2 = SimpleSheafGraph::new();
1153        g2.add_node(vec![1.0, 0.0]);
1154        g2.add_node(vec![0.0, 1.0]);
1155        g2.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
1156        let score2 = consistency_score(&g2).unwrap();
1157
1158        assert!(
1159            score1 > score2,
1160            "closer stalks should have higher consistency"
1161        );
1162    }
1163
1164    // =========================================================================
1165    // Default and Display impls
1166    // =========================================================================
1167
1168    #[test]
1169    fn test_diffusion_config_default() {
1170        let config = DiffusionConfig::default();
1171        assert_eq!(config.num_steps, 5);
1172        assert!((config.step_size - 0.1).abs() < 1e-6);
1173        assert!(config.normalize);
1174        assert_eq!(config.laplacian_type, LaplacianType::General);
1175    }
1176
1177    #[test]
1178    fn test_sheaf_error_display() {
1179        assert_eq!(
1180            format!("{}", SheafError::NodeNotFound(5)),
1181            "Node 5 not found"
1182        );
1183        assert_eq!(
1184            format!("{}", SheafError::EdgeNotFound(1, 2)),
1185            "Edge (1, 2) not found"
1186        );
1187        assert_eq!(
1188            format!(
1189                "{}",
1190                SheafError::DimensionMismatch {
1191                    expected: 3,
1192                    actual: 2
1193                }
1194            ),
1195            "Dimension mismatch: expected 3, got 2"
1196        );
1197        assert!(format!("{}", SheafError::InvalidRestriction("bad".into())).contains("bad"));
1198    }
1199
1200    #[test]
1201    fn test_simple_sheaf_graph_default() {
1202        let graph = SimpleSheafGraph::default();
1203        assert_eq!(graph.num_nodes(), 0);
1204        assert_eq!(graph.num_edges(), 0);
1205    }
1206
1207    // =========================================================================
1208    // Non-square restriction maps
1209    // =========================================================================
1210
1211    #[test]
1212    fn test_non_square_restriction_maps() {
1213        // Project 3D stalks into 2D edge space
1214        let mut graph = SimpleSheafGraph::new();
1215        graph.add_node(vec![1.0, 0.0, 0.0]);
1216        graph.add_node(vec![0.0, 1.0, 0.0]);
1217
1218        // Projection: keep first 2 dims (2x3 matrix)
1219        let proj = DenseRestriction::new(vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 2, 3).unwrap();
1220        graph
1221            .add_edge(0, 1, proj.clone(), proj.clone(), 1.0)
1222            .unwrap();
1223
1224        let energy = graph.dirichlet_energy().unwrap();
1225        // R(x0) = [1,0], R(x1) = [0,1], ||diff||^2 = 2
1226        assert!((energy - 2.0).abs() < 1e-6);
1227    }
1228
1229    #[test]
1230    fn test_diffusion_with_non_square_restrictions() {
1231        // Verify diffusion works with non-square maps
1232        let mut graph = SimpleSheafGraph::new();
1233        graph.add_node(vec![1.0, 0.0, 0.0]);
1234        graph.add_node(vec![0.0, 1.0, 0.0]);
1235
1236        let proj = DenseRestriction::new(vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 2, 3).unwrap();
1237        graph
1238            .add_edge(0, 1, proj.clone(), proj.clone(), 1.0)
1239            .unwrap();
1240
1241        let initial_energy = graph.dirichlet_energy().unwrap();
1242        graph.diffusion_step(0.1).unwrap();
1243        let final_energy = graph.dirichlet_energy().unwrap();
1244        assert!(
1245            final_energy < initial_energy,
1246            "diffusion should reduce energy with non-square maps"
1247        );
1248    }
1249
1250    // =========================================================================
1251    // Empty and single-node graphs
1252    // =========================================================================
1253
1254    #[test]
1255    fn test_empty_graph_energy() {
1256        let graph = SimpleSheafGraph::new();
1257        let energy = graph.dirichlet_energy().unwrap();
1258        assert_eq!(energy, 0.0);
1259    }
1260
1261    #[test]
1262    fn test_single_node_graph() {
1263        let mut graph = SimpleSheafGraph::new();
1264        graph.add_node(vec![1.0, 2.0]);
1265        assert_eq!(graph.num_nodes(), 1);
1266        assert_eq!(graph.num_edges(), 0);
1267        let energy = graph.dirichlet_energy().unwrap();
1268        assert_eq!(energy, 0.0);
1269    }
1270}