scirs2_text/
neural_architectures.rs

1//! Advanced neural architectures for text processing
2//!
3//! This module provides various neural network architectures optimized for
4//! text processing tasks, including RNNs, CNNs, attention mechanisms, and hybrid models.
5
6use crate::error::{Result, TextError};
7use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
8use scirs2_core::random::{self, Rng};
9use statrs::statistics::Statistics;
10
11/// Activation functions for neural networks
12#[derive(Debug, Clone, Copy)]
13pub enum ActivationFunction {
14    /// Hyperbolic tangent
15    Tanh,
16    /// Sigmoid function
17    Sigmoid,
18    /// Rectified Linear Unit
19    ReLU,
20    /// Gaussian Error Linear Unit
21    GELU,
22    /// Swish activation
23    Swish,
24    /// Linear activation (identity)
25    Linear,
26}
27
28impl ActivationFunction {
29    /// Apply activation function to a value
30    pub fn apply(&self, x: f64) -> f64 {
31        match self {
32            ActivationFunction::Tanh => x.tanh(),
33            ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
34            ActivationFunction::ReLU => x.max(0.0),
35            ActivationFunction::GELU => {
36                0.5 * x * (1.0 + (x * 0.7978845608 * (1.0 + 0.044715 * x * x)).tanh())
37            }
38            ActivationFunction::Swish => x / (1.0 + (-x).exp()),
39            ActivationFunction::Linear => x,
40        }
41    }
42
43    /// Apply activation function to an array
44    pub fn apply_array(&self, x: &Array1<f64>) -> Array1<f64> {
45        x.mapv(|val| self.apply(val))
46    }
47
48    /// Compute derivative of activation function
49    pub fn derivative(&self, x: f64) -> f64 {
50        match self {
51            ActivationFunction::Tanh => {
52                let tanh_x = x.tanh();
53                1.0 - tanh_x * tanh_x
54            }
55            ActivationFunction::Sigmoid => {
56                let sig_x = self.apply(x);
57                sig_x * (1.0 - sig_x)
58            }
59            ActivationFunction::ReLU => {
60                if x > 0.0 {
61                    1.0
62                } else {
63                    0.0
64                }
65            }
66            ActivationFunction::GELU => {
67                // Approximate derivative of GELU
68                let cdf = 0.5 * (1.0 + (x * 0.7978845608).tanh());
69                let pdf = 0.7978845608 * (-0.5 * x * x).exp();
70                cdf + x * pdf
71            }
72            ActivationFunction::Swish => {
73                let sig_x = 1.0 / (1.0 + (-x).exp());
74                sig_x + x * sig_x * (1.0 - sig_x)
75            }
76            ActivationFunction::Linear => 1.0,
77        }
78    }
79}
80
81/// Long Short-Term Memory (LSTM) cell
82#[derive(Debug, Clone)]
83pub struct LSTMCell {
84    /// Input gate weights
85    w_i: Array2<f64>,
86    /// Forget gate weights
87    w_f: Array2<f64>,
88    /// Output gate weights
89    w_o: Array2<f64>,
90    /// Candidate gate weights
91    w_c: Array2<f64>,
92    /// Hidden state weights for input gate
93    u_i: Array2<f64>,
94    /// Hidden state weights for forget gate
95    u_f: Array2<f64>,
96    /// Hidden state weights for output gate
97    u_o: Array2<f64>,
98    /// Hidden state weights for candidate gate
99    u_c: Array2<f64>,
100    /// Bias vectors
101    b_i: Array1<f64>,
102    b_f: Array1<f64>,
103    b_o: Array1<f64>,
104    b_c: Array1<f64>,
105    /// Input size
106    input_size: usize,
107    /// Hidden size
108    hidden_size: usize,
109}
110
111impl LSTMCell {
112    /// Create new LSTM cell
113    pub fn new(_input_size: usize, hiddensize: usize) -> Self {
114        let scale = (2.0 / (_input_size + hiddensize) as f64).sqrt();
115
116        // Initialize weights with Xavier initialization
117        let w_i = Array2::from_shape_fn((hiddensize, _input_size), |_| {
118            scirs2_core::random::rng().random_range(-scale..scale)
119        });
120        let w_f = Array2::from_shape_fn((hiddensize, _input_size), |_| {
121            scirs2_core::random::rng().random_range(-scale..scale)
122        });
123        let w_o = Array2::from_shape_fn((hiddensize, _input_size), |_| {
124            scirs2_core::random::rng().random_range(-scale..scale)
125        });
126        let w_c = Array2::from_shape_fn((hiddensize, _input_size), |_| {
127            scirs2_core::random::rng().random_range(-scale..scale)
128        });
129
130        let u_i = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
131            scirs2_core::random::rng().random_range(-scale..scale)
132        });
133        let u_f = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
134            scirs2_core::random::rng().random_range(-scale..scale)
135        });
136        let u_o = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
137            scirs2_core::random::rng().random_range(-scale..scale)
138        });
139        let u_c = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
140            scirs2_core::random::rng().random_range(-scale..scale)
141        });
142
143        // Initialize biases (forget gate bias to 1.0 for better gradient flow)
144        let b_i = Array1::zeros(hiddensize);
145        let b_f = Array1::ones(hiddensize);
146        let b_o = Array1::zeros(hiddensize);
147        let b_c = Array1::zeros(hiddensize);
148
149        Self {
150            w_i,
151            w_f,
152            w_o,
153            w_c,
154            u_i,
155            u_f,
156            u_o,
157            u_c,
158            b_i,
159            b_f,
160            b_o,
161            b_c,
162            input_size: _input_size,
163            hidden_size: hiddensize,
164        }
165    }
166
167    /// Forward pass through LSTM cell
168    pub fn forward(
169        &self,
170        x: ArrayView1<f64>,
171        h_prev: ArrayView1<f64>,
172        c_prev: ArrayView1<f64>,
173    ) -> Result<(Array1<f64>, Array1<f64>)> {
174        if x.len() != self.input_size {
175            return Err(TextError::InvalidInput(format!(
176                "Expected input size {}, got {}",
177                self.input_size,
178                x.len()
179            )));
180        }
181
182        if h_prev.len() != self.hidden_size || c_prev.len() != self.hidden_size {
183            return Err(TextError::InvalidInput(format!(
184                "Expected hidden size {}, got h: {}, c: {}",
185                self.hidden_size,
186                h_prev.len(),
187                c_prev.len()
188            )));
189        }
190
191        // Input gate
192        let i_t = ActivationFunction::Sigmoid
193            .apply_array(&(self.w_i.dot(&x) + self.u_i.dot(&h_prev) + &self.b_i));
194
195        // Forget gate
196        let f_t = ActivationFunction::Sigmoid
197            .apply_array(&(self.w_f.dot(&x) + self.u_f.dot(&h_prev) + &self.b_f));
198
199        // Output gate
200        let o_t = ActivationFunction::Sigmoid
201            .apply_array(&(self.w_o.dot(&x) + self.u_o.dot(&h_prev) + &self.b_o));
202
203        // Candidate values
204        let c_tilde = ActivationFunction::Tanh
205            .apply_array(&(self.w_c.dot(&x) + self.u_c.dot(&h_prev) + &self.b_c));
206
207        // Cell state
208        let c_t = &f_t * &c_prev + &i_t * &c_tilde;
209
210        // Hidden state
211        let h_t = &o_t * &ActivationFunction::Tanh.apply_array(&c_t);
212
213        Ok((h_t, c_t))
214    }
215}
216
217/// Gated Recurrent Unit (GRU) cell
218#[derive(Debug, Clone)]
219pub struct GRUCell {
220    /// Update gate weights
221    w_z: Array2<f64>,
222    /// Reset gate weights  
223    w_r: Array2<f64>,
224    /// New gate weights
225    w_h: Array2<f64>,
226    /// Hidden state weights for update gate
227    u_z: Array2<f64>,
228    /// Hidden state weights for reset gate
229    u_r: Array2<f64>,
230    /// Hidden state weights for new gate
231    u_h: Array2<f64>,
232    /// Bias vectors
233    b_z: Array1<f64>,
234    b_r: Array1<f64>,
235    b_h: Array1<f64>,
236    /// Input size
237    input_size: usize,
238    /// Hidden size
239    hidden_size: usize,
240}
241
242impl GRUCell {
243    /// Create new GRU cell
244    pub fn new(_input_size: usize, hiddensize: usize) -> Self {
245        let scale = (2.0 / (_input_size + hiddensize) as f64).sqrt();
246
247        // Initialize weights with Xavier initialization
248        let w_z = Array2::from_shape_fn((hiddensize, _input_size), |_| {
249            scirs2_core::random::rng().random_range(-scale..scale)
250        });
251        let w_r = Array2::from_shape_fn((hiddensize, _input_size), |_| {
252            scirs2_core::random::rng().random_range(-scale..scale)
253        });
254        let w_h = Array2::from_shape_fn((hiddensize, _input_size), |_| {
255            scirs2_core::random::rng().random_range(-scale..scale)
256        });
257
258        let u_z = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
259            scirs2_core::random::rng().random_range(-scale..scale)
260        });
261        let u_r = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
262            scirs2_core::random::rng().random_range(-scale..scale)
263        });
264        let u_h = Array2::from_shape_fn((hiddensize, hiddensize), |_| {
265            scirs2_core::random::rng().random_range(-scale..scale)
266        });
267
268        // Initialize biases
269        let b_z = Array1::zeros(hiddensize);
270        let b_r = Array1::zeros(hiddensize);
271        let b_h = Array1::zeros(hiddensize);
272
273        Self {
274            w_z,
275            w_r,
276            w_h,
277            u_z,
278            u_r,
279            u_h,
280            b_z,
281            b_r,
282            b_h,
283            input_size: _input_size,
284            hidden_size: hiddensize,
285        }
286    }
287
288    /// Forward pass through GRU cell
289    pub fn forward(&self, x: ArrayView1<f64>, hprev: ArrayView1<f64>) -> Result<Array1<f64>> {
290        if x.len() != self.input_size {
291            return Err(TextError::InvalidInput(format!(
292                "Expected input size {}, got {}",
293                self.input_size,
294                x.len()
295            )));
296        }
297
298        if hprev.len() != self.hidden_size {
299            return Err(TextError::InvalidInput(format!(
300                "Expected hidden size {}, got {}",
301                self.hidden_size,
302                hprev.len()
303            )));
304        }
305
306        // Update gate
307        let z_t = ActivationFunction::Sigmoid
308            .apply_array(&(self.w_z.dot(&x) + self.u_z.dot(&hprev) + &self.b_z));
309
310        // Reset gate
311        let r_t = ActivationFunction::Sigmoid
312            .apply_array(&(self.w_r.dot(&x) + self.u_r.dot(&hprev) + &self.b_r));
313
314        // New gate (candidate activation)
315        let h_tilde = ActivationFunction::Tanh
316            .apply_array(&(self.w_h.dot(&x) + self.u_h.dot(&(&r_t * &hprev)) + &self.b_h));
317
318        // Final hidden state
319        let h_t = &(&Array1::ones(self.hidden_size) - &z_t) * &hprev + &z_t * &h_tilde;
320
321        Ok(h_t)
322    }
323}
324
325/// Bidirectional LSTM layer
326pub struct BiLSTM {
327    /// Forward LSTM cells
328    forward_cells: Vec<LSTMCell>,
329    /// Backward LSTM cells
330    backward_cells: Vec<LSTMCell>,
331    /// Number of layers
332    num_layers: usize,
333    /// Hidden size
334    hidden_size: usize,
335}
336
337impl BiLSTM {
338    /// Create new bidirectional LSTM
339    pub fn new(_input_size: usize, hidden_size: usize, numlayers: usize) -> Self {
340        let mut forward_cells = Vec::new();
341        let mut backward_cells = Vec::new();
342
343        for i in 0..numlayers {
344            let layer_input_size = if i == 0 { _input_size } else { hidden_size * 2 };
345            forward_cells.push(LSTMCell::new(layer_input_size, hidden_size));
346            backward_cells.push(LSTMCell::new(layer_input_size, hidden_size));
347        }
348
349        Self {
350            forward_cells,
351            backward_cells,
352            num_layers: numlayers,
353            hidden_size,
354        }
355    }
356
357    /// Forward pass through bidirectional LSTM
358    pub fn forward(&self, sequence: ArrayView2<f64>) -> Result<Array2<f64>> {
359        let (seq_len, input_size) = sequence.dim();
360        let output_size = self.hidden_size * 2; // Concatenated forward and backward
361
362        let mut current_input = sequence.to_owned();
363
364        for layer in 0..self.num_layers {
365            let mut forward_outputs = Vec::new();
366            let mut backward_outputs = Vec::new();
367
368            // Forward direction
369            let mut h_forward = Array1::zeros(self.hidden_size);
370            let mut c_forward = Array1::zeros(self.hidden_size);
371
372            for t in 0..seq_len {
373                let (h_new, c_new) = self.forward_cells[layer].forward(
374                    current_input.row(t),
375                    h_forward.view(),
376                    c_forward.view(),
377                )?;
378                h_forward = h_new;
379                c_forward = c_new;
380                forward_outputs.push(h_forward.clone());
381            }
382
383            // Backward direction
384            let mut h_backward = Array1::zeros(self.hidden_size);
385            let mut c_backward = Array1::zeros(self.hidden_size);
386
387            for t in (0..seq_len).rev() {
388                let (h_new, c_new) = self.backward_cells[layer].forward(
389                    current_input.row(t),
390                    h_backward.view(),
391                    c_backward.view(),
392                )?;
393                h_backward = h_new;
394                c_backward = c_new;
395                backward_outputs.push(h_backward.clone());
396            }
397
398            // Reverse backward outputs to match forward order
399            backward_outputs.reverse();
400
401            // Concatenate forward and backward outputs
402            let mut layer_output = Array2::zeros((seq_len, output_size));
403            for t in 0..seq_len {
404                let mut concat_output = Array1::zeros(output_size);
405                concat_output
406                    .slice_mut(s![..self.hidden_size])
407                    .assign(&forward_outputs[t]);
408                concat_output
409                    .slice_mut(s![self.hidden_size..])
410                    .assign(&backward_outputs[t]);
411                layer_output.row_mut(t).assign(&concat_output);
412            }
413
414            current_input = layer_output;
415        }
416
417        Ok(current_input)
418    }
419}
420
421/// Convolutional layer for text processing
422#[derive(Debug, Clone)]
423pub struct Conv1D {
424    /// Convolution filters
425    filters: Array3<f64>,
426    /// Bias terms
427    bias: Array1<f64>,
428    /// Number of filters
429    num_filters: usize,
430    /// Kernel size
431    kernel_size: usize,
432    /// Input channels
433    input_channels: usize,
434    /// Activation function
435    activation: ActivationFunction,
436}
437
438impl Conv1D {
439    /// Create new 1D convolutional layer
440    pub fn new(
441        input_channels: usize,
442        num_filters: usize,
443        kernel_size: usize,
444        activation: ActivationFunction,
445    ) -> Self {
446        let scale = (2.0 / (input_channels * kernel_size) as f64).sqrt();
447
448        // Initialize _filters with Xavier initialization
449        let _filters = Array3::from_shape_fn((num_filters, input_channels, kernel_size), |_| {
450            scirs2_core::random::rng().random_range(-scale..scale)
451        });
452
453        let bias = Array1::zeros(num_filters);
454
455        Self {
456            filters: _filters,
457            bias,
458            num_filters,
459            kernel_size,
460            input_channels,
461            activation,
462        }
463    }
464
465    /// Forward pass through convolution layer
466    pub fn forward(&self, input: ArrayView2<f64>) -> Result<Array2<f64>> {
467        let (seq_len, input_dim) = input.dim();
468
469        if input_dim != self.input_channels {
470            return Err(TextError::InvalidInput(format!(
471                "Expected {} input channels, got {}",
472                self.input_channels, input_dim
473            )));
474        }
475
476        let output_len = seq_len.saturating_sub(self.kernel_size - 1);
477        let mut output = Array2::zeros((output_len, self.num_filters));
478
479        for filter_idx in 0..self.num_filters {
480            for pos in 0..output_len {
481                let mut conv_sum = 0.0;
482
483                for ch in 0..self.input_channels {
484                    for k in 0..self.kernel_size {
485                        if pos + k < seq_len {
486                            conv_sum += input[[pos + k, ch]] * self.filters[[filter_idx, ch, k]];
487                        }
488                    }
489                }
490
491                conv_sum += self.bias[filter_idx];
492                output[[pos, filter_idx]] = self.activation.apply(conv_sum);
493            }
494        }
495
496        Ok(output)
497    }
498}
499
500/// Max pooling layer for 1D data
501#[derive(Debug)]
502pub struct MaxPool1D {
503    /// Pool size
504    pool_size: usize,
505    /// Stride
506    stride: usize,
507}
508
509impl MaxPool1D {
510    /// Create new max pooling layer
511    pub fn new(poolsize: usize, stride: usize) -> Self {
512        Self {
513            pool_size: poolsize,
514            stride,
515        }
516    }
517
518    /// Forward pass through max pooling
519    pub fn forward(&self, input: ArrayView2<f64>) -> Array2<f64> {
520        let (seq_len, channels) = input.dim();
521        let output_len = (seq_len - self.pool_size) / self.stride + 1;
522
523        let mut output = Array2::zeros((output_len, channels));
524
525        for ch in 0..channels {
526            for i in 0..output_len {
527                let start = i * self.stride;
528                let end = (start + self.pool_size).min(seq_len);
529
530                let mut max_val = f64::NEG_INFINITY;
531                for j in start..end {
532                    max_val = max_val.max(input[[j, ch]]);
533                }
534
535                output[[i, ch]] = max_val;
536            }
537        }
538
539        output
540    }
541}
542
543/// Residual block for CNNs
544#[derive(Debug, Clone)]
545pub struct ResidualBlock1D {
546    /// First convolution layer
547    conv1: Conv1D,
548    /// Second convolution layer  
549    conv2: Conv1D,
550    /// Skip connection projection (for dimension matching)
551    skip_projection: Option<Array2<f64>>,
552    /// Batch normalization parameters
553    bn1_scale: Array1<f64>,
554    bn1_shift: Array1<f64>,
555    bn2_scale: Array1<f64>,
556    bn2_shift: Array1<f64>,
557}
558
559impl ResidualBlock1D {
560    /// Create new residual block
561    pub fn new(_input_channels: usize, output_channels: usize, kernelsize: usize) -> Self {
562        let conv1 = Conv1D::new(
563            _input_channels,
564            output_channels,
565            kernelsize,
566            ActivationFunction::Linear,
567        );
568        let conv2 = Conv1D::new(
569            output_channels,
570            output_channels,
571            kernelsize,
572            ActivationFunction::Linear,
573        );
574
575        // Skip projection if input and output _channels differ
576        let skip_projection = if _input_channels != output_channels {
577            let scale = (2.0 / _input_channels as f64).sqrt();
578            Some(Array2::from_shape_fn(
579                (output_channels, _input_channels),
580                |_| scirs2_core::random::rng().random_range(-scale..scale),
581            ))
582        } else {
583            None
584        };
585
586        // Batch normalization parameters
587        let bn1_scale = Array1::ones(output_channels);
588        let bn1_shift = Array1::zeros(output_channels);
589        let bn2_scale = Array1::ones(output_channels);
590        let bn2_shift = Array1::zeros(output_channels);
591
592        Self {
593            conv1,
594            conv2,
595            skip_projection,
596            bn1_scale,
597            bn1_shift,
598            bn2_scale,
599            bn2_shift,
600        }
601    }
602
603    /// Forward pass through residual block
604    pub fn forward(&self, input: ArrayView2<f64>) -> Result<Array2<f64>> {
605        // First convolution + batch norm + ReLU
606        let conv1_out = self.conv1.forward(input)?;
607        let bn1_out = self.batch_norm(&conv1_out, &self.bn1_scale, &self.bn1_shift);
608        let relu1_out = bn1_out.mapv(|x| ActivationFunction::ReLU.apply(x));
609
610        // Second convolution + batch norm
611        let conv2_out = self.conv2.forward(relu1_out.view())?;
612        let bn2_out = self.batch_norm(&conv2_out, &self.bn2_scale, &self.bn2_shift);
613
614        // Skip connection
615        let skip_out = if let Some(ref projection) = self.skip_projection {
616            // Project input to match output channels: (seq_len, input_channels) -> (seq_len, output_channels)
617            let projected = input.dot(&projection.t());
618
619            // Handle sequence length mismatch due to convolutions
620            // Each convolution reduces length by (kernel_size - 1), so total reduction is 2 * (kernel_size - 1)
621            let conv_output_len = bn2_out.shape()[0];
622            let skip_len = projected.shape()[0];
623
624            if conv_output_len < skip_len {
625                // Take center slice of skip connection to match conv output length
626                let start = (skip_len - conv_output_len) / 2;
627                let end = start + conv_output_len;
628                projected.slice(s![start..end, ..]).to_owned()
629            } else {
630                projected
631            }
632        } else {
633            // Direct skip connection - handle sequence length mismatch
634            let conv_output_len = bn2_out.shape()[0];
635            let skip_len = input.shape()[0];
636
637            if conv_output_len < skip_len {
638                // Take center slice of input to match conv output length
639                let start = (skip_len - conv_output_len) / 2;
640                let end = start + conv_output_len;
641                input.slice(s![start..end, ..]).to_owned()
642            } else {
643                input.to_owned()
644            }
645        };
646
647        // Add skip connection and apply ReLU
648        let output = &bn2_out + &skip_out;
649        Ok(output.mapv(|x| ActivationFunction::ReLU.apply(x)))
650    }
651
652    /// Simple batch normalization (simplified implementation)
653    fn batch_norm(
654        &self,
655        input: &Array2<f64>,
656        scale: &Array1<f64>,
657        shift: &Array1<f64>,
658    ) -> Array2<f64> {
659        let mut result = input.clone();
660        let eps = 1e-5;
661
662        // Normalize over the sequence dimension for each channel
663        for ch in 0..input.shape()[1] {
664            let channel_data = input.column(ch);
665            let mean = channel_data.mean();
666            let var = channel_data.mapv(|x| (x - mean).powi(2)).mean();
667            let std = (var + eps).sqrt();
668
669            let mut normalized = channel_data.mapv(|x| (x - mean) / std);
670            normalized = normalized * scale[ch] + shift[ch];
671
672            result.column_mut(ch).assign(&normalized);
673        }
674
675        result
676    }
677}
678
679/// Multi-scale CNN for text processing
680#[derive(Debug)]
681pub struct MultiScaleCNN {
682    /// Parallel convolution branches with different kernel sizes
683    conv_branches: Vec<Conv1D>,
684    /// Batch normalization for each branch
685    bn_branches: Vec<(Array1<f64>, Array1<f64>)>,
686    /// Combination weights
687    combinationweights: Array2<f64>,
688    /// Global max pooling
689    #[allow(dead_code)]
690    global_pool: MaxPool1D,
691}
692
693impl MultiScaleCNN {
694    /// Create new multi-scale CNN
695    pub fn new(
696        input_channels: usize,
697        num_filters_per_scale: usize,
698        kernel_sizes: Vec<usize>,
699        output_size: usize,
700    ) -> Self {
701        let mut conv_branches = Vec::new();
702        let mut bn_branches = Vec::new();
703
704        // Create convolution branches for different scales
705        for &kernel_size in &kernel_sizes {
706            conv_branches.push(Conv1D::new(
707                input_channels,
708                num_filters_per_scale,
709                kernel_size,
710                ActivationFunction::ReLU,
711            ));
712
713            // Batch normalization parameters
714            bn_branches.push((
715                Array1::ones(num_filters_per_scale),
716                Array1::zeros(num_filters_per_scale),
717            ));
718        }
719
720        // Combination layer
721        let total_features = kernel_sizes.len() * num_filters_per_scale;
722        let _scale = (2.0 / total_features as f64).sqrt();
723        let combination_weights = Array2::from_shape_fn((output_size, total_features), |_| {
724            scirs2_core::random::rng().random_range(-_scale.._scale)
725        });
726
727        let global_pool = MaxPool1D::new(2, 2);
728
729        Self {
730            conv_branches,
731            bn_branches,
732            combinationweights: combination_weights,
733            global_pool,
734        }
735    }
736
737    /// Forward pass through multi-scale CNN
738    pub fn forward(&self, input: ArrayView2<f64>) -> Result<Array1<f64>> {
739        let mut branch_outputs = Vec::new();
740
741        // Process each scale branch
742        for (i, conv) in self.conv_branches.iter().enumerate() {
743            let conv_out = conv.forward(input)?;
744
745            // Apply batch normalization
746            let (scale, shift) = &self.bn_branches[i];
747            let bn_out = self.batch_norm_branch(&conv_out, scale, shift);
748
749            // Global max pooling over sequence dimension
750            let global_max = bn_out.map_axis(Axis(0), |row| {
751                row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
752            });
753
754            branch_outputs.push(global_max);
755        }
756
757        // Concatenate all branch outputs
758        let mut concatenated = Array1::zeros(branch_outputs.iter().map(|x| x.len()).sum::<usize>());
759        let mut offset = 0;
760        for branch_output in branch_outputs {
761            let end = offset + branch_output.len();
762            concatenated
763                .slice_mut(s![offset..end])
764                .assign(&branch_output);
765            offset = end;
766        }
767
768        // Final combination layer
769        Ok(self.combinationweights.dot(&concatenated))
770    }
771
772    /// Batch normalization for a single branch
773    fn batch_norm_branch(
774        &self,
775        input: &Array2<f64>,
776        scale: &Array1<f64>,
777        shift: &Array1<f64>,
778    ) -> Array2<f64> {
779        let mut result = input.clone();
780        let eps = 1e-5;
781
782        for ch in 0..input.shape()[1] {
783            let channel_data = input.column(ch);
784            let mean = channel_data.mean();
785            let var = channel_data.mapv(|x| (x - mean).powi(2)).mean();
786            let std = (var + eps).sqrt();
787
788            let mut normalized = channel_data.mapv(|x| (x - mean) / std);
789            normalized = normalized * scale[ch] + shift[ch];
790
791            result.column_mut(ch).assign(&normalized);
792        }
793
794        result
795    }
796}
797
798/// Attention mechanism for sequence-to-sequence models
799pub struct AdditiveAttention {
800    /// Attention weights
801    w_a: Array2<f64>,
802    /// Query projection
803    #[allow(dead_code)]
804    w_q: Array2<f64>,
805    /// Key projection
806    #[allow(dead_code)]
807    w_k: Array2<f64>,
808    /// Value projection
809    #[allow(dead_code)]
810    w_v: Array2<f64>,
811    /// Attention vector
812    v_a: Array1<f64>,
813}
814
815impl AdditiveAttention {
816    /// Create new additive attention mechanism
817    pub fn new(_encoder_dim: usize, decoder_dim: usize, attentiondim: usize) -> Self {
818        let scale = (2.0 / attentiondim as f64).sqrt();
819
820        let w_a = Array2::from_shape_fn((attentiondim, _encoder_dim + decoder_dim), |_| {
821            scirs2_core::random::rng().random_range(-scale..scale)
822        });
823
824        let w_q = Array2::from_shape_fn((attentiondim, decoder_dim), |_| {
825            scirs2_core::random::rng().random_range(-scale..scale)
826        });
827
828        let w_k = Array2::from_shape_fn((attentiondim, _encoder_dim), |_| {
829            scirs2_core::random::rng().random_range(-scale..scale)
830        });
831
832        let w_v = Array2::from_shape_fn((_encoder_dim, _encoder_dim), |_| {
833            scirs2_core::random::rng().random_range(-scale..scale)
834        });
835
836        let v_a = Array1::from_shape_fn(attentiondim, |_| {
837            scirs2_core::random::rng().random_range(-scale..scale)
838        });
839
840        Self {
841            w_a,
842            w_q,
843            w_k,
844            w_v,
845            v_a,
846        }
847    }
848
849    /// Compute attention scores
850    pub fn forward(
851        &self,
852        query: ArrayView1<f64>,
853        encoder_outputs: ArrayView2<f64>,
854    ) -> Result<(Array1<f64>, Array1<f64>)> {
855        let seq_len = encoder_outputs.shape()[0];
856        let mut attention_scores = Array1::zeros(seq_len);
857
858        // Compute attention scores for each encoder output
859        for i in 0..seq_len {
860            let encoder_output = encoder_outputs.row(i);
861
862            // Concatenate query and encoder output
863            let mut combined = Array1::zeros(query.len() + encoder_output.len());
864            combined.slice_mut(s![..query.len()]).assign(&query);
865            combined
866                .slice_mut(s![query.len()..])
867                .assign(&encoder_output);
868
869            // Compute attention score
870            let attention_input = self.w_a.dot(&combined);
871            let activated = ActivationFunction::Tanh.apply_array(&attention_input);
872            attention_scores[i] = self.v_a.dot(&activated);
873        }
874
875        // Apply softmax to get attention weights
876        let attention_weights = self.softmax(&attention_scores);
877
878        // Compute context vector
879        let context = encoder_outputs.t().dot(&attention_weights);
880
881        Ok((context, attention_weights))
882    }
883
884    /// Apply softmax to scores
885    fn softmax(&self, scores: &Array1<f64>) -> Array1<f64> {
886        let max_score = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
887        let exp_scores = scores.mapv(|x| (x - max_score).exp());
888        let sum_exp = exp_scores.sum();
889        exp_scores / sum_exp
890    }
891}
892
893/// Self-attention mechanism (simplified transformer-style)
894#[derive(Debug)]
895pub struct SelfAttention {
896    /// Query projection
897    w_q: Array2<f64>,
898    /// Key projection
899    w_k: Array2<f64>,
900    /// Value projection
901    w_v: Array2<f64>,
902    /// Output projection
903    w_o: Array2<f64>,
904    /// Attention dimension
905    d_k: usize,
906    /// Dropout rate
907    #[allow(dead_code)]
908    dropout: f64,
909}
910
911impl SelfAttention {
912    /// Create new self-attention layer
913    pub fn new(_dmodel: usize, dropout: f64) -> Self {
914        let d_k = _dmodel;
915        let scale = (2.0 / _dmodel as f64).sqrt();
916
917        let w_q = Array2::from_shape_fn((_dmodel, d_k), |_| {
918            scirs2_core::random::rng().random_range(-scale..scale)
919        });
920        let w_k = Array2::from_shape_fn((_dmodel, d_k), |_| {
921            scirs2_core::random::rng().random_range(-scale..scale)
922        });
923        let w_v = Array2::from_shape_fn((_dmodel, d_k), |_| {
924            scirs2_core::random::rng().random_range(-scale..scale)
925        });
926        let w_o = Array2::from_shape_fn((d_k, _dmodel), |_| {
927            scirs2_core::random::rng().random_range(-scale..scale)
928        });
929
930        Self {
931            w_q,
932            w_k,
933            w_v,
934            w_o,
935            d_k,
936            dropout,
937        }
938    }
939
940    /// Forward pass through self-attention
941    pub fn forward(
942        &self,
943        input: ArrayView2<f64>,
944        mask: Option<ArrayView2<bool>>,
945    ) -> Result<Array2<f64>> {
946        let _seq_len = input.shape()[0];
947
948        // Compute Q, K, V
949        let q = input.dot(&self.w_q);
950        let k = input.dot(&self.w_k);
951        let v = input.dot(&self.w_v);
952
953        // Scaled dot-product attention
954        let attention_output =
955            self.scaled_dot_product_attention(q.view(), k.view(), v.view(), mask)?;
956
957        // Output projection
958        Ok(attention_output.dot(&self.w_o))
959    }
960
961    /// Scaled dot-product attention computation
962    fn scaled_dot_product_attention(
963        &self,
964        q: ArrayView2<f64>,
965        k: ArrayView2<f64>,
966        v: ArrayView2<f64>,
967        mask: Option<ArrayView2<bool>>,
968    ) -> Result<Array2<f64>> {
969        let d_k = self.d_k as f64;
970
971        // Compute attention scores: Q * K^T / sqrt(d_k)
972        let scores = q.dot(&k.t()) / d_k.sqrt();
973
974        // Apply mask if provided
975        let mut masked_scores = scores;
976        if let Some(mask) = mask {
977            for ((i, j), &should_mask) in mask.indexed_iter() {
978                if should_mask {
979                    masked_scores[[i, j]] = f64::NEG_INFINITY;
980                }
981            }
982        }
983
984        // Apply softmax
985        let attention_weights = self.softmax_2d(&masked_scores)?;
986
987        // Apply attention to values
988        Ok(attention_weights.dot(&v))
989    }
990
991    /// Apply softmax to 2D array along last axis
992    fn softmax_2d(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
993        let mut result = x.clone();
994
995        for mut row in result.rows_mut() {
996            let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
997            row.mapv_inplace(|x| (x - max_val).exp());
998            let sum: f64 = row.sum();
999            if sum > 0.0 {
1000                row /= sum;
1001            }
1002        }
1003
1004        Ok(result)
1005    }
1006}
1007
1008/// Cross-attention mechanism for encoder-decoder architectures
1009#[derive(Debug)]
1010pub struct CrossAttention {
1011    /// Query projection
1012    w_q: Array2<f64>,
1013    /// Key projection
1014    w_k: Array2<f64>,
1015    /// Value projection
1016    w_v: Array2<f64>,
1017    /// Output projection
1018    w_o: Array2<f64>,
1019    /// Attention dimension
1020    d_k: usize,
1021}
1022
1023impl CrossAttention {
1024    /// Create new cross-attention layer
1025    pub fn new(_dmodel: usize) -> Self {
1026        let d_k = _dmodel;
1027        let scale = (2.0 / _dmodel as f64).sqrt();
1028
1029        let w_q = Array2::from_shape_fn((_dmodel, d_k), |_| {
1030            scirs2_core::random::rng().random_range(-scale..scale)
1031        });
1032        let w_k = Array2::from_shape_fn((_dmodel, d_k), |_| {
1033            scirs2_core::random::rng().random_range(-scale..scale)
1034        });
1035        let w_v = Array2::from_shape_fn((_dmodel, d_k), |_| {
1036            scirs2_core::random::rng().random_range(-scale..scale)
1037        });
1038        let w_o = Array2::from_shape_fn((d_k, _dmodel), |_| {
1039            scirs2_core::random::rng().random_range(-scale..scale)
1040        });
1041
1042        Self {
1043            w_q,
1044            w_k,
1045            w_v,
1046            w_o,
1047            d_k,
1048        }
1049    }
1050
1051    /// Forward pass through cross-attention
1052    pub fn forward(
1053        &self,
1054        query: ArrayView2<f64>,
1055        key: ArrayView2<f64>,
1056        value: ArrayView2<f64>,
1057        mask: Option<ArrayView2<bool>>,
1058    ) -> Result<Array2<f64>> {
1059        // Compute Q, K, V
1060        let q = query.dot(&self.w_q);
1061        let k = key.dot(&self.w_k);
1062        let v = value.dot(&self.w_v);
1063
1064        // Scaled dot-product attention
1065        let attention_output =
1066            self.scaled_dot_product_attention(q.view(), k.view(), v.view(), mask)?;
1067
1068        // Output projection
1069        Ok(attention_output.dot(&self.w_o))
1070    }
1071
1072    /// Scaled dot-product attention computation
1073    fn scaled_dot_product_attention(
1074        &self,
1075        q: ArrayView2<f64>,
1076        k: ArrayView2<f64>,
1077        v: ArrayView2<f64>,
1078        mask: Option<ArrayView2<bool>>,
1079    ) -> Result<Array2<f64>> {
1080        let d_k = self.d_k as f64;
1081
1082        // Compute attention scores: Q * K^T / sqrt(d_k)
1083        let scores = q.dot(&k.t()) / d_k.sqrt();
1084
1085        // Apply mask if provided
1086        let mut masked_scores = scores;
1087        if let Some(mask) = mask {
1088            for ((i, j), &should_mask) in mask.indexed_iter() {
1089                if should_mask {
1090                    masked_scores[[i, j]] = f64::NEG_INFINITY;
1091                }
1092            }
1093        }
1094
1095        // Apply softmax
1096        let attention_weights = self.softmax_2d(&masked_scores)?;
1097
1098        // Apply attention to values
1099        Ok(attention_weights.dot(&v))
1100    }
1101
1102    /// Apply softmax to 2D array along last axis
1103    fn softmax_2d(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
1104        let mut result = x.clone();
1105
1106        for mut row in result.rows_mut() {
1107            let max_val = row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1108            row.mapv_inplace(|x| (x - max_val).exp());
1109            let sum: f64 = row.sum();
1110            if sum > 0.0 {
1111                row /= sum;
1112            }
1113        }
1114
1115        Ok(result)
1116    }
1117}
1118
1119/// Positional feed-forward network with GELU activation
1120#[derive(Debug)]
1121pub struct PositionwiseFeedForward {
1122    /// First linear transformation
1123    w1: Array2<f64>,
1124    /// Second linear transformation
1125    w2: Array2<f64>,
1126    /// Bias vectors
1127    b1: Array1<f64>,
1128    b2: Array1<f64>,
1129    /// Dropout rate
1130    dropout: f64,
1131}
1132
1133impl PositionwiseFeedForward {
1134    /// Create new position-wise feed-forward network
1135    pub fn new(_dmodel: usize, dff: usize, dropout: f64) -> Self {
1136        let scale1 = (2.0 / _dmodel as f64).sqrt();
1137        let scale2 = (2.0 / dff as f64).sqrt();
1138
1139        let w1 = Array2::from_shape_fn((dff, _dmodel), |_| {
1140            scirs2_core::random::rng().random_range(-scale1..scale1)
1141        });
1142        let w2 = Array2::from_shape_fn((_dmodel, dff), |_| {
1143            scirs2_core::random::rng().random_range(-scale2..scale2)
1144        });
1145        let b1 = Array1::zeros(dff);
1146        let b2 = Array1::zeros(_dmodel);
1147
1148        Self {
1149            w1,
1150            w2,
1151            b1,
1152            b2,
1153            dropout,
1154        }
1155    }
1156
1157    /// Forward pass through feed-forward network
1158    pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
1159        // First linear transformation + GELU
1160        let hidden = x.dot(&self.w1.t()) + &self.b1;
1161        let activated = hidden.mapv(|x| ActivationFunction::GELU.apply(x));
1162
1163        // Apply dropout (simplified - in practice would use random mask)
1164        let dropout_mask = if self.dropout > 0.0 {
1165            1.0 - self.dropout
1166        } else {
1167            1.0
1168        };
1169        let dropped = activated * dropout_mask;
1170
1171        // Second linear transformation
1172        dropped.dot(&self.w2.t()) + &self.b2
1173    }
1174}
1175
1176/// Text CNN architecture for classification
1177pub struct TextCNN {
1178    /// Convolutional layers with different kernel sizes
1179    conv_layers: Vec<Conv1D>,
1180    /// Max pooling layers
1181    pool_layers: Vec<MaxPool1D>,
1182    /// Fully connected layer weights
1183    fcweights: Array2<f64>,
1184    /// Fully connected layer bias
1185    fc_bias: Array1<f64>,
1186    /// Dropout rate
1187    dropout_rate: f64,
1188}
1189
1190impl TextCNN {
1191    /// Create new Text CNN
1192    #[allow(clippy::too_many_arguments)]
1193    pub fn new(
1194        _vocab_size: usize,
1195        embedding_dim: usize,
1196        num_filters: usize,
1197        filter_sizes: Vec<usize>,
1198        num_classes: usize,
1199        dropout_rate: f64,
1200    ) -> Self {
1201        let mut conv_layers = Vec::new();
1202        let mut pool_layers = Vec::new();
1203
1204        // Create convolutional layers with different filter _sizes
1205        for &filter_size in &filter_sizes {
1206            conv_layers.push(Conv1D::new(
1207                embedding_dim,
1208                num_filters,
1209                filter_size,
1210                ActivationFunction::ReLU,
1211            ));
1212            pool_layers.push(MaxPool1D::new(2, 2));
1213        }
1214
1215        // Fully connected layer
1216        let fc_input_size = num_filters * filter_sizes.len();
1217        let scale = (2.0 / fc_input_size as f64).sqrt();
1218
1219        let fc_weights = Array2::from_shape_fn((num_classes, fc_input_size), |_| {
1220            scirs2_core::random::rng().random_range(-scale..scale)
1221        });
1222        let fc_bias = Array1::zeros(num_classes);
1223
1224        Self {
1225            conv_layers,
1226            pool_layers,
1227            fcweights: fc_weights,
1228            fc_bias,
1229            dropout_rate,
1230        }
1231    }
1232
1233    /// Forward pass through Text CNN
1234    pub fn forward(&self, embeddings: ArrayView2<f64>) -> Result<Array1<f64>> {
1235        let mut feature_maps = Vec::new();
1236
1237        // Apply each convolutional layer
1238        for (conv_layer, pool_layer) in self.conv_layers.iter().zip(&self.pool_layers) {
1239            let conv_output = conv_layer.forward(embeddings)?;
1240            let pooled_output = pool_layer.forward(conv_output.view());
1241
1242            // Global max pooling over sequence dimension
1243            let global_max = pooled_output.map_axis(Axis(0), |row| {
1244                row.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
1245            });
1246
1247            feature_maps.push(global_max);
1248        }
1249
1250        // Concatenate all feature maps
1251        let mut concatenated_features =
1252            Array1::zeros(feature_maps.iter().map(|fm| fm.len()).sum::<usize>());
1253        let mut offset = 0;
1254        for feature_map in feature_maps {
1255            let end = offset + feature_map.len();
1256            concatenated_features
1257                .slice_mut(s![offset..end])
1258                .assign(&feature_map);
1259            offset = end;
1260        }
1261
1262        // Apply dropout (simplified - in practice would use random mask)
1263        let dropout_mask = if self.dropout_rate > 0.0 {
1264            1.0 - self.dropout_rate
1265        } else {
1266            1.0
1267        };
1268        concatenated_features *= dropout_mask;
1269
1270        // Fully connected layer
1271        let output = self.fcweights.dot(&concatenated_features) + &self.fc_bias;
1272
1273        Ok(output)
1274    }
1275}
1276
1277/// Hybrid CNN-LSTM architecture
1278pub struct CNNLSTMHybrid {
1279    /// CNN feature extractor
1280    cnn: TextCNN,
1281    /// LSTM sequence processor
1282    lstm: BiLSTM,
1283    /// Final classification layer
1284    classifier: Array2<f64>,
1285    /// Classifier bias
1286    classifier_bias: Array1<f64>,
1287}
1288
1289impl CNNLSTMHybrid {
1290    /// Create new CNN-LSTM hybrid model
1291    #[allow(clippy::too_many_arguments)]
1292    pub fn new(
1293        embedding_dim: usize,
1294        cnn_filters: usize,
1295        filter_sizes: Vec<usize>,
1296        lstm_hidden_size: usize,
1297        lstm_layers: usize,
1298        num_classes: usize,
1299    ) -> Self {
1300        // CNN for local feature extraction
1301        let cnn = TextCNN::new(
1302            0, // vocab_size not needed for feature extraction
1303            embedding_dim,
1304            cnn_filters,
1305            filter_sizes.clone(),
1306            cnn_filters * filter_sizes.len(),
1307            0.0, // No dropout in feature extraction
1308        );
1309
1310        // BiLSTM for sequence modeling
1311        let lstm_input_size = cnn_filters * filter_sizes.len();
1312        let lstm = BiLSTM::new(lstm_input_size, lstm_hidden_size, lstm_layers);
1313
1314        // Final classifier
1315        let classifier_input_size = lstm_hidden_size * 2; // Bidirectional
1316        let scale = (2.0 / classifier_input_size as f64).sqrt();
1317
1318        let classifier = Array2::from_shape_fn((num_classes, classifier_input_size), |_| {
1319            scirs2_core::random::rng().random_range(-scale..scale)
1320        });
1321        let classifier_bias = Array1::zeros(num_classes);
1322
1323        Self {
1324            cnn,
1325            lstm,
1326            classifier,
1327            classifier_bias,
1328        }
1329    }
1330
1331    /// Forward pass through hybrid model
1332    pub fn forward(&self, embeddings: ArrayView2<f64>) -> Result<Array1<f64>> {
1333        // Extract CNN features (this is simplified - would need proper implementation)
1334        let cnn_features = self.cnn.forward(embeddings)?;
1335
1336        // Reshape for LSTM input (simplified)
1337        let lstm_input = Array2::from_shape_vec((1, cnn_features.len()), cnn_features.to_vec())
1338            .map_err(|e| TextError::InvalidInput(format!("Reshape error: {e}")))?;
1339
1340        // Process through LSTM
1341        let lstm_output = self.lstm.forward(lstm_input.view())?;
1342
1343        // Take last timestep output
1344        let final_hidden = lstm_output.row(lstm_output.shape()[0] - 1);
1345
1346        // Final classification
1347        let output = self.classifier.dot(&final_hidden) + &self.classifier_bias;
1348
1349        Ok(output)
1350    }
1351}
1352
1353/// Layer normalization for neural networks
1354pub struct LayerNorm {
1355    /// Learnable scale parameters
1356    weight: Array1<f64>,
1357    /// Learnable bias parameters
1358    bias: Array1<f64>,
1359    /// Small epsilon for numerical stability
1360    eps: f64,
1361}
1362
1363impl LayerNorm {
1364    /// Create new layer normalization layer
1365    pub fn new(normalizedshape: usize) -> Self {
1366        Self {
1367            weight: Array1::ones(normalizedshape),
1368            bias: Array1::zeros(normalizedshape),
1369            eps: 1e-6,
1370        }
1371    }
1372
1373    /// Forward pass with layer normalization
1374    pub fn forward(&self, x: ArrayView2<f64>) -> Result<Array2<f64>> {
1375        let mut output = Array2::zeros(x.raw_dim());
1376
1377        // Normalize along the last dimension for each sample
1378        for (i, row) in x.outer_iter().enumerate() {
1379            let mean = row.mean();
1380            let variance = row.mapv(|v| (v - mean).powi(2)).mean();
1381            let std = (variance + self.eps).sqrt();
1382
1383            // Apply normalization and learned parameters
1384            for (j, &val) in row.iter().enumerate() {
1385                let normalized = (val - mean) / std;
1386                output[[i, j]] = normalized * self.weight[j] + self.bias[j];
1387            }
1388        }
1389
1390        Ok(output)
1391    }
1392}
1393
1394/// Dropout layer for regularization
1395pub struct Dropout {
1396    /// Dropout probability
1397    p: f64,
1398    /// Whether the layer is in training mode
1399    training: bool,
1400}
1401
1402impl Dropout {
1403    /// Create new dropout layer
1404    pub fn new(p: f64) -> Self {
1405        Self {
1406            p: p.clamp(0.0, 1.0),
1407            training: true,
1408        }
1409    }
1410
1411    /// Set training mode
1412    pub fn set_training(&mut self, training: bool) {
1413        self.training = training;
1414    }
1415
1416    /// Forward pass with dropout
1417    pub fn forward(&self, x: ArrayView2<f64>) -> Array2<f64> {
1418        if !self.training || self.p == 0.0 {
1419            return x.to_owned();
1420        }
1421
1422        let mut output = x.to_owned();
1423        let scale = 1.0 / (1.0 - self.p);
1424
1425        for elem in output.iter_mut() {
1426            if scirs2_core::random::rng().random_range(0.0..1.0) < self.p {
1427                *elem = 0.0; // Drop the element
1428            } else {
1429                *elem *= scale; // Scale to maintain expected value
1430            }
1431        }
1432
1433        output
1434    }
1435}
1436
1437/// Multi-head attention mechanism
1438pub struct MultiHeadAttention {
1439    /// Number of attention heads
1440    num_heads: usize,
1441    /// Model dimension
1442    d_model: usize,
1443    /// Dimension per head
1444    d_k: usize,
1445    /// Query projection weights
1446    w_q: Array2<f64>,
1447    /// Key projection weights
1448    w_k: Array2<f64>,
1449    /// Value projection weights
1450    w_v: Array2<f64>,
1451    /// Output projection weights
1452    w_o: Array2<f64>,
1453    /// Dropout layer
1454    dropout: Dropout,
1455}
1456
1457impl MultiHeadAttention {
1458    /// Create new multi-head attention layer
1459    pub fn new(_dmodel: usize, num_heads: usize, dropoutp: f64) -> Result<Self> {
1460        if !_dmodel.is_multiple_of(num_heads) {
1461            return Err(TextError::InvalidInput(
1462                "Model dimension must be divisible by number of _heads".to_string(),
1463            ));
1464        }
1465
1466        let d_k = _dmodel / num_heads;
1467        let scale = (2.0 / _dmodel as f64).sqrt();
1468
1469        let w_q = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
1470            scirs2_core::random::rng().random_range(-scale..scale)
1471        });
1472        let w_k = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
1473            scirs2_core::random::rng().random_range(-scale..scale)
1474        });
1475        let w_v = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
1476            scirs2_core::random::rng().random_range(-scale..scale)
1477        });
1478        let w_o = Array2::from_shape_fn((_dmodel, _dmodel), |_| {
1479            scirs2_core::random::rng().random_range(-scale..scale)
1480        });
1481
1482        Ok(Self {
1483            num_heads,
1484            d_model: _dmodel,
1485            d_k,
1486            w_q,
1487            w_k,
1488            w_v,
1489            w_o,
1490            dropout: Dropout::new(dropoutp),
1491        })
1492    }
1493
1494    /// Forward pass through multi-head attention
1495    pub fn forward(
1496        &self,
1497        query: ArrayView2<f64>,
1498        key: ArrayView2<f64>,
1499        value: ArrayView2<f64>,
1500        mask: Option<ArrayView2<bool>>,
1501    ) -> Result<Array2<f64>> {
1502        let seq_len = query.shape()[0];
1503        let _batch_size = 1; // Simplified for single sequence
1504
1505        // Linear projections
1506        let q = query.dot(&self.w_q);
1507        let k = key.dot(&self.w_k);
1508        let v = value.dot(&self.w_v);
1509
1510        // Reshape for multi-head attention [seq_len, d_model] -> [seq_len, num_heads, d_k]
1511        let mut q_heads = Array3::zeros((seq_len, self.num_heads, self.d_k));
1512        let mut k_heads = Array3::zeros((seq_len, self.num_heads, self.d_k));
1513        let mut v_heads = Array3::zeros((seq_len, self.num_heads, self.d_k));
1514
1515        for i in 0..seq_len {
1516            for h in 0..self.num_heads {
1517                let start = h * self.d_k;
1518                let _end = start + self.d_k;
1519
1520                for j in 0..self.d_k {
1521                    q_heads[[i, h, j]] = q[[i, start + j]];
1522                    k_heads[[i, h, j]] = k[[i, start + j]];
1523                    v_heads[[i, h, j]] = v[[i, start + j]];
1524                }
1525            }
1526        }
1527
1528        // Apply scaled dot-product attention for each head
1529        let mut attention_outputs = Array3::zeros((seq_len, self.num_heads, self.d_k));
1530
1531        for h in 0..self.num_heads {
1532            let q_h = q_heads.slice(s![.., h, ..]);
1533            let k_h = k_heads.slice(s![.., h, ..]);
1534            let v_h = v_heads.slice(s![.., h, ..]);
1535
1536            // Compute attention scores: Q * K^T / sqrt(d_k)
1537            let scores = q_h.dot(&k_h.t()) / (self.d_k as f64).sqrt();
1538
1539            // Apply mask if provided
1540            let mut masked_scores = scores;
1541            if let Some(mask) = mask {
1542                for i in 0..seq_len {
1543                    for j in 0..seq_len {
1544                        if mask[[i, j]] {
1545                            masked_scores[[i, j]] = f64::NEG_INFINITY;
1546                        }
1547                    }
1548                }
1549            }
1550
1551            // Apply softmax
1552            let mut attention_weights = Array2::zeros((seq_len, seq_len));
1553            for i in 0..seq_len {
1554                let row = masked_scores.row(i);
1555                let max_val = row.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
1556                let exp_sum: f64 = row.iter().map(|&x| (x - max_val).exp()).sum();
1557
1558                for j in 0..seq_len {
1559                    attention_weights[[i, j]] = (masked_scores[[i, j]] - max_val).exp() / exp_sum;
1560                }
1561            }
1562
1563            // Apply dropout to attention weights
1564            let attention_weights_dropped = self.dropout.forward(attention_weights.view());
1565
1566            // Apply attention to values
1567            let attended = attention_weights_dropped.dot(&v_h);
1568
1569            // Store result for this head
1570            for i in 0..seq_len {
1571                for j in 0..self.d_k {
1572                    attention_outputs[[i, h, j]] = attended[[i, j]];
1573                }
1574            }
1575        }
1576
1577        // Concatenate heads and reshape back to [seq_len, d_model]
1578        let mut concatenated = Array2::zeros((seq_len, self.d_model));
1579        for i in 0..seq_len {
1580            for h in 0..self.num_heads {
1581                let start = h * self.d_k;
1582                for j in 0..self.d_k {
1583                    concatenated[[i, start + j]] = attention_outputs[[i, h, j]];
1584                }
1585            }
1586        }
1587
1588        // Final output projection
1589        Ok(concatenated.dot(&self.w_o))
1590    }
1591
1592    /// Set training mode for dropout
1593    pub fn set_training(&mut self, training: bool) {
1594        self.dropout.set_training(training);
1595    }
1596}
1597
1598#[cfg(test)]
1599mod tests {
1600    use super::*;
1601
1602    #[test]
1603    fn test_activation_functions() {
1604        let x = 0.5;
1605
1606        // Test that activations produce reasonable outputs
1607        assert!(ActivationFunction::Sigmoid.apply(x) > 0.0);
1608        assert!(ActivationFunction::Sigmoid.apply(x) < 1.0);
1609        assert!(ActivationFunction::Tanh.apply(x) > -1.0);
1610        assert!(ActivationFunction::Tanh.apply(x) < 1.0);
1611        assert_eq!(ActivationFunction::ReLU.apply(-1.0), 0.0);
1612        assert_eq!(ActivationFunction::ReLU.apply(1.0), 1.0);
1613    }
1614
1615    #[test]
1616    fn test_lstm_cell() {
1617        let lstm = LSTMCell::new(10, 20);
1618        let input = Array1::ones(10);
1619        let h_prev = Array1::zeros(20);
1620        let c_prev = Array1::zeros(20);
1621
1622        let (h_new, c_new) = lstm
1623            .forward(input.view(), h_prev.view(), c_prev.view())
1624            .unwrap();
1625
1626        assert_eq!(h_new.len(), 20);
1627        assert_eq!(c_new.len(), 20);
1628    }
1629
1630    #[test]
1631    fn test_conv1d() {
1632        let conv = Conv1D::new(5, 10, 3, ActivationFunction::ReLU);
1633        let input = Array2::ones((8, 5)); // Sequence length 8, 5 channels
1634
1635        let output = conv.forward(input.view()).unwrap();
1636        assert_eq!(output.shape(), &[6, 10]); // (8-3+1, 10)
1637    }
1638
1639    #[test]
1640    fn test_bilstm() {
1641        let bilstm = BiLSTM::new(10, 20, 2);
1642        let input = Array2::ones((5, 10)); // 5 timesteps, 10 features
1643
1644        let output = bilstm.forward(input.view()).unwrap();
1645        assert_eq!(output.shape(), &[5, 40]); // Bidirectional doubles output size
1646    }
1647
1648    #[test]
1649    fn test_gru_cell() {
1650        let gru = GRUCell::new(10, 20);
1651        let input = Array1::ones(10);
1652        let h_prev = Array1::zeros(20);
1653
1654        let h_new = gru.forward(input.view(), h_prev.view()).unwrap();
1655
1656        assert_eq!(h_new.len(), 20);
1657        // Check that output is not all zeros (some processing happened)
1658        assert!(h_new.iter().any(|&x| x != 0.0));
1659    }
1660
1661    #[test]
1662    fn test_self_attention() {
1663        let attention = SelfAttention::new(8, 0.1);
1664        let input = Array2::ones((4, 8)); // 4 tokens, 8 dimensions
1665
1666        let output = attention.forward(input.view(), None).unwrap();
1667        assert_eq!(output.shape(), &[4, 8]);
1668    }
1669
1670    #[test]
1671    fn test_cross_attention() {
1672        let attention = CrossAttention::new(8);
1673        let query = Array2::ones((3, 8));
1674        let key = Array2::ones((5, 8));
1675        let value = Array2::ones((5, 8));
1676
1677        let output = attention
1678            .forward(query.view(), key.view(), value.view(), None)
1679            .unwrap();
1680        assert_eq!(output.shape(), &[3, 8]);
1681    }
1682
1683    #[test]
1684    fn test_residual_block() {
1685        let block = ResidualBlock1D::new(4, 8, 3);
1686        let input = Array2::ones((10, 4)); // 10 sequence length, 4 channels
1687
1688        let output = block.forward(input.view()).unwrap();
1689        // Two conv layers with kernel_size=3 reduce sequence: 10 -> 8 -> 6
1690        assert_eq!(output.shape(), &[6, 8]); // Correct convolution output shape
1691    }
1692
1693    #[test]
1694    fn test_multi_scale_cnn() {
1695        let cnn = MultiScaleCNN::new(
1696            5,             // input channels
1697            10,            // filters per scale
1698            vec![2, 3, 4], // kernel sizes
1699            30,            // output size
1700        );
1701        let input = Array2::ones((8, 5)); // 8 sequence length, 5 channels
1702
1703        let output = cnn.forward(input.view()).unwrap();
1704        assert_eq!(output.len(), 30);
1705    }
1706
1707    #[test]
1708    fn test_positionwise_feedforward() {
1709        let ff = PositionwiseFeedForward::new(8, 16, 0.1);
1710        let input = Array2::ones((4, 8)); // 4 tokens, 8 dimensions
1711
1712        let output = ff.forward(input.view());
1713        assert_eq!(output.shape(), &[4, 8]);
1714    }
1715}