surrealml_core/storage/
surml_file.rs1use std::fs::File;
3use std::io::{Read, Write};
4
5use crate::{
6 safe_eject_internal,
7 safe_eject,
8 storage::header::Header,
9 errors::error::{
10 SurrealError,
11 SurrealErrorStatus
12 }
13};
14
15
16pub struct SurMlFile {
22 pub header: Header,
23 pub model: Vec<u8>,
24}
25
26
27impl SurMlFile {
28
29 pub fn fresh(model: Vec<u8>) -> Self {
37 Self {
38 header: Header::fresh(),
39 model
40 }
41 }
42
43 pub fn new(header: Header, model: Vec<u8>) -> Self {
52 Self {
53 header,
54 model,
55 }
56 }
57
58 pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, SurrealError> {
66 if bytes.len() < 4 {
68 return Err(
69 SurrealError::new(
70 "Not enough bytes to read".to_string(),
71 SurrealErrorStatus::BadRequest
72 )
73 );
74 }
75 let mut header_bytes = Vec::new();
76 let mut model_bytes = Vec::new();
77
78 let mut buffer = [0u8; 4];
80 buffer.copy_from_slice(&bytes[0..4]);
81 let integer_value = u32::from_be_bytes(buffer);
82
83 if bytes.len() < (4 + integer_value as usize) {
85 return Err(
86 SurrealError::new(
87 "Not enough bytes to read for header, maybe the file format is not correct".to_string(),
88 SurrealErrorStatus::BadRequest
89 )
90 );
91 }
92
93 header_bytes.extend_from_slice(&bytes[4..(4 + integer_value as usize)]);
95
96 model_bytes.extend_from_slice(&bytes[(4 + integer_value as usize)..]);
98
99 let header = Header::from_bytes(header_bytes)?;
101 let model = model_bytes;
102 Ok(Self {
103 header,
104 model,
105 })
106 }
107
108 pub fn from_file(file_path: &str) -> Result<Self, SurrealError> {
116 let mut file = safe_eject!(File::open(file_path), SurrealErrorStatus::NotFound);
117
118 let mut buffer = [0u8; 4];
120 safe_eject!(file.read_exact(&mut buffer), SurrealErrorStatus::BadRequest);
121 let integer_value = u32::from_be_bytes(buffer);
122
123 let mut header_buffer = vec![0u8; integer_value as usize];
125 safe_eject!(file.read_exact(&mut header_buffer), SurrealErrorStatus::BadRequest);
126
127 let mut model_buffer = Vec::new();
129
130 safe_eject!(file.take(usize::MAX as u64).read_to_end(&mut model_buffer), SurrealErrorStatus::BadRequest);
132
133 let header = Header::from_bytes(header_buffer)?;
135 Ok(Self {
136 header,
137 model: model_buffer,
138 })
139 }
140
141 pub fn to_bytes(&self) -> Vec<u8> {
146 let (num, header_bytes) = self.header.to_bytes();
148 let num_bytes = i32::to_be_bytes(num).to_vec();
149
150 let mut combined_vec: Vec<u8> = Vec::new();
152 combined_vec.extend(num_bytes);
153 combined_vec.extend(header_bytes);
154 combined_vec.extend(self.model.clone());
155 return combined_vec
156 }
157
158 pub fn write(&self, file_path: &str) -> Result<(), SurrealError> {
166 let combined_vec = self.to_bytes();
167
168 let mut file = safe_eject_internal!(File::create(file_path));
170 safe_eject_internal!(file.write(&combined_vec));
171 Ok(())
172 }
173}
174
175
176#[cfg(test)]
177mod tests {
178
179 use super::*;
180
181 #[test]
182 fn test_write() {
183 let mut header = Header::fresh();
184 header.add_column(String::from("squarefoot"));
185 header.add_column(String::from("num_floors"));
186 header.add_output(String::from("house_price"), None);
187
188 let mut file = File::open("./stash/linear_test.onnx").unwrap();
189
190 let mut model_bytes = Vec::new();
191 file.read_to_end(&mut model_bytes).unwrap();
192
193 let surml_file = SurMlFile::new(header, model_bytes);
194 surml_file.write("./stash/test.surml").unwrap();
195
196 let _ = SurMlFile::from_file("./stash/test.surml").unwrap();
197 }
198
199 #[test]
200 fn test_write_forrest() {
201
202 let header = Header::fresh();
203
204 let mut file = File::open("./stash/forrest_test.onnx").unwrap();
205
206 let mut model_bytes = Vec::new();
207 file.read_to_end(&mut model_bytes).unwrap();
208
209 let surml_file = SurMlFile::new(header, model_bytes);
210 surml_file.write("./stash/forrest.surml").unwrap();
211
212 let _ = SurMlFile::from_file("./stash/forrest.surml").unwrap();
213
214 }
215
216 #[test]
217 fn test_empty_buffer() {
218 let bytes = vec![0u8; 0];
219 match SurMlFile::from_bytes(bytes) {
220 Ok(_) => assert!(false),
221 Err(error) => {
222 assert_eq!(error.status, SurrealErrorStatus::BadRequest);
223 assert_eq!(error.to_string(), "Not enough bytes to read");
224 }
225 }
226 }
227}