use std::{cmp::Ordering, hash::Hash};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use web_rwkv_derive::{Deref, DerefMut};
use super::TensorError;
pub trait IntoBytes {
fn into_bytes(self) -> Vec<u8>;
}
#[derive(
Debug, Default, Clone, Copy, Deref, DerefMut, PartialEq, Eq, Hash, Serialize, Deserialize,
)]
pub struct Shape([usize; 4]);
impl Shape {
pub fn new(x: usize, y: usize, z: usize, w: usize) -> Self {
Self([x, y, z, w])
}
pub fn from_slice(slice: &[usize]) -> Self {
let mut shape = Self::new(1, 1, 1, 1);
for (index, &dim) in slice.iter().take(4).enumerate() {
shape[index] = dim;
}
shape
}
pub fn len(&self) -> usize {
self.0.into_iter().product()
}
pub fn is_empty(&self) -> bool {
self.0.into_iter().any(|x| x == 0)
}
pub fn shape_index(&self, indices: Shape) -> usize {
Iterator::zip(self.0.into_iter().rev(), indices.0.into_iter().rev())
.fold(0, |acc, (shape, index)| acc * shape + index)
}
}
impl IntoBytes for Shape {
fn into_bytes(self) -> Vec<u8> {
let data = self.0.map(|x| x as u32);
bytemuck::pod_collect_to_vec(&data)
}
}
impl std::cmp::PartialOrd for Shape {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
use Ordering::Equal;
match (
self[0].cmp(&other[0]),
self[1].cmp(&other[1]),
self[2].cmp(&other[2]),
self[3].cmp(&other[3]),
) {
(x, y, z, w) if x == y && y == z && z == w => Some(x),
(x, y, z, Equal) if x == y && y == z => Some(x),
(x, y, Equal, w) if x == y && y == w => Some(y),
(x, Equal, z, w) if x == z && z == w => Some(z),
(Equal, y, z, w) if y == z && z == w => Some(w),
(x, y, Equal, Equal) if x == y => Some(x),
(x, Equal, z, Equal) if x == z => Some(x),
(x, Equal, Equal, w) if x == w => Some(x),
(Equal, y, z, Equal) if y == z => Some(y),
(Equal, y, Equal, w) if y == w => Some(y),
(Equal, Equal, z, w) if z == w => Some(z),
(x, Equal, Equal, Equal) => Some(x),
(Equal, y, Equal, Equal) => Some(y),
(Equal, Equal, z, Equal) => Some(z),
(Equal, Equal, Equal, w) => Some(w),
_ => None,
}
}
}
impl std::fmt::Display for Shape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}, {}, {}, {})", self[0], self[1], self[2], self[3])
}
}
impl std::ops::Index<usize> for Shape {
type Output = usize;
fn index(&self, index: usize) -> &Self::Output {
&self.0[index]
}
}
impl std::ops::IndexMut<usize> for Shape {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.0[index]
}
}
impl std::ops::Add<Shape> for Shape {
type Output = Self;
fn add(self, rhs: Shape) -> Self::Output {
Self::new(
self[0] + rhs[0],
self[1] + rhs[1],
self[2] + rhs[2],
self[3] + rhs[3],
)
}
}
impl std::ops::Sub<Shape> for Shape {
type Output = Self;
fn sub(self, rhs: Shape) -> Self::Output {
Self::new(
self[0] - rhs[0],
self[1] - rhs[1],
self[2] - rhs[2],
self[3] - rhs[3],
)
}
}
impl std::ops::AddAssign<Shape> for Shape {
fn add_assign(&mut self, rhs: Shape) {
*self = *self + rhs;
}
}
impl std::ops::SubAssign<Shape> for Shape {
fn sub_assign(&mut self, rhs: Shape) {
*self = *self - rhs;
}
}
pub trait TensorSlice {
fn shape_bounds(&self, shape: Shape) -> Result<(Shape, Shape), TensorError>;
fn contiguous_bounds(&self, shape: Shape) -> Result<(usize, usize), TensorError>;
}
pub trait TensorAxis: Clone + PartialEq + Eq + Hash {
fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError>;
}
#[inline]
fn check_bounds(dim: usize, start: usize, end: usize) -> Result<(usize, usize), TensorError> {
if start > end || start >= dim || end > dim {
Err(TensorError::SliceOutOfRange { dim, start, end })
} else {
Ok((start, end))
}
}
impl TensorAxis for usize {
fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
let start = *self;
let end = start + 1;
check_bounds(dim, start, end)
}
}
impl TensorAxis for std::ops::RangeFull {
fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
Ok((0, dim))
}
}
impl TensorAxis for std::ops::Range<usize> {
fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
check_bounds(dim, self.start, self.end)
}
}
impl TensorAxis for std::ops::RangeInclusive<usize> {
fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
let start = *self.start();
let end = self.end() + 1;
check_bounds(dim, start, end)
}
}
impl TensorAxis for std::ops::RangeFrom<usize> {
fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
check_bounds(dim, self.start, dim)
}
}
impl TensorAxis for std::ops::RangeTo<usize> {
fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
check_bounds(dim, 0, self.end)
}
}
impl TensorAxis for std::ops::RangeToInclusive<usize> {
fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
check_bounds(dim, 0, self.end + 1)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum SliceQuantState {
Zero,
One,
Plural,
}
enum SliceFillState {
NotFull,
Full,
}
impl<X, Y, Z, W> TensorSlice for (X, Y, Z, W)
where
X: TensorAxis,
Y: TensorAxis,
Z: TensorAxis,
W: TensorAxis,
{
fn shape_bounds(&self, shape: Shape) -> Result<(Shape, Shape), TensorError> {
let mut start = Shape::default();
let mut end = Shape::default();
(start[0], end[0]) = self.0.bounds(shape[0])?;
(start[1], end[1]) = self.1.bounds(shape[1])?;
(start[2], end[2]) = self.2.bounds(shape[2])?;
(start[3], end[3]) = self.3.bounds(shape[3])?;
Ok((start, end))
}
fn contiguous_bounds(&self, shape: Shape) -> Result<(usize, usize), TensorError> {
use SliceFillState::{Full, NotFull};
use SliceQuantState::{One, Plural, Zero};
let quant_state = |start, end| match end - start {
0 => Zero,
1 => One,
_ => Plural,
};
let fill_state = |start, end, dim| match (start, end) {
(0, end) if end == dim => Full,
(start, end) if start == end => Full,
_ => NotFull,
};
let (start, end) = self.shape_bounds(shape)?;
let (_, valid) = start.iter().zip(end.iter()).zip(shape.iter()).fold(
(Full, true),
|(state, valid), ((&start, &end), &dim)| match (state, valid) {
(Full, valid) => (fill_state(start, end, dim), valid),
(NotFull, true) => (NotFull, quant_state(start, end) < Plural),
(NotFull, false) => (NotFull, false),
},
);
if !valid {
return Err(TensorError::Contiguous);
}
let len = (end - start).len();
let start = shape.shape_index(start);
Ok((start, start + len))
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TensorDimension {
#[default]
Full,
Auto,
Dimension(usize),
}
impl TensorDimension {
pub fn deduce(shape: Shape, x: Self, y: Self, z: Self, w: Self) -> Result<Shape, TensorError> {
use TensorDimension::{Auto, Dimension, Full};
let len = shape.len();
let deduced = [x, y, z, w]
.into_iter()
.enumerate()
.map(|(index, dim)| match dim {
Full => Some(shape[index]),
Auto => None,
Dimension(dim) => Some(dim),
});
let remain: usize = deduced.clone().flatten().product();
if remain == 0 || deduced.clone().filter(|x| x.is_none()).count() > 1 {
return Err(TensorError::Deduce);
};
let deduced = deduced.map(|x| x.unwrap_or(len / remain)).collect_vec();
let deduced = Shape::from_slice(&deduced);
if deduced.len() != len {
Err(TensorError::Size(deduced.len(), len))
} else {
Ok(deduced)
}
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use wgpu::PowerPreference;
use super::{Shape, TensorSlice};
use crate::{
context::{Context, ContextBuilder, Instance},
tensor::{TensorCpu, TensorInit},
};
fn create_context() -> Result<Context, anyhow::Error> {
let adapter = pollster::block_on(async {
let instance = Instance::new();
instance.adapter(PowerPreference::HighPerformance).await
})?;
let context = pollster::block_on(async {
ContextBuilder::new(adapter)
.with_default_pipelines()
.build()
.await
})?;
Ok(context)
}
#[test]
fn test_shape_index() {
let shape = Shape::new(1024, 768, 12, 1);
let indices = Shape::new(35, 42, 9, 0);
let index = shape.shape_index(indices);
assert_eq!(index, 35 + 42 * 1024 + 9 * 1024 * 768);
}
#[test]
fn test_slice() -> Result<(), anyhow::Error> {
let context = match create_context() {
Ok(context) => context,
Err(_) => return Ok(()),
};
let x: TensorCpu<f32> = context.tensor_init(Shape::new(1024, 768, 3, 1));
assert_eq!(
(12..42, 7..8, 1, 0).contiguous_bounds(x.shape)?,
(793612, 793642)
);
assert_eq!(
(.., 42..56, 2..=2, ..).shape_bounds(x.shape)?,
(Shape::new(0, 42, 2, 0), Shape::new(1024, 56, 3, 1))
);
assert!((.., 42..56, 2..3, ..).contiguous_bounds(x.shape).is_ok());
assert!((0..1, 0..1, 0..1, ..).contiguous_bounds(x.shape).is_ok());
assert!((.., 42..56, 0..2, ..).contiguous_bounds(x.shape).is_err());
assert!((0, 0..2, 1..2, ..).contiguous_bounds(x.shape).is_err());
let x: TensorCpu<f32> = context.tensor_init(Shape::new(1, 1024, 6, 1));
assert_eq!(
(.., 0..256, 3..=3, ..).contiguous_bounds(x.shape)?,
(3072, 3328)
);
let x: TensorCpu<f32> = context.tensor_init(Shape::new(1024, 768, 1, 1));
assert!((.., 0..256, .., ..).contiguous_bounds(x.shape).is_ok());
let x: TensorCpu<f32> = context.tensor_init(Shape::new(1, 768, 1, 1));
assert!((.., 256..512, .., ..).contiguous_bounds(x.shape).is_ok());
let shape = Shape::new(4, 2, 3, 1);
let x = (0..shape.len()).map(|x| x as f32).collect_vec();
let x = TensorCpu::from_data(&context, shape, x)?;
let y: Vec<_> = x.slice(.., 1..2, 1..2, ..)?.into();
assert_eq!(y, vec![12.0, 13.0, 14.0, 15.0]);
let y: Vec<_> = x.slice(.., .., 1..2, ..)?.into();
assert_eq!(y, vec![8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]);
let y: Vec<_> = x.into_slice(2.., 1.., ..0, ..)?.into();
assert_eq!(y, Vec::<f32>::new());
Ok(())
}
}