Skip to main content

ExprGraph

Struct ExprGraph 

Source
pub struct ExprGraph { /* private fields */ }
Expand description

Arena-based expression graph with structural interning.

Identical subexpressions always return the same ExprId — this gives automatic common subexpression elimination (CSE) for free.

Implementations§

Source§

impl ExprGraph

Source

pub fn to_kernel( &self, outputs: &[ExprId], n_inputs: usize, dialect: Dialect, ) -> ComputeKernel

Generate a compute kernel for the given dialect.

Each work item reads n_inputs f32 values and writes outputs.len() f32 values. Shared subexpressions are computed once per thread.

Source§

impl ExprGraph

Source

pub fn compile(&self, expr: ExprId) -> CompiledExpr

Compile a single expression to a closure &[f64] -> f64.

Shared subexpressions (from interning) are computed once. Dead nodes (not reachable from output) are skipped.

Examples found in repository?
examples/symbolic_regression.rs (line 260)
257fn fitness(expr: &Expr, data: &[(f64, f64)]) -> f64 {
258    let mut g = ExprGraph::new();
259    let root = expr.to_expr(&mut g);
260    let compiled = g.compile(root);
261
262    let mut mse = 0.0;
263    for &(x, y) in data {
264        let pred = compiled(&[x]);
265        if pred.is_nan() || pred.is_infinite() {
266            return f64::INFINITY;
267        }
268        let err = pred - y;
269        mse += err * err;
270    }
271    mse /= data.len() as f64;
272
273    let penalty = 0.001 * expr.size() as f64;
274    mse + penalty
275}
276
277// --- Selection ---------------------------------------------------------------
278
279fn tournament<'a>(pop: &'a [(Expr, f64)], k: usize, rng: &mut Lcg) -> &'a Expr {
280    let mut best_idx = rng.range(pop.len());
281    for _ in 1..k {
282        let idx = rng.range(pop.len());
283        if pop[idx].1 < pop[best_idx].1 {
284            best_idx = idx;
285        }
286    }
287    &pop[best_idx].0
288}
289
290// --- Main --------------------------------------------------------------------
291
292fn main() {
293    println!("=== Symbolic Regression ===\n");
294    println!("target: f(x) = x^2 * sin(x)\n");
295
296    let mut rng = Lcg::new(42);
297    let data = generate_data(50, &mut rng);
298
299    const POP_SIZE: usize = 200;
300    const GENERATIONS: usize = 100;
301    const TOURNAMENT_K: usize = 5;
302
303    // Initialize population
304    let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305        .map(|_| {
306            let expr = random_expr(4, &mut rng);
307            let fit = fitness(&expr, &data);
308            (expr, fit)
309        })
310        .collect();
311
312    let mut best_expr: Option<Expr> = None;
313    let mut best_fitness = f64::INFINITY;
314
315    for gen in 0..GENERATIONS {
316        // Track best
317        for (expr, fit) in &population {
318            if *fit < best_fitness {
319                best_fitness = *fit;
320                best_expr = Some(expr.clone());
321            }
322        }
323
324        if (gen + 1) % 10 == 0 || gen == 0 {
325            let b = best_expr.as_ref().unwrap();
326            println!(
327                "gen {:>3}: best fitness = {:.6}  nodes = {:>3}  expr = {}",
328                gen + 1,
329                best_fitness,
330                b.size(),
331                b.format(),
332            );
333        }
334
335        // Build next generation
336        let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338        // Elitism: keep top 5
339        let mut indices: Vec<usize> = (0..population.len()).collect();
340        indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341        for &i in indices.iter().take(5) {
342            next_pop.push(population[i].clone());
343        }
344
345        // Fill rest via mutation
346        while next_pop.len() < POP_SIZE {
347            let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348            let child = match rng.range(10) {
349                0..=3 => mutate_grow(parent, &mut rng),
350                4..=6 => mutate_point(parent, &mut rng),
351                7..=8 => mutate_simplify(parent),
352                _ => random_expr(4, &mut rng), // fresh blood
353            };
354
355            // Skip overly large expressions
356            if child.size() > 50 {
357                continue;
358            }
359
360            let fit = fitness(&child, &data);
361            next_pop.push((child, fit));
362        }
363
364        population = next_pop;
365    }
366
367    // Final results
368    println!();
369    let best = best_expr.unwrap();
370    let mut g = ExprGraph::new();
371    let root = best.to_expr(&mut g);
372    let simplified = g.simplify(root);
373    let expr_str = g.fmt_expr(simplified);
374    println!("best expression: {}", expr_str);
375    println!("best fitness:    {:.6}", best_fitness);
376
377    // Verify on test points
378    println!("\nverification:");
379    let compiled = g.compile(simplified);
380    for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381        let pred = compiled(&[x]);
382        let exact = target(x);
383        println!(
384            "  f({:>5.1}) = {:>8.4}  (predicted: {:>8.4}, error: {:>8.4})",
385            x, exact, pred, (pred - exact).abs()
386        );
387    }
388
389    // Symbolic derivative via ExprGraph
390    let dx = g.diff(simplified, 0);
391    let dx = g.simplify(dx);
392    println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}
Source

pub fn compile_many(&self, exprs: &[ExprId]) -> CompiledMany

Compile multiple output expressions to a single closure.

Writes results into the output slice.

Source

pub fn live_set(&self, outputs: &[ExprId]) -> HashSet<usize>

Find all node indices reachable from the given outputs.

Source§

impl ExprGraph

Source

pub fn diff(&mut self, expr: ExprId, var: u16) -> ExprId

Differentiate expr with respect to Var(var).

Returns a new ExprId in the same graph. Uses memoization to avoid recomputing derivatives of shared subexpressions.

