sine_wave/
sine_wave.rs

1use std::f32::consts::PI;
2use vexus::NeuralNetwork;
3
4fn normalize(x: f32, min: f32, max: f32) -> f32 {
5    (x - min) / (max - min)
6}
7
8fn denormalize(x: f32, min: f32, max: f32) -> f32 {
9    x * (max - min) + min
10}
11
12fn main() {
13    // Create a network with 1 input, two hidden layers, and 1 output
14    // Larger architecture to handle the complexity of sine function
15    let mut nn = NeuralNetwork::new(vec![1, 32, 32, 1], 0.005);
16
17    // Generate training data: sin(x) for x in [0, 2π]
18    let training_data: Vec<(Vec<f32>, Vec<f32>)> = (0..200)
19        .map(|i| {
20            let x = (i as f32) * 2.0 * PI / 200.0;
21            let normalized_x = normalize(x, 0.0, 2.0 * PI);
22            let normalized_sin = normalize(x.sin(), -1.0, 1.0);
23            (vec![normalized_x], vec![normalized_sin])
24        })
25        .collect();
26
27    // Train the network
28    println!("Training...");
29    for epoch in 0..20000 {
30        let mut total_error = 0.0;
31        for (input, expected) in &training_data {
32            nn.forward(input.clone());
33            let output = nn.get_outputs();
34            let error = expected[0] - output[0];
35            total_error += error * error;
36            nn.backwards(vec![error]);
37        }
38
39        if epoch % 1000 == 0 {
40            println!(
41                "Epoch {}: MSE = {:.6}",
42                epoch,
43                total_error / training_data.len() as f32
44            );
45        }
46    }
47
48    // Test the network
49    println!("\nTesting...");
50    let test_points = vec![0.0, PI / 4.0, PI / 2.0, PI, 3.0 * PI / 2.0, 2.0 * PI];
51    for x in test_points {
52        let normalized_x = normalize(x, 0.0, 2.0 * PI);
53        nn.forward(vec![normalized_x]);
54        let predicted = denormalize(nn.get_outputs()[0], -1.0, 1.0);
55        println!(
56            "x = {:.3}, sin(x) = {:.3}, predicted = {:.3}, error = {:.3}",
57            x,
58            x.sin(),
59            predicted,
60            (x.sin() - predicted).abs()
61        );
62    }
63}