supervised_regression/
supervised_regression.rs

1//! Supervised regression with a small Feed-Forward Network (Train Station public API)
2//!
3//! - Continuous regression on a simple y = 2x1 - 3x2 + 0.5 noisy dataset
4//! - Reuses `feedforward_network.rs` building block
5//! - Proper parameter linking, zero_grad, backward, step, and graph clearing
6//! - Gradient clipping and helpful console logging (loss, RMSE, R^2)
7//!
8//! Run:
9//!   cargo run --release --example supervised_regression
10
11use train_station::{
12    gradtrack::clear_all_graphs_known,
13    optimizers::{Adam, Optimizer},
14    Tensor,
15};
16
17#[allow(clippy::duplicate_mod)]
18#[path = "../neural_networks/feedforward_network.rs"]
19mod feedforward_network;
20use feedforward_network::{FeedForwardConfig, FeedForwardNetwork};
21
22fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
23    let mut total_sq = 0.0f32;
24    for p in parameters.iter() {
25        if let Some(g) = p.grad_owned() {
26            for &v in g.data() {
27                total_sq += v * v;
28            }
29        }
30    }
31    let norm = total_sq.sqrt();
32    if norm > max_norm {
33        let scale = max_norm / (norm + eps);
34        for p in parameters.iter_mut() {
35            if let Some(g) = p.grad_owned() {
36                p.set_grad(g.mul_scalar(scale));
37            }
38        }
39    }
40}
41
42fn mse(pred: &Tensor, target: &Tensor) -> Tensor {
43    pred.sub_tensor(target).pow_scalar(2.0).mean()
44}
45
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47    mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51    // R^2 = 1 - SS_res / SS_tot
52    let y = target;
53    let y_mean = y.mean();
54    let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55    let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56    let ss_res_v = ss_res.value();
57    let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58    1.0 - (ss_res_v / ss_tot_v)
59}
60
61pub fn main() -> Result<(), Box<dyn std::error::Error>> {
62    println!("=== Supervised Regression Example (MSE) ===");
63
64    // Generate simple synthetic data: y = 2*x1 - 3*x2 + 0.5 + noise
65    let n = 1024usize;
66    let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
67    let mut ys: Vec<f32> = Vec::with_capacity(n);
68    // Simple LCG RNG for reproducibility
69    let mut state: u64 = 123456789;
70    let mut rand_f32 = || {
71        state = state.wrapping_mul(1664525).wrapping_add(1013904223);
72        (state >> 16) as f32 / (u32::MAX as f32)
73    };
74    for _ in 0..n {
75        let x1 = rand_f32() * 2.0 - 1.0;
76        let x2 = rand_f32() * 2.0 - 1.0;
77        let noise = (rand_f32() * 2.0 - 1.0) * 0.05;
78        let y = 2.0 * x1 - 3.0 * x2 + 0.5 + noise;
79        xs.push(x1);
80        xs.push(x2);
81        ys.push(y);
82    }
83
84    // Normalize targets to [-1, 1] (max-abs scaling) for reasonable loss magnitudes
85    let mut max_abs = 0.0f32;
86    for &v in &ys {
87        let a = v.abs();
88        if a > max_abs {
89            max_abs = a;
90        }
91    }
92    if max_abs < 1e-8 {
93        max_abs = 1.0;
94    }
95    for v in ys.iter_mut() {
96        *v /= max_abs;
97    }
98
99    // Normalize inputs per-feature to [-1, 1] (min-max scaling)
100    let mut min1 = f32::INFINITY;
101    let mut max1 = f32::NEG_INFINITY;
102    let mut min2 = f32::INFINITY;
103    let mut max2 = f32::NEG_INFINITY;
104    for i in (0..xs.len()).step_by(2) {
105        let a = xs[i];
106        let b = xs[i + 1];
107        if a < min1 {
108            min1 = a;
109        }
110        if a > max1 {
111            max1 = a;
112        }
113        if b < min2 {
114            min2 = b;
115        }
116        if b > max2 {
117            max2 = b;
118        }
119    }
120    let rng1 = (max1 - min1).max(1e-8);
121    let rng2 = (max2 - min2).max(1e-8);
122    for i in (0..xs.len()).step_by(2) {
123        let a = xs[i];
124        let b = xs[i + 1];
125        xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
126        xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
127    }
128
129    // Train/Val split (80/20)
130    let n_train = (n as f32 * 0.8) as usize;
131    let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
132    let y_train = Tensor::from_slice(&ys[..n_train], vec![n_train, 1]).unwrap();
133    let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
134    let y_val = Tensor::from_slice(&ys[n_train..], vec![n - n_train, 1]).unwrap();
135
136    // Model config: 2 -> 64 -> 64 -> 1
137    let cfg = FeedForwardConfig {
138        input_size: 2,
139        hidden_sizes: vec![64, 64],
140        output_size: 1,
141        use_bias: true,
142    };
143    let mut net = FeedForwardNetwork::new(cfg, Some(2025));
144
145    // Optimizer and parameter linking
146    let mut opt = Adam::with_learning_rate(1e-3);
147    for p in net.parameters() {
148        opt.add_parameter(p);
149    }
150
151    let epochs = 400usize;
152    let max_grad_norm = 1.0f32;
153    let mut best_val_rmse = f32::INFINITY;
154    let mut best_val_r2 = -f32::INFINITY;
155
156    for e in 0..epochs {
157        // Zero grads
158        {
159            let mut params = net.parameters();
160            opt.zero_grad(&mut params);
161        }
162
163        // Forward
164        let pred = net.forward(&x_train);
165        let mut loss = mse(&pred, &y_train);
166        loss.backward(None);
167
168        // Step
169        {
170            let params = net.parameters();
171            let mut with_grads: Vec<&mut Tensor> = Vec::new();
172            for p in params {
173                if p.grad_owned().is_some() {
174                    with_grads.push(p);
175                }
176            }
177            if !with_grads.is_empty() {
178                clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
179                opt.step(&mut with_grads);
180                opt.zero_grad(&mut with_grads);
181            }
182        }
183
184        // Metrics
185        let train_rmse = rmse(&pred, &y_train);
186        let train_r2 = r2_score(&pred, &y_train);
187        let val_pred = net.forward(&x_val);
188        let val_rmse = rmse(&val_pred, &y_val);
189        let val_r2 = r2_score(&val_pred, &y_val);
190        if val_rmse < best_val_rmse {
191            best_val_rmse = val_rmse;
192        }
193        if val_r2 > best_val_r2 {
194            best_val_r2 = val_r2;
195        }
196
197        if e % 20 == 0 || e + 1 == epochs {
198            // Clamp displayed R^2 to avoid huge negative prints on early epochs
199            let train_r2_disp = train_r2.max(-10.0);
200            let val_r2_disp = val_r2.max(-10.0);
201            println!(
202                "epoch {:4} | train_rmse={:.4} r2={:.3} | val_rmse={:.4} r2={:.3} | best_val_rmse={:.4} best_val_r2={:.3}",
203                e, train_rmse, train_r2_disp, val_rmse, val_r2_disp, best_val_rmse, best_val_r2
204            );
205        }
206
207        clear_all_graphs_known();
208    }
209
210    // Quick sanity predictions on small samples
211    let sample = Tensor::from_slice(&[0.5, -0.25, -0.8, 0.3], vec![2, 2]).unwrap();
212    let sample_pred = net.forward(&sample);
213    println!("samples pred: {:?}", sample_pred.data());
214
215    println!("=== Supervised regression finished ===");
216    Ok(())
217}