Skip to main content

whisperforge_core/
model.rs

1// Whisper model implementation adapted from whisper-burn (MIT License)
2// https://github.com/Gadersd/whisper-burn
3
4use std::f32::NEG_INFINITY;
5
6use burn::{
7    config::Config,
8    module::{Module, Param},
9    nn::{
10        self, PaddingConfig1d,
11        conv::{Conv1d, Conv1dConfig},
12    },
13    tensor::{Distribution, Int, Tensor, activation::softmax, backend::Backend},
14};
15
16/// Configuration for the Whisper model
17#[derive(Config, Debug)]
18pub struct WhisperConfig {
19    pub audio_encoder_config: AudioEncoderConfig,
20    pub text_decoder_config: TextDecoderConfig,
21}
22
23impl WhisperConfig {
24    /// Create config for tiny.en model
25    pub fn tiny_en() -> Self {
26        Self {
27            audio_encoder_config: AudioEncoderConfig {
28                n_mels: 80,
29                n_audio_ctx: 1500,
30                n_audio_state: 384,
31                n_audio_head: 6,
32                n_audio_layer: 4,
33            },
34            text_decoder_config: TextDecoderConfig {
35                n_vocab: 51864,
36                n_text_ctx: 448,
37                n_text_state: 384,
38                n_text_head: 6,
39                n_text_layer: 4,
40            },
41        }
42    }
43
44    /// Create config for base model
45    pub fn base() -> Self {
46        Self {
47            audio_encoder_config: AudioEncoderConfig {
48                n_mels: 80,
49                n_audio_ctx: 1500,
50                n_audio_state: 512,
51                n_audio_head: 8,
52                n_audio_layer: 6,
53            },
54            text_decoder_config: TextDecoderConfig {
55                n_vocab: 51864,
56                n_text_ctx: 448,
57                n_text_state: 512,
58                n_text_head: 8,
59                n_text_layer: 6,
60            },
61        }
62    }
63
64    /// Create config for small model
65    pub fn small() -> Self {
66        Self {
67            audio_encoder_config: AudioEncoderConfig {
68                n_mels: 80,
69                n_audio_ctx: 1500,
70                n_audio_state: 768,
71                n_audio_head: 12,
72                n_audio_layer: 12,
73            },
74            text_decoder_config: TextDecoderConfig {
75                n_vocab: 51864,
76                n_text_ctx: 448,
77                n_text_state: 768,
78                n_text_head: 12,
79                n_text_layer: 12,
80            },
81        }
82    }
83
84    /// Create config for medium model
85    pub fn medium() -> Self {
86        Self {
87            audio_encoder_config: AudioEncoderConfig {
88                n_mels: 80,
89                n_audio_ctx: 1500,
90                n_audio_state: 1024,
91                n_audio_head: 16,
92                n_audio_layer: 24,
93            },
94            text_decoder_config: TextDecoderConfig {
95                n_vocab: 51864,
96                n_text_ctx: 448,
97                n_text_state: 1024,
98                n_text_head: 16,
99                n_text_layer: 24,
100            },
101        }
102    }
103
104    /// Create config for large-v2 model
105    pub fn large_v2() -> Self {
106        Self {
107            audio_encoder_config: AudioEncoderConfig {
108                n_mels: 128,
109                n_audio_ctx: 1500,
110                n_audio_state: 1280,
111                n_audio_head: 20,
112                n_audio_layer: 32,
113            },
114            text_decoder_config: TextDecoderConfig {
115                n_vocab: 51864,
116                n_text_ctx: 448,
117                n_text_state: 1280,
118                n_text_head: 20,
119                n_text_layer: 32,
120            },
121        }
122    }
123
124    /// Create config for large-v3 model
125    pub fn large_v3() -> Self {
126        Self {
127            audio_encoder_config: AudioEncoderConfig {
128                n_mels: 128,
129                n_audio_ctx: 1500,
130                n_audio_state: 1280,
131                n_audio_head: 20,
132                n_audio_layer: 32,
133            },
134            text_decoder_config: TextDecoderConfig {
135                n_vocab: 51865,
136                n_text_ctx: 448,
137                n_text_state: 1280,
138                n_text_head: 20,
139                n_text_layer: 32,
140            },
141        }
142    }
143
144    pub fn init<B: Backend>(&self, device: &B::Device) -> Whisper<B> {
145        let n_audio_state = self.audio_encoder_config.n_audio_state;
146        let n_text_state = self.text_decoder_config.n_text_state;
147
148        assert!(
149            n_audio_state == n_text_state,
150            "Audio encoder state size {} must be equal to text decoder state size {}.",
151            n_audio_state,
152            n_text_state
153        );
154
155        let encoder = self.audio_encoder_config.init(device);
156        let decoder = self.text_decoder_config.init(device);
157
158        Whisper { encoder, decoder }
159    }
160}
161
162/// The main Whisper model
163#[derive(Module, Debug)]
164pub struct Whisper<B: Backend> {
165    pub encoder: AudioEncoder<B>,
166    pub decoder: TextDecoder<B>,
167}
168
169impl<B: Backend> Whisper<B> {
170    /// Full forward pass: encode audio and decode tokens
171    pub fn forward(&self, mel: Tensor<B, 3>, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
172        self.decoder.forward(tokens, self.encoder.forward(mel))
173    }
174
175    /// Encode audio mel spectrogram
176    pub fn forward_encoder(&self, mel: Tensor<B, 3>) -> Tensor<B, 3> {
177        self.encoder.forward(mel)
178    }
179
180    /// Decode tokens given encoder output
181    pub fn forward_decoder(
182        &self,
183        tokens: Tensor<B, 2, Int>,
184        encoder_output: Tensor<B, 3>,
185    ) -> Tensor<B, 3> {
186        self.decoder.forward(tokens, encoder_output)
187    }
188
189    pub fn encoder_ctx_size(&self) -> usize {
190        self.encoder.ctx_size()
191    }
192
193    pub fn decoder_ctx_size(&self) -> usize {
194        self.decoder.ctx_size()
195    }
196}
197
198// ============================================================================
199// Text Decoder
200// ============================================================================
201
202#[derive(Config, Debug)]
203pub struct TextDecoderConfig {
204    pub n_vocab: usize,
205    pub n_text_ctx: usize,
206    pub n_text_state: usize,
207    pub n_text_head: usize,
208    pub n_text_layer: usize,
209}
210
211impl TextDecoderConfig {
212    pub fn init<B: Backend>(&self, device: &B::Device) -> TextDecoder<B> {
213        let token_embedding = Param::from_tensor(Tensor::random(
214            [self.n_vocab, self.n_text_state],
215            Distribution::Normal(0.0, 0.02),
216            device,
217        ));
218        let positional_embedding = Param::from_tensor(Tensor::random(
219            [self.n_text_ctx, self.n_text_state],
220            Distribution::Normal(0.0, 0.01),
221            device,
222        ));
223        let blocks: Vec<_> = (0..self.n_text_layer)
224            .map(|_| {
225                ResidualDecoderAttentionBlockConfig::new(self.n_text_state, self.n_text_head)
226                    .init(device)
227            })
228            .collect();
229        let ln = nn::LayerNormConfig::new(self.n_text_state).init(device);
230
231        let mask = Param::from_tensor(attn_decoder_mask(self.n_text_ctx, device));
232
233        let n_vocab = self.n_vocab;
234        let n_text_ctx = self.n_text_ctx;
235
236        TextDecoder {
237            token_embedding,
238            positional_embedding,
239            blocks,
240            ln,
241            mask,
242            n_vocab,
243            n_text_ctx,
244        }
245    }
246}
247
248#[derive(Module, Debug)]
249pub struct TextDecoder<B: Backend> {
250    pub token_embedding: Param<Tensor<B, 2>>,
251    pub positional_embedding: Param<Tensor<B, 2>>,
252    pub blocks: Vec<ResidualDecoderAttentionBlock<B>>,
253    pub ln: nn::LayerNorm<B>,
254    pub mask: Param<Tensor<B, 2>>,
255    pub n_vocab: usize,
256    pub n_text_ctx: usize,
257}
258
259impl<B: Backend> TextDecoder<B> {
260    pub fn forward(&self, x: Tensor<B, 2, Int>, xa: Tensor<B, 3>) -> Tensor<B, 3> {
261        let [_n_batch, seq_len] = x.dims();
262
263        assert!(
264            seq_len <= self.n_text_ctx,
265            "Token sequence length {} must not exceed {}.",
266            seq_len,
267            self.n_text_ctx
268        );
269
270        // Token embedding lookup
271        let x = burn::tensor::module::embedding(self.token_embedding.val(), x)
272            + self
273                .positional_embedding
274                .val()
275                .slice([0..seq_len])
276                .unsqueeze::<3>();
277
278        let mut x = x;
279        for block in &self.blocks {
280            x = block.forward(x, xa.clone(), self.mask.val());
281        }
282
283        let x = self.ln.forward(x);
284        // Project to vocabulary logits
285        x.matmul(self.token_embedding.val().transpose().unsqueeze::<3>())
286    }
287
288    pub fn ctx_size(&self) -> usize {
289        self.n_text_ctx
290    }
291}
292
293// ============================================================================
294// Audio Encoder
295// ============================================================================
296
297#[derive(Config, Debug)]
298pub struct AudioEncoderConfig {
299    pub n_mels: usize,
300    pub n_audio_ctx: usize,
301    pub n_audio_state: usize,
302    pub n_audio_head: usize,
303    pub n_audio_layer: usize,
304}
305
306impl AudioEncoderConfig {
307    pub fn init<B: Backend>(&self, device: &B::Device) -> AudioEncoder<B> {
308        let conv1 = Conv1dConfig::new(self.n_mels, self.n_audio_state, 3)
309            .with_padding(PaddingConfig1d::Explicit(1, 1))
310            .init(device);
311        let gelu1 = nn::Gelu::new();
312        let conv2 = Conv1dConfig::new(self.n_audio_state, self.n_audio_state, 3)
313            .with_padding(PaddingConfig1d::Explicit(1, 1))
314            .with_stride(2)
315            .init(device);
316        let gelu2 = nn::Gelu::new();
317        let blocks: Vec<_> = (0..self.n_audio_layer)
318            .map(|_| {
319                ResidualEncoderAttentionBlockConfig::new(self.n_audio_state, self.n_audio_head)
320                    .init(device)
321            })
322            .collect();
323        let ln_post = nn::LayerNormConfig::new(self.n_audio_state).init(device);
324        let positional_embedding = Param::from_tensor(Tensor::random(
325            [self.n_audio_ctx, self.n_audio_state],
326            Distribution::Normal(0.0, 0.01),
327            device,
328        ));
329        let n_mels = self.n_mels;
330        let n_audio_ctx = self.n_audio_ctx;
331
332        AudioEncoder {
333            conv1,
334            gelu1,
335            conv2,
336            gelu2,
337            blocks,
338            ln_post,
339            positional_embedding,
340            n_mels,
341            n_audio_ctx,
342        }
343    }
344}
345
346#[derive(Module, Debug)]
347pub struct AudioEncoder<B: Backend> {
348    pub conv1: Conv1d<B>,
349    pub gelu1: nn::Gelu,
350    pub conv2: Conv1d<B>,
351    pub gelu2: nn::Gelu,
352    pub blocks: Vec<ResidualEncoderAttentionBlock<B>>,
353    pub ln_post: nn::LayerNorm<B>,
354    pub positional_embedding: Param<Tensor<B, 2>>,
355    pub n_mels: usize,
356    pub n_audio_ctx: usize,
357}
358
359impl<B: Backend> AudioEncoder<B> {
360    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
361        let [_, n_mels, n_ctx] = x.dims();
362
363        assert!(
364            n_mels == self.n_mels,
365            "Audio mel spectrum size must be {}.",
366            self.n_mels
367        );
368        assert!(
369            n_ctx <= 2 * self.n_audio_ctx,
370            "Audio length {} cannot exceed {}.",
371            n_ctx,
372            2 * self.n_audio_ctx
373        );
374
375        let x = self.gelu1.forward(self.conv1.forward(x));
376        let x = self.gelu2.forward(self.conv2.forward(x));
377
378        let x = x.swap_dims(1, 2);
379        let k = x.dims()[1];
380        let x = x + self
381            .positional_embedding
382            .val()
383            .slice([0..k])
384            .unsqueeze::<3>();
385
386        let mut x = x;
387        for block in &self.blocks {
388            x = block.forward(x);
389        }
390
391        self.ln_post.forward(x)
392    }
393
394    pub fn ctx_size(&self) -> usize {
395        self.n_audio_ctx
396    }
397}
398
399// ============================================================================
400// Attention Blocks
401// ============================================================================
402
403#[derive(Config, Debug)]
404pub struct ResidualEncoderAttentionBlockConfig {
405    n_state: usize,
406    n_head: usize,
407}
408
409impl ResidualEncoderAttentionBlockConfig {
410    pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualEncoderAttentionBlock<B> {
411        let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
412        let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
413
414        let mlp = MLPConfig::new(self.n_state).init(device);
415        let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
416
417        ResidualEncoderAttentionBlock {
418            attn,
419            attn_ln,
420            mlp,
421            mlp_ln,
422        }
423    }
424}
425
426#[derive(Module, Debug)]
427pub struct ResidualEncoderAttentionBlock<B: Backend> {
428    pub attn: MultiHeadSelfAttention<B>,
429    pub attn_ln: nn::LayerNorm<B>,
430    pub mlp: MLP<B>,
431    pub mlp_ln: nn::LayerNorm<B>,
432}
433
434impl<B: Backend> ResidualEncoderAttentionBlock<B> {
435    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
436        let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), None);
437        let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
438        x
439    }
440}
441
442#[derive(Config, Debug)]
443pub struct ResidualDecoderAttentionBlockConfig {
444    n_state: usize,
445    n_head: usize,
446}
447
448impl ResidualDecoderAttentionBlockConfig {
449    pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualDecoderAttentionBlock<B> {
450        let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init(device);
451        let attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
452
453        let cross_attn = MultiHeadCrossAttentionConfig::new(self.n_state, self.n_head).init(device);
454        let cross_attn_ln = nn::LayerNormConfig::new(self.n_state).init(device);
455
456        let mlp = MLPConfig::new(self.n_state).init(device);
457        let mlp_ln = nn::LayerNormConfig::new(self.n_state).init(device);
458
459        ResidualDecoderAttentionBlock {
460            attn,
461            attn_ln,
462            cross_attn,
463            cross_attn_ln,
464            mlp,
465            mlp_ln,
466        }
467    }
468}
469
470#[derive(Module, Debug)]
471pub struct ResidualDecoderAttentionBlock<B: Backend> {
472    pub attn: MultiHeadSelfAttention<B>,
473    pub attn_ln: nn::LayerNorm<B>,
474    pub cross_attn: MultiHeadCrossAttention<B>,
475    pub cross_attn_ln: nn::LayerNorm<B>,
476    pub mlp: MLP<B>,
477    pub mlp_ln: nn::LayerNorm<B>,
478}
479
480impl<B: Backend> ResidualDecoderAttentionBlock<B> {
481    pub fn forward(&self, x: Tensor<B, 3>, xa: Tensor<B, 3>, mask: Tensor<B, 2>) -> Tensor<B, 3> {
482        let x = x.clone() + self.attn.forward(self.attn_ln.forward(x), Some(mask));
483        let x = x.clone() + self.cross_attn.forward(self.cross_attn_ln.forward(x), xa);
484        let x = x.clone() + self.mlp.forward(self.mlp_ln.forward(x));
485        x
486    }
487}
488
489// ============================================================================
490// MLP (Feed-Forward Network)
491// ============================================================================
492
493#[derive(Config, Debug)]
494pub struct MLPConfig {
495    n_state: usize,
496}
497
498impl MLPConfig {
499    pub fn init<B: Backend>(&self, device: &B::Device) -> MLP<B> {
500        let lin1 = nn::LinearConfig::new(self.n_state, 4 * self.n_state).init(device);
501        let gelu = nn::Gelu::new();
502        let lin2 = nn::LinearConfig::new(4 * self.n_state, self.n_state).init(device);
503
504        MLP { lin1, gelu, lin2 }
505    }
506}
507
508#[derive(Module, Debug)]
509pub struct MLP<B: Backend> {
510    pub lin1: nn::Linear<B>,
511    pub gelu: nn::Gelu,
512    pub lin2: nn::Linear<B>,
513}
514
515impl<B: Backend> MLP<B> {
516    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
517        let x = self.lin1.forward(x);
518        let x = self.gelu.forward(x);
519        self.lin2.forward(x)
520    }
521}
522
523// ============================================================================
524// Multi-Head Attention
525// ============================================================================
526
527#[derive(Config, Debug)]
528pub struct MultiHeadSelfAttentionConfig {
529    n_state: usize,
530    n_head: usize,
531}
532
533impl MultiHeadSelfAttentionConfig {
534    pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
535        assert!(
536            self.n_state.is_multiple_of(self.n_head),
537            "State size {} must be a multiple of head size {}",
538            self.n_state,
539            self.n_head
540        );
541
542        let n_head = self.n_head;
543        let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
544        let key = nn::LinearConfig::new(self.n_state, self.n_state)
545            .with_bias(false)
546            .init(device);
547        let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
548        let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
549
550        MultiHeadSelfAttention {
551            n_head,
552            query,
553            key,
554            value,
555            out,
556        }
557    }
558}
559
560#[derive(Module, Debug)]
561pub struct MultiHeadSelfAttention<B: Backend> {
562    pub n_head: usize,
563    pub query: nn::Linear<B>,
564    pub key: nn::Linear<B>,
565    pub value: nn::Linear<B>,
566    pub out: nn::Linear<B>,
567}
568
569impl<B: Backend> MultiHeadSelfAttention<B> {
570    pub fn forward(&self, x: Tensor<B, 3>, mask: Option<Tensor<B, 2>>) -> Tensor<B, 3> {
571        let q = self.query.forward(x.clone());
572        let k = self.key.forward(x.clone());
573        let v = self.value.forward(x);
574
575        let wv = qkv_attention(q, k, v, mask, self.n_head);
576
577        self.out.forward(wv)
578    }
579}
580
581#[derive(Config, Debug)]
582pub struct MultiHeadCrossAttentionConfig {
583    n_state: usize,
584    n_head: usize,
585}
586
587impl MultiHeadCrossAttentionConfig {
588    pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadCrossAttention<B> {
589        assert!(
590            self.n_state.is_multiple_of(self.n_head),
591            "State size {} must be a multiple of head size {}",
592            self.n_state,
593            self.n_head
594        );
595
596        let n_head = self.n_head;
597        let query = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
598        let key = nn::LinearConfig::new(self.n_state, self.n_state)
599            .with_bias(false)
600            .init(device);
601        let value = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
602        let out = nn::LinearConfig::new(self.n_state, self.n_state).init(device);
603
604        MultiHeadCrossAttention {
605            n_head,
606            query,
607            key,
608            value,
609            out,
610        }
611    }
612}
613
614#[derive(Module, Debug)]
615pub struct MultiHeadCrossAttention<B: Backend> {
616    pub n_head: usize,
617    pub query: nn::Linear<B>,
618    pub key: nn::Linear<B>,
619    pub value: nn::Linear<B>,
620    pub out: nn::Linear<B>,
621}
622
623impl<B: Backend> MultiHeadCrossAttention<B> {
624    pub fn forward(&self, x: Tensor<B, 3>, xa: Tensor<B, 3>) -> Tensor<B, 3> {
625        let q = self.query.forward(x);
626        let k = self.key.forward(xa.clone());
627        let v = self.value.forward(xa);
628
629        let wv = qkv_attention(q, k, v, None, self.n_head);
630
631        self.out.forward(wv)
632    }
633}
634
635// ============================================================================
636// Attention Utilities
637// ============================================================================
638
639pub fn qkv_attention<B: Backend>(
640    q: Tensor<B, 3>,
641    k: Tensor<B, 3>,
642    v: Tensor<B, 3>,
643    mask: Option<Tensor<B, 2>>,
644    n_head: usize,
645) -> Tensor<B, 3> {
646    let [n_batch, n_qctx, n_state] = q.dims();
647    let [_, n_ctx, _] = k.dims();
648
649    let scale = (n_state as f64 / n_head as f64).powf(-0.25);
650    let n_hstate = n_state / n_head;
651
652    let q = q
653        .reshape([n_batch, n_qctx, n_head, n_hstate])
654        .swap_dims(1, 2)
655        * scale;
656    let k = k
657        .reshape([n_batch, n_ctx, n_head, n_hstate])
658        .swap_dims(1, 2)
659        .transpose()
660        * scale;
661    let v = v
662        .reshape([n_batch, n_ctx, n_head, n_hstate])
663        .swap_dims(1, 2);
664
665    let qk = q.matmul(k);
666
667    // Apply mask
668    let qk = if let Some(mask) = mask {
669        qk + mask.slice([0..n_qctx, 0..n_ctx]).unsqueeze::<4>()
670    } else {
671        qk
672    };
673
674    // Normalize value weightings
675    let w = softmax(qk, 3);
676    let o = w.matmul(v).swap_dims(1, 2).flatten(2, 3);
677
678    o
679}
680
681/// Create causal attention mask for decoder
682pub fn attn_decoder_mask<B: Backend>(seq_length: usize, device: &B::Device) -> Tensor<B, 2> {
683    let mut mask = Tensor::<B, 2>::zeros([seq_length, seq_length], device);
684
685    for i in 0..(seq_length - 1) {
686        let values =
687            Tensor::<B, 2>::zeros([1, seq_length - (i + 1)], device).add_scalar(NEG_INFINITY);
688        mask = mask.slice_assign([i..i + 1, i + 1..seq_length], values);
689    }
690
691    mask
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697    use burn_flex::Flex;
698    use burn_flex::FlexDevice;
699
700    type TestBackend = Flex<f32>;
701
702    #[test]
703    fn test_config_creation() {
704        let config = WhisperConfig::tiny_en();
705        assert_eq!(config.audio_encoder_config.n_audio_state, 384);
706        assert_eq!(config.text_decoder_config.n_text_state, 384);
707        assert_eq!(config.audio_encoder_config.n_mels, 80);
708    }
709
710    #[test]
711    fn test_model_init() {
712        let device = FlexDevice;
713        let config = WhisperConfig::tiny_en();
714        let model = config.init::<TestBackend>(&device);
715
716        assert_eq!(model.encoder.n_mels, 80);
717        assert_eq!(model.decoder.n_vocab, 51864);
718    }
719
720    #[test]
721    fn test_encoder_forward() {
722        let device = FlexDevice;
723        let config = WhisperConfig::tiny_en();
724        let model = config.init::<TestBackend>(&device);
725
726        // Input: [batch=1, n_mels=80, time=100]
727        let mel = Tensor::random([1, 80, 100], Distribution::Normal(0.0, 1.0), &device);
728        let output = model.encoder.forward(mel);
729
730        // Output: [batch=1, time/2=50, n_state=384]
731        assert_eq!(output.dims()[0], 1);
732        assert_eq!(output.dims()[1], 50); // Stride 2 halves the time dimension
733        assert_eq!(output.dims()[2], 384);
734    }
735
736    #[test]
737    fn test_decoder_forward() {
738        let device = FlexDevice;
739        let config = WhisperConfig::tiny_en();
740        let model = config.init::<TestBackend>(&device);
741
742        // Encoder output: [batch=1, time=50, n_state=384]
743        let encoder_output = Tensor::random([1, 50, 384], Distribution::Normal(0.0, 1.0), &device);
744
745        // Tokens: [batch=1, seq_len=5]
746        let tokens = Tensor::<TestBackend, 2, Int>::zeros([1, 5], &device);
747
748        let logits = model.decoder.forward(tokens, encoder_output);
749
750        // Output: [batch=1, seq_len=5, vocab=51864]
751        assert_eq!(logits.dims()[0], 1);
752        assert_eq!(logits.dims()[1], 5);
753        assert_eq!(logits.dims()[2], 51864);
754    }
755
756    #[test]
757    fn test_attention_mask() {
758        let device = FlexDevice;
759        let mask = attn_decoder_mask::<TestBackend>(4, &device);
760
761        assert_eq!(mask.dims(), [4, 4]);
762        // Check that it's lower triangular (zeros on diagonal and below, -inf above)
763    }
764}