1use serde::{Deserialize, Serialize};
37use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
41pub enum SyncMode {
42 Cooperative,
44 SoftwareBarrier,
46 #[default]
48 MultiLaunch,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum ReductionOp {
54 Sum,
56 Product,
58 Max,
60 Min,
62 Count,
64 All,
66 Any,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum PhaseState {
73 Pending,
75 Running,
77 Complete,
79 Failed,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct ReductionConfig {
86 pub sync_mode: SyncMode,
88 pub num_phases: u32,
90 pub block_size: u32,
92 pub grid_size: u32,
94 pub convergence_check: bool,
96 pub convergence_threshold: f64,
98}
99
100impl Default for ReductionConfig {
101 fn default() -> Self {
102 Self {
103 sync_mode: SyncMode::MultiLaunch,
104 num_phases: 2,
105 block_size: 256,
106 grid_size: 1024,
107 convergence_check: false,
108 convergence_threshold: 1e-6,
109 }
110 }
111}
112
113pub struct InterPhaseReduction<T> {
115 config: ReductionConfig,
117 input_size: usize,
119 phase_buffers: Vec<Vec<T>>,
121 current_phase: AtomicU32,
123 phase_states: Vec<AtomicU32>,
125 is_complete: AtomicBool,
127 convergence_value: AtomicU64,
129}
130
131impl<T: Default + Clone + Copy> InterPhaseReduction<T> {
132 pub fn new(input_size: usize, sync_mode: SyncMode) -> Self {
134 Self::with_config(
135 input_size,
136 ReductionConfig {
137 sync_mode,
138 ..Default::default()
139 },
140 )
141 }
142
143 pub fn with_config(input_size: usize, config: ReductionConfig) -> Self {
145 let num_phases = config.num_phases as usize;
146
147 let mut phase_buffers = Vec::with_capacity(num_phases);
149 let mut size = input_size;
150 for _ in 0..num_phases {
151 phase_buffers.push(vec![T::default(); size]);
152 size = size.div_ceil(config.block_size as usize);
154 size = size.max(1);
155 }
156
157 let phase_states: Vec<_> = (0..num_phases)
158 .map(|_| AtomicU32::new(PhaseState::Pending as u32))
159 .collect();
160
161 Self {
162 config,
163 input_size,
164 phase_buffers,
165 current_phase: AtomicU32::new(0),
166 phase_states,
167 is_complete: AtomicBool::new(false),
168 convergence_value: AtomicU64::new(0),
169 }
170 }
171
172 pub fn config(&self) -> &ReductionConfig {
174 &self.config
175 }
176
177 pub fn input_size(&self) -> usize {
179 self.input_size
180 }
181
182 pub fn current_phase(&self) -> u32 {
184 self.current_phase.load(Ordering::Relaxed)
185 }
186
187 pub fn phase_start(&self, phase: u32) -> Result<(), ReductionError> {
189 if phase >= self.config.num_phases {
190 return Err(ReductionError::InvalidPhase {
191 phase,
192 max_phases: self.config.num_phases,
193 });
194 }
195
196 let expected = PhaseState::Pending as u32;
197 let new = PhaseState::Running as u32;
198
199 match self.phase_states[phase as usize].compare_exchange(
200 expected,
201 new,
202 Ordering::SeqCst,
203 Ordering::SeqCst,
204 ) {
205 Ok(_) => {
206 self.current_phase.store(phase, Ordering::Relaxed);
207 Ok(())
208 }
209 Err(current) => Err(ReductionError::InvalidPhaseState {
210 phase,
211 current: phase_state_from_u32(current),
212 }),
213 }
214 }
215
216 pub fn phase_complete(&self, phase: u32) -> Result<(), ReductionError> {
218 if phase >= self.config.num_phases {
219 return Err(ReductionError::InvalidPhase {
220 phase,
221 max_phases: self.config.num_phases,
222 });
223 }
224
225 let expected = PhaseState::Running as u32;
226 let new = PhaseState::Complete as u32;
227
228 match self.phase_states[phase as usize].compare_exchange(
229 expected,
230 new,
231 Ordering::SeqCst,
232 Ordering::SeqCst,
233 ) {
234 Ok(_) => {
235 if phase == self.config.num_phases - 1 {
237 self.is_complete.store(true, Ordering::Release);
238 }
239 Ok(())
240 }
241 Err(current) => Err(ReductionError::InvalidPhaseState {
242 phase,
243 current: phase_state_from_u32(current),
244 }),
245 }
246 }
247
248 pub fn phase_failed(&self, phase: u32) {
250 if (phase as usize) < self.phase_states.len() {
251 self.phase_states[phase as usize].store(PhaseState::Failed as u32, Ordering::Release);
252 }
253 }
254
255 pub fn phase_state(&self, phase: u32) -> PhaseState {
257 if phase >= self.config.num_phases {
258 return PhaseState::Pending;
259 }
260 phase_state_from_u32(self.phase_states[phase as usize].load(Ordering::Acquire))
261 }
262
263 pub fn is_complete(&self) -> bool {
265 self.is_complete.load(Ordering::Acquire)
266 }
267
268 pub fn get_buffer(&self, phase: u32) -> Option<&[T]> {
270 self.phase_buffers.get(phase as usize).map(|v| v.as_slice())
271 }
272
273 pub fn get_buffer_mut(&mut self, phase: u32) -> Option<&mut [T]> {
275 self.phase_buffers
276 .get_mut(phase as usize)
277 .map(|v| v.as_mut_slice())
278 }
279
280 pub fn buffer_size(&self, phase: u32) -> usize {
282 self.phase_buffers
283 .get(phase as usize)
284 .map(|v| v.len())
285 .unwrap_or(0)
286 }
287
288 pub fn set_convergence(&self, value: f64) {
290 self.convergence_value
291 .store(value.to_bits(), Ordering::Release);
292 }
293
294 pub fn convergence(&self) -> f64 {
296 f64::from_bits(self.convergence_value.load(Ordering::Acquire))
297 }
298
299 pub fn is_converged(&self) -> bool {
301 if !self.config.convergence_check {
302 return false;
303 }
304 self.convergence() < self.config.convergence_threshold
305 }
306
307 pub fn reset(&mut self) {
309 self.current_phase.store(0, Ordering::Relaxed);
310 self.is_complete.store(false, Ordering::Release);
311 self.convergence_value.store(0, Ordering::Release);
312
313 for state in &self.phase_states {
314 state.store(PhaseState::Pending as u32, Ordering::Release);
315 }
316
317 for buffer in &mut self.phase_buffers {
318 for item in buffer.iter_mut() {
319 *item = T::default();
320 }
321 }
322 }
323}
324
325fn phase_state_from_u32(value: u32) -> PhaseState {
326 match value {
327 0 => PhaseState::Pending,
328 1 => PhaseState::Running,
329 2 => PhaseState::Complete,
330 _ => PhaseState::Failed,
331 }
332}
333
334#[derive(Debug, thiserror::Error)]
336pub enum ReductionError {
337 #[error("Invalid phase {phase}, max phases: {max_phases}")]
339 InvalidPhase {
340 phase: u32,
342 max_phases: u32,
344 },
345
346 #[error("Invalid phase state for phase {phase}: {current:?}")]
348 InvalidPhaseState {
349 phase: u32,
351 current: PhaseState,
353 },
354
355 #[error("Reduction not complete, current phase: {current_phase}")]
357 NotComplete {
358 current_phase: u32,
360 },
361
362 #[error("Buffer size mismatch: expected {expected}, got {actual}")]
364 BufferSizeMismatch {
365 expected: usize,
367 actual: usize,
369 },
370}
371
372pub struct GlobalReduction {
374 pub total_participants: u32,
376 pub completed: AtomicU32,
378 pub all_complete: AtomicBool,
380 pub partial_results: Vec<AtomicU64>,
382}
383
384impl GlobalReduction {
385 pub fn new(participants: u32) -> Self {
387 let partial_results = (0..participants).map(|_| AtomicU64::new(0)).collect();
388
389 Self {
390 total_participants: participants,
391 completed: AtomicU32::new(0),
392 all_complete: AtomicBool::new(false),
393 partial_results,
394 }
395 }
396
397 pub fn submit(&self, participant_id: u32, value: f64) -> bool {
399 if participant_id >= self.total_participants {
400 return false;
401 }
402
403 self.partial_results[participant_id as usize].store(value.to_bits(), Ordering::Release);
404
405 let count = self.completed.fetch_add(1, Ordering::AcqRel) + 1;
406 if count == self.total_participants {
407 self.all_complete.store(true, Ordering::Release);
408 return true;
409 }
410
411 false
412 }
413
414 pub fn is_complete(&self) -> bool {
416 self.all_complete.load(Ordering::Acquire)
417 }
418
419 pub fn completion_count(&self) -> u32 {
421 self.completed.load(Ordering::Acquire)
422 }
423
424 pub fn finalize_sum(&self) -> Option<f64> {
426 if !self.is_complete() {
427 return None;
428 }
429
430 let sum: f64 = self
431 .partial_results
432 .iter()
433 .map(|v| f64::from_bits(v.load(Ordering::Acquire)))
434 .sum();
435
436 Some(sum)
437 }
438
439 pub fn finalize_max(&self) -> Option<f64> {
441 if !self.is_complete() {
442 return None;
443 }
444
445 self.partial_results
446 .iter()
447 .map(|v| f64::from_bits(v.load(Ordering::Acquire)))
448 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
449 }
450
451 pub fn finalize_min(&self) -> Option<f64> {
453 if !self.is_complete() {
454 return None;
455 }
456
457 self.partial_results
458 .iter()
459 .map(|v| f64::from_bits(v.load(Ordering::Acquire)))
460 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
461 }
462
463 pub fn reset(&self) {
465 self.completed.store(0, Ordering::Release);
466 self.all_complete.store(false, Ordering::Release);
467 for partial in &self.partial_results {
468 partial.store(0, Ordering::Release);
469 }
470 }
471}
472
473pub struct CooperativeBarrier {
475 expected: u32,
477 arrived: AtomicU32,
479 generation: AtomicU32,
481}
482
483impl CooperativeBarrier {
484 pub fn new(expected: u32) -> Self {
486 Self {
487 expected,
488 arrived: AtomicU32::new(0),
489 generation: AtomicU32::new(0),
490 }
491 }
492
493 pub fn wait(&self) -> u32 {
495 let generation_num = self.generation.load(Ordering::Acquire);
496 let arrived = self.arrived.fetch_add(1, Ordering::AcqRel) + 1;
497
498 if arrived == self.expected {
499 self.arrived.store(0, Ordering::Release);
501 self.generation.fetch_add(1, Ordering::Release);
502 } else {
503 while self.generation.load(Ordering::Acquire) == generation_num {
505 std::hint::spin_loop();
506 }
507 }
508
509 generation_num
510 }
511
512 pub fn reset(&self) {
514 self.arrived.store(0, Ordering::Release);
515 self.generation.store(0, Ordering::Release);
516 }
517}
518
519pub struct ReductionBuilder {
521 config: ReductionConfig,
522}
523
524impl ReductionBuilder {
525 pub fn new() -> Self {
527 Self {
528 config: ReductionConfig::default(),
529 }
530 }
531
532 pub fn sync_mode(mut self, mode: SyncMode) -> Self {
534 self.config.sync_mode = mode;
535 self
536 }
537
538 pub fn phases(mut self, num: u32) -> Self {
540 self.config.num_phases = num;
541 self
542 }
543
544 pub fn block_size(mut self, size: u32) -> Self {
546 self.config.block_size = size;
547 self
548 }
549
550 pub fn grid_size(mut self, size: u32) -> Self {
552 self.config.grid_size = size;
553 self
554 }
555
556 pub fn with_convergence(mut self, threshold: f64) -> Self {
558 self.config.convergence_check = true;
559 self.config.convergence_threshold = threshold;
560 self
561 }
562
563 pub fn build(self) -> ReductionConfig {
565 self.config
566 }
567
568 pub fn build_reduction<T: Default + Clone + Copy>(
570 self,
571 input_size: usize,
572 ) -> InterPhaseReduction<T> {
573 InterPhaseReduction::with_config(input_size, self.config)
574 }
575}
576
577impl Default for ReductionBuilder {
578 fn default() -> Self {
579 Self::new()
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586
587 #[test]
588 fn test_inter_phase_reduction() {
589 let reduction = InterPhaseReduction::<f64>::new(1024, SyncMode::MultiLaunch);
590
591 assert_eq!(reduction.current_phase(), 0);
592 assert!(!reduction.is_complete());
593
594 reduction.phase_start(0).unwrap();
596 assert_eq!(reduction.phase_state(0), PhaseState::Running);
597 reduction.phase_complete(0).unwrap();
598 assert_eq!(reduction.phase_state(0), PhaseState::Complete);
599
600 reduction.phase_start(1).unwrap();
602 reduction.phase_complete(1).unwrap();
603
604 assert!(reduction.is_complete());
605 }
606
607 #[test]
608 fn test_phase_buffers() {
609 let mut reduction = InterPhaseReduction::<f64>::with_config(
610 1000,
611 ReductionConfig {
612 block_size: 256,
613 num_phases: 3,
614 ..Default::default()
615 },
616 );
617
618 assert_eq!(reduction.buffer_size(0), 1000);
620
621 assert!(reduction.buffer_size(1) < reduction.buffer_size(0));
623
624 if let Some(buf) = reduction.get_buffer_mut(0) {
626 buf[0] = 42.0;
627 }
628
629 assert_eq!(reduction.get_buffer(0).unwrap()[0], 42.0);
630 }
631
632 #[test]
633 fn test_global_reduction() {
634 let reduction = GlobalReduction::new(4);
635
636 assert!(!reduction.is_complete());
637
638 reduction.submit(0, 1.0);
639 reduction.submit(1, 2.0);
640 reduction.submit(2, 3.0);
641
642 assert!(!reduction.is_complete());
643 assert_eq!(reduction.completion_count(), 3);
644
645 reduction.submit(3, 4.0);
646
647 assert!(reduction.is_complete());
648 assert_eq!(reduction.finalize_sum(), Some(10.0));
649 }
650
651 #[test]
652 fn test_cooperative_barrier() {
653 use std::sync::Arc;
654 use std::thread;
655
656 let barrier = Arc::new(CooperativeBarrier::new(3));
657 let handles: Vec<_> = (0..3)
658 .map(|_| {
659 let b = barrier.clone();
660 thread::spawn(move || b.wait())
661 })
662 .collect();
663
664 for h in handles {
665 let generation_num = h.join().unwrap();
666 assert_eq!(generation_num, 0);
667 }
668 }
669
670 #[test]
671 fn test_reduction_builder() {
672 let config = ReductionBuilder::new()
673 .sync_mode(SyncMode::Cooperative)
674 .phases(3)
675 .block_size(512)
676 .with_convergence(1e-8)
677 .build();
678
679 assert_eq!(config.sync_mode, SyncMode::Cooperative);
680 assert_eq!(config.num_phases, 3);
681 assert_eq!(config.block_size, 512);
682 assert!(config.convergence_check);
683 }
684
685 #[test]
686 fn test_convergence_tracking() {
687 let reduction = InterPhaseReduction::<f64>::with_config(
688 100,
689 ReductionConfig {
690 convergence_check: true,
691 convergence_threshold: 1e-6,
692 ..Default::default()
693 },
694 );
695
696 reduction.set_convergence(1e-3);
697 assert!(!reduction.is_converged());
698
699 reduction.set_convergence(1e-8);
700 assert!(reduction.is_converged());
701 }
702}