shrew_core/layout.rs
1use crate::error::{Error, Result};
2use crate::shape::Shape;
3
4// Layout — Memory layout of a tensor (shape + strides + offset)
5//
6// The Layout decouples the *logical* shape of a tensor from how its data is
7// arranged in memory. This is what makes operations like transpose, reshape,
8// and slicing "free" (no data copy needed — just change the layout).
9//
10// KEY CONCEPTS:
11//
12// 1. **Strides**: How many elements to skip in the flat storage to move one
13// step along each dimension. A contiguous [2,3] matrix has strides [3,1]:
14// - row 0 starts at offset 0, row 1 starts at offset 3
15// - within a row, consecutive elements are 1 apart
16//
17// 2. **Transpose**: Just swap the strides (and shape). No data movement!
18// [2,3] with strides [3,1] → transpose → [3,2] with strides [1,3]
19// The same data, but now read column-major.
20//
21// 3. **Narrow/Slice**: Just adjust the offset and shape. Still same storage.
22// narrow(dim=1, start=1, len=2) on [2,3] →
23// Shape [2,2], offset += 1 * stride[1] = 1, same strides
24//
25// 4. **Contiguous check**: A tensor is contiguous if its strides match the
26// default row-major strides for its shape. Non-contiguous tensors need
27// to be made contiguous (data copy) before certain operations (like
28// passing to BLAS or CUDA kernels that expect contiguous memory).
29
30/// Layout describes how a tensor's logical shape maps to flat storage.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct Layout {
33 shape: Shape,
34 strides: Vec<usize>,
35 /// Offset into the storage buffer where this tensor's data starts.
36 /// Used by slicing/narrow operations to create views into existing storage.
37 offset: usize,
38}
39
40impl Layout {
41 /// Create a new contiguous layout for the given shape.
42 /// Strides are computed as row-major (C-order).
43 pub fn contiguous(shape: Shape) -> Self {
44 let strides = shape.stride_contiguous();
45 Layout {
46 shape,
47 strides,
48 offset: 0,
49 }
50 }
51
52 /// Create a layout with explicit strides and offset (for views).
53 pub fn new(shape: Shape, strides: Vec<usize>, offset: usize) -> Self {
54 Layout {
55 shape,
56 strides,
57 offset,
58 }
59 }
60
61 pub fn shape(&self) -> &Shape {
62 &self.shape
63 }
64
65 pub fn strides(&self) -> &[usize] {
66 &self.strides
67 }
68
69 pub fn offset(&self) -> usize {
70 self.offset
71 }
72
73 pub fn rank(&self) -> usize {
74 self.shape.rank()
75 }
76
77 pub fn dims(&self) -> &[usize] {
78 self.shape.dims()
79 }
80
81 pub fn elem_count(&self) -> usize {
82 self.shape.elem_count()
83 }
84
85 /// Check if this layout is contiguous (row-major, no gaps).
86 /// A tensor is contiguous if its strides equal the default strides
87 /// for its shape AND offset is 0.
88 pub fn is_contiguous(&self) -> bool {
89 self.offset == 0 && self.strides == self.shape.stride_contiguous()
90 }
91
92 /// Transpose two dimensions. Returns a new layout with swapped shape/strides.
93 /// This is a "free" operation — no data is copied.
94 ///
95 /// Example: [2, 3, 4] transpose(0, 2) → [4, 3, 2]
96 /// strides [12, 4, 1] → [1, 4, 12]
97 pub fn transpose(&self, dim0: usize, dim1: usize) -> Result<Layout> {
98 let rank = self.rank();
99 if dim0 >= rank || dim1 >= rank {
100 return Err(Error::DimOutOfRange {
101 dim: dim0.max(dim1),
102 rank,
103 });
104 }
105 let mut new_dims = self.shape.dims().to_vec();
106 let mut new_strides = self.strides.clone();
107 new_dims.swap(dim0, dim1);
108 new_strides.swap(dim0, dim1);
109 Ok(Layout::new(Shape::new(new_dims), new_strides, self.offset))
110 }
111
112 /// Narrow (slice) along a dimension. Returns a new layout that is a view
113 /// into the same storage with adjusted shape and offset.
114 ///
115 /// Example: tensor of shape [4, 6], narrow(dim=1, start=2, len=3)
116 /// → shape [4, 3], offset += 2 * stride[1]
117 pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Layout> {
118 let rank = self.rank();
119 if dim >= rank {
120 return Err(Error::DimOutOfRange { dim, rank });
121 }
122 let dim_size = self.shape.dims()[dim];
123 if start + len > dim_size {
124 return Err(Error::NarrowOutOfBounds {
125 dim,
126 start,
127 len,
128 dim_size,
129 });
130 }
131 let mut new_dims = self.shape.dims().to_vec();
132 new_dims[dim] = len;
133 let new_offset = self.offset + start * self.strides[dim];
134 Ok(Layout::new(
135 Shape::new(new_dims),
136 self.strides.clone(),
137 new_offset,
138 ))
139 }
140
141 /// Compute the flat index into storage for a given multi-dimensional index.
142 /// This is the core formula: flat_index = offset + sum(index[i] * stride[i])
143 pub fn flat_index(&self, index: &[usize]) -> usize {
144 let mut flat = self.offset;
145 for (i, &idx) in index.iter().enumerate() {
146 flat += idx * self.strides[i];
147 }
148 flat
149 }
150
151 /// Iterator over all flat indices of this layout, in logical order.
152 /// This handles non-contiguous layouts correctly by walking through
153 /// multi-dimensional indices and converting via strides.
154 pub fn strided_indices(&self) -> StridedIter {
155 StridedIter::new(self)
156 }
157}
158
159// StridedIter — Iterates over flat storage indices respecting strides
160//
161// This iterator is essential for non-contiguous tensors. When a tensor has
162// been transposed or sliced, the data in memory is no longer sequential.
163// StridedIter walks through the logical elements in order and produces
164// the actual storage index for each one.
165//
166// For a contiguous tensor, this just counts 0, 1, 2, 3, ...
167// For a transposed tensor, it jumps around in memory following the strides.
168
169/// Iterator that yields flat storage indices for each element of a Layout.
170pub struct StridedIter {
171 /// Current multi-dimensional index (e.g., [0, 0, 0]).
172 current: Vec<usize>,
173 /// The shape dimensions.
174 dims: Vec<usize>,
175 /// The strides for each dimension.
176 strides: Vec<usize>,
177 /// Base offset into storage.
178 offset: usize,
179 /// Total elements remaining.
180 remaining: usize,
181 /// Whether we've started yet.
182 started: bool,
183}
184
185impl StridedIter {
186 fn new(layout: &Layout) -> Self {
187 let rank = layout.rank();
188 StridedIter {
189 current: vec![0; rank],
190 dims: layout.dims().to_vec(),
191 strides: layout.strides().to_vec(),
192 offset: layout.offset(),
193 remaining: layout.elem_count(),
194 started: false,
195 }
196 }
197
198 /// Compute flat index from current multi-dim index.
199 fn flat_index(&self) -> usize {
200 let mut idx = self.offset;
201 for i in 0..self.current.len() {
202 idx += self.current[i] * self.strides[i];
203 }
204 idx
205 }
206
207 /// Advance the multi-dimensional index by one (rightmost dimension first).
208 fn advance(&mut self) {
209 let rank = self.dims.len();
210 for i in (0..rank).rev() {
211 self.current[i] += 1;
212 if self.current[i] < self.dims[i] {
213 return;
214 }
215 self.current[i] = 0;
216 }
217 }
218}
219
220impl Iterator for StridedIter {
221 type Item = usize;
222
223 fn next(&mut self) -> Option<usize> {
224 if self.remaining == 0 {
225 return None;
226 }
227 if self.started {
228 self.advance();
229 }
230 self.started = true;
231 self.remaining -= 1;
232 Some(self.flat_index())
233 }
234
235 fn size_hint(&self) -> (usize, Option<usize>) {
236 (self.remaining, Some(self.remaining))
237 }
238}
239
240impl ExactSizeIterator for StridedIter {}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::shape::Shape;
246
247 #[test]
248 fn test_contiguous_layout() {
249 let layout = Layout::contiguous(Shape::from((2, 3)));
250 assert!(layout.is_contiguous());
251 assert_eq!(layout.strides(), &[3, 1]);
252 assert_eq!(layout.offset(), 0);
253 }
254
255 #[test]
256 fn test_contiguous_indices() {
257 // For [2, 3] contiguous, indices should be 0,1,2,3,4,5
258 let layout = Layout::contiguous(Shape::from((2, 3)));
259 let indices: Vec<usize> = layout.strided_indices().collect();
260 assert_eq!(indices, vec![0, 1, 2, 3, 4, 5]);
261 }
262
263 #[test]
264 fn test_transpose_layout() {
265 let layout = Layout::contiguous(Shape::from((2, 3)));
266 let transposed = layout.transpose(0, 1).unwrap();
267 // Shape becomes [3, 2], strides become [1, 3]
268 assert_eq!(transposed.dims(), &[3, 2]);
269 assert_eq!(transposed.strides(), &[1, 3]);
270 assert!(!transposed.is_contiguous());
271 }
272
273 #[test]
274 fn test_transpose_indices() {
275 // Original [2,3]:
276 // [[0, 1, 2],
277 // [3, 4, 5]]
278 //
279 // Transposed [3,2] should read column-major:
280 // [[0, 3],
281 // [1, 4],
282 // [2, 5]]
283 //
284 // So the flat indices in row-order of transposed are: 0, 3, 1, 4, 2, 5
285 let layout = Layout::contiguous(Shape::from((2, 3)));
286 let transposed = layout.transpose(0, 1).unwrap();
287 let indices: Vec<usize> = transposed.strided_indices().collect();
288 assert_eq!(indices, vec![0, 3, 1, 4, 2, 5]);
289 }
290
291 #[test]
292 fn test_narrow() {
293 // [4, 6] narrow(dim=1, start=2, len=3) → [4, 3] with offset=2
294 let layout = Layout::contiguous(Shape::from((4, 6)));
295 let narrowed = layout.narrow(1, 2, 3).unwrap();
296 assert_eq!(narrowed.dims(), &[4, 3]);
297 assert_eq!(narrowed.offset(), 2);
298 assert_eq!(narrowed.strides(), &[6, 1]); // strides unchanged
299 }
300
301 #[test]
302 fn test_narrow_out_of_bounds() {
303 let layout = Layout::contiguous(Shape::from((4, 6)));
304 assert!(layout.narrow(1, 5, 3).is_err()); // 5+3 = 8 > 6
305 }
306
307 #[test]
308 fn test_flat_index() {
309 let layout = Layout::contiguous(Shape::from((2, 3, 4)));
310 // Element at [1, 2, 3]: 1*12 + 2*4 + 3*1 = 23
311 assert_eq!(layout.flat_index(&[1, 2, 3]), 23);
312 // Element at [0, 0, 0]: 0
313 assert_eq!(layout.flat_index(&[0, 0, 0]), 0);
314 }
315}