winsfs_core/sfs/
generics.rs1use std::{fmt, slice};
9
10pub trait Shape: AsRef<[usize]> + Clone + fmt::Debug + Eq + PartialEq {
15 fn strides(&self) -> Self;
17
18 fn iter(&self) -> slice::Iter<'_, usize> {
20 self.as_ref().iter()
21 }
22
23 fn len(&self) -> usize {
25 self.as_ref().len()
26 }
27
28 fn is_empty(&self) -> bool {
30 self.as_ref().is_empty()
32 }
33}
34
35pub type ConstShape<const D: usize> = [usize; D];
37
38impl<const D: usize> Shape for ConstShape<D> {
39 fn strides(&self) -> Self {
40 let mut strides = [1; D];
41 compute_strides(&self[..], &mut strides);
42 strides
43 }
44
45 fn len(&self) -> usize {
46 D
47 }
48}
49
50pub type DynShape = Box<[usize]>;
52
53impl Shape for DynShape {
54 fn strides(&self) -> Self {
55 let mut strides = vec![1; self.len()];
56 compute_strides(self, &mut strides);
57 strides.into_boxed_slice()
58 }
59}
60
61pub trait Normalisation {
68 const NORM: bool;
70}
71
72#[derive(Debug, Clone, Copy, Eq, PartialEq)]
74pub struct Norm {}
75impl Normalisation for Norm {
76 const NORM: bool = true;
77}
78
79#[derive(Debug, Clone, Copy, Eq, PartialEq)]
81pub struct Unnorm {}
82impl Normalisation for Unnorm {
83 const NORM: bool = false;
84}
85
86fn compute_strides(shape: &[usize], strides: &mut [usize]) {
88 debug_assert_eq!(shape.len(), strides.len());
89
90 for (i, v) in shape.iter().enumerate().skip(1).rev() {
91 strides.iter_mut().take(i).for_each(|stride| *stride *= v)
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn test_array_strides() {
101 assert_eq!([7].strides(), [1]);
102 assert_eq!([9, 3].strides(), [3, 1]);
103 assert_eq!([3, 7, 5].strides(), [35, 5, 1]);
104 assert_eq!([9, 3, 5, 7].strides(), [105, 35, 7, 1]);
105 }
106}