1mod data;
2
3use std::{
4 env,
5 fs::{self, File},
6};
7
8use env_logger::{Builder, Target};
9use log::{error, info};
10use runn::{
11 adam::Adam,
12 cross_entropy::CrossEntropy,
13 csv::CSV,
14 dense_layer::Dense,
15 helper,
16 network::network_model::{Network, NetworkBuilder},
17 network_io::JSON,
18 network_search::NetworkSearchBuilder,
19 numbers::{Numbers, SequentialNumbers},
20 relu::ReLU,
21 softmax::Softmax,
22};
23
24const EXP_NAME: &str = "triplets";
25
26fn main() {
53 initialize_logger(EXP_NAME);
54
55 let args: Vec<String> = env::args().collect();
56 if args.contains(&"-search".to_string()) {
57 search();
58 } else {
59 train_and_validate();
60 }
61}
62
63fn train_and_validate() {
64 let network_file = format!("{}_network", EXP_NAME);
65 let training_inputs = data::training_inputs();
66 let training_targets = data::training_targets();
67 let mut network = triplets_network(training_inputs.cols(), training_targets.cols());
68
69 let train_result = network.train(&training_inputs, &training_targets);
70 match train_result {
71 Ok(_) => {
72 info!("Training successfully completed");
73 network
74 .save(
75 JSON::default()
76 .directory(EXP_NAME)
77 .file_name(&network_file)
78 .build()
79 .unwrap(),
80 )
81 .unwrap();
82 let net_results = network.predict(&training_inputs, &training_targets).unwrap();
83 info!(
84 "{}",
85 helper::pretty_compare_matrices(
86 &training_inputs,
87 &training_targets,
88 &net_results.predictions,
89 helper::CompareMode::Classification
90 )
91 );
92 info!("Training: {}", net_results.display_metrics());
93 }
94 Err(e) => {
95 eprintln!("Training failed: {}", e);
96 }
97 }
98
99 network = Network::load(
100 JSON::default()
101 .directory(EXP_NAME)
102 .file_name(&network_file)
103 .build()
104 .unwrap(),
105 )
106 .unwrap();
107 let validation_inputs = data::validation_inputs();
108 let validation_targets = data::validation_targets();
109 let net_results = network.predict(&validation_inputs, &validation_targets).unwrap();
110 log::info!(
111 "{}",
112 helper::pretty_compare_matrices(
113 &validation_inputs,
114 &validation_targets,
115 &net_results.predictions,
116 helper::CompareMode::Classification
117 )
118 );
119 info!("Validation: {}", net_results.display_metrics());
120}
121
122fn search() {
123 let training_inputs = data::training_inputs();
124 let training_targets = data::training_targets();
125
126 let validation_inputs = data::validation_inputs();
127 let validation_targets = data::validation_targets();
128
129 let network = triplets_network(training_inputs.cols(), training_targets.cols());
130
131 let network_search = NetworkSearchBuilder::new()
132 .network(network)
133 .parallelize(4)
134 .learning_rates(
135 SequentialNumbers::new()
136 .lower_limit(0.0025)
137 .upper_limit(0.0035)
138 .increment(0.0005)
139 .floats(),
140 )
141 .batch_sizes(
142 SequentialNumbers::new()
143 .lower_limit(5.0)
144 .upper_limit(10.0)
145 .increment(1.0)
146 .ints(),
147 )
148 .hidden_layer(
149 SequentialNumbers::new()
150 .lower_limit(12.0)
151 .upper_limit(24.0)
152 .increment(4.0)
153 .ints(),
154 ReLU::build(),
155 )
156 .export(
157 CSV::default()
158 .directory(EXP_NAME)
159 .file_name(&format!("{}_search", EXP_NAME))
160 .build(),
161 )
162 .build();
163
164 let mut network_search = match network_search {
165 Ok(ns) => ns,
166 Err(e) => {
167 error!("Failed to build network_search: {}", e);
168 std::process::exit(1);
169 }
170 };
171
172 let search_res = network_search
173 .search(&training_inputs, &training_targets, &validation_inputs, &validation_targets)
174 .unwrap();
175
176 info!("Num Results: {}", search_res.len());
177}
178
179fn triplets_network(inp_size: usize, targ_size: usize) -> Network {
180 let network = NetworkBuilder::new(inp_size, targ_size)
181 .layer(Dense::default().size(24).activation(ReLU::build()).build())
182 .layer(Dense::default().size(targ_size).activation(Softmax::build()).build())
183 .loss_function(CrossEntropy::default().epsilon(1e-8).build())
184 .optimizer(Adam::default().beta1(0.99).beta2(0.999).learning_rate(0.0035).build())
185 .batch_size(8)
186 .batch_group_size(2)
187 .parallelize(2)
188 .epochs(1000)
189 .seed(55)
190 .build();
191
192 match network {
193 Ok(net) => net,
194 Err(e) => {
195 eprintln!("Failed to build network: {}", e);
196 std::process::exit(1);
197 }
198 }
199}
200
201fn initialize_logger(name: &str) {
205 if !std::path::Path::new(name).exists() {
207 let _res = fs::create_dir_all(name).map_err(|e| {
208 eprintln!("Failed to create log directory: {}", e);
209 });
210 }
211
212 let log_file = match File::create(format!("./{}/{}.log", name, name)) {
214 Ok(file) => file,
215 Err(e) => {
216 eprintln!("Failed to create log file: {}", e);
217 return;
218 }
219 };
220
221 let log_level = env::var("LOG").unwrap_or_else(|_| "info".to_string()); Builder::new()
226 .target(Target::Pipe(Box::new(log_file)))
227 .parse_filters(&log_level) .init();
229}