Skip to main content

tenflowers_core/
strided.rs

1use crate::{Result, TensorError};
2use std::ops::Range;
3
4/// Slice parameters with stride support
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct SliceParams {
7    pub start: Option<isize>,
8    pub end: Option<isize>,
9    pub step: Option<isize>,
10}
11
12impl SliceParams {
13    /// Create a new slice parameters with default values
14    pub fn new() -> Self {
15        Self {
16            start: None,
17            end: None,
18            step: Some(1),
19        }
20    }
21
22    /// Create slice parameters with start, end, and step
23    pub fn with_step(start: Option<isize>, end: Option<isize>, step: Option<isize>) -> Self {
24        Self { start, end, step }
25    }
26
27    /// Convert to normalized start, end, step for a given dimension size
28    pub fn normalize(&self, size: usize) -> Result<(usize, usize, isize)> {
29        let size = size as isize;
30        let step = self.step.unwrap_or(1);
31
32        if step == 0 {
33            return Err(TensorError::invalid_argument(
34                "Slice step cannot be zero".to_string(),
35            ));
36        }
37
38        let (start, end) = if step > 0 {
39            let start = match self.start {
40                Some(s) if s < 0 => (size + s).max(0) as usize,
41                Some(s) => (s as usize).min(size as usize),
42                None => 0,
43            };
44            let end = match self.end {
45                Some(e) if e < 0 => (size + e).max(0) as usize,
46                Some(e) => (e as usize).min(size as usize),
47                None => size as usize,
48            };
49            (start, end)
50        } else {
51            let start = match self.start {
52                Some(s) if s < 0 => (size + s).max(-1) as usize,
53                Some(s) => (s as usize).min(size as usize - 1),
54                None => size as usize - 1,
55            };
56            let end = match self.end {
57                Some(e) if e < 0 => (size + e).max(-1) as usize,
58                Some(e) => (e as usize).min(size as usize - 1),
59                None => 0,
60            };
61            (start, end)
62        };
63
64        Ok((start, end, step))
65    }
66}
67
68impl Default for SliceParams {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl From<Range<usize>> for SliceParams {
75    fn from(range: Range<usize>) -> Self {
76        Self {
77            start: Some(range.start as isize),
78            end: Some(range.end as isize),
79            step: Some(1),
80        }
81    }
82}
83
84/// Strided tensor layout for efficient views and slicing
85#[derive(Debug, Clone, PartialEq, Eq)]
86pub struct StridedLayout {
87    shape: Vec<usize>,
88    strides: Vec<isize>,
89    offset: usize,
90}
91
92impl StridedLayout {
93    /// Create a new strided layout with default C-contiguous strides
94    pub fn new(shape: Vec<usize>) -> Self {
95        let strides = Self::compute_strides(&shape);
96        Self {
97            shape,
98            strides,
99            offset: 0,
100        }
101    }
102
103    /// Create layout with custom strides
104    pub fn with_strides(shape: Vec<usize>, strides: Vec<isize>, offset: usize) -> Result<Self> {
105        if shape.len() != strides.len() {
106            return Err(TensorError::invalid_argument(format!(
107                "Shape and strides must have same length: {} != {}",
108                strides.len(),
109                shape.len()
110            )));
111        }
112
113        Ok(Self {
114            shape,
115            strides,
116            offset,
117        })
118    }
119
120    /// Compute C-contiguous strides for a shape
121    fn compute_strides(shape: &[usize]) -> Vec<isize> {
122        let mut strides = vec![1isize; shape.len()];
123        for i in (0..shape.len() - 1).rev() {
124            strides[i] = strides[i + 1] * shape[i + 1] as isize;
125        }
126        strides
127    }
128
129    /// Get the shape
130    pub fn shape(&self) -> &[usize] {
131        &self.shape
132    }
133
134    /// Get the strides
135    pub fn strides(&self) -> &[isize] {
136        &self.strides
137    }
138
139    /// Get the offset
140    pub fn offset(&self) -> usize {
141        self.offset
142    }
143
144    /// Get total number of elements
145    pub fn numel(&self) -> usize {
146        self.shape.iter().product()
147    }
148
149    /// Check if the layout is contiguous (C-order)
150    pub fn is_contiguous(&self) -> bool {
151        if self.offset != 0 {
152            return false;
153        }
154
155        let expected_strides = Self::compute_strides(&self.shape);
156        self.strides == expected_strides
157    }
158
159    /// Check if the layout is Fortran contiguous (F-order)
160    pub fn is_fortran_contiguous(&self) -> bool {
161        if self.offset != 0 {
162            return false;
163        }
164
165        let mut expected_strides = vec![1isize; self.shape.len()];
166        for i in 1..self.shape.len() {
167            expected_strides[i] = expected_strides[i - 1] * self.shape[i - 1] as isize;
168        }
169
170        self.strides == expected_strides
171    }
172
173    /// Compute the linear index for a multi-dimensional index
174    pub fn linear_index(&self, indices: &[usize]) -> Result<usize> {
175        if indices.len() != self.shape.len() {
176            return Err(TensorError::invalid_argument(format!(
177                "Index dimension mismatch: {} != {}",
178                indices.len(),
179                self.shape.len()
180            )));
181        }
182
183        let mut linear_idx = self.offset as isize;
184        for (i, &idx) in indices.iter().enumerate() {
185            if idx >= self.shape[i] {
186                return Err(TensorError::invalid_argument(format!(
187                    "Index out of bounds: {} >= {}",
188                    idx, self.shape[i]
189                )));
190            }
191            linear_idx += idx as isize * self.strides[i];
192        }
193
194        Ok(linear_idx as usize)
195    }
196
197    /// Create a view by slicing along dimensions
198    pub fn slice(&self, ranges: &[Range<usize>]) -> Result<Self> {
199        if ranges.len() != self.shape.len() {
200            return Err(TensorError::invalid_argument(format!(
201                "Slice dimension mismatch: {} != {}",
202                ranges.len(),
203                self.shape.len()
204            )));
205        }
206
207        let mut new_shape = Vec::with_capacity(self.shape.len());
208        let mut new_offset = self.offset as isize;
209
210        for (i, range) in ranges.iter().enumerate() {
211            if range.start > range.end || range.end > self.shape[i] {
212                return Err(TensorError::invalid_argument(format!(
213                    "Invalid slice range {:?} for dimension size {}",
214                    range, self.shape[i]
215                )));
216            }
217
218            new_shape.push(range.end - range.start);
219            new_offset += range.start as isize * self.strides[i];
220        }
221
222        if new_offset < 0 {
223            return Err(TensorError::invalid_argument(
224                "Slice operation resulted in negative offset".to_string(),
225            ));
226        }
227
228        Ok(Self {
229            shape: new_shape,
230            strides: self.strides.clone(),
231            offset: new_offset as usize,
232        })
233    }
234
235    /// Create a view by slicing along dimensions with stride support
236    pub fn slice_with_stride(&self, slice_params: &[SliceParams]) -> Result<Self> {
237        if slice_params.len() != self.shape.len() {
238            return Err(TensorError::invalid_argument(format!(
239                "Slice dimension mismatch: {} != {}",
240                slice_params.len(),
241                self.shape.len()
242            )));
243        }
244
245        let mut new_shape = Vec::with_capacity(self.shape.len());
246        let mut new_strides = Vec::with_capacity(self.strides.len());
247        let mut new_offset = self.offset as isize;
248
249        for (i, slice_param) in slice_params.iter().enumerate() {
250            let (start, end, step) = slice_param.normalize(self.shape[i])?;
251
252            // Calculate the new dimension size
253            let new_dim_size = if step > 0 {
254                if start >= end {
255                    0
256                } else {
257                    ((end - start) as isize + step - 1) / step
258                }
259            } else if start <= end {
260                0
261            } else {
262                ((start as isize - end as isize) + (-step) - 1) / (-step)
263            };
264
265            new_shape.push(new_dim_size.max(0) as usize);
266            new_strides.push(self.strides[i] * step);
267            new_offset += start as isize * self.strides[i];
268        }
269
270        if new_offset < 0 {
271            return Err(TensorError::invalid_argument(
272                "Slice operation resulted in negative offset".to_string(),
273            ));
274        }
275
276        Ok(Self {
277            shape: new_shape,
278            strides: new_strides,
279            offset: new_offset as usize,
280        })
281    }
282
283    /// Transpose the layout
284    pub fn transpose(&self, axes: Option<&[usize]>) -> Result<Self> {
285        let axes = if let Some(axes) = axes {
286            if axes.len() != self.shape.len() {
287                return Err(TensorError::invalid_argument(String::new()));
288            }
289            axes.to_vec()
290        } else {
291            // Default: reverse all axes
292            (0..self.shape.len()).rev().collect()
293        };
294
295        // Check for valid permutation
296        let mut seen = vec![false; self.shape.len()];
297        for &ax in &axes {
298            if ax >= self.shape.len() {
299                return Err(TensorError::invalid_argument(String::new()));
300            }
301            if seen[ax] {
302                return Err(TensorError::invalid_argument(String::new()));
303            }
304            seen[ax] = true;
305        }
306
307        let new_shape: Vec<_> = axes.iter().map(|&i| self.shape[i]).collect();
308        let new_strides: Vec<_> = axes.iter().map(|&i| self.strides[i]).collect();
309
310        Ok(Self {
311            shape: new_shape,
312            strides: new_strides,
313            offset: self.offset,
314        })
315    }
316
317    /// Reshape the layout (only works if contiguous)
318    pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self> {
319        if !self.is_contiguous() {
320            return Err(TensorError::invalid_argument(String::new()));
321        }
322
323        let old_numel: usize = self.shape.iter().product();
324        let new_numel: usize = new_shape.iter().product();
325
326        if old_numel != new_numel {
327            return Err(TensorError::invalid_argument(String::new()));
328        }
329
330        Ok(Self::new(new_shape))
331    }
332
333    /// Broadcast to a new shape
334    pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Self> {
335        // Check if we can broadcast
336        if target_shape.len() < self.shape.len() {
337            return Err(TensorError::invalid_argument(String::new()));
338        }
339
340        // Prepare new shape and strides
341        let mut new_shape = vec![1; target_shape.len()];
342        let mut new_strides = vec![0; target_shape.len()];
343        let offset = target_shape.len() - self.shape.len();
344
345        // Process existing dimensions
346        for i in 0..self.shape.len() {
347            let target_dim = target_shape[i + offset];
348            let self_dim = self.shape[i];
349
350            // Validate broadcast compatibility
351            if self_dim != 1 && self_dim != target_dim {
352                return Err(TensorError::invalid_argument(format!(
353                    "Cannot broadcast dimension {self_dim} to {target_dim} at axis {i}"
354                )));
355            }
356
357            new_shape[i + offset] = target_dim;
358            new_strides[i + offset] = if self_dim == 1 { 0 } else { self.strides[i] };
359        }
360
361        // Set leading dimensions
362        for i in 0..offset {
363            new_shape[i] = target_shape[i];
364            new_strides[i] = 0;
365        }
366
367        Ok(Self {
368            shape: new_shape,
369            strides: new_strides,
370            offset: self.offset,
371        })
372    }
373
374    /// Create an iterator over all valid indices
375    pub fn indices_iter(&self) -> StridedIndicesIter {
376        StridedIndicesIter::new(&self.shape)
377    }
378}
379
380/// Iterator over multi-dimensional indices
381pub struct StridedIndicesIter {
382    shape: Vec<usize>,
383    current: Vec<usize>,
384    done: bool,
385}
386
387impl StridedIndicesIter {
388    fn new(shape: &[usize]) -> Self {
389        Self {
390            shape: shape.to_vec(),
391            current: vec![0; shape.len()],
392            done: shape.contains(&0),
393        }
394    }
395}
396
397impl Iterator for StridedIndicesIter {
398    type Item = Vec<usize>;
399
400    fn next(&mut self) -> Option<Self::Item> {
401        if self.done {
402            return None;
403        }
404
405        let result = self.current.clone();
406
407        // Increment indices
408        for i in (0..self.shape.len()).rev() {
409            self.current[i] += 1;
410            if self.current[i] < self.shape[i] {
411                break;
412            }
413            if i == 0 {
414                self.done = true;
415            } else {
416                self.current[i] = 0;
417            }
418        }
419
420        Some(result)
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_strided_layout_basic() {
430        let layout = StridedLayout::new(vec![2, 3, 4]);
431        assert_eq!(layout.shape(), &[2, 3, 4]);
432        assert_eq!(layout.strides(), &[12, 4, 1]);
433        assert_eq!(layout.offset(), 0);
434        assert!(layout.is_contiguous());
435    }
436
437    #[test]
438    fn test_linear_index() {
439        let layout = StridedLayout::new(vec![2, 3, 4]);
440        assert_eq!(
441            layout
442                .linear_index(&[0, 0, 0])
443                .expect("test: linear_index should succeed"),
444            0
445        );
446        assert_eq!(
447            layout
448                .linear_index(&[1, 2, 3])
449                .expect("test: linear_index should succeed"),
450            23
451        );
452        assert_eq!(
453            layout
454                .linear_index(&[1, 0, 0])
455                .expect("test: linear_index should succeed"),
456            12
457        );
458    }
459
460    #[test]
461    fn test_slice() {
462        let layout = StridedLayout::new(vec![4, 5, 6]);
463        let sliced = layout
464            .slice(&[1..3, 0..5, 2..4])
465            .expect("test: slice should succeed");
466        assert_eq!(sliced.shape(), &[2, 5, 2]);
467        assert_eq!(sliced.strides(), &[30, 6, 1]);
468        assert_eq!(sliced.offset(), 32); // 1*30 + 0*6 + 2*1
469    }
470
471    #[test]
472    fn test_transpose() {
473        let layout = StridedLayout::new(vec![2, 3, 4]);
474        let transposed = layout
475            .transpose(Some(&[2, 0, 1]))
476            .expect("test: operation should succeed");
477        assert_eq!(transposed.shape(), &[4, 2, 3]);
478        assert_eq!(transposed.strides(), &[1, 12, 4]);
479    }
480
481    #[test]
482    fn test_broadcast() {
483        let layout = StridedLayout::new(vec![1, 3, 1]);
484        let broadcasted = layout
485            .broadcast_to(&[2, 3, 4])
486            .expect("test: broadcast_to should succeed");
487        assert_eq!(broadcasted.shape(), &[2, 3, 4]);
488        assert_eq!(broadcasted.strides(), &[0, 1, 0]);
489    }
490
491    #[test]
492    fn test_slice_params_normalize() {
493        let params = SliceParams::with_step(Some(1), Some(4), Some(2));
494        let (start, end, step) = params.normalize(6).expect("test: normalize should succeed");
495        assert_eq!(start, 1);
496        assert_eq!(end, 4);
497        assert_eq!(step, 2);
498
499        // Test negative indices
500        let params = SliceParams::with_step(Some(-2), Some(-1), Some(1));
501        let (start, end, step) = params.normalize(6).expect("test: normalize should succeed");
502        assert_eq!(start, 4);
503        assert_eq!(end, 5);
504        assert_eq!(step, 1);
505    }
506
507    #[test]
508    fn test_slice_with_stride() {
509        let layout = StridedLayout::new(vec![6, 4]);
510
511        // Test basic stride slicing - every 2nd element
512        let slice_params = vec![
513            SliceParams::with_step(Some(0), Some(6), Some(2)),
514            SliceParams::with_step(Some(0), Some(4), Some(1)),
515        ];
516        let sliced = layout
517            .slice_with_stride(&slice_params)
518            .expect("test: slice_with_stride should succeed");
519        assert_eq!(sliced.shape(), &[3, 4]);
520        assert_eq!(sliced.strides(), &[8, 1]); // stride doubled for dimension 0
521        assert_eq!(sliced.offset(), 0);
522
523        // Test negative step
524        let slice_params = vec![
525            SliceParams::with_step(Some(5), Some(0), Some(-2)),
526            SliceParams::with_step(Some(0), Some(4), Some(1)),
527        ];
528        let sliced = layout
529            .slice_with_stride(&slice_params)
530            .expect("test: slice_with_stride should succeed");
531        assert_eq!(sliced.shape(), &[3, 4]);
532        assert_eq!(sliced.strides(), &[-8, 1]); // negative stride for dimension 0
533        assert_eq!(sliced.offset(), 20); // 5*4 + 0*1
534    }
535
536    #[test]
537    fn test_slice_with_stride_default_params() {
538        let layout = StridedLayout::new(vec![4, 4]);
539
540        // Test with default parameters (equivalent to full slice)
541        let slice_params = vec![SliceParams::default(), SliceParams::default()];
542        let sliced = layout
543            .slice_with_stride(&slice_params)
544            .expect("test: slice_with_stride should succeed");
545        assert_eq!(sliced.shape(), &[4, 4]);
546        assert_eq!(sliced.strides(), &[4, 1]);
547        assert_eq!(sliced.offset(), 0);
548    }
549
550    #[test]
551    fn test_slice_with_stride_from_range() {
552        let layout = StridedLayout::new(vec![6, 4]);
553
554        // Test converting from Range to SliceParams
555        let slice_params = vec![SliceParams::from(1..5), SliceParams::from(0..4)];
556        let sliced = layout
557            .slice_with_stride(&slice_params)
558            .expect("test: slice_with_stride should succeed");
559        assert_eq!(sliced.shape(), &[4, 4]);
560        assert_eq!(sliced.strides(), &[4, 1]);
561        assert_eq!(sliced.offset(), 4); // 1*4 + 0*1
562    }
563}