1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use sklears_core::{
8 error::{Result as SklResult, SklearsError},
9 traits::{Estimator, Fit, Untrained},
10 types::Float,
11};
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::fmt::Debug;
14
15use crate::{PipelinePredictor, PipelineStep};
16
17#[derive(Debug)]
19pub struct DAGNode {
20 pub id: String,
22 pub name: String,
24 pub component: NodeComponent,
26 pub dependencies: Vec<String>,
28 pub consumers: Vec<String>,
30 pub metadata: HashMap<String, String>,
32 pub config: NodeConfig,
34}
35
36pub enum NodeComponent {
38 Transformer(Box<dyn PipelineStep>),
40 Estimator(Box<dyn PipelinePredictor>),
42 DataSource {
44 data: Option<Array2<f64>>,
45 targets: Option<Array1<f64>>,
46 },
47 DataSink,
49 ConditionalBranch {
51 condition: BranchCondition,
52 true_path: String,
53 false_path: String,
54 },
55 DataMerger { merge_strategy: MergeStrategy },
57 CustomFunction {
59 function: Box<dyn Fn(&[NodeOutput]) -> SklResult<NodeOutput> + Send + Sync>,
60 },
61}
62
63impl Debug for NodeComponent {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 NodeComponent::Transformer(_) => f
67 .debug_tuple("Transformer")
68 .field(&"<transformer>")
69 .finish(),
70 NodeComponent::Estimator(_) => {
71 f.debug_tuple("Estimator").field(&"<estimator>").finish()
72 }
73 NodeComponent::DataSource { data, targets } => f
74 .debug_struct("DataSource")
75 .field(
76 "data",
77 &data
78 .as_ref()
79 .map(|d| format!("Array2<f64>({}, {})", d.nrows(), d.ncols())),
80 )
81 .field(
82 "targets",
83 &targets
84 .as_ref()
85 .map(|t| format!("Array1<f64>({})", t.len())),
86 )
87 .finish(),
88 NodeComponent::DataSink => f.debug_tuple("DataSink").finish(),
89 NodeComponent::ConditionalBranch {
90 condition,
91 true_path,
92 false_path,
93 } => f
94 .debug_struct("ConditionalBranch")
95 .field("condition", condition)
96 .field("true_path", true_path)
97 .field("false_path", false_path)
98 .finish(),
99 NodeComponent::DataMerger { merge_strategy } => f
100 .debug_struct("DataMerger")
101 .field("merge_strategy", merge_strategy)
102 .finish(),
103 NodeComponent::CustomFunction { .. } => f
104 .debug_struct("CustomFunction")
105 .field("function", &"<function>")
106 .finish(),
107 }
108 }
109}
110
111#[derive(Debug)]
113pub enum BranchCondition {
114 FeatureThreshold {
116 feature_idx: usize,
117 threshold: f64,
118 comparison: ComparisonOp,
119 },
120 DataSize {
122 min_samples: Option<usize>,
123 max_samples: Option<usize>,
124 },
125 Custom {
127 condition_fn: fn(&NodeOutput) -> bool,
128 },
129}
130
131#[derive(Debug, Clone)]
133pub enum ComparisonOp {
134 GreaterThan,
136 LessThan,
138 GreaterEqual,
140 LessEqual,
142 Equal,
144 NotEqual,
146}
147
148#[derive(Debug)]
150pub enum MergeStrategy {
151 HorizontalConcat,
153 VerticalConcat,
155 Average,
157 WeightedAverage { weights: Vec<f64> },
159 Maximum,
161 Minimum,
163 Custom {
165 merge_fn: fn(&[NodeOutput]) -> SklResult<NodeOutput>,
166 },
167}
168
169#[derive(Debug, Clone)]
171pub struct NodeConfig {
172 pub parallel_execution: bool,
174 pub timeout: Option<f64>,
176 pub retry_attempts: usize,
178 pub cache_output: bool,
180 pub resource_requirements: ResourceRequirements,
182}
183
184impl Default for NodeConfig {
185 fn default() -> Self {
186 Self {
187 parallel_execution: true,
188 timeout: None,
189 retry_attempts: 0,
190 cache_output: false,
191 resource_requirements: ResourceRequirements::default(),
192 }
193 }
194}
195
196#[derive(Debug, Clone, Default)]
198pub struct ResourceRequirements {
199 pub memory_mb: Option<usize>,
201 pub cpu_cores: Option<usize>,
203 pub gpu_required: bool,
205}
206
207#[derive(Debug, Clone)]
209pub struct NodeOutput {
210 pub data: Array2<f64>,
212 pub targets: Option<Array1<f64>>,
214 pub metadata: HashMap<String, String>,
216 pub execution_stats: ExecutionStats,
218}
219
220#[derive(Debug, Clone)]
222pub struct ExecutionStats {
223 pub execution_time: f64,
225 pub memory_usage: f64,
227 pub success: bool,
229 pub error_message: Option<String>,
231}
232
233impl Default for ExecutionStats {
234 fn default() -> Self {
235 Self {
236 execution_time: 0.0,
237 memory_usage: 0.0,
238 success: true,
239 error_message: None,
240 }
241 }
242}
243
244#[derive(Debug)]
246pub struct DAGPipeline<S = Untrained> {
247 state: S,
248 nodes: HashMap<String, DAGNode>,
249 edges: HashMap<String, HashSet<String>>, execution_order: Vec<String>,
251 parallel_groups: Vec<Vec<String>>,
252 cache: HashMap<String, NodeOutput>,
253}
254
255#[derive(Debug)]
257pub struct DAGPipelineTrained {
258 fitted_nodes: HashMap<String, DAGNode>,
259 edges: HashMap<String, HashSet<String>>,
260 execution_order: Vec<String>,
261 parallel_groups: Vec<Vec<String>>,
262 cache: HashMap<String, NodeOutput>,
263 execution_history: Vec<ExecutionRecord>,
264 n_features_in: usize,
265 feature_names_in: Option<Vec<String>>,
266}
267
268#[derive(Debug, Clone)]
270pub struct ExecutionRecord {
271 pub timestamp: f64,
273 pub executed_nodes: Vec<String>,
275 pub total_time: f64,
277 pub success: bool,
279 pub errors: Vec<(String, String)>, }
282
283impl DAGPipeline<Untrained> {
284 #[must_use]
286 pub fn new() -> Self {
287 Self {
288 state: Untrained,
289 nodes: HashMap::new(),
290 edges: HashMap::new(),
291 execution_order: Vec::new(),
292 parallel_groups: Vec::new(),
293 cache: HashMap::new(),
294 }
295 }
296
297 pub fn add_node(mut self, node: DAGNode) -> SklResult<Self> {
299 if self.nodes.contains_key(&node.id) {
301 return Err(SklearsError::InvalidInput(format!(
302 "Node with ID '{}' already exists",
303 node.id
304 )));
305 }
306
307 let node_id = node.id.clone();
309 self.edges
310 .insert(node_id.clone(), node.dependencies.iter().cloned().collect());
311
312 self.nodes.insert(node_id, node);
314
315 self.compute_execution_order()?;
317
318 Ok(self)
319 }
320
321 pub fn add_edge(mut self, from_node: &str, to_node: &str) -> SklResult<Self> {
323 if !self.nodes.contains_key(from_node) {
325 return Err(SklearsError::InvalidInput(format!(
326 "Source node '{from_node}' does not exist"
327 )));
328 }
329 if !self.nodes.contains_key(to_node) {
330 return Err(SklearsError::InvalidInput(format!(
331 "Target node '{to_node}' does not exist"
332 )));
333 }
334
335 self.edges
337 .entry(to_node.to_string())
338 .or_default()
339 .insert(from_node.to_string());
340
341 if let Some(to_node_obj) = self.nodes.get_mut(to_node) {
343 if !to_node_obj.dependencies.contains(&from_node.to_string()) {
344 to_node_obj.dependencies.push(from_node.to_string());
345 }
346 }
347
348 if let Some(from_node_obj) = self.nodes.get_mut(from_node) {
350 if !from_node_obj.consumers.contains(&to_node.to_string()) {
351 from_node_obj.consumers.push(to_node.to_string());
352 }
353 }
354
355 if self.has_cycles()? {
357 return Err(SklearsError::InvalidInput(
358 "Adding edge would create a cycle in the DAG".to_string(),
359 ));
360 }
361
362 self.compute_execution_order()?;
364
365 Ok(self)
366 }
367
368 fn has_cycles(&self) -> SklResult<bool> {
370 let mut visited = HashSet::new();
371 let mut rec_stack = HashSet::new();
372
373 for node_id in self.nodes.keys() {
374 if !visited.contains(node_id)
375 && self.dfs_cycle_check(node_id, &mut visited, &mut rec_stack)?
376 {
377 return Ok(true);
378 }
379 }
380
381 Ok(false)
382 }
383
384 fn dfs_cycle_check(
386 &self,
387 node_id: &str,
388 visited: &mut HashSet<String>,
389 rec_stack: &mut HashSet<String>,
390 ) -> SklResult<bool> {
391 visited.insert(node_id.to_string());
392 rec_stack.insert(node_id.to_string());
393
394 if let Some(dependencies) = self.edges.get(node_id) {
395 for dep in dependencies {
396 if !visited.contains(dep) {
397 if self.dfs_cycle_check(dep, visited, rec_stack)? {
398 return Ok(true);
399 }
400 } else if rec_stack.contains(dep) {
401 return Ok(true);
402 }
403 }
404 }
405
406 rec_stack.remove(node_id);
407 Ok(false)
408 }
409
410 fn compute_execution_order(&mut self) -> SklResult<()> {
412 let mut in_degree = HashMap::new();
413 let mut queue = VecDeque::new();
414 let mut order = Vec::new();
415 let mut parallel_groups = Vec::new();
416
417 for node_id in self.nodes.keys() {
419 in_degree.insert(node_id.clone(), 0);
420 }
421
422 for (node_id, dependencies) in &self.edges {
424 in_degree.insert(node_id.clone(), dependencies.len());
425 }
426
427 for (node_id, °ree) in &in_degree {
429 if degree == 0 {
430 queue.push_back(node_id.clone());
431 }
432 }
433
434 while !queue.is_empty() {
436 let current_level: Vec<String> = queue.drain(..).collect();
437 parallel_groups.push(current_level.clone());
438 order.extend(current_level.iter().cloned());
439
440 for node_id in ¤t_level {
442 if let Some(node) = self.nodes.get(node_id) {
444 for consumer in &node.consumers {
445 if let Some(degree) = in_degree.get_mut(consumer) {
446 *degree -= 1;
447 if *degree == 0 {
448 queue.push_back(consumer.clone());
449 }
450 }
451 }
452 }
453 }
454 }
455
456 if order.len() != self.nodes.len() {
458 return Err(SklearsError::InvalidInput(
459 "DAG contains cycles".to_string(),
460 ));
461 }
462
463 self.execution_order = order;
464 self.parallel_groups = parallel_groups;
465
466 Ok(())
467 }
468
469 pub fn linear(components: Vec<(String, Box<dyn PipelineStep>)>) -> SklResult<Self> {
471 let mut dag = Self::new();
472 let num_components = components.len();
473
474 for (i, (name, component)) in components.into_iter().enumerate() {
475 let dependencies = if i == 0 {
476 Vec::new()
477 } else {
478 vec![format!("node_{}", i - 1)]
479 };
480
481 let node = DAGNode {
482 id: format!("node_{i}"),
483 name,
484 component: NodeComponent::Transformer(component),
485 dependencies,
486 consumers: if i == num_components - 1 {
487 Vec::new()
488 } else {
489 vec![format!("node_{}", i + 1)]
490 },
491 metadata: HashMap::new(),
492 config: NodeConfig::default(),
493 };
494
495 dag = dag.add_node(node)?;
496 }
497
498 Ok(dag)
499 }
500
501 pub fn parallel(
503 components: Vec<(String, Box<dyn PipelineStep>)>,
504 merge_strategy: MergeStrategy,
505 ) -> SklResult<Self> {
506 let mut dag = Self::new();
507
508 let num_components = components.len();
509
510 for (i, (name, component)) in components.into_iter().enumerate() {
512 let node = DAGNode {
513 id: format!("parallel_{i}"),
514 name,
515 component: NodeComponent::Transformer(component),
516 dependencies: Vec::new(),
517 consumers: vec!["merger".to_string()],
518 metadata: HashMap::new(),
519 config: NodeConfig::default(),
520 };
521
522 dag = dag.add_node(node)?;
523 }
524
525 let merger_dependencies: Vec<String> = (0..num_components)
527 .map(|i| format!("parallel_{i}"))
528 .collect();
529
530 let merger_node = DAGNode {
531 id: "merger".to_string(),
532 name: "Data Merger".to_string(),
533 component: NodeComponent::DataMerger { merge_strategy },
534 dependencies: merger_dependencies,
535 consumers: Vec::new(),
536 metadata: HashMap::new(),
537 config: NodeConfig::default(),
538 };
539
540 dag = dag.add_node(merger_node)?;
541
542 Ok(dag)
543 }
544}
545
546impl Default for DAGPipeline<Untrained> {
547 fn default() -> Self {
548 Self::new()
549 }
550}
551
552impl Estimator for DAGPipeline<Untrained> {
553 type Config = ();
554 type Error = SklearsError;
555 type Float = Float;
556
557 fn config(&self) -> &Self::Config {
558 &()
559 }
560}
561
562impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>> for DAGPipeline<Untrained> {
563 type Fitted = DAGPipeline<DAGPipelineTrained>;
564
565 fn fit(
566 mut self,
567 x: &ArrayView2<'_, Float>,
568 y: &Option<&ArrayView1<'_, Float>>,
569 ) -> SklResult<Self::Fitted> {
570 let mut fitted_nodes = HashMap::new();
571 let mut execution_errors = Vec::new();
572 let start_time = std::time::SystemTime::now()
573 .duration_since(std::time::UNIX_EPOCH)
574 .unwrap()
575 .as_secs_f64();
576
577 let initial_output = NodeOutput {
579 data: x.mapv(|v| v),
580 targets: y.as_ref().map(|y_vals| y_vals.mapv(|v| v)),
581 metadata: HashMap::new(),
582 execution_stats: ExecutionStats::default(),
583 };
584 self.cache.insert("input".to_string(), initial_output);
585
586 let parallel_groups = std::mem::take(&mut self.parallel_groups);
588 for group in ¶llel_groups {
589 let group_results = self.execute_parallel_group(group)?;
591
592 for (node_id, result) in group_results {
593 match result {
594 Ok(output) => {
595 self.cache.insert(node_id.clone(), output);
596 if let Some(node) = self.nodes.remove(&node_id) {
597 fitted_nodes.insert(node_id, node);
598 }
599 }
600 Err(e) => {
601 execution_errors.push((node_id, e.to_string()));
602 }
603 }
604 }
605 }
606
607 let end_time = std::time::SystemTime::now()
608 .duration_since(std::time::UNIX_EPOCH)
609 .unwrap()
610 .as_secs_f64();
611
612 let execution_record = ExecutionRecord {
613 timestamp: start_time,
614 executed_nodes: fitted_nodes.keys().cloned().collect(),
615 total_time: end_time - start_time,
616 success: execution_errors.is_empty(),
617 errors: execution_errors,
618 };
619
620 Ok(DAGPipeline {
621 state: DAGPipelineTrained {
622 fitted_nodes,
623 edges: self.edges,
624 execution_order: self.execution_order,
625 parallel_groups,
626 cache: self.cache,
627 execution_history: vec![execution_record],
628 n_features_in: x.ncols(),
629 feature_names_in: None,
630 },
631 nodes: HashMap::new(),
632 edges: HashMap::new(),
633 execution_order: Vec::new(),
634 parallel_groups: Vec::new(),
635 cache: HashMap::new(),
636 })
637 }
638}
639
640impl DAGPipeline<Untrained> {
641 fn execute_parallel_group(
643 &mut self,
644 group: &[String],
645 ) -> SklResult<Vec<(String, SklResult<NodeOutput>)>> {
646 let mut results = Vec::new();
647
648 for node_id in group {
649 if let Some(node) = self.nodes.remove(node_id) {
650 let result = self.execute_node(&node);
651 results.push((node_id.clone(), result));
652 self.nodes.insert(node_id.clone(), node);
654 }
655 }
656
657 Ok(results)
658 }
659
660 fn execute_node(&mut self, node: &DAGNode) -> SklResult<NodeOutput> {
662 let start_time = std::time::SystemTime::now();
663
664 let mut inputs = Vec::new();
666 for dep_id in &node.dependencies {
667 if let Some(output) = self.cache.get(dep_id) {
668 inputs.push(output.clone());
669 } else if dep_id == "input" {
670 } else {
672 return Err(SklearsError::InvalidInput(format!(
673 "Missing input from dependency: {dep_id}"
674 )));
675 }
676 }
677
678 if inputs.is_empty() && self.cache.contains_key("input") {
680 inputs.push(self.cache["input"].clone());
681 }
682
683 let result = match &node.component {
685 NodeComponent::Transformer(transformer) => {
686 if let Some(input) = inputs.first() {
687 let mapped_data = input.data.view().mapv(|v| v as Float);
688 let transformed = transformer.transform(&mapped_data.view())?;
689 Ok(NodeOutput {
690 data: transformed,
691 targets: input.targets.clone(),
692 metadata: HashMap::new(),
693 execution_stats: ExecutionStats::default(),
694 })
695 } else {
696 Err(SklearsError::InvalidInput(
697 "No input data for transformer".to_string(),
698 ))
699 }
700 }
701 NodeComponent::DataMerger { merge_strategy } => {
702 self.execute_data_merger(&inputs, merge_strategy)
703 }
704 NodeComponent::ConditionalBranch {
705 condition,
706 true_path,
707 false_path,
708 } => self.execute_conditional_branch(&inputs, condition, true_path, false_path),
709 NodeComponent::DataSource { data, targets } => {
710 if let Some(ref source_data) = data {
711 Ok(NodeOutput {
712 data: source_data.clone(),
713 targets: targets.clone(),
714 metadata: HashMap::new(),
715 execution_stats: ExecutionStats::default(),
716 })
717 } else {
718 Err(SklearsError::InvalidInput(
719 "No data in data source".to_string(),
720 ))
721 }
722 }
723 NodeComponent::DataSink => {
724 inputs
726 .into_iter()
727 .next()
728 .ok_or_else(|| SklearsError::InvalidInput("No input for data sink".to_string()))
729 }
730 NodeComponent::Estimator(_) => {
731 if let Some(input) = inputs.first() {
733 Ok(input.clone())
734 } else {
735 Err(SklearsError::InvalidInput(
736 "No input data for estimator".to_string(),
737 ))
738 }
739 }
740 NodeComponent::CustomFunction { function } => function(&inputs),
741 };
742
743 let execution_time = start_time.elapsed().unwrap().as_secs_f64();
745 if let Ok(ref mut output) = result.clone() {
746 output.execution_stats.execution_time = execution_time;
747 }
748
749 result
750 }
751
752 fn execute_data_merger(
754 &self,
755 inputs: &[NodeOutput],
756 strategy: &MergeStrategy,
757 ) -> SklResult<NodeOutput> {
758 if inputs.is_empty() {
759 return Err(SklearsError::InvalidInput("No inputs to merge".to_string()));
760 }
761
762 let merged_data = match strategy {
763 MergeStrategy::HorizontalConcat => {
764 let total_cols: usize = inputs.iter().map(|inp| inp.data.ncols()).sum();
765 let n_rows = inputs[0].data.nrows();
766
767 let mut merged = Array2::zeros((n_rows, total_cols));
768 let mut col_offset = 0;
769
770 for input in inputs {
771 let cols = input.data.ncols();
772 merged
773 .slice_mut(s![.., col_offset..col_offset + cols])
774 .assign(&input.data);
775 col_offset += cols;
776 }
777
778 merged
779 }
780 MergeStrategy::VerticalConcat => {
781 let n_cols = inputs[0].data.ncols();
782 let total_rows: usize = inputs.iter().map(|inp| inp.data.nrows()).sum();
783
784 let mut merged = Array2::zeros((total_rows, n_cols));
785 let mut row_offset = 0;
786
787 for input in inputs {
788 let rows = input.data.nrows();
789 merged
790 .slice_mut(s![row_offset..row_offset + rows, ..])
791 .assign(&input.data);
792 row_offset += rows;
793 }
794
795 merged
796 }
797 MergeStrategy::Average => {
798 let mut sum = inputs[0].data.clone();
799 for input in inputs.iter().skip(1) {
800 sum += &input.data;
801 }
802 sum / inputs.len() as f64
803 }
804 MergeStrategy::WeightedAverage { weights } => {
805 if weights.len() != inputs.len() {
806 return Err(SklearsError::InvalidInput(
807 "Number of weights must match number of inputs".to_string(),
808 ));
809 }
810
811 let mut weighted_sum = &inputs[0].data * weights[0];
812 for (input, &weight) in inputs.iter().skip(1).zip(weights.iter().skip(1)) {
813 weighted_sum += &(&input.data * weight);
814 }
815
816 weighted_sum
817 }
818 MergeStrategy::Maximum => {
819 let mut max_data = inputs[0].data.clone();
820 for input in inputs.iter().skip(1) {
821 for ((i, j), &val) in input.data.indexed_iter() {
822 if val > max_data[(i, j)] {
823 max_data[(i, j)] = val;
824 }
825 }
826 }
827 max_data
828 }
829 MergeStrategy::Minimum => {
830 let mut min_data = inputs[0].data.clone();
831 for input in inputs.iter().skip(1) {
832 for ((i, j), &val) in input.data.indexed_iter() {
833 if val < min_data[(i, j)] {
834 min_data[(i, j)] = val;
835 }
836 }
837 }
838 min_data
839 }
840 MergeStrategy::Custom { merge_fn } => {
841 return merge_fn(inputs);
842 }
843 };
844
845 Ok(NodeOutput {
846 data: merged_data,
847 targets: inputs[0].targets.clone(),
848 metadata: HashMap::new(),
849 execution_stats: ExecutionStats::default(),
850 })
851 }
852
853 fn execute_conditional_branch(
855 &self,
856 inputs: &[NodeOutput],
857 condition: &BranchCondition,
858 true_path: &str,
859 false_path: &str,
860 ) -> SklResult<NodeOutput> {
861 if inputs.is_empty() {
862 return Err(SklearsError::InvalidInput(
863 "No input for conditional branch".to_string(),
864 ));
865 }
866
867 let input = &inputs[0];
868 let condition_result = match condition {
869 BranchCondition::FeatureThreshold {
870 feature_idx,
871 threshold,
872 comparison,
873 } => {
874 if *feature_idx >= input.data.ncols() {
875 return Err(SklearsError::InvalidInput(
876 "Feature index out of bounds".to_string(),
877 ));
878 }
879
880 let feature_values = input.data.column(*feature_idx);
881 let mean_value = feature_values.mean().unwrap_or(0.0);
882
883 match comparison {
884 ComparisonOp::GreaterThan => mean_value > *threshold,
885 ComparisonOp::LessThan => mean_value < *threshold,
886 ComparisonOp::GreaterEqual => mean_value >= *threshold,
887 ComparisonOp::LessEqual => mean_value <= *threshold,
888 ComparisonOp::Equal => (mean_value - threshold).abs() < 1e-8,
889 ComparisonOp::NotEqual => (mean_value - threshold).abs() >= 1e-8,
890 }
891 }
892 BranchCondition::DataSize {
893 min_samples,
894 max_samples,
895 } => {
896 let n_samples = input.data.nrows();
897 let min_ok = min_samples.map_or(true, |min| n_samples >= min);
898 let max_ok = max_samples.map_or(true, |max| n_samples <= max);
899 min_ok && max_ok
900 }
901 BranchCondition::Custom { condition_fn } => condition_fn(input),
902 };
903
904 let mut output = input.clone();
906 output.metadata.insert(
907 "branch_taken".to_string(),
908 if condition_result {
909 true_path.to_string()
910 } else {
911 false_path.to_string()
912 },
913 );
914
915 Ok(output)
916 }
917}
918
919impl DAGPipeline<DAGPipelineTrained> {
920 pub fn transform(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
922 Ok(x.mapv(|v| v))
925 }
926
927 #[must_use]
929 pub fn execution_history(&self) -> &[ExecutionRecord] {
930 &self.state.execution_history
931 }
932
933 #[must_use]
935 pub fn statistics(&self) -> HashMap<String, f64> {
936 let mut stats = HashMap::new();
937 stats.insert(
938 "total_nodes".to_string(),
939 self.state.fitted_nodes.len() as f64,
940 );
941 stats.insert(
942 "parallel_groups".to_string(),
943 self.state.parallel_groups.len() as f64,
944 );
945
946 if let Some(last_execution) = self.state.execution_history.last() {
947 stats.insert("last_execution_time".to_string(), last_execution.total_time);
948 stats.insert(
949 "last_execution_success".to_string(),
950 if last_execution.success { 1.0 } else { 0.0 },
951 );
952 }
953
954 stats
955 }
956
957 #[must_use]
959 pub fn to_dot(&self) -> String {
960 let mut dot = String::from("digraph DAG {\n");
961
962 for (node_id, node) in &self.state.fitted_nodes {
964 dot.push_str(&format!(" \"{}\" [label=\"{}\"];\n", node_id, node.name));
965 }
966
967 for (to_node, dependencies) in &self.state.edges {
969 for from_node in dependencies {
970 dot.push_str(&format!(" \"{from_node}\" -> \"{to_node}\";\n"));
971 }
972 }
973
974 dot.push_str("}\n");
975 dot
976 }
977}
978
979use scirs2_core::ndarray::s;
981
982#[allow(non_snake_case)]
983#[cfg(test)]
984mod tests {
985 use super::*;
986 use crate::MockTransformer;
987 use scirs2_core::ndarray::array;
988
989 #[test]
990 fn test_dag_node_creation() {
991 let node = DAGNode {
992 id: "test_node".to_string(),
993 name: "Test Node".to_string(),
994 component: NodeComponent::DataSource {
995 data: Some(array![[1.0, 2.0], [3.0, 4.0]]),
996 targets: Some(array![1.0, 0.0]),
997 },
998 dependencies: Vec::new(),
999 consumers: Vec::new(),
1000 metadata: HashMap::new(),
1001 config: NodeConfig::default(),
1002 };
1003
1004 assert_eq!(node.id, "test_node");
1005 assert_eq!(node.name, "Test Node");
1006 }
1007
1008 #[test]
1009 fn test_linear_dag() {
1010 let components = vec![
1011 (
1012 "transformer1".to_string(),
1013 Box::new(MockTransformer::new()) as Box<dyn PipelineStep>,
1014 ),
1015 (
1016 "transformer2".to_string(),
1017 Box::new(MockTransformer::new()) as Box<dyn PipelineStep>,
1018 ),
1019 ];
1020
1021 let dag = DAGPipeline::linear(components).unwrap();
1022 assert_eq!(dag.nodes.len(), 2);
1023 assert_eq!(dag.execution_order.len(), 2);
1024 }
1025
1026 #[test]
1027 fn test_parallel_dag() {
1028 let components = vec![
1029 (
1030 "transformer1".to_string(),
1031 Box::new(MockTransformer::new()) as Box<dyn PipelineStep>,
1032 ),
1033 (
1034 "transformer2".to_string(),
1035 Box::new(MockTransformer::new()) as Box<dyn PipelineStep>,
1036 ),
1037 ];
1038
1039 let dag = DAGPipeline::parallel(components, MergeStrategy::HorizontalConcat).unwrap();
1040 assert_eq!(dag.nodes.len(), 3); }
1042
1043 #[test]
1044 fn test_cycle_detection() {
1045 let mut dag = DAGPipeline::new();
1046
1047 let node1 = DAGNode {
1049 id: "node1".to_string(),
1050 name: "Node 1".to_string(),
1051 component: NodeComponent::DataSource {
1052 data: None,
1053 targets: None,
1054 },
1055 dependencies: vec![],
1056 consumers: vec![],
1057 metadata: HashMap::new(),
1058 config: NodeConfig::default(),
1059 };
1060
1061 let node2 = DAGNode {
1062 id: "node2".to_string(),
1063 name: "Node 2".to_string(),
1064 component: NodeComponent::DataSource {
1065 data: None,
1066 targets: None,
1067 },
1068 dependencies: vec![],
1069 consumers: vec![],
1070 metadata: HashMap::new(),
1071 config: NodeConfig::default(),
1072 };
1073
1074 dag = dag.add_node(node1).unwrap();
1075 dag = dag.add_node(node2).unwrap();
1076
1077 dag = dag.add_edge("node1", "node2").unwrap();
1079
1080 assert!(dag.add_edge("node2", "node1").is_err());
1082 }
1083
1084 #[test]
1085 fn test_merge_strategies() {
1086 let input1 = NodeOutput {
1087 data: array![[1.0, 2.0], [3.0, 4.0]],
1088 targets: None,
1089 metadata: HashMap::new(),
1090 execution_stats: ExecutionStats::default(),
1091 };
1092
1093 let input2 = NodeOutput {
1094 data: array![[5.0, 6.0], [7.0, 8.0]],
1095 targets: None,
1096 metadata: HashMap::new(),
1097 execution_stats: ExecutionStats::default(),
1098 };
1099
1100 let inputs = vec![input1, input2];
1101 let dag = DAGPipeline::new();
1102
1103 let result = dag
1105 .execute_data_merger(&inputs, &MergeStrategy::HorizontalConcat)
1106 .unwrap();
1107 assert_eq!(result.data.ncols(), 4);
1108 assert_eq!(result.data.nrows(), 2);
1109
1110 let result = dag
1112 .execute_data_merger(&inputs, &MergeStrategy::Average)
1113 .unwrap();
1114 assert_eq!(result.data[[0, 0]], 3.0); assert_eq!(result.data[[0, 1]], 4.0); }
1117}