1use crate::gpu::{GpuBuffer, GpuError};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::{Arc, Mutex, Weak};
11use std::time::{Duration, Instant};
12use thiserror::Error;
13
14type CallbackFn = Box<dyn FnOnce() + Send + 'static>;
16
17type CallbackList = Arc<Mutex<Vec<CallbackFn>>>;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22pub struct EventId(u64);
23
24impl EventId {
25 pub fn new() -> Self {
27 static COUNTER: AtomicU64 = AtomicU64::new(1);
28 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
29 }
30}
31
32impl Default for EventId {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct StreamId(u64);
41
42impl StreamId {
43 pub fn new() -> Self {
45 static COUNTER: AtomicU64 = AtomicU64::new(1);
46 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
47 }
48}
49
50impl Default for StreamId {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum EventState {
59 Pending,
61 Completed,
63 Failed,
65 Cancelled,
67}
68
69pub struct GpuEvent {
71 id: EventId,
72 state: Arc<Mutex<EventState>>,
73 timestamp: Option<Instant>,
74 duration: Arc<Mutex<Option<Duration>>>,
75 dependencies: Vec<EventId>,
76 callbacks: CallbackList,
77}
78
79impl GpuEvent {
80 pub fn new() -> Self {
82 Self {
83 id: EventId::new(),
84 state: Arc::new(Mutex::new(EventState::Pending)),
85 timestamp: Some(Instant::now()),
86 duration: Arc::new(Mutex::new(None)),
87 dependencies: Vec::new(),
88 callbacks: Arc::new(Mutex::new(Vec::new())),
89 }
90 }
91
92 pub fn with_dependencies(dependencies: Vec<EventId>) -> Self {
94 Self {
95 id: EventId::new(),
96 state: Arc::new(Mutex::new(EventState::Pending)),
97 timestamp: Some(Instant::now()),
98 duration: Arc::new(Mutex::new(None)),
99 dependencies,
100 callbacks: Arc::new(Mutex::new(Vec::new())),
101 }
102 }
103
104 pub fn id(&self) -> EventId {
106 self.id
107 }
108
109 pub fn state(&self) -> EventState {
111 *self.state.lock().expect("Operation failed")
112 }
113
114 pub fn is_completed(&self) -> bool {
116 self.state() == EventState::Completed
117 }
118
119 pub fn is_failed(&self) -> bool {
121 self.state() == EventState::Failed
122 }
123
124 pub fn wait(&self) -> Result<(), GpuError> {
126 self.wait_timeout(Duration::from_secs(30))
127 }
128
129 pub fn wait_timeout(&self, timeout: Duration) -> Result<(), GpuError> {
131 let start = Instant::now();
132 while start.elapsed() < timeout {
133 match self.state() {
134 EventState::Completed => return Ok(()),
135 EventState::Failed => {
136 return Err(GpuError::KernelExecutionError(
137 "Event execution failed".to_string(),
138 ))
139 }
140 EventState::Cancelled => {
141 return Err(GpuError::Other("Event was cancelled".to_string()))
142 }
143 EventState::Pending => {
144 std::thread::sleep(Duration::from_millis(1));
145 }
146 }
147 }
148 Err(GpuError::Other("Event wait timeout".to_string()))
149 }
150
151 pub fn duration(&self) -> Option<Duration> {
153 *self.duration.lock().expect("Operation failed")
154 }
155
156 pub fn add_callback<F>(&self, callback: F)
158 where
159 F: FnOnce() + Send + 'static,
160 {
161 self.callbacks
162 .lock()
163 .expect("Operation failed")
164 .push(Box::new(callback));
165 }
166
167 pub fn dependencies(&self) -> &[EventId] {
169 &self.dependencies
170 }
171
172 #[allow(dead_code)]
174 pub(crate) fn complete(&self) {
175 let start_time = self.timestamp.unwrap_or_else(Instant::now);
176 let duration = start_time.elapsed();
177
178 *self.duration.lock().expect("Operation failed") = Some(duration);
179 *self.state.lock().expect("Operation failed") = EventState::Completed;
180
181 let callbacks = std::mem::take(&mut *self.callbacks.lock().expect("Operation failed"));
183 for callback in callbacks {
184 callback();
185 }
186 }
187
188 #[allow(dead_code)]
190 pub(crate) fn fail(&self) {
191 *self.state.lock().expect("Operation failed") = EventState::Failed;
192 }
193
194 #[allow(dead_code)]
196 pub(crate) fn cancel(&self) {
197 *self.state.lock().expect("Operation failed") = EventState::Cancelled;
198 }
199}
200
201impl Default for GpuEvent {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207impl std::fmt::Debug for GpuEvent {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 f.debug_struct("GpuEvent")
210 .field("id", &self.id)
211 .field("state", &self.state)
212 .field("timestamp", &self.timestamp)
213 .field("duration", &self.duration)
214 .field("dependencies", &self.dependencies)
215 .field(
216 "callbacks",
217 &format!(
218 "{} callbacks",
219 self.callbacks.lock().expect("Operation failed").len()
220 ),
221 )
222 .finish()
223 }
224}
225
226#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
228pub enum StreamPriority {
229 Low = 0,
231 Normal = 1,
233 High = 2,
235}
236
237impl Default for StreamPriority {
238 fn default() -> Self {
239 Self::Normal
240 }
241}
242
243#[derive(Debug)]
245pub struct GpuStream {
246 id: StreamId,
247 priority: StreamPriority,
248 events: Arc<Mutex<Vec<Weak<GpuEvent>>>>,
249 operations_count: Arc<Mutex<usize>>,
250}
251
252impl GpuStream {
253 pub fn new() -> Self {
255 Self {
256 id: StreamId::new(),
257 priority: StreamPriority::Normal,
258 events: Arc::new(Mutex::new(Vec::new())),
259 operations_count: Arc::new(Mutex::new(0)),
260 }
261 }
262
263 pub fn with_priority(priority: StreamPriority) -> Self {
265 Self {
266 id: StreamId::new(),
267 priority,
268 events: Arc::new(Mutex::new(Vec::new())),
269 operations_count: Arc::new(Mutex::new(0)),
270 }
271 }
272
273 pub fn id(&self) -> StreamId {
275 self.id
276 }
277
278 pub fn priority(&self) -> StreamPriority {
280 self.priority
281 }
282
283 pub fn add_event(&self, event: &Arc<GpuEvent>) {
285 self.events
286 .lock()
287 .expect("Operation failed")
288 .push(Arc::downgrade(event));
289 *self.operations_count.lock().expect("Operation failed") += 1;
290 }
291
292 pub fn synchronize(&self) -> Result<(), GpuError> {
294 let events = self.events.lock().expect("Operation failed").clone();
295 for weak_event in events {
296 if let Some(event) = weak_event.upgrade() {
297 event.wait()?;
298 }
299 }
300 Ok(())
301 }
302
303 pub fn operations_count(&self) -> usize {
305 *self.operations_count.lock().expect("Operation failed")
306 }
307
308 pub fn is_idle(&self) -> bool {
310 let events = self.events.lock().expect("Operation failed");
311 events.iter().all(|weak_event| {
312 weak_event
313 .upgrade()
314 .map(|event| event.is_completed())
315 .unwrap_or(true)
316 })
317 }
318
319 pub fn cleanup(&self) {
321 let mut events = self.events.lock().expect("Operation failed");
322 events.retain(|weak_event| {
323 weak_event
324 .upgrade()
325 .is_some_and(|event| !event.is_completed())
326 });
327 }
328}
329
330impl Default for GpuStream {
331 fn default() -> Self {
332 Self::new()
333 }
334}
335
336#[derive(Error, Debug)]
338pub enum AsyncGpuError {
339 #[error("Stream not found: {0:?}")]
341 StreamNotFound(StreamId),
342
343 #[error("Event not found: {0:?}")]
345 EventNotFound(EventId),
346
347 #[error("Operation timeout after {0:?}")]
349 Timeout(Duration),
350
351 #[error("Dependency cycle detected in events")]
353 DependencyCycle,
354
355 #[error("GPU error: {0}")]
357 GpuError(#[from] GpuError),
358}
359
360pub type AsyncResult<T> = Result<T, AsyncGpuError>;
362
363#[derive(Debug)]
365pub struct AsyncGpuManager {
366 streams: Arc<Mutex<HashMap<StreamId, Arc<GpuStream>>>>,
367 events: Arc<Mutex<HashMap<EventId, Arc<GpuEvent>>>>,
368 default_stream: Arc<GpuStream>,
369}
370
371impl AsyncGpuManager {
372 pub fn new() -> Self {
374 let default_stream = Arc::new(GpuStream::new());
375 let mut streams = HashMap::new();
376 streams.insert(default_stream.id(), default_stream.clone());
377
378 Self {
379 streams: Arc::new(Mutex::new(streams)),
380 events: Arc::new(Mutex::new(HashMap::new())),
381 default_stream,
382 }
383 }
384
385 pub fn create_stream(&self) -> Arc<GpuStream> {
387 self.create_stream_with_priority(StreamPriority::Normal)
388 }
389
390 pub fn create_stream_with_priority(&self, priority: StreamPriority) -> Arc<GpuStream> {
392 let stream = Arc::new(GpuStream::with_priority(priority));
393 self.streams
394 .lock()
395 .expect("Operation failed")
396 .insert(stream.id(), stream.clone());
397 stream
398 }
399
400 pub fn default_stream(&self) -> Arc<GpuStream> {
402 self.default_stream.clone()
403 }
404
405 pub fn get_stream(&self, id: StreamId) -> Option<Arc<GpuStream>> {
407 self.streams
408 .lock()
409 .expect("Operation failed")
410 .get(&id)
411 .cloned()
412 }
413
414 pub fn record_event(&self, stream: &Arc<GpuStream>) -> Arc<GpuEvent> {
416 let event = Arc::new(GpuEvent::new());
417 stream.add_event(&event);
418 self.events
419 .lock()
420 .expect("Operation failed")
421 .insert(event.id(), event.clone());
422 event
423 }
424
425 pub fn record_event_with_dependencies(
427 &self,
428 stream: &Arc<GpuStream>,
429 dependencies: Vec<EventId>,
430 ) -> AsyncResult<Arc<GpuEvent>> {
431 self.check_dependency_cycles(&dependencies)?;
433
434 let event = Arc::new(GpuEvent::with_dependencies(dependencies));
435 stream.add_event(&event);
436 self.events
437 .lock()
438 .expect("Operation failed")
439 .insert(event.id(), event.clone());
440 Ok(event)
441 }
442
443 pub fn wait_for_events(&self, eventids: &[EventId]) -> AsyncResult<()> {
445 for &event_id in eventids {
446 if let Some(event) = self
447 .events
448 .lock()
449 .expect("Operation failed")
450 .get(&event_id)
451 .cloned()
452 {
453 event.wait()?;
454 } else {
455 return Err(AsyncGpuError::EventNotFound(event_id));
456 }
457 }
458 Ok(())
459 }
460
461 pub fn synchronize_all(&self) -> AsyncResult<()> {
463 let streams = self
464 .streams
465 .lock()
466 .expect("Operation failed")
467 .values()
468 .cloned()
469 .collect::<Vec<_>>();
470 for stream in streams {
471 stream.synchronize()?;
472 }
473 Ok(())
474 }
475
476 pub fn cleanup(&self) {
478 let stream_ids: Vec<_> = self
480 .streams
481 .lock()
482 .expect("Operation failed")
483 .keys()
484 .cloned()
485 .collect();
486 for stream_id in stream_ids {
487 if let Some(stream) = self
488 .streams
489 .lock()
490 .expect("Operation failed")
491 .get(&stream_id)
492 .cloned()
493 {
494 stream.cleanup();
495 }
496 }
497
498 let mut events = self.events.lock().expect("Operation failed");
500 events.retain(|_, event| !event.is_completed() && !event.is_failed());
501 }
502
503 pub fn get_statistics(&self) -> AsyncGpuStatistics {
505 let streams = self.streams.lock().expect("Operation failed");
506 let events = self.events.lock().expect("Operation failed");
507
508 let total_streams = streams.len();
509 let total_events = events.len();
510 let completed_events = events.values().filter(|e| e.is_completed()).count();
511 let failed_events = events.values().filter(|e| e.is_failed()).count();
512 let pending_events = events
513 .values()
514 .filter(|e| e.state() == EventState::Pending)
515 .count();
516
517 AsyncGpuStatistics {
518 total_streams,
519 total_events,
520 completed_events,
521 failed_events,
522 pending_events,
523 }
524 }
525
526 fn check_dependency_cycles(&self, dependencies: &[EventId]) -> AsyncResult<()> {
528 let events = self.events.lock().expect("Operation failed");
529
530 fn has_cycle(
532 event_id: EventId,
533 events: &HashMap<EventId, Arc<GpuEvent>>,
534 visited: &mut std::collections::HashSet<EventId>,
535 rec_stack: &mut std::collections::HashSet<EventId>,
536 ) -> bool {
537 visited.insert(event_id);
538 rec_stack.insert(event_id);
539
540 if let Some(event) = events.get(&event_id) {
541 for &dep_id in event.dependencies() {
542 if !visited.contains(&dep_id) {
543 if has_cycle(dep_id, events, visited, rec_stack) {
544 return true;
545 }
546 } else if rec_stack.contains(&dep_id) {
547 return true;
548 }
549 }
550 }
551
552 rec_stack.remove(&event_id);
553 false
554 }
555
556 let mut visited = std::collections::HashSet::new();
557 let mut rec_stack = std::collections::HashSet::new();
558
559 for &dep_id in dependencies {
560 if !visited.contains(&dep_id)
561 && has_cycle(dep_id, &events, &mut visited, &mut rec_stack)
562 {
563 return Err(AsyncGpuError::DependencyCycle);
564 }
565 }
566
567 Ok(())
568 }
569}
570
571impl Default for AsyncGpuManager {
572 fn default() -> Self {
573 Self::new()
574 }
575}
576
577#[derive(Debug, Clone)]
579pub struct AsyncGpuStatistics {
580 pub total_streams: usize,
582 pub total_events: usize,
584 pub completed_events: usize,
586 pub failed_events: usize,
588 pub pending_events: usize,
590}
591
592pub trait AsyncGpuOps {
594 fn launch_async(&self, workgroups: [u32; 3], stream: &Arc<GpuStream>) -> Arc<GpuEvent>;
596
597 fn copy_async<T: crate::gpu::GpuDataType>(
599 &self,
600 src: &GpuBuffer<T>,
601 dst: &GpuBuffer<T>,
602 stream: &Arc<GpuStream>,
603 ) -> Arc<GpuEvent>;
604
605 fn copy_from_host_async<T: crate::gpu::GpuDataType>(
607 &self,
608 src: &[T],
609 dst: &GpuBuffer<T>,
610 stream: &Arc<GpuStream>,
611 ) -> Arc<GpuEvent>;
612
613 fn copy_to_host_async<T: crate::gpu::GpuDataType>(
615 &self,
616 src: &GpuBuffer<T>,
617 dst: &mut [T],
618 stream: &Arc<GpuStream>,
619 ) -> Arc<GpuEvent>;
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625
626 #[test]
627 fn test_event_creation() {
628 let event = GpuEvent::new();
629 assert_eq!(event.state(), EventState::Pending);
630 assert!(!event.is_completed());
631 assert!(!event.is_failed());
632 }
633
634 #[test]
635 fn test_event_completion() {
636 let event = GpuEvent::new();
637 event.complete();
638 assert_eq!(event.state(), EventState::Completed);
639 assert!(event.is_completed());
640 assert!(!event.is_failed());
641 assert!(event.duration().is_some());
642 }
643
644 #[test]
645 fn test_stream_creation() {
646 let stream = GpuStream::new();
647 assert_eq!(stream.priority(), StreamPriority::Normal);
648 assert_eq!(stream.operations_count(), 0);
649 assert!(stream.is_idle());
650 }
651
652 #[test]
653 fn test_async_manager() {
654 let manager = AsyncGpuManager::new();
655 let stream = manager.create_stream();
656 let event = manager.record_event(&stream);
657
658 assert_eq!(stream.operations_count(), 1);
659 assert!(!stream.is_idle());
660
661 event.complete();
662 assert!(event.is_completed());
663 }
664
665 #[test]
666 fn test_event_dependencies() {
667 let event1 = GpuEvent::new();
668 let event2 = GpuEvent::with_dependencies(vec![event1.id()]);
669
670 assert_eq!(event2.dependencies().len(), 1);
671 assert_eq!(event2.dependencies()[0], event1.id());
672 }
673
674 #[test]
675 fn test_stream_priority() {
676 let low_stream = GpuStream::with_priority(StreamPriority::Low);
677 let high_stream = GpuStream::with_priority(StreamPriority::High);
678
679 assert_eq!(low_stream.priority(), StreamPriority::Low);
680 assert_eq!(high_stream.priority(), StreamPriority::High);
681 assert!(high_stream.priority() > low_stream.priority());
682 }
683}