shrew_core/backend.rs
1use crate::dtype::DType;
2use crate::error::Result;
3use crate::layout::Layout;
4use crate::shape::Shape;
5use std::fmt;
6
7// Backend — Abstraction over compute devices (CPU, CUDA, etc.)
8//
9// The Backend trait is the central abstraction that makes Shrew extensible.
10// Each backend (CPU, CUDA, ROCm, etc.) implements this trait, providing
11// its own storage type and operation implementations.
12//
13// WHY A TRAIT AND NOT AN ENUM?
14//
15// Using a trait (vs. an enum like `Device::Cpu | Device::Cuda`) means:
16// - New backends can be added as separate crates without modifying shrew-core
17// - Each backend can have different associated types (e.g., CudaStorage vs CpuStorage)
18// - The compiler can monomorphize for performance
19//
20// The tradeoff is that Tensor becomes generic: Tensor<B: Backend>.
21// This is similar to Burn's approach and provides maximum flexibility.
22
23/// Identifies a compute device (e.g., "CPU", "CUDA:0", "CUDA:1").
24pub trait BackendDevice: Clone + fmt::Debug + Send + Sync + 'static {
25 /// A human-readable name for this device (e.g., "cpu", "cuda:0").
26 fn name(&self) -> String;
27}
28
29/// A storage buffer that holds tensor data on a specific device.
30///
31/// For CPU, this is a `Vec<f32>` (or enum over dtypes).
32/// For CUDA, this is a device memory allocation (`CudaSlice`).
33pub trait BackendStorage: Clone + Send + Sync + 'static {
34 /// The data type of the elements in this storage.
35 fn dtype(&self) -> DType;
36
37 /// Total number of elements that fit in this storage.
38 fn len(&self) -> usize;
39
40 fn is_empty(&self) -> bool {
41 self.len() == 0
42 }
43}
44
45// Binary and Unary operation enums
46//
47// These enums serve two purposes:
48// 1. They parameterize the backend ops (so we have one trait method per category)
49// 2. They are recorded in the Op enum for autograd (knowing WHICH binary op
50// was performed is needed to compute the correct gradient)
51
52/// Element-wise binary operations.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum BinaryOp {
55 Add,
56 Sub,
57 Mul,
58 Div,
59}
60
61/// Element-wise unary operations.
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum UnaryOp {
64 Neg,
65 Abs,
66 Exp,
67 Log,
68 Sqrt,
69 Relu,
70 Sigmoid,
71 Tanh,
72 Gelu,
73 Silu,
74 Sin,
75 Cos,
76 Square,
77 Floor,
78 Ceil,
79 Round,
80}
81
82/// Reduction operations along dimension(s).
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum ReduceOp {
85 Sum,
86 Mean,
87 Max,
88 Min,
89 ArgMax,
90 ArgMin,
91}
92
93/// Comparison operations (produce boolean tensors).
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub enum CmpOp {
96 Eq,
97 Ne,
98 Gt,
99 Ge,
100 Lt,
101 Le,
102}
103
104// Backend Trait — The core interface every backend must implement
105
106/// The main Backend trait. Implementing this for a struct (e.g., CpuBackend)
107/// makes that struct a complete compute backend for Shrew.
108///
109/// All operations take storage + layout (which encodes shape/strides) and
110/// return new storage (immutable semantics — no in-place mutation by default).
111pub trait Backend: Clone + Send + Sync + fmt::Debug + 'static {
112 /// The device type for this backend.
113 type Device: BackendDevice;
114 /// The storage type for this backend.
115 type Storage: BackendStorage;
116
117 // Creation
118
119 /// Allocate storage filled with zeros.
120 fn zeros(shape: &Shape, dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
121
122 /// Allocate storage filled with ones.
123 fn ones(shape: &Shape, dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
124
125 /// Allocate storage filled with a constant value.
126 fn full(shape: &Shape, val: f64, dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
127
128 /// Create storage from a flat f64 slice, converting to the target dtype.
129 fn from_f64_slice(data: &[f64], dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
130
131 /// Create storage with random uniform values in [0, 1).
132 fn rand_uniform(shape: &Shape, dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
133
134 /// Create storage with random normal values (mean=0, std=1).
135 fn rand_normal(shape: &Shape, dtype: DType, device: &Self::Device) -> Result<Self::Storage>;
136
137 // Element-wise binary ops
138
139 /// Apply a binary op element-wise: result[i] = op(lhs[i], rhs[i]).
140 /// The layouts handle broadcasting and non-contiguous access.
141 fn binary_op(
142 op: BinaryOp,
143 lhs: &Self::Storage,
144 lhs_layout: &Layout,
145 rhs: &Self::Storage,
146 rhs_layout: &Layout,
147 ) -> Result<Self::Storage>;
148
149 // Element-wise unary ops
150
151 /// Apply a unary op element-wise: result[i] = op(input[i]).
152 fn unary_op(op: UnaryOp, input: &Self::Storage, layout: &Layout) -> Result<Self::Storage>;
153
154 // Reductions
155
156 /// Reduce along specific dimensions.
157 /// If `dims` is empty, reduce over all elements.
158 fn reduce_op(
159 op: ReduceOp,
160 input: &Self::Storage,
161 layout: &Layout,
162 dims: &[usize],
163 keep_dim: bool,
164 ) -> Result<Self::Storage>;
165
166 // Matrix multiplication
167
168 /// General matrix multiply: C = A @ B.
169 /// Supports batched matmul for tensors with rank > 2.
170 fn matmul(
171 lhs: &Self::Storage,
172 lhs_layout: &Layout,
173 rhs: &Self::Storage,
174 rhs_layout: &Layout,
175 ) -> Result<Self::Storage>;
176
177 // Data movement
178
179 /// Make a contiguous copy of the storage following the given layout.
180 /// If the layout is already contiguous, this may just clone the storage.
181 fn to_contiguous(input: &Self::Storage, layout: &Layout) -> Result<Self::Storage>;
182
183 /// Copy data from this storage to a Vec<f64> on the host (for inspection).
184 fn to_f64_vec(input: &Self::Storage, layout: &Layout) -> Result<Vec<f64>>;
185
186 // Comparison ops
187
188 /// Element-wise comparison, returns a u8 storage (0 or 1).
189 fn cmp_op(
190 op: CmpOp,
191 lhs: &Self::Storage,
192 lhs_layout: &Layout,
193 rhs: &Self::Storage,
194 rhs_layout: &Layout,
195 ) -> Result<Self::Storage>;
196
197 // Affine / fused ops (optional but useful)
198
199 /// Affine transform: result = input * mul + add.
200 /// Used for normalization and other fused operations.
201 fn affine(input: &Self::Storage, layout: &Layout, mul: f64, add: f64) -> Result<Self::Storage>;
202
203 // Indexing
204
205 /// Gather elements along a dimension using index tensor.
206 fn index_select(
207 input: &Self::Storage,
208 input_layout: &Layout,
209 indices: &Self::Storage,
210 indices_layout: &Layout,
211 dim: usize,
212 ) -> Result<Self::Storage>;
213
214 // Powf
215
216 /// Element-wise power: result[i] = input[i] ^ exponent.
217 fn powf(input: &Self::Storage, layout: &Layout, exponent: f64) -> Result<Self::Storage>;
218
219 // Clamp
220
221 /// Element-wise clamp: result[i] = clamp(input[i], min, max).
222 fn clamp(input: &Self::Storage, layout: &Layout, min: f64, max: f64) -> Result<Self::Storage>;
223
224 // Where / conditional select
225
226 /// Element-wise conditional: result[i] = if mask[i] != 0 { on_true[i] } else { on_false[i] }.
227 fn where_cond(
228 mask: &Self::Storage,
229 mask_layout: &Layout,
230 on_true: &Self::Storage,
231 on_true_layout: &Layout,
232 on_false: &Self::Storage,
233 on_false_layout: &Layout,
234 ) -> Result<Self::Storage>;
235
236 // Gather
237
238 /// Gather elements along `dim` using `index` tensor.
239 ///
240 /// `output[i][j][k] = input[index[i][j][k]][j][k]` (when dim=0)
241 /// `output[i][j][k] = input[i][index[i][j][k]][k]` (when dim=1)
242 /// etc.
243 ///
244 /// `index` must have the same number of dimensions as `input`.
245 fn gather(
246 input: &Self::Storage,
247 input_layout: &Layout,
248 index: &Self::Storage,
249 index_layout: &Layout,
250 dim: usize,
251 ) -> Result<Self::Storage>;
252
253 // Concatenation
254
255 /// Concatenate multiple storages along `dim` into a single contiguous storage.
256 /// Each entry is (storage, layout) so non-contiguous inputs are handled correctly.
257 /// `out_shape` is the pre-validated output shape.
258 fn cat(
259 inputs: &[(&Self::Storage, &Layout)],
260 out_shape: &Shape,
261 dim: usize,
262 ) -> Result<Self::Storage>;
263
264 // Dtype conversion
265
266 /// Cast storage to a different dtype on-device (no host round-trip).
267 ///
268 /// The default implementation falls back to `to_f64_vec` + `from_f64_slice`,
269 /// which involves a host round-trip. Backends should override this with
270 /// a native on-device kernel when possible.
271 fn cast(
272 input: &Self::Storage,
273 layout: &Layout,
274 dtype: DType,
275 device: &Self::Device,
276 ) -> Result<Self::Storage> {
277 let data = Self::to_f64_vec(input, layout)?;
278 Self::from_f64_slice(&data, dtype, device)
279 }
280}