rustkernel_core/memory/
reduction.rs

1//! Reduction Primitives
2//!
3//! Provides GPU reduction operations and multi-phase kernel synchronization
4//! for iterative algorithms like PageRank, K-Means, and graph analytics.
5//!
6//! # Reduction Modes
7//!
8//! - **Single-pass**: Simple reductions completed in one kernel launch
9//! - **Multi-phase**: Complex reductions requiring intermediate storage
10//! - **Cooperative**: GPU-wide synchronization using cooperative groups
11//!
12//! # Sync Modes
13//!
14//! - **Cooperative**: Use CUDA cooperative groups for grid-wide sync
15//! - **SoftwareBarrier**: Software-based barrier with atomic operations
16//! - **MultiLaunch**: Separate kernel launches per phase
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! use rustkernel_core::memory::reduction::{InterPhaseReduction, SyncMode};
22//!
23//! let reduction = InterPhaseReduction::<f64>::new(1000, SyncMode::Cooperative);
24//!
25//! // Phase 1: Local reduction
26//! reduction.phase_start(0);
27//! // ... kernel execution ...
28//! reduction.phase_complete(0);
29//!
30//! // Phase 2: Global reduction
31//! reduction.phase_start(1);
32//! // ... kernel execution ...
33//! let result = reduction.finalize();
34//! ```
35
36use serde::{Deserialize, Serialize};
37use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
38
39/// Synchronization mode for multi-phase reductions
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
41pub enum SyncMode {
42    /// Use cooperative groups for GPU-wide synchronization
43    Cooperative,
44    /// Software barrier using atomic operations
45    SoftwareBarrier,
46    /// Separate kernel launches per phase
47    #[default]
48    MultiLaunch,
49}
50
51/// Reduction operation type
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum ReductionOp {
54    /// Sum all values
55    Sum,
56    /// Product of all values
57    Product,
58    /// Maximum value
59    Max,
60    /// Minimum value
61    Min,
62    /// Count of true values
63    Count,
64    /// Logical AND
65    All,
66    /// Logical OR
67    Any,
68}
69
70/// Phase state for multi-phase reductions
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum PhaseState {
73    /// Phase not started
74    Pending,
75    /// Phase in progress
76    Running,
77    /// Phase completed
78    Complete,
79    /// Phase failed
80    Failed,
81}
82
83/// Configuration for inter-phase reduction
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct ReductionConfig {
86    /// Synchronization mode
87    pub sync_mode: SyncMode,
88    /// Number of phases
89    pub num_phases: u32,
90    /// Elements per block for block-level reduction
91    pub block_size: u32,
92    /// Number of blocks
93    pub grid_size: u32,
94    /// Enable convergence checking between phases
95    pub convergence_check: bool,
96    /// Convergence threshold
97    pub convergence_threshold: f64,
98}
99
100impl Default for ReductionConfig {
101    fn default() -> Self {
102        Self {
103            sync_mode: SyncMode::MultiLaunch,
104            num_phases: 2,
105            block_size: 256,
106            grid_size: 1024,
107            convergence_check: false,
108            convergence_threshold: 1e-6,
109        }
110    }
111}
112
113/// Inter-phase reduction state
114pub struct InterPhaseReduction<T> {
115    /// Configuration
116    config: ReductionConfig,
117    /// Input size
118    input_size: usize,
119    /// Phase buffers
120    phase_buffers: Vec<Vec<T>>,
121    /// Current phase
122    current_phase: AtomicU32,
123    /// Phase states
124    phase_states: Vec<AtomicU32>,
125    /// Whether reduction is complete
126    is_complete: AtomicBool,
127    /// Convergence value (if tracked)
128    convergence_value: AtomicU64,
129}
130
131impl<T: Default + Clone + Copy> InterPhaseReduction<T> {
132    /// Create a new inter-phase reduction
133    pub fn new(input_size: usize, sync_mode: SyncMode) -> Self {
134        Self::with_config(
135            input_size,
136            ReductionConfig {
137                sync_mode,
138                ..Default::default()
139            },
140        )
141    }
142
143    /// Create with full configuration
144    pub fn with_config(input_size: usize, config: ReductionConfig) -> Self {
145        let num_phases = config.num_phases as usize;
146
147        // Calculate buffer sizes for each phase
148        let mut phase_buffers = Vec::with_capacity(num_phases);
149        let mut size = input_size;
150        for _ in 0..num_phases {
151            phase_buffers.push(vec![T::default(); size]);
152            // Each phase reduces by block_size factor
153            size = size.div_ceil(config.block_size as usize);
154            size = size.max(1);
155        }
156
157        let phase_states: Vec<_> = (0..num_phases)
158            .map(|_| AtomicU32::new(PhaseState::Pending as u32))
159            .collect();
160
161        Self {
162            config,
163            input_size,
164            phase_buffers,
165            current_phase: AtomicU32::new(0),
166            phase_states,
167            is_complete: AtomicBool::new(false),
168            convergence_value: AtomicU64::new(0),
169        }
170    }
171
172    /// Get configuration
173    pub fn config(&self) -> &ReductionConfig {
174        &self.config
175    }
176
177    /// Get input size
178    pub fn input_size(&self) -> usize {
179        self.input_size
180    }
181
182    /// Get current phase
183    pub fn current_phase(&self) -> u32 {
184        self.current_phase.load(Ordering::Relaxed)
185    }
186
187    /// Start a phase
188    pub fn phase_start(&self, phase: u32) -> Result<(), ReductionError> {
189        if phase >= self.config.num_phases {
190            return Err(ReductionError::InvalidPhase {
191                phase,
192                max_phases: self.config.num_phases,
193            });
194        }
195
196        let expected = PhaseState::Pending as u32;
197        let new = PhaseState::Running as u32;
198
199        match self.phase_states[phase as usize].compare_exchange(
200            expected,
201            new,
202            Ordering::SeqCst,
203            Ordering::SeqCst,
204        ) {
205            Ok(_) => {
206                self.current_phase.store(phase, Ordering::Relaxed);
207                Ok(())
208            }
209            Err(current) => Err(ReductionError::InvalidPhaseState {
210                phase,
211                current: phase_state_from_u32(current),
212            }),
213        }
214    }
215
216    /// Complete a phase
217    pub fn phase_complete(&self, phase: u32) -> Result<(), ReductionError> {
218        if phase >= self.config.num_phases {
219            return Err(ReductionError::InvalidPhase {
220                phase,
221                max_phases: self.config.num_phases,
222            });
223        }
224
225        let expected = PhaseState::Running as u32;
226        let new = PhaseState::Complete as u32;
227
228        match self.phase_states[phase as usize].compare_exchange(
229            expected,
230            new,
231            Ordering::SeqCst,
232            Ordering::SeqCst,
233        ) {
234            Ok(_) => {
235                // Check if all phases complete
236                if phase == self.config.num_phases - 1 {
237                    self.is_complete.store(true, Ordering::Release);
238                }
239                Ok(())
240            }
241            Err(current) => Err(ReductionError::InvalidPhaseState {
242                phase,
243                current: phase_state_from_u32(current),
244            }),
245        }
246    }
247
248    /// Mark a phase as failed
249    pub fn phase_failed(&self, phase: u32) {
250        if (phase as usize) < self.phase_states.len() {
251            self.phase_states[phase as usize].store(PhaseState::Failed as u32, Ordering::Release);
252        }
253    }
254
255    /// Get phase state
256    pub fn phase_state(&self, phase: u32) -> PhaseState {
257        if phase >= self.config.num_phases {
258            return PhaseState::Pending;
259        }
260        phase_state_from_u32(self.phase_states[phase as usize].load(Ordering::Acquire))
261    }
262
263    /// Check if reduction is complete
264    pub fn is_complete(&self) -> bool {
265        self.is_complete.load(Ordering::Acquire)
266    }
267
268    /// Get buffer for a phase (for reading previous phase results)
269    pub fn get_buffer(&self, phase: u32) -> Option<&[T]> {
270        self.phase_buffers.get(phase as usize).map(|v| v.as_slice())
271    }
272
273    /// Get mutable buffer for a phase (for writing current phase results)
274    pub fn get_buffer_mut(&mut self, phase: u32) -> Option<&mut [T]> {
275        self.phase_buffers
276            .get_mut(phase as usize)
277            .map(|v| v.as_mut_slice())
278    }
279
280    /// Get buffer size for a phase
281    pub fn buffer_size(&self, phase: u32) -> usize {
282        self.phase_buffers
283            .get(phase as usize)
284            .map(|v| v.len())
285            .unwrap_or(0)
286    }
287
288    /// Set convergence value (as bits)
289    pub fn set_convergence(&self, value: f64) {
290        self.convergence_value
291            .store(value.to_bits(), Ordering::Release);
292    }
293
294    /// Get convergence value
295    pub fn convergence(&self) -> f64 {
296        f64::from_bits(self.convergence_value.load(Ordering::Acquire))
297    }
298
299    /// Check if converged (if convergence checking enabled)
300    pub fn is_converged(&self) -> bool {
301        if !self.config.convergence_check {
302            return false;
303        }
304        self.convergence() < self.config.convergence_threshold
305    }
306
307    /// Reset for reuse
308    pub fn reset(&mut self) {
309        self.current_phase.store(0, Ordering::Relaxed);
310        self.is_complete.store(false, Ordering::Release);
311        self.convergence_value.store(0, Ordering::Release);
312
313        for state in &self.phase_states {
314            state.store(PhaseState::Pending as u32, Ordering::Release);
315        }
316
317        for buffer in &mut self.phase_buffers {
318            for item in buffer.iter_mut() {
319                *item = T::default();
320            }
321        }
322    }
323}
324
325fn phase_state_from_u32(value: u32) -> PhaseState {
326    match value {
327        0 => PhaseState::Pending,
328        1 => PhaseState::Running,
329        2 => PhaseState::Complete,
330        _ => PhaseState::Failed,
331    }
332}
333
334/// Reduction errors
335#[derive(Debug, thiserror::Error)]
336pub enum ReductionError {
337    /// Invalid phase number
338    #[error("Invalid phase {phase}, max phases: {max_phases}")]
339    InvalidPhase {
340        /// Requested phase
341        phase: u32,
342        /// Maximum phases
343        max_phases: u32,
344    },
345
346    /// Invalid phase state transition
347    #[error("Invalid phase state for phase {phase}: {current:?}")]
348    InvalidPhaseState {
349        /// Phase number
350        phase: u32,
351        /// Current state
352        current: PhaseState,
353    },
354
355    /// Reduction not complete
356    #[error("Reduction not complete, current phase: {current_phase}")]
357    NotComplete {
358        /// Current phase
359        current_phase: u32,
360    },
361
362    /// Buffer size mismatch
363    #[error("Buffer size mismatch: expected {expected}, got {actual}")]
364    BufferSizeMismatch {
365        /// Expected size
366        expected: usize,
367        /// Actual size
368        actual: usize,
369    },
370}
371
372/// Global reduction tracker for K2K coordination
373pub struct GlobalReduction {
374    /// Total number of participants
375    pub total_participants: u32,
376    /// Number of participants that have completed
377    pub completed: AtomicU32,
378    /// Whether all participants are done
379    pub all_complete: AtomicBool,
380    /// Partial results (one per participant)
381    pub partial_results: Vec<AtomicU64>,
382}
383
384impl GlobalReduction {
385    /// Create a new global reduction
386    pub fn new(participants: u32) -> Self {
387        let partial_results = (0..participants).map(|_| AtomicU64::new(0)).collect();
388
389        Self {
390            total_participants: participants,
391            completed: AtomicU32::new(0),
392            all_complete: AtomicBool::new(false),
393            partial_results,
394        }
395    }
396
397    /// Submit a partial result
398    pub fn submit(&self, participant_id: u32, value: f64) -> bool {
399        if participant_id >= self.total_participants {
400            return false;
401        }
402
403        self.partial_results[participant_id as usize].store(value.to_bits(), Ordering::Release);
404
405        let count = self.completed.fetch_add(1, Ordering::AcqRel) + 1;
406        if count == self.total_participants {
407            self.all_complete.store(true, Ordering::Release);
408            return true;
409        }
410
411        false
412    }
413
414    /// Check if all participants have submitted
415    pub fn is_complete(&self) -> bool {
416        self.all_complete.load(Ordering::Acquire)
417    }
418
419    /// Get completion count
420    pub fn completion_count(&self) -> u32 {
421        self.completed.load(Ordering::Acquire)
422    }
423
424    /// Compute final result (sum of partials)
425    pub fn finalize_sum(&self) -> Option<f64> {
426        if !self.is_complete() {
427            return None;
428        }
429
430        let sum: f64 = self
431            .partial_results
432            .iter()
433            .map(|v| f64::from_bits(v.load(Ordering::Acquire)))
434            .sum();
435
436        Some(sum)
437    }
438
439    /// Compute final result (max of partials)
440    pub fn finalize_max(&self) -> Option<f64> {
441        if !self.is_complete() {
442            return None;
443        }
444
445        self.partial_results
446            .iter()
447            .map(|v| f64::from_bits(v.load(Ordering::Acquire)))
448            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
449    }
450
451    /// Compute final result (min of partials)
452    pub fn finalize_min(&self) -> Option<f64> {
453        if !self.is_complete() {
454            return None;
455        }
456
457        self.partial_results
458            .iter()
459            .map(|v| f64::from_bits(v.load(Ordering::Acquire)))
460            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
461    }
462
463    /// Reset for reuse
464    pub fn reset(&self) {
465        self.completed.store(0, Ordering::Release);
466        self.all_complete.store(false, Ordering::Release);
467        for partial in &self.partial_results {
468            partial.store(0, Ordering::Release);
469        }
470    }
471}
472
473/// Helper to create a cooperative sync barrier
474pub struct CooperativeBarrier {
475    /// Expected thread count
476    expected: u32,
477    /// Arrived count
478    arrived: AtomicU32,
479    /// Generation counter
480    generation: AtomicU32,
481}
482
483impl CooperativeBarrier {
484    /// Create a new barrier
485    pub fn new(expected: u32) -> Self {
486        Self {
487            expected,
488            arrived: AtomicU32::new(0),
489            generation: AtomicU32::new(0),
490        }
491    }
492
493    /// Wait at the barrier
494    pub fn wait(&self) -> u32 {
495        let generation_num = self.generation.load(Ordering::Acquire);
496        let arrived = self.arrived.fetch_add(1, Ordering::AcqRel) + 1;
497
498        if arrived == self.expected {
499            // Last one to arrive, reset and advance generation
500            self.arrived.store(0, Ordering::Release);
501            self.generation.fetch_add(1, Ordering::Release);
502        } else {
503            // Spin wait for generation change
504            while self.generation.load(Ordering::Acquire) == generation_num {
505                std::hint::spin_loop();
506            }
507        }
508
509        generation_num
510    }
511
512    /// Reset the barrier
513    pub fn reset(&self) {
514        self.arrived.store(0, Ordering::Release);
515        self.generation.store(0, Ordering::Release);
516    }
517}
518
519/// Builder for reduction operations
520pub struct ReductionBuilder {
521    config: ReductionConfig,
522}
523
524impl ReductionBuilder {
525    /// Create a new builder
526    pub fn new() -> Self {
527        Self {
528            config: ReductionConfig::default(),
529        }
530    }
531
532    /// Set sync mode
533    pub fn sync_mode(mut self, mode: SyncMode) -> Self {
534        self.config.sync_mode = mode;
535        self
536    }
537
538    /// Set number of phases
539    pub fn phases(mut self, num: u32) -> Self {
540        self.config.num_phases = num;
541        self
542    }
543
544    /// Set block size
545    pub fn block_size(mut self, size: u32) -> Self {
546        self.config.block_size = size;
547        self
548    }
549
550    /// Set grid size
551    pub fn grid_size(mut self, size: u32) -> Self {
552        self.config.grid_size = size;
553        self
554    }
555
556    /// Enable convergence checking
557    pub fn with_convergence(mut self, threshold: f64) -> Self {
558        self.config.convergence_check = true;
559        self.config.convergence_threshold = threshold;
560        self
561    }
562
563    /// Build the configuration
564    pub fn build(self) -> ReductionConfig {
565        self.config
566    }
567
568    /// Build an InterPhaseReduction
569    pub fn build_reduction<T: Default + Clone + Copy>(
570        self,
571        input_size: usize,
572    ) -> InterPhaseReduction<T> {
573        InterPhaseReduction::with_config(input_size, self.config)
574    }
575}
576
577impl Default for ReductionBuilder {
578    fn default() -> Self {
579        Self::new()
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn test_inter_phase_reduction() {
589        let reduction = InterPhaseReduction::<f64>::new(1024, SyncMode::MultiLaunch);
590
591        assert_eq!(reduction.current_phase(), 0);
592        assert!(!reduction.is_complete());
593
594        // Phase 0
595        reduction.phase_start(0).unwrap();
596        assert_eq!(reduction.phase_state(0), PhaseState::Running);
597        reduction.phase_complete(0).unwrap();
598        assert_eq!(reduction.phase_state(0), PhaseState::Complete);
599
600        // Phase 1
601        reduction.phase_start(1).unwrap();
602        reduction.phase_complete(1).unwrap();
603
604        assert!(reduction.is_complete());
605    }
606
607    #[test]
608    fn test_phase_buffers() {
609        let mut reduction = InterPhaseReduction::<f64>::with_config(
610            1000,
611            ReductionConfig {
612                block_size: 256,
613                num_phases: 3,
614                ..Default::default()
615            },
616        );
617
618        // First phase buffer should be 1000 elements
619        assert_eq!(reduction.buffer_size(0), 1000);
620
621        // Subsequent buffers are reduced
622        assert!(reduction.buffer_size(1) < reduction.buffer_size(0));
623
624        // Can write to buffers
625        if let Some(buf) = reduction.get_buffer_mut(0) {
626            buf[0] = 42.0;
627        }
628
629        assert_eq!(reduction.get_buffer(0).unwrap()[0], 42.0);
630    }
631
632    #[test]
633    fn test_global_reduction() {
634        let reduction = GlobalReduction::new(4);
635
636        assert!(!reduction.is_complete());
637
638        reduction.submit(0, 1.0);
639        reduction.submit(1, 2.0);
640        reduction.submit(2, 3.0);
641
642        assert!(!reduction.is_complete());
643        assert_eq!(reduction.completion_count(), 3);
644
645        reduction.submit(3, 4.0);
646
647        assert!(reduction.is_complete());
648        assert_eq!(reduction.finalize_sum(), Some(10.0));
649    }
650
651    #[test]
652    fn test_cooperative_barrier() {
653        use std::sync::Arc;
654        use std::thread;
655
656        let barrier = Arc::new(CooperativeBarrier::new(3));
657        let handles: Vec<_> = (0..3)
658            .map(|_| {
659                let b = barrier.clone();
660                thread::spawn(move || b.wait())
661            })
662            .collect();
663
664        for h in handles {
665            let generation_num = h.join().unwrap();
666            assert_eq!(generation_num, 0);
667        }
668    }
669
670    #[test]
671    fn test_reduction_builder() {
672        let config = ReductionBuilder::new()
673            .sync_mode(SyncMode::Cooperative)
674            .phases(3)
675            .block_size(512)
676            .with_convergence(1e-8)
677            .build();
678
679        assert_eq!(config.sync_mode, SyncMode::Cooperative);
680        assert_eq!(config.num_phases, 3);
681        assert_eq!(config.block_size, 512);
682        assert!(config.convergence_check);
683    }
684
685    #[test]
686    fn test_convergence_tracking() {
687        let reduction = InterPhaseReduction::<f64>::with_config(
688            100,
689            ReductionConfig {
690                convergence_check: true,
691                convergence_threshold: 1e-6,
692                ..Default::default()
693            },
694        );
695
696        reduction.set_convergence(1e-3);
697        assert!(!reduction.is_converged());
698
699        reduction.set_convergence(1e-8);
700        assert!(reduction.is_converged());
701    }
702}