Skip to main content

tract_linalg/
lib.rs

1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::redundant_closure_call)]
3#![allow(clippy::len_zero)]
4#![allow(clippy::excessive_precision)]
5#![allow(clippy::approx_constant)]
6#![allow(clippy::manual_is_multiple_of)]
7#![allow(unexpected_cfgs)]
8#![allow(unused_macros)]
9#[macro_use]
10extern crate derive_new;
11extern crate lazy_static;
12extern crate log;
13extern crate num_traits;
14#[macro_use]
15extern crate pastey;
16#[cfg(test)]
17extern crate proptest;
18
19include!(concat!(env!("OUT_DIR"), "/extern_kernel_macro.rs"));
20
21#[macro_use]
22mod frame;
23pub mod generic;
24pub mod multithread;
25pub use frame::weights::WeightType;
26pub use generic::{ScaleShiftAndRound, Scaler};
27use lazy_static::lazy_static;
28use mmm::{MMMInputFormat, MatMatMul, PanelExtractor};
29use tract_data::internal::TensorView;
30#[cfg(target_arch = "x86_64")]
31pub mod x86_64_fma;
32
33pub mod hwbench;
34
35#[cfg(target_arch = "aarch64")]
36pub mod arm64;
37
38#[cfg(target_arch = "aarch64")]
39pub use arm64::has_fp16;
40use tract_itertools::Itertools;
41
42#[cfg(not(target_arch = "aarch64"))]
43pub fn has_fp16() -> bool {
44    false
45}
46
47#[cfg(any(target_arch = "arm", target_arch = "armv7", target_arch = "arm"))]
48pub mod arm32;
49
50#[cfg(all(target_family = "wasm", target_feature = "simd128"))]
51pub mod wasm;
52
53pub use self::frame::*;
54
55use tract_data::prelude::*;
56
57pub type MMMImpl = Box<
58    dyn Fn(Option<usize>, Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync,
59>;
60
61type MMVImpl = Box<dyn Fn(Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync>;
62
63#[allow(clippy::type_complexity)]
64pub struct Ops {
65    mmm_impls: Vec<Box<dyn mmm::MatMatMul>>,
66    panel_extractors: Vec<mmm::PanelExtractor>,
67
68    mmm_f64: MMMImpl,
69    mmv_f64: MMVImpl,
70
71    mmm_f32: MMMImpl,
72    mmv_f32: MMVImpl,
73
74    mmm_f16: MMMImpl,
75    mmv_f16: MMVImpl,
76
77    qmmm_i32: MMMImpl,
78    qmmv_i32: MMVImpl,
79
80    pub leaky_relu_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
81    pub leaky_relu_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
82    pub mul_by_scalar_f32:
83        Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
84    pub mul_by_scalar_f16:
85        Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
86
87    pub sigmoid_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
88    pub sigmoid_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
89    pub tanh_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
90    pub tanh_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
91    pub erf_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
92    pub hardswish_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
93    pub hardswish_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
94    pub silu_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
95    pub silu_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
96    pub gelu_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
97    pub gelu_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
98    pub lut_u8: Box<dyn Fn(&[u8]) -> Box<dyn lut::Lut> + Send + Sync>,
99
100    pub max_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
101    pub max_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
102
103    pub sum_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
104    pub sum_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
105
106    pub softmax2_fastcompact_f16:
107        Box<dyn Fn() -> Box<dyn reduce::MapReduce<f16, f16>> + Send + Sync>,
108    pub softmax2_fastcompact_f32:
109        Box<dyn Fn() -> Box<dyn reduce::MapReduce<f32, f32>> + Send + Sync>,
110
111    /// Fused row-wise RmsNorm: out_i = x_i * rsqrt(mean(x_i²) + eps).
112    /// Replaces a 4-call composition (MeanOfSquares + Add + Rsqrt + Mul) with
113    /// a single 2-pass kernel. Called once per row by `core::ops::nn::RmsNorm`
114    /// when the input is f32 and the axis is the last (contiguous) one.
115    pub rms_norm_f32: Box<dyn Fn(&mut [f32], f32) + Send + Sync>,
116}
117
118impl Ops {
119    pub fn mmm_impls(&self) -> &[Box<dyn mmm::MatMatMul>] {
120        &self.mmm_impls
121    }
122
123    pub fn all_possible_packing(
124        &self,
125        weight_type: impl Into<WeightType>,
126    ) -> impl Iterator<Item = &dyn MMMInputFormat> {
127        let weight_type = weight_type.into();
128        self.mmm_impls
129            .iter()
130            .flat_map(|m| m.packings())
131            .map(|p| &*p.0)
132            .flat_map(move |p| {
133                let mut packs: Vec<&dyn MMMInputFormat> = vec![];
134                if p.precursor() == weight_type {
135                    packs.push(p)
136                };
137                for pe in &self.panel_extractors {
138                    if pe.from.precursor() == weight_type && pe.to.dyn_eq(p) {
139                        packs.push(&*pe.from);
140                    }
141                }
142                packs.into_iter()
143            })
144            .sorted_by_key(|p| p.to_string())
145            .dedup()
146    }
147
148    pub fn filter_impls<'o>(
149        &'o self,
150        weight: &'o dyn MMMInputFormat,
151        acc: &[DatumType],
152        act: DatumType,
153        store: DatumType,
154    ) -> impl Iterator<
155        Item = (
156            &'o dyn MatMatMul,
157            usize,
158            &'o dyn MMMInputFormat,
159            Option<&'o PanelExtractor>,
160            &'o dyn MMMInputFormat,
161        ),
162    > {
163        let acc = acc.to_vec();
164        self.mmm_impls
165            .iter()
166            .filter(move |mmm| acc.contains(&mmm.internal_type()) && mmm.stores().contains(&store))
167            .flat_map(|mmm| {
168                mmm.packings()
169                    .iter()
170                    .enumerate()
171                    .map(|(pack_ix, (a, b))| (&**mmm, pack_ix, &**a, &**b))
172            })
173            .filter_map(|(mmm, ix, a, b)| {
174                if a.dyn_eq(weight) {
175                    Some((mmm, ix, a, None, b))
176                } else {
177                    self.panel_extractors
178                        .iter()
179                        .find(|pe| pe.from.dyn_eq(weight) && pe.to.dyn_eq(a))
180                        .map(|pe| (mmm, ix, a, Some(pe), b))
181                }
182            })
183            .filter(move |(_mmm, _ix, _a, _pe, b)| {
184                b.precursor().as_dt().is_some_and(|dt| dt == act)
185            })
186    }
187
188    pub fn panel_extractors(&self) -> &[mmm::panel_extract::PanelExtractor] {
189        &self.panel_extractors
190    }
191
192    pub fn mmm(
193        &self,
194        accumulator: DatumType,
195        m: Option<usize>,
196        k: Option<usize>,
197        n: Option<usize>,
198    ) -> Option<Box<dyn mmm::MatMatMul>> {
199        use DatumType::*;
200        match accumulator {
201            F64 => Some(if n == Some(1) { (self.mmv_f64)(m, k) } else { (self.mmm_f64)(m, k, n) }),
202            F32 => Some(if n == Some(1) { (self.mmv_f32)(m, k) } else { (self.mmm_f32)(m, k, n) }),
203            F16 => Some(if n == Some(1) { (self.mmv_f16)(m, k) } else { (self.mmm_f16)(m, k, n) }),
204            I32 => {
205                Some(if n == Some(1) { (self.qmmv_i32)(m, k) } else { (self.qmmm_i32)(m, k, n) })
206            }
207            _ => None,
208        }
209    }
210}
211
212pub fn generic() -> Ops {
213    use crate::generic::mmm::*;
214    use element_wise::ElementWiseKer;
215    use reduce::{MapReduceKer, ReduceKer};
216    let mut ops = Ops {
217        mmm_impls: vec![],
218        panel_extractors: vec![],
219        mmm_f64: Box::new(|_, _, _| generic_f64_4x4.mmm()),
220        mmv_f64: Box::new(|_, _| generic_f64_4x1.mmm()),
221        mmm_f32: Box::new(|_, _, _| generic_f32_4x4.mmm()),
222        mmv_f32: Box::new(|_, _| generic_f32_4x1.mmm()),
223        mmm_f16: Box::new(|_, _, _| generic_f16_4x4.mmm()),
224        mmv_f16: Box::new(|_, _| generic_f16_4x1.mmm()),
225        qmmm_i32: Box::new(|_, _, _| generic_i32_4x4.mmm()),
226        qmmv_i32: Box::new(|_, _| generic_i32_4x4.mmm()),
227        leaky_relu_f16: Box::new(|| generic::HLeakyRelu8::ew()),
228        leaky_relu_f32: Box::new(|| generic::SLeakyRelu4::ew()),
229        mul_by_scalar_f16: Box::new(|| generic::HMulByScalar8::ew()),
230        mul_by_scalar_f32: Box::new(|| generic::SMulByScalar4::ew()),
231        sigmoid_f16: Box::new(|| generic::HSigmoid8::ew()),
232        sigmoid_f32: Box::new(|| generic::SSigmoid4::ew()),
233        tanh_f16: Box::new(|| generic::HTanh8::ew()),
234        tanh_f32: Box::new(|| generic::STanh4::ew()),
235        erf_f32: Box::new(|| generic::SErf4::ew()),
236        hardswish_f16: Box::new(|| generic::HHardSwish8::ew()),
237        hardswish_f32: Box::new(|| generic::SHardSwish4::ew()),
238        silu_f16: Box::new(|| generic::HSiLU8::ew()),
239        silu_f32: Box::new(|| generic::SSiLU4::ew()),
240        gelu_f16: Box::new(|| generic::HGelu8::ew()),
241        gelu_f32: Box::new(|| generic::SGelu4::ew()),
242        lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::<generic::GenericLut8>::new(table))),
243        max_f16: Box::new(|| generic::reduce::max::HMax8::red()),
244        max_f32: Box::new(|| generic::reduce::max::SMax4::red()),
245        sum_f16: Box::new(|| generic::reduce::sum::HSum8::red()),
246        sum_f32: Box::new(|| generic::reduce::sum::SSum4::red()),
247        /*
248        activation_f32: Box::new(|microcode| generic::SActivation::new(microcode))
249        */
250        softmax2_fastcompact_f16: Box::new(|| generic::reduce::softmax_l2::HSoftMaxL2::red()),
251        softmax2_fastcompact_f32: Box::new(|| generic::reduce::softmax_l2::SSoftMaxL2::red()),
252        rms_norm_f32: Box::new(generic::rms_norm::rms_norm_f32),
253    };
254    crate::generic::mmm::plug(&mut ops);
255    ops
256}
257
258#[allow(unreachable_code, unused_mut, unexpected_cfgs)]
259pub fn best() -> Ops {
260    let mut ops = generic();
261    #[cfg(target_arch = "x86_64")]
262    x86_64_fma::plug(&mut ops);
263    #[cfg(any(target_arch = "arm", target_arch = "armv7"))]
264    arm32::plug(&mut ops);
265    #[cfg(target_arch = "aarch64")]
266    arm64::plug(&mut ops);
267    #[cfg(all(target_family = "wasm", target_feature = "simd128"))]
268    wasm::plug(&mut ops);
269
270    ops
271}
272
273lazy_static::lazy_static! {
274    static ref OPS: Ops = {
275        best()
276    };
277}
278
279#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
280pub enum BinOp {
281    Min,
282    Max,
283    Add,
284    Mul,
285    Sub,
286    SubF,
287}
288
289impl BinOp {
290    pub fn flip(&self) -> BinOp {
291        use BinOp::*;
292        match self {
293            Sub => SubF,
294            SubF => Sub,
295            sym => *sym,
296        }
297    }
298}
299
300fn register_all_unicast(registry: &mut LinalgRegistry) {
301    generic::register_all_unicast(registry);
302    #[cfg(target_arch = "aarch64")]
303    arm64::register_all_unicast(registry);
304}
305
306fn register_all_by_scalar(registry: &mut LinalgRegistry) {
307    generic::register_all_by_scalar(registry);
308    #[cfg(target_arch = "aarch64")]
309    arm64::register_all_by_scalar(registry);
310}
311
312pub type LinalgFn = dyn Fn(&mut TensorView, &TensorView) -> TractResult<()> + Send + Sync;
313type LinalgRegistry = HashMap<(BinOp, DatumType), Box<dyn Fn() -> Box<LinalgFn> + Send + Sync>>;
314lazy_static! {
315    static ref BIN_UNICAST_OPS: Mutex<LinalgRegistry> = {
316        let mut registry = HashMap::default();
317        register_all_unicast(&mut registry);
318        Mutex::new(registry)
319    };
320    static ref BIN_BY_SCALAR_OPS: Mutex<LinalgRegistry> = {
321        let mut registry = HashMap::default();
322        register_all_by_scalar(&mut registry);
323        Mutex::new(registry)
324    };
325}
326
327pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
328    let map = BIN_BY_SCALAR_OPS.lock().unwrap();
329    if (dt == DatumType::F16) && !has_fp16() {
330        return None;
331    }
332    map.get(&(bin, dt)).map(|it| (it)())
333}
334
335pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
336    let map = BIN_UNICAST_OPS.lock().unwrap();
337    if (dt == DatumType::F16) && !has_fp16() {
338        return None;
339    }
340    map.get(&(bin, dt)).map(|it| (it)())
341}
342
343pub fn ops() -> &'static Ops {
344    &OPS
345}
346
347use dyn_eq::DynEq;
348use num_traits::*;
349use std::collections::HashMap;
350use std::fmt::Debug;
351use std::ops::*;
352use std::sync::Mutex;
353
354pub trait LADatum:
355    Sized
356    + std::fmt::Display
357    + Debug
358    + Copy
359    + Clone
360    + Zero
361    + One
362    + 'static
363    + Add<Output = Self>
364    + Sub<Output = Self>
365    + Mul
366    + AddAssign
367    + PartialOrd
368    + Bounded
369    + tract_data::prelude::Datum
370{
371    #[cfg(test)]
372    fn strat() -> proptest::prelude::BoxedStrategy<Self>;
373}
374
375#[cfg(test)]
376use proptest::prelude::*;
377
378impl LADatum for f16 {
379    #[cfg(test)]
380    fn strat() -> BoxedStrategy<Self> {
381        f32::strat().prop_map(|f| f.as_()).boxed()
382    }
383}
384
385impl LADatum for f32 {
386    #[cfg(test)]
387    fn strat() -> BoxedStrategy<Self> {
388        (-1000isize..1000).prop_map(|i| i as f32 / 1000.0).boxed()
389    }
390}
391
392impl LADatum for f64 {
393    #[cfg(test)]
394    fn strat() -> BoxedStrategy<Self> {
395        (-1000isize..1000).prop_map(|i| i as f64 / 1000.0).boxed()
396    }
397}
398
399impl LADatum for u8 {
400    #[cfg(test)]
401    fn strat() -> BoxedStrategy<Self> {
402        any::<u8>().boxed()
403    }
404}
405
406impl LADatum for i8 {
407    #[cfg(test)]
408    fn strat() -> BoxedStrategy<Self> {
409        any::<i8>().boxed()
410    }
411}
412
413impl LADatum for i32 {
414    #[cfg(test)]
415    fn strat() -> BoxedStrategy<Self> {
416        any::<i32>().boxed()
417    }
418}
419
420#[cfg(test)]
421#[allow(dead_code)]
422fn setup_test_logger() {
423    let _ = env_logger::Builder::from_env("TRACT_LOG").try_init();
424}