web_rwkv/runtime/
v4.rs

1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use half::f16;
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use web_rwkv_derive::DeserializeSeed;
11use wgpu::CommandBuffer;
12
13use super::{
14    infer::{RnnChunk, RnnInfo, RnnInput, RnnOutput, RnnOutputBatch, RnnRedirect, Token},
15    loader::{Loader, LoaderError, Reader},
16    model::{AsAny, ModelBuilder, ModelInfo, Quant, State as _},
17    Dispatcher, Job, RuntimeError,
18};
19use crate::{
20    context::Context,
21    num::Float,
22    tensor::{
23        cache::ResourceCache,
24        kind::ReadWrite,
25        matrix::Matrix,
26        ops::{Activation, TensorCommand, TensorOp},
27        serialization::Seed,
28        shape::Shape,
29        DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit,
30        TensorShape, TensorStack,
31    },
32};
33
34#[derive(Debug, Clone, Serialize, DeserializeSeed)]
35#[serde_seed(seed = "Seed", context = "Context")]
36pub struct Model {
37    pub context: Context,
38    pub info: ModelInfo,
39    pub rescale: usize,
40    pub sep: usize,
41    pub tensor: ModelTensor,
42}
43
44impl Model {
45    pub const LN_EPS: f32 = 1.0e-5;
46    pub const GN_EPS: f32 = 64.0e-5;
47
48    pub const DEFAULT_RESCALE: usize = 6;
49    pub const DEFAULT_SEP: usize = 1024;
50}
51
52#[derive(Debug, Clone, Serialize, DeserializeSeed)]
53#[serde_seed(seed = "Seed", context = "Context")]
54pub struct ModelTensor {
55    pub embed: Embed,
56    pub head: Head,
57    pub layers: Vec<Layer>,
58}
59
60#[derive(Debug, Clone, Serialize, DeserializeSeed)]
61#[serde_seed(seed = "Seed", context = "Context")]
62pub struct LayerNorm {
63    pub w: TensorGpu<f16, ReadWrite>,
64    pub b: TensorGpu<f16, ReadWrite>,
65}
66
67#[derive(Debug, Clone, Serialize, DeserializeSeed)]
68#[serde_seed(seed = "Seed", context = "Context")]
69pub struct Att {
70    pub time_decay: TensorGpu<f32, ReadWrite>,
71    pub time_first: TensorGpu<f32, ReadWrite>,
72
73    pub time_mix_k: TensorGpu<f16, ReadWrite>,
74    pub time_mix_v: TensorGpu<f16, ReadWrite>,
75    pub time_mix_r: TensorGpu<f16, ReadWrite>,
76
77    pub w_k: Matrix,
78    pub w_v: Matrix,
79    pub w_r: Matrix,
80    pub w_o: Matrix,
81}
82
83#[derive(Debug, Clone, Serialize, DeserializeSeed)]
84#[serde_seed(seed = "Seed", context = "Context")]
85pub struct Ffn {
86    pub time_mix_k: TensorGpu<f16, ReadWrite>,
87    pub time_mix_r: TensorGpu<f16, ReadWrite>,
88
89    pub w_k: Matrix,
90    pub w_v: Matrix,
91    pub w_r: Matrix,
92}
93
94#[derive(Debug, Clone, Serialize, DeserializeSeed)]
95#[serde_seed(seed = "Seed", context = "Context")]
96pub struct Layer {
97    pub att_layer_norm: LayerNorm,
98    pub ffn_layer_norm: LayerNorm,
99    pub att: Att,
100    pub ffn: Ffn,
101}
102
103#[derive(Debug, Clone, Serialize, DeserializeSeed)]
104#[serde_seed(seed = "Seed", context = "Context")]
105pub struct Embed {
106    pub layer_norm: LayerNorm,
107    pub w: TensorCpu<f16>,
108}
109
110#[derive(Debug, Clone, Serialize, DeserializeSeed)]
111#[serde_seed(seed = "Seed", context = "Context")]
112pub struct Head {
113    pub layer_norm: LayerNorm,
114    pub w: Matrix,
115}
116
117#[derive(Debug, Clone, Serialize, DeserializeSeed)]
118#[serde_seed(seed = "Seed", context = "Context")]
119pub struct State {
120    pub context: Context,
121    pub info: ModelInfo,
122    pub data: TensorGpu<f32, ReadWrite>,
123}
124
125impl State {
126    async fn back(&self, batch: usize) -> Result<TensorCpu<f32>, TensorError> {
127        let context = &self.context;
128
129        let shape = self.data.shape();
130        let tensor: TensorGpu<f32, ReadWrite> = context.tensor_init([shape[0], shape[1], 1, 1]);
131        let mut encoder = context.device.create_command_encoder(&Default::default());
132        encoder.copy_tensor_batch(&self.data, &tensor, batch, 0)?;
133        context.queue.submit(Some(encoder.finish()));
134
135        Ok(tensor.back().await)
136    }
137}
138
139impl AsAny for State {
140    fn as_any(&self) -> &dyn std::any::Any {
141        self
142    }
143}
144
145impl super::model::State for State {
146    #[inline]
147    fn num_batch(&self) -> usize {
148        self.data.shape()[2]
149    }
150
151    #[inline]
152    fn init_shape(&self) -> Shape {
153        let info = &self.info;
154        [info.num_emb, 5 * info.num_layer, 1, 1].into()
155    }
156
157    fn init(&self) -> TensorCpu<f32> {
158        let info = &self.info;
159        let data = (0..info.num_layer)
160            .map(|_| {
161                [
162                    vec![0.0; info.num_emb],
163                    vec![0.0; info.num_emb],
164                    vec![0.0; info.num_emb],
165                    vec![f32::MIN; info.num_emb],
166                    vec![0.0; info.num_emb],
167                ]
168                .concat()
169            })
170            .collect_vec()
171            .concat();
172        TensorCpu::from_data(self.init_shape(), data).unwrap()
173    }
174
175    fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
176        let start = 5 * layer;
177        let end = start + 4;
178        self.data.view(.., start..end, .., ..)
179    }
180
181    fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
182        let start = 5 * layer + 4;
183        self.data.view(.., start..=start, .., ..)
184    }
185
186    fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError> {
187        tensor.check_shape([self.info.num_emb, self.info.num_layer * 5, 1, 1])?;
188        self.data.load_batch(&tensor, batch)?;
189        Ok(())
190    }
191
192    #[cfg(not(target_arch = "wasm32"))]
193    fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
194        Box::pin(self.back(batch))
195    }
196
197    #[cfg(target_arch = "wasm32")]
198    fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
199        Box::pin(self.back(batch))
200    }
201
202    fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError> {
203        tensor.check_shape([self.info.num_emb, self.info.num_layer * 5, 1, 1])?;
204        let op = TensorOp::blit(&tensor, self.data.view(.., .., batch, ..)?)?;
205        self.context.queue.submit(self.context.encode(&op));
206        Ok(())
207    }
208
209    fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError> {
210        let shape = [self.info.num_emb, self.info.num_layer * 5, 1, 1];
211        let tensor: TensorGpu<_, _> = self.context.tensor_init(shape);
212        let op = TensorOp::blit(self.data.view(.., .., batch, ..)?, &tensor)?;
213        self.context.queue.submit(self.context.encode(&op));
214        Ok(tensor)
215    }
216
217    fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError> {
218        backed.slice(.., layer, .., ..)
219    }
220}
221
222impl DeepClone for State {
223    fn deep_clone(&self) -> Self {
224        let data = self.data.deep_clone();
225        Self {
226            data,
227            ..self.clone()
228        }
229    }
230}
231
232#[derive(Debug, Clone, Serialize, DeserializeSeed)]
233#[serde_seed(seed = "Seed", context = "Context")]
234pub struct Runtime<F: Float> {
235    pub cursors: TensorGpu<u32, ReadWrite>,
236    pub input: TensorGpu<f16, ReadWrite>,
237
238    pub x: TensorGpu<F, ReadWrite>,
239    pub aux_x: TensorGpu<f32, ReadWrite>,
240
241    pub att_x: TensorGpu<F, ReadWrite>,
242    pub att_kx: TensorGpu<F, ReadWrite>,
243    pub att_vx: TensorGpu<F, ReadWrite>,
244    pub att_rx: TensorGpu<F, ReadWrite>,
245    pub att_k: TensorGpu<f32, ReadWrite>,
246    pub att_v: TensorGpu<f32, ReadWrite>,
247    pub att_r: TensorGpu<f32, ReadWrite>,
248    pub att_o: TensorGpu<F, ReadWrite>,
249
250    pub ffn_x: TensorGpu<F, ReadWrite>,
251    pub ffn_kx: TensorGpu<F, ReadWrite>,
252    pub ffn_rx: TensorGpu<F, ReadWrite>,
253    pub ffn_k: TensorGpu<F, ReadWrite>,
254    pub ffn_v: TensorGpu<F, ReadWrite>,
255    pub ffn_r: TensorGpu<F, ReadWrite>,
256}
257
258impl<F: Float> Runtime<F> {
259    pub fn new(context: &Context, info: &ModelInfo, num_token: usize) -> Self {
260        let shape = Shape::new(info.num_emb, num_token, 1, 1);
261        let cursors_shape = Shape::new(num_token, 1, 1, 1);
262        let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1);
263
264        Self {
265            cursors: context.tensor_init(cursors_shape),
266            input: context.tensor_init(shape),
267            x: context.tensor_init(shape),
268            aux_x: context.tensor_init(shape),
269            att_x: context.tensor_init(shape),
270            att_kx: context.tensor_init(shape),
271            att_vx: context.tensor_init(shape),
272            att_rx: context.tensor_init(shape),
273            att_k: context.tensor_init(shape),
274            att_v: context.tensor_init(shape),
275            att_r: context.tensor_init(shape),
276            att_o: context.tensor_init(shape),
277            ffn_x: context.tensor_init(shape),
278            ffn_kx: context.tensor_init(shape),
279            ffn_rx: context.tensor_init(shape),
280            ffn_k: context.tensor_init(hidden_shape),
281            ffn_v: context.tensor_init(shape),
282            ffn_r: context.tensor_init(shape),
283        }
284    }
285}
286
287#[derive(Debug, Clone, Serialize, DeserializeSeed)]
288#[serde_seed(seed = "Seed", context = "Context")]
289pub struct Header<F: Float> {
290    pub head_x: TensorGpu<F, ReadWrite>,
291    pub head_o: TensorGpu<f32, ReadWrite>,
292}
293
294impl<F: Float> Header<F> {
295    pub fn new(context: &Context, info: &ModelInfo, num_header: usize) -> Self {
296        let head_shape = Shape::new(info.num_emb, num_header, 1, 1);
297        let output_shape = Shape::new(info.num_vocab_padded(), num_header, 1, 1);
298
299        Self {
300            head_x: context.tensor_init(head_shape),
301            head_o: context.tensor_init(output_shape),
302        }
303    }
304}
305
306#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
307pub enum Hook {
308    PostEmbedLoaded,
309    PostEmbedLayerNorm,
310    PreAtt(usize),
311    PostAttLayerNorm(usize),
312    PreAttTokenShift(usize),
313    PostAttTokenShift(usize),
314    PreAttLinear(usize),
315    PostAttLinear(usize),
316    PreAttTimeMix(usize),
317    PostAttTimeMix(usize),
318    PreAttOut(usize),
319    PostAttOut(usize),
320    PostAtt(usize),
321    PreFfn(usize),
322    PostFfnLayerNorm(usize),
323    PreFfnTokenShift(usize),
324    PostFfnTokenShift(usize),
325    PreFfnLinear(usize),
326    PostFfnLinear(usize),
327    PostFfnActivate(usize),
328    PreFfnChannelMix(usize),
329    PostFfnChannelMix(usize),
330    PostFfn(usize),
331    PreHead,
332    PostHeadLayerNorm,
333    PostHead,
334}
335
336pub struct RnnJob {
337    commands: Vec<CommandBuffer>,
338    redirect: RnnRedirect,
339
340    embed: TensorCpu<f16>,
341
342    cursors: TensorGpu<u32, ReadWrite>,
343    input: TensorGpu<f16, ReadWrite>,
344    output: TensorGpu<f32, ReadWrite>,
345}
346
347impl Job for RnnJob {
348    type Input = RnnInput;
349    type Output = RnnOutput;
350
351    fn load(&self, input: &RnnChunk) -> Result<(), RuntimeError> {
352        if input.num_token() == 0 {
353            return Ok(());
354        }
355
356        let stack: Vec<TensorCpu<f16>> = input
357            .iter()
358            .map(|chunk| {
359                let num_emb = self.embed.shape()[0];
360                let data = self.embed.data();
361                let data = chunk
362                    .iter()
363                    .map(|token| match token {
364                        &Token::Token(token) => {
365                            let start = num_emb * token as usize;
366                            let end = start + num_emb;
367                            let data = data[start..end].to_vec();
368                            TensorCpu::from_data_1d(data)
369                        }
370                        Token::Embed(tensor) => tensor.clone(),
371                    })
372                    .collect_vec();
373                match TensorCpu::stack(data, 1) {
374                    Ok(tensor) => tensor,
375                    Err(_) => TensorCpu::init([num_emb, 0, 1, 1]),
376                }
377            })
378            .collect();
379        let stack = TensorStack::try_from(stack)?;
380
381        let cursors = stack.cursors.clone().into_cursors();
382        let cursors = TensorCpu::from_data(self.cursors.shape(), cursors)?;
383        self.cursors.load(&cursors)?;
384        self.input.load(&stack.tensor)?;
385
386        Ok(())
387    }
388
389    fn submit(&mut self) {
390        let commands = std::mem::take(&mut self.commands);
391        self.output.context.queue.submit(commands);
392    }
393
394    async fn back(self) -> Result<Self::Output, RuntimeError> {
395        let output = self.output.back().await;
396        let batches: Vec<_> = self
397            .redirect
398            .outputs
399            .into_iter()
400            .map(|(start, end)| output.slice(.., start..end, .., ..))
401            .try_collect()?;
402        let batches = batches.into_iter().map(RnnOutputBatch).collect();
403        Ok(RnnOutput(batches))
404    }
405}
406
407#[derive(Debug, Clone)]
408pub struct Frame<F: Float> {
409    pub state: State,
410    pub buffer: Arc<Runtime<F>>,
411    pub header: Arc<Header<F>>,
412}
413
414pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
415pub type HookMap<F> = HashMap<Hook, HookFn<F>>;
416
417#[derive(Clone)]
418pub struct Bundle<F: Float> {
419    model: Model,
420    state: State,
421    hooks: Arc<HookMap<F>>,
422    buffers: ResourceCache<usize, Runtime<F>>,
423    headers: ResourceCache<usize, Header<F>>,
424    phantom: PhantomData<F>,
425}
426
427impl<F: Float> super::model::Bundle for Bundle<F> {
428    #[inline]
429    fn info(&self) -> ModelInfo {
430        self.model.info.clone()
431    }
432
433    #[inline]
434    fn state(&self) -> impl super::model::State + AsAny + 'static {
435        self.state.clone()
436    }
437
438    #[inline]
439    fn model(&self) -> impl Serialize + 'static {
440        self.model.clone()
441    }
442}
443
444impl<F: Float> Bundle<F> {
445    pub fn new(model: Model, num_batch: usize) -> Self {
446        let context = model.context.clone();
447        let info = model.info.clone();
448        let state = {
449            let shape = Shape::new(info.num_emb, 5 * info.num_layer, num_batch, 1);
450            let data = (0..info.num_layer * num_batch)
451                .map(|_| {
452                    [
453                        vec![0.0; info.num_emb],
454                        vec![0.0; info.num_emb],
455                        vec![0.0; info.num_emb],
456                        vec![f32::MIN; info.num_emb],
457                        vec![0.0; info.num_emb],
458                    ]
459                    .concat()
460                })
461                .collect_vec()
462                .concat();
463            let data = context.tensor_from_data(shape, data).unwrap();
464            State {
465                context,
466                info,
467                data,
468            }
469        };
470        Self {
471            model,
472            state,
473            hooks: Default::default(),
474            buffers: ResourceCache::new(4),
475            headers: ResourceCache::new(4),
476            phantom: PhantomData,
477        }
478    }
479
480    pub fn new_with_hooks(model: Model, num_batch: usize, hooks: HookMap<F>) -> Self {
481        Self {
482            hooks: Arc::new(hooks),
483            ..Self::new(model, num_batch)
484        }
485    }
486
487    fn checkout_buffer(
488        &self,
489        context: &Context,
490        info: &ModelInfo,
491        num_token: usize,
492    ) -> Arc<Runtime<F>> {
493        self.buffers
494            .checkout(num_token, || Runtime::new(context, info, num_token))
495    }
496
497    fn checkout_header(
498        &self,
499        context: &Context,
500        info: &ModelInfo,
501        num_header: usize,
502    ) -> Arc<Header<F>> {
503        self.headers
504            .checkout(num_header, || Header::new(context, info, num_header))
505    }
506}
507
508fn turbo(num_token: usize) -> bool {
509    num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE)
510}
511
512fn hook_op<F: Float>(
513    hooks: &HookMap<F>,
514    hook: &Hook,
515    frame: &Frame<F>,
516) -> Result<TensorOp, TensorError> {
517    match hooks.get(hook) {
518        Some(f) => f(frame.clone()),
519        None => Ok(TensorOp::empty()),
520    }
521}
522
523impl<F: Float> Dispatcher<RnnJob> for Bundle<F> {
524    type Info = RnnInfo;
525
526    fn dispatch(&self, seed: Self::Info) -> Result<RnnJob, RuntimeError> {
527        let model = &self.model;
528        let state = &self.state;
529        let context = &model.context;
530        let info = &model.info;
531        let tensor = &model.tensor;
532
533        let num_token = seed.num_token();
534
535        let redirect = seed.redirect();
536        let num_header = redirect.headers.len();
537
538        let buffer = self.checkout_buffer(context, info, num_token);
539        let header = self.checkout_header(context, info, num_header);
540        let frame = Frame {
541            state: state.clone(),
542            buffer: buffer.clone(),
543            header: header.clone(),
544        };
545
546        context.maintain();
547
548        if num_token == 0 {
549            return Ok(RnnJob {
550                commands: vec![],
551                redirect,
552                embed: model.tensor.embed.w.clone(),
553                cursors: buffer.cursors.clone(),
554                input: buffer.input.clone(),
555                output: header.head_o.clone(),
556            });
557        }
558
559        #[cfg(feature = "trace")]
560        let _span = tracing::trace_span!("build").entered();
561
562        let (head_op, head_x) = redirect.op(&buffer.x, &header.head_x)?;
563
564        let hook_op = |hook: Hook| hook_op(&self.hooks, &hook, &frame);
565        let mut ops = vec![];
566
567        {
568            #[cfg(feature = "trace")]
569            let _span = tracing::trace_span!("embed").entered();
570
571            ops.extend([
572                hook_op(Hook::PostEmbedLoaded)?,
573                TensorOp::layer_norm(
574                    &tensor.embed.layer_norm.w,
575                    &tensor.embed.layer_norm.b,
576                    &buffer.input,
577                    Model::LN_EPS,
578                )?,
579                TensorOp::blit(&buffer.input, &buffer.x)?,
580                hook_op(Hook::PostEmbedLayerNorm)?,
581            ]);
582        };
583
584        for (index, layer) in tensor.layers.iter().enumerate() {
585            #[cfg(feature = "trace")]
586            let _span = tracing::trace_span!("layer", index).entered();
587
588            let hooks = self.hooks.clone();
589            let frame = frame.clone();
590            let layer = layer.clone();
591
592            let op = dispatch_layer(hooks, frame, layer, index, num_token, model.rescale)?;
593            ops.push(op);
594
595            if (index + 1) % model.sep == 0 {
596                ops.push(TensorOp::Sep);
597            }
598        }
599
600        {
601            #[cfg(feature = "trace")]
602            let _span = tracing::trace_span!("header").entered();
603
604            let hooks = self.hooks.clone();
605            let frame = frame.clone();
606            let head = model.tensor.head.clone();
607
608            let op = dispatch_header(hooks, frame, head, head_x, num_header, head_op)?;
609            ops.push(op);
610        }
611
612        let commands = {
613            #[cfg(feature = "trace")]
614            let _span = tracing::trace_span!("encode").entered();
615            context.encode(&TensorOp::List(ops))
616        };
617
618        Ok(RnnJob {
619            commands,
620            redirect,
621            embed: model.tensor.embed.w.clone(),
622            cursors: buffer.cursors.clone(),
623            input: buffer.input.clone(),
624            output: header.head_o.clone(),
625        })
626    }
627}
628
629#[allow(clippy::too_many_arguments)]
630fn dispatch_layer<F: Float>(
631    hooks: Arc<HookMap<F>>,
632    frame: Frame<F>,
633    layer: Layer,
634    index: usize,
635    num_token: usize,
636    rescale: usize,
637) -> Result<TensorOp, TensorError> {
638    let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
639    let Frame { state, buffer, .. } = &frame;
640
641    let mut ops = vec![];
642
643    ops.extend([
644        TensorOp::blit(&buffer.x, &buffer.att_x)?,
645        hook_op(Hook::PreAtt(index))?,
646        TensorOp::layer_norm(
647            &layer.att_layer_norm.w,
648            &layer.att_layer_norm.b,
649            &buffer.att_x,
650            Model::LN_EPS,
651        )?,
652        hook_op(Hook::PostAttLayerNorm(index))?,
653        hook_op(Hook::PreAttTokenShift(index))?,
654        TensorOp::token_shift(
655            &buffer.cursors,
656            &layer.att.time_mix_k,
657            state.att(index)?,
658            &buffer.att_x,
659            &buffer.att_kx,
660            false,
661        )?,
662        TensorOp::token_shift(
663            &buffer.cursors,
664            &layer.att.time_mix_v,
665            state.att(index)?,
666            &buffer.att_x,
667            &buffer.att_vx,
668            false,
669        )?,
670        TensorOp::token_shift(
671            &buffer.cursors,
672            &layer.att.time_mix_r,
673            state.att(index)?,
674            &buffer.att_x,
675            &buffer.att_rx,
676            false,
677        )?,
678        hook_op(Hook::PostAttTokenShift(index))?,
679        hook_op(Hook::PreAttLinear(index))?,
680        layer.att.w_k.matmul_op(
681            &buffer.att_kx,
682            &buffer.att_k,
683            Activation::None,
684            turbo(num_token),
685        )?,
686        layer.att.w_v.matmul_op(
687            &buffer.att_vx,
688            &buffer.att_v,
689            Activation::None,
690            turbo(num_token),
691        )?,
692        layer.att.w_r.matmul_op(
693            &buffer.att_rx,
694            &buffer.att_r,
695            Activation::None,
696            turbo(num_token),
697        )?,
698        hook_op(Hook::PostAttLinear(index))?,
699        hook_op(Hook::PreAttTimeMix(index))?,
700        TensorOp::blit(&buffer.att_x, &buffer.aux_x)?,
701        TensorOp::time_mix_v4(
702            &buffer.cursors,
703            &layer.att.time_decay,
704            &layer.att.time_first,
705            state.att(index)?,
706            &buffer.att_k,
707            &buffer.att_v,
708            &buffer.att_r,
709            &buffer.aux_x,
710        )?,
711        TensorOp::blit(&buffer.aux_x, &buffer.att_x)?,
712        hook_op(Hook::PostAttTimeMix(index))?,
713        hook_op(Hook::PreAttOut(index))?,
714        layer.att.w_o.matmul_op(
715            &buffer.att_x,
716            &buffer.att_o,
717            Activation::None,
718            turbo(num_token),
719        )?,
720        hook_op(Hook::PostAttOut(index))?,
721        TensorOp::add(&buffer.att_o, &buffer.x)?,
722        hook_op(Hook::PostAtt(index))?,
723    ]);
724
725    ops.extend([
726        TensorOp::blit(&buffer.x, &buffer.ffn_x)?,
727        hook_op(Hook::PreFfn(index))?,
728        TensorOp::layer_norm(
729            &layer.ffn_layer_norm.w,
730            &layer.ffn_layer_norm.b,
731            &buffer.ffn_x,
732            Model::LN_EPS,
733        )?,
734        hook_op(Hook::PostFfnLayerNorm(index))?,
735        hook_op(Hook::PreFfnTokenShift(index))?,
736        TensorOp::token_shift(
737            &buffer.cursors,
738            &layer.ffn.time_mix_k,
739            state.ffn(index)?,
740            &buffer.ffn_x,
741            &buffer.ffn_kx,
742            false,
743        )?,
744        TensorOp::token_shift(
745            &buffer.cursors,
746            &layer.ffn.time_mix_r,
747            state.ffn(index)?,
748            &buffer.ffn_x,
749            &buffer.ffn_rx,
750            false,
751        )?,
752        hook_op(Hook::PostFfnTokenShift(index))?,
753        hook_op(Hook::PreFfnLinear(index))?,
754        layer.ffn.w_k.matmul_op(
755            &buffer.ffn_kx,
756            &buffer.ffn_k,
757            Activation::SquaredRelu,
758            turbo(num_token),
759        )?,
760        hook_op(Hook::PostFfnActivate(index))?,
761        layer.ffn.w_v.matmul_op_sparse(
762            &buffer.ffn_k,
763            &buffer.ffn_v,
764            Activation::None,
765            turbo(num_token),
766        )?,
767        layer.ffn.w_r.matmul_op(
768            &buffer.ffn_rx,
769            &buffer.ffn_r,
770            Activation::None,
771            turbo(num_token),
772        )?,
773        hook_op(Hook::PostFfnLinear(index))?,
774        hook_op(Hook::PreFfnChannelMix(index))?,
775        TensorOp::channel_mix(
776            &buffer.cursors,
777            state.ffn(index)?,
778            &buffer.ffn_r,
779            &buffer.ffn_v,
780            &buffer.ffn_x,
781        )?,
782        hook_op(Hook::PostFfnChannelMix(index))?,
783        TensorOp::add(&buffer.ffn_x, &buffer.x)?,
784        hook_op(Hook::PostFfn(index))?,
785    ]);
786
787    if (index + 1).is_multiple_of(rescale) {
788        ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
789    }
790
791    Ok(TensorOp::List(ops))
792}
793
794fn dispatch_header<F: Float>(
795    hooks: Arc<HookMap<F>>,
796    frame: Frame<F>,
797    head: Head,
798    head_x: TensorGpu<F, ReadWrite>,
799    num_header: usize,
800    head_op: TensorOp,
801) -> Result<TensorOp, TensorError> {
802    let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
803    let header = &frame.header;
804    let mut ops = vec![head_op];
805
806    if num_header > 0 {
807        ops.extend([
808            hook_op(Hook::PreHead)?,
809            TensorOp::layer_norm(
810                &head.layer_norm.w,
811                &head.layer_norm.b,
812                &head_x,
813                Model::LN_EPS,
814            )?,
815            hook_op(Hook::PostHeadLayerNorm)?,
816            head.w.matmul_op(
817                head_x.view(.., .., .., ..)?,
818                header.head_o.view(.., .., .., ..)?,
819                Activation::None,
820                turbo(num_header),
821            )?,
822            hook_op(Hook::PostHead)?,
823        ]);
824    }
825    Ok(TensorOp::List(ops))
826}
827
828impl<R: Reader> ModelBuilder<R> {
829    pub async fn build_v4(self) -> Result<Model, LoaderError> {
830        let ModelBuilder {
831            context,
832            model,
833            rescale,
834            sep,
835            lora,
836            quant,
837            ..
838        } = self;
839
840        let info = Loader::info(&model)?;
841        let loader = Loader {
842            context: context.clone(),
843            model,
844            lora,
845        };
846
847        let rescale = rescale.unwrap_or(Model::DEFAULT_RESCALE);
848        let sep = sep.unwrap_or(Model::DEFAULT_SEP);
849
850        let embed = Embed {
851            layer_norm: LayerNorm {
852                w: loader.load_vector_f16("blocks.0.ln0.weight")?,
853                b: loader.load_vector_f16("blocks.0.ln0.bias")?,
854            },
855            w: loader.load_matrix_f16_padded_cpu("emb.weight")?,
856        };
857
858        let head = Head {
859            layer_norm: LayerNorm {
860                w: loader.load_vector_f16("ln_out.weight")?,
861                b: loader.load_vector_f16("ln_out.bias")?,
862            },
863            w: Matrix::Fp16(loader.load_matrix_f16("head.weight")?),
864        };
865
866        let submission_index = Some(context.queue.submit(None));
867        _ = context.device.poll(wgpu::PollType::Wait {
868            submission_index,
869            timeout: None,
870        });
871
872        let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant);
873        let load_matrix_discount = |name: String, quant: Quant, discount: f32| {
874            loader.load_matrix_discount(name, quant, discount)
875        };
876
877        let mut layers = vec![];
878        for layer in 0..info.num_layer {
879            let quant = quant.get(&layer).copied().unwrap_or_default();
880            let discount = 2.0_f32.powi(-((layer / rescale) as i32));
881
882            let att_layer_norm = LayerNorm {
883                w: loader.load_vector_f16(format!("blocks.{layer}.ln1.weight"))?,
884                b: loader.load_vector_f16(format!("blocks.{layer}.ln1.bias"))?,
885            };
886
887            let att = format!("blocks.{layer}.att");
888            let time_decay = loader.load_vector_exp_f32(format!("{att}.time_decay"))?;
889            let time_first = loader.load_vector_f32(format!("{att}.time_first"))?;
890            let time_mix_k = loader.load_vector_f16(format!("{att}.time_mix_k"))?;
891            let time_mix_v = loader.load_vector_f16(format!("{att}.time_mix_v"))?;
892            let time_mix_r = loader.load_vector_f16(format!("{att}.time_mix_r"))?;
893
894            let att = Att {
895                time_decay,
896                time_first,
897                time_mix_k,
898                time_mix_v,
899                time_mix_r,
900                w_k: load_matrix(format!("{att}.key.weight"), quant)?,
901                w_v: load_matrix(format!("{att}.value.weight"), quant)?,
902                w_r: load_matrix(format!("{att}.receptance.weight"), quant)?,
903                w_o: load_matrix_discount(format!("{att}.output.weight"), quant, discount)?,
904            };
905
906            let ffn_layer_norm = LayerNorm {
907                w: loader.load_vector_f16(format!("blocks.{layer}.ln2.weight"))?,
908                b: loader.load_vector_f16(format!("blocks.{layer}.ln2.bias"))?,
909            };
910
911            let ffn = format!("blocks.{layer}.ffn");
912            let time_mix_k = loader.load_vector_f16(format!("{ffn}.time_mix_k"))?;
913            let time_mix_r = loader.load_vector_f16(format!("{ffn}.time_mix_r"))?;
914
915            let ffn = Ffn {
916                time_mix_k,
917                time_mix_r,
918                w_r: load_matrix(format!("{ffn}.receptance.weight"), quant)?,
919                w_k: load_matrix(format!("{ffn}.key.weight"), quant)?,
920                w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?,
921            };
922
923            let submission_index = Some(context.queue.submit(None));
924            _ = context.device.poll(wgpu::PollType::Wait {
925                submission_index,
926                timeout: None,
927            });
928
929            layers.push(Layer {
930                att_layer_norm,
931                ffn_layer_norm,
932                att,
933                ffn,
934            })
935        }
936
937        let submission_index = Some(context.queue.submit(None));
938        _ = context.device.poll(wgpu::PollType::Wait {
939            submission_index,
940            timeout: None,
941        });
942
943        let tensor = ModelTensor {
944            embed,
945            head,
946            layers,
947        };
948        let model = {
949            let context = context.clone();
950            let info = info.clone();
951            Model {
952                context,
953                info,
954                rescale,
955                sep,
956                tensor,
957            }
958        };
959        Ok(model)
960    }
961}