1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::random::thread_rng;
10use scirs2_core::random::RandNormal;
11use sklears_core::{
12 error::{Result as SklResult, SklearsError},
13 traits::{Estimator, Fit, Predict, Untrained},
14 types::Float,
15};
16use std::collections::HashMap;
17
18#[derive(Debug, Clone, Copy, PartialEq, Default)]
20pub enum ConsistencyEnforcement {
21 #[default]
23 PostProcessing,
24 ConstrainedTraining,
26 BayesianInference,
28}
29
30#[derive(Debug, Clone)]
62pub struct OntologyAwareClassifier<S = Untrained> {
63 state: S,
64 ontology: HashMap<usize, Vec<usize>>,
65 consistency_enforcement: ConsistencyEnforcement,
66 base_classifier_learning_rate: Float,
67 max_iterations: usize,
68}
69
70#[derive(Debug, Clone)]
72pub struct OntologyAwareClassifierTrained {
73 weights: Array2<Float>,
74 biases: Array1<Float>,
75 ontology: HashMap<usize, Vec<usize>>,
76 consistency_enforcement: ConsistencyEnforcement,
77 n_features: usize,
78 n_labels: usize,
79}
80
81impl OntologyAwareClassifier<Untrained> {
82 pub fn new() -> Self {
84 Self {
85 state: Untrained,
86 ontology: HashMap::new(),
87 consistency_enforcement: ConsistencyEnforcement::PostProcessing,
88 base_classifier_learning_rate: 0.01,
89 max_iterations: 100,
90 }
91 }
92
93 pub fn ontology(mut self, ontology: HashMap<usize, Vec<usize>>) -> Self {
95 self.ontology = ontology;
96 self
97 }
98
99 pub fn consistency_enforcement(mut self, enforcement: ConsistencyEnforcement) -> Self {
101 self.consistency_enforcement = enforcement;
102 self
103 }
104
105 pub fn base_classifier_learning_rate(mut self, learning_rate: Float) -> Self {
107 self.base_classifier_learning_rate = learning_rate;
108 self
109 }
110
111 pub fn max_iterations(mut self, max_iterations: usize) -> Self {
113 self.max_iterations = max_iterations;
114 self
115 }
116}
117
118impl Default for OntologyAwareClassifier<Untrained> {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl Estimator for OntologyAwareClassifier<Untrained> {
125 type Config = ();
126 type Error = SklearsError;
127 type Float = Float;
128
129 fn config(&self) -> &Self::Config {
130 &()
131 }
132}
133
134impl Fit<ArrayView2<'_, Float>, Array2<i32>> for OntologyAwareClassifier<Untrained> {
135 type Fitted = OntologyAwareClassifier<OntologyAwareClassifierTrained>;
136
137 fn fit(
138 self,
139 X: &ArrayView2<'_, Float>,
140 y: &Array2<i32>,
141 ) -> SklResult<OntologyAwareClassifier<OntologyAwareClassifierTrained>> {
142 let (n_samples, n_features) = X.dim();
143 let n_labels = y.ncols();
144
145 if n_samples != y.nrows() {
146 return Err(SklearsError::InvalidInput(
147 "X and y must have the same number of samples".to_string(),
148 ));
149 }
150
151 let mut weights = Array2::<Float>::zeros((n_features, n_labels));
153 let mut biases = Array1::<Float>::zeros(n_labels);
154
155 for iteration in 0..self.max_iterations {
157 let mut total_loss = 0.0;
158
159 for sample_idx in 0..n_samples {
160 let x = X.row(sample_idx);
161 let y_true = y.row(sample_idx);
162
163 let logits = x.dot(&weights) + &biases;
165 let probabilities = logits.mapv(|x| 1.0 / (1.0 + (-x).exp()));
166
167 let consistent_probabilities = match self.consistency_enforcement {
169 ConsistencyEnforcement::ConstrainedTraining => {
170 self.enforce_consistency_training(&probabilities)?
171 }
172 _ => probabilities.clone(),
173 };
174
175 for label_idx in 0..n_labels {
177 let y_label = y_true[label_idx] as Float;
178 let prob = consistent_probabilities[label_idx];
179 let error = prob - y_label;
180
181 total_loss += if y_label == 1.0 {
182 -prob.ln()
183 } else {
184 -(1.0 - prob).ln()
185 };
186
187 for feat_idx in 0..n_features {
189 weights[[feat_idx, label_idx]] -=
190 self.base_classifier_learning_rate * error * x[feat_idx];
191 }
192 biases[label_idx] -= self.base_classifier_learning_rate * error;
193 }
194 }
195
196 if iteration > 0 && total_loss < 1e-6 {
197 break;
198 }
199 }
200
201 Ok(OntologyAwareClassifier {
202 state: OntologyAwareClassifierTrained {
203 weights,
204 biases,
205 ontology: self.ontology,
206 consistency_enforcement: self.consistency_enforcement,
207 n_features,
208 n_labels,
209 },
210 ontology: HashMap::new(),
211 consistency_enforcement: self.consistency_enforcement,
212 base_classifier_learning_rate: self.base_classifier_learning_rate,
213 max_iterations: self.max_iterations,
214 })
215 }
216}
217
218impl OntologyAwareClassifier<Untrained> {
219 fn enforce_consistency_training(
221 &self,
222 probabilities: &Array1<Float>,
223 ) -> SklResult<Array1<Float>> {
224 let mut consistent_probs = probabilities.clone();
225
226 for (&child, parents) in &self.ontology {
228 if child < probabilities.len() {
229 for &parent in parents {
230 if parent < probabilities.len() {
231 let child_prob = probabilities[child];
232 if consistent_probs[parent] < child_prob {
233 consistent_probs[parent] = child_prob;
234 }
235 }
236 }
237 }
238 }
239
240 Ok(consistent_probs)
241 }
242}
243
244impl Predict<ArrayView2<'_, Float>, Array2<i32>>
245 for OntologyAwareClassifier<OntologyAwareClassifierTrained>
246{
247 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
248 let (n_samples, n_features) = X.dim();
249
250 if n_features != self.state.n_features {
251 return Err(SklearsError::InvalidInput(
252 "X has different number of features than training data".to_string(),
253 ));
254 }
255
256 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
257
258 for sample_idx in 0..n_samples {
259 let x = X.row(sample_idx);
260
261 let logits = x.dot(&self.state.weights) + &self.state.biases;
263 let probabilities = logits.mapv(|x| 1.0 / (1.0 + (-x).exp()));
264
265 let consistent_probs = match self.state.consistency_enforcement {
267 ConsistencyEnforcement::PostProcessing => {
268 self.enforce_consistency_postprocessing(&probabilities)?
269 }
270 ConsistencyEnforcement::BayesianInference => {
271 self.enforce_consistency_bayesian(&probabilities)?
272 }
273 _ => probabilities,
274 };
275
276 for label_idx in 0..self.state.n_labels {
278 predictions[[sample_idx, label_idx]] = if consistent_probs[label_idx] > 0.5 {
279 1
280 } else {
281 0
282 };
283 }
284 }
285
286 Ok(predictions)
287 }
288}
289
290impl OntologyAwareClassifier<OntologyAwareClassifierTrained> {
291 pub fn weights(&self) -> &Array2<Float> {
293 &self.state.weights
294 }
295
296 pub fn biases(&self) -> &Array1<Float> {
298 &self.state.biases
299 }
300
301 pub fn ontology(&self) -> &HashMap<usize, Vec<usize>> {
303 &self.state.ontology
304 }
305
306 fn enforce_consistency_postprocessing(
308 &self,
309 probabilities: &Array1<Float>,
310 ) -> SklResult<Array1<Float>> {
311 let mut consistent_probs = probabilities.clone();
312
313 for (&child, parents) in &self.state.ontology {
315 if child < probabilities.len() && probabilities[child] > 0.5 {
316 for &parent in parents {
317 if parent < probabilities.len() {
318 consistent_probs[parent] =
319 consistent_probs[parent].max(probabilities[child]);
320 }
321 }
322 }
323 }
324
325 Ok(consistent_probs)
326 }
327
328 fn enforce_consistency_bayesian(
330 &self,
331 probabilities: &Array1<Float>,
332 ) -> SklResult<Array1<Float>> {
333 let mut consistent_probs = probabilities.clone();
334
335 for (&child, parents) in &self.state.ontology {
337 if child < probabilities.len() {
338 let child_prob = probabilities[child];
339 for &parent in parents {
340 if parent < probabilities.len() {
341 consistent_probs[parent] = consistent_probs[parent].max(child_prob * 0.8);
344 }
345 }
346 }
347 }
348
349 Ok(consistent_probs)
350 }
351}
352
353#[derive(Debug, Clone, Copy, PartialEq, Default)]
355pub enum CostStrategy {
356 #[default]
358 Uniform,
359 DistanceBased,
361 Custom,
363}
364
365#[derive(Debug, Clone)]
397pub struct CostSensitiveHierarchicalClassifier<S = Untrained> {
398 state: S,
399 hierarchy: HashMap<usize, Vec<usize>>,
400 cost_strategy: CostStrategy,
401 cost_matrix: Option<Array2<Float>>,
402 learning_rate: Float,
403 max_iterations: usize,
404 lambda_hierarchy: Float,
405 lambda_cost: Float,
406}
407
408#[derive(Debug, Clone)]
410pub struct CostSensitiveHierarchicalClassifierTrained {
411 weights: Array2<Float>,
412 hierarchy: HashMap<usize, Vec<usize>>,
414 cost_strategy: CostStrategy,
415 cost_matrix: Option<Array2<Float>>,
416 n_features: usize,
417 n_labels: usize,
418 lambda_hierarchy: Float,
419 lambda_cost: Float,
420}
421
422impl CostSensitiveHierarchicalClassifier<Untrained> {
423 pub fn new() -> Self {
425 Self {
426 state: Untrained,
427 hierarchy: HashMap::new(),
428 cost_strategy: CostStrategy::Uniform,
429 cost_matrix: None,
430 learning_rate: 0.01,
431 max_iterations: 100,
432 lambda_hierarchy: 1.0,
433 lambda_cost: 1.0,
434 }
435 }
436
437 pub fn hierarchy(mut self, hierarchy: HashMap<usize, Vec<usize>>) -> Self {
439 self.hierarchy = hierarchy;
440 self
441 }
442
443 pub fn cost_strategy(mut self, strategy: CostStrategy) -> Self {
445 self.cost_strategy = strategy;
446 self
447 }
448
449 pub fn cost_matrix(mut self, cost_matrix: Array2<Float>) -> Self {
451 self.cost_matrix = Some(cost_matrix);
452 self
453 }
454
455 pub fn learning_rate(mut self, learning_rate: Float) -> Self {
457 self.learning_rate = learning_rate;
458 self
459 }
460
461 pub fn max_iterations(mut self, max_iterations: usize) -> Self {
463 self.max_iterations = max_iterations;
464 self
465 }
466
467 pub fn lambda_hierarchy(mut self, lambda: Float) -> Self {
469 self.lambda_hierarchy = lambda;
470 self
471 }
472
473 pub fn lambda_cost(mut self, lambda: Float) -> Self {
475 self.lambda_cost = lambda;
476 self
477 }
478}
479
480impl Default for CostSensitiveHierarchicalClassifier<Untrained> {
481 fn default() -> Self {
482 Self::new()
483 }
484}
485
486impl Estimator for CostSensitiveHierarchicalClassifier<Untrained> {
487 type Config = ();
488 type Error = SklearsError;
489 type Float = Float;
490
491 fn config(&self) -> &Self::Config {
492 &()
493 }
494}
495
496impl Fit<ArrayView2<'_, Float>, Array2<i32>> for CostSensitiveHierarchicalClassifier<Untrained> {
497 type Fitted = CostSensitiveHierarchicalClassifier<CostSensitiveHierarchicalClassifierTrained>;
498
499 fn fit(
500 self,
501 X: &ArrayView2<'_, Float>,
502 y: &Array2<i32>,
503 ) -> SklResult<CostSensitiveHierarchicalClassifier<CostSensitiveHierarchicalClassifierTrained>>
504 {
505 let (n_samples, n_features) = X.dim();
506 let n_labels = y.ncols();
507
508 if n_samples != y.nrows() {
509 return Err(SklearsError::InvalidInput(
510 "X and y must have the same number of samples".to_string(),
511 ));
512 }
513
514 let cost_matrix = match &self.cost_matrix {
516 Some(matrix) => matrix.clone(),
517 None => self.generate_cost_matrix(n_labels)?,
518 };
519
520 let mut weights = Array2::<Float>::zeros((n_features, n_labels));
522
523 for _iteration in 0..self.max_iterations {
525 for sample_idx in 0..n_samples {
526 let x = X.row(sample_idx);
527 let y_true = y.row(sample_idx);
528
529 let scores = x.dot(&weights);
531 let probabilities = scores.mapv(|x| 1.0 / (1.0 + (-x).exp()));
532
533 for label_idx in 0..n_labels {
535 let y_label = y_true[label_idx] as Float;
536 let prob = probabilities[label_idx];
537
538 let mut gradient = prob - y_label;
540
541 let cost_weight = cost_matrix[[label_idx, label_idx]];
543 gradient *= cost_weight * self.lambda_cost;
544
545 gradient += self.lambda_hierarchy
547 * self.hierarchical_gradient(label_idx, &probabilities, &y_true)?;
548
549 for feat_idx in 0..n_features {
551 weights[[feat_idx, label_idx]] -=
552 self.learning_rate * gradient * x[feat_idx];
553 }
554 }
555 }
556 }
557
558 Ok(CostSensitiveHierarchicalClassifier {
559 state: CostSensitiveHierarchicalClassifierTrained {
560 weights,
561 hierarchy: self.hierarchy,
562 cost_strategy: self.cost_strategy,
563 cost_matrix: Some(cost_matrix),
564 n_features,
565 n_labels,
566 lambda_hierarchy: self.lambda_hierarchy,
567 lambda_cost: self.lambda_cost,
568 },
569 hierarchy: HashMap::new(),
570 cost_strategy: self.cost_strategy,
571 cost_matrix: None,
572 learning_rate: self.learning_rate,
573 max_iterations: self.max_iterations,
574 lambda_hierarchy: self.lambda_hierarchy,
575 lambda_cost: self.lambda_cost,
576 })
577 }
578}
579
580impl CostSensitiveHierarchicalClassifier<Untrained> {
581 fn generate_cost_matrix(&self, n_labels: usize) -> SklResult<Array2<Float>> {
583 match self.cost_strategy {
584 CostStrategy::Uniform => Ok(Array2::eye(n_labels)),
585 CostStrategy::DistanceBased => {
586 let mut cost_matrix = Array2::<Float>::zeros((n_labels, n_labels));
587 for i in 0..n_labels {
589 for j in 0..n_labels {
590 cost_matrix[[i, j]] = if i == j { 1.0 } else { 0.5 };
591 }
592 }
593 Ok(cost_matrix)
594 }
595 CostStrategy::Custom => Err(SklearsError::InvalidInput(
596 "Custom cost strategy requires a cost matrix".to_string(),
597 )),
598 }
599 }
600
601 fn hierarchical_gradient(
603 &self,
604 label_idx: usize,
605 probabilities: &Array1<Float>,
606 y_true: &ArrayView1<i32>,
607 ) -> SklResult<Float> {
608 let mut gradient = 0.0;
609
610 if let Some(children) = self.hierarchy.get(&label_idx) {
612 for &child in children {
613 if child < probabilities.len() {
614 let parent_prob = probabilities[label_idx];
615 let child_prob = probabilities[child];
616 let child_true = y_true[child] as Float;
617
618 if child_true > 0.5 && child_prob > parent_prob {
620 gradient += child_prob - parent_prob;
621 }
622 }
623 }
624 }
625
626 for (&parent, children) in &self.hierarchy {
628 if children.contains(&label_idx) && parent < probabilities.len() {
629 let parent_prob = probabilities[parent];
630 let child_prob = probabilities[label_idx];
631 let label_true = y_true[label_idx] as Float;
632
633 if label_true > 0.5 && child_prob > parent_prob {
635 gradient -= child_prob - parent_prob;
636 }
637 }
638 }
639
640 Ok(gradient)
641 }
642}
643
644impl Predict<ArrayView2<'_, Float>, Array2<i32>>
645 for CostSensitiveHierarchicalClassifier<CostSensitiveHierarchicalClassifierTrained>
646{
647 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
648 let (n_samples, n_features) = X.dim();
649
650 if n_features != self.state.n_features {
651 return Err(SklearsError::InvalidInput(
652 "X has different number of features than training data".to_string(),
653 ));
654 }
655
656 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
657
658 for sample_idx in 0..n_samples {
659 let x = X.row(sample_idx);
660
661 let scores = x.dot(&self.state.weights);
663 let probabilities = scores.mapv(|x| 1.0 / (1.0 + (-x).exp()));
664
665 let final_predictions = self.apply_constraints(&probabilities)?;
667
668 for label_idx in 0..self.state.n_labels {
669 predictions[[sample_idx, label_idx]] = final_predictions[label_idx];
670 }
671 }
672
673 Ok(predictions)
674 }
675}
676
677impl CostSensitiveHierarchicalClassifier<CostSensitiveHierarchicalClassifierTrained> {
678 pub fn weights(&self) -> &Array2<Float> {
680 &self.state.weights
681 }
682
683 pub fn cost_matrix(&self) -> Option<&Array2<Float>> {
685 self.state.cost_matrix.as_ref()
686 }
687
688 fn apply_constraints(&self, probabilities: &Array1<Float>) -> SklResult<Array1<i32>> {
690 let mut binary_predictions = Array1::<i32>::zeros(probabilities.len());
691
692 for i in 0..probabilities.len() {
694 let threshold = if let Some(cost_matrix) = &self.state.cost_matrix {
695 let cost = cost_matrix[[i, i]];
697 0.5 / cost.max(0.1) } else {
699 0.5
700 };
701
702 binary_predictions[i] = if probabilities[i] > threshold { 1 } else { 0 };
703 }
704
705 for (&parent, children) in &self.state.hierarchy {
707 if parent < binary_predictions.len() {
708 let mut any_child_predicted = false;
710 for &child in children {
711 if child < binary_predictions.len() && binary_predictions[child] == 1 {
712 any_child_predicted = true;
713 break;
714 }
715 }
716 if any_child_predicted {
717 binary_predictions[parent] = 1;
718 }
719 }
720 }
721
722 Ok(binary_predictions)
723 }
724}
725
726#[derive(Debug, Clone, Copy, PartialEq)]
730pub enum AggregationFunction {
731 Mean,
733 Sum,
735 Max,
737 Attention,
739}
740
741#[derive(Debug, Clone, Copy, PartialEq)]
743pub enum MessagePassingVariant {
744 GCN,
746 GAT,
748 GraphSAGE,
750 GIN,
752}
753
754#[derive(Debug, Clone)]
784pub struct GraphNeuralNetwork<S = Untrained> {
785 state: S,
786 hidden_dim: usize,
787 num_layers: usize,
788 message_passing_variant: MessagePassingVariant,
789 aggregation_function: AggregationFunction,
790 learning_rate: Float,
791 max_iter: usize,
792 dropout_rate: Float,
793 random_state: Option<u64>,
794}
795
796#[derive(Debug, Clone)]
798pub struct GraphNeuralNetworkTrained {
799 layer_weights: Vec<Array2<Float>>,
801 layer_biases: Vec<Array1<Float>>,
803 attention_weights: Option<Vec<Array2<Float>>>,
805 hidden_dim: usize,
807 num_layers: usize,
808 message_passing_variant: MessagePassingVariant,
809 aggregation_function: AggregationFunction,
810 n_features: usize,
811 n_outputs: usize,
812 dropout_rate: Float,
813}
814
815impl GraphNeuralNetwork<Untrained> {
816 pub fn new() -> Self {
818 Self {
819 state: Untrained,
820 hidden_dim: 32,
821 num_layers: 2,
822 message_passing_variant: MessagePassingVariant::GCN,
823 aggregation_function: AggregationFunction::Mean,
824 learning_rate: 0.01,
825 max_iter: 100,
826 dropout_rate: 0.0,
827 random_state: None,
828 }
829 }
830
831 pub fn hidden_dim(mut self, hidden_dim: usize) -> Self {
833 self.hidden_dim = hidden_dim;
834 self
835 }
836
837 pub fn num_layers(mut self, num_layers: usize) -> Self {
839 self.num_layers = num_layers;
840 self
841 }
842
843 pub fn message_passing_variant(mut self, variant: MessagePassingVariant) -> Self {
845 self.message_passing_variant = variant;
846 self
847 }
848
849 pub fn aggregation_function(mut self, function: AggregationFunction) -> Self {
851 self.aggregation_function = function;
852 self
853 }
854
855 pub fn learning_rate(mut self, learning_rate: Float) -> Self {
857 self.learning_rate = learning_rate;
858 self
859 }
860
861 pub fn max_iter(mut self, max_iter: usize) -> Self {
863 self.max_iter = max_iter;
864 self
865 }
866
867 pub fn dropout_rate(mut self, dropout_rate: Float) -> Self {
869 self.dropout_rate = dropout_rate;
870 self
871 }
872
873 pub fn random_state(mut self, random_state: u64) -> Self {
875 self.random_state = Some(random_state);
876 self
877 }
878}
879
880impl Default for GraphNeuralNetwork<Untrained> {
881 fn default() -> Self {
882 Self::new()
883 }
884}
885
886impl Estimator for GraphNeuralNetwork<Untrained> {
887 type Config = ();
888 type Error = SklearsError;
889 type Float = Float;
890
891 fn config(&self) -> &Self::Config {
892 &()
893 }
894}
895
896impl GraphNeuralNetwork<Untrained> {
898 pub fn fit_graph(
900 self,
901 adjacency: &ArrayView2<'_, i32>,
902 node_features: &ArrayView2<'_, Float>,
903 node_labels: &Array2<i32>,
904 ) -> SklResult<GraphNeuralNetwork<GraphNeuralNetworkTrained>> {
905 let (n_nodes, n_features) = node_features.dim();
906 let n_outputs = node_labels.ncols();
907
908 if adjacency.dim() != (n_nodes, n_nodes) {
909 return Err(SklearsError::InvalidInput(
910 "Adjacency matrix must be n_nodes x n_nodes".to_string(),
911 ));
912 }
913
914 if node_labels.nrows() != n_nodes {
915 return Err(SklearsError::InvalidInput(
916 "Node labels must have same number of rows as nodes".to_string(),
917 ));
918 }
919
920 let mut rng_instance = thread_rng();
922 let (layer_weights, layer_biases, attention_weights) =
923 self.initialize_gnn_parameters(n_features, n_outputs, &mut rng_instance)?;
924
925 let mut weights = layer_weights;
927 let biases = layer_biases;
928 let attention_weights = attention_weights;
929
930 for _iteration in 0..self.max_iter {
931 let (node_embeddings, _) = self.forward_pass_graph(
933 adjacency,
934 node_features,
935 &weights,
936 &biases,
937 &attention_weights,
938 )?;
939
940 let predictions = node_embeddings.mapv(|x| if x > 0.0 { 1 } else { 0 });
942
943 for weight in &mut weights {
945 for i in 0..weight.nrows() {
946 for j in 0..weight.ncols() {
947 weight[[i, j]] *= 0.999; }
949 }
950 }
951 }
952
953 let trained_state = GraphNeuralNetworkTrained {
954 layer_weights: weights,
955 layer_biases: biases,
956 attention_weights,
957 hidden_dim: self.hidden_dim,
958 num_layers: self.num_layers,
959 message_passing_variant: self.message_passing_variant,
960 aggregation_function: self.aggregation_function,
961 n_features,
962 n_outputs,
963 dropout_rate: self.dropout_rate,
964 };
965
966 Ok(GraphNeuralNetwork {
967 state: trained_state,
968 hidden_dim: self.hidden_dim,
969 num_layers: self.num_layers,
970 message_passing_variant: self.message_passing_variant,
971 aggregation_function: self.aggregation_function,
972 learning_rate: self.learning_rate,
973 max_iter: self.max_iter,
974 dropout_rate: self.dropout_rate,
975 random_state: self.random_state,
976 })
977 }
978
979 #[allow(clippy::type_complexity)]
981 fn initialize_gnn_parameters(
982 &self,
983 n_features: usize,
984 n_outputs: usize,
985 rng: &mut scirs2_core::random::CoreRandom,
986 ) -> SklResult<(
987 Vec<Array2<Float>>,
988 Vec<Array1<Float>>,
989 Option<Vec<Array2<Float>>>,
990 )> {
991 let mut layer_weights = Vec::new();
992 let mut layer_biases = Vec::new();
993 let mut attention_weights = None;
994
995 let input_dim = match self.message_passing_variant {
997 MessagePassingVariant::GraphSAGE => n_features * 2, _ => n_features,
999 };
1000
1001 let hidden_dim = match self.message_passing_variant {
1003 MessagePassingVariant::GraphSAGE => self.hidden_dim * 2, _ => self.hidden_dim,
1005 };
1006
1007 for layer_idx in 0..self.num_layers {
1009 let (in_dim, out_dim) = if layer_idx == 0 {
1010 (input_dim, self.hidden_dim)
1011 } else if layer_idx == self.num_layers - 1 {
1012 (hidden_dim, n_outputs)
1013 } else {
1014 (hidden_dim, self.hidden_dim)
1015 };
1016
1017 let normal_dist = RandNormal::new(0.0, (2.0 / in_dim as Float).sqrt()).unwrap();
1018 let mut input_weight = Array2::<Float>::zeros((in_dim, out_dim));
1019 for i in 0..in_dim {
1020 for j in 0..out_dim {
1021 input_weight[[i, j]] = rng.sample(normal_dist);
1022 }
1023 }
1024 let bias = Array1::<Float>::zeros(out_dim);
1025
1026 layer_weights.push(input_weight);
1027 layer_biases.push(bias);
1028 }
1029
1030 if self.message_passing_variant == MessagePassingVariant::GAT {
1032 let mut att_weights = Vec::new();
1033 for layer_idx in 0..self.num_layers {
1034 let att_dim = if layer_idx == 0 {
1035 n_features
1036 } else {
1037 self.hidden_dim
1038 };
1039 let att_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
1040 let mut attention_weight = Array2::<Float>::zeros((att_dim * 2, 1));
1041 for i in 0..(att_dim * 2) {
1042 attention_weight[[i, 0]] = rng.sample(att_normal_dist);
1043 }
1044 att_weights.push(attention_weight);
1045 }
1046 attention_weights = Some(att_weights);
1047 }
1048
1049 Ok((layer_weights, layer_biases, attention_weights))
1050 }
1051
1052 fn forward_pass_graph(
1054 &self,
1055 adjacency: &ArrayView2<'_, i32>,
1056 node_features: &ArrayView2<'_, Float>,
1057 weights: &[Array2<Float>],
1058 biases: &[Array1<Float>],
1059 attention_weights: &Option<Vec<Array2<Float>>>,
1060 ) -> SklResult<(Array2<Float>, Vec<Array2<Float>>)> {
1061 let n_nodes = node_features.nrows();
1062 let mut current_embeddings = node_features.to_owned();
1063 let mut layer_outputs = Vec::new();
1064
1065 for layer_idx in 0..self.num_layers {
1066 let layer_output = match self.message_passing_variant {
1067 MessagePassingVariant::GCN => self.gcn_layer(
1068 ¤t_embeddings,
1069 adjacency,
1070 &weights[layer_idx],
1071 &biases[layer_idx],
1072 )?,
1073 MessagePassingVariant::GAT => {
1074 let att_weights = attention_weights.as_ref().unwrap();
1075 self.gat_layer(
1076 ¤t_embeddings,
1077 adjacency,
1078 &weights[layer_idx],
1079 &biases[layer_idx],
1080 &att_weights[layer_idx],
1081 )?
1082 }
1083 MessagePassingVariant::GraphSAGE => self.graphsage_layer(
1084 ¤t_embeddings,
1085 adjacency,
1086 &weights[layer_idx],
1087 &biases[layer_idx],
1088 )?,
1089 MessagePassingVariant::GIN => self.gin_layer(
1090 ¤t_embeddings,
1091 adjacency,
1092 &weights[layer_idx],
1093 &biases[layer_idx],
1094 )?,
1095 };
1096
1097 current_embeddings = layer_output.clone();
1098 layer_outputs.push(layer_output);
1099 }
1100
1101 Ok((current_embeddings, layer_outputs))
1102 }
1103
1104 fn gcn_layer(
1106 &self,
1107 node_embeddings: &Array2<Float>,
1108 adjacency: &ArrayView2<'_, i32>,
1109 weights: &Array2<Float>,
1110 bias: &Array1<Float>,
1111 ) -> SklResult<Array2<Float>> {
1112 let n_nodes = node_embeddings.nrows();
1113 let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1114
1115 for i in 0..n_nodes {
1116 let mut aggregated = Array1::<Float>::zeros(node_embeddings.ncols());
1117 let mut degree = 0;
1118
1119 for j in 0..n_nodes {
1121 if adjacency[[i, j]] == 1 {
1122 aggregated += &node_embeddings.row(j).to_owned();
1123 degree += 1;
1124 }
1125 }
1126
1127 aggregated += &node_embeddings.row(i).to_owned();
1129 degree += 1;
1130
1131 if degree > 0 {
1133 aggregated /= degree as Float;
1134 }
1135
1136 let transformed = aggregated.dot(weights) + bias;
1138 let activated = transformed.mapv(|x| x.max(0.0)); output.row_mut(i).assign(&activated);
1141 }
1142
1143 Ok(output)
1144 }
1145
1146 fn gat_layer(
1148 &self,
1149 node_embeddings: &Array2<Float>,
1150 adjacency: &ArrayView2<'_, i32>,
1151 weights: &Array2<Float>,
1152 bias: &Array1<Float>,
1153 attention_weights: &Array2<Float>,
1154 ) -> SklResult<Array2<Float>> {
1155 let n_nodes = node_embeddings.nrows();
1156 let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1157
1158 for i in 0..n_nodes {
1159 let mut attention_scores = Array1::<Float>::zeros(n_nodes);
1160 let mut valid_neighbors = Vec::new();
1161
1162 for j in 0..n_nodes {
1164 if adjacency[[i, j]] == 1 || i == j {
1165 let concat_features = Array1::from_iter(
1167 node_embeddings
1168 .row(i)
1169 .iter()
1170 .chain(node_embeddings.row(j).iter())
1171 .cloned(),
1172 );
1173
1174 if concat_features.len() == attention_weights.nrows() {
1175 let score = concat_features.dot(&attention_weights.column(0));
1176 attention_scores[j] = score.exp();
1177 valid_neighbors.push(j);
1178 }
1179 }
1180 }
1181
1182 let total_attention: Float = valid_neighbors.iter().map(|&j| attention_scores[j]).sum();
1184 if total_attention > 0.0 {
1185 for &j in &valid_neighbors {
1186 attention_scores[j] /= total_attention;
1187 }
1188 }
1189
1190 let mut aggregated = Array1::<Float>::zeros(node_embeddings.ncols());
1192 for &j in &valid_neighbors {
1193 let weighted_features = &node_embeddings.row(j).to_owned() * attention_scores[j];
1194 aggregated += &weighted_features;
1195 }
1196
1197 let transformed = aggregated.dot(weights) + bias;
1199 let activated = transformed.mapv(|x| x.max(0.0)); output.row_mut(i).assign(&activated);
1202 }
1203
1204 Ok(output)
1205 }
1206
1207 fn graphsage_layer(
1209 &self,
1210 node_embeddings: &Array2<Float>,
1211 adjacency: &ArrayView2<'_, i32>,
1212 weights: &Array2<Float>,
1213 bias: &Array1<Float>,
1214 ) -> SklResult<Array2<Float>> {
1215 let n_nodes = node_embeddings.nrows();
1216 let embedding_dim = node_embeddings.ncols();
1217 let output_dim = weights.ncols();
1218 let mut output = Array2::<Float>::zeros((n_nodes, output_dim));
1219
1220 for i in 0..n_nodes {
1221 let mut neighbor_sum = Array1::<Float>::zeros(embedding_dim);
1223 let mut neighbor_count = 0;
1224
1225 for j in 0..n_nodes {
1226 if adjacency[[i, j]] == 1 && i != j {
1227 neighbor_sum += &node_embeddings.row(j).to_owned();
1228 neighbor_count += 1;
1229 }
1230 }
1231
1232 if neighbor_count > 0 {
1234 neighbor_sum /= neighbor_count as Float;
1235 }
1236
1237 let self_features = node_embeddings.row(i).to_owned();
1239 let concatenated =
1240 Array1::from_iter(self_features.iter().chain(neighbor_sum.iter()).cloned());
1241
1242 if concatenated.len() == weights.nrows() {
1244 let transformed = concatenated.dot(weights) + bias;
1245 let activated = transformed.mapv(|x| x.max(0.0)); output.row_mut(i).assign(&activated);
1247 }
1248 }
1249
1250 Ok(output)
1251 }
1252
1253 fn gin_layer(
1255 &self,
1256 node_embeddings: &Array2<Float>,
1257 adjacency: &ArrayView2<'_, i32>,
1258 weights: &Array2<Float>,
1259 bias: &Array1<Float>,
1260 ) -> SklResult<Array2<Float>> {
1261 let n_nodes = node_embeddings.nrows();
1262 let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1263 let epsilon = 0.0; for i in 0..n_nodes {
1266 let mut neighbor_sum = Array1::<Float>::zeros(node_embeddings.ncols());
1268
1269 for j in 0..n_nodes {
1270 if adjacency[[i, j]] == 1 && i != j {
1271 neighbor_sum += &node_embeddings.row(j).to_owned();
1272 }
1273 }
1274
1275 let updated = &node_embeddings.row(i).to_owned() * (1.0 + epsilon) + &neighbor_sum;
1277
1278 let transformed = updated.dot(weights) + bias;
1280 let activated = transformed.mapv(|x| x.max(0.0)); output.row_mut(i).assign(&activated);
1283 }
1284
1285 Ok(output)
1286 }
1287}
1288
1289impl GraphNeuralNetwork<GraphNeuralNetworkTrained> {
1290 pub fn predict_graph(
1292 &self,
1293 adjacency: &ArrayView2<'_, i32>,
1294 node_features: &ArrayView2<'_, Float>,
1295 ) -> SklResult<Array2<i32>> {
1296 let (n_nodes, n_features) = node_features.dim();
1297
1298 if n_features != self.state.n_features {
1299 return Err(SklearsError::InvalidInput(
1300 "Node features have different dimensionality than training data".to_string(),
1301 ));
1302 }
1303
1304 if adjacency.dim() != (n_nodes, n_nodes) {
1305 return Err(SklearsError::InvalidInput(
1306 "Adjacency matrix must be n_nodes x n_nodes".to_string(),
1307 ));
1308 }
1309
1310 let (final_embeddings, _) = self.forward_pass_trained(adjacency, node_features)?;
1312
1313 let predictions = final_embeddings.mapv(|x| if x > 0.0 { 1 } else { 0 });
1315
1316 Ok(predictions)
1317 }
1318
1319 pub fn hidden_dim(&self) -> usize {
1321 self.state.hidden_dim
1322 }
1323
1324 pub fn num_layers(&self) -> usize {
1326 self.state.num_layers
1327 }
1328
1329 fn forward_pass_trained(
1331 &self,
1332 adjacency: &ArrayView2<'_, i32>,
1333 node_features: &ArrayView2<'_, Float>,
1334 ) -> SklResult<(Array2<Float>, Vec<Array2<Float>>)> {
1335 let n_nodes = node_features.nrows();
1336 let mut current_embeddings = node_features.to_owned();
1337 let mut layer_outputs = Vec::new();
1338
1339 for layer_idx in 0..self.state.num_layers {
1340 let layer_output = match self.state.message_passing_variant {
1341 MessagePassingVariant::GCN => {
1342 self.gcn_layer_trained(¤t_embeddings, adjacency, layer_idx)?
1343 }
1344 MessagePassingVariant::GAT => {
1345 self.gat_layer_trained(¤t_embeddings, adjacency, layer_idx)?
1346 }
1347 MessagePassingVariant::GraphSAGE => {
1348 self.graphsage_layer_trained(¤t_embeddings, adjacency, layer_idx)?
1349 }
1350 MessagePassingVariant::GIN => {
1351 self.gin_layer_trained(¤t_embeddings, adjacency, layer_idx)?
1352 }
1353 };
1354
1355 current_embeddings = layer_output.clone();
1356 layer_outputs.push(layer_output);
1357 }
1358
1359 Ok((current_embeddings, layer_outputs))
1360 }
1361
1362 fn gcn_layer_trained(
1364 &self,
1365 node_embeddings: &Array2<Float>,
1366 adjacency: &ArrayView2<'_, i32>,
1367 layer_idx: usize,
1368 ) -> SklResult<Array2<Float>> {
1369 let weights = &self.state.layer_weights[layer_idx];
1370 let bias = &self.state.layer_biases[layer_idx];
1371 let n_nodes = node_embeddings.nrows();
1372 let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1373
1374 for i in 0..n_nodes {
1375 let mut aggregated = Array1::<Float>::zeros(node_embeddings.ncols());
1376 let mut degree = 0;
1377
1378 for j in 0..n_nodes {
1380 if adjacency[[i, j]] == 1 {
1381 aggregated += &node_embeddings.row(j).to_owned();
1382 degree += 1;
1383 }
1384 }
1385
1386 aggregated += &node_embeddings.row(i).to_owned();
1388 degree += 1;
1389
1390 if degree > 0 {
1392 aggregated /= degree as Float;
1393 }
1394
1395 let transformed = aggregated.dot(weights) + bias;
1397 let activated = if layer_idx == self.state.num_layers - 1 {
1398 transformed.mapv(|x| 1.0 / (1.0 + (-x).exp()))
1400 } else {
1401 transformed.mapv(|x| x.max(0.0))
1403 };
1404
1405 output.row_mut(i).assign(&activated);
1406 }
1407
1408 Ok(output)
1409 }
1410
1411 fn gat_layer_trained(
1413 &self,
1414 node_embeddings: &Array2<Float>,
1415 adjacency: &ArrayView2<'_, i32>,
1416 layer_idx: usize,
1417 ) -> SklResult<Array2<Float>> {
1418 let weights = &self.state.layer_weights[layer_idx];
1419 let bias = &self.state.layer_biases[layer_idx];
1420 let attention_weights = self.state.attention_weights.as_ref().unwrap();
1421 let att_weights = &attention_weights[layer_idx];
1422
1423 let n_nodes = node_embeddings.nrows();
1424 let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1425
1426 for i in 0..n_nodes {
1427 let mut attention_scores = Array1::<Float>::zeros(n_nodes);
1428 let mut valid_neighbors = Vec::new();
1429
1430 for j in 0..n_nodes {
1432 if adjacency[[i, j]] == 1 || i == j {
1433 let concat_features = Array1::from_iter(
1434 node_embeddings
1435 .row(i)
1436 .iter()
1437 .chain(node_embeddings.row(j).iter())
1438 .cloned(),
1439 );
1440
1441 if concat_features.len() == att_weights.nrows() {
1442 let score = concat_features.dot(&att_weights.column(0));
1443 attention_scores[j] = score.exp();
1444 valid_neighbors.push(j);
1445 }
1446 }
1447 }
1448
1449 let total_attention: Float = valid_neighbors.iter().map(|&j| attention_scores[j]).sum();
1451 if total_attention > 0.0 {
1452 for &j in &valid_neighbors {
1453 attention_scores[j] /= total_attention;
1454 }
1455 }
1456
1457 let mut aggregated = Array1::<Float>::zeros(node_embeddings.ncols());
1459 for &j in &valid_neighbors {
1460 let weighted_features = &node_embeddings.row(j).to_owned() * attention_scores[j];
1461 aggregated += &weighted_features;
1462 }
1463
1464 let transformed = aggregated.dot(weights) + bias;
1466 let activated = if layer_idx == self.state.num_layers - 1 {
1467 transformed.mapv(|x| 1.0 / (1.0 + (-x).exp()))
1468 } else {
1469 transformed.mapv(|x| x.max(0.0))
1470 };
1471
1472 output.row_mut(i).assign(&activated);
1473 }
1474
1475 Ok(output)
1476 }
1477
1478 fn graphsage_layer_trained(
1480 &self,
1481 node_embeddings: &Array2<Float>,
1482 adjacency: &ArrayView2<'_, i32>,
1483 layer_idx: usize,
1484 ) -> SklResult<Array2<Float>> {
1485 let weights = &self.state.layer_weights[layer_idx];
1486 let bias = &self.state.layer_biases[layer_idx];
1487 let n_nodes = node_embeddings.nrows();
1488 let embedding_dim = node_embeddings.ncols();
1489 let output_dim = weights.ncols();
1490 let mut output = Array2::<Float>::zeros((n_nodes, output_dim));
1491
1492 for i in 0..n_nodes {
1493 let mut neighbor_sum = Array1::<Float>::zeros(embedding_dim);
1495 let mut neighbor_count = 0;
1496
1497 for j in 0..n_nodes {
1498 if adjacency[[i, j]] == 1 && i != j {
1499 neighbor_sum += &node_embeddings.row(j).to_owned();
1500 neighbor_count += 1;
1501 }
1502 }
1503
1504 if neighbor_count > 0 {
1506 neighbor_sum /= neighbor_count as Float;
1507 }
1508
1509 let self_features = node_embeddings.row(i).to_owned();
1511 let concatenated =
1512 Array1::from_iter(self_features.iter().chain(neighbor_sum.iter()).cloned());
1513
1514 if concatenated.len() == weights.nrows() {
1516 let transformed = concatenated.dot(weights) + bias;
1517 let activated = if layer_idx == self.state.num_layers - 1 {
1518 transformed.mapv(|x| 1.0 / (1.0 + (-x).exp()))
1519 } else {
1520 transformed.mapv(|x| x.max(0.0))
1521 };
1522 output.row_mut(i).assign(&activated);
1523 }
1524 }
1525
1526 Ok(output)
1527 }
1528
1529 fn gin_layer_trained(
1531 &self,
1532 node_embeddings: &Array2<Float>,
1533 adjacency: &ArrayView2<'_, i32>,
1534 layer_idx: usize,
1535 ) -> SklResult<Array2<Float>> {
1536 let weights = &self.state.layer_weights[layer_idx];
1537 let bias = &self.state.layer_biases[layer_idx];
1538 let n_nodes = node_embeddings.nrows();
1539 let mut output = Array2::<Float>::zeros((n_nodes, weights.ncols()));
1540 let epsilon = 0.0; for i in 0..n_nodes {
1543 let mut neighbor_sum = Array1::<Float>::zeros(node_embeddings.ncols());
1545
1546 for j in 0..n_nodes {
1547 if adjacency[[i, j]] == 1 && i != j {
1548 neighbor_sum += &node_embeddings.row(j).to_owned();
1549 }
1550 }
1551
1552 let updated = &node_embeddings.row(i).to_owned() * (1.0 + epsilon) + &neighbor_sum;
1554
1555 let transformed = updated.dot(weights) + bias;
1557 let activated = if layer_idx == self.state.num_layers - 1 {
1558 transformed.mapv(|x| 1.0 / (1.0 + (-x).exp()))
1559 } else {
1560 transformed.mapv(|x| x.max(0.0))
1561 };
1562
1563 output.row_mut(i).assign(&activated);
1564 }
1565
1566 Ok(output)
1567 }
1568}
1569
1570#[allow(non_snake_case)]
1572#[cfg(test)]
1573mod tests {
1574 use super::*;
1575 use scirs2_core::ndarray::array;
1577
1578 #[test]
1579 fn test_gnn_basic_functionality() {
1580 let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1581 let adjacency = array![[0, 1, 0], [1, 0, 1], [0, 1, 0]];
1582 let node_labels = array![[1, 0], [0, 1], [1, 1]];
1583
1584 let gnn = GraphNeuralNetwork::new()
1585 .hidden_dim(4)
1586 .num_layers(2)
1587 .max_iter(5);
1588
1589 let trained_gnn = gnn
1590 .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1591 .unwrap();
1592
1593 let predictions = trained_gnn
1594 .predict_graph(&adjacency.view(), &node_features.view())
1595 .unwrap();
1596
1597 assert_eq!(predictions.dim(), (3, 2));
1598 assert!(predictions.iter().all(|&x| x == 0 || x == 1));
1599 }
1600
1601 #[test]
1602 fn test_gnn_different_variants() {
1603 let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1604 let adjacency = array![[0, 1, 0], [1, 0, 1], [0, 1, 0]];
1605 let node_labels = array![[1, 0], [0, 1], [1, 1]];
1606
1607 let gnn_gcn = GraphNeuralNetwork::new()
1609 .message_passing_variant(MessagePassingVariant::GCN)
1610 .max_iter(5);
1611 let trained_gcn = gnn_gcn
1612 .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1613 .unwrap();
1614
1615 let gnn_gat = GraphNeuralNetwork::new()
1617 .message_passing_variant(MessagePassingVariant::GAT)
1618 .max_iter(5);
1619 let trained_gat = gnn_gat
1620 .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1621 .unwrap();
1622
1623 let gnn_sage = GraphNeuralNetwork::new()
1625 .message_passing_variant(MessagePassingVariant::GraphSAGE)
1626 .max_iter(5);
1627 let trained_sage = gnn_sage
1628 .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1629 .unwrap();
1630
1631 assert_eq!(
1632 trained_gcn.state.message_passing_variant,
1633 MessagePassingVariant::GCN
1634 );
1635 assert_eq!(
1636 trained_gat.state.message_passing_variant,
1637 MessagePassingVariant::GAT
1638 );
1639 assert_eq!(
1640 trained_sage.state.message_passing_variant,
1641 MessagePassingVariant::GraphSAGE
1642 );
1643 }
1644
1645 #[test]
1646 fn test_gnn_parameter_settings() {
1647 let gnn = GraphNeuralNetwork::new()
1648 .hidden_dim(16)
1649 .num_layers(3)
1650 .learning_rate(0.001)
1651 .max_iter(50)
1652 .dropout_rate(0.1);
1653
1654 assert_eq!(gnn.hidden_dim, 16);
1655 assert_eq!(gnn.num_layers, 3);
1656 assert!((gnn.learning_rate - 0.001).abs() < 1e-10);
1657 assert_eq!(gnn.max_iter, 50);
1658 assert!((gnn.dropout_rate - 0.1).abs() < 1e-10);
1659 }
1660
1661 #[test]
1662 fn test_gnn_default_settings() {
1663 let gnn = GraphNeuralNetwork::new();
1664
1665 assert_eq!(gnn.hidden_dim, 32);
1666 assert_eq!(gnn.num_layers, 2);
1667 assert_eq!(gnn.message_passing_variant, MessagePassingVariant::GCN);
1668 assert_eq!(gnn.aggregation_function, AggregationFunction::Mean);
1669 }
1670
1671 #[test]
1672 fn test_gnn_builder_pattern() {
1673 let gnn1 = GraphNeuralNetwork::new();
1674 let gnn2 = GraphNeuralNetwork::new();
1675
1676 assert_eq!(gnn1.hidden_dim, gnn2.hidden_dim);
1677 assert_eq!(gnn1.num_layers, gnn2.num_layers);
1678
1679 let gnn3 = GraphNeuralNetwork::new().max_iter(1);
1680 assert_eq!(gnn3.max_iter, 1);
1681 }
1682
1683 #[test]
1684 fn test_message_passing_variants() {
1685 assert_eq!(MessagePassingVariant::GCN, MessagePassingVariant::GCN);
1686 assert_ne!(MessagePassingVariant::GCN, MessagePassingVariant::GAT);
1687
1688 let variants = [
1689 MessagePassingVariant::GCN,
1690 MessagePassingVariant::GAT,
1691 MessagePassingVariant::GraphSAGE,
1692 MessagePassingVariant::GIN,
1693 ];
1694
1695 let gnn1 = GraphNeuralNetwork::new()
1696 .message_passing_variant(variants[0])
1697 .hidden_dim(8)
1698 .max_iter(3);
1699
1700 let gnn2 = GraphNeuralNetwork::new()
1701 .message_passing_variant(variants[1])
1702 .hidden_dim(8)
1703 .max_iter(3);
1704
1705 assert_eq!(gnn1.message_passing_variant, MessagePassingVariant::GCN);
1706 assert_eq!(gnn2.message_passing_variant, MessagePassingVariant::GAT);
1707 }
1708
1709 #[test]
1710 fn test_gnn_larger_graph() {
1711 let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0], [1.0, 3.0]];
1712 let adjacency = array![
1713 [0, 1, 1, 0, 0],
1714 [1, 0, 1, 1, 0],
1715 [1, 1, 0, 0, 1],
1716 [0, 1, 0, 0, 1],
1717 [0, 0, 1, 1, 0]
1718 ];
1719 let node_labels = array![[1, 0, 1], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 0]];
1720
1721 let gnn = GraphNeuralNetwork::new()
1722 .hidden_dim(10)
1723 .num_layers(2)
1724 .message_passing_variant(MessagePassingVariant::GCN)
1725 .max_iter(10);
1726
1727 let trained_gnn = gnn
1728 .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1729 .unwrap();
1730
1731 let predictions = trained_gnn
1732 .predict_graph(&adjacency.view(), &node_features.view())
1733 .unwrap();
1734
1735 assert_eq!(predictions.dim(), (5, 3));
1736 assert!(predictions.iter().all(|&x| x == 0 || x == 1));
1737 assert_eq!(trained_gnn.hidden_dim(), 10);
1738 }
1739
1740 #[test]
1741 fn test_aggregation_functions() {
1742 assert_ne!(AggregationFunction::Mean, AggregationFunction::Max);
1743 assert_eq!(AggregationFunction::Sum, AggregationFunction::Sum);
1744 assert_ne!(MessagePassingVariant::GraphSAGE, MessagePassingVariant::GIN);
1745 }
1746
1747 #[test]
1748 fn test_gnn_reproducibility() {
1749 let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1750 let adjacency = array![[0, 1, 0], [1, 0, 1], [0, 1, 0]];
1751 let node_labels = array![[1, 0], [0, 1], [1, 1]];
1752
1753 let gnn = GraphNeuralNetwork::new()
1754 .hidden_dim(4)
1755 .num_layers(2)
1756 .max_iter(5)
1757 .random_state(42);
1758
1759 let trained_gnn = gnn
1760 .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1761 .unwrap();
1762
1763 let predictions = trained_gnn
1764 .predict_graph(&adjacency.view(), &node_features.view())
1765 .unwrap();
1766
1767 assert_eq!(predictions.dim(), (3, 2));
1768 }
1769
1770 #[test]
1771 fn test_gnn_edge_cases() {
1772 let node_features = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0], [1.0, 3.0]];
1773 let adjacency = array![
1774 [0, 1, 1, 0, 0],
1775 [1, 0, 1, 1, 0],
1776 [1, 1, 0, 0, 1],
1777 [0, 1, 0, 0, 1],
1778 [0, 0, 1, 1, 0]
1779 ];
1780 let node_labels = array![[1, 0, 1], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 0]];
1781
1782 let gnn = GraphNeuralNetwork::new()
1783 .hidden_dim(10)
1784 .num_layers(2)
1785 .message_passing_variant(MessagePassingVariant::GCN)
1786 .max_iter(15)
1787 .random_state(42);
1788
1789 let trained_gnn = gnn
1790 .fit_graph(&adjacency.view(), &node_features.view(), &node_labels)
1791 .unwrap();
1792 let predictions = trained_gnn
1793 .predict_graph(&adjacency.view(), &node_features.view())
1794 .unwrap();
1795
1796 assert_eq!(predictions.dim(), (5, 3));
1797 assert!(predictions.iter().all(|&x| x == 0 || x == 1));
1798 assert_eq!(trained_gnn.hidden_dim(), 10);
1799 }
1800}