web_rwkv/runtime/
v6.rs

1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use half::f16;
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use web_rwkv_derive::DeserializeSeed;
11use wgpu::CommandBuffer;
12
13use super::{
14    infer::{RnnChunk, RnnInfo, RnnInput, RnnOutput, RnnOutputBatch, RnnRedirect, Token},
15    loader::{Loader, LoaderError, Reader},
16    model::{AsAny, ModelBuilder, ModelCustomInfo, ModelInfo, State as _},
17    Dispatcher, Job, RuntimeError,
18};
19use crate::{
20    context::Context,
21    num::Float,
22    runtime::model::Quant,
23    tensor::{
24        cache::ResourceCache,
25        kind::ReadWrite,
26        matrix::Matrix,
27        ops::{Activation, TensorCommand, TensorOp},
28        serialization::Seed,
29        shape::{Shape, TensorDimension},
30        DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit,
31        TensorReshape, TensorShape, TensorStack,
32    },
33};
34
35#[derive(Debug, Clone, Serialize, DeserializeSeed)]
36#[serde_seed(seed = "Seed", context = "Context")]
37pub struct Model {
38    pub context: Context,
39    pub info: ModelInfo,
40    pub rescale: usize,
41    pub sep: usize,
42    pub tensor: ModelTensor,
43}
44
45impl Model {
46    pub const LN_EPS: f32 = 1.0e-5;
47    pub const GN_EPS: f32 = 64.0e-5;
48
49    pub const DEFAULT_RESCALE: usize = 6;
50    pub const DEFAULT_SEP: usize = 1024;
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
54pub struct CustomInfo {
55    /// Token shift LoRA adapter size.
56    pub time_mix: usize,
57    /// Time decay LoRA adapter size.
58    pub time_decay: usize,
59}
60
61#[derive(Debug, Clone, Serialize, DeserializeSeed)]
62#[serde_seed(seed = "Seed", context = "Context")]
63pub struct ModelTensor {
64    pub embed: Embed,
65    pub head: Head,
66    pub layers: Vec<Layer>,
67}
68
69#[derive(Debug, Clone, Serialize, DeserializeSeed)]
70#[serde_seed(seed = "Seed", context = "Context")]
71pub struct LayerNorm {
72    pub w: TensorGpu<f16, ReadWrite>,
73    pub b: TensorGpu<f16, ReadWrite>,
74}
75
76#[derive(Debug, Clone, Serialize, DeserializeSeed)]
77#[serde_seed(seed = "Seed", context = "Context")]
78pub struct Att {
79    pub time_decay: TensorGpu<f16, ReadWrite>,
80    pub time_first: TensorGpu<f32, ReadWrite>,
81
82    pub time_mix_x: TensorGpu<f16, ReadWrite>,
83    pub time_mix: TensorGpu<f16, ReadWrite>,
84
85    pub time_decay_w1: Matrix,
86    pub time_decay_w2: Matrix,
87    pub time_mix_w1: Matrix,
88    pub time_mix_w2: Matrix,
89
90    pub w_k: Matrix,
91    pub w_v: Matrix,
92    pub w_r: Matrix,
93    pub w_g: Matrix,
94    pub w_o: Matrix,
95
96    pub group_norm: LayerNorm,
97}
98
99#[derive(Debug, Clone, Serialize, DeserializeSeed)]
100#[serde_seed(seed = "Seed", context = "Context")]
101pub struct Ffn {
102    pub time_mix_k: TensorGpu<f16, ReadWrite>,
103    pub time_mix_r: TensorGpu<f16, ReadWrite>,
104
105    pub w_k: Matrix,
106    pub w_v: Matrix,
107    pub w_r: Matrix,
108}
109
110#[derive(Debug, Clone, Serialize, DeserializeSeed)]
111#[serde_seed(seed = "Seed", context = "Context")]
112pub struct Layer {
113    pub att_layer_norm: LayerNorm,
114    pub ffn_layer_norm: LayerNorm,
115    pub att: Att,
116    pub ffn: Ffn,
117}
118
119#[derive(Debug, Clone, Serialize, DeserializeSeed)]
120#[serde_seed(seed = "Seed", context = "Context")]
121pub struct Embed {
122    pub layer_norm: LayerNorm,
123    pub w: TensorCpu<f16>,
124}
125
126#[derive(Debug, Clone, Serialize, DeserializeSeed)]
127#[serde_seed(seed = "Seed", context = "Context")]
128pub struct Head {
129    pub layer_norm: LayerNorm,
130    pub w: Matrix,
131}
132
133#[derive(Debug, Clone, Serialize, DeserializeSeed)]
134#[serde_seed(seed = "Seed", context = "Context")]
135pub struct State {
136    pub context: Context,
137    pub info: ModelInfo,
138    pub data: Vec<TensorGpu<f32, ReadWrite>>,
139}
140
141impl State {
142    async fn back(&self, batch: usize) -> Result<TensorCpu<f32>, TensorError> {
143        let context = &self.context;
144        let mut tensors = Vec::with_capacity(self.info.num_layer);
145        let mut encoder = context.device.create_command_encoder(&Default::default());
146        for data in self.data.iter() {
147            let shape = data.shape();
148            let destination = context.tensor_init([shape[0], shape[1], 1, 1]);
149            encoder.copy_tensor_batch(data, &destination, batch, 0)?;
150            tensors.push(destination);
151        }
152        context.queue.submit(Some(encoder.finish()));
153
154        let mut backed = Vec::with_capacity(tensors.len());
155        for tensor in tensors {
156            backed.push(tensor.back().await);
157        }
158        TensorCpu::stack(backed, 2)
159    }
160}
161
162impl AsAny for State {
163    fn as_any(&self) -> &dyn std::any::Any {
164        self
165    }
166}
167
168impl super::model::State for State {
169    #[inline]
170    fn num_batch(&self) -> usize {
171        self.data[0].shape()[2]
172    }
173
174    #[inline]
175    fn init_shape(&self) -> Shape {
176        let info = &self.info;
177        let head_size = info.num_emb / info.num_head;
178        [info.num_emb, head_size + 2, info.num_layer, 1].into()
179    }
180
181    fn init(&self) -> TensorCpu<f32> {
182        let shape = self.init_shape();
183        let data = vec![0.0; shape.len()];
184        TensorCpu::from_data(shape, data).unwrap()
185    }
186
187    fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
188        let head_size = self.info.num_emb / self.info.num_head;
189        let end = head_size + 1;
190        self.data[layer].view(.., 0..end, .., ..)
191    }
192
193    fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
194        let head_size = self.info.num_emb / self.info.num_head;
195        let start = head_size + 1;
196        self.data[layer].view(.., start, .., ..)
197    }
198
199    fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError> {
200        let head_size = self.info.num_emb / self.info.num_head;
201        tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
202        for (data, source) in self.data.iter().zip(tensor.split(2)?.into_iter()) {
203            data.load_batch(&source, batch)?;
204        }
205        Ok(())
206    }
207
208    #[cfg(not(target_arch = "wasm32"))]
209    fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
210        Box::pin(self.back(batch))
211    }
212
213    #[cfg(target_arch = "wasm32")]
214    fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
215        Box::pin(self.back(batch))
216    }
217
218    fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError> {
219        let head_size = self.info.num_emb / self.info.num_head;
220        tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
221
222        let context = &self.context;
223        let mut ops = Vec::with_capacity(self.data.len());
224        for (layer, data) in self.data.iter().enumerate() {
225            ops.push(TensorOp::blit(
226                tensor.view(.., .., layer, ..)?,
227                data.view(.., .., batch, ..)?,
228            )?);
229        }
230        context.queue.submit(context.encode(&TensorOp::List(ops)));
231
232        Ok(())
233    }
234
235    fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError> {
236        let context = &self.context;
237        let head_size = self.info.num_emb / self.info.num_head;
238        let shape = [self.info.num_emb, head_size + 2, self.info.num_layer, 1];
239        let tensor: TensorGpu<_, _> = context.tensor_init(shape);
240
241        let mut ops = Vec::with_capacity(self.data.len());
242        for (layer, data) in self.data.iter().enumerate() {
243            ops.push(TensorOp::blit(
244                data.view(.., .., batch, ..)?,
245                tensor.view(.., .., layer, ..)?,
246            )?);
247        }
248        context.queue.submit(context.encode(&TensorOp::List(ops)));
249
250        Ok(tensor)
251    }
252
253    fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError> {
254        backed.slice(.., 0, layer, ..)
255    }
256}
257
258impl DeepClone for State {
259    fn deep_clone(&self) -> Self {
260        let data = self.data.iter().map(|tensor| tensor.deep_clone()).collect();
261        Self {
262            data,
263            ..self.clone()
264        }
265    }
266}
267
268#[derive(Debug, Clone, Serialize, DeserializeSeed)]
269#[serde_seed(seed = "Seed", context = "Context")]
270pub struct Runtime<F: Float> {
271    pub cursors: TensorGpu<u32, ReadWrite>,
272    pub input: TensorGpu<f16, ReadWrite>,
273
274    pub x: TensorGpu<F, ReadWrite>,
275    pub aux_x: TensorGpu<f32, ReadWrite>,
276
277    pub att_x: TensorGpu<F, ReadWrite>,
278    pub att_xx: TensorGpu<F, ReadWrite>,
279    /// Token shifted time decay input, `[C, T, 5]`.
280    pub att_sx: TensorGpu<F, ReadWrite>,
281    /// Time decay LoRA intermediate, `[64, T]`.
282    pub att_w: TensorGpu<F, ReadWrite>,
283    pub att_k: TensorGpu<f32, ReadWrite>,
284    pub att_v: TensorGpu<f32, ReadWrite>,
285    pub att_r: TensorGpu<f32, ReadWrite>,
286    pub att_g: TensorGpu<F, ReadWrite>,
287    pub att_o: TensorGpu<F, ReadWrite>,
288
289    /// Token shift LoRA intermediate, `[32, 5, T]`.
290    pub time_mix_x: TensorGpu<F, ReadWrite>,
291    /// Token shift LoRA intermediate transposed, `[32, T, 5]`.
292    pub time_mix_t: TensorGpu<F, ReadWrite>,
293    /// Token shift LoRA output, `[C, T, 5]`.
294    pub time_mix: TensorGpu<F, ReadWrite>,
295    pub time_decay: TensorGpu<f32, ReadWrite>,
296
297    pub ffn_x: TensorGpu<F, ReadWrite>,
298    pub ffn_kx: TensorGpu<F, ReadWrite>,
299    pub ffn_rx: TensorGpu<F, ReadWrite>,
300    pub ffn_k: TensorGpu<F, ReadWrite>,
301    pub ffn_v: TensorGpu<F, ReadWrite>,
302    pub ffn_r: TensorGpu<F, ReadWrite>,
303}
304
305impl<F: Float> Runtime<F> {
306    pub fn new(context: &Context, info: &ModelInfo, num_token: usize) -> Self {
307        let ModelCustomInfo::V6(custom) = info.custom else {
308            unreachable!()
309        };
310
311        let shape = Shape::new(info.num_emb, num_token, 1, 1);
312        let cursors_shape = Shape::new(num_token, 1, 1, 1);
313        let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1);
314        let time_mix_shape = Shape::new(info.num_emb, num_token, 5, 1);
315        let time_mix_x_shape = Shape::new(custom.time_mix, 5, num_token, 1);
316        let time_mix_t_shape = Shape::new(custom.time_mix, num_token, 5, 1);
317        let time_decay_shape = Shape::new(custom.time_decay, num_token, 1, 1);
318
319        Self {
320            cursors: context.tensor_init(cursors_shape),
321            input: context.tensor_init(shape),
322            x: context.tensor_init(shape),
323            aux_x: context.tensor_init(shape),
324            att_x: context.tensor_init(shape),
325            att_xx: context.tensor_init(shape),
326            att_sx: context.tensor_init(time_mix_shape),
327            att_w: context.tensor_init(time_decay_shape),
328            att_k: context.tensor_init(shape),
329            att_v: context.tensor_init(shape),
330            att_r: context.tensor_init(shape),
331            att_g: context.tensor_init(shape),
332            att_o: context.tensor_init(shape),
333            time_mix_x: context.tensor_init(time_mix_x_shape),
334            time_mix_t: context.tensor_init(time_mix_t_shape),
335            time_mix: context.tensor_init(time_mix_shape),
336            time_decay: context.tensor_init(shape),
337            ffn_x: context.tensor_init(shape),
338            ffn_kx: context.tensor_init(shape),
339            ffn_rx: context.tensor_init(shape),
340            ffn_k: context.tensor_init(hidden_shape),
341            ffn_v: context.tensor_init(shape),
342            ffn_r: context.tensor_init(shape),
343        }
344    }
345}
346
347#[derive(Debug, Clone, Serialize, DeserializeSeed)]
348#[serde_seed(seed = "Seed", context = "Context")]
349pub struct Header<F: Float> {
350    pub head_x: TensorGpu<F, ReadWrite>,
351    pub head_o: TensorGpu<f32, ReadWrite>,
352}
353
354impl<F: Float> Header<F> {
355    pub fn new(context: &Context, info: &ModelInfo, num_header: usize) -> Self {
356        let head_shape = Shape::new(info.num_emb, num_header, 1, 1);
357        let output_shape = Shape::new(info.num_vocab_padded(), num_header, 1, 1);
358
359        Self {
360            head_x: context.tensor_init(head_shape),
361            head_o: context.tensor_init(output_shape),
362        }
363    }
364}
365
366#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
367pub enum Hook {
368    PostEmbedLoaded,
369    PostEmbedLayerNorm,
370    PreAtt(usize),
371    PostAttLayerNorm(usize),
372    PreAttTokenShift(usize),
373    PostAttTokenShift(usize),
374    PreAttTokenShiftAdapt(usize),
375    PostAttTokenShiftAdapt(usize),
376    PostAttTokenShiftAdaptActivate(usize),
377    PreAttGatedTokenShift(usize),
378    PostAttGatedTokenShift(usize),
379    PreAttTimeDecayAdapt(usize),
380    PostAttTimeDecayAdapt(usize),
381    PostAttTimeDecayAdaptActivate(usize),
382    PreAttTimeDecayActivate(usize),
383    PostAttTimeDecayActivate(usize),
384    PreAttLinear(usize),
385    PostAttLinear(usize),
386    PreAttTimeMix(usize),
387    PostAttTimeMix(usize),
388    PreAttGate(usize),
389    PostAttGate(usize),
390    PreAttOut(usize),
391    PostAttOut(usize),
392    PostAtt(usize),
393    PreFfn(usize),
394    PostFfnLayerNorm(usize),
395    PreFfnTokenShift(usize),
396    PostFfnTokenShift(usize),
397    PreFfnLinear(usize),
398    PostFfnLinear(usize),
399    PostFfnActivate(usize),
400    PreFfnChannelMix(usize),
401    PostFfnChannelMix(usize),
402    PostFfn(usize),
403    PreHead,
404    PostHeadLayerNorm,
405    PostHead,
406}
407
408pub struct RnnJob {
409    commands: Vec<CommandBuffer>,
410    redirect: RnnRedirect,
411
412    embed: TensorCpu<f16>,
413
414    cursors: TensorGpu<u32, ReadWrite>,
415    input: TensorGpu<f16, ReadWrite>,
416    output: TensorGpu<f32, ReadWrite>,
417}
418
419impl Job for RnnJob {
420    type Input = RnnInput;
421    type Output = RnnOutput;
422
423    fn load(&self, input: &RnnChunk) -> Result<(), RuntimeError> {
424        if input.num_token() == 0 {
425            return Ok(());
426        }
427
428        let stack: Vec<TensorCpu<f16>> = input
429            .iter()
430            .map(|chunk| {
431                let num_emb = self.embed.shape()[0];
432                let data = self.embed.data();
433                let data = chunk
434                    .iter()
435                    .map(|token| match token {
436                        &Token::Token(token) => {
437                            let start = num_emb * token as usize;
438                            let end = start + num_emb;
439                            let data = data[start..end].to_vec();
440                            TensorCpu::from_data_1d(data)
441                        }
442                        Token::Embed(tensor) => tensor.clone(),
443                    })
444                    .collect_vec();
445                match TensorCpu::stack(data, 1) {
446                    Ok(tensor) => tensor,
447                    Err(_) => TensorCpu::init([num_emb, 0, 1, 1]),
448                }
449            })
450            .collect();
451        let stack = TensorStack::try_from(stack)?;
452
453        let cursors = stack.cursors.clone().into_cursors();
454        let cursors = TensorCpu::from_data(self.cursors.shape(), cursors)?;
455        self.cursors.load(&cursors)?;
456        self.input.load(&stack.tensor)?;
457
458        Ok(())
459    }
460
461    fn submit(&mut self) {
462        let commands = std::mem::take(&mut self.commands);
463        self.output.context.queue.submit(commands);
464    }
465
466    async fn back(self) -> Result<Self::Output, RuntimeError> {
467        let output = self.output.back().await;
468        let batches: Vec<_> = self
469            .redirect
470            .outputs
471            .into_iter()
472            .map(|(start, end)| output.slice(.., start..end, .., ..))
473            .try_collect()?;
474        let batches = batches.into_iter().map(RnnOutputBatch).collect();
475        Ok(RnnOutput(batches))
476    }
477}
478
479#[derive(Debug, Clone)]
480pub struct Frame<F: Float> {
481    pub state: State,
482    pub buffer: Arc<Runtime<F>>,
483    pub header: Arc<Header<F>>,
484}
485
486pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
487pub type HookMap<F> = HashMap<Hook, HookFn<F>>;
488
489#[derive(Clone)]
490pub struct Bundle<F: Float> {
491    model: Model,
492    state: State,
493    hooks: Arc<HookMap<F>>,
494    buffers: ResourceCache<usize, Runtime<F>>,
495    headers: ResourceCache<usize, Header<F>>,
496    phantom: PhantomData<F>,
497}
498
499impl<F: Float> Bundle<F> {
500    pub fn new(model: Model, num_batch: usize) -> Self {
501        let context = model.context.clone();
502        let info = model.info.clone();
503        let state = {
504            let head_size = info.num_emb / info.num_head;
505            let shape = Shape::new(info.num_emb, head_size + 2, num_batch, 1);
506            let data = (0..info.num_layer).map(|_| context.zeros(shape)).collect();
507            State {
508                context,
509                info,
510                data,
511            }
512        };
513        Self {
514            model,
515            state,
516            hooks: Default::default(),
517            buffers: ResourceCache::new(4),
518            headers: ResourceCache::new(4),
519            phantom: PhantomData,
520        }
521    }
522
523    pub fn new_with_hooks(model: Model, num_batch: usize, hooks: HookMap<F>) -> Self {
524        Self {
525            hooks: Arc::new(hooks),
526            ..Self::new(model, num_batch)
527        }
528    }
529
530    fn checkout_buffer(
531        &self,
532        context: &Context,
533        info: &ModelInfo,
534        num_token: usize,
535    ) -> Arc<Runtime<F>> {
536        self.buffers
537            .checkout(num_token, || Runtime::new(context, info, num_token))
538    }
539
540    fn checkout_header(
541        &self,
542        context: &Context,
543        info: &ModelInfo,
544        num_header: usize,
545    ) -> Arc<Header<F>> {
546        self.headers
547            .checkout(num_header, || Header::new(context, info, num_header))
548    }
549}
550
551impl<F: Float> super::model::Bundle for Bundle<F> {
552    #[inline]
553    fn info(&self) -> ModelInfo {
554        self.model.info.clone()
555    }
556
557    #[inline]
558    fn state(&self) -> impl super::model::State + AsAny + 'static {
559        self.state.clone()
560    }
561
562    #[inline]
563    fn model(&self) -> impl Serialize + 'static {
564        self.model.clone()
565    }
566}
567
568fn turbo(num_token: usize) -> bool {
569    num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE)
570}
571
572fn hook_op<F: Float>(
573    hooks: &HookMap<F>,
574    hook: &Hook,
575    frame: &Frame<F>,
576) -> Result<TensorOp, TensorError> {
577    match hooks.get(hook) {
578        Some(f) => f(frame.clone()),
579        None => Ok(TensorOp::empty()),
580    }
581}
582
583impl<F: Float> Dispatcher<RnnJob> for Bundle<F> {
584    type Info = RnnInfo;
585
586    fn dispatch(&self, seed: Self::Info) -> Result<RnnJob, RuntimeError> {
587        let model = &self.model;
588        let state = &self.state;
589        let context = &model.context;
590        let info = &model.info;
591        let tensor = &model.tensor;
592
593        let num_token = seed.num_token();
594        let head_size = info.num_emb / info.num_head;
595
596        let redirect = seed.redirect();
597        let num_header = redirect.headers.len();
598
599        let buffer = self.checkout_buffer(context, info, num_token);
600        let header = self.checkout_header(context, info, num_header);
601        let frame = Frame {
602            state: state.clone(),
603            buffer: buffer.clone(),
604            header: header.clone(),
605        };
606
607        context.maintain();
608        self.buffers.maintain();
609        self.headers.maintain();
610
611        if num_token == 0 {
612            return Ok(RnnJob {
613                commands: vec![],
614                redirect,
615                embed: model.tensor.embed.w.clone(),
616                cursors: buffer.cursors.clone(),
617                input: buffer.input.clone(),
618                output: header.head_o.clone(),
619            });
620        }
621
622        #[cfg(feature = "trace")]
623        let _span = tracing::trace_span!("build").entered();
624
625        let (head_op, head_x) = redirect.op(&buffer.x, &header.head_x)?;
626
627        let hook_op = |hook: Hook| hook_op(&self.hooks, &hook, &frame);
628        let mut ops = vec![];
629
630        {
631            #[cfg(feature = "trace")]
632            let _span = tracing::trace_span!("embed").entered();
633
634            ops.extend([
635                hook_op(Hook::PostEmbedLoaded)?,
636                TensorOp::layer_norm(
637                    &tensor.embed.layer_norm.w,
638                    &tensor.embed.layer_norm.b,
639                    &buffer.input,
640                    Model::LN_EPS,
641                )?,
642                TensorOp::blit(&buffer.input, &buffer.x)?,
643                hook_op(Hook::PostEmbedLayerNorm)?,
644            ]);
645        };
646
647        for (index, layer) in tensor.layers.iter().enumerate() {
648            #[cfg(feature = "trace")]
649            let _span = tracing::trace_span!("layer", index).entered();
650
651            let hooks = self.hooks.clone();
652            let frame = frame.clone();
653            let layer = layer.clone();
654
655            let op = dispatch_layer(
656                hooks,
657                frame,
658                layer,
659                index,
660                num_token,
661                head_size,
662                model.rescale,
663            )?;
664            ops.push(op);
665
666            if (index + 1) % model.sep == 0 {
667                ops.push(TensorOp::Sep);
668            }
669        }
670
671        {
672            #[cfg(feature = "trace")]
673            let _span = tracing::trace_span!("header").entered();
674
675            let hooks = self.hooks.clone();
676            let frame = frame.clone();
677            let head = model.tensor.head.clone();
678
679            let op = dispatch_header(hooks, frame, head, head_x, num_header, head_op)?;
680            ops.push(op);
681        }
682
683        let commands = {
684            #[cfg(feature = "trace")]
685            let _span = tracing::trace_span!("encode").entered();
686            context.encode(&TensorOp::List(ops))
687        };
688
689        Ok(RnnJob {
690            commands,
691            redirect,
692            embed: model.tensor.embed.w.clone(),
693            cursors: buffer.cursors.clone(),
694            input: buffer.input.clone(),
695            output: header.head_o.clone(),
696        })
697    }
698}
699
700#[allow(clippy::too_many_arguments)]
701fn dispatch_layer<F: Float>(
702    hooks: Arc<HookMap<F>>,
703    frame: Frame<F>,
704    layer: Layer,
705    index: usize,
706    num_token: usize,
707    head_size: usize,
708    rescale: usize,
709) -> Result<TensorOp, TensorError> {
710    let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
711    let Frame { state, buffer, .. } = &frame;
712
713    let time_first = layer.att.time_first.reshape(
714        TensorDimension::Size(head_size),
715        TensorDimension::Auto,
716        TensorDimension::Size(1),
717        TensorDimension::Size(1),
718    )?;
719    let time_decay = buffer.time_decay.reshape(
720        TensorDimension::Size(head_size),
721        TensorDimension::Auto,
722        TensorDimension::Size(num_token),
723        TensorDimension::Size(1),
724    )?;
725    let time_mix_x = buffer.time_mix_x.reshape(
726        TensorDimension::Auto,
727        TensorDimension::Size(num_token),
728        TensorDimension::Size(1),
729        TensorDimension::Size(1),
730    )?;
731    let aux_x = buffer.aux_x.reshape(
732        TensorDimension::Size(head_size),
733        TensorDimension::Auto,
734        TensorDimension::Size(num_token),
735        TensorDimension::Size(1),
736    )?;
737    let att_k = buffer.att_k.reshape(
738        TensorDimension::Size(head_size),
739        TensorDimension::Auto,
740        TensorDimension::Size(num_token),
741        TensorDimension::Size(1),
742    )?;
743    let att_v = buffer.att_v.reshape(
744        TensorDimension::Size(head_size),
745        TensorDimension::Auto,
746        TensorDimension::Size(num_token),
747        TensorDimension::Size(1),
748    )?;
749    let att_r = buffer.att_r.reshape(
750        TensorDimension::Size(head_size),
751        TensorDimension::Auto,
752        TensorDimension::Size(num_token),
753        TensorDimension::Size(1),
754    )?;
755
756    let mut ops = vec![];
757
758    ops.extend([
759        TensorOp::blit(&buffer.x, &buffer.att_x)?,
760        hook_op(Hook::PreAtt(index))?,
761        TensorOp::layer_norm(
762            &layer.att_layer_norm.w,
763            &layer.att_layer_norm.b,
764            &buffer.att_x,
765            Model::LN_EPS,
766        )?,
767        hook_op(Hook::PostAttLayerNorm(index))?,
768        hook_op(Hook::PreAttTokenShift(index))?,
769        TensorOp::token_shift(
770            &buffer.cursors,
771            &layer.att.time_mix_x,
772            state.att(index)?,
773            &buffer.att_x,
774            &buffer.att_xx,
775            true,
776        )?,
777        hook_op(Hook::PostAttTokenShift(index))?,
778        hook_op(Hook::PreAttTokenShiftAdapt(index))?,
779        layer.att.time_mix_w1.matmul_op(
780            &buffer.att_xx,
781            &time_mix_x,
782            Activation::Tanh,
783            turbo(num_token),
784        )?,
785        TensorOp::transpose(&buffer.time_mix_x, &buffer.time_mix_t)?,
786        hook_op(Hook::PostAttTokenShiftAdaptActivate(index))?,
787        layer.att.time_mix_w2.matmul_op(
788            &buffer.time_mix_t,
789            &buffer.time_mix,
790            Activation::None,
791            turbo(num_token),
792        )?,
793        hook_op(Hook::PostAttTokenShiftAdapt(index))?,
794        TensorOp::add(&layer.att.time_mix, &buffer.time_mix)?,
795        hook_op(Hook::PreAttGatedTokenShift(index))?,
796        TensorOp::token_shift(
797            &buffer.cursors,
798            &buffer.time_mix,
799            state.att(index)?,
800            &buffer.att_x,
801            &buffer.att_sx,
802            true,
803        )?,
804        hook_op(Hook::PostAttGatedTokenShift(index))?,
805        hook_op(Hook::PreAttLinear(index))?,
806        layer.att.w_k.matmul_op(
807            buffer.att_sx.view(.., .., 1, ..)?,
808            &buffer.att_k,
809            Activation::None,
810            turbo(num_token),
811        )?,
812        layer.att.w_v.matmul_op(
813            buffer.att_sx.view(.., .., 2, ..)?,
814            &buffer.att_v,
815            Activation::None,
816            turbo(num_token),
817        )?,
818        layer.att.w_r.matmul_op(
819            buffer.att_sx.view(.., .., 3, ..)?,
820            &buffer.att_r,
821            Activation::None,
822            turbo(num_token),
823        )?,
824        layer.att.w_g.matmul_op(
825            buffer.att_sx.view(.., .., 4, ..)?,
826            &buffer.att_g,
827            Activation::None,
828            turbo(num_token),
829        )?,
830        hook_op(Hook::PostAttLinear(index))?,
831        hook_op(Hook::PreAttTimeDecayAdapt(index))?,
832        layer.att.time_decay_w1.matmul_op(
833            buffer.att_sx.view(.., .., 0, ..)?,
834            &buffer.att_w,
835            Activation::Tanh,
836            turbo(num_token),
837        )?,
838        hook_op(Hook::PostAttTimeDecayAdaptActivate(index))?,
839        layer.att.time_decay_w2.matmul_op(
840            &buffer.att_w,
841            &buffer.time_decay,
842            Activation::None,
843            turbo(num_token),
844        )?,
845        hook_op(Hook::PostAttTimeDecayAdapt(index))?,
846        TensorOp::add(&layer.att.time_decay, &buffer.time_decay)?,
847        hook_op(Hook::PreAttTimeDecayActivate(index))?,
848        TensorOp::activate(&buffer.time_decay, Activation::StableExp)?,
849        hook_op(Hook::PostAttTimeDecayActivate(index))?,
850        hook_op(Hook::PreAttTimeMix(index))?,
851        TensorOp::blit(&buffer.att_x, &buffer.aux_x)?,
852        TensorOp::time_mix_v6(
853            &buffer.cursors,
854            &time_decay,
855            &time_first,
856            state.att(index)?,
857            &att_k,
858            &att_v,
859            &att_r,
860            &aux_x,
861        )?,
862        TensorOp::group_norm(
863            &layer.att.group_norm.w,
864            &layer.att.group_norm.b,
865            &aux_x,
866            Model::GN_EPS,
867        )?,
868        TensorOp::blit(&buffer.aux_x, &buffer.att_x)?,
869        hook_op(Hook::PostAttTimeMix(index))?,
870        hook_op(Hook::PreAttGate(index))?,
871        TensorOp::mul_activate(
872            &buffer.att_g,
873            &buffer.att_x,
874            Activation::Silu,
875            Activation::None,
876            Activation::None,
877        )?,
878        hook_op(Hook::PostAttGate(index))?,
879        hook_op(Hook::PreAttOut(index))?,
880        layer.att.w_o.matmul_op(
881            &buffer.att_x,
882            &buffer.att_o,
883            Activation::None,
884            turbo(num_token),
885        )?,
886        hook_op(Hook::PostAttOut(index))?,
887        TensorOp::add(&buffer.att_o, &buffer.x)?,
888        hook_op(Hook::PostAtt(index))?,
889    ]);
890
891    ops.extend([
892        TensorOp::blit(&buffer.x, &buffer.ffn_x)?,
893        hook_op(Hook::PreFfn(index))?,
894        TensorOp::layer_norm(
895            &layer.ffn_layer_norm.w,
896            &layer.ffn_layer_norm.b,
897            &buffer.ffn_x,
898            Model::LN_EPS,
899        )?,
900        hook_op(Hook::PostFfnLayerNorm(index))?,
901        hook_op(Hook::PreFfnTokenShift(index))?,
902        TensorOp::token_shift(
903            &buffer.cursors,
904            &layer.ffn.time_mix_k,
905            state.ffn(index)?,
906            &buffer.ffn_x,
907            &buffer.ffn_kx,
908            true,
909        )?,
910        TensorOp::token_shift(
911            &buffer.cursors,
912            &layer.ffn.time_mix_r,
913            state.ffn(index)?,
914            &buffer.ffn_x,
915            &buffer.ffn_rx,
916            true,
917        )?,
918        hook_op(Hook::PostFfnTokenShift(index))?,
919        hook_op(Hook::PreFfnLinear(index))?,
920        layer.ffn.w_k.matmul_op(
921            &buffer.ffn_kx,
922            &buffer.ffn_k,
923            Activation::SquaredRelu,
924            turbo(num_token),
925        )?,
926        hook_op(Hook::PostFfnActivate(index))?,
927        layer.ffn.w_v.matmul_op_sparse(
928            &buffer.ffn_k,
929            &buffer.ffn_v,
930            Activation::None,
931            turbo(num_token),
932        )?,
933        layer.ffn.w_r.matmul_op(
934            &buffer.ffn_rx,
935            &buffer.ffn_r,
936            Activation::None,
937            turbo(num_token),
938        )?,
939        hook_op(Hook::PostFfnLinear(index))?,
940        hook_op(Hook::PreFfnChannelMix(index))?,
941        TensorOp::channel_mix(
942            &buffer.cursors,
943            state.ffn(index)?,
944            &buffer.ffn_r,
945            &buffer.ffn_v,
946            &buffer.ffn_x,
947        )?,
948        hook_op(Hook::PostFfnChannelMix(index))?,
949        TensorOp::add(&buffer.ffn_x, &buffer.x)?,
950        hook_op(Hook::PostFfn(index))?,
951    ]);
952
953    if (index + 1).is_multiple_of(rescale) {
954        ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
955    }
956
957    Ok(TensorOp::List(ops))
958}
959
960fn dispatch_header<F: Float>(
961    hooks: Arc<HookMap<F>>,
962    frame: Frame<F>,
963    head: Head,
964    head_x: TensorGpu<F, ReadWrite>,
965    num_header: usize,
966    head_op: TensorOp,
967) -> Result<TensorOp, TensorError> {
968    let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
969    let header = &frame.header;
970    let mut ops = vec![head_op];
971
972    if num_header > 0 {
973        ops.extend([
974            hook_op(Hook::PreHead)?,
975            TensorOp::layer_norm(
976                &head.layer_norm.w,
977                &head.layer_norm.b,
978                &head_x,
979                Model::LN_EPS,
980            )?,
981            hook_op(Hook::PostHeadLayerNorm)?,
982            head.w.matmul_op(
983                head_x.view(.., .., .., ..)?,
984                header.head_o.view(.., .., .., ..)?,
985                Activation::None,
986                turbo(num_header),
987            )?,
988            hook_op(Hook::PostHead)?,
989        ]);
990    }
991    Ok(TensorOp::List(ops))
992}
993
994impl<R: Reader> ModelBuilder<R> {
995    pub async fn build_v6(self) -> Result<Model, LoaderError> {
996        let ModelBuilder {
997            context,
998            model,
999            rescale,
1000            sep,
1001            lora,
1002            quant,
1003            ..
1004        } = self;
1005
1006        let rescale = rescale.unwrap_or(Model::DEFAULT_RESCALE);
1007        let sep = sep.unwrap_or(Model::DEFAULT_SEP);
1008
1009        let info = Loader::info(&model)?;
1010        let loader = Loader {
1011            context: context.clone(),
1012            model,
1013            lora,
1014        };
1015
1016        let embed = Embed {
1017            layer_norm: LayerNorm {
1018                w: loader.load_vector_f16("blocks.0.ln0.weight")?,
1019                b: loader.load_vector_f16("blocks.0.ln0.bias")?,
1020            },
1021            w: loader.load_matrix_f16_padded_cpu("emb.weight")?,
1022        };
1023
1024        let head = Head {
1025            layer_norm: LayerNorm {
1026                w: loader.load_vector_f16("ln_out.weight")?,
1027                b: loader.load_vector_f16("ln_out.bias")?,
1028            },
1029            w: Matrix::Fp16(loader.load_matrix_f16_padded("head.weight")?),
1030        };
1031
1032        let submission_index = Some(context.queue.submit(None));
1033        _ = context.device.poll(wgpu::PollType::Wait {
1034            submission_index,
1035            timeout: None,
1036        });
1037
1038        let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant);
1039        let load_matrix_discount = |name: String, quant: Quant, discount: f32| {
1040            loader.load_matrix_discount(name, quant, discount)
1041        };
1042
1043        let mut layers = vec![];
1044        for layer in 0..info.num_layer {
1045            let quant = quant.get(&layer).copied().unwrap_or_default();
1046            let discount = 2.0_f32.powi(-((layer / rescale) as i32));
1047
1048            let att_layer_norm = LayerNorm {
1049                w: loader.load_vector_f16(format!("blocks.{layer}.ln1.weight"))?,
1050                b: loader.load_vector_f16(format!("blocks.{layer}.ln1.bias"))?,
1051            };
1052
1053            let att = format!("blocks.{layer}.att");
1054            let time_decay = loader.load_vector_f16(format!("{att}.time_decay"))?;
1055            let time_first = loader.load_vector_f32(format!("{att}.time_first"))?;
1056            let time_mix_x = loader.load_vector_f16(format!("{att}.time_mix_x"))?;
1057            let time_mix = {
1058                let time_mix: TensorGpu<_, _> = context.zeros([info.num_emb, 1, 5, 1]);
1059                let time_mix_w = loader.load_vector_f16(format!("{att}.time_mix_w"))?;
1060                let time_mix_k = loader.load_vector_f16(format!("{att}.time_mix_k"))?;
1061                let time_mix_v = loader.load_vector_f16(format!("{att}.time_mix_v"))?;
1062                let time_mix_r = loader.load_vector_f16(format!("{att}.time_mix_r"))?;
1063                let time_mix_g = loader.load_vector_f16(format!("{att}.time_mix_g"))?;
1064
1065                let ops = TensorOp::List(vec![
1066                    TensorOp::blit(&time_mix_w, time_mix.view(.., .., 0, ..)?)?,
1067                    TensorOp::blit(&time_mix_k, time_mix.view(.., .., 1, ..)?)?,
1068                    TensorOp::blit(&time_mix_v, time_mix.view(.., .., 2, ..)?)?,
1069                    TensorOp::blit(&time_mix_r, time_mix.view(.., .., 3, ..)?)?,
1070                    TensorOp::blit(&time_mix_g, time_mix.view(.., .., 4, ..)?)?,
1071                ]);
1072                context.queue.submit(context.encode(&ops));
1073                time_mix
1074            };
1075
1076            let time_decay_w1 = loader.load_matrix_f16(format!("{att}.time_decay_w1"))?;
1077            let time_decay_w2 = loader.load_matrix_f16(format!("{att}.time_decay_w2"))?;
1078
1079            let time_mix_w1 = loader.load_matrix_f16(format!("{att}.time_mix_w1"))?;
1080            let time_mix_w2 = loader.load_matrix_f16(format!("{att}.time_mix_w2"))?;
1081
1082            let group_norm = LayerNorm {
1083                w: loader
1084                    .load_vector_f16(format!("{att}.ln_x.weight"))?
1085                    .reshape(
1086                        TensorDimension::Auto,
1087                        TensorDimension::Size(info.num_head),
1088                        TensorDimension::Size(1),
1089                        TensorDimension::Size(1),
1090                    )?,
1091                b: loader
1092                    .load_vector_f16(format!("{att}.ln_x.bias"))?
1093                    .reshape(
1094                        TensorDimension::Auto,
1095                        TensorDimension::Size(info.num_head),
1096                        TensorDimension::Size(1),
1097                        TensorDimension::Size(1),
1098                    )?,
1099            };
1100
1101            let att = Att {
1102                time_decay,
1103                time_first,
1104                time_mix_x,
1105                time_mix,
1106                time_decay_w1: Matrix::Fp16(time_decay_w1),
1107                time_decay_w2: Matrix::Fp16(time_decay_w2),
1108                time_mix_w1: Matrix::Fp16(time_mix_w1),
1109                time_mix_w2: Matrix::Fp16(time_mix_w2),
1110                w_k: load_matrix(format!("{att}.key.weight"), quant)?,
1111                w_v: load_matrix(format!("{att}.value.weight"), quant)?,
1112                w_r: load_matrix(format!("{att}.receptance.weight"), quant)?,
1113                w_g: load_matrix(format!("{att}.gate.weight"), quant)?,
1114                w_o: load_matrix_discount(format!("{att}.output.weight"), quant, discount)?,
1115                group_norm,
1116            };
1117
1118            let ffn_layer_norm = LayerNorm {
1119                w: loader.load_vector_f16(format!("blocks.{layer}.ln2.weight"))?,
1120                b: loader.load_vector_f16(format!("blocks.{layer}.ln2.bias"))?,
1121            };
1122
1123            let ffn = format!("blocks.{layer}.ffn");
1124            let time_mix_k = loader.load_vector_f16(format!("{ffn}.time_mix_k"))?;
1125            let time_mix_r = loader.load_vector_f16(format!("{ffn}.time_mix_r"))?;
1126
1127            let ffn = Ffn {
1128                time_mix_k,
1129                time_mix_r,
1130                w_r: load_matrix(format!("{ffn}.receptance.weight"), quant)?,
1131                w_k: load_matrix(format!("{ffn}.key.weight"), quant)?,
1132                w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?,
1133            };
1134
1135            let submission_index = Some(context.queue.submit(None));
1136            _ = context.device.poll(wgpu::PollType::Wait {
1137                submission_index,
1138                timeout: None,
1139            });
1140
1141            layers.push(Layer {
1142                att_layer_norm,
1143                ffn_layer_norm,
1144                att,
1145                ffn,
1146            })
1147        }
1148
1149        let submission_index = Some(context.queue.submit(None));
1150        _ = context.device.poll(wgpu::PollType::Wait {
1151            submission_index,
1152            timeout: None,
1153        });
1154
1155        let tensor = ModelTensor {
1156            embed,
1157            head,
1158            layers,
1159        };
1160        let model = {
1161            let context = context.clone();
1162            let info = info.clone();
1163            Model {
1164                context,
1165                info,
1166                rescale,
1167                sep,
1168                tensor,
1169            }
1170        };
1171        Ok(model)
1172    }
1173}
1174
1175/// Read the pre-trained state from the file.
1176pub async fn read_state<R: Reader>(
1177    context: &Context,
1178    info: &ModelInfo,
1179    model: R,
1180) -> Result<TensorCpu<f32>, LoaderError> {
1181    let loader = Loader {
1182        context: context.clone(),
1183        model,
1184        lora: vec![],
1185    };
1186
1187    let head_size = info.num_emb / info.num_head;
1188    let data: TensorGpu<f32, _> = context.zeros([info.num_emb, head_size + 2, info.num_layer, 1]);
1189
1190    let mut ops = vec![];
1191    for layer in 0..info.num_layer {
1192        let matrix = loader.load_matrix_f16(format!("blocks.{layer}.att.time_state"))?;
1193        let state: TensorGpu<_, _> = context.tensor_init([head_size, info.num_head, head_size, 1]);
1194        let reshaped: TensorGpu<f16, _> = state.reshape(
1195            TensorDimension::Size(info.num_emb),
1196            TensorDimension::Size(head_size),
1197            TensorDimension::Size(1),
1198            TensorDimension::Auto,
1199        )?;
1200        ops.extend([
1201            TensorOp::transpose(&matrix, &state)?,
1202            TensorOp::blit(&reshaped, data.view(.., 1..head_size + 1, layer, ..)?)?,
1203        ]);
1204    }
1205    context.queue.submit(context.encode(&TensorOp::List(ops)));
1206
1207    Ok(data.back().await)
1208}