1use crate::fnet::config::FNetConfig;
2use std::io::Read;
3use trustformers_core::{
4 device::Device,
5 errors::Result,
6 layers::{Embedding, LayerNorm, Linear},
7 tensor::Tensor,
8 traits::{Config, Layer, Model},
9};
10
11pub struct FourierTransform {
14 fourier_type: String,
15 #[allow(dead_code)]
16 use_bias: bool,
17 bias: Option<Linear>,
18 #[allow(dead_code)]
19 dropout: f32,
20 device: Device,
21}
22
23impl FourierTransform {
24 pub fn new(config: &FNetConfig) -> Result<Self> {
25 Self::new_with_device(config, Device::CPU)
26 }
27
28 pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
29 let bias = if config.use_bias_in_fourier {
30 Some(Linear::new_with_device(
31 config.hidden_size,
32 config.hidden_size,
33 true,
34 device,
35 ))
36 } else {
37 None
38 };
39
40 Ok(Self {
41 fourier_type: config.fourier_transform_type.clone(),
42 use_bias: config.use_bias_in_fourier,
43 bias,
44 dropout: config.fourier_dropout_prob,
45 device,
46 })
47 }
48
49 pub fn device(&self) -> Device {
50 self.device
51 }
52
53 pub fn parameter_count(&self) -> usize {
54 if let Some(ref bias_layer) = self.bias {
55 bias_layer.parameter_count()
56 } else {
57 0
58 }
59 }
60
61 fn apply_dft(&self, x: &Tensor) -> Result<Tensor> {
63 let _batch_size = x.shape()[0];
65 let _seq_len = x.shape()[1];
66 let _hidden_size = x.shape()[2];
67
68 let x_seq_dft = self.dft_1d(x, 1)?; let x_both_dft = self.dft_1d(&x_seq_dft, 2)?; self.real_part(&x_both_dft)
76 }
77
78 fn apply_real_dft(&self, x: &Tensor) -> Result<Tensor> {
80 self.apply_dft(x)
83 }
84
85 fn apply_dct(&self, x: &Tensor) -> Result<Tensor> {
87 let batch_size = x.shape()[0];
90 let seq_len = x.shape()[1];
91 let hidden_size = x.shape()[2];
92
93 let seq_dct_matrix = self.create_dct_matrix(seq_len)?;
95 let hidden_dct_matrix = self.create_dct_matrix(hidden_size)?;
96
97 let seq_shape = seq_dct_matrix.shape();
100 let seq_dim0 = seq_shape.len().saturating_sub(2);
101 let seq_dim1 = seq_shape.len().saturating_sub(1);
102 let x_seq_dct = x.matmul(&seq_dct_matrix.transpose(seq_dim0, seq_dim1)?)?;
103
104 let reshaped = x_seq_dct.reshape(&[batch_size * seq_len, hidden_size])?;
107 let hidden_shape = hidden_dct_matrix.shape();
108 let hidden_dim0 = hidden_shape.len().saturating_sub(2);
109 let hidden_dim1 = hidden_shape.len().saturating_sub(1);
110 let hidden_dct =
111 reshaped.matmul(&hidden_dct_matrix.transpose(hidden_dim0, hidden_dim1)?)?;
112 hidden_dct.reshape(&[batch_size, seq_len, hidden_size])
113 }
114
115 fn create_dct_matrix(&self, n: usize) -> Result<Tensor> {
117 let mut matrix = Vec::new();
118 let pi = std::f32::consts::PI;
119
120 for k in 0..n {
121 for i in 0..n {
122 let value = if k == 0 {
123 (1.0 / n as f32).sqrt()
124 } else {
125 (2.0 / n as f32).sqrt()
126 * (pi * k as f32 * (2 * i + 1) as f32 / (2 * n) as f32).cos()
127 };
128 matrix.push(value);
129 }
130 }
131
132 Tensor::from_vec(matrix, &[n, n])
133 }
134
135 fn dft_1d(&self, x: &Tensor, dim: i32) -> Result<Tensor> {
137 let shape = x.shape();
141 let n = shape[dim as usize];
142
143 let mut dft_matrix = Vec::new();
148 let pi = std::f32::consts::PI;
149
150 for k in 0..n {
151 for j in 0..n {
152 let angle = -2.0 * pi * (k * j) as f32 / n as f32;
153 let real_part = angle.cos() / (n as f32).sqrt();
154 dft_matrix.push(real_part);
155 }
156 }
157
158 let dft_tensor = Tensor::from_vec(dft_matrix, &[n, n])?;
159
160 if dim == 1 {
162 let dft_shape = dft_tensor.shape();
164 let dft_dim0 = dft_shape.len().saturating_sub(2);
165 let dft_dim1 = dft_shape.len().saturating_sub(1);
166 x.matmul(&dft_tensor.transpose(dft_dim0, dft_dim1)?)
167 } else {
168 let batch_size = shape[0];
170 let seq_len = shape[1];
171 let hidden_size = shape[2];
172
173 let reshaped = x.reshape(&[batch_size * seq_len, hidden_size])?;
174 let dft_shape = dft_tensor.shape();
175 let dft_dim0 = dft_shape.len().saturating_sub(2);
176 let dft_dim1 = dft_shape.len().saturating_sub(1);
177 let transformed = reshaped.matmul(&dft_tensor.transpose(dft_dim0, dft_dim1)?)?;
178 transformed.reshape(&[batch_size, seq_len, hidden_size])
179 }
180 }
181
182 fn real_part(&self, x: &Tensor) -> Result<Tensor> {
184 Ok(x.clone())
187 }
188}
189
190impl Layer for FourierTransform {
191 type Input = Tensor;
192 type Output = Tensor;
193
194 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
195 let fourier_output = match self.fourier_type.as_str() {
197 "dft" => self.apply_dft(&input)?,
198 "real_dft" => self.apply_real_dft(&input)?,
199 "dct" => self.apply_dct(&input)?,
200 _ => self.apply_dft(&input)?, };
202
203 let output = if let Some(ref bias_layer) = self.bias {
205 bias_layer.forward(fourier_output)?
206 } else {
207 fourier_output
208 };
209
210 Ok(output)
213 }
214}
215
216pub struct FNetFeedForward {
218 dense1: Linear,
219 dense2: Linear,
220 activation: String,
221 #[allow(dead_code)]
222 dropout: f32,
223 device: Device,
224}
225
226impl FNetFeedForward {
227 pub fn new(config: &FNetConfig) -> Result<Self> {
228 Self::new_with_device(config, Device::CPU)
229 }
230
231 pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
232 let dense1 =
233 Linear::new_with_device(config.hidden_size, config.intermediate_size, true, device);
234 let dense2 =
235 Linear::new_with_device(config.intermediate_size, config.hidden_size, true, device);
236
237 Ok(Self {
238 dense1,
239 dense2,
240 activation: config.hidden_act.clone(),
241 dropout: config.hidden_dropout_prob,
242 device,
243 })
244 }
245
246 pub fn device(&self) -> Device {
247 self.device
248 }
249
250 pub fn parameter_count(&self) -> usize {
251 self.dense1.parameter_count() + self.dense2.parameter_count()
252 }
253
254 fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
255 match self.activation.as_str() {
256 "gelu" => x.gelu(),
257 "relu" => x.relu(),
258 "silu" | "swish" => x.silu(),
259 _ => Ok(x.clone()),
260 }
261 }
262}
263
264impl Layer for FNetFeedForward {
265 type Input = Tensor;
266 type Output = Tensor;
267
268 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
269 let hidden = self.dense1.forward(input)?;
270 let hidden = self.apply_activation(&hidden)?;
271 self.dense2.forward(hidden)
272 }
273}
274
275pub struct FNetLayer {
277 fourier_transform: FourierTransform,
278 feed_forward: FNetFeedForward,
279 fourier_norm: LayerNorm,
280 output_norm: LayerNorm,
281 device: Device,
282}
283
284impl FNetLayer {
285 pub fn new(config: &FNetConfig) -> Result<Self> {
286 Self::new_with_device(config, Device::CPU)
287 }
288
289 pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
290 let fourier_transform = FourierTransform::new_with_device(config, device)?;
291 let feed_forward = FNetFeedForward::new_with_device(config, device)?;
292 let fourier_norm =
293 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
294 let output_norm =
295 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
296
297 Ok(Self {
298 fourier_transform,
299 feed_forward,
300 fourier_norm,
301 output_norm,
302 device,
303 })
304 }
305
306 pub fn device(&self) -> Device {
307 self.device
308 }
309
310 pub fn parameter_count(&self) -> usize {
311 self.fourier_transform.parameter_count()
312 + self.feed_forward.parameter_count()
313 + self.fourier_norm.parameter_count()
314 + self.output_norm.parameter_count()
315 }
316}
317
318impl Layer for FNetLayer {
319 type Input = Tensor;
320 type Output = Tensor;
321
322 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
323 let fourier_output = self.fourier_transform.forward(input.clone())?;
325 let fourier_output = input.add(&fourier_output)?; let fourier_output = self.fourier_norm.forward(fourier_output)?;
327
328 let ff_output = self.feed_forward.forward(fourier_output.clone())?;
330 let output = fourier_output.add(&ff_output)?; self.output_norm.forward(output)
332 }
333}
334
335pub struct FNetEmbeddings {
337 word_embeddings: Embedding,
338 position_embeddings: Embedding,
339 token_type_embeddings: Embedding,
340 layer_norm: LayerNorm,
341 #[allow(dead_code)]
342 dropout: f32,
343 device: Device,
344}
345
346impl FNetEmbeddings {
347 pub fn new(config: &FNetConfig) -> Result<Self> {
348 Self::new_with_device(config, Device::CPU)
349 }
350
351 pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
352 let word_embeddings = Embedding::new_with_device(
353 config.vocab_size,
354 config.hidden_size,
355 Some(config.pad_token_id as usize),
356 device,
357 )?;
358 let position_embeddings = Embedding::new_with_device(
359 config.max_position_embeddings,
360 config.hidden_size,
361 None,
362 device,
363 )?;
364 let token_type_embeddings =
365 Embedding::new_with_device(config.type_vocab_size, config.hidden_size, None, device)?;
366 let layer_norm =
367 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
368
369 Ok(Self {
370 word_embeddings,
371 position_embeddings,
372 token_type_embeddings,
373 layer_norm,
374 dropout: config.hidden_dropout_prob,
375 device,
376 })
377 }
378
379 pub fn device(&self) -> Device {
380 self.device
381 }
382
383 pub fn parameter_count(&self) -> usize {
384 self.word_embeddings.parameter_count()
385 + self.position_embeddings.parameter_count()
386 + self.token_type_embeddings.parameter_count()
387 + self.layer_norm.parameter_count()
388 }
389}
390
391impl Layer for FNetEmbeddings {
392 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
393 type Output = Tensor;
394
395 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
396 let (input_ids, token_type_ids, position_ids) = input;
397 let seq_len = input_ids.len();
398
399 let words_embeddings = self.word_embeddings.forward(input_ids)?;
400
401 let position_ids = position_ids.unwrap_or_else(|| (0..seq_len as u32).collect());
402 let position_embeddings = self.position_embeddings.forward(position_ids)?;
403
404 let token_type_ids = token_type_ids.unwrap_or_else(|| vec![0; seq_len]);
405 let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
406
407 let embeddings = words_embeddings.add(&position_embeddings)?.add(&token_type_embeddings)?;
408 let embeddings = self.layer_norm.forward(embeddings)?;
409
410 Ok(embeddings)
411 }
412}
413
414pub struct FNetEncoder {
416 layers: Vec<FNetLayer>,
417 device: Device,
418}
419
420impl FNetEncoder {
421 pub fn new(config: &FNetConfig) -> Result<Self> {
422 Self::new_with_device(config, Device::CPU)
423 }
424
425 pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
426 let mut layers = Vec::new();
427 for _ in 0..config.num_hidden_layers {
428 layers.push(FNetLayer::new_with_device(config, device)?);
429 }
430
431 Ok(Self { layers, device })
432 }
433
434 pub fn device(&self) -> Device {
435 self.device
436 }
437
438 pub fn parameter_count(&self) -> usize {
439 self.layers.iter().map(|layer| layer.parameter_count()).sum()
440 }
441}
442
443impl Layer for FNetEncoder {
444 type Input = Tensor;
445 type Output = Tensor;
446
447 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
448 let mut hidden_states = input;
449
450 for layer in &self.layers {
451 hidden_states = layer.forward(hidden_states)?;
452 }
453
454 Ok(hidden_states)
455 }
456}
457
458pub struct FNetModel {
460 config: FNetConfig,
461 embeddings: FNetEmbeddings,
462 encoder: FNetEncoder,
463 device: Device,
464}
465
466impl FNetModel {
467 pub fn new(config: FNetConfig) -> Result<Self> {
468 Self::new_with_device(config, Device::CPU)
469 }
470
471 pub fn new_with_device(config: FNetConfig, device: Device) -> Result<Self> {
472 config.validate()?;
473
474 let embeddings = FNetEmbeddings::new_with_device(&config, device)?;
475 let encoder = FNetEncoder::new_with_device(&config, device)?;
476
477 Ok(Self {
478 config,
479 embeddings,
480 encoder,
481 device,
482 })
483 }
484
485 pub fn device(&self) -> Device {
486 self.device
487 }
488}
489
490impl Model for FNetModel {
491 type Config = FNetConfig;
492 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
493 type Output = Tensor;
494
495 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
496 let embeddings = self.embeddings.forward(input)?;
497 let sequence_output = self.encoder.forward(embeddings)?;
498 Ok(sequence_output)
499 }
500
501 fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
502 Ok(())
503 }
504
505 fn get_config(&self) -> &Self::Config {
506 &self.config
507 }
508
509 fn num_parameters(&self) -> usize {
510 self.embeddings.parameter_count() + self.encoder.parameter_count()
511 }
512}
513
514pub struct FNetForSequenceClassification {
516 fnet: FNetModel,
517 classifier: Linear,
518 #[allow(dead_code)]
519 num_labels: usize,
520 device: Device,
521}
522
523impl FNetForSequenceClassification {
524 pub fn new(config: FNetConfig, num_labels: usize) -> Result<Self> {
525 Self::new_with_device(config, num_labels, Device::CPU)
526 }
527
528 pub fn new_with_device(config: FNetConfig, num_labels: usize, device: Device) -> Result<Self> {
529 let fnet = FNetModel::new_with_device(config.clone(), device)?;
530 let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
531
532 Ok(Self {
533 fnet,
534 classifier,
535 num_labels,
536 device,
537 })
538 }
539
540 pub fn device(&self) -> Device {
541 self.device
542 }
543}
544
545impl Model for FNetForSequenceClassification {
546 type Config = FNetConfig;
547 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
548 type Output = Tensor;
549
550 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
551 let sequence_output = self.fnet.forward(input)?;
552 let cls_output = sequence_output.slice(1, 0, 1)?; let cls_output = cls_output.squeeze(1)?; self.classifier.forward(cls_output)
555 }
556
557 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
558 self.fnet.load_pretrained(reader)
559 }
560
561 fn get_config(&self) -> &Self::Config {
562 self.fnet.get_config()
563 }
564
565 fn num_parameters(&self) -> usize {
566 self.fnet.num_parameters() + self.classifier.parameter_count()
567 }
568}
569
570pub struct FNetForMaskedLM {
572 fnet: FNetModel,
573 mlm_head: Linear,
574 device: Device,
575}
576
577impl FNetForMaskedLM {
578 pub fn new(config: FNetConfig) -> Result<Self> {
579 Self::new_with_device(config, Device::CPU)
580 }
581
582 pub fn new_with_device(config: FNetConfig, device: Device) -> Result<Self> {
583 let fnet = FNetModel::new_with_device(config.clone(), device)?;
584 let mlm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, true, device);
585
586 Ok(Self {
587 fnet,
588 mlm_head,
589 device,
590 })
591 }
592
593 pub fn device(&self) -> Device {
594 self.device
595 }
596}
597
598impl Model for FNetForMaskedLM {
599 type Config = FNetConfig;
600 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
601 type Output = Tensor;
602
603 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
604 let sequence_output = self.fnet.forward(input)?;
605 self.mlm_head.forward(sequence_output)
606 }
607
608 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
609 self.fnet.load_pretrained(reader)
610 }
611
612 fn get_config(&self) -> &Self::Config {
613 self.fnet.get_config()
614 }
615
616 fn num_parameters(&self) -> usize {
617 self.fnet.num_parameters() + self.mlm_head.parameter_count()
618 }
619}