zyx_core/
backend.rs

1extern crate alloc;
2use crate::error::ZyxError;
3use crate::tensor::{tensor, IntoTensor, Tensor};
4use crate::utils::SizedIterator;
5use crate::{dtype::DType, node::Node, scalar::Scalar, shape::Shape, tensor::Id};
6use alloc::{
7    collections::{BTreeMap, BTreeSet},
8    string::String,
9    vec::Vec,
10};
11use core::ops::Range;
12
13/// Backend for [tensors](Tensor).
14/// Tensor requires that all backends implement this trait and only this trait.
15pub trait Backend: Copy {
16    /// Create graph of operations between tensors in dot format for visualization
17    #[must_use]
18    fn plot_graph<'a, B: Backend + 'a>(
19        self,
20        tensors: impl IntoIterator<Item = &'a Tensor<B>>,
21    ) -> String;
22
23    /// Create new tensor
24    #[must_use]
25    fn tensor(self, data: impl IntoTensor<Self>) -> Result<Tensor<Self>, ZyxError> {
26        Ok(data.into_tensor(self))
27    }
28
29    /// Create new tensor using values from standard normal distribution
30    #[must_use]
31    fn randn(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError>;
32
33    /// Create new tensor using values from uniform distribution
34    #[must_use]
35    fn uniform(
36        self,
37        shape: impl Into<Shape>,
38        range: Range<impl Scalar>,
39    ) -> Result<Tensor<Self>, ZyxError>;
40
41    /// Create new tensor by repeating single value
42    #[must_use]
43    fn full(self, shape: impl Into<Shape>, value: impl Scalar) -> Result<Tensor<Self>, ZyxError> {
44        Ok(tensor(self.store([value])?, self).expand(shape))
45    }
46
47    /// Create new tensor by repeating zeroes
48    #[must_use]
49    fn zeros(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
50        match dtype {
51            DType::F32 => self.full(shape, 0f32),
52            DType::F64 => self.full(shape, 0f64),
53            DType::I32 => self.full(shape, 0),
54        }
55    }
56
57    /// Create new tensor by repeating ones
58    #[must_use]
59    fn ones(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
60        match dtype {
61            DType::F32 => self.full(shape, 1f32),
62            DType::F64 => self.full(shape, 1f64),
63            DType::I32 => self.full(shape, 1),
64        }
65    }
66
67    /// Create eye tensor
68    #[must_use]
69    fn eye(self, n: usize, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
70        Ok(tensor(
71            match dtype {
72                DType::F32 => self.store(
73                    (0..n)
74                        .flat_map(move |i| (0..n).map(move |j| if j == i { 1f32 } else { 0. }))
75                        .make_sized(n * n),
76                )?,
77                DType::F64 => self.store(
78                    (0..n)
79                        .flat_map(move |i| (0..n).map(move |j| if j == i { 1f64 } else { 0. }))
80                        .make_sized(n * n),
81                )?,
82                DType::I32 => self.store(
83                    (0..n)
84                        .flat_map(move |i| (0..n).map(move |j| if j == i { 1i32 } else { 0 }))
85                        .make_sized(n * n),
86                )?,
87            },
88            self,
89        )
90        .reshape([n, n]))
91    }
92
93    /// Get shape if tensor x
94    #[must_use]
95    fn shape(self, x: Id) -> Shape;
96    /// Get dtype of tensor x
97    #[must_use]
98    fn dtype(self, x: Id) -> DType;
99    /// Calculate derivatives of x w.r.t. sources.
100    /// Returns map source id -> gradient id
101    fn backward(self, x: Id, sources: &BTreeSet<Id>) -> Result<BTreeMap<Id, Id>, ZyxError>;
102    /// Returns iterator over data stored in backend
103    fn load<T: Scalar>(self, id: Id) -> Result<Vec<T>, ZyxError>;
104    /// Store iterator into backend as tensor
105    fn store<T: Scalar, IT>(self, iter: IT) -> Result<Id, ZyxError>
106    where
107        IT: IntoIterator<Item = T>,
108        IT::IntoIter: ExactSizeIterator;
109    /// Create new tensor from given operation
110    fn push(self, node: Node) -> Result<Id, ZyxError>;
111    /// Decrease reference count of tensor
112    fn release(self, x: Id) -> Result<(), ZyxError>;
113    /// Increase reference count of tensor
114    fn retain(self, x: Id);
115}