1use std::f32::NEG_INFINITY;
5
6use burn::{
7 config::Config,
8 module::{Module, Param},
9 nn::{
10 self, PaddingConfig1d,
11 conv::{Conv1d, Conv1dConfig},
12 },
13 tensor::{Distribution, Int, Tensor, activation::softmax, backend::Backend},
14};
15
16#[derive(Config, Debug)]
18pub struct WhisperConfig {
19 pub audio_encoder_config: AudioEncoderConfig,
20 pub text_decoder_config: TextDecoderConfig,
21}
22
23impl WhisperConfig {
24 pub fn tiny_en() -> Self {
26 Self {
27 audio_encoder_config: AudioEncoderConfig {
28 n_mels: 80,
29 n_audio_ctx: 1500,
30 n_audio_state: 384,
31 n_audio_head: 6,
32 n_audio_layer: 4,
33 },
34 text_decoder_config: TextDecoderConfig {
35 n_vocab: 51864,
36 n_text_ctx: 448,
37 n_text_state: 384,
38 n_text_head: 6,
39 n_text_layer: 4,
40 },
41 }
42 }
43
44 pub fn base() -> Self {
46 Self {
47 audio_encoder_config: AudioEncoderConfig {
48 n_mels: 80,
49 n_audio_ctx: 1500,
50 n_audio_state: 512,
51 n_audio_head: 8,
52 n_audio_layer: 6,
53 },
54 text_decoder_config: TextDecoderConfig {
55 n_vocab: 51864,
56 n_text_ctx: 448,
57 n_text_state: 512,
58 n_text_head: 8,
59 n_text_layer: 6,
60 },
61 }
62 }
63
64 pub fn small() -> Self {
66 Self {
67 audio_encoder_config: AudioEncoderConfig {
68 n_mels: 80,
69 n_audio_ctx: 1500,
70 n_audio_state: 768,
71 n_audio_head: 12,
72 n_audio_layer: 12,
73 },
74 text_decoder_config: TextDecoderConfig {
75 n_vocab: 51864,
76 n_text_ctx: 448,
77 n_text_state: 768,
78 n_text_head: 12,
79 n_text_layer: 12,
80 },
81 }
82 }
83
84 pub fn medium() -> Self {
86 Self {
87 audio_encoder_config: AudioEncoderConfig {
88 n_mels: 80,
89 n_audio_ctx: 1500,
90 n_audio_state: 1024,
91 n_audio_head: 16,
92 n_audio_layer: 24,
93 },
94 text_decoder_config: TextDecoderConfig {
95 n_vocab: 51864,
96 n_text_ctx: 448,
97 n_text_state: 1024,
98 n_text_head: 16,
99 n_text_layer: 24,
100 },
101 }
102 }
103
104 pub fn large_v2() -> Self {
106 Self {
107 audio_encoder_config: AudioEncoderConfig {
108 n_mels: 128,
109 n_audio_ctx: 1500,
110 n_audio_state: 1280,
111 n_audio_head: 20,
112 n_audio_layer: 32,
113 },
114 text_decoder_config: TextDecoderConfig {
115 n_vocab: 51864,
116 n_text_ctx: 448,
117 n_text_state: 1280,
118 n_text_head: 20,
119 n_text_layer: 32,
120 },
121 }
122 }
123
124 pub fn large_v3() -> Self {
126 Self {
127 audio_encoder_config: AudioEncoderConfig {
128 n_mels: 128,
129 n_audio_ctx: 1500,
130 n_audio_state: 1280,
131 n_audio_head: 20,
132 n_audio_layer: 32,
133 },
134 text_decoder_config: TextDecoderConfig {
135 n_vocab: 51865,
136 n_text_ctx: 448,
137 n_text_state: 1280,
138 n_text_head: 20,
139 n_text_layer: 32,
140 },
141 }
142 }
143
144 pub fn init<B: Backend>(&self, device: &B::Device) -> Whisper<B> {
145 let n_audio_state = self.audio_encoder_config.n_audio_state;
146 let n_text_state = self.text_decoder_config.n_text_state;
147
148 assert!(
149 n_audio_state == n_text_state,
150 "Audio encoder state size {} must be equal to text decoder state size {}.",
151 n_audio_state,
152 n_text_state
153 );
154
155 let encoder = self.audio_encoder_config.init(device);
156 let decoder = self.text_decoder_config.init(device);
157
158 Whisper { encoder, decoder }
159 }
160}
161
162#[derive(Module, Debug)]
164pub struct Whisper<B: Backend> {
165 pub encoder: AudioEncoder<B>,
166 pub decoder: TextDecoder<B>,
167}
168
169impl<B: Backend> Whisper<B> {
170 pub fn forward(&self, mel: Tensor<B, 3>, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
172 self.decoder.forward(tokens, self.encoder.forward(mel))
173 }
174
175 pub fn forward_encoder(&self, mel: Tensor<B, 3>) -> Tensor<B, 3> {
177 self.encoder.forward(mel)
178 }
179
180 pub fn forward_decoder(
182 &self,
183 tokens: Tensor<B, 2, Int>,
184 encoder_output: Tensor<B, 3>,
185 ) -> Tensor<B, 3> {
186 self.decoder.forward(tokens, encoder_output)
187 }
188
189 pub fn encoder_ctx_size(&self) -> usize {
190 self.encoder.ctx_size()
191 }
192
193 pub fn decoder_ctx_size(&self) -> usize {
194 self.decoder.ctx_size()
195 }
196}
197
198#[derive(Config, Debug)]
203pub struct TextDecoderConfig {
204 pub n_vocab: usize,
205 pub n_text_ctx: usize,
206 pub n_text_state: usize,
207 pub n_text_head: usize,
208 pub n_text_layer: usize,
209}
210
211impl TextDecoderConfig {
212 pub fn init<B: Backend>(&self, device: &B::Device) -> TextDecoder<B> {
213 let token_embedding = Param::from_tensor(Tensor::random(
214 [self.n_vocab, self.n_text_state],
215 Distribution::Normal(0.0, 0.02),
216 device,
217 ));
218 let positional_embedding = Param::from_tensor(Tensor::random(
219 [self.n_text_ctx, self.n_text_state],
220 Distribution::Normal(0.0, 0.01),
221 device,
222 ));
223 let blocks: Vec<_> = (0..self.n_text_layer)
224 .map(|_| {
225 ResidualDecoderAttentionBlockConfig::new(self.n_text_state, self.n_text_head)
226 .init(device)
227 })
228 .collect();
229 let ln = nn::LayerNormConfig::new(self.n_text_state).init(device);
230
231 let mask = Param::from_tensor(attn_decoder_mask(self.n_text_ctx, device));
232
233 let n_vocab = self.n_vocab;
234 let n_text_ctx = self.n_text_ctx;
235
236 TextDecoder {
237 token_embedding,
238 positional_embedding,
239 blocks,
240 ln,
241 mask,
242 n_vocab,
243 n_text_ctx,
244 }
245 }
246}
247
248#[derive(Module, Debug)]
249pub struct TextDecoder<B: Backend> {
250 pub token_embedding: Param<Tensor<B, 2>>,
251 pub positional_embedding: Param<Tensor<B, 2>>,
252 pub blocks: Vec<ResidualDecoderAttentionBlock<B>>,
253 pub ln: nn::LayerNorm<B>,
254 pub mask: Param<Tensor<B, 2>>,
255 pub n_vocab: usize,
256 pub n_text_ctx: usize,
257}
258
259impl<B: Backend> TextDecoder<B> {
260 pub fn forward(&self, x: Tensor<B, 2, Int>, xa: Tensor<B, 3>) -> Tensor<B, 3> {
261 let [_n_batch, seq_len] = x.dims();
262
263 assert!(
264 seq_len <= self.n_text_ctx,
265 "Token sequence length {} must not exceed {}.",
266 seq_len,
267 self.n_text_ctx
268 );
269
270 let x = burn::tensor::module::embedding(self.token_embedding.val(), x)
272 + self
273 .positional_embedding
274 .val()
275 .slice([0..seq_len])
276 .unsqueeze::<3>();
277
278 let mut x = x;
279 for block in &self.blocks {
280 x = block.forward(x, xa.clone(), self.mask.val());
281 }
282
283 let x = self.ln.forward(x);
284 x.matmul(self.token_embedding.val().transpose().unsqueeze::<3>())
286 }
287
288 pub fn ctx_size(&self) -> usize {
289 self.n_text_ctx
290 }
291}
292
293#[derive(Config, Debug)]
298pub struct AudioEncoderConfig {
299 pub n_mels: usize,
300 pub n_audio_ctx: usize,
301 pub n_audio_state: usize,
302 pub n_audio_head: usize,
303 pub n_audio_layer: usize,
304}
305
306impl AudioEncoderConfig {
307 pub fn init<B: Backend>(&self, device: &B::Device) -> AudioEncoder<B> {
308 let conv1 = Conv1dConfig::new(self.n_mels, self.n_audio_state, 3)
309 .with_padding(PaddingConfig1d::Explicit(1, 1))
310 .init(device);
311 let gelu1 = nn::Gelu::new();
312 let conv2 = Conv1dConfig::new(self.n_audio_state, self.n_audio_state, 3)
313 .with_padding(PaddingConfig1d::Explicit(1, 1))
314 .with_stride(2)
315 .init(device);
316 let gelu2 = nn::Gelu::new();
317 let blocks: Vec<_> = (0..self.n_audio_layer)
318 .map(|_| {
319 ResidualEncoderAttentionBlockConfig::new(self.n_audio_state, self.n_audio_head)
320 .init(device)
321 })
322 .collect();
323 let ln_post = nn::LayerNormConfig::new(self.n_audio_state).init(device);
324 let positional_embedding = Param::from_tensor(Tensor::random(
325 [self.n_audio_ctx, self.n_audio_state],
326 Distribution::Normal(0.0, 0.01),
327 device,
328 ));
329 let n_mels = self.n_mels;
330 let n_audio_ctx = self.n_audio_ctx;
331
332 AudioEncoder {
333 conv1,
334 gelu1,
335 conv2,
336 gelu2,
337 blocks,
338 ln_post,
339 positional_embedding,
340 n_mels,
341 n_audio_ctx,
342 }
343 }
344}
345
346#[derive(Module, Debug)]
347pub struct AudioEncoder<B: Backend> {
348 pub conv1: Conv1d<B>,
349 pub gelu1: nn::Gelu,
350 pub conv2: Conv1d<B>,
351 pub gelu2: nn::Gelu,
352 pub blocks: Vec<ResidualEncoderAttentionBlock<B>>,
353 pub ln_post: nn::LayerNorm<B>,
354 pub positional_embedding: Param<Tensor<B, 2>>,
355 pub n_mels: usize,
356 pub n_audio_ctx: usize,
357}
358
359impl<B: Backend> AudioEncoder<B> {
360 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
361 let [_, n_mels, n_ctx] = x.dims();
362
363 assert!(
364 n_mels == self.n_mels,
365 "Audio mel spectrum size must be {}.",
366 self.n_mels
367 );
368 assert!(
369 n_ctx <= 2 * self.n_audio_ctx,
370 "Audio length {} cannot exceed {}.",
371 n_ctx,
372 2 * self.n_audio_ctx
373 );
374
375 let x = self.gelu1.forward(self.conv1.forward(x));
376 let x = self.gelu2.forward(self.conv2.forward(x));
377
378 let x = x.swap_dims(1, 2);
379 let k = x.dims()[1];
380 let x = x + self
381 .positional_embedding
382 .val()
383 .slice([0..k])
384 .unsqueeze::<3>();
385
386 let mut x = x;
387 for block in &self.blocks {
388 x = block.forward(x);
389 }
390
391 self.ln_post.forward(x)
392 }
393
394 pub fn ctx_size(&self) -> usize {
395 self.n_audio_ctx
396 }
397}
398
399#[derive(Config, Debug)]
404pub struct ResidualEncoderAttentionBlockConfig {
405 n_state: usize,
406 n_head: usize,
407}
408
409impl ResidualEncoderAttentionBlockConfig {
410 pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualEncoderAttentionBlock<B> {
411 let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
412 let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
413
414 let mlp = MLPConfig::new(self.n_state).init(device);
415 let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
416
417 ResidualEncoderAttentionBlock {
418 attn,
419 attn_ln,
420 mlp,
421 mlp_ln,
422 }
423 }
424}
425
426#[derive(Module, Debug)]
427pub struct ResidualEncoderAttentionBlock<B: Backend> {
428 pub attn: MultiHeadSelfAttention<B>,
429 pub attn_ln: nn::LayerNorm<B>,
430 pub mlp: MLP<B>,
431 pub mlp_ln: nn::LayerNorm<B>,
432}
433
434impl<B: Backend> ResidualEncoderAttentionBlock<B> {
435 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
436 let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), None);
437 let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
438 x
439 }
440}
441
442#[derive(Config, Debug)]
443pub struct ResidualDecoderAttentionBlockConfig {
444 n_state: usize,
445 n_head: usize,
446}
447
448impl ResidualDecoderAttentionBlockConfig {
449 pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
450 let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
451 let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
452
453 let cross_attn = MultiHeadCrossAttentionConfig::new(self.n_state, self.n_head).init(device);
454 let cross_attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
455
456 let mlp = MLPConfig::new(self.n_state).init(device);
457 let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
458
459 ResidualDecoderAttentionBlock {
460 attn,
461 attn_ln,
462 cross_attn,
463 cross_attn_ln,
464 mlp,
465 mlp_ln,
466 }
467 }
468}
469
470#[derive(Module, Debug)]
471pub struct ResidualDecoderAttentionBlock<B: Backend> {
472 pub attn: MultiHeadSelfAttention<B>,
473 pub attn_ln: nn::LayerNorm<B>,
474 pub cross_attn: MultiHeadCrossAttention<B>,
475 pub cross_attn_ln: nn::LayerNorm<B>,
476 pub mlp: MLP<B>,
477 pub mlp_ln: nn::LayerNorm<B>,
478}
479
480impl<B: Backend> ResidualDecoderAttentionBlock<B> {
481 pub fn forward(&self, x: Tensor<B, 3>, xa: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
482 let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
483 let x = x.clone() + self.cross_attn.forward(self.cross_attn_ln.forward(x), xa);
484 let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
485 x
486 }
487}
488
489#[derive(Config, Debug)]
494pub struct MLPConfig {
495 n_state: usize,
496}
497
498impl MLPConfig {
499 pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
500 let lin1 = nn::LinearConfig::new(self.n_state, 4 * self.n_state).init(device);
501 let gelu = nn::Gelu::new();
502 let lin2 = nn::LinearConfig::new(4 * self.n_state, self.n_state).init(device);
503
504 MLP { lin1, gelu, lin2 }
505 }
506}
507
508#[derive(Module, Debug)]
509pub struct MLP<B: Backend> {
510 pub lin1: nn::Linear<B>,
511 pub gelu: nn::Gelu,
512 pub lin2: nn::Linear<B>,
513}
514
515impl<B: Backend> MLP<B> {
516 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
517 let x = self.lin1.forward(x);
518 let x = self.gelu.forward(x);
519 self.lin2.forward(x)
520 }
521}
522
523#[derive(Config, Debug)]
528pub struct MultiHeadSelfAttentionConfig {
529 n_state: usize,
530 n_head: usize,
531}
532
533impl MultiHeadSelfAttentionConfig {
534 pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
535 assert!(
536 self.n_state.is_multiple_of(self.n_head),
537 "State size {} must be a multiple of head size {}",
538 self.n_state,
539 self.n_head
540 );
541
542 let n_head = self.n_head;
543 let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
544 let key = nn::LinearConfig::new(self.n_state, self.n_state)
545 .with_bias(false)
546 .init(device);
547 let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
548 let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
549
550 MultiHeadSelfAttention {
551 n_head,
552 query,
553 key,
554 value,
555 out,
556 }
557 }
558}
559
560#[derive(Module, Debug)]
561pub struct MultiHeadSelfAttention<B: Backend> {
562 pub n_head: usize,
563 pub query: nn::Linear<B>,
564 pub key: nn::Linear<B>,
565 pub value: nn::Linear<B>,
566 pub out: nn::Linear<B>,
567}
568
569impl<B: Backend> MultiHeadSelfAttention<B> {
570 pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
571 let q = self.query.forward(x.clone());
572 let k = self.key.forward(x.clone());
573 let v = self.value.forward(x);
574
575 let wv = qkv_attention(q, k, v, mask, self.n_head);
576
577 self.out.forward(wv)
578 }
579}
580
581#[derive(Config, Debug)]
582pub struct MultiHeadCrossAttentionConfig {
583 n_state: usize,
584 n_head: usize,
585}
586
587impl MultiHeadCrossAttentionConfig {
588 pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadCrossAttention<B> {
589 assert!(
590 self.n_state.is_multiple_of(self.n_head),
591 "State size {} must be a multiple of head size {}",
592 self.n_state,
593 self.n_head
594 );
595
596 let n_head = self.n_head;
597 let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
598 let key = nn::LinearConfig::new(self.n_state, self.n_state)
599 .with_bias(false)
600 .init(device);
601 let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
602 let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
603
604 MultiHeadCrossAttention {
605 n_head,
606 query,
607 key,
608 value,
609 out,
610 }
611 }
612}
613
614#[derive(Module, Debug)]
615pub struct MultiHeadCrossAttention<B: Backend> {
616 pub n_head: usize,
617 pub query: nn::Linear<B>,
618 pub key: nn::Linear<B>,
619 pub value: nn::Linear<B>,
620 pub out: nn::Linear<B>,
621}
622
623impl<B: Backend> MultiHeadCrossAttention<B> {
624 pub fn forward(&self, x: Tensor<B, 3>, xa: Tensor<B, 3>) -> Tensor<B, 3> {
625 let q = self.query.forward(x);
626 let k = self.key.forward(xa.clone());
627 let v = self.value.forward(xa);
628
629 let wv = qkv_attention(q, k, v, None, self.n_head);
630
631 self.out.forward(wv)
632 }
633}
634
635pub fn qkv_attention<B: Backend>(
640 q: Tensor<B, 3>,
641 k: Tensor<B, 3>,
642 v: Tensor<B, 3>,
643 mask: Option<Tensor<B, 2>>,
644 n_head: usize,
645) -> Tensor<B, 3> {
646 let [n_batch, n_qctx, n_state] = q.dims();
647 let [_, n_ctx, _] = k.dims();
648
649 let scale = (n_state as f64 / n_head as f64).powf(-0.25);
650 let n_hstate = n_state / n_head;
651
652 let q = q
653 .reshape([n_batch, n_qctx, n_head, n_hstate])
654 .swap_dims(1, 2)
655 * scale;
656 let k = k
657 .reshape([n_batch, n_ctx, n_head, n_hstate])
658 .swap_dims(1, 2)
659 .transpose()
660 * scale;
661 let v = v
662 .reshape([n_batch, n_ctx, n_head, n_hstate])
663 .swap_dims(1, 2);
664
665 let qk = q.matmul(k);
666
667 let qk = if let Some(mask) = mask {
669 qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>()
670 } else {
671 qk
672 };
673
674 let w = softmax(qk, 3);
676 let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3);
677
678 o
679}
680
681pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
683 let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
684
685 for i in 0..(seq_length - 1) {
686 let values =
687 Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
688 mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
689 }
690
691 mask
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697 use burn_flex::Flex;
698 use burn_flex::FlexDevice;
699
700 type TestBackend = Flex<f32>;
701
702 #[test]
703 fn test_config_creation() {
704 let config = WhisperConfig::tiny_en();
705 assert_eq!(config.audio_encoder_config.n_audio_state, 384);
706 assert_eq!(config.text_decoder_config.n_text_state, 384);
707 assert_eq!(config.audio_encoder_config.n_mels, 80);
708 }
709
710 #[test]
711 fn test_model_init() {
712 let device = FlexDevice;
713 let config = WhisperConfig::tiny_en();
714 let model = config.init::<TestBackend>(&device);
715
716 assert_eq!(model.encoder.n_mels, 80);
717 assert_eq!(model.decoder.n_vocab, 51864);
718 }
719
720 #[test]
721 fn test_encoder_forward() {
722 let device = FlexDevice;
723 let config = WhisperConfig::tiny_en();
724 let model = config.init::<TestBackend>(&device);
725
726 let mel = Tensor::random([1, 80, 100], Distribution::Normal(0.0, 1.0), &device);
728 let output = model.encoder.forward(mel);
729
730 assert_eq!(output.dims()[0], 1);
732 assert_eq!(output.dims()[1], 50); assert_eq!(output.dims()[2], 384);
734 }
735
736 #[test]
737 fn test_decoder_forward() {
738 let device = FlexDevice;
739 let config = WhisperConfig::tiny_en();
740 let model = config.init::<TestBackend>(&device);
741
742 let encoder_output = Tensor::random([1, 50, 384], Distribution::Normal(0.0, 1.0), &device);
744
745 let tokens = Tensor::<TestBackend, 2, Int>::zeros([1, 5], &device);
747
748 let logits = model.decoder.forward(tokens, encoder_output);
749
750 assert_eq!(logits.dims()[0], 1);
752 assert_eq!(logits.dims()[1], 5);
753 assert_eq!(logits.dims()[2], 51864);
754 }
755
756 #[test]
757 fn test_attention_mask() {
758 let device = FlexDevice;
759 let mask = attn_decoder_mask::<TestBackend>(4, &device);
760
761 assert_eq!(mask.dims(), [4, 4]);
762 }
764}