Examples found in repository?
examples/symbolic_regression.rs (line 390)
292fn main() {
293    println!("=== Symbolic Regression ===\n");
294    println!("target: f(x) = x^2 * sin(x)\n");
295
296    let mut rng = Lcg::new(42);
297    let data = generate_data(50, &mut rng);
298
299    const POP_SIZE: usize = 200;
300    const GENERATIONS: usize = 100;
301    const TOURNAMENT_K: usize = 5;
302
303    // Initialize population
304    let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305        .map(|_| {
306            let expr = random_expr(4, &mut rng);
307            let fit = fitness(&expr, &data);
308            (expr, fit)
309        })
310        .collect();
311
312    let mut best_expr: Option<Expr> = None;
313    let mut best_fitness = f64::INFINITY;
314
315    for gen in 0..GENERATIONS {
316        // Track best
317        for (expr, fit) in &population {
318            if *fit < best_fitness {
319                best_fitness = *fit;
320                best_expr = Some(expr.clone());
321            }
322        }
323
324        if (gen + 1) % 10 == 0 || gen == 0 {
325            let b = best_expr.as_ref().unwrap();
326            println!(
327                "gen {:>3}: best fitness = {:.6}  nodes = {:>3}  expr = {}",
328                gen + 1,
329                best_fitness,
330                b.size(),
331                b.format(),
332            );
333        }
334
335        // Build next generation
336        let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338        // Elitism: keep top 5
339        let mut indices: Vec<usize> = (0..population.len()).collect();
340        indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341        for &i in indices.iter().take(5) {
342            next_pop.push(population[i].clone());
343        }
344
345        // Fill rest via mutation
346        while next_pop.len() < POP_SIZE {
347            let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348            let child = match rng.range(10) {
349                0..=3 => mutate_grow(parent, &mut rng),
350                4..=6 => mutate_point(parent, &mut rng),
351                7..=8 => mutate_simplify(parent),
352                _ => random_expr(4, &mut rng), // fresh blood
353            };
354
355            // Skip overly large expressions
356            if child.size() > 50 {
357                continue;
358            }
359
360            let fit = fitness(&child, &data);
361            next_pop.push((child, fit));
362        }
363
364        population = next_pop;
365    }
366
367    // Final results
368    println!();
369    let best = best_expr.unwrap();
370    let mut g = ExprGraph::new();
371    let root = best.to_expr(&mut g);
372    let simplified = g.simplify(root);
373    let expr_str = g.fmt_expr(simplified);
374    println!("best expression: {}", expr_str);
375    println!("best fitness:    {:.6}", best_fitness);
376
377    // Verify on test points
378    println!("\nverification:");
379    let compiled = g.compile(simplified);
380    for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381        let pred = compiled(&[x]);
382        let exact = target(x);
383        println!(
384            "  f({:>5.1}) = {:>8.4}  (predicted: {:>8.4}, error: {:>8.4})",
385            x, exact, pred, (pred - exact).abs()
386        );
387    }
388
389    // Symbolic derivative via ExprGraph
390    let dx = g.diff(simplified, 0);
391    let dx = g.simplify(dx);
392    println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}
Source§

impl ExprGraph

Source

pub fn fmt_expr(&self, expr: ExprId) -> String

Format an expression as a human-readable string.

