Skip to main content

sklears_simd/
multi_gpu.rs

1//! Multi-GPU support for parallel processing
2//!
3//! This module provides utilities for distributing computations across multiple
4//! GPU devices, load balancing, and coordinating parallel GPU operations.
5
6use crate::gpu::{GpuDevice, KernelConfig};
7use crate::gpu_memory::MultiGpuMemoryManager;
8use crate::traits::SimdError;
9
10#[cfg(not(feature = "no-std"))]
11use std::collections::{HashMap, HashSet};
12#[cfg(not(feature = "no-std"))]
13use std::sync::{Arc, Mutex};
14#[cfg(not(feature = "no-std"))]
15use std::thread;
16
17#[cfg(feature = "no-std")]
18use alloc::collections::{BTreeMap as HashMap, BTreeSet as HashSet};
19#[cfg(feature = "no-std")]
20use alloc::{
21    boxed::Box,
22    format,
23    string::{String, ToString},
24    sync::Arc,
25    vec,
26    vec::Vec,
27};
28
29#[cfg(feature = "no-std")]
30use core::mem;
31#[cfg(feature = "no-std")]
32use core::{any::Any, cmp::Ordering};
33#[cfg(feature = "no-std")]
34use spin::Mutex;
35#[cfg(not(feature = "no-std"))]
36use std::{any::Any, cmp::Ordering, string::ToString};
37
38// Mock types for no-std compatibility
39#[cfg(feature = "no-std")]
40#[derive(Debug, Clone, Copy)]
41pub struct Instant;
42
43#[cfg(feature = "no-std")]
44impl Instant {
45    pub fn now() -> Self {
46        Instant // Mock implementation for no-std
47    }
48
49    pub fn elapsed(&self) -> u64 {
50        0 // Mock implementation - returns 0 nanoseconds
51    }
52}
53
54/// Multi-GPU coordinator for parallel processing
55pub struct MultiGpuCoordinator {
56    devices: Vec<GpuDevice>,
57    memory_manager: Arc<Mutex<MultiGpuMemoryManager>>,
58    load_balancer: LoadBalancer,
59    task_scheduler: TaskScheduler,
60    #[allow(dead_code)] // Reserved for barrier/event synchronization when GPU backends are enabled
61    sync_manager: SynchronizationManager,
62}
63
64/// Load balancing strategies for multi-GPU operations
65#[derive(Debug, Clone, Copy)]
66pub enum LoadBalancingStrategy {
67    /// Equal distribution across all devices
68    Equal,
69    /// Weighted by compute units
70    ComputeWeighted,
71    /// Weighted by memory bandwidth
72    BandwidthWeighted,
73    /// Dynamic based on current load
74    Dynamic,
75    /// Custom weights specified by user
76    Custom,
77}
78
79/// Load balancer for distributing work across GPUs
80pub struct LoadBalancer {
81    strategy: LoadBalancingStrategy,
82    device_weights: HashMap<u32, f64>,
83    performance_history: HashMap<u32, Vec<f64>>,
84}
85
86/// Task scheduler for coordinating GPU operations
87pub struct TaskScheduler {
88    pending_tasks: Vec<GpuTask>,
89    running_tasks: HashMap<u32, Vec<GpuTask>>,
90    completed_tasks: Vec<CompletedTask>,
91}
92
93/// Synchronization manager for multi-GPU operations
94pub struct SynchronizationManager {
95    barriers: HashMap<String, GpuBarrier>,
96    events: HashMap<String, GpuEvent>,
97}
98
99/// GPU task representation
100#[derive(Debug, Clone)]
101pub struct GpuTask {
102    pub id: String,
103    pub kernel_name: String,
104    pub config: KernelConfig,
105    pub input_data: Vec<GpuTaskData>,
106    pub output_data: Vec<GpuTaskData>,
107    pub device_preference: Option<u32>,
108    pub priority: TaskPriority,
109    pub dependencies: Vec<String>,
110}
111
112/// Task data descriptor
113#[derive(Debug, Clone)]
114pub struct GpuTaskData {
115    pub name: String,
116    pub size: usize,
117    pub data_type: String, // "f32", "f64", "i32", etc.
118    pub location: DataLocation,
119}
120
121/// Data location for GPU tasks
122#[derive(Debug, Clone)]
123pub enum DataLocation {
124    Host(Vec<u8>),
125    Device(u32, *mut u8), // device_id, pointer
126    Unified(*mut u8),     // unified memory pointer
127}
128
129unsafe impl Send for DataLocation {}
130unsafe impl Sync for DataLocation {}
131
132/// Task priority levels
133#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
134pub enum TaskPriority {
135    Low = 0,
136    Normal = 1,
137    High = 2,
138    Critical = 3,
139}
140
141/// Completed task information
142#[derive(Debug, Clone)]
143pub struct CompletedTask {
144    pub task_id: String,
145    pub device_id: u32,
146    pub execution_time_ms: f64,
147    pub memory_used: usize,
148    pub success: bool,
149    pub error: Option<String>,
150}
151
152/// GPU barrier for synchronization
153pub struct GpuBarrier {
154    #[allow(dead_code)] // Identifies barrier in debug/logging output
155    name: String,
156    expected_participants: u32,
157    current_participants: u32,
158    waiting_devices: Vec<u32>,
159}
160
161/// GPU event for asynchronous operations
162pub struct GpuEvent {
163    #[allow(dead_code)] // Identifies event in debug/logging output
164    name: String,
165    #[allow(dead_code)] // Stores which device owns this event for routing
166    device_id: u32,
167    is_recorded: bool,
168    #[allow(dead_code)] // Reserved for native GPU event handle (CUDA event / OpenCL marker)
169    backend_event: Option<Box<dyn Any>>,
170}
171
172impl MultiGpuCoordinator {
173    /// Create a new multi-GPU coordinator
174    pub fn new(devices: Vec<GpuDevice>) -> Self {
175        let memory_manager = Arc::new(Mutex::new(MultiGpuMemoryManager::new()));
176
177        // Add devices to memory manager
178        #[cfg(not(feature = "no-std"))]
179        {
180            if let Ok(mut manager) = memory_manager.lock() {
181                for device in &devices {
182                    manager.add_device(device.clone());
183                }
184            }
185        }
186        #[cfg(feature = "no-std")]
187        {
188            let mut manager = memory_manager.lock();
189            for device in &devices {
190                manager.add_device(device.clone());
191            }
192        }
193
194        Self {
195            devices,
196            memory_manager,
197            load_balancer: LoadBalancer::new(LoadBalancingStrategy::ComputeWeighted),
198            task_scheduler: TaskScheduler::new(),
199            sync_manager: SynchronizationManager::new(),
200        }
201    }
202
203    /// Add a task to the scheduler
204    pub fn submit_task(&mut self, task: GpuTask) -> Result<(), SimdError> {
205        self.task_scheduler.add_task(task);
206        Ok(())
207    }
208
209    /// Execute all pending tasks
210    pub fn execute_all(&mut self) -> Result<Vec<CompletedTask>, SimdError> {
211        let mut results = Vec::new();
212
213        // Schedule tasks based on dependencies and load balancing
214        let scheduled_tasks = self.schedule_tasks()?;
215
216        // Execute tasks in parallel (or sequentially in no-std)
217        #[cfg(not(feature = "no-std"))]
218        {
219            let handles: Vec<_> = scheduled_tasks
220                .into_iter()
221                .map(|(device_id, tasks)| {
222                    let memory_manager = Arc::clone(&self.memory_manager);
223                    thread::spawn(move || {
224                        Self::execute_device_tasks(device_id, tasks, memory_manager)
225                    })
226                })
227                .collect();
228
229            // Collect results
230            for handle in handles {
231                match handle.join() {
232                    Ok(device_results) => results.extend(device_results),
233                    Err(_) => {
234                        return Err(SimdError::ExternalLibraryError(
235                            "Thread execution failed".to_string(),
236                        ))
237                    }
238                }
239            }
240        }
241
242        #[cfg(feature = "no-std")]
243        {
244            // Sequential execution for no-std
245            for (device_id, tasks) in scheduled_tasks {
246                let memory_manager = Arc::clone(&self.memory_manager);
247                let device_results = Self::execute_device_tasks(device_id, tasks, memory_manager);
248                results.extend(device_results);
249            }
250        }
251
252        // Update performance history
253        self.update_performance_history(&results);
254
255        Ok(results)
256    }
257
258    /// Execute a distributed matrix multiplication
259    pub fn distributed_matrix_multiply(
260        &mut self,
261        a: &[f32],
262        b: &[f32],
263        a_rows: usize,
264        a_cols: usize,
265        b_cols: usize,
266    ) -> Result<Vec<f32>, SimdError> {
267        let num_devices = self.devices.len();
268        if num_devices == 0 {
269            return Err(SimdError::ExternalLibraryError(
270                "No GPU devices available".to_string(),
271            ));
272        }
273
274        // Distribute rows across devices
275        let rows_per_device = a_rows / num_devices;
276        let mut tasks = Vec::new();
277
278        for (i, device) in self.devices.iter().enumerate() {
279            let start_row = i * rows_per_device;
280            let end_row = if i == num_devices - 1 {
281                a_rows
282            } else {
283                (i + 1) * rows_per_device
284            };
285            let device_rows = end_row - start_row;
286
287            if device_rows == 0 {
288                continue;
289            }
290
291            // Create task for this device
292            let task = GpuTask {
293                id: format!("matmul_device_{}", i),
294                kernel_name: "matrix_mul".to_string(),
295                config: KernelConfig {
296                    grid_size: (
297                        b_cols.div_ceil(16) as u32,
298                        device_rows.div_ceil(16) as u32,
299                        1,
300                    ),
301                    block_size: (16, 16, 1),
302                    shared_memory: 0,
303                    stream: None,
304                },
305                input_data: vec![
306                    GpuTaskData {
307                        name: "matrix_a".to_string(),
308                        #[cfg(not(feature = "no-std"))]
309                        size: device_rows * a_cols * std::mem::size_of::<f32>(),
310                        #[cfg(feature = "no-std")]
311                        size: device_rows * a_cols * mem::size_of::<f32>(),
312                        data_type: "f32".to_string(),
313                        location: DataLocation::Host(
314                            a[start_row * a_cols..end_row * a_cols]
315                                .iter()
316                                .flat_map(|&x| x.to_ne_bytes())
317                                .collect(),
318                        ),
319                    },
320                    GpuTaskData {
321                        name: "matrix_b".to_string(),
322                        #[cfg(not(feature = "no-std"))]
323                        size: a_cols * b_cols * std::mem::size_of::<f32>(),
324                        #[cfg(feature = "no-std")]
325                        size: a_cols * b_cols * mem::size_of::<f32>(),
326                        data_type: "f32".to_string(),
327                        location: DataLocation::Host(
328                            b.iter().flat_map(|&x| x.to_ne_bytes()).collect(),
329                        ),
330                    },
331                ],
332                output_data: vec![GpuTaskData {
333                    name: "matrix_c".to_string(),
334                    #[cfg(not(feature = "no-std"))]
335                    size: device_rows * b_cols * std::mem::size_of::<f32>(),
336                    #[cfg(feature = "no-std")]
337                    size: device_rows * b_cols * mem::size_of::<f32>(),
338                    data_type: "f32".to_string(),
339                    location: DataLocation::Host(Vec::new()),
340                }],
341                device_preference: Some(device.id),
342                priority: TaskPriority::High,
343                dependencies: Vec::new(),
344            };
345
346            tasks.push(task);
347        }
348
349        // Submit and execute tasks
350        for task in tasks {
351            self.submit_task(task)?;
352        }
353
354        let results = self.execute_all()?;
355
356        // Combine results
357        let output = vec![0.0f32; a_rows * b_cols];
358        let mut _current_row = 0;
359
360        for result in results {
361            if result.success {
362                // Extract result data from completed task
363                // This would involve copying from GPU memory
364                let device_rows = rows_per_device;
365                _current_row += device_rows;
366            }
367        }
368
369        Ok(output)
370    }
371
372    /// Set load balancing strategy
373    pub fn set_load_balancing(&mut self, strategy: LoadBalancingStrategy) {
374        self.load_balancer.set_strategy(strategy);
375    }
376
377    /// Get device utilization statistics
378    pub fn get_device_stats(&self) -> HashMap<u32, DeviceStats> {
379        let mut stats = HashMap::new();
380
381        for device in &self.devices {
382            let device_stats = DeviceStats {
383                device_id: device.id,
384                name: device.name.clone(),
385                compute_units: device.compute_units,
386                memory_mb: device.memory_mb,
387                current_tasks: self.task_scheduler.get_device_task_count(device.id),
388                average_performance: self.load_balancer.get_average_performance(device.id),
389            };
390            stats.insert(device.id, device_stats);
391        }
392
393        stats
394    }
395
396    fn schedule_tasks(&mut self) -> Result<HashMap<u32, Vec<GpuTask>>, SimdError> {
397        let mut scheduled = HashMap::new();
398
399        // Get available tasks (no pending dependencies)
400        let available_tasks = self.task_scheduler.get_available_tasks();
401
402        for task in available_tasks {
403            let device_id = if let Some(preferred) = task.device_preference {
404                preferred
405            } else {
406                self.load_balancer.select_device(&self.devices, &task)?
407            };
408
409            scheduled
410                .entry(device_id)
411                .or_insert_with(Vec::new)
412                .push(task);
413        }
414
415        Ok(scheduled)
416    }
417
418    fn execute_device_tasks(
419        device_id: u32,
420        tasks: Vec<GpuTask>,
421        _memory_manager: Arc<Mutex<MultiGpuMemoryManager>>,
422    ) -> Vec<CompletedTask> {
423        let mut results = Vec::new();
424
425        for task in tasks {
426            #[cfg(not(feature = "no-std"))]
427            let start_time = std::time::Instant::now();
428            #[cfg(feature = "no-std")]
429            let start_time = Instant::now();
430
431            // Execute task (simplified)
432            let result = CompletedTask {
433                task_id: task.id.clone(),
434                device_id,
435                #[cfg(not(feature = "no-std"))]
436                execution_time_ms: start_time.elapsed().as_millis() as f64,
437                #[cfg(feature = "no-std")]
438                execution_time_ms: start_time.elapsed() as f64 / 1_000_000.0, // Convert nanoseconds to milliseconds
439                memory_used: task.input_data.iter().map(|d| d.size).sum(),
440                success: true, // Placeholder
441                error: None,
442            };
443
444            results.push(result);
445        }
446
447        results
448    }
449
450    fn update_performance_history(&mut self, results: &[CompletedTask]) {
451        for result in results {
452            self.load_balancer.add_performance_sample(
453                result.device_id,
454                1.0 / result.execution_time_ms, // Operations per ms
455            );
456        }
457    }
458}
459
460/// Device performance statistics
461#[derive(Debug, Clone)]
462pub struct DeviceStats {
463    pub device_id: u32,
464    pub name: String,
465    pub compute_units: u32,
466    pub memory_mb: u64,
467    pub current_tasks: usize,
468    pub average_performance: f64,
469}
470
471impl LoadBalancer {
472    pub fn new(strategy: LoadBalancingStrategy) -> Self {
473        Self {
474            strategy,
475            device_weights: HashMap::new(),
476            performance_history: HashMap::new(),
477        }
478    }
479
480    pub fn set_strategy(&mut self, strategy: LoadBalancingStrategy) {
481        self.strategy = strategy;
482    }
483
484    pub fn select_device(&self, devices: &[GpuDevice], _task: &GpuTask) -> Result<u32, SimdError> {
485        if devices.is_empty() {
486            return Err(SimdError::ExternalLibraryError(
487                "No devices available".to_string(),
488            ));
489        }
490
491        match self.strategy {
492            LoadBalancingStrategy::Equal => Ok(devices[0].id),
493            LoadBalancingStrategy::ComputeWeighted => {
494                // Select device with most compute units
495                let best_device = devices
496                    .iter()
497                    .max_by_key(|d| d.compute_units)
498                    .expect("operation should succeed");
499                Ok(best_device.id)
500            }
501            LoadBalancingStrategy::BandwidthWeighted => {
502                // Select device with most memory (proxy for bandwidth)
503                let best_device = devices
504                    .iter()
505                    .max_by_key(|d| d.memory_mb)
506                    .expect("operation should succeed");
507                Ok(best_device.id)
508            }
509            LoadBalancingStrategy::Dynamic => {
510                // Select device with best recent performance
511                let best_device = devices
512                    .iter()
513                    .max_by(|a, b| {
514                        let a_perf = self.get_average_performance(a.id);
515                        let b_perf = self.get_average_performance(b.id);
516                        a_perf.partial_cmp(&b_perf).unwrap_or(Ordering::Equal)
517                    })
518                    .expect("operation should succeed");
519                Ok(best_device.id)
520            }
521            LoadBalancingStrategy::Custom => {
522                // Use custom weights
523                let best_device = devices
524                    .iter()
525                    .max_by(|a, b| {
526                        let a_weight = self.device_weights.get(&a.id).unwrap_or(&1.0);
527                        let b_weight = self.device_weights.get(&b.id).unwrap_or(&1.0);
528                        a_weight.partial_cmp(b_weight).unwrap_or(Ordering::Equal)
529                    })
530                    .expect("operation should succeed");
531                Ok(best_device.id)
532            }
533        }
534    }
535
536    pub fn add_performance_sample(&mut self, device_id: u32, performance: f64) {
537        let history = self.performance_history.entry(device_id).or_default();
538        history.push(performance);
539
540        // Keep only recent samples
541        if history.len() > 100 {
542            history.remove(0);
543        }
544    }
545
546    pub fn get_average_performance(&self, device_id: u32) -> f64 {
547        if let Some(history) = self.performance_history.get(&device_id) {
548            if history.is_empty() {
549                1.0
550            } else {
551                history.iter().sum::<f64>() / history.len() as f64
552            }
553        } else {
554            1.0
555        }
556    }
557
558    pub fn set_custom_weight(&mut self, device_id: u32, weight: f64) {
559        self.device_weights.insert(device_id, weight);
560    }
561}
562
563impl TaskScheduler {
564    pub fn new() -> Self {
565        Self {
566            pending_tasks: Vec::new(),
567            running_tasks: HashMap::new(),
568            completed_tasks: Vec::new(),
569        }
570    }
571
572    pub fn add_task(&mut self, task: GpuTask) {
573        self.pending_tasks.push(task);
574    }
575
576    pub fn get_available_tasks(&mut self) -> Vec<GpuTask> {
577        let completed_ids: HashSet<_> = self.completed_tasks.iter().map(|t| &t.task_id).collect();
578
579        let mut available = Vec::new();
580        let mut remaining = Vec::new();
581
582        for task in self.pending_tasks.drain(..) {
583            let deps_satisfied = task
584                .dependencies
585                .iter()
586                .all(|dep| completed_ids.contains(dep));
587
588            if deps_satisfied {
589                available.push(task);
590            } else {
591                remaining.push(task);
592            }
593        }
594
595        self.pending_tasks = remaining;
596        available.sort_by_key(|b| core::cmp::Reverse(b.priority));
597        available
598    }
599
600    pub fn get_device_task_count(&self, device_id: u32) -> usize {
601        self.running_tasks
602            .get(&device_id)
603            .map_or(0, |tasks| tasks.len())
604    }
605
606    pub fn mark_task_completed(&mut self, task_id: String) {
607        // Remove from running tasks
608        for tasks in self.running_tasks.values_mut() {
609            tasks.retain(|t| t.id != task_id);
610        }
611    }
612}
613
614impl SynchronizationManager {
615    pub fn new() -> Self {
616        Self {
617            barriers: HashMap::new(),
618            events: HashMap::new(),
619        }
620    }
621
622    pub fn create_barrier(
623        &mut self,
624        name: String,
625        participant_count: u32,
626    ) -> Result<(), SimdError> {
627        let barrier = GpuBarrier {
628            name: name.clone(),
629            expected_participants: participant_count,
630            current_participants: 0,
631            waiting_devices: Vec::new(),
632        };
633
634        self.barriers.insert(name, barrier);
635        Ok(())
636    }
637
638    pub fn wait_barrier(&mut self, name: &str, device_id: u32) -> Result<(), SimdError> {
639        let should_synchronize = if let Some(barrier) = self.barriers.get_mut(name) {
640            barrier.current_participants += 1;
641            barrier.waiting_devices.push(device_id);
642
643            if barrier.current_participants >= barrier.expected_participants {
644                // All devices reached barrier, synchronize
645                let waiting_devices = barrier.waiting_devices.clone();
646                barrier.current_participants = 0;
647                barrier.waiting_devices.clear();
648                Some(waiting_devices)
649            } else {
650                None
651            }
652        } else {
653            return Err(SimdError::InvalidParameter {
654                name: "name".to_string(),
655                value: name.to_string(),
656            });
657        };
658
659        if let Some(waiting_devices) = should_synchronize {
660            self.synchronize_devices(&waiting_devices)?;
661        }
662
663        Ok(())
664    }
665
666    pub fn create_event(&mut self, name: String, device_id: u32) -> Result<(), SimdError> {
667        let event = GpuEvent {
668            name: name.clone(),
669            device_id,
670            is_recorded: false,
671            backend_event: None,
672        };
673
674        self.events.insert(name, event);
675        Ok(())
676    }
677
678    pub fn record_event(&mut self, name: &str) -> Result<(), SimdError> {
679        if let Some(event) = self.events.get_mut(name) {
680            event.is_recorded = true;
681            // Record backend-specific event
682            Ok(())
683        } else {
684            Err(SimdError::InvalidParameter {
685                name: "event".to_string(),
686                value: format!("Event '{}' not found", name),
687            })
688        }
689    }
690
691    fn synchronize_devices(&self, device_ids: &[u32]) -> Result<(), SimdError> {
692        // Synchronize all specified devices
693        for &_device_id in device_ids {
694            // Device-specific synchronization
695        }
696        Ok(())
697    }
698}
699
700impl Default for TaskScheduler {
701    fn default() -> Self {
702        Self::new()
703    }
704}
705
706impl Default for SynchronizationManager {
707    fn default() -> Self {
708        Self::new()
709    }
710}
711
712#[allow(non_snake_case)]
713#[cfg(all(test, not(feature = "no-std")))]
714mod tests {
715    use super::*;
716    use crate::gpu::GpuBackend;
717
718    #[cfg(feature = "no-std")]
719    use alloc::{
720        string::{String, ToString},
721        vec,
722        vec::Vec,
723    };
724
725    #[test]
726    fn test_multi_gpu_coordinator_creation() {
727        let devices = vec![
728            GpuDevice {
729                id: 0,
730                name: "Device 0".to_string(),
731                backend: GpuBackend::Cuda,
732                compute_units: 80,
733                memory_mb: 8192,
734                supports_f64: true,
735                supports_f16: true,
736            },
737            GpuDevice {
738                id: 1,
739                name: "Device 1".to_string(),
740                backend: GpuBackend::Cuda,
741                compute_units: 40,
742                memory_mb: 4096,
743                supports_f64: true,
744                supports_f16: true,
745            },
746        ];
747
748        let coordinator = MultiGpuCoordinator::new(devices);
749        assert_eq!(coordinator.devices.len(), 2);
750    }
751
752    #[test]
753    fn test_load_balancer() {
754        let balancer = LoadBalancer::new(LoadBalancingStrategy::ComputeWeighted);
755
756        let devices = vec![
757            GpuDevice {
758                id: 0,
759                name: "Device 0".to_string(),
760                backend: GpuBackend::Cuda,
761                compute_units: 80,
762                memory_mb: 8192,
763                supports_f64: true,
764                supports_f16: true,
765            },
766            GpuDevice {
767                id: 1,
768                name: "Device 1".to_string(),
769                backend: GpuBackend::Cuda,
770                compute_units: 40,
771                memory_mb: 4096,
772                supports_f64: true,
773                supports_f16: true,
774            },
775        ];
776
777        let task = GpuTask {
778            id: "test_task".to_string(),
779            kernel_name: "test_kernel".to_string(),
780            config: KernelConfig::default(),
781            input_data: Vec::new(),
782            output_data: Vec::new(),
783            device_preference: None,
784            priority: TaskPriority::Normal,
785            dependencies: Vec::new(),
786        };
787
788        let selected = balancer
789            .select_device(&devices, &task)
790            .expect("operation should succeed");
791        assert_eq!(selected, 0); // Device 0 has more compute units
792    }
793
794    #[test]
795    fn test_task_scheduler() {
796        let mut scheduler = TaskScheduler::new();
797
798        let task = GpuTask {
799            id: "test_task".to_string(),
800            kernel_name: "test_kernel".to_string(),
801            config: KernelConfig::default(),
802            input_data: Vec::new(),
803            output_data: Vec::new(),
804            device_preference: None,
805            priority: TaskPriority::High,
806            dependencies: Vec::new(),
807        };
808
809        scheduler.add_task(task);
810        let available = scheduler.get_available_tasks();
811        assert_eq!(available.len(), 1);
812        assert_eq!(available[0].priority, TaskPriority::High);
813    }
814
815    #[test]
816    fn test_task_dependencies() {
817        let mut scheduler = TaskScheduler::new();
818
819        let task1 = GpuTask {
820            id: "task1".to_string(),
821            kernel_name: "kernel1".to_string(),
822            config: KernelConfig::default(),
823            input_data: Vec::new(),
824            output_data: Vec::new(),
825            device_preference: None,
826            priority: TaskPriority::Normal,
827            dependencies: Vec::new(),
828        };
829
830        let task2 = GpuTask {
831            id: "task2".to_string(),
832            kernel_name: "kernel2".to_string(),
833            config: KernelConfig::default(),
834            input_data: Vec::new(),
835            output_data: Vec::new(),
836            device_preference: None,
837            priority: TaskPriority::Normal,
838            dependencies: vec!["task1".to_string()],
839        };
840
841        scheduler.add_task(task1);
842        scheduler.add_task(task2);
843
844        let available = scheduler.get_available_tasks();
845        assert_eq!(available.len(), 1); // Only task1 should be available
846        assert_eq!(available[0].id, "task1");
847    }
848
849    #[test]
850    fn test_synchronization_manager() {
851        let mut sync_manager = SynchronizationManager::new();
852
853        sync_manager
854            .create_barrier("test_barrier".to_string(), 2)
855            .expect("operation should succeed");
856        sync_manager
857            .create_event("test_event".to_string(), 0)
858            .expect("operation should succeed");
859
860        assert!(sync_manager.barriers.contains_key("test_barrier"));
861        assert!(sync_manager.events.contains_key("test_event"));
862    }
863
864    #[test]
865    fn test_device_stats() {
866        let stats = DeviceStats {
867            device_id: 0,
868            name: "Test Device".to_string(),
869            compute_units: 80,
870            memory_mb: 8192,
871            current_tasks: 3,
872            average_performance: 1.5,
873        };
874
875        assert_eq!(stats.device_id, 0);
876        assert_eq!(stats.current_tasks, 3);
877        assert!((stats.average_performance - 1.5).abs() < 0.001);
878    }
879}