surrealml_core/execution/
compute.rs

1//! Defines the operations around performing computations on a loaded model.
2use crate::storage::surml_file::SurMlFile;
3use std::collections::HashMap;
4use ndarray::ArrayD;
5use ort::value::ValueType;
6use ort::session::Session;
7
8use crate::safe_eject;
9use crate::errors::error::{SurrealError, SurrealErrorStatus};
10use crate::execution::session::get_session;
11
12
13/// A wrapper for the loaded machine learning model so we can perform computations on the loaded model.
14/// 
15/// # Attributes
16/// * `surml_file` - The loaded machine learning model using interior mutability to allow mutable access to the model
17pub struct ModelComputation<'a> {
18    pub surml_file: &'a mut SurMlFile,
19}
20
21
22impl <'a>ModelComputation<'a> {
23
24    /// Creates a Tensor that can be used as input to the loaded model from a hashmap of keys and values.
25    /// 
26    /// # Arguments
27    /// * `input_values` - A hashmap of keys and values that will be used to create the input tensor.
28    /// 
29    /// # Returns
30    /// A Tensor that can be used as input to the loaded model.
31    pub fn input_tensor_from_key_bindings(&self, input_values: HashMap<String, f32>) -> Result<ArrayD<f32>, SurrealError> {
32        let buffer = self.input_vector_from_key_bindings(input_values)?;
33        Ok(ndarray::arr1::<f32>(&buffer).into_dyn())
34    }
35
36    /// Creates a vector of dimensions for the input tensor from the loaded model.
37    /// 
38    /// # Arguments
39    /// * `input_dims` - The input dimensions from the loaded model.
40    /// 
41    /// # Returns
42    /// A vector of dimensions for the input tensor to be reshaped into from the loaded model.
43    fn process_input_dims(session_ref: &Session) -> Vec<usize> {
44        let some_dims = match &session_ref.inputs[0].input_type {
45            ValueType::Tensor{ ty: _, dimensions: new_dims, dimension_symbols: _ } => Some(new_dims),
46            _ => None
47        };
48        let mut dims_cache = Vec::new();
49        for dim in some_dims.unwrap() {
50            if dim < &0 {
51                dims_cache.push((dim * -1) as usize);
52            }
53            else {
54                dims_cache.push(*dim as usize);
55            }
56        }
57        dims_cache
58    }
59
60    /// Creates a Vector that can be used manipulated with other operations such as normalisation from a hashmap of keys and values.
61    /// 
62    /// # Arguments
63    /// * `input_values` - A hashmap of keys and values that will be used to create the input vector.
64    /// 
65    /// # Returns
66    /// A Vector that can be used manipulated with other operations such as normalisation.
67    pub fn input_vector_from_key_bindings(&self, mut input_values: HashMap<String, f32>) -> Result<Vec<f32>, SurrealError> {
68        let mut buffer = Vec::with_capacity(self.surml_file.header.keys.store.len());
69
70        for key in &self.surml_file.header.keys.store {
71            let value = match input_values.get_mut(key) {
72                Some(value) => value,
73                None => return Err(SurrealError::new(format!("src/execution/compute.rs 67: Key {} not found in input values", key), SurrealErrorStatus::NotFound))
74            };
75            buffer.push(std::mem::take(value));
76        }
77
78        Ok(buffer)
79    }
80
81    /// Performs a raw computation on the loaded model.
82    /// 
83    /// # Arguments
84    /// * `tensor` - The input tensor to the loaded model.
85    /// 
86    /// # Returns
87    /// The computed output tensor from the loaded model.
88    pub fn raw_compute(&self, tensor: ArrayD<f32>, _dims: Option<(i32, i32)>) -> Result<Vec<f32>, SurrealError> {
89        let session = get_session(self.surml_file.model.clone())?;
90        let dims_cache = ModelComputation::process_input_dims(&session);
91        let tensor = match tensor.into_shape_with_order(dims_cache) {
92            Ok(tensor) => tensor,
93            Err(_) => return Err(SurrealError::new("Failed to reshape tensor to input dimensions".to_string(), SurrealErrorStatus::Unknown))
94        };
95        let tensor = match ort::value::Tensor::from_array(tensor) {
96            Ok(tensor) => tensor,
97            Err(_) => return Err(SurrealError::new("Failed to convert tensor to ort tensor".to_string(), SurrealErrorStatus::Unknown))
98        };
99        let x = match ort::inputs![tensor] {
100            Ok(x) => x,
101            Err(_) => return Err(SurrealError::new("Failed to create input tensor".to_string(), SurrealErrorStatus::Unknown))
102        };
103        let outputs = safe_eject!(session.run(x), SurrealErrorStatus::Unknown);
104
105        let mut buffer: Vec<f32> = Vec::new();
106
107        // extract the output tensor converting the values to f32 if they are i64
108        match outputs[0].try_extract_tensor::<f32>() {
109            Ok(y) => {
110                for i in y.view().clone().into_iter() {
111                    buffer.push(*i);
112                }
113            },
114            Err(_) => {
115                for i in safe_eject!(outputs[0].try_extract_tensor::<i64>(), SurrealErrorStatus::Unknown).view().clone().into_iter() {
116                    buffer.push(*i as f32);
117                }
118            }
119        };
120        return Ok(buffer)
121    }
122
123    /// Checks the header applying normalisers if present and then performs a raw computation on the loaded model. Will
124    /// also apply inverse normalisers if present on the outputs.
125    /// 
126    /// # Notes
127    /// This function is fairly coupled and will consider breaking out the functions later on if needed.
128    /// 
129    /// # Arguments
130    /// * `input_values` - A hashmap of keys and values that will be used to create the input tensor.
131    /// 
132    /// # Returns
133    /// The computed output tensor from the loaded model.
134    pub fn buffered_compute(&self, input_values: &mut HashMap<String, f32>) -> Result<Vec<f32>, SurrealError> {
135        // applying normalisers if present
136        for (key, value) in &mut *input_values {
137            let value_ref = value.clone();
138            match self.surml_file.header.get_normaliser(&key.to_string())? {
139                Some(normaliser) => {
140                    *value = normaliser.normalise(value_ref);
141                },
142                None => {}
143            }
144        }
145        let tensor = self.input_tensor_from_key_bindings(input_values.clone())?;
146        let output = self.raw_compute(tensor, None)?;
147        
148        // if no normaliser is present, return the output
149        if self.surml_file.header.output.normaliser == None {
150            return Ok(output)
151        }
152
153        // apply the normaliser to the output
154        let output_normaliser = match self.surml_file.header.output.normaliser.as_ref() {
155            Some(normaliser) => normaliser,
156            None => return Err(SurrealError::new(
157                String::from("No normaliser present for output which shouldn't happen as passed initial check for").to_string(), 
158                SurrealErrorStatus::Unknown
159            ))
160        };
161        let mut buffer = Vec::with_capacity(output.len());
162
163        for value in output {
164            buffer.push(output_normaliser.inverse_normalise(value));
165        }
166        return Ok(buffer)
167    }
168
169}
170
171
172#[cfg(test)]
173mod tests {
174
175    use super::*;
176    use crate::execution::session::set_environment;
177
178    #[cfg(feature = "sklearn-tests")]
179    #[test]
180    fn test_raw_compute_linear_sklearn() {
181        set_environment().unwrap();
182        let mut file = SurMlFile::from_file("./model_stash/sklearn/surml/linear.surml").unwrap();
183        let model_computation = ModelComputation {
184            surml_file: &mut file,
185        };
186
187        let mut input_values = HashMap::new();
188        input_values.insert(String::from("squarefoot"), 1000.0);
189        input_values.insert(String::from("num_floors"), 2.0);
190
191        let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
192
193        let output = model_computation.raw_compute(raw_input, Some((1, 2))).unwrap();
194        assert_eq!(output.len(), 1);
195        assert_eq!(output[0], 985.57745);
196    }
197
198    #[cfg(feature = "sklearn-tests")]
199    #[test]
200    fn test_buffered_compute_linear_sklearn() {
201        set_environment().unwrap();
202        let mut file = SurMlFile::from_file("./model_stash/sklearn/surml/linear.surml").unwrap();
203        let model_computation = ModelComputation {
204            surml_file: &mut file,
205        };
206
207        let mut input_values = HashMap::new();
208        input_values.insert(String::from("squarefoot"), 1000.0);
209        input_values.insert(String::from("num_floors"), 2.0);
210
211        let output = model_computation.buffered_compute(&mut input_values).unwrap();
212        assert_eq!(output.len(), 1);
213    }
214
215    #[cfg(feature = "onnx-tests")]
216    #[test]
217    fn test_raw_compute_linear_onnx() {
218        set_environment().unwrap();
219        let mut file = SurMlFile::from_file("./model_stash/onnx/surml/linear.surml").unwrap();
220        let model_computation = ModelComputation {
221            surml_file: &mut file,
222        };
223
224        let mut input_values = HashMap::new();
225        input_values.insert(String::from("squarefoot"), 1000.0);
226        input_values.insert(String::from("num_floors"), 2.0);
227
228        let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
229
230        let output = model_computation.raw_compute(raw_input, Some((1, 2))).unwrap();
231        assert_eq!(output.len(), 1);
232        assert_eq!(output[0], 985.57745);
233    }
234
235    #[cfg(feature = "onnx-tests")]
236    #[test]
237    fn test_buffered_compute_linear_onnx() {
238        set_environment().unwrap();
239        let mut file = SurMlFile::from_file("./model_stash/onnx/surml/linear.surml").unwrap();
240        let model_computation = ModelComputation {
241            surml_file: &mut file,
242        };
243
244        let mut input_values = HashMap::new();
245        input_values.insert(String::from("squarefoot"), 1000.0);
246        input_values.insert(String::from("num_floors"), 2.0);
247
248        let output = model_computation.buffered_compute(&mut input_values).unwrap();
249        assert_eq!(output.len(), 1);
250    }
251
252    #[cfg(feature = "torch-tests")]
253    #[test]
254    fn test_raw_compute_linear_torch() {
255        set_environment().unwrap();
256        let mut file = SurMlFile::from_file("./model_stash/torch/surml/linear.surml").unwrap();
257        let model_computation = ModelComputation {
258            surml_file: &mut file,
259        };
260
261        let mut input_values = HashMap::new();
262        input_values.insert(String::from("squarefoot"), 1000.0);
263        input_values.insert(String::from("num_floors"), 2.0);
264
265        let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
266
267        let output = model_computation.raw_compute(raw_input, None).unwrap();
268        assert_eq!(output.len(), 1);
269    }
270
271    #[cfg(feature = "torch-tests")]
272    #[test]
273    fn test_buffered_compute_linear_torch() {
274        set_environment().unwrap();
275        let mut file = SurMlFile::from_file("./model_stash/torch/surml/linear.surml").unwrap();
276        let model_computation = ModelComputation {
277            surml_file: &mut file,
278        };
279
280        let mut input_values = HashMap::new();
281        input_values.insert(String::from("squarefoot"), 1000.0);
282        input_values.insert(String::from("num_floors"), 2.0);
283
284        let output = model_computation.buffered_compute(&mut input_values).unwrap();
285        assert_eq!(output.len(), 1);
286    }
287
288    #[cfg(feature = "tensorflow-tests")]
289    #[test]
290    fn test_raw_compute_linear_tensorflow() {
291        set_environment().unwrap();
292        let mut file = SurMlFile::from_file("./model_stash/tensorflow/surml/linear.surml").unwrap();
293        let model_computation = ModelComputation {
294            surml_file: &mut file,
295        };
296
297        let mut input_values = HashMap::new();
298        input_values.insert(String::from("squarefoot"), 1000.0);
299        input_values.insert(String::from("num_floors"), 2.0);
300
301        let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
302
303        let output = model_computation.raw_compute(raw_input, None).unwrap();
304        assert_eq!(output.len(), 1);
305    }
306
307    #[cfg(feature = "tensorflow-tests")]
308    #[test]
309    fn test_buffered_compute_linear_tensorflow() {
310        set_environment().unwrap();
311        let mut file = SurMlFile::from_file("./model_stash/tensorflow/surml/linear.surml").unwrap();
312        let model_computation = ModelComputation {
313            surml_file: &mut file,
314        };
315
316        let mut input_values = HashMap::new();
317        input_values.insert(String::from("squarefoot"), 1000.0);
318        input_values.insert(String::from("num_floors"), 2.0);
319
320        let output = model_computation.buffered_compute(&mut input_values).unwrap();
321        assert_eq!(output.len(), 1);
322    }
323}