surrealml_core/storage/header/
mod.rs

1//! Handles the loading, saving, and utilisation of all the data in the header of the model file.
2pub mod keys;
3pub mod normalisers;
4pub mod output;
5pub mod string_value;
6pub mod version;
7pub mod engine;
8pub mod origin;
9pub mod input_dims;
10
11use keys::KeyBindings;
12use normalisers::wrapper::NormaliserType;
13use normalisers::NormaliserMap;
14use output::Output;
15use string_value::StringValue;
16use version::Version;
17use engine::Engine;
18use origin::Origin;
19use input_dims::InputDims;
20use crate::safe_eject;
21use crate::errors::error::{SurrealError, SurrealErrorStatus};
22
23
24/// The header of the model file.
25/// 
26/// # Fields
27/// * `keys` - The key bindings where the order of the input columns is stored.
28/// * `normalisers` - The normalisers where the normalisation functions are stored per column if there are any.
29/// * `output` - The output where the output column name and normaliser are stored if there are any.
30/// * `name` - The name of the model.
31/// * `version` - The version of the model.
32/// * `description` - The description of the model.
33/// * `engine` - The engine of the model (could be native or pytorch).
34/// * `origin` - The origin of the model which is where the model was created and who the author is.
35#[derive(Debug, PartialEq)]
36pub struct Header {
37    pub keys: KeyBindings,
38    pub normalisers: NormaliserMap,
39    pub output: Output,
40    pub name: StringValue,
41    pub version: Version,
42    pub description: StringValue,
43    pub engine: Engine,
44    pub origin: Origin,
45    pub input_dims: InputDims,
46}
47
48
49impl Header {
50
51    /// Creates a new header with no columns or normalisers.
52    /// 
53    /// # Returns
54    /// A new header with no columns or normalisers.
55    pub fn fresh() -> Self {
56        Header {
57            keys: KeyBindings::fresh(),
58            normalisers: NormaliserMap::fresh(),
59            output: Output::fresh(),
60            name: StringValue::fresh(),
61            version: Version::fresh(),
62            description: StringValue::fresh(),
63            engine: Engine::fresh(),
64            origin: Origin::fresh(),
65            input_dims: InputDims::fresh(),
66        }
67    }
68
69    /// Adds a model name to the `self.name` field.
70    /// 
71    /// # Arguments
72    /// * `model_name` - The name of the model to be added.
73    pub fn add_name(&mut self, model_name: String) {
74        self.name = StringValue::from_string(model_name);
75    }
76
77    /// Adds a version to the `self.version` field.
78    /// 
79    /// # Arguments
80    /// * `version` - The version to be added.
81    pub fn add_version(&mut self, version: String) -> Result<(), SurrealError> {
82        self.version = Version::from_string(version)?;
83        Ok(())
84    }
85
86    /// Adds a description to the `self.description` field.
87    /// 
88    /// # Arguments
89    /// * `description` - The description to be added.
90    pub fn add_description(&mut self, description: String) {
91        self.description = StringValue::from_string(description);
92    }
93
94    /// Adds a column name to the `self.keys` field. It must be noted that the order in which the columns are added is 
95    /// the order in which they will be expected in the input data. We can do this with the followng example:
96    /// 
97    /// # Arguments
98    /// * `column_name` - The name of the column to be added.
99    pub fn add_column(&mut self, column_name: String) {
100        self.keys.add_column(column_name);
101    }
102
103    /// Adds a normaliser to the `self.normalisers` field.
104    /// 
105    /// # Arguments
106    /// * `column_name` - The name of the column to which the normaliser will be applied.
107    /// * `normaliser` - The normaliser to be applied to the column.
108    pub fn add_normaliser(&mut self, column_name: String, normaliser: NormaliserType) -> Result<(), SurrealError> {
109        let _ =  self.normalisers.add_normaliser(normaliser, column_name, &self.keys)?;
110        Ok(())
111    }
112
113    /// Gets the normaliser for a given column name.
114    /// 
115    /// # Arguments
116    /// * `column_name` - The name of the column to which the normaliser will be applied.
117    /// 
118    /// # Returns
119    /// The normaliser for the given column name.
120    pub fn get_normaliser(&self, column_name: &String) -> Result<Option<&NormaliserType>, SurrealError> {
121        self.normalisers.get_normaliser(column_name.to_string(), &self.keys)
122    }
123
124    /// Adds an output column to the `self.output` field.
125    /// 
126    /// # Arguments
127    /// * `column_name` - The name of the column to be added.
128    /// * `normaliser` - The normaliser to be applied to the column.
129    pub fn add_output(&mut self, column_name: String, normaliser: Option<NormaliserType>) {
130        self.output.name = Some(column_name);
131        self.output.normaliser = normaliser;
132    }
133
134    /// Adds an engine to the `self.engine` field.
135    /// 
136    /// # Arguments
137    /// * `engine` - The engine to be added.
138    pub fn add_engine(&mut self, engine: String) {
139        self.engine = Engine::from_string(engine);
140    }
141
142    /// Adds an author to the `self.origin` field.
143    /// 
144    /// # Arguments
145    /// * `author` - The author to be added.
146    pub fn add_author(&mut self, author: String) {
147        self.origin.add_author(author);
148    }
149
150    /// Adds an origin to the `self.origin` field.
151    /// 
152    /// # Arguments
153    /// * `origin` - The origin to be added.
154    pub fn add_origin(&mut self, origin: String) -> Result<(), SurrealError> {
155        self.origin.add_origin(origin)
156    }
157
158    /// The standard delimiter used to seperate each field in the header.
159    fn delimiter() -> &'static str {
160        "//=>"
161    } 
162
163    /// Constructs the `Header` struct from bytes.
164    /// 
165    /// # Arguments
166    /// * `data` - The bytes to be converted into a `Header` struct.
167    /// 
168    /// # Returns
169    /// The `Header` struct.
170    pub fn from_bytes(data: Vec<u8>) -> Result<Self, SurrealError> {
171
172        let string_data = safe_eject!(String::from_utf8(data), SurrealErrorStatus::BadRequest);
173
174        let buffer = string_data.split(Self::delimiter()).collect::<Vec<&str>>();
175
176        let keys: KeyBindings = KeyBindings::from_string(buffer.get(1).unwrap_or(&"").to_string());
177        let normalisers = NormaliserMap::from_string(buffer.get(2).unwrap_or(&"").to_string(), &keys)?;
178        let output = Output::from_string(buffer.get(3).unwrap_or(&"").to_string())?;
179        let name = StringValue::from_string(buffer.get(4).unwrap_or(&"").to_string());
180        let version = Version::from_string(buffer.get(5).unwrap_or(&"").to_string())?;
181        let description = StringValue::from_string(buffer.get(6).unwrap_or(&"").to_string());
182        let engine = Engine::from_string(buffer.get(7).unwrap_or(&"").to_string());
183        let origin = Origin::from_string(buffer.get(8).unwrap_or(&"").to_string())?;
184        let input_dims = InputDims::from_string(buffer.get(9).unwrap_or(&"").to_string());
185        Ok(Header {keys, normalisers, output, name, version, description, engine, origin, input_dims})
186    }
187
188    /// Converts the `Header` struct into bytes.
189    /// 
190    /// # Returns
191    /// A tuple containing the number of bytes in the header and the bytes themselves.
192    pub fn to_bytes(&self) -> (i32, Vec<u8>) {
193        let buffer = vec![
194            "".to_string(),
195            self.keys.to_string(),
196            self.normalisers.to_string(),
197            self.output.to_string(),
198            self.name.to_string(),
199            self.version.to_string(),
200            self.description.to_string(),
201            self.engine.to_string(),
202            self.origin.to_string(),
203            self.input_dims.to_string(),
204            "".to_string(),
205        ];
206        let buffer = buffer.join(Self::delimiter()).into_bytes();
207        (buffer.len() as i32, buffer)
208    }
209}
210
211
212#[cfg(test)]
213mod tests {
214
215    use super::*;
216    use super::keys::tests::generate_string as generate_key_string;
217    use super::normalisers::tests::generate_string as generate_normaliser_string;
218    use super::normalisers::{
219        clipping::Clipping,
220        linear_scaling::LinearScaling,
221        log_scale::LogScaling,
222        z_score::ZScore,
223    };
224
225
226    pub fn generate_string() -> String {
227        let keys = generate_key_string();
228        let normalisers = generate_normaliser_string();
229        let output = "g=>linear_scaling(0.0,1.0)".to_string();
230        format!(
231            "{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}", 
232            Header::delimiter(), 
233            keys, 
234            Header::delimiter(), 
235            normalisers, 
236            Header::delimiter(),
237            output,
238            Header::delimiter(),
239            "test model name".to_string(),
240            Header::delimiter(),
241            "0.0.1".to_string(),
242            Header::delimiter(),
243            "test description".to_string(),
244            Header::delimiter(),
245            Engine::PyTorch.to_string(),
246            Header::delimiter(),
247            Origin::from_string("author=>local".to_string()).unwrap().to_string(),
248            Header::delimiter(),
249            InputDims::from_string("1,2".to_string()).to_string(),
250            Header::delimiter(),
251        )
252    }
253
254    pub fn generate_bytes() -> Vec<u8> {
255        generate_string().into_bytes()
256    }
257
258    #[test]
259    fn test_from_bytes() {
260        let header = Header::from_bytes(generate_bytes()).unwrap();
261
262        assert_eq!(header.keys.store.len(), 6);
263        assert_eq!(header.keys.reference.len(), 6);
264        assert_eq!(header.normalisers.store.len(), 4);
265
266        assert_eq!(header.keys.store[0], "a");
267        assert_eq!(header.keys.store[1], "b");
268        assert_eq!(header.keys.store[2], "c");
269        assert_eq!(header.keys.store[3], "d");
270        assert_eq!(header.keys.store[4], "e");
271        assert_eq!(header.keys.store[5], "f");
272    }
273
274    #[test]
275    fn test_empty_header() {
276        let string = "//=>//=>//=>//=>//=>//=>//=>//=>//=>".to_string();
277        let data = string.as_bytes();
278        let header = Header::from_bytes(data.to_vec()).unwrap();
279
280        assert_eq!(header, Header::fresh());
281
282        let string = "".to_string();
283        let data = string.as_bytes();
284        let header = Header::from_bytes(data.to_vec()).unwrap();
285
286        assert_eq!(header, Header::fresh());
287    }
288
289    #[test]
290    fn test_to_bytes() {
291        let header = Header::from_bytes(generate_bytes()).unwrap();
292        let (bytes_num, bytes) = header.to_bytes();
293        let string = String::from_utf8(bytes).unwrap();
294
295        // below the integers are correct but there is a difference with the decimal point representation in the string, we can alter this
296        // fairly easy and will investigate it
297        let expected_string = "//=>a=>b=>c=>d=>e=>f//=>a=>linear_scaling(0,1)//b=>clipping(0,1.5)//c=>log_scaling(10,0)//e=>z_score(0,1)//=>g=>linear_scaling(0,1)//=>test model name//=>0.0.1//=>test description//=>pytorch//=>author=>local//=>1,2//=>".to_string();
298
299        assert_eq!(string, expected_string);
300        assert_eq!(bytes_num, expected_string.len() as i32);
301
302        let empty_header = Header::fresh();
303        let (bytes_num, bytes) = empty_header.to_bytes();
304        let string = String::from_utf8(bytes).unwrap();
305        let expected_string = "//=>//=>//=>//=>//=>//=>//=>//=>//=>//=>".to_string();
306
307        assert_eq!(string, expected_string);
308        assert_eq!(bytes_num, expected_string.len() as i32);
309    }
310
311    #[test]
312    fn test_add_column() {
313        let mut header = Header::fresh();
314        header.add_column("a".to_string());
315        header.add_column("b".to_string());
316        header.add_column("c".to_string());
317        header.add_column("d".to_string());
318        header.add_column("e".to_string());
319        header.add_column("f".to_string());
320
321        assert_eq!(header.keys.store.len(), 6);
322        assert_eq!(header.keys.reference.len(), 6);
323
324        assert_eq!(header.keys.store[0], "a");
325        assert_eq!(header.keys.store[1], "b");
326        assert_eq!(header.keys.store[2], "c");
327        assert_eq!(header.keys.store[3], "d");
328        assert_eq!(header.keys.store[4], "e");
329        assert_eq!(header.keys.store[5], "f");
330    }
331
332    #[test] 
333    fn test_add_normalizer() {
334        let mut header = Header::fresh();
335        header.add_column("a".to_string());
336        header.add_column("b".to_string());
337        header.add_column("c".to_string());
338        header.add_column("d".to_string());
339        header.add_column("e".to_string());
340        header.add_column("f".to_string());
341
342        let _ = header.add_normaliser(
343            "a".to_string(), 
344            NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 })
345        );
346        let _ = header.add_normaliser(
347            "b".to_string(), 
348            NormaliserType::Clipping(Clipping { min: Some(0.0), max: Some(1.5) })
349        );
350        let _ = header.add_normaliser(
351            "c".to_string(), 
352            NormaliserType::LogScaling(LogScaling { base: 10.0, min: 0.0 })
353        );
354        let _ = header.add_normaliser(
355            "e".to_string(), 
356            NormaliserType::ZScore(ZScore { mean: 0.0, std_dev: 1.0 })
357        );
358
359        assert_eq!(header.normalisers.store.len(), 4);
360        assert_eq!(header.normalisers.store[0], NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 }));
361        assert_eq!(header.normalisers.store[1], NormaliserType::Clipping(Clipping { min: Some(0.0), max: Some(1.5) }));
362        assert_eq!(header.normalisers.store[2], NormaliserType::LogScaling(LogScaling { base: 10.0, min: 0.0 }));
363        assert_eq!(header.normalisers.store[3], NormaliserType::ZScore(ZScore { mean: 0.0, std_dev: 1.0 }));
364    }
365
366}
367
368