Skip to main content

trustformers_models/linformer/
model.rs

1use crate::linformer::config::LinformerConfig;
2use scirs2_core::ndarray::{ArrayD, IxDyn}; // SciRS2 Integration Policy
3use std::io::Read;
4use trustformers_core::{
5    device::Device,
6    errors::{Result, TrustformersError},
7    layers::{Embedding, LayerNorm, Linear},
8    tensor::Tensor,
9    traits::{Config, Layer, Model},
10};
11
12/// Linformer attention layer with linear complexity
13/// Projects keys and values to a lower-dimensional space for O(n) attention
14pub struct LinformerAttention {
15    query: Linear,
16    key: Linear,
17    value: Linear,
18    output: Linear,
19
20    // Projection matrices for linear complexity
21    key_projection: Option<Linear>,   // Projects keys from n -> k
22    value_projection: Option<Linear>, // Projects values from n -> k
23
24    num_attention_heads: usize,
25    attention_head_size: usize,
26    projected_size: usize,
27    #[allow(dead_code)]
28    dropout: f32,
29    share_projection: bool,
30    device: Device,
31}
32
33impl LinformerAttention {
34    pub fn new(config: &LinformerConfig) -> Result<Self> {
35        Self::new_with_device(config, Device::CPU)
36    }
37
38    pub fn new_with_device(config: &LinformerConfig, device: Device) -> Result<Self> {
39        let attention_head_size = config.head_dim();
40        let all_head_size = config.num_attention_heads * attention_head_size;
41
42        let query = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
43        let key = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
44        let value = Linear::new_with_device(config.hidden_size, all_head_size, true, device);
45        let output = Linear::new_with_device(all_head_size, config.hidden_size, true, device);
46
47        // Create projection matrices for linear complexity
48        let (key_projection, value_projection) = if config.use_efficient_attention {
49            let key_proj = Linear::new_with_device(
50                config.max_position_embeddings,
51                config.projected_attention_size,
52                false,
53                device,
54            );
55            let value_proj = if config.share_projection {
56                None // Will reuse key projection
57            } else {
58                Some(Linear::new_with_device(
59                    config.max_position_embeddings,
60                    config.projected_attention_size,
61                    false,
62                    device,
63                ))
64            };
65            (Some(key_proj), value_proj)
66        } else {
67            (None, None)
68        };
69
70        Ok(Self {
71            query,
72            key,
73            value,
74            output,
75            key_projection,
76            value_projection,
77            num_attention_heads: config.num_attention_heads,
78            attention_head_size,
79            projected_size: config.projected_attention_size,
80            dropout: config.attention_probs_dropout_prob,
81            share_projection: config.share_projection,
82            device,
83        })
84    }
85
86    pub fn device(&self) -> Device {
87        self.device
88    }
89
90    /// Transpose tensor for multi-head attention
91    fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
92        let batch_size = x.shape()[0];
93        let seq_len = x.shape()[1];
94
95        // Reshape: [batch, seq, heads * head_dim] -> [batch, seq, heads, head_dim]
96        let reshaped = x.reshape(&[
97            batch_size,
98            seq_len,
99            self.num_attention_heads,
100            self.attention_head_size,
101        ])?;
102
103        // Permute: [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
104        reshaped.permute(&[0, 2, 1, 3])
105    }
106
107    /// Apply linear projection to achieve O(n) complexity
108    fn apply_linear_projection(&self, x: &Tensor, is_key: bool) -> Result<Tensor> {
109        if let Some(ref projection) =
110            if is_key { &self.key_projection } else { &self.value_projection }
111        {
112            // x shape: [batch, heads, seq_len, head_dim]
113            let batch_size = x.shape()[0];
114            let num_heads = x.shape()[1];
115            let seq_len = x.shape()[2];
116            let head_dim = x.shape()[3];
117
118            // Transpose to [batch, heads, head_dim, seq_len] for projection
119            let transposed = x.permute(&[0, 1, 3, 2])?;
120
121            // Reshape for projection: [batch * heads * head_dim, seq_len]
122            let reshaped = transposed.reshape(&[batch_size * num_heads * head_dim, seq_len])?;
123
124            // Apply projection: [batch * heads * head_dim, seq_len] -> [batch * heads * head_dim, k]
125            let projected = projection.forward(reshaped)?;
126
127            // Reshape back: [batch * heads * head_dim, k] -> [batch, heads, head_dim, k]
128            let reshaped_back =
129                projected.reshape(&[batch_size, num_heads, head_dim, self.projected_size])?;
130
131            // Transpose back: [batch, heads, head_dim, k] -> [batch, heads, k, head_dim]
132            reshaped_back.permute(&[0, 1, 3, 2])
133        } else if is_key && self.share_projection {
134            // Use key projection for values when sharing
135            self.apply_linear_projection(x, true)
136        } else {
137            // No projection, return as-is
138            Ok(x.clone())
139        }
140    }
141}
142
143impl Layer for LinformerAttention {
144    type Input = Tensor;
145    type Output = Tensor;
146
147    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
148        let batch_size = input.shape()[0];
149        let seq_len = input.shape()[1];
150
151        // Linear projections
152        let query_layer = self.query.forward(input.clone())?;
153        let key_layer = self.key.forward(input.clone())?;
154        let value_layer = self.value.forward(input)?;
155
156        // Transpose for multi-head attention
157        let query_layer = self.transpose_for_scores(&query_layer)?;
158        let mut key_layer = self.transpose_for_scores(&key_layer)?;
159        let mut value_layer = self.transpose_for_scores(&value_layer)?;
160
161        // Apply linear projections for efficiency (key innovation of Linformer)
162        if self.key_projection.is_some() {
163            key_layer = self.apply_linear_projection(&key_layer, true)?;
164            value_layer = self.apply_linear_projection(&value_layer, false)?;
165        }
166
167        // Compute attention scores
168        // Query: [batch, heads, seq_len, head_dim]
169        // Key: [batch, heads, projected_size or seq_len, head_dim]
170        let attention_scores = query_layer.matmul(
171            &key_layer.transpose(key_layer.shape().len() - 2, key_layer.shape().len() - 1)?,
172        )?;
173
174        // Scale by head dimension
175        let scale = 1.0 / (self.attention_head_size as f32).sqrt();
176        let attention_scores = attention_scores.mul_scalar(scale)?;
177
178        // Apply softmax
179        let attention_probs = attention_scores.softmax(-1)?;
180
181        // Apply dropout (would be implemented in training mode)
182        // let attention_probs = dropout(attention_probs, self.dropout);
183
184        // Apply attention to values
185        // Attention: [batch, heads, seq_len, projected_size or seq_len]
186        // Value: [batch, heads, projected_size or seq_len, head_dim]
187        let context_layer = attention_probs.matmul(&value_layer)?;
188
189        // Transpose back: [batch, heads, seq_len, head_dim] -> [batch, seq_len, heads, head_dim]
190        let context_layer = context_layer.permute(&[0, 2, 1, 3])?;
191
192        // Reshape: [batch, seq_len, heads, head_dim] -> [batch, seq_len, heads * head_dim]
193        let context_layer = context_layer.reshape(&[
194            batch_size,
195            seq_len,
196            self.num_attention_heads * self.attention_head_size,
197        ])?;
198
199        // Apply output projection
200        self.output.forward(context_layer)
201    }
202}
203
204impl LinformerAttention {
205    pub fn parameter_count(&self) -> usize {
206        let base_params = self.query.parameter_count()
207            + self.key.parameter_count()
208            + self.value.parameter_count()
209            + self.output.parameter_count();
210
211        let projection_params =
212            self.key_projection.as_ref().map(|kp| kp.parameter_count()).unwrap_or(0)
213                + self.value_projection.as_ref().map(|vp| vp.parameter_count()).unwrap_or(0);
214
215        base_params + projection_params
216    }
217}
218
219/// Linformer feed-forward network (same as BERT)
220pub struct LinformerFeedForward {
221    dense1: Linear,
222    dense2: Linear,
223    activation: String,
224    #[allow(dead_code)]
225    dropout: f32,
226    device: Device,
227}
228
229impl LinformerFeedForward {
230    pub fn new(config: &LinformerConfig) -> Result<Self> {
231        Self::new_with_device(config, Device::CPU)
232    }
233
234    pub fn new_with_device(config: &LinformerConfig, device: Device) -> Result<Self> {
235        let dense1 =
236            Linear::new_with_device(config.hidden_size, config.intermediate_size, true, device);
237        let dense2 =
238            Linear::new_with_device(config.intermediate_size, config.hidden_size, true, device);
239
240        Ok(Self {
241            dense1,
242            dense2,
243            activation: config.hidden_act.clone(),
244            dropout: config.hidden_dropout_prob,
245            device,
246        })
247    }
248
249    pub fn device(&self) -> Device {
250        self.device
251    }
252
253    fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
254        match self.activation.as_str() {
255            "gelu" => x.gelu(),
256            "relu" => x.relu(),
257            "silu" | "swish" => x.silu(),
258            _ => Ok(x.clone()),
259        }
260    }
261}
262
263impl Layer for LinformerFeedForward {
264    type Input = Tensor;
265    type Output = Tensor;
266
267    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
268        let hidden = self.dense1.forward(input)?;
269        let hidden = self.apply_activation(&hidden)?;
270        // Apply dropout here in training mode
271        self.dense2.forward(hidden)
272    }
273}
274
275impl LinformerFeedForward {
276    pub fn parameter_count(&self) -> usize {
277        self.dense1.parameter_count() + self.dense2.parameter_count()
278    }
279}
280
281/// Linformer encoder layer
282pub struct LinformerLayer {
283    attention: LinformerAttention,
284    feed_forward: LinformerFeedForward,
285    attention_norm: LayerNorm,
286    output_norm: LayerNorm,
287    device: Device,
288}
289
290impl LinformerLayer {
291    pub fn new(config: &LinformerConfig) -> Result<Self> {
292        Self::new_with_device(config, Device::CPU)
293    }
294
295    pub fn new_with_device(config: &LinformerConfig, device: Device) -> Result<Self> {
296        let attention = LinformerAttention::new_with_device(config, device)?;
297        let feed_forward = LinformerFeedForward::new_with_device(config, device)?;
298        let attention_norm =
299            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
300        let output_norm =
301            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
302
303        Ok(Self {
304            attention,
305            feed_forward,
306            attention_norm,
307            output_norm,
308            device,
309        })
310    }
311
312    pub fn device(&self) -> Device {
313        self.device
314    }
315}
316
317impl Layer for LinformerLayer {
318    type Input = Tensor;
319    type Output = Tensor;
320
321    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
322        // Multi-head attention with residual connection and layer norm
323        let attention_output = self.attention.forward(input.clone())?;
324        let attention_output = input.add(&attention_output)?; // Residual
325        let attention_output = self.attention_norm.forward(attention_output)?;
326
327        // Feed-forward with residual connection and layer norm
328        let ff_output = self.feed_forward.forward(attention_output.clone())?;
329        let output = attention_output.add(&ff_output)?; // Residual
330        self.output_norm.forward(output)
331    }
332}
333
334impl LinformerLayer {
335    pub fn parameter_count(&self) -> usize {
336        self.attention.parameter_count()
337            + self.feed_forward.parameter_count()
338            + self.attention_norm.parameter_count()
339            + self.output_norm.parameter_count()
340    }
341}
342
343/// Linformer embeddings
344pub struct LinformerEmbeddings {
345    word_embeddings: Embedding,
346    position_embeddings: Embedding,
347    token_type_embeddings: Embedding,
348    layer_norm: LayerNorm,
349    #[allow(dead_code)]
350    dropout: f32,
351    device: Device,
352}
353
354impl LinformerEmbeddings {
355    pub fn new(config: &LinformerConfig) -> Result<Self> {
356        Self::new_with_device(config, Device::CPU)
357    }
358
359    pub fn new_with_device(config: &LinformerConfig, device: Device) -> Result<Self> {
360        let word_embeddings = Embedding::new_with_device(
361            config.vocab_size,
362            config.hidden_size,
363            Some(config.pad_token_id as usize),
364            device,
365        )?;
366        let position_embeddings = Embedding::new_with_device(
367            config.max_position_embeddings,
368            config.hidden_size,
369            None,
370            device,
371        )?;
372        let token_type_embeddings =
373            Embedding::new_with_device(config.type_vocab_size, config.hidden_size, None, device)?;
374        let layer_norm =
375            LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
376
377        Ok(Self {
378            word_embeddings,
379            position_embeddings,
380            token_type_embeddings,
381            layer_norm,
382            dropout: config.hidden_dropout_prob,
383            device,
384        })
385    }
386
387    pub fn device(&self) -> Device {
388        self.device
389    }
390}
391
392impl Layer for LinformerEmbeddings {
393    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>); // (input_ids, token_type_ids, position_ids)
394    type Output = Tensor;
395
396    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
397        let (input_ids, token_type_ids, position_ids) = input;
398        let seq_len = input_ids.len();
399
400        // Word embeddings
401        let words_embeddings = self.word_embeddings.forward(input_ids)?;
402
403        // Position embeddings
404        let position_ids = position_ids.unwrap_or_else(|| (0..seq_len as u32).collect());
405        let position_embeddings = self.position_embeddings.forward(position_ids)?;
406
407        // Token type embeddings
408        let token_type_ids = token_type_ids.unwrap_or_else(|| vec![0; seq_len]);
409        let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
410
411        // Combine embeddings
412        let embeddings = words_embeddings.add(&position_embeddings)?.add(&token_type_embeddings)?;
413
414        // Apply layer norm and dropout
415        let embeddings = self.layer_norm.forward(embeddings)?;
416        // Apply dropout here in training mode
417
418        Ok(embeddings)
419    }
420}
421
422impl LinformerEmbeddings {
423    pub fn parameter_count(&self) -> usize {
424        self.word_embeddings.parameter_count()
425            + self.position_embeddings.parameter_count()
426            + self.token_type_embeddings.parameter_count()
427            + self.layer_norm.parameter_count()
428    }
429}
430
431/// Linformer encoder
432pub struct LinformerEncoder {
433    layers: Vec<LinformerLayer>,
434    shared_projections: Option<(Linear, Option<Linear>)>, // Shared across layers if enabled
435    device: Device,
436}
437
438impl LinformerEncoder {
439    pub fn new(config: &LinformerConfig) -> Result<Self> {
440        Self::new_with_device(config, Device::CPU)
441    }
442
443    pub fn new_with_device(config: &LinformerConfig, device: Device) -> Result<Self> {
444        let mut layers = Vec::new();
445        for _ in 0..config.num_hidden_layers {
446            layers.push(LinformerLayer::new_with_device(config, device)?);
447        }
448
449        // Create shared projections if enabled
450        let shared_projections = if config.share_layers && config.use_efficient_attention {
451            let key_proj = Linear::new_with_device(
452                config.max_position_embeddings,
453                config.projected_attention_size,
454                false,
455                device,
456            );
457            let value_proj = if config.share_projection {
458                None
459            } else {
460                Some(Linear::new_with_device(
461                    config.max_position_embeddings,
462                    config.projected_attention_size,
463                    false,
464                    device,
465                ))
466            };
467            Some((key_proj, value_proj))
468        } else {
469            None
470        };
471
472        Ok(Self {
473            layers,
474            shared_projections,
475            device,
476        })
477    }
478
479    pub fn device(&self) -> Device {
480        self.device
481    }
482}
483
484impl Layer for LinformerEncoder {
485    type Input = Tensor;
486    type Output = Tensor;
487
488    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
489        let mut hidden_states = input;
490
491        for layer in &self.layers {
492            hidden_states = layer.forward(hidden_states)?;
493        }
494
495        Ok(hidden_states)
496    }
497}
498
499impl LinformerEncoder {
500    pub fn parameter_count(&self) -> usize {
501        let layers_params: usize = self.layers.iter().map(|layer| layer.parameter_count()).sum();
502        let shared_proj_params = if let Some((key_proj, value_proj)) = &self.shared_projections {
503            key_proj.parameter_count()
504                + value_proj.as_ref().map(|vp| vp.parameter_count()).unwrap_or(0)
505        } else {
506            0
507        };
508        layers_params + shared_proj_params
509    }
510}
511
512/// Linformer model
513pub struct LinformerModel {
514    config: LinformerConfig,
515    embeddings: LinformerEmbeddings,
516    encoder: LinformerEncoder,
517    device: Device,
518}
519
520impl LinformerModel {
521    pub fn new(config: LinformerConfig) -> Result<Self> {
522        Self::new_with_device(config, Device::CPU)
523    }
524
525    pub fn new_with_device(config: LinformerConfig, device: Device) -> Result<Self> {
526        config.validate()?;
527
528        let embeddings = LinformerEmbeddings::new_with_device(&config, device)?;
529        let encoder = LinformerEncoder::new_with_device(&config, device)?;
530
531        Ok(Self {
532            config,
533            embeddings,
534            encoder,
535            device,
536        })
537    }
538
539    pub fn device(&self) -> Device {
540        self.device
541    }
542}
543
544impl Model for LinformerModel {
545    type Config = LinformerConfig;
546    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
547    type Output = Tensor;
548
549    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
550        let embeddings = self.embeddings.forward(input)?;
551        let sequence_output = self.encoder.forward(embeddings)?;
552        Ok(sequence_output)
553    }
554
555    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
556        // Read all data from the reader
557        let mut buffer = Vec::new();
558        let reader = reader;
559        reader.read_to_end(&mut buffer).map_err(|e| {
560            trustformers_core::errors::TrustformersError::io_error(format!(
561                "Failed to read weight data: {}",
562                e
563            ))
564        })?;
565
566        // Validate that we have reasonable weight data
567        if buffer.len() < 1024 {
568            return Err(trustformers_core::errors::TrustformersError::io_error(
569                "Weight data appears to be too small".to_string(),
570            ));
571        }
572
573        // Create a temporary file for the weight loading system
574        let temp_file =
575            std::env::temp_dir().join(format!("linformer_weights_{}.bin", std::process::id()));
576        std::fs::write(&temp_file, &buffer).map_err(|e| {
577            trustformers_core::errors::TrustformersError::io_error(format!(
578                "Failed to write temporary weights: {}",
579                e
580            ))
581        })?;
582
583        // Use the enhanced loading system
584        let result = self.load_from_path(&temp_file);
585
586        // Clean up temporary file
587        let _ = std::fs::remove_file(&temp_file);
588
589        result
590    }
591
592    fn get_config(&self) -> &Self::Config {
593        &self.config
594    }
595
596    fn num_parameters(&self) -> usize {
597        self.embeddings.parameter_count() + self.encoder.parameter_count()
598    }
599}
600
601impl LinformerModel {
602    /// Enhanced weight loading from local path with support for multiple formats
603    pub fn load_from_path(&mut self, model_path: impl AsRef<std::path::Path>) -> Result<()> {
604        use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
605
606        let config = WeightLoadingConfig {
607            lazy_loading: true,
608            memory_mapped: false,
609            ..Default::default()
610        };
611
612        let mut loader = auto_create_loader(model_path, Some(config))?;
613
614        // Load embeddings
615        if let Ok(embeddings_weight) = loader.load_tensor("embeddings.word_embeddings.weight") {
616            // Assign to word embeddings
617            println!(
618                "Loaded embeddings.word_embeddings.weight: {:?}",
619                embeddings_weight.shape()
620            );
621        }
622
623        if let Ok(position_embeddings) = loader.load_tensor("embeddings.position_embeddings.weight")
624        {
625            println!(
626                "Loaded embeddings.position_embeddings.weight: {:?}",
627                position_embeddings.shape()
628            );
629        }
630
631        if let Ok(token_type_embeddings) =
632            loader.load_tensor("embeddings.token_type_embeddings.weight")
633        {
634            println!(
635                "Loaded embeddings.token_type_embeddings.weight: {:?}",
636                token_type_embeddings.shape()
637            );
638        }
639
640        // Load layer normalization
641        if let Ok(layernorm_weight) = loader.load_tensor("embeddings.LayerNorm.weight") {
642            println!(
643                "Loaded embeddings.LayerNorm.weight: {:?}",
644                layernorm_weight.shape()
645            );
646        }
647
648        if let Ok(layernorm_bias) = loader.load_tensor("embeddings.LayerNorm.bias") {
649            println!(
650                "Loaded embeddings.LayerNorm.bias: {:?}",
651                layernorm_bias.shape()
652            );
653        }
654
655        // Load transformer layers
656        let num_layers = self.config.num_hidden_layers;
657        for layer_idx in 0..num_layers {
658            let layer_prefix = format!("encoder.layer.{}", layer_idx);
659
660            // Attention weights
661            let attention_prefix = format!("{}.attention.self", layer_prefix);
662            for weight_type in &["query", "key", "value"] {
663                let weight_name = format!("{}.{}.weight", attention_prefix, weight_type);
664                let bias_name = format!("{}.{}.bias", attention_prefix, weight_type);
665
666                if let Ok(weight) = loader.load_tensor(&weight_name) {
667                    println!("Loaded {}: {:?}", weight_name, weight.shape());
668                }
669                if let Ok(bias) = loader.load_tensor(&bias_name) {
670                    println!("Loaded {}: {:?}", bias_name, bias.shape());
671                }
672            }
673
674            // Projection weights for Linformer
675            if self.config.use_efficient_attention {
676                let proj_prefix = format!("{}.attention.linformer", layer_prefix);
677                for proj_type in &["key_projection", "value_projection"] {
678                    let weight_name = format!("{}.{}.weight", proj_prefix, proj_type);
679                    if let Ok(weight) = loader.load_tensor(&weight_name) {
680                        println!("Loaded {}: {:?}", weight_name, weight.shape());
681                    }
682                }
683            }
684
685            // Output weights
686            let output_weight = format!("{}.attention.output.dense.weight", layer_prefix);
687            let output_bias = format!("{}.attention.output.dense.bias", layer_prefix);
688            if let Ok(weight) = loader.load_tensor(&output_weight) {
689                println!("Loaded {}: {:?}", output_weight, weight.shape());
690            }
691            if let Ok(bias) = loader.load_tensor(&output_bias) {
692                println!("Loaded {}: {:?}", output_bias, bias.shape());
693            }
694
695            // Attention LayerNorm
696            let attention_layernorm_weight =
697                format!("{}.attention.output.LayerNorm.weight", layer_prefix);
698            let attention_layernorm_bias =
699                format!("{}.attention.output.LayerNorm.bias", layer_prefix);
700            if let Ok(weight) = loader.load_tensor(&attention_layernorm_weight) {
701                println!(
702                    "Loaded {}: {:?}",
703                    attention_layernorm_weight,
704                    weight.shape()
705                );
706            }
707            if let Ok(bias) = loader.load_tensor(&attention_layernorm_bias) {
708                println!("Loaded {}: {:?}", attention_layernorm_bias, bias.shape());
709            }
710
711            // Feed forward weights
712            let intermediate_weight = format!("{}.intermediate.dense.weight", layer_prefix);
713            let intermediate_bias = format!("{}.intermediate.dense.bias", layer_prefix);
714            if let Ok(weight) = loader.load_tensor(&intermediate_weight) {
715                println!("Loaded {}: {:?}", intermediate_weight, weight.shape());
716            }
717            if let Ok(bias) = loader.load_tensor(&intermediate_bias) {
718                println!("Loaded {}: {:?}", intermediate_bias, bias.shape());
719            }
720
721            let output_dense_weight = format!("{}.output.dense.weight", layer_prefix);
722            let output_dense_bias = format!("{}.output.dense.bias", layer_prefix);
723            if let Ok(weight) = loader.load_tensor(&output_dense_weight) {
724                println!("Loaded {}: {:?}", output_dense_weight, weight.shape());
725            }
726            if let Ok(bias) = loader.load_tensor(&output_dense_bias) {
727                println!("Loaded {}: {:?}", output_dense_bias, bias.shape());
728            }
729
730            // Output LayerNorm
731            let output_layernorm_weight = format!("{}.output.LayerNorm.weight", layer_prefix);
732            let output_layernorm_bias = format!("{}.output.LayerNorm.bias", layer_prefix);
733            if let Ok(weight) = loader.load_tensor(&output_layernorm_weight) {
734                println!("Loaded {}: {:?}", output_layernorm_weight, weight.shape());
735            }
736            if let Ok(bias) = loader.load_tensor(&output_layernorm_bias) {
737                println!("Loaded {}: {:?}", output_layernorm_bias, bias.shape());
738            }
739        }
740
741        println!("Successfully loaded Linformer model weights from path");
742        Ok(())
743    }
744
745    /// Enhanced weight loading from HuggingFace Hub with automatic download
746    pub fn load_from_huggingface(&mut self, model_name: &str) -> Result<()> {
747        let cache_dir = std::env::temp_dir().join("huggingface_cache");
748        let model_path = cache_dir.join(format!("models--{}", model_name.replace("/", "--")));
749
750        if model_path.exists() {
751            self.load_from_path(&model_path)
752        } else {
753            // Attempt to download the model from HuggingFace Hub
754            self.download_from_huggingface_hub(model_name, &model_path)?;
755            self.load_from_path(&model_path)
756        }
757    }
758
759    /// Download model from HuggingFace Hub
760    fn download_from_huggingface_hub(
761        &self,
762        model_name: &str,
763        model_path: &std::path::Path,
764    ) -> Result<()> {
765        use std::process::Command;
766
767        println!(
768            "Downloading Linformer model {} from HuggingFace Hub to {:?}",
769            model_name, model_path
770        );
771
772        // Create the model directory
773        std::fs::create_dir_all(model_path).map_err(|e| {
774            trustformers_core::errors::TrustformersError::io_error(format!(
775                "Failed to create model directory: {}",
776                e
777            ))
778        })?;
779
780        // List of essential files for Linformer models
781        let essential_files = vec![
782            "config.json",
783            "pytorch_model.bin",
784            "model.safetensors",
785            "tokenizer.json",
786            "tokenizer_config.json",
787            "vocab.txt",
788        ];
789
790        let mut successful_downloads = 0;
791
792        for file in &essential_files {
793            let url = format!(
794                "https://huggingface.co/{}/resolve/main/{}",
795                model_name, file
796            );
797            let output_path = model_path.join(file);
798
799            // Convert path to string once for both commands
800            let output_path_str = output_path.to_str().ok_or_else(|| {
801                TrustformersError::invalid_config(format!(
802                    "Invalid UTF-8 in path: {:?}",
803                    output_path
804                ))
805            })?;
806
807            // Try curl first
808            let curl_result = Command::new("curl")
809                .args([
810                    "-L", // Follow redirects
811                    "-f", // Fail silently on HTTP errors
812                    "-o",
813                    output_path_str,
814                    &url,
815                ])
816                .output();
817
818            let success = match curl_result {
819                Ok(output) => output.status.success(),
820                Err(_) => {
821                    // Fallback to wget if curl is not available
822                    let wget_result = Command::new("wget")
823                        .args([
824                            "-q", // Quiet mode
825                            "-O",
826                            output_path_str,
827                            &url,
828                        ])
829                        .output();
830
831                    match wget_result {
832                        Ok(output) => output.status.success(),
833                        Err(_) => false,
834                    }
835                },
836            };
837
838            if success {
839                successful_downloads += 1;
840                println!("Downloaded {}", file);
841            } else {
842                eprintln!(
843                    "Failed to download {} (this may be normal if the file doesn't exist)",
844                    file
845                );
846            }
847        }
848
849        if successful_downloads == 0 {
850            return Err(trustformers_core::errors::TrustformersError::io_error(
851                "Failed to download any files from HuggingFace Hub. Please check the model name and your internet connection.".to_string()
852            ));
853        }
854
855        println!(
856            "Successfully downloaded {}/{} files for Linformer model",
857            successful_downloads,
858            essential_files.len()
859        );
860        Ok(())
861    }
862}
863
864/// Linformer for sequence classification
865pub struct LinformerForSequenceClassification {
866    linformer: LinformerModel,
867    classifier: Linear,
868    #[allow(dead_code)]
869    num_labels: usize,
870    device: Device,
871}
872
873impl LinformerForSequenceClassification {
874    pub fn new(config: LinformerConfig, num_labels: usize) -> Result<Self> {
875        Self::new_with_device(config, num_labels, Device::CPU)
876    }
877
878    pub fn new_with_device(
879        config: LinformerConfig,
880        num_labels: usize,
881        device: Device,
882    ) -> Result<Self> {
883        let linformer = LinformerModel::new_with_device(config.clone(), device)?;
884        let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
885
886        Ok(Self {
887            linformer,
888            classifier,
889            num_labels,
890            device,
891        })
892    }
893
894    pub fn device(&self) -> Device {
895        self.device
896    }
897}
898
899impl Model for LinformerForSequenceClassification {
900    type Config = LinformerConfig;
901    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
902    type Output = Tensor;
903
904    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
905        let sequence_output = self.linformer.forward(input)?;
906
907        // Use [CLS] token (first token) for classification
908        // Extract the first token (CLS token) from the sequence output
909        let cls_output = match &sequence_output {
910            Tensor::F32(arr) => {
911                let shape = arr.shape();
912                if shape.len() >= 3 {
913                    // Shape: [batch_size, seq_len, hidden_size]
914                    // Extract first token: [batch_size, 1, hidden_size] -> [batch_size, hidden_size]
915                    let batch_size = shape[0];
916                    let hidden_size = shape[2];
917
918                    let arr_slice = arr.as_slice().ok_or_else(|| {
919                        TrustformersError::tensor_op_error(
920                            "extract_cls_embeddings",
921                            "Tensor is not contiguous in memory",
922                        )
923                    })?;
924
925                    let mut cls_data = Vec::with_capacity(batch_size * hidden_size);
926                    for b in 0..batch_size {
927                        for h in 0..hidden_size {
928                            // Take first token (index 0) for each batch
929                            let idx = (b * shape[1]) * hidden_size + h;
930                            cls_data.push(arr_slice[idx]);
931                        }
932                    }
933
934                    let cls_array =
935                        ArrayD::from_shape_vec(IxDyn(&[batch_size, hidden_size]), cls_data)
936                            .map_err(|_| {
937                                trustformers_core::errors::TrustformersError::shape_error(
938                                    "Failed to create CLS token tensor".to_string(),
939                                )
940                            })?;
941
942                    Tensor::F32(cls_array)
943                } else {
944                    sequence_output.clone()
945                }
946            },
947            _ => sequence_output.clone(),
948        };
949
950        self.classifier.forward(cls_output)
951    }
952
953    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
954        self.linformer.load_pretrained(reader)
955    }
956
957    fn get_config(&self) -> &Self::Config {
958        self.linformer.get_config()
959    }
960
961    fn num_parameters(&self) -> usize {
962        self.linformer.num_parameters() + self.classifier.parameter_count()
963    }
964}
965
966/// Linformer for masked language modeling
967pub struct LinformerForMaskedLM {
968    linformer: LinformerModel,
969    mlm_head: Linear,
970    device: Device,
971}
972
973impl LinformerForMaskedLM {
974    pub fn new(config: LinformerConfig) -> Result<Self> {
975        Self::new_with_device(config, Device::CPU)
976    }
977
978    pub fn new_with_device(config: LinformerConfig, device: Device) -> Result<Self> {
979        let linformer = LinformerModel::new_with_device(config.clone(), device)?;
980        let mlm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, true, device);
981
982        Ok(Self {
983            linformer,
984            mlm_head,
985            device,
986        })
987    }
988
989    pub fn device(&self) -> Device {
990        self.device
991    }
992}
993
994impl Model for LinformerForMaskedLM {
995    type Config = LinformerConfig;
996    type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
997    type Output = Tensor;
998
999    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1000        let sequence_output = self.linformer.forward(input)?;
1001        self.mlm_head.forward(sequence_output)
1002    }
1003
1004    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
1005        self.linformer.load_pretrained(reader)
1006    }
1007
1008    fn get_config(&self) -> &Self::Config {
1009        self.linformer.get_config()
1010    }
1011
1012    fn num_parameters(&self) -> usize {
1013        self.linformer.num_parameters() + self.mlm_head.parameter_count()
1014    }
1015}