surrealml_core/storage/
surml_file.rs

1//! Defines the saving and loading of the entire `surml` file.
2use std::fs::File;
3use std::io::{Read, Write};
4
5use crate::{
6    safe_eject_internal,
7    safe_eject,
8    storage::header::Header,
9    errors::error::{
10        SurrealError,
11        SurrealErrorStatus
12    }
13};
14
15
16/// The `SurMlFile` struct represents the entire `surml` file.
17/// 
18/// # Fields
19/// * `header` - The header of the `surml` file containing data such as key bindings for inputs and normalisers.
20/// * `model` - The PyTorch model in C.
21pub struct SurMlFile {
22    pub header: Header,
23    pub model: Vec<u8>,
24}
25
26
27impl SurMlFile {
28
29    /// Creates a new `SurMlFile` struct with an empty header.
30    /// 
31    /// # Arguments
32    /// * `model` - The PyTorch model in C.
33    /// 
34    /// # Returns
35    /// A new `SurMlFile` struct with no columns or normalisers.
36    pub fn fresh(model: Vec<u8>) -> Self {
37        Self {
38            header: Header::fresh(),
39            model
40        }
41    }
42
43    /// Creates a new `SurMlFile` struct.
44    /// 
45    /// # Arguments
46    /// * `header` - The header of the `surml` file containing data such as key bindings for inputs and normalisers.
47    /// * `model` - The PyTorch model in C.
48    /// 
49    /// # Returns
50    /// A new `SurMlFile` struct.
51    pub fn new(header: Header, model: Vec<u8>) -> Self {
52        Self {
53            header,
54            model,
55        }
56    }
57
58    /// Creates a new `SurMlFile` struct from a vector of bytes.
59    /// 
60    /// # Arguments
61    /// * `bytes` - A vector of bytes representing the header and the model.
62    /// 
63    /// # Returns
64    /// A new `SurMlFile` struct.
65    pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, SurrealError> {
66        // check to see if there is enough bytes to read
67        if bytes.len() < 4 {
68            return Err(
69                SurrealError::new(
70                    "Not enough bytes to read".to_string(),
71                    SurrealErrorStatus::BadRequest
72                )
73            );
74        }
75        let mut header_bytes = Vec::new();
76        let mut model_bytes = Vec::new();
77
78        // extract the first 4 bytes as an integer to get the length of the header
79        let mut buffer = [0u8; 4];
80        buffer.copy_from_slice(&bytes[0..4]);
81        let integer_value = u32::from_be_bytes(buffer);
82
83        // check to see if there is enough bytes to read
84        if bytes.len() < (4 + integer_value as usize) {
85            return Err(
86                SurrealError::new(
87                    "Not enough bytes to read for header, maybe the file format is not correct".to_string(),
88                    SurrealErrorStatus::BadRequest
89                )
90            );
91        }
92
93        // Read the next integer_value bytes for the header
94        header_bytes.extend_from_slice(&bytes[4..(4 + integer_value as usize)]);
95
96        // Read the remaining bytes for the model
97        model_bytes.extend_from_slice(&bytes[(4 + integer_value as usize)..]);
98
99        // construct the header and C model from the bytes
100        let header = Header::from_bytes(header_bytes)?;
101        let model = model_bytes;
102        Ok(Self {
103            header,
104            model,
105        })
106    }
107
108    /// Creates a new `SurMlFile` struct from a file.
109    /// 
110    /// # Arguments
111    /// * `file_path` - The path to the `surml` file.
112    /// 
113    /// # Returns
114    /// A new `SurMlFile` struct.
115    pub fn from_file(file_path: &str) -> Result<Self, SurrealError> {
116        let mut file = safe_eject!(File::open(file_path), SurrealErrorStatus::NotFound);
117
118        // extract the first 4 bytes as an integer to get the length of the header
119        let mut buffer = [0u8; 4];
120        safe_eject!(file.read_exact(&mut buffer), SurrealErrorStatus::BadRequest);
121        let integer_value = u32::from_be_bytes(buffer);
122
123        // Read the next integer_value bytes for the header
124        let mut header_buffer = vec![0u8; integer_value as usize];
125        safe_eject!(file.read_exact(&mut header_buffer), SurrealErrorStatus::BadRequest);
126
127        // Create a Vec<u8> to store the data
128        let mut model_buffer = Vec::new();
129
130        // Read the rest of the file into the buffer
131        safe_eject!(file.take(usize::MAX as u64).read_to_end(&mut model_buffer), SurrealErrorStatus::BadRequest);
132
133        // construct the header and C model from the bytes
134        let header = Header::from_bytes(header_buffer)?;
135        Ok(Self {
136            header,
137            model: model_buffer,
138        })
139    }
140
141    /// Converts the header and the model to a vector of bytes.
142    /// 
143    /// # Returns
144    /// A vector of bytes representing the header and the model.
145    pub fn to_bytes(&self) -> Vec<u8> {
146        // compile the header into bytes.
147        let (num, header_bytes) = self.header.to_bytes();
148        let num_bytes = i32::to_be_bytes(num).to_vec();
149        
150        // combine the bytes into a single vector
151        let mut combined_vec: Vec<u8> = Vec::new();
152        combined_vec.extend(num_bytes);
153        combined_vec.extend(header_bytes);
154        combined_vec.extend(self.model.clone());
155        return combined_vec
156    }
157
158    /// Writes the header and the model to a `surml` file.
159    /// 
160    /// # Arguments
161    /// * `file_path` - The path to the `surml` file.
162    /// 
163    /// # Returns
164    /// An `io::Result` indicating whether the write was successful.
165    pub fn write(&self, file_path: &str) -> Result<(), SurrealError> {
166        let combined_vec = self.to_bytes();
167
168        // write the bytes to a file
169        let mut file = safe_eject_internal!(File::create(file_path));
170        safe_eject_internal!(file.write(&combined_vec));
171        Ok(())
172    }
173}
174
175
176#[cfg(test)]
177mod tests {
178
179    use super::*;
180
181    #[test]
182    fn test_write() {
183        let mut header = Header::fresh();
184        header.add_column(String::from("squarefoot"));
185        header.add_column(String::from("num_floors"));
186        header.add_output(String::from("house_price"), None);
187
188        let mut file = File::open("./stash/linear_test.onnx").unwrap();
189
190        let mut model_bytes = Vec::new();
191        file.read_to_end(&mut model_bytes).unwrap();
192
193        let surml_file = SurMlFile::new(header, model_bytes);
194        surml_file.write("./stash/test.surml").unwrap();
195
196        let _ = SurMlFile::from_file("./stash/test.surml").unwrap();
197    }
198
199    #[test]
200    fn test_write_forrest() {
201
202        let header = Header::fresh();
203
204        let mut file = File::open("./stash/forrest_test.onnx").unwrap();
205
206        let mut model_bytes = Vec::new();
207        file.read_to_end(&mut model_bytes).unwrap();
208
209        let surml_file = SurMlFile::new(header, model_bytes);
210        surml_file.write("./stash/forrest.surml").unwrap();
211
212        let _ = SurMlFile::from_file("./stash/forrest.surml").unwrap();
213
214    }
215
216    #[test]
217    fn test_empty_buffer() {
218        let bytes = vec![0u8; 0];
219        match SurMlFile::from_bytes(bytes) {
220            Ok(_) => assert!(false),
221            Err(error) => {
222                assert_eq!(error.status, SurrealErrorStatus::BadRequest);
223                assert_eq!(error.to_string(), "Not enough bytes to read");
224            }
225        }
226    }
227}