1use std::{
2 collections::{HashMap, HashSet},
3 path::Path,
4 sync::{Arc, RwLock, RwLockReadGuard, atomic::Ordering},
5};
6
7use crate::{
8 common::{
9 Args, BoxResult, FileInputs, OperationFn, OperatorResult, TensorType, VerbosityLevel,
10 MAX_OPSET_VERSION, VERBOSE,
11 },
12 onnx,
13 operators::OPERATION_MAP,
14 print_at_level,
15 protograph::{build_graph_from_proto, GraphOutputType},
16 read_model,
17 utils::{initialize_nodes, make_initializers},
18 utils::{make_external_outputs, make_graph_outputs, operator_not_implemented, OutputInfo},
19};
20
21use anyhow::anyhow;
22use smallvec::SmallVec;
23#[derive(Debug, Clone, PartialEq)]
24pub struct ONNXNode {
26 id: usize,
27 op_func: OperationFn,
28 node_ref: Arc<onnx::NodeProto>,
29}
30
31impl ONNXNode {
32 fn new(id: usize, op_func: OperationFn, node_ref: onnx::NodeProto) -> Self {
34 Self {
35 id,
36 op_func,
37 node_ref: Arc::new(node_ref),
38 }
39 }
40
41 fn execute(
43 &self,
44 node_inputs: RwLockReadGuard<HashMap<String, Arc<TensorType>>>,
45 opset_version: i64,
46 ) -> BoxResult<OperatorResult> {
47 let mut inputs = vec![];
48 let mut outputs = vec![];
49 let mut all_nodes_have_init = true;
50 for input in self.node_ref.input.iter() {
51 if let Some(k) = node_inputs.get(input) {
52 inputs.push(k.clone());
53 } else {
54 all_nodes_have_init = false;
55 }
56 }
57 drop(node_inputs); for output in self.node_ref.output.iter() {
59 outputs.push(output);
60 }
61 if !all_nodes_have_init {
62 return Err(anyhow!("Some nodes in this operation have not been initialized yet, this means the operations aren't in order, fix the code to account for this"));
63 }
64 let input_names = self
65 .node_ref
66 .input
67 .iter()
68 .map(|s| s.as_str())
69 .collect::<Vec<&str>>();
70 let output_names = self
71 .node_ref
72 .output
73 .iter()
74 .map(|s| s.as_str())
75 .collect::<Vec<&str>>();
76 print_at_level!(
77 VerbosityLevel::Informational,
78 "Running {} operator (id: {}, thread: {:?}) between {:?} to get {:?}",
79 self.node_ref.op_type(),
80 self.id,
81 std::thread::current().id(),
82 input_names,
83 output_names
84 );
85 let inputs: SmallVec<[&TensorType; 4]> = inputs.iter().map(|x| x.as_ref()).collect();
87 (self.op_func)(
88 &inputs,
89 self.node_ref.as_ref(),
90 opset_version,
91 output_names.len(),
92 )
93 }
94}
95
96impl std::hash::Hash for ONNXNode {
97 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
98 self.id.hash(state);
99 }
100}
101
102impl std::cmp::Eq for ONNXNode {}
103
104fn handle_output(
106 result: OperatorResult,
107 node: &ONNXNode,
108 outputs_dir: &Path,
109 node_inputs: &mut HashMap<String, Arc<TensorType>>,
110 graph_outputs: &mut HashMap<String, OutputInfo>,
111) -> BoxResult<Vec<String>> {
112 let node = node.node_ref.clone();
113 let outputs = node.output.to_vec();
114 let result = result.result; assert_eq!(outputs.len(), result.len());
116 for (output_name, res) in outputs.iter().zip(result.into_iter()) {
117 print_at_level!(
118 VerbosityLevel::Informational,
119 "\tOutput {} has shape {:?}",
120 output_name,
121 res.shape()
122 );
123 if VERBOSE.load(Ordering::Relaxed) >= VerbosityLevel::Results
124 {
125 res.to_file(outputs_dir, output_name)?;
126 }
127 node_inputs.insert(output_name.to_string(), Arc::new(res));
128 }
129 for output_name in outputs.iter() {
130 if let Some(gout) = graph_outputs.get_mut(output_name) {
131 if let Some(produced) = node_inputs.get(output_name) {
132 gout.data = Some(produced.as_ref().clone());
133 }
134 }
135 }
136 Ok(outputs)
137}
138
139struct DependencyGraph {
141 pub node_input_requirements: HashMap<ONNXNode, Vec<String>>,
143 pub input_link_map: HashMap<String, Vec<ONNXNode>>,
145 pub not_implemented: HashSet<String>,
147}
148
149fn create_links_and_requirements(
151 graph: &onnx::GraphProto,
152 node_inputs: &HashMap<String, Arc<TensorType>>,
153) -> BoxResult<DependencyGraph> {
154 let mut node_input_requirements: HashMap<ONNXNode, Vec<String>> = HashMap::new();
155 let mut input_link_map: HashMap<String, Vec<ONNXNode>> = HashMap::new();
156 let mut not_implemented = HashSet::new();
157 for (counter, node) in graph.node.iter().enumerate() {
158 let input_names = node
159 .input
160 .iter()
161 .filter_map(|s| {
162 if node_inputs.contains_key(s.as_str()) {
163 None
164 } else {
165 Some(s.clone())
166 }
167 })
168 .collect::<Vec<String>>();
169 if let Some(name) = node.op_type.as_deref() {
170 for input_name in input_names.iter() {
171 input_link_map
172 .entry(input_name.to_string())
173 .or_default()
174 .push(ONNXNode::new(
175 counter,
176 *OPERATION_MAP
177 .get(name)
178 .unwrap_or(&(operator_not_implemented as OperationFn)),
179 node.clone(),
180 ));
181 }
182 if let Some(op_func) = OPERATION_MAP.get(name) {
183 node_input_requirements
184 .insert(ONNXNode::new(counter, *op_func, node.clone()), input_names);
185 } else {
186 node_input_requirements.insert(
187 ONNXNode::new(
188 counter,
189 operator_not_implemented as OperationFn,
190 node.clone(),
191 ),
192 input_names,
193 );
194 not_implemented.insert(name.to_string());
195 }
196 }
197 }
198 Ok(DependencyGraph {
199 node_input_requirements,
200 input_link_map,
201 not_implemented,
202 })
203}
204
205pub fn compare_outputs(
207 expected_outputs: HashMap<String, TensorType>,
208 mut graph_outputs: HashMap<String, OutputInfo>,
209) -> BoxResult<HashMap<String, OutputInfo>> {
210 let mut results = HashMap::new();
211 for (name, value) in expected_outputs.iter() {
212 if let Some((namestring, gout)) = graph_outputs.remove_entry(name) {
213 if let Some(data) = &gout.data {
214 if value.shape() != data.shape() {
215 return Err(anyhow!(
216 "Expected output {} to have shape {:?} but got {:?}",
217 name,
218 value.shape(),
219 data.shape()
220 ));
221 } else {
222 print_at_level!(
223 VerbosityLevel::Minimal,
224 "Output {} has shape {:?} as expected",
225 name,
226 value.shape()
227 );
228 }
229 if value.value_type() != data.value_type() {
230 return Err(anyhow!(
231 "Expected output {} to have type {:?} but got {:?}",
232 name,
233 value.value_type(),
234 data.value_type()
235 ));
236 } else {
237 print_at_level!(
238 VerbosityLevel::Minimal,
239 "Output {} has type {:?} as expected",
240 name,
241 value.value_type()
242 );
243 }
244 match (value, data) {
245 (TensorType::F32(v), TensorType::F32(d)) => {
246 let mut count = 0;
247 let mut diff = vec![];
248 for (i, (v, d)) in v.iter().zip(d.iter()).enumerate() {
249 if (v - d).abs() > 0.0001 {
250 count += 1;
251 }
252 diff.push((i, v, d, (v - d).abs()));
253 }
254 let max = diff
255 .iter()
256 .max_by(|(_, _, _, d1), (_, _, _, d2)| {
257 d1.partial_cmp(d2).unwrap_or(std::cmp::Ordering::Less)
258 })
259 .expect("Failed to get max difference");
260 print_at_level!(
261 VerbosityLevel::Minimal,
262 "Output {} has {} values with absolute difference of more than .0001\n\tMax difference: {:?}",
263 name,
264 count,
265 max
266 );
267 }
268 _ => todo!(
269 "Compare output {:?} with {:?}",
270 value.value_type(),
271 data.value_type()
272 ),
273 }
274 }
275 results.insert(namestring, gout);
276 }
277 }
278 Ok(results)
279}
280
281#[cfg(feature = "custom-threadpool")]
282fn create_pool(parallelism: usize) -> BoxResult<crate::parallel::ThreadPool> {
283 Ok(crate::parallel::ThreadPool::new(parallelism / 3 * 2)) }
285
286#[cfg(feature = "custom-threadpool")]
287fn wait_pool(pool: &crate::parallel::ThreadPool) {
288 pool.wait()
289}
290
291#[cfg(not(feature = "custom-threadpool"))]
292fn create_pool(parallelism: usize) -> BoxResult<rayon::ThreadPool> {
293 Ok(rayon::ThreadPoolBuilder::new()
294 .num_threads(parallelism / 3 * 2)
295 .build()?)
296}
297
298#[cfg(not(feature = "custom-threadpool"))]
299fn wait_pool(_pool: &rayon::ThreadPool) {
300 }
302
303pub fn execute_model(args: &Args) -> BoxResult<HashMap<String, OutputInfo>> {
304 VERBOSE
305 .store(VerbosityLevel::new(args.verbose) as usize, Ordering::Relaxed);
306 print_at_level!(
307 VerbosityLevel::Minimal,
308 "Running model: {}",
309 args.model.display()
310 );
311 let inputspath = if args.model.is_relative() {
312 Path::new("models").join(&args.model).join("inputs.json")
313 } else {
314 args.model.join("inputs.json")
315 };
316 let inputs_file = std::fs::File::open(inputspath)?;
317 let mut fileinputs: FileInputs = serde_json::from_reader(inputs_file)?;
318 fileinputs.extend_paths(&args.model);
319 let model = read_model(Path::new(&fileinputs.modelpath))?;
320 let outputs_dir = Path::new("outputs").join(&args.model);
321 let parallelism: usize = std::thread::available_parallelism()?.into();
322 let pool = create_pool(parallelism)?;
323 if VERBOSE
324 .load(Ordering::Relaxed) >= VerbosityLevel::Results
325 {
326 std::fs::create_dir_all(&outputs_dir)?;
327 }
328 let opset_version = if let Some(v) = model.opset_import.first() {
329 if let Some(v) = v.version {
330 v
331 } else {
332 MAX_OPSET_VERSION
333 }
334 } else {
335 MAX_OPSET_VERSION
336 };
337 if opset_version > MAX_OPSET_VERSION {
338 return Err(anyhow!(
339 "Opset version {} is not supported, max supported version is {}",
340 opset_version,
341 MAX_OPSET_VERSION
342 ));
343 }
344 let graph = model.graph.get_or_default();
345 if args.gengraph {
346 build_graph_from_proto(
347 graph,
348 &fileinputs.modelpath,
349 match args.graphtype.as_str() {
350 "json" => GraphOutputType::Json,
351 "dot" => GraphOutputType::Dot,
352 _ => return Err(anyhow!("Invalid graph type")),
353 },
354 )?;
355 }
356 let initializers = make_initializers(graph)?;
357 let node_inputs = initialize_nodes(graph, &fileinputs, initializers)?;
358 let expected_outputs = make_external_outputs(graph, &fileinputs)?;
359 let mut graph_outputs = make_graph_outputs(graph)?;
360 let mut dependency_graph = create_links_and_requirements(graph, &node_inputs)?;
361 let node_inputs = Arc::new(RwLock::new(node_inputs));
362 let (tx, rx) = std::sync::mpsc::channel();
363 for vi in graph.value_info.iter() {
364 if let Some(onnx::type_proto::Value::TensorType(_)) = vi.type_.value {
365 } else {
367 unimplemented!("ValueInfoProto type {:?}", vi.type_)
368 }
369 }
370
371 print_at_level!(
372 VerbosityLevel::Informational,
373 "Number of not implemented operators: {}",
374 dependency_graph.not_implemented.len()
375 );
376 for name in dependency_graph.not_implemented.iter() {
377 eprintln!("Model uses operator {} which is not implemented yet", name);
378 }
379 if !dependency_graph.not_implemented.is_empty() && args.failfast {
380 return Err(anyhow!("Not implemented operators found"));
381 }
382 loop {
383 let mut nodes_ready = vec![];
384 for (node, inputs) in dependency_graph.node_input_requirements.iter() {
385 if inputs.is_empty() {
386 nodes_ready.push(node.clone());
387 }
388 }
389 dependency_graph
390 .node_input_requirements
391 .retain(|_, v| !v.is_empty());
392 for node in nodes_ready {
393 let node_inputs_ref = node_inputs.clone();
394 let tx = tx.clone();
395 pool.spawn(move || {
396 let r = {
397 let node_inputs_lock =
398 node_inputs_ref.read().expect("Failed to lock node inputs");
399 node.execute(node_inputs_lock, opset_version)
400 };
401 tx.send((r, node)).expect("Failed to send result");
402 });
403 }
404 match rx.recv() {
406 Ok((r, node)) => {
407 let outputs = {
408 let mut node_inputs_lock =
409 node_inputs.write().expect("Failed to lock node inputs");
410 handle_output(
411 r?,
412 &node,
413 &outputs_dir,
414 &mut node_inputs_lock,
415 &mut graph_outputs,
416 )?
417 };
418 for output in outputs {
419 if let Some(n) = dependency_graph.input_link_map.remove(&output) {
420 for node in n {
421 dependency_graph
422 .node_input_requirements
423 .entry(node)
424 .and_modify(|v| {
425 v.retain(|x| *x != output);
426 });
427 }
428 }
429 }
430 }
431 Err(e) => {
432 return Err(anyhow!("Failed to receive result: {:?}", e));
433 }
434 }
435 loop {
437 match rx.try_recv() {
438 Ok((r, node)) => {
439 let outputs = {
440 let mut node_inputs_lock =
441 node_inputs.write().expect("Failed to lock node inputs");
442 handle_output(
443 r?,
444 &node,
445 &outputs_dir,
446 &mut node_inputs_lock,
447 &mut graph_outputs,
448 )?
449 };
450 for output in outputs {
451 if let Some(n) = dependency_graph.input_link_map.remove(&output) {
452 for node in n {
453 dependency_graph
454 .node_input_requirements
455 .entry(node)
456 .and_modify(|v| {
457 v.retain(|x| *x != output);
458 });
459 }
460 }
461 }
462 }
463 Err(std::sync::mpsc::TryRecvError::Empty) => break,
464 Err(e) => {
465 return Err(anyhow!("Failed to receive result: {:?}", e));
466 }
467 }
468 }
469 if dependency_graph.node_input_requirements.is_empty() {
470 break;
471 }
472 }
473 wait_pool(&pool);
474 compare_outputs(expected_outputs, graph_outputs)
475}