Skip to main content

shrew_core/
layout.rs

1use crate::error::{Error, Result};
2use crate::shape::Shape;
3
4// Layout — Memory layout of a tensor (shape + strides + offset)
5//
6// The Layout decouples the *logical* shape of a tensor from how its data is
7// arranged in memory. This is what makes operations like transpose, reshape,
8// and slicing "free" (no data copy needed — just change the layout).
9//
10// KEY CONCEPTS:
11//
12// 1. **Strides**: How many elements to skip in the flat storage to move one
13//    step along each dimension. A contiguous [2,3] matrix has strides [3,1]:
14//    - row 0 starts at offset 0, row 1 starts at offset 3
15//    - within a row, consecutive elements are 1 apart
16//
17// 2. **Transpose**: Just swap the strides (and shape). No data movement!
18//    [2,3] with strides [3,1] → transpose → [3,2] with strides [1,3]
19//    The same data, but now read column-major.
20//
21// 3. **Narrow/Slice**: Just adjust the offset and shape. Still same storage.
22//    narrow(dim=1, start=1, len=2) on [2,3] →
23//    Shape [2,2], offset += 1 * stride[1] = 1, same strides
24//
25// 4. **Contiguous check**: A tensor is contiguous if its strides match the
26//    default row-major strides for its shape. Non-contiguous tensors need
27//    to be made contiguous (data copy) before certain operations (like
28//    passing to BLAS or CUDA kernels that expect contiguous memory).
29
30/// Layout describes how a tensor's logical shape maps to flat storage.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct Layout {
33    shape: Shape,
34    strides: Vec<usize>,
35    /// Offset into the storage buffer where this tensor's data starts.
36    /// Used by slicing/narrow operations to create views into existing storage.
37    offset: usize,
38}
39
40impl Layout {
41    /// Create a new contiguous layout for the given shape.
42    /// Strides are computed as row-major (C-order).
43    pub fn contiguous(shape: Shape) -> Self {
44        let strides = shape.stride_contiguous();
45        Layout {
46            shape,
47            strides,
48            offset: 0,
49        }
50    }
51
52    /// Create a layout with explicit strides and offset (for views).
53    pub fn new(shape: Shape, strides: Vec<usize>, offset: usize) -> Self {
54        Layout {
55            shape,
56            strides,
57            offset,
58        }
59    }
60
61    pub fn shape(&self) -> &Shape {
62        &self.shape
63    }
64
65    pub fn strides(&self) -> &[usize] {
66        &self.strides
67    }
68
69    pub fn offset(&self) -> usize {
70        self.offset
71    }
72
73    pub fn rank(&self) -> usize {
74        self.shape.rank()
75    }
76
77    pub fn dims(&self) -> &[usize] {
78        self.shape.dims()
79    }
80
81    pub fn elem_count(&self) -> usize {
82        self.shape.elem_count()
83    }
84
85    /// Check if this layout is contiguous (row-major, no gaps).
86    /// A tensor is contiguous if its strides equal the default strides
87    /// for its shape AND offset is 0.
88    pub fn is_contiguous(&self) -> bool {
89        self.offset == 0 && self.strides == self.shape.stride_contiguous()
90    }
91
92    /// Transpose two dimensions. Returns a new layout with swapped shape/strides.
93    /// This is a "free" operation — no data is copied.
94    ///
95    /// Example: [2, 3, 4] transpose(0, 2) → [4, 3, 2]
96    ///          strides [12, 4, 1]         → [1, 4, 12]
97    pub fn transpose(&self, dim0: usize, dim1: usize) -> Result<Layout> {
98        let rank = self.rank();
99        if dim0 >= rank || dim1 >= rank {
100            return Err(Error::DimOutOfRange {
101                dim: dim0.max(dim1),
102                rank,
103            });
104        }
105        let mut new_dims = self.shape.dims().to_vec();
106        let mut new_strides = self.strides.clone();
107        new_dims.swap(dim0, dim1);
108        new_strides.swap(dim0, dim1);
109        Ok(Layout::new(Shape::new(new_dims), new_strides, self.offset))
110    }
111
112    /// Narrow (slice) along a dimension. Returns a new layout that is a view
113    /// into the same storage with adjusted shape and offset.
114    ///
115    /// Example: tensor of shape [4, 6], narrow(dim=1, start=2, len=3)
116    /// → shape [4, 3], offset += 2 * stride[1]
117    pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Layout> {
118        let rank = self.rank();
119        if dim >= rank {
120            return Err(Error::DimOutOfRange { dim, rank });
121        }
122        let dim_size = self.shape.dims()[dim];
123        if start + len > dim_size {
124            return Err(Error::NarrowOutOfBounds {
125                dim,
126                start,
127                len,
128                dim_size,
129            });
130        }
131        let mut new_dims = self.shape.dims().to_vec();
132        new_dims[dim] = len;
133        let new_offset = self.offset + start * self.strides[dim];
134        Ok(Layout::new(
135            Shape::new(new_dims),
136            self.strides.clone(),
137            new_offset,
138        ))
139    }
140
141    /// Compute the flat index into storage for a given multi-dimensional index.
142    /// This is the core formula: flat_index = offset + sum(index[i] * stride[i])
143    pub fn flat_index(&self, index: &[usize]) -> usize {
144        let mut flat = self.offset;
145        for (i, &idx) in index.iter().enumerate() {
146            flat += idx * self.strides[i];
147        }
148        flat
149    }
150
151    /// Iterator over all flat indices of this layout, in logical order.
152    /// This handles non-contiguous layouts correctly by walking through
153    /// multi-dimensional indices and converting via strides.
154    pub fn strided_indices(&self) -> StridedIter {
155        StridedIter::new(self)
156    }
157}
158
159// StridedIter — Iterates over flat storage indices respecting strides
160//
161// This iterator is essential for non-contiguous tensors. When a tensor has
162// been transposed or sliced, the data in memory is no longer sequential.
163// StridedIter walks through the logical elements in order and produces
164// the actual storage index for each one.
165//
166// For a contiguous tensor, this just counts 0, 1, 2, 3, ...
167// For a transposed tensor, it jumps around in memory following the strides.
168
169/// Iterator that yields flat storage indices for each element of a Layout.
170pub struct StridedIter {
171    /// Current multi-dimensional index (e.g., [0, 0, 0]).
172    current: Vec<usize>,
173    /// The shape dimensions.
174    dims: Vec<usize>,
175    /// The strides for each dimension.
176    strides: Vec<usize>,
177    /// Base offset into storage.
178    offset: usize,
179    /// Total elements remaining.
180    remaining: usize,
181    /// Whether we've started yet.
182    started: bool,
183}
184
185impl StridedIter {
186    fn new(layout: &Layout) -> Self {
187        let rank = layout.rank();
188        StridedIter {
189            current: vec![0; rank],
190            dims: layout.dims().to_vec(),
191            strides: layout.strides().to_vec(),
192            offset: layout.offset(),
193            remaining: layout.elem_count(),
194            started: false,
195        }
196    }
197
198    /// Compute flat index from current multi-dim index.
199    fn flat_index(&self) -> usize {
200        let mut idx = self.offset;
201        for i in 0..self.current.len() {
202            idx += self.current[i] * self.strides[i];
203        }
204        idx
205    }
206
207    /// Advance the multi-dimensional index by one (rightmost dimension first).
208    fn advance(&mut self) {
209        let rank = self.dims.len();
210        for i in (0..rank).rev() {
211            self.current[i] += 1;
212            if self.current[i] < self.dims[i] {
213                return;
214            }
215            self.current[i] = 0;
216        }
217    }
218}
219
220impl Iterator for StridedIter {
221    type Item = usize;
222
223    fn next(&mut self) -> Option<usize> {
224        if self.remaining == 0 {
225            return None;
226        }
227        if self.started {
228            self.advance();
229        }
230        self.started = true;
231        self.remaining -= 1;
232        Some(self.flat_index())
233    }
234
235    fn size_hint(&self) -> (usize, Option<usize>) {
236        (self.remaining, Some(self.remaining))
237    }
238}
239
240impl ExactSizeIterator for StridedIter {}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::shape::Shape;
246
247    #[test]
248    fn test_contiguous_layout() {
249        let layout = Layout::contiguous(Shape::from((2, 3)));
250        assert!(layout.is_contiguous());
251        assert_eq!(layout.strides(), &[3, 1]);
252        assert_eq!(layout.offset(), 0);
253    }
254
255    #[test]
256    fn test_contiguous_indices() {
257        // For [2, 3] contiguous, indices should be 0,1,2,3,4,5
258        let layout = Layout::contiguous(Shape::from((2, 3)));
259        let indices: Vec<usize> = layout.strided_indices().collect();
260        assert_eq!(indices, vec![0, 1, 2, 3, 4, 5]);
261    }
262
263    #[test]
264    fn test_transpose_layout() {
265        let layout = Layout::contiguous(Shape::from((2, 3)));
266        let transposed = layout.transpose(0, 1).unwrap();
267        // Shape becomes [3, 2], strides become [1, 3]
268        assert_eq!(transposed.dims(), &[3, 2]);
269        assert_eq!(transposed.strides(), &[1, 3]);
270        assert!(!transposed.is_contiguous());
271    }
272
273    #[test]
274    fn test_transpose_indices() {
275        // Original [2,3]:
276        //   [[0, 1, 2],
277        //    [3, 4, 5]]
278        //
279        // Transposed [3,2] should read column-major:
280        //   [[0, 3],
281        //    [1, 4],
282        //    [2, 5]]
283        //
284        // So the flat indices in row-order of transposed are: 0, 3, 1, 4, 2, 5
285        let layout = Layout::contiguous(Shape::from((2, 3)));
286        let transposed = layout.transpose(0, 1).unwrap();
287        let indices: Vec<usize> = transposed.strided_indices().collect();
288        assert_eq!(indices, vec![0, 3, 1, 4, 2, 5]);
289    }
290
291    #[test]
292    fn test_narrow() {
293        // [4, 6] narrow(dim=1, start=2, len=3) → [4, 3] with offset=2
294        let layout = Layout::contiguous(Shape::from((4, 6)));
295        let narrowed = layout.narrow(1, 2, 3).unwrap();
296        assert_eq!(narrowed.dims(), &[4, 3]);
297        assert_eq!(narrowed.offset(), 2);
298        assert_eq!(narrowed.strides(), &[6, 1]); // strides unchanged
299    }
300
301    #[test]
302    fn test_narrow_out_of_bounds() {
303        let layout = Layout::contiguous(Shape::from((4, 6)));
304        assert!(layout.narrow(1, 5, 3).is_err()); // 5+3 = 8 > 6
305    }
306
307    #[test]
308    fn test_flat_index() {
309        let layout = Layout::contiguous(Shape::from((2, 3, 4)));
310        // Element at [1, 2, 3]: 1*12 + 2*4 + 3*1 = 23
311        assert_eq!(layout.flat_index(&[1, 2, 3]), 23);
312        // Element at [0, 0, 0]: 0
313        assert_eq!(layout.flat_index(&[0, 0, 0]), 0);
314    }
315}