Examples found in repository?
examples/symbolic_regression.rs (line 87)
84    fn format(&self) -> String {
85        let mut g = ExprGraph::new();
86        let root = self.to_expr(&mut g);
87        g.fmt_expr(root)
88    }
89
90    /// Evaluate at a point using ExprGraph's compiled eval.
91    fn eval_at(&self, x: f64) -> f64 {
92        let mut g = ExprGraph::new();
93        let root = self.to_expr(&mut g);
94        g.eval(root, &[x])
95    }
96}
97
98// --- Random expression generation -------------------------------------------
99
100fn random_expr(depth: usize, rng: &mut Lcg) -> Expr {
101    if depth == 0 || (depth < 3 && rng.uniform() < 0.3) {
102        return match rng.range(3) {
103            0 => Expr::X,
104            1 => Expr::Lit((rng.uniform() * 4.0 - 2.0) * 10.0_f64.powi(-((rng.uniform() * 2.0) as i32))),
105            _ => Expr::X,
106        };
107    }
108
109    match rng.range(5) {
110        0 => Expr::Add(
111            Box::new(random_expr(depth - 1, rng)),
112            Box::new(random_expr(depth - 1, rng)),
113        ),
114        1 | 2 => Expr::Mul(
115            Box::new(random_expr(depth - 1, rng)),
116            Box::new(random_expr(depth - 1, rng)),
117        ),
118        3 => Expr::Sin(Box::new(random_expr(depth - 1, rng))),
119        _ => Expr::Neg(Box::new(random_expr(depth - 1, rng))),
120    }
121}
122
123// --- Mutation operators ------------------------------------------------------
124
125/// Grow mutation: replace a random subtree with a new random one.
126fn mutate_grow(expr: &Expr, rng: &mut Lcg) -> Expr {
127    let size = expr.size();
128    let target = rng.range(size);
129    grow_at(expr, target, &mut 0, rng)
130}
131
132fn grow_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
133    if *counter == target {
134        *counter += expr.size(); // skip subtree
135        return random_expr(3, rng);
136    }
137    *counter += 1;
138    match expr {
139        Expr::X => Expr::X,
140        Expr::Lit(v) => Expr::Lit(*v),
141        Expr::Add(a, b) => Expr::Add(
142            Box::new(grow_at(a, target, counter, rng)),
143            Box::new(grow_at(b, target, counter, rng)),
144        ),
145        Expr::Mul(a, b) => Expr::Mul(
146            Box::new(grow_at(a, target, counter, rng)),
147            Box::new(grow_at(b, target, counter, rng)),
148        ),
149        Expr::Sin(a) => Expr::Sin(Box::new(grow_at(a, target, counter, rng))),
150        Expr::Neg(a) => Expr::Neg(Box::new(grow_at(a, target, counter, rng))),
151    }
152}
153
154/// Point mutation: change a single node's operation.
155fn mutate_point(expr: &Expr, rng: &mut Lcg) -> Expr {
156    let size = expr.size();
157    let target = rng.range(size);
158    point_at(expr, target, &mut 0, rng)
159}
160
161fn point_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
162    if *counter == target {
163        *counter += 1;
164        return match expr {
165            Expr::X => Expr::Lit(rng.uniform() * 2.0 - 1.0),
166            Expr::Lit(_) => Expr::X,
167            Expr::Add(a, b) => Expr::Mul(a.clone(), b.clone()),
168            Expr::Mul(a, b) => Expr::Add(a.clone(), b.clone()),
169            Expr::Sin(a) => Expr::Neg(a.clone()),
170            Expr::Neg(a) => Expr::Sin(a.clone()),
171        };
172    }
173    *counter += 1;
174    match expr {
175        Expr::X => Expr::X,
176        Expr::Lit(v) => Expr::Lit(*v),
177        Expr::Add(a, b) => Expr::Add(
178            Box::new(point_at(a, target, counter, rng)),
179            Box::new(point_at(b, target, counter, rng)),
180        ),
181        Expr::Mul(a, b) => Expr::Mul(
182            Box::new(point_at(a, target, counter, rng)),
183            Box::new(point_at(b, target, counter, rng)),
184        ),
185        Expr::Sin(a) => Expr::Sin(Box::new(point_at(a, target, counter, rng))),
186        Expr::Neg(a) => Expr::Neg(Box::new(point_at(a, target, counter, rng))),
187    }
188}
189
190/// Simplify mutation: convert to ExprGraph, simplify, convert back.
191/// Falls back to identity if conversion back would be complex.
192fn mutate_simplify(expr: &Expr) -> Expr {
193    let mut g = ExprGraph::new();
194    let root = expr.to_expr(&mut g);
195    let simplified = g.simplify(root);
196    let s = g.fmt_expr(simplified);
197
198    // Quick check: if simplified form is much shorter, it's better
199    let orig = expr.format();
200    if s.len() < orig.len() {
201        // Re-evaluate to confirm it still works
202        let test = g.eval::<f64>(simplified, &[1.0]);
203        let orig_test = expr.eval_at(1.0);
204        if (test - orig_test).abs() < 1e-10 || (test.is_nan() && orig_test.is_nan()) {
205            // Return a new tree built from the simplified graph
206            return expr_from_graph(&g, simplified);
207        }
208    }
209    expr.clone()
210}
211
212/// Reconstruct an Expr tree from an ExprGraph node.
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214    match g.node(id) {
215        tang_expr::node::Node::Var(_) => Expr::X,
216        tang_expr::node::Node::Lit(bits) => {
217            let v = f64::from_bits(bits);
218            if v == 0.0 {
219                Expr::Lit(0.0)
220            } else {
221                Expr::Lit(v)
222            }
223        }
224        tang_expr::node::Node::Add(a, b) => {
225            Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226        }
227        tang_expr::node::Node::Mul(a, b) => {
228            Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229        }
230        tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231        tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232        // For operations we don't represent in our AST, just evaluate as literal
233        _ => {
234            let v = g.eval::<f64>(id, &[1.0]); // fallback
235            Expr::Lit(v)
236        }
237    }
238}
239
240// --- Fitness evaluation ------------------------------------------------------
241
242fn target(x: f64) -> f64 {
243    x * x * x.sin()
244}
245
246fn generate_data(n: usize, rng: &mut Lcg) -> Vec<(f64, f64)> {
247    (0..n)
248        .map(|i| {
249            let x = -3.0 + 6.0 * i as f64 / (n - 1) as f64;
250            let noise = (rng.uniform() - 0.5) * 0.01;
251            (x, target(x) + noise)
252        })
253        .collect()
254}
255
256/// Evaluate fitness: MSE + complexity penalty.
257fn fitness(expr: &Expr, data: &[(f64, f64)]) -> f64 {
258    let mut g = ExprGraph::new();
259    let root = expr.to_expr(&mut g);
260    let compiled = g.compile(root);
261
262    let mut mse = 0.0;
263    for &(x, y) in data {
264        let pred = compiled(&[x]);
265        if pred.is_nan() || pred.is_infinite() {
266            return f64::INFINITY;
267        }
268        let err = pred - y;
269        mse += err * err;
270    }
271    mse /= data.len() as f64;
272
273    let penalty = 0.001 * expr.size() as f64;
274    mse + penalty
275}
276
277// --- Selection ---------------------------------------------------------------
278
279fn tournament<'a>(pop: &'a [(Expr, f64)], k: usize, rng: &mut Lcg) -> &'a Expr {
280    let mut best_idx = rng.range(pop.len());
281    for _ in 1..k {
282        let idx = rng.range(pop.len());
283        if pop[idx].1 < pop[best_idx].1 {
284            best_idx = idx;
285        }
286    }
287    &pop[best_idx].0
288}
289
290// --- Main --------------------------------------------------------------------
291
292fn main() {
293    println!("=== Symbolic Regression ===\n");
294    println!("target: f(x) = x^2 * sin(x)\n");
295
296    let mut rng = Lcg::new(42);
297    let data = generate_data(50, &mut rng);
298
299    const POP_SIZE: usize = 200;
300    const GENERATIONS: usize = 100;
301    const TOURNAMENT_K: usize = 5;
302
303    // Initialize population
304    let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305        .map(|_| {
306            let expr = random_expr(4, &mut rng);
307            let fit = fitness(&expr, &data);
308            (expr, fit)
309        })
310        .collect();
311
312    let mut best_expr: Option<Expr> = None;
313    let mut best_fitness = f64::INFINITY;
314
315    for gen in 0..GENERATIONS {
316        // Track best
317        for (expr, fit) in &population {
318            if *fit < best_fitness {
319                best_fitness = *fit;
320                best_expr = Some(expr.clone());
321            }
322        }
323
324        if (gen + 1) % 10 == 0 || gen == 0 {
325            let b = best_expr.as_ref().unwrap();
326            println!(
327                "gen {:>3}: best fitness = {:.6}  nodes = {:>3}  expr = {}",
328                gen + 1,
329                best_fitness,
330                b.size(),
331                b.format(),
332            );
333        }
334
335        // Build next generation
336        let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338        // Elitism: keep top 5
339        let mut indices: Vec<usize> = (0..population.len()).collect();
340        indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341        for &i in indices.iter().take(5) {
342            next_pop.push(population[i].clone());
343        }
344
345        // Fill rest via mutation
346        while next_pop.len() < POP_SIZE {
347            let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348            let child = match rng.range(10) {
349                0..=3 => mutate_grow(parent, &mut rng),
350                4..=6 => mutate_point(parent, &mut rng),
351                7..=8 => mutate_simplify(parent),
352                _ => random_expr(4, &mut rng), // fresh blood
353            };
354
355            // Skip overly large expressions
356            if child.size() > 50 {
357                continue;
358            }
359
360            let fit = fitness(&child, &data);
361            next_pop.push((child, fit));
362        }
363
364        population = next_pop;
365    }
366
367    // Final results
368    println!();
369    let best = best_expr.unwrap();
370    let mut g = ExprGraph::new();
371    let root = best.to_expr(&mut g);
372    let simplified = g.simplify(root);
373    let expr_str = g.fmt_expr(simplified);
374    println!("best expression: {}", expr_str);
375    println!("best fitness:    {:.6}", best_fitness);
376
377    // Verify on test points
378    println!("\nverification:");
379    let compiled = g.compile(simplified);
380    for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381        let pred = compiled(&[x]);
382        let exact = target(x);
383        println!(
384            "  f({:>5.1}) = {:>8.4}  (predicted: {:>8.4}, error: {:>8.4})",
385            x, exact, pred, (pred - exact).abs()
386        );
387    }
388
389    // Symbolic derivative via ExprGraph
390    let dx = g.diff(simplified, 0);
391    let dx = g.simplify(dx);
392    println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}
Source§

