1use crate::performer::config::PerformerConfig;
2use std::io::Read;
3use trustformers_core::{
4 device::Device,
5 errors::Result,
6 layers::{Embedding, LayerNorm, Linear},
7 tensor::Tensor,
8 traits::{Config, Layer, Model},
9};
10
11pub struct FavorPlusAttention {
14 query: Linear,
15 key: Linear,
16 value: Linear,
17 output: Linear,
18
19 num_attention_heads: usize,
20 attention_head_size: usize,
21 num_random_features: usize,
22 kernel_type: String,
23 causal: bool,
24 normalize_features: bool,
25 numerical_stabilizer: f32,
26
27 random_features: Option<Tensor>,
29
30 device: Device,
31}
32
33impl FavorPlusAttention {
34 pub fn new(config: &PerformerConfig) -> Result<Self> {
35 Self::new_with_device(config, Device::CPU)
36 }
37
38 pub fn new_with_device(config: &PerformerConfig, device: Device) -> Result<Self> {
39 let attention_head_size = config.head_dim();
40 let all_head_size = config.num_attention_heads * attention_head_size;
41
42 let query = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
43 let key = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
44 let value = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
45 let output = Linear::new_with_device(all_head_size, config.hidden_size, true, device);
46
47 Ok(Self {
48 query,
49 key,
50 value,
51 output,
52 num_attention_heads: config.num_attention_heads,
53 attention_head_size,
54 num_random_features: config.num_random_features,
55 kernel_type: config.kernel_type.clone(),
56 causal: config.causal_attention,
57 normalize_features: config.normalize_features,
58 numerical_stabilizer: config.numerical_stabilizer,
59 random_features: None,
60 device,
61 })
62 }
63
64 pub fn device(&self) -> Device {
65 self.device
66 }
67
68 pub fn parameter_count(&self) -> usize {
69 self.query.parameter_count()
70 + self.key.parameter_count()
71 + self.value.parameter_count()
72 + self.output.parameter_count()
73 }
74
75 fn generate_random_features(&self, _device: &str) -> Result<Tensor> {
77 let random_matrix = Tensor::randn(&[self.attention_head_size, self.num_random_features])?;
79
80 if self.normalize_features {
81 let squared = random_matrix.mul(&random_matrix)?;
84 let sum_squared = squared.sum(None, false)?; let norm = sum_squared.sqrt()?;
86
87 let eps = Tensor::scalar(1e-8)?;
89 let stable_norm = norm.add(&eps)?;
90
91 random_matrix.div(&stable_norm)
93 } else {
94 Ok(random_matrix)
95 }
96 }
97
98 fn apply_feature_map(&self, x: &Tensor, random_features: &Tensor) -> Result<Tensor> {
100 let _batch_size = x.shape()[0];
104 let _num_heads = x.shape()[1];
105 let _seq_len = x.shape()[2];
106
107 let projections = x.matmul(random_features)?;
109
110 match self.kernel_type.as_str() {
111 "relu" => {
112 let scale = (2.0 / self.num_random_features as f32).sqrt();
114 let features = projections.relu()?.mul_scalar(scale)?;
115 Ok(features)
116 },
117 "exp" => {
118 let x_norm_sq = x.pow(2.0)?.sum(Some(vec![x.shape().len() - 1]), true)?; let scaled_proj = projections.sub(&x_norm_sq.mul_scalar(0.5)?)?;
121 let features = scaled_proj
122 .exp()?
123 .mul_scalar(1.0 / (self.num_random_features as f32).sqrt())?;
124 Ok(features)
125 },
126 "softmax+" => {
127 let x_norm_sq = x.pow(2.0)?.sum(Some(vec![x.shape().len() - 1]), true)?;
129 let h = self.attention_head_size as f32;
130
131 let scaled_proj = projections.sub(&x_norm_sq.mul_scalar(0.5)?)?;
133 let features =
134 scaled_proj.exp()?.mul_scalar((h / self.num_random_features as f32).sqrt())?;
135 Ok(features)
136 },
137 _ => {
138 let scale = (2.0 / self.num_random_features as f32).sqrt();
140 let features = projections.relu()?.mul_scalar(scale)?;
141 Ok(features)
142 },
143 }
144 }
145
146 fn favor_attention(
148 &self,
149 query_features: &Tensor,
150 key_features: &Tensor,
151 values: &Tensor,
152 ) -> Result<Tensor> {
153 if self.causal {
157 self.causal_favor_attention(query_features, key_features, values)
159 } else {
160 self.non_causal_favor_attention(query_features, key_features, values)
162 }
163 }
164
165 fn non_causal_favor_attention(
166 &self,
167 query_features: &Tensor,
168 key_features: &Tensor,
169 values: &Tensor,
170 ) -> Result<Tensor> {
171 let d = key_features.sum(Some(vec![2]), false)?;
174
175 let key_features_t = key_features.transpose(
178 key_features.shape().len() - 2,
179 key_features.shape().len() - 1,
180 )?;
181
182 let kv = key_features_t.matmul(values)?;
184
185 let numerator = query_features.matmul(&kv)?;
187
188 let denominator = query_features.matmul(&d.unsqueeze(d.shape().len())?)?;
191 let denominator = denominator.add_scalar(self.numerical_stabilizer)?;
192
193 numerator.div(&denominator)
195 }
196
197 fn causal_favor_attention(
198 &self,
199 query_features: &Tensor,
200 key_features: &Tensor,
201 values: &Tensor,
202 ) -> Result<Tensor> {
203 let batch_size = query_features.shape()[0];
204 let num_heads = query_features.shape()[1];
205 let seq_len = query_features.shape()[2];
206 let head_dim = values.shape()[3];
207
208 let mut output = Tensor::zeros(&[batch_size, num_heads, seq_len, head_dim])?;
210
211 let mut running_kv =
213 Tensor::zeros(&[batch_size, num_heads, self.num_random_features, head_dim])?;
214 let mut running_k = Tensor::zeros(&[batch_size, num_heads, self.num_random_features])?;
215
216 for i in 0..seq_len {
218 let q_i = query_features.slice_multi(&[
220 (0, batch_size),
221 (0, num_heads),
222 (i, i + 1),
223 (0, self.num_random_features),
224 ])?;
225 let k_i = key_features.slice_multi(&[
226 (0, batch_size),
227 (0, num_heads),
228 (i, i + 1),
229 (0, self.num_random_features),
230 ])?;
231 let v_i = values.slice_multi(&[
232 (0, batch_size),
233 (0, num_heads),
234 (i, i + 1),
235 (0, head_dim),
236 ])?;
237
238 let numerator = q_i.matmul(&running_kv)?;
240 let denominator = q_i.matmul(&running_k.unsqueeze(running_k.shape().len())?)?;
241 let denominator = denominator.add_scalar(self.numerical_stabilizer)?;
242
243 let att_output = numerator.div(&denominator)?;
244
245 if i == 0 {
247 output = att_output.clone();
248 } else {
249 output = Tensor::concat(&[output, att_output], 2)?;
250 }
251
252 let shape = k_i.shape();
254 let dim0 = shape.len().saturating_sub(2);
255 let dim1 = shape.len().saturating_sub(1);
256 let k_i_t = k_i.transpose(dim0, dim1)?; let kv_update = k_i_t.matmul(&v_i)?; running_kv = running_kv.add(&kv_update)?;
259 let shape = k_i.shape();
260 let squeeze_dim = shape.len().saturating_sub(2);
261 running_k = running_k.add(&k_i.squeeze(squeeze_dim)?)?;
262 }
263
264 Ok(output)
265 }
266
267 fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
269 let batch_size = x.shape()[0];
270 let seq_len = x.shape()[1];
271
272 let reshaped = x.reshape(&[
274 batch_size,
275 seq_len,
276 self.num_attention_heads,
277 self.attention_head_size,
278 ])?;
279
280 reshaped.permute(&[0, 2, 1, 3])
282 }
283}
284
285impl Layer for FavorPlusAttention {
286 type Input = Tensor;
287 type Output = Tensor;
288
289 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
290 let batch_size = input.shape()[0];
291 let seq_len = input.shape()[1];
292
293 let query_layer = self.query.forward(input.clone())?;
295 let key_layer = self.key.forward(input.clone())?;
296 let value_layer = self.value.forward(input)?;
297
298 let query_layer = self.transpose_for_scores(&query_layer)?;
300 let key_layer = self.transpose_for_scores(&key_layer)?;
301 let value_layer = self.transpose_for_scores(&value_layer)?;
302
303 let random_features = if let Some(ref features) = self.random_features {
305 features.clone()
306 } else {
307 self.generate_random_features("cpu")?
308 };
309
310 let query_features = self.apply_feature_map(&query_layer, &random_features)?;
312 let key_features = self.apply_feature_map(&key_layer, &random_features)?;
313
314 let context_layer = self.favor_attention(&query_features, &key_features, &value_layer)?;
316
317 let context_layer = context_layer.permute(&[0, 2, 1, 3])?;
319
320 let context_layer = context_layer.reshape(&[
322 batch_size,
323 seq_len,
324 self.num_attention_heads * self.attention_head_size,
325 ])?;
326
327 self.output.forward(context_layer)
329 }
330}
331
332pub struct PerformerFeedForward {
334 dense1: Linear,
335 dense2: Linear,
336 activation: String,
337 #[allow(dead_code)]
338 dropout: f32,
339 device: Device,
340}
341
342impl PerformerFeedForward {
343 pub fn new(config: &PerformerConfig) -> Result<Self> {
344 Self::new_with_device(config, Device::CPU)
345 }
346
347 pub fn new_with_device(config: &PerformerConfig, device: Device) -> Result<Self> {
348 let dense1 =
349 Linear::new_with_device(config.hidden_size, config.intermediate_size, true, device);
350 let dense2 =
351 Linear::new_with_device(config.intermediate_size, config.hidden_size, true, device);
352
353 Ok(Self {
354 dense1,
355 dense2,
356 activation: config.hidden_act.clone(),
357 dropout: config.hidden_dropout_prob,
358 device,
359 })
360 }
361
362 pub fn device(&self) -> Device {
363 self.device
364 }
365
366 pub fn parameter_count(&self) -> usize {
367 self.dense1.parameter_count() + self.dense2.parameter_count()
368 }
369
370 fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
371 match self.activation.as_str() {
372 "gelu" => x.gelu(),
373 "relu" => x.relu(),
374 "silu" | "swish" => x.silu(),
375 _ => Ok(x.clone()),
376 }
377 }
378}
379
380impl Layer for PerformerFeedForward {
381 type Input = Tensor;
382 type Output = Tensor;
383
384 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
385 let hidden = self.dense1.forward(input);
386 let hidden = hidden?;
387 let hidden = self.apply_activation(&hidden)?;
388 self.dense2.forward(hidden)
389 }
390}
391
392pub struct PerformerLayer {
394 attention: FavorPlusAttention,
395 feed_forward: PerformerFeedForward,
396 attention_norm: LayerNorm,
397 output_norm: LayerNorm,
398 device: Device,
399}
400
401impl PerformerLayer {
402 pub fn new(config: &PerformerConfig) -> Result<Self> {
403 Self::new_with_device(config, Device::CPU)
404 }
405
406 pub fn new_with_device(config: &PerformerConfig, device: Device) -> Result<Self> {
407 let attention = FavorPlusAttention::new_with_device(config, device)?;
408 let feed_forward = PerformerFeedForward::new_with_device(config, device)?;
409 let attention_norm =
410 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
411 let output_norm =
412 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
413
414 Ok(Self {
415 attention,
416 feed_forward,
417 attention_norm,
418 output_norm,
419 device,
420 })
421 }
422
423 pub fn device(&self) -> Device {
424 self.device
425 }
426
427 pub fn parameter_count(&self) -> usize {
428 self.attention.parameter_count()
429 + self.feed_forward.parameter_count()
430 + self.attention_norm.parameter_count()
431 + self.output_norm.parameter_count()
432 }
433}
434
435impl Layer for PerformerLayer {
436 type Input = Tensor;
437 type Output = Tensor;
438
439 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
440 let attention_output = self.attention.forward(input.clone())?;
442 let attention_output = input.add(&attention_output)?;
443 let attention_output = self.attention_norm.forward(attention_output)?;
444
445 let ff_output = self.feed_forward.forward(attention_output.clone())?;
447 let output = attention_output.add(&ff_output)?;
448 self.output_norm.forward(output)
449 }
450}
451
452pub struct PerformerEmbeddings {
454 word_embeddings: Embedding,
455 position_embeddings: Embedding,
456 token_type_embeddings: Embedding,
457 layer_norm: LayerNorm,
458 #[allow(dead_code)]
459 dropout: f32,
460 device: Device,
461}
462
463impl PerformerEmbeddings {
464 pub fn new(config: &PerformerConfig) -> Result<Self> {
465 Self::new_with_device(config, Device::CPU)
466 }
467
468 pub fn new_with_device(config: &PerformerConfig, device: Device) -> Result<Self> {
469 let word_embeddings = Embedding::new_with_device(
470 config.vocab_size,
471 config.hidden_size,
472 Some(config.pad_token_id as usize),
473 device,
474 )?;
475 let position_embeddings = Embedding::new_with_device(
476 config.max_position_embeddings,
477 config.hidden_size,
478 None,
479 device,
480 )?;
481 let token_type_embeddings =
482 Embedding::new_with_device(config.type_vocab_size, config.hidden_size, None, device)?;
483 let layer_norm =
484 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
485
486 Ok(Self {
487 word_embeddings,
488 position_embeddings,
489 token_type_embeddings,
490 layer_norm,
491 dropout: config.hidden_dropout_prob,
492 device,
493 })
494 }
495
496 pub fn device(&self) -> Device {
497 self.device
498 }
499
500 pub fn parameter_count(&self) -> usize {
501 self.word_embeddings.parameter_count()
502 + self.position_embeddings.parameter_count()
503 + self.token_type_embeddings.parameter_count()
504 + self.layer_norm.parameter_count()
505 }
506}
507
508impl Layer for PerformerEmbeddings {
509 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
510 type Output = Tensor;
511
512 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
513 let (input_ids, token_type_ids, position_ids) = input;
514 let seq_len = input_ids.len();
515
516 let words_embeddings = self.word_embeddings.forward(input_ids)?;
517
518 let position_ids = position_ids.unwrap_or_else(|| (0..seq_len as u32).collect());
519 let position_embeddings = self.position_embeddings.forward(position_ids)?;
520
521 let token_type_ids = token_type_ids.unwrap_or_else(|| vec![0; seq_len]);
522 let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
523
524 let embeddings = words_embeddings.add(&position_embeddings)?.add(&token_type_embeddings)?;
525 let embeddings = self.layer_norm.forward(embeddings)?;
526
527 Ok(embeddings)
528 }
529}
530
531pub struct PerformerEncoder {
533 layers: Vec<PerformerLayer>,
534 device: Device,
535}
536
537impl PerformerEncoder {
538 pub fn new(config: &PerformerConfig) -> Result<Self> {
539 Self::new_with_device(config, Device::CPU)
540 }
541
542 pub fn new_with_device(config: &PerformerConfig, device: Device) -> Result<Self> {
543 let mut layers = Vec::new();
544 for _ in 0..config.num_hidden_layers {
545 layers.push(PerformerLayer::new_with_device(config, device)?);
546 }
547
548 Ok(Self { layers, device })
549 }
550
551 pub fn device(&self) -> Device {
552 self.device
553 }
554
555 pub fn parameter_count(&self) -> usize {
556 self.layers.iter().map(|layer| layer.parameter_count()).sum()
557 }
558}
559
560impl Layer for PerformerEncoder {
561 type Input = Tensor;
562 type Output = Tensor;
563
564 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
565 let mut hidden_states = input;
566
567 for layer in &self.layers {
568 hidden_states = layer.forward(hidden_states)?;
569 }
570
571 Ok(hidden_states)
572 }
573}
574
575pub struct PerformerModel {
577 config: PerformerConfig,
578 embeddings: PerformerEmbeddings,
579 encoder: PerformerEncoder,
580 device: Device,
581}
582
583impl PerformerModel {
584 pub fn new(config: PerformerConfig) -> Result<Self> {
585 Self::new_with_device(config, Device::CPU)
586 }
587
588 pub fn new_with_device(config: PerformerConfig, device: Device) -> Result<Self> {
589 config.validate()?;
590
591 let embeddings = PerformerEmbeddings::new_with_device(&config, device)?;
592 let encoder = PerformerEncoder::new_with_device(&config, device)?;
593
594 Ok(Self {
595 config,
596 embeddings,
597 encoder,
598 device,
599 })
600 }
601
602 pub fn device(&self) -> Device {
603 self.device
604 }
605}
606
607impl Model for PerformerModel {
608 type Config = PerformerConfig;
609 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
610 type Output = Tensor;
611
612 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
613 let embeddings = self.embeddings.forward(input)?;
614 let sequence_output = self.encoder.forward(embeddings)?;
615 Ok(sequence_output)
616 }
617
618 fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
619 Ok(())
620 }
621
622 fn get_config(&self) -> &Self::Config {
623 &self.config
624 }
625
626 fn num_parameters(&self) -> usize {
627 self.embeddings.parameter_count() + self.encoder.parameter_count()
628 }
629}
630
631pub struct PerformerForSequenceClassification {
633 performer: PerformerModel,
634 classifier: Linear,
635 #[allow(dead_code)]
636 num_labels: usize,
637 device: Device,
638}
639
640impl PerformerForSequenceClassification {
641 pub fn new(config: PerformerConfig, num_labels: usize) -> Result<Self> {
642 Self::new_with_device(config, num_labels, Device::CPU)
643 }
644
645 pub fn new_with_device(
646 config: PerformerConfig,
647 num_labels: usize,
648 device: Device,
649 ) -> Result<Self> {
650 let performer = PerformerModel::new_with_device(config.clone(), device)?;
651 let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
652
653 Ok(Self {
654 performer,
655 classifier,
656 num_labels,
657 device,
658 })
659 }
660
661 pub fn device(&self) -> Device {
662 self.device
663 }
664}
665
666impl Model for PerformerForSequenceClassification {
667 type Config = PerformerConfig;
668 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
669 type Output = Tensor;
670
671 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
672 let sequence_output = self.performer.forward(input)?;
673 let cls_output = sequence_output.slice(1, 0, 1)?; let cls_output = cls_output.squeeze(1)?; self.classifier.forward(cls_output)
676 }
677
678 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
679 self.performer.load_pretrained(reader)
680 }
681
682 fn get_config(&self) -> &Self::Config {
683 self.performer.get_config()
684 }
685
686 fn num_parameters(&self) -> usize {
687 self.performer.num_parameters() + self.classifier.parameter_count()
688 }
689}
690
691pub struct PerformerForMaskedLM {
693 performer: PerformerModel,
694 mlm_head: Linear,
695 device: Device,
696}
697
698impl PerformerForMaskedLM {
699 pub fn new(config: PerformerConfig) -> Result<Self> {
700 Self::new_with_device(config, Device::CPU)
701 }
702
703 pub fn new_with_device(config: PerformerConfig, device: Device) -> Result<Self> {
704 let performer = PerformerModel::new_with_device(config.clone(), device)?;
705 let mlm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, true, device);
706
707 Ok(Self {
708 performer,
709 mlm_head,
710 device,
711 })
712 }
713
714 pub fn device(&self) -> Device {
715 self.device
716 }
717}
718
719impl Model for PerformerForMaskedLM {
720 type Config = PerformerConfig;
721 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
722 type Output = Tensor;
723
724 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
725 let sequence_output = self.performer.forward(input)?;
726 self.mlm_head.forward(sequence_output)
727 }
728
729 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
730 self.performer.load_pretrained(reader)
731 }
732
733 fn get_config(&self) -> &Self::Config {
734 self.performer.get_config()
735 }
736
737 fn num_parameters(&self) -> usize {
738 self.performer.num_parameters() + self.mlm_head.parameter_count()
739 }
740}