scirs2_series/advanced_training_modules/
memory_augmented.rs

1//! Memory-Augmented Neural Networks
2//!
3//! This module implements Memory-Augmented Neural Networks (MANNs) for few-shot learning
4//! and other tasks that require external memory capabilities.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::few_shot::FewShotEpisode;
11use crate::error::Result;
12
13/// Memory-Augmented Neural Network (MANN)
14#[derive(Debug)]
15pub struct MANN<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
16    /// Controller network parameters
17    controller_params: Array2<F>,
18    /// External memory matrix
19    memory: Array2<F>,
20    /// Memory dimensions
21    memory_size: usize,
22    memory_width: usize,
23    /// Controller dimensions
24    controller_input_dim: usize,
25    controller_hidden_dim: usize,
26    controller_output_dim: usize,
27    /// Read/write head parameters
28    #[allow(dead_code)]
29    read_head_params: Array2<F>,
30    #[allow(dead_code)]
31    write_head_params: Array2<F>,
32}
33
34impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand> MANN<F> {
35    /// Create new Memory-Augmented Neural Network
36    pub fn new(
37        memory_size: usize,
38        memory_width: usize,
39        controller_input_dim: usize,
40        controller_hidden_dim: usize,
41        controller_output_dim: usize,
42    ) -> Self {
43        // Initialize controller parameters
44        let controller_param_count = controller_input_dim * controller_hidden_dim
45            + controller_hidden_dim
46            + controller_hidden_dim * controller_output_dim
47            + controller_output_dim;
48
49        let mut controller_params = Array2::zeros((1, controller_param_count));
50        let scale =
51            F::from(2.0).unwrap() / F::from(controller_input_dim + controller_output_dim).unwrap();
52        let std_dev = scale.sqrt();
53
54        for i in 0..controller_param_count {
55            let val = ((i * 67) % 1000) as f64 / 1000.0 - 0.5;
56            controller_params[[0, i]] = F::from(val).unwrap() * std_dev;
57        }
58
59        // Initialize memory
60        let memory = Array2::zeros((memory_size, memory_width));
61
62        // Initialize read/write head parameters
63        let head_param_count = memory_width * 2 + 3; // key, beta, gate, shift, gamma
64        let mut read_head_params = Array2::zeros((1, head_param_count));
65        let mut write_head_params = Array2::zeros((1, head_param_count));
66
67        for i in 0..head_param_count {
68            let val1 = ((i * 71) % 1000) as f64 / 1000.0 - 0.5;
69            let val2 = ((i * 73) % 1000) as f64 / 1000.0 - 0.5;
70            read_head_params[[0, i]] = F::from(val1).unwrap() * F::from(0.1).unwrap();
71            write_head_params[[0, i]] = F::from(val2).unwrap() * F::from(0.1).unwrap();
72        }
73
74        Self {
75            controller_params,
76            memory,
77            memory_size,
78            memory_width,
79            controller_input_dim,
80            controller_hidden_dim,
81            controller_output_dim,
82            read_head_params,
83            write_head_params,
84        }
85    }
86
87    /// Forward pass through MANN
88    pub fn forward(&mut self, input: &Array1<F>) -> Result<Array1<F>> {
89        // Read from memory
90        let read_vector = self.memory_read()?;
91
92        // Combine input with read vector
93        let mut controller_input = Array1::zeros(self.controller_input_dim);
94        for i in 0..input.len().min(self.controller_input_dim) {
95            controller_input[i] = input[i];
96        }
97
98        // Add read vector to controller input
99        let read_start = input.len().min(self.controller_input_dim);
100        for i in 0..read_vector.len() {
101            if read_start + i < self.controller_input_dim {
102                controller_input[read_start + i] = read_vector[i];
103            }
104        }
105
106        // Controller forward pass
107        let controller_output = self.controller_forward(&controller_input)?;
108
109        // Write to memory
110        self.memory_write(&controller_output)?;
111
112        Ok(controller_output)
113    }
114
115    /// Controller neural network forward pass
116    fn controller_forward(&self, input: &Array1<F>) -> Result<Array1<F>> {
117        let (w1, b1, w2, b2) = self.extract_controller_weights();
118
119        // Hidden layer
120        let mut hidden = Array1::zeros(self.controller_hidden_dim);
121        for i in 0..self.controller_hidden_dim {
122            let mut sum = b1[i];
123            for j in 0..input.len().min(w1.ncols()) {
124                sum = sum + input[j] * w1[[i, j]];
125            }
126            hidden[i] = self.tanh(sum);
127        }
128
129        // Output layer
130        let mut output = Array1::zeros(self.controller_output_dim);
131        for i in 0..self.controller_output_dim {
132            let mut sum = b2[i];
133            for j in 0..self.controller_hidden_dim {
134                sum = sum + hidden[j] * w2[[i, j]];
135            }
136            output[i] = sum;
137        }
138
139        Ok(output)
140    }
141
142    /// Read from external memory
143    fn memory_read(&self) -> Result<Array1<F>> {
144        // Simplified memory read - return average of memory rows
145        let mut read_vector = Array1::zeros(self.memory_width);
146
147        for i in 0..self.memory_size {
148            for j in 0..self.memory_width {
149                read_vector[j] = read_vector[j] + self.memory[[i, j]];
150            }
151        }
152
153        let size = F::from(self.memory_size).unwrap();
154        for j in 0..self.memory_width {
155            read_vector[j] = read_vector[j] / size;
156        }
157
158        Ok(read_vector)
159    }
160
161    /// Write to external memory
162    fn memory_write(&mut self, controller_output: &Array1<F>) -> Result<()> {
163        // Simplified memory write - update first row with controller _output
164        for i in 0..controller_output.len().min(self.memory_width) {
165            self.memory[[0, i]] = controller_output[i];
166        }
167
168        Ok(())
169    }
170
171    /// Extract controller weights from parameters
172    fn extract_controller_weights(&self) -> (Array2<F>, Array1<F>, Array2<F>, Array1<F>) {
173        let param_vec = self.controller_params.row(0);
174        let mut idx = 0;
175
176        // W1: controller_input_dim x controller_hidden_dim
177        let mut w1 = Array2::zeros((self.controller_hidden_dim, self.controller_input_dim));
178        for i in 0..self.controller_hidden_dim {
179            for j in 0..self.controller_input_dim {
180                if idx < param_vec.len() {
181                    w1[[i, j]] = param_vec[idx];
182                    idx += 1;
183                }
184            }
185        }
186
187        // b1: controller_hidden_dim
188        let mut b1 = Array1::zeros(self.controller_hidden_dim);
189        for i in 0..self.controller_hidden_dim {
190            if idx < param_vec.len() {
191                b1[i] = param_vec[idx];
192                idx += 1;
193            }
194        }
195
196        // W2: controller_hidden_dim x controller_output_dim
197        let mut w2 = Array2::zeros((self.controller_output_dim, self.controller_hidden_dim));
198        for i in 0..self.controller_output_dim {
199            for j in 0..self.controller_hidden_dim {
200                if idx < param_vec.len() {
201                    w2[[i, j]] = param_vec[idx];
202                    idx += 1;
203                }
204            }
205        }
206
207        // b2: controller_output_dim
208        let mut b2 = Array1::zeros(self.controller_output_dim);
209        for i in 0..self.controller_output_dim {
210            if idx < param_vec.len() {
211                b2[i] = param_vec[idx];
212                idx += 1;
213            }
214        }
215
216        (w1, b1, w2, b2)
217    }
218
219    /// Reset memory
220    pub fn reset_memory(&mut self) {
221        self.memory = Array2::zeros((self.memory_size, self.memory_width));
222    }
223
224    /// Train MANN on few-shot learning task
225    pub fn train_few_shot(&mut self, episodes: &[FewShotEpisode<F>]) -> Result<F> {
226        let mut total_loss = F::zero();
227
228        for episode in episodes {
229            self.reset_memory();
230
231            // Present support set
232            for i in 0..episode.support_x.nrows() {
233                let input_row = episode.support_x.row(i).to_owned();
234                let _output = self.forward(&input_row)?;
235            }
236
237            // Test on query set
238            let mut episode_loss = F::zero();
239            for i in 0..episode.query_x.nrows() {
240                let input_row = episode.query_x.row(i).to_owned();
241                let prediction = self.forward(&input_row)?;
242
243                // Compute loss (simplified)
244                if i < episode.query_y.len() {
245                    let target = F::from(episode.query_y[i]).unwrap();
246                    if !prediction.is_empty() {
247                        let diff = prediction[0] - target;
248                        episode_loss = episode_loss + diff * diff;
249                    }
250                }
251            }
252
253            total_loss = total_loss + episode_loss;
254        }
255
256        Ok(total_loss / F::from(episodes.len()).unwrap())
257    }
258
259    /// Get current memory state
260    pub fn get_memory(&self) -> &Array2<F> {
261        &self.memory
262    }
263
264    /// Set memory state
265    pub fn set_memory(&mut self, memory: Array2<F>) -> Result<()> {
266        if memory.dim() != (self.memory_size, self.memory_width) {
267            return Err(crate::error::TimeSeriesError::InvalidOperation(
268                "Memory dimensions do not match".to_string(),
269            ));
270        }
271        self.memory = memory;
272        Ok(())
273    }
274
275    /// Get controller parameters
276    pub fn get_controller_params(&self) -> &Array2<F> {
277        &self.controller_params
278    }
279
280    /// Set controller parameters
281    pub fn set_controller_params(&mut self, params: Array2<F>) -> Result<()> {
282        if params.dim() != self.controller_params.dim() {
283            return Err(crate::error::TimeSeriesError::InvalidOperation(
284                "Controller parameter dimensions do not match".to_string(),
285            ));
286        }
287        self.controller_params = params;
288        Ok(())
289    }
290
291    /// Get memory dimensions
292    pub fn memory_dimensions(&self) -> (usize, usize) {
293        (self.memory_size, self.memory_width)
294    }
295
296    /// Get controller dimensions
297    pub fn controller_dimensions(&self) -> (usize, usize, usize) {
298        (
299            self.controller_input_dim,
300            self.controller_hidden_dim,
301            self.controller_output_dim,
302        )
303    }
304
305    /// Process a sequence of inputs
306    pub fn process_sequence(&mut self, inputs: &[Array1<F>]) -> Result<Vec<Array1<F>>> {
307        let mut outputs = Vec::new();
308
309        for input in inputs {
310            let output = self.forward(input)?;
311            outputs.push(output);
312        }
313
314        Ok(outputs)
315    }
316
317    /// Compute attention weights for memory addressing (simplified)
318    pub fn compute_attention_weights(&self, key: &Array1<F>) -> Result<Array1<F>> {
319        let mut weights = Array1::zeros(self.memory_size);
320
321        for i in 0..self.memory_size {
322            let memory_row = self.memory.row(i);
323            let mut similarity = F::zero();
324
325            for j in 0..key.len().min(memory_row.len()) {
326                similarity = similarity + key[j] * memory_row[j];
327            }
328
329            weights[i] = similarity;
330        }
331
332        // Apply softmax
333        let max_weight = weights.iter().fold(F::neg_infinity(), |a, &b| a.max(b));
334        let mut sum = F::zero();
335
336        for weight in weights.iter_mut() {
337            *weight = (*weight - max_weight).exp();
338            sum = sum + *weight;
339        }
340
341        for weight in weights.iter_mut() {
342            *weight = *weight / sum;
343        }
344
345        Ok(weights)
346    }
347
348    /// Hyperbolic tangent activation
349    fn tanh(&self, x: F) -> F {
350        x.tanh()
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use approx::assert_abs_diff_eq;
358
359    #[test]
360    fn test_mann_creation() {
361        let mann = MANN::<f64>::new(10, 8, 12, 16, 6);
362        let (memory_size, memory_width) = mann.memory_dimensions();
363        let (input_dim, hidden_dim, output_dim) = mann.controller_dimensions();
364
365        assert_eq!(memory_size, 10);
366        assert_eq!(memory_width, 8);
367        assert_eq!(input_dim, 12);
368        assert_eq!(hidden_dim, 16);
369        assert_eq!(output_dim, 6);
370    }
371
372    #[test]
373    fn test_mann_forward() {
374        let mut mann = MANN::<f64>::new(5, 4, 8, 10, 3);
375        let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
376
377        let output = mann.forward(&input).unwrap();
378        assert_eq!(output.len(), 3);
379
380        // Check that output is finite
381        for &val in output.iter() {
382            assert!(val.is_finite());
383        }
384    }
385
386    #[test]
387    fn test_mann_memory_operations() {
388        let mut mann = MANN::<f64>::new(3, 2, 4, 6, 2);
389
390        // Test memory read (should be zeros initially)
391        let read_vector = mann.memory_read().unwrap();
392        assert_eq!(read_vector.len(), 2);
393        for &val in read_vector.iter() {
394            assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
395        }
396
397        // Test memory write
398        let write_data = Array1::from_vec(vec![1.0, 2.0]);
399        mann.memory_write(&write_data).unwrap();
400
401        // Check that memory was updated
402        let memory = mann.get_memory();
403        assert_abs_diff_eq!(memory[[0, 0]], 1.0, epsilon = 1e-10);
404        assert_abs_diff_eq!(memory[[0, 1]], 2.0, epsilon = 1e-10);
405    }
406
407    #[test]
408    fn test_mann_reset_memory() {
409        let mut mann = MANN::<f64>::new(3, 2, 4, 6, 2);
410
411        // Write some data
412        let write_data = Array1::from_vec(vec![5.0, 10.0]);
413        mann.memory_write(&write_data).unwrap();
414
415        // Reset memory
416        mann.reset_memory();
417
418        // Check that memory is reset to zeros
419        let memory = mann.get_memory();
420        for &val in memory.iter() {
421            assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
422        }
423    }
424
425    #[test]
426    fn test_mann_process_sequence() {
427        let mut mann = MANN::<f64>::new(4, 3, 6, 8, 2);
428        let inputs = vec![
429            Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
430            Array1::from_vec(vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0]),
431            Array1::from_vec(vec![3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
432        ];
433
434        let outputs = mann.process_sequence(&inputs).unwrap();
435        assert_eq!(outputs.len(), 3);
436
437        for output in outputs {
438            assert_eq!(output.len(), 2);
439            for &val in output.iter() {
440                assert!(val.is_finite());
441            }
442        }
443    }
444
445    #[test]
446    fn test_mann_attention_weights() {
447        let mut mann = MANN::<f64>::new(3, 4, 6, 8, 2);
448
449        // Set some values in memory
450        let memory_data = Array2::from_shape_vec(
451            (3, 4),
452            vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
453        )
454        .unwrap();
455        mann.set_memory(memory_data).unwrap();
456
457        // Compute attention with a key
458        let key = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
459        let weights = mann.compute_attention_weights(&key).unwrap();
460
461        assert_eq!(weights.len(), 3);
462
463        // The sum of attention weights should be approximately 1
464        let sum: f64 = weights.iter().sum();
465        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
466
467        // All weights should be non-negative
468        for &weight in weights.iter() {
469            assert!(weight >= 0.0);
470        }
471
472        // The first memory location should have the highest weight
473        // since the key matches the first row
474        let max_weight = weights.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
475        assert_abs_diff_eq!(weights[0], max_weight, epsilon = 1e-10);
476    }
477
478    #[test]
479    fn test_mann_controller_forward() {
480        let mann = MANN::<f64>::new(4, 3, 6, 8, 2);
481        let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
482
483        let output = mann.controller_forward(&input).unwrap();
484        assert_eq!(output.len(), 2);
485
486        for &val in output.iter() {
487            assert!(val.is_finite());
488        }
489    }
490
491    #[test]
492    fn test_mann_set_get_params() {
493        let mut mann = MANN::<f64>::new(2, 2, 4, 4, 2);
494
495        let original_params = mann.get_controller_params().clone();
496        let new_params = Array2::zeros(original_params.dim());
497
498        mann.set_controller_params(new_params.clone()).unwrap();
499        let retrieved_params = mann.get_controller_params();
500
501        assert_eq!(retrieved_params.dim(), new_params.dim());
502        for (&a, &b) in retrieved_params.iter().zip(new_params.iter()) {
503            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
504        }
505    }
506
507    #[test]
508    fn test_mann_memory_dimensions_validation() {
509        let mut mann = MANN::<f64>::new(3, 2, 4, 6, 2);
510
511        // Try to set memory with wrong dimensions
512        let wrong_memory = Array2::zeros((2, 3)); // Wrong dimensions
513        let result = mann.set_memory(wrong_memory);
514        assert!(result.is_err());
515
516        // Set memory with correct dimensions
517        let correct_memory = Array2::zeros((3, 2));
518        let result = mann.set_memory(correct_memory);
519        assert!(result.is_ok());
520    }
521}