1#![no_std]
4#![forbid(unsafe_code)]
5#![forbid(rustdoc::broken_intra_doc_links)]
6#![forbid(rustdoc::private_intra_doc_links)]
7#![forbid(rustdoc::missing_crate_level_docs)]
9#![forbid(rustdoc::private_doc_tests)]
11#![forbid(rustdoc::invalid_codeblock_attributes)]
12#![forbid(rustdoc::invalid_html_tags)]
13#![forbid(rustdoc::invalid_rust_codeblocks)]
14#![forbid(rustdoc::bare_urls)]
15#![forbid(rustdoc::unescaped_backticks)]
16#![forbid(rustdoc::redundant_explicit_links)]
17
18use zyx_core::error::ZyxError;
19
20mod ast;
21mod ir;
22
23use ast::Kernel;
24pub use ir::{Op, IR, UOp, BOp};
25
26#[cfg(feature = "std")]
27extern crate std;
28
29extern crate alloc;
30
31use alloc::{collections::BTreeMap, vec::Vec};
32use zyx_core::axes::Axes;
33use zyx_core::dtype::DType;
34use zyx_core::scalar::Scalar;
35use zyx_core::shape::Shape;
36use zyx_core::tensor::Id;
37use zyx_core::view::View;
38
39pub struct CompiledBackend<C: Compiler> {
41 compiler: C,
42 kernels: BTreeMap<Id, Kernel>,
43 buffers: BTreeMap<Id, C::Buffer>,
44 programs: BTreeMap<AST, C::Program>,
45}
46
47pub trait Compiler {
49 type Buffer;
51 type Program;
53 fn store<T: Scalar>(
55 &mut self,
56 iter: impl IntoIterator<Item = T>,
57 ) -> Result<Self::Buffer, ZyxError>;
58 fn load<T: Scalar>(&mut self, buffer: &Self::Buffer, numel: usize) -> Result<Vec<T>, ZyxError>;
60 fn drop_buffer(&mut self, buffer: &mut Self::Buffer) -> Result<(), ZyxError>;
62 fn drop_program(&mut self, program: &mut Self::Program) -> Result<(), ZyxError>;
64 fn launch(
66 &mut self,
67 program: &Self::Program,
68 args: &[&Self::Buffer],
69 flop: usize,
70 bytes: usize,
71 ) -> Result<Self::Buffer, ZyxError>;
72 fn compile(&mut self, ir: &IR) -> Result<Self::Program, ZyxError>;
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
77enum ASTUOp {
78 Cast(DType),
79 Neg,
80 ReLU,
81 Sin,
82 Cos,
83 Exp,
84 Ln,
85 Tanh,
86 Sqrt,
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
90enum ASTBOp {
91 Add,
92 Sub,
93 Mul,
94 Div,
95 Pow,
96 Cmplt,
97}
98
99#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
100enum ASTROp {
101 Sum,
102 Max
103}
104
105#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
106enum ASTOp {
107 Leaf(u8),
108 Unary(u8, ASTUOp),
109 Binary(u8, u8, ASTBOp),
110 #[allow(dead_code)]
111 Where(u8, u8, u8),
112 Reduce(u8, ASTROp),
113}
114
115#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
120struct AST {
121 pub arg_views: Vec<View>,
123 pub arg_dtypes: Vec<DType>,
125 pub ops: Vec<ASTOp>,
127 pub shape: Shape,
129 pub dtype: DType,
131 pub reduce_axes: Option<Axes>,
133 pub reduce_dtype: Option<DType>,
135}