1use csv::ReaderBuilder;
2use env_logger::{Builder, Target};
3use log::{error, info};
4use runn::{
5 adam::Adam,
6 cross_entropy::CrossEntropy,
7 csv::CSV,
8 dense_layer::Dense,
9 helper,
10 matrix::{DMat, DenseMatrix},
11 network::network_model::{Network, NetworkBuilder},
12 network_io::JSON,
13 network_search::NetworkSearchBuilder,
14 numbers::{Numbers, SequentialNumbers},
15 relu::ReLU,
16 softmax::Softmax,
17};
18use std::error::Error;
19use std::fs::File;
20use std::{env, fs};
21
22const EXP_NAME: &str = "iris";
23
24fn main() {
40 initialize_logger(EXP_NAME);
41
42 let args: Vec<String> = env::args().collect();
43 if args.contains(&"-search".to_string()) {
44 search();
45 } else {
46 train_and_validate();
47 }
48}
49
50fn train_and_validate() {
51 let network_file = format!("{}_network", EXP_NAME);
52
53 let (training_inputs, training_targets) = iris_inputs_outputs("train", 7, 4).unwrap();
54 let mut network = iris_network(training_inputs.cols(), training_targets.cols());
55
56 let training_result = network.train(&training_inputs, &training_targets);
57 match training_result {
58 Ok(_) => {
59 info!("Training successfully completed");
60 network
61 .save(
62 JSON::default()
63 .directory(EXP_NAME)
64 .file_name(&network_file)
65 .build()
66 .unwrap(),
67 )
68 .unwrap();
69 let net_results = network.predict(&training_inputs, &training_targets).unwrap();
70 info!(
71 "{}",
72 helper::pretty_compare_matrices(
73 &training_inputs,
74 &training_targets,
75 &net_results.predictions,
76 helper::CompareMode::Classification
77 )
78 );
79 info!("Training: {}", net_results.display_metrics());
80 }
81 Err(e) => {
82 eprintln!("Training failed: {}", e);
83 }
84 }
85
86 network = Network::load(
87 JSON::default()
88 .directory(EXP_NAME)
89 .file_name(&network_file)
90 .build()
91 .unwrap(),
92 )
93 .unwrap();
94 let (validation_inputs, validation_targets) = iris_inputs_outputs("test", 7, 4).unwrap();
95 let net_results = network.predict(&validation_inputs, &validation_targets).unwrap();
96 log::info!(
97 "{}",
98 helper::pretty_compare_matrices(
99 &validation_inputs,
100 &validation_targets,
101 &net_results.predictions,
102 helper::CompareMode::Classification
103 )
104 );
105 info!("Validation: {}", net_results.display_metrics());
106}
107
108fn iris_network(inp_size: usize, targ_size: usize) -> Network {
109 let network = NetworkBuilder::new(inp_size, targ_size)
110 .layer(Dense::default().size(12).activation(ReLU::build()).build())
111 .layer(Dense::default().size(12).activation(ReLU::build()).build())
112 .layer(Dense::default().size(targ_size).activation(Softmax::build()).build())
113 .loss_function(CrossEntropy::default().epsilon(1e-8).build())
114 .optimizer(Adam::default().beta1(0.99).beta2(0.999).learning_rate(0.0035).build())
115 .batch_size(9)
116 .batch_group_size(2)
117 .parallelize(2)
118 .epochs(3000)
119 .seed(55)
120 .build();
121
122 match network {
123 Ok(net) => net,
124 Err(e) => {
125 eprintln!("Failed to build network: {}", e);
126 std::process::exit(1);
127 }
128 }
129}
130
131fn search() {
132 let (training_inputs, training_targets) = iris_inputs_outputs("train", 7, 4).unwrap();
133 let (validation_inputs, validation_targets) = iris_inputs_outputs("test", 7, 4).unwrap();
134
135 let network = iris_network(training_inputs.cols(), training_targets.cols());
136
137 let network_search = NetworkSearchBuilder::new()
138 .network(network)
139 .parallelize(4)
140 .learning_rates(
141 SequentialNumbers::new()
142 .lower_limit(0.0025)
143 .upper_limit(0.0035)
144 .increment(0.0005)
145 .floats(),
146 )
147 .batch_sizes(
148 SequentialNumbers::new()
149 .lower_limit(7.0)
150 .upper_limit(10.0)
151 .increment(1.0)
152 .ints(),
153 )
154 .hidden_layer(
155 SequentialNumbers::new()
156 .lower_limit(12.0)
157 .upper_limit(20.0)
158 .increment(4.0)
159 .ints(),
160 ReLU::build(),
161 )
162 .hidden_layer(
163 SequentialNumbers::new()
164 .lower_limit(12.0)
165 .upper_limit(20.0)
166 .increment(4.0)
167 .ints(),
168 ReLU::build(),
169 )
170 .export(
171 CSV::default()
172 .directory(EXP_NAME)
173 .file_name(&format!("{}_search", EXP_NAME))
174 .build(),
175 )
176 .build();
177
178 let mut network_search = match network_search {
179 Ok(ns) => ns,
180 Err(e) => {
181 error!("Failed to build network_search: {}", e);
182 std::process::exit(1);
183 }
184 };
185
186 let search_res = network_search
187 .search(&training_inputs, &training_targets, &validation_inputs, &validation_targets)
188 .unwrap();
189
190 info!("Num Results: {}", search_res.len());
191}
192
193pub fn iris_inputs_outputs(
194 name: &str, fields_count: usize, input_count: usize,
195) -> Result<(DMat, DMat), Box<dyn Error>> {
196 let target_count = fields_count - input_count;
197
198 let file_path = format!("./examples/iris/{}.csv", name);
199 let file = File::open(&file_path)?;
200 let mut reader = ReaderBuilder::new().has_headers(true).from_reader(file);
201
202 let mut inputs_data = Vec::new();
203 let mut labels_data = Vec::new();
204
205 for result in reader.records() {
206 let record = result?;
207 for (i, value) in record.iter().enumerate() {
208 let parsed_val: f32 = value.parse()?;
209 if i >= fields_count - target_count {
210 labels_data.push(parsed_val);
211 } else {
212 inputs_data.push(parsed_val);
213 }
214 }
215 }
216
217 let data_length = inputs_data.len() / input_count;
218
219 let inputs = DenseMatrix::new(data_length, input_count)
220 .data(&inputs_data)
221 .build()
222 .unwrap();
223 let labels = DenseMatrix::new(data_length, target_count)
224 .data(&labels_data)
225 .build()
226 .unwrap();
227
228 Ok((inputs, labels))
229}
230
231fn initialize_logger(name: &str) {
235 if !std::path::Path::new(name).exists() {
237 let _res = fs::create_dir_all(name).map_err(|e| {
238 eprintln!("Failed to create log directory: {}", e);
239 });
240 }
241
242 let log_file = match File::create(format!("./{}/{}.log", name, name)) {
244 Ok(file) => file,
245 Err(e) => {
246 eprintln!("Failed to create log file: {}", e);
247 return;
248 }
249 };
250
251 let log_level = env::var("LOG").unwrap_or_else(|_| "info".to_string()); Builder::new()
256 .target(Target::Pipe(Box::new(log_file)))
257 .parse_filters(&log_level) .init();
259}