web_rwkv/tensor/
mod.rs

1use std::{borrow::Cow, marker::PhantomData, sync::Arc};
2
3use itertools::Itertools;
4use shape::ShapedIndex;
5use thiserror::Error;
6use wgpu::{
7    BindGroupLayoutEntry, BindingResource, BindingType, Buffer, BufferBinding, BufferBindingType,
8    BufferUsages, ShaderStages,
9};
10
11use self::{
12    kind::{Kind, ReadWrite, Uniform},
13    shape::{IntoBytes, Shape, TensorAxis, TensorDimension, TensorSlice},
14};
15use crate::{
16    context::Context,
17    num::{Float, Scalar},
18};
19
20pub mod cache;
21pub mod matrix;
22pub mod ops;
23pub mod serialization;
24pub mod shape;
25
26/// Data defining a tensor view in shader.
27#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
28pub struct View {
29    pub shape: Shape,
30    pub stride: Shape,
31    pub offset: Shape,
32}
33
34impl IntoBytes for View {
35    fn into_bytes(self) -> Vec<u8> {
36        [
37            self.shape.into_bytes(),
38            self.stride.into_bytes(),
39            self.offset.into_bytes(),
40        ]
41        .concat()
42    }
43}
44
45/// A record in order to separate different batches of input of various lengths.
46#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct Cursor {
48    pub batch: usize,
49    pub token: usize,
50    pub len: usize,
51}
52
53impl Cursor {
54    pub fn pack(self) -> u32 {
55        let batch = self.batch as u8;
56        let token = (self.token as u16).to_ne_bytes();
57        let len = self.len as u8;
58        bytemuck::cast([batch, token[0], token[1], len])
59    }
60}
61
62pub trait IntoPackedCursors {
63    fn into_stack(self) -> Vec<u32>;
64    fn into_cursors(self) -> Vec<u32>;
65}
66
67impl IntoPackedCursors for Vec<Cursor> {
68    fn into_stack(self) -> Vec<u32> {
69        self.into_iter()
70            .filter(|cursor| cursor.len > 0)
71            .map(Cursor::pack)
72            .collect()
73    }
74
75    fn into_cursors(self) -> Vec<u32> {
76        self.into_iter()
77            .filter(|cursor| cursor.len > 0)
78            .map(|cursor| {
79                let repeat = cursor.len;
80                vec![cursor.pack(); repeat]
81            })
82            .collect_vec()
83            .concat()
84    }
85}
86
87#[derive(Debug)]
88pub struct TensorError {
89    pub error: TensorErrorKind,
90    #[cfg(feature = "backtrace")]
91    pub backtrace: std::backtrace::Backtrace,
92}
93
94impl TensorError {
95    #[cfg(feature = "backtrace")]
96    pub fn new(error: TensorErrorKind) -> Self {
97        let backtrace = std::backtrace::Backtrace::capture();
98        Self { error, backtrace }
99    }
100
101    #[cfg(not(feature = "backtrace"))]
102    pub fn new(error: TensorErrorKind) -> Self {
103        Self { error }
104    }
105}
106
107impl From<TensorErrorKind> for TensorError {
108    fn from(value: TensorErrorKind) -> Self {
109        Self::new(value)
110    }
111}
112
113impl std::fmt::Display for TensorError {
114    #[cfg(feature = "backtrace")]
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        writeln!(f, "{}\n\nBacktrace:\n{}", self.error, self.backtrace)
117    }
118
119    #[cfg(not(feature = "backtrace"))]
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        writeln!(f, "{}", self.error)
122    }
123}
124
125impl std::error::Error for TensorError {}
126
127#[derive(Debug, Error)]
128pub enum TensorErrorKind {
129    #[error("list must not be empty")]
130    Empty,
131    #[error("data type mismatch")]
132    Type,
133    #[error("data size not match: {0} vs. {1}")]
134    Size(usize, usize),
135    #[error("batch size not match: {0} vs. {1}")]
136    Batch(usize, usize),
137    #[error("tensor shape not match: {0} vs. {1}")]
138    Shape(Shape, Shape),
139    #[error("cannot deduce dimension")]
140    Deduce,
141    #[error("batch {batch} out of range of max {max}")]
142    BatchOutOfRange { batch: usize, max: usize },
143    #[error("slice {start}..{end} out of range for dimension size {dim}")]
144    SliceOutOfRange {
145        dim: usize,
146        start: usize,
147        end: usize,
148    },
149    #[error("slice not contiguous")]
150    SliceInvalid,
151    #[error("cannot split along the axis {0}")]
152    SplitInvalid(usize),
153    #[error("possible tensor error(s):\n{0}")]
154    Any(#[from] AnyTensorError),
155}
156
157#[derive(Debug, Error)]
158pub struct AnyTensorError(pub Vec<Box<dyn std::error::Error + Send + Sync>>);
159
160impl std::fmt::Display for AnyTensorError {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        self.0
163            .iter()
164            .enumerate()
165            .try_for_each(|(index, error)| writeln!(f, "{index}. {error}"))
166    }
167}
168
169pub trait DeepClone: Sized {
170    fn deep_clone(&self) -> Self;
171}
172
173pub trait TensorScalar {
174    type T: Scalar;
175}
176
177pub trait TensorInitContext<T: Scalar>: Sized {
178    /// Init the tensor with given shape and contents.
179    fn from_data<'a, S, D>(context: &Context, shape: S, data: D) -> Result<Self, TensorError>
180    where
181        S: Into<Shape>,
182        D: Into<Cow<'a, [T]>>;
183    /// Init the tensor with given shape.
184    fn init(context: &Context, shape: impl Into<Shape>) -> Self;
185}
186
187pub trait TensorInit<T: Scalar>: Sized {
188    /// Init the tensor with given shape and contents.
189    fn from_data<'a, S, D>(shape: S, data: D) -> Result<Self, TensorError>
190    where
191        S: Into<Shape>,
192        D: Into<Cow<'a, [T]>>;
193    /// Init the tensor with given shape.
194    fn init(shape: impl Into<Shape>) -> Self;
195
196    /// Init an 1-D tensor from data.
197    fn from_data_1d<'a>(data: impl Into<Cow<'a, [T]>>) -> Self {
198        let data: Cow<'_, [T]> = data.into();
199        let shape = [data.len(), 1, 1, 1];
200        Self::from_data(shape, data).expect("tensor 1d from data")
201    }
202}
203
204pub trait TensorInto<Into> {
205    fn to(self, context: &Context) -> Into;
206}
207
208pub trait TensorShape: Sized {
209    /// Get the shape of the tensor.
210    fn shape(&self) -> Shape;
211
212    /// Check if the tensor's shape is the same as what expected.
213    fn check_shape(&self, shape: impl Into<Shape>) -> Result<(), TensorError> {
214        let shape = shape.into();
215        (self.shape() == shape)
216            .then_some(())
217            .ok_or(TensorErrorKind::Shape(self.shape(), shape))
218            .map_err(Into::into)
219    }
220
221    /// Check if the tensor's shape matches any of the expected ones.
222    fn check_shape_any<S>(&self, shapes: &[S]) -> Result<(), TensorError>
223    where
224        S: Into<Shape> + ToOwned<Owned = S>,
225    {
226        let (oks, errors): (Vec<_>, Vec<_>) = shapes
227            .iter()
228            .map(|shape| self.check_shape(shape.to_owned()).map_err(Into::into))
229            .partition_result();
230        match oks.is_empty() {
231            true => Err(TensorErrorKind::Any(AnyTensorError(errors)))?,
232            false => Ok(()),
233        }
234    }
235}
236
237pub trait TensorReshape: Sized {
238    fn reshape(
239        &self,
240        x: TensorDimension,
241        y: TensorDimension,
242        z: TensorDimension,
243        w: TensorDimension,
244    ) -> Result<Self, TensorError>;
245}
246
247pub trait TensorResource {
248    /// Retrieve the key identifying a resource.
249    fn resource_key(&self) -> ResourceKey;
250    /// Binding for metadata of the tensor (shape, stride, etc.).
251    fn meta_binding(&self) -> BindingResource<'_>;
252    /// Binding for actual data of the tensor.
253    fn binding(&self) -> BindingResource<'_>;
254}
255
256/// A tensor on either CPU or GPU.
257#[derive(Debug)]
258pub struct Tensor<D: Device, T: Scalar> {
259    shape: Shape,
260    data: D::Data,
261    id: uid::Id<TensorId>,
262    phantom: PhantomData<T>,
263}
264
265#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
266pub struct TensorId;
267
268/// A unique identifier of tensor views. Useful in caches.
269#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
270pub struct ResourceKey {
271    pub id: uid::Id<TensorId>,
272    pub view: View,
273}
274
275pub trait Device: sealed::Sealed {
276    type Data: Clone;
277}
278
279#[derive(Debug)]
280pub struct Cpu<T: Scalar>(PhantomData<T>);
281
282#[derive(Debug)]
283pub struct Gpu<K: Kind>(PhantomData<K>);
284
285impl<T: Scalar> Device for Cpu<T> {
286    type Data = Arc<[T]>;
287}
288
289impl<K: Kind> Device for Gpu<K> {
290    type Data = TensorGpuData;
291}
292
293/// Buffer of the tensor on GPU.
294#[derive(Debug, Clone)]
295pub struct TensorGpuData {
296    pub context: Context,
297    pub meta: Arc<Buffer>,
298    pub buffer: Arc<Buffer>,
299}
300
301pub mod kind {
302    use web_rwkv_derive::Kind;
303    use wgpu::BufferUsages;
304
305    use super::sealed;
306
307    pub trait Kind: sealed::Sealed {
308        fn buffer_usages() -> BufferUsages;
309    }
310
311    /// Tensor is a uniform buffer.
312    #[derive(Debug, Kind)]
313    #[usage(UNIFORM, COPY_DST, COPY_SRC)]
314    pub struct Uniform;
315
316    /// Tensor is a storage buffer with can be copied to other buffers.
317    #[derive(Debug, Kind)]
318    #[usage(STORAGE, COPY_DST, COPY_SRC)]
319    pub struct ReadWrite;
320}
321
322pub type TensorCpu<T> = Tensor<Cpu<T>, T>;
323pub type TensorGpu<T, K> = Tensor<Gpu<K>, T>;
324
325impl<D: Device, T: Scalar> Clone for Tensor<D, T> {
326    fn clone(&self) -> Self {
327        Self {
328            shape: self.shape,
329            data: self.data.clone(),
330            id: self.id,
331            phantom: PhantomData,
332        }
333    }
334}
335
336impl<D: Device, T: Scalar> std::ops::Deref for Tensor<D, T> {
337    type Target = D::Data;
338
339    #[inline]
340    fn deref(&self) -> &Self::Target {
341        &self.data
342    }
343}
344
345impl<D: Device, T: Scalar> TensorScalar for Tensor<D, T> {
346    type T = T;
347}
348
349impl<D: Device, T: Scalar> Tensor<D, T> {
350    #[inline]
351    pub fn len(&self) -> usize {
352        self.shape.len()
353    }
354
355    #[inline]
356    pub fn is_empty(&self) -> bool {
357        self.shape.is_empty()
358    }
359
360    /// Size of the tensor in bytes.
361    #[inline]
362    pub fn size(&self) -> usize {
363        self.len() * T::size()
364    }
365
366    /// The offset in bytes for a linear index.
367    #[inline]
368    pub fn offset(index: usize) -> usize {
369        index * T::size()
370    }
371
372    #[inline]
373    pub fn data(&self) -> &D::Data {
374        &self.data
375    }
376
377    #[inline]
378    pub fn id(&self) -> uid::Id<TensorId> {
379        self.id
380    }
381}
382
383impl<D: Device, F: Float> Tensor<D, F> {
384    #[inline]
385    pub const fn def(&self) -> &'static str {
386        F::DEF
387    }
388}
389
390impl<T: Scalar> TensorInit<T> for TensorCpu<T> {
391    fn from_data<'a, S, D>(shape: S, data: D) -> Result<Self, TensorError>
392    where
393        S: Into<Shape>,
394        D: Into<Cow<'a, [T]>>,
395    {
396        let shape = shape.into();
397        let data: Cow<'_, _> = data.into();
398        if shape.len() != data.len() {
399            Err(TensorErrorKind::Size(shape.len(), data.len()))?;
400        }
401        let data = data.into_owned().into();
402        Ok(Self {
403            shape,
404            data,
405            id: uid::Id::new(),
406            phantom: PhantomData,
407        })
408    }
409
410    #[inline]
411    fn init(shape: impl Into<Shape>) -> Self {
412        let shape = shape.into();
413        let data = vec![T::zero(); shape.len()].into();
414        Self {
415            shape,
416            data,
417            id: uid::Id::new(),
418            phantom: PhantomData,
419        }
420    }
421}
422
423impl<T: Scalar> TensorInitContext<T> for TensorCpu<T> {
424    fn from_data<'a, S, D>(_context: &Context, shape: S, data: D) -> Result<Self, TensorError>
425    where
426        S: Into<Shape>,
427        D: Into<Cow<'a, [T]>>,
428    {
429        TensorInit::from_data(shape, data)
430    }
431
432    fn init(_context: &Context, shape: impl Into<Shape>) -> Self {
433        TensorInit::init(shape)
434    }
435}
436
437impl<T: Scalar> TensorInto<TensorCpu<T>> for TensorCpu<T> {
438    fn to(self, _: &Context) -> Self {
439        self
440    }
441}
442
443impl<T: Scalar> DeepClone for TensorCpu<T> {
444    fn deep_clone(&self) -> Self {
445        self.clone()
446    }
447}
448
449impl<T: Scalar> TensorShape for TensorCpu<T> {
450    #[inline]
451    fn shape(&self) -> Shape {
452        self.shape
453    }
454}
455
456impl<T: Scalar> TensorReshape for TensorCpu<T> {
457    #[inline]
458    fn reshape(
459        &self,
460        x: TensorDimension,
461        y: TensorDimension,
462        z: TensorDimension,
463        w: TensorDimension,
464    ) -> Result<Self, TensorError> {
465        let shape = TensorDimension::deduce(self.shape, x, y, z, w)?;
466        Ok(Self {
467            shape,
468            ..self.clone()
469        })
470    }
471}
472
473impl<T: Scalar, K: Kind> TensorInitContext<T> for TensorGpu<T, K> {
474    fn from_data<'a, S, D>(context: &Context, shape: S, data: D) -> Result<Self, TensorError>
475    where
476        S: Into<Shape>,
477        D: Into<Cow<'a, [T]>>,
478    {
479        let tensor: TensorCpu<T> = TensorInit::from_data(shape, data)?;
480        Ok(tensor.to(context))
481    }
482
483    fn init(context: &Context, shape: impl Into<Shape>) -> Self {
484        let context = context.clone();
485        let shape = shape.into();
486        let meta = context.checkout_shape_uniform(shape);
487        let size = shape.len() * std::mem::size_of::<T>();
488        let buffer = context.checkout_buffer(size, K::buffer_usages());
489        Self {
490            shape,
491            data: TensorGpuData {
492                context,
493                meta,
494                buffer,
495            },
496            id: uid::Id::new(),
497            phantom: PhantomData,
498        }
499    }
500}
501
502impl<T: Scalar, K: Kind> TensorInto<TensorGpu<T, K>> for TensorCpu<T> {
503    fn to(self, context: &Context) -> TensorGpu<T, K> {
504        let Tensor { shape, data, .. } = self;
505        let context = context.clone();
506        let meta = context.checkout_shape_uniform(shape);
507        let contents = bytemuck::cast_slice(&data);
508        let buffer = context.checkout_buffer_init(contents, K::buffer_usages());
509        TensorGpu {
510            shape,
511            data: TensorGpuData {
512                context,
513                meta,
514                buffer,
515            },
516            id: uid::Id::new(),
517            phantom: PhantomData,
518        }
519    }
520}
521
522#[cfg(not(target_arch = "wasm32"))]
523impl<T: Scalar> TensorInto<TensorGpu<T, ReadWrite>> for TensorGpu<T, ReadWrite> {
524    fn to(self, context: &Context) -> Self {
525        match context {
526            context if context == &self.context => self,
527            _ => self.back_in_place().to(context),
528        }
529    }
530}
531
532#[cfg(target_arch = "wasm32")]
533impl<T: Scalar> TensorInto<TensorGpu<T, ReadWrite>> for TensorGpu<T, ReadWrite> {
534    fn to(self, _: &Context) -> Self {
535        self
536    }
537}
538
539impl<T: Scalar, K: Kind> TensorShape for TensorGpu<T, K> {
540    #[inline]
541    fn shape(&self) -> Shape {
542        self.shape
543    }
544}
545
546impl<T: Scalar, K: Kind> TensorReshape for TensorGpu<T, K> {
547    #[inline]
548    fn reshape(
549        &self,
550        x: TensorDimension,
551        y: TensorDimension,
552        z: TensorDimension,
553        w: TensorDimension,
554    ) -> Result<Self, TensorError> {
555        let shape = TensorDimension::deduce(self.shape, x, y, z, w)?;
556        let context = self.context.clone();
557        let meta = context.checkout_shape_uniform(shape);
558        let buffer = self.buffer.clone();
559        Ok(Self {
560            shape,
561            data: TensorGpuData {
562                context,
563                meta,
564                buffer,
565            },
566            ..self.clone()
567        })
568    }
569}
570
571impl<T: Scalar, K: Kind> TensorResource for TensorGpu<T, K> {
572    #[inline]
573    fn resource_key(&self) -> ResourceKey {
574        let id = self.id;
575        let view = View {
576            shape: self.shape,
577            stride: self.shape,
578            offset: [0, 0, 0, 0].into(),
579        };
580        ResourceKey { id, view }
581    }
582
583    #[inline]
584    fn meta_binding(&self) -> BindingResource<'_> {
585        BindingResource::Buffer(BufferBinding {
586            buffer: &self.meta,
587            offset: 0,
588            size: None,
589        })
590    }
591
592    #[inline]
593    fn binding(&self) -> BindingResource<'_> {
594        BindingResource::Buffer(BufferBinding {
595            buffer: &self.buffer,
596            offset: 0,
597            size: None,
598        })
599    }
600}
601
602impl<T: Scalar, K: Kind> TensorGpu<T, K> {
603    pub fn from_data_u8(
604        context: &Context,
605        shape: impl Into<Shape>,
606        contents: &[u8],
607    ) -> Result<Self, TensorError> {
608        let shape = shape.into();
609        let size = shape.len() * size_of::<T>();
610        if contents.len() != size {
611            Err(TensorErrorKind::Size(size, contents.len()))?;
612        }
613        let buffer = context.checkout_buffer_init(contents, K::buffer_usages());
614        let meta = context.checkout_shape_uniform(shape);
615        Ok(Self {
616            shape,
617            data: TensorGpuData {
618                context: context.clone(),
619                meta,
620                buffer,
621            },
622            id: uid::Id::new(),
623            phantom: PhantomData,
624        })
625    }
626
627    #[cfg(not(target_arch = "wasm32"))]
628    pub fn back_in_place(&self) -> TensorCpu<T> {
629        use crate::context::ContextEvent;
630
631        if self.is_empty() {
632            return TensorCpu {
633                shape: self.shape,
634                data: Arc::new([]),
635                id: uid::Id::new(),
636                phantom: PhantomData,
637            };
638        }
639
640        let context = &self.context;
641        let size = self.buffer.size();
642        let buffer = context.checkout_buffer(
643            size as usize,
644            BufferUsages::MAP_READ | BufferUsages::COPY_DST,
645        );
646
647        let mut encoder = context.device.create_command_encoder(&Default::default());
648        encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
649        context.queue.submit(Some(encoder.finish()));
650
651        let (sender, receiver) = flume::bounded(1);
652        let _ = context.event().send(ContextEvent { buffer, sender });
653        let data = receiver.recv().expect("failed to receive read back buffer");
654        let data = unsafe {
655            let data = Box::leak(data);
656            let slice = bytemuck::cast_slice_mut::<_, T>(data);
657            Box::from_raw(slice)
658        };
659        let data = data.into_vec().into();
660        let shape = self.shape;
661
662        TensorCpu {
663            shape,
664            data,
665            id: uid::Id::new(),
666            phantom: PhantomData,
667        }
668    }
669
670    #[cfg(not(target_arch = "wasm32"))]
671    pub async fn back(&self) -> TensorCpu<T> {
672        if self.is_empty() {
673            return TensorCpu {
674                shape: self.shape,
675                data: Arc::new([]),
676                id: uid::Id::new(),
677                phantom: PhantomData,
678            };
679        }
680
681        let context = &self.context;
682        let size = self.buffer.size();
683        let buffer = context.checkout_buffer(
684            size as usize,
685            BufferUsages::MAP_READ | BufferUsages::COPY_DST,
686        );
687
688        let mut encoder = context.device.create_command_encoder(&Default::default());
689        encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
690        context.queue.submit(Some(encoder.finish()));
691
692        let (sender, receiver) = flume::bounded(1);
693
694        let _ = context
695            .event()
696            .send(crate::context::ContextEvent { buffer, sender });
697        let data = receiver
698            .recv_async()
699            .await
700            .expect("failed to receive read back buffer");
701        let data = unsafe {
702            let data = Box::leak(data);
703            let slice = bytemuck::cast_slice_mut::<_, T>(data);
704            Box::from_raw(slice)
705        };
706        let data = data.into_vec().into();
707
708        TensorCpu {
709            shape: self.shape,
710            data,
711            id: uid::Id::new(),
712            phantom: PhantomData,
713        }
714    }
715
716    #[cfg(target_arch = "wasm32")]
717    pub async fn back(self) -> TensorCpu<T> {
718        if self.is_empty() {
719            return TensorCpu {
720                shape: self.shape,
721                data: Arc::new([]),
722                id: uid::Id::new(),
723                phantom: PhantomData,
724            };
725        }
726
727        let context = &self.context;
728        let size = self.buffer.size();
729        let buffer = context.checkout_buffer(
730            size as usize,
731            BufferUsages::MAP_READ | BufferUsages::COPY_DST,
732        );
733
734        let mut encoder = context.device.create_command_encoder(&Default::default());
735        encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
736        let submission_index = Some(context.queue.submit(Some(encoder.finish())));
737
738        let (sender, receiver) = flume::unbounded();
739
740        let slice = buffer.slice(..);
741        slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
742
743        _ = context.device.poll(wgpu::PollType::Wait {
744            submission_index,
745            timeout: None,
746        });
747        receiver
748            .recv_async()
749            .await
750            .expect("failed to receive read back buffer")
751            .expect("failed to map buffer");
752
753        let data = {
754            let map = slice.get_mapped_range();
755            Vec::from(bytemuck::cast_slice(&map)).into()
756        };
757        buffer.unmap();
758
759        TensorCpu {
760            shape: self.shape,
761            data,
762            id: uid::Id::new(),
763            phantom: PhantomData,
764        }
765    }
766}
767
768impl<T: Scalar, K: Kind> TensorGpu<T, K> {
769    #[inline]
770    pub fn context(&self) -> &Context {
771        &self.context
772    }
773
774    pub fn load(&self, host: &TensorCpu<T>) -> Result<(), TensorError> {
775        host.check_shape(self.shape)?;
776        self.context
777            .queue
778            .write_buffer(&self.buffer, 0, bytemuck::cast_slice(&host.data[..]));
779        Ok(())
780    }
781
782    pub fn load_batch(&self, host: &TensorCpu<T>, batch: usize) -> Result<(), TensorError> {
783        host.check_shape([self.shape[0], self.shape[1], 1, 1])?;
784        if batch >= self.shape[2] {
785            Err(TensorErrorKind::BatchOutOfRange {
786                batch,
787                max: self.shape[2],
788            })?;
789        }
790        let offset = (T::size() * self.shape[0] * self.shape[1] * batch) as u64;
791        self.context
792            .queue
793            .write_buffer(&self.buffer, offset, bytemuck::cast_slice(&host.data[..]));
794        Ok(())
795    }
796
797    pub fn destroy(self) {
798        self.buffer.destroy();
799    }
800}
801
802impl<T: Scalar> TensorGpu<T, Uniform> {
803    #[inline]
804    pub fn layout(&self, binding: u32) -> BindGroupLayoutEntry {
805        BindGroupLayoutEntry {
806            binding,
807            visibility: ShaderStages::COMPUTE,
808            ty: BindingType::Buffer {
809                ty: BufferBindingType::Uniform,
810                has_dynamic_offset: false,
811                min_binding_size: None,
812            },
813            count: None,
814        }
815    }
816}
817
818impl<T: Scalar> TensorGpu<T, ReadWrite> {
819    #[inline]
820    pub fn meta_layout(&self, binding: u32) -> BindGroupLayoutEntry {
821        BindGroupLayoutEntry {
822            binding,
823            visibility: ShaderStages::COMPUTE,
824            ty: BindingType::Buffer {
825                ty: BufferBindingType::Uniform,
826                has_dynamic_offset: false,
827                min_binding_size: None,
828            },
829            count: None,
830        }
831    }
832
833    #[inline]
834    pub fn layout(&self, binding: u32, read_only: bool) -> BindGroupLayoutEntry {
835        BindGroupLayoutEntry {
836            binding,
837            visibility: ShaderStages::COMPUTE,
838            ty: BindingType::Buffer {
839                ty: BufferBindingType::Storage { read_only },
840                has_dynamic_offset: false,
841                min_binding_size: None,
842            },
843            count: None,
844        }
845    }
846}
847
848impl<T: Scalar> From<TensorCpu<T>> for Vec<T> {
849    #[inline]
850    fn from(value: TensorCpu<T>) -> Self {
851        // match Arc::get_mut(&mut value.data) {
852        //     Some(data) => {
853        //         // SAFETY: if `data` is unique, it stays unique in the scope of this function since we own the `Arc`.
854        //         unsafe {
855        //             let len = data.len();
856        //             let data = Arc::into_raw(value.data) as *mut T;
857        //             let slice = core::slice::from_raw_parts_mut(data, len);
858        //             let boxed = Box::from_raw(slice);
859        //             boxed.into_vec()
860        //         }
861        //     }
862        //     None => value.data.to_vec(),
863        // }
864        value.to_vec()
865    }
866}
867
868impl<T: Scalar, S: Into<ShapedIndex>> std::ops::Index<S> for TensorCpu<T> {
869    type Output = T;
870
871    fn index(&self, index: S) -> &Self::Output {
872        &self.data[self.shape.linear_index(index)]
873    }
874}
875
876impl<T: Scalar> TensorCpu<T> {
877    /// Apply a map `f` to every element in the tensor.
878    pub fn map<U: Scalar>(self, f: impl FnMut(&T) -> U) -> TensorCpu<U> {
879        let Self { shape, data, .. } = self;
880        let data = data.iter().map(f).collect_vec();
881        TensorInit::from_data(shape, &data).unwrap()
882    }
883
884    /// Pad each dimension to multiples with zeros.
885    pub fn pad(self, multiples: impl Into<Shape>) -> Self {
886        // let shape = Shape::new(
887        //     self.shape[0].next_multiple_of(64),
888        //     self.shape[1].next_multiple_of(64),
889        //     self.shape[2].next_multiple_of(64),
890        //     self.shape[3].next_multiple_of(64),
891        // );
892
893        let multiples: Shape = multiples.into();
894
895        let mut shape = self.shape;
896        for (axis, multiple) in multiples.iter().enumerate() {
897            shape[axis] = shape[axis].next_multiple_of(multiple);
898        }
899
900        let mut data = vec![T::zero(); shape.len()];
901        for index in self.shape.cartesian_product() {
902            let value = self[index];
903            data[shape.linear_index(index)] = value;
904        }
905        TensorInit::from_data(shape, data).unwrap()
906    }
907
908    /// Repeat the tensor along a given axis.
909    pub fn repeat(self, axis: usize, repeat: usize) -> Self {
910        let mut shape = self.shape;
911        let data = self.data;
912
913        let num_chunk: usize = shape.iter().skip(axis + 1).product();
914        let chunk_size = data.len() / num_chunk;
915
916        shape[axis] *= repeat;
917
918        let data = (0..num_chunk)
919            .map(|chunk| {
920                let start = chunk * chunk_size;
921                let end = start + chunk_size;
922                let chunk = data[start..end].to_vec();
923                chunk.repeat(repeat)
924            })
925            .collect_vec()
926            .concat()
927            .into();
928
929        Self {
930            shape,
931            data,
932            id: uid::Id::new(),
933            phantom: PhantomData,
934        }
935    }
936
937    /// Concat a batch of tensors.
938    pub fn stack(batches: Vec<Self>, axis: usize) -> Result<Self, TensorError> {
939        let mut shape = match batches.first() {
940            Some(batch) => batch.shape,
941            None => Err(TensorErrorKind::Empty)?,
942        };
943
944        // batches
945        //     .iter()
946        //     .try_for_each(|batch| batch.check_shape([shape[0], shape[1], batch.shape[2], 1]))?;
947
948        batches.iter().try_for_each(|batch| match axis {
949            0 => batch.check_shape([batch.shape[0], 1, 1, 1]),
950            1 => batch.check_shape([shape[0], batch.shape[1], 1, 1]),
951            2 => batch.check_shape([shape[0], shape[1], batch.shape[2], 1]),
952            3 => batch.check_shape([shape[0], shape[1], shape[2], batch.shape[3]]),
953            _ => unreachable!(),
954        })?;
955
956        let num_batch: usize = batches.iter().map(|batch| batch.shape[axis]).sum();
957        shape[axis] = num_batch;
958
959        let data = batches
960            .into_iter()
961            .map(|batch| batch.data.to_vec())
962            .collect_vec()
963            .concat()
964            .into();
965
966        Ok(Self {
967            shape,
968            data,
969            id: uid::Id::new(),
970            phantom: PhantomData,
971        })
972    }
973
974    /// Split the tensor along the batch axis.
975    pub fn split(self, axis: usize) -> Result<Vec<Self>, TensorError> {
976        if self.shape.iter().skip(axis + 1).any(|dim| dim > 1) {
977            Err(TensorErrorKind::SplitInvalid(axis))?;
978        }
979
980        (0..self.shape[axis])
981            .map(|index| match axis {
982                0 => self.slice(index, .., .., ..),
983                1 => self.slice(.., index, .., ..),
984                2 => self.slice(.., .., index, ..),
985                3 => self.slice(.., .., .., ..),
986                _ => Err(TensorErrorKind::SplitInvalid(axis))?,
987            })
988            .try_collect()
989    }
990
991    pub fn slice(
992        &self,
993        x: impl TensorAxis,
994        y: impl TensorAxis,
995        z: impl TensorAxis,
996        w: impl TensorAxis,
997    ) -> Result<Self, TensorError> {
998        let slice = (x, y, z, w);
999        let (start, end) = slice.shaped_bounds(self.shape)?;
1000        let shape = (end - start).into();
1001
1002        let (start, end) = slice.linear_bounds(self.shape)?;
1003        let data = self.data[start..end].into();
1004
1005        Ok(Self {
1006            shape,
1007            data,
1008            id: uid::Id::new(),
1009            phantom: PhantomData,
1010        })
1011    }
1012
1013    pub fn into_slice(
1014        self,
1015        x: impl TensorAxis,
1016        y: impl TensorAxis,
1017        z: impl TensorAxis,
1018        w: impl TensorAxis,
1019    ) -> Result<Self, TensorError> {
1020        let slice = (x, y, z, w);
1021        let (start, end) = slice.shaped_bounds(self.shape)?;
1022        let shape = (end - start).into();
1023
1024        let (start, end) = slice.linear_bounds(self.shape)?;
1025        let data = self.data[start..end].into();
1026
1027        Ok(Self {
1028            shape,
1029            data,
1030            id: uid::Id::new(),
1031            phantom: PhantomData,
1032        })
1033    }
1034}
1035
1036/// Like a reference to a tensor, but refer to a sub-chunk of it.
1037#[derive(Debug, Clone)]
1038pub struct TensorGpuView<'a, T: Scalar> {
1039    tensor: &'a TensorGpu<T, ReadWrite>,
1040    meta: Arc<Buffer>,
1041    view: View,
1042}
1043
1044impl<T: Scalar> TensorShape for TensorGpuView<'_, T> {
1045    #[inline]
1046    fn shape(&self) -> Shape {
1047        self.view.shape
1048    }
1049}
1050
1051impl<T: Scalar> TensorGpuView<'_, T> {
1052    #[inline]
1053    pub fn tensor(&self) -> &TensorGpu<T, ReadWrite> {
1054        self.tensor
1055    }
1056
1057    #[inline]
1058    pub fn context(&self) -> &Context {
1059        self.tensor.context()
1060    }
1061
1062    #[inline]
1063    pub fn data(&self) -> &TensorGpuData {
1064        &self.tensor.data
1065    }
1066
1067    #[inline]
1068    pub fn meta_layout(&self, binding: u32) -> BindGroupLayoutEntry {
1069        self.tensor.meta_layout(binding)
1070    }
1071
1072    #[inline]
1073    pub fn layout(&self, binding: u32, read_only: bool) -> BindGroupLayoutEntry {
1074        self.tensor.layout(binding, read_only)
1075    }
1076}
1077
1078impl<T: Scalar> TensorResource for TensorGpuView<'_, T> {
1079    #[inline]
1080    fn resource_key(&self) -> ResourceKey {
1081        ResourceKey {
1082            id: self.tensor.id,
1083            view: self.view,
1084        }
1085    }
1086
1087    #[inline]
1088    fn meta_binding(&self) -> BindingResource<'_> {
1089        BindingResource::Buffer(BufferBinding {
1090            buffer: &self.meta,
1091            offset: 0,
1092            size: None,
1093        })
1094    }
1095
1096    #[inline]
1097    fn binding(&self) -> BindingResource<'_> {
1098        self.tensor.binding()
1099    }
1100}
1101
1102impl<T: Scalar> TensorScalar for TensorGpuView<'_, T> {
1103    type T = T;
1104}
1105
1106impl<F: Float> TensorGpuView<'_, F> {
1107    #[inline]
1108    pub const fn def(&self) -> &'static str {
1109        F::DEF
1110    }
1111}
1112
1113impl<'a, T: Scalar> From<&'a TensorGpu<T, ReadWrite>> for TensorGpuView<'a, T> {
1114    fn from(value: &'a TensorGpu<T, ReadWrite>) -> Self {
1115        value.view(.., .., .., ..).unwrap()
1116    }
1117}
1118
1119impl<T: Scalar> TensorGpu<T, ReadWrite> {
1120    /// Create a view for the tensor.
1121    pub fn view(
1122        &self,
1123        x: impl TensorAxis,
1124        y: impl TensorAxis,
1125        z: impl TensorAxis,
1126        w: impl TensorAxis,
1127    ) -> Result<TensorGpuView<'_, T>, TensorError> {
1128        let slice = (x, y, z, w);
1129        let (start, end) = slice.shaped_bounds(self.shape)?;
1130        let view = View {
1131            stride: self.shape,
1132            offset: start.into(),
1133            shape: (end - start).into(),
1134        };
1135        let meta = self.context.checkout_view_uniform(view);
1136        Ok(TensorGpuView {
1137            tensor: self,
1138            meta,
1139            view,
1140        })
1141    }
1142}
1143
1144impl<T: Scalar> DeepClone for TensorGpu<T, ReadWrite> {
1145    fn deep_clone(&self) -> Self {
1146        let context = &self.context;
1147        let shape = self.shape;
1148        let size = shape.len() as u64;
1149        let cloned: TensorGpu<_, _> = context.tensor_init(shape);
1150
1151        let mut encoder = context.device.create_command_encoder(&Default::default());
1152        encoder.copy_buffer_to_buffer(&self.buffer, 0, &cloned.buffer, 0, size);
1153        context.queue.submit(Some(encoder.finish()));
1154
1155        cloned
1156    }
1157}
1158
1159/// Stack a batch of tensors of shape `[C, T, 1]` to one with shape `[C, A, 1]`, with cursors information.
1160#[derive(Debug, Clone)]
1161pub struct TensorStack<T: Scalar> {
1162    pub tensor: TensorCpu<T>,
1163    pub cursors: Vec<Cursor>,
1164}
1165
1166impl<T: Scalar> TensorStack<T> {
1167    /// Number of input batches (including empty batches).
1168    #[inline]
1169    pub fn num_batch(&self) -> usize {
1170        self.cursors.len()
1171    }
1172
1173    /// Number of non-empty input batches.
1174    #[inline]
1175    pub fn num_active_batch(&self) -> usize {
1176        self.cursors.iter().filter(|cursor| cursor.len > 0).count()
1177    }
1178
1179    #[inline]
1180    pub fn num_token(&self) -> usize {
1181        self.tensor.shape[1]
1182    }
1183}
1184
1185impl<T: Scalar> TryFrom<Vec<TensorCpu<T>>> for TensorStack<T> {
1186    type Error = TensorError;
1187
1188    fn try_from(value: Vec<TensorCpu<T>>) -> Result<Self, Self::Error> {
1189        let shape = match value.first() {
1190            Some(batch) => batch.shape,
1191            None => Err(TensorErrorKind::Empty)?,
1192        };
1193
1194        value
1195            .iter()
1196            .try_for_each(|batch| batch.check_shape([shape[0], batch.shape[1], 1, 1]))?;
1197
1198        let cursors = value
1199            .iter()
1200            .enumerate()
1201            .scan(0, |token, (batch, tensor)| {
1202                let len = tensor.shape[1];
1203                let cursor = Cursor {
1204                    batch,
1205                    token: *token,
1206                    len,
1207                };
1208                *token += len;
1209                Some(cursor)
1210            })
1211            .collect_vec();
1212
1213        let (shape, data) = value.into_iter().fold(
1214            (Shape::new(shape[0], 0, 1, 1), vec![]),
1215            |(mut shape, mut data), tensor| {
1216                shape[1] += tensor.shape[1];
1217                data.extend(tensor.data.to_vec());
1218                (shape, data)
1219            },
1220        );
1221        let data = data.into();
1222
1223        Ok(Self {
1224            tensor: Tensor {
1225                shape,
1226                data,
1227                id: uid::Id::new(),
1228                phantom: PhantomData,
1229            },
1230            cursors,
1231        })
1232    }
1233}
1234
1235impl Context {
1236    #[inline]
1237    pub fn zeros<T: Scalar, Tensor>(&self, shape: impl Into<Shape>) -> Tensor
1238    where
1239        TensorCpu<T>: TensorInto<Tensor>,
1240    {
1241        let tensor: TensorCpu<T> = TensorInit::init(shape);
1242        tensor.to(self)
1243    }
1244
1245    #[inline]
1246    pub fn ones<T: Scalar, Tensor>(&self, shape: impl Into<Shape>) -> Tensor
1247    where
1248        TensorCpu<T>: TensorInto<Tensor>,
1249    {
1250        let shape = shape.into();
1251        let data = vec![T::one(); shape.len()];
1252        let tensor: TensorCpu<T> = TensorInit::from_data(shape, data).unwrap();
1253        tensor.to(self)
1254    }
1255
1256    #[inline]
1257    pub fn tensor_from_data<'a, T: Scalar, Tensor: TensorInitContext<T>>(
1258        &self,
1259        shape: impl Into<Shape>,
1260        data: impl Into<Cow<'a, [T]>>,
1261    ) -> Result<Tensor, TensorError> {
1262        TensorInitContext::from_data(self, shape, data)
1263    }
1264
1265    #[inline]
1266    pub fn tensor_init<T: Scalar, Tensor: TensorInitContext<T>>(
1267        &self,
1268        shape: impl Into<Shape>,
1269    ) -> Tensor {
1270        TensorInitContext::init(self, shape)
1271    }
1272}
1273
1274mod sealed {
1275    use super::{Cpu, Gpu, Kind, ReadWrite, Uniform};
1276    use crate::num::Scalar;
1277
1278    pub trait Sealed {}
1279
1280    impl<T: Scalar> Sealed for Cpu<T> {}
1281    impl<K: Kind> Sealed for Gpu<K> {}
1282
1283    impl Sealed for Uniform {}
1284    impl Sealed for ReadWrite {}
1285}
1286
1287#[cfg(test)]
1288mod tests {
1289    use anyhow::Result;
1290
1291    use super::Shape;
1292    use crate::tensor::{TensorCpu, TensorInit, TensorShape};
1293
1294    #[test]
1295    fn test_pad_64() -> Result<()> {
1296        let shape = Shape::new(133, 256, 1, 1);
1297        let x: Vec<_> = (0..shape.len()).map(|x| x as f32).collect();
1298        let x = TensorCpu::from_data(shape, x)?.pad([64, 64, 1, 1]);
1299
1300        assert_eq!(x.shape(), Shape::new(192, 256, 1, 1));
1301        assert_eq!(x[(132, 255, 0, 0)], (shape.len() - 1) as f32);
1302        assert_eq!(x[(133, 255, 0, 0)], 0.0);
1303
1304        Ok(())
1305    }
1306
1307    #[test]
1308    fn test_repeat() -> Result<()> {
1309        let shape = Shape::new(5, 1, 2, 1);
1310        let x: Vec<_> = (0..shape.len()).map(|x| x as f32).collect();
1311        let x = TensorCpu::from_data(shape, x)?;
1312
1313        let y = x.clone().repeat(1, 3);
1314        let ans = [
1315            [0.0, 1.0, 2.0, 3.0, 4.0].repeat(3),
1316            [5.0, 6.0, 7.0, 8.0, 9.0].repeat(3),
1317        ]
1318        .concat();
1319        y.check_shape([5, 3, 2, 1])?;
1320        assert_eq!(y.to_vec(), ans);
1321
1322        let y = x.clone().repeat(0, 3);
1323        y.check_shape([15, 1, 2, 1])?;
1324        assert_eq!(y.to_vec(), ans);
1325
1326        let y = x.repeat(2, 3);
1327        let ans = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0].repeat(3);
1328        y.check_shape([5, 1, 6, 1])?;
1329        assert_eq!(y.to_vec(), ans);
1330
1331        Ok(())
1332    }
1333
1334    #[test]
1335    fn test_split() -> Result<()> {
1336        let shape = Shape::new(5, 1, 2, 1);
1337        let x: Vec<_> = (0..10).map(|x| x as f32).collect();
1338        let x = TensorCpu::from_data(shape, x)?;
1339
1340        assert!(x.clone().split(0).is_err());
1341        assert!(x.clone().split(1).is_err());
1342
1343        let x = x.split(2)?;
1344        x[0].check_shape([5, 1, 1, 1])?;
1345        x[1].check_shape([5, 1, 1, 1])?;
1346        assert_eq!(x[0].to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
1347        assert_eq!(x[1].to_vec(), vec![5.0, 6.0, 7.0, 8.0, 9.0]);
1348
1349        Ok(())
1350    }
1351}