1use tract_core::internal::*;
2use tract_linalg::block_quant::{BlockQuantFact, BlockQuantValue, Q4_0};
3
4use crate::fact::*;
5
6pub fn facts_to_device_facts(
7 facts: &[&TypedFact],
8 resolve_facts: impl Fn(&[&TypedFact]) -> TractResult<TVec<TypedFact>>,
9) -> TractResult<TVec<TypedFact>> {
10 if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
11 let device_facts = facts
12 .iter()
13 .map(|it| it.to_device_fact().map(|it| it.as_ref()))
14 .collect::<TractResult<TVec<_>>>()?;
15 let output_facts = (resolve_facts)(device_facts.as_slice())?;
16 Ok(output_facts
17 .into_iter()
18 .map(|it| Ok(DeviceFact::new(DeviceTensorOrigin::FromDevice, it)?.into_opaque_fact()))
19 .collect::<TractResult<_>>()?)
20 } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
21 (resolve_facts)(facts)
22 } else {
23 bail!(
24 "Inconsistent facts datum type: {:?}",
25 facts.iter().map(|it| it.datum_type).collect::<TVec<_>>()
26 );
27 }
28}
29
30pub fn get_device_facts<'a, 'b: 'a, T>(
31 facts: &'a [&'b TypedFact],
32 map_facts: impl Fn(&[&'b TypedFact]) -> TractResult<T>,
33) -> TractResult<T> {
34 if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
35 let device_facts = facts
36 .iter()
37 .map(|it| it.to_device_fact().map(|it| it.as_ref()))
38 .collect::<TractResult<TVec<_>>>()?;
39 (map_facts)(device_facts.as_slice())
40 } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
41 (map_facts)(facts)
42 } else {
43 bail!(
44 "Inconsistent facts datum type: {:?}",
45 facts.iter().map(|it| it.datum_type).collect::<Vec<_>>()
46 );
47 }
48}
49
50pub fn get_device_fact<'a, T: 'a>(
51 fact: &'a TypedFact,
52 map_fact: impl Fn(&'a TypedFact) -> TractResult<T>,
53) -> TractResult<T> {
54 if fact.datum_type == DatumType::Opaque {
55 (map_fact)(fact.to_device_fact()?)
56 } else {
57 (map_fact)(fact)
58 }
59}
60
61pub fn as_q40_fact(fact: &TypedFact) -> Option<&BlockQuantFact> {
62 fact.opaque_fact
63 .as_ref()
64 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
65 .and_then(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None })
66 .or_else(|| {
67 fact.konst
68 .as_ref()
69 .and_then(|k| k.to_scalar::<Opaque>().ok())
70 .and_then(|o| o.downcast_ref::<BlockQuantValue>())
71 .map(|v| &v.fact)
72 .and_then(|bqf| if bqf.format.same_as(&Q4_0) { Some(bqf) } else { None })
73 })
74}
75
76pub fn as_q40_tensor(a: &Tensor) -> Option<&BlockQuantValue> {
77 a.to_scalar::<Opaque>().ok().and_then(|od| {
78 od.downcast_ref::<BlockQuantValue>()
79 .and_then(|bqv| if bqv.fact.format.same_as(&Q4_0) { Some(bqv) } else { None })
80 })
81}
82
83pub fn check_strides_validity(shape: TVec<usize>, strides: TVec<isize>) -> TractResult<()> {
84 let mut zipped_shape_strides: Vec<_> = shape.into_iter().zip(strides).collect();
85 zipped_shape_strides.sort_by_key(|&(_, stride)| stride);
86
87 let mut prev_stride = 1;
88 for (dim, stride) in zipped_shape_strides {
89 ensure!((stride == prev_stride) || (dim == 1), "Invalid strides");
90 prev_stride *= dim as isize;
91 }
92 Ok(())
93}