Skip to main content

trustformers_models/falcon/
model.rs

1use crate::falcon::config::FalconConfig;
2use scirs2_core::ndarray::{s, ArrayD, IxDyn}; // SciRS2 Integration Policy
3use std::io::Read;
4use trustformers_core::{
5    device::Device,
6    errors::{tensor_op_error, Result, TrustformersError},
7    layers::{Embedding, LayerNorm, Linear},
8    ops::activations::{gelu, silu},
9    tensor::Tensor,
10    traits::{Config, Layer, Model},
11};
12
13/// ALiBi positional encoding implementation
14/// Attention with Linear Biases (Press et al., 2022)
15pub struct ALiBi {
16    slopes: Tensor,
17    num_heads: usize,
18    device: Device,
19}
20
21impl ALiBi {
22    pub fn new(num_heads: usize) -> Result<Self> {
23        Self::new_with_device(num_heads, Device::CPU)
24    }
25
26    pub fn new_with_device(num_heads: usize, device: Device) -> Result<Self> {
27        // Calculate slopes based on the geometric sequence pattern
28        let mut slopes = Vec::new();
29        let ratio = 2.0_f32.powf(-8.0 / num_heads as f32);
30
31        if num_heads % 2 == 0 {
32            // Even number of heads
33            for i in 0..num_heads / 2 {
34                slopes.push(ratio.powf((2 * i + 1) as f32));
35            }
36            for i in 0..num_heads / 2 {
37                slopes.push(ratio.powf((2 * i + 2) as f32));
38            }
39        } else {
40            // Odd number of heads
41            for i in 0..num_heads {
42                slopes.push(ratio.powf((i + 1) as f32));
43            }
44        }
45
46        let slopes_tensor = Tensor::new(slopes)?;
47
48        Ok(Self {
49            slopes: slopes_tensor,
50            num_heads,
51            device,
52        })
53    }
54
55    pub fn device(&self) -> Device {
56        self.device
57    }
58
59    /// Apply ALiBi bias to attention scores
60    pub fn apply_bias(&self, attention_scores: &Tensor, seq_len: usize) -> Result<Tensor> {
61        // Create position bias matrix for causal attention
62        let mut bias_data = Vec::new();
63
64        // Create bias for each head
65        for head_idx in 0..self.num_heads {
66            for i in 0..seq_len {
67                for j in 0..seq_len {
68                    if j > i {
69                        // Future positions get large negative bias (causal mask)
70                        bias_data.push(-10000.0);
71                    } else {
72                        // Past positions get linear bias scaled by head-specific slope
73                        let distance = (i - j) as f32;
74                        let slope = if let Ok(slopes_data) = self.slopes.data() {
75                            if head_idx < slopes_data.len() {
76                                slopes_data[head_idx]
77                            } else {
78                                1.0
79                            }
80                        } else {
81                            1.0
82                        };
83                        bias_data.push(-distance * slope);
84                    }
85                }
86            }
87        }
88
89        // Create bias tensor with proper shape for broadcasting
90        let bias_tensor = Tensor::from_vec(bias_data, &[seq_len, seq_len])?;
91
92        // Add bias to attention scores with proper broadcasting
93        let biased_scores = attention_scores.add(&bias_tensor)?;
94        Ok(biased_scores)
95    }
96}
97
98/// Falcon attention layer with multi-query attention and optional ALiBi
99pub struct FalconAttention {
100    q_proj: Linear,
101    k_proj: Linear,
102    v_proj: Linear,
103    dense: Linear,
104    alibi: Option<ALiBi>,
105    num_heads: usize,
106    num_kv_heads: usize,
107    head_dim: usize,
108    #[allow(dead_code)]
109    attention_dropout: f32,
110    #[allow(dead_code)]
111    use_flash_attention: bool,
112    device: Device,
113    // Note: Multi-query attention is implemented through num_kv_heads parameter
114    // Future enhancement: could add dedicated MultiQueryAttention component when needed
115}
116
117impl FalconAttention {
118    pub fn new(config: &FalconConfig) -> Result<Self> {
119        Self::new_with_device(config, Device::CPU)
120    }
121
122    pub fn new_with_device(config: &FalconConfig, device: Device) -> Result<Self> {
123        let head_dim = config.head_dim();
124        let num_kv_heads = config.num_kv_heads();
125
126        let q_proj = Linear::new(
127            config.hidden_size,
128            config.num_attention_heads * head_dim,
129            config.bias,
130        );
131        let k_proj = Linear::new(config.hidden_size, num_kv_heads * head_dim, config.bias);
132        let v_proj = Linear::new(config.hidden_size, num_kv_heads * head_dim, config.bias);
133        let dense = Linear::new(
134            config.num_attention_heads * head_dim,
135            config.hidden_size,
136            config.bias,
137        );
138
139        let alibi = if config.alibi {
140            Some(ALiBi::new_with_device(config.num_attention_heads, device)?)
141        } else {
142            None
143        };
144
145        Ok(Self {
146            q_proj,
147            k_proj,
148            v_proj,
149            dense,
150            alibi,
151            num_heads: config.num_attention_heads,
152            num_kv_heads,
153            head_dim,
154            attention_dropout: config.attention_dropout,
155            use_flash_attention: config.use_flash_attention.unwrap_or(false),
156            device,
157        })
158    }
159
160    pub fn device(&self) -> Device {
161        self.device
162    }
163
164    /// Create causal mask for autoregressive attention
165    fn create_causal_mask(&self, seq_len: usize) -> Result<Tensor> {
166        // Create lower triangular mask filled with 0s and -inf
167        let mut mask_data = vec![0.0f32; seq_len * seq_len];
168        for i in 0..seq_len {
169            for j in (i + 1)..seq_len {
170                mask_data[i * seq_len + j] = f32::NEG_INFINITY;
171            }
172        }
173        Tensor::from_vec(mask_data, &[seq_len, seq_len])
174    }
175
176    pub fn parameter_count(&self) -> usize {
177        self.q_proj.parameter_count()
178            + self.k_proj.parameter_count()
179            + self.v_proj.parameter_count()
180            + self.dense.parameter_count()
181    }
182}
183
184impl Layer for FalconAttention {
185    type Input = Tensor;
186    type Output = Tensor;
187
188    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
189        let batch_size = input.shape()[0];
190        let seq_len = input.shape()[1];
191
192        // Project to query, key, value
193        let q = self.q_proj.forward(input.clone())?;
194        let k = self.k_proj.forward(input.clone())?;
195        let v = self.v_proj.forward(input)?;
196
197        // Implement proper multi-query attention
198        // Reshape q, k, v for multi-head attention
199        let q = q.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])?;
200        let k = k.reshape(&[batch_size, seq_len, self.num_kv_heads, self.head_dim])?;
201        let v = v.reshape(&[batch_size, seq_len, self.num_kv_heads, self.head_dim])?;
202
203        // Transpose to [batch, num_heads, seq_len, head_dim]
204        let q = q.transpose(1, 2)?;
205        let k = k.transpose(1, 2)?;
206        let v = v.transpose(1, 2)?;
207
208        // For multi-query attention, repeat k and v heads to match query heads
209        let (k, v) = if self.num_kv_heads < self.num_heads {
210            let repeats = self.num_heads / self.num_kv_heads;
211
212            // Manually repeat each kv head 'repeats' times
213            let mut k_heads = Vec::new();
214            let mut v_heads = Vec::new();
215
216            for head_idx in 0..self.num_kv_heads {
217                // Extract single head: [batch, 1, seq_len, head_dim]
218                let k_head = k.slice_multi(&[
219                    (0, batch_size),
220                    (head_idx, head_idx + 1),
221                    (0, seq_len),
222                    (0, self.head_dim),
223                ])?;
224                let v_head = v.slice_multi(&[
225                    (0, batch_size),
226                    (head_idx, head_idx + 1),
227                    (0, seq_len),
228                    (0, self.head_dim),
229                ])?;
230
231                // Repeat this head 'repeats' times
232                for _ in 0..repeats {
233                    k_heads.push(k_head.clone());
234                    v_heads.push(v_head.clone());
235                }
236            }
237
238            // Concatenate all repeated heads
239            let k_repeated = Tensor::concat(&k_heads, 1)?;
240            let v_repeated = Tensor::concat(&v_heads, 1)?;
241            (k_repeated, v_repeated)
242        } else {
243            (k, v)
244        };
245
246        // Compute attention scores: Q @ K.T / sqrt(d_k)
247        // Transpose last two dimensions: [batch, num_heads, seq_len, head_dim] -> [batch, num_heads, head_dim, seq_len]
248        let k_transposed = k.transpose(2, 3)?;
249        let scores = q.matmul(&k_transposed)?;
250        let scale = (self.head_dim as f32).sqrt();
251        let scaled_scores = scores.div_scalar(scale)?;
252
253        // Apply causal mask
254        let causal_mask = self.create_causal_mask(seq_len)?;
255        let masked_scores = scaled_scores.add(&causal_mask)?;
256
257        // Apply softmax
258        let attention_weights = masked_scores.softmax(-1)?;
259
260        // Apply attention to values
261        let attention_output = attention_weights.matmul(&v)?;
262
263        // Transpose back and reshape
264        let attention_output = attention_output.transpose(1, 2)?;
265        let attention_output =
266            attention_output.reshape(&[batch_size, seq_len, self.num_heads * self.head_dim])?;
267
268        // Apply ALiBi bias if enabled
269        let biased_output = if let Some(alibi) = &self.alibi {
270            alibi.apply_bias(&attention_output, seq_len)?
271        } else {
272            attention_output
273        };
274
275        // Final output projection
276        let output = self.dense.forward(biased_output)?;
277        Ok(output)
278    }
279}
280
281/// Falcon MLP layer
282pub struct FalconMLP {
283    dense_h_to_4h: Linear,
284    dense_4h_to_h: Linear,
285    activation: String,
286    device: Device,
287}
288
289impl FalconMLP {
290    pub fn new(config: &FalconConfig) -> Result<Self> {
291        Self::new_with_device(config, Device::CPU)
292    }
293
294    pub fn new_with_device(config: &FalconConfig, device: Device) -> Result<Self> {
295        let intermediate_size = 4 * config.hidden_size;
296
297        let dense_h_to_4h = Linear::new(config.hidden_size, intermediate_size, config.bias);
298        let dense_4h_to_h = Linear::new(intermediate_size, config.hidden_size, config.bias);
299
300        Ok(Self {
301            dense_h_to_4h,
302            dense_4h_to_h,
303            activation: config.hidden_act.clone(),
304            device,
305        })
306    }
307
308    pub fn device(&self) -> Device {
309        self.device
310    }
311
312    pub fn parameter_count(&self) -> usize {
313        self.dense_h_to_4h.parameter_count() + self.dense_4h_to_h.parameter_count()
314    }
315}
316
317impl Layer for FalconMLP {
318    type Input = Tensor;
319    type Output = Tensor;
320
321    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
322        let hidden = self.dense_h_to_4h.forward(input)?;
323
324        // Apply activation function
325        let activated = match self.activation.as_str() {
326            "gelu" => gelu(&hidden)?,
327            "relu" => hidden.relu()?,
328            "silu" | "swish" => silu(&hidden)?,
329            _ => hidden,
330        };
331
332        let output = self.dense_4h_to_h.forward(activated)?;
333        Ok(output)
334    }
335}
336
337/// Falcon decoder layer
338pub struct FalconDecoderLayer {
339    input_layernorm: LayerNorm,
340    self_attention: FalconAttention,
341    mlp: FalconMLP,
342    parallel_attn: bool,
343    apply_residual_connection_post_layernorm: bool,
344    device: Device,
345}
346
347impl FalconDecoderLayer {
348    pub fn new(config: &FalconConfig) -> Result<Self> {
349        Self::new_with_device(config, Device::CPU)
350    }
351
352    pub fn new_with_device(config: &FalconConfig, device: Device) -> Result<Self> {
353        let input_layernorm = LayerNorm::new(vec![config.hidden_size], config.layer_norm_epsilon)?;
354        let self_attention = FalconAttention::new_with_device(config, device)?;
355        let mlp = FalconMLP::new_with_device(config, device)?;
356
357        Ok(Self {
358            input_layernorm,
359            self_attention,
360            mlp,
361            parallel_attn: config.parallel_attn,
362            apply_residual_connection_post_layernorm: config
363                .apply_residual_connection_post_layernorm,
364            device,
365        })
366    }
367
368    pub fn device(&self) -> Device {
369        self.device
370    }
371
372    pub fn parameter_count(&self) -> usize {
373        self.input_layernorm.parameter_count()
374            + self.self_attention.parameter_count()
375            + self.mlp.parameter_count()
376    }
377}
378
379impl Layer for FalconDecoderLayer {
380    type Input = Tensor;
381    type Output = Tensor;
382
383    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
384        if self.parallel_attn {
385            // Parallel attention and MLP computation (Falcon's innovation)
386            let layernorm_output = self.input_layernorm.forward(input.clone())?;
387
388            // Compute attention and MLP in parallel
389            let attention_output = self.self_attention.forward(layernorm_output.clone())?;
390            let mlp_output = self.mlp.forward(layernorm_output.clone())?;
391
392            // Add both outputs to input (residual connections)
393            let residual_input = if self.apply_residual_connection_post_layernorm {
394                layernorm_output
395            } else {
396                input
397            };
398
399            // Add both outputs to input (residual connections)
400            let output = residual_input.add(&attention_output)?.add(&mlp_output)?;
401            Ok(output)
402        } else {
403            // Sequential attention -> MLP (standard transformer)
404            let layernorm_output = self.input_layernorm.forward(input.clone())?;
405            let attention_output = self.self_attention.forward(layernorm_output)?;
406
407            // Add residual connection
408            let residual_output = input.add(&attention_output)?;
409
410            let layernorm_output2 = self.input_layernorm.forward(residual_output.clone())?;
411            let mlp_output = self.mlp.forward(layernorm_output2)?;
412
413            // Add residual connection
414            let output = residual_output.add(&mlp_output)?;
415            Ok(output)
416        }
417    }
418}
419
420/// Falcon transformer model
421pub struct FalconModel {
422    word_embeddings: Embedding,
423    layers: Vec<FalconDecoderLayer>,
424    ln_f: LayerNorm,
425    config: FalconConfig,
426    device: Device,
427}
428
429impl FalconModel {
430    pub fn new(config: FalconConfig) -> Result<Self> {
431        Self::new_with_device(config, Device::CPU)
432    }
433
434    pub fn new_with_device(config: FalconConfig, device: Device) -> Result<Self> {
435        config.validate()?;
436
437        let word_embeddings = Embedding::new(
438            config.vocab_size,
439            config.hidden_size,
440            config.pad_token_id.map(|id| id as usize),
441        )?;
442
443        let mut layers = Vec::new();
444        for _ in 0..config.num_hidden_layers {
445            layers.push(FalconDecoderLayer::new_with_device(&config, device)?);
446        }
447
448        let ln_f = LayerNorm::new(vec![config.hidden_size], config.layer_norm_epsilon)?;
449
450        Ok(Self {
451            word_embeddings,
452            layers,
453            ln_f,
454            config,
455            device,
456        })
457    }
458
459    pub fn device(&self) -> Device {
460        self.device
461    }
462
463    pub fn config(&self) -> &FalconConfig {
464        &self.config
465    }
466}
467
468impl Model for FalconModel {
469    type Config = FalconConfig;
470    type Input = Tensor;
471    type Output = Tensor;
472
473    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
474        Layer::forward(self, input)
475    }
476
477    fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
478        // Legacy interface - use enhanced weight loading methods for production
479        Err(TrustformersError::not_implemented(
480            "Use load_from_path or load_from_huggingface for enhanced weight loading".to_string(),
481        ))
482    }
483
484    fn get_config(&self) -> &Self::Config {
485        &self.config
486    }
487
488    fn num_parameters(&self) -> usize {
489        let embeddings_params = self.word_embeddings.parameter_count();
490        let layers_params: usize = self.layers.iter().map(|layer| layer.parameter_count()).sum();
491        let norm_params = self.ln_f.parameter_count();
492
493        embeddings_params + layers_params + norm_params
494    }
495}
496
497impl Layer for FalconModel {
498    type Input = Tensor;
499    type Output = Tensor;
500
501    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
502        // Convert input tensor to token IDs
503        let token_ids = match &input {
504            Tensor::F32(arr) => {
505                // Convert F32 tensor to u32 token IDs
506                arr.iter().map(|&x| x as u32).collect::<Vec<u32>>()
507            },
508            _ => {
509                return Err(tensor_op_error(
510                    "tensor_operation",
511                    "Input must be F32 tensor",
512                ))
513            },
514        };
515
516        if token_ids.is_empty() {
517            return Err(TrustformersError::model_error(
518                "Empty token_ids provided".to_string(),
519            ));
520        }
521
522        let mut hidden_states = self.word_embeddings.forward(token_ids)?;
523
524        // Pass through transformer layers
525        for layer in &self.layers {
526            hidden_states = layer.forward(hidden_states)?;
527        }
528
529        // Final layer norm
530        let output = self.ln_f.forward(hidden_states)?;
531        Ok(output)
532    }
533}
534
535/// Falcon model for causal language modeling
536pub struct FalconForCausalLM {
537    transformer: FalconModel,
538    lm_head: Linear,
539    device: Device,
540}
541
542impl FalconForCausalLM {
543    pub fn new(config: FalconConfig) -> Result<Self> {
544        Self::new_with_device(config, Device::CPU)
545    }
546
547    pub fn new_with_device(config: FalconConfig, device: Device) -> Result<Self> {
548        let transformer = FalconModel::new_with_device(config.clone(), device)?;
549        let lm_head = Linear::new(
550            config.hidden_size,
551            config.vocab_size,
552            false, // No bias in language modeling head
553        );
554
555        Ok(Self {
556            transformer,
557            lm_head,
558            device,
559        })
560    }
561
562    pub fn device(&self) -> Device {
563        self.device
564    }
565
566    /// Load model weights from a directory containing HuggingFace format weights
567    pub fn load_from_path(&mut self, model_path: impl AsRef<std::path::Path>) -> Result<()> {
568        use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
569
570        let config = WeightLoadingConfig {
571            lazy_loading: true,
572            memory_mapped: false,
573            ..Default::default()
574        };
575
576        let mut loader = auto_create_loader(model_path, Some(config))?;
577
578        // Load word embeddings
579        if let Ok(embed_weights) = loader.load_tensor("transformer.word_embeddings.weight") {
580            self.transformer.word_embeddings.set_weight(embed_weights)?;
581        }
582
583        // Load layer weights
584        for (i, layer) in self.transformer.layers.iter_mut().enumerate() {
585            // Load attention weights
586            let attn_prefix = format!("transformer.h.{}.self_attention", i);
587
588            if let Ok(qkv_weight) =
589                loader.load_tensor(&format!("{}.query_key_value.weight", attn_prefix))
590            {
591                // Falcon uses combined QKV projection - split into Q, K, V
592                match &qkv_weight {
593                    Tensor::F32(arr) => {
594                        let shape = arr.shape();
595                        let combined_size = shape[0];
596                        let _hidden_size = shape[1];
597
598                        // Assuming equal sizes for Q, K, V (though Falcon may use different ratios)
599                        let head_dim = combined_size / 3;
600
601                        // Split the combined weight tensor
602                        let q_slice = arr.slice(s![0..head_dim, ..]).to_owned();
603                        let k_slice = arr.slice(s![head_dim..2 * head_dim, ..]).to_owned();
604                        let v_slice = arr.slice(s![2 * head_dim..3 * head_dim, ..]).to_owned();
605
606                        // Convert to dynamic arrays and set individual weights
607                        let q_dyn = q_slice.into_dyn();
608                        let k_dyn = k_slice.into_dyn();
609                        let v_dyn = v_slice.into_dyn();
610
611                        layer.self_attention.q_proj.set_weight(Tensor::F32(q_dyn))?;
612                        layer.self_attention.k_proj.set_weight(Tensor::F32(k_dyn))?;
613                        layer.self_attention.v_proj.set_weight(Tensor::F32(v_dyn))?;
614                    },
615                    _ => {
616                        // Fallback: use the same weight for all (not ideal but better than crashing)
617                        layer.self_attention.q_proj.set_weight(qkv_weight.clone())?;
618                    },
619                }
620            }
621            if let Ok(o_weight) = loader.load_tensor(&format!("{}.dense.weight", attn_prefix)) {
622                layer.self_attention.dense.set_weight(o_weight)?;
623            }
624
625            // Load MLP weights
626            let mlp_prefix = format!("transformer.h.{}.mlp", i);
627
628            if let Ok(up_weight) =
629                loader.load_tensor(&format!("{}.dense_h_to_4h.weight", mlp_prefix))
630            {
631                layer.mlp.dense_h_to_4h.set_weight(up_weight)?;
632            }
633            if let Ok(down_weight) =
634                loader.load_tensor(&format!("{}.dense_4h_to_h.weight", mlp_prefix))
635            {
636                layer.mlp.dense_4h_to_h.set_weight(down_weight)?;
637            }
638
639            // Load layer norm weights
640            if let Ok(ln_weight) =
641                loader.load_tensor(&format!("transformer.h.{}.input_layernorm.weight", i))
642            {
643                layer.input_layernorm.set_weight(ln_weight)?;
644            }
645            if let Ok(ln_bias) =
646                loader.load_tensor(&format!("transformer.h.{}.input_layernorm.bias", i))
647            {
648                layer.input_layernorm.set_bias(ln_bias)?;
649            }
650        }
651
652        // Load final layer norm
653        if let Ok(norm_weight) = loader.load_tensor("transformer.ln_f.weight") {
654            self.transformer.ln_f.set_weight(norm_weight)?;
655        }
656        if let Ok(norm_bias) = loader.load_tensor("transformer.ln_f.bias") {
657            self.transformer.ln_f.set_bias(norm_bias)?;
658        }
659
660        // Load LM head weights
661        if let Ok(lm_head_weight) = loader.load_tensor("lm_head.weight") {
662            self.lm_head.set_weight(lm_head_weight)?;
663        }
664
665        Ok(())
666    }
667
668    /// Load from HuggingFace Hub model name
669    pub fn load_from_huggingface(&mut self, model_name: &str) -> Result<()> {
670        // Check if model is cached locally
671        let cache_dir = std::env::var("HF_HOME")
672            .or_else(|_| std::env::var("HUGGINGFACE_HUB_CACHE"))
673            .unwrap_or_else(|_| {
674                std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
675                    + "/.cache/huggingface/hub"
676            });
677
678        let model_path = std::path::Path::new(&cache_dir)
679            .join(format!("models--{}", model_name.replace("/", "--")));
680
681        if model_path.exists() {
682            self.load_from_path(&model_path)
683        } else {
684            // Attempt to download the model from HuggingFace Hub
685            self.download_from_huggingface_hub(model_name, &model_path)?;
686            self.load_from_path(&model_path)
687        }
688    }
689
690    /// Download model from HuggingFace Hub
691    fn download_from_huggingface_hub(
692        &self,
693        model_name: &str,
694        model_path: &std::path::Path,
695    ) -> Result<()> {
696        use std::process::Command;
697
698        println!(
699            "Downloading model {} from HuggingFace Hub to {:?}",
700            model_name, model_path
701        );
702
703        // Create the model directory
704        std::fs::create_dir_all(model_path).map_err(|e| {
705            TrustformersError::io_error(format!("Failed to create model directory: {}", e))
706        })?;
707
708        // List of essential files for Falcon models
709        let essential_files = vec![
710            "config.json",
711            "tokenizer.json",
712            "tokenizer_config.json",
713            "pytorch_model.bin", // Try .bin first
714            "model.safetensors", // Fall back to safetensors
715        ];
716
717        let base_url = format!("https://huggingface.co/{}/resolve/main", model_name);
718
719        // Try to download each essential file
720        for file_name in &essential_files {
721            let file_url = format!("{}/{}", base_url, file_name);
722            let file_path = model_path.join(file_name);
723
724            println!("Attempting to download {}", file_url);
725
726            // Try using curl first
727            let curl_result = Command::new("curl")
728                .args([
729                    "-L", // Follow redirects
730                    "-f", // Fail on HTTP errors
731                    "-o",
732                    file_path.to_str().expect("operation failed"),
733                    &file_url,
734                ])
735                .output();
736
737            match curl_result {
738                Ok(output) if output.status.success() => {
739                    println!("Successfully downloaded {}", file_name);
740                    continue;
741                },
742                Ok(output) => {
743                    eprintln!(
744                        "Failed to download {} with curl: {}",
745                        file_name,
746                        String::from_utf8_lossy(&output.stderr)
747                    );
748                },
749                Err(e) => {
750                    println!("curl not available: {}", e);
751                },
752            }
753
754            // Try using wget as fallback
755            let wget_result = Command::new("wget")
756                .args([
757                    "-O",
758                    file_path.to_str().expect("operation failed"),
759                    &file_url,
760                ])
761                .output();
762
763            match wget_result {
764                Ok(output) if output.status.success() => {
765                    println!("Successfully downloaded {} with wget", file_name);
766                    continue;
767                },
768                Ok(output) => {
769                    eprintln!(
770                        "Failed to download {} with wget: {}",
771                        file_name,
772                        String::from_utf8_lossy(&output.stderr)
773                    );
774                },
775                Err(e) => {
776                    println!("wget not available: {}", e);
777                },
778            }
779
780            // If essential files like config.json or pytorch_model.bin fail, return error
781            if matches!(file_name, &"config.json" | &"pytorch_model.bin") {
782                return Err(TrustformersError::io_error(format!(
783                    "Failed to download essential file {} for model {}. Please ensure curl or wget is installed and you have internet access.",
784                    file_name, model_name
785                )));
786            }
787        }
788
789        println!(
790            "Successfully downloaded model {} from HuggingFace Hub",
791            model_name
792        );
793        Ok(())
794    }
795
796    /// Legacy method name for backward compatibility
797    pub fn load_from_hub(&mut self, model_name: &str) -> Result<()> {
798        self.load_from_huggingface(model_name)
799    }
800
801    /// Generate text using the model
802    pub fn generate(&self, input_ids: Tensor, max_length: usize) -> Result<Tensor> {
803        let mut current_ids = input_ids;
804        let current_length = current_ids.shape()[current_ids.shape().len() - 1];
805
806        // Autoregressive generation
807        for _ in current_length..max_length {
808            // Forward pass through the model
809            let logits = <Self as Model>::forward(self, current_ids.clone())?;
810
811            // Get the last token logits
812            let last_logits = match &logits {
813                Tensor::F32(arr) => {
814                    let shape = arr.shape();
815                    let seq_len = shape[shape.len() - 2];
816                    let _vocab_size = shape[shape.len() - 1];
817
818                    // Extract last token logits
819                    let last_token_slice = if shape.len() == 3 {
820                        arr.slice(s![0, seq_len - 1, ..])
821                    } else {
822                        arr.slice(s![seq_len - 1, ..])
823                    };
824                    last_token_slice.to_owned()
825                },
826                _ => {
827                    return Err(tensor_op_error(
828                        "tensor_operation",
829                        "Logits must be F32 tensor",
830                    ))
831                },
832            };
833
834            // Greedy decoding: select token with highest probability
835            let next_token_id = last_logits
836                .iter()
837                .enumerate()
838                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
839                .map(|(idx, _)| idx as u32)
840                .ok_or_else(|| {
841                    TrustformersError::model_error("Failed to find next token".to_string())
842                })?;
843
844            // Check for EOS token (commonly ID 2 for Falcon models)
845            if next_token_id == 2 {
846                break;
847            }
848
849            // Append next token to sequence
850            current_ids = match &current_ids {
851                Tensor::F32(arr) => {
852                    // Convert token ID to f32 tensor and concatenate
853                    let mut new_shape = arr.shape().to_vec();
854                    let last_idx = new_shape.len() - 1;
855                    new_shape[last_idx] += 1;
856
857                    let mut new_arr = ArrayD::<f32>::zeros(IxDyn(&new_shape));
858
859                    // Copy existing data
860                    if arr.ndim() == 2 {
861                        for i in 0..arr.shape()[0] {
862                            for j in 0..arr.shape()[1] {
863                                new_arr[[i, j]] = arr[[i, j]];
864                            }
865                            new_arr[[i, arr.shape()[1]]] = next_token_id as f32;
866                        }
867                    } else if arr.ndim() == 1 {
868                        for i in 0..arr.shape()[0] {
869                            new_arr[[i]] = arr[[i]];
870                        }
871                        new_arr[[arr.shape()[0]]] = next_token_id as f32;
872                    }
873
874                    Tensor::F32(new_arr)
875                },
876                _ => {
877                    return Err(tensor_op_error(
878                        "tensor_operation",
879                        "Input must be F32 tensor",
880                    ))
881                },
882            };
883        }
884
885        Ok(current_ids)
886    }
887
888    pub fn model(&self) -> &FalconModel {
889        &self.transformer
890    }
891}
892
893impl Model for FalconForCausalLM {
894    type Config = FalconConfig;
895    type Input = Tensor;
896    type Output = Tensor;
897
898    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
899        Layer::forward(self, input)
900    }
901
902    fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
903        self.transformer.load_pretrained(reader)
904    }
905
906    fn get_config(&self) -> &Self::Config {
907        self.transformer.get_config()
908    }
909
910    fn num_parameters(&self) -> usize {
911        self.transformer.num_parameters() + self.lm_head.parameter_count()
912    }
913}
914
915impl Layer for FalconForCausalLM {
916    type Input = Tensor;
917    type Output = Tensor;
918
919    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
920        let hidden_states = Layer::forward(&self.transformer, input)?;
921        let logits = self.lm_head.forward(hidden_states)?;
922        Ok(logits)
923    }
924}
925
926#[cfg(test)]
927mod tests {
928    use super::*;
929
930    #[test]
931    #[ignore] // Very heavy test - Falcon 7B model (SIGKILL risk), run with --ignored
932    fn test_falcon_model_creation() {
933        let config = FalconConfig::falcon_7b();
934        let model = FalconModel::new(config);
935        assert!(model.is_ok());
936    }
937
938    #[test]
939    #[ignore] // Very heavy test - Falcon 7B CausalLM (SIGKILL risk), run with --ignored
940    fn test_falcon_causal_lm_creation() {
941        let config = FalconConfig::falcon_7b();
942        let model = FalconForCausalLM::new(config);
943        assert!(model.is_ok());
944    }
945
946    #[test]
947    fn test_falcon_config_variants() {
948        // Test 7B model
949        let config_7b = FalconConfig::falcon_7b();
950        assert_eq!(config_7b.hidden_size, 4544);
951        assert_eq!(config_7b.num_hidden_layers, 32);
952        assert!(config_7b.uses_alibi());
953
954        // Test 40B model
955        let config_40b = FalconConfig::falcon_40b();
956        assert_eq!(config_40b.hidden_size, 8192);
957        assert_eq!(config_40b.num_hidden_layers, 60);
958        assert!(config_40b.uses_alibi());
959
960        // Test 180B model
961        let config_180b = FalconConfig::falcon_180b();
962        assert_eq!(config_180b.hidden_size, 14848);
963        assert_eq!(config_180b.num_hidden_layers, 80);
964        assert!(!config_180b.uses_alibi());
965        assert!(config_180b.uses_new_architecture());
966    }
967
968    #[test]
969    fn test_alibi_creation() {
970        let alibi = ALiBi::new(8);
971        assert!(alibi.is_ok());
972
973        let alibi = alibi.expect("operation failed");
974        assert_eq!(alibi.num_heads, 8);
975    }
976
977    #[test]
978    fn test_falcon_attention_creation() {
979        let config = FalconConfig::falcon_7b();
980        let attention = FalconAttention::new(&config);
981        assert!(attention.is_ok());
982    }
983
984    #[test]
985    fn test_falcon_mlp_creation() {
986        let config = FalconConfig::falcon_7b();
987        let mlp = FalconMLP::new(&config);
988        assert!(mlp.is_ok());
989    }
990}