web_rwkv/tensor/
shape.rs

1use std::{cmp::Ordering, hash::Hash};
2
3use itertools::Itertools;
4use serde::{Deserialize, Serialize};
5
6use super::{TensorError, TensorErrorKind};
7
8pub trait IntoBytes {
9    fn into_bytes(self) -> Vec<u8>;
10}
11
12/// Indices along each dimension.
13#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct ShapedIndex([usize; 4]);
15
16impl ShapedIndex {
17    pub fn new(x: usize, y: usize, z: usize, w: usize) -> Self {
18        Self([x, y, z, w])
19    }
20
21    pub fn iter(&self) -> impl Iterator<Item = usize> {
22        self.0.into_iter()
23    }
24}
25
26impl From<[usize; 4]> for ShapedIndex {
27    fn from(value: [usize; 4]) -> Self {
28        Self(value)
29    }
30}
31
32impl From<(usize, usize, usize, usize)> for ShapedIndex {
33    fn from((x, y, z, w): (usize, usize, usize, usize)) -> Self {
34        Self([x, y, z, w])
35    }
36}
37
38impl std::fmt::Display for ShapedIndex {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "({}, {}, {}, {})", self[0], self[1], self[2], self[3])
41    }
42}
43
44impl std::ops::Index<usize> for ShapedIndex {
45    type Output = usize;
46
47    fn index(&self, index: usize) -> &Self::Output {
48        &self.0[index]
49    }
50}
51
52impl std::ops::IndexMut<usize> for ShapedIndex {
53    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
54        &mut self.0[index]
55    }
56}
57
58impl std::ops::Add<ShapedIndex> for ShapedIndex {
59    type Output = Self;
60
61    fn add(self, rhs: ShapedIndex) -> Self::Output {
62        Self::new(
63            self[0] + rhs[0],
64            self[1] + rhs[1],
65            self[2] + rhs[2],
66            self[3] + rhs[3],
67        )
68    }
69}
70
71impl std::ops::Sub<ShapedIndex> for ShapedIndex {
72    type Output = Self;
73
74    fn sub(self, rhs: ShapedIndex) -> Self::Output {
75        Self::new(
76            self[0] - rhs[0],
77            self[1] - rhs[1],
78            self[2] - rhs[2],
79            self[3] - rhs[3],
80        )
81    }
82}
83
84impl std::ops::AddAssign<ShapedIndex> for ShapedIndex {
85    fn add_assign(&mut self, rhs: ShapedIndex) {
86        *self = *self + rhs;
87    }
88}
89
90impl std::ops::SubAssign<ShapedIndex> for ShapedIndex {
91    fn sub_assign(&mut self, rhs: ShapedIndex) {
92        *self = *self - rhs;
93    }
94}
95
96/// The shape of a [`Tensor`](super::Tensor).
97/// Note that the fastest-moving axis occupies the lowest shape index, which is opposite to that in `torch`.
98#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
99pub struct Shape([usize; 4]);
100
101impl Shape {
102    pub fn new(x: usize, y: usize, z: usize, w: usize) -> Self {
103        Self([x, y, z, w])
104    }
105
106    pub fn from_slice(slice: &[usize]) -> Self {
107        let mut shape = Self::new(1, 1, 1, 1);
108        for (index, &dim) in slice.iter().take(4).enumerate() {
109            shape[index] = dim;
110        }
111        shape
112    }
113
114    pub fn from_slice_rev(shape: &[usize]) -> Result<Self, TensorError> {
115        let shape = match shape[..] {
116            [] => Shape::new(0, 0, 0, 0),
117            [x] => Shape::new(x, 1, 1, 1),
118            [y, x] => Shape::new(x, y, 1, 1),
119            [z, y, x] => Shape::new(x, y, z, 1),
120            [w, z, y, x] => Shape::new(x, y, z, w),
121            _ => Err(TensorErrorKind::Deduce)?,
122        };
123        Ok(shape)
124    }
125
126    pub fn len(&self) -> usize {
127        self.0.into_iter().product()
128    }
129
130    pub fn is_empty(&self) -> bool {
131        self.0.into_iter().any(|x| x == 0)
132    }
133
134    pub fn iter(&self) -> impl Iterator<Item = usize> {
135        self.0.into_iter()
136    }
137
138    /// Convert a shaped index into a linear index.
139    pub fn linear_index(&self, index: impl Into<ShapedIndex>) -> usize {
140        let index: ShapedIndex = index.into();
141        Iterator::zip(self.0.into_iter().rev(), index.0.into_iter().rev())
142            .fold(0, |acc, (shape, index)| acc * shape + index)
143    }
144
145    /// Iterate through all indices within the shape's bound.
146    pub fn cartesian_product(&self) -> impl Iterator<Item = ShapedIndex> {
147        (0..self[3])
148            .cartesian_product(0..self[2])
149            .cartesian_product(0..self[1])
150            .cartesian_product(0..self[0])
151            .map(|(((w, z), y), x)| ShapedIndex::new(x, y, z, w))
152    }
153}
154
155impl From<ShapedIndex> for Shape {
156    fn from(value: ShapedIndex) -> Self {
157        Self(value.0)
158    }
159}
160
161impl From<[usize; 4]> for Shape {
162    fn from(value: [usize; 4]) -> Self {
163        Self(value)
164    }
165}
166
167impl From<Shape> for [usize; 4] {
168    fn from(value: Shape) -> Self {
169        value.0
170    }
171}
172
173impl IntoBytes for Shape {
174    fn into_bytes(self) -> Vec<u8> {
175        let data = self.0.map(|x| x as u32);
176        bytemuck::pod_collect_to_vec(&data)
177    }
178}
179
180impl std::cmp::PartialOrd for Shape {
181    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
182        use Ordering::Equal;
183        match (
184            self[0].cmp(&other[0]),
185            self[1].cmp(&other[1]),
186            self[2].cmp(&other[2]),
187            self[3].cmp(&other[3]),
188        ) {
189            (x, y, z, w) if x == y && y == z && z == w => Some(x),
190            (x, y, z, Equal) if x == y && y == z => Some(x),
191            (x, y, Equal, w) if x == y && y == w => Some(y),
192            (x, Equal, z, w) if x == z && z == w => Some(z),
193            (Equal, y, z, w) if y == z && z == w => Some(w),
194            (x, y, Equal, Equal) if x == y => Some(x),
195            (x, Equal, z, Equal) if x == z => Some(x),
196            (x, Equal, Equal, w) if x == w => Some(x),
197            (Equal, y, z, Equal) if y == z => Some(y),
198            (Equal, y, Equal, w) if y == w => Some(y),
199            (Equal, Equal, z, w) if z == w => Some(z),
200            (x, Equal, Equal, Equal) => Some(x),
201            (Equal, y, Equal, Equal) => Some(y),
202            (Equal, Equal, z, Equal) => Some(z),
203            (Equal, Equal, Equal, w) => Some(w),
204            _ => None,
205        }
206    }
207}
208
209impl std::fmt::Display for Shape {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        write!(f, "({}, {}, {}, {})", self[0], self[1], self[2], self[3])
212    }
213}
214
215impl std::ops::Index<usize> for Shape {
216    type Output = usize;
217
218    fn index(&self, index: usize) -> &Self::Output {
219        &self.0[index]
220    }
221}
222
223impl std::ops::IndexMut<usize> for Shape {
224    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
225        &mut self.0[index]
226    }
227}
228
229pub trait TensorSlice {
230    fn shaped_bounds(&self, shape: Shape) -> Result<(ShapedIndex, ShapedIndex), TensorError>;
231    fn linear_bounds(&self, shape: Shape) -> Result<(usize, usize), TensorError>;
232}
233
234pub trait TensorAxis: Clone + PartialEq + Eq + Hash {
235    fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError>;
236}
237
238#[inline]
239fn checked_bounds(dim: usize, start: usize, end: usize) -> Result<(usize, usize), TensorError> {
240    if start > end || end - start > dim || end > dim {
241        Err(TensorErrorKind::SliceOutOfRange { dim, start, end })?
242    } else {
243        Ok((start, end))
244    }
245}
246
247impl TensorAxis for usize {
248    fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
249        let start = *self;
250        let end = start + 1;
251        checked_bounds(dim, start, end)
252    }
253}
254
255impl TensorAxis for std::ops::RangeFull {
256    fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
257        Ok((0, dim))
258    }
259}
260
261impl TensorAxis for std::ops::Range<usize> {
262    fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
263        checked_bounds(dim, self.start, self.end)
264    }
265}
266
267impl TensorAxis for std::ops::RangeInclusive<usize> {
268    fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
269        let start = *self.start();
270        let end = self.end() + 1;
271        checked_bounds(dim, start, end)
272    }
273}
274
275impl TensorAxis for std::ops::RangeFrom<usize> {
276    fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
277        checked_bounds(dim, self.start, dim)
278    }
279}
280
281impl TensorAxis for std::ops::RangeTo<usize> {
282    fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
283        checked_bounds(dim, 0, self.end)
284    }
285}
286
287impl TensorAxis for std::ops::RangeToInclusive<usize> {
288    fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
289        checked_bounds(dim, 0, self.end + 1)
290    }
291}
292
293// impl<T: std::ops::RangeBounds<usize>> TensorAxis for T {
294//     fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
295//         let start = match self.start_bound() {
296//             Bound::Included(&bound) => bound,
297//             Bound::Excluded(&bound) => bound + 1,
298//             Bound::Unbounded => 0,
299//         };
300//         let end = match self.end_bound() {
301//             Bound::Included(&bound) => bound + 1,
302//             Bound::Excluded(&bound) => bound,
303//             Bound::Unbounded => dim,
304//         };
305//         if start > end || start >= dim || end > dim {
306//             Err(TensorError::SliceOutOfRange { dim, start, end })
307//         } else {
308//             Ok((start, end))
309//         }
310//     }
311// }
312
313#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
314enum SliceQuantState {
315    Zero,
316    One,
317    Plural,
318}
319
320enum SliceFillState {
321    NotFull,
322    Full,
323}
324
325impl<X, Y, Z, W> TensorSlice for (X, Y, Z, W)
326where
327    X: TensorAxis,
328    Y: TensorAxis,
329    Z: TensorAxis,
330    W: TensorAxis,
331{
332    fn shaped_bounds(&self, shape: Shape) -> Result<(ShapedIndex, ShapedIndex), TensorError> {
333        let mut start = ShapedIndex::default();
334        let mut end = ShapedIndex::default();
335        (start[0], end[0]) = self.0.bounds(shape[0])?;
336        (start[1], end[1]) = self.1.bounds(shape[1])?;
337        (start[2], end[2]) = self.2.bounds(shape[2])?;
338        (start[3], end[3]) = self.3.bounds(shape[3])?;
339        Ok((start, end))
340    }
341
342    fn linear_bounds(&self, shape: Shape) -> Result<(usize, usize), TensorError> {
343        use SliceFillState::{Full, NotFull};
344        use SliceQuantState::{One, Plural, Zero};
345
346        let quant_state = |start, end| match end - start {
347            0 => Zero,
348            1 => One,
349            _ => Plural,
350        };
351
352        let fill_state = |start, end, dim| match (start, end) {
353            (0, end) if end == dim => Full,
354            (start, end) if start == end => Full,
355            _ => NotFull,
356        };
357
358        let (start, end) = self.shaped_bounds(shape)?;
359        let (_, valid) = itertools::multizip((start.iter(), end.iter(), shape.iter())).fold(
360            (Full, true),
361            |(state, valid), (start, end, dim)| match (state, valid) {
362                (Full, valid) => (fill_state(start, end, dim), valid),
363                (NotFull, true) => (NotFull, quant_state(start, end) < Plural),
364                (NotFull, false) => (NotFull, false),
365            },
366        );
367        if !valid {
368            Err(TensorErrorKind::SliceInvalid)?;
369        }
370
371        let len = Shape::from(end - start).len();
372        let start = shape.linear_index(start);
373        Ok((start, start + len))
374    }
375}
376
377#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
378pub enum TensorDimension {
379    #[default]
380    Full,
381    Auto,
382    Size(usize),
383}
384
385impl TensorDimension {
386    pub fn deduce(shape: Shape, x: Self, y: Self, z: Self, w: Self) -> Result<Shape, TensorError> {
387        use TensorDimension::{Auto, Full, Size};
388        let len = shape.len();
389
390        let deduced = [x, y, z, w]
391            .into_iter()
392            .enumerate()
393            .map(|(index, dim)| match dim {
394                Full => Some(shape[index]),
395                Auto => None,
396                Size(dim) => Some(dim),
397            });
398        let remain: usize = deduced.clone().flatten().product();
399
400        if remain == 0 || deduced.clone().filter(|x| x.is_none()).count() > 1 {
401            Err(TensorErrorKind::Deduce)?;
402        };
403
404        let deduced = deduced.map(|x| x.unwrap_or(len / remain)).collect_vec();
405        let deduced = Shape::from_slice(&deduced);
406
407        if deduced.len() != len {
408            Err(TensorErrorKind::Size(deduced.len(), len))?
409        } else {
410            Ok(deduced)
411        }
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use anyhow::Result;
418    use itertools::Itertools;
419    use wgpu::{Instance, PowerPreference};
420
421    use super::{Shape, TensorSlice};
422    use crate::{
423        context::{Context, ContextBuilder, InstanceExt},
424        tensor::{shape::ShapedIndex, TensorCpu, TensorInit},
425    };
426
427    async fn create_context() -> Result<Context> {
428        let instance = Instance::default();
429        let adapter = instance.adapter(PowerPreference::HighPerformance).await?;
430        let context = ContextBuilder::new(adapter)
431            // .features(Features::TIMESTAMP_QUERY | Features::TIMESTAMP_QUERY_INSIDE_PASSES)
432            .build()
433            .await?;
434        Ok(context)
435    }
436
437    #[test]
438    fn test_shaped_index() {
439        let shape = Shape::new(1024, 768, 12, 1);
440        let index = ShapedIndex::new(35, 42, 9, 0);
441        let index = shape.linear_index(index);
442        assert_eq!(index, 35 + 42 * 1024 + 9 * 1024 * 768);
443    }
444
445    #[cfg(feature = "tokio")]
446    #[tokio::test]
447    async fn test_slice() -> Result<()> {
448        let context = create_context().await?;
449
450        let x: TensorCpu<f32> = context.tensor_init([1024, 768, 3, 1]);
451        assert_eq!(
452            (12..42, 7..8, 1, 0).linear_bounds(x.shape)?,
453            (793612, 793642)
454        );
455        assert_eq!(
456            (.., 42..56, 2..=2, ..).shaped_bounds(x.shape)?,
457            (
458                ShapedIndex::new(0, 42, 2, 0),
459                ShapedIndex::new(1024, 56, 3, 1)
460            )
461        );
462        assert!((.., 42..56, 2..3, ..).linear_bounds(x.shape).is_ok());
463        assert!((0..1, 0..1, 0..1, ..).linear_bounds(x.shape).is_ok());
464        assert!((.., 42..56, 0..2, ..).linear_bounds(x.shape).is_err());
465        assert!((0, 0..2, 1..2, ..).linear_bounds(x.shape).is_err());
466
467        let x: TensorCpu<f32> = context.tensor_init([1, 1024, 6, 1]);
468        assert_eq!(
469            (.., 0..256, 3..=3, ..).linear_bounds(x.shape)?,
470            (3072, 3328)
471        );
472
473        let x: TensorCpu<f32> = context.tensor_init([1024, 768, 1, 1]);
474        assert!((.., 0..256, .., ..).linear_bounds(x.shape).is_ok());
475
476        let x: TensorCpu<f32> = context.tensor_init([1, 768, 1, 1]);
477        assert!((.., 256..512, .., ..).linear_bounds(x.shape).is_ok());
478
479        let shape = Shape::new(4, 2, 3, 1);
480        let x = (0..shape.len()).map(|x| x as f32).collect_vec();
481        let x = TensorCpu::from_data(shape, x)?;
482
483        let y: Vec<_> = x.slice(.., 1..2, 1..2, ..)?.into();
484        assert_eq!(y, vec![12.0, 13.0, 14.0, 15.0]);
485
486        let y: Vec<_> = x.slice(.., .., 1..2, ..)?.into();
487        assert_eq!(y, vec![8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]);
488
489        let y: Vec<_> = x.slice(2.., 1.., ..0, ..)?.into();
490        assert_eq!(y, Vec::<f32>::new());
491
492        Ok(())
493    }
494
495    #[test]
496    fn test_cartesian_product() -> Result<()> {
497        let shape = Shape::new(4, 3, 2, 1);
498        let indices = shape.cartesian_product().collect_vec();
499        assert_eq!(
500            indices,
501            vec![
502                [0, 0, 0, 0],
503                [1, 0, 0, 0],
504                [2, 0, 0, 0],
505                [3, 0, 0, 0],
506                [0, 1, 0, 0],
507                [1, 1, 0, 0],
508                [2, 1, 0, 0],
509                [3, 1, 0, 0],
510                [0, 2, 0, 0],
511                [1, 2, 0, 0],
512                [2, 2, 0, 0],
513                [3, 2, 0, 0],
514                [0, 0, 1, 0],
515                [1, 0, 1, 0],
516                [2, 0, 1, 0],
517                [3, 0, 1, 0],
518                [0, 1, 1, 0],
519                [1, 1, 1, 0],
520                [2, 1, 1, 0],
521                [3, 1, 1, 0],
522                [0, 2, 1, 0],
523                [1, 2, 1, 0],
524                [2, 2, 1, 0],
525                [3, 2, 1, 0],
526            ]
527            .into_iter()
528            .map(ShapedIndex::from)
529            .collect_vec()
530        );
531
532        Ok(())
533    }
534}