1use scirs2_core::ndarray::{Array2, Axis};
8use sklears_core::{
9 error::Result as SklResult, prelude::SklearsError, traits::Estimator, types::Float,
10};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct ComputationNode {
16 pub id: String,
18 pub operation: DifferentiableOperation,
20 pub inputs: Vec<String>,
22 pub output_shape: Vec<usize>,
24 pub gradient: Option<Array2<Float>>,
26 pub value: Option<Array2<Float>>,
28 pub requires_grad: bool,
30}
31
32#[derive(Debug, Clone)]
34pub enum DifferentiableOperation {
35 MatMul,
37 Add,
39 Mul,
41 Sub,
43 Div,
45 Activation { function: ActivationFunction },
47 Loss { function: LossFunction },
49 Normalization { method: NormalizationMethod },
51 Convolution {
53 kernel_size: usize,
54 stride: usize,
55 padding: usize,
56 },
57 Pooling {
59 pool_type: PoolingType,
60 kernel_size: usize,
61 stride: usize,
62 },
63 Dropout { rate: f64 },
65 Reshape { shape: Vec<usize> },
67 Concatenate { axis: usize },
69 Slice {
71 start: usize,
72 end: usize,
73 axis: usize,
74 },
75 Custom {
77 name: String,
78 forward: fn(&[Array2<Float>]) -> Array2<Float>,
79 },
80}
81
82#[derive(Debug, Clone)]
84pub enum ActivationFunction {
85 ReLU,
87 Sigmoid,
89 Tanh,
91 Softmax,
93 LeakyReLU { alpha: f64 },
95 ELU { alpha: f64 },
97 Swish,
99 GELU,
101 Mish,
103}
104
105#[derive(Debug, Clone)]
107pub enum LossFunction {
108 MeanSquaredError,
110 CrossEntropy,
112 BinaryCrossEntropy,
114 Huber { delta: f64 },
116 Hinge,
118 KLDivergence,
120 L1Loss,
122 SmoothL1Loss,
124}
125
126#[derive(Debug, Clone)]
128pub enum NormalizationMethod {
129 BatchNorm { momentum: f64, epsilon: f64 },
131 LayerNorm { epsilon: f64 },
133 GroupNorm { num_groups: usize, epsilon: f64 },
135 InstanceNorm { epsilon: f64 },
137 StandardScaler,
139 MinMaxScaler,
141}
142
143#[derive(Debug, Clone)]
145pub enum PoolingType {
146 Max,
148 Average,
150 Global,
152 Adaptive,
154}
155
156pub struct ComputationGraph {
158 nodes: HashMap<String, ComputationNode>,
160 execution_order: Vec<String>,
162 input_nodes: Vec<String>,
164 output_nodes: Vec<String>,
166 parameters: HashMap<String, Array2<Float>>,
168 parameter_gradients: HashMap<String, Array2<Float>>,
170 metadata: GraphMetadata,
172}
173
174#[derive(Debug, Clone)]
176pub struct GraphMetadata {
177 pub name: String,
179 pub created_at: std::time::SystemTime,
181 pub version: String,
183 pub total_parameters: usize,
185 pub trainable_parameters: usize,
187}
188
189pub struct GradientContext {
191 graph: ComputationGraph,
193 forward_values: HashMap<String, Array2<Float>>,
195 backward_gradients: HashMap<String, Array2<Float>>,
197 gradient_tape: Vec<GradientRecord>,
199}
200
201#[derive(Debug, Clone)]
203pub struct GradientRecord {
204 pub node_id: String,
206 pub operation: DifferentiableOperation,
208 pub input_gradients: Vec<Array2<Float>>,
210 pub output_gradient: Array2<Float>,
212}
213
214pub struct DifferentiablePipeline {
216 stages: Vec<DifferentiableStage>,
218 optimization_config: OptimizationConfig,
220 lr_schedule: LearningRateSchedule,
222 gradient_accumulation: GradientAccumulation,
224 training_state: TrainingState,
226 metrics: TrainingMetrics,
228}
229
230pub struct DifferentiableStage {
232 pub name: String,
234 pub graph: ComputationGraph,
236 pub parameters: HashMap<String, Parameter>,
238 pub optimizer_state: OptimizerState,
240 pub config: StageConfig,
242}
243
244#[derive(Debug, Clone)]
246pub struct Parameter {
247 pub values: Array2<Float>,
249 pub gradients: Array2<Float>,
251 pub momentum: Option<Array2<Float>>,
253 pub velocity: Option<Array2<Float>>,
255 pub config: ParameterConfig,
257}
258
259#[derive(Debug, Clone)]
261pub struct ParameterConfig {
262 pub name: String,
264 pub lr_multiplier: f64,
266 pub weight_decay: f64,
268 pub requires_grad: bool,
270 pub initialization: InitializationMethod,
272}
273
274#[derive(Debug, Clone)]
276pub enum InitializationMethod {
277 Zero,
279 Uniform { min: f64, max: f64 },
281 Normal { mean: f64, std: f64 },
283 Xavier,
285 He,
287 Orthogonal,
289 Custom { method: String },
291}
292
293#[derive(Debug, Clone)]
295pub struct OptimizationConfig {
296 pub optimizer: OptimizerType,
298 pub learning_rate: f64,
300 pub momentum: f64,
302 pub weight_decay: f64,
304 pub gradient_clipping: Option<GradientClipping>,
306 pub batch_size: usize,
308 pub max_epochs: usize,
310 pub early_stopping: Option<EarlyStopping>,
312}
313
314#[derive(Debug, Clone)]
316pub enum OptimizerType {
317 SGD,
319 Adam {
321 beta1: f64,
322 beta2: f64,
323 epsilon: f64,
324 },
325 RMSprop { alpha: f64, epsilon: f64 },
327 AdaGrad { epsilon: f64 },
329 AdaDelta { rho: f64, epsilon: f64 },
331 LBfgs { history_size: usize },
333 Custom { name: String },
335}
336
337#[derive(Debug, Clone)]
339pub struct GradientClipping {
340 pub method: ClippingMethod,
342 pub threshold: f64,
344}
345
346#[derive(Debug, Clone)]
348pub enum ClippingMethod {
349 Norm,
351 Value,
353 GlobalNorm,
355}
356
357#[derive(Debug, Clone)]
359pub struct EarlyStopping {
360 pub patience: usize,
362 pub min_delta: f64,
364 pub monitor: String,
366 pub restore_best_weights: bool,
368}
369
370#[derive(Debug, Clone)]
372pub struct LearningRateSchedule {
373 pub schedule_type: ScheduleType,
375 pub initial_lr: f64,
377 pub current_lr: f64,
379 pub parameters: HashMap<String, f64>,
381}
382
383#[derive(Debug, Clone)]
385pub enum ScheduleType {
386 Constant,
388 Linear { final_lr: f64 },
390 Exponential { decay_rate: f64 },
392 StepLR { step_size: usize, gamma: f64 },
394 CosineAnnealing { t_max: usize },
396 ReduceOnPlateau { factor: f64, patience: usize },
398 Polynomial { power: f64, final_lr: f64 },
400 Custom { name: String },
402}
403
404#[derive(Debug, Clone)]
406pub struct GradientAccumulation {
407 pub steps: usize,
409 pub current_step: usize,
411 pub accumulated_gradients: HashMap<String, Array2<Float>>,
413 pub scaling_factor: f64,
415}
416
417#[derive(Debug, Clone)]
419pub struct TrainingState {
420 pub epoch: usize,
422 pub step: usize,
424 pub training: bool,
426 pub best_metric: Option<f64>,
428 pub best_weights: Option<HashMap<String, Array2<Float>>>,
430 pub history: TrainingHistory,
432}
433
434#[derive(Debug, Clone)]
436pub struct TrainingHistory {
437 pub losses: Vec<f64>,
439 pub metrics: HashMap<String, Vec<f64>>,
441 pub learning_rates: Vec<f64>,
443 pub timestamps: Vec<std::time::SystemTime>,
445}
446
447#[derive(Debug, Clone)]
449pub struct TrainingMetrics {
450 pub current_loss: f64,
452 pub current_metrics: HashMap<String, f64>,
454 pub moving_averages: HashMap<String, f64>,
456 pub gradient_norms: HashMap<String, f64>,
458 pub parameter_norms: HashMap<String, f64>,
460}
461
462#[derive(Debug, Clone)]
464pub struct OptimizerState {
465 pub optimizer_type: OptimizerType,
467 pub state: HashMap<String, Array2<Float>>,
469 pub step: usize,
471 pub config: HashMap<String, f64>,
473}
474
475#[derive(Debug, Clone)]
477pub struct StageConfig {
478 pub name: String,
480 pub input_shapes: Vec<Vec<usize>>,
482 pub output_shapes: Vec<Vec<usize>>,
484 pub regularization: RegularizationConfig,
486 pub batch_processing: BatchProcessingConfig,
488}
489
490#[derive(Debug, Clone)]
492pub struct RegularizationConfig {
493 pub l1_lambda: f64,
495 pub l2_lambda: f64,
497 pub dropout_rate: f64,
499 pub batch_norm: bool,
501 pub layer_norm: bool,
503}
504
505#[derive(Debug, Clone)]
507pub struct BatchProcessingConfig {
508 pub batch_size: usize,
510 pub shuffle: bool,
512 pub drop_last: bool,
514 pub num_workers: usize,
516}
517
518pub struct NeuralPipelineController {
520 controller_network: DifferentiablePipeline,
522 controlled_pipeline: Box<dyn PipelineComponent>,
524 control_strategy: ControlStrategy,
526 adaptation_history: Vec<AdaptationRecord>,
528 performance_metrics: ControllerMetrics,
530}
531
532pub trait PipelineComponent: Send + Sync {
534 fn name(&self) -> &str;
536
537 fn process(&mut self, input: &Array2<Float>) -> SklResult<Array2<Float>>;
539
540 fn get_parameters(&self) -> HashMap<String, f64>;
542
543 fn set_parameters(&mut self, params: HashMap<String, f64>) -> SklResult<()>;
545
546 fn get_metrics(&self) -> HashMap<String, f64>;
548}
549
550#[derive(Debug, Clone)]
552pub enum ControlStrategy {
553 Reinforcement { reward_function: RewardFunction },
555 Supervised { target_performance: f64 },
557 MetaLearning { adaptation_steps: usize },
559 Evolutionary {
561 population_size: usize,
562 mutation_rate: f64,
563 },
564 Bayesian {
566 prior_distribution: PriorDistribution,
567 },
568}
569
570#[derive(Debug, Clone)]
572pub enum RewardFunction {
573 Performance { metric: String },
575 Efficiency {
577 latency_weight: f64,
578 accuracy_weight: f64,
579 },
580 ResourceUsage { cpu_weight: f64, memory_weight: f64 },
582 Custom { function: String },
584}
585
586#[derive(Debug, Clone)]
588pub enum PriorDistribution {
589 Normal { mean: f64, std: f64 },
591 Uniform { min: f64, max: f64 },
593 Beta { alpha: f64, beta: f64 },
595 Custom { distribution: String },
597}
598
599#[derive(Debug, Clone)]
601pub struct AdaptationRecord {
602 pub timestamp: std::time::SystemTime,
604 pub previous_params: HashMap<String, f64>,
606 pub new_params: HashMap<String, f64>,
608 pub performance_before: f64,
610 pub performance_after: f64,
612 pub trigger: AdaptationTrigger,
614}
615
616#[derive(Debug, Clone)]
618pub enum AdaptationTrigger {
619 PerformanceDrop { threshold: f64 },
621 DataDrift { magnitude: f64 },
623 ResourceConstraint { constraint: String },
625 ScheduledUpdate,
627 Manual,
629}
630
631#[derive(Debug, Clone)]
633pub struct ControllerMetrics {
634 pub total_adaptations: usize,
636 pub successful_adaptations: usize,
638 pub average_improvement: f64,
640 pub adaptation_latency: std::time::Duration,
642 pub control_overhead: f64,
644}
645
646pub struct AutoDiffEngine {
648 graph: ComputationGraph,
650 forward_mode: ForwardModeAD,
652 reverse_mode: ReverseModeAD,
654 mixed_mode: MixedModeAD,
656 config: AutoDiffConfig,
658}
659
660pub struct ForwardModeAD {
662 dual_numbers: HashMap<String, DualNumber>,
664 computation_order: Vec<String>,
666}
667
668pub struct ReverseModeAD {
670 adjoint_variables: HashMap<String, Array2<Float>>,
672 computation_tape: Vec<TapeEntry>,
674}
675
676pub struct MixedModeAD {
678 forward_nodes: Vec<String>,
680 reverse_nodes: Vec<String>,
682 checkpointing: CheckpointingStrategy,
684}
685
686#[derive(Debug, Clone)]
688pub struct DualNumber {
689 pub real: Array2<Float>,
691 pub dual: Array2<Float>,
693}
694
695#[derive(Debug, Clone)]
697pub struct TapeEntry {
698 pub node_id: String,
700 pub operation: DifferentiableOperation,
702 pub input_values: Vec<Array2<Float>>,
704 pub output_value: Array2<Float>,
706 pub gradient_function: fn(&Array2<Float>, &[Array2<Float>]) -> Vec<Array2<Float>>,
708}
709
710#[derive(Debug, Clone)]
712pub enum CheckpointingStrategy {
713 None,
714 Uniform {
716 interval: usize,
717 },
718 Adaptive {
720 memory_threshold: usize,
721 },
722 Custom {
724 strategy: String,
725 },
726}
727
728#[derive(Debug, Clone)]
730pub struct AutoDiffConfig {
731 pub mode: DifferentiationMode,
733 pub precision: f64,
735 pub memory_optimization: MemoryOptimization,
737 pub parallel: bool,
739 pub checkpointing: CheckpointingStrategy,
741}
742
743#[derive(Debug, Clone)]
745pub enum DifferentiationMode {
746 Forward,
748 Reverse,
750 Mixed,
752 Automatic,
754}
755
756#[derive(Debug, Clone)]
758pub enum MemoryOptimization {
759 None,
760 Gradient {
762 release_intermediate: bool,
763 },
764 Checkpointing {
766 max_checkpoints: usize,
767 },
768 Streaming {
770 chunk_size: usize,
771 },
772}
773
774impl ComputationGraph {
775 #[must_use]
777 pub fn new(name: String) -> Self {
778 Self {
779 nodes: HashMap::new(),
780 execution_order: Vec::new(),
781 input_nodes: Vec::new(),
782 output_nodes: Vec::new(),
783 parameters: HashMap::new(),
784 parameter_gradients: HashMap::new(),
785 metadata: GraphMetadata {
786 name,
787 created_at: std::time::SystemTime::now(),
788 version: "1.0.0".to_string(),
789 total_parameters: 0,
790 trainable_parameters: 0,
791 },
792 }
793 }
794
795 pub fn add_node(&mut self, node: ComputationNode) -> SklResult<()> {
797 let node_id = node.id.clone();
798 self.nodes.insert(node_id, node);
799 self.update_execution_order()?;
800 Ok(())
801 }
802
803 pub fn add_input(&mut self, node_id: String, shape: Vec<usize>) {
805 self.input_nodes.push(node_id.clone());
806 let node = ComputationNode {
807 id: node_id,
808 operation: DifferentiableOperation::Custom {
809 name: "input".to_string(),
810 forward: |inputs| inputs[0].clone(),
811 },
812 inputs: Vec::new(),
813 output_shape: shape,
814 gradient: None,
815 value: None,
816 requires_grad: false,
817 };
818 self.nodes.insert(node.id.clone(), node);
819 }
820
821 pub fn add_output(&mut self, node_id: String) {
823 self.output_nodes.push(node_id);
824 }
825
826 pub fn forward(
828 &mut self,
829 inputs: &HashMap<String, Array2<Float>>,
830 ) -> SklResult<HashMap<String, Array2<Float>>> {
831 for (input_id, input_value) in inputs {
833 if let Some(node) = self.nodes.get_mut(input_id) {
834 node.value = Some(input_value.clone());
835 }
836 }
837
838 let execution_order = self.execution_order.clone();
840 for node_id in execution_order {
841 self.execute_node(&node_id)?;
842 }
843
844 let mut outputs = HashMap::new();
846 for output_id in &self.output_nodes {
847 if let Some(node) = self.nodes.get(output_id) {
848 if let Some(value) = &node.value {
849 outputs.insert(output_id.clone(), value.clone());
850 }
851 }
852 }
853
854 Ok(outputs)
855 }
856
857 pub fn backward(&mut self, output_gradients: &HashMap<String, Array2<Float>>) -> SklResult<()> {
859 for (output_id, grad) in output_gradients {
861 if let Some(node) = self.nodes.get_mut(output_id) {
862 node.gradient = Some(grad.clone());
863 }
864 }
865
866 let execution_order = self.execution_order.clone();
868 for node_id in execution_order.iter().rev() {
869 self.backpropagate_node(node_id)?;
870 }
871
872 Ok(())
873 }
874
875 fn execute_node(&mut self, node_id: &str) -> SklResult<()> {
877 let node = self.nodes.get(node_id).unwrap().clone();
878
879 let mut input_values = Vec::new();
881 for input_id in &node.inputs {
882 if let Some(input_node) = self.nodes.get(input_id) {
883 if let Some(value) = &input_node.value {
884 input_values.push(value.clone());
885 }
886 }
887 }
888
889 let output = self.execute_operation(&node.operation, &input_values)?;
891
892 if let Some(node) = self.nodes.get_mut(node_id) {
894 node.value = Some(output);
895 }
896
897 Ok(())
898 }
899
900 fn execute_operation(
902 &self,
903 operation: &DifferentiableOperation,
904 inputs: &[Array2<Float>],
905 ) -> SklResult<Array2<Float>> {
906 match operation {
907 DifferentiableOperation::MatMul => {
908 if inputs.len() != 2 {
909 return Err(SklearsError::InvalidInput(
910 "MatMul requires 2 inputs".to_string(),
911 ));
912 }
913 Ok(inputs[0].dot(&inputs[1]))
914 }
915 DifferentiableOperation::Add => {
916 if inputs.len() != 2 {
917 return Err(SklearsError::InvalidInput(
918 "Add requires 2 inputs".to_string(),
919 ));
920 }
921 Ok(&inputs[0] + &inputs[1])
922 }
923 DifferentiableOperation::Mul => {
924 if inputs.len() != 2 {
925 return Err(SklearsError::InvalidInput(
926 "Mul requires 2 inputs".to_string(),
927 ));
928 }
929 Ok(&inputs[0] * &inputs[1])
930 }
931 DifferentiableOperation::Activation { function } => {
932 if inputs.len() != 1 {
933 return Err(SklearsError::InvalidInput(
934 "Activation requires 1 input".to_string(),
935 ));
936 }
937 self.apply_activation(function, &inputs[0])
938 }
939 DifferentiableOperation::Custom { forward, .. } => Ok(forward(inputs)),
940 _ => Err(SklearsError::NotImplemented(
941 "Operation not implemented".to_string(),
942 )),
943 }
944 }
945
946 fn apply_activation(
948 &self,
949 function: &ActivationFunction,
950 input: &Array2<Float>,
951 ) -> SklResult<Array2<Float>> {
952 match function {
953 ActivationFunction::ReLU => Ok(input.mapv(|x| x.max(0.0))),
954 ActivationFunction::Sigmoid => Ok(input.mapv(|x| 1.0 / (1.0 + (-x).exp()))),
955 ActivationFunction::Tanh => Ok(input.mapv(f64::tanh)),
956 ActivationFunction::Softmax => {
957 let mut result = input.clone();
958 for mut row in result.axis_iter_mut(Axis(0)) {
959 let max_val = row.fold(Float::NEG_INFINITY, |acc, &x| acc.max(x));
960 row.mapv_inplace(|x| (x - max_val).exp());
961 let sum = row.sum();
962 row.mapv_inplace(|x| x / sum);
963 }
964 Ok(result)
965 }
966 _ => Err(SklearsError::NotImplemented(
967 "Activation function not implemented".to_string(),
968 )),
969 }
970 }
971
972 fn backpropagate_node(&mut self, node_id: &str) -> SklResult<()> {
974 Ok(())
977 }
978
979 fn update_execution_order(&mut self) -> SklResult<()> {
981 self.execution_order = self.nodes.keys().cloned().collect();
983 Ok(())
984 }
985}
986
987impl DifferentiablePipeline {
988 #[must_use]
990 pub fn new(optimization_config: OptimizationConfig) -> Self {
991 Self {
992 stages: Vec::new(),
993 optimization_config,
994 lr_schedule: LearningRateSchedule {
995 schedule_type: ScheduleType::Constant,
996 initial_lr: 0.001,
997 current_lr: 0.001,
998 parameters: HashMap::new(),
999 },
1000 gradient_accumulation: GradientAccumulation {
1001 steps: 1,
1002 current_step: 0,
1003 accumulated_gradients: HashMap::new(),
1004 scaling_factor: 1.0,
1005 },
1006 training_state: TrainingState {
1007 epoch: 0,
1008 step: 0,
1009 training: false,
1010 best_metric: None,
1011 best_weights: None,
1012 history: TrainingHistory {
1013 losses: Vec::new(),
1014 metrics: HashMap::new(),
1015 learning_rates: Vec::new(),
1016 timestamps: Vec::new(),
1017 },
1018 },
1019 metrics: TrainingMetrics {
1020 current_loss: 0.0,
1021 current_metrics: HashMap::new(),
1022 moving_averages: HashMap::new(),
1023 gradient_norms: HashMap::new(),
1024 parameter_norms: HashMap::new(),
1025 },
1026 }
1027 }
1028
1029 pub fn add_stage(&mut self, stage: DifferentiableStage) {
1031 self.stages.push(stage);
1032 }
1033
1034 pub fn train(&mut self, train_data: &[(Array2<Float>, Array2<Float>)]) -> SklResult<()> {
1036 self.training_state.training = true;
1037
1038 for epoch in 0..self.optimization_config.max_epochs {
1039 self.training_state.epoch = epoch;
1040
1041 let mut epoch_loss = 0.0;
1042 let mut batch_count = 0;
1043
1044 for (inputs, targets) in train_data {
1045 let predictions = self.forward(inputs)?;
1047
1048 let loss = self.compute_loss(&predictions, targets)?;
1050 epoch_loss += loss;
1051 batch_count += 1;
1052
1053 self.backward(&predictions, targets)?;
1055
1056 self.update_parameters()?;
1058
1059 self.training_state.step += 1;
1060 }
1061
1062 self.update_learning_rate()?;
1064
1065 let avg_loss = epoch_loss / f64::from(batch_count);
1067 self.training_state.history.losses.push(avg_loss);
1068 self.training_state
1069 .history
1070 .learning_rates
1071 .push(self.lr_schedule.current_lr);
1072 self.training_state
1073 .history
1074 .timestamps
1075 .push(std::time::SystemTime::now());
1076
1077 if let Some(early_stopping) = &self.optimization_config.early_stopping.clone() {
1079 if self.should_early_stop(early_stopping, avg_loss) {
1080 break;
1081 }
1082 }
1083 }
1084
1085 self.training_state.training = false;
1086 Ok(())
1087 }
1088
1089 pub fn forward(&mut self, input: &Array2<Float>) -> SklResult<Array2<Float>> {
1091 let mut current_input = input.clone();
1092
1093 for stage in &mut self.stages {
1094 let inputs = HashMap::from([("input".to_string(), current_input.clone())]);
1095 let outputs = stage.graph.forward(&inputs)?;
1096
1097 if let Some(output) = outputs.get("output") {
1098 current_input = output.clone();
1099 }
1100 }
1101
1102 Ok(current_input)
1103 }
1104
1105 pub fn backward(
1107 &mut self,
1108 predictions: &Array2<Float>,
1109 targets: &Array2<Float>,
1110 ) -> SklResult<()> {
1111 let output_gradients = self.compute_output_gradients(predictions, targets)?;
1113
1114 for stage in self.stages.iter_mut().rev() {
1116 stage.graph.backward(&output_gradients)?;
1117 }
1118
1119 Ok(())
1120 }
1121
1122 fn compute_loss(&self, predictions: &Array2<Float>, targets: &Array2<Float>) -> SklResult<f64> {
1124 let diff = predictions - targets;
1126 let squared_diff = diff.mapv(|x| x * x);
1127 Ok(squared_diff.mean().unwrap_or(0.0))
1128 }
1129
1130 fn compute_output_gradients(
1132 &self,
1133 predictions: &Array2<Float>,
1134 targets: &Array2<Float>,
1135 ) -> SklResult<HashMap<String, Array2<Float>>> {
1136 let gradients = 2.0 * (predictions - targets) / (predictions.len() as f64);
1137 let mut gradient_map = HashMap::new();
1138 gradient_map.insert("output".to_string(), gradients);
1139 Ok(gradient_map)
1140 }
1141
1142 fn update_parameters(&mut self) -> SklResult<()> {
1144 match &self.optimization_config.optimizer {
1145 OptimizerType::SGD => {
1146 for stage in &mut self.stages {
1147 for (param_name, param) in &mut stage.parameters {
1148 let lr = self.lr_schedule.current_lr;
1149 param.values = ¶m.values - &(lr * ¶m.gradients);
1150 }
1151 }
1152 }
1153 OptimizerType::Adam {
1154 beta1,
1155 beta2,
1156 epsilon,
1157 } => {
1158 for stage in &mut self.stages {
1159 for (param_name, param) in &mut stage.parameters {
1160 let lr = self.lr_schedule.current_lr;
1162 let step = self.training_state.step as f64 + 1.0;
1163
1164 if let Some(momentum) = &mut param.momentum {
1166 *momentum = *beta1 * &*momentum + (1.0 - beta1) * ¶m.gradients;
1167 } else {
1168 param.momentum = Some((1.0 - beta1) * ¶m.gradients);
1169 }
1170
1171 if let Some(velocity) = &mut param.velocity {
1173 *velocity = *beta2 * &*velocity
1174 + (1.0 - beta2) * ¶m.gradients.mapv(|x| x * x);
1175 } else {
1176 param.velocity = Some((1.0 - beta2) * ¶m.gradients.mapv(|x| x * x));
1177 }
1178
1179 let momentum_corrected =
1181 param.momentum.as_ref().unwrap() / (1.0 - beta1.powf(step));
1182 let velocity_corrected =
1183 param.velocity.as_ref().unwrap() / (1.0 - beta2.powf(step));
1184
1185 param.values = ¶m.values
1187 - &(lr * &momentum_corrected
1188 / &velocity_corrected.mapv(|x| x.sqrt() + epsilon));
1189 }
1190 }
1191 }
1192 _ => {
1193 return Err(SklearsError::NotImplemented(
1194 "Optimizer not implemented".to_string(),
1195 ))
1196 }
1197 }
1198 Ok(())
1199 }
1200
1201 fn update_learning_rate(&mut self) -> SklResult<()> {
1203 match &self.lr_schedule.schedule_type {
1204 ScheduleType::Constant => {
1205 }
1207 ScheduleType::Exponential { decay_rate } => {
1208 self.lr_schedule.current_lr =
1209 self.lr_schedule.initial_lr * decay_rate.powf(self.training_state.epoch as f64);
1210 }
1211 ScheduleType::StepLR { step_size, gamma } => {
1212 if self.training_state.epoch % step_size == 0 && self.training_state.epoch > 0 {
1213 self.lr_schedule.current_lr *= gamma;
1214 }
1215 }
1216 _ => {
1217 return Err(SklearsError::NotImplemented(
1218 "Learning rate schedule not implemented".to_string(),
1219 ))
1220 }
1221 }
1222 Ok(())
1223 }
1224
1225 fn should_early_stop(&mut self, early_stopping: &EarlyStopping, current_loss: f64) -> bool {
1227 if let Some(best_metric) = self.training_state.best_metric {
1228 if current_loss < best_metric - early_stopping.min_delta {
1229 self.training_state.best_metric = Some(current_loss);
1230 false
1231 } else {
1232 true }
1235 } else {
1236 self.training_state.best_metric = Some(current_loss);
1237 false
1238 }
1239 }
1240}
1241
1242impl NeuralPipelineController {
1243 #[must_use]
1245 pub fn new(
1246 controller_network: DifferentiablePipeline,
1247 controlled_pipeline: Box<dyn PipelineComponent>,
1248 control_strategy: ControlStrategy,
1249 ) -> Self {
1250 Self {
1251 controller_network,
1252 controlled_pipeline,
1253 control_strategy,
1254 adaptation_history: Vec::new(),
1255 performance_metrics: ControllerMetrics {
1256 total_adaptations: 0,
1257 successful_adaptations: 0,
1258 average_improvement: 0.0,
1259 adaptation_latency: std::time::Duration::from_millis(0),
1260 control_overhead: 0.0,
1261 },
1262 }
1263 }
1264
1265 pub fn adapt(&mut self, performance_data: &Array2<Float>) -> SklResult<()> {
1267 let start_time = std::time::Instant::now();
1268
1269 let current_metrics = self.controlled_pipeline.get_metrics();
1271 let current_performance = current_metrics.get("performance").copied().unwrap_or(0.0);
1272
1273 let control_signal = self.controller_network.forward(performance_data)?;
1275
1276 let new_params = self.control_signal_to_parameters(&control_signal)?;
1278
1279 let previous_params = self.controlled_pipeline.get_parameters();
1281 self.controlled_pipeline
1282 .set_parameters(new_params.clone())?;
1283
1284 let new_metrics = self.controlled_pipeline.get_metrics();
1286 let new_performance = new_metrics.get("performance").copied().unwrap_or(0.0);
1287
1288 let adaptation_record = AdaptationRecord {
1290 timestamp: std::time::SystemTime::now(),
1291 previous_params,
1292 new_params,
1293 performance_before: current_performance,
1294 performance_after: new_performance,
1295 trigger: AdaptationTrigger::ScheduledUpdate,
1296 };
1297
1298 self.adaptation_history.push(adaptation_record);
1299
1300 self.performance_metrics.total_adaptations += 1;
1302 if new_performance > current_performance {
1303 self.performance_metrics.successful_adaptations += 1;
1304 }
1305
1306 let adaptation_latency = start_time.elapsed();
1307 self.performance_metrics.adaptation_latency = adaptation_latency;
1308
1309 Ok(())
1310 }
1311
1312 fn control_signal_to_parameters(
1314 &self,
1315 control_signal: &Array2<Float>,
1316 ) -> SklResult<HashMap<String, f64>> {
1317 let mut params = HashMap::new();
1318
1319 for (i, &value) in control_signal.iter().enumerate() {
1321 params.insert(format!("param_{i}"), value);
1322 }
1323
1324 Ok(params)
1325 }
1326
1327 #[must_use]
1329 pub fn get_adaptation_history(&self) -> &[AdaptationRecord] {
1330 &self.adaptation_history
1331 }
1332
1333 #[must_use]
1335 pub fn get_performance_metrics(&self) -> &ControllerMetrics {
1336 &self.performance_metrics
1337 }
1338}
1339
1340impl AutoDiffEngine {
1341 #[must_use]
1343 pub fn new(config: AutoDiffConfig) -> Self {
1344 Self {
1345 graph: ComputationGraph::new("autodiff_graph".to_string()),
1346 forward_mode: ForwardModeAD {
1347 dual_numbers: HashMap::new(),
1348 computation_order: Vec::new(),
1349 },
1350 reverse_mode: ReverseModeAD {
1351 adjoint_variables: HashMap::new(),
1352 computation_tape: Vec::new(),
1353 },
1354 mixed_mode: MixedModeAD {
1355 forward_nodes: Vec::new(),
1356 reverse_nodes: Vec::new(),
1357 checkpointing: config.checkpointing.clone(),
1358 },
1359 config,
1360 }
1361 }
1362
1363 pub fn compute_gradients(
1365 &mut self,
1366 function: &dyn Fn(&Array2<Float>) -> Array2<Float>,
1367 input: &Array2<Float>,
1368 ) -> SklResult<Array2<Float>> {
1369 match self.config.mode {
1370 DifferentiationMode::Forward => self.forward_mode_gradients(function, input),
1371 DifferentiationMode::Reverse => self.reverse_mode_gradients(function, input),
1372 DifferentiationMode::Mixed => self.mixed_mode_gradients(function, input),
1373 DifferentiationMode::Automatic => self.automatic_mode_gradients(function, input),
1374 }
1375 }
1376
1377 fn forward_mode_gradients(
1379 &mut self,
1380 function: &dyn Fn(&Array2<Float>) -> Array2<Float>,
1381 input: &Array2<Float>,
1382 ) -> SklResult<Array2<Float>> {
1383 let h = 1e-8;
1385 let mut gradients = Array2::zeros(input.dim());
1386
1387 for i in 0..input.nrows() {
1388 for j in 0..input.ncols() {
1389 let mut input_plus = input.clone();
1390 let mut input_minus = input.clone();
1391
1392 input_plus[[i, j]] += h;
1393 input_minus[[i, j]] -= h;
1394
1395 let output_plus = function(&input_plus);
1396 let output_minus = function(&input_minus);
1397
1398 gradients[[i, j]] = (output_plus.sum() - output_minus.sum()) / (2.0 * h);
1399 }
1400 }
1401
1402 Ok(gradients)
1403 }
1404
1405 fn reverse_mode_gradients(
1407 &mut self,
1408 function: &dyn Fn(&Array2<Float>) -> Array2<Float>,
1409 input: &Array2<Float>,
1410 ) -> SklResult<Array2<Float>> {
1411 self.forward_mode_gradients(function, input)
1413 }
1414
1415 fn mixed_mode_gradients(
1417 &mut self,
1418 function: &dyn Fn(&Array2<Float>) -> Array2<Float>,
1419 input: &Array2<Float>,
1420 ) -> SklResult<Array2<Float>> {
1421 self.forward_mode_gradients(function, input)
1423 }
1424
1425 fn automatic_mode_gradients(
1427 &mut self,
1428 function: &dyn Fn(&Array2<Float>) -> Array2<Float>,
1429 input: &Array2<Float>,
1430 ) -> SklResult<Array2<Float>> {
1431 if input.len() > 1000 {
1433 self.reverse_mode_gradients(function, input)
1434 } else {
1435 self.forward_mode_gradients(function, input)
1436 }
1437 }
1438}
1439
1440#[allow(non_snake_case)]
1441#[cfg(test)]
1442mod tests {
1443 use super::*;
1444
1445 #[test]
1446 fn test_computation_graph_creation() {
1447 let graph = ComputationGraph::new("test_graph".to_string());
1448 assert_eq!(graph.metadata.name, "test_graph");
1449 assert_eq!(graph.nodes.len(), 0);
1450 }
1451
1452 #[test]
1453 fn test_differentiable_pipeline_creation() {
1454 let config = OptimizationConfig {
1455 optimizer: OptimizerType::SGD,
1456 learning_rate: 0.001,
1457 momentum: 0.9,
1458 weight_decay: 0.0001,
1459 gradient_clipping: None,
1460 batch_size: 32,
1461 max_epochs: 100,
1462 early_stopping: None,
1463 };
1464
1465 let pipeline = DifferentiablePipeline::new(config);
1466 assert_eq!(pipeline.stages.len(), 0);
1467 assert!(!pipeline.training_state.training);
1468 }
1469
1470 #[test]
1471 fn test_activation_functions() {
1472 let input = Array2::from_shape_vec((2, 2), vec![1.0, -1.0, 2.0, -2.0]).unwrap();
1473 let graph = ComputationGraph::new("test".to_string());
1474
1475 let relu_result = graph
1477 .apply_activation(&ActivationFunction::ReLU, &input)
1478 .unwrap();
1479 assert_eq!(relu_result[[0, 0]], 1.0);
1480 assert_eq!(relu_result[[0, 1]], 0.0);
1481
1482 let sigmoid_result = graph
1484 .apply_activation(&ActivationFunction::Sigmoid, &input)
1485 .unwrap();
1486 assert!(sigmoid_result[[0, 0]] > 0.0 && sigmoid_result[[0, 0]] < 1.0);
1487 }
1488
1489 #[test]
1490 fn test_parameter_initialization() {
1491 let config = ParameterConfig {
1492 name: "test_param".to_string(),
1493 lr_multiplier: 1.0,
1494 weight_decay: 0.0001,
1495 requires_grad: true,
1496 initialization: InitializationMethod::Xavier,
1497 };
1498
1499 let param = Parameter {
1500 values: Array2::zeros((3, 3)),
1501 gradients: Array2::zeros((3, 3)),
1502 momentum: None,
1503 velocity: None,
1504 config,
1505 };
1506
1507 assert_eq!(param.values.dim(), (3, 3));
1508 assert!(param.config.requires_grad);
1509 }
1510
1511 #[test]
1512 fn test_learning_rate_schedule() {
1513 let schedule = LearningRateSchedule {
1514 schedule_type: ScheduleType::Exponential { decay_rate: 0.9 },
1515 initial_lr: 0.001,
1516 current_lr: 0.001,
1517 parameters: HashMap::new(),
1518 };
1519
1520 assert_eq!(schedule.initial_lr, 0.001);
1521 assert_eq!(schedule.current_lr, 0.001);
1522 }
1523
1524 #[test]
1525 fn test_gradient_accumulation() {
1526 let accumulation = GradientAccumulation {
1527 steps: 4,
1528 current_step: 0,
1529 accumulated_gradients: HashMap::new(),
1530 scaling_factor: 1.0,
1531 };
1532
1533 assert_eq!(accumulation.steps, 4);
1534 assert_eq!(accumulation.current_step, 0);
1535 }
1536
1537 #[test]
1538 fn test_autodiff_engine_creation() {
1539 let config = AutoDiffConfig {
1540 mode: DifferentiationMode::Automatic,
1541 precision: 1e-8,
1542 memory_optimization: MemoryOptimization::None,
1543 parallel: false,
1544 checkpointing: CheckpointingStrategy::None,
1545 };
1546
1547 let engine = AutoDiffEngine::new(config);
1548 assert!(matches!(engine.config.mode, DifferentiationMode::Automatic));
1549 }
1550
1551 #[test]
1552 fn test_dual_number() {
1553 let dual = DualNumber {
1554 real: Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
1555 dual: Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap(),
1556 };
1557
1558 assert_eq!(dual.real.dim(), (2, 2));
1559 assert_eq!(dual.dual.dim(), (2, 2));
1560 }
1561
1562 #[test]
1563 fn test_control_strategy() {
1564 let strategy = ControlStrategy::Reinforcement {
1565 reward_function: RewardFunction::Performance {
1566 metric: "accuracy".to_string(),
1567 },
1568 };
1569
1570 assert!(matches!(strategy, ControlStrategy::Reinforcement { .. }));
1571 }
1572
1573 #[test]
1574 fn test_adaptation_record() {
1575 let record = AdaptationRecord {
1576 timestamp: std::time::SystemTime::now(),
1577 previous_params: HashMap::new(),
1578 new_params: HashMap::new(),
1579 performance_before: 0.8,
1580 performance_after: 0.85,
1581 trigger: AdaptationTrigger::PerformanceDrop { threshold: 0.1 },
1582 };
1583
1584 assert_eq!(record.performance_before, 0.8);
1585 assert_eq!(record.performance_after, 0.85);
1586 }
1587}