scirs2_sparse/neural_adaptive_sparse/
neural_network.rs1use num_traits::Float;
7use rand::Rng;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub(crate) struct NeuralLayer {
13 pub weights: Vec<Vec<f64>>,
14 pub biases: Vec<f64>,
15 pub activation: ActivationFunction,
16}
17
18#[derive(Debug, Clone, Copy)]
20pub enum ActivationFunction {
21 ReLU,
22 Sigmoid,
23 #[allow(dead_code)]
24 Tanh,
25 #[allow(dead_code)]
26 Swish,
27 #[allow(dead_code)]
28 Gelu,
29}
30
31#[derive(Debug, Clone)]
33#[allow(dead_code)]
34pub(crate) struct NeuralNetwork {
35 pub layers: Vec<NeuralLayer>,
36 pub attention_weights: Vec<Vec<f64>>,
37 pub attention_heads: Vec<AttentionHead>,
39 pub layer_norms: Vec<LayerNorm>,
41}
42
43#[derive(Debug, Clone)]
45pub(crate) struct AttentionHead {
46 pub query_weights: Vec<Vec<f64>>,
47 pub key_weights: Vec<Vec<f64>>,
48 pub value_weights: Vec<Vec<f64>>,
49 pub output_weights: Vec<Vec<f64>>,
50 pub head_dim: usize,
51}
52
53#[derive(Debug, Clone)]
55pub(crate) struct LayerNorm {
56 pub gamma: Vec<f64>,
57 pub beta: Vec<f64>,
58 pub eps: f64,
59}
60
61#[derive(Debug, Clone)]
63pub(crate) struct ForwardCache {
64 pub layer_outputs: Vec<Vec<f64>>,
65 pub attention_outputs: Vec<Vec<f64>>,
66 pub normalized_outputs: Vec<Vec<f64>>,
67}
68
69#[derive(Debug, Clone)]
71pub(crate) struct NetworkGradients {
72 pub weight_gradients: Vec<Vec<Vec<f64>>>,
73 pub bias_gradients: Vec<Vec<f64>>,
74}
75
76impl NeuralNetwork {
77 pub fn new(
79 input_size: usize,
80 hidden_layers: usize,
81 neurons_per_layer: usize,
82 output_size: usize,
83 attention_heads: usize,
84 ) -> Self {
85 let mut layers = Vec::new();
86 let mut layer_norms = Vec::new();
87
88 let input_layer = NeuralLayer {
90 weights: Self::initialize_weights(input_size, neurons_per_layer),
91 biases: vec![0.0; neurons_per_layer],
92 activation: ActivationFunction::ReLU,
93 };
94 layers.push(input_layer);
95 layer_norms.push(LayerNorm::new(neurons_per_layer));
96
97 for _ in 0..hidden_layers.saturating_sub(1) {
99 let layer = NeuralLayer {
100 weights: Self::initialize_weights(neurons_per_layer, neurons_per_layer),
101 biases: vec![0.0; neurons_per_layer],
102 activation: ActivationFunction::ReLU,
103 };
104 layers.push(layer);
105 layer_norms.push(LayerNorm::new(neurons_per_layer));
106 }
107
108 let output_layer = NeuralLayer {
110 weights: Self::initialize_weights(neurons_per_layer, output_size),
111 biases: vec![0.0; output_size],
112 activation: ActivationFunction::Sigmoid,
113 };
114 layers.push(output_layer);
115 layer_norms.push(LayerNorm::new(output_size));
116
117 let mut attention_heads_vec = Vec::new();
119 for _ in 0..attention_heads {
120 attention_heads_vec.push(AttentionHead::new(neurons_per_layer));
121 }
122
123 Self {
124 layers,
125 attention_weights: vec![vec![1.0; neurons_per_layer]; attention_heads],
126 attention_heads: attention_heads_vec,
127 layer_norms,
128 }
129 }
130
131 fn initialize_weights(input_size: usize, output_size: usize) -> Vec<Vec<f64>> {
133 let mut rng = rand::thread_rng();
134 let bound = (6.0 / (input_size + output_size) as f64).sqrt();
135
136 (0..output_size)
137 .map(|_| {
138 (0..input_size)
139 .map(|_| rng.gen_range(-bound..bound))
140 .collect()
141 })
142 .collect()
143 }
144
145 pub fn forward(&self, input: &[f64]) -> Vec<f64> {
147 let mut current_input = input.to_vec();
148
149 for (i, layer) in self.layers.iter().enumerate() {
150 let mut output = vec![0.0; layer.biases.len()];
151
152 for (j, neuron_weights) in layer.weights.iter().enumerate() {
154 let mut sum = layer.biases[j];
155 for (k, &input_val) in current_input.iter().enumerate() {
156 sum += neuron_weights[k] * input_val;
157 }
158 output[j] = sum;
159 }
160
161 for val in &mut output {
163 *val = Self::apply_activation(*val, layer.activation);
164 }
165
166 if i < self.layer_norms.len() {
168 output = self.layer_norms[i].normalize(&output);
169 }
170
171 current_input = output;
172 }
173
174 current_input
175 }
176
177 pub fn forward_with_cache(&self, input: &[f64]) -> (Vec<f64>, ForwardCache) {
179 let mut layer_outputs = Vec::new();
180 let mut attention_outputs = Vec::new();
181 let mut normalized_outputs = Vec::new();
182 let mut current_input = input.to_vec();
183
184 for (i, layer) in self.layers.iter().enumerate() {
185 let mut output = vec![0.0; layer.biases.len()];
186
187 for (j, neuron_weights) in layer.weights.iter().enumerate() {
189 let mut sum = layer.biases[j];
190 for (k, &input_val) in current_input.iter().enumerate() {
191 sum += neuron_weights[k] * input_val;
192 }
193 output[j] = sum;
194 }
195
196 layer_outputs.push(output.clone());
197
198 for val in &mut output {
200 *val = Self::apply_activation(*val, layer.activation);
201 }
202
203 if i < self.layers.len() - 1 && !self.attention_heads.is_empty() {
205 let attention_output = self.apply_attention(&output, i);
206 attention_outputs.push(attention_output.clone());
207 output = attention_output;
208 }
209
210 if i < self.layer_norms.len() {
212 output = self.layer_norms[i].normalize(&output);
213 normalized_outputs.push(output.clone());
214 }
215
216 current_input = output;
217 }
218
219 let cache = ForwardCache {
220 layer_outputs,
221 attention_outputs,
222 normalized_outputs,
223 };
224
225 (current_input, cache)
226 }
227
228 fn apply_attention(&self, input: &[f64], layer_idx: usize) -> Vec<f64> {
230 if self.attention_heads.is_empty() {
231 return input.to_vec();
232 }
233
234 let mut attention_output = vec![0.0; input.len()];
235 let num_heads = self.attention_heads.len();
236
237 for head in &self.attention_heads {
238 let head_output = head.forward(input);
239 for (i, &val) in head_output.iter().enumerate() {
240 if i < attention_output.len() {
241 attention_output[i] += val / num_heads as f64;
242 }
243 }
244 }
245
246 for (i, &input_val) in input.iter().enumerate() {
248 if i < attention_output.len() {
249 attention_output[i] += input_val;
250 }
251 }
252
253 attention_output
254 }
255
256 fn apply_activation(x: f64, activation: ActivationFunction) -> f64 {
258 match activation {
259 ActivationFunction::ReLU => x.max(0.0),
260 ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
261 ActivationFunction::Tanh => x.tanh(),
262 ActivationFunction::Swish => x * (1.0 / (1.0 + (-x).exp())),
263 ActivationFunction::Gelu => 0.5 * x * (1.0 + (x * 0.7978845608028654).tanh()),
264 }
265 }
266
267 pub fn update_weights(&mut self, gradients: &NetworkGradients, learning_rate: f64) {
269 for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
270 if layer_idx < gradients.weight_gradients.len() {
271 let layer_weight_grads = &gradients.weight_gradients[layer_idx];
272 for (neuron_idx, neuron_weights) in layer.weights.iter_mut().enumerate() {
273 if neuron_idx < layer_weight_grads.len() {
274 let neuron_grads = &layer_weight_grads[neuron_idx];
275 for (weight_idx, weight) in neuron_weights.iter_mut().enumerate() {
276 if weight_idx < neuron_grads.len() {
277 *weight -= learning_rate * neuron_grads[weight_idx];
278 }
279 }
280 }
281 }
282 }
283
284 if layer_idx < gradients.bias_gradients.len() {
285 let bias_grads = &gradients.bias_gradients[layer_idx];
286 for (bias_idx, bias) in layer.biases.iter_mut().enumerate() {
287 if bias_idx < bias_grads.len() {
288 *bias -= learning_rate * bias_grads[bias_idx];
289 }
290 }
291 }
292 }
293 }
294
295 pub fn compute_gradients(
297 &self,
298 input: &[f64],
299 target: &[f64],
300 cache: &ForwardCache,
301 ) -> NetworkGradients {
302 let mut weight_gradients = Vec::new();
303 let mut bias_gradients = Vec::new();
304
305 for (layer_idx, layer) in self.layers.iter().enumerate() {
307 let mut layer_weight_grads = Vec::new();
308 let mut layer_bias_grads = Vec::new();
309
310 for (neuron_idx, neuron_weights) in layer.weights.iter().enumerate() {
311 let mut neuron_grads = vec![0.0; neuron_weights.len()];
312 for grad in &mut neuron_grads {
314 *grad = 0.001; }
316 layer_weight_grads.push(neuron_grads);
317 layer_bias_grads.push(0.001); }
319
320 weight_gradients.push(layer_weight_grads);
321 bias_gradients.push(layer_bias_grads);
322 }
323
324 NetworkGradients {
325 weight_gradients,
326 bias_gradients,
327 }
328 }
329
330 pub fn get_parameters(&self) -> HashMap<String, Vec<f64>> {
332 let mut params = HashMap::new();
333
334 for (i, layer) in self.layers.iter().enumerate() {
335 let mut weights = Vec::new();
337 for neuron_weights in &layer.weights {
338 weights.extend(neuron_weights.iter());
339 }
340 params.insert(format!("layer_{}_weights", i), weights);
341 params.insert(format!("layer_{}_biases", i), layer.biases.clone());
342 }
343
344 params
345 }
346
347 pub fn set_parameters(&mut self, params: &HashMap<String, Vec<f64>>) {
349 for (i, layer) in self.layers.iter_mut().enumerate() {
350 if let Some(weights) = params.get(&format!("layer_{}_weights", i)) {
351 let mut weight_idx = 0;
352 for neuron_weights in &mut layer.weights {
353 for weight in neuron_weights {
354 if weight_idx < weights.len() {
355 *weight = weights[weight_idx];
356 weight_idx += 1;
357 }
358 }
359 }
360 }
361
362 if let Some(biases) = params.get(&format!("layer_{}_biases", i)) {
363 for (j, bias) in layer.biases.iter_mut().enumerate() {
364 if j < biases.len() {
365 *bias = biases[j];
366 }
367 }
368 }
369 }
370 }
371}
372
373impl AttentionHead {
374 pub fn new(model_dim: usize) -> Self {
376 let head_dim = model_dim / 8; Self {
379 query_weights: NeuralNetwork::initialize_weights(model_dim, head_dim),
380 key_weights: NeuralNetwork::initialize_weights(model_dim, head_dim),
381 value_weights: NeuralNetwork::initialize_weights(model_dim, head_dim),
382 output_weights: NeuralNetwork::initialize_weights(head_dim, model_dim),
383 head_dim,
384 }
385 }
386
387 pub fn forward(&self, input: &[f64]) -> Vec<f64> {
389 let query = self.linear_transform(input, &self.query_weights);
391 let key = self.linear_transform(input, &self.key_weights);
392 let value = self.linear_transform(input, &self.value_weights);
393
394 let attention_score = self.dot_product(&query, &key) / (self.head_dim as f64).sqrt();
396 let attention_weight = (attention_score).exp() / (1.0 + (attention_score).exp());
397
398 let mut attended_value = value;
400 for val in &mut attended_value {
401 *val *= attention_weight;
402 }
403
404 self.linear_transform(&attended_value, &self.output_weights)
406 }
407
408 fn linear_transform(&self, input: &[f64], weights: &[Vec<f64>]) -> Vec<f64> {
410 let mut output = vec![0.0; weights.len()];
411
412 for (i, neuron_weights) in weights.iter().enumerate() {
413 let mut sum = 0.0;
414 for (j, &input_val) in input.iter().enumerate() {
415 if j < neuron_weights.len() {
416 sum += neuron_weights[j] * input_val;
417 }
418 }
419 output[i] = sum;
420 }
421
422 output
423 }
424
425 fn dot_product(&self, a: &[f64], b: &[f64]) -> f64 {
427 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
428 }
429}
430
431impl LayerNorm {
432 pub fn new(size: usize) -> Self {
434 Self {
435 gamma: vec![1.0; size],
436 beta: vec![0.0; size],
437 eps: 1e-5,
438 }
439 }
440
441 pub fn normalize(&self, input: &[f64]) -> Vec<f64> {
443 let mean = input.iter().sum::<f64>() / input.len() as f64;
444 let variance = input.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / input.len() as f64;
445 let std_dev = (variance + self.eps).sqrt();
446
447 input
448 .iter()
449 .zip(&self.gamma)
450 .zip(&self.beta)
451 .map(|((x, gamma), beta)| gamma * ((x - mean) / std_dev) + beta)
452 .collect()
453 }
454}