impl ExprGraph

Source

pub fn eval<S: Scalar>(&self, expr: ExprId, inputs: &[S]) -> S

Evaluate an expression with concrete scalar inputs.

inputs[n] provides the value for Var(n). Walks the graph in topological order (which is just index order, since children are always created before parents).

Examples found in repository?
examples/symbolic_regression.rs (line 94)
91    fn eval_at(&self, x: f64) -> f64 {
92        let mut g = ExprGraph::new();
93        let root = self.to_expr(&mut g);
94        g.eval(root, &[x])
95    }
96}
97
98// --- Random expression generation -------------------------------------------
99
100fn random_expr(depth: usize, rng: &mut Lcg) -> Expr {
101    if depth == 0 || (depth < 3 && rng.uniform() < 0.3) {
102        return match rng.range(3) {
103            0 => Expr::X,
104            1 => Expr::Lit((rng.uniform() * 4.0 - 2.0) * 10.0_f64.powi(-((rng.uniform() * 2.0) as i32))),
105            _ => Expr::X,
106        };
107    }
108
109    match rng.range(5) {
110        0 => Expr::Add(
111            Box::new(random_expr(depth - 1, rng)),
112            Box::new(random_expr(depth - 1, rng)),
113        ),
114        1 | 2 => Expr::Mul(
115            Box::new(random_expr(depth - 1, rng)),
116            Box::new(random_expr(depth - 1, rng)),
117        ),
118        3 => Expr::Sin(Box::new(random_expr(depth - 1, rng))),
119        _ => Expr::Neg(Box::new(random_expr(depth - 1, rng))),
120    }
121}
122
123// --- Mutation operators ------------------------------------------------------
124
125/// Grow mutation: replace a random subtree with a new random one.
126fn mutate_grow(expr: &Expr, rng: &mut Lcg) -> Expr {
127    let size = expr.size();
128    let target = rng.range(size);
129    grow_at(expr, target, &mut 0, rng)
130}
131
132fn grow_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
133    if *counter == target {
134        *counter += expr.size(); // skip subtree
135        return random_expr(3, rng);
136    }
137    *counter += 1;
138    match expr {
139        Expr::X => Expr::X,
140        Expr::Lit(v) => Expr::Lit(*v),
141        Expr::Add(a, b) => Expr::Add(
142            Box::new(grow_at(a, target, counter, rng)),
143            Box::new(grow_at(b, target, counter, rng)),
144        ),
145        Expr::Mul(a, b) => Expr::Mul(
146            Box::new(grow_at(a, target, counter, rng)),
147            Box::new(grow_at(b, target, counter, rng)),
148        ),
149        Expr::Sin(a) => Expr::Sin(Box::new(grow_at(a, target, counter, rng))),
150        Expr::Neg(a) => Expr::Neg(Box::new(grow_at(a, target, counter, rng))),
151    }
152}
153
154/// Point mutation: change a single node's operation.
155fn mutate_point(expr: &Expr, rng: &mut Lcg) -> Expr {
156    let size = expr.size();
157    let target = rng.range(size);
158    point_at(expr, target, &mut 0, rng)
159}
160
161fn point_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
162    if *counter == target {
163        *counter += 1;
164        return match expr {
165            Expr::X => Expr::Lit(rng.uniform() * 2.0 - 1.0),
166            Expr::Lit(_) => Expr::X,
167            Expr::Add(a, b) => Expr::Mul(a.clone(), b.clone()),
168            Expr::Mul(a, b) => Expr::Add(a.clone(), b.clone()),
169            Expr::Sin(a) => Expr::Neg(a.clone()),
170            Expr::Neg(a) => Expr::Sin(a.clone()),
171        };
172    }
173    *counter += 1;
174    match expr {
175        Expr::X => Expr::X,
176        Expr::Lit(v) => Expr::Lit(*v),
177        Expr::Add(a, b) => Expr::Add(
178            Box::new(point_at(a, target, counter, rng)),
179            Box::new(point_at(b, target, counter, rng)),
180        ),
181        Expr::Mul(a, b) => Expr::Mul(
182            Box::new(point_at(a, target, counter, rng)),
183            Box::new(point_at(b, target, counter, rng)),
184        ),
185        Expr::Sin(a) => Expr::Sin(Box::new(point_at(a, target, counter, rng))),
186        Expr::Neg(a) => Expr::Neg(Box::new(point_at(a, target, counter, rng))),
187    }
188}
189
190/// Simplify mutation: convert to ExprGraph, simplify, convert back.
191/// Falls back to identity if conversion back would be complex.
192fn mutate_simplify(expr: &Expr) -> Expr {
193    let mut g = ExprGraph::new();
194    let root = expr.to_expr(&mut g);
195    let simplified = g.simplify(root);
196    let s = g.fmt_expr(simplified);
197
198    // Quick check: if simplified form is much shorter, it's better
199    let orig = expr.format();
200    if s.len() < orig.len() {
201        // Re-evaluate to confirm it still works
202        let test = g.eval::<f64>(simplified, &[1.0]);
203        let orig_test = expr.eval_at(1.0);
204        if (test - orig_test).abs() < 1e-10 || (test.is_nan() && orig_test.is_nan()) {
205            // Return a new tree built from the simplified graph
206            return expr_from_graph(&g, simplified);
207        }
208    }
209    expr.clone()
210}
211
212/// Reconstruct an Expr tree from an ExprGraph node.
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214    match g.node(id) {
215        tang_expr::node::Node::Var(_) => Expr::X,
216        tang_expr::node::Node::Lit(bits) => {
217            let v = f64::from_bits(bits);
218            if v == 0.0 {
219                Expr::Lit(0.0)
220            } else {
221                Expr::Lit(v)
222            }
223        }
224        tang_expr::node::Node::Add(a, b) => {
225            Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226        }
227        tang_expr::node::Node::Mul(a, b) => {
228            Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229        }
230        tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231        tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232        // For operations we don't represent in our AST, just evaluate as literal
233        _ => {
234            let v = g.eval::<f64>(id, &[1.0]); // fallback
235            Expr::Lit(v)
236        }
237    }
238}
Source

