Skip to main content

tract_gpu/
utils.rs

1use tract_core::internal::*;
2use tract_core::tract_linalg::block_quant::BlockQuant;
3use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage, Q4_0};
4
5use crate::fact::*;
6use crate::tensor::DeviceTensor;
7
8pub fn facts_to_device_facts(
9    facts: &[&TypedFact],
10    resolve_facts: impl Fn(&[&TypedFact]) -> TractResult<TVec<TypedFact>>,
11) -> TractResult<TVec<TypedFact>> {
12    if facts.iter().all(|it| it.as_device_fact().is_some()) {
13        let device_facts = facts
14            .iter()
15            .map(|it| it.to_device_fact().map(|it| it.as_ref()))
16            .collect::<TractResult<TVec<_>>>()?;
17        let output_facts = (resolve_facts)(device_facts.as_slice())?;
18        Ok(output_facts
19            .into_iter()
20            .map(|it| Ok(DeviceFact::new(DeviceTensorOrigin::FromDevice, it)?.into_exotic_fact()))
21            .collect::<TractResult<_>>()?)
22    } else if facts.iter().all(|it| it.as_device_fact().is_none()) {
23        (resolve_facts)(facts)
24    } else {
25        bail!("Inconsistent facts: mix of device and host facts");
26    }
27}
28
29pub fn get_device_facts<'a, 'b: 'a, T>(
30    facts: &'a [&'b TypedFact],
31    map_facts: impl Fn(&[&'b TypedFact]) -> TractResult<T>,
32) -> TractResult<T> {
33    if facts.iter().all(|it| it.as_device_fact().is_some()) {
34        let device_facts = facts
35            .iter()
36            .map(|it| it.to_device_fact().map(|it| it.as_ref()))
37            .collect::<TractResult<TVec<_>>>()?;
38        (map_facts)(device_facts.as_slice())
39    } else if facts.iter().all(|it| it.as_device_fact().is_none()) {
40        (map_facts)(facts)
41    } else {
42        bail!("Inconsistent facts: mix of device and host facts");
43    }
44}
45
46pub fn get_device_fact<'a, T: 'a>(
47    fact: &'a TypedFact,
48    map_fact: impl Fn(&'a TypedFact) -> TractResult<T>,
49) -> TractResult<T> {
50    if fact.as_device_fact().is_some() {
51        (map_fact)(fact.to_device_fact()?)
52    } else {
53        (map_fact)(fact)
54    }
55}
56
57pub fn as_quant_fact<'a>(
58    fact: &'a TypedFact,
59    format: &dyn BlockQuant,
60) -> Option<&'a BlockQuantFact> {
61    fact.exotic_fact
62        .as_ref()
63        .and_then(|of| of.downcast_ref::<BlockQuantFact>())
64        .and_then(|bqf| if bqf.format.dyn_eq(format) { Some(bqf) } else { None })
65}
66
67pub fn as_q40_tensor(a: &Tensor) -> Option<&BlockQuantStorage> {
68    a.storage_as::<BlockQuantStorage>().filter(|bqs| bqs.format().dyn_eq(&Q4_0))
69}
70
71pub fn get_quant_fact(t: &DeviceTensor, format: &dyn BlockQuant) -> Option<BlockQuantFact> {
72    if let DeviceTensor::Owned(t) = t {
73        t.exotic_fact()
74            .and_then(|of| of.downcast_ref::<BlockQuantFact>())
75            .cloned()
76            .filter(|bqf| bqf.format.dyn_eq(format))
77    } else {
78        None
79    }
80}
81
82// --- Shared array/copy utilities ---
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub enum BroadcastKind {
86    Unicast,
87    ByScalarLeft,
88    ByScalarRight,
89    Nd1,
90    Nd2,
91    Nd3,
92    Nd4,
93    Nd5,
94    Nd6,
95}
96
97impl BroadcastKind {
98    pub const ALL: [BroadcastKind; 8] = [
99        Self::Unicast,
100        Self::ByScalarLeft,
101        Self::ByScalarRight,
102        Self::Nd1,
103        Self::Nd2,
104        Self::Nd3,
105        Self::Nd4,
106        Self::Nd5,
107    ];
108
109    pub fn from_rank(rank: usize) -> TractResult<Self> {
110        match rank {
111            1 => Ok(Self::Nd1),
112            2 => Ok(Self::Nd2),
113            3 => Ok(Self::Nd3),
114            4 => Ok(Self::Nd4),
115            5 => Ok(Self::Nd5),
116            6 => Ok(Self::Nd6),
117            _ => bail!("Unsupported rank {rank} for broadcasting"),
118        }
119    }
120
121    pub fn name(&self) -> &'static str {
122        match self {
123            Self::Unicast => "unicast",
124            Self::ByScalarLeft => "by_scalar_lhs",
125            Self::ByScalarRight => "by_scalar_rhs",
126            Self::Nd1 => "nd1",
127            Self::Nd2 => "nd2",
128            Self::Nd3 => "nd3",
129            Self::Nd4 => "nd4",
130            Self::Nd5 => "nd5",
131            Self::Nd6 => "nd6",
132        }
133    }
134
135    /// Map datum type to the copy kernel type name based on element size.
136    /// Copy kernels only care about element size, not the actual type.
137    pub fn copy_tname(dt: DatumType) -> &'static str {
138        match dt.size_of() {
139            1 => "u8",
140            2 => "u16",
141            4 => "u32",
142            8 => "u64",
143            _ => panic!("Unsupported element size {} for copy kernel", dt.size_of()),
144        }
145    }
146
147    pub fn copy_kernel_name(&self, dt: DatumType, prefix: &str) -> TractResult<String> {
148        Ok(format!("{prefix}copy_{}_{}", self.name(), Self::copy_tname(dt)))
149    }
150
151    pub fn all_copy_kernel_names(prefix: &str) -> Vec<String> {
152        let copy_types = ["u8", "u16", "u32", "u64"];
153        Self::ALL
154            .into_iter()
155            .flat_map(|bk| {
156                copy_types
157                    .into_iter()
158                    .map(move |tname| format!("{prefix}copy_{}_{tname}", bk.name()))
159            })
160            .collect()
161    }
162}
163
164pub fn compute_broadcast_strides<T: num_traits::Zero + Copy + 'static>(
165    shape: &[usize],
166    strides: &[isize],
167) -> TractResult<TVec<T>>
168where
169    isize: num_traits::AsPrimitive<T>,
170{
171    use num_traits::AsPrimitive;
172    ensure!(
173        shape.len() == strides.len(),
174        "Mismatch between shape and strides length while computing broadcast strides"
175    );
176    Ok(strides
177        .iter()
178        .zip(shape)
179        .map(|(s, dim)| if *dim == 1 { T::zero() } else { s.as_() })
180        .collect::<TVec<T>>())
181}
182
183pub fn reshape_to_rank_2(shape: &[usize], axis: usize) -> TVec<usize> {
184    let dim_axis_0 = shape[0..axis].iter().product::<usize>();
185    let dim_axis_2 = shape[axis..].iter().product::<usize>();
186    tvec![dim_axis_0, dim_axis_2]
187}
188
189pub fn reshape_to_rank_3(shape: &[usize], axis: usize) -> TVec<usize> {
190    let dim_axis_0 = shape[0..axis].iter().product::<usize>();
191    let dim_axis_1 = shape[axis];
192    let dim_axis_2 = shape[axis + 1..].iter().product::<usize>();
193    tvec![dim_axis_0, dim_axis_1, dim_axis_2]
194}
195
196pub fn check_strides_validity(shape: TVec<usize>, strides: TVec<isize>) -> TractResult<()> {
197    let mut zipped_shape_strides: Vec<_> = shape.into_iter().zip(strides).collect();
198    zipped_shape_strides.sort_by_key(|&(_, stride)| stride);
199
200    let mut prev_stride = 1;
201    for (dim, stride) in zipped_shape_strides {
202        ensure!((stride == prev_stride) || (dim == 1), "Invalid strides");
203        prev_stride *= dim as isize;
204    }
205    Ok(())
206}