supervised_bce/
supervised_bce.rs

1//! Supervised learning with a small Feed-Forward Network (Train Station public API)
2//!
3//! - Binary classification on a simple XOR-like dataset
4//! - Reuses `feedforward_network.rs` building block (no duplication)
5//! - Proper parameter linking, zero_grad, backward, step, and graph clearing
6//! - Gradient clipping and helpful console logging
7//!
8//! Run:
9//!   cargo run --release --example supervised_ffn
10
11use train_station::{
12    gradtrack::clear_all_graphs_known,
13    optimizers::{Adam, Optimizer},
14    Tensor,
15};
16
17#[allow(clippy::duplicate_mod)]
18#[path = "../neural_networks/feedforward_network.rs"]
19mod feedforward_network;
20use feedforward_network::{FeedForwardConfig, FeedForwardNetwork};
21
22// Small helper: global-norm gradient clipping
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24    let mut total_sq = 0.0f32;
25    for p in parameters.iter() {
26        if let Some(g) = p.grad_owned() {
27            for &v in g.data() {
28                total_sq += v * v;
29            }
30        }
31    }
32    let norm = total_sq.sqrt();
33    if norm > max_norm {
34        let scale = max_norm / (norm + eps);
35        for p in parameters.iter_mut() {
36            if let Some(g) = p.grad_owned() {
37                p.set_grad(g.mul_scalar(scale));
38            }
39        }
40    }
41}
42
43fn accuracy(pred: &Tensor, targets: &Tensor) -> f32 {
44    // pred: [B,1] with sigmoid; threshold at 0.5
45    let p = pred.data();
46    let t = targets.data();
47    let mut correct = 0usize;
48    for i in 0..p.len() {
49        let yhat = if p[i] >= 0.5 { 1.0 } else { 0.0 };
50        if (yhat - t[i]).abs() < 1e-6 {
51            correct += 1;
52        }
53    }
54    correct as f32 / (p.len() as f32)
55}
56
57// Numerically stable BCE with logits:
58// L = mean( relu(z) - z*y + log(1 + exp(-|z|)) )
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
67
68pub fn main() -> Result<(), Box<dyn std::error::Error>> {
69    println!("=== Supervised FFN Example (XOR) ===");
70
71    // Dataset: XOR (repeat to form a small batch)
72    let inputs: Vec<f32> = vec![
73        0.0, 0.0, // -> 0
74        0.0, 1.0, // -> 1
75        1.0, 0.0, // -> 1
76        1.0, 1.0, // -> 0
77    ];
78    let targets: Vec<f32> = vec![0.0, 1.0, 1.0, 0.0];
79
80    // Repeat the base patterns to stabilize training
81    let repeats = 64usize; // effective batch = 4 * repeats = 256
82    let mut xs = Vec::with_capacity(repeats * inputs.len());
83    let mut ys = Vec::with_capacity(repeats * targets.len());
84    for _ in 0..repeats {
85        xs.extend_from_slice(&inputs);
86        ys.extend_from_slice(&targets);
87    }
88
89    let batch = xs.len() / 2; // two features
90    let x_t = Tensor::from_slice(&xs, vec![batch, 2]).unwrap();
91    let y_t = Tensor::from_slice(&ys, vec![batch, 1]).unwrap();
92
93    // Model config: 2 -> 32 -> 32 -> 1, final sigmoid via loss path
94    let cfg = FeedForwardConfig {
95        input_size: 2,
96        hidden_sizes: vec![32, 32],
97        output_size: 1,
98        use_bias: true,
99    };
100    let mut net = FeedForwardNetwork::new(cfg, Some(777));
101
102    // Optimizer and parameter linking
103    let mut opt = Adam::with_learning_rate(1e-3);
104    for p in net.parameters() {
105        opt.add_parameter(p);
106    }
107
108    let epochs = 1000usize;
109    let max_grad_norm = 1.0f32;
110    let mut best_loss = f32::INFINITY;
111    let mut best_acc = 0.0f32;
112
113    for e in 0..epochs {
114        // Zero grads each iteration
115        {
116            let mut params = net.parameters();
117            opt.zero_grad(&mut params);
118        }
119
120        // Forward -> logits; use numerically stable BCE-with-logits for loss
121        let logits = net.forward(&x_t);
122        let mut loss = bce_with_logits(&logits, &y_t);
123        loss.backward(None);
124
125        // Step only params with grads
126        {
127            let params = net.parameters();
128            let mut with_grads: Vec<&mut Tensor> = Vec::new();
129            for p in params {
130                if p.grad_owned().is_some() {
131                    with_grads.push(p);
132                }
133            }
134            if !with_grads.is_empty() {
135                clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
136                opt.step(&mut with_grads);
137                opt.zero_grad(&mut with_grads);
138            }
139        }
140
141        // Metrics (use sigmoid only for reporting accuracy)
142        let preds = logits.sigmoid();
143        let acc = accuracy(&preds, &y_t);
144        if loss.value() < best_loss {
145            best_loss = loss.value();
146        }
147        if acc > best_acc {
148            best_acc = acc;
149        }
150        if e % 10 == 0 || e + 1 == epochs {
151            println!(
152                "epoch {:4} | loss={:.5} acc={:.3} | best_loss={:.5} best_acc={:.3}",
153                e,
154                loss.value(),
155                acc,
156                best_loss,
157                best_acc
158            );
159        }
160
161        // Clear graphs to avoid stale accumulation across epochs
162        clear_all_graphs_known();
163    }
164
165    // Quick sanity check predictions
166    let test = Tensor::from_slice(&inputs, vec![4, 2]).unwrap();
167    let out = net.forward(&test).sigmoid();
168    println!("predictions (approx): {:?}", out.data());
169
170    println!("=== Supervised training finished ===");
171    Ok(())
172}