pub fn eval_many<S: Scalar>(&self, exprs: &[ExprId], inputs: &[S]) -> Vec<S>

Evaluate multiple output expressions, sharing intermediate values.

Source§

impl ExprGraph

Source

pub fn new() -> Self

Create a new graph pre-populated with ZERO, ONE, TWO.

Examples found in repository?
examples/symbolic_regression.rs (line 85)
84    fn format(&self) -> String {
85        let mut g = ExprGraph::new();
86        let root = self.to_expr(&mut g);
87        g.fmt_expr(root)
88    }
89
90    /// Evaluate at a point using ExprGraph's compiled eval.
91    fn eval_at(&self, x: f64) -> f64 {
92        let mut g = ExprGraph::new();
93        let root = self.to_expr(&mut g);
94        g.eval(root, &[x])
95    }
96}
97
98// --- Random expression generation -------------------------------------------
99
100fn random_expr(depth: usize, rng: &mut Lcg) -> Expr {
101    if depth == 0 || (depth < 3 && rng.uniform() < 0.3) {
102        return match rng.range(3) {
103            0 => Expr::X,
104            1 => Expr::Lit((rng.uniform() * 4.0 - 2.0) * 10.0_f64.powi(-((rng.uniform() * 2.0) as i32))),
105            _ => Expr::X,
106        };
107    }
108
109    match rng.range(5) {
110        0 => Expr::Add(
111            Box::new(random_expr(depth - 1, rng)),
112            Box::new(random_expr(depth - 1, rng)),
113        ),
114        1 | 2 => Expr::Mul(
115            Box::new(random_expr(depth - 1, rng)),
116            Box::new(random_expr(depth - 1, rng)),
117        ),
118        3 => Expr::Sin(Box::new(random_expr(depth - 1, rng))),
119        _ => Expr::Neg(Box::new(random_expr(depth - 1, rng))),
120    }
121}
122
123// --- Mutation operators ------------------------------------------------------
124
125/// Grow mutation: replace a random subtree with a new random one.
126fn mutate_grow(expr: &Expr, rng: &mut Lcg) -> Expr {
127    let size = expr.size();
128    let target = rng.range(size);
129    grow_at(expr, target, &mut 0, rng)
130}
131
132fn grow_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
133    if *counter == target {
134        *counter += expr.size(); // skip subtree
135        return random_expr(3, rng);
136    }
137    *counter += 1;
138    match expr {
139        Expr::X => Expr::X,
140        Expr::Lit(v) => Expr::Lit(*v),
141        Expr::Add(a, b) => Expr::Add(
142            Box::new(grow_at(a, target, counter, rng)),
143            Box::new(grow_at(b, target, counter, rng)),
144        ),
145        Expr::Mul(a, b) => Expr::Mul(
146            Box::new(grow_at(a, target, counter, rng)),
147            Box::new(grow_at(b, target, counter, rng)),
148        ),
149        Expr::Sin(a) => Expr::Sin(Box::new(grow_at(a, target, counter, rng))),
150        Expr::Neg(a) => Expr::Neg(Box::new(grow_at(a, target, counter, rng))),
151    }
152}
153
154/// Point mutation: change a single node's operation.
155fn mutate_point(expr: &Expr, rng: &mut Lcg) -> Expr {
156    let size = expr.size();
157    let target = rng.range(size);
158    point_at(expr, target, &mut 0, rng)
159}
160
161fn point_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
162    if *counter == target {
163        *counter += 1;
164        return match expr {
165            Expr::X => Expr::Lit(rng.uniform() * 2.0 - 1.0),
166            Expr::Lit(_) => Expr::X,
167            Expr::Add(a, b) => Expr::Mul(a.clone(), b.clone()),
168            Expr::Mul(a, b) => Expr::Add(a.clone(), b.clone()),
169            Expr::Sin(a) => Expr::Neg(a.clone()),
170            Expr::Neg(a) => Expr::Sin(a.clone()),
171        };
172    }
173    *counter += 1;
174    match expr {
175        Expr::X => Expr::X,
176        Expr::Lit(v) => Expr::Lit(*v),
177        Expr::Add(a, b) => Expr::Add(
178            Box::new(point_at(a, target, counter, rng)),
179            Box::new(point_at(b, target, counter, rng)),
180        ),
181        Expr::Mul(a, b) => Expr::Mul(
182            Box::new(point_at(a, target, counter, rng)),
183            Box::new(point_at(b, target, counter, rng)),
184        ),
185        Expr::Sin(a) => Expr::Sin(Box::new(point_at(a, target, counter, rng))),
186        Expr::Neg(a) => Expr::Neg(Box::new(point_at(a, target, counter, rng))),
187    }
188}
189
190/// Simplify mutation: convert to ExprGraph, simplify, convert back.
191/// Falls back to identity if conversion back would be complex.
192fn mutate_simplify(expr: &Expr) -> Expr {
193    let mut g = ExprGraph::new();
194    let root = expr.to_expr(&mut g);
195    let simplified = g.simplify(root);
196    let s = g.fmt_expr(simplified);
197
198    // Quick check: if simplified form is much shorter, it's better
199    let orig = expr.format();
200    if s.len() < orig.len() {
201        // Re-evaluate to confirm it still works
202        let test = g.eval::<f64>(simplified, &[1.0]);
203        let orig_test = expr.eval_at(1.0);
204        if (test - orig_test).abs() < 1e-10 || (test.is_nan() && orig_test.is_nan()) {
205            // Return a new tree built from the simplified graph
206            return expr_from_graph(&g, simplified);
207        }
208    }
209    expr.clone()
210}
211
212/// Reconstruct an Expr tree from an ExprGraph node.
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214    match g.node(id) {
215        tang_expr::node::Node::Var(_) => Expr::X,
216        tang_expr::node::Node::Lit(bits) => {
217            let v = f64::from_bits(bits);
218            if v == 0.0 {
219                Expr::Lit(0.0)
220            } else {
221                Expr::Lit(v)
222            }
223        }
224        tang_expr::node::Node::Add(a, b) => {
225            Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226        }
227        tang_expr::node::Node::Mul(a, b) => {
228            Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229        }
230        tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231        tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232        // For operations we don't represent in our AST, just evaluate as literal
233        _ => {
234            let v = g.eval::<f64>(id, &[1.0]); // fallback
235            Expr::Lit(v)
236        }
237    }
238}
239
240// --- Fitness evaluation ------------------------------------------------------
241
242fn target(x: f64) -> f64 {
243    x * x * x.sin()
244}
245
246fn generate_data(n: usize, rng: &mut Lcg) -> Vec<(f64, f64)> {
247    (0..n)
248        .map(|i| {
249            let x = -3.0 + 6.0 * i as f64 / (n - 1) as f64;
250            let noise = (rng.uniform() - 0.5) * 0.01;
251            (x, target(x) + noise)
252        })
253        .collect()
254}
255
256/// Evaluate fitness: MSE + complexity penalty.
257fn fitness(expr: &Expr, data: &[(f64, f64)]) -> f64 {
258    let mut g = ExprGraph::new();
259    let root = expr.to_expr(&mut g);
260    let compiled = g.compile(root);
261
262    let mut mse = 0.0;
263    for &(x, y) in data {
264        let pred = compiled(&[x]);
265        if pred.is_nan() || pred.is_infinite() {
266            return f64::INFINITY;
267        }
268        let err = pred - y;
269        mse += err * err;
270    }
271    mse /= data.len() as f64;
272
273    let penalty = 0.001 * expr.size() as f64;
274    mse + penalty
275}
276
277// --- Selection ---------------------------------------------------------------
278
279fn tournament<'a>(pop: &'a [(Expr, f64)], k: usize, rng: &mut Lcg) -> &'a Expr {
280    let mut best_idx = rng.range(pop.len());
281    for _ in 1..k {
282        let idx = rng.range(pop.len());
283        if pop[idx].1 < pop[best_idx].1 {
284            best_idx = idx;
285        }
286    }
287    &pop[best_idx].0
288}
289
290// --- Main --------------------------------------------------------------------
291
292fn main() {
293    println!("=== Symbolic Regression ===\n");
294    println!("target: f(x) = x^2 * sin(x)\n");
295
296    let mut rng = Lcg::new(42);
297    let data = generate_data(50, &mut rng);
298
299    const POP_SIZE: usize = 200;
300    const GENERATIONS: usize = 100;
301    const TOURNAMENT_K: usize = 5;
302
303    // Initialize population
304    let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305        .map(|_| {
306            let expr = random_expr(4, &mut rng);
307            let fit = fitness(&expr, &data);
308            (expr, fit)
309        })
310        .collect();
311
312    let mut best_expr: Option<Expr> = None;
313    let mut best_fitness = f64::INFINITY;
314
315    for gen in 0..GENERATIONS {
316        // Track best
317        for (expr, fit) in &population {
318            if *fit < best_fitness {
319                best_fitness = *fit;
320                best_expr = Some(expr.clone());
321            }
322        }
323
324        if (gen + 1) % 10 == 0 || gen == 0 {
325            let b = best_expr.as_ref().unwrap();
326            println!(
327                "gen {:>3}: best fitness = {:.6}  nodes = {:>3}  expr = {}",
328                gen + 1,
329                best_fitness,
330                b.size(),
331                b.format(),
332            );
333        }
334
335        // Build next generation
336        let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338        // Elitism: keep top 5
339        let mut indices: Vec<usize> = (0..population.len()).collect();
340        indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341        for &i in indices.iter().take(5) {
342            next_pop.push(population[i].clone());
343        }
344
345        // Fill rest via mutation
346        while next_pop.len() < POP_SIZE {
347            let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348            let child = match rng.range(10) {
349                0..=3 => mutate_grow(parent, &mut rng),
350                4..=6 => mutate_point(parent, &mut rng),
351                7..=8 => mutate_simplify(parent),
352                _ => random_expr(4, &mut rng), // fresh blood
353            };
354
355            // Skip overly large expressions
356            if child.size() > 50 {
357                continue;
358            }
359
360            let fit = fitness(&child, &data);
361            next_pop.push((child, fit));
362        }
363
364        population = next_pop;
365    }
366
367    // Final results
368    println!();
369    let best = best_expr.unwrap();
370    let mut g = ExprGraph::new();
371    let root = best.to_expr(&mut g);
372    let simplified = g.simplify(root);
373    let expr_str = g.fmt_expr(simplified);
374    println!("best expression: {}", expr_str);
375    println!("best fitness:    {:.6}", best_fitness);
376
377    // Verify on test points
378    println!("\nverification:");
379    let compiled = g.compile(simplified);
380    for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381        let pred = compiled(&[x]);
382        let exact = target(x);
383        println!(
384            "  f({:>5.1}) = {:>8.4}  (predicted: {:>8.4}, error: {:>8.4})",
385            x, exact, pred, (pred - exact).abs()
386        );
387    }
388
389    // Symbolic derivative via ExprGraph
390    let dx = g.diff(simplified, 0);
391    let dx = g.simplify(dx);
392    println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}
Source

