Skip to main content

scirs2_core/array_protocol/
mod.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under the Apache License, Version 2.0
4// (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
5//
6
7//! Implementation of Array Protocol (similar to ``NumPy``'s `__array_function__` protocol)
8//!
9//! This module provides a mechanism for third-party array implementations to
10//! override ``SciRS2`` functions. It is inspired by ``NumPy``'s `__array_function__`
11//! protocol defined in NEP-18.
12//!
13//! The protocol enables third-party arrays to implement ``SciRS2`` functions in a way
14//! that is recognized by the ``SciRS2`` library. This allows for seamless integration with
15//! distributed arrays, GPU arrays, and other custom array implementations.
16//!
17//! ## Core Components
18//!
19//! The Array Protocol system includes:
20//!
21//! * Specialized array implementations (GPU, distributed, JIT)
22//! * Automatic device placement with `AutoDevice`
23//! * Mixed-precision operations
24//! * Neural network layers and models
25//! * Gradient computation and training capabilities
26//! * Distributed training and model serialization
27
28use std::any::{Any, TypeId};
29use std::collections::HashMap;
30use std::fmt::{Debug, Display};
31use std::marker::PhantomData;
32use std::sync::{Arc, LazyLock, RwLock};
33use std::time::{Duration, Instant};
34
35use crate::error::{CoreError, CoreResult, ErrorContext};
36
37// Internal submodules
38mod distributed_impl;
39mod gpu_impl;
40#[cfg(feature = "array_protocol_wgpu")]
41pub mod gpu_ndarray;
42#[cfg(feature = "array_protocol_wgpu")]
43mod gpu_ndarray_shaders;
44mod jit_impl;
45mod operations;
46
47// Re-export the array_function_dispatch macro
48pub use crate::array_function_dispatch;
49
50// Public submodules
51pub mod auto_device;
52pub mod distributed_training;
53pub mod grad;
54pub mod mixed_precision;
55pub mod ml_ops;
56pub mod neural;
57#[cfg(feature = "serialization")]
58pub mod serialization;
59pub mod training;
60
61/// Trait for objects that can handle the array protocol.
62///
63/// This is similar to `NumPy`'s `__array_function__` protocol.
64pub trait ArrayProtocol: Any + Send + Sync {
65    /// Implementation of the array protocol.
66    ///
67    /// * `func` - The function being called
68    /// * `types` - The types of all arguments that implement `ArrayProtocol`
69    /// * `args` - The arguments to the function
70    /// * `kwargs` - Named arguments to the function
71    ///
72    /// Returns `Ok(result)` if the operation is successful, or `Err(NotImplemented)`
73    /// if the operation is not implemented for this type.
74    ///
75    /// # Errors
76    ///
77    /// Returns `Err(NotImplemented)` if the operation is not supported by this array type.
78    fn array_function(
79        &self,
80        func: &ArrayFunction,
81        types: &[TypeId],
82        args: &[Box<dyn Any>],
83        kwargs: &HashMap<String, Box<dyn Any>>,
84    ) -> Result<Box<dyn Any>, NotImplemented>;
85
86    /// Get the array as Any for downcasting
87    #[must_use]
88    fn as_any(&self) -> &dyn Any;
89
90    /// Get the shape of the array (default implementation returns empty slice)
91    #[must_use]
92    fn shape(&self) -> &[usize] {
93        &[]
94    }
95
96    /// Get the data type of the array (default implementation returns f64)
97    #[must_use]
98    fn dtype(&self) -> TypeId {
99        TypeId::of::<f64>()
100    }
101
102    /// Clone this array protocol object.
103    #[must_use]
104    fn box_clone(&self) -> Box<dyn ArrayProtocol>;
105}
106
107/// Make `Box<dyn ArrayProtocol>` cloneable via `box_clone`
108impl Clone for Box<dyn ArrayProtocol> {
109    fn clone(&self) -> Self {
110        self.box_clone()
111    }
112}
113
114/// Marker for functions not implemented by a specific type.
115///
116/// This is part of the Array Protocol API design and is used as a marker to indicate
117/// that a function is not implemented by a specific array type. It's different from
118/// the `CoreError::NotImplementedError` enum variant, which is used for error reporting.
119///
120/// When an error is propagated up the call chain, `NotImplemented` is converted
121/// to `OperationError::NotImplemented` and then to `CoreError::NotImplementedError`
122/// for consistent error handling.
123#[derive(Debug, Clone, Copy)]
124pub struct NotImplemented;
125
126impl Display for NotImplemented {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        write!(f, "NotImplemented")
129    }
130}
131
132/// Type alias for the complex function implementation type
133pub type ArrayFunctionImpl = dyn Fn(&[Box<dyn Any>], &HashMap<String, Box<dyn Any>>) -> CoreResult<Box<dyn Any>>
134    + Send
135    + Sync;
136
137/// A wrapper for functions that can be overridden by the array protocol.
138#[derive(Clone)]
139pub struct ArrayFunction {
140    /// The name of the function, including its module path
141    pub name: &'static str,
142
143    /// The function implementation
144    pub implementation: Arc<ArrayFunctionImpl>,
145}
146
147impl Debug for ArrayFunction {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        f.debug_struct("ArrayFunction")
150            .field("name", &self.name)
151            .finish_non_exhaustive()
152    }
153}
154
155impl PartialEq for ArrayFunction {
156    fn eq(&self, other: &Self) -> bool {
157        self.name == other.name
158    }
159}
160
161impl Eq for ArrayFunction {}
162
163impl std::hash::Hash for ArrayFunction {
164    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
165        self.name.hash(state);
166    }
167}
168
169impl ArrayFunction {
170    /// Create a new array function with the given name
171    #[must_use]
172    pub fn new(name: &'static str) -> Self {
173        Self {
174            name,
175            // Default implementation that returns NotImplemented
176            implementation: Arc::new(|_args, _kwargs| {
177                Err(CoreError::NotImplementedError(ErrorContext::new(
178                    "Function not implemented".to_string(),
179                )))
180            }),
181        }
182    }
183}
184
185/// Cache entry for function dispatch optimization
186#[derive(Debug, Clone)]
187pub struct DispatchCacheEntry {
188    /// Type signature for the cached result
189    #[allow(dead_code)]
190    type_signature: Vec<TypeId>,
191    /// Which implementation type to try first
192    #[allow(dead_code)]
193    preferred_impl_type: TypeId,
194    /// Cache timestamp for TTL management
195    timestamp: Instant,
196    /// Number of cache hits
197    hit_count: u64,
198}
199
200/// Registry of all array functions with dispatch caching.
201#[derive(Debug)]
202pub struct ArrayFunctionRegistry {
203    /// Map of function names to array functions
204    functions: HashMap<&'static str, ArrayFunction>,
205    /// Dispatch cache for performance optimization
206    dispatch_cache: HashMap<(&'static str, Vec<TypeId>), DispatchCacheEntry>,
207    /// Maximum cache size to prevent unbounded growth
208    max_cache_size: usize,
209    /// Cache TTL for entries (prevents stale cache)
210    cache_ttl: Duration,
211}
212
213impl Default for ArrayFunctionRegistry {
214    fn default() -> Self {
215        Self {
216            functions: HashMap::new(),
217            dispatch_cache: HashMap::new(),
218            max_cache_size: 1000,                // Reasonable default cache size
219            cache_ttl: Duration::from_secs(300), // 5 minutes TTL
220        }
221    }
222}
223
224impl ArrayFunctionRegistry {
225    /// Get the global registry.
226    #[must_use]
227    pub fn global() -> &'static RwLock<Self> {
228        static REGISTRY: LazyLock<RwLock<ArrayFunctionRegistry>> =
229            LazyLock::new(|| RwLock::new(ArrayFunctionRegistry::default()));
230        &REGISTRY
231    }
232
233    /// Register a new array function.
234    pub fn register(&mut self, func: ArrayFunction) {
235        self.functions.insert(func.name, func);
236    }
237
238    /// Get an array function by name.
239    #[must_use]
240    #[allow(dead_code)]
241    pub fn get(&self, name: &str) -> Option<&ArrayFunction> {
242        self.functions.get(name)
243    }
244
245    /// Get all registered functions.
246    #[must_use]
247    pub fn all_functions(&self) -> Vec<&ArrayFunction> {
248        self.functions.values().collect()
249    }
250
251    /// Get cached dispatch entry for optimization
252    #[must_use]
253    pub fn get_cached_dispatch(
254        &self,
255        funcname: &'static str,
256        types: &[TypeId],
257    ) -> Option<&DispatchCacheEntry> {
258        let key = (funcname, types.to_vec());
259        if let Some(entry) = self.dispatch_cache.get(&key) {
260            // Check if cache entry is still valid (TTL check)
261            if entry.timestamp.elapsed() < self.cache_ttl {
262                return Some(entry);
263            }
264        }
265        None
266    }
267
268    /// Cache dispatch result for future optimization
269    pub fn cache_dispatch(
270        &mut self,
271        funcname: &'static str,
272        types: Vec<TypeId>,
273        impl_type: TypeId,
274    ) {
275        // Clean cache if it's getting too large
276        if self.dispatch_cache.len() >= self.max_cache_size {
277            self.cleanup_cache();
278        }
279
280        let key = (funcname, types.clone());
281        let entry = DispatchCacheEntry {
282            type_signature: types,
283            preferred_impl_type: impl_type,
284            timestamp: Instant::now(),
285            hit_count: 0,
286        };
287        self.dispatch_cache.insert(key, entry);
288    }
289
290    /// Update cache hit count for an entry
291    pub fn update_cache_hit(&mut self, funcname: &'static str, types: &[TypeId]) {
292        let key = (funcname, types.to_vec());
293        if let Some(entry) = self.dispatch_cache.get_mut(&key) {
294            entry.hit_count += 1;
295        }
296    }
297
298    /// Clean up expired cache entries
299    fn cleanup_cache(&mut self) {
300        let now = Instant::now();
301        self.dispatch_cache
302            .retain(|_, entry| now.duration_since(entry.timestamp) < self.cache_ttl);
303
304        // If still too large, remove least recently used entries
305        if self.dispatch_cache.len() >= self.max_cache_size {
306            let mut entries: Vec<_> = self
307                .dispatch_cache
308                .iter()
309                .map(|(k, v)| (k.clone(), v.hit_count))
310                .collect();
311            entries.sort_by_key(|(_, hit_count)| *hit_count);
312
313            // Remove bottom 25% of entries by hit count
314            let to_remove = self.dispatch_cache.len() / 4;
315            let keys_to_remove: Vec<_> = entries
316                .iter()
317                .take(to_remove)
318                .map(|(key, _)| key.clone())
319                .collect();
320            for key in keys_to_remove {
321                self.dispatch_cache.remove(&key);
322            }
323        }
324    }
325
326    /// Get cache statistics for monitoring
327    #[must_use]
328    pub fn cache_stats(&self) -> HashMap<String, u64> {
329        let mut stats = HashMap::new();
330        stats.insert("cache_size".to_string(), self.dispatch_cache.len() as u64);
331        stats.insert("max_cache_size".to_string(), self.max_cache_size as u64);
332
333        let total_hits: u64 = self.dispatch_cache.values().map(|e| e.hit_count).sum();
334        stats.insert("total_hits".to_string(), total_hits);
335
336        stats
337    }
338}
339
340/// Helper function to extract all arguments implementing the `ArrayProtocol` trait.
341///
342/// This is similar to `NumPy`'s `_get_implementing_args` function.
343/// Optimized version with pre-allocated capacity and fast-path for common cases.
344#[allow(dead_code)]
345pub fn get_implementing_args(args: &[Box<dyn Any>]) -> Vec<(TypeId, &dyn ArrayProtocol)> {
346    if args.is_empty() {
347        return Vec::new();
348    }
349
350    // Pre-allocate with capacity to avoid reallocation
351    let mut implementing_args = Vec::with_capacity(args.len());
352
353    for arg in args {
354        if let Some(array_protocol_obj) = arg.downcast_ref::<Box<dyn ArrayProtocol>>() {
355            let type_id = (**array_protocol_obj).type_id();
356            implementing_args.push((type_id, &**array_protocol_obj));
357        }
358    }
359
360    // Sort implementing _args by TypeId for deterministic dispatch order
361    // This ensures consistent dispatch behavior across calls
362    implementing_args.sort_by_key(|&_type_id_| {
363        // Use TypeId hash for deterministic ordering
364        use std::hash::{Hash, Hasher};
365        let mut hasher = std::collections::hash_map::DefaultHasher::new();
366        std::any::TypeId::of::<i32>().hash(&mut hasher);
367        hasher.finish()
368    });
369
370    implementing_args
371}
372
373/// Calls the array protocol with the given function and arguments.
374///
375/// * `func` - The array function to call
376/// * `args` - The arguments to the function
377/// * `kwargs` - Named arguments to the function
378///
379/// Returns the result of the function call, or an error if the function
380/// cannot be dispatched to any of the array protocol implementations.
381///
382/// Optimized version with caching and fast-path optimizations.
383#[allow(dead_code)]
384pub fn array_function_dispatch(
385    func: &ArrayFunction,
386    args: &[Box<dyn Any>],
387    kwargs: &HashMap<String, Box<dyn Any>>,
388) -> CoreResult<Box<dyn Any>> {
389    // Fast path for empty args
390    if args.is_empty() {
391        return (func.implementation)(args, kwargs);
392    }
393
394    // Find all arguments implementing ArrayProtocol
395    let implementing_args = get_implementing_args(args);
396
397    if implementing_args.is_empty() {
398        // No arguments implement ArrayProtocol, use default implementation
399        return (func.implementation)(args, kwargs);
400    }
401
402    // Fast path for single implementing argument
403    if implementing_args.len() == 1 {
404        let (type_id, array_protocol_obj) = implementing_args[0];
405        let types = [type_id];
406        match array_protocol_obj.array_function(func, &types, args, kwargs) {
407            Ok(result) => return Ok(result),
408            Err(NotImplemented) => {
409                return Err(CoreError::DispatchError(ErrorContext::new(format!(
410                    "No implementation found for {} with type {:?}",
411                    func.name, type_id
412                ))));
413            }
414        }
415    }
416
417    // Extract all unique types that implement ArrayProtocol (optimized)
418    let mut unique_types = Vec::with_capacity(implementing_args.len());
419    let mut seen_types = std::collections::HashSet::with_capacity(implementing_args.len());
420
421    for &(type_id, _) in &implementing_args {
422        if seen_types.insert(type_id) {
423            unique_types.push(type_id);
424        }
425    }
426
427    // Try dispatching to each implementation in priority order
428    for (_, array_protocol_obj) in implementing_args {
429        if let Ok(result) = array_protocol_obj.array_function(func, &unique_types, args, kwargs) {
430            return Ok(result);
431        }
432    }
433
434    // If we get here, no implementation was found
435    Err(CoreError::DispatchError(ErrorContext::new(format!(
436        "No implementation found for {} with {} argument types: {:?}",
437        func.name,
438        unique_types.len(),
439        unique_types
440    ))))
441}
442
443/// Decorator for adding array function dispatch capabilities to a function.
444///
445/// This is similar to `NumPy`'s `array_function_dispatch` decorator.
446pub struct ArrayFunctionDecorator<F> {
447    function: F,
448    name: &'static str,
449}
450
451impl<F> ArrayFunctionDecorator<F>
452where
453    F: Send + Sync + 'static,
454{
455    /// Create a new array function decorator.
456    #[must_use]
457    pub fn new(function: F, name: &'static str) -> Self {
458        Self { function, name }
459    }
460
461    /// Register the function with the global registry.
462    pub fn register(self) -> F {
463        let implementation = Arc::new(
464            move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
465                // Implementation that converts generic arguments to specific types
466                // and calls the original function
467                // This is a simplified version - in practice, we would need more complex
468                // type conversion
469                Err(CoreError::NotImplementedError(ErrorContext::new(
470                    "ArrayFunctionDecorator: Type conversion in array_function_dispatch is not implemented yet".to_string()
471                )))
472            },
473        );
474
475        let func = ArrayFunction {
476            name: self.name,
477            implementation,
478        };
479
480        // Register the function with the global registry
481        let registry = ArrayFunctionRegistry::global();
482        if let Ok(mut registry) = registry.write() {
483            registry.register(func);
484        } else {
485            eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry, skipping function registration");
486            // Continue without registration - this may result in reduced functionality but avoids crash
487        }
488
489        self.function
490    }
491}
492
493/// Trait for arrays that can support GPU operations.
494pub trait GPUArray: ArrayProtocol {
495    /// Move the array to GPU.
496    fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>>;
497
498    /// Move the array from GPU to CPU.
499    fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
500
501    /// Check if the array is on GPU.
502    #[must_use]
503    fn is_on_gpu(&self) -> bool;
504
505    /// Get information about the GPU device that holds this array.
506    #[must_use]
507    fn device_info(&self) -> HashMap<String, String>;
508}
509
510/// Trait for distributed arrays that can span multiple machines.
511pub trait DistributedArray: ArrayProtocol {
512    /// Get information about the distribution of this array.
513    #[must_use]
514    fn distribution_info(&self) -> HashMap<String, String>;
515
516    /// Gather the distributed array to a single node.
517    fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
518
519    /// Scatter a regular array to a distributed array.
520    fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>>;
521
522    /// Check if this array is distributed.
523    #[must_use]
524    fn is_distributed(&self) -> bool;
525}
526
527/// JIT (Just-In-Time) compilation support for arrays.
528pub trait JITArray: ArrayProtocol {
529    /// Compile an expression to be evaluated on this array.
530    fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>>;
531
532    /// Check if JIT compilation is supported for this array type.
533    #[must_use]
534    fn supports_jit(&self) -> bool;
535
536    /// Get information about the JIT compiler being used.
537    #[must_use]
538    fn jit_info(&self) -> HashMap<String, String>;
539}
540
541/// A JIT-compiled function that can be evaluated on arrays.
542pub trait JITFunction: Send + Sync {
543    /// Evaluate the function with the given arguments.
544    fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>>;
545
546    /// Get the source code of the compiled function.
547    #[must_use]
548    fn source(&self) -> String;
549
550    /// Get information about how the function was compiled.
551    #[must_use]
552    fn compile_info(&self) -> HashMap<String, String>;
553
554    /// Clone this JIT function into a `Box<dyn JITFunction>`.
555    #[must_use]
556    fn clone_box(&self) -> Box<dyn JITFunction>;
557}
558
559/// A factory for creating JIT functions for specific array implementations.
560pub trait JITFunctionFactory: Send + Sync {
561    /// Create a new JIT function for the given expression and array type.
562    fn create_jit_function(
563        &self,
564        expression: &str,
565        array_typeid: TypeId,
566    ) -> CoreResult<Box<dyn JITFunction>>;
567
568    /// Check if this factory supports the given array type.
569    #[must_use]
570    fn supports_array_type(&self, array_typeid: TypeId) -> bool;
571}
572
573/// Registry of JIT function factories.
574#[derive(Default)]
575pub struct JITFactoryRegistry {
576    factories: Vec<Box<dyn JITFunctionFactory>>,
577}
578
579impl std::fmt::Debug for JITFactoryRegistry {
580    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581        write!(
582            f,
583            "JITFactoryRegistry {{ factories: {} }}",
584            self.factories.len()
585        )
586    }
587}
588
589impl JITFactoryRegistry {
590    /// Get the global registry.
591    #[must_use]
592    pub fn global() -> &'static RwLock<Self> {
593        static REGISTRY: LazyLock<RwLock<JITFactoryRegistry>> = LazyLock::new(|| {
594            RwLock::new(JITFactoryRegistry {
595                factories: Vec::new(),
596            })
597        });
598        &REGISTRY
599    }
600
601    /// Register a new JIT function factory.
602    pub fn register(&mut self, factory: Box<dyn JITFunctionFactory>) {
603        self.factories.push(factory);
604    }
605
606    /// Get a JIT function factory that supports the given array type.
607    #[must_use]
608    pub fn get_factory_for_array_type(
609        &self,
610        array_typeid: TypeId,
611    ) -> Option<&dyn JITFunctionFactory> {
612        for factory in &self.factories {
613            if factory.supports_array_type(array_typeid) {
614                return Some(&**factory);
615            }
616        }
617        None
618    }
619}
620
621/// A wrapper for ndarray to implement the ArrayProtocol trait.
622#[derive(Debug, Clone)]
623pub struct NdarrayWrapper<T, D: crate::ndarray::Dimension> {
624    array: crate::ndarray::Array<T, D>,
625    phantom: PhantomData<(T, D)>,
626}
627
628impl<T, D> NdarrayWrapper<T, D>
629where
630    T: Clone + 'static,
631    D: crate::ndarray::Dimension + 'static,
632{
633    /// Create a new ndarray wrapper.
634    #[must_use]
635    pub fn new(array: crate::ndarray::Array<T, D>) -> Self {
636        Self {
637            array,
638            phantom: PhantomData,
639        }
640    }
641
642    /// Get the underlying ndarray.
643    #[must_use]
644    pub const fn as_array(&self) -> &crate::ndarray::Array<T, D> {
645        &self.array
646    }
647
648    /// Convert into the underlying ndarray.
649    #[must_use]
650    pub fn into_array(self) -> crate::ndarray::Array<T, D> {
651        self.array
652    }
653
654    /// Update the underlying array with a new one.
655    pub fn array_2(&mut self, newarray: crate::ndarray::Array<T, D>) {
656        self.array = newarray;
657    }
658}
659
660impl<T, D> ArrayProtocol for NdarrayWrapper<T, D>
661where
662    T: Clone + Send + Sync + 'static,
663    D: crate::ndarray::Dimension + Send + Sync + 'static,
664{
665    fn array_function(
666        &self,
667        func: &ArrayFunction,
668        _types: &[TypeId],
669        args: &[Box<dyn Any>],
670        kwargs: &HashMap<String, Box<dyn Any>>,
671    ) -> Result<Box<dyn Any>, NotImplemented> {
672        match func.name {
673            "scirs2::array_protocol::operations::add" => {
674                // Addition operation for NdarrayWrapper
675                if args.len() < 2 {
676                    return Err(NotImplemented);
677                }
678
679                if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
680                    if let (Some(a), Some(b)) = (
681                        self.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
682                        other.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
683                    ) {
684                        // Need to make sure T supports addition
685                        if TypeId::of::<T>() == TypeId::of::<f64>() {
686                            let a_f64 =
687                                unsafe { &*(a as *const _ as *const NdarrayWrapper<f64, D>) };
688                            let b_f64 =
689                                unsafe { &*(b as *const _ as *const NdarrayWrapper<f64, D>) };
690                            let result = a_f64.as_array() + b_f64.as_array();
691                            return Ok(Box::new(NdarrayWrapper::new(result)));
692                        } else if TypeId::of::<T>() == TypeId::of::<f32>() {
693                            let a_f32 =
694                                unsafe { &*(a as *const _ as *const NdarrayWrapper<f32, D>) };
695                            let b_f32 =
696                                unsafe { &*(b as *const _ as *const NdarrayWrapper<f32, D>) };
697                            let result = a_f32.as_array() + b_f32.as_array();
698                            return Ok(Box::new(NdarrayWrapper::new(result)));
699                        }
700                    }
701                }
702                Err(NotImplemented)
703            }
704            "scirs2::array_protocol::operations::matmul" => {
705                // Matrix multiplication for NdarrayWrapper
706                if args.len() < 2 {
707                    return Err(NotImplemented);
708                }
709
710                // We can only handle matrix multiplication for 2D arrays
711                // Check for 2D array using TypeId
712                if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
713                    return Err(NotImplemented);
714                }
715
716                if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
717                    // Since we've already checked TypeId::of::<D>() == TypeId::of::<crate::ndarray::Ix2>()
718                    // We can safely specialize for Ix2 matrices
719
720                    // Handle the case for f64 matrices
721                    if TypeId::of::<T>() == TypeId::of::<f64>() {
722                        // Cast to concrete _types we know how to handle
723                        let a_f64 = unsafe {
724                            &*(self as *const _ as *const NdarrayWrapper<f64, crate::ndarray::Ix2>)
725                        };
726                        let b_f64 = unsafe {
727                            &*(other as *const _ as *const NdarrayWrapper<f64, crate::ndarray::Ix2>)
728                        };
729
730                        // Get dimensions
731                        let ashape = a_f64.as_array().shape();
732                        let bshape = b_f64.as_array().shape();
733
734                        if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
735                            return Err(NotImplemented);
736                        }
737
738                        // Use the higher-level dot operation which will be more efficient
739                        // than our manual implementation
740                        let result = a_f64.as_array().dot(b_f64.as_array());
741                        return Ok(Box::new(NdarrayWrapper::new(result)));
742                    }
743                    // Handle the case for f32 matrices
744                    else if TypeId::of::<T>() == TypeId::of::<f32>() {
745                        // Cast to concrete _types we know how to handle
746                        let a_f32 = unsafe {
747                            &*(self as *const _ as *const NdarrayWrapper<f32, crate::ndarray::Ix2>)
748                        };
749                        let b_f32 = unsafe {
750                            &*(other as *const _ as *const NdarrayWrapper<f32, crate::ndarray::Ix2>)
751                        };
752
753                        // Get dimensions
754                        let ashape = a_f32.as_array().shape();
755                        let bshape = b_f32.as_array().shape();
756
757                        if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
758                            return Err(NotImplemented);
759                        }
760
761                        // Use the higher-level dot operation which will be more efficient
762                        // than our manual implementation
763                        let result = a_f32.as_array().dot(b_f32.as_array());
764                        return Ok(Box::new(NdarrayWrapper::new(result)));
765                    }
766                }
767                // If we get here, we don't know how to handle this case
768                Err(NotImplemented)
769            }
770            "scirs2::array_protocol::operations::transpose" => {
771                // Transpose operation for NdarrayWrapper
772                if TypeId::of::<T>() == TypeId::of::<f64>() {
773                    let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
774                    let result = a_f64.as_array().t().to_owned();
775                    return Ok(Box::new(NdarrayWrapper::new(result)));
776                } else if TypeId::of::<T>() == TypeId::of::<f32>() {
777                    let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
778                    let result = a_f32.as_array().t().to_owned();
779                    return Ok(Box::new(NdarrayWrapper::new(result)));
780                }
781                Err(NotImplemented)
782            }
783            "scirs2::array_protocol::operations::sum" => {
784                // Sum operation for NdarrayWrapper
785                let axis_ref = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
786
787                if TypeId::of::<T>() == TypeId::of::<f64>() {
788                    let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
789                    match axis_ref {
790                        Some(&_ax) => {
791                            // Can't use sum_axis without RemoveAxis trait
792                            // Just return the full sum for now
793                            let result = a_f64.as_array().sum();
794                            return Ok(Box::new(result));
795                        }
796                        None => {
797                            let result = a_f64.as_array().sum();
798                            return Ok(Box::new(result));
799                        }
800                    }
801                } else if TypeId::of::<T>() == TypeId::of::<f32>() {
802                    let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
803                    match axis_ref {
804                        Some(&_ax) => {
805                            // Can't use sum_axis without RemoveAxis trait
806                            // Just return the full sum for now
807                            let result = a_f32.as_array().sum();
808                            return Ok(Box::new(result));
809                        }
810                        None => {
811                            let result = a_f32.as_array().sum();
812                            return Ok(Box::new(result));
813                        }
814                    }
815                }
816                Err(NotImplemented)
817            }
818            "scirs2::array_protocol::operations::reshape" => {
819                // Reshape operation for NdarrayWrapper
820                if let Some(shape) = kwargs
821                    .get("shape")
822                    .and_then(|s| s.downcast_ref::<Vec<usize>>())
823                {
824                    if TypeId::of::<T>() == TypeId::of::<f64>() {
825                        let a_f64 =
826                            unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
827                        match a_f64
828                            .as_array()
829                            .clone()
830                            .into_shape_with_order(shape.clone())
831                        {
832                            Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
833                            Err(_) => return Err(NotImplemented),
834                        }
835                    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
836                        let a_f32 =
837                            unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
838                        match a_f32
839                            .as_array()
840                            .clone()
841                            .into_shape_with_order(shape.clone())
842                        {
843                            Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
844                            Err(_) => return Err(NotImplemented),
845                        }
846                    }
847                }
848                Err(NotImplemented)
849            }
850            _ => Err(NotImplemented),
851        }
852    }
853
854    fn as_any(&self) -> &dyn Any {
855        self
856    }
857
858    fn shape(&self) -> &[usize] {
859        self.array.shape()
860    }
861
862    fn dtype(&self) -> TypeId {
863        TypeId::of::<T>()
864    }
865
866    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
867        Box::new(self.clone())
868    }
869}
870
871// Example implementation for a third-party array library:
872
873/// A mock distributed array implementation.
874#[derive(Debug, Clone)]
875pub struct MockDistributedArray<T: Clone + 'static> {
876    chunks: Vec<T>,
877    shape: Vec<usize>,
878}
879
880impl<T: Clone + Send + Sync + 'static> MockDistributedArray<T> {
881    /// Create a new mock distributed array.
882    #[must_use]
883    pub fn new(chunks: Vec<T>, shape: Vec<usize>) -> Self {
884        Self { chunks, shape }
885    }
886}
887
888impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockDistributedArray<T> {
889    fn array_function(
890        &self,
891        func: &ArrayFunction,
892        _types: &[TypeId],
893        _args: &[Box<dyn Any>],
894        _kwargs: &HashMap<String, Box<dyn Any>>,
895    ) -> Result<Box<dyn Any>, NotImplemented> {
896        match func.name {
897            "scirs2::mean" => {
898                // Example: Implement a mean function for distributed arrays
899                // In a real implementation, this would use distributed computation
900
901                // For simplicity, we'll just return a dummy result
902                let result = T::clone(&self.chunks[0]);
903                Ok(Box::new(result))
904            }
905            _ => Err(NotImplemented),
906        }
907    }
908
909    fn as_any(&self) -> &dyn Any {
910        self
911    }
912
913    fn shape(&self) -> &[usize] {
914        &self.shape
915    }
916
917    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
918        Box::new(self.clone())
919    }
920}
921
922impl<T: Clone + Send + Sync + 'static> DistributedArray for MockDistributedArray<T> {
923    fn distribution_info(&self) -> HashMap<String, String> {
924        let mut info = HashMap::new();
925        info.insert("type".to_string(), "mock_distributed".to_string());
926        info.insert("chunks".to_string(), self.chunks.len().to_string());
927        info
928    }
929
930    fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
931        // In a real implementation, this would gather data from all nodes
932        // For now, we just return self boxed as ArrayProtocol
933        Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
934    }
935
936    fn scatter(&self, _numchunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
937        // In a real implementation, this would scatter data to multiple nodes
938        // For now, we just return self boxed as DistributedArray
939        Ok(Box::new(self.clone()) as Box<dyn DistributedArray>)
940    }
941
942    fn is_distributed(&self) -> bool {
943        true
944    }
945}
946
947/// A mock GPU array implementation.
948#[derive(Debug, Clone)]
949pub struct MockGPUArray<T: Clone + 'static> {
950    data: Vec<T>,
951    shape: Vec<usize>,
952    device: String,
953}
954
955impl<T: Clone + Send + Sync + 'static> MockGPUArray<T> {
956    /// Create a new mock GPU array.
957    #[must_use]
958    pub fn new(data: Vec<T>, shape: Vec<usize>, device: String) -> Self {
959        Self {
960            data,
961            shape,
962            device,
963        }
964    }
965}
966
967impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockGPUArray<T> {
968    fn array_function(
969        &self,
970        func: &ArrayFunction,
971        _types: &[TypeId],
972        _args: &[Box<dyn Any>],
973        _kwargs: &HashMap<String, Box<dyn Any>>,
974    ) -> Result<Box<dyn Any>, NotImplemented> {
975        match func.name {
976            "scirs2::matmul" => {
977                // Example: Implement a GPU-accelerated matrix multiplication
978                // In a real implementation, this would use GPU computation
979
980                // For simplicity, we'll just return a dummy result
981                let result =
982                    MockGPUArray::new(self.data.clone(), self.shape.clone(), self.device.clone());
983                Ok(Box::new(result))
984            }
985            _ => Err(NotImplemented),
986        }
987    }
988
989    fn as_any(&self) -> &dyn Any {
990        self
991    }
992
993    fn shape(&self) -> &[usize] {
994        &self.shape
995    }
996
997    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
998        Box::new(self.clone())
999    }
1000}
1001
1002impl<T: Clone + Send + Sync + 'static> GPUArray for MockGPUArray<T> {
1003    fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
1004        // Already on GPU
1005        Ok(Box::new(self.clone()) as Box<dyn GPUArray>)
1006    }
1007
1008    fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
1009        // In a real implementation, this would transfer data from GPU to CPU
1010        // For now, we just return self boxed as ArrayProtocol
1011        Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
1012    }
1013
1014    fn is_on_gpu(&self) -> bool {
1015        true
1016    }
1017
1018    fn device_info(&self) -> HashMap<String, String> {
1019        let mut info = HashMap::new();
1020        info.insert("device".to_string(), self.device.clone());
1021        info.insert("type".to_string(), "mock_gpu".to_string());
1022        info
1023    }
1024}
1025
1026/// A factory for creating and registering array protocol enabled functions.
1027///
1028/// This provides a convenient way to create functions that can be overridden
1029/// by third-party array implementations.
1030#[derive(Debug)]
1031pub struct ArrayProtocolFunction<F> {
1032    func: F,
1033    name: &'static str,
1034}
1035
1036impl<F> ArrayProtocolFunction<F> {
1037    /// Create a new array protocol function.
1038    #[must_use]
1039    pub fn new(func: F, name: &'static str) -> Self {
1040        Self { func, name }
1041    }
1042}
1043
1044impl<F> ArrayProtocolFunction<F>
1045where
1046    F: Clone + Send + Sync + 'static,
1047{
1048    /// Register this function with the array protocol system.
1049    pub fn register(self) -> F {
1050        let implementation = Arc::new(
1051            move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1052                // This is a placeholder for actual implementation that would:
1053                // 1. Convert generic args to specific types needed by the function
1054                // 2. Call the function with the converted args
1055                // 3. Return the result as a Box<dyn Any>
1056                Err(CoreError::NotImplementedError(ErrorContext::new(
1057                    "ArrayProtocolFunction: Implementation for array protocol functions is not complete".to_string()
1058                )))
1059            },
1060        );
1061
1062        let array_func = ArrayFunction {
1063            name: self.name,
1064            implementation,
1065        };
1066
1067        // Register the function
1068        if let Ok(mut registry) = ArrayFunctionRegistry::global().write() {
1069            registry.register(array_func);
1070        } else {
1071            eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry during array protocol building, skipping function registration");
1072            // Continue without registration - this may result in reduced functionality but avoids crash
1073        }
1074
1075        self.func
1076    }
1077}
1078
1079/// Convenience macro for defining array protocol functions.
1080///
1081/// This macro creates a function and registers it with the array protocol system.
1082/// The function can then be overridden by array types that implement the ArrayProtocol trait.
1083///
1084/// Example usage:
1085/// ```ignore
1086/// use scirs2_core::array_protocol::{ArrayFunction, ArrayFunctionRegistry};
1087/// use std::sync::Arc;
1088/// use std::collections::HashMap;
1089/// use std::any::Any;
1090///
1091/// // Define and register a sum function
1092/// fn register_sum_function() {
1093///     let implementation = Arc::new(
1094///         move |args: &[Box<dyn Any>], kwargs: &HashMap<String, Box<dyn Any>>| {
1095///             if let Some(array) = args.get(0)
1096///                 .and_then(|arg| arg.downcast_ref::<crate::ndarray::Array<f64, crate::ndarray::Ix2>>()) {
1097///                 let sum = array.sum();
1098///                 Ok(Box::new(sum) as Box<dyn Any>)
1099///             } else {
1100///                 Err(scirs2_core::error::CoreError::InvalidArgument(
1101///                     scirs2_core::error::ErrorContext::new(
1102///                         "Expected Array2<f64> as first argument".to_string()
1103///                     )
1104///                 ))
1105///             }
1106///         }
1107///     );
1108///     
1109///     let func = ArrayFunction {
1110///         name: "scirs2::sum",
1111///         implementation,
1112///     };
1113///     
1114///     // Register the function
1115///     if let Ok(mut registry) = ArrayFunctionRegistry::global().write() {
1116///         registry.register(func);
1117///     }
1118/// }
1119/// ```
1120#[macro_export]
1121macro_rules! array_function_def {
1122    (fn $name:ident $(<$($gen:ident),*>)? ($($arg:ident : $arg_ty:ty),*) -> $ret:ty $body:block, $funcname:expr) => {
1123        {
1124            // Define the function
1125            fn $name $(<$($gen),*>)? ($($arg : $arg_ty),*) -> $ret $body
1126
1127            // Return the function so it can be used
1128            $name
1129        }
1130    };
1131}
1132
1133// Re-export distributed array implementation
1134pub use self::distributed_impl::{
1135    ArrayChunk, DistributedBackend, DistributedConfig, DistributedNdarray, DistributionStrategy,
1136};
1137
1138// Re-export GPU array implementation
1139pub use self::gpu_impl::{
1140    kernels as gpu_kernels, GPUArrayBuilder, GPUBackend, GPUConfig, GPUNdarray,
1141};
1142
1143// Re-export JIT compilation implementation
1144pub use self::jit_impl::{
1145    CraneliftFunctionFactory, JITBackend, JITConfig, JITEnabledArray, JITFunctionImpl, JITManager,
1146    LLVMFunctionFactory,
1147};
1148
1149// Re-export array operations
1150pub use self::operations::{
1151    add, apply_elementwise, concatenate, inverse, matmul, multiply, reshape, subtract, sum, svd,
1152    transpose, OperationError,
1153};
1154
1155// Re-export ml_ops
1156pub use self::ml_ops::{
1157    activation, batch_norm, conv2d, cross_entropy, dropout, max_pool2d, self_attention,
1158    ActivationFunc,
1159};
1160
1161/// Initializes the array protocol system.
1162///
1163/// This function initializes the array protocol system by registering the
1164/// default JIT function factories and other components. It should be called
1165/// before using any of the array protocol features.
1166#[allow(dead_code)]
1167pub fn init() {
1168    // Initialize the JIT manager
1169    let mut jit_manager = JITManager::global().write().expect("Operation failed");
1170    jit_manager.initialize();
1171}
1172
1173/// Extra traits for third-party array implementations.
1174pub mod traits {
1175    use super::*;
1176
1177    /// Trait for arrays that support strided access.
1178    pub trait StridedArray: ArrayProtocol {
1179        /// Get the strides of this array.
1180        #[must_use]
1181        fn strides(&self) -> Vec<usize>;
1182
1183        /// Check if this array is contiguous.
1184        #[must_use]
1185        fn is_contiguous(&self) -> bool;
1186
1187        /// Check if this array is Fortran-contiguous (column-major).
1188        #[must_use]
1189        fn is_fortran_contiguous(&self) -> bool;
1190    }
1191
1192    /// Trait for arrays that support zero-copy operations.
1193    pub trait ZeroCopyArray: ArrayProtocol {
1194        /// Create a view of this array.
1195        #[must_use]
1196        fn view(&self) -> Box<dyn ZeroCopyArray>;
1197
1198        /// Create a mutable view of this array.
1199        #[must_use]
1200        fn view_mut(&mut self) -> Box<dyn ZeroCopyArray>;
1201
1202        /// Check if this array is a view.
1203        #[must_use]
1204        fn is_view(&self) -> bool;
1205    }
1206
1207    /// Trait for arrays that support automatic differentiation.
1208    pub trait DifferentiableArray: ArrayProtocol {
1209        /// Compute the gradient of this array with respect to some variables.
1210        fn gradient(
1211            &self,
1212            variables: &[Box<dyn DifferentiableArray>],
1213        ) -> Vec<Box<dyn DifferentiableArray>>;
1214
1215        /// Set whether to record operations for automatic differentiation.
1216        fn set_requiresgrad(&mut self, requiresgrad: bool);
1217
1218        /// Check if this array requires gradient computation.
1219        #[must_use]
1220        fn requiresgrad(&self) -> bool;
1221
1222        /// Get the gradient of this array.
1223        #[must_use]
1224        fn grad(&self) -> Option<Box<dyn DifferentiableArray>>;
1225    }
1226
1227    /// Trait for arrays that support asynchronous operations.
1228    pub trait AsyncArray: ArrayProtocol {
1229        /// Perform an asynchronous operation on this array.
1230        fn async_op<F, R>(&self, op: F) -> impl std::future::Future<Output = CoreResult<R>>
1231        where
1232            F: FnOnce(&Self) -> CoreResult<R> + Send + 'static,
1233            R: Send + 'static;
1234
1235        /// Check if this array supports asynchronous operations.
1236        #[must_use]
1237        fn supports_async(&self) -> bool;
1238    }
1239}
1240
1241#[cfg(test)]
1242mod tests {
1243    use super::*;
1244
1245    #[test]
1246    fn test_array_protocol_registry() {
1247        // Create a function and register it
1248        let implementation = Arc::new(
1249            move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1250                Ok(Box::new(42.0) as Box<dyn Any>)
1251            },
1252        );
1253
1254        let func = ArrayFunction {
1255            name: "scirs2::test::test_func",
1256            implementation,
1257        };
1258
1259        let registry = ArrayFunctionRegistry::global();
1260        {
1261            let mut reg = registry.write().expect("Operation failed");
1262            reg.register(func.clone());
1263        }
1264
1265        // Verify the function was registered
1266        {
1267            let reg = registry.read().expect("Operation failed");
1268            let registered_func = reg
1269                .get("scirs2::test::test_func")
1270                .expect("Operation failed");
1271            assert_eq!(registered_func.name, "scirs2::test::test_func");
1272        }
1273    }
1274
1275    #[test]
1276    fn test_mock_distributed_array() {
1277        let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1278        assert!(array.is_distributed());
1279
1280        let info = array.distribution_info();
1281        assert_eq!(
1282            info.get("type").expect("Operation failed"),
1283            "mock_distributed"
1284        );
1285        assert_eq!(info.get("chunks").expect("Operation failed"), "3");
1286    }
1287
1288    #[test]
1289    fn test_mock_gpu_array() {
1290        let array = MockGPUArray::new(vec![1.0, 2.0, 3.0], vec![3], "cuda:0".to_string());
1291        assert!(array.is_on_gpu());
1292
1293        let info = array.device_info();
1294        assert_eq!(info.get("device").expect("Operation failed"), "cuda:0");
1295        assert_eq!(info.get("type").expect("Operation failed"), "mock_gpu");
1296    }
1297
1298    #[test]
1299    fn test_box_clone() {
1300        // Test Box<dyn ArrayProtocol> cloning for NdarrayWrapper
1301        let array = crate::ndarray::Array2::<f64>::ones((3, 3));
1302        let wrapped = NdarrayWrapper::new(array);
1303        let boxed: Box<dyn ArrayProtocol> = Box::new(wrapped);
1304        let cloned = boxed.clone();
1305
1306        // Verify the clone is correct
1307        assert_eq!(cloned.shape(), &[3, 3]);
1308
1309        // Test Box<dyn ArrayProtocol> cloning for MockDistributedArray
1310        let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1311        let boxed: Box<dyn ArrayProtocol> = Box::new(array);
1312        let cloned = boxed.clone();
1313
1314        // Verify the clone is correct
1315        assert_eq!(cloned.shape(), &[3]);
1316    }
1317}
1318
1319/// Examples of using the array protocol.
1320#[cfg(test)]
1321mod examples {
1322    use super::*;
1323    use ::ndarray::Array2;
1324    use std::any::Any;
1325    use std::collections::HashMap;
1326
1327    /// Example: Create and use a distributed array.
1328    #[test]
1329    fn example_distributed_array() {
1330        // Create a regular array
1331        let array = Array2::<f64>::ones((10, 5));
1332
1333        // Create a distributed array configuration
1334        let config = DistributedConfig {
1335            chunks: 3,
1336            balance: true,
1337            strategy: DistributionStrategy::RowWise,
1338            backend: DistributedBackend::Threaded,
1339        };
1340
1341        // Create a distributed array
1342        let dist_array = DistributedNdarray::from_array(&array, config);
1343
1344        // Check that the array was split correctly
1345        assert_eq!(dist_array.num_chunks(), 3);
1346        assert_eq!(dist_array.shape(), &[10, 5]);
1347
1348        // Convert back to a regular array
1349        let result = dist_array.to_array().expect("Operation failed");
1350
1351        // Check that the result matches the original array
1352        assert_eq!(result.shape(), array.shape());
1353        // NOTE: Arrays with different dimensions can't be directly compared
1354        // assert_eq!(result, array);
1355    }
1356
1357    /// Example: Create and use a GPU array.
1358    #[test]
1359    fn example_gpu_array() {
1360        // Create a regular array
1361        let array = Array2::<f64>::ones((10, 5));
1362
1363        // Create a GPU array configuration
1364        let config = GPUConfig {
1365            backend: GPUBackend::CUDA,
1366            device_id: 0,
1367            async_ops: true,
1368            mixed_precision: false,
1369            memory_fraction: 0.9,
1370        };
1371
1372        // Create a GPU array
1373        let gpu_array = GPUNdarray::new(array.clone(), config);
1374
1375        // Check that the array was created correctly
1376        assert_eq!(gpu_array.shape(), &[10, 5]);
1377        assert!(gpu_array.is_on_gpu());
1378
1379        // Get device information
1380        let info = gpu_array.device_info();
1381        assert_eq!(info.get("backend").expect("Operation failed"), "CUDA");
1382
1383        // Test box_clone for GPU array
1384        let gpu_box: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1385        let gpu_clone = gpu_box.clone();
1386
1387        // Check the cloned GPU array
1388        assert_eq!(gpu_clone.shape(), &[10, 5]);
1389    }
1390
1391    /// Example: Create and use a JIT-enabled array.
1392    #[test]
1393    fn example_jit_array() {
1394        // Initialize the JIT manager
1395        init();
1396
1397        // Create a regular array
1398        let array = Array2::<f64>::ones((10, 5));
1399        let wrapped = NdarrayWrapper::new(array);
1400
1401        // Create a JIT-enabled array
1402        let jitarray: JITEnabledArray<f64, NdarrayWrapper<f64, crate::ndarray::Ix2>> =
1403            JITEnabledArray::new(wrapped);
1404
1405        // Check if JIT is supported
1406        assert!(jitarray.supports_jit());
1407
1408        // Compile a function
1409        let expression = "x + y";
1410        let jit_function = jitarray.compile(expression).expect("Operation failed");
1411
1412        // Check the function's properties
1413        assert_eq!(jit_function.source(), expression);
1414
1415        // Get JIT information
1416        let info = jitarray.jit_info();
1417        assert_eq!(info.get("supports_jit").expect("Operation failed"), "true");
1418
1419        // Test box_clone for JIT-enabled array
1420        let jit_box: Box<dyn ArrayProtocol> = Box::new(jitarray);
1421        let jit_clone = jit_box.clone();
1422
1423        // Check the cloned JIT array
1424        assert_eq!(jit_clone.shape(), &[10, 5]);
1425    }
1426
1427    /// Example: Test cloning Box<dyn ArrayProtocol>
1428    #[test]
1429    fn example_cloning_array_protocol_objects() {
1430        // Create a GPU array with box_clone support
1431        let array = Array2::<f64>::ones((10, 5));
1432        let config = GPUConfig::default();
1433        let gpu_array = GPUNdarray::new(array.clone(), config);
1434
1435        // Box the array as ArrayProtocol and clone it
1436        let boxed: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1437        let cloned = boxed.clone();
1438
1439        // Verify the clone works correctly
1440        assert_eq!(cloned.shape(), &[10, 5]);
1441
1442        // Create a distributed array and test box_clone
1443        let config = DistributedConfig {
1444            chunks: 3,
1445            balance: true,
1446            strategy: DistributionStrategy::RowWise,
1447            backend: DistributedBackend::Threaded,
1448        };
1449        let dist_array = DistributedNdarray::from_array(&array, config);
1450
1451        // Box the array as ArrayProtocol and clone it
1452        let boxed: Box<dyn ArrayProtocol> = Box::new(dist_array);
1453        let cloned = boxed.clone();
1454
1455        // Verify the clone works correctly
1456        assert_eq!(cloned.shape(), &[10, 5]);
1457    }
1458
1459    /*
1460    // Commented out examples using macros - we'll fix these later
1461
1462    /// Example: Define an array function using the macro.
1463    /// Example: Register and use an array function.
1464    #[test]
1465    fn example_array_function() {
1466        // Create a simple array function (without using macros)
1467        let funcname = "scirs2::example::sum";
1468
1469        // Create an ArrayFunction manually
1470        let implementation = Arc::new(move |args: &[Box<dyn Any>], kwargs: &HashMap<String, Box<dyn Any>>| {
1471            if let Some(array) = args.get(0)
1472                .and_then(|arg| arg.downcast_ref::<Array2<f64>>()) {
1473                let sum = array.sum();
1474                Ok(Box::new(sum))
1475            } else {
1476                Err(CoreError::InvalidArgument(ErrorContext::new(
1477                    "Expected Array2<f64> as first argument".to_string()
1478                )))
1479            }
1480        });
1481
1482        let func = ArrayFunction {
1483            name: funcname,
1484            implementation,
1485        };
1486
1487        // Register the function
1488        let registry = ArrayFunctionRegistry::global();
1489        {
1490            let mut reg = registry.write().expect("Operation failed");
1491            reg.register(func.clone());
1492        }
1493
1494        // Verify the function was registered
1495        {
1496            let reg = registry.read().expect("Operation failed");
1497            let registered_func = reg.get(funcname).expect("Operation failed");
1498            assert_eq!(registered_func.name, funcname);
1499        }
1500    }
1501    */
1502
1503    /// Example: Interoperability between different array types
1504    #[test]
1505    fn example_array_interoperability() {
1506        // Initialize the system
1507        init();
1508
1509        // Create arrays of different types
1510        let cpu_array = Array2::<f64>::ones((5, 5));
1511
1512        // Create a GPU array
1513        let gpu_config = GPUConfig {
1514            backend: GPUBackend::CUDA,
1515            device_id: 0,
1516            async_ops: false,
1517            mixed_precision: false,
1518            memory_fraction: 0.9,
1519        };
1520        let gpu_array = GPUNdarray::new(cpu_array.clone(), gpu_config);
1521
1522        // Create a distributed array
1523        let dist_config = DistributedConfig {
1524            chunks: 2,
1525            balance: true,
1526            strategy: DistributionStrategy::RowWise,
1527            backend: DistributedBackend::Threaded,
1528        };
1529        let dist_array = DistributedNdarray::from_array(&cpu_array, dist_config);
1530
1531        // Simple test of interoperability: convert both to Box<dyn ArrayProtocol>
1532        let gpu_wrapper: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1533        let dist_wrapper: Box<dyn ArrayProtocol> = Box::new(dist_array);
1534
1535        // Verify the clones work correctly
1536        let gpu_clone = gpu_wrapper.clone();
1537        let dist_clone = dist_wrapper.clone();
1538
1539        assert_eq!(gpu_clone.shape(), &[5, 5]);
1540        assert_eq!(dist_clone.shape(), &[5, 5]);
1541    }
1542
1543    /// Example: Advanced usage with custom array type
1544    #[test]
1545    fn example_custom_array_type() {
1546        use std::sync::Arc;
1547
1548        // Define a custom array type
1549        struct MyCustomArray<T> {
1550            data: Vec<T>,
1551            shape: Vec<usize>,
1552        }
1553
1554        impl<T: Clone + 'static> MyCustomArray<T> {
1555            fn new(data: Vec<T>, shape: Vec<usize>) -> Self {
1556                Self { data, shape }
1557            }
1558
1559            // Commented out since it's unused but may be needed in the future
1560            // fn shape(&self) -> &[usize] {
1561            //     &self.shape
1562            // }
1563        }
1564
1565        // Implement ArrayProtocol for the custom array type
1566        impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MyCustomArray<T> {
1567            fn array_function(
1568                &self,
1569                func: &ArrayFunction,
1570                _types: &[TypeId],
1571                _args: &[Box<dyn Any>],
1572                _kwargs: &HashMap<String, Box<dyn Any>>,
1573            ) -> Result<Box<dyn Any>, NotImplemented> {
1574                if func.name == "scirs2::example::custom_sum" {
1575                    // Implement custom sum for our array type
1576                    match std::any::TypeId::of::<T>() {
1577                        tid if tid == std::any::TypeId::of::<f64>() => {
1578                            // For f64 arrays, cast to f64 slice
1579                            let f64_data = unsafe {
1580                                std::slice::from_raw_parts(
1581                                    self.data.as_ptr() as *const f64,
1582                                    self.data.len(),
1583                                )
1584                            };
1585                            let sum = f64_data.iter().sum::<f64>();
1586                            Ok(Box::new(sum))
1587                        }
1588                        tid if tid == std::any::TypeId::of::<f32>() => {
1589                            // For f32 arrays, cast to f32 slice
1590                            let f32_data = unsafe {
1591                                std::slice::from_raw_parts(
1592                                    self.data.as_ptr() as *const f32,
1593                                    self.data.len(),
1594                                )
1595                            };
1596                            let sum = f32_data.iter().sum::<f32>();
1597                            Ok(Box::new(sum))
1598                        }
1599                        _ => Err(NotImplemented),
1600                    }
1601                } else {
1602                    Err(NotImplemented)
1603                }
1604            }
1605
1606            fn as_any(&self) -> &dyn Any {
1607                self
1608            }
1609
1610            fn shape(&self) -> &[usize] {
1611                &self.shape
1612            }
1613
1614            fn box_clone(&self) -> Box<dyn ArrayProtocol> {
1615                Box::new(MyCustomArray {
1616                    data: self.data.clone(),
1617                    shape: self.shape.clone(),
1618                })
1619            }
1620        }
1621
1622        // Create an instance of the custom array type
1623        let custom_array = MyCustomArray::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1624
1625        // Test box_clone functionality
1626        let boxed: Box<dyn ArrayProtocol> = Box::new(custom_array);
1627        let cloned = boxed.clone();
1628
1629        // Verify the clone has the correct shape
1630        assert_eq!(cloned.shape(), &[2, 2]);
1631
1632        // Create an ArrayFunction for testing
1633        let func = ArrayFunction {
1634            name: "scirs2::example::custom_sum",
1635            implementation: Arc::new(move |_args, _kwargs| {
1636                // Dummy implementation
1637                Ok(Box::new(42.0) as Box<dyn Any>)
1638            }),
1639        };
1640
1641        // Test array_function directly
1642        let result = cloned.array_function(
1643            &func,
1644            &[std::any::TypeId::of::<f64>()],
1645            &[],
1646            &HashMap::new(),
1647        );
1648
1649        // Verify we get a result (the sum of 1+2+3+4 = 10)
1650        assert!(result.is_ok());
1651        if let Ok(value) = result {
1652            let sum = *value.downcast_ref::<f64>().expect("Operation failed");
1653            assert_eq!(sum, 10.0);
1654        }
1655    }
1656}
1657/// Make `Box<dyn JITFunction>` cloneable via clone_box
1658impl Clone for Box<dyn JITFunction> {
1659    fn clone(&self) -> Self {
1660        self.clone_box()
1661    }
1662}