web_rwkv/runtime/
v7.rs

1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use half::f16;
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use web_rwkv_derive::DeserializeSeed;
11use wgpu::CommandBuffer;
12
13use super::{
14    infer::{RnnChunk, RnnInfo, RnnInput, RnnOutput, RnnOutputBatch, RnnRedirect, Token},
15    loader::{Loader, LoaderError, Reader},
16    model::{AsAny, ModelBuilder, ModelCustomInfo, ModelInfo, State as _},
17    Dispatcher, Job, RuntimeError,
18};
19use crate::{
20    context::Context,
21    num::Float,
22    runtime::model::Quant,
23    tensor::{
24        cache::ResourceCache,
25        kind::ReadWrite,
26        matrix::Matrix,
27        ops::{Activation, TensorCommand, TensorOp},
28        serialization::Seed,
29        shape::{Shape, TensorDimension},
30        DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit,
31        TensorReshape, TensorShape, TensorStack,
32    },
33};
34
35#[derive(Debug, Clone, Serialize, DeserializeSeed)]
36#[serde_seed(seed = "Seed", context = "Context")]
37pub struct Model {
38    pub context: Context,
39    pub info: ModelInfo,
40    pub rescale: usize,
41    pub sep: usize,
42    pub tensor: ModelTensor,
43}
44
45impl Model {
46    pub const L2_EPS: f32 = 1.0e-12;
47    pub const LN_EPS: f32 = 1.0e-5;
48    pub const GN_EPS: f32 = 64.0e-5;
49
50    pub const DEFAULT_RESCALE: usize = 1024;
51    pub const DEFAULT_SEP: usize = 1024;
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
55pub struct CustomInfo {
56    pub w: usize,
57    pub a: usize,
58    pub g: usize,
59    pub v: usize,
60}
61
62#[derive(Debug, Clone, Serialize, DeserializeSeed)]
63#[serde_seed(seed = "Seed", context = "Context")]
64pub struct ModelTensor {
65    pub embed: Embed,
66    pub head: Head,
67    pub layers: Vec<Layer>,
68}
69
70#[derive(Debug, Clone, Serialize, DeserializeSeed)]
71#[serde_seed(seed = "Seed", context = "Context")]
72pub struct LayerNorm {
73    pub w: TensorGpu<f16, ReadWrite>,
74    pub b: TensorGpu<f16, ReadWrite>,
75}
76
77#[derive(Debug, Clone, Serialize, DeserializeSeed)]
78#[serde_seed(seed = "Seed", context = "Context")]
79pub struct Att {
80    pub x_r: TensorGpu<f16, ReadWrite>,
81    pub x_w: TensorGpu<f16, ReadWrite>,
82    pub x_k: TensorGpu<f16, ReadWrite>,
83    pub x_v: TensorGpu<f16, ReadWrite>,
84    pub x_a: TensorGpu<f16, ReadWrite>,
85    pub x_g: TensorGpu<f16, ReadWrite>,
86
87    pub w0: TensorGpu<f16, ReadWrite>,
88    pub a0: TensorGpu<f16, ReadWrite>,
89    pub v0: TensorGpu<f16, ReadWrite>,
90
91    pub w1: Matrix,
92    pub w2: Matrix,
93    pub a1: Matrix,
94    pub a2: Matrix,
95    pub g1: Matrix,
96    pub g2: Matrix,
97    pub v1: Matrix,
98    pub v2: Matrix,
99
100    pub r_k: TensorGpu<f16, ReadWrite>,
101    pub k_k: TensorGpu<f16, ReadWrite>,
102    pub k_a: TensorGpu<f16, ReadWrite>,
103
104    pub w_k: Matrix,
105    pub w_v: Matrix,
106    pub w_r: Matrix,
107    pub w_o: Matrix,
108
109    pub gn: LayerNorm,
110}
111
112#[derive(Debug, Clone, Serialize, DeserializeSeed)]
113#[serde_seed(seed = "Seed", context = "Context")]
114pub struct Ffn {
115    pub x_k: TensorGpu<f16, ReadWrite>,
116
117    pub w_k: Matrix,
118    pub w_v: Matrix,
119}
120
121#[derive(Debug, Clone, Serialize, DeserializeSeed)]
122#[serde_seed(seed = "Seed", context = "Context")]
123pub struct Layer {
124    pub att_ln: LayerNorm,
125    pub ffn_ln: LayerNorm,
126    pub att: Att,
127    pub ffn: Ffn,
128}
129
130#[derive(Debug, Clone, Serialize, DeserializeSeed)]
131#[serde_seed(seed = "Seed", context = "Context")]
132pub struct Embed {
133    pub ln: LayerNorm,
134    pub w: TensorCpu<f16>,
135}
136
137#[derive(Debug, Clone, Serialize, DeserializeSeed)]
138#[serde_seed(seed = "Seed", context = "Context")]
139pub struct Head {
140    pub ln: LayerNorm,
141    pub w: Matrix,
142}
143
144#[derive(Debug, Clone, Serialize, DeserializeSeed)]
145#[serde_seed(seed = "Seed", context = "Context")]
146pub struct State {
147    pub context: Context,
148    pub info: ModelInfo,
149    pub data: Vec<TensorGpu<f32, ReadWrite>>,
150}
151
152impl State {
153    async fn back(&self, batch: usize) -> Result<TensorCpu<f32>, TensorError> {
154        let context = &self.context;
155        let mut tensors = Vec::with_capacity(self.info.num_layer);
156        let mut encoder = context.device.create_command_encoder(&Default::default());
157        for data in self.data.iter() {
158            let shape = data.shape();
159            let destination = context.tensor_init([shape[0], shape[1], 1, 1]);
160            encoder.copy_tensor_batch(data, &destination, batch, 0)?;
161            tensors.push(destination);
162        }
163        context.queue.submit(Some(encoder.finish()));
164
165        let mut backed = Vec::with_capacity(tensors.len());
166        for tensor in tensors {
167            backed.push(tensor.back().await);
168        }
169        TensorCpu::stack(backed, 2)
170    }
171}
172
173impl AsAny for State {
174    fn as_any(&self) -> &dyn std::any::Any {
175        self
176    }
177}
178
179impl super::model::State for State {
180    #[inline]
181    fn num_batch(&self) -> usize {
182        self.data[0].shape()[2]
183    }
184
185    #[inline]
186    fn init_shape(&self) -> Shape {
187        let info = &self.info;
188        let head_size = info.num_emb / info.num_head;
189        [info.num_emb, head_size + 2, info.num_layer, 1].into()
190    }
191
192    fn init(&self) -> TensorCpu<f32> {
193        let shape = self.init_shape();
194        let data = vec![0.0; shape.len()];
195        TensorCpu::from_data(shape, data).unwrap()
196    }
197
198    fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
199        let head_size = self.info.num_emb / self.info.num_head;
200        let end = head_size + 1;
201        self.data[layer].view(.., 0..end, .., ..)
202    }
203
204    fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
205        let head_size = self.info.num_emb / self.info.num_head;
206        let start = head_size + 1;
207        self.data[layer].view(.., start, .., ..)
208    }
209
210    fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError> {
211        let head_size = self.info.num_emb / self.info.num_head;
212        tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
213        for (data, source) in self.data.iter().zip(tensor.split(2)?.into_iter()) {
214            data.load_batch(&source, batch)?;
215        }
216        Ok(())
217    }
218
219    #[cfg(not(target_arch = "wasm32"))]
220    fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
221        Box::pin(self.back(batch))
222    }
223
224    #[cfg(target_arch = "wasm32")]
225    fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
226        Box::pin(self.back(batch))
227    }
228
229    fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError> {
230        let head_size = self.info.num_emb / self.info.num_head;
231        tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
232
233        let context = &self.context;
234        let mut ops = Vec::with_capacity(self.data.len());
235        for (layer, data) in self.data.iter().enumerate() {
236            ops.push(TensorOp::blit(
237                tensor.view(.., .., layer, ..)?,
238                data.view(.., .., batch, ..)?,
239            )?);
240        }
241        context.queue.submit(context.encode(&TensorOp::List(ops)));
242
243        Ok(())
244    }
245
246    fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError> {
247        let context = &self.context;
248        let head_size = self.info.num_emb / self.info.num_head;
249        let shape = [self.info.num_emb, head_size + 2, self.info.num_layer, 1];
250        let tensor: TensorGpu<_, _> = context.tensor_init(shape);
251
252        let mut ops = Vec::with_capacity(self.data.len());
253        for (layer, data) in self.data.iter().enumerate() {
254            ops.push(TensorOp::blit(
255                data.view(.., .., batch, ..)?,
256                tensor.view(.., .., layer, ..)?,
257            )?);
258        }
259        context.queue.submit(context.encode(&TensorOp::List(ops)));
260
261        Ok(tensor)
262    }
263
264    fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError> {
265        backed.slice(.., 0, layer, ..)
266    }
267}
268
269impl DeepClone for State {
270    fn deep_clone(&self) -> Self {
271        let data = self.data.iter().map(|tensor| tensor.deep_clone()).collect();
272        Self {
273            data,
274            ..self.clone()
275        }
276    }
277}
278
279#[derive(Debug, Clone, Serialize, DeserializeSeed)]
280#[serde_seed(seed = "Seed", context = "Context")]
281pub struct Runtime<F: Float> {
282    pub cursors: TensorGpu<u32, ReadWrite>,
283    pub input: TensorGpu<f16, ReadWrite>,
284
285    pub x: TensorGpu<F, ReadWrite>,
286
287    pub att_x: TensorGpu<F, ReadWrite>,
288    pub att_v0: TensorGpu<F, ReadWrite>,
289
290    pub att_rx: TensorGpu<F, ReadWrite>,
291    pub att_wx: TensorGpu<F, ReadWrite>,
292    pub att_kx: TensorGpu<F, ReadWrite>,
293    pub att_vx: TensorGpu<F, ReadWrite>,
294    pub att_ax: TensorGpu<F, ReadWrite>,
295    pub att_gx: TensorGpu<F, ReadWrite>,
296
297    pub att_r: TensorGpu<F, ReadWrite>,
298    pub att_w: TensorGpu<F, ReadWrite>,
299    pub att_k: TensorGpu<F, ReadWrite>,
300    pub att_v: TensorGpu<F, ReadWrite>,
301    pub att_a: TensorGpu<F, ReadWrite>,
302    pub att_g: TensorGpu<F, ReadWrite>,
303    pub att_o: TensorGpu<F, ReadWrite>,
304
305    pub att_kk: TensorGpu<F, ReadWrite>,
306    pub att_vv: TensorGpu<F, ReadWrite>,
307
308    pub att_n: TensorGpu<F, ReadWrite>,
309
310    /// Time decay LoRA intermediate.
311    pub aux_w: TensorGpu<F, ReadWrite>,
312    pub aux_a: TensorGpu<F, ReadWrite>,
313    pub aux_g: TensorGpu<F, ReadWrite>,
314    pub aux_v: TensorGpu<F, ReadWrite>,
315
316    pub ffn_x: TensorGpu<F, ReadWrite>,
317    pub ffn_kx: TensorGpu<F, ReadWrite>,
318    pub ffn_k: TensorGpu<F, ReadWrite>,
319    pub ffn_v: TensorGpu<F, ReadWrite>,
320}
321
322impl<F: Float> Runtime<F> {
323    pub fn new(context: &Context, info: &ModelInfo, num_token: usize) -> Self {
324        let ModelCustomInfo::V7(custom) = info.custom else {
325            unreachable!()
326        };
327
328        let shape = Shape::new(info.num_emb, num_token, 1, 1);
329        let cursors_shape = Shape::new(num_token, 1, 1, 1);
330        let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1);
331
332        Self {
333            cursors: context.tensor_init(cursors_shape),
334            input: context.tensor_init(shape),
335            x: context.tensor_init(shape),
336            att_x: context.tensor_init(shape),
337            att_v0: context.tensor_init(shape),
338            att_rx: context.tensor_init(shape),
339            att_wx: context.tensor_init(shape),
340            att_kx: context.tensor_init(shape),
341            att_vx: context.tensor_init(shape),
342            att_ax: context.tensor_init(shape),
343            att_gx: context.tensor_init(shape),
344            att_r: context.tensor_init(shape),
345            att_w: context.tensor_init(shape),
346            att_k: context.tensor_init(shape),
347            att_v: context.tensor_init(shape),
348            att_a: context.tensor_init(shape),
349            att_g: context.tensor_init(shape),
350            att_o: context.tensor_init(shape),
351            att_kk: context.tensor_init(shape),
352            att_vv: context.tensor_init(shape),
353            att_n: context.tensor_init([shape[0], shape[1], 4, 1]),
354            aux_w: context.tensor_init([custom.w, shape[1], 1, 1]),
355            aux_a: context.tensor_init([custom.a, shape[1], 1, 1]),
356            aux_g: context.tensor_init([custom.g, shape[1], 1, 1]),
357            aux_v: context.tensor_init([custom.v, shape[1], 1, 1]),
358            ffn_x: context.tensor_init(shape),
359            ffn_kx: context.tensor_init(shape),
360            ffn_k: context.tensor_init(hidden_shape),
361            ffn_v: context.tensor_init(shape),
362        }
363    }
364}
365
366#[derive(Debug, Clone, Serialize, DeserializeSeed)]
367#[serde_seed(seed = "Seed", context = "Context")]
368pub struct Header<F: Float> {
369    pub head_x: TensorGpu<F, ReadWrite>,
370    pub head_o: TensorGpu<f32, ReadWrite>,
371}
372
373impl<F: Float> Header<F> {
374    pub fn new(context: &Context, info: &ModelInfo, num_header: usize) -> Self {
375        let head_shape = Shape::new(info.num_emb, num_header, 1, 1);
376        let output_shape = Shape::new(info.num_vocab_padded(), num_header, 1, 1);
377
378        Self {
379            head_x: context.tensor_init(head_shape),
380            head_o: context.tensor_init(output_shape),
381        }
382    }
383}
384
385#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
386pub enum Hook {
387    PostEmbedLoaded,
388    PostEmbedLayerNorm,
389    PreAtt(usize),
390    PostAttLayerNorm(usize),
391    PreAttTokenShift(usize),
392    PostAttTokenShift(usize),
393    PreAttLinear(usize),
394    PostAttLinear(usize),
395    PreAttAdapt(usize),
396    PostAttAdapt(usize),
397    PreAttControl(usize),
398    PostAttControl(usize),
399    PreAttValueResidual(usize),
400    PostAttValueResidual(usize),
401    PreAttTimeMix(usize),
402    PostAttTimeMix(usize),
403    PreAttGate(usize),
404    PostAttGate(usize),
405    PreAttOut(usize),
406    PostAttOut(usize),
407    PostAtt(usize),
408    PreFfn(usize),
409    PostFfnLayerNorm(usize),
410    PreFfnTokenShift(usize),
411    PostFfnTokenShift(usize),
412    PreFfnLinear(usize),
413    PostFfnLinear(usize),
414    PostFfnActivate(usize),
415    PreFfnChannelMix(usize),
416    PostFfnChannelMix(usize),
417    PostFfn(usize),
418    PreHead,
419    PostHeadLayerNorm,
420    PostHead,
421}
422
423pub struct RnnJob {
424    commands: Vec<CommandBuffer>,
425    redirect: RnnRedirect,
426
427    embed: TensorCpu<f16>,
428
429    cursors: TensorGpu<u32, ReadWrite>,
430    input: TensorGpu<f16, ReadWrite>,
431    output: TensorGpu<f32, ReadWrite>,
432}
433
434impl Job for RnnJob {
435    type Input = RnnInput;
436    type Output = RnnOutput;
437
438    fn load(&self, input: &RnnChunk) -> Result<(), RuntimeError> {
439        if input.num_token() == 0 {
440            return Ok(());
441        }
442
443        let stack: Vec<TensorCpu<f16>> = input
444            .iter()
445            .map(|chunk| {
446                let num_emb = self.embed.shape()[0];
447                let data = self.embed.data();
448                let data = chunk
449                    .iter()
450                    .map(|token| match token {
451                        &Token::Token(token) => {
452                            let start = num_emb * token as usize;
453                            let end = start + num_emb;
454                            let data = data[start..end].to_vec();
455                            TensorCpu::from_data_1d(data)
456                        }
457                        Token::Embed(tensor) => tensor.clone(),
458                    })
459                    .collect_vec();
460                match TensorCpu::stack(data, 1) {
461                    Ok(tensor) => tensor,
462                    Err(_) => TensorCpu::init([num_emb, 0, 1, 1]),
463                }
464            })
465            .collect();
466        let stack = TensorStack::try_from(stack)?;
467
468        let cursors = stack.cursors.clone().into_cursors();
469        let cursors = TensorCpu::from_data(self.cursors.shape(), cursors)?;
470        self.cursors.load(&cursors)?;
471        self.input.load(&stack.tensor)?;
472
473        Ok(())
474    }
475
476    fn submit(&mut self) {
477        let commands = std::mem::take(&mut self.commands);
478        self.output.context.queue.submit(commands);
479    }
480
481    async fn back(self) -> Result<Self::Output, RuntimeError> {
482        let output = self.output.back().await;
483        let batches: Vec<_> = self
484            .redirect
485            .outputs
486            .into_iter()
487            .map(|(start, end)| output.slice(.., start..end, .., ..))
488            .try_collect()?;
489        let batches = batches.into_iter().map(RnnOutputBatch).collect();
490        Ok(RnnOutput(batches))
491    }
492}
493
494#[derive(Debug, Clone)]
495pub struct Frame<F: Float> {
496    pub state: State,
497    pub buffer: Arc<Runtime<F>>,
498    pub header: Arc<Header<F>>,
499}
500
501pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
502pub type HookMap<F> = HashMap<Hook, HookFn<F>>;
503
504#[derive(Clone)]
505pub struct Bundle<F: Float> {
506    model: Model,
507    state: State,
508    hooks: Arc<HookMap<F>>,
509    buffers: ResourceCache<usize, Runtime<F>>,
510    headers: ResourceCache<usize, Header<F>>,
511    phantom: PhantomData<F>,
512}
513
514impl<F: Float> Bundle<F> {
515    pub fn new(model: Model, num_batch: usize) -> Self {
516        let context = model.context.clone();
517        let info = model.info.clone();
518        let state = {
519            let head_size = info.num_emb / info.num_head;
520            let shape = Shape::new(info.num_emb, head_size + 2, num_batch, 1);
521            let data = (0..info.num_layer).map(|_| context.zeros(shape)).collect();
522            State {
523                context,
524                info,
525                data,
526            }
527        };
528        Self {
529            model,
530            state,
531            hooks: Default::default(),
532            buffers: ResourceCache::new(4),
533            headers: ResourceCache::new(4),
534            phantom: PhantomData,
535        }
536    }
537
538    pub fn new_with_hooks(model: Model, num_batch: usize, hooks: HookMap<F>) -> Self {
539        Self {
540            hooks: Arc::new(hooks),
541            ..Self::new(model, num_batch)
542        }
543    }
544
545    fn checkout_buffer(
546        &self,
547        context: &Context,
548        info: &ModelInfo,
549        num_token: usize,
550    ) -> Arc<Runtime<F>> {
551        self.buffers
552            .checkout(num_token, || Runtime::new(context, info, num_token))
553    }
554
555    fn checkout_header(
556        &self,
557        context: &Context,
558        info: &ModelInfo,
559        num_header: usize,
560    ) -> Arc<Header<F>> {
561        self.headers
562            .checkout(num_header, || Header::new(context, info, num_header))
563    }
564}
565
566impl<F: Float> super::model::Bundle for Bundle<F> {
567    #[inline]
568    fn info(&self) -> ModelInfo {
569        self.model.info.clone()
570    }
571
572    #[inline]
573    fn state(&self) -> impl super::model::State + AsAny + 'static {
574        self.state.clone()
575    }
576
577    #[inline]
578    fn model(&self) -> impl Serialize + 'static {
579        self.model.clone()
580    }
581}
582
583fn turbo(num_token: usize) -> bool {
584    num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE)
585}
586
587fn hook_op<F: Float>(
588    hooks: &HookMap<F>,
589    hook: &Hook,
590    frame: &Frame<F>,
591) -> Result<TensorOp, TensorError> {
592    match hooks.get(hook) {
593        Some(f) => f(frame.clone()),
594        None => Ok(TensorOp::empty()),
595    }
596}
597
598impl<F: Float> Dispatcher<RnnJob> for Bundle<F> {
599    type Info = RnnInfo;
600
601    fn dispatch(&self, seed: Self::Info) -> Result<RnnJob, RuntimeError> {
602        let model = &self.model;
603        let state = &self.state;
604        let context = &model.context;
605        let info = &model.info;
606        let tensor = &model.tensor;
607
608        let num_token = seed.num_token();
609        let head_size = info.num_emb / info.num_head;
610
611        let redirect = seed.redirect();
612        let num_header = redirect.headers.len();
613
614        let buffer = self.checkout_buffer(context, info, num_token);
615        let header = self.checkout_header(context, info, num_header);
616        let frame = Frame {
617            state: state.clone(),
618            buffer: buffer.clone(),
619            header: header.clone(),
620        };
621
622        context.maintain();
623        self.buffers.maintain();
624        self.headers.maintain();
625
626        if num_token == 0 {
627            return Ok(RnnJob {
628                commands: vec![],
629                redirect,
630                embed: model.tensor.embed.w.clone(),
631                cursors: buffer.cursors.clone(),
632                input: buffer.input.clone(),
633                output: header.head_o.clone(),
634            });
635        }
636
637        #[cfg(feature = "trace")]
638        let _span = tracing::trace_span!("build").entered();
639
640        let (head_op, head_x) = redirect.op(&buffer.x, &header.head_x)?;
641
642        let hook_op = |hook: Hook| hook_op(&self.hooks, &hook, &frame);
643        let mut ops = vec![];
644
645        {
646            #[cfg(feature = "trace")]
647            let _span = tracing::trace_span!("embed").entered();
648
649            ops.extend([
650                hook_op(Hook::PostEmbedLoaded)?,
651                TensorOp::layer_norm(
652                    &tensor.embed.ln.w,
653                    &tensor.embed.ln.b,
654                    &buffer.input,
655                    Model::LN_EPS,
656                )?,
657                TensorOp::blit(&buffer.input, &buffer.x)?,
658                hook_op(Hook::PostEmbedLayerNorm)?,
659            ]);
660        };
661
662        for (index, layer) in tensor.layers.iter().enumerate() {
663            #[cfg(feature = "trace")]
664            let _span = tracing::trace_span!("layer", index).entered();
665
666            let hooks = self.hooks.clone();
667            let frame = frame.clone();
668            let layer = layer.clone();
669
670            let op = dispatch_layer(
671                hooks,
672                frame,
673                layer,
674                index,
675                num_token,
676                head_size,
677                model.rescale,
678            )?;
679            ops.push(op);
680
681            if (index + 1) % model.sep == 0 {
682                ops.push(TensorOp::Sep);
683            }
684        }
685
686        {
687            #[cfg(feature = "trace")]
688            let _span = tracing::trace_span!("header").entered();
689
690            let hooks = self.hooks.clone();
691            let frame = frame.clone();
692            let head = model.tensor.head.clone();
693
694            let op = dispatch_header(hooks, frame, head, head_x, num_header, head_op)?;
695            ops.push(op);
696        }
697
698        let commands = {
699            #[cfg(feature = "trace")]
700            let _span = tracing::trace_span!("encode").entered();
701            context.encode(&TensorOp::List(ops))
702        };
703
704        Ok(RnnJob {
705            commands,
706            redirect,
707            embed: model.tensor.embed.w.clone(),
708            cursors: buffer.cursors.clone(),
709            input: buffer.input.clone(),
710            output: header.head_o.clone(),
711        })
712    }
713}
714
715#[allow(clippy::too_many_arguments)]
716fn dispatch_layer<F: Float>(
717    hooks: Arc<HookMap<F>>,
718    frame: Frame<F>,
719    layer: Layer,
720    index: usize,
721    num_token: usize,
722    head_size: usize,
723    rescale: usize,
724) -> Result<TensorOp, TensorError> {
725    let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
726    let Frame { state, buffer, .. } = &frame;
727
728    let att_kk = buffer.att_kk.reshape(
729        TensorDimension::Size(head_size),
730        TensorDimension::Auto,
731        TensorDimension::Size(num_token),
732        TensorDimension::Size(1),
733    )?;
734    let att_x = buffer.att_x.reshape(
735        TensorDimension::Size(head_size),
736        TensorDimension::Auto,
737        TensorDimension::Size(num_token),
738        TensorDimension::Size(1),
739    )?;
740    let att_r = buffer.att_r.reshape(
741        TensorDimension::Size(head_size),
742        TensorDimension::Auto,
743        TensorDimension::Size(num_token),
744        TensorDimension::Size(1),
745    )?;
746    let att_w = buffer.att_w.reshape(
747        TensorDimension::Size(head_size),
748        TensorDimension::Auto,
749        TensorDimension::Size(num_token),
750        TensorDimension::Size(1),
751    )?;
752    let att_n = buffer.att_n.reshape(
753        TensorDimension::Size(head_size),
754        TensorDimension::Auto,
755        TensorDimension::Size(num_token),
756        TensorDimension::Size(4),
757    )?;
758
759    let mut ops = vec![];
760
761    ops.extend([
762        TensorOp::blit(&buffer.x, &buffer.att_x)?,
763        hook_op(Hook::PreAtt(index))?,
764        TensorOp::layer_norm(
765            &layer.att_ln.w,
766            &layer.att_ln.b,
767            &buffer.att_x,
768            Model::LN_EPS,
769        )?,
770        hook_op(Hook::PostAttLayerNorm(index))?,
771        hook_op(Hook::PreAttTokenShift(index))?,
772        TensorOp::token_shift(
773            &buffer.cursors,
774            &layer.att.x_r,
775            state.att(index)?,
776            &buffer.att_x,
777            &buffer.att_rx,
778            true,
779        )?,
780        TensorOp::token_shift(
781            &buffer.cursors,
782            &layer.att.x_w,
783            state.att(index)?,
784            &buffer.att_x,
785            &buffer.att_wx,
786            true,
787        )?,
788        TensorOp::token_shift(
789            &buffer.cursors,
790            &layer.att.x_k,
791            state.att(index)?,
792            &buffer.att_x,
793            &buffer.att_kx,
794            true,
795        )?,
796        TensorOp::token_shift(
797            &buffer.cursors,
798            &layer.att.x_v,
799            state.att(index)?,
800            &buffer.att_x,
801            &buffer.att_vx,
802            true,
803        )?,
804        TensorOp::token_shift(
805            &buffer.cursors,
806            &layer.att.x_a,
807            state.att(index)?,
808            &buffer.att_x,
809            &buffer.att_ax,
810            true,
811        )?,
812        TensorOp::token_shift(
813            &buffer.cursors,
814            &layer.att.x_g,
815            state.att(index)?,
816            &buffer.att_x,
817            &buffer.att_gx,
818            true,
819        )?,
820        hook_op(Hook::PostAttTokenShift(index))?,
821        hook_op(Hook::PreAttLinear(index))?,
822        layer.att.w_r.matmul_op(
823            &buffer.att_rx,
824            &buffer.att_r,
825            Activation::None,
826            turbo(num_token),
827        )?,
828        layer.att.w_k.matmul_op(
829            &buffer.att_kx,
830            &buffer.att_k,
831            Activation::None,
832            turbo(num_token),
833        )?,
834        layer.att.w_v.matmul_op(
835            &buffer.att_vx,
836            &buffer.att_v,
837            Activation::None,
838            turbo(num_token),
839        )?,
840        hook_op(Hook::PostAttLinear(index))?,
841        hook_op(Hook::PreAttAdapt(index))?,
842        layer.att.w1.matmul_op(
843            &buffer.att_wx,
844            &buffer.aux_w,
845            Activation::Tanh,
846            turbo(num_token),
847        )?,
848        layer.att.w2.matmul_op(
849            &buffer.aux_w,
850            &buffer.att_w,
851            Activation::None,
852            turbo(num_token),
853        )?,
854        TensorOp::add(&layer.att.w0, &buffer.att_w)?,
855        layer.att.a1.matmul_op(
856            &buffer.att_ax,
857            &buffer.aux_a,
858            Activation::None,
859            turbo(num_token),
860        )?,
861        layer.att.a2.matmul_op(
862            &buffer.aux_a,
863            &buffer.att_a,
864            Activation::None,
865            turbo(num_token),
866        )?,
867        TensorOp::add_activate(
868            &layer.att.a0,
869            &buffer.att_a,
870            Activation::None,
871            Activation::None,
872            Activation::Sigmoid,
873        )?,
874        layer.att.g1.matmul_op(
875            &buffer.att_gx,
876            &buffer.aux_g,
877            Activation::Sigmoid,
878            turbo(num_token),
879        )?,
880        layer.att.g2.matmul_op(
881            &buffer.aux_g,
882            &buffer.att_g,
883            Activation::None,
884            turbo(num_token),
885        )?,
886        hook_op(Hook::PostAttAdapt(index))?,
887        hook_op(Hook::PreAttControl(index))?,
888        TensorOp::blit(&buffer.att_k, &buffer.att_kk)?,
889        TensorOp::mul(&layer.att.k_k, &buffer.att_kk)?,
890        TensorOp::l2_norm(&att_kk, Model::L2_EPS)?,
891        TensorOp::control_k_v7(&layer.att.k_a, &buffer.att_a, &buffer.att_k)?,
892        hook_op(Hook::PostAttControl(index))?,
893    ]);
894
895    ops.push(hook_op(Hook::PreAttValueResidual(index))?);
896    match index {
897        0 => ops.push(TensorOp::blit(&buffer.att_v, &buffer.att_v0)?),
898        _ => ops.extend([
899            layer.att.v1.matmul_op(
900                &buffer.att_vx,
901                &buffer.aux_v,
902                Activation::None,
903                turbo(num_token),
904            )?,
905            layer.att.v2.matmul_op(
906                &buffer.aux_v,
907                &buffer.att_vv,
908                Activation::None,
909                turbo(num_token),
910            )?,
911            TensorOp::add_activate(
912                &layer.att.v0,
913                &buffer.att_vv,
914                Activation::None,
915                Activation::None,
916                Activation::Sigmoid,
917            )?,
918            TensorOp::lerp(&buffer.att_v0, &buffer.att_v, &buffer.att_vv, true)?,
919        ]),
920    };
921    ops.push(hook_op(Hook::PostAttValueResidual(index))?);
922
923    ops.extend([
924        hook_op(Hook::PreAttTimeMix(index))?,
925        TensorOp::blit(&buffer.att_k, buffer.att_n.view(.., .., 0, ..)?)?,
926        TensorOp::blit(&buffer.att_v, buffer.att_n.view(.., .., 1, ..)?)?,
927        TensorOp::blit(&buffer.att_a, buffer.att_n.view(.., .., 2, ..)?)?,
928        TensorOp::blit(&buffer.att_kk, buffer.att_n.view(.., .., 3, ..)?)?,
929        TensorOp::time_mix_v7(
930            &buffer.cursors,
931            state.att(index)?,
932            &att_r,
933            &att_w,
934            &att_n,
935            &att_x,
936        )?,
937        TensorOp::group_norm(&layer.att.gn.w, &layer.att.gn.b, &att_x, Model::GN_EPS)?,
938        TensorOp::time_first_v7(&layer.att.r_k, &att_r, &att_n, &att_x)?,
939        hook_op(Hook::PostAttTimeMix(index))?,
940        hook_op(Hook::PreAttGate(index))?,
941        TensorOp::mul(&buffer.att_g, &buffer.att_x)?,
942        hook_op(Hook::PostAttGate(index))?,
943        hook_op(Hook::PreAttOut(index))?,
944        layer.att.w_o.matmul_op(
945            &buffer.att_x,
946            &buffer.att_o,
947            Activation::None,
948            turbo(num_token),
949        )?,
950        hook_op(Hook::PostAttOut(index))?,
951        TensorOp::add(&buffer.att_o, &buffer.x)?,
952        hook_op(Hook::PostAtt(index))?,
953    ]);
954
955    ops.extend([
956        TensorOp::blit(&buffer.x, &buffer.ffn_x)?,
957        hook_op(Hook::PreFfn(index))?,
958        TensorOp::layer_norm(
959            &layer.ffn_ln.w,
960            &layer.ffn_ln.b,
961            &buffer.ffn_x,
962            Model::LN_EPS,
963        )?,
964        hook_op(Hook::PostFfnLayerNorm(index))?,
965        hook_op(Hook::PreFfnTokenShift(index))?,
966        TensorOp::token_shift(
967            &buffer.cursors,
968            &layer.ffn.x_k,
969            state.ffn(index)?,
970            &buffer.ffn_x,
971            &buffer.ffn_kx,
972            true,
973        )?,
974        hook_op(Hook::PostFfnTokenShift(index))?,
975        hook_op(Hook::PreFfnLinear(index))?,
976        layer.ffn.w_k.matmul_op(
977            &buffer.ffn_kx,
978            &buffer.ffn_k,
979            Activation::SquaredRelu,
980            turbo(num_token),
981        )?,
982        hook_op(Hook::PostFfnActivate(index))?,
983        layer.ffn.w_v.matmul_op_sparse(
984            &buffer.ffn_k,
985            &buffer.ffn_v,
986            Activation::None,
987            turbo(num_token),
988        )?,
989        hook_op(Hook::PostFfnLinear(index))?,
990        hook_op(Hook::PreFfnChannelMix(index))?,
991        TensorOp::channel_mix_v7(
992            &buffer.cursors,
993            state.ffn(index)?,
994            &buffer.ffn_v,
995            &buffer.ffn_x,
996        )?,
997        hook_op(Hook::PostFfnChannelMix(index))?,
998        TensorOp::add(&buffer.ffn_x, &buffer.x)?,
999        hook_op(Hook::PostFfn(index))?,
1000    ]);
1001
1002    if (index + 1).is_multiple_of(rescale) {
1003        ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
1004    }
1005
1006    Ok(TensorOp::List(ops))
1007}
1008
1009fn dispatch_header<F: Float>(
1010    hooks: Arc<HookMap<F>>,
1011    frame: Frame<F>,
1012    head: Head,
1013    head_x: TensorGpu<F, ReadWrite>,
1014    num_header: usize,
1015    head_op: TensorOp,
1016) -> Result<TensorOp, TensorError> {
1017    let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
1018    let header = &frame.header;
1019    let mut ops = vec![head_op];
1020
1021    if num_header > 0 {
1022        ops.extend([
1023            hook_op(Hook::PreHead)?,
1024            TensorOp::layer_norm(&head.ln.w, &head.ln.b, &head_x, Model::LN_EPS)?,
1025            hook_op(Hook::PostHeadLayerNorm)?,
1026            head.w.matmul_op(
1027                head_x.view(.., .., .., ..)?,
1028                header.head_o.view(.., .., .., ..)?,
1029                Activation::None,
1030                turbo(num_header),
1031            )?,
1032            hook_op(Hook::PostHead)?,
1033        ]);
1034    }
1035    Ok(TensorOp::List(ops))
1036}
1037
1038impl<R: Reader> ModelBuilder<R> {
1039    pub async fn build_v7(self) -> Result<Model, LoaderError> {
1040        let ModelBuilder {
1041            context,
1042            model,
1043            rescale,
1044            sep,
1045            lora,
1046            quant,
1047            ..
1048        } = self;
1049
1050        let rescale = rescale.unwrap_or(Model::DEFAULT_RESCALE);
1051        let sep = sep.unwrap_or(Model::DEFAULT_SEP);
1052
1053        let info = Loader::info(&model)?;
1054        let loader = Loader {
1055            context: context.clone(),
1056            model,
1057            lora,
1058        };
1059
1060        let embed = Embed {
1061            ln: LayerNorm {
1062                w: loader.load_vector_f16("blocks.0.ln0.weight")?,
1063                b: loader.load_vector_f16("blocks.0.ln0.bias")?,
1064            },
1065            w: loader.load_matrix_f16_padded_cpu("emb.weight")?,
1066        };
1067
1068        let head = Head {
1069            ln: LayerNorm {
1070                w: loader.load_vector_f16("ln_out.weight")?,
1071                b: loader.load_vector_f16("ln_out.bias")?,
1072            },
1073            w: Matrix::Fp16(loader.load_matrix_f16_padded("head.weight")?),
1074        };
1075
1076        let submission_index = Some(context.queue.submit(None));
1077        _ = context.device.poll(wgpu::PollType::Wait {
1078            submission_index,
1079            timeout: None,
1080        });
1081
1082        let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant);
1083        let load_matrix_discount = |name: String, quant: Quant, discount: f32| {
1084            loader.load_matrix_discount(name, quant, discount)
1085        };
1086
1087        let mut layers = vec![];
1088        for layer in 0..info.num_layer {
1089            let quant = quant.get(&layer).copied().unwrap_or_default();
1090            let discount = 2.0_f32.powi(-((layer / rescale) as i32));
1091
1092            let att_ln = LayerNorm {
1093                w: loader.load_vector_f16(format!("blocks.{layer}.ln1.weight"))?,
1094                b: loader.load_vector_f16(format!("blocks.{layer}.ln1.bias"))?,
1095            };
1096
1097            let att = format!("blocks.{layer}.att");
1098            let x_r = loader.load_vector_f16(format!("{att}.x_r"))?;
1099            let x_w = loader.load_vector_f16(format!("{att}.x_w"))?;
1100            let x_k = loader.load_vector_f16(format!("{att}.x_k"))?;
1101            let x_v = loader.load_vector_f16(format!("{att}.x_v"))?;
1102            let x_a = loader.load_vector_f16(format!("{att}.x_a"))?;
1103            let x_g = loader.load_vector_f16(format!("{att}.x_g"))?;
1104
1105            let w0 = loader.load_vector_f16(format!("{att}.w0"))?;
1106            let a0 = loader.load_vector_f16(format!("{att}.a0"))?;
1107
1108            let w1 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.w1"))?);
1109            let w2 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.w2"))?);
1110            let a1 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.a1"))?);
1111            let a2 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.a2"))?);
1112            let g1 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.g1"))?);
1113            let g2 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.g2"))?);
1114
1115            let (v0, v1, v2) = match layer {
1116                0 => (a0.clone(), a1.clone(), a2.clone()), // placeholder, actually not used
1117                _ => (
1118                    loader.load_vector_f16(format!("{att}.v0"))?,
1119                    Matrix::Fp16(loader.load_matrix_f16(format!("{att}.v1"))?),
1120                    Matrix::Fp16(loader.load_matrix_f16(format!("{att}.v2"))?),
1121                ),
1122            };
1123
1124            let r_k = loader.load_matrix_f16(format!("{att}.r_k"))?;
1125            let k_k = loader.load_vector_f16(format!("{att}.k_k"))?;
1126            let k_a = loader.load_vector_f16(format!("{att}.k_a"))?;
1127
1128            let gn = LayerNorm {
1129                w: loader
1130                    .load_vector_f16(format!("{att}.ln_x.weight"))?
1131                    .reshape(
1132                        TensorDimension::Auto,
1133                        TensorDimension::Size(info.num_head),
1134                        TensorDimension::Size(1),
1135                        TensorDimension::Size(1),
1136                    )?,
1137                b: loader
1138                    .load_vector_f16(format!("{att}.ln_x.bias"))?
1139                    .reshape(
1140                        TensorDimension::Auto,
1141                        TensorDimension::Size(info.num_head),
1142                        TensorDimension::Size(1),
1143                        TensorDimension::Size(1),
1144                    )?,
1145            };
1146
1147            let att = Att {
1148                x_r,
1149                x_w,
1150                x_k,
1151                x_v,
1152                x_a,
1153                x_g,
1154                w0,
1155                a0,
1156                v0,
1157                w1,
1158                w2,
1159                a1,
1160                a2,
1161                g1,
1162                g2,
1163                v1,
1164                v2,
1165                r_k,
1166                k_k,
1167                k_a,
1168                w_k: load_matrix(format!("{att}.key.weight"), quant)?,
1169                w_v: load_matrix(format!("{att}.value.weight"), quant)?,
1170                w_r: load_matrix(format!("{att}.receptance.weight"), quant)?,
1171                w_o: load_matrix_discount(format!("{att}.output.weight"), quant, discount)?,
1172                gn,
1173            };
1174
1175            let ffn_ln = LayerNorm {
1176                w: loader.load_vector_f16(format!("blocks.{layer}.ln2.weight"))?,
1177                b: loader.load_vector_f16(format!("blocks.{layer}.ln2.bias"))?,
1178            };
1179
1180            let ffn = format!("blocks.{layer}.ffn");
1181            let x_k = loader.load_vector_f16(format!("{ffn}.x_k"))?;
1182
1183            let ffn = Ffn {
1184                x_k,
1185                w_k: load_matrix(format!("{ffn}.key.weight"), quant)?,
1186                w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?,
1187            };
1188
1189            let submission_index = Some(context.queue.submit(None));
1190            _ = context.device.poll(wgpu::PollType::Wait {
1191                submission_index,
1192                timeout: None,
1193            });
1194
1195            layers.push(Layer {
1196                att_ln,
1197                ffn_ln,
1198                att,
1199                ffn,
1200            })
1201        }
1202
1203        let submission_index = Some(context.queue.submit(None));
1204        _ = context.device.poll(wgpu::PollType::Wait {
1205            submission_index,
1206            timeout: None,
1207        });
1208
1209        let tensor = ModelTensor {
1210            embed,
1211            head,
1212            layers,
1213        };
1214        let model = {
1215            let context = context.clone();
1216            let info = info.clone();
1217            Model {
1218                context,
1219                info,
1220                rescale,
1221                sep,
1222                tensor,
1223            }
1224        };
1225        Ok(model)
1226    }
1227}
1228
1229/// Read the pre-trained state from the file.
1230pub async fn read_state<R: Reader>(
1231    context: &Context,
1232    info: &ModelInfo,
1233    model: R,
1234) -> Result<TensorCpu<f32>, LoaderError> {
1235    let loader = Loader {
1236        context: context.clone(),
1237        model,
1238        lora: vec![],
1239    };
1240
1241    let head_size = info.num_emb / info.num_head;
1242    let data: TensorGpu<f32, _> = context.zeros([info.num_emb, head_size + 2, info.num_layer, 1]);
1243
1244    let mut ops = vec![];
1245    for layer in 0..info.num_layer {
1246        let matrix = loader.load_matrix_f16(format!("blocks.{layer}.att.time_state"))?;
1247        let state: TensorGpu<_, _> = context.tensor_init([head_size, info.num_head, head_size, 1]);
1248        let reshaped: TensorGpu<f16, _> = state.reshape(
1249            TensorDimension::Size(info.num_emb),
1250            TensorDimension::Size(head_size),
1251            TensorDimension::Size(1),
1252            TensorDimension::Auto,
1253        )?;
1254        ops.extend([
1255            TensorOp::transpose(&matrix, &state)?,
1256            TensorOp::blit(&reshaped, data.view(.., 1..head_size + 1, layer, ..)?)?,
1257        ]);
1258    }
1259    context.queue.submit(context.encode(&TensorOp::List(ops)));
1260
1261    Ok(data.back().await)
1262}