Skip to main content

tract_gpu/
utils.rs

1use tract_core::internal::*;
2use tract_core::tract_linalg::block_quant::BlockQuant;
3use tract_linalg::block_quant::{BlockQuantFact, Q4_0};
4
5use crate::fact::*;
6use crate::tensor::DeviceTensor;
7
8pub fn facts_to_device_facts(
9    facts: &[&TypedFact],
10    resolve_facts: impl Fn(&[&TypedFact]) -> TractResult<TVec<TypedFact>>,
11) -> TractResult<TVec<TypedFact>> {
12    if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
13        let device_facts = facts
14            .iter()
15            .map(|it| it.to_device_fact().map(|it| it.as_ref()))
16            .collect::<TractResult<TVec<_>>>()?;
17        let output_facts = (resolve_facts)(device_facts.as_slice())?;
18        Ok(output_facts
19            .into_iter()
20            .map(|it| Ok(DeviceFact::new(DeviceTensorOrigin::FromDevice, it)?.into_opaque_fact()))
21            .collect::<TractResult<_>>()?)
22    } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
23        (resolve_facts)(facts)
24    } else {
25        bail!(
26            "Inconsistent facts datum type: {:?}",
27            facts.iter().map(|it| it.datum_type).collect::<TVec<_>>()
28        );
29    }
30}
31
32pub fn get_device_facts<'a, 'b: 'a, T>(
33    facts: &'a [&'b TypedFact],
34    map_facts: impl Fn(&[&'b TypedFact]) -> TractResult<T>,
35) -> TractResult<T> {
36    if facts.iter().all(|it| it.datum_type == DatumType::Opaque) {
37        let device_facts = facts
38            .iter()
39            .map(|it| it.to_device_fact().map(|it| it.as_ref()))
40            .collect::<TractResult<TVec<_>>>()?;
41        (map_facts)(device_facts.as_slice())
42    } else if facts.iter().all(|it| it.datum_type != DatumType::Opaque) {
43        (map_facts)(facts)
44    } else {
45        bail!(
46            "Inconsistent facts datum type: {:?}",
47            facts.iter().map(|it| it.datum_type).collect::<Vec<_>>()
48        );
49    }
50}
51
52pub fn get_device_fact<'a, T: 'a>(
53    fact: &'a TypedFact,
54    map_fact: impl Fn(&'a TypedFact) -> TractResult<T>,
55) -> TractResult<T> {
56    if fact.datum_type == DatumType::Opaque {
57        (map_fact)(fact.to_device_fact()?)
58    } else {
59        (map_fact)(fact)
60    }
61}
62
63pub fn as_quant_fact<'a>(
64    fact: &'a TypedFact,
65    format: &dyn BlockQuant,
66) -> Option<&'a BlockQuantFact> {
67    fact.opaque_fact
68        .as_ref()
69        .and_then(|of| of.downcast_ref::<BlockQuantFact>())
70        .and_then(|bqf| if bqf.format.same_as(format) { Some(bqf) } else { None })
71        .or_else(|| {
72            fact.konst
73                .as_ref()
74                .and_then(|k| k.try_as_dense().ok()?.to_scalar::<Opaque>().ok())
75                .and_then(|o| o.downcast_ref::<BlobWithFact>())
76                .and_then(|bwf| bwf.fact.downcast_ref::<BlockQuantFact>())
77                .and_then(|bqf| if bqf.format.same_as(format) { Some(bqf) } else { None })
78        })
79}
80
81pub fn as_q40_tensor(a: &Tensor) -> Option<&BlobWithFact> {
82    let od = a.try_as_dense().ok()?.to_scalar::<Opaque>().ok()?;
83    od.downcast_ref::<BlobWithFact>().and_then(|bwf| {
84        if bwf.fact.downcast_ref::<BlockQuantFact>().is_some_and(|bqf| bqf.format.same_as(&Q4_0)) {
85            Some(bwf)
86        } else {
87            None
88        }
89    })
90}
91
92pub fn get_quant_fact(t: &DeviceTensor, format: &dyn BlockQuant) -> Option<BlockQuantFact> {
93    if let DeviceTensor::Owned(t) = t {
94        t.opaque_fact()
95            .and_then(|of| of.downcast_ref::<BlockQuantFact>())
96            .cloned()
97            .filter(|bqf| bqf.format.same_as(format))
98    } else {
99        None
100    }
101}
102
103pub fn check_strides_validity(shape: TVec<usize>, strides: TVec<isize>) -> TractResult<()> {
104    let mut zipped_shape_strides: Vec<_> = shape.into_iter().zip(strides).collect();
105    zipped_shape_strides.sort_by_key(|&(_, stride)| stride);
106
107    let mut prev_stride = 1;
108    for (dim, stride) in zipped_shape_strides {
109        ensure!((stride == prev_stride) || (dim == 1), "Invalid strides");
110        prev_stride *= dim as isize;
111    }
112    Ok(())
113}