use crate::dtype::DType;
pub trait Scalar: Clone + core::fmt::Debug + 'static {
fn dtype() -> DType;
fn zero() -> Self;
fn one() -> Self;
fn byte_size() -> usize;
fn into_f32(self) -> f32;
fn into_f64(self) -> f64;
fn into_i32(self) -> i32;
fn reciprocal(self) -> Self;
fn neg(self) -> Self;
fn relu(self) -> Self;
fn sin(self) -> Self;
fn cos(self) -> Self;
fn ln(self) -> Self;
fn exp(self) -> Self;
fn tanh(self) -> Self;
fn sqrt(self) -> Self;
fn add(self, rhs: Self) -> Self;
fn sub(self, rhs: Self) -> Self;
fn mul(self, rhs: Self) -> Self;
fn div(self, rhs: Self) -> Self;
fn pow(self, rhs: Self) -> Self;
fn cmplt(self, rhs: Self) -> Self;
fn max(self, rhs: Self) -> Self;
fn max_value() -> Self;
fn min_value() -> Self;
fn epsilon() -> Self;
fn is_equal(self, rhs: Self) -> bool;
}
impl Scalar for f32 {
fn dtype() -> DType {
DType::F32
}
fn zero() -> Self {
0.
}
fn one() -> Self {
1.
}
fn byte_size() -> usize {
4
}
fn into_f32(self) -> f32 {
self
}
fn into_f64(self) -> f64 {
self as f64
}
fn into_i32(self) -> i32 {
self as i32
}
fn reciprocal(self) -> Self {
1.0 / self
}
fn neg(self) -> Self {
-self
}
fn relu(self) -> Self {
self.max(0.)
}
fn sin(self) -> Self {
f32::sin(self)
}
fn cos(self) -> Self {
f32::cos(self)
}
fn exp(self) -> Self {
f32::exp(self)
}
fn ln(self) -> Self {
f32::ln(self)
}
fn tanh(self) -> Self {
f32::tanh(self)
}
fn sqrt(self) -> Self {
if self >= 0. {
Self::from_bits((self.to_bits() + 0x3f80_0000) >> 1)
} else {
Self::NAN
}
}
fn add(self, rhs: Self) -> Self {
self + rhs
}
fn sub(self, rhs: Self) -> Self {
self - rhs
}
fn mul(self, rhs: Self) -> Self {
self * rhs
}
fn div(self, rhs: Self) -> Self {
self / rhs
}
fn pow(self, rhs: Self) -> Self {
f32::powf(self, rhs)
}
fn cmplt(self, rhs: Self) -> Self {
(self < rhs) as i32 as f32
}
fn max(self, rhs: Self) -> Self {
f32::max(self, rhs)
}
fn max_value() -> Self {
f32::MAX
}
fn min_value() -> Self {
f32::MIN
}
fn epsilon() -> Self {
0.00001
}
fn is_equal(self, rhs: Self) -> bool {
(self == -f32::INFINITY && rhs == -f32::INFINITY)
|| (self - rhs).abs() < Self::epsilon()
|| (self - rhs).abs() < self.abs() * 0.01
}
}
impl Scalar for f64 {
fn dtype() -> DType {
DType::F64
}
fn zero() -> Self {
0.
}
fn one() -> Self {
1.
}
fn byte_size() -> usize {
8
}
fn into_f32(self) -> f32 {
self as f32
}
fn into_f64(self) -> f64 {
self
}
fn into_i32(self) -> i32 {
self as i32
}
fn reciprocal(self) -> Self {
1.0 / self
}
fn neg(self) -> Self {
-self
}
fn relu(self) -> Self {
self.max(0.)
}
fn sin(self) -> Self {
f64::sin(self)
}
fn cos(self) -> Self {
f64::cos(self)
}
fn exp(self) -> Self {
f64::exp(self)
}
fn ln(self) -> Self {
f64::ln(self)
}
fn tanh(self) -> Self {
f64::tanh(self)
}
fn sqrt(self) -> Self {
f64::sqrt(self)
}
fn add(self, rhs: Self) -> Self {
self + rhs
}
fn sub(self, rhs: Self) -> Self {
self - rhs
}
fn mul(self, rhs: Self) -> Self {
self * rhs
}
fn div(self, rhs: Self) -> Self {
self / rhs
}
fn pow(self, rhs: Self) -> Self {
f64::powf(self, rhs)
}
fn cmplt(self, rhs: Self) -> Self {
(self < rhs) as i32 as f64
}
fn max(self, rhs: Self) -> Self {
f64::max(self, rhs)
}
fn max_value() -> Self {
f64::MAX
}
fn min_value() -> Self {
f64::MIN
}
fn epsilon() -> Self {
0.00001
}
fn is_equal(self, rhs: Self) -> bool {
(self == -f64::INFINITY && rhs == -f64::INFINITY)
|| (self - rhs).abs() < Self::epsilon()
|| (self - rhs).abs() < self.abs() * 0.01
}
}
impl Scalar for i32 {
fn dtype() -> DType {
DType::I32
}
fn zero() -> Self {
0
}
fn one() -> Self {
1
}
fn byte_size() -> usize {
4
}
fn into_f32(self) -> f32 {
self as f32
}
fn into_f64(self) -> f64 {
self as f64
}
fn into_i32(self) -> i32 {
self
}
fn reciprocal(self) -> Self {
1 / self
}
fn neg(self) -> Self {
-self
}
fn relu(self) -> Self {
<i32 as Ord>::max(self, 0)
}
fn sin(self) -> Self {
f32::sin(self as f32) as i32
}
fn cos(self) -> Self {
f32::cos(self as f32) as i32
}
fn exp(self) -> Self {
f32::exp(self as f32) as i32
}
fn ln(self) -> Self {
f32::ln(self as f32) as i32
}
fn tanh(self) -> Self {
f32::tanh(self as f32) as i32
}
fn sqrt(self) -> Self {
(self as f32).sqrt() as i32
}
fn add(self, rhs: Self) -> Self {
self + rhs
}
fn sub(self, rhs: Self) -> Self {
self - rhs
}
fn mul(self, rhs: Self) -> Self {
self * rhs
}
fn div(self, rhs: Self) -> Self {
self / rhs
}
fn pow(self, rhs: Self) -> Self {
i32::pow(self, rhs as u32)
}
fn cmplt(self, rhs: Self) -> Self {
(self < rhs) as i32
}
fn max(self, rhs: Self) -> Self {
<i32 as Ord>::max(self, rhs)
}
fn max_value() -> Self {
i32::MAX
}
fn min_value() -> Self {
i32::MIN
}
fn epsilon() -> Self {
0
}
fn is_equal(self, rhs: Self) -> bool {
self == rhs
}
}