Skip to main content

tenflowers_core/
dispatch_registry.rs

1/// Unified Operation Dispatch Registry for TenfloweRS
2///
3/// This module provides a centralized system for registering and dispatching
4/// tensor operations across different backends (CPU, GPU, BLAS, etc.) with
5/// feature gating and capability detection.
6use crate::{Device, Result, Shape, Tensor, TensorError};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10/// Backend type for kernel implementations
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub enum BackendType {
13    /// CPU implementation (always available)
14    Cpu,
15    /// SIMD-optimized CPU implementation
16    #[cfg(feature = "simd")]
17    SimdCpu,
18    /// BLAS-accelerated implementation
19    #[cfg(feature = "blas")]
20    Blas,
21    /// GPU implementation via WGPU
22    #[cfg(feature = "gpu")]
23    Gpu,
24    /// CUDA implementation
25    #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
26    Cuda,
27    /// ROCm implementation
28    #[cfg(feature = "rocm")]
29    Rocm,
30    /// Metal Performance Shaders
31    #[cfg(all(feature = "metal", target_os = "macos"))]
32    Metal,
33}
34
35impl BackendType {
36    /// Check if this backend is available at runtime
37    pub fn is_available(&self) -> bool {
38        match self {
39            BackendType::Cpu => true,
40            #[cfg(feature = "simd")]
41            BackendType::SimdCpu => true,
42            #[cfg(feature = "blas")]
43            BackendType::Blas => crate::ops::lapack::is_lapack_available(),
44            #[cfg(feature = "gpu")]
45            BackendType::Gpu => true, // WGPU availability checked at context creation
46            #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
47            BackendType::Cuda => crate::gpu::cuda_kernels::is_cuda_available(),
48            #[cfg(feature = "rocm")]
49            BackendType::Rocm => false, // NOTE(v0.2): Implement ROCm availability check
50            #[cfg(all(feature = "metal", target_os = "macos"))]
51            BackendType::Metal => true,
52        }
53    }
54
55    /// Get priority for backend selection (higher = preferred)
56    pub fn priority(&self) -> u8 {
57        match self {
58            BackendType::Cpu => 0,
59            #[cfg(feature = "simd")]
60            BackendType::SimdCpu => 10,
61            #[cfg(feature = "blas")]
62            BackendType::Blas => 20,
63            #[cfg(feature = "gpu")]
64            BackendType::Gpu => 30,
65            #[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
66            BackendType::Cuda => 40,
67            #[cfg(feature = "rocm")]
68            BackendType::Rocm => 40,
69            #[cfg(all(feature = "metal", target_os = "macos"))]
70            BackendType::Metal => 50,
71        }
72    }
73
74    /// Get backend from device
75    pub fn from_device(device: &Device) -> Self {
76        match device {
77            Device::Cpu => BackendType::Cpu,
78            #[cfg(feature = "gpu")]
79            Device::Gpu(_) => BackendType::Gpu,
80            #[cfg(feature = "rocm")]
81            Device::Rocm(_) => BackendType::Rocm,
82        }
83    }
84}
85
86/// Operation metadata and description
87#[derive(Debug, Clone)]
88pub struct OperationDescriptor {
89    /// Unique operation name
90    pub name: String,
91    /// Operation category (e.g., "binary", "reduction", "matmul")
92    pub category: String,
93    /// Semantic version
94    pub version: String,
95    /// Supported data types
96    pub supported_dtypes: Vec<crate::DType>,
97    /// Minimum supported rank
98    pub min_rank: Option<usize>,
99    /// Maximum supported rank
100    pub max_rank: Option<usize>,
101    /// Whether operation supports broadcasting
102    pub supports_broadcast: bool,
103    /// Whether operation is in-place capable
104    pub supports_inplace: bool,
105}
106
107impl OperationDescriptor {
108    /// Create a new operation descriptor
109    pub fn new(name: impl Into<String>, category: impl Into<String>) -> Self {
110        Self {
111            name: name.into(),
112            category: category.into(),
113            version: "1.0.0".to_string(),
114            supported_dtypes: vec![crate::DType::Float32, crate::DType::Float64],
115            min_rank: None,
116            max_rank: None,
117            supports_broadcast: false,
118            supports_inplace: false,
119        }
120    }
121
122    /// Set supported data types
123    pub fn with_dtypes(mut self, dtypes: Vec<crate::DType>) -> Self {
124        self.supported_dtypes = dtypes;
125        self
126    }
127
128    /// Set rank constraints
129    pub fn with_rank_range(mut self, min: Option<usize>, max: Option<usize>) -> Self {
130        self.min_rank = min;
131        self.max_rank = max;
132        self
133    }
134
135    /// Enable broadcasting support
136    pub fn with_broadcast(mut self) -> Self {
137        self.supports_broadcast = true;
138        self
139    }
140
141    /// Enable in-place support
142    pub fn with_inplace(mut self) -> Self {
143        self.supports_inplace = true;
144        self
145    }
146}
147
148/// Kernel implementation function signature for unary operations
149pub type UnaryKernelFn<T> = fn(&Tensor<T>) -> Result<Tensor<T>>;
150
151/// Kernel implementation function signature for binary operations
152pub type BinaryKernelFn<T> = fn(&Tensor<T>, &Tensor<T>) -> Result<Tensor<T>>;
153
154/// Kernel implementation for a specific backend
155#[derive(Clone)]
156pub struct KernelImplementation<T> {
157    pub backend: BackendType,
158    pub unary_fn: Option<UnaryKernelFn<T>>,
159    pub binary_fn: Option<BinaryKernelFn<T>>,
160}
161
162impl<T> KernelImplementation<T> {
163    /// Create a new unary kernel implementation
164    pub fn unary(backend: BackendType, func: UnaryKernelFn<T>) -> Self {
165        Self {
166            backend,
167            unary_fn: Some(func),
168            binary_fn: None,
169        }
170    }
171
172    /// Create a new binary kernel implementation
173    pub fn binary(backend: BackendType, func: BinaryKernelFn<T>) -> Self {
174        Self {
175            backend,
176            unary_fn: None,
177            binary_fn: Some(func),
178        }
179    }
180}
181
182/// Registered operation with all backend implementations
183struct RegisteredOperation<T> {
184    descriptor: OperationDescriptor,
185    kernels: Vec<KernelImplementation<T>>,
186}
187
188impl<T> RegisteredOperation<T> {
189    fn new(descriptor: OperationDescriptor) -> Self {
190        Self {
191            descriptor,
192            kernels: Vec::new(),
193        }
194    }
195
196    fn add_kernel(&mut self, kernel: KernelImplementation<T>) {
197        self.kernels.push(kernel);
198    }
199
200    /// Select the best available kernel for the given device
201    fn select_kernel(&self, device: &Device) -> Option<&KernelImplementation<T>> {
202        let preferred_backend = BackendType::from_device(device);
203
204        // First, try to find the preferred backend
205        if let Some(kernel) = self
206            .kernels
207            .iter()
208            .find(|k| k.backend == preferred_backend && k.backend.is_available())
209        {
210            return Some(kernel);
211        }
212
213        // Fall back to the highest priority available backend
214        self.kernels
215            .iter()
216            .filter(|k| k.backend.is_available())
217            .max_by_key(|k| k.backend.priority())
218    }
219}
220
221/// Global operation dispatch registry
222pub struct DispatchRegistry<T> {
223    operations: Arc<RwLock<HashMap<String, RegisteredOperation<T>>>>,
224}
225
226impl<T> Default for DispatchRegistry<T> {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232impl<T> DispatchRegistry<T> {
233    /// Create a new dispatch registry
234    pub fn new() -> Self {
235        Self {
236            operations: Arc::new(RwLock::new(HashMap::new())),
237        }
238    }
239
240    /// Register a new operation
241    pub fn register_operation(&self, descriptor: OperationDescriptor) -> Result<()> {
242        let mut ops = self
243            .operations
244            .write()
245            .expect("write lock should not be poisoned");
246
247        if ops.contains_key(&descriptor.name) {
248            return Err(TensorError::invalid_argument(format!(
249                "Operation '{}' is already registered",
250                descriptor.name
251            )));
252        }
253
254        ops.insert(
255            descriptor.name.clone(),
256            RegisteredOperation::new(descriptor),
257        );
258        Ok(())
259    }
260
261    /// Register a kernel implementation for an operation
262    pub fn register_kernel(
263        &self,
264        operation_name: &str,
265        kernel: KernelImplementation<T>,
266    ) -> Result<()> {
267        let mut ops = self
268            .operations
269            .write()
270            .expect("write lock should not be poisoned");
271
272        let op = ops.get_mut(operation_name).ok_or_else(|| {
273            TensorError::invalid_argument(format!(
274                "Operation '{}' not found. Register the operation first.",
275                operation_name
276            ))
277        })?;
278
279        op.add_kernel(kernel);
280        Ok(())
281    }
282
283    /// Dispatch a unary operation
284    pub fn dispatch_unary(&self, operation_name: &str, input: &Tensor<T>) -> Result<Tensor<T>> {
285        let ops = self
286            .operations
287            .read()
288            .expect("read lock should not be poisoned");
289
290        let op = ops.get(operation_name).ok_or_else(|| {
291            TensorError::invalid_argument(format!(
292                "Operation '{}' not found in registry",
293                operation_name
294            ))
295        })?;
296
297        let kernel = op.select_kernel(input.device()).ok_or_else(|| {
298            TensorError::invalid_argument(format!(
299                "No available kernel for operation '{}' on device {:?}",
300                operation_name,
301                input.device()
302            ))
303        })?;
304
305        let kernel_fn = kernel.unary_fn.ok_or_else(|| {
306            TensorError::invalid_argument(format!(
307                "Operation '{}' does not support unary execution",
308                operation_name
309            ))
310        })?;
311
312        kernel_fn(input)
313    }
314
315    /// Dispatch a binary operation
316    pub fn dispatch_binary(
317        &self,
318        operation_name: &str,
319        lhs: &Tensor<T>,
320        rhs: &Tensor<T>,
321    ) -> Result<Tensor<T>> {
322        // Check device compatibility
323        if lhs.device() != rhs.device() {
324            return Err(TensorError::device_mismatch(
325                operation_name,
326                &format!("{:?}", lhs.device()),
327                &format!("{:?}", rhs.device()),
328            ));
329        }
330
331        let ops = self
332            .operations
333            .read()
334            .expect("read lock should not be poisoned");
335
336        let op = ops.get(operation_name).ok_or_else(|| {
337            TensorError::invalid_argument(format!(
338                "Operation '{}' not found in registry",
339                operation_name
340            ))
341        })?;
342
343        let kernel = op.select_kernel(lhs.device()).ok_or_else(|| {
344            TensorError::invalid_argument(format!(
345                "No available kernel for operation '{}' on device {:?}",
346                operation_name,
347                lhs.device()
348            ))
349        })?;
350
351        let kernel_fn = kernel.binary_fn.ok_or_else(|| {
352            TensorError::invalid_argument(format!(
353                "Operation '{}' does not support binary execution",
354                operation_name
355            ))
356        })?;
357
358        kernel_fn(lhs, rhs)
359    }
360
361    /// Get operation descriptor
362    pub fn get_operation(&self, name: &str) -> Option<OperationDescriptor> {
363        let ops = self
364            .operations
365            .read()
366            .expect("read lock should not be poisoned");
367        ops.get(name).map(|op| op.descriptor.clone())
368    }
369
370    /// List all registered operations
371    pub fn list_operations(&self) -> Vec<String> {
372        let ops = self
373            .operations
374            .read()
375            .expect("read lock should not be poisoned");
376        ops.keys().cloned().collect()
377    }
378
379    /// Get available backends for an operation
380    pub fn available_backends(&self, operation_name: &str) -> Vec<BackendType> {
381        let ops = self
382            .operations
383            .read()
384            .expect("read lock should not be poisoned");
385
386        if let Some(op) = ops.get(operation_name) {
387            op.kernels
388                .iter()
389                .filter(|k| k.backend.is_available())
390                .map(|k| k.backend)
391                .collect()
392        } else {
393            Vec::new()
394        }
395    }
396}
397
398/// Macro to simplify operation registration
399#[macro_export]
400macro_rules! register_operation {
401    ($registry:expr, $name:expr, $category:expr) => {
402        $registry.register_operation(
403            $crate::OperationDescriptor::new($name, $category)
404        ).expect("operation registration should succeed");
405    };
406    ($registry:expr, $name:expr, $category:expr, dtypes: [$($dtype:expr),*]) => {
407        $registry.register_operation(
408            $crate::OperationDescriptor::new($name, $category)
409                .with_dtypes(vec![$($dtype),*])
410        ).expect("operation registration with dtypes should succeed");
411    };
412    ($registry:expr, $name:expr, $category:expr, rank: $min:expr, $max:expr) => {
413        $registry.register_operation(
414            $crate::OperationDescriptor::new($name, $category)
415                .with_rank_range(Some($min), Some($max))
416        ).expect("operation registration with rank range should succeed");
417    };
418}
419
420/// Macro to register a unary kernel
421#[macro_export]
422macro_rules! register_unary_kernel {
423    ($registry:expr, $op_name:expr, $backend:expr, $func:expr) => {
424        $registry
425            .register_kernel(
426                $op_name,
427                $crate::KernelImplementation::unary($backend, $func),
428            )
429            .expect("unary kernel registration should succeed");
430    };
431}
432
433/// Macro to register a binary kernel
434#[macro_export]
435macro_rules! register_binary_kernel {
436    ($registry:expr, $op_name:expr, $backend:expr, $func:expr) => {
437        $registry
438            .register_kernel(
439                $op_name,
440                $crate::KernelImplementation::binary($backend, $func),
441            )
442            .expect("binary kernel registration should succeed");
443    };
444}
445
446/// Benchmark result for dispatch overhead measurements.
447///
448/// Produced by [`DispatchRegistry::benchmark_overhead`], which runs a fixed
449/// number of no-op dispatches and collects per-sample nanosecond timings.
450#[derive(Debug, Clone)]
451pub struct DispatchBenchmarkResult {
452    /// Minimum observed latency in nanoseconds.
453    pub min_ns: u64,
454    /// Maximum observed latency in nanoseconds.
455    pub max_ns: u64,
456    /// Arithmetic mean latency in nanoseconds (truncated to integer).
457    pub avg_ns: u64,
458    /// 95th-percentile latency in nanoseconds.
459    pub p95_ns: u64,
460    /// Number of samples collected.
461    pub sample_count: usize,
462}
463
464impl DispatchBenchmarkResult {
465    /// Build a `DispatchBenchmarkResult` from a **pre-sorted** slice of nanosecond samples.
466    ///
467    /// Returns `None` when `samples` is empty.
468    pub fn from_sorted_samples(samples: &[u64]) -> Option<Self> {
469        if samples.is_empty() {
470            return None;
471        }
472
473        let min_ns = *samples.first().unwrap_or(&0);
474        let max_ns = *samples.last().unwrap_or(&0);
475
476        let sum: u64 = samples.iter().sum();
477        let avg_ns = sum / samples.len() as u64;
478
479        // p95: index = floor(0.95 * n), clamped to valid range.
480        let p95_idx = ((samples.len() as f64 * 0.95) as usize).min(samples.len() - 1);
481        let p95_ns = samples[p95_idx];
482
483        Some(Self {
484            min_ns,
485            max_ns,
486            avg_ns,
487            p95_ns,
488            sample_count: samples.len(),
489        })
490    }
491}
492
493impl<T> DispatchRegistry<T> {
494    /// Measure the per-call overhead of a registry read-lock + no-op lookup.
495    ///
496    /// Runs 1 000 no-op dispatches ("__overhead_probe__"), collects the timing
497    /// for each call, sorts the samples, and returns the aggregated statistics.
498    ///
499    /// The probe operation is expected to be absent from the registry, so each
500    /// call exercises the read-lock acquisition plus a single failed hash-map
501    /// probe — which is the hot path for every real dispatch.
502    pub fn benchmark_overhead(&self) -> DispatchBenchmarkResult {
503        const SAMPLE_COUNT: usize = 1_000;
504        const PROBE_NAME: &str = "__overhead_probe__";
505
506        let mut samples: Vec<u64> = Vec::with_capacity(SAMPLE_COUNT);
507
508        for _ in 0..SAMPLE_COUNT {
509            let start = std::time::Instant::now();
510            // Perform a failed lookup to measure lock + probe cost.
511            let _ = self.get_operation(PROBE_NAME);
512            let elapsed_ns = start.elapsed().as_nanos() as u64;
513            samples.push(elapsed_ns);
514        }
515
516        samples.sort_unstable();
517
518        // SAFETY: samples is always non-empty (SAMPLE_COUNT > 0).
519        DispatchBenchmarkResult::from_sorted_samples(&samples).unwrap_or(DispatchBenchmarkResult {
520            min_ns: 0,
521            max_ns: 0,
522            avg_ns: 0,
523            p95_ns: 0,
524            sample_count: 0,
525        })
526    }
527}
528
529/// Global registry instance (lazily initialized)
530use lazy_static::lazy_static;
531
532lazy_static! {
533    /// Global f32 dispatch registry
534    pub static ref F32_REGISTRY: DispatchRegistry<f32> = DispatchRegistry::new();
535
536    /// Global f64 dispatch registry
537    pub static ref F64_REGISTRY: DispatchRegistry<f64> = DispatchRegistry::new();
538
539    /// Global i32 dispatch registry
540    pub static ref I32_REGISTRY: DispatchRegistry<i32> = DispatchRegistry::new();
541}
542
543/// Get the global registry for a specific type
544pub fn get_registry<T: 'static>() -> Option<&'static DispatchRegistry<T>> {
545    use std::any::TypeId;
546
547    let type_id = TypeId::of::<T>();
548
549    if type_id == TypeId::of::<f32>() {
550        // SAFETY: We've checked that T is f32
551        Some(unsafe {
552            &*(&*F32_REGISTRY as *const DispatchRegistry<f32> as *const DispatchRegistry<T>)
553        })
554    } else if type_id == TypeId::of::<f64>() {
555        // SAFETY: We've checked that T is f64
556        Some(unsafe {
557            &*(&*F64_REGISTRY as *const DispatchRegistry<f64> as *const DispatchRegistry<T>)
558        })
559    } else if type_id == TypeId::of::<i32>() {
560        // SAFETY: We've checked that T is i32
561        Some(unsafe {
562            &*(&*I32_REGISTRY as *const DispatchRegistry<i32> as *const DispatchRegistry<T>)
563        })
564    } else {
565        None
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use scirs2_core::ndarray::array;
573
574    #[test]
575    fn test_backend_type_priority() {
576        assert!(BackendType::Cpu.priority() < BackendType::Cpu.priority() + 1);
577
578        #[cfg(feature = "simd")]
579        assert!(BackendType::SimdCpu.priority() > BackendType::Cpu.priority());
580    }
581
582    #[test]
583    fn test_operation_descriptor() {
584        let desc = OperationDescriptor::new("test_op", "binary")
585            .with_dtypes(vec![crate::DType::Float32])
586            .with_broadcast()
587            .with_rank_range(Some(1), Some(4));
588
589        assert_eq!(desc.name, "test_op");
590        assert_eq!(desc.category, "binary");
591        assert!(desc.supports_broadcast);
592        assert_eq!(desc.min_rank, Some(1));
593        assert_eq!(desc.max_rank, Some(4));
594    }
595
596    #[test]
597    fn test_registry_creation() {
598        let registry: DispatchRegistry<f32> = DispatchRegistry::new();
599        assert_eq!(registry.list_operations().len(), 0);
600    }
601
602    #[test]
603    fn test_operation_registration() {
604        let registry: DispatchRegistry<f32> = DispatchRegistry::new();
605
606        let desc = OperationDescriptor::new("add", "binary");
607        registry
608            .register_operation(desc)
609            .expect("test: register_operation should succeed");
610
611        assert_eq!(registry.list_operations().len(), 1);
612        assert!(registry.get_operation("add").is_some());
613    }
614
615    #[test]
616    fn test_duplicate_registration_fails() {
617        let registry: DispatchRegistry<f32> = DispatchRegistry::new();
618
619        let desc1 = OperationDescriptor::new("add", "binary");
620        let desc2 = OperationDescriptor::new("add", "binary");
621
622        registry
623            .register_operation(desc1)
624            .expect("test: register_operation should succeed");
625        assert!(registry.register_operation(desc2).is_err());
626    }
627
628    #[test]
629    fn test_kernel_registration() {
630        let registry: DispatchRegistry<f32> = DispatchRegistry::new();
631
632        // Register operation
633        let desc = OperationDescriptor::new("abs", "unary");
634        registry
635            .register_operation(desc)
636            .expect("test: register_operation should succeed");
637
638        // Register CPU kernel
639        fn abs_cpu(x: &Tensor<f32>) -> Result<Tensor<f32>> {
640            let data = x.data();
641            let abs_data: Vec<f32> = data.iter().map(|v| v.abs()).collect();
642            let array = scirs2_core::ndarray::ArrayD::from_shape_vec(x.shape().dims(), abs_data)
643                .expect("test: operation should succeed");
644            Ok(Tensor::from_array(array))
645        }
646
647        let kernel = KernelImplementation::unary(BackendType::Cpu, abs_cpu);
648        registry
649            .register_kernel("abs", kernel)
650            .expect("test: register_kernel should succeed");
651
652        assert_eq!(registry.available_backends("abs").len(), 1);
653    }
654
655    #[test]
656    fn test_unary_dispatch() {
657        let registry: DispatchRegistry<f32> = DispatchRegistry::new();
658
659        // Register operation
660        let desc = OperationDescriptor::new("negate", "unary");
661        registry
662            .register_operation(desc)
663            .expect("test: register_operation should succeed");
664
665        // Register CPU kernel
666        fn negate_cpu(x: &Tensor<f32>) -> Result<Tensor<f32>> {
667            let data = x.data();
668            let neg_data: Vec<f32> = data.iter().map(|v| -v).collect();
669            let array = scirs2_core::ndarray::ArrayD::from_shape_vec(x.shape().dims(), neg_data)
670                .expect("test: operation should succeed");
671            Ok(Tensor::from_array(array))
672        }
673
674        let kernel = KernelImplementation::unary(BackendType::Cpu, negate_cpu);
675        registry
676            .register_kernel("negate", kernel)
677            .expect("test: register_kernel should succeed");
678
679        // Test dispatch
680        let input = Tensor::from_array(array![1.0f32, 2.0, 3.0].into_dyn());
681        let result = registry
682            .dispatch_unary("negate", &input)
683            .expect("test: dispatch_unary should succeed");
684
685        assert_eq!(result.data(), &[-1.0f32, -2.0, -3.0]);
686    }
687
688    #[test]
689    fn test_binary_dispatch() {
690        let registry: DispatchRegistry<f32> = DispatchRegistry::new();
691
692        // Register operation
693        let desc = OperationDescriptor::new("add", "binary");
694        registry
695            .register_operation(desc)
696            .expect("test: register_operation should succeed");
697
698        // Register CPU kernel
699        fn add_cpu(a: &Tensor<f32>, b: &Tensor<f32>) -> Result<Tensor<f32>> {
700            let a_data = a.data();
701            let b_data = b.data();
702            let sum_data: Vec<f32> = a_data
703                .iter()
704                .zip(b_data.iter())
705                .map(|(x, y)| x + y)
706                .collect();
707            let array = scirs2_core::ndarray::ArrayD::from_shape_vec(a.shape().dims(), sum_data)
708                .expect("test: operation should succeed");
709            Ok(Tensor::from_array(array))
710        }
711
712        let kernel = KernelImplementation::binary(BackendType::Cpu, add_cpu);
713        registry
714            .register_kernel("add", kernel)
715            .expect("test: register_kernel should succeed");
716
717        // Test dispatch
718        let a = Tensor::from_array(array![1.0f32, 2.0, 3.0].into_dyn());
719        let b = Tensor::from_array(array![4.0f32, 5.0, 6.0].into_dyn());
720        let result = registry
721            .dispatch_binary("add", &a, &b)
722            .expect("test: dispatch_binary should succeed");
723
724        assert_eq!(result.data(), &[5.0f32, 7.0, 9.0]);
725    }
726
727    #[test]
728    fn test_device_mismatch_error() {
729        let registry: DispatchRegistry<f32> = DispatchRegistry::new();
730
731        let desc = OperationDescriptor::new("add", "binary");
732        registry
733            .register_operation(desc)
734            .expect("test: register_operation should succeed");
735
736        fn add_cpu(a: &Tensor<f32>, b: &Tensor<f32>) -> Result<Tensor<f32>> {
737            Ok(a.clone())
738        }
739
740        let kernel = KernelImplementation::binary(BackendType::Cpu, add_cpu);
741        registry
742            .register_kernel("add", kernel)
743            .expect("test: register_kernel should succeed");
744
745        let a = Tensor::from_array(array![1.0f32].into_dyn());
746        let b = Tensor::from_array(array![2.0f32].into_dyn());
747
748        // Both on CPU, should work
749        let result = registry.dispatch_binary("add", &a, &b);
750        assert!(result.is_ok());
751    }
752
753    #[test]
754    fn test_global_registry_access() {
755        let registry = get_registry::<f32>();
756        assert!(registry.is_some());
757
758        let registry = get_registry::<f64>();
759        assert!(registry.is_some());
760
761        let registry = get_registry::<i32>();
762        assert!(registry.is_some());
763    }
764}