1use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
9use crossbeam_deque::{Injector, Steal, Stealer, Worker};
10use crossbeam_utils::sync::{Parker, Unparker};
11use num_cpus;
12use std::cell::UnsafeCell;
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
15use std::sync::{Arc, Condvar, Mutex, RwLock};
16use std::thread::{self, JoinHandle};
17use std::time::{Duration, Instant};
18
19type TaskCompletionNotify = Arc<(Mutex<bool>, Condvar)>;
21
22type TaskCompletionMap = Arc<Mutex<HashMap<usize, TaskCompletionNotify>>>;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
27pub enum TaskPriority {
28 Background = 0,
30 #[default]
32 Normal = 1,
33 High = 2,
35 Critical = 3,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
41pub enum SchedulingPolicy {
42 Fifo,
44 Lifo,
46 #[default]
48 Priority,
49 WeightedFair,
51}
52
53#[derive(Debug, Clone)]
55pub struct SchedulerConfig {
56 pub numworkers: usize,
58 pub policy: SchedulingPolicy,
60 pub max_queue_size: usize,
62 pub adaptive: bool,
64 pub enable_stealing_heuristics: bool,
66 pub enable_priorities: bool,
68 pub stealing_threshold: usize,
70 pub sleep_ms: u64,
72 pub min_batch_size: usize,
74 pub max_batch_size: usize,
76 pub task_timeout_ms: u64,
78 pub maxretries: usize,
80}
81
82impl Default for SchedulerConfig {
83 fn default() -> Self {
84 Self {
85 numworkers: num_cpus::get(),
86 policy: SchedulingPolicy::Priority,
87 max_queue_size: 10000,
88 adaptive: true,
89 enable_stealing_heuristics: true,
90 enable_priorities: true,
91 stealing_threshold: 4,
92 sleep_ms: 1,
93 min_batch_size: 1,
94 max_batch_size: 100,
95 task_timeout_ms: 0,
96 maxretries: 3,
97 }
98 }
99}
100
101#[derive(Debug, Clone, Default)]
103pub struct SchedulerConfigBuilder {
104 config: SchedulerConfig,
105}
106
107impl SchedulerConfigBuilder {
108 pub fn new() -> Self {
110 Self::default()
111 }
112
113 pub const fn workers(mut self, numworkers: usize) -> Self {
115 self.config.numworkers = numworkers;
116 self
117 }
118
119 pub const fn policy(mut self, policy: SchedulingPolicy) -> Self {
121 self.config.policy = policy;
122 self
123 }
124
125 pub const fn max_queue_size(mut self, size: usize) -> Self {
127 self.config.max_queue_size = size;
128 self
129 }
130
131 pub const fn adaptive(mut self, enable: bool) -> Self {
133 self.config.adaptive = enable;
134 self
135 }
136
137 pub const fn enable_stealing_heuristics(mut self, enable: bool) -> Self {
139 self.config.enable_stealing_heuristics = enable;
140 self
141 }
142
143 pub const fn enable_priorities(mut self, enable: bool) -> Self {
145 self.config.enable_priorities = enable;
146 self
147 }
148
149 pub const fn stealing_threshold(mut self, threshold: usize) -> Self {
151 self.config.stealing_threshold = threshold;
152 self
153 }
154
155 pub const fn sleep_ms(mut self, ms: u64) -> Self {
157 self.config.sleep_ms = ms;
158 self
159 }
160
161 pub const fn min_batch_size(mut self, size: usize) -> Self {
163 self.config.min_batch_size = size;
164 self
165 }
166
167 pub const fn max_batch_size(mut self, size: usize) -> Self {
169 self.config.max_batch_size = size;
170 self
171 }
172
173 pub const fn task_timeout_ms(mut self, timeout: u64) -> Self {
175 self.config.task_timeout_ms = timeout;
176 self
177 }
178
179 pub const fn maxretries(mut self, retries: usize) -> Self {
181 self.config.maxretries = retries;
182 self
183 }
184
185 pub fn build(self) -> SchedulerConfig {
187 self.config
188 }
189}
190
191pub trait Task: Send + 'static {
193 fn execute(&mut self) -> Result<(), CoreError>;
195
196 fn priority(&self) -> TaskPriority {
198 TaskPriority::Normal
199 }
200
201 fn weight(&self) -> usize {
203 1
204 }
205
206 fn estimated_cost(&self) -> usize {
208 1
209 }
210
211 fn clone_task(&self) -> Box<dyn Task>;
213
214 fn name(&self) -> &str {
216 "unnamed"
217 }
218}
219
220#[derive(Clone)]
222pub struct TaskHandle {
223 id: usize,
225 status: Arc<Mutex<TaskStatus>>,
227 result_notify: TaskCompletionNotify,
229}
230
231#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233pub enum TaskStatus {
234 Pending,
236 Running,
238 Completed,
240 Failed(usize), Cancelled,
244 TimedOut,
246}
247
248impl TaskHandle {
249 #[allow(dead_code)]
250 fn new(id: usize) -> Self {
252 Self {
253 id,
254 status: Arc::new(Mutex::new(TaskStatus::Pending)),
255 result_notify: Arc::new((Mutex::new(false), Condvar::new())),
256 }
257 }
258
259 pub fn id(&self) -> usize {
261 self.id
262 }
263
264 pub fn status(&self) -> TaskStatus {
266 *self.status.lock().expect("Operation failed")
267 }
268
269 pub fn wait(&self) -> TaskStatus {
271 let (lock, cvar) = &*self.result_notify;
272 let completed = lock.lock().expect("Operation failed");
273
274 if !*completed {
276 let _completed = cvar.wait(completed).expect("Operation failed");
277 }
278
279 self.status()
280 }
281
282 pub fn wait_timeout(&self, timeout: Duration) -> Result<TaskStatus, CoreError> {
284 let (lock, cvar) = &*self.result_notify;
285 let completed = lock.lock().expect("Operation failed");
286
287 if !*completed {
289 let result = cvar
290 .wait_timeout(completed, timeout)
291 .expect("Operation failed");
292
293 if result.1.timed_out() {
294 return Err(CoreError::TimeoutError(
295 ErrorContext::new(format!("{}", self.id))
296 .with_location(ErrorLocation::new(file!(), line!())),
297 ));
298 }
299 }
300
301 Ok(self.status())
302 }
303
304 pub fn cancel(&self) -> bool {
306 let mut status = self.status.lock().expect("Operation failed");
307
308 if *status == TaskStatus::Pending {
309 *status = TaskStatus::Cancelled;
310
311 let (lock, cvar) = &*self.result_notify;
313 let mut completed = lock.lock().expect("Operation failed");
314 *completed = true;
315 cvar.notify_all();
316
317 true
318 } else {
319 false
320 }
321 }
322}
323
324struct TaskWrapper {
326 id: usize,
328 task: Box<dyn Task>,
330 priority: TaskPriority,
332 weight: usize,
334 #[allow(dead_code)]
336 cost: usize,
337 status: Arc<Mutex<TaskStatus>>,
339 result_notify: TaskCompletionNotify,
341 #[allow(dead_code)]
343 submission_time: Instant,
344 retry_count: usize,
346 #[allow(dead_code)]
348 name: String,
349}
350
351impl TaskWrapper {
352 fn new(id: usize, task: Box<dyn Task>) -> Self {
354 let priority = task.priority();
355 let weight = task.weight();
356 let cost = task.estimated_cost();
357 let name = task.name().to_string();
358
359 Self {
360 id,
361 task,
362 priority,
363 weight,
364 cost,
365 status: Arc::new(Mutex::new(TaskStatus::Pending)),
366 result_notify: Arc::new((Mutex::new(false), Condvar::new())),
367 submission_time: Instant::now(),
368 retry_count: 0,
369 name,
370 }
371 }
372
373 fn create_handle(&self) -> TaskHandle {
375 TaskHandle {
376 id: self.id,
377 status: self.status.clone(),
378 result_notify: self.result_notify.clone(),
379 }
380 }
381
382 fn execute(&mut self) -> Result<(), CoreError> {
384 {
386 let mut status = self.status.lock().expect("Operation failed");
387 *status = TaskStatus::Running;
388 }
389
390 let result = self.task.execute();
392
393 {
395 let mut status = self.status.lock().expect("Operation failed");
396 *status = match result {
397 Ok(_) => TaskStatus::Completed,
398 Err(_) => TaskStatus::Failed(self.retry_count),
399 };
400 }
401
402 let (lock, cvar) = &*self.result_notify;
404 let mut completed = lock.lock().expect("Operation failed");
405 *completed = true;
406 cvar.notify_all();
407
408 result
409 }
410
411 fn increment_retry(&mut self) {
413 self.retry_count += 1;
414 }
415}
416
417thread_local! {
418 static WORKER_ID: UnsafeCell<Option<usize>> = const { UnsafeCell::new(None) };
420}
421
422#[allow(dead_code)]
424fn set_workerid(id: usize) {
425 WORKER_ID.with(|cell| unsafe {
426 *cell.get() = Some(id);
427 });
428}
429
430#[allow(dead_code)]
432pub fn get_workerid() -> Option<usize> {
433 WORKER_ID.with(|cell| unsafe { *cell.get() })
434}
435
436struct WorkerState {
438 #[allow(dead_code)]
440 id: usize,
441 #[allow(clippy::type_complexity)]
443 local_queue: UnsafeCell<Worker<TaskWrapper>>,
444 stealers: Vec<Stealer<TaskWrapper>>,
446 injector: Arc<Injector<TaskWrapper>>,
448 #[allow(dead_code)]
450 active: Arc<AtomicBool>,
451 parker: UnsafeCell<Parker>,
453 unparker: Unparker,
455 tasks_processed: AtomicUsize,
457 tasks_stolen: AtomicUsize,
459 failed_steals: AtomicUsize,
461 last_active: Mutex<Instant>,
463 local_queue_size: AtomicUsize,
465 adaptive_batch_size: AtomicUsize,
467}
468
469unsafe impl Send for WorkerState {}
472unsafe impl Sync for WorkerState {}
473
474impl WorkerState {
475 fn new(
477 id: usize,
478 stealers: Vec<Stealer<TaskWrapper>>,
479 injector: Arc<Injector<TaskWrapper>>,
480 ) -> Self {
481 let parker = Parker::new();
482 let unparker = parker.unparker().clone();
483
484 Self {
485 id,
486 local_queue: UnsafeCell::new(Worker::new_fifo()),
487 stealers,
488 injector,
489 active: Arc::new(AtomicBool::new(true)),
490 parker: UnsafeCell::new(parker),
491 unparker,
492 tasks_processed: AtomicUsize::new(0),
493 tasks_stolen: AtomicUsize::new(0),
494 failed_steals: AtomicUsize::new(0),
495 last_active: Mutex::new(Instant::now()),
496 local_queue_size: AtomicUsize::new(0),
497 adaptive_batch_size: AtomicUsize::new(1),
498 }
499 }
500
501 #[allow(dead_code)]
503 fn id(&self) -> usize {
504 self.id
505 }
506
507 fn local_queue_size(&self) -> usize {
509 self.local_queue_size.load(Ordering::Relaxed)
510 }
511
512 fn push_local(&self, task: TaskWrapper) {
514 unsafe {
516 (*self.local_queue.get()).push(task);
517 }
518 self.local_queue_size.fetch_add(1, Ordering::Relaxed);
519 }
520
521 fn pop_local(&self) -> Option<TaskWrapper> {
523 let result = unsafe { (*self.local_queue.get()).pop() };
525
526 if result.is_some() {
527 self.local_queue_size.fetch_sub(1, Ordering::Relaxed);
528 }
529
530 result
531 }
532
533 fn steal(&self) -> Option<TaskWrapper> {
535 match self.injector.steal() {
537 Steal::Success(task) => {
538 self.tasks_stolen.fetch_add(1, Ordering::Relaxed);
539 return Some(task);
540 }
541 Steal::Empty => {}
542 Steal::Retry => {}
543 }
544
545 for stealer in &self.stealers {
547 match stealer.steal() {
548 Steal::Success(task) => {
549 self.tasks_stolen.fetch_add(1, Ordering::Relaxed);
550 return Some(task);
551 }
552 Steal::Empty => {}
553 Steal::Retry => {}
554 }
555 }
556
557 self.failed_steals.fetch_add(1, Ordering::Relaxed);
559 None
560 }
561
562 fn update_last_active(&self) {
564 let mut last_active = self.last_active.lock().expect("Operation failed");
565 *last_active = Instant::now();
566 }
567
568 fn time_since_last_active(&self) -> Duration {
570 let last_active = self.last_active.lock().expect("Operation failed");
571 last_active.elapsed()
572 }
573
574 fn update_adaptive_batch_size(&self, config: &SchedulerConfig) {
576 if !config.adaptive {
577 self.adaptive_batch_size
579 .store(config.min_batch_size, Ordering::Relaxed);
580 return;
581 }
582
583 let _tasks_processed = self.tasks_processed.load(Ordering::Relaxed);
585 let tasks_stolen = self.tasks_stolen.load(Ordering::Relaxed);
586 let failed_steals = self.failed_steals.load(Ordering::Relaxed);
587
588 let steal_attempts = tasks_stolen + failed_steals;
590 let steal_success_rate = if steal_attempts > 0 {
591 tasks_stolen as f64 / steal_attempts as f64
592 } else {
593 0.0
594 };
595
596 let current_batch_size = self.adaptive_batch_size.load(Ordering::Relaxed);
598 let new_batch_size = if steal_success_rate > 0.8 {
599 (current_batch_size * 2).min(config.max_batch_size)
601 } else if steal_success_rate < 0.2 {
602 (current_batch_size / 2).max(config.min_batch_size)
604 } else {
605 current_batch_size
607 };
608
609 self.adaptive_batch_size
611 .store(new_batch_size, Ordering::Relaxed);
612 }
613}
614
615#[derive(Debug, Clone)]
617pub struct SchedulerStats {
618 pub tasks_submitted: usize,
620 pub tasks_completed: usize,
622 pub tasks_failed: usize,
624 pub tasks_cancelled: usize,
626 pub tasks_timed_out: usize,
628 pub task_retries: usize,
630 pub numworkers: usize,
632 pub avg_queue_size: f64,
634 pub avg_task_latency_ms: f64,
636 pub avg_task_execution_ms: f64,
638 pub successful_steals: usize,
640 pub failed_steals: usize,
642 pub worker_utilization: Vec<f64>,
644 pub uptime_seconds: f64,
646 pub tasks_per_second: f64,
648}
649
650impl Default for SchedulerStats {
651 fn default() -> Self {
652 Self {
653 tasks_submitted: 0,
654 tasks_completed: 0,
655 tasks_failed: 0,
656 tasks_cancelled: 0,
657 tasks_timed_out: 0,
658 task_retries: 0,
659 numworkers: 0,
660 avg_queue_size: 0.0,
661 avg_task_latency_ms: 0.0,
662 avg_task_execution_ms: 0.0,
663 successful_steals: 0,
664 failed_steals: 0,
665 worker_utilization: Vec::new(),
666 uptime_seconds: 0.0,
667 tasks_per_second: 0.0,
668 }
669 }
670}
671
672pub struct WorkStealingScheduler {
674 config: SchedulerConfig,
676 injector: Arc<Injector<TaskWrapper>>,
678 workers: Vec<JoinHandle<()>>,
680 worker_states: Vec<Arc<WorkerState>>,
682 state: Arc<RwLock<SchedulerState>>,
684 next_taskid: Arc<AtomicUsize>,
686 task_completion: TaskCompletionMap,
688 task_submissions: Arc<Mutex<HashMap<usize, Instant>>>,
690 task_executions: Arc<Mutex<HashMap<usize, Duration>>>,
692 start_time: Instant,
694}
695
696#[derive(Debug, Clone, Copy, PartialEq, Eq)]
698enum SchedulerState {
699 Running,
701 ShuttingDown,
703 ShutDown,
705}
706
707impl WorkStealingScheduler {
708 pub fn new(config: SchedulerConfig) -> Self {
710 let injector = Arc::new(Injector::new());
712 let state = Arc::new(RwLock::new(SchedulerState::Running));
713 let next_taskid = Arc::new(AtomicUsize::new(1));
714 let task_completion = Arc::new(Mutex::new(HashMap::new()));
715 let task_submissions = Arc::new(Mutex::new(HashMap::new()));
716 let task_executions = Arc::new(Mutex::new(HashMap::new()));
717
718 let mut worker_states = Vec::with_capacity(config.numworkers);
720 let mut workers = Vec::with_capacity(config.numworkers);
721
722 let worker_queues: Vec<_> = (0..config.numworkers).map(|_| Worker::new_fifo()).collect();
724
725 let stealers: Vec<_> = worker_queues
726 .iter()
727 .map(|worker| worker.stealer())
728 .collect();
729
730 for i in 0..config.numworkers {
732 let worker_state = Arc::new(WorkerState::new(i, stealers.clone(), injector.clone()));
734
735 worker_states.push(worker_state.clone());
736
737 let state_clone = state.clone();
739 let config_clone = config.clone();
740 let task_completion_clone = task_completion.clone();
741 let task_executions_clone = task_executions.clone();
742
743 let worker_thread = thread::spawn(move || {
744 set_workerid(i);
746
747 Self::worker_loop(
749 worker_state,
750 state_clone,
751 config_clone,
752 task_completion_clone, task_executions_clone,
754 );
755 });
756
757 workers.push(worker_thread);
758 }
759
760 Self {
761 config,
762 injector,
763 workers,
764 worker_states,
765 state,
766 next_taskid,
767 task_completion,
768 task_submissions,
769 task_executions,
770 start_time: Instant::now(),
771 }
772 }
773
774 fn worker_loop(
776 worker_state: Arc<WorkerState>,
777 state: Arc<RwLock<SchedulerState>>,
778 config: SchedulerConfig,
779 _task_completion: TaskCompletionMap,
780 task_executions: Arc<Mutex<HashMap<usize, Duration>>>,
781 ) {
782 while let SchedulerState::Running = *state.read().expect("Operation failed") {
784 let task = worker_state.pop_local().or_else(|| worker_state.steal());
786
787 if let Some(mut task) = task {
788 worker_state.update_last_active();
790
791 let start_time = Instant::now();
793 let result = task.execute();
794 let execution_time = start_time.elapsed();
795
796 let taskid = task.id;
798 task_executions
799 .lock()
800 .expect("Test: operation failed")
801 .insert(taskid, execution_time);
802
803 worker_state.tasks_processed.fetch_add(1, Ordering::Relaxed);
805
806 match result {
808 Ok(_) => {
809 }
811 Err(_) => {
812 if task.retry_count < config.maxretries {
814 task.increment_retry();
816
817 {
819 let mut status = task.status.lock().expect("Operation failed");
820 *status = TaskStatus::Pending;
821 }
822
823 worker_state.push_local(task);
825 }
826 }
827 }
828
829 worker_state.update_adaptive_batch_size(&config);
831 } else {
832 if config.sleep_ms > 0 {
834 unsafe {
836 (*worker_state.parker.get())
837 .park_timeout(Duration::from_millis(config.sleep_ms));
838 }
839 } else {
840 thread::yield_now();
842 }
843 }
844 }
845
846 while let Some(mut task) = worker_state.pop_local() {
848 let start_time = Instant::now();
850 let _ = task.execute();
851 let execution_time = start_time.elapsed();
852
853 let taskid = task.id;
855 task_executions
856 .lock()
857 .expect("Test: operation failed")
858 .insert(taskid, execution_time);
859
860 worker_state.tasks_processed.fetch_add(1, Ordering::Relaxed);
862 }
863 }
864
865 pub fn submit<T: Task>(&self, task: T) -> TaskHandle {
867 self.submit_boxed(Box::new(task))
868 }
869
870 pub fn submit_boxed(&self, task: Box<dyn Task>) -> TaskHandle {
872 if *self.state.read().expect("Operation failed") != SchedulerState::Running {
874 panic!("Cannot submit tasks to a stopped scheduler");
875 }
876
877 let taskid = self.next_taskid.fetch_add(1, Ordering::SeqCst);
879
880 let wrapper = TaskWrapper::new(taskid, task);
882
883 let handle = wrapper.create_handle();
885
886 self.task_completion
888 .lock()
889 .expect("Test: operation failed")
890 .insert(taskid, wrapper.result_notify.clone());
891
892 self.task_submissions
894 .lock()
895 .expect("Test: operation failed")
896 .insert(taskid, Instant::now());
897
898 match self.config.policy {
900 SchedulingPolicy::Fifo | SchedulingPolicy::Lifo => {
901 self.injector.push(wrapper);
903 }
904 SchedulingPolicy::Priority => {
905 if wrapper.priority >= TaskPriority::High {
907 let queue_idx = taskid % self.worker_states.len();
909 self.worker_states[queue_idx].push_local(wrapper);
910 self.worker_states[queue_idx].unparker.unpark();
911 } else {
912 self.injector.push(wrapper);
914 }
915 }
916 SchedulingPolicy::WeightedFair => {
917 if wrapper.weight > 1 {
919 let min_queue_idx = self
921 .worker_states
922 .iter()
923 .enumerate()
924 .min_by_key(|(_, state)| state.local_queue_size())
925 .map(|(idx, _)| idx)
926 .unwrap_or(0);
927
928 self.worker_states[min_queue_idx].push_local(wrapper);
929 self.worker_states[min_queue_idx].unparker.unpark();
930 } else {
931 self.injector.push(wrapper);
933 }
934 }
935 }
936
937 self.wake_idle_workers();
939
940 handle
941 }
942
943 pub fn submit_batch<T: Task + Clone>(&self, tasks: &[T]) -> Vec<TaskHandle> {
945 let mut handles = Vec::with_capacity(tasks.len());
946
947 for task in tasks {
948 handles.push(self.submit(task.clone()));
949 }
950
951 handles
952 }
953
954 pub fn submit_fn<F, R>(&self, f: F) -> TaskHandle
956 where
957 F: FnOnce() -> Result<R, CoreError> + Send + 'static,
958 R: Send + 'static,
959 {
960 struct FnTask<F, R> {
962 f: Option<F>,
963 phantom: std::marker::PhantomData<R>,
964 }
965
966 impl<F, R> Task for FnTask<F, R>
967 where
968 F: FnOnce() -> Result<R, CoreError> + Send + 'static,
969 R: Send + 'static,
970 {
971 fn execute(&mut self) -> Result<(), CoreError> {
972 if let Some(f) = self.f.take() {
973 f()?;
974 Ok(())
975 } else {
976 Err(CoreError::SchedulerError(
977 ErrorContext::new("Task function was already called".to_string())
978 .with_location(ErrorLocation::new(file!(), line!())),
979 ))
980 }
981 }
982
983 fn clone_task(&self) -> Box<dyn Task> {
984 panic!("FnTask cannot be cloned")
985 }
986 }
987
988 self.submit_boxed(Box::new(FnTask {
989 f: Some(f),
990 phantom: std::marker::PhantomData,
991 }))
992 }
993
994 fn wake_idle_workers(&self) {
996 for worker in &self.worker_states {
997 if worker.time_since_last_active() > Duration::from_millis(self.config.sleep_ms) {
998 worker.unparker.unpark();
999 }
1000 }
1001 }
1002
1003 pub fn wait_all(&self) {
1005 let taskids: Vec<_> = self
1007 .task_completion
1008 .lock()
1009 .expect("Test: operation failed")
1010 .keys()
1011 .copied()
1012 .collect();
1013
1014 for id in taskids {
1016 if let Some(notify) = self
1017 .task_completion
1018 .lock()
1019 .expect("Operation failed")
1020 .get(&id)
1021 {
1022 let (lock, cvar) = &**notify;
1023 let completed = lock.lock().expect("Operation failed");
1024
1025 if !*completed {
1026 let _completed = cvar.wait(completed).expect("Operation failed");
1027 }
1028 }
1029 }
1030 }
1031
1032 pub fn wait_all_timeout(&self, timeout: Duration) -> Result<(), CoreError> {
1034 let deadline = Instant::now() + timeout;
1035
1036 let taskids: Vec<_> = self
1038 .task_completion
1039 .lock()
1040 .expect("Test: operation failed")
1041 .keys()
1042 .copied()
1043 .collect();
1044
1045 for id in taskids {
1047 let remaining = deadline.saturating_duration_since(Instant::now());
1048
1049 if remaining.as_secs() == 0 && remaining.subsec_nanos() == 0 {
1050 return Err(CoreError::TimeoutError(
1051 ErrorContext::new("Timeout waiting for tasks".to_string())
1052 .with_location(ErrorLocation::new(file!(), line!())),
1053 ));
1054 }
1055
1056 if let Some(notify) = self
1057 .task_completion
1058 .lock()
1059 .expect("Operation failed")
1060 .get(&id)
1061 {
1062 let (lock, cvar) = &**notify;
1063 let completed = lock.lock().expect("Operation failed");
1064
1065 if !*completed {
1066 let result = cvar
1067 .wait_timeout(completed, remaining)
1068 .expect("Operation failed");
1069
1070 if result.1.timed_out() && !*result.0 {
1071 return Err(CoreError::TimeoutError(
1072 ErrorContext::new("Timeout waiting for tasks".to_string())
1073 .with_location(ErrorLocation::new(file!(), line!())),
1074 ));
1075 }
1076 }
1077 }
1078 }
1079
1080 Ok(())
1081 }
1082
1083 pub fn stats(&self) -> SchedulerStats {
1085 let mut stats = SchedulerStats {
1086 tasks_submitted: self.next_taskid.load(Ordering::Relaxed) - 1,
1087 numworkers: self.config.numworkers,
1088 ..SchedulerStats::default()
1089 };
1090
1091 let mut total_latency = Duration::from_secs(0);
1093 let mut total_execution = Duration::from_secs(0);
1094 let mut completed_tasks = 0;
1095
1096 let submissions = self.task_submissions.lock().expect("Operation failed");
1097 let executions = self.task_executions.lock().expect("Operation failed");
1098
1099 for (id, submission_time) in submissions.iter() {
1100 if let Some(execution_time) = executions.get(id) {
1101 let latency = submission_time.elapsed() - *execution_time;
1103
1104 total_latency += latency;
1105 total_execution += *execution_time;
1106 completed_tasks += 1;
1107 }
1108 }
1109
1110 stats.tasks_completed = completed_tasks;
1112
1113 if completed_tasks > 0 {
1114 stats.avg_task_latency_ms = total_latency.as_millis() as f64 / completed_tasks as f64;
1115 stats.avg_task_execution_ms =
1116 total_execution.as_millis() as f64 / completed_tasks as f64;
1117 }
1118
1119 let mut total_queue_size = 0;
1121 let mut total_successful_steals = 0;
1122 let mut total_failed_steals = 0;
1123 let mut worker_utils = Vec::with_capacity(self.worker_states.len());
1124
1125 for worker in &self.worker_states {
1126 total_queue_size += worker.local_queue_size();
1127 total_successful_steals += worker.tasks_stolen.load(Ordering::Relaxed);
1128 total_failed_steals += worker.failed_steals.load(Ordering::Relaxed);
1129
1130 let tasks_processed = worker.tasks_processed.load(Ordering::Relaxed);
1132 let utilization = if stats.tasks_submitted > 0 {
1133 tasks_processed as f64 / stats.tasks_submitted as f64
1134 } else {
1135 0.0
1136 };
1137
1138 worker_utils.push(utilization);
1139 }
1140
1141 stats.avg_queue_size = total_queue_size as f64 / self.worker_states.len() as f64;
1142 stats.successful_steals = total_successful_steals;
1143 stats.failed_steals = total_failed_steals;
1144 stats.worker_utilization = worker_utils;
1145
1146 stats.uptime_seconds = self.start_time.elapsed().as_secs_f64();
1148
1149 if stats.uptime_seconds > 0.0 {
1150 stats.tasks_per_second = stats.tasks_completed as f64 / stats.uptime_seconds;
1151 }
1152
1153 stats
1154 }
1155
1156 pub fn shutdown(&mut self) {
1158 {
1160 let mut state = self.state.write().expect("Operation failed");
1161 *state = SchedulerState::ShuttingDown;
1162 }
1163
1164 for worker in &self.worker_states {
1166 worker.unparker.unpark();
1167 }
1168
1169 while let Some(worker) = self.workers.pop() {
1171 let _ = worker.join();
1172 }
1173
1174 {
1176 let mut state = self.state.write().expect("Operation failed");
1177 *state = SchedulerState::ShutDown;
1178 }
1179 }
1180
1181 pub fn numworkers(&self) -> usize {
1183 self.worker_states.len()
1184 }
1185
1186 pub fn current_workerid(&self) -> Option<usize> {
1188 get_workerid()
1189 }
1190
1191 pub fn pending_tasks(&self) -> usize {
1193 let mut total = 0;
1195
1196 for worker in &self.worker_states {
1197 total += worker.local_queue_size();
1198 }
1199
1200 total
1201 }
1202}
1203
1204impl Drop for WorkStealingScheduler {
1205 fn drop(&mut self) {
1206 if *self.state.read().expect("Operation failed") != SchedulerState::ShutDown {
1207 self.shutdown();
1208 }
1209 }
1210}
1211
1212pub struct CloneableTask<F, R>
1214where
1215 F: Fn() -> Result<R, CoreError> + Send + Sync + Clone + 'static,
1216 R: Send + 'static,
1217{
1218 func: F,
1220 name: String,
1222 priority: TaskPriority,
1224 weight: usize,
1226}
1227
1228impl<F, R> CloneableTask<F, R>
1229where
1230 F: Fn() -> Result<R, CoreError> + Send + Sync + Clone + 'static,
1231 R: Send + 'static,
1232{
1233 pub fn new(func: F) -> Self {
1235 Self {
1236 func,
1237 name: "unnamed".to_string(),
1238 priority: TaskPriority::Normal,
1239 weight: 1,
1240 }
1241 }
1242
1243 pub fn with_name(mut self, name: &str) -> Self {
1245 self.name = name.to_string();
1246 self
1247 }
1248
1249 pub fn with_priority(mut self, priority: TaskPriority) -> Self {
1251 self.priority = priority;
1252 self
1253 }
1254
1255 pub fn with_weight(mut self, weight: usize) -> Self {
1257 self.weight = weight;
1258 self
1259 }
1260}
1261
1262impl<F, R> Task for CloneableTask<F, R>
1263where
1264 F: Fn() -> Result<R, CoreError> + Send + Sync + Clone + 'static,
1265 R: Send + 'static,
1266{
1267 fn execute(&mut self) -> Result<(), CoreError> {
1268 (self.func)().map(|_| ())
1269 }
1270
1271 fn priority(&self) -> TaskPriority {
1272 self.priority
1273 }
1274
1275 fn weight(&self) -> usize {
1276 self.weight
1277 }
1278
1279 fn clone_task(&self) -> Box<dyn Task> {
1280 Box::new(Self {
1281 func: self.func.clone(),
1282 name: self.name.clone(),
1283 priority: self.priority,
1284 weight: self.weight,
1285 })
1286 }
1287
1288 fn name(&self) -> &str {
1289 &self.name
1290 }
1291}
1292
1293#[allow(dead_code)]
1295pub fn create_work_stealing_scheduler() -> WorkStealingScheduler {
1296 WorkStealingScheduler::new(SchedulerConfig::default())
1297}
1298
1299#[allow(dead_code)]
1301pub fn create_work_stealing_scheduler_with_workers(numworkers: usize) -> WorkStealingScheduler {
1302 let config = SchedulerConfigBuilder::new().workers(numworkers).build();
1303
1304 WorkStealingScheduler::new(config)
1305}
1306
1307pub struct ParallelTask<T, F, R>
1309where
1310 T: Clone + Send + Sync + 'static,
1311 F: Fn(&T) -> Result<R, CoreError> + Send + Sync + Clone + 'static,
1312 R: Send + 'static,
1313{
1314 items: Vec<T>,
1316 func: F,
1318 name: String,
1320 priority: TaskPriority,
1322 continue_onerror: bool,
1324}
1325
1326impl<T, F, R> ParallelTask<T, F, R>
1327where
1328 T: Clone + Send + Sync + 'static,
1329 F: Fn(&T) -> Result<R, CoreError> + Send + Sync + Clone + 'static,
1330 R: Send + 'static,
1331{
1332 pub fn new(items: Vec<T>, func: F) -> Self {
1334 Self {
1335 items,
1336 func,
1337 name: "parallel".to_string(),
1338 priority: TaskPriority::Normal,
1339 continue_onerror: false,
1340 }
1341 }
1342
1343 pub fn with_name(mut self, name: &str) -> Self {
1345 self.name = name.to_string();
1346 self
1347 }
1348
1349 pub fn with_priority(mut self, priority: TaskPriority) -> Self {
1351 self.priority = priority;
1352 self
1353 }
1354
1355 pub fn continue_onerror(mut self, continue_onerror: bool) -> Self {
1357 self.continue_onerror = continue_onerror;
1358 self
1359 }
1360
1361 pub fn execute(self) -> Result<Vec<R>, CoreError>
1363 where
1364 R: Clone,
1365 {
1366 let scheduler = create_work_stealing_scheduler();
1368
1369 let items_len = self.items.len();
1371
1372 let mut handles = Vec::with_capacity(items_len);
1374 let results = Arc::new(Mutex::new(Vec::with_capacity(items_len)));
1375
1376 for (i, item) in self.items.into_iter().enumerate() {
1377 let func = self.func.clone();
1378 let results_clone = results.clone();
1379 let task_name = format!("{}_{}", self.name, i);
1380 let priority = self.priority;
1381
1382 let task = CloneableTask::new(move || {
1384 let result = func(&item)?;
1385 results_clone
1386 .lock()
1387 .expect("Operation failed")
1388 .push((i, result));
1389 Ok(())
1390 })
1391 .with_name(&task_name)
1392 .with_priority(priority);
1393
1394 handles.push(scheduler.submit(task));
1395 }
1396
1397 for handle in &handles {
1399 match handle.wait() {
1400 TaskStatus::Completed => {}
1401 TaskStatus::Failed(_) if self.continue_onerror => {}
1402 status => {
1403 return Err(CoreError::SchedulerError(
1404 ErrorContext::new(format!(
1405 "Task {} failed with status {:?}",
1406 handle.id(),
1407 status
1408 ))
1409 .with_location(ErrorLocation::new(file!(), line!())),
1410 ));
1411 }
1412 }
1413 }
1414
1415 let mut result_map = Vec::with_capacity(items_len);
1417 {
1418 let results_guard = results.lock().expect("Operation failed");
1419 for (i, result) in results_guard.iter() {
1420 result_map.push((*i, result.clone()));
1421 }
1422 }
1423
1424 result_map.sort_by_key(|(i, _)| *i);
1425
1426 let results = result_map.into_iter().map(|(_, r)| r).collect();
1427
1428 Ok(results)
1429 }
1430}
1431
1432pub mod parallel {
1434 use super::*;
1435 use crate::error::CoreResult;
1436
1437 pub fn par_map<T, U, F>(items: &[T], f: F) -> CoreResult<Vec<U>>
1439 where
1440 T: Clone + Send + Sync + 'static,
1441 U: Clone + Send + 'static,
1442 F: Fn(&T) -> Result<U, CoreError> + Send + Sync + Clone + 'static,
1443 {
1444 let owned_items = items.to_vec();
1446 let task = ParallelTask::new(owned_items, f);
1447 task.execute()
1448 }
1449
1450 #[allow(dead_code)]
1452 pub fn par_filter<T, F>(items: &[T], predicate: F) -> CoreResult<Vec<T>>
1453 where
1454 T: Clone + Send + Sync + 'static,
1455 F: Fn(&T) -> Result<bool, CoreError> + Send + Sync + Clone + 'static,
1456 {
1457 let task = ParallelTask::new(items.to_vec(), move |item| {
1458 let include = predicate(item)?;
1459 if include {
1460 Ok(Some(item.clone()))
1461 } else {
1462 Ok(None)
1463 }
1464 });
1465
1466 let results = task.execute()?;
1467
1468 let filtered: Vec<_> = results.into_iter().flatten().collect();
1470
1471 Ok(filtered)
1472 }
1473
1474 #[allow(dead_code)]
1476 pub fn par_for_each<T, F>(items: &[T], f: F) -> CoreResult<()>
1477 where
1478 T: Clone + Send + Sync + 'static,
1479 F: Fn(&T) -> Result<(), CoreError> + Send + Sync + Clone + 'static,
1480 {
1481 let task = ParallelTask::new(items.to_vec(), f);
1482 task.execute()?;
1483 Ok(())
1484 }
1485
1486 #[allow(dead_code)]
1488 pub fn par_reduce<T, F>(items: &[T], init: T, f: F) -> CoreResult<T>
1489 where
1490 T: Clone + Send + Sync + 'static,
1491 F: Fn(T, &T) -> Result<T, CoreError> + Send + Sync + Clone + 'static,
1492 {
1493 if items.is_empty() {
1494 return Ok(init);
1495 }
1496
1497 let items_owned: Vec<T> = items.to_vec();
1499
1500 let num_chunks = std::cmp::min(items_owned.len(), num_cpus::get() * 4);
1502 let chunk_size = std::cmp::max(1, items_owned.len() / num_chunks);
1503
1504 let mut chunks = Vec::with_capacity(num_chunks);
1505 for chunk_start in (0..items_owned.len()).step_by(chunk_size) {
1506 let chunk_end = std::cmp::min(chunk_start + chunk_size, items_owned.len());
1507 chunks.push(items_owned[chunk_start..chunk_end].to_vec());
1508 }
1509
1510 let f_clone = f.clone();
1512 let init_clone = init.clone();
1513 let chunk_results = par_map(&chunks, move |chunk| {
1514 let mut result = init_clone.clone();
1515 for item in chunk {
1516 result = f_clone(result, item)?;
1517 }
1518 Ok(result)
1519 })?;
1520
1521 let mut final_result = init;
1523 for result in chunk_results {
1524 final_result = f(final_result, &result)?;
1525 }
1526
1527 Ok(final_result)
1528 }
1529}
1530
1531pub trait WorkStealingArray<A, S, D>
1533where
1534 A: Clone + Send + Sync + 'static,
1535 S: crate::ndarray::RawData<Elem = A>,
1536 D: crate::ndarray::Dimension,
1537{
1538 fn work_stealing_map<F, B>(&self, f: F) -> CoreResult<crate::ndarray::Array<B, D>>
1540 where
1541 B: Clone + Send + 'static,
1542 F: Fn(&A) -> Result<B, CoreError> + Send + Sync + Clone + 'static;
1543}
1544
1545impl<A, S, D> WorkStealingArray<A, S, D> for crate::ndarray::ArrayBase<S, D>
1546where
1547 A: Clone + Send + Sync + 'static,
1548 S: crate::ndarray::RawData<Elem = A> + crate::ndarray::Data,
1549 D: crate::ndarray::Dimension + Clone + Send + 'static,
1550{
1551 fn work_stealing_map<F, B>(&self, f: F) -> CoreResult<crate::ndarray::Array<B, D>>
1552 where
1553 B: Clone + Send + 'static,
1554 F: Fn(&A) -> Result<B, CoreError> + Send + Sync + Clone + 'static,
1555 {
1556 let shape = self.raw_dim();
1558 let flat_view = self
1559 .view()
1560 .into_shape_with_order(crate::ndarray::IxDyn(&[self.len()]))
1561 .expect("Test: operation failed");
1562 let flat = flat_view.to_slice().expect("Operation failed");
1563
1564 let results = parallel::par_map(flat, f)?;
1566
1567 let result_array = crate::ndarray::Array::from_shape_vec(shape, results).map_err(|e| {
1569 CoreError::DimensionError(
1570 ErrorContext::new(format!("{e}"))
1571 .with_location(ErrorLocation::new(file!(), line!())),
1572 )
1573 })?;
1574
1575 Ok(result_array)
1576 }
1577}
1578
1579impl CoreError {
1581 pub fn schedulererror(message: &str) -> Self {
1583 CoreError::SchedulerError(
1584 ErrorContext::new(message.to_string())
1585 .with_location(ErrorLocation::new(file!(), line!())),
1586 )
1587 }
1588}