1use crate::graph::ExprGraph;
4use crate::node::{ExprId, Node};
5
6impl ExprGraph {
7 pub fn deps(&self, expr: ExprId) -> u64 {
12 let n = expr.0 as usize + 1;
13 let mut masks = vec![0u64; n];
14
15 for i in 0..n {
16 let m = match self.node(ExprId(i as u32)) {
17 Node::Var(idx) => {
18 assert!(idx < 64, "deps() supports at most 64 variables");
19 1u64 << idx
20 }
21 Node::Lit(_) => 0,
22 Node::Add(a, b) | Node::Mul(a, b) | Node::Atan2(a, b) => {
23 masks[a.0 as usize] | masks[b.0 as usize]
24 }
25 Node::Neg(a)
26 | Node::Recip(a)
27 | Node::Sqrt(a)
28 | Node::Sin(a)
29 | Node::Exp2(a)
30 | Node::Log2(a) => masks[a.0 as usize],
31 Node::Select(c, a, b) => {
32 masks[c.0 as usize] | masks[a.0 as usize] | masks[b.0 as usize]
33 }
34 };
35 masks[i] = m;
36 }
37
38 masks[expr.0 as usize]
39 }
40
41 pub fn jacobian_sparsity(&self, outputs: &[ExprId], n_vars: usize) -> Vec<u64> {
46 if outputs.is_empty() {
47 return Vec::new();
48 }
49
50 let max_id = outputs.iter().map(|e| e.0).max().unwrap() as usize;
51 let n = max_id + 1;
52 let mut masks = vec![0u64; n];
53
54 for i in 0..n {
55 let m = match self.node(ExprId(i as u32)) {
56 Node::Var(idx) => {
57 if (idx as usize) < n_vars {
58 1u64 << idx
59 } else {
60 0
61 }
62 }
63 Node::Lit(_) => 0,
64 Node::Add(a, b) | Node::Mul(a, b) | Node::Atan2(a, b) => {
65 masks[a.0 as usize] | masks[b.0 as usize]
66 }
67 Node::Neg(a)
68 | Node::Recip(a)
69 | Node::Sqrt(a)
70 | Node::Sin(a)
71 | Node::Exp2(a)
72 | Node::Log2(a) => masks[a.0 as usize],
73 Node::Select(c, a, b) => {
74 masks[c.0 as usize] | masks[a.0 as usize] | masks[b.0 as usize]
75 }
76 };
77 masks[i] = m;
78 }
79
80 outputs.iter().map(|e| masks[e.0 as usize]).collect()
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use crate::graph::ExprGraph;
87
88 #[test]
89 fn deps_var() {
90 let mut g = ExprGraph::new();
91 let x = g.var(0);
92 assert_eq!(g.deps(x), 0b1);
93 let y = g.var(1);
94 assert_eq!(g.deps(y), 0b10);
95 }
96
97 #[test]
98 fn deps_lit() {
99 let mut g = ExprGraph::new();
100 let c = g.lit(42.0);
101 assert_eq!(g.deps(c), 0);
102 }
103
104 #[test]
105 fn deps_add() {
106 let mut g = ExprGraph::new();
107 let x = g.var(0);
108 let y = g.var(1);
109 let sum = g.add(x, y);
110 assert_eq!(g.deps(sum), 0b11);
111 }
112
113 #[test]
114 fn deps_dot_product() {
115 let mut g = ExprGraph::new();
116 let x0 = g.var(0);
117 let x1 = g.var(1);
118 let x2 = g.var(2);
119 let x3 = g.var(3);
120 let x4 = g.var(4);
121 let x5 = g.var(5);
122
123 let t0 = g.mul(x0, x3);
124 let t1 = g.mul(x1, x4);
125 let t2 = g.mul(x2, x5);
126 let s = g.add(t0, t1);
127 let dot = g.add(s, t2);
128
129 assert_eq!(g.deps(dot), 0b111111);
130 }
131
132 #[test]
133 fn jacobian_sparsity_basic() {
134 let mut g = ExprGraph::new();
135 let x = g.var(0);
136 let y = g.var(1);
137 let z = g.var(2);
138
139 let f0 = g.add(x, y); let f1 = g.mul(y, z); let f2 = g.sin(x); let sparsity = g.jacobian_sparsity(&[f0, f1, f2], 3);
144 assert_eq!(sparsity[0], 0b011); assert_eq!(sparsity[1], 0b110); assert_eq!(sparsity[2], 0b001); }
148}