pub struct ExprGraph { /* private fields */ }Expand description
Arena-based expression graph with structural interning.
Identical subexpressions always return the same ExprId — this gives
automatic common subexpression elimination (CSE) for free.
Implementations§
Source§impl ExprGraph
impl ExprGraph
Sourcepub fn compile(&self, expr: ExprId) -> CompiledExpr
pub fn compile(&self, expr: ExprId) -> CompiledExpr
Compile a single expression to a closure &[f64] -> f64.
Shared subexpressions (from interning) are computed once. Dead nodes (not reachable from output) are skipped.
Examples found in repository?
257fn fitness(expr: &Expr, data: &[(f64, f64)]) -> f64 {
258 let mut g = ExprGraph::new();
259 let root = expr.to_expr(&mut g);
260 let compiled = g.compile(root);
261
262 let mut mse = 0.0;
263 for &(x, y) in data {
264 let pred = compiled(&[x]);
265 if pred.is_nan() || pred.is_infinite() {
266 return f64::INFINITY;
267 }
268 let err = pred - y;
269 mse += err * err;
270 }
271 mse /= data.len() as f64;
272
273 let penalty = 0.001 * expr.size() as f64;
274 mse + penalty
275}
276
277// --- Selection ---------------------------------------------------------------
278
279fn tournament<'a>(pop: &'a [(Expr, f64)], k: usize, rng: &mut Lcg) -> &'a Expr {
280 let mut best_idx = rng.range(pop.len());
281 for _ in 1..k {
282 let idx = rng.range(pop.len());
283 if pop[idx].1 < pop[best_idx].1 {
284 best_idx = idx;
285 }
286 }
287 &pop[best_idx].0
288}
289
290// --- Main --------------------------------------------------------------------
291
292fn main() {
293 println!("=== Symbolic Regression ===\n");
294 println!("target: f(x) = x^2 * sin(x)\n");
295
296 let mut rng = Lcg::new(42);
297 let data = generate_data(50, &mut rng);
298
299 const POP_SIZE: usize = 200;
300 const GENERATIONS: usize = 100;
301 const TOURNAMENT_K: usize = 5;
302
303 // Initialize population
304 let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305 .map(|_| {
306 let expr = random_expr(4, &mut rng);
307 let fit = fitness(&expr, &data);
308 (expr, fit)
309 })
310 .collect();
311
312 let mut best_expr: Option<Expr> = None;
313 let mut best_fitness = f64::INFINITY;
314
315 for gen in 0..GENERATIONS {
316 // Track best
317 for (expr, fit) in &population {
318 if *fit < best_fitness {
319 best_fitness = *fit;
320 best_expr = Some(expr.clone());
321 }
322 }
323
324 if (gen + 1) % 10 == 0 || gen == 0 {
325 let b = best_expr.as_ref().unwrap();
326 println!(
327 "gen {:>3}: best fitness = {:.6} nodes = {:>3} expr = {}",
328 gen + 1,
329 best_fitness,
330 b.size(),
331 b.format(),
332 );
333 }
334
335 // Build next generation
336 let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338 // Elitism: keep top 5
339 let mut indices: Vec<usize> = (0..population.len()).collect();
340 indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341 for &i in indices.iter().take(5) {
342 next_pop.push(population[i].clone());
343 }
344
345 // Fill rest via mutation
346 while next_pop.len() < POP_SIZE {
347 let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348 let child = match rng.range(10) {
349 0..=3 => mutate_grow(parent, &mut rng),
350 4..=6 => mutate_point(parent, &mut rng),
351 7..=8 => mutate_simplify(parent),
352 _ => random_expr(4, &mut rng), // fresh blood
353 };
354
355 // Skip overly large expressions
356 if child.size() > 50 {
357 continue;
358 }
359
360 let fit = fitness(&child, &data);
361 next_pop.push((child, fit));
362 }
363
364 population = next_pop;
365 }
366
367 // Final results
368 println!();
369 let best = best_expr.unwrap();
370 let mut g = ExprGraph::new();
371 let root = best.to_expr(&mut g);
372 let simplified = g.simplify(root);
373 let expr_str = g.fmt_expr(simplified);
374 println!("best expression: {}", expr_str);
375 println!("best fitness: {:.6}", best_fitness);
376
377 // Verify on test points
378 println!("\nverification:");
379 let compiled = g.compile(simplified);
380 for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381 let pred = compiled(&[x]);
382 let exact = target(x);
383 println!(
384 " f({:>5.1}) = {:>8.4} (predicted: {:>8.4}, error: {:>8.4})",
385 x, exact, pred, (pred - exact).abs()
386 );
387 }
388
389 // Symbolic derivative via ExprGraph
390 let dx = g.diff(simplified, 0);
391 let dx = g.simplify(dx);
392 println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}Sourcepub fn compile_many(&self, exprs: &[ExprId]) -> CompiledMany
pub fn compile_many(&self, exprs: &[ExprId]) -> CompiledMany
Compile multiple output expressions to a single closure.
Writes results into the output slice.
Source§impl ExprGraph
impl ExprGraph
Sourcepub fn diff(&mut self, expr: ExprId, var: u16) -> ExprId
pub fn diff(&mut self, expr: ExprId, var: u16) -> ExprId
Differentiate expr with respect to Var(var).
Returns a new ExprId in the same graph. Uses memoization to avoid recomputing derivatives of shared subexpressions.
Examples found in repository?
292fn main() {
293 println!("=== Symbolic Regression ===\n");
294 println!("target: f(x) = x^2 * sin(x)\n");
295
296 let mut rng = Lcg::new(42);
297 let data = generate_data(50, &mut rng);
298
299 const POP_SIZE: usize = 200;
300 const GENERATIONS: usize = 100;
301 const TOURNAMENT_K: usize = 5;
302
303 // Initialize population
304 let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305 .map(|_| {
306 let expr = random_expr(4, &mut rng);
307 let fit = fitness(&expr, &data);
308 (expr, fit)
309 })
310 .collect();
311
312 let mut best_expr: Option<Expr> = None;
313 let mut best_fitness = f64::INFINITY;
314
315 for gen in 0..GENERATIONS {
316 // Track best
317 for (expr, fit) in &population {
318 if *fit < best_fitness {
319 best_fitness = *fit;
320 best_expr = Some(expr.clone());
321 }
322 }
323
324 if (gen + 1) % 10 == 0 || gen == 0 {
325 let b = best_expr.as_ref().unwrap();
326 println!(
327 "gen {:>3}: best fitness = {:.6} nodes = {:>3} expr = {}",
328 gen + 1,
329 best_fitness,
330 b.size(),
331 b.format(),
332 );
333 }
334
335 // Build next generation
336 let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338 // Elitism: keep top 5
339 let mut indices: Vec<usize> = (0..population.len()).collect();
340 indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341 for &i in indices.iter().take(5) {
342 next_pop.push(population[i].clone());
343 }
344
345 // Fill rest via mutation
346 while next_pop.len() < POP_SIZE {
347 let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348 let child = match rng.range(10) {
349 0..=3 => mutate_grow(parent, &mut rng),
350 4..=6 => mutate_point(parent, &mut rng),
351 7..=8 => mutate_simplify(parent),
352 _ => random_expr(4, &mut rng), // fresh blood
353 };
354
355 // Skip overly large expressions
356 if child.size() > 50 {
357 continue;
358 }
359
360 let fit = fitness(&child, &data);
361 next_pop.push((child, fit));
362 }
363
364 population = next_pop;
365 }
366
367 // Final results
368 println!();
369 let best = best_expr.unwrap();
370 let mut g = ExprGraph::new();
371 let root = best.to_expr(&mut g);
372 let simplified = g.simplify(root);
373 let expr_str = g.fmt_expr(simplified);
374 println!("best expression: {}", expr_str);
375 println!("best fitness: {:.6}", best_fitness);
376
377 // Verify on test points
378 println!("\nverification:");
379 let compiled = g.compile(simplified);
380 for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381 let pred = compiled(&[x]);
382 let exact = target(x);
383 println!(
384 " f({:>5.1}) = {:>8.4} (predicted: {:>8.4}, error: {:>8.4})",
385 x, exact, pred, (pred - exact).abs()
386 );
387 }
388
389 // Symbolic derivative via ExprGraph
390 let dx = g.diff(simplified, 0);
391 let dx = g.simplify(dx);
392 println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}Source§impl ExprGraph
impl ExprGraph
Sourcepub fn fmt_expr(&self, expr: ExprId) -> String
pub fn fmt_expr(&self, expr: ExprId) -> String
Format an expression as a human-readable string.
Examples found in repository?
84 fn format(&self) -> String {
85 let mut g = ExprGraph::new();
86 let root = self.to_expr(&mut g);
87 g.fmt_expr(root)
88 }
89
90 /// Evaluate at a point using ExprGraph's compiled eval.
91 fn eval_at(&self, x: f64) -> f64 {
92 let mut g = ExprGraph::new();
93 let root = self.to_expr(&mut g);
94 g.eval(root, &[x])
95 }
96}
97
98// --- Random expression generation -------------------------------------------
99
100fn random_expr(depth: usize, rng: &mut Lcg) -> Expr {
101 if depth == 0 || (depth < 3 && rng.uniform() < 0.3) {
102 return match rng.range(3) {
103 0 => Expr::X,
104 1 => Expr::Lit((rng.uniform() * 4.0 - 2.0) * 10.0_f64.powi(-((rng.uniform() * 2.0) as i32))),
105 _ => Expr::X,
106 };
107 }
108
109 match rng.range(5) {
110 0 => Expr::Add(
111 Box::new(random_expr(depth - 1, rng)),
112 Box::new(random_expr(depth - 1, rng)),
113 ),
114 1 | 2 => Expr::Mul(
115 Box::new(random_expr(depth - 1, rng)),
116 Box::new(random_expr(depth - 1, rng)),
117 ),
118 3 => Expr::Sin(Box::new(random_expr(depth - 1, rng))),
119 _ => Expr::Neg(Box::new(random_expr(depth - 1, rng))),
120 }
121}
122
123// --- Mutation operators ------------------------------------------------------
124
125/// Grow mutation: replace a random subtree with a new random one.
126fn mutate_grow(expr: &Expr, rng: &mut Lcg) -> Expr {
127 let size = expr.size();
128 let target = rng.range(size);
129 grow_at(expr, target, &mut 0, rng)
130}
131
132fn grow_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
133 if *counter == target {
134 *counter += expr.size(); // skip subtree
135 return random_expr(3, rng);
136 }
137 *counter += 1;
138 match expr {
139 Expr::X => Expr::X,
140 Expr::Lit(v) => Expr::Lit(*v),
141 Expr::Add(a, b) => Expr::Add(
142 Box::new(grow_at(a, target, counter, rng)),
143 Box::new(grow_at(b, target, counter, rng)),
144 ),
145 Expr::Mul(a, b) => Expr::Mul(
146 Box::new(grow_at(a, target, counter, rng)),
147 Box::new(grow_at(b, target, counter, rng)),
148 ),
149 Expr::Sin(a) => Expr::Sin(Box::new(grow_at(a, target, counter, rng))),
150 Expr::Neg(a) => Expr::Neg(Box::new(grow_at(a, target, counter, rng))),
151 }
152}
153
154/// Point mutation: change a single node's operation.
155fn mutate_point(expr: &Expr, rng: &mut Lcg) -> Expr {
156 let size = expr.size();
157 let target = rng.range(size);
158 point_at(expr, target, &mut 0, rng)
159}
160
161fn point_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
162 if *counter == target {
163 *counter += 1;
164 return match expr {
165 Expr::X => Expr::Lit(rng.uniform() * 2.0 - 1.0),
166 Expr::Lit(_) => Expr::X,
167 Expr::Add(a, b) => Expr::Mul(a.clone(), b.clone()),
168 Expr::Mul(a, b) => Expr::Add(a.clone(), b.clone()),
169 Expr::Sin(a) => Expr::Neg(a.clone()),
170 Expr::Neg(a) => Expr::Sin(a.clone()),
171 };
172 }
173 *counter += 1;
174 match expr {
175 Expr::X => Expr::X,
176 Expr::Lit(v) => Expr::Lit(*v),
177 Expr::Add(a, b) => Expr::Add(
178 Box::new(point_at(a, target, counter, rng)),
179 Box::new(point_at(b, target, counter, rng)),
180 ),
181 Expr::Mul(a, b) => Expr::Mul(
182 Box::new(point_at(a, target, counter, rng)),
183 Box::new(point_at(b, target, counter, rng)),
184 ),
185 Expr::Sin(a) => Expr::Sin(Box::new(point_at(a, target, counter, rng))),
186 Expr::Neg(a) => Expr::Neg(Box::new(point_at(a, target, counter, rng))),
187 }
188}
189
190/// Simplify mutation: convert to ExprGraph, simplify, convert back.
191/// Falls back to identity if conversion back would be complex.
192fn mutate_simplify(expr: &Expr) -> Expr {
193 let mut g = ExprGraph::new();
194 let root = expr.to_expr(&mut g);
195 let simplified = g.simplify(root);
196 let s = g.fmt_expr(simplified);
197
198 // Quick check: if simplified form is much shorter, it's better
199 let orig = expr.format();
200 if s.len() < orig.len() {
201 // Re-evaluate to confirm it still works
202 let test = g.eval::<f64>(simplified, &[1.0]);
203 let orig_test = expr.eval_at(1.0);
204 if (test - orig_test).abs() < 1e-10 || (test.is_nan() && orig_test.is_nan()) {
205 // Return a new tree built from the simplified graph
206 return expr_from_graph(&g, simplified);
207 }
208 }
209 expr.clone()
210}
211
212/// Reconstruct an Expr tree from an ExprGraph node.
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214 match g.node(id) {
215 tang_expr::node::Node::Var(_) => Expr::X,
216 tang_expr::node::Node::Lit(bits) => {
217 let v = f64::from_bits(bits);
218 if v == 0.0 {
219 Expr::Lit(0.0)
220 } else {
221 Expr::Lit(v)
222 }
223 }
224 tang_expr::node::Node::Add(a, b) => {
225 Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226 }
227 tang_expr::node::Node::Mul(a, b) => {
228 Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229 }
230 tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231 tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232 // For operations we don't represent in our AST, just evaluate as literal
233 _ => {
234 let v = g.eval::<f64>(id, &[1.0]); // fallback
235 Expr::Lit(v)
236 }
237 }
238}
239
240// --- Fitness evaluation ------------------------------------------------------
241
242fn target(x: f64) -> f64 {
243 x * x * x.sin()
244}
245
246fn generate_data(n: usize, rng: &mut Lcg) -> Vec<(f64, f64)> {
247 (0..n)
248 .map(|i| {
249 let x = -3.0 + 6.0 * i as f64 / (n - 1) as f64;
250 let noise = (rng.uniform() - 0.5) * 0.01;
251 (x, target(x) + noise)
252 })
253 .collect()
254}
255
256/// Evaluate fitness: MSE + complexity penalty.
257fn fitness(expr: &Expr, data: &[(f64, f64)]) -> f64 {
258 let mut g = ExprGraph::new();
259 let root = expr.to_expr(&mut g);
260 let compiled = g.compile(root);
261
262 let mut mse = 0.0;
263 for &(x, y) in data {
264 let pred = compiled(&[x]);
265 if pred.is_nan() || pred.is_infinite() {
266 return f64::INFINITY;
267 }
268 let err = pred - y;
269 mse += err * err;
270 }
271 mse /= data.len() as f64;
272
273 let penalty = 0.001 * expr.size() as f64;
274 mse + penalty
275}
276
277// --- Selection ---------------------------------------------------------------
278
279fn tournament<'a>(pop: &'a [(Expr, f64)], k: usize, rng: &mut Lcg) -> &'a Expr {
280 let mut best_idx = rng.range(pop.len());
281 for _ in 1..k {
282 let idx = rng.range(pop.len());
283 if pop[idx].1 < pop[best_idx].1 {
284 best_idx = idx;
285 }
286 }
287 &pop[best_idx].0
288}
289
290// --- Main --------------------------------------------------------------------
291
292fn main() {
293 println!("=== Symbolic Regression ===\n");
294 println!("target: f(x) = x^2 * sin(x)\n");
295
296 let mut rng = Lcg::new(42);
297 let data = generate_data(50, &mut rng);
298
299 const POP_SIZE: usize = 200;
300 const GENERATIONS: usize = 100;
301 const TOURNAMENT_K: usize = 5;
302
303 // Initialize population
304 let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305 .map(|_| {
306 let expr = random_expr(4, &mut rng);
307 let fit = fitness(&expr, &data);
308 (expr, fit)
309 })
310 .collect();
311
312 let mut best_expr: Option<Expr> = None;
313 let mut best_fitness = f64::INFINITY;
314
315 for gen in 0..GENERATIONS {
316 // Track best
317 for (expr, fit) in &population {
318 if *fit < best_fitness {
319 best_fitness = *fit;
320 best_expr = Some(expr.clone());
321 }
322 }
323
324 if (gen + 1) % 10 == 0 || gen == 0 {
325 let b = best_expr.as_ref().unwrap();
326 println!(
327 "gen {:>3}: best fitness = {:.6} nodes = {:>3} expr = {}",
328 gen + 1,
329 best_fitness,
330 b.size(),
331 b.format(),
332 );
333 }
334
335 // Build next generation
336 let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338 // Elitism: keep top 5
339 let mut indices: Vec<usize> = (0..population.len()).collect();
340 indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341 for &i in indices.iter().take(5) {
342 next_pop.push(population[i].clone());
343 }
344
345 // Fill rest via mutation
346 while next_pop.len() < POP_SIZE {
347 let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348 let child = match rng.range(10) {
349 0..=3 => mutate_grow(parent, &mut rng),
350 4..=6 => mutate_point(parent, &mut rng),
351 7..=8 => mutate_simplify(parent),
352 _ => random_expr(4, &mut rng), // fresh blood
353 };
354
355 // Skip overly large expressions
356 if child.size() > 50 {
357 continue;
358 }
359
360 let fit = fitness(&child, &data);
361 next_pop.push((child, fit));
362 }
363
364 population = next_pop;
365 }
366
367 // Final results
368 println!();
369 let best = best_expr.unwrap();
370 let mut g = ExprGraph::new();
371 let root = best.to_expr(&mut g);
372 let simplified = g.simplify(root);
373 let expr_str = g.fmt_expr(simplified);
374 println!("best expression: {}", expr_str);
375 println!("best fitness: {:.6}", best_fitness);
376
377 // Verify on test points
378 println!("\nverification:");
379 let compiled = g.compile(simplified);
380 for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381 let pred = compiled(&[x]);
382 let exact = target(x);
383 println!(
384 " f({:>5.1}) = {:>8.4} (predicted: {:>8.4}, error: {:>8.4})",
385 x, exact, pred, (pred - exact).abs()
386 );
387 }
388
389 // Symbolic derivative via ExprGraph
390 let dx = g.diff(simplified, 0);
391 let dx = g.simplify(dx);
392 println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}Source§impl ExprGraph
impl ExprGraph
Sourcepub fn eval<S: Scalar>(&self, expr: ExprId, inputs: &[S]) -> S
pub fn eval<S: Scalar>(&self, expr: ExprId, inputs: &[S]) -> S
Evaluate an expression with concrete scalar inputs.
inputs[n] provides the value for Var(n). Walks the graph in
topological order (which is just index order, since children are
always created before parents).
Examples found in repository?
91 fn eval_at(&self, x: f64) -> f64 {
92 let mut g = ExprGraph::new();
93 let root = self.to_expr(&mut g);
94 g.eval(root, &[x])
95 }
96}
97
98// --- Random expression generation -------------------------------------------
99
100fn random_expr(depth: usize, rng: &mut Lcg) -> Expr {
101 if depth == 0 || (depth < 3 && rng.uniform() < 0.3) {
102 return match rng.range(3) {
103 0 => Expr::X,
104 1 => Expr::Lit((rng.uniform() * 4.0 - 2.0) * 10.0_f64.powi(-((rng.uniform() * 2.0) as i32))),
105 _ => Expr::X,
106 };
107 }
108
109 match rng.range(5) {
110 0 => Expr::Add(
111 Box::new(random_expr(depth - 1, rng)),
112 Box::new(random_expr(depth - 1, rng)),
113 ),
114 1 | 2 => Expr::Mul(
115 Box::new(random_expr(depth - 1, rng)),
116 Box::new(random_expr(depth - 1, rng)),
117 ),
118 3 => Expr::Sin(Box::new(random_expr(depth - 1, rng))),
119 _ => Expr::Neg(Box::new(random_expr(depth - 1, rng))),
120 }
121}
122
123// --- Mutation operators ------------------------------------------------------
124
125/// Grow mutation: replace a random subtree with a new random one.
126fn mutate_grow(expr: &Expr, rng: &mut Lcg) -> Expr {
127 let size = expr.size();
128 let target = rng.range(size);
129 grow_at(expr, target, &mut 0, rng)
130}
131
132fn grow_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
133 if *counter == target {
134 *counter += expr.size(); // skip subtree
135 return random_expr(3, rng);
136 }
137 *counter += 1;
138 match expr {
139 Expr::X => Expr::X,
140 Expr::Lit(v) => Expr::Lit(*v),
141 Expr::Add(a, b) => Expr::Add(
142 Box::new(grow_at(a, target, counter, rng)),
143 Box::new(grow_at(b, target, counter, rng)),
144 ),
145 Expr::Mul(a, b) => Expr::Mul(
146 Box::new(grow_at(a, target, counter, rng)),
147 Box::new(grow_at(b, target, counter, rng)),
148 ),
149 Expr::Sin(a) => Expr::Sin(Box::new(grow_at(a, target, counter, rng))),
150 Expr::Neg(a) => Expr::Neg(Box::new(grow_at(a, target, counter, rng))),
151 }
152}
153
154/// Point mutation: change a single node's operation.
155fn mutate_point(expr: &Expr, rng: &mut Lcg) -> Expr {
156 let size = expr.size();
157 let target = rng.range(size);
158 point_at(expr, target, &mut 0, rng)
159}
160
161fn point_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
162 if *counter == target {
163 *counter += 1;
164 return match expr {
165 Expr::X => Expr::Lit(rng.uniform() * 2.0 - 1.0),
166 Expr::Lit(_) => Expr::X,
167 Expr::Add(a, b) => Expr::Mul(a.clone(), b.clone()),
168 Expr::Mul(a, b) => Expr::Add(a.clone(), b.clone()),
169 Expr::Sin(a) => Expr::Neg(a.clone()),
170 Expr::Neg(a) => Expr::Sin(a.clone()),
171 };
172 }
173 *counter += 1;
174 match expr {
175 Expr::X => Expr::X,
176 Expr::Lit(v) => Expr::Lit(*v),
177 Expr::Add(a, b) => Expr::Add(
178 Box::new(point_at(a, target, counter, rng)),
179 Box::new(point_at(b, target, counter, rng)),
180 ),
181 Expr::Mul(a, b) => Expr::Mul(
182 Box::new(point_at(a, target, counter, rng)),
183 Box::new(point_at(b, target, counter, rng)),
184 ),
185 Expr::Sin(a) => Expr::Sin(Box::new(point_at(a, target, counter, rng))),
186 Expr::Neg(a) => Expr::Neg(Box::new(point_at(a, target, counter, rng))),
187 }
188}
189
190/// Simplify mutation: convert to ExprGraph, simplify, convert back.
191/// Falls back to identity if conversion back would be complex.
192fn mutate_simplify(expr: &Expr) -> Expr {
193 let mut g = ExprGraph::new();
194 let root = expr.to_expr(&mut g);
195 let simplified = g.simplify(root);
196 let s = g.fmt_expr(simplified);
197
198 // Quick check: if simplified form is much shorter, it's better
199 let orig = expr.format();
200 if s.len() < orig.len() {
201 // Re-evaluate to confirm it still works
202 let test = g.eval::<f64>(simplified, &[1.0]);
203 let orig_test = expr.eval_at(1.0);
204 if (test - orig_test).abs() < 1e-10 || (test.is_nan() && orig_test.is_nan()) {
205 // Return a new tree built from the simplified graph
206 return expr_from_graph(&g, simplified);
207 }
208 }
209 expr.clone()
210}
211
212/// Reconstruct an Expr tree from an ExprGraph node.
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214 match g.node(id) {
215 tang_expr::node::Node::Var(_) => Expr::X,
216 tang_expr::node::Node::Lit(bits) => {
217 let v = f64::from_bits(bits);
218 if v == 0.0 {
219 Expr::Lit(0.0)
220 } else {
221 Expr::Lit(v)
222 }
223 }
224 tang_expr::node::Node::Add(a, b) => {
225 Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226 }
227 tang_expr::node::Node::Mul(a, b) => {
228 Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229 }
230 tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231 tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232 // For operations we don't represent in our AST, just evaluate as literal
233 _ => {
234 let v = g.eval::<f64>(id, &[1.0]); // fallback
235 Expr::Lit(v)
236 }
237 }
238}Source§impl ExprGraph
impl ExprGraph
Sourcepub fn new() -> Self
pub fn new() -> Self
Create a new graph pre-populated with ZERO, ONE, TWO.
Examples found in repository?
84 fn format(&self) -> String {
85 let mut g = ExprGraph::new();
86 let root = self.to_expr(&mut g);
87 g.fmt_expr(root)
88 }
89
90 /// Evaluate at a point using ExprGraph's compiled eval.
91 fn eval_at(&self, x: f64) -> f64 {
92 let mut g = ExprGraph::new();
93 let root = self.to_expr(&mut g);
94 g.eval(root, &[x])
95 }
96}
97
98// --- Random expression generation -------------------------------------------
99
100fn random_expr(depth: usize, rng: &mut Lcg) -> Expr {
101 if depth == 0 || (depth < 3 && rng.uniform() < 0.3) {
102 return match rng.range(3) {
103 0 => Expr::X,
104 1 => Expr::Lit((rng.uniform() * 4.0 - 2.0) * 10.0_f64.powi(-((rng.uniform() * 2.0) as i32))),
105 _ => Expr::X,
106 };
107 }
108
109 match rng.range(5) {
110 0 => Expr::Add(
111 Box::new(random_expr(depth - 1, rng)),
112 Box::new(random_expr(depth - 1, rng)),
113 ),
114 1 | 2 => Expr::Mul(
115 Box::new(random_expr(depth - 1, rng)),
116 Box::new(random_expr(depth - 1, rng)),
117 ),
118 3 => Expr::Sin(Box::new(random_expr(depth - 1, rng))),
119 _ => Expr::Neg(Box::new(random_expr(depth - 1, rng))),
120 }
121}
122
123// --- Mutation operators ------------------------------------------------------
124
125/// Grow mutation: replace a random subtree with a new random one.
126fn mutate_grow(expr: &Expr, rng: &mut Lcg) -> Expr {
127 let size = expr.size();
128 let target = rng.range(size);
129 grow_at(expr, target, &mut 0, rng)
130}
131
132fn grow_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
133 if *counter == target {
134 *counter += expr.size(); // skip subtree
135 return random_expr(3, rng);
136 }
137 *counter += 1;
138 match expr {
139 Expr::X => Expr::X,
140 Expr::Lit(v) => Expr::Lit(*v),
141 Expr::Add(a, b) => Expr::Add(
142 Box::new(grow_at(a, target, counter, rng)),
143 Box::new(grow_at(b, target, counter, rng)),
144 ),
145 Expr::Mul(a, b) => Expr::Mul(
146 Box::new(grow_at(a, target, counter, rng)),
147 Box::new(grow_at(b, target, counter, rng)),
148 ),
149 Expr::Sin(a) => Expr::Sin(Box::new(grow_at(a, target, counter, rng))),
150 Expr::Neg(a) => Expr::Neg(Box::new(grow_at(a, target, counter, rng))),
151 }
152}
153
154/// Point mutation: change a single node's operation.
155fn mutate_point(expr: &Expr, rng: &mut Lcg) -> Expr {
156 let size = expr.size();
157 let target = rng.range(size);
158 point_at(expr, target, &mut 0, rng)
159}
160
161fn point_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
162 if *counter == target {
163 *counter += 1;
164 return match expr {
165 Expr::X => Expr::Lit(rng.uniform() * 2.0 - 1.0),
166 Expr::Lit(_) => Expr::X,
167 Expr::Add(a, b) => Expr::Mul(a.clone(), b.clone()),
168 Expr::Mul(a, b) => Expr::Add(a.clone(), b.clone()),
169 Expr::Sin(a) => Expr::Neg(a.clone()),
170 Expr::Neg(a) => Expr::Sin(a.clone()),
171 };
172 }
173 *counter += 1;
174 match expr {
175 Expr::X => Expr::X,
176 Expr::Lit(v) => Expr::Lit(*v),
177 Expr::Add(a, b) => Expr::Add(
178 Box::new(point_at(a, target, counter, rng)),
179 Box::new(point_at(b, target, counter, rng)),
180 ),
181 Expr::Mul(a, b) => Expr::Mul(
182 Box::new(point_at(a, target, counter, rng)),
183 Box::new(point_at(b, target, counter, rng)),
184 ),
185 Expr::Sin(a) => Expr::Sin(Box::new(point_at(a, target, counter, rng))),
186 Expr::Neg(a) => Expr::Neg(Box::new(point_at(a, target, counter, rng))),
187 }
188}
189
190/// Simplify mutation: convert to ExprGraph, simplify, convert back.
191/// Falls back to identity if conversion back would be complex.
192fn mutate_simplify(expr: &Expr) -> Expr {
193 let mut g = ExprGraph::new();
194 let root = expr.to_expr(&mut g);
195 let simplified = g.simplify(root);
196 let s = g.fmt_expr(simplified);
197
198 // Quick check: if simplified form is much shorter, it's better
199 let orig = expr.format();
200 if s.len() < orig.len() {
201 // Re-evaluate to confirm it still works
202 let test = g.eval::<f64>(simplified, &[1.0]);
203 let orig_test = expr.eval_at(1.0);
204 if (test - orig_test).abs() < 1e-10 || (test.is_nan() && orig_test.is_nan()) {
205 // Return a new tree built from the simplified graph
206 return expr_from_graph(&g, simplified);
207 }
208 }
209 expr.clone()
210}
211
212/// Reconstruct an Expr tree from an ExprGraph node.
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214 match g.node(id) {
215 tang_expr::node::Node::Var(_) => Expr::X,
216 tang_expr::node::Node::Lit(bits) => {
217 let v = f64::from_bits(bits);
218 if v == 0.0 {
219 Expr::Lit(0.0)
220 } else {
221 Expr::Lit(v)
222 }
223 }
224 tang_expr::node::Node::Add(a, b) => {
225 Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226 }
227 tang_expr::node::Node::Mul(a, b) => {
228 Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229 }
230 tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231 tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232 // For operations we don't represent in our AST, just evaluate as literal
233 _ => {
234 let v = g.eval::<f64>(id, &[1.0]); // fallback
235 Expr::Lit(v)
236 }
237 }
238}
239
240// --- Fitness evaluation ------------------------------------------------------
241
242fn target(x: f64) -> f64 {
243 x * x * x.sin()
244}
245
246fn generate_data(n: usize, rng: &mut Lcg) -> Vec<(f64, f64)> {
247 (0..n)
248 .map(|i| {
249 let x = -3.0 + 6.0 * i as f64 / (n - 1) as f64;
250 let noise = (rng.uniform() - 0.5) * 0.01;
251 (x, target(x) + noise)
252 })
253 .collect()
254}
255
256/// Evaluate fitness: MSE + complexity penalty.
257fn fitness(expr: &Expr, data: &[(f64, f64)]) -> f64 {
258 let mut g = ExprGraph::new();
259 let root = expr.to_expr(&mut g);
260 let compiled = g.compile(root);
261
262 let mut mse = 0.0;
263 for &(x, y) in data {
264 let pred = compiled(&[x]);
265 if pred.is_nan() || pred.is_infinite() {
266 return f64::INFINITY;
267 }
268 let err = pred - y;
269 mse += err * err;
270 }
271 mse /= data.len() as f64;
272
273 let penalty = 0.001 * expr.size() as f64;
274 mse + penalty
275}
276
277// --- Selection ---------------------------------------------------------------
278
279fn tournament<'a>(pop: &'a [(Expr, f64)], k: usize, rng: &mut Lcg) -> &'a Expr {
280 let mut best_idx = rng.range(pop.len());
281 for _ in 1..k {
282 let idx = rng.range(pop.len());
283 if pop[idx].1 < pop[best_idx].1 {
284 best_idx = idx;
285 }
286 }
287 &pop[best_idx].0
288}
289
290// --- Main --------------------------------------------------------------------
291
292fn main() {
293 println!("=== Symbolic Regression ===\n");
294 println!("target: f(x) = x^2 * sin(x)\n");
295
296 let mut rng = Lcg::new(42);
297 let data = generate_data(50, &mut rng);
298
299 const POP_SIZE: usize = 200;
300 const GENERATIONS: usize = 100;
301 const TOURNAMENT_K: usize = 5;
302
303 // Initialize population
304 let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305 .map(|_| {
306 let expr = random_expr(4, &mut rng);
307 let fit = fitness(&expr, &data);
308 (expr, fit)
309 })
310 .collect();
311
312 let mut best_expr: Option<Expr> = None;
313 let mut best_fitness = f64::INFINITY;
314
315 for gen in 0..GENERATIONS {
316 // Track best
317 for (expr, fit) in &population {
318 if *fit < best_fitness {
319 best_fitness = *fit;
320 best_expr = Some(expr.clone());
321 }
322 }
323
324 if (gen + 1) % 10 == 0 || gen == 0 {
325 let b = best_expr.as_ref().unwrap();
326 println!(
327 "gen {:>3}: best fitness = {:.6} nodes = {:>3} expr = {}",
328 gen + 1,
329 best_fitness,
330 b.size(),
331 b.format(),
332 );
333 }
334
335 // Build next generation
336 let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338 // Elitism: keep top 5
339 let mut indices: Vec<usize> = (0..population.len()).collect();
340 indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341 for &i in indices.iter().take(5) {
342 next_pop.push(population[i].clone());
343 }
344
345 // Fill rest via mutation
346 while next_pop.len() < POP_SIZE {
347 let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348 let child = match rng.range(10) {
349 0..=3 => mutate_grow(parent, &mut rng),
350 4..=6 => mutate_point(parent, &mut rng),
351 7..=8 => mutate_simplify(parent),
352 _ => random_expr(4, &mut rng), // fresh blood
353 };
354
355 // Skip overly large expressions
356 if child.size() > 50 {
357 continue;
358 }
359
360 let fit = fitness(&child, &data);
361 next_pop.push((child, fit));
362 }
363
364 population = next_pop;
365 }
366
367 // Final results
368 println!();
369 let best = best_expr.unwrap();
370 let mut g = ExprGraph::new();
371 let root = best.to_expr(&mut g);
372 let simplified = g.simplify(root);
373 let expr_str = g.fmt_expr(simplified);
374 println!("best expression: {}", expr_str);
375 println!("best fitness: {:.6}", best_fitness);
376
377 // Verify on test points
378 println!("\nverification:");
379 let compiled = g.compile(simplified);
380 for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381 let pred = compiled(&[x]);
382 let exact = target(x);
383 println!(
384 " f({:>5.1}) = {:>8.4} (predicted: {:>8.4}, error: {:>8.4})",
385 x, exact, pred, (pred - exact).abs()
386 );
387 }
388
389 // Symbolic derivative via ExprGraph
390 let dx = g.diff(simplified, 0);
391 let dx = g.simplify(dx);
392 println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}Sourcepub fn node(&self, id: ExprId) -> Node
pub fn node(&self, id: ExprId) -> Node
Look up the node for an ExprId.
Examples found in repository?
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214 match g.node(id) {
215 tang_expr::node::Node::Var(_) => Expr::X,
216 tang_expr::node::Node::Lit(bits) => {
217 let v = f64::from_bits(bits);
218 if v == 0.0 {
219 Expr::Lit(0.0)
220 } else {
221 Expr::Lit(v)
222 }
223 }
224 tang_expr::node::Node::Add(a, b) => {
225 Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226 }
227 tang_expr::node::Node::Mul(a, b) => {
228 Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229 }
230 tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231 tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232 // For operations we don't represent in our AST, just evaluate as literal
233 _ => {
234 let v = g.eval::<f64>(id, &[1.0]); // fallback
235 Expr::Lit(v)
236 }
237 }
238}Sourcepub fn nodes_slice(&self) -> &[Node]
pub fn nodes_slice(&self) -> &[Node]
Read-only access to the node arena for serialization.
Sourcepub fn var(&mut self, n: u16) -> ExprId
pub fn var(&mut self, n: u16) -> ExprId
Create a variable node.
Examples found in repository?
58 fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59 match self {
60 Expr::X => g.var(0),
61 Expr::Lit(v) => g.lit(*v),
62 Expr::Add(a, b) => {
63 let a = a.to_expr(g);
64 let b = b.to_expr(g);
65 g.add(a, b)
66 }
67 Expr::Mul(a, b) => {
68 let a = a.to_expr(g);
69 let b = b.to_expr(g);
70 g.mul(a, b)
71 }
72 Expr::Sin(a) => {
73 let a = a.to_expr(g);
74 g.sin(a)
75 }
76 Expr::Neg(a) => {
77 let a = a.to_expr(g);
78 g.neg(a)
79 }
80 }
81 }Sourcepub fn lit(&mut self, v: f64) -> ExprId
pub fn lit(&mut self, v: f64) -> ExprId
Create a literal node.
Examples found in repository?
58 fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59 match self {
60 Expr::X => g.var(0),
61 Expr::Lit(v) => g.lit(*v),
62 Expr::Add(a, b) => {
63 let a = a.to_expr(g);
64 let b = b.to_expr(g);
65 g.add(a, b)
66 }
67 Expr::Mul(a, b) => {
68 let a = a.to_expr(g);
69 let b = b.to_expr(g);
70 g.mul(a, b)
71 }
72 Expr::Sin(a) => {
73 let a = a.to_expr(g);
74 g.sin(a)
75 }
76 Expr::Neg(a) => {
77 let a = a.to_expr(g);
78 g.neg(a)
79 }
80 }
81 }Sourcepub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId
pub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId
Add two expressions.
Examples found in repository?
58 fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59 match self {
60 Expr::X => g.var(0),
61 Expr::Lit(v) => g.lit(*v),
62 Expr::Add(a, b) => {
63 let a = a.to_expr(g);
64 let b = b.to_expr(g);
65 g.add(a, b)
66 }
67 Expr::Mul(a, b) => {
68 let a = a.to_expr(g);
69 let b = b.to_expr(g);
70 g.mul(a, b)
71 }
72 Expr::Sin(a) => {
73 let a = a.to_expr(g);
74 g.sin(a)
75 }
76 Expr::Neg(a) => {
77 let a = a.to_expr(g);
78 g.neg(a)
79 }
80 }
81 }Sourcepub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId
pub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId
Multiply two expressions.
Examples found in repository?
58 fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59 match self {
60 Expr::X => g.var(0),
61 Expr::Lit(v) => g.lit(*v),
62 Expr::Add(a, b) => {
63 let a = a.to_expr(g);
64 let b = b.to_expr(g);
65 g.add(a, b)
66 }
67 Expr::Mul(a, b) => {
68 let a = a.to_expr(g);
69 let b = b.to_expr(g);
70 g.mul(a, b)
71 }
72 Expr::Sin(a) => {
73 let a = a.to_expr(g);
74 g.sin(a)
75 }
76 Expr::Neg(a) => {
77 let a = a.to_expr(g);
78 g.neg(a)
79 }
80 }
81 }Sourcepub fn neg(&mut self, a: ExprId) -> ExprId
pub fn neg(&mut self, a: ExprId) -> ExprId
Negate an expression.
Examples found in repository?
58 fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59 match self {
60 Expr::X => g.var(0),
61 Expr::Lit(v) => g.lit(*v),
62 Expr::Add(a, b) => {
63 let a = a.to_expr(g);
64 let b = b.to_expr(g);
65 g.add(a, b)
66 }
67 Expr::Mul(a, b) => {
68 let a = a.to_expr(g);
69 let b = b.to_expr(g);
70 g.mul(a, b)
71 }
72 Expr::Sin(a) => {
73 let a = a.to_expr(g);
74 g.sin(a)
75 }
76 Expr::Neg(a) => {
77 let a = a.to_expr(g);
78 g.neg(a)
79 }
80 }
81 }Sourcepub fn sin(&mut self, a: ExprId) -> ExprId
pub fn sin(&mut self, a: ExprId) -> ExprId
Sine.
Examples found in repository?
58 fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59 match self {
60 Expr::X => g.var(0),
61 Expr::Lit(v) => g.lit(*v),
62 Expr::Add(a, b) => {
63 let a = a.to_expr(g);
64 let b = b.to_expr(g);
65 g.add(a, b)
66 }
67 Expr::Mul(a, b) => {
68 let a = a.to_expr(g);
69 let b = b.to_expr(g);
70 g.mul(a, b)
71 }
72 Expr::Sin(a) => {
73 let a = a.to_expr(g);
74 g.sin(a)
75 }
76 Expr::Neg(a) => {
77 let a = a.to_expr(g);
78 g.neg(a)
79 }
80 }
81 }Source§impl ExprGraph
impl ExprGraph
Sourcepub fn simplify(&mut self, expr: ExprId) -> ExprId
pub fn simplify(&mut self, expr: ExprId) -> ExprId
Simplify an expression by applying rewrite rules to fixpoint.
Bottom-up: simplify children first, then match parent. Iterates until no more changes occur.
Examples found in repository?
192fn mutate_simplify(expr: &Expr) -> Expr {
193 let mut g = ExprGraph::new();
194 let root = expr.to_expr(&mut g);
195 let simplified = g.simplify(root);
196 let s = g.fmt_expr(simplified);
197
198 // Quick check: if simplified form is much shorter, it's better
199 let orig = expr.format();
200 if s.len() < orig.len() {
201 // Re-evaluate to confirm it still works
202 let test = g.eval::<f64>(simplified, &[1.0]);
203 let orig_test = expr.eval_at(1.0);
204 if (test - orig_test).abs() < 1e-10 || (test.is_nan() && orig_test.is_nan()) {
205 // Return a new tree built from the simplified graph
206 return expr_from_graph(&g, simplified);
207 }
208 }
209 expr.clone()
210}
211
212/// Reconstruct an Expr tree from an ExprGraph node.
213fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214 match g.node(id) {
215 tang_expr::node::Node::Var(_) => Expr::X,
216 tang_expr::node::Node::Lit(bits) => {
217 let v = f64::from_bits(bits);
218 if v == 0.0 {
219 Expr::Lit(0.0)
220 } else {
221 Expr::Lit(v)
222 }
223 }
224 tang_expr::node::Node::Add(a, b) => {
225 Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226 }
227 tang_expr::node::Node::Mul(a, b) => {
228 Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229 }
230 tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231 tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232 // For operations we don't represent in our AST, just evaluate as literal
233 _ => {
234 let v = g.eval::<f64>(id, &[1.0]); // fallback
235 Expr::Lit(v)
236 }
237 }
238}
239
240// --- Fitness evaluation ------------------------------------------------------
241
242fn target(x: f64) -> f64 {
243 x * x * x.sin()
244}
245
246fn generate_data(n: usize, rng: &mut Lcg) -> Vec<(f64, f64)> {
247 (0..n)
248 .map(|i| {
249 let x = -3.0 + 6.0 * i as f64 / (n - 1) as f64;
250 let noise = (rng.uniform() - 0.5) * 0.01;
251 (x, target(x) + noise)
252 })
253 .collect()
254}
255
256/// Evaluate fitness: MSE + complexity penalty.
257fn fitness(expr: &Expr, data: &[(f64, f64)]) -> f64 {
258 let mut g = ExprGraph::new();
259 let root = expr.to_expr(&mut g);
260 let compiled = g.compile(root);
261
262 let mut mse = 0.0;
263 for &(x, y) in data {
264 let pred = compiled(&[x]);
265 if pred.is_nan() || pred.is_infinite() {
266 return f64::INFINITY;
267 }
268 let err = pred - y;
269 mse += err * err;
270 }
271 mse /= data.len() as f64;
272
273 let penalty = 0.001 * expr.size() as f64;
274 mse + penalty
275}
276
277// --- Selection ---------------------------------------------------------------
278
279fn tournament<'a>(pop: &'a [(Expr, f64)], k: usize, rng: &mut Lcg) -> &'a Expr {
280 let mut best_idx = rng.range(pop.len());
281 for _ in 1..k {
282 let idx = rng.range(pop.len());
283 if pop[idx].1 < pop[best_idx].1 {
284 best_idx = idx;
285 }
286 }
287 &pop[best_idx].0
288}
289
290// --- Main --------------------------------------------------------------------
291
292fn main() {
293 println!("=== Symbolic Regression ===\n");
294 println!("target: f(x) = x^2 * sin(x)\n");
295
296 let mut rng = Lcg::new(42);
297 let data = generate_data(50, &mut rng);
298
299 const POP_SIZE: usize = 200;
300 const GENERATIONS: usize = 100;
301 const TOURNAMENT_K: usize = 5;
302
303 // Initialize population
304 let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305 .map(|_| {
306 let expr = random_expr(4, &mut rng);
307 let fit = fitness(&expr, &data);
308 (expr, fit)
309 })
310 .collect();
311
312 let mut best_expr: Option<Expr> = None;
313 let mut best_fitness = f64::INFINITY;
314
315 for gen in 0..GENERATIONS {
316 // Track best
317 for (expr, fit) in &population {
318 if *fit < best_fitness {
319 best_fitness = *fit;
320 best_expr = Some(expr.clone());
321 }
322 }
323
324 if (gen + 1) % 10 == 0 || gen == 0 {
325 let b = best_expr.as_ref().unwrap();
326 println!(
327 "gen {:>3}: best fitness = {:.6} nodes = {:>3} expr = {}",
328 gen + 1,
329 best_fitness,
330 b.size(),
331 b.format(),
332 );
333 }
334
335 // Build next generation
336 let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338 // Elitism: keep top 5
339 let mut indices: Vec<usize> = (0..population.len()).collect();
340 indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341 for &i in indices.iter().take(5) {
342 next_pop.push(population[i].clone());
343 }
344
345 // Fill rest via mutation
346 while next_pop.len() < POP_SIZE {
347 let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348 let child = match rng.range(10) {
349 0..=3 => mutate_grow(parent, &mut rng),
350 4..=6 => mutate_point(parent, &mut rng),
351 7..=8 => mutate_simplify(parent),
352 _ => random_expr(4, &mut rng), // fresh blood
353 };
354
355 // Skip overly large expressions
356 if child.size() > 50 {
357 continue;
358 }
359
360 let fit = fitness(&child, &data);
361 next_pop.push((child, fit));
362 }
363
364 population = next_pop;
365 }
366
367 // Final results
368 println!();
369 let best = best_expr.unwrap();
370 let mut g = ExprGraph::new();
371 let root = best.to_expr(&mut g);
372 let simplified = g.simplify(root);
373 let expr_str = g.fmt_expr(simplified);
374 println!("best expression: {}", expr_str);
375 println!("best fitness: {:.6}", best_fitness);
376
377 // Verify on test points
378 println!("\nverification:");
379 let compiled = g.compile(simplified);
380 for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381 let pred = compiled(&[x]);
382 let exact = target(x);
383 println!(
384 " f({:>5.1}) = {:>8.4} (predicted: {:>8.4}, error: {:>8.4})",
385 x, exact, pred, (pred - exact).abs()
386 );
387 }
388
389 // Symbolic derivative via ExprGraph
390 let dx = g.diff(simplified, 0);
391 let dx = g.simplify(dx);
392 println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}Source§impl ExprGraph
impl ExprGraph
Source§impl ExprGraph
impl ExprGraph
Sourcepub fn to_wgsl(&self, outputs: &[ExprId], n_inputs: usize) -> WgslKernel
pub fn to_wgsl(&self, outputs: &[ExprId], n_inputs: usize) -> WgslKernel
Generate a WGSL compute shader that evaluates expressions in parallel.
Each work item reads n_inputs values and writes outputs.len() values.
The generated shader uses f32 (GPU native). Shared subexpressions are
computed once per thread.
The caller handles device/pipeline/dispatch (no wgpu dependency here).