Skip to main content

scirs2_sparse/neural_adaptive_sparse/
neural_network.rs

1//! Neural network components for sparse matrix optimization
2//!
3//! This module contains the neural network architectures used in the adaptive
4//! sparse matrix processing system.
5
6use scirs2_core::numeric::{Float, SparseElement};
7use scirs2_core::random::{Rng, RngExt};
8use std::collections::HashMap;
9
10/// Neural network layer for sparse matrix optimization
11#[derive(Debug, Clone)]
12pub(crate) struct NeuralLayer {
13    pub weights: Vec<Vec<f64>>,
14    pub biases: Vec<f64>,
15    pub activation: ActivationFunction,
16}
17
18/// Activation functions for neural network layers
19#[derive(Debug, Clone, Copy)]
20pub enum ActivationFunction {
21    ReLU,
22    Sigmoid,
23    #[allow(dead_code)]
24    Tanh,
25    #[allow(dead_code)]
26    Swish,
27    #[allow(dead_code)]
28    Gelu,
29}
30
31/// Neural network for sparse matrix optimization
32#[derive(Debug, Clone)]
33#[allow(dead_code)]
34pub(crate) struct NeuralNetwork {
35    pub layers: Vec<NeuralLayer>,
36    pub attention_weights: Vec<Vec<f64>>,
37    /// Multi-head attention mechanisms
38    pub attention_heads: Vec<AttentionHead>,
39    /// Layer normalization parameters
40    pub layer_norms: Vec<LayerNorm>,
41}
42
43/// Multi-head attention mechanism
44#[derive(Debug, Clone)]
45pub(crate) struct AttentionHead {
46    pub query_weights: Vec<Vec<f64>>,
47    pub key_weights: Vec<Vec<f64>>,
48    pub value_weights: Vec<Vec<f64>>,
49    pub output_weights: Vec<Vec<f64>>,
50    pub head_dim: usize,
51}
52
53/// Layer normalization
54#[derive(Debug, Clone)]
55pub(crate) struct LayerNorm {
56    pub gamma: Vec<f64>,
57    pub beta: Vec<f64>,
58    pub eps: f64,
59}
60
61/// Forward cache for neural network computations
62#[derive(Debug, Clone)]
63pub(crate) struct ForwardCache {
64    pub layer_outputs: Vec<Vec<f64>>,
65    pub attention_outputs: Vec<Vec<f64>>,
66    pub normalized_outputs: Vec<Vec<f64>>,
67}
68
69/// Network gradients for backpropagation
70#[derive(Debug, Clone)]
71pub(crate) struct NetworkGradients {
72    pub weight_gradients: Vec<Vec<Vec<f64>>>,
73    pub bias_gradients: Vec<Vec<f64>>,
74}
75
76impl NeuralNetwork {
77    /// Create a new neural network with specified architecture
78    pub fn new(
79        input_size: usize,
80        hidden_layers: usize,
81        neurons_per_layer: usize,
82        output_size: usize,
83        attention_heads: usize,
84    ) -> Self {
85        let mut layers = Vec::new();
86        let mut layer_norms = Vec::new();
87
88        // Input layer
89        let input_layer = NeuralLayer {
90            weights: Self::initialize_weights(input_size, neurons_per_layer),
91            biases: vec![0.0; neurons_per_layer],
92            activation: ActivationFunction::ReLU,
93        };
94        layers.push(input_layer);
95        layer_norms.push(LayerNorm::new(neurons_per_layer));
96
97        // Hidden layers
98        for _ in 0..hidden_layers.saturating_sub(1) {
99            let layer = NeuralLayer {
100                weights: Self::initialize_weights(neurons_per_layer, neurons_per_layer),
101                biases: vec![0.0; neurons_per_layer],
102                activation: ActivationFunction::ReLU,
103            };
104            layers.push(layer);
105            layer_norms.push(LayerNorm::new(neurons_per_layer));
106        }
107
108        // Output layer
109        let output_layer = NeuralLayer {
110            weights: Self::initialize_weights(neurons_per_layer, output_size),
111            biases: vec![0.0; output_size],
112            activation: ActivationFunction::Sigmoid,
113        };
114        layers.push(output_layer);
115        layer_norms.push(LayerNorm::new(output_size));
116
117        // Initialize attention heads
118        let mut attention_heads_vec = Vec::new();
119        for _ in 0..attention_heads {
120            attention_heads_vec.push(AttentionHead::new(neurons_per_layer));
121        }
122
123        Self {
124            layers,
125            attention_weights: vec![vec![1.0; neurons_per_layer]; attention_heads],
126            attention_heads: attention_heads_vec,
127            layer_norms,
128        }
129    }
130
131    /// Initialize weights using Xavier initialization
132    fn initialize_weights(input_size: usize, output_size: usize) -> Vec<Vec<f64>> {
133        let mut rng = scirs2_core::random::thread_rng();
134        let bound = (6.0 / (input_size + output_size) as f64).sqrt();
135
136        (0..output_size)
137            .map(|_| {
138                (0..input_size)
139                    .map(|_| rng.random_range(-bound..bound))
140                    .collect()
141            })
142            .collect()
143    }
144
145    /// Forward pass through the network
146    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
147        let mut current_input = input.to_vec();
148
149        for (i, layer) in self.layers.iter().enumerate() {
150            let mut output = vec![0.0; layer.biases.len()];
151
152            // Linear transformation
153            for (j, neuron_weights) in layer.weights.iter().enumerate() {
154                let mut sum = layer.biases[j];
155                for (k, &input_val) in current_input.iter().enumerate() {
156                    sum += neuron_weights[k] * input_val;
157                }
158                output[j] = sum;
159            }
160
161            // Apply activation function
162            for val in &mut output {
163                *val = Self::apply_activation(*val, layer.activation);
164            }
165
166            // Apply layer normalization
167            if i < self.layer_norms.len() {
168                output = self.layer_norms[i].normalize(&output);
169            }
170
171            current_input = output;
172        }
173
174        current_input
175    }
176
177    /// Forward pass with caching for backpropagation
178    pub fn forward_with_cache(&self, input: &[f64]) -> (Vec<f64>, ForwardCache) {
179        let mut layer_outputs = Vec::new();
180        let mut attention_outputs = Vec::new();
181        let mut normalized_outputs = Vec::new();
182        let mut current_input = input.to_vec();
183
184        for (i, layer) in self.layers.iter().enumerate() {
185            let mut output = vec![0.0; layer.biases.len()];
186
187            // Linear transformation
188            for (j, neuron_weights) in layer.weights.iter().enumerate() {
189                let mut sum = layer.biases[j];
190                for (k, &input_val) in current_input.iter().enumerate() {
191                    sum += neuron_weights[k] * input_val;
192                }
193                output[j] = sum;
194            }
195
196            layer_outputs.push(output.clone());
197
198            // Apply activation function
199            for val in &mut output {
200                *val = Self::apply_activation(*val, layer.activation);
201            }
202
203            // Apply attention if not the last layer
204            if i < self.layers.len() - 1 && !self.attention_heads.is_empty() {
205                let attention_output = self.apply_attention(&output, i);
206                attention_outputs.push(attention_output.clone());
207                output = attention_output;
208            }
209
210            // Apply layer normalization
211            if i < self.layer_norms.len() {
212                output = self.layer_norms[i].normalize(&output);
213                normalized_outputs.push(output.clone());
214            }
215
216            current_input = output;
217        }
218
219        let cache = ForwardCache {
220            layer_outputs,
221            attention_outputs,
222            normalized_outputs,
223        };
224
225        (current_input, cache)
226    }
227
228    /// Apply attention mechanism
229    fn apply_attention(&self, input: &[f64], layer_idx: usize) -> Vec<f64> {
230        if self.attention_heads.is_empty() {
231            return input.to_vec();
232        }
233
234        let mut attention_output = vec![0.0; input.len()];
235        let num_heads = self.attention_heads.len();
236
237        for head in &self.attention_heads {
238            let head_output = head.forward(input);
239            for (i, &val) in head_output.iter().enumerate() {
240                if i < attention_output.len() {
241                    attention_output[i] += val / num_heads as f64;
242                }
243            }
244        }
245
246        // Add residual connection
247        for (i, &input_val) in input.iter().enumerate() {
248            if i < attention_output.len() {
249                attention_output[i] += input_val;
250            }
251        }
252
253        attention_output
254    }
255
256    /// Apply activation function
257    fn apply_activation(x: f64, activation: ActivationFunction) -> f64 {
258        match activation {
259            ActivationFunction::ReLU => x.max(0.0),
260            ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
261            ActivationFunction::Tanh => x.tanh(),
262            ActivationFunction::Swish => x * (1.0 / (1.0 + (-x).exp())),
263            ActivationFunction::Gelu => 0.5 * x * (1.0 + (x * 0.7978845608028654).tanh()),
264        }
265    }
266
267    /// Update weights using gradients
268    pub fn update_weights(&mut self, gradients: &NetworkGradients, learning_rate: f64) {
269        for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
270            if layer_idx < gradients.weight_gradients.len() {
271                let layer_weight_grads = &gradients.weight_gradients[layer_idx];
272                for (neuron_idx, neuron_weights) in layer.weights.iter_mut().enumerate() {
273                    if neuron_idx < layer_weight_grads.len() {
274                        let neuron_grads = &layer_weight_grads[neuron_idx];
275                        for (weight_idx, weight) in neuron_weights.iter_mut().enumerate() {
276                            if weight_idx < neuron_grads.len() {
277                                *weight -= learning_rate * neuron_grads[weight_idx];
278                            }
279                        }
280                    }
281                }
282            }
283
284            if layer_idx < gradients.bias_gradients.len() {
285                let bias_grads = &gradients.bias_gradients[layer_idx];
286                for (bias_idx, bias) in layer.biases.iter_mut().enumerate() {
287                    if bias_idx < bias_grads.len() {
288                        *bias -= learning_rate * bias_grads[bias_idx];
289                    }
290                }
291            }
292        }
293    }
294
295    /// Derivative of the activation function evaluated at pre-activation value `x`.
296    fn activation_derivative(x: f64, activation: ActivationFunction) -> f64 {
297        match activation {
298            ActivationFunction::ReLU => {
299                if x > 0.0 {
300                    1.0
301                } else {
302                    0.0
303                }
304            }
305            ActivationFunction::Sigmoid => {
306                let s = 1.0 / (1.0 + (-x).exp());
307                s * (1.0 - s)
308            }
309            ActivationFunction::Tanh => {
310                let t = x.tanh();
311                1.0 - t * t
312            }
313            ActivationFunction::Swish => {
314                let s = 1.0 / (1.0 + (-x).exp());
315                s + x * s * (1.0 - s)
316            }
317            ActivationFunction::Gelu => {
318                let c = 0.7978845608028654;
319                let t = (c * x).tanh();
320                0.5 * (1.0 + t) + 0.5 * x * c * (1.0 - t * t)
321            }
322        }
323    }
324
325    /// Compute network gradients via backpropagation.
326    ///
327    /// Uses cached pre-activation outputs to compute dL/dW and dL/db
328    /// for each layer, where L = 0.5 * ||output - target||^2.
329    pub fn compute_gradients(
330        &self,
331        input: &[f64],
332        target: &[f64],
333        cache: &ForwardCache,
334    ) -> NetworkGradients {
335        let num_layers = self.layers.len();
336        let mut weight_gradients: Vec<Vec<Vec<f64>>> = Vec::with_capacity(num_layers);
337        let mut bias_gradients: Vec<Vec<f64>> = Vec::with_capacity(num_layers);
338
339        // Collect the input to each layer during the forward pass.
340        let mut layer_inputs: Vec<Vec<f64>> = Vec::with_capacity(num_layers);
341        {
342            let mut current = input.to_vec();
343            for (l, layer) in self.layers.iter().enumerate() {
344                layer_inputs.push(current.clone());
345                let mut output = vec![0.0; layer.biases.len()];
346                for (j, neuron_w) in layer.weights.iter().enumerate() {
347                    let mut s = layer.biases[j];
348                    for (k, &iv) in current.iter().enumerate() {
349                        if k < neuron_w.len() {
350                            s += neuron_w[k] * iv;
351                        }
352                    }
353                    output[j] = Self::apply_activation(s, layer.activation);
354                }
355                if l < self.layer_norms.len() {
356                    output = self.layer_norms[l].normalize(&output);
357                }
358                current = output;
359            }
360        }
361
362        // Compute output from cache
363        let last_output = if !cache.layer_outputs.is_empty() {
364            let pre_act = &cache.layer_outputs[num_layers - 1];
365            let act = self.layers[num_layers - 1].activation;
366            pre_act
367                .iter()
368                .map(|&z| Self::apply_activation(z, act))
369                .collect::<Vec<_>>()
370        } else {
371            self.forward(input)
372        };
373
374        let output_pre_act = if num_layers <= cache.layer_outputs.len() {
375            cache.layer_outputs[num_layers - 1].clone()
376        } else {
377            last_output.clone()
378        };
379
380        let out_activation = self.layers[num_layers - 1].activation;
381        let output_size = self.layers[num_layers - 1].biases.len();
382        let mut delta = vec![0.0; output_size];
383        for i in 0..output_size {
384            let z = if i < output_pre_act.len() {
385                output_pre_act[i]
386            } else {
387                0.0
388            };
389            let o = if i < last_output.len() {
390                last_output[i]
391            } else {
392                0.0
393            };
394            let t = if i < target.len() { target[i] } else { 0.0 };
395            delta[i] = (o - t) * Self::activation_derivative(z, out_activation);
396        }
397
398        // Pre-allocate gradient storage
399        for layer in &self.layers {
400            let n_out = layer.biases.len();
401            let mut wg = Vec::with_capacity(n_out);
402            for neuron_w in &layer.weights {
403                wg.push(vec![0.0; neuron_w.len()]);
404            }
405            weight_gradients.push(wg);
406            bias_gradients.push(vec![0.0; n_out]);
407        }
408
409        // Backpropagate through layers in reverse
410        for l in (0..num_layers).rev() {
411            let layer = &self.layers[l];
412            let layer_in = &layer_inputs[l];
413
414            for j in 0..layer.biases.len() {
415                if j < delta.len() {
416                    bias_gradients[l][j] = delta[j];
417                    for k in 0..layer.weights[j].len() {
418                        let inp_val = if k < layer_in.len() { layer_in[k] } else { 0.0 };
419                        weight_gradients[l][j][k] = delta[j] * inp_val;
420                    }
421                }
422            }
423
424            if l > 0 {
425                let prev_layer = &self.layers[l - 1];
426                let prev_pre_act = if l - 1 < cache.layer_outputs.len() {
427                    &cache.layer_outputs[l - 1]
428                } else {
429                    &layer_inputs[l]
430                };
431                let prev_activation = prev_layer.activation;
432                let prev_size = prev_layer.biases.len();
433
434                let mut new_delta = vec![0.0; prev_size];
435                for k in 0..prev_size {
436                    let mut sum = 0.0;
437                    for j in 0..layer.biases.len() {
438                        if j < delta.len() && k < layer.weights[j].len() {
439                            sum += delta[j] * layer.weights[j][k];
440                        }
441                    }
442                    let z = if k < prev_pre_act.len() {
443                        prev_pre_act[k]
444                    } else {
445                        0.0
446                    };
447                    new_delta[k] = sum * Self::activation_derivative(z, prev_activation);
448                }
449                delta = new_delta;
450            }
451        }
452
453        NetworkGradients {
454            weight_gradients,
455            bias_gradients,
456        }
457    }
458
459    /// Get network parameters for serialization
460    pub fn get_parameters(&self) -> HashMap<String, Vec<f64>> {
461        let mut params = HashMap::new();
462
463        for (i, layer) in self.layers.iter().enumerate() {
464            // Flatten weights
465            let mut weights = Vec::new();
466            for neuron_weights in &layer.weights {
467                weights.extend(neuron_weights.iter());
468            }
469            params.insert(format!("layer_{}_weights", i), weights);
470            params.insert(format!("layer_{}_biases", i), layer.biases.clone());
471        }
472
473        params
474    }
475
476    /// Set network parameters from serialized data
477    pub fn set_parameters(&mut self, params: &HashMap<String, Vec<f64>>) {
478        for (i, layer) in self.layers.iter_mut().enumerate() {
479            if let Some(weights) = params.get(&format!("layer_{}_weights", i)) {
480                let mut weight_idx = 0;
481                for neuron_weights in &mut layer.weights {
482                    for weight in neuron_weights {
483                        if weight_idx < weights.len() {
484                            *weight = weights[weight_idx];
485                            weight_idx += 1;
486                        }
487                    }
488                }
489            }
490
491            if let Some(biases) = params.get(&format!("layer_{}_biases", i)) {
492                for (j, bias) in layer.biases.iter_mut().enumerate() {
493                    if j < biases.len() {
494                        *bias = biases[j];
495                    }
496                }
497            }
498        }
499    }
500}
501
502impl AttentionHead {
503    /// Create a new attention head
504    pub fn new(model_dim: usize) -> Self {
505        let head_dim = model_dim / 8; // Typical head dimension
506
507        Self {
508            query_weights: NeuralNetwork::initialize_weights(model_dim, head_dim),
509            key_weights: NeuralNetwork::initialize_weights(model_dim, head_dim),
510            value_weights: NeuralNetwork::initialize_weights(model_dim, head_dim),
511            output_weights: NeuralNetwork::initialize_weights(head_dim, model_dim),
512            head_dim,
513        }
514    }
515
516    /// Forward pass through attention head
517    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
518        // Simplified attention mechanism
519        let query = self.linear_transform(input, &self.query_weights);
520        let key = self.linear_transform(input, &self.key_weights);
521        let value = self.linear_transform(input, &self.value_weights);
522
523        // Compute attention scores (simplified)
524        let attention_score = self.dot_product(&query, &key) / (self.head_dim as f64).sqrt();
525        let attention_weight = (attention_score).exp() / (1.0 + (attention_score).exp());
526
527        // Apply attention to values
528        let mut attended_value = value;
529        for val in &mut attended_value {
530            *val *= attention_weight;
531        }
532
533        // Output projection
534        self.linear_transform(&attended_value, &self.output_weights)
535    }
536
537    /// Linear transformation
538    fn linear_transform(&self, input: &[f64], weights: &[Vec<f64>]) -> Vec<f64> {
539        let mut output = vec![0.0; weights.len()];
540
541        for (i, neuron_weights) in weights.iter().enumerate() {
542            let mut sum = 0.0;
543            for (j, &input_val) in input.iter().enumerate() {
544                if j < neuron_weights.len() {
545                    sum += neuron_weights[j] * input_val;
546                }
547            }
548            output[i] = sum;
549        }
550
551        output
552    }
553
554    /// Dot product of two vectors
555    fn dot_product(&self, a: &[f64], b: &[f64]) -> f64 {
556        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
557    }
558}
559
560impl LayerNorm {
561    /// Create a new layer normalization
562    pub fn new(size: usize) -> Self {
563        Self {
564            gamma: vec![1.0; size],
565            beta: vec![0.0; size],
566            eps: 1e-5,
567        }
568    }
569
570    /// Normalize input
571    pub fn normalize(&self, input: &[f64]) -> Vec<f64> {
572        let mean = input.iter().sum::<f64>() / input.len() as f64;
573        let variance = input.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / input.len() as f64;
574        let std_dev = (variance + self.eps).sqrt();
575
576        input
577            .iter()
578            .zip(&self.gamma)
579            .zip(&self.beta)
580            .map(|((x, gamma), beta)| gamma * ((x - mean) / std_dev) + beta)
581            .collect()
582    }
583}