web_rwkv/runtime/
v5.rs

1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use half::f16;
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use web_rwkv_derive::DeserializeSeed;
11use wgpu::CommandBuffer;
12
13use super::{
14    infer::{RnnChunk, RnnInfo, RnnInput, RnnOutput, RnnOutputBatch, RnnRedirect, Token},
15    loader::{Loader, LoaderError, Reader},
16    model::{AsAny, ModelBuilder, ModelInfo, Quant, State as _},
17    Dispatcher, Job, RuntimeError,
18};
19use crate::{
20    context::Context,
21    num::Float,
22    tensor::{
23        cache::ResourceCache,
24        kind::ReadWrite,
25        matrix::Matrix,
26        ops::{Activation, TensorCommand, TensorOp},
27        serialization::Seed,
28        shape::{Shape, TensorDimension},
29        DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit,
30        TensorReshape, TensorShape, TensorStack,
31    },
32};
33
34#[derive(Debug, Clone, Serialize, DeserializeSeed)]
35#[serde_seed(seed = "Seed", context = "Context")]
36pub struct Model {
37    pub context: Context,
38    pub info: ModelInfo,
39    pub rescale: usize,
40    pub sep: usize,
41    pub tensor: ModelTensor,
42}
43
44impl Model {
45    pub const LN_EPS: f32 = 1.0e-5;
46    pub const GN_EPS: f32 = 64.0e-5;
47
48    pub const DEFAULT_RESCALE: usize = 6;
49    pub const DEFAULT_SEP: usize = 1024;
50}
51
52#[derive(Debug, Clone, Serialize, DeserializeSeed)]
53#[serde_seed(seed = "Seed", context = "Context")]
54pub struct ModelTensor {
55    pub embed: Embed,
56    pub head: Head,
57    pub layers: Vec<Layer>,
58}
59
60#[derive(Debug, Clone, Serialize, DeserializeSeed)]
61#[serde_seed(seed = "Seed", context = "Context")]
62pub struct LayerNorm {
63    pub w: TensorGpu<f16, ReadWrite>,
64    pub b: TensorGpu<f16, ReadWrite>,
65}
66
67#[derive(Debug, Clone, Serialize, DeserializeSeed)]
68#[serde_seed(seed = "Seed", context = "Context")]
69pub struct Att {
70    pub time_decay: TensorGpu<f32, ReadWrite>,
71    pub time_first: TensorGpu<f32, ReadWrite>,
72
73    pub time_mix_k: TensorGpu<f16, ReadWrite>,
74    pub time_mix_v: TensorGpu<f16, ReadWrite>,
75    pub time_mix_r: TensorGpu<f16, ReadWrite>,
76    pub time_mix_g: TensorGpu<f16, ReadWrite>,
77
78    pub w_k: Matrix,
79    pub w_v: Matrix,
80    pub w_r: Matrix,
81    pub w_g: Matrix,
82    pub w_o: Matrix,
83
84    pub group_norm: LayerNorm,
85}
86
87#[derive(Debug, Clone, Serialize, DeserializeSeed)]
88#[serde_seed(seed = "Seed", context = "Context")]
89pub struct Ffn {
90    pub time_mix_k: TensorGpu<f16, ReadWrite>,
91    pub time_mix_r: TensorGpu<f16, ReadWrite>,
92
93    pub w_k: Matrix,
94    pub w_v: Matrix,
95    pub w_r: Matrix,
96}
97
98#[derive(Debug, Clone, Serialize, DeserializeSeed)]
99#[serde_seed(seed = "Seed", context = "Context")]
100pub struct Layer {
101    pub att_layer_norm: LayerNorm,
102    pub ffn_layer_norm: LayerNorm,
103    pub att: Att,
104    pub ffn: Ffn,
105}
106
107#[derive(Debug, Clone, Serialize, DeserializeSeed)]
108#[serde_seed(seed = "Seed", context = "Context")]
109pub struct Embed {
110    pub layer_norm: LayerNorm,
111    pub w: TensorCpu<f16>,
112}
113
114#[derive(Debug, Clone, Serialize, DeserializeSeed)]
115#[serde_seed(seed = "Seed", context = "Context")]
116pub struct Head {
117    pub layer_norm: LayerNorm,
118    pub w: Matrix,
119}
120
121#[derive(Debug, Clone, Serialize, DeserializeSeed)]
122#[serde_seed(seed = "Seed", context = "Context")]
123pub struct State {
124    pub context: Context,
125    pub info: ModelInfo,
126    pub data: Vec<TensorGpu<f32, ReadWrite>>,
127}
128
129impl State {
130    async fn back(&self, batch: usize) -> Result<TensorCpu<f32>, TensorError> {
131        let context = &self.context;
132        let mut tensors = Vec::with_capacity(self.info.num_layer);
133        let mut encoder = context.device.create_command_encoder(&Default::default());
134        for data in self.data.iter() {
135            let shape = data.shape();
136            let destination = context.tensor_init([shape[0], shape[1], 1, 1]);
137            encoder.copy_tensor_batch(data, &destination, batch, 0)?;
138            tensors.push(destination);
139        }
140        context.queue.submit(Some(encoder.finish()));
141
142        let mut backed = Vec::with_capacity(tensors.len());
143        for tensor in tensors {
144            backed.push(tensor.back().await);
145        }
146        TensorCpu::stack(backed, 2)
147    }
148}
149
150impl AsAny for State {
151    fn as_any(&self) -> &dyn std::any::Any {
152        self
153    }
154}
155
156impl super::model::State for State {
157    #[inline]
158    fn num_batch(&self) -> usize {
159        self.data[0].shape()[2]
160    }
161
162    #[inline]
163    fn init_shape(&self) -> Shape {
164        let info = &self.info;
165        let head_size = info.num_emb / info.num_head;
166        [info.num_emb, head_size + 2, info.num_layer, 1].into()
167    }
168
169    fn init(&self) -> TensorCpu<f32> {
170        let shape = self.init_shape();
171        let data = vec![0.0; shape.len()];
172        TensorCpu::from_data(shape, data).unwrap()
173    }
174
175    fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
176        let head_size = self.info.num_emb / self.info.num_head;
177        let end = head_size + 1;
178        self.data[layer].view(.., 0..end, .., ..)
179    }
180
181    fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
182        let head_size = self.info.num_emb / self.info.num_head;
183        let start = head_size + 1;
184        self.data[layer].view(.., start, .., ..)
185    }
186
187    fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError> {
188        let head_size = self.info.num_emb / self.info.num_head;
189        tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
190        for (data, source) in self.data.iter().zip(tensor.split(2)?.into_iter()) {
191            data.load_batch(&source, batch)?;
192        }
193        Ok(())
194    }
195
196    #[cfg(not(target_arch = "wasm32"))]
197    fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
198        Box::pin(self.back(batch))
199    }
200
201    #[cfg(target_arch = "wasm32")]
202    fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
203        Box::pin(self.back(batch))
204    }
205
206    fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError> {
207        let head_size = self.info.num_emb / self.info.num_head;
208        tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
209
210        let context = &self.context;
211        let mut ops = Vec::with_capacity(self.data.len());
212        for (layer, data) in self.data.iter().enumerate() {
213            ops.push(TensorOp::blit(
214                tensor.view(.., .., layer, ..)?,
215                data.view(.., .., batch, ..)?,
216            )?);
217        }
218        context.queue.submit(context.encode(&TensorOp::List(ops)));
219
220        Ok(())
221    }
222
223    fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError> {
224        let context = &self.context;
225        let head_size = self.info.num_emb / self.info.num_head;
226        let shape = [self.info.num_emb, head_size + 2, self.info.num_layer, 1];
227        let tensor: TensorGpu<_, _> = context.tensor_init(shape);
228
229        let mut ops = Vec::with_capacity(self.data.len());
230        for (layer, data) in self.data.iter().enumerate() {
231            ops.push(TensorOp::blit(
232                data.view(.., .., batch, ..)?,
233                tensor.view(.., .., layer, ..)?,
234            )?);
235        }
236        context.queue.submit(context.encode(&TensorOp::List(ops)));
237
238        Ok(tensor)
239    }
240
241    fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError> {
242        backed.slice(.., 0, layer, ..)
243    }
244}
245
246impl DeepClone for State {
247    fn deep_clone(&self) -> Self {
248        let data = self.data.iter().map(|tensor| tensor.deep_clone()).collect();
249        Self {
250            data,
251            ..self.clone()
252        }
253    }
254}
255
256#[derive(Debug, Clone, Serialize, DeserializeSeed)]
257#[serde_seed(seed = "Seed", context = "Context")]
258pub struct Runtime<F: Float> {
259    pub cursors: TensorGpu<u32, ReadWrite>,
260    pub input: TensorGpu<f16, ReadWrite>,
261
262    pub x: TensorGpu<F, ReadWrite>,
263    pub aux_x: TensorGpu<f32, ReadWrite>,
264
265    pub att_x: TensorGpu<F, ReadWrite>,
266    pub att_kx: TensorGpu<F, ReadWrite>,
267    pub att_vx: TensorGpu<F, ReadWrite>,
268    pub att_rx: TensorGpu<F, ReadWrite>,
269    pub att_gx: TensorGpu<F, ReadWrite>,
270    pub att_k: TensorGpu<f32, ReadWrite>,
271    pub att_v: TensorGpu<f32, ReadWrite>,
272    pub att_r: TensorGpu<f32, ReadWrite>,
273    pub att_g: TensorGpu<F, ReadWrite>,
274    pub att_o: TensorGpu<F, ReadWrite>,
275
276    pub ffn_x: TensorGpu<F, ReadWrite>,
277    pub ffn_kx: TensorGpu<F, ReadWrite>,
278    pub ffn_rx: TensorGpu<F, ReadWrite>,
279    pub ffn_k: TensorGpu<F, ReadWrite>,
280    pub ffn_v: TensorGpu<F, ReadWrite>,
281    pub ffn_r: TensorGpu<F, ReadWrite>,
282}
283
284impl<F: Float> Runtime<F> {
285    pub fn new(context: &Context, info: &ModelInfo, num_token: usize) -> Self {
286        let shape = Shape::new(info.num_emb, num_token, 1, 1);
287        let cursors_shape = Shape::new(num_token, 1, 1, 1);
288        let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1);
289
290        Self {
291            cursors: context.tensor_init(cursors_shape),
292            input: context.tensor_init(shape),
293            x: context.tensor_init(shape),
294            aux_x: context.tensor_init(shape),
295            att_x: context.tensor_init(shape),
296            att_kx: context.tensor_init(shape),
297            att_vx: context.tensor_init(shape),
298            att_rx: context.tensor_init(shape),
299            att_gx: context.tensor_init(shape),
300            att_k: context.tensor_init(shape),
301            att_v: context.tensor_init(shape),
302            att_r: context.tensor_init(shape),
303            att_g: context.tensor_init(shape),
304            att_o: context.tensor_init(shape),
305            ffn_x: context.tensor_init(shape),
306            ffn_kx: context.tensor_init(shape),
307            ffn_rx: context.tensor_init(shape),
308            ffn_k: context.tensor_init(hidden_shape),
309            ffn_v: context.tensor_init(shape),
310            ffn_r: context.tensor_init(shape),
311        }
312    }
313}
314
315#[derive(Debug, Clone, Serialize, DeserializeSeed)]
316#[serde_seed(seed = "Seed", context = "Context")]
317pub struct Header<F: Float> {
318    pub head_x: TensorGpu<F, ReadWrite>,
319    pub head_o: TensorGpu<f32, ReadWrite>,
320}
321
322impl<F: Float> Header<F> {
323    pub fn new(context: &Context, info: &ModelInfo, num_header: usize) -> Self {
324        let head_shape = Shape::new(info.num_emb, num_header, 1, 1);
325        let output_shape = Shape::new(info.num_vocab_padded(), num_header, 1, 1);
326
327        Self {
328            head_x: context.tensor_init(head_shape),
329            head_o: context.tensor_init(output_shape),
330        }
331    }
332}
333
334#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
335pub enum Hook {
336    PostEmbedLoaded,
337    PostEmbedLayerNorm,
338    PreAtt(usize),
339    PostAttLayerNorm(usize),
340    PreAttTokenShift(usize),
341    PostAttTokenShift(usize),
342    PreAttLinear(usize),
343    PostAttLinear(usize),
344    PreAttTimeMix(usize),
345    PostAttTimeMix(usize),
346    PreAttGate(usize),
347    PostAttGate(usize),
348    PreAttOut(usize),
349    PostAttOut(usize),
350    PostAtt(usize),
351    PreFfn(usize),
352    PostFfnLayerNorm(usize),
353    PreFfnTokenShift(usize),
354    PostFfnTokenShift(usize),
355    PreFfnLinear(usize),
356    PostFfnLinear(usize),
357    PostFfnActivate(usize),
358    PreFfnChannelMix(usize),
359    PostFfnChannelMix(usize),
360    PostFfn(usize),
361    PreHead,
362    PostHeadLayerNorm,
363    PostHead,
364}
365
366pub struct RnnJob {
367    commands: Vec<CommandBuffer>,
368    redirect: RnnRedirect,
369
370    embed: TensorCpu<f16>,
371
372    cursors: TensorGpu<u32, ReadWrite>,
373    input: TensorGpu<f16, ReadWrite>,
374    output: TensorGpu<f32, ReadWrite>,
375}
376
377impl Job for RnnJob {
378    type Input = RnnInput;
379    type Output = RnnOutput;
380
381    fn load(&self, input: &RnnChunk) -> Result<(), RuntimeError> {
382        if input.num_token() == 0 {
383            return Ok(());
384        }
385
386        let stack: Vec<TensorCpu<f16>> = input
387            .iter()
388            .map(|chunk| {
389                let num_emb = self.embed.shape()[0];
390                let data = self.embed.data();
391                let data = chunk
392                    .iter()
393                    .map(|token| match token {
394                        &Token::Token(token) => {
395                            let start = num_emb * token as usize;
396                            let end = start + num_emb;
397                            let data = data[start..end].to_vec();
398                            TensorCpu::from_data_1d(data)
399                        }
400                        Token::Embed(tensor) => tensor.clone(),
401                    })
402                    .collect_vec();
403                match TensorCpu::stack(data, 1) {
404                    Ok(tensor) => tensor,
405                    Err(_) => TensorCpu::init([num_emb, 0, 1, 1]),
406                }
407            })
408            .collect();
409        let stack = TensorStack::try_from(stack)?;
410
411        let cursors = stack.cursors.clone().into_cursors();
412        let cursors = TensorCpu::from_data(self.cursors.shape(), cursors)?;
413        self.cursors.load(&cursors)?;
414        self.input.load(&stack.tensor)?;
415
416        Ok(())
417    }
418
419    fn submit(&mut self) {
420        let commands = std::mem::take(&mut self.commands);
421        self.output.context.queue.submit(commands);
422    }
423
424    async fn back(self) -> Result<Self::Output, RuntimeError> {
425        let output = self.output.back().await;
426        let batches: Vec<_> = self
427            .redirect
428            .outputs
429            .into_iter()
430            .map(|(start, end)| output.slice(.., start..end, .., ..))
431            .try_collect()?;
432        let batches = batches.into_iter().map(RnnOutputBatch).collect();
433        Ok(RnnOutput(batches))
434    }
435}
436
437#[derive(Debug, Clone)]
438pub struct Frame<F: Float> {
439    pub state: State,
440    pub buffer: Arc<Runtime<F>>,
441    pub header: Arc<Header<F>>,
442}
443
444pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
445pub type HookMap<F> = HashMap<Hook, HookFn<F>>;
446
447#[derive(Clone)]
448pub struct Bundle<F: Float> {
449    model: Model,
450    state: State,
451    hooks: Arc<HookMap<F>>,
452    buffers: ResourceCache<usize, Runtime<F>>,
453    headers: ResourceCache<usize, Header<F>>,
454    phantom: PhantomData<F>,
455}
456
457impl<F: Float> Bundle<F> {
458    pub fn new(model: Model, num_batch: usize) -> Self {
459        let context = model.context.clone();
460        let info = model.info.clone();
461        let state = {
462            let head_size = info.num_emb / info.num_head;
463            let shape = Shape::new(info.num_emb, head_size + 2, num_batch, 1);
464            let data = (0..info.num_layer).map(|_| context.zeros(shape)).collect();
465            State {
466                context,
467                info,
468                data,
469            }
470        };
471        Self {
472            model,
473            state,
474            hooks: Default::default(),
475            buffers: ResourceCache::new(4),
476            headers: ResourceCache::new(4),
477            phantom: PhantomData,
478        }
479    }
480
481    pub fn new_with_hooks(model: Model, num_batch: usize, hooks: HookMap<F>) -> Self {
482        Self {
483            hooks: Arc::new(hooks),
484            ..Self::new(model, num_batch)
485        }
486    }
487
488    fn checkout_buffer(
489        &self,
490        context: &Context,
491        info: &ModelInfo,
492        num_token: usize,
493    ) -> Arc<Runtime<F>> {
494        self.buffers
495            .checkout(num_token, || Runtime::new(context, info, num_token))
496    }
497
498    fn checkout_header(
499        &self,
500        context: &Context,
501        info: &ModelInfo,
502        num_header: usize,
503    ) -> Arc<Header<F>> {
504        self.headers
505            .checkout(num_header, || Header::new(context, info, num_header))
506    }
507}
508
509impl<F: Float> super::model::Bundle for Bundle<F> {
510    #[inline]
511    fn info(&self) -> ModelInfo {
512        self.model.info.clone()
513    }
514
515    #[inline]
516    fn state(&self) -> impl super::model::State + AsAny + 'static {
517        self.state.clone()
518    }
519
520    fn model(&self) -> impl Serialize + 'static {
521        self.model.clone()
522    }
523}
524
525fn turbo(num_token: usize) -> bool {
526    num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE)
527}
528
529fn hook_op<F: Float>(
530    hooks: &HookMap<F>,
531    hook: &Hook,
532    frame: &Frame<F>,
533) -> Result<TensorOp, TensorError> {
534    match hooks.get(hook) {
535        Some(f) => f(frame.clone()),
536        None => Ok(TensorOp::empty()),
537    }
538}
539
540impl<F: Float> Dispatcher<RnnJob> for Bundle<F> {
541    type Info = RnnInfo;
542
543    fn dispatch(&self, seed: Self::Info) -> Result<RnnJob, RuntimeError> {
544        let model = &self.model;
545        let state = &self.state;
546        let context = &model.context;
547        let info = &model.info;
548        let tensor = &model.tensor;
549
550        let num_token = seed.num_token();
551        let head_size = info.num_emb / info.num_head;
552
553        let redirect = seed.redirect();
554        let num_header = redirect.headers.len();
555
556        let buffer = self.checkout_buffer(context, info, num_token);
557        let header = self.checkout_header(context, info, num_header);
558        let frame = Frame {
559            state: state.clone(),
560            buffer: buffer.clone(),
561            header: header.clone(),
562        };
563
564        context.maintain();
565
566        if num_token == 0 {
567            return Ok(RnnJob {
568                commands: vec![],
569                redirect,
570                embed: model.tensor.embed.w.clone(),
571                cursors: buffer.cursors.clone(),
572                input: buffer.input.clone(),
573                output: header.head_o.clone(),
574            });
575        }
576
577        #[cfg(feature = "trace")]
578        let _span = tracing::trace_span!("build").entered();
579
580        let (head_op, head_x) = redirect.op(&buffer.x, &header.head_x)?;
581
582        let hook_op = |hook: Hook| hook_op(&self.hooks, &hook, &frame);
583        let mut ops = vec![];
584
585        {
586            #[cfg(feature = "trace")]
587            let _span = tracing::trace_span!("embed").entered();
588
589            ops.extend([
590                hook_op(Hook::PostEmbedLoaded)?,
591                TensorOp::layer_norm(
592                    &tensor.embed.layer_norm.w,
593                    &tensor.embed.layer_norm.b,
594                    &buffer.input,
595                    Model::LN_EPS,
596                )?,
597                TensorOp::blit(&buffer.input, &buffer.x)?,
598                hook_op(Hook::PostEmbedLayerNorm)?,
599            ]);
600        };
601
602        for (index, layer) in tensor.layers.iter().enumerate() {
603            #[cfg(feature = "trace")]
604            let _span = tracing::trace_span!("layer", index).entered();
605
606            let hooks = self.hooks.clone();
607            let frame = frame.clone();
608            let layer = layer.clone();
609
610            let op = dispatch_layer(
611                hooks,
612                frame,
613                layer,
614                index,
615                num_token,
616                head_size,
617                model.rescale,
618            )?;
619            ops.push(op);
620
621            if (index + 1) % model.sep == 0 {
622                ops.push(TensorOp::Sep);
623            }
624        }
625
626        {
627            #[cfg(feature = "trace")]
628            let _span = tracing::trace_span!("header").entered();
629
630            let hooks = self.hooks.clone();
631            let frame = frame.clone();
632            let head = model.tensor.head.clone();
633
634            let op = dispatch_header(hooks, frame, head, head_x, num_header, head_op)?;
635            ops.push(op);
636        }
637
638        let commands = {
639            #[cfg(feature = "trace")]
640            let _span = tracing::trace_span!("encode").entered();
641            context.encode(&TensorOp::List(ops))
642        };
643
644        Ok(RnnJob {
645            commands,
646            redirect,
647            embed: model.tensor.embed.w.clone(),
648            cursors: buffer.cursors.clone(),
649            input: buffer.input.clone(),
650            output: header.head_o.clone(),
651        })
652    }
653}
654
655#[allow(clippy::too_many_arguments)]
656fn dispatch_layer<F: Float>(
657    hooks: Arc<HookMap<F>>,
658    frame: Frame<F>,
659    layer: Layer,
660    index: usize,
661    num_token: usize,
662    head_size: usize,
663    rescale: usize,
664) -> Result<TensorOp, TensorError> {
665    let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
666    let Frame { state, buffer, .. } = &frame;
667
668    let time_first = layer.att.time_first.reshape(
669        TensorDimension::Size(head_size),
670        TensorDimension::Auto,
671        TensorDimension::Size(1),
672        TensorDimension::Size(1),
673    )?;
674    let time_decay = layer.att.time_decay.reshape(
675        TensorDimension::Size(head_size),
676        TensorDimension::Auto,
677        TensorDimension::Size(1),
678        TensorDimension::Size(1),
679    )?;
680    let aux_x = buffer.aux_x.reshape(
681        TensorDimension::Size(head_size),
682        TensorDimension::Auto,
683        TensorDimension::Size(num_token),
684        TensorDimension::Size(1),
685    )?;
686    let att_k = buffer.att_k.reshape(
687        TensorDimension::Size(head_size),
688        TensorDimension::Auto,
689        TensorDimension::Size(num_token),
690        TensorDimension::Size(1),
691    )?;
692    let att_v = buffer.att_v.reshape(
693        TensorDimension::Size(head_size),
694        TensorDimension::Auto,
695        TensorDimension::Size(num_token),
696        TensorDimension::Size(1),
697    )?;
698    let att_r = buffer.att_r.reshape(
699        TensorDimension::Size(head_size),
700        TensorDimension::Auto,
701        TensorDimension::Size(num_token),
702        TensorDimension::Size(1),
703    )?;
704
705    let mut ops = vec![];
706
707    ops.extend([
708        TensorOp::blit(&buffer.x, &buffer.att_x)?,
709        hook_op(Hook::PreAtt(index))?,
710        TensorOp::layer_norm(
711            &layer.att_layer_norm.w,
712            &layer.att_layer_norm.b,
713            &buffer.att_x,
714            Model::LN_EPS,
715        )?,
716        hook_op(Hook::PostAttLayerNorm(index))?,
717        hook_op(Hook::PreAttTokenShift(index))?,
718        TensorOp::token_shift(
719            &buffer.cursors,
720            &layer.att.time_mix_k,
721            state.att(index)?,
722            &buffer.att_x,
723            &buffer.att_kx,
724            false,
725        )?,
726        TensorOp::token_shift(
727            &buffer.cursors,
728            &layer.att.time_mix_v,
729            state.att(index)?,
730            &buffer.att_x,
731            &buffer.att_vx,
732            false,
733        )?,
734        TensorOp::token_shift(
735            &buffer.cursors,
736            &layer.att.time_mix_r,
737            state.att(index)?,
738            &buffer.att_x,
739            &buffer.att_rx,
740            false,
741        )?,
742        TensorOp::token_shift(
743            &buffer.cursors,
744            &layer.att.time_mix_g,
745            state.att(index)?,
746            &buffer.att_x,
747            &buffer.att_gx,
748            false,
749        )?,
750        hook_op(Hook::PostAttTokenShift(index))?,
751        hook_op(Hook::PreAttLinear(index))?,
752        layer.att.w_k.matmul_op(
753            &buffer.att_kx,
754            &buffer.att_k,
755            Activation::None,
756            turbo(num_token),
757        )?,
758        layer.att.w_v.matmul_op(
759            &buffer.att_vx,
760            &buffer.att_v,
761            Activation::None,
762            turbo(num_token),
763        )?,
764        layer.att.w_r.matmul_op(
765            &buffer.att_rx,
766            &buffer.att_r,
767            Activation::None,
768            turbo(num_token),
769        )?,
770        layer.att.w_g.matmul_op(
771            &buffer.att_gx,
772            &buffer.att_g,
773            Activation::None,
774            turbo(num_token),
775        )?,
776        hook_op(Hook::PostAttLinear(index))?,
777        hook_op(Hook::PreAttTimeMix(index))?,
778        TensorOp::blit(&buffer.att_x, &buffer.aux_x)?,
779        TensorOp::time_mix_v5(
780            &buffer.cursors,
781            &time_decay,
782            &time_first,
783            state.att(index)?,
784            &att_k,
785            &att_v,
786            &att_r,
787            &aux_x,
788        )?,
789        TensorOp::group_norm(
790            &layer.att.group_norm.w,
791            &layer.att.group_norm.b,
792            &aux_x,
793            Model::GN_EPS,
794        )?,
795        TensorOp::blit(&buffer.aux_x, &buffer.att_x)?,
796        hook_op(Hook::PostAttTimeMix(index))?,
797        hook_op(Hook::PreAttGate(index))?,
798        TensorOp::mul_activate(
799            &buffer.att_g,
800            &buffer.att_x,
801            Activation::Silu,
802            Activation::None,
803            Activation::None,
804        )?,
805        hook_op(Hook::PostAttGate(index))?,
806        hook_op(Hook::PreAttOut(index))?,
807        layer.att.w_o.matmul_op(
808            &buffer.att_x,
809            &buffer.att_o,
810            Activation::None,
811            turbo(num_token),
812        )?,
813        hook_op(Hook::PostAttOut(index))?,
814        TensorOp::add(&buffer.att_o, &buffer.x)?,
815        hook_op(Hook::PostAtt(index))?,
816    ]);
817
818    ops.extend([
819        TensorOp::blit(&buffer.x, &buffer.ffn_x)?,
820        hook_op(Hook::PreFfn(index))?,
821        TensorOp::layer_norm(
822            &layer.ffn_layer_norm.w,
823            &layer.ffn_layer_norm.b,
824            &buffer.ffn_x,
825            Model::LN_EPS,
826        )?,
827        hook_op(Hook::PostFfnLayerNorm(index))?,
828        hook_op(Hook::PreFfnTokenShift(index))?,
829        TensorOp::token_shift(
830            &buffer.cursors,
831            &layer.ffn.time_mix_k,
832            state.ffn(index)?,
833            &buffer.ffn_x,
834            &buffer.ffn_kx,
835            false,
836        )?,
837        TensorOp::token_shift(
838            &buffer.cursors,
839            &layer.ffn.time_mix_r,
840            state.ffn(index)?,
841            &buffer.ffn_x,
842            &buffer.ffn_rx,
843            false,
844        )?,
845        hook_op(Hook::PostFfnTokenShift(index))?,
846        hook_op(Hook::PreFfnLinear(index))?,
847        layer.ffn.w_k.matmul_op(
848            &buffer.ffn_kx,
849            &buffer.ffn_k,
850            Activation::SquaredRelu,
851            turbo(num_token),
852        )?,
853        hook_op(Hook::PostFfnActivate(index))?,
854        layer.ffn.w_v.matmul_op_sparse(
855            &buffer.ffn_k,
856            &buffer.ffn_v,
857            Activation::None,
858            turbo(num_token),
859        )?,
860        layer.ffn.w_r.matmul_op(
861            &buffer.ffn_rx,
862            &buffer.ffn_r,
863            Activation::None,
864            turbo(num_token),
865        )?,
866        hook_op(Hook::PostFfnLinear(index))?,
867        hook_op(Hook::PreFfnChannelMix(index))?,
868        TensorOp::channel_mix(
869            &buffer.cursors,
870            state.ffn(index)?,
871            &buffer.ffn_r,
872            &buffer.ffn_v,
873            &buffer.ffn_x,
874        )?,
875        hook_op(Hook::PostFfnChannelMix(index))?,
876        TensorOp::add(&buffer.ffn_x, &buffer.x)?,
877        hook_op(Hook::PostFfn(index))?,
878    ]);
879
880    if (index + 1).is_multiple_of(rescale) {
881        ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
882    }
883
884    Ok(TensorOp::List(ops))
885}
886
887fn dispatch_header<F: Float>(
888    hooks: Arc<HookMap<F>>,
889    frame: Frame<F>,
890    head: Head,
891    head_x: TensorGpu<F, ReadWrite>,
892    num_header: usize,
893    head_op: TensorOp,
894) -> Result<TensorOp, TensorError> {
895    let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
896    let header = &frame.header;
897    let mut ops = vec![head_op];
898
899    if num_header > 0 {
900        ops.extend([
901            hook_op(Hook::PreHead)?,
902            TensorOp::layer_norm(
903                &head.layer_norm.w,
904                &head.layer_norm.b,
905                &head_x,
906                Model::LN_EPS,
907            )?,
908            hook_op(Hook::PostHeadLayerNorm)?,
909            head.w.matmul_op(
910                head_x.view(.., .., .., ..)?,
911                header.head_o.view(.., .., .., ..)?,
912                Activation::None,
913                turbo(num_header),
914            )?,
915            hook_op(Hook::PostHead)?,
916        ]);
917    }
918    Ok(TensorOp::List(ops))
919}
920
921impl<R: Reader> ModelBuilder<R> {
922    pub async fn build_v5(self) -> Result<Model, LoaderError> {
923        let ModelBuilder {
924            context,
925            model,
926            rescale,
927            sep,
928            lora,
929            quant,
930            ..
931        } = self;
932
933        let rescale = rescale.unwrap_or(Model::DEFAULT_RESCALE);
934        let sep = sep.unwrap_or(Model::DEFAULT_SEP);
935
936        let info = Loader::info(&model)?;
937        let loader = Loader {
938            context: context.clone(),
939            model,
940            lora,
941        };
942
943        let embed = Embed {
944            layer_norm: LayerNorm {
945                w: loader.load_vector_f16("blocks.0.ln0.weight")?,
946                b: loader.load_vector_f16("blocks.0.ln0.bias")?,
947            },
948            w: loader.load_matrix_f16_padded_cpu("emb.weight")?,
949        };
950
951        let head = Head {
952            layer_norm: LayerNorm {
953                w: loader.load_vector_f16("ln_out.weight")?,
954                b: loader.load_vector_f16("ln_out.bias")?,
955            },
956            w: Matrix::Fp16(loader.load_matrix_f16("head.weight")?),
957        };
958
959        let submission_index = Some(context.queue.submit(None));
960        _ = context.device.poll(wgpu::PollType::Wait {
961            submission_index,
962            timeout: None,
963        });
964
965        let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant);
966        let load_matrix_discount = |name: String, quant: Quant, discount: f32| {
967            loader.load_matrix_discount(name, quant, discount)
968        };
969
970        let mut layers = vec![];
971        for layer in 0..info.num_layer {
972            let quant = quant.get(&layer).copied().unwrap_or_default();
973            let discount = 2.0_f32.powi(-((layer / rescale) as i32));
974
975            let att_layer_norm = LayerNorm {
976                w: loader.load_vector_f16(format!("blocks.{layer}.ln1.weight"))?,
977                b: loader.load_vector_f16(format!("blocks.{layer}.ln1.bias"))?,
978            };
979
980            let att = format!("blocks.{layer}.att");
981            let time_decay = loader.load_vector_exp_exp_f32(format!("{att}.time_decay"))?;
982            let time_first = loader.load_vector_f32(format!("{att}.time_first"))?;
983            let time_mix_k = loader.load_vector_f16(format!("{att}.time_mix_k"))?;
984            let time_mix_v = loader.load_vector_f16(format!("{att}.time_mix_v"))?;
985            let time_mix_r = loader.load_vector_f16(format!("{att}.time_mix_r"))?;
986            let time_mix_g = loader.load_vector_f16(format!("{att}.time_mix_g"))?;
987
988            let group_norm = LayerNorm {
989                w: loader
990                    .load_vector_f16(format!("{att}.ln_x.weight"))?
991                    .reshape(
992                        TensorDimension::Auto,
993                        TensorDimension::Size(info.num_head),
994                        TensorDimension::Size(1),
995                        TensorDimension::Size(1),
996                    )?,
997                b: loader
998                    .load_vector_f16(format!("{att}.ln_x.bias"))?
999                    .reshape(
1000                        TensorDimension::Auto,
1001                        TensorDimension::Size(info.num_head),
1002                        TensorDimension::Size(1),
1003                        TensorDimension::Size(1),
1004                    )?,
1005            };
1006
1007            let att = Att {
1008                time_decay,
1009                time_first,
1010                time_mix_k,
1011                time_mix_v,
1012                time_mix_r,
1013                time_mix_g,
1014                w_k: load_matrix(format!("{att}.key.weight"), quant)?,
1015                w_v: load_matrix(format!("{att}.value.weight"), quant)?,
1016                w_r: load_matrix(format!("{att}.receptance.weight"), quant)?,
1017                w_g: load_matrix(format!("{att}.gate.weight"), quant)?,
1018                w_o: load_matrix_discount(format!("{att}.output.weight"), quant, discount)?,
1019                group_norm,
1020            };
1021
1022            let ffn_layer_norm = LayerNorm {
1023                w: loader.load_vector_f16(format!("blocks.{layer}.ln2.weight"))?,
1024                b: loader.load_vector_f16(format!("blocks.{layer}.ln2.bias"))?,
1025            };
1026
1027            let ffn = format!("blocks.{layer}.ffn");
1028            let time_mix_k = loader.load_vector_f16(format!("{ffn}.time_mix_k"))?;
1029            let time_mix_r = loader.load_vector_f16(format!("{ffn}.time_mix_r"))?;
1030
1031            let ffn = Ffn {
1032                time_mix_k,
1033                time_mix_r,
1034                w_r: load_matrix(format!("{ffn}.receptance.weight"), quant)?,
1035                w_k: load_matrix(format!("{ffn}.key.weight"), quant)?,
1036                w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?,
1037            };
1038
1039            let submission_index = Some(context.queue.submit(None));
1040            _ = context.device.poll(wgpu::PollType::Wait {
1041                submission_index,
1042                timeout: None,
1043            });
1044
1045            layers.push(Layer {
1046                att_layer_norm,
1047                ffn_layer_norm,
1048                att,
1049                ffn,
1050            })
1051        }
1052
1053        let submission_index = Some(context.queue.submit(None));
1054        _ = context.device.poll(wgpu::PollType::Wait {
1055            submission_index,
1056            timeout: None,
1057        });
1058
1059        let tensor = ModelTensor {
1060            embed,
1061            head,
1062            layers,
1063        };
1064        let model = {
1065            let context = context.clone();
1066            let info = info.clone();
1067            Model {
1068                context,
1069                info,
1070                rescale,
1071                sep,
1072                tensor,
1073            }
1074        };
1075        Ok(model)
1076    }
1077}
1078
1079/// Read the pre-trained state from the file.
1080pub async fn read_state<R: Reader>(
1081    context: &Context,
1082    info: &ModelInfo,
1083    model: R,
1084) -> Result<TensorCpu<f32>, LoaderError> {
1085    let loader = Loader {
1086        context: context.clone(),
1087        model,
1088        lora: vec![],
1089    };
1090
1091    let head_size = info.num_emb / info.num_head;
1092    let data: TensorGpu<f32, _> = context.zeros([info.num_emb, head_size + 2, info.num_layer, 1]);
1093
1094    let mut ops = vec![];
1095    for layer in 0..info.num_layer {
1096        let matrix = loader.load_matrix_f16(format!("blocks.{layer}.att.time_state"))?;
1097        let state: TensorGpu<_, _> = context.tensor_init([head_size, info.num_head, head_size, 1]);
1098        let reshaped: TensorGpu<f16, _> = state.reshape(
1099            TensorDimension::Size(info.num_emb),
1100            TensorDimension::Size(head_size),
1101            TensorDimension::Size(1),
1102            TensorDimension::Auto,
1103        )?;
1104        ops.extend([
1105            TensorOp::transpose(&matrix, &state)?,
1106            TensorOp::blit(&reshaped, data.view(.., 1..head_size + 1, layer, ..)?)?,
1107        ]);
1108    }
1109    context.queue.submit(context.encode(&TensorOp::List(ops)));
1110
1111    Ok(data.back().await)
1112}