1use std::{hash::Hash, sync::Arc};
2
3use embed_doc_image::embed_doc_image;
4use half::f16;
5use serde::{Deserialize, Serialize};
6use wgpu::{BindGroup, CommandBuffer, CommandEncoder, ComputePass};
7
8use super::{
9 kind::{Kind, ReadWrite, Uniform},
10 Shape, TensorError, TensorErrorKind, TensorGpu, TensorGpuView, TensorScalar, TensorShape,
11};
12use crate::{
13 context::{BindGroupBuilder, CachedPipeline, Macros, PipelineKey},
14 num::{Float, Scalar},
15 tensor::{shape::TensorDimension, TensorReshape},
16};
17
18pub trait TensorCommand<T: Scalar, K: Kind> {
19 fn copy_tensor(
20 &mut self,
21 source: &TensorGpu<T, K>,
22 destination: &TensorGpu<T, K>,
23 ) -> Result<(), TensorError>;
24
25 fn copy_tensor_batch(
26 &mut self,
27 source: &TensorGpu<T, K>,
28 destination: &TensorGpu<T, K>,
29 from: usize,
30 to: usize,
31 ) -> Result<(), TensorError>;
32}
33
34impl<T: Scalar, K: Kind> TensorCommand<T, K> for CommandEncoder {
35 fn copy_tensor(
36 &mut self,
37 source: &TensorGpu<T, K>,
38 destination: &TensorGpu<T, K>,
39 ) -> Result<(), TensorError> {
40 destination.check_shape(source.shape())?;
41 let size = destination.size() as u64;
42 self.copy_buffer_to_buffer(&source.buffer, 0, &destination.buffer, 0, size);
43 Ok(())
44 }
45
46 fn copy_tensor_batch(
47 &mut self,
48 source: &TensorGpu<T, K>,
49 destination: &TensorGpu<T, K>,
50 from: usize,
51 to: usize,
52 ) -> Result<(), TensorError> {
53 source.check_shape([source.shape[0], source.shape[1], source.shape[2], 1])?;
54 destination.check_shape([source.shape[0], source.shape[1], destination.shape[2], 1])?;
55 if from >= source.shape[2] {
56 Err(TensorErrorKind::BatchOutOfRange {
57 batch: from,
58 max: source.shape[2],
59 })?;
60 }
61 if to > destination.shape[2] {
62 Err(TensorErrorKind::BatchOutOfRange {
63 batch: to,
64 max: destination.shape[2],
65 })?;
66 }
67 self.copy_buffer_to_buffer(
68 &source.buffer,
69 (T::size() * source.shape[0] * source.shape[1] * from) as u64,
70 &destination.buffer,
71 (T::size() * destination.shape[0] * destination.shape[1] * to) as u64,
72 (T::size() * source.shape[0] * source.shape[1]) as u64,
73 );
74 Ok(())
75 }
76}
77
78impl crate::context::Context {
79 pub fn encode(&self, op: &TensorOp) -> Vec<CommandBuffer> {
80 struct Atom<'a> {
81 pipeline: &'a CachedPipeline,
82 bindings: &'a [Arc<BindGroup>],
83 dispatch: &'a [u32; 3],
84 }
85
86 fn dispatch<'b, 'a: 'b>(
87 pass: &'b mut ComputePass<'a>,
88 Atom {
89 pipeline,
90 bindings,
91 dispatch,
92 }: Atom<'a>,
93 ) {
94 pass.set_pipeline(&pipeline.pipeline);
95 for (index, bind) in bindings.iter().enumerate() {
96 pass.set_bind_group(index as u32, &**bind, &[]);
97 }
98 pass.dispatch_workgroups(dispatch[0], dispatch[1], dispatch[2]);
99 }
100
101 fn flatten<'b, 'a: 'b>(
102 commands: &'b mut Vec<Vec<Atom<'a>>>,
103 passes: &'b mut Vec<Atom<'a>>,
104 op: &'a TensorOp,
105 ) {
106 match op {
107 TensorOp::Atom {
108 pipeline,
109 bindings,
110 dispatch,
111 } => passes.push(Atom {
112 pipeline,
113 bindings,
114 dispatch,
115 }),
116 TensorOp::List(ops) => ops.iter().for_each(|op| flatten(commands, passes, op)),
117 TensorOp::Sep => {
118 let mut temp = vec![];
119 std::mem::swap(&mut temp, passes);
120 commands.push(temp);
121 }
122 }
123 }
124
125 let mut commands = vec![];
126 let mut passes = vec![];
127 flatten(&mut commands, &mut passes, op);
128 commands.push(passes);
129
130 commands
131 .into_iter()
132 .filter(|atoms| !atoms.is_empty())
133 .map(|atoms| {
134 let mut encoder = self.device.create_command_encoder(&Default::default());
135 let mut pass = encoder.begin_compute_pass(&Default::default());
136 for atom in atoms {
137 dispatch(&mut pass, atom);
138 }
139 drop(pass);
140 encoder.finish()
141 })
142 .collect()
143 }
144}
145
146#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
147#[serde(rename_all = "snake_case")]
148pub enum Activation {
149 #[default]
150 #[serde(rename = "")]
151 None,
152 SquaredRelu,
153 #[serde(rename = "custom_tanh")]
154 Tanh,
155 StableExp,
156 OppositeExp,
157 Softplus,
158 Sigmoid,
159 Silu,
160}
161
162impl std::fmt::Display for Activation {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.write_str(serde_variant::to_variant_name(self).unwrap())
165 }
166}
167
168impl Macros {
169 pub fn nf4(mut self, block_size: u32) -> Self {
171 self.insert("NF4_BLOCK_SIZE".into(), format!("{block_size}u"));
172 self
173 }
174
175 pub fn int8(mut self, block_size: u32) -> Self {
177 self.insert("INT8_BLOCK_SIZE".into(), format!("{block_size}u"));
178 self
179 }
180
181 pub fn f32(mut self, name: impl Into<String>, value: f32) -> Self {
183 self.insert(name.into(), format!("{value}"));
184 self
185 }
186
187 pub fn u32(mut self, name: impl Into<String>, value: u32) -> Self {
189 self.insert(name.into(), format!("{value}u"));
190 self
191 }
192
193 pub fn bool(mut self, name: impl Into<String>, value: bool) -> Self {
195 match value {
196 true => {
197 self.insert(name.into(), Default::default());
198 self
199 }
200 false => self,
201 }
202 }
203
204 pub fn activate(mut self, name: impl Into<String>, value: Activation) -> Self {
205 const ACTIVATION_DEFINE: &str = "
206fn squared_relu(x: vec4<f32>) -> vec4<f32> {
207 let p = max(x, vec4<f32>(0.0));
208 return p * p;
209}
210
211fn stable_exp(x: vec4<f32>) -> vec4<f32> {
212 return exp(-exp(x));
213}
214
215fn opposite_exp(x: vec4<f32>) -> vec4<f32> {
216 return -exp(x);
217}
218
219fn softplus(x: vec4<f32>) -> vec4<f32> {
220 return log(1.0 + exp(x));
221}
222
223fn sigmoid(x: vec4<f32>) -> vec4<f32> {
224 return 1.0 / (1.0 + exp(-x));
225}
226
227fn silu(x: vec4<f32>) -> vec4<f32> {
228 return x / (1.0 + exp(-x));
229}
230
231// Metal has some trouble with `tanh`.
232fn custom_tanh(x: vec4<f32>) -> vec4<f32> {
233 return select(tanh(x), vec4<f32>(1.0), x > vec4<f32>(42.0));
234}
235";
236 self.insert("ACTIVATION_DEFINE".into(), ACTIVATION_DEFINE.to_string());
237 self.insert(name.into(), value.to_string());
238 self
239 }
240
241 pub fn tensor<T: Float>(
243 mut self,
244 _tensor: &impl TensorScalar<T = T>,
245 prefix: Option<&'_ str>,
246 ) -> Self {
247 match prefix {
248 None => self.insert(T::DEF.into(), Default::default()),
249 Some(prefix) => self.insert(format!("{}_{}", prefix, T::DEF), Default::default()),
250 };
251 self
252 }
253
254 pub fn custom(mut self, value: impl std::fmt::Display, prefix: Option<&'_ str>) -> Self {
256 match prefix {
257 None => self.insert(format!("{value}"), Default::default()),
258 Some(prefix) => self.insert(format!("{prefix}_{value}"), Default::default()),
259 };
260 self
261 }
262
263 pub fn define(mut self, name: impl Into<String>, condition: bool) -> Self {
265 if condition {
266 self.insert(name.into(), Default::default());
267 }
268 self
269 }
270
271 #[cfg(feature = "subgroup-ops")]
273 pub fn subgroup(self, min: u32, max: u32) -> Self {
274 self.u32("MIN_SUBGROUP_SIZE", min)
275 .u32("MAX_SUBGROUP_SIZE", max)
276 .define(format!("SUBGROUP_SIZE_{min}_{max}"), true)
277 }
278}
279
280pub enum TensorOp {
281 Atom {
282 pipeline: Arc<CachedPipeline>,
283 bindings: Vec<Arc<BindGroup>>,
284 dispatch: [u32; 3],
285 },
286 List(Vec<TensorOp>),
287 Sep,
288}
289
290impl TensorOp {
291 pub const NF4_BLOCK_SIZE: u32 = 64;
292 pub const INT8_BLOCK_SIZE: u32 = 128;
293
294 #[inline]
295 pub fn empty() -> Self {
296 Self::List(vec![])
297 }
298
299 pub fn softmax(x: &TensorGpu<impl Float, ReadWrite>) -> Result<Self, TensorError> {
301 const BLOCK_SIZE: u32 = 128;
302
303 let context = x.context();
304 let shape = x.shape();
305
306 #[cfg(not(feature = "subgroup-ops"))]
307 let key = PipelineKey::new(
308 "softmax",
309 "softmax",
310 Macros::new().u32("BLOCK_SIZE", BLOCK_SIZE).tensor(x, None),
311 );
312 #[cfg(feature = "subgroup-ops")]
313 let key = PipelineKey::new(
314 "softmax",
315 "softmax",
316 Macros::new()
317 .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
318 .u32("BLOCK_SIZE", BLOCK_SIZE)
319 .tensor(x, None),
320 );
321
322 #[cfg(not(feature = "subgroup-ops"))]
323 let pipeline = context.checkout_pipeline(
324 &key,
325 include_str!("../shaders/softmax.wgsl"),
326 &[x.meta_layout(0), x.layout(1, false)],
327 );
328 #[cfg(feature = "subgroup-ops")]
329 let pipeline = context.checkout_pipeline(
330 &key,
331 include_str!("../shaders/subgroup/softmax.wgsl"),
332 &[x.meta_layout(0), x.layout(1, false)],
333 );
334 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
335 .bind_meta(0, x)
336 .bind(1, x)
337 .build()];
338
339 Ok(Self::Atom {
340 pipeline,
341 bindings,
342 dispatch: [1, shape[1] as u32, shape[2] as u32],
343 })
344 }
345
346 pub fn embed(
351 tokens: &TensorGpu<u32, ReadWrite>,
352 input: &TensorGpu<f16, ReadWrite>,
353 output: &TensorGpu<impl Float, ReadWrite>,
354 ) -> Result<Self, TensorError> {
355 const BLOCK_SIZE: u32 = 128;
356
357 let context = output.context();
358 let shape = {
359 let [index, token, batch, _] = output.shape().into();
360 let [_, vocab, _, _] = input.shape().into();
361 tokens.check_shape([token, batch, 1, 1])?;
362 input.check_shape([index, vocab, 1, 1])?;
363 output.check_shape([index, token, batch, 1])?;
364 output.shape()
365 };
366
367 let key = PipelineKey::new(
368 "embed",
369 "embed",
370 Macros::new()
371 .u32("BLOCK_SIZE", BLOCK_SIZE)
372 .tensor(output, None),
373 );
374 let pipeline = context.checkout_pipeline(
375 &key,
376 include_str!("../shaders/embed.wgsl"),
377 &[
378 output.meta_layout(0),
379 tokens.layout(1, true),
380 input.layout(2, true),
381 output.layout(3, false),
382 ],
383 );
384 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
385 .bind_meta(0, output)
386 .bind(1, tokens)
387 .bind(2, input)
388 .bind(3, output)
389 .build()];
390
391 Ok(Self::Atom {
392 pipeline,
393 bindings,
394 dispatch: [
395 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
396 shape[1] as u32,
397 shape[2] as u32,
398 ],
399 })
400 }
401
402 pub fn layer_norm(
408 w: &TensorGpu<f16, ReadWrite>,
409 b: &TensorGpu<f16, ReadWrite>,
410 x: &TensorGpu<impl Float, ReadWrite>,
411 eps: f32,
412 ) -> Result<Self, TensorError> {
413 const BLOCK_SIZE: u32 = 128;
414
415 let context = x.context();
416 let shape = {
417 let [index, token, batch, _] = x.shape().into();
418 x.check_shape([index, token, batch, 1])?;
419 w.check_shape([index, 1, 1, 1])?;
420 b.check_shape([index, 1, 1, 1])?;
421 x.shape()
422 };
423
424 let key = PipelineKey::new(
425 "layer_norm",
426 "layer_norm",
427 Macros::new()
428 .u32("BLOCK_SIZE", BLOCK_SIZE)
429 .tensor(x, None)
430 .f32("EPS", eps),
431 );
432 let pipeline = context.checkout_pipeline(
433 &key,
434 include_str!("../shaders/layer_norm.wgsl"),
435 &[
436 x.meta_layout(0),
437 w.layout(1, true),
438 b.layout(2, true),
439 x.layout(3, false),
440 ],
441 );
442 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
443 .bind_meta(0, x)
444 .bind(1, w)
445 .bind(2, b)
446 .bind(3, x)
447 .build()];
448
449 Ok(Self::Atom {
450 pipeline,
451 bindings,
452 dispatch: [1, shape[1] as u32, shape[2] as u32],
453 })
454 }
455
456 pub fn group_norm(
461 w: &TensorGpu<f16, ReadWrite>,
462 b: &TensorGpu<f16, ReadWrite>,
463 x: &TensorGpu<impl Float, ReadWrite>,
464 eps: f32,
465 ) -> Result<Self, TensorError> {
466 const BLOCK_SIZE: u32 = 32;
467
468 let context = x.context();
469 let shape = {
470 let [index, head, token, _] = x.shape().into();
471 x.check_shape([index, head, token, 1])?;
472 w.check_shape([index, head, 1, 1])?;
473 b.check_shape([index, head, 1, 1])?;
474 x.shape()
475 };
476
477 let key = PipelineKey::new(
478 "group_norm",
479 "layer_norm",
480 Macros::new()
481 .u32("BLOCK_SIZE", BLOCK_SIZE)
482 .bool("GROUP_NORM", true)
483 .tensor(x, None)
484 .f32("EPS", eps),
485 );
486 let pipeline = context.checkout_pipeline(
487 &key,
488 include_str!("../shaders/layer_norm.wgsl"),
489 &[
490 x.meta_layout(0),
491 w.layout(1, true),
492 b.layout(2, true),
493 x.layout(3, false),
494 ],
495 );
496 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
497 .bind_meta(0, x)
498 .bind(1, w)
499 .bind(2, b)
500 .bind(3, x)
501 .build()];
502
503 Ok(Self::Atom {
504 pipeline,
505 bindings,
506 dispatch: [1, shape[1] as u32, shape[2] as u32],
507 })
508 }
509
510 pub fn recenter(x: &TensorGpu<impl Float, ReadWrite>) -> Result<Self, TensorError> {
512 const BLOCK_SIZE: u32 = 128;
513
514 let context = x.context();
515 let shape = x.shape();
516
517 #[cfg(not(feature = "subgroup-ops"))]
518 let key = PipelineKey::new(
519 "recenter",
520 "recenter",
521 Macros::new()
522 .u32("BLOCK_SIZE", BLOCK_SIZE)
523 .tensor(x, None)
524 .f32("EPS", 0.0),
525 );
526 #[cfg(feature = "subgroup-ops")]
527 let key = PipelineKey::new(
528 "recenter",
529 "recenter",
530 Macros::new()
531 .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
532 .u32("BLOCK_SIZE", BLOCK_SIZE)
533 .tensor(x, None)
534 .f32("EPS", 0.0),
535 );
536
537 #[cfg(not(feature = "subgroup-ops"))]
538 let pipeline = context.checkout_pipeline(
539 &key,
540 include_str!("../shaders/normalize.wgsl"),
541 &[x.meta_layout(0), x.layout(3, false)],
542 );
543 #[cfg(feature = "subgroup-ops")]
544 let pipeline = context.checkout_pipeline(
545 &key,
546 include_str!("../shaders/subgroup/normalize.wgsl"),
547 &[x.meta_layout(0), x.layout(3, false)],
548 );
549
550 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
551 .bind_meta(0, x)
552 .bind(3, x)
553 .build()];
554
555 Ok(Self::Atom {
556 pipeline,
557 bindings,
558 dispatch: [1, shape[1] as u32, shape[2] as u32],
559 })
560 }
561
562 pub fn rms_norm(
567 w: &TensorGpu<f16, ReadWrite>,
568 b: &TensorGpu<f16, ReadWrite>,
569 x: &TensorGpu<impl Float, ReadWrite>,
570 eps: f32,
571 ) -> Result<Self, TensorError> {
572 const BLOCK_SIZE: u32 = 128;
573
574 let context = x.context();
575 let shape = {
576 let [index, token, batch, _] = x.shape().into();
577 x.check_shape([index, token, batch, 1])?;
578 w.check_shape([index, 1, 1, 1])?;
579 b.check_shape([index, 1, 1, 1])?;
580 x.shape()
581 };
582
583 #[cfg(not(feature = "subgroup-ops"))]
584 let key = PipelineKey::new(
585 "rms_norm",
586 "rms_norm",
587 Macros::new()
588 .u32("BLOCK_SIZE", BLOCK_SIZE)
589 .tensor(x, None)
590 .f32("EPS", eps),
591 );
592 #[cfg(feature = "subgroup-ops")]
593 let key = PipelineKey::new(
594 "rms_norm",
595 "rms_norm",
596 Macros::new()
597 .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
598 .u32("BLOCK_SIZE", BLOCK_SIZE)
599 .tensor(x, None)
600 .f32("EPS", eps),
601 );
602
603 #[cfg(not(feature = "subgroup-ops"))]
604 let pipeline = context.checkout_pipeline(
605 &key,
606 include_str!("../shaders/normalize.wgsl"),
607 &[
608 x.meta_layout(0),
609 w.layout(1, true),
610 b.layout(2, true),
611 x.layout(3, false),
612 ],
613 );
614 #[cfg(feature = "subgroup-ops")]
615 let pipeline = context.checkout_pipeline(
616 &key,
617 include_str!("../shaders/subgroup/normalize.wgsl"),
618 &[
619 x.meta_layout(0),
620 w.layout(1, true),
621 b.layout(2, true),
622 x.layout(3, false),
623 ],
624 );
625
626 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
627 .bind_meta(0, x)
628 .bind(1, w)
629 .bind(2, b)
630 .bind(3, x)
631 .build()];
632
633 Ok(Self::Atom {
634 pipeline,
635 bindings,
636 dispatch: [1, shape[1] as u32, shape[2] as u32],
637 })
638 }
639
640 pub fn l2_norm(x: &TensorGpu<impl Float, ReadWrite>, eps: f32) -> Result<Self, TensorError> {
643 const BLOCK_SIZE: u32 = 128;
644
645 let context = x.context();
646 let shape = x.shape();
647
648 #[cfg(not(feature = "subgroup-ops"))]
649 let key = PipelineKey::new(
650 "l2_norm",
651 "l2_norm",
652 Macros::new()
653 .u32("BLOCK_SIZE", BLOCK_SIZE)
654 .tensor(x, None)
655 .f32("EPS", eps),
656 );
657 #[cfg(feature = "subgroup-ops")]
658 let key = PipelineKey::new(
659 "l2_norm",
660 "l2_norm",
661 Macros::new()
662 .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
663 .u32("BLOCK_SIZE", BLOCK_SIZE)
664 .tensor(x, None)
665 .f32("EPS", eps),
666 );
667
668 #[cfg(not(feature = "subgroup-ops"))]
669 let pipeline = context.checkout_pipeline(
670 &key,
671 include_str!("../shaders/normalize.wgsl"),
672 &[x.meta_layout(0), x.layout(3, false)],
673 );
674 #[cfg(feature = "subgroup-ops")]
675 let pipeline = context.checkout_pipeline(
676 &key,
677 include_str!("../shaders/subgroup/normalize.wgsl"),
678 &[x.meta_layout(0), x.layout(3, false)],
679 );
680
681 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
682 .bind_meta(0, x)
683 .bind(3, x)
684 .build()];
685
686 Ok(Self::Atom {
687 pipeline,
688 bindings,
689 dispatch: [1, shape[1] as u32, shape[2] as u32],
690 })
691 }
692
693 pub fn matmul_vec_fp16<'a, 'b, F0: Float, F1: Float>(
698 matrix: &TensorGpu<f16, ReadWrite>,
699 input: impl Into<TensorGpuView<'a, F0>>,
700 output: impl Into<TensorGpuView<'b, F1>>,
701 act: Activation,
702 sparse: bool,
703 ) -> Result<Self, TensorError> {
704 const BLOCK_SIZE: u32 = 128;
705
706 let input: TensorGpuView<_> = input.into();
707 let output: TensorGpuView<_> = output.into();
708
709 let context = output.context();
710 let shape = {
711 let [m, n, b, _] = output.shape().into();
712 let [k, _, _, _] = input.shape().into();
713 matrix.check_shape([k, m, b, 1])?;
714 input.check_shape([k, n, b, 1])?;
715 output.check_shape([m, n, b, 1])?;
716 output.shape()
717 };
718
719 #[cfg(not(feature = "subgroup-ops"))]
720 let key = PipelineKey::new(
721 "matmul_vec_fp16",
722 "matmul",
723 Macros::new()
724 .u32("BLOCK_SIZE", BLOCK_SIZE)
725 .tensor(&input, Some("IN"))
726 .tensor(&output, Some("OUT"))
727 .activate("ACT", act)
728 .bool("SPARSE_INPUT", sparse),
729 );
730 #[cfg(feature = "subgroup-ops")]
731 let key = PipelineKey::new(
732 "matmul_vec_fp16",
733 "matmul",
734 Macros::new()
735 .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
736 .u32("BLOCK_SIZE", BLOCK_SIZE)
737 .tensor(&input, Some("IN"))
738 .tensor(&output, Some("OUT"))
739 .activate("ACT", act)
740 .bool("SPARSE_INPUT", sparse),
741 );
742
743 #[cfg(not(feature = "subgroup-ops"))]
744 let pipeline = context.checkout_pipeline(
745 &key,
746 include_str!("../shaders/matmul_vec_fp16.wgsl"),
747 &[
748 matrix.meta_layout(0),
749 input.meta_layout(1),
750 output.meta_layout(2),
751 matrix.layout(3, true),
752 input.layout(4, true),
753 output.layout(5, false),
754 ],
755 );
756 #[cfg(feature = "subgroup-ops")]
757 let pipeline = context.checkout_pipeline(
758 &key,
759 include_str!("../shaders/subgroup/matmul_vec_fp16.wgsl"),
760 &[
761 matrix.meta_layout(0),
762 input.meta_layout(1),
763 output.meta_layout(2),
764 matrix.layout(3, true),
765 input.layout(4, true),
766 output.layout(5, false),
767 ],
768 );
769
770 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
771 .bind_meta(0, matrix)
772 .bind_meta(1, &input)
773 .bind_meta(2, &output)
774 .bind(3, matrix)
775 .bind(4, &input)
776 .bind(5, &output)
777 .build()];
778
779 Ok(Self::Atom {
780 pipeline,
781 bindings,
782 dispatch: [matrix.shape[1] as u32 / 4, shape[1] as u32, shape[2] as u32],
783 })
784 }
785
786 #[allow(clippy::too_many_arguments)]
791 pub fn matmul_vec_int8<'a, 'b, F0: Float, F1: Float>(
792 matrix: &TensorGpu<u8, ReadWrite>,
793 minmax: &TensorGpu<f16, ReadWrite>,
794 input: impl Into<TensorGpuView<'a, F0>>,
795 output: impl Into<TensorGpuView<'b, F1>>,
796 act: Activation,
797 sparse: bool,
798 ) -> Result<Self, TensorError> {
799 const BLOCK_SIZE: u32 = 128;
800
801 let input: TensorGpuView<_> = input.into();
802 let output: TensorGpuView<_> = output.into();
803
804 let context = matrix.context();
805 let shape = {
806 let [m, n, b, _] = output.shape().into();
807 let [k, _, _, _] = input.shape().into();
808 let len = matrix.shape().len();
809 minmax.check_shape([(len << 1).div_ceil(Self::INT8_BLOCK_SIZE as usize), 1, 1, 1])?;
810 matrix.check_shape([k, m, b, 1])?;
811 input.check_shape([k, n, b, 1])?;
812 output.check_shape([m, n, b, 1])?;
813 output.shape()
814 };
815
816 #[cfg(not(feature = "subgroup-ops"))]
817 let key = PipelineKey::new(
818 "matmul_vec_int8",
819 "matmul",
820 Macros::new()
821 .u32("BLOCK_SIZE", BLOCK_SIZE)
822 .int8(Self::INT8_BLOCK_SIZE)
823 .tensor(&input, Some("IN"))
824 .tensor(&output, Some("OUT"))
825 .activate("ACT", act)
826 .bool("SPARSE_INPUT", sparse),
827 );
828 #[cfg(feature = "subgroup-ops")]
829 let key = PipelineKey::new(
830 "matmul_vec_int8",
831 "matmul",
832 Macros::new()
833 .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
834 .u32("BLOCK_SIZE", BLOCK_SIZE)
835 .int8(Self::INT8_BLOCK_SIZE)
836 .tensor(&input, Some("IN"))
837 .tensor(&output, Some("OUT"))
838 .activate("ACT", act)
839 .bool("SPARSE_INPUT", sparse),
840 );
841
842 #[cfg(not(feature = "subgroup-ops"))]
843 let pipeline = context.checkout_pipeline(
844 &key,
845 include_str!("../shaders/matmul_vec_int8.wgsl"),
846 &[
847 matrix.meta_layout(0),
848 input.meta_layout(1),
849 output.meta_layout(2),
850 matrix.layout(3, true),
851 minmax.layout(4, true),
852 input.layout(5, true),
853 output.layout(6, false),
854 ],
855 );
856 #[cfg(feature = "subgroup-ops")]
857 let pipeline = context.checkout_pipeline(
858 &key,
859 include_str!("../shaders/subgroup/matmul_vec_int8.wgsl"),
860 &[
861 matrix.meta_layout(0),
862 input.meta_layout(1),
863 output.meta_layout(2),
864 matrix.layout(3, true),
865 minmax.layout(4, true),
866 input.layout(5, true),
867 output.layout(6, false),
868 ],
869 );
870
871 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
872 .bind_meta(0, matrix)
873 .bind_meta(1, &input)
874 .bind_meta(2, &output)
875 .bind(3, matrix)
876 .bind(4, minmax)
877 .bind(5, &input)
878 .bind(6, &output)
879 .build()];
880
881 Ok(Self::Atom {
882 pipeline,
883 bindings,
884 dispatch: [matrix.shape[1] as u32 / 4, shape[1] as u32, shape[2] as u32],
885 })
886 }
887
888 pub fn matmul_vec_nf4<'a, 'b, F0: Float, F1: Float>(
893 matrix: &TensorGpu<u8, ReadWrite>,
894 quant: &TensorGpu<f32, Uniform>,
895 absmax: &TensorGpu<f16, ReadWrite>,
896 input: impl Into<TensorGpuView<'a, F0>>,
897 output: impl Into<TensorGpuView<'b, F1>>,
898 act: Activation,
899 sparse: bool,
900 ) -> Result<Self, TensorError> {
901 const BLOCK_SIZE: u32 = 128;
902
903 let input: TensorGpuView<_> = input.into();
904 let output: TensorGpuView<_> = output.into();
905
906 let context = matrix.context();
907 let shape = {
908 let [m, n, b, _] = output.shape().into();
909 let [k, _, _, _] = input.shape().into();
910 let len = matrix.shape().len() << 1;
911 absmax.check_shape([len.div_ceil(Self::NF4_BLOCK_SIZE as usize), 1, 1, 1])?;
912 matrix.check_shape([k >> 1, m, b, 1])?;
913 input.check_shape([k, n, b, 1])?;
914 output.check_shape([m, n, b, 1])?;
915 output.shape()
916 };
917
918 #[cfg(not(feature = "subgroup-ops"))]
919 let key = PipelineKey::new(
920 "matmul_vec_nf4",
921 "matmul",
922 Macros::new()
923 .u32("BLOCK_SIZE", BLOCK_SIZE)
924 .nf4(Self::NF4_BLOCK_SIZE)
925 .tensor(&input, Some("IN"))
926 .tensor(&output, Some("OUT"))
927 .activate("ACT", act)
928 .bool("SPARSE_INPUT", sparse),
929 );
930 #[cfg(feature = "subgroup-ops")]
931 let key = PipelineKey::new(
932 "matmul_vec_nf4",
933 "matmul",
934 Macros::new()
935 .subgroup(context.min_subgroup_size(), context.max_subgroup_size())
936 .u32("BLOCK_SIZE", BLOCK_SIZE)
937 .nf4(Self::NF4_BLOCK_SIZE)
938 .tensor(&input, Some("IN"))
939 .tensor(&output, Some("OUT"))
940 .activate("ACT", act)
941 .bool("SPARSE_INPUT", sparse),
942 );
943
944 #[cfg(not(feature = "subgroup-ops"))]
945 let pipeline = context.checkout_pipeline(
946 &key,
947 include_str!("../shaders/matmul_vec_nf4.wgsl"),
948 &[
949 matrix.meta_layout(0),
950 input.meta_layout(1),
951 output.meta_layout(2),
952 quant.layout(3),
953 matrix.layout(4, true),
954 absmax.layout(5, true),
955 input.layout(6, true),
956 output.layout(7, false),
957 ],
958 );
959 #[cfg(feature = "subgroup-ops")]
960 let pipeline = context.checkout_pipeline(
961 &key,
962 include_str!("../shaders/subgroup/matmul_vec_nf4.wgsl"),
963 &[
964 matrix.meta_layout(0),
965 input.meta_layout(1),
966 output.meta_layout(2),
967 quant.layout(3),
968 matrix.layout(4, true),
969 absmax.layout(5, true),
970 input.layout(6, true),
971 output.layout(7, false),
972 ],
973 );
974
975 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
976 .bind_meta(0, matrix)
977 .bind_meta(1, &input)
978 .bind_meta(2, &output)
979 .bind(3, quant)
980 .bind(4, matrix)
981 .bind(5, absmax)
982 .bind(6, &input)
983 .bind(7, &output)
984 .build()];
985
986 Ok(Self::Atom {
987 pipeline,
988 bindings,
989 dispatch: [matrix.shape[1] as u32 / 4, shape[1] as u32, shape[2] as u32],
990 })
991 }
992
993 pub fn matmul_mat_fp16<'a, 'b, 'c, F0: Float, F1: Float>(
1000 matrix: impl Into<TensorGpuView<'c, f16>>,
1001 input: impl Into<TensorGpuView<'a, F0>>,
1002 output: impl Into<TensorGpuView<'b, F1>>,
1003 act: Activation,
1004 ) -> Result<Self, TensorError> {
1005 const BLOCK_SIZE: u32 = 8;
1006
1007 let matrix: TensorGpuView<_> = matrix.into();
1008 let input: TensorGpuView<_> = input.into();
1009 let output: TensorGpuView<_> = output.into();
1010
1011 let context = output.context();
1012 let shape = {
1013 let [m, n, b, _] = output.shape().into();
1014 let [k, _, _, _] = input.shape().into();
1015 matrix.check_shape([k, m, b, 1])?;
1016 input.check_shape([k, n, b, 1])?;
1017 output.check_shape([m, n, b, 1])?;
1018 output.shape()
1019 };
1020
1021 let key = PipelineKey::new(
1022 "matmul_mat_fp16",
1023 "matmul",
1024 Macros::new()
1025 .u32("BLOCK_SIZE", BLOCK_SIZE)
1026 .tensor(&input, Some("IN"))
1027 .tensor(&output, Some("OUT"))
1028 .activate("ACT", act),
1029 );
1030 let pipeline = context.checkout_pipeline(
1031 &key,
1032 include_str!("../shaders/matmul_mat_fp16.wgsl"),
1033 &[
1034 matrix.meta_layout(0),
1035 input.meta_layout(1),
1036 output.meta_layout(2),
1037 matrix.layout(3, true),
1038 input.layout(4, true),
1039 output.layout(5, false),
1040 ],
1041 );
1042
1043 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1044 .bind_meta(0, &matrix)
1045 .bind_meta(1, &input)
1046 .bind_meta(2, &output)
1047 .bind(3, &matrix)
1048 .bind(4, &input)
1049 .bind(5, &output)
1050 .build()];
1051
1052 Ok(Self::Atom {
1053 pipeline,
1054 bindings,
1055 dispatch: [
1056 u32::div_ceil(u32::div_ceil(shape[0] as u32, 4), BLOCK_SIZE),
1057 u32::div_ceil(u32::div_ceil(shape[1] as u32, 4), BLOCK_SIZE),
1058 shape[2] as u32,
1059 ],
1060 })
1061 }
1062
1063 #[allow(clippy::too_many_arguments)]
1072 pub fn matmul_mat_int8<'a, 'b, 'c, F0: Float, F1: Float>(
1073 matrix: impl Into<TensorGpuView<'c, u8>>,
1074 minmax: &TensorGpu<f16, ReadWrite>,
1075 input: impl Into<TensorGpuView<'a, F0>>,
1076 output: impl Into<TensorGpuView<'b, F1>>,
1077 act: Activation,
1078 ) -> Result<Self, TensorError> {
1079 const BLOCK_SIZE: u32 = 8;
1080
1081 let matrix: TensorGpuView<_> = matrix.into();
1082 let input: TensorGpuView<_> = input.into();
1083 let output: TensorGpuView<_> = output.into();
1084
1085 let context = output.context();
1086 let shape = {
1087 let [m, n, b, _] = output.shape().into();
1088 let [k, _, _, _] = input.shape().into();
1089 let len = matrix.shape().len();
1090 minmax.check_shape([(len << 1).div_ceil(Self::INT8_BLOCK_SIZE as usize), 1, 1, 1])?;
1091 matrix.check_shape([k, m, b, 1])?;
1092 input.check_shape([k, n, b, 1])?;
1093 output.check_shape([m, n, b, 1])?;
1094 output.shape()
1095 };
1096
1097 let key = PipelineKey::new(
1098 "matmul_mat_int8",
1099 "matmul",
1100 Macros::new()
1101 .u32("BLOCK_SIZE", BLOCK_SIZE)
1102 .int8(Self::INT8_BLOCK_SIZE)
1103 .tensor(&input, Some("IN"))
1104 .tensor(&output, Some("OUT"))
1105 .activate("ACT", act),
1106 );
1107 let pipeline = context.checkout_pipeline(
1108 &key,
1109 include_str!("../shaders/matmul_mat_int8.wgsl"),
1110 &[
1111 matrix.meta_layout(0),
1112 input.meta_layout(1),
1113 output.meta_layout(2),
1114 minmax.layout(3, true),
1115 matrix.layout(4, true),
1116 input.layout(5, true),
1117 output.layout(6, false),
1118 ],
1119 );
1120
1121 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1122 .bind_meta(0, &matrix)
1123 .bind_meta(1, &input)
1124 .bind_meta(2, &output)
1125 .bind(3, minmax)
1126 .bind(4, &matrix)
1127 .bind(5, &input)
1128 .bind(6, &output)
1129 .build()];
1130
1131 Ok(Self::Atom {
1132 pipeline,
1133 bindings,
1134 dispatch: [
1135 u32::div_ceil(u32::div_ceil(shape[0] as u32, 4), BLOCK_SIZE),
1136 u32::div_ceil(u32::div_ceil(shape[1] as u32, 4), BLOCK_SIZE),
1137 shape[2] as u32,
1138 ],
1139 })
1140 }
1141
1142 pub fn matmul_mat_nf4<'a, 'b, 'c, F0: Float, F1: Float>(
1151 matrix: impl Into<TensorGpuView<'c, u8>>,
1152 quant: &TensorGpu<f32, Uniform>,
1153 absmax: &TensorGpu<f16, ReadWrite>,
1154 input: impl Into<TensorGpuView<'a, F0>>,
1155 output: impl Into<TensorGpuView<'b, F1>>,
1156 act: Activation,
1157 ) -> Result<Self, TensorError> {
1158 const BLOCK_SIZE: u32 = 8;
1159
1160 let matrix: TensorGpuView<_> = matrix.into();
1161 let input: TensorGpuView<_> = input.into();
1162 let output: TensorGpuView<_> = output.into();
1163
1164 let context = output.context();
1165 let shape = {
1166 let [m, n, b, _] = output.shape().into();
1167 let [k, _, _, _] = input.shape().into();
1168 let len = matrix.shape().len() << 1;
1169 absmax.check_shape([len.div_ceil(Self::NF4_BLOCK_SIZE as usize), 1, 1, 1])?;
1170 matrix.check_shape([k >> 1, m, b, 1])?;
1171 input.check_shape([k, n, b, 1])?;
1172 output.check_shape([m, n, b, 1])?;
1173 output.shape()
1174 };
1175
1176 let key = PipelineKey::new(
1177 "matmul_mat_nf4",
1178 "matmul",
1179 Macros::new()
1180 .u32("BLOCK_SIZE", BLOCK_SIZE)
1181 .nf4(Self::NF4_BLOCK_SIZE)
1182 .tensor(&input, Some("IN"))
1183 .tensor(&output, Some("OUT"))
1184 .activate("ACT", act),
1185 );
1186 let pipeline = context.checkout_pipeline(
1187 &key,
1188 include_str!("../shaders/matmul_mat_nf4.wgsl"),
1189 &[
1190 matrix.meta_layout(0),
1191 input.meta_layout(1),
1192 output.meta_layout(2),
1193 quant.layout(3),
1194 absmax.layout(4, true),
1195 matrix.layout(5, true),
1196 input.layout(6, true),
1197 output.layout(7, false),
1198 ],
1199 );
1200
1201 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1202 .bind_meta(0, &matrix)
1203 .bind_meta(1, &input)
1204 .bind_meta(2, &output)
1205 .bind(3, quant)
1206 .bind(4, absmax)
1207 .bind(5, &matrix)
1208 .bind(6, &input)
1209 .bind(7, &output)
1210 .build()];
1211
1212 Ok(Self::Atom {
1213 pipeline,
1214 bindings,
1215 dispatch: [
1216 u32::div_ceil(u32::div_ceil(shape[0] as u32, 4), BLOCK_SIZE),
1217 u32::div_ceil(u32::div_ceil(shape[1] as u32, 4), BLOCK_SIZE),
1218 shape[2] as u32,
1219 ],
1220 })
1221 }
1222
1223 pub fn add_activate<'a, 'b, F0: Float, F1: Float>(
1228 input: impl Into<TensorGpuView<'a, F0>>,
1229 output: impl Into<TensorGpuView<'b, F1>>,
1230 act_x: Activation,
1231 act_y: Activation,
1232 act_out: Activation,
1233 ) -> Result<Self, TensorError> {
1234 const BLOCK_SIZE: u32 = 128;
1235
1236 let input: TensorGpuView<_> = input.into();
1237 let output: TensorGpuView<_> = output.into();
1238
1239 let context = output.context();
1240 let shape = {
1241 let [index, token, batch, _] = output.shape().into();
1242 input.check_shape_any(&[
1243 [index, token, batch, 1],
1244 [index, token, 1, batch],
1245 [index, 1, batch, 1],
1246 [index, 1, 1, 1],
1247 ])?;
1248 output.check_shape([index, token, batch, 1])?;
1249 output.shape()
1250 };
1251
1252 let key = PipelineKey::new(
1253 "add",
1254 "add",
1255 Macros::new()
1256 .u32("BLOCK_SIZE", BLOCK_SIZE)
1257 .tensor(&input, Some("IN"))
1258 .tensor(&output, Some("OUT"))
1259 .activate("ACT_X", act_x)
1260 .activate("ACT_Y", act_y)
1261 .activate("ACT_OUT", act_out),
1262 );
1263 let pipeline = context.checkout_pipeline(
1264 &key,
1265 include_str!("../shaders/binary.wgsl"),
1266 &[
1267 input.meta_layout(0),
1268 output.meta_layout(1),
1269 input.layout(2, true),
1270 output.layout(3, false),
1271 ],
1272 );
1273
1274 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1275 .bind_meta(0, &input)
1276 .bind_meta(1, &output)
1277 .bind(2, &input)
1278 .bind(3, &output)
1279 .build()];
1280
1281 Ok(Self::Atom {
1282 pipeline,
1283 bindings,
1284 dispatch: [
1285 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1286 shape[1] as u32,
1287 shape[2] as u32,
1288 ],
1289 })
1290 }
1291
1292 pub fn add<'a, 'b, F0: Float, F1: Float>(
1296 input: impl Into<TensorGpuView<'a, F0>>,
1297 output: impl Into<TensorGpuView<'b, F1>>,
1298 ) -> Result<Self, TensorError> {
1299 Self::add_activate(
1300 input,
1301 output,
1302 Activation::None,
1303 Activation::None,
1304 Activation::None,
1305 )
1306 }
1307
1308 pub fn mul_activate<'a, 'b, F0: Float, F1: Float>(
1313 input: impl Into<TensorGpuView<'a, F0>>,
1314 output: impl Into<TensorGpuView<'b, F1>>,
1315 act_x: Activation,
1316 act_y: Activation,
1317 act_out: Activation,
1318 ) -> Result<Self, TensorError> {
1319 const BLOCK_SIZE: u32 = 128;
1320
1321 let input: TensorGpuView<_> = input.into();
1322 let output: TensorGpuView<_> = output.into();
1323
1324 let context = output.context();
1325 let shape = {
1326 let [index, token, batch, _] = output.shape().into();
1327 input.check_shape_any(&[
1328 [index, token, batch, 1],
1329 [index, token, 1, batch],
1330 [index, 1, batch, 1],
1331 [index, 1, 1, 1],
1332 ])?;
1333 output.check_shape([index, token, batch, 1])?;
1334 output.shape()
1335 };
1336
1337 let key = PipelineKey::new(
1338 "mul",
1339 "mul",
1340 Macros::new()
1341 .u32("BLOCK_SIZE", BLOCK_SIZE)
1342 .tensor(&input, Some("IN"))
1343 .tensor(&output, Some("OUT"))
1344 .activate("ACT_X", act_x)
1345 .activate("ACT_Y", act_y)
1346 .activate("ACT_OUT", act_out),
1347 );
1348 let pipeline = context.checkout_pipeline(
1349 &key,
1350 include_str!("../shaders/binary.wgsl"),
1351 &[
1352 input.meta_layout(0),
1353 output.meta_layout(1),
1354 input.layout(2, true),
1355 output.layout(3, false),
1356 ],
1357 );
1358
1359 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1360 .bind_meta(0, &input)
1361 .bind_meta(1, &output)
1362 .bind(2, &input)
1363 .bind(3, &output)
1364 .build()];
1365
1366 Ok(Self::Atom {
1367 pipeline,
1368 bindings,
1369 dispatch: [
1370 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1371 shape[1] as u32,
1372 shape[2] as u32,
1373 ],
1374 })
1375 }
1376
1377 pub fn mul<'a, 'b, F0: Float, F1: Float>(
1381 input: impl Into<TensorGpuView<'a, F0>>,
1382 output: impl Into<TensorGpuView<'b, F1>>,
1383 ) -> Result<Self, TensorError> {
1384 Self::mul_activate(
1385 input,
1386 output,
1387 Activation::None,
1388 Activation::None,
1389 Activation::None,
1390 )
1391 }
1392
1393 pub fn token_shift<'a, 'b, F: Float>(
1394 cursors: &TensorGpu<u32, ReadWrite>,
1395 time_mix: impl Into<TensorGpuView<'a, F>>,
1396 state: impl Into<TensorGpuView<'b, f32>>,
1397 input: &TensorGpu<impl Float, ReadWrite>,
1398 output: &TensorGpu<impl Float, ReadWrite>,
1399 reversed: bool,
1400 ) -> Result<Self, TensorError> {
1401 const BLOCK_SIZE: u32 = 128;
1402
1403 let time_mix: TensorGpuView<_> = time_mix.into();
1404 let state: TensorGpuView<_> = state.into();
1405
1406 let context = output.context();
1407 let shape = {
1408 let [index, token, count, _] = output.shape().into();
1409 let [_, head, batch, _] = state.shape().into();
1410 input.check_shape_any(&[[index, token, count, 1], [index, token, 1, 1]])?;
1411 time_mix.check_shape_any(&[[index, token, count, 1], [index, 1, 1, 1]])?;
1412 state.check_shape([index, head, batch, 1])?;
1413 output.shape()
1414 };
1415
1416 let key = PipelineKey::new(
1417 "token_shift",
1418 "token_shift",
1419 Macros::new()
1420 .u32("BLOCK_SIZE", BLOCK_SIZE)
1421 .tensor(&time_mix, Some("TIME_MIX"))
1422 .tensor(input, Some("IN"))
1423 .tensor(output, Some("OUT"))
1424 .bool("REVERSED", reversed),
1425 );
1426 let pipeline = context.checkout_pipeline(
1427 &key,
1428 include_str!("../shaders/token_shift.wgsl"),
1429 &[
1430 output.meta_layout(0),
1431 time_mix.meta_layout(1),
1432 state.meta_layout(2),
1433 cursors.layout(3, true),
1434 time_mix.layout(4, true),
1435 state.layout(5, true),
1436 input.layout(6, true),
1437 output.layout(7, false),
1438 ],
1439 );
1440
1441 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1442 .bind_meta(0, output)
1443 .bind_meta(1, &time_mix)
1444 .bind_meta(2, &state)
1445 .bind(3, cursors)
1446 .bind(4, &time_mix)
1447 .bind(5, &state)
1448 .bind(6, input)
1449 .bind(7, output)
1450 .build()];
1451
1452 Ok(Self::Atom {
1453 pipeline,
1454 bindings,
1455 dispatch: [
1456 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1457 shape[1] as u32,
1458 shape[2] as u32,
1459 ],
1460 })
1461 }
1462
1463 #[allow(clippy::too_many_arguments)]
1464 pub fn time_mix_v4<'a, T: Float>(
1465 cursors: &TensorGpu<u32, ReadWrite>,
1466 time_decay: &TensorGpu<f32, ReadWrite>,
1467 time_first: &TensorGpu<f32, ReadWrite>,
1468 state: impl Into<TensorGpuView<'a, f32>>,
1469 k: &TensorGpu<T, ReadWrite>,
1470 v: &TensorGpu<T, ReadWrite>,
1471 r: &TensorGpu<T, ReadWrite>,
1472 x: &TensorGpu<T, ReadWrite>,
1473 ) -> Result<Self, TensorError> {
1474 const BLOCK_SIZE: u32 = 128;
1475
1476 let state: TensorGpuView<_> = state.into();
1477
1478 let context = x.context();
1479 let shape = x.shape();
1480 k.check_shape(shape)?;
1481 v.check_shape(shape)?;
1482 r.check_shape(shape)?;
1483 time_decay.check_shape([shape[0], 1, 1, 1])?;
1484 time_first.check_shape([shape[0], 1, 1, 1])?;
1485 state.check_shape([shape[0], 4, state.shape()[2], 1])?;
1486
1487 let key = PipelineKey::new(
1488 "time_mix_v4",
1489 "time_mix",
1490 Macros::new().u32("BLOCK_SIZE", BLOCK_SIZE).tensor(x, None),
1491 );
1492 let pipeline = context.checkout_pipeline(
1493 &key,
1494 include_str!("../shaders/time_mix_v4.wgsl"),
1495 &[
1496 x.meta_layout(0),
1497 state.meta_layout(1),
1498 cursors.layout(2, true),
1499 time_decay.layout(3, true),
1500 time_first.layout(4, true),
1501 state.layout(5, false),
1502 k.layout(6, true),
1503 v.layout(7, true),
1504 r.layout(8, true),
1505 x.layout(9, false),
1506 ],
1507 );
1508
1509 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1510 .bind_meta(0, x)
1511 .bind_meta(1, &state)
1512 .bind(2, cursors)
1513 .bind(3, time_decay)
1514 .bind(4, time_first)
1515 .bind(5, &state)
1516 .bind(6, k)
1517 .bind(7, v)
1518 .bind(8, r)
1519 .bind(9, x)
1520 .build()];
1521
1522 Ok(Self::Atom {
1523 pipeline,
1524 bindings,
1525 dispatch: [u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE), 1, 1],
1526 })
1527 }
1528
1529 #[allow(clippy::too_many_arguments)]
1530 pub fn time_mix_v5<'a, T: Float>(
1531 cursors: &TensorGpu<u32, ReadWrite>,
1532 time_decay: &TensorGpu<f32, ReadWrite>,
1533 time_first: &TensorGpu<f32, ReadWrite>,
1534 state: impl Into<TensorGpuView<'a, f32>>,
1535 k: &TensorGpu<T, ReadWrite>,
1536 v: &TensorGpu<T, ReadWrite>,
1537 r: &TensorGpu<T, ReadWrite>,
1538 x: &TensorGpu<T, ReadWrite>,
1539 ) -> Result<Self, TensorError> {
1540 const BLOCK_SIZE: u32 = 32;
1541
1542 let state: TensorGpuView<_> = state.into();
1543
1544 let context = x.context();
1545 let shape = x.shape();
1546 let stride = shape[0] * shape[1];
1547
1548 k.check_shape(shape)?;
1549 v.check_shape(shape)?;
1550 r.check_shape(shape)?;
1551 time_decay.check_shape([shape[0], shape[1], 1, 1])?;
1552 time_first.check_shape([shape[0], shape[1], 1, 1])?;
1553 state.check_shape([stride, shape[0] + 1, state.shape()[2], 1])?;
1554
1555 let key = PipelineKey::new(
1556 "time_mix_v5",
1557 "time_mix",
1558 Macros::new()
1559 .u32("BLOCK_SIZE", BLOCK_SIZE)
1560 .u32("HEAD_SIZE", shape[0] as u32 / 4)
1561 .tensor(x, None),
1562 );
1563 let pipeline = context.checkout_pipeline(
1564 &key,
1565 include_str!("../shaders/time_mix_v5.wgsl"),
1566 &[
1567 x.meta_layout(0),
1568 state.meta_layout(1),
1569 cursors.layout(2, true),
1570 time_decay.layout(3, true),
1571 time_first.layout(4, true),
1572 state.layout(5, false),
1573 k.layout(6, true),
1574 v.layout(7, true),
1575 r.layout(8, true),
1576 x.layout(9, false),
1577 ],
1578 );
1579
1580 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1581 .bind_meta(0, x)
1582 .bind_meta(1, &state)
1583 .bind(2, cursors)
1584 .bind(3, time_decay)
1585 .bind(4, time_first)
1586 .bind(5, &state)
1587 .bind(6, k)
1588 .bind(7, v)
1589 .bind(8, r)
1590 .bind(9, x)
1591 .build()];
1592
1593 Ok(Self::Atom {
1594 pipeline,
1595 bindings,
1596 dispatch: [u32::div_ceil(stride as u32 / 4, BLOCK_SIZE), 1, 1],
1597 })
1598 }
1599
1600 #[allow(clippy::too_many_arguments)]
1601 pub fn time_mix_v6<'a, T: Float>(
1602 cursors: &TensorGpu<u32, ReadWrite>,
1603 time_decay: &TensorGpu<f32, ReadWrite>,
1604 time_first: &TensorGpu<f32, ReadWrite>,
1605 state: impl Into<TensorGpuView<'a, f32>>,
1606 k: &TensorGpu<T, ReadWrite>,
1607 v: &TensorGpu<T, ReadWrite>,
1608 r: &TensorGpu<T, ReadWrite>,
1609 x: &TensorGpu<T, ReadWrite>,
1610 ) -> Result<Self, TensorError> {
1611 const BLOCK_SIZE: u32 = 32;
1612
1613 let state: TensorGpuView<_> = state.into();
1614
1615 let context = x.context();
1616 let shape = x.shape();
1617 let stride = shape[0] * shape[1];
1618
1619 k.check_shape(shape)?;
1620 v.check_shape(shape)?;
1621 r.check_shape(shape)?;
1622 time_decay.check_shape(shape)?;
1623 time_first.check_shape([shape[0], shape[1], 1, 1])?;
1624 state.check_shape([stride, shape[0] + 1, state.shape()[2], 1])?;
1625
1626 let key = PipelineKey::new(
1627 "time_mix_v6",
1628 "time_mix",
1629 Macros::new()
1630 .u32("BLOCK_SIZE", BLOCK_SIZE)
1631 .u32("HEAD_SIZE", shape[0] as u32 / 4)
1632 .tensor(x, None),
1633 );
1634 let pipeline = context.checkout_pipeline(
1635 &key,
1636 include_str!("../shaders/time_mix_v6.wgsl"),
1637 &[
1638 x.meta_layout(0),
1639 state.meta_layout(1),
1640 cursors.layout(2, true),
1641 time_decay.layout(3, true),
1642 time_first.layout(4, true),
1643 state.layout(5, false),
1644 k.layout(6, true),
1645 v.layout(7, true),
1646 r.layout(8, true),
1647 x.layout(9, false),
1648 ],
1649 );
1650
1651 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1652 .bind_meta(0, x)
1653 .bind_meta(1, &state)
1654 .bind(2, cursors)
1655 .bind(3, time_decay)
1656 .bind(4, time_first)
1657 .bind(5, &state)
1658 .bind(6, k)
1659 .bind(7, v)
1660 .bind(8, r)
1661 .bind(9, x)
1662 .build()];
1663
1664 Ok(Self::Atom {
1665 pipeline,
1666 bindings,
1667 dispatch: [u32::div_ceil(stride as u32 / 4, BLOCK_SIZE), 1, 1],
1668 })
1669 }
1670
1671 #[embed_doc_image("time-mix-v7", "src/tensor/time-mix-v7.png")]
1679 pub fn time_mix_v7<'a, T: Float>(
1680 cursors: &TensorGpu<u32, ReadWrite>,
1681 state: impl Into<TensorGpuView<'a, f32>>,
1682 r: &TensorGpu<T, ReadWrite>,
1683 w: &TensorGpu<T, ReadWrite>,
1684 n: &TensorGpu<T, ReadWrite>,
1685 x: &TensorGpu<T, ReadWrite>,
1686 ) -> Result<Self, TensorError> {
1687 const BLOCK_SIZE: u32 = 32;
1688
1689 let state: TensorGpuView<_> = state.into();
1690
1691 let context = x.context();
1692 let shape = x.shape();
1693 let stride = shape[0] * shape[1];
1694
1695 r.check_shape(shape)?;
1696 w.check_shape(shape)?;
1697 n.check_shape([shape[0], shape[1], shape[2], 4])?;
1698 state.check_shape([stride, shape[0] + 1, state.shape()[2], 1])?;
1699
1700 let key = PipelineKey::new(
1701 "time_mix_v7",
1702 "time_mix",
1703 Macros::new()
1704 .u32("BLOCK_SIZE", BLOCK_SIZE)
1705 .u32("HEAD_SIZE", shape[0] as u32 / 4)
1706 .bool("TIME_MIX", true)
1707 .tensor(x, None)
1708 .activate("ACT", Activation::None),
1709 );
1710 let pipeline = context.checkout_pipeline(
1711 &key,
1712 include_str!("../shaders/time_mix_v7.wgsl"),
1713 &[
1714 x.meta_layout(0),
1715 state.meta_layout(1),
1716 cursors.layout(2, true),
1717 state.layout(3, false),
1718 r.layout(5, true),
1719 w.layout(6, true),
1720 n.layout(7, true),
1721 x.layout(9, false),
1722 ],
1723 );
1724
1725 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1726 .bind_meta(0, x)
1727 .bind_meta(1, &state)
1728 .bind(2, cursors)
1729 .bind(3, &state)
1730 .bind(5, r)
1731 .bind(6, w)
1732 .bind(7, n)
1733 .bind(9, x)
1734 .build()];
1735
1736 Ok(Self::Atom {
1737 pipeline,
1738 bindings,
1739 dispatch: [u32::div_ceil(stride as u32 / 4, BLOCK_SIZE), 1, 1],
1740 })
1741 }
1742
1743 pub fn time_first_v7<T: Float>(
1744 u: &TensorGpu<f16, ReadWrite>,
1745 r: &TensorGpu<T, ReadWrite>,
1746 n: &TensorGpu<T, ReadWrite>,
1747 x: &TensorGpu<T, ReadWrite>,
1748 ) -> Result<Self, TensorError> {
1749 const BLOCK_SIZE: u32 = 32;
1750
1751 let context = x.context();
1752 let shape = x.shape();
1753 let stride = shape[0] * shape[1];
1754
1755 r.check_shape(shape)?;
1756 u.check_shape([shape[0], shape[1], 1, 1])?;
1757 n.check_shape([shape[0], shape[1], shape[2], 4])?;
1758
1759 let key = PipelineKey::new(
1760 "time_first_v7",
1761 "time_first",
1762 Macros::new()
1763 .u32("BLOCK_SIZE", BLOCK_SIZE)
1764 .u32("HEAD_SIZE", shape[0] as u32 / 4)
1765 .bool("TIME_FIRST", true)
1766 .tensor(x, None)
1767 .activate("ACT", Activation::None),
1768 );
1769 let pipeline = context.checkout_pipeline(
1770 &key,
1771 include_str!("../shaders/time_mix_v7.wgsl"),
1772 &[
1773 x.meta_layout(0),
1774 u.layout(4, true),
1775 r.layout(5, true),
1776 n.layout(7, true),
1777 x.layout(9, false),
1778 ],
1779 );
1780
1781 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1782 .bind_meta(0, x)
1783 .bind(4, u)
1784 .bind(5, r)
1785 .bind(7, n)
1786 .bind(9, x)
1787 .build()];
1788
1789 Ok(Self::Atom {
1790 pipeline,
1791 bindings,
1792 dispatch: [
1793 u32::div_ceil(stride as u32 / 4, BLOCK_SIZE),
1794 shape[2] as u32,
1795 1,
1796 ],
1797 })
1798 }
1799
1800 pub fn control_k_v7<'a, 'b, F0: Float, F1: Float>(
1801 p: &TensorGpu<f16, ReadWrite>,
1802 a: impl Into<TensorGpuView<'a, F0>>,
1803 k: impl Into<TensorGpuView<'b, F1>>,
1804 ) -> Result<Self, TensorError> {
1805 const BLOCK_SIZE: u32 = 128;
1806
1807 let p: TensorGpuView<_> = p.into();
1808 let a: TensorGpuView<_> = a.into();
1809 let k: TensorGpuView<_> = k.into();
1810
1811 let context = k.context();
1812 let shape = {
1813 let [index, token, batch, _] = k.shape().into();
1814 a.check_shape([index, token, batch, 1])?;
1815 p.check_shape([index, 1, 1, 1])?;
1816 k.shape()
1817 };
1818
1819 let key = PipelineKey::new(
1820 "control_k_v7",
1821 "main",
1822 Macros::new()
1823 .u32("BLOCK_SIZE", BLOCK_SIZE)
1824 .tensor(&a, Some("A"))
1825 .tensor(&k, Some("K")),
1826 );
1827 let pipeline = context.checkout_pipeline(
1828 &key,
1829 include_str!("../shaders/control_k_v7.wgsl"),
1830 &[
1831 p.meta_layout(0),
1832 a.meta_layout(1),
1833 k.meta_layout(2),
1834 p.layout(3, true),
1835 a.layout(4, true),
1836 k.layout(5, false),
1837 ],
1838 );
1839
1840 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1841 .bind_meta(0, &p)
1842 .bind_meta(1, &a)
1843 .bind_meta(2, &k)
1844 .bind(3, &p)
1845 .bind(4, &a)
1846 .bind(5, &k)
1847 .build()];
1848
1849 Ok(Self::Atom {
1850 pipeline,
1851 bindings,
1852 dispatch: [
1853 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1854 shape[1] as u32,
1855 shape[2] as u32,
1856 ],
1857 })
1858 }
1859
1860 pub fn channel_mix<'a, T: Float>(
1861 cursors: &TensorGpu<u32, ReadWrite>,
1862 state: impl Into<TensorGpuView<'a, f32>>,
1863 r: &TensorGpu<T, ReadWrite>,
1864 v: &TensorGpu<T, ReadWrite>,
1865 x: &TensorGpu<T, ReadWrite>,
1866 ) -> Result<Self, TensorError> {
1867 const BLOCK_SIZE: u32 = 128;
1868
1869 let state: TensorGpuView<_> = state.into();
1870
1871 let context = x.context();
1872 let shape = x.shape();
1873 v.check_shape(shape)?;
1874 r.check_shape(shape)?;
1875 state.check_shape([shape[0], 1, state.shape()[2], 1])?;
1876
1877 let key = PipelineKey::new(
1878 "channel_mix",
1879 "channel_mix",
1880 Macros::new().u32("BLOCK_SIZE", BLOCK_SIZE).tensor(x, None),
1881 );
1882 let pipeline = context.checkout_pipeline(
1883 &key,
1884 include_str!("../shaders/channel_mix.wgsl"),
1885 &[
1886 x.meta_layout(0),
1887 state.meta_layout(1),
1888 cursors.layout(2, true),
1889 state.layout(3, false),
1890 r.layout(4, true),
1891 v.layout(5, true),
1892 x.layout(6, false),
1893 ],
1894 );
1895
1896 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1897 .bind_meta(0, x)
1898 .bind_meta(1, &state)
1899 .bind(2, cursors)
1900 .bind(3, &state)
1901 .bind(4, r)
1902 .bind(5, v)
1903 .bind(6, x)
1904 .build()];
1905
1906 Ok(Self::Atom {
1907 pipeline,
1908 bindings,
1909 dispatch: [
1910 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1911 shape[1] as u32,
1912 1,
1913 ],
1914 })
1915 }
1916
1917 pub fn channel_mix_v7<'a, T: Float>(
1918 cursors: &TensorGpu<u32, ReadWrite>,
1919 state: impl Into<TensorGpuView<'a, f32>>,
1920 v: &TensorGpu<T, ReadWrite>,
1921 x: &TensorGpu<T, ReadWrite>,
1922 ) -> Result<Self, TensorError> {
1923 const BLOCK_SIZE: u32 = 128;
1924
1925 let state: TensorGpuView<_> = state.into();
1926
1927 let context = x.context();
1928 let shape = x.shape();
1929 v.check_shape(shape)?;
1930 state.check_shape([shape[0], 1, state.shape()[2], 1])?;
1931
1932 let key = PipelineKey::new(
1933 "channel_mix",
1934 "channel_mix",
1935 Macros::new()
1936 .u32("BLOCK_SIZE", BLOCK_SIZE)
1937 .tensor(x, None)
1938 .bool("V7", true),
1939 );
1940 let pipeline = context.checkout_pipeline(
1941 &key,
1942 include_str!("../shaders/channel_mix.wgsl"),
1943 &[
1944 x.meta_layout(0),
1945 state.meta_layout(1),
1946 cursors.layout(2, true),
1947 state.layout(3, false),
1948 v.layout(5, true),
1949 x.layout(6, false),
1950 ],
1951 );
1952
1953 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1954 .bind_meta(0, x)
1955 .bind_meta(1, &state)
1956 .bind(2, cursors)
1957 .bind(3, &state)
1958 .bind(5, v)
1959 .bind(6, x)
1960 .build()];
1961
1962 Ok(Self::Atom {
1963 pipeline,
1964 bindings,
1965 dispatch: [
1966 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
1967 shape[1] as u32,
1968 1,
1969 ],
1970 })
1971 }
1972
1973 pub fn activate<'a, F: Float>(
1974 x: impl Into<TensorGpuView<'a, F>>,
1975 act: Activation,
1976 ) -> Result<Self, TensorError> {
1977 const BLOCK_SIZE: u32 = 128;
1978
1979 let x: TensorGpuView<_> = x.into();
1980
1981 let context = x.context();
1982 let shape = x.shape();
1983
1984 let key = PipelineKey::new(
1985 "activate",
1986 "act",
1987 Macros::new()
1988 .u32("BLOCK_SIZE", BLOCK_SIZE)
1989 .tensor(&x, None)
1990 .activate("ACT", act),
1991 );
1992 let pipeline = context.checkout_pipeline(
1993 &key,
1994 include_str!("../shaders/activation.wgsl"),
1995 &[x.meta_layout(0), x.layout(1, false)],
1996 );
1997
1998 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
1999 .bind_meta(0, &x)
2000 .bind(1, &x)
2001 .build()];
2002
2003 Ok(Self::Atom {
2004 pipeline,
2005 bindings,
2006 dispatch: [
2007 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
2008 shape[1] as u32,
2009 shape[2] as u32,
2010 ],
2011 })
2012 }
2013
2014 pub fn blit<'a, 'b, F0: Float, F1: Float>(
2016 input: impl Into<TensorGpuView<'a, F0>>,
2017 output: impl Into<TensorGpuView<'b, F1>>,
2018 ) -> Result<Self, TensorError> {
2019 let input: TensorGpuView<_> = input.into();
2020 let output: TensorGpuView<_> = output.into();
2021
2022 let context = input.context();
2023 let shape = output.shape();
2024 input.check_shape(shape)?;
2025
2026 let block_size = match shape[1] {
2027 x if x < 8 => [128, 1],
2028 _ => [16, 16],
2029 };
2030
2031 let key = PipelineKey::new(
2032 "blit",
2033 "blit",
2034 Macros::new()
2035 .u32("BLOCK_SIZE_X", block_size[0])
2036 .u32("BLOCK_SIZE_Y", block_size[1])
2037 .tensor(&input, Some("IN"))
2038 .tensor(&output, Some("OUT")),
2039 );
2040 let pipeline = context.checkout_pipeline(
2041 &key,
2042 include_str!("../shaders/blit.wgsl"),
2043 &[
2044 input.meta_layout(0),
2045 output.meta_layout(1),
2046 input.layout(2, true),
2047 output.layout(3, false),
2048 ],
2049 );
2050
2051 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2052 .bind_meta(0, &input)
2053 .bind_meta(1, &output)
2054 .bind(2, &input)
2055 .bind(3, &output)
2056 .build()];
2057
2058 Ok(Self::Atom {
2059 pipeline,
2060 bindings,
2061 dispatch: [
2062 u32::div_ceil(shape[0] as u32 / 4, block_size[0]),
2063 u32::div_ceil(shape[1] as u32, block_size[1]),
2064 shape[2] as u32,
2065 ],
2066 })
2067 }
2068
2069 pub fn broadcast<'a, 'b, F0: Float, F1: Float>(
2071 input: impl Into<TensorGpuView<'a, F0>>,
2072 output: impl Into<TensorGpuView<'b, F1>>,
2073 ) -> Result<Self, TensorError> {
2074 const BLOCK_SIZE: u32 = 128;
2075
2076 let input: TensorGpuView<_> = input.into();
2077 let output: TensorGpuView<_> = output.into();
2078
2079 let context = input.context();
2080 let shape = output.shape();
2081 input.check_shape([shape[0], input.shape()[1], input.shape()[2], 1])?;
2082
2083 let key = PipelineKey::new(
2084 "broadcast",
2085 "broadcast",
2086 Macros::new()
2087 .u32("BLOCK_SIZE", BLOCK_SIZE)
2088 .tensor(&input, Some("IN"))
2089 .tensor(&output, Some("OUT")),
2090 );
2091 let pipeline = context.checkout_pipeline(
2092 &key,
2093 include_str!("../shaders/reshape.wgsl"),
2094 &[
2095 input.meta_layout(0),
2096 output.meta_layout(1),
2097 input.layout(2, true),
2098 output.layout(3, false),
2099 ],
2100 );
2101
2102 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2103 .bind_meta(0, &input)
2104 .bind_meta(1, &output)
2105 .bind(2, &input)
2106 .bind(3, &output)
2107 .build()];
2108
2109 Ok(Self::Atom {
2110 pipeline,
2111 bindings,
2112 dispatch: [
2113 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
2114 shape[1] as u32,
2115 shape[2] as u32,
2116 ],
2117 })
2118 }
2119
2120 pub fn transpose<'a, 'b, F0: Float, F1: Float>(
2122 input: impl Into<TensorGpuView<'a, F0>>,
2123 output: impl Into<TensorGpuView<'b, F1>>,
2124 ) -> Result<Self, TensorError> {
2125 const BLOCK_SIZE: u32 = 128;
2126
2127 let input: TensorGpuView<_> = input.into();
2128 let output: TensorGpuView<_> = output.into();
2129
2130 let context = input.context();
2131 let shape = input.shape();
2132 output.check_shape([shape[0], shape[2], shape[1], 1])?;
2133
2134 let key = PipelineKey::new(
2135 "transpose",
2136 "transpose",
2137 Macros::new()
2138 .u32("BLOCK_SIZE", BLOCK_SIZE)
2139 .tensor(&input, Some("IN"))
2140 .tensor(&output, Some("OUT")),
2141 );
2142 let pipeline = context.checkout_pipeline(
2143 &key,
2144 include_str!("../shaders/reshape.wgsl"),
2145 &[
2146 input.meta_layout(0),
2147 output.meta_layout(1),
2148 input.layout(2, true),
2149 output.layout(3, false),
2150 ],
2151 );
2152
2153 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2154 .bind_meta(0, &input)
2155 .bind_meta(1, &output)
2156 .bind(2, &input)
2157 .bind(3, &output)
2158 .build()];
2159
2160 Ok(Self::Atom {
2161 pipeline,
2162 bindings,
2163 dispatch: [
2164 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
2165 shape[1] as u32,
2166 shape[2] as u32,
2167 ],
2168 })
2169 }
2170
2171 pub fn blend(
2172 factor: &TensorGpu<f32, Uniform>,
2173 input: &TensorGpu<impl Float, ReadWrite>,
2174 output: &TensorGpu<impl Float, ReadWrite>,
2175 ) -> Result<Self, TensorError> {
2176 let context = output.context();
2177 let shape = output.shape();
2178 input.check_shape(shape)?;
2179 factor.check_shape([4, 1, 1, 1])?;
2180
2181 let block_size = match shape[1] {
2182 x if x < 8 => [128, 1],
2183 _ => [16, 16],
2184 };
2185
2186 let key = PipelineKey::new(
2187 "blend",
2188 "blend",
2189 Macros::new()
2190 .u32("BLOCK_SIZE_X", block_size[0])
2191 .u32("BLOCK_SIZE_Y", block_size[1])
2192 .tensor(input, Some("IN"))
2193 .tensor(output, Some("OUT")),
2194 );
2195 let pipeline = context.checkout_pipeline(
2196 &key,
2197 include_str!("../shaders/blend.wgsl"),
2198 &[
2199 input.meta_layout(0),
2200 output.meta_layout(1),
2201 factor.layout(2),
2202 input.layout(3, true),
2203 output.layout(4, false),
2204 ],
2205 );
2206
2207 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2208 .bind_meta(0, input)
2209 .bind_meta(1, output)
2210 .bind(2, factor)
2211 .bind(3, input)
2212 .bind(4, output)
2213 .build()];
2214
2215 Ok(Self::Atom {
2216 pipeline,
2217 bindings,
2218 dispatch: [
2219 u32::div_ceil(shape[0] as u32 / 4, block_size[0]),
2220 u32::div_ceil(shape[1] as u32, block_size[1]),
2221 shape[2] as u32,
2222 ],
2223 })
2224 }
2225
2226 pub fn blend_lora<'a, 'b, 'c>(
2227 factor: &TensorGpu<f32, Uniform>,
2228 xa: impl Into<TensorGpuView<'a, f16>>,
2229 xb: impl Into<TensorGpuView<'b, f16>>,
2230 output: impl Into<TensorGpuView<'c, f16>>,
2231 ) -> Result<Self, TensorError> {
2232 const BLOCK_SIZE: u32 = 8;
2233
2234 let xa: TensorGpuView<_> = xa.into();
2235 let xb: TensorGpuView<_> = xb.into();
2236 let output: TensorGpuView<_> = output.into();
2237
2238 let context = output.context();
2239 let shape = output.shape();
2240 factor.check_shape([4, 1, 1, 1])?;
2241 xa.check_shape([xa.shape()[0], shape[0], shape[2], 1])?;
2242 xb.check_shape([xb.shape()[0], shape[1], shape[2], 1])?;
2243
2244 let key = PipelineKey::new(
2245 "blend_lora",
2246 "blend_lora",
2247 Macros::new().u32("BLOCK_SIZE", BLOCK_SIZE),
2248 );
2249 let pipeline = context.checkout_pipeline(
2250 &key,
2251 include_str!("../shaders/blend_lora.wgsl"),
2252 &[
2253 xa.meta_layout(0),
2254 xb.meta_layout(1),
2255 output.meta_layout(2),
2256 factor.layout(3),
2257 xa.layout(4, true),
2258 xb.layout(5, true),
2259 output.layout(6, false),
2260 ],
2261 );
2262
2263 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2264 .bind_meta(0, &xa)
2265 .bind_meta(1, &xb)
2266 .bind_meta(2, &output)
2267 .bind(3, factor)
2268 .bind(4, &xa)
2269 .bind(5, &xb)
2270 .bind(6, &output)
2271 .build()];
2272
2273 Ok(Self::Atom {
2274 pipeline,
2275 bindings,
2276 dispatch: [
2277 u32::div_ceil(u32::div_ceil(shape[0] as u32, 4), BLOCK_SIZE),
2278 u32::div_ceil(u32::div_ceil(shape[1] as u32, 4), BLOCK_SIZE),
2279 shape[2] as u32,
2280 ],
2281 })
2282 }
2283
2284 pub fn lerp<'a, 'b, 'c, F0: Float, F1: Float, F2: Float>(
2285 input: impl Into<TensorGpuView<'a, F0>>,
2286 output: impl Into<TensorGpuView<'b, F1>>,
2287 factor: impl Into<TensorGpuView<'c, F2>>,
2288 reversed: bool,
2289 ) -> Result<Self, TensorError> {
2290 const BLOCK_SIZE: u32 = 128;
2291
2292 let factor: TensorGpuView<_> = factor.into();
2293 let input: TensorGpuView<_> = input.into();
2294 let output: TensorGpuView<_> = output.into();
2295
2296 let context = output.context();
2297 let shape = {
2298 let [index, token, batch, _] = output.shape().into();
2299 factor.check_shape_any(&[
2300 [index, token, batch, 1],
2301 [index, token, 1, 1],
2302 [index, 1, batch, 1],
2303 [index, 1, 1, 1],
2304 ])?;
2305 input.check_shape([index, token, batch, 1])?;
2306 output.shape()
2307 };
2308
2309 let key = PipelineKey::new(
2310 "lerp",
2311 "lerp",
2312 Macros::new()
2313 .u32("BLOCK_SIZE", BLOCK_SIZE)
2314 .tensor(&factor, Some("FACTOR"))
2315 .tensor(&input, Some("IN"))
2316 .tensor(&output, Some("OUT"))
2317 .bool("REVERSED", reversed),
2318 );
2319 let pipeline = context.checkout_pipeline(
2320 &key,
2321 include_str!("../shaders/lerp.wgsl"),
2322 &[
2323 factor.meta_layout(0),
2324 input.meta_layout(1),
2325 output.meta_layout(2),
2326 factor.layout(3, true),
2327 input.layout(4, true),
2328 output.layout(5, false),
2329 ],
2330 );
2331
2332 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2333 .bind_meta(0, &factor)
2334 .bind_meta(1, &input)
2335 .bind_meta(2, &output)
2336 .bind(3, &factor)
2337 .bind(4, &input)
2338 .bind(5, &output)
2339 .build()];
2340
2341 Ok(Self::Atom {
2342 pipeline,
2343 bindings,
2344 dispatch: [
2345 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
2346 shape[1] as u32,
2347 shape[2] as u32,
2348 ],
2349 })
2350 }
2351
2352 pub fn affine(
2353 x: &TensorGpu<impl Float, ReadWrite>,
2354 scale: f32,
2355 bias: f32,
2356 ) -> Result<Self, TensorError> {
2357 const BLOCK_SIZE: u32 = 128;
2358
2359 let context = x.context();
2360 let shape = x.shape();
2361
2362 let key = PipelineKey::new(
2363 "affine",
2364 "affine",
2365 Macros::new()
2366 .u32("BLOCK_SIZE", BLOCK_SIZE)
2367 .tensor(x, None)
2368 .f32("SCALE", scale)
2369 .f32("BIAS", bias),
2370 );
2371 let pipeline = context.checkout_pipeline(
2372 &key,
2373 include_str!("../shaders/affine.wgsl"),
2374 &[x.meta_layout(0), x.layout(1, false)],
2375 );
2376
2377 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2378 .bind_meta(0, x)
2379 .bind(1, x)
2380 .build()];
2381
2382 Ok(Self::Atom {
2383 pipeline,
2384 bindings,
2385 dispatch: [
2386 u32::div_ceil(shape[0] as u32 / 4, BLOCK_SIZE),
2387 shape[1] as u32,
2388 shape[2] as u32,
2389 ],
2390 })
2391 }
2392
2393 pub fn quantize_mat_int8(
2394 input: &TensorGpu<f16, ReadWrite>,
2395 minmax: &TensorGpu<f16, ReadWrite>,
2396 output: &TensorGpu<u8, ReadWrite>,
2397 ) -> Result<Self, TensorError> {
2398 const BLOCK_SIZE: u32 = 128;
2399
2400 let context = output.context();
2401 let shape = output.shape();
2402 let minmax_len = shape.len().div_ceil(Self::INT8_BLOCK_SIZE as usize);
2403 let minmax_shape = Shape::new(minmax_len << 1, 1, 1, 1);
2404
2405 input.check_shape(shape)?;
2406 minmax.check_shape(minmax_shape)?;
2407
2408 let key = PipelineKey::new(
2409 "quant_mat_int8_minmax",
2410 "compute_minmax",
2411 Macros::new()
2412 .u32("BLOCK_SIZE", BLOCK_SIZE)
2413 .int8(Self::INT8_BLOCK_SIZE),
2414 );
2415 let pipeline = context.checkout_pipeline(
2416 &key,
2417 include_str!("../shaders/quant_mat_int8.wgsl"),
2418 &[
2419 minmax.meta_layout(0),
2420 input.meta_layout(1),
2421 input.layout(2, true),
2422 minmax.layout(3, false),
2423 ],
2424 );
2425
2426 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2427 .bind_meta(0, minmax)
2428 .bind_meta(1, input)
2429 .bind(2, input)
2430 .bind(3, minmax)
2431 .build()];
2432
2433 let compute_minmax = Self::Atom {
2434 pipeline,
2435 bindings,
2436 dispatch: [
2437 u32::div_ceil(minmax_len as u32, BLOCK_SIZE * BLOCK_SIZE),
2438 BLOCK_SIZE,
2439 1,
2440 ],
2441 };
2442
2443 let output = output.reshape(
2444 TensorDimension::Auto,
2445 TensorDimension::Size(1),
2446 TensorDimension::Size(1),
2447 TensorDimension::Size(1),
2448 )?;
2449
2450 let key = PipelineKey::new(
2451 "quant_mat_int8",
2452 "quantize",
2453 Macros::new()
2454 .u32("BLOCK_SIZE", BLOCK_SIZE)
2455 .int8(Self::INT8_BLOCK_SIZE),
2456 );
2457 let pipeline = context.checkout_pipeline(
2458 &key,
2459 include_str!("../shaders/quant_mat_int8.wgsl"),
2460 &[
2461 output.meta_layout(0),
2462 input.layout(2, true),
2463 minmax.layout(3, false),
2464 output.layout(4, false),
2465 ],
2466 );
2467
2468 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2469 .bind_meta(0, &output)
2470 .bind(2, input)
2471 .bind(3, minmax)
2472 .bind(4, &output)
2473 .build()];
2474
2475 let quantize = Self::Atom {
2476 pipeline,
2477 bindings,
2478 dispatch: [
2479 u32::div_ceil(shape[0] as u32, BLOCK_SIZE),
2480 shape[1] as u32,
2481 shape[2] as u32,
2482 ],
2483 };
2484
2485 Ok(Self::List(vec![compute_minmax, quantize]))
2486 }
2487
2488 pub fn quantize_mat_nf4(
2489 input: &TensorGpu<f16, ReadWrite>,
2490 quant: &TensorGpu<f32, Uniform>,
2491 absmax: &TensorGpu<f16, ReadWrite>,
2492 output: &TensorGpu<u8, ReadWrite>,
2493 ) -> Result<Self, TensorError> {
2494 const BLOCK_SIZE: u32 = 128;
2495
2496 let context = output.context();
2497 let shape = output.shape();
2498 let input_shape = Shape::new(shape[0] << 1, shape[1], shape[2], shape[3]);
2499 let absmax_len = input_shape.len().div_ceil(Self::NF4_BLOCK_SIZE as usize);
2500 let absmax_shape = Shape::new(absmax_len, 1, 1, 1);
2501
2502 input.check_shape(input_shape)?;
2503 absmax.check_shape(absmax_shape)?;
2504
2505 let absmax_f32: TensorGpu<f32, ReadWrite> = context.tensor_init(absmax_shape);
2506
2507 let key = PipelineKey::new(
2508 "quant_mat_nf4_absmax",
2509 "compute_absmax",
2510 Macros::new()
2511 .u32("BLOCK_SIZE", BLOCK_SIZE)
2512 .nf4(Self::NF4_BLOCK_SIZE),
2513 );
2514 let pipeline = context.checkout_pipeline(
2515 &key,
2516 include_str!("../shaders/quant_mat_nf4.wgsl"),
2517 &[
2518 absmax_f32.meta_layout(0),
2519 input.meta_layout(1),
2520 input.layout(3, true),
2521 absmax_f32.layout(4, false),
2522 ],
2523 );
2524
2525 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2526 .bind_meta(0, &absmax_f32)
2527 .bind_meta(1, input)
2528 .bind(3, input)
2529 .bind(4, &absmax_f32)
2530 .build()];
2531
2532 let compute_absmax = Self::Atom {
2533 pipeline,
2534 bindings,
2535 dispatch: [
2536 u32::div_ceil(absmax_len as u32, BLOCK_SIZE * BLOCK_SIZE),
2537 BLOCK_SIZE,
2538 1,
2539 ],
2540 };
2541
2542 let output = output.reshape(
2543 TensorDimension::Auto,
2544 TensorDimension::Size(1),
2545 TensorDimension::Size(1),
2546 TensorDimension::Size(1),
2547 )?;
2548
2549 let key = PipelineKey::new(
2550 "quant_mat_nf4",
2551 "quantize",
2552 Macros::new()
2553 .u32("BLOCK_SIZE", BLOCK_SIZE)
2554 .nf4(Self::NF4_BLOCK_SIZE),
2555 );
2556 let pipeline = context.checkout_pipeline(
2557 &key,
2558 include_str!("../shaders/quant_mat_nf4.wgsl"),
2559 &[
2560 output.meta_layout(0),
2561 quant.layout(2),
2562 input.layout(3, true),
2563 absmax_f32.layout(4, false),
2564 output.layout(5, false),
2565 ],
2566 );
2567
2568 let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
2569 .bind_meta(0, &output)
2570 .bind(2, quant)
2571 .bind(3, input)
2572 .bind(4, &absmax_f32)
2573 .bind(5, &output)
2574 .build()];
2575
2576 let quantize = Self::Atom {
2577 pipeline,
2578 bindings,
2579 dispatch: [
2580 u32::div_ceil((shape[0]) as u32, BLOCK_SIZE),
2581 shape[1] as u32,
2582 shape[2] as u32,
2583 ],
2584 };
2585
2586 let quantize_absmax = Self::blit(&absmax_f32, absmax)?;
2587
2588 Ok(Self::List(vec![compute_absmax, quantize, quantize_absmax]))
2589 }
2590}
2591
2592#[cfg(test)]
2593mod tests {
2594 use std::f32::consts::PI;
2595
2596 use anyhow::Result;
2597 use half::f16;
2598 use itertools::Itertools;
2599 use wgpu::{Instance, PowerPreference};
2600 use super::TensorOp;
2603 use crate::{
2604 context::{Context, ContextBuilder, InstanceExt},
2605 tensor::{ops::Activation, Shape, TensorGpu},
2606 };
2607
2608 fn is_approx(a: impl Into<f32>, b: impl Into<f32>) -> bool {
2609 let a: f32 = a.into();
2610 let b: f32 = b.into();
2611 (a - b).abs() <= f32::max(f32::EPSILON, f32::max(a.abs(), b.abs()) * f32::EPSILON)
2612 }
2613
2614 fn is_approx_eps(a: impl Into<f32>, b: impl Into<f32>, eps: f32) -> bool {
2615 let a: f32 = a.into();
2616 let b: f32 = b.into();
2617 (a - b).abs() <= f32::max(eps, f32::max(a.abs(), b.abs()) * eps)
2618 }
2619
2620 async fn create_context() -> Result<Context> {
2621 let instance = Instance::default();
2622 let adapter = instance.adapter(PowerPreference::HighPerformance).await?;
2623 let context = ContextBuilder::new(adapter)
2624 .build()
2626 .await?;
2627 Ok(context)
2628 }
2629
2630 #[cfg(feature = "tokio")]
2631 #[tokio::test]
2632 async fn test_softmax() -> Result<()> {
2633 let context = create_context().await?;
2634 fastrand::seed(42);
2635
2636 const C: usize = 1000;
2637 const T: usize = 3;
2638 const B: usize = 2;
2639
2640 let x = [(); C * T * B]
2641 .map(|_| 10.0 * (fastrand::f32() - 0.5))
2642 .to_vec();
2643 let shape = Shape::new(C, T, B, 1);
2644
2645 let x_dev: TensorGpu<_, _> = context.tensor_from_data(shape, x.clone())?;
2646 let softmax = TensorOp::softmax(&x_dev)?;
2647
2648 context.queue.submit(context.encode(&softmax));
2649 let x_host = x_dev.back().await.to_vec();
2650
2651 let mut ans = vec![];
2652 for x in &x.into_iter().chunks(C) {
2653 let x = x.collect_vec().into_iter();
2654 let max = x.clone().reduce(f32::max).unwrap_or_default();
2655 let x = x.map(|x| (x - max).exp());
2656 let sum: f32 = x.clone().sum();
2657 let x = x.map(|x| x / sum);
2658 ans.extend(x);
2659 }
2660
2661 for (index, (a, b)) in itertools::zip_eq(x_host, ans).enumerate() {
2662 assert!(
2663 is_approx(a, b),
2664 "Failed at index {index}, computed: {a} vs. answer: {b}"
2665 );
2666 }
2667
2668 Ok(())
2669 }
2670
2671 #[cfg(feature = "tokio")]
2672 #[tokio::test]
2673 async fn test_layer_norm() -> Result<()> {
2674 let context = create_context().await?;
2675 fastrand::seed(42);
2676
2677 const C: usize = 1000;
2678 const T: usize = 3;
2679 const B: usize = 2;
2680 const EPS: f32 = 1.0e-5;
2681
2682 let x = [(); C * T * B]
2683 .map(|_| 10.0 * (fastrand::f32() - 0.5))
2684 .to_vec();
2685 let w = [(); C]
2686 .map(|_| f16::from_f32(fastrand::f32() - 0.5))
2687 .repeat(T * B)
2688 .to_vec();
2689 let b = [(); C]
2690 .map(|_| f16::from_f32(fastrand::f32() - 0.5))
2691 .repeat(T * B)
2692 .to_vec();
2693
2694 let shape = Shape::new(C, T, B, 1);
2695 let x_dev = context.tensor_from_data(shape, x.clone())?;
2696
2697 let shape = Shape::new(C, 1, 1, 1);
2698 let w_dev = context.tensor_from_data(shape, &w[..1000])?;
2699 let b_dev = context.tensor_from_data(shape, &b[..1000])?;
2700
2701 let layer_norm = TensorOp::layer_norm(&w_dev, &b_dev, &x_dev, EPS)?;
2705 context.queue.submit(context.encode(&layer_norm));
2706
2707 let x_host = x_dev.back().await.to_vec();
2708 let shape = Shape::new(C, T, B, 1);
2712 let x_dev = context.tensor_from_data(shape, x.clone())?;
2713 let ops = TensorOp::List(vec![
2714 TensorOp::recenter(&x_dev)?,
2715 TensorOp::rms_norm(&w_dev, &b_dev, &x_dev, EPS)?,
2716 ]);
2717 context.queue.submit(context.encode(&ops));
2718
2719 let x_rms_host = x_dev.back().await.to_vec();
2720
2721 let mut ans = vec![];
2722 for chunk in &x
2724 .into_iter()
2725 .zip(w.into_iter())
2726 .zip(b.into_iter())
2727 .chunks(C)
2728 {
2729 let chunk = chunk.collect_vec();
2730 let x = chunk.iter().map(|((x, _), _)| x).copied();
2731 let (mean, m2, count) = x.fold((0.0f32, 0.0f32, 0u32), |(mean, m2, count), x| {
2737 let count = count + 1;
2738 let delta = x - mean;
2739 let mean = mean + delta / count as f32;
2740 let m2 = m2 + delta * (x - mean);
2741 (mean, m2, count)
2742 });
2743 let variance = m2 / count as f32 + EPS;
2744 let deviation = 1.0 / variance.sqrt();
2745 let x = chunk
2748 .into_iter()
2749 .map(|((x, w), b)| (x - mean) * deviation * w.to_f32() + b.to_f32());
2750 ans.extend(x);
2751 }
2752
2753 for (index, (a, &b)) in itertools::zip_eq(x_host, ans.iter()).enumerate() {
2754 assert!(
2755 is_approx_eps(a, b, 1.0e-3),
2756 "Failed at index {index}, computed: {a} vs. answer: {b}"
2757 );
2758 }
2759
2760 for (index, (a, &b)) in itertools::zip_eq(x_rms_host, ans.iter()).enumerate() {
2761 assert!(
2762 is_approx_eps(a, b, 1.0e-3),
2763 "Failed at index {index}, computed: {a} vs. answer: {b}"
2764 );
2765 }
2766
2767 Ok(())
2768 }
2769
2770 #[cfg(feature = "tokio")]
2771 #[tokio::test]
2772 async fn test_l2_norm() -> Result<()> {
2773 let context = create_context().await?;
2774 fastrand::seed(42);
2775
2776 const C: usize = 1000;
2777 const T: usize = 3;
2778 const B: usize = 2;
2779 const EPS: f32 = 1.0e-12;
2780
2781 let x = [(); C * T * B]
2782 .map(|_| 10.0 * (fastrand::f32() - 0.5))
2783 .to_vec();
2784
2785 let shape = Shape::new(C, T, B, 1);
2786 let x_dev = context.tensor_from_data(shape, x.clone())?;
2787
2788 let l2_norm = TensorOp::l2_norm(&x_dev, EPS)?;
2789 context.queue.submit(context.encode(&l2_norm));
2790
2791 let x_host = x_dev.back().await.to_vec();
2792
2793 let mut ans = vec![];
2794 for x in &x.into_iter().chunks(C) {
2795 let x = x.collect_vec().into_iter();
2796 let norm = x.clone().map(|x| x * x).sum::<f32>().sqrt();
2797 let x = x.map(|x| x / (norm + EPS));
2798 ans.extend(x);
2799 }
2800
2801 for (index, (a, b)) in itertools::zip_eq(x_host, ans).enumerate() {
2802 assert!(
2803 is_approx(a, b),
2804 "Failed at index {index}, computed: {a} vs. answer: {b}"
2805 );
2806 }
2807
2808 Ok(())
2809 }
2810
2811 #[cfg(feature = "tokio")]
2812 #[tokio::test]
2813 async fn test_matmul() -> Result<()> {
2814 let context = create_context().await?;
2815 fastrand::seed(42);
2816
2817 async fn test_matmul_inner(
2818 context: &Context,
2819 c: usize,
2820 r: usize,
2821 t: usize,
2822 b: usize,
2823 ) -> Result<()> {
2824 let matrix = vec![(); c * r * b]
2827 .into_iter()
2828 .map(|_| 10.0 * (fastrand::f32() - 0.5))
2829 .map(f16::from_f32)
2830 .collect_vec();
2831 let input_f32 = vec![(); c * t * b]
2832 .into_iter()
2833 .map(|_| 10.0 * (fastrand::f32() - 0.5))
2834 .collect_vec();
2835 let input_f16 = input_f32.iter().copied().map(f16::from_f32).collect_vec();
2836
2837 let matrix_shape = Shape::new(c, r, b, 1);
2838 let input_shape = Shape::new(c, t, b, 1);
2839 let output_shape = Shape::new(r, t, 2 * b, 1);
2840
2841 let matrix_dev = context.tensor_from_data(matrix_shape, matrix.clone())?;
2842 let input_f32_dev = context.tensor_from_data(input_shape, input_f32.clone())?;
2843 let input_f16_dev: TensorGpu<f16, _> = context.tensor_init(input_shape);
2844 let output_dev: TensorGpu<_, _> = context.tensor_init(output_shape);
2845
2846 let ops = TensorOp::List(vec![
2847 TensorOp::blit(&input_f32_dev, &input_f16_dev)?,
2848 TensorOp::matmul_vec_fp16(
2849 &matrix_dev,
2850 &input_f32_dev,
2851 output_dev.view(.., .., 0..b, ..)?,
2852 Activation::None,
2853 false,
2854 )?,
2855 TensorOp::matmul_mat_fp16(
2856 &matrix_dev,
2857 &input_f16_dev,
2858 output_dev.view(.., .., b.., ..)?,
2859 Activation::None,
2860 )?,
2861 ]);
2862
2863 context.queue.submit(context.encode(&ops));
2865
2866 let output_host = output_dev.back().await;
2867 let output_host: Vec<f32> = Vec::from(output_host);
2868
2869 let mut ans = vec![0.0; output_host.len()];
2881 for ((batch, token), line) in (0..b).cartesian_product(0..t).cartesian_product(0..r) {
2882 let matrix = &matrix[((batch * r + line) * c)..((batch * r + line) + 1) * c];
2883 let input = &input_f32[(batch * t + token) * c..((batch * t + token) + 1) * c];
2884 let product = matrix
2885 .iter()
2886 .zip(input.iter())
2887 .fold(0.0f32, |acc, x| acc + x.0.to_f32() * *x.1);
2888 ans[(batch * t + token) * r + line] = product;
2889
2890 let input = &input_f16[(batch * t + token) * c..((batch * t + token) + 1) * c];
2891 let product = matrix
2892 .iter()
2893 .zip(input.iter())
2894 .fold(0.0f32, |acc, x| acc + x.0.to_f32() * x.1.to_f32());
2895 ans[((b + batch) * t + token) * r + line] = product;
2896 }
2897
2898 for (index, (a, b)) in itertools::zip_eq(output_host, ans).enumerate() {
2899 assert!(
2900 is_approx_eps(a, b, 0.01),
2901 "Failed at index {index}, computed: {a} vs. answer: {b}"
2902 );
2903 }
2904
2905 Ok(())
2906 }
2907
2908 test_matmul_inner(&context, 2560, 2048, 32, 2).await?;
2909 test_matmul_inner(&context, 320, 64, 320, 2).await?;
2910
2911 Ok(())
2912 }
2913
2914 #[cfg(feature = "tokio")]
2915 #[tokio::test]
2916 async fn test_matmul_int8() -> Result<()> {
2917 let context = create_context().await?;
2918 fastrand::seed(42);
2919
2920 const INT8_BLOCK_SIZE: usize = TensorOp::INT8_BLOCK_SIZE as usize;
2921
2922 async fn test_matmul_int8_inner(
2923 context: &Context,
2924 c: usize,
2925 r: usize,
2926 t: usize,
2927 ) -> Result<()> {
2928 let matrix = vec![(); c * r]
2929 .into_iter()
2930 .map(|_| 10.0 * (fastrand::f32() - 0.5))
2931 .map(f16::from_f32)
2932 .collect_vec();
2933 let input_f32 = vec![(); c * t]
2934 .into_iter()
2935 .map(|_| 10.0 * (fastrand::f32() - 0.5))
2936 .collect_vec();
2937 let input_f16 = input_f32.iter().copied().map(f16::from_f32).collect_vec();
2938
2939 let (matrix_u8, min, max) = {
2940 let mut matrix_u8: Vec<u8> = vec![0; matrix.len()];
2941 let mut min = vec![f16::MAX; matrix.len().div_ceil(INT8_BLOCK_SIZE)];
2942 let mut max = vec![f16::MIN; matrix.len().div_ceil(INT8_BLOCK_SIZE)];
2943
2944 for (i, (min, max)) in itertools::zip_eq(&mut min, &mut max).enumerate() {
2945 let start = i * INT8_BLOCK_SIZE;
2946 let end = start + INT8_BLOCK_SIZE;
2947 let chunk = &matrix[start..end];
2948 for value in chunk.iter() {
2949 *min = min.min(*value);
2950 *max = max.max(*value);
2951 }
2952 for (j, value) in chunk.iter().enumerate() {
2953 let value = value.to_f32();
2954 let min = min.to_f32();
2955 let max = max.to_f32();
2956 let value = (value - min) / (max - min);
2957 matrix_u8[start + j] = f32::round(value * 255.0) as u8;
2958 }
2959 }
2960
2961 (matrix_u8, min, max)
2962 };
2963 let minmax = itertools::zip_eq(&min, &max)
2964 .map(|(&min, &max)| [min, max])
2965 .collect_vec()
2966 .concat();
2967
2968 let minmax_shape = Shape::new((c * r).div_ceil(INT8_BLOCK_SIZE) * 2, 1, 1, 1);
2969 let matrix_shape = Shape::new(c, r, 1, 1);
2970 let input_shape = Shape::new(c, t, 1, 1);
2971 let output_shape = Shape::new(r, t, 1, 1);
2972
2973 let minmax_dev = context.tensor_init(minmax_shape);
2974 let matrix_f16_dev = context.tensor_from_data(matrix_shape, matrix.clone())?;
2975
2976 let matrix_u8_dev = context.tensor_init(matrix_shape);
2977 let input_dev = context.tensor_from_data(input_shape, input_f16.clone())?;
2978 let output_dev = context.tensor_init(output_shape);
2979
2980 let ops = TensorOp::List(vec![TensorOp::quantize_mat_int8(
2981 &matrix_f16_dev,
2982 &minmax_dev,
2983 &matrix_u8_dev,
2984 )?]);
2985 context.queue.submit(context.encode(&ops));
2986 let minmax_host = minmax_dev.back().await.to_vec();
2987 let matrix_u8_host = matrix_u8_dev.back().await.to_vec();
2988
2989 for (index, (&a, &b)) in itertools::zip_eq(&minmax_host, &minmax).enumerate() {
2990 assert!(
2991 is_approx_eps(a, b, 0.01),
2992 "Failed at index {index}, computed: {a} vs. answer: {b}"
2993 );
2994 }
2995 for (index, (&a, &b)) in itertools::zip_eq(&matrix_u8_host, &matrix_u8).enumerate() {
2996 assert!(
2997 a.abs_diff(b) < 2,
2998 "Failed at index {index}, computed: {a} vs. answer: {b}"
2999 );
3000 }
3001
3002 let mut ans = vec![0.0; t * r];
3003 for (token, line) in (0..t).cartesian_product(0..r) {
3004 let matrix = &matrix_u8_host[line * c..(line + 1) * c];
3005 let input = &input_f16[token * c..(token + 1) * c];
3006 let product =
3007 matrix
3008 .iter()
3009 .zip_eq(input.iter())
3010 .enumerate()
3011 .fold(0.0f32, |acc, (i, x)| {
3012 let min = min[(line * c + i) / INT8_BLOCK_SIZE].to_f32();
3013 let max = max[(line * c + i) / INT8_BLOCK_SIZE].to_f32();
3014 let value = (*x.0 as f32) / 255.0;
3015 acc + (value * (max - min) + min) * x.1.to_f32()
3016 });
3017 ans[token * r + line] = product;
3018 }
3019
3020 let ops = TensorOp::List(vec![TensorOp::matmul_vec_int8(
3021 &matrix_u8_dev,
3022 &minmax_dev,
3023 &input_dev,
3024 &output_dev,
3025 Activation::None,
3026 false,
3027 )?]);
3028 context.queue.submit(context.encode(&ops));
3029 let output_host: Vec<f32> = output_dev.back().await.to_vec();
3030
3031 for (index, (&a, &b)) in itertools::zip_eq(&output_host, &ans).enumerate() {
3032 assert!(
3033 is_approx_eps(a, b, 0.01),
3034 "Failed at index {index}, computed: {a} vs. answer: {b}"
3035 );
3036 }
3037
3038 let ops = TensorOp::List(vec![TensorOp::matmul_mat_int8(
3039 &matrix_u8_dev,
3040 &minmax_dev,
3041 &input_dev,
3042 &output_dev,
3043 Activation::None,
3044 )?]);
3045 context.queue.submit(context.encode(&ops));
3046 let output_host = output_dev.back().await.to_vec();
3047
3048 for (index, (&a, &b)) in itertools::zip_eq(&output_host, &ans).enumerate() {
3049 assert!(
3050 is_approx_eps(a, b, 0.01),
3051 "Failed at index {index}, computed: {a} vs. answer: {b}"
3052 );
3053 }
3054
3055 Ok(())
3056 }
3057
3058 test_matmul_int8_inner(&context, 2560, 2048, 64).await?;
3059 test_matmul_int8_inner(&context, 320, 64, 320).await?;
3060
3061 Ok(())
3062 }
3063
3064 #[cfg(feature = "tokio")]
3065 #[tokio::test]
3066 async fn test_matmul_nf4() -> Result<()> {
3067 let context = create_context().await?;
3068 fastrand::seed(42);
3069
3070 const NF4_BLOCK_SIZE: usize = TensorOp::NF4_BLOCK_SIZE as usize;
3071
3072 fn normal() -> f32 {
3073 let u = fastrand::f32();
3074 let v = fastrand::f32();
3075 (-2.0 * u.ln()).sqrt() * (2.0 * PI * v).cos()
3076 }
3077
3078 async fn test_matmul_nf4_inner(
3079 context: &Context,
3080 c: usize,
3081 r: usize,
3082 t: usize,
3083 ) -> Result<()> {
3084 let matrix = vec![(); c * r]
3085 .into_iter()
3086 .map(|_| normal())
3087 .map(f16::from_f32)
3088 .collect_vec();
3089 let input_f32 = vec![(); c * t]
3090 .into_iter()
3091 .map(|_| 2.0 * fastrand::f32() - 1.0)
3092 .collect_vec();
3093 let input_f16 = input_f32.iter().copied().map(f16::from_f32).collect_vec();
3094
3095 #[allow(clippy::excessive_precision)]
3096 let quant: [f32; 16] = [
3097 -1.0,
3098 -0.6961928009986877,
3099 -0.5250730514526367,
3100 -0.39491748809814453,
3101 -0.28444138169288635,
3102 -0.18477343022823334,
3103 -0.09105003625154495,
3104 0.0,
3105 0.07958029955625534,
3106 0.16093020141124725,
3107 0.24611230194568634,
3108 0.33791524171829224,
3109 0.44070982933044434,
3110 0.5626170039176941,
3111 0.7229568362236023,
3112 1.0,
3113 ];
3114 let (matrix_u8, matrix_u4, absmax) = {
3115 let mut matrix_u8: Vec<u8> = vec![0; matrix.len()];
3116 let mut matrix_u4: Vec<u8> = vec![0; matrix.len() / 2];
3117 let mut absmax = vec![f16::ZERO; matrix.len().div_ceil(NF4_BLOCK_SIZE)];
3118
3119 for (i, absmax) in absmax.iter_mut().enumerate() {
3120 let start = i * NF4_BLOCK_SIZE;
3121 let end = start + NF4_BLOCK_SIZE;
3122 let chunk = &matrix[start..end];
3123 *absmax = chunk
3124 .iter()
3125 .map(|&x| if x >= f16::ZERO { x } else { -x })
3126 .reduce(f16::max)
3127 .unwrap();
3128 for (j, value) in chunk.iter().enumerate() {
3129 let value = value.to_f32() / absmax.to_f32();
3130 matrix_u8[start + j] = quant
3131 .iter()
3132 .map(|quant| (value - quant).abs())
3133 .enumerate()
3134 .fold((0, f32::MAX), |acc, x| if x.1 < acc.1 { x } else { acc })
3135 .0 as u8;
3136 }
3137 }
3138
3139 for (i, x) in matrix_u4.iter_mut().enumerate() {
3140 *x = matrix_u8[2 * i] | matrix_u8[2 * i + 1] << 4;
3141 }
3142
3143 (matrix_u8, matrix_u4, absmax)
3144 };
3145
3146 let quant_shape = Shape::new(quant.len(), 1, 1, 1);
3147 let absmax_shape = Shape::new((c * r).div_ceil(NF4_BLOCK_SIZE), 1, 1, 1);
3148 let matrix_f16_shape = Shape::new(c, r, 1, 1);
3149 let matrix_u4_shape = Shape::new(c / 2, r, 1, 1);
3150 let input_shape = Shape::new(c, t, 1, 1);
3151 let output_shape = Shape::new(r, t, 1, 1);
3152
3153 let quant_dev = context.tensor_from_data(quant_shape, quant.to_vec())?;
3154 let absmax_dev = context.tensor_init(absmax_shape);
3155 let matrix_f16_dev = context.tensor_from_data(matrix_f16_shape, matrix.clone())?;
3156
3157 let matrix_u4_dev = context.tensor_init(matrix_u4_shape);
3158 let input_dev: TensorGpu<_, _> =
3159 context.tensor_from_data(input_shape, input_f16.clone())?;
3160 let output_dev: TensorGpu<_, _> = context.tensor_init(output_shape);
3161
3162 let ops = TensorOp::List(vec![TensorOp::quantize_mat_nf4(
3163 &matrix_f16_dev,
3164 &quant_dev,
3165 &absmax_dev,
3166 &matrix_u4_dev,
3167 )?]);
3168 context.queue.submit(context.encode(&ops));
3169 let matrix_u4_host = matrix_u4_dev.back().await.to_vec();
3170 let absmax_host = absmax_dev.back().await.to_vec();
3171
3172 for (index, (&a, &b)) in itertools::zip_eq(&absmax_host, &absmax).enumerate() {
3173 assert!(
3174 is_approx_eps(a.to_f32(), b.to_f32(), 0.01),
3175 "Failed at index {index}, computed: {a} vs. answer: {b}"
3176 );
3177 }
3178
3179 for (index, (a, b)) in itertools::zip_eq(matrix_u4_host, matrix_u4).enumerate() {
3180 assert!(
3181 a == b,
3182 "Failed at index {index}, computed: {a} vs. answer: {b}"
3183 );
3184 }
3185
3186 let mut truth = vec![0.0; t * r];
3187 for (token, line) in (0..t).cartesian_product(0..r) {
3188 let matrix = &matrix[line * c..(line + 1) * c];
3189 let input = &input_f16[token * c..(token + 1) * c];
3190 let product = matrix
3191 .iter()
3192 .zip(input.iter())
3193 .fold(0.0f32, |acc, x| acc + x.0.to_f32() * x.1.to_f32());
3194 truth[token * r + line] = product;
3195 }
3196
3197 let mut ans = vec![0.0; t * r];
3198 for (token, line) in (0..t).cartesian_product(0..r) {
3199 let matrix = &matrix_u8[line * c..(line + 1) * c];
3200 let input = &input_f16[token * c..(token + 1) * c];
3201 let product =
3202 matrix
3203 .iter()
3204 .zip(input.iter())
3205 .enumerate()
3206 .fold(0.0f32, |acc, (i, x)| {
3207 let amp = absmax[(line * c + i) / NF4_BLOCK_SIZE];
3208 acc + quant[*x.0 as usize] * amp.to_f32() * x.1.to_f32()
3209 });
3210 ans[token * r + line] = product;
3211 }
3212
3213 let ops = TensorOp::List(vec![TensorOp::matmul_vec_nf4(
3214 &matrix_u4_dev,
3215 &quant_dev,
3216 &absmax_dev,
3217 &input_dev,
3218 &output_dev,
3219 Activation::None,
3220 false,
3221 )?]);
3222 context.queue.submit(context.encode(&ops));
3223 let output_host: Vec<f32> = output_dev.back().await.to_vec();
3224
3225 for (index, (&a, &b)) in itertools::zip_eq(&output_host, &ans).enumerate() {
3226 assert!(
3227 is_approx_eps(a, b, 0.01),
3228 "Failed at index {index}, computed: {a} vs. answer: {b}"
3229 );
3230 }
3231
3232 let ops = TensorOp::List(vec![TensorOp::matmul_mat_nf4(
3233 &matrix_u4_dev,
3234 &quant_dev,
3235 &absmax_dev,
3236 &input_dev,
3237 &output_dev,
3238 Activation::None,
3239 )?]);
3240 context.queue.submit(context.encode(&ops));
3241 let output_host = output_dev.back().await.to_vec();
3242
3243 for (index, (&a, &b)) in itertools::zip_eq(&output_host, &ans).enumerate() {
3244 assert!(
3245 is_approx_eps(a, b, 0.01),
3246 "Failed at index {index}, computed: {a} vs. answer: {b}"
3247 );
3248 }
3249
3250 Ok(())
3251 }
3252
3253 test_matmul_nf4_inner(&context, 2560, 2048, 64).await?;
3254 test_matmul_nf4_inner(&context, 320, 64, 320).await?;
3255
3256 Ok(())
3257 }
3258
3259 #[cfg(feature = "tokio")]
3260 #[tokio::test]
3261 async fn test_lerp() -> Result<()> {
3262 let context = create_context().await?;
3263 fastrand::seed(42);
3264
3265 const C: usize = 1000;
3266 const T: usize = 3;
3267 const B: usize = 2;
3268
3269 let x = [(); C * T * B].map(|_| fastrand::f32() - 0.5).to_vec();
3270 let y = [(); C * T * B].map(|_| fastrand::f32() - 0.5).to_vec();
3271 let f = [(); C * T * B].map(|_| fastrand::f32()).to_vec();
3272
3273 let shape = Shape::new(C, T, B, 1);
3274 let x_dev = context.tensor_from_data(shape, x.clone())?;
3275 let y_dev = context.tensor_from_data(shape, y.clone())?;
3276 let f_dev = context.tensor_from_data(shape, f.clone())?;
3277
3278 let lerp = TensorOp::lerp(&x_dev, &y_dev, &f_dev, false)?;
3279 context.queue.submit(context.encode(&lerp));
3280
3281 let y_host = y_dev.back().await.to_vec();
3282
3283 let mut ans = vec![];
3284 for chunk in &itertools::multizip((&x, &y, &f)).chunks(C) {
3285 for (x, y, f) in chunk {
3286 ans.push(x * (1.0 - f) + y * f);
3287 }
3288 }
3289
3290 for (index, (a, b)) in itertools::zip_eq(y_host, ans).enumerate() {
3291 assert!(
3292 is_approx(a, b),
3293 "Failed at index {index}, computed: {a} vs. answer: {b}"
3294 );
3295 }
3296
3297 Ok(())
3298 }
3299
3300 #[cfg(feature = "tokio")]
3301 #[tokio::test]
3302 async fn test_blit() -> Result<()> {
3303 let context = create_context().await?;
3304 fastrand::seed(42);
3305
3306 let output = vec![0.0; 24];
3307 let output: TensorGpu<_, _> = context.tensor_from_data([4, 3, 2, 1], output)?;
3308
3309 let mut ops = vec![];
3310
3311 let input = (0..8).map(|x| x as f32).collect_vec();
3312 let input: TensorGpu<_, _> = context.tensor_from_data([4, 1, 2, 1], input)?;
3313 ops.push(TensorOp::blit(&input, output.view(.., 1, .., ..)?)?);
3314
3315 let input = (8..12).map(|x| x as f32).collect_vec();
3316 let input: TensorGpu<_, _> = context.tensor_from_data([4, 1, 1, 1], input)?;
3317 ops.push(TensorOp::blit(&input, output.view(.., 2.., 1..2, ..)?)?);
3318
3319 let ops = TensorOp::List(ops);
3320 context.queue.submit(context.encode(&ops));
3321
3322 let output_host = output.back().await;
3323 let output_host = Vec::from(output_host);
3324
3325 assert_eq!(
3326 output_host,
3327 vec![
3328 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
3329 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0
3330 ]
3331 );
3332
3333 Ok(())
3334 }
3335
3336 #[cfg(feature = "tokio")]
3337 #[tokio::test]
3338 async fn test_transpose() -> Result<()> {
3339 let context = create_context().await?;
3340 fastrand::seed(42);
3341
3342 let output = vec![0.0; 36];
3343 let output: TensorGpu<_, _> = context.tensor_from_data([4, 3, 3, 1], output)?;
3344
3345 let input = (0..24).map(|x| x as f32).collect_vec();
3346 let input: TensorGpu<_, _> = context.tensor_from_data([4, 3, 2, 1], input)?;
3347
3348 let ops = TensorOp::transpose(&input, output.view(.., ..2, .., ..)?)?;
3349 context.queue.submit(context.encode(&ops));
3350
3351 let output_host = output.back().await;
3352 let output_host: Vec<f32> = Vec::from(output_host);
3353
3354 assert_eq!(
3355 output_host,
3356 vec![
3357 0.0, 1.0, 2.0, 3.0, 12.0, 13.0, 14.0, 15.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0,
3358 16.0, 17.0, 18.0, 19.0, 0.0, 0.0, 0.0, 0.0, 8.0, 9.0, 10.0, 11.0, 20.0, 21.0, 22.0,
3359 23.0, 0.0, 0.0, 0.0, 0.0
3360 ]
3361 );
3362
3363 Ok(())
3364 }
3365}