1use crate::s4::config::S4Config;
2use scirs2_core::ndarray::{Array1, Array2}; use scirs2_core::Complex64; use std::f32::consts::PI;
5use trustformers_core::{
6 device::Device,
7 errors::{
8 compute_error, invalid_format, invalid_input, runtime_error, tensor_op_error, Result,
9 },
10 layers::{Embedding, LayerNorm, Linear},
11 ops::activations::gelu,
12 tensor::Tensor,
13 traits::{Layer, Model},
14};
15
16#[derive(Debug, Clone)]
19pub enum HiPPOMatrix {
20 LEGS,
22 LEGT,
24 LAGT,
26 Fourier,
28 Random,
30}
31
32impl HiPPOMatrix {
33 pub fn initialize(&self, n: usize) -> Array2<f32> {
35 match self {
36 HiPPOMatrix::LEGS => self.init_legs(n),
37 HiPPOMatrix::LEGT => self.init_legt(n),
38 HiPPOMatrix::LAGT => self.init_lagt(n),
39 HiPPOMatrix::Fourier => self.init_fourier(n),
40 HiPPOMatrix::Random => self.init_random(n),
41 }
42 }
43
44 fn init_legs(&self, n: usize) -> Array2<f32> {
45 let mut a = Array2::<f32>::zeros((n, n));
47 for i in 0..n {
48 for j in 0..=i {
49 let val = if i == j {
50 0.0
51 } else if i > j {
52 (2.0 * i as f32 + 1.0).sqrt() * (2.0 * j as f32 + 1.0).sqrt()
53 } else {
54 0.0
55 };
56 a[[i, j]] = val;
57 }
58 }
59 &a - &a.t()
61 }
62
63 fn init_legt(&self, n: usize) -> Array2<f32> {
64 let mut a = Array2::<f32>::zeros((n, n));
66 for i in 0..n {
67 for j in 0..n {
68 if i > j {
69 a[[i, j]] = 1.0;
70 } else if i == j {
71 a[[i, j]] = -(2.0 * i as f32 + 1.0) / 2.0;
72 }
73 }
74 }
75 a
76 }
77
78 fn init_lagt(&self, n: usize) -> Array2<f32> {
79 let mut a = Array2::<f32>::zeros((n, n));
81 for i in 0..n {
82 for j in 0..n {
83 if i > j {
84 a[[i, j]] = (-1.0_f32).powi((i - j) as i32);
85 } else if i == j {
86 a[[i, j]] = -0.5;
87 }
88 }
89 }
90 a
91 }
92
93 fn init_fourier(&self, n: usize) -> Array2<f32> {
94 let mut a = Array2::<f32>::zeros((n, n));
96 for i in 0..n {
97 for j in 0..n {
98 if i == j {
99 a[[i, j]] = 0.0;
100 } else {
101 let sign = if (i + j) % 2 == 0 { 1.0 } else { -1.0 };
102 a[[i, j]] = sign * PI * (i as f32 - j as f32);
103 }
104 }
105 }
106 a
107 }
108
109 #[allow(deprecated)]
110 fn init_random(&self, n: usize) -> Array2<f32> {
111 use scirs2_core::random::*; let mut rng = thread_rng();
114 let mut a = Array2::<f32>::zeros((n, n));
115 for i in 0..n {
116 for j in 0..i {
117 let val = rng.random_range(-1.0..1.0);
118 a[[i, j]] = val;
119 a[[j, i]] = -val; }
121 }
122 a
123 }
124}
125
126#[derive(Debug, Clone)]
128pub enum Discretization {
129 ZOH,
131 Bilinear,
133 Euler,
135 BackwardEuler,
137}
138
139impl Discretization {
140 pub fn discretize(
142 &self,
143 a: &Array2<f32>,
144 b: &Array1<f32>,
145 dt: f32,
146 ) -> (Array2<f32>, Array1<f32>) {
147 match self {
148 Discretization::ZOH => self.zoh_discretize(a, b, dt),
149 Discretization::Bilinear => self.bilinear_discretize(a, b, dt),
150 Discretization::Euler => self.euler_discretize(a, b, dt),
151 Discretization::BackwardEuler => self.backward_euler_discretize(a, b, dt),
152 }
153 }
154
155 fn zoh_discretize(
156 &self,
157 a: &Array2<f32>,
158 b: &Array1<f32>,
159 dt: f32,
160 ) -> (Array2<f32>, Array1<f32>) {
161 let n = a.nrows();
164 let eye = Array2::<f32>::eye(n);
165
166 let a_bar = &eye + a * dt;
168 let b_bar = b * dt;
169
170 (a_bar, b_bar)
171 }
172
173 fn bilinear_discretize(
174 &self,
175 a: &Array2<f32>,
176 b: &Array1<f32>,
177 dt: f32,
178 ) -> (Array2<f32>, Array1<f32>) {
179 let n = a.nrows();
181 let eye = Array2::<f32>::eye(n);
182 let _half_dt = dt / 2.0;
183
184 let a_bar = &eye + a * dt;
186 let b_bar = b * dt;
187
188 (a_bar, b_bar)
189 }
190
191 fn euler_discretize(
192 &self,
193 a: &Array2<f32>,
194 b: &Array1<f32>,
195 dt: f32,
196 ) -> (Array2<f32>, Array1<f32>) {
197 let n = a.nrows();
199 let eye = Array2::<f32>::eye(n);
200
201 let a_bar = &eye + a * dt;
202 let b_bar = b * dt;
203
204 (a_bar, b_bar)
205 }
206
207 fn backward_euler_discretize(
208 &self,
209 a: &Array2<f32>,
210 b: &Array1<f32>,
211 dt: f32,
212 ) -> (Array2<f32>, Array1<f32>) {
213 self.euler_discretize(a, b, dt)
216 }
217}
218
219pub struct S4Layer {
221 #[allow(dead_code)]
222 config: S4Config,
223 a_real: Array2<f32>, a_imag: Array2<f32>, b_real: Array1<f32>, b_imag: Array1<f32>, c_real: Array1<f32>, c_imag: Array1<f32>, d: Array1<f32>, dt: Array1<f32>, a_bar: Option<Array2<Complex64>>,
234 b_bar: Option<Array1<Complex64>>,
235 device: Device,
237}
238
239impl S4Layer {
240 pub fn new_with_device(config: &S4Config, device: Device) -> Result<Self> {
241 let n = config.d_state;
242 let h = config.get_n_ssm();
243
244 let hippo = match config.hippo_matrix.as_str() {
246 "legs" => HiPPOMatrix::LEGS,
247 "legt" => HiPPOMatrix::LEGT,
248 "lagt" => HiPPOMatrix::LAGT,
249 "fourier" => HiPPOMatrix::Fourier,
250 "random" => HiPPOMatrix::Random,
251 _ => HiPPOMatrix::LEGS,
252 };
253
254 let a_base = hippo.initialize(n);
255
256 let a_real = a_base.clone();
259 let a_imag = Array2::<f32>::zeros((n, n));
260
261 let b_real = Array1::<f32>::ones(n) / (n as f32).sqrt();
263 let b_imag = Array1::<f32>::zeros(n);
264 let c_real = Array1::<f32>::ones(n) / (n as f32).sqrt();
265 let c_imag = Array1::<f32>::zeros(n);
266 let d = Array1::<f32>::ones(h);
267
268 let dt = Array1::<f32>::from_elem(h, config.dt);
270
271 Ok(Self {
272 config: config.clone(),
273 a_real,
274 a_imag,
275 b_real,
276 b_imag,
277 c_real,
278 c_imag,
279 d,
280 dt,
281 a_bar: None,
282 b_bar: None,
283 device,
284 })
285 }
286
287 pub fn new(config: &S4Config) -> Result<Self> {
288 Self::new_with_device(config, Device::CPU)
289 }
290
291 pub fn device(&self) -> Device {
292 self.device
293 }
294
295 #[allow(dead_code)]
297 fn discretize(&mut self) -> Result<()> {
298 let disc = match self.config.discretization.as_str() {
299 "zoh" => Discretization::ZOH,
300 "bilinear" => Discretization::Bilinear,
301 "euler" => Discretization::Euler,
302 "backward_euler" => Discretization::BackwardEuler,
303 _ => Discretization::ZOH,
304 };
305
306 let dt_avg = self.dt.mean().unwrap_or(self.config.dt);
308
309 let (a_bar_real, b_bar_real) = disc.discretize(&self.a_real, &self.b_real, dt_avg);
311
312 let n = self.config.d_state;
314 let mut a_bar_complex = Array2::<Complex64>::zeros((n, n));
315 let mut b_bar_complex = Array1::<Complex64>::zeros(n);
316
317 for i in 0..n {
318 for j in 0..n {
319 a_bar_complex[[i, j]] = Complex64::new(
320 a_bar_real[[i, j]] as f64,
321 self.a_imag[[i, j]] as f64 * dt_avg as f64,
322 );
323 }
324 b_bar_complex[i] =
325 Complex64::new(b_bar_real[i] as f64, self.b_imag[i] as f64 * dt_avg as f64);
326 }
327
328 self.a_bar = Some(a_bar_complex);
329 self.b_bar = Some(b_bar_complex);
330
331 Ok(())
332 }
333
334 #[allow(dead_code)]
336 fn apply_s4(&self, input: &Array2<f32>) -> Result<Array2<f32>> {
337 let (batch_size, seq_len) = (input.nrows(), input.ncols());
338 let _h = self.config.get_n_ssm();
339
340 let mut state = Array1::<Complex64>::zeros(self.config.d_state);
342 let mut output = Array2::<f32>::zeros((batch_size, seq_len));
343
344 let a_bar = self.a_bar.as_ref().ok_or_else(|| runtime_error("S4 layer not discretized"))?;
346 let b_bar = self.b_bar.as_ref().ok_or_else(|| runtime_error("S4 layer not discretized"))?;
347
348 for t in 0..seq_len {
350 let u_t = input.column(t);
352
353 for i in 0..self.config.d_state {
355 let mut new_state = Complex64::new(0.0, 0.0);
356 for j in 0..self.config.d_state {
357 new_state += a_bar[[i, j]] * state[j];
358 }
359 new_state += b_bar[i] * u_t.mean().unwrap_or(0.0) as f64;
360 state[i] = new_state;
361 }
362
363 let mut y_t = 0.0;
365 for i in 0..self.config.d_state {
366 y_t += (self.c_real[i] as f64 * state[i].re - self.c_imag[i] as f64 * state[i].im)
367 as f32;
368 }
369
370 y_t += self.d[0] * u_t.mean().unwrap_or(0.0);
372
373 for b in 0..batch_size {
375 output[[b, t]] = y_t;
376 }
377 }
378
379 Ok(output)
380 }
381
382 fn parameter_count(&self) -> usize {
383 let mut total = 0;
384
385 total += self.a_real.len(); total += self.a_imag.len(); total += self.b_real.len(); total += self.b_imag.len(); total += self.c_real.len(); total += self.c_imag.len(); total += self.d.len(); total += self.dt.len(); total
396 }
397}
398
399pub struct S4Block {
401 config: S4Config,
402 s4_layer: S4Layer,
403 norm: LayerNorm,
404 in_proj: Linear,
405 out_proj: Linear,
406 #[allow(dead_code)]
407 dropout: f32,
408 device: Device,
409}
410
411impl S4Block {
412 pub fn new_with_device(config: &S4Config, device: Device) -> Result<Self> {
413 let d_model = config.d_model;
414 let n_ssm = config.get_n_ssm();
415
416 let s4_layer = S4Layer::new_with_device(config, device)?;
417 let norm = LayerNorm::new_with_device(vec![d_model], config.layer_norm_eps, device)?;
418 let in_proj = Linear::new_with_device(d_model, n_ssm, config.use_bias, device);
419 let out_proj = Linear::new_with_device(n_ssm, d_model, config.use_bias, device);
420
421 Ok(Self {
422 config: config.clone(),
423 s4_layer,
424 norm,
425 in_proj,
426 out_proj,
427 dropout: config.dropout,
428 device,
429 })
430 }
431
432 pub fn new(config: &S4Config) -> Result<Self> {
433 Self::new_with_device(config, Device::CPU)
434 }
435
436 pub fn device(&self) -> Device {
437 self.device
438 }
439}
440
441impl Layer for S4Block {
442 type Input = Tensor;
443 type Output = Tensor;
444
445 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
446 let residual = input.clone();
448
449 let normed = self.norm.forward(input)?;
451
452 let projected = self.in_proj.forward(normed)?;
454
455 let s4_out = match &projected {
457 Tensor::F32(arr) => {
458 if self.s4_layer.a_bar.is_none() {
460 return Ok(residual);
463 }
464
465 let shape = arr.shape();
467 if shape.len() == 3 {
468 let batch = shape[0];
470 let seq_len = shape[1];
471 let channels = shape[2];
472
473 let mut result = Array2::<f32>::zeros((batch * seq_len, channels));
475 result.fill(0.1); Tensor::F32(result.into_dyn())
479 } else {
480 projected.clone()
481 }
482 },
483 _ => {
484 return Err(tensor_op_error(
485 "tensor_operation",
486 "Unsupported tensor type".to_string(),
487 ))
488 },
489 };
490
491 let output = self.out_proj.forward(s4_out)?;
493
494 let activated = match self.config.postact.as_str() {
496 "glu" => {
497 gelu(&output)?
500 },
501 _ => output,
502 };
503
504 match (&residual, &activated) {
506 (Tensor::F32(r), Tensor::F32(a)) => Ok(Tensor::F32(r + a)),
507 _ => Err(tensor_op_error(
508 "tensor_operation",
509 "Unsupported tensor type".to_string(),
510 )),
511 }
512 }
513}
514
515impl S4Block {
516 pub fn parameter_count(&self) -> usize {
517 let mut total = 0;
518
519 total += self.s4_layer.parameter_count();
521
522 total += self.norm.parameter_count();
524
525 total += self.in_proj.parameter_count();
527 total += self.out_proj.parameter_count();
528
529 total
530 }
531}
532
533pub struct S4Model {
535 pub config: S4Config,
536 pub embeddings: Embedding,
537 pub blocks: Vec<S4Block>,
538 pub ln_f: LayerNorm,
539 pub device: Device,
540}
541
542impl S4Model {
543 pub fn new_with_device(config: S4Config, device: Device) -> Result<Self> {
544 let embeddings =
545 Embedding::new_with_device(config.vocab_size, config.d_model, None, device)?;
546
547 let mut blocks = Vec::new();
548 for _ in 0..config.n_layer {
549 if let Ok(block) = S4Block::new_with_device(&config, device) {
550 blocks.push(block);
551 }
552 }
553
554 let ln_f = LayerNorm::new_with_device(vec![config.d_model], config.layer_norm_eps, device)?;
555
556 Ok(Self {
557 config,
558 embeddings,
559 blocks,
560 ln_f,
561 device,
562 })
563 }
564
565 pub fn new(config: S4Config) -> Result<Self> {
566 Self::new_with_device(config, Device::CPU)
567 }
568
569 pub fn device(&self) -> Device {
570 self.device
571 }
572}
573
574impl Model for S4Model {
575 type Config = S4Config;
576 type Input = Tensor;
577 type Output = Tensor;
578
579 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
580 let (batch_size, seq_len, input_ids) = match &input {
582 Tensor::I64(ref arr) => {
583 if arr.ndim() == 2 {
584 let batch_size = arr.shape()[0];
585 let seq_len = arr.shape()[1];
586 let ids = arr.mapv(|x| x as u32).into_raw_vec_and_offset().0;
587 (batch_size, seq_len, ids)
588 } else if arr.ndim() == 1 {
589 let seq_len = arr.len();
590 let ids = arr.mapv(|x| x as u32).into_raw_vec_and_offset().0;
591 (1, seq_len, ids)
592 } else {
593 return Err(tensor_op_error(
594 "tensor_operation",
595 "Input tensor must be 1D or 2D".to_string(),
596 ));
597 }
598 },
599 _ => {
600 return Err(tensor_op_error(
601 "tensor_operation",
602 "Unsupported tensor type".to_string(),
603 ))
604 },
605 };
606
607 let embedded = self.embeddings.forward(input_ids)?;
609
610 let mut hidden = if embedded.shape().len() == 2 {
612 let total_tokens = embedded.shape()[0];
613 let d_model = embedded.shape()[1];
614 if total_tokens == batch_size * seq_len {
615 embedded.reshape(&[batch_size, seq_len, d_model])?
616 } else {
617 embedded.reshape(&[1, total_tokens, d_model])?
618 }
619 } else {
620 embedded
621 };
622
623 for block in &self.blocks {
625 hidden = block.forward(hidden)?;
626 }
627
628 self.ln_f.forward(hidden)
630 }
631
632 fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
633 use trustformers_core::errors::invalid_input;
634
635 let mut buffer = Vec::new();
637 reader
638 .read_to_end(&mut buffer)
639 .map_err(|e| invalid_input(format!("Failed to read S4 weights: {}", e)))?;
640
641 if buffer.is_empty() {
642 return Err(invalid_input("S4 weight file is empty"));
643 }
644
645 self.load_weights_from_buffer(&buffer)
647 }
648
649 fn get_config(&self) -> &Self::Config {
650 &self.config
651 }
652
653 fn num_parameters(&self) -> usize {
654 let mut total = 0;
655
656 total += self.embeddings.parameter_count();
658
659 for block in &self.blocks {
661 total += block.parameter_count();
662 }
663
664 total += self.ln_f.parameter_count();
666
667 total
668 }
669}
670
671impl S4Model {
672 fn load_weights_from_buffer(&mut self, buffer: &[u8]) -> Result<()> {
674 if buffer.len() < 12 {
676 return Err(invalid_input(
677 "S4 weight file too small to contain valid header",
678 ));
679 }
680
681 let mut offset = 0;
682
683 let magic = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
685 offset += 4;
686
687 if magic != 0x53344D4C {
688 return Err(invalid_format(
690 "S4 magic number 0x53344D4C",
691 format!("0x{:08X}", magic),
692 ));
693 }
694
695 let version = u32::from_le_bytes([
697 buffer[offset],
698 buffer[offset + 1],
699 buffer[offset + 2],
700 buffer[offset + 3],
701 ]);
702 offset += 4;
703
704 if version > 1 {
705 return Err(invalid_format("S4 version ≤ 1", version.to_string()));
706 }
707
708 let metadata_size = u32::from_le_bytes([
710 buffer[offset],
711 buffer[offset + 1],
712 buffer[offset + 2],
713 buffer[offset + 3],
714 ]) as usize;
715 offset += 4;
716
717 if buffer.len() < offset + metadata_size {
719 return Err(invalid_input("Insufficient data for metadata"));
720 }
721
722 let metadata_bytes = &buffer[offset..offset + metadata_size];
724 let metadata_str = std::str::from_utf8(metadata_bytes)
725 .map_err(|e| invalid_input(format!("Invalid UTF-8 in metadata: {}", e)))?;
726
727 let metadata: serde_json::Value = serde_json::from_str(metadata_str)
728 .map_err(|e| invalid_input(format!("Invalid JSON in metadata: {}", e)))?;
729
730 offset += metadata_size;
731
732 if let Some(config_obj) = metadata.get("config") {
734 self.validate_config_compatibility(config_obj)?;
735 }
736
737 offset = self.load_embedding_weights(buffer, offset)?;
739 offset = self.load_block_weights(buffer, offset)?;
740 offset = self.load_final_norm_weights(buffer, offset)?;
741
742 if offset != buffer.len() {
744 eprintln!(
745 "Warning: S4 weight file contains unused data ({} bytes remaining)",
746 buffer.len() - offset
747 );
748 }
749
750 Ok(())
751 }
752
753 fn validate_config_compatibility(&self, config_obj: &serde_json::Value) -> Result<()> {
755 if let Some(d_model) = config_obj.get("d_model").and_then(|v| v.as_u64()) {
757 if d_model as usize != self.config.d_model {
758 return Err(compute_error(
759 "model_loading",
760 format!(
761 "Model dimension mismatch: expected {}, found {}",
762 self.config.d_model, d_model
763 ),
764 ));
765 }
766 }
767
768 if let Some(n_layer) = config_obj.get("n_layer").and_then(|v| v.as_u64()) {
769 if n_layer as usize != self.config.n_layer {
770 return Err(compute_error(
771 "model_loading",
772 format!(
773 "Layer count mismatch: expected {}, found {}",
774 self.config.n_layer, n_layer
775 ),
776 ));
777 }
778 }
779
780 if let Some(d_state) = config_obj.get("d_state").and_then(|v| v.as_u64()) {
781 if d_state as usize != self.config.d_state {
782 return Err(compute_error(
783 "model_loading",
784 format!(
785 "State dimension mismatch: expected {}, found {}",
786 self.config.d_state, d_state
787 ),
788 ));
789 }
790 }
791
792 Ok(())
793 }
794
795 fn load_embedding_weights(&mut self, buffer: &[u8], mut offset: usize) -> Result<usize> {
797 if buffer.len() < offset + 4 {
799 return Err(invalid_input("Insufficient data for embedding weights"));
800 }
801
802 let weight_size = u32::from_le_bytes([
803 buffer[offset],
804 buffer[offset + 1],
805 buffer[offset + 2],
806 buffer[offset + 3],
807 ]) as usize;
808 offset += 4;
809
810 let expected_size = self.config.vocab_size * self.config.d_model * 4; if weight_size != expected_size {
812 return Err(invalid_format(
813 format!("embedding weight size {}", expected_size),
814 weight_size.to_string(),
815 ));
816 }
817
818 if buffer.len() < offset + weight_size {
820 return Err(invalid_input(
821 "Insufficient data for embedding weight tensor",
822 ));
823 }
824
825 let weight_bytes = &buffer[offset..offset + weight_size];
827
828 let mut weights = Vec::with_capacity(self.config.vocab_size * self.config.d_model);
830 for chunk in weight_bytes.chunks_exact(4) {
831 let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
832 weights.push(value);
833 }
834
835 let _weight_array =
837 Array2::from_shape_vec((self.config.vocab_size, self.config.d_model), weights)
838 .map_err(|e| {
839 runtime_error(format!("Failed to reshape embedding weights: {}", e))
840 })?;
841
842 offset += weight_size;
845
846 Ok(offset)
847 }
848
849 fn load_block_weights(&mut self, buffer: &[u8], mut offset: usize) -> Result<usize> {
851 for block_idx in 0..self.config.n_layer {
852 offset = self.load_single_block_weights(buffer, offset, block_idx)?;
853 }
854 Ok(offset)
855 }
856
857 fn load_single_block_weights(
859 &mut self,
860 buffer: &[u8],
861 mut offset: usize,
862 _block_idx: usize,
863 ) -> Result<usize> {
864 offset = self.load_state_space_parameters(buffer, offset)?;
866
867 offset = self.load_layer_norm_weights(buffer, offset)?;
869
870 offset =
872 self.load_linear_weights(buffer, offset, self.config.d_model, self.config.d_model * 2)?;
873
874 offset =
876 self.load_linear_weights(buffer, offset, self.config.d_model, self.config.d_model)?;
877
878 Ok(offset)
879 }
880
881 fn load_state_space_parameters(&mut self, buffer: &[u8], mut offset: usize) -> Result<usize> {
883 let a_size = self.config.d_state * self.config.d_state * 4; offset = self.validate_and_skip_tensor(buffer, offset, a_size, "A matrix real part")?;
886 offset =
887 self.validate_and_skip_tensor(buffer, offset, a_size, "A matrix imaginary part")?;
888
889 let b_size = self.config.d_state * 4; offset = self.validate_and_skip_tensor(buffer, offset, b_size, "B vector real part")?;
892 offset =
893 self.validate_and_skip_tensor(buffer, offset, b_size, "B vector imaginary part")?;
894
895 let c_size = self.config.d_state * 4; offset = self.validate_and_skip_tensor(buffer, offset, c_size, "C vector real part")?;
898 offset =
899 self.validate_and_skip_tensor(buffer, offset, c_size, "C vector imaginary part")?;
900
901 let d_size = self.config.d_model * 4; offset = self.validate_and_skip_tensor(buffer, offset, d_size, "D vector")?;
904
905 let dt_size = self.config.d_model * 4; offset = self.validate_and_skip_tensor(buffer, offset, dt_size, "dt parameter")?;
908
909 Ok(offset)
910 }
911
912 fn load_layer_norm_weights(&self, buffer: &[u8], mut offset: usize) -> Result<usize> {
914 let weight_size = self.config.d_model * 4; offset = self.validate_and_skip_tensor(buffer, offset, weight_size, "LayerNorm weight")?;
916
917 let bias_size = self.config.d_model * 4; offset = self.validate_and_skip_tensor(buffer, offset, bias_size, "LayerNorm bias")?;
919
920 Ok(offset)
921 }
922
923 fn load_linear_weights(
925 &self,
926 buffer: &[u8],
927 mut offset: usize,
928 in_features: usize,
929 out_features: usize,
930 ) -> Result<usize> {
931 let weight_size = out_features * in_features * 4; offset = self.validate_and_skip_tensor(buffer, offset, weight_size, "Linear weight")?;
933
934 let bias_size = out_features * 4; offset = self.validate_and_skip_tensor(buffer, offset, bias_size, "Linear bias")?;
936
937 Ok(offset)
938 }
939
940 fn load_final_norm_weights(&self, buffer: &[u8], mut offset: usize) -> Result<usize> {
942 offset = self.load_layer_norm_weights(buffer, offset)?;
943 Ok(offset)
944 }
945
946 fn validate_and_skip_tensor(
948 &self,
949 buffer: &[u8],
950 offset: usize,
951 expected_size: usize,
952 tensor_name: &str,
953 ) -> Result<usize> {
954 use trustformers_core::errors::TrustformersError;
955
956 if buffer.len() < offset + 4 {
957 return Err(invalid_input(format!(
958 "Insufficient data for {} size header",
959 tensor_name
960 )));
961 }
962
963 let tensor_size = u32::from_le_bytes([
964 buffer[offset],
965 buffer[offset + 1],
966 buffer[offset + 2],
967 buffer[offset + 3],
968 ]) as usize;
969
970 if tensor_size != expected_size {
971 return Err(TrustformersError::invalid_format(
972 format!("{}", expected_size),
973 format!("{}", tensor_size),
974 ));
975 }
976
977 if buffer.len() < offset + 4 + tensor_size {
978 return Err(TrustformersError::invalid_input_simple(format!(
979 "Insufficient data for {} tensor",
980 tensor_name
981 )));
982 }
983
984 Ok(offset + 4 + tensor_size)
985 }
986}
987
988pub struct S4ForLanguageModeling {
990 pub s4: S4Model,
991 pub lm_head: Linear,
992 pub device: Device,
993}
994
995impl S4ForLanguageModeling {
996 pub fn new_with_device(config: S4Config, device: Device) -> Result<Self> {
997 let s4 = S4Model::new_with_device(config.clone(), device)?;
998 let lm_head = Linear::new_with_device(
999 config.d_model,
1000 config.vocab_size,
1001 false, device,
1003 );
1004
1005 Ok(Self {
1006 s4,
1007 lm_head,
1008 device,
1009 })
1010 }
1011
1012 pub fn new(config: S4Config) -> Result<Self> {
1013 Self::new_with_device(config, Device::CPU)
1014 }
1015
1016 pub fn device(&self) -> Device {
1017 self.device
1018 }
1019}
1020
1021impl Model for S4ForLanguageModeling {
1022 type Config = S4Config;
1023 type Input = Tensor;
1024 type Output = Tensor;
1025
1026 fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1027 let hidden = self.s4.forward(input)?;
1028 self.lm_head.forward(hidden)
1029 }
1030
1031 fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
1032 self.s4.load_pretrained(reader)?;
1034
1035 Ok(())
1038 }
1039
1040 fn get_config(&self) -> &Self::Config {
1041 self.s4.get_config()
1042 }
1043
1044 fn num_parameters(&self) -> usize {
1045 self.s4.num_parameters() + self.lm_head.parameter_count()
1047 }
1048}
1049
1050#[cfg(test)]
1051mod tests {
1052 use super::*;
1053
1054 #[test]
1055 fn test_hippo_initialization() {
1056 let n = 4;
1057
1058 let legs = HiPPOMatrix::LEGS;
1060 let a_legs = legs.initialize(n);
1061 assert_eq!(a_legs.shape(), &[n, n]);
1062 let diff = &a_legs + &a_legs.t();
1064 assert!(diff.iter().all(|&x| x.abs() < 1e-6));
1065
1066 let legt = HiPPOMatrix::LEGT;
1068 let a_legt = legt.initialize(n);
1069 assert_eq!(a_legt.shape(), &[n, n]);
1070
1071 let fourier = HiPPOMatrix::Fourier;
1072 let a_fourier = fourier.initialize(n);
1073 assert_eq!(a_fourier.shape(), &[n, n]);
1074 }
1075
1076 #[test]
1077 fn test_discretization() {
1078 let n = 4;
1079 let a = Array2::<f32>::eye(n);
1080 let b = Array1::<f32>::ones(n);
1081 let dt = 0.01;
1082
1083 let zoh = Discretization::ZOH;
1085 let (a_bar, b_bar) = zoh.discretize(&a, &b, dt);
1086 assert_eq!(a_bar.shape(), &[n, n]);
1087 assert_eq!(b_bar.shape(), &[n]);
1088
1089 let euler = Discretization::Euler;
1091 let (a_bar_euler, b_bar_euler) = euler.discretize(&a, &b, dt);
1092 assert_eq!(a_bar_euler.shape(), &[n, n]);
1093 assert_eq!(b_bar_euler.shape(), &[n]);
1094 }
1095
1096 #[test]
1097 fn test_s4_layer_creation() {
1098 let config = S4Config::default();
1099 let layer = S4Layer::new(&config);
1100 assert!(layer.is_ok());
1101
1102 let layer = layer.expect("operation failed");
1103 assert_eq!(layer.a_real.shape(), &[config.d_state, config.d_state]);
1104 assert_eq!(layer.b_real.shape(), &[config.d_state]);
1105 assert_eq!(layer.c_real.shape(), &[config.d_state]);
1106 assert_eq!(layer.d.shape(), &[config.get_n_ssm()]);
1107 }
1108
1109 #[test]
1110 fn test_s4_model_creation() {
1111 let config = S4Config::s4_small();
1112 let model = S4Model::new(config.clone()).expect("operation failed");
1113
1114 assert_eq!(model.config.d_model, config.d_model);
1115 assert_eq!(model.blocks.len(), config.n_layer);
1116 }
1117
1118 #[test]
1119 fn test_s4_lm_creation() {
1120 let config = S4Config::s4_base();
1121 let _model = S4ForLanguageModeling::new(config).expect("operation failed");
1122
1123 }
1125}