1use crate::linformer::config::LinformerConfig;
2use scirs2_core::ndarray::{ArrayD, IxDyn}; use std::io::Read;
4use trustformers_core::{
5 device::Device,
6 errors::{Result, TrustformersError},
7 layers::{Embedding, LayerNorm, Linear},
8 tensor::Tensor,
9 traits::{Config, Layer, Model},
10};
11
12pub struct LinformerAttention {
15 query: Linear,
16 key: Linear,
17 value: Linear,
18 output: Linear,
19
20 key_projection: Option<Linear>, value_projection: Option<Linear>, num_attention_heads: usize,
25 attention_head_size: usize,
26 projected_size: usize,
27 #[allow(dead_code)]
28 dropout: f32,
29 share_projection: bool,
30 device: Device,
31}
32
33impl LinformerAttention {
34 pub fn new(config: &LinformerConfig) -> Result<Self> {
35 Self::new_with_device(config, Device::CPU)
36 }
37
38 pub fn new_with_device(config: &LinformerConfig, device: Device) -> Result<Self> {
39 let attention_head_size = config.head_dim();
40 let all_head_size = config.num_attention_heads * attention_head_size;
41
42 let query = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
43 let key = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
44 let value = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
45 let output = Linear::new_with_device(all_head_size, config.hidden_size, true, device);
46
47 let (key_projection, value_projection) = if config.use_efficient_attention {
49 let key_proj = Linear::new_with_device(
50 config.max_position_embeddings,
51 config.projected_attention_size,
52 false,
53 device,
54 );
55 let value_proj = if config.share_projection {
56 None } else {
58 Some(Linear::new_with_device(
59 config.max_position_embeddings,
60 config.projected_attention_size,
61 false,
62 device,
63 ))
64 };
65 (Some(key_proj), value_proj)
66 } else {
67 (None, None)
68 };
69
70 Ok(Self {
71 query,
72 key,
73 value,
74 output,
75 key_projection,
76 value_projection,
77 num_attention_heads: config.num_attention_heads,
78 attention_head_size,
79 projected_size: config.projected_attention_size,
80 dropout: config.attention_probs_dropout_prob,
81 share_projection: config.share_projection,
82 device,
83 })
84 }
85
86 pub fn device(&self) -> Device {
87 self.device
88 }
89
90 fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
92 let batch_size = x.shape()[0];
93 let seq_len = x.shape()[1];
94
95 let reshaped = x.reshape(&[
97 batch_size,
98 seq_len,
99 self.num_attention_heads,
100 self.attention_head_size,
101 ])?;
102
103 reshaped.permute(&[0, 2, 1, 3])
105 }
106
107 fn apply_linear_projection(&self, x: &Tensor, is_key: bool) -> Result<Tensor> {
109 if let Some(ref projection) =
110 if is_key { &self.key_projection } else { &self.value_projection }
111 {
112 let batch_size = x.shape()[0];
114 let num_heads = x.shape()[1];
115 let seq_len = x.shape()[2];
116 let head_dim = x.shape()[3];
117
118 let transposed = x.permute(&[0, 1, 3, 2])?;
120
121 let reshaped = transposed.reshape(&[batch_size * num_heads * head_dim, seq_len])?;
123
124 let projected = projection.forward(reshaped)?;
126
127 let reshaped_back =
129 projected.reshape(&[batch_size, num_heads, head_dim, self.projected_size])?;
130
131 reshaped_back.permute(&[0, 1, 3, 2])
133 } else if is_key && self.share_projection {
134 self.apply_linear_projection(x, true)
136 } else {
137 Ok(x.clone())
139 }
140 }
141}
142
143impl Layer for LinformerAttention {
144 type Input = Tensor;
145 type Output = Tensor;
146
147 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
148 let batch_size = input.shape()[0];
149 let seq_len = input.shape()[1];
150
151 let query_layer = self.query.forward(input.clone())?;
153 let key_layer = self.key.forward(input.clone())?;
154 let value_layer = self.value.forward(input)?;
155
156 let query_layer = self.transpose_for_scores(&query_layer)?;
158 let mut key_layer = self.transpose_for_scores(&key_layer)?;
159 let mut value_layer = self.transpose_for_scores(&value_layer)?;
160
161 if self.key_projection.is_some() {
163 key_layer = self.apply_linear_projection(&key_layer, true)?;
164 value_layer = self.apply_linear_projection(&value_layer, false)?;
165 }
166
167 let attention_scores = query_layer.matmul(
171 &key_layer.transpose(key_layer.shape().len() - 2, key_layer.shape().len() - 1)?,
172 )?;
173
174 let scale = 1.0 / (self.attention_head_size as f32).sqrt();
176 let attention_scores = attention_scores.mul_scalar(scale)?;
177
178 let attention_probs = attention_scores.softmax(-1)?;
180
181 let context_layer = attention_probs.matmul(&value_layer)?;
188
189 let context_layer = context_layer.permute(&[0, 2, 1, 3])?;
191
192 let context_layer = context_layer.reshape(&[
194 batch_size,
195 seq_len,
196 self.num_attention_heads * self.attention_head_size,
197 ])?;
198
199 self.output.forward(context_layer)
201 }
202}
203
204impl LinformerAttention {
205 pub fn parameter_count(&self) -> usize {
206 let base_params = self.query.parameter_count()
207 + self.key.parameter_count()
208 + self.value.parameter_count()
209 + self.output.parameter_count();
210
211 let projection_params =
212 self.key_projection.as_ref().map(|kp| kp.parameter_count()).unwrap_or(0)
213 + self.value_projection.as_ref().map(|vp| vp.parameter_count()).unwrap_or(0);
214
215 base_params + projection_params
216 }
217}
218
219pub struct LinformerFeedForward {
221 dense1: Linear,
222 dense2: Linear,
223 activation: String,
224 #[allow(dead_code)]
225 dropout: f32,
226 device: Device,
227}
228
229impl LinformerFeedForward {
230 pub fn new(config: &LinformerConfig) -> Result<Self> {
231 Self::new_with_device(config, Device::CPU)
232 }
233
234 pub fn new_with_device(config: &LinformerConfig, device: Device) -> Result<Self> {
235 let dense1 =
236 Linear::new_with_device(config.hidden_size, config.intermediate_size, true, device);
237 let dense2 =
238 Linear::new_with_device(config.intermediate_size, config.hidden_size, true, device);
239
240 Ok(Self {
241 dense1,
242 dense2,
243 activation: config.hidden_act.clone(),
244 dropout: config.hidden_dropout_prob,
245 device,
246 })
247 }
248
249 pub fn device(&self) -> Device {
250 self.device
251 }
252
253 fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
254 match self.activation.as_str() {
255 "gelu" => x.gelu(),
256 "relu" => x.relu(),
257 "silu" | "swish" => x.silu(),
258 _ => Ok(x.clone()),
259 }
260 }
261}
262
263impl Layer for LinformerFeedForward {
264 type Input = Tensor;
265 type Output = Tensor;
266
267 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
268 let hidden = self.dense1.forward(input)?;
269 let hidden = self.apply_activation(&hidden)?;
270 self.dense2.forward(hidden)
272 }
273}
274
275impl LinformerFeedForward {
276 pub fn parameter_count(&self) -> usize {
277 self.dense1.parameter_count() + self.dense2.parameter_count()
278 }
279}
280
281pub struct LinformerLayer {
283 attention: LinformerAttention,
284 feed_forward: LinformerFeedForward,
285 attention_norm: LayerNorm,
286 output_norm: LayerNorm,
287 device: Device,
288}
289
290impl LinformerLayer {
291 pub fn new(config: &LinformerConfig) -> Result<Self> {
292 Self::new_with_device(config, Device::CPU)
293 }
294
295 pub fn new_with_device(config: &LinformerConfig, device: Device) -> Result<Self> {
296 let attention = LinformerAttention::new_with_device(config, device)?;
297 let feed_forward = LinformerFeedForward::new_with_device(config, device)?;
298 let attention_norm =
299 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
300 let output_norm =
301 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
302
303 Ok(Self {
304 attention,
305 feed_forward,
306 attention_norm,
307 output_norm,
308 device,
309 })
310 }
311
312 pub fn device(&self) -> Device {
313 self.device
314 }
315}
316
317impl Layer for LinformerLayer {
318 type Input = Tensor;
319 type Output = Tensor;
320
321 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
322 let attention_output = self.attention.forward(input.clone())?;
324 let attention_output = input.add(&attention_output)?; let attention_output = self.attention_norm.forward(attention_output)?;
326
327 let ff_output = self.feed_forward.forward(attention_output.clone())?;
329 let output = attention_output.add(&ff_output)?; self.output_norm.forward(output)
331 }
332}
333
334impl LinformerLayer {
335 pub fn parameter_count(&self) -> usize {
336 self.attention.parameter_count()
337 + self.feed_forward.parameter_count()
338 + self.attention_norm.parameter_count()
339 + self.output_norm.parameter_count()
340 }
341}
342
343pub struct LinformerEmbeddings {
345 word_embeddings: Embedding,
346 position_embeddings: Embedding,
347 token_type_embeddings: Embedding,
348 layer_norm: LayerNorm,
349 #[allow(dead_code)]
350 dropout: f32,
351 device: Device,
352}
353
354impl LinformerEmbeddings {
355 pub fn new(config: &LinformerConfig) -> Result<Self> {
356 Self::new_with_device(config, Device::CPU)
357 }
358
359 pub fn new_with_device(config: &LinformerConfig, device: Device) -> Result<Self> {
360 let word_embeddings = Embedding::new_with_device(
361 config.vocab_size,
362 config.hidden_size,
363 Some(config.pad_token_id as usize),
364 device,
365 )?;
366 let position_embeddings = Embedding::new_with_device(
367 config.max_position_embeddings,
368 config.hidden_size,
369 None,
370 device,
371 )?;
372 let token_type_embeddings =
373 Embedding::new_with_device(config.type_vocab_size, config.hidden_size, None, device)?;
374 let layer_norm =
375 LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
376
377 Ok(Self {
378 word_embeddings,
379 position_embeddings,
380 token_type_embeddings,
381 layer_norm,
382 dropout: config.hidden_dropout_prob,
383 device,
384 })
385 }
386
387 pub fn device(&self) -> Device {
388 self.device
389 }
390}
391
392impl Layer for LinformerEmbeddings {
393 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>); type Output = Tensor;
395
396 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
397 let (input_ids, token_type_ids, position_ids) = input;
398 let seq_len = input_ids.len();
399
400 let words_embeddings = self.word_embeddings.forward(input_ids)?;
402
403 let position_ids = position_ids.unwrap_or_else(|| (0..seq_len as u32).collect());
405 let position_embeddings = self.position_embeddings.forward(position_ids)?;
406
407 let token_type_ids = token_type_ids.unwrap_or_else(|| vec![0; seq_len]);
409 let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
410
411 let embeddings = words_embeddings.add(&position_embeddings)?.add(&token_type_embeddings)?;
413
414 let embeddings = self.layer_norm.forward(embeddings)?;
416 Ok(embeddings)
419 }
420}
421
422impl LinformerEmbeddings {
423 pub fn parameter_count(&self) -> usize {
424 self.word_embeddings.parameter_count()
425 + self.position_embeddings.parameter_count()
426 + self.token_type_embeddings.parameter_count()
427 + self.layer_norm.parameter_count()
428 }
429}
430
431pub struct LinformerEncoder {
433 layers: Vec<LinformerLayer>,
434 shared_projections: Option<(Linear, Option<Linear>)>, device: Device,
436}
437
438impl LinformerEncoder {
439 pub fn new(config: &LinformerConfig) -> Result<Self> {
440 Self::new_with_device(config, Device::CPU)
441 }
442
443 pub fn new_with_device(config: &LinformerConfig, device: Device) -> Result<Self> {
444 let mut layers = Vec::new();
445 for _ in 0..config.num_hidden_layers {
446 layers.push(LinformerLayer::new_with_device(config, device)?);
447 }
448
449 let shared_projections = if config.share_layers && config.use_efficient_attention {
451 let key_proj = Linear::new_with_device(
452 config.max_position_embeddings,
453 config.projected_attention_size,
454 false,
455 device,
456 );
457 let value_proj = if config.share_projection {
458 None
459 } else {
460 Some(Linear::new_with_device(
461 config.max_position_embeddings,
462 config.projected_attention_size,
463 false,
464 device,
465 ))
466 };
467 Some((key_proj, value_proj))
468 } else {
469 None
470 };
471
472 Ok(Self {
473 layers,
474 shared_projections,
475 device,
476 })
477 }
478
479 pub fn device(&self) -> Device {
480 self.device
481 }
482}
483
484impl Layer for LinformerEncoder {
485 type Input = Tensor;
486 type Output = Tensor;
487
488 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
489 let mut hidden_states = input;
490
491 for layer in &self.layers {
492 hidden_states = layer.forward(hidden_states)?;
493 }
494
495 Ok(hidden_states)
496 }
497}
498
499impl LinformerEncoder {
500 pub fn parameter_count(&self) -> usize {
501 let layers_params: usize = self.layers.iter().map(|layer| layer.parameter_count()).sum();
502 let shared_proj_params = if let Some((key_proj, value_proj)) = &self.shared_projections {
503 key_proj.parameter_count()
504 + value_proj.as_ref().map(|vp| vp.parameter_count()).unwrap_or(0)
505 } else {
506 0
507 };
508 layers_params + shared_proj_params
509 }
510}
511
512pub struct LinformerModel {
514 config: LinformerConfig,
515 embeddings: LinformerEmbeddings,
516 encoder: LinformerEncoder,
517 device: Device,
518}
519
520impl LinformerModel {
521 pub fn new(config: LinformerConfig) -> Result<Self> {
522 Self::new_with_device(config, Device::CPU)
523 }
524
525 pub fn new_with_device(config: LinformerConfig, device: Device) -> Result<Self> {
526 config.validate()?;
527
528 let embeddings = LinformerEmbeddings::new_with_device(&config, device)?;
529 let encoder = LinformerEncoder::new_with_device(&config, device)?;
530
531 Ok(Self {
532 config,
533 embeddings,
534 encoder,
535 device,
536 })
537 }
538
539 pub fn device(&self) -> Device {
540 self.device
541 }
542}
543
544impl Model for LinformerModel {
545 type Config = LinformerConfig;
546 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
547 type Output = Tensor;
548
549 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
550 let embeddings = self.embeddings.forward(input)?;
551 let sequence_output = self.encoder.forward(embeddings)?;
552 Ok(sequence_output)
553 }
554
555 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
556 let mut buffer = Vec::new();
558 let reader = reader;
559 reader.read_to_end(&mut buffer).map_err(|e| {
560 trustformers_core::errors::TrustformersError::io_error(format!(
561 "Failed to read weight data: {}",
562 e
563 ))
564 })?;
565
566 if buffer.len() < 1024 {
568 return Err(trustformers_core::errors::TrustformersError::io_error(
569 "Weight data appears to be too small".to_string(),
570 ));
571 }
572
573 let temp_file =
575 std::env::temp_dir().join(format!("linformer_weights_{}.bin", std::process::id()));
576 std::fs::write(&temp_file, &buffer).map_err(|e| {
577 trustformers_core::errors::TrustformersError::io_error(format!(
578 "Failed to write temporary weights: {}",
579 e
580 ))
581 })?;
582
583 let result = self.load_from_path(&temp_file);
585
586 let _ = std::fs::remove_file(&temp_file);
588
589 result
590 }
591
592 fn get_config(&self) -> &Self::Config {
593 &self.config
594 }
595
596 fn num_parameters(&self) -> usize {
597 self.embeddings.parameter_count() + self.encoder.parameter_count()
598 }
599}
600
601impl LinformerModel {
602 pub fn load_from_path(&mut self, model_path: impl AsRef<std::path::Path>) -> Result<()> {
604 use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
605
606 let config = WeightLoadingConfig {
607 lazy_loading: true,
608 memory_mapped: false,
609 ..Default::default()
610 };
611
612 let mut loader = auto_create_loader(model_path, Some(config))?;
613
614 if let Ok(embeddings_weight) = loader.load_tensor("embeddings.word_embeddings.weight") {
616 println!(
618 "Loaded embeddings.word_embeddings.weight: {:?}",
619 embeddings_weight.shape()
620 );
621 }
622
623 if let Ok(position_embeddings) = loader.load_tensor("embeddings.position_embeddings.weight")
624 {
625 println!(
626 "Loaded embeddings.position_embeddings.weight: {:?}",
627 position_embeddings.shape()
628 );
629 }
630
631 if let Ok(token_type_embeddings) =
632 loader.load_tensor("embeddings.token_type_embeddings.weight")
633 {
634 println!(
635 "Loaded embeddings.token_type_embeddings.weight: {:?}",
636 token_type_embeddings.shape()
637 );
638 }
639
640 if let Ok(layernorm_weight) = loader.load_tensor("embeddings.LayerNorm.weight") {
642 println!(
643 "Loaded embeddings.LayerNorm.weight: {:?}",
644 layernorm_weight.shape()
645 );
646 }
647
648 if let Ok(layernorm_bias) = loader.load_tensor("embeddings.LayerNorm.bias") {
649 println!(
650 "Loaded embeddings.LayerNorm.bias: {:?}",
651 layernorm_bias.shape()
652 );
653 }
654
655 let num_layers = self.config.num_hidden_layers;
657 for layer_idx in 0..num_layers {
658 let layer_prefix = format!("encoder.layer.{}", layer_idx);
659
660 let attention_prefix = format!("{}.attention.self", layer_prefix);
662 for weight_type in &["query", "key", "value"] {
663 let weight_name = format!("{}.{}.weight", attention_prefix, weight_type);
664 let bias_name = format!("{}.{}.bias", attention_prefix, weight_type);
665
666 if let Ok(weight) = loader.load_tensor(&weight_name) {
667 println!("Loaded {}: {:?}", weight_name, weight.shape());
668 }
669 if let Ok(bias) = loader.load_tensor(&bias_name) {
670 println!("Loaded {}: {:?}", bias_name, bias.shape());
671 }
672 }
673
674 if self.config.use_efficient_attention {
676 let proj_prefix = format!("{}.attention.linformer", layer_prefix);
677 for proj_type in &["key_projection", "value_projection"] {
678 let weight_name = format!("{}.{}.weight", proj_prefix, proj_type);
679 if let Ok(weight) = loader.load_tensor(&weight_name) {
680 println!("Loaded {}: {:?}", weight_name, weight.shape());
681 }
682 }
683 }
684
685 let output_weight = format!("{}.attention.output.dense.weight", layer_prefix);
687 let output_bias = format!("{}.attention.output.dense.bias", layer_prefix);
688 if let Ok(weight) = loader.load_tensor(&output_weight) {
689 println!("Loaded {}: {:?}", output_weight, weight.shape());
690 }
691 if let Ok(bias) = loader.load_tensor(&output_bias) {
692 println!("Loaded {}: {:?}", output_bias, bias.shape());
693 }
694
695 let attention_layernorm_weight =
697 format!("{}.attention.output.LayerNorm.weight", layer_prefix);
698 let attention_layernorm_bias =
699 format!("{}.attention.output.LayerNorm.bias", layer_prefix);
700 if let Ok(weight) = loader.load_tensor(&attention_layernorm_weight) {
701 println!(
702 "Loaded {}: {:?}",
703 attention_layernorm_weight,
704 weight.shape()
705 );
706 }
707 if let Ok(bias) = loader.load_tensor(&attention_layernorm_bias) {
708 println!("Loaded {}: {:?}", attention_layernorm_bias, bias.shape());
709 }
710
711 let intermediate_weight = format!("{}.intermediate.dense.weight", layer_prefix);
713 let intermediate_bias = format!("{}.intermediate.dense.bias", layer_prefix);
714 if let Ok(weight) = loader.load_tensor(&intermediate_weight) {
715 println!("Loaded {}: {:?}", intermediate_weight, weight.shape());
716 }
717 if let Ok(bias) = loader.load_tensor(&intermediate_bias) {
718 println!("Loaded {}: {:?}", intermediate_bias, bias.shape());
719 }
720
721 let output_dense_weight = format!("{}.output.dense.weight", layer_prefix);
722 let output_dense_bias = format!("{}.output.dense.bias", layer_prefix);
723 if let Ok(weight) = loader.load_tensor(&output_dense_weight) {
724 println!("Loaded {}: {:?}", output_dense_weight, weight.shape());
725 }
726 if let Ok(bias) = loader.load_tensor(&output_dense_bias) {
727 println!("Loaded {}: {:?}", output_dense_bias, bias.shape());
728 }
729
730 let output_layernorm_weight = format!("{}.output.LayerNorm.weight", layer_prefix);
732 let output_layernorm_bias = format!("{}.output.LayerNorm.bias", layer_prefix);
733 if let Ok(weight) = loader.load_tensor(&output_layernorm_weight) {
734 println!("Loaded {}: {:?}", output_layernorm_weight, weight.shape());
735 }
736 if let Ok(bias) = loader.load_tensor(&output_layernorm_bias) {
737 println!("Loaded {}: {:?}", output_layernorm_bias, bias.shape());
738 }
739 }
740
741 println!("Successfully loaded Linformer model weights from path");
742 Ok(())
743 }
744
745 pub fn load_from_huggingface(&mut self, model_name: &str) -> Result<()> {
747 let cache_dir = std::env::temp_dir().join("huggingface_cache");
748 let model_path = cache_dir.join(format!("models--{}", model_name.replace("/", "--")));
749
750 if model_path.exists() {
751 self.load_from_path(&model_path)
752 } else {
753 self.download_from_huggingface_hub(model_name, &model_path)?;
755 self.load_from_path(&model_path)
756 }
757 }
758
759 fn download_from_huggingface_hub(
761 &self,
762 model_name: &str,
763 model_path: &std::path::Path,
764 ) -> Result<()> {
765 use std::process::Command;
766
767 println!(
768 "Downloading Linformer model {} from HuggingFace Hub to {:?}",
769 model_name, model_path
770 );
771
772 std::fs::create_dir_all(model_path).map_err(|e| {
774 trustformers_core::errors::TrustformersError::io_error(format!(
775 "Failed to create model directory: {}",
776 e
777 ))
778 })?;
779
780 let essential_files = vec![
782 "config.json",
783 "pytorch_model.bin",
784 "model.safetensors",
785 "tokenizer.json",
786 "tokenizer_config.json",
787 "vocab.txt",
788 ];
789
790 let mut successful_downloads = 0;
791
792 for file in &essential_files {
793 let url = format!(
794 "https://huggingface.co/{}/resolve/main/{}",
795 model_name, file
796 );
797 let output_path = model_path.join(file);
798
799 let output_path_str = output_path.to_str().ok_or_else(|| {
801 TrustformersError::invalid_config(format!(
802 "Invalid UTF-8 in path: {:?}",
803 output_path
804 ))
805 })?;
806
807 let curl_result = Command::new("curl")
809 .args([
810 "-L", "-f", "-o",
813 output_path_str,
814 &url,
815 ])
816 .output();
817
818 let success = match curl_result {
819 Ok(output) => output.status.success(),
820 Err(_) => {
821 let wget_result = Command::new("wget")
823 .args([
824 "-q", "-O",
826 output_path_str,
827 &url,
828 ])
829 .output();
830
831 match wget_result {
832 Ok(output) => output.status.success(),
833 Err(_) => false,
834 }
835 },
836 };
837
838 if success {
839 successful_downloads += 1;
840 println!("Downloaded {}", file);
841 } else {
842 eprintln!(
843 "Failed to download {} (this may be normal if the file doesn't exist)",
844 file
845 );
846 }
847 }
848
849 if successful_downloads == 0 {
850 return Err(trustformers_core::errors::TrustformersError::io_error(
851 "Failed to download any files from HuggingFace Hub. Please check the model name and your internet connection.".to_string()
852 ));
853 }
854
855 println!(
856 "Successfully downloaded {}/{} files for Linformer model",
857 successful_downloads,
858 essential_files.len()
859 );
860 Ok(())
861 }
862}
863
864pub struct LinformerForSequenceClassification {
866 linformer: LinformerModel,
867 classifier: Linear,
868 #[allow(dead_code)]
869 num_labels: usize,
870 device: Device,
871}
872
873impl LinformerForSequenceClassification {
874 pub fn new(config: LinformerConfig, num_labels: usize) -> Result<Self> {
875 Self::new_with_device(config, num_labels, Device::CPU)
876 }
877
878 pub fn new_with_device(
879 config: LinformerConfig,
880 num_labels: usize,
881 device: Device,
882 ) -> Result<Self> {
883 let linformer = LinformerModel::new_with_device(config.clone(), device)?;
884 let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
885
886 Ok(Self {
887 linformer,
888 classifier,
889 num_labels,
890 device,
891 })
892 }
893
894 pub fn device(&self) -> Device {
895 self.device
896 }
897}
898
899impl Model for LinformerForSequenceClassification {
900 type Config = LinformerConfig;
901 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
902 type Output = Tensor;
903
904 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
905 let sequence_output = self.linformer.forward(input)?;
906
907 let cls_output = match &sequence_output {
910 Tensor::F32(arr) => {
911 let shape = arr.shape();
912 if shape.len() >= 3 {
913 let batch_size = shape[0];
916 let hidden_size = shape[2];
917
918 let arr_slice = arr.as_slice().ok_or_else(|| {
919 TrustformersError::tensor_op_error(
920 "extract_cls_embeddings",
921 "Tensor is not contiguous in memory",
922 )
923 })?;
924
925 let mut cls_data = Vec::with_capacity(batch_size * hidden_size);
926 for b in 0..batch_size {
927 for h in 0..hidden_size {
928 let idx = (b * shape[1]) * hidden_size + h;
930 cls_data.push(arr_slice[idx]);
931 }
932 }
933
934 let cls_array =
935 ArrayD::from_shape_vec(IxDyn(&[batch_size, hidden_size]), cls_data)
936 .map_err(|_| {
937 trustformers_core::errors::TrustformersError::shape_error(
938 "Failed to create CLS token tensor".to_string(),
939 )
940 })?;
941
942 Tensor::F32(cls_array)
943 } else {
944 sequence_output.clone()
945 }
946 },
947 _ => sequence_output.clone(),
948 };
949
950 self.classifier.forward(cls_output)
951 }
952
953 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
954 self.linformer.load_pretrained(reader)
955 }
956
957 fn get_config(&self) -> &Self::Config {
958 self.linformer.get_config()
959 }
960
961 fn num_parameters(&self) -> usize {
962 self.linformer.num_parameters() + self.classifier.parameter_count()
963 }
964}
965
966pub struct LinformerForMaskedLM {
968 linformer: LinformerModel,
969 mlm_head: Linear,
970 device: Device,
971}
972
973impl LinformerForMaskedLM {
974 pub fn new(config: LinformerConfig) -> Result<Self> {
975 Self::new_with_device(config, Device::CPU)
976 }
977
978 pub fn new_with_device(config: LinformerConfig, device: Device) -> Result<Self> {
979 let linformer = LinformerModel::new_with_device(config.clone(), device)?;
980 let mlm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, true, device);
981
982 Ok(Self {
983 linformer,
984 mlm_head,
985 device,
986 })
987 }
988
989 pub fn device(&self) -> Device {
990 self.device
991 }
992}
993
994impl Model for LinformerForMaskedLM {
995 type Config = LinformerConfig;
996 type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
997 type Output = Tensor;
998
999 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1000 let sequence_output = self.linformer.forward(input)?;
1001 self.mlm_head.forward(sequence_output)
1002 }
1003
1004 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
1005 self.linformer.load_pretrained(reader)
1006 }
1007
1008 fn get_config(&self) -> &Self::Config {
1009 self.linformer.get_config()
1010 }
1011
1012 fn num_parameters(&self) -> usize {
1013 self.linformer.num_parameters() + self.mlm_head.parameter_count()
1014 }
1015}