1use crate::Tensor;
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use torsh_core::{device::DeviceType, dtype::TensorElement, error::Result};
12
13#[cfg(feature = "gpu")]
29pub struct GpuContext;
30
31#[cfg(feature = "gpu")]
32pub struct GpuKernel;
33
34#[cfg(feature = "gpu")]
35impl GpuContext {
36 pub fn new() -> Result<Self> {
37 Err(torsh_core::error::TorshError::InvalidArgument(
38 "GPU support temporarily unavailable".to_string(),
39 ))
40 }
41}
42
43#[cfg(feature = "gpu")]
44impl GpuKernel {
45 pub fn load(_context: &GpuContext, _name: &str) -> Result<Self> {
46 Err(torsh_core::error::TorshError::InvalidArgument(
47 "GPU support temporarily unavailable".to_string(),
48 ))
49 }
50
51 pub fn auto_tune(&mut self, _tuning_params: &[(String, f32)]) -> Result<()> {
52 Err(torsh_core::error::TorshError::InvalidArgument(
53 "GPU support temporarily unavailable".to_string(),
54 ))
55 }
56
57 pub fn enable_fusion(&mut self, _enable: bool) -> Result<()> {
58 Err(torsh_core::error::TorshError::InvalidArgument(
59 "GPU support temporarily unavailable".to_string(),
60 ))
61 }
62
63 pub fn enable_tensor_cores(&mut self, _enable: bool) -> Result<()> {
64 Err(torsh_core::error::TorshError::InvalidArgument(
65 "GPU support temporarily unavailable".to_string(),
66 ))
67 }
68
69 pub fn supports_tensor_cores(&self) -> bool {
70 false
71 }
72
73 pub fn execute<T>(&self, _input: &[T], _output: &mut [T]) -> Result<()> {
74 Err(torsh_core::error::TorshError::InvalidArgument(
75 "GPU support temporarily unavailable".to_string(),
76 ))
77 }
78}
79
80#[derive(Debug, Clone)]
82pub enum DeviceOptimization {
83 Cpu(CpuOptimization),
85 Gpu(GpuOptimization),
87 Metal(MetalOptimization),
89 WebGpu(WebGpuOptimization),
91}
92
93#[derive(Debug, Clone)]
95pub struct CpuOptimization {
96 pub use_simd: bool,
98 pub thread_count: Option<usize>,
100 pub cache_friendly: bool,
102 pub numa_aware: bool,
104}
105
106#[derive(Debug, Clone)]
108pub struct GpuOptimization {
109 pub use_pinned_memory: bool,
111 pub stream_count: u32,
113 pub mixed_precision: bool,
115 pub memory_pool_size: Option<usize>,
117
118 pub use_tensor_cores: bool,
121 pub auto_kernel_tuning: bool,
123 pub use_unified_memory: bool,
125 pub multi_gpu_strategy: MultiGpuStrategy,
127 pub backend_preference: Vec<GpuBackendType>,
129 pub memory_coalescing: bool,
131 pub kernel_fusion_level: u8,
133 pub dynamic_batching: bool,
135}
136
137#[derive(Debug, Clone)]
139pub enum MultiGpuStrategy {
140 Single,
142 DataParallel,
144 ModelParallel,
146 PipelineParallel,
148 Auto,
150}
151
152#[derive(Debug, Clone, PartialEq, Eq)]
154pub enum GpuBackendType {
155 Cuda,
157 Metal,
159 WebGpu,
161 Rocm,
163 OpenCl,
165}
166
167#[derive(Debug, Clone)]
169pub struct MetalOptimization {
170 pub use_mps: bool,
172 pub command_buffer_count: u32,
174 pub auto_memory_management: bool,
176}
177
178#[derive(Debug, Clone)]
180pub struct WebGpuOptimization {
181 pub use_compute_shaders: bool,
183 pub buffer_pool_size: Option<usize>,
185 pub pipeline_caching: bool,
187}
188
189#[derive(Debug)]
191pub struct OperationScheduler {
192 device_queues: HashMap<DeviceType, Vec<ScheduledOperation>>,
194 sync_state: HashMap<DeviceType, SyncState>,
196 operation_counter: Arc<RwLock<u64>>,
198}
199
200#[derive(Debug)]
202pub struct ScheduledOperation {
203 pub id: u64,
205 pub operation: OperationType,
207 pub priority: u8,
209 pub dependencies: Vec<DeviceType>,
211}
212
213#[derive(Debug)]
215pub enum OperationType {
216 Compute,
218 Transfer,
220 Synchronization,
222}
223
224#[derive(Debug)]
226pub struct SyncState {
227 pub last_operation: std::time::Instant,
229 pub pending_transfers: usize,
231 pub available: bool,
233}
234
235impl Default for CpuOptimization {
236 fn default() -> Self {
237 Self {
238 use_simd: true,
239 thread_count: None, cache_friendly: true,
241 numa_aware: true,
242 }
243 }
244}
245
246impl Default for GpuOptimization {
247 fn default() -> Self {
248 Self {
249 use_pinned_memory: true,
250 stream_count: 4,
251 mixed_precision: false,
252 memory_pool_size: Some(1024 * 1024 * 1024), use_tensor_cores: true, auto_kernel_tuning: true, use_unified_memory: true, multi_gpu_strategy: MultiGpuStrategy::Auto, backend_preference: vec![
260 GpuBackendType::Cuda, GpuBackendType::Metal, GpuBackendType::Rocm, GpuBackendType::WebGpu, GpuBackendType::OpenCl, ],
266 memory_coalescing: true, kernel_fusion_level: 2, dynamic_batching: true, }
270 }
271}
272
273impl Default for MetalOptimization {
274 fn default() -> Self {
275 Self {
276 use_mps: true,
277 command_buffer_count: 8,
278 auto_memory_management: true,
279 }
280 }
281}
282
283impl Default for WebGpuOptimization {
284 fn default() -> Self {
285 Self {
286 use_compute_shaders: true,
287 buffer_pool_size: Some(256 * 1024 * 1024), pipeline_caching: true,
289 }
290 }
291}
292
293impl<T: TensorElement + Copy> Tensor<T> {
294 pub fn to_device(&self, target_device: DeviceType) -> Result<Self> {
296 if self.device == target_device {
297 return Ok(self.clone());
298 }
299
300 let optimization = self.get_device_optimization(target_device);
302
303 match (self.device, target_device) {
305 (DeviceType::Cpu, DeviceType::Cuda(gpu_id)) => {
306 self.cpu_to_gpu_transfer(gpu_id as u32, optimization)
307 }
308 (DeviceType::Cuda(gpu_id), DeviceType::Cpu) => {
309 self.gpu_to_cpu_transfer(gpu_id as u32, optimization)
310 }
311 (DeviceType::Cpu, DeviceType::Metal(metal_id)) => {
312 self.cpu_to_metal_transfer(metal_id as u32, optimization)
313 }
314 (DeviceType::Metal(metal_id), DeviceType::Cpu) => {
315 self.metal_to_cpu_transfer(metal_id as u32, optimization)
316 }
317 _ => {
318 self.generic_device_transfer(target_device)
320 }
321 }
322 }
323
324 fn get_device_optimization(&self, device: DeviceType) -> DeviceOptimization {
326 match device {
327 DeviceType::Cpu => DeviceOptimization::Cpu(CpuOptimization::default()),
328 DeviceType::Cuda(_) => DeviceOptimization::Gpu(GpuOptimization::default()),
329 DeviceType::Metal(_) => DeviceOptimization::Metal(MetalOptimization::default()),
330 DeviceType::Wgpu(_) => DeviceOptimization::Gpu(GpuOptimization::default()),
331 }
332 }
333
334 fn cpu_to_gpu_transfer(&self, _gpu_id: u32, optimization: DeviceOptimization) -> Result<Self> {
336 let data = self.to_vec()?;
337
338 if let DeviceOptimization::Gpu(gpu_opt) = optimization {
340 if gpu_opt.use_pinned_memory {
341 self.transfer_with_pinned_memory(data, DeviceType::Cuda(_gpu_id as usize))
343 } else {
344 Self::from_data(
346 data,
347 self.shape().dims().to_vec(),
348 DeviceType::Cuda(_gpu_id as usize),
349 )
350 }
351 } else {
352 Self::from_data(
353 data,
354 self.shape().dims().to_vec(),
355 DeviceType::Cuda(_gpu_id as usize),
356 )
357 }
358 }
359
360 fn gpu_to_cpu_transfer(&self, _gpu_id: u32, optimization: DeviceOptimization) -> Result<Self> {
362 let data = self.to_vec()?;
363
364 if let DeviceOptimization::Cpu(cpu_opt) = optimization {
366 if cpu_opt.numa_aware {
367 self.transfer_with_numa_awareness(data, DeviceType::Cpu)
369 } else {
370 Self::from_data(data, self.shape().dims().to_vec(), DeviceType::Cpu)
372 }
373 } else {
374 Self::from_data(data, self.shape().dims().to_vec(), DeviceType::Cpu)
375 }
376 }
377
378 fn cpu_to_metal_transfer(
380 &self,
381 _metal_id: u32,
382 optimization: DeviceOptimization,
383 ) -> Result<Self> {
384 let data = self.to_vec()?;
385
386 if let DeviceOptimization::Metal(metal_opt) = optimization {
388 if metal_opt.use_mps {
389 self.transfer_with_mps(data, DeviceType::Metal(_metal_id as usize))
391 } else {
392 Self::from_data(
394 data,
395 self.shape().dims().to_vec(),
396 DeviceType::Metal(_metal_id as usize),
397 )
398 }
399 } else {
400 Self::from_data(
401 data,
402 self.shape().dims().to_vec(),
403 DeviceType::Metal(_metal_id as usize),
404 )
405 }
406 }
407
408 fn metal_to_cpu_transfer(
410 &self,
411 _metal_id: u32,
412 optimization: DeviceOptimization,
413 ) -> Result<Self> {
414 let data = self.to_vec()?;
415
416 if let DeviceOptimization::Cpu(cpu_opt) = optimization {
418 if cpu_opt.cache_friendly {
419 self.transfer_with_cache_optimization(data, DeviceType::Cpu)
421 } else {
422 Self::from_data(data, self.shape().dims().to_vec(), DeviceType::Cpu)
424 }
425 } else {
426 Self::from_data(data, self.shape().dims().to_vec(), DeviceType::Cpu)
427 }
428 }
429
430 fn generic_device_transfer(&self, target_device: DeviceType) -> Result<Self> {
432 let data = self.to_vec()?;
433 Self::from_data(data, self.shape().dims().to_vec(), target_device)
434 }
435
436 fn transfer_with_pinned_memory(&self, data: Vec<T>, target_device: DeviceType) -> Result<Self> {
438 Self::from_data(data, self.shape().dims().to_vec(), target_device)
440 }
441
442 fn transfer_with_numa_awareness(
444 &self,
445 data: Vec<T>,
446 target_device: DeviceType,
447 ) -> Result<Self> {
448 Self::from_data(data, self.shape().dims().to_vec(), target_device)
450 }
451
452 fn transfer_with_mps(&self, data: Vec<T>, target_device: DeviceType) -> Result<Self> {
454 Self::from_data(data, self.shape().dims().to_vec(), target_device)
456 }
457
458 fn transfer_with_cache_optimization(
460 &self,
461 data: Vec<T>,
462 target_device: DeviceType,
463 ) -> Result<Self> {
464 let optimized_data = self.optimize_for_cache(data)?;
466 Self::from_data(optimized_data, self.shape().dims().to_vec(), target_device)
467 }
468
469 fn optimize_for_cache(&self, data: Vec<T>) -> Result<Vec<T>> {
471 Ok(data)
473 }
474
475 pub fn synchronize_devices(&self, devices: &[DeviceType]) -> Result<()> {
477 for device in devices {
479 self.synchronize_device(*device)?;
480 }
481 Ok(())
482 }
483
484 fn synchronize_device(&self, _device: DeviceType) -> Result<()> {
486 Ok(())
488 }
489
490 pub fn can_transfer_efficiently(&self, target_device: DeviceType) -> bool {
492 match (self.device, target_device) {
493 (a, b) if a == b => true,
495 (DeviceType::Cpu, DeviceType::Cuda(_)) | (DeviceType::Cuda(_), DeviceType::Cpu) => true,
497 (DeviceType::Cpu, DeviceType::Metal(_)) | (DeviceType::Metal(_), DeviceType::Cpu) => {
499 true
500 }
501 _ => false,
503 }
504 }
505
506 pub fn get_transfer_strategy(&self, target_device: DeviceType) -> TransferStrategy {
508 match (self.device, target_device) {
509 (a, b) if a == b => TransferStrategy::NoTransfer,
510 (DeviceType::Cpu, DeviceType::Cuda(_)) => TransferStrategy::DirectTransfer,
511 (DeviceType::Cuda(_), DeviceType::Cpu) => TransferStrategy::DirectTransfer,
512 (DeviceType::Cpu, DeviceType::Metal(_)) => TransferStrategy::DirectTransfer,
513 (DeviceType::Metal(_), DeviceType::Cpu) => TransferStrategy::DirectTransfer,
514 _ => TransferStrategy::ThroughCpu,
515 }
516 }
517}
518
519#[derive(Debug, Clone, PartialEq)]
521pub enum TransferStrategy {
522 NoTransfer,
524 DirectTransfer,
526 ThroughCpu,
528}
529
530impl OperationScheduler {
531 pub fn new() -> Self {
533 Self {
534 device_queues: HashMap::new(),
535 sync_state: HashMap::new(),
536 operation_counter: Arc::new(RwLock::new(0)),
537 }
538 }
539
540 pub fn schedule_operation(
542 &mut self,
543 device: DeviceType,
544 operation: OperationType,
545 priority: u8,
546 dependencies: Vec<DeviceType>,
547 ) -> Result<u64> {
548 let mut counter = self
550 .operation_counter
551 .write()
552 .expect("lock should not be poisoned");
553 *counter += 1;
554 let op_id = *counter;
555 drop(counter);
556
557 let scheduled_op = ScheduledOperation {
559 id: op_id,
560 operation,
561 priority,
562 dependencies,
563 };
564
565 self.device_queues
567 .entry(device)
568 .or_default()
569 .push(scheduled_op);
570
571 if let Some(queue) = self.device_queues.get_mut(&device) {
573 queue.sort_by(|a, b| b.priority.cmp(&a.priority));
574 }
575
576 self.sync_state.entry(device).or_insert_with(|| SyncState {
578 last_operation: std::time::Instant::now(),
579 pending_transfers: 0,
580 available: true,
581 });
582
583 Ok(op_id)
584 }
585
586 pub fn execute_next_operation(&mut self, device: DeviceType) -> Result<Option<u64>> {
588 let op = if let Some(queue) = self.device_queues.get_mut(&device) {
590 if queue.is_empty() {
591 None
592 } else {
593 Some(queue.remove(0)) }
595 } else {
596 None
597 };
598
599 if let Some(op) = op {
600 let dependencies_satisfied = self.check_dependencies(&op.dependencies)?;
602
603 if dependencies_satisfied {
604 self.execute_operation(&op)?;
606
607 if let Some(sync_state) = self.sync_state.get_mut(&device) {
609 sync_state.last_operation = std::time::Instant::now();
610 }
611
612 Ok(Some(op.id))
613 } else {
614 if let Some(queue) = self.device_queues.get_mut(&device) {
616 queue.insert(0, op);
617 }
618 Ok(None)
619 }
620 } else {
621 Ok(None)
622 }
623 }
624
625 fn check_dependencies(&self, dependencies: &[DeviceType]) -> Result<bool> {
627 for &dep_device in dependencies {
628 if let Some(sync_state) = self.sync_state.get(&dep_device) {
629 if !sync_state.available {
630 return Ok(false);
631 }
632 }
633 }
634 Ok(true)
635 }
636
637 fn execute_operation(&self, _operation: &ScheduledOperation) -> Result<()> {
639 std::thread::sleep(std::time::Duration::from_millis(1));
641 Ok(())
642 }
643
644 pub fn get_queue_length(&self, device: DeviceType) -> usize {
646 self.device_queues
647 .get(&device)
648 .map_or(0, |queue| queue.len())
649 }
650
651 pub fn clear_device_queue(&mut self, device: DeviceType) {
653 self.device_queues.remove(&device);
654 }
655}
656
657impl Default for OperationScheduler {
658 fn default() -> Self {
659 Self::new()
660 }
661}
662
663static GLOBAL_SCHEDULER: parking_lot::Mutex<Option<OperationScheduler>> =
665 parking_lot::Mutex::new(None);
666
667pub fn get_global_scheduler() -> parking_lot::MutexGuard<'static, Option<OperationScheduler>> {
669 let mut guard = GLOBAL_SCHEDULER.lock();
670 if guard.is_none() {
671 *guard = Some(OperationScheduler::new());
672 }
673 guard
674}
675
676pub fn initialize_global_scheduler() -> Result<()> {
678 let mut guard = GLOBAL_SCHEDULER.lock();
679 *guard = Some(OperationScheduler::new());
680 Ok(())
681}
682
683#[cfg(feature = "gpu")]
685impl<T: TensorElement + Copy + Default> Tensor<T> {
686 pub fn execute_gpu_kernel(&self, kernel_name: &str, _params: Vec<T>) -> Result<Self> {
688 let gpu_opt = match self.get_device_optimization(self.device) {
689 DeviceOptimization::Gpu(opt) => opt,
690 _ => {
691 return Err(torsh_core::error::TorshError::InvalidArgument(
692 "GPU kernel execution requires GPU device".to_string(),
693 ))
694 }
695 };
696
697 let gpu_context = self.create_optimal_gpu_context(&gpu_opt)?;
699
700 let input_buffer = self.create_gpu_buffer(&gpu_context, &gpu_opt)?;
702
703 let kernel = self.select_optimal_kernel(&gpu_context, kernel_name, &gpu_opt)?;
705
706 let mut output_buffer = vec![T::default(); input_buffer.len()];
708 kernel.execute(&input_buffer, &mut output_buffer)?;
709
710 self.gpu_buffer_to_tensor(output_buffer, &gpu_context, &gpu_opt)
712 }
713
714 #[allow(dead_code)]
717 fn create_optimal_gpu_context(&self, _gpu_opt: &GpuOptimization) -> Result<GpuContext> {
718 Err(torsh_core::error::TorshError::InvalidArgument(
750 "GPU backend creation temporarily disabled".to_string(),
751 ))
752 }
753
754 #[allow(dead_code)]
757 fn create_gpu_buffer(&self, _context: &GpuContext, _gpu_opt: &GpuOptimization) -> Result<Vec<T>>
758 where
759 T: Copy,
760 {
761 let data = self.to_vec()?;
762 Ok(data)
774 }
775
776 fn select_optimal_kernel(
778 &self,
779 context: &GpuContext,
780 kernel_name: &str,
781 gpu_opt: &GpuOptimization,
782 ) -> Result<GpuKernel> {
783 let mut kernel = GpuKernel::load(context, kernel_name).map_err(|e| {
784 torsh_core::error::TorshError::InvalidArgument(format!(
785 "Failed to load kernel '{}': {}",
786 kernel_name, e
787 ))
788 })?;
789
790 if gpu_opt.auto_kernel_tuning {
791 kernel.auto_tune(&[])?;
794 }
795
796 if gpu_opt.use_tensor_cores && kernel.supports_tensor_cores() {
797 kernel.enable_tensor_cores(true)?;
799 }
800
801 if gpu_opt.kernel_fusion_level > 0 {
802 kernel.enable_fusion(gpu_opt.kernel_fusion_level > 0)?;
804 }
805
806 Ok(kernel)
807 }
808
809 #[allow(dead_code)]
812 fn gpu_buffer_to_tensor(
813 &self,
814 buffer: Vec<T>, _context: &GpuContext,
816 _gpu_opt: &GpuOptimization,
817 ) -> Result<Self>
818 where
819 T: Copy,
820 {
821 Self::from_data(buffer, self.shape().dims().to_vec(), self.device)
831 }
832
833 pub fn distribute_multi_gpu(
835 &self,
836 gpu_count: usize,
837 strategy: Option<MultiGpuStrategy>,
838 ) -> Result<Vec<Self>> {
839 if gpu_count <= 1 {
840 return Ok(vec![self.clone()]);
841 }
842
843 let strategy = strategy.unwrap_or(MultiGpuStrategy::Auto);
844 let effective_strategy = match strategy {
845 MultiGpuStrategy::Auto => self.select_optimal_multi_gpu_strategy(gpu_count),
846 s => s,
847 };
848
849 match effective_strategy {
850 MultiGpuStrategy::DataParallel => self.data_parallel_distribution(gpu_count),
851 MultiGpuStrategy::ModelParallel => self.model_parallel_distribution(gpu_count),
852 MultiGpuStrategy::PipelineParallel => self.pipeline_parallel_distribution(gpu_count),
853 _ => Ok(vec![self.clone()]), }
855 }
856
857 fn select_optimal_multi_gpu_strategy(&self, gpu_count: usize) -> MultiGpuStrategy {
859 let _total_elements = self.numel();
860 let shape = self.shape();
861 let dims = shape.dims();
862
863 if dims.len() > 0 && dims[0] >= gpu_count * 4 {
865 return MultiGpuStrategy::DataParallel;
866 }
867
868 if dims.len() > 1 && dims.iter().skip(1).product::<usize>() > 1024 * 1024 {
870 return MultiGpuStrategy::ModelParallel;
871 }
872
873 if dims.len() > 3 {
875 return MultiGpuStrategy::PipelineParallel;
876 }
877
878 MultiGpuStrategy::DataParallel
880 }
881
882 fn data_parallel_distribution(&self, gpu_count: usize) -> Result<Vec<Self>> {
884 let shape = self.shape();
885 let dims = shape.dims();
886 if dims.is_empty() {
887 return Err(torsh_core::error::TorshError::InvalidArgument(
888 "Cannot distribute scalar tensor".to_string(),
889 ));
890 }
891
892 let batch_size = dims[0];
893 let chunk_size = (batch_size + gpu_count - 1) / gpu_count; let mut distributed_tensors = Vec::with_capacity(gpu_count);
896 let data = self.to_vec()?;
897 let elements_per_batch = dims.iter().skip(1).product::<usize>();
898
899 for gpu_id in 0..gpu_count {
900 let start_batch = gpu_id * chunk_size;
901 let end_batch = ((gpu_id + 1) * chunk_size).min(batch_size);
902
903 if start_batch >= batch_size {
904 break; }
906
907 let start_idx = start_batch * elements_per_batch;
908 let end_idx = end_batch * elements_per_batch;
909 let chunk_data = data[start_idx..end_idx].to_vec();
910
911 let mut chunk_dims = dims.to_vec();
912 chunk_dims[0] = end_batch - start_batch;
913
914 let chunk_tensor = Self::from_data(chunk_data, chunk_dims, DeviceType::Cuda(gpu_id))?;
915
916 distributed_tensors.push(chunk_tensor);
917 }
918
919 Ok(distributed_tensors)
920 }
921
922 fn model_parallel_distribution(&self, gpu_count: usize) -> Result<Vec<Self>> {
924 let shape = self.shape();
925 let dims = shape.dims();
926 if dims.len() < 2 {
927 return Err(torsh_core::error::TorshError::InvalidArgument(
928 "Model parallel requires at least 2D tensor".to_string(),
929 ));
930 }
931
932 let feature_dim = dims.len() - 1;
934 let feature_size = dims[feature_dim];
935 let chunk_size = (feature_size + gpu_count - 1) / gpu_count;
936
937 let mut distributed_tensors = Vec::with_capacity(gpu_count);
938 let _data = self.to_vec()?;
939
940 for gpu_id in 0..gpu_count {
941 let start_feature = gpu_id * chunk_size;
942 let end_feature = ((gpu_id + 1) * chunk_size).min(feature_size);
943
944 if start_feature >= feature_size {
945 break;
946 }
947
948 let mut chunk_dims = dims.to_vec();
951 chunk_dims[feature_dim] = end_feature - start_feature;
952
953 let chunk_size_total: usize = chunk_dims.iter().product();
955 let chunk_data = vec![T::default(); chunk_size_total];
956
957 let chunk_tensor = Self::from_data(chunk_data, chunk_dims, DeviceType::Cuda(gpu_id))?;
958
959 distributed_tensors.push(chunk_tensor);
960 }
961
962 Ok(distributed_tensors)
963 }
964
965 fn pipeline_parallel_distribution(&self, gpu_count: usize) -> Result<Vec<Self>> {
967 let mut distributed_tensors = Vec::with_capacity(gpu_count);
970
971 for gpu_id in 0..gpu_count {
972 let pipeline_tensor = Self::from_data(
973 self.to_vec()?,
974 self.shape().dims().to_vec(),
975 DeviceType::Cuda(gpu_id),
976 )?;
977 distributed_tensors.push(pipeline_tensor);
978 }
979
980 Ok(distributed_tensors)
981 }
982
983 #[allow(dead_code)]
986 pub fn enable_mixed_precision(
987 &mut self,
988 _precision: i32, ) -> Result<()> {
990 Err(torsh_core::error::TorshError::InvalidArgument(
1000 "Mixed precision temporarily disabled".to_string(),
1001 ))
1002 }
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007 use super::*;
1008 use crate::Tensor;
1009
1010 #[test]
1011 fn test_device_transfer() {
1012 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
1013 .expect("tensor creation should succeed");
1014
1015 let same_device = tensor
1017 .to_device(DeviceType::Cpu)
1018 .expect("device transfer should succeed");
1019 assert_eq!(same_device.device(), DeviceType::Cpu);
1020
1021 assert_eq!(
1023 tensor.get_transfer_strategy(DeviceType::Cpu),
1024 TransferStrategy::NoTransfer
1025 );
1026 assert_eq!(
1027 tensor.get_transfer_strategy(DeviceType::Cuda(0)),
1028 TransferStrategy::DirectTransfer
1029 );
1030 }
1031
1032 #[test]
1033 fn test_operation_scheduler() {
1034 let mut scheduler = OperationScheduler::new();
1035
1036 let op1 = scheduler
1038 .schedule_operation(DeviceType::Cpu, OperationType::Compute, 5, vec![])
1039 .expect("operation should succeed");
1040
1041 let op2 = scheduler
1042 .schedule_operation(DeviceType::Cpu, OperationType::Compute, 10, vec![])
1043 .expect("operation should succeed");
1044
1045 assert_eq!(
1047 scheduler
1048 .execute_next_operation(DeviceType::Cpu)
1049 .expect("operation execution should succeed"),
1050 Some(op2)
1051 );
1052 assert_eq!(
1053 scheduler
1054 .execute_next_operation(DeviceType::Cpu)
1055 .expect("operation execution should succeed"),
1056 Some(op1)
1057 );
1058 }
1059
1060 #[test]
1061 fn test_transfer_efficiency() {
1062 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
1063 .expect("tensor creation should succeed");
1064
1065 assert!(tensor.can_transfer_efficiently(DeviceType::Cpu));
1067
1068 assert!(tensor.can_transfer_efficiently(DeviceType::Cuda(0)));
1070
1071 assert!(tensor.can_transfer_efficiently(DeviceType::Metal(0)));
1073 }
1074
1075 #[test]
1076 fn test_device_optimization_defaults() {
1077 let cpu_opt = CpuOptimization::default();
1078 assert!(cpu_opt.use_simd);
1079 assert!(cpu_opt.cache_friendly);
1080 assert!(cpu_opt.numa_aware);
1081
1082 let gpu_opt = GpuOptimization::default();
1083 assert!(gpu_opt.use_pinned_memory);
1084 assert_eq!(gpu_opt.stream_count, 4);
1085 assert!(!gpu_opt.mixed_precision);
1086 }
1087
1088 #[test]
1089 fn test_global_scheduler() {
1090 initialize_global_scheduler().expect("scheduler initialization should succeed");
1091
1092 {
1093 let mut scheduler = get_global_scheduler();
1094 let scheduler = scheduler
1095 .as_mut()
1096 .expect("mutable reference should be available");
1097
1098 let op_id = scheduler
1099 .schedule_operation(DeviceType::Cpu, OperationType::Compute, 5, vec![])
1100 .expect("scheduler initialization should succeed");
1101
1102 assert_eq!(scheduler.get_queue_length(DeviceType::Cpu), 1);
1103 assert_eq!(
1104 scheduler
1105 .execute_next_operation(DeviceType::Cpu)
1106 .expect("operation execution should succeed"),
1107 Some(op_id)
1108 );
1109 }
1110 }
1111}