1use crate::stablelm::config::StableLMConfig;
2use scirs2_core::ndarray::{Array1, Array2, Axis}; use trustformers_core::{
4 device::Device,
5 errors::{tensor_op_error, Result, TrustformersError},
6 layers::{Embedding, Linear},
7 ops::activations::{silu, swiglu},
8 tensor::Tensor,
9 traits::{Layer, Model},
10};
11
12pub struct RMSNorm {
14 weight: Tensor,
15 eps: f32,
16 device: Device,
17}
18
19impl RMSNorm {
20 pub fn new(hidden_size: usize, eps: f32) -> Result<Self> {
21 Self::new_with_device(hidden_size, eps, Device::CPU)
22 }
23
24 pub fn new_with_device(hidden_size: usize, eps: f32, device: Device) -> Result<Self> {
25 let weight = Tensor::ones(&[hidden_size])?.to_device_enum(&device)?;
26 Ok(Self {
27 weight,
28 eps,
29 device,
30 })
31 }
32
33 pub fn device(&self) -> &Device {
34 &self.device
35 }
36
37 pub fn parameter_count(&self) -> usize {
38 self.weight.shape().iter().product()
39 }
40}
41
42impl Layer for RMSNorm {
43 type Input = Tensor;
44 type Output = Tensor;
45
46 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
47 match &input {
48 Tensor::F32(arr) => {
49 let mean_sq = arr.mapv(|x| x * x).mean().unwrap_or(0.0);
50 let rms = (mean_sq + self.eps).sqrt();
51 let normalized = arr.mapv(|x| x / rms);
52
53 match &self.weight {
54 Tensor::F32(weight_arr) => {
55 let result = &normalized * weight_arr;
56 Ok(Tensor::F32(result))
57 },
58 _ => Err(tensor_op_error(
59 "tensor_operation",
60 "Unsupported tensor type".to_string(),
61 )),
62 }
63 },
64 _ => Err(tensor_op_error(
65 "tensor_operation",
66 "Unsupported tensor type".to_string(),
67 )),
68 }
69 }
70}
71
72pub struct RotaryEmbedding {
74 sin_cached: Tensor,
75 cos_cached: Tensor,
76 max_seq_len: usize,
77 head_dim: usize,
78 #[allow(dead_code)]
79 base: f32,
80 partial_rotary_factor: f32,
81 device: Device,
82}
83
84impl RotaryEmbedding {
85 pub fn new(
86 head_dim: usize,
87 max_seq_len: usize,
88 base: f32,
89 partial_rotary_factor: f32,
90 ) -> Result<Self> {
91 Self::new_with_device(
92 head_dim,
93 max_seq_len,
94 base,
95 partial_rotary_factor,
96 Device::CPU,
97 )
98 }
99
100 pub fn new_with_device(
101 head_dim: usize,
102 max_seq_len: usize,
103 base: f32,
104 partial_rotary_factor: f32,
105 device: Device,
106 ) -> Result<Self> {
107 let rotary_dim = ((head_dim as f32) * partial_rotary_factor) as usize;
108
109 let inv_freq = Array1::range(0.0, rotary_dim as f32, 2.0)
111 .mapv(|i| 1.0 / base.powf(i / rotary_dim as f32));
112
113 let t = Array1::range(0.0, max_seq_len as f32, 1.0);
114 let freqs = t.view().insert_axis(Axis(1)).dot(&inv_freq.view().insert_axis(Axis(0)));
115
116 let sin_arr =
117 Array2::from_shape_fn((max_seq_len, rotary_dim / 2), |(i, j)| freqs[[i, j]].sin());
118 let cos_arr =
119 Array2::from_shape_fn((max_seq_len, rotary_dim / 2), |(i, j)| freqs[[i, j]].cos());
120
121 let sin_cached = Tensor::F32(sin_arr.into_dyn()).to_device_enum(&device)?;
122 let cos_cached = Tensor::F32(cos_arr.into_dyn()).to_device_enum(&device)?;
123
124 Ok(Self {
125 sin_cached,
126 cos_cached,
127 max_seq_len,
128 head_dim,
129 base,
130 partial_rotary_factor,
131 device,
132 })
133 }
134
135 pub fn device(&self) -> &Device {
136 &self.device
137 }
138
139 pub fn forward(&self, q: &Tensor, k: &Tensor, seq_len: usize) -> Result<(Tensor, Tensor)> {
140 let rotary_dim = ((self.head_dim as f32) * self.partial_rotary_factor) as usize;
141
142 match (q, k, &self.sin_cached, &self.cos_cached) {
143 (
144 Tensor::F32(q_arr),
145 Tensor::F32(k_arr),
146 Tensor::F32(sin_arr),
147 Tensor::F32(cos_arr),
148 ) => {
149 let mut q_rot = q_arr.clone();
151 let mut k_rot = k_arr.clone();
152
153 if rotary_dim > 0 && seq_len <= self.max_seq_len {
155 let q_shape = q_rot.shape().to_vec();
157 let _k_shape = k_rot.shape().to_vec();
158
159 for seq_idx in 0..seq_len {
161 for dim_idx in 0..(rotary_dim / 2) {
162 let cos_val = cos_arr[[seq_idx, dim_idx]];
163 let sin_val = sin_arr[[seq_idx, dim_idx]];
164
165 for batch in 0..q_shape[0] {
168 for head in 0..q_shape[1] {
169 if seq_idx < q_shape[2] && dim_idx < rotary_dim / 2 {
170 let x1_idx = [batch, head, seq_idx, dim_idx * 2];
171 let x2_idx = [batch, head, seq_idx, dim_idx * 2 + 1];
172
173 if x2_idx[3] < q_shape[3] {
174 let q_x1 = q_rot[x1_idx];
175 let q_x2 = q_rot[x2_idx];
176 let k_x1 = k_rot[x1_idx];
177 let k_x2 = k_rot[x2_idx];
178
179 q_rot[x1_idx] = q_x1 * cos_val - q_x2 * sin_val;
181 q_rot[x2_idx] = q_x1 * sin_val + q_x2 * cos_val;
182
183 k_rot[x1_idx] = k_x1 * cos_val - k_x2 * sin_val;
185 k_rot[x2_idx] = k_x1 * sin_val + k_x2 * cos_val;
186 }
187 }
188 }
189 }
190 }
191 }
192 }
193
194 Ok((Tensor::F32(q_rot), Tensor::F32(k_rot)))
195 },
196 _ => Err(tensor_op_error(
197 "tensor_operation",
198 "Unsupported tensor type".to_string(),
199 )),
200 }
201 }
202}
203
204pub struct StableLMAttention {
206 #[allow(dead_code)]
207 config: StableLMConfig,
208 q_proj: Linear,
209 k_proj: Linear,
210 v_proj: Linear,
211 o_proj: Linear,
212 rotary_emb: RotaryEmbedding,
213 #[allow(dead_code)]
214 head_dim: usize,
215 num_heads: usize,
216 num_kv_heads: usize,
217 device: Device,
218}
219
220impl StableLMAttention {
221 pub fn new(config: &StableLMConfig) -> Result<Self> {
222 Self::new_with_device(config, Device::CPU)
223 }
224
225 pub fn new_with_device(config: &StableLMConfig, device: Device) -> Result<Self> {
226 let hidden_size = config.hidden_size;
227 let num_heads = config.num_attention_heads;
228 let num_kv_heads = config.num_key_value_heads.unwrap_or(num_heads);
229 let head_dim = hidden_size / num_heads;
230
231 let q_proj =
232 Linear::new_with_device(hidden_size, hidden_size, config.attention_bias, device);
233 let k_proj = Linear::new_with_device(
234 hidden_size,
235 num_kv_heads * head_dim,
236 config.attention_bias,
237 device,
238 );
239 let v_proj = Linear::new_with_device(
240 hidden_size,
241 num_kv_heads * head_dim,
242 config.attention_bias,
243 device,
244 );
245 let o_proj =
246 Linear::new_with_device(hidden_size, hidden_size, config.attention_bias, device);
247
248 let rotary_emb = RotaryEmbedding::new_with_device(
249 head_dim,
250 config.max_position_embeddings,
251 config.rope_theta,
252 config.partial_rotary_factor,
253 device,
254 )?;
255
256 Ok(Self {
257 config: config.clone(),
258 q_proj,
259 k_proj,
260 v_proj,
261 o_proj,
262 rotary_emb,
263 head_dim,
264 num_heads,
265 num_kv_heads,
266 device,
267 })
268 }
269
270 pub fn device(&self) -> &Device {
271 &self.device
272 }
273
274 fn repeat_kv(&self, hidden_states: &Tensor, n_rep: usize) -> Result<Tensor> {
275 if n_rep == 1 {
276 return Ok(hidden_states.clone());
277 }
278
279 match hidden_states {
280 Tensor::F32(arr) => {
281 let _shape = arr.shape();
283 let mut repeated = arr.clone();
284
285 for _ in 1..n_rep {
287 repeated = repeated.clone(); }
289
290 Ok(Tensor::F32(repeated))
291 },
292 _ => Err(tensor_op_error(
293 "tensor_operation",
294 "Unsupported tensor type".to_string(),
295 )),
296 }
297 }
298
299 pub fn parameter_count(&self) -> usize {
300 self.q_proj.parameter_count()
301 + self.k_proj.parameter_count()
302 + self.v_proj.parameter_count()
303 + self.o_proj.parameter_count()
304 }
306}
307
308impl Layer for StableLMAttention {
309 type Input = Tensor;
310 type Output = Tensor;
311
312 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
313 let _batch_size = 1; let seq_len = 1; let q = self.q_proj.forward(input.clone())?;
318 let k = self.k_proj.forward(input.clone())?;
319 let v = self.v_proj.forward(input)?;
320
321 let (q_rot, k_rot) = self.rotary_emb.forward(&q, &k, seq_len)?;
323
324 let n_rep = self.num_heads / self.num_kv_heads;
326 let k_repeated = self.repeat_kv(&k_rot, n_rep)?;
327 let v_repeated = self.repeat_kv(&v, n_rep)?;
328
329 let attn_output = match (&q_rot, &k_repeated, &v_repeated) {
332 (Tensor::F32(q_arr), Tensor::F32(_k_arr), Tensor::F32(_v_arr)) => {
333 Tensor::F32(q_arr.clone())
335 },
336 _ => {
337 return Err(tensor_op_error(
338 "tensor_operation",
339 "Unsupported tensor type".to_string(),
340 ))
341 },
342 };
343
344 self.o_proj.forward(attn_output)
346 }
347}
348
349pub struct StableLMMLP {
351 config: StableLMConfig,
352 gate_proj: Linear,
353 up_proj: Linear,
354 down_proj: Linear,
355 device: Device,
356}
357
358impl StableLMMLP {
359 pub fn new(config: &StableLMConfig) -> Self {
360 Self::new_with_device(config, Device::CPU)
361 }
362
363 pub fn new_with_device(config: &StableLMConfig, device: Device) -> Self {
364 let hidden_size = config.hidden_size;
365 let intermediate_size = config.intermediate_size;
366
367 Self {
368 config: config.clone(),
369 gate_proj: Linear::new_with_device(
370 hidden_size,
371 intermediate_size,
372 config.mlp_bias,
373 device,
374 ),
375 up_proj: Linear::new_with_device(
376 hidden_size,
377 intermediate_size,
378 config.mlp_bias,
379 device,
380 ),
381 down_proj: Linear::new_with_device(
382 intermediate_size,
383 hidden_size,
384 config.mlp_bias,
385 device,
386 ),
387 device,
388 }
389 }
390
391 pub fn device(&self) -> &Device {
392 &self.device
393 }
394
395 pub fn parameter_count(&self) -> usize {
396 self.gate_proj.parameter_count()
397 + self.up_proj.parameter_count()
398 + self.down_proj.parameter_count()
399 }
400}
401
402impl Layer for StableLMMLP {
403 type Input = Tensor;
404 type Output = Tensor;
405
406 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
407 let gate = self.gate_proj.forward(input.clone())?;
408 let up = self.up_proj.forward(input)?;
409
410 let activated = match self.config.hidden_act.as_str() {
412 "silu" => {
413 let gate_act = silu(&gate)?;
414 match (&gate_act, &up) {
415 (Tensor::F32(g), Tensor::F32(u)) => Tensor::F32(g * u),
416 _ => {
417 return Err(tensor_op_error(
418 "tensor_operation",
419 "Unsupported tensor type".to_string(),
420 ))
421 },
422 }
423 },
424 "swiglu" => swiglu(&gate, &up)?,
425 _ => silu(&gate)?, };
427
428 self.down_proj.forward(activated)
429 }
430}
431
432pub struct StableLMDecoderLayer {
434 #[allow(dead_code)]
435 config: StableLMConfig,
436 self_attn: StableLMAttention,
437 mlp: StableLMMLP,
438 input_layernorm: RMSNorm,
439 post_attention_layernorm: RMSNorm,
440 device: Device,
441}
442
443impl StableLMDecoderLayer {
444 pub fn new(config: &StableLMConfig) -> Result<Self> {
445 Self::new_with_device(config, Device::CPU)
446 }
447
448 pub fn new_with_device(config: &StableLMConfig, device: Device) -> Result<Self> {
449 Ok(Self {
450 config: config.clone(),
451 self_attn: StableLMAttention::new_with_device(config, device)?,
452 mlp: StableLMMLP::new_with_device(config, device),
453 input_layernorm: RMSNorm::new_with_device(
454 config.hidden_size,
455 config.rms_norm_eps,
456 device,
457 )?,
458 post_attention_layernorm: RMSNorm::new_with_device(
459 config.hidden_size,
460 config.rms_norm_eps,
461 device,
462 )?,
463 device,
464 })
465 }
466
467 pub fn device(&self) -> &Device {
468 &self.device
469 }
470
471 pub fn parameter_count(&self) -> usize {
472 self.self_attn.parameter_count()
473 + self.mlp.parameter_count()
474 + self.input_layernorm.parameter_count()
475 + self.post_attention_layernorm.parameter_count()
476 }
477}
478
479impl Layer for StableLMDecoderLayer {
480 type Input = Tensor;
481 type Output = Tensor;
482
483 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
484 let residual = input.clone();
486 let hidden_states = self.input_layernorm.forward(input)?;
487 let attn_output = self.self_attn.forward(hidden_states)?;
488
489 let hidden_states = match (&residual, &attn_output) {
491 (Tensor::F32(r), Tensor::F32(a)) => Tensor::F32(r + a),
492 _ => {
493 return Err(tensor_op_error(
494 "tensor_operation",
495 "Unsupported tensor type".to_string(),
496 ))
497 },
498 };
499
500 let residual = hidden_states.clone();
502 let hidden_states = self.post_attention_layernorm.forward(hidden_states)?;
503 let mlp_output = self.mlp.forward(hidden_states)?;
504
505 match (&residual, &mlp_output) {
507 (Tensor::F32(r), Tensor::F32(m)) => Ok(Tensor::F32(r + m)),
508 _ => Err(tensor_op_error(
509 "tensor_operation",
510 "Unsupported tensor type".to_string(),
511 )),
512 }
513 }
514}
515
516pub struct StableLMEmbeddings {
518 word_embeddings: Embedding,
519 device: Device,
520}
521
522impl StableLMEmbeddings {
523 pub fn new(config: &StableLMConfig) -> Result<Self> {
524 Self::new_with_device(config, Device::CPU)
525 }
526
527 pub fn new_with_device(config: &StableLMConfig, device: Device) -> Result<Self> {
528 Ok(Self {
529 word_embeddings: Embedding::new_with_device(
530 config.vocab_size,
531 config.hidden_size,
532 config.pad_token_id.map(|x| x as usize),
533 device,
534 )?,
535 device,
536 })
537 }
538
539 pub fn device(&self) -> &Device {
540 &self.device
541 }
542
543 pub fn parameter_count(&self) -> usize {
544 self.word_embeddings.parameter_count()
545 }
546}
547
548impl Layer for StableLMEmbeddings {
549 type Input = Vec<u32>;
550 type Output = Tensor;
551
552 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
553 self.word_embeddings.forward(input)
554 }
555}
556
557#[derive(Debug)]
559pub struct StableLMOutputs {
560 pub last_hidden_state: Tensor,
561}
562
563pub struct StableLMModel {
565 pub config: StableLMConfig,
566 pub embeddings: StableLMEmbeddings,
567 pub layers: Vec<StableLMDecoderLayer>,
568 pub norm: RMSNorm,
569 device: Device,
570}
571
572impl StableLMModel {
573 pub fn new(config: StableLMConfig) -> Result<Self> {
574 Self::new_with_device(config, Device::CPU)
575 }
576
577 pub fn new_with_device(config: StableLMConfig, device: Device) -> Result<Self> {
578 let embeddings = StableLMEmbeddings::new_with_device(&config, device)?;
579
580 let mut layers = Vec::new();
581 for _ in 0..config.num_hidden_layers {
582 layers.push(StableLMDecoderLayer::new_with_device(&config, device)?);
583 }
584
585 let norm = RMSNorm::new_with_device(config.hidden_size, config.rms_norm_eps, device)?;
586
587 Ok(Self {
588 config,
589 embeddings,
590 layers,
591 norm,
592 device,
593 })
594 }
595
596 pub fn device(&self) -> &Device {
597 &self.device
598 }
599
600 pub fn forward_with_outputs(&self, input_ids: &Tensor) -> Result<StableLMOutputs> {
601 let input_ids_vec = match input_ids {
603 Tensor::I64(ref arr) => arr.mapv(|x| x as u32).into_raw_vec_and_offset().0,
604 _ => {
605 return Err(tensor_op_error(
606 "tensor_operation",
607 "Unsupported tensor type".to_string(),
608 ))
609 },
610 };
611 let mut hidden_states = self.embeddings.forward(input_ids_vec)?;
612
613 for layer in &self.layers {
614 hidden_states = layer.forward(hidden_states)?;
615 }
616
617 let last_hidden_state = self.norm.forward(hidden_states)?;
618
619 Ok(StableLMOutputs { last_hidden_state })
620 }
621}
622
623impl Model for StableLMModel {
624 type Config = StableLMConfig;
625 type Input = Tensor;
626 type Output = Tensor;
627
628 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
629 let outputs = self.forward_with_outputs(&input)?;
630 Ok(outputs.last_hidden_state)
631 }
632
633 fn load_pretrained(&mut self, _reader: &mut dyn std::io::Read) -> Result<()> {
634 Err(
636 trustformers_core::errors::TrustformersError::not_implemented(
637 "Use load_from_path or load_from_huggingface for enhanced weight loading"
638 .to_string(),
639 ),
640 )
641 }
642
643 fn get_config(&self) -> &Self::Config {
644 &self.config
645 }
646
647 fn num_parameters(&self) -> usize {
648 let embeddings_params = self.embeddings.parameter_count();
649 let layers_params: usize = self.layers.iter().map(|layer| layer.parameter_count()).sum();
650 let norm_params = self.norm.parameter_count();
651
652 embeddings_params + layers_params + norm_params
653 }
654}
655
656#[derive(Debug)]
658pub struct StableLMCausalLMOutputs {
659 pub logits: Tensor,
660 pub hidden_states: Option<Tensor>,
661}
662
663pub struct StableLMForCausalLM {
665 pub model: StableLMModel,
666 pub lm_head: Linear,
667 device: Device,
668}
669
670impl StableLMForCausalLM {
671 pub fn new(config: StableLMConfig) -> Result<Self> {
672 Self::new_with_device(config, Device::CPU)
673 }
674
675 pub fn new_with_device(config: StableLMConfig, device: Device) -> Result<Self> {
676 let model = StableLMModel::new_with_device(config.clone(), device)?;
677 let lm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, false, device);
678
679 Ok(Self {
680 model,
681 lm_head,
682 device,
683 })
684 }
685
686 pub fn device(&self) -> &Device {
687 &self.device
688 }
689
690 pub fn forward_with_outputs(&self, input_ids: &Tensor) -> Result<StableLMCausalLMOutputs> {
691 let outputs = self.model.forward_with_outputs(input_ids)?;
692 let logits = self.lm_head.forward(outputs.last_hidden_state.clone())?;
693
694 Ok(StableLMCausalLMOutputs {
695 logits,
696 hidden_states: Some(outputs.last_hidden_state),
697 })
698 }
699}
700
701impl Model for StableLMForCausalLM {
702 type Config = StableLMConfig;
703 type Input = Tensor;
704 type Output = Tensor;
705
706 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
707 let outputs = self.forward_with_outputs(&input)?;
708 Ok(outputs.logits)
709 }
710
711 fn load_pretrained(&mut self, _reader: &mut dyn std::io::Read) -> Result<()> {
712 Err(
714 trustformers_core::errors::TrustformersError::not_implemented(
715 "Use load_from_path or load_from_huggingface for enhanced weight loading"
716 .to_string(),
717 ),
718 )
719 }
720
721 fn get_config(&self) -> &Self::Config {
722 self.model.get_config()
723 }
724
725 fn num_parameters(&self) -> usize {
726 self.model.num_parameters() + self.lm_head.parameter_count()
727 }
728}
729
730impl StableLMForCausalLM {
731 pub fn load_from_path(&mut self, model_path: impl AsRef<std::path::Path>) -> Result<()> {
733 use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
734
735 let config = WeightLoadingConfig {
736 lazy_loading: true,
737 memory_mapped: false,
738 ..Default::default()
739 };
740
741 let mut loader = auto_create_loader(model_path, Some(config))?;
742
743 if let Ok(embed_weights) = loader.load_tensor("model.embed_tokens.weight") {
745 self.model.embeddings.word_embeddings.set_weight(embed_weights)?;
746 }
747
748 for (i, layer) in self.model.layers.iter_mut().enumerate() {
750 let attn_prefix = format!("model.layers.{}.self_attn", i);
752
753 if let Ok(q_weight) = loader.load_tensor(&format!("{}.q_proj.weight", attn_prefix)) {
754 layer.self_attn.q_proj.set_weight(q_weight)?;
755 }
756 if let Ok(k_weight) = loader.load_tensor(&format!("{}.k_proj.weight", attn_prefix)) {
757 layer.self_attn.k_proj.set_weight(k_weight)?;
758 }
759 if let Ok(v_weight) = loader.load_tensor(&format!("{}.v_proj.weight", attn_prefix)) {
760 layer.self_attn.v_proj.set_weight(v_weight)?;
761 }
762 if let Ok(o_weight) = loader.load_tensor(&format!("{}.o_proj.weight", attn_prefix)) {
763 layer.self_attn.o_proj.set_weight(o_weight)?;
764 }
765
766 let mlp_prefix = format!("model.layers.{}.mlp", i);
768
769 if let Ok(gate_weight) = loader.load_tensor(&format!("{}.gate_proj.weight", mlp_prefix))
770 {
771 layer.mlp.gate_proj.set_weight(gate_weight)?;
772 }
773 if let Ok(up_weight) = loader.load_tensor(&format!("{}.up_proj.weight", mlp_prefix)) {
774 layer.mlp.up_proj.set_weight(up_weight)?;
775 }
776 if let Ok(down_weight) = loader.load_tensor(&format!("{}.down_proj.weight", mlp_prefix))
777 {
778 layer.mlp.down_proj.set_weight(down_weight)?;
779 }
780
781 }
784
785 if let Ok(lm_head_weight) = loader.load_tensor("lm_head.weight") {
787 self.lm_head.set_weight(lm_head_weight)?;
788 }
789
790 Ok(())
791 }
792
793 pub fn load_from_huggingface(&mut self, model_name: &str) -> Result<()> {
795 let cache_dir = std::env::var("HF_HOME")
797 .or_else(|_| std::env::var("HUGGINGFACE_HUB_CACHE"))
798 .unwrap_or_else(|_| {
799 std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
800 + "/.cache/huggingface/hub"
801 });
802
803 let model_path = std::path::Path::new(&cache_dir)
804 .join(format!("models--{}", model_name.replace("/", "--")));
805
806 if model_path.exists() {
807 self.load_from_path(&model_path)
808 } else {
809 self.download_from_huggingface_hub(model_name, &model_path)?;
811 self.load_from_path(&model_path)
812 }
813 }
814
815 fn download_from_huggingface_hub(
817 &self,
818 model_name: &str,
819 model_path: &std::path::Path,
820 ) -> Result<()> {
821 use std::process::Command;
822
823 println!(
824 "Downloading model {} from HuggingFace Hub to {:?}",
825 model_name, model_path
826 );
827
828 std::fs::create_dir_all(model_path).map_err(|e| {
830 trustformers_core::errors::TrustformersError::io_error(format!(
831 "Failed to create model directory: {}",
832 e
833 ))
834 })?;
835
836 let essential_files = vec![
838 "config.json",
839 "tokenizer.json",
840 "tokenizer_config.json",
841 "pytorch_model.bin", "model.safetensors", ];
844
845 let base_url = format!("https://huggingface.co/{}/resolve/main", model_name);
846
847 for file_name in &essential_files {
849 let file_url = format!("{}/{}", base_url, file_name);
850 let file_path = model_path.join(file_name);
851
852 println!("Attempting to download {}", file_url);
853
854 let file_path_str = file_path.to_str().ok_or_else(|| {
856 TrustformersError::invalid_config(format!("Invalid UTF-8 in path: {:?}", file_path))
857 })?;
858
859 let curl_result = Command::new("curl")
861 .args([
862 "-L", "-f", "-o",
865 file_path_str,
866 &file_url,
867 ])
868 .output();
869
870 match curl_result {
871 Ok(output) if output.status.success() => {
872 println!("Successfully downloaded {}", file_name);
873 continue;
874 },
875 Ok(output) => {
876 eprintln!(
877 "Failed to download {} with curl: {}",
878 file_name,
879 String::from_utf8_lossy(&output.stderr)
880 );
881 },
882 Err(e) => {
883 println!("curl not available: {}", e);
884 },
885 }
886
887 let wget_result = Command::new("wget").args(["-O", file_path_str, &file_url]).output();
889
890 match wget_result {
891 Ok(output) if output.status.success() => {
892 println!("Successfully downloaded {} with wget", file_name);
893 continue;
894 },
895 Ok(output) => {
896 eprintln!(
897 "Failed to download {} with wget: {}",
898 file_name,
899 String::from_utf8_lossy(&output.stderr)
900 );
901 },
902 Err(e) => {
903 println!("wget not available: {}", e);
904 },
905 }
906
907 if matches!(file_name, &"config.json" | &"pytorch_model.bin") {
909 return Err(trustformers_core::errors::TrustformersError::io_error(format!(
910 "Failed to download essential file {} for model {}. Please ensure curl or wget is installed and you have internet access.",
911 file_name, model_name
912 )));
913 }
914 }
915
916 println!(
917 "Successfully downloaded model {} from HuggingFace Hub",
918 model_name
919 );
920 Ok(())
921 }
922}
923
924#[cfg(test)]
925mod tests {
926 use super::*;
927 #[test]
930 fn test_rms_norm() -> Result<()> {
931 let norm = RMSNorm::new(768, 1e-5)?;
932 let input = Tensor::F32(Array2::ones((2, 768)).into_dyn());
933 let output = norm.forward(input);
934 assert!(output.is_ok());
935 Ok(())
936 }
937
938 #[test]
939 fn test_rotary_embedding() -> Result<()> {
940 let rope = RotaryEmbedding::new(64, 512, 10000.0, 0.25)?;
941 assert_eq!(rope.head_dim, 64);
942 assert_eq!(rope.max_seq_len, 512);
943 assert_eq!(rope.partial_rotary_factor, 0.25);
944 Ok(())
945 }
946
947 #[test]
948 #[ignore] fn test_stablelm_model_creation() -> Result<()> {
950 let config = StableLMConfig::stablelm_3b();
951 let model = StableLMModel::new(config.clone())?;
952
953 assert_eq!(model.layers.len(), config.num_hidden_layers);
954 assert_eq!(model.config.hidden_size, 2560);
955 Ok(())
956 }
957
958 #[test]
959 #[ignore] fn test_stablelm_causal_lm() -> Result<()> {
961 let config = StableLMConfig::stablelm_3b();
962 let _model = StableLMForCausalLM::new(config.clone())?;
963
964 Ok(())
966 }
967
968 #[test]
969 fn test_grouped_query_attention() -> Result<()> {
970 let mut config = StableLMConfig::stablelm_2_1_6b();
971 config.num_key_value_heads = Some(4);
972
973 let attn = StableLMAttention::new(&config)?;
974 assert_eq!(attn.num_heads, 32);
975 assert_eq!(attn.num_kv_heads, 4);
976
977 Ok(())
979 }
980
981 #[test]
982 #[ignore] fn test_device_support() -> Result<()> {
984 let config = StableLMConfig::stablelm_3b();
985
986 let model_cpu = StableLMModel::new(config.clone())?;
988 assert_eq!(*model_cpu.device(), Device::CPU);
989
990 let model_cpu_explicit = StableLMModel::new_with_device(config.clone(), Device::CPU)?;
992 assert_eq!(*model_cpu_explicit.device(), Device::CPU);
993
994 assert_eq!(*model_cpu.embeddings.device(), Device::CPU);
996 assert_eq!(*model_cpu.norm.device(), Device::CPU);
997 for layer in &model_cpu.layers {
998 assert_eq!(*layer.device(), Device::CPU);
999 assert_eq!(*layer.self_attn.device(), Device::CPU);
1000 assert_eq!(*layer.mlp.device(), Device::CPU);
1001 }
1002 Ok(())
1003 }
1004
1005 #[test]
1006 #[ignore] fn test_causal_lm_device_support() -> Result<()> {
1008 let config = StableLMConfig::stablelm_3b();
1009
1010 let model = StableLMForCausalLM::new(config.clone())?;
1012 assert_eq!(*model.device(), Device::CPU);
1013 assert_eq!(*model.model.device(), Device::CPU);
1014
1015 let model_explicit = StableLMForCausalLM::new_with_device(config, Device::CPU)?;
1017 assert_eq!(*model_explicit.device(), Device::CPU);
1018 Ok(())
1019 }
1020}