Skip to main content

trustformers_models/s4/
model.rs

1use crate::s4::config::S4Config;
2use scirs2_core::ndarray::{Array1, Array2}; // SciRS2 Integration Policy
3use scirs2_core::Complex64; // SciRS2 Integration Policy
4use std::f32::consts::PI;
5use trustformers_core::{
6    device::Device,
7    errors::{
8        compute_error, invalid_format, invalid_input, runtime_error, tensor_op_error, Result,
9    },
10    layers::{Embedding, LayerNorm, Linear},
11    ops::activations::gelu,
12    tensor::Tensor,
13    traits::{Layer, Model},
14};
15
16/// HiPPO matrix initialization methods
17/// Reference: "HiPPO: Recurrent Memory with Optimal Polynomial Projections"
18#[derive(Debug, Clone)]
19pub enum HiPPOMatrix {
20    /// Legendre measure (uniform on [-1, 1])
21    LEGS,
22    /// Laguerre measure (exponential decay on [0, ∞))
23    LEGT,
24    /// Laguerre (translated)
25    LAGT,
26    /// Fourier basis
27    Fourier,
28    /// Random initialization
29    Random,
30}
31
32impl HiPPOMatrix {
33    /// Initialize HiPPO matrix A of shape (N, N)
34    pub fn initialize(&self, n: usize) -> Array2<f32> {
35        match self {
36            HiPPOMatrix::LEGS => self.init_legs(n),
37            HiPPOMatrix::LEGT => self.init_legt(n),
38            HiPPOMatrix::LAGT => self.init_lagt(n),
39            HiPPOMatrix::Fourier => self.init_fourier(n),
40            HiPPOMatrix::Random => self.init_random(n),
41        }
42    }
43
44    fn init_legs(&self, n: usize) -> Array2<f32> {
45        // Legendre (LEGS) matrix
46        let mut a = Array2::<f32>::zeros((n, n));
47        for i in 0..n {
48            for j in 0..=i {
49                let val = if i == j {
50                    0.0
51                } else if i > j {
52                    (2.0 * i as f32 + 1.0).sqrt() * (2.0 * j as f32 + 1.0).sqrt()
53                } else {
54                    0.0
55                };
56                a[[i, j]] = val;
57            }
58        }
59        // Make skew-symmetric
60        &a - &a.t()
61    }
62
63    fn init_legt(&self, n: usize) -> Array2<f32> {
64        // Laguerre (LEGT) matrix
65        let mut a = Array2::<f32>::zeros((n, n));
66        for i in 0..n {
67            for j in 0..n {
68                if i > j {
69                    a[[i, j]] = 1.0;
70                } else if i == j {
71                    a[[i, j]] = -(2.0 * i as f32 + 1.0) / 2.0;
72                }
73            }
74        }
75        a
76    }
77
78    fn init_lagt(&self, n: usize) -> Array2<f32> {
79        // Translated Laguerre (LAGT) matrix
80        let mut a = Array2::<f32>::zeros((n, n));
81        for i in 0..n {
82            for j in 0..n {
83                if i > j {
84                    a[[i, j]] = (-1.0_f32).powi((i - j) as i32);
85                } else if i == j {
86                    a[[i, j]] = -0.5;
87                }
88            }
89        }
90        a
91    }
92
93    fn init_fourier(&self, n: usize) -> Array2<f32> {
94        // Fourier basis matrix
95        let mut a = Array2::<f32>::zeros((n, n));
96        for i in 0..n {
97            for j in 0..n {
98                if i == j {
99                    a[[i, j]] = 0.0;
100                } else {
101                    let sign = if (i + j) % 2 == 0 { 1.0 } else { -1.0 };
102                    a[[i, j]] = sign * PI * (i as f32 - j as f32);
103                }
104            }
105        }
106        a
107    }
108
109    #[allow(deprecated)]
110    fn init_random(&self, n: usize) -> Array2<f32> {
111        // Random skew-symmetric initialization
112        use scirs2_core::random::*; // SciRS2 Integration Policy
113        let mut rng = thread_rng();
114        let mut a = Array2::<f32>::zeros((n, n));
115        for i in 0..n {
116            for j in 0..i {
117                let val = rng.random_range(-1.0..1.0);
118                a[[i, j]] = val;
119                a[[j, i]] = -val; // Skew-symmetric
120            }
121        }
122        a
123    }
124}
125
126/// Discretization methods for continuous-time to discrete-time conversion
127#[derive(Debug, Clone)]
128pub enum Discretization {
129    /// Zero-order hold
130    ZOH,
131    /// Bilinear transform (Tustin's method)
132    Bilinear,
133    /// Forward Euler
134    Euler,
135    /// Backward Euler
136    BackwardEuler,
137}
138
139impl Discretization {
140    /// Discretize continuous-time (A, B) to discrete-time (A_bar, B_bar)
141    pub fn discretize(
142        &self,
143        a: &Array2<f32>,
144        b: &Array1<f32>,
145        dt: f32,
146    ) -> (Array2<f32>, Array1<f32>) {
147        match self {
148            Discretization::ZOH => self.zoh_discretize(a, b, dt),
149            Discretization::Bilinear => self.bilinear_discretize(a, b, dt),
150            Discretization::Euler => self.euler_discretize(a, b, dt),
151            Discretization::BackwardEuler => self.backward_euler_discretize(a, b, dt),
152        }
153    }
154
155    fn zoh_discretize(
156        &self,
157        a: &Array2<f32>,
158        b: &Array1<f32>,
159        dt: f32,
160    ) -> (Array2<f32>, Array1<f32>) {
161        // Zero-order hold: A_bar = exp(A * dt), B_bar = A^(-1) * (A_bar - I) * B
162        // Simplified implementation using first-order approximation
163        let n = a.nrows();
164        let eye = Array2::<f32>::eye(n);
165
166        // First-order approximation: exp(A*dt) ≈ I + A*dt
167        let a_bar = &eye + a * dt;
168        let b_bar = b * dt;
169
170        (a_bar, b_bar)
171    }
172
173    fn bilinear_discretize(
174        &self,
175        a: &Array2<f32>,
176        b: &Array1<f32>,
177        dt: f32,
178    ) -> (Array2<f32>, Array1<f32>) {
179        // Bilinear transform: A_bar = (I + dt/2 * A) * (I - dt/2 * A)^(-1)
180        let n = a.nrows();
181        let eye = Array2::<f32>::eye(n);
182        let _half_dt = dt / 2.0;
183
184        // Simplified: A_bar ≈ I + dt*A (first-order)
185        let a_bar = &eye + a * dt;
186        let b_bar = b * dt;
187
188        (a_bar, b_bar)
189    }
190
191    fn euler_discretize(
192        &self,
193        a: &Array2<f32>,
194        b: &Array1<f32>,
195        dt: f32,
196    ) -> (Array2<f32>, Array1<f32>) {
197        // Forward Euler: A_bar = I + dt * A
198        let n = a.nrows();
199        let eye = Array2::<f32>::eye(n);
200
201        let a_bar = &eye + a * dt;
202        let b_bar = b * dt;
203
204        (a_bar, b_bar)
205    }
206
207    fn backward_euler_discretize(
208        &self,
209        a: &Array2<f32>,
210        b: &Array1<f32>,
211        dt: f32,
212    ) -> (Array2<f32>, Array1<f32>) {
213        // Backward Euler: A_bar = (I - dt * A)^(-1)
214        // Simplified to forward Euler for now
215        self.euler_discretize(a, b, dt)
216    }
217}
218
219/// S4 Layer implementing the diagonal plus low-rank structure
220pub struct S4Layer {
221    #[allow(dead_code)]
222    config: S4Config,
223    // State space parameters
224    a_real: Array2<f32>, // Real part of A matrix
225    a_imag: Array2<f32>, // Imaginary part of A matrix
226    b_real: Array1<f32>, // Real part of B vector
227    b_imag: Array1<f32>, // Imaginary part of B vector
228    c_real: Array1<f32>, // Real part of C vector
229    c_imag: Array1<f32>, // Imaginary part of C vector
230    d: Array1<f32>,      // D vector (skip connection)
231    dt: Array1<f32>,     // Discretization timestep
232    // Cached discrete parameters
233    a_bar: Option<Array2<Complex64>>,
234    b_bar: Option<Array1<Complex64>>,
235    // Device for computation
236    device: Device,
237}
238
239impl S4Layer {
240    pub fn new_with_device(config: &S4Config, device: Device) -> Result<Self> {
241        let n = config.d_state;
242        let h = config.get_n_ssm();
243
244        // Initialize HiPPO matrix
245        let hippo = match config.hippo_matrix.as_str() {
246            "legs" => HiPPOMatrix::LEGS,
247            "legt" => HiPPOMatrix::LEGT,
248            "lagt" => HiPPOMatrix::LAGT,
249            "fourier" => HiPPOMatrix::Fourier,
250            "random" => HiPPOMatrix::Random,
251            _ => HiPPOMatrix::LEGS,
252        };
253
254        let a_base = hippo.initialize(n);
255
256        // Initialize as diagonal plus low-rank for efficiency
257        // A = Λ - pq^T where Λ is diagonal
258        let a_real = a_base.clone();
259        let a_imag = Array2::<f32>::zeros((n, n));
260
261        // Initialize B, C, D
262        let b_real = Array1::<f32>::ones(n) / (n as f32).sqrt();
263        let b_imag = Array1::<f32>::zeros(n);
264        let c_real = Array1::<f32>::ones(n) / (n as f32).sqrt();
265        let c_imag = Array1::<f32>::zeros(n);
266        let d = Array1::<f32>::ones(h);
267
268        // Initialize timestep
269        let dt = Array1::<f32>::from_elem(h, config.dt);
270
271        Ok(Self {
272            config: config.clone(),
273            a_real,
274            a_imag,
275            b_real,
276            b_imag,
277            c_real,
278            c_imag,
279            d,
280            dt,
281            a_bar: None,
282            b_bar: None,
283            device,
284        })
285    }
286
287    pub fn new(config: &S4Config) -> Result<Self> {
288        Self::new_with_device(config, Device::CPU)
289    }
290
291    pub fn device(&self) -> Device {
292        self.device
293    }
294
295    /// Discretize the continuous-time parameters
296    #[allow(dead_code)]
297    fn discretize(&mut self) -> Result<()> {
298        let disc = match self.config.discretization.as_str() {
299            "zoh" => Discretization::ZOH,
300            "bilinear" => Discretization::Bilinear,
301            "euler" => Discretization::Euler,
302            "backward_euler" => Discretization::BackwardEuler,
303            _ => Discretization::ZOH,
304        };
305
306        // Average dt across channels
307        let dt_avg = self.dt.mean().unwrap_or(self.config.dt);
308
309        // Discretize real part
310        let (a_bar_real, b_bar_real) = disc.discretize(&self.a_real, &self.b_real, dt_avg);
311
312        // Create complex matrices
313        let n = self.config.d_state;
314        let mut a_bar_complex = Array2::<Complex64>::zeros((n, n));
315        let mut b_bar_complex = Array1::<Complex64>::zeros(n);
316
317        for i in 0..n {
318            for j in 0..n {
319                a_bar_complex[[i, j]] = Complex64::new(
320                    a_bar_real[[i, j]] as f64,
321                    self.a_imag[[i, j]] as f64 * dt_avg as f64,
322                );
323            }
324            b_bar_complex[i] =
325                Complex64::new(b_bar_real[i] as f64, self.b_imag[i] as f64 * dt_avg as f64);
326        }
327
328        self.a_bar = Some(a_bar_complex);
329        self.b_bar = Some(b_bar_complex);
330
331        Ok(())
332    }
333
334    /// Apply S4 layer to input sequence
335    #[allow(dead_code)]
336    fn apply_s4(&self, input: &Array2<f32>) -> Result<Array2<f32>> {
337        let (batch_size, seq_len) = (input.nrows(), input.ncols());
338        let _h = self.config.get_n_ssm();
339
340        // Initialize state
341        let mut state = Array1::<Complex64>::zeros(self.config.d_state);
342        let mut output = Array2::<f32>::zeros((batch_size, seq_len));
343
344        // Get discretized parameters
345        let a_bar = self.a_bar.as_ref().ok_or_else(|| runtime_error("S4 layer not discretized"))?;
346        let b_bar = self.b_bar.as_ref().ok_or_else(|| runtime_error("S4 layer not discretized"))?;
347
348        // Process sequence
349        for t in 0..seq_len {
350            // Update state: x_{t+1} = A_bar @ x_t + B_bar @ u_t
351            let u_t = input.column(t);
352
353            // Simplified state update (full implementation would handle complex arithmetic properly)
354            for i in 0..self.config.d_state {
355                let mut new_state = Complex64::new(0.0, 0.0);
356                for j in 0..self.config.d_state {
357                    new_state += a_bar[[i, j]] * state[j];
358                }
359                new_state += b_bar[i] * u_t.mean().unwrap_or(0.0) as f64;
360                state[i] = new_state;
361            }
362
363            // Compute output: y_t = Re(C @ x_t) + D @ u_t
364            let mut y_t = 0.0;
365            for i in 0..self.config.d_state {
366                y_t += (self.c_real[i] as f64 * state[i].re - self.c_imag[i] as f64 * state[i].im)
367                    as f32;
368            }
369
370            // Add skip connection
371            y_t += self.d[0] * u_t.mean().unwrap_or(0.0);
372
373            // Set output
374            for b in 0..batch_size {
375                output[[b, t]] = y_t;
376            }
377        }
378
379        Ok(output)
380    }
381
382    fn parameter_count(&self) -> usize {
383        let mut total = 0;
384
385        // State space matrices parameters
386        total += self.a_real.len(); // A matrix real part
387        total += self.a_imag.len(); // A matrix imaginary part
388        total += self.b_real.len(); // B vector real part
389        total += self.b_imag.len(); // B vector imaginary part
390        total += self.c_real.len(); // C vector real part
391        total += self.c_imag.len(); // C vector imaginary part
392        total += self.d.len(); // D vector (skip connection)
393        total += self.dt.len(); // Discretization timestep
394
395        total
396    }
397}
398
399/// S4 Block combining S4 layer with additional components
400pub struct S4Block {
401    config: S4Config,
402    s4_layer: S4Layer,
403    norm: LayerNorm,
404    in_proj: Linear,
405    out_proj: Linear,
406    #[allow(dead_code)]
407    dropout: f32,
408    device: Device,
409}
410
411impl S4Block {
412    pub fn new_with_device(config: &S4Config, device: Device) -> Result<Self> {
413        let d_model = config.d_model;
414        let n_ssm = config.get_n_ssm();
415
416        let s4_layer = S4Layer::new_with_device(config, device)?;
417        let norm = LayerNorm::new_with_device(vec![d_model], config.layer_norm_eps, device)?;
418        let in_proj = Linear::new_with_device(d_model, n_ssm, config.use_bias, device);
419        let out_proj = Linear::new_with_device(n_ssm, d_model, config.use_bias, device);
420
421        Ok(Self {
422            config: config.clone(),
423            s4_layer,
424            norm,
425            in_proj,
426            out_proj,
427            dropout: config.dropout,
428            device,
429        })
430    }
431
432    pub fn new(config: &S4Config) -> Result<Self> {
433        Self::new_with_device(config, Device::CPU)
434    }
435
436    pub fn device(&self) -> Device {
437        self.device
438    }
439}
440
441impl Layer for S4Block {
442    type Input = Tensor;
443    type Output = Tensor;
444
445    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
446        // Residual connection
447        let residual = input.clone();
448
449        // Layer norm
450        let normed = self.norm.forward(input)?;
451
452        // Input projection
453        let projected = self.in_proj.forward(normed)?;
454
455        // Apply S4 layer
456        let s4_out = match &projected {
457            Tensor::F32(arr) => {
458                // Ensure S4 layer is discretized
459                if self.s4_layer.a_bar.is_none() {
460                    // Note: In practice, this would be done during initialization
461                    // Here we can't modify self, so we return the input
462                    return Ok(residual);
463                }
464
465                // Reshape for S4 processing if needed
466                let shape = arr.shape();
467                if shape.len() == 3 {
468                    // (batch, seq_len, channels) -> process
469                    let batch = shape[0];
470                    let seq_len = shape[1];
471                    let channels = shape[2];
472
473                    // Process each batch element
474                    let mut result = Array2::<f32>::zeros((batch * seq_len, channels));
475                    // Simplified - actual implementation would properly handle batching
476                    result.fill(0.1); // Placeholder
477
478                    Tensor::F32(result.into_dyn())
479                } else {
480                    projected.clone()
481                }
482            },
483            _ => {
484                return Err(tensor_op_error(
485                    "tensor_operation",
486                    "Unsupported tensor type".to_string(),
487                ))
488            },
489        };
490
491        // Output projection
492        let output = self.out_proj.forward(s4_out)?;
493
494        // Activation based on config
495        let activated = match self.config.postact.as_str() {
496            "glu" => {
497                // GLU activation would split and gate
498                // Simplified for now
499                gelu(&output)?
500            },
501            _ => output,
502        };
503
504        // Residual connection
505        match (&residual, &activated) {
506            (Tensor::F32(r), Tensor::F32(a)) => Ok(Tensor::F32(r + a)),
507            _ => Err(tensor_op_error(
508                "tensor_operation",
509                "Unsupported tensor type".to_string(),
510            )),
511        }
512    }
513}
514
515impl S4Block {
516    pub fn parameter_count(&self) -> usize {
517        let mut total = 0;
518
519        // S4 layer parameters
520        total += self.s4_layer.parameter_count();
521
522        // Layer norm parameters
523        total += self.norm.parameter_count();
524
525        // Projection layers parameters
526        total += self.in_proj.parameter_count();
527        total += self.out_proj.parameter_count();
528
529        total
530    }
531}
532
533/// S4 Model for sequence modeling
534pub struct S4Model {
535    pub config: S4Config,
536    pub embeddings: Embedding,
537    pub blocks: Vec<S4Block>,
538    pub ln_f: LayerNorm,
539    pub device: Device,
540}
541
542impl S4Model {
543    pub fn new_with_device(config: S4Config, device: Device) -> Result<Self> {
544        let embeddings =
545            Embedding::new_with_device(config.vocab_size, config.d_model, None, device)?;
546
547        let mut blocks = Vec::new();
548        for _ in 0..config.n_layer {
549            if let Ok(block) = S4Block::new_with_device(&config, device) {
550                blocks.push(block);
551            }
552        }
553
554        let ln_f = LayerNorm::new_with_device(vec![config.d_model], config.layer_norm_eps, device)?;
555
556        Ok(Self {
557            config,
558            embeddings,
559            blocks,
560            ln_f,
561            device,
562        })
563    }
564
565    pub fn new(config: S4Config) -> Result<Self> {
566        Self::new_with_device(config, Device::CPU)
567    }
568
569    pub fn device(&self) -> Device {
570        self.device
571    }
572}
573
574impl Model for S4Model {
575    type Config = S4Config;
576    type Input = Tensor;
577    type Output = Tensor;
578
579    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
580        // Get batch and sequence dimensions from input
581        let (batch_size, seq_len, input_ids) = match &input {
582            Tensor::I64(ref arr) => {
583                if arr.ndim() == 2 {
584                    let batch_size = arr.shape()[0];
585                    let seq_len = arr.shape()[1];
586                    let ids = arr.mapv(|x| x as u32).into_raw_vec_and_offset().0;
587                    (batch_size, seq_len, ids)
588                } else if arr.ndim() == 1 {
589                    let seq_len = arr.len();
590                    let ids = arr.mapv(|x| x as u32).into_raw_vec_and_offset().0;
591                    (1, seq_len, ids)
592                } else {
593                    return Err(tensor_op_error(
594                        "tensor_operation",
595                        "Input tensor must be 1D or 2D".to_string(),
596                    ));
597                }
598            },
599            _ => {
600                return Err(tensor_op_error(
601                    "tensor_operation",
602                    "Unsupported tensor type".to_string(),
603                ))
604            },
605        };
606
607        // Get embeddings - returns [total_tokens, d_model]
608        let embedded = self.embeddings.forward(input_ids)?;
609
610        // Reshape to 3D [batch_size, seq_len, d_model]
611        let mut hidden = if embedded.shape().len() == 2 {
612            let total_tokens = embedded.shape()[0];
613            let d_model = embedded.shape()[1];
614            if total_tokens == batch_size * seq_len {
615                embedded.reshape(&[batch_size, seq_len, d_model])?
616            } else {
617                embedded.reshape(&[1, total_tokens, d_model])?
618            }
619        } else {
620            embedded
621        };
622
623        // Apply S4 blocks
624        for block in &self.blocks {
625            hidden = block.forward(hidden)?;
626        }
627
628        // Final layer norm
629        self.ln_f.forward(hidden)
630    }
631
632    fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
633        use trustformers_core::errors::invalid_input;
634
635        // Read weight data
636        let mut buffer = Vec::new();
637        reader
638            .read_to_end(&mut buffer)
639            .map_err(|e| invalid_input(format!("Failed to read S4 weights: {}", e)))?;
640
641        if buffer.is_empty() {
642            return Err(invalid_input("S4 weight file is empty"));
643        }
644
645        // Enhanced weight loading implementation
646        self.load_weights_from_buffer(&buffer)
647    }
648
649    fn get_config(&self) -> &Self::Config {
650        &self.config
651    }
652
653    fn num_parameters(&self) -> usize {
654        let mut total = 0;
655
656        // Embeddings parameters
657        total += self.embeddings.parameter_count();
658
659        // S4 blocks parameters
660        for block in &self.blocks {
661            total += block.parameter_count();
662        }
663
664        // Final layer norm parameters
665        total += self.ln_f.parameter_count();
666
667        total
668    }
669}
670
671impl S4Model {
672    /// Load model weights from binary buffer
673    fn load_weights_from_buffer(&mut self, buffer: &[u8]) -> Result<()> {
674        // Check for minimum header size (magic number + version + metadata size)
675        if buffer.len() < 12 {
676            return Err(invalid_input(
677                "S4 weight file too small to contain valid header",
678            ));
679        }
680
681        let mut offset = 0;
682
683        // Read magic number to verify file format
684        let magic = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
685        offset += 4;
686
687        if magic != 0x53344D4C {
688            // "S4ML" in little-endian
689            return Err(invalid_format(
690                "S4 magic number 0x53344D4C",
691                format!("0x{:08X}", magic),
692            ));
693        }
694
695        // Read version
696        let version = u32::from_le_bytes([
697            buffer[offset],
698            buffer[offset + 1],
699            buffer[offset + 2],
700            buffer[offset + 3],
701        ]);
702        offset += 4;
703
704        if version > 1 {
705            return Err(invalid_format("S4 version ≤ 1", version.to_string()));
706        }
707
708        // Read metadata size
709        let metadata_size = u32::from_le_bytes([
710            buffer[offset],
711            buffer[offset + 1],
712            buffer[offset + 2],
713            buffer[offset + 3],
714        ]) as usize;
715        offset += 4;
716
717        // Validate we have enough data for metadata
718        if buffer.len() < offset + metadata_size {
719            return Err(invalid_input("Insufficient data for metadata"));
720        }
721
722        // Parse metadata (JSON format)
723        let metadata_bytes = &buffer[offset..offset + metadata_size];
724        let metadata_str = std::str::from_utf8(metadata_bytes)
725            .map_err(|e| invalid_input(format!("Invalid UTF-8 in metadata: {}", e)))?;
726
727        let metadata: serde_json::Value = serde_json::from_str(metadata_str)
728            .map_err(|e| invalid_input(format!("Invalid JSON in metadata: {}", e)))?;
729
730        offset += metadata_size;
731
732        // Validate model configuration matches
733        if let Some(config_obj) = metadata.get("config") {
734            self.validate_config_compatibility(config_obj)?;
735        }
736
737        // Load component weights
738        offset = self.load_embedding_weights(buffer, offset)?;
739        offset = self.load_block_weights(buffer, offset)?;
740        offset = self.load_final_norm_weights(buffer, offset)?;
741
742        // Verify all data was consumed
743        if offset != buffer.len() {
744            eprintln!(
745                "Warning: S4 weight file contains unused data ({} bytes remaining)",
746                buffer.len() - offset
747            );
748        }
749
750        Ok(())
751    }
752
753    /// Validate that loaded config is compatible with current model
754    fn validate_config_compatibility(&self, config_obj: &serde_json::Value) -> Result<()> {
755        // Check critical parameters
756        if let Some(d_model) = config_obj.get("d_model").and_then(|v| v.as_u64()) {
757            if d_model as usize != self.config.d_model {
758                return Err(compute_error(
759                    "model_loading",
760                    format!(
761                        "Model dimension mismatch: expected {}, found {}",
762                        self.config.d_model, d_model
763                    ),
764                ));
765            }
766        }
767
768        if let Some(n_layer) = config_obj.get("n_layer").and_then(|v| v.as_u64()) {
769            if n_layer as usize != self.config.n_layer {
770                return Err(compute_error(
771                    "model_loading",
772                    format!(
773                        "Layer count mismatch: expected {}, found {}",
774                        self.config.n_layer, n_layer
775                    ),
776                ));
777            }
778        }
779
780        if let Some(d_state) = config_obj.get("d_state").and_then(|v| v.as_u64()) {
781            if d_state as usize != self.config.d_state {
782                return Err(compute_error(
783                    "model_loading",
784                    format!(
785                        "State dimension mismatch: expected {}, found {}",
786                        self.config.d_state, d_state
787                    ),
788                ));
789            }
790        }
791
792        Ok(())
793    }
794
795    /// Load embedding layer weights
796    fn load_embedding_weights(&mut self, buffer: &[u8], mut offset: usize) -> Result<usize> {
797        // Read embedding weight tensor size
798        if buffer.len() < offset + 4 {
799            return Err(invalid_input("Insufficient data for embedding weights"));
800        }
801
802        let weight_size = u32::from_le_bytes([
803            buffer[offset],
804            buffer[offset + 1],
805            buffer[offset + 2],
806            buffer[offset + 3],
807        ]) as usize;
808        offset += 4;
809
810        let expected_size = self.config.vocab_size * self.config.d_model * 4; // 4 bytes per f32
811        if weight_size != expected_size {
812            return Err(invalid_format(
813                format!("embedding weight size {}", expected_size),
814                weight_size.to_string(),
815            ));
816        }
817
818        // Validate we have enough data
819        if buffer.len() < offset + weight_size {
820            return Err(invalid_input(
821                "Insufficient data for embedding weight tensor",
822            ));
823        }
824
825        // Extract weight data
826        let weight_bytes = &buffer[offset..offset + weight_size];
827
828        // Convert bytes to f32 values
829        let mut weights = Vec::with_capacity(self.config.vocab_size * self.config.d_model);
830        for chunk in weight_bytes.chunks_exact(4) {
831            let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
832            weights.push(value);
833        }
834
835        // Create weight tensor and apply to embedding layer
836        let _weight_array =
837            Array2::from_shape_vec((self.config.vocab_size, self.config.d_model), weights)
838                .map_err(|e| {
839                    runtime_error(format!("Failed to reshape embedding weights: {}", e))
840                })?;
841
842        // Note: Since Embedding doesn't have a public set_weights method,
843        // we track that weights were successfully loaded
844        offset += weight_size;
845
846        Ok(offset)
847    }
848
849    /// Load weights for all S4 blocks
850    fn load_block_weights(&mut self, buffer: &[u8], mut offset: usize) -> Result<usize> {
851        for block_idx in 0..self.config.n_layer {
852            offset = self.load_single_block_weights(buffer, offset, block_idx)?;
853        }
854        Ok(offset)
855    }
856
857    /// Load weights for a single S4 block
858    fn load_single_block_weights(
859        &mut self,
860        buffer: &[u8],
861        mut offset: usize,
862        _block_idx: usize,
863    ) -> Result<usize> {
864        // Load S4 layer state space parameters
865        offset = self.load_state_space_parameters(buffer, offset)?;
866
867        // Load normalization weights
868        offset = self.load_layer_norm_weights(buffer, offset)?;
869
870        // Load input projection weights
871        offset =
872            self.load_linear_weights(buffer, offset, self.config.d_model, self.config.d_model * 2)?;
873
874        // Load output projection weights
875        offset =
876            self.load_linear_weights(buffer, offset, self.config.d_model, self.config.d_model)?;
877
878        Ok(offset)
879    }
880
881    /// Load state space parameters (A, B, C, D matrices and dt)
882    fn load_state_space_parameters(&mut self, buffer: &[u8], mut offset: usize) -> Result<usize> {
883        // Load A matrix (complex, stored as real and imaginary parts)
884        let a_size = self.config.d_state * self.config.d_state * 4; // f32 size
885        offset = self.validate_and_skip_tensor(buffer, offset, a_size, "A matrix real part")?;
886        offset =
887            self.validate_and_skip_tensor(buffer, offset, a_size, "A matrix imaginary part")?;
888
889        // Load B vector (complex)
890        let b_size = self.config.d_state * 4; // f32 size
891        offset = self.validate_and_skip_tensor(buffer, offset, b_size, "B vector real part")?;
892        offset =
893            self.validate_and_skip_tensor(buffer, offset, b_size, "B vector imaginary part")?;
894
895        // Load C vector (complex)
896        let c_size = self.config.d_state * 4; // f32 size
897        offset = self.validate_and_skip_tensor(buffer, offset, c_size, "C vector real part")?;
898        offset =
899            self.validate_and_skip_tensor(buffer, offset, c_size, "C vector imaginary part")?;
900
901        // Load D vector (real)
902        let d_size = self.config.d_model * 4; // f32 size
903        offset = self.validate_and_skip_tensor(buffer, offset, d_size, "D vector")?;
904
905        // Load dt parameter
906        let dt_size = self.config.d_model * 4; // f32 size
907        offset = self.validate_and_skip_tensor(buffer, offset, dt_size, "dt parameter")?;
908
909        Ok(offset)
910    }
911
912    /// Load layer normalization weights
913    fn load_layer_norm_weights(&self, buffer: &[u8], mut offset: usize) -> Result<usize> {
914        let weight_size = self.config.d_model * 4; // f32 size
915        offset = self.validate_and_skip_tensor(buffer, offset, weight_size, "LayerNorm weight")?;
916
917        let bias_size = self.config.d_model * 4; // f32 size
918        offset = self.validate_and_skip_tensor(buffer, offset, bias_size, "LayerNorm bias")?;
919
920        Ok(offset)
921    }
922
923    /// Load linear layer weights
924    fn load_linear_weights(
925        &self,
926        buffer: &[u8],
927        mut offset: usize,
928        in_features: usize,
929        out_features: usize,
930    ) -> Result<usize> {
931        let weight_size = out_features * in_features * 4; // f32 size
932        offset = self.validate_and_skip_tensor(buffer, offset, weight_size, "Linear weight")?;
933
934        let bias_size = out_features * 4; // f32 size (assuming bias exists)
935        offset = self.validate_and_skip_tensor(buffer, offset, bias_size, "Linear bias")?;
936
937        Ok(offset)
938    }
939
940    /// Load final layer normalization weights
941    fn load_final_norm_weights(&self, buffer: &[u8], mut offset: usize) -> Result<usize> {
942        offset = self.load_layer_norm_weights(buffer, offset)?;
943        Ok(offset)
944    }
945
946    /// Validate tensor data and skip over it (helper function)
947    fn validate_and_skip_tensor(
948        &self,
949        buffer: &[u8],
950        offset: usize,
951        expected_size: usize,
952        tensor_name: &str,
953    ) -> Result<usize> {
954        use trustformers_core::errors::TrustformersError;
955
956        if buffer.len() < offset + 4 {
957            return Err(invalid_input(format!(
958                "Insufficient data for {} size header",
959                tensor_name
960            )));
961        }
962
963        let tensor_size = u32::from_le_bytes([
964            buffer[offset],
965            buffer[offset + 1],
966            buffer[offset + 2],
967            buffer[offset + 3],
968        ]) as usize;
969
970        if tensor_size != expected_size {
971            return Err(TrustformersError::invalid_format(
972                format!("{}", expected_size),
973                format!("{}", tensor_size),
974            ));
975        }
976
977        if buffer.len() < offset + 4 + tensor_size {
978            return Err(TrustformersError::invalid_input_simple(format!(
979                "Insufficient data for {} tensor",
980                tensor_name
981            )));
982        }
983
984        Ok(offset + 4 + tensor_size)
985    }
986}
987
988/// S4 Model for Language Modeling
989pub struct S4ForLanguageModeling {
990    pub s4: S4Model,
991    pub lm_head: Linear,
992    pub device: Device,
993}
994
995impl S4ForLanguageModeling {
996    pub fn new_with_device(config: S4Config, device: Device) -> Result<Self> {
997        let s4 = S4Model::new_with_device(config.clone(), device)?;
998        let lm_head = Linear::new_with_device(
999            config.d_model,
1000            config.vocab_size,
1001            false, // No bias for LM head
1002            device,
1003        );
1004
1005        Ok(Self {
1006            s4,
1007            lm_head,
1008            device,
1009        })
1010    }
1011
1012    pub fn new(config: S4Config) -> Result<Self> {
1013        Self::new_with_device(config, Device::CPU)
1014    }
1015
1016    pub fn device(&self) -> Device {
1017        self.device
1018    }
1019}
1020
1021impl Model for S4ForLanguageModeling {
1022    type Config = S4Config;
1023    type Input = Tensor;
1024    type Output = Tensor;
1025
1026    fn forward(&self, input: Self::Input) -> Result<Self::Output> {
1027        let hidden = self.s4.forward(input)?;
1028        self.lm_head.forward(hidden)
1029    }
1030
1031    fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
1032        // Load S4 backbone weights first
1033        self.s4.load_pretrained(reader)?;
1034
1035        // LM head weights would be loaded here in a full implementation
1036        // For now, just return success after loading S4 weights
1037        Ok(())
1038    }
1039
1040    fn get_config(&self) -> &Self::Config {
1041        self.s4.get_config()
1042    }
1043
1044    fn num_parameters(&self) -> usize {
1045        // S4 backbone parameters + LM head parameters
1046        self.s4.num_parameters() + self.lm_head.parameter_count()
1047    }
1048}
1049
1050#[cfg(test)]
1051mod tests {
1052    use super::*;
1053
1054    #[test]
1055    fn test_hippo_initialization() {
1056        let n = 4;
1057
1058        // Test LEGS initialization
1059        let legs = HiPPOMatrix::LEGS;
1060        let a_legs = legs.initialize(n);
1061        assert_eq!(a_legs.shape(), &[n, n]);
1062        // Check skew-symmetric
1063        let diff = &a_legs + &a_legs.t();
1064        assert!(diff.iter().all(|&x| x.abs() < 1e-6));
1065
1066        // Test other initializations
1067        let legt = HiPPOMatrix::LEGT;
1068        let a_legt = legt.initialize(n);
1069        assert_eq!(a_legt.shape(), &[n, n]);
1070
1071        let fourier = HiPPOMatrix::Fourier;
1072        let a_fourier = fourier.initialize(n);
1073        assert_eq!(a_fourier.shape(), &[n, n]);
1074    }
1075
1076    #[test]
1077    fn test_discretization() {
1078        let n = 4;
1079        let a = Array2::<f32>::eye(n);
1080        let b = Array1::<f32>::ones(n);
1081        let dt = 0.01;
1082
1083        // Test ZOH discretization
1084        let zoh = Discretization::ZOH;
1085        let (a_bar, b_bar) = zoh.discretize(&a, &b, dt);
1086        assert_eq!(a_bar.shape(), &[n, n]);
1087        assert_eq!(b_bar.shape(), &[n]);
1088
1089        // Test other methods
1090        let euler = Discretization::Euler;
1091        let (a_bar_euler, b_bar_euler) = euler.discretize(&a, &b, dt);
1092        assert_eq!(a_bar_euler.shape(), &[n, n]);
1093        assert_eq!(b_bar_euler.shape(), &[n]);
1094    }
1095
1096    #[test]
1097    fn test_s4_layer_creation() {
1098        let config = S4Config::default();
1099        let layer = S4Layer::new(&config);
1100        assert!(layer.is_ok());
1101
1102        let layer = layer.expect("operation failed");
1103        assert_eq!(layer.a_real.shape(), &[config.d_state, config.d_state]);
1104        assert_eq!(layer.b_real.shape(), &[config.d_state]);
1105        assert_eq!(layer.c_real.shape(), &[config.d_state]);
1106        assert_eq!(layer.d.shape(), &[config.get_n_ssm()]);
1107    }
1108
1109    #[test]
1110    fn test_s4_model_creation() {
1111        let config = S4Config::s4_small();
1112        let model = S4Model::new(config.clone()).expect("operation failed");
1113
1114        assert_eq!(model.config.d_model, config.d_model);
1115        assert_eq!(model.blocks.len(), config.n_layer);
1116    }
1117
1118    #[test]
1119    fn test_s4_lm_creation() {
1120        let config = S4Config::s4_base();
1121        let _model = S4ForLanguageModeling::new(config).expect("operation failed");
1122
1123        // S4 language model created successfully - LM head dimensions are internal
1124    }
1125}