1use anyhow::anyhow;
2use num::traits::AsPrimitive;
3use std::io::Read;
4use std::os::raw::c_uchar;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::{collections::HashMap, io, path::Path};
8
9use ndarray::{ArrayD, IxDyn};
10use protobuf::Enum;
11
12use crate::common::FileInputs;
13use crate::onnx::tensor_proto::DataType;
14use crate::onnx::{NodeProto, TensorProto, ValueInfoProto};
15use crate::onnxparser::onnx;
16use crate::{common::*, print_at_level};
17use half::{bf16, f16};
18
19pub fn shape_safe_product<
21 'a,
22 B: 'a + std::iter::Product<&'a B> + std::default::Default + Copy + 'static,
23 A: IntoIterator<Item = &'a B>,
24>(
25 shape: A,
26) -> B
27where
28 usize: AsPrimitive<B>,
29{
30 let mut piter = shape.into_iter().peekable();
31 if piter.peek().is_none() {
32 1_usize.as_()
33 } else {
34 piter.product()
35 }
36}
37
38pub fn log_array_to_file<A: ndarray_npy::WritableElement, D: ndarray::Dimension>(
40 operation: &str,
41 name: &str,
42 a: &ndarray::ArrayBase<ndarray::ViewRepr<&A>, D>,
43) -> BoxResult<()> {
44 let verbose_flag = VERBOSE.load(std::sync::atomic::Ordering::Relaxed);
45 if verbose_flag == VerbosityLevel::Intermediate as usize {
46 static mut COUNTER: usize = 0;
47 unsafe {
48 ndarray_npy::write_npy(
49 format!(
50 "{}_intermediate_outputs/{}_{}.npy",
51 operation, COUNTER, name
52 ),
53 a,
54 )?;
55 COUNTER += 1;
56 }
57 }
58 Ok(())
59}
60
61#[macro_export]
62macro_rules! named_array_to_file {
64 ($op:ident, $name:ident) => {{
65 let $name = $name.view();
66 $crate::utils::log_array_to_file(stringify!($op), stringify!($name), &$name).unwrap();
67 }};
68 ($op:ident, $var:ident, $name:expr) => {{
69 let $var = $var.view();
70 $crate::utils::log_array_to_file(stringify!($op), &$name, &$var).unwrap();
71 }};
72}
73
74#[macro_export]
75macro_rules! create_intermediate_output_dir_for {
77 ($name:ident) => {{
78 use $crate::common::VerbosityLevel;
79 let verbose_flag = VERBOSE.load(std::sync::atomic::Ordering::Relaxed);
80 if verbose_flag == VerbosityLevel::Intermediate {
81 match std::fs::create_dir(concat!(stringify!($name), "_intermediate_outputs")) {
82 Ok(_) => {}
83 Err(e) => {
84 if e.kind() != std::io::ErrorKind::AlreadyExists {
85 return Err(anyhow!("Error creating rust_conv_outputs directory: {}", e));
86 }
87 }
88 }
89 }
90 }};
91}
92
93#[derive(Debug, Clone)]
94pub struct ValueInfo {
96 pub name: String,
97 pub type_: (ValueType, Vec<i64>),
98 pub doc_string: String,
99}
100
101#[derive(Debug, Clone)]
102pub struct OutputInfo {
104 pub valueinfo: ValueInfo,
105 pub data: Option<TensorType>,
106}
107
108impl OutputInfo {
109 fn new(valueinfo: ValueInfo) -> Self {
110 Self {
111 valueinfo,
112 data: None,
113 }
114 }
115}
116
117impl ValueInfo {
118 fn from_proto(proto: &ValueInfoProto) -> BoxResult<Self> {
120 if let Some(onnx::type_proto::Value::TensorType(tensor)) = &proto.type_.value {
121 let dt = onnx::tensor_proto::DataType::from_i32(tensor.elem_type.unwrap_or_default())
122 .unwrap_or_default();
123 Ok(Self {
124 name: proto
125 .name
126 .as_ref()
127 .map_or_else(|| UNKNOWN.to_owned(), |v| v.clone()),
128 type_: (
129 ValueType::new(dt)?,
130 tensor.shape.dim.iter().map(|v| v.dim_value()).collect(),
131 ),
132 doc_string: proto
133 .doc_string
134 .as_ref()
135 .map_or_else(|| UNKNOWN.to_owned(), |v| v.clone()),
136 })
137 } else {
138 todo!("ValueInfoProto type not supported: {:?}", proto.type_)
139 }
140 }
141}
142
143pub fn make_tensor_from_proto(proto: &TensorProto) -> BoxResult<TensorType> {
146 let shape = &proto.dims;
147 if proto.data_location() != onnx::tensor_proto::DataLocation::DEFAULT {
148 return Err(anyhow!("External data location not supported"));
149 }
150 make_tensor(shape, proto, proto.data_type())
151}
152
153fn get_raw_data(proto: &TensorProto) -> BoxResult<(&[u8], usize)> {
155 if let Some(ref raw_data) = proto.raw_data {
156 Ok((raw_data.as_slice(), 1))
157 } else if !proto.int32_data.is_empty() {
158 Ok((
159 bytemuck::try_cast_slice(proto.int32_data.as_slice()).map_err(|e| anyhow!(e))?,
160 4,
161 ))
162 } else if !proto.int64_data.is_empty() {
163 Ok((
164 bytemuck::try_cast_slice(proto.int64_data.as_slice()).map_err(|e| anyhow!(e))?,
165 8,
166 ))
167 } else if !proto.float_data.is_empty() {
168 Ok((
169 bytemuck::try_cast_slice(proto.float_data.as_slice()).map_err(|e| anyhow!(e))?,
170 4,
171 ))
172 } else if !proto.double_data.is_empty() {
173 Ok((
174 bytemuck::try_cast_slice(proto.double_data.as_slice()).map_err(|e| anyhow!(e))?,
175 8,
176 ))
177 } else if !proto.uint64_data.is_empty() {
178 Ok((
179 bytemuck::try_cast_slice(proto.uint64_data.as_slice()).map_err(|e| anyhow!(e))?,
180 8,
181 ))
182 } else {
183 Ok((&[], 0))
184 }
185}
186
187pub fn make_tensor(shape: &[i64], proto: &TensorProto, data_type: i32) -> BoxResult<TensorType> {
189 let enum_dt = DataType::from_i32(data_type).unwrap_or_default();
190 let shape = shape.iter().map(|v| *v as usize).collect::<Vec<usize>>();
191 let (bytedata, origin_elem_size) = get_raw_data(proto)?;
192 match enum_dt {
193 DataType::UNDEFINED => Err(anyhow!("Undefined data type")),
194 DataType::INT8 => match bytemuck::try_cast_slice::<u8, i8>(bytedata) {
195 Ok(data) => {
196 assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
197 let a = if origin_elem_size == std::mem::size_of::<i8>() {
198 ArrayD::<i8>::from_shape_vec(IxDyn(&shape), data.to_vec())?
199 } else {
200 ArrayD::<i8>::from_shape_vec(
201 IxDyn(&shape),
202 data.iter().step_by(origin_elem_size).copied().collect(),
203 )?
204 };
205 Ok(TensorType::I8(a))
206 }
207 Err(e) => Err(anyhow!(e)),
208 },
209 DataType::INT16 => match bytemuck::try_cast_slice::<u8, i16>(bytedata) {
210 Ok(data) => {
211 assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
212 let a = ArrayD::<i16>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
213 Ok(TensorType::I16(a))
214 }
215 Err(e) => Err(anyhow!(e)),
216 },
217 DataType::INT32 => {
218 let data = if let Some(data) = &proto.raw_data {
219 if data.is_empty() {
220 &[]
221 } else {
222 match bytemuck::try_cast_slice::<u8, i32>(data) {
223 Ok(data) => data,
224 Err(e) => return Err(anyhow!(e)),
225 }
226 }
227 } else {
228 proto.int32_data.as_slice()
229 };
230 let dlen = data.len();
231 let slen = if !shape.is_empty() {
232 shape_safe_product(&shape)
233 } else {
234 0
235 };
236 if dlen != slen && (slen == 0 && dlen != 1) {
239 return Err(anyhow!(
240 "Data length {} does not match shape length {}",
241 dlen,
242 slen
243 ));
244 }
245 let a = if data.is_empty() {
246 ArrayD::<i32>::zeros(IxDyn(&shape))
247 } else {
248 ArrayD::<i32>::from_shape_vec(IxDyn(&shape), data.to_vec())?
249 };
250 Ok(TensorType::I32(a))
251 }
252 DataType::INT64 => {
253 let data = if let Some(data) = &proto.raw_data {
254 if data.is_empty() {
255 &[]
256 } else {
257 match bytemuck::try_cast_slice::<u8, i64>(data) {
258 Ok(data) => data,
259 Err(e) => return Err(anyhow!(e)),
260 }
261 }
262 } else {
263 proto.int64_data.as_slice()
264 };
265 let dlen = data.len();
266 let slen = if !shape.is_empty() {
267 shape_safe_product(&shape)
268 } else {
269 0
270 };
271 if dlen != slen && (slen == 0 && dlen != 1) {
274 return Err(anyhow!(
275 "Data length {} does not match shape length {}",
276 dlen,
277 slen
278 ));
279 }
280 let a = if data.is_empty() {
281 ArrayD::<i64>::zeros(IxDyn(&shape))
282 } else {
283 ArrayD::<i64>::from_shape_vec(IxDyn(&shape), data.to_vec())?
284 };
285 Ok(TensorType::I64(a))
286 }
287 DataType::UINT8 => match bytemuck::try_cast_slice::<u8, u8>(bytedata) {
288 Ok(data) => {
289 assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
290 let a = if origin_elem_size == std::mem::size_of::<u8>() {
291 ArrayD::<u8>::from_shape_vec(IxDyn(&shape), data.to_vec())?
292 } else {
293 ArrayD::<u8>::from_shape_vec(
294 IxDyn(&shape),
295 data.iter().step_by(origin_elem_size).copied().collect(),
296 )?
297 };
298 Ok(TensorType::U8(a))
299 }
300 Err(e) => Err(anyhow!(e)),
301 },
302 DataType::UINT16 => match bytemuck::try_cast_slice::<u8, u16>(bytedata) {
303 Ok(data) => {
304 assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
305 let a = ArrayD::<u16>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
306 Ok(TensorType::U16(a))
307 }
308 Err(e) => Err(anyhow!(e)),
309 },
310 DataType::UINT32 => match bytemuck::try_cast_slice::<u8, u32>(bytedata) {
311 Ok(data) => {
312 assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
313 let a = ArrayD::<u32>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
314 Ok(TensorType::U32(a))
315 }
316 Err(e) => Err(anyhow!(e)),
317 },
318 DataType::UINT64 => {
319 let data = if let Some(data) = &proto.raw_data {
320 if data.is_empty() {
321 &[]
322 } else {
323 match bytemuck::try_cast_slice::<u8, u64>(data) {
324 Ok(data) => data,
325 Err(e) => return Err(anyhow!(e)),
326 }
327 }
328 } else {
329 proto.uint64_data.as_slice()
330 };
331 let dlen = data.len();
332 let slen = if !shape.is_empty() {
333 shape_safe_product(&shape)
334 } else {
335 0
336 };
337 if dlen != slen && (slen == 0 && dlen != 1) {
340 return Err(anyhow!(
341 "Data length {} does not match shape length {}",
342 dlen,
343 slen
344 ));
345 }
346 let a = if data.is_empty() {
347 ArrayD::<u64>::zeros(IxDyn(&shape))
348 } else {
349 ArrayD::<u64>::from_shape_vec(IxDyn(&shape), data.to_vec())?
350 };
351 Ok(TensorType::U64(a))
352 }
353 DataType::FLOAT16 => match bytemuck::try_cast_slice::<u8, u16>(bytedata) {
354 Ok(data) => {
355 assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
356 let a = ArrayD::<f16>::from_shape_vec(
357 IxDyn(&shape),
358 data.iter().map(|x| f16::from_bits(*x)).collect(),
359 )?;
360 Ok(TensorType::F16(a))
361 }
362 Err(e) => Err(anyhow!(e)),
363 },
364 DataType::BFLOAT16 => match bytemuck::try_cast_slice::<u8, f32>(bytedata) {
365 Ok(data) => {
366 assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
367 let a = ArrayD::<bf16>::from_shape_vec(
368 IxDyn(&shape),
369 data.iter().map(|x| bf16::from_f32(*x)).collect(),
370 )?;
371 Ok(TensorType::BF16(a))
372 }
373 Err(e) => Err(anyhow!(e)),
374 },
375 DataType::DOUBLE => {
376 let data = if let Some(data) = &proto.raw_data {
377 if data.is_empty() {
378 &[]
379 } else {
380 match bytemuck::try_cast_slice::<u8, f64>(data) {
381 Ok(data) => data,
382 Err(e) => return Err(anyhow!(e)),
383 }
384 }
385 } else {
386 proto.double_data.as_slice()
387 };
388 let dlen = data.len();
389 let slen = if !shape.is_empty() {
390 shape_safe_product(&shape)
391 } else {
392 0
393 };
394 if dlen != slen && (slen == 0 && dlen != 1) {
397 return Err(anyhow!(
398 "Data length {} does not match shape length {}",
399 dlen,
400 slen
401 ));
402 }
403 let a = if data.is_empty() {
404 ArrayD::<f64>::zeros(IxDyn(&shape))
405 } else {
406 ArrayD::<f64>::from_shape_vec(IxDyn(&shape), data.to_vec())?
407 };
408 Ok(TensorType::F64(a))
409 }
410 DataType::STRING => {
411 let bytedata = &proto.string_data;
412 let a = ArrayD::<String>::from_shape_vec(
413 IxDyn(&shape),
414 bytedata
415 .iter()
416 .map(|v| String::from_utf8_lossy(v.as_ref()).to_string())
417 .collect(),
418 )?;
419 Ok(TensorType::Str(a))
420 }
421 DataType::BOOL => match bytemuck::try_cast_slice::<u8, c_uchar>(bytedata) {
422 Ok(data) => {
423 assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
424 let a = ArrayD::<bool>::from_shape_vec(
425 IxDyn(&shape),
426 data.iter().map(|x| *x != 0).collect(),
427 )?;
428 Ok(TensorType::Bool(a))
429 }
430 Err(e) => Err(anyhow!(e)),
431 },
432 DataType::FLOAT8E4M3FN
433 | DataType::FLOAT8E4M3FNUZ
434 | DataType::FLOAT8E5M2FNUZ
435 | DataType::FLOAT8E5M2 => {
436 todo!("Data type {:?} not supported", enum_dt);
437 }
438 DataType::FLOAT => {
439 let data = if let Some(data) = &proto.raw_data {
440 if data.is_empty() {
441 &[]
442 } else {
443 match bytemuck::try_cast_slice::<u8, f32>(data) {
444 Ok(data) => data,
445 Err(e) => return Err(anyhow!(e)),
446 }
447 }
448 } else {
449 proto.float_data.as_slice()
450 };
451 let dlen = data.len();
452 let slen = if !shape.is_empty() {
453 shape_safe_product(&shape)
454 } else {
455 0
456 };
457 if dlen != slen && (slen == 0 && dlen != 1) {
460 return Err(anyhow!(
461 "Data length {} does not match shape length {}",
462 dlen,
463 slen
464 ));
465 }
466 let a = if data.is_empty() {
467 ArrayD::<f32>::zeros(IxDyn(&shape))
468 } else {
469 ArrayD::<f32>::from_shape_vec(IxDyn(&shape), data.to_vec())?
470 };
471 Ok(TensorType::F32(a))
472 }
473 DataType::COMPLEX64 => match bytemuck::try_cast_slice::<u8, Complex64Repr>(bytedata) {
474 Ok(data) => {
475 assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
476 let a = ArrayD::<Complex64>::from_shape_vec(
477 IxDyn(&shape),
478 data.iter()
479 .map(|v| Complex64::new(v._val[0], v._val[1]))
480 .collect(),
481 )?;
482 Ok(TensorType::C64(a))
483 }
484 Err(e) => Err(anyhow!(e)),
485 },
486 DataType::COMPLEX128 => match bytemuck::try_cast_slice::<u8, Complex128Repr>(bytedata) {
487 Ok(data) => {
488 assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
489 let a = ArrayD::<Complex128>::from_shape_vec(
490 IxDyn(&shape),
491 data.iter()
492 .map(|v| Complex128::new(v._val[0], v._val[1]))
493 .collect(),
494 )?;
495 Ok(TensorType::C128(a))
496 }
497 Err(e) => Err(anyhow!(e)),
498 },
499 }
500}
501
502pub fn make_tensor_from_raw(
504 shape: &[i64],
505 bytedata: &[u8],
506 data_type: i32,
507) -> BoxResult<TensorType> {
508 let enum_dt = DataType::from_i32(data_type).unwrap_or_default();
509 let shape = shape.iter().map(|v| *v as usize).collect::<Vec<usize>>();
510 match enum_dt {
511 DataType::UNDEFINED => Err(anyhow!("Undefined data type")),
512 DataType::INT8 => match bytemuck::try_cast_slice::<u8, i8>(bytedata) {
513 Ok(data) => {
514 assert_eq!(data.len(), shape_safe_product(&shape));
515 let a = ArrayD::<i8>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
516 Ok(TensorType::I8(a))
517 }
518 Err(e) => Err(anyhow!(e)),
519 },
520 DataType::INT16 => match bytemuck::try_cast_slice::<u8, i16>(bytedata) {
521 Ok(data) => {
522 assert_eq!(data.len(), shape_safe_product(&shape));
523 let a = ArrayD::<i16>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
524 Ok(TensorType::I16(a))
525 }
526 Err(e) => Err(anyhow!(e)),
527 },
528 DataType::INT32 => match bytemuck::try_cast_slice::<u8, i32>(bytedata) {
529 Ok(data) => {
530 assert_eq!(data.len(), shape_safe_product(&shape));
531 let a = ArrayD::<i32>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
532 Ok(TensorType::I32(a))
533 }
534 Err(e) => Err(anyhow!(e)),
535 },
536 DataType::INT64 => match bytemuck::try_cast_slice::<u8, i64>(bytedata) {
537 Ok(data) => {
538 let dlen = data.len();
539 let slen = if !shape.is_empty() {
540 shape_safe_product(&shape)
541 } else {
542 0
543 };
544 if dlen != slen && (slen == 0 && dlen != 1) {
547 return Err(anyhow!(
548 "Data length {} does not match shape length {}",
549 dlen,
550 slen
551 ));
552 }
553 let a = if data.is_empty() {
554 ArrayD::<i64>::zeros(IxDyn(&shape))
555 } else {
556 ArrayD::<i64>::from_shape_vec(IxDyn(&shape), data.to_vec())?
557 };
558 Ok(TensorType::I64(a))
559 }
560 Err(e) => Err(anyhow!(e)),
561 },
562 DataType::UINT8 => match bytemuck::try_cast_slice::<u8, u8>(bytedata) {
563 Ok(data) => {
564 assert_eq!(data.len(), shape_safe_product(&shape));
565 let a = ArrayD::<u8>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
566 Ok(TensorType::U8(a))
567 }
568 Err(e) => Err(anyhow!(e)),
569 },
570 DataType::UINT16 => match bytemuck::try_cast_slice::<u8, u16>(bytedata) {
571 Ok(data) => {
572 assert_eq!(data.len(), shape_safe_product(&shape));
573 let a = ArrayD::<u16>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
574 Ok(TensorType::U16(a))
575 }
576 Err(e) => Err(anyhow!(e)),
577 },
578 DataType::UINT32 => match bytemuck::try_cast_slice::<u8, u32>(bytedata) {
579 Ok(data) => {
580 assert_eq!(data.len(), shape_safe_product(&shape));
581 let a = ArrayD::<u32>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
582 Ok(TensorType::U32(a))
583 }
584 Err(e) => Err(anyhow!(e)),
585 },
586 DataType::UINT64 => match bytemuck::try_cast_slice::<u8, u64>(bytedata) {
587 Ok(data) => {
588 assert_eq!(data.len(), shape_safe_product(&shape));
589 let a = ArrayD::<u64>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
590 Ok(TensorType::U64(a))
591 }
592 Err(e) => Err(anyhow!(e)),
593 },
594 DataType::FLOAT16 => match bytemuck::try_cast_slice::<u8, u16>(bytedata) {
595 Ok(data) => {
596 assert_eq!(data.len(), shape_safe_product(&shape));
597 let a = ArrayD::<f16>::from_shape_vec(
598 IxDyn(&shape),
599 data.iter().map(|x| f16::from_bits(*x)).collect(),
600 )?;
601 Ok(TensorType::F16(a))
602 }
603 Err(e) => Err(anyhow!(e)),
604 },
605 DataType::BFLOAT16 => match bytemuck::try_cast_slice::<u8, f32>(bytedata) {
606 Ok(data) => {
607 assert_eq!(data.len(), shape_safe_product(&shape));
608 let a = ArrayD::<bf16>::from_shape_vec(
609 IxDyn(&shape),
610 data.iter().map(|x| bf16::from_f32(*x)).collect(),
611 )?;
612 Ok(TensorType::BF16(a))
613 }
614 Err(e) => Err(anyhow!(e)),
615 },
616 DataType::DOUBLE => match bytemuck::try_cast_slice::<u8, f64>(bytedata) {
617 Ok(data) => {
618 assert_eq!(data.len(), shape_safe_product(&shape));
619 let a = ArrayD::<f64>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
620 Ok(TensorType::F64(a))
621 }
622 Err(e) => Err(anyhow!(e)),
623 },
624 DataType::STRING => Err(anyhow!(
625 "String data type not supported, use make_string_tensor()"
626 )),
627 DataType::BOOL => match bytemuck::try_cast_slice::<u8, c_uchar>(bytedata) {
628 Ok(data) => {
629 assert_eq!(data.len(), shape_safe_product(&shape));
630 let a = ArrayD::<bool>::from_shape_vec(
631 IxDyn(&shape),
632 data.iter().map(|x| *x != 0).collect(),
633 )?;
634 Ok(TensorType::Bool(a))
635 }
636 Err(e) => Err(anyhow!(e)),
637 },
638 DataType::FLOAT
639 | DataType::FLOAT8E4M3FN
640 | DataType::FLOAT8E4M3FNUZ
641 | DataType::FLOAT8E5M2FNUZ
642 | DataType::FLOAT8E5M2 => match bytemuck::try_cast_slice::<u8, f32>(bytedata) {
643 Ok(data) => {
644 let dlen = data.len();
645 let slen = if !shape.is_empty() {
646 shape_safe_product(&shape)
647 } else {
648 0
649 };
650 if dlen != slen && (slen == 0 && dlen != 1) {
653 return Err(anyhow!(
654 "Data length {} does not match shape length {}",
655 dlen,
656 slen
657 ));
658 }
659 let a = if data.is_empty() {
660 ArrayD::<f32>::zeros(IxDyn(&shape))
661 } else {
662 ArrayD::<f32>::from_shape_vec(IxDyn(&shape), data.to_vec())?
663 };
664 Ok(TensorType::F32(a))
665 }
666 Err(e) => {
667 eprintln!("Copying data of tensor as f32 because {}", e);
668 let mut copied_data = vec![];
669 for float_slice in bytedata.chunks_exact(std::mem::size_of::<f32>()) {
670 copied_data.push(f32::from_le_bytes(float_slice.try_into()?));
671 }
672 let a = ArrayD::<f32>::from_shape_vec(IxDyn(&shape), copied_data)?;
673 Ok(TensorType::F32(a))
674 }
675 },
676 DataType::COMPLEX64 => match bytemuck::try_cast_slice::<u8, Complex64Repr>(bytedata) {
677 Ok(data) => {
678 assert_eq!(data.len(), shape_safe_product(&shape));
679 let a = ArrayD::<Complex64>::from_shape_vec(
680 IxDyn(&shape),
681 data.iter()
682 .map(|v| Complex64::new(v._val[0], v._val[1]))
683 .collect(),
684 )?;
685 Ok(TensorType::C64(a))
686 }
687 Err(e) => Err(anyhow!(e)),
688 },
689 DataType::COMPLEX128 => match bytemuck::try_cast_slice::<u8, Complex128Repr>(bytedata) {
690 Ok(data) => {
691 assert_eq!(data.len(), shape_safe_product(&shape));
692 let a = ArrayD::<Complex128>::from_shape_vec(
693 IxDyn(&shape),
694 data.iter()
695 .map(|v| Complex128::new(v._val[0], v._val[1]))
696 .collect(),
697 )?;
698 Ok(TensorType::C128(a))
699 }
700 Err(e) => Err(anyhow!(e.to_string())),
701 },
702 }
703}
704
705pub fn make_initializers(graph: &onnx::GraphProto) -> BoxResult<HashMap<String, TensorType>> {
707 let mut initializers: HashMap<String, TensorType> = HashMap::new();
708 for tensor in graph.initializer.iter() {
709 let tensor_name = tensor.name.as_ref().map_or(UNKNOWN, |v| v.as_str());
710 if !tensor.has_data_type() {
711 eprintln!(" Tensor: {} has no data type", tensor_name);
712 } else {
713 initializers.insert(tensor_name.to_string(), make_tensor_from_proto(tensor)?);
714 }
715 }
716 Ok(initializers)
717}
718
719fn make_input_tensors_from_files(
721 graph: &onnx::GraphProto,
722 files: &[PathBuf],
723 mut initializers: HashMap<String, TensorType>,
724) -> BoxResult<HashMap<String, Arc<TensorType>>> {
725 let mut map = HashMap::new();
726 let mut external_inputs_map = HashMap::new();
727 for input in files.iter() {
728 let input_tensor = read_tensor(input)?;
729 external_inputs_map.insert(
730 input_tensor
731 .name
732 .as_ref()
733 .map_or_else(|| UNKNOWN.to_owned(), |v| v.clone()),
734 input_tensor,
735 );
736 }
737 for input in graph.input.iter() {
738 let input_name = input.name.as_ref().map_or(UNKNOWN, |v| v.as_str());
739 if let Some(input_from_file) = external_inputs_map.get(input_name) {
740 let tensor = make_tensor_from_proto(input_from_file)?;
741 print_at_level!(
742 VerbosityLevel::Informational,
743 " Input {} from file has shape {:?} and type {:?}",
744 input_name,
745 tensor.shape(),
746 tensor.value_type()
747 );
748 map.insert(input_name.to_string(), Arc::new(tensor));
749 } else if let Some((_, init)) = initializers.remove_entry(input_name) {
750 print_at_level!(
751 VerbosityLevel::Informational,
752 " Input {} from initializer has shape {:?} and type {:?}",
753 input_name,
754 init.shape(),
755 init.value_type()
756 );
757 map.insert(input_name.to_string(), Arc::new(init));
758 } else {
759 return Err(anyhow!(
760 "Input {} not found in inputs file or graph initializers",
761 input_name
762 ));
763 }
764 }
765 for (k, v) in initializers {
766 map.insert(k, Arc::new(v));
767 }
768 Ok(map)
769}
770
771fn make_output_tensors_from_files(
773 graph: &onnx::GraphProto,
774 files: &[PathBuf],
775) -> BoxResult<HashMap<String, TensorType>> {
776 let mut map = HashMap::new();
777 let mut external_outputs_map = HashMap::new();
778 for output in files.iter() {
779 let ouput_tensor = read_tensor(output)?;
780 external_outputs_map.insert(
781 ouput_tensor
782 .name
783 .as_ref()
784 .map_or_else(|| UNKNOWN.to_owned(), |v| v.clone()),
785 ouput_tensor,
786 );
787 }
788 for output in graph.output.iter() {
789 let output_name = output.name.as_ref().map_or(UNKNOWN, |v| v.as_str());
790 if let Some(output_from_file) = external_outputs_map.get(output_name) {
791 map.insert(
792 output_name.to_string(),
793 make_tensor_from_proto(output_from_file)?,
794 );
795 } else {
796 return Err(anyhow!("Output {} not found in inputs file", output_name));
797 }
798 }
799 Ok(map)
800}
801
802pub fn initialize_nodes(
804 graph: &onnx::GraphProto,
805 fileinputs: &FileInputs,
806 initializers: HashMap<String, TensorType>,
807) -> BoxResult<HashMap<String, Arc<TensorType>>> {
808 if fileinputs.inputs.is_empty() {
809 return Ok(HashMap::new());
810 }
811 make_input_tensors_from_files(graph, &fileinputs.inputs, initializers)
812}
813
814pub fn make_external_outputs(
816 graph: &onnx::GraphProto,
817 fileinputs: &FileInputs,
818) -> BoxResult<HashMap<String, TensorType>> {
819 if fileinputs.outputs.is_empty() {
820 return Ok(HashMap::new());
821 }
822 make_output_tensors_from_files(graph, &fileinputs.outputs)
823}
824
825pub fn make_graph_outputs(graph: &onnx::GraphProto) -> BoxResult<HashMap<String, OutputInfo>> {
827 let mut map = HashMap::new();
828 for output in graph.output.iter() {
829 let output_name = output.name.as_ref().map_or(UNKNOWN, |v| v.as_str());
830 map.insert(
831 output_name.to_string(),
832 OutputInfo::new(ValueInfo::from_proto(output)?),
833 );
834 }
835 Ok(map)
836}
837
838fn read_model_text(p: &Path) -> BoxResult<onnx::ModelProto> {
840 let file = std::fs::File::open(p)?;
841 let mut reader = io::BufReader::new(file);
842 let mut buf = String::new();
843 reader.read_to_string(&mut buf)?;
844 let model = protobuf::text_format::parse_from_str(&buf)?;
845 Ok(model)
846}
847
848fn read_model_binary(p: &Path) -> BoxResult<onnx::ModelProto> {
850 let file = std::fs::File::open(p)?;
851 let mut reader = io::BufReader::new(file);
852 let model: onnx::ModelProto = protobuf::Message::parse_from_reader(&mut reader)?;
853 Ok(model)
854}
855
856pub fn read_model(p: &Path) -> BoxResult<onnx::ModelProto> {
858 print_at_level!(VerbosityLevel::Minimal, "Reading model from {}", p.display());
859 let merr = read_model_binary(p);
860 match merr {
861 Ok(m) => Ok(m),
862 Err(e) => {
863 eprintln!("Error reading binary model: {}", e);
864 read_model_text(p)
865 }
866 }
867}
868
869pub fn read_tensor(p: &Path) -> BoxResult<onnx::TensorProto> {
871 let file = std::fs::File::open(p)?;
872 let mut reader = io::BufReader::new(file);
873 let model: onnx::TensorProto = protobuf::Message::parse_from_reader(&mut reader)?;
874 Ok(model)
875}
876
877pub fn pick_opset_version(target_ver: i64, opset_versions: &[i64]) -> i64 {
879 let mut opset_version = 0;
880 for v in opset_versions.iter() {
881 if *v <= target_ver && *v > opset_version {
882 opset_version = *v;
883 }
884 }
885 opset_version
886}
887
888pub fn operator_not_implemented(
890 _inputs: &[&TensorType],
891 _node: &NodeProto,
892 _opset_version: i64,
893 _output_len: usize,
894) -> BoxResult<OperatorResult> {
895 todo!("operator not implemented");
896}