1use tang_expr::{ExprGraph, ExprId};
11
12struct Lcg(u64);
15
16impl Lcg {
17 fn new(seed: u64) -> Self {
18 Self(seed)
19 }
20 fn next(&mut self) -> u64 {
21 self.0 = self
22 .0
23 .wrapping_mul(6364136223846793005)
24 .wrapping_add(1442695040888963407);
25 self.0
26 }
27 fn uniform(&mut self) -> f64 {
28 (self.next() >> 11) as f64 / (1u64 << 53) as f64
29 }
30 fn range(&mut self, n: usize) -> usize {
31 (self.uniform() * n as f64) as usize % n
32 }
33}
34
35#[derive(Clone, Debug)]
38enum Expr {
39 X,
40 Lit(f64),
41 Add(Box<Expr>, Box<Expr>),
42 Mul(Box<Expr>, Box<Expr>),
43 Sin(Box<Expr>),
44 Neg(Box<Expr>),
45}
46
47impl Expr {
48 fn size(&self) -> usize {
50 match self {
51 Expr::X | Expr::Lit(_) => 1,
52 Expr::Sin(a) | Expr::Neg(a) => 1 + a.size(),
53 Expr::Add(a, b) | Expr::Mul(a, b) => 1 + a.size() + b.size(),
54 }
55 }
56
57 fn to_expr(&self, g: &mut ExprGraph) -> ExprId {
59 match self {
60 Expr::X => g.var(0),
61 Expr::Lit(v) => g.lit(*v),
62 Expr::Add(a, b) => {
63 let a = a.to_expr(g);
64 let b = b.to_expr(g);
65 g.add(a, b)
66 }
67 Expr::Mul(a, b) => {
68 let a = a.to_expr(g);
69 let b = b.to_expr(g);
70 g.mul(a, b)
71 }
72 Expr::Sin(a) => {
73 let a = a.to_expr(g);
74 g.sin(a)
75 }
76 Expr::Neg(a) => {
77 let a = a.to_expr(g);
78 g.neg(a)
79 }
80 }
81 }
82
83 fn format(&self) -> String {
85 let mut g = ExprGraph::new();
86 let root = self.to_expr(&mut g);
87 g.fmt_expr(root)
88 }
89
90 fn eval_at(&self, x: f64) -> f64 {
92 let mut g = ExprGraph::new();
93 let root = self.to_expr(&mut g);
94 g.eval(root, &[x])
95 }
96}
97
98fn random_expr(depth: usize, rng: &mut Lcg) -> Expr {
101 if depth == 0 || (depth < 3 && rng.uniform() < 0.3) {
102 return match rng.range(3) {
103 0 => Expr::X,
104 1 => Expr::Lit((rng.uniform() * 4.0 - 2.0) * 10.0_f64.powi(-((rng.uniform() * 2.0) as i32))),
105 _ => Expr::X,
106 };
107 }
108
109 match rng.range(5) {
110 0 => Expr::Add(
111 Box::new(random_expr(depth - 1, rng)),
112 Box::new(random_expr(depth - 1, rng)),
113 ),
114 1 | 2 => Expr::Mul(
115 Box::new(random_expr(depth - 1, rng)),
116 Box::new(random_expr(depth - 1, rng)),
117 ),
118 3 => Expr::Sin(Box::new(random_expr(depth - 1, rng))),
119 _ => Expr::Neg(Box::new(random_expr(depth - 1, rng))),
120 }
121}
122
123fn mutate_grow(expr: &Expr, rng: &mut Lcg) -> Expr {
127 let size = expr.size();
128 let target = rng.range(size);
129 grow_at(expr, target, &mut 0, rng)
130}
131
132fn grow_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
133 if *counter == target {
134 *counter += expr.size(); return random_expr(3, rng);
136 }
137 *counter += 1;
138 match expr {
139 Expr::X => Expr::X,
140 Expr::Lit(v) => Expr::Lit(*v),
141 Expr::Add(a, b) => Expr::Add(
142 Box::new(grow_at(a, target, counter, rng)),
143 Box::new(grow_at(b, target, counter, rng)),
144 ),
145 Expr::Mul(a, b) => Expr::Mul(
146 Box::new(grow_at(a, target, counter, rng)),
147 Box::new(grow_at(b, target, counter, rng)),
148 ),
149 Expr::Sin(a) => Expr::Sin(Box::new(grow_at(a, target, counter, rng))),
150 Expr::Neg(a) => Expr::Neg(Box::new(grow_at(a, target, counter, rng))),
151 }
152}
153
154fn mutate_point(expr: &Expr, rng: &mut Lcg) -> Expr {
156 let size = expr.size();
157 let target = rng.range(size);
158 point_at(expr, target, &mut 0, rng)
159}
160
161fn point_at(expr: &Expr, target: usize, counter: &mut usize, rng: &mut Lcg) -> Expr {
162 if *counter == target {
163 *counter += 1;
164 return match expr {
165 Expr::X => Expr::Lit(rng.uniform() * 2.0 - 1.0),
166 Expr::Lit(_) => Expr::X,
167 Expr::Add(a, b) => Expr::Mul(a.clone(), b.clone()),
168 Expr::Mul(a, b) => Expr::Add(a.clone(), b.clone()),
169 Expr::Sin(a) => Expr::Neg(a.clone()),
170 Expr::Neg(a) => Expr::Sin(a.clone()),
171 };
172 }
173 *counter += 1;
174 match expr {
175 Expr::X => Expr::X,
176 Expr::Lit(v) => Expr::Lit(*v),
177 Expr::Add(a, b) => Expr::Add(
178 Box::new(point_at(a, target, counter, rng)),
179 Box::new(point_at(b, target, counter, rng)),
180 ),
181 Expr::Mul(a, b) => Expr::Mul(
182 Box::new(point_at(a, target, counter, rng)),
183 Box::new(point_at(b, target, counter, rng)),
184 ),
185 Expr::Sin(a) => Expr::Sin(Box::new(point_at(a, target, counter, rng))),
186 Expr::Neg(a) => Expr::Neg(Box::new(point_at(a, target, counter, rng))),
187 }
188}
189
190fn mutate_simplify(expr: &Expr) -> Expr {
193 let mut g = ExprGraph::new();
194 let root = expr.to_expr(&mut g);
195 let simplified = g.simplify(root);
196 let s = g.fmt_expr(simplified);
197
198 let orig = expr.format();
200 if s.len() < orig.len() {
201 let test = g.eval::<f64>(simplified, &[1.0]);
203 let orig_test = expr.eval_at(1.0);
204 if (test - orig_test).abs() < 1e-10 || (test.is_nan() && orig_test.is_nan()) {
205 return expr_from_graph(&g, simplified);
207 }
208 }
209 expr.clone()
210}
211
212fn expr_from_graph(g: &ExprGraph, id: ExprId) -> Expr {
214 match g.node(id) {
215 tang_expr::node::Node::Var(_) => Expr::X,
216 tang_expr::node::Node::Lit(bits) => {
217 let v = f64::from_bits(bits);
218 if v == 0.0 {
219 Expr::Lit(0.0)
220 } else {
221 Expr::Lit(v)
222 }
223 }
224 tang_expr::node::Node::Add(a, b) => {
225 Expr::Add(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
226 }
227 tang_expr::node::Node::Mul(a, b) => {
228 Expr::Mul(Box::new(expr_from_graph(g, a)), Box::new(expr_from_graph(g, b)))
229 }
230 tang_expr::node::Node::Neg(a) => Expr::Neg(Box::new(expr_from_graph(g, a))),
231 tang_expr::node::Node::Sin(a) => Expr::Sin(Box::new(expr_from_graph(g, a))),
232 _ => {
234 let v = g.eval::<f64>(id, &[1.0]); Expr::Lit(v)
236 }
237 }
238}
239
240fn target(x: f64) -> f64 {
243 x * x * x.sin()
244}
245
246fn generate_data(n: usize, rng: &mut Lcg) -> Vec<(f64, f64)> {
247 (0..n)
248 .map(|i| {
249 let x = -3.0 + 6.0 * i as f64 / (n - 1) as f64;
250 let noise = (rng.uniform() - 0.5) * 0.01;
251 (x, target(x) + noise)
252 })
253 .collect()
254}
255
256fn fitness(expr: &Expr, data: &[(f64, f64)]) -> f64 {
258 let mut g = ExprGraph::new();
259 let root = expr.to_expr(&mut g);
260 let compiled = g.compile(root);
261
262 let mut mse = 0.0;
263 for &(x, y) in data {
264 let pred = compiled(&[x]);
265 if pred.is_nan() || pred.is_infinite() {
266 return f64::INFINITY;
267 }
268 let err = pred - y;
269 mse += err * err;
270 }
271 mse /= data.len() as f64;
272
273 let penalty = 0.001 * expr.size() as f64;
274 mse + penalty
275}
276
277fn tournament<'a>(pop: &'a [(Expr, f64)], k: usize, rng: &mut Lcg) -> &'a Expr {
280 let mut best_idx = rng.range(pop.len());
281 for _ in 1..k {
282 let idx = rng.range(pop.len());
283 if pop[idx].1 < pop[best_idx].1 {
284 best_idx = idx;
285 }
286 }
287 &pop[best_idx].0
288}
289
290fn main() {
293 println!("=== Symbolic Regression ===\n");
294 println!("target: f(x) = x^2 * sin(x)\n");
295
296 let mut rng = Lcg::new(42);
297 let data = generate_data(50, &mut rng);
298
299 const POP_SIZE: usize = 200;
300 const GENERATIONS: usize = 100;
301 const TOURNAMENT_K: usize = 5;
302
303 let mut population: Vec<(Expr, f64)> = (0..POP_SIZE)
305 .map(|_| {
306 let expr = random_expr(4, &mut rng);
307 let fit = fitness(&expr, &data);
308 (expr, fit)
309 })
310 .collect();
311
312 let mut best_expr: Option<Expr> = None;
313 let mut best_fitness = f64::INFINITY;
314
315 for gen in 0..GENERATIONS {
316 for (expr, fit) in &population {
318 if *fit < best_fitness {
319 best_fitness = *fit;
320 best_expr = Some(expr.clone());
321 }
322 }
323
324 if (gen + 1) % 10 == 0 || gen == 0 {
325 let b = best_expr.as_ref().unwrap();
326 println!(
327 "gen {:>3}: best fitness = {:.6} nodes = {:>3} expr = {}",
328 gen + 1,
329 best_fitness,
330 b.size(),
331 b.format(),
332 );
333 }
334
335 let mut next_pop = Vec::with_capacity(POP_SIZE);
337
338 let mut indices: Vec<usize> = (0..population.len()).collect();
340 indices.sort_by(|&a, &b| population[a].1.partial_cmp(&population[b].1).unwrap());
341 for &i in indices.iter().take(5) {
342 next_pop.push(population[i].clone());
343 }
344
345 while next_pop.len() < POP_SIZE {
347 let parent = tournament(&population, TOURNAMENT_K, &mut rng);
348 let child = match rng.range(10) {
349 0..=3 => mutate_grow(parent, &mut rng),
350 4..=6 => mutate_point(parent, &mut rng),
351 7..=8 => mutate_simplify(parent),
352 _ => random_expr(4, &mut rng), };
354
355 if child.size() > 50 {
357 continue;
358 }
359
360 let fit = fitness(&child, &data);
361 next_pop.push((child, fit));
362 }
363
364 population = next_pop;
365 }
366
367 println!();
369 let best = best_expr.unwrap();
370 let mut g = ExprGraph::new();
371 let root = best.to_expr(&mut g);
372 let simplified = g.simplify(root);
373 let expr_str = g.fmt_expr(simplified);
374 println!("best expression: {}", expr_str);
375 println!("best fitness: {:.6}", best_fitness);
376
377 println!("\nverification:");
379 let compiled = g.compile(simplified);
380 for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
381 let pred = compiled(&[x]);
382 let exact = target(x);
383 println!(
384 " f({:>5.1}) = {:>8.4} (predicted: {:>8.4}, error: {:>8.4})",
385 x, exact, pred, (pred - exact).abs()
386 );
387 }
388
389 let dx = g.diff(simplified, 0);
391 let dx = g.simplify(dx);
392 println!("\nsymbolic derivative: d/dx [{}] = {}", expr_str, g.fmt_expr(dx));
393}