Skip to main content

trustformers_models/fnet/
config.rs

1use serde::{Deserialize, Serialize};
2use trustformers_core::{errors::invalid_config, traits::Config};
3
4/// FNet model configuration
5/// Reference: "FNet: Mixing Tokens with Fourier Transforms" (Lee-Thorp et al., 2021)
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct FNetConfig {
8    pub vocab_size: usize,
9    pub hidden_size: usize,
10    pub num_hidden_layers: usize,
11    pub intermediate_size: usize,
12    pub hidden_act: String,
13    pub hidden_dropout_prob: f32,
14    pub max_position_embeddings: usize,
15    pub type_vocab_size: usize,
16    pub initializer_range: f32,
17    pub layer_norm_eps: f32,
18    pub pad_token_id: u32,
19    pub position_embedding_type: String,
20
21    // FNet-specific parameters
22    pub use_fourier_transform: bool, // Use DFT instead of attention
23    pub use_tpu_optimized_fft: bool, // Use TPU-optimized FFT variants
24    pub fourier_transform_type: String, // "dft", "real_dft", "dct"
25    pub use_bias_in_fourier: bool,   // Add bias after Fourier transform
26    pub fourier_dropout_prob: f32,   // Dropout after Fourier layer
27}
28
29impl Default for FNetConfig {
30    fn default() -> Self {
31        Self {
32            vocab_size: 32000, // Use larger vocab like T5
33            hidden_size: 768,
34            num_hidden_layers: 12,
35            intermediate_size: 3072,
36            hidden_act: "gelu".to_string(),
37            hidden_dropout_prob: 0.1,
38            max_position_embeddings: 512,
39            type_vocab_size: 4, // FNet often uses more token types
40            initializer_range: 0.02,
41            layer_norm_eps: 1e-12,
42            pad_token_id: 0,
43            position_embedding_type: "absolute".to_string(),
44
45            // FNet defaults
46            use_fourier_transform: true,
47            use_tpu_optimized_fft: false,
48            fourier_transform_type: "dft".to_string(),
49            use_bias_in_fourier: true,
50            fourier_dropout_prob: 0.0, // Usually no dropout on Fourier layer
51        }
52    }
53}
54
55impl Config for FNetConfig {
56    fn validate(&self) -> trustformers_core::errors::Result<()> {
57        // No head constraints since FNet doesn't use attention heads
58
59        if !["dft", "real_dft", "dct"].contains(&self.fourier_transform_type.as_str()) {
60            return Err(trustformers_core::errors::invalid_config(
61                "fourier_transform_type",
62                "fourier_transform_type must be one of: dft, real_dft, dct",
63            ));
64        }
65
66        // Check if sequence length is reasonable for FFT
67        if self.max_position_embeddings > 8192 {
68            return Err(invalid_config(
69                "config_field",
70                "max_position_embeddings > 8192 may be inefficient for FFT. Consider chunking."
71                    .to_string(),
72            ));
73        }
74
75        Ok(())
76    }
77
78    fn architecture(&self) -> &'static str {
79        "FNet"
80    }
81}
82
83impl FNetConfig {
84    /// FNet-Base configuration
85    pub fn fnet_base() -> Self {
86        Self::default()
87    }
88
89    /// FNet-Large configuration
90    pub fn fnet_large() -> Self {
91        Self {
92            hidden_size: 1024,
93            num_hidden_layers: 24,
94            intermediate_size: 4096,
95            ..Self::default()
96        }
97    }
98
99    /// FNet optimized for TPU training
100    pub fn fnet_tpu() -> Self {
101        Self {
102            use_tpu_optimized_fft: true,
103            fourier_transform_type: "real_dft".to_string(), // More efficient on TPU
104            max_position_embeddings: 1024,                  // Power of 2 for efficiency
105            ..Self::default()
106        }
107    }
108
109    /// FNet with DCT (Discrete Cosine Transform) instead of DFT
110    pub fn fnet_dct() -> Self {
111        Self {
112            fourier_transform_type: "dct".to_string(),
113            max_position_embeddings: 1024,
114            ..Self::default()
115        }
116    }
117
118    /// Long-sequence FNet (up to 4K tokens)
119    pub fn fnet_long() -> Self {
120        Self {
121            max_position_embeddings: 4096,
122            fourier_transform_type: "real_dft".to_string(), // More efficient for long sequences
123            ..Self::default()
124        }
125    }
126
127    /// Get theoretical complexity advantage over attention
128    pub fn complexity_advantage(&self) -> f32 {
129        let n = self.max_position_embeddings as f32;
130        let attention_complexity = n * n; // O(n²)
131        let fourier_complexity = n * n.log2(); // O(n log n)
132        attention_complexity / fourier_complexity
133    }
134
135    /// Check if configuration is optimized for efficiency
136    pub fn is_efficient_config(&self) -> bool {
137        // Check if sequence length is power of 2 (optimal for FFT)
138        let n = self.max_position_embeddings;
139        n > 0 && (n & (n - 1)) == 0
140    }
141
142    /// Get recommended batch size for efficiency
143    pub fn recommended_batch_size(&self) -> usize {
144        // Fourier transforms are more batch-friendly than attention
145        match self.hidden_size {
146            768 => 64,  // Base model
147            1024 => 32, // Large model
148            _ => 16,    // Conservative default
149        }
150    }
151}