1use super::{Task, TaskResult, TaskStatus};
3use crate::Result;
4use crate::config::{Config, OverflowStrategy};
5use crate::core::parallel::config::ParallelConfig;
6use crate::error::SubXError;
7use std::collections::VecDeque;
8use std::sync::{Arc, Mutex};
9use tokio::sync::{Semaphore, oneshot};
10
11struct PendingTask {
12 task: Box<dyn Task + Send + Sync>,
13 result_sender: oneshot::Sender<TaskResult>,
14 task_id: String,
15 priority: TaskPriority,
16}
17
18struct ActiveTaskGuard {
24 active_tasks: Arc<Mutex<std::collections::HashMap<String, TaskInfo>>>,
25 task_id: String,
26}
27
28impl Drop for ActiveTaskGuard {
29 fn drop(&mut self) {
30 if let Ok(mut active) = self.active_tasks.lock() {
31 active.remove(&self.task_id);
32 }
33 }
34}
35
36impl PartialEq for PendingTask {
37 fn eq(&self, other: &Self) -> bool {
38 self.priority == other.priority
39 }
40}
41
42impl Eq for PendingTask {}
43
44impl PartialOrd for PendingTask {
45 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
46 Some(self.cmp(other))
47 }
48}
49
50impl Ord for PendingTask {
51 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
52 self.priority.cmp(&other.priority)
53 }
54}
55
56#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
61pub enum TaskPriority {
62 Low = 0,
64 Normal = 1,
66 High = 2,
68 Critical = 3,
70}
71
72#[derive(Debug, Clone)]
77pub struct TaskInfo {
78 pub task_id: String,
80 pub task_type: String,
82 pub status: TaskStatus,
84 pub start_time: std::time::Instant,
86 pub progress: f32,
88}
89
90pub struct TaskScheduler {
92 _config: ParallelConfig,
94 load_balancer: Option<crate::core::parallel::load_balancer::LoadBalancer>,
96 task_timeout: std::time::Duration,
98 worker_idle_timeout: std::time::Duration,
100 task_queue: Arc<Mutex<VecDeque<PendingTask>>>,
101 semaphore: Arc<Semaphore>,
102 active_tasks: Arc<Mutex<std::collections::HashMap<String, TaskInfo>>>,
103 scheduler_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
104}
105
106impl TaskScheduler {
107 pub fn new_with_config(app_config: &Config) -> Result<Self> {
109 let config = ParallelConfig::from_app_config(app_config);
110 config.validate()?;
111 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_jobs));
112 let task_queue = Arc::new(Mutex::new(VecDeque::new()));
113 let active_tasks = Arc::new(Mutex::new(std::collections::HashMap::new()));
114
115 let general = &app_config.general;
117 let scheduler = Self {
118 _config: config.clone(),
119 task_queue: task_queue.clone(),
120 semaphore: semaphore.clone(),
121 active_tasks: active_tasks.clone(),
122 scheduler_handle: Arc::new(Mutex::new(None)),
123 load_balancer: if config.auto_balance_workers {
124 Some(crate::core::parallel::load_balancer::LoadBalancer::new())
125 } else {
126 None
127 },
128 task_timeout: std::time::Duration::from_secs(general.task_timeout_seconds),
129 worker_idle_timeout: std::time::Duration::from_secs(
130 general.worker_idle_timeout_seconds,
131 ),
132 };
133
134 scheduler.start_scheduler_loop();
136 Ok(scheduler)
137 }
138
139 pub fn new_with_defaults() -> Self {
141 let default_app_config = Config::default();
143 let config = ParallelConfig::from_app_config(&default_app_config);
144 let _ = config.validate();
145 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_jobs));
146 let task_queue = Arc::new(Mutex::new(VecDeque::new()));
147 let active_tasks = Arc::new(Mutex::new(std::collections::HashMap::new()));
148
149 let general = &default_app_config.general;
150 let scheduler = Self {
151 _config: config.clone(),
152 task_queue: task_queue.clone(),
153 semaphore: semaphore.clone(),
154 active_tasks: active_tasks.clone(),
155 scheduler_handle: Arc::new(Mutex::new(None)),
156 load_balancer: if config.auto_balance_workers {
157 Some(crate::core::parallel::load_balancer::LoadBalancer::new())
158 } else {
159 None
160 },
161 task_timeout: std::time::Duration::from_secs(general.task_timeout_seconds),
162 worker_idle_timeout: std::time::Duration::from_secs(
163 general.worker_idle_timeout_seconds,
164 ),
165 };
166
167 scheduler.start_scheduler_loop();
169 scheduler
170 }
171
172 pub fn new() -> Result<Self> {
174 let default_config = Config::default();
175 Self::new_with_config(&default_config)
176 }
177
178 fn start_scheduler_loop(&self) {
180 let task_queue = Arc::clone(&self.task_queue);
181 let semaphore = Arc::clone(&self.semaphore);
182 let active_tasks = Arc::clone(&self.active_tasks);
183 let config = self._config.clone();
184 let task_timeout = self.task_timeout;
185 let worker_idle_timeout = self.worker_idle_timeout;
186
187 let handle = tokio::spawn(async move {
188 let mut last_active = std::time::Instant::now();
190 loop {
191 let has_pending = {
193 let q = task_queue.lock().unwrap();
194 !q.is_empty()
195 };
196 let has_active = {
197 let a = active_tasks.lock().unwrap();
198 !a.is_empty()
199 };
200 if has_pending || has_active {
201 last_active = std::time::Instant::now();
202 } else if last_active.elapsed() > worker_idle_timeout {
203 break;
204 }
205 if let Ok(permit) = semaphore.clone().try_acquire_owned() {
207 let pending = {
208 let mut queue = task_queue.lock().unwrap();
209 if config.enable_task_priorities {
211 if let Some((idx, _)) =
213 queue.iter().enumerate().max_by_key(|(_, t)| t.priority)
214 {
215 queue.remove(idx)
216 } else {
217 None
218 }
219 } else {
220 queue.pop_front()
221 }
222 };
223 if let Some(p) = pending {
224 {
226 let mut active = active_tasks.lock().unwrap();
227 if let Some(info) = active.get_mut(&p.task_id) {
228 info.status = TaskStatus::Running;
229 }
230 }
231
232 let task_id = p.task_id.clone();
233 let active_tasks_clone = Arc::clone(&active_tasks);
234
235 tokio::spawn(async move {
237 let result = match tokio::time::timeout(task_timeout, p.task.execute())
239 .await
240 {
241 Ok(res) => res,
242 Err(_) => TaskResult::Failed("Task execution timeout".to_string()),
243 };
244
245 {
247 let mut at = active_tasks_clone.lock().unwrap();
248 if let Some(info) = at.get_mut(&task_id) {
249 info.status = TaskStatus::Completed(result.clone());
250 info.progress = 1.0;
251 }
252 }
253
254 let _ = p.result_sender.send(result);
256
257 drop(permit);
259 });
260 } else {
261 drop(permit);
263 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
264 }
265 } else {
266 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
268 }
269 }
270 });
271
272 *self.scheduler_handle.lock().unwrap() = Some(handle);
274 }
275
276 pub async fn submit_task(&self, task: Box<dyn Task + Send + Sync>) -> Result<TaskResult> {
278 self.submit_task_with_priority(task, TaskPriority::Normal)
279 .await
280 }
281
282 pub async fn submit_task_with_priority(
284 &self,
285 task: Box<dyn Task + Send + Sync>,
286 priority: TaskPriority,
287 ) -> Result<TaskResult> {
288 let task_id = task.task_id();
289 let task_type = task.task_type().to_string();
290 let (tx, rx) = oneshot::channel();
291
292 {
294 let mut active = self.active_tasks.lock().unwrap();
295 active.insert(
296 task_id.clone(),
297 TaskInfo {
298 task_id: task_id.clone(),
299 task_type,
300 status: TaskStatus::Pending,
301 start_time: std::time::Instant::now(),
302 progress: 0.0,
303 },
304 );
305 }
306
307 let _guard = ActiveTaskGuard {
310 active_tasks: Arc::clone(&self.active_tasks),
311 task_id: task_id.clone(),
312 };
313
314 let pending = PendingTask {
316 task,
317 result_sender: tx,
318 task_id: task_id.clone(),
319 priority,
320 };
321 if self.get_queue_size() >= self._config.task_queue_size {
322 match self._config.queue_overflow_strategy {
323 OverflowStrategy::Block => {
324 while self.get_queue_size() >= self._config.task_queue_size {
326 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
327 }
328 }
329 OverflowStrategy::DropOldest => {
330 let evicted_id = {
331 let mut q = self.task_queue.lock().unwrap();
332 if let Some(evicted) = q.pop_front() {
333 let id = evicted.task_id.clone();
334 let _ = evicted.result_sender.send(TaskResult::Failed(
335 "Task dropped due to queue overflow".to_string(),
336 ));
337 Some(id)
338 } else {
339 None
340 }
341 };
342 if let Some(id) = evicted_id {
343 let mut active = self.active_tasks.lock().unwrap();
347 active.remove(&id);
348 }
349 }
350 OverflowStrategy::Reject => {
351 return Err(SubXError::parallel_processing(
352 "Task queue is full".to_string(),
353 ));
354 }
355 OverflowStrategy::Drop => {
356 return Ok(TaskResult::Failed(
358 "Task dropped due to queue overflow".to_string(),
359 ));
360 }
361 OverflowStrategy::Expand => {
362 }
365 }
366 }
367 {
369 let mut q = self.task_queue.lock().unwrap();
370 if self._config.enable_task_priorities {
371 let pos = q
372 .iter()
373 .position(|t| t.priority < pending.priority)
374 .unwrap_or(q.len());
375 q.insert(pos, pending);
376 } else {
377 q.push_back(pending);
378 }
379 }
380
381 self.ensure_scheduler_running();
384
385 let result = rx.await.map_err(|_| {
389 crate::error::SubXError::parallel_processing("Task execution interrupted".to_string())
390 })?;
391
392 Ok(result)
393 }
394
395 fn ensure_scheduler_running(&self) {
401 let needs_restart = {
402 let handle = self.scheduler_handle.lock().unwrap();
403 match handle.as_ref() {
404 Some(h) => h.is_finished(),
405 None => true,
406 }
407 };
408 if needs_restart {
409 self.start_scheduler_loop();
410 }
411 }
412
413 async fn try_execute_next_task(&self) {
414 }
417
418 pub async fn submit_batch_tasks(
420 &self,
421 tasks: Vec<Box<dyn Task + Send + Sync>>,
422 ) -> Vec<TaskResult> {
423 let mut receivers = Vec::new();
424
425 for task in tasks {
427 let task_id = task.task_id();
428 let task_type = task.task_type().to_string();
429 let (tx, rx) = oneshot::channel();
430
431 {
433 let mut active = self.active_tasks.lock().unwrap();
434 active.insert(
435 task_id.clone(),
436 TaskInfo {
437 task_id: task_id.clone(),
438 task_type,
439 status: TaskStatus::Pending,
440 start_time: std::time::Instant::now(),
441 progress: 0.0,
442 },
443 );
444 }
445
446 let pending = PendingTask {
448 task,
449 result_sender: tx,
450 task_id: task_id.clone(),
451 priority: TaskPriority::Normal,
452 };
453 if self.get_queue_size() >= self._config.task_queue_size {
454 match self._config.queue_overflow_strategy {
455 OverflowStrategy::Block => {
456 while self.get_queue_size() >= self._config.task_queue_size {
458 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
459 }
460 }
461 OverflowStrategy::DropOldest => {
462 let evicted_id = {
463 let mut q = self.task_queue.lock().unwrap();
464 if let Some(evicted) = q.pop_front() {
465 let id = evicted.task_id.clone();
466 let _ = evicted.result_sender.send(TaskResult::Failed(
467 "Task dropped due to queue overflow".to_string(),
468 ));
469 Some(id)
470 } else {
471 None
472 }
473 };
474 if let Some(id) = evicted_id {
475 let mut active = self.active_tasks.lock().unwrap();
476 active.remove(&id);
477 }
478 }
479 OverflowStrategy::Reject => {
480 return Vec::new();
482 }
483 OverflowStrategy::Drop => {
484 continue;
486 }
487 OverflowStrategy::Expand => {
488 }
491 }
492 }
493 {
495 let mut q = self.task_queue.lock().unwrap();
496 if self._config.enable_task_priorities {
497 let pos = q
498 .iter()
499 .position(|t| t.priority < pending.priority)
500 .unwrap_or(q.len());
501 q.insert(pos, pending);
502 } else {
503 q.push_back(pending);
504 }
505 }
506
507 receivers.push((task_id, rx));
508 }
509
510 self.ensure_scheduler_running();
512
513 let mut results = Vec::new();
515 for (task_id, rx) in receivers {
516 match rx.await {
517 Ok(result) => results.push(result),
518 Err(_) => {
519 results.push(TaskResult::Failed("Task execution interrupted".to_string()))
520 }
521 }
522
523 {
525 let mut active = self.active_tasks.lock().unwrap();
526 active.remove(&task_id);
527 }
528 }
529
530 results
531 }
532
533 pub fn get_queue_size(&self) -> usize {
535 self.task_queue.lock().unwrap().len()
536 }
537
538 pub fn get_active_workers(&self) -> usize {
540 self._config.max_concurrent_jobs - self.semaphore.available_permits()
541 }
542
543 pub fn get_task_status(&self, task_id: &str) -> Option<TaskInfo> {
545 self.active_tasks.lock().unwrap().get(task_id).cloned()
546 }
547
548 pub fn list_active_tasks(&self) -> Vec<TaskInfo> {
550 self.active_tasks
551 .lock()
552 .unwrap()
553 .values()
554 .cloned()
555 .collect()
556 }
557}
558
559impl Clone for TaskScheduler {
560 fn clone(&self) -> Self {
561 Self {
562 _config: self._config.clone(),
563 task_queue: Arc::clone(&self.task_queue),
564 semaphore: Arc::clone(&self.semaphore),
565 active_tasks: Arc::clone(&self.active_tasks),
566 scheduler_handle: Arc::clone(&self.scheduler_handle),
567 load_balancer: self.load_balancer.clone(),
568 task_timeout: self.task_timeout,
569 worker_idle_timeout: self.worker_idle_timeout,
570 }
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use super::{Task, TaskPriority, TaskResult, TaskScheduler};
577 use std::sync::atomic::{AtomicUsize, Ordering};
578 use std::sync::{Arc, Mutex};
579 use tokio::time::Duration;
580 use uuid::Uuid;
581
582 struct MockTask {
583 name: String,
584 duration: Duration,
585 }
586
587 #[async_trait::async_trait]
588 impl Task for MockTask {
589 async fn execute(&self) -> TaskResult {
590 tokio::time::sleep(self.duration).await;
591 TaskResult::Success(format!("Task completed: {}", self.name))
592 }
593 fn task_type(&self) -> &'static str {
594 "mock"
595 }
596 fn task_id(&self) -> String {
597 format!("mock_{}", self.name)
598 }
599 }
600
601 struct CounterTask {
602 counter: Arc<AtomicUsize>,
603 }
604 impl CounterTask {
605 fn new(counter: Arc<AtomicUsize>) -> Self {
606 Self { counter }
607 }
608 }
609 #[async_trait::async_trait]
610 impl Task for CounterTask {
611 async fn execute(&self) -> TaskResult {
612 self.counter.fetch_add(1, Ordering::SeqCst);
613 TaskResult::Success("Counter task completed".to_string())
614 }
615 fn task_type(&self) -> &'static str {
616 "counter"
617 }
618 fn task_id(&self) -> String {
619 Uuid::new_v4().to_string()
620 }
621 }
622
623 struct OrderTask {
624 name: String,
625 order: Arc<Mutex<Vec<String>>>,
626 }
627 impl OrderTask {
628 fn new(name: &str, order: Arc<Mutex<Vec<String>>>) -> Self {
629 Self {
630 name: name.to_string(),
631 order,
632 }
633 }
634 }
635 #[async_trait::async_trait]
636 impl Task for OrderTask {
637 async fn execute(&self) -> TaskResult {
638 let mut v = self.order.lock().unwrap();
639 v.push(self.name.clone());
640 TaskResult::Success(format!("Order task completed: {}", self.name))
641 }
642 fn task_type(&self) -> &'static str {
643 "order"
644 }
645 fn task_id(&self) -> String {
646 format!("order_{}", self.name)
647 }
648 }
649
650 #[tokio::test]
651 async fn test_task_scheduler_basic() {
652 let scheduler = TaskScheduler::new_with_defaults();
653 let task = Box::new(MockTask {
654 name: "test".to_string(),
655 duration: Duration::from_millis(10),
656 });
657 let result = scheduler.submit_task(task).await.unwrap();
658 assert!(matches!(result, TaskResult::Success(_)));
659 }
660
661 #[tokio::test]
662 async fn test_concurrent_task_execution() {
663 let scheduler = TaskScheduler::new_with_defaults();
664 let counter = Arc::new(AtomicUsize::new(0));
665
666 let task = Box::new(CounterTask::new(counter.clone()));
668 let result = scheduler.submit_task(task).await.unwrap();
669 assert!(matches!(result, TaskResult::Success(_)));
670 assert_eq!(counter.load(Ordering::SeqCst), 1);
671
672 for _ in 0..4 {
674 let task = Box::new(CounterTask::new(counter.clone()));
675 let _result = scheduler.submit_task(task).await.unwrap();
676 }
677 assert_eq!(counter.load(Ordering::SeqCst), 5);
678 }
679
680 #[tokio::test]
681 async fn test_task_priority_ordering() {
682 let scheduler = TaskScheduler::new_with_defaults();
683 let order = Arc::new(Mutex::new(Vec::new()));
684
685 let tasks = vec![
687 (TaskPriority::Low, "low"),
688 (TaskPriority::High, "high"),
689 (TaskPriority::Normal, "normal"),
690 (TaskPriority::Critical, "critical"),
691 ];
692
693 let mut handles = Vec::new();
694 for (prio, name) in tasks {
695 let task = Box::new(OrderTask::new(name, order.clone()));
696 let scheduler_clone = scheduler.clone();
697 let handle = tokio::spawn(async move {
698 scheduler_clone
699 .submit_task_with_priority(task, prio)
700 .await
701 .unwrap()
702 });
703 handles.push(handle);
704 }
705
706 for handle in handles {
708 let _ = handle.await.unwrap();
709 }
710
711 let v = order.lock().unwrap();
712 assert_eq!(v.len(), 4);
713 assert!(v.contains(&"critical".to_string()));
715 assert!(v.contains(&"high".to_string()));
716 assert!(v.contains(&"normal".to_string()));
717 assert!(v.contains(&"low".to_string()));
718 }
719
720 #[tokio::test]
721 async fn test_queue_and_active_workers_metrics() {
722 let scheduler = TaskScheduler::new_with_defaults();
723
724 assert_eq!(scheduler.get_queue_size(), 0);
726 assert_eq!(scheduler.get_active_workers(), 0);
727
728 let task = Box::new(MockTask {
730 name: "long_task".to_string(),
731 duration: Duration::from_millis(100),
732 });
733
734 let handle = {
735 let scheduler_clone = scheduler.clone();
736 tokio::spawn(async move { scheduler_clone.submit_task(task).await })
737 };
738
739 tokio::time::sleep(Duration::from_millis(20)).await;
741
742 let _result = handle.await.unwrap().unwrap();
744
745 assert_eq!(scheduler.get_queue_size(), 0);
747 }
748
749 #[tokio::test]
750 async fn test_continuous_scheduling() {
751 let scheduler = TaskScheduler::new_with_defaults();
752 let counter = Arc::new(AtomicUsize::new(0));
753
754 let mut handles = Vec::new();
756 for i in 0..10 {
757 let task = Box::new(CounterTask::new(counter.clone()));
758 let scheduler_clone = scheduler.clone();
759 let handle =
760 tokio::spawn(async move { scheduler_clone.submit_task(task).await.unwrap() });
761 handles.push(handle);
762
763 if i % 3 == 0 {
765 tokio::time::sleep(Duration::from_millis(5)).await;
766 }
767 }
768
769 for handle in handles {
771 let result = handle.await.unwrap();
772 assert!(matches!(result, TaskResult::Success(_)));
773 }
774
775 assert_eq!(counter.load(Ordering::SeqCst), 10);
777 }
778
779 #[tokio::test]
780 async fn test_batch_task_execution() {
781 let scheduler = TaskScheduler::new_with_defaults();
782 let counter = Arc::new(AtomicUsize::new(0));
783
784 let mut tasks: Vec<Box<dyn Task + Send + Sync>> = Vec::new();
786 for _ in 0..3 {
787 tasks.push(Box::new(CounterTask::new(counter.clone())));
789 }
790
791 let results = scheduler.submit_batch_tasks(tasks).await;
792 assert_eq!(results.len(), 3);
793 assert_eq!(counter.load(Ordering::SeqCst), 3);
794 for result in results {
795 assert!(matches!(result, TaskResult::Success(_)));
796 }
797 }
798
799 #[tokio::test]
800 async fn test_high_concurrency_stress() {
801 let scheduler = TaskScheduler::new_with_defaults();
802 let counter = Arc::new(AtomicUsize::new(0));
803
804 let mut handles = Vec::new();
806 for i in 0..50 {
807 let task = Box::new(CounterTask::new(counter.clone()));
808 let scheduler_clone = scheduler.clone();
809 let priority = match i % 4 {
810 0 => TaskPriority::Low,
811 1 => TaskPriority::Normal,
812 2 => TaskPriority::High,
813 3 => TaskPriority::Critical,
814 _ => TaskPriority::Normal,
815 };
816
817 let handle = tokio::spawn(async move {
818 scheduler_clone
819 .submit_task_with_priority(task, priority)
820 .await
821 .unwrap()
822 });
823 handles.push(handle);
824
825 if i % 5 == 0 {
827 tokio::time::sleep(Duration::from_millis(1)).await;
828 }
829 }
830
831 for handle in handles {
833 let result = handle.await.unwrap();
834 assert!(matches!(result, TaskResult::Success(_)));
835 }
836
837 assert_eq!(counter.load(Ordering::SeqCst), 50);
839
840 assert_eq!(scheduler.get_queue_size(), 0);
842 assert_eq!(scheduler.get_active_workers(), 0);
843 }
844
845 #[tokio::test]
846 async fn test_mixed_batch_and_individual_tasks() {
847 let scheduler = TaskScheduler::new_with_defaults();
848 let counter = Arc::new(AtomicUsize::new(0));
849
850 let mut individual_handles = Vec::new();
852 for _ in 0..3 {
853 let task = Box::new(CounterTask::new(counter.clone()));
854 let scheduler_clone = scheduler.clone();
855 let handle =
856 tokio::spawn(async move { scheduler_clone.submit_task(task).await.unwrap() });
857 individual_handles.push(handle);
858 }
859
860 let mut batch_tasks: Vec<Box<dyn Task + Send + Sync>> = Vec::new();
862 for _ in 0..4 {
863 batch_tasks.push(Box::new(CounterTask::new(counter.clone())));
864 }
865
866 let batch_handle = {
867 let scheduler_clone = scheduler.clone();
868 tokio::spawn(async move { scheduler_clone.submit_batch_tasks(batch_tasks).await })
869 };
870
871 let mut more_individual_handles = Vec::new();
873 for _ in 0..2 {
874 let task = Box::new(CounterTask::new(counter.clone()));
875 let scheduler_clone = scheduler.clone();
876 let handle =
877 tokio::spawn(async move { scheduler_clone.submit_task(task).await.unwrap() });
878 more_individual_handles.push(handle);
879 }
880
881 for handle in individual_handles {
883 let result = handle.await.unwrap();
884 assert!(matches!(result, TaskResult::Success(_)));
885 }
886
887 let batch_results = batch_handle.await.unwrap();
888 assert_eq!(batch_results.len(), 4);
889 for result in batch_results {
890 assert!(matches!(result, TaskResult::Success(_)));
891 }
892
893 for handle in more_individual_handles {
894 let result = handle.await.unwrap();
895 assert!(matches!(result, TaskResult::Success(_)));
896 }
897
898 assert_eq!(counter.load(Ordering::SeqCst), 9);
900 }
901
902 #[tokio::test]
904 async fn test_task_scheduling_strategies() {
905 use std::sync::Arc;
906 use std::sync::atomic::{AtomicUsize, Ordering};
907
908 struct PriorityTask {
909 id: String,
910 priority: TaskPriority,
911 counter: Arc<AtomicUsize>,
912 execution_order: Arc<Mutex<Vec<String>>>,
913 }
914
915 #[async_trait::async_trait]
916 impl Task for PriorityTask {
917 async fn execute(&self) -> TaskResult {
918 self.counter.fetch_add(1, Ordering::SeqCst);
919 self.execution_order.lock().unwrap().push(self.id.clone());
920 tokio::time::sleep(Duration::from_millis(50)).await;
922 TaskResult::Success(format!("Priority task {} completed", self.id))
923 }
924 fn task_type(&self) -> &'static str {
925 "priority"
926 }
927 fn task_id(&self) -> String {
928 self.id.clone()
929 }
930 }
931
932 let scheduler = TaskScheduler::new_with_defaults();
933 let counter = Arc::new(AtomicUsize::new(0));
934 let execution_order = Arc::new(Mutex::new(Vec::new()));
935
936 let priorities = vec![
938 ("low", TaskPriority::Low),
939 ("high", TaskPriority::High),
940 ("critical", TaskPriority::Critical),
941 ("normal", TaskPriority::Normal),
942 ];
943
944 for (id, priority) in priorities {
945 let task = PriorityTask {
946 id: id.to_string(),
947 priority,
948 counter: Arc::clone(&counter),
949 execution_order: Arc::clone(&execution_order),
950 };
951
952 scheduler
953 .submit_task_with_priority(Box::new(task), priority)
954 .await
955 .unwrap();
956 }
957
958 tokio::time::sleep(Duration::from_millis(200)).await;
960
961 let final_count = counter.load(Ordering::SeqCst);
963 assert_eq!(final_count, 4, "All 4 tasks should have been executed");
964
965 let order = execution_order.lock().unwrap();
967 println!("Task execution order: {:?}", *order);
968
969 assert!(
972 order.contains(&"critical".to_string()),
973 "Critical task should have been executed"
974 );
975 assert!(
976 order.contains(&"low".to_string()),
977 "Low task should have been executed"
978 );
979 assert!(
980 order.contains(&"high".to_string()),
981 "High task should have been executed"
982 );
983 assert!(
984 order.contains(&"normal".to_string()),
985 "Normal task should have been executed"
986 );
987 }
988
989 #[tokio::test]
991 async fn test_load_balancing() {
992 let scheduler = TaskScheduler::new_with_defaults();
993 let task_counter = Arc::new(AtomicUsize::new(0));
994
995 for _i in 0..10 {
997 let task = CounterTask::new(Arc::clone(&task_counter));
998 let result = scheduler.submit_task(Box::new(task)).await.unwrap();
999 assert!(matches!(result, TaskResult::Success(_)));
1000 }
1001
1002 let final_count = task_counter.load(Ordering::SeqCst);
1004 assert_eq!(final_count, 10);
1005
1006 assert_eq!(scheduler.get_queue_size(), 0);
1008 }
1009
1010 #[tokio::test]
1012 async fn test_task_priority_processing() {
1013 let scheduler = TaskScheduler::new_with_defaults();
1014
1015 assert!(TaskPriority::Critical > TaskPriority::High);
1017 assert!(TaskPriority::High > TaskPriority::Normal);
1018 assert!(TaskPriority::Normal > TaskPriority::Low);
1019
1020 let high_task = MockTask {
1022 name: "high_priority".to_string(),
1023 duration: Duration::from_millis(5),
1024 };
1025
1026 let low_task = MockTask {
1027 name: "low_priority".to_string(),
1028 duration: Duration::from_millis(5),
1029 };
1030
1031 let high_result = scheduler
1032 .submit_task_with_priority(Box::new(high_task), TaskPriority::High)
1033 .await
1034 .unwrap();
1035 let low_result = scheduler
1036 .submit_task_with_priority(Box::new(low_task), TaskPriority::Low)
1037 .await
1038 .unwrap();
1039
1040 assert!(matches!(high_result, TaskResult::Success(_)));
1041 assert!(matches!(low_result, TaskResult::Success(_)));
1042 }
1043
1044 #[tokio::test]
1046 async fn test_scheduler_state_management() {
1047 let scheduler = TaskScheduler::new_with_defaults();
1048
1049 assert_eq!(scheduler.get_queue_size(), 0);
1051 assert_eq!(scheduler.get_active_workers(), 0);
1052
1053 let task = MockTask {
1055 name: "state_test".to_string(),
1056 duration: Duration::from_millis(50),
1057 };
1058
1059 let result = scheduler.submit_task(Box::new(task)).await.unwrap();
1060
1061 tokio::time::sleep(Duration::from_millis(5)).await;
1063
1064 assert!(matches!(result, TaskResult::Success(_)));
1066
1067 assert_eq!(scheduler.get_queue_size(), 0);
1069 }
1070
1071 #[tokio::test]
1073 async fn test_overflow_strategy_handling() {
1074 let scheduler = TaskScheduler::new_with_defaults();
1075
1076 for i in 0..20 {
1078 let task = MockTask {
1079 name: format!("overflow_test_{}", i),
1080 duration: Duration::from_millis(20),
1081 };
1082
1083 match scheduler.submit_task(Box::new(task)).await {
1084 Ok(result) => {
1085 assert!(matches!(result, TaskResult::Success(_)));
1086 }
1087 Err(_) => {
1088 break;
1090 }
1091 }
1092 }
1093
1094 tokio::time::sleep(Duration::from_millis(100)).await;
1096
1097 assert_eq!(scheduler.get_queue_size(), 0);
1099 }
1100
1101 #[tokio::test]
1103 async fn test_concurrent_task_submission() {
1104 let scheduler = TaskScheduler::new_with_defaults();
1105 let completion_counter = Arc::new(AtomicUsize::new(0));
1106 let mut submission_handles = Vec::new();
1107
1108 for _i in 0..8 {
1110 let scheduler_clone = scheduler.clone();
1111 let counter_clone = Arc::clone(&completion_counter);
1112
1113 let submission_handle = tokio::spawn(async move {
1114 let task = CounterTask::new(counter_clone);
1115 scheduler_clone.submit_task(Box::new(task)).await.unwrap()
1116 });
1117
1118 submission_handles.push(submission_handle);
1119 }
1120
1121 for handle in submission_handles {
1123 let result = handle.await.unwrap();
1124 assert!(matches!(result, TaskResult::Success(_)));
1125 }
1126
1127 let final_count = completion_counter.load(Ordering::SeqCst);
1129 assert_eq!(final_count, 8);
1130 }
1131
1132 #[tokio::test]
1134 async fn test_scheduler_performance_metrics() {
1135 let scheduler = TaskScheduler::new_with_defaults();
1136 let start_time = std::time::Instant::now();
1137 let task_count = 5;
1138
1139 for i in 0..task_count {
1141 let task = MockTask {
1142 name: format!("perf_test_{}", i),
1143 duration: Duration::from_millis(10),
1144 };
1145 let result = scheduler.submit_task(Box::new(task)).await.unwrap();
1146 assert!(matches!(result, TaskResult::Success(_)));
1147 }
1148
1149 let total_time = start_time.elapsed();
1150
1151 assert!(
1153 total_time < Duration::from_millis(500),
1154 "Tasks took too long: {:?}",
1155 total_time
1156 );
1157
1158 assert_eq!(scheduler.get_queue_size(), 0);
1160 assert_eq!(scheduler.get_active_workers(), 0);
1161 }
1162
1163 #[tokio::test]
1166 async fn test_active_task_guard_cleanup() {
1167 use super::{ActiveTaskGuard, TaskInfo};
1168 use std::collections::HashMap;
1169
1170 let active_tasks = Arc::new(Mutex::new(HashMap::<String, TaskInfo>::new()));
1171 let task_id = "guard_test_task".to_string();
1172
1173 active_tasks.lock().unwrap().insert(
1174 task_id.clone(),
1175 TaskInfo {
1176 task_id: task_id.clone(),
1177 task_type: "mock".to_string(),
1178 status: crate::core::parallel::TaskStatus::Pending,
1179 start_time: std::time::Instant::now(),
1180 progress: 0.0,
1181 },
1182 );
1183 assert!(active_tasks.lock().unwrap().contains_key(&task_id));
1184
1185 {
1186 let _guard = ActiveTaskGuard {
1187 active_tasks: Arc::clone(&active_tasks),
1188 task_id: task_id.clone(),
1189 };
1190 assert!(active_tasks.lock().unwrap().contains_key(&task_id));
1192 }
1193
1194 assert!(!active_tasks.lock().unwrap().contains_key(&task_id));
1196 }
1197
1198 #[tokio::test]
1202 async fn test_drop_oldest_sends_failed() {
1203 use crate::config::{Config, OverflowStrategy};
1204
1205 let mut config = Config::default();
1206 config.parallel.task_queue_size = 1;
1207 config.general.max_concurrent_jobs = 1;
1208 config.parallel.overflow_strategy = OverflowStrategy::DropOldest;
1209 config.parallel.enable_task_priorities = false;
1210 config.parallel.auto_balance_workers = false;
1211
1212 let scheduler = TaskScheduler::new_with_config(&config).unwrap();
1213
1214 let blocker = Box::new(MockTask {
1216 name: "blocker".to_string(),
1217 duration: Duration::from_millis(300),
1218 });
1219 let blocker_scheduler = scheduler.clone();
1220 let blocker_handle =
1221 tokio::spawn(async move { blocker_scheduler.submit_task(blocker).await });
1222
1223 tokio::time::sleep(Duration::from_millis(30)).await;
1225
1226 let first = Box::new(MockTask {
1228 name: "first_queued".to_string(),
1229 duration: Duration::from_millis(50),
1230 });
1231 let first_scheduler = scheduler.clone();
1232 let first_handle = tokio::spawn(async move { first_scheduler.submit_task(first).await });
1233
1234 tokio::time::sleep(Duration::from_millis(30)).await;
1236
1237 let second = Box::new(MockTask {
1239 name: "second_queued".to_string(),
1240 duration: Duration::from_millis(10),
1241 });
1242 let second_scheduler = scheduler.clone();
1243 let second_handle = tokio::spawn(async move { second_scheduler.submit_task(second).await });
1244
1245 let first_result = first_handle.await.unwrap().unwrap();
1247 match first_result {
1248 TaskResult::Failed(msg) => {
1249 assert!(
1250 msg.contains("overflow"),
1251 "expected overflow-related failure message, got: {}",
1252 msg
1253 );
1254 }
1255 other => panic!("expected Failed for evicted task, got {:?}", other),
1256 }
1257
1258 let blocker_result = blocker_handle.await.unwrap().unwrap();
1260 assert!(matches!(blocker_result, TaskResult::Success(_)));
1261 let second_result = second_handle.await.unwrap().unwrap();
1262 assert!(matches!(second_result, TaskResult::Success(_)));
1263 }
1264
1265 #[tokio::test]
1269 async fn test_scheduler_restart_after_idle() {
1270 let mut scheduler = TaskScheduler::new_with_defaults();
1271
1272 {
1275 let mut handle = scheduler.scheduler_handle.lock().unwrap();
1276 if let Some(h) = handle.take() {
1277 h.abort();
1278 }
1279 }
1280 tokio::time::sleep(Duration::from_millis(30)).await;
1282
1283 scheduler.worker_idle_timeout = Duration::from_millis(100);
1284 scheduler.start_scheduler_loop();
1285
1286 let t1 = Box::new(MockTask {
1288 name: "before_idle".to_string(),
1289 duration: Duration::from_millis(10),
1290 });
1291 let r1 = scheduler.submit_task(t1).await.unwrap();
1292 assert!(matches!(r1, TaskResult::Success(_)));
1293
1294 tokio::time::sleep(Duration::from_millis(350)).await;
1296
1297 let loop_finished = {
1298 let handle = scheduler.scheduler_handle.lock().unwrap();
1299 handle.as_ref().map(|h| h.is_finished()).unwrap_or(true)
1300 };
1301 assert!(
1302 loop_finished,
1303 "scheduler loop should have exited after idle timeout"
1304 );
1305
1306 let t2 = Box::new(MockTask {
1308 name: "after_idle".to_string(),
1309 duration: Duration::from_millis(10),
1310 });
1311 let r2 = scheduler.submit_task(t2).await.unwrap();
1312 assert!(matches!(r2, TaskResult::Success(_)));
1313
1314 let still_running = {
1316 let handle = scheduler.scheduler_handle.lock().unwrap();
1317 handle.as_ref().map(|h| !h.is_finished()).unwrap_or(false)
1318 };
1319 assert!(
1320 still_running,
1321 "scheduler loop should be running after restart"
1322 );
1323 }
1324}