1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use half::f16;
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use web_rwkv_derive::DeserializeSeed;
11use wgpu::CommandBuffer;
12
13use super::{
14 infer::{RnnChunk, RnnInfo, RnnInput, RnnOutput, RnnOutputBatch, RnnRedirect, Token},
15 loader::{Loader, LoaderError, Reader},
16 model::{AsAny, ModelBuilder, ModelCustomInfo, ModelInfo, State as _},
17 Dispatcher, Job, RuntimeError,
18};
19use crate::{
20 context::Context,
21 num::Float,
22 runtime::model::Quant,
23 tensor::{
24 cache::ResourceCache,
25 kind::ReadWrite,
26 matrix::Matrix,
27 ops::{Activation, TensorCommand, TensorOp},
28 serialization::Seed,
29 shape::{Shape, TensorDimension},
30 DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit,
31 TensorReshape, TensorShape, TensorStack,
32 },
33};
34
35#[derive(Debug, Clone, Serialize, DeserializeSeed)]
36#[serde_seed(seed = "Seed", context = "Context")]
37pub struct Model {
38 pub context: Context,
39 pub info: ModelInfo,
40 pub rescale: usize,
41 pub sep: usize,
42 pub tensor: ModelTensor,
43}
44
45impl Model {
46 pub const LN_EPS: f32 = 1.0e-5;
47 pub const GN_EPS: f32 = 64.0e-5;
48
49 pub const DEFAULT_RESCALE: usize = 6;
50 pub const DEFAULT_SEP: usize = 1024;
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
54pub struct CustomInfo {
55 pub time_mix: usize,
57 pub time_decay: usize,
59}
60
61#[derive(Debug, Clone, Serialize, DeserializeSeed)]
62#[serde_seed(seed = "Seed", context = "Context")]
63pub struct ModelTensor {
64 pub embed: Embed,
65 pub head: Head,
66 pub layers: Vec<Layer>,
67}
68
69#[derive(Debug, Clone, Serialize, DeserializeSeed)]
70#[serde_seed(seed = "Seed", context = "Context")]
71pub struct LayerNorm {
72 pub w: TensorGpu<f16, ReadWrite>,
73 pub b: TensorGpu<f16, ReadWrite>,
74}
75
76#[derive(Debug, Clone, Serialize, DeserializeSeed)]
77#[serde_seed(seed = "Seed", context = "Context")]
78pub struct Att {
79 pub time_decay: TensorGpu<f16, ReadWrite>,
80 pub time_first: TensorGpu<f32, ReadWrite>,
81
82 pub time_mix_x: TensorGpu<f16, ReadWrite>,
83 pub time_mix: TensorGpu<f16, ReadWrite>,
84
85 pub time_decay_w1: Matrix,
86 pub time_decay_w2: Matrix,
87 pub time_mix_w1: Matrix,
88 pub time_mix_w2: Matrix,
89
90 pub w_k: Matrix,
91 pub w_v: Matrix,
92 pub w_r: Matrix,
93 pub w_g: Matrix,
94 pub w_o: Matrix,
95
96 pub group_norm: LayerNorm,
97}
98
99#[derive(Debug, Clone, Serialize, DeserializeSeed)]
100#[serde_seed(seed = "Seed", context = "Context")]
101pub struct Ffn {
102 pub time_mix_k: TensorGpu<f16, ReadWrite>,
103 pub time_mix_r: TensorGpu<f16, ReadWrite>,
104
105 pub w_k: Matrix,
106 pub w_v: Matrix,
107 pub w_r: Matrix,
108}
109
110#[derive(Debug, Clone, Serialize, DeserializeSeed)]
111#[serde_seed(seed = "Seed", context = "Context")]
112pub struct Layer {
113 pub att_layer_norm: LayerNorm,
114 pub ffn_layer_norm: LayerNorm,
115 pub att: Att,
116 pub ffn: Ffn,
117}
118
119#[derive(Debug, Clone, Serialize, DeserializeSeed)]
120#[serde_seed(seed = "Seed", context = "Context")]
121pub struct Embed {
122 pub layer_norm: LayerNorm,
123 pub w: TensorCpu<f16>,
124}
125
126#[derive(Debug, Clone, Serialize, DeserializeSeed)]
127#[serde_seed(seed = "Seed", context = "Context")]
128pub struct Head {
129 pub layer_norm: LayerNorm,
130 pub w: Matrix,
131}
132
133#[derive(Debug, Clone, Serialize, DeserializeSeed)]
134#[serde_seed(seed = "Seed", context = "Context")]
135pub struct State {
136 pub context: Context,
137 pub info: ModelInfo,
138 pub data: Vec<TensorGpu<f32, ReadWrite>>,
139}
140
141impl State {
142 async fn back(&self, batch: usize) -> Result<TensorCpu<f32>, TensorError> {
143 let context = &self.context;
144 let mut tensors = Vec::with_capacity(self.info.num_layer);
145 let mut encoder = context.device.create_command_encoder(&Default::default());
146 for data in self.data.iter() {
147 let shape = data.shape();
148 let destination = context.tensor_init([shape[0], shape[1], 1, 1]);
149 encoder.copy_tensor_batch(data, &destination, batch, 0)?;
150 tensors.push(destination);
151 }
152 context.queue.submit(Some(encoder.finish()));
153
154 let mut backed = Vec::with_capacity(tensors.len());
155 for tensor in tensors {
156 backed.push(tensor.back().await);
157 }
158 TensorCpu::stack(backed, 2)
159 }
160}
161
162impl AsAny for State {
163 fn as_any(&self) -> &dyn std::any::Any {
164 self
165 }
166}
167
168impl super::model::State for State {
169 #[inline]
170 fn num_batch(&self) -> usize {
171 self.data[0].shape()[2]
172 }
173
174 #[inline]
175 fn init_shape(&self) -> Shape {
176 let info = &self.info;
177 let head_size = info.num_emb / info.num_head;
178 [info.num_emb, head_size + 2, info.num_layer, 1].into()
179 }
180
181 fn init(&self) -> TensorCpu<f32> {
182 let shape = self.init_shape();
183 let data = vec![0.0; shape.len()];
184 TensorCpu::from_data(shape, data).unwrap()
185 }
186
187 fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
188 let head_size = self.info.num_emb / self.info.num_head;
189 let end = head_size + 1;
190 self.data[layer].view(.., 0..end, .., ..)
191 }
192
193 fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
194 let head_size = self.info.num_emb / self.info.num_head;
195 let start = head_size + 1;
196 self.data[layer].view(.., start, .., ..)
197 }
198
199 fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError> {
200 let head_size = self.info.num_emb / self.info.num_head;
201 tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
202 for (data, source) in self.data.iter().zip(tensor.split(2)?.into_iter()) {
203 data.load_batch(&source, batch)?;
204 }
205 Ok(())
206 }
207
208 #[cfg(not(target_arch = "wasm32"))]
209 fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
210 Box::pin(self.back(batch))
211 }
212
213 #[cfg(target_arch = "wasm32")]
214 fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
215 Box::pin(self.back(batch))
216 }
217
218 fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError> {
219 let head_size = self.info.num_emb / self.info.num_head;
220 tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
221
222 let context = &self.context;
223 let mut ops = Vec::with_capacity(self.data.len());
224 for (layer, data) in self.data.iter().enumerate() {
225 ops.push(TensorOp::blit(
226 tensor.view(.., .., layer, ..)?,
227 data.view(.., .., batch, ..)?,
228 )?);
229 }
230 context.queue.submit(context.encode(&TensorOp::List(ops)));
231
232 Ok(())
233 }
234
235 fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError> {
236 let context = &self.context;
237 let head_size = self.info.num_emb / self.info.num_head;
238 let shape = [self.info.num_emb, head_size + 2, self.info.num_layer, 1];
239 let tensor: TensorGpu<_, _> = context.tensor_init(shape);
240
241 let mut ops = Vec::with_capacity(self.data.len());
242 for (layer, data) in self.data.iter().enumerate() {
243 ops.push(TensorOp::blit(
244 data.view(.., .., batch, ..)?,
245 tensor.view(.., .., layer, ..)?,
246 )?);
247 }
248 context.queue.submit(context.encode(&TensorOp::List(ops)));
249
250 Ok(tensor)
251 }
252
253 fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError> {
254 backed.slice(.., 0, layer, ..)
255 }
256}
257
258impl DeepClone for State {
259 fn deep_clone(&self) -> Self {
260 let data = self.data.iter().map(|tensor| tensor.deep_clone()).collect();
261 Self {
262 data,
263 ..self.clone()
264 }
265 }
266}
267
268#[derive(Debug, Clone, Serialize, DeserializeSeed)]
269#[serde_seed(seed = "Seed", context = "Context")]
270pub struct Runtime<F: Float> {
271 pub cursors: TensorGpu<u32, ReadWrite>,
272 pub input: TensorGpu<f16, ReadWrite>,
273
274 pub x: TensorGpu<F, ReadWrite>,
275 pub aux_x: TensorGpu<f32, ReadWrite>,
276
277 pub att_x: TensorGpu<F, ReadWrite>,
278 pub att_xx: TensorGpu<F, ReadWrite>,
279 pub att_sx: TensorGpu<F, ReadWrite>,
281 pub att_w: TensorGpu<F, ReadWrite>,
283 pub att_k: TensorGpu<f32, ReadWrite>,
284 pub att_v: TensorGpu<f32, ReadWrite>,
285 pub att_r: TensorGpu<f32, ReadWrite>,
286 pub att_g: TensorGpu<F, ReadWrite>,
287 pub att_o: TensorGpu<F, ReadWrite>,
288
289 pub time_mix_x: TensorGpu<F, ReadWrite>,
291 pub time_mix_t: TensorGpu<F, ReadWrite>,
293 pub time_mix: TensorGpu<F, ReadWrite>,
295 pub time_decay: TensorGpu<f32, ReadWrite>,
296
297 pub ffn_x: TensorGpu<F, ReadWrite>,
298 pub ffn_kx: TensorGpu<F, ReadWrite>,
299 pub ffn_rx: TensorGpu<F, ReadWrite>,
300 pub ffn_k: TensorGpu<F, ReadWrite>,
301 pub ffn_v: TensorGpu<F, ReadWrite>,
302 pub ffn_r: TensorGpu<F, ReadWrite>,
303}
304
305impl<F: Float> Runtime<F> {
306 pub fn new(context: &Context, info: &ModelInfo, num_token: usize) -> Self {
307 let ModelCustomInfo::V6(custom) = info.custom else {
308 unreachable!()
309 };
310
311 let shape = Shape::new(info.num_emb, num_token, 1, 1);
312 let cursors_shape = Shape::new(num_token, 1, 1, 1);
313 let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1);
314 let time_mix_shape = Shape::new(info.num_emb, num_token, 5, 1);
315 let time_mix_x_shape = Shape::new(custom.time_mix, 5, num_token, 1);
316 let time_mix_t_shape = Shape::new(custom.time_mix, num_token, 5, 1);
317 let time_decay_shape = Shape::new(custom.time_decay, num_token, 1, 1);
318
319 Self {
320 cursors: context.tensor_init(cursors_shape),
321 input: context.tensor_init(shape),
322 x: context.tensor_init(shape),
323 aux_x: context.tensor_init(shape),
324 att_x: context.tensor_init(shape),
325 att_xx: context.tensor_init(shape),
326 att_sx: context.tensor_init(time_mix_shape),
327 att_w: context.tensor_init(time_decay_shape),
328 att_k: context.tensor_init(shape),
329 att_v: context.tensor_init(shape),
330 att_r: context.tensor_init(shape),
331 att_g: context.tensor_init(shape),
332 att_o: context.tensor_init(shape),
333 time_mix_x: context.tensor_init(time_mix_x_shape),
334 time_mix_t: context.tensor_init(time_mix_t_shape),
335 time_mix: context.tensor_init(time_mix_shape),
336 time_decay: context.tensor_init(shape),
337 ffn_x: context.tensor_init(shape),
338 ffn_kx: context.tensor_init(shape),
339 ffn_rx: context.tensor_init(shape),
340 ffn_k: context.tensor_init(hidden_shape),
341 ffn_v: context.tensor_init(shape),
342 ffn_r: context.tensor_init(shape),
343 }
344 }
345}
346
347#[derive(Debug, Clone, Serialize, DeserializeSeed)]
348#[serde_seed(seed = "Seed", context = "Context")]
349pub struct Header<F: Float> {
350 pub head_x: TensorGpu<F, ReadWrite>,
351 pub head_o: TensorGpu<f32, ReadWrite>,
352}
353
354impl<F: Float> Header<F> {
355 pub fn new(context: &Context, info: &ModelInfo, num_header: usize) -> Self {
356 let head_shape = Shape::new(info.num_emb, num_header, 1, 1);
357 let output_shape = Shape::new(info.num_vocab_padded(), num_header, 1, 1);
358
359 Self {
360 head_x: context.tensor_init(head_shape),
361 head_o: context.tensor_init(output_shape),
362 }
363 }
364}
365
366#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
367pub enum Hook {
368 PostEmbedLoaded,
369 PostEmbedLayerNorm,
370 PreAtt(usize),
371 PostAttLayerNorm(usize),
372 PreAttTokenShift(usize),
373 PostAttTokenShift(usize),
374 PreAttTokenShiftAdapt(usize),
375 PostAttTokenShiftAdapt(usize),
376 PostAttTokenShiftAdaptActivate(usize),
377 PreAttGatedTokenShift(usize),
378 PostAttGatedTokenShift(usize),
379 PreAttTimeDecayAdapt(usize),
380 PostAttTimeDecayAdapt(usize),
381 PostAttTimeDecayAdaptActivate(usize),
382 PreAttTimeDecayActivate(usize),
383 PostAttTimeDecayActivate(usize),
384 PreAttLinear(usize),
385 PostAttLinear(usize),
386 PreAttTimeMix(usize),
387 PostAttTimeMix(usize),
388 PreAttGate(usize),
389 PostAttGate(usize),
390 PreAttOut(usize),
391 PostAttOut(usize),
392 PostAtt(usize),
393 PreFfn(usize),
394 PostFfnLayerNorm(usize),
395 PreFfnTokenShift(usize),
396 PostFfnTokenShift(usize),
397 PreFfnLinear(usize),
398 PostFfnLinear(usize),
399 PostFfnActivate(usize),
400 PreFfnChannelMix(usize),
401 PostFfnChannelMix(usize),
402 PostFfn(usize),
403 PreHead,
404 PostHeadLayerNorm,
405 PostHead,
406}
407
408pub struct RnnJob {
409 commands: Vec<CommandBuffer>,
410 redirect: RnnRedirect,
411
412 embed: TensorCpu<f16>,
413
414 cursors: TensorGpu<u32, ReadWrite>,
415 input: TensorGpu<f16, ReadWrite>,
416 output: TensorGpu<f32, ReadWrite>,
417}
418
419impl Job for RnnJob {
420 type Input = RnnInput;
421 type Output = RnnOutput;
422
423 fn load(&self, input: &RnnChunk) -> Result<(), RuntimeError> {
424 if input.num_token() == 0 {
425 return Ok(());
426 }
427
428 let stack: Vec<TensorCpu<f16>> = input
429 .iter()
430 .map(|chunk| {
431 let num_emb = self.embed.shape()[0];
432 let data = self.embed.data();
433 let data = chunk
434 .iter()
435 .map(|token| match token {
436 &Token::Token(token) => {
437 let start = num_emb * token as usize;
438 let end = start + num_emb;
439 let data = data[start..end].to_vec();
440 TensorCpu::from_data_1d(data)
441 }
442 Token::Embed(tensor) => tensor.clone(),
443 })
444 .collect_vec();
445 match TensorCpu::stack(data, 1) {
446 Ok(tensor) => tensor,
447 Err(_) => TensorCpu::init([num_emb, 0, 1, 1]),
448 }
449 })
450 .collect();
451 let stack = TensorStack::try_from(stack)?;
452
453 let cursors = stack.cursors.clone().into_cursors();
454 let cursors = TensorCpu::from_data(self.cursors.shape(), cursors)?;
455 self.cursors.load(&cursors)?;
456 self.input.load(&stack.tensor)?;
457
458 Ok(())
459 }
460
461 fn submit(&mut self) {
462 let commands = std::mem::take(&mut self.commands);
463 self.output.context.queue.submit(commands);
464 }
465
466 async fn back(self) -> Result<Self::Output, RuntimeError> {
467 let output = self.output.back().await;
468 let batches: Vec<_> = self
469 .redirect
470 .outputs
471 .into_iter()
472 .map(|(start, end)| output.slice(.., start..end, .., ..))
473 .try_collect()?;
474 let batches = batches.into_iter().map(RnnOutputBatch).collect();
475 Ok(RnnOutput(batches))
476 }
477}
478
479#[derive(Debug, Clone)]
480pub struct Frame<F: Float> {
481 pub state: State,
482 pub buffer: Arc<Runtime<F>>,
483 pub header: Arc<Header<F>>,
484}
485
486pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
487pub type HookMap<F> = HashMap<Hook, HookFn<F>>;
488
489#[derive(Clone)]
490pub struct Bundle<F: Float> {
491 model: Model,
492 state: State,
493 hooks: Arc<HookMap<F>>,
494 buffers: ResourceCache<usize, Runtime<F>>,
495 headers: ResourceCache<usize, Header<F>>,
496 phantom: PhantomData<F>,
497}
498
499impl<F: Float> Bundle<F> {
500 pub fn new(model: Model, num_batch: usize) -> Self {
501 let context = model.context.clone();
502 let info = model.info.clone();
503 let state = {
504 let head_size = info.num_emb / info.num_head;
505 let shape = Shape::new(info.num_emb, head_size + 2, num_batch, 1);
506 let data = (0..info.num_layer).map(|_| context.zeros(shape)).collect();
507 State {
508 context,
509 info,
510 data,
511 }
512 };
513 Self {
514 model,
515 state,
516 hooks: Default::default(),
517 buffers: ResourceCache::new(4),
518 headers: ResourceCache::new(4),
519 phantom: PhantomData,
520 }
521 }
522
523 pub fn new_with_hooks(model: Model, num_batch: usize, hooks: HookMap<F>) -> Self {
524 Self {
525 hooks: Arc::new(hooks),
526 ..Self::new(model, num_batch)
527 }
528 }
529
530 fn checkout_buffer(
531 &self,
532 context: &Context,
533 info: &ModelInfo,
534 num_token: usize,
535 ) -> Arc<Runtime<F>> {
536 self.buffers
537 .checkout(num_token, || Runtime::new(context, info, num_token))
538 }
539
540 fn checkout_header(
541 &self,
542 context: &Context,
543 info: &ModelInfo,
544 num_header: usize,
545 ) -> Arc<Header<F>> {
546 self.headers
547 .checkout(num_header, || Header::new(context, info, num_header))
548 }
549}
550
551impl<F: Float> super::model::Bundle for Bundle<F> {
552 #[inline]
553 fn info(&self) -> ModelInfo {
554 self.model.info.clone()
555 }
556
557 #[inline]
558 fn state(&self) -> impl super::model::State + AsAny + 'static {
559 self.state.clone()
560 }
561
562 #[inline]
563 fn model(&self) -> impl Serialize + 'static {
564 self.model.clone()
565 }
566}
567
568fn turbo(num_token: usize) -> bool {
569 num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE)
570}
571
572fn hook_op<F: Float>(
573 hooks: &HookMap<F>,
574 hook: &Hook,
575 frame: &Frame<F>,
576) -> Result<TensorOp, TensorError> {
577 match hooks.get(hook) {
578 Some(f) => f(frame.clone()),
579 None => Ok(TensorOp::empty()),
580 }
581}
582
583impl<F: Float> Dispatcher<RnnJob> for Bundle<F> {
584 type Info = RnnInfo;
585
586 fn dispatch(&self, seed: Self::Info) -> Result<RnnJob, RuntimeError> {
587 let model = &self.model;
588 let state = &self.state;
589 let context = &model.context;
590 let info = &model.info;
591 let tensor = &model.tensor;
592
593 let num_token = seed.num_token();
594 let head_size = info.num_emb / info.num_head;
595
596 let redirect = seed.redirect();
597 let num_header = redirect.headers.len();
598
599 let buffer = self.checkout_buffer(context, info, num_token);
600 let header = self.checkout_header(context, info, num_header);
601 let frame = Frame {
602 state: state.clone(),
603 buffer: buffer.clone(),
604 header: header.clone(),
605 };
606
607 context.maintain();
608 self.buffers.maintain();
609 self.headers.maintain();
610
611 if num_token == 0 {
612 return Ok(RnnJob {
613 commands: vec![],
614 redirect,
615 embed: model.tensor.embed.w.clone(),
616 cursors: buffer.cursors.clone(),
617 input: buffer.input.clone(),
618 output: header.head_o.clone(),
619 });
620 }
621
622 #[cfg(feature = "trace")]
623 let _span = tracing::trace_span!("build").entered();
624
625 let (head_op, head_x) = redirect.op(&buffer.x, &header.head_x)?;
626
627 let hook_op = |hook: Hook| hook_op(&self.hooks, &hook, &frame);
628 let mut ops = vec![];
629
630 {
631 #[cfg(feature = "trace")]
632 let _span = tracing::trace_span!("embed").entered();
633
634 ops.extend([
635 hook_op(Hook::PostEmbedLoaded)?,
636 TensorOp::layer_norm(
637 &tensor.embed.layer_norm.w,
638 &tensor.embed.layer_norm.b,
639 &buffer.input,
640 Model::LN_EPS,
641 )?,
642 TensorOp::blit(&buffer.input, &buffer.x)?,
643 hook_op(Hook::PostEmbedLayerNorm)?,
644 ]);
645 };
646
647 for (index, layer) in tensor.layers.iter().enumerate() {
648 #[cfg(feature = "trace")]
649 let _span = tracing::trace_span!("layer", index).entered();
650
651 let hooks = self.hooks.clone();
652 let frame = frame.clone();
653 let layer = layer.clone();
654
655 let op = dispatch_layer(
656 hooks,
657 frame,
658 layer,
659 index,
660 num_token,
661 head_size,
662 model.rescale,
663 )?;
664 ops.push(op);
665
666 if (index + 1) % model.sep == 0 {
667 ops.push(TensorOp::Sep);
668 }
669 }
670
671 {
672 #[cfg(feature = "trace")]
673 let _span = tracing::trace_span!("header").entered();
674
675 let hooks = self.hooks.clone();
676 let frame = frame.clone();
677 let head = model.tensor.head.clone();
678
679 let op = dispatch_header(hooks, frame, head, head_x, num_header, head_op)?;
680 ops.push(op);
681 }
682
683 let commands = {
684 #[cfg(feature = "trace")]
685 let _span = tracing::trace_span!("encode").entered();
686 context.encode(&TensorOp::List(ops))
687 };
688
689 Ok(RnnJob {
690 commands,
691 redirect,
692 embed: model.tensor.embed.w.clone(),
693 cursors: buffer.cursors.clone(),
694 input: buffer.input.clone(),
695 output: header.head_o.clone(),
696 })
697 }
698}
699
700#[allow(clippy::too_many_arguments)]
701fn dispatch_layer<F: Float>(
702 hooks: Arc<HookMap<F>>,
703 frame: Frame<F>,
704 layer: Layer,
705 index: usize,
706 num_token: usize,
707 head_size: usize,
708 rescale: usize,
709) -> Result<TensorOp, TensorError> {
710 let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
711 let Frame { state, buffer, .. } = &frame;
712
713 let time_first = layer.att.time_first.reshape(
714 TensorDimension::Size(head_size),
715 TensorDimension::Auto,
716 TensorDimension::Size(1),
717 TensorDimension::Size(1),
718 )?;
719 let time_decay = buffer.time_decay.reshape(
720 TensorDimension::Size(head_size),
721 TensorDimension::Auto,
722 TensorDimension::Size(num_token),
723 TensorDimension::Size(1),
724 )?;
725 let time_mix_x = buffer.time_mix_x.reshape(
726 TensorDimension::Auto,
727 TensorDimension::Size(num_token),
728 TensorDimension::Size(1),
729 TensorDimension::Size(1),
730 )?;
731 let aux_x = buffer.aux_x.reshape(
732 TensorDimension::Size(head_size),
733 TensorDimension::Auto,
734 TensorDimension::Size(num_token),
735 TensorDimension::Size(1),
736 )?;
737 let att_k = buffer.att_k.reshape(
738 TensorDimension::Size(head_size),
739 TensorDimension::Auto,
740 TensorDimension::Size(num_token),
741 TensorDimension::Size(1),
742 )?;
743 let att_v = buffer.att_v.reshape(
744 TensorDimension::Size(head_size),
745 TensorDimension::Auto,
746 TensorDimension::Size(num_token),
747 TensorDimension::Size(1),
748 )?;
749 let att_r = buffer.att_r.reshape(
750 TensorDimension::Size(head_size),
751 TensorDimension::Auto,
752 TensorDimension::Size(num_token),
753 TensorDimension::Size(1),
754 )?;
755
756 let mut ops = vec![];
757
758 ops.extend([
759 TensorOp::blit(&buffer.x, &buffer.att_x)?,
760 hook_op(Hook::PreAtt(index))?,
761 TensorOp::layer_norm(
762 &layer.att_layer_norm.w,
763 &layer.att_layer_norm.b,
764 &buffer.att_x,
765 Model::LN_EPS,
766 )?,
767 hook_op(Hook::PostAttLayerNorm(index))?,
768 hook_op(Hook::PreAttTokenShift(index))?,
769 TensorOp::token_shift(
770 &buffer.cursors,
771 &layer.att.time_mix_x,
772 state.att(index)?,
773 &buffer.att_x,
774 &buffer.att_xx,
775 true,
776 )?,
777 hook_op(Hook::PostAttTokenShift(index))?,
778 hook_op(Hook::PreAttTokenShiftAdapt(index))?,
779 layer.att.time_mix_w1.matmul_op(
780 &buffer.att_xx,
781 &time_mix_x,
782 Activation::Tanh,
783 turbo(num_token),
784 )?,
785 TensorOp::transpose(&buffer.time_mix_x, &buffer.time_mix_t)?,
786 hook_op(Hook::PostAttTokenShiftAdaptActivate(index))?,
787 layer.att.time_mix_w2.matmul_op(
788 &buffer.time_mix_t,
789 &buffer.time_mix,
790 Activation::None,
791 turbo(num_token),
792 )?,
793 hook_op(Hook::PostAttTokenShiftAdapt(index))?,
794 TensorOp::add(&layer.att.time_mix, &buffer.time_mix)?,
795 hook_op(Hook::PreAttGatedTokenShift(index))?,
796 TensorOp::token_shift(
797 &buffer.cursors,
798 &buffer.time_mix,
799 state.att(index)?,
800 &buffer.att_x,
801 &buffer.att_sx,
802 true,
803 )?,
804 hook_op(Hook::PostAttGatedTokenShift(index))?,
805 hook_op(Hook::PreAttLinear(index))?,
806 layer.att.w_k.matmul_op(
807 buffer.att_sx.view(.., .., 1, ..)?,
808 &buffer.att_k,
809 Activation::None,
810 turbo(num_token),
811 )?,
812 layer.att.w_v.matmul_op(
813 buffer.att_sx.view(.., .., 2, ..)?,
814 &buffer.att_v,
815 Activation::None,
816 turbo(num_token),
817 )?,
818 layer.att.w_r.matmul_op(
819 buffer.att_sx.view(.., .., 3, ..)?,
820 &buffer.att_r,
821 Activation::None,
822 turbo(num_token),
823 )?,
824 layer.att.w_g.matmul_op(
825 buffer.att_sx.view(.., .., 4, ..)?,
826 &buffer.att_g,
827 Activation::None,
828 turbo(num_token),
829 )?,
830 hook_op(Hook::PostAttLinear(index))?,
831 hook_op(Hook::PreAttTimeDecayAdapt(index))?,
832 layer.att.time_decay_w1.matmul_op(
833 buffer.att_sx.view(.., .., 0, ..)?,
834 &buffer.att_w,
835 Activation::Tanh,
836 turbo(num_token),
837 )?,
838 hook_op(Hook::PostAttTimeDecayAdaptActivate(index))?,
839 layer.att.time_decay_w2.matmul_op(
840 &buffer.att_w,
841 &buffer.time_decay,
842 Activation::None,
843 turbo(num_token),
844 )?,
845 hook_op(Hook::PostAttTimeDecayAdapt(index))?,
846 TensorOp::add(&layer.att.time_decay, &buffer.time_decay)?,
847 hook_op(Hook::PreAttTimeDecayActivate(index))?,
848 TensorOp::activate(&buffer.time_decay, Activation::StableExp)?,
849 hook_op(Hook::PostAttTimeDecayActivate(index))?,
850 hook_op(Hook::PreAttTimeMix(index))?,
851 TensorOp::blit(&buffer.att_x, &buffer.aux_x)?,
852 TensorOp::time_mix_v6(
853 &buffer.cursors,
854 &time_decay,
855 &time_first,
856 state.att(index)?,
857 &att_k,
858 &att_v,
859 &att_r,
860 &aux_x,
861 )?,
862 TensorOp::group_norm(
863 &layer.att.group_norm.w,
864 &layer.att.group_norm.b,
865 &aux_x,
866 Model::GN_EPS,
867 )?,
868 TensorOp::blit(&buffer.aux_x, &buffer.att_x)?,
869 hook_op(Hook::PostAttTimeMix(index))?,
870 hook_op(Hook::PreAttGate(index))?,
871 TensorOp::mul_activate(
872 &buffer.att_g,
873 &buffer.att_x,
874 Activation::Silu,
875 Activation::None,
876 Activation::None,
877 )?,
878 hook_op(Hook::PostAttGate(index))?,
879 hook_op(Hook::PreAttOut(index))?,
880 layer.att.w_o.matmul_op(
881 &buffer.att_x,
882 &buffer.att_o,
883 Activation::None,
884 turbo(num_token),
885 )?,
886 hook_op(Hook::PostAttOut(index))?,
887 TensorOp::add(&buffer.att_o, &buffer.x)?,
888 hook_op(Hook::PostAtt(index))?,
889 ]);
890
891 ops.extend([
892 TensorOp::blit(&buffer.x, &buffer.ffn_x)?,
893 hook_op(Hook::PreFfn(index))?,
894 TensorOp::layer_norm(
895 &layer.ffn_layer_norm.w,
896 &layer.ffn_layer_norm.b,
897 &buffer.ffn_x,
898 Model::LN_EPS,
899 )?,
900 hook_op(Hook::PostFfnLayerNorm(index))?,
901 hook_op(Hook::PreFfnTokenShift(index))?,
902 TensorOp::token_shift(
903 &buffer.cursors,
904 &layer.ffn.time_mix_k,
905 state.ffn(index)?,
906 &buffer.ffn_x,
907 &buffer.ffn_kx,
908 true,
909 )?,
910 TensorOp::token_shift(
911 &buffer.cursors,
912 &layer.ffn.time_mix_r,
913 state.ffn(index)?,
914 &buffer.ffn_x,
915 &buffer.ffn_rx,
916 true,
917 )?,
918 hook_op(Hook::PostFfnTokenShift(index))?,
919 hook_op(Hook::PreFfnLinear(index))?,
920 layer.ffn.w_k.matmul_op(
921 &buffer.ffn_kx,
922 &buffer.ffn_k,
923 Activation::SquaredRelu,
924 turbo(num_token),
925 )?,
926 hook_op(Hook::PostFfnActivate(index))?,
927 layer.ffn.w_v.matmul_op_sparse(
928 &buffer.ffn_k,
929 &buffer.ffn_v,
930 Activation::None,
931 turbo(num_token),
932 )?,
933 layer.ffn.w_r.matmul_op(
934 &buffer.ffn_rx,
935 &buffer.ffn_r,
936 Activation::None,
937 turbo(num_token),
938 )?,
939 hook_op(Hook::PostFfnLinear(index))?,
940 hook_op(Hook::PreFfnChannelMix(index))?,
941 TensorOp::channel_mix(
942 &buffer.cursors,
943 state.ffn(index)?,
944 &buffer.ffn_r,
945 &buffer.ffn_v,
946 &buffer.ffn_x,
947 )?,
948 hook_op(Hook::PostFfnChannelMix(index))?,
949 TensorOp::add(&buffer.ffn_x, &buffer.x)?,
950 hook_op(Hook::PostFfn(index))?,
951 ]);
952
953 if (index + 1).is_multiple_of(rescale) {
954 ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
955 }
956
957 Ok(TensorOp::List(ops))
958}
959
960fn dispatch_header<F: Float>(
961 hooks: Arc<HookMap<F>>,
962 frame: Frame<F>,
963 head: Head,
964 head_x: TensorGpu<F, ReadWrite>,
965 num_header: usize,
966 head_op: TensorOp,
967) -> Result<TensorOp, TensorError> {
968 let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
969 let header = &frame.header;
970 let mut ops = vec![head_op];
971
972 if num_header > 0 {
973 ops.extend([
974 hook_op(Hook::PreHead)?,
975 TensorOp::layer_norm(
976 &head.layer_norm.w,
977 &head.layer_norm.b,
978 &head_x,
979 Model::LN_EPS,
980 )?,
981 hook_op(Hook::PostHeadLayerNorm)?,
982 head.w.matmul_op(
983 head_x.view(.., .., .., ..)?,
984 header.head_o.view(.., .., .., ..)?,
985 Activation::None,
986 turbo(num_header),
987 )?,
988 hook_op(Hook::PostHead)?,
989 ]);
990 }
991 Ok(TensorOp::List(ops))
992}
993
994impl<R: Reader> ModelBuilder<R> {
995 pub async fn build_v6(self) -> Result<Model, LoaderError> {
996 let ModelBuilder {
997 context,
998 model,
999 rescale,
1000 sep,
1001 lora,
1002 quant,
1003 ..
1004 } = self;
1005
1006 let rescale = rescale.unwrap_or(Model::DEFAULT_RESCALE);
1007 let sep = sep.unwrap_or(Model::DEFAULT_SEP);
1008
1009 let info = Loader::info(&model)?;
1010 let loader = Loader {
1011 context: context.clone(),
1012 model,
1013 lora,
1014 };
1015
1016 let embed = Embed {
1017 layer_norm: LayerNorm {
1018 w: loader.load_vector_f16("blocks.0.ln0.weight")?,
1019 b: loader.load_vector_f16("blocks.0.ln0.bias")?,
1020 },
1021 w: loader.load_matrix_f16_padded_cpu("emb.weight")?,
1022 };
1023
1024 let head = Head {
1025 layer_norm: LayerNorm {
1026 w: loader.load_vector_f16("ln_out.weight")?,
1027 b: loader.load_vector_f16("ln_out.bias")?,
1028 },
1029 w: Matrix::Fp16(loader.load_matrix_f16_padded("head.weight")?),
1030 };
1031
1032 let submission_index = Some(context.queue.submit(None));
1033 _ = context.device.poll(wgpu::PollType::Wait {
1034 submission_index,
1035 timeout: None,
1036 });
1037
1038 let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant);
1039 let load_matrix_discount = |name: String, quant: Quant, discount: f32| {
1040 loader.load_matrix_discount(name, quant, discount)
1041 };
1042
1043 let mut layers = vec![];
1044 for layer in 0..info.num_layer {
1045 let quant = quant.get(&layer).copied().unwrap_or_default();
1046 let discount = 2.0_f32.powi(-((layer / rescale) as i32));
1047
1048 let att_layer_norm = LayerNorm {
1049 w: loader.load_vector_f16(format!("blocks.{layer}.ln1.weight"))?,
1050 b: loader.load_vector_f16(format!("blocks.{layer}.ln1.bias"))?,
1051 };
1052
1053 let att = format!("blocks.{layer}.att");
1054 let time_decay = loader.load_vector_f16(format!("{att}.time_decay"))?;
1055 let time_first = loader.load_vector_f32(format!("{att}.time_first"))?;
1056 let time_mix_x = loader.load_vector_f16(format!("{att}.time_mix_x"))?;
1057 let time_mix = {
1058 let time_mix: TensorGpu<_, _> = context.zeros([info.num_emb, 1, 5, 1]);
1059 let time_mix_w = loader.load_vector_f16(format!("{att}.time_mix_w"))?;
1060 let time_mix_k = loader.load_vector_f16(format!("{att}.time_mix_k"))?;
1061 let time_mix_v = loader.load_vector_f16(format!("{att}.time_mix_v"))?;
1062 let time_mix_r = loader.load_vector_f16(format!("{att}.time_mix_r"))?;
1063 let time_mix_g = loader.load_vector_f16(format!("{att}.time_mix_g"))?;
1064
1065 let ops = TensorOp::List(vec![
1066 TensorOp::blit(&time_mix_w, time_mix.view(.., .., 0, ..)?)?,
1067 TensorOp::blit(&time_mix_k, time_mix.view(.., .., 1, ..)?)?,
1068 TensorOp::blit(&time_mix_v, time_mix.view(.., .., 2, ..)?)?,
1069 TensorOp::blit(&time_mix_r, time_mix.view(.., .., 3, ..)?)?,
1070 TensorOp::blit(&time_mix_g, time_mix.view(.., .., 4, ..)?)?,
1071 ]);
1072 context.queue.submit(context.encode(&ops));
1073 time_mix
1074 };
1075
1076 let time_decay_w1 = loader.load_matrix_f16(format!("{att}.time_decay_w1"))?;
1077 let time_decay_w2 = loader.load_matrix_f16(format!("{att}.time_decay_w2"))?;
1078
1079 let time_mix_w1 = loader.load_matrix_f16(format!("{att}.time_mix_w1"))?;
1080 let time_mix_w2 = loader.load_matrix_f16(format!("{att}.time_mix_w2"))?;
1081
1082 let group_norm = LayerNorm {
1083 w: loader
1084 .load_vector_f16(format!("{att}.ln_x.weight"))?
1085 .reshape(
1086 TensorDimension::Auto,
1087 TensorDimension::Size(info.num_head),
1088 TensorDimension::Size(1),
1089 TensorDimension::Size(1),
1090 )?,
1091 b: loader
1092 .load_vector_f16(format!("{att}.ln_x.bias"))?
1093 .reshape(
1094 TensorDimension::Auto,
1095 TensorDimension::Size(info.num_head),
1096 TensorDimension::Size(1),
1097 TensorDimension::Size(1),
1098 )?,
1099 };
1100
1101 let att = Att {
1102 time_decay,
1103 time_first,
1104 time_mix_x,
1105 time_mix,
1106 time_decay_w1: Matrix::Fp16(time_decay_w1),
1107 time_decay_w2: Matrix::Fp16(time_decay_w2),
1108 time_mix_w1: Matrix::Fp16(time_mix_w1),
1109 time_mix_w2: Matrix::Fp16(time_mix_w2),
1110 w_k: load_matrix(format!("{att}.key.weight"), quant)?,
1111 w_v: load_matrix(format!("{att}.value.weight"), quant)?,
1112 w_r: load_matrix(format!("{att}.receptance.weight"), quant)?,
1113 w_g: load_matrix(format!("{att}.gate.weight"), quant)?,
1114 w_o: load_matrix_discount(format!("{att}.output.weight"), quant, discount)?,
1115 group_norm,
1116 };
1117
1118 let ffn_layer_norm = LayerNorm {
1119 w: loader.load_vector_f16(format!("blocks.{layer}.ln2.weight"))?,
1120 b: loader.load_vector_f16(format!("blocks.{layer}.ln2.bias"))?,
1121 };
1122
1123 let ffn = format!("blocks.{layer}.ffn");
1124 let time_mix_k = loader.load_vector_f16(format!("{ffn}.time_mix_k"))?;
1125 let time_mix_r = loader.load_vector_f16(format!("{ffn}.time_mix_r"))?;
1126
1127 let ffn = Ffn {
1128 time_mix_k,
1129 time_mix_r,
1130 w_r: load_matrix(format!("{ffn}.receptance.weight"), quant)?,
1131 w_k: load_matrix(format!("{ffn}.key.weight"), quant)?,
1132 w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?,
1133 };
1134
1135 let submission_index = Some(context.queue.submit(None));
1136 _ = context.device.poll(wgpu::PollType::Wait {
1137 submission_index,
1138 timeout: None,
1139 });
1140
1141 layers.push(Layer {
1142 att_layer_norm,
1143 ffn_layer_norm,
1144 att,
1145 ffn,
1146 })
1147 }
1148
1149 let submission_index = Some(context.queue.submit(None));
1150 _ = context.device.poll(wgpu::PollType::Wait {
1151 submission_index,
1152 timeout: None,
1153 });
1154
1155 let tensor = ModelTensor {
1156 embed,
1157 head,
1158 layers,
1159 };
1160 let model = {
1161 let context = context.clone();
1162 let info = info.clone();
1163 Model {
1164 context,
1165 info,
1166 rescale,
1167 sep,
1168 tensor,
1169 }
1170 };
1171 Ok(model)
1172 }
1173}
1174
1175pub async fn read_state<R: Reader>(
1177 context: &Context,
1178 info: &ModelInfo,
1179 model: R,
1180) -> Result<TensorCpu<f32>, LoaderError> {
1181 let loader = Loader {
1182 context: context.clone(),
1183 model,
1184 lora: vec![],
1185 };
1186
1187 let head_size = info.num_emb / info.num_head;
1188 let data: TensorGpu<f32, _> = context.zeros([info.num_emb, head_size + 2, info.num_layer, 1]);
1189
1190 let mut ops = vec![];
1191 for layer in 0..info.num_layer {
1192 let matrix = loader.load_matrix_f16(format!("blocks.{layer}.att.time_state"))?;
1193 let state: TensorGpu<_, _> = context.tensor_init([head_size, info.num_head, head_size, 1]);
1194 let reshaped: TensorGpu<f16, _> = state.reshape(
1195 TensorDimension::Size(info.num_emb),
1196 TensorDimension::Size(head_size),
1197 TensorDimension::Size(1),
1198 TensorDimension::Auto,
1199 )?;
1200 ops.extend([
1201 TensorOp::transpose(&matrix, &state)?,
1202 TensorOp::blit(&reshaped, data.view(.., 1..head_size + 1, layer, ..)?)?,
1203 ]);
1204 }
1205 context.queue.submit(context.encode(&TensorOp::List(ops)));
1206
1207 Ok(data.back().await)
1208}