1use std::{cmp::Ordering, hash::Hash};
2
3use itertools::Itertools;
4use serde::{Deserialize, Serialize};
5
6use super::{TensorError, TensorErrorKind};
7
8pub trait IntoBytes {
9 fn into_bytes(self) -> Vec<u8>;
10}
11
12#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct ShapedIndex([usize; 4]);
15
16impl ShapedIndex {
17 pub fn new(x: usize, y: usize, z: usize, w: usize) -> Self {
18 Self([x, y, z, w])
19 }
20
21 pub fn iter(&self) -> impl Iterator<Item = usize> {
22 self.0.into_iter()
23 }
24}
25
26impl From<[usize; 4]> for ShapedIndex {
27 fn from(value: [usize; 4]) -> Self {
28 Self(value)
29 }
30}
31
32impl From<(usize, usize, usize, usize)> for ShapedIndex {
33 fn from((x, y, z, w): (usize, usize, usize, usize)) -> Self {
34 Self([x, y, z, w])
35 }
36}
37
38impl std::fmt::Display for ShapedIndex {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "({}, {}, {}, {})", self[0], self[1], self[2], self[3])
41 }
42}
43
44impl std::ops::Index<usize> for ShapedIndex {
45 type Output = usize;
46
47 fn index(&self, index: usize) -> &Self::Output {
48 &self.0[index]
49 }
50}
51
52impl std::ops::IndexMut<usize> for ShapedIndex {
53 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
54 &mut self.0[index]
55 }
56}
57
58impl std::ops::Add<ShapedIndex> for ShapedIndex {
59 type Output = Self;
60
61 fn add(self, rhs: ShapedIndex) -> Self::Output {
62 Self::new(
63 self[0] + rhs[0],
64 self[1] + rhs[1],
65 self[2] + rhs[2],
66 self[3] + rhs[3],
67 )
68 }
69}
70
71impl std::ops::Sub<ShapedIndex> for ShapedIndex {
72 type Output = Self;
73
74 fn sub(self, rhs: ShapedIndex) -> Self::Output {
75 Self::new(
76 self[0] - rhs[0],
77 self[1] - rhs[1],
78 self[2] - rhs[2],
79 self[3] - rhs[3],
80 )
81 }
82}
83
84impl std::ops::AddAssign<ShapedIndex> for ShapedIndex {
85 fn add_assign(&mut self, rhs: ShapedIndex) {
86 *self = *self + rhs;
87 }
88}
89
90impl std::ops::SubAssign<ShapedIndex> for ShapedIndex {
91 fn sub_assign(&mut self, rhs: ShapedIndex) {
92 *self = *self - rhs;
93 }
94}
95
96#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
99pub struct Shape([usize; 4]);
100
101impl Shape {
102 pub fn new(x: usize, y: usize, z: usize, w: usize) -> Self {
103 Self([x, y, z, w])
104 }
105
106 pub fn from_slice(slice: &[usize]) -> Self {
107 let mut shape = Self::new(1, 1, 1, 1);
108 for (index, &dim) in slice.iter().take(4).enumerate() {
109 shape[index] = dim;
110 }
111 shape
112 }
113
114 pub fn from_slice_rev(shape: &[usize]) -> Result<Self, TensorError> {
115 let shape = match shape[..] {
116 [] => Shape::new(0, 0, 0, 0),
117 [x] => Shape::new(x, 1, 1, 1),
118 [y, x] => Shape::new(x, y, 1, 1),
119 [z, y, x] => Shape::new(x, y, z, 1),
120 [w, z, y, x] => Shape::new(x, y, z, w),
121 _ => Err(TensorErrorKind::Deduce)?,
122 };
123 Ok(shape)
124 }
125
126 pub fn len(&self) -> usize {
127 self.0.into_iter().product()
128 }
129
130 pub fn is_empty(&self) -> bool {
131 self.0.into_iter().any(|x| x == 0)
132 }
133
134 pub fn iter(&self) -> impl Iterator<Item = usize> {
135 self.0.into_iter()
136 }
137
138 pub fn linear_index(&self, index: impl Into<ShapedIndex>) -> usize {
140 let index: ShapedIndex = index.into();
141 Iterator::zip(self.0.into_iter().rev(), index.0.into_iter().rev())
142 .fold(0, |acc, (shape, index)| acc * shape + index)
143 }
144
145 pub fn cartesian_product(&self) -> impl Iterator<Item = ShapedIndex> {
147 (0..self[3])
148 .cartesian_product(0..self[2])
149 .cartesian_product(0..self[1])
150 .cartesian_product(0..self[0])
151 .map(|(((w, z), y), x)| ShapedIndex::new(x, y, z, w))
152 }
153}
154
155impl From<ShapedIndex> for Shape {
156 fn from(value: ShapedIndex) -> Self {
157 Self(value.0)
158 }
159}
160
161impl From<[usize; 4]> for Shape {
162 fn from(value: [usize; 4]) -> Self {
163 Self(value)
164 }
165}
166
167impl From<Shape> for [usize; 4] {
168 fn from(value: Shape) -> Self {
169 value.0
170 }
171}
172
173impl IntoBytes for Shape {
174 fn into_bytes(self) -> Vec<u8> {
175 let data = self.0.map(|x| x as u32);
176 bytemuck::pod_collect_to_vec(&data)
177 }
178}
179
180impl std::cmp::PartialOrd for Shape {
181 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
182 use Ordering::Equal;
183 match (
184 self[0].cmp(&other[0]),
185 self[1].cmp(&other[1]),
186 self[2].cmp(&other[2]),
187 self[3].cmp(&other[3]),
188 ) {
189 (x, y, z, w) if x == y && y == z && z == w => Some(x),
190 (x, y, z, Equal) if x == y && y == z => Some(x),
191 (x, y, Equal, w) if x == y && y == w => Some(y),
192 (x, Equal, z, w) if x == z && z == w => Some(z),
193 (Equal, y, z, w) if y == z && z == w => Some(w),
194 (x, y, Equal, Equal) if x == y => Some(x),
195 (x, Equal, z, Equal) if x == z => Some(x),
196 (x, Equal, Equal, w) if x == w => Some(x),
197 (Equal, y, z, Equal) if y == z => Some(y),
198 (Equal, y, Equal, w) if y == w => Some(y),
199 (Equal, Equal, z, w) if z == w => Some(z),
200 (x, Equal, Equal, Equal) => Some(x),
201 (Equal, y, Equal, Equal) => Some(y),
202 (Equal, Equal, z, Equal) => Some(z),
203 (Equal, Equal, Equal, w) => Some(w),
204 _ => None,
205 }
206 }
207}
208
209impl std::fmt::Display for Shape {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 write!(f, "({}, {}, {}, {})", self[0], self[1], self[2], self[3])
212 }
213}
214
215impl std::ops::Index<usize> for Shape {
216 type Output = usize;
217
218 fn index(&self, index: usize) -> &Self::Output {
219 &self.0[index]
220 }
221}
222
223impl std::ops::IndexMut<usize> for Shape {
224 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
225 &mut self.0[index]
226 }
227}
228
229pub trait TensorSlice {
230 fn shaped_bounds(&self, shape: Shape) -> Result<(ShapedIndex, ShapedIndex), TensorError>;
231 fn linear_bounds(&self, shape: Shape) -> Result<(usize, usize), TensorError>;
232}
233
234pub trait TensorAxis: Clone + PartialEq + Eq + Hash {
235 fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError>;
236}
237
238#[inline]
239fn checked_bounds(dim: usize, start: usize, end: usize) -> Result<(usize, usize), TensorError> {
240 if start > end || end - start > dim || end > dim {
241 Err(TensorErrorKind::SliceOutOfRange { dim, start, end })?
242 } else {
243 Ok((start, end))
244 }
245}
246
247impl TensorAxis for usize {
248 fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
249 let start = *self;
250 let end = start + 1;
251 checked_bounds(dim, start, end)
252 }
253}
254
255impl TensorAxis for std::ops::RangeFull {
256 fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
257 Ok((0, dim))
258 }
259}
260
261impl TensorAxis for std::ops::Range<usize> {
262 fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
263 checked_bounds(dim, self.start, self.end)
264 }
265}
266
267impl TensorAxis for std::ops::RangeInclusive<usize> {
268 fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
269 let start = *self.start();
270 let end = self.end() + 1;
271 checked_bounds(dim, start, end)
272 }
273}
274
275impl TensorAxis for std::ops::RangeFrom<usize> {
276 fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
277 checked_bounds(dim, self.start, dim)
278 }
279}
280
281impl TensorAxis for std::ops::RangeTo<usize> {
282 fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
283 checked_bounds(dim, 0, self.end)
284 }
285}
286
287impl TensorAxis for std::ops::RangeToInclusive<usize> {
288 fn bounds(&self, dim: usize) -> Result<(usize, usize), TensorError> {
289 checked_bounds(dim, 0, self.end + 1)
290 }
291}
292
293#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
314enum SliceQuantState {
315 Zero,
316 One,
317 Plural,
318}
319
320enum SliceFillState {
321 NotFull,
322 Full,
323}
324
325impl<X, Y, Z, W> TensorSlice for (X, Y, Z, W)
326where
327 X: TensorAxis,
328 Y: TensorAxis,
329 Z: TensorAxis,
330 W: TensorAxis,
331{
332 fn shaped_bounds(&self, shape: Shape) -> Result<(ShapedIndex, ShapedIndex), TensorError> {
333 let mut start = ShapedIndex::default();
334 let mut end = ShapedIndex::default();
335 (start[0], end[0]) = self.0.bounds(shape[0])?;
336 (start[1], end[1]) = self.1.bounds(shape[1])?;
337 (start[2], end[2]) = self.2.bounds(shape[2])?;
338 (start[3], end[3]) = self.3.bounds(shape[3])?;
339 Ok((start, end))
340 }
341
342 fn linear_bounds(&self, shape: Shape) -> Result<(usize, usize), TensorError> {
343 use SliceFillState::{Full, NotFull};
344 use SliceQuantState::{One, Plural, Zero};
345
346 let quant_state = |start, end| match end - start {
347 0 => Zero,
348 1 => One,
349 _ => Plural,
350 };
351
352 let fill_state = |start, end, dim| match (start, end) {
353 (0, end) if end == dim => Full,
354 (start, end) if start == end => Full,
355 _ => NotFull,
356 };
357
358 let (start, end) = self.shaped_bounds(shape)?;
359 let (_, valid) = itertools::multizip((start.iter(), end.iter(), shape.iter())).fold(
360 (Full, true),
361 |(state, valid), (start, end, dim)| match (state, valid) {
362 (Full, valid) => (fill_state(start, end, dim), valid),
363 (NotFull, true) => (NotFull, quant_state(start, end) < Plural),
364 (NotFull, false) => (NotFull, false),
365 },
366 );
367 if !valid {
368 Err(TensorErrorKind::SliceInvalid)?;
369 }
370
371 let len = Shape::from(end - start).len();
372 let start = shape.linear_index(start);
373 Ok((start, start + len))
374 }
375}
376
377#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
378pub enum TensorDimension {
379 #[default]
380 Full,
381 Auto,
382 Size(usize),
383}
384
385impl TensorDimension {
386 pub fn deduce(shape: Shape, x: Self, y: Self, z: Self, w: Self) -> Result<Shape, TensorError> {
387 use TensorDimension::{Auto, Full, Size};
388 let len = shape.len();
389
390 let deduced = [x, y, z, w]
391 .into_iter()
392 .enumerate()
393 .map(|(index, dim)| match dim {
394 Full => Some(shape[index]),
395 Auto => None,
396 Size(dim) => Some(dim),
397 });
398 let remain: usize = deduced.clone().flatten().product();
399
400 if remain == 0 || deduced.clone().filter(|x| x.is_none()).count() > 1 {
401 Err(TensorErrorKind::Deduce)?;
402 };
403
404 let deduced = deduced.map(|x| x.unwrap_or(len / remain)).collect_vec();
405 let deduced = Shape::from_slice(&deduced);
406
407 if deduced.len() != len {
408 Err(TensorErrorKind::Size(deduced.len(), len))?
409 } else {
410 Ok(deduced)
411 }
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use anyhow::Result;
418 use itertools::Itertools;
419 use wgpu::{Instance, PowerPreference};
420
421 use super::{Shape, TensorSlice};
422 use crate::{
423 context::{Context, ContextBuilder, InstanceExt},
424 tensor::{shape::ShapedIndex, TensorCpu, TensorInit},
425 };
426
427 async fn create_context() -> Result<Context> {
428 let instance = Instance::default();
429 let adapter = instance.adapter(PowerPreference::HighPerformance).await?;
430 let context = ContextBuilder::new(adapter)
431 .build()
433 .await?;
434 Ok(context)
435 }
436
437 #[test]
438 fn test_shaped_index() {
439 let shape = Shape::new(1024, 768, 12, 1);
440 let index = ShapedIndex::new(35, 42, 9, 0);
441 let index = shape.linear_index(index);
442 assert_eq!(index, 35 + 42 * 1024 + 9 * 1024 * 768);
443 }
444
445 #[cfg(feature = "tokio")]
446 #[tokio::test]
447 async fn test_slice() -> Result<()> {
448 let context = create_context().await?;
449
450 let x: TensorCpu<f32> = context.tensor_init([1024, 768, 3, 1]);
451 assert_eq!(
452 (12..42, 7..8, 1, 0).linear_bounds(x.shape)?,
453 (793612, 793642)
454 );
455 assert_eq!(
456 (.., 42..56, 2..=2, ..).shaped_bounds(x.shape)?,
457 (
458 ShapedIndex::new(0, 42, 2, 0),
459 ShapedIndex::new(1024, 56, 3, 1)
460 )
461 );
462 assert!((.., 42..56, 2..3, ..).linear_bounds(x.shape).is_ok());
463 assert!((0..1, 0..1, 0..1, ..).linear_bounds(x.shape).is_ok());
464 assert!((.., 42..56, 0..2, ..).linear_bounds(x.shape).is_err());
465 assert!((0, 0..2, 1..2, ..).linear_bounds(x.shape).is_err());
466
467 let x: TensorCpu<f32> = context.tensor_init([1, 1024, 6, 1]);
468 assert_eq!(
469 (.., 0..256, 3..=3, ..).linear_bounds(x.shape)?,
470 (3072, 3328)
471 );
472
473 let x: TensorCpu<f32> = context.tensor_init([1024, 768, 1, 1]);
474 assert!((.., 0..256, .., ..).linear_bounds(x.shape).is_ok());
475
476 let x: TensorCpu<f32> = context.tensor_init([1, 768, 1, 1]);
477 assert!((.., 256..512, .., ..).linear_bounds(x.shape).is_ok());
478
479 let shape = Shape::new(4, 2, 3, 1);
480 let x = (0..shape.len()).map(|x| x as f32).collect_vec();
481 let x = TensorCpu::from_data(shape, x)?;
482
483 let y: Vec<_> = x.slice(.., 1..2, 1..2, ..)?.into();
484 assert_eq!(y, vec![12.0, 13.0, 14.0, 15.0]);
485
486 let y: Vec<_> = x.slice(.., .., 1..2, ..)?.into();
487 assert_eq!(y, vec![8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]);
488
489 let y: Vec<_> = x.slice(2.., 1.., ..0, ..)?.into();
490 assert_eq!(y, Vec::<f32>::new());
491
492 Ok(())
493 }
494
495 #[test]
496 fn test_cartesian_product() -> Result<()> {
497 let shape = Shape::new(4, 3, 2, 1);
498 let indices = shape.cartesian_product().collect_vec();
499 assert_eq!(
500 indices,
501 vec![
502 [0, 0, 0, 0],
503 [1, 0, 0, 0],
504 [2, 0, 0, 0],
505 [3, 0, 0, 0],
506 [0, 1, 0, 0],
507 [1, 1, 0, 0],
508 [2, 1, 0, 0],
509 [3, 1, 0, 0],
510 [0, 2, 0, 0],
511 [1, 2, 0, 0],
512 [2, 2, 0, 0],
513 [3, 2, 0, 0],
514 [0, 0, 1, 0],
515 [1, 0, 1, 0],
516 [2, 0, 1, 0],
517 [3, 0, 1, 0],
518 [0, 1, 1, 0],
519 [1, 1, 1, 0],
520 [2, 1, 1, 0],
521 [3, 1, 1, 0],
522 [0, 2, 1, 0],
523 [1, 2, 1, 0],
524 [2, 2, 1, 0],
525 [3, 2, 1, 0],
526 ]
527 .into_iter()
528 .map(ShapedIndex::from)
529 .collect_vec()
530 );
531
532 Ok(())
533 }
534}