1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#[macro_use]
extern crate derive_new;
extern crate lazy_static;
extern crate libc;
extern crate log;
#[macro_use]
extern crate objekt;
extern crate num_traits;
#[cfg(test)]
extern crate proptest;

pub mod align;
pub mod f16;
#[macro_use]
pub mod frame;
mod generic;

#[cfg(target_arch = "x86_64")]
pub mod x86_64_fma;

#[cfg(target_arch = "aarch64")]
pub mod arm64;

#[cfg(any(target_arch = "arm", target_arch = "armv7"))]
pub mod arm32;

pub use self::frame::mmm;
pub use self::frame::sigmoid;
pub use self::frame::tanh;
pub use self::frame::vecmatmul;

pub struct Ops {
    pub svmm: Box<dyn Fn(usize, usize) -> Box<dyn vecmatmul::VecMatMul<f32>> + Send + Sync>,
    pub smmm: Box<
        dyn Fn(usize, usize, usize) -> Box<dyn mmm::MatMatMul<f32, f32, f32, f32>> + Send + Sync,
    >,
    pub qmmm_i8_i32: Box<
        dyn Fn(usize, usize, usize) -> Box<dyn mmm::QMatMatMul<i8, i8, i32, i32>> + Send + Sync,
    >,
    pub qmmm_u8_i32: Box<
        dyn Fn(usize, usize, usize) -> Box<dyn mmm::QMatMatMul<u8, u8, i32, i32>> + Send + Sync,
    >,
    pub ssigmoid: Box<dyn Fn() -> Box<dyn sigmoid::Sigmoid<f32>> + Send + Sync>,
    pub stanh: Box<dyn Fn() -> Box<dyn tanh::Tanh<f32>> + Send + Sync>,
}

pub fn generic() -> Ops {
    Ops {
        svmm: Box::new(|k, n| {
            Box::new(vecmatmul::PackedVecMatMul::<generic::SVecMatMul8, f32>::new(k, n))
        }),
        smmm: Box::new(|m, k, n| {
            Box::new(mmm::MatMatMulImpl::<
                generic::GenericMmm4x4<f32, f32, f32, f32>,
                f32,
                f32,
                f32,
                f32,
            >::new(m, k, n))
        }),
        qmmm_i8_i32: Box::new(|m, k, n| {
            Box::new(mmm::QMatMatMulImpl::from(mmm::MatMatMulImpl::<
                generic::GenericMmm4x4<i8, i8, i32, i32>,
                i8,
                i8,
                i32,
                i32,
            >::new(m, k, n)))
        }),
        qmmm_u8_i32: Box::new(|m, k, n| {
            Box::new(mmm::QMatMatMulImpl::from(mmm::MatMatMulImpl::<
                generic::GenericMmm4x4<u8, u8, i32, i32>,
                u8,
                u8,
                i32,
                i32,
            >::new(m, k, n)))
        }),
        ssigmoid: Box::new(|| Box::new(sigmoid::SigmoidImpl::<generic::SSigmoid4, f32>::new())),
        stanh: Box::new(|| Box::new(tanh::TanhImpl::<generic::STanh4, f32>::new())),
    }
}

#[allow(unreachable_code, unused_mut)]
pub fn best() -> Ops {
    let mut ops = generic();
    #[cfg(target_arch = "x86_64")]
    {
        if is_x86_feature_detected!("fma") {
            ops.smmm = Box::new(|m, k, n| {
                Box::new(
                    mmm::MatMatMulImpl::<x86_64_fma::mmm::SMatMatMul16x6, f32, f32, f32, f32>::new(
                        m, k, n,
                    ),
                )
            });
            log::info!("x86_64/fma activated");
        }
    }
    #[cfg(any(target_arch = "arm", target_arch = "armv7"))]
    arm32::plug(&mut ops);
    #[cfg(target_arch = "aarch64")]
    arm64::plug(&mut ops);
    return ops;
}

lazy_static::lazy_static! {
    static ref OPS: Ops = {
        best()
    };
}

pub fn ops() -> &'static Ops {
    &*OPS
}

#[cfg(test)]
pub(crate) fn check_close(
    found: &[f32],
    expected: &[f32],
) -> proptest::test_runner::TestCaseResult {
    proptest::prop_assert!(
        found.iter().zip(expected.iter()).all(|(a, b)| (a - b).abs() < 0.001),
        "found: {:?} expected: {:?}",
        found,
        expected
    );
    Ok(())
}