Skip to main content

trustformers_models/stablelm/
model.rs

1use crate::stablelm::config::StableLMConfig;
2use scirs2_core::ndarray::{Array1, Array2, Axis}; // SciRS2 Integration Policy (Array2 in tests)
3use trustformers_core::{
4    device::Device,
5    errors::{tensor_op_error, Result, TrustformersError},
6    layers::{Embedding, Linear},
7    ops::activations::{silu, swiglu},
8    tensor::Tensor,
9    traits::{Layer, Model},
10};
11
12/// Root Mean Square Layer Normalization
13pub struct RMSNorm {
14    weight: Tensor,
15    eps: f32,
16    device: Device,
17}
18
19impl RMSNorm {
20    pub fn new(hidden_size: usize, eps: f32) -> Result<Self> {
21        Self::new_with_device(hidden_size, eps, Device::CPU)
22    }
23
24    pub fn new_with_device(hidden_size: usize, eps: f32, device: Device) -> Result<Self> {
25        let weight = Tensor::ones(&[hidden_size])?.to_device_enum(&device)?;
26        Ok(Self {
27            weight,
28            eps,
29            device,
30        })
31    }
32
33    pub fn device(&self) -> &Device {
34        &self.device
35    }
36
37    pub fn parameter_count(&self) -> usize {
38        self.weight.shape().iter().product()
39    }
40}
41
42impl Layer for RMSNorm {
43    type Input = Tensor;
44    type Output = Tensor;
45
46    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
47        match &input {
48            Tensor::F32(arr) => {
49                let mean_sq = arr.mapv(|x| x * x).mean().unwrap_or(0.0);
50                let rms = (mean_sq + self.eps).sqrt();
51                let normalized = arr.mapv(|x| x / rms);
52
53                match &self.weight {
54                    Tensor::F32(weight_arr) => {
55                        let result = &normalized * weight_arr;
56                        Ok(Tensor::F32(result))
57                    },
58                    _ => Err(tensor_op_error(
59                        "tensor_operation",
60                        "Unsupported tensor type".to_string(),
61                    )),
62                }
63            },
64            _ => Err(tensor_op_error(
65                "tensor_operation",
66                "Unsupported tensor type".to_string(),
67            )),
68        }
69    }
70}
71
72/// Rotary Position Embeddings with partial rotary factor
73pub struct RotaryEmbedding {
74    sin_cached: Tensor,
75    cos_cached: Tensor,
76    max_seq_len: usize,
77    head_dim: usize,
78    #[allow(dead_code)]
79    base: f32,
80    partial_rotary_factor: f32,
81    device: Device,
82}
83
84impl RotaryEmbedding {
85    pub fn new(
86        head_dim: usize,
87        max_seq_len: usize,
88        base: f32,
89        partial_rotary_factor: f32,
90    ) -> Result<Self> {
91        Self::new_with_device(
92            head_dim,
93            max_seq_len,
94            base,
95            partial_rotary_factor,
96            Device::CPU,
97        )
98    }
99
100    pub fn new_with_device(
101        head_dim: usize,
102        max_seq_len: usize,
103        base: f32,
104        partial_rotary_factor: f32,
105        device: Device,
106    ) -> Result<Self> {
107        let rotary_dim = ((head_dim as f32) * partial_rotary_factor) as usize;
108
109        // Pre-compute sin and cos values
110        let inv_freq = Array1::range(0.0, rotary_dim as f32, 2.0)
111            .mapv(|i| 1.0 / base.powf(i / rotary_dim as f32));
112
113        let t = Array1::range(0.0, max_seq_len as f32, 1.0);
114        let freqs = t.view().insert_axis(Axis(1)).dot(&inv_freq.view().insert_axis(Axis(0)));
115
116        let sin_arr =
117            Array2::from_shape_fn((max_seq_len, rotary_dim / 2), |(i, j)| freqs[[i, j]].sin());
118        let cos_arr =
119            Array2::from_shape_fn((max_seq_len, rotary_dim / 2), |(i, j)| freqs[[i, j]].cos());
120
121        let sin_cached = Tensor::F32(sin_arr.into_dyn()).to_device_enum(&device)?;
122        let cos_cached = Tensor::F32(cos_arr.into_dyn()).to_device_enum(&device)?;
123
124        Ok(Self {
125            sin_cached,
126            cos_cached,
127            max_seq_len,
128            head_dim,
129            base,
130            partial_rotary_factor,
131            device,
132        })
133    }
134
135    pub fn device(&self) -> &Device {
136        &self.device
137    }
138
139    pub fn forward(&self, q: &Tensor, k: &Tensor, seq_len: usize) -> Result<(Tensor, Tensor)> {
140        let rotary_dim = ((self.head_dim as f32) * self.partial_rotary_factor) as usize;
141
142        match (q, k, &self.sin_cached, &self.cos_cached) {
143            (
144                Tensor::F32(q_arr),
145                Tensor::F32(k_arr),
146                Tensor::F32(sin_arr),
147                Tensor::F32(cos_arr),
148            ) => {
149                // Apply partial rotary embeddings
150                let mut q_rot = q_arr.clone();
151                let mut k_rot = k_arr.clone();
152
153                // Only rotate the first rotary_dim dimensions
154                if rotary_dim > 0 && seq_len <= self.max_seq_len {
155                    // Get shapes before mutable operations to avoid borrow checker issues
156                    let q_shape = q_rot.shape().to_vec();
157                    let _k_shape = k_rot.shape().to_vec();
158
159                    // Apply RoPE to query and key tensors
160                    for seq_idx in 0..seq_len {
161                        for dim_idx in 0..(rotary_dim / 2) {
162                            let cos_val = cos_arr[[seq_idx, dim_idx]];
163                            let sin_val = sin_arr[[seq_idx, dim_idx]];
164
165                            // Apply rotation: [x1, x2] -> [x1*cos - x2*sin, x1*sin + x2*cos]
166                            // This is a simplified implementation for the core rotation logic
167                            for batch in 0..q_shape[0] {
168                                for head in 0..q_shape[1] {
169                                    if seq_idx < q_shape[2] && dim_idx < rotary_dim / 2 {
170                                        let x1_idx = [batch, head, seq_idx, dim_idx * 2];
171                                        let x2_idx = [batch, head, seq_idx, dim_idx * 2 + 1];
172
173                                        if x2_idx[3] < q_shape[3] {
174                                            let q_x1 = q_rot[x1_idx];
175                                            let q_x2 = q_rot[x2_idx];
176                                            let k_x1 = k_rot[x1_idx];
177                                            let k_x2 = k_rot[x2_idx];
178
179                                            // Query rotation
180                                            q_rot[x1_idx] = q_x1 * cos_val - q_x2 * sin_val;
181                                            q_rot[x2_idx] = q_x1 * sin_val + q_x2 * cos_val;
182
183                                            // Key rotation
184                                            k_rot[x1_idx] = k_x1 * cos_val - k_x2 * sin_val;
185                                            k_rot[x2_idx] = k_x1 * sin_val + k_x2 * cos_val;
186                                        }
187                                    }
188                                }
189                            }
190                        }
191                    }
192                }
193
194                Ok((Tensor::F32(q_rot), Tensor::F32(k_rot)))
195            },
196            _ => Err(tensor_op_error(
197                "tensor_operation",
198                "Unsupported tensor type".to_string(),
199            )),
200        }
201    }
202}
203
204/// Multi-Head Attention with optional grouped-query attention
205pub struct StableLMAttention {
206    #[allow(dead_code)]
207    config: StableLMConfig,
208    q_proj: Linear,
209    k_proj: Linear,
210    v_proj: Linear,
211    o_proj: Linear,
212    rotary_emb: RotaryEmbedding,
213    #[allow(dead_code)]
214    head_dim: usize,
215    num_heads: usize,
216    num_kv_heads: usize,
217    device: Device,
218}
219
220impl StableLMAttention {
221    pub fn new(config: &StableLMConfig) -> Result<Self> {
222        Self::new_with_device(config, Device::CPU)
223    }
224
225    pub fn new_with_device(config: &StableLMConfig, device: Device) -> Result<Self> {
226        let hidden_size = config.hidden_size;
227        let num_heads = config.num_attention_heads;
228        let num_kv_heads = config.num_key_value_heads.unwrap_or(num_heads);
229        let head_dim = hidden_size / num_heads;
230
231        let q_proj =
232            Linear::new_with_device(hidden_size, hidden_size, config.attention_bias, device);
233        let k_proj = Linear::new_with_device(
234            hidden_size,
235            num_kv_heads * head_dim,
236            config.attention_bias,
237            device,
238        );
239        let v_proj = Linear::new_with_device(
240            hidden_size,
241            num_kv_heads * head_dim,
242            config.attention_bias,
243            device,
244        );
245        let o_proj =
246            Linear::new_with_device(hidden_size, hidden_size, config.attention_bias, device);
247
248        let rotary_emb = RotaryEmbedding::new_with_device(
249            head_dim,
250            config.max_position_embeddings,
251            config.rope_theta,
252            config.partial_rotary_factor,
253            device,
254        )?;
255
256        Ok(Self {
257            config: config.clone(),
258            q_proj,
259            k_proj,
260            v_proj,
261            o_proj,
262            rotary_emb,
263            head_dim,
264            num_heads,
265            num_kv_heads,
266            device,
267        })
268    }
269
270    pub fn device(&self) -> &Device {
271        &self.device
272    }
273
274    fn repeat_kv(&self, hidden_states: &Tensor, n_rep: usize) -> Result<Tensor> {
275        if n_rep == 1 {
276            return Ok(hidden_states.clone());
277        }
278
279        match hidden_states {
280            Tensor::F32(arr) => {
281                // Repeat key/value heads for grouped-query attention
282                let _shape = arr.shape();
283                let mut repeated = arr.clone();
284
285                // Simplified - actual implementation would properly repeat along head dimension
286                for _ in 1..n_rep {
287                    repeated = repeated.clone(); // Placeholder
288                }
289
290                Ok(Tensor::F32(repeated))
291            },
292            _ => Err(tensor_op_error(
293                "tensor_operation",
294                "Unsupported tensor type".to_string(),
295            )),
296        }
297    }
298
299    pub fn parameter_count(&self) -> usize {
300        self.q_proj.parameter_count()
301            + self.k_proj.parameter_count()
302            + self.v_proj.parameter_count()
303            + self.o_proj.parameter_count()
304        // Note: rotary_emb typically doesn't have trainable parameters
305    }
306}
307
308impl Layer for StableLMAttention {
309    type Input = Tensor;
310    type Output = Tensor;
311
312    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
313        let _batch_size = 1; // Simplified
314        let seq_len = 1; // Simplified
315
316        // Query, Key, Value projections
317        let q = self.q_proj.forward(input.clone())?;
318        let k = self.k_proj.forward(input.clone())?;
319        let v = self.v_proj.forward(input)?;
320
321        // Apply rotary embeddings
322        let (q_rot, k_rot) = self.rotary_emb.forward(&q, &k, seq_len)?;
323
324        // Repeat KV heads if using grouped-query attention
325        let n_rep = self.num_heads / self.num_kv_heads;
326        let k_repeated = self.repeat_kv(&k_rot, n_rep)?;
327        let v_repeated = self.repeat_kv(&v, n_rep)?;
328
329        // Compute attention scores
330        // Simplified - actual implementation would compute proper attention
331        let attn_output = match (&q_rot, &k_repeated, &v_repeated) {
332            (Tensor::F32(q_arr), Tensor::F32(_k_arr), Tensor::F32(_v_arr)) => {
333                // Placeholder for attention computation
334                Tensor::F32(q_arr.clone())
335            },
336            _ => {
337                return Err(tensor_op_error(
338                    "tensor_operation",
339                    "Unsupported tensor type".to_string(),
340                ))
341            },
342        };
343
344        // Output projection
345        self.o_proj.forward(attn_output)
346    }
347}
348
349/// MLP with SwiGLU activation
350pub struct StableLMMLP {
351    config: StableLMConfig,
352    gate_proj: Linear,
353    up_proj: Linear,
354    down_proj: Linear,
355    device: Device,
356}
357
358impl StableLMMLP {
359    pub fn new(config: &StableLMConfig) -> Self {
360        Self::new_with_device(config, Device::CPU)
361    }
362
363    pub fn new_with_device(config: &StableLMConfig, device: Device) -> Self {
364        let hidden_size = config.hidden_size;
365        let intermediate_size = config.intermediate_size;
366
367        Self {
368            config: config.clone(),
369            gate_proj: Linear::new_with_device(
370                hidden_size,
371                intermediate_size,
372                config.mlp_bias,
373                device,
374            ),
375            up_proj: Linear::new_with_device(
376                hidden_size,
377                intermediate_size,
378                config.mlp_bias,
379                device,
380            ),
381            down_proj: Linear::new_with_device(
382                intermediate_size,
383                hidden_size,
384                config.mlp_bias,
385                device,
386            ),
387            device,
388        }
389    }
390
391    pub fn device(&self) -> &Device {
392        &self.device
393    }
394
395    pub fn parameter_count(&self) -> usize {
396        self.gate_proj.parameter_count()
397            + self.up_proj.parameter_count()
398            + self.down_proj.parameter_count()
399    }
400}
401
402impl Layer for StableLMMLP {
403    type Input = Tensor;
404    type Output = Tensor;
405
406    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
407        let gate = self.gate_proj.forward(input.clone())?;
408        let up = self.up_proj.forward(input)?;
409
410        // Apply activation based on config
411        let activated = match self.config.hidden_act.as_str() {
412            "silu" => {
413                let gate_act = silu(&gate)?;
414                match (&gate_act, &up) {
415                    (Tensor::F32(g), Tensor::F32(u)) => Tensor::F32(g * u),
416                    _ => {
417                        return Err(tensor_op_error(
418                            "tensor_operation",
419                            "Unsupported tensor type".to_string(),
420                        ))
421                    },
422                }
423            },
424            "swiglu" => swiglu(&gate, &up)?,
425            _ => silu(&gate)?, // Default to SiLU
426        };
427
428        self.down_proj.forward(activated)
429    }
430}
431
432/// StableLM Decoder Layer
433pub struct StableLMDecoderLayer {
434    #[allow(dead_code)]
435    config: StableLMConfig,
436    self_attn: StableLMAttention,
437    mlp: StableLMMLP,
438    input_layernorm: RMSNorm,
439    post_attention_layernorm: RMSNorm,
440    device: Device,
441}
442
443impl StableLMDecoderLayer {
444    pub fn new(config: &StableLMConfig) -> Result<Self> {
445        Self::new_with_device(config, Device::CPU)
446    }
447
448    pub fn new_with_device(config: &StableLMConfig, device: Device) -> Result<Self> {
449        Ok(Self {
450            config: config.clone(),
451            self_attn: StableLMAttention::new_with_device(config, device)?,
452            mlp: StableLMMLP::new_with_device(config, device),
453            input_layernorm: RMSNorm::new_with_device(
454                config.hidden_size,
455                config.rms_norm_eps,
456                device,
457            )?,
458            post_attention_layernorm: RMSNorm::new_with_device(
459                config.hidden_size,
460                config.rms_norm_eps,
461                device,
462            )?,
463            device,
464        })
465    }
466
467    pub fn device(&self) -> &Device {
468        &self.device
469    }
470
471    pub fn parameter_count(&self) -> usize {
472        self.self_attn.parameter_count()
473            + self.mlp.parameter_count()
474            + self.input_layernorm.parameter_count()
475            + self.post_attention_layernorm.parameter_count()
476    }
477}
478
479impl Layer for StableLMDecoderLayer {
480    type Input = Tensor;
481    type Output = Tensor;
482
483    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
484        // Pre-norm architecture
485        let residual = input.clone();
486        let hidden_states = self.input_layernorm.forward(input)?;
487        let attn_output = self.self_attn.forward(hidden_states)?;
488
489        // First residual connection
490        let hidden_states = match (&residual, &attn_output) {
491            (Tensor::F32(r), Tensor::F32(a)) => Tensor::F32(r + a),
492            _ => {
493                return Err(tensor_op_error(
494                    "tensor_operation",
495                    "Unsupported tensor type".to_string(),
496                ))
497            },
498        };
499
500        // MLP block
501        let residual = hidden_states.clone();
502        let hidden_states = self.post_attention_layernorm.forward(hidden_states)?;
503        let mlp_output = self.mlp.forward(hidden_states)?;
504
505        // Second residual connection
506        match (&residual, &mlp_output) {
507            (Tensor::F32(r), Tensor::F32(m)) => Ok(Tensor::F32(r + m)),
508            _ => Err(tensor_op_error(
509                "tensor_operation",
510                "Unsupported tensor type".to_string(),
511            )),
512        }
513    }
514}
515
516/// StableLM Embeddings
517pub struct StableLMEmbeddings {
518    word_embeddings: Embedding,
519    device: Device,
520}
521
522impl StableLMEmbeddings {
523    pub fn new(config: &StableLMConfig) -> Result<Self> {
524        Self::new_with_device(config, Device::CPU)
525    }
526
527    pub fn new_with_device(config: &StableLMConfig, device: Device) -> Result<Self> {
528        Ok(Self {
529            word_embeddings: Embedding::new_with_device(
530                config.vocab_size,
531                config.hidden_size,
532                config.pad_token_id.map(|x| x as usize),
533                device,
534            )?,
535            device,
536        })
537    }
538
539    pub fn device(&self) -> &Device {
540        &self.device
541    }
542
543    pub fn parameter_count(&self) -> usize {
544        self.word_embeddings.parameter_count()
545    }
546}
547
548impl Layer for StableLMEmbeddings {
549    type Input = Vec<u32>;
550    type Output = Tensor;
551
552    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
553        self.word_embeddings.forward(input)
554    }
555}
556
557/// StableLM Model Output
558#[derive(Debug)]
559pub struct StableLMOutputs {
560    pub last_hidden_state: Tensor,
561}
562
563/// StableLM Base Model
564pub struct StableLMModel {
565    pub config: StableLMConfig,
566    pub embeddings: StableLMEmbeddings,
567    pub layers: Vec<StableLMDecoderLayer>,
568    pub norm: RMSNorm,
569    device: Device,
570}
571
572impl StableLMModel {
573    pub fn new(config: StableLMConfig) -> Result<Self> {
574        Self::new_with_device(config, Device::CPU)
575    }
576
577    pub fn new_with_device(config: StableLMConfig, device: Device) -> Result<Self> {
578        let embeddings = StableLMEmbeddings::new_with_device(&config, device)?;
579
580        let mut layers = Vec::new();
581        for _ in 0..config.num_hidden_layers {
582            layers.push(StableLMDecoderLayer::new_with_device(&config, device)?);
583        }
584
585        let norm = RMSNorm::new_with_device(config.hidden_size, config.rms_norm_eps, device)?;
586
587        Ok(Self {
588            config,
589            embeddings,
590            layers,
591            norm,
592            device,
593        })
594    }
595
596    pub fn device(&self) -> &Device {
597        &self.device
598    }
599
600    pub fn forward_with_outputs(&self, input_ids: &Tensor) -> Result<StableLMOutputs> {
601        // Convert tensor to token IDs
602        let input_ids_vec = match input_ids {
603            Tensor::I64(ref arr) => arr.mapv(|x| x as u32).into_raw_vec_and_offset().0,
604            _ => {
605                return Err(tensor_op_error(
606                    "tensor_operation",
607                    "Unsupported tensor type".to_string(),
608                ))
609            },
610        };
611        let mut hidden_states = self.embeddings.forward(input_ids_vec)?;
612
613        for layer in &self.layers {
614            hidden_states = layer.forward(hidden_states)?;
615        }
616
617        let last_hidden_state = self.norm.forward(hidden_states)?;
618
619        Ok(StableLMOutputs { last_hidden_state })
620    }
621}
622
623impl Model for StableLMModel {
624    type Config = StableLMConfig;
625    type Input = Tensor;
626    type Output = Tensor;
627
628    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
629        let outputs = self.forward_with_outputs(&input)?;
630        Ok(outputs.last_hidden_state)
631    }
632
633    fn load_pretrained(&mut self, _reader: &mut dyn std::io::Read) -> Result<()> {
634        // Legacy interface - use load_from_path instead for new weight loading
635        Err(
636            trustformers_core::errors::TrustformersError::not_implemented(
637                "Use load_from_path or load_from_huggingface for enhanced weight loading"
638                    .to_string(),
639            ),
640        )
641    }
642
643    fn get_config(&self) -> &Self::Config {
644        &self.config
645    }
646
647    fn num_parameters(&self) -> usize {
648        let embeddings_params = self.embeddings.parameter_count();
649        let layers_params: usize = self.layers.iter().map(|layer| layer.parameter_count()).sum();
650        let norm_params = self.norm.parameter_count();
651
652        embeddings_params + layers_params + norm_params
653    }
654}
655
656/// StableLM Causal LM Output
657#[derive(Debug)]
658pub struct StableLMCausalLMOutputs {
659    pub logits: Tensor,
660    pub hidden_states: Option<Tensor>,
661}
662
663/// StableLM for Causal Language Modeling
664pub struct StableLMForCausalLM {
665    pub model: StableLMModel,
666    pub lm_head: Linear,
667    device: Device,
668}
669
670impl StableLMForCausalLM {
671    pub fn new(config: StableLMConfig) -> Result<Self> {
672        Self::new_with_device(config, Device::CPU)
673    }
674
675    pub fn new_with_device(config: StableLMConfig, device: Device) -> Result<Self> {
676        let model = StableLMModel::new_with_device(config.clone(), device)?;
677        let lm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, false, device);
678
679        Ok(Self {
680            model,
681            lm_head,
682            device,
683        })
684    }
685
686    pub fn device(&self) -> &Device {
687        &self.device
688    }
689
690    pub fn forward_with_outputs(&self, input_ids: &Tensor) -> Result<StableLMCausalLMOutputs> {
691        let outputs = self.model.forward_with_outputs(input_ids)?;
692        let logits = self.lm_head.forward(outputs.last_hidden_state.clone())?;
693
694        Ok(StableLMCausalLMOutputs {
695            logits,
696            hidden_states: Some(outputs.last_hidden_state),
697        })
698    }
699}
700
701impl Model for StableLMForCausalLM {
702    type Config = StableLMConfig;
703    type Input = Tensor;
704    type Output = Tensor;
705
706    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
707        let outputs = self.forward_with_outputs(&input)?;
708        Ok(outputs.logits)
709    }
710
711    fn load_pretrained(&mut self, _reader: &mut dyn std::io::Read) -> Result<()> {
712        // Legacy interface - use load_from_path instead for new weight loading
713        Err(
714            trustformers_core::errors::TrustformersError::not_implemented(
715                "Use load_from_path or load_from_huggingface for enhanced weight loading"
716                    .to_string(),
717            ),
718        )
719    }
720
721    fn get_config(&self) -> &Self::Config {
722        self.model.get_config()
723    }
724
725    fn num_parameters(&self) -> usize {
726        self.model.num_parameters() + self.lm_head.parameter_count()
727    }
728}
729
730impl StableLMForCausalLM {
731    /// Load model weights from a directory containing HuggingFace format weights
732    pub fn load_from_path(&mut self, model_path: impl AsRef<std::path::Path>) -> Result<()> {
733        use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
734
735        let config = WeightLoadingConfig {
736            lazy_loading: true,
737            memory_mapped: false,
738            ..Default::default()
739        };
740
741        let mut loader = auto_create_loader(model_path, Some(config))?;
742
743        // Load embedding weights
744        if let Ok(embed_weights) = loader.load_tensor("model.embed_tokens.weight") {
745            self.model.embeddings.word_embeddings.set_weight(embed_weights)?;
746        }
747
748        // Load layer weights
749        for (i, layer) in self.model.layers.iter_mut().enumerate() {
750            // Load attention weights
751            let attn_prefix = format!("model.layers.{}.self_attn", i);
752
753            if let Ok(q_weight) = loader.load_tensor(&format!("{}.q_proj.weight", attn_prefix)) {
754                layer.self_attn.q_proj.set_weight(q_weight)?;
755            }
756            if let Ok(k_weight) = loader.load_tensor(&format!("{}.k_proj.weight", attn_prefix)) {
757                layer.self_attn.k_proj.set_weight(k_weight)?;
758            }
759            if let Ok(v_weight) = loader.load_tensor(&format!("{}.v_proj.weight", attn_prefix)) {
760                layer.self_attn.v_proj.set_weight(v_weight)?;
761            }
762            if let Ok(o_weight) = loader.load_tensor(&format!("{}.o_proj.weight", attn_prefix)) {
763                layer.self_attn.o_proj.set_weight(o_weight)?;
764            }
765
766            // Load MLP weights
767            let mlp_prefix = format!("model.layers.{}.mlp", i);
768
769            if let Ok(gate_weight) = loader.load_tensor(&format!("{}.gate_proj.weight", mlp_prefix))
770            {
771                layer.mlp.gate_proj.set_weight(gate_weight)?;
772            }
773            if let Ok(up_weight) = loader.load_tensor(&format!("{}.up_proj.weight", mlp_prefix)) {
774                layer.mlp.up_proj.set_weight(up_weight)?;
775            }
776            if let Ok(down_weight) = loader.load_tensor(&format!("{}.down_proj.weight", mlp_prefix))
777            {
778                layer.mlp.down_proj.set_weight(down_weight)?;
779            }
780
781            // Layer norm weights would be loaded here if RMSNorm supported set_weight
782            // For now, skipping layer norm weight loading
783        }
784
785        // Load LM head weights
786        if let Ok(lm_head_weight) = loader.load_tensor("lm_head.weight") {
787            self.lm_head.set_weight(lm_head_weight)?;
788        }
789
790        Ok(())
791    }
792
793    /// Load from HuggingFace Hub model name
794    pub fn load_from_huggingface(&mut self, model_name: &str) -> Result<()> {
795        // Check if model is cached locally
796        let cache_dir = std::env::var("HF_HOME")
797            .or_else(|_| std::env::var("HUGGINGFACE_HUB_CACHE"))
798            .unwrap_or_else(|_| {
799                std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
800                    + "/.cache/huggingface/hub"
801            });
802
803        let model_path = std::path::Path::new(&cache_dir)
804            .join(format!("models--{}", model_name.replace("/", "--")));
805
806        if model_path.exists() {
807            self.load_from_path(&model_path)
808        } else {
809            // Attempt to download the model from HuggingFace Hub
810            self.download_from_huggingface_hub(model_name, &model_path)?;
811            self.load_from_path(&model_path)
812        }
813    }
814
815    /// Download model from HuggingFace Hub
816    fn download_from_huggingface_hub(
817        &self,
818        model_name: &str,
819        model_path: &std::path::Path,
820    ) -> Result<()> {
821        use std::process::Command;
822
823        println!(
824            "Downloading model {} from HuggingFace Hub to {:?}",
825            model_name, model_path
826        );
827
828        // Create the model directory
829        std::fs::create_dir_all(model_path).map_err(|e| {
830            trustformers_core::errors::TrustformersError::io_error(format!(
831                "Failed to create model directory: {}",
832                e
833            ))
834        })?;
835
836        // List of essential files for StableLM models
837        let essential_files = vec![
838            "config.json",
839            "tokenizer.json",
840            "tokenizer_config.json",
841            "pytorch_model.bin", // Try .bin first
842            "model.safetensors", // Fall back to safetensors
843        ];
844
845        let base_url = format!("https://huggingface.co/{}/resolve/main", model_name);
846
847        // Try to download each essential file
848        for file_name in &essential_files {
849            let file_url = format!("{}/{}", base_url, file_name);
850            let file_path = model_path.join(file_name);
851
852            println!("Attempting to download {}", file_url);
853
854            // Convert path to string once for both commands
855            let file_path_str = file_path.to_str().ok_or_else(|| {
856                TrustformersError::invalid_config(format!("Invalid UTF-8 in path: {:?}", file_path))
857            })?;
858
859            // Try using curl first
860            let curl_result = Command::new("curl")
861                .args([
862                    "-L", // Follow redirects
863                    "-f", // Fail on HTTP errors
864                    "-o",
865                    file_path_str,
866                    &file_url,
867                ])
868                .output();
869
870            match curl_result {
871                Ok(output) if output.status.success() => {
872                    println!("Successfully downloaded {}", file_name);
873                    continue;
874                },
875                Ok(output) => {
876                    eprintln!(
877                        "Failed to download {} with curl: {}",
878                        file_name,
879                        String::from_utf8_lossy(&output.stderr)
880                    );
881                },
882                Err(e) => {
883                    println!("curl not available: {}", e);
884                },
885            }
886
887            // Try using wget as fallback
888            let wget_result = Command::new("wget").args(["-O", file_path_str, &file_url]).output();
889
890            match wget_result {
891                Ok(output) if output.status.success() => {
892                    println!("Successfully downloaded {} with wget", file_name);
893                    continue;
894                },
895                Ok(output) => {
896                    eprintln!(
897                        "Failed to download {} with wget: {}",
898                        file_name,
899                        String::from_utf8_lossy(&output.stderr)
900                    );
901                },
902                Err(e) => {
903                    println!("wget not available: {}", e);
904                },
905            }
906
907            // If essential files like config.json or pytorch_model.bin fail, return error
908            if matches!(file_name, &"config.json" | &"pytorch_model.bin") {
909                return Err(trustformers_core::errors::TrustformersError::io_error(format!(
910                    "Failed to download essential file {} for model {}. Please ensure curl or wget is installed and you have internet access.",
911                    file_name, model_name
912                )));
913            }
914        }
915
916        println!(
917            "Successfully downloaded model {} from HuggingFace Hub",
918            model_name
919        );
920        Ok(())
921    }
922}
923
924#[cfg(test)]
925mod tests {
926    use super::*;
927    // Array2 already imported via scirs2_core at top
928
929    #[test]
930    fn test_rms_norm() -> Result<()> {
931        let norm = RMSNorm::new(768, 1e-5)?;
932        let input = Tensor::F32(Array2::ones((2, 768)).into_dyn());
933        let output = norm.forward(input);
934        assert!(output.is_ok());
935        Ok(())
936    }
937
938    #[test]
939    fn test_rotary_embedding() -> Result<()> {
940        let rope = RotaryEmbedding::new(64, 512, 10000.0, 0.25)?;
941        assert_eq!(rope.head_dim, 64);
942        assert_eq!(rope.max_seq_len, 512);
943        assert_eq!(rope.partial_rotary_factor, 0.25);
944        Ok(())
945    }
946
947    #[test]
948    #[ignore] // Heavy test - StableLM 3B model creation, run with --ignored
949    fn test_stablelm_model_creation() -> Result<()> {
950        let config = StableLMConfig::stablelm_3b();
951        let model = StableLMModel::new(config.clone())?;
952
953        assert_eq!(model.layers.len(), config.num_hidden_layers);
954        assert_eq!(model.config.hidden_size, 2560);
955        Ok(())
956    }
957
958    #[test]
959    #[ignore] // Heavy test - StableLM 3B CausalLM, run with --ignored
960    fn test_stablelm_causal_lm() -> Result<()> {
961        let config = StableLMConfig::stablelm_3b();
962        let _model = StableLMForCausalLM::new(config.clone())?;
963
964        // StableLM for CausalLM created successfully - LM head dimensions are internal
965        Ok(())
966    }
967
968    #[test]
969    fn test_grouped_query_attention() -> Result<()> {
970        let mut config = StableLMConfig::stablelm_2_1_6b();
971        config.num_key_value_heads = Some(4);
972
973        let attn = StableLMAttention::new(&config)?;
974        assert_eq!(attn.num_heads, 32);
975        assert_eq!(attn.num_kv_heads, 4);
976
977        // Grouped query attention created successfully - projection dimensions are internal
978        Ok(())
979    }
980
981    #[test]
982    #[ignore] // Heavy test - StableLM 3B device support, run with --ignored
983    fn test_device_support() -> Result<()> {
984        let config = StableLMConfig::stablelm_3b();
985
986        // Test CPU device (default)
987        let model_cpu = StableLMModel::new(config.clone())?;
988        assert_eq!(*model_cpu.device(), Device::CPU);
989
990        // Test explicit CPU device
991        let model_cpu_explicit = StableLMModel::new_with_device(config.clone(), Device::CPU)?;
992        assert_eq!(*model_cpu_explicit.device(), Device::CPU);
993
994        // Test that all components have the correct device
995        assert_eq!(*model_cpu.embeddings.device(), Device::CPU);
996        assert_eq!(*model_cpu.norm.device(), Device::CPU);
997        for layer in &model_cpu.layers {
998            assert_eq!(*layer.device(), Device::CPU);
999            assert_eq!(*layer.self_attn.device(), Device::CPU);
1000            assert_eq!(*layer.mlp.device(), Device::CPU);
1001        }
1002        Ok(())
1003    }
1004
1005    #[test]
1006    #[ignore] // Heavy test - StableLM 3B CausalLM device support (SIGKILL risk), run with --ignored
1007    fn test_causal_lm_device_support() -> Result<()> {
1008        let config = StableLMConfig::stablelm_3b();
1009
1010        // Test CPU device
1011        let model = StableLMForCausalLM::new(config.clone())?;
1012        assert_eq!(*model.device(), Device::CPU);
1013        assert_eq!(*model.model.device(), Device::CPU);
1014
1015        // Test explicit device
1016        let model_explicit = StableLMForCausalLM::new_with_device(config, Device::CPU)?;
1017        assert_eq!(*model_explicit.device(), Device::CPU);
1018        Ok(())
1019    }
1020}