1use super::{Task, TaskResult, TaskStatus};
3use crate::Result;
4use crate::config::{Config, OverflowStrategy, load_config};
5use crate::core::parallel::config::ParallelConfig;
6use crate::error::SubXError;
7use std::collections::VecDeque;
8use std::sync::{Arc, Mutex};
9use tokio::sync::{Semaphore, oneshot};
10
11struct PendingTask {
12 task: Box<dyn Task + Send + Sync>,
13 result_sender: oneshot::Sender<TaskResult>,
14 task_id: String,
15 priority: TaskPriority,
16}
17
18impl PartialEq for PendingTask {
19 fn eq(&self, other: &Self) -> bool {
20 self.priority == other.priority
21 }
22}
23
24impl Eq for PendingTask {}
25
26impl PartialOrd for PendingTask {
27 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
28 Some(self.cmp(other))
29 }
30}
31
32impl Ord for PendingTask {
33 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
34 self.priority.cmp(&other.priority)
35 }
36}
37
38#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
40pub enum TaskPriority {
41 Low = 0,
42 Normal = 1,
43 High = 2,
44 Critical = 3,
45}
46
47#[derive(Debug, Clone)]
49pub struct TaskInfo {
50 pub task_id: String,
51 pub task_type: String,
52 pub status: TaskStatus,
53 pub start_time: std::time::Instant,
54 pub progress: f32,
55}
56
57pub struct TaskScheduler {
59 _config: ParallelConfig,
61 load_balancer: Option<crate::core::parallel::load_balancer::LoadBalancer>,
63 task_timeout: std::time::Duration,
65 worker_idle_timeout: std::time::Duration,
67 task_queue: Arc<Mutex<VecDeque<PendingTask>>>,
68 semaphore: Arc<Semaphore>,
69 active_tasks: Arc<Mutex<std::collections::HashMap<String, TaskInfo>>>,
70 scheduler_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
71}
72
73impl TaskScheduler {
74 pub fn new() -> Result<Self> {
76 let app_config = load_config()?;
77 let config = ParallelConfig::from_app_config(&app_config);
78 config.validate()?;
79 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_jobs));
80 let task_queue = Arc::new(Mutex::new(VecDeque::new()));
81 let active_tasks = Arc::new(Mutex::new(std::collections::HashMap::new()));
82
83 let general = &app_config.general;
85 let scheduler = Self {
86 _config: config.clone(),
87 task_queue: task_queue.clone(),
88 semaphore: semaphore.clone(),
89 active_tasks: active_tasks.clone(),
90 scheduler_handle: Arc::new(Mutex::new(None)),
91 load_balancer: if config.auto_balance_workers {
92 Some(crate::core::parallel::load_balancer::LoadBalancer::new())
93 } else {
94 None
95 },
96 task_timeout: std::time::Duration::from_secs(general.task_timeout_seconds),
97 worker_idle_timeout: std::time::Duration::from_secs(
98 general.worker_idle_timeout_seconds,
99 ),
100 };
101
102 scheduler.start_scheduler_loop();
104 Ok(scheduler)
105 }
106
107 pub fn new_with_defaults() -> Self {
109 let default_app_config = Config::default();
111 let config = ParallelConfig::from_app_config(&default_app_config);
112 let _ = config.validate();
113 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_jobs));
114 let task_queue = Arc::new(Mutex::new(VecDeque::new()));
115 let active_tasks = Arc::new(Mutex::new(std::collections::HashMap::new()));
116
117 let general = &default_app_config.general;
118 let scheduler = Self {
119 _config: config.clone(),
120 task_queue: task_queue.clone(),
121 semaphore: semaphore.clone(),
122 active_tasks: active_tasks.clone(),
123 scheduler_handle: Arc::new(Mutex::new(None)),
124 load_balancer: if config.auto_balance_workers {
125 Some(crate::core::parallel::load_balancer::LoadBalancer::new())
126 } else {
127 None
128 },
129 task_timeout: std::time::Duration::from_secs(general.task_timeout_seconds),
130 worker_idle_timeout: std::time::Duration::from_secs(
131 general.worker_idle_timeout_seconds,
132 ),
133 };
134
135 scheduler.start_scheduler_loop();
137 scheduler
138 }
139
140 fn start_scheduler_loop(&self) {
142 let task_queue = Arc::clone(&self.task_queue);
143 let semaphore = Arc::clone(&self.semaphore);
144 let active_tasks = Arc::clone(&self.active_tasks);
145 let config = self._config.clone();
146 let task_timeout = self.task_timeout;
147 let worker_idle_timeout = self.worker_idle_timeout;
148
149 let handle = tokio::spawn(async move {
150 let mut last_active = std::time::Instant::now();
152 loop {
153 let has_pending = {
155 let q = task_queue.lock().unwrap();
156 !q.is_empty()
157 };
158 let has_active = {
159 let a = active_tasks.lock().unwrap();
160 !a.is_empty()
161 };
162 if has_pending || has_active {
163 last_active = std::time::Instant::now();
164 } else if last_active.elapsed() > worker_idle_timeout {
165 break;
166 }
167 if let Ok(permit) = semaphore.clone().try_acquire_owned() {
169 let pending = {
170 let mut queue = task_queue.lock().unwrap();
171 if config.enable_task_priorities {
173 if let Some((idx, _)) =
175 queue.iter().enumerate().max_by_key(|(_, t)| t.priority)
176 {
177 queue.remove(idx)
178 } else {
179 None
180 }
181 } else {
182 queue.pop_front()
183 }
184 };
185 if let Some(p) = pending {
186 {
188 let mut active = active_tasks.lock().unwrap();
189 if let Some(info) = active.get_mut(&p.task_id) {
190 info.status = TaskStatus::Running;
191 }
192 }
193
194 let task_id = p.task_id.clone();
195 let active_tasks_clone = Arc::clone(&active_tasks);
196
197 tokio::spawn(async move {
199 let result =
201 match tokio::time::timeout(task_timeout, p.task.execute()).await {
202 Ok(res) => res,
203 Err(_) => TaskResult::Failed("任務執行逾時".to_string()),
204 };
205
206 {
208 let mut at = active_tasks_clone.lock().unwrap();
209 if let Some(info) = at.get_mut(&task_id) {
210 info.status = TaskStatus::Completed(result.clone());
211 info.progress = 1.0;
212 }
213 }
214
215 let _ = p.result_sender.send(result);
217
218 drop(permit);
220 });
221 } else {
222 drop(permit);
224 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
225 }
226 } else {
227 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
229 }
230 }
231 });
232
233 *self.scheduler_handle.lock().unwrap() = Some(handle);
235 }
236
237 pub async fn submit_task(&self, task: Box<dyn Task + Send + Sync>) -> Result<TaskResult> {
239 self.submit_task_with_priority(task, TaskPriority::Normal)
240 .await
241 }
242
243 pub async fn submit_task_with_priority(
245 &self,
246 task: Box<dyn Task + Send + Sync>,
247 priority: TaskPriority,
248 ) -> Result<TaskResult> {
249 let task_id = task.task_id();
250 let task_type = task.task_type().to_string();
251 let (tx, rx) = oneshot::channel();
252
253 {
255 let mut active = self.active_tasks.lock().unwrap();
256 active.insert(
257 task_id.clone(),
258 TaskInfo {
259 task_id: task_id.clone(),
260 task_type,
261 status: TaskStatus::Pending,
262 start_time: std::time::Instant::now(),
263 progress: 0.0,
264 },
265 );
266 }
267
268 let pending = PendingTask {
270 task,
271 result_sender: tx,
272 task_id: task_id.clone(),
273 priority,
274 };
275 if self.get_queue_size() >= self._config.task_queue_size {
276 match self._config.queue_overflow_strategy {
277 OverflowStrategy::Block => {
278 while self.get_queue_size() >= self._config.task_queue_size {
280 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
281 }
282 }
283 OverflowStrategy::DropOldest => {
284 let mut q = self.task_queue.lock().unwrap();
285 q.pop_front();
286 }
287 OverflowStrategy::Reject => {
288 return Err(SubXError::parallel_processing("任務佇列已滿".to_string()));
289 }
290 }
291 }
292 {
294 let mut q = self.task_queue.lock().unwrap();
295 if self._config.enable_task_priorities {
296 let pos = q
297 .iter()
298 .position(|t| t.priority < pending.priority)
299 .unwrap_or(q.len());
300 q.insert(pos, pending);
301 } else {
302 q.push_back(pending);
303 }
304 }
305
306 let result = rx.await.map_err(|_| {
308 crate::error::SubXError::parallel_processing("任務執行被中斷".to_string())
309 })?;
310
311 {
313 let mut active = self.active_tasks.lock().unwrap();
314 active.remove(&task_id);
315 }
316 Ok(result)
317 }
318
319 async fn try_execute_next_task(&self) {
320 }
323
324 pub async fn submit_batch_tasks(
326 &self,
327 tasks: Vec<Box<dyn Task + Send + Sync>>,
328 ) -> Vec<TaskResult> {
329 let mut receivers = Vec::new();
330
331 for task in tasks {
333 let task_id = task.task_id();
334 let task_type = task.task_type().to_string();
335 let (tx, rx) = oneshot::channel();
336
337 {
339 let mut active = self.active_tasks.lock().unwrap();
340 active.insert(
341 task_id.clone(),
342 TaskInfo {
343 task_id: task_id.clone(),
344 task_type,
345 status: TaskStatus::Pending,
346 start_time: std::time::Instant::now(),
347 progress: 0.0,
348 },
349 );
350 }
351
352 let pending = PendingTask {
354 task,
355 result_sender: tx,
356 task_id: task_id.clone(),
357 priority: TaskPriority::Normal,
358 };
359 if self.get_queue_size() >= self._config.task_queue_size {
360 match self._config.queue_overflow_strategy {
361 OverflowStrategy::Block => {
362 while self.get_queue_size() >= self._config.task_queue_size {
364 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
365 }
366 }
367 OverflowStrategy::DropOldest => {
368 let mut q = self.task_queue.lock().unwrap();
369 q.pop_front();
370 }
371 OverflowStrategy::Reject => {
372 return Vec::new();
374 }
375 }
376 }
377 {
379 let mut q = self.task_queue.lock().unwrap();
380 if self._config.enable_task_priorities {
381 let pos = q
382 .iter()
383 .position(|t| t.priority < pending.priority)
384 .unwrap_or(q.len());
385 q.insert(pos, pending);
386 } else {
387 q.push_back(pending);
388 }
389 }
390
391 receivers.push((task_id, rx));
392 }
393
394 let mut results = Vec::new();
396 for (task_id, rx) in receivers {
397 match rx.await {
398 Ok(result) => results.push(result),
399 Err(_) => results.push(TaskResult::Failed("任務執行被中斷".to_string())),
400 }
401
402 {
404 let mut active = self.active_tasks.lock().unwrap();
405 active.remove(&task_id);
406 }
407 }
408
409 results
410 }
411
412 pub fn get_queue_size(&self) -> usize {
414 self.task_queue.lock().unwrap().len()
415 }
416
417 pub fn get_active_workers(&self) -> usize {
419 self._config.max_concurrent_jobs - self.semaphore.available_permits()
420 }
421
422 pub fn get_task_status(&self, task_id: &str) -> Option<TaskInfo> {
424 self.active_tasks.lock().unwrap().get(task_id).cloned()
425 }
426
427 pub fn list_active_tasks(&self) -> Vec<TaskInfo> {
429 self.active_tasks
430 .lock()
431 .unwrap()
432 .values()
433 .cloned()
434 .collect()
435 }
436}
437
438impl Clone for TaskScheduler {
439 fn clone(&self) -> Self {
440 Self {
441 _config: self._config.clone(),
442 task_queue: Arc::clone(&self.task_queue),
443 semaphore: Arc::clone(&self.semaphore),
444 active_tasks: Arc::clone(&self.active_tasks),
445 scheduler_handle: Arc::clone(&self.scheduler_handle),
446 load_balancer: self.load_balancer.clone(),
447 task_timeout: self.task_timeout,
448 worker_idle_timeout: self.worker_idle_timeout,
449 }
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::{Task, TaskPriority, TaskResult, TaskScheduler};
456 use std::sync::atomic::{AtomicUsize, Ordering};
457 use std::sync::{Arc, Mutex};
458 use tokio::time::Duration;
459 use uuid::Uuid;
460
461 struct MockTask {
462 name: String,
463 duration: Duration,
464 }
465
466 #[async_trait::async_trait]
467 impl Task for MockTask {
468 async fn execute(&self) -> TaskResult {
469 tokio::time::sleep(self.duration).await;
470 TaskResult::Success(format!("完成任務: {}", self.name))
471 }
472 fn task_type(&self) -> &'static str {
473 "mock"
474 }
475 fn task_id(&self) -> String {
476 format!("mock_{}", self.name)
477 }
478 }
479
480 struct CounterTask {
481 counter: Arc<AtomicUsize>,
482 }
483 impl CounterTask {
484 fn new(counter: Arc<AtomicUsize>) -> Self {
485 Self { counter }
486 }
487 }
488 #[async_trait::async_trait]
489 impl Task for CounterTask {
490 async fn execute(&self) -> TaskResult {
491 self.counter.fetch_add(1, Ordering::SeqCst);
492 TaskResult::Success("計數任務完成".to_string())
493 }
494 fn task_type(&self) -> &'static str {
495 "counter"
496 }
497 fn task_id(&self) -> String {
498 Uuid::new_v4().to_string()
499 }
500 }
501
502 struct OrderTask {
503 name: String,
504 order: Arc<Mutex<Vec<String>>>,
505 }
506 impl OrderTask {
507 fn new(name: &str, order: Arc<Mutex<Vec<String>>>) -> Self {
508 Self {
509 name: name.to_string(),
510 order,
511 }
512 }
513 }
514 #[async_trait::async_trait]
515 impl Task for OrderTask {
516 async fn execute(&self) -> TaskResult {
517 let mut v = self.order.lock().unwrap();
518 v.push(self.name.clone());
519 TaskResult::Success(format!("順序任務完成: {}", self.name))
520 }
521 fn task_type(&self) -> &'static str {
522 "order"
523 }
524 fn task_id(&self) -> String {
525 format!("order_{}", self.name)
526 }
527 }
528
529 #[tokio::test]
530 async fn test_task_scheduler_basic() {
531 let scheduler = TaskScheduler::new_with_defaults();
532 let task = Box::new(MockTask {
533 name: "test".to_string(),
534 duration: Duration::from_millis(10),
535 });
536 let result = scheduler.submit_task(task).await.unwrap();
537 assert!(matches!(result, TaskResult::Success(_)));
538 }
539
540 #[tokio::test]
541 async fn test_concurrent_task_execution() {
542 let scheduler = TaskScheduler::new_with_defaults();
543 let counter = Arc::new(AtomicUsize::new(0));
544
545 let task = Box::new(CounterTask::new(counter.clone()));
547 let result = scheduler.submit_task(task).await.unwrap();
548 assert!(matches!(result, TaskResult::Success(_)));
549 assert_eq!(counter.load(Ordering::SeqCst), 1);
550
551 for _ in 0..4 {
553 let task = Box::new(CounterTask::new(counter.clone()));
554 let _result = scheduler.submit_task(task).await.unwrap();
555 }
556 assert_eq!(counter.load(Ordering::SeqCst), 5);
557 }
558
559 #[tokio::test]
560 async fn test_task_priority_ordering() {
561 let scheduler = TaskScheduler::new_with_defaults();
562 let order = Arc::new(Mutex::new(Vec::new()));
563
564 let tasks = vec![
566 (TaskPriority::Low, "low"),
567 (TaskPriority::High, "high"),
568 (TaskPriority::Normal, "normal"),
569 (TaskPriority::Critical, "critical"),
570 ];
571
572 let mut handles = Vec::new();
573 for (prio, name) in tasks {
574 let task = Box::new(OrderTask::new(name, order.clone()));
575 let scheduler_clone = scheduler.clone();
576 let handle = tokio::spawn(async move {
577 scheduler_clone
578 .submit_task_with_priority(task, prio)
579 .await
580 .unwrap()
581 });
582 handles.push(handle);
583 }
584
585 for handle in handles {
587 let _ = handle.await.unwrap();
588 }
589
590 let v = order.lock().unwrap();
591 assert_eq!(v.len(), 4);
592 assert!(v.contains(&"critical".to_string()));
594 assert!(v.contains(&"high".to_string()));
595 assert!(v.contains(&"normal".to_string()));
596 assert!(v.contains(&"low".to_string()));
597 }
598
599 #[tokio::test]
600 async fn test_queue_and_active_workers_metrics() {
601 let scheduler = TaskScheduler::new_with_defaults();
602
603 assert_eq!(scheduler.get_queue_size(), 0);
605 assert_eq!(scheduler.get_active_workers(), 0);
606
607 let task = Box::new(MockTask {
609 name: "long_task".to_string(),
610 duration: Duration::from_millis(100),
611 });
612
613 let handle = {
614 let scheduler_clone = scheduler.clone();
615 tokio::spawn(async move { scheduler_clone.submit_task(task).await })
616 };
617
618 tokio::time::sleep(Duration::from_millis(20)).await;
620
621 let _result = handle.await.unwrap().unwrap();
623
624 assert_eq!(scheduler.get_queue_size(), 0);
626 }
627
628 #[tokio::test]
629 async fn test_continuous_scheduling() {
630 let scheduler = TaskScheduler::new_with_defaults();
631 let counter = Arc::new(AtomicUsize::new(0));
632
633 let mut handles = Vec::new();
635 for i in 0..10 {
636 let task = Box::new(CounterTask::new(counter.clone()));
637 let scheduler_clone = scheduler.clone();
638 let handle =
639 tokio::spawn(async move { scheduler_clone.submit_task(task).await.unwrap() });
640 handles.push(handle);
641
642 if i % 3 == 0 {
644 tokio::time::sleep(Duration::from_millis(5)).await;
645 }
646 }
647
648 for handle in handles {
650 let result = handle.await.unwrap();
651 assert!(matches!(result, TaskResult::Success(_)));
652 }
653
654 assert_eq!(counter.load(Ordering::SeqCst), 10);
656 }
657
658 #[tokio::test]
659 async fn test_batch_task_execution() {
660 let scheduler = TaskScheduler::new_with_defaults();
661 let counter = Arc::new(AtomicUsize::new(0));
662
663 let mut tasks: Vec<Box<dyn Task + Send + Sync>> = Vec::new();
665 for _ in 0..3 {
666 tasks.push(Box::new(CounterTask::new(counter.clone())));
668 }
669
670 let results = scheduler.submit_batch_tasks(tasks).await;
671 assert_eq!(results.len(), 3);
672 assert_eq!(counter.load(Ordering::SeqCst), 3);
673 for result in results {
674 assert!(matches!(result, TaskResult::Success(_)));
675 }
676 }
677
678 #[tokio::test]
679 async fn test_high_concurrency_stress() {
680 let scheduler = TaskScheduler::new_with_defaults();
681 let counter = Arc::new(AtomicUsize::new(0));
682
683 let mut handles = Vec::new();
685 for i in 0..50 {
686 let task = Box::new(CounterTask::new(counter.clone()));
687 let scheduler_clone = scheduler.clone();
688 let priority = match i % 4 {
689 0 => TaskPriority::Low,
690 1 => TaskPriority::Normal,
691 2 => TaskPriority::High,
692 3 => TaskPriority::Critical,
693 _ => TaskPriority::Normal,
694 };
695
696 let handle = tokio::spawn(async move {
697 scheduler_clone
698 .submit_task_with_priority(task, priority)
699 .await
700 .unwrap()
701 });
702 handles.push(handle);
703
704 if i % 5 == 0 {
706 tokio::time::sleep(Duration::from_millis(1)).await;
707 }
708 }
709
710 for handle in handles {
712 let result = handle.await.unwrap();
713 assert!(matches!(result, TaskResult::Success(_)));
714 }
715
716 assert_eq!(counter.load(Ordering::SeqCst), 50);
718
719 assert_eq!(scheduler.get_queue_size(), 0);
721 assert_eq!(scheduler.get_active_workers(), 0);
722 }
723
724 #[tokio::test]
725 async fn test_mixed_batch_and_individual_tasks() {
726 let scheduler = TaskScheduler::new_with_defaults();
727 let counter = Arc::new(AtomicUsize::new(0));
728
729 let mut individual_handles = Vec::new();
731 for _ in 0..3 {
732 let task = Box::new(CounterTask::new(counter.clone()));
733 let scheduler_clone = scheduler.clone();
734 let handle =
735 tokio::spawn(async move { scheduler_clone.submit_task(task).await.unwrap() });
736 individual_handles.push(handle);
737 }
738
739 let mut batch_tasks: Vec<Box<dyn Task + Send + Sync>> = Vec::new();
741 for _ in 0..4 {
742 batch_tasks.push(Box::new(CounterTask::new(counter.clone())));
743 }
744
745 let batch_handle = {
746 let scheduler_clone = scheduler.clone();
747 tokio::spawn(async move { scheduler_clone.submit_batch_tasks(batch_tasks).await })
748 };
749
750 let mut more_individual_handles = Vec::new();
752 for _ in 0..2 {
753 let task = Box::new(CounterTask::new(counter.clone()));
754 let scheduler_clone = scheduler.clone();
755 let handle =
756 tokio::spawn(async move { scheduler_clone.submit_task(task).await.unwrap() });
757 more_individual_handles.push(handle);
758 }
759
760 for handle in individual_handles {
762 let result = handle.await.unwrap();
763 assert!(matches!(result, TaskResult::Success(_)));
764 }
765
766 let batch_results = batch_handle.await.unwrap();
767 assert_eq!(batch_results.len(), 4);
768 for result in batch_results {
769 assert!(matches!(result, TaskResult::Success(_)));
770 }
771
772 for handle in more_individual_handles {
773 let result = handle.await.unwrap();
774 assert!(matches!(result, TaskResult::Success(_)));
775 }
776
777 assert_eq!(counter.load(Ordering::SeqCst), 9);
779 }
780}