supervised_classification/
supervised_classification.rs

1//! Supervised classification with Feed-Forward Network (softmax + cross-entropy)
2//!
3//! - Multi-class toy dataset (3 classes) generated synthetically
4//! - Reuses `feedforward_network.rs` building block
5//! - Stable cross-entropy over logits (no softmax in loss path)
6//! - Input normalization to [-1, 1], parameter linking, clipping, graph clearing
7//! - Logs loss and accuracy
8//!
9//! Run:
10//!   cargo run --release --example supervised_classification
11
12use train_station::{
13    gradtrack::clear_all_graphs_known,
14    optimizers::{Adam, Optimizer},
15    Tensor,
16};
17
18#[allow(clippy::duplicate_mod)]
19#[path = "../neural_networks/feedforward_network.rs"]
20mod feedforward_network;
21use feedforward_network::{FeedForwardConfig, FeedForwardNetwork};
22
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24    let mut total_sq = 0.0f32;
25    for p in parameters.iter() {
26        if let Some(g) = p.grad_owned() {
27            for &v in g.data() {
28                total_sq += v * v;
29            }
30        }
31    }
32    let norm = total_sq.sqrt();
33    if norm > max_norm {
34        let scale = max_norm / (norm + eps);
35        for p in parameters.iter_mut() {
36            if let Some(g) = p.grad_owned() {
37                p.set_grad(g.mul_scalar(scale));
38            }
39        }
40    }
41}
42
43// Cross-entropy over logits: CE = -mean(log_softmax(logits)[range, labels])
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
60
61fn accuracy_from_logits(
62    logits: &Tensor,
63    labels: &[usize],
64    batch: usize,
65    num_classes: usize,
66) -> f32 {
67    let row = logits.data();
68    let mut correct = 0usize;
69    for (i, &label) in labels.iter().enumerate().take(batch) {
70        let base = i * num_classes;
71        let mut best_j = 0usize;
72        let mut best_v = row[base];
73        for j in 1..num_classes {
74            let v = row[base + j];
75            if v > best_v {
76                best_v = v;
77                best_j = j;
78            }
79        }
80        if best_j == label {
81            correct += 1;
82        }
83    }
84    correct as f32 / batch as f32
85}
86
87pub fn main() -> Result<(), Box<dyn std::error::Error>> {
88    println!("=== Supervised Classification Example (Cross-Entropy) ===");
89
90    // Synthetic 2D inputs, 3 classes with linear-ish separations
91    let n = 1200usize;
92    let classes = 3usize;
93    let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
94    let mut ys: Vec<usize> = Vec::with_capacity(n);
95
96    // Simple RNG
97    let mut state: u64 = 424242;
98    let mut rand_f32 = || {
99        state = state.wrapping_mul(1664525).wrapping_add(1013904223);
100        (state >> 16) as f32 / (u32::MAX as f32)
101    };
102
103    for _ in 0..n {
104        let x1 = rand_f32() * 4.0 - 2.0;
105        let x2 = rand_f32() * 4.0 - 2.0;
106        // Class by quadrant-ish rule with noise
107        let mut c = if x1 + 0.5 * x2 > 0.5 {
108            0
109        } else if x1 - x2 < -0.5 {
110            1
111        } else {
112            2
113        };
114        if rand_f32() < 0.05 {
115            c = (c + 1) % classes;
116        }
117        xs.push(x1);
118        xs.push(x2);
119        ys.push(c);
120    }
121
122    // Normalize inputs per-feature to [-1, 1]
123    let mut min1 = f32::INFINITY;
124    let mut max1 = f32::NEG_INFINITY;
125    let mut min2 = f32::INFINITY;
126    let mut max2 = f32::NEG_INFINITY;
127    for i in (0..xs.len()).step_by(2) {
128        let a = xs[i];
129        let b = xs[i + 1];
130        if a < min1 {
131            min1 = a;
132        }
133        if a > max1 {
134            max1 = a;
135        }
136        if b < min2 {
137            min2 = b;
138        }
139        if b > max2 {
140            max2 = b;
141        }
142    }
143    let rng1 = (max1 - min1).max(1e-8);
144    let rng2 = (max2 - min2).max(1e-8);
145    for i in (0..xs.len()).step_by(2) {
146        let a = xs[i];
147        let b = xs[i + 1];
148        xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
149        xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
150    }
151
152    // Train/Val split (80/20)
153    let n_train = (n as f32 * 0.8) as usize;
154    let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
155    let y_train = ys[..n_train].to_vec();
156    let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
157    let y_val = ys[n_train..].to_vec();
158
159    // Model: 2 -> 64 -> 64 -> 3 (logits)
160    let cfg = FeedForwardConfig {
161        input_size: 2,
162        hidden_sizes: vec![64, 64],
163        output_size: classes,
164        use_bias: true,
165    };
166    let mut net = FeedForwardNetwork::new(cfg, Some(303));
167
168    // Optimizer
169    let mut opt = Adam::with_learning_rate(1e-3);
170    for p in net.parameters() {
171        opt.add_parameter(p);
172    }
173
174    let epochs = 300usize;
175    let max_grad_norm = 1.0f32;
176    let mut best_val_acc = 0.0f32;
177    let mut best_val_loss = f32::INFINITY;
178
179    for e in 0..epochs {
180        // Zero grads
181        {
182            let mut params = net.parameters();
183            opt.zero_grad(&mut params);
184        }
185
186        // Forward logits
187        let logits = net.forward(&x_train);
188        let mut loss = cross_entropy_logits(&logits, &y_train, n_train, classes);
189        loss.backward(None);
190
191        // Step clipped
192        {
193            let params = net.parameters();
194            let mut with_grads: Vec<&mut Tensor> = Vec::new();
195            for p in params {
196                if p.grad_owned().is_some() {
197                    with_grads.push(p);
198                }
199            }
200            if !with_grads.is_empty() {
201                clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
202                opt.step(&mut with_grads);
203                opt.zero_grad(&mut with_grads);
204            }
205        }
206
207        // Metrics
208        let train_acc = accuracy_from_logits(&logits, &y_train, n_train, classes);
209        let val_logits = net.forward(&x_val);
210        let val_loss = cross_entropy_logits(&val_logits, &y_val, n - n_train, classes).value();
211        let val_acc = accuracy_from_logits(&val_logits, &y_val, n - n_train, classes);
212        if val_acc > best_val_acc {
213            best_val_acc = val_acc;
214        }
215        if val_loss < best_val_loss {
216            best_val_loss = val_loss;
217        }
218
219        if e % 10 == 0 || e + 1 == epochs {
220            println!(
221                "epoch {:4} | loss={:.4} acc={:.3} | val_loss={:.4} val_acc={:.3} | best_val_acc={:.3}",
222                e, loss.value(), train_acc, val_loss, val_acc, best_val_acc
223            );
224        }
225
226        clear_all_graphs_known();
227    }
228
229    // Quick sample preds via softmax
230    let samples = Tensor::from_slice(&[-1.0, -1.0, 0.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
231    let sm = net.forward(&samples).softmax(1);
232    println!("sample class probs: {:?}", sm.data());
233
234    println!("=== Supervised classification finished ===");
235    Ok(())
236}