Skip to main content

trustformers_models/performer/
model.rs

1use crate::performer::config::PerformerConfig;
2use std::io::Read;
3use trustformers_core::{
4    device::Device,
5    errors::Result,
6    layers::{Embedding, LayerNorm, Linear},
7    tensor::Tensor,
8    traits::{Config, Layer, Model},
9};
10
11/// FAVOR+ attention mechanism for linear complexity
12/// Approximates softmax attention using positive random features
13pub struct FavorPlusAttention {
14    query: Linear,
15    key: Linear,
16    value: Linear,
17    output: Linear,
18
19    num_attention_heads: usize,
20    attention_head_size: usize,
21    num_random_features: usize,
22    kernel_type: String,
23    causal: bool,
24    normalize_features: bool,
25    numerical_stabilizer: f32,
26
27    // Random feature matrices (would be redrawn periodically in training)
28    random_features: Option<Tensor>,
29
30    device: Device,
31}
32
33impl FavorPlusAttention {
34    pub fn new(config: &PerformerConfig) -> Result<Self> {
35        Self::new_with_device(config, Device::CPU)
36    }
37
38    pub fn new_with_device(config: &PerformerConfig, device: Device) -> Result<Self> {
39        let attention_head_size = config.head_dim();
40        let all_head_size = config.num_attention_heads * attention_head_size;
41
42        let query = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
43        let key = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
44        let value = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
45        let output = Linear::new_with_device(all_head_size, config.hidden_size, true, device);
46
47        Ok(Self {
48            query,
49            key,
50            value,
51            output,
52            num_attention_heads: config.num_attention_heads,
53            attention_head_size,
54            num_random_features: config.num_random_features,
55            kernel_type: config.kernel_type.clone(),
56            causal: config.causal_attention,
57            normalize_features: config.normalize_features,
58            numerical_stabilizer: config.numerical_stabilizer,
59            random_features: None,
60            device,
61        })
62    }
63
64    pub fn device(&self) -> Device {
65        self.device
66    }
67
68    pub fn parameter_count(&self) -> usize {
69        self.query.parameter_count()
70            + self.key.parameter_count()
71            + self.value.parameter_count()
72            + self.output.parameter_count()
73    }
74
75    /// Generate random features for FAVOR+ approximation
76    fn generate_random_features(&self, _device: &str) -> Result<Tensor> {
77        // Generate random Gaussian matrix: [head_dim, num_random_features]
78        let random_matrix = Tensor::randn(&[self.attention_head_size, self.num_random_features])?;
79
80        if self.normalize_features {
81            // Normalize to unit length along the feature dimension
82            // Compute L2 norm across the feature dimension (axis 1)
83            let squared = random_matrix.mul(&random_matrix)?;
84            let sum_squared = squared.sum(None, false)?; // Sum across all dimensions
85            let norm = sum_squared.sqrt()?;
86
87            // Add small epsilon for numerical stability
88            let eps = Tensor::scalar(1e-8)?;
89            let stable_norm = norm.add(&eps)?;
90
91            // Normalize by broadcasting the norm
92            random_matrix.div(&stable_norm)
93        } else {
94            Ok(random_matrix)
95        }
96    }
97
98    /// Apply feature map function φ(x) based on kernel type
99    fn apply_feature_map(&self, x: &Tensor, random_features: &Tensor) -> Result<Tensor> {
100        // x: [batch, heads, seq_len, head_dim]
101        // random_features: [head_dim, num_random_features]
102
103        let _batch_size = x.shape()[0];
104        let _num_heads = x.shape()[1];
105        let _seq_len = x.shape()[2];
106
107        // Project: x @ random_features -> [batch, heads, seq_len, num_random_features]
108        let projections = x.matmul(random_features)?;
109
110        match self.kernel_type.as_str() {
111            "relu" => {
112                // ReLU kernel: φ(x) = sqrt(2/m) * max(0, x @ w)
113                let scale = (2.0 / self.num_random_features as f32).sqrt();
114                let features = projections.relu()?.mul_scalar(scale)?;
115                Ok(features)
116            },
117            "exp" => {
118                // Exponential kernel: φ(x) = exp(x @ w - ||x||²/2) / sqrt(m)
119                let x_norm_sq = x.pow(2.0)?.sum(Some(vec![x.shape().len() - 1]), true)?; // [batch, heads, seq_len, 1]
120                let scaled_proj = projections.sub(&x_norm_sq.mul_scalar(0.5)?)?;
121                let features = scaled_proj
122                    .exp()?
123                    .mul_scalar(1.0 / (self.num_random_features as f32).sqrt())?;
124                Ok(features)
125            },
126            "softmax+" => {
127                // Positive features for softmax approximation
128                let x_norm_sq = x.pow(2.0)?.sum(Some(vec![x.shape().len() - 1]), true)?;
129                let h = self.attention_head_size as f32;
130
131                // φ(x) = exp(x @ w - ||x||²/2) / sqrt(m) for better softmax approximation
132                let scaled_proj = projections.sub(&x_norm_sq.mul_scalar(0.5)?)?;
133                let features =
134                    scaled_proj.exp()?.mul_scalar((h / self.num_random_features as f32).sqrt())?;
135                Ok(features)
136            },
137            _ => {
138                // Default to ReLU
139                let scale = (2.0 / self.num_random_features as f32).sqrt();
140                let features = projections.relu()?.mul_scalar(scale)?;
141                Ok(features)
142            },
143        }
144    }
145
146    /// Compute FAVOR+ attention
147    fn favor_attention(
148        &self,
149        query_features: &Tensor,
150        key_features: &Tensor,
151        values: &Tensor,
152    ) -> Result<Tensor> {
153        // query_features, key_features: [batch, heads, seq_len, num_random_features]
154        // values: [batch, heads, seq_len, head_dim]
155
156        if self.causal {
157            // Causal attention: use cumulative sums
158            self.causal_favor_attention(query_features, key_features, values)
159        } else {
160            // Non-causal attention: use matrix multiplication
161            self.non_causal_favor_attention(query_features, key_features, values)
162        }
163    }
164
165    fn non_causal_favor_attention(
166        &self,
167        query_features: &Tensor,
168        key_features: &Tensor,
169        values: &Tensor,
170    ) -> Result<Tensor> {
171        // Compute D = sum(key_features, dim=seq_len)
172        // D: [batch, heads, num_random_features]
173        let d = key_features.sum(Some(vec![2]), false)?;
174
175        // Compute numerator: query_features @ (key_features^T @ values)
176        // key_features^T: [batch, heads, num_random_features, seq_len]
177        let key_features_t = key_features.transpose(
178            key_features.shape().len() - 2,
179            key_features.shape().len() - 1,
180        )?;
181
182        // kv: [batch, heads, num_random_features, head_dim]
183        let kv = key_features_t.matmul(values)?;
184
185        // numerator: [batch, heads, seq_len, head_dim]
186        let numerator = query_features.matmul(&kv)?;
187
188        // Compute denominator: query_features @ D
189        // denominator: [batch, heads, seq_len, 1]
190        let denominator = query_features.matmul(&d.unsqueeze(d.shape().len())?)?;
191        let denominator = denominator.add_scalar(self.numerical_stabilizer)?;
192
193        // Final attention output
194        numerator.div(&denominator)
195    }
196
197    fn causal_favor_attention(
198        &self,
199        query_features: &Tensor,
200        key_features: &Tensor,
201        values: &Tensor,
202    ) -> Result<Tensor> {
203        let batch_size = query_features.shape()[0];
204        let num_heads = query_features.shape()[1];
205        let seq_len = query_features.shape()[2];
206        let head_dim = values.shape()[3];
207
208        // Initialize output
209        let mut output = Tensor::zeros(&[batch_size, num_heads, seq_len, head_dim])?;
210
211        // Running sums for causal attention
212        let mut running_kv =
213            Tensor::zeros(&[batch_size, num_heads, self.num_random_features, head_dim])?;
214        let mut running_k = Tensor::zeros(&[batch_size, num_heads, self.num_random_features])?;
215
216        // Process each position causally
217        for i in 0..seq_len {
218            // Get current query, key, value using proper tensor slicing
219            let q_i = query_features.slice_multi(&[
220                (0, batch_size),
221                (0, num_heads),
222                (i, i + 1),
223                (0, self.num_random_features),
224            ])?;
225            let k_i = key_features.slice_multi(&[
226                (0, batch_size),
227                (0, num_heads),
228                (i, i + 1),
229                (0, self.num_random_features),
230            ])?;
231            let v_i = values.slice_multi(&[
232                (0, batch_size),
233                (0, num_heads),
234                (i, i + 1),
235                (0, head_dim),
236            ])?;
237
238            // Compute attention output for position i
239            let numerator = q_i.matmul(&running_kv)?;
240            let denominator = q_i.matmul(&running_k.unsqueeze(running_k.shape().len())?)?;
241            let denominator = denominator.add_scalar(self.numerical_stabilizer)?;
242
243            let att_output = numerator.div(&denominator)?;
244
245            // Build output tensor by concatenating position outputs
246            if i == 0 {
247                output = att_output.clone();
248            } else {
249                output = Tensor::concat(&[output, att_output], 2)?;
250            }
251
252            // Update running sums
253            let shape = k_i.shape();
254            let dim0 = shape.len().saturating_sub(2);
255            let dim1 = shape.len().saturating_sub(1);
256            let k_i_t = k_i.transpose(dim0, dim1)?; // [batch, heads, num_random_features, 1]
257            let kv_update = k_i_t.matmul(&v_i)?; // [batch, heads, num_random_features, head_dim]
258            running_kv = running_kv.add(&kv_update)?;
259            let shape = k_i.shape();
260            let squeeze_dim = shape.len().saturating_sub(2);
261            running_k = running_k.add(&k_i.squeeze(squeeze_dim)?)?;
262        }
263
264        Ok(output)
265    }
266
267    /// Transpose tensor for multi-head attention
268    fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
269        let batch_size = x.shape()[0];
270        let seq_len = x.shape()[1];
271
272        // Reshape: [batch, seq, heads * head_dim] -> [batch, seq, heads, head_dim]
273        let reshaped = x.reshape(&[
274            batch_size,
275            seq_len,
276            self.num_attention_heads,
277            self.attention_head_size,
278        ])?;
279
280        // Permute: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
281        reshaped.permute(&[0, 2, 1, 3])
282    }
283}
284
285impl Layer for FavorPlusAttention {
286    type Input = Tensor;
287    type Output = Tensor;
288
289    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
290        let batch_size = input.shape()[0];
291        let seq_len = input.shape()[1];
292
293        // Linear projections
294        let query_layer = self.query.forward(input.clone())?;
295        let key_layer = self.key.forward(input.clone())?;
296        let value_layer = self.value.forward(input)?;
297
298        // Transpose for multi-head attention
299        let query_layer = self.transpose_for_scores(&query_layer)?;
300        let key_layer = self.transpose_for_scores(&key_layer)?;
301        let value_layer = self.transpose_for_scores(&value_layer)?;
302
303        // Generate or reuse random features
304        let random_features = if let Some(ref features) = self.random_features {
305            features.clone()
306        } else {
307            self.generate_random_features("cpu")?
308        };
309
310        // Apply feature maps
311        let query_features = self.apply_feature_map(&query_layer, &random_features)?;
312        let key_features = self.apply_feature_map(&key_layer, &random_features)?;
313
314        // Compute FAVOR+ attention
315        let context_layer = self.favor_attention(&query_features, &key_features, &value_layer)?;
316
317        // Transpose back: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim]
318        let context_layer = context_layer.permute(&[0, 2, 1, 3])?;
319
320        // Reshape: [batch, seq_len, heads, head_dim] -> [batch, seq_len, heads * head_dim]
321        let context_layer = context_layer.reshape(&[
322            batch_size,
323            seq_len,
324            self.num_attention_heads * self.attention_head_size,
325        ])?;
326
327        // Apply output projection
328        self.output.forward(context_layer)
329    }
330}
331
332/// Performer feed-forward network (same as BERT)
333pub struct PerformerFeedForward {
334    dense1: Linear,
335    dense2: Linear,
336    activation: String,
337    #[allow(dead_code)]
338    dropout: f32,
339    device: Device,
340}
341
342impl PerformerFeedForward {
343    pub fn new(config: &PerformerConfig) -> Result<Self> {
344        Self::new_with_device(config, Device::CPU)
345    }
346
347    pub fn new_with_device(config: &PerformerConfig, device: Device) -> Result<Self> {
348        let dense1 =
349            Linear::new_with_device(config.hidden_size, config.intermediate_size, true, device);
350        let dense2 =
351            Linear::new_with_device(config.intermediate_size, config.hidden_size, true, device);
352
353        Ok(Self {
354            dense1,
355            dense2,
356            activation: config.hidden_act.clone(),
357            dropout: config.hidden_dropout_prob,
358            device,
359        })
360    }
361
362    pub fn device(&self) -> Device {
363        self.device
364    }
365
366    pub fn parameter_count(&self) -> usize {
367        self.dense1.parameter_count() + self.dense2.parameter_count()
368    }
369
370    fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
371        match self.activation.as_str() {
372            "gelu" => x.gelu(),
373            "relu" => x.relu(),
374            "silu" | "swish" => x.silu(),
375            _ => Ok(x.clone()),
376        }
377    }
378}
379
380impl Layer for PerformerFeedForward {
381    type Input = Tensor;
382    type Output = Tensor;
383
384    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
385        let hidden = self.dense1.forward(input);
386        let hidden = hidden?;
387        let hidden = self.apply_activation(&hidden)?;
388        self.dense2.forward(hidden)
389    }
390}
391
392/// Performer encoder layer
393pub struct PerformerLayer {
394    attention: FavorPlusAttention,
395    feed_forward: PerformerFeedForward,
396    attention_norm: LayerNorm,
397    output_norm: LayerNorm,
398    device: Device,
399}
400
401impl PerformerLayer {
402    pub fn new(config: &PerformerConfig) -> Result<Self> {
403        Self::new_with_device(config, Device::CPU)
404    }
405
406    pub fn new_with_device(config: &PerformerConfig, device: Device) -> Result<Self> {
407        let attention = FavorPlusAttention::new_with_device(config, device)?;
408        let feed_forward = PerformerFeedForward::new_with_device(config, device)?;
409        let attention_norm =
410            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
411        let output_norm =
412            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
413
414        Ok(Self {
415            attention,
416            feed_forward,
417            attention_norm,
418            output_norm,
419            device,
420        })
421    }
422
423    pub fn device(&self) -> Device {
424        self.device
425    }
426
427    pub fn parameter_count(&self) -> usize {
428        self.attention.parameter_count()
429            + self.feed_forward.parameter_count()
430            + self.attention_norm.parameter_count()
431            + self.output_norm.parameter_count()
432    }
433}
434
435impl Layer for PerformerLayer {
436    type Input = Tensor;
437    type Output = Tensor;
438
439    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
440        // Multi-head attention with residual connection and layer norm
441        let attention_output = self.attention.forward(input.clone())?;
442        let attention_output = input.add(&attention_output)?;
443        let attention_output = self.attention_norm.forward(attention_output)?;
444
445        // Feed-forward with residual connection and layer norm
446        let ff_output = self.feed_forward.forward(attention_output.clone())?;
447        let output = attention_output.add(&ff_output)?;
448        self.output_norm.forward(output)
449    }
450}
451
452/// Performer embeddings (same as BERT)
453pub struct PerformerEmbeddings {
454    word_embeddings: Embedding,
455    position_embeddings: Embedding,
456    token_type_embeddings: Embedding,
457    layer_norm: LayerNorm,
458    #[allow(dead_code)]
459    dropout: f32,
460    device: Device,
461}
462
463impl PerformerEmbeddings {
464    pub fn new(config: &PerformerConfig) -> Result<Self> {
465        Self::new_with_device(config, Device::CPU)
466    }
467
468    pub fn new_with_device(config: &PerformerConfig, device: Device) -> Result<Self> {
469        let word_embeddings = Embedding::new_with_device(
470            config.vocab_size,
471            config.hidden_size,
472            Some(config.pad_token_id as usize),
473            device,
474        )?;
475        let position_embeddings = Embedding::new_with_device(
476            config.max_position_embeddings,
477            config.hidden_size,
478            None,
479            device,
480        )?;
481        let token_type_embeddings =
482            Embedding::new_with_device(config.type_vocab_size, config.hidden_size, None, device)?;
483        let layer_norm =
484            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
485
486        Ok(Self {
487            word_embeddings,
488            position_embeddings,
489            token_type_embeddings,
490            layer_norm,
491            dropout: config.hidden_dropout_prob,
492            device,
493        })
494    }
495
496    pub fn device(&self) -> Device {
497        self.device
498    }
499
500    pub fn parameter_count(&self) -> usize {
501        self.word_embeddings.parameter_count()
502            + self.position_embeddings.parameter_count()
503            + self.token_type_embeddings.parameter_count()
504            + self.layer_norm.parameter_count()
505    }
506}
507
508impl Layer for PerformerEmbeddings {
509    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
510    type Output = Tensor;
511
512    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
513        let (input_ids, token_type_ids, position_ids) = input;
514        let seq_len = input_ids.len();
515
516        let words_embeddings = self.word_embeddings.forward(input_ids)?;
517
518        let position_ids = position_ids.unwrap_or_else(|| (0..seq_len as u32).collect());
519        let position_embeddings = self.position_embeddings.forward(position_ids)?;
520
521        let token_type_ids = token_type_ids.unwrap_or_else(|| vec![0; seq_len]);
522        let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
523
524        let embeddings = words_embeddings.add(&position_embeddings)?.add(&token_type_embeddings)?;
525        let embeddings = self.layer_norm.forward(embeddings)?;
526
527        Ok(embeddings)
528    }
529}
530
531/// Performer encoder
532pub struct PerformerEncoder {
533    layers: Vec<PerformerLayer>,
534    device: Device,
535}
536
537impl PerformerEncoder {
538    pub fn new(config: &PerformerConfig) -> Result<Self> {
539        Self::new_with_device(config, Device::CPU)
540    }
541
542    pub fn new_with_device(config: &PerformerConfig, device: Device) -> Result<Self> {
543        let mut layers = Vec::new();
544        for _ in 0..config.num_hidden_layers {
545            layers.push(PerformerLayer::new_with_device(config, device)?);
546        }
547
548        Ok(Self { layers, device })
549    }
550
551    pub fn device(&self) -> Device {
552        self.device
553    }
554
555    pub fn parameter_count(&self) -> usize {
556        self.layers.iter().map(|layer| layer.parameter_count()).sum()
557    }
558}
559
560impl Layer for PerformerEncoder {
561    type Input = Tensor;
562    type Output = Tensor;
563
564    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
565        let mut hidden_states = input;
566
567        for layer in &self.layers {
568            hidden_states = layer.forward(hidden_states)?;
569        }
570
571        Ok(hidden_states)
572    }
573}
574
575/// Performer model
576pub struct PerformerModel {
577    config: PerformerConfig,
578    embeddings: PerformerEmbeddings,
579    encoder: PerformerEncoder,
580    device: Device,
581}
582
583impl PerformerModel {
584    pub fn new(config: PerformerConfig) -> Result<Self> {
585        Self::new_with_device(config, Device::CPU)
586    }
587
588    pub fn new_with_device(config: PerformerConfig, device: Device) -> Result<Self> {
589        config.validate()?;
590
591        let embeddings = PerformerEmbeddings::new_with_device(&config, device)?;
592        let encoder = PerformerEncoder::new_with_device(&config, device)?;
593
594        Ok(Self {
595            config,
596            embeddings,
597            encoder,
598            device,
599        })
600    }
601
602    pub fn device(&self) -> Device {
603        self.device
604    }
605}
606
607impl Model for PerformerModel {
608    type Config = PerformerConfig;
609    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
610    type Output = Tensor;
611
612    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
613        let embeddings = self.embeddings.forward(input)?;
614        let sequence_output = self.encoder.forward(embeddings)?;
615        Ok(sequence_output)
616    }
617
618    fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
619        Ok(())
620    }
621
622    fn get_config(&self) -> &Self::Config {
623        &self.config
624    }
625
626    fn num_parameters(&self) -> usize {
627        self.embeddings.parameter_count() + self.encoder.parameter_count()
628    }
629}
630
631/// Performer for sequence classification
632pub struct PerformerForSequenceClassification {
633    performer: PerformerModel,
634    classifier: Linear,
635    #[allow(dead_code)]
636    num_labels: usize,
637    device: Device,
638}
639
640impl PerformerForSequenceClassification {
641    pub fn new(config: PerformerConfig, num_labels: usize) -> Result<Self> {
642        Self::new_with_device(config, num_labels, Device::CPU)
643    }
644
645    pub fn new_with_device(
646        config: PerformerConfig,
647        num_labels: usize,
648        device: Device,
649    ) -> Result<Self> {
650        let performer = PerformerModel::new_with_device(config.clone(), device)?;
651        let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
652
653        Ok(Self {
654            performer,
655            classifier,
656            num_labels,
657            device,
658        })
659    }
660
661    pub fn device(&self) -> Device {
662        self.device
663    }
664}
665
666impl Model for PerformerForSequenceClassification {
667    type Config = PerformerConfig;
668    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
669    type Output = Tensor;
670
671    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
672        let sequence_output = self.performer.forward(input)?;
673        let cls_output = sequence_output.slice(1, 0, 1)?; // Get first token (CLS) from sequence
674        let cls_output = cls_output.squeeze(1)?; // Remove singleton sequence dimension
675        self.classifier.forward(cls_output)
676    }
677
678    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
679        self.performer.load_pretrained(reader)
680    }
681
682    fn get_config(&self) -> &Self::Config {
683        self.performer.get_config()
684    }
685
686    fn num_parameters(&self) -> usize {
687        self.performer.num_parameters() + self.classifier.parameter_count()
688    }
689}
690
691/// Performer for masked language modeling
692pub struct PerformerForMaskedLM {
693    performer: PerformerModel,
694    mlm_head: Linear,
695    device: Device,
696}
697
698impl PerformerForMaskedLM {
699    pub fn new(config: PerformerConfig) -> Result<Self> {
700        Self::new_with_device(config, Device::CPU)
701    }
702
703    pub fn new_with_device(config: PerformerConfig, device: Device) -> Result<Self> {
704        let performer = PerformerModel::new_with_device(config.clone(), device)?;
705        let mlm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, true, device);
706
707        Ok(Self {
708            performer,
709            mlm_head,
710            device,
711        })
712    }
713
714    pub fn device(&self) -> Device {
715        self.device
716    }
717}
718
719impl Model for PerformerForMaskedLM {
720    type Config = PerformerConfig;
721    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
722    type Output = Tensor;
723
724    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
725        let sequence_output = self.performer.forward(input)?;
726        self.mlm_head.forward(sequence_output)
727    }
728
729    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
730        self.performer.load_pretrained(reader)
731    }
732
733    fn get_config(&self) -> &Self::Config {
734        self.performer.get_config()
735    }
736
737    fn num_parameters(&self) -> usize {
738        self.performer.num_parameters() + self.mlm_head.parameter_count()
739    }
740}