1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use half::f16;
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use web_rwkv_derive::DeserializeSeed;
11use wgpu::CommandBuffer;
12
13use super::{
14 infer::{RnnChunk, RnnInfo, RnnInput, RnnOutput, RnnOutputBatch, RnnRedirect, Token},
15 loader::{Loader, LoaderError, Reader},
16 model::{AsAny, ModelBuilder, ModelInfo, Quant, State as _},
17 Dispatcher, Job, RuntimeError,
18};
19use crate::{
20 context::Context,
21 num::Float,
22 tensor::{
23 cache::ResourceCache,
24 kind::ReadWrite,
25 matrix::Matrix,
26 ops::{Activation, TensorCommand, TensorOp},
27 serialization::Seed,
28 shape::{Shape, TensorDimension},
29 DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit,
30 TensorReshape, TensorShape, TensorStack,
31 },
32};
33
34#[derive(Debug, Clone, Serialize, DeserializeSeed)]
35#[serde_seed(seed = "Seed", context = "Context")]
36pub struct Model {
37 pub context: Context,
38 pub info: ModelInfo,
39 pub rescale: usize,
40 pub sep: usize,
41 pub tensor: ModelTensor,
42}
43
44impl Model {
45 pub const LN_EPS: f32 = 1.0e-5;
46 pub const GN_EPS: f32 = 64.0e-5;
47
48 pub const DEFAULT_RESCALE: usize = 6;
49 pub const DEFAULT_SEP: usize = 1024;
50}
51
52#[derive(Debug, Clone, Serialize, DeserializeSeed)]
53#[serde_seed(seed = "Seed", context = "Context")]
54pub struct ModelTensor {
55 pub embed: Embed,
56 pub head: Head,
57 pub layers: Vec<Layer>,
58}
59
60#[derive(Debug, Clone, Serialize, DeserializeSeed)]
61#[serde_seed(seed = "Seed", context = "Context")]
62pub struct LayerNorm {
63 pub w: TensorGpu<f16, ReadWrite>,
64 pub b: TensorGpu<f16, ReadWrite>,
65}
66
67#[derive(Debug, Clone, Serialize, DeserializeSeed)]
68#[serde_seed(seed = "Seed", context = "Context")]
69pub struct Att {
70 pub time_decay: TensorGpu<f32, ReadWrite>,
71 pub time_first: TensorGpu<f32, ReadWrite>,
72
73 pub time_mix_k: TensorGpu<f16, ReadWrite>,
74 pub time_mix_v: TensorGpu<f16, ReadWrite>,
75 pub time_mix_r: TensorGpu<f16, ReadWrite>,
76 pub time_mix_g: TensorGpu<f16, ReadWrite>,
77
78 pub w_k: Matrix,
79 pub w_v: Matrix,
80 pub w_r: Matrix,
81 pub w_g: Matrix,
82 pub w_o: Matrix,
83
84 pub group_norm: LayerNorm,
85}
86
87#[derive(Debug, Clone, Serialize, DeserializeSeed)]
88#[serde_seed(seed = "Seed", context = "Context")]
89pub struct Ffn {
90 pub time_mix_k: TensorGpu<f16, ReadWrite>,
91 pub time_mix_r: TensorGpu<f16, ReadWrite>,
92
93 pub w_k: Matrix,
94 pub w_v: Matrix,
95 pub w_r: Matrix,
96}
97
98#[derive(Debug, Clone, Serialize, DeserializeSeed)]
99#[serde_seed(seed = "Seed", context = "Context")]
100pub struct Layer {
101 pub att_layer_norm: LayerNorm,
102 pub ffn_layer_norm: LayerNorm,
103 pub att: Att,
104 pub ffn: Ffn,
105}
106
107#[derive(Debug, Clone, Serialize, DeserializeSeed)]
108#[serde_seed(seed = "Seed", context = "Context")]
109pub struct Embed {
110 pub layer_norm: LayerNorm,
111 pub w: TensorCpu<f16>,
112}
113
114#[derive(Debug, Clone, Serialize, DeserializeSeed)]
115#[serde_seed(seed = "Seed", context = "Context")]
116pub struct Head {
117 pub layer_norm: LayerNorm,
118 pub w: Matrix,
119}
120
121#[derive(Debug, Clone, Serialize, DeserializeSeed)]
122#[serde_seed(seed = "Seed", context = "Context")]
123pub struct State {
124 pub context: Context,
125 pub info: ModelInfo,
126 pub data: Vec<TensorGpu<f32, ReadWrite>>,
127}
128
129impl State {
130 async fn back(&self, batch: usize) -> Result<TensorCpu<f32>, TensorError> {
131 let context = &self.context;
132 let mut tensors = Vec::with_capacity(self.info.num_layer);
133 let mut encoder = context.device.create_command_encoder(&Default::default());
134 for data in self.data.iter() {
135 let shape = data.shape();
136 let destination = context.tensor_init([shape[0], shape[1], 1, 1]);
137 encoder.copy_tensor_batch(data, &destination, batch, 0)?;
138 tensors.push(destination);
139 }
140 context.queue.submit(Some(encoder.finish()));
141
142 let mut backed = Vec::with_capacity(tensors.len());
143 for tensor in tensors {
144 backed.push(tensor.back().await);
145 }
146 TensorCpu::stack(backed, 2)
147 }
148}
149
150impl AsAny for State {
151 fn as_any(&self) -> &dyn std::any::Any {
152 self
153 }
154}
155
156impl super::model::State for State {
157 #[inline]
158 fn num_batch(&self) -> usize {
159 self.data[0].shape()[2]
160 }
161
162 #[inline]
163 fn init_shape(&self) -> Shape {
164 let info = &self.info;
165 let head_size = info.num_emb / info.num_head;
166 [info.num_emb, head_size + 2, info.num_layer, 1].into()
167 }
168
169 fn init(&self) -> TensorCpu<f32> {
170 let shape = self.init_shape();
171 let data = vec![0.0; shape.len()];
172 TensorCpu::from_data(shape, data).unwrap()
173 }
174
175 fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
176 let head_size = self.info.num_emb / self.info.num_head;
177 let end = head_size + 1;
178 self.data[layer].view(.., 0..end, .., ..)
179 }
180
181 fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
182 let head_size = self.info.num_emb / self.info.num_head;
183 let start = head_size + 1;
184 self.data[layer].view(.., start, .., ..)
185 }
186
187 fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError> {
188 let head_size = self.info.num_emb / self.info.num_head;
189 tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
190 for (data, source) in self.data.iter().zip(tensor.split(2)?.into_iter()) {
191 data.load_batch(&source, batch)?;
192 }
193 Ok(())
194 }
195
196 #[cfg(not(target_arch = "wasm32"))]
197 fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
198 Box::pin(self.back(batch))
199 }
200
201 #[cfg(target_arch = "wasm32")]
202 fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
203 Box::pin(self.back(batch))
204 }
205
206 fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError> {
207 let head_size = self.info.num_emb / self.info.num_head;
208 tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
209
210 let context = &self.context;
211 let mut ops = Vec::with_capacity(self.data.len());
212 for (layer, data) in self.data.iter().enumerate() {
213 ops.push(TensorOp::blit(
214 tensor.view(.., .., layer, ..)?,
215 data.view(.., .., batch, ..)?,
216 )?);
217 }
218 context.queue.submit(context.encode(&TensorOp::List(ops)));
219
220 Ok(())
221 }
222
223 fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError> {
224 let context = &self.context;
225 let head_size = self.info.num_emb / self.info.num_head;
226 let shape = [self.info.num_emb, head_size + 2, self.info.num_layer, 1];
227 let tensor: TensorGpu<_, _> = context.tensor_init(shape);
228
229 let mut ops = Vec::with_capacity(self.data.len());
230 for (layer, data) in self.data.iter().enumerate() {
231 ops.push(TensorOp::blit(
232 data.view(.., .., batch, ..)?,
233 tensor.view(.., .., layer, ..)?,
234 )?);
235 }
236 context.queue.submit(context.encode(&TensorOp::List(ops)));
237
238 Ok(tensor)
239 }
240
241 fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError> {
242 backed.slice(.., 0, layer, ..)
243 }
244}
245
246impl DeepClone for State {
247 fn deep_clone(&self) -> Self {
248 let data = self.data.iter().map(|tensor| tensor.deep_clone()).collect();
249 Self {
250 data,
251 ..self.clone()
252 }
253 }
254}
255
256#[derive(Debug, Clone, Serialize, DeserializeSeed)]
257#[serde_seed(seed = "Seed", context = "Context")]
258pub struct Runtime<F: Float> {
259 pub cursors: TensorGpu<u32, ReadWrite>,
260 pub input: TensorGpu<f16, ReadWrite>,
261
262 pub x: TensorGpu<F, ReadWrite>,
263 pub aux_x: TensorGpu<f32, ReadWrite>,
264
265 pub att_x: TensorGpu<F, ReadWrite>,
266 pub att_kx: TensorGpu<F, ReadWrite>,
267 pub att_vx: TensorGpu<F, ReadWrite>,
268 pub att_rx: TensorGpu<F, ReadWrite>,
269 pub att_gx: TensorGpu<F, ReadWrite>,
270 pub att_k: TensorGpu<f32, ReadWrite>,
271 pub att_v: TensorGpu<f32, ReadWrite>,
272 pub att_r: TensorGpu<f32, ReadWrite>,
273 pub att_g: TensorGpu<F, ReadWrite>,
274 pub att_o: TensorGpu<F, ReadWrite>,
275
276 pub ffn_x: TensorGpu<F, ReadWrite>,
277 pub ffn_kx: TensorGpu<F, ReadWrite>,
278 pub ffn_rx: TensorGpu<F, ReadWrite>,
279 pub ffn_k: TensorGpu<F, ReadWrite>,
280 pub ffn_v: TensorGpu<F, ReadWrite>,
281 pub ffn_r: TensorGpu<F, ReadWrite>,
282}
283
284impl<F: Float> Runtime<F> {
285 pub fn new(context: &Context, info: &ModelInfo, num_token: usize) -> Self {
286 let shape = Shape::new(info.num_emb, num_token, 1, 1);
287 let cursors_shape = Shape::new(num_token, 1, 1, 1);
288 let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1);
289
290 Self {
291 cursors: context.tensor_init(cursors_shape),
292 input: context.tensor_init(shape),
293 x: context.tensor_init(shape),
294 aux_x: context.tensor_init(shape),
295 att_x: context.tensor_init(shape),
296 att_kx: context.tensor_init(shape),
297 att_vx: context.tensor_init(shape),
298 att_rx: context.tensor_init(shape),
299 att_gx: context.tensor_init(shape),
300 att_k: context.tensor_init(shape),
301 att_v: context.tensor_init(shape),
302 att_r: context.tensor_init(shape),
303 att_g: context.tensor_init(shape),
304 att_o: context.tensor_init(shape),
305 ffn_x: context.tensor_init(shape),
306 ffn_kx: context.tensor_init(shape),
307 ffn_rx: context.tensor_init(shape),
308 ffn_k: context.tensor_init(hidden_shape),
309 ffn_v: context.tensor_init(shape),
310 ffn_r: context.tensor_init(shape),
311 }
312 }
313}
314
315#[derive(Debug, Clone, Serialize, DeserializeSeed)]
316#[serde_seed(seed = "Seed", context = "Context")]
317pub struct Header<F: Float> {
318 pub head_x: TensorGpu<F, ReadWrite>,
319 pub head_o: TensorGpu<f32, ReadWrite>,
320}
321
322impl<F: Float> Header<F> {
323 pub fn new(context: &Context, info: &ModelInfo, num_header: usize) -> Self {
324 let head_shape = Shape::new(info.num_emb, num_header, 1, 1);
325 let output_shape = Shape::new(info.num_vocab_padded(), num_header, 1, 1);
326
327 Self {
328 head_x: context.tensor_init(head_shape),
329 head_o: context.tensor_init(output_shape),
330 }
331 }
332}
333
334#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
335pub enum Hook {
336 PostEmbedLoaded,
337 PostEmbedLayerNorm,
338 PreAtt(usize),
339 PostAttLayerNorm(usize),
340 PreAttTokenShift(usize),
341 PostAttTokenShift(usize),
342 PreAttLinear(usize),
343 PostAttLinear(usize),
344 PreAttTimeMix(usize),
345 PostAttTimeMix(usize),
346 PreAttGate(usize),
347 PostAttGate(usize),
348 PreAttOut(usize),
349 PostAttOut(usize),
350 PostAtt(usize),
351 PreFfn(usize),
352 PostFfnLayerNorm(usize),
353 PreFfnTokenShift(usize),
354 PostFfnTokenShift(usize),
355 PreFfnLinear(usize),
356 PostFfnLinear(usize),
357 PostFfnActivate(usize),
358 PreFfnChannelMix(usize),
359 PostFfnChannelMix(usize),
360 PostFfn(usize),
361 PreHead,
362 PostHeadLayerNorm,
363 PostHead,
364}
365
366pub struct RnnJob {
367 commands: Vec<CommandBuffer>,
368 redirect: RnnRedirect,
369
370 embed: TensorCpu<f16>,
371
372 cursors: TensorGpu<u32, ReadWrite>,
373 input: TensorGpu<f16, ReadWrite>,
374 output: TensorGpu<f32, ReadWrite>,
375}
376
377impl Job for RnnJob {
378 type Input = RnnInput;
379 type Output = RnnOutput;
380
381 fn load(&self, input: &RnnChunk) -> Result<(), RuntimeError> {
382 if input.num_token() == 0 {
383 return Ok(());
384 }
385
386 let stack: Vec<TensorCpu<f16>> = input
387 .iter()
388 .map(|chunk| {
389 let num_emb = self.embed.shape()[0];
390 let data = self.embed.data();
391 let data = chunk
392 .iter()
393 .map(|token| match token {
394 &Token::Token(token) => {
395 let start = num_emb * token as usize;
396 let end = start + num_emb;
397 let data = data[start..end].to_vec();
398 TensorCpu::from_data_1d(data)
399 }
400 Token::Embed(tensor) => tensor.clone(),
401 })
402 .collect_vec();
403 match TensorCpu::stack(data, 1) {
404 Ok(tensor) => tensor,
405 Err(_) => TensorCpu::init([num_emb, 0, 1, 1]),
406 }
407 })
408 .collect();
409 let stack = TensorStack::try_from(stack)?;
410
411 let cursors = stack.cursors.clone().into_cursors();
412 let cursors = TensorCpu::from_data(self.cursors.shape(), cursors)?;
413 self.cursors.load(&cursors)?;
414 self.input.load(&stack.tensor)?;
415
416 Ok(())
417 }
418
419 fn submit(&mut self) {
420 let commands = std::mem::take(&mut self.commands);
421 self.output.context.queue.submit(commands);
422 }
423
424 async fn back(self) -> Result<Self::Output, RuntimeError> {
425 let output = self.output.back().await;
426 let batches: Vec<_> = self
427 .redirect
428 .outputs
429 .into_iter()
430 .map(|(start, end)| output.slice(.., start..end, .., ..))
431 .try_collect()?;
432 let batches = batches.into_iter().map(RnnOutputBatch).collect();
433 Ok(RnnOutput(batches))
434 }
435}
436
437#[derive(Debug, Clone)]
438pub struct Frame<F: Float> {
439 pub state: State,
440 pub buffer: Arc<Runtime<F>>,
441 pub header: Arc<Header<F>>,
442}
443
444pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
445pub type HookMap<F> = HashMap<Hook, HookFn<F>>;
446
447#[derive(Clone)]
448pub struct Bundle<F: Float> {
449 model: Model,
450 state: State,
451 hooks: Arc<HookMap<F>>,
452 buffers: ResourceCache<usize, Runtime<F>>,
453 headers: ResourceCache<usize, Header<F>>,
454 phantom: PhantomData<F>,
455}
456
457impl<F: Float> Bundle<F> {
458 pub fn new(model: Model, num_batch: usize) -> Self {
459 let context = model.context.clone();
460 let info = model.info.clone();
461 let state = {
462 let head_size = info.num_emb / info.num_head;
463 let shape = Shape::new(info.num_emb, head_size + 2, num_batch, 1);
464 let data = (0..info.num_layer).map(|_| context.zeros(shape)).collect();
465 State {
466 context,
467 info,
468 data,
469 }
470 };
471 Self {
472 model,
473 state,
474 hooks: Default::default(),
475 buffers: ResourceCache::new(4),
476 headers: ResourceCache::new(4),
477 phantom: PhantomData,
478 }
479 }
480
481 pub fn new_with_hooks(model: Model, num_batch: usize, hooks: HookMap<F>) -> Self {
482 Self {
483 hooks: Arc::new(hooks),
484 ..Self::new(model, num_batch)
485 }
486 }
487
488 fn checkout_buffer(
489 &self,
490 context: &Context,
491 info: &ModelInfo,
492 num_token: usize,
493 ) -> Arc<Runtime<F>> {
494 self.buffers
495 .checkout(num_token, || Runtime::new(context, info, num_token))
496 }
497
498 fn checkout_header(
499 &self,
500 context: &Context,
501 info: &ModelInfo,
502 num_header: usize,
503 ) -> Arc<Header<F>> {
504 self.headers
505 .checkout(num_header, || Header::new(context, info, num_header))
506 }
507}
508
509impl<F: Float> super::model::Bundle for Bundle<F> {
510 #[inline]
511 fn info(&self) -> ModelInfo {
512 self.model.info.clone()
513 }
514
515 #[inline]
516 fn state(&self) -> impl super::model::State + AsAny + 'static {
517 self.state.clone()
518 }
519
520 fn model(&self) -> impl Serialize + 'static {
521 self.model.clone()
522 }
523}
524
525fn turbo(num_token: usize) -> bool {
526 num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE)
527}
528
529fn hook_op<F: Float>(
530 hooks: &HookMap<F>,
531 hook: &Hook,
532 frame: &Frame<F>,
533) -> Result<TensorOp, TensorError> {
534 match hooks.get(hook) {
535 Some(f) => f(frame.clone()),
536 None => Ok(TensorOp::empty()),
537 }
538}
539
540impl<F: Float> Dispatcher<RnnJob> for Bundle<F> {
541 type Info = RnnInfo;
542
543 fn dispatch(&self, seed: Self::Info) -> Result<RnnJob, RuntimeError> {
544 let model = &self.model;
545 let state = &self.state;
546 let context = &model.context;
547 let info = &model.info;
548 let tensor = &model.tensor;
549
550 let num_token = seed.num_token();
551 let head_size = info.num_emb / info.num_head;
552
553 let redirect = seed.redirect();
554 let num_header = redirect.headers.len();
555
556 let buffer = self.checkout_buffer(context, info, num_token);
557 let header = self.checkout_header(context, info, num_header);
558 let frame = Frame {
559 state: state.clone(),
560 buffer: buffer.clone(),
561 header: header.clone(),
562 };
563
564 context.maintain();
565
566 if num_token == 0 {
567 return Ok(RnnJob {
568 commands: vec![],
569 redirect,
570 embed: model.tensor.embed.w.clone(),
571 cursors: buffer.cursors.clone(),
572 input: buffer.input.clone(),
573 output: header.head_o.clone(),
574 });
575 }
576
577 #[cfg(feature = "trace")]
578 let _span = tracing::trace_span!("build").entered();
579
580 let (head_op, head_x) = redirect.op(&buffer.x, &header.head_x)?;
581
582 let hook_op = |hook: Hook| hook_op(&self.hooks, &hook, &frame);
583 let mut ops = vec![];
584
585 {
586 #[cfg(feature = "trace")]
587 let _span = tracing::trace_span!("embed").entered();
588
589 ops.extend([
590 hook_op(Hook::PostEmbedLoaded)?,
591 TensorOp::layer_norm(
592 &tensor.embed.layer_norm.w,
593 &tensor.embed.layer_norm.b,
594 &buffer.input,
595 Model::LN_EPS,
596 )?,
597 TensorOp::blit(&buffer.input, &buffer.x)?,
598 hook_op(Hook::PostEmbedLayerNorm)?,
599 ]);
600 };
601
602 for (index, layer) in tensor.layers.iter().enumerate() {
603 #[cfg(feature = "trace")]
604 let _span = tracing::trace_span!("layer", index).entered();
605
606 let hooks = self.hooks.clone();
607 let frame = frame.clone();
608 let layer = layer.clone();
609
610 let op = dispatch_layer(
611 hooks,
612 frame,
613 layer,
614 index,
615 num_token,
616 head_size,
617 model.rescale,
618 )?;
619 ops.push(op);
620
621 if (index + 1) % model.sep == 0 {
622 ops.push(TensorOp::Sep);
623 }
624 }
625
626 {
627 #[cfg(feature = "trace")]
628 let _span = tracing::trace_span!("header").entered();
629
630 let hooks = self.hooks.clone();
631 let frame = frame.clone();
632 let head = model.tensor.head.clone();
633
634 let op = dispatch_header(hooks, frame, head, head_x, num_header, head_op)?;
635 ops.push(op);
636 }
637
638 let commands = {
639 #[cfg(feature = "trace")]
640 let _span = tracing::trace_span!("encode").entered();
641 context.encode(&TensorOp::List(ops))
642 };
643
644 Ok(RnnJob {
645 commands,
646 redirect,
647 embed: model.tensor.embed.w.clone(),
648 cursors: buffer.cursors.clone(),
649 input: buffer.input.clone(),
650 output: header.head_o.clone(),
651 })
652 }
653}
654
655#[allow(clippy::too_many_arguments)]
656fn dispatch_layer<F: Float>(
657 hooks: Arc<HookMap<F>>,
658 frame: Frame<F>,
659 layer: Layer,
660 index: usize,
661 num_token: usize,
662 head_size: usize,
663 rescale: usize,
664) -> Result<TensorOp, TensorError> {
665 let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
666 let Frame { state, buffer, .. } = &frame;
667
668 let time_first = layer.att.time_first.reshape(
669 TensorDimension::Size(head_size),
670 TensorDimension::Auto,
671 TensorDimension::Size(1),
672 TensorDimension::Size(1),
673 )?;
674 let time_decay = layer.att.time_decay.reshape(
675 TensorDimension::Size(head_size),
676 TensorDimension::Auto,
677 TensorDimension::Size(1),
678 TensorDimension::Size(1),
679 )?;
680 let aux_x = buffer.aux_x.reshape(
681 TensorDimension::Size(head_size),
682 TensorDimension::Auto,
683 TensorDimension::Size(num_token),
684 TensorDimension::Size(1),
685 )?;
686 let att_k = buffer.att_k.reshape(
687 TensorDimension::Size(head_size),
688 TensorDimension::Auto,
689 TensorDimension::Size(num_token),
690 TensorDimension::Size(1),
691 )?;
692 let att_v = buffer.att_v.reshape(
693 TensorDimension::Size(head_size),
694 TensorDimension::Auto,
695 TensorDimension::Size(num_token),
696 TensorDimension::Size(1),
697 )?;
698 let att_r = buffer.att_r.reshape(
699 TensorDimension::Size(head_size),
700 TensorDimension::Auto,
701 TensorDimension::Size(num_token),
702 TensorDimension::Size(1),
703 )?;
704
705 let mut ops = vec![];
706
707 ops.extend([
708 TensorOp::blit(&buffer.x, &buffer.att_x)?,
709 hook_op(Hook::PreAtt(index))?,
710 TensorOp::layer_norm(
711 &layer.att_layer_norm.w,
712 &layer.att_layer_norm.b,
713 &buffer.att_x,
714 Model::LN_EPS,
715 )?,
716 hook_op(Hook::PostAttLayerNorm(index))?,
717 hook_op(Hook::PreAttTokenShift(index))?,
718 TensorOp::token_shift(
719 &buffer.cursors,
720 &layer.att.time_mix_k,
721 state.att(index)?,
722 &buffer.att_x,
723 &buffer.att_kx,
724 false,
725 )?,
726 TensorOp::token_shift(
727 &buffer.cursors,
728 &layer.att.time_mix_v,
729 state.att(index)?,
730 &buffer.att_x,
731 &buffer.att_vx,
732 false,
733 )?,
734 TensorOp::token_shift(
735 &buffer.cursors,
736 &layer.att.time_mix_r,
737 state.att(index)?,
738 &buffer.att_x,
739 &buffer.att_rx,
740 false,
741 )?,
742 TensorOp::token_shift(
743 &buffer.cursors,
744 &layer.att.time_mix_g,
745 state.att(index)?,
746 &buffer.att_x,
747 &buffer.att_gx,
748 false,
749 )?,
750 hook_op(Hook::PostAttTokenShift(index))?,
751 hook_op(Hook::PreAttLinear(index))?,
752 layer.att.w_k.matmul_op(
753 &buffer.att_kx,
754 &buffer.att_k,
755 Activation::None,
756 turbo(num_token),
757 )?,
758 layer.att.w_v.matmul_op(
759 &buffer.att_vx,
760 &buffer.att_v,
761 Activation::None,
762 turbo(num_token),
763 )?,
764 layer.att.w_r.matmul_op(
765 &buffer.att_rx,
766 &buffer.att_r,
767 Activation::None,
768 turbo(num_token),
769 )?,
770 layer.att.w_g.matmul_op(
771 &buffer.att_gx,
772 &buffer.att_g,
773 Activation::None,
774 turbo(num_token),
775 )?,
776 hook_op(Hook::PostAttLinear(index))?,
777 hook_op(Hook::PreAttTimeMix(index))?,
778 TensorOp::blit(&buffer.att_x, &buffer.aux_x)?,
779 TensorOp::time_mix_v5(
780 &buffer.cursors,
781 &time_decay,
782 &time_first,
783 state.att(index)?,
784 &att_k,
785 &att_v,
786 &att_r,
787 &aux_x,
788 )?,
789 TensorOp::group_norm(
790 &layer.att.group_norm.w,
791 &layer.att.group_norm.b,
792 &aux_x,
793 Model::GN_EPS,
794 )?,
795 TensorOp::blit(&buffer.aux_x, &buffer.att_x)?,
796 hook_op(Hook::PostAttTimeMix(index))?,
797 hook_op(Hook::PreAttGate(index))?,
798 TensorOp::mul_activate(
799 &buffer.att_g,
800 &buffer.att_x,
801 Activation::Silu,
802 Activation::None,
803 Activation::None,
804 )?,
805 hook_op(Hook::PostAttGate(index))?,
806 hook_op(Hook::PreAttOut(index))?,
807 layer.att.w_o.matmul_op(
808 &buffer.att_x,
809 &buffer.att_o,
810 Activation::None,
811 turbo(num_token),
812 )?,
813 hook_op(Hook::PostAttOut(index))?,
814 TensorOp::add(&buffer.att_o, &buffer.x)?,
815 hook_op(Hook::PostAtt(index))?,
816 ]);
817
818 ops.extend([
819 TensorOp::blit(&buffer.x, &buffer.ffn_x)?,
820 hook_op(Hook::PreFfn(index))?,
821 TensorOp::layer_norm(
822 &layer.ffn_layer_norm.w,
823 &layer.ffn_layer_norm.b,
824 &buffer.ffn_x,
825 Model::LN_EPS,
826 )?,
827 hook_op(Hook::PostFfnLayerNorm(index))?,
828 hook_op(Hook::PreFfnTokenShift(index))?,
829 TensorOp::token_shift(
830 &buffer.cursors,
831 &layer.ffn.time_mix_k,
832 state.ffn(index)?,
833 &buffer.ffn_x,
834 &buffer.ffn_kx,
835 false,
836 )?,
837 TensorOp::token_shift(
838 &buffer.cursors,
839 &layer.ffn.time_mix_r,
840 state.ffn(index)?,
841 &buffer.ffn_x,
842 &buffer.ffn_rx,
843 false,
844 )?,
845 hook_op(Hook::PostFfnTokenShift(index))?,
846 hook_op(Hook::PreFfnLinear(index))?,
847 layer.ffn.w_k.matmul_op(
848 &buffer.ffn_kx,
849 &buffer.ffn_k,
850 Activation::SquaredRelu,
851 turbo(num_token),
852 )?,
853 hook_op(Hook::PostFfnActivate(index))?,
854 layer.ffn.w_v.matmul_op_sparse(
855 &buffer.ffn_k,
856 &buffer.ffn_v,
857 Activation::None,
858 turbo(num_token),
859 )?,
860 layer.ffn.w_r.matmul_op(
861 &buffer.ffn_rx,
862 &buffer.ffn_r,
863 Activation::None,
864 turbo(num_token),
865 )?,
866 hook_op(Hook::PostFfnLinear(index))?,
867 hook_op(Hook::PreFfnChannelMix(index))?,
868 TensorOp::channel_mix(
869 &buffer.cursors,
870 state.ffn(index)?,
871 &buffer.ffn_r,
872 &buffer.ffn_v,
873 &buffer.ffn_x,
874 )?,
875 hook_op(Hook::PostFfnChannelMix(index))?,
876 TensorOp::add(&buffer.ffn_x, &buffer.x)?,
877 hook_op(Hook::PostFfn(index))?,
878 ]);
879
880 if (index + 1).is_multiple_of(rescale) {
881 ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
882 }
883
884 Ok(TensorOp::List(ops))
885}
886
887fn dispatch_header<F: Float>(
888 hooks: Arc<HookMap<F>>,
889 frame: Frame<F>,
890 head: Head,
891 head_x: TensorGpu<F, ReadWrite>,
892 num_header: usize,
893 head_op: TensorOp,
894) -> Result<TensorOp, TensorError> {
895 let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
896 let header = &frame.header;
897 let mut ops = vec![head_op];
898
899 if num_header > 0 {
900 ops.extend([
901 hook_op(Hook::PreHead)?,
902 TensorOp::layer_norm(
903 &head.layer_norm.w,
904 &head.layer_norm.b,
905 &head_x,
906 Model::LN_EPS,
907 )?,
908 hook_op(Hook::PostHeadLayerNorm)?,
909 head.w.matmul_op(
910 head_x.view(.., .., .., ..)?,
911 header.head_o.view(.., .., .., ..)?,
912 Activation::None,
913 turbo(num_header),
914 )?,
915 hook_op(Hook::PostHead)?,
916 ]);
917 }
918 Ok(TensorOp::List(ops))
919}
920
921impl<R: Reader> ModelBuilder<R> {
922 pub async fn build_v5(self) -> Result<Model, LoaderError> {
923 let ModelBuilder {
924 context,
925 model,
926 rescale,
927 sep,
928 lora,
929 quant,
930 ..
931 } = self;
932
933 let rescale = rescale.unwrap_or(Model::DEFAULT_RESCALE);
934 let sep = sep.unwrap_or(Model::DEFAULT_SEP);
935
936 let info = Loader::info(&model)?;
937 let loader = Loader {
938 context: context.clone(),
939 model,
940 lora,
941 };
942
943 let embed = Embed {
944 layer_norm: LayerNorm {
945 w: loader.load_vector_f16("blocks.0.ln0.weight")?,
946 b: loader.load_vector_f16("blocks.0.ln0.bias")?,
947 },
948 w: loader.load_matrix_f16_padded_cpu("emb.weight")?,
949 };
950
951 let head = Head {
952 layer_norm: LayerNorm {
953 w: loader.load_vector_f16("ln_out.weight")?,
954 b: loader.load_vector_f16("ln_out.bias")?,
955 },
956 w: Matrix::Fp16(loader.load_matrix_f16("head.weight")?),
957 };
958
959 let submission_index = Some(context.queue.submit(None));
960 _ = context.device.poll(wgpu::PollType::Wait {
961 submission_index,
962 timeout: None,
963 });
964
965 let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant);
966 let load_matrix_discount = |name: String, quant: Quant, discount: f32| {
967 loader.load_matrix_discount(name, quant, discount)
968 };
969
970 let mut layers = vec![];
971 for layer in 0..info.num_layer {
972 let quant = quant.get(&layer).copied().unwrap_or_default();
973 let discount = 2.0_f32.powi(-((layer / rescale) as i32));
974
975 let att_layer_norm = LayerNorm {
976 w: loader.load_vector_f16(format!("blocks.{layer}.ln1.weight"))?,
977 b: loader.load_vector_f16(format!("blocks.{layer}.ln1.bias"))?,
978 };
979
980 let att = format!("blocks.{layer}.att");
981 let time_decay = loader.load_vector_exp_exp_f32(format!("{att}.time_decay"))?;
982 let time_first = loader.load_vector_f32(format!("{att}.time_first"))?;
983 let time_mix_k = loader.load_vector_f16(format!("{att}.time_mix_k"))?;
984 let time_mix_v = loader.load_vector_f16(format!("{att}.time_mix_v"))?;
985 let time_mix_r = loader.load_vector_f16(format!("{att}.time_mix_r"))?;
986 let time_mix_g = loader.load_vector_f16(format!("{att}.time_mix_g"))?;
987
988 let group_norm = LayerNorm {
989 w: loader
990 .load_vector_f16(format!("{att}.ln_x.weight"))?
991 .reshape(
992 TensorDimension::Auto,
993 TensorDimension::Size(info.num_head),
994 TensorDimension::Size(1),
995 TensorDimension::Size(1),
996 )?,
997 b: loader
998 .load_vector_f16(format!("{att}.ln_x.bias"))?
999 .reshape(
1000 TensorDimension::Auto,
1001 TensorDimension::Size(info.num_head),
1002 TensorDimension::Size(1),
1003 TensorDimension::Size(1),
1004 )?,
1005 };
1006
1007 let att = Att {
1008 time_decay,
1009 time_first,
1010 time_mix_k,
1011 time_mix_v,
1012 time_mix_r,
1013 time_mix_g,
1014 w_k: load_matrix(format!("{att}.key.weight"), quant)?,
1015 w_v: load_matrix(format!("{att}.value.weight"), quant)?,
1016 w_r: load_matrix(format!("{att}.receptance.weight"), quant)?,
1017 w_g: load_matrix(format!("{att}.gate.weight"), quant)?,
1018 w_o: load_matrix_discount(format!("{att}.output.weight"), quant, discount)?,
1019 group_norm,
1020 };
1021
1022 let ffn_layer_norm = LayerNorm {
1023 w: loader.load_vector_f16(format!("blocks.{layer}.ln2.weight"))?,
1024 b: loader.load_vector_f16(format!("blocks.{layer}.ln2.bias"))?,
1025 };
1026
1027 let ffn = format!("blocks.{layer}.ffn");
1028 let time_mix_k = loader.load_vector_f16(format!("{ffn}.time_mix_k"))?;
1029 let time_mix_r = loader.load_vector_f16(format!("{ffn}.time_mix_r"))?;
1030
1031 let ffn = Ffn {
1032 time_mix_k,
1033 time_mix_r,
1034 w_r: load_matrix(format!("{ffn}.receptance.weight"), quant)?,
1035 w_k: load_matrix(format!("{ffn}.key.weight"), quant)?,
1036 w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?,
1037 };
1038
1039 let submission_index = Some(context.queue.submit(None));
1040 _ = context.device.poll(wgpu::PollType::Wait {
1041 submission_index,
1042 timeout: None,
1043 });
1044
1045 layers.push(Layer {
1046 att_layer_norm,
1047 ffn_layer_norm,
1048 att,
1049 ffn,
1050 })
1051 }
1052
1053 let submission_index = Some(context.queue.submit(None));
1054 _ = context.device.poll(wgpu::PollType::Wait {
1055 submission_index,
1056 timeout: None,
1057 });
1058
1059 let tensor = ModelTensor {
1060 embed,
1061 head,
1062 layers,
1063 };
1064 let model = {
1065 let context = context.clone();
1066 let info = info.clone();
1067 Model {
1068 context,
1069 info,
1070 rescale,
1071 sep,
1072 tensor,
1073 }
1074 };
1075 Ok(model)
1076 }
1077}
1078
1079pub async fn read_state<R: Reader>(
1081 context: &Context,
1082 info: &ModelInfo,
1083 model: R,
1084) -> Result<TensorCpu<f32>, LoaderError> {
1085 let loader = Loader {
1086 context: context.clone(),
1087 model,
1088 lora: vec![],
1089 };
1090
1091 let head_size = info.num_emb / info.num_head;
1092 let data: TensorGpu<f32, _> = context.zeros([info.num_emb, head_size + 2, info.num_layer, 1]);
1093
1094 let mut ops = vec![];
1095 for layer in 0..info.num_layer {
1096 let matrix = loader.load_matrix_f16(format!("blocks.{layer}.att.time_state"))?;
1097 let state: TensorGpu<_, _> = context.tensor_init([head_size, info.num_head, head_size, 1]);
1098 let reshaped: TensorGpu<f16, _> = state.reshape(
1099 TensorDimension::Size(info.num_emb),
1100 TensorDimension::Size(head_size),
1101 TensorDimension::Size(1),
1102 TensorDimension::Auto,
1103 )?;
1104 ops.extend([
1105 TensorOp::transpose(&matrix, &state)?,
1106 TensorOp::blit(&reshaped, data.view(.., 1..head_size + 1, layer, ..)?)?,
1107 ]);
1108 }
1109 context.queue.submit(context.encode(&TensorOp::List(ops)));
1110
1111 Ok(data.back().await)
1112}