1use std::collections::{HashMap, VecDeque};
7use std::time::{Duration, SystemTime};
8
9use sklears_core::error::{Result as SklResult, SklearsError};
10
11use super::tasks::{ExecutionTask, TaskPriority};
12
13pub trait TaskScheduler: Send + Sync {
15 fn schedule_task(&mut self, task: ExecutionTask) -> SklResult<TaskHandle>;
17
18 fn schedule_batch(&mut self, tasks: Vec<ExecutionTask>) -> SklResult<Vec<TaskHandle>>;
20
21 fn cancel_task(&mut self, handle: TaskHandle) -> SklResult<()>;
23
24 fn get_status(&self) -> SchedulerStatus;
26
27 fn update_config(&mut self, config: SchedulerConfig) -> SklResult<()>;
29
30 fn get_next_task(&mut self) -> Option<ExecutionTask>;
32
33 fn mark_completed(&mut self, task_id: &str) -> SklResult<()>;
35
36 fn mark_failed(&mut self, task_id: &str, error: String) -> SklResult<()>;
38
39 fn get_metrics(&self) -> SchedulingMetrics;
41}
42
43#[derive(Debug, Clone)]
45pub struct TaskHandle {
46 pub task_id: String,
48 pub scheduled_at: SystemTime,
50 pub estimated_duration: Option<Duration>,
52 pub priority: TaskPriority,
54 pub dependencies: Vec<String>,
56}
57
58#[derive(Debug, Clone)]
60pub struct SchedulerConfig {
61 pub algorithm: SchedulingAlgorithm,
63 pub queue_management: QueueManagement,
65 pub priority_handling: PriorityHandling,
67 pub dependency_resolution: DependencyResolution,
69}
70
71#[derive(Debug, Clone, PartialEq)]
73pub enum SchedulingAlgorithm {
74 FIFO,
75 Priority,
76 ShortestJobFirst,
77 RoundRobin {
78 quantum: Duration,
79 },
80 CFS,
82 MultilevelQueue,
84 WeightedFairQueuing,
86 EarliestDeadlineFirst,
88}
89
90#[derive(Debug, Clone)]
92pub struct QueueManagement {
93 pub max_queue_size: usize,
95 pub overflow_strategy: QueueOverflowStrategy,
97 pub persistence: QueuePersistence,
99}
100
101#[derive(Debug, Clone, PartialEq)]
103pub enum QueueOverflowStrategy {
104 Block,
106 DropOldest,
108 DropNewest,
110 DropLowestPriority,
112 Reject,
114}
115
116#[derive(Debug, Clone, PartialEq)]
118pub enum QueuePersistence {
119 Memory,
121 Disk { path: String },
123 Database { connection_string: String },
125}
126
127#[derive(Debug, Clone)]
129pub struct PriorityHandling {
130 pub levels: Vec<TaskPriority>,
132 pub aging_strategy: AgingStrategy,
134 pub starvation_prevention: bool,
136}
137
138#[derive(Debug, Clone)]
140pub enum AgingStrategy {
141 None,
143 Linear { increment_interval: Duration },
145 Exponential {
147 base_interval: Duration,
148 multiplier: f64,
149 },
150 Custom { function_name: String },
152}
153
154#[derive(Debug, Clone)]
156pub struct DependencyResolution {
157 pub enable_tracking: bool,
159 pub cycle_detection: bool,
161 pub deadlock_prevention: bool,
163 pub resolution_timeout: Duration,
165}
166
167#[derive(Debug, Clone)]
169pub struct SchedulerStatus {
170 pub queued_tasks: usize,
172 pub running_tasks: usize,
174 pub completed_tasks: u64,
176 pub failed_tasks: u64,
178 pub health: SchedulerHealth,
180}
181
182#[derive(Debug, Clone, PartialEq)]
184pub enum SchedulerHealth {
185 Healthy,
187 Overloaded,
189 Degraded { reason: String },
191 Down { reason: String },
193}
194
195#[derive(Debug, Clone)]
197pub struct SchedulingMetrics {
198 pub tasks_scheduled: u64,
200 pub avg_scheduling_time: Duration,
202 pub queue_length: usize,
204 pub efficiency: f64,
206 pub tasks_by_priority: HashMap<TaskPriority, u64>,
208 pub avg_wait_time: Duration,
210 pub throughput: f64,
212 pub last_updated: SystemTime,
214}
215
216impl Default for SchedulingMetrics {
217 fn default() -> Self {
218 Self {
219 tasks_scheduled: 0,
220 avg_scheduling_time: Duration::ZERO,
221 queue_length: 0,
222 efficiency: 1.0,
223 tasks_by_priority: HashMap::new(),
224 avg_wait_time: Duration::ZERO,
225 throughput: 0.0,
226 last_updated: SystemTime::now(),
227 }
228 }
229}
230
231pub struct DefaultTaskScheduler {
233 config: SchedulerConfig,
235 queue: VecDeque<(ExecutionTask, TaskHandle)>,
237 running: HashMap<String, (ExecutionTask, TaskHandle)>,
239 completed: u64,
241 failed: u64,
243 metrics: SchedulingMetrics,
245 start_time: SystemTime,
247}
248
249impl DefaultTaskScheduler {
250 #[must_use]
252 pub fn new(config: SchedulerConfig) -> Self {
253 Self {
254 config,
255 queue: VecDeque::new(),
256 running: HashMap::new(),
257 completed: 0,
258 failed: 0,
259 metrics: SchedulingMetrics::default(),
260 start_time: SystemTime::now(),
261 }
262 }
263
264 fn sort_queue(&mut self) {
266 match self.config.algorithm {
267 SchedulingAlgorithm::Priority => {
268 let mut tasks: Vec<_> = self.queue.drain(..).collect();
269 tasks.sort_by(|(_, handle_a), (_, handle_b)| {
270 handle_b.priority.cmp(&handle_a.priority)
271 });
272 self.queue.extend(tasks);
273 }
274 SchedulingAlgorithm::ShortestJobFirst => {
275 let mut tasks: Vec<_> = self.queue.drain(..).collect();
276 tasks.sort_by(|(_, handle_a), (_, handle_b)| {
277 match (handle_a.estimated_duration, handle_b.estimated_duration) {
278 (Some(a), Some(b)) => a.cmp(&b),
279 (Some(_), None) => std::cmp::Ordering::Less,
280 (None, Some(_)) => std::cmp::Ordering::Greater,
281 (None, None) => std::cmp::Ordering::Equal,
282 }
283 });
284 self.queue.extend(tasks);
285 }
286 SchedulingAlgorithm::EarliestDeadlineFirst => {
287 let mut tasks: Vec<_> = self.queue.drain(..).collect();
288 tasks.sort_by(|(task_a, _), (task_b, _)| {
289 match (task_a.metadata.deadline, task_b.metadata.deadline) {
290 (Some(a), Some(b)) => a.cmp(&b),
291 (Some(_), None) => std::cmp::Ordering::Less,
292 (None, Some(_)) => std::cmp::Ordering::Greater,
293 (None, None) => std::cmp::Ordering::Equal,
294 }
295 });
296 self.queue.extend(tasks);
297 }
298 _ => {} }
300 }
301
302 fn check_dependencies(&self, task: &ExecutionTask) -> bool {
304 if !self.config.dependency_resolution.enable_tracking {
305 return true;
306 }
307
308 for dependency in &task.metadata.dependencies {
309 if self.running.contains_key(dependency) {
311 return false;
312 }
313 if self.queue.iter().any(|(t, _)| t.id == *dependency) {
314 return false;
315 }
316 }
317 true
318 }
319
320 fn update_metrics(&mut self) {
322 self.metrics.queue_length = self.queue.len();
323 self.metrics.last_updated = SystemTime::now();
324
325 if let Ok(elapsed) = self.start_time.elapsed() {
326 let total_tasks =
327 self.completed + self.failed + self.running.len() as u64 + self.queue.len() as u64;
328 if elapsed.as_secs() > 0 {
329 self.metrics.throughput = total_tasks as f64 / elapsed.as_secs_f64();
330 }
331 }
332
333 let total_processed = self.completed + self.failed;
335 if total_processed > 0 {
336 self.metrics.efficiency = self.completed as f64 / total_processed as f64;
337 }
338 }
339}
340
341impl TaskScheduler for DefaultTaskScheduler {
342 fn schedule_task(&mut self, task: ExecutionTask) -> SklResult<TaskHandle> {
343 if self.queue.len() >= self.config.queue_management.max_queue_size {
345 match self.config.queue_management.overflow_strategy {
346 QueueOverflowStrategy::Block => {
347 return Err(SklearsError::InvalidInput(
348 "Queue is full and blocking new tasks".to_string(),
349 ));
350 }
351 QueueOverflowStrategy::Reject => {
352 return Err(SklearsError::InvalidInput(
353 "Queue is full, rejecting new task".to_string(),
354 ));
355 }
356 QueueOverflowStrategy::DropOldest => {
357 self.queue.pop_front();
358 }
359 QueueOverflowStrategy::DropNewest => {
360 self.queue.pop_back();
361 }
362 QueueOverflowStrategy::DropLowestPriority => {
363 if let Some(min_idx) = self
365 .queue
366 .iter()
367 .enumerate()
368 .min_by_key(|(_, (_, handle))| &handle.priority)
369 .map(|(idx, _)| idx)
370 {
371 self.queue.remove(min_idx);
372 }
373 }
374 }
375 }
376
377 let handle = TaskHandle {
378 task_id: task.id.clone(),
379 scheduled_at: SystemTime::now(),
380 estimated_duration: task.metadata.estimated_duration,
381 priority: task.metadata.priority.clone(),
382 dependencies: task.metadata.dependencies.clone(),
383 };
384
385 self.queue.push_back((task, handle.clone()));
386 self.sort_queue();
387
388 self.metrics.tasks_scheduled += 1;
389 *self
390 .metrics
391 .tasks_by_priority
392 .entry(handle.priority.clone())
393 .or_insert(0) += 1;
394
395 self.update_metrics();
396
397 Ok(handle)
398 }
399
400 fn schedule_batch(&mut self, tasks: Vec<ExecutionTask>) -> SklResult<Vec<TaskHandle>> {
401 let mut handles = Vec::new();
402 for task in tasks {
403 let handle = self.schedule_task(task)?;
404 handles.push(handle);
405 }
406 Ok(handles)
407 }
408
409 fn cancel_task(&mut self, handle: TaskHandle) -> SklResult<()> {
410 self.queue.retain(|(_, h)| h.task_id != handle.task_id);
412
413 self.running.remove(&handle.task_id);
415
416 self.update_metrics();
417 Ok(())
418 }
419
420 fn get_status(&self) -> SchedulerStatus {
421 SchedulerStatus {
422 queued_tasks: self.queue.len(),
423 running_tasks: self.running.len(),
424 completed_tasks: self.completed,
425 failed_tasks: self.failed,
426 health: if self.queue.len() > self.config.queue_management.max_queue_size / 2 {
427 SchedulerHealth::Overloaded
428 } else {
429 SchedulerHealth::Healthy
430 },
431 }
432 }
433
434 fn update_config(&mut self, config: SchedulerConfig) -> SklResult<()> {
435 self.config = config;
436 self.sort_queue(); Ok(())
438 }
439
440 fn get_next_task(&mut self) -> Option<ExecutionTask> {
441 let mut task_index = None;
443 for (idx, (task, _)) in self.queue.iter().enumerate() {
444 if self.check_dependencies(task) {
445 task_index = Some(idx);
446 break;
447 }
448 }
449
450 if let Some(idx) = task_index {
451 if let Some((task, handle)) = self.queue.remove(idx) {
452 let task_id = task.id.clone();
453 self.running.insert(task_id.clone(), (task, handle));
454 self.update_metrics();
455 return self.running.get(&task_id).map(|(t, _)| {
457 ExecutionTask {
459 id: t.id.clone(),
460 task_type: t.task_type.clone(),
461 metadata: t.metadata.clone(),
462 requirements: t.requirements.clone(),
463 input_data: None, configuration: t.configuration.clone(),
465 }
466 });
467 }
468 }
469
470 None
471 }
472
473 fn mark_completed(&mut self, task_id: &str) -> SklResult<()> {
474 if self.running.remove(task_id).is_some() {
475 self.completed += 1;
476 self.update_metrics();
477 }
478 Ok(())
479 }
480
481 fn mark_failed(&mut self, task_id: &str, _error: String) -> SklResult<()> {
482 if self.running.remove(task_id).is_some() {
483 self.failed += 1;
484 self.update_metrics();
485 }
486 Ok(())
487 }
488
489 fn get_metrics(&self) -> SchedulingMetrics {
490 self.metrics.clone()
491 }
492}
493
494impl Default for SchedulerConfig {
495 fn default() -> Self {
496 Self {
497 algorithm: SchedulingAlgorithm::Priority,
498 queue_management: QueueManagement {
499 max_queue_size: 1000,
500 overflow_strategy: QueueOverflowStrategy::Block,
501 persistence: QueuePersistence::Memory,
502 },
503 priority_handling: PriorityHandling {
504 levels: vec![
505 TaskPriority::Low,
506 TaskPriority::Normal,
507 TaskPriority::High,
508 TaskPriority::Critical,
509 ],
510 aging_strategy: AgingStrategy::Linear {
511 increment_interval: Duration::from_secs(60),
512 },
513 starvation_prevention: true,
514 },
515 dependency_resolution: DependencyResolution {
516 enable_tracking: true,
517 cycle_detection: true,
518 deadlock_prevention: true,
519 resolution_timeout: Duration::from_secs(30),
520 },
521 }
522 }
523}
524
525pub struct PriorityScheduler {
527 config: SchedulerConfig,
528 queues: HashMap<TaskPriority, VecDeque<(ExecutionTask, TaskHandle)>>,
529 running: HashMap<String, (ExecutionTask, TaskHandle)>,
530 completed: u64,
531 failed: u64,
532 metrics: SchedulingMetrics,
533 start_time: SystemTime,
534}
535
536impl PriorityScheduler {
537 #[must_use]
539 pub fn new(config: SchedulerConfig) -> Self {
540 let mut queues = HashMap::new();
541 for priority in &config.priority_handling.levels {
542 queues.insert(priority.clone(), VecDeque::new());
543 }
544
545 Self {
546 config,
547 queues,
548 running: HashMap::new(),
549 completed: 0,
550 failed: 0,
551 metrics: SchedulingMetrics::default(),
552 start_time: SystemTime::now(),
553 }
554 }
555}
556
557impl TaskScheduler for PriorityScheduler {
558 fn schedule_task(&mut self, task: ExecutionTask) -> SklResult<TaskHandle> {
559 let priority = task.metadata.priority.clone();
560 let handle = TaskHandle {
561 task_id: task.id.clone(),
562 scheduled_at: SystemTime::now(),
563 estimated_duration: task.metadata.estimated_duration,
564 priority: priority.clone(),
565 dependencies: task.metadata.dependencies.clone(),
566 };
567
568 if let Some(queue) = self.queues.get_mut(&priority) {
569 queue.push_back((task, handle.clone()));
570 self.metrics.tasks_scheduled += 1;
571 *self.metrics.tasks_by_priority.entry(priority).or_insert(0) += 1;
572 }
573
574 Ok(handle)
575 }
576
577 fn schedule_batch(&mut self, tasks: Vec<ExecutionTask>) -> SklResult<Vec<TaskHandle>> {
578 let mut handles = Vec::new();
579 for task in tasks {
580 let handle = self.schedule_task(task)?;
581 handles.push(handle);
582 }
583 Ok(handles)
584 }
585
586 fn cancel_task(&mut self, handle: TaskHandle) -> SklResult<()> {
587 if let Some(queue) = self.queues.get_mut(&handle.priority) {
589 queue.retain(|(_, h)| h.task_id != handle.task_id);
590 }
591
592 self.running.remove(&handle.task_id);
594
595 Ok(())
596 }
597
598 fn get_status(&self) -> SchedulerStatus {
599 let total_queued: usize = self
600 .queues
601 .values()
602 .map(std::collections::VecDeque::len)
603 .sum();
604
605 SchedulerStatus {
606 queued_tasks: total_queued,
607 running_tasks: self.running.len(),
608 completed_tasks: self.completed,
609 failed_tasks: self.failed,
610 health: if total_queued > self.config.queue_management.max_queue_size / 2 {
611 SchedulerHealth::Overloaded
612 } else {
613 SchedulerHealth::Healthy
614 },
615 }
616 }
617
618 fn update_config(&mut self, config: SchedulerConfig) -> SklResult<()> {
619 self.config = config;
620 Ok(())
621 }
622
623 fn get_next_task(&mut self) -> Option<ExecutionTask> {
624 for priority in &self.config.priority_handling.levels {
626 if let Some(queue) = self.queues.get_mut(priority) {
627 if let Some((task, handle)) = queue.pop_front() {
628 let task_id = task.id.clone();
629 let result_task = ExecutionTask {
630 id: task.id.clone(),
631 task_type: task.task_type.clone(),
632 metadata: task.metadata.clone(),
633 requirements: task.requirements.clone(),
634 input_data: None, configuration: task.configuration.clone(),
636 };
637 self.running.insert(task_id, (task, handle));
638 return Some(result_task);
639 }
640 }
641 }
642 None
643 }
644
645 fn mark_completed(&mut self, task_id: &str) -> SklResult<()> {
646 if self.running.remove(task_id).is_some() {
647 self.completed += 1;
648 }
649 Ok(())
650 }
651
652 fn mark_failed(&mut self, task_id: &str, _error: String) -> SklResult<()> {
653 if self.running.remove(task_id).is_some() {
654 self.failed += 1;
655 }
656 Ok(())
657 }
658
659 fn get_metrics(&self) -> SchedulingMetrics {
660 self.metrics.clone()
661 }
662}
663
664#[allow(non_snake_case)]
665#[cfg(test)]
666mod tests {
667 use super::*;
668 use crate::execution::tasks::*;
669
670 #[test]
671 fn test_default_scheduler_creation() {
672 let config = SchedulerConfig::default();
673 let scheduler = DefaultTaskScheduler::new(config);
674 let status = scheduler.get_status();
675
676 assert_eq!(status.queued_tasks, 0);
677 assert_eq!(status.running_tasks, 0);
678 assert_eq!(status.completed_tasks, 0);
679 assert_eq!(status.failed_tasks, 0);
680 assert_eq!(status.health, SchedulerHealth::Healthy);
681 }
682
683 #[test]
684 fn test_priority_scheduler_creation() {
685 let config = SchedulerConfig::default();
686 let scheduler = PriorityScheduler::new(config);
687 let status = scheduler.get_status();
688
689 assert_eq!(status.queued_tasks, 0);
690 assert_eq!(status.running_tasks, 0);
691 assert_eq!(status.completed_tasks, 0);
692 assert_eq!(status.failed_tasks, 0);
693 assert_eq!(status.health, SchedulerHealth::Healthy);
694 }
695
696 #[test]
697 fn test_task_scheduling() {
698 let mut scheduler = DefaultTaskScheduler::new(SchedulerConfig::default());
699 let task = create_test_task();
700
701 let handle = scheduler.schedule_task(task).unwrap();
702 assert!(!handle.task_id.is_empty());
703
704 let status = scheduler.get_status();
705 assert_eq!(status.queued_tasks, 1);
706 }
707
708 fn create_test_task() -> ExecutionTask {
709 ExecutionTask {
710 id: "test_task_1".to_string(),
711 task_type: TaskType::Computation,
712 metadata: TaskMetadata {
713 name: "Test Task".to_string(),
714 description: "A test task".to_string(),
715 priority: TaskPriority::Normal,
716 estimated_duration: Some(Duration::from_secs(10)),
717 deadline: None,
718 dependencies: Vec::new(),
719 tags: Vec::new(),
720 created_at: SystemTime::now(),
721 },
722 requirements: ResourceRequirements {
723 cpu_cores: 1.0,
724 memory_bytes: 1024 * 1024,
725 disk_bytes: 0,
726 network_bandwidth: 0,
727 gpu_memory_bytes: 0,
728 special_resources: Vec::new(),
729 },
730 input_data: None,
731 configuration: TaskConfiguration::default(),
732 }
733 }
734}