web_rwkv/tensor/
ops.rs

1use std::{hash::Hash, sync::Arc};
2
3use embed_doc_image::embed_doc_image;
4use half::f16;
5use serde::{Deserialize, Serialize};
6use wgpu::{BindGroup, CommandBuffer, CommandEncoder, ComputePass};
7
8use super::{
9    kind::{Kind, ReadWrite, Uniform},
10    Shape, TensorError, TensorErrorKind, TensorGpu, TensorGpuView, TensorScalar, TensorShape,
11};
12use crate::{
13    context::{BindGroupBuilder, CachedPipeline, Macros, PipelineKey},
14    num::{Float, Scalar},
15    tensor::{shape::TensorDimension, TensorReshape},
16};
17
18pub trait TensorCommand<T: Scalar, K: Kind> {
19    fn copy_tensor(
20        &mut self,
21        source: &TensorGpu<T, K>,
22        destination: &TensorGpu<T, K>,
23    ) -> Result<(), TensorError>;
24
25    fn copy_tensor_batch(
26        &mut self,
27        source: &TensorGpu<T, K>,
28        destination: &TensorGpu<T, K>,
29        from: usize,
30        to: usize,
31    ) -> Result<(), TensorError>;
32}
33
34impl<T: Scalar, K: Kind> TensorCommand<T, K> for CommandEncoder {
35    fn copy_tensor(
36        &mut self,
37        source: &TensorGpu<T, K>,
38        destination: &TensorGpu<T, K>,
39    ) -> Result<(), TensorError> {
40        destination.check_shape(source.shape())?;
41        let size = destination.size() as u64;
42        self.copy_buffer_to_buffer(&source.buffer, 0, &destination.buffer, 0, size);
43        Ok(())
44    }
45
46    fn copy_tensor_batch(
47        &mut self,
48        source: &TensorGpu<T, K>,
49        destination: &TensorGpu<T, K>,
50        from: usize,
51        to: usize,
52    ) -> Result<(), TensorError> {
53        source.check_shape([source.shape[0], source.shape[1], source.shape[2], 1])?;
54        destination.check_shape([source.shape[0], source.shape[1], destination.shape[2], 1])?;
55        if from >= source.shape[2] {
56            Err(TensorErrorKind::BatchOutOfRange {
57                batch: from,
58                max: source.shape[2],
59            })?;
60        }
61        if to > destination.shape[2] {
62            Err(TensorErrorKind::BatchOutOfRange {
63                batch: to,
64                max: destination.shape[2],
65            })?;
66        }
67        self.copy_buffer_to_buffer(
68            &source.buffer,
69            (T::size() * source.shape[0] * source.shape[1] * from) as u64,
70            &destination.buffer,
71            (T::size() * destination.shape[0] * destination.shape[1] * to) as u64,
72            (T::size() * source.shape[0] * source.shape[1]) as u64,
73        );
74        Ok(())
75    }
76}
77
78impl crate::context::Context {
79    pub fn encode(&self, op: &TensorOp) -> Vec<CommandBuffer> {
80        struct Atom<'a> {
81            pipeline: &'a CachedPipeline,
82            bindings: &'a [Arc<BindGroup>],
83            dispatch: &'a [u32; 3],
84        }
85
86        fn dispatch<'b, 'a: 'b>(
87            pass: &'b mut ComputePass<'a>,
88            Atom {
89                pipeline,
90                bindings,
91                dispatch,
92            }: Atom<'a>,
93        ) {
94            pass.set_pipeline(&pipeline.pipeline);
95            for (index, bind) in bindings.iter().enumerate() {
96                pass.set_bind_group(index as u32, &**bind, &[]);
97            }
98            pass.dispatch_workgroups(dispatch[0], dispatch[1], dispatch[2]);
99        }
100
101        fn flatten<'b, 'a: 'b>(
102            commands: &'b mut Vec<Vec<Atom<'a>>>,
103            passes: &'b mut Vec<Atom<'a>>,
104            op: &'a TensorOp,
105        ) {
106            match op {
107                TensorOp::Atom {
108                    pipeline,
109                    bindings,
110                    dispatch,
111                } => passes.push(Atom {
112                    pipeline,
113                    bindings,
114                    dispatch,
115                }),
116                TensorOp::List(ops) => ops.iter().for_each(|op| flatten(commands, passes, op)),
117                TensorOp::Sep => {
118                    let mut temp = vec![];
119                    std::mem::swap(&mut temp, passes);
120                    commands.push(temp);
121                }
122            }
123        }
124
125        let mut commands = vec![];
126        let mut passes = vec![];
127        flatten(&mut commands, &mut passes, op);
128        commands.push(passes);
129
130        commands
131            .into_iter()
132            .filter(|atoms| !atoms.is_empty())
133            .map(|atoms| {
134                let mut encoder = self.device.create_command_encoder(&Default::default());
135                let mut pass = encoder.begin_compute_pass(&Default::default());
136                for atom in atoms {
137                    dispatch(&mut pass, atom);
138                }
139                drop(pass);
140                encoder.finish()
141            })
142            .collect()
143    }
144}
145
146#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
147#[serde(rename_all = "snake_case")]
148pub enum Activation {
149    #[default]
150    #[serde(rename = "")]
151    None,
152    SquaredRelu,
153    #[serde(rename = "custom_tanh")]
154    Tanh,
155    StableExp,
156    OppositeExp,
157    Softplus,
158    Sigmoid,
159    Silu,
160}
161
162impl std::fmt::Display for Activation {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        f.write_str(serde_variant::to_variant_name(self).unwrap())
165    }
166}
167
168impl Macros {
169    /// Define a `u32` macro `NF4_BLOCK_SIZE`.
170    pub fn nf4(mut self, block_size: u32) -> Self {
171        self.insert("NF4_BLOCK_SIZE".into(), format!("{block_size}u"));
172        self
173    }
174
175    /// Define a `u32` macro `NF4_BLOCK_SIZE`.
176    pub fn int8(mut self, block_size: u32) -> Self {
177        self.insert("INT8_BLOCK_SIZE".into(), format!("{block_size}u"));
178        self
179    }
180
181    /// Define a `f32` macro with a given name.
182    pub fn f32(mut self, name: impl Into<String>, value: f32) -> Self {
183        self.insert(name.into(), format!("{value}"));
184        self
185    }
186
187    /// Define a `usize` macro with a given name.
188    pub fn u32(mut self, name: impl Into<String>, value: u32) -> Self {
189        self.insert(name.into(), format!("{value}u"));
190        self
191    }
192
193    /// Define a `bool` macro with a given name.
194    pub fn bool(mut self, name: impl Into<String>, value: bool) -> Self {
195        match value {
196            true => {
197                self.insert(name.into(), Default::default());
198                self
199            }
200            false => self,
201        }
202    }
203
204    pub fn activate(mut self, name: impl Into<String>, value: Activation) -> Self {
205        const ACTIVATION_DEFINE: &str = "
206fn squared_relu(x: vec4<f32>) -> vec4<f32> {
207    let p = max(x, vec4<f32>(0.0));
208    return p * p;
209}
210
211fn stable_exp(x: vec4<f32>) -> vec4<f32> {
212    return exp(-exp(x));
213}
214
215fn opposite_exp(x: vec4<f32>) -> vec4<f32> {
216    return -exp(x);
217}
218
219fn softplus(x: vec4<f32>) -> vec4<f32> {
220    return log(1.0 + exp(x));
221}
222
223fn sigmoid(x: vec4<f32>) -> vec4<f32> {
224    return 1.0 / (1.0 + exp(-x));
225}
226
227fn silu(x: vec4<f32>) -> vec4<f32> {
228    return x / (1.0 + exp(-x));
229}
230
231// Metal has some trouble with `tanh`.
232fn custom_tanh(x: vec4<f32>) -> vec4<f32> {
233    return select(tanh(x), vec4<f32>(1.0), x > vec4<f32>(42.0));
234}
235";
236        self.insert("ACTIVATION_DEFINE".into(), ACTIVATION_DEFINE.to_string());
237        self.insert(name.into(), value.to_string());
238        self
239    }
240
241    /// Define the macro specifies input/output tensor data type.
242    pub fn tensor<T: Float>(
243        mut self,
244        _tensor: &impl TensorScalar<T = T>,
245        prefix: Option<&'_ str>,
246    ) -> Self {
247        match prefix {
248            None => self.insert(T::DEF.into(), Default::default()),
249            Some(prefix) => self.insert(format!("{}_{}", prefix, T::DEF), Default::default()),
250        };
251        self
252    }
253
254    /// Define a macro with custom display name and prefix.
255    pub fn custom(mut self, value: impl std::fmt::Display, prefix: Option<&'_ str>) -> Self {
256        match prefix {
257            None => self.insert(format!("{value}"), Default::default()),
258            Some(prefix) => self.insert(format!("{prefix}_{value}"), Default::default()),
259        };
260        self
261    }
262
263    /// Add a define when `condition` is true.
264    pub fn define(mut self, name: impl Into<String>, condition: bool) -> Self {
265        if condition {
266            self.insert(name.into(), Default::default());
267        }
268        self
269    }
270
271    /// Add subgroup defines.
272    #[cfg(feature = "subgroup-ops")]
273    pub fn subgroup(self, min: u32, max: u32) -> Self {
274        self.u32("MIN_SUBGROUP_SIZE", min)
275            .u32("MAX_SUBGROUP_SIZE", max)
276            .define(format!("SUBGROUP_SIZE_{min}_{max}"), true)
277    }
278}
279
280pub enum TensorOp {
281    Atom {
282        pipeline: Arc<CachedPipeline>,
283        bindings: Vec<Arc<BindGroup>>,
284        dispatch: [u32; 3],
285    },
286    List(Vec<TensorOp>),
287    Sep,
288}
289
290impl TensorOp {
291    pub const NF4_BLOCK_SIZE: u32 = 64;
292    pub const INT8_BLOCK_SIZE: u32 = 128;
293
294    #[inline]
295    pub fn empty() -> Self {
296        Self::List(vec![])
297    }
298
299    /// Softmax operator applied on `x`.
300    pub fn softmax(x: &TensorGpu<impl Float, ReadWrite>) -> Result<Self, TensorError> {
301        const BLOCK_SIZE: u32 = 128;
302
303        let context = x.context();
304        let shape = x.shape();
305
306        #[cfg(not(feature = "subgroup-ops"))]
307        let key = PipelineKey::new(
308            "softmax",
309            "softmax",
310            Macros::new().u32("BLOCK_SIZE", BLOCK_SIZE).tensor(x, None),
311        );
312        #[cfg(feature = "subgroup-ops")]
313        let key = PipelineKey::new(
314            "softmax",
315            "softmax",
316            Macros::new()
317                .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
318                .u32("BLOCK_SIZE", BLOCK_SIZE)
319                .tensor(x, None),
320        );
321
322        #[cfg(not(feature = "subgroup-ops"))]
323        let pipeline = context.checkout_pipeline(
324            &key,
325            include_str!("../shaders/softmax.wgsl"),
326            &[x.meta_layout(0), x.layout(1, false)],
327        );
328        #[cfg(feature = "subgroup-ops")]
329        let pipeline = context.checkout_pipeline(
330            &key,
331            include_str!("../shaders/subgroup/softmax.wgsl"),
332            &[x.meta_layout(0), x.layout(1, false)],
333        );
334        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
335            .bind_meta(0, x)
336            .bind(1, x)
337            .build()];
338
339        Ok(Self::Atom {
340            pipeline,
341            bindings,
342            dispatch: [1, shape[1] as u32, shape[2] as u32],
343        })
344    }
345
346    /// Embedding on GPU.
347    /// - `tokens` shape: `[T, B]`.
348    /// - `input` shape: `[C, V]`.
349    /// - `output` shape: `[C, T, B]`.
350    pub fn embed(
351        tokens: &TensorGpu<u32, ReadWrite>,
352        input: &TensorGpu<f16, ReadWrite>,
353        output: &TensorGpu<impl Float, ReadWrite>,
354    ) -> Result<Self, TensorError> {
355        const BLOCK_SIZE: u32 = 128;
356
357        let context = output.context();
358        let shape = {
359            let [index, token, batch, _] = output.shape().into();
360            let [_, vocab, _, _] = input.shape().into();
361            tokens.check_shape([token, batch, 1, 1])?;
362            input.check_shape([index, vocab, 1, 1])?;
363            output.check_shape([index, token, batch, 1])?;
364            output.shape()
365        };
366
367        let key = PipelineKey::new(
368            "embed",
369            "embed",
370            Macros::new()
371                .u32("BLOCK_SIZE", BLOCK_SIZE)
372                .tensor(output, None),
373        );
374        let pipeline = context.checkout_pipeline(
375            &key,
376            include_str!("../shaders/embed.wgsl"),
377            &[
378                output.meta_layout(0),
379                tokens.layout(1, true),
380                input.layout(2, true),
381                output.layout(3, false),
382            ],
383        );
384        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
385            .bind_meta(0, output)
386            .bind(1, tokens)
387            .bind(2, input)
388            .bind(3, output)
389            .build()];
390
391        Ok(Self::Atom {
392            pipeline,
393            bindings,
394            dispatch: [
395                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
396                shape[1] as u32,
397                shape[2] as u32,
398            ],
399        })
400    }
401
402    /// Layer normalization applied on `x`, with weight `w` and bias `b`.
403    /// - `x` shape: `[C, T, B]`.
404    /// - `w` shape: `[C, 1, 1]`.
405    /// - `b` shape: `[C, 1, 1]`.
406    /// - `s` shape: `[4, T, B]`, mean and inverse std of `x`.
407    pub fn layer_norm(
408        w: &TensorGpu<f16, ReadWrite>,
409        b: &TensorGpu<f16, ReadWrite>,
410        x: &TensorGpu<impl Float, ReadWrite>,
411        eps: f32,
412    ) -> Result<Self, TensorError> {
413        const BLOCK_SIZE: u32 = 128;
414
415        let context = x.context();
416        let shape = {
417            let [index, token, batch, _] = x.shape().into();
418            x.check_shape([index, token, batch, 1])?;
419            w.check_shape([index, 1, 1, 1])?;
420            b.check_shape([index, 1, 1, 1])?;
421            x.shape()
422        };
423
424        let key = PipelineKey::new(
425            "layer_norm",
426            "layer_norm",
427            Macros::new()
428                .u32("BLOCK_SIZE", BLOCK_SIZE)
429                .tensor(x, None)
430                .f32("EPS", eps),
431        );
432        let pipeline = context.checkout_pipeline(
433            &key,
434            include_str!("../shaders/layer_norm.wgsl"),
435            &[
436                x.meta_layout(0),
437                w.layout(1, true),
438                b.layout(2, true),
439                x.layout(3, false),
440            ],
441        );
442        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
443            .bind_meta(0, x)
444            .bind(1, w)
445            .bind(2, b)
446            .bind(3, x)
447            .build()];
448
449        Ok(Self::Atom {
450            pipeline,
451            bindings,
452            dispatch: [1, shape[1] as u32, shape[2] as u32],
453        })
454    }
455
456    /// Group normalization applied on `x`, with weight `w` and bias `b`.
457    /// - `x` shape: `[S, H, A]`.
458    /// - `w` shape: `[S, H, 1]`.
459    /// - `b` shape: `[S, H, 1]`.
460    pub fn group_norm(
461        w: &TensorGpu<f16, ReadWrite>,
462        b: &TensorGpu<f16, ReadWrite>,
463        x: &TensorGpu<impl Float, ReadWrite>,
464        eps: f32,
465    ) -> Result<Self, TensorError> {
466        const BLOCK_SIZE: u32 = 32;
467
468        let context = x.context();
469        let shape = {
470            let [index, head, token, _] = x.shape().into();
471            x.check_shape([index, head, token, 1])?;
472            w.check_shape([index, head, 1, 1])?;
473            b.check_shape([index, head, 1, 1])?;
474            x.shape()
475        };
476
477        let key = PipelineKey::new(
478            "group_norm",
479            "layer_norm",
480            Macros::new()
481                .u32("BLOCK_SIZE", BLOCK_SIZE)
482                .bool("GROUP_NORM", true)
483                .tensor(x, None)
484                .f32("EPS", eps),
485        );
486        let pipeline = context.checkout_pipeline(
487            &key,
488            include_str!("../shaders/layer_norm.wgsl"),
489            &[
490                x.meta_layout(0),
491                w.layout(1, true),
492                b.layout(2, true),
493                x.layout(3, false),
494            ],
495        );
496        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
497            .bind_meta(0, x)
498            .bind(1, w)
499            .bind(2, b)
500            .bind(3, x)
501            .build()];
502
503        Ok(Self::Atom {
504            pipeline,
505            bindings,
506            dispatch: [1, shape[1] as u32, shape[2] as u32],
507        })
508    }
509
510    /// Recenter `x` to be zero-mean.
511    pub fn recenter(x: &TensorGpu<impl Float, ReadWrite>) -> Result<Self, TensorError> {
512        const BLOCK_SIZE: u32 = 128;
513
514        let context = x.context();
515        let shape = x.shape();
516
517        #[cfg(not(feature = "subgroup-ops"))]
518        let key = PipelineKey::new(
519            "recenter",
520            "recenter",
521            Macros::new()
522                .u32("BLOCK_SIZE", BLOCK_SIZE)
523                .tensor(x, None)
524                .f32("EPS", 0.0),
525        );
526        #[cfg(feature = "subgroup-ops")]
527        let key = PipelineKey::new(
528            "recenter",
529            "recenter",
530            Macros::new()
531                .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
532                .u32("BLOCK_SIZE", BLOCK_SIZE)
533                .tensor(x, None)
534                .f32("EPS", 0.0),
535        );
536
537        #[cfg(not(feature = "subgroup-ops"))]
538        let pipeline = context.checkout_pipeline(
539            &key,
540            include_str!("../shaders/normalize.wgsl"),
541            &[x.meta_layout(0), x.layout(3, false)],
542        );
543        #[cfg(feature = "subgroup-ops")]
544        let pipeline = context.checkout_pipeline(
545            &key,
546            include_str!("../shaders/subgroup/normalize.wgsl"),
547            &[x.meta_layout(0), x.layout(3, false)],
548        );
549
550        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
551            .bind_meta(0, x)
552            .bind(3, x)
553            .build()];
554
555        Ok(Self::Atom {
556            pipeline,
557            bindings,
558            dispatch: [1, shape[1] as u32, shape[2] as u32],
559        })
560    }
561
562    /// Root-mean-square normalization applied on `x`, with weight `w` and bias `b`.
563    /// - `x` shape: `[C, T, B]`.
564    /// - `w` shape: `[C, 1, 1]`.
565    /// - `b` shape: `[C, 1, 1]`.
566    pub fn rms_norm(
567        w: &TensorGpu<f16, ReadWrite>,
568        b: &TensorGpu<f16, ReadWrite>,
569        x: &TensorGpu<impl Float, ReadWrite>,
570        eps: f32,
571    ) -> Result<Self, TensorError> {
572        const BLOCK_SIZE: u32 = 128;
573
574        let context = x.context();
575        let shape = {
576            let [index, token, batch, _] = x.shape().into();
577            x.check_shape([index, token, batch, 1])?;
578            w.check_shape([index, 1, 1, 1])?;
579            b.check_shape([index, 1, 1, 1])?;
580            x.shape()
581        };
582
583        #[cfg(not(feature = "subgroup-ops"))]
584        let key = PipelineKey::new(
585            "rms_norm",
586            "rms_norm",
587            Macros::new()
588                .u32("BLOCK_SIZE", BLOCK_SIZE)
589                .tensor(x, None)
590                .f32("EPS", eps),
591        );
592        #[cfg(feature = "subgroup-ops")]
593        let key = PipelineKey::new(
594            "rms_norm",
595            "rms_norm",
596            Macros::new()
597                .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
598                .u32("BLOCK_SIZE", BLOCK_SIZE)
599                .tensor(x, None)
600                .f32("EPS", eps),
601        );
602
603        #[cfg(not(feature = "subgroup-ops"))]
604        let pipeline = context.checkout_pipeline(
605            &key,
606            include_str!("../shaders/normalize.wgsl"),
607            &[
608                x.meta_layout(0),
609                w.layout(1, true),
610                b.layout(2, true),
611                x.layout(3, false),
612            ],
613        );
614        #[cfg(feature = "subgroup-ops")]
615        let pipeline = context.checkout_pipeline(
616            &key,
617            include_str!("../shaders/subgroup/normalize.wgsl"),
618            &[
619                x.meta_layout(0),
620                w.layout(1, true),
621                b.layout(2, true),
622                x.layout(3, false),
623            ],
624        );
625
626        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
627            .bind_meta(0, x)
628            .bind(1, w)
629            .bind(2, b)
630            .bind(3, x)
631            .build()];
632
633        Ok(Self::Atom {
634            pipeline,
635            bindings,
636            dispatch: [1, shape[1] as u32, shape[2] as u32],
637        })
638    }
639
640    /// L2 normalization applied on `x`.
641    /// - `x` shape: `[C, T, B]`.
642    pub fn l2_norm(x: &TensorGpu<impl Float, ReadWrite>, eps: f32) -> Result<Self, TensorError> {
643        const BLOCK_SIZE: u32 = 128;
644
645        let context = x.context();
646        let shape = x.shape();
647
648        #[cfg(not(feature = "subgroup-ops"))]
649        let key = PipelineKey::new(
650            "l2_norm",
651            "l2_norm",
652            Macros::new()
653                .u32("BLOCK_SIZE", BLOCK_SIZE)
654                .tensor(x, None)
655                .f32("EPS", eps),
656        );
657        #[cfg(feature = "subgroup-ops")]
658        let key = PipelineKey::new(
659            "l2_norm",
660            "l2_norm",
661            Macros::new()
662                .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
663                .u32("BLOCK_SIZE", BLOCK_SIZE)
664                .tensor(x, None)
665                .f32("EPS", eps),
666        );
667
668        #[cfg(not(feature = "subgroup-ops"))]
669        let pipeline = context.checkout_pipeline(
670            &key,
671            include_str!("../shaders/normalize.wgsl"),
672            &[x.meta_layout(0), x.layout(3, false)],
673        );
674        #[cfg(feature = "subgroup-ops")]
675        let pipeline = context.checkout_pipeline(
676            &key,
677            include_str!("../shaders/subgroup/normalize.wgsl"),
678            &[x.meta_layout(0), x.layout(3, false)],
679        );
680
681        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
682            .bind_meta(0, x)
683            .bind(3, x)
684            .build()];
685
686        Ok(Self::Atom {
687            pipeline,
688            bindings,
689            dispatch: [1, shape[1] as u32, shape[2] as u32],
690        })
691    }
692
693    /// Fp32 matrix-vector multiplication.
694    /// - `matrix` shape: `[C, R, B]`.
695    /// - `input` shape: `[C, T, B]`.
696    /// - `output` shape: `[R, T, B]`.
697    pub fn matmul_vec_fp16<'a, 'b, F0: Float, F1: Float>(
698        matrix: &TensorGpu<f16, ReadWrite>,
699        input: impl Into<TensorGpuView<'a, F0>>,
700        output: impl Into<TensorGpuView<'b, F1>>,
701        act: Activation,
702        sparse: bool,
703    ) -> Result<Self, TensorError> {
704        const BLOCK_SIZE: u32 = 128;
705
706        let input: TensorGpuView<_> = input.into();
707        let output: TensorGpuView<_> = output.into();
708
709        let context = output.context();
710        let shape = {
711            let [m, n, b, _] = output.shape().into();
712            let [k, _, _, _] = input.shape().into();
713            matrix.check_shape([k, m, b, 1])?;
714            input.check_shape([k, n, b, 1])?;
715            output.check_shape([m, n, b, 1])?;
716            output.shape()
717        };
718
719        #[cfg(not(feature = "subgroup-ops"))]
720        let key = PipelineKey::new(
721            "matmul_vec_fp16",
722            "matmul",
723            Macros::new()
724                .u32("BLOCK_SIZE", BLOCK_SIZE)
725                .tensor(&input, Some("IN"))
726                .tensor(&output, Some("OUT"))
727                .activate("ACT", act)
728                .bool("SPARSE_INPUT", sparse),
729        );
730        #[cfg(feature = "subgroup-ops")]
731        let key = PipelineKey::new(
732            "matmul_vec_fp16",
733            "matmul",
734            Macros::new()
735                .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
736                .u32("BLOCK_SIZE", BLOCK_SIZE)
737                .tensor(&input, Some("IN"))
738                .tensor(&output, Some("OUT"))
739                .activate("ACT", act)
740                .bool("SPARSE_INPUT", sparse),
741        );
742
743        #[cfg(not(feature = "subgroup-ops"))]
744        let pipeline = context.checkout_pipeline(
745            &key,
746            include_str!("../shaders/matmul_vec_fp16.wgsl"),
747            &[
748                matrix.meta_layout(0),
749                input.meta_layout(1),
750                output.meta_layout(2),
751                matrix.layout(3, true),
752                input.layout(4, true),
753                output.layout(5, false),
754            ],
755        );
756        #[cfg(feature = "subgroup-ops")]
757        let pipeline = context.checkout_pipeline(
758            &key,
759            include_str!("../shaders/subgroup/matmul_vec_fp16.wgsl"),
760            &[
761                matrix.meta_layout(0),
762                input.meta_layout(1),
763                output.meta_layout(2),
764                matrix.layout(3, true),
765                input.layout(4, true),
766                output.layout(5, false),
767            ],
768        );
769
770        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
771            .bind_meta(0, matrix)
772            .bind_meta(1, &input)
773            .bind_meta(2, &output)
774            .bind(3, matrix)
775            .bind(4, &input)
776            .bind(5, &output)
777            .build()];
778
779        Ok(Self::Atom {
780            pipeline,
781            bindings,
782            dispatch: [matrix.shape[1] as u32 / 4, shape[1] as u32, shape[2] as u32],
783        })
784    }
785
786    /// Int8 matrix-vector multiplication.
787    /// - `matrix` shape: `[C, R, B]`.
788    /// - `input` shape: `[C, T, B]`.
789    /// - `output` shape: `[R, T, B]`.
790    #[allow(clippy::too_many_arguments)]
791    pub fn matmul_vec_int8<'a, 'b, F0: Float, F1: Float>(
792        matrix: &TensorGpu<u8, ReadWrite>,
793        minmax: &TensorGpu<f16, ReadWrite>,
794        input: impl Into<TensorGpuView<'a, F0>>,
795        output: impl Into<TensorGpuView<'b, F1>>,
796        act: Activation,
797        sparse: bool,
798    ) -> Result<Self, TensorError> {
799        const BLOCK_SIZE: u32 = 128;
800
801        let input: TensorGpuView<_> = input.into();
802        let output: TensorGpuView<_> = output.into();
803
804        let context = matrix.context();
805        let shape = {
806            let [m, n, b, _] = output.shape().into();
807            let [k, _, _, _] = input.shape().into();
808            let len = matrix.shape().len();
809            minmax.check_shape([(len << 1).div_ceil(Self::INT8_BLOCK_SIZE as usize), 1, 1, 1])?;
810            matrix.check_shape([k, m, b, 1])?;
811            input.check_shape([k, n, b, 1])?;
812            output.check_shape([m, n, b, 1])?;
813            output.shape()
814        };
815
816        #[cfg(not(feature = "subgroup-ops"))]
817        let key = PipelineKey::new(
818            "matmul_vec_int8",
819            "matmul",
820            Macros::new()
821                .u32("BLOCK_SIZE", BLOCK_SIZE)
822                .int8(Self::INT8_BLOCK_SIZE)
823                .tensor(&input, Some("IN"))
824                .tensor(&output, Some("OUT"))
825                .activate("ACT", act)
826                .bool("SPARSE_INPUT", sparse),
827        );
828        #[cfg(feature = "subgroup-ops")]
829        let key = PipelineKey::new(
830            "matmul_vec_int8",
831            "matmul",
832            Macros::new()
833                .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
834                .u32("BLOCK_SIZE", BLOCK_SIZE)
835                .int8(Self::INT8_BLOCK_SIZE)
836                .tensor(&input, Some("IN"))
837                .tensor(&output, Some("OUT"))
838                .activate("ACT", act)
839                .bool("SPARSE_INPUT", sparse),
840        );
841
842        #[cfg(not(feature = "subgroup-ops"))]
843        let pipeline = context.checkout_pipeline(
844            &key,
845            include_str!("../shaders/matmul_vec_int8.wgsl"),
846            &[
847                matrix.meta_layout(0),
848                input.meta_layout(1),
849                output.meta_layout(2),
850                matrix.layout(3, true),
851                minmax.layout(4, true),
852                input.layout(5, true),
853                output.layout(6, false),
854            ],
855        );
856        #[cfg(feature = "subgroup-ops")]
857        let pipeline = context.checkout_pipeline(
858            &key,
859            include_str!("../shaders/subgroup/matmul_vec_int8.wgsl"),
860            &[
861                matrix.meta_layout(0),
862                input.meta_layout(1),
863                output.meta_layout(2),
864                matrix.layout(3, true),
865                minmax.layout(4, true),
866                input.layout(5, true),
867                output.layout(6, false),
868            ],
869        );
870
871        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
872            .bind_meta(0, matrix)
873            .bind_meta(1, &input)
874            .bind_meta(2, &output)
875            .bind(3, matrix)
876            .bind(4, minmax)
877            .bind(5, &input)
878            .bind(6, &output)
879            .build()];
880
881        Ok(Self::Atom {
882            pipeline,
883            bindings,
884            dispatch: [matrix.shape[1] as u32 / 4, shape[1] as u32, shape[2] as u32],
885        })
886    }
887
888    /// NFloat4 matrix-vector multiplication.
889    /// - `matrix` shape: `[C, R, B]`.
890    /// - `input` shape: `[C, T, B]`.
891    /// - `output` shape: `[R, T, B]`.
892    pub fn matmul_vec_nf4<'a, 'b, F0: Float, F1: Float>(
893        matrix: &TensorGpu<u8, ReadWrite>,
894        quant: &TensorGpu<f32, Uniform>,
895        absmax: &TensorGpu<f16, ReadWrite>,
896        input: impl Into<TensorGpuView<'a, F0>>,
897        output: impl Into<TensorGpuView<'b, F1>>,
898        act: Activation,
899        sparse: bool,
900    ) -> Result<Self, TensorError> {
901        const BLOCK_SIZE: u32 = 128;
902
903        let input: TensorGpuView<_> = input.into();
904        let output: TensorGpuView<_> = output.into();
905
906        let context = matrix.context();
907        let shape = {
908            let [m, n, b, _] = output.shape().into();
909            let [k, _, _, _] = input.shape().into();
910            let len = matrix.shape().len() << 1;
911            absmax.check_shape([len.div_ceil(Self::NF4_BLOCK_SIZE as usize), 1, 1, 1])?;
912            matrix.check_shape([k >> 1, m, b, 1])?;
913            input.check_shape([k, n, b, 1])?;
914            output.check_shape([m, n, b, 1])?;
915            output.shape()
916        };
917
918        #[cfg(not(feature = "subgroup-ops"))]
919        let key = PipelineKey::new(
920            "matmul_vec_nf4",
921            "matmul",
922            Macros::new()
923                .u32("BLOCK_SIZE", BLOCK_SIZE)
924                .nf4(Self::NF4_BLOCK_SIZE)
925                .tensor(&input, Some("IN"))
926                .tensor(&output, Some("OUT"))
927                .activate("ACT", act)
928                .bool("SPARSE_INPUT", sparse),
929        );
930        #[cfg(feature = "subgroup-ops")]
931        let key = PipelineKey::new(
932            "matmul_vec_nf4",
933            "matmul",
934            Macros::new()
935                .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
936                .u32("BLOCK_SIZE", BLOCK_SIZE)
937                .nf4(Self::NF4_BLOCK_SIZE)
938                .tensor(&input, Some("IN"))
939                .tensor(&output, Some("OUT"))
940                .activate("ACT", act)
941                .bool("SPARSE_INPUT", sparse),
942        );
943
944        #[cfg(not(feature = "subgroup-ops"))]
945        let pipeline = context.checkout_pipeline(
946            &key,
947            include_str!("../shaders/matmul_vec_nf4.wgsl"),
948            &[
949                matrix.meta_layout(0),
950                input.meta_layout(1),
951                output.meta_layout(2),
952                quant.layout(3),
953                matrix.layout(4, true),
954                absmax.layout(5, true),
955                input.layout(6, true),
956                output.layout(7, false),
957            ],
958        );
959        #[cfg(feature = "subgroup-ops")]
960        let pipeline = context.checkout_pipeline(
961            &key,
962            include_str!("../shaders/subgroup/matmul_vec_nf4.wgsl"),
963            &[
964                matrix.meta_layout(0),
965                input.meta_layout(1),
966                output.meta_layout(2),
967                quant.layout(3),
968                matrix.layout(4, true),
969                absmax.layout(5, true),
970                input.layout(6, true),
971                output.layout(7, false),
972            ],
973        );
974
975        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
976            .bind_meta(0, matrix)
977            .bind_meta(1, &input)
978            .bind_meta(2, &output)
979            .bind(3, quant)
980            .bind(4, matrix)
981            .bind(5, absmax)
982            .bind(6, &input)
983            .bind(7, &output)
984            .build()];
985
986        Ok(Self::Atom {
987            pipeline,
988            bindings,
989            dispatch: [matrix.shape[1] as u32 / 4, shape[1] as u32, shape[2] as u32],
990        })
991    }
992
993    /// Fp16 matrix-matrix multiplication.
994    /// - `matrix` shape: `[K, M, B]`.
995    /// - `input` shape: `[K, N, B]`.
996    /// - `output` shape: `[M, N, B]`.
997    ///
998    /// Note: `K` must be multiples of 4; `M` and `N` must be multiples of 4.
999    pub fn matmul_mat_fp16<'a, 'b, 'c, F0: Float, F1: Float>(
1000        matrix: impl Into<TensorGpuView<'c, f16>>,
1001        input: impl Into<TensorGpuView<'a, F0>>,
1002        output: impl Into<TensorGpuView<'b, F1>>,
1003        act: Activation,
1004    ) -> Result<Self, TensorError> {
1005        const BLOCK_SIZE: u32 = 8;
1006
1007        let matrix: TensorGpuView<_> = matrix.into();
1008        let input: TensorGpuView<_> = input.into();
1009        let output: TensorGpuView<_> = output.into();
1010
1011        let context = output.context();
1012        let shape = {
1013            let [m, n, b, _] = output.shape().into();
1014            let [k, _, _, _] = input.shape().into();
1015            matrix.check_shape([k, m, b, 1])?;
1016            input.check_shape([k, n, b, 1])?;
1017            output.check_shape([m, n, b, 1])?;
1018            output.shape()
1019        };
1020
1021        let key = PipelineKey::new(
1022            "matmul_mat_fp16",
1023            "matmul",
1024            Macros::new()
1025                .u32("BLOCK_SIZE", BLOCK_SIZE)
1026                .tensor(&input, Some("IN"))
1027                .tensor(&output, Some("OUT"))
1028                .activate("ACT", act),
1029        );
1030        let pipeline = context.checkout_pipeline(
1031            &key,
1032            include_str!("../shaders/matmul_mat_fp16.wgsl"),
1033            &[
1034                matrix.meta_layout(0),
1035                input.meta_layout(1),
1036                output.meta_layout(2),
1037                matrix.layout(3, true),
1038                input.layout(4, true),
1039                output.layout(5, false),
1040            ],
1041        );
1042
1043        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1044            .bind_meta(0, &matrix)
1045            .bind_meta(1, &input)
1046            .bind_meta(2, &output)
1047            .bind(3, &matrix)
1048            .bind(4, &input)
1049            .bind(5, &output)
1050            .build()];
1051
1052        Ok(Self::Atom {
1053            pipeline,
1054            bindings,
1055            dispatch: [
1056                u32::div_ceil(u32::div_ceil(shape[0] as u32, 4), BLOCK_SIZE),
1057                u32::div_ceil(u32::div_ceil(shape[1] as u32, 4), BLOCK_SIZE),
1058                shape[2] as u32,
1059            ],
1060        })
1061    }
1062
1063    /// Int8 matrix-matrix multiplication.
1064    /// - `matrix` shape: `[K, M, B]`.
1065    /// - `input` shape: `[K, N, B]`.
1066    /// - `output` shape: `[M, N, B]`.
1067    ///
1068    /// Notes:
1069    /// 1. `K` must be multiples of 4; `M` and `N` must be multiples of 4.
1070    /// 2. The total size of `matrix` must be multiples of 128.
1071    #[allow(clippy::too_many_arguments)]
1072    pub fn matmul_mat_int8<'a, 'b, 'c, F0: Float, F1: Float>(
1073        matrix: impl Into<TensorGpuView<'c, u8>>,
1074        minmax: &TensorGpu<f16, ReadWrite>,
1075        input: impl Into<TensorGpuView<'a, F0>>,
1076        output: impl Into<TensorGpuView<'b, F1>>,
1077        act: Activation,
1078    ) -> Result<Self, TensorError> {
1079        const BLOCK_SIZE: u32 = 8;
1080
1081        let matrix: TensorGpuView<_> = matrix.into();
1082        let input: TensorGpuView<_> = input.into();
1083        let output: TensorGpuView<_> = output.into();
1084
1085        let context = output.context();
1086        let shape = {
1087            let [m, n, b, _] = output.shape().into();
1088            let [k, _, _, _] = input.shape().into();
1089            let len = matrix.shape().len();
1090            minmax.check_shape([(len << 1).div_ceil(Self::INT8_BLOCK_SIZE as usize), 1, 1, 1])?;
1091            matrix.check_shape([k, m, b, 1])?;
1092            input.check_shape([k, n, b, 1])?;
1093            output.check_shape([m, n, b, 1])?;
1094            output.shape()
1095        };
1096
1097        let key = PipelineKey::new(
1098            "matmul_mat_int8",
1099            "matmul",
1100            Macros::new()
1101                .u32("BLOCK_SIZE", BLOCK_SIZE)
1102                .int8(Self::INT8_BLOCK_SIZE)
1103                .tensor(&input, Some("IN"))
1104                .tensor(&output, Some("OUT"))
1105                .activate("ACT", act),
1106        );
1107        let pipeline = context.checkout_pipeline(
1108            &key,
1109            include_str!("../shaders/matmul_mat_int8.wgsl"),
1110            &[
1111                matrix.meta_layout(0),
1112                input.meta_layout(1),
1113                output.meta_layout(2),
1114                minmax.layout(3, true),
1115                matrix.layout(4, true),
1116                input.layout(5, true),
1117                output.layout(6, false),
1118            ],
1119        );
1120
1121        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1122            .bind_meta(0, &matrix)
1123            .bind_meta(1, &input)
1124            .bind_meta(2, &output)
1125            .bind(3, minmax)
1126            .bind(4, &matrix)
1127            .bind(5, &input)
1128            .bind(6, &output)
1129            .build()];
1130
1131        Ok(Self::Atom {
1132            pipeline,
1133            bindings,
1134            dispatch: [
1135                u32::div_ceil(u32::div_ceil(shape[0] as u32, 4), BLOCK_SIZE),
1136                u32::div_ceil(u32::div_ceil(shape[1] as u32, 4), BLOCK_SIZE),
1137                shape[2] as u32,
1138            ],
1139        })
1140    }
1141
1142    /// NFloat4 matrix-matrix multiplication.
1143    /// - `matrix` shape: `[K, M, B]`.
1144    /// - `input` shape: `[K, N, B]`.
1145    /// - `output` shape: `[M, N, B]`.
1146    ///
1147    /// Notes:
1148    /// 1. `K` must be multiples of 8; `M` and `N` must be multiples of 8.
1149    /// 2. The total size of `matrix` must be multiples of 256.
1150    pub fn matmul_mat_nf4<'a, 'b, 'c, F0: Float, F1: Float>(
1151        matrix: impl Into<TensorGpuView<'c, u8>>,
1152        quant: &TensorGpu<f32, Uniform>,
1153        absmax: &TensorGpu<f16, ReadWrite>,
1154        input: impl Into<TensorGpuView<'a, F0>>,
1155        output: impl Into<TensorGpuView<'b, F1>>,
1156        act: Activation,
1157    ) -> Result<Self, TensorError> {
1158        const BLOCK_SIZE: u32 = 8;
1159
1160        let matrix: TensorGpuView<_> = matrix.into();
1161        let input: TensorGpuView<_> = input.into();
1162        let output: TensorGpuView<_> = output.into();
1163
1164        let context = output.context();
1165        let shape = {
1166            let [m, n, b, _] = output.shape().into();
1167            let [k, _, _, _] = input.shape().into();
1168            let len = matrix.shape().len() << 1;
1169            absmax.check_shape([len.div_ceil(Self::NF4_BLOCK_SIZE as usize), 1, 1, 1])?;
1170            matrix.check_shape([k >> 1, m, b, 1])?;
1171            input.check_shape([k, n, b, 1])?;
1172            output.check_shape([m, n, b, 1])?;
1173            output.shape()
1174        };
1175
1176        let key = PipelineKey::new(
1177            "matmul_mat_nf4",
1178            "matmul",
1179            Macros::new()
1180                .u32("BLOCK_SIZE", BLOCK_SIZE)
1181                .nf4(Self::NF4_BLOCK_SIZE)
1182                .tensor(&input, Some("IN"))
1183                .tensor(&output, Some("OUT"))
1184                .activate("ACT", act),
1185        );
1186        let pipeline = context.checkout_pipeline(
1187            &key,
1188            include_str!("../shaders/matmul_mat_nf4.wgsl"),
1189            &[
1190                matrix.meta_layout(0),
1191                input.meta_layout(1),
1192                output.meta_layout(2),
1193                quant.layout(3),
1194                absmax.layout(4, true),
1195                matrix.layout(5, true),
1196                input.layout(6, true),
1197                output.layout(7, false),
1198            ],
1199        );
1200
1201        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1202            .bind_meta(0, &matrix)
1203            .bind_meta(1, &input)
1204            .bind_meta(2, &output)
1205            .bind(3, quant)
1206            .bind(4, absmax)
1207            .bind(5, &matrix)
1208            .bind(6, &input)
1209            .bind(7, &output)
1210            .build()];
1211
1212        Ok(Self::Atom {
1213            pipeline,
1214            bindings,
1215            dispatch: [
1216                u32::div_ceil(u32::div_ceil(shape[0] as u32, 4), BLOCK_SIZE),
1217                u32::div_ceil(u32::div_ceil(shape[1] as u32, 4), BLOCK_SIZE),
1218                shape[2] as u32,
1219            ],
1220        })
1221    }
1222
1223    /// Add `input` to `output`.
1224    /// - `input` shape: `[C, 1, B]` or `[C, T, B]`.
1225    /// - `output` shape: `[C, T, B]`.
1226    /// - Activations may be applied to `input`, `output` and the final result.
1227    pub fn add_activate<'a, 'b, F0: Float, F1: Float>(
1228        input: impl Into<TensorGpuView<'a, F0>>,
1229        output: impl Into<TensorGpuView<'b, F1>>,
1230        act_x: Activation,
1231        act_y: Activation,
1232        act_out: Activation,
1233    ) -> Result<Self, TensorError> {
1234        const BLOCK_SIZE: u32 = 128;
1235
1236        let input: TensorGpuView<_> = input.into();
1237        let output: TensorGpuView<_> = output.into();
1238
1239        let context = output.context();
1240        let shape = {
1241            let [index, token, batch, _] = output.shape().into();
1242            input.check_shape_any(&[
1243                [index, token, batch, 1],
1244                [index, token, 1, batch],
1245                [index, 1, batch, 1],
1246                [index, 1, 1, 1],
1247            ])?;
1248            output.check_shape([index, token, batch, 1])?;
1249            output.shape()
1250        };
1251
1252        let key = PipelineKey::new(
1253            "add",
1254            "add",
1255            Macros::new()
1256                .u32("BLOCK_SIZE", BLOCK_SIZE)
1257                .tensor(&input, Some("IN"))
1258                .tensor(&output, Some("OUT"))
1259                .activate("ACT_X", act_x)
1260                .activate("ACT_Y", act_y)
1261                .activate("ACT_OUT", act_out),
1262        );
1263        let pipeline = context.checkout_pipeline(
1264            &key,
1265            include_str!("../shaders/binary.wgsl"),
1266            &[
1267                input.meta_layout(0),
1268                output.meta_layout(1),
1269                input.layout(2, true),
1270                output.layout(3, false),
1271            ],
1272        );
1273
1274        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1275            .bind_meta(0, &input)
1276            .bind_meta(1, &output)
1277            .bind(2, &input)
1278            .bind(3, &output)
1279            .build()];
1280
1281        Ok(Self::Atom {
1282            pipeline,
1283            bindings,
1284            dispatch: [
1285                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1286                shape[1] as u32,
1287                shape[2] as u32,
1288            ],
1289        })
1290    }
1291
1292    /// Add `input` to `output`.
1293    /// - `input` shape: `[C, 1, B]` or `[C, T, B]`.
1294    /// - `output` shape: `[C, T, B]`.
1295    pub fn add<'a, 'b, F0: Float, F1: Float>(
1296        input: impl Into<TensorGpuView<'a, F0>>,
1297        output: impl Into<TensorGpuView<'b, F1>>,
1298    ) -> Result<Self, TensorError> {
1299        Self::add_activate(
1300            input,
1301            output,
1302            Activation::None,
1303            Activation::None,
1304            Activation::None,
1305        )
1306    }
1307
1308    /// Multiply `input` to `output`.
1309    /// - `input` shape: `[C, 1, B]` or `[C, T, B]`.
1310    /// - `output` shape: `[C, T, B]`.
1311    /// - Activations may be applied to `input`, `output` and the final result.
1312    pub fn mul_activate<'a, 'b, F0: Float, F1: Float>(
1313        input: impl Into<TensorGpuView<'a, F0>>,
1314        output: impl Into<TensorGpuView<'b, F1>>,
1315        act_x: Activation,
1316        act_y: Activation,
1317        act_out: Activation,
1318    ) -> Result<Self, TensorError> {
1319        const BLOCK_SIZE: u32 = 128;
1320
1321        let input: TensorGpuView<_> = input.into();
1322        let output: TensorGpuView<_> = output.into();
1323
1324        let context = output.context();
1325        let shape = {
1326            let [index, token, batch, _] = output.shape().into();
1327            input.check_shape_any(&[
1328                [index, token, batch, 1],
1329                [index, token, 1, batch],
1330                [index, 1, batch, 1],
1331                [index, 1, 1, 1],
1332            ])?;
1333            output.check_shape([index, token, batch, 1])?;
1334            output.shape()
1335        };
1336
1337        let key = PipelineKey::new(
1338            "mul",
1339            "mul",
1340            Macros::new()
1341                .u32("BLOCK_SIZE", BLOCK_SIZE)
1342                .tensor(&input, Some("IN"))
1343                .tensor(&output, Some("OUT"))
1344                .activate("ACT_X", act_x)
1345                .activate("ACT_Y", act_y)
1346                .activate("ACT_OUT", act_out),
1347        );
1348        let pipeline = context.checkout_pipeline(
1349            &key,
1350            include_str!("../shaders/binary.wgsl"),
1351            &[
1352                input.meta_layout(0),
1353                output.meta_layout(1),
1354                input.layout(2, true),
1355                output.layout(3, false),
1356            ],
1357        );
1358
1359        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1360            .bind_meta(0, &input)
1361            .bind_meta(1, &output)
1362            .bind(2, &input)
1363            .bind(3, &output)
1364            .build()];
1365
1366        Ok(Self::Atom {
1367            pipeline,
1368            bindings,
1369            dispatch: [
1370                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1371                shape[1] as u32,
1372                shape[2] as u32,
1373            ],
1374        })
1375    }
1376
1377    /// Multiply `input` to `output`.
1378    /// - `input` shape: `[C, 1, B]` or `[C, T, B]`.
1379    /// - `output` shape: `[C, T, B]`.
1380    pub fn mul<'a, 'b, F0: Float, F1: Float>(
1381        input: impl Into<TensorGpuView<'a, F0>>,
1382        output: impl Into<TensorGpuView<'b, F1>>,
1383    ) -> Result<Self, TensorError> {
1384        Self::mul_activate(
1385            input,
1386            output,
1387            Activation::None,
1388            Activation::None,
1389            Activation::None,
1390        )
1391    }
1392
1393    pub fn token_shift<'a, 'b, F: Float>(
1394        cursors: &TensorGpu<u32, ReadWrite>,
1395        time_mix: impl Into<TensorGpuView<'a, F>>,
1396        state: impl Into<TensorGpuView<'b, f32>>,
1397        input: &TensorGpu<impl Float, ReadWrite>,
1398        output: &TensorGpu<impl Float, ReadWrite>,
1399        reversed: bool,
1400    ) -> Result<Self, TensorError> {
1401        const BLOCK_SIZE: u32 = 128;
1402
1403        let time_mix: TensorGpuView<_> = time_mix.into();
1404        let state: TensorGpuView<_> = state.into();
1405
1406        let context = output.context();
1407        let shape = {
1408            let [index, token, count, _] = output.shape().into();
1409            let [_, head, batch, _] = state.shape().into();
1410            input.check_shape_any(&[[index, token, count, 1], [index, token, 1, 1]])?;
1411            time_mix.check_shape_any(&[[index, token, count, 1], [index, 1, 1, 1]])?;
1412            state.check_shape([index, head, batch, 1])?;
1413            output.shape()
1414        };
1415
1416        let key = PipelineKey::new(
1417            "token_shift",
1418            "token_shift",
1419            Macros::new()
1420                .u32("BLOCK_SIZE", BLOCK_SIZE)
1421                .tensor(&time_mix, Some("TIME_MIX"))
1422                .tensor(input, Some("IN"))
1423                .tensor(output, Some("OUT"))
1424                .bool("REVERSED", reversed),
1425        );
1426        let pipeline = context.checkout_pipeline(
1427            &key,
1428            include_str!("../shaders/token_shift.wgsl"),
1429            &[
1430                output.meta_layout(0),
1431                time_mix.meta_layout(1),
1432                state.meta_layout(2),
1433                cursors.layout(3, true),
1434                time_mix.layout(4, true),
1435                state.layout(5, true),
1436                input.layout(6, true),
1437                output.layout(7, false),
1438            ],
1439        );
1440
1441        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1442            .bind_meta(0, output)
1443            .bind_meta(1, &time_mix)
1444            .bind_meta(2, &state)
1445            .bind(3, cursors)
1446            .bind(4, &time_mix)
1447            .bind(5, &state)
1448            .bind(6, input)
1449            .bind(7, output)
1450            .build()];
1451
1452        Ok(Self::Atom {
1453            pipeline,
1454            bindings,
1455            dispatch: [
1456                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1457                shape[1] as u32,
1458                shape[2] as u32,
1459            ],
1460        })
1461    }
1462
1463    #[allow(clippy::too_many_arguments)]
1464    pub fn time_mix_v4<'a, T: Float>(
1465        cursors: &TensorGpu<u32, ReadWrite>,
1466        time_decay: &TensorGpu<f32, ReadWrite>,
1467        time_first: &TensorGpu<f32, ReadWrite>,
1468        state: impl Into<TensorGpuView<'a, f32>>,
1469        k: &TensorGpu<T, ReadWrite>,
1470        v: &TensorGpu<T, ReadWrite>,
1471        r: &TensorGpu<T, ReadWrite>,
1472        x: &TensorGpu<T, ReadWrite>,
1473    ) -> Result<Self, TensorError> {
1474        const BLOCK_SIZE: u32 = 128;
1475
1476        let state: TensorGpuView<_> = state.into();
1477
1478        let context = x.context();
1479        let shape = x.shape();
1480        k.check_shape(shape)?;
1481        v.check_shape(shape)?;
1482        r.check_shape(shape)?;
1483        time_decay.check_shape([shape[0], 1, 1, 1])?;
1484        time_first.check_shape([shape[0], 1, 1, 1])?;
1485        state.check_shape([shape[0], 4, state.shape()[2], 1])?;
1486
1487        let key = PipelineKey::new(
1488            "time_mix_v4",
1489            "time_mix",
1490            Macros::new().u32("BLOCK_SIZE", BLOCK_SIZE).tensor(x, None),
1491        );
1492        let pipeline = context.checkout_pipeline(
1493            &key,
1494            include_str!("../shaders/time_mix_v4.wgsl"),
1495            &[
1496                x.meta_layout(0),
1497                state.meta_layout(1),
1498                cursors.layout(2, true),
1499                time_decay.layout(3, true),
1500                time_first.layout(4, true),
1501                state.layout(5, false),
1502                k.layout(6, true),
1503                v.layout(7, true),
1504                r.layout(8, true),
1505                x.layout(9, false),
1506            ],
1507        );
1508
1509        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1510            .bind_meta(0, x)
1511            .bind_meta(1, &state)
1512            .bind(2, cursors)
1513            .bind(3, time_decay)
1514            .bind(4, time_first)
1515            .bind(5, &state)
1516            .bind(6, k)
1517            .bind(7, v)
1518            .bind(8, r)
1519            .bind(9, x)
1520            .build()];
1521
1522        Ok(Self::Atom {
1523            pipeline,
1524            bindings,
1525            dispatch: [u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE), 1, 1],
1526        })
1527    }
1528
1529    #[allow(clippy::too_many_arguments)]
1530    pub fn time_mix_v5<'a, T: Float>(
1531        cursors: &TensorGpu<u32, ReadWrite>,
1532        time_decay: &TensorGpu<f32, ReadWrite>,
1533        time_first: &TensorGpu<f32, ReadWrite>,
1534        state: impl Into<TensorGpuView<'a, f32>>,
1535        k: &TensorGpu<T, ReadWrite>,
1536        v: &TensorGpu<T, ReadWrite>,
1537        r: &TensorGpu<T, ReadWrite>,
1538        x: &TensorGpu<T, ReadWrite>,
1539    ) -> Result<Self, TensorError> {
1540        const BLOCK_SIZE: u32 = 32;
1541
1542        let state: TensorGpuView<_> = state.into();
1543
1544        let context = x.context();
1545        let shape = x.shape();
1546        let stride = shape[0] * shape[1];
1547
1548        k.check_shape(shape)?;
1549        v.check_shape(shape)?;
1550        r.check_shape(shape)?;
1551        time_decay.check_shape([shape[0], shape[1], 1, 1])?;
1552        time_first.check_shape([shape[0], shape[1], 1, 1])?;
1553        state.check_shape([stride, shape[0] + 1, state.shape()[2], 1])?;
1554
1555        let key = PipelineKey::new(
1556            "time_mix_v5",
1557            "time_mix",
1558            Macros::new()
1559                .u32("BLOCK_SIZE", BLOCK_SIZE)
1560                .u32("HEAD_SIZE", shape[0] as u32 / 4)
1561                .tensor(x, None),
1562        );
1563        let pipeline = context.checkout_pipeline(
1564            &key,
1565            include_str!("../shaders/time_mix_v5.wgsl"),
1566            &[
1567                x.meta_layout(0),
1568                state.meta_layout(1),
1569                cursors.layout(2, true),
1570                time_decay.layout(3, true),
1571                time_first.layout(4, true),
1572                state.layout(5, false),
1573                k.layout(6, true),
1574                v.layout(7, true),
1575                r.layout(8, true),
1576                x.layout(9, false),
1577            ],
1578        );
1579
1580        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1581            .bind_meta(0, x)
1582            .bind_meta(1, &state)
1583            .bind(2, cursors)
1584            .bind(3, time_decay)
1585            .bind(4, time_first)
1586            .bind(5, &state)
1587            .bind(6, k)
1588            .bind(7, v)
1589            .bind(8, r)
1590            .bind(9, x)
1591            .build()];
1592
1593        Ok(Self::Atom {
1594            pipeline,
1595            bindings,
1596            dispatch: [u32::div_ceil(stride as u32 / 4, BLOCK_SIZE), 1, 1],
1597        })
1598    }
1599
1600    #[allow(clippy::too_many_arguments)]
1601    pub fn time_mix_v6<'a, T: Float>(
1602        cursors: &TensorGpu<u32, ReadWrite>,
1603        time_decay: &TensorGpu<f32, ReadWrite>,
1604        time_first: &TensorGpu<f32, ReadWrite>,
1605        state: impl Into<TensorGpuView<'a, f32>>,
1606        k: &TensorGpu<T, ReadWrite>,
1607        v: &TensorGpu<T, ReadWrite>,
1608        r: &TensorGpu<T, ReadWrite>,
1609        x: &TensorGpu<T, ReadWrite>,
1610    ) -> Result<Self, TensorError> {
1611        const BLOCK_SIZE: u32 = 32;
1612
1613        let state: TensorGpuView<_> = state.into();
1614
1615        let context = x.context();
1616        let shape = x.shape();
1617        let stride = shape[0] * shape[1];
1618
1619        k.check_shape(shape)?;
1620        v.check_shape(shape)?;
1621        r.check_shape(shape)?;
1622        time_decay.check_shape(shape)?;
1623        time_first.check_shape([shape[0], shape[1], 1, 1])?;
1624        state.check_shape([stride, shape[0] + 1, state.shape()[2], 1])?;
1625
1626        let key = PipelineKey::new(
1627            "time_mix_v6",
1628            "time_mix",
1629            Macros::new()
1630                .u32("BLOCK_SIZE", BLOCK_SIZE)
1631                .u32("HEAD_SIZE", shape[0] as u32 / 4)
1632                .tensor(x, None),
1633        );
1634        let pipeline = context.checkout_pipeline(
1635            &key,
1636            include_str!("../shaders/time_mix_v6.wgsl"),
1637            &[
1638                x.meta_layout(0),
1639                state.meta_layout(1),
1640                cursors.layout(2, true),
1641                time_decay.layout(3, true),
1642                time_first.layout(4, true),
1643                state.layout(5, false),
1644                k.layout(6, true),
1645                v.layout(7, true),
1646                r.layout(8, true),
1647                x.layout(9, false),
1648            ],
1649        );
1650
1651        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1652            .bind_meta(0, x)
1653            .bind_meta(1, &state)
1654            .bind(2, cursors)
1655            .bind(3, time_decay)
1656            .bind(4, time_first)
1657            .bind(5, &state)
1658            .bind(6, k)
1659            .bind(7, v)
1660            .bind(8, r)
1661            .bind(9, x)
1662            .build()];
1663
1664        Ok(Self::Atom {
1665            pipeline,
1666            bindings,
1667            dispatch: [u32::div_ceil(stride as u32 / 4, BLOCK_SIZE), 1, 1],
1668        })
1669    }
1670
1671    /// The V7 WKV kernel.
1672    /// - `n`: Stack of `k`, `v`, `a`, `kk`.
1673    ///
1674    /// Note that the state layout is different from the official implementation.
1675    /// Here is an illustration of each head's layout:
1676    ///
1677    /// ![time-mix-v7][time-mix-v7]
1678    #[embed_doc_image("time-mix-v7", "src/tensor/time-mix-v7.png")]
1679    pub fn time_mix_v7<'a, T: Float>(
1680        cursors: &TensorGpu<u32, ReadWrite>,
1681        state: impl Into<TensorGpuView<'a, f32>>,
1682        r: &TensorGpu<T, ReadWrite>,
1683        w: &TensorGpu<T, ReadWrite>,
1684        n: &TensorGpu<T, ReadWrite>,
1685        x: &TensorGpu<T, ReadWrite>,
1686    ) -> Result<Self, TensorError> {
1687        const BLOCK_SIZE: u32 = 32;
1688
1689        let state: TensorGpuView<_> = state.into();
1690
1691        let context = x.context();
1692        let shape = x.shape();
1693        let stride = shape[0] * shape[1];
1694
1695        r.check_shape(shape)?;
1696        w.check_shape(shape)?;
1697        n.check_shape([shape[0], shape[1], shape[2], 4])?;
1698        state.check_shape([stride, shape[0] + 1, state.shape()[2], 1])?;
1699
1700        let key = PipelineKey::new(
1701            "time_mix_v7",
1702            "time_mix",
1703            Macros::new()
1704                .u32("BLOCK_SIZE", BLOCK_SIZE)
1705                .u32("HEAD_SIZE", shape[0] as u32 / 4)
1706                .bool("TIME_MIX", true)
1707                .tensor(x, None)
1708                .activate("ACT", Activation::None),
1709        );
1710        let pipeline = context.checkout_pipeline(
1711            &key,
1712            include_str!("../shaders/time_mix_v7.wgsl"),
1713            &[
1714                x.meta_layout(0),
1715                state.meta_layout(1),
1716                cursors.layout(2, true),
1717                state.layout(3, false),
1718                r.layout(5, true),
1719                w.layout(6, true),
1720                n.layout(7, true),
1721                x.layout(9, false),
1722            ],
1723        );
1724
1725        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1726            .bind_meta(0, x)
1727            .bind_meta(1, &state)
1728            .bind(2, cursors)
1729            .bind(3, &state)
1730            .bind(5, r)
1731            .bind(6, w)
1732            .bind(7, n)
1733            .bind(9, x)
1734            .build()];
1735
1736        Ok(Self::Atom {
1737            pipeline,
1738            bindings,
1739            dispatch: [u32::div_ceil(stride as u32 / 4, BLOCK_SIZE), 1, 1],
1740        })
1741    }
1742
1743    pub fn time_first_v7<T: Float>(
1744        u: &TensorGpu<f16, ReadWrite>,
1745        r: &TensorGpu<T, ReadWrite>,
1746        n: &TensorGpu<T, ReadWrite>,
1747        x: &TensorGpu<T, ReadWrite>,
1748    ) -> Result<Self, TensorError> {
1749        const BLOCK_SIZE: u32 = 32;
1750
1751        let context = x.context();
1752        let shape = x.shape();
1753        let stride = shape[0] * shape[1];
1754
1755        r.check_shape(shape)?;
1756        u.check_shape([shape[0], shape[1], 1, 1])?;
1757        n.check_shape([shape[0], shape[1], shape[2], 4])?;
1758
1759        let key = PipelineKey::new(
1760            "time_first_v7",
1761            "time_first",
1762            Macros::new()
1763                .u32("BLOCK_SIZE", BLOCK_SIZE)
1764                .u32("HEAD_SIZE", shape[0] as u32 / 4)
1765                .bool("TIME_FIRST", true)
1766                .tensor(x, None)
1767                .activate("ACT", Activation::None),
1768        );
1769        let pipeline = context.checkout_pipeline(
1770            &key,
1771            include_str!("../shaders/time_mix_v7.wgsl"),
1772            &[
1773                x.meta_layout(0),
1774                u.layout(4, true),
1775                r.layout(5, true),
1776                n.layout(7, true),
1777                x.layout(9, false),
1778            ],
1779        );
1780
1781        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1782            .bind_meta(0, x)
1783            .bind(4, u)
1784            .bind(5, r)
1785            .bind(7, n)
1786            .bind(9, x)
1787            .build()];
1788
1789        Ok(Self::Atom {
1790            pipeline,
1791            bindings,
1792            dispatch: [
1793                u32::div_ceil(stride as u32 / 4, BLOCK_SIZE),
1794                shape[2] as u32,
1795                1,
1796            ],
1797        })
1798    }
1799
1800    pub fn control_k_v7<'a, 'b, F0: Float, F1: Float>(
1801        p: &TensorGpu<f16, ReadWrite>,
1802        a: impl Into<TensorGpuView<'a, F0>>,
1803        k: impl Into<TensorGpuView<'b, F1>>,
1804    ) -> Result<Self, TensorError> {
1805        const BLOCK_SIZE: u32 = 128;
1806
1807        let p: TensorGpuView<_> = p.into();
1808        let a: TensorGpuView<_> = a.into();
1809        let k: TensorGpuView<_> = k.into();
1810
1811        let context = k.context();
1812        let shape = {
1813            let [index, token, batch, _] = k.shape().into();
1814            a.check_shape([index, token, batch, 1])?;
1815            p.check_shape([index, 1, 1, 1])?;
1816            k.shape()
1817        };
1818
1819        let key = PipelineKey::new(
1820            "control_k_v7",
1821            "main",
1822            Macros::new()
1823                .u32("BLOCK_SIZE", BLOCK_SIZE)
1824                .tensor(&a, Some("A"))
1825                .tensor(&k, Some("K")),
1826        );
1827        let pipeline = context.checkout_pipeline(
1828            &key,
1829            include_str!("../shaders/control_k_v7.wgsl"),
1830            &[
1831                p.meta_layout(0),
1832                a.meta_layout(1),
1833                k.meta_layout(2),
1834                p.layout(3, true),
1835                a.layout(4, true),
1836                k.layout(5, false),
1837            ],
1838        );
1839
1840        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1841            .bind_meta(0, &p)
1842            .bind_meta(1, &a)
1843            .bind_meta(2, &k)
1844            .bind(3, &p)
1845            .bind(4, &a)
1846            .bind(5, &k)
1847            .build()];
1848
1849        Ok(Self::Atom {
1850            pipeline,
1851            bindings,
1852            dispatch: [
1853                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1854                shape[1] as u32,
1855                shape[2] as u32,
1856            ],
1857        })
1858    }
1859
1860    pub fn channel_mix<'a, T: Float>(
1861        cursors: &TensorGpu<u32, ReadWrite>,
1862        state: impl Into<TensorGpuView<'a, f32>>,
1863        r: &TensorGpu<T, ReadWrite>,
1864        v: &TensorGpu<T, ReadWrite>,
1865        x: &TensorGpu<T, ReadWrite>,
1866    ) -> Result<Self, TensorError> {
1867        const BLOCK_SIZE: u32 = 128;
1868
1869        let state: TensorGpuView<_> = state.into();
1870
1871        let context = x.context();
1872        let shape = x.shape();
1873        v.check_shape(shape)?;
1874        r.check_shape(shape)?;
1875        state.check_shape([shape[0], 1, state.shape()[2], 1])?;
1876
1877        let key = PipelineKey::new(
1878            "channel_mix",
1879            "channel_mix",
1880            Macros::new().u32("BLOCK_SIZE", BLOCK_SIZE).tensor(x, None),
1881        );
1882        let pipeline = context.checkout_pipeline(
1883            &key,
1884            include_str!("../shaders/channel_mix.wgsl"),
1885            &[
1886                x.meta_layout(0),
1887                state.meta_layout(1),
1888                cursors.layout(2, true),
1889                state.layout(3, false),
1890                r.layout(4, true),
1891                v.layout(5, true),
1892                x.layout(6, false),
1893            ],
1894        );
1895
1896        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1897            .bind_meta(0, x)
1898            .bind_meta(1, &state)
1899            .bind(2, cursors)
1900            .bind(3, &state)
1901            .bind(4, r)
1902            .bind(5, v)
1903            .bind(6, x)
1904            .build()];
1905
1906        Ok(Self::Atom {
1907            pipeline,
1908            bindings,
1909            dispatch: [
1910                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1911                shape[1] as u32,
1912                1,
1913            ],
1914        })
1915    }
1916
1917    pub fn channel_mix_v7<'a, T: Float>(
1918        cursors: &TensorGpu<u32, ReadWrite>,
1919        state: impl Into<TensorGpuView<'a, f32>>,
1920        v: &TensorGpu<T, ReadWrite>,
1921        x: &TensorGpu<T, ReadWrite>,
1922    ) -> Result<Self, TensorError> {
1923        const BLOCK_SIZE: u32 = 128;
1924
1925        let state: TensorGpuView<_> = state.into();
1926
1927        let context = x.context();
1928        let shape = x.shape();
1929        v.check_shape(shape)?;
1930        state.check_shape([shape[0], 1, state.shape()[2], 1])?;
1931
1932        let key = PipelineKey::new(
1933            "channel_mix",
1934            "channel_mix",
1935            Macros::new()
1936                .u32("BLOCK_SIZE", BLOCK_SIZE)
1937                .tensor(x, None)
1938                .bool("V7", true),
1939        );
1940        let pipeline = context.checkout_pipeline(
1941            &key,
1942            include_str!("../shaders/channel_mix.wgsl"),
1943            &[
1944                x.meta_layout(0),
1945                state.meta_layout(1),
1946                cursors.layout(2, true),
1947                state.layout(3, false),
1948                v.layout(5, true),
1949                x.layout(6, false),
1950            ],
1951        );
1952
1953        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1954            .bind_meta(0, x)
1955            .bind_meta(1, &state)
1956            .bind(2, cursors)
1957            .bind(3, &state)
1958            .bind(5, v)
1959            .bind(6, x)
1960            .build()];
1961
1962        Ok(Self::Atom {
1963            pipeline,
1964            bindings,
1965            dispatch: [
1966                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1967                shape[1] as u32,
1968                1,
1969            ],
1970        })
1971    }
1972
1973    pub fn activate<'a, F: Float>(
1974        x: impl Into<TensorGpuView<'a, F>>,
1975        act: Activation,
1976    ) -> Result<Self, TensorError> {
1977        const BLOCK_SIZE: u32 = 128;
1978
1979        let x: TensorGpuView<_> = x.into();
1980
1981        let context = x.context();
1982        let shape = x.shape();
1983
1984        let key = PipelineKey::new(
1985            "activate",
1986            "act",
1987            Macros::new()
1988                .u32("BLOCK_SIZE", BLOCK_SIZE)
1989                .tensor(&x, None)
1990                .activate("ACT", act),
1991        );
1992        let pipeline = context.checkout_pipeline(
1993            &key,
1994            include_str!("../shaders/activation.wgsl"),
1995            &[x.meta_layout(0), x.layout(1, false)],
1996        );
1997
1998        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1999            .bind_meta(0, &x)
2000            .bind(1, &x)
2001            .build()];
2002
2003        Ok(Self::Atom {
2004            pipeline,
2005            bindings,
2006            dispatch: [
2007                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
2008                shape[1] as u32,
2009                shape[2] as u32,
2010            ],
2011        })
2012    }
2013
2014    /// Copy the content of `input` into `output` of the same shape.
2015    pub fn blit<'a, 'b, F0: Float, F1: Float>(
2016        input: impl Into<TensorGpuView<'a, F0>>,
2017        output: impl Into<TensorGpuView<'b, F1>>,
2018    ) -> Result<Self, TensorError> {
2019        let input: TensorGpuView<_> = input.into();
2020        let output: TensorGpuView<_> = output.into();
2021
2022        let context = input.context();
2023        let shape = output.shape();
2024        input.check_shape(shape)?;
2025
2026        let block_size = match shape[1] {
2027            x if x < 8 => [128, 1],
2028            _ => [16, 16],
2029        };
2030
2031        let key = PipelineKey::new(
2032            "blit",
2033            "blit",
2034            Macros::new()
2035                .u32("BLOCK_SIZE_X", block_size[0])
2036                .u32("BLOCK_SIZE_Y", block_size[1])
2037                .tensor(&input, Some("IN"))
2038                .tensor(&output, Some("OUT")),
2039        );
2040        let pipeline = context.checkout_pipeline(
2041            &key,
2042            include_str!("../shaders/blit.wgsl"),
2043            &[
2044                input.meta_layout(0),
2045                output.meta_layout(1),
2046                input.layout(2, true),
2047                output.layout(3, false),
2048            ],
2049        );
2050
2051        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2052            .bind_meta(0, &input)
2053            .bind_meta(1, &output)
2054            .bind(2, &input)
2055            .bind(3, &output)
2056            .build()];
2057
2058        Ok(Self::Atom {
2059            pipeline,
2060            bindings,
2061            dispatch: [
2062                u32::div_ceil(shape[0] as u32 / 4, block_size[0]),
2063                u32::div_ceil(shape[1] as u32, block_size[1]),
2064                shape[2] as u32,
2065            ],
2066        })
2067    }
2068
2069    /// Repeat the content of `input` into `output` along the token and batch axes.
2070    pub fn broadcast<'a, 'b, F0: Float, F1: Float>(
2071        input: impl Into<TensorGpuView<'a, F0>>,
2072        output: impl Into<TensorGpuView<'b, F1>>,
2073    ) -> Result<Self, TensorError> {
2074        const BLOCK_SIZE: u32 = 128;
2075
2076        let input: TensorGpuView<_> = input.into();
2077        let output: TensorGpuView<_> = output.into();
2078
2079        let context = input.context();
2080        let shape = output.shape();
2081        input.check_shape([shape[0], input.shape()[1], input.shape()[2], 1])?;
2082
2083        let key = PipelineKey::new(
2084            "broadcast",
2085            "broadcast",
2086            Macros::new()
2087                .u32("BLOCK_SIZE", BLOCK_SIZE)
2088                .tensor(&input, Some("IN"))
2089                .tensor(&output, Some("OUT")),
2090        );
2091        let pipeline = context.checkout_pipeline(
2092            &key,
2093            include_str!("../shaders/reshape.wgsl"),
2094            &[
2095                input.meta_layout(0),
2096                output.meta_layout(1),
2097                input.layout(2, true),
2098                output.layout(3, false),
2099            ],
2100        );
2101
2102        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2103            .bind_meta(0, &input)
2104            .bind_meta(1, &output)
2105            .bind(2, &input)
2106            .bind(3, &output)
2107            .build()];
2108
2109        Ok(Self::Atom {
2110            pipeline,
2111            bindings,
2112            dispatch: [
2113                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
2114                shape[1] as u32,
2115                shape[2] as u32,
2116            ],
2117        })
2118    }
2119
2120    /// Swap the `token` and `batch` axes.
2121    pub fn transpose<'a, 'b, F0: Float, F1: Float>(
2122        input: impl Into<TensorGpuView<'a, F0>>,
2123        output: impl Into<TensorGpuView<'b, F1>>,
2124    ) -> Result<Self, TensorError> {
2125        const BLOCK_SIZE: u32 = 128;
2126
2127        let input: TensorGpuView<_> = input.into();
2128        let output: TensorGpuView<_> = output.into();
2129
2130        let context = input.context();
2131        let shape = input.shape();
2132        output.check_shape([shape[0], shape[2], shape[1], 1])?;
2133
2134        let key = PipelineKey::new(
2135            "transpose",
2136            "transpose",
2137            Macros::new()
2138                .u32("BLOCK_SIZE", BLOCK_SIZE)
2139                .tensor(&input, Some("IN"))
2140                .tensor(&output, Some("OUT")),
2141        );
2142        let pipeline = context.checkout_pipeline(
2143            &key,
2144            include_str!("../shaders/reshape.wgsl"),
2145            &[
2146                input.meta_layout(0),
2147                output.meta_layout(1),
2148                input.layout(2, true),
2149                output.layout(3, false),
2150            ],
2151        );
2152
2153        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2154            .bind_meta(0, &input)
2155            .bind_meta(1, &output)
2156            .bind(2, &input)
2157            .bind(3, &output)
2158            .build()];
2159
2160        Ok(Self::Atom {
2161            pipeline,
2162            bindings,
2163            dispatch: [
2164                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
2165                shape[1] as u32,
2166                shape[2] as u32,
2167            ],
2168        })
2169    }
2170
2171    pub fn blend(
2172        factor: &TensorGpu<f32, Uniform>,
2173        input: &TensorGpu<impl Float, ReadWrite>,
2174        output: &TensorGpu<impl Float, ReadWrite>,
2175    ) -> Result<Self, TensorError> {
2176        let context = output.context();
2177        let shape = output.shape();
2178        input.check_shape(shape)?;
2179        factor.check_shape([4, 1, 1, 1])?;
2180
2181        let block_size = match shape[1] {
2182            x if x < 8 => [128, 1],
2183            _ => [16, 16],
2184        };
2185
2186        let key = PipelineKey::new(
2187            "blend",
2188            "blend",
2189            Macros::new()
2190                .u32("BLOCK_SIZE_X", block_size[0])
2191                .u32("BLOCK_SIZE_Y", block_size[1])
2192                .tensor(input, Some("IN"))
2193                .tensor(output, Some("OUT")),
2194        );
2195        let pipeline = context.checkout_pipeline(
2196            &key,
2197            include_str!("../shaders/blend.wgsl"),
2198            &[
2199                input.meta_layout(0),
2200                output.meta_layout(1),
2201                factor.layout(2),
2202                input.layout(3, true),
2203                output.layout(4, false),
2204            ],
2205        );
2206
2207        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2208            .bind_meta(0, input)
2209            .bind_meta(1, output)
2210            .bind(2, factor)
2211            .bind(3, input)
2212            .bind(4, output)
2213            .build()];
2214
2215        Ok(Self::Atom {
2216            pipeline,
2217            bindings,
2218            dispatch: [
2219                u32::div_ceil(shape[0] as u32 / 4, block_size[0]),
2220                u32::div_ceil(shape[1] as u32, block_size[1]),
2221                shape[2] as u32,
2222            ],
2223        })
2224    }
2225
2226    pub fn blend_lora<'a, 'b, 'c>(
2227        factor: &TensorGpu<f32, Uniform>,
2228        xa: impl Into<TensorGpuView<'a, f16>>,
2229        xb: impl Into<TensorGpuView<'b, f16>>,
2230        output: impl Into<TensorGpuView<'c, f16>>,
2231    ) -> Result<Self, TensorError> {
2232        const BLOCK_SIZE: u32 = 8;
2233
2234        let xa: TensorGpuView<_> = xa.into();
2235        let xb: TensorGpuView<_> = xb.into();
2236        let output: TensorGpuView<_> = output.into();
2237
2238        let context = output.context();
2239        let shape = output.shape();
2240        factor.check_shape([4, 1, 1, 1])?;
2241        xa.check_shape([xa.shape()[0], shape[0], shape[2], 1])?;
2242        xb.check_shape([xb.shape()[0], shape[1], shape[2], 1])?;
2243
2244        let key = PipelineKey::new(
2245            "blend_lora",
2246            "blend_lora",
2247            Macros::new().u32("BLOCK_SIZE", BLOCK_SIZE),
2248        );
2249        let pipeline = context.checkout_pipeline(
2250            &key,
2251            include_str!("../shaders/blend_lora.wgsl"),
2252            &[
2253                xa.meta_layout(0),
2254                xb.meta_layout(1),
2255                output.meta_layout(2),
2256                factor.layout(3),
2257                xa.layout(4, true),
2258                xb.layout(5, true),
2259                output.layout(6, false),
2260            ],
2261        );
2262
2263        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2264            .bind_meta(0, &xa)
2265            .bind_meta(1, &xb)
2266            .bind_meta(2, &output)
2267            .bind(3, factor)
2268            .bind(4, &xa)
2269            .bind(5, &xb)
2270            .bind(6, &output)
2271            .build()];
2272
2273        Ok(Self::Atom {
2274            pipeline,
2275            bindings,
2276            dispatch: [
2277                u32::div_ceil(u32::div_ceil(shape[0] as u32, 4), BLOCK_SIZE),
2278                u32::div_ceil(u32::div_ceil(shape[1] as u32, 4), BLOCK_SIZE),
2279                shape[2] as u32,
2280            ],
2281        })
2282    }
2283
2284    pub fn lerp<'a, 'b, 'c, F0: Float, F1: Float, F2: Float>(
2285        input: impl Into<TensorGpuView<'a, F0>>,
2286        output: impl Into<TensorGpuView<'b, F1>>,
2287        factor: impl Into<TensorGpuView<'c, F2>>,
2288        reversed: bool,
2289    ) -> Result<Self, TensorError> {
2290        const BLOCK_SIZE: u32 = 128;
2291
2292        let factor: TensorGpuView<_> = factor.into();
2293        let input: TensorGpuView<_> = input.into();
2294        let output: TensorGpuView<_> = output.into();
2295
2296        let context = output.context();
2297        let shape = {
2298            let [index, token, batch, _] = output.shape().into();
2299            factor.check_shape_any(&[
2300                [index, token, batch, 1],
2301                [index, token, 1, 1],
2302                [index, 1, batch, 1],
2303                [index, 1, 1, 1],
2304            ])?;
2305            input.check_shape([index, token, batch, 1])?;
2306            output.shape()
2307        };
2308
2309        let key = PipelineKey::new(
2310            "lerp",
2311            "lerp",
2312            Macros::new()
2313                .u32("BLOCK_SIZE", BLOCK_SIZE)
2314                .tensor(&factor, Some("FACTOR"))
2315                .tensor(&input, Some("IN"))
2316                .tensor(&output, Some("OUT"))
2317                .bool("REVERSED", reversed),
2318        );
2319        let pipeline = context.checkout_pipeline(
2320            &key,
2321            include_str!("../shaders/lerp.wgsl"),
2322            &[
2323                factor.meta_layout(0),
2324                input.meta_layout(1),
2325                output.meta_layout(2),
2326                factor.layout(3, true),
2327                input.layout(4, true),
2328                output.layout(5, false),
2329            ],
2330        );
2331
2332        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2333            .bind_meta(0, &factor)
2334            .bind_meta(1, &input)
2335            .bind_meta(2, &output)
2336            .bind(3, &factor)
2337            .bind(4, &input)
2338            .bind(5, &output)
2339            .build()];
2340
2341        Ok(Self::Atom {
2342            pipeline,
2343            bindings,
2344            dispatch: [
2345                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
2346                shape[1] as u32,
2347                shape[2] as u32,
2348            ],
2349        })
2350    }
2351
2352    pub fn affine(
2353        x: &TensorGpu<impl Float, ReadWrite>,
2354        scale: f32,
2355        bias: f32,
2356    ) -> Result<Self, TensorError> {
2357        const BLOCK_SIZE: u32 = 128;
2358
2359        let context = x.context();
2360        let shape = x.shape();
2361
2362        let key = PipelineKey::new(
2363            "affine",
2364            "affine",
2365            Macros::new()
2366                .u32("BLOCK_SIZE", BLOCK_SIZE)
2367                .tensor(x, None)
2368                .f32("SCALE", scale)
2369                .f32("BIAS", bias),
2370        );
2371        let pipeline = context.checkout_pipeline(
2372            &key,
2373            include_str!("../shaders/affine.wgsl"),
2374            &[x.meta_layout(0), x.layout(1, false)],
2375        );
2376
2377        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2378            .bind_meta(0, x)
2379            .bind(1, x)
2380            .build()];
2381
2382        Ok(Self::Atom {
2383            pipeline,
2384            bindings,
2385            dispatch: [
2386                u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
2387                shape[1] as u32,
2388                shape[2] as u32,
2389            ],
2390        })
2391    }
2392
2393    pub fn quantize_mat_int8(
2394        input: &TensorGpu<f16, ReadWrite>,
2395        minmax: &TensorGpu<f16, ReadWrite>,
2396        output: &TensorGpu<u8, ReadWrite>,
2397    ) -> Result<Self, TensorError> {
2398        const BLOCK_SIZE: u32 = 128;
2399
2400        let context = output.context();
2401        let shape = output.shape();
2402        let minmax_len = shape.len().div_ceil(Self::INT8_BLOCK_SIZE as usize);
2403        let minmax_shape = Shape::new(minmax_len << 1, 1, 1, 1);
2404
2405        input.check_shape(shape)?;
2406        minmax.check_shape(minmax_shape)?;
2407
2408        let key = PipelineKey::new(
2409            "quant_mat_int8_minmax",
2410            "compute_minmax",
2411            Macros::new()
2412                .u32("BLOCK_SIZE", BLOCK_SIZE)
2413                .int8(Self::INT8_BLOCK_SIZE),
2414        );
2415        let pipeline = context.checkout_pipeline(
2416            &key,
2417            include_str!("../shaders/quant_mat_int8.wgsl"),
2418            &[
2419                minmax.meta_layout(0),
2420                input.meta_layout(1),
2421                input.layout(2, true),
2422                minmax.layout(3, false),
2423            ],
2424        );
2425
2426        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2427            .bind_meta(0, minmax)
2428            .bind_meta(1, input)
2429            .bind(2, input)
2430            .bind(3, minmax)
2431            .build()];
2432
2433        let compute_minmax = Self::Atom {
2434            pipeline,
2435            bindings,
2436            dispatch: [
2437                u32::div_ceil(minmax_len as u32, BLOCK_SIZE * BLOCK_SIZE),
2438                BLOCK_SIZE,
2439                1,
2440            ],
2441        };
2442
2443        let output = output.reshape(
2444            TensorDimension::Auto,
2445            TensorDimension::Size(1),
2446            TensorDimension::Size(1),
2447            TensorDimension::Size(1),
2448        )?;
2449
2450        let key = PipelineKey::new(
2451            "quant_mat_int8",
2452            "quantize",
2453            Macros::new()
2454                .u32("BLOCK_SIZE", BLOCK_SIZE)
2455                .int8(Self::INT8_BLOCK_SIZE),
2456        );
2457        let pipeline = context.checkout_pipeline(
2458            &key,
2459            include_str!("../shaders/quant_mat_int8.wgsl"),
2460            &[
2461                output.meta_layout(0),
2462                input.layout(2, true),
2463                minmax.layout(3, false),
2464                output.layout(4, false),
2465            ],
2466        );
2467
2468        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2469            .bind_meta(0, &output)
2470            .bind(2, input)
2471            .bind(3, minmax)
2472            .bind(4, &output)
2473            .build()];
2474
2475        let quantize = Self::Atom {
2476            pipeline,
2477            bindings,
2478            dispatch: [
2479                u32::div_ceil(shape[0] as u32, BLOCK_SIZE),
2480                shape[1] as u32,
2481                shape[2] as u32,
2482            ],
2483        };
2484
2485        Ok(Self::List(vec![compute_minmax, quantize]))
2486    }
2487
2488    pub fn quantize_mat_nf4(
2489        input: &TensorGpu<f16, ReadWrite>,
2490        quant: &TensorGpu<f32, Uniform>,
2491        absmax: &TensorGpu<f16, ReadWrite>,
2492        output: &TensorGpu<u8, ReadWrite>,
2493    ) -> Result<Self, TensorError> {
2494        const BLOCK_SIZE: u32 = 128;
2495
2496        let context = output.context();
2497        let shape = output.shape();
2498        let input_shape = Shape::new(shape[0] << 1, shape[1], shape[2], shape[3]);
2499        let absmax_len = input_shape.len().div_ceil(Self::NF4_BLOCK_SIZE as usize);
2500        let absmax_shape = Shape::new(absmax_len, 1, 1, 1);
2501
2502        input.check_shape(input_shape)?;
2503        absmax.check_shape(absmax_shape)?;
2504
2505        let absmax_f32: TensorGpu<f32, ReadWrite> = context.tensor_init(absmax_shape);
2506
2507        let key = PipelineKey::new(
2508            "quant_mat_nf4_absmax",
2509            "compute_absmax",
2510            Macros::new()
2511                .u32("BLOCK_SIZE", BLOCK_SIZE)
2512                .nf4(Self::NF4_BLOCK_SIZE),
2513        );
2514        let pipeline = context.checkout_pipeline(
2515            &key,
2516            include_str!("../shaders/quant_mat_nf4.wgsl"),
2517            &[
2518                absmax_f32.meta_layout(0),
2519                input.meta_layout(1),
2520                input.layout(3, true),
2521                absmax_f32.layout(4, false),
2522            ],
2523        );
2524
2525        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2526            .bind_meta(0, &absmax_f32)
2527            .bind_meta(1, input)
2528            .bind(3, input)
2529            .bind(4, &absmax_f32)
2530            .build()];
2531
2532        let compute_absmax = Self::Atom {
2533            pipeline,
2534            bindings,
2535            dispatch: [
2536                u32::div_ceil(absmax_len as u32, BLOCK_SIZE * BLOCK_SIZE),
2537                BLOCK_SIZE,
2538                1,
2539            ],
2540        };
2541
2542        let output = output.reshape(
2543            TensorDimension::Auto,
2544            TensorDimension::Size(1),
2545            TensorDimension::Size(1),
2546            TensorDimension::Size(1),
2547        )?;
2548
2549        let key = PipelineKey::new(
2550            "quant_mat_nf4",
2551            "quantize",
2552            Macros::new()
2553                .u32("BLOCK_SIZE", BLOCK_SIZE)
2554                .nf4(Self::NF4_BLOCK_SIZE),
2555        );
2556        let pipeline = context.checkout_pipeline(
2557            &key,
2558            include_str!("../shaders/quant_mat_nf4.wgsl"),
2559            &[
2560                output.meta_layout(0),
2561                quant.layout(2),
2562                input.layout(3, true),
2563                absmax_f32.layout(4, false),
2564                output.layout(5, false),
2565            ],
2566        );
2567
2568        let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2569            .bind_meta(0, &output)
2570            .bind(2, quant)
2571            .bind(3, input)
2572            .bind(4, &absmax_f32)
2573            .bind(5, &output)
2574            .build()];
2575
2576        let quantize = Self::Atom {
2577            pipeline,
2578            bindings,
2579            dispatch: [
2580                u32::div_ceil((shape[0]) as u32, BLOCK_SIZE),
2581                shape[1] as u32,
2582                shape[2] as u32,
2583            ],
2584        };
2585
2586        let quantize_absmax = Self::blit(&absmax_f32, absmax)?;
2587
2588        Ok(Self::List(vec![compute_absmax, quantize, quantize_absmax]))
2589    }
2590}
2591
2592#[cfg(test)]
2593mod tests {
2594    use std::f32::consts::PI;
2595
2596    use anyhow::Result;
2597    use half::f16;
2598    use itertools::Itertools;
2599    use wgpu::{Instance, PowerPreference};
2600    // use wgpu_profiler::GpuProfiler;
2601
2602    use super::TensorOp;
2603    use crate::{
2604        context::{Context, ContextBuilder, InstanceExt},
2605        tensor::{ops::Activation, Shape, TensorGpu},
2606    };
2607
2608    fn is_approx(a: impl Into<f32>, b: impl Into<f32>) -> bool {
2609        let a: f32 = a.into();
2610        let b: f32 = b.into();
2611        (a - b).abs() <= f32::max(f32::EPSILON, f32::max(a.abs(), b.abs()) * f32::EPSILON)
2612    }
2613
2614    fn is_approx_eps(a: impl Into<f32>, b: impl Into<f32>, eps: f32) -> bool {
2615        let a: f32 = a.into();
2616        let b: f32 = b.into();
2617        (a - b).abs() <= f32::max(eps, f32::max(a.abs(), b.abs()) * eps)
2618    }
2619
2620    async fn create_context() -> Result<Context> {
2621        let instance = Instance::default();
2622        let adapter = instance.adapter(PowerPreference::HighPerformance).await?;
2623        let context = ContextBuilder::new(adapter)
2624            // .features(Features::TIMESTAMP_QUERY | Features::TIMESTAMP_QUERY_INSIDE_PASSES)
2625            .build()
2626            .await?;
2627        Ok(context)
2628    }
2629
2630    #[cfg(feature = "tokio")]
2631    #[tokio::test]
2632    async fn test_softmax() -> Result<()> {
2633        let context = create_context().await?;
2634        fastrand::seed(42);
2635
2636        const C: usize = 1000;
2637        const T: usize = 3;
2638        const B: usize = 2;
2639
2640        let x = [(); C * T * B]
2641            .map(|_| 10.0 * (fastrand::f32() - 0.5))
2642            .to_vec();
2643        let shape = Shape::new(C, T, B, 1);
2644
2645        let x_dev: TensorGpu<_, _> = context.tensor_from_data(shape, x.clone())?;
2646        let softmax = TensorOp::softmax(&x_dev)?;
2647
2648        context.queue.submit(context.encode(&softmax));
2649        let x_host = x_dev.back().await.to_vec();
2650
2651        let mut ans = vec![];
2652        for x in &x.into_iter().chunks(C) {
2653            let x = x.collect_vec().into_iter();
2654            let max = x.clone().reduce(f32::max).unwrap_or_default();
2655            let x = x.map(|x| (x - max).exp());
2656            let sum: f32 = x.clone().sum();
2657            let x = x.map(|x| x / sum);
2658            ans.extend(x);
2659        }
2660
2661        for (index, (a, b)) in itertools::zip_eq(x_host, ans).enumerate() {
2662            assert!(
2663                is_approx(a, b),
2664                "Failed at index {index}, computed: {a} vs. answer: {b}"
2665            );
2666        }
2667
2668        Ok(())
2669    }
2670
2671    #[cfg(feature = "tokio")]
2672    #[tokio::test]
2673    async fn test_layer_norm() -> Result<()> {
2674        let context = create_context().await?;
2675        fastrand::seed(42);
2676
2677        const C: usize = 1000;
2678        const T: usize = 3;
2679        const B: usize = 2;
2680        const EPS: f32 = 1.0e-5;
2681
2682        let x = [(); C * T * B]
2683            .map(|_| 10.0 * (fastrand::f32() - 0.5))
2684            .to_vec();
2685        let w = [(); C]
2686            .map(|_| f16::from_f32(fastrand::f32() - 0.5))
2687            .repeat(T * B)
2688            .to_vec();
2689        let b = [(); C]
2690            .map(|_| f16::from_f32(fastrand::f32() - 0.5))
2691            .repeat(T * B)
2692            .to_vec();
2693
2694        let shape = Shape::new(C, T, B, 1);
2695        let x_dev = context.tensor_from_data(shape, x.clone())?;
2696
2697        let shape = Shape::new(C, 1, 1, 1);
2698        let w_dev = context.tensor_from_data(shape, &w[..1000])?;
2699        let b_dev = context.tensor_from_data(shape, &b[..1000])?;
2700
2701        // let shape = Shape::new(4, T, B, 1);
2702        // let s_dev = context.tensor_init(shape);
2703
2704        let layer_norm = TensorOp::layer_norm(&w_dev, &b_dev, &x_dev, EPS)?;
2705        context.queue.submit(context.encode(&layer_norm));
2706
2707        let x_host = x_dev.back().await.to_vec();
2708        // let s_host = s_dev.back().await.to_vec();
2709
2710        // test recenter and rms norm
2711        let shape = Shape::new(C, T, B, 1);
2712        let x_dev = context.tensor_from_data(shape, x.clone())?;
2713        let ops = TensorOp::List(vec![
2714            TensorOp::recenter(&x_dev)?,
2715            TensorOp::rms_norm(&w_dev, &b_dev, &x_dev, EPS)?,
2716        ]);
2717        context.queue.submit(context.encode(&ops));
2718
2719        let x_rms_host = x_dev.back().await.to_vec();
2720
2721        let mut ans = vec![];
2722        // let mut ans_stats = vec![];
2723        for chunk in &x
2724            .into_iter()
2725            .zip(w.into_iter())
2726            .zip(b.into_iter())
2727            .chunks(C)
2728        {
2729            let chunk = chunk.collect_vec();
2730            let x = chunk.iter().map(|((x, _), _)| x).copied();
2731            // let sum: f32 = x.clone().sum();
2732            // let squared_sum: f32 = x.clone().map(|x| x.powi(2)).sum();
2733
2734            // let mean = sum / C as f32;
2735            // let deviation = ((squared_sum / C as f32) - mean.powi(2)).sqrt();
2736            let (mean, m2, count) = x.fold((0.0f32, 0.0f32, 0u32), |(mean, m2, count), x| {
2737                let count = count + 1;
2738                let delta = x - mean;
2739                let mean = mean + delta / count as f32;
2740                let m2 = m2 + delta * (x - mean);
2741                (mean, m2, count)
2742            });
2743            let variance = m2 / count as f32 + EPS;
2744            let deviation = 1.0 / variance.sqrt();
2745            // ans_stats.append(&mut vec![mean, deviation, variance, 0.0]);
2746
2747            let x = chunk
2748                .into_iter()
2749                .map(|((x, w), b)| (x - mean) * deviation * w.to_f32() + b.to_f32());
2750            ans.extend(x);
2751        }
2752
2753        for (index, (a, &b)) in itertools::zip_eq(x_host, ans.iter()).enumerate() {
2754            assert!(
2755                is_approx_eps(a, b, 1.0e-3),
2756                "Failed at index {index}, computed: {a} vs. answer: {b}"
2757            );
2758        }
2759
2760        for (index, (a, &b)) in itertools::zip_eq(x_rms_host, ans.iter()).enumerate() {
2761            assert!(
2762                is_approx_eps(a, b, 1.0e-3),
2763                "Failed at index {index}, computed: {a} vs. answer: {b}"
2764            );
2765        }
2766
2767        Ok(())
2768    }
2769
2770    #[cfg(feature = "tokio")]
2771    #[tokio::test]
2772    async fn test_l2_norm() -> Result<()> {
2773        let context = create_context().await?;
2774        fastrand::seed(42);
2775
2776        const C: usize = 1000;
2777        const T: usize = 3;
2778        const B: usize = 2;
2779        const EPS: f32 = 1.0e-12;
2780
2781        let x = [(); C * T * B]
2782            .map(|_| 10.0 * (fastrand::f32() - 0.5))
2783            .to_vec();
2784
2785        let shape = Shape::new(C, T, B, 1);
2786        let x_dev = context.tensor_from_data(shape, x.clone())?;
2787
2788        let l2_norm = TensorOp::l2_norm(&x_dev, EPS)?;
2789        context.queue.submit(context.encode(&l2_norm));
2790
2791        let x_host = x_dev.back().await.to_vec();
2792
2793        let mut ans = vec![];
2794        for x in &x.into_iter().chunks(C) {
2795            let x = x.collect_vec().into_iter();
2796            let norm = x.clone().map(|x| x * x).sum::<f32>().sqrt();
2797            let x = x.map(|x| x / (norm + EPS));
2798            ans.extend(x);
2799        }
2800
2801        for (index, (a, b)) in itertools::zip_eq(x_host, ans).enumerate() {
2802            assert!(
2803                is_approx(a, b),
2804                "Failed at index {index}, computed: {a} vs. answer: {b}"
2805            );
2806        }
2807
2808        Ok(())
2809    }
2810
2811    #[cfg(feature = "tokio")]
2812    #[tokio::test]
2813    async fn test_matmul() -> Result<()> {
2814        let context = create_context().await?;
2815        fastrand::seed(42);
2816
2817        async fn test_matmul_inner(
2818            context: &Context,
2819            c: usize,
2820            r: usize,
2821            t: usize,
2822            b: usize,
2823        ) -> Result<()> {
2824            // let mut profiler = GpuProfiler::new(&context.adapter, &context.device, &context.queue, 1);
2825
2826            let matrix = vec![(); c * r * b]
2827                .into_iter()
2828                .map(|_| 10.0 * (fastrand::f32() - 0.5))
2829                .map(f16::from_f32)
2830                .collect_vec();
2831            let input_f32 = vec![(); c * t * b]
2832                .into_iter()
2833                .map(|_| 10.0 * (fastrand::f32() - 0.5))
2834                .collect_vec();
2835            let input_f16 = input_f32.iter().copied().map(f16::from_f32).collect_vec();
2836
2837            let matrix_shape = Shape::new(c, r, b, 1);
2838            let input_shape = Shape::new(c, t, b, 1);
2839            let output_shape = Shape::new(r, t, 2 * b, 1);
2840
2841            let matrix_dev = context.tensor_from_data(matrix_shape, matrix.clone())?;
2842            let input_f32_dev = context.tensor_from_data(input_shape, input_f32.clone())?;
2843            let input_f16_dev: TensorGpu<f16, _> = context.tensor_init(input_shape);
2844            let output_dev: TensorGpu<_, _> = context.tensor_init(output_shape);
2845
2846            let ops = TensorOp::List(vec![
2847                TensorOp::blit(&input_f32_dev, &input_f16_dev)?,
2848                TensorOp::matmul_vec_fp16(
2849                    &matrix_dev,
2850                    &input_f32_dev,
2851                    output_dev.view(.., .., 0..b, ..)?,
2852                    Activation::None,
2853                    false,
2854                )?,
2855                TensorOp::matmul_mat_fp16(
2856                    &matrix_dev,
2857                    &input_f16_dev,
2858                    output_dev.view(.., .., b.., ..)?,
2859                    Activation::None,
2860                )?,
2861            ]);
2862
2863            // profiler.resolve_queries(&mut encoder);
2864            context.queue.submit(context.encode(&ops));
2865
2866            let output_host = output_dev.back().await;
2867            let output_host: Vec<f32> = Vec::from(output_host);
2868
2869            // profiler.end_frame().unwrap();
2870            // _ = context.device.poll(wgpu::PollType::Wait);
2871
2872            // if let Some(results) = profiler.process_finished_frame() {
2873            //     wgpu_profiler::chrometrace::write_chrometrace(
2874            //         std::path::Path::new(&format!("./trace/matmul_{T}.json")),
2875            //         &results,
2876            //     )
2877            //     .expect("failed to write trace");
2878            // }
2879
2880            let mut ans = vec![0.0; output_host.len()];
2881            for ((batch, token), line) in (0..b).cartesian_product(0..t).cartesian_product(0..r) {
2882                let matrix = &matrix[((batch * r + line) * c)..((batch * r + line) + 1) * c];
2883                let input = &input_f32[(batch * t + token) * c..((batch * t + token) + 1) * c];
2884                let product = matrix
2885                    .iter()
2886                    .zip(input.iter())
2887                    .fold(0.0f32, |acc, x| acc + x.0.to_f32() * *x.1);
2888                ans[(batch * t + token) * r + line] = product;
2889
2890                let input = &input_f16[(batch * t + token) * c..((batch * t + token) + 1) * c];
2891                let product = matrix
2892                    .iter()
2893                    .zip(input.iter())
2894                    .fold(0.0f32, |acc, x| acc + x.0.to_f32() * x.1.to_f32());
2895                ans[((b + batch) * t + token) * r + line] = product;
2896            }
2897
2898            for (index, (a, b)) in itertools::zip_eq(output_host, ans).enumerate() {
2899                assert!(
2900                    is_approx_eps(a, b, 0.01),
2901                    "Failed at index {index}, computed: {a} vs. answer: {b}"
2902                );
2903            }
2904
2905            Ok(())
2906        }
2907
2908        test_matmul_inner(&context, 2560, 2048, 32, 2).await?;
2909        test_matmul_inner(&context, 320, 64, 320, 2).await?;
2910
2911        Ok(())
2912    }
2913
2914    #[cfg(feature = "tokio")]
2915    #[tokio::test]
2916    async fn test_matmul_int8() -> Result<()> {
2917        let context = create_context().await?;
2918        fastrand::seed(42);
2919
2920        const INT8_BLOCK_SIZE: usize = TensorOp::INT8_BLOCK_SIZE as usize;
2921
2922        async fn test_matmul_int8_inner(
2923            context: &Context,
2924            c: usize,
2925            r: usize,
2926            t: usize,
2927        ) -> Result<()> {
2928            let matrix = vec![(); c * r]
2929                .into_iter()
2930                .map(|_| 10.0 * (fastrand::f32() - 0.5))
2931                .map(f16::from_f32)
2932                .collect_vec();
2933            let input_f32 = vec![(); c * t]
2934                .into_iter()
2935                .map(|_| 10.0 * (fastrand::f32() - 0.5))
2936                .collect_vec();
2937            let input_f16 = input_f32.iter().copied().map(f16::from_f32).collect_vec();
2938
2939            let (matrix_u8, min, max) = {
2940                let mut matrix_u8: Vec<u8> = vec![0; matrix.len()];
2941                let mut min = vec![f16::MAX; matrix.len().div_ceil(INT8_BLOCK_SIZE)];
2942                let mut max = vec![f16::MIN; matrix.len().div_ceil(INT8_BLOCK_SIZE)];
2943
2944                for (i, (min, max)) in itertools::zip_eq(&mut min, &mut max).enumerate() {
2945                    let start = i * INT8_BLOCK_SIZE;
2946                    let end = start + INT8_BLOCK_SIZE;
2947                    let chunk = &matrix[start..end];
2948                    for value in chunk.iter() {
2949                        *min = min.min(*value);
2950                        *max = max.max(*value);
2951                    }
2952                    for (j, value) in chunk.iter().enumerate() {
2953                        let value = value.to_f32();
2954                        let min = min.to_f32();
2955                        let max = max.to_f32();
2956                        let value = (value - min) / (max - min);
2957                        matrix_u8[start + j] = f32::round(value * 255.0) as u8;
2958                    }
2959                }
2960
2961                (matrix_u8, min, max)
2962            };
2963            let minmax = itertools::zip_eq(&min, &max)
2964                .map(|(&min, &max)| [min, max])
2965                .collect_vec()
2966                .concat();
2967
2968            let minmax_shape = Shape::new((c * r).div_ceil(INT8_BLOCK_SIZE) * 2, 1, 1, 1);
2969            let matrix_shape = Shape::new(c, r, 1, 1);
2970            let input_shape = Shape::new(c, t, 1, 1);
2971            let output_shape = Shape::new(r, t, 1, 1);
2972
2973            let minmax_dev = context.tensor_init(minmax_shape);
2974            let matrix_f16_dev = context.tensor_from_data(matrix_shape, matrix.clone())?;
2975
2976            let matrix_u8_dev = context.tensor_init(matrix_shape);
2977            let input_dev = context.tensor_from_data(input_shape, input_f16.clone())?;
2978            let output_dev = context.tensor_init(output_shape);
2979
2980            let ops = TensorOp::List(vec![TensorOp::quantize_mat_int8(
2981                &matrix_f16_dev,
2982                &minmax_dev,
2983                &matrix_u8_dev,
2984            )?]);
2985            context.queue.submit(context.encode(&ops));
2986            let minmax_host = minmax_dev.back().await.to_vec();
2987            let matrix_u8_host = matrix_u8_dev.back().await.to_vec();
2988
2989            for (index, (&a, &b)) in itertools::zip_eq(&minmax_host, &minmax).enumerate() {
2990                assert!(
2991                    is_approx_eps(a, b, 0.01),
2992                    "Failed at index {index}, computed: {a} vs. answer: {b}"
2993                );
2994            }
2995            for (index, (&a, &b)) in itertools::zip_eq(&matrix_u8_host, &matrix_u8).enumerate() {
2996                assert!(
2997                    a.abs_diff(b) < 2,
2998                    "Failed at index {index}, computed: {a} vs. answer: {b}"
2999                );
3000            }
3001
3002            let mut ans = vec![0.0; t * r];
3003            for (token, line) in (0..t).cartesian_product(0..r) {
3004                let matrix = &matrix_u8_host[line * c..(line + 1) * c];
3005                let input = &input_f16[token * c..(token + 1) * c];
3006                let product =
3007                    matrix
3008                        .iter()
3009                        .zip_eq(input.iter())
3010                        .enumerate()
3011                        .fold(0.0f32, |acc, (i, x)| {
3012                            let min = min[(line * c + i) / INT8_BLOCK_SIZE].to_f32();
3013                            let max = max[(line * c + i) / INT8_BLOCK_SIZE].to_f32();
3014                            let value = (*x.0 as f32) / 255.0;
3015                            acc + (value * (max - min) + min) * x.1.to_f32()
3016                        });
3017                ans[token * r + line] = product;
3018            }
3019
3020            let ops = TensorOp::List(vec![TensorOp::matmul_vec_int8(
3021                &matrix_u8_dev,
3022                &minmax_dev,
3023                &input_dev,
3024                &output_dev,
3025                Activation::None,
3026                false,
3027            )?]);
3028            context.queue.submit(context.encode(&ops));
3029            let output_host: Vec<f32> = output_dev.back().await.to_vec();
3030
3031            for (index, (&a, &b)) in itertools::zip_eq(&output_host, &ans).enumerate() {
3032                assert!(
3033                    is_approx_eps(a, b, 0.01),
3034                    "Failed at index {index}, computed: {a} vs. answer: {b}"
3035                );
3036            }
3037
3038            let ops = TensorOp::List(vec![TensorOp::matmul_mat_int8(
3039                &matrix_u8_dev,
3040                &minmax_dev,
3041                &input_dev,
3042                &output_dev,
3043                Activation::None,
3044            )?]);
3045            context.queue.submit(context.encode(&ops));
3046            let output_host = output_dev.back().await.to_vec();
3047
3048            for (index, (&a, &b)) in itertools::zip_eq(&output_host, &ans).enumerate() {
3049                assert!(
3050                    is_approx_eps(a, b, 0.01),
3051                    "Failed at index {index}, computed: {a} vs. answer: {b}"
3052                );
3053            }
3054
3055            Ok(())
3056        }
3057
3058        test_matmul_int8_inner(&context, 2560, 2048, 64).await?;
3059        test_matmul_int8_inner(&context, 320, 64, 320).await?;
3060
3061        Ok(())
3062    }
3063
3064    #[cfg(feature = "tokio")]
3065    #[tokio::test]
3066    async fn test_matmul_nf4() -> Result<()> {
3067        let context = create_context().await?;
3068        fastrand::seed(42);
3069
3070        const NF4_BLOCK_SIZE: usize = TensorOp::NF4_BLOCK_SIZE as usize;
3071
3072        fn normal() -> f32 {
3073            let u = fastrand::f32();
3074            let v = fastrand::f32();
3075            (-2.0 * u.ln()).sqrt() * (2.0 * PI * v).cos()
3076        }
3077
3078        async fn test_matmul_nf4_inner(
3079            context: &Context,
3080            c: usize,
3081            r: usize,
3082            t: usize,
3083        ) -> Result<()> {
3084            let matrix = vec![(); c * r]
3085                .into_iter()
3086                .map(|_| normal())
3087                .map(f16::from_f32)
3088                .collect_vec();
3089            let input_f32 = vec![(); c * t]
3090                .into_iter()
3091                .map(|_| 2.0 * fastrand::f32() - 1.0)
3092                .collect_vec();
3093            let input_f16 = input_f32.iter().copied().map(f16::from_f32).collect_vec();
3094
3095            #[allow(clippy::excessive_precision)]
3096            let quant: [f32; 16] = [
3097                -1.0,
3098                -0.6961928009986877,
3099                -0.5250730514526367,
3100                -0.39491748809814453,
3101                -0.28444138169288635,
3102                -0.18477343022823334,
3103                -0.09105003625154495,
3104                0.0,
3105                0.07958029955625534,
3106                0.16093020141124725,
3107                0.24611230194568634,
3108                0.33791524171829224,
3109                0.44070982933044434,
3110                0.5626170039176941,
3111                0.7229568362236023,
3112                1.0,
3113            ];
3114            let (matrix_u8, matrix_u4, absmax) = {
3115                let mut matrix_u8: Vec<u8> = vec![0; matrix.len()];
3116                let mut matrix_u4: Vec<u8> = vec![0; matrix.len() / 2];
3117                let mut absmax = vec![f16::ZERO; matrix.len().div_ceil(NF4_BLOCK_SIZE)];
3118
3119                for (i, absmax) in absmax.iter_mut().enumerate() {
3120                    let start = i * NF4_BLOCK_SIZE;
3121                    let end = start + NF4_BLOCK_SIZE;
3122                    let chunk = &matrix[start..end];
3123                    *absmax = chunk
3124                        .iter()
3125                        .map(|&x| if x >= f16::ZERO { x } else { -x })
3126                        .reduce(f16::max)
3127                        .unwrap();
3128                    for (j, value) in chunk.iter().enumerate() {
3129                        let value = value.to_f32() / absmax.to_f32();
3130                        matrix_u8[start + j] = quant
3131                            .iter()
3132                            .map(|quant| (value - quant).abs())
3133                            .enumerate()
3134                            .fold((0, f32::MAX), |acc, x| if x.1 < acc.1 { x } else { acc })
3135                            .0 as u8;
3136                    }
3137                }
3138
3139                for (i, x) in matrix_u4.iter_mut().enumerate() {
3140                    *x = matrix_u8[2 * i] | matrix_u8[2 * i + 1] << 4;
3141                }
3142
3143                (matrix_u8, matrix_u4, absmax)
3144            };
3145
3146            let quant_shape = Shape::new(quant.len(), 1, 1, 1);
3147            let absmax_shape = Shape::new((c * r).div_ceil(NF4_BLOCK_SIZE), 1, 1, 1);
3148            let matrix_f16_shape = Shape::new(c, r, 1, 1);
3149            let matrix_u4_shape = Shape::new(c / 2, r, 1, 1);
3150            let input_shape = Shape::new(c, t, 1, 1);
3151            let output_shape = Shape::new(r, t, 1, 1);
3152
3153            let quant_dev = context.tensor_from_data(quant_shape, quant.to_vec())?;
3154            let absmax_dev = context.tensor_init(absmax_shape);
3155            let matrix_f16_dev = context.tensor_from_data(matrix_f16_shape, matrix.clone())?;
3156
3157            let matrix_u4_dev = context.tensor_init(matrix_u4_shape);
3158            let input_dev: TensorGpu<_, _> =
3159                context.tensor_from_data(input_shape, input_f16.clone())?;
3160            let output_dev: TensorGpu<_, _> = context.tensor_init(output_shape);
3161
3162            let ops = TensorOp::List(vec![TensorOp::quantize_mat_nf4(
3163                &matrix_f16_dev,
3164                &quant_dev,
3165                &absmax_dev,
3166                &matrix_u4_dev,
3167            )?]);
3168            context.queue.submit(context.encode(&ops));
3169            let matrix_u4_host = matrix_u4_dev.back().await.to_vec();
3170            let absmax_host = absmax_dev.back().await.to_vec();
3171
3172            for (index, (&a, &b)) in itertools::zip_eq(&absmax_host, &absmax).enumerate() {
3173                assert!(
3174                    is_approx_eps(a.to_f32(), b.to_f32(), 0.01),
3175                    "Failed at index {index}, computed: {a} vs. answer: {b}"
3176                );
3177            }
3178
3179            for (index, (a, b)) in itertools::zip_eq(matrix_u4_host, matrix_u4).enumerate() {
3180                assert!(
3181                    a == b,
3182                    "Failed at index {index}, computed: {a} vs. answer: {b}"
3183                );
3184            }
3185
3186            let mut truth = vec![0.0; t * r];
3187            for (token, line) in (0..t).cartesian_product(0..r) {
3188                let matrix = &matrix[line * c..(line + 1) * c];
3189                let input = &input_f16[token * c..(token + 1) * c];
3190                let product = matrix
3191                    .iter()
3192                    .zip(input.iter())
3193                    .fold(0.0f32, |acc, x| acc + x.0.to_f32() * x.1.to_f32());
3194                truth[token * r + line] = product;
3195            }
3196
3197            let mut ans = vec![0.0; t * r];
3198            for (token, line) in (0..t).cartesian_product(0..r) {
3199                let matrix = &matrix_u8[line * c..(line + 1) * c];
3200                let input = &input_f16[token * c..(token + 1) * c];
3201                let product =
3202                    matrix
3203                        .iter()
3204                        .zip(input.iter())
3205                        .enumerate()
3206                        .fold(0.0f32, |acc, (i, x)| {
3207                            let amp = absmax[(line * c + i) / NF4_BLOCK_SIZE];
3208                            acc + quant[*x.0 as usize] * amp.to_f32() * x.1.to_f32()
3209                        });
3210                ans[token * r + line] = product;
3211            }
3212
3213            let ops = TensorOp::List(vec![TensorOp::matmul_vec_nf4(
3214                &matrix_u4_dev,
3215                &quant_dev,
3216                &absmax_dev,
3217                &input_dev,
3218                &output_dev,
3219                Activation::None,
3220                false,
3221            )?]);
3222            context.queue.submit(context.encode(&ops));
3223            let output_host: Vec<f32> = output_dev.back().await.to_vec();
3224
3225            for (index, (&a, &b)) in itertools::zip_eq(&output_host, &ans).enumerate() {
3226                assert!(
3227                    is_approx_eps(a, b, 0.01),
3228                    "Failed at index {index}, computed: {a} vs. answer: {b}"
3229                );
3230            }
3231
3232            let ops = TensorOp::List(vec![TensorOp::matmul_mat_nf4(
3233                &matrix_u4_dev,
3234                &quant_dev,
3235                &absmax_dev,
3236                &input_dev,
3237                &output_dev,
3238                Activation::None,
3239            )?]);
3240            context.queue.submit(context.encode(&ops));
3241            let output_host = output_dev.back().await.to_vec();
3242
3243            for (index, (&a, &b)) in itertools::zip_eq(&output_host, &ans).enumerate() {
3244                assert!(
3245                    is_approx_eps(a, b, 0.01),
3246                    "Failed at index {index}, computed: {a} vs. answer: {b}"
3247                );
3248            }
3249
3250            Ok(())
3251        }
3252
3253        test_matmul_nf4_inner(&context, 2560, 2048, 64).await?;
3254        test_matmul_nf4_inner(&context, 320, 64, 320).await?;
3255
3256        Ok(())
3257    }
3258
3259    #[cfg(feature = "tokio")]
3260    #[tokio::test]
3261    async fn test_lerp() -> Result<()> {
3262        let context = create_context().await?;
3263        fastrand::seed(42);
3264
3265        const C: usize = 1000;
3266        const T: usize = 3;
3267        const B: usize = 2;
3268
3269        let x = [(); C * T * B].map(|_| fastrand::f32() - 0.5).to_vec();
3270        let y = [(); C * T * B].map(|_| fastrand::f32() - 0.5).to_vec();
3271        let f = [(); C * T * B].map(|_| fastrand::f32()).to_vec();
3272
3273        let shape = Shape::new(C, T, B, 1);
3274        let x_dev = context.tensor_from_data(shape, x.clone())?;
3275        let y_dev = context.tensor_from_data(shape, y.clone())?;
3276        let f_dev = context.tensor_from_data(shape, f.clone())?;
3277
3278        let lerp = TensorOp::lerp(&x_dev, &y_dev, &f_dev, false)?;
3279        context.queue.submit(context.encode(&lerp));
3280
3281        let y_host = y_dev.back().await.to_vec();
3282
3283        let mut ans = vec![];
3284        for chunk in &itertools::multizip((&x, &y, &f)).chunks(C) {
3285            for (x, y, f) in chunk {
3286                ans.push(x * (1.0 - f) + y * f);
3287            }
3288        }
3289
3290        for (index, (a, b)) in itertools::zip_eq(y_host, ans).enumerate() {
3291            assert!(
3292                is_approx(a, b),
3293                "Failed at index {index}, computed: {a} vs. answer: {b}"
3294            );
3295        }
3296
3297        Ok(())
3298    }
3299
3300    #[cfg(feature = "tokio")]
3301    #[tokio::test]
3302    async fn test_blit() -> Result<()> {
3303        let context = create_context().await?;
3304        fastrand::seed(42);
3305
3306        let output = vec![0.0; 24];
3307        let output: TensorGpu<_, _> = context.tensor_from_data([4, 3, 2, 1], output)?;
3308
3309        let mut ops = vec![];
3310
3311        let input = (0..8).map(|x| x as f32).collect_vec();
3312        let input: TensorGpu<_, _> = context.tensor_from_data([4, 1, 2, 1], input)?;
3313        ops.push(TensorOp::blit(&input, output.view(.., 1, .., ..)?)?);
3314
3315        let input = (8..12).map(|x| x as f32).collect_vec();
3316        let input: TensorGpu<_, _> = context.tensor_from_data([4, 1, 1, 1], input)?;
3317        ops.push(TensorOp::blit(&input, output.view(.., 2.., 1..2, ..)?)?);
3318
3319        let ops = TensorOp::List(ops);
3320        context.queue.submit(context.encode(&ops));
3321
3322        let output_host = output.back().await;
3323        let output_host = Vec::from(output_host);
3324
3325        assert_eq!(
3326            output_host,
3327            vec![
3328                0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
3329                4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0
3330            ]
3331        );
3332
3333        Ok(())
3334    }
3335
3336    #[cfg(feature = "tokio")]
3337    #[tokio::test]
3338    async fn test_transpose() -> Result<()> {
3339        let context = create_context().await?;
3340        fastrand::seed(42);
3341
3342        let output = vec![0.0; 36];
3343        let output: TensorGpu<_, _> = context.tensor_from_data([4, 3, 3, 1], output)?;
3344
3345        let input = (0..24).map(|x| x as f32).collect_vec();
3346        let input: TensorGpu<_, _> = context.tensor_from_data([4, 3, 2, 1], input)?;
3347
3348        let ops = TensorOp::transpose(&input, output.view(.., ..2, .., ..)?)?;
3349        context.queue.submit(context.encode(&ops));
3350
3351        let output_host = output.back().await;
3352        let output_host: Vec<f32> = Vec::from(output_host);
3353
3354        assert_eq!(
3355            output_host,
3356            vec![
3357                0.0, 1.0, 2.0, 3.0, 12.0, 13.0, 14.0, 15.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0,
3358                16.0, 17.0, 18.0, 19.0, 0.0, 0.0, 0.0, 0.0, 8.0, 9.0, 10.0, 11.0, 20.0, 21.0, 22.0,
3359                23.0, 0.0, 0.0, 0.0, 0.0
3360            ]
3361        );
3362
3363        Ok(())
3364    }
3365}