Function batch_run

Source
pub fn batch_run(
    networks: &mut [NeuralNetwork],
    inputs: &[f64],
) -> Vec<Vec<f64>>
Expand description

Returns a Vector of Vectors that makes up the output of all NeuralNetworks given to this function. This function does this by repeatedly calling NeuralNetwork::run() so it isn’t any more efficent, its simply here for convenience.

Examples found in repository?
examples/simple_use_case.rs (line 49)
21fn main() {
22    /// This is our target, we will consider `NeuralNetworks` that get closer to this to be learning.
23    const TARGET:f64 = 10.0;
24
25    /// The margin of error for the networks.
26    const MARGIN:f64 = 0.1;
27
28    /// This is the `INPUT` for the networks, the goal being for them to "learn" how to take this number as
29    /// an input and output `TARGET`. When they output a value within `MARGIN` of `TARGET` that `NeuralNetwork`
30    /// will be considered to have reach the goal.
31    const INPUT:f64 = 1.0;
32
33    /// The maximum number of `GENERATIONS` that will be ran, less may be ran if a network gets within `margin` of
34    /// the target before `GENERATIONS` number of generations are ran.
35    const GENERATIONS:usize = 10_000;
36
37    // Create a `Vector` containing all `NeuralNetwork`s in the current generation using `batch_mutate()` and `NeuralNetwork::new()`
38    let mut networks:Vec<NeuralNetwork> = batch_mutate(5,0.25,&mut NeuralNetwork::new(1.0,1,3,1), true);
39
40    // This stores the closest value found by the network, it defaults to negative infinity.
41    let mut closest_value:f64 = f64::NEG_INFINITY;
42
43    // Get the current instant so that it can later be used to time how long it took to finish/fail
44    let time:Instant = Instant::now();
45
46    // For `generation` in `GENERATIONS` perform `batch_run()` and check if the networks are getting closer.
47    for generation in 0..GENERATIONS {
48        // Run the networks using `INPUT` as an input and store the output in `output`
49        let outputs:Vec<Vec<f64>> = batch_run(&mut networks, &vec![INPUT]);
50
51        // The `closest_network` used for creating the next generation.
52        let mut closest_network:usize = 0;
53
54        // Loop through every value in `outputs` checking to see if any of the outputs are within `MARGIN` of `TARGET`
55        // And use a range so that we can track the index of the output easily.
56        for output in 0..outputs.len() {
57            // Since the networks are only outputting a single value we can simply grab the first value of the `Vector`
58            // and check if thats within `MARGIN` of `TARGET` using a range.
59            if (TARGET-MARGIN..=TARGET+MARGIN).contains(&outputs[output][0]) {
60                // If true then print the value found by the network, the network itself, the current generation, and exit from the program.
61                println!("Finished in {:?}!\nGeneration: {:?}\nValue: {:?}\nNetwork: {{\nHiddenLayers: {:?}\nOutputLayer: {:?}\n}}", time.elapsed(),generation, outputs[output][0], networks[output].get_weights(), networks[output].get_output_weights());
62                // Exit code 0 on Linux means no problem, on Windows however this should be 256 but that is outside the scope of this example.
63                std::process::exit(0);
64            } else {
65                // If the `output` was not within `MARGIN` of `TARGET` then check if this value is closer to `TARGET` than the last `closest_value`.
66                // and set `closest_value` to `output` if it is closer.
67                if outputs[output][0] < TARGET && outputs[output][0] > closest_value {
68                    closest_value = outputs[output][0];
69                    closest_network = output;
70                } else if outputs[output][0] > TARGET && outputs[output][0] < closest_value {
71                    closest_value = outputs[output][0];
72                    closest_network = output;
73                }
74            }
75        }
76
77        // Set all `networks` to various mutations of the `closest_network`.
78        networks = batch_mutate(5, 0.25, &networks[closest_network], true);
79    }
80
81    // If we managed to get through `GENERATIONS` number of generations without getting within `MARGIN` of `TARGET` then output the `closest_value` we found.
82    println!("Failed to get within `MARGIN` within {:?} number of generations, this is the `closest_value` obtained: {:?}. In {:?}", GENERATIONS, closest_value, time.elapsed());
83}