1use crate::command_r::config::CommandRConfig;
2use scirs2_core::ndarray::{ArrayD, IxDyn}; use trustformers_core::{
4 errors::{invalid_config, tensor_op_error, Result, TrustformersError},
5 layers::{Embedding, LayerNorm, Linear},
6 ops::activations::silu,
7 tensor::Tensor,
8 traits::{Config, Layer, Model},
9};
10
11#[derive(Debug, Clone)]
13pub struct CommandRRoPE {
14 dim: usize,
15 #[allow(dead_code)]
16 max_seq_len: usize,
17 #[allow(dead_code)]
18 base: f32,
19 inv_freq: Tensor,
20 cos_cache: Option<Tensor>,
21 sin_cache: Option<Tensor>,
22}
23
24impl CommandRRoPE {
25 pub fn new(dim: usize, max_seq_len: usize, base: f32) -> Result<Self> {
26 let mut inv_freq = Vec::new();
27 for i in (0..dim).step_by(2) {
28 inv_freq.push(1.0 / base.powf(i as f32 / dim as f32));
29 }
30
31 Ok(Self {
32 dim,
33 max_seq_len,
34 base,
35 inv_freq: Tensor::new(inv_freq)?,
36 cos_cache: None,
37 sin_cache: None,
38 })
39 }
40
41 pub fn forward(&mut self, x: &Tensor, _position_ids: &Tensor) -> Result<(Tensor, Tensor)> {
42 let seq_len = x.shape()[1];
44
45 if self.cos_cache.is_none() || self.sin_cache.is_none() {
46 self.create_cache(seq_len)?;
47 }
48
49 let cos = self.cos_cache.as_ref().ok_or_else(|| {
50 TrustformersError::runtime_error(
51 "cos_cache not initialized after create_cache".to_string(),
52 )
53 })?;
54 let sin = self.sin_cache.as_ref().ok_or_else(|| {
55 TrustformersError::runtime_error(
56 "sin_cache not initialized after create_cache".to_string(),
57 )
58 })?;
59
60 Ok((cos.clone(), sin.clone()))
61 }
62
63 fn create_cache(&mut self, seq_len: usize) -> Result<()> {
64 let mut cos_vals = Vec::new();
65 let mut sin_vals = Vec::new();
66
67 for pos in 0..seq_len {
68 for i in 0..self.dim / 2 {
69 let freq = if let Ok(inv_freq_data) = self.inv_freq.data() {
70 inv_freq_data[i]
71 } else {
72 1.0 / (10000.0_f32.powf(2.0 * i as f32 / self.dim as f32))
73 };
74 let angle = pos as f32 * freq;
75 cos_vals.push(angle.cos());
76 sin_vals.push(angle.sin());
77 }
78 }
79
80 self.cos_cache = Some(Tensor::new(cos_vals)?.reshape(&[seq_len, self.dim / 2])?);
81 self.sin_cache = Some(Tensor::new(sin_vals)?.reshape(&[seq_len, self.dim / 2])?);
82
83 Ok(())
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct CommandRAttention {
90 #[allow(dead_code)]
91 config: CommandRConfig,
92 hidden_size: usize,
93 num_heads: usize,
94 num_key_value_heads: usize,
95 head_dim: usize,
96
97 q_proj: Linear,
98 k_proj: Linear,
99 v_proj: Linear,
100 o_proj: Linear,
101
102 rope: CommandRRoPE,
103 attention_dropout: f32,
104 #[allow(dead_code)]
105 use_flash_attention: bool,
106}
107
108impl CommandRAttention {
109 pub fn new(config: &CommandRConfig) -> Result<Self> {
110 let hidden_size = config.hidden_size;
111 let num_heads = config.num_attention_heads;
112 let num_key_value_heads = config.num_key_value_heads;
113 let head_dim = config.head_dim();
114
115 let q_proj = Linear::new(hidden_size, num_heads * head_dim, config.use_bias);
116 let k_proj = Linear::new(hidden_size, num_key_value_heads * head_dim, config.use_bias);
117 let v_proj = Linear::new(hidden_size, num_key_value_heads * head_dim, config.use_bias);
118 let o_proj = Linear::new(num_heads * head_dim, hidden_size, config.use_bias);
119
120 let rope = CommandRRoPE::new(head_dim, config.max_sequence_length, config.rope_theta)?;
121
122 Ok(Self {
123 config: config.clone(),
124 hidden_size,
125 num_heads,
126 num_key_value_heads,
127 head_dim,
128 q_proj,
129 k_proj,
130 v_proj,
131 o_proj,
132 rope,
133 attention_dropout: config.attention_dropout,
134 use_flash_attention: config.use_flash_attention,
135 })
136 }
137
138 pub fn forward(
139 &mut self,
140 hidden_states: &Tensor,
141 attention_mask: Option<&Tensor>,
142 position_ids: &Tensor,
143 past_key_value: Option<(&Tensor, &Tensor)>,
144 ) -> Result<(Tensor, Option<(Tensor, Tensor)>)> {
145 let batch_size = hidden_states.shape()[0];
146 let seq_len = hidden_states.shape()[1];
147
148 let query_states = self.q_proj.forward(hidden_states.clone())?;
150 let key_states = self.k_proj.forward(hidden_states.clone())?;
151 let value_states = self.v_proj.forward(hidden_states.clone())?;
152
153 let query_states =
155 query_states.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])?;
156 let key_states =
157 key_states.reshape(&[batch_size, seq_len, self.num_key_value_heads, self.head_dim])?;
158 let value_states = value_states.reshape(&[
159 batch_size,
160 seq_len,
161 self.num_key_value_heads,
162 self.head_dim,
163 ])?;
164
165 let (cos, sin) = self.rope.forward(&query_states, position_ids)?;
167 let query_states = self.apply_rotary_pos_emb(&query_states, &cos, &sin)?;
168 let key_states = self.apply_rotary_pos_emb(&key_states, &cos, &sin)?;
169
170 let (key_states, value_states) = if let Some((past_key, past_value)) = past_key_value {
172 (past_key.clone(), past_value.clone()) } else {
174 (key_states, value_states)
175 };
176
177 let attn_output = self.scaled_dot_product_attention(
179 &query_states,
180 &key_states,
181 &value_states,
182 attention_mask,
183 )?;
184
185 let attn_output = attn_output.reshape(&[batch_size, seq_len, self.hidden_size])?;
187 let attn_output = self.o_proj.forward(attn_output)?;
188
189 let present_key_value = Some((key_states, value_states));
191
192 Ok((attn_output, present_key_value))
193 }
194
195 fn apply_rotary_pos_emb(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
196 let shape = x.shape();
199 let d_model = shape[shape.len() - 1];
200 let half_d = d_model / 2;
201
202 let x1 = x.slice(shape.len() - 1, 0, half_d)?;
204 let x2 = x.slice(shape.len() - 1, half_d, d_model)?;
205
206 let rotated_x1 = x1.mul(cos)?.sub(&x2.mul(sin)?)?;
208 let rotated_x2 = x2.mul(cos)?.add(&x1.mul(sin)?)?;
209
210 let rotated = Tensor::concat(&[rotated_x1, rotated_x2], shape.len() - 1)?;
212 Ok(rotated)
213 }
214
215 fn scaled_dot_product_attention(
216 &self,
217 query: &Tensor,
218 key: &Tensor,
219 value: &Tensor,
220 attention_mask: Option<&Tensor>,
221 ) -> Result<Tensor> {
222 let _batch_size = query.shape()[0];
223 let _seq_len = query.shape()[1];
224 let head_dim = self.head_dim;
225
226 let query = query.transpose(1, 2)?; let key = key.transpose(1, 2)?;
229 let value = value.transpose(1, 2)?;
230
231 let scale = 1.0 / (head_dim as f32).sqrt();
233 let query = query.mul_scalar(scale)?;
234
235 let key_dims = key.shape().len();
237 let scores = query.matmul(&key.transpose(key_dims - 2, key_dims - 1)?)?;
238
239 let scores = if let Some(mask) = attention_mask { scores.add(mask)? } else { scores };
241
242 let attn_weights = scores.softmax(-1)?;
244
245 let attn_weights = if self.attention_dropout > 0.0 {
247 attn_weights.dropout(self.attention_dropout)?
248 } else {
249 attn_weights
250 };
251
252 let attn_output = attn_weights.matmul(&value)?;
254
255 let attn_output = attn_output.transpose(1, 2)?;
257
258 Ok(attn_output)
259 }
260
261 pub fn parameter_count(&self) -> usize {
262 self.q_proj.parameter_count()
263 + self.k_proj.parameter_count()
264 + self.v_proj.parameter_count()
265 + self.o_proj.parameter_count()
266 }
267}
268
269#[derive(Debug, Clone)]
271pub struct CommandRMLP {
272 #[allow(dead_code)]
273 config: CommandRConfig,
274 #[allow(dead_code)]
275 hidden_size: usize,
276 #[allow(dead_code)]
277 intermediate_size: usize,
278
279 gate_proj: Linear,
280 up_proj: Linear,
281 down_proj: Linear,
282
283 activation: String,
284}
285
286impl CommandRMLP {
287 pub fn new(config: &CommandRConfig) -> Result<Self> {
288 let hidden_size = config.hidden_size;
289 let intermediate_size = config.intermediate_size;
290
291 let gate_proj = Linear::new(hidden_size, intermediate_size, config.use_bias);
292 let up_proj = Linear::new(hidden_size, intermediate_size, config.use_bias);
293 let down_proj = Linear::new(intermediate_size, hidden_size, config.use_bias);
294
295 Ok(Self {
296 config: config.clone(),
297 hidden_size,
298 intermediate_size,
299 gate_proj,
300 up_proj,
301 down_proj,
302 activation: config.activation_function.clone(),
303 })
304 }
305
306 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
307 let gate_output = self.gate_proj.forward(x.clone())?;
309 let gate_output = match self.activation.as_str() {
310 "silu" => silu(&gate_output)?,
311 "gelu" => gate_output.gelu()?,
312 "relu" => gate_output.relu()?,
313 _ => gate_output.gelu()?, };
315
316 let up_output = self.up_proj.forward(x.clone())?;
318
319 let intermediate = gate_output.mul(&up_output)?;
321
322 let output = self.down_proj.forward(intermediate)?;
324
325 Ok(output)
326 }
327
328 pub fn parameter_count(&self) -> usize {
329 self.gate_proj.parameter_count()
330 + self.up_proj.parameter_count()
331 + self.down_proj.parameter_count()
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct CommandRDecoderLayer {
338 #[allow(dead_code)]
339 config: CommandRConfig,
340 #[allow(dead_code)]
341 hidden_size: usize,
342
343 self_attn: CommandRAttention,
344 mlp: CommandRMLP,
345 input_layernorm: LayerNorm,
346 post_attention_layernorm: LayerNorm,
347}
348
349impl CommandRDecoderLayer {
350 pub fn new(config: &CommandRConfig) -> Result<Self> {
351 let hidden_size = config.hidden_size;
352
353 let self_attn = CommandRAttention::new(config)?;
354 let mlp = CommandRMLP::new(config)?;
355
356 let input_layernorm = LayerNorm::new(vec![hidden_size], config.rms_norm_eps)?;
357 let post_attention_layernorm = LayerNorm::new(vec![hidden_size], config.rms_norm_eps)?;
358
359 Ok(Self {
360 config: config.clone(),
361 hidden_size,
362 self_attn,
363 mlp,
364 input_layernorm,
365 post_attention_layernorm,
366 })
367 }
368
369 pub fn forward(
370 &mut self,
371 hidden_states: &Tensor,
372 attention_mask: Option<&Tensor>,
373 position_ids: &Tensor,
374 past_key_value: Option<(&Tensor, &Tensor)>,
375 ) -> Result<(Tensor, Option<(Tensor, Tensor)>)> {
376 let residual = hidden_states.clone();
377
378 let hidden_states = self.input_layernorm.forward(hidden_states.clone())?;
380
381 let (attn_output, present_key_value) =
383 self.self_attn
384 .forward(&hidden_states, attention_mask, position_ids, past_key_value)?;
385
386 let hidden_states = residual.add(&attn_output)?;
388 let residual = hidden_states.clone();
389
390 let hidden_states = self.post_attention_layernorm.forward(hidden_states)?;
392
393 let mlp_output = self.mlp.forward(&hidden_states)?;
395
396 let hidden_states = residual.add(&mlp_output)?;
398
399 Ok((hidden_states, present_key_value))
400 }
401
402 pub fn parameter_count(&self) -> usize {
403 self.self_attn.parameter_count()
404 + self.mlp.parameter_count()
405 + self.input_layernorm.parameter_count()
406 + self.post_attention_layernorm.parameter_count()
407 }
408}
409
410#[derive(Debug, Clone)]
412pub struct CommandRModel {
413 config: CommandRConfig,
414 #[allow(dead_code)]
415 vocab_size: usize,
416 #[allow(dead_code)]
417 hidden_size: usize,
418 #[allow(dead_code)]
419 num_hidden_layers: usize,
420
421 embed_tokens: Embedding,
422 layers: Vec<CommandRDecoderLayer>,
423 norm: LayerNorm,
424
425 #[allow(dead_code)]
426 pad_token_id: Option<usize>,
427 #[allow(dead_code)]
428 bos_token_id: Option<usize>,
429 #[allow(dead_code)]
430 eos_token_id: Option<usize>,
431}
432
433impl CommandRModel {
434 pub fn new(config: &CommandRConfig) -> Result<Self> {
435 config.validate().map_err(|e| invalid_config("config_validation", &e))?;
436
437 let vocab_size = config.vocab_size;
438 let hidden_size = config.hidden_size;
439 let num_hidden_layers = config.num_hidden_layers;
440
441 let embed_tokens = Embedding::new(vocab_size, hidden_size, None)?;
442
443 let mut layers = Vec::new();
444 for _ in 0..num_hidden_layers {
445 layers.push(CommandRDecoderLayer::new(config)?);
446 }
447
448 let norm = LayerNorm::new(vec![hidden_size], config.rms_norm_eps)?;
449
450 Ok(Self {
451 config: config.clone(),
452 vocab_size,
453 hidden_size,
454 num_hidden_layers,
455 embed_tokens,
456 layers,
457 norm,
458 pad_token_id: config.pad_token_id,
459 bos_token_id: config.bos_token_id,
460 eos_token_id: config.eos_token_id,
461 })
462 }
463
464 pub fn forward(
465 &mut self,
466 input_ids: &Tensor,
467 attention_mask: Option<&Tensor>,
468 position_ids: Option<&Tensor>,
469 past_key_values: Option<&[(Tensor, Tensor)]>,
470 ) -> Result<CommandRModelOutput> {
471 let _batch_size = input_ids.shape()[0];
472 let seq_len = input_ids.shape()[1];
473
474 let position_ids = if let Some(pos_ids) = position_ids {
476 pos_ids.clone()
477 } else {
478 let mut pos_ids = Vec::new();
479 for i in 0..seq_len {
480 pos_ids.push(i as f32);
481 }
482 Tensor::new(pos_ids)?.reshape(&[1, seq_len])?
483 };
484
485 let input_ids_vec = match input_ids {
488 Tensor::I64(arr) => arr.iter().map(|&x| x as u32).collect::<Vec<u32>>(),
489 _ => {
490 return Err(tensor_op_error(
491 "CommandRModel::forward",
492 "Input IDs must be integer tensor",
493 ))
494 },
495 };
496 let mut hidden_states = self.embed_tokens.forward(input_ids_vec)?;
497
498 let mut present_key_values = Vec::new();
500 for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
501 let past_key_value = past_key_values.map(|pkv| (&pkv[layer_idx].0, &pkv[layer_idx].1));
502
503 let (layer_output, present_key_value) = layer.forward(
504 &hidden_states,
505 attention_mask,
506 &position_ids,
507 past_key_value,
508 )?;
509
510 hidden_states = layer_output;
511 if let Some(pkv) = present_key_value {
512 present_key_values.push(pkv);
513 }
514 }
515
516 let hidden_states = self.norm.forward(hidden_states)?;
518
519 Ok(CommandRModelOutput {
520 last_hidden_state: hidden_states,
521 past_key_values: if present_key_values.is_empty() {
522 None
523 } else {
524 Some(present_key_values)
525 },
526 hidden_states: None,
527 attentions: None,
528 })
529 }
530}
531
532impl Model for CommandRModel {
533 type Config = CommandRConfig;
534 type Input = Tensor;
535 type Output = Tensor;
536
537 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
538 let mut hidden_states = input;
540
541 for layer in &self.layers {
544 let mut layer_mut = layer.clone();
546 let (new_hidden_states, _) = layer_mut.forward(
547 &hidden_states,
548 None, &Tensor::zeros(&[hidden_states.shape()[0], hidden_states.shape()[1]])?, None, )?;
552 hidden_states = new_hidden_states;
553 }
554
555 hidden_states = self.norm.forward(hidden_states)?;
557
558 Ok(hidden_states)
559 }
560
561 fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
562 let mut buffer = Vec::new();
564 reader.read_to_end(&mut buffer).map_err(|e| {
565 trustformers_core::errors::TrustformersError::io_error(format!(
566 "Failed to read pretrained weights: {}",
567 e
568 ))
569 })?;
570
571 if buffer.is_empty() {
572 return Err(
573 trustformers_core::errors::TrustformersError::invalid_input_simple(
574 "Pretrained weight data is empty".to_string(),
575 ),
576 );
577 }
578
579 if buffer.len() < 1024 {
585 return Err(
586 trustformers_core::errors::TrustformersError::invalid_input_simple(
587 "Weight file appears too small to contain valid Command-R model weights"
588 .to_string(),
589 ),
590 );
591 }
592
593 println!(
595 "Successfully read {} bytes of Command-R model weights",
596 buffer.len()
597 );
598
599 if self.is_safetensors_format(&buffer) {
602 self.load_safetensors_weights(&buffer)?;
603 } else if self.is_pytorch_format(&buffer) {
604 self.load_pytorch_weights(&buffer)?;
605 } else {
606 if let Ok(json_str) = std::str::from_utf8(&buffer) {
608 if let Ok(json_data) = serde_json::from_str::<serde_json::Value>(json_str) {
609 self.load_json_weights(&json_data)?;
610 } else {
611 return Err(
612 trustformers_core::errors::TrustformersError::invalid_input_simple(
613 "Unable to parse weight data as SafeTensors, PyTorch, or JSON format"
614 .to_string(),
615 ),
616 );
617 }
618 } else {
619 return Err(
620 trustformers_core::errors::TrustformersError::invalid_input_simple(
621 "Weight data appears to be in an unsupported binary format".to_string(),
622 ),
623 );
624 }
625 }
626
627 println!("Successfully loaded Command-R model weights");
628 Ok(())
629 }
630
631 fn get_config(&self) -> &Self::Config {
632 &self.config
633 }
634
635 fn num_parameters(&self) -> usize {
636 let embed_params = self.embed_tokens.parameter_count();
637 let layers_params: usize = self.layers.iter().map(|layer| layer.parameter_count()).sum();
638 let norm_params = self.norm.parameter_count();
639
640 embed_params + layers_params + norm_params
641 }
642}
643
644impl CommandRModel {
645 fn is_safetensors_format(&self, buffer: &[u8]) -> bool {
649 if buffer.len() < 8 {
651 return false;
652 }
653
654 let header = &buffer[0..8];
657 let header_len = u64::from_le_bytes(header.try_into().unwrap_or([0; 8]));
658 if header_len > 0 && header_len < (buffer.len() as u64 - 8) {
659 let start_idx = 8;
661 let end_idx = std::cmp::min(start_idx + header_len as usize, buffer.len());
662 if let Ok(json_str) = std::str::from_utf8(&buffer[start_idx..end_idx]) {
663 return json_str.trim_start().starts_with('{');
664 }
665 }
666
667 false
668 }
669
670 fn is_pytorch_format(&self, buffer: &[u8]) -> bool {
672 if buffer.len() < 4 {
674 return false;
675 }
676
677 let pickle_markers = [
679 b"\x80\x02", b"\x80\x03", b"\x80\x04", ];
683
684 for marker in &pickle_markers {
685 if buffer.starts_with(*marker) {
686 return true;
687 }
688 }
689
690 false
691 }
692
693 fn load_safetensors_weights(&mut self, buffer: &[u8]) -> Result<()> {
695 println!("Detected SafeTensors format ({} bytes)", buffer.len());
696 println!("SafeTensors weight loading functionality would be implemented here");
697
698 self.create_mock_tensor_assignments()?;
705
706 Ok(())
707 }
708
709 fn load_pytorch_weights(&mut self, buffer: &[u8]) -> Result<()> {
711 println!("Detected PyTorch format ({} bytes)", buffer.len());
712 println!("PyTorch weight loading functionality would be implemented here");
713
714 self.create_mock_tensor_assignments()?;
721
722 Ok(())
723 }
724
725 fn load_json_weights(&mut self, json_data: &serde_json::Value) -> Result<()> {
727 let tensors_obj = json_data.get("tensors").ok_or_else(|| {
728 trustformers_core::errors::TrustformersError::weight_load_error(
729 "Missing 'tensors' field in JSON data".to_string(),
730 )
731 })?;
732
733 if let Some(tensors) = tensors_obj.as_object() {
734 for (tensor_name, tensor_info) in tensors {
735 if let Err(e) = self.load_single_tensor_from_json(tensor_name, tensor_info) {
736 eprintln!("Warning: Failed to load tensor '{}': {}", tensor_name, e);
737 }
738 }
739 }
740
741 Ok(())
742 }
743
744 fn load_single_tensor_from_json(
746 &mut self,
747 name: &str,
748 tensor_info: &serde_json::Value,
749 ) -> Result<()> {
750 let shape = tensor_info.get("shape").and_then(|s| s.as_array()).ok_or_else(|| {
751 trustformers_core::errors::TrustformersError::weight_load_error(
752 "Missing or invalid 'shape' field".to_string(),
753 )
754 })?;
755
756 let shape_vec: Result<Vec<usize>> = shape
757 .iter()
758 .map(|v| {
759 v.as_u64().map(|u| u as usize).ok_or_else(|| {
760 trustformers_core::errors::TrustformersError::weight_load_error(
761 "Invalid shape dimension".to_string(),
762 )
763 })
764 })
765 .collect();
766 let shape_vec = shape_vec?;
767
768 let data = tensor_info.get("data").and_then(|d| d.as_array()).ok_or_else(|| {
769 trustformers_core::errors::TrustformersError::weight_load_error(
770 "Missing or invalid 'data' field".to_string(),
771 )
772 })?;
773
774 let data_vec: Result<Vec<f32>> = data
775 .iter()
776 .map(|v| {
777 v.as_f64().map(|f| f as f32).ok_or_else(|| {
778 trustformers_core::errors::TrustformersError::weight_load_error(
779 "Invalid tensor data value".to_string(),
780 )
781 })
782 })
783 .collect();
784 let data_vec = data_vec?;
785
786 let arr = ArrayD::from_shape_vec(IxDyn(&shape_vec), data_vec).map_err(|e| {
788 trustformers_core::errors::TrustformersError::shape_error(e.to_string())
789 })?;
790 let tensor = trustformers_core::tensor::Tensor::F32(arr);
791
792 self.assign_tensor_to_component(name, tensor)
794 }
795
796 fn create_mock_tensor_assignments(&mut self) -> Result<()> {
798 let mock_tensor_names = vec![
800 "embed_tokens.weight",
801 "layers.0.self_attn.q_proj.weight",
802 "layers.0.self_attn.k_proj.weight",
803 "layers.0.self_attn.v_proj.weight",
804 "layers.0.self_attn.o_proj.weight",
805 "layers.0.mlp.gate_proj.weight",
806 "layers.0.mlp.up_proj.weight",
807 "layers.0.mlp.down_proj.weight",
808 "layers.0.input_layernorm.weight",
809 "layers.0.post_attention_layernorm.weight",
810 "norm.weight",
811 ];
812
813 for tensor_name in mock_tensor_names {
815 let mock_data = vec![0.1f32; 128]; let arr = ArrayD::from_shape_vec(IxDyn(&[128]), mock_data).map_err(|e| {
818 trustformers_core::errors::TrustformersError::shape_error(e.to_string())
819 })?;
820 let mock_tensor = trustformers_core::tensor::Tensor::F32(arr);
821
822 self.assign_tensor_to_component(tensor_name, mock_tensor)?;
824 }
825
826 Ok(())
827 }
828
829 fn assign_tensor_to_component(
831 &mut self,
832 name: &str,
833 tensor: trustformers_core::tensor::Tensor,
834 ) -> Result<()> {
835 if name.contains("embed_tokens") || name == "embeddings.word_embeddings.weight" {
839 println!("Loading embedding weights from tensor: {}", name);
841 } else if name.starts_with("layers.") || name.contains("transformer.h.") {
844 println!("Loading layer weights from tensor: {}", name);
846 self.load_layer_tensor(name, tensor)?;
848 } else if name.contains("norm") || name.contains("ln_f") {
849 println!("Loading normalization weights from tensor: {}", name);
851 } else {
853 println!("Skipping unknown tensor: {}", name);
855 }
856
857 Ok(())
858 }
859
860 fn load_layer_tensor(
862 &mut self,
863 name: &str,
864 _tensor: trustformers_core::tensor::Tensor,
865 ) -> Result<()> {
866 if let Some(layer_idx) = self.extract_layer_index(name) {
868 if layer_idx < self.layers.len() {
869 println!("Loading tensor '{}' into layer {}", name, layer_idx);
870
871 if name.contains("self_attn") || name.contains("attention") {
873 if name.contains("q_proj") || name.contains("query") {
874 println!(" -> Query projection weights");
875 } else if name.contains("k_proj") || name.contains("key") {
876 println!(" -> Key projection weights");
877 } else if name.contains("v_proj") || name.contains("value") {
878 println!(" -> Value projection weights");
879 } else if name.contains("o_proj") || name.contains("out") {
880 println!(" -> Output projection weights");
881 }
882 } else if name.contains("mlp") || name.contains("feed_forward") {
883 if name.contains("gate_proj") || name.contains("w1") {
884 println!(" -> Gate projection weights");
885 } else if name.contains("up_proj") || name.contains("w3") {
886 println!(" -> Up projection weights");
887 } else if name.contains("down_proj") || name.contains("w2") {
888 println!(" -> Down projection weights");
889 }
890 } else if name.contains("input_layernorm") || name.contains("ln_1") {
891 println!(" -> Input layer norm weights");
892 } else if name.contains("post_attention_layernorm") || name.contains("ln_2") {
893 println!(" -> Post-attention layer norm weights");
894 }
895
896 }
899 }
900
901 Ok(())
902 }
903
904 fn extract_layer_index(&self, name: &str) -> Option<usize> {
906 if let Some(captures) = name.find("layers.") {
908 let start = captures + "layers.".len();
909 if let Some(end) = name[start..].find('.') {
910 if let Ok(idx) = name[start..start + end].parse::<usize>() {
911 return Some(idx);
912 }
913 }
914 }
915
916 if let Some(captures) = name.find("transformer.h.") {
917 let start = captures + "transformer.h.".len();
918 if let Some(end) = name[start..].find('.') {
919 if let Ok(idx) = name[start..start + end].parse::<usize>() {
920 return Some(idx);
921 }
922 }
923 }
924
925 None
926 }
927}
928
929#[derive(Debug, Clone)]
931pub struct CommandRModelOutput {
932 pub last_hidden_state: Tensor,
933 pub past_key_values: Option<Vec<(Tensor, Tensor)>>,
934 pub hidden_states: Option<Vec<Tensor>>,
935 pub attentions: Option<Vec<Tensor>>,
936}
937
938#[derive(Debug, Clone)]
940pub struct CommandRForCausalLM {
941 model: CommandRModel,
942 lm_head: Linear,
943 config: CommandRConfig,
944}
945
946impl CommandRForCausalLM {
947 pub fn new(config: &CommandRConfig) -> Result<Self> {
948 let model = CommandRModel::new(config)?;
949 let lm_head = Linear::new(config.hidden_size, config.vocab_size, config.use_bias);
950
951 Ok(Self {
952 model,
953 lm_head,
954 config: config.clone(),
955 })
956 }
957
958 pub fn forward(
959 &mut self,
960 input_ids: &Tensor,
961 attention_mask: Option<&Tensor>,
962 position_ids: Option<&Tensor>,
963 past_key_values: Option<&[(Tensor, Tensor)]>,
964 labels: Option<&Tensor>,
965 ) -> Result<CommandRCausalLMOutput> {
966 let mut model_mut = self.model.clone();
967 let outputs = CommandRModel::forward(
968 &mut model_mut,
969 input_ids,
970 attention_mask,
971 position_ids,
972 past_key_values,
973 )?;
974
975 let logits = self.lm_head.forward(outputs.last_hidden_state)?;
976
977 let loss = if let Some(labels) = labels {
978 let vocab_size = logits.shape()[logits.shape().len() - 1];
981 let seq_len = logits.shape()[logits.shape().len() - 2];
982
983 let batch_size = logits.shape()[0];
985 let flat_logits = logits.reshape(&[batch_size * seq_len, vocab_size])?;
986 let _flat_labels = labels.reshape(&[batch_size * seq_len])?;
987
988 let _log_probs = flat_logits.softmax(-1)?.log()?;
990
991 let target_probs = Tensor::zeros(&flat_logits.shape())?;
994 let diff = flat_logits.sub(&target_probs)?;
997 let squared = diff.mul(&diff)?;
998 Some(squared.mean()?)
999 } else {
1000 None
1001 };
1002
1003 Ok(CommandRCausalLMOutput {
1004 loss,
1005 logits,
1006 past_key_values: outputs.past_key_values,
1007 hidden_states: outputs.hidden_states,
1008 attentions: outputs.attentions,
1009 })
1010 }
1011
1012 pub fn generate(
1013 &mut self,
1014 input_ids: &Tensor,
1015 max_length: usize,
1016 temperature: f32,
1017 top_k: Option<usize>,
1018 top_p: Option<f32>,
1019 ) -> Result<Tensor> {
1020 let mut current_ids = input_ids.clone();
1021 let mut past_key_values = None;
1022
1023 for _ in 0..max_length {
1024 let outputs =
1025 self.forward(¤t_ids, None, None, past_key_values.as_deref(), None)?;
1026
1027 let seq_len = outputs.logits.shape()[1];
1028 let next_token_logits = outputs.logits.slice(1, seq_len - 1, seq_len)?;
1029 let next_token_logits = next_token_logits.div_scalar(temperature)?;
1030
1031 let next_token = self.sample_next_token(&next_token_logits, top_k, top_p)?;
1033
1034 current_ids = Tensor::concat(&[current_ids, next_token.clone()], 1)?;
1036 past_key_values = outputs.past_key_values;
1037
1038 if let Some(eos_id) = self.config.eos_token_id {
1040 if let Ok(data) = next_token.data() {
1041 if data[0] as usize == eos_id {
1042 break;
1043 }
1044 }
1045 }
1046 }
1047
1048 Ok(current_ids)
1049 }
1050
1051 fn sample_next_token(
1052 &self,
1053 logits: &Tensor,
1054 top_k: Option<usize>,
1055 top_p: Option<f32>,
1056 ) -> Result<Tensor> {
1057 let mut probs = logits.softmax(-1)?;
1058
1059 if let Some(k) = top_k {
1061 probs = self.top_k_sampling(&probs, k)?;
1062 }
1063
1064 if let Some(p) = top_p {
1066 probs = self.top_p_sampling(&probs, p)?;
1067 }
1068
1069 let sampled_idx = self.categorical_sample(&probs)?;
1071
1072 Tensor::new(vec![sampled_idx as f32])?.reshape(&[1, 1])
1073 }
1074
1075 fn top_k_sampling(&self, probs: &Tensor, _k: usize) -> Result<Tensor> {
1076 Ok(probs.clone())
1079 }
1080
1081 fn top_p_sampling(&self, probs: &Tensor, _p: f32) -> Result<Tensor> {
1082 Ok(probs.clone())
1085 }
1086
1087 fn categorical_sample(&self, probs: &Tensor) -> Result<usize> {
1088 let data = probs.data()?;
1091 let mut max_idx = 0;
1092 let mut max_prob = data[0];
1093
1094 for (i, &prob) in data.iter().enumerate() {
1095 if prob > max_prob {
1096 max_prob = prob;
1097 max_idx = i;
1098 }
1099 }
1100
1101 Ok(max_idx)
1102 }
1103}
1104
1105#[derive(Debug, Clone)]
1107pub struct CommandRCausalLMOutput {
1108 pub loss: Option<Tensor>,
1109 pub logits: Tensor,
1110 pub past_key_values: Option<Vec<(Tensor, Tensor)>>,
1111 pub hidden_states: Option<Vec<Tensor>>,
1112 pub attentions: Option<Vec<Tensor>>,
1113}
1114
1115impl Model for CommandRForCausalLM {
1116 type Config = CommandRConfig;
1117 type Input = Tensor;
1118 type Output = Tensor;
1119
1120 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1121 let hidden_states = self.model.forward(input)?;
1123
1124 let logits = self.lm_head.forward(hidden_states)?;
1126
1127 Ok(logits)
1128 }
1129
1130 fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
1131 use std::io::Write;
1132
1133 let mut buffer = Vec::new();
1135 reader.read_to_end(&mut buffer).map_err(|e| {
1136 trustformers_core::errors::TrustformersError::io_error(format!(
1137 "Failed to read pretrained weights: {}",
1138 e
1139 ))
1140 })?;
1141
1142 if buffer.is_empty() {
1143 return Err(
1144 trustformers_core::errors::TrustformersError::invalid_input_simple(
1145 "Pretrained weight data is empty".to_string(),
1146 ),
1147 );
1148 }
1149
1150 let temp_dir = std::env::temp_dir();
1152 let temp_file_path = temp_dir.join(format!(
1153 "command_r_causal_weights_{}.bin",
1154 std::process::id()
1155 ));
1156
1157 {
1159 let mut temp_file = std::fs::File::create(&temp_file_path).map_err(|e| {
1160 trustformers_core::errors::TrustformersError::io_error(format!(
1161 "Failed to create temporary file: {}",
1162 e
1163 ))
1164 })?;
1165 temp_file.write_all(&buffer).map_err(|e| {
1166 trustformers_core::errors::TrustformersError::io_error(format!(
1167 "Failed to write to temporary file: {}",
1168 e
1169 ))
1170 })?;
1171 }
1172
1173 let result = self.load_from_path(&temp_file_path);
1175
1176 let _ = std::fs::remove_file(&temp_file_path);
1178
1179 result
1180 }
1181
1182 fn get_config(&self) -> &Self::Config {
1183 &self.config
1184 }
1185
1186 fn num_parameters(&self) -> usize {
1187 self.model.num_parameters() + self.lm_head.parameter_count()
1188 }
1189}
1190
1191impl CommandRForCausalLM {
1192 pub fn load_from_path(&mut self, model_path: impl AsRef<std::path::Path>) -> Result<()> {
1194 use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
1195
1196 let config = WeightLoadingConfig {
1197 lazy_loading: true,
1198 memory_mapped: false,
1199 ..Default::default()
1200 };
1201
1202 let mut loader = auto_create_loader(model_path, Some(config))?;
1203
1204 if let Ok(embed_weights) = loader.load_tensor("model.embed_tokens.weight") {
1206 self.model.embed_tokens.set_weight(embed_weights)?;
1207 }
1208
1209 for (i, layer) in self.model.layers.iter_mut().enumerate() {
1211 let attn_prefix = format!("model.layers.{}.self_attn", i);
1213
1214 if let Ok(q_weight) = loader.load_tensor(&format!("{}.q_proj.weight", attn_prefix)) {
1215 layer.self_attn.q_proj.set_weight(q_weight)?;
1216 }
1217 if let Ok(k_weight) = loader.load_tensor(&format!("{}.k_proj.weight", attn_prefix)) {
1218 layer.self_attn.k_proj.set_weight(k_weight)?;
1219 }
1220 if let Ok(v_weight) = loader.load_tensor(&format!("{}.v_proj.weight", attn_prefix)) {
1221 layer.self_attn.v_proj.set_weight(v_weight)?;
1222 }
1223 if let Ok(o_weight) = loader.load_tensor(&format!("{}.o_proj.weight", attn_prefix)) {
1224 layer.self_attn.o_proj.set_weight(o_weight)?;
1225 }
1226
1227 let mlp_prefix = format!("model.layers.{}.mlp", i);
1229
1230 if let Ok(gate_weight) = loader.load_tensor(&format!("{}.gate_proj.weight", mlp_prefix))
1231 {
1232 layer.mlp.gate_proj.set_weight(gate_weight)?;
1233 }
1234 if let Ok(up_weight) = loader.load_tensor(&format!("{}.up_proj.weight", mlp_prefix)) {
1235 layer.mlp.up_proj.set_weight(up_weight)?;
1236 }
1237 if let Ok(down_weight) = loader.load_tensor(&format!("{}.down_proj.weight", mlp_prefix))
1238 {
1239 layer.mlp.down_proj.set_weight(down_weight)?;
1240 }
1241
1242 if let Ok(ln1_weight) =
1244 loader.load_tensor(&format!("model.layers.{}.input_layernorm.weight", i))
1245 {
1246 layer.input_layernorm.set_weight(ln1_weight)?;
1247 }
1248 if let Ok(ln2_weight) = loader.load_tensor(&format!(
1249 "model.layers.{}.post_attention_layernorm.weight",
1250 i
1251 )) {
1252 layer.post_attention_layernorm.set_weight(ln2_weight)?;
1253 }
1254 }
1255
1256 if let Ok(norm_weight) = loader.load_tensor("model.norm.weight") {
1258 self.model.norm.set_weight(norm_weight)?;
1259 }
1260
1261 if let Ok(lm_head_weight) = loader.load_tensor("lm_head.weight") {
1263 self.lm_head.set_weight(lm_head_weight)?;
1264 }
1265
1266 Ok(())
1267 }
1268
1269 pub fn load_from_huggingface(&mut self, model_name: &str) -> Result<()> {
1271 let cache_dir = std::env::var("HF_HOME")
1273 .or_else(|_| std::env::var("HUGGINGFACE_HUB_CACHE"))
1274 .unwrap_or_else(|_| {
1275 std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
1276 + "/.cache/huggingface/hub"
1277 });
1278
1279 let model_path = std::path::Path::new(&cache_dir)
1280 .join(format!("models--{}", model_name.replace("/", "--")));
1281
1282 if model_path.exists() {
1283 self.load_from_path(&model_path)
1284 } else {
1285 self.download_from_huggingface_hub(model_name, &model_path)?;
1287 self.load_from_path(&model_path)
1288 }
1289 }
1290
1291 fn download_from_huggingface_hub(
1293 &self,
1294 model_name: &str,
1295 model_path: &std::path::Path,
1296 ) -> Result<()> {
1297 use std::process::Command;
1298
1299 println!(
1300 "Downloading model {} from HuggingFace Hub to {:?}",
1301 model_name, model_path
1302 );
1303
1304 std::fs::create_dir_all(model_path).map_err(|e| {
1306 trustformers_core::errors::TrustformersError::io_error(format!(
1307 "Failed to create model directory: {}",
1308 e
1309 ))
1310 })?;
1311
1312 let essential_files = vec![
1314 "config.json",
1315 "tokenizer.json",
1316 "tokenizer_config.json",
1317 "pytorch_model.bin", "model.safetensors", ];
1320
1321 let base_url = format!("https://huggingface.co/{}/resolve/main", model_name);
1322
1323 for file_name in &essential_files {
1325 let file_url = format!("{}/{}", base_url, file_name);
1326 let file_path = model_path.join(file_name);
1327
1328 println!("Attempting to download {}", file_url);
1329
1330 let file_path_str = file_path.to_str().ok_or_else(|| {
1332 TrustformersError::invalid_config(format!("Invalid UTF-8 in path: {:?}", file_path))
1333 })?;
1334
1335 let curl_result = Command::new("curl")
1337 .args([
1338 "-L", "-f", "-o",
1341 file_path_str,
1342 &file_url,
1343 ])
1344 .output();
1345
1346 match curl_result {
1347 Ok(output) if output.status.success() => {
1348 println!("Successfully downloaded {}", file_name);
1349 continue;
1350 },
1351 Ok(output) => {
1352 eprintln!(
1353 "Failed to download {} with curl: {}",
1354 file_name,
1355 String::from_utf8_lossy(&output.stderr)
1356 );
1357 },
1358 Err(e) => {
1359 println!("curl not available: {}", e);
1360 },
1361 }
1362
1363 let wget_result = Command::new("wget").args(["-O", file_path_str, &file_url]).output();
1365
1366 match wget_result {
1367 Ok(output) if output.status.success() => {
1368 println!("Successfully downloaded {} with wget", file_name);
1369 continue;
1370 },
1371 Ok(output) => {
1372 eprintln!(
1373 "Failed to download {} with wget: {}",
1374 file_name,
1375 String::from_utf8_lossy(&output.stderr)
1376 );
1377 },
1378 Err(e) => {
1379 println!("wget not available: {}", e);
1380 },
1381 }
1382
1383 if matches!(file_name, &"config.json" | &"pytorch_model.bin") {
1385 return Err(trustformers_core::errors::TrustformersError::io_error(format!(
1386 "Failed to download essential file {} for model {}. Please ensure curl or wget is installed and you have internet access.",
1387 file_name, model_name
1388 )));
1389 }
1390 }
1391
1392 println!(
1393 "Successfully downloaded model {} to {:?}",
1394 model_name, model_path
1395 );
1396 Ok(())
1397 }
1398
1399 pub fn load_with_lazy_loading(
1401 &mut self,
1402 model_path: impl AsRef<std::path::Path>,
1403 ) -> Result<()> {
1404 use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
1405
1406 let config = WeightLoadingConfig {
1407 lazy_loading: true,
1408 memory_mapped: true,
1409 streaming: false,
1410 ..Default::default()
1411 };
1412
1413 let _loader = auto_create_loader(&model_path, Some(config))?;
1414
1415 self.load_from_path(model_path)
1422 }
1423}
1424
1425impl Config for CommandRConfig {
1426 fn validate(&self) -> Result<()> {
1427 self.validate().map_err(|e| invalid_config("config_validation", &e))
1428 }
1429
1430 fn architecture(&self) -> &'static str {
1431 "command-r"
1432 }
1433}
1434
1435#[cfg(test)]
1436mod tests {
1437 use super::*;
1438
1439 #[test]
1441 fn test_command_r_model_creation_tiny() {
1442 let config = CommandRConfig::tiny();
1443 let model = CommandRModel::new(&config);
1444 assert!(model.is_ok());
1445 }
1446
1447 #[test]
1448 fn test_command_r_causal_lm_creation_tiny() {
1449 let config = CommandRConfig::tiny();
1450 let model = CommandRForCausalLM::new(&config);
1451 assert!(model.is_ok());
1452 }
1453
1454 #[test]
1455 #[ignore = "Forward pass requires proper hidden state input - model's forward method is shadowed by Model trait"]
1456 fn test_command_r_forward_pass_tiny() {
1457 let config = CommandRConfig::tiny();
1458 let model = CommandRModel::new(&config).expect("operation failed");
1459
1460 let batch_size = 1;
1463 let seq_len = 4;
1464 let hidden_states =
1465 Tensor::zeros(&[batch_size, seq_len, config.hidden_size]).expect("operation failed");
1466
1467 let result = model.forward(hidden_states);
1468 assert!(result.is_ok(), "Forward pass failed: {:?}", result.err());
1469 }
1470
1471 #[test]
1472 fn test_command_r_attention_creation_tiny() {
1473 let config = CommandRConfig::tiny();
1474 let attention = CommandRAttention::new(&config);
1475 assert!(attention.is_ok());
1476 }
1477
1478 #[test]
1479 fn test_command_r_mlp_creation_tiny() {
1480 let config = CommandRConfig::tiny();
1481 let mlp = CommandRMLP::new(&config);
1482 assert!(mlp.is_ok());
1483 }
1484
1485 #[test]
1486 fn test_command_r_decoder_layer_creation_tiny() {
1487 let config = CommandRConfig::tiny();
1488 let layer = CommandRDecoderLayer::new(&config);
1489 assert!(layer.is_ok());
1490 }
1491
1492 #[test]
1493 fn test_rope_creation() {
1494 let rope = CommandRRoPE::new(128, 4096, 10000.0);
1495 assert!(rope.is_ok());
1496 }
1497
1498 #[test]
1500 #[ignore = "Full model size test - requires significant memory and time"]
1501 fn test_command_r_model_creation() {
1502 let config = CommandRConfig::command_r();
1503 let model = CommandRModel::new(&config);
1504 assert!(model.is_ok());
1505 }
1506
1507 #[test]
1508 #[ignore = "Full model size test - requires significant memory and time"]
1509 fn test_command_r_plus_model_creation() {
1510 let config = CommandRConfig::command_r_plus();
1511 let model = CommandRModel::new(&config);
1512 assert!(model.is_ok());
1513 }
1514
1515 #[test]
1516 #[ignore = "Full model size test - requires significant memory and time"]
1517 fn test_command_r_causal_lm_creation() {
1518 let config = CommandRConfig::command_r();
1519 let model = CommandRForCausalLM::new(&config);
1520 assert!(model.is_ok());
1521 }
1522
1523 #[test]
1524 #[ignore = "Full model size test - requires significant memory and time"]
1525 fn test_command_r_forward_pass() {
1526 let config = CommandRConfig::command_r();
1527 let model = CommandRModel::new(&config).expect("operation failed");
1528
1529 let input_ids = Tensor::from_vec_i64(vec![1, 2, 3, 4], &[1, 4]).expect("operation failed");
1531
1532 let result = model.forward(input_ids);
1533 assert!(result.is_ok());
1534 }
1535
1536 #[test]
1537 #[ignore = "Full model size test - requires significant memory and time"]
1538 fn test_command_r_attention_creation() {
1539 let config = CommandRConfig::command_r();
1540 let attention = CommandRAttention::new(&config);
1541 assert!(attention.is_ok());
1542 }
1543
1544 #[test]
1545 #[ignore = "Full model size test - requires significant memory and time"]
1546 fn test_command_r_mlp_creation() {
1547 let config = CommandRConfig::command_r();
1548 let mlp = CommandRMLP::new(&config);
1549 assert!(mlp.is_ok());
1550 }
1551
1552 #[test]
1553 #[ignore = "Full model size test - requires significant memory and time"]
1554 fn test_command_r_decoder_layer_creation() {
1555 let config = CommandRConfig::command_r();
1556 let layer = CommandRDecoderLayer::new(&config);
1557 assert!(layer.is_ok());
1558 }
1559}