1use crate::analysis::types::*;
7use crate::error::{IntegrateError, IntegrateResult};
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::Rng;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct BifurcationPredictionNetwork {
15 pub architecture: NetworkArchitecture,
17 pub model_parameters: ModelParameters,
19 pub training_config: super::training::TrainingConfiguration,
21 pub feature_extraction: super::features::FeatureExtraction,
23 pub performance_metrics: super::uncertainty::PerformanceMetrics,
25 pub uncertainty_quantification: super::uncertainty::UncertaintyQuantification,
27}
28
29#[derive(Debug, Clone)]
31pub struct NetworkArchitecture {
32 pub input_size: usize,
34 pub hidden_layers: Vec<usize>,
36 pub output_size: usize,
38 pub activation_functions: Vec<ActivationFunction>,
40 pub dropoutrates: Vec<f64>,
42 pub batch_normalization: Vec<bool>,
44 pub skip_connections: Vec<SkipConnection>,
46}
47
48#[derive(Debug, Clone, Copy)]
50pub enum ActivationFunction {
51 ReLU,
53 LeakyReLU(f64),
55 Tanh,
57 Sigmoid,
59 Softmax,
61 Swish,
63 GELU,
65 ELU(f64),
67}
68
69#[derive(Debug, Clone)]
71pub struct SkipConnection {
72 pub from_layer: usize,
74 pub to_layer: usize,
76 pub connection_type: ConnectionType,
78}
79
80#[derive(Debug, Clone, Copy)]
82pub enum ConnectionType {
83 Addition,
85 Concatenation,
87 Gated,
89}
90
91#[derive(Debug, Clone)]
93pub struct ModelParameters {
94 pub weights: Vec<Array2<f64>>,
96 pub biases: Vec<Array1<f64>>,
98 pub batch_norm_params: Vec<BatchNormParams>,
100 pub dropout_masks: Vec<Array1<bool>>,
102}
103
104#[derive(Debug, Clone)]
106pub struct BatchNormParams {
107 pub scale: Array1<f64>,
109 pub shift: Array1<f64>,
111 pub running_mean: Array1<f64>,
113 pub running_var: Array1<f64>,
115}
116
117#[derive(Debug, Clone)]
119pub struct BifurcationPrediction {
120 pub bifurcation_type: BifurcationType,
122 pub predicted_parameter: f64,
124 pub confidence: f64,
126 pub raw_output: Array1<f64>,
128 pub uncertainty_estimate: Option<UncertaintyEstimate>,
130}
131
132#[derive(Debug, Clone)]
134pub struct UncertaintyEstimate {
135 pub epistemic_uncertainty: f64,
137 pub aleatoric_uncertainty: f64,
139 pub total_uncertainty: f64,
141 pub confidence_interval: (f64, f64),
143}
144
145impl BifurcationPredictionNetwork {
146 pub fn new(input_size: usize, hidden_layers: Vec<usize>, output_size: usize) -> Self {
148 let architecture = NetworkArchitecture {
149 input_size,
150 hidden_layers: hidden_layers.clone(),
151 output_size,
152 activation_functions: vec![ActivationFunction::ReLU; hidden_layers.len() + 1],
153 dropoutrates: vec![0.0; hidden_layers.len() + 1],
154 batch_normalization: vec![false; hidden_layers.len() + 1],
155 skip_connections: Vec::new(),
156 };
157
158 let model_parameters = Self::initialize_parameters(&architecture);
159
160 Self {
161 architecture,
162 model_parameters,
163 training_config: super::training::TrainingConfiguration::default(),
164 feature_extraction: super::features::FeatureExtraction::default(),
165 performance_metrics: super::uncertainty::PerformanceMetrics::default(),
166 uncertainty_quantification: super::uncertainty::UncertaintyQuantification::default(),
167 }
168 }
169
170 fn initialize_parameters(arch: &NetworkArchitecture) -> ModelParameters {
172 let mut weights = Vec::new();
173 let mut biases = Vec::new();
174
175 let mut prev_size = arch.input_size;
176 for &layer_size in &arch.hidden_layers {
177 weights.push(Array2::zeros((prev_size, layer_size)));
178 biases.push(Array1::zeros(layer_size));
179 prev_size = layer_size;
180 }
181
182 weights.push(Array2::zeros((prev_size, arch.output_size)));
184 biases.push(Array1::zeros(arch.output_size));
185
186 ModelParameters {
187 weights,
188 biases,
189 batch_norm_params: Vec::new(),
190 dropout_masks: Vec::new(),
191 }
192 }
193
194 pub fn forward(&self, input: &Array1<f64>) -> IntegrateResult<Array1<f64>> {
196 let mut activation = input.clone();
197
198 for (i, (weights, bias)) in self
199 .model_parameters
200 .weights
201 .iter()
202 .zip(&self.model_parameters.biases)
203 .enumerate()
204 {
205 activation = weights.t().dot(&activation) + bias;
207
208 activation = self.apply_activation_function(
210 &activation,
211 self.architecture.activation_functions[i],
212 )?;
213
214 if self.architecture.dropoutrates[i] > 0.0 {
216 activation = Self::apply_dropout(&activation, self.architecture.dropoutrates[i])?;
217 }
218 }
219
220 Ok(activation)
221 }
222
223 fn apply_activation_function(
225 &self,
226 x: &Array1<f64>,
227 func: ActivationFunction,
228 ) -> IntegrateResult<Array1<f64>> {
229 let result = match func {
230 ActivationFunction::ReLU => x.mapv(|v| v.max(0.0)),
231 ActivationFunction::LeakyReLU(alpha) => x.mapv(|v| if v > 0.0 { v } else { alpha * v }),
232 ActivationFunction::Tanh => x.mapv(|v| v.tanh()),
233 ActivationFunction::Sigmoid => x.mapv(|v| 1.0 / (1.0 + (-v).exp())),
234 ActivationFunction::Softmax => {
235 let exp_x = x.mapv(|v| v.exp());
236 let sum = exp_x.sum();
237 exp_x / sum
238 }
239 ActivationFunction::Swish => x.mapv(|v| v / (1.0 + (-v).exp())),
240 ActivationFunction::GELU => x.mapv(|v| 0.5 * v * (1.0 + (v / (2.0_f64).sqrt()).tanh())),
241 ActivationFunction::ELU(alpha) => {
242 x.mapv(|v| if v > 0.0 { v } else { alpha * (v.exp() - 1.0) })
243 }
244 };
245
246 Ok(result)
247 }
248
249 fn apply_dropout(x: &Array1<f64>, dropout_rate: f64) -> IntegrateResult<Array1<f64>> {
251 if dropout_rate == 0.0 {
252 return Ok(x.clone());
253 }
254
255 let mut rng = scirs2_core::random::rng();
256 let mask: Array1<f64> = Array1::from_shape_fn(x.len(), |_| {
257 if rng.random::<f64>() < dropout_rate {
258 0.0
259 } else {
260 1.0 / (1.0 - dropout_rate)
261 }
262 });
263
264 Ok(x * &mask)
265 }
266
267 pub fn train(
269 &mut self,
270 training_data: &[(Array1<f64>, Array1<f64>)],
271 validation_data: Option<&[(Array1<f64>, Array1<f64>)]>,
272 ) -> IntegrateResult<()> {
273 let mut training_metrics = Vec::new();
274 let mut validation_metrics = Vec::new();
275
276 for epoch in 0..self.training_config.epochs {
277 let epoch_loss = self.train_epoch(training_data)?;
278
279 let epoch_metric = super::uncertainty::EpochMetrics {
280 epoch,
281 loss: epoch_loss,
282 accuracy: None,
283 precision: None,
284 recall: None,
285 f1_score: None,
286 learning_rate: self.get_current_learning_rate(epoch),
287 };
288
289 training_metrics.push(epoch_metric.clone());
290
291 if let Some(val_data) = validation_data {
292 let val_loss = self.evaluate(val_data)?;
293 let val_metric = super::uncertainty::EpochMetrics {
294 epoch,
295 loss: val_loss,
296 accuracy: None,
297 precision: None,
298 recall: None,
299 f1_score: None,
300 learning_rate: epoch_metric.learning_rate,
301 };
302 validation_metrics.push(val_metric);
303 }
304
305 if self.should_early_stop(&training_metrics, &validation_metrics) {
307 break;
308 }
309 }
310
311 self.performance_metrics.training_metrics = training_metrics;
312 self.performance_metrics.validation_metrics = validation_metrics;
313
314 Ok(())
315 }
316
317 fn train_epoch(
319 &mut self,
320 training_data: &[(Array1<f64>, Array1<f64>)],
321 ) -> IntegrateResult<f64> {
322 let mut total_loss = 0.0;
323 let batch_size = self.training_config.batch_size;
324
325 for batch_start in (0..training_data.len()).step_by(batch_size) {
326 let batch_end = (batch_start + batch_size).min(training_data.len());
327 let batch = &training_data[batch_start..batch_end];
328
329 let batch_loss = self.train_batch(batch)?;
330 total_loss += batch_loss;
331 }
332
333 Ok(total_loss / (training_data.len() as f64 / batch_size as f64))
334 }
335
336 fn train_batch(&mut self, batch: &[(Array1<f64>, Array1<f64>)]) -> IntegrateResult<f64> {
338 let mut total_loss = 0.0;
339
340 for (input, target) in batch {
341 let prediction = self.forward(input)?;
342 let loss = self.calculate_loss(&prediction, target)?;
343 total_loss += loss;
344
345 self.backward(&prediction, target, input)?;
347 }
348
349 Ok(total_loss / batch.len() as f64)
350 }
351
352 fn calculate_loss(
354 &self,
355 prediction: &Array1<f64>,
356 target: &Array1<f64>,
357 ) -> IntegrateResult<f64> {
358 match self.training_config.loss_function {
359 super::training::LossFunction::MSE => {
360 let diff = prediction - target;
361 Ok(diff.dot(&diff) / prediction.len() as f64)
362 }
363 super::training::LossFunction::CrossEntropy => {
364 let epsilon = 1e-15;
365 let pred_clipped = prediction.mapv(|p| p.max(epsilon).min(1.0 - epsilon));
366 let loss = -target
367 .iter()
368 .zip(pred_clipped.iter())
369 .map(|(&t, &p)| t * p.ln())
370 .sum::<f64>();
371 Ok(loss)
372 }
373 super::training::LossFunction::FocalLoss(alpha, gamma) => {
374 let epsilon = 1e-15;
375 let pred_clipped = prediction.mapv(|p| p.max(epsilon).min(1.0 - epsilon));
376 let loss = -alpha
377 * target
378 .iter()
379 .zip(pred_clipped.iter())
380 .map(|(&t, &p)| t * (1.0 - p).powf(gamma) * p.ln())
381 .sum::<f64>();
382 Ok(loss)
383 }
384 super::training::LossFunction::HuberLoss(delta) => {
385 let diff = prediction - target;
386 let abs_diff = diff.mapv(|d| d.abs());
387 let loss = abs_diff
388 .iter()
389 .map(|&d| {
390 if d <= delta {
391 0.5 * d * d
392 } else {
393 delta * d - 0.5 * delta * delta
394 }
395 })
396 .sum::<f64>();
397 Ok(loss / prediction.len() as f64)
398 }
399 super::training::LossFunction::WeightedMSE => {
400 let diff = prediction - target;
402 Ok(diff.dot(&diff) / prediction.len() as f64)
403 }
404 }
405 }
406
407 fn backward(
409 &mut self,
410 _prediction: &Array1<f64>,
411 _target: &Array1<f64>,
412 _input: &Array1<f64>,
413 ) -> IntegrateResult<()> {
414 Ok(())
417 }
418
419 pub fn evaluate(&self, test_data: &[(Array1<f64>, Array1<f64>)]) -> IntegrateResult<f64> {
421 let mut total_loss = 0.0;
422
423 for (input, target) in test_data {
424 let prediction = self.forward(input)?;
425 let loss = self.calculate_loss(&prediction, target)?;
426 total_loss += loss;
427 }
428
429 Ok(total_loss / test_data.len() as f64)
430 }
431
432 fn get_current_learning_rate(&self, epoch: usize) -> f64 {
434 match &self.training_config.learning_rate {
435 super::training::LearningRateSchedule::Constant(lr) => *lr,
436 super::training::LearningRateSchedule::ExponentialDecay {
437 initial_lr,
438 decay_rate,
439 decay_steps,
440 } => initial_lr * decay_rate.powf(epoch as f64 / *decay_steps as f64),
441 super::training::LearningRateSchedule::CosineAnnealing {
442 initial_lr,
443 min_lr,
444 cycle_length,
445 } => {
446 let cycle_pos = (epoch % cycle_length) as f64 / *cycle_length as f64;
447 min_lr
448 + (initial_lr - min_lr) * (1.0 + (cycle_pos * std::f64::consts::PI).cos()) / 2.0
449 }
450 super::training::LearningRateSchedule::StepDecay {
451 initial_lr,
452 drop_rate,
453 epochs_drop,
454 } => initial_lr * drop_rate.powf((epoch / epochs_drop) as f64),
455 super::training::LearningRateSchedule::Adaptive { initial_lr, .. } => {
456 *initial_lr
458 }
459 }
460 }
461
462 fn should_early_stop(
464 &self,
465 _training_metrics: &[super::uncertainty::EpochMetrics],
466 _validation_metrics: &[super::uncertainty::EpochMetrics],
467 ) -> bool {
468 if !self.training_config.early_stopping.enabled {
469 return false;
470 }
471
472 false
474 }
475
476 pub fn predict_bifurcation(
478 &self,
479 features: &Array1<f64>,
480 ) -> IntegrateResult<BifurcationPrediction> {
481 let raw_output = self.forward(features)?;
482
483 let bifurcation_type = self.classify_bifurcation_type(&raw_output)?;
485 let confidence = self.calculate_confidence(&raw_output)?;
486 let predicted_parameter = raw_output[0]; Ok(BifurcationPrediction {
489 bifurcation_type,
490 predicted_parameter,
491 confidence,
492 raw_output,
493 uncertainty_estimate: None,
494 })
495 }
496
497 fn classify_bifurcation_type(&self, output: &Array1<f64>) -> IntegrateResult<BifurcationType> {
499 let max_idx = output
501 .iter()
502 .enumerate()
503 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
504 .map(|(idx, _)| idx)
505 .unwrap_or(0);
506
507 let bifurcation_type = match max_idx {
509 0 => BifurcationType::Fold,
510 1 => BifurcationType::Transcritical,
511 2 => BifurcationType::Pitchfork,
512 3 => BifurcationType::Hopf,
513 4 => BifurcationType::PeriodDoubling,
514 5 => BifurcationType::Homoclinic,
515 _ => BifurcationType::Unknown,
516 };
517
518 Ok(bifurcation_type)
519 }
520
521 fn calculate_confidence(&self, output: &Array1<f64>) -> IntegrateResult<f64> {
523 let max_prob = output.iter().cloned().fold(0.0, f64::max);
525 Ok(max_prob)
526 }
527}