Skip to main content

tang_expr/
sparsity.rs

1//! Sparsity analysis via dependency bitmasks.
2
3use crate::graph::ExprGraph;
4use crate::node::{ExprId, Node};
5
6impl ExprGraph {
7    /// Compute a bitmask of which `Var(n)` indices appear in `expr`.
8    ///
9    /// Bit `n` is set if `Var(n)` is reachable from `expr`.
10    /// Supports up to 64 variables.
11    pub fn deps(&self, expr: ExprId) -> u64 {
12        let n = expr.0 as usize + 1;
13        let mut masks = vec![0u64; n];
14
15        for i in 0..n {
16            let m = match self.node(ExprId(i as u32)) {
17                Node::Var(idx) => {
18                    assert!(idx < 64, "deps() supports at most 64 variables");
19                    1u64 << idx
20                }
21                Node::Lit(_) => 0,
22                Node::Add(a, b) | Node::Mul(a, b) | Node::Atan2(a, b) => {
23                    masks[a.0 as usize] | masks[b.0 as usize]
24                }
25                Node::Neg(a)
26                | Node::Recip(a)
27                | Node::Sqrt(a)
28                | Node::Sin(a)
29                | Node::Exp2(a)
30                | Node::Log2(a) => masks[a.0 as usize],
31                Node::Select(c, a, b) => {
32                    masks[c.0 as usize] | masks[a.0 as usize] | masks[b.0 as usize]
33                }
34            };
35            masks[i] = m;
36        }
37
38        masks[expr.0 as usize]
39    }
40
41    /// Compute the Jacobian sparsity pattern.
42    ///
43    /// Returns one `u64` bitmask per output expression. Bit `j` of `result[i]`
44    /// is set if `outputs[i]` depends on `Var(j)`.
45    pub fn jacobian_sparsity(&self, outputs: &[ExprId], n_vars: usize) -> Vec<u64> {
46        if outputs.is_empty() {
47            return Vec::new();
48        }
49
50        let max_id = outputs.iter().map(|e| e.0).max().unwrap() as usize;
51        let n = max_id + 1;
52        let mut masks = vec![0u64; n];
53
54        for i in 0..n {
55            let m = match self.node(ExprId(i as u32)) {
56                Node::Var(idx) => {
57                    if (idx as usize) < n_vars {
58                        1u64 << idx
59                    } else {
60                        0
61                    }
62                }
63                Node::Lit(_) => 0,
64                Node::Add(a, b) | Node::Mul(a, b) | Node::Atan2(a, b) => {
65                    masks[a.0 as usize] | masks[b.0 as usize]
66                }
67                Node::Neg(a)
68                | Node::Recip(a)
69                | Node::Sqrt(a)
70                | Node::Sin(a)
71                | Node::Exp2(a)
72                | Node::Log2(a) => masks[a.0 as usize],
73                Node::Select(c, a, b) => {
74                    masks[c.0 as usize] | masks[a.0 as usize] | masks[b.0 as usize]
75                }
76            };
77            masks[i] = m;
78        }
79
80        outputs.iter().map(|e| masks[e.0 as usize]).collect()
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use crate::graph::ExprGraph;
87
88    #[test]
89    fn deps_var() {
90        let mut g = ExprGraph::new();
91        let x = g.var(0);
92        assert_eq!(g.deps(x), 0b1);
93        let y = g.var(1);
94        assert_eq!(g.deps(y), 0b10);
95    }
96
97    #[test]
98    fn deps_lit() {
99        let mut g = ExprGraph::new();
100        let c = g.lit(42.0);
101        assert_eq!(g.deps(c), 0);
102    }
103
104    #[test]
105    fn deps_add() {
106        let mut g = ExprGraph::new();
107        let x = g.var(0);
108        let y = g.var(1);
109        let sum = g.add(x, y);
110        assert_eq!(g.deps(sum), 0b11);
111    }
112
113    #[test]
114    fn deps_dot_product() {
115        let mut g = ExprGraph::new();
116        let x0 = g.var(0);
117        let x1 = g.var(1);
118        let x2 = g.var(2);
119        let x3 = g.var(3);
120        let x4 = g.var(4);
121        let x5 = g.var(5);
122
123        let t0 = g.mul(x0, x3);
124        let t1 = g.mul(x1, x4);
125        let t2 = g.mul(x2, x5);
126        let s = g.add(t0, t1);
127        let dot = g.add(s, t2);
128
129        assert_eq!(g.deps(dot), 0b111111);
130    }
131
132    #[test]
133    fn jacobian_sparsity_basic() {
134        let mut g = ExprGraph::new();
135        let x = g.var(0);
136        let y = g.var(1);
137        let z = g.var(2);
138
139        let f0 = g.add(x, y); // depends on x0, x1
140        let f1 = g.mul(y, z); // depends on x1, x2
141        let f2 = g.sin(x); // depends on x0
142
143        let sparsity = g.jacobian_sparsity(&[f0, f1, f2], 3);
144        assert_eq!(sparsity[0], 0b011); // f0 depends on x0, x1
145        assert_eq!(sparsity[1], 0b110); // f1 depends on x1, x2
146        assert_eq!(sparsity[2], 0b001); // f2 depends on x0
147    }
148}