1use crate::{Device, Result, Shape, Tensor, TensorError};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub enum BackendType {
13 Cpu,
15 #[cfg(feature = "simd")]
17 SimdCpu,
18 #[cfg(feature = "blas")]
20 Blas,
21 #[cfg(feature = "gpu")]
23 Gpu,
24 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
26 Cuda,
27 #[cfg(feature = "rocm")]
29 Rocm,
30 #[cfg(all(feature = "metal", target_os = "macos"))]
32 Metal,
33}
34
35impl BackendType {
36 pub fn is_available(&self) -> bool {
38 match self {
39 BackendType::Cpu => true,
40 #[cfg(feature = "simd")]
41 BackendType::SimdCpu => true,
42 #[cfg(feature = "blas")]
43 BackendType::Blas => crate::ops::lapack::is_lapack_available(),
44 #[cfg(feature = "gpu")]
45 BackendType::Gpu => true, #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
47 BackendType::Cuda => crate::gpu::cuda_kernels::is_cuda_available(),
48 #[cfg(feature = "rocm")]
49 BackendType::Rocm => false, #[cfg(all(feature = "metal", target_os = "macos"))]
51 BackendType::Metal => true,
52 }
53 }
54
55 pub fn priority(&self) -> u8 {
57 match self {
58 BackendType::Cpu => 0,
59 #[cfg(feature = "simd")]
60 BackendType::SimdCpu => 10,
61 #[cfg(feature = "blas")]
62 BackendType::Blas => 20,
63 #[cfg(feature = "gpu")]
64 BackendType::Gpu => 30,
65 #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
66 BackendType::Cuda => 40,
67 #[cfg(feature = "rocm")]
68 BackendType::Rocm => 40,
69 #[cfg(all(feature = "metal", target_os = "macos"))]
70 BackendType::Metal => 50,
71 }
72 }
73
74 pub fn from_device(device: &Device) -> Self {
76 match device {
77 Device::Cpu => BackendType::Cpu,
78 #[cfg(feature = "gpu")]
79 Device::Gpu(_) => BackendType::Gpu,
80 #[cfg(feature = "rocm")]
81 Device::Rocm(_) => BackendType::Rocm,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct OperationDescriptor {
89 pub name: String,
91 pub category: String,
93 pub version: String,
95 pub supported_dtypes: Vec<crate::DType>,
97 pub min_rank: Option<usize>,
99 pub max_rank: Option<usize>,
101 pub supports_broadcast: bool,
103 pub supports_inplace: bool,
105}
106
107impl OperationDescriptor {
108 pub fn new(name: impl Into<String>, category: impl Into<String>) -> Self {
110 Self {
111 name: name.into(),
112 category: category.into(),
113 version: "1.0.0".to_string(),
114 supported_dtypes: vec![crate::DType::Float32, crate::DType::Float64],
115 min_rank: None,
116 max_rank: None,
117 supports_broadcast: false,
118 supports_inplace: false,
119 }
120 }
121
122 pub fn with_dtypes(mut self, dtypes: Vec<crate::DType>) -> Self {
124 self.supported_dtypes = dtypes;
125 self
126 }
127
128 pub fn with_rank_range(mut self, min: Option<usize>, max: Option<usize>) -> Self {
130 self.min_rank = min;
131 self.max_rank = max;
132 self
133 }
134
135 pub fn with_broadcast(mut self) -> Self {
137 self.supports_broadcast = true;
138 self
139 }
140
141 pub fn with_inplace(mut self) -> Self {
143 self.supports_inplace = true;
144 self
145 }
146}
147
148pub type UnaryKernelFn<T> = fn(&Tensor<T>) -> Result<Tensor<T>>;
150
151pub type BinaryKernelFn<T> = fn(&Tensor<T>, &Tensor<T>) -> Result<Tensor<T>>;
153
154#[derive(Clone)]
156pub struct KernelImplementation<T> {
157 pub backend: BackendType,
158 pub unary_fn: Option<UnaryKernelFn<T>>,
159 pub binary_fn: Option<BinaryKernelFn<T>>,
160}
161
162impl<T> KernelImplementation<T> {
163 pub fn unary(backend: BackendType, func: UnaryKernelFn<T>) -> Self {
165 Self {
166 backend,
167 unary_fn: Some(func),
168 binary_fn: None,
169 }
170 }
171
172 pub fn binary(backend: BackendType, func: BinaryKernelFn<T>) -> Self {
174 Self {
175 backend,
176 unary_fn: None,
177 binary_fn: Some(func),
178 }
179 }
180}
181
182struct RegisteredOperation<T> {
184 descriptor: OperationDescriptor,
185 kernels: Vec<KernelImplementation<T>>,
186}
187
188impl<T> RegisteredOperation<T> {
189 fn new(descriptor: OperationDescriptor) -> Self {
190 Self {
191 descriptor,
192 kernels: Vec::new(),
193 }
194 }
195
196 fn add_kernel(&mut self, kernel: KernelImplementation<T>) {
197 self.kernels.push(kernel);
198 }
199
200 fn select_kernel(&self, device: &Device) -> Option<&KernelImplementation<T>> {
202 let preferred_backend = BackendType::from_device(device);
203
204 if let Some(kernel) = self
206 .kernels
207 .iter()
208 .find(|k| k.backend == preferred_backend && k.backend.is_available())
209 {
210 return Some(kernel);
211 }
212
213 self.kernels
215 .iter()
216 .filter(|k| k.backend.is_available())
217 .max_by_key(|k| k.backend.priority())
218 }
219}
220
221pub struct DispatchRegistry<T> {
223 operations: Arc<RwLock<HashMap<String, RegisteredOperation<T>>>>,
224}
225
226impl<T> Default for DispatchRegistry<T> {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232impl<T> DispatchRegistry<T> {
233 pub fn new() -> Self {
235 Self {
236 operations: Arc::new(RwLock::new(HashMap::new())),
237 }
238 }
239
240 pub fn register_operation(&self, descriptor: OperationDescriptor) -> Result<()> {
242 let mut ops = self
243 .operations
244 .write()
245 .expect("write lock should not be poisoned");
246
247 if ops.contains_key(&descriptor.name) {
248 return Err(TensorError::invalid_argument(format!(
249 "Operation '{}' is already registered",
250 descriptor.name
251 )));
252 }
253
254 ops.insert(
255 descriptor.name.clone(),
256 RegisteredOperation::new(descriptor),
257 );
258 Ok(())
259 }
260
261 pub fn register_kernel(
263 &self,
264 operation_name: &str,
265 kernel: KernelImplementation<T>,
266 ) -> Result<()> {
267 let mut ops = self
268 .operations
269 .write()
270 .expect("write lock should not be poisoned");
271
272 let op = ops.get_mut(operation_name).ok_or_else(|| {
273 TensorError::invalid_argument(format!(
274 "Operation '{}' not found. Register the operation first.",
275 operation_name
276 ))
277 })?;
278
279 op.add_kernel(kernel);
280 Ok(())
281 }
282
283 pub fn dispatch_unary(&self, operation_name: &str, input: &Tensor<T>) -> Result<Tensor<T>> {
285 let ops = self
286 .operations
287 .read()
288 .expect("read lock should not be poisoned");
289
290 let op = ops.get(operation_name).ok_or_else(|| {
291 TensorError::invalid_argument(format!(
292 "Operation '{}' not found in registry",
293 operation_name
294 ))
295 })?;
296
297 let kernel = op.select_kernel(input.device()).ok_or_else(|| {
298 TensorError::invalid_argument(format!(
299 "No available kernel for operation '{}' on device {:?}",
300 operation_name,
301 input.device()
302 ))
303 })?;
304
305 let kernel_fn = kernel.unary_fn.ok_or_else(|| {
306 TensorError::invalid_argument(format!(
307 "Operation '{}' does not support unary execution",
308 operation_name
309 ))
310 })?;
311
312 kernel_fn(input)
313 }
314
315 pub fn dispatch_binary(
317 &self,
318 operation_name: &str,
319 lhs: &Tensor<T>,
320 rhs: &Tensor<T>,
321 ) -> Result<Tensor<T>> {
322 if lhs.device() != rhs.device() {
324 return Err(TensorError::device_mismatch(
325 operation_name,
326 &format!("{:?}", lhs.device()),
327 &format!("{:?}", rhs.device()),
328 ));
329 }
330
331 let ops = self
332 .operations
333 .read()
334 .expect("read lock should not be poisoned");
335
336 let op = ops.get(operation_name).ok_or_else(|| {
337 TensorError::invalid_argument(format!(
338 "Operation '{}' not found in registry",
339 operation_name
340 ))
341 })?;
342
343 let kernel = op.select_kernel(lhs.device()).ok_or_else(|| {
344 TensorError::invalid_argument(format!(
345 "No available kernel for operation '{}' on device {:?}",
346 operation_name,
347 lhs.device()
348 ))
349 })?;
350
351 let kernel_fn = kernel.binary_fn.ok_or_else(|| {
352 TensorError::invalid_argument(format!(
353 "Operation '{}' does not support binary execution",
354 operation_name
355 ))
356 })?;
357
358 kernel_fn(lhs, rhs)
359 }
360
361 pub fn get_operation(&self, name: &str) -> Option<OperationDescriptor> {
363 let ops = self
364 .operations
365 .read()
366 .expect("read lock should not be poisoned");
367 ops.get(name).map(|op| op.descriptor.clone())
368 }
369
370 pub fn list_operations(&self) -> Vec<String> {
372 let ops = self
373 .operations
374 .read()
375 .expect("read lock should not be poisoned");
376 ops.keys().cloned().collect()
377 }
378
379 pub fn available_backends(&self, operation_name: &str) -> Vec<BackendType> {
381 let ops = self
382 .operations
383 .read()
384 .expect("read lock should not be poisoned");
385
386 if let Some(op) = ops.get(operation_name) {
387 op.kernels
388 .iter()
389 .filter(|k| k.backend.is_available())
390 .map(|k| k.backend)
391 .collect()
392 } else {
393 Vec::new()
394 }
395 }
396}
397
398#[macro_export]
400macro_rules! register_operation {
401 ($registry:expr, $name:expr, $category:expr) => {
402 $registry.register_operation(
403 $crate::OperationDescriptor::new($name, $category)
404 ).expect("operation registration should succeed");
405 };
406 ($registry:expr, $name:expr, $category:expr, dtypes: [$($dtype:expr),*]) => {
407 $registry.register_operation(
408 $crate::OperationDescriptor::new($name, $category)
409 .with_dtypes(vec![$($dtype),*])
410 ).expect("operation registration with dtypes should succeed");
411 };
412 ($registry:expr, $name:expr, $category:expr, rank: $min:expr, $max:expr) => {
413 $registry.register_operation(
414 $crate::OperationDescriptor::new($name, $category)
415 .with_rank_range(Some($min), Some($max))
416 ).expect("operation registration with rank range should succeed");
417 };
418}
419
420#[macro_export]
422macro_rules! register_unary_kernel {
423 ($registry:expr, $op_name:expr, $backend:expr, $func:expr) => {
424 $registry
425 .register_kernel(
426 $op_name,
427 $crate::KernelImplementation::unary($backend, $func),
428 )
429 .expect("unary kernel registration should succeed");
430 };
431}
432
433#[macro_export]
435macro_rules! register_binary_kernel {
436 ($registry:expr, $op_name:expr, $backend:expr, $func:expr) => {
437 $registry
438 .register_kernel(
439 $op_name,
440 $crate::KernelImplementation::binary($backend, $func),
441 )
442 .expect("binary kernel registration should succeed");
443 };
444}
445
446#[derive(Debug, Clone)]
451pub struct DispatchBenchmarkResult {
452 pub min_ns: u64,
454 pub max_ns: u64,
456 pub avg_ns: u64,
458 pub p95_ns: u64,
460 pub sample_count: usize,
462}
463
464impl DispatchBenchmarkResult {
465 pub fn from_sorted_samples(samples: &[u64]) -> Option<Self> {
469 if samples.is_empty() {
470 return None;
471 }
472
473 let min_ns = *samples.first().unwrap_or(&0);
474 let max_ns = *samples.last().unwrap_or(&0);
475
476 let sum: u64 = samples.iter().sum();
477 let avg_ns = sum / samples.len() as u64;
478
479 let p95_idx = ((samples.len() as f64 * 0.95) as usize).min(samples.len() - 1);
481 let p95_ns = samples[p95_idx];
482
483 Some(Self {
484 min_ns,
485 max_ns,
486 avg_ns,
487 p95_ns,
488 sample_count: samples.len(),
489 })
490 }
491}
492
493impl<T> DispatchRegistry<T> {
494 pub fn benchmark_overhead(&self) -> DispatchBenchmarkResult {
503 const SAMPLE_COUNT: usize = 1_000;
504 const PROBE_NAME: &str = "__overhead_probe__";
505
506 let mut samples: Vec<u64> = Vec::with_capacity(SAMPLE_COUNT);
507
508 for _ in 0..SAMPLE_COUNT {
509 let start = std::time::Instant::now();
510 let _ = self.get_operation(PROBE_NAME);
512 let elapsed_ns = start.elapsed().as_nanos() as u64;
513 samples.push(elapsed_ns);
514 }
515
516 samples.sort_unstable();
517
518 DispatchBenchmarkResult::from_sorted_samples(&samples).unwrap_or(DispatchBenchmarkResult {
520 min_ns: 0,
521 max_ns: 0,
522 avg_ns: 0,
523 p95_ns: 0,
524 sample_count: 0,
525 })
526 }
527}
528
529use lazy_static::lazy_static;
531
532lazy_static! {
533 pub static ref F32_REGISTRY: DispatchRegistry<f32> = DispatchRegistry::new();
535
536 pub static ref F64_REGISTRY: DispatchRegistry<f64> = DispatchRegistry::new();
538
539 pub static ref I32_REGISTRY: DispatchRegistry<i32> = DispatchRegistry::new();
541}
542
543pub fn get_registry<T: 'static>() -> Option<&'static DispatchRegistry<T>> {
545 use std::any::TypeId;
546
547 let type_id = TypeId::of::<T>();
548
549 if type_id == TypeId::of::<f32>() {
550 Some(unsafe {
552 &*(&*F32_REGISTRY as *const DispatchRegistry<f32> as *const DispatchRegistry<T>)
553 })
554 } else if type_id == TypeId::of::<f64>() {
555 Some(unsafe {
557 &*(&*F64_REGISTRY as *const DispatchRegistry<f64> as *const DispatchRegistry<T>)
558 })
559 } else if type_id == TypeId::of::<i32>() {
560 Some(unsafe {
562 &*(&*I32_REGISTRY as *const DispatchRegistry<i32> as *const DispatchRegistry<T>)
563 })
564 } else {
565 None
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use scirs2_core::ndarray::array;
573
574 #[test]
575 fn test_backend_type_priority() {
576 assert!(BackendType::Cpu.priority() < BackendType::Cpu.priority() + 1);
577
578 #[cfg(feature = "simd")]
579 assert!(BackendType::SimdCpu.priority() > BackendType::Cpu.priority());
580 }
581
582 #[test]
583 fn test_operation_descriptor() {
584 let desc = OperationDescriptor::new("test_op", "binary")
585 .with_dtypes(vec![crate::DType::Float32])
586 .with_broadcast()
587 .with_rank_range(Some(1), Some(4));
588
589 assert_eq!(desc.name, "test_op");
590 assert_eq!(desc.category, "binary");
591 assert!(desc.supports_broadcast);
592 assert_eq!(desc.min_rank, Some(1));
593 assert_eq!(desc.max_rank, Some(4));
594 }
595
596 #[test]
597 fn test_registry_creation() {
598 let registry: DispatchRegistry<f32> = DispatchRegistry::new();
599 assert_eq!(registry.list_operations().len(), 0);
600 }
601
602 #[test]
603 fn test_operation_registration() {
604 let registry: DispatchRegistry<f32> = DispatchRegistry::new();
605
606 let desc = OperationDescriptor::new("add", "binary");
607 registry
608 .register_operation(desc)
609 .expect("test: register_operation should succeed");
610
611 assert_eq!(registry.list_operations().len(), 1);
612 assert!(registry.get_operation("add").is_some());
613 }
614
615 #[test]
616 fn test_duplicate_registration_fails() {
617 let registry: DispatchRegistry<f32> = DispatchRegistry::new();
618
619 let desc1 = OperationDescriptor::new("add", "binary");
620 let desc2 = OperationDescriptor::new("add", "binary");
621
622 registry
623 .register_operation(desc1)
624 .expect("test: register_operation should succeed");
625 assert!(registry.register_operation(desc2).is_err());
626 }
627
628 #[test]
629 fn test_kernel_registration() {
630 let registry: DispatchRegistry<f32> = DispatchRegistry::new();
631
632 let desc = OperationDescriptor::new("abs", "unary");
634 registry
635 .register_operation(desc)
636 .expect("test: register_operation should succeed");
637
638 fn abs_cpu(x: &Tensor<f32>) -> Result<Tensor<f32>> {
640 let data = x.data();
641 let abs_data: Vec<f32> = data.iter().map(|v| v.abs()).collect();
642 let array = scirs2_core::ndarray::ArrayD::from_shape_vec(x.shape().dims(), abs_data)
643 .expect("test: operation should succeed");
644 Ok(Tensor::from_array(array))
645 }
646
647 let kernel = KernelImplementation::unary(BackendType::Cpu, abs_cpu);
648 registry
649 .register_kernel("abs", kernel)
650 .expect("test: register_kernel should succeed");
651
652 assert_eq!(registry.available_backends("abs").len(), 1);
653 }
654
655 #[test]
656 fn test_unary_dispatch() {
657 let registry: DispatchRegistry<f32> = DispatchRegistry::new();
658
659 let desc = OperationDescriptor::new("negate", "unary");
661 registry
662 .register_operation(desc)
663 .expect("test: register_operation should succeed");
664
665 fn negate_cpu(x: &Tensor<f32>) -> Result<Tensor<f32>> {
667 let data = x.data();
668 let neg_data: Vec<f32> = data.iter().map(|v| -v).collect();
669 let array = scirs2_core::ndarray::ArrayD::from_shape_vec(x.shape().dims(), neg_data)
670 .expect("test: operation should succeed");
671 Ok(Tensor::from_array(array))
672 }
673
674 let kernel = KernelImplementation::unary(BackendType::Cpu, negate_cpu);
675 registry
676 .register_kernel("negate", kernel)
677 .expect("test: register_kernel should succeed");
678
679 let input = Tensor::from_array(array![1.0f32, 2.0, 3.0].into_dyn());
681 let result = registry
682 .dispatch_unary("negate", &input)
683 .expect("test: dispatch_unary should succeed");
684
685 assert_eq!(result.data(), &[-1.0f32, -2.0, -3.0]);
686 }
687
688 #[test]
689 fn test_binary_dispatch() {
690 let registry: DispatchRegistry<f32> = DispatchRegistry::new();
691
692 let desc = OperationDescriptor::new("add", "binary");
694 registry
695 .register_operation(desc)
696 .expect("test: register_operation should succeed");
697
698 fn add_cpu(a: &Tensor<f32>, b: &Tensor<f32>) -> Result<Tensor<f32>> {
700 let a_data = a.data();
701 let b_data = b.data();
702 let sum_data: Vec<f32> = a_data
703 .iter()
704 .zip(b_data.iter())
705 .map(|(x, y)| x + y)
706 .collect();
707 let array = scirs2_core::ndarray::ArrayD::from_shape_vec(a.shape().dims(), sum_data)
708 .expect("test: operation should succeed");
709 Ok(Tensor::from_array(array))
710 }
711
712 let kernel = KernelImplementation::binary(BackendType::Cpu, add_cpu);
713 registry
714 .register_kernel("add", kernel)
715 .expect("test: register_kernel should succeed");
716
717 let a = Tensor::from_array(array![1.0f32, 2.0, 3.0].into_dyn());
719 let b = Tensor::from_array(array![4.0f32, 5.0, 6.0].into_dyn());
720 let result = registry
721 .dispatch_binary("add", &a, &b)
722 .expect("test: dispatch_binary should succeed");
723
724 assert_eq!(result.data(), &[5.0f32, 7.0, 9.0]);
725 }
726
727 #[test]
728 fn test_device_mismatch_error() {
729 let registry: DispatchRegistry<f32> = DispatchRegistry::new();
730
731 let desc = OperationDescriptor::new("add", "binary");
732 registry
733 .register_operation(desc)
734 .expect("test: register_operation should succeed");
735
736 fn add_cpu(a: &Tensor<f32>, b: &Tensor<f32>) -> Result<Tensor<f32>> {
737 Ok(a.clone())
738 }
739
740 let kernel = KernelImplementation::binary(BackendType::Cpu, add_cpu);
741 registry
742 .register_kernel("add", kernel)
743 .expect("test: register_kernel should succeed");
744
745 let a = Tensor::from_array(array![1.0f32].into_dyn());
746 let b = Tensor::from_array(array![2.0f32].into_dyn());
747
748 let result = registry.dispatch_binary("add", &a, &b);
750 assert!(result.is_ok());
751 }
752
753 #[test]
754 fn test_global_registry_access() {
755 let registry = get_registry::<f32>();
756 assert!(registry.is_some());
757
758 let registry = get_registry::<f64>();
759 assert!(registry.is_some());
760
761 let registry = get_registry::<i32>();
762 assert!(registry.is_some());
763 }
764}