sine_wave/
sine_wave.rs

1use std::f32::consts::PI;
2use vexus::{NeuralNetwork, Sigmoid};
3
4fn normalize(x: f32, min: f32, max: f32) -> f32 {
5    (x - min) / (max - min)
6}
7
8fn denormalize(x: f32, min: f32, max: f32) -> f32 {
9    x * (max - min) + min
10}
11
12fn main() {
13    // Create a network with 1 input, two hidden layers, and 1 output
14    // Larger architecture to handle the complexity of sine function
15    let mut nn = NeuralNetwork::new(vec![1, 4, 4, 1], 0.001, Box::new(Sigmoid));
16
17    // Generate training data: sin(x) for x in [0, 2π]
18    let training_data: Vec<(Vec<f32>, Vec<f32>)> = (0..200)
19        .map(|i| {
20            let x = (i as f32) * 2.0 * PI / 200.0;
21            let normalized_x = normalize(x, 0.0, 2.0 * PI);
22            let normalized_sin = normalize(x.sin(), -1.0, 1.0);
23            (vec![normalized_x], vec![normalized_sin])
24        })
25        .collect();
26
27    // Train the network
28    println!("Training...");
29    for epoch in 0..1000000 {
30        let mut total_error = 0.0;
31        for (input, expected) in &training_data {
32            let _outputs = nn.forward(&vec![input[0]]);
33            nn.backpropagate(&vec![expected[0]]);
34            total_error += nn.errors(&vec![expected[0]]);
35        }
36
37        if epoch % 1000 == 0 {
38            println!(
39                "Epoch {}: MSE = {:.6}",
40                epoch,
41                total_error / training_data.len() as f32
42            );
43        }
44    }
45
46    // Test the network
47    println!("\nTesting...");
48    let test_points = vec![0.0, PI / 4.0, PI / 2.0, PI, 3.0 * PI / 2.0, 2.0 * PI];
49    for x in test_points {
50        let normalized_x = normalize(x, 0.0, 2.0 * PI);
51        let predicted = denormalize(nn.forward(&vec![normalized_x])[0], -1.0, 1.0);
52        println!(
53            "x = {:.3}, sin(x) = {:.3}, predicted = {:.3}, error = {:.3}",
54            x,
55            x.sin(),
56            predicted,
57            (x.sin() - predicted).abs()
58        );
59    }
60}