surrealml_core/execution/
compute.rs1use crate::storage::surml_file::SurMlFile;
3use std::collections::HashMap;
4use ndarray::ArrayD;
5use ort::value::ValueType;
6use ort::session::Session;
7
8use crate::safe_eject;
9use crate::errors::error::{SurrealError, SurrealErrorStatus};
10use crate::execution::session::get_session;
11
12
13pub struct ModelComputation<'a> {
18 pub surml_file: &'a mut SurMlFile,
19}
20
21
22impl <'a>ModelComputation<'a> {
23
24 pub fn input_tensor_from_key_bindings(&self, input_values: HashMap<String, f32>) -> Result<ArrayD<f32>, SurrealError> {
32 let buffer = self.input_vector_from_key_bindings(input_values)?;
33 Ok(ndarray::arr1::<f32>(&buffer).into_dyn())
34 }
35
36 fn process_input_dims(session_ref: &Session) -> Vec<usize> {
44 let some_dims = match &session_ref.inputs[0].input_type {
45 ValueType::Tensor{ ty: _, dimensions: new_dims, dimension_symbols: _ } => Some(new_dims),
46 _ => None
47 };
48 let mut dims_cache = Vec::new();
49 for dim in some_dims.unwrap() {
50 if dim < &0 {
51 dims_cache.push((dim * -1) as usize);
52 }
53 else {
54 dims_cache.push(*dim as usize);
55 }
56 }
57 dims_cache
58 }
59
60 pub fn input_vector_from_key_bindings(&self, mut input_values: HashMap<String, f32>) -> Result<Vec<f32>, SurrealError> {
68 let mut buffer = Vec::with_capacity(self.surml_file.header.keys.store.len());
69
70 for key in &self.surml_file.header.keys.store {
71 let value = match input_values.get_mut(key) {
72 Some(value) => value,
73 None => return Err(SurrealError::new(format!("src/execution/compute.rs 67: Key {} not found in input values", key), SurrealErrorStatus::NotFound))
74 };
75 buffer.push(std::mem::take(value));
76 }
77
78 Ok(buffer)
79 }
80
81 pub fn raw_compute(&self, tensor: ArrayD<f32>, _dims: Option<(i32, i32)>) -> Result<Vec<f32>, SurrealError> {
89 let session = get_session(self.surml_file.model.clone())?;
90 let dims_cache = ModelComputation::process_input_dims(&session);
91 let tensor = match tensor.into_shape_with_order(dims_cache) {
92 Ok(tensor) => tensor,
93 Err(_) => return Err(SurrealError::new("Failed to reshape tensor to input dimensions".to_string(), SurrealErrorStatus::Unknown))
94 };
95 let tensor = match ort::value::Tensor::from_array(tensor) {
96 Ok(tensor) => tensor,
97 Err(_) => return Err(SurrealError::new("Failed to convert tensor to ort tensor".to_string(), SurrealErrorStatus::Unknown))
98 };
99 let x = match ort::inputs![tensor] {
100 Ok(x) => x,
101 Err(_) => return Err(SurrealError::new("Failed to create input tensor".to_string(), SurrealErrorStatus::Unknown))
102 };
103 let outputs = safe_eject!(session.run(x), SurrealErrorStatus::Unknown);
104
105 let mut buffer: Vec<f32> = Vec::new();
106
107 match outputs[0].try_extract_tensor::<f32>() {
109 Ok(y) => {
110 for i in y.view().clone().into_iter() {
111 buffer.push(*i);
112 }
113 },
114 Err(_) => {
115 for i in safe_eject!(outputs[0].try_extract_tensor::<i64>(), SurrealErrorStatus::Unknown).view().clone().into_iter() {
116 buffer.push(*i as f32);
117 }
118 }
119 };
120 return Ok(buffer)
121 }
122
123 pub fn buffered_compute(&self, input_values: &mut HashMap<String, f32>) -> Result<Vec<f32>, SurrealError> {
135 for (key, value) in &mut *input_values {
137 let value_ref = value.clone();
138 match self.surml_file.header.get_normaliser(&key.to_string())? {
139 Some(normaliser) => {
140 *value = normaliser.normalise(value_ref);
141 },
142 None => {}
143 }
144 }
145 let tensor = self.input_tensor_from_key_bindings(input_values.clone())?;
146 let output = self.raw_compute(tensor, None)?;
147
148 if self.surml_file.header.output.normaliser == None {
150 return Ok(output)
151 }
152
153 let output_normaliser = match self.surml_file.header.output.normaliser.as_ref() {
155 Some(normaliser) => normaliser,
156 None => return Err(SurrealError::new(
157 String::from("No normaliser present for output which shouldn't happen as passed initial check for").to_string(),
158 SurrealErrorStatus::Unknown
159 ))
160 };
161 let mut buffer = Vec::with_capacity(output.len());
162
163 for value in output {
164 buffer.push(output_normaliser.inverse_normalise(value));
165 }
166 return Ok(buffer)
167 }
168
169}
170
171
172#[cfg(test)]
173mod tests {
174
175 use super::*;
176 use crate::execution::session::set_environment;
177
178 #[cfg(feature = "sklearn-tests")]
179 #[test]
180 fn test_raw_compute_linear_sklearn() {
181 set_environment().unwrap();
182 let mut file = SurMlFile::from_file("./model_stash/sklearn/surml/linear.surml").unwrap();
183 let model_computation = ModelComputation {
184 surml_file: &mut file,
185 };
186
187 let mut input_values = HashMap::new();
188 input_values.insert(String::from("squarefoot"), 1000.0);
189 input_values.insert(String::from("num_floors"), 2.0);
190
191 let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
192
193 let output = model_computation.raw_compute(raw_input, Some((1, 2))).unwrap();
194 assert_eq!(output.len(), 1);
195 assert_eq!(output[0], 985.57745);
196 }
197
198 #[cfg(feature = "sklearn-tests")]
199 #[test]
200 fn test_buffered_compute_linear_sklearn() {
201 set_environment().unwrap();
202 let mut file = SurMlFile::from_file("./model_stash/sklearn/surml/linear.surml").unwrap();
203 let model_computation = ModelComputation {
204 surml_file: &mut file,
205 };
206
207 let mut input_values = HashMap::new();
208 input_values.insert(String::from("squarefoot"), 1000.0);
209 input_values.insert(String::from("num_floors"), 2.0);
210
211 let output = model_computation.buffered_compute(&mut input_values).unwrap();
212 assert_eq!(output.len(), 1);
213 }
214
215 #[cfg(feature = "onnx-tests")]
216 #[test]
217 fn test_raw_compute_linear_onnx() {
218 set_environment().unwrap();
219 let mut file = SurMlFile::from_file("./model_stash/onnx/surml/linear.surml").unwrap();
220 let model_computation = ModelComputation {
221 surml_file: &mut file,
222 };
223
224 let mut input_values = HashMap::new();
225 input_values.insert(String::from("squarefoot"), 1000.0);
226 input_values.insert(String::from("num_floors"), 2.0);
227
228 let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
229
230 let output = model_computation.raw_compute(raw_input, Some((1, 2))).unwrap();
231 assert_eq!(output.len(), 1);
232 assert_eq!(output[0], 985.57745);
233 }
234
235 #[cfg(feature = "onnx-tests")]
236 #[test]
237 fn test_buffered_compute_linear_onnx() {
238 set_environment().unwrap();
239 let mut file = SurMlFile::from_file("./model_stash/onnx/surml/linear.surml").unwrap();
240 let model_computation = ModelComputation {
241 surml_file: &mut file,
242 };
243
244 let mut input_values = HashMap::new();
245 input_values.insert(String::from("squarefoot"), 1000.0);
246 input_values.insert(String::from("num_floors"), 2.0);
247
248 let output = model_computation.buffered_compute(&mut input_values).unwrap();
249 assert_eq!(output.len(), 1);
250 }
251
252 #[cfg(feature = "torch-tests")]
253 #[test]
254 fn test_raw_compute_linear_torch() {
255 set_environment().unwrap();
256 let mut file = SurMlFile::from_file("./model_stash/torch/surml/linear.surml").unwrap();
257 let model_computation = ModelComputation {
258 surml_file: &mut file,
259 };
260
261 let mut input_values = HashMap::new();
262 input_values.insert(String::from("squarefoot"), 1000.0);
263 input_values.insert(String::from("num_floors"), 2.0);
264
265 let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
266
267 let output = model_computation.raw_compute(raw_input, None).unwrap();
268 assert_eq!(output.len(), 1);
269 }
270
271 #[cfg(feature = "torch-tests")]
272 #[test]
273 fn test_buffered_compute_linear_torch() {
274 set_environment().unwrap();
275 let mut file = SurMlFile::from_file("./model_stash/torch/surml/linear.surml").unwrap();
276 let model_computation = ModelComputation {
277 surml_file: &mut file,
278 };
279
280 let mut input_values = HashMap::new();
281 input_values.insert(String::from("squarefoot"), 1000.0);
282 input_values.insert(String::from("num_floors"), 2.0);
283
284 let output = model_computation.buffered_compute(&mut input_values).unwrap();
285 assert_eq!(output.len(), 1);
286 }
287
288 #[cfg(feature = "tensorflow-tests")]
289 #[test]
290 fn test_raw_compute_linear_tensorflow() {
291 set_environment().unwrap();
292 let mut file = SurMlFile::from_file("./model_stash/tensorflow/surml/linear.surml").unwrap();
293 let model_computation = ModelComputation {
294 surml_file: &mut file,
295 };
296
297 let mut input_values = HashMap::new();
298 input_values.insert(String::from("squarefoot"), 1000.0);
299 input_values.insert(String::from("num_floors"), 2.0);
300
301 let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
302
303 let output = model_computation.raw_compute(raw_input, None).unwrap();
304 assert_eq!(output.len(), 1);
305 }
306
307 #[cfg(feature = "tensorflow-tests")]
308 #[test]
309 fn test_buffered_compute_linear_tensorflow() {
310 set_environment().unwrap();
311 let mut file = SurMlFile::from_file("./model_stash/tensorflow/surml/linear.surml").unwrap();
312 let model_computation = ModelComputation {
313 surml_file: &mut file,
314 };
315
316 let mut input_values = HashMap::new();
317 input_values.insert(String::from("squarefoot"), 1000.0);
318 input_values.insert(String::from("num_floors"), 2.0);
319
320 let output = model_computation.buffered_compute(&mut input_values).unwrap();
321 assert_eq!(output.len(), 1);
322 }
323}