1use tract_core::internal::*;
2use tract_core::tract_linalg::block_quant::BlockQuant;
3use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage, Q4_0};
4
5use crate::fact::*;
6use crate::tensor::DeviceTensor;
7
8pub fn facts_to_device_facts(
9 facts: &[&TypedFact],
10 resolve_facts: impl Fn(&[&TypedFact]) -> TractResult<TVec<TypedFact>>,
11) -> TractResult<TVec<TypedFact>> {
12 if facts.iter().all(|it| it.as_device_fact().is_some()) {
13 let device_facts = facts
14 .iter()
15 .map(|it| it.to_device_fact().map(|it| it.as_ref()))
16 .collect::<TractResult<TVec<_>>>()?;
17 let output_facts = (resolve_facts)(device_facts.as_slice())?;
18 Ok(output_facts
19 .into_iter()
20 .map(|it| Ok(DeviceFact::new(DeviceTensorOrigin::FromDevice, it)?.into_exotic_fact()))
21 .collect::<TractResult<_>>()?)
22 } else if facts.iter().all(|it| it.as_device_fact().is_none()) {
23 (resolve_facts)(facts)
24 } else {
25 bail!("Inconsistent facts: mix of device and host facts");
26 }
27}
28
29pub fn get_device_facts<'a, 'b: 'a, T>(
30 facts: &'a [&'b TypedFact],
31 map_facts: impl Fn(&[&'b TypedFact]) -> TractResult<T>,
32) -> TractResult<T> {
33 if facts.iter().all(|it| it.as_device_fact().is_some()) {
34 let device_facts = facts
35 .iter()
36 .map(|it| it.to_device_fact().map(|it| it.as_ref()))
37 .collect::<TractResult<TVec<_>>>()?;
38 (map_facts)(device_facts.as_slice())
39 } else if facts.iter().all(|it| it.as_device_fact().is_none()) {
40 (map_facts)(facts)
41 } else {
42 bail!("Inconsistent facts: mix of device and host facts");
43 }
44}
45
46pub fn get_device_fact<'a, T: 'a>(
47 fact: &'a TypedFact,
48 map_fact: impl Fn(&'a TypedFact) -> TractResult<T>,
49) -> TractResult<T> {
50 if fact.as_device_fact().is_some() {
51 (map_fact)(fact.to_device_fact()?)
52 } else {
53 (map_fact)(fact)
54 }
55}
56
57pub fn as_quant_fact<'a>(
58 fact: &'a TypedFact,
59 format: &dyn BlockQuant,
60) -> Option<&'a BlockQuantFact> {
61 fact.exotic_fact
62 .as_ref()
63 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
64 .and_then(|bqf| if bqf.format.dyn_eq(format) { Some(bqf) } else { None })
65}
66
67pub fn as_q40_tensor(a: &Tensor) -> Option<&BlockQuantStorage> {
68 a.storage_as::<BlockQuantStorage>().filter(|bqs| bqs.format().dyn_eq(&Q4_0))
69}
70
71pub fn get_quant_fact(t: &DeviceTensor, format: &dyn BlockQuant) -> Option<BlockQuantFact> {
72 if let DeviceTensor::Owned(t) = t {
73 t.exotic_fact()
74 .and_then(|of| of.downcast_ref::<BlockQuantFact>())
75 .cloned()
76 .filter(|bqf| bqf.format.dyn_eq(format))
77 } else {
78 None
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub enum BroadcastKind {
86 Unicast,
87 ByScalarLeft,
88 ByScalarRight,
89 Nd1,
90 Nd2,
91 Nd3,
92 Nd4,
93 Nd5,
94 Nd6,
95}
96
97impl BroadcastKind {
98 pub const ALL: [BroadcastKind; 8] = [
99 Self::Unicast,
100 Self::ByScalarLeft,
101 Self::ByScalarRight,
102 Self::Nd1,
103 Self::Nd2,
104 Self::Nd3,
105 Self::Nd4,
106 Self::Nd5,
107 ];
108
109 pub fn from_rank(rank: usize) -> TractResult<Self> {
110 match rank {
111 1 => Ok(Self::Nd1),
112 2 => Ok(Self::Nd2),
113 3 => Ok(Self::Nd3),
114 4 => Ok(Self::Nd4),
115 5 => Ok(Self::Nd5),
116 6 => Ok(Self::Nd6),
117 _ => bail!("Unsupported rank {rank} for broadcasting"),
118 }
119 }
120
121 pub fn name(&self) -> &'static str {
122 match self {
123 Self::Unicast => "unicast",
124 Self::ByScalarLeft => "by_scalar_lhs",
125 Self::ByScalarRight => "by_scalar_rhs",
126 Self::Nd1 => "nd1",
127 Self::Nd2 => "nd2",
128 Self::Nd3 => "nd3",
129 Self::Nd4 => "nd4",
130 Self::Nd5 => "nd5",
131 Self::Nd6 => "nd6",
132 }
133 }
134
135 pub fn copy_tname(dt: DatumType) -> &'static str {
138 match dt.size_of() {
139 1 => "u8",
140 2 => "u16",
141 4 => "u32",
142 8 => "u64",
143 _ => panic!("Unsupported element size {} for copy kernel", dt.size_of()),
144 }
145 }
146
147 pub fn copy_kernel_name(&self, dt: DatumType, prefix: &str) -> TractResult<String> {
148 Ok(format!("{prefix}copy_{}_{}", self.name(), Self::copy_tname(dt)))
149 }
150
151 pub fn all_copy_kernel_names(prefix: &str) -> Vec<String> {
152 let copy_types = ["u8", "u16", "u32", "u64"];
153 Self::ALL
154 .into_iter()
155 .flat_map(|bk| {
156 copy_types
157 .into_iter()
158 .map(move |tname| format!("{prefix}copy_{}_{tname}", bk.name()))
159 })
160 .collect()
161 }
162}
163
164pub fn compute_broadcast_strides<T: num_traits::Zero + Copy + 'static>(
165 shape: &[usize],
166 strides: &[isize],
167) -> TractResult<TVec<T>>
168where
169 isize: num_traits::AsPrimitive<T>,
170{
171 use num_traits::AsPrimitive;
172 ensure!(
173 shape.len() == strides.len(),
174 "Mismatch between shape and strides length while computing broadcast strides"
175 );
176 Ok(strides
177 .iter()
178 .zip(shape)
179 .map(|(s, dim)| if *dim == 1 { T::zero() } else { s.as_() })
180 .collect::<TVec<T>>())
181}
182
183pub fn reshape_to_rank_2(shape: &[usize], axis: usize) -> TVec<usize> {
184 let dim_axis_0 = shape[0..axis].iter().product::<usize>();
185 let dim_axis_2 = shape[axis..].iter().product::<usize>();
186 tvec![dim_axis_0, dim_axis_2]
187}
188
189pub fn reshape_to_rank_3(shape: &[usize], axis: usize) -> TVec<usize> {
190 let dim_axis_0 = shape[0..axis].iter().product::<usize>();
191 let dim_axis_1 = shape[axis];
192 let dim_axis_2 = shape[axis + 1..].iter().product::<usize>();
193 tvec![dim_axis_0, dim_axis_1, dim_axis_2]
194}
195
196pub fn check_strides_validity(shape: TVec<usize>, strides: TVec<isize>) -> TractResult<()> {
197 let mut zipped_shape_strides: Vec<_> = shape.into_iter().zip(strides).collect();
198 zipped_shape_strides.sort_by_key(|&(_, stride)| stride);
199
200 let mut prev_stride = 1;
201 for (dim, stride) in zipped_shape_strides {
202 ensure!((stride == prev_stride) || (dim == 1), "Invalid strides");
203 prev_stride *= dim as isize;
204 }
205 Ok(())
206}