scirs2_core/gpu/
async_execution.rs

1//! Asynchronous execution and event-based synchronization for GPU operations
2//!
3//! This module provides comprehensive support for asynchronous GPU operations with
4//! event-based synchronization, enabling efficient overlapping of computation and
5//! memory transfers.
6
7use crate::gpu::{GpuBuffer, GpuError};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::{Arc, Mutex, Weak};
11use std::time::{Duration, Instant};
12use thiserror::Error;
13
14/// Type alias for a callback function
15type CallbackFn = Box<dyn FnOnce() + Send + 'static>;
16
17/// Type alias for a list of callbacks
18type CallbackList = Arc<Mutex<Vec<CallbackFn>>>;
19
20/// Unique identifier for GPU events
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22pub struct EventId(u64);
23
24impl EventId {
25    /// Create a new unique event ID
26    pub fn new() -> Self {
27        static COUNTER: AtomicU64 = AtomicU64::new(1);
28        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
29    }
30}
31
32impl Default for EventId {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38/// Unique identifier for GPU streams
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct StreamId(u64);
41
42impl StreamId {
43    /// Create a new unique stream ID
44    pub fn new() -> Self {
45        static COUNTER: AtomicU64 = AtomicU64::new(1);
46        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
47    }
48}
49
50impl Default for StreamId {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56/// Event state for synchronization
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum EventState {
59    /// Event has been recorded but not yet completed
60    Pending,
61    /// Event has completed successfully
62    Completed,
63    /// Event has failed
64    Failed,
65    /// Event was cancelled
66    Cancelled,
67}
68
69/// GPU event for synchronization
70pub struct GpuEvent {
71    id: EventId,
72    state: Arc<Mutex<EventState>>,
73    timestamp: Option<Instant>,
74    duration: Arc<Mutex<Option<Duration>>>,
75    dependencies: Vec<EventId>,
76    callbacks: CallbackList,
77}
78
79impl GpuEvent {
80    /// Create a new GPU event
81    pub fn new() -> Self {
82        Self {
83            id: EventId::new(),
84            state: Arc::new(Mutex::new(EventState::Pending)),
85            timestamp: Some(Instant::now()),
86            duration: Arc::new(Mutex::new(None)),
87            dependencies: Vec::new(),
88            callbacks: Arc::new(Mutex::new(Vec::new())),
89        }
90    }
91
92    /// Create a new event with dependencies
93    pub fn with_dependencies(dependencies: Vec<EventId>) -> Self {
94        Self {
95            id: EventId::new(),
96            state: Arc::new(Mutex::new(EventState::Pending)),
97            timestamp: Some(Instant::now()),
98            duration: Arc::new(Mutex::new(None)),
99            dependencies,
100            callbacks: Arc::new(Mutex::new(Vec::new())),
101        }
102    }
103
104    /// Get the event ID
105    pub fn id(&self) -> EventId {
106        self.id
107    }
108
109    /// Get the current state of the event
110    pub fn state(&self) -> EventState {
111        *self.state.lock().expect("Operation failed")
112    }
113
114    /// Check if the event has completed
115    pub fn is_completed(&self) -> bool {
116        self.state() == EventState::Completed
117    }
118
119    /// Check if the event has failed
120    pub fn is_failed(&self) -> bool {
121        self.state() == EventState::Failed
122    }
123
124    /// Wait for the event to complete
125    pub fn wait(&self) -> Result<(), GpuError> {
126        self.wait_timeout(Duration::from_secs(30))
127    }
128
129    /// Wait for the event to complete with a timeout
130    pub fn wait_timeout(&self, timeout: Duration) -> Result<(), GpuError> {
131        let start = Instant::now();
132        while start.elapsed() < timeout {
133            match self.state() {
134                EventState::Completed => return Ok(()),
135                EventState::Failed => {
136                    return Err(GpuError::KernelExecutionError(
137                        "Event execution failed".to_string(),
138                    ))
139                }
140                EventState::Cancelled => {
141                    return Err(GpuError::Other("Event was cancelled".to_string()))
142                }
143                EventState::Pending => {
144                    std::thread::sleep(Duration::from_millis(1));
145                }
146            }
147        }
148        Err(GpuError::Other("Event wait timeout".to_string()))
149    }
150
151    /// Get the execution duration if completed
152    pub fn duration(&self) -> Option<Duration> {
153        *self.duration.lock().expect("Operation failed")
154    }
155
156    /// Add a callback to be executed when the event completes
157    pub fn add_callback<F>(&self, callback: F)
158    where
159        F: FnOnce() + Send + 'static,
160    {
161        self.callbacks
162            .lock()
163            .expect("Operation failed")
164            .push(Box::new(callback));
165    }
166
167    /// Get dependencies
168    pub fn dependencies(&self) -> &[EventId] {
169        &self.dependencies
170    }
171
172    /// Mark the event as completed
173    #[allow(dead_code)]
174    pub(crate) fn complete(&self) {
175        let start_time = self.timestamp.unwrap_or_else(Instant::now);
176        let duration = start_time.elapsed();
177
178        *self.duration.lock().expect("Operation failed") = Some(duration);
179        *self.state.lock().expect("Operation failed") = EventState::Completed;
180
181        // Execute callbacks
182        let callbacks = std::mem::take(&mut *self.callbacks.lock().expect("Operation failed"));
183        for callback in callbacks {
184            callback();
185        }
186    }
187
188    /// Mark the event as failed
189    #[allow(dead_code)]
190    pub(crate) fn fail(&self) {
191        *self.state.lock().expect("Operation failed") = EventState::Failed;
192    }
193
194    /// Cancel the event
195    #[allow(dead_code)]
196    pub(crate) fn cancel(&self) {
197        *self.state.lock().expect("Operation failed") = EventState::Cancelled;
198    }
199}
200
201impl Default for GpuEvent {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207impl std::fmt::Debug for GpuEvent {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        f.debug_struct("GpuEvent")
210            .field("id", &self.id)
211            .field("state", &self.state)
212            .field("timestamp", &self.timestamp)
213            .field("duration", &self.duration)
214            .field("dependencies", &self.dependencies)
215            .field(
216                "callbacks",
217                &format!(
218                    "{} callbacks",
219                    self.callbacks.lock().expect("Operation failed").len()
220                ),
221            )
222            .finish()
223    }
224}
225
226/// Priority levels for stream operations
227#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
228pub enum StreamPriority {
229    /// Low priority for background operations
230    Low = 0,
231    /// Normal priority for regular operations
232    Normal = 1,
233    /// High priority for critical operations
234    High = 2,
235}
236
237impl Default for StreamPriority {
238    fn default() -> Self {
239        Self::Normal
240    }
241}
242
243/// GPU stream for organizing operations
244#[derive(Debug)]
245pub struct GpuStream {
246    id: StreamId,
247    priority: StreamPriority,
248    events: Arc<Mutex<Vec<Weak<GpuEvent>>>>,
249    operations_count: Arc<Mutex<usize>>,
250}
251
252impl GpuStream {
253    /// Create a new GPU stream
254    pub fn new() -> Self {
255        Self {
256            id: StreamId::new(),
257            priority: StreamPriority::Normal,
258            events: Arc::new(Mutex::new(Vec::new())),
259            operations_count: Arc::new(Mutex::new(0)),
260        }
261    }
262
263    /// Create a new GPU stream with priority
264    pub fn with_priority(priority: StreamPriority) -> Self {
265        Self {
266            id: StreamId::new(),
267            priority,
268            events: Arc::new(Mutex::new(Vec::new())),
269            operations_count: Arc::new(Mutex::new(0)),
270        }
271    }
272
273    /// Get the stream ID
274    pub fn id(&self) -> StreamId {
275        self.id
276    }
277
278    /// Get the stream priority
279    pub fn priority(&self) -> StreamPriority {
280        self.priority
281    }
282
283    /// Add an event to this stream
284    pub fn add_event(&self, event: &Arc<GpuEvent>) {
285        self.events
286            .lock()
287            .expect("Operation failed")
288            .push(Arc::downgrade(event));
289        *self.operations_count.lock().expect("Operation failed") += 1;
290    }
291
292    /// Wait for all operations in this stream to complete
293    pub fn synchronize(&self) -> Result<(), GpuError> {
294        let events = self.events.lock().expect("Operation failed").clone();
295        for weak_event in events {
296            if let Some(event) = weak_event.upgrade() {
297                event.wait()?;
298            }
299        }
300        Ok(())
301    }
302
303    /// Get the number of operations in this stream
304    pub fn operations_count(&self) -> usize {
305        *self.operations_count.lock().expect("Operation failed")
306    }
307
308    /// Check if the stream is idle (all operations completed)
309    pub fn is_idle(&self) -> bool {
310        let events = self.events.lock().expect("Operation failed");
311        events.iter().all(|weak_event| {
312            weak_event
313                .upgrade()
314                .map(|event| event.is_completed())
315                .unwrap_or(true)
316        })
317    }
318
319    /// Clean up completed events
320    pub fn cleanup(&self) {
321        let mut events = self.events.lock().expect("Operation failed");
322        events.retain(|weak_event| {
323            weak_event
324                .upgrade()
325                .is_some_and(|event| !event.is_completed())
326        });
327    }
328}
329
330impl Default for GpuStream {
331    fn default() -> Self {
332        Self::new()
333    }
334}
335
336/// Error types for asynchronous GPU operations
337#[derive(Error, Debug)]
338pub enum AsyncGpuError {
339    /// Stream not found
340    #[error("Stream not found: {0:?}")]
341    StreamNotFound(StreamId),
342
343    /// Event not found
344    #[error("Event not found: {0:?}")]
345    EventNotFound(EventId),
346
347    /// Operation timeout
348    #[error("Operation timeout after {0:?}")]
349    Timeout(Duration),
350
351    /// Dependency cycle detected
352    #[error("Dependency cycle detected in events")]
353    DependencyCycle,
354
355    /// Underlying GPU error
356    #[error("GPU error: {0}")]
357    GpuError(#[from] GpuError),
358}
359
360/// Asynchronous GPU operation result
361pub type AsyncResult<T> = Result<T, AsyncGpuError>;
362
363/// Manager for asynchronous GPU operations
364#[derive(Debug)]
365pub struct AsyncGpuManager {
366    streams: Arc<Mutex<HashMap<StreamId, Arc<GpuStream>>>>,
367    events: Arc<Mutex<HashMap<EventId, Arc<GpuEvent>>>>,
368    default_stream: Arc<GpuStream>,
369}
370
371impl AsyncGpuManager {
372    /// Create a new async GPU manager
373    pub fn new() -> Self {
374        let default_stream = Arc::new(GpuStream::new());
375        let mut streams = HashMap::new();
376        streams.insert(default_stream.id(), default_stream.clone());
377
378        Self {
379            streams: Arc::new(Mutex::new(streams)),
380            events: Arc::new(Mutex::new(HashMap::new())),
381            default_stream,
382        }
383    }
384
385    /// Create a new stream
386    pub fn create_stream(&self) -> Arc<GpuStream> {
387        self.create_stream_with_priority(StreamPriority::Normal)
388    }
389
390    /// Create a new stream with priority
391    pub fn create_stream_with_priority(&self, priority: StreamPriority) -> Arc<GpuStream> {
392        let stream = Arc::new(GpuStream::with_priority(priority));
393        self.streams
394            .lock()
395            .expect("Operation failed")
396            .insert(stream.id(), stream.clone());
397        stream
398    }
399
400    /// Get the default stream
401    pub fn default_stream(&self) -> Arc<GpuStream> {
402        self.default_stream.clone()
403    }
404
405    /// Get a stream by ID
406    pub fn get_stream(&self, id: StreamId) -> Option<Arc<GpuStream>> {
407        self.streams
408            .lock()
409            .expect("Operation failed")
410            .get(&id)
411            .cloned()
412    }
413
414    /// Record an event in a stream
415    pub fn record_event(&self, stream: &Arc<GpuStream>) -> Arc<GpuEvent> {
416        let event = Arc::new(GpuEvent::new());
417        stream.add_event(&event);
418        self.events
419            .lock()
420            .expect("Operation failed")
421            .insert(event.id(), event.clone());
422        event
423    }
424
425    /// Record an event with dependencies
426    pub fn record_event_with_dependencies(
427        &self,
428        stream: &Arc<GpuStream>,
429        dependencies: Vec<EventId>,
430    ) -> AsyncResult<Arc<GpuEvent>> {
431        // Check for dependency cycles
432        self.check_dependency_cycles(&dependencies)?;
433
434        let event = Arc::new(GpuEvent::with_dependencies(dependencies));
435        stream.add_event(&event);
436        self.events
437            .lock()
438            .expect("Operation failed")
439            .insert(event.id(), event.clone());
440        Ok(event)
441    }
442
443    /// Wait for multiple events
444    pub fn wait_for_events(&self, eventids: &[EventId]) -> AsyncResult<()> {
445        for &event_id in eventids {
446            if let Some(event) = self
447                .events
448                .lock()
449                .expect("Operation failed")
450                .get(&event_id)
451                .cloned()
452            {
453                event.wait()?;
454            } else {
455                return Err(AsyncGpuError::EventNotFound(event_id));
456            }
457        }
458        Ok(())
459    }
460
461    /// Synchronize all streams
462    pub fn synchronize_all(&self) -> AsyncResult<()> {
463        let streams = self
464            .streams
465            .lock()
466            .expect("Operation failed")
467            .values()
468            .cloned()
469            .collect::<Vec<_>>();
470        for stream in streams {
471            stream.synchronize()?;
472        }
473        Ok(())
474    }
475
476    /// Clean up completed events and empty streams
477    pub fn cleanup(&self) {
478        // Clean up streams
479        let stream_ids: Vec<_> = self
480            .streams
481            .lock()
482            .expect("Operation failed")
483            .keys()
484            .cloned()
485            .collect();
486        for stream_id in stream_ids {
487            if let Some(stream) = self
488                .streams
489                .lock()
490                .expect("Operation failed")
491                .get(&stream_id)
492                .cloned()
493            {
494                stream.cleanup();
495            }
496        }
497
498        // Clean up completed events
499        let mut events = self.events.lock().expect("Operation failed");
500        events.retain(|_, event| !event.is_completed() && !event.is_failed());
501    }
502
503    /// Get statistics about async operations
504    pub fn get_statistics(&self) -> AsyncGpuStatistics {
505        let streams = self.streams.lock().expect("Operation failed");
506        let events = self.events.lock().expect("Operation failed");
507
508        let total_streams = streams.len();
509        let total_events = events.len();
510        let completed_events = events.values().filter(|e| e.is_completed()).count();
511        let failed_events = events.values().filter(|e| e.is_failed()).count();
512        let pending_events = events
513            .values()
514            .filter(|e| e.state() == EventState::Pending)
515            .count();
516
517        AsyncGpuStatistics {
518            total_streams,
519            total_events,
520            completed_events,
521            failed_events,
522            pending_events,
523        }
524    }
525
526    /// Check for dependency cycles in events
527    fn check_dependency_cycles(&self, dependencies: &[EventId]) -> AsyncResult<()> {
528        let events = self.events.lock().expect("Operation failed");
529
530        // Simple cycle detection using DFS
531        fn has_cycle(
532            event_id: EventId,
533            events: &HashMap<EventId, Arc<GpuEvent>>,
534            visited: &mut std::collections::HashSet<EventId>,
535            rec_stack: &mut std::collections::HashSet<EventId>,
536        ) -> bool {
537            visited.insert(event_id);
538            rec_stack.insert(event_id);
539
540            if let Some(event) = events.get(&event_id) {
541                for &dep_id in event.dependencies() {
542                    if !visited.contains(&dep_id) {
543                        if has_cycle(dep_id, events, visited, rec_stack) {
544                            return true;
545                        }
546                    } else if rec_stack.contains(&dep_id) {
547                        return true;
548                    }
549                }
550            }
551
552            rec_stack.remove(&event_id);
553            false
554        }
555
556        let mut visited = std::collections::HashSet::new();
557        let mut rec_stack = std::collections::HashSet::new();
558
559        for &dep_id in dependencies {
560            if !visited.contains(&dep_id)
561                && has_cycle(dep_id, &events, &mut visited, &mut rec_stack)
562            {
563                return Err(AsyncGpuError::DependencyCycle);
564            }
565        }
566
567        Ok(())
568    }
569}
570
571impl Default for AsyncGpuManager {
572    fn default() -> Self {
573        Self::new()
574    }
575}
576
577/// Statistics for asynchronous GPU operations
578#[derive(Debug, Clone)]
579pub struct AsyncGpuStatistics {
580    /// Total number of streams
581    pub total_streams: usize,
582    /// Total number of events
583    pub total_events: usize,
584    /// Number of completed events
585    pub completed_events: usize,
586    /// Number of failed events
587    pub failed_events: usize,
588    /// Number of pending events
589    pub pending_events: usize,
590}
591
592/// Extension trait for adding async capabilities to GPU operations
593pub trait AsyncGpuOps {
594    /// Launch a kernel asynchronously
595    fn launch_async(&self, workgroups: [u32; 3], stream: &Arc<GpuStream>) -> Arc<GpuEvent>;
596
597    /// Copy data asynchronously
598    fn copy_async<T: crate::gpu::GpuDataType>(
599        &self,
600        src: &GpuBuffer<T>,
601        dst: &GpuBuffer<T>,
602        stream: &Arc<GpuStream>,
603    ) -> Arc<GpuEvent>;
604
605    /// Copy from host asynchronously
606    fn copy_from_host_async<T: crate::gpu::GpuDataType>(
607        &self,
608        src: &[T],
609        dst: &GpuBuffer<T>,
610        stream: &Arc<GpuStream>,
611    ) -> Arc<GpuEvent>;
612
613    /// Copy to host asynchronously
614    fn copy_to_host_async<T: crate::gpu::GpuDataType>(
615        &self,
616        src: &GpuBuffer<T>,
617        dst: &mut [T],
618        stream: &Arc<GpuStream>,
619    ) -> Arc<GpuEvent>;
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625
626    #[test]
627    fn test_event_creation() {
628        let event = GpuEvent::new();
629        assert_eq!(event.state(), EventState::Pending);
630        assert!(!event.is_completed());
631        assert!(!event.is_failed());
632    }
633
634    #[test]
635    fn test_event_completion() {
636        let event = GpuEvent::new();
637        event.complete();
638        assert_eq!(event.state(), EventState::Completed);
639        assert!(event.is_completed());
640        assert!(!event.is_failed());
641        assert!(event.duration().is_some());
642    }
643
644    #[test]
645    fn test_stream_creation() {
646        let stream = GpuStream::new();
647        assert_eq!(stream.priority(), StreamPriority::Normal);
648        assert_eq!(stream.operations_count(), 0);
649        assert!(stream.is_idle());
650    }
651
652    #[test]
653    fn test_async_manager() {
654        let manager = AsyncGpuManager::new();
655        let stream = manager.create_stream();
656        let event = manager.record_event(&stream);
657
658        assert_eq!(stream.operations_count(), 1);
659        assert!(!stream.is_idle());
660
661        event.complete();
662        assert!(event.is_completed());
663    }
664
665    #[test]
666    fn test_event_dependencies() {
667        let event1 = GpuEvent::new();
668        let event2 = GpuEvent::with_dependencies(vec![event1.id()]);
669
670        assert_eq!(event2.dependencies().len(), 1);
671        assert_eq!(event2.dependencies()[0], event1.id());
672    }
673
674    #[test]
675    fn test_stream_priority() {
676        let low_stream = GpuStream::with_priority(StreamPriority::Low);
677        let high_stream = GpuStream::with_priority(StreamPriority::High);
678
679        assert_eq!(low_stream.priority(), StreamPriority::Low);
680        assert_eq!(high_stream.priority(), StreamPriority::High);
681        assert!(high_stream.priority() > low_stream.priority());
682    }
683}