Skip to main content

trustformers_models/fnet/
model.rs

1use crate::fnet::config::FNetConfig;
2use std::io::Read;
3use trustformers_core::{
4    device::Device,
5    errors::Result,
6    layers::{Embedding, LayerNorm, Linear},
7    tensor::Tensor,
8    traits::{Config, Layer, Model},
9};
10
11/// Fourier Transform layer that replaces self-attention
12/// Applies 2D DFT along sequence and feature dimensions
13pub struct FourierTransform {
14    fourier_type: String,
15    #[allow(dead_code)]
16    use_bias: bool,
17    bias: Option<Linear>,
18    #[allow(dead_code)]
19    dropout: f32,
20    device: Device,
21}
22
23impl FourierTransform {
24    pub fn new(config: &FNetConfig) -> Result<Self> {
25        Self::new_with_device(config, Device::CPU)
26    }
27
28    pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
29        let bias = if config.use_bias_in_fourier {
30            Some(Linear::new_with_device(
31                config.hidden_size,
32                config.hidden_size,
33                true,
34                device,
35            ))
36        } else {
37            None
38        };
39
40        Ok(Self {
41            fourier_type: config.fourier_transform_type.clone(),
42            use_bias: config.use_bias_in_fourier,
43            bias,
44            dropout: config.fourier_dropout_prob,
45            device,
46        })
47    }
48
49    pub fn device(&self) -> Device {
50        self.device
51    }
52
53    pub fn parameter_count(&self) -> usize {
54        if let Some(ref bias_layer) = self.bias {
55            bias_layer.parameter_count()
56        } else {
57            0
58        }
59    }
60
61    /// Apply Discrete Fourier Transform (DFT)
62    fn apply_dft(&self, x: &Tensor) -> Result<Tensor> {
63        // x: [batch_size, seq_len, hidden_size]
64        let _batch_size = x.shape()[0];
65        let _seq_len = x.shape()[1];
66        let _hidden_size = x.shape()[2];
67
68        // Apply DFT along sequence dimension first
69        let x_seq_dft = self.dft_1d(x, 1)?; // DFT along dimension 1 (seq_len)
70
71        // Apply DFT along hidden dimension
72        let x_both_dft = self.dft_1d(&x_seq_dft, 2)?; // DFT along dimension 2 (hidden_size)
73
74        // Take real part only (common practice in FNet)
75        self.real_part(&x_both_dft)
76    }
77
78    /// Apply Real DFT (more efficient variant)
79    fn apply_real_dft(&self, x: &Tensor) -> Result<Tensor> {
80        // Similar to DFT but optimized for real inputs
81        // For simplicity, we'll implement this as regular DFT taking real part
82        self.apply_dft(x)
83    }
84
85    /// Apply Discrete Cosine Transform (DCT)
86    fn apply_dct(&self, x: &Tensor) -> Result<Tensor> {
87        // DCT is real-valued and often more efficient than DFT
88        // For now, approximate with cosine-based transformation
89        let batch_size = x.shape()[0];
90        let seq_len = x.shape()[1];
91        let hidden_size = x.shape()[2];
92
93        // Create DCT basis matrices
94        let seq_dct_matrix = self.create_dct_matrix(seq_len)?;
95        let hidden_dct_matrix = self.create_dct_matrix(hidden_size)?;
96
97        // Apply DCT along sequence dimension
98        // x @ seq_dct_matrix^T
99        let seq_shape = seq_dct_matrix.shape();
100        let seq_dim0 = seq_shape.len().saturating_sub(2);
101        let seq_dim1 = seq_shape.len().saturating_sub(1);
102        let x_seq_dct = x.matmul(&seq_dct_matrix.transpose(seq_dim0, seq_dim1)?)?;
103
104        // Apply DCT along hidden dimension
105        // For hidden dimension: reshape, apply DCT, reshape back
106        let reshaped = x_seq_dct.reshape(&[batch_size * seq_len, hidden_size])?;
107        let hidden_shape = hidden_dct_matrix.shape();
108        let hidden_dim0 = hidden_shape.len().saturating_sub(2);
109        let hidden_dim1 = hidden_shape.len().saturating_sub(1);
110        let hidden_dct =
111            reshaped.matmul(&hidden_dct_matrix.transpose(hidden_dim0, hidden_dim1)?)?;
112        hidden_dct.reshape(&[batch_size, seq_len, hidden_size])
113    }
114
115    /// Create DCT transformation matrix
116    fn create_dct_matrix(&self, n: usize) -> Result<Tensor> {
117        let mut matrix = Vec::new();
118        let pi = std::f32::consts::PI;
119
120        for k in 0..n {
121            for i in 0..n {
122                let value = if k == 0 {
123                    (1.0 / n as f32).sqrt()
124                } else {
125                    (2.0 / n as f32).sqrt()
126                        * (pi * k as f32 * (2 * i + 1) as f32 / (2 * n) as f32).cos()
127                };
128                matrix.push(value);
129            }
130        }
131
132        Tensor::from_vec(matrix, &[n, n])
133    }
134
135    /// 1D DFT implementation (simplified)
136    fn dft_1d(&self, x: &Tensor, dim: i32) -> Result<Tensor> {
137        // This is a simplified implementation
138        // In practice, you'd use an efficient FFT library
139
140        let shape = x.shape();
141        let n = shape[dim as usize];
142
143        // For simplicity, we'll approximate DFT with a learned transformation
144        // that captures the frequency domain mixing behavior
145
146        // Create a pseudo-DFT matrix that mixes elements
147        let mut dft_matrix = Vec::new();
148        let pi = std::f32::consts::PI;
149
150        for k in 0..n {
151            for j in 0..n {
152                let angle = -2.0 * pi * (k * j) as f32 / n as f32;
153                let real_part = angle.cos() / (n as f32).sqrt();
154                dft_matrix.push(real_part);
155            }
156        }
157
158        let dft_tensor = Tensor::from_vec(dft_matrix, &[n, n])?;
159
160        // Apply transformation along the specified dimension
161        if dim == 1 {
162            // Along sequence dimension
163            let dft_shape = dft_tensor.shape();
164            let dft_dim0 = dft_shape.len().saturating_sub(2);
165            let dft_dim1 = dft_shape.len().saturating_sub(1);
166            x.matmul(&dft_tensor.transpose(dft_dim0, dft_dim1)?)
167        } else {
168            // Along hidden dimension - need to reshape
169            let batch_size = shape[0];
170            let seq_len = shape[1];
171            let hidden_size = shape[2];
172
173            let reshaped = x.reshape(&[batch_size * seq_len, hidden_size])?;
174            let dft_shape = dft_tensor.shape();
175            let dft_dim0 = dft_shape.len().saturating_sub(2);
176            let dft_dim1 = dft_shape.len().saturating_sub(1);
177            let transformed = reshaped.matmul(&dft_tensor.transpose(dft_dim0, dft_dim1)?)?;
178            transformed.reshape(&[batch_size, seq_len, hidden_size])
179        }
180    }
181
182    /// Extract real part of complex tensor
183    fn real_part(&self, x: &Tensor) -> Result<Tensor> {
184        // Since we're working with real tensors, just return as-is
185        // In a full implementation, this would handle complex numbers
186        Ok(x.clone())
187    }
188}
189
190impl Layer for FourierTransform {
191    type Input = Tensor;
192    type Output = Tensor;
193
194    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
195        // Apply the appropriate Fourier transform
196        let fourier_output = match self.fourier_type.as_str() {
197            "dft" => self.apply_dft(&input)?,
198            "real_dft" => self.apply_real_dft(&input)?,
199            "dct" => self.apply_dct(&input)?,
200            _ => self.apply_dft(&input)?, // Default to DFT
201        };
202
203        // Apply bias if configured
204        let output = if let Some(ref bias_layer) = self.bias {
205            bias_layer.forward(fourier_output)?
206        } else {
207            fourier_output
208        };
209
210        // Apply dropout if configured (in training mode)
211        // For inference, we skip dropout
212        Ok(output)
213    }
214}
215
216/// FNet feed-forward network (same as BERT)
217pub struct FNetFeedForward {
218    dense1: Linear,
219    dense2: Linear,
220    activation: String,
221    #[allow(dead_code)]
222    dropout: f32,
223    device: Device,
224}
225
226impl FNetFeedForward {
227    pub fn new(config: &FNetConfig) -> Result<Self> {
228        Self::new_with_device(config, Device::CPU)
229    }
230
231    pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
232        let dense1 =
233            Linear::new_with_device(config.hidden_size, config.intermediate_size, true, device);
234        let dense2 =
235            Linear::new_with_device(config.intermediate_size, config.hidden_size, true, device);
236
237        Ok(Self {
238            dense1,
239            dense2,
240            activation: config.hidden_act.clone(),
241            dropout: config.hidden_dropout_prob,
242            device,
243        })
244    }
245
246    pub fn device(&self) -> Device {
247        self.device
248    }
249
250    pub fn parameter_count(&self) -> usize {
251        self.dense1.parameter_count() + self.dense2.parameter_count()
252    }
253
254    fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
255        match self.activation.as_str() {
256            "gelu" => x.gelu(),
257            "relu" => x.relu(),
258            "silu" | "swish" => x.silu(),
259            _ => Ok(x.clone()),
260        }
261    }
262}
263
264impl Layer for FNetFeedForward {
265    type Input = Tensor;
266    type Output = Tensor;
267
268    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
269        let hidden = self.dense1.forward(input)?;
270        let hidden = self.apply_activation(&hidden)?;
271        self.dense2.forward(hidden)
272    }
273}
274
275/// FNet encoder layer (Fourier + FFN)
276pub struct FNetLayer {
277    fourier_transform: FourierTransform,
278    feed_forward: FNetFeedForward,
279    fourier_norm: LayerNorm,
280    output_norm: LayerNorm,
281    device: Device,
282}
283
284impl FNetLayer {
285    pub fn new(config: &FNetConfig) -> Result<Self> {
286        Self::new_with_device(config, Device::CPU)
287    }
288
289    pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
290        let fourier_transform = FourierTransform::new_with_device(config, device)?;
291        let feed_forward = FNetFeedForward::new_with_device(config, device)?;
292        let fourier_norm =
293            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
294        let output_norm =
295            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
296
297        Ok(Self {
298            fourier_transform,
299            feed_forward,
300            fourier_norm,
301            output_norm,
302            device,
303        })
304    }
305
306    pub fn device(&self) -> Device {
307        self.device
308    }
309
310    pub fn parameter_count(&self) -> usize {
311        self.fourier_transform.parameter_count()
312            + self.feed_forward.parameter_count()
313            + self.fourier_norm.parameter_count()
314            + self.output_norm.parameter_count()
315    }
316}
317
318impl Layer for FNetLayer {
319    type Input = Tensor;
320    type Output = Tensor;
321
322    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
323        // Fourier transform with residual connection and layer norm
324        let fourier_output = self.fourier_transform.forward(input.clone())?;
325        let fourier_output = input.add(&fourier_output)?; // Residual
326        let fourier_output = self.fourier_norm.forward(fourier_output)?;
327
328        // Feed-forward with residual connection and layer norm
329        let ff_output = self.feed_forward.forward(fourier_output.clone())?;
330        let output = fourier_output.add(&ff_output)?; // Residual
331        self.output_norm.forward(output)
332    }
333}
334
335/// FNet embeddings (same as BERT)
336pub struct FNetEmbeddings {
337    word_embeddings: Embedding,
338    position_embeddings: Embedding,
339    token_type_embeddings: Embedding,
340    layer_norm: LayerNorm,
341    #[allow(dead_code)]
342    dropout: f32,
343    device: Device,
344}
345
346impl FNetEmbeddings {
347    pub fn new(config: &FNetConfig) -> Result<Self> {
348        Self::new_with_device(config, Device::CPU)
349    }
350
351    pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
352        let word_embeddings = Embedding::new_with_device(
353            config.vocab_size,
354            config.hidden_size,
355            Some(config.pad_token_id as usize),
356            device,
357        )?;
358        let position_embeddings = Embedding::new_with_device(
359            config.max_position_embeddings,
360            config.hidden_size,
361            None,
362            device,
363        )?;
364        let token_type_embeddings =
365            Embedding::new_with_device(config.type_vocab_size, config.hidden_size, None, device)?;
366        let layer_norm =
367            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
368
369        Ok(Self {
370            word_embeddings,
371            position_embeddings,
372            token_type_embeddings,
373            layer_norm,
374            dropout: config.hidden_dropout_prob,
375            device,
376        })
377    }
378
379    pub fn device(&self) -> Device {
380        self.device
381    }
382
383    pub fn parameter_count(&self) -> usize {
384        self.word_embeddings.parameter_count()
385            + self.position_embeddings.parameter_count()
386            + self.token_type_embeddings.parameter_count()
387            + self.layer_norm.parameter_count()
388    }
389}
390
391impl Layer for FNetEmbeddings {
392    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
393    type Output = Tensor;
394
395    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
396        let (input_ids, token_type_ids, position_ids) = input;
397        let seq_len = input_ids.len();
398
399        let words_embeddings = self.word_embeddings.forward(input_ids)?;
400
401        let position_ids = position_ids.unwrap_or_else(|| (0..seq_len as u32).collect());
402        let position_embeddings = self.position_embeddings.forward(position_ids)?;
403
404        let token_type_ids = token_type_ids.unwrap_or_else(|| vec![0; seq_len]);
405        let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
406
407        let embeddings = words_embeddings.add(&position_embeddings)?.add(&token_type_embeddings)?;
408        let embeddings = self.layer_norm.forward(embeddings)?;
409
410        Ok(embeddings)
411    }
412}
413
414/// FNet encoder
415pub struct FNetEncoder {
416    layers: Vec<FNetLayer>,
417    device: Device,
418}
419
420impl FNetEncoder {
421    pub fn new(config: &FNetConfig) -> Result<Self> {
422        Self::new_with_device(config, Device::CPU)
423    }
424
425    pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
426        let mut layers = Vec::new();
427        for _ in 0..config.num_hidden_layers {
428            layers.push(FNetLayer::new_with_device(config, device)?);
429        }
430
431        Ok(Self { layers, device })
432    }
433
434    pub fn device(&self) -> Device {
435        self.device
436    }
437
438    pub fn parameter_count(&self) -> usize {
439        self.layers.iter().map(|layer| layer.parameter_count()).sum()
440    }
441}
442
443impl Layer for FNetEncoder {
444    type Input = Tensor;
445    type Output = Tensor;
446
447    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
448        let mut hidden_states = input;
449
450        for layer in &self.layers {
451            hidden_states = layer.forward(hidden_states)?;
452        }
453
454        Ok(hidden_states)
455    }
456}
457
458/// FNet model
459pub struct FNetModel {
460    config: FNetConfig,
461    embeddings: FNetEmbeddings,
462    encoder: FNetEncoder,
463    device: Device,
464}
465
466impl FNetModel {
467    pub fn new(config: FNetConfig) -> Result<Self> {
468        Self::new_with_device(config, Device::CPU)
469    }
470
471    pub fn new_with_device(config: FNetConfig, device: Device) -> Result<Self> {
472        config.validate()?;
473
474        let embeddings = FNetEmbeddings::new_with_device(&config, device)?;
475        let encoder = FNetEncoder::new_with_device(&config, device)?;
476
477        Ok(Self {
478            config,
479            embeddings,
480            encoder,
481            device,
482        })
483    }
484
485    pub fn device(&self) -> Device {
486        self.device
487    }
488}
489
490impl Model for FNetModel {
491    type Config = FNetConfig;
492    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
493    type Output = Tensor;
494
495    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
496        let embeddings = self.embeddings.forward(input)?;
497        let sequence_output = self.encoder.forward(embeddings)?;
498        Ok(sequence_output)
499    }
500
501    fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
502        Ok(())
503    }
504
505    fn get_config(&self) -> &Self::Config {
506        &self.config
507    }
508
509    fn num_parameters(&self) -> usize {
510        self.embeddings.parameter_count() + self.encoder.parameter_count()
511    }
512}
513
514/// FNet for sequence classification
515pub struct FNetForSequenceClassification {
516    fnet: FNetModel,
517    classifier: Linear,
518    #[allow(dead_code)]
519    num_labels: usize,
520    device: Device,
521}
522
523impl FNetForSequenceClassification {
524    pub fn new(config: FNetConfig, num_labels: usize) -> Result<Self> {
525        Self::new_with_device(config, num_labels, Device::CPU)
526    }
527
528    pub fn new_with_device(config: FNetConfig, num_labels: usize, device: Device) -> Result<Self> {
529        let fnet = FNetModel::new_with_device(config.clone(), device)?;
530        let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
531
532        Ok(Self {
533            fnet,
534            classifier,
535            num_labels,
536            device,
537        })
538    }
539
540    pub fn device(&self) -> Device {
541        self.device
542    }
543}
544
545impl Model for FNetForSequenceClassification {
546    type Config = FNetConfig;
547    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
548    type Output = Tensor;
549
550    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
551        let sequence_output = self.fnet.forward(input)?;
552        let cls_output = sequence_output.slice(1, 0, 1)?; // Get first token (CLS) from sequence
553        let cls_output = cls_output.squeeze(1)?; // Remove singleton sequence dimension
554        self.classifier.forward(cls_output)
555    }
556
557    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
558        self.fnet.load_pretrained(reader)
559    }
560
561    fn get_config(&self) -> &Self::Config {
562        self.fnet.get_config()
563    }
564
565    fn num_parameters(&self) -> usize {
566        self.fnet.num_parameters() + self.classifier.parameter_count()
567    }
568}
569
570/// FNet for masked language modeling
571pub struct FNetForMaskedLM {
572    fnet: FNetModel,
573    mlm_head: Linear,
574    device: Device,
575}
576
577impl FNetForMaskedLM {
578    pub fn new(config: FNetConfig) -> Result<Self> {
579        Self::new_with_device(config, Device::CPU)
580    }
581
582    pub fn new_with_device(config: FNetConfig, device: Device) -> Result<Self> {
583        let fnet = FNetModel::new_with_device(config.clone(), device)?;
584        let mlm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, true, device);
585
586        Ok(Self {
587            fnet,
588            mlm_head,
589            device,
590        })
591    }
592
593    pub fn device(&self) -> Device {
594        self.device
595    }
596}
597
598impl Model for FNetForMaskedLM {
599    type Config = FNetConfig;
600    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
601    type Output = Tensor;
602
603    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
604        let sequence_output = self.fnet.forward(input)?;
605        self.mlm_head.forward(sequence_output)
606    }
607
608    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
609        self.fnet.load_pretrained(reader)
610    }
611
612    fn get_config(&self) -> &Self::Config {
613        self.fnet.get_config()
614    }
615
616    fn num_parameters(&self) -> usize {
617        self.fnet.num_parameters() + self.mlm_head.parameter_count()
618    }
619}