tract_linalg/
generic.rs

1pub mod by_scalar;
2pub mod erf;
3pub mod leaky_relu;
4pub mod lut;
5pub mod mmm;
6pub mod reduce;
7pub mod rounding;
8pub mod sigmoid;
9pub mod tanh;
10pub mod unicast;
11
12use tract_data::prelude::DatumType;
13
14use crate::by_scalar::ByScalarKer;
15use crate::unicast::UnicastKer;
16use crate::{BinOp, LinalgRegistry};
17
18pub use self::by_scalar::{HMulByScalar8, SMulByScalar4};
19pub use self::erf::SErf4;
20pub use self::leaky_relu::{HLeakyRelu8, SLeakyRelu4};
21pub use self::lut::GenericLut8;
22pub use self::reduce::softmax_l2::SSoftMaxL2;
23pub use self::rounding::{ScaleShiftAndRound, Scaler};
24pub use self::sigmoid::{HSigmoid8, SSigmoid4};
25pub use self::tanh::{HTanh8, STanh4};
26
27pub(crate) fn register_all_unicast(registry: &mut LinalgRegistry) {
28    registry.insert((BinOp::Mul, DatumType::F32), Box::new(|| unicast::SUnicastMul4::bin()));
29    registry.insert((BinOp::Mul, DatumType::F16), Box::new(|| unicast::HUnicastMul8::bin()));
30    registry.insert((BinOp::Add, DatumType::F32), Box::new(|| unicast::SUnicastAdd4::bin()));
31    registry.insert((BinOp::Add, DatumType::F16), Box::new(|| unicast::HUnicastAdd8::bin()));
32    registry.insert((BinOp::Sub, DatumType::F32), Box::new(|| unicast::SUnicastSub4::bin()));
33    registry.insert((BinOp::Sub, DatumType::F16), Box::new(|| unicast::HUnicastSub8::bin()));
34    registry.insert((BinOp::SubF, DatumType::F32), Box::new(|| unicast::SUnicastSubF4::bin()));
35    registry.insert((BinOp::SubF, DatumType::F16), Box::new(|| unicast::HUnicastSubF8::bin()));
36    registry.insert((BinOp::Min, DatumType::F32), Box::new(|| unicast::SUnicastMin4::bin()));
37    registry.insert((BinOp::Min, DatumType::F16), Box::new(|| unicast::HUnicastMin8::bin()));
38    registry.insert((BinOp::Max, DatumType::F32), Box::new(|| unicast::SUnicastMax4::bin()));
39    registry.insert((BinOp::Max, DatumType::F16), Box::new(|| unicast::HUnicastMax8::bin()));
40}
41
42pub(crate) fn register_all_by_scalar(registry: &mut LinalgRegistry) {
43    registry.insert((BinOp::Mul, DatumType::F32), Box::new(|| by_scalar::SMulByScalar4::bin()));
44    registry.insert((BinOp::Mul, DatumType::F16), Box::new(|| by_scalar::HMulByScalar8::bin()));
45    registry.insert((BinOp::Add, DatumType::F32), Box::new(|| by_scalar::SAddByScalar4::bin()));
46    registry.insert((BinOp::Add, DatumType::F16), Box::new(|| by_scalar::HAddByScalar8::bin()));
47    registry.insert((BinOp::Sub, DatumType::F32), Box::new(|| by_scalar::SSubByScalar4::bin()));
48    registry.insert((BinOp::Sub, DatumType::F16), Box::new(|| by_scalar::HSubByScalar8::bin()));
49    registry.insert((BinOp::SubF, DatumType::F32), Box::new(|| by_scalar::SSubFByScalar4::bin()));
50    registry.insert((BinOp::SubF, DatumType::F16), Box::new(|| by_scalar::HSubFByScalar8::bin()));
51    registry.insert((BinOp::Min, DatumType::F32), Box::new(|| by_scalar::SMinByScalar4::bin()));
52    registry.insert((BinOp::Min, DatumType::F16), Box::new(|| by_scalar::HMinByScalar8::bin()));
53    registry.insert((BinOp::Max, DatumType::F32), Box::new(|| by_scalar::SMaxByScalar4::bin()));
54    registry.insert((BinOp::Max, DatumType::F16), Box::new(|| by_scalar::HMaxByScalar8::bin()));
55}