1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use half::f16;
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use web_rwkv_derive::DeserializeSeed;
11use wgpu::CommandBuffer;
12
13use super::{
14 infer::{RnnChunk, RnnInfo, RnnInput, RnnOutput, RnnOutputBatch, RnnRedirect, Token},
15 loader::{Loader, LoaderError, Reader},
16 model::{AsAny, ModelBuilder, ModelInfo, Quant, State as _},
17 Dispatcher, Job, RuntimeError,
18};
19use crate::{
20 context::Context,
21 num::Float,
22 tensor::{
23 cache::ResourceCache,
24 kind::ReadWrite,
25 matrix::Matrix,
26 ops::{Activation, TensorCommand, TensorOp},
27 serialization::Seed,
28 shape::Shape,
29 DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit,
30 TensorShape, TensorStack,
31 },
32};
33
34#[derive(Debug, Clone, Serialize, DeserializeSeed)]
35#[serde_seed(seed = "Seed", context = "Context")]
36pub struct Model {
37 pub context: Context,
38 pub info: ModelInfo,
39 pub rescale: usize,
40 pub sep: usize,
41 pub tensor: ModelTensor,
42}
43
44impl Model {
45 pub const LN_EPS: f32 = 1.0e-5;
46 pub const GN_EPS: f32 = 64.0e-5;
47
48 pub const DEFAULT_RESCALE: usize = 6;
49 pub const DEFAULT_SEP: usize = 1024;
50}
51
52#[derive(Debug, Clone, Serialize, DeserializeSeed)]
53#[serde_seed(seed = "Seed", context = "Context")]
54pub struct ModelTensor {
55 pub embed: Embed,
56 pub head: Head,
57 pub layers: Vec<Layer>,
58}
59
60#[derive(Debug, Clone, Serialize, DeserializeSeed)]
61#[serde_seed(seed = "Seed", context = "Context")]
62pub struct LayerNorm {
63 pub w: TensorGpu<f16, ReadWrite>,
64 pub b: TensorGpu<f16, ReadWrite>,
65}
66
67#[derive(Debug, Clone, Serialize, DeserializeSeed)]
68#[serde_seed(seed = "Seed", context = "Context")]
69pub struct Att {
70 pub time_decay: TensorGpu<f32, ReadWrite>,
71 pub time_first: TensorGpu<f32, ReadWrite>,
72
73 pub time_mix_k: TensorGpu<f16, ReadWrite>,
74 pub time_mix_v: TensorGpu<f16, ReadWrite>,
75 pub time_mix_r: TensorGpu<f16, ReadWrite>,
76
77 pub w_k: Matrix,
78 pub w_v: Matrix,
79 pub w_r: Matrix,
80 pub w_o: Matrix,
81}
82
83#[derive(Debug, Clone, Serialize, DeserializeSeed)]
84#[serde_seed(seed = "Seed", context = "Context")]
85pub struct Ffn {
86 pub time_mix_k: TensorGpu<f16, ReadWrite>,
87 pub time_mix_r: TensorGpu<f16, ReadWrite>,
88
89 pub w_k: Matrix,
90 pub w_v: Matrix,
91 pub w_r: Matrix,
92}
93
94#[derive(Debug, Clone, Serialize, DeserializeSeed)]
95#[serde_seed(seed = "Seed", context = "Context")]
96pub struct Layer {
97 pub att_layer_norm: LayerNorm,
98 pub ffn_layer_norm: LayerNorm,
99 pub att: Att,
100 pub ffn: Ffn,
101}
102
103#[derive(Debug, Clone, Serialize, DeserializeSeed)]
104#[serde_seed(seed = "Seed", context = "Context")]
105pub struct Embed {
106 pub layer_norm: LayerNorm,
107 pub w: TensorCpu<f16>,
108}
109
110#[derive(Debug, Clone, Serialize, DeserializeSeed)]
111#[serde_seed(seed = "Seed", context = "Context")]
112pub struct Head {
113 pub layer_norm: LayerNorm,
114 pub w: Matrix,
115}
116
117#[derive(Debug, Clone, Serialize, DeserializeSeed)]
118#[serde_seed(seed = "Seed", context = "Context")]
119pub struct State {
120 pub context: Context,
121 pub info: ModelInfo,
122 pub data: TensorGpu<f32, ReadWrite>,
123}
124
125impl State {
126 async fn back(&self, batch: usize) -> Result<TensorCpu<f32>, TensorError> {
127 let context = &self.context;
128
129 let shape = self.data.shape();
130 let tensor: TensorGpu<f32, ReadWrite> = context.tensor_init([shape[0], shape[1], 1, 1]);
131 let mut encoder = context.device.create_command_encoder(&Default::default());
132 encoder.copy_tensor_batch(&self.data, &tensor, batch, 0)?;
133 context.queue.submit(Some(encoder.finish()));
134
135 Ok(tensor.back().await)
136 }
137}
138
139impl AsAny for State {
140 fn as_any(&self) -> &dyn std::any::Any {
141 self
142 }
143}
144
145impl super::model::State for State {
146 #[inline]
147 fn num_batch(&self) -> usize {
148 self.data.shape()[2]
149 }
150
151 #[inline]
152 fn init_shape(&self) -> Shape {
153 let info = &self.info;
154 [info.num_emb, 5 * info.num_layer, 1, 1].into()
155 }
156
157 fn init(&self) -> TensorCpu<f32> {
158 let info = &self.info;
159 let data = (0..info.num_layer)
160 .map(|_| {
161 [
162 vec![0.0; info.num_emb],
163 vec![0.0; info.num_emb],
164 vec![0.0; info.num_emb],
165 vec![f32::MIN; info.num_emb],
166 vec![0.0; info.num_emb],
167 ]
168 .concat()
169 })
170 .collect_vec()
171 .concat();
172 TensorCpu::from_data(self.init_shape(), data).unwrap()
173 }
174
175 fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
176 let start = 5 * layer;
177 let end = start + 4;
178 self.data.view(.., start..end, .., ..)
179 }
180
181 fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
182 let start = 5 * layer + 4;
183 self.data.view(.., start..=start, .., ..)
184 }
185
186 fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError> {
187 tensor.check_shape([self.info.num_emb, self.info.num_layer * 5, 1, 1])?;
188 self.data.load_batch(&tensor, batch)?;
189 Ok(())
190 }
191
192 #[cfg(not(target_arch = "wasm32"))]
193 fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
194 Box::pin(self.back(batch))
195 }
196
197 #[cfg(target_arch = "wasm32")]
198 fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
199 Box::pin(self.back(batch))
200 }
201
202 fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError> {
203 tensor.check_shape([self.info.num_emb, self.info.num_layer * 5, 1, 1])?;
204 let op = TensorOp::blit(&tensor, self.data.view(.., .., batch, ..)?)?;
205 self.context.queue.submit(self.context.encode(&op));
206 Ok(())
207 }
208
209 fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError> {
210 let shape = [self.info.num_emb, self.info.num_layer * 5, 1, 1];
211 let tensor: TensorGpu<_, _> = self.context.tensor_init(shape);
212 let op = TensorOp::blit(self.data.view(.., .., batch, ..)?, &tensor)?;
213 self.context.queue.submit(self.context.encode(&op));
214 Ok(tensor)
215 }
216
217 fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError> {
218 backed.slice(.., layer, .., ..)
219 }
220}
221
222impl DeepClone for State {
223 fn deep_clone(&self) -> Self {
224 let data = self.data.deep_clone();
225 Self {
226 data,
227 ..self.clone()
228 }
229 }
230}
231
232#[derive(Debug, Clone, Serialize, DeserializeSeed)]
233#[serde_seed(seed = "Seed", context = "Context")]
234pub struct Runtime<F: Float> {
235 pub cursors: TensorGpu<u32, ReadWrite>,
236 pub input: TensorGpu<f16, ReadWrite>,
237
238 pub x: TensorGpu<F, ReadWrite>,
239 pub aux_x: TensorGpu<f32, ReadWrite>,
240
241 pub att_x: TensorGpu<F, ReadWrite>,
242 pub att_kx: TensorGpu<F, ReadWrite>,
243 pub att_vx: TensorGpu<F, ReadWrite>,
244 pub att_rx: TensorGpu<F, ReadWrite>,
245 pub att_k: TensorGpu<f32, ReadWrite>,
246 pub att_v: TensorGpu<f32, ReadWrite>,
247 pub att_r: TensorGpu<f32, ReadWrite>,
248 pub att_o: TensorGpu<F, ReadWrite>,
249
250 pub ffn_x: TensorGpu<F, ReadWrite>,
251 pub ffn_kx: TensorGpu<F, ReadWrite>,
252 pub ffn_rx: TensorGpu<F, ReadWrite>,
253 pub ffn_k: TensorGpu<F, ReadWrite>,
254 pub ffn_v: TensorGpu<F, ReadWrite>,
255 pub ffn_r: TensorGpu<F, ReadWrite>,
256}
257
258impl<F: Float> Runtime<F> {
259 pub fn new(context: &Context, info: &ModelInfo, num_token: usize) -> Self {
260 let shape = Shape::new(info.num_emb, num_token, 1, 1);
261 let cursors_shape = Shape::new(num_token, 1, 1, 1);
262 let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1);
263
264 Self {
265 cursors: context.tensor_init(cursors_shape),
266 input: context.tensor_init(shape),
267 x: context.tensor_init(shape),
268 aux_x: context.tensor_init(shape),
269 att_x: context.tensor_init(shape),
270 att_kx: context.tensor_init(shape),
271 att_vx: context.tensor_init(shape),
272 att_rx: context.tensor_init(shape),
273 att_k: context.tensor_init(shape),
274 att_v: context.tensor_init(shape),
275 att_r: context.tensor_init(shape),
276 att_o: context.tensor_init(shape),
277 ffn_x: context.tensor_init(shape),
278 ffn_kx: context.tensor_init(shape),
279 ffn_rx: context.tensor_init(shape),
280 ffn_k: context.tensor_init(hidden_shape),
281 ffn_v: context.tensor_init(shape),
282 ffn_r: context.tensor_init(shape),
283 }
284 }
285}
286
287#[derive(Debug, Clone, Serialize, DeserializeSeed)]
288#[serde_seed(seed = "Seed", context = "Context")]
289pub struct Header<F: Float> {
290 pub head_x: TensorGpu<F, ReadWrite>,
291 pub head_o: TensorGpu<f32, ReadWrite>,
292}
293
294impl<F: Float> Header<F> {
295 pub fn new(context: &Context, info: &ModelInfo, num_header: usize) -> Self {
296 let head_shape = Shape::new(info.num_emb, num_header, 1, 1);
297 let output_shape = Shape::new(info.num_vocab_padded(), num_header, 1, 1);
298
299 Self {
300 head_x: context.tensor_init(head_shape),
301 head_o: context.tensor_init(output_shape),
302 }
303 }
304}
305
306#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
307pub enum Hook {
308 PostEmbedLoaded,
309 PostEmbedLayerNorm,
310 PreAtt(usize),
311 PostAttLayerNorm(usize),
312 PreAttTokenShift(usize),
313 PostAttTokenShift(usize),
314 PreAttLinear(usize),
315 PostAttLinear(usize),
316 PreAttTimeMix(usize),
317 PostAttTimeMix(usize),
318 PreAttOut(usize),
319 PostAttOut(usize),
320 PostAtt(usize),
321 PreFfn(usize),
322 PostFfnLayerNorm(usize),
323 PreFfnTokenShift(usize),
324 PostFfnTokenShift(usize),
325 PreFfnLinear(usize),
326 PostFfnLinear(usize),
327 PostFfnActivate(usize),
328 PreFfnChannelMix(usize),
329 PostFfnChannelMix(usize),
330 PostFfn(usize),
331 PreHead,
332 PostHeadLayerNorm,
333 PostHead,
334}
335
336pub struct RnnJob {
337 commands: Vec<CommandBuffer>,
338 redirect: RnnRedirect,
339
340 embed: TensorCpu<f16>,
341
342 cursors: TensorGpu<u32, ReadWrite>,
343 input: TensorGpu<f16, ReadWrite>,
344 output: TensorGpu<f32, ReadWrite>,
345}
346
347impl Job for RnnJob {
348 type Input = RnnInput;
349 type Output = RnnOutput;
350
351 fn load(&self, input: &RnnChunk) -> Result<(), RuntimeError> {
352 if input.num_token() == 0 {
353 return Ok(());
354 }
355
356 let stack: Vec<TensorCpu<f16>> = input
357 .iter()
358 .map(|chunk| {
359 let num_emb = self.embed.shape()[0];
360 let data = self.embed.data();
361 let data = chunk
362 .iter()
363 .map(|token| match token {
364 &Token::Token(token) => {
365 let start = num_emb * token as usize;
366 let end = start + num_emb;
367 let data = data[start..end].to_vec();
368 TensorCpu::from_data_1d(data)
369 }
370 Token::Embed(tensor) => tensor.clone(),
371 })
372 .collect_vec();
373 match TensorCpu::stack(data, 1) {
374 Ok(tensor) => tensor,
375 Err(_) => TensorCpu::init([num_emb, 0, 1, 1]),
376 }
377 })
378 .collect();
379 let stack = TensorStack::try_from(stack)?;
380
381 let cursors = stack.cursors.clone().into_cursors();
382 let cursors = TensorCpu::from_data(self.cursors.shape(), cursors)?;
383 self.cursors.load(&cursors)?;
384 self.input.load(&stack.tensor)?;
385
386 Ok(())
387 }
388
389 fn submit(&mut self) {
390 let commands = std::mem::take(&mut self.commands);
391 self.output.context.queue.submit(commands);
392 }
393
394 async fn back(self) -> Result<Self::Output, RuntimeError> {
395 let output = self.output.back().await;
396 let batches: Vec<_> = self
397 .redirect
398 .outputs
399 .into_iter()
400 .map(|(start, end)| output.slice(.., start..end, .., ..))
401 .try_collect()?;
402 let batches = batches.into_iter().map(RnnOutputBatch).collect();
403 Ok(RnnOutput(batches))
404 }
405}
406
407#[derive(Debug, Clone)]
408pub struct Frame<F: Float> {
409 pub state: State,
410 pub buffer: Arc<Runtime<F>>,
411 pub header: Arc<Header<F>>,
412}
413
414pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
415pub type HookMap<F> = HashMap<Hook, HookFn<F>>;
416
417#[derive(Clone)]
418pub struct Bundle<F: Float> {
419 model: Model,
420 state: State,
421 hooks: Arc<HookMap<F>>,
422 buffers: ResourceCache<usize, Runtime<F>>,
423 headers: ResourceCache<usize, Header<F>>,
424 phantom: PhantomData<F>,
425}
426
427impl<F: Float> super::model::Bundle for Bundle<F> {
428 #[inline]
429 fn info(&self) -> ModelInfo {
430 self.model.info.clone()
431 }
432
433 #[inline]
434 fn state(&self) -> impl super::model::State + AsAny + 'static {
435 self.state.clone()
436 }
437
438 #[inline]
439 fn model(&self) -> impl Serialize + 'static {
440 self.model.clone()
441 }
442}
443
444impl<F: Float> Bundle<F> {
445 pub fn new(model: Model, num_batch: usize) -> Self {
446 let context = model.context.clone();
447 let info = model.info.clone();
448 let state = {
449 let shape = Shape::new(info.num_emb, 5 * info.num_layer, num_batch, 1);
450 let data = (0..info.num_layer * num_batch)
451 .map(|_| {
452 [
453 vec![0.0; info.num_emb],
454 vec![0.0; info.num_emb],
455 vec![0.0; info.num_emb],
456 vec![f32::MIN; info.num_emb],
457 vec![0.0; info.num_emb],
458 ]
459 .concat()
460 })
461 .collect_vec()
462 .concat();
463 let data = context.tensor_from_data(shape, data).unwrap();
464 State {
465 context,
466 info,
467 data,
468 }
469 };
470 Self {
471 model,
472 state,
473 hooks: Default::default(),
474 buffers: ResourceCache::new(4),
475 headers: ResourceCache::new(4),
476 phantom: PhantomData,
477 }
478 }
479
480 pub fn new_with_hooks(model: Model, num_batch: usize, hooks: HookMap<F>) -> Self {
481 Self {
482 hooks: Arc::new(hooks),
483 ..Self::new(model, num_batch)
484 }
485 }
486
487 fn checkout_buffer(
488 &self,
489 context: &Context,
490 info: &ModelInfo,
491 num_token: usize,
492 ) -> Arc<Runtime<F>> {
493 self.buffers
494 .checkout(num_token, || Runtime::new(context, info, num_token))
495 }
496
497 fn checkout_header(
498 &self,
499 context: &Context,
500 info: &ModelInfo,
501 num_header: usize,
502 ) -> Arc<Header<F>> {
503 self.headers
504 .checkout(num_header, || Header::new(context, info, num_header))
505 }
506}
507
508fn turbo(num_token: usize) -> bool {
509 num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE)
510}
511
512fn hook_op<F: Float>(
513 hooks: &HookMap<F>,
514 hook: &Hook,
515 frame: &Frame<F>,
516) -> Result<TensorOp, TensorError> {
517 match hooks.get(hook) {
518 Some(f) => f(frame.clone()),
519 None => Ok(TensorOp::empty()),
520 }
521}
522
523impl<F: Float> Dispatcher<RnnJob> for Bundle<F> {
524 type Info = RnnInfo;
525
526 fn dispatch(&self, seed: Self::Info) -> Result<RnnJob, RuntimeError> {
527 let model = &self.model;
528 let state = &self.state;
529 let context = &model.context;
530 let info = &model.info;
531 let tensor = &model.tensor;
532
533 let num_token = seed.num_token();
534
535 let redirect = seed.redirect();
536 let num_header = redirect.headers.len();
537
538 let buffer = self.checkout_buffer(context, info, num_token);
539 let header = self.checkout_header(context, info, num_header);
540 let frame = Frame {
541 state: state.clone(),
542 buffer: buffer.clone(),
543 header: header.clone(),
544 };
545
546 context.maintain();
547
548 if num_token == 0 {
549 return Ok(RnnJob {
550 commands: vec![],
551 redirect,
552 embed: model.tensor.embed.w.clone(),
553 cursors: buffer.cursors.clone(),
554 input: buffer.input.clone(),
555 output: header.head_o.clone(),
556 });
557 }
558
559 #[cfg(feature = "trace")]
560 let _span = tracing::trace_span!("build").entered();
561
562 let (head_op, head_x) = redirect.op(&buffer.x, &header.head_x)?;
563
564 let hook_op = |hook: Hook| hook_op(&self.hooks, &hook, &frame);
565 let mut ops = vec![];
566
567 {
568 #[cfg(feature = "trace")]
569 let _span = tracing::trace_span!("embed").entered();
570
571 ops.extend([
572 hook_op(Hook::PostEmbedLoaded)?,
573 TensorOp::layer_norm(
574 &tensor.embed.layer_norm.w,
575 &tensor.embed.layer_norm.b,
576 &buffer.input,
577 Model::LN_EPS,
578 )?,
579 TensorOp::blit(&buffer.input, &buffer.x)?,
580 hook_op(Hook::PostEmbedLayerNorm)?,
581 ]);
582 };
583
584 for (index, layer) in tensor.layers.iter().enumerate() {
585 #[cfg(feature = "trace")]
586 let _span = tracing::trace_span!("layer", index).entered();
587
588 let hooks = self.hooks.clone();
589 let frame = frame.clone();
590 let layer = layer.clone();
591
592 let op = dispatch_layer(hooks, frame, layer, index, num_token, model.rescale)?;
593 ops.push(op);
594
595 if (index + 1) % model.sep == 0 {
596 ops.push(TensorOp::Sep);
597 }
598 }
599
600 {
601 #[cfg(feature = "trace")]
602 let _span = tracing::trace_span!("header").entered();
603
604 let hooks = self.hooks.clone();
605 let frame = frame.clone();
606 let head = model.tensor.head.clone();
607
608 let op = dispatch_header(hooks, frame, head, head_x, num_header, head_op)?;
609 ops.push(op);
610 }
611
612 let commands = {
613 #[cfg(feature = "trace")]
614 let _span = tracing::trace_span!("encode").entered();
615 context.encode(&TensorOp::List(ops))
616 };
617
618 Ok(RnnJob {
619 commands,
620 redirect,
621 embed: model.tensor.embed.w.clone(),
622 cursors: buffer.cursors.clone(),
623 input: buffer.input.clone(),
624 output: header.head_o.clone(),
625 })
626 }
627}
628
629#[allow(clippy::too_many_arguments)]
630fn dispatch_layer<F: Float>(
631 hooks: Arc<HookMap<F>>,
632 frame: Frame<F>,
633 layer: Layer,
634 index: usize,
635 num_token: usize,
636 rescale: usize,
637) -> Result<TensorOp, TensorError> {
638 let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
639 let Frame { state, buffer, .. } = &frame;
640
641 let mut ops = vec![];
642
643 ops.extend([
644 TensorOp::blit(&buffer.x, &buffer.att_x)?,
645 hook_op(Hook::PreAtt(index))?,
646 TensorOp::layer_norm(
647 &layer.att_layer_norm.w,
648 &layer.att_layer_norm.b,
649 &buffer.att_x,
650 Model::LN_EPS,
651 )?,
652 hook_op(Hook::PostAttLayerNorm(index))?,
653 hook_op(Hook::PreAttTokenShift(index))?,
654 TensorOp::token_shift(
655 &buffer.cursors,
656 &layer.att.time_mix_k,
657 state.att(index)?,
658 &buffer.att_x,
659 &buffer.att_kx,
660 false,
661 )?,
662 TensorOp::token_shift(
663 &buffer.cursors,
664 &layer.att.time_mix_v,
665 state.att(index)?,
666 &buffer.att_x,
667 &buffer.att_vx,
668 false,
669 )?,
670 TensorOp::token_shift(
671 &buffer.cursors,
672 &layer.att.time_mix_r,
673 state.att(index)?,
674 &buffer.att_x,
675 &buffer.att_rx,
676 false,
677 )?,
678 hook_op(Hook::PostAttTokenShift(index))?,
679 hook_op(Hook::PreAttLinear(index))?,
680 layer.att.w_k.matmul_op(
681 &buffer.att_kx,
682 &buffer.att_k,
683 Activation::None,
684 turbo(num_token),
685 )?,
686 layer.att.w_v.matmul_op(
687 &buffer.att_vx,
688 &buffer.att_v,
689 Activation::None,
690 turbo(num_token),
691 )?,
692 layer.att.w_r.matmul_op(
693 &buffer.att_rx,
694 &buffer.att_r,
695 Activation::None,
696 turbo(num_token),
697 )?,
698 hook_op(Hook::PostAttLinear(index))?,
699 hook_op(Hook::PreAttTimeMix(index))?,
700 TensorOp::blit(&buffer.att_x, &buffer.aux_x)?,
701 TensorOp::time_mix_v4(
702 &buffer.cursors,
703 &layer.att.time_decay,
704 &layer.att.time_first,
705 state.att(index)?,
706 &buffer.att_k,
707 &buffer.att_v,
708 &buffer.att_r,
709 &buffer.aux_x,
710 )?,
711 TensorOp::blit(&buffer.aux_x, &buffer.att_x)?,
712 hook_op(Hook::PostAttTimeMix(index))?,
713 hook_op(Hook::PreAttOut(index))?,
714 layer.att.w_o.matmul_op(
715 &buffer.att_x,
716 &buffer.att_o,
717 Activation::None,
718 turbo(num_token),
719 )?,
720 hook_op(Hook::PostAttOut(index))?,
721 TensorOp::add(&buffer.att_o, &buffer.x)?,
722 hook_op(Hook::PostAtt(index))?,
723 ]);
724
725 ops.extend([
726 TensorOp::blit(&buffer.x, &buffer.ffn_x)?,
727 hook_op(Hook::PreFfn(index))?,
728 TensorOp::layer_norm(
729 &layer.ffn_layer_norm.w,
730 &layer.ffn_layer_norm.b,
731 &buffer.ffn_x,
732 Model::LN_EPS,
733 )?,
734 hook_op(Hook::PostFfnLayerNorm(index))?,
735 hook_op(Hook::PreFfnTokenShift(index))?,
736 TensorOp::token_shift(
737 &buffer.cursors,
738 &layer.ffn.time_mix_k,
739 state.ffn(index)?,
740 &buffer.ffn_x,
741 &buffer.ffn_kx,
742 false,
743 )?,
744 TensorOp::token_shift(
745 &buffer.cursors,
746 &layer.ffn.time_mix_r,
747 state.ffn(index)?,
748 &buffer.ffn_x,
749 &buffer.ffn_rx,
750 false,
751 )?,
752 hook_op(Hook::PostFfnTokenShift(index))?,
753 hook_op(Hook::PreFfnLinear(index))?,
754 layer.ffn.w_k.matmul_op(
755 &buffer.ffn_kx,
756 &buffer.ffn_k,
757 Activation::SquaredRelu,
758 turbo(num_token),
759 )?,
760 hook_op(Hook::PostFfnActivate(index))?,
761 layer.ffn.w_v.matmul_op_sparse(
762 &buffer.ffn_k,
763 &buffer.ffn_v,
764 Activation::None,
765 turbo(num_token),
766 )?,
767 layer.ffn.w_r.matmul_op(
768 &buffer.ffn_rx,
769 &buffer.ffn_r,
770 Activation::None,
771 turbo(num_token),
772 )?,
773 hook_op(Hook::PostFfnLinear(index))?,
774 hook_op(Hook::PreFfnChannelMix(index))?,
775 TensorOp::channel_mix(
776 &buffer.cursors,
777 state.ffn(index)?,
778 &buffer.ffn_r,
779 &buffer.ffn_v,
780 &buffer.ffn_x,
781 )?,
782 hook_op(Hook::PostFfnChannelMix(index))?,
783 TensorOp::add(&buffer.ffn_x, &buffer.x)?,
784 hook_op(Hook::PostFfn(index))?,
785 ]);
786
787 if (index + 1).is_multiple_of(rescale) {
788 ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
789 }
790
791 Ok(TensorOp::List(ops))
792}
793
794fn dispatch_header<F: Float>(
795 hooks: Arc<HookMap<F>>,
796 frame: Frame<F>,
797 head: Head,
798 head_x: TensorGpu<F, ReadWrite>,
799 num_header: usize,
800 head_op: TensorOp,
801) -> Result<TensorOp, TensorError> {
802 let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
803 let header = &frame.header;
804 let mut ops = vec![head_op];
805
806 if num_header > 0 {
807 ops.extend([
808 hook_op(Hook::PreHead)?,
809 TensorOp::layer_norm(
810 &head.layer_norm.w,
811 &head.layer_norm.b,
812 &head_x,
813 Model::LN_EPS,
814 )?,
815 hook_op(Hook::PostHeadLayerNorm)?,
816 head.w.matmul_op(
817 head_x.view(.., .., .., ..)?,
818 header.head_o.view(.., .., .., ..)?,
819 Activation::None,
820 turbo(num_header),
821 )?,
822 hook_op(Hook::PostHead)?,
823 ]);
824 }
825 Ok(TensorOp::List(ops))
826}
827
828impl<R: Reader> ModelBuilder<R> {
829 pub async fn build_v4(self) -> Result<Model, LoaderError> {
830 let ModelBuilder {
831 context,
832 model,
833 rescale,
834 sep,
835 lora,
836 quant,
837 ..
838 } = self;
839
840 let info = Loader::info(&model)?;
841 let loader = Loader {
842 context: context.clone(),
843 model,
844 lora,
845 };
846
847 let rescale = rescale.unwrap_or(Model::DEFAULT_RESCALE);
848 let sep = sep.unwrap_or(Model::DEFAULT_SEP);
849
850 let embed = Embed {
851 layer_norm: LayerNorm {
852 w: loader.load_vector_f16("blocks.0.ln0.weight")?,
853 b: loader.load_vector_f16("blocks.0.ln0.bias")?,
854 },
855 w: loader.load_matrix_f16_padded_cpu("emb.weight")?,
856 };
857
858 let head = Head {
859 layer_norm: LayerNorm {
860 w: loader.load_vector_f16("ln_out.weight")?,
861 b: loader.load_vector_f16("ln_out.bias")?,
862 },
863 w: Matrix::Fp16(loader.load_matrix_f16("head.weight")?),
864 };
865
866 let submission_index = Some(context.queue.submit(None));
867 _ = context.device.poll(wgpu::PollType::Wait {
868 submission_index,
869 timeout: None,
870 });
871
872 let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant);
873 let load_matrix_discount = |name: String, quant: Quant, discount: f32| {
874 loader.load_matrix_discount(name, quant, discount)
875 };
876
877 let mut layers = vec![];
878 for layer in 0..info.num_layer {
879 let quant = quant.get(&layer).copied().unwrap_or_default();
880 let discount = 2.0_f32.powi(-((layer / rescale) as i32));
881
882 let att_layer_norm = LayerNorm {
883 w: loader.load_vector_f16(format!("blocks.{layer}.ln1.weight"))?,
884 b: loader.load_vector_f16(format!("blocks.{layer}.ln1.bias"))?,
885 };
886
887 let att = format!("blocks.{layer}.att");
888 let time_decay = loader.load_vector_exp_f32(format!("{att}.time_decay"))?;
889 let time_first = loader.load_vector_f32(format!("{att}.time_first"))?;
890 let time_mix_k = loader.load_vector_f16(format!("{att}.time_mix_k"))?;
891 let time_mix_v = loader.load_vector_f16(format!("{att}.time_mix_v"))?;
892 let time_mix_r = loader.load_vector_f16(format!("{att}.time_mix_r"))?;
893
894 let att = Att {
895 time_decay,
896 time_first,
897 time_mix_k,
898 time_mix_v,
899 time_mix_r,
900 w_k: load_matrix(format!("{att}.key.weight"), quant)?,
901 w_v: load_matrix(format!("{att}.value.weight"), quant)?,
902 w_r: load_matrix(format!("{att}.receptance.weight"), quant)?,
903 w_o: load_matrix_discount(format!("{att}.output.weight"), quant, discount)?,
904 };
905
906 let ffn_layer_norm = LayerNorm {
907 w: loader.load_vector_f16(format!("blocks.{layer}.ln2.weight"))?,
908 b: loader.load_vector_f16(format!("blocks.{layer}.ln2.bias"))?,
909 };
910
911 let ffn = format!("blocks.{layer}.ffn");
912 let time_mix_k = loader.load_vector_f16(format!("{ffn}.time_mix_k"))?;
913 let time_mix_r = loader.load_vector_f16(format!("{ffn}.time_mix_r"))?;
914
915 let ffn = Ffn {
916 time_mix_k,
917 time_mix_r,
918 w_r: load_matrix(format!("{ffn}.receptance.weight"), quant)?,
919 w_k: load_matrix(format!("{ffn}.key.weight"), quant)?,
920 w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?,
921 };
922
923 let submission_index = Some(context.queue.submit(None));
924 _ = context.device.poll(wgpu::PollType::Wait {
925 submission_index,
926 timeout: None,
927 });
928
929 layers.push(Layer {
930 att_layer_norm,
931 ffn_layer_norm,
932 att,
933 ffn,
934 })
935 }
936
937 let submission_index = Some(context.queue.submit(None));
938 _ = context.device.poll(wgpu::PollType::Wait {
939 submission_index,
940 timeout: None,
941 });
942
943 let tensor = ModelTensor {
944 embed,
945 head,
946 layers,
947 };
948 let model = {
949 let context = context.clone();
950 let info = info.clone();
951 Model {
952 context,
953 info,
954 rescale,
955 sep,
956 tensor,
957 }
958 };
959 Ok(model)
960 }
961}