1use crate::falcon::config::FalconConfig;
2use scirs2_core::ndarray::{s, ArrayD, IxDyn}; use std::io::Read;
4use trustformers_core::{
5 device::Device,
6 errors::{tensor_op_error, Result, TrustformersError},
7 layers::{Embedding, LayerNorm, Linear},
8 ops::activations::{gelu, silu},
9 tensor::Tensor,
10 traits::{Config, Layer, Model},
11};
12
13pub struct ALiBi {
16 slopes: Tensor,
17 num_heads: usize,
18 device: Device,
19}
20
21impl ALiBi {
22 pub fn new(num_heads: usize) -> Result<Self> {
23 Self::new_with_device(num_heads, Device::CPU)
24 }
25
26 pub fn new_with_device(num_heads: usize, device: Device) -> Result<Self> {
27 let mut slopes = Vec::new();
29 let ratio = 2.0_f32.powf(-8.0 / num_heads as f32);
30
31 if num_heads % 2 == 0 {
32 for i in 0..num_heads / 2 {
34 slopes.push(ratio.powf((2 * i + 1) as f32));
35 }
36 for i in 0..num_heads / 2 {
37 slopes.push(ratio.powf((2 * i + 2) as f32));
38 }
39 } else {
40 for i in 0..num_heads {
42 slopes.push(ratio.powf((i + 1) as f32));
43 }
44 }
45
46 let slopes_tensor = Tensor::new(slopes)?;
47
48 Ok(Self {
49 slopes: slopes_tensor,
50 num_heads,
51 device,
52 })
53 }
54
55 pub fn device(&self) -> Device {
56 self.device
57 }
58
59 pub fn apply_bias(&self, attention_scores: &Tensor, seq_len: usize) -> Result<Tensor> {
61 let mut bias_data = Vec::new();
63
64 for head_idx in 0..self.num_heads {
66 for i in 0..seq_len {
67 for j in 0..seq_len {
68 if j > i {
69 bias_data.push(-10000.0);
71 } else {
72 let distance = (i - j) as f32;
74 let slope = if let Ok(slopes_data) = self.slopes.data() {
75 if head_idx < slopes_data.len() {
76 slopes_data[head_idx]
77 } else {
78 1.0
79 }
80 } else {
81 1.0
82 };
83 bias_data.push(-distance * slope);
84 }
85 }
86 }
87 }
88
89 let bias_tensor = Tensor::from_vec(bias_data, &[seq_len, seq_len])?;
91
92 let biased_scores = attention_scores.add(&bias_tensor)?;
94 Ok(biased_scores)
95 }
96}
97
98pub struct FalconAttention {
100 q_proj: Linear,
101 k_proj: Linear,
102 v_proj: Linear,
103 dense: Linear,
104 alibi: Option<ALiBi>,
105 num_heads: usize,
106 num_kv_heads: usize,
107 head_dim: usize,
108 #[allow(dead_code)]
109 attention_dropout: f32,
110 #[allow(dead_code)]
111 use_flash_attention: bool,
112 device: Device,
113 }
116
117impl FalconAttention {
118 pub fn new(config: &FalconConfig) -> Result<Self> {
119 Self::new_with_device(config, Device::CPU)
120 }
121
122 pub fn new_with_device(config: &FalconConfig, device: Device) -> Result<Self> {
123 let head_dim = config.head_dim();
124 let num_kv_heads = config.num_kv_heads();
125
126 let q_proj = Linear::new(
127 config.hidden_size,
128 config.num_attention_heads * head_dim,
129 config.bias,
130 );
131 let k_proj = Linear::new(config.hidden_size, num_kv_heads * head_dim, config.bias);
132 let v_proj = Linear::new(config.hidden_size, num_kv_heads * head_dim, config.bias);
133 let dense = Linear::new(
134 config.num_attention_heads * head_dim,
135 config.hidden_size,
136 config.bias,
137 );
138
139 let alibi = if config.alibi {
140 Some(ALiBi::new_with_device(config.num_attention_heads, device)?)
141 } else {
142 None
143 };
144
145 Ok(Self {
146 q_proj,
147 k_proj,
148 v_proj,
149 dense,
150 alibi,
151 num_heads: config.num_attention_heads,
152 num_kv_heads,
153 head_dim,
154 attention_dropout: config.attention_dropout,
155 use_flash_attention: config.use_flash_attention.unwrap_or(false),
156 device,
157 })
158 }
159
160 pub fn device(&self) -> Device {
161 self.device
162 }
163
164 fn create_causal_mask(&self, seq_len: usize) -> Result<Tensor> {
166 let mut mask_data = vec![0.0f32; seq_len * seq_len];
168 for i in 0..seq_len {
169 for j in (i + 1)..seq_len {
170 mask_data[i * seq_len + j] = f32::NEG_INFINITY;
171 }
172 }
173 Tensor::from_vec(mask_data, &[seq_len, seq_len])
174 }
175
176 pub fn parameter_count(&self) -> usize {
177 self.q_proj.parameter_count()
178 + self.k_proj.parameter_count()
179 + self.v_proj.parameter_count()
180 + self.dense.parameter_count()
181 }
182}
183
184impl Layer for FalconAttention {
185 type Input = Tensor;
186 type Output = Tensor;
187
188 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
189 let batch_size = input.shape()[0];
190 let seq_len = input.shape()[1];
191
192 let q = self.q_proj.forward(input.clone())?;
194 let k = self.k_proj.forward(input.clone())?;
195 let v = self.v_proj.forward(input)?;
196
197 let q = q.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])?;
200 let k = k.reshape(&[batch_size, seq_len, self.num_kv_heads, self.head_dim])?;
201 let v = v.reshape(&[batch_size, seq_len, self.num_kv_heads, self.head_dim])?;
202
203 let q = q.transpose(1, 2)?;
205 let k = k.transpose(1, 2)?;
206 let v = v.transpose(1, 2)?;
207
208 let (k, v) = if self.num_kv_heads < self.num_heads {
210 let repeats = self.num_heads / self.num_kv_heads;
211
212 let mut k_heads = Vec::new();
214 let mut v_heads = Vec::new();
215
216 for head_idx in 0..self.num_kv_heads {
217 let k_head = k.slice_multi(&[
219 (0, batch_size),
220 (head_idx, head_idx + 1),
221 (0, seq_len),
222 (0, self.head_dim),
223 ])?;
224 let v_head = v.slice_multi(&[
225 (0, batch_size),
226 (head_idx, head_idx + 1),
227 (0, seq_len),
228 (0, self.head_dim),
229 ])?;
230
231 for _ in 0..repeats {
233 k_heads.push(k_head.clone());
234 v_heads.push(v_head.clone());
235 }
236 }
237
238 let k_repeated = Tensor::concat(&k_heads, 1)?;
240 let v_repeated = Tensor::concat(&v_heads, 1)?;
241 (k_repeated, v_repeated)
242 } else {
243 (k, v)
244 };
245
246 let k_transposed = k.transpose(2, 3)?;
249 let scores = q.matmul(&k_transposed)?;
250 let scale = (self.head_dim as f32).sqrt();
251 let scaled_scores = scores.div_scalar(scale)?;
252
253 let causal_mask = self.create_causal_mask(seq_len)?;
255 let masked_scores = scaled_scores.add(&causal_mask)?;
256
257 let attention_weights = masked_scores.softmax(-1)?;
259
260 let attention_output = attention_weights.matmul(&v)?;
262
263 let attention_output = attention_output.transpose(1, 2)?;
265 let attention_output =
266 attention_output.reshape(&[batch_size, seq_len, self.num_heads * self.head_dim])?;
267
268 let biased_output = if let Some(alibi) = &self.alibi {
270 alibi.apply_bias(&attention_output, seq_len)?
271 } else {
272 attention_output
273 };
274
275 let output = self.dense.forward(biased_output)?;
277 Ok(output)
278 }
279}
280
281pub struct FalconMLP {
283 dense_h_to_4h: Linear,
284 dense_4h_to_h: Linear,
285 activation: String,
286 device: Device,
287}
288
289impl FalconMLP {
290 pub fn new(config: &FalconConfig) -> Result<Self> {
291 Self::new_with_device(config, Device::CPU)
292 }
293
294 pub fn new_with_device(config: &FalconConfig, device: Device) -> Result<Self> {
295 let intermediate_size = 4 * config.hidden_size;
296
297 let dense_h_to_4h = Linear::new(config.hidden_size, intermediate_size, config.bias);
298 let dense_4h_to_h = Linear::new(intermediate_size, config.hidden_size, config.bias);
299
300 Ok(Self {
301 dense_h_to_4h,
302 dense_4h_to_h,
303 activation: config.hidden_act.clone(),
304 device,
305 })
306 }
307
308 pub fn device(&self) -> Device {
309 self.device
310 }
311
312 pub fn parameter_count(&self) -> usize {
313 self.dense_h_to_4h.parameter_count() + self.dense_4h_to_h.parameter_count()
314 }
315}
316
317impl Layer for FalconMLP {
318 type Input = Tensor;
319 type Output = Tensor;
320
321 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
322 let hidden = self.dense_h_to_4h.forward(input)?;
323
324 let activated = match self.activation.as_str() {
326 "gelu" => gelu(&hidden)?,
327 "relu" => hidden.relu()?,
328 "silu" | "swish" => silu(&hidden)?,
329 _ => hidden,
330 };
331
332 let output = self.dense_4h_to_h.forward(activated)?;
333 Ok(output)
334 }
335}
336
337pub struct FalconDecoderLayer {
339 input_layernorm: LayerNorm,
340 self_attention: FalconAttention,
341 mlp: FalconMLP,
342 parallel_attn: bool,
343 apply_residual_connection_post_layernorm: bool,
344 device: Device,
345}
346
347impl FalconDecoderLayer {
348 pub fn new(config: &FalconConfig) -> Result<Self> {
349 Self::new_with_device(config, Device::CPU)
350 }
351
352 pub fn new_with_device(config: &FalconConfig, device: Device) -> Result<Self> {
353 let input_layernorm = LayerNorm::new(vec![config.hidden_size], config.layer_norm_epsilon)?;
354 let self_attention = FalconAttention::new_with_device(config, device)?;
355 let mlp = FalconMLP::new_with_device(config, device)?;
356
357 Ok(Self {
358 input_layernorm,
359 self_attention,
360 mlp,
361 parallel_attn: config.parallel_attn,
362 apply_residual_connection_post_layernorm: config
363 .apply_residual_connection_post_layernorm,
364 device,
365 })
366 }
367
368 pub fn device(&self) -> Device {
369 self.device
370 }
371
372 pub fn parameter_count(&self) -> usize {
373 self.input_layernorm.parameter_count()
374 + self.self_attention.parameter_count()
375 + self.mlp.parameter_count()
376 }
377}
378
379impl Layer for FalconDecoderLayer {
380 type Input = Tensor;
381 type Output = Tensor;
382
383 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
384 if self.parallel_attn {
385 let layernorm_output = self.input_layernorm.forward(input.clone())?;
387
388 let attention_output = self.self_attention.forward(layernorm_output.clone())?;
390 let mlp_output = self.mlp.forward(layernorm_output.clone())?;
391
392 let residual_input = if self.apply_residual_connection_post_layernorm {
394 layernorm_output
395 } else {
396 input
397 };
398
399 let output = residual_input.add(&attention_output)?.add(&mlp_output)?;
401 Ok(output)
402 } else {
403 let layernorm_output = self.input_layernorm.forward(input.clone())?;
405 let attention_output = self.self_attention.forward(layernorm_output)?;
406
407 let residual_output = input.add(&attention_output)?;
409
410 let layernorm_output2 = self.input_layernorm.forward(residual_output.clone())?;
411 let mlp_output = self.mlp.forward(layernorm_output2)?;
412
413 let output = residual_output.add(&mlp_output)?;
415 Ok(output)
416 }
417 }
418}
419
420pub struct FalconModel {
422 word_embeddings: Embedding,
423 layers: Vec<FalconDecoderLayer>,
424 ln_f: LayerNorm,
425 config: FalconConfig,
426 device: Device,
427}
428
429impl FalconModel {
430 pub fn new(config: FalconConfig) -> Result<Self> {
431 Self::new_with_device(config, Device::CPU)
432 }
433
434 pub fn new_with_device(config: FalconConfig, device: Device) -> Result<Self> {
435 config.validate()?;
436
437 let word_embeddings = Embedding::new(
438 config.vocab_size,
439 config.hidden_size,
440 config.pad_token_id.map(|id| id as usize),
441 )?;
442
443 let mut layers = Vec::new();
444 for _ in 0..config.num_hidden_layers {
445 layers.push(FalconDecoderLayer::new_with_device(&config, device)?);
446 }
447
448 let ln_f = LayerNorm::new(vec![config.hidden_size], config.layer_norm_epsilon)?;
449
450 Ok(Self {
451 word_embeddings,
452 layers,
453 ln_f,
454 config,
455 device,
456 })
457 }
458
459 pub fn device(&self) -> Device {
460 self.device
461 }
462
463 pub fn config(&self) -> &FalconConfig {
464 &self.config
465 }
466}
467
468impl Model for FalconModel {
469 type Config = FalconConfig;
470 type Input = Tensor;
471 type Output = Tensor;
472
473 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
474 Layer::forward(self, input)
475 }
476
477 fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
478 Err(TrustformersError::not_implemented(
480 "Use load_from_path or load_from_huggingface for enhanced weight loading".to_string(),
481 ))
482 }
483
484 fn get_config(&self) -> &Self::Config {
485 &self.config
486 }
487
488 fn num_parameters(&self) -> usize {
489 let embeddings_params = self.word_embeddings.parameter_count();
490 let layers_params: usize = self.layers.iter().map(|layer| layer.parameter_count()).sum();
491 let norm_params = self.ln_f.parameter_count();
492
493 embeddings_params + layers_params + norm_params
494 }
495}
496
497impl Layer for FalconModel {
498 type Input = Tensor;
499 type Output = Tensor;
500
501 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
502 let token_ids = match &input {
504 Tensor::F32(arr) => {
505 arr.iter().map(|&x| x as u32).collect::<Vec<u32>>()
507 },
508 _ => {
509 return Err(tensor_op_error(
510 "tensor_operation",
511 "Input must be F32 tensor",
512 ))
513 },
514 };
515
516 if token_ids.is_empty() {
517 return Err(TrustformersError::model_error(
518 "Empty token_ids provided".to_string(),
519 ));
520 }
521
522 let mut hidden_states = self.word_embeddings.forward(token_ids)?;
523
524 for layer in &self.layers {
526 hidden_states = layer.forward(hidden_states)?;
527 }
528
529 let output = self.ln_f.forward(hidden_states)?;
531 Ok(output)
532 }
533}
534
535pub struct FalconForCausalLM {
537 transformer: FalconModel,
538 lm_head: Linear,
539 device: Device,
540}
541
542impl FalconForCausalLM {
543 pub fn new(config: FalconConfig) -> Result<Self> {
544 Self::new_with_device(config, Device::CPU)
545 }
546
547 pub fn new_with_device(config: FalconConfig, device: Device) -> Result<Self> {
548 let transformer = FalconModel::new_with_device(config.clone(), device)?;
549 let lm_head = Linear::new(
550 config.hidden_size,
551 config.vocab_size,
552 false, );
554
555 Ok(Self {
556 transformer,
557 lm_head,
558 device,
559 })
560 }
561
562 pub fn device(&self) -> Device {
563 self.device
564 }
565
566 pub fn load_from_path(&mut self, model_path: impl AsRef<std::path::Path>) -> Result<()> {
568 use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
569
570 let config = WeightLoadingConfig {
571 lazy_loading: true,
572 memory_mapped: false,
573 ..Default::default()
574 };
575
576 let mut loader = auto_create_loader(model_path, Some(config))?;
577
578 if let Ok(embed_weights) = loader.load_tensor("transformer.word_embeddings.weight") {
580 self.transformer.word_embeddings.set_weight(embed_weights)?;
581 }
582
583 for (i, layer) in self.transformer.layers.iter_mut().enumerate() {
585 let attn_prefix = format!("transformer.h.{}.self_attention", i);
587
588 if let Ok(qkv_weight) =
589 loader.load_tensor(&format!("{}.query_key_value.weight", attn_prefix))
590 {
591 match &qkv_weight {
593 Tensor::F32(arr) => {
594 let shape = arr.shape();
595 let combined_size = shape[0];
596 let _hidden_size = shape[1];
597
598 let head_dim = combined_size / 3;
600
601 let q_slice = arr.slice(s![0..head_dim, ..]).to_owned();
603 let k_slice = arr.slice(s![head_dim..2 * head_dim, ..]).to_owned();
604 let v_slice = arr.slice(s![2 * head_dim..3 * head_dim, ..]).to_owned();
605
606 let q_dyn = q_slice.into_dyn();
608 let k_dyn = k_slice.into_dyn();
609 let v_dyn = v_slice.into_dyn();
610
611 layer.self_attention.q_proj.set_weight(Tensor::F32(q_dyn))?;
612 layer.self_attention.k_proj.set_weight(Tensor::F32(k_dyn))?;
613 layer.self_attention.v_proj.set_weight(Tensor::F32(v_dyn))?;
614 },
615 _ => {
616 layer.self_attention.q_proj.set_weight(qkv_weight.clone())?;
618 },
619 }
620 }
621 if let Ok(o_weight) = loader.load_tensor(&format!("{}.dense.weight", attn_prefix)) {
622 layer.self_attention.dense.set_weight(o_weight)?;
623 }
624
625 let mlp_prefix = format!("transformer.h.{}.mlp", i);
627
628 if let Ok(up_weight) =
629 loader.load_tensor(&format!("{}.dense_h_to_4h.weight", mlp_prefix))
630 {
631 layer.mlp.dense_h_to_4h.set_weight(up_weight)?;
632 }
633 if let Ok(down_weight) =
634 loader.load_tensor(&format!("{}.dense_4h_to_h.weight", mlp_prefix))
635 {
636 layer.mlp.dense_4h_to_h.set_weight(down_weight)?;
637 }
638
639 if let Ok(ln_weight) =
641 loader.load_tensor(&format!("transformer.h.{}.input_layernorm.weight", i))
642 {
643 layer.input_layernorm.set_weight(ln_weight)?;
644 }
645 if let Ok(ln_bias) =
646 loader.load_tensor(&format!("transformer.h.{}.input_layernorm.bias", i))
647 {
648 layer.input_layernorm.set_bias(ln_bias)?;
649 }
650 }
651
652 if let Ok(norm_weight) = loader.load_tensor("transformer.ln_f.weight") {
654 self.transformer.ln_f.set_weight(norm_weight)?;
655 }
656 if let Ok(norm_bias) = loader.load_tensor("transformer.ln_f.bias") {
657 self.transformer.ln_f.set_bias(norm_bias)?;
658 }
659
660 if let Ok(lm_head_weight) = loader.load_tensor("lm_head.weight") {
662 self.lm_head.set_weight(lm_head_weight)?;
663 }
664
665 Ok(())
666 }
667
668 pub fn load_from_huggingface(&mut self, model_name: &str) -> Result<()> {
670 let cache_dir = std::env::var("HF_HOME")
672 .or_else(|_| std::env::var("HUGGINGFACE_HUB_CACHE"))
673 .unwrap_or_else(|_| {
674 std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
675 + "/.cache/huggingface/hub"
676 });
677
678 let model_path = std::path::Path::new(&cache_dir)
679 .join(format!("models--{}", model_name.replace("/", "--")));
680
681 if model_path.exists() {
682 self.load_from_path(&model_path)
683 } else {
684 self.download_from_huggingface_hub(model_name, &model_path)?;
686 self.load_from_path(&model_path)
687 }
688 }
689
690 fn download_from_huggingface_hub(
692 &self,
693 model_name: &str,
694 model_path: &std::path::Path,
695 ) -> Result<()> {
696 use std::process::Command;
697
698 println!(
699 "Downloading model {} from HuggingFace Hub to {:?}",
700 model_name, model_path
701 );
702
703 std::fs::create_dir_all(model_path).map_err(|e| {
705 TrustformersError::io_error(format!("Failed to create model directory: {}", e))
706 })?;
707
708 let essential_files = vec![
710 "config.json",
711 "tokenizer.json",
712 "tokenizer_config.json",
713 "pytorch_model.bin", "model.safetensors", ];
716
717 let base_url = format!("https://huggingface.co/{}/resolve/main", model_name);
718
719 for file_name in &essential_files {
721 let file_url = format!("{}/{}", base_url, file_name);
722 let file_path = model_path.join(file_name);
723
724 println!("Attempting to download {}", file_url);
725
726 let curl_result = Command::new("curl")
728 .args([
729 "-L", "-f", "-o",
732 file_path.to_str().expect("operation failed"),
733 &file_url,
734 ])
735 .output();
736
737 match curl_result {
738 Ok(output) if output.status.success() => {
739 println!("Successfully downloaded {}", file_name);
740 continue;
741 },
742 Ok(output) => {
743 eprintln!(
744 "Failed to download {} with curl: {}",
745 file_name,
746 String::from_utf8_lossy(&output.stderr)
747 );
748 },
749 Err(e) => {
750 println!("curl not available: {}", e);
751 },
752 }
753
754 let wget_result = Command::new("wget")
756 .args([
757 "-O",
758 file_path.to_str().expect("operation failed"),
759 &file_url,
760 ])
761 .output();
762
763 match wget_result {
764 Ok(output) if output.status.success() => {
765 println!("Successfully downloaded {} with wget", file_name);
766 continue;
767 },
768 Ok(output) => {
769 eprintln!(
770 "Failed to download {} with wget: {}",
771 file_name,
772 String::from_utf8_lossy(&output.stderr)
773 );
774 },
775 Err(e) => {
776 println!("wget not available: {}", e);
777 },
778 }
779
780 if matches!(file_name, &"config.json" | &"pytorch_model.bin") {
782 return Err(TrustformersError::io_error(format!(
783 "Failed to download essential file {} for model {}. Please ensure curl or wget is installed and you have internet access.",
784 file_name, model_name
785 )));
786 }
787 }
788
789 println!(
790 "Successfully downloaded model {} from HuggingFace Hub",
791 model_name
792 );
793 Ok(())
794 }
795
796 pub fn load_from_hub(&mut self, model_name: &str) -> Result<()> {
798 self.load_from_huggingface(model_name)
799 }
800
801 pub fn generate(&self, input_ids: Tensor, max_length: usize) -> Result<Tensor> {
803 let mut current_ids = input_ids;
804 let current_length = current_ids.shape()[current_ids.shape().len() - 1];
805
806 for _ in current_length..max_length {
808 let logits = <Self as Model>::forward(self, current_ids.clone())?;
810
811 let last_logits = match &logits {
813 Tensor::F32(arr) => {
814 let shape = arr.shape();
815 let seq_len = shape[shape.len() - 2];
816 let _vocab_size = shape[shape.len() - 1];
817
818 let last_token_slice = if shape.len() == 3 {
820 arr.slice(s![0, seq_len - 1, ..])
821 } else {
822 arr.slice(s![seq_len - 1, ..])
823 };
824 last_token_slice.to_owned()
825 },
826 _ => {
827 return Err(tensor_op_error(
828 "tensor_operation",
829 "Logits must be F32 tensor",
830 ))
831 },
832 };
833
834 let next_token_id = last_logits
836 .iter()
837 .enumerate()
838 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
839 .map(|(idx, _)| idx as u32)
840 .ok_or_else(|| {
841 TrustformersError::model_error("Failed to find next token".to_string())
842 })?;
843
844 if next_token_id == 2 {
846 break;
847 }
848
849 current_ids = match ¤t_ids {
851 Tensor::F32(arr) => {
852 let mut new_shape = arr.shape().to_vec();
854 let last_idx = new_shape.len() - 1;
855 new_shape[last_idx] += 1;
856
857 let mut new_arr = ArrayD::<f32>::zeros(IxDyn(&new_shape));
858
859 if arr.ndim() == 2 {
861 for i in 0..arr.shape()[0] {
862 for j in 0..arr.shape()[1] {
863 new_arr[[i, j]] = arr[[i, j]];
864 }
865 new_arr[[i, arr.shape()[1]]] = next_token_id as f32;
866 }
867 } else if arr.ndim() == 1 {
868 for i in 0..arr.shape()[0] {
869 new_arr[[i]] = arr[[i]];
870 }
871 new_arr[[arr.shape()[0]]] = next_token_id as f32;
872 }
873
874 Tensor::F32(new_arr)
875 },
876 _ => {
877 return Err(tensor_op_error(
878 "tensor_operation",
879 "Input must be F32 tensor",
880 ))
881 },
882 };
883 }
884
885 Ok(current_ids)
886 }
887
888 pub fn model(&self) -> &FalconModel {
889 &self.transformer
890 }
891}
892
893impl Model for FalconForCausalLM {
894 type Config = FalconConfig;
895 type Input = Tensor;
896 type Output = Tensor;
897
898 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
899 Layer::forward(self, input)
900 }
901
902 fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
903 self.transformer.load_pretrained(reader)
904 }
905
906 fn get_config(&self) -> &Self::Config {
907 self.transformer.get_config()
908 }
909
910 fn num_parameters(&self) -> usize {
911 self.transformer.num_parameters() + self.lm_head.parameter_count()
912 }
913}
914
915impl Layer for FalconForCausalLM {
916 type Input = Tensor;
917 type Output = Tensor;
918
919 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
920 let hidden_states = Layer::forward(&self.transformer, input)?;
921 let logits = self.lm_head.forward(hidden_states)?;
922 Ok(logits)
923 }
924}
925
926#[cfg(test)]
927mod tests {
928 use super::*;
929
930 #[test]
931 #[ignore] fn test_falcon_model_creation() {
933 let config = FalconConfig::falcon_7b();
934 let model = FalconModel::new(config);
935 assert!(model.is_ok());
936 }
937
938 #[test]
939 #[ignore] fn test_falcon_causal_lm_creation() {
941 let config = FalconConfig::falcon_7b();
942 let model = FalconForCausalLM::new(config);
943 assert!(model.is_ok());
944 }
945
946 #[test]
947 fn test_falcon_config_variants() {
948 let config_7b = FalconConfig::falcon_7b();
950 assert_eq!(config_7b.hidden_size, 4544);
951 assert_eq!(config_7b.num_hidden_layers, 32);
952 assert!(config_7b.uses_alibi());
953
954 let config_40b = FalconConfig::falcon_40b();
956 assert_eq!(config_40b.hidden_size, 8192);
957 assert_eq!(config_40b.num_hidden_layers, 60);
958 assert!(config_40b.uses_alibi());
959
960 let config_180b = FalconConfig::falcon_180b();
962 assert_eq!(config_180b.hidden_size, 14848);
963 assert_eq!(config_180b.num_hidden_layers, 80);
964 assert!(!config_180b.uses_alibi());
965 assert!(config_180b.uses_new_architecture());
966 }
967
968 #[test]
969 fn test_alibi_creation() {
970 let alibi = ALiBi::new(8);
971 assert!(alibi.is_ok());
972
973 let alibi = alibi.expect("operation failed");
974 assert_eq!(alibi.num_heads, 8);
975 }
976
977 #[test]
978 fn test_falcon_attention_creation() {
979 let config = FalconConfig::falcon_7b();
980 let attention = FalconAttention::new(&config);
981 assert!(attention.is_ok());
982 }
983
984 #[test]
985 fn test_falcon_mlp_creation() {
986 let config = FalconConfig::falcon_7b();
987 let mlp = FalconMLP::new(&config);
988 assert!(mlp.is_ok());
989 }
990}