1use crate::common::IntegrateFloat;
30use crate::error::IntegrateResult;
31use std::collections::VecDeque;
32use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
33use std::sync::{Arc, Condvar, Mutex};
34use std::thread::{self, JoinHandle};
35use std::time::{Duration, Instant};
36
37pub trait WorkStealingTask: Send + 'static {
39 type Output: Send;
40
41 fn execute(&mut self) -> Self::Output;
43
44 fn estimated_cost(&self) -> f64 {
46 1.0
47 }
48
49 fn can_subdivide(&self) -> bool {
51 false
52 }
53
54 fn subdivide(&self) -> Vec<Box<dyn WorkStealingTask<Output = Self::Output>>>
56 where
57 Self: Sized,
58 {
59 vec![]
60 }
61}
62
63pub struct Task<F, R>
65where
66 F: FnOnce() -> R + Send + 'static,
67 R: Send + 'static,
68{
69 func: Option<F>,
70 cost_estimate: f64,
71}
72
73impl<F, R> Task<F, R>
74where
75 F: FnOnce() -> R + Send + 'static,
76 R: Send + 'static,
77{
78 pub fn new(func: F) -> Self {
80 Self {
81 func: Some(func),
82 cost_estimate: 1.0,
83 }
84 }
85
86 pub fn with_cost(func: F, cost: f64) -> Self {
88 Self {
89 func: Some(func),
90 cost_estimate: cost,
91 }
92 }
93}
94
95impl<F, R> WorkStealingTask for Task<F, R>
96where
97 F: FnOnce() -> R + Send + 'static,
98 R: Send + 'static,
99{
100 type Output = R;
101
102 fn execute(&mut self) -> Self::Output {
103 (self.func.take().unwrap())()
104 }
105
106 fn estimated_cost(&self) -> f64 {
107 self.cost_estimate
108 }
109}
110
111#[derive(Debug)]
113struct WorkStealingDeque<T> {
114 items: VecDeque<T>,
115 total_cost: f64,
116}
117
118impl<T: WorkStealingTask> WorkStealingDeque<T> {
119 fn new() -> Self {
120 Self {
121 items: VecDeque::new(),
122 total_cost: 0.0,
123 }
124 }
125
126 fn push_back(&mut self, task: T) {
127 self.total_cost += task.estimated_cost();
128 self.items.push_back(task);
129 }
130
131 fn pop_back(&mut self) -> Option<T> {
132 if let Some(task) = self.items.pop_back() {
133 self.total_cost -= task.estimated_cost();
134 Some(task)
135 } else {
136 None
137 }
138 }
139
140 fn steal_front(&mut self) -> Option<T> {
141 if let Some(task) = self.items.pop_front() {
142 self.total_cost -= task.estimated_cost();
143 Some(task)
144 } else {
145 None
146 }
147 }
148
149 #[allow(dead_code)]
150 fn len(&self) -> usize {
151 self.items.len()
152 }
153
154 fn is_empty(&self) -> bool {
155 self.items.is_empty()
156 }
157
158 fn total_cost(&self) -> f64 {
159 self.total_cost
160 }
161}
162
163struct WorkerState<T: WorkStealingTask> {
165 local_queue: Mutex<WorkStealingDeque<T>>,
167 completed_tasks: AtomicUsize,
169 computation_time: Mutex<Duration>,
171}
172
173impl<T: WorkStealingTask> WorkerState<T> {
174 fn new() -> Self {
175 Self {
176 local_queue: Mutex::new(WorkStealingDeque::new()),
177 completed_tasks: AtomicUsize::new(0),
178 computation_time: Mutex::new(Duration::ZERO),
179 }
180 }
181}
182
183pub struct WorkStealingPool<T: WorkStealingTask> {
185 workers: Vec<JoinHandle<()>>,
187 worker_states: Arc<Vec<WorkerState<T>>>,
189 global_queue: Arc<Mutex<WorkStealingDeque<T>>>,
191 active_tasks: Arc<AtomicUsize>,
193 shutdown: Arc<AtomicBool>,
195 cv: Arc<Condvar>,
197 #[allow(dead_code)]
199 cv_mutex: Arc<Mutex<()>>,
200 stats: Arc<Mutex<PoolStatistics>>,
202}
203
204#[derive(Debug, Clone, Default)]
206pub struct PoolStatistics {
207 pub total_tasks: usize,
209 pub total_computation_time: Duration,
211 pub steal_attempts: usize,
213 pub successful_steals: usize,
215 pub load_balance_efficiency: f64,
217}
218
219impl<T: WorkStealingTask + 'static> WorkStealingPool<T> {
220 pub fn new(_numthreads: usize) -> Self {
222 let _num_threads = _numthreads.max(1);
223
224 let worker_states = Arc::new(
225 (0.._num_threads)
226 .map(|_| WorkerState::new())
227 .collect::<Vec<_>>(),
228 );
229
230 let global_queue = Arc::new(Mutex::new(WorkStealingDeque::new()));
231 let active_tasks = Arc::new(AtomicUsize::new(0));
232 let shutdown = Arc::new(AtomicBool::new(false));
233 let cv = Arc::new(Condvar::new());
234 let cv_mutex = Arc::new(Mutex::new(()));
235 let stats = Arc::new(Mutex::new(PoolStatistics::default()));
236
237 let workers = (0.._num_threads)
238 .map(|worker_id| {
239 let worker_states = Arc::clone(&worker_states);
240 let global_queue = Arc::clone(&global_queue);
241 let active_tasks = Arc::clone(&active_tasks);
242 let shutdown = Arc::clone(&shutdown);
243 let cv = Arc::clone(&cv);
244 let cv_mutex = Arc::clone(&cv_mutex);
245 let stats = Arc::clone(&stats);
246
247 thread::spawn(move || {
248 Self::worker_thread(
249 worker_id,
250 worker_states,
251 global_queue,
252 active_tasks,
253 shutdown,
254 cv,
255 cv_mutex,
256 stats,
257 );
258 })
259 })
260 .collect();
261
262 Self {
263 workers,
264 worker_states,
265 global_queue,
266 active_tasks,
267 shutdown,
268 cv,
269 cv_mutex,
270 stats,
271 }
272 }
273
274 pub fn submit(&self, task: T) {
276 let mut global_queue = self.global_queue.lock().unwrap();
277 global_queue.push_back(task);
278 drop(global_queue);
279
280 self.cv.notify_one();
282 }
283
284 pub fn submit_all(&self, tasks: Vec<T>) {
286 let mut global_queue = self.global_queue.lock().unwrap();
287 for task in tasks {
288 global_queue.push_back(task);
289 }
290 drop(global_queue);
291
292 self.cv.notify_all();
294 }
295
296 pub fn execute_and_wait(&self) -> IntegrateResult<()> {
298 loop {
300 let global_empty = self.global_queue.lock().unwrap().is_empty();
302 let locals_empty = self
303 .worker_states
304 .iter()
305 .all(|state| state.local_queue.lock().unwrap().is_empty());
306 let no_active_tasks = self.active_tasks.load(Ordering::Relaxed) == 0;
307
308 if global_empty && locals_empty && no_active_tasks {
309 break;
310 }
311
312 thread::sleep(Duration::from_micros(100));
314 }
315
316 Ok(())
317 }
318
319 pub fn statistics(&self) -> PoolStatistics {
321 let mut stats = self.stats.lock().unwrap();
322
323 stats.total_tasks = self
325 .worker_states
326 .iter()
327 .map(|state| state.completed_tasks.load(Ordering::Relaxed))
328 .sum();
329
330 stats.total_computation_time = self
331 .worker_states
332 .iter()
333 .map(|state| *state.computation_time.lock().unwrap())
334 .sum();
335
336 if stats.total_tasks > 0 {
338 let worker_loads: Vec<f64> = self
339 .worker_states
340 .iter()
341 .map(|state| {
342 let completed = state.completed_tasks.load(Ordering::Relaxed);
343 completed as f64 / stats.total_tasks as f64
344 })
345 .collect();
346
347 let ideal_load = 1.0 / self.worker_states.len() as f64;
348 let load_variance: f64 = worker_loads
349 .iter()
350 .map(|&load| (load - ideal_load).powi(2))
351 .sum::<f64>()
352 / self.worker_states.len() as f64;
353
354 stats.load_balance_efficiency = (1.0 - load_variance).max(0.0);
355 }
356
357 stats.clone()
358 }
359
360 fn worker_thread(
362 worker_id: usize,
363 worker_states: Arc<Vec<WorkerState<T>>>,
364 global_queue: Arc<Mutex<WorkStealingDeque<T>>>,
365 active_tasks: Arc<AtomicUsize>,
366 shutdown: Arc<AtomicBool>,
367 cv: Arc<Condvar>,
368 cv_mutex: Arc<Mutex<()>>,
369 stats: Arc<Mutex<PoolStatistics>>,
370 ) {
371 let my_state = &worker_states[worker_id];
372
373 while !shutdown.load(Ordering::Relaxed) {
374 let mut task_opt = my_state.local_queue.lock().unwrap().pop_back();
376
377 if task_opt.is_none() {
379 task_opt = global_queue.lock().unwrap().pop_back();
380 }
381
382 if task_opt.is_none() {
384 task_opt = Self::try_steal_work(worker_id, &worker_states, &stats);
385 }
386
387 if let Some(mut task) = task_opt {
388 active_tasks.fetch_add(1, Ordering::Relaxed);
390
391 let start_time = Instant::now();
393 let _result = task.execute();
394 let computation_time = start_time.elapsed();
395
396 active_tasks.fetch_sub(1, Ordering::Relaxed);
398
399 my_state.completed_tasks.fetch_add(1, Ordering::Relaxed);
401 *my_state.computation_time.lock().unwrap() += computation_time;
402 } else {
403 let _guard = cv
405 .wait_timeout(cv_mutex.lock().unwrap(), Duration::from_millis(10))
406 .unwrap();
407 }
408 }
409 }
410
411 fn try_steal_work(
413 worker_id: usize,
414 worker_states: &[WorkerState<T>],
415 stats: &Arc<Mutex<PoolStatistics>>,
416 ) -> Option<T> {
417 stats.lock().unwrap().steal_attempts += 1;
419
420 let mut best_victim = None;
422 let mut best_cost = 0.0;
423
424 for (victim_id, victim_state) in worker_states.iter().enumerate() {
425 if victim_id == worker_id {
426 continue; }
428
429 let queue = victim_state.local_queue.lock().unwrap();
430 let cost = queue.total_cost();
431
432 if cost > best_cost && !queue.is_empty() {
433 best_cost = cost;
434 best_victim = Some(victim_id);
435 }
436 }
437
438 if let Some(victim_id) = best_victim {
440 let victim_state = &worker_states[victim_id];
441 let mut victim_queue = victim_state.local_queue.lock().unwrap();
442
443 if let Some(stolen_task) = victim_queue.steal_front() {
444 stats.lock().unwrap().successful_steals += 1;
446 return Some(stolen_task);
447 }
448 }
449
450 None
451 }
452}
453
454impl<T: WorkStealingTask> Drop for WorkStealingPool<T> {
455 fn drop(&mut self) {
456 self.shutdown.store(true, Ordering::Relaxed);
458 self.cv.notify_all();
459
460 while let Some(worker) = self.workers.pop() {
462 let _ = worker.join();
463 }
464 }
465}
466
467pub struct AdaptiveIntegrationTask<F: IntegrateFloat, Func> {
469 integrand: Func,
471 interval: (F, F),
473 tolerance: F,
475 depth: usize,
477 max_depth: usize,
479}
480
481impl<F: IntegrateFloat, Func> AdaptiveIntegrationTask<F, Func>
482where
483 Func: Fn(F) -> F + Send + Clone + 'static,
484{
485 pub fn new(integrand: Func, interval: (F, F), tolerance: F, max_depth: usize) -> Self {
487 Self {
488 integrand,
489 interval,
490 tolerance,
491 depth: 0,
492 max_depth,
493 }
494 }
495
496 fn integrate_region(&self) -> F {
498 let (a, b) = self.interval;
499 let h = b - a;
500 let fa = (self.integrand)(a);
501 let fb = (self.integrand)(b);
502 h * (fa + fb) / F::from(2.0).unwrap()
503 }
504
505 fn estimate_error(&self) -> F {
507 let (a, b) = self.interval;
508 let mid = (a + b) / F::from(2.0).unwrap();
509
510 let coarse = self.integrate_region();
512
513 let left_task = AdaptiveIntegrationTask {
515 integrand: self.integrand.clone(),
516 interval: (a, mid),
517 tolerance: self.tolerance,
518 depth: self.depth + 1,
519 max_depth: self.max_depth,
520 };
521
522 let right_task = AdaptiveIntegrationTask {
523 integrand: self.integrand.clone(),
524 interval: (mid, b),
525 tolerance: self.tolerance,
526 depth: self.depth + 1,
527 max_depth: self.max_depth,
528 };
529
530 let fine = left_task.integrate_region() + right_task.integrate_region();
531
532 (fine - coarse).abs()
533 }
534}
535
536impl<F: IntegrateFloat + Send, Func> WorkStealingTask for AdaptiveIntegrationTask<F, Func>
537where
538 Func: Fn(F) -> F + Send + Clone + 'static,
539{
540 type Output = IntegrateResult<F>;
541
542 fn execute(&mut self) -> Self::Output {
543 let result = self.integrate_region();
544 Ok(result)
545 }
546
547 fn estimated_cost(&self) -> f64 {
548 let (a, b) = self.interval;
549 (b - a).to_f64().unwrap_or(1.0)
550 }
551
552 fn can_subdivide(&self) -> bool {
553 self.depth < self.max_depth && self.estimate_error() > self.tolerance
554 }
555
556 fn subdivide(&self) -> Vec<Box<dyn WorkStealingTask<Output = Self::Output>>> {
557 let (a, b) = self.interval;
558 let mid = (a + b) / F::from(2.0).unwrap();
559
560 let left_task = AdaptiveIntegrationTask {
561 integrand: self.integrand.clone(),
562 interval: (a, mid),
563 tolerance: self.tolerance / F::from(2.0).unwrap(),
564 depth: self.depth + 1,
565 max_depth: self.max_depth,
566 };
567
568 let right_task = AdaptiveIntegrationTask {
569 integrand: self.integrand.clone(),
570 interval: (mid, b),
571 tolerance: self.tolerance / F::from(2.0).unwrap(),
572 depth: self.depth + 1,
573 max_depth: self.max_depth,
574 };
575
576 vec![
577 Box::new(left_task) as Box<dyn WorkStealingTask<Output = Self::Output>>,
578 Box::new(right_task) as Box<dyn WorkStealingTask<Output = Self::Output>>,
579 ]
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586 use std::sync::atomic::AtomicI32;
587
588 #[test]
589 fn test_work_stealing_pool_basic() {
590 let pool: WorkStealingPool<Task<_, i32>> = WorkStealingPool::new(2);
591
592 for i in 0..10 {
594 let task = Task::new(move || i * 2);
595 pool.submit(task);
596 }
597
598 assert!(pool.execute_and_wait().is_ok());
600
601 let stats = pool.statistics();
603 assert_eq!(stats.total_tasks, 10);
604 assert!(stats.load_balance_efficiency >= 0.0);
605 }
606
607 #[test]
608 fn test_task_subdivision() {
609 let integrand = |x: f64| x * x;
610 let task = AdaptiveIntegrationTask::new(integrand, (0.0, 1.0), 1e-6, 5);
611
612 assert!(task.can_subdivide());
613
614 let subtasks = task.subdivide();
615 assert_eq!(subtasks.len(), 2);
616 }
617
618 #[test]
619 fn test_load_balancing() {
620 let pool: WorkStealingPool<Task<_, ()>> = WorkStealingPool::new(4);
621 let counter = Arc::new(AtomicI32::new(0));
622
623 for i in 0..20 {
625 let counter_clone = Arc::clone(&counter);
626 let sleep_time = (i % 5) * 10; let task = Task::with_cost(
629 move || {
630 thread::sleep(Duration::from_millis(sleep_time));
631 counter_clone.fetch_add(1, Ordering::Relaxed);
632 },
633 sleep_time as f64,
634 );
635
636 pool.submit(task);
637 }
638
639 pool.execute_and_wait().unwrap();
640
641 assert_eq!(counter.load(Ordering::Relaxed), 20);
642
643 let stats = pool.statistics();
644 assert_eq!(stats.total_tasks, 20);
645 assert!(stats.steal_attempts > 0); }
647}