sklears_compose/
differentiable.rs

1//! Differentiable Pipeline Components with Automatic Differentiation
2//!
3//! This module provides differentiable pipeline components supporting gradient-based
4//! optimization, automatic differentiation, end-to-end learning, and neural pipeline
5//! controllers for adaptive and learnable data processing workflows.
6
7use scirs2_core::ndarray::{Array2, Axis};
8use sklears_core::{
9    error::Result as SklResult, prelude::SklearsError, traits::Estimator, types::Float,
10};
11use std::collections::HashMap;
12
13/// Differentiable computation graph node
14#[derive(Debug, Clone)]
15pub struct ComputationNode {
16    /// Node identifier
17    pub id: String,
18    /// Node operation
19    pub operation: DifferentiableOperation,
20    /// Input nodes
21    pub inputs: Vec<String>,
22    /// Output shape
23    pub output_shape: Vec<usize>,
24    /// Gradient storage
25    pub gradient: Option<Array2<Float>>,
26    /// Forward pass value
27    pub value: Option<Array2<Float>>,
28    /// Whether gradients are needed
29    pub requires_grad: bool,
30}
31
32/// Differentiable operations
33#[derive(Debug, Clone)]
34pub enum DifferentiableOperation {
35    /// Matrix multiplication
36    MatMul,
37    /// Element-wise addition
38    Add,
39    /// Element-wise multiplication
40    Mul,
41    /// Element-wise subtraction
42    Sub,
43    /// Element-wise division
44    Div,
45    /// Activation functions
46    Activation { function: ActivationFunction },
47    /// Loss functions
48    Loss { function: LossFunction },
49    /// Normalization
50    Normalization { method: NormalizationMethod },
51    /// Convolution
52    Convolution {
53        kernel_size: usize,
54        stride: usize,
55        padding: usize,
56    },
57    /// Pooling
58    Pooling {
59        pool_type: PoolingType,
60        kernel_size: usize,
61        stride: usize,
62    },
63    /// Dropout
64    Dropout { rate: f64 },
65    /// Reshape
66    Reshape { shape: Vec<usize> },
67    /// Concatenation
68    Concatenate { axis: usize },
69    /// Slice
70    Slice {
71        start: usize,
72        end: usize,
73        axis: usize,
74    },
75    /// Custom operation
76    Custom {
77        name: String,
78        forward: fn(&[Array2<Float>]) -> Array2<Float>,
79    },
80}
81
82/// Activation functions
83#[derive(Debug, Clone)]
84pub enum ActivationFunction {
85    /// ReLU
86    ReLU,
87    /// Sigmoid
88    Sigmoid,
89    /// Tanh
90    Tanh,
91    /// Softmax
92    Softmax,
93    /// LeakyReLU
94    LeakyReLU { alpha: f64 },
95    /// ELU
96    ELU { alpha: f64 },
97    /// Swish
98    Swish,
99    /// GELU
100    GELU,
101    /// Mish
102    Mish,
103}
104
105/// Loss functions
106#[derive(Debug, Clone)]
107pub enum LossFunction {
108    /// MeanSquaredError
109    MeanSquaredError,
110    /// CrossEntropy
111    CrossEntropy,
112    /// BinaryCrossEntropy
113    BinaryCrossEntropy,
114    /// Huber
115    Huber { delta: f64 },
116    /// Hinge
117    Hinge,
118    /// KLDivergence
119    KLDivergence,
120    /// L1Loss
121    L1Loss,
122    /// SmoothL1Loss
123    SmoothL1Loss,
124}
125
126/// Normalization methods
127#[derive(Debug, Clone)]
128pub enum NormalizationMethod {
129    /// BatchNorm
130    BatchNorm { momentum: f64, epsilon: f64 },
131    /// LayerNorm
132    LayerNorm { epsilon: f64 },
133    /// GroupNorm
134    GroupNorm { num_groups: usize, epsilon: f64 },
135    /// InstanceNorm
136    InstanceNorm { epsilon: f64 },
137    /// StandardScaler
138    StandardScaler,
139    /// MinMaxScaler
140    MinMaxScaler,
141}
142
143/// Pooling types
144#[derive(Debug, Clone)]
145pub enum PoolingType {
146    /// Max
147    Max,
148    /// Average
149    Average,
150    /// Global
151    Global,
152    /// Adaptive
153    Adaptive,
154}
155
156/// Differentiable computation graph
157pub struct ComputationGraph {
158    /// Graph nodes
159    nodes: HashMap<String, ComputationNode>,
160    /// Execution order (topological sort)
161    execution_order: Vec<String>,
162    /// Input nodes
163    input_nodes: Vec<String>,
164    /// Output nodes
165    output_nodes: Vec<String>,
166    /// Parameters (learnable)
167    parameters: HashMap<String, Array2<Float>>,
168    /// Parameter gradients
169    parameter_gradients: HashMap<String, Array2<Float>>,
170    /// Graph metadata
171    metadata: GraphMetadata,
172}
173
174/// Graph metadata
175#[derive(Debug, Clone)]
176pub struct GraphMetadata {
177    /// Graph name
178    pub name: String,
179    /// Creation timestamp
180    pub created_at: std::time::SystemTime,
181    /// Version
182    pub version: String,
183    /// Total parameters
184    pub total_parameters: usize,
185    /// Trainable parameters
186    pub trainable_parameters: usize,
187}
188
189/// Gradient computation context
190pub struct GradientContext {
191    /// Computation graph
192    graph: ComputationGraph,
193    /// Forward pass values
194    forward_values: HashMap<String, Array2<Float>>,
195    /// Backward pass gradients
196    backward_gradients: HashMap<String, Array2<Float>>,
197    /// Gradient tape
198    gradient_tape: Vec<GradientRecord>,
199}
200
201/// Gradient record for automatic differentiation
202#[derive(Debug, Clone)]
203pub struct GradientRecord {
204    /// Node that computed the gradient
205    pub node_id: String,
206    /// Operation performed
207    pub operation: DifferentiableOperation,
208    /// Input gradients
209    pub input_gradients: Vec<Array2<Float>>,
210    /// Output gradient
211    pub output_gradient: Array2<Float>,
212}
213
214/// Differentiable pipeline component
215pub struct DifferentiablePipeline {
216    /// Pipeline stages
217    stages: Vec<DifferentiableStage>,
218    /// Optimization configuration
219    optimization_config: OptimizationConfig,
220    /// Learning rate schedule
221    lr_schedule: LearningRateSchedule,
222    /// Gradient accumulation
223    gradient_accumulation: GradientAccumulation,
224    /// Training state
225    training_state: TrainingState,
226    /// Metrics tracking
227    metrics: TrainingMetrics,
228}
229
230/// Differentiable pipeline stage
231pub struct DifferentiableStage {
232    /// Stage name
233    pub name: String,
234    /// Computation graph
235    pub graph: ComputationGraph,
236    /// Parameters
237    pub parameters: HashMap<String, Parameter>,
238    /// Optimizer state
239    pub optimizer_state: OptimizerState,
240    /// Stage configuration
241    pub config: StageConfig,
242}
243
244/// Parameter with gradient information
245#[derive(Debug, Clone)]
246pub struct Parameter {
247    /// Parameter values
248    pub values: Array2<Float>,
249    /// Parameter gradients
250    pub gradients: Array2<Float>,
251    /// Parameter momentum (for optimizers)
252    pub momentum: Option<Array2<Float>>,
253    /// Parameter velocity (for optimizers)
254    pub velocity: Option<Array2<Float>>,
255    /// Parameter configuration
256    pub config: ParameterConfig,
257}
258
259/// Parameter configuration
260#[derive(Debug, Clone)]
261pub struct ParameterConfig {
262    /// Parameter name
263    pub name: String,
264    /// Learning rate multiplier
265    pub lr_multiplier: f64,
266    /// Weight decay
267    pub weight_decay: f64,
268    /// Requires gradient
269    pub requires_grad: bool,
270    /// Initialization method
271    pub initialization: InitializationMethod,
272}
273
274/// Parameter initialization methods
275#[derive(Debug, Clone)]
276pub enum InitializationMethod {
277    /// Zero
278    Zero,
279    /// Uniform
280    Uniform { min: f64, max: f64 },
281    /// Normal
282    Normal { mean: f64, std: f64 },
283    /// Xavier
284    Xavier,
285    /// He
286    He,
287    /// Orthogonal
288    Orthogonal,
289    /// Custom
290    Custom { method: String },
291}
292
293/// Optimization configuration
294#[derive(Debug, Clone)]
295pub struct OptimizationConfig {
296    /// Optimizer type
297    pub optimizer: OptimizerType,
298    /// Learning rate
299    pub learning_rate: f64,
300    /// Momentum
301    pub momentum: f64,
302    /// Weight decay
303    pub weight_decay: f64,
304    /// Gradient clipping
305    pub gradient_clipping: Option<GradientClipping>,
306    /// Batch size
307    pub batch_size: usize,
308    /// Maximum epochs
309    pub max_epochs: usize,
310    /// Early stopping
311    pub early_stopping: Option<EarlyStopping>,
312}
313
314/// Optimizer types
315#[derive(Debug, Clone)]
316pub enum OptimizerType {
317    /// SGD
318    SGD,
319    /// Adam
320    Adam {
321        beta1: f64,
322        beta2: f64,
323        epsilon: f64,
324    },
325    /// RMSprop
326    RMSprop { alpha: f64, epsilon: f64 },
327    /// AdaGrad
328    AdaGrad { epsilon: f64 },
329    /// AdaDelta
330    AdaDelta { rho: f64, epsilon: f64 },
331    /// LBfgs
332    LBfgs { history_size: usize },
333    /// Custom
334    Custom { name: String },
335}
336
337/// Gradient clipping configuration
338#[derive(Debug, Clone)]
339pub struct GradientClipping {
340    /// Clipping method
341    pub method: ClippingMethod,
342    /// Clipping threshold
343    pub threshold: f64,
344}
345
346/// Gradient clipping methods
347#[derive(Debug, Clone)]
348pub enum ClippingMethod {
349    /// Norm
350    Norm,
351    /// Value
352    Value,
353    /// GlobalNorm
354    GlobalNorm,
355}
356
357/// Early stopping configuration
358#[derive(Debug, Clone)]
359pub struct EarlyStopping {
360    /// Patience (epochs)
361    pub patience: usize,
362    /// Minimum delta for improvement
363    pub min_delta: f64,
364    /// Metric to monitor
365    pub monitor: String,
366    /// Restore best weights
367    pub restore_best_weights: bool,
368}
369
370/// Learning rate schedule
371#[derive(Debug, Clone)]
372pub struct LearningRateSchedule {
373    /// Schedule type
374    pub schedule_type: ScheduleType,
375    /// Initial learning rate
376    pub initial_lr: f64,
377    /// Current learning rate
378    pub current_lr: f64,
379    /// Schedule parameters
380    pub parameters: HashMap<String, f64>,
381}
382
383/// Learning rate schedule types
384#[derive(Debug, Clone)]
385pub enum ScheduleType {
386    /// Constant
387    Constant,
388    /// Linear
389    Linear { final_lr: f64 },
390    /// Exponential
391    Exponential { decay_rate: f64 },
392    /// StepLR
393    StepLR { step_size: usize, gamma: f64 },
394    /// CosineAnnealing
395    CosineAnnealing { t_max: usize },
396    /// ReduceOnPlateau
397    ReduceOnPlateau { factor: f64, patience: usize },
398    /// Polynomial
399    Polynomial { power: f64, final_lr: f64 },
400    /// Custom
401    Custom { name: String },
402}
403
404/// Gradient accumulation configuration
405#[derive(Debug, Clone)]
406pub struct GradientAccumulation {
407    /// Accumulation steps
408    pub steps: usize,
409    /// Current step
410    pub current_step: usize,
411    /// Accumulated gradients
412    pub accumulated_gradients: HashMap<String, Array2<Float>>,
413    /// Scaling factor
414    pub scaling_factor: f64,
415}
416
417/// Training state
418#[derive(Debug, Clone)]
419pub struct TrainingState {
420    /// Current epoch
421    pub epoch: usize,
422    /// Current step
423    pub step: usize,
424    /// Training mode
425    pub training: bool,
426    /// Best metric value
427    pub best_metric: Option<f64>,
428    /// Best weights
429    pub best_weights: Option<HashMap<String, Array2<Float>>>,
430    /// Training history
431    pub history: TrainingHistory,
432}
433
434/// Training history
435#[derive(Debug, Clone)]
436pub struct TrainingHistory {
437    /// Loss values
438    pub losses: Vec<f64>,
439    /// Metric values
440    pub metrics: HashMap<String, Vec<f64>>,
441    /// Learning rates
442    pub learning_rates: Vec<f64>,
443    /// Timestamps
444    pub timestamps: Vec<std::time::SystemTime>,
445}
446
447/// Training metrics
448#[derive(Debug, Clone)]
449pub struct TrainingMetrics {
450    /// Current loss
451    pub current_loss: f64,
452    /// Current metrics
453    pub current_metrics: HashMap<String, f64>,
454    /// Moving averages
455    pub moving_averages: HashMap<String, f64>,
456    /// Gradient norms
457    pub gradient_norms: HashMap<String, f64>,
458    /// Parameter norms
459    pub parameter_norms: HashMap<String, f64>,
460}
461
462/// Optimizer state
463#[derive(Debug, Clone)]
464pub struct OptimizerState {
465    /// Optimizer type
466    pub optimizer_type: OptimizerType,
467    /// State variables
468    pub state: HashMap<String, Array2<Float>>,
469    /// Step count
470    pub step: usize,
471    /// Configuration
472    pub config: HashMap<String, f64>,
473}
474
475/// Stage configuration
476#[derive(Debug, Clone)]
477pub struct StageConfig {
478    /// Stage name
479    pub name: String,
480    /// Input shapes
481    pub input_shapes: Vec<Vec<usize>>,
482    /// Output shapes
483    pub output_shapes: Vec<Vec<usize>>,
484    /// Regularization
485    pub regularization: RegularizationConfig,
486    /// Batch processing
487    pub batch_processing: BatchProcessingConfig,
488}
489
490/// Regularization configuration
491#[derive(Debug, Clone)]
492pub struct RegularizationConfig {
493    /// L1 regularization
494    pub l1_lambda: f64,
495    /// L2 regularization
496    pub l2_lambda: f64,
497    /// Dropout rate
498    pub dropout_rate: f64,
499    /// Batch normalization
500    pub batch_norm: bool,
501    /// Layer normalization
502    pub layer_norm: bool,
503}
504
505/// Batch processing configuration
506#[derive(Debug, Clone)]
507pub struct BatchProcessingConfig {
508    /// Batch size
509    pub batch_size: usize,
510    /// Shuffle data
511    pub shuffle: bool,
512    /// Drop last batch
513    pub drop_last: bool,
514    /// Number of workers
515    pub num_workers: usize,
516}
517
518/// Neural pipeline controller
519pub struct NeuralPipelineController {
520    /// Controller network
521    controller_network: DifferentiablePipeline,
522    /// Controlled pipeline
523    controlled_pipeline: Box<dyn PipelineComponent>,
524    /// Control strategy
525    control_strategy: ControlStrategy,
526    /// Adaptation history
527    adaptation_history: Vec<AdaptationRecord>,
528    /// Performance metrics
529    performance_metrics: ControllerMetrics,
530}
531
532/// Pipeline component trait
533pub trait PipelineComponent: Send + Sync {
534    /// Component name
535    fn name(&self) -> &str;
536
537    /// Process input data
538    fn process(&mut self, input: &Array2<Float>) -> SklResult<Array2<Float>>;
539
540    /// Get configurable parameters
541    fn get_parameters(&self) -> HashMap<String, f64>;
542
543    /// Set configurable parameters
544    fn set_parameters(&mut self, params: HashMap<String, f64>) -> SklResult<()>;
545
546    /// Get performance metrics
547    fn get_metrics(&self) -> HashMap<String, f64>;
548}
549
550/// Control strategy
551#[derive(Debug, Clone)]
552pub enum ControlStrategy {
553    /// Reinforcement
554    Reinforcement { reward_function: RewardFunction },
555    /// Supervised
556    Supervised { target_performance: f64 },
557    /// MetaLearning
558    MetaLearning { adaptation_steps: usize },
559    /// Evolutionary
560    Evolutionary {
561        population_size: usize,
562        mutation_rate: f64,
563    },
564    /// Bayesian
565    Bayesian {
566        prior_distribution: PriorDistribution,
567    },
568}
569
570/// Reward function for reinforcement learning
571#[derive(Debug, Clone)]
572pub enum RewardFunction {
573    /// Performance
574    Performance { metric: String },
575    /// Efficiency
576    Efficiency {
577        latency_weight: f64,
578        accuracy_weight: f64,
579    },
580    /// ResourceUsage
581    ResourceUsage { cpu_weight: f64, memory_weight: f64 },
582    /// Custom
583    Custom { function: String },
584}
585
586/// Prior distribution for Bayesian optimization
587#[derive(Debug, Clone)]
588pub enum PriorDistribution {
589    /// Normal
590    Normal { mean: f64, std: f64 },
591    /// Uniform
592    Uniform { min: f64, max: f64 },
593    /// Beta
594    Beta { alpha: f64, beta: f64 },
595    /// Custom
596    Custom { distribution: String },
597}
598
599/// Adaptation record
600#[derive(Debug, Clone)]
601pub struct AdaptationRecord {
602    /// Timestamp
603    pub timestamp: std::time::SystemTime,
604    /// Previous parameters
605    pub previous_params: HashMap<String, f64>,
606    /// New parameters
607    pub new_params: HashMap<String, f64>,
608    /// Performance before adaptation
609    pub performance_before: f64,
610    /// Performance after adaptation
611    pub performance_after: f64,
612    /// Adaptation trigger
613    pub trigger: AdaptationTrigger,
614}
615
616/// Adaptation triggers
617#[derive(Debug, Clone)]
618pub enum AdaptationTrigger {
619    /// PerformanceDrop
620    PerformanceDrop { threshold: f64 },
621    /// DataDrift
622    DataDrift { magnitude: f64 },
623    /// ResourceConstraint
624    ResourceConstraint { constraint: String },
625    /// ScheduledUpdate
626    ScheduledUpdate,
627    /// Manual
628    Manual,
629}
630
631/// Controller performance metrics
632#[derive(Debug, Clone)]
633pub struct ControllerMetrics {
634    /// Total adaptations
635    pub total_adaptations: usize,
636    /// Successful adaptations
637    pub successful_adaptations: usize,
638    /// Average improvement
639    pub average_improvement: f64,
640    /// Adaptation latency
641    pub adaptation_latency: std::time::Duration,
642    /// Control overhead
643    pub control_overhead: f64,
644}
645
646/// Automatic differentiation engine
647pub struct AutoDiffEngine {
648    /// Computation graph
649    graph: ComputationGraph,
650    /// Forward mode AD
651    forward_mode: ForwardModeAD,
652    /// Reverse mode AD
653    reverse_mode: ReverseModeAD,
654    /// Mixed mode AD
655    mixed_mode: MixedModeAD,
656    /// Engine configuration
657    config: AutoDiffConfig,
658}
659
660/// Forward mode automatic differentiation
661pub struct ForwardModeAD {
662    /// Dual numbers
663    dual_numbers: HashMap<String, DualNumber>,
664    /// Computation order
665    computation_order: Vec<String>,
666}
667
668/// Reverse mode automatic differentiation
669pub struct ReverseModeAD {
670    /// Adjoint variables
671    adjoint_variables: HashMap<String, Array2<Float>>,
672    /// Computation tape
673    computation_tape: Vec<TapeEntry>,
674}
675
676/// Mixed mode automatic differentiation
677pub struct MixedModeAD {
678    /// Forward pass nodes
679    forward_nodes: Vec<String>,
680    /// Reverse pass nodes
681    reverse_nodes: Vec<String>,
682    /// Checkpointing strategy
683    checkpointing: CheckpointingStrategy,
684}
685
686/// Dual number for forward mode AD
687#[derive(Debug, Clone)]
688pub struct DualNumber {
689    /// Real part
690    pub real: Array2<Float>,
691    /// Dual part (gradient)
692    pub dual: Array2<Float>,
693}
694
695/// Tape entry for reverse mode AD
696#[derive(Debug, Clone)]
697pub struct TapeEntry {
698    /// Node ID
699    pub node_id: String,
700    /// Operation
701    pub operation: DifferentiableOperation,
702    /// Input values
703    pub input_values: Vec<Array2<Float>>,
704    /// Output value
705    pub output_value: Array2<Float>,
706    /// Gradient function
707    pub gradient_function: fn(&Array2<Float>, &[Array2<Float>]) -> Vec<Array2<Float>>,
708}
709
710/// Checkpointing strategy
711#[derive(Debug, Clone)]
712pub enum CheckpointingStrategy {
713    None,
714    /// Uniform
715    Uniform {
716        interval: usize,
717    },
718    /// Adaptive
719    Adaptive {
720        memory_threshold: usize,
721    },
722    /// Custom
723    Custom {
724        strategy: String,
725    },
726}
727
728/// `AutoDiff` configuration
729#[derive(Debug, Clone)]
730pub struct AutoDiffConfig {
731    /// Differentiation mode
732    pub mode: DifferentiationMode,
733    /// Numerical precision
734    pub precision: f64,
735    /// Memory optimization
736    pub memory_optimization: MemoryOptimization,
737    /// Parallel computation
738    pub parallel: bool,
739    /// Checkpointing
740    pub checkpointing: CheckpointingStrategy,
741}
742
743/// Differentiation modes
744#[derive(Debug, Clone)]
745pub enum DifferentiationMode {
746    /// Forward
747    Forward,
748    /// Reverse
749    Reverse,
750    /// Mixed
751    Mixed,
752    /// Automatic
753    Automatic,
754}
755
756/// Memory optimization strategies
757#[derive(Debug, Clone)]
758pub enum MemoryOptimization {
759    None,
760    /// Gradient
761    Gradient {
762        release_intermediate: bool,
763    },
764    /// Checkpointing
765    Checkpointing {
766        max_checkpoints: usize,
767    },
768    /// Streaming
769    Streaming {
770        chunk_size: usize,
771    },
772}
773
774impl ComputationGraph {
775    /// Create a new computation graph
776    #[must_use]
777    pub fn new(name: String) -> Self {
778        Self {
779            nodes: HashMap::new(),
780            execution_order: Vec::new(),
781            input_nodes: Vec::new(),
782            output_nodes: Vec::new(),
783            parameters: HashMap::new(),
784            parameter_gradients: HashMap::new(),
785            metadata: GraphMetadata {
786                name,
787                created_at: std::time::SystemTime::now(),
788                version: "1.0.0".to_string(),
789                total_parameters: 0,
790                trainable_parameters: 0,
791            },
792        }
793    }
794
795    /// Add a node to the graph
796    pub fn add_node(&mut self, node: ComputationNode) -> SklResult<()> {
797        let node_id = node.id.clone();
798        self.nodes.insert(node_id, node);
799        self.update_execution_order()?;
800        Ok(())
801    }
802
803    /// Add an input node
804    pub fn add_input(&mut self, node_id: String, shape: Vec<usize>) {
805        self.input_nodes.push(node_id.clone());
806        let node = ComputationNode {
807            id: node_id,
808            operation: DifferentiableOperation::Custom {
809                name: "input".to_string(),
810                forward: |inputs| inputs[0].clone(),
811            },
812            inputs: Vec::new(),
813            output_shape: shape,
814            gradient: None,
815            value: None,
816            requires_grad: false,
817        };
818        self.nodes.insert(node.id.clone(), node);
819    }
820
821    /// Add an output node
822    pub fn add_output(&mut self, node_id: String) {
823        self.output_nodes.push(node_id);
824    }
825
826    /// Forward pass through the graph
827    pub fn forward(
828        &mut self,
829        inputs: &HashMap<String, Array2<Float>>,
830    ) -> SklResult<HashMap<String, Array2<Float>>> {
831        // Set input values
832        for (input_id, input_value) in inputs {
833            if let Some(node) = self.nodes.get_mut(input_id) {
834                node.value = Some(input_value.clone());
835            }
836        }
837
838        // Execute nodes in topological order
839        let execution_order = self.execution_order.clone();
840        for node_id in execution_order {
841            self.execute_node(&node_id)?;
842        }
843
844        // Collect outputs
845        let mut outputs = HashMap::new();
846        for output_id in &self.output_nodes {
847            if let Some(node) = self.nodes.get(output_id) {
848                if let Some(value) = &node.value {
849                    outputs.insert(output_id.clone(), value.clone());
850                }
851            }
852        }
853
854        Ok(outputs)
855    }
856
857    /// Backward pass through the graph
858    pub fn backward(&mut self, output_gradients: &HashMap<String, Array2<Float>>) -> SklResult<()> {
859        // Initialize output gradients
860        for (output_id, grad) in output_gradients {
861            if let Some(node) = self.nodes.get_mut(output_id) {
862                node.gradient = Some(grad.clone());
863            }
864        }
865
866        // Backpropagate gradients
867        let execution_order = self.execution_order.clone();
868        for node_id in execution_order.iter().rev() {
869            self.backpropagate_node(node_id)?;
870        }
871
872        Ok(())
873    }
874
875    /// Execute a single node
876    fn execute_node(&mut self, node_id: &str) -> SklResult<()> {
877        let node = self.nodes.get(node_id).unwrap().clone();
878
879        // Collect input values
880        let mut input_values = Vec::new();
881        for input_id in &node.inputs {
882            if let Some(input_node) = self.nodes.get(input_id) {
883                if let Some(value) = &input_node.value {
884                    input_values.push(value.clone());
885                }
886            }
887        }
888
889        // Execute operation
890        let output = self.execute_operation(&node.operation, &input_values)?;
891
892        // Store output
893        if let Some(node) = self.nodes.get_mut(node_id) {
894            node.value = Some(output);
895        }
896
897        Ok(())
898    }
899
900    /// Execute a differentiable operation
901    fn execute_operation(
902        &self,
903        operation: &DifferentiableOperation,
904        inputs: &[Array2<Float>],
905    ) -> SklResult<Array2<Float>> {
906        match operation {
907            DifferentiableOperation::MatMul => {
908                if inputs.len() != 2 {
909                    return Err(SklearsError::InvalidInput(
910                        "MatMul requires 2 inputs".to_string(),
911                    ));
912                }
913                Ok(inputs[0].dot(&inputs[1]))
914            }
915            DifferentiableOperation::Add => {
916                if inputs.len() != 2 {
917                    return Err(SklearsError::InvalidInput(
918                        "Add requires 2 inputs".to_string(),
919                    ));
920                }
921                Ok(&inputs[0] + &inputs[1])
922            }
923            DifferentiableOperation::Mul => {
924                if inputs.len() != 2 {
925                    return Err(SklearsError::InvalidInput(
926                        "Mul requires 2 inputs".to_string(),
927                    ));
928                }
929                Ok(&inputs[0] * &inputs[1])
930            }
931            DifferentiableOperation::Activation { function } => {
932                if inputs.len() != 1 {
933                    return Err(SklearsError::InvalidInput(
934                        "Activation requires 1 input".to_string(),
935                    ));
936                }
937                self.apply_activation(function, &inputs[0])
938            }
939            DifferentiableOperation::Custom { forward, .. } => Ok(forward(inputs)),
940            _ => Err(SklearsError::NotImplemented(
941                "Operation not implemented".to_string(),
942            )),
943        }
944    }
945
946    /// Apply activation function
947    fn apply_activation(
948        &self,
949        function: &ActivationFunction,
950        input: &Array2<Float>,
951    ) -> SklResult<Array2<Float>> {
952        match function {
953            ActivationFunction::ReLU => Ok(input.mapv(|x| x.max(0.0))),
954            ActivationFunction::Sigmoid => Ok(input.mapv(|x| 1.0 / (1.0 + (-x).exp()))),
955            ActivationFunction::Tanh => Ok(input.mapv(f64::tanh)),
956            ActivationFunction::Softmax => {
957                let mut result = input.clone();
958                for mut row in result.axis_iter_mut(Axis(0)) {
959                    let max_val = row.fold(Float::NEG_INFINITY, |acc, &x| acc.max(x));
960                    row.mapv_inplace(|x| (x - max_val).exp());
961                    let sum = row.sum();
962                    row.mapv_inplace(|x| x / sum);
963                }
964                Ok(result)
965            }
966            _ => Err(SklearsError::NotImplemented(
967                "Activation function not implemented".to_string(),
968            )),
969        }
970    }
971
972    /// Backpropagate gradients through a node
973    fn backpropagate_node(&mut self, node_id: &str) -> SklResult<()> {
974        // This is a simplified version - full implementation would compute gradients
975        // based on the specific operation and chain rule
976        Ok(())
977    }
978
979    /// Update execution order (topological sort)
980    fn update_execution_order(&mut self) -> SklResult<()> {
981        // Simplified topological sort
982        self.execution_order = self.nodes.keys().cloned().collect();
983        Ok(())
984    }
985}
986
987impl DifferentiablePipeline {
988    /// Create a new differentiable pipeline
989    #[must_use]
990    pub fn new(optimization_config: OptimizationConfig) -> Self {
991        Self {
992            stages: Vec::new(),
993            optimization_config,
994            lr_schedule: LearningRateSchedule {
995                schedule_type: ScheduleType::Constant,
996                initial_lr: 0.001,
997                current_lr: 0.001,
998                parameters: HashMap::new(),
999            },
1000            gradient_accumulation: GradientAccumulation {
1001                steps: 1,
1002                current_step: 0,
1003                accumulated_gradients: HashMap::new(),
1004                scaling_factor: 1.0,
1005            },
1006            training_state: TrainingState {
1007                epoch: 0,
1008                step: 0,
1009                training: false,
1010                best_metric: None,
1011                best_weights: None,
1012                history: TrainingHistory {
1013                    losses: Vec::new(),
1014                    metrics: HashMap::new(),
1015                    learning_rates: Vec::new(),
1016                    timestamps: Vec::new(),
1017                },
1018            },
1019            metrics: TrainingMetrics {
1020                current_loss: 0.0,
1021                current_metrics: HashMap::new(),
1022                moving_averages: HashMap::new(),
1023                gradient_norms: HashMap::new(),
1024                parameter_norms: HashMap::new(),
1025            },
1026        }
1027    }
1028
1029    /// Add a differentiable stage
1030    pub fn add_stage(&mut self, stage: DifferentiableStage) {
1031        self.stages.push(stage);
1032    }
1033
1034    /// Train the pipeline
1035    pub fn train(&mut self, train_data: &[(Array2<Float>, Array2<Float>)]) -> SklResult<()> {
1036        self.training_state.training = true;
1037
1038        for epoch in 0..self.optimization_config.max_epochs {
1039            self.training_state.epoch = epoch;
1040
1041            let mut epoch_loss = 0.0;
1042            let mut batch_count = 0;
1043
1044            for (inputs, targets) in train_data {
1045                // Forward pass
1046                let predictions = self.forward(inputs)?;
1047
1048                // Compute loss
1049                let loss = self.compute_loss(&predictions, targets)?;
1050                epoch_loss += loss;
1051                batch_count += 1;
1052
1053                // Backward pass
1054                self.backward(&predictions, targets)?;
1055
1056                // Update parameters
1057                self.update_parameters()?;
1058
1059                self.training_state.step += 1;
1060            }
1061
1062            // Update learning rate
1063            self.update_learning_rate()?;
1064
1065            // Record metrics
1066            let avg_loss = epoch_loss / f64::from(batch_count);
1067            self.training_state.history.losses.push(avg_loss);
1068            self.training_state
1069                .history
1070                .learning_rates
1071                .push(self.lr_schedule.current_lr);
1072            self.training_state
1073                .history
1074                .timestamps
1075                .push(std::time::SystemTime::now());
1076
1077            // Check early stopping
1078            if let Some(early_stopping) = &self.optimization_config.early_stopping.clone() {
1079                if self.should_early_stop(early_stopping, avg_loss) {
1080                    break;
1081                }
1082            }
1083        }
1084
1085        self.training_state.training = false;
1086        Ok(())
1087    }
1088
1089    /// Forward pass through all stages
1090    pub fn forward(&mut self, input: &Array2<Float>) -> SklResult<Array2<Float>> {
1091        let mut current_input = input.clone();
1092
1093        for stage in &mut self.stages {
1094            let inputs = HashMap::from([("input".to_string(), current_input.clone())]);
1095            let outputs = stage.graph.forward(&inputs)?;
1096
1097            if let Some(output) = outputs.get("output") {
1098                current_input = output.clone();
1099            }
1100        }
1101
1102        Ok(current_input)
1103    }
1104
1105    /// Backward pass through all stages
1106    pub fn backward(
1107        &mut self,
1108        predictions: &Array2<Float>,
1109        targets: &Array2<Float>,
1110    ) -> SklResult<()> {
1111        // Compute output gradients
1112        let output_gradients = self.compute_output_gradients(predictions, targets)?;
1113
1114        // Backpropagate through stages in reverse order
1115        for stage in self.stages.iter_mut().rev() {
1116            stage.graph.backward(&output_gradients)?;
1117        }
1118
1119        Ok(())
1120    }
1121
1122    /// Compute loss
1123    fn compute_loss(&self, predictions: &Array2<Float>, targets: &Array2<Float>) -> SklResult<f64> {
1124        // Mean squared error loss
1125        let diff = predictions - targets;
1126        let squared_diff = diff.mapv(|x| x * x);
1127        Ok(squared_diff.mean().unwrap_or(0.0))
1128    }
1129
1130    /// Compute output gradients
1131    fn compute_output_gradients(
1132        &self,
1133        predictions: &Array2<Float>,
1134        targets: &Array2<Float>,
1135    ) -> SklResult<HashMap<String, Array2<Float>>> {
1136        let gradients = 2.0 * (predictions - targets) / (predictions.len() as f64);
1137        let mut gradient_map = HashMap::new();
1138        gradient_map.insert("output".to_string(), gradients);
1139        Ok(gradient_map)
1140    }
1141
1142    /// Update parameters
1143    fn update_parameters(&mut self) -> SklResult<()> {
1144        match &self.optimization_config.optimizer {
1145            OptimizerType::SGD => {
1146                for stage in &mut self.stages {
1147                    for (param_name, param) in &mut stage.parameters {
1148                        let lr = self.lr_schedule.current_lr;
1149                        param.values = &param.values - &(lr * &param.gradients);
1150                    }
1151                }
1152            }
1153            OptimizerType::Adam {
1154                beta1,
1155                beta2,
1156                epsilon,
1157            } => {
1158                for stage in &mut self.stages {
1159                    for (param_name, param) in &mut stage.parameters {
1160                        // Adam optimizer update
1161                        let lr = self.lr_schedule.current_lr;
1162                        let step = self.training_state.step as f64 + 1.0;
1163
1164                        // Update momentum
1165                        if let Some(momentum) = &mut param.momentum {
1166                            *momentum = *beta1 * &*momentum + (1.0 - beta1) * &param.gradients;
1167                        } else {
1168                            param.momentum = Some((1.0 - beta1) * &param.gradients);
1169                        }
1170
1171                        // Update velocity
1172                        if let Some(velocity) = &mut param.velocity {
1173                            *velocity = *beta2 * &*velocity
1174                                + (1.0 - beta2) * &param.gradients.mapv(|x| x * x);
1175                        } else {
1176                            param.velocity = Some((1.0 - beta2) * &param.gradients.mapv(|x| x * x));
1177                        }
1178
1179                        // Bias correction
1180                        let momentum_corrected =
1181                            param.momentum.as_ref().unwrap() / (1.0 - beta1.powf(step));
1182                        let velocity_corrected =
1183                            param.velocity.as_ref().unwrap() / (1.0 - beta2.powf(step));
1184
1185                        // Update parameters
1186                        param.values = &param.values
1187                            - &(lr * &momentum_corrected
1188                                / &velocity_corrected.mapv(|x| x.sqrt() + epsilon));
1189                    }
1190                }
1191            }
1192            _ => {
1193                return Err(SklearsError::NotImplemented(
1194                    "Optimizer not implemented".to_string(),
1195                ))
1196            }
1197        }
1198        Ok(())
1199    }
1200
1201    /// Update learning rate
1202    fn update_learning_rate(&mut self) -> SklResult<()> {
1203        match &self.lr_schedule.schedule_type {
1204            ScheduleType::Constant => {
1205                // No update needed
1206            }
1207            ScheduleType::Exponential { decay_rate } => {
1208                self.lr_schedule.current_lr =
1209                    self.lr_schedule.initial_lr * decay_rate.powf(self.training_state.epoch as f64);
1210            }
1211            ScheduleType::StepLR { step_size, gamma } => {
1212                if self.training_state.epoch % step_size == 0 && self.training_state.epoch > 0 {
1213                    self.lr_schedule.current_lr *= gamma;
1214                }
1215            }
1216            _ => {
1217                return Err(SklearsError::NotImplemented(
1218                    "Learning rate schedule not implemented".to_string(),
1219                ))
1220            }
1221        }
1222        Ok(())
1223    }
1224
1225    /// Check if early stopping should be triggered
1226    fn should_early_stop(&mut self, early_stopping: &EarlyStopping, current_loss: f64) -> bool {
1227        if let Some(best_metric) = self.training_state.best_metric {
1228            if current_loss < best_metric - early_stopping.min_delta {
1229                self.training_state.best_metric = Some(current_loss);
1230                false
1231            } else {
1232                // Check patience
1233                true // Simplified - would need to track patience counter
1234            }
1235        } else {
1236            self.training_state.best_metric = Some(current_loss);
1237            false
1238        }
1239    }
1240}
1241
1242impl NeuralPipelineController {
1243    /// Create a new neural pipeline controller
1244    #[must_use]
1245    pub fn new(
1246        controller_network: DifferentiablePipeline,
1247        controlled_pipeline: Box<dyn PipelineComponent>,
1248        control_strategy: ControlStrategy,
1249    ) -> Self {
1250        Self {
1251            controller_network,
1252            controlled_pipeline,
1253            control_strategy,
1254            adaptation_history: Vec::new(),
1255            performance_metrics: ControllerMetrics {
1256                total_adaptations: 0,
1257                successful_adaptations: 0,
1258                average_improvement: 0.0,
1259                adaptation_latency: std::time::Duration::from_millis(0),
1260                control_overhead: 0.0,
1261            },
1262        }
1263    }
1264
1265    /// Adapt the controlled pipeline
1266    pub fn adapt(&mut self, performance_data: &Array2<Float>) -> SklResult<()> {
1267        let start_time = std::time::Instant::now();
1268
1269        // Get current performance
1270        let current_metrics = self.controlled_pipeline.get_metrics();
1271        let current_performance = current_metrics.get("performance").copied().unwrap_or(0.0);
1272
1273        // Generate control signal
1274        let control_signal = self.controller_network.forward(performance_data)?;
1275
1276        // Convert control signal to parameter updates
1277        let new_params = self.control_signal_to_parameters(&control_signal)?;
1278
1279        // Apply parameter updates
1280        let previous_params = self.controlled_pipeline.get_parameters();
1281        self.controlled_pipeline
1282            .set_parameters(new_params.clone())?;
1283
1284        // Measure new performance
1285        let new_metrics = self.controlled_pipeline.get_metrics();
1286        let new_performance = new_metrics.get("performance").copied().unwrap_or(0.0);
1287
1288        // Record adaptation
1289        let adaptation_record = AdaptationRecord {
1290            timestamp: std::time::SystemTime::now(),
1291            previous_params,
1292            new_params,
1293            performance_before: current_performance,
1294            performance_after: new_performance,
1295            trigger: AdaptationTrigger::ScheduledUpdate,
1296        };
1297
1298        self.adaptation_history.push(adaptation_record);
1299
1300        // Update metrics
1301        self.performance_metrics.total_adaptations += 1;
1302        if new_performance > current_performance {
1303            self.performance_metrics.successful_adaptations += 1;
1304        }
1305
1306        let adaptation_latency = start_time.elapsed();
1307        self.performance_metrics.adaptation_latency = adaptation_latency;
1308
1309        Ok(())
1310    }
1311
1312    /// Convert control signal to parameter updates
1313    fn control_signal_to_parameters(
1314        &self,
1315        control_signal: &Array2<Float>,
1316    ) -> SklResult<HashMap<String, f64>> {
1317        let mut params = HashMap::new();
1318
1319        // This is a simplified conversion - in practice, this would be more sophisticated
1320        for (i, &value) in control_signal.iter().enumerate() {
1321            params.insert(format!("param_{i}"), value);
1322        }
1323
1324        Ok(params)
1325    }
1326
1327    /// Get adaptation history
1328    #[must_use]
1329    pub fn get_adaptation_history(&self) -> &[AdaptationRecord] {
1330        &self.adaptation_history
1331    }
1332
1333    /// Get performance metrics
1334    #[must_use]
1335    pub fn get_performance_metrics(&self) -> &ControllerMetrics {
1336        &self.performance_metrics
1337    }
1338}
1339
1340impl AutoDiffEngine {
1341    /// Create a new automatic differentiation engine
1342    #[must_use]
1343    pub fn new(config: AutoDiffConfig) -> Self {
1344        Self {
1345            graph: ComputationGraph::new("autodiff_graph".to_string()),
1346            forward_mode: ForwardModeAD {
1347                dual_numbers: HashMap::new(),
1348                computation_order: Vec::new(),
1349            },
1350            reverse_mode: ReverseModeAD {
1351                adjoint_variables: HashMap::new(),
1352                computation_tape: Vec::new(),
1353            },
1354            mixed_mode: MixedModeAD {
1355                forward_nodes: Vec::new(),
1356                reverse_nodes: Vec::new(),
1357                checkpointing: config.checkpointing.clone(),
1358            },
1359            config,
1360        }
1361    }
1362
1363    /// Compute gradients using automatic differentiation
1364    pub fn compute_gradients(
1365        &mut self,
1366        function: &dyn Fn(&Array2<Float>) -> Array2<Float>,
1367        input: &Array2<Float>,
1368    ) -> SklResult<Array2<Float>> {
1369        match self.config.mode {
1370            DifferentiationMode::Forward => self.forward_mode_gradients(function, input),
1371            DifferentiationMode::Reverse => self.reverse_mode_gradients(function, input),
1372            DifferentiationMode::Mixed => self.mixed_mode_gradients(function, input),
1373            DifferentiationMode::Automatic => self.automatic_mode_gradients(function, input),
1374        }
1375    }
1376
1377    /// Forward mode automatic differentiation
1378    fn forward_mode_gradients(
1379        &mut self,
1380        function: &dyn Fn(&Array2<Float>) -> Array2<Float>,
1381        input: &Array2<Float>,
1382    ) -> SklResult<Array2<Float>> {
1383        // Simplified forward mode implementation
1384        let h = 1e-8;
1385        let mut gradients = Array2::zeros(input.dim());
1386
1387        for i in 0..input.nrows() {
1388            for j in 0..input.ncols() {
1389                let mut input_plus = input.clone();
1390                let mut input_minus = input.clone();
1391
1392                input_plus[[i, j]] += h;
1393                input_minus[[i, j]] -= h;
1394
1395                let output_plus = function(&input_plus);
1396                let output_minus = function(&input_minus);
1397
1398                gradients[[i, j]] = (output_plus.sum() - output_minus.sum()) / (2.0 * h);
1399            }
1400        }
1401
1402        Ok(gradients)
1403    }
1404
1405    /// Reverse mode automatic differentiation
1406    fn reverse_mode_gradients(
1407        &mut self,
1408        function: &dyn Fn(&Array2<Float>) -> Array2<Float>,
1409        input: &Array2<Float>,
1410    ) -> SklResult<Array2<Float>> {
1411        // Simplified reverse mode implementation
1412        self.forward_mode_gradients(function, input)
1413    }
1414
1415    /// Mixed mode automatic differentiation
1416    fn mixed_mode_gradients(
1417        &mut self,
1418        function: &dyn Fn(&Array2<Float>) -> Array2<Float>,
1419        input: &Array2<Float>,
1420    ) -> SklResult<Array2<Float>> {
1421        // Use forward mode for simplicity
1422        self.forward_mode_gradients(function, input)
1423    }
1424
1425    /// Automatic mode selection
1426    fn automatic_mode_gradients(
1427        &mut self,
1428        function: &dyn Fn(&Array2<Float>) -> Array2<Float>,
1429        input: &Array2<Float>,
1430    ) -> SklResult<Array2<Float>> {
1431        // Choose best mode based on input/output dimensions
1432        if input.len() > 1000 {
1433            self.reverse_mode_gradients(function, input)
1434        } else {
1435            self.forward_mode_gradients(function, input)
1436        }
1437    }
1438}
1439
1440#[allow(non_snake_case)]
1441#[cfg(test)]
1442mod tests {
1443    use super::*;
1444
1445    #[test]
1446    fn test_computation_graph_creation() {
1447        let graph = ComputationGraph::new("test_graph".to_string());
1448        assert_eq!(graph.metadata.name, "test_graph");
1449        assert_eq!(graph.nodes.len(), 0);
1450    }
1451
1452    #[test]
1453    fn test_differentiable_pipeline_creation() {
1454        let config = OptimizationConfig {
1455            optimizer: OptimizerType::SGD,
1456            learning_rate: 0.001,
1457            momentum: 0.9,
1458            weight_decay: 0.0001,
1459            gradient_clipping: None,
1460            batch_size: 32,
1461            max_epochs: 100,
1462            early_stopping: None,
1463        };
1464
1465        let pipeline = DifferentiablePipeline::new(config);
1466        assert_eq!(pipeline.stages.len(), 0);
1467        assert!(!pipeline.training_state.training);
1468    }
1469
1470    #[test]
1471    fn test_activation_functions() {
1472        let input = Array2::from_shape_vec((2, 2), vec![1.0, -1.0, 2.0, -2.0]).unwrap();
1473        let graph = ComputationGraph::new("test".to_string());
1474
1475        // Test ReLU
1476        let relu_result = graph
1477            .apply_activation(&ActivationFunction::ReLU, &input)
1478            .unwrap();
1479        assert_eq!(relu_result[[0, 0]], 1.0);
1480        assert_eq!(relu_result[[0, 1]], 0.0);
1481
1482        // Test Sigmoid
1483        let sigmoid_result = graph
1484            .apply_activation(&ActivationFunction::Sigmoid, &input)
1485            .unwrap();
1486        assert!(sigmoid_result[[0, 0]] > 0.0 && sigmoid_result[[0, 0]] < 1.0);
1487    }
1488
1489    #[test]
1490    fn test_parameter_initialization() {
1491        let config = ParameterConfig {
1492            name: "test_param".to_string(),
1493            lr_multiplier: 1.0,
1494            weight_decay: 0.0001,
1495            requires_grad: true,
1496            initialization: InitializationMethod::Xavier,
1497        };
1498
1499        let param = Parameter {
1500            values: Array2::zeros((3, 3)),
1501            gradients: Array2::zeros((3, 3)),
1502            momentum: None,
1503            velocity: None,
1504            config,
1505        };
1506
1507        assert_eq!(param.values.dim(), (3, 3));
1508        assert!(param.config.requires_grad);
1509    }
1510
1511    #[test]
1512    fn test_learning_rate_schedule() {
1513        let schedule = LearningRateSchedule {
1514            schedule_type: ScheduleType::Exponential { decay_rate: 0.9 },
1515            initial_lr: 0.001,
1516            current_lr: 0.001,
1517            parameters: HashMap::new(),
1518        };
1519
1520        assert_eq!(schedule.initial_lr, 0.001);
1521        assert_eq!(schedule.current_lr, 0.001);
1522    }
1523
1524    #[test]
1525    fn test_gradient_accumulation() {
1526        let accumulation = GradientAccumulation {
1527            steps: 4,
1528            current_step: 0,
1529            accumulated_gradients: HashMap::new(),
1530            scaling_factor: 1.0,
1531        };
1532
1533        assert_eq!(accumulation.steps, 4);
1534        assert_eq!(accumulation.current_step, 0);
1535    }
1536
1537    #[test]
1538    fn test_autodiff_engine_creation() {
1539        let config = AutoDiffConfig {
1540            mode: DifferentiationMode::Automatic,
1541            precision: 1e-8,
1542            memory_optimization: MemoryOptimization::None,
1543            parallel: false,
1544            checkpointing: CheckpointingStrategy::None,
1545        };
1546
1547        let engine = AutoDiffEngine::new(config);
1548        assert!(matches!(engine.config.mode, DifferentiationMode::Automatic));
1549    }
1550
1551    #[test]
1552    fn test_dual_number() {
1553        let dual = DualNumber {
1554            real: Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
1555            dual: Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap(),
1556        };
1557
1558        assert_eq!(dual.real.dim(), (2, 2));
1559        assert_eq!(dual.dual.dim(), (2, 2));
1560    }
1561
1562    #[test]
1563    fn test_control_strategy() {
1564        let strategy = ControlStrategy::Reinforcement {
1565            reward_function: RewardFunction::Performance {
1566                metric: "accuracy".to_string(),
1567            },
1568        };
1569
1570        assert!(matches!(strategy, ControlStrategy::Reinforcement { .. }));
1571    }
1572
1573    #[test]
1574    fn test_adaptation_record() {
1575        let record = AdaptationRecord {
1576            timestamp: std::time::SystemTime::now(),
1577            previous_params: HashMap::new(),
1578            new_params: HashMap::new(),
1579            performance_before: 0.8,
1580            performance_after: 0.85,
1581            trigger: AdaptationTrigger::PerformanceDrop { threshold: 0.1 },
1582        };
1583
1584        assert_eq!(record.performance_before, 0.8);
1585        assert_eq!(record.performance_after, 0.85);
1586    }
1587}