sklears_multioutput/
hierarchical.rs

1//! Hierarchical classification and graph neural network models
2//!
3//! This module provides algorithms for hierarchical multi-label classification and
4//! graph-based structured prediction tasks. It includes ontology-aware classifiers,
5//! cost-sensitive hierarchical methods, and graph neural networks.
6
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::random::thread_rng;
10use scirs2_core::random::RandNormal;
11use sklears_core::{
12    error::{Result as SklResult, SklearsError},
13    traits::{Estimator, Fit, Predict, Untrained},
14    types::Float,
15};
16use std::collections::HashMap;
17
18/// Consistency enforcement strategies for hierarchical classification
19#[derive(Debug, Clone, Copy, PartialEq, Default)]
20pub enum ConsistencyEnforcement {
21    /// Post-processing approach that corrects predictions after classification
22    #[default]
23    PostProcessing,
24    /// Training-time approach that enforces constraints during optimization
25    ConstrainedTraining,
26    /// Bayesian inference approach using probabilistic dependencies
27    BayesianInference,
28}
29
30/// Ontology-Aware Hierarchical Classifier
31///
32/// A hierarchical multi-label classifier that incorporates domain ontology knowledge
33/// to ensure taxonomically consistent predictions. This method enforces that if a child
34/// concept is predicted, its parent concepts are also predicted according to the
35/// provided hierarchical structure.
36///
37/// # Examples
38///
39/// ```
40/// use sklears_multioutput::{OntologyAwareClassifier, ConsistencyEnforcement};
41/// use sklears_core::traits::{Predict, Fit};
42/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
43/// use scirs2_core::ndarray::array;
44/// use std::collections::HashMap;
45///
46/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
47/// let y = array![[1, 0, 1, 0], [0, 1, 0, 1], [1, 1, 0, 0], [0, 0, 1, 1]];
48///
49/// // Define ontology: child -> parent relationships
50/// let mut ontology = HashMap::new();
51/// ontology.insert(2, vec![0]); // concept 2 is child of concept 0
52/// ontology.insert(3, vec![1]); // concept 3 is child of concept 1
53///
54/// let classifier = OntologyAwareClassifier::new()
55///     .ontology(ontology)
56///     .consistency_enforcement(ConsistencyEnforcement::PostProcessing)
57///     .base_classifier_learning_rate(0.01);
58/// let trained_classifier = classifier.fit(&X.view(), &y).unwrap();
59/// let predictions = trained_classifier.predict(&X.view()).unwrap();
60/// ```
61#[derive(Debug, Clone)]
62pub struct OntologyAwareClassifier<S = Untrained> {
63    state: S,
64    ontology: HashMap<usize, Vec<usize>>,
65    consistency_enforcement: ConsistencyEnforcement,
66    base_classifier_learning_rate: Float,
67    max_iterations: usize,
68}
69
70/// Trained state for OntologyAwareClassifier
71#[derive(Debug, Clone)]
72pub struct OntologyAwareClassifierTrained {
73    weights: Array2<Float>,
74    biases: Array1<Float>,
75    ontology: HashMap<usize, Vec<usize>>,
76    consistency_enforcement: ConsistencyEnforcement,
77    n_features: usize,
78    n_labels: usize,
79}
80
81impl OntologyAwareClassifier<Untrained> {
82    /// Create a new OntologyAwareClassifier
83    pub fn new() -> Self {
84        Self {
85            state: Untrained,
86            ontology: HashMap::new(),
87            consistency_enforcement: ConsistencyEnforcement::PostProcessing,
88            base_classifier_learning_rate: 0.01,
89            max_iterations: 100,
90        }
91    }
92
93    /// Set the ontology (child -> parent relationships)
94    pub fn ontology(mut self, ontology: HashMap<usize, Vec<usize>>) -> Self {
95        self.ontology = ontology;
96        self
97    }
98
99    /// Set the consistency enforcement strategy
100    pub fn consistency_enforcement(mut self, enforcement: ConsistencyEnforcement) -> Self {
101        self.consistency_enforcement = enforcement;
102        self
103    }
104
105    /// Set the learning rate for the base classifier
106    pub fn base_classifier_learning_rate(mut self, learning_rate: Float) -> Self {
107        self.base_classifier_learning_rate = learning_rate;
108        self
109    }
110
111    /// Set the maximum number of iterations
112    pub fn max_iterations(mut self, max_iterations: usize) -> Self {
113        self.max_iterations = max_iterations;
114        self
115    }
116}
117
118impl Default for OntologyAwareClassifier<Untrained> {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124impl Estimator for OntologyAwareClassifier<Untrained> {
125    type Config = ();
126    type Error = SklearsError;
127    type Float = Float;
128
129    fn config(&self) -> &Self::Config {
130        &()
131    }
132}
133
134impl Fit<ArrayView2<'_, Float>, Array2<i32>> for OntologyAwareClassifier<Untrained> {
135    type Fitted = OntologyAwareClassifier<OntologyAwareClassifierTrained>;
136
137    fn fit(
138        self,
139        X: &ArrayView2<'_, Float>,
140        y: &Array2<i32>,
141    ) -> SklResult<OntologyAwareClassifier<OntologyAwareClassifierTrained>> {
142        let (n_samples, n_features) = X.dim();
143        let n_labels = y.ncols();
144
145        if n_samples != y.nrows() {
146            return Err(SklearsError::InvalidInput(
147                "X and y must have the same number of samples".to_string(),
148            ));
149        }
150
151        // Initialize weights and biases
152        let mut weights = Array2::<Float>::zeros((n_features, n_labels));
153        let mut biases = Array1::<Float>::zeros(n_labels);
154
155        // Train base classifiers for each label
156        for iteration in 0..self.max_iterations {
157            let mut total_loss = 0.0;
158
159            for sample_idx in 0..n_samples {
160                let x = X.row(sample_idx);
161                let y_true = y.row(sample_idx);
162
163                // Forward pass
164                let logits = x.dot(&weights) + &biases;
165                let probabilities = logits.mapv(|x| 1.0 / (1.0 + (-x).exp()));
166
167                // Apply consistency constraints during training
168                let consistent_probabilities = match self.consistency_enforcement {
169                    ConsistencyEnforcement::ConstrainedTraining => {
170                        self.enforce_consistency_training(&probabilities)?
171                    }
172                    _ => probabilities.clone(),
173                };
174
175                // Calculate loss and gradients
176                for label_idx in 0..n_labels {
177                    let y_label = y_true[label_idx] as Float;
178                    let prob = consistent_probabilities[label_idx];
179                    let error = prob - y_label;
180
181                    total_loss += if y_label == 1.0 {
182                        -prob.ln()
183                    } else {
184                        -(1.0 - prob).ln()
185                    };
186
187                    // Update weights and biases
188                    for feat_idx in 0..n_features {
189                        weights[[feat_idx, label_idx]] -=
190                            self.base_classifier_learning_rate * error * x[feat_idx];
191                    }
192                    biases[label_idx] -= self.base_classifier_learning_rate * error;
193                }
194            }
195
196            if iteration > 0 && total_loss < 1e-6 {
197                break;
198            }
199        }
200
201        Ok(OntologyAwareClassifier {
202            state: OntologyAwareClassifierTrained {
203                weights,
204                biases,
205                ontology: self.ontology,
206                consistency_enforcement: self.consistency_enforcement,
207                n_features,
208                n_labels,
209            },
210            ontology: HashMap::new(),
211            consistency_enforcement: self.consistency_enforcement,
212            base_classifier_learning_rate: self.base_classifier_learning_rate,
213            max_iterations: self.max_iterations,
214        })
215    }
216}
217
218impl OntologyAwareClassifier<Untrained> {
219    /// Enforce consistency during training
220    fn enforce_consistency_training(
221        &self,
222        probabilities: &Array1<Float>,
223    ) -> SklResult<Array1<Float>> {
224        let mut consistent_probs = probabilities.clone();
225
226        // For training, we enforce that parent probabilities are at least as high as child probabilities
227        for (&child, parents) in &self.ontology {
228            if child < probabilities.len() {
229                for &parent in parents {
230                    if parent < probabilities.len() {
231                        let child_prob = probabilities[child];
232                        if consistent_probs[parent] < child_prob {
233                            consistent_probs[parent] = child_prob;
234                        }
235                    }
236                }
237            }
238        }
239
240        Ok(consistent_probs)
241    }
242}
243
244impl Predict<ArrayView2<'_, Float>, Array2<i32>>
245    for OntologyAwareClassifier<OntologyAwareClassifierTrained>
246{
247    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
248        let (n_samples, n_features) = X.dim();
249
250        if n_features != self.state.n_features {
251            return Err(SklearsError::InvalidInput(
252                "X has different number of features than training data".to_string(),
253            ));
254        }
255
256        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
257
258        for sample_idx in 0..n_samples {
259            let x = X.row(sample_idx);
260
261            // Forward pass
262            let logits = x.dot(&self.state.weights) + &self.state.biases;
263            let probabilities = logits.mapv(|x| 1.0 / (1.0 + (-x).exp()));
264
265            // Apply consistency enforcement
266            let consistent_probs = match self.state.consistency_enforcement {
267                ConsistencyEnforcement::PostProcessing => {
268                    self.enforce_consistency_postprocessing(&probabilities)?
269                }
270                ConsistencyEnforcement::BayesianInference => {
271                    self.enforce_consistency_bayesian(&probabilities)?
272                }
273                _ => probabilities,
274            };
275
276            // Convert probabilities to binary predictions
277            for label_idx in 0..self.state.n_labels {
278                predictions[[sample_idx, label_idx]] = if consistent_probs[label_idx] > 0.5 {
279                    1
280                } else {
281                    0
282                };
283            }
284        }
285
286        Ok(predictions)
287    }
288}
289
290impl OntologyAwareClassifier<OntologyAwareClassifierTrained> {
291    /// Get the learned weights
292    pub fn weights(&self) -> &Array2<Float> {
293        &self.state.weights
294    }
295
296    /// Get the learned biases
297    pub fn biases(&self) -> &Array1<Float> {
298        &self.state.biases
299    }
300
301    /// Get the ontology
302    pub fn ontology(&self) -> &HashMap<usize, Vec<usize>> {
303        &self.state.ontology
304    }
305
306    /// Enforce consistency using post-processing
307    fn enforce_consistency_postprocessing(
308        &self,
309        probabilities: &Array1<Float>,
310    ) -> SklResult<Array1<Float>> {
311        let mut consistent_probs = probabilities.clone();
312
313        // Enforce hierarchical constraints: if child is predicted, parent must be predicted
314        for (&child, parents) in &self.state.ontology {
315            if child < probabilities.len() && probabilities[child] > 0.5 {
316                for &parent in parents {
317                    if parent < probabilities.len() {
318                        consistent_probs[parent] =
319                            consistent_probs[parent].max(probabilities[child]);
320                    }
321                }
322            }
323        }
324
325        Ok(consistent_probs)
326    }
327
328    /// Enforce consistency using Bayesian inference
329    fn enforce_consistency_bayesian(
330        &self,
331        probabilities: &Array1<Float>,
332    ) -> SklResult<Array1<Float>> {
333        let mut consistent_probs = probabilities.clone();
334
335        // Simple Bayesian consistency: P(parent|child) = 1 if child is predicted
336        for (&child, parents) in &self.state.ontology {
337            if child < probabilities.len() {
338                let child_prob = probabilities[child];
339                for &parent in parents {
340                    if parent < probabilities.len() {
341                        // Bayesian update: P(parent) = P(parent) + P(child) * P(parent|child)
342                        // Simplified: if child has high probability, parent should too
343                        consistent_probs[parent] = consistent_probs[parent].max(child_prob * 0.8);
344                    }
345                }
346            }
347        }
348
349        Ok(consistent_probs)
350    }
351}
352
353/// Cost strategy for hierarchical classification
354#[derive(Debug, Clone, Copy, PartialEq, Default)]
355pub enum CostStrategy {
356    /// Uniform misclassification costs
357    #[default]
358    Uniform,
359    /// Distance-based costs (closer nodes have lower cost)
360    DistanceBased,
361    /// Custom cost matrix
362    Custom,
363}
364
365/// Cost-Sensitive Hierarchical Classifier
366///
367/// A hierarchical multi-label classifier that incorporates misclassification costs
368/// and hierarchical relationships to optimize cost-sensitive predictions. This method
369/// can handle different cost strategies including uniform costs, distance-based costs,
370/// and custom cost matrices.
371///
372/// # Examples
373///
374/// ```
375/// use sklears_multioutput::{CostSensitiveHierarchicalClassifier, CostStrategy};
376/// use sklears_core::traits::{Predict, Fit};
377/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
378/// use scirs2_core::ndarray::array;
379/// use std::collections::HashMap;
380///
381/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
382/// let y = array![[1, 0, 1, 0], [0, 1, 0, 1], [1, 1, 0, 0], [0, 0, 1, 1]];
383///
384/// // Define hierarchical structure and costs
385/// let mut hierarchy = HashMap::new();
386/// hierarchy.insert(0, vec![2, 3]); // concept 0 has children 2, 3
387/// hierarchy.insert(1, vec![2, 3]); // concept 1 has children 2, 3
388///
389/// let classifier = CostSensitiveHierarchicalClassifier::new()
390///     .hierarchy(hierarchy)
391///     .cost_strategy(CostStrategy::DistanceBased)
392///     .learning_rate(0.01);
393/// let trained_classifier = classifier.fit(&X.view(), &y).unwrap();
394/// let predictions = trained_classifier.predict(&X.view()).unwrap();
395/// ```
396#[derive(Debug, Clone)]
397pub struct CostSensitiveHierarchicalClassifier<S = Untrained> {
398    state: S,
399    hierarchy: HashMap<usize, Vec<usize>>,
400    cost_strategy: CostStrategy,
401    cost_matrix: Option<Array2<Float>>,
402    learning_rate: Float,
403    max_iterations: usize,
404    lambda_hierarchy: Float,
405    lambda_cost: Float,
406}
407
408/// Trained state for CostSensitiveHierarchicalClassifier
409#[derive(Debug, Clone)]
410pub struct CostSensitiveHierarchicalClassifierTrained {
411    weights: Array2<Float>,
412    /// Hierarchical structure
413    hierarchy: HashMap<usize, Vec<usize>>,
414    cost_strategy: CostStrategy,
415    cost_matrix: Option<Array2<Float>>,
416    n_features: usize,
417    n_labels: usize,
418    lambda_hierarchy: Float,
419    lambda_cost: Float,
420}
421
422impl CostSensitiveHierarchicalClassifier<Untrained> {
423    /// Create a new CostSensitiveHierarchicalClassifier
424    pub fn new() -> Self {
425        Self {
426            state: Untrained,
427            hierarchy: HashMap::new(),
428            cost_strategy: CostStrategy::Uniform,
429            cost_matrix: None,
430            learning_rate: 0.01,
431            max_iterations: 100,
432            lambda_hierarchy: 1.0,
433            lambda_cost: 1.0,
434        }
435    }
436
437    /// Set the hierarchical structure (parent -> children relationships)
438    pub fn hierarchy(mut self, hierarchy: HashMap<usize, Vec<usize>>) -> Self {
439        self.hierarchy = hierarchy;
440        self
441    }
442
443    /// Set the cost strategy
444    pub fn cost_strategy(mut self, strategy: CostStrategy) -> Self {
445        self.cost_strategy = strategy;
446        self
447    }
448
449    /// Set a custom cost matrix
450    pub fn cost_matrix(mut self, cost_matrix: Array2<Float>) -> Self {
451        self.cost_matrix = Some(cost_matrix);
452        self
453    }
454
455    /// Set the learning rate
456    pub fn learning_rate(mut self, learning_rate: Float) -> Self {
457        self.learning_rate = learning_rate;
458        self
459    }
460
461    /// Set the maximum number of iterations
462    pub fn max_iterations(mut self, max_iterations: usize) -> Self {
463        self.max_iterations = max_iterations;
464        self
465    }
466
467    /// Set the hierarchical constraint weight
468    pub fn lambda_hierarchy(mut self, lambda: Float) -> Self {
469        self.lambda_hierarchy = lambda;
470        self
471    }
472
473    /// Set the cost constraint weight
474    pub fn lambda_cost(mut self, lambda: Float) -> Self {
475        self.lambda_cost = lambda;
476        self
477    }
478}
479
480impl Default for CostSensitiveHierarchicalClassifier<Untrained> {
481    fn default() -> Self {
482        Self::new()
483    }
484}
485
486impl Estimator for CostSensitiveHierarchicalClassifier<Untrained> {
487    type Config = ();
488    type Error = SklearsError;
489    type Float = Float;
490
491    fn config(&self) -> &Self::Config {
492        &()
493    }
494}
495
496impl Fit<ArrayView2<'_, Float>, Array2<i32>> for CostSensitiveHierarchicalClassifier<Untrained> {
497    type Fitted = CostSensitiveHierarchicalClassifier<CostSensitiveHierarchicalClassifierTrained>;
498
499    fn fit(
500        self,
501        X: &ArrayView2<'_, Float>,
502        y: &Array2<i32>,
503    ) -> SklResult<CostSensitiveHierarchicalClassifier<CostSensitiveHierarchicalClassifierTrained>>
504    {
505        let (n_samples, n_features) = X.dim();
506        let n_labels = y.ncols();
507
508        if n_samples != y.nrows() {
509            return Err(SklearsError::InvalidInput(
510                "X and y must have the same number of samples".to_string(),
511            ));
512        }
513
514        // Initialize cost matrix if not provided
515        let cost_matrix = match &self.cost_matrix {
516            Some(matrix) => matrix.clone(),
517            None => self.generate_cost_matrix(n_labels)?,
518        };
519
520        // Initialize weights
521        let mut weights = Array2::<Float>::zeros((n_features, n_labels));
522
523        // Training loop with cost-sensitive and hierarchical constraints
524        for _iteration in 0..self.max_iterations {
525            for sample_idx in 0..n_samples {
526                let x = X.row(sample_idx);
527                let y_true = y.row(sample_idx);
528
529                // Forward pass
530                let scores = x.dot(&weights);
531                let probabilities = scores.mapv(|x| 1.0 / (1.0 + (-x).exp()));
532
533                // Calculate gradients with cost-sensitive and hierarchical terms
534                for label_idx in 0..n_labels {
535                    let y_label = y_true[label_idx] as Float;
536                    let prob = probabilities[label_idx];
537
538                    // Standard logistic loss gradient
539                    let mut gradient = prob - y_label;
540
541                    // Add cost-sensitive term
542                    let cost_weight = cost_matrix[[label_idx, label_idx]];
543                    gradient *= cost_weight * self.lambda_cost;
544
545                    // Add hierarchical constraint term
546                    gradient += self.lambda_hierarchy
547                        * self.hierarchical_gradient(label_idx, &probabilities, &y_true)?;
548
549                    // Update weights
550                    for feat_idx in 0..n_features {
551                        weights[[feat_idx, label_idx]] -=
552                            self.learning_rate * gradient * x[feat_idx];
553                    }
554                }
555            }
556        }
557
558        Ok(CostSensitiveHierarchicalClassifier {
559            state: CostSensitiveHierarchicalClassifierTrained {
560                weights,
561                hierarchy: self.hierarchy,
562                cost_strategy: self.cost_strategy,
563                cost_matrix: Some(cost_matrix),
564                n_features,
565                n_labels,
566                lambda_hierarchy: self.lambda_hierarchy,
567                lambda_cost: self.lambda_cost,
568            },
569            hierarchy: HashMap::new(),
570            cost_strategy: self.cost_strategy,
571            cost_matrix: None,
572            learning_rate: self.learning_rate,
573            max_iterations: self.max_iterations,
574            lambda_hierarchy: self.lambda_hierarchy,
575            lambda_cost: self.lambda_cost,
576        })
577    }
578}
579
580impl CostSensitiveHierarchicalClassifier<Untrained> {
581    /// Generate cost matrix based on strategy
582    fn generate_cost_matrix(&self, n_labels: usize) -> SklResult<Array2<Float>> {
583        match self.cost_strategy {
584            CostStrategy::Uniform => Ok(Array2::eye(n_labels)),
585            CostStrategy::DistanceBased => {
586                let mut cost_matrix = Array2::<Float>::zeros((n_labels, n_labels));
587                // Simple distance-based costs (can be enhanced with actual hierarchy distances)
588                for i in 0..n_labels {
589                    for j in 0..n_labels {
590                        cost_matrix[[i, j]] = if i == j { 1.0 } else { 0.5 };
591                    }
592                }
593                Ok(cost_matrix)
594            }
595            CostStrategy::Custom => Err(SklearsError::InvalidInput(
596                "Custom cost strategy requires a cost matrix".to_string(),
597            )),
598        }
599    }
600
601    /// Calculate hierarchical gradient term
602    fn hierarchical_gradient(
603        &self,
604        label_idx: usize,
605        probabilities: &Array1<Float>,
606        y_true: &ArrayView1<i32>,
607    ) -> SklResult<Float> {
608        let mut gradient = 0.0;
609
610        // If this label has children, enforce that children can't be more probable than parent
611        if let Some(children) = self.hierarchy.get(&label_idx) {
612            for &child in children {
613                if child < probabilities.len() {
614                    let parent_prob = probabilities[label_idx];
615                    let child_prob = probabilities[child];
616                    let child_true = y_true[child] as Float;
617
618                    // Penalty if child probability exceeds parent probability when child is true
619                    if child_true > 0.5 && child_prob > parent_prob {
620                        gradient += child_prob - parent_prob;
621                    }
622                }
623            }
624        }
625
626        // If this label is a child, enforce consistency with parents
627        for (&parent, children) in &self.hierarchy {
628            if children.contains(&label_idx) && parent < probabilities.len() {
629                let parent_prob = probabilities[parent];
630                let child_prob = probabilities[label_idx];
631                let label_true = y_true[label_idx] as Float;
632
633                // Penalty if child is predicted but parent is not
634                if label_true > 0.5 && child_prob > parent_prob {
635                    gradient -= child_prob - parent_prob;
636                }
637            }
638        }
639
640        Ok(gradient)
641    }
642}
643
644impl Predict<ArrayView2<'_, Float>, Array2<i32>>
645    for CostSensitiveHierarchicalClassifier<CostSensitiveHierarchicalClassifierTrained>
646{
647    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
648        let (n_samples, n_features) = X.dim();
649
650        if n_features != self.state.n_features {
651            return Err(SklearsError::InvalidInput(
652                "X has different number of features than training data".to_string(),
653            ));
654        }
655
656        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
657
658        for sample_idx in 0..n_samples {
659            let x = X.row(sample_idx);
660
661            // Forward pass
662            let scores = x.dot(&self.state.weights);
663            let probabilities = scores.mapv(|x| 1.0 / (1.0 + (-x).exp()));
664
665            // Apply hierarchical constraints and cost-sensitive thresholding
666            let final_predictions = self.apply_constraints(&probabilities)?;
667
668            for label_idx in 0..self.state.n_labels {
669                predictions[[sample_idx, label_idx]] = final_predictions[label_idx];
670            }
671        }
672
673        Ok(predictions)
674    }
675}
676
677impl CostSensitiveHierarchicalClassifier<CostSensitiveHierarchicalClassifierTrained> {
678    /// Get the learned weights
679    pub fn weights(&self) -> &Array2<Float> {
680        &self.state.weights
681    }
682
683    /// Get the cost matrix
684    pub fn cost_matrix(&self) -> Option<&Array2<Float>> {
685        self.state.cost_matrix.as_ref()
686    }
687
688    /// Apply constraints to get final predictions
689    fn apply_constraints(&self, probabilities: &Array1<Float>) -> SklResult<Array1<i32>> {
690        let mut binary_predictions = Array1::<i32>::zeros(probabilities.len());
691
692        // Convert probabilities to binary predictions with cost-sensitive thresholds
693        for i in 0..probabilities.len() {
694            let threshold = if let Some(cost_matrix) = &self.state.cost_matrix {
695                // Adjust threshold based on cost
696                let cost = cost_matrix[[i, i]];
697                0.5 / cost.max(0.1) // Higher cost = lower threshold
698            } else {
699                0.5
700            };
701
702            binary_predictions[i] = if probabilities[i] > threshold { 1 } else { 0 };
703        }
704
705        // Enforce hierarchical constraints
706        for (&parent, children) in &self.state.hierarchy {
707            if parent < binary_predictions.len() {
708                // If any child is predicted, parent must be predicted
709                let mut any_child_predicted = false;
710                for &child in children {
711                    if child < binary_predictions.len() && binary_predictions[child] == 1 {
712                        any_child_predicted = true;
713                        break;
714                    }
715                }
716                if any_child_predicted {
717                    binary_predictions[parent] = 1;
718                }
719            }
720        }
721
722        Ok(binary_predictions)
723    }
724}
725
726// Graph Neural Networks for Structured Output Prediction
727
728/// Aggregation functions for Graph Neural Networks
729#[derive(Debug, Clone, Copy, PartialEq)]
730pub enum AggregationFunction {
731    /// Mean aggregation
732    Mean,
733    /// Sum aggregation
734    Sum,
735    /// Max aggregation
736    Max,
737    /// Attention-based aggregation
738    Attention,
739}
740
741/// Message passing variants for Graph Neural Networks
742#[derive(Debug, Clone, Copy, PartialEq)]
743pub enum MessagePassingVariant {
744    /// Graph Convolutional Network (GCN)
745    GCN,
746    /// Graph Attention Network (GAT)
747    GAT,
748    /// GraphSAGE
749    GraphSAGE,
750    /// Graph Isomorphism Network (GIN)
751    GIN,
752}
753
754/// Graph Neural Network for Structured Output Prediction
755///
756/// A graph neural network implementation for multi-output prediction tasks where
757/// the outputs have structural relationships represented as a graph. This method
758/// can leverage node features, edge information, and graph topology to make
759/// predictions that respect the underlying graph structure.
760///
761/// # Examples
762///
763/// ```
764/// use sklears_multioutput::{GraphNeuralNetwork, MessagePassingVariant, AggregationFunction};
765/// use sklears_core::traits::{Predict, Fit};
766/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
767/// use scirs2_core::ndarray::array;
768///
769/// // Node features and adjacency matrix
770/// let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0], [1.0, 3.0]];
771/// let adjacency = array![[0, 1, 1, 0, 0], [1, 0, 1, 1, 0], [1, 1, 0, 0, 1],
772///                        [0, 1, 0, 0, 1], [0, 0, 1, 1, 0]];
773/// let node_labels = array![[1, 0, 1], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 0]];
774///
775/// let gnn = GraphNeuralNetwork::new()
776///     .hidden_dim(16)
777///     .num_layers(2)
778///     .message_passing_variant(MessagePassingVariant::GCN)
779///     .aggregation_function(AggregationFunction::Mean);
780/// let trained_gnn = gnn.fit_graph(&adjacency.view(), &node_features.view(), &node_labels).unwrap();
781/// let predictions = trained_gnn.predict_graph(&adjacency.view(), &node_features.view()).unwrap();
782/// ```
783#[derive(Debug, Clone)]
784pub struct GraphNeuralNetwork<S = Untrained> {
785    state: S,
786    hidden_dim: usize,
787    num_layers: usize,
788    message_passing_variant: MessagePassingVariant,
789    aggregation_function: AggregationFunction,
790    learning_rate: Float,
791    max_iter: usize,
792    dropout_rate: Float,
793    random_state: Option<u64>,
794}
795
796/// Trained state for GraphNeuralNetwork
797#[derive(Debug, Clone)]
798pub struct GraphNeuralNetworkTrained {
799    /// Layer weights: Vec of (input_dim x output_dim) matrices
800    layer_weights: Vec<Array2<Float>>,
801    /// Layer biases: Vec of output_dim vectors
802    layer_biases: Vec<Array1<Float>>,
803    /// Attention weights for GAT (if applicable)
804    attention_weights: Option<Vec<Array2<Float>>>,
805    /// Model configuration
806    hidden_dim: usize,
807    num_layers: usize,
808    message_passing_variant: MessagePassingVariant,
809    aggregation_function: AggregationFunction,
810    n_features: usize,
811    n_outputs: usize,
812    dropout_rate: Float,
813}
814
815impl GraphNeuralNetwork<Untrained> {
816    /// Create a new GraphNeuralNetwork instance
817    pub fn new() -> Self {
818        Self {
819            state: Untrained,
820            hidden_dim: 32,
821            num_layers: 2,
822            message_passing_variant: MessagePassingVariant::GCN,
823            aggregation_function: AggregationFunction::Mean,
824            learning_rate: 0.01,
825            max_iter: 100,
826            dropout_rate: 0.0,
827            random_state: None,
828        }
829    }
830
831    /// Set the hidden dimension
832    pub fn hidden_dim(mut self, hidden_dim: usize) -> Self {
833        self.hidden_dim = hidden_dim;
834        self
835    }
836
837    /// Set the number of layers
838    pub fn num_layers(mut self, num_layers: usize) -> Self {
839        self.num_layers = num_layers;
840        self
841    }
842
843    /// Set the message passing variant
844    pub fn message_passing_variant(mut self, variant: MessagePassingVariant) -> Self {
845        self.message_passing_variant = variant;
846        self
847    }
848
849    /// Set the aggregation function
850    pub fn aggregation_function(mut self, function: AggregationFunction) -> Self {
851        self.aggregation_function = function;
852        self
853    }
854
855    /// Set the learning rate
856    pub fn learning_rate(mut self, learning_rate: Float) -> Self {
857        self.learning_rate = learning_rate;
858        self
859    }
860
861    /// Set the maximum number of iterations
862    pub fn max_iter(mut self, max_iter: usize) -> Self {
863        self.max_iter = max_iter;
864        self
865    }
866
867    /// Set the dropout rate
868    pub fn dropout_rate(mut self, dropout_rate: Float) -> Self {
869        self.dropout_rate = dropout_rate;
870        self
871    }
872
873    /// Set random state for reproducible results
874    pub fn random_state(mut self, random_state: u64) -> Self {
875        self.random_state = Some(random_state);
876        self
877    }
878}
879
880impl Default for GraphNeuralNetwork<Untrained> {
881    fn default() -> Self {
882        Self::new()
883    }
884}
885
886impl Estimator for GraphNeuralNetwork<Untrained> {
887    type Config = ();
888    type Error = SklearsError;
889    type Float = Float;
890
891    fn config(&self) -> &Self::Config {
892        &()
893    }
894}
895
896/// Fit method for Graph Neural Networks with graph structure
897impl GraphNeuralNetwork<Untrained> {
898    /// Fit the GNN using graph structure, node features, and node labels
899    pub fn fit_graph(
900        self,
901        adjacency: &ArrayView2<'_, i32>,
902        node_features: &ArrayView2<'_, Float>,
903        node_labels: &Array2<i32>,
904    ) -> SklResult<GraphNeuralNetwork<GraphNeuralNetworkTrained>> {
905        let (n_nodes, n_features) = node_features.dim();
906        let n_outputs = node_labels.ncols();
907
908        if adjacency.dim() != (n_nodes, n_nodes) {
909            return Err(SklearsError::InvalidInput(
910                "Adjacency matrix must be n_nodes x n_nodes".to_string(),
911            ));
912        }
913
914        if node_labels.nrows() != n_nodes {
915            return Err(SklearsError::InvalidInput(
916                "Node labels must have same number of rows as nodes".to_string(),
917            ));
918        }
919
920        // Initialize parameters
921        let mut rng_instance = thread_rng();
922        let (layer_weights, layer_biases, attention_weights) =
923            self.initialize_gnn_parameters(n_features, n_outputs, &mut rng_instance)?;
924
925        // Training loop (simplified gradient descent)
926        let mut weights = layer_weights;
927        let biases = layer_biases;
928        let attention_weights = attention_weights;
929
930        for _iteration in 0..self.max_iter {
931            // Forward pass
932            let (node_embeddings, _) = self.forward_pass_graph(
933                adjacency,
934                node_features,
935                &weights,
936                &biases,
937                &attention_weights,
938            )?;
939
940            // Compute loss and gradients (simplified)
941            let predictions = node_embeddings.mapv(|x| if x > 0.0 { 1 } else { 0 });
942
943            // Simple gradient update (in practice, would use backpropagation)
944            for weight in &mut weights {
945                for i in 0..weight.nrows() {
946                    for j in 0..weight.ncols() {
947                        weight[[i, j]] *= 0.999; // Simple weight decay
948                    }
949                }
950            }
951        }
952
953        let trained_state = GraphNeuralNetworkTrained {
954            layer_weights: weights,
955            layer_biases: biases,
956            attention_weights,
957            hidden_dim: self.hidden_dim,
958            num_layers: self.num_layers,
959            message_passing_variant: self.message_passing_variant,
960            aggregation_function: self.aggregation_function,
961            n_features,
962            n_outputs,
963            dropout_rate: self.dropout_rate,
964        };
965
966        Ok(GraphNeuralNetwork {
967            state: trained_state,
968            hidden_dim: self.hidden_dim,
969            num_layers: self.num_layers,
970            message_passing_variant: self.message_passing_variant,
971            aggregation_function: self.aggregation_function,
972            learning_rate: self.learning_rate,
973            max_iter: self.max_iter,
974            dropout_rate: self.dropout_rate,
975            random_state: self.random_state,
976        })
977    }
978
979    /// Initialize GNN parameters
980    #[allow(clippy::type_complexity)]
981    fn initialize_gnn_parameters(
982        &self,
983        n_features: usize,
984        n_outputs: usize,
985        rng: &mut scirs2_core::random::CoreRandom,
986    ) -> SklResult<(
987        Vec<Array2<Float>>,
988        Vec<Array1<Float>>,
989        Option<Vec<Array2<Float>>>,
990    )> {
991        let mut layer_weights = Vec::new();
992        let mut layer_biases = Vec::new();
993        let mut attention_weights = None;
994
995        // Input layer
996        let input_dim = match self.message_passing_variant {
997            MessagePassingVariant::GraphSAGE => n_features * 2, // Concatenated features
998            _ => n_features,
999        };
1000
1001        // Hidden layers
1002        let hidden_dim = match self.message_passing_variant {
1003            MessagePassingVariant::GraphSAGE => self.hidden_dim * 2, // Concatenated features
1004            _ => self.hidden_dim,
1005        };
1006
1007        // Initialize weights for each layer
1008        for layer_idx in 0..self.num_layers {
1009            let (in_dim, out_dim) = if layer_idx == 0 {
1010                (input_dim, self.hidden_dim)
1011            } else if layer_idx == self.num_layers - 1 {
1012                (hidden_dim, n_outputs)
1013            } else {
1014                (hidden_dim, self.hidden_dim)
1015            };
1016
1017            let normal_dist = RandNormal::new(0.0, (2.0 / in_dim as Float).sqrt()).unwrap();
1018            let mut input_weight = Array2::<Float>::zeros((in_dim, out_dim));
1019            for i in 0..in_dim {
1020                for j in 0..out_dim {
1021                    input_weight[[i, j]] = rng.sample(normal_dist);
1022                }
1023            }
1024            let bias = Array1::<Float>::zeros(out_dim);
1025
1026            layer_weights.push(input_weight);
1027            layer_biases.push(bias);
1028        }
1029
1030        // Initialize attention weights for GAT
1031        if self.message_passing_variant == MessagePassingVariant::GAT {
1032            let mut att_weights = Vec::new();
1033            for layer_idx in 0..self.num_layers {
1034                let att_dim = if layer_idx == 0 {
1035                    n_features
1036                } else {
1037                    self.hidden_dim
1038                };
1039                let att_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
1040                let mut attention_weight = Array2::<Float>::zeros((att_dim * 2, 1));
1041                for i in 0..(att_dim * 2) {
1042                    attention_weight[[i, 0]] = rng.sample(att_normal_dist);
1043                }
1044                att_weights.push(attention_weight);
1045            }
1046            attention_weights = Some(att_weights);
1047        }
1048
1049        Ok((layer_weights, layer_biases, attention_weights))
1050    }
1051
1052    /// Forward pass through the graph neural network
1053    fn forward_pass_graph(
1054        &self,
1055        adjacency: &ArrayView2<'_, i32>,
1056        node_features: &ArrayView2<'_, Float>,
1057        weights: &[Array2<Float>],
1058        biases: &[Array1<Float>],
1059        attention_weights: &Option<Vec<Array2<Float>>>,
1060    ) -> SklResult<(Array2<Float>, Vec<Array2<Float>>)> {
1061        let n_nodes = node_features.nrows();
1062        let mut current_embeddings = node_features.to_owned();
1063        let mut layer_outputs = Vec::new();
1064
1065        for layer_idx in 0..self.num_layers {
1066            let layer_output = match self.message_passing_variant {
1067                MessagePassingVariant::GCN => self.gcn_layer(
1068                    &current_embeddings,
1069                    adjacency,
1070                    &weights[layer_idx],
1071                    &biases[layer_idx],
1072                )?,
1073                MessagePassingVariant::GAT => {
1074                    let att_weights = attention_weights.as_ref().unwrap();
1075                    self.gat_layer(
1076                        &current_embeddings,
1077                        adjacency,
1078                        &weights[layer_idx],
1079                        &biases[layer_idx],
1080                        &att_weights[layer_idx],
1081                    )?
1082                }
1083                MessagePassingVariant::GraphSAGE => self.graphsage_layer(
1084                    &current_embeddings,
1085                    adjacency,
1086                    &weights[layer_idx],
1087                    &biases[layer_idx],
1088                )?,
1089                MessagePassingVariant::GIN => self.gin_layer(
1090                    &current_embeddings,
1091                    adjacency,
1092                    &weights[layer_idx],
1093                    &biases[layer_idx],
1094                )?,
1095            };
1096
1097            current_embeddings = layer_output.clone();
1098            layer_outputs.push(layer_output);
1099        }
1100
1101        Ok((current_embeddings, layer_outputs))
1102    }
1103
1104    /// Graph Convolutional Network layer
1105    fn gcn_layer(
1106        &self,
1107        node_embeddings: &Array2<Float>,
1108        adjacency: &ArrayView2<'_, i32>,
1109        weights: &Array2<Float>,
1110        bias: &Array1<Float>,
1111    ) -> SklResult<Array2<Float>> {
1112        let n_nodes = node_embeddings.nrows();
1113        let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1114
1115        for i in 0..n_nodes {
1116            let mut aggregated = Array1::<Float>::zeros(node_embeddings.ncols());
1117            let mut degree = 0;
1118
1119            // Aggregate neighbor features
1120            for j in 0..n_nodes {
1121                if adjacency[[i, j]] == 1 {
1122                    aggregated += &node_embeddings.row(j).to_owned();
1123                    degree += 1;
1124                }
1125            }
1126
1127            // Add self-loop
1128            aggregated += &node_embeddings.row(i).to_owned();
1129            degree += 1;
1130
1131            // Normalize by degree
1132            if degree > 0 {
1133                aggregated /= degree as Float;
1134            }
1135
1136            // Apply linear transformation
1137            let transformed = aggregated.dot(weights) + bias;
1138            let activated = transformed.mapv(|x| x.max(0.0)); // ReLU activation
1139
1140            output.row_mut(i).assign(&activated);
1141        }
1142
1143        Ok(output)
1144    }
1145
1146    /// Graph Attention Network layer (simplified)
1147    fn gat_layer(
1148        &self,
1149        node_embeddings: &Array2<Float>,
1150        adjacency: &ArrayView2<'_, i32>,
1151        weights: &Array2<Float>,
1152        bias: &Array1<Float>,
1153        attention_weights: &Array2<Float>,
1154    ) -> SklResult<Array2<Float>> {
1155        let n_nodes = node_embeddings.nrows();
1156        let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1157
1158        for i in 0..n_nodes {
1159            let mut attention_scores = Array1::<Float>::zeros(n_nodes);
1160            let mut valid_neighbors = Vec::new();
1161
1162            // Calculate attention scores
1163            for j in 0..n_nodes {
1164                if adjacency[[i, j]] == 1 || i == j {
1165                    // Concatenate node features for attention computation
1166                    let concat_features = Array1::from_iter(
1167                        node_embeddings
1168                            .row(i)
1169                            .iter()
1170                            .chain(node_embeddings.row(j).iter())
1171                            .cloned(),
1172                    );
1173
1174                    if concat_features.len() == attention_weights.nrows() {
1175                        let score = concat_features.dot(&attention_weights.column(0));
1176                        attention_scores[j] = score.exp();
1177                        valid_neighbors.push(j);
1178                    }
1179                }
1180            }
1181
1182            // Normalize attention scores
1183            let total_attention: Float = valid_neighbors.iter().map(|&j| attention_scores[j]).sum();
1184            if total_attention > 0.0 {
1185                for &j in &valid_neighbors {
1186                    attention_scores[j] /= total_attention;
1187                }
1188            }
1189
1190            // Aggregate features using attention weights
1191            let mut aggregated = Array1::<Float>::zeros(node_embeddings.ncols());
1192            for &j in &valid_neighbors {
1193                let weighted_features = &node_embeddings.row(j).to_owned() * attention_scores[j];
1194                aggregated += &weighted_features;
1195            }
1196
1197            // Apply linear transformation
1198            let transformed = aggregated.dot(weights) + bias;
1199            let activated = transformed.mapv(|x| x.max(0.0)); // ReLU activation
1200
1201            output.row_mut(i).assign(&activated);
1202        }
1203
1204        Ok(output)
1205    }
1206
1207    /// GraphSAGE layer (simplified)
1208    fn graphsage_layer(
1209        &self,
1210        node_embeddings: &Array2<Float>,
1211        adjacency: &ArrayView2<'_, i32>,
1212        weights: &Array2<Float>,
1213        bias: &Array1<Float>,
1214    ) -> SklResult<Array2<Float>> {
1215        let n_nodes = node_embeddings.nrows();
1216        let embedding_dim = node_embeddings.ncols();
1217        let output_dim = weights.ncols();
1218        let mut output = Array2::<Float>::zeros((n_nodes, output_dim));
1219
1220        for i in 0..n_nodes {
1221            // Aggregate neighbor features
1222            let mut neighbor_sum = Array1::<Float>::zeros(embedding_dim);
1223            let mut neighbor_count = 0;
1224
1225            for j in 0..n_nodes {
1226                if adjacency[[i, j]] == 1 && i != j {
1227                    neighbor_sum += &node_embeddings.row(j).to_owned();
1228                    neighbor_count += 1;
1229                }
1230            }
1231
1232            // Average pooling of neighbors
1233            if neighbor_count > 0 {
1234                neighbor_sum /= neighbor_count as Float;
1235            }
1236
1237            // Concatenate self and neighbor representations
1238            let self_features = node_embeddings.row(i).to_owned();
1239            let concatenated =
1240                Array1::from_iter(self_features.iter().chain(neighbor_sum.iter()).cloned());
1241
1242            // Apply linear transformation (note: weights should match concatenated dimension)
1243            if concatenated.len() == weights.nrows() {
1244                let transformed = concatenated.dot(weights) + bias;
1245                let activated = transformed.mapv(|x| x.max(0.0)); // ReLU activation
1246                output.row_mut(i).assign(&activated);
1247            }
1248        }
1249
1250        Ok(output)
1251    }
1252
1253    /// Graph Isomorphism Network layer
1254    fn gin_layer(
1255        &self,
1256        node_embeddings: &Array2<Float>,
1257        adjacency: &ArrayView2<'_, i32>,
1258        weights: &Array2<Float>,
1259        bias: &Array1<Float>,
1260    ) -> SklResult<Array2<Float>> {
1261        let n_nodes = node_embeddings.nrows();
1262        let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1263        let epsilon = 0.0; // Learnable parameter, simplified as 0
1264
1265        for i in 0..n_nodes {
1266            // Sum neighbor features
1267            let mut neighbor_sum = Array1::<Float>::zeros(node_embeddings.ncols());
1268
1269            for j in 0..n_nodes {
1270                if adjacency[[i, j]] == 1 && i != j {
1271                    neighbor_sum += &node_embeddings.row(j).to_owned();
1272                }
1273            }
1274
1275            // GIN update: (1 + epsilon) * h_i + sum(h_j for j in neighbors)
1276            let updated = &node_embeddings.row(i).to_owned() * (1.0 + epsilon) + &neighbor_sum;
1277
1278            // Apply MLP (simplified as single linear layer)
1279            let transformed = updated.dot(weights) + bias;
1280            let activated = transformed.mapv(|x| x.max(0.0)); // ReLU activation
1281
1282            output.row_mut(i).assign(&activated);
1283        }
1284
1285        Ok(output)
1286    }
1287}
1288
1289impl GraphNeuralNetwork<GraphNeuralNetworkTrained> {
1290    /// Predict node labels using the trained GNN
1291    pub fn predict_graph(
1292        &self,
1293        adjacency: &ArrayView2<'_, i32>,
1294        node_features: &ArrayView2<'_, Float>,
1295    ) -> SklResult<Array2<i32>> {
1296        let (n_nodes, n_features) = node_features.dim();
1297
1298        if n_features != self.state.n_features {
1299            return Err(SklearsError::InvalidInput(
1300                "Node features have different dimensionality than training data".to_string(),
1301            ));
1302        }
1303
1304        if adjacency.dim() != (n_nodes, n_nodes) {
1305            return Err(SklearsError::InvalidInput(
1306                "Adjacency matrix must be n_nodes x n_nodes".to_string(),
1307            ));
1308        }
1309
1310        // Forward pass
1311        let (final_embeddings, _) = self.forward_pass_trained(adjacency, node_features)?;
1312
1313        // Convert to binary predictions
1314        let predictions = final_embeddings.mapv(|x| if x > 0.0 { 1 } else { 0 });
1315
1316        Ok(predictions)
1317    }
1318
1319    /// Get the hidden dimension
1320    pub fn hidden_dim(&self) -> usize {
1321        self.state.hidden_dim
1322    }
1323
1324    /// Get the number of layers
1325    pub fn num_layers(&self) -> usize {
1326        self.state.num_layers
1327    }
1328
1329    /// Forward pass for trained model
1330    fn forward_pass_trained(
1331        &self,
1332        adjacency: &ArrayView2<'_, i32>,
1333        node_features: &ArrayView2<'_, Float>,
1334    ) -> SklResult<(Array2<Float>, Vec<Array2<Float>>)> {
1335        let n_nodes = node_features.nrows();
1336        let mut current_embeddings = node_features.to_owned();
1337        let mut layer_outputs = Vec::new();
1338
1339        for layer_idx in 0..self.state.num_layers {
1340            let layer_output = match self.state.message_passing_variant {
1341                MessagePassingVariant::GCN => {
1342                    self.gcn_layer_trained(&current_embeddings, adjacency, layer_idx)?
1343                }
1344                MessagePassingVariant::GAT => {
1345                    self.gat_layer_trained(&current_embeddings, adjacency, layer_idx)?
1346                }
1347                MessagePassingVariant::GraphSAGE => {
1348                    self.graphsage_layer_trained(&current_embeddings, adjacency, layer_idx)?
1349                }
1350                MessagePassingVariant::GIN => {
1351                    self.gin_layer_trained(&current_embeddings, adjacency, layer_idx)?
1352                }
1353            };
1354
1355            current_embeddings = layer_output.clone();
1356            layer_outputs.push(layer_output);
1357        }
1358
1359        Ok((current_embeddings, layer_outputs))
1360    }
1361
1362    /// GCN layer for trained model
1363    fn gcn_layer_trained(
1364        &self,
1365        node_embeddings: &Array2<Float>,
1366        adjacency: &ArrayView2<'_, i32>,
1367        layer_idx: usize,
1368    ) -> SklResult<Array2<Float>> {
1369        let weights = &self.state.layer_weights[layer_idx];
1370        let bias = &self.state.layer_biases[layer_idx];
1371        let n_nodes = node_embeddings.nrows();
1372        let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1373
1374        for i in 0..n_nodes {
1375            let mut aggregated = Array1::<Float>::zeros(node_embeddings.ncols());
1376            let mut degree = 0;
1377
1378            // Aggregate neighbor features
1379            for j in 0..n_nodes {
1380                if adjacency[[i, j]] == 1 {
1381                    aggregated += &node_embeddings.row(j).to_owned();
1382                    degree += 1;
1383                }
1384            }
1385
1386            // Add self-loop
1387            aggregated += &node_embeddings.row(i).to_owned();
1388            degree += 1;
1389
1390            // Normalize by degree
1391            if degree > 0 {
1392                aggregated /= degree as Float;
1393            }
1394
1395            // Apply linear transformation
1396            let transformed = aggregated.dot(weights) + bias;
1397            let activated = if layer_idx == self.state.num_layers - 1 {
1398                // Output layer: sigmoid activation for binary classification
1399                transformed.mapv(|x| 1.0 / (1.0 + (-x).exp()))
1400            } else {
1401                // Hidden layers: ReLU activation
1402                transformed.mapv(|x| x.max(0.0))
1403            };
1404
1405            output.row_mut(i).assign(&activated);
1406        }
1407
1408        Ok(output)
1409    }
1410
1411    /// GAT layer for trained model
1412    fn gat_layer_trained(
1413        &self,
1414        node_embeddings: &Array2<Float>,
1415        adjacency: &ArrayView2<'_, i32>,
1416        layer_idx: usize,
1417    ) -> SklResult<Array2<Float>> {
1418        let weights = &self.state.layer_weights[layer_idx];
1419        let bias = &self.state.layer_biases[layer_idx];
1420        let attention_weights = self.state.attention_weights.as_ref().unwrap();
1421        let att_weights = &attention_weights[layer_idx];
1422
1423        let n_nodes = node_embeddings.nrows();
1424        let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1425
1426        for i in 0..n_nodes {
1427            let mut attention_scores = Array1::<Float>::zeros(n_nodes);
1428            let mut valid_neighbors = Vec::new();
1429
1430            // Calculate attention scores
1431            for j in 0..n_nodes {
1432                if adjacency[[i, j]] == 1 || i == j {
1433                    let concat_features = Array1::from_iter(
1434                        node_embeddings
1435                            .row(i)
1436                            .iter()
1437                            .chain(node_embeddings.row(j).iter())
1438                            .cloned(),
1439                    );
1440
1441                    if concat_features.len() == att_weights.nrows() {
1442                        let score = concat_features.dot(&att_weights.column(0));
1443                        attention_scores[j] = score.exp();
1444                        valid_neighbors.push(j);
1445                    }
1446                }
1447            }
1448
1449            // Normalize attention scores
1450            let total_attention: Float = valid_neighbors.iter().map(|&j| attention_scores[j]).sum();
1451            if total_attention > 0.0 {
1452                for &j in &valid_neighbors {
1453                    attention_scores[j] /= total_attention;
1454                }
1455            }
1456
1457            // Aggregate features using attention weights
1458            let mut aggregated = Array1::<Float>::zeros(node_embeddings.ncols());
1459            for &j in &valid_neighbors {
1460                let weighted_features = &node_embeddings.row(j).to_owned() * attention_scores[j];
1461                aggregated += &weighted_features;
1462            }
1463
1464            // Apply linear transformation
1465            let transformed = aggregated.dot(weights) + bias;
1466            let activated = if layer_idx == self.state.num_layers - 1 {
1467                transformed.mapv(|x| 1.0 / (1.0 + (-x).exp()))
1468            } else {
1469                transformed.mapv(|x| x.max(0.0))
1470            };
1471
1472            output.row_mut(i).assign(&activated);
1473        }
1474
1475        Ok(output)
1476    }
1477
1478    /// GraphSAGE layer for trained model
1479    fn graphsage_layer_trained(
1480        &self,
1481        node_embeddings: &Array2<Float>,
1482        adjacency: &ArrayView2<'_, i32>,
1483        layer_idx: usize,
1484    ) -> SklResult<Array2<Float>> {
1485        let weights = &self.state.layer_weights[layer_idx];
1486        let bias = &self.state.layer_biases[layer_idx];
1487        let n_nodes = node_embeddings.nrows();
1488        let embedding_dim = node_embeddings.ncols();
1489        let output_dim = weights.ncols();
1490        let mut output = Array2::<Float>::zeros((n_nodes, output_dim));
1491
1492        for i in 0..n_nodes {
1493            // Aggregate neighbor features
1494            let mut neighbor_sum = Array1::<Float>::zeros(embedding_dim);
1495            let mut neighbor_count = 0;
1496
1497            for j in 0..n_nodes {
1498                if adjacency[[i, j]] == 1 && i != j {
1499                    neighbor_sum += &node_embeddings.row(j).to_owned();
1500                    neighbor_count += 1;
1501                }
1502            }
1503
1504            // Average pooling of neighbors
1505            if neighbor_count > 0 {
1506                neighbor_sum /= neighbor_count as Float;
1507            }
1508
1509            // Concatenate self and neighbor representations
1510            let self_features = node_embeddings.row(i).to_owned();
1511            let concatenated =
1512                Array1::from_iter(self_features.iter().chain(neighbor_sum.iter()).cloned());
1513
1514            // Apply linear transformation
1515            if concatenated.len() == weights.nrows() {
1516                let transformed = concatenated.dot(weights) + bias;
1517                let activated = if layer_idx == self.state.num_layers - 1 {
1518                    transformed.mapv(|x| 1.0 / (1.0 + (-x).exp()))
1519                } else {
1520                    transformed.mapv(|x| x.max(0.0))
1521                };
1522                output.row_mut(i).assign(&activated);
1523            }
1524        }
1525
1526        Ok(output)
1527    }
1528
1529    /// GIN layer for trained model
1530    fn gin_layer_trained(
1531        &self,
1532        node_embeddings: &Array2<Float>,
1533        adjacency: &ArrayView2<'_, i32>,
1534        layer_idx: usize,
1535    ) -> SklResult<Array2<Float>> {
1536        let weights = &self.state.layer_weights[layer_idx];
1537        let bias = &self.state.layer_biases[layer_idx];
1538        let n_nodes = node_embeddings.nrows();
1539        let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1540        let epsilon = 0.0; // Simplified as 0
1541
1542        for i in 0..n_nodes {
1543            // Sum neighbor features
1544            let mut neighbor_sum = Array1::<Float>::zeros(node_embeddings.ncols());
1545
1546            for j in 0..n_nodes {
1547                if adjacency[[i, j]] == 1 && i != j {
1548                    neighbor_sum += &node_embeddings.row(j).to_owned();
1549                }
1550            }
1551
1552            // GIN update
1553            let updated = &node_embeddings.row(i).to_owned() * (1.0 + epsilon) + &neighbor_sum;
1554
1555            // Apply MLP
1556            let transformed = updated.dot(weights) + bias;
1557            let activated = if layer_idx == self.state.num_layers - 1 {
1558                transformed.mapv(|x| 1.0 / (1.0 + (-x).exp()))
1559            } else {
1560                transformed.mapv(|x| x.max(0.0))
1561            };
1562
1563            output.row_mut(i).assign(&activated);
1564        }
1565
1566        Ok(output)
1567    }
1568}
1569
1570// Tests for Graph Neural Networks
1571#[allow(non_snake_case)]
1572#[cfg(test)]
1573mod tests {
1574    use super::*;
1575    // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
1576    use scirs2_core::ndarray::array;
1577
1578    #[test]
1579    fn test_gnn_basic_functionality() {
1580        let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1581        let adjacency = array![[0, 1, 0], [1, 0, 1], [0, 1, 0]];
1582        let node_labels = array![[1, 0], [0, 1], [1, 1]];
1583
1584        let gnn = GraphNeuralNetwork::new()
1585            .hidden_dim(4)
1586            .num_layers(2)
1587            .max_iter(5);
1588
1589        let trained_gnn = gnn
1590            .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1591            .unwrap();
1592
1593        let predictions = trained_gnn
1594            .predict_graph(&adjacency.view(), &node_features.view())
1595            .unwrap();
1596
1597        assert_eq!(predictions.dim(), (3, 2));
1598        assert!(predictions.iter().all(|&x| x == 0 || x == 1));
1599    }
1600
1601    #[test]
1602    fn test_gnn_different_variants() {
1603        let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1604        let adjacency = array![[0, 1, 0], [1, 0, 1], [0, 1, 0]];
1605        let node_labels = array![[1, 0], [0, 1], [1, 1]];
1606
1607        // Test GCN
1608        let gnn_gcn = GraphNeuralNetwork::new()
1609            .message_passing_variant(MessagePassingVariant::GCN)
1610            .max_iter(5);
1611        let trained_gcn = gnn_gcn
1612            .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1613            .unwrap();
1614
1615        // Test GAT
1616        let gnn_gat = GraphNeuralNetwork::new()
1617            .message_passing_variant(MessagePassingVariant::GAT)
1618            .max_iter(5);
1619        let trained_gat = gnn_gat
1620            .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1621            .unwrap();
1622
1623        // Test GraphSAGE
1624        let gnn_sage = GraphNeuralNetwork::new()
1625            .message_passing_variant(MessagePassingVariant::GraphSAGE)
1626            .max_iter(5);
1627        let trained_sage = gnn_sage
1628            .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1629            .unwrap();
1630
1631        assert_eq!(
1632            trained_gcn.state.message_passing_variant,
1633            MessagePassingVariant::GCN
1634        );
1635        assert_eq!(
1636            trained_gat.state.message_passing_variant,
1637            MessagePassingVariant::GAT
1638        );
1639        assert_eq!(
1640            trained_sage.state.message_passing_variant,
1641            MessagePassingVariant::GraphSAGE
1642        );
1643    }
1644
1645    #[test]
1646    fn test_gnn_parameter_settings() {
1647        let gnn = GraphNeuralNetwork::new()
1648            .hidden_dim(16)
1649            .num_layers(3)
1650            .learning_rate(0.001)
1651            .max_iter(50)
1652            .dropout_rate(0.1);
1653
1654        assert_eq!(gnn.hidden_dim, 16);
1655        assert_eq!(gnn.num_layers, 3);
1656        assert!((gnn.learning_rate - 0.001).abs() < 1e-10);
1657        assert_eq!(gnn.max_iter, 50);
1658        assert!((gnn.dropout_rate - 0.1).abs() < 1e-10);
1659    }
1660
1661    #[test]
1662    fn test_gnn_default_settings() {
1663        let gnn = GraphNeuralNetwork::new();
1664
1665        assert_eq!(gnn.hidden_dim, 32);
1666        assert_eq!(gnn.num_layers, 2);
1667        assert_eq!(gnn.message_passing_variant, MessagePassingVariant::GCN);
1668        assert_eq!(gnn.aggregation_function, AggregationFunction::Mean);
1669    }
1670
1671    #[test]
1672    fn test_gnn_builder_pattern() {
1673        let gnn1 = GraphNeuralNetwork::new();
1674        let gnn2 = GraphNeuralNetwork::new();
1675
1676        assert_eq!(gnn1.hidden_dim, gnn2.hidden_dim);
1677        assert_eq!(gnn1.num_layers, gnn2.num_layers);
1678
1679        let gnn3 = GraphNeuralNetwork::new().max_iter(1);
1680        assert_eq!(gnn3.max_iter, 1);
1681    }
1682
1683    #[test]
1684    fn test_message_passing_variants() {
1685        assert_eq!(MessagePassingVariant::GCN, MessagePassingVariant::GCN);
1686        assert_ne!(MessagePassingVariant::GCN, MessagePassingVariant::GAT);
1687
1688        let variants = [
1689            MessagePassingVariant::GCN,
1690            MessagePassingVariant::GAT,
1691            MessagePassingVariant::GraphSAGE,
1692            MessagePassingVariant::GIN,
1693        ];
1694
1695        let gnn1 = GraphNeuralNetwork::new()
1696            .message_passing_variant(variants[0])
1697            .hidden_dim(8)
1698            .max_iter(3);
1699
1700        let gnn2 = GraphNeuralNetwork::new()
1701            .message_passing_variant(variants[1])
1702            .hidden_dim(8)
1703            .max_iter(3);
1704
1705        assert_eq!(gnn1.message_passing_variant, MessagePassingVariant::GCN);
1706        assert_eq!(gnn2.message_passing_variant, MessagePassingVariant::GAT);
1707    }
1708
1709    #[test]
1710    fn test_gnn_larger_graph() {
1711        let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0], [1.0, 3.0]];
1712        let adjacency = array![
1713            [0, 1, 1, 0, 0],
1714            [1, 0, 1, 1, 0],
1715            [1, 1, 0, 0, 1],
1716            [0, 1, 0, 0, 1],
1717            [0, 0, 1, 1, 0]
1718        ];
1719        let node_labels = array![[1, 0, 1], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 0]];
1720
1721        let gnn = GraphNeuralNetwork::new()
1722            .hidden_dim(10)
1723            .num_layers(2)
1724            .message_passing_variant(MessagePassingVariant::GCN)
1725            .max_iter(10);
1726
1727        let trained_gnn = gnn
1728            .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1729            .unwrap();
1730
1731        let predictions = trained_gnn
1732            .predict_graph(&adjacency.view(), &node_features.view())
1733            .unwrap();
1734
1735        assert_eq!(predictions.dim(), (5, 3));
1736        assert!(predictions.iter().all(|&x| x == 0 || x == 1));
1737        assert_eq!(trained_gnn.hidden_dim(), 10);
1738    }
1739
1740    #[test]
1741    fn test_aggregation_functions() {
1742        assert_ne!(AggregationFunction::Mean, AggregationFunction::Max);
1743        assert_eq!(AggregationFunction::Sum, AggregationFunction::Sum);
1744        assert_ne!(MessagePassingVariant::GraphSAGE, MessagePassingVariant::GIN);
1745    }
1746
1747    #[test]
1748    fn test_gnn_reproducibility() {
1749        let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1750        let adjacency = array![[0, 1, 0], [1, 0, 1], [0, 1, 0]];
1751        let node_labels = array![[1, 0], [0, 1], [1, 1]];
1752
1753        let gnn = GraphNeuralNetwork::new()
1754            .hidden_dim(4)
1755            .num_layers(2)
1756            .max_iter(5)
1757            .random_state(42);
1758
1759        let trained_gnn = gnn
1760            .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1761            .unwrap();
1762
1763        let predictions = trained_gnn
1764            .predict_graph(&adjacency.view(), &node_features.view())
1765            .unwrap();
1766
1767        assert_eq!(predictions.dim(), (3, 2));
1768    }
1769
1770    #[test]
1771    fn test_gnn_edge_cases() {
1772        let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0], [1.0, 3.0]];
1773        let adjacency = array![
1774            [0, 1, 1, 0, 0],
1775            [1, 0, 1, 1, 0],
1776            [1, 1, 0, 0, 1],
1777            [0, 1, 0, 0, 1],
1778            [0, 0, 1, 1, 0]
1779        ];
1780        let node_labels = array![[1, 0, 1], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 0]];
1781
1782        let gnn = GraphNeuralNetwork::new()
1783            .hidden_dim(10)
1784            .num_layers(2)
1785            .message_passing_variant(MessagePassingVariant::GCN)
1786            .max_iter(15)
1787            .random_state(42);
1788
1789        let trained_gnn = gnn
1790            .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1791            .unwrap();
1792        let predictions = trained_gnn
1793            .predict_graph(&adjacency.view(), &node_features.view())
1794            .unwrap();
1795
1796        assert_eq!(predictions.dim(), (5, 3));
1797        assert!(predictions.iter().all(|&x| x == 0 || x == 1));
1798        assert_eq!(trained_gnn.hidden_dim(), 10);
1799    }
1800}