1use std::{borrow::Cow, marker::PhantomData, sync::Arc};
2
3use itertools::Itertools;
4use shape::ShapedIndex;
5use thiserror::Error;
6use wgpu::{
7 BindGroupLayoutEntry, BindingResource, BindingType, Buffer, BufferBinding, BufferBindingType,
8 BufferUsages, ShaderStages,
9};
10
11use self::{
12 kind::{Kind, ReadWrite, Uniform},
13 shape::{IntoBytes, Shape, TensorAxis, TensorDimension, TensorSlice},
14};
15use crate::{
16 context::Context,
17 num::{Float, Scalar},
18};
19
20pub mod cache;
21pub mod matrix;
22pub mod ops;
23pub mod serialization;
24pub mod shape;
25
26#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
28pub struct View {
29 pub shape: Shape,
30 pub stride: Shape,
31 pub offset: Shape,
32}
33
34impl IntoBytes for View {
35 fn into_bytes(self) -> Vec<u8> {
36 [
37 self.shape.into_bytes(),
38 self.stride.into_bytes(),
39 self.offset.into_bytes(),
40 ]
41 .concat()
42 }
43}
44
45#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
47pub struct Cursor {
48 pub batch: usize,
49 pub token: usize,
50 pub len: usize,
51}
52
53impl Cursor {
54 pub fn pack(self) -> u32 {
55 let batch = self.batch as u8;
56 let token = (self.token as u16).to_ne_bytes();
57 let len = self.len as u8;
58 bytemuck::cast([batch, token[0], token[1], len])
59 }
60}
61
62pub trait IntoPackedCursors {
63 fn into_stack(self) -> Vec<u32>;
64 fn into_cursors(self) -> Vec<u32>;
65}
66
67impl IntoPackedCursors for Vec<Cursor> {
68 fn into_stack(self) -> Vec<u32> {
69 self.into_iter()
70 .filter(|cursor| cursor.len > 0)
71 .map(Cursor::pack)
72 .collect()
73 }
74
75 fn into_cursors(self) -> Vec<u32> {
76 self.into_iter()
77 .filter(|cursor| cursor.len > 0)
78 .map(|cursor| {
79 let repeat = cursor.len;
80 vec![cursor.pack(); repeat]
81 })
82 .collect_vec()
83 .concat()
84 }
85}
86
87#[derive(Debug)]
88pub struct TensorError {
89 pub error: TensorErrorKind,
90 #[cfg(feature = "backtrace")]
91 pub backtrace: std::backtrace::Backtrace,
92}
93
94impl TensorError {
95 #[cfg(feature = "backtrace")]
96 pub fn new(error: TensorErrorKind) -> Self {
97 let backtrace = std::backtrace::Backtrace::capture();
98 Self { error, backtrace }
99 }
100
101 #[cfg(not(feature = "backtrace"))]
102 pub fn new(error: TensorErrorKind) -> Self {
103 Self { error }
104 }
105}
106
107impl From<TensorErrorKind> for TensorError {
108 fn from(value: TensorErrorKind) -> Self {
109 Self::new(value)
110 }
111}
112
113impl std::fmt::Display for TensorError {
114 #[cfg(feature = "backtrace")]
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 writeln!(f, "{}\n\nBacktrace:\n{}", self.error, self.backtrace)
117 }
118
119 #[cfg(not(feature = "backtrace"))]
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 writeln!(f, "{}", self.error)
122 }
123}
124
125impl std::error::Error for TensorError {}
126
127#[derive(Debug, Error)]
128pub enum TensorErrorKind {
129 #[error("list must not be empty")]
130 Empty,
131 #[error("data type mismatch")]
132 Type,
133 #[error("data size not match: {0} vs. {1}")]
134 Size(usize, usize),
135 #[error("batch size not match: {0} vs. {1}")]
136 Batch(usize, usize),
137 #[error("tensor shape not match: {0} vs. {1}")]
138 Shape(Shape, Shape),
139 #[error("cannot deduce dimension")]
140 Deduce,
141 #[error("batch {batch} out of range of max {max}")]
142 BatchOutOfRange { batch: usize, max: usize },
143 #[error("slice {start}..{end} out of range for dimension size {dim}")]
144 SliceOutOfRange {
145 dim: usize,
146 start: usize,
147 end: usize,
148 },
149 #[error("slice not contiguous")]
150 SliceInvalid,
151 #[error("cannot split along the axis {0}")]
152 SplitInvalid(usize),
153 #[error("possible tensor error(s):\n{0}")]
154 Any(#[from] AnyTensorError),
155}
156
157#[derive(Debug, Error)]
158pub struct AnyTensorError(pub Vec<Box<dyn std::error::Error + Send + Sync>>);
159
160impl std::fmt::Display for AnyTensorError {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 self.0
163 .iter()
164 .enumerate()
165 .try_for_each(|(index, error)| writeln!(f, "{index}. {error}"))
166 }
167}
168
169pub trait DeepClone: Sized {
170 fn deep_clone(&self) -> Self;
171}
172
173pub trait TensorScalar {
174 type T: Scalar;
175}
176
177pub trait TensorInitContext<T: Scalar>: Sized {
178 fn from_data<'a, S, D>(context: &Context, shape: S, data: D) -> Result<Self, TensorError>
180 where
181 S: Into<Shape>,
182 D: Into<Cow<'a, [T]>>;
183 fn init(context: &Context, shape: impl Into<Shape>) -> Self;
185}
186
187pub trait TensorInit<T: Scalar>: Sized {
188 fn from_data<'a, S, D>(shape: S, data: D) -> Result<Self, TensorError>
190 where
191 S: Into<Shape>,
192 D: Into<Cow<'a, [T]>>;
193 fn init(shape: impl Into<Shape>) -> Self;
195
196 fn from_data_1d<'a>(data: impl Into<Cow<'a, [T]>>) -> Self {
198 let data: Cow<'_, [T]> = data.into();
199 let shape = [data.len(), 1, 1, 1];
200 Self::from_data(shape, data).expect("tensor 1d from data")
201 }
202}
203
204pub trait TensorInto<Into> {
205 fn to(self, context: &Context) -> Into;
206}
207
208pub trait TensorShape: Sized {
209 fn shape(&self) -> Shape;
211
212 fn check_shape(&self, shape: impl Into<Shape>) -> Result<(), TensorError> {
214 let shape = shape.into();
215 (self.shape() == shape)
216 .then_some(())
217 .ok_or(TensorErrorKind::Shape(self.shape(), shape))
218 .map_err(Into::into)
219 }
220
221 fn check_shape_any<S>(&self, shapes: &[S]) -> Result<(), TensorError>
223 where
224 S: Into<Shape> + ToOwned<Owned = S>,
225 {
226 let (oks, errors): (Vec<_>, Vec<_>) = shapes
227 .iter()
228 .map(|shape| self.check_shape(shape.to_owned()).map_err(Into::into))
229 .partition_result();
230 match oks.is_empty() {
231 true => Err(TensorErrorKind::Any(AnyTensorError(errors)))?,
232 false => Ok(()),
233 }
234 }
235}
236
237pub trait TensorReshape: Sized {
238 fn reshape(
239 &self,
240 x: TensorDimension,
241 y: TensorDimension,
242 z: TensorDimension,
243 w: TensorDimension,
244 ) -> Result<Self, TensorError>;
245}
246
247pub trait TensorResource {
248 fn resource_key(&self) -> ResourceKey;
250 fn meta_binding(&self) -> BindingResource<'_>;
252 fn binding(&self) -> BindingResource<'_>;
254}
255
256#[derive(Debug)]
258pub struct Tensor<D: Device, T: Scalar> {
259 shape: Shape,
260 data: D::Data,
261 id: uid::Id<TensorId>,
262 phantom: PhantomData<T>,
263}
264
265#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
266pub struct TensorId;
267
268#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
270pub struct ResourceKey {
271 pub id: uid::Id<TensorId>,
272 pub view: View,
273}
274
275pub trait Device: sealed::Sealed {
276 type Data: Clone;
277}
278
279#[derive(Debug)]
280pub struct Cpu<T: Scalar>(PhantomData<T>);
281
282#[derive(Debug)]
283pub struct Gpu<K: Kind>(PhantomData<K>);
284
285impl<T: Scalar> Device for Cpu<T> {
286 type Data = Arc<[T]>;
287}
288
289impl<K: Kind> Device for Gpu<K> {
290 type Data = TensorGpuData;
291}
292
293#[derive(Debug, Clone)]
295pub struct TensorGpuData {
296 pub context: Context,
297 pub meta: Arc<Buffer>,
298 pub buffer: Arc<Buffer>,
299}
300
301pub mod kind {
302 use web_rwkv_derive::Kind;
303 use wgpu::BufferUsages;
304
305 use super::sealed;
306
307 pub trait Kind: sealed::Sealed {
308 fn buffer_usages() -> BufferUsages;
309 }
310
311 #[derive(Debug, Kind)]
313 #[usage(UNIFORM, COPY_DST, COPY_SRC)]
314 pub struct Uniform;
315
316 #[derive(Debug, Kind)]
318 #[usage(STORAGE, COPY_DST, COPY_SRC)]
319 pub struct ReadWrite;
320}
321
322pub type TensorCpu<T> = Tensor<Cpu<T>, T>;
323pub type TensorGpu<T, K> = Tensor<Gpu<K>, T>;
324
325impl<D: Device, T: Scalar> Clone for Tensor<D, T> {
326 fn clone(&self) -> Self {
327 Self {
328 shape: self.shape,
329 data: self.data.clone(),
330 id: self.id,
331 phantom: PhantomData,
332 }
333 }
334}
335
336impl<D: Device, T: Scalar> std::ops::Deref for Tensor<D, T> {
337 type Target = D::Data;
338
339 #[inline]
340 fn deref(&self) -> &Self::Target {
341 &self.data
342 }
343}
344
345impl<D: Device, T: Scalar> TensorScalar for Tensor<D, T> {
346 type T = T;
347}
348
349impl<D: Device, T: Scalar> Tensor<D, T> {
350 #[inline]
351 pub fn len(&self) -> usize {
352 self.shape.len()
353 }
354
355 #[inline]
356 pub fn is_empty(&self) -> bool {
357 self.shape.is_empty()
358 }
359
360 #[inline]
362 pub fn size(&self) -> usize {
363 self.len() * T::size()
364 }
365
366 #[inline]
368 pub fn offset(index: usize) -> usize {
369 index * T::size()
370 }
371
372 #[inline]
373 pub fn data(&self) -> &D::Data {
374 &self.data
375 }
376
377 #[inline]
378 pub fn id(&self) -> uid::Id<TensorId> {
379 self.id
380 }
381}
382
383impl<D: Device, F: Float> Tensor<D, F> {
384 #[inline]
385 pub const fn def(&self) -> &'static str {
386 F::DEF
387 }
388}
389
390impl<T: Scalar> TensorInit<T> for TensorCpu<T> {
391 fn from_data<'a, S, D>(shape: S, data: D) -> Result<Self, TensorError>
392 where
393 S: Into<Shape>,
394 D: Into<Cow<'a, [T]>>,
395 {
396 let shape = shape.into();
397 let data: Cow<'_, _> = data.into();
398 if shape.len() != data.len() {
399 Err(TensorErrorKind::Size(shape.len(), data.len()))?;
400 }
401 let data = data.into_owned().into();
402 Ok(Self {
403 shape,
404 data,
405 id: uid::Id::new(),
406 phantom: PhantomData,
407 })
408 }
409
410 #[inline]
411 fn init(shape: impl Into<Shape>) -> Self {
412 let shape = shape.into();
413 let data = vec![T::zero(); shape.len()].into();
414 Self {
415 shape,
416 data,
417 id: uid::Id::new(),
418 phantom: PhantomData,
419 }
420 }
421}
422
423impl<T: Scalar> TensorInitContext<T> for TensorCpu<T> {
424 fn from_data<'a, S, D>(_context: &Context, shape: S, data: D) -> Result<Self, TensorError>
425 where
426 S: Into<Shape>,
427 D: Into<Cow<'a, [T]>>,
428 {
429 TensorInit::from_data(shape, data)
430 }
431
432 fn init(_context: &Context, shape: impl Into<Shape>) -> Self {
433 TensorInit::init(shape)
434 }
435}
436
437impl<T: Scalar> TensorInto<TensorCpu<T>> for TensorCpu<T> {
438 fn to(self, _: &Context) -> Self {
439 self
440 }
441}
442
443impl<T: Scalar> DeepClone for TensorCpu<T> {
444 fn deep_clone(&self) -> Self {
445 self.clone()
446 }
447}
448
449impl<T: Scalar> TensorShape for TensorCpu<T> {
450 #[inline]
451 fn shape(&self) -> Shape {
452 self.shape
453 }
454}
455
456impl<T: Scalar> TensorReshape for TensorCpu<T> {
457 #[inline]
458 fn reshape(
459 &self,
460 x: TensorDimension,
461 y: TensorDimension,
462 z: TensorDimension,
463 w: TensorDimension,
464 ) -> Result<Self, TensorError> {
465 let shape = TensorDimension::deduce(self.shape, x, y, z, w)?;
466 Ok(Self {
467 shape,
468 ..self.clone()
469 })
470 }
471}
472
473impl<T: Scalar, K: Kind> TensorInitContext<T> for TensorGpu<T, K> {
474 fn from_data<'a, S, D>(context: &Context, shape: S, data: D) -> Result<Self, TensorError>
475 where
476 S: Into<Shape>,
477 D: Into<Cow<'a, [T]>>,
478 {
479 let tensor: TensorCpu<T> = TensorInit::from_data(shape, data)?;
480 Ok(tensor.to(context))
481 }
482
483 fn init(context: &Context, shape: impl Into<Shape>) -> Self {
484 let context = context.clone();
485 let shape = shape.into();
486 let meta = context.checkout_shape_uniform(shape);
487 let size = shape.len() * std::mem::size_of::<T>();
488 let buffer = context.checkout_buffer(size, K::buffer_usages());
489 Self {
490 shape,
491 data: TensorGpuData {
492 context,
493 meta,
494 buffer,
495 },
496 id: uid::Id::new(),
497 phantom: PhantomData,
498 }
499 }
500}
501
502impl<T: Scalar, K: Kind> TensorInto<TensorGpu<T, K>> for TensorCpu<T> {
503 fn to(self, context: &Context) -> TensorGpu<T, K> {
504 let Tensor { shape, data, .. } = self;
505 let context = context.clone();
506 let meta = context.checkout_shape_uniform(shape);
507 let contents = bytemuck::cast_slice(&data);
508 let buffer = context.checkout_buffer_init(contents, K::buffer_usages());
509 TensorGpu {
510 shape,
511 data: TensorGpuData {
512 context,
513 meta,
514 buffer,
515 },
516 id: uid::Id::new(),
517 phantom: PhantomData,
518 }
519 }
520}
521
522#[cfg(not(target_arch = "wasm32"))]
523impl<T: Scalar> TensorInto<TensorGpu<T, ReadWrite>> for TensorGpu<T, ReadWrite> {
524 fn to(self, context: &Context) -> Self {
525 match context {
526 context if context == &self.context => self,
527 _ => self.back_in_place().to(context),
528 }
529 }
530}
531
532#[cfg(target_arch = "wasm32")]
533impl<T: Scalar> TensorInto<TensorGpu<T, ReadWrite>> for TensorGpu<T, ReadWrite> {
534 fn to(self, _: &Context) -> Self {
535 self
536 }
537}
538
539impl<T: Scalar, K: Kind> TensorShape for TensorGpu<T, K> {
540 #[inline]
541 fn shape(&self) -> Shape {
542 self.shape
543 }
544}
545
546impl<T: Scalar, K: Kind> TensorReshape for TensorGpu<T, K> {
547 #[inline]
548 fn reshape(
549 &self,
550 x: TensorDimension,
551 y: TensorDimension,
552 z: TensorDimension,
553 w: TensorDimension,
554 ) -> Result<Self, TensorError> {
555 let shape = TensorDimension::deduce(self.shape, x, y, z, w)?;
556 let context = self.context.clone();
557 let meta = context.checkout_shape_uniform(shape);
558 let buffer = self.buffer.clone();
559 Ok(Self {
560 shape,
561 data: TensorGpuData {
562 context,
563 meta,
564 buffer,
565 },
566 ..self.clone()
567 })
568 }
569}
570
571impl<T: Scalar, K: Kind> TensorResource for TensorGpu<T, K> {
572 #[inline]
573 fn resource_key(&self) -> ResourceKey {
574 let id = self.id;
575 let view = View {
576 shape: self.shape,
577 stride: self.shape,
578 offset: [0, 0, 0, 0].into(),
579 };
580 ResourceKey { id, view }
581 }
582
583 #[inline]
584 fn meta_binding(&self) -> BindingResource<'_> {
585 BindingResource::Buffer(BufferBinding {
586 buffer: &self.meta,
587 offset: 0,
588 size: None,
589 })
590 }
591
592 #[inline]
593 fn binding(&self) -> BindingResource<'_> {
594 BindingResource::Buffer(BufferBinding {
595 buffer: &self.buffer,
596 offset: 0,
597 size: None,
598 })
599 }
600}
601
602impl<T: Scalar, K: Kind> TensorGpu<T, K> {
603 pub fn from_data_u8(
604 context: &Context,
605 shape: impl Into<Shape>,
606 contents: &[u8],
607 ) -> Result<Self, TensorError> {
608 let shape = shape.into();
609 let size = shape.len() * size_of::<T>();
610 if contents.len() != size {
611 Err(TensorErrorKind::Size(size, contents.len()))?;
612 }
613 let buffer = context.checkout_buffer_init(contents, K::buffer_usages());
614 let meta = context.checkout_shape_uniform(shape);
615 Ok(Self {
616 shape,
617 data: TensorGpuData {
618 context: context.clone(),
619 meta,
620 buffer,
621 },
622 id: uid::Id::new(),
623 phantom: PhantomData,
624 })
625 }
626
627 #[cfg(not(target_arch = "wasm32"))]
628 pub fn back_in_place(&self) -> TensorCpu<T> {
629 use crate::context::ContextEvent;
630
631 if self.is_empty() {
632 return TensorCpu {
633 shape: self.shape,
634 data: Arc::new([]),
635 id: uid::Id::new(),
636 phantom: PhantomData,
637 };
638 }
639
640 let context = &self.context;
641 let size = self.buffer.size();
642 let buffer = context.checkout_buffer(
643 size as usize,
644 BufferUsages::MAP_READ | BufferUsages::COPY_DST,
645 );
646
647 let mut encoder = context.device.create_command_encoder(&Default::default());
648 encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
649 context.queue.submit(Some(encoder.finish()));
650
651 let (sender, receiver) = flume::bounded(1);
652 let _ = context.event().send(ContextEvent { buffer, sender });
653 let data = receiver.recv().expect("failed to receive read back buffer");
654 let data = unsafe {
655 let data = Box::leak(data);
656 let slice = bytemuck::cast_slice_mut::<_, T>(data);
657 Box::from_raw(slice)
658 };
659 let data = data.into_vec().into();
660 let shape = self.shape;
661
662 TensorCpu {
663 shape,
664 data,
665 id: uid::Id::new(),
666 phantom: PhantomData,
667 }
668 }
669
670 #[cfg(not(target_arch = "wasm32"))]
671 pub async fn back(&self) -> TensorCpu<T> {
672 if self.is_empty() {
673 return TensorCpu {
674 shape: self.shape,
675 data: Arc::new([]),
676 id: uid::Id::new(),
677 phantom: PhantomData,
678 };
679 }
680
681 let context = &self.context;
682 let size = self.buffer.size();
683 let buffer = context.checkout_buffer(
684 size as usize,
685 BufferUsages::MAP_READ | BufferUsages::COPY_DST,
686 );
687
688 let mut encoder = context.device.create_command_encoder(&Default::default());
689 encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
690 context.queue.submit(Some(encoder.finish()));
691
692 let (sender, receiver) = flume::bounded(1);
693
694 let _ = context
695 .event()
696 .send(crate::context::ContextEvent { buffer, sender });
697 let data = receiver
698 .recv_async()
699 .await
700 .expect("failed to receive read back buffer");
701 let data = unsafe {
702 let data = Box::leak(data);
703 let slice = bytemuck::cast_slice_mut::<_, T>(data);
704 Box::from_raw(slice)
705 };
706 let data = data.into_vec().into();
707
708 TensorCpu {
709 shape: self.shape,
710 data,
711 id: uid::Id::new(),
712 phantom: PhantomData,
713 }
714 }
715
716 #[cfg(target_arch = "wasm32")]
717 pub async fn back(self) -> TensorCpu<T> {
718 if self.is_empty() {
719 return TensorCpu {
720 shape: self.shape,
721 data: Arc::new([]),
722 id: uid::Id::new(),
723 phantom: PhantomData,
724 };
725 }
726
727 let context = &self.context;
728 let size = self.buffer.size();
729 let buffer = context.checkout_buffer(
730 size as usize,
731 BufferUsages::MAP_READ | BufferUsages::COPY_DST,
732 );
733
734 let mut encoder = context.device.create_command_encoder(&Default::default());
735 encoder.copy_buffer_to_buffer(&self.buffer, 0, &buffer, 0, size);
736 let submission_index = Some(context.queue.submit(Some(encoder.finish())));
737
738 let (sender, receiver) = flume::unbounded();
739
740 let slice = buffer.slice(..);
741 slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
742
743 _ = context.device.poll(wgpu::PollType::Wait {
744 submission_index,
745 timeout: None,
746 });
747 receiver
748 .recv_async()
749 .await
750 .expect("failed to receive read back buffer")
751 .expect("failed to map buffer");
752
753 let data = {
754 let map = slice.get_mapped_range();
755 Vec::from(bytemuck::cast_slice(&map)).into()
756 };
757 buffer.unmap();
758
759 TensorCpu {
760 shape: self.shape,
761 data,
762 id: uid::Id::new(),
763 phantom: PhantomData,
764 }
765 }
766}
767
768impl<T: Scalar, K: Kind> TensorGpu<T, K> {
769 #[inline]
770 pub fn context(&self) -> &Context {
771 &self.context
772 }
773
774 pub fn load(&self, host: &TensorCpu<T>) -> Result<(), TensorError> {
775 host.check_shape(self.shape)?;
776 self.context
777 .queue
778 .write_buffer(&self.buffer, 0, bytemuck::cast_slice(&host.data[..]));
779 Ok(())
780 }
781
782 pub fn load_batch(&self, host: &TensorCpu<T>, batch: usize) -> Result<(), TensorError> {
783 host.check_shape([self.shape[0], self.shape[1], 1, 1])?;
784 if batch >= self.shape[2] {
785 Err(TensorErrorKind::BatchOutOfRange {
786 batch,
787 max: self.shape[2],
788 })?;
789 }
790 let offset = (T::size() * self.shape[0] * self.shape[1] * batch) as u64;
791 self.context
792 .queue
793 .write_buffer(&self.buffer, offset, bytemuck::cast_slice(&host.data[..]));
794 Ok(())
795 }
796
797 pub fn destroy(self) {
798 self.buffer.destroy();
799 }
800}
801
802impl<T: Scalar> TensorGpu<T, Uniform> {
803 #[inline]
804 pub fn layout(&self, binding: u32) -> BindGroupLayoutEntry {
805 BindGroupLayoutEntry {
806 binding,
807 visibility: ShaderStages::COMPUTE,
808 ty: BindingType::Buffer {
809 ty: BufferBindingType::Uniform,
810 has_dynamic_offset: false,
811 min_binding_size: None,
812 },
813 count: None,
814 }
815 }
816}
817
818impl<T: Scalar> TensorGpu<T, ReadWrite> {
819 #[inline]
820 pub fn meta_layout(&self, binding: u32) -> BindGroupLayoutEntry {
821 BindGroupLayoutEntry {
822 binding,
823 visibility: ShaderStages::COMPUTE,
824 ty: BindingType::Buffer {
825 ty: BufferBindingType::Uniform,
826 has_dynamic_offset: false,
827 min_binding_size: None,
828 },
829 count: None,
830 }
831 }
832
833 #[inline]
834 pub fn layout(&self, binding: u32, read_only: bool) -> BindGroupLayoutEntry {
835 BindGroupLayoutEntry {
836 binding,
837 visibility: ShaderStages::COMPUTE,
838 ty: BindingType::Buffer {
839 ty: BufferBindingType::Storage { read_only },
840 has_dynamic_offset: false,
841 min_binding_size: None,
842 },
843 count: None,
844 }
845 }
846}
847
848impl<T: Scalar> From<TensorCpu<T>> for Vec<T> {
849 #[inline]
850 fn from(value: TensorCpu<T>) -> Self {
851 value.to_vec()
865 }
866}
867
868impl<T: Scalar, S: Into<ShapedIndex>> std::ops::Index<S> for TensorCpu<T> {
869 type Output = T;
870
871 fn index(&self, index: S) -> &Self::Output {
872 &self.data[self.shape.linear_index(index)]
873 }
874}
875
876impl<T: Scalar> TensorCpu<T> {
877 pub fn map<U: Scalar>(self, f: impl FnMut(&T) -> U) -> TensorCpu<U> {
879 let Self { shape, data, .. } = self;
880 let data = data.iter().map(f).collect_vec();
881 TensorInit::from_data(shape, &data).unwrap()
882 }
883
884 pub fn pad(self, multiples: impl Into<Shape>) -> Self {
886 let multiples: Shape = multiples.into();
894
895 let mut shape = self.shape;
896 for (axis, multiple) in multiples.iter().enumerate() {
897 shape[axis] = shape[axis].next_multiple_of(multiple);
898 }
899
900 let mut data = vec![T::zero(); shape.len()];
901 for index in self.shape.cartesian_product() {
902 let value = self[index];
903 data[shape.linear_index(index)] = value;
904 }
905 TensorInit::from_data(shape, data).unwrap()
906 }
907
908 pub fn repeat(self, axis: usize, repeat: usize) -> Self {
910 let mut shape = self.shape;
911 let data = self.data;
912
913 let num_chunk: usize = shape.iter().skip(axis + 1).product();
914 let chunk_size = data.len() / num_chunk;
915
916 shape[axis] *= repeat;
917
918 let data = (0..num_chunk)
919 .map(|chunk| {
920 let start = chunk * chunk_size;
921 let end = start + chunk_size;
922 let chunk = data[start..end].to_vec();
923 chunk.repeat(repeat)
924 })
925 .collect_vec()
926 .concat()
927 .into();
928
929 Self {
930 shape,
931 data,
932 id: uid::Id::new(),
933 phantom: PhantomData,
934 }
935 }
936
937 pub fn stack(batches: Vec<Self>, axis: usize) -> Result<Self, TensorError> {
939 let mut shape = match batches.first() {
940 Some(batch) => batch.shape,
941 None => Err(TensorErrorKind::Empty)?,
942 };
943
944 batches.iter().try_for_each(|batch| match axis {
949 0 => batch.check_shape([batch.shape[0], 1, 1, 1]),
950 1 => batch.check_shape([shape[0], batch.shape[1], 1, 1]),
951 2 => batch.check_shape([shape[0], shape[1], batch.shape[2], 1]),
952 3 => batch.check_shape([shape[0], shape[1], shape[2], batch.shape[3]]),
953 _ => unreachable!(),
954 })?;
955
956 let num_batch: usize = batches.iter().map(|batch| batch.shape[axis]).sum();
957 shape[axis] = num_batch;
958
959 let data = batches
960 .into_iter()
961 .map(|batch| batch.data.to_vec())
962 .collect_vec()
963 .concat()
964 .into();
965
966 Ok(Self {
967 shape,
968 data,
969 id: uid::Id::new(),
970 phantom: PhantomData,
971 })
972 }
973
974 pub fn split(self, axis: usize) -> Result<Vec<Self>, TensorError> {
976 if self.shape.iter().skip(axis + 1).any(|dim| dim > 1) {
977 Err(TensorErrorKind::SplitInvalid(axis))?;
978 }
979
980 (0..self.shape[axis])
981 .map(|index| match axis {
982 0 => self.slice(index, .., .., ..),
983 1 => self.slice(.., index, .., ..),
984 2 => self.slice(.., .., index, ..),
985 3 => self.slice(.., .., .., ..),
986 _ => Err(TensorErrorKind::SplitInvalid(axis))?,
987 })
988 .try_collect()
989 }
990
991 pub fn slice(
992 &self,
993 x: impl TensorAxis,
994 y: impl TensorAxis,
995 z: impl TensorAxis,
996 w: impl TensorAxis,
997 ) -> Result<Self, TensorError> {
998 let slice = (x, y, z, w);
999 let (start, end) = slice.shaped_bounds(self.shape)?;
1000 let shape = (end - start).into();
1001
1002 let (start, end) = slice.linear_bounds(self.shape)?;
1003 let data = self.data[start..end].into();
1004
1005 Ok(Self {
1006 shape,
1007 data,
1008 id: uid::Id::new(),
1009 phantom: PhantomData,
1010 })
1011 }
1012
1013 pub fn into_slice(
1014 self,
1015 x: impl TensorAxis,
1016 y: impl TensorAxis,
1017 z: impl TensorAxis,
1018 w: impl TensorAxis,
1019 ) -> Result<Self, TensorError> {
1020 let slice = (x, y, z, w);
1021 let (start, end) = slice.shaped_bounds(self.shape)?;
1022 let shape = (end - start).into();
1023
1024 let (start, end) = slice.linear_bounds(self.shape)?;
1025 let data = self.data[start..end].into();
1026
1027 Ok(Self {
1028 shape,
1029 data,
1030 id: uid::Id::new(),
1031 phantom: PhantomData,
1032 })
1033 }
1034}
1035
1036#[derive(Debug, Clone)]
1038pub struct TensorGpuView<'a, T: Scalar> {
1039 tensor: &'a TensorGpu<T, ReadWrite>,
1040 meta: Arc<Buffer>,
1041 view: View,
1042}
1043
1044impl<T: Scalar> TensorShape for TensorGpuView<'_, T> {
1045 #[inline]
1046 fn shape(&self) -> Shape {
1047 self.view.shape
1048 }
1049}
1050
1051impl<T: Scalar> TensorGpuView<'_, T> {
1052 #[inline]
1053 pub fn tensor(&self) -> &TensorGpu<T, ReadWrite> {
1054 self.tensor
1055 }
1056
1057 #[inline]
1058 pub fn context(&self) -> &Context {
1059 self.tensor.context()
1060 }
1061
1062 #[inline]
1063 pub fn data(&self) -> &TensorGpuData {
1064 &self.tensor.data
1065 }
1066
1067 #[inline]
1068 pub fn meta_layout(&self, binding: u32) -> BindGroupLayoutEntry {
1069 self.tensor.meta_layout(binding)
1070 }
1071
1072 #[inline]
1073 pub fn layout(&self, binding: u32, read_only: bool) -> BindGroupLayoutEntry {
1074 self.tensor.layout(binding, read_only)
1075 }
1076}
1077
1078impl<T: Scalar> TensorResource for TensorGpuView<'_, T> {
1079 #[inline]
1080 fn resource_key(&self) -> ResourceKey {
1081 ResourceKey {
1082 id: self.tensor.id,
1083 view: self.view,
1084 }
1085 }
1086
1087 #[inline]
1088 fn meta_binding(&self) -> BindingResource<'_> {
1089 BindingResource::Buffer(BufferBinding {
1090 buffer: &self.meta,
1091 offset: 0,
1092 size: None,
1093 })
1094 }
1095
1096 #[inline]
1097 fn binding(&self) -> BindingResource<'_> {
1098 self.tensor.binding()
1099 }
1100}
1101
1102impl<T: Scalar> TensorScalar for TensorGpuView<'_, T> {
1103 type T = T;
1104}
1105
1106impl<F: Float> TensorGpuView<'_, F> {
1107 #[inline]
1108 pub const fn def(&self) -> &'static str {
1109 F::DEF
1110 }
1111}
1112
1113impl<'a, T: Scalar> From<&'a TensorGpu<T, ReadWrite>> for TensorGpuView<'a, T> {
1114 fn from(value: &'a TensorGpu<T, ReadWrite>) -> Self {
1115 value.view(.., .., .., ..).unwrap()
1116 }
1117}
1118
1119impl<T: Scalar> TensorGpu<T, ReadWrite> {
1120 pub fn view(
1122 &self,
1123 x: impl TensorAxis,
1124 y: impl TensorAxis,
1125 z: impl TensorAxis,
1126 w: impl TensorAxis,
1127 ) -> Result<TensorGpuView<'_, T>, TensorError> {
1128 let slice = (x, y, z, w);
1129 let (start, end) = slice.shaped_bounds(self.shape)?;
1130 let view = View {
1131 stride: self.shape,
1132 offset: start.into(),
1133 shape: (end - start).into(),
1134 };
1135 let meta = self.context.checkout_view_uniform(view);
1136 Ok(TensorGpuView {
1137 tensor: self,
1138 meta,
1139 view,
1140 })
1141 }
1142}
1143
1144impl<T: Scalar> DeepClone for TensorGpu<T, ReadWrite> {
1145 fn deep_clone(&self) -> Self {
1146 let context = &self.context;
1147 let shape = self.shape;
1148 let size = shape.len() as u64;
1149 let cloned: TensorGpu<_, _> = context.tensor_init(shape);
1150
1151 let mut encoder = context.device.create_command_encoder(&Default::default());
1152 encoder.copy_buffer_to_buffer(&self.buffer, 0, &cloned.buffer, 0, size);
1153 context.queue.submit(Some(encoder.finish()));
1154
1155 cloned
1156 }
1157}
1158
1159#[derive(Debug, Clone)]
1161pub struct TensorStack<T: Scalar> {
1162 pub tensor: TensorCpu<T>,
1163 pub cursors: Vec<Cursor>,
1164}
1165
1166impl<T: Scalar> TensorStack<T> {
1167 #[inline]
1169 pub fn num_batch(&self) -> usize {
1170 self.cursors.len()
1171 }
1172
1173 #[inline]
1175 pub fn num_active_batch(&self) -> usize {
1176 self.cursors.iter().filter(|cursor| cursor.len > 0).count()
1177 }
1178
1179 #[inline]
1180 pub fn num_token(&self) -> usize {
1181 self.tensor.shape[1]
1182 }
1183}
1184
1185impl<T: Scalar> TryFrom<Vec<TensorCpu<T>>> for TensorStack<T> {
1186 type Error = TensorError;
1187
1188 fn try_from(value: Vec<TensorCpu<T>>) -> Result<Self, Self::Error> {
1189 let shape = match value.first() {
1190 Some(batch) => batch.shape,
1191 None => Err(TensorErrorKind::Empty)?,
1192 };
1193
1194 value
1195 .iter()
1196 .try_for_each(|batch| batch.check_shape([shape[0], batch.shape[1], 1, 1]))?;
1197
1198 let cursors = value
1199 .iter()
1200 .enumerate()
1201 .scan(0, |token, (batch, tensor)| {
1202 let len = tensor.shape[1];
1203 let cursor = Cursor {
1204 batch,
1205 token: *token,
1206 len,
1207 };
1208 *token += len;
1209 Some(cursor)
1210 })
1211 .collect_vec();
1212
1213 let (shape, data) = value.into_iter().fold(
1214 (Shape::new(shape[0], 0, 1, 1), vec![]),
1215 |(mut shape, mut data), tensor| {
1216 shape[1] += tensor.shape[1];
1217 data.extend(tensor.data.to_vec());
1218 (shape, data)
1219 },
1220 );
1221 let data = data.into();
1222
1223 Ok(Self {
1224 tensor: Tensor {
1225 shape,
1226 data,
1227 id: uid::Id::new(),
1228 phantom: PhantomData,
1229 },
1230 cursors,
1231 })
1232 }
1233}
1234
1235impl Context {
1236 #[inline]
1237 pub fn zeros<T: Scalar, Tensor>(&self, shape: impl Into<Shape>) -> Tensor
1238 where
1239 TensorCpu<T>: TensorInto<Tensor>,
1240 {
1241 let tensor: TensorCpu<T> = TensorInit::init(shape);
1242 tensor.to(self)
1243 }
1244
1245 #[inline]
1246 pub fn ones<T: Scalar, Tensor>(&self, shape: impl Into<Shape>) -> Tensor
1247 where
1248 TensorCpu<T>: TensorInto<Tensor>,
1249 {
1250 let shape = shape.into();
1251 let data = vec![T::one(); shape.len()];
1252 let tensor: TensorCpu<T> = TensorInit::from_data(shape, data).unwrap();
1253 tensor.to(self)
1254 }
1255
1256 #[inline]
1257 pub fn tensor_from_data<'a, T: Scalar, Tensor: TensorInitContext<T>>(
1258 &self,
1259 shape: impl Into<Shape>,
1260 data: impl Into<Cow<'a, [T]>>,
1261 ) -> Result<Tensor, TensorError> {
1262 TensorInitContext::from_data(self, shape, data)
1263 }
1264
1265 #[inline]
1266 pub fn tensor_init<T: Scalar, Tensor: TensorInitContext<T>>(
1267 &self,
1268 shape: impl Into<Shape>,
1269 ) -> Tensor {
1270 TensorInitContext::init(self, shape)
1271 }
1272}
1273
1274mod sealed {
1275 use super::{Cpu, Gpu, Kind, ReadWrite, Uniform};
1276 use crate::num::Scalar;
1277
1278 pub trait Sealed {}
1279
1280 impl<T: Scalar> Sealed for Cpu<T> {}
1281 impl<K: Kind> Sealed for Gpu<K> {}
1282
1283 impl Sealed for Uniform {}
1284 impl Sealed for ReadWrite {}
1285}
1286
1287#[cfg(test)]
1288mod tests {
1289 use anyhow::Result;
1290
1291 use super::Shape;
1292 use crate::tensor::{TensorCpu, TensorInit, TensorShape};
1293
1294 #[test]
1295 fn test_pad_64() -> Result<()> {
1296 let shape = Shape::new(133, 256, 1, 1);
1297 let x: Vec<_> = (0..shape.len()).map(|x| x as f32).collect();
1298 let x = TensorCpu::from_data(shape, x)?.pad([64, 64, 1, 1]);
1299
1300 assert_eq!(x.shape(), Shape::new(192, 256, 1, 1));
1301 assert_eq!(x[(132, 255, 0, 0)], (shape.len() - 1) as f32);
1302 assert_eq!(x[(133, 255, 0, 0)], 0.0);
1303
1304 Ok(())
1305 }
1306
1307 #[test]
1308 fn test_repeat() -> Result<()> {
1309 let shape = Shape::new(5, 1, 2, 1);
1310 let x: Vec<_> = (0..shape.len()).map(|x| x as f32).collect();
1311 let x = TensorCpu::from_data(shape, x)?;
1312
1313 let y = x.clone().repeat(1, 3);
1314 let ans = [
1315 [0.0, 1.0, 2.0, 3.0, 4.0].repeat(3),
1316 [5.0, 6.0, 7.0, 8.0, 9.0].repeat(3),
1317 ]
1318 .concat();
1319 y.check_shape([5, 3, 2, 1])?;
1320 assert_eq!(y.to_vec(), ans);
1321
1322 let y = x.clone().repeat(0, 3);
1323 y.check_shape([15, 1, 2, 1])?;
1324 assert_eq!(y.to_vec(), ans);
1325
1326 let y = x.repeat(2, 3);
1327 let ans = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0].repeat(3);
1328 y.check_shape([5, 1, 6, 1])?;
1329 assert_eq!(y.to_vec(), ans);
1330
1331 Ok(())
1332 }
1333
1334 #[test]
1335 fn test_split() -> Result<()> {
1336 let shape = Shape::new(5, 1, 2, 1);
1337 let x: Vec<_> = (0..10).map(|x| x as f32).collect();
1338 let x = TensorCpu::from_data(shape, x)?;
1339
1340 assert!(x.clone().split(0).is_err());
1341 assert!(x.clone().split(1).is_err());
1342
1343 let x = x.split(2)?;
1344 x[0].check_shape([5, 1, 1, 1])?;
1345 x[1].check_shape([5, 1, 1, 1])?;
1346 assert_eq!(x[0].to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
1347 assert_eq!(x[1].to_vec(), vec![5.0, 6.0, 7.0, 8.0, 9.0]);
1348
1349 Ok(())
1350 }
1351}