1use crate::{bail, Error, Result};
3
4#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
5pub struct Shape(Vec<usize>);
6
7pub const SCALAR: Shape = Shape(vec![]);
8
9impl std::fmt::Debug for Shape {
10 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
11 write!(f, "{:?}", &self.dims())
12 }
13}
14
15impl<const C: usize> From<&[usize; C]> for Shape {
16 fn from(dims: &[usize; C]) -> Self {
17 Self(dims.to_vec())
18 }
19}
20
21impl From<&[usize]> for Shape {
22 fn from(dims: &[usize]) -> Self {
23 Self(dims.to_vec())
24 }
25}
26
27impl From<&Shape> for Shape {
28 fn from(shape: &Shape) -> Self {
29 Self(shape.0.to_vec())
30 }
31}
32
33impl From<()> for Shape {
34 fn from(_: ()) -> Self {
35 Self(vec![])
36 }
37}
38
39impl From<usize> for Shape {
40 fn from(d1: usize) -> Self {
41 Self(vec![d1])
42 }
43}
44
45impl From<(usize,)> for Shape {
46 fn from(d1: (usize,)) -> Self {
47 Self(vec![d1.0])
48 }
49}
50
51impl From<(usize, usize)> for Shape {
52 fn from(d12: (usize, usize)) -> Self {
53 Self(vec![d12.0, d12.1])
54 }
55}
56
57impl From<(usize, usize, usize)> for Shape {
58 fn from(d123: (usize, usize, usize)) -> Self {
59 Self(vec![d123.0, d123.1, d123.2])
60 }
61}
62
63impl From<(usize, usize, usize, usize)> for Shape {
64 fn from(d1234: (usize, usize, usize, usize)) -> Self {
65 Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
66 }
67}
68
69impl From<(usize, usize, usize, usize, usize)> for Shape {
70 fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
71 Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
72 }
73}
74
75impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
76 fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
77 Self(vec![d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5])
78 }
79}
80
81impl From<Vec<usize>> for Shape {
82 fn from(dims: Vec<usize>) -> Self {
83 Self(dims)
84 }
85}
86
87macro_rules! extract_dims {
88 ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
89 pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
90 if dims.len() != $cnt {
91 bail!(
92 "unexpected number of dims, expected {} got {} shape {:?}",
93 $cnt,
94 dims.len(),
95 dims
96 )
97 }
98 Ok($dims(dims))
99 }
100
101 impl Shape {
102 pub fn $fn_name(&self) -> Result<$out_type> {
103 $fn_name(self.0.as_slice())
104 }
105 }
106
107 impl std::convert::TryInto<$out_type> for Shape {
108 type Error = crate::Error;
109 fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
110 self.$fn_name()
111 }
112 }
113 };
114}
115
116impl Shape {
117 pub fn num_elements(&self) -> usize {
118 self.dims().iter().product()
119 }
120
121 pub fn from_dims(dims: &[usize]) -> Self {
122 Self(dims.to_vec())
123 }
124
125 pub fn rank(&self) -> usize {
127 self.0.len()
128 }
129
130 pub fn into_dims(self) -> Vec<usize> {
131 self.0
132 }
133
134 pub fn dims(&self) -> &[usize] {
136 &self.0
137 }
138
139 pub fn elem_count(&self) -> usize {
141 self.0.iter().product()
142 }
143
144 pub fn stride_contiguous(&self) -> Vec<usize> {
147 let mut stride: Vec<_> = self
148 .0
149 .iter()
150 .rev()
151 .scan(1, |prod, u| {
152 let prod_pre_mult = *prod;
153 *prod *= u;
154 Some(prod_pre_mult)
155 })
156 .collect();
157 stride.reverse();
158 stride
159 }
160
161 pub fn is_contiguous(&self, stride: &[usize]) -> bool {
163 if self.0.len() != stride.len() {
164 return false;
165 }
166 let mut acc = 1;
167 for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
168 if dim > 1 && stride != acc {
169 return false;
170 }
171 acc *= dim;
172 }
173 true
174 }
175
176 pub fn is_fortran_contiguous(&self, stride: &[usize]) -> bool {
178 if self.0.len() != stride.len() {
179 return false;
180 }
181 let mut acc = 1;
182 for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
183 if dim > 1 && stride != acc {
184 return false;
185 }
186 acc *= dim;
187 }
188 true
189 }
190
191 pub fn extend(mut self, additional_dims: &[usize]) -> Self {
194 self.0.extend(additional_dims);
195 self
196 }
197
198 pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
201 let lhs = self;
202 let lhs_dims = lhs.dims();
203 let rhs_dims = rhs.dims();
204 let lhs_ndims = lhs_dims.len();
205 let rhs_ndims = rhs_dims.len();
206 let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
207 let mut bcast_dims = vec![0; bcast_ndims];
208 for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
209 let rev_idx = bcast_ndims - idx;
210 let l_value = if lhs_ndims < rev_idx { 1 } else { lhs_dims[lhs_ndims - rev_idx] };
211 let r_value = if rhs_ndims < rev_idx { 1 } else { rhs_dims[rhs_ndims - rev_idx] };
212 *bcast_value = if l_value == r_value {
213 l_value
214 } else if l_value == 1 {
215 r_value
216 } else if r_value == 1 {
217 l_value
218 } else {
219 bail!("shape mismatch in binary op '{op}', lhs: {lhs:?} rhs: {rhs:?}")
220 }
221 }
222 Ok(Shape::from(bcast_dims))
223 }
224}
225
226pub trait Dim {
227 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
228 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
229}
230
231impl Dim for usize {
232 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
233 let dim = *self;
234 if dim >= shape.dims().len() {
235 bail!("dim out of range in '{op}', dim: {dim}, shape: {shape:?}")
236 }
237 Ok(dim)
238 }
239
240 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
241 let dim = *self;
242 if dim > shape.dims().len() {
243 bail!("dim out of range in '{op}', dim: {dim}, shape: {shape:?}")
244 }
245 Ok(dim)
246 }
247}
248
249impl Dim for i32 {
250 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
251 let dim = *self;
252 if dim >= 0 {
253 (dim as usize).to_index(shape, op)
254 } else {
255 D::Minus((-dim) as usize).to_index(shape, op)
256 }
257 }
258
259 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
260 let dim = *self;
261 if dim >= 0 {
262 (dim as usize).to_index_plus_one(shape, op)
263 } else {
264 D::Minus((-dim) as usize).to_index_plus_one(shape, op)
265 }
266 }
267}
268
269#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
270pub enum D {
271 Minus1,
272 Minus2,
273 Minus(usize),
274}
275
276impl D {
277 fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
278 let dim = match self {
279 Self::Minus1 => -1,
280 Self::Minus2 => -2,
281 Self::Minus(u) => -(*u as i32),
282 };
283 Error::Msg(format!("dim out of range in '{op}', dim: {dim}, shape: {shape:?}")).bt()
284 }
285}
286
287impl Dim for D {
288 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
289 let rank = shape.rank();
290 match self {
291 Self::Minus1 if rank >= 1 => Ok(rank - 1),
292 Self::Minus2 if rank >= 2 => Ok(rank - 2),
293 Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
294 _ => Err(self.out_of_range(shape, op)),
295 }
296 }
297
298 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
299 let rank = shape.rank();
300 match self {
301 Self::Minus1 => Ok(rank),
302 Self::Minus2 if rank >= 1 => Ok(rank - 1),
303 Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
304 _ => Err(self.out_of_range(shape, op)),
305 }
306 }
307}
308
309pub trait Dims: Sized {
310 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
311
312 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
313 let dims = self.to_indexes_internal(shape, op)?;
314 for (i, &dim) in dims.iter().enumerate() {
315 if dims[..i].contains(&dim) {
316 bail!("duplicate dim indexes in '{op}', dims: {dims:?}, shape: {shape:?}")
317 }
318 if dim >= shape.rank() {
319 bail!("dim out of range in '{op}', dim: {dim}, shape: {shape:?}")
320 }
321 }
322 Ok(dims)
323 }
324}
325
326impl Dims for Vec<usize> {
327 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
328 Ok(self)
329 }
330}
331
332impl<const N: usize> Dims for [usize; N] {
333 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
334 Ok(self.to_vec())
335 }
336}
337
338impl Dims for &[usize] {
339 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
340 Ok(self.to_vec())
341 }
342}
343
344impl Dims for () {
345 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
346 Ok(vec![])
347 }
348}
349
350impl<D: Dim + Sized> Dims for D {
351 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
352 let dim = self.to_index(shape, op)?;
353 Ok(vec![dim])
354 }
355}
356
357impl<D: Dim> Dims for (D,) {
358 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
359 let dim = self.0.to_index(shape, op)?;
360 Ok(vec![dim])
361 }
362}
363
364impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
365 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
366 let d0 = self.0.to_index(shape, op)?;
367 let d1 = self.1.to_index(shape, op)?;
368 Ok(vec![d0, d1])
369 }
370}
371
372impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
373 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
374 let d0 = self.0.to_index(shape, op)?;
375 let d1 = self.1.to_index(shape, op)?;
376 let d2 = self.2.to_index(shape, op)?;
377 Ok(vec![d0, d1, d2])
378 }
379}
380
381impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
382 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
383 let d0 = self.0.to_index(shape, op)?;
384 let d1 = self.1.to_index(shape, op)?;
385 let d2 = self.2.to_index(shape, op)?;
386 let d3 = self.3.to_index(shape, op)?;
387 Ok(vec![d0, d1, d2, d3])
388 }
389}
390
391impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
392 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
393 let d0 = self.0.to_index(shape, op)?;
394 let d1 = self.1.to_index(shape, op)?;
395 let d2 = self.2.to_index(shape, op)?;
396 let d3 = self.3.to_index(shape, op)?;
397 let d4 = self.4.to_index(shape, op)?;
398 Ok(vec![d0, d1, d2, d3, d4])
399 }
400}
401
402impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
403 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
404 let d0 = self.0.to_index(shape, op)?;
405 let d1 = self.1.to_index(shape, op)?;
406 let d2 = self.2.to_index(shape, op)?;
407 let d3 = self.3.to_index(shape, op)?;
408 let d4 = self.4.to_index(shape, op)?;
409 let d5 = self.5.to_index(shape, op)?;
410 Ok(vec![d0, d1, d2, d3, d4, d5])
411 }
412}
413
414extract_dims!(dims0, 0, |_: &[usize]| (), ());
415extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
416extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
417extract_dims!(dims3, 3, |d: &[usize]| (d[0], d[1], d[2]), (usize, usize, usize));
418extract_dims!(dims4, 4, |d: &[usize]| (d[0], d[1], d[2], d[3]), (usize, usize, usize, usize));
419extract_dims!(
420 dims5,
421 5,
422 |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
423 (usize, usize, usize, usize, usize)
424);
425
426pub trait ShapeWithOneHole {
427 fn into_shape(self, el_count: usize) -> Result<Shape>;
428}
429
430impl<S: Into<Shape>> ShapeWithOneHole for S {
431 fn into_shape(self, _el_count: usize) -> Result<Shape> {
432 Ok(self.into())
433 }
434}
435
436impl ShapeWithOneHole for ((),) {
437 fn into_shape(self, el_count: usize) -> Result<Shape> {
438 Ok(el_count.into())
439 }
440}
441
442fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
443 if prod_d == 0 {
444 bail!("cannot reshape tensor of {el_count} elements to {s:?}")
445 }
446 if el_count % prod_d != 0 {
447 bail!("cannot reshape tensor with {el_count} elements to {s:?}")
448 }
449 Ok(el_count / prod_d)
450}
451
452impl ShapeWithOneHole for ((), usize) {
453 fn into_shape(self, el_count: usize) -> Result<Shape> {
454 let ((), d1) = self;
455 Ok((hole_size(el_count, d1, &self)?, d1).into())
456 }
457}
458
459impl ShapeWithOneHole for (usize, ()) {
460 fn into_shape(self, el_count: usize) -> Result<Shape> {
461 let (d1, ()) = self;
462 Ok((d1, hole_size(el_count, d1, &self)?).into())
463 }
464}
465
466impl ShapeWithOneHole for ((), usize, usize) {
467 fn into_shape(self, el_count: usize) -> Result<Shape> {
468 let ((), d1, d2) = self;
469 Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
470 }
471}
472
473impl ShapeWithOneHole for (usize, (), usize) {
474 fn into_shape(self, el_count: usize) -> Result<Shape> {
475 let (d1, (), d2) = self;
476 Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
477 }
478}
479
480impl ShapeWithOneHole for (usize, usize, ()) {
481 fn into_shape(self, el_count: usize) -> Result<Shape> {
482 let (d1, d2, ()) = self;
483 Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
484 }
485}
486
487impl ShapeWithOneHole for ((), usize, usize, usize) {
488 fn into_shape(self, el_count: usize) -> Result<Shape> {
489 let ((), d1, d2, d3) = self;
490 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
491 Ok((d, d1, d2, d3).into())
492 }
493}
494
495impl ShapeWithOneHole for (usize, (), usize, usize) {
496 fn into_shape(self, el_count: usize) -> Result<Shape> {
497 let (d1, (), d2, d3) = self;
498 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
499 Ok((d1, d, d2, d3).into())
500 }
501}
502
503impl ShapeWithOneHole for (usize, usize, (), usize) {
504 fn into_shape(self, el_count: usize) -> Result<Shape> {
505 let (d1, d2, (), d3) = self;
506 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
507 Ok((d1, d2, d, d3).into())
508 }
509}
510
511impl ShapeWithOneHole for (usize, usize, usize, ()) {
512 fn into_shape(self, el_count: usize) -> Result<Shape> {
513 let (d1, d2, d3, ()) = self;
514 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
515 Ok((d1, d2, d3, d).into())
516 }
517}
518
519impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
520 fn into_shape(self, el_count: usize) -> Result<Shape> {
521 let ((), d1, d2, d3, d4) = self;
522 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
523 Ok((d, d1, d2, d3, d4).into())
524 }
525}
526
527impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
528 fn into_shape(self, el_count: usize) -> Result<Shape> {
529 let (d1, (), d2, d3, d4) = self;
530 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
531 Ok((d1, d, d2, d3, d4).into())
532 }
533}
534
535impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
536 fn into_shape(self, el_count: usize) -> Result<Shape> {
537 let (d1, d2, (), d3, d4) = self;
538 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
539 Ok((d1, d2, d, d3, d4).into())
540 }
541}
542
543impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
544 fn into_shape(self, el_count: usize) -> Result<Shape> {
545 let (d1, d2, d3, (), d4) = self;
546 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
547 Ok((d1, d2, d3, d, d4).into())
548 }
549}
550
551impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
552 fn into_shape(self, el_count: usize) -> Result<Shape> {
553 let (d1, d2, d3, d4, ()) = self;
554 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
555 Ok((d1, d2, d3, d4, d).into())
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 #[test]
564 fn stride() {
565 let shape = Shape::from(());
566 assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
567 let shape = Shape::from(42);
568 assert_eq!(shape.stride_contiguous(), [1]);
569 let shape = Shape::from((42, 1337));
570 assert_eq!(shape.stride_contiguous(), [1337, 1]);
571 let shape = Shape::from((299, 792, 458));
572 assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
573 }
574}
575
576#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
577pub struct Layout {
578 shape: Shape,
579 strides: Vec<usize>,
580 offset: usize,
582}
583
584impl Layout {
585 pub fn can_be_compressed(&self) -> bool {
586 let strides = self.strides();
587 let dims = self.dims();
588 if dims.len() <= 1 {
589 return false;
590 }
591 for i in 0..dims.len() - 1 {
592 if strides[i] != strides[i + 1] * dims[i + 1] {
593 return false;
594 }
595 }
596 true
597 }
598
599 pub fn compress_all(&self) -> Result<Self> {
600 let strides = self.strides();
601 let dims = self.dims();
602 for i in 0..dims.len() - 1 {
603 if strides[i] != strides[i + 1] * dims[i + 1] {
604 bail!("cannot collapse dims, {self:?}")
605 }
606 }
607 let stride = strides.last().copied().unwrap_or(1);
608 let dim = self.num_elements();
609 Ok(Self { shape: Shape::from(dim), strides: vec![stride], offset: self.offset() })
610 }
611
612 pub fn from_shape<S: Into<Shape>>(shape: S) -> Self {
613 let shape = shape.into();
614 let mut strides = vec![];
615 let mut stride = 1;
616 for l in shape.dims().iter().rev() {
617 strides.push(stride);
618 stride *= l
619 }
620 strides.reverse();
621 Self { shape, strides, offset: 0 }
622 }
623
624 pub fn transpose(&self) -> Self {
625 let r = self.rank();
626 if r < 2 {
627 return self.clone();
628 }
629 let mut dims = self.dims().to_vec();
630 let mut strides = self.strides.to_vec();
631 dims.swap(r - 2, r - 1);
632 strides.swap(r - 2, r - 1);
633 Self { shape: dims.into(), offset: self.offset, strides }
634 }
635
636 pub fn num_elements(&self) -> usize {
637 self.shape.num_elements()
638 }
639
640 pub fn shape(&self) -> &Shape {
641 &self.shape
642 }
643
644 pub fn dims(&self) -> &[usize] {
645 self.shape.dims()
646 }
647
648 pub fn rank(&self) -> usize {
649 self.shape.rank()
650 }
651
652 pub fn strides(&self) -> &[usize] {
653 self.strides.as_slice()
654 }
655
656 pub fn offset(&self) -> usize {
657 self.offset
658 }
659
660 pub fn set_offset(&mut self, offset: usize) {
661 self.offset = offset
662 }
663
664 pub fn c_contiguous(&self) -> bool {
665 let mut prod_l = 1;
666 for (&s, &l) in self.strides.iter().zip(self.shape.dims().iter()).rev() {
667 if s != prod_l {
668 return false;
669 }
670 prod_l *= l
671 }
672 true
673 }
674}