1use std::{borrow::Cow, marker::PhantomData, sync::Arc};
2
3use itertools::Itertools;
4use shape::ShapedIndex;
5use thiserror::Error;
6use wgpu::{
7 BindGroupLayoutEntry, BindingResource, BindingType, Buffer, BufferBinding, BufferBindingType,
8 BufferUsages, ShaderStages,
9};
10
11use self::{
12 kind::{Kind, ReadWrite, Uniform},
13 shape::{IntoBytes, Shape, TensorAxis, TensorDimension, TensorSlice},
14};
15use crate::{
16 context::Context,
17 num::{Float, Scalar},
18};
19
20pub mod cache;
21pub mod matrix;
22pub mod ops;
23pub mod serialization;
24pub mod shape;
25
26#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
28pub struct View {
29 pub shape: Shape,
30 pub stride: Shape,
31 pub offset: Shape,
32}
33
34impl IntoBytes for View {
35 fn into_bytes(self) -> Vec<u8> {
36 [
37 self.shape.into_bytes(),
38 self.stride.into_bytes(),
39 self.offset.into_bytes(),
40 ]
41 .concat()
42 }
43}
44
45#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct Cursor {
48 pub batch: usize,
49 pub token: usize,
50 pub len: usize,
51}
52
53impl Cursor {
54 pub fn pack(self) -> u32 {
55 let batch = self.batch as u8;
56 let token = (self.token as u16).to_ne_bytes();
57 let len = self.len as u8;
58 bytemuck::cast([batch, token[0], token[1], len])
59 }
60}
61
62pub trait IntoPackedCursors {
63 fn into_stack(self) -> Vec<u32>;
64 fn into_cursors(self) -> Vec<u32>;
65}
66
67impl IntoPackedCursors for Vec<Cursor> {
68 fn into_stack(self) -> Vec<u32> {
69 self.into_iter()
70 .filter(|cursor| cursor.len > 0)
71 .map(Cursor::pack)
72 .collect()
73 }
74
75 fn into_cursors(self) -> Vec<u32> {
76 self.into_iter()
77 .filter(|cursor| cursor.len > 0)
78 .map(|cursor| {
79 let repeat = cursor.len;
80 vec![cursor.pack(); repeat]
81 })
82 .collect_vec()
83 .concat()
84 }
85}
86
87#[derive(Debug)]
88pub struct TensorError {
89 pub error: TensorErrorKind,
90 #[cfg(feature = "backtrace")]
91 pub backtrace: std::backtrace::Backtrace,
92}
93
94impl TensorError {
95 #[cfg(feature = "backtrace")]
96 pub fn new(error: TensorErrorKind) -> Self {
97 let backtrace = std::backtrace::Backtrace::capture();
98 Self { error, backtrace }
99 }
100
101 #[cfg(not(feature = "backtrace"))]
102 pub fn new(error: TensorErrorKind) -> Self {
103 Self { error }
104 }
105}
106
107impl From<TensorErrorKind> for TensorError {
108 fn from(value: TensorErrorKind) -> Self {
109 Self::new(value)
110 }
111}
112
113impl std::fmt::Display for TensorError {
114 #[cfg(feature = "backtrace")]
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 writeln!(f, "{}\n\nBacktrace:\n{}", self.error, self.backtrace)
117 }
118
119 #[cfg(not(feature = "backtrace"))]
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 writeln!(f, "{}", self.error)
122 }
123}
124
125impl std::error::Error for TensorError {}
126
127#[derive(Debug, Error)]
128pub enum TensorErrorKind {
129 #[error("list must not be empty")]
130 Empty,
131 #[error("data type mismatch")]
132 Type,
133 #[error("data size not match: {0} vs. {1}")]
134 Size(usize, usize),
135 #[error("batch size not match: {0} vs. {1}")]
136 Batch(usize, usize),
137 #[error("tensor shape not match: {0} vs. {1}")]
138 Shape(Shape, Shape),
139 #[error("cannot deduce dimension")]
140 Deduce,
141 #[error("batch {batch} out of range of max {max}")]
142 BatchOutOfRange { batch: usize, max: usize },
143 #[error("slice {start}..{end} out of range for dimension size {dim}")]
144 SliceOutOfRange {
145 dim: usize,
146 start: usize,
147 end: usize,
148 },
149 #[error("slice not contiguous")]
150 SliceInvalid,
151 #[error("cannot split along the axis {0}")]
152 SplitInvalid(usize),
153 #[error("possible tensor error(s):\n{0}")]
154 Any(#[from] AnyTensorError),
155}
156
157#[derive(Debug, Error)]
158pub struct AnyTensorError(pub Vec<Box<dyn std::error::Error + Send + Sync>>);
159
160impl std::fmt::Display for AnyTensorError {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 self.0
163 .iter()
164 .enumerate()
165 .try_for_each(|(index, error)| writeln!(f, "{index}. {error}"))
166 }
167}
168
169pub trait DeepClone: Sized {
170 fn deep_clone(&self) -> Self;
171}
172
173pub trait TensorScalar {
174 type T: Scalar;
175}
176
177pub trait TensorInitContext<T: Scalar>: Sized {
178 fn from_data<'a, S, D>(context: &Context, shape: S, data: D) -> Result<Self, TensorError>
180 where
181 S: Into<Shape>,
182 D: Into<Cow<'a, [T]>>;
183 fn init(context: &Context, shape: impl Into<Shape>) -> Self;
185}
186
187pub trait TensorInit<T: Scalar>: Sized {
188 fn from_data<'a, S, D>(shape: S, data: D) -> Result<Self, TensorError>
190 where
191 S: Into<Shape>,
192 D: Into<Cow<'a, [T]>>;
193 fn init(shape: impl Into<Shape>) -> Self;
195
196 fn from_data_1d<'a>(data: impl Into<Cow<'a, [T]>>) -> Self {
198 let data: Cow<'_, [T]> = data.into();
199 let shape = [data.len(), 1, 1, 1];
200 Self::from_data(shape, data).expect("tensor 1d from data")
201 }
202}
203
204pub trait TensorInto<Into> {
205 fn to(self, context: &Context) -> Into;
206}
207
208pub trait TensorShape: Sized {
209 fn shape(&self) -> Shape;
211
212 fn check_shape(&self, shape: impl Into<Shape>) -> Result<(), TensorError> {
214 let shape = shape.into();
215 (self.shape() == shape)
216 .then_some(())
217 .ok_or(TensorErrorKind::Shape(self.shape(), shape))
218 .map_err(Into::into)
219 }
220
221 fn check_shape_any<S>(&self, shapes: &[S]) -> Result<(), TensorError>
223 where
224 S: Into<Shape> + ToOwned<Owned = S>,
225 {
226 let (oks, errors): (Vec<_>, Vec<_>) = shapes
227 .iter()
228 .map(|shape| self.check_shape(shape.to_owned()).map_err(Into::into))
229 .partition_result();
230 match oks.is_empty() {
231 true => Err(TensorErrorKind::Any(AnyTensorError(errors)))?,
232 false => Ok(()),
233 }
234 }
235}
236
237pub trait TensorReshape: Sized {
238 fn reshape(
239 &self,
240 x: TensorDimension,
241 y: TensorDimension,
242 z: TensorDimension,
243 w: TensorDimension,
244 ) -> Result<Self, TensorError>;
245}
246
247pub trait TensorResource {
248 fn resource_key(&self) -> ResourceKey;
250 fn meta_binding(&self) -> BindingResource<'_>;
252 fn binding(&self) -> BindingResource<'_>;
254}
255
256#[derive(Debug)]
258pub struct Tensor<D: Device, T: Scalar> {
259 shape: Shape,
260 data: D::Data,
261 id: uid::Id<TensorId>,
262 phantom: PhantomData<T>,
263}
264
265#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
266pub struct TensorId;
267
268#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
270pub struct ResourceKey {
271 pub id: uid::Id<TensorId>,
272 pub view: View,
273}
274
275pub trait Device: sealed::Sealed {
276 type Data: Clone;
277}
278
279#[derive(Debug)]
280pub struct Cpu<T: Scalar>(PhantomData<T>);
281
282#[derive(Debug)]
283pub struct Gpu<K: Kind>(PhantomData<K>);
284
285impl<T: Scalar> Device for Cpu<T> {
286 type Data = Arc<[T]>;
287}
288
289impl<K: Kind> Device for Gpu<K> {
290 type Data = TensorGpuData;
291}
292
293#[derive(Debug, Clone)]
295pub struct TensorGpuData {
296 pub context: Context,
297 pub meta: Arc<Buffer>,
298 pub buffer: Arc<Buffer>,
299}
300
301pub mod kind {
302 use web_rwkv_derive::Kind;
303 use wgpu::BufferUsages;
304
305 use super::sealed;
306
307 pub trait Kind: sealed::Sealed {
308 fn buffer_usages() -> BufferUsages;
309 }
310
311 #[derive(Debug, Kind)]
313 #[usage(UNIFORM, COPY_DST, COPY_SRC)]
314 pub struct Uniform;
315
316 #[derive(Debug, Kind)]
318 #[usage(STORAGE, COPY_DST, COPY_SRC)]
319 pub struct ReadWrite;
320}
321
322pub type TensorCpu<T> = Tensor<Cpu<T>, T>;
323pub type TensorGpu<T, K> = Tensor<Gpu<K>, T>;
324
325impl<D: Device, T: Scalar> Clone for Tensor<D, T> {
326 fn clone(&self) -> Self {
327 Self {
328 shape: self.shape,
329 data: self.data.clone(),
330 id: self.id,
331 phantom: PhantomData,
332 }
333 }
334}
335
336impl<D: Device, T: Scalar> std::ops::Deref for Tensor<D, T> {
337 type Target = D::Data;
338
339 #[inline]
340 fn deref(&self) -> &Self::Target {
341 &self.data
342 }
343}
344
345impl<D: Device, T: Scalar> TensorScalar for Tensor<D, T> {
346 type T = T;
347}
348
349impl<D: Device, T: Scalar> Tensor<D, T> {
350 #[inline]
351 pub fn len(&self) -> usize {
352 self.shape.len()
353 }
354
355 #[inline]
356 pub fn is_empty(&self) -> bool {
357 self.shape.is_empty()
358 }
359
360 #[inline]
362 pub fn size(&self) -> usize {
363 self.len() * T::size()
364 }
365
366 #[inline]
368 pub fn offset(index: usize) -> usize {
369 index * T::size()
370 }
371
372 #[inline]
373 pub fn data(&self) -> &D::Data {
374 &self.data
375 }
376
377 #[inline]
378 pub fn id(&self) -> uid::Id<TensorId> {
379 self.id
380 }
381}
382
383impl<D: Device, F: Float> Tensor<D, F> {
384 #[inline]
385 pub const fn def(&self) -> &'static str {
386 F::DEF
387 }
388}
389
390impl<T: Scalar> TensorInit<T> for TensorCpu<T> {
391 fn from_data<'a, S, D>(shape: S, data: D) -> Result<Self, TensorError>
392 where
393 S: Into<Shape>,
394 D: Into<Cow<'a, [T]>>,
395 {
396 let shape = shape.into();
397 let data: Cow<'_, _> = data.into();
398 if shape.len() != data.len() {
399 Err(TensorErrorKind::Size(shape.len(), data.len()))?;
400 }
401 let data = data.into_owned().into();
402 Ok(Self {
403 shape,
404 data,
405 id: uid::Id::new(),
406 phantom: PhantomData,
407 })
408 }
409
410 #[inline]
411 fn init(shape: impl Into<Shape>) -> Self {
412 let shape = shape.into();
413 let data = vec![T::zero(); shape.len()].into();
414 Self {
415 shape,
416 data,
417 id: uid::Id::new(),
418 phantom: PhantomData,
419 }
420 }
421}
422
423impl<T: Scalar> TensorInitContext<T> for TensorCpu<T> {
424 fn from_data<'a, S, D>(_context: &Context, shape: S, data: D) -> Result<Self, TensorError>
425 where
426 S: Into<Shape>,
427 D: Into<Cow<'a, [T]>>,
428 {
429 TensorInit::from_data(shape, data)
430 }
431
432 fn init(_context: &Context, shape: impl Into<Shape>) -> Self {
433 TensorInit::init(shape)
434 }
435}
436
437impl<T: Scalar> TensorInto<TensorCpu<T>> for TensorCpu<T> {
438 fn to(self, _: &Context) -> Self {
439 self
440 }
441}
442
443impl<T: Scalar> DeepClone for TensorCpu<T> {
444 fn deep_clone(&self) -> Self {
445 self.clone()
446 }
447}
448
449impl<T: Scalar> TensorShape for TensorCpu<T> {
450 #[inline]
451 fn shape(&self) -> Shape {
452 self.shape
453 }
454}
455
456impl<T: Scalar> TensorReshape for TensorCpu<T> {
457 #[inline]
458 fn reshape(
459 &self,
460 x: TensorDimension,
461 y: TensorDimension,
462 z: TensorDimension,
463 w: TensorDimension,
464 ) -> Result<Self, TensorError> {
465 let shape = TensorDimension::deduce(self.shape, x, y, z, w)?;
466 Ok(Self {
467 shape,
468 ..self.clone()
469 })
470 }
471}
472
473impl<T: Scalar, K: Kind> TensorInitContext<T> for TensorGpu<T, K> {
474 fn from_data<'a, S, D>(context: &Context, shape: S, data: D) -> Result<Self, TensorError>
475 where
476 S: Into<Shape>,
477 D: Into<Cow<'a, [T]>>,
478 {
479 let tensor: TensorCpu<T> = TensorInit::from_data(shape, data)?;
480 Ok(tensor.to(context))
481 }
482
483 fn init(context: &Context, shape: impl Into<Shape>) -> Self {
484 let context = context.clone();
485 let shape = shape.into();
486 let meta = context.checkout_shape_uniform(shape);
487 let size = shape.len() * std::mem::size_of::<T>();
488 let buffer = context.checkout_buffer(size, K::buffer_usages());
489 Self {
490 shape,
491 data: TensorGpuData {
492 context,
493 meta,
494 buffer,
495 },
496 id: uid::Id::new(),
497 phantom: PhantomData,
498 }
499 }
500}
501
502impl<T: Scalar, K: Kind> TensorInto<TensorGpu<T, K>> for TensorCpu<T> {
503 fn to(self, context: &Context) -> TensorGpu<T, K> {
504 let Tensor { shape, data, .. } = self;
505 let context = context.clone();
506 let meta = context.checkout_shape_uniform(shape);
507 let contents = bytemuck::cast_slice(&data);
508 let buffer = context.checkout_buffer_init(contents, K::buffer_usages());
509 TensorGpu {
510 shape,
511 data: TensorGpuData {
512 context,
513 meta,
514 buffer,
515 },
516 id: uid::Id::new(),
517 phantom: PhantomData,
518 }
519 }
520}
521
522#[cfg(not(target_arch = "wasm32"))]
523impl<T: Scalar> TensorInto<TensorGpu<T, ReadWrite>> for TensorGpu<T, ReadWrite> {
524 fn to(self, context: &Context) -> Self {
525 match context {
526 context if context == &self.context => self,
527 _ => self.back_in_place().to(context),
528 }
529 }
530}
531
532#[cfg(target_arch = "wasm32")]
533impl<T: Scalar> TensorInto<TensorGpu<T, ReadWrite>> for TensorGpu<T, ReadWrite> {
534 fn to(self, _: &Context) -> Self {
535 self
536 }
537}
538
539impl<T: Scalar, K: Kind> TensorShape for TensorGpu<T, K> {
540 #[inline]
541 fn shape(&self) -> Shape {
542 self.shape
543 }
544}
545
546impl<T: Scalar, K: Kind> TensorReshape for TensorGpu<T, K> {
547 #[inline]
548 fn reshape(
549 &self,
550 x: TensorDimension,
551 y: TensorDimension,
552 z: TensorDimension,
553 w: TensorDimension,
554 ) -> Result<Self, TensorError> {
555 let shape = TensorDimension::deduce(self.shape, x, y, z, w)?;
556 let context = self.context.clone();
557 let meta = context.checkout_shape_uniform(shape);
558 let buffer = self.buffer.clone();
559 Ok(Self {
560 shape,
561 data: TensorGpuData {
562 context,
563 meta,
564 buffer,
565 },
566 ..self.clone()
567 })
568 }
569}
570
571impl<T: Scalar, K: Kind> TensorResource for TensorGpu<T, K> {
572 #[inline]
573 fn resource_key(&self) -> ResourceKey {
574 let id = self.id;
575 let view = View {
576 shape: self.shape,
577 stride: self.shape,
578 offset: [0, 0, 0, 0].into(),
579 };
580 ResourceKey { id, view }
581 }
582
583 #[inline]
584 fn meta_binding(&self) -> BindingResource<'_> {
585 BindingResource::Buffer(BufferBinding {
586 buffer: &self.meta,
587 offset: 0,
588 size: None,
589 })
590 }
591
592 #[inline]
593 fn binding(&self) -> BindingResource<'_> {
594 BindingResource::Buffer(BufferBinding {
595 buffer: &self.buffer,
596 offset: 0,
597 size: None,
598 })
599 }
600}
601
602impl<T: Scalar, K: Kind> TensorGpu<T, K> {
603 pub fn from_data_u8(
604 context: &Context,
605 shape: impl Into<Shape>,
606 contents: &[u8],
607 ) -> Result<Self, TensorError> {
608 let shape = shape.into();
609 let size = shape.len() * size_of::<T>();
610 if contents.len() != size {
611 Err(TensorErrorKind::Size(size, contents.len()))?;
612 }
613 let buffer = context.checkout_buffer_init(contents, K::buffer_usages());
614 let meta = context.checkout_shape_uniform(shape);
615 Ok(Self {
616 shape,
617 data: TensorGpuData {
618 context: context.clone(),
619 meta,
620 buffer,
621 },
622 id: uid::Id::new(),
623 phantom: PhantomData,
624 })
625 }
626
627 #[cfg(not(target_arch = "wasm32"))]
628 pub fn back_in_place(&self) -> TensorCpu<T> {
629 use crate::context::ContextEvent;
630
631 if self.is_empty() {
632 return TensorCpu {
633 shape: self.shape,
634 data: Arc::new([]),
635 id: uid::Id::new(),
636 phantom: PhantomData,
637 };
638 }
639
640 let context = &self.context;
641 let size = self.buffer.size();
642 let buffer = context.checkout_buffer(
643 size as usize,
644 BufferUsages::MAP_READ | BufferUsages::COPY_DST,
645 );
646
647 let mut encoder = context.device.create_command_encoder(&Default::default());
648 encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
649 context.queue.submit(Some(encoder.finish()));
650
651 let (sender, receiver) = flume::bounded(1);
652 let _ = context.event().send(ContextEvent { buffer, sender });
653 let data = receiver.recv().expect("failed to receive read back buffer");
654 let data = unsafe {
655 let data = Box::leak(data);
656 let slice = bytemuck::cast_slice_mut::<_, T>(data);
657 Box::from_raw(slice)
658 };
659 let data = data.into_vec().into();
660 let shape = self.shape;
661
662 TensorCpu {
663 shape,
664 data,
665 id: uid::Id::new(),
666 phantom: PhantomData,
667 }
668 }
669
670 #[cfg(not(target_arch = "wasm32"))]
671 pub async fn back(&self) -> TensorCpu<T> {
672 if self.is_empty() {
673 return TensorCpu {
674 shape: self.shape,
675 data: Arc::new([]),
676 id: uid::Id::new(),
677 phantom: PhantomData,
678 };
679 }
680
681 let context = &self.context;
682 let size = self.buffer.size();
683 let buffer = context.checkout_buffer(
684 size as usize,
685 BufferUsages::MAP_READ | BufferUsages::COPY_DST,
686 );
687
688 let mut encoder = context.device.create_command_encoder(&Default::default());
689 encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
690 context.queue.submit(Some(encoder.finish()));
691
692 let (sender, receiver) = flume::bounded(1);
693
694 let _ = context
695 .event()
696 .send(crate::context::ContextEvent { buffer, sender });
697 let data = receiver
698 .recv_async()
699 .await
700 .expect("failed to receive read back buffer");
701 let data = unsafe {
702 let data = Box::leak(data);
703 let slice = bytemuck::cast_slice_mut::<_, T>(data);
704 Box::from_raw(slice)
705 };
706 let data = data.into_vec().into();
707
708 TensorCpu {
709 shape: self.shape,
710 data,
711 id: uid::Id::new(),
712 phantom: PhantomData,
713 }
714 }
715
716 #[cfg(target_arch = "wasm32")]
717 pub async fn back(self) -> TensorCpu<T> {
718 if self.is_empty() {
719 return TensorCpu {
720 shape: self.shape,
721 data: Arc::new([]),
722 id: uid::Id::new(),
723 phantom: PhantomData,
724 };
725 }
726
727 let context = &self.context;
728 let size = self.buffer.size();
729 let buffer = context.checkout_buffer(
730 size as usize,
731 BufferUsages::MAP_READ | BufferUsages::COPY_DST,
732 );
733
734 let mut encoder = context.device.create_command_encoder(&Default::default());
735 encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
736 context.queue.submit(Some(encoder.finish()));
737
738 let (sender, receiver) = flume::unbounded();
739
740 let slice = buffer.slice(..);
741 slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
742
743 _ = context.device.poll(wgpu::PollType::Wait);
744 receiver
745 .recv_async()
746 .await
747 .expect("failed to receive read back buffer")
748 .expect("failed to map buffer");
749
750 let data = {
751 let map = slice.get_mapped_range();
752 Vec::from(bytemuck::cast_slice(&map)).into()
753 };
754 buffer.unmap();
755
756 TensorCpu {
757 shape: self.shape,
758 data,
759 id: uid::Id::new(),
760 phantom: PhantomData,
761 }
762 }
763}
764
765impl<T: Scalar, K: Kind> TensorGpu<T, K> {
766 #[inline]
767 pub fn context(&self) -> &Context {
768 &self.context
769 }
770
771 pub fn load(&self, host: &TensorCpu<T>) -> Result<(), TensorError> {
772 host.check_shape(self.shape)?;
773 self.context
774 .queue
775 .write_buffer(&self.buffer, 0, bytemuck::cast_slice(&host.data[..]));
776 Ok(())
777 }
778
779 pub fn load_batch(&self, host: &TensorCpu<T>, batch: usize) -> Result<(), TensorError> {
780 host.check_shape([self.shape[0], self.shape[1], 1, 1])?;
781 if batch >= self.shape[2] {
782 Err(TensorErrorKind::BatchOutOfRange {
783 batch,
784 max: self.shape[2],
785 })?;
786 }
787 let offset = (T::size() * self.shape[0] * self.shape[1] * batch) as u64;
788 self.context
789 .queue
790 .write_buffer(&self.buffer, offset, bytemuck::cast_slice(&host.data[..]));
791 Ok(())
792 }
793
794 pub fn destroy(self) {
795 self.buffer.destroy();
796 }
797}
798
799impl<T: Scalar> TensorGpu<T, Uniform> {
800 #[inline]
801 pub fn layout(&self, binding: u32) -> BindGroupLayoutEntry {
802 BindGroupLayoutEntry {
803 binding,
804 visibility: ShaderStages::COMPUTE,
805 ty: BindingType::Buffer {
806 ty: BufferBindingType::Uniform,
807 has_dynamic_offset: false,
808 min_binding_size: None,
809 },
810 count: None,
811 }
812 }
813}
814
815impl<T: Scalar> TensorGpu<T, ReadWrite> {
816 #[inline]
817 pub fn meta_layout(&self, binding: u32) -> BindGroupLayoutEntry {
818 BindGroupLayoutEntry {
819 binding,
820 visibility: ShaderStages::COMPUTE,
821 ty: BindingType::Buffer {
822 ty: BufferBindingType::Uniform,
823 has_dynamic_offset: false,
824 min_binding_size: None,
825 },
826 count: None,
827 }
828 }
829
830 #[inline]
831 pub fn layout(&self, binding: u32, read_only: bool) -> BindGroupLayoutEntry {
832 BindGroupLayoutEntry {
833 binding,
834 visibility: ShaderStages::COMPUTE,
835 ty: BindingType::Buffer {
836 ty: BufferBindingType::Storage { read_only },
837 has_dynamic_offset: false,
838 min_binding_size: None,
839 },
840 count: None,
841 }
842 }
843}
844
845impl<T: Scalar> From<TensorCpu<T>> for Vec<T> {
846 #[inline]
847 fn from(value: TensorCpu<T>) -> Self {
848 value.to_vec()
862 }
863}
864
865impl<T: Scalar, S: Into<ShapedIndex>> std::ops::Index<S> for TensorCpu<T> {
866 type Output = T;
867
868 fn index(&self, index: S) -> &Self::Output {
869 &self.data[self.shape.linear_index(index)]
870 }
871}
872
873impl<T: Scalar> TensorCpu<T> {
874 pub fn map<U: Scalar>(self, f: impl FnMut(&T) -> U) -> TensorCpu<U> {
876 let Self { shape, data, .. } = self;
877 let data = data.iter().map(f).collect_vec();
878 TensorInit::from_data(shape, &data).unwrap()
879 }
880
881 pub fn pad(self, multiples: impl Into<Shape>) -> Self {
883 let multiples: Shape = multiples.into();
891
892 let mut shape = self.shape;
893 for (axis, multiple) in multiples.iter().enumerate() {
894 shape[axis] = shape[axis].next_multiple_of(multiple);
895 }
896
897 let mut data = vec![T::zero(); shape.len()];
898 for index in self.shape.cartesian_product() {
899 let value = self[index];
900 data[shape.linear_index(index)] = value;
901 }
902 TensorInit::from_data(shape, data).unwrap()
903 }
904
905 pub fn repeat(self, axis: usize, repeat: usize) -> Self {
907 let mut shape = self.shape;
908 let data = self.data;
909
910 let num_chunk: usize = shape.iter().skip(axis + 1).product();
911 let chunk_size = data.len() / num_chunk;
912
913 shape[axis] *= repeat;
914
915 let data = (0..num_chunk)
916 .map(|chunk| {
917 let start = chunk * chunk_size;
918 let end = start + chunk_size;
919 let chunk = data[start..end].to_vec();
920 chunk.repeat(repeat)
921 })
922 .collect_vec()
923 .concat()
924 .into();
925
926 Self {
927 shape,
928 data,
929 id: uid::Id::new(),
930 phantom: PhantomData,
931 }
932 }
933
934 pub fn stack(batches: Vec<Self>, axis: usize) -> Result<Self, TensorError> {
936 let mut shape = match batches.first() {
937 Some(batch) => batch.shape,
938 None => Err(TensorErrorKind::Empty)?,
939 };
940
941 batches.iter().try_for_each(|batch| match axis {
946 0 => batch.check_shape([batch.shape[0], 1, 1, 1]),
947 1 => batch.check_shape([shape[0], batch.shape[1], 1, 1]),
948 2 => batch.check_shape([shape[0], shape[1], batch.shape[2], 1]),
949 3 => batch.check_shape([shape[0], shape[1], shape[2], batch.shape[3]]),
950 _ => unreachable!(),
951 })?;
952
953 let num_batch: usize = batches.iter().map(|batch| batch.shape[axis]).sum();
954 shape[axis] = num_batch;
955
956 let data = batches
957 .into_iter()
958 .map(|batch| batch.data.to_vec())
959 .collect_vec()
960 .concat()
961 .into();
962
963 Ok(Self {
964 shape,
965 data,
966 id: uid::Id::new(),
967 phantom: PhantomData,
968 })
969 }
970
971 pub fn split(self, axis: usize) -> Result<Vec<Self>, TensorError> {
973 if self.shape.iter().skip(axis + 1).any(|dim| dim > 1) {
974 Err(TensorErrorKind::SplitInvalid(axis))?;
975 }
976
977 (0..self.shape[axis])
978 .map(|index| match axis {
979 0 => self.slice(index, .., .., ..),
980 1 => self.slice(.., index, .., ..),
981 2 => self.slice(.., .., index, ..),
982 3 => self.slice(.., .., .., ..),
983 _ => Err(TensorErrorKind::SplitInvalid(axis))?,
984 })
985 .try_collect()
986 }
987
988 pub fn slice(
989 &self,
990 x: impl TensorAxis,
991 y: impl TensorAxis,
992 z: impl TensorAxis,
993 w: impl TensorAxis,
994 ) -> Result<Self, TensorError> {
995 let slice = (x, y, z, w);
996 let (start, end) = slice.shaped_bounds(self.shape)?;
997 let shape = (end - start).into();
998
999 let (start, end) = slice.linear_bounds(self.shape)?;
1000 let data = self.data[start..end].into();
1001
1002 Ok(Self {
1003 shape,
1004 data,
1005 id: uid::Id::new(),
1006 phantom: PhantomData,
1007 })
1008 }
1009
1010 pub fn into_slice(
1011 self,
1012 x: impl TensorAxis,
1013 y: impl TensorAxis,
1014 z: impl TensorAxis,
1015 w: impl TensorAxis,
1016 ) -> Result<Self, TensorError> {
1017 let slice = (x, y, z, w);
1018 let (start, end) = slice.shaped_bounds(self.shape)?;
1019 let shape = (end - start).into();
1020
1021 let (start, end) = slice.linear_bounds(self.shape)?;
1022 let data = self.data[start..end].into();
1023
1024 Ok(Self {
1025 shape,
1026 data,
1027 id: uid::Id::new(),
1028 phantom: PhantomData,
1029 })
1030 }
1031}
1032
1033#[derive(Debug, Clone)]
1035pub struct TensorGpuView<'a, T: Scalar> {
1036 tensor: &'a TensorGpu<T, ReadWrite>,
1037 meta: Arc<Buffer>,
1038 view: View,
1039}
1040
1041impl<T: Scalar> TensorShape for TensorGpuView<'_, T> {
1042 #[inline]
1043 fn shape(&self) -> Shape {
1044 self.view.shape
1045 }
1046}
1047
1048impl<T: Scalar> TensorGpuView<'_, T> {
1049 #[inline]
1050 pub fn tensor(&self) -> &TensorGpu<T, ReadWrite> {
1051 self.tensor
1052 }
1053
1054 #[inline]
1055 pub fn context(&self) -> &Context {
1056 self.tensor.context()
1057 }
1058
1059 #[inline]
1060 pub fn data(&self) -> &TensorGpuData {
1061 &self.tensor.data
1062 }
1063
1064 #[inline]
1065 pub fn meta_layout(&self, binding: u32) -> BindGroupLayoutEntry {
1066 self.tensor.meta_layout(binding)
1067 }
1068
1069 #[inline]
1070 pub fn layout(&self, binding: u32, read_only: bool) -> BindGroupLayoutEntry {
1071 self.tensor.layout(binding, read_only)
1072 }
1073}
1074
1075impl<T: Scalar> TensorResource for TensorGpuView<'_, T> {
1076 #[inline]
1077 fn resource_key(&self) -> ResourceKey {
1078 ResourceKey {
1079 id: self.tensor.id,
1080 view: self.view,
1081 }
1082 }
1083
1084 #[inline]
1085 fn meta_binding(&self) -> BindingResource<'_> {
1086 BindingResource::Buffer(BufferBinding {
1087 buffer: &self.meta,
1088 offset: 0,
1089 size: None,
1090 })
1091 }
1092
1093 #[inline]
1094 fn binding(&self) -> BindingResource<'_> {
1095 self.tensor.binding()
1096 }
1097}
1098
1099impl<T: Scalar> TensorScalar for TensorGpuView<'_, T> {
1100 type T = T;
1101}
1102
1103impl<F: Float> TensorGpuView<'_, F> {
1104 #[inline]
1105 pub const fn def(&self) -> &'static str {
1106 F::DEF
1107 }
1108}
1109
1110impl<'a, T: Scalar> From<&'a TensorGpu<T, ReadWrite>> for TensorGpuView<'a, T> {
1111 fn from(value: &'a TensorGpu<T, ReadWrite>) -> Self {
1112 value.view(.., .., .., ..).unwrap()
1113 }
1114}
1115
1116impl<T: Scalar> TensorGpu<T, ReadWrite> {
1117 pub fn view(
1119 &self,
1120 x: impl TensorAxis,
1121 y: impl TensorAxis,
1122 z: impl TensorAxis,
1123 w: impl TensorAxis,
1124 ) -> Result<TensorGpuView<'_, T>, TensorError> {
1125 let slice = (x, y, z, w);
1126 let (start, end) = slice.shaped_bounds(self.shape)?;
1127 let view = View {
1128 stride: self.shape,
1129 offset: start.into(),
1130 shape: (end - start).into(),
1131 };
1132 let meta = self.context.checkout_view_uniform(view);
1133 Ok(TensorGpuView {
1134 tensor: self,
1135 meta,
1136 view,
1137 })
1138 }
1139}
1140
1141impl<T: Scalar> DeepClone for TensorGpu<T, ReadWrite> {
1142 fn deep_clone(&self) -> Self {
1143 let context = &self.context;
1144 let shape = self.shape;
1145 let size = shape.len() as u64;
1146 let cloned: TensorGpu<_, _> = context.tensor_init(shape);
1147
1148 let mut encoder = context.device.create_command_encoder(&Default::default());
1149 encoder.copy_buffer_to_buffer(&self.buffer, 0, &cloned.buffer, 0, size);
1150 context.queue.submit(Some(encoder.finish()));
1151
1152 cloned
1153 }
1154}
1155
1156#[derive(Debug, Clone)]
1158pub struct TensorStack<T: Scalar> {
1159 pub tensor: TensorCpu<T>,
1160 pub cursors: Vec<Cursor>,
1161}
1162
1163impl<T: Scalar> TensorStack<T> {
1164 #[inline]
1166 pub fn num_batch(&self) -> usize {
1167 self.cursors.len()
1168 }
1169
1170 #[inline]
1172 pub fn num_active_batch(&self) -> usize {
1173 self.cursors.iter().filter(|cursor| cursor.len > 0).count()
1174 }
1175
1176 #[inline]
1177 pub fn num_token(&self) -> usize {
1178 self.tensor.shape[1]
1179 }
1180}
1181
1182impl<T: Scalar> TryFrom<Vec<TensorCpu<T>>> for TensorStack<T> {
1183 type Error = TensorError;
1184
1185 fn try_from(value: Vec<TensorCpu<T>>) -> Result<Self, Self::Error> {
1186 let shape = match value.first() {
1187 Some(batch) => batch.shape,
1188 None => Err(TensorErrorKind::Empty)?,
1189 };
1190
1191 value
1192 .iter()
1193 .try_for_each(|batch| batch.check_shape([shape[0], batch.shape[1], 1, 1]))?;
1194
1195 let cursors = value
1196 .iter()
1197 .enumerate()
1198 .scan(0, |token, (batch, tensor)| {
1199 let len = tensor.shape[1];
1200 let cursor = Cursor {
1201 batch,
1202 token: *token,
1203 len,
1204 };
1205 *token += len;
1206 Some(cursor)
1207 })
1208 .collect_vec();
1209
1210 let (shape, data) = value.into_iter().fold(
1211 (Shape::new(shape[0], 0, 1, 1), vec![]),
1212 |(mut shape, mut data), tensor| {
1213 shape[1] += tensor.shape[1];
1214 data.extend(tensor.data.to_vec());
1215 (shape, data)
1216 },
1217 );
1218 let data = data.into();
1219
1220 Ok(Self {
1221 tensor: Tensor {
1222 shape,
1223 data,
1224 id: uid::Id::new(),
1225 phantom: PhantomData,
1226 },
1227 cursors,
1228 })
1229 }
1230}
1231
1232impl Context {
1233 #[inline]
1234 pub fn zeros<T: Scalar, Tensor>(&self, shape: impl Into<Shape>) -> Tensor
1235 where
1236 TensorCpu<T>: TensorInto<Tensor>,
1237 {
1238 let tensor: TensorCpu<T> = TensorInit::init(shape);
1239 tensor.to(self)
1240 }
1241
1242 #[inline]
1243 pub fn ones<T: Scalar, Tensor>(&self, shape: impl Into<Shape>) -> Tensor
1244 where
1245 TensorCpu<T>: TensorInto<Tensor>,
1246 {
1247 let shape = shape.into();
1248 let data = vec![T::one(); shape.len()];
1249 let tensor: TensorCpu<T> = TensorInit::from_data(shape, data).unwrap();
1250 tensor.to(self)
1251 }
1252
1253 #[inline]
1254 pub fn tensor_from_data<'a, T: Scalar, Tensor: TensorInitContext<T>>(
1255 &self,
1256 shape: impl Into<Shape>,
1257 data: impl Into<Cow<'a, [T]>>,
1258 ) -> Result<Tensor, TensorError> {
1259 TensorInitContext::from_data(self, shape, data)
1260 }
1261
1262 #[inline]
1263 pub fn tensor_init<T: Scalar, Tensor: TensorInitContext<T>>(
1264 &self,
1265 shape: impl Into<Shape>,
1266 ) -> Tensor {
1267 TensorInitContext::init(self, shape)
1268 }
1269}
1270
1271mod sealed {
1272 use super::{Cpu, Gpu, Kind, ReadWrite, Uniform};
1273 use crate::num::Scalar;
1274
1275 pub trait Sealed {}
1276
1277 impl<T: Scalar> Sealed for Cpu<T> {}
1278 impl<K: Kind> Sealed for Gpu<K> {}
1279
1280 impl Sealed for Uniform {}
1281 impl Sealed for ReadWrite {}
1282}
1283
1284#[cfg(test)]
1285mod tests {
1286 use anyhow::Result;
1287
1288 use super::Shape;
1289 use crate::tensor::{TensorCpu, TensorInit, TensorShape};
1290
1291 #[test]
1292 fn test_pad_64() -> Result<()> {
1293 let shape = Shape::new(133, 256, 1, 1);
1294 let x: Vec<_> = (0..shape.len()).map(|x| x as f32).collect();
1295 let x = TensorCpu::from_data(shape, x)?.pad([64, 64, 1, 1]);
1296
1297 assert_eq!(x.shape(), Shape::new(192, 256, 1, 1));
1298 assert_eq!(x[(132, 255, 0, 0)], (shape.len() - 1) as f32);
1299 assert_eq!(x[(133, 255, 0, 0)], 0.0);
1300
1301 Ok(())
1302 }
1303
1304 #[test]
1305 fn test_repeat() -> Result<()> {
1306 let shape = Shape::new(5, 1, 2, 1);
1307 let x: Vec<_> = (0..shape.len()).map(|x| x as f32).collect();
1308 let x = TensorCpu::from_data(shape, x)?;
1309
1310 let y = x.clone().repeat(1, 3);
1311 let ans = [
1312 [0.0, 1.0, 2.0, 3.0, 4.0].repeat(3),
1313 [5.0, 6.0, 7.0, 8.0, 9.0].repeat(3),
1314 ]
1315 .concat();
1316 y.check_shape([5, 3, 2, 1])?;
1317 assert_eq!(y.to_vec(), ans);
1318
1319 let y = x.clone().repeat(0, 3);
1320 y.check_shape([15, 1, 2, 1])?;
1321 assert_eq!(y.to_vec(), ans);
1322
1323 let y = x.repeat(2, 3);
1324 let ans = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0].repeat(3);
1325 y.check_shape([5, 1, 6, 1])?;
1326 assert_eq!(y.to_vec(), ans);
1327
1328 Ok(())
1329 }
1330
1331 #[test]
1332 fn test_split() -> Result<()> {
1333 let shape = Shape::new(5, 1, 2, 1);
1334 let x: Vec<_> = (0..10).map(|x| x as f32).collect();
1335 let x = TensorCpu::from_data(shape, x)?;
1336
1337 assert!(x.clone().split(0).is_err());
1338 assert!(x.clone().split(1).is_err());
1339
1340 let x = x.split(2)?;
1341 x[0].check_shape([5, 1, 1, 1])?;
1342 x[1].check_shape([5, 1, 1, 1])?;
1343 assert_eq!(x[0].to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
1344 assert_eq!(x[1].to_vec(), vec![5.0, 6.0, 7.0, 8.0, 9.0]);
1345
1346 Ok(())
1347 }
1348}