1use sklears_core::error::{Result as SklResult, SklearsError};
7use std::cmp::Ordering;
8use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
9use std::sync::{Arc, Condvar, Mutex, RwLock};
10use std::thread::{self, JoinHandle};
11use std::time::{Duration, SystemTime};
12
13use crate::distributed::{NodeId, ResourceRequirements, TaskId, TaskPriority};
14
15#[derive(Debug, Clone)]
17pub struct ScheduledTask {
18 pub id: TaskId,
20 pub name: String,
22 pub component_type: ComponentType,
24 pub dependencies: Vec<TaskId>,
26 pub resource_requirements: ResourceRequirements,
28 pub priority: TaskPriority,
30 pub estimated_duration: Duration,
32 pub submitted_at: SystemTime,
34 pub deadline: Option<SystemTime>,
36 pub metadata: HashMap<String, String>,
38 pub retry_config: RetryConfig,
40}
41
42#[derive(Debug, Clone)]
44pub enum ComponentType {
45 Transformer,
47 Predictor,
49 DataProcessor,
51 CustomFunction,
53}
54
55#[derive(Debug, Clone)]
57pub struct RetryConfig {
58 pub max_retries: usize,
60 pub delay_strategy: RetryDelayStrategy,
62 pub backoff_multiplier: f64,
64 pub max_delay: Duration,
66}
67
68#[derive(Debug, Clone)]
70pub enum RetryDelayStrategy {
71 Fixed(Duration),
73 Linear(Duration),
75 Exponential(Duration),
77 Custom(fn(usize) -> Duration),
79}
80
81impl Default for RetryConfig {
82 fn default() -> Self {
83 Self {
84 max_retries: 3,
85 delay_strategy: RetryDelayStrategy::Exponential(Duration::from_millis(100)),
86 backoff_multiplier: 2.0,
87 max_delay: Duration::from_secs(60),
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq)]
94pub enum TaskState {
95 Pending,
97 Ready,
99 Running {
101 started_at: SystemTime,
102 node_id: Option<NodeId>,
103 },
104 Completed {
106 completed_at: SystemTime,
107 execution_time: Duration,
108 },
109 Failed {
111 failed_at: SystemTime,
112 error: String,
113 retry_count: usize,
114 },
115 Cancelled { cancelled_at: SystemTime },
117 Retrying {
119 next_retry_at: SystemTime,
120 retry_count: usize,
121 },
122}
123
124#[derive(Debug)]
126struct PriorityTask {
127 task: ScheduledTask,
128 priority_score: i64,
129}
130
131impl PartialEq for PriorityTask {
132 fn eq(&self, other: &Self) -> bool {
133 self.priority_score == other.priority_score
134 }
135}
136
137impl Eq for PriorityTask {}
138
139impl PartialOrd for PriorityTask {
140 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
141 Some(self.cmp(other))
142 }
143}
144
145impl Ord for PriorityTask {
146 fn cmp(&self, other: &Self) -> Ordering {
147 self.priority_score.cmp(&other.priority_score)
149 }
150}
151
152#[derive(Debug, Clone)]
154pub enum SchedulingStrategy {
155 FIFO,
157 Priority,
159 ShortestJobFirst,
161 EarliestDeadlineFirst,
163 FairShare {
165 time_quantum: Duration,
167 },
168 ResourceAware,
170 Custom {
172 schedule_fn: fn(&[ScheduledTask], &ResourcePool) -> Option<TaskId>,
173 },
174}
175
176#[derive(Debug, Clone)]
178pub struct ResourcePool {
179 pub available_cpu: u32,
181 pub available_memory: u64,
183 pub available_disk: u64,
185 pub available_gpu: u32,
187 pub utilization_history: Vec<ResourceUtilization>,
189}
190
191#[derive(Debug, Clone)]
193pub struct ResourceUtilization {
194 pub timestamp: SystemTime,
196 pub cpu_usage: f64,
198 pub memory_usage: f64,
200 pub disk_usage: f64,
202 pub gpu_usage: f64,
204}
205
206impl Default for ResourcePool {
207 fn default() -> Self {
208 Self {
209 available_cpu: 4,
210 available_memory: 8192,
211 available_disk: 100_000,
212 available_gpu: 0,
213 utilization_history: Vec::new(),
214 }
215 }
216}
217
218pub trait PluggableScheduler: Send + Sync + std::fmt::Debug {
220 fn name(&self) -> &str;
222
223 fn description(&self) -> &str;
225
226 fn initialize(&mut self, config: &SchedulerConfig) -> SklResult<()>;
228
229 fn select_next_task(
231 &self,
232 available_tasks: &[ScheduledTask],
233 resource_pool: &ResourcePool,
234 current_time: SystemTime,
235 ) -> Option<TaskId>;
236
237 fn calculate_priority(&self, task: &ScheduledTask, context: &SchedulingContext) -> i64;
239
240 fn can_schedule_task(&self, task: &ScheduledTask, resource_pool: &ResourcePool) -> bool;
242
243 fn get_metrics(&self) -> SchedulerMetrics;
245
246 fn on_task_completed(&mut self, task_id: &TaskId, execution_time: Duration) -> SklResult<()>;
248
249 fn on_task_failed(&mut self, task_id: &TaskId, error: &str) -> SklResult<()>;
251
252 fn cleanup(&mut self) -> SklResult<()>;
254}
255
256#[derive(Debug, Clone, Default)]
258pub struct SchedulingContext {
259 pub system_load: SystemLoad,
261 pub execution_history: Vec<TaskExecutionHistory>,
263 pub resource_constraints: ResourceConstraints,
265 pub temporal_context: TemporalContext,
267 pub custom_data: HashMap<String, String>,
269}
270
271#[derive(Debug, Clone)]
273pub struct SystemLoad {
274 pub cpu_utilization: f64,
276 pub memory_utilization: f64,
278 pub io_wait: f64,
280 pub network_utilization: f64,
282 pub load_average: (f64, f64, f64),
284}
285
286impl Default for SystemLoad {
287 fn default() -> Self {
288 Self {
289 cpu_utilization: 0.0,
290 memory_utilization: 0.0,
291 io_wait: 0.0,
292 network_utilization: 0.0,
293 load_average: (0.0, 0.0, 0.0),
294 }
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct TaskExecutionHistory {
301 pub task_type: ComponentType,
303 pub execution_time: Duration,
305 pub resource_usage: ResourceUsage,
307 pub success_rate: f64,
309 pub timestamp: SystemTime,
311}
312
313#[derive(Debug, Clone)]
315pub struct ResourceUsage {
316 pub peak_cpu: f64,
318 pub peak_memory: u64,
320 pub io_operations: u64,
322 pub network_bytes: u64,
324}
325
326#[derive(Debug, Clone)]
328pub struct ResourceConstraints {
329 pub max_cpu_per_task: f64,
331 pub max_memory_per_task: u64,
333 pub max_concurrent_io: u32,
335 pub network_bandwidth_limit: u64,
337}
338
339impl Default for ResourceConstraints {
340 fn default() -> Self {
341 Self {
342 max_cpu_per_task: 1.0, max_memory_per_task: 1024, max_concurrent_io: 10, network_bandwidth_limit: 100_000_000, }
347 }
348}
349
350#[derive(Debug, Clone)]
352pub struct TemporalContext {
353 pub current_time: SystemTime,
355 pub business_hours: Option<BusinessHours>,
357 pub maintenance_windows: Vec<MaintenanceWindow>,
359 pub peak_periods: Vec<PeakPeriod>,
361}
362
363impl Default for TemporalContext {
364 fn default() -> Self {
365 Self {
366 current_time: SystemTime::now(),
367 business_hours: None,
368 maintenance_windows: Vec::new(),
369 peak_periods: Vec::new(),
370 }
371 }
372}
373
374#[derive(Debug, Clone)]
376pub struct BusinessHours {
377 pub start: (u8, u8),
379 pub end: (u8, u8),
381 pub business_days: Vec<u8>,
383 pub timezone_offset: i8,
385}
386
387#[derive(Debug, Clone)]
389pub struct MaintenanceWindow {
390 pub name: String,
392 pub start: SystemTime,
394 pub end: SystemTime,
396 pub severity: MaintenanceSeverity,
398}
399
400#[derive(Debug, Clone)]
402pub enum MaintenanceSeverity {
403 Normal,
405 Critical,
407 Emergency,
409}
410
411#[derive(Debug, Clone)]
413pub struct PeakPeriod {
414 pub name: String,
416 pub start: (u8, u8),
418 pub end: (u8, u8),
420 pub peak_factor: f64,
422}
423
424#[derive(Debug, Clone)]
426pub struct SchedulerMetrics {
427 pub tasks_scheduled: u64,
429 pub avg_scheduling_latency: Duration,
431 pub resource_efficiency: f64,
433 pub deadline_miss_rate: f64,
435 pub fairness_index: f64,
437 pub custom_metrics: HashMap<String, f64>,
439}
440
441#[derive(Debug, Clone)]
443pub enum AdvancedSchedulingStrategy {
444 MLAdaptive {
446 model_path: String,
447 feature_extractors: Vec<String>,
448 },
449 GeneticOptimization {
451 population_size: usize,
452 generations: usize,
453 mutation_rate: f64,
454 },
455 MultiObjective {
457 objectives: Vec<SchedulingObjective>,
458 weights: Vec<f64>,
459 },
460 ReinforcementLearning {
462 agent_type: String,
463 learning_rate: f64,
464 exploration_rate: f64,
465 },
466 GameTheory {
468 strategy_type: GameTheoryStrategy,
469 coalition_formation: bool,
470 },
471 QuantumInspired {
473 quantum_operators: Vec<String>,
474 entanglement_depth: usize,
475 },
476}
477
478#[derive(Debug, Clone)]
480pub enum SchedulingObjective {
481 MinimizeMakespan,
483 MinimizeResourceUsage,
485 MaximizeThroughput,
487 MinimizeEnergy,
489 MaximizeFairness,
491 MinimizeDeadlineViolations,
493 Custom {
495 name: String,
496 objective_fn: fn(&[ScheduledTask], &ResourcePool) -> f64,
497 },
498}
499
500#[derive(Debug, Clone)]
502pub enum GameTheoryStrategy {
503 NashEquilibrium,
505 Stackelberg,
507 Cooperative,
509 Auction,
511}
512
513pub struct MultiLevelFeedbackScheduler {
515 name: String,
516 queues: Vec<PriorityQueue>,
517 time_quantum: Vec<Duration>,
518 promotion_threshold: Vec<u32>,
519 demotion_threshold: Vec<u32>,
520 aging_factor: f64,
521 metrics: SchedulerMetrics,
522}
523
524#[derive(Debug)]
526struct PriorityQueue {
527 tasks: VecDeque<ScheduledTask>,
528 priority_level: u8,
529 time_slice: Duration,
530}
531
532pub struct FairShareScheduler {
534 name: String,
535 user_shares: HashMap<String, f64>,
536 group_shares: HashMap<String, f64>,
537 usage_history: HashMap<String, Vec<ResourceUsage>>,
538 decay_factor: f64,
539 metrics: SchedulerMetrics,
540}
541
542pub struct DeadlineAwareScheduler {
544 name: String,
545 deadline_weight: f64,
546 urgency_factor: f64,
547 preemption_enabled: bool,
548 grace_period: Duration,
549 metrics: SchedulerMetrics,
550}
551
552pub struct ResourceAwareScheduler {
554 name: String,
555 resource_weights: HashMap<String, f64>,
556 load_balancing_strategy: LoadBalancingStrategy,
557 prediction_window: Duration,
558 efficiency_threshold: f64,
559 metrics: SchedulerMetrics,
560}
561
562#[derive(Debug, Clone)]
564pub enum LoadBalancingStrategy {
565 RoundRobin,
567 LeastLoaded,
569 WeightedRoundRobin { weights: HashMap<String, f64> },
571 Random,
573 ConsistentHashing { virtual_nodes: usize },
575}
576
577pub struct MLAdaptiveScheduler {
579 name: String,
580 model_type: MLModelType,
581 feature_extractors: Vec<Box<dyn FeatureExtractor>>,
582 training_data: Vec<SchedulingDecision>,
583 prediction_accuracy: f64,
584 retraining_threshold: usize,
585 metrics: SchedulerMetrics,
586}
587
588#[derive(Debug, Clone)]
590pub enum MLModelType {
591 DecisionTree,
593 RandomForest { n_trees: usize },
595 NeuralNetwork { layers: Vec<usize> },
597 SVM { kernel: String },
599 ReinforcementLearning { algorithm: String },
601}
602
603pub trait FeatureExtractor: Send + Sync {
605 fn extract_features(&self, context: &SchedulingContext) -> Vec<f64>;
607
608 fn feature_names(&self) -> Vec<String>;
610}
611
612#[derive(Debug, Clone)]
614pub struct SchedulingDecision {
615 pub features: Vec<f64>,
617 pub chosen_task: TaskId,
619 pub outcome: DecisionOutcome,
621 pub timestamp: SystemTime,
623}
624
625#[derive(Debug, Clone)]
627pub struct DecisionOutcome {
628 pub completion_time: Duration,
630 pub resource_utilization: f64,
632 pub deadline_met: bool,
634 pub satisfaction_score: f64,
636}
637
638#[derive(Debug)]
640pub struct TaskScheduler {
641 strategy: SchedulingStrategy,
643 pluggable_schedulers: HashMap<String, Box<dyn PluggableScheduler>>,
645 active_scheduler: Option<String>,
647 task_queue: Arc<Mutex<BinaryHeap<PriorityTask>>>,
649 task_states: Arc<RwLock<HashMap<TaskId, TaskState>>>,
651 resource_pool: Arc<RwLock<ResourcePool>>,
653 dependency_graph: Arc<RwLock<HashMap<TaskId, HashSet<TaskId>>>>,
655 config: SchedulerConfig,
657 context: Arc<RwLock<SchedulingContext>>,
659 task_notification: Arc<Condvar>,
661 scheduler_thread: Option<JoinHandle<()>>,
663 is_running: Arc<Mutex<bool>>,
665}
666
667#[derive(Debug, Clone)]
669pub struct SchedulerConfig {
670 pub max_concurrent_tasks: usize,
672 pub scheduling_interval: Duration,
674 pub monitoring_interval: Duration,
676 pub default_task_timeout: Duration,
678 pub cleanup_interval: Duration,
680 pub max_task_history: usize,
682}
683
684impl Default for SchedulerConfig {
685 fn default() -> Self {
686 Self {
687 max_concurrent_tasks: 10,
688 scheduling_interval: Duration::from_millis(100),
689 monitoring_interval: Duration::from_secs(1),
690 default_task_timeout: Duration::from_secs(3600),
691 cleanup_interval: Duration::from_secs(300),
692 max_task_history: 10000,
693 }
694 }
695}
696
697impl TaskScheduler {
698 #[must_use]
700 pub fn new(strategy: SchedulingStrategy, config: SchedulerConfig) -> Self {
701 Self {
702 strategy,
703 pluggable_schedulers: HashMap::new(),
704 active_scheduler: None,
705 task_queue: Arc::new(Mutex::new(BinaryHeap::new())),
706 task_states: Arc::new(RwLock::new(HashMap::new())),
707 resource_pool: Arc::new(RwLock::new(ResourcePool::default())),
708 dependency_graph: Arc::new(RwLock::new(HashMap::new())),
709 config,
710 context: Arc::new(RwLock::new(SchedulingContext::default())),
711 task_notification: Arc::new(Condvar::new()),
712 scheduler_thread: None,
713 is_running: Arc::new(Mutex::new(false)),
714 }
715 }
716
717 pub fn submit_task(&self, task: ScheduledTask) -> SklResult<()> {
719 let task_id = task.id.clone();
720
721 {
723 let mut graph = self.dependency_graph.write().unwrap();
724 graph.insert(task_id.clone(), task.dependencies.iter().cloned().collect());
725 }
726
727 {
729 let mut states = self.task_states.write().unwrap();
730 states.insert(task_id, TaskState::Pending);
731 }
732
733 let priority_score = self.calculate_priority_score(&task);
735
736 {
738 let mut queue = self.task_queue.lock().unwrap();
739 queue.push(PriorityTask {
740 task,
741 priority_score,
742 });
743 }
744
745 self.task_notification.notify_one();
747
748 Ok(())
749 }
750
751 fn calculate_priority_score(&self, task: &ScheduledTask) -> i64 {
753 let mut score = match task.priority {
754 TaskPriority::Low => 1,
755 TaskPriority::Normal => 10,
756 TaskPriority::High => 100,
757 TaskPriority::Critical => 1000,
758 };
759
760 if let Some(deadline) = task.deadline {
762 let time_to_deadline = deadline
763 .duration_since(SystemTime::now())
764 .unwrap_or(Duration::ZERO)
765 .as_secs() as i64;
766 score += 1_000_000 / (time_to_deadline + 1); }
768
769 let age = SystemTime::now()
771 .duration_since(task.submitted_at)
772 .unwrap_or(Duration::ZERO)
773 .as_secs() as i64;
774 score += age / 60; score
777 }
778
779 pub fn start(&mut self) -> SklResult<()> {
781 {
782 let mut running = self.is_running.lock().unwrap();
783 *running = true;
784 }
785
786 let task_queue = Arc::clone(&self.task_queue);
787 let task_states = Arc::clone(&self.task_states);
788 let resource_pool = Arc::clone(&self.resource_pool);
789 let dependency_graph = Arc::clone(&self.dependency_graph);
790 let task_notification = Arc::clone(&self.task_notification);
791 let is_running = Arc::clone(&self.is_running);
792 let config = self.config.clone();
793 let strategy = self.strategy.clone();
794
795 let handle = thread::spawn(move || {
796 Self::scheduler_loop(
797 task_queue,
798 task_states,
799 resource_pool,
800 dependency_graph,
801 task_notification,
802 is_running,
803 config,
804 strategy,
805 );
806 });
807
808 self.scheduler_thread = Some(handle);
809 Ok(())
810 }
811
812 pub fn stop(&mut self) -> SklResult<()> {
814 {
815 let mut running = self.is_running.lock().unwrap();
816 *running = false;
817 }
818
819 self.task_notification.notify_all();
820
821 if let Some(handle) = self.scheduler_thread.take() {
822 handle.join().map_err(|_| SklearsError::InvalidData {
823 reason: "Failed to join scheduler thread".to_string(),
824 })?;
825 }
826
827 Ok(())
828 }
829
830 fn scheduler_loop(
832 task_queue: Arc<Mutex<BinaryHeap<PriorityTask>>>,
833 task_states: Arc<RwLock<HashMap<TaskId, TaskState>>>,
834 resource_pool: Arc<RwLock<ResourcePool>>,
835 dependency_graph: Arc<RwLock<HashMap<TaskId, HashSet<TaskId>>>>,
836 task_notification: Arc<Condvar>,
837 is_running: Arc<Mutex<bool>>,
838 config: SchedulerConfig,
839 strategy: SchedulingStrategy,
840 ) {
841 let mut lock = task_queue.lock().unwrap();
842
843 while *is_running.lock().unwrap() {
844 let ready_tasks = Self::find_ready_tasks(&task_queue, &task_states, &dependency_graph);
846
847 for task_id in ready_tasks {
849 if Self::count_running_tasks(&task_states) >= config.max_concurrent_tasks {
850 break;
851 }
852
853 if Self::can_allocate_resources(&task_id, &task_states, &resource_pool) {
854 Self::start_task_execution(&task_id, &task_states, &resource_pool);
855 }
856 }
857
858 Self::cleanup_tasks(&task_states, &config);
860
861 Self::update_resource_monitoring(&resource_pool);
863
864 let _guard = task_notification
866 .wait_timeout(lock, config.scheduling_interval)
867 .unwrap();
868 lock = _guard.0;
869 }
870 }
871
872 fn find_ready_tasks(
874 task_queue: &Arc<Mutex<BinaryHeap<PriorityTask>>>,
875 task_states: &Arc<RwLock<HashMap<TaskId, TaskState>>>,
876 dependency_graph: &Arc<RwLock<HashMap<TaskId, HashSet<TaskId>>>>,
877 ) -> Vec<TaskId> {
878 let mut ready_tasks = Vec::new();
879 let states = task_states.read().unwrap();
880 let graph = dependency_graph.read().unwrap();
881
882 for (task_id, state) in states.iter() {
883 if *state == TaskState::Pending {
884 if let Some(dependencies) = graph.get(task_id) {
885 let all_deps_completed = dependencies.iter().all(|dep_id| {
886 if let Some(dep_state) = states.get(dep_id) {
887 matches!(dep_state, TaskState::Completed { .. })
888 } else {
889 false
890 }
891 });
892
893 if all_deps_completed {
894 ready_tasks.push(task_id.clone());
895 }
896 }
897 }
898 }
899
900 ready_tasks
901 }
902
903 fn count_running_tasks(task_states: &Arc<RwLock<HashMap<TaskId, TaskState>>>) -> usize {
905 let states = task_states.read().unwrap();
906 states
907 .values()
908 .filter(|state| matches!(state, TaskState::Running { .. }))
909 .count()
910 }
911
912 fn can_allocate_resources(
914 task_id: &TaskId,
915 task_states: &Arc<RwLock<HashMap<TaskId, TaskState>>>,
916 resource_pool: &Arc<RwLock<ResourcePool>>,
917 ) -> bool {
918 let pool = resource_pool.read().unwrap();
920 pool.available_cpu > 0 && pool.available_memory > 100
921 }
922
923 fn start_task_execution(
925 task_id: &TaskId,
926 task_states: &Arc<RwLock<HashMap<TaskId, TaskState>>>,
927 resource_pool: &Arc<RwLock<ResourcePool>>,
928 ) {
929 let mut states = task_states.write().unwrap();
930 states.insert(
931 task_id.clone(),
932 TaskState::Running {
933 started_at: SystemTime::now(),
934 node_id: Some("local".to_string()),
935 },
936 );
937
938 let mut pool = resource_pool.write().unwrap();
940 pool.available_cpu = pool.available_cpu.saturating_sub(1);
941 pool.available_memory = pool.available_memory.saturating_sub(100);
942 }
943
944 fn cleanup_tasks(
946 task_states: &Arc<RwLock<HashMap<TaskId, TaskState>>>,
947 config: &SchedulerConfig,
948 ) {
949 let mut states = task_states.write().unwrap();
950
951 let cutoff_time = SystemTime::now() - config.cleanup_interval;
952 let mut to_remove = Vec::new();
953
954 for (task_id, state) in states.iter() {
955 let should_remove = match state {
956 TaskState::Completed { completed_at, .. } => *completed_at < cutoff_time,
957 TaskState::Failed { failed_at, .. } => *failed_at < cutoff_time,
958 TaskState::Cancelled { cancelled_at } => *cancelled_at < cutoff_time,
959 _ => false,
960 };
961
962 if should_remove {
963 to_remove.push(task_id.clone());
964 }
965 }
966
967 if states.len() > config.max_task_history {
969 let excess = states.len() - config.max_task_history;
970 for _ in 0..excess {
971 if let Some(oldest_id) = to_remove.first().cloned() {
972 to_remove.remove(0);
973 states.remove(&oldest_id);
974 }
975 }
976 }
977
978 for task_id in to_remove {
979 states.remove(&task_id);
980 }
981 }
982
983 fn update_resource_monitoring(resource_pool: &Arc<RwLock<ResourcePool>>) {
985 let mut pool = resource_pool.write().unwrap();
986
987 let utilization = ResourceUtilization {
988 timestamp: SystemTime::now(),
989 cpu_usage: 1.0 - (f64::from(pool.available_cpu) / 4.0), memory_usage: 1.0 - (pool.available_memory as f64 / 8192.0), disk_usage: 0.5, gpu_usage: 0.0,
993 };
994
995 pool.utilization_history.push(utilization);
996
997 if pool.utilization_history.len() > 100 {
999 pool.utilization_history.remove(0);
1000 }
1001 }
1002
1003 #[must_use]
1005 pub fn get_task_state(&self, task_id: &TaskId) -> Option<TaskState> {
1006 let states = self.task_states.read().unwrap();
1007 states.get(task_id).cloned()
1008 }
1009
1010 #[must_use]
1012 pub fn get_statistics(&self) -> SchedulerStatistics {
1013 let states = self.task_states.read().unwrap();
1014 let queue = self.task_queue.lock().unwrap();
1015 let pool = self.resource_pool.read().unwrap();
1016
1017 let pending_count = states
1018 .values()
1019 .filter(|s| matches!(s, TaskState::Pending))
1020 .count();
1021 let running_count = states
1022 .values()
1023 .filter(|s| matches!(s, TaskState::Running { .. }))
1024 .count();
1025 let completed_count = states
1026 .values()
1027 .filter(|s| matches!(s, TaskState::Completed { .. }))
1028 .count();
1029 let failed_count = states
1030 .values()
1031 .filter(|s| matches!(s, TaskState::Failed { .. }))
1032 .count();
1033
1034 SchedulerStatistics {
1036 total_tasks: states.len(),
1037 pending_tasks: pending_count,
1038 running_tasks: running_count,
1039 completed_tasks: completed_count,
1040 failed_tasks: failed_count,
1041 queued_tasks: queue.len(),
1042 resource_utilization: pool.utilization_history.last().cloned(),
1043 }
1044 }
1045
1046 pub fn cancel_task(&self, task_id: &TaskId) -> SklResult<()> {
1048 let mut states = self.task_states.write().unwrap();
1049
1050 if let Some(current_state) = states.get(task_id) {
1051 match current_state {
1052 TaskState::Pending | TaskState::Ready => {
1053 states.insert(
1054 task_id.clone(),
1055 TaskState::Cancelled {
1056 cancelled_at: SystemTime::now(),
1057 },
1058 );
1059 Ok(())
1060 }
1061 TaskState::Running { .. } => {
1062 states.insert(
1064 task_id.clone(),
1065 TaskState::Cancelled {
1066 cancelled_at: SystemTime::now(),
1067 },
1068 );
1069 Ok(())
1070 }
1071 _ => Err(SklearsError::InvalidInput(format!(
1072 "Cannot cancel task {task_id} in state {current_state:?}"
1073 ))),
1074 }
1075 } else {
1076 Err(SklearsError::InvalidInput(format!(
1077 "Task {task_id} not found"
1078 )))
1079 }
1080 }
1081
1082 #[must_use]
1084 pub fn list_tasks(&self) -> HashMap<TaskId, TaskState> {
1085 let states = self.task_states.read().unwrap();
1086 states.clone()
1087 }
1088
1089 #[must_use]
1091 pub fn get_resource_utilization(&self) -> ResourceUtilization {
1092 let pool = self.resource_pool.read().unwrap();
1093 pool.utilization_history
1094 .last()
1095 .cloned()
1096 .unwrap_or_else(|| ResourceUtilization {
1097 timestamp: SystemTime::now(),
1098 cpu_usage: 0.0,
1099 memory_usage: 0.0,
1100 disk_usage: 0.0,
1101 gpu_usage: 0.0,
1102 })
1103 }
1104}
1105
1106#[derive(Debug, Clone)]
1108pub struct SchedulerStatistics {
1109 pub total_tasks: usize,
1111 pub pending_tasks: usize,
1113 pub running_tasks: usize,
1115 pub completed_tasks: usize,
1117 pub failed_tasks: usize,
1119 pub queued_tasks: usize,
1121 pub resource_utilization: Option<ResourceUtilization>,
1123}
1124
1125#[derive(Debug)]
1127pub struct WorkflowManager {
1128 scheduler: TaskScheduler,
1130 workflows: Arc<RwLock<HashMap<String, Workflow>>>,
1132 workflow_instances: Arc<RwLock<HashMap<String, WorkflowInstance>>>,
1134}
1135
1136#[derive(Debug, Clone)]
1138pub struct Workflow {
1139 pub id: String,
1141 pub name: String,
1143 pub tasks: Vec<WorkflowTask>,
1145 pub config: WorkflowConfig,
1147}
1148
1149#[derive(Debug, Clone)]
1151pub struct WorkflowTask {
1152 pub id: String,
1154 pub template: TaskTemplate,
1156 pub depends_on: Vec<String>,
1158 pub config_overrides: HashMap<String, String>,
1160}
1161
1162#[derive(Debug, Clone)]
1164pub struct TaskTemplate {
1165 pub name: String,
1167 pub component_type: ComponentType,
1169 pub default_resources: ResourceRequirements,
1171 pub default_config: HashMap<String, String>,
1173}
1174
1175#[derive(Debug, Clone)]
1177pub struct WorkflowConfig {
1178 pub max_parallelism: usize,
1180 pub timeout: Duration,
1182 pub failure_strategy: WorkflowFailureStrategy,
1184 pub retry_config: RetryConfig,
1186}
1187
1188#[derive(Debug, Clone)]
1190pub enum WorkflowFailureStrategy {
1191 StopOnFailure,
1193 ContinueOnFailure,
1195 RetryFailedTasks,
1197 UseFallbackTasks,
1199}
1200
1201#[derive(Debug, Clone)]
1203pub struct WorkflowInstance {
1204 pub id: String,
1206 pub workflow_id: String,
1208 pub state: WorkflowState,
1210 pub task_instances: HashMap<String, TaskId>,
1212 pub started_at: SystemTime,
1214 pub ended_at: Option<SystemTime>,
1216 pub context: HashMap<String, String>,
1218}
1219
1220#[derive(Debug, Clone, PartialEq)]
1222pub enum WorkflowState {
1223 Starting,
1225 Running,
1227 Completed,
1229 Failed { error: String },
1231 Cancelled,
1233 Paused,
1235}
1236
1237impl WorkflowManager {
1238 #[must_use]
1240 pub fn new(scheduler: TaskScheduler) -> Self {
1241 Self {
1242 scheduler,
1243 workflows: Arc::new(RwLock::new(HashMap::new())),
1244 workflow_instances: Arc::new(RwLock::new(HashMap::new())),
1245 }
1246 }
1247
1248 pub fn register_workflow(&self, workflow: Workflow) -> SklResult<()> {
1250 let mut workflows = self.workflows.write().unwrap();
1251 workflows.insert(workflow.id.clone(), workflow);
1252 Ok(())
1253 }
1254
1255 pub fn start_workflow(
1257 &self,
1258 workflow_id: &str,
1259 context: HashMap<String, String>,
1260 ) -> SklResult<String> {
1261 let workflows = self.workflows.read().unwrap();
1262 let workflow = workflows.get(workflow_id).ok_or_else(|| {
1263 SklearsError::InvalidInput(format!("Workflow {workflow_id} not found"))
1264 })?;
1265
1266 let instance_id = format!(
1267 "{}_{}",
1268 workflow_id,
1269 SystemTime::now()
1270 .duration_since(SystemTime::UNIX_EPOCH)
1271 .unwrap()
1272 .as_millis()
1273 );
1274
1275 let instance = WorkflowInstance {
1276 id: instance_id.clone(),
1277 workflow_id: workflow_id.to_string(),
1278 state: WorkflowState::Starting,
1279 task_instances: HashMap::new(),
1280 started_at: SystemTime::now(),
1281 ended_at: None,
1282 context,
1283 };
1284
1285 {
1286 let mut instances = self.workflow_instances.write().unwrap();
1287 instances.insert(instance_id.clone(), instance);
1288 }
1289
1290 self.submit_ready_tasks(&instance_id, workflow)?;
1292
1293 Ok(instance_id)
1294 }
1295
1296 fn submit_ready_tasks(&self, instance_id: &str, workflow: &Workflow) -> SklResult<()> {
1298 let ready_tasks: Vec<_> = workflow
1299 .tasks
1300 .iter()
1301 .filter(|task| task.depends_on.is_empty())
1302 .collect();
1303
1304 for task in ready_tasks {
1305 let scheduled_task = self.create_scheduled_task(instance_id, task)?;
1306 self.scheduler.submit_task(scheduled_task)?;
1307 }
1308
1309 Ok(())
1310 }
1311
1312 fn create_scheduled_task(
1314 &self,
1315 instance_id: &str,
1316 workflow_task: &WorkflowTask,
1317 ) -> SklResult<ScheduledTask> {
1318 let task_id = format!("{}_{}", instance_id, workflow_task.id);
1319
1320 Ok(ScheduledTask {
1321 id: task_id,
1322 name: workflow_task.template.name.clone(),
1323 component_type: workflow_task.template.component_type.clone(),
1324 dependencies: workflow_task
1325 .depends_on
1326 .iter()
1327 .map(|dep| format!("{instance_id}_{dep}"))
1328 .collect(),
1329 resource_requirements: workflow_task.template.default_resources.clone(),
1330 priority: TaskPriority::Normal,
1331 estimated_duration: Duration::from_secs(60),
1332 submitted_at: SystemTime::now(),
1333 deadline: None,
1334 metadata: HashMap::new(),
1335 retry_config: RetryConfig::default(),
1336 })
1337 }
1338
1339 #[must_use]
1341 pub fn get_workflow_status(&self, instance_id: &str) -> Option<WorkflowInstance> {
1342 let instances = self.workflow_instances.read().unwrap();
1343 instances.get(instance_id).cloned()
1344 }
1345
1346 pub fn cancel_workflow(&self, instance_id: &str) -> SklResult<()> {
1348 let mut instances = self.workflow_instances.write().unwrap();
1349
1350 if let Some(instance) = instances.get_mut(instance_id) {
1351 instance.state = WorkflowState::Cancelled;
1352 instance.ended_at = Some(SystemTime::now());
1353
1354 for task_id in instance.task_instances.values() {
1356 let _ = self.scheduler.cancel_task(task_id);
1357 }
1358
1359 Ok(())
1360 } else {
1361 Err(SklearsError::InvalidInput(format!(
1362 "Workflow instance {instance_id} not found"
1363 )))
1364 }
1365 }
1366
1367 #[must_use]
1369 pub fn list_workflow_instances(&self) -> HashMap<String, WorkflowInstance> {
1370 let instances = self.workflow_instances.read().unwrap();
1371 instances.clone()
1372 }
1373}
1374
1375#[allow(non_snake_case)]
1376#[cfg(test)]
1377mod tests {
1378 use super::*;
1379
1380 #[test]
1381 fn test_scheduled_task_creation() {
1382 let task = ScheduledTask {
1383 id: "test_task".to_string(),
1384 name: "Test Task".to_string(),
1385 component_type: ComponentType::Transformer,
1386 dependencies: vec!["dep1".to_string()],
1387 resource_requirements: ResourceRequirements {
1388 cpu_cores: 1,
1389 memory_mb: 512,
1390 disk_mb: 100,
1391 gpu_required: false,
1392 estimated_duration: Duration::from_secs(60),
1393 priority: TaskPriority::Normal,
1394 },
1395 priority: TaskPriority::Normal,
1396 estimated_duration: Duration::from_secs(60),
1397 submitted_at: SystemTime::now(),
1398 deadline: None,
1399 metadata: HashMap::new(),
1400 retry_config: RetryConfig::default(),
1401 };
1402
1403 assert_eq!(task.id, "test_task");
1404 assert_eq!(task.dependencies.len(), 1);
1405 assert_eq!(task.priority, TaskPriority::Normal);
1406 }
1407
1408 #[test]
1409 fn test_task_scheduler_creation() {
1410 let config = SchedulerConfig::default();
1411 let scheduler = TaskScheduler::new(SchedulingStrategy::FIFO, config);
1412
1413 let stats = scheduler.get_statistics();
1414 assert_eq!(stats.total_tasks, 0);
1415 assert_eq!(stats.pending_tasks, 0);
1416 }
1417
1418 #[test]
1419 fn test_priority_task_ordering() {
1420 let task1 = PriorityTask {
1421 task: ScheduledTask {
1422 id: "task1".to_string(),
1423 name: "Task 1".to_string(),
1424 component_type: ComponentType::Transformer,
1425 dependencies: Vec::new(),
1426 resource_requirements: ResourceRequirements {
1427 cpu_cores: 1,
1428 memory_mb: 512,
1429 disk_mb: 100,
1430 gpu_required: false,
1431 estimated_duration: Duration::from_secs(60),
1432 priority: TaskPriority::Normal,
1433 },
1434 priority: TaskPriority::Normal,
1435 estimated_duration: Duration::from_secs(60),
1436 submitted_at: SystemTime::now(),
1437 deadline: None,
1438 metadata: HashMap::new(),
1439 retry_config: RetryConfig::default(),
1440 },
1441 priority_score: 10,
1442 };
1443
1444 let task2 = PriorityTask {
1445 task: ScheduledTask {
1446 id: "task2".to_string(),
1447 name: "Task 2".to_string(),
1448 component_type: ComponentType::Transformer,
1449 dependencies: Vec::new(),
1450 resource_requirements: ResourceRequirements {
1451 cpu_cores: 1,
1452 memory_mb: 512,
1453 disk_mb: 100,
1454 gpu_required: false,
1455 estimated_duration: Duration::from_secs(60),
1456 priority: TaskPriority::High,
1457 },
1458 priority: TaskPriority::High,
1459 estimated_duration: Duration::from_secs(60),
1460 submitted_at: SystemTime::now(),
1461 deadline: None,
1462 metadata: HashMap::new(),
1463 retry_config: RetryConfig::default(),
1464 },
1465 priority_score: 100,
1466 };
1467
1468 assert!(task2 > task1); }
1470
1471 #[test]
1472 fn test_workflow_creation() {
1473 let workflow = Workflow {
1474 id: "test_workflow".to_string(),
1475 name: "Test Workflow".to_string(),
1476 tasks: vec![WorkflowTask {
1477 id: "task1".to_string(),
1478 template: TaskTemplate {
1479 name: "Task 1".to_string(),
1480 component_type: ComponentType::Transformer,
1481 default_resources: ResourceRequirements {
1482 cpu_cores: 1,
1483 memory_mb: 512,
1484 disk_mb: 100,
1485 gpu_required: false,
1486 estimated_duration: Duration::from_secs(60),
1487 priority: TaskPriority::Normal,
1488 },
1489 default_config: HashMap::new(),
1490 },
1491 depends_on: Vec::new(),
1492 config_overrides: HashMap::new(),
1493 }],
1494 config: WorkflowConfig {
1495 max_parallelism: 5,
1496 timeout: Duration::from_secs(3600),
1497 failure_strategy: WorkflowFailureStrategy::StopOnFailure,
1498 retry_config: RetryConfig::default(),
1499 },
1500 };
1501
1502 assert_eq!(workflow.id, "test_workflow");
1503 assert_eq!(workflow.tasks.len(), 1);
1504 assert_eq!(workflow.config.max_parallelism, 5);
1505 }
1506
1507 #[test]
1508 fn test_resource_utilization() {
1509 let utilization = ResourceUtilization {
1510 timestamp: SystemTime::now(),
1511 cpu_usage: 0.5,
1512 memory_usage: 0.7,
1513 disk_usage: 0.3,
1514 gpu_usage: 0.0,
1515 };
1516
1517 assert_eq!(utilization.cpu_usage, 0.5);
1518 assert_eq!(utilization.memory_usage, 0.7);
1519 }
1520}