1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::fmt::Debug;
9
10use super::few_shot::FewShotEpisode;
11use crate::error::Result;
12
13#[derive(Debug)]
15pub struct MANN<F: Float + Debug + scirs2_core::ndarray::ScalarOperand> {
16 controller_params: Array2<F>,
18 memory: Array2<F>,
20 memory_size: usize,
22 memory_width: usize,
23 controller_input_dim: usize,
25 controller_hidden_dim: usize,
26 controller_output_dim: usize,
27 #[allow(dead_code)]
29 read_head_params: Array2<F>,
30 #[allow(dead_code)]
31 write_head_params: Array2<F>,
32}
33
34impl<F: Float + Debug + Clone + FromPrimitive + scirs2_core::ndarray::ScalarOperand> MANN<F> {
35 pub fn new(
37 memory_size: usize,
38 memory_width: usize,
39 controller_input_dim: usize,
40 controller_hidden_dim: usize,
41 controller_output_dim: usize,
42 ) -> Self {
43 let controller_param_count = controller_input_dim * controller_hidden_dim
45 + controller_hidden_dim
46 + controller_hidden_dim * controller_output_dim
47 + controller_output_dim;
48
49 let mut controller_params = Array2::zeros((1, controller_param_count));
50 let scale =
51 F::from(2.0).unwrap() / F::from(controller_input_dim + controller_output_dim).unwrap();
52 let std_dev = scale.sqrt();
53
54 for i in 0..controller_param_count {
55 let val = ((i * 67) % 1000) as f64 / 1000.0 - 0.5;
56 controller_params[[0, i]] = F::from(val).unwrap() * std_dev;
57 }
58
59 let memory = Array2::zeros((memory_size, memory_width));
61
62 let head_param_count = memory_width * 2 + 3; let mut read_head_params = Array2::zeros((1, head_param_count));
65 let mut write_head_params = Array2::zeros((1, head_param_count));
66
67 for i in 0..head_param_count {
68 let val1 = ((i * 71) % 1000) as f64 / 1000.0 - 0.5;
69 let val2 = ((i * 73) % 1000) as f64 / 1000.0 - 0.5;
70 read_head_params[[0, i]] = F::from(val1).unwrap() * F::from(0.1).unwrap();
71 write_head_params[[0, i]] = F::from(val2).unwrap() * F::from(0.1).unwrap();
72 }
73
74 Self {
75 controller_params,
76 memory,
77 memory_size,
78 memory_width,
79 controller_input_dim,
80 controller_hidden_dim,
81 controller_output_dim,
82 read_head_params,
83 write_head_params,
84 }
85 }
86
87 pub fn forward(&mut self, input: &Array1<F>) -> Result<Array1<F>> {
89 let read_vector = self.memory_read()?;
91
92 let mut controller_input = Array1::zeros(self.controller_input_dim);
94 for i in 0..input.len().min(self.controller_input_dim) {
95 controller_input[i] = input[i];
96 }
97
98 let read_start = input.len().min(self.controller_input_dim);
100 for i in 0..read_vector.len() {
101 if read_start + i < self.controller_input_dim {
102 controller_input[read_start + i] = read_vector[i];
103 }
104 }
105
106 let controller_output = self.controller_forward(&controller_input)?;
108
109 self.memory_write(&controller_output)?;
111
112 Ok(controller_output)
113 }
114
115 fn controller_forward(&self, input: &Array1<F>) -> Result<Array1<F>> {
117 let (w1, b1, w2, b2) = self.extract_controller_weights();
118
119 let mut hidden = Array1::zeros(self.controller_hidden_dim);
121 for i in 0..self.controller_hidden_dim {
122 let mut sum = b1[i];
123 for j in 0..input.len().min(w1.ncols()) {
124 sum = sum + input[j] * w1[[i, j]];
125 }
126 hidden[i] = self.tanh(sum);
127 }
128
129 let mut output = Array1::zeros(self.controller_output_dim);
131 for i in 0..self.controller_output_dim {
132 let mut sum = b2[i];
133 for j in 0..self.controller_hidden_dim {
134 sum = sum + hidden[j] * w2[[i, j]];
135 }
136 output[i] = sum;
137 }
138
139 Ok(output)
140 }
141
142 fn memory_read(&self) -> Result<Array1<F>> {
144 let mut read_vector = Array1::zeros(self.memory_width);
146
147 for i in 0..self.memory_size {
148 for j in 0..self.memory_width {
149 read_vector[j] = read_vector[j] + self.memory[[i, j]];
150 }
151 }
152
153 let size = F::from(self.memory_size).unwrap();
154 for j in 0..self.memory_width {
155 read_vector[j] = read_vector[j] / size;
156 }
157
158 Ok(read_vector)
159 }
160
161 fn memory_write(&mut self, controller_output: &Array1<F>) -> Result<()> {
163 for i in 0..controller_output.len().min(self.memory_width) {
165 self.memory[[0, i]] = controller_output[i];
166 }
167
168 Ok(())
169 }
170
171 fn extract_controller_weights(&self) -> (Array2<F>, Array1<F>, Array2<F>, Array1<F>) {
173 let param_vec = self.controller_params.row(0);
174 let mut idx = 0;
175
176 let mut w1 = Array2::zeros((self.controller_hidden_dim, self.controller_input_dim));
178 for i in 0..self.controller_hidden_dim {
179 for j in 0..self.controller_input_dim {
180 if idx < param_vec.len() {
181 w1[[i, j]] = param_vec[idx];
182 idx += 1;
183 }
184 }
185 }
186
187 let mut b1 = Array1::zeros(self.controller_hidden_dim);
189 for i in 0..self.controller_hidden_dim {
190 if idx < param_vec.len() {
191 b1[i] = param_vec[idx];
192 idx += 1;
193 }
194 }
195
196 let mut w2 = Array2::zeros((self.controller_output_dim, self.controller_hidden_dim));
198 for i in 0..self.controller_output_dim {
199 for j in 0..self.controller_hidden_dim {
200 if idx < param_vec.len() {
201 w2[[i, j]] = param_vec[idx];
202 idx += 1;
203 }
204 }
205 }
206
207 let mut b2 = Array1::zeros(self.controller_output_dim);
209 for i in 0..self.controller_output_dim {
210 if idx < param_vec.len() {
211 b2[i] = param_vec[idx];
212 idx += 1;
213 }
214 }
215
216 (w1, b1, w2, b2)
217 }
218
219 pub fn reset_memory(&mut self) {
221 self.memory = Array2::zeros((self.memory_size, self.memory_width));
222 }
223
224 pub fn train_few_shot(&mut self, episodes: &[FewShotEpisode<F>]) -> Result<F> {
226 let mut total_loss = F::zero();
227
228 for episode in episodes {
229 self.reset_memory();
230
231 for i in 0..episode.support_x.nrows() {
233 let input_row = episode.support_x.row(i).to_owned();
234 let _output = self.forward(&input_row)?;
235 }
236
237 let mut episode_loss = F::zero();
239 for i in 0..episode.query_x.nrows() {
240 let input_row = episode.query_x.row(i).to_owned();
241 let prediction = self.forward(&input_row)?;
242
243 if i < episode.query_y.len() {
245 let target = F::from(episode.query_y[i]).unwrap();
246 if !prediction.is_empty() {
247 let diff = prediction[0] - target;
248 episode_loss = episode_loss + diff * diff;
249 }
250 }
251 }
252
253 total_loss = total_loss + episode_loss;
254 }
255
256 Ok(total_loss / F::from(episodes.len()).unwrap())
257 }
258
259 pub fn get_memory(&self) -> &Array2<F> {
261 &self.memory
262 }
263
264 pub fn set_memory(&mut self, memory: Array2<F>) -> Result<()> {
266 if memory.dim() != (self.memory_size, self.memory_width) {
267 return Err(crate::error::TimeSeriesError::InvalidOperation(
268 "Memory dimensions do not match".to_string(),
269 ));
270 }
271 self.memory = memory;
272 Ok(())
273 }
274
275 pub fn get_controller_params(&self) -> &Array2<F> {
277 &self.controller_params
278 }
279
280 pub fn set_controller_params(&mut self, params: Array2<F>) -> Result<()> {
282 if params.dim() != self.controller_params.dim() {
283 return Err(crate::error::TimeSeriesError::InvalidOperation(
284 "Controller parameter dimensions do not match".to_string(),
285 ));
286 }
287 self.controller_params = params;
288 Ok(())
289 }
290
291 pub fn memory_dimensions(&self) -> (usize, usize) {
293 (self.memory_size, self.memory_width)
294 }
295
296 pub fn controller_dimensions(&self) -> (usize, usize, usize) {
298 (
299 self.controller_input_dim,
300 self.controller_hidden_dim,
301 self.controller_output_dim,
302 )
303 }
304
305 pub fn process_sequence(&mut self, inputs: &[Array1<F>]) -> Result<Vec<Array1<F>>> {
307 let mut outputs = Vec::new();
308
309 for input in inputs {
310 let output = self.forward(input)?;
311 outputs.push(output);
312 }
313
314 Ok(outputs)
315 }
316
317 pub fn compute_attention_weights(&self, key: &Array1<F>) -> Result<Array1<F>> {
319 let mut weights = Array1::zeros(self.memory_size);
320
321 for i in 0..self.memory_size {
322 let memory_row = self.memory.row(i);
323 let mut similarity = F::zero();
324
325 for j in 0..key.len().min(memory_row.len()) {
326 similarity = similarity + key[j] * memory_row[j];
327 }
328
329 weights[i] = similarity;
330 }
331
332 let max_weight = weights.iter().fold(F::neg_infinity(), |a, &b| a.max(b));
334 let mut sum = F::zero();
335
336 for weight in weights.iter_mut() {
337 *weight = (*weight - max_weight).exp();
338 sum = sum + *weight;
339 }
340
341 for weight in weights.iter_mut() {
342 *weight = *weight / sum;
343 }
344
345 Ok(weights)
346 }
347
348 fn tanh(&self, x: F) -> F {
350 x.tanh()
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use approx::assert_abs_diff_eq;
358
359 #[test]
360 fn test_mann_creation() {
361 let mann = MANN::<f64>::new(10, 8, 12, 16, 6);
362 let (memory_size, memory_width) = mann.memory_dimensions();
363 let (input_dim, hidden_dim, output_dim) = mann.controller_dimensions();
364
365 assert_eq!(memory_size, 10);
366 assert_eq!(memory_width, 8);
367 assert_eq!(input_dim, 12);
368 assert_eq!(hidden_dim, 16);
369 assert_eq!(output_dim, 6);
370 }
371
372 #[test]
373 fn test_mann_forward() {
374 let mut mann = MANN::<f64>::new(5, 4, 8, 10, 3);
375 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
376
377 let output = mann.forward(&input).unwrap();
378 assert_eq!(output.len(), 3);
379
380 for &val in output.iter() {
382 assert!(val.is_finite());
383 }
384 }
385
386 #[test]
387 fn test_mann_memory_operations() {
388 let mut mann = MANN::<f64>::new(3, 2, 4, 6, 2);
389
390 let read_vector = mann.memory_read().unwrap();
392 assert_eq!(read_vector.len(), 2);
393 for &val in read_vector.iter() {
394 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
395 }
396
397 let write_data = Array1::from_vec(vec![1.0, 2.0]);
399 mann.memory_write(&write_data).unwrap();
400
401 let memory = mann.get_memory();
403 assert_abs_diff_eq!(memory[[0, 0]], 1.0, epsilon = 1e-10);
404 assert_abs_diff_eq!(memory[[0, 1]], 2.0, epsilon = 1e-10);
405 }
406
407 #[test]
408 fn test_mann_reset_memory() {
409 let mut mann = MANN::<f64>::new(3, 2, 4, 6, 2);
410
411 let write_data = Array1::from_vec(vec![5.0, 10.0]);
413 mann.memory_write(&write_data).unwrap();
414
415 mann.reset_memory();
417
418 let memory = mann.get_memory();
420 for &val in memory.iter() {
421 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
422 }
423 }
424
425 #[test]
426 fn test_mann_process_sequence() {
427 let mut mann = MANN::<f64>::new(4, 3, 6, 8, 2);
428 let inputs = vec![
429 Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
430 Array1::from_vec(vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0]),
431 Array1::from_vec(vec![3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
432 ];
433
434 let outputs = mann.process_sequence(&inputs).unwrap();
435 assert_eq!(outputs.len(), 3);
436
437 for output in outputs {
438 assert_eq!(output.len(), 2);
439 for &val in output.iter() {
440 assert!(val.is_finite());
441 }
442 }
443 }
444
445 #[test]
446 fn test_mann_attention_weights() {
447 let mut mann = MANN::<f64>::new(3, 4, 6, 8, 2);
448
449 let memory_data = Array2::from_shape_vec(
451 (3, 4),
452 vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
453 )
454 .unwrap();
455 mann.set_memory(memory_data).unwrap();
456
457 let key = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
459 let weights = mann.compute_attention_weights(&key).unwrap();
460
461 assert_eq!(weights.len(), 3);
462
463 let sum: f64 = weights.iter().sum();
465 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
466
467 for &weight in weights.iter() {
469 assert!(weight >= 0.0);
470 }
471
472 let max_weight = weights.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
475 assert_abs_diff_eq!(weights[0], max_weight, epsilon = 1e-10);
476 }
477
478 #[test]
479 fn test_mann_controller_forward() {
480 let mann = MANN::<f64>::new(4, 3, 6, 8, 2);
481 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
482
483 let output = mann.controller_forward(&input).unwrap();
484 assert_eq!(output.len(), 2);
485
486 for &val in output.iter() {
487 assert!(val.is_finite());
488 }
489 }
490
491 #[test]
492 fn test_mann_set_get_params() {
493 let mut mann = MANN::<f64>::new(2, 2, 4, 4, 2);
494
495 let original_params = mann.get_controller_params().clone();
496 let new_params = Array2::zeros(original_params.dim());
497
498 mann.set_controller_params(new_params.clone()).unwrap();
499 let retrieved_params = mann.get_controller_params();
500
501 assert_eq!(retrieved_params.dim(), new_params.dim());
502 for (&a, &b) in retrieved_params.iter().zip(new_params.iter()) {
503 assert_abs_diff_eq!(a, b, epsilon = 1e-10);
504 }
505 }
506
507 #[test]
508 fn test_mann_memory_dimensions_validation() {
509 let mut mann = MANN::<f64>::new(3, 2, 4, 6, 2);
510
511 let wrong_memory = Array2::zeros((2, 3)); let result = mann.set_memory(wrong_memory);
514 assert!(result.is_err());
515
516 let correct_memory = Array2::zeros((3, 2));
518 let result = mann.set_memory(correct_memory);
519 assert!(result.is_ok());
520 }
521}