surrealml_core/storage/header/
mod.rs1pub mod keys;
3pub mod normalisers;
4pub mod output;
5pub mod string_value;
6pub mod version;
7pub mod engine;
8pub mod origin;
9pub mod input_dims;
10
11use keys::KeyBindings;
12use normalisers::wrapper::NormaliserType;
13use normalisers::NormaliserMap;
14use output::Output;
15use string_value::StringValue;
16use version::Version;
17use engine::Engine;
18use origin::Origin;
19use input_dims::InputDims;
20use crate::safe_eject;
21use crate::errors::error::{SurrealError, SurrealErrorStatus};
22
23
24#[derive(Debug, PartialEq)]
36pub struct Header {
37 pub keys: KeyBindings,
38 pub normalisers: NormaliserMap,
39 pub output: Output,
40 pub name: StringValue,
41 pub version: Version,
42 pub description: StringValue,
43 pub engine: Engine,
44 pub origin: Origin,
45 pub input_dims: InputDims,
46}
47
48
49impl Header {
50
51 pub fn fresh() -> Self {
56 Header {
57 keys: KeyBindings::fresh(),
58 normalisers: NormaliserMap::fresh(),
59 output: Output::fresh(),
60 name: StringValue::fresh(),
61 version: Version::fresh(),
62 description: StringValue::fresh(),
63 engine: Engine::fresh(),
64 origin: Origin::fresh(),
65 input_dims: InputDims::fresh(),
66 }
67 }
68
69 pub fn add_name(&mut self, model_name: String) {
74 self.name = StringValue::from_string(model_name);
75 }
76
77 pub fn add_version(&mut self, version: String) -> Result<(), SurrealError> {
82 self.version = Version::from_string(version)?;
83 Ok(())
84 }
85
86 pub fn add_description(&mut self, description: String) {
91 self.description = StringValue::from_string(description);
92 }
93
94 pub fn add_column(&mut self, column_name: String) {
100 self.keys.add_column(column_name);
101 }
102
103 pub fn add_normaliser(&mut self, column_name: String, normaliser: NormaliserType) -> Result<(), SurrealError> {
109 let _ = self.normalisers.add_normaliser(normaliser, column_name, &self.keys)?;
110 Ok(())
111 }
112
113 pub fn get_normaliser(&self, column_name: &String) -> Result<Option<&NormaliserType>, SurrealError> {
121 self.normalisers.get_normaliser(column_name.to_string(), &self.keys)
122 }
123
124 pub fn add_output(&mut self, column_name: String, normaliser: Option<NormaliserType>) {
130 self.output.name = Some(column_name);
131 self.output.normaliser = normaliser;
132 }
133
134 pub fn add_engine(&mut self, engine: String) {
139 self.engine = Engine::from_string(engine);
140 }
141
142 pub fn add_author(&mut self, author: String) {
147 self.origin.add_author(author);
148 }
149
150 pub fn add_origin(&mut self, origin: String) -> Result<(), SurrealError> {
155 self.origin.add_origin(origin)
156 }
157
158 fn delimiter() -> &'static str {
160 "//=>"
161 }
162
163 pub fn from_bytes(data: Vec<u8>) -> Result<Self, SurrealError> {
171
172 let string_data = safe_eject!(String::from_utf8(data), SurrealErrorStatus::BadRequest);
173
174 let buffer = string_data.split(Self::delimiter()).collect::<Vec<&str>>();
175
176 let keys: KeyBindings = KeyBindings::from_string(buffer.get(1).unwrap_or(&"").to_string());
177 let normalisers = NormaliserMap::from_string(buffer.get(2).unwrap_or(&"").to_string(), &keys)?;
178 let output = Output::from_string(buffer.get(3).unwrap_or(&"").to_string())?;
179 let name = StringValue::from_string(buffer.get(4).unwrap_or(&"").to_string());
180 let version = Version::from_string(buffer.get(5).unwrap_or(&"").to_string())?;
181 let description = StringValue::from_string(buffer.get(6).unwrap_or(&"").to_string());
182 let engine = Engine::from_string(buffer.get(7).unwrap_or(&"").to_string());
183 let origin = Origin::from_string(buffer.get(8).unwrap_or(&"").to_string())?;
184 let input_dims = InputDims::from_string(buffer.get(9).unwrap_or(&"").to_string());
185 Ok(Header {keys, normalisers, output, name, version, description, engine, origin, input_dims})
186 }
187
188 pub fn to_bytes(&self) -> (i32, Vec<u8>) {
193 let buffer = vec![
194 "".to_string(),
195 self.keys.to_string(),
196 self.normalisers.to_string(),
197 self.output.to_string(),
198 self.name.to_string(),
199 self.version.to_string(),
200 self.description.to_string(),
201 self.engine.to_string(),
202 self.origin.to_string(),
203 self.input_dims.to_string(),
204 "".to_string(),
205 ];
206 let buffer = buffer.join(Self::delimiter()).into_bytes();
207 (buffer.len() as i32, buffer)
208 }
209}
210
211
212#[cfg(test)]
213mod tests {
214
215 use super::*;
216 use super::keys::tests::generate_string as generate_key_string;
217 use super::normalisers::tests::generate_string as generate_normaliser_string;
218 use super::normalisers::{
219 clipping::Clipping,
220 linear_scaling::LinearScaling,
221 log_scale::LogScaling,
222 z_score::ZScore,
223 };
224
225
226 pub fn generate_string() -> String {
227 let keys = generate_key_string();
228 let normalisers = generate_normaliser_string();
229 let output = "g=>linear_scaling(0.0,1.0)".to_string();
230 format!(
231 "{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}",
232 Header::delimiter(),
233 keys,
234 Header::delimiter(),
235 normalisers,
236 Header::delimiter(),
237 output,
238 Header::delimiter(),
239 "test model name".to_string(),
240 Header::delimiter(),
241 "0.0.1".to_string(),
242 Header::delimiter(),
243 "test description".to_string(),
244 Header::delimiter(),
245 Engine::PyTorch.to_string(),
246 Header::delimiter(),
247 Origin::from_string("author=>local".to_string()).unwrap().to_string(),
248 Header::delimiter(),
249 InputDims::from_string("1,2".to_string()).to_string(),
250 Header::delimiter(),
251 )
252 }
253
254 pub fn generate_bytes() -> Vec<u8> {
255 generate_string().into_bytes()
256 }
257
258 #[test]
259 fn test_from_bytes() {
260 let header = Header::from_bytes(generate_bytes()).unwrap();
261
262 assert_eq!(header.keys.store.len(), 6);
263 assert_eq!(header.keys.reference.len(), 6);
264 assert_eq!(header.normalisers.store.len(), 4);
265
266 assert_eq!(header.keys.store[0], "a");
267 assert_eq!(header.keys.store[1], "b");
268 assert_eq!(header.keys.store[2], "c");
269 assert_eq!(header.keys.store[3], "d");
270 assert_eq!(header.keys.store[4], "e");
271 assert_eq!(header.keys.store[5], "f");
272 }
273
274 #[test]
275 fn test_empty_header() {
276 let string = "//=>//=>//=>//=>//=>//=>//=>//=>//=>".to_string();
277 let data = string.as_bytes();
278 let header = Header::from_bytes(data.to_vec()).unwrap();
279
280 assert_eq!(header, Header::fresh());
281
282 let string = "".to_string();
283 let data = string.as_bytes();
284 let header = Header::from_bytes(data.to_vec()).unwrap();
285
286 assert_eq!(header, Header::fresh());
287 }
288
289 #[test]
290 fn test_to_bytes() {
291 let header = Header::from_bytes(generate_bytes()).unwrap();
292 let (bytes_num, bytes) = header.to_bytes();
293 let string = String::from_utf8(bytes).unwrap();
294
295 let expected_string = "//=>a=>b=>c=>d=>e=>f//=>a=>linear_scaling(0,1)//b=>clipping(0,1.5)//c=>log_scaling(10,0)//e=>z_score(0,1)//=>g=>linear_scaling(0,1)//=>test model name//=>0.0.1//=>test description//=>pytorch//=>author=>local//=>1,2//=>".to_string();
298
299 assert_eq!(string, expected_string);
300 assert_eq!(bytes_num, expected_string.len() as i32);
301
302 let empty_header = Header::fresh();
303 let (bytes_num, bytes) = empty_header.to_bytes();
304 let string = String::from_utf8(bytes).unwrap();
305 let expected_string = "//=>//=>//=>//=>//=>//=>//=>//=>//=>//=>".to_string();
306
307 assert_eq!(string, expected_string);
308 assert_eq!(bytes_num, expected_string.len() as i32);
309 }
310
311 #[test]
312 fn test_add_column() {
313 let mut header = Header::fresh();
314 header.add_column("a".to_string());
315 header.add_column("b".to_string());
316 header.add_column("c".to_string());
317 header.add_column("d".to_string());
318 header.add_column("e".to_string());
319 header.add_column("f".to_string());
320
321 assert_eq!(header.keys.store.len(), 6);
322 assert_eq!(header.keys.reference.len(), 6);
323
324 assert_eq!(header.keys.store[0], "a");
325 assert_eq!(header.keys.store[1], "b");
326 assert_eq!(header.keys.store[2], "c");
327 assert_eq!(header.keys.store[3], "d");
328 assert_eq!(header.keys.store[4], "e");
329 assert_eq!(header.keys.store[5], "f");
330 }
331
332 #[test]
333 fn test_add_normalizer() {
334 let mut header = Header::fresh();
335 header.add_column("a".to_string());
336 header.add_column("b".to_string());
337 header.add_column("c".to_string());
338 header.add_column("d".to_string());
339 header.add_column("e".to_string());
340 header.add_column("f".to_string());
341
342 let _ = header.add_normaliser(
343 "a".to_string(),
344 NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 })
345 );
346 let _ = header.add_normaliser(
347 "b".to_string(),
348 NormaliserType::Clipping(Clipping { min: Some(0.0), max: Some(1.5) })
349 );
350 let _ = header.add_normaliser(
351 "c".to_string(),
352 NormaliserType::LogScaling(LogScaling { base: 10.0, min: 0.0 })
353 );
354 let _ = header.add_normaliser(
355 "e".to_string(),
356 NormaliserType::ZScore(ZScore { mean: 0.0, std_dev: 1.0 })
357 );
358
359 assert_eq!(header.normalisers.store.len(), 4);
360 assert_eq!(header.normalisers.store[0], NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 }));
361 assert_eq!(header.normalisers.store[1], NormaliserType::Clipping(Clipping { min: Some(0.0), max: Some(1.5) }));
362 assert_eq!(header.normalisers.store[2], NormaliserType::LogScaling(LogScaling { base: 10.0, min: 0.0 }));
363 assert_eq!(header.normalisers.store[3], NormaliserType::ZScore(ZScore { mean: 0.0, std_dev: 1.0 }));
364 }
365
366}
367
368