Skip to main content

triplets/
triplets.rs

1mod data;
2
3use std::{
4    env,
5    fs::{self, File},
6};
7
8use env_logger::{Builder, Target};
9use log::{error, info};
10use runn::{
11    adam::Adam,
12    cross_entropy::CrossEntropy,
13    csv::CSV,
14    dense_layer::Dense,
15    helper,
16    network::network_model::{Network, NetworkBuilder},
17    network_io::JSON,
18    network_search::NetworkSearchBuilder,
19    numbers::{Numbers, SequentialNumbers},
20    relu::ReLU,
21    softmax::Softmax,
22};
23
24const EXP_NAME: &str = "triplets";
25
26// Triplets is a Multi-class classification problem.
27// One-hot encoding problem with 3 classes.
28// predict 1,0,0 if all input elements are same
29// predict 0,1,0 if only two of the input elements are same
30// predict 0,0,1 if none of the input elements are same
31
32/// This example demonstrates how to train and validate a neural network on the Triplets dataset.
33/// Triplets is a Multi-class classification problem.
34///
35///  - predict 1,0,0 if all input elements are same
36///  - predict 0,1,0 if only two of the input elements are same
37///  - predict 0,0,1 if none of the input elements are same
38///
39/// The code includes functions to load the dataset, build the neural network,
40/// train the network, validate its performance, and perform a hyperparameter search.
41///
42/// to run the example:
43/// ```bash
44/// cargo run --example triplets
45/// ```
46/// to run the hyperparameter search:
47/// ```bash
48/// cargo run --example triplets -- -search
49/// ```
50/// The hyperparameter search will create a CSV file with the results in the `triplets` directory.
51/// The training and validation results will be logged in the `triplets` directory.
52fn main() {
53    initialize_logger(EXP_NAME);
54
55    let args: Vec<String> = env::args().collect();
56    if args.contains(&"-search".to_string()) {
57        search();
58    } else {
59        train_and_validate();
60    }
61}
62
63fn train_and_validate() {
64    let network_file = format!("{}_network", EXP_NAME);
65    let training_inputs = data::training_inputs();
66    let training_targets = data::training_targets();
67    let mut network = triplets_network(training_inputs.cols(), training_targets.cols());
68
69    let train_result = network.train(&training_inputs, &training_targets);
70    match train_result {
71        Ok(_) => {
72            info!("Training successfully completed");
73            network
74                .save(
75                    JSON::default()
76                        .directory(EXP_NAME)
77                        .file_name(&network_file)
78                        .build()
79                        .unwrap(),
80                )
81                .unwrap();
82            let net_results = network.predict(&training_inputs, &training_targets).unwrap();
83            info!(
84                "{}",
85                helper::pretty_compare_matrices(
86                    &training_inputs,
87                    &training_targets,
88                    &net_results.predictions,
89                    helper::CompareMode::Classification
90                )
91            );
92            info!("Training: {}", net_results.display_metrics());
93        }
94        Err(e) => {
95            eprintln!("Training failed: {}", e);
96        }
97    }
98
99    network = Network::load(
100        JSON::default()
101            .directory(EXP_NAME)
102            .file_name(&network_file)
103            .build()
104            .unwrap(),
105    )
106    .unwrap();
107    let validation_inputs = data::validation_inputs();
108    let validation_targets = data::validation_targets();
109    let net_results = network.predict(&validation_inputs, &validation_targets).unwrap();
110    log::info!(
111        "{}",
112        helper::pretty_compare_matrices(
113            &validation_inputs,
114            &validation_targets,
115            &net_results.predictions,
116            helper::CompareMode::Classification
117        )
118    );
119    info!("Validation: {}", net_results.display_metrics());
120}
121
122fn search() {
123    let training_inputs = data::training_inputs();
124    let training_targets = data::training_targets();
125
126    let validation_inputs = data::validation_inputs();
127    let validation_targets = data::validation_targets();
128
129    let network = triplets_network(training_inputs.cols(), training_targets.cols());
130
131    let network_search = NetworkSearchBuilder::new()
132        .network(network)
133        .parallelize(4)
134        .learning_rates(
135            SequentialNumbers::new()
136                .lower_limit(0.0025)
137                .upper_limit(0.0035)
138                .increment(0.0005)
139                .floats(),
140        )
141        .batch_sizes(
142            SequentialNumbers::new()
143                .lower_limit(5.0)
144                .upper_limit(10.0)
145                .increment(1.0)
146                .ints(),
147        )
148        .hidden_layer(
149            SequentialNumbers::new()
150                .lower_limit(12.0)
151                .upper_limit(24.0)
152                .increment(4.0)
153                .ints(),
154            ReLU::build(),
155        )
156        .export(
157            CSV::default()
158                .directory(EXP_NAME)
159                .file_name(&format!("{}_search", EXP_NAME))
160                .build(),
161        )
162        .build();
163
164    let mut network_search = match network_search {
165        Ok(ns) => ns,
166        Err(e) => {
167            error!("Failed to build network_search: {}", e);
168            std::process::exit(1);
169        }
170    };
171
172    let search_res = network_search
173        .search(&training_inputs, &training_targets, &validation_inputs, &validation_targets)
174        .unwrap();
175
176    info!("Num Results: {}", search_res.len());
177}
178
179fn triplets_network(inp_size: usize, targ_size: usize) -> Network {
180    let network = NetworkBuilder::new(inp_size, targ_size)
181        .layer(Dense::default().size(24).activation(ReLU::build()).build())
182        .layer(Dense::default().size(targ_size).activation(Softmax::build()).build())
183        .loss_function(CrossEntropy::default().epsilon(1e-8).build())
184        .optimizer(Adam::default().beta1(0.99).beta2(0.999).learning_rate(0.0035).build())
185        .batch_size(8)
186        .batch_group_size(2)
187        .parallelize(2)
188        .epochs(1000)
189        .seed(55)
190        .build();
191
192    match network {
193        Ok(net) => net,
194        Err(e) => {
195            eprintln!("Failed to build network: {}", e);
196            std::process::exit(1);
197        }
198    }
199}
200
201/// Initializes the logger for the application.
202/// The LOG environment variable is used to define the log level (e.g., info, debug, warn, error).
203/// If the LOG variable is not set, it defaults to info.
204fn initialize_logger(name: &str) {
205    // Check if the directory exists, and attempt to create it if it doesn't
206    if !std::path::Path::new(name).exists() {
207        let _res = fs::create_dir_all(name).map_err(|e| {
208            eprintln!("Failed to create log directory: {}", e);
209        });
210    }
211
212    // Attempt to create a log file
213    let log_file = match File::create(format!("./{}/{}.log", name, name)) {
214        Ok(file) => file,
215        Err(e) => {
216            eprintln!("Failed to create log file: {}", e);
217            return;
218        }
219    };
220
221    // Check if the "LOG" environment variable is set
222    let log_level = env::var("LOG").unwrap_or_else(|_| "info".to_string()); // Default to "info"
223
224    // Initialize the logger with the specified log level
225    Builder::new()
226        .target(Target::Pipe(Box::new(log_file)))
227        .parse_filters(&log_level) // Use the log level from the environment variable
228        .init();
229}