scirs2_sparse/neural_adaptive_sparse/
neural_network.rs

1//! Neural network components for sparse matrix optimization
2//!
3//! This module contains the neural network architectures used in the adaptive
4//! sparse matrix processing system.
5
6use num_traits::Float;
7use rand::Rng;
8use std::collections::HashMap;
9
10/// Neural network layer for sparse matrix optimization
11#[derive(Debug, Clone)]
12pub(crate) struct NeuralLayer {
13    pub weights: Vec<Vec<f64>>,
14    pub biases: Vec<f64>,
15    pub activation: ActivationFunction,
16}
17
18/// Activation functions for neural network layers
19#[derive(Debug, Clone, Copy)]
20pub enum ActivationFunction {
21    ReLU,
22    Sigmoid,
23    #[allow(dead_code)]
24    Tanh,
25    #[allow(dead_code)]
26    Swish,
27    #[allow(dead_code)]
28    Gelu,
29}
30
31/// Neural network for sparse matrix optimization
32#[derive(Debug, Clone)]
33#[allow(dead_code)]
34pub(crate) struct NeuralNetwork {
35    pub layers: Vec<NeuralLayer>,
36    pub attention_weights: Vec<Vec<f64>>,
37    /// Multi-head attention mechanisms
38    pub attention_heads: Vec<AttentionHead>,
39    /// Layer normalization parameters
40    pub layer_norms: Vec<LayerNorm>,
41}
42
43/// Multi-head attention mechanism
44#[derive(Debug, Clone)]
45pub(crate) struct AttentionHead {
46    pub query_weights: Vec<Vec<f64>>,
47    pub key_weights: Vec<Vec<f64>>,
48    pub value_weights: Vec<Vec<f64>>,
49    pub output_weights: Vec<Vec<f64>>,
50    pub head_dim: usize,
51}
52
53/// Layer normalization
54#[derive(Debug, Clone)]
55pub(crate) struct LayerNorm {
56    pub gamma: Vec<f64>,
57    pub beta: Vec<f64>,
58    pub eps: f64,
59}
60
61/// Forward cache for neural network computations
62#[derive(Debug, Clone)]
63pub(crate) struct ForwardCache {
64    pub layer_outputs: Vec<Vec<f64>>,
65    pub attention_outputs: Vec<Vec<f64>>,
66    pub normalized_outputs: Vec<Vec<f64>>,
67}
68
69/// Network gradients for backpropagation
70#[derive(Debug, Clone)]
71pub(crate) struct NetworkGradients {
72    pub weight_gradients: Vec<Vec<Vec<f64>>>,
73    pub bias_gradients: Vec<Vec<f64>>,
74}
75
76impl NeuralNetwork {
77    /// Create a new neural network with specified architecture
78    pub fn new(
79        input_size: usize,
80        hidden_layers: usize,
81        neurons_per_layer: usize,
82        output_size: usize,
83        attention_heads: usize,
84    ) -> Self {
85        let mut layers = Vec::new();
86        let mut layer_norms = Vec::new();
87
88        // Input layer
89        let input_layer = NeuralLayer {
90            weights: Self::initialize_weights(input_size, neurons_per_layer),
91            biases: vec![0.0; neurons_per_layer],
92            activation: ActivationFunction::ReLU,
93        };
94        layers.push(input_layer);
95        layer_norms.push(LayerNorm::new(neurons_per_layer));
96
97        // Hidden layers
98        for _ in 0..hidden_layers.saturating_sub(1) {
99            let layer = NeuralLayer {
100                weights: Self::initialize_weights(neurons_per_layer, neurons_per_layer),
101                biases: vec![0.0; neurons_per_layer],
102                activation: ActivationFunction::ReLU,
103            };
104            layers.push(layer);
105            layer_norms.push(LayerNorm::new(neurons_per_layer));
106        }
107
108        // Output layer
109        let output_layer = NeuralLayer {
110            weights: Self::initialize_weights(neurons_per_layer, output_size),
111            biases: vec![0.0; output_size],
112            activation: ActivationFunction::Sigmoid,
113        };
114        layers.push(output_layer);
115        layer_norms.push(LayerNorm::new(output_size));
116
117        // Initialize attention heads
118        let mut attention_heads_vec = Vec::new();
119        for _ in 0..attention_heads {
120            attention_heads_vec.push(AttentionHead::new(neurons_per_layer));
121        }
122
123        Self {
124            layers,
125            attention_weights: vec![vec![1.0; neurons_per_layer]; attention_heads],
126            attention_heads: attention_heads_vec,
127            layer_norms,
128        }
129    }
130
131    /// Initialize weights using Xavier initialization
132    fn initialize_weights(input_size: usize, output_size: usize) -> Vec<Vec<f64>> {
133        let mut rng = rand::thread_rng();
134        let bound = (6.0 / (input_size + output_size) as f64).sqrt();
135
136        (0..output_size)
137            .map(|_| {
138                (0..input_size)
139                    .map(|_| rng.gen_range(-bound..bound))
140                    .collect()
141            })
142            .collect()
143    }
144
145    /// Forward pass through the network
146    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
147        let mut current_input = input.to_vec();
148
149        for (i, layer) in self.layers.iter().enumerate() {
150            let mut output = vec![0.0; layer.biases.len()];
151
152            // Linear transformation
153            for (j, neuron_weights) in layer.weights.iter().enumerate() {
154                let mut sum = layer.biases[j];
155                for (k, &input_val) in current_input.iter().enumerate() {
156                    sum += neuron_weights[k] * input_val;
157                }
158                output[j] = sum;
159            }
160
161            // Apply activation function
162            for val in &mut output {
163                *val = Self::apply_activation(*val, layer.activation);
164            }
165
166            // Apply layer normalization
167            if i < self.layer_norms.len() {
168                output = self.layer_norms[i].normalize(&output);
169            }
170
171            current_input = output;
172        }
173
174        current_input
175    }
176
177    /// Forward pass with caching for backpropagation
178    pub fn forward_with_cache(&self, input: &[f64]) -> (Vec<f64>, ForwardCache) {
179        let mut layer_outputs = Vec::new();
180        let mut attention_outputs = Vec::new();
181        let mut normalized_outputs = Vec::new();
182        let mut current_input = input.to_vec();
183
184        for (i, layer) in self.layers.iter().enumerate() {
185            let mut output = vec![0.0; layer.biases.len()];
186
187            // Linear transformation
188            for (j, neuron_weights) in layer.weights.iter().enumerate() {
189                let mut sum = layer.biases[j];
190                for (k, &input_val) in current_input.iter().enumerate() {
191                    sum += neuron_weights[k] * input_val;
192                }
193                output[j] = sum;
194            }
195
196            layer_outputs.push(output.clone());
197
198            // Apply activation function
199            for val in &mut output {
200                *val = Self::apply_activation(*val, layer.activation);
201            }
202
203            // Apply attention if not the last layer
204            if i < self.layers.len() - 1 && !self.attention_heads.is_empty() {
205                let attention_output = self.apply_attention(&output, i);
206                attention_outputs.push(attention_output.clone());
207                output = attention_output;
208            }
209
210            // Apply layer normalization
211            if i < self.layer_norms.len() {
212                output = self.layer_norms[i].normalize(&output);
213                normalized_outputs.push(output.clone());
214            }
215
216            current_input = output;
217        }
218
219        let cache = ForwardCache {
220            layer_outputs,
221            attention_outputs,
222            normalized_outputs,
223        };
224
225        (current_input, cache)
226    }
227
228    /// Apply attention mechanism
229    fn apply_attention(&self, input: &[f64], layer_idx: usize) -> Vec<f64> {
230        if self.attention_heads.is_empty() {
231            return input.to_vec();
232        }
233
234        let mut attention_output = vec![0.0; input.len()];
235        let num_heads = self.attention_heads.len();
236
237        for head in &self.attention_heads {
238            let head_output = head.forward(input);
239            for (i, &val) in head_output.iter().enumerate() {
240                if i < attention_output.len() {
241                    attention_output[i] += val / num_heads as f64;
242                }
243            }
244        }
245
246        // Add residual connection
247        for (i, &input_val) in input.iter().enumerate() {
248            if i < attention_output.len() {
249                attention_output[i] += input_val;
250            }
251        }
252
253        attention_output
254    }
255
256    /// Apply activation function
257    fn apply_activation(x: f64, activation: ActivationFunction) -> f64 {
258        match activation {
259            ActivationFunction::ReLU => x.max(0.0),
260            ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
261            ActivationFunction::Tanh => x.tanh(),
262            ActivationFunction::Swish => x * (1.0 / (1.0 + (-x).exp())),
263            ActivationFunction::Gelu => 0.5 * x * (1.0 + (x * 0.7978845608028654).tanh()),
264        }
265    }
266
267    /// Update weights using gradients
268    pub fn update_weights(&mut self, gradients: &NetworkGradients, learning_rate: f64) {
269        for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
270            if layer_idx < gradients.weight_gradients.len() {
271                let layer_weight_grads = &gradients.weight_gradients[layer_idx];
272                for (neuron_idx, neuron_weights) in layer.weights.iter_mut().enumerate() {
273                    if neuron_idx < layer_weight_grads.len() {
274                        let neuron_grads = &layer_weight_grads[neuron_idx];
275                        for (weight_idx, weight) in neuron_weights.iter_mut().enumerate() {
276                            if weight_idx < neuron_grads.len() {
277                                *weight -= learning_rate * neuron_grads[weight_idx];
278                            }
279                        }
280                    }
281                }
282            }
283
284            if layer_idx < gradients.bias_gradients.len() {
285                let bias_grads = &gradients.bias_gradients[layer_idx];
286                for (bias_idx, bias) in layer.biases.iter_mut().enumerate() {
287                    if bias_idx < bias_grads.len() {
288                        *bias -= learning_rate * bias_grads[bias_idx];
289                    }
290                }
291            }
292        }
293    }
294
295    /// Compute network gradients
296    pub fn compute_gradients(
297        &self,
298        input: &[f64],
299        target: &[f64],
300        cache: &ForwardCache,
301    ) -> NetworkGradients {
302        let mut weight_gradients = Vec::new();
303        let mut bias_gradients = Vec::new();
304
305        // Simplified gradient computation (actual implementation would be more complex)
306        for (layer_idx, layer) in self.layers.iter().enumerate() {
307            let mut layer_weight_grads = Vec::new();
308            let mut layer_bias_grads = Vec::new();
309
310            for (neuron_idx, neuron_weights) in layer.weights.iter().enumerate() {
311                let mut neuron_grads = vec![0.0; neuron_weights.len()];
312                // Simplified gradient calculation
313                for grad in &mut neuron_grads {
314                    *grad = 0.001; // Placeholder
315                }
316                layer_weight_grads.push(neuron_grads);
317                layer_bias_grads.push(0.001); // Placeholder
318            }
319
320            weight_gradients.push(layer_weight_grads);
321            bias_gradients.push(layer_bias_grads);
322        }
323
324        NetworkGradients {
325            weight_gradients,
326            bias_gradients,
327        }
328    }
329
330    /// Get network parameters for serialization
331    pub fn get_parameters(&self) -> HashMap<String, Vec<f64>> {
332        let mut params = HashMap::new();
333
334        for (i, layer) in self.layers.iter().enumerate() {
335            // Flatten weights
336            let mut weights = Vec::new();
337            for neuron_weights in &layer.weights {
338                weights.extend(neuron_weights.iter());
339            }
340            params.insert(format!("layer_{}_weights", i), weights);
341            params.insert(format!("layer_{}_biases", i), layer.biases.clone());
342        }
343
344        params
345    }
346
347    /// Set network parameters from serialized data
348    pub fn set_parameters(&mut self, params: &HashMap<String, Vec<f64>>) {
349        for (i, layer) in self.layers.iter_mut().enumerate() {
350            if let Some(weights) = params.get(&format!("layer_{}_weights", i)) {
351                let mut weight_idx = 0;
352                for neuron_weights in &mut layer.weights {
353                    for weight in neuron_weights {
354                        if weight_idx < weights.len() {
355                            *weight = weights[weight_idx];
356                            weight_idx += 1;
357                        }
358                    }
359                }
360            }
361
362            if let Some(biases) = params.get(&format!("layer_{}_biases", i)) {
363                for (j, bias) in layer.biases.iter_mut().enumerate() {
364                    if j < biases.len() {
365                        *bias = biases[j];
366                    }
367                }
368            }
369        }
370    }
371}
372
373impl AttentionHead {
374    /// Create a new attention head
375    pub fn new(model_dim: usize) -> Self {
376        let head_dim = model_dim / 8; // Typical head dimension
377
378        Self {
379            query_weights: NeuralNetwork::initialize_weights(model_dim, head_dim),
380            key_weights: NeuralNetwork::initialize_weights(model_dim, head_dim),
381            value_weights: NeuralNetwork::initialize_weights(model_dim, head_dim),
382            output_weights: NeuralNetwork::initialize_weights(head_dim, model_dim),
383            head_dim,
384        }
385    }
386
387    /// Forward pass through attention head
388    pub fn forward(&self, input: &[f64]) -> Vec<f64> {
389        // Simplified attention mechanism
390        let query = self.linear_transform(input, &self.query_weights);
391        let key = self.linear_transform(input, &self.key_weights);
392        let value = self.linear_transform(input, &self.value_weights);
393
394        // Compute attention scores (simplified)
395        let attention_score = self.dot_product(&query, &key) / (self.head_dim as f64).sqrt();
396        let attention_weight = (attention_score).exp() / (1.0 + (attention_score).exp());
397
398        // Apply attention to values
399        let mut attended_value = value;
400        for val in &mut attended_value {
401            *val *= attention_weight;
402        }
403
404        // Output projection
405        self.linear_transform(&attended_value, &self.output_weights)
406    }
407
408    /// Linear transformation
409    fn linear_transform(&self, input: &[f64], weights: &[Vec<f64>]) -> Vec<f64> {
410        let mut output = vec![0.0; weights.len()];
411
412        for (i, neuron_weights) in weights.iter().enumerate() {
413            let mut sum = 0.0;
414            for (j, &input_val) in input.iter().enumerate() {
415                if j < neuron_weights.len() {
416                    sum += neuron_weights[j] * input_val;
417                }
418            }
419            output[i] = sum;
420        }
421
422        output
423    }
424
425    /// Dot product of two vectors
426    fn dot_product(&self, a: &[f64], b: &[f64]) -> f64 {
427        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
428    }
429}
430
431impl LayerNorm {
432    /// Create a new layer normalization
433    pub fn new(size: usize) -> Self {
434        Self {
435            gamma: vec![1.0; size],
436            beta: vec![0.0; size],
437            eps: 1e-5,
438        }
439    }
440
441    /// Normalize input
442    pub fn normalize(&self, input: &[f64]) -> Vec<f64> {
443        let mean = input.iter().sum::<f64>() / input.len() as f64;
444        let variance = input.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / input.len() as f64;
445        let std_dev = (variance + self.eps).sqrt();
446
447        input
448            .iter()
449            .zip(&self.gamma)
450            .zip(&self.beta)
451            .map(|((x, gamma), beta)| gamma * ((x - mean) / std_dev) + beta)
452            .collect()
453    }
454}