pub fn len(&self) -> usize

Total number of nodes in the graph.

Source

pub fn is_empty(&self) -> bool

Whether the graph is empty (it never is after construction).

Source

pub fn node(&self, id: ExprId) -> Node

Look up the node for an ExprId.

Examples found in repository?
examples/symbolic_regression.rs (line 214)
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214    match g.node(id) {
215        tang_expr::node::Node::Var(_) => Expr::X,
216        tang_expr::node::Node::Lit(bits) => {
217            let v = f64::from_bits(bits);
218            if v == 0.0 {
219                Expr::Lit(0.0)
220            } else {
221                Expr::Lit(v)
222            }
223        }
224        tang_expr::node::Node::Add(a, b) => {
225            Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226        }
227        tang_expr::node::Node::Mul(a, b) => {
228            Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229        }
230        tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231        tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232        // For operations we don't represent in our AST, just evaluate as literal
233        _ => {
234            let v = g.eval::<f64>(id, &[1.0]); // fallback
235            Expr::Lit(v)
236        }
237    }
238}
Source

pub fn nodes_slice(&self) -> &[Node]

Read-only access to the node arena for serialization.

Source

pub fn var(&mut self, n: u16) -> ExprId

Create a variable node.

Examples found in repository?
examples/symbolic_regression.rs (line 60)
58    fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59        match self {
60            Expr::X => g.var(0),
61            Expr::Lit(v) => g.lit(*v),
62            Expr::Add(a, b) => {
63                let a = a.to_expr(g);
64                let b = b.to_expr(g);
65                g.add(a, b)
66            }
67            Expr::Mul(a, b) => {
68                let a = a.to_expr(g);
69                let b = b.to_expr(g);
70                g.mul(a, b)
71            }
72            Expr::Sin(a) => {
73                let a = a.to_expr(g);
74                g.sin(a)
75            }
76            Expr::Neg(a) => {
77                let a = a.to_expr(g);
78                g.neg(a)
79            }
80        }
81    }
Source

