zyx_compiler/
lib.rs

1//! Zyx compiler to IR
2
3#![no_std]
4#![forbid(unsafe_code)]
5#![forbid(rustdoc::broken_intra_doc_links)]
6#![forbid(rustdoc::private_intra_doc_links)]
7//#![forbid(missing_docs)]
8#![forbid(rustdoc::missing_crate_level_docs)]
9//#![forbid(rustdoc::missing_doc_code_examples)]
10#![forbid(rustdoc::private_doc_tests)]
11#![forbid(rustdoc::invalid_codeblock_attributes)]
12#![forbid(rustdoc::invalid_html_tags)]
13#![forbid(rustdoc::invalid_rust_codeblocks)]
14#![forbid(rustdoc::bare_urls)]
15#![forbid(rustdoc::unescaped_backticks)]
16#![forbid(rustdoc::redundant_explicit_links)]
17
18use zyx_core::error::ZyxError;
19
20mod ast;
21mod ir;
22
23use ast::Kernel;
24pub use ir::{Op, IR, UOp, BOp};
25
26#[cfg(feature = "std")]
27extern crate std;
28
29extern crate alloc;
30
31use alloc::{collections::BTreeMap, vec::Vec};
32use zyx_core::axes::Axes;
33use zyx_core::dtype::DType;
34use zyx_core::scalar::Scalar;
35use zyx_core::shape::Shape;
36use zyx_core::tensor::Id;
37use zyx_core::view::View;
38
39/// Compiled backend that holds compiler, buffers and programs
40pub struct CompiledBackend<C: Compiler> {
41    compiler: C,
42    kernels: BTreeMap<Id, Kernel>,
43    buffers: BTreeMap<Id, C::Buffer>,
44    programs: BTreeMap<AST, C::Program>,
45}
46
47/// Implement this trait for compiled backends
48pub trait Compiler {
49    /// Buffer holds actual values in memory
50    type Buffer;
51    /// Program is kernel executable on the device, can be compiled at runtime
52    type Program;
53    /// Store iter into buffer
54    fn store<T: Scalar>(
55        &mut self,
56        iter: impl IntoIterator<Item = T>,
57    ) -> Result<Self::Buffer, ZyxError>;
58    /// Load buffer into vec
59    fn load<T: Scalar>(&mut self, buffer: &Self::Buffer, numel: usize) -> Result<Vec<T>, ZyxError>;
60    /// Drop Buffer
61    fn drop_buffer(&mut self, buffer: &mut Self::Buffer) -> Result<(), ZyxError>;
62    /// Drop Program
63    fn drop_program(&mut self, program: &mut Self::Program) -> Result<(), ZyxError>;
64    /// Launch program with args
65    fn launch(
66        &mut self,
67        program: &Self::Program,
68        args: &[&Self::Buffer],
69        flop: usize,
70        bytes: usize,
71    ) -> Result<Self::Buffer, ZyxError>;
72    /// Compile ast into program
73    fn compile(&mut self, ir: &IR) -> Result<Self::Program, ZyxError>;
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
77enum ASTUOp {
78    Cast(DType),
79    Neg,
80    ReLU,
81    Sin,
82    Cos,
83    Exp,
84    Ln,
85    Tanh,
86    Sqrt,
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
90enum ASTBOp {
91    Add,
92    Sub,
93    Mul,
94    Div,
95    Pow,
96    Cmplt,
97}
98
99#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
100enum ASTROp {
101    Sum,
102    Max
103}
104
105#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
106enum ASTOp {
107    Leaf(u8),
108    Unary(u8, ASTUOp),
109    Binary(u8, u8, ASTBOp),
110    #[allow(dead_code)]
111    Where(u8, u8, u8),
112    Reduce(u8, ASTROp),
113}
114
115/// Abstract syntax tree that can be compiled into program.
116/// Consists of kernel arguments, elementwise ops, optional reduce op
117/// and elementwise ops after reduce.
118/// This struct is immutable.
119#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
120struct AST {
121    /// AST argument views
122    pub arg_views: Vec<View>,
123    /// AST argument dtypes
124    pub arg_dtypes: Vec<DType>,
125    /// AST ops
126    pub ops: Vec<ASTOp>,
127    /// Shape of the result, this is before any reduce ops
128    pub shape: Shape,
129    /// DType of the result
130    pub dtype: DType,
131    /// Reduce axes, if any
132    pub reduce_axes: Option<Axes>,
133    /// DType of accumulated elements, if any
134    pub reduce_dtype: Option<DType>,
135}