Skip to main content

runn/network/
network_io.rs

1use rmp_serde::{decode, encode};
2use serde::{Deserialize, Serialize};
3use std::fs::{self, File};
4use std::io::{Read, Write};
5
6use crate::error::NetworkError;
7use crate::layer::Layer;
8use crate::{EarlyStopper, LossFunction, Normalization, OptimizerConfig, Regularization};
9
10pub trait NetworkIO {
11    fn save(&self, network: NetworkSerialized) -> Result<(), NetworkError>;
12    fn load(&self) -> Result<NetworkSerialized, NetworkError>;
13}
14
15#[derive(Clone)]
16struct JSONNetworkIO {
17    filename: String,
18    directory: String,
19}
20
21impl NetworkIO for JSONNetworkIO {
22    fn save(&self, network_s: NetworkSerialized) -> Result<(), NetworkError> {
23        let serialized_data = serde_json::to_vec(&network_s);
24        match serialized_data {
25            Ok(data) => save(self.filename.clone(), self.directory.clone(), data)?,
26            Err(_) => return Err(NetworkError::IoError("Failed to serialize to JSON".to_string())),
27        };
28        Ok(())
29    }
30
31    fn load(&self) -> Result<NetworkSerialized, NetworkError> {
32        let serialized_data = load(self.filename.clone(), self.directory.clone())?;
33        let network_s = serde_json::from_slice(&serialized_data);
34        if network_s.is_err() {
35            return Err(NetworkError::IoError("Failed to deserialize from JSON".to_string()));
36        }
37        Ok(network_s.unwrap())
38    }
39}
40
41/// A builder for configuring and creating a JSON network I/O exporter.
42///
43/// This struct provides a fluent interface to customize the file name
44/// and output directory for a JSON-encoded network representation.
45/// Use the `build` method to validate the configuration and return
46/// a concrete implementation of `NetworkIO`.
47pub struct JSON {
48    pub file_name: String,
49    pub directory: String,
50}
51
52impl JSON {
53    // Creates a new JSON builder.
54    // Default values:
55    // - File name: `"network"`
56    // - Directory: `"."` (current directory)
57    fn new() -> Self {
58        JSON {
59            file_name: "network".to_string(),
60            directory: ".".to_string(),
61        }
62    }
63
64    /// Sets the base name for the output JSON file.
65    pub fn file_name(mut self, filename: &str) -> Self {
66        self.file_name = filename.to_string();
67        self
68    }
69
70    /// Sets the output directory for the JSON file.
71    pub fn directory(mut self, directory: &str) -> Self {
72        self.directory = directory.to_string();
73        self
74    }
75
76    // Validates the configuration and checks if the output directory is writable.
77    //
78    // # Errors
79    // Returns a `NetworkError` if any field is invalid or the file system check fails.
80    fn validate(&self) -> Result<(), NetworkError> {
81        if self.file_name.is_empty() {
82            return Err(NetworkError::ConfigError("Filename cannot be empty".to_string()));
83        }
84        if self.directory.is_empty() {
85            return Err(NetworkError::ConfigError("Directory cannot be empty".to_string()));
86        }
87
88        if !std::path::Path::new(&self.directory).exists() {
89            fs::create_dir_all(&self.directory).map_err(|e| {
90                NetworkError::IoError(format!("Failed to create output directory '{}': {}", self.directory, e))
91            })?;
92        }
93
94        Ok(())
95    }
96
97    /// Finalizes the builder and constructs a `JSONNetworkIO` if the configuration is valid.
98    pub fn build(self) -> Result<impl NetworkIO, NetworkError> {
99        self.validate()?;
100        Ok(JSONNetworkIO {
101            filename: self.file_name,
102            directory: self.directory,
103        })
104    }
105}
106
107impl Default for JSON {
108    /// Creates a new JSON builder with default values.
109    /// Default values:
110    /// - File name: `"network"`
111    /// - Directory: `"."` (current directory)
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117#[derive(Clone)]
118struct MessagePackNetworkIO {
119    filename: String,
120    directory: String,
121}
122
123impl NetworkIO for MessagePackNetworkIO {
124    fn save(&self, network_s: NetworkSerialized) -> Result<(), NetworkError> {
125        let serialized_data = encode::to_vec(&network_s);
126        match serialized_data {
127            Ok(data) => save(self.filename.clone(), self.directory.clone(), data)?,
128            Err(_) => return Err(NetworkError::IoError("Failed to serialize to MessagePack".to_string())),
129        };
130        Ok(())
131    }
132
133    fn load(&self) -> Result<NetworkSerialized, NetworkError> {
134        let serialized_data = load(self.filename.clone(), self.directory.clone())?;
135        let network_s = decode::from_slice(&serialized_data);
136        if network_s.is_err() {
137            return Err(NetworkError::IoError("Failed to deserialize from MessagePack".to_string()));
138        }
139        Ok(network_s.unwrap())
140    }
141}
142/// A builder for configuring and creating a MessagePack network I/O exporter.
143///
144/// This struct provides a fluent interface to customize the file name
145/// and output directory for a MessagePack-encoded network representation.
146/// Use the `build` method to validate the configuration and return
147/// a concrete implementation of `NetworkIO`.
148pub struct MessagePack {
149    pub file_name: String,
150    pub directory: String,
151}
152
153impl MessagePack {
154    // Creates a new MessagePack builder.
155    // Default values:
156    // - File name: `"network"`
157    // - Directory: `"."` (current directory)
158    fn new() -> Self {
159        MessagePack {
160            file_name: "network".to_string(),
161            directory: ".".to_string(),
162        }
163    }
164
165    /// Sets the base name for the output MessagePack file.
166    pub fn file_name(mut self, filename: &str) -> Self {
167        self.file_name = filename.to_string();
168        self
169    }
170
171    /// Sets the output directory for the MessagePack file.
172    pub fn directory(mut self, directory: &str) -> Self {
173        self.directory = directory.to_string();
174        self
175    }
176
177    // Validates the configuration and checks if the output directory is writable.
178    //
179    // # Errors
180    // Returns a `NetworkError` if any field is invalid or the file system check fails.
181    fn validate(&self) -> Result<(), NetworkError> {
182        if self.file_name.is_empty() {
183            return Err(NetworkError::ConfigError("Filename cannot be empty".to_string()));
184        }
185        if self.directory.is_empty() {
186            return Err(NetworkError::ConfigError("Directory cannot be empty".to_string()));
187        }
188
189        if !std::path::Path::new(&self.directory).exists() {
190            fs::create_dir_all(&self.directory).map_err(|e| {
191                NetworkError::IoError(format!("Failed to create output directory '{}': {}", self.directory, e))
192            })?;
193        }
194
195        Ok(())
196    }
197
198    /// Finalizes the builder and constructs a `MessagePackNetworkIO` if the configuration is valid.
199    pub fn build(self) -> Result<impl NetworkIO, NetworkError> {
200        self.validate()?;
201        Ok(MessagePackNetworkIO {
202            filename: self.file_name,
203            directory: self.directory,
204        })
205    }
206}
207
208impl Default for MessagePack {
209    /// Creates a new MessagePack builder.
210    /// Default values:
211    /// - File name: `"network"`
212    /// - Directory: `"."` (current directory)
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218fn save(file_name: String, directory: String, serialized_data: Vec<u8>) -> Result<(), NetworkError> {
219    let file = File::create(format!("{}/{}.json", directory, file_name));
220    if file.is_err() {
221        return Err(NetworkError::IoError("Failed to create file".to_string()));
222    }
223    let res = file.unwrap().write_all(&serialized_data);
224    if res.is_err() {
225        return Err(NetworkError::IoError("Failed to write to file".to_string()));
226    }
227    Ok(())
228}
229
230fn load(file_name: String, directory: String) -> Result<Vec<u8>, NetworkError> {
231    let file = File::open(format!("{}/{}.json", directory, file_name));
232    if file.is_err() {
233        return Err(NetworkError::IoError("Failed to open file".to_string()));
234    }
235    let mut buffer = Vec::new();
236    let res = file.unwrap().read_to_end(&mut buffer);
237    if res.is_err() {
238        return Err(NetworkError::IoError("Failed to read file".to_string()));
239    }
240    Ok(buffer)
241}
242
243#[derive(Serialize, Deserialize)]
244pub struct NetworkSerialized {
245    pub(crate) input_size: usize,
246    pub(crate) output_size: usize,
247    pub(crate) layers: Vec<Box<dyn Layer>>,
248    pub(crate) loss_function: Box<dyn LossFunction>,
249    pub(crate) optimizer_config: Box<dyn OptimizerConfig>,
250    pub(crate) regularizations: Vec<Box<dyn Regularization>>,
251    pub(crate) batch_size: usize,
252    pub(crate) batch_group_size: usize,
253    pub(crate) epochs: usize,
254    pub(crate) clip_threshold: f32,
255    pub(crate) seed: u64,
256    pub(crate) early_stopper: Option<Box<dyn EarlyStopper>>,
257    pub(crate) debug: bool,
258    pub(crate) normalize_input: Option<Box<dyn Normalization>>,
259    pub(crate) normalize_output: Option<Box<dyn Normalization>>,
260    pub(crate) summary_writer: Option<Box<dyn crate::summary::SummaryWriter>>,
261    pub(crate) parallelize: usize,
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::dense_layer::Dense;
268    use crate::dropout::Dropout;
269    use crate::mean_squared_error::MeanSquared;
270    use crate::network::network_model::Network;
271    use crate::network::network_model::NetworkBuilder;
272    use crate::relu::ReLU;
273    use crate::sgd::SGD;
274    use crate::softmax::Softmax;
275
276    #[test]
277    fn test_json_io() {
278        let json_io = JSON::new()
279            .file_name("test_network")
280            .directory("./test_dir_123")
281            .build()
282            .unwrap();
283
284        let network = NetworkBuilder::new(4, 3)
285            .layer(Dense::default().size(5).activation(ReLU::build()).build())
286            .layer(Dense::default().size(3).activation(Softmax::build()).build())
287            .loss_function(MeanSquared.build())
288            .optimizer(SGD::default().learning_rate(0.01).build())
289            .regularization(Dropout::default().dropout_rate(0.5).seed(42).build())
290            .seed(42)
291            .epochs(10)
292            .batch_size(2)
293            .build()
294            .unwrap();
295
296        let _res = network.save(json_io);
297        let loaded_network = Network::load(
298            JSON::new()
299                .file_name("test_network")
300                .directory("./test_dir_123")
301                .build()
302                .unwrap(),
303        )
304        .unwrap();
305        assert_eq!(loaded_network.input_size, 4);
306        assert_eq!(loaded_network.output_size, 3);
307        //remove file and directory
308        let _res = fs::remove_dir_all("./test_dir_123");
309        assert!(_res.is_ok());
310    }
311
312    #[test]
313    fn test_message_pack_io() {
314        let msgpack_io = MessagePack::new()
315            .file_name("test_network")
316            .directory("./test_dir_1234")
317            .build()
318            .unwrap();
319
320        let network = NetworkBuilder::new(4, 3)
321            .layer(Dense::default().size(5).activation(ReLU::build()).build())
322            .layer(Dense::default().size(3).activation(Softmax::build()).build())
323            .loss_function(MeanSquared.build())
324            .optimizer(SGD::default().learning_rate(0.01).build())
325            .regularization(Dropout::default().dropout_rate(0.5).seed(42).build())
326            .seed(42)
327            .epochs(10)
328            .batch_size(2)
329            .build()
330            .unwrap();
331
332        let _res = network.save(msgpack_io);
333        let loaded_network = Network::load(
334            MessagePack::new()
335                .file_name("test_network")
336                .directory("./test_dir_1234")
337                .build()
338                .unwrap(),
339        )
340        .unwrap();
341        assert_eq!(loaded_network.input_size, 4);
342        assert_eq!(loaded_network.output_size, 3);
343        //remove file and directory
344        let _res = fs::remove_dir_all("./test_dir_1234");
345        assert!(_res.is_ok());
346    }
347
348    #[test]
349    fn test_save_load_invalid_file() {
350        let json_io = JSON::new()
351            .file_name("invalid_network")
352            .directory("./invalid_dir")
353            .build()
354            .unwrap();
355
356        let result = json_io.load();
357        assert!(result.is_err());
358        if let Err(NetworkError::IoError(msg)) = result {
359            assert_eq!(msg, "Failed to open file");
360        } else {
361            panic!("Expected ConfigError");
362        }
363        //remove file and directory
364        let _res = fs::remove_dir_all("./invalid_dir");
365    }
366
367    #[test]
368    fn test_save_load_invalid_directory() {
369        let msgpack_io = MessagePack::new()
370            .file_name("invalid_network")
371            .directory("./invalid_dir")
372            .build()
373            .unwrap();
374
375        let result = msgpack_io.load();
376        assert!(result.is_err());
377        if let Err(NetworkError::IoError(msg)) = result {
378            assert_eq!(msg, "Failed to open file");
379        } else {
380            panic!("Expected ConfigError");
381        }
382        //remove file and directory
383        let _res = fs::remove_dir_all("./invalid_dir");
384        assert!(_res.is_ok());
385    }
386}