extern crate alloc;
use crate::error::ZyxError;
use crate::tensor::{tensor, IntoTensor, Tensor};
use crate::utils::SizedIterator;
use crate::{dtype::DType, node::Node, scalar::Scalar, shape::Shape, tensor::Id};
use alloc::{
collections::{BTreeMap, BTreeSet},
string::String,
vec::Vec,
};
use core::ops::Range;
pub trait Backend: Copy {
#[must_use]
fn plot_graph<'a, B: Backend + 'a>(
self,
tensors: impl IntoIterator<Item = &'a Tensor<B>>,
) -> String;
#[must_use]
fn tensor(self, data: impl IntoTensor<Self>) -> Result<Tensor<Self>, ZyxError> {
Ok(data.into_tensor(self))
}
#[must_use]
fn randn(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError>;
#[must_use]
fn uniform(
self,
shape: impl Into<Shape>,
range: Range<impl Scalar>,
) -> Result<Tensor<Self>, ZyxError>;
#[must_use]
fn full(self, shape: impl Into<Shape>, value: impl Scalar) -> Result<Tensor<Self>, ZyxError> {
Ok(tensor(self.store([value])?, self).expand(shape))
}
#[must_use]
fn zeros(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
match dtype {
DType::F32 => self.full(shape, 0f32),
DType::F64 => self.full(shape, 0f64),
DType::I32 => self.full(shape, 0),
}
}
#[must_use]
fn ones(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
match dtype {
DType::F32 => self.full(shape, 1f32),
DType::F64 => self.full(shape, 1f64),
DType::I32 => self.full(shape, 1),
}
}
#[must_use]
fn eye(self, n: usize, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
Ok(tensor(
match dtype {
DType::F32 => self.store(
(0..n)
.flat_map(move |i| (0..n).map(move |j| if j == i { 1f32 } else { 0. }))
.make_sized(n * n),
)?,
DType::F64 => self.store(
(0..n)
.flat_map(move |i| (0..n).map(move |j| if j == i { 1f64 } else { 0. }))
.make_sized(n * n),
)?,
DType::I32 => self.store(
(0..n)
.flat_map(move |i| (0..n).map(move |j| if j == i { 1i32 } else { 0 }))
.make_sized(n * n),
)?,
},
self,
)
.reshape([n, n]))
}
#[must_use]
fn shape(self, x: Id) -> Shape;
#[must_use]
fn dtype(self, x: Id) -> DType;
fn backward(self, x: Id, sources: &BTreeSet<Id>) -> Result<BTreeMap<Id, Id>, ZyxError>;
fn load<T: Scalar>(self, id: Id) -> Result<Vec<T>, ZyxError>;
fn store<T: Scalar, IT>(self, iter: IT) -> Result<Id, ZyxError>
where
IT: IntoIterator<Item = T>,
IT::IntoIter: ExactSizeIterator;
fn push(self, node: Node) -> Result<Id, ZyxError>;
fn release(self, x: Id) -> Result<(), ZyxError>;
fn retain(self, x: Id);
}