pub fn lit(&mut self, v: f64) -> ExprId

Create a literal node.

Examples found in repository?
examples/symbolic_regression.rs (line 61)
58    fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59        match self {
60            Expr::X => g.var(0),
61            Expr::Lit(v) => g.lit(*v),
62            Expr::Add(a, b) => {
63                let a = a.to_expr(g);
64                let b = b.to_expr(g);
65                g.add(a, b)
66            }
67            Expr::Mul(a, b) => {
68                let a = a.to_expr(g);
69                let b = b.to_expr(g);
70                g.mul(a, b)
71            }
72            Expr::Sin(a) => {
73                let a = a.to_expr(g);
74                g.sin(a)
75            }
76            Expr::Neg(a) => {
77                let a = a.to_expr(g);
78                g.neg(a)
79            }
80        }
81    }
Source

pub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId

Add two expressions.

Examples found in repository?
examples/symbolic_regression.rs (line 65)
58    fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59        match self {
60            Expr::X => g.var(0),
61            Expr::Lit(v) => g.lit(*v),
62            Expr::Add(a, b) => {
63                let a = a.to_expr(g);
64                let b = b.to_expr(g);
65                g.add(a, b)
66            }
67            Expr::Mul(a, b) => {
68                let a = a.to_expr(g);
69                let b = b.to_expr(g);
70                g.mul(a, b)
71            }
72            Expr::Sin(a) => {
73                let a = a.to_expr(g);
74                g.sin(a)
75            }
76            Expr::Neg(a) => {
77                let a = a.to_expr(g);
78                g.neg(a)
79            }
80        }
81    }
Source

pub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId

Multiply two expressions.

Examples found in repository?
examples/symbolic_regression.rs (line 70)
58    fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59        match self {
60            Expr::X => g.var(0),
61            Expr::Lit(v) => g.lit(*v),
62            Expr::Add(a, b) => {
63                let a = a.to_expr(g);
64                let b = b.to_expr(g);
65                g.add(a, b)
66            }
67            Expr::Mul(a, b) => {
68                let a = a.to_expr(g);
69                let b = b.to_expr(g);
70                g.mul(a, b)
71            }
72            Expr::Sin(a) => {
73                let a = a.to_expr(g);
74                g.sin(a)
75            }
76            Expr::Neg(a) => {
77                let a = a.to_expr(g);
78                g.neg(a)
79            }
80        }
81    }
Source

pub fn neg(&mut self, a: ExprId) -> ExprId

Negate an expression.

Examples found in repository?
examples/symbolic_regression.rs (line 78)
58    fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59        match self {
60            Expr::X => g.var(0),
61            Expr::Lit(v) => g.lit(*v),
62            Expr::Add(a, b) => {
63                let a = a.to_expr(g);
64                let b = b.to_expr(g);
65                g.add(a, b)
66            }
67            Expr::Mul(a, b) => {
68                let a = a.to_expr(g);
69                let b = b.to_expr(g);
70                g.mul(a, b)
71            }
72            Expr::Sin(a) => {
73                let a = a.to_expr(g);
74                g.sin(a)
75            }
76            Expr::Neg(a) => {
77                let a = a.to_expr(g);
78                g.neg(a)
79            }
80        }
81    }
Source

pub fn recip(&mut self, a: ExprId) -> ExprId

Reciprocal (1/x).

Source

pub fn sqrt(&mut self, a: ExprId) -> ExprId

Square root.

Source

pub fn sin(&mut self, a: ExprId) -> ExprId

Sine.

Examples found in repository?
examples/symbolic_regression.rs (line 74)
58    fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59        match self {
60            Expr::X => g.var(0),
61            Expr::Lit(v) => g.lit(*v),
62            Expr::Add(a, b) => {
63                let a = a.to_expr(g);
64                let b = b.to_expr(g);
65                g.add(a, b)
66            }
67            Expr::Mul(a, b) => {
68                let a = a.to_expr(g);
69                let b = b.to_expr(g);
70                g.mul(a, b)
71            }
72            Expr::Sin(a) => {
73                let a = a.to_expr(g);
74                g.sin(a)
75            }
76            Expr::Neg(a) => {
77                let a = a.to_expr(g);
78                g.neg(a)
79            }
80        }
81    }
Source

pub fn atan2(&mut self, y: ExprId, x: ExprId) -> ExprId

atan2(y, x).

Source

pub fn exp2(&mut self, a: ExprId) -> ExprId

Base-2 exponential.

Source

pub fn log2(&mut self, a: ExprId) -> ExprId

Base-2 logarithm.

Source

pub fn select(&mut self, cond: ExprId, a: ExprId, b: ExprId) -> ExprId

Branchless select: returns a if cond > 0, else b.

Source§

impl ExprGraph

Source

pub fn simplify(&mut self, expr: ExprId) -> ExprId

Simplify an expression by applying rewrite rules to fixpoint.

Bottom-up: simplify children first, then match parent. Iterates until no more changes occur.

