tract_linalg/
lib.rs

1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::redundant_closure_call)]
3#![allow(clippy::len_zero)]
4#![allow(clippy::excessive_precision)]
5#![allow(clippy::approx_constant)]
6#![allow(unexpected_cfgs)]
7#![allow(unused_macros)]
8#[macro_use]
9extern crate derive_new;
10extern crate lazy_static;
11extern crate log;
12extern crate num_traits;
13#[macro_use]
14extern crate paste;
15#[cfg(test)]
16extern crate proptest;
17
18include!(concat!(env!("OUT_DIR"), "/extern_kernel_macro.rs"));
19
20#[macro_use]
21pub mod frame;
22pub mod generic;
23pub mod multithread;
24use frame::by_scalar::ByScalarKer;
25use frame::element_wise::ElementWiseKer;
26use frame::mmm::panel_extract::PanelExtractor;
27use frame::mmm::{MMMInputFormat, MMMKit, WeightType};
28use frame::reduce::{MapReduceKer, ReduceKer};
29use frame::unicast::UnicastKer;
30use frame::{reduce, MatMatMul};
31pub use generic::{ScaleShiftAndRound, Scaler};
32use lazy_static::lazy_static;
33use tract_data::internal::TensorView;
34#[cfg(target_arch = "x86_64")]
35pub mod x86_64_fma;
36
37#[cfg(target_arch = "aarch64")]
38pub mod arm64;
39
40#[cfg(target_arch = "aarch64")]
41pub use arm64::has_fp16;
42use tract_itertools::Itertools;
43
44#[cfg(not(target_arch = "aarch64"))]
45pub fn has_fp16() -> bool {
46    false
47}
48
49#[cfg(any(target_arch = "arm", target_arch = "armv7", target_arch = "arm"))]
50pub mod arm32;
51
52#[cfg(all(target_family = "wasm", target_feature = "simd128"))]
53pub mod wasm;
54
55pub use self::frame::{element_wise, lut, mmm};
56
57use tract_data::prelude::*;
58
59pub type MMMImpl = Box<
60    dyn Fn(Option<usize>, Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync,
61>;
62
63type MMVImpl = Box<dyn Fn(Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync>;
64
65#[allow(clippy::type_complexity)]
66pub struct Ops {
67    mmm_impls: Vec<Box<dyn MatMatMul>>,
68    panel_extractors: Vec<PanelExtractor>,
69    mmm_kits: Vec<MMMKit>,
70    // default_kit: Box<dyn Fn(WeightType) -> Box<dyn MMMInputFormat>>,
71    mmm_f64: MMMImpl,
72    mmv_f64: MMVImpl,
73
74    mmm_f32: MMMImpl,
75    mmv_f32: MMVImpl,
76
77    mmm_f16: MMMImpl,
78    mmv_f16: MMVImpl,
79
80    qmmm_i32: MMMImpl,
81    qmmv_i32: MMVImpl,
82
83    pub leaky_relu_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
84    pub leaky_relu_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
85    pub mul_by_scalar_f32:
86        Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
87    pub mul_by_scalar_f16:
88        Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
89
90    pub sigmoid_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
91    pub sigmoid_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
92    pub tanh_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
93    pub tanh_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
94    pub erf_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
95    pub lut_u8: Box<dyn Fn(&[u8]) -> Box<dyn lut::Lut> + Send + Sync>,
96
97    pub max_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
98    pub max_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
99
100    pub sum_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
101    pub sum_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
102
103    pub softmax2_fastcompact_f16:
104        Box<dyn Fn() -> Box<dyn reduce::MapReduce<f16, f16>> + Send + Sync>,
105    pub softmax2_fastcompact_f32:
106        Box<dyn Fn() -> Box<dyn reduce::MapReduce<f32, f32>> + Send + Sync>,
107}
108
109impl Ops {
110    pub fn mmm_impls(&self) -> &[Box<dyn MatMatMul>] {
111        &self.mmm_impls
112    }
113
114    pub fn mmm_kits(&self) -> &[MMMKit] {
115        &self.mmm_kits
116    }
117
118    pub fn kit_input_format(&self, w: WeightType) -> Box<dyn MMMInputFormat> {
119        self.mmm_kits
120            .iter()
121            .filter(|kit| kit.weight == w)
122            .sorted_by_key(|kit| kit.generic_fallback as usize)
123            .next()
124            .unwrap()
125            .static_packer
126            .clone()
127    }
128
129    pub fn panel_extractors(&self) -> &[PanelExtractor] {
130        &self.panel_extractors
131    }
132
133    pub fn mmm(
134        &self,
135        accumulator: DatumType,
136        m: Option<usize>,
137        k: Option<usize>,
138        n: Option<usize>,
139    ) -> Option<Box<dyn mmm::MatMatMul>> {
140        use DatumType::*;
141        match accumulator {
142            F64 => Some(if n == Some(1) { (self.mmv_f64)(m, k) } else { (self.mmm_f64)(m, k, n) }),
143            F32 => Some(if n == Some(1) { (self.mmv_f32)(m, k) } else { (self.mmm_f32)(m, k, n) }),
144            F16 => Some(if n == Some(1) { (self.mmv_f16)(m, k) } else { (self.mmm_f16)(m, k, n) }),
145            I32 => {
146                Some(if n == Some(1) { (self.qmmv_i32)(m, k) } else { (self.qmmm_i32)(m, k, n) })
147            }
148            _ => None,
149        }
150    }
151}
152
153pub fn generic() -> Ops {
154    use crate::generic::mmm::*;
155    let mut ops = Ops {
156        mmm_impls: vec![generic_f32_4x4.mmm(), generic_f32_4x1.mmm()],
157        mmm_kits: vec![],
158        panel_extractors: vec![],
159        mmm_f64: Box::new(|_, _, _| generic_f64_4x4.mmm()),
160        mmv_f64: Box::new(|_, _| generic_f64_4x1.mmm()),
161        mmm_f32: Box::new(|_, _, _| generic_f32_4x4.mmm()),
162        mmv_f32: Box::new(|_, _| generic_f32_4x1.mmm()),
163        mmm_f16: Box::new(|_, _, _| generic_f16_4x4.mmm()),
164        mmv_f16: Box::new(|_, _| generic_f16_4x1.mmm()),
165        qmmm_i32: Box::new(|_, _, _| generic_i32_4x4.mmm()),
166        qmmv_i32: Box::new(|_, _| generic_i32_4x4.mmm()),
167        leaky_relu_f16: Box::new(|| generic::HLeakyRelu8::ew()),
168        leaky_relu_f32: Box::new(|| generic::SLeakyRelu4::ew()),
169        mul_by_scalar_f16: Box::new(|| generic::HMulByScalar8::ew()),
170        mul_by_scalar_f32: Box::new(|| generic::SMulByScalar4::ew()),
171        sigmoid_f16: Box::new(|| generic::HSigmoid8::ew()),
172        sigmoid_f32: Box::new(|| generic::SSigmoid4::ew()),
173        tanh_f16: Box::new(|| generic::HTanh8::ew()),
174        tanh_f32: Box::new(|| generic::STanh4::ew()),
175        erf_f32: Box::new(|| generic::SErf4::ew()),
176        lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::<generic::GenericLut8>::new(table))),
177        max_f16: Box::new(|| generic::reduce::max::HMax8::red()),
178        max_f32: Box::new(|| generic::reduce::max::SMax4::red()),
179        sum_f16: Box::new(|| generic::reduce::sum::HSum8::red()),
180        sum_f32: Box::new(|| generic::reduce::sum::SSum4::red()),
181        /*
182        activation_f32: Box::new(|microcode| generic::SActivation::new(microcode))
183        */
184        softmax2_fastcompact_f16: Box::new(|| generic::reduce::softmax_l2::HSoftMaxL2::red()),
185        softmax2_fastcompact_f32: Box::new(|| generic::reduce::softmax_l2::SSoftMaxL2::red()),
186    };
187    crate::generic::mmm::plug(&mut ops);
188    ops
189}
190
191#[allow(unreachable_code, unused_mut, unexpected_cfgs)]
192pub fn best() -> Ops {
193    let mut ops = generic();
194    #[cfg(target_arch = "x86_64")]
195    x86_64_fma::plug(&mut ops);
196    #[cfg(any(target_arch = "arm", target_arch = "armv7"))]
197    arm32::plug(&mut ops);
198    #[cfg(target_arch = "aarch64")]
199    arm64::plug(&mut ops);
200    #[cfg(all(target_family = "wasm", target_feature = "simd128"))]
201    wasm::plug(&mut ops);
202
203    ops
204}
205
206lazy_static::lazy_static! {
207    static ref OPS: Ops = {
208        best()
209    };
210}
211
212#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
213pub enum BinOp {
214    Min,
215    Max,
216    Add,
217    Mul,
218    Sub,
219    SubF,
220}
221
222impl BinOp {
223    pub fn flip(&self) -> BinOp {
224        use BinOp::*;
225        match self {
226            Sub => SubF,
227            SubF => Sub,
228            sym => *sym,
229        }
230    }
231}
232
233fn register_all_unicast(registry: &mut LinalgRegistry) {
234    generic::register_all_unicast(registry);
235    #[cfg(target_arch = "aarch64")]
236    arm64::register_all_unicast(registry);
237}
238
239fn register_all_by_scalar(registry: &mut LinalgRegistry) {
240    generic::register_all_by_scalar(registry);
241    #[cfg(target_arch = "aarch64")]
242    arm64::register_all_by_scalar(registry);
243}
244
245pub type LinalgFn = dyn Fn(&mut TensorView, &TensorView) -> TractResult<()> + Send + Sync;
246type LinalgRegistry = HashMap<(BinOp, DatumType), Box<dyn Fn() -> Box<LinalgFn> + Send + Sync>>;
247lazy_static! {
248    static ref BIN_UNICAST_OPS: Mutex<LinalgRegistry> = {
249        let mut registry = HashMap::default();
250        register_all_unicast(&mut registry);
251        Mutex::new(registry)
252    };
253    static ref BIN_BY_SCALAR_OPS: Mutex<LinalgRegistry> = {
254        let mut registry = HashMap::default();
255        register_all_by_scalar(&mut registry);
256        Mutex::new(registry)
257    };
258}
259
260pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
261    let map = BIN_BY_SCALAR_OPS.lock().unwrap();
262    if (dt == DatumType::F16) && !has_fp16() {
263        return None;
264    }
265    map.get(&(bin, dt)).map(|it| (it)())
266}
267
268pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
269    let map = BIN_UNICAST_OPS.lock().unwrap();
270    if (dt == DatumType::F16) && !has_fp16() {
271        return None;
272    }
273    map.get(&(bin, dt)).map(|it| (it)())
274}
275
276pub fn ops() -> &'static Ops {
277    &OPS
278}
279
280use num_traits::*;
281use std::collections::HashMap;
282use std::fmt::Debug;
283use std::ops::*;
284use std::sync::Mutex;
285
286pub trait LADatum:
287    Sized
288    + std::fmt::Display
289    + Debug
290    + Copy
291    + Clone
292    + Zero
293    + One
294    + 'static
295    + Add<Output = Self>
296    + Sub<Output = Self>
297    + Mul
298    + AddAssign
299    + PartialOrd
300    + Bounded
301    + tract_data::prelude::Datum
302{
303    #[cfg(test)]
304    fn strat() -> proptest::prelude::BoxedStrategy<Self>;
305}
306
307#[cfg(test)]
308use proptest::prelude::*;
309
310impl LADatum for f16 {
311    #[cfg(test)]
312    fn strat() -> BoxedStrategy<Self> {
313        f32::strat().prop_map(|f| f.as_()).boxed()
314    }
315}
316
317impl LADatum for f32 {
318    #[cfg(test)]
319    fn strat() -> BoxedStrategy<Self> {
320        (-1000isize..1000).prop_map(|i| i as f32 / 1000.0).boxed()
321    }
322}
323
324impl LADatum for f64 {
325    #[cfg(test)]
326    fn strat() -> BoxedStrategy<Self> {
327        (-1000isize..1000).prop_map(|i| i as f64 / 1000.0).boxed()
328    }
329}
330
331impl LADatum for u8 {
332    #[cfg(test)]
333    fn strat() -> BoxedStrategy<Self> {
334        any::<u8>().boxed()
335    }
336}
337
338impl LADatum for i8 {
339    #[cfg(test)]
340    fn strat() -> BoxedStrategy<Self> {
341        any::<i8>().boxed()
342    }
343}
344
345impl LADatum for i32 {
346    #[cfg(test)]
347    fn strat() -> BoxedStrategy<Self> {
348        any::<i32>().boxed()
349    }
350}
351
352#[cfg(test)]
353#[allow(dead_code)]
354fn setup_test_logger() {
355    let _ = env_logger::Builder::from_env("TRACT_LOG").try_init();
356}