1use crate::{Result, TensorError};
2use std::ops::Range;
3
4#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct SliceParams {
7 pub start: Option<isize>,
8 pub end: Option<isize>,
9 pub step: Option<isize>,
10}
11
12impl SliceParams {
13 pub fn new() -> Self {
15 Self {
16 start: None,
17 end: None,
18 step: Some(1),
19 }
20 }
21
22 pub fn with_step(start: Option<isize>, end: Option<isize>, step: Option<isize>) -> Self {
24 Self { start, end, step }
25 }
26
27 pub fn normalize(&self, size: usize) -> Result<(usize, usize, isize)> {
29 let size = size as isize;
30 let step = self.step.unwrap_or(1);
31
32 if step == 0 {
33 return Err(TensorError::invalid_argument(
34 "Slice step cannot be zero".to_string(),
35 ));
36 }
37
38 let (start, end) = if step > 0 {
39 let start = match self.start {
40 Some(s) if s < 0 => (size + s).max(0) as usize,
41 Some(s) => (s as usize).min(size as usize),
42 None => 0,
43 };
44 let end = match self.end {
45 Some(e) if e < 0 => (size + e).max(0) as usize,
46 Some(e) => (e as usize).min(size as usize),
47 None => size as usize,
48 };
49 (start, end)
50 } else {
51 let start = match self.start {
52 Some(s) if s < 0 => (size + s).max(-1) as usize,
53 Some(s) => (s as usize).min(size as usize - 1),
54 None => size as usize - 1,
55 };
56 let end = match self.end {
57 Some(e) if e < 0 => (size + e).max(-1) as usize,
58 Some(e) => (e as usize).min(size as usize - 1),
59 None => 0,
60 };
61 (start, end)
62 };
63
64 Ok((start, end, step))
65 }
66}
67
68impl Default for SliceParams {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl From<Range<usize>> for SliceParams {
75 fn from(range: Range<usize>) -> Self {
76 Self {
77 start: Some(range.start as isize),
78 end: Some(range.end as isize),
79 step: Some(1),
80 }
81 }
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
86pub struct StridedLayout {
87 shape: Vec<usize>,
88 strides: Vec<isize>,
89 offset: usize,
90}
91
92impl StridedLayout {
93 pub fn new(shape: Vec<usize>) -> Self {
95 let strides = Self::compute_strides(&shape);
96 Self {
97 shape,
98 strides,
99 offset: 0,
100 }
101 }
102
103 pub fn with_strides(shape: Vec<usize>, strides: Vec<isize>, offset: usize) -> Result<Self> {
105 if shape.len() != strides.len() {
106 return Err(TensorError::invalid_argument(format!(
107 "Shape and strides must have same length: {} != {}",
108 strides.len(),
109 shape.len()
110 )));
111 }
112
113 Ok(Self {
114 shape,
115 strides,
116 offset,
117 })
118 }
119
120 fn compute_strides(shape: &[usize]) -> Vec<isize> {
122 let mut strides = vec![1isize; shape.len()];
123 for i in (0..shape.len() - 1).rev() {
124 strides[i] = strides[i + 1] * shape[i + 1] as isize;
125 }
126 strides
127 }
128
129 pub fn shape(&self) -> &[usize] {
131 &self.shape
132 }
133
134 pub fn strides(&self) -> &[isize] {
136 &self.strides
137 }
138
139 pub fn offset(&self) -> usize {
141 self.offset
142 }
143
144 pub fn numel(&self) -> usize {
146 self.shape.iter().product()
147 }
148
149 pub fn is_contiguous(&self) -> bool {
151 if self.offset != 0 {
152 return false;
153 }
154
155 let expected_strides = Self::compute_strides(&self.shape);
156 self.strides == expected_strides
157 }
158
159 pub fn is_fortran_contiguous(&self) -> bool {
161 if self.offset != 0 {
162 return false;
163 }
164
165 let mut expected_strides = vec![1isize; self.shape.len()];
166 for i in 1..self.shape.len() {
167 expected_strides[i] = expected_strides[i - 1] * self.shape[i - 1] as isize;
168 }
169
170 self.strides == expected_strides
171 }
172
173 pub fn linear_index(&self, indices: &[usize]) -> Result<usize> {
175 if indices.len() != self.shape.len() {
176 return Err(TensorError::invalid_argument(format!(
177 "Index dimension mismatch: {} != {}",
178 indices.len(),
179 self.shape.len()
180 )));
181 }
182
183 let mut linear_idx = self.offset as isize;
184 for (i, &idx) in indices.iter().enumerate() {
185 if idx >= self.shape[i] {
186 return Err(TensorError::invalid_argument(format!(
187 "Index out of bounds: {} >= {}",
188 idx, self.shape[i]
189 )));
190 }
191 linear_idx += idx as isize * self.strides[i];
192 }
193
194 Ok(linear_idx as usize)
195 }
196
197 pub fn slice(&self, ranges: &[Range<usize>]) -> Result<Self> {
199 if ranges.len() != self.shape.len() {
200 return Err(TensorError::invalid_argument(format!(
201 "Slice dimension mismatch: {} != {}",
202 ranges.len(),
203 self.shape.len()
204 )));
205 }
206
207 let mut new_shape = Vec::with_capacity(self.shape.len());
208 let mut new_offset = self.offset as isize;
209
210 for (i, range) in ranges.iter().enumerate() {
211 if range.start > range.end || range.end > self.shape[i] {
212 return Err(TensorError::invalid_argument(format!(
213 "Invalid slice range {:?} for dimension size {}",
214 range, self.shape[i]
215 )));
216 }
217
218 new_shape.push(range.end - range.start);
219 new_offset += range.start as isize * self.strides[i];
220 }
221
222 if new_offset < 0 {
223 return Err(TensorError::invalid_argument(
224 "Slice operation resulted in negative offset".to_string(),
225 ));
226 }
227
228 Ok(Self {
229 shape: new_shape,
230 strides: self.strides.clone(),
231 offset: new_offset as usize,
232 })
233 }
234
235 pub fn slice_with_stride(&self, slice_params: &[SliceParams]) -> Result<Self> {
237 if slice_params.len() != self.shape.len() {
238 return Err(TensorError::invalid_argument(format!(
239 "Slice dimension mismatch: {} != {}",
240 slice_params.len(),
241 self.shape.len()
242 )));
243 }
244
245 let mut new_shape = Vec::with_capacity(self.shape.len());
246 let mut new_strides = Vec::with_capacity(self.strides.len());
247 let mut new_offset = self.offset as isize;
248
249 for (i, slice_param) in slice_params.iter().enumerate() {
250 let (start, end, step) = slice_param.normalize(self.shape[i])?;
251
252 let new_dim_size = if step > 0 {
254 if start >= end {
255 0
256 } else {
257 ((end - start) as isize + step - 1) / step
258 }
259 } else if start <= end {
260 0
261 } else {
262 ((start as isize - end as isize) + (-step) - 1) / (-step)
263 };
264
265 new_shape.push(new_dim_size.max(0) as usize);
266 new_strides.push(self.strides[i] * step);
267 new_offset += start as isize * self.strides[i];
268 }
269
270 if new_offset < 0 {
271 return Err(TensorError::invalid_argument(
272 "Slice operation resulted in negative offset".to_string(),
273 ));
274 }
275
276 Ok(Self {
277 shape: new_shape,
278 strides: new_strides,
279 offset: new_offset as usize,
280 })
281 }
282
283 pub fn transpose(&self, axes: Option<&[usize]>) -> Result<Self> {
285 let axes = if let Some(axes) = axes {
286 if axes.len() != self.shape.len() {
287 return Err(TensorError::invalid_argument(String::new()));
288 }
289 axes.to_vec()
290 } else {
291 (0..self.shape.len()).rev().collect()
293 };
294
295 let mut seen = vec![false; self.shape.len()];
297 for &ax in &axes {
298 if ax >= self.shape.len() {
299 return Err(TensorError::invalid_argument(String::new()));
300 }
301 if seen[ax] {
302 return Err(TensorError::invalid_argument(String::new()));
303 }
304 seen[ax] = true;
305 }
306
307 let new_shape: Vec<_> = axes.iter().map(|&i| self.shape[i]).collect();
308 let new_strides: Vec<_> = axes.iter().map(|&i| self.strides[i]).collect();
309
310 Ok(Self {
311 shape: new_shape,
312 strides: new_strides,
313 offset: self.offset,
314 })
315 }
316
317 pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self> {
319 if !self.is_contiguous() {
320 return Err(TensorError::invalid_argument(String::new()));
321 }
322
323 let old_numel: usize = self.shape.iter().product();
324 let new_numel: usize = new_shape.iter().product();
325
326 if old_numel != new_numel {
327 return Err(TensorError::invalid_argument(String::new()));
328 }
329
330 Ok(Self::new(new_shape))
331 }
332
333 pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Self> {
335 if target_shape.len() < self.shape.len() {
337 return Err(TensorError::invalid_argument(String::new()));
338 }
339
340 let mut new_shape = vec![1; target_shape.len()];
342 let mut new_strides = vec![0; target_shape.len()];
343 let offset = target_shape.len() - self.shape.len();
344
345 for i in 0..self.shape.len() {
347 let target_dim = target_shape[i + offset];
348 let self_dim = self.shape[i];
349
350 if self_dim != 1 && self_dim != target_dim {
352 return Err(TensorError::invalid_argument(format!(
353 "Cannot broadcast dimension {self_dim} to {target_dim} at axis {i}"
354 )));
355 }
356
357 new_shape[i + offset] = target_dim;
358 new_strides[i + offset] = if self_dim == 1 { 0 } else { self.strides[i] };
359 }
360
361 for i in 0..offset {
363 new_shape[i] = target_shape[i];
364 new_strides[i] = 0;
365 }
366
367 Ok(Self {
368 shape: new_shape,
369 strides: new_strides,
370 offset: self.offset,
371 })
372 }
373
374 pub fn indices_iter(&self) -> StridedIndicesIter {
376 StridedIndicesIter::new(&self.shape)
377 }
378}
379
380pub struct StridedIndicesIter {
382 shape: Vec<usize>,
383 current: Vec<usize>,
384 done: bool,
385}
386
387impl StridedIndicesIter {
388 fn new(shape: &[usize]) -> Self {
389 Self {
390 shape: shape.to_vec(),
391 current: vec![0; shape.len()],
392 done: shape.contains(&0),
393 }
394 }
395}
396
397impl Iterator for StridedIndicesIter {
398 type Item = Vec<usize>;
399
400 fn next(&mut self) -> Option<Self::Item> {
401 if self.done {
402 return None;
403 }
404
405 let result = self.current.clone();
406
407 for i in (0..self.shape.len()).rev() {
409 self.current[i] += 1;
410 if self.current[i] < self.shape[i] {
411 break;
412 }
413 if i == 0 {
414 self.done = true;
415 } else {
416 self.current[i] = 0;
417 }
418 }
419
420 Some(result)
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_strided_layout_basic() {
430 let layout = StridedLayout::new(vec![2, 3, 4]);
431 assert_eq!(layout.shape(), &[2, 3, 4]);
432 assert_eq!(layout.strides(), &[12, 4, 1]);
433 assert_eq!(layout.offset(), 0);
434 assert!(layout.is_contiguous());
435 }
436
437 #[test]
438 fn test_linear_index() {
439 let layout = StridedLayout::new(vec![2, 3, 4]);
440 assert_eq!(
441 layout
442 .linear_index(&[0, 0, 0])
443 .expect("test: linear_index should succeed"),
444 0
445 );
446 assert_eq!(
447 layout
448 .linear_index(&[1, 2, 3])
449 .expect("test: linear_index should succeed"),
450 23
451 );
452 assert_eq!(
453 layout
454 .linear_index(&[1, 0, 0])
455 .expect("test: linear_index should succeed"),
456 12
457 );
458 }
459
460 #[test]
461 fn test_slice() {
462 let layout = StridedLayout::new(vec![4, 5, 6]);
463 let sliced = layout
464 .slice(&[1..3, 0..5, 2..4])
465 .expect("test: slice should succeed");
466 assert_eq!(sliced.shape(), &[2, 5, 2]);
467 assert_eq!(sliced.strides(), &[30, 6, 1]);
468 assert_eq!(sliced.offset(), 32); }
470
471 #[test]
472 fn test_transpose() {
473 let layout = StridedLayout::new(vec![2, 3, 4]);
474 let transposed = layout
475 .transpose(Some(&[2, 0, 1]))
476 .expect("test: operation should succeed");
477 assert_eq!(transposed.shape(), &[4, 2, 3]);
478 assert_eq!(transposed.strides(), &[1, 12, 4]);
479 }
480
481 #[test]
482 fn test_broadcast() {
483 let layout = StridedLayout::new(vec![1, 3, 1]);
484 let broadcasted = layout
485 .broadcast_to(&[2, 3, 4])
486 .expect("test: broadcast_to should succeed");
487 assert_eq!(broadcasted.shape(), &[2, 3, 4]);
488 assert_eq!(broadcasted.strides(), &[0, 1, 0]);
489 }
490
491 #[test]
492 fn test_slice_params_normalize() {
493 let params = SliceParams::with_step(Some(1), Some(4), Some(2));
494 let (start, end, step) = params.normalize(6).expect("test: normalize should succeed");
495 assert_eq!(start, 1);
496 assert_eq!(end, 4);
497 assert_eq!(step, 2);
498
499 let params = SliceParams::with_step(Some(-2), Some(-1), Some(1));
501 let (start, end, step) = params.normalize(6).expect("test: normalize should succeed");
502 assert_eq!(start, 4);
503 assert_eq!(end, 5);
504 assert_eq!(step, 1);
505 }
506
507 #[test]
508 fn test_slice_with_stride() {
509 let layout = StridedLayout::new(vec![6, 4]);
510
511 let slice_params = vec![
513 SliceParams::with_step(Some(0), Some(6), Some(2)),
514 SliceParams::with_step(Some(0), Some(4), Some(1)),
515 ];
516 let sliced = layout
517 .slice_with_stride(&slice_params)
518 .expect("test: slice_with_stride should succeed");
519 assert_eq!(sliced.shape(), &[3, 4]);
520 assert_eq!(sliced.strides(), &[8, 1]); assert_eq!(sliced.offset(), 0);
522
523 let slice_params = vec![
525 SliceParams::with_step(Some(5), Some(0), Some(-2)),
526 SliceParams::with_step(Some(0), Some(4), Some(1)),
527 ];
528 let sliced = layout
529 .slice_with_stride(&slice_params)
530 .expect("test: slice_with_stride should succeed");
531 assert_eq!(sliced.shape(), &[3, 4]);
532 assert_eq!(sliced.strides(), &[-8, 1]); assert_eq!(sliced.offset(), 20); }
535
536 #[test]
537 fn test_slice_with_stride_default_params() {
538 let layout = StridedLayout::new(vec![4, 4]);
539
540 let slice_params = vec![SliceParams::default(), SliceParams::default()];
542 let sliced = layout
543 .slice_with_stride(&slice_params)
544 .expect("test: slice_with_stride should succeed");
545 assert_eq!(sliced.shape(), &[4, 4]);
546 assert_eq!(sliced.strides(), &[4, 1]);
547 assert_eq!(sliced.offset(), 0);
548 }
549
550 #[test]
551 fn test_slice_with_stride_from_range() {
552 let layout = StridedLayout::new(vec![6, 4]);
553
554 let slice_params = vec![SliceParams::from(1..5), SliceParams::from(0..4)];
556 let sliced = layout
557 .slice_with_stride(&slice_params)
558 .expect("test: slice_with_stride should succeed");
559 assert_eq!(sliced.shape(), &[4, 4]);
560 assert_eq!(sliced.strides(), &[4, 1]);
561 assert_eq!(sliced.offset(), 4); }
563}