Skip to main content

shrew_core/
backend.rs

1use crate::dtype::DType;
2use crate::error::Result;
3use crate::layout::Layout;
4use crate::shape::Shape;
5use std::fmt;
6
7// Backend — Abstraction over compute devices (CPU, CUDA, etc.)
8//
9// The Backend trait is the central abstraction that makes Shrew extensible.
10// Each backend (CPU, CUDA, ROCm, etc.) implements this trait, providing
11// its own storage type and operation implementations.
12//
13// WHY A TRAIT AND NOT AN ENUM?
14//
15// Using a trait (vs. an enum like `Device::Cpu | Device::Cuda`) means:
16// - New backends can be added as separate crates without modifying shrew-core
17// - Each backend can have different associated types (e.g., CudaStorage vs CpuStorage)
18// - The compiler can monomorphize for performance
19//
20// The tradeoff is that Tensor becomes generic: Tensor<B: Backend>.
21// This is similar to Burn's approach and provides maximum flexibility.
22
23/// Identifies a compute device (e.g., "CPU", "CUDA:0", "CUDA:1").
24pub trait BackendDevice: Clone + fmt::Debug + Send + Sync + 'static {
25    /// A human-readable name for this device (e.g., "cpu", "cuda:0").
26    fn name(&self) -> String;
27}
28
29/// A storage buffer that holds tensor data on a specific device.
30///
31/// For CPU, this is a `Vec<f32>` (or enum over dtypes).
32/// For CUDA, this is a device memory allocation (`CudaSlice`).
33pub trait BackendStorage: Clone + Send + Sync + 'static {
34    /// The data type of the elements in this storage.
35    fn dtype(&self) -> DType;
36
37    /// Total number of elements that fit in this storage.
38    fn len(&self) -> usize;
39
40    fn is_empty(&self) -> bool {
41        self.len() == 0
42    }
43}
44
45// Binary and Unary operation enums
46//
47// These enums serve two purposes:
48// 1. They parameterize the backend ops (so we have one trait method per category)
49// 2. They are recorded in the Op enum for autograd (knowing WHICH binary op
50//    was performed is needed to compute the correct gradient)
51
52/// Element-wise binary operations.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum BinaryOp {
55    Add,
56    Sub,
57    Mul,
58    Div,
59}
60
61/// Element-wise unary operations.
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum UnaryOp {
64    Neg,
65    Abs,
66    Exp,
67    Log,
68    Sqrt,
69    Relu,
70    Sigmoid,
71    Tanh,
72    Gelu,
73    Silu,
74    Sin,
75    Cos,
76    Square,
77    Floor,
78    Ceil,
79    Round,
80}
81
82/// Reduction operations along dimension(s).
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum ReduceOp {
85    Sum,
86    Mean,
87    Max,
88    Min,
89    ArgMax,
90    ArgMin,
91}
92
93/// Comparison operations (produce boolean tensors).
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum CmpOp {
96    Eq,
97    Ne,
98    Gt,
99    Ge,
100    Lt,
101    Le,
102}
103
104// Backend Trait — The core interface every backend must implement
105
106/// The main Backend trait. Implementing this for a struct (e.g., CpuBackend)
107/// makes that struct a complete compute backend for Shrew.
108///
109/// All operations take storage + layout (which encodes shape/strides) and
110/// return new storage (immutable semantics — no in-place mutation by default).
111pub trait Backend: Clone + Send + Sync + fmt::Debug + 'static {
112    /// The device type for this backend.
113    type Device: BackendDevice;
114    /// The storage type for this backend.
115    type Storage: BackendStorage;
116
117    //  Creation 
118
119    /// Allocate storage filled with zeros.
120    fn zeros(shape: &Shape, dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
121
122    /// Allocate storage filled with ones.
123    fn ones(shape: &Shape, dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
124
125    /// Allocate storage filled with a constant value.
126    fn full(shape: &Shape, val: f64, dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
127
128    /// Create storage from a flat f64 slice, converting to the target dtype.
129    fn from_f64_slice(data: &[f64], dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
130
131    /// Create storage with random uniform values in [0, 1).
132    fn rand_uniform(shape: &Shape, dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
133
134    /// Create storage with random normal values (mean=0, std=1).
135    fn rand_normal(shape: &Shape, dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
136
137    //  Element-wise binary ops 
138
139    /// Apply a binary op element-wise: result[i] = op(lhs[i], rhs[i]).
140    /// The layouts handle broadcasting and non-contiguous access.
141    fn binary_op(
142        op: BinaryOp,
143        lhs: &Self::Storage,
144        lhs_layout: &Layout,
145        rhs: &Self::Storage,
146        rhs_layout: &Layout,
147    ) -> Result<Self::Storage>;
148
149    //  Element-wise unary ops 
150
151    /// Apply a unary op element-wise: result[i] = op(input[i]).
152    fn unary_op(op: UnaryOp, input: &Self::Storage, layout: &Layout) -> Result<Self::Storage>;
153
154    //  Reductions 
155
156    /// Reduce along specific dimensions.
157    /// If `dims` is empty, reduce over all elements.
158    fn reduce_op(
159        op: ReduceOp,
160        input: &Self::Storage,
161        layout: &Layout,
162        dims: &[usize],
163        keep_dim: bool,
164    ) -> Result<Self::Storage>;
165
166    //  Matrix multiplication 
167
168    /// General matrix multiply: C = A @ B.
169    /// Supports batched matmul for tensors with rank > 2.
170    fn matmul(
171        lhs: &Self::Storage,
172        lhs_layout: &Layout,
173        rhs: &Self::Storage,
174        rhs_layout: &Layout,
175    ) -> Result<Self::Storage>;
176
177    //  Data movement 
178
179    /// Make a contiguous copy of the storage following the given layout.
180    /// If the layout is already contiguous, this may just clone the storage.
181    fn to_contiguous(input: &Self::Storage, layout: &Layout) -> Result<Self::Storage>;
182
183    /// Copy data from this storage to a Vec<f64> on the host (for inspection).
184    fn to_f64_vec(input: &Self::Storage, layout: &Layout) -> Result<Vec<f64>>;
185
186    //  Comparison ops 
187
188    /// Element-wise comparison, returns a u8 storage (0 or 1).
189    fn cmp_op(
190        op: CmpOp,
191        lhs: &Self::Storage,
192        lhs_layout: &Layout,
193        rhs: &Self::Storage,
194        rhs_layout: &Layout,
195    ) -> Result<Self::Storage>;
196
197    //  Affine / fused ops (optional but useful) 
198
199    /// Affine transform: result = input * mul + add.
200    /// Used for normalization and other fused operations.
201    fn affine(input: &Self::Storage, layout: &Layout, mul: f64, add: f64) -> Result<Self::Storage>;
202
203    //  Indexing 
204
205    /// Gather elements along a dimension using index tensor.
206    fn index_select(
207        input: &Self::Storage,
208        input_layout: &Layout,
209        indices: &Self::Storage,
210        indices_layout: &Layout,
211        dim: usize,
212    ) -> Result<Self::Storage>;
213
214    //  Powf 
215
216    /// Element-wise power: result[i] = input[i] ^ exponent.
217    fn powf(input: &Self::Storage, layout: &Layout, exponent: f64) -> Result<Self::Storage>;
218
219    //  Clamp 
220
221    /// Element-wise clamp: result[i] = clamp(input[i], min, max).
222    fn clamp(input: &Self::Storage, layout: &Layout, min: f64, max: f64) -> Result<Self::Storage>;
223
224    //  Where / conditional select 
225
226    /// Element-wise conditional: result[i] = if mask[i] != 0 { on_true[i] } else { on_false[i] }.
227    fn where_cond(
228        mask: &Self::Storage,
229        mask_layout: &Layout,
230        on_true: &Self::Storage,
231        on_true_layout: &Layout,
232        on_false: &Self::Storage,
233        on_false_layout: &Layout,
234    ) -> Result<Self::Storage>;
235
236    //  Gather 
237
238    /// Gather elements along `dim` using `index` tensor.
239    ///
240    /// `output[i][j][k] = input[index[i][j][k]][j][k]`  (when dim=0)
241    /// `output[i][j][k] = input[i][index[i][j][k]][k]`  (when dim=1)
242    /// etc.
243    ///
244    /// `index` must have the same number of dimensions as `input`.
245    fn gather(
246        input: &Self::Storage,
247        input_layout: &Layout,
248        index: &Self::Storage,
249        index_layout: &Layout,
250        dim: usize,
251    ) -> Result<Self::Storage>;
252
253    //  Concatenation 
254
255    /// Concatenate multiple storages along `dim` into a single contiguous storage.
256    /// Each entry is (storage, layout) so non-contiguous inputs are handled correctly.
257    /// `out_shape` is the pre-validated output shape.
258    fn cat(
259        inputs: &[(&Self::Storage, &Layout)],
260        out_shape: &Shape,
261        dim: usize,
262    ) -> Result<Self::Storage>;
263
264    //  Dtype conversion 
265
266    /// Cast storage to a different dtype on-device (no host round-trip).
267    ///
268    /// The default implementation falls back to `to_f64_vec` + `from_f64_slice`,
269    /// which involves a host round-trip. Backends should override this with
270    /// a native on-device kernel when possible.
271    fn cast(
272        input: &Self::Storage,
273        layout: &Layout,
274        dtype: DType,
275        device: &Self::Device,
276    ) -> Result<Self::Storage> {
277        let data = Self::to_f64_vec(input, layout)?;
278        Self::from_f64_slice(&data, dtype, device)
279    }
280}