use bytemuck::Pod;
use half::f16;
use safetensors::Dtype;
pub trait Zero: Sized + core::ops::Add<Self, Output = Self> {
fn zero() -> Self;
}
impl Zero for f32 {
fn zero() -> Self {
0.0
}
}
impl Zero for f16 {
fn zero() -> Self {
Self::ZERO
}
}
impl Zero for u8 {
fn zero() -> Self {
0
}
}
impl Zero for u16 {
fn zero() -> Self {
0
}
}
impl Zero for u32 {
fn zero() -> Self {
0
}
}
pub trait One: Sized + core::ops::Mul<Self, Output = Self> {
fn one() -> Self;
}
impl One for f32 {
fn one() -> Self {
1.0
}
}
impl One for f16 {
fn one() -> Self {
Self::ONE
}
}
impl One for u8 {
fn one() -> Self {
1
}
}
impl One for u16 {
fn one() -> Self {
1
}
}
impl One for u32 {
fn one() -> Self {
1
}
}
pub trait Scalar: Sized + Clone + Copy + Pod + Zero + One + Send + Sync + sealed::Sealed {
fn size() -> usize {
std::mem::size_of::<Self>()
}
const DATA_TYPE: Dtype;
}
impl Scalar for f32 {
const DATA_TYPE: Dtype = Dtype::F32;
}
impl Scalar for f16 {
const DATA_TYPE: Dtype = Dtype::F16;
}
impl Scalar for u8 {
const DATA_TYPE: Dtype = Dtype::U8;
}
impl Scalar for u16 {
const DATA_TYPE: Dtype = Dtype::U16;
}
impl Scalar for u32 {
const DATA_TYPE: Dtype = Dtype::U32;
}
pub trait Float: Scalar + Hom<f16> + Hom<f32> + CoHom<f16> + CoHom<f32> {
const DEF: &'static str;
}
impl Float for f32 {
const DEF: &'static str = "FP32";
}
impl Float for f16 {
const DEF: &'static str = "FP16";
}
pub trait Hom<Into> {
fn hom(self) -> Into;
}
impl Hom<f32> for f32 {
fn hom(self) -> f32 {
self
}
}
impl Hom<f16> for f32 {
fn hom(self) -> f16 {
f16::from_f32(self)
}
}
impl Hom<f32> for f16 {
fn hom(self) -> f32 {
self.to_f32()
}
}
impl Hom<f16> for f16 {
fn hom(self) -> f16 {
self
}
}
pub trait CoHom<From> {
fn co_hom(value: From) -> Self;
}
impl<From, Into> CoHom<From> for Into
where
From: Hom<Into>,
{
fn co_hom(value: From) -> Self {
value.hom()
}
}
mod sealed {
use half::f16;
pub trait Sealed {}
impl Sealed for f32 {}
impl Sealed for f16 {}
impl Sealed for u8 {}
impl Sealed for u16 {}
impl Sealed for u32 {}
}