1use scirs2_core::ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 types::Float,
12};
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::time::{Duration, Instant};
15
16use super::component_registry::ComponentRegistry;
17use super::workflow_definitions::{Connection, ExecutionMode, StepDefinition, WorkflowDefinition};
18
19#[derive(Debug)]
21pub struct WorkflowExecutor {
22 registry: ComponentRegistry,
24 context: ExecutionContext,
26 stats: ExecutionStatistics,
28}
29
30#[derive(Debug, Clone)]
32pub struct ExecutionContext {
33 pub execution_id: String,
35 pub workflow: WorkflowDefinition,
37 pub data_flow: HashMap<String, StepData>,
39 pub start_time: Instant,
41 pub current_step: Option<String>,
43 pub execution_mode: ExecutionMode,
45}
46
47#[derive(Debug, Clone)]
49pub struct StepData {
50 pub step_id: String,
52 pub port_name: String,
54 pub matrices: HashMap<String, Array2<Float>>,
56 pub arrays: HashMap<String, Array1<Float>>,
58 pub metadata: HashMap<String, String>,
60 pub timestamp: Instant,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ExecutionResult {
67 pub execution_id: String,
69 pub success: bool,
71 pub duration: Duration,
73 pub step_results: Vec<StepExecutionResult>,
75 pub outputs: HashMap<String, String>,
77 pub error: Option<String>,
79 pub performance: PerformanceMetrics,
81}
82
83impl Default for ExecutionResult {
84 fn default() -> Self {
85 Self {
86 execution_id: "unknown".to_string(),
87 success: false,
88 duration: Duration::from_secs(0),
89 step_results: Vec::new(),
90 outputs: HashMap::new(),
91 error: Some("Execution failed".to_string()),
92 performance: PerformanceMetrics::default(),
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct StepExecutionResult {
100 pub step_id: String,
102 pub algorithm: String,
104 pub success: bool,
106 pub duration: Duration,
108 pub memory_usage: u64,
110 pub output_sizes: HashMap<String, usize>,
112 pub error: Option<String>,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct PerformanceMetrics {
119 pub total_time: Duration,
121 pub peak_memory: u64,
123 pub cpu_utilization: f64,
125 pub throughput: f64,
127 pub parallelism_efficiency: f64,
129}
130
131#[derive(Debug, Clone)]
133pub struct ExecutionStatistics {
134 pub total_executions: u64,
136 pub successful_executions: u64,
138 pub failed_executions: u64,
140 pub average_execution_time: Duration,
142 pub step_execution_counts: HashMap<String, u64>,
144}
145
146#[derive(Debug, Clone)]
148pub struct ValidationResult {
149 pub is_valid: bool,
151 pub errors: Vec<ValidationError>,
153 pub warnings: Vec<ValidationWarning>,
155 pub execution_order: Option<Vec<String>>,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ValidationError {
162 pub error_type: String,
164 pub message: String,
166 pub step_id: Option<String>,
168 pub connection: Option<String>,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct ValidationWarning {
175 pub warning_type: String,
177 pub message: String,
179 pub step_id: Option<String>,
181}
182
183impl WorkflowExecutor {
184 #[must_use]
186 pub fn new() -> Self {
187 Self {
188 registry: ComponentRegistry::new(),
189 context: ExecutionContext::new(),
190 stats: ExecutionStatistics::new(),
191 }
192 }
193
194 #[must_use]
196 pub fn with_registry(registry: ComponentRegistry) -> Self {
197 Self {
198 registry,
199 context: ExecutionContext::new(),
200 stats: ExecutionStatistics::new(),
201 }
202 }
203
204 #[must_use]
206 pub fn validate_workflow(&self, workflow: &WorkflowDefinition) -> ValidationResult {
207 let mut errors = Vec::new();
208 let mut warnings = Vec::new();
209
210 if workflow.steps.is_empty() {
212 errors.push(ValidationError {
213 error_type: "EmptyWorkflow".to_string(),
214 message: "Workflow has no steps".to_string(),
215 step_id: None,
216 connection: None,
217 });
218 return ValidationResult {
219 is_valid: false,
220 errors,
221 warnings,
222 execution_order: None,
223 };
224 }
225
226 for step in &workflow.steps {
228 self.validate_step(step, &mut errors, &mut warnings);
229 }
230
231 for connection in &workflow.connections {
233 self.validate_connection(connection, workflow, &mut errors, &mut warnings);
234 }
235
236 if let Err(cycle_error) = self.check_circular_dependencies(workflow) {
238 errors.push(ValidationError {
239 error_type: "CircularDependency".to_string(),
240 message: cycle_error,
241 step_id: None,
242 connection: None,
243 });
244 }
245
246 let execution_order = if errors.is_empty() {
248 self.determine_execution_order(workflow).ok()
249 } else {
250 None
251 };
252
253 ValidationResult {
254 is_valid: errors.is_empty(),
255 errors,
256 warnings,
257 execution_order,
258 }
259 }
260
261 fn validate_step(
263 &self,
264 step: &StepDefinition,
265 errors: &mut Vec<ValidationError>,
266 warnings: &mut Vec<ValidationWarning>,
267 ) {
268 if !self.registry.has_component(&step.algorithm) {
270 errors.push(ValidationError {
271 error_type: "UnknownComponent".to_string(),
272 message: format!("Component '{}' not found in registry", step.algorithm),
273 step_id: Some(step.id.clone()),
274 connection: None,
275 });
276 return;
277 }
278
279 if let Err(param_error) = self
281 .registry
282 .validate_parameters(&step.algorithm, &step.parameters)
283 {
284 errors.push(ValidationError {
285 error_type: "InvalidParameters".to_string(),
286 message: param_error.to_string(),
287 step_id: Some(step.id.clone()),
288 connection: None,
289 });
290 }
291
292 if let Some(component) = self.registry.get_component(&step.algorithm) {
294 if component.deprecated {
295 warnings.push(ValidationWarning {
296 warning_type: "DeprecatedComponent".to_string(),
297 message: format!("Component '{}' is deprecated", step.algorithm),
298 step_id: Some(step.id.clone()),
299 });
300 }
301 }
302 }
303
304 fn validate_connection(
306 &self,
307 connection: &Connection,
308 workflow: &WorkflowDefinition,
309 errors: &mut Vec<ValidationError>,
310 _warnings: &mut Vec<ValidationWarning>,
311 ) {
312 let source_step = workflow.steps.iter().find(|s| s.id == connection.from_step);
314 if source_step.is_none() {
315 errors.push(ValidationError {
316 error_type: "InvalidConnection".to_string(),
317 message: format!("Source step '{}' not found", connection.from_step),
318 step_id: None,
319 connection: Some(format!(
320 "{}:{} -> {}:{}",
321 connection.from_step,
322 connection.from_output,
323 connection.to_step,
324 connection.to_input
325 )),
326 });
327 return;
328 }
329
330 let target_step = workflow.steps.iter().find(|s| s.id == connection.to_step);
332 if target_step.is_none() {
333 errors.push(ValidationError {
334 error_type: "InvalidConnection".to_string(),
335 message: format!("Target step '{}' not found", connection.to_step),
336 step_id: None,
337 connection: Some(format!(
338 "{}:{} -> {}:{}",
339 connection.from_step,
340 connection.from_output,
341 connection.to_step,
342 connection.to_input
343 )),
344 });
345 return;
346 }
347
348 let source = source_step.unwrap();
350 if !source.outputs.contains(&connection.from_output) {
351 errors.push(ValidationError {
352 error_type: "InvalidConnection".to_string(),
353 message: format!(
354 "Step '{}' does not have output '{}'",
355 connection.from_step, connection.from_output
356 ),
357 step_id: Some(source.id.clone()),
358 connection: Some(format!(
359 "{}:{} -> {}:{}",
360 connection.from_step,
361 connection.from_output,
362 connection.to_step,
363 connection.to_input
364 )),
365 });
366 }
367
368 let target = target_step.unwrap();
370 if !target.inputs.contains(&connection.to_input) {
371 errors.push(ValidationError {
372 error_type: "InvalidConnection".to_string(),
373 message: format!(
374 "Step '{}' does not have input '{}'",
375 connection.to_step, connection.to_input
376 ),
377 step_id: Some(target.id.clone()),
378 connection: Some(format!(
379 "{}:{} -> {}:{}",
380 connection.from_step,
381 connection.from_output,
382 connection.to_step,
383 connection.to_input
384 )),
385 });
386 }
387 }
388
389 pub fn check_circular_dependencies(&self, workflow: &WorkflowDefinition) -> Result<(), String> {
391 let mut graph = HashMap::new();
392
393 for step in &workflow.steps {
395 graph.insert(step.id.clone(), HashSet::new());
396 }
397
398 for connection in &workflow.connections {
399 if let Some(dependencies) = graph.get_mut(&connection.to_step) {
400 dependencies.insert(connection.from_step.clone());
401 }
402 }
403
404 let mut visited = HashSet::new();
406 let mut rec_stack = HashSet::new();
407
408 for step_id in graph.keys() {
409 if !visited.contains(step_id)
410 && self.has_cycle_dfs(step_id, &graph, &mut visited, &mut rec_stack)
411 {
412 return Err(format!(
413 "Circular dependency detected involving step '{step_id}'"
414 ));
415 }
416 }
417
418 Ok(())
419 }
420
421 fn has_cycle_dfs(
423 &self,
424 step_id: &str,
425 graph: &HashMap<String, HashSet<String>>,
426 visited: &mut HashSet<String>,
427 rec_stack: &mut HashSet<String>,
428 ) -> bool {
429 visited.insert(step_id.to_string());
430 rec_stack.insert(step_id.to_string());
431
432 if let Some(dependencies) = graph.get(step_id) {
433 for dep in dependencies {
434 if !visited.contains(dep) {
435 if self.has_cycle_dfs(dep, graph, visited, rec_stack) {
436 return true;
437 }
438 } else if rec_stack.contains(dep) {
439 return true;
440 }
441 }
442 }
443
444 rec_stack.remove(step_id);
445 false
446 }
447
448 pub fn determine_execution_order(
450 &self,
451 workflow: &WorkflowDefinition,
452 ) -> SklResult<Vec<String>> {
453 let mut in_degree = HashMap::new();
454 let mut adj_list = HashMap::new();
455
456 for step in &workflow.steps {
458 in_degree.insert(step.id.clone(), 0);
459 adj_list.insert(step.id.clone(), Vec::new());
460 }
461
462 for connection in &workflow.connections {
464 *in_degree.get_mut(&connection.to_step).unwrap() += 1;
465 adj_list
466 .get_mut(&connection.from_step)
467 .unwrap()
468 .push(connection.to_step.clone());
469 }
470
471 let mut queue = VecDeque::new();
473 let mut result = Vec::new();
474
475 for (step_id, degree) in &in_degree {
477 if *degree == 0 {
478 queue.push_back(step_id.clone());
479 }
480 }
481
482 while let Some(current) = queue.pop_front() {
483 result.push(current.clone());
484
485 for neighbor in &adj_list[¤t] {
487 *in_degree.get_mut(neighbor).unwrap() -= 1;
488 if in_degree[neighbor] == 0 {
489 queue.push_back(neighbor.clone());
490 }
491 }
492 }
493
494 if result.len() != workflow.steps.len() {
495 return Err(SklearsError::InvalidInput(
496 "Circular dependency detected".to_string(),
497 ));
498 }
499
500 Ok(result)
501 }
502
503 pub fn execute_workflow(&mut self, workflow: WorkflowDefinition) -> SklResult<ExecutionResult> {
505 let execution_start = Instant::now();
506 let execution_id = uuid::Uuid::new_v4().to_string();
507
508 let validation = self.validate_workflow(&workflow);
510 if !validation.is_valid {
511 return Ok(ExecutionResult {
512 execution_id,
513 success: false,
514 duration: execution_start.elapsed(),
515 step_results: Vec::new(),
516 outputs: HashMap::new(),
517 error: Some(format!(
518 "Workflow validation failed: {:?}",
519 validation.errors
520 )),
521 performance: PerformanceMetrics::default(),
522 });
523 }
524
525 self.context = ExecutionContext {
527 execution_id: execution_id.clone(),
528 workflow: workflow.clone(),
529 data_flow: HashMap::new(),
530 start_time: execution_start,
531 current_step: None,
532 execution_mode: workflow.execution.mode.clone(),
533 };
534
535 let execution_order = validation.execution_order.unwrap();
536 let mut step_results = Vec::new();
537 let mut success = true;
538 let mut error_message = None;
539
540 for step_id in execution_order {
542 let step = workflow.steps.iter().find(|s| s.id == step_id).unwrap();
543 self.context.current_step = Some(step_id.clone());
544
545 match self.execute_step(step) {
546 Ok(step_result) => {
547 step_results.push(step_result);
548 }
549 Err(e) => {
550 success = false;
551 error_message = Some(e.to_string());
552 step_results.push(StepExecutionResult {
553 step_id: step_id.clone(),
554 algorithm: step.algorithm.clone(),
555 success: false,
556 duration: Duration::from_millis(0),
557 memory_usage: 0,
558 output_sizes: HashMap::new(),
559 error: Some(e.to_string()),
560 });
561 break;
562 }
563 }
564 }
565
566 self.stats.total_executions += 1;
568 if success {
569 self.stats.successful_executions += 1;
570 } else {
571 self.stats.failed_executions += 1;
572 }
573
574 let total_duration = execution_start.elapsed();
575 self.stats.average_execution_time = Duration::from_millis(
576 (((self.stats.average_execution_time.as_millis()
577 * u128::from(self.stats.total_executions - 1))
578 + total_duration.as_millis())
579 / u128::from(self.stats.total_executions))
580 .try_into()
581 .unwrap_or(u64::MAX),
582 );
583
584 Ok(ExecutionResult {
585 execution_id,
586 success,
587 duration: total_duration,
588 step_results,
589 outputs: self.extract_final_outputs(&workflow),
590 error: error_message,
591 performance: self.calculate_performance_metrics(execution_start),
592 })
593 }
594
595 fn execute_step(&mut self, step: &StepDefinition) -> SklResult<StepExecutionResult> {
597 let step_start = Instant::now();
598
599 let component = self
601 .registry
602 .get_component(&step.algorithm)
603 .ok_or_else(|| {
604 SklearsError::InvalidInput(format!("Component '{}' not found", step.algorithm))
605 })?;
606
607 let input_data = self.prepare_step_input(step)?;
609
610 let output_data = self.simulate_step_execution(step, &input_data)?;
612
613 self.store_step_output(step, output_data.clone());
615
616 *self
618 .stats
619 .step_execution_counts
620 .entry(step.algorithm.clone())
621 .or_insert(0) += 1;
622
623 Ok(StepExecutionResult {
624 step_id: step.id.clone(),
625 algorithm: step.algorithm.clone(),
626 success: true,
627 duration: step_start.elapsed(),
628 memory_usage: self.estimate_memory_usage(&output_data),
629 output_sizes: output_data
630 .matrices
631 .iter()
632 .map(|(k, v)| (k.clone(), v.len()))
633 .collect(),
634 error: None,
635 })
636 }
637
638 fn prepare_step_input(&self, step: &StepDefinition) -> SklResult<StepData> {
640 let mut input_data = StepData {
641 step_id: step.id.clone(),
642 port_name: "input".to_string(),
643 matrices: HashMap::new(),
644 arrays: HashMap::new(),
645 metadata: HashMap::new(),
646 timestamp: Instant::now(),
647 };
648
649 for connection in &self.context.workflow.connections {
651 if connection.to_step == step.id {
652 let source_data_key =
653 format!("{}:{}", connection.from_step, connection.from_output);
654 if let Some(source_data) = self.context.data_flow.get(&source_data_key) {
655 for (key, matrix) in &source_data.matrices {
657 input_data.matrices.insert(key.clone(), matrix.clone());
658 }
659 for (key, array) in &source_data.arrays {
660 input_data.arrays.insert(key.clone(), array.clone());
661 }
662 }
663 }
664 }
665
666 Ok(input_data)
667 }
668
669 fn simulate_step_execution(
671 &self,
672 step: &StepDefinition,
673 input_data: &StepData,
674 ) -> SklResult<StepData> {
675 let mut output_data = StepData {
679 step_id: step.id.clone(),
680 port_name: "output".to_string(),
681 matrices: HashMap::new(),
682 arrays: HashMap::new(),
683 metadata: HashMap::new(),
684 timestamp: Instant::now(),
685 };
686
687 match step.algorithm.as_str() {
689 "StandardScaler" => {
690 if let Some(input_matrix) = input_data.matrices.get("X") {
692 let scaled_matrix = input_matrix.clone(); output_data
694 .matrices
695 .insert("X_scaled".to_string(), scaled_matrix);
696 }
697 }
698 "LinearRegression" => {
699 if input_data.matrices.contains_key("X") && input_data.arrays.contains_key("y") {
701 output_data
703 .metadata
704 .insert("model_type".to_string(), "LinearRegression".to_string());
705 output_data
706 .metadata
707 .insert("trained".to_string(), "true".to_string());
708 }
709 }
710 _ => {
711 output_data.matrices = input_data.matrices.clone();
713 output_data.arrays = input_data.arrays.clone();
714 }
715 }
716
717 Ok(output_data)
718 }
719
720 fn store_step_output(&mut self, step: &StepDefinition, output_data: StepData) {
722 for output_name in &step.outputs {
723 let key = format!("{}:{}", step.id, output_name);
724 self.context.data_flow.insert(key, output_data.clone());
725 }
726 }
727
728 fn extract_final_outputs(&self, workflow: &WorkflowDefinition) -> HashMap<String, String> {
730 let mut outputs = HashMap::new();
731
732 for output in &workflow.outputs {
733 for step in &workflow.steps {
735 if step.outputs.contains(&output.name) {
736 let key = format!("{}:{}", step.id, output.name);
737 if let Some(data) = self.context.data_flow.get(&key) {
738 outputs.insert(
739 output.name.clone(),
740 format!("Data from step '{}' port '{}'", step.id, output.name),
741 );
742 }
743 }
744 }
745 }
746
747 outputs
748 }
749
750 fn calculate_performance_metrics(&self, start_time: Instant) -> PerformanceMetrics {
752 PerformanceMetrics {
753 total_time: start_time.elapsed(),
754 peak_memory: 0, cpu_utilization: 0.0, throughput: 0.0, parallelism_efficiency: 1.0, }
759 }
760
761 fn estimate_memory_usage(&self, _data: &StepData) -> u64 {
763 1024 * 1024 }
766
767 #[must_use]
769 pub fn get_statistics(&self) -> &ExecutionStatistics {
770 &self.stats
771 }
772}
773
774impl ExecutionContext {
775 fn new() -> Self {
776 Self {
777 execution_id: String::new(),
778 workflow: WorkflowDefinition::default(),
779 data_flow: HashMap::new(),
780 start_time: Instant::now(),
781 current_step: None,
782 execution_mode: ExecutionMode::Sequential,
783 }
784 }
785}
786
787impl ExecutionStatistics {
788 fn new() -> Self {
789 Self {
790 total_executions: 0,
791 successful_executions: 0,
792 failed_executions: 0,
793 average_execution_time: Duration::from_secs(0),
794 step_execution_counts: HashMap::new(),
795 }
796 }
797}
798
799impl Default for PerformanceMetrics {
800 fn default() -> Self {
801 Self {
802 total_time: Duration::from_secs(0),
803 peak_memory: 0,
804 cpu_utilization: 0.0,
805 throughput: 0.0,
806 parallelism_efficiency: 0.0,
807 }
808 }
809}
810
811impl Default for WorkflowExecutor {
812 fn default() -> Self {
813 Self::new()
814 }
815}
816
817#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
819pub enum ExecutionState {
820 Initializing,
822 Preparing,
824 Running,
826 Paused,
828 Completed,
830 Failed,
832 Cancelled,
834 TimedOut,
836}
837
838#[derive(Debug, Clone, Serialize, Deserialize)]
840pub struct ExecutionTracker {
841 pub state: ExecutionState,
843 pub progress: f64,
845 pub current_step: Option<String>,
847 pub completed_steps: Vec<String>,
849 pub failed_steps: Vec<String>,
851 pub start_time: String,
853 pub estimated_completion: Option<String>,
855 pub errors: Vec<String>,
857 pub warnings: Vec<String>,
859}
860
861#[derive(Debug, Clone, Serialize, Deserialize)]
863pub struct ParallelExecutionConfig {
864 pub max_workers: usize,
866 pub queue_size: usize,
868 pub load_balancing: LoadBalancingStrategy,
870 pub thread_pool: ThreadPoolConfig,
872 pub resource_sharing: ResourceSharingStrategy,
874}
875
876#[derive(Debug, Clone, Serialize, Deserialize)]
878pub enum LoadBalancingStrategy {
879 RoundRobin,
881 LeastLoaded,
883 Random,
885 WorkStealing,
887 Custom(String),
889}
890
891#[derive(Debug, Clone, Serialize, Deserialize)]
893pub struct ThreadPoolConfig {
894 pub core_threads: usize,
896 pub max_threads: usize,
898 pub keep_alive_sec: u64,
900 pub stack_size: Option<usize>,
902}
903
904#[derive(Debug, Clone, Serialize, Deserialize)]
906pub enum ResourceSharingStrategy {
907 Exclusive,
909 Shared,
911 CopyOnWrite,
913 MemoryMapped,
915}
916
917#[derive(Debug, Clone, Serialize, Deserialize)]
919pub struct ResourceAllocation {
920 pub cpu: CpuAllocation,
922 pub memory: MemoryAllocation,
924 pub gpu: Option<GpuAllocation>,
926 pub disk: Option<DiskAllocation>,
928 pub network: Option<NetworkAllocation>,
930}
931
932#[derive(Debug, Clone, Serialize, Deserialize)]
934pub struct CpuAllocation {
935 pub cores: usize,
937 pub utilization_limit: f64,
939 pub affinity: Vec<usize>,
941}
942
943#[derive(Debug, Clone, Serialize, Deserialize)]
945pub struct MemoryAllocation {
946 pub max_memory_mb: usize,
948 pub memory_type: MemoryType,
950 pub allow_swap: bool,
952}
953
954#[derive(Debug, Clone, Serialize, Deserialize)]
956pub enum MemoryType {
957 Ram,
959 Hbm,
961 Nvram,
963 Any,
965}
966
967#[derive(Debug, Clone, Serialize, Deserialize)]
969pub struct GpuAllocation {
970 pub device_ids: Vec<usize>,
972 pub memory_per_gpu_mb: usize,
974 pub min_compute_capability: f64,
976}
977
978#[derive(Debug, Clone, Serialize, Deserialize)]
980pub struct DiskAllocation {
981 pub temp_storage_mb: usize,
983 pub storage_paths: Vec<String>,
985 pub io_bandwidth_mbs: Option<f64>,
987}
988
989#[derive(Debug, Clone, Serialize, Deserialize)]
991pub struct NetworkAllocation {
992 pub bandwidth_mbps: f64,
994 pub max_connections: usize,
996 pub interfaces: Vec<String>,
998}
999
1000#[derive(Debug, Clone, Serialize, Deserialize)]
1002pub struct ResourceManager {
1003 pub available_resources: ResourcePool,
1005 pub allocations: HashMap<String, ResourceAllocation>,
1007 pub monitoring: ResourceMonitoring,
1009 pub scheduling_strategy: ResourceSchedulingStrategy,
1011}
1012
1013#[derive(Debug, Clone, Serialize, Deserialize)]
1015pub struct ResourcePool {
1016 pub total_cpu_cores: usize,
1018 pub total_memory_mb: usize,
1020 pub gpus: Vec<GpuInfo>,
1022 pub disk_space_mb: usize,
1024 pub network_bandwidth_mbps: f64,
1026}
1027
1028#[derive(Debug, Clone, Serialize, Deserialize)]
1030pub struct GpuInfo {
1031 pub id: usize,
1033 pub name: String,
1035 pub memory_mb: usize,
1037 pub compute_capability: f64,
1039 pub available: bool,
1041}
1042
1043#[derive(Debug, Clone, Serialize, Deserialize)]
1045pub struct ResourceMonitoring {
1046 pub enabled: bool,
1048 pub interval_sec: u64,
1050 pub thresholds: ResourceThresholds,
1052}
1053
1054#[derive(Debug, Clone, Serialize, Deserialize)]
1056pub struct ResourceThresholds {
1057 pub cpu_warning: f64,
1059 pub memory_warning: f64,
1061 pub disk_warning: f64,
1063}
1064
1065#[derive(Debug, Clone, Serialize, Deserialize)]
1067pub enum ResourceSchedulingStrategy {
1068 Fcfs,
1070 Sjf,
1072 RoundRobin,
1074 Priority,
1076 FairShare,
1078 Custom(String),
1080}
1081
1082#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)]
1084pub enum WorkflowExecutionError {
1085 #[error("Workflow validation error: {0}")]
1087 ValidationError(String),
1088 #[error("Resource allocation failed: {0}")]
1090 ResourceAllocationError(String),
1091 #[error("Step execution failed for '{0}': {1}")]
1093 StepExecutionError(String, String), #[error("Dependency resolution failed: {0}")]
1096 DependencyError(String),
1097 #[error("Workflow timeout: {0}")]
1099 TimeoutError(String),
1100 #[error("Workflow cancelled: {0}")]
1102 CancellationError(String),
1103 #[error("Configuration error: {0}")]
1105 ConfigurationError(String),
1106 #[error("Runtime error: {0}")]
1108 RuntimeError(String),
1109 #[error("System error: {0}")]
1111 SystemError(String),
1112}
1113
1114#[allow(non_snake_case)]
1115#[cfg(test)]
1116mod tests {
1117 use super::*;
1118 use crate::workflow_language::workflow_definitions::{DataType, StepType};
1119
1120 #[test]
1121 fn test_workflow_executor_creation() {
1122 let executor = WorkflowExecutor::new();
1123 assert_eq!(executor.stats.total_executions, 0);
1124 }
1125
1126 #[test]
1127 fn test_empty_workflow_validation() {
1128 let executor = WorkflowExecutor::new();
1129 let workflow = WorkflowDefinition::default();
1130
1131 let validation = executor.validate_workflow(&workflow);
1132 assert!(!validation.is_valid);
1133 assert!(!validation.errors.is_empty());
1134 assert_eq!(validation.errors[0].error_type, "EmptyWorkflow");
1135 }
1136
1137 #[test]
1138 fn test_valid_workflow_validation() {
1139 let executor = WorkflowExecutor::new();
1140 let mut workflow = WorkflowDefinition::default();
1141
1142 workflow.steps.push(StepDefinition::new(
1143 "step1",
1144 StepType::Transformer,
1145 "StandardScaler",
1146 ));
1147
1148 let validation = executor.validate_workflow(&workflow);
1149 assert!(validation.is_valid);
1150 assert!(validation.errors.is_empty());
1151 assert!(validation.execution_order.is_some());
1152 }
1153
1154 #[test]
1155 fn test_unknown_component_validation() {
1156 let executor = WorkflowExecutor::new();
1157 let mut workflow = WorkflowDefinition::default();
1158
1159 workflow.steps.push(StepDefinition::new(
1160 "step1",
1161 StepType::Transformer,
1162 "UnknownComponent",
1163 ));
1164
1165 let validation = executor.validate_workflow(&workflow);
1166 assert!(!validation.is_valid);
1167 assert!(!validation.errors.is_empty());
1168 assert_eq!(validation.errors[0].error_type, "UnknownComponent");
1169 }
1170
1171 #[test]
1172 fn test_execution_order_determination() {
1173 let executor = WorkflowExecutor::new();
1174 let mut workflow = WorkflowDefinition::default();
1175
1176 workflow.steps.push(
1178 StepDefinition::new("step1", StepType::Transformer, "StandardScaler")
1179 .with_output("X_scaled"),
1180 );
1181 workflow.steps.push(
1182 StepDefinition::new("step2", StepType::Trainer, "LinearRegression").with_input("X"),
1183 );
1184
1185 workflow
1187 .connections
1188 .push(Connection::direct("step1", "X_scaled", "step2", "X"));
1189
1190 let order = executor.determine_execution_order(&workflow).unwrap();
1191 assert_eq!(order, vec!["step1".to_string(), "step2".to_string()]);
1192 }
1193
1194 #[test]
1195 fn test_circular_dependency_detection() {
1196 let executor = WorkflowExecutor::new();
1197 let mut workflow = WorkflowDefinition::default();
1198
1199 workflow.steps.push(StepDefinition::new(
1201 "step1",
1202 StepType::Transformer,
1203 "StandardScaler",
1204 ));
1205 workflow.steps.push(StepDefinition::new(
1206 "step2",
1207 StepType::Trainer,
1208 "LinearRegression",
1209 ));
1210
1211 workflow
1213 .connections
1214 .push(Connection::direct("step1", "output", "step2", "input"));
1215 workflow
1216 .connections
1217 .push(Connection::direct("step2", "output", "step1", "input"));
1218
1219 let result = executor.check_circular_dependencies(&workflow);
1220 assert!(result.is_err());
1221 }
1222}