web_rwkv/tensor/
mod.rs

1use std::{borrow::Cow, marker::PhantomData, sync::Arc};
2
3use itertools::Itertools;
4use shape::ShapedIndex;
5use thiserror::Error;
6use wgpu::{
7    BindGroupLayoutEntry, BindingResource, BindingType, Buffer, BufferBinding, BufferBindingType,
8    BufferUsages, ShaderStages,
9};
10
11use self::{
12    kind::{Kind, ReadWrite, Uniform},
13    shape::{IntoBytes, Shape, TensorAxis, TensorDimension, TensorSlice},
14};
15use crate::{
16    context::Context,
17    num::{Float, Scalar},
18};
19
20pub mod cache;
21pub mod matrix;
22pub mod ops;
23pub mod serialization;
24pub mod shape;
25
26/// Data defining a tensor view in shader.
27#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
28pub struct View {
29    pub shape: Shape,
30    pub stride: Shape,
31    pub offset: Shape,
32}
33
34impl IntoBytes for View {
35    fn into_bytes(self) -> Vec<u8> {
36        [
37            self.shape.into_bytes(),
38            self.stride.into_bytes(),
39            self.offset.into_bytes(),
40        ]
41        .concat()
42    }
43}
44
45/// A record in order to separate different batches of input of various lengths.
46#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct Cursor {
48    pub batch: usize,
49    pub token: usize,
50    pub len: usize,
51}
52
53impl Cursor {
54    pub fn pack(self) -> u32 {
55        let batch = self.batch as u8;
56        let token = (self.token as u16).to_ne_bytes();
57        let len = self.len as u8;
58        bytemuck::cast([batch, token[0], token[1], len])
59    }
60}
61
62pub trait IntoPackedCursors {
63    fn into_stack(self) -> Vec<u32>;
64    fn into_cursors(self) -> Vec<u32>;
65}
66
67impl IntoPackedCursors for Vec<Cursor> {
68    fn into_stack(self) -> Vec<u32> {
69        self.into_iter()
70            .filter(|cursor| cursor.len > 0)
71            .map(Cursor::pack)
72            .collect()
73    }
74
75    fn into_cursors(self) -> Vec<u32> {
76        self.into_iter()
77            .filter(|cursor| cursor.len > 0)
78            .map(|cursor| {
79                let repeat = cursor.len;
80                vec![cursor.pack(); repeat]
81            })
82            .collect_vec()
83            .concat()
84    }
85}
86
87#[derive(Debug)]
88pub struct TensorError {
89    pub error: TensorErrorKind,
90    #[cfg(feature = "backtrace")]
91    pub backtrace: std::backtrace::Backtrace,
92}
93
94impl TensorError {
95    #[cfg(feature = "backtrace")]
96    pub fn new(error: TensorErrorKind) -> Self {
97        let backtrace = std::backtrace::Backtrace::capture();
98        Self { error, backtrace }
99    }
100
101    #[cfg(not(feature = "backtrace"))]
102    pub fn new(error: TensorErrorKind) -> Self {
103        Self { error }
104    }
105}
106
107impl From<TensorErrorKind> for TensorError {
108    fn from(value: TensorErrorKind) -> Self {
109        Self::new(value)
110    }
111}
112
113impl std::fmt::Display for TensorError {
114    #[cfg(feature = "backtrace")]
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        writeln!(f, "{}\n\nBacktrace:\n{}", self.error, self.backtrace)
117    }
118
119    #[cfg(not(feature = "backtrace"))]
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        writeln!(f, "{}", self.error)
122    }
123}
124
125impl std::error::Error for TensorError {}
126
127#[derive(Debug, Error)]
128pub enum TensorErrorKind {
129    #[error("list must not be empty")]
130    Empty,
131    #[error("data type mismatch")]
132    Type,
133    #[error("data size not match: {0} vs. {1}")]
134    Size(usize, usize),
135    #[error("batch size not match: {0} vs. {1}")]
136    Batch(usize, usize),
137    #[error("tensor shape not match: {0} vs. {1}")]
138    Shape(Shape, Shape),
139    #[error("cannot deduce dimension")]
140    Deduce,
141    #[error("batch {batch} out of range of max {max}")]
142    BatchOutOfRange { batch: usize, max: usize },
143    #[error("slice {start}..{end} out of range for dimension size {dim}")]
144    SliceOutOfRange {
145        dim: usize,
146        start: usize,
147        end: usize,
148    },
149    #[error("slice not contiguous")]
150    SliceInvalid,
151    #[error("cannot split along the axis {0}")]
152    SplitInvalid(usize),
153    #[error("possible tensor error(s):\n{0}")]
154    Any(#[from] AnyTensorError),
155}
156
157#[derive(Debug, Error)]
158pub struct AnyTensorError(pub Vec<Box<dyn std::error::Error + Send + Sync>>);
159
160impl std::fmt::Display for AnyTensorError {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        self.0
163            .iter()
164            .enumerate()
165            .try_for_each(|(index, error)| writeln!(f, "{index}. {error}"))
166    }
167}
168
169pub trait DeepClone: Sized {
170    fn deep_clone(&self) -> Self;
171}
172
173pub trait TensorScalar {
174    type T: Scalar;
175}
176
177pub trait TensorInitContext<T: Scalar>: Sized {
178    /// Init the tensor with given shape and contents.
179    fn from_data<'a, S, D>(context: &Context, shape: S, data: D) -> Result<Self, TensorError>
180    where
181        S: Into<Shape>,
182        D: Into<Cow<'a, [T]>>;
183    /// Init the tensor with given shape.
184    fn init(context: &Context, shape: impl Into<Shape>) -> Self;
185}
186
187pub trait TensorInit<T: Scalar>: Sized {
188    /// Init the tensor with given shape and contents.
189    fn from_data<'a, S, D>(shape: S, data: D) -> Result<Self, TensorError>
190    where
191        S: Into<Shape>,
192        D: Into<Cow<'a, [T]>>;
193    /// Init the tensor with given shape.
194    fn init(shape: impl Into<Shape>) -> Self;
195
196    /// Init an 1-D tensor from data.
197    fn from_data_1d<'a>(data: impl Into<Cow<'a, [T]>>) -> Self {
198        let data: Cow<'_, [T]> = data.into();
199        let shape = [data.len(), 1, 1, 1];
200        Self::from_data(shape, data).expect("tensor 1d from data")
201    }
202}
203
204pub trait TensorInto<Into> {
205    fn to(self, context: &Context) -> Into;
206}
207
208pub trait TensorShape: Sized {
209    /// Get the shape of the tensor.
210    fn shape(&self) -> Shape;
211
212    /// Check if the tensor's shape is the same as what expected.
213    fn check_shape(&self, shape: impl Into<Shape>) -> Result<(), TensorError> {
214        let shape = shape.into();
215        (self.shape() == shape)
216            .then_some(())
217            .ok_or(TensorErrorKind::Shape(self.shape(), shape))
218            .map_err(Into::into)
219    }
220
221    /// Check if the tensor's shape matches any of the expected ones.
222    fn check_shape_any<S>(&self, shapes: &[S]) -> Result<(), TensorError>
223    where
224        S: Into<Shape> + ToOwned<Owned = S>,
225    {
226        let (oks, errors): (Vec<_>, Vec<_>) = shapes
227            .iter()
228            .map(|shape| self.check_shape(shape.to_owned()).map_err(Into::into))
229            .partition_result();
230        match oks.is_empty() {
231            true => Err(TensorErrorKind::Any(AnyTensorError(errors)))?,
232            false => Ok(()),
233        }
234    }
235}
236
237pub trait TensorReshape: Sized {
238    fn reshape(
239        &self,
240        x: TensorDimension,
241        y: TensorDimension,
242        z: TensorDimension,
243        w: TensorDimension,
244    ) -> Result<Self, TensorError>;
245}
246
247pub trait TensorResource {
248    /// Retrieve the key identifying a resource.
249    fn resource_key(&self) -> ResourceKey;
250    /// Binding for metadata of the tensor (shape, stride, etc.).
251    fn meta_binding(&self) -> BindingResource<'_>;
252    /// Binding for actual data of the tensor.
253    fn binding(&self) -> BindingResource<'_>;
254}
255
256/// A tensor on either CPU or GPU.
257#[derive(Debug)]
258pub struct Tensor<D: Device, T: Scalar> {
259    shape: Shape,
260    data: D::Data,
261    id: uid::Id<TensorId>,
262    phantom: PhantomData<T>,
263}
264
265#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
266pub struct TensorId;
267
268/// A unique identifier of tensor views. Useful in caches.
269#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
270pub struct ResourceKey {
271    pub id: uid::Id<TensorId>,
272    pub view: View,
273}
274
275pub trait Device: sealed::Sealed {
276    type Data: Clone;
277}
278
279#[derive(Debug)]
280pub struct Cpu<T: Scalar>(PhantomData<T>);
281
282#[derive(Debug)]
283pub struct Gpu<K: Kind>(PhantomData<K>);
284
285impl<T: Scalar> Device for Cpu<T> {
286    type Data = Arc<[T]>;
287}
288
289impl<K: Kind> Device for Gpu<K> {
290    type Data = TensorGpuData;
291}
292
293/// Buffer of the tensor on GPU.
294#[derive(Debug, Clone)]
295pub struct TensorGpuData {
296    pub context: Context,
297    pub meta: Arc<Buffer>,
298    pub buffer: Arc<Buffer>,
299}
300
301pub mod kind {
302    use web_rwkv_derive::Kind;
303    use wgpu::BufferUsages;
304
305    use super::sealed;
306
307    pub trait Kind: sealed::Sealed {
308        fn buffer_usages() -> BufferUsages;
309    }
310
311    /// Tensor is a uniform buffer.
312    #[derive(Debug, Kind)]
313    #[usage(UNIFORM, COPY_DST, COPY_SRC)]
314    pub struct Uniform;
315
316    /// Tensor is a storage buffer with can be copied to other buffers.
317    #[derive(Debug, Kind)]
318    #[usage(STORAGE, COPY_DST, COPY_SRC)]
319    pub struct ReadWrite;
320}
321
322pub type TensorCpu<T> = Tensor<Cpu<T>, T>;
323pub type TensorGpu<T, K> = Tensor<Gpu<K>, T>;
324
325impl<D: Device, T: Scalar> Clone for Tensor<D, T> {
326    fn clone(&self) -> Self {
327        Self {
328            shape: self.shape,
329            data: self.data.clone(),
330            id: self.id,
331            phantom: PhantomData,
332        }
333    }
334}
335
336impl<D: Device, T: Scalar> std::ops::Deref for Tensor<D, T> {
337    type Target = D::Data;
338
339    #[inline]
340    fn deref(&self) -> &Self::Target {
341        &self.data
342    }
343}
344
345impl<D: Device, T: Scalar> TensorScalar for Tensor<D, T> {
346    type T = T;
347}
348
349impl<D: Device, T: Scalar> Tensor<D, T> {
350    #[inline]
351    pub fn len(&self) -> usize {
352        self.shape.len()
353    }
354
355    #[inline]
356    pub fn is_empty(&self) -> bool {
357        self.shape.is_empty()
358    }
359
360    /// Size of the tensor in bytes.
361    #[inline]
362    pub fn size(&self) -> usize {
363        self.len() * T::size()
364    }
365
366    /// The offset in bytes for a linear index.
367    #[inline]
368    pub fn offset(index: usize) -> usize {
369        index * T::size()
370    }
371
372    #[inline]
373    pub fn data(&self) -> &D::Data {
374        &self.data
375    }
376
377    #[inline]
378    pub fn id(&self) -> uid::Id<TensorId> {
379        self.id
380    }
381}
382
383impl<D: Device, F: Float> Tensor<D, F> {
384    #[inline]
385    pub const fn def(&self) -> &'static str {
386        F::DEF
387    }
388}
389
390impl<T: Scalar> TensorInit<T> for TensorCpu<T> {
391    fn from_data<'a, S, D>(shape: S, data: D) -> Result<Self, TensorError>
392    where
393        S: Into<Shape>,
394        D: Into<Cow<'a, [T]>>,
395    {
396        let shape = shape.into();
397        let data: Cow<'_, _> = data.into();
398        if shape.len() != data.len() {
399            Err(TensorErrorKind::Size(shape.len(), data.len()))?;
400        }
401        let data = data.into_owned().into();
402        Ok(Self {
403            shape,
404            data,
405            id: uid::Id::new(),
406            phantom: PhantomData,
407        })
408    }
409
410    #[inline]
411    fn init(shape: impl Into<Shape>) -> Self {
412        let shape = shape.into();
413        let data = vec![T::zero(); shape.len()].into();
414        Self {
415            shape,
416            data,
417            id: uid::Id::new(),
418            phantom: PhantomData,
419        }
420    }
421}
422
423impl<T: Scalar> TensorInitContext<T> for TensorCpu<T> {
424    fn from_data<'a, S, D>(_context: &Context, shape: S, data: D) -> Result<Self, TensorError>
425    where
426        S: Into<Shape>,
427        D: Into<Cow<'a, [T]>>,
428    {
429        TensorInit::from_data(shape, data)
430    }
431
432    fn init(_context: &Context, shape: impl Into<Shape>) -> Self {
433        TensorInit::init(shape)
434    }
435}
436
437impl<T: Scalar> TensorInto<TensorCpu<T>> for TensorCpu<T> {
438    fn to(self, _: &Context) -> Self {
439        self
440    }
441}
442
443impl<T: Scalar> DeepClone for TensorCpu<T> {
444    fn deep_clone(&self) -> Self {
445        self.clone()
446    }
447}
448
449impl<T: Scalar> TensorShape for TensorCpu<T> {
450    #[inline]
451    fn shape(&self) -> Shape {
452        self.shape
453    }
454}
455
456impl<T: Scalar> TensorReshape for TensorCpu<T> {
457    #[inline]
458    fn reshape(
459        &self,
460        x: TensorDimension,
461        y: TensorDimension,
462        z: TensorDimension,
463        w: TensorDimension,
464    ) -> Result<Self, TensorError> {
465        let shape = TensorDimension::deduce(self.shape, x, y, z, w)?;
466        Ok(Self {
467            shape,
468            ..self.clone()
469        })
470    }
471}
472
473impl<T: Scalar, K: Kind> TensorInitContext<T> for TensorGpu<T, K> {
474    fn from_data<'a, S, D>(context: &Context, shape: S, data: D) -> Result<Self, TensorError>
475    where
476        S: Into<Shape>,
477        D: Into<Cow<'a, [T]>>,
478    {
479        let tensor: TensorCpu<T> = TensorInit::from_data(shape, data)?;
480        Ok(tensor.to(context))
481    }
482
483    fn init(context: &Context, shape: impl Into<Shape>) -> Self {
484        let context = context.clone();
485        let shape = shape.into();
486        let meta = context.checkout_shape_uniform(shape);
487        let size = shape.len() * std::mem::size_of::<T>();
488        let buffer = context.checkout_buffer(size, K::buffer_usages());
489        Self {
490            shape,
491            data: TensorGpuData {
492                context,
493                meta,
494                buffer,
495            },
496            id: uid::Id::new(),
497            phantom: PhantomData,
498        }
499    }
500}
501
502impl<T: Scalar, K: Kind> TensorInto<TensorGpu<T, K>> for TensorCpu<T> {
503    fn to(self, context: &Context) -> TensorGpu<T, K> {
504        let Tensor { shape, data, .. } = self;
505        let context = context.clone();
506        let meta = context.checkout_shape_uniform(shape);
507        let contents = bytemuck::cast_slice(&data);
508        let buffer = context.checkout_buffer_init(contents, K::buffer_usages());
509        TensorGpu {
510            shape,
511            data: TensorGpuData {
512                context,
513                meta,
514                buffer,
515            },
516            id: uid::Id::new(),
517            phantom: PhantomData,
518        }
519    }
520}
521
522#[cfg(not(target_arch = "wasm32"))]
523impl<T: Scalar> TensorInto<TensorGpu<T, ReadWrite>> for TensorGpu<T, ReadWrite> {
524    fn to(self, context: &Context) -> Self {
525        match context {
526            context if context == &self.context => self,
527            _ => self.back_in_place().to(context),
528        }
529    }
530}
531
532#[cfg(target_arch = "wasm32")]
533impl<T: Scalar> TensorInto<TensorGpu<T, ReadWrite>> for TensorGpu<T, ReadWrite> {
534    fn to(self, _: &Context) -> Self {
535        self
536    }
537}
538
539impl<T: Scalar, K: Kind> TensorShape for TensorGpu<T, K> {
540    #[inline]
541    fn shape(&self) -> Shape {
542        self.shape
543    }
544}
545
546impl<T: Scalar, K: Kind> TensorReshape for TensorGpu<T, K> {
547    #[inline]
548    fn reshape(
549        &self,
550        x: TensorDimension,
551        y: TensorDimension,
552        z: TensorDimension,
553        w: TensorDimension,
554    ) -> Result<Self, TensorError> {
555        let shape = TensorDimension::deduce(self.shape, x, y, z, w)?;
556        let context = self.context.clone();
557        let meta = context.checkout_shape_uniform(shape);
558        let buffer = self.buffer.clone();
559        Ok(Self {
560            shape,
561            data: TensorGpuData {
562                context,
563                meta,
564                buffer,
565            },
566            ..self.clone()
567        })
568    }
569}
570
571impl<T: Scalar, K: Kind> TensorResource for TensorGpu<T, K> {
572    #[inline]
573    fn resource_key(&self) -> ResourceKey {
574        let id = self.id;
575        let view = View {
576            shape: self.shape,
577            stride: self.shape,
578            offset: [0, 0, 0, 0].into(),
579        };
580        ResourceKey { id, view }
581    }
582
583    #[inline]
584    fn meta_binding(&self) -> BindingResource<'_> {
585        BindingResource::Buffer(BufferBinding {
586            buffer: &self.meta,
587            offset: 0,
588            size: None,
589        })
590    }
591
592    #[inline]
593    fn binding(&self) -> BindingResource<'_> {
594        BindingResource::Buffer(BufferBinding {
595            buffer: &self.buffer,
596            offset: 0,
597            size: None,
598        })
599    }
600}
601
602impl<T: Scalar, K: Kind> TensorGpu<T, K> {
603    pub fn from_data_u8(
604        context: &Context,
605        shape: impl Into<Shape>,
606        contents: &[u8],
607    ) -> Result<Self, TensorError> {
608        let shape = shape.into();
609        let size = shape.len() * size_of::<T>();
610        if contents.len() != size {
611            Err(TensorErrorKind::Size(size, contents.len()))?;
612        }
613        let buffer = context.checkout_buffer_init(contents, K::buffer_usages());
614        let meta = context.checkout_shape_uniform(shape);
615        Ok(Self {
616            shape,
617            data: TensorGpuData {
618                context: context.clone(),
619                meta,
620                buffer,
621            },
622            id: uid::Id::new(),
623            phantom: PhantomData,
624        })
625    }
626
627    #[cfg(not(target_arch = "wasm32"))]
628    pub fn back_in_place(&self) -> TensorCpu<T> {
629        use crate::context::ContextEvent;
630
631        if self.is_empty() {
632            return TensorCpu {
633                shape: self.shape,
634                data: Arc::new([]),
635                id: uid::Id::new(),
636                phantom: PhantomData,
637            };
638        }
639
640        let context = &self.context;
641        let size = self.buffer.size();
642        let buffer = context.checkout_buffer(
643            size as usize,
644            BufferUsages::MAP_READ | BufferUsages::COPY_DST,
645        );
646
647        let mut encoder = context.device.create_command_encoder(&Default::default());
648        encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
649        context.queue.submit(Some(encoder.finish()));
650
651        let (sender, receiver) = flume::bounded(1);
652        let _ = context.event().send(ContextEvent { buffer, sender });
653        let data = receiver.recv().expect("failed to receive read back buffer");
654        let data = unsafe {
655            let data = Box::leak(data);
656            let slice = bytemuck::cast_slice_mut::<_, T>(data);
657            Box::from_raw(slice)
658        };
659        let data = data.into_vec().into();
660        let shape = self.shape;
661
662        TensorCpu {
663            shape,
664            data,
665            id: uid::Id::new(),
666            phantom: PhantomData,
667        }
668    }
669
670    #[cfg(not(target_arch = "wasm32"))]
671    pub async fn back(&self) -> TensorCpu<T> {
672        if self.is_empty() {
673            return TensorCpu {
674                shape: self.shape,
675                data: Arc::new([]),
676                id: uid::Id::new(),
677                phantom: PhantomData,
678            };
679        }
680
681        let context = &self.context;
682        let size = self.buffer.size();
683        let buffer = context.checkout_buffer(
684            size as usize,
685            BufferUsages::MAP_READ | BufferUsages::COPY_DST,
686        );
687
688        let mut encoder = context.device.create_command_encoder(&Default::default());
689        encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
690        context.queue.submit(Some(encoder.finish()));
691
692        let (sender, receiver) = flume::bounded(1);
693
694        let _ = context
695            .event()
696            .send(crate::context::ContextEvent { buffer, sender });
697        let data = receiver
698            .recv_async()
699            .await
700            .expect("failed to receive read back buffer");
701        let data = unsafe {
702            let data = Box::leak(data);
703            let slice = bytemuck::cast_slice_mut::<_, T>(data);
704            Box::from_raw(slice)
705        };
706        let data = data.into_vec().into();
707
708        TensorCpu {
709            shape: self.shape,
710            data,
711            id: uid::Id::new(),
712            phantom: PhantomData,
713        }
714    }
715
716    #[cfg(target_arch = "wasm32")]
717    pub async fn back(self) -> TensorCpu<T> {
718        if self.is_empty() {
719            return TensorCpu {
720                shape: self.shape,
721                data: Arc::new([]),
722                id: uid::Id::new(),
723                phantom: PhantomData,
724            };
725        }
726
727        let context = &self.context;
728        let size = self.buffer.size();
729        let buffer = context.checkout_buffer(
730            size as usize,
731            BufferUsages::MAP_READ | BufferUsages::COPY_DST,
732        );
733
734        let mut encoder = context.device.create_command_encoder(&Default::default());
735        encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
736        context.queue.submit(Some(encoder.finish()));
737
738        let (sender, receiver) = flume::unbounded();
739
740        let slice = buffer.slice(..);
741        slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
742
743        _ = context.device.poll(wgpu::PollType::Wait);
744        receiver
745            .recv_async()
746            .await
747            .expect("failed to receive read back buffer")
748            .expect("failed to map buffer");
749
750        let data = {
751            let map = slice.get_mapped_range();
752            Vec::from(bytemuck::cast_slice(&map)).into()
753        };
754        buffer.unmap();
755
756        TensorCpu {
757            shape: self.shape,
758            data,
759            id: uid::Id::new(),
760            phantom: PhantomData,
761        }
762    }
763}
764
765impl<T: Scalar, K: Kind> TensorGpu<T, K> {
766    #[inline]
767    pub fn context(&self) -> &Context {
768        &self.context
769    }
770
771    pub fn load(&self, host: &TensorCpu<T>) -> Result<(), TensorError> {
772        host.check_shape(self.shape)?;
773        self.context
774            .queue
775            .write_buffer(&self.buffer, 0, bytemuck::cast_slice(&host.data[..]));
776        Ok(())
777    }
778
779    pub fn load_batch(&self, host: &TensorCpu<T>, batch: usize) -> Result<(), TensorError> {
780        host.check_shape([self.shape[0], self.shape[1], 1, 1])?;
781        if batch >= self.shape[2] {
782            Err(TensorErrorKind::BatchOutOfRange {
783                batch,
784                max: self.shape[2],
785            })?;
786        }
787        let offset = (T::size() * self.shape[0] * self.shape[1] * batch) as u64;
788        self.context
789            .queue
790            .write_buffer(&self.buffer, offset, bytemuck::cast_slice(&host.data[..]));
791        Ok(())
792    }
793
794    pub fn destroy(self) {
795        self.buffer.destroy();
796    }
797}
798
799impl<T: Scalar> TensorGpu<T, Uniform> {
800    #[inline]
801    pub fn layout(&self, binding: u32) -> BindGroupLayoutEntry {
802        BindGroupLayoutEntry {
803            binding,
804            visibility: ShaderStages::COMPUTE,
805            ty: BindingType::Buffer {
806                ty: BufferBindingType::Uniform,
807                has_dynamic_offset: false,
808                min_binding_size: None,
809            },
810            count: None,
811        }
812    }
813}
814
815impl<T: Scalar> TensorGpu<T, ReadWrite> {
816    #[inline]
817    pub fn meta_layout(&self, binding: u32) -> BindGroupLayoutEntry {
818        BindGroupLayoutEntry {
819            binding,
820            visibility: ShaderStages::COMPUTE,
821            ty: BindingType::Buffer {
822                ty: BufferBindingType::Uniform,
823                has_dynamic_offset: false,
824                min_binding_size: None,
825            },
826            count: None,
827        }
828    }
829
830    #[inline]
831    pub fn layout(&self, binding: u32, read_only: bool) -> BindGroupLayoutEntry {
832        BindGroupLayoutEntry {
833            binding,
834            visibility: ShaderStages::COMPUTE,
835            ty: BindingType::Buffer {
836                ty: BufferBindingType::Storage { read_only },
837                has_dynamic_offset: false,
838                min_binding_size: None,
839            },
840            count: None,
841        }
842    }
843}
844
845impl<T: Scalar> From<TensorCpu<T>> for Vec<T> {
846    #[inline]
847    fn from(value: TensorCpu<T>) -> Self {
848        // match Arc::get_mut(&mut value.data) {
849        //     Some(data) => {
850        //         // SAFETY: if `data` is unique, it stays unique in the scope of this function since we own the `Arc`.
851        //         unsafe {
852        //             let len = data.len();
853        //             let data = Arc::into_raw(value.data) as *mut T;
854        //             let slice = core::slice::from_raw_parts_mut(data, len);
855        //             let boxed = Box::from_raw(slice);
856        //             boxed.into_vec()
857        //         }
858        //     }
859        //     None => value.data.to_vec(),
860        // }
861        value.to_vec()
862    }
863}
864
865impl<T: Scalar, S: Into<ShapedIndex>> std::ops::Index<S> for TensorCpu<T> {
866    type Output = T;
867
868    fn index(&self, index: S) -> &Self::Output {
869        &self.data[self.shape.linear_index(index)]
870    }
871}
872
873impl<T: Scalar> TensorCpu<T> {
874    /// Apply a map `f` to every element in the tensor.
875    pub fn map<U: Scalar>(self, f: impl FnMut(&T) -> U) -> TensorCpu<U> {
876        let Self { shape, data, .. } = self;
877        let data = data.iter().map(f).collect_vec();
878        TensorInit::from_data(shape, &data).unwrap()
879    }
880
881    /// Pad each dimension to multiples with zeros.
882    pub fn pad(self, multiples: impl Into<Shape>) -> Self {
883        // let shape = Shape::new(
884        //     self.shape[0].next_multiple_of(64),
885        //     self.shape[1].next_multiple_of(64),
886        //     self.shape[2].next_multiple_of(64),
887        //     self.shape[3].next_multiple_of(64),
888        // );
889
890        let multiples: Shape = multiples.into();
891
892        let mut shape = self.shape;
893        for (axis, multiple) in multiples.iter().enumerate() {
894            shape[axis] = shape[axis].next_multiple_of(multiple);
895        }
896
897        let mut data = vec![T::zero(); shape.len()];
898        for index in self.shape.cartesian_product() {
899            let value = self[index];
900            data[shape.linear_index(index)] = value;
901        }
902        TensorInit::from_data(shape, data).unwrap()
903    }
904
905    /// Repeat the tensor along a given axis.
906    pub fn repeat(self, axis: usize, repeat: usize) -> Self {
907        let mut shape = self.shape;
908        let data = self.data;
909
910        let num_chunk: usize = shape.iter().skip(axis + 1).product();
911        let chunk_size = data.len() / num_chunk;
912
913        shape[axis] *= repeat;
914
915        let data = (0..num_chunk)
916            .map(|chunk| {
917                let start = chunk * chunk_size;
918                let end = start + chunk_size;
919                let chunk = data[start..end].to_vec();
920                chunk.repeat(repeat)
921            })
922            .collect_vec()
923            .concat()
924            .into();
925
926        Self {
927            shape,
928            data,
929            id: uid::Id::new(),
930            phantom: PhantomData,
931        }
932    }
933
934    /// Concat a batch of tensors.
935    pub fn stack(batches: Vec<Self>, axis: usize) -> Result<Self, TensorError> {
936        let mut shape = match batches.first() {
937            Some(batch) => batch.shape,
938            None => Err(TensorErrorKind::Empty)?,
939        };
940
941        // batches
942        //     .iter()
943        //     .try_for_each(|batch| batch.check_shape([shape[0], shape[1], batch.shape[2], 1]))?;
944
945        batches.iter().try_for_each(|batch| match axis {
946            0 => batch.check_shape([batch.shape[0], 1, 1, 1]),
947            1 => batch.check_shape([shape[0], batch.shape[1], 1, 1]),
948            2 => batch.check_shape([shape[0], shape[1], batch.shape[2], 1]),
949            3 => batch.check_shape([shape[0], shape[1], shape[2], batch.shape[3]]),
950            _ => unreachable!(),
951        })?;
952
953        let num_batch: usize = batches.iter().map(|batch| batch.shape[axis]).sum();
954        shape[axis] = num_batch;
955
956        let data = batches
957            .into_iter()
958            .map(|batch| batch.data.to_vec())
959            .collect_vec()
960            .concat()
961            .into();
962
963        Ok(Self {
964            shape,
965            data,
966            id: uid::Id::new(),
967            phantom: PhantomData,
968        })
969    }
970
971    /// Split the tensor along the batch axis.
972    pub fn split(self, axis: usize) -> Result<Vec<Self>, TensorError> {
973        if self.shape.iter().skip(axis + 1).any(|dim| dim > 1) {
974            Err(TensorErrorKind::SplitInvalid(axis))?;
975        }
976
977        (0..self.shape[axis])
978            .map(|index| match axis {
979                0 => self.slice(index, .., .., ..),
980                1 => self.slice(.., index, .., ..),
981                2 => self.slice(.., .., index, ..),
982                3 => self.slice(.., .., .., ..),
983                _ => Err(TensorErrorKind::SplitInvalid(axis))?,
984            })
985            .try_collect()
986    }
987
988    pub fn slice(
989        &self,
990        x: impl TensorAxis,
991        y: impl TensorAxis,
992        z: impl TensorAxis,
993        w: impl TensorAxis,
994    ) -> Result<Self, TensorError> {
995        let slice = (x, y, z, w);
996        let (start, end) = slice.shaped_bounds(self.shape)?;
997        let shape = (end - start).into();
998
999        let (start, end) = slice.linear_bounds(self.shape)?;
1000        let data = self.data[start..end].into();
1001
1002        Ok(Self {
1003            shape,
1004            data,
1005            id: uid::Id::new(),
1006            phantom: PhantomData,
1007        })
1008    }
1009
1010    pub fn into_slice(
1011        self,
1012        x: impl TensorAxis,
1013        y: impl TensorAxis,
1014        z: impl TensorAxis,
1015        w: impl TensorAxis,
1016    ) -> Result<Self, TensorError> {
1017        let slice = (x, y, z, w);
1018        let (start, end) = slice.shaped_bounds(self.shape)?;
1019        let shape = (end - start).into();
1020
1021        let (start, end) = slice.linear_bounds(self.shape)?;
1022        let data = self.data[start..end].into();
1023
1024        Ok(Self {
1025            shape,
1026            data,
1027            id: uid::Id::new(),
1028            phantom: PhantomData,
1029        })
1030    }
1031}
1032
1033/// Like a reference to a tensor, but refer to a sub-chunk of it.
1034#[derive(Debug, Clone)]
1035pub struct TensorGpuView<'a, T: Scalar> {
1036    tensor: &'a TensorGpu<T, ReadWrite>,
1037    meta: Arc<Buffer>,
1038    view: View,
1039}
1040
1041impl<T: Scalar> TensorShape for TensorGpuView<'_, T> {
1042    #[inline]
1043    fn shape(&self) -> Shape {
1044        self.view.shape
1045    }
1046}
1047
1048impl<T: Scalar> TensorGpuView<'_, T> {
1049    #[inline]
1050    pub fn tensor(&self) -> &TensorGpu<T, ReadWrite> {
1051        self.tensor
1052    }
1053
1054    #[inline]
1055    pub fn context(&self) -> &Context {
1056        self.tensor.context()
1057    }
1058
1059    #[inline]
1060    pub fn data(&self) -> &TensorGpuData {
1061        &self.tensor.data
1062    }
1063
1064    #[inline]
1065    pub fn meta_layout(&self, binding: u32) -> BindGroupLayoutEntry {
1066        self.tensor.meta_layout(binding)
1067    }
1068
1069    #[inline]
1070    pub fn layout(&self, binding: u32, read_only: bool) -> BindGroupLayoutEntry {
1071        self.tensor.layout(binding, read_only)
1072    }
1073}
1074
1075impl<T: Scalar> TensorResource for TensorGpuView<'_, T> {
1076    #[inline]
1077    fn resource_key(&self) -> ResourceKey {
1078        ResourceKey {
1079            id: self.tensor.id,
1080            view: self.view,
1081        }
1082    }
1083
1084    #[inline]
1085    fn meta_binding(&self) -> BindingResource<'_> {
1086        BindingResource::Buffer(BufferBinding {
1087            buffer: &self.meta,
1088            offset: 0,
1089            size: None,
1090        })
1091    }
1092
1093    #[inline]
1094    fn binding(&self) -> BindingResource<'_> {
1095        self.tensor.binding()
1096    }
1097}
1098
1099impl<T: Scalar> TensorScalar for TensorGpuView<'_, T> {
1100    type T = T;
1101}
1102
1103impl<F: Float> TensorGpuView<'_, F> {
1104    #[inline]
1105    pub const fn def(&self) -> &'static str {
1106        F::DEF
1107    }
1108}
1109
1110impl<'a, T: Scalar> From<&'a TensorGpu<T, ReadWrite>> for TensorGpuView<'a, T> {
1111    fn from(value: &'a TensorGpu<T, ReadWrite>) -> Self {
1112        value.view(.., .., .., ..).unwrap()
1113    }
1114}
1115
1116impl<T: Scalar> TensorGpu<T, ReadWrite> {
1117    /// Create a view for the tensor.
1118    pub fn view(
1119        &self,
1120        x: impl TensorAxis,
1121        y: impl TensorAxis,
1122        z: impl TensorAxis,
1123        w: impl TensorAxis,
1124    ) -> Result<TensorGpuView<'_, T>, TensorError> {
1125        let slice = (x, y, z, w);
1126        let (start, end) = slice.shaped_bounds(self.shape)?;
1127        let view = View {
1128            stride: self.shape,
1129            offset: start.into(),
1130            shape: (end - start).into(),
1131        };
1132        let meta = self.context.checkout_view_uniform(view);
1133        Ok(TensorGpuView {
1134            tensor: self,
1135            meta,
1136            view,
1137        })
1138    }
1139}
1140
1141impl<T: Scalar> DeepClone for TensorGpu<T, ReadWrite> {
1142    fn deep_clone(&self) -> Self {
1143        let context = &self.context;
1144        let shape = self.shape;
1145        let size = shape.len() as u64;
1146        let cloned: TensorGpu<_, _> = context.tensor_init(shape);
1147
1148        let mut encoder = context.device.create_command_encoder(&Default::default());
1149        encoder.copy_buffer_to_buffer(&self.buffer, 0, &cloned.buffer, 0, size);
1150        context.queue.submit(Some(encoder.finish()));
1151
1152        cloned
1153    }
1154}
1155
1156/// Stack a batch of tensors of shape `[C, T, 1]` to one with shape `[C, A, 1]`, with cursors information.
1157#[derive(Debug, Clone)]
1158pub struct TensorStack<T: Scalar> {
1159    pub tensor: TensorCpu<T>,
1160    pub cursors: Vec<Cursor>,
1161}
1162
1163impl<T: Scalar> TensorStack<T> {
1164    /// Number of input batches (including empty batches).
1165    #[inline]
1166    pub fn num_batch(&self) -> usize {
1167        self.cursors.len()
1168    }
1169
1170    /// Number of non-empty input batches.
1171    #[inline]
1172    pub fn num_active_batch(&self) -> usize {
1173        self.cursors.iter().filter(|cursor| cursor.len > 0).count()
1174    }
1175
1176    #[inline]
1177    pub fn num_token(&self) -> usize {
1178        self.tensor.shape[1]
1179    }
1180}
1181
1182impl<T: Scalar> TryFrom<Vec<TensorCpu<T>>> for TensorStack<T> {
1183    type Error = TensorError;
1184
1185    fn try_from(value: Vec<TensorCpu<T>>) -> Result<Self, Self::Error> {
1186        let shape = match value.first() {
1187            Some(batch) => batch.shape,
1188            None => Err(TensorErrorKind::Empty)?,
1189        };
1190
1191        value
1192            .iter()
1193            .try_for_each(|batch| batch.check_shape([shape[0], batch.shape[1], 1, 1]))?;
1194
1195        let cursors = value
1196            .iter()
1197            .enumerate()
1198            .scan(0, |token, (batch, tensor)| {
1199                let len = tensor.shape[1];
1200                let cursor = Cursor {
1201                    batch,
1202                    token: *token,
1203                    len,
1204                };
1205                *token += len;
1206                Some(cursor)
1207            })
1208            .collect_vec();
1209
1210        let (shape, data) = value.into_iter().fold(
1211            (Shape::new(shape[0], 0, 1, 1), vec![]),
1212            |(mut shape, mut data), tensor| {
1213                shape[1] += tensor.shape[1];
1214                data.extend(tensor.data.to_vec());
1215                (shape, data)
1216            },
1217        );
1218        let data = data.into();
1219
1220        Ok(Self {
1221            tensor: Tensor {
1222                shape,
1223                data,
1224                id: uid::Id::new(),
1225                phantom: PhantomData,
1226            },
1227            cursors,
1228        })
1229    }
1230}
1231
1232impl Context {
1233    #[inline]
1234    pub fn zeros<T: Scalar, Tensor>(&self, shape: impl Into<Shape>) -> Tensor
1235    where
1236        TensorCpu<T>: TensorInto<Tensor>,
1237    {
1238        let tensor: TensorCpu<T> = TensorInit::init(shape);
1239        tensor.to(self)
1240    }
1241
1242    #[inline]
1243    pub fn ones<T: Scalar, Tensor>(&self, shape: impl Into<Shape>) -> Tensor
1244    where
1245        TensorCpu<T>: TensorInto<Tensor>,
1246    {
1247        let shape = shape.into();
1248        let data = vec![T::one(); shape.len()];
1249        let tensor: TensorCpu<T> = TensorInit::from_data(shape, data).unwrap();
1250        tensor.to(self)
1251    }
1252
1253    #[inline]
1254    pub fn tensor_from_data<'a, T: Scalar, Tensor: TensorInitContext<T>>(
1255        &self,
1256        shape: impl Into<Shape>,
1257        data: impl Into<Cow<'a, [T]>>,
1258    ) -> Result<Tensor, TensorError> {
1259        TensorInitContext::from_data(self, shape, data)
1260    }
1261
1262    #[inline]
1263    pub fn tensor_init<T: Scalar, Tensor: TensorInitContext<T>>(
1264        &self,
1265        shape: impl Into<Shape>,
1266    ) -> Tensor {
1267        TensorInitContext::init(self, shape)
1268    }
1269}
1270
1271mod sealed {
1272    use super::{Cpu, Gpu, Kind, ReadWrite, Uniform};
1273    use crate::num::Scalar;
1274
1275    pub trait Sealed {}
1276
1277    impl<T: Scalar> Sealed for Cpu<T> {}
1278    impl<K: Kind> Sealed for Gpu<K> {}
1279
1280    impl Sealed for Uniform {}
1281    impl Sealed for ReadWrite {}
1282}
1283
1284#[cfg(test)]
1285mod tests {
1286    use anyhow::Result;
1287
1288    use super::Shape;
1289    use crate::tensor::{TensorCpu, TensorInit, TensorShape};
1290
1291    #[test]
1292    fn test_pad_64() -> Result<()> {
1293        let shape = Shape::new(133, 256, 1, 1);
1294        let x: Vec<_> = (0..shape.len()).map(|x| x as f32).collect();
1295        let x = TensorCpu::from_data(shape, x)?.pad([64, 64, 1, 1]);
1296
1297        assert_eq!(x.shape(), Shape::new(192, 256, 1, 1));
1298        assert_eq!(x[(132, 255, 0, 0)], (shape.len() - 1) as f32);
1299        assert_eq!(x[(133, 255, 0, 0)], 0.0);
1300
1301        Ok(())
1302    }
1303
1304    #[test]
1305    fn test_repeat() -> Result<()> {
1306        let shape = Shape::new(5, 1, 2, 1);
1307        let x: Vec<_> = (0..shape.len()).map(|x| x as f32).collect();
1308        let x = TensorCpu::from_data(shape, x)?;
1309
1310        let y = x.clone().repeat(1, 3);
1311        let ans = [
1312            [0.0, 1.0, 2.0, 3.0, 4.0].repeat(3),
1313            [5.0, 6.0, 7.0, 8.0, 9.0].repeat(3),
1314        ]
1315        .concat();
1316        y.check_shape([5, 3, 2, 1])?;
1317        assert_eq!(y.to_vec(), ans);
1318
1319        let y = x.clone().repeat(0, 3);
1320        y.check_shape([15, 1, 2, 1])?;
1321        assert_eq!(y.to_vec(), ans);
1322
1323        let y = x.repeat(2, 3);
1324        let ans = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0].repeat(3);
1325        y.check_shape([5, 1, 6, 1])?;
1326        assert_eq!(y.to_vec(), ans);
1327
1328        Ok(())
1329    }
1330
1331    #[test]
1332    fn test_split() -> Result<()> {
1333        let shape = Shape::new(5, 1, 2, 1);
1334        let x: Vec<_> = (0..10).map(|x| x as f32).collect();
1335        let x = TensorCpu::from_data(shape, x)?;
1336
1337        assert!(x.clone().split(0).is_err());
1338        assert!(x.clone().split(1).is_err());
1339
1340        let x = x.split(2)?;
1341        x[0].check_shape([5, 1, 1, 1])?;
1342        x[1].check_shape([5, 1, 1, 1])?;
1343        assert_eq!(x[0].to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
1344        assert_eq!(x[1].to_vec(), vec![5.0, 6.0, 7.0, 8.0, 9.0]);
1345
1346        Ok(())
1347    }
1348}