tract_linalg/
lib.rs

1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::redundant_closure_call)]
3#![allow(clippy::len_zero)]
4#![allow(clippy::excessive_precision)]
5#![allow(clippy::approx_constant)]
6#![allow(unexpected_cfgs)]
7#![allow(unused_macros)]
8#[macro_use]
9extern crate derive_new;
10extern crate lazy_static;
11extern crate log;
12extern crate num_traits;
13#[macro_use]
14extern crate paste;
15#[cfg(test)]
16extern crate proptest;
17
18include!(concat!(env!("OUT_DIR"), "/extern_kernel_macro.rs"));
19
20#[macro_use]
21mod frame;
22pub mod generic;
23pub mod multithread;
24pub use frame::weights::WeightType;
25pub use generic::{ScaleShiftAndRound, Scaler};
26use lazy_static::lazy_static;
27use mmm::{MMMInputFormat, MatMatMul, PanelExtractor};
28use tract_data::internal::TensorView;
29#[cfg(target_arch = "x86_64")]
30pub mod x86_64_fma;
31
32#[cfg(target_arch = "aarch64")]
33pub mod arm64;
34
35#[cfg(target_arch = "aarch64")]
36pub use arm64::has_fp16;
37use tract_itertools::Itertools;
38
39#[cfg(not(target_arch = "aarch64"))]
40pub fn has_fp16() -> bool {
41    false
42}
43
44#[cfg(any(target_arch = "arm", target_arch = "armv7", target_arch = "arm"))]
45pub mod arm32;
46
47#[cfg(all(target_family = "wasm", target_feature = "simd128"))]
48pub mod wasm;
49
50pub use self::frame::*;
51
52use tract_data::prelude::*;
53
54pub type MMMImpl = Box<
55    dyn Fn(Option<usize>, Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync,
56>;
57
58type MMVImpl = Box<dyn Fn(Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync>;
59
60#[allow(clippy::type_complexity)]
61pub struct Ops {
62    mmm_impls: Vec<Box<dyn mmm::MatMatMul>>,
63    panel_extractors: Vec<mmm::PanelExtractor>,
64
65    mmm_f64: MMMImpl,
66    mmv_f64: MMVImpl,
67
68    mmm_f32: MMMImpl,
69    mmv_f32: MMVImpl,
70
71    mmm_f16: MMMImpl,
72    mmv_f16: MMVImpl,
73
74    qmmm_i32: MMMImpl,
75    qmmv_i32: MMVImpl,
76
77    pub leaky_relu_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
78    pub leaky_relu_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
79    pub mul_by_scalar_f32:
80        Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
81    pub mul_by_scalar_f16:
82        Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
83
84    pub sigmoid_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
85    pub sigmoid_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
86    pub tanh_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
87    pub tanh_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
88    pub erf_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
89    pub lut_u8: Box<dyn Fn(&[u8]) -> Box<dyn lut::Lut> + Send + Sync>,
90
91    pub max_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
92    pub max_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
93
94    pub sum_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
95    pub sum_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
96
97    pub softmax2_fastcompact_f16:
98        Box<dyn Fn() -> Box<dyn reduce::MapReduce<f16, f16>> + Send + Sync>,
99    pub softmax2_fastcompact_f32:
100        Box<dyn Fn() -> Box<dyn reduce::MapReduce<f32, f32>> + Send + Sync>,
101}
102
103impl Ops {
104    pub fn mmm_impls(&self) -> &[Box<dyn mmm::MatMatMul>] {
105        &self.mmm_impls
106    }
107
108    pub fn all_possible_packing(
109        &self,
110        weight_type: impl Into<WeightType>,
111    ) -> impl Iterator<Item = &dyn MMMInputFormat> {
112        let weight_type = weight_type.into();
113        self.mmm_impls
114            .iter()
115            .flat_map(|m| m.packings())
116            .map(|p| &*p.0)
117            .flat_map(move |p| {
118                let mut packs: Vec<&dyn MMMInputFormat> = vec![];
119                if p.precursor() == weight_type {
120                    packs.push(p)
121                };
122                for pe in &self.panel_extractors {
123                    if pe.from.precursor() == weight_type && pe.to.same_as(p) {
124                        packs.push(&*pe.from);
125                    }
126                }
127                packs.into_iter()
128            })
129            .sorted_by_key(|p| p.to_string())
130            .dedup()
131    }
132
133    pub fn filter_impls<'o>(
134        &'o self,
135        weight: &'o dyn MMMInputFormat,
136        acc: &[DatumType],
137        act: DatumType,
138        store: DatumType,
139    ) -> impl Iterator<
140        Item = (
141            &'o dyn MatMatMul,
142            usize,
143            &'o dyn MMMInputFormat,
144            Option<&'o PanelExtractor>,
145            &'o dyn MMMInputFormat,
146        ),
147    > {
148        let acc = acc.to_vec();
149        self.mmm_impls
150            .iter()
151            .filter(move |mmm| acc.contains(&mmm.internal_type()) && mmm.stores().contains(&store))
152            .flat_map(|mmm| {
153                mmm.packings()
154                    .iter()
155                    .enumerate()
156                    .map(|(pack_ix, (a, b))| (&**mmm, pack_ix, &**a, &**b))
157            })
158            .filter_map(|(mmm, ix, a, b)| {
159                if a.same_as(weight) {
160                    Some((mmm, ix, a, None, b))
161                } else {
162                    self.panel_extractors
163                        .iter()
164                        .find(|pe| pe.from.same_as(weight) && pe.to.same_as(a))
165                        .map(|pe| (mmm, ix, a, Some(pe), b))
166                }
167            })
168            .filter(move |(_mmm, _ix, _a, _pe, b)| {
169                b.precursor().as_dt().is_some_and(|dt| dt == act)
170            })
171    }
172
173    pub fn panel_extractors(&self) -> &[mmm::panel_extract::PanelExtractor] {
174        &self.panel_extractors
175    }
176
177    pub fn mmm(
178        &self,
179        accumulator: DatumType,
180        m: Option<usize>,
181        k: Option<usize>,
182        n: Option<usize>,
183    ) -> Option<Box<dyn mmm::MatMatMul>> {
184        use DatumType::*;
185        match accumulator {
186            F64 => Some(if n == Some(1) { (self.mmv_f64)(m, k) } else { (self.mmm_f64)(m, k, n) }),
187            F32 => Some(if n == Some(1) { (self.mmv_f32)(m, k) } else { (self.mmm_f32)(m, k, n) }),
188            F16 => Some(if n == Some(1) { (self.mmv_f16)(m, k) } else { (self.mmm_f16)(m, k, n) }),
189            I32 => {
190                Some(if n == Some(1) { (self.qmmv_i32)(m, k) } else { (self.qmmm_i32)(m, k, n) })
191            }
192            _ => None,
193        }
194    }
195}
196
197pub fn generic() -> Ops {
198    use crate::generic::mmm::*;
199    use element_wise::ElementWiseKer;
200    use reduce::{MapReduceKer, ReduceKer};
201    let mut ops = Ops {
202        mmm_impls: vec![],
203        panel_extractors: vec![],
204        mmm_f64: Box::new(|_, _, _| generic_f64_4x4.mmm()),
205        mmv_f64: Box::new(|_, _| generic_f64_4x1.mmm()),
206        mmm_f32: Box::new(|_, _, _| generic_f32_4x4.mmm()),
207        mmv_f32: Box::new(|_, _| generic_f32_4x1.mmm()),
208        mmm_f16: Box::new(|_, _, _| generic_f16_4x4.mmm()),
209        mmv_f16: Box::new(|_, _| generic_f16_4x1.mmm()),
210        qmmm_i32: Box::new(|_, _, _| generic_i32_4x4.mmm()),
211        qmmv_i32: Box::new(|_, _| generic_i32_4x4.mmm()),
212        leaky_relu_f16: Box::new(|| generic::HLeakyRelu8::ew()),
213        leaky_relu_f32: Box::new(|| generic::SLeakyRelu4::ew()),
214        mul_by_scalar_f16: Box::new(|| generic::HMulByScalar8::ew()),
215        mul_by_scalar_f32: Box::new(|| generic::SMulByScalar4::ew()),
216        sigmoid_f16: Box::new(|| generic::HSigmoid8::ew()),
217        sigmoid_f32: Box::new(|| generic::SSigmoid4::ew()),
218        tanh_f16: Box::new(|| generic::HTanh8::ew()),
219        tanh_f32: Box::new(|| generic::STanh4::ew()),
220        erf_f32: Box::new(|| generic::SErf4::ew()),
221        lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::<generic::GenericLut8>::new(table))),
222        max_f16: Box::new(|| generic::reduce::max::HMax8::red()),
223        max_f32: Box::new(|| generic::reduce::max::SMax4::red()),
224        sum_f16: Box::new(|| generic::reduce::sum::HSum8::red()),
225        sum_f32: Box::new(|| generic::reduce::sum::SSum4::red()),
226        /*
227        activation_f32: Box::new(|microcode| generic::SActivation::new(microcode))
228        */
229        softmax2_fastcompact_f16: Box::new(|| generic::reduce::softmax_l2::HSoftMaxL2::red()),
230        softmax2_fastcompact_f32: Box::new(|| generic::reduce::softmax_l2::SSoftMaxL2::red()),
231    };
232    crate::generic::mmm::plug(&mut ops);
233    ops
234}
235
236#[allow(unreachable_code, unused_mut, unexpected_cfgs)]
237pub fn best() -> Ops {
238    let mut ops = generic();
239    #[cfg(target_arch = "x86_64")]
240    x86_64_fma::plug(&mut ops);
241    #[cfg(any(target_arch = "arm", target_arch = "armv7"))]
242    arm32::plug(&mut ops);
243    #[cfg(target_arch = "aarch64")]
244    arm64::plug(&mut ops);
245    #[cfg(all(target_family = "wasm", target_feature = "simd128"))]
246    wasm::plug(&mut ops);
247
248    ops
249}
250
251lazy_static::lazy_static! {
252    static ref OPS: Ops = {
253        best()
254    };
255}
256
257#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
258pub enum BinOp {
259    Min,
260    Max,
261    Add,
262    Mul,
263    Sub,
264    SubF,
265}
266
267impl BinOp {
268    pub fn flip(&self) -> BinOp {
269        use BinOp::*;
270        match self {
271            Sub => SubF,
272            SubF => Sub,
273            sym => *sym,
274        }
275    }
276}
277
278fn register_all_unicast(registry: &mut LinalgRegistry) {
279    generic::register_all_unicast(registry);
280    #[cfg(target_arch = "aarch64")]
281    arm64::register_all_unicast(registry);
282}
283
284fn register_all_by_scalar(registry: &mut LinalgRegistry) {
285    generic::register_all_by_scalar(registry);
286    #[cfg(target_arch = "aarch64")]
287    arm64::register_all_by_scalar(registry);
288}
289
290pub type LinalgFn = dyn Fn(&mut TensorView, &TensorView) -> TractResult<()> + Send + Sync;
291type LinalgRegistry = HashMap<(BinOp, DatumType), Box<dyn Fn() -> Box<LinalgFn> + Send + Sync>>;
292lazy_static! {
293    static ref BIN_UNICAST_OPS: Mutex<LinalgRegistry> = {
294        let mut registry = HashMap::default();
295        register_all_unicast(&mut registry);
296        Mutex::new(registry)
297    };
298    static ref BIN_BY_SCALAR_OPS: Mutex<LinalgRegistry> = {
299        let mut registry = HashMap::default();
300        register_all_by_scalar(&mut registry);
301        Mutex::new(registry)
302    };
303}
304
305pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
306    let map = BIN_BY_SCALAR_OPS.lock().unwrap();
307    if (dt == DatumType::F16) && !has_fp16() {
308        return None;
309    }
310    map.get(&(bin, dt)).map(|it| (it)())
311}
312
313pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
314    let map = BIN_UNICAST_OPS.lock().unwrap();
315    if (dt == DatumType::F16) && !has_fp16() {
316        return None;
317    }
318    map.get(&(bin, dt)).map(|it| (it)())
319}
320
321pub fn ops() -> &'static Ops {
322    &OPS
323}
324
325use num_traits::*;
326use std::collections::HashMap;
327use std::fmt::Debug;
328use std::ops::*;
329use std::sync::Mutex;
330
331pub trait LADatum:
332    Sized
333    + std::fmt::Display
334    + Debug
335    + Copy
336    + Clone
337    + Zero
338    + One
339    + 'static
340    + Add<Output = Self>
341    + Sub<Output = Self>
342    + Mul
343    + AddAssign
344    + PartialOrd
345    + Bounded
346    + tract_data::prelude::Datum
347{
348    #[cfg(test)]
349    fn strat() -> proptest::prelude::BoxedStrategy<Self>;
350}
351
352#[cfg(test)]
353use proptest::prelude::*;
354
355impl LADatum for f16 {
356    #[cfg(test)]
357    fn strat() -> BoxedStrategy<Self> {
358        f32::strat().prop_map(|f| f.as_()).boxed()
359    }
360}
361
362impl LADatum for f32 {
363    #[cfg(test)]
364    fn strat() -> BoxedStrategy<Self> {
365        (-1000isize..1000).prop_map(|i| i as f32 / 1000.0).boxed()
366    }
367}
368
369impl LADatum for f64 {
370    #[cfg(test)]
371    fn strat() -> BoxedStrategy<Self> {
372        (-1000isize..1000).prop_map(|i| i as f64 / 1000.0).boxed()
373    }
374}
375
376impl LADatum for u8 {
377    #[cfg(test)]
378    fn strat() -> BoxedStrategy<Self> {
379        any::<u8>().boxed()
380    }
381}
382
383impl LADatum for i8 {
384    #[cfg(test)]
385    fn strat() -> BoxedStrategy<Self> {
386        any::<i8>().boxed()
387    }
388}
389
390impl LADatum for i32 {
391    #[cfg(test)]
392    fn strat() -> BoxedStrategy<Self> {
393        any::<i32>().boxed()
394    }
395}
396
397#[cfg(test)]
398#[allow(dead_code)]
399fn setup_test_logger() {
400    let _ = env_logger::Builder::from_env("TRACT_LOG").try_init();
401}