Skip to main content

trustformers_training/few_shot/
meta_learning.rs

1use anyhow::Result;
2use scirs2_core::ndarray::{s, Array2}; // SciRS2 Integration Policy
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7/// Meta-learning algorithm trait
8pub trait MetaLearningAlgorithm {
9    fn meta_update(&mut self, task_batch: &TaskBatch) -> Result<MetaUpdateResult>;
10    fn adapt(&self, support_set: &TaskData, adaptation_steps: usize) -> Result<ModelParameters>;
11    fn evaluate(&self, params: &ModelParameters, query_set: &TaskData) -> Result<f32>;
12}
13
14/// Configuration for MAML (Model-Agnostic Meta-Learning)
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct MAMLConfig {
17    /// Meta-learning rate (outer loop)
18    pub meta_lr: f32,
19    /// Task-specific learning rate (inner loop)
20    pub inner_lr: f32,
21    /// Number of gradient steps for adaptation
22    pub adaptation_steps: usize,
23    /// Number of tasks per meta-batch
24    pub meta_batch_size: usize,
25    /// Whether to use first-order approximation
26    pub first_order: bool,
27    /// Gradient clipping threshold
28    pub grad_clip: f32,
29    /// Whether to learn inner learning rates
30    pub learn_inner_lrs: bool,
31}
32
33impl Default for MAMLConfig {
34    fn default() -> Self {
35        Self {
36            meta_lr: 0.001,
37            inner_lr: 0.01,
38            adaptation_steps: 5,
39            meta_batch_size: 16,
40            first_order: false,
41            grad_clip: 10.0,
42            learn_inner_lrs: false,
43        }
44    }
45}
46
47/// Configuration for Reptile
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ReptileConfig {
50    /// Meta-learning rate
51    pub meta_lr: f32,
52    /// Task-specific learning rate
53    pub inner_lr: f32,
54    /// Number of gradient steps for adaptation
55    pub adaptation_steps: usize,
56    /// Number of tasks per meta-batch
57    pub meta_batch_size: usize,
58    /// Gradient clipping threshold
59    pub grad_clip: f32,
60}
61
62impl Default for ReptileConfig {
63    fn default() -> Self {
64        Self {
65            meta_lr: 0.001,
66            inner_lr: 0.01,
67            adaptation_steps: 5,
68            meta_batch_size: 16,
69            grad_clip: 10.0,
70        }
71    }
72}
73
74/// Model parameters representation
75#[derive(Debug, Clone)]
76pub struct ModelParameters {
77    /// Parameter tensors by layer name
78    pub parameters: HashMap<String, Array2<f32>>,
79    /// Parameter shapes
80    pub shapes: HashMap<String, Vec<usize>>,
81}
82
83impl Default for ModelParameters {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl ModelParameters {
90    pub fn new() -> Self {
91        Self {
92            parameters: HashMap::new(),
93            shapes: HashMap::new(),
94        }
95    }
96
97    /// Add parameter tensor
98    pub fn add_parameter(&mut self, name: String, tensor: Array2<f32>) {
99        let shape = tensor.shape().to_vec();
100        self.shapes.insert(name.clone(), shape);
101        self.parameters.insert(name, tensor);
102    }
103
104    /// Get parameter by name
105    pub fn get_parameter(&self, name: &str) -> Option<&Array2<f32>> {
106        self.parameters.get(name)
107    }
108
109    /// Get mutable parameter by name
110    pub fn get_parameter_mut(&mut self, name: &str) -> Option<&mut Array2<f32>> {
111        self.parameters.get_mut(name)
112    }
113
114    /// Clone parameters
115    pub fn clone_parameters(&self) -> Self {
116        Self {
117            parameters: self.parameters.clone(),
118            shapes: self.shapes.clone(),
119        }
120    }
121
122    /// Update parameters with gradients
123    pub fn update_with_gradients(&mut self, gradients: &Self, learning_rate: f32) -> Result<()> {
124        for (name, param) in &mut self.parameters {
125            if let Some(grad) = gradients.get_parameter(name) {
126                *param = param.clone() - learning_rate * grad;
127            }
128        }
129        Ok(())
130    }
131
132    /// Compute parameter difference
133    pub fn subtract(&self, other: &Self) -> Result<Self> {
134        let mut result = Self::new();
135
136        for (name, param) in &self.parameters {
137            if let Some(other_param) = other.get_parameter(name) {
138                let diff = param - other_param;
139                result.add_parameter(name.clone(), diff);
140            }
141        }
142
143        Ok(result)
144    }
145
146    /// Add parameters (element-wise)
147    pub fn add(&self, other: &Self) -> Result<Self> {
148        let mut result = Self::new();
149
150        for (name, param) in &self.parameters {
151            if let Some(other_param) = other.get_parameter(name) {
152                let sum = param + other_param;
153                result.add_parameter(name.clone(), sum);
154            }
155        }
156
157        Ok(result)
158    }
159
160    /// Scale parameters by a scalar
161    pub fn scale(&self, scalar: f32) -> Self {
162        let mut result = Self::new();
163
164        for (name, param) in &self.parameters {
165            let scaled = param * scalar;
166            result.add_parameter(name.clone(), scaled);
167        }
168
169        result
170    }
171}
172
173/// Task data (support and query sets)
174#[derive(Debug, Clone)]
175pub struct TaskData {
176    /// Input features
177    pub inputs: Array2<f32>,
178    /// Target outputs
179    pub targets: Array2<f32>,
180    /// Task identifier
181    pub task_id: String,
182}
183
184impl TaskData {
185    pub fn new(inputs: Array2<f32>, targets: Array2<f32>, task_id: String) -> Self {
186        Self {
187            inputs,
188            targets,
189            task_id,
190        }
191    }
192
193    /// Get batch size
194    pub fn batch_size(&self) -> usize {
195        self.inputs.nrows()
196    }
197
198    /// Split into mini-batches
199    pub fn split_batches(&self, batch_size: usize) -> Vec<TaskData> {
200        let total_samples = self.batch_size();
201        let mut batches = Vec::new();
202
203        for start in (0..total_samples).step_by(batch_size) {
204            let end = (start + batch_size).min(total_samples);
205            let batch_inputs = self.inputs.slice(s![start..end, ..]).to_owned();
206            let batch_targets = self.targets.slice(s![start..end, ..]).to_owned();
207
208            batches.push(TaskData::new(
209                batch_inputs,
210                batch_targets,
211                format!("{}_batch_{}", self.task_id, start / batch_size),
212            ));
213        }
214
215        batches
216    }
217}
218
219/// Batch of tasks for meta-learning
220#[derive(Debug)]
221pub struct TaskBatch {
222    /// Support sets for each task
223    pub support_sets: Vec<TaskData>,
224    /// Query sets for each task
225    pub query_sets: Vec<TaskData>,
226}
227
228impl TaskBatch {
229    pub fn new(support_sets: Vec<TaskData>, query_sets: Vec<TaskData>) -> Result<Self> {
230        if support_sets.len() != query_sets.len() {
231            return Err(anyhow::anyhow!(
232                "Support and query sets must have same length"
233            ));
234        }
235        Ok(Self {
236            support_sets,
237            query_sets,
238        })
239    }
240
241    /// Number of tasks in batch
242    pub fn num_tasks(&self) -> usize {
243        self.support_sets.len()
244    }
245}
246
247/// Result of meta-update
248#[derive(Debug)]
249pub struct MetaUpdateResult {
250    /// Meta-loss across all tasks
251    pub meta_loss: f32,
252    /// Per-task losses
253    pub task_losses: Vec<f32>,
254    /// Gradient norm
255    pub grad_norm: f32,
256    /// Updated parameters
257    pub updated_parameters: ModelParameters,
258}
259
260/// MAML (Model-Agnostic Meta-Learning) implementation
261pub struct MAMLTrainer {
262    config: MAMLConfig,
263    meta_parameters: Arc<RwLock<ModelParameters>>,
264    #[allow(dead_code)]
265    optimizer_state: HashMap<String, Array2<f32>>, // For Adam/RMSprop
266    meta_step: usize,
267}
268
269impl MAMLTrainer {
270    pub fn new(config: MAMLConfig, initial_parameters: ModelParameters) -> Self {
271        Self {
272            config,
273            meta_parameters: Arc::new(RwLock::new(initial_parameters)),
274            optimizer_state: HashMap::new(),
275            meta_step: 0,
276        }
277    }
278
279    /// Compute gradients for inner loop adaptation
280    fn compute_inner_gradients(
281        &self,
282        parameters: &ModelParameters,
283        task_data: &TaskData,
284    ) -> Result<ModelParameters> {
285        // Simplified gradient computation (in practice, would use automatic differentiation)
286        let mut gradients = ModelParameters::new();
287
288        for (name, param) in &parameters.parameters {
289            // Compute loss and gradients (simplified)
290            let grad = self.compute_parameter_gradient(param, task_data)?;
291            gradients.add_parameter(name.clone(), grad);
292        }
293
294        Ok(gradients)
295    }
296
297    /// Compute gradients for a parameter using finite differences approximation
298    fn compute_parameter_gradient(
299        &self,
300        param: &Array2<f32>,
301        data: &TaskData,
302    ) -> Result<Array2<f32>> {
303        let eps = 1e-5f32;
304        let mut gradients = Array2::zeros(param.raw_dim());
305        let _original_loss = self.compute_loss_for_parameter(param, data)?;
306
307        // Compute gradients using finite differences
308        for ((i, j), param_val) in param.indexed_iter() {
309            // Forward difference
310            let mut param_plus = param.clone();
311            param_plus[[i, j]] = param_val + eps;
312            let loss_plus = self.compute_loss_for_parameter(&param_plus, data)?;
313
314            // Backward difference
315            let mut param_minus = param.clone();
316            param_minus[[i, j]] = param_val - eps;
317            let loss_minus = self.compute_loss_for_parameter(&param_minus, data)?;
318
319            // Central difference for better accuracy
320            gradients[[i, j]] = (loss_plus - loss_minus) / (2.0 * eps);
321        }
322
323        Ok(gradients)
324    }
325
326    /// Compute loss for a single parameter (helper for gradient computation)
327    fn compute_loss_for_parameter(&self, param: &Array2<f32>, data: &TaskData) -> Result<f32> {
328        // Simplified neural network forward pass
329        // For this example, we'll assume a simple linear model: y = Wx + b
330        let predictions = if param.ncols() == data.inputs.ncols() {
331            // Weight matrix
332            data.inputs.dot(param)
333        } else if param.shape() == [1, data.targets.ncols()] {
334            // Bias vector - broadcast across batch
335
336            Array2::from_shape_fn((data.inputs.nrows(), param.ncols()), |(_, j)| param[[0, j]])
337        } else {
338            // Default: treat as identity-scaled inputs
339            data.inputs.clone()
340        };
341
342        // Mean squared error loss
343        let diff = &predictions - &data.targets;
344        let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
345
346        Ok(mse)
347    }
348
349    /// Perform inner loop adaptation
350    fn inner_loop_adaptation(
351        &self,
352        initial_params: &ModelParameters,
353        support_set: &TaskData,
354    ) -> Result<ModelParameters> {
355        let mut adapted_params = initial_params.clone_parameters();
356
357        for _step in 0..self.config.adaptation_steps {
358            let gradients = self.compute_inner_gradients(&adapted_params, support_set)?;
359            let lr = if self.config.learn_inner_lrs {
360                // In practice, would have learned inner LRs per parameter
361                self.config.inner_lr
362            } else {
363                self.config.inner_lr
364            };
365
366            adapted_params.update_with_gradients(&gradients, lr)?;
367        }
368
369        Ok(adapted_params)
370    }
371
372    /// Compute meta-gradients
373    fn compute_meta_gradients(&self, task_batch: &TaskBatch) -> Result<(ModelParameters, f32)> {
374        let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
375        let mut meta_gradients = ModelParameters::new();
376        let mut total_meta_loss = 0.0;
377
378        // Initialize meta-gradients to zero
379        for (name, param) in &meta_params.parameters {
380            meta_gradients.add_parameter(name.clone(), Array2::zeros(param.raw_dim()));
381        }
382
383        // Accumulate gradients across tasks
384        for (support_set, query_set) in task_batch.support_sets.iter().zip(&task_batch.query_sets) {
385            // Inner loop adaptation
386            let adapted_params = self.inner_loop_adaptation(&meta_params, support_set)?;
387
388            // Compute query loss and gradients
389            let query_loss = self.compute_query_loss(&adapted_params, query_set)?;
390            total_meta_loss += query_loss;
391
392            // Compute gradients w.r.t. meta-parameters
393            let task_meta_grads = if self.config.first_order {
394                // First-order approximation (Reptile-like)
395                meta_params.subtract(&adapted_params)?
396            } else {
397                // Full second-order gradients (expensive)
398                self.compute_second_order_gradients(&meta_params, &adapted_params, query_set)?
399            };
400
401            // Accumulate meta-gradients
402            for (name, grad) in &task_meta_grads.parameters {
403                if let Some(meta_grad) = meta_gradients.get_parameter_mut(name) {
404                    *meta_grad = meta_grad.clone() + grad;
405                }
406            }
407        }
408
409        // Average gradients
410        let num_tasks = task_batch.num_tasks() as f32;
411        for grad in meta_gradients.parameters.values_mut() {
412            *grad = grad.clone() / num_tasks;
413        }
414
415        total_meta_loss /= num_tasks;
416
417        Ok((meta_gradients, total_meta_loss))
418    }
419
420    /// Compute query loss using current model parameters
421    fn compute_query_loss(&self, params: &ModelParameters, query_set: &TaskData) -> Result<f32> {
422        // Perform forward pass through the network
423        let predictions = self.forward_pass(params, &query_set.inputs)?;
424
425        // Compute loss based on task type
426        let loss = if query_set.targets.ncols() == 1 {
427            // Regression task - use MSE
428            self.compute_mse_loss(&predictions, &query_set.targets)?
429        } else {
430            // Classification task - use cross-entropy
431            self.compute_cross_entropy_loss(&predictions, &query_set.targets)?
432        };
433
434        Ok(loss)
435    }
436
437    /// Forward pass through a simple neural network
438    fn forward_pass(&self, params: &ModelParameters, inputs: &Array2<f32>) -> Result<Array2<f32>> {
439        let mut activations = inputs.clone();
440
441        // Apply layers in sequence
442        if let Some(layer1_weights) = params.get_parameter("layer1_weight") {
443            activations = activations.dot(layer1_weights);
444
445            // Add bias if present
446            if let Some(layer1_bias) = params.get_parameter("layer1_bias") {
447                for mut row in activations.rows_mut() {
448                    for (i, &bias) in layer1_bias.row(0).iter().enumerate() {
449                        if i < row.len() {
450                            row[i] += bias;
451                        }
452                    }
453                }
454            }
455
456            // Apply ReLU activation
457            activations.mapv_inplace(|x| x.max(0.0));
458        }
459
460        // Output layer
461        if let Some(output_weights) = params.get_parameter("output_weight") {
462            activations = activations.dot(output_weights);
463
464            if let Some(output_bias) = params.get_parameter("output_bias") {
465                for mut row in activations.rows_mut() {
466                    for (i, &bias) in output_bias.row(0).iter().enumerate() {
467                        if i < row.len() {
468                            row[i] += bias;
469                        }
470                    }
471                }
472            }
473        }
474
475        Ok(activations)
476    }
477
478    /// Compute mean squared error loss
479    fn compute_mse_loss(&self, predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
480        let diff = predictions - targets;
481        let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
482        Ok(mse)
483    }
484
485    /// Compute cross-entropy loss
486    fn compute_cross_entropy_loss(
487        &self,
488        predictions: &Array2<f32>,
489        targets: &Array2<f32>,
490    ) -> Result<f32> {
491        let batch_size = predictions.nrows();
492        let mut total_loss = 0.0;
493
494        for i in 0..batch_size {
495            let pred_row = predictions.row(i);
496            let target_row = targets.row(i);
497
498            // Apply softmax to predictions
499            let max_pred = pred_row.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
500            let exp_preds: Vec<f32> = pred_row.iter().map(|&x| (x - max_pred).exp()).collect();
501            let sum_exp: f32 = exp_preds.iter().sum();
502            let softmax_preds: Vec<f32> = exp_preds.iter().map(|&x| x / sum_exp).collect();
503
504            // Compute cross-entropy
505            let mut row_loss = 0.0;
506            for (&pred, &target) in softmax_preds.iter().zip(target_row.iter()) {
507                if target > 0.0 {
508                    row_loss -= target * pred.max(1e-15).ln();
509                }
510            }
511            total_loss += row_loss;
512        }
513
514        Ok(total_loss / batch_size as f32)
515    }
516
517    /// Compute second-order gradients using finite differences of gradients
518    fn compute_second_order_gradients(
519        &self,
520        meta_params: &ModelParameters,
521        adapted_params: &ModelParameters,
522        query_set: &TaskData,
523    ) -> Result<ModelParameters> {
524        let eps = 1e-4f32;
525        let mut second_order_grads = ModelParameters::new();
526
527        // For each parameter in meta_params, compute second-order gradients
528        for (param_name, meta_param) in &meta_params.parameters {
529            if adapted_params.get_parameter(param_name).is_none() {
530                continue;
531            }
532
533            let mut param_grad = Array2::zeros(meta_param.raw_dim());
534
535            // Compute Hessian-vector product using finite differences
536            for ((i, j), _) in meta_param.indexed_iter() {
537                // Perturb meta-parameter
538                let mut meta_plus = meta_params.clone_parameters();
539                let mut meta_minus = meta_params.clone_parameters();
540
541                if let (Some(param_plus), Some(param_minus)) = (
542                    meta_plus.get_parameter_mut(param_name),
543                    meta_minus.get_parameter_mut(param_name),
544                ) {
545                    param_plus[[i, j]] += eps;
546                    param_minus[[i, j]] -= eps;
547
548                    // Compute gradients at perturbed points
549                    let grad_plus =
550                        self.compute_meta_gradient_at_point(&meta_plus, adapted_params, query_set)?;
551                    let grad_minus = self.compute_meta_gradient_at_point(
552                        &meta_minus,
553                        adapted_params,
554                        query_set,
555                    )?;
556
557                    // Second-order gradient via finite difference
558                    if let (Some(g_plus), Some(g_minus)) = (
559                        grad_plus.get_parameter(param_name),
560                        grad_minus.get_parameter(param_name),
561                    ) {
562                        param_grad[[i, j]] = (g_plus[[i, j]] - g_minus[[i, j]]) / (2.0 * eps);
563                    }
564                }
565            }
566
567            second_order_grads.add_parameter(param_name.clone(), param_grad);
568        }
569
570        Ok(second_order_grads)
571    }
572
573    /// Compute meta-gradient at a specific parameter point
574    fn compute_meta_gradient_at_point(
575        &self,
576        meta_params: &ModelParameters,
577        adapted_params: &ModelParameters,
578        query_set: &TaskData,
579    ) -> Result<ModelParameters> {
580        // Compute gradient of query loss w.r.t. adapted parameters
581        let query_loss_grad = self.compute_query_loss_gradients(adapted_params, query_set)?;
582
583        // Chain rule: grad w.r.t. meta_params = grad w.r.t. adapted_params * jacobian
584        let jacobian = self.compute_adaptation_jacobian(meta_params, adapted_params)?;
585
586        // Apply chain rule
587        let mut meta_grad = ModelParameters::new();
588        for (param_name, loss_grad) in &query_loss_grad.parameters {
589            if let Some(jac) = jacobian.get_parameter(param_name) {
590                // Simplified: element-wise multiplication (in practice would be matrix multiplication)
591                let meta_gradient = loss_grad * jac;
592                meta_grad.add_parameter(param_name.clone(), meta_gradient);
593            }
594        }
595
596        Ok(meta_grad)
597    }
598
599    /// Compute gradients of query loss w.r.t. adapted parameters
600    fn compute_query_loss_gradients(
601        &self,
602        params: &ModelParameters,
603        query_set: &TaskData,
604    ) -> Result<ModelParameters> {
605        let mut gradients = ModelParameters::new();
606
607        for (param_name, param) in &params.parameters {
608            let grad = self.compute_parameter_gradient_for_query(param, query_set)?;
609            gradients.add_parameter(param_name.clone(), grad);
610        }
611
612        Ok(gradients)
613    }
614
615    /// Compute parameter gradient for query loss
616    fn compute_parameter_gradient_for_query(
617        &self,
618        param: &Array2<f32>,
619        query_set: &TaskData,
620    ) -> Result<Array2<f32>> {
621        let eps = 1e-5f32;
622        let mut gradients = Array2::zeros(param.raw_dim());
623
624        for ((i, j), param_val) in param.indexed_iter() {
625            // Forward difference
626            let mut param_plus = param.clone();
627            param_plus[[i, j]] = param_val + eps;
628
629            let mut param_minus = param.clone();
630            param_minus[[i, j]] = param_val - eps;
631
632            // Create temporary parameter sets for loss computation
633            let mut params_plus = ModelParameters::new();
634            let mut params_minus = ModelParameters::new();
635            params_plus.add_parameter("temp_param".to_string(), param_plus);
636            params_minus.add_parameter("temp_param".to_string(), param_minus);
637
638            let loss_plus = self.compute_query_loss(&params_plus, query_set)?;
639            let loss_minus = self.compute_query_loss(&params_minus, query_set)?;
640
641            gradients[[i, j]] = (loss_plus - loss_minus) / (2.0 * eps);
642        }
643
644        Ok(gradients)
645    }
646
647    /// Compute Jacobian of adaptation process (simplified)
648    fn compute_adaptation_jacobian(
649        &self,
650        meta_params: &ModelParameters,
651        adapted_params: &ModelParameters,
652    ) -> Result<ModelParameters> {
653        let mut jacobian = ModelParameters::new();
654
655        // Simplified: assume identity Jacobian for first-order approximation
656        // In practice, this would compute d(adapted_params)/d(meta_params)
657        for (param_name, meta_param) in &meta_params.parameters {
658            if adapted_params.get_parameter(param_name).is_some() {
659                // Identity Jacobian (simplified)
660                let identity_jac =
661                    Array2::eye(meta_param.len()).into_shape_with_order(meta_param.raw_dim())?;
662                jacobian.add_parameter(param_name.clone(), identity_jac);
663            }
664        }
665
666        Ok(jacobian)
667    }
668
669    /// Clip gradients
670    fn clip_gradients(&self, gradients: &mut ModelParameters) -> f32 {
671        let mut total_norm = 0.0;
672
673        // Compute total gradient norm
674        for grad in gradients.parameters.values() {
675            total_norm += grad.mapv(|x| x * x).sum();
676        }
677        total_norm = total_norm.sqrt();
678
679        // Clip if necessary
680        if total_norm > self.config.grad_clip {
681            let clip_coef = self.config.grad_clip / total_norm;
682            for grad in gradients.parameters.values_mut() {
683                *grad = grad.clone() * clip_coef;
684            }
685        }
686
687        total_norm
688    }
689}
690
691impl MetaLearningAlgorithm for MAMLTrainer {
692    fn meta_update(&mut self, task_batch: &TaskBatch) -> Result<MetaUpdateResult> {
693        let (mut meta_gradients, meta_loss) = self.compute_meta_gradients(task_batch)?;
694        let grad_norm = self.clip_gradients(&mut meta_gradients);
695
696        // Update meta-parameters
697        {
698            let mut meta_params =
699                self.meta_parameters.write().expect("lock should not be poisoned");
700            meta_params.update_with_gradients(&meta_gradients, self.config.meta_lr)?;
701        }
702
703        self.meta_step += 1;
704
705        Ok(MetaUpdateResult {
706            meta_loss,
707            task_losses: vec![meta_loss; task_batch.num_tasks()], // Simplified
708            grad_norm,
709            updated_parameters: self
710                .meta_parameters
711                .read()
712                .expect("lock should not be poisoned")
713                .clone_parameters(),
714        })
715    }
716
717    fn adapt(&self, support_set: &TaskData, adaptation_steps: usize) -> Result<ModelParameters> {
718        let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
719        let mut adapted_params = meta_params.clone_parameters();
720
721        for _ in 0..adaptation_steps {
722            let gradients = self.compute_inner_gradients(&adapted_params, support_set)?;
723            adapted_params.update_with_gradients(&gradients, self.config.inner_lr)?;
724        }
725
726        Ok(adapted_params)
727    }
728
729    fn evaluate(&self, params: &ModelParameters, query_set: &TaskData) -> Result<f32> {
730        self.compute_query_loss(params, query_set)
731    }
732}
733
734/// Reptile algorithm implementation
735pub struct ReptileTrainer {
736    config: ReptileConfig,
737    meta_parameters: Arc<RwLock<ModelParameters>>,
738    meta_step: usize,
739}
740
741impl ReptileTrainer {
742    pub fn new(config: ReptileConfig, initial_parameters: ModelParameters) -> Self {
743        Self {
744            config,
745            meta_parameters: Arc::new(RwLock::new(initial_parameters)),
746            meta_step: 0,
747        }
748    }
749
750    /// Perform SGD on a single task
751    fn sgd_on_task(&self, task_data: &TaskData) -> Result<ModelParameters> {
752        let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
753        let mut task_params = meta_params.clone_parameters();
754
755        for _ in 0..self.config.adaptation_steps {
756            let gradients = self.compute_task_gradients(&task_params, task_data)?;
757            task_params.update_with_gradients(&gradients, self.config.inner_lr)?;
758        }
759
760        Ok(task_params)
761    }
762
763    /// Compute gradients for task using proper gradient computation
764    fn compute_task_gradients(
765        &self,
766        params: &ModelParameters,
767        data: &TaskData,
768    ) -> Result<ModelParameters> {
769        let mut gradients = ModelParameters::new();
770
771        // Compute gradients for each parameter
772        for (param_name, param) in &params.parameters {
773            let grad = self.compute_task_parameter_gradient(param, data, param_name)?;
774            gradients.add_parameter(param_name.clone(), grad);
775        }
776
777        Ok(gradients)
778    }
779
780    /// Compute gradient for a specific parameter using finite differences
781    fn compute_task_parameter_gradient(
782        &self,
783        param: &Array2<f32>,
784        data: &TaskData,
785        param_name: &str,
786    ) -> Result<Array2<f32>> {
787        let eps = 1e-5f32;
788        let mut gradients = Array2::zeros(param.raw_dim());
789
790        for ((i, j), param_val) in param.indexed_iter() {
791            // Create perturbed parameters
792            let mut param_plus = param.clone();
793            let mut param_minus = param.clone();
794            param_plus[[i, j]] = param_val + eps;
795            param_minus[[i, j]] = param_val - eps;
796
797            // Compute loss at perturbed points
798            let loss_plus = self.compute_task_loss_for_param(&param_plus, data, param_name)?;
799            let loss_minus = self.compute_task_loss_for_param(&param_minus, data, param_name)?;
800
801            // Gradient via central difference
802            gradients[[i, j]] = (loss_plus - loss_minus) / (2.0 * eps);
803        }
804
805        Ok(gradients)
806    }
807
808    /// Compute task loss for a specific parameter
809    fn compute_task_loss_for_param(
810        &self,
811        param: &Array2<f32>,
812        data: &TaskData,
813        param_name: &str,
814    ) -> Result<f32> {
815        // Create a temporary parameter set with this parameter
816        let mut temp_params = ModelParameters::new();
817        temp_params.add_parameter(param_name.to_string(), param.clone());
818
819        // Perform forward pass
820        let predictions = if param_name.contains("weight") {
821            // Weight matrix - matrix multiplication
822            if param.ncols() == data.inputs.ncols() {
823                data.inputs.dot(param)
824            } else if param.nrows() == data.inputs.ncols() {
825                data.inputs.dot(&param.t())
826            } else {
827                // Default behavior for mismatched dimensions
828                data.inputs.clone()
829            }
830        } else if param_name.contains("bias") {
831            // Bias vector - broadcast addition
832            let mut result = data.inputs.clone();
833            for mut row in result.rows_mut() {
834                for (k, &bias) in param.row(0).iter().enumerate() {
835                    if k < row.len() {
836                        row[k] += bias;
837                    }
838                }
839            }
840            result
841        } else {
842            // Default: identity operation
843            data.inputs.clone()
844        };
845
846        // Compute loss (MSE for simplicity)
847        let diff = &predictions - &data.targets;
848        let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
849
850        Ok(mse)
851    }
852}
853
854impl MetaLearningAlgorithm for ReptileTrainer {
855    fn meta_update(&mut self, task_batch: &TaskBatch) -> Result<MetaUpdateResult> {
856        let mut total_update = ModelParameters::new();
857        let mut total_loss = 0.0;
858
859        // Initialize update to zero
860        {
861            let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
862            for (name, param) in &meta_params.parameters {
863                total_update.add_parameter(name.clone(), Array2::zeros(param.raw_dim()));
864            }
865        }
866
867        // Accumulate updates from all tasks
868        for (support_set, query_set) in task_batch.support_sets.iter().zip(&task_batch.query_sets) {
869            // Train on support set
870            let task_params = self.sgd_on_task(support_set)?;
871
872            // Compute update direction
873            let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
874            let update = task_params.subtract(&meta_params)?;
875
876            // Accumulate update
877            for (name, param_update) in &update.parameters {
878                if let Some(total_param_update) = total_update.get_parameter_mut(name) {
879                    *total_param_update = total_param_update.clone() + param_update;
880                }
881            }
882
883            // Evaluate on query set
884            let loss = self.evaluate(&task_params, query_set)?;
885            total_loss += loss;
886        }
887
888        // Average updates
889        let num_tasks = task_batch.num_tasks() as f32;
890        for param_update in total_update.parameters.values_mut() {
891            *param_update = param_update.clone() / num_tasks;
892        }
893        total_loss /= num_tasks;
894
895        // Apply meta-update
896        {
897            let mut meta_params =
898                self.meta_parameters.write().expect("lock should not be poisoned");
899            let scaled_update = total_update.scale(self.config.meta_lr);
900            *meta_params = meta_params.add(&scaled_update)?;
901        }
902
903        self.meta_step += 1;
904
905        Ok(MetaUpdateResult {
906            meta_loss: total_loss,
907            task_losses: vec![total_loss; task_batch.num_tasks()], // Simplified
908            grad_norm: 0.0,                                        // Not computed for Reptile
909            updated_parameters: self
910                .meta_parameters
911                .read()
912                .expect("lock should not be poisoned")
913                .clone_parameters(),
914        })
915    }
916
917    fn adapt(&self, support_set: &TaskData, adaptation_steps: usize) -> Result<ModelParameters> {
918        let meta_params = self.meta_parameters.read().expect("lock should not be poisoned");
919        let mut adapted_params = meta_params.clone_parameters();
920
921        for _ in 0..adaptation_steps {
922            let gradients = self.compute_task_gradients(&adapted_params, support_set)?;
923            adapted_params.update_with_gradients(&gradients, self.config.inner_lr)?;
924        }
925
926        Ok(adapted_params)
927    }
928
929    fn evaluate(&self, params: &ModelParameters, query_set: &TaskData) -> Result<f32> {
930        // Perform forward pass through the model
931        let mut predictions = query_set.inputs.clone();
932
933        // Apply model parameters in sequence
934        for (param_name, param) in &params.parameters {
935            if param_name.contains("weight") {
936                // Apply weight matrix
937                if param.ncols() == predictions.ncols() {
938                    predictions = predictions.dot(param);
939                } else if param.nrows() == predictions.ncols() {
940                    predictions = predictions.dot(&param.t());
941                }
942
943                // Apply ReLU activation after weight layers
944                if !param_name.contains("output") {
945                    predictions.mapv_inplace(|x| x.max(0.0));
946                }
947            } else if param_name.contains("bias") {
948                // Apply bias
949                for mut row in predictions.rows_mut() {
950                    for (k, &bias) in param.row(0).iter().enumerate() {
951                        if k < row.len() {
952                            row[k] += bias;
953                        }
954                    }
955                }
956            }
957        }
958
959        // Compute loss
960        let diff = &predictions - &query_set.targets;
961        let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
962
963        Ok(mse)
964    }
965}
966
967#[cfg(test)]
968mod tests {
969    use super::*;
970
971    #[test]
972    fn test_model_parameters() {
973        let mut params = ModelParameters::new();
974        let tensor = Array2::ones((2, 3));
975        params.add_parameter("layer1".to_string(), tensor.clone());
976
977        assert_eq!(
978            params.get_parameter("layer1").expect("tensor operation failed"),
979            &tensor
980        );
981        assert_eq!(
982            params.shapes.get("layer1").expect("expected value not found"),
983            &vec![2, 3]
984        );
985    }
986
987    #[test]
988    fn test_task_data() {
989        let inputs = Array2::ones((10, 5));
990        let targets = Array2::zeros((10, 2));
991        let task_data = TaskData::new(inputs, targets, "test_task".to_string());
992
993        assert_eq!(task_data.batch_size(), 10);
994        assert_eq!(task_data.task_id, "test_task");
995    }
996
997    #[test]
998    fn test_task_batch() {
999        let support = vec![
1000            TaskData::new(
1001                Array2::ones((5, 3)),
1002                Array2::zeros((5, 1)),
1003                "task1".to_string(),
1004            ),
1005            TaskData::new(
1006                Array2::ones((5, 3)),
1007                Array2::zeros((5, 1)),
1008                "task2".to_string(),
1009            ),
1010        ];
1011        let query = vec![
1012            TaskData::new(
1013                Array2::ones((3, 3)),
1014                Array2::zeros((3, 1)),
1015                "task1".to_string(),
1016            ),
1017            TaskData::new(
1018                Array2::ones((3, 3)),
1019                Array2::zeros((3, 1)),
1020                "task2".to_string(),
1021            ),
1022        ];
1023
1024        let batch = TaskBatch::new(support, query).expect("operation failed in test");
1025        assert_eq!(batch.num_tasks(), 2);
1026    }
1027
1028    #[test]
1029    fn test_maml_trainer_creation() {
1030        let config = MAMLConfig::default();
1031        let mut params = ModelParameters::new();
1032        params.add_parameter("test".to_string(), Array2::<f32>::ones((2, 2)));
1033
1034        let trainer = MAMLTrainer::new(config, params);
1035        assert_eq!(trainer.meta_step, 0);
1036    }
1037
1038    #[test]
1039    fn test_reptile_trainer_creation() {
1040        let config = ReptileConfig::default();
1041        let mut params = ModelParameters::new();
1042        params.add_parameter("test".to_string(), Array2::<f32>::ones((2, 2)));
1043
1044        let trainer = ReptileTrainer::new(config, params);
1045        assert_eq!(trainer.meta_step, 0);
1046    }
1047
1048    #[test]
1049    fn test_parameter_operations() {
1050        let mut params1 = ModelParameters::new();
1051        let mut params2 = ModelParameters::new();
1052
1053        params1.add_parameter("layer1".to_string(), Array2::<f32>::ones((2, 2)));
1054        params2.add_parameter("layer1".to_string(), Array2::<f32>::ones((2, 2)) * 2.0);
1055
1056        let diff = params2.subtract(&params1).expect("operation failed in test");
1057        let sum = params1.add(&params2).expect("add operation failed");
1058        let scaled = params1.scale(2.0);
1059
1060        assert_eq!(
1061            diff.get_parameter("layer1").expect("operation failed in test"),
1062            &Array2::<f32>::ones((2, 2))
1063        );
1064        assert_eq!(
1065            sum.get_parameter("layer1").expect("operation failed in test"),
1066            &(Array2::<f32>::ones((2, 2)) * 3.0)
1067        );
1068        assert_eq!(
1069            scaled.get_parameter("layer1").expect("operation failed in test"),
1070            &(Array2::<f32>::ones((2, 2)) * 2.0)
1071        );
1072    }
1073}