1use rmp_serde::{decode, encode};
2use serde::{Deserialize, Serialize};
3use std::fs::{self, File};
4use std::io::{Read, Write};
5
6use crate::error::NetworkError;
7use crate::layer::Layer;
8use crate::{EarlyStopper, LossFunction, Normalization, OptimizerConfig, Regularization};
9
10pub trait NetworkIO {
11 fn save(&self, network: NetworkSerialized) -> Result<(), NetworkError>;
12 fn load(&self) -> Result<NetworkSerialized, NetworkError>;
13}
14
15#[derive(Clone)]
16struct JSONNetworkIO {
17 filename: String,
18 directory: String,
19}
20
21impl NetworkIO for JSONNetworkIO {
22 fn save(&self, network_s: NetworkSerialized) -> Result<(), NetworkError> {
23 let serialized_data = serde_json::to_vec(&network_s);
24 match serialized_data {
25 Ok(data) => save(self.filename.clone(), self.directory.clone(), data)?,
26 Err(_) => return Err(NetworkError::IoError("Failed to serialize to JSON".to_string())),
27 };
28 Ok(())
29 }
30
31 fn load(&self) -> Result<NetworkSerialized, NetworkError> {
32 let serialized_data = load(self.filename.clone(), self.directory.clone())?;
33 let network_s = serde_json::from_slice(&serialized_data);
34 if network_s.is_err() {
35 return Err(NetworkError::IoError("Failed to deserialize from JSON".to_string()));
36 }
37 Ok(network_s.unwrap())
38 }
39}
40
41pub struct JSON {
48 pub file_name: String,
49 pub directory: String,
50}
51
52impl JSON {
53 fn new() -> Self {
58 JSON {
59 file_name: "network".to_string(),
60 directory: ".".to_string(),
61 }
62 }
63
64 pub fn file_name(mut self, filename: &str) -> Self {
66 self.file_name = filename.to_string();
67 self
68 }
69
70 pub fn directory(mut self, directory: &str) -> Self {
72 self.directory = directory.to_string();
73 self
74 }
75
76 fn validate(&self) -> Result<(), NetworkError> {
81 if self.file_name.is_empty() {
82 return Err(NetworkError::ConfigError("Filename cannot be empty".to_string()));
83 }
84 if self.directory.is_empty() {
85 return Err(NetworkError::ConfigError("Directory cannot be empty".to_string()));
86 }
87
88 if !std::path::Path::new(&self.directory).exists() {
89 fs::create_dir_all(&self.directory).map_err(|e| {
90 NetworkError::IoError(format!("Failed to create output directory '{}': {}", self.directory, e))
91 })?;
92 }
93
94 Ok(())
95 }
96
97 pub fn build(self) -> Result<impl NetworkIO, NetworkError> {
99 self.validate()?;
100 Ok(JSONNetworkIO {
101 filename: self.file_name,
102 directory: self.directory,
103 })
104 }
105}
106
107impl Default for JSON {
108 fn default() -> Self {
113 Self::new()
114 }
115}
116
117#[derive(Clone)]
118struct MessagePackNetworkIO {
119 filename: String,
120 directory: String,
121}
122
123impl NetworkIO for MessagePackNetworkIO {
124 fn save(&self, network_s: NetworkSerialized) -> Result<(), NetworkError> {
125 let serialized_data = encode::to_vec(&network_s);
126 match serialized_data {
127 Ok(data) => save(self.filename.clone(), self.directory.clone(), data)?,
128 Err(_) => return Err(NetworkError::IoError("Failed to serialize to MessagePack".to_string())),
129 };
130 Ok(())
131 }
132
133 fn load(&self) -> Result<NetworkSerialized, NetworkError> {
134 let serialized_data = load(self.filename.clone(), self.directory.clone())?;
135 let network_s = decode::from_slice(&serialized_data);
136 if network_s.is_err() {
137 return Err(NetworkError::IoError("Failed to deserialize from MessagePack".to_string()));
138 }
139 Ok(network_s.unwrap())
140 }
141}
142pub struct MessagePack {
149 pub file_name: String,
150 pub directory: String,
151}
152
153impl MessagePack {
154 fn new() -> Self {
159 MessagePack {
160 file_name: "network".to_string(),
161 directory: ".".to_string(),
162 }
163 }
164
165 pub fn file_name(mut self, filename: &str) -> Self {
167 self.file_name = filename.to_string();
168 self
169 }
170
171 pub fn directory(mut self, directory: &str) -> Self {
173 self.directory = directory.to_string();
174 self
175 }
176
177 fn validate(&self) -> Result<(), NetworkError> {
182 if self.file_name.is_empty() {
183 return Err(NetworkError::ConfigError("Filename cannot be empty".to_string()));
184 }
185 if self.directory.is_empty() {
186 return Err(NetworkError::ConfigError("Directory cannot be empty".to_string()));
187 }
188
189 if !std::path::Path::new(&self.directory).exists() {
190 fs::create_dir_all(&self.directory).map_err(|e| {
191 NetworkError::IoError(format!("Failed to create output directory '{}': {}", self.directory, e))
192 })?;
193 }
194
195 Ok(())
196 }
197
198 pub fn build(self) -> Result<impl NetworkIO, NetworkError> {
200 self.validate()?;
201 Ok(MessagePackNetworkIO {
202 filename: self.file_name,
203 directory: self.directory,
204 })
205 }
206}
207
208impl Default for MessagePack {
209 fn default() -> Self {
214 Self::new()
215 }
216}
217
218fn save(file_name: String, directory: String, serialized_data: Vec<u8>) -> Result<(), NetworkError> {
219 let file = File::create(format!("{}/{}.json", directory, file_name));
220 if file.is_err() {
221 return Err(NetworkError::IoError("Failed to create file".to_string()));
222 }
223 let res = file.unwrap().write_all(&serialized_data);
224 if res.is_err() {
225 return Err(NetworkError::IoError("Failed to write to file".to_string()));
226 }
227 Ok(())
228}
229
230fn load(file_name: String, directory: String) -> Result<Vec<u8>, NetworkError> {
231 let file = File::open(format!("{}/{}.json", directory, file_name));
232 if file.is_err() {
233 return Err(NetworkError::IoError("Failed to open file".to_string()));
234 }
235 let mut buffer = Vec::new();
236 let res = file.unwrap().read_to_end(&mut buffer);
237 if res.is_err() {
238 return Err(NetworkError::IoError("Failed to read file".to_string()));
239 }
240 Ok(buffer)
241}
242
243#[derive(Serialize, Deserialize)]
244pub struct NetworkSerialized {
245 pub(crate) input_size: usize,
246 pub(crate) output_size: usize,
247 pub(crate) layers: Vec<Box<dyn Layer>>,
248 pub(crate) loss_function: Box<dyn LossFunction>,
249 pub(crate) optimizer_config: Box<dyn OptimizerConfig>,
250 pub(crate) regularizations: Vec<Box<dyn Regularization>>,
251 pub(crate) batch_size: usize,
252 pub(crate) batch_group_size: usize,
253 pub(crate) epochs: usize,
254 pub(crate) clip_threshold: f32,
255 pub(crate) seed: u64,
256 pub(crate) early_stopper: Option<Box<dyn EarlyStopper>>,
257 pub(crate) debug: bool,
258 pub(crate) normalize_input: Option<Box<dyn Normalization>>,
259 pub(crate) normalize_output: Option<Box<dyn Normalization>>,
260 pub(crate) summary_writer: Option<Box<dyn crate::summary::SummaryWriter>>,
261 pub(crate) parallelize: usize,
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::dense_layer::Dense;
268 use crate::dropout::Dropout;
269 use crate::mean_squared_error::MeanSquared;
270 use crate::network::network_model::Network;
271 use crate::network::network_model::NetworkBuilder;
272 use crate::relu::ReLU;
273 use crate::sgd::SGD;
274 use crate::softmax::Softmax;
275
276 #[test]
277 fn test_json_io() {
278 let json_io = JSON::new()
279 .file_name("test_network")
280 .directory("./test_dir_123")
281 .build()
282 .unwrap();
283
284 let network = NetworkBuilder::new(4, 3)
285 .layer(Dense::default().size(5).activation(ReLU::build()).build())
286 .layer(Dense::default().size(3).activation(Softmax::build()).build())
287 .loss_function(MeanSquared.build())
288 .optimizer(SGD::default().learning_rate(0.01).build())
289 .regularization(Dropout::default().dropout_rate(0.5).seed(42).build())
290 .seed(42)
291 .epochs(10)
292 .batch_size(2)
293 .build()
294 .unwrap();
295
296 let _res = network.save(json_io);
297 let loaded_network = Network::load(
298 JSON::new()
299 .file_name("test_network")
300 .directory("./test_dir_123")
301 .build()
302 .unwrap(),
303 )
304 .unwrap();
305 assert_eq!(loaded_network.input_size, 4);
306 assert_eq!(loaded_network.output_size, 3);
307 let _res = fs::remove_dir_all("./test_dir_123");
309 assert!(_res.is_ok());
310 }
311
312 #[test]
313 fn test_message_pack_io() {
314 let msgpack_io = MessagePack::new()
315 .file_name("test_network")
316 .directory("./test_dir_1234")
317 .build()
318 .unwrap();
319
320 let network = NetworkBuilder::new(4, 3)
321 .layer(Dense::default().size(5).activation(ReLU::build()).build())
322 .layer(Dense::default().size(3).activation(Softmax::build()).build())
323 .loss_function(MeanSquared.build())
324 .optimizer(SGD::default().learning_rate(0.01).build())
325 .regularization(Dropout::default().dropout_rate(0.5).seed(42).build())
326 .seed(42)
327 .epochs(10)
328 .batch_size(2)
329 .build()
330 .unwrap();
331
332 let _res = network.save(msgpack_io);
333 let loaded_network = Network::load(
334 MessagePack::new()
335 .file_name("test_network")
336 .directory("./test_dir_1234")
337 .build()
338 .unwrap(),
339 )
340 .unwrap();
341 assert_eq!(loaded_network.input_size, 4);
342 assert_eq!(loaded_network.output_size, 3);
343 let _res = fs::remove_dir_all("./test_dir_1234");
345 assert!(_res.is_ok());
346 }
347
348 #[test]
349 fn test_save_load_invalid_file() {
350 let json_io = JSON::new()
351 .file_name("invalid_network")
352 .directory("./invalid_dir")
353 .build()
354 .unwrap();
355
356 let result = json_io.load();
357 assert!(result.is_err());
358 if let Err(NetworkError::IoError(msg)) = result {
359 assert_eq!(msg, "Failed to open file");
360 } else {
361 panic!("Expected ConfigError");
362 }
363 let _res = fs::remove_dir_all("./invalid_dir");
365 }
366
367 #[test]
368 fn test_save_load_invalid_directory() {
369 let msgpack_io = MessagePack::new()
370 .file_name("invalid_network")
371 .directory("./invalid_dir")
372 .build()
373 .unwrap();
374
375 let result = msgpack_io.load();
376 assert!(result.is_err());
377 if let Err(NetworkError::IoError(msg)) = result {
378 assert_eq!(msg, "Failed to open file");
379 } else {
380 panic!("Expected ConfigError");
381 }
382 let _res = fs::remove_dir_all("./invalid_dir");
384 assert!(_res.is_ok());
385 }
386}