Examples found in repository?
examples/symbolic_regression.rs (line 195)
192fn mutate_simplify(expr: &Expr) -> Expr {
193    let mut g = ExprGraph::new();
194    let root = expr.to_expr(&mut g);
195    let simplified = g.simplify(root);
196    let s = g.fmt_expr(simplified);
197
198    // Quick check: if simplified form is much shorter, it's better
199    let orig = expr.format();
200    if s.len() < orig.len() {
201        // Re-evaluate to confirm it still works
202        let test = g.eval::<f64>(simplified, &[1.0]);
203        let orig_test = expr.eval_at(1.0);
204        if (test - orig_test).abs() < 1e-10 || (test.is_nan() && orig_test.is_nan()) {
205            // Return a new tree built from the simplified graph
206            return expr_from_graph(&g, simplified);
207        }
208    }
209    expr.clone()
210}
211
212/// Reconstruct an Expr tree from an ExprGraph node.
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214    match g.node(id) {
215        tang_expr::node::Node::Var(_) => Expr::X,
216        tang_expr::node::Node::Lit(bits) => {
217            let v = f64::from_bits(bits);
218            if v == 0.0 {
219                Expr::Lit(0.0)
220            } else {
221                Expr::Lit(v)
222            }
223        }
224        tang_expr::node::Node::Add(a, b) => {
225            Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226        }
227        tang_expr::node::Node::Mul(a, b) => {
228            Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229        }
230        tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231        tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232        // For operations we don't represent in our AST, just evaluate as literal
233        _ => {
234            let v = g.eval::<f64>(id, &[1.0]); // fallback
235            Expr::Lit(v)
236        }
237    }
238}
239
240// --- Fitness evaluation ------------------------------------------------------
241
242fn target(x: f64) -> f64 {
243    x * x * x.sin()
244}
245
246fn generate_data(n: usize, rng: &mut Lcg) -> Vec<(f64, f64)> {
247    (0..n)
248        .map(|i| {
249            let x = -3.0 + 6.0 * i as f64 / (n - 1) as f64;
250            let noise = (rng.uniform() - 0.5) * 0.01;
251            (x, target(x) + noise)
252        })
253        .collect()
254}
255
256/// Evaluate fitness: MSE + complexity penalty.
257fn fitness(expr: &Expr, data: &[(f64, f64)]) -> f64 {
258    let mut g = ExprGraph::new();
259    let root = expr.to_expr(&mut g);
260    let compiled = g.compile(root);
261
262    let mut mse = 0.0;
263    for &(x, y) in data {
264        let pred = compiled(&[x]);
265        if pred.is_nan() || pred.is_infinite() {
266            return f64::INFINITY;
267        }
268        let err = pred - y;
269        mse += err * err;
270    }
271    mse /= data.len() as f64;
272
273    let penalty = 0.001 * expr.size() as f64;
274    mse + penalty
275}
276
277// --- Selection ---------------------------------------------------------------
278
279fn tournament<'a>(pop: &'a [(Expr, f64)], k: usize, rng: &mut Lcg) -> &'a Expr {
280    let mut best_idx = rng.range(pop.len());
281    for _ in 1..k {
282        let idx = rng.range(pop.len());
283        if pop[idx].1 < pop[best_idx].1 {
284            best_idx = idx;
285        }
286    }
287    &pop[best_idx].0
288}
289
290// --- Main --------------------------------------------------------------------
291
292fn main() {
293    println!("=== Symbolic Regression ===\n");
294    println!("target: f(x) = x^2 * sin(x)\n");
295
296    let mut rng = Lcg::new(42);
297    let data = generate_data(50, &mut rng);
298
299    const POP_SIZE: usize = 200;
300    const GENERATIONS: usize = 100;
301    const TOURNAMENT_K: usize = 5;
302
303    // Initialize population
304    let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305        .map(|_| {
306            let expr = random_expr(4, &mut rng);
307            let fit = fitness(&expr, &data);
308            (expr, fit)
309        })
310        .collect();
311
312    let mut best_expr: Option<Expr> = None;
313    let mut best_fitness = f64::INFINITY;
314
315    for gen in 0..GENERATIONS {
316        // Track best
317        for (expr, fit) in &population {
318            if *fit < best_fitness {
319                best_fitness = *fit;
320                best_expr = Some(expr.clone());
321            }
322        }
323
324        if (gen + 1) % 10 == 0 || gen == 0 {
325            let b = best_expr.as_ref().unwrap();
326            println!(
327                "gen {:>3}: best fitness = {:.6}  nodes = {:>3}  expr = {}",
328                gen + 1,
329                best_fitness,
330                b.size(),
331                b.format(),
332            );
333        }
334
335        // Build next generation
336        let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338        // Elitism: keep top 5
339        let mut indices: Vec<usize> = (0..population.len()).collect();
340        indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341        for &i in indices.iter().take(5) {
342            next_pop.push(population[i].clone());
343        }
344
345        // Fill rest via mutation
346        while next_pop.len() < POP_SIZE {
347            let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348            let child = match rng.range(10) {
349                0..=3 => mutate_grow(parent, &mut rng),
350                4..=6 => mutate_point(parent, &mut rng),
351                7..=8 => mutate_simplify(parent),
352                _ => random_expr(4, &mut rng), // fresh blood
353            };
354
355            // Skip overly large expressions
356            if child.size() > 50 {
357                continue;
358            }
359
360            let fit = fitness(&child, &data);
361            next_pop.push((child, fit));
362        }
363
364        population = next_pop;
365    }
366
367    // Final results
368    println!();
369    let best = best_expr.unwrap();
370    let mut g = ExprGraph::new();
371    let root = best.to_expr(&mut g);
372    let simplified = g.simplify(root);
373    let expr_str = g.fmt_expr(simplified);
374    println!("best expression: {}", expr_str);
375    println!("best fitness:    {:.6}", best_fitness);
376
377    // Verify on test points
378    println!("\nverification:");
379    let compiled = g.compile(simplified);
380    for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381        let pred = compiled(&[x]);
382        let exact = target(x);
383        println!(
384            "  f({:>5.1}) = {:>8.4}  (predicted: {:>8.4}, error: {:>8.4})",
385            x, exact, pred, (pred - exact).abs()
386        );
387    }
388
389    // Symbolic derivative via ExprGraph
390    let dx = g.diff(simplified, 0);
391    let dx = g.simplify(dx);
392    println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}
Source§

impl ExprGraph

Source

pub fn deps(&self, expr: ExprId) -> u64

Compute a bitmask of which Var(n) indices appear in expr.

Bit n is set if Var(n) is reachable from expr. Supports up to 64 variables.

Source

pub fn jacobian_sparsity(&self, outputs: &[ExprId], n_vars: usize) -> Vec<u64>

Compute the Jacobian sparsity pattern.

Returns one u64 bitmask per output expression. Bit j of result[i] is set if outputs[i] depends on Var(j).

Source§

impl ExprGraph

Source

pub fn to_wgsl(&self, outputs: &[ExprId], n_inputs: usize) -> WgslKernel

Generate a WGSL compute shader that evaluates expressions in parallel.

Each work item reads n_inputs values and writes outputs.len() values. The generated shader uses f32 (GPU native). Shared subexpressions are computed once per thread.

The caller handles device/pipeline/dispatch (no wgpu dependency here).

Trait Implementations§

Source§

impl Clone for ExprGraph

Source§

fn clone(&self) -> ExprGraph

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Default for ExprGraph

Source§

fn default() -> Self

Returns the “default value” for a type. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.