trustformers_models/fnet/
config.rs1use serde::{Deserialize, Serialize};
2use trustformers_core::{errors::invalid_config, traits::Config};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct FNetConfig {
8 pub vocab_size: usize,
9 pub hidden_size: usize,
10 pub num_hidden_layers: usize,
11 pub intermediate_size: usize,
12 pub hidden_act: String,
13 pub hidden_dropout_prob: f32,
14 pub max_position_embeddings: usize,
15 pub type_vocab_size: usize,
16 pub initializer_range: f32,
17 pub layer_norm_eps: f32,
18 pub pad_token_id: u32,
19 pub position_embedding_type: String,
20
21 pub use_fourier_transform: bool, pub use_tpu_optimized_fft: bool, pub fourier_transform_type: String, pub use_bias_in_fourier: bool, pub fourier_dropout_prob: f32, }
28
29impl Default for FNetConfig {
30 fn default() -> Self {
31 Self {
32 vocab_size: 32000, hidden_size: 768,
34 num_hidden_layers: 12,
35 intermediate_size: 3072,
36 hidden_act: "gelu".to_string(),
37 hidden_dropout_prob: 0.1,
38 max_position_embeddings: 512,
39 type_vocab_size: 4, initializer_range: 0.02,
41 layer_norm_eps: 1e-12,
42 pad_token_id: 0,
43 position_embedding_type: "absolute".to_string(),
44
45 use_fourier_transform: true,
47 use_tpu_optimized_fft: false,
48 fourier_transform_type: "dft".to_string(),
49 use_bias_in_fourier: true,
50 fourier_dropout_prob: 0.0, }
52 }
53}
54
55impl Config for FNetConfig {
56 fn validate(&self) -> trustformers_core::errors::Result<()> {
57 if !["dft", "real_dft", "dct"].contains(&self.fourier_transform_type.as_str()) {
60 return Err(trustformers_core::errors::invalid_config(
61 "fourier_transform_type",
62 "fourier_transform_type must be one of: dft, real_dft, dct",
63 ));
64 }
65
66 if self.max_position_embeddings > 8192 {
68 return Err(invalid_config(
69 "config_field",
70 "max_position_embeddings > 8192 may be inefficient for FFT. Consider chunking."
71 .to_string(),
72 ));
73 }
74
75 Ok(())
76 }
77
78 fn architecture(&self) -> &'static str {
79 "FNet"
80 }
81}
82
83impl FNetConfig {
84 pub fn fnet_base() -> Self {
86 Self::default()
87 }
88
89 pub fn fnet_large() -> Self {
91 Self {
92 hidden_size: 1024,
93 num_hidden_layers: 24,
94 intermediate_size: 4096,
95 ..Self::default()
96 }
97 }
98
99 pub fn fnet_tpu() -> Self {
101 Self {
102 use_tpu_optimized_fft: true,
103 fourier_transform_type: "real_dft".to_string(), max_position_embeddings: 1024, ..Self::default()
106 }
107 }
108
109 pub fn fnet_dct() -> Self {
111 Self {
112 fourier_transform_type: "dct".to_string(),
113 max_position_embeddings: 1024,
114 ..Self::default()
115 }
116 }
117
118 pub fn fnet_long() -> Self {
120 Self {
121 max_position_embeddings: 4096,
122 fourier_transform_type: "real_dft".to_string(), ..Self::default()
124 }
125 }
126
127 pub fn complexity_advantage(&self) -> f32 {
129 let n = self.max_position_embeddings as f32;
130 let attention_complexity = n * n; let fourier_complexity = n * n.log2(); attention_complexity / fourier_complexity
133 }
134
135 pub fn is_efficient_config(&self) -> bool {
137 let n = self.max_position_embeddings;
139 n > 0 && (n & (n - 1)) == 0
140 }
141
142 pub fn recommended_batch_size(&self) -> usize {
144 match self.hidden_size {
146 768 => 64, 1024 => 32, _ => 16, }
150 }
151}