1#![allow(clippy::missing_safety_doc)]
2#![allow(clippy::redundant_closure_call)]
3#![allow(clippy::len_zero)]
4#![allow(clippy::excessive_precision)]
5#![allow(clippy::approx_constant)]
6#![allow(unexpected_cfgs)]
7#![allow(unused_macros)]
8#[macro_use]
9extern crate derive_new;
10extern crate lazy_static;
11extern crate log;
12extern crate num_traits;
13#[macro_use]
14extern crate paste;
15#[cfg(test)]
16extern crate proptest;
17
18include!(concat!(env!("OUT_DIR"), "/extern_kernel_macro.rs"));
19
20#[macro_use]
21mod frame;
22pub mod generic;
23pub mod multithread;
24pub use frame::weights::WeightType;
25pub use generic::{ScaleShiftAndRound, Scaler};
26use lazy_static::lazy_static;
27use mmm::{MMMInputFormat, MatMatMul, PanelExtractor};
28use tract_data::internal::TensorView;
29#[cfg(target_arch = "x86_64")]
30pub mod x86_64_fma;
31
32#[cfg(target_arch = "aarch64")]
33pub mod arm64;
34
35#[cfg(target_arch = "aarch64")]
36pub use arm64::has_fp16;
37use tract_itertools::Itertools;
38
39#[cfg(not(target_arch = "aarch64"))]
40pub fn has_fp16() -> bool {
41 false
42}
43
44#[cfg(any(target_arch = "arm", target_arch = "armv7", target_arch = "arm"))]
45pub mod arm32;
46
47#[cfg(all(target_family = "wasm", target_feature = "simd128"))]
48pub mod wasm;
49
50pub use self::frame::*;
51
52use tract_data::prelude::*;
53
54pub type MMMImpl = Box<
55 dyn Fn(Option<usize>, Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync,
56>;
57
58type MMVImpl = Box<dyn Fn(Option<usize>, Option<usize>) -> Box<dyn mmm::MatMatMul> + Send + Sync>;
59
60#[allow(clippy::type_complexity)]
61pub struct Ops {
62 mmm_impls: Vec<Box<dyn mmm::MatMatMul>>,
63 panel_extractors: Vec<mmm::PanelExtractor>,
64
65 mmm_f64: MMMImpl,
66 mmv_f64: MMVImpl,
67
68 mmm_f32: MMMImpl,
69 mmv_f32: MMVImpl,
70
71 mmm_f16: MMMImpl,
72 mmv_f16: MMVImpl,
73
74 qmmm_i32: MMMImpl,
75 qmmv_i32: MMVImpl,
76
77 pub leaky_relu_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
78 pub leaky_relu_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
79 pub mul_by_scalar_f32:
80 Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32, f32>> + Send + Sync>,
81 pub mul_by_scalar_f16:
82 Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16, f16>> + Send + Sync>,
83
84 pub sigmoid_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
85 pub sigmoid_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
86 pub tanh_f16: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f16>> + Send + Sync>,
87 pub tanh_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
88 pub erf_f32: Box<dyn Fn() -> Box<dyn element_wise::ElementWise<f32>> + Send + Sync>,
89 pub lut_u8: Box<dyn Fn(&[u8]) -> Box<dyn lut::Lut> + Send + Sync>,
90
91 pub max_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
92 pub max_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
93
94 pub sum_f16: Box<dyn Fn() -> Box<dyn reduce::Reduce<f16>> + Send + Sync>,
95 pub sum_f32: Box<dyn Fn() -> Box<dyn reduce::Reduce<f32>> + Send + Sync>,
96
97 pub softmax2_fastcompact_f16:
98 Box<dyn Fn() -> Box<dyn reduce::MapReduce<f16, f16>> + Send + Sync>,
99 pub softmax2_fastcompact_f32:
100 Box<dyn Fn() -> Box<dyn reduce::MapReduce<f32, f32>> + Send + Sync>,
101}
102
103impl Ops {
104 pub fn mmm_impls(&self) -> &[Box<dyn mmm::MatMatMul>] {
105 &self.mmm_impls
106 }
107
108 pub fn all_possible_packing(
109 &self,
110 weight_type: impl Into<WeightType>,
111 ) -> impl Iterator<Item = &dyn MMMInputFormat> {
112 let weight_type = weight_type.into();
113 self.mmm_impls
114 .iter()
115 .flat_map(|m| m.packings())
116 .map(|p| &*p.0)
117 .flat_map(move |p| {
118 let mut packs: Vec<&dyn MMMInputFormat> = vec![];
119 if p.precursor() == weight_type {
120 packs.push(p)
121 };
122 for pe in &self.panel_extractors {
123 if pe.from.precursor() == weight_type && pe.to.same_as(p) {
124 packs.push(&*pe.from);
125 }
126 }
127 packs.into_iter()
128 })
129 .sorted_by_key(|p| p.to_string())
130 .dedup()
131 }
132
133 pub fn filter_impls<'o>(
134 &'o self,
135 weight: &'o dyn MMMInputFormat,
136 acc: &[DatumType],
137 act: DatumType,
138 store: DatumType,
139 ) -> impl Iterator<
140 Item = (
141 &'o dyn MatMatMul,
142 usize,
143 &'o dyn MMMInputFormat,
144 Option<&'o PanelExtractor>,
145 &'o dyn MMMInputFormat,
146 ),
147 > {
148 let acc = acc.to_vec();
149 self.mmm_impls
150 .iter()
151 .filter(move |mmm| acc.contains(&mmm.internal_type()) && mmm.stores().contains(&store))
152 .flat_map(|mmm| {
153 mmm.packings()
154 .iter()
155 .enumerate()
156 .map(|(pack_ix, (a, b))| (&**mmm, pack_ix, &**a, &**b))
157 })
158 .filter_map(|(mmm, ix, a, b)| {
159 if a.same_as(weight) {
160 Some((mmm, ix, a, None, b))
161 } else {
162 self.panel_extractors
163 .iter()
164 .find(|pe| pe.from.same_as(weight) && pe.to.same_as(a))
165 .map(|pe| (mmm, ix, a, Some(pe), b))
166 }
167 })
168 .filter(move |(_mmm, _ix, _a, _pe, b)| {
169 b.precursor().as_dt().is_some_and(|dt| dt == act)
170 })
171 }
172
173 pub fn panel_extractors(&self) -> &[mmm::panel_extract::PanelExtractor] {
174 &self.panel_extractors
175 }
176
177 pub fn mmm(
178 &self,
179 accumulator: DatumType,
180 m: Option<usize>,
181 k: Option<usize>,
182 n: Option<usize>,
183 ) -> Option<Box<dyn mmm::MatMatMul>> {
184 use DatumType::*;
185 match accumulator {
186 F64 => Some(if n == Some(1) { (self.mmv_f64)(m, k) } else { (self.mmm_f64)(m, k, n) }),
187 F32 => Some(if n == Some(1) { (self.mmv_f32)(m, k) } else { (self.mmm_f32)(m, k, n) }),
188 F16 => Some(if n == Some(1) { (self.mmv_f16)(m, k) } else { (self.mmm_f16)(m, k, n) }),
189 I32 => {
190 Some(if n == Some(1) { (self.qmmv_i32)(m, k) } else { (self.qmmm_i32)(m, k, n) })
191 }
192 _ => None,
193 }
194 }
195}
196
197pub fn generic() -> Ops {
198 use crate::generic::mmm::*;
199 use element_wise::ElementWiseKer;
200 use reduce::{MapReduceKer, ReduceKer};
201 let mut ops = Ops {
202 mmm_impls: vec![],
203 panel_extractors: vec![],
204 mmm_f64: Box::new(|_, _, _| generic_f64_4x4.mmm()),
205 mmv_f64: Box::new(|_, _| generic_f64_4x1.mmm()),
206 mmm_f32: Box::new(|_, _, _| generic_f32_4x4.mmm()),
207 mmv_f32: Box::new(|_, _| generic_f32_4x1.mmm()),
208 mmm_f16: Box::new(|_, _, _| generic_f16_4x4.mmm()),
209 mmv_f16: Box::new(|_, _| generic_f16_4x1.mmm()),
210 qmmm_i32: Box::new(|_, _, _| generic_i32_4x4.mmm()),
211 qmmv_i32: Box::new(|_, _| generic_i32_4x4.mmm()),
212 leaky_relu_f16: Box::new(|| generic::HLeakyRelu8::ew()),
213 leaky_relu_f32: Box::new(|| generic::SLeakyRelu4::ew()),
214 mul_by_scalar_f16: Box::new(|| generic::HMulByScalar8::ew()),
215 mul_by_scalar_f32: Box::new(|| generic::SMulByScalar4::ew()),
216 sigmoid_f16: Box::new(|| generic::HSigmoid8::ew()),
217 sigmoid_f32: Box::new(|| generic::SSigmoid4::ew()),
218 tanh_f16: Box::new(|| generic::HTanh8::ew()),
219 tanh_f32: Box::new(|| generic::STanh4::ew()),
220 erf_f32: Box::new(|| generic::SErf4::ew()),
221 lut_u8: Box::new(|table: &[u8]| Box::new(lut::LutImpl::<generic::GenericLut8>::new(table))),
222 max_f16: Box::new(|| generic::reduce::max::HMax8::red()),
223 max_f32: Box::new(|| generic::reduce::max::SMax4::red()),
224 sum_f16: Box::new(|| generic::reduce::sum::HSum8::red()),
225 sum_f32: Box::new(|| generic::reduce::sum::SSum4::red()),
226 softmax2_fastcompact_f16: Box::new(|| generic::reduce::softmax_l2::HSoftMaxL2::red()),
230 softmax2_fastcompact_f32: Box::new(|| generic::reduce::softmax_l2::SSoftMaxL2::red()),
231 };
232 crate::generic::mmm::plug(&mut ops);
233 ops
234}
235
236#[allow(unreachable_code, unused_mut, unexpected_cfgs)]
237pub fn best() -> Ops {
238 let mut ops = generic();
239 #[cfg(target_arch = "x86_64")]
240 x86_64_fma::plug(&mut ops);
241 #[cfg(any(target_arch = "arm", target_arch = "armv7"))]
242 arm32::plug(&mut ops);
243 #[cfg(target_arch = "aarch64")]
244 arm64::plug(&mut ops);
245 #[cfg(all(target_family = "wasm", target_feature = "simd128"))]
246 wasm::plug(&mut ops);
247
248 ops
249}
250
251lazy_static::lazy_static! {
252 static ref OPS: Ops = {
253 best()
254 };
255}
256
257#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
258pub enum BinOp {
259 Min,
260 Max,
261 Add,
262 Mul,
263 Sub,
264 SubF,
265}
266
267impl BinOp {
268 pub fn flip(&self) -> BinOp {
269 use BinOp::*;
270 match self {
271 Sub => SubF,
272 SubF => Sub,
273 sym => *sym,
274 }
275 }
276}
277
278fn register_all_unicast(registry: &mut LinalgRegistry) {
279 generic::register_all_unicast(registry);
280 #[cfg(target_arch = "aarch64")]
281 arm64::register_all_unicast(registry);
282}
283
284fn register_all_by_scalar(registry: &mut LinalgRegistry) {
285 generic::register_all_by_scalar(registry);
286 #[cfg(target_arch = "aarch64")]
287 arm64::register_all_by_scalar(registry);
288}
289
290pub type LinalgFn = dyn Fn(&mut TensorView, &TensorView) -> TractResult<()> + Send + Sync;
291type LinalgRegistry = HashMap<(BinOp, DatumType), Box<dyn Fn() -> Box<LinalgFn> + Send + Sync>>;
292lazy_static! {
293 static ref BIN_UNICAST_OPS: Mutex<LinalgRegistry> = {
294 let mut registry = HashMap::default();
295 register_all_unicast(&mut registry);
296 Mutex::new(registry)
297 };
298 static ref BIN_BY_SCALAR_OPS: Mutex<LinalgRegistry> = {
299 let mut registry = HashMap::default();
300 register_all_by_scalar(&mut registry);
301 Mutex::new(registry)
302 };
303}
304
305pub fn bin_by_scalar(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
306 let map = BIN_BY_SCALAR_OPS.lock().unwrap();
307 if (dt == DatumType::F16) && !has_fp16() {
308 return None;
309 }
310 map.get(&(bin, dt)).map(|it| (it)())
311}
312
313pub fn bin_unicast(dt: DatumType, bin: BinOp) -> Option<Box<LinalgFn>> {
314 let map = BIN_UNICAST_OPS.lock().unwrap();
315 if (dt == DatumType::F16) && !has_fp16() {
316 return None;
317 }
318 map.get(&(bin, dt)).map(|it| (it)())
319}
320
321pub fn ops() -> &'static Ops {
322 &OPS
323}
324
325use num_traits::*;
326use std::collections::HashMap;
327use std::fmt::Debug;
328use std::ops::*;
329use std::sync::Mutex;
330
331pub trait LADatum:
332 Sized
333 + std::fmt::Display
334 + Debug
335 + Copy
336 + Clone
337 + Zero
338 + One
339 + 'static
340 + Add<Output = Self>
341 + Sub<Output = Self>
342 + Mul
343 + AddAssign
344 + PartialOrd
345 + Bounded
346 + tract_data::prelude::Datum
347{
348 #[cfg(test)]
349 fn strat() -> proptest::prelude::BoxedStrategy<Self>;
350}
351
352#[cfg(test)]
353use proptest::prelude::*;
354
355impl LADatum for f16 {
356 #[cfg(test)]
357 fn strat() -> BoxedStrategy<Self> {
358 f32::strat().prop_map(|f| f.as_()).boxed()
359 }
360}
361
362impl LADatum for f32 {
363 #[cfg(test)]
364 fn strat() -> BoxedStrategy<Self> {
365 (-1000isize..1000).prop_map(|i| i as f32 / 1000.0).boxed()
366 }
367}
368
369impl LADatum for f64 {
370 #[cfg(test)]
371 fn strat() -> BoxedStrategy<Self> {
372 (-1000isize..1000).prop_map(|i| i as f64 / 1000.0).boxed()
373 }
374}
375
376impl LADatum for u8 {
377 #[cfg(test)]
378 fn strat() -> BoxedStrategy<Self> {
379 any::<u8>().boxed()
380 }
381}
382
383impl LADatum for i8 {
384 #[cfg(test)]
385 fn strat() -> BoxedStrategy<Self> {
386 any::<i8>().boxed()
387 }
388}
389
390impl LADatum for i32 {
391 #[cfg(test)]
392 fn strat() -> BoxedStrategy<Self> {
393 any::<i32>().boxed()
394 }
395}
396
397#[cfg(test)]
398#[allow(dead_code)]
399fn setup_test_logger() {
400 let _ = env_logger::Builder::from_env("TRACT_LOG").try_init();
401}