stonnx_api/executor/
mod.rs

1use std::{
2    collections::{HashMap, HashSet},
3    path::Path,
4    sync::{Arc, RwLock, RwLockReadGuard, atomic::Ordering},
5};
6
7use crate::{
8    common::{
9        Args, BoxResult, FileInputs, OperationFn, OperatorResult, TensorType, VerbosityLevel,
10        MAX_OPSET_VERSION, VERBOSE,
11    },
12    onnx,
13    operators::OPERATION_MAP,
14    print_at_level,
15    protograph::{build_graph_from_proto, GraphOutputType},
16    read_model,
17    utils::{initialize_nodes, make_initializers},
18    utils::{make_external_outputs, make_graph_outputs, operator_not_implemented, OutputInfo},
19};
20
21use anyhow::anyhow;
22use smallvec::SmallVec;
23#[derive(Debug, Clone, PartialEq)]
24/// Represent an ONNX execution node, with an ID, the function to execute it and a reference to the internal ONNX node
25pub struct ONNXNode {
26    id: usize,
27    op_func: OperationFn,
28    node_ref: Arc<onnx::NodeProto>,
29}
30
31impl ONNXNode {
32    /// Create a new ONNXNode
33    fn new(id: usize, op_func: OperationFn, node_ref: onnx::NodeProto) -> Self {
34        Self {
35            id,
36            op_func,
37            node_ref: Arc::new(node_ref),
38        }
39    }
40
41    /// Executes the ONNX node given the inputs map and the opset version
42    fn execute(
43        &self,
44        node_inputs: RwLockReadGuard<HashMap<String, Arc<TensorType>>>,
45        opset_version: i64,
46    ) -> BoxResult<OperatorResult> {
47        let mut inputs = vec![];
48        let mut outputs = vec![];
49        let mut all_nodes_have_init = true;
50        for input in self.node_ref.input.iter() {
51            if let Some(k) = node_inputs.get(input) {
52                inputs.push(k.clone());
53            } else {
54                all_nodes_have_init = false;
55            }
56        }
57        drop(node_inputs); // drop the rwlock as soon as possible
58        for output in self.node_ref.output.iter() {
59            outputs.push(output);
60        }
61        if !all_nodes_have_init {
62            return Err(anyhow!("Some nodes in this operation have not been initialized yet, this means the operations aren't in order, fix the code to account for this"));
63        }
64        let input_names = self
65            .node_ref
66            .input
67            .iter()
68            .map(|s| s.as_str())
69            .collect::<Vec<&str>>();
70        let output_names = self
71            .node_ref
72            .output
73            .iter()
74            .map(|s| s.as_str())
75            .collect::<Vec<&str>>();
76        print_at_level!(
77            VerbosityLevel::Informational,
78            "Running {} operator (id: {}, thread: {:?}) between {:?} to get {:?}",
79            self.node_ref.op_type(),
80            self.id,
81            std::thread::current().id(),
82            input_names,
83            output_names
84        );
85        // most operators have 2/3 inputs, so we use a smallvec to avoid heap allocations
86        let inputs: SmallVec<[&TensorType; 4]> = inputs.iter().map(|x| x.as_ref()).collect();
87        (self.op_func)(
88            &inputs,
89            self.node_ref.as_ref(),
90            opset_version,
91            output_names.len(),
92        )
93    }
94}
95
96impl std::hash::Hash for ONNXNode {
97    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
98        self.id.hash(state);
99    }
100}
101
102impl std::cmp::Eq for ONNXNode {}
103
104/// Handle the output of an operator, saving it to disk if verbose is VerbosityLevel::Results or higher, and saving it to graph_outputs
105fn handle_output(
106    result: OperatorResult,
107    node: &ONNXNode,
108    outputs_dir: &Path,
109    node_inputs: &mut HashMap<String, Arc<TensorType>>,
110    graph_outputs: &mut HashMap<String, OutputInfo>,
111) -> BoxResult<Vec<String>> {
112    let node = node.node_ref.clone();
113    let outputs = node.output.to_vec();
114    let result = result.result; // we love rust
115    assert_eq!(outputs.len(), result.len());
116    for (output_name, res) in outputs.iter().zip(result.into_iter()) {
117        print_at_level!(
118            VerbosityLevel::Informational,
119            "\tOutput {} has shape {:?}",
120            output_name,
121            res.shape()
122        );
123        if VERBOSE.load(Ordering::Relaxed) >= VerbosityLevel::Results
124        {
125            res.to_file(outputs_dir, output_name)?;
126        }
127        node_inputs.insert(output_name.to_string(), Arc::new(res));
128    }
129    for output_name in outputs.iter() {
130        if let Some(gout) = graph_outputs.get_mut(output_name) {
131            if let Some(produced) = node_inputs.get(output_name) {
132                gout.data = Some(produced.as_ref().clone());
133            }
134        }
135    }
136    Ok(outputs)
137}
138
139/// A dependency graph of ONNX nodes
140struct DependencyGraph {
141    /// A map of ONNX nodes to their input requirements
142    pub node_input_requirements: HashMap<ONNXNode, Vec<String>>,
143    /// A map of input names to the nodes that require them
144    pub input_link_map: HashMap<String, Vec<ONNXNode>>,
145    /// A set of not implemented operators
146    pub not_implemented: HashSet<String>,
147}
148
149/// Create a dependency graph from an ONNX graph and a map of input names to their corresponding tensors
150fn create_links_and_requirements(
151    graph: &onnx::GraphProto,
152    node_inputs: &HashMap<String, Arc<TensorType>>,
153) -> BoxResult<DependencyGraph> {
154    let mut node_input_requirements: HashMap<ONNXNode, Vec<String>> = HashMap::new();
155    let mut input_link_map: HashMap<String, Vec<ONNXNode>> = HashMap::new();
156    let mut not_implemented = HashSet::new();
157    for (counter, node) in graph.node.iter().enumerate() {
158        let input_names = node
159            .input
160            .iter()
161            .filter_map(|s| {
162                if node_inputs.contains_key(s.as_str()) {
163                    None
164                } else {
165                    Some(s.clone())
166                }
167            })
168            .collect::<Vec<String>>();
169        if let Some(name) = node.op_type.as_deref() {
170            for input_name in input_names.iter() {
171                input_link_map
172                    .entry(input_name.to_string())
173                    .or_default()
174                    .push(ONNXNode::new(
175                        counter,
176                        *OPERATION_MAP
177                            .get(name)
178                            .unwrap_or(&(operator_not_implemented as OperationFn)),
179                        node.clone(),
180                    ));
181            }
182            if let Some(op_func) = OPERATION_MAP.get(name) {
183                node_input_requirements
184                    .insert(ONNXNode::new(counter, *op_func, node.clone()), input_names);
185            } else {
186                node_input_requirements.insert(
187                    ONNXNode::new(
188                        counter,
189                        operator_not_implemented as OperationFn,
190                        node.clone(),
191                    ),
192                    input_names,
193                );
194                not_implemented.insert(name.to_string());
195            }
196        }
197    }
198    Ok(DependencyGraph {
199        node_input_requirements,
200        input_link_map,
201        not_implemented,
202    })
203}
204
205/// Compare the expected outputs to the actual outputs
206pub fn compare_outputs(
207    expected_outputs: HashMap<String, TensorType>,
208    mut graph_outputs: HashMap<String, OutputInfo>,
209) -> BoxResult<HashMap<String, OutputInfo>> {
210    let mut results = HashMap::new();
211    for (name, value) in expected_outputs.iter() {
212        if let Some((namestring, gout)) = graph_outputs.remove_entry(name) {
213            if let Some(data) = &gout.data {
214                if value.shape() != data.shape() {
215                    return Err(anyhow!(
216                        "Expected output {} to have shape {:?} but got {:?}",
217                        name,
218                        value.shape(),
219                        data.shape()
220                    ));
221                } else {
222                    print_at_level!(
223                        VerbosityLevel::Minimal,
224                        "Output {} has shape {:?} as expected",
225                        name,
226                        value.shape()
227                    );
228                }
229                if value.value_type() != data.value_type() {
230                    return Err(anyhow!(
231                        "Expected output {} to have type {:?} but got {:?}",
232                        name,
233                        value.value_type(),
234                        data.value_type()
235                    ));
236                } else {
237                    print_at_level!(
238                        VerbosityLevel::Minimal,
239                        "Output {} has type {:?} as expected",
240                        name,
241                        value.value_type()
242                    );
243                }
244                match (value, data) {
245                    (TensorType::F32(v), TensorType::F32(d)) => {
246                        let mut count = 0;
247                        let mut diff = vec![];
248                        for (i, (v, d)) in v.iter().zip(d.iter()).enumerate() {
249                            if (v - d).abs() > 0.0001 {
250                                count += 1;
251                            }
252                            diff.push((i, v, d, (v - d).abs()));
253                        }
254                        let max = diff
255                            .iter()
256                            .max_by(|(_, _, _, d1), (_, _, _, d2)| {
257                                d1.partial_cmp(d2).unwrap_or(std::cmp::Ordering::Less)
258                            })
259                            .expect("Failed to get max difference");
260                        print_at_level!(
261                            VerbosityLevel::Minimal,
262                            "Output {} has {} values with absolute difference of more than .0001\n\tMax difference: {:?}",
263                            name,
264                            count,
265                            max
266                        );
267                    }
268                    _ => todo!(
269                        "Compare output {:?} with {:?}",
270                        value.value_type(),
271                        data.value_type()
272                    ),
273                }
274            }
275            results.insert(namestring, gout);
276        }
277    }
278    Ok(results)
279}
280
281#[cfg(feature = "custom-threadpool")]
282fn create_pool(parallelism: usize) -> BoxResult<crate::parallel::ThreadPool> {
283    Ok(crate::parallel::ThreadPool::new(parallelism / 3 * 2)) // use 2/3rds of the available threads/cores
284}
285
286#[cfg(feature = "custom-threadpool")]
287fn wait_pool(pool: &crate::parallel::ThreadPool) {
288    pool.wait()
289}
290
291#[cfg(not(feature = "custom-threadpool"))]
292fn create_pool(parallelism: usize) -> BoxResult<rayon::ThreadPool> {
293    Ok(rayon::ThreadPoolBuilder::new()
294        .num_threads(parallelism / 3 * 2)
295        .build()?)
296}
297
298#[cfg(not(feature = "custom-threadpool"))]
299fn wait_pool(_pool: &rayon::ThreadPool) {
300    // do nothing
301}
302
303pub fn execute_model(args: &Args) -> BoxResult<HashMap<String, OutputInfo>> {
304    VERBOSE
305        .store(VerbosityLevel::new(args.verbose) as usize, Ordering::Relaxed);
306    print_at_level!(
307        VerbosityLevel::Minimal,
308        "Running model: {}",
309        args.model.display()
310    );
311    let inputspath = if args.model.is_relative() {
312        Path::new("models").join(&args.model).join("inputs.json")
313    } else {
314        args.model.join("inputs.json")
315    };
316    let inputs_file = std::fs::File::open(inputspath)?;
317    let mut fileinputs: FileInputs = serde_json::from_reader(inputs_file)?;
318    fileinputs.extend_paths(&args.model);
319    let model = read_model(Path::new(&fileinputs.modelpath))?;
320    let outputs_dir = Path::new("outputs").join(&args.model);
321    let parallelism: usize = std::thread::available_parallelism()?.into();
322    let pool = create_pool(parallelism)?;
323    if VERBOSE
324        .load(Ordering::Relaxed) >= VerbosityLevel::Results
325    {
326        std::fs::create_dir_all(&outputs_dir)?;
327    }
328    let opset_version = if let Some(v) = model.opset_import.first() {
329        if let Some(v) = v.version {
330            v
331        } else {
332            MAX_OPSET_VERSION
333        }
334    } else {
335        MAX_OPSET_VERSION
336    };
337    if opset_version > MAX_OPSET_VERSION {
338        return Err(anyhow!(
339            "Opset version {} is not supported, max supported version is {}",
340            opset_version,
341            MAX_OPSET_VERSION
342        ));
343    }
344    let graph = model.graph.get_or_default();
345    if args.gengraph {
346        build_graph_from_proto(
347            graph,
348            &fileinputs.modelpath,
349            match args.graphtype.as_str() {
350                "json" => GraphOutputType::Json,
351                "dot" => GraphOutputType::Dot,
352                _ => return Err(anyhow!("Invalid graph type")),
353            },
354        )?;
355    }
356    let initializers = make_initializers(graph)?;
357    let node_inputs = initialize_nodes(graph, &fileinputs, initializers)?;
358    let expected_outputs = make_external_outputs(graph, &fileinputs)?;
359    let mut graph_outputs = make_graph_outputs(graph)?;
360    let mut dependency_graph = create_links_and_requirements(graph, &node_inputs)?;
361    let node_inputs = Arc::new(RwLock::new(node_inputs));
362    let (tx, rx) = std::sync::mpsc::channel();
363    for vi in graph.value_info.iter() {
364        if let Some(onnx::type_proto::Value::TensorType(_)) = vi.type_.value {
365            // If the type is Tensor, then we are fine because that's implemented
366        } else {
367            unimplemented!("ValueInfoProto type {:?}", vi.type_)
368        }
369    }
370
371    print_at_level!(
372        VerbosityLevel::Informational,
373        "Number of not implemented operators: {}",
374        dependency_graph.not_implemented.len()
375    );
376    for name in dependency_graph.not_implemented.iter() {
377        eprintln!("Model uses operator {} which is not implemented yet", name);
378    }
379    if !dependency_graph.not_implemented.is_empty() && args.failfast {
380        return Err(anyhow!("Not implemented operators found"));
381    }
382    loop {
383        let mut nodes_ready = vec![];
384        for (node, inputs) in dependency_graph.node_input_requirements.iter() {
385            if inputs.is_empty() {
386                nodes_ready.push(node.clone());
387            }
388        }
389        dependency_graph
390            .node_input_requirements
391            .retain(|_, v| !v.is_empty());
392        for node in nodes_ready {
393            let node_inputs_ref = node_inputs.clone();
394            let tx = tx.clone();
395            pool.spawn(move || {
396                let r = {
397                    let node_inputs_lock =
398                        node_inputs_ref.read().expect("Failed to lock node inputs");
399                    node.execute(node_inputs_lock, opset_version)
400                };
401                tx.send((r, node)).expect("Failed to send result");
402            });
403        }
404        // first, block until we get a result
405        match rx.recv() {
406            Ok((r, node)) => {
407                let outputs = {
408                    let mut node_inputs_lock =
409                        node_inputs.write().expect("Failed to lock node inputs");
410                    handle_output(
411                        r?,
412                        &node,
413                        &outputs_dir,
414                        &mut node_inputs_lock,
415                        &mut graph_outputs,
416                    )?
417                };
418                for output in outputs {
419                    if let Some(n) = dependency_graph.input_link_map.remove(&output) {
420                        for node in n {
421                            dependency_graph
422                                .node_input_requirements
423                                .entry(node)
424                                .and_modify(|v| {
425                                    v.retain(|x| *x != output);
426                                });
427                        }
428                    }
429                }
430            }
431            Err(e) => {
432                return Err(anyhow!("Failed to receive result: {:?}", e));
433            }
434        }
435        // then, check if we have more results
436        loop {
437            match rx.try_recv() {
438                Ok((r, node)) => {
439                    let outputs = {
440                        let mut node_inputs_lock =
441                            node_inputs.write().expect("Failed to lock node inputs");
442                        handle_output(
443                            r?,
444                            &node,
445                            &outputs_dir,
446                            &mut node_inputs_lock,
447                            &mut graph_outputs,
448                        )?
449                    };
450                    for output in outputs {
451                        if let Some(n) = dependency_graph.input_link_map.remove(&output) {
452                            for node in n {
453                                dependency_graph
454                                    .node_input_requirements
455                                    .entry(node)
456                                    .and_modify(|v| {
457                                        v.retain(|x| *x != output);
458                                    });
459                            }
460                        }
461                    }
462                }
463                Err(std::sync::mpsc::TryRecvError::Empty) => break,
464                Err(e) => {
465                    return Err(anyhow!("Failed to receive result: {:?}", e));
466                }
467            }
468        }
469        if dependency_graph.node_input_requirements.is_empty() {
470            break;
471        }
472    }
473    wait_pool(&pool);
474    compare_outputs(expected_outputs, graph_outputs)
475}