1extern crate alloc;
2use crate::error::ZyxError;
3use crate::tensor::{tensor, IntoTensor, Tensor};
4use crate::utils::SizedIterator;
5use crate::{dtype::DType, node::Node, scalar::Scalar, shape::Shape, tensor::Id};
6use alloc::{
7 collections::{BTreeMap, BTreeSet},
8 string::String,
9 vec::Vec,
10};
11use core::ops::Range;
12
13pub trait Backend: Copy {
16 #[must_use]
18 fn plot_graph<'a, B: Backend + 'a>(
19 self,
20 tensors: impl IntoIterator<Item = &'a Tensor<B>>,
21 ) -> String;
22
23 #[must_use]
25 fn tensor(self, data: impl IntoTensor<Self>) -> Result<Tensor<Self>, ZyxError> {
26 Ok(data.into_tensor(self))
27 }
28
29 #[must_use]
31 fn randn(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError>;
32
33 #[must_use]
35 fn uniform(
36 self,
37 shape: impl Into<Shape>,
38 range: Range<impl Scalar>,
39 ) -> Result<Tensor<Self>, ZyxError>;
40
41 #[must_use]
43 fn full(self, shape: impl Into<Shape>, value: impl Scalar) -> Result<Tensor<Self>, ZyxError> {
44 Ok(tensor(self.store([value])?, self).expand(shape))
45 }
46
47 #[must_use]
49 fn zeros(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
50 match dtype {
51 DType::F32 => self.full(shape, 0f32),
52 DType::F64 => self.full(shape, 0f64),
53 DType::I32 => self.full(shape, 0),
54 }
55 }
56
57 #[must_use]
59 fn ones(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
60 match dtype {
61 DType::F32 => self.full(shape, 1f32),
62 DType::F64 => self.full(shape, 1f64),
63 DType::I32 => self.full(shape, 1),
64 }
65 }
66
67 #[must_use]
69 fn eye(self, n: usize, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
70 Ok(tensor(
71 match dtype {
72 DType::F32 => self.store(
73 (0..n)
74 .flat_map(move |i| (0..n).map(move |j| if j == i { 1f32 } else { 0. }))
75 .make_sized(n * n),
76 )?,
77 DType::F64 => self.store(
78 (0..n)
79 .flat_map(move |i| (0..n).map(move |j| if j == i { 1f64 } else { 0. }))
80 .make_sized(n * n),
81 )?,
82 DType::I32 => self.store(
83 (0..n)
84 .flat_map(move |i| (0..n).map(move |j| if j == i { 1i32 } else { 0 }))
85 .make_sized(n * n),
86 )?,
87 },
88 self,
89 )
90 .reshape([n, n]))
91 }
92
93 #[must_use]
95 fn shape(self, x: Id) -> Shape;
96 #[must_use]
98 fn dtype(self, x: Id) -> DType;
99 fn backward(self, x: Id, sources: &BTreeSet<Id>) -> Result<BTreeMap<Id, Id>, ZyxError>;
102 fn load<T: Scalar>(self, id: Id) -> Result<Vec<T>, ZyxError>;
104 fn store<T: Scalar, IT>(self, iter: IT) -> Result<Id, ZyxError>
106 where
107 IT: IntoIterator<Item = T>,
108 IT::IntoIter: ExactSizeIterator;
109 fn push(self, node: Node) -> Result<Id, ZyxError>;
111 fn release(self, x: Id) -> Result<(), ZyxError>;
113 fn retain(self, x: Id);
115}