1#![no_std]
14#![forbid(rustdoc::broken_intra_doc_links)]
16#![forbid(rustdoc::private_intra_doc_links)]
17#![forbid(missing_docs)]
18#![forbid(rustdoc::missing_crate_level_docs)]
19#![forbid(rustdoc::private_doc_tests)]
21#![forbid(rustdoc::invalid_codeblock_attributes)]
22#![forbid(rustdoc::invalid_html_tags)]
23#![forbid(rustdoc::invalid_rust_codeblocks)]
24#![forbid(rustdoc::bare_urls)]
25#![forbid(rustdoc::unescaped_backticks)]
26#![forbid(rustdoc::redundant_explicit_links)]
27
28#[cfg(feature = "std")]
29extern crate std;
30
31mod interpreter;
32use crate::interpreter::Interpreter;
33
34extern crate alloc;
35use alloc::{
36 collections::{BTreeMap, BTreeSet},
37 vec::Vec,
38};
39use core::ops::Range;
40use std::cell::RefCell;
41#[cfg(feature = "std")]
42pub use zyx_core::io::save;
43use zyx_core::{
44 backend::Backend,
45 node::Node,
46 runtime::Runtime,
47 scalar::Scalar,
48 shape::Shape,
49 tensor::Id,
50 tensor::{tensor, IntoTensor},
51};
52pub use zyx_core::{dtype::DType, error::ZyxError, tensor::Tensor};
53
54pub struct CPU(RefCell<Runtime<Interpreter>>);
56
57pub fn device() -> Result<CPU, ZyxError> {
59 Ok(CPU(RefCell::new(Runtime::new(Interpreter::new()))))
60}
61
62impl CPU {
63 #[must_use]
65 pub fn tensor<'a>(&'a self, data: impl IntoTensor<&'a Self>) -> Tensor<&'a Self> {
66 <&Self as Backend>::tensor(self, data).unwrap()
67 }
68
69 #[must_use]
71 pub fn randn(&self, shape: impl Into<Shape>, dtype: DType) -> Tensor<&Self> {
72 <&Self as Backend>::randn(self, shape, dtype).unwrap()
73 }
74
75 #[must_use]
77 pub fn uniform(&self, shape: impl Into<Shape>, range: Range<impl Scalar>) -> Tensor<&Self> {
78 <&Self as Backend>::uniform(self, shape, range).unwrap()
79 }
80
81 #[must_use]
83 pub fn full(&self, shape: impl Into<Shape>, value: impl Scalar) -> Tensor<&Self> {
84 <&Self as Backend>::full(self, shape, value).unwrap()
85 }
86
87 #[must_use]
89 pub fn zeros(&self, shape: impl Into<Shape>, dtype: DType) -> Tensor<&Self> {
90 <&Self as Backend>::zeros(self, shape, dtype).unwrap()
91 }
92
93 #[must_use]
95 pub fn ones(&self, shape: impl Into<Shape>, dtype: DType) -> Tensor<&Self> {
96 <&Self as Backend>::ones(self, shape, dtype).unwrap()
97 }
98
99 #[must_use]
101 pub fn eye(&self, n: usize, dtype: DType) -> Tensor<&Self> {
102 <&Self as Backend>::eye(self, n, dtype).unwrap()
103 }
104
105 #[must_use]
107 pub fn plot_graph<'a, B: Backend + 'a>(
108 &self,
109 tensors: impl IntoIterator<Item = &'a Tensor<B>>,
110 ) -> alloc::string::String {
111 <&Self as Backend>::plot_graph(self, tensors)
112 }
113
114 #[cfg(feature = "std")]
116 pub fn load(&self, path: impl AsRef<std::path::Path>) -> Result<Vec<Tensor<&CPU>>, ZyxError> {
117 zyx_core::io::load(self, path)
118 }
119}
120
121impl Backend for &CPU {
122 fn plot_graph<'a, B: Backend + 'a>(
123 self,
124 tensors: impl IntoIterator<Item = &'a Tensor<B>>,
125 ) -> alloc::string::String {
126 let ids: Vec<Id> = tensors.into_iter().map(|t| t.id()).collect();
127 self.0.borrow().plot_graph_dot(&ids)
128 }
129
130 fn randn(self, shape: impl Into<Shape>, dtype: DType) -> Result<Tensor<Self>, ZyxError> {
131 Ok(tensor(
132 self.0.borrow_mut().randn(shape.into(), dtype)?,
133 self,
134 ))
135 }
136
137 fn uniform(
138 self,
139 shape: impl Into<Shape>,
140 range: Range<impl Scalar>,
141 ) -> Result<Tensor<Self>, ZyxError> {
142 Ok(tensor(
143 self.0.borrow_mut().uniform(shape.into(), range)?,
144 self,
145 ))
146 }
147
148 fn shape(self, x: Id) -> Shape {
149 self.0.borrow().shape(x).clone()
150 }
151
152 fn dtype(self, x: Id) -> DType {
153 self.0.borrow().dtype(x)
154 }
155
156 fn backward(self, x: Id, sources: &BTreeSet<Id>) -> Result<BTreeMap<Id, Id>, ZyxError> {
157 self.0.borrow_mut().backward(x, sources)
158 }
159
160 fn load<T: Scalar>(self, x: Id) -> Result<Vec<T>, ZyxError> {
161 self.0.borrow_mut().load(x)
162 }
163
164 fn store<T: Scalar, IT>(self, iter: IT) -> Result<Id, ZyxError>
165 where
166 IT: IntoIterator<Item = T>,
167 IT::IntoIter: ExactSizeIterator,
168 {
169 self.0.borrow_mut().store(iter)
170 }
171
172 fn push(self, node: Node) -> Result<Id, ZyxError> {
173 self.0.borrow_mut().push(node)
174 }
175
176 fn release(self, x: Id) -> Result<(), ZyxError> {
177 self.0.borrow_mut().release(x)
178 }
179
180 fn retain(self, x: Id) {
181 self.0.borrow_mut().retain(x);
182 }
183}
184
185