Skip to main content

scirs2_core/array_protocol/
mod.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under the Apache License, Version 2.0
4// (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
5//
6
7//! Implementation of Array Protocol (similar to ``NumPy``'s `__array_function__` protocol)
8//!
9//! This module provides a mechanism for third-party array implementations to
10//! override ``SciRS2`` functions. It is inspired by ``NumPy``'s `__array_function__`
11//! protocol defined in NEP-18.
12//!
13//! The protocol enables third-party arrays to implement ``SciRS2`` functions in a way
14//! that is recognized by the ``SciRS2`` library. This allows for seamless integration with
15//! distributed arrays, GPU arrays, and other custom array implementations.
16//!
17//! ## Core Components
18//!
19//! The Array Protocol system includes:
20//!
21//! * Specialized array implementations (GPU, distributed, JIT)
22//! * Automatic device placement with `AutoDevice`
23//! * Mixed-precision operations
24//! * Neural network layers and models
25//! * Gradient computation and training capabilities
26//! * Distributed training and model serialization
27
28use std::any::{Any, TypeId};
29use std::collections::HashMap;
30use std::fmt::{Debug, Display};
31use std::marker::PhantomData;
32use std::sync::{Arc, LazyLock, RwLock};
33use std::time::{Duration, Instant};
34
35use crate::error::{CoreError, CoreResult, ErrorContext};
36
37// Internal submodules
38mod distributed_impl;
39mod gpu_impl;
40mod jit_impl;
41mod operations;
42
43// Re-export the array_function_dispatch macro
44pub use crate::array_function_dispatch;
45
46// Public submodules
47pub mod auto_device;
48pub mod distributed_training;
49pub mod grad;
50pub mod mixed_precision;
51pub mod ml_ops;
52pub mod neural;
53#[cfg(feature = "serialization")]
54pub mod serialization;
55pub mod training;
56
57/// Trait for objects that can handle the array protocol.
58///
59/// This is similar to `NumPy`'s `__array_function__` protocol.
60pub trait ArrayProtocol: Any + Send + Sync {
61    /// Implementation of the array protocol.
62    ///
63    /// * `func` - The function being called
64    /// * `types` - The types of all arguments that implement `ArrayProtocol`
65    /// * `args` - The arguments to the function
66    /// * `kwargs` - Named arguments to the function
67    ///
68    /// Returns `Ok(result)` if the operation is successful, or `Err(NotImplemented)`
69    /// if the operation is not implemented for this type.
70    ///
71    /// # Errors
72    ///
73    /// Returns `Err(NotImplemented)` if the operation is not supported by this array type.
74    fn array_function(
75        &self,
76        func: &ArrayFunction,
77        types: &[TypeId],
78        args: &[Box<dyn Any>],
79        kwargs: &HashMap<String, Box<dyn Any>>,
80    ) -> Result<Box<dyn Any>, NotImplemented>;
81
82    /// Get the array as Any for downcasting
83    #[must_use]
84    fn as_any(&self) -> &dyn Any;
85
86    /// Get the shape of the array (default implementation returns empty slice)
87    #[must_use]
88    fn shape(&self) -> &[usize] {
89        &[]
90    }
91
92    /// Get the data type of the array (default implementation returns f64)
93    #[must_use]
94    fn dtype(&self) -> TypeId {
95        TypeId::of::<f64>()
96    }
97
98    /// Clone this array protocol object.
99    #[must_use]
100    fn box_clone(&self) -> Box<dyn ArrayProtocol>;
101}
102
103/// Make `Box<dyn ArrayProtocol>` cloneable via `box_clone`
104impl Clone for Box<dyn ArrayProtocol> {
105    fn clone(&self) -> Self {
106        self.box_clone()
107    }
108}
109
110/// Marker for functions not implemented by a specific type.
111///
112/// This is part of the Array Protocol API design and is used as a marker to indicate
113/// that a function is not implemented by a specific array type. It's different from
114/// the `CoreError::NotImplementedError` enum variant, which is used for error reporting.
115///
116/// When an error is propagated up the call chain, `NotImplemented` is converted
117/// to `OperationError::NotImplemented` and then to `CoreError::NotImplementedError`
118/// for consistent error handling.
119#[derive(Debug, Clone, Copy)]
120pub struct NotImplemented;
121
122impl Display for NotImplemented {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        write!(f, "NotImplemented")
125    }
126}
127
128/// Type alias for the complex function implementation type
129pub type ArrayFunctionImpl = dyn Fn(&[Box<dyn Any>], &HashMap<String, Box<dyn Any>>) -> CoreResult<Box<dyn Any>>
130    + Send
131    + Sync;
132
133/// A wrapper for functions that can be overridden by the array protocol.
134#[derive(Clone)]
135pub struct ArrayFunction {
136    /// The name of the function, including its module path
137    pub name: &'static str,
138
139    /// The function implementation
140    pub implementation: Arc<ArrayFunctionImpl>,
141}
142
143impl Debug for ArrayFunction {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        f.debug_struct("ArrayFunction")
146            .field("name", &self.name)
147            .finish_non_exhaustive()
148    }
149}
150
151impl PartialEq for ArrayFunction {
152    fn eq(&self, other: &Self) -> bool {
153        self.name == other.name
154    }
155}
156
157impl Eq for ArrayFunction {}
158
159impl std::hash::Hash for ArrayFunction {
160    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
161        self.name.hash(state);
162    }
163}
164
165impl ArrayFunction {
166    /// Create a new array function with the given name
167    #[must_use]
168    pub fn new(name: &'static str) -> Self {
169        Self {
170            name,
171            // Default implementation that returns NotImplemented
172            implementation: Arc::new(|_args, _kwargs| {
173                Err(CoreError::NotImplementedError(ErrorContext::new(
174                    "Function not implemented".to_string(),
175                )))
176            }),
177        }
178    }
179}
180
181/// Cache entry for function dispatch optimization
182#[derive(Debug, Clone)]
183pub struct DispatchCacheEntry {
184    /// Type signature for the cached result
185    #[allow(dead_code)]
186    type_signature: Vec<TypeId>,
187    /// Which implementation type to try first
188    #[allow(dead_code)]
189    preferred_impl_type: TypeId,
190    /// Cache timestamp for TTL management
191    timestamp: Instant,
192    /// Number of cache hits
193    hit_count: u64,
194}
195
196/// Registry of all array functions with dispatch caching.
197#[derive(Debug)]
198pub struct ArrayFunctionRegistry {
199    /// Map of function names to array functions
200    functions: HashMap<&'static str, ArrayFunction>,
201    /// Dispatch cache for performance optimization
202    dispatch_cache: HashMap<(&'static str, Vec<TypeId>), DispatchCacheEntry>,
203    /// Maximum cache size to prevent unbounded growth
204    max_cache_size: usize,
205    /// Cache TTL for entries (prevents stale cache)
206    cache_ttl: Duration,
207}
208
209impl Default for ArrayFunctionRegistry {
210    fn default() -> Self {
211        Self {
212            functions: HashMap::new(),
213            dispatch_cache: HashMap::new(),
214            max_cache_size: 1000,                // Reasonable default cache size
215            cache_ttl: Duration::from_secs(300), // 5 minutes TTL
216        }
217    }
218}
219
220impl ArrayFunctionRegistry {
221    /// Get the global registry.
222    #[must_use]
223    pub fn global() -> &'static RwLock<Self> {
224        static REGISTRY: LazyLock<RwLock<ArrayFunctionRegistry>> =
225            LazyLock::new(|| RwLock::new(ArrayFunctionRegistry::default()));
226        &REGISTRY
227    }
228
229    /// Register a new array function.
230    pub fn register(&mut self, func: ArrayFunction) {
231        self.functions.insert(func.name, func);
232    }
233
234    /// Get an array function by name.
235    #[must_use]
236    #[allow(dead_code)]
237    pub fn get(&self, name: &str) -> Option<&ArrayFunction> {
238        self.functions.get(name)
239    }
240
241    /// Get all registered functions.
242    #[must_use]
243    pub fn all_functions(&self) -> Vec<&ArrayFunction> {
244        self.functions.values().collect()
245    }
246
247    /// Get cached dispatch entry for optimization
248    #[must_use]
249    pub fn get_cached_dispatch(
250        &self,
251        funcname: &'static str,
252        types: &[TypeId],
253    ) -> Option<&DispatchCacheEntry> {
254        let key = (funcname, types.to_vec());
255        if let Some(entry) = self.dispatch_cache.get(&key) {
256            // Check if cache entry is still valid (TTL check)
257            if entry.timestamp.elapsed() < self.cache_ttl {
258                return Some(entry);
259            }
260        }
261        None
262    }
263
264    /// Cache dispatch result for future optimization
265    pub fn cache_dispatch(
266        &mut self,
267        funcname: &'static str,
268        types: Vec<TypeId>,
269        impl_type: TypeId,
270    ) {
271        // Clean cache if it's getting too large
272        if self.dispatch_cache.len() >= self.max_cache_size {
273            self.cleanup_cache();
274        }
275
276        let key = (funcname, types.clone());
277        let entry = DispatchCacheEntry {
278            type_signature: types,
279            preferred_impl_type: impl_type,
280            timestamp: Instant::now(),
281            hit_count: 0,
282        };
283        self.dispatch_cache.insert(key, entry);
284    }
285
286    /// Update cache hit count for an entry
287    pub fn update_cache_hit(&mut self, funcname: &'static str, types: &[TypeId]) {
288        let key = (funcname, types.to_vec());
289        if let Some(entry) = self.dispatch_cache.get_mut(&key) {
290            entry.hit_count += 1;
291        }
292    }
293
294    /// Clean up expired cache entries
295    fn cleanup_cache(&mut self) {
296        let now = Instant::now();
297        self.dispatch_cache
298            .retain(|_, entry| now.duration_since(entry.timestamp) < self.cache_ttl);
299
300        // If still too large, remove least recently used entries
301        if self.dispatch_cache.len() >= self.max_cache_size {
302            let mut entries: Vec<_> = self
303                .dispatch_cache
304                .iter()
305                .map(|(k, v)| (k.clone(), v.hit_count))
306                .collect();
307            entries.sort_by_key(|(_, hit_count)| *hit_count);
308
309            // Remove bottom 25% of entries by hit count
310            let to_remove = self.dispatch_cache.len() / 4;
311            let keys_to_remove: Vec<_> = entries
312                .iter()
313                .take(to_remove)
314                .map(|(key, _)| key.clone())
315                .collect();
316            for key in keys_to_remove {
317                self.dispatch_cache.remove(&key);
318            }
319        }
320    }
321
322    /// Get cache statistics for monitoring
323    #[must_use]
324    pub fn cache_stats(&self) -> HashMap<String, u64> {
325        let mut stats = HashMap::new();
326        stats.insert("cache_size".to_string(), self.dispatch_cache.len() as u64);
327        stats.insert("max_cache_size".to_string(), self.max_cache_size as u64);
328
329        let total_hits: u64 = self.dispatch_cache.values().map(|e| e.hit_count).sum();
330        stats.insert("total_hits".to_string(), total_hits);
331
332        stats
333    }
334}
335
336/// Helper function to extract all arguments implementing the `ArrayProtocol` trait.
337///
338/// This is similar to `NumPy`'s `_get_implementing_args` function.
339/// Optimized version with pre-allocated capacity and fast-path for common cases.
340#[allow(dead_code)]
341pub fn get_implementing_args(args: &[Box<dyn Any>]) -> Vec<(TypeId, &dyn ArrayProtocol)> {
342    if args.is_empty() {
343        return Vec::new();
344    }
345
346    // Pre-allocate with capacity to avoid reallocation
347    let mut implementing_args = Vec::with_capacity(args.len());
348
349    for arg in args {
350        if let Some(array_protocol_obj) = arg.downcast_ref::<Box<dyn ArrayProtocol>>() {
351            let type_id = (**array_protocol_obj).type_id();
352            implementing_args.push((type_id, &**array_protocol_obj));
353        }
354    }
355
356    // Sort implementing _args by TypeId for deterministic dispatch order
357    // This ensures consistent dispatch behavior across calls
358    implementing_args.sort_by_key(|&_type_id_| {
359        // Use TypeId hash for deterministic ordering
360        use std::hash::{Hash, Hasher};
361        let mut hasher = std::collections::hash_map::DefaultHasher::new();
362        std::any::TypeId::of::<i32>().hash(&mut hasher);
363        hasher.finish()
364    });
365
366    implementing_args
367}
368
369/// Calls the array protocol with the given function and arguments.
370///
371/// * `func` - The array function to call
372/// * `args` - The arguments to the function
373/// * `kwargs` - Named arguments to the function
374///
375/// Returns the result of the function call, or an error if the function
376/// cannot be dispatched to any of the array protocol implementations.
377///
378/// Optimized version with caching and fast-path optimizations.
379#[allow(dead_code)]
380pub fn array_function_dispatch(
381    func: &ArrayFunction,
382    args: &[Box<dyn Any>],
383    kwargs: &HashMap<String, Box<dyn Any>>,
384) -> CoreResult<Box<dyn Any>> {
385    // Fast path for empty args
386    if args.is_empty() {
387        return (func.implementation)(args, kwargs);
388    }
389
390    // Find all arguments implementing ArrayProtocol
391    let implementing_args = get_implementing_args(args);
392
393    if implementing_args.is_empty() {
394        // No arguments implement ArrayProtocol, use default implementation
395        return (func.implementation)(args, kwargs);
396    }
397
398    // Fast path for single implementing argument
399    if implementing_args.len() == 1 {
400        let (type_id, array_protocol_obj) = implementing_args[0];
401        let types = [type_id];
402        match array_protocol_obj.array_function(func, &types, args, kwargs) {
403            Ok(result) => return Ok(result),
404            Err(NotImplemented) => {
405                return Err(CoreError::DispatchError(ErrorContext::new(format!(
406                    "No implementation found for {} with type {:?}",
407                    func.name, type_id
408                ))));
409            }
410        }
411    }
412
413    // Extract all unique types that implement ArrayProtocol (optimized)
414    let mut unique_types = Vec::with_capacity(implementing_args.len());
415    let mut seen_types = std::collections::HashSet::with_capacity(implementing_args.len());
416
417    for &(type_id, _) in &implementing_args {
418        if seen_types.insert(type_id) {
419            unique_types.push(type_id);
420        }
421    }
422
423    // Try dispatching to each implementation in priority order
424    for (_, array_protocol_obj) in implementing_args {
425        if let Ok(result) = array_protocol_obj.array_function(func, &unique_types, args, kwargs) {
426            return Ok(result);
427        }
428    }
429
430    // If we get here, no implementation was found
431    Err(CoreError::DispatchError(ErrorContext::new(format!(
432        "No implementation found for {} with {} argument types: {:?}",
433        func.name,
434        unique_types.len(),
435        unique_types
436    ))))
437}
438
439/// Decorator for adding array function dispatch capabilities to a function.
440///
441/// This is similar to `NumPy`'s `array_function_dispatch` decorator.
442pub struct ArrayFunctionDecorator<F> {
443    function: F,
444    name: &'static str,
445}
446
447impl<F> ArrayFunctionDecorator<F>
448where
449    F: Send + Sync + 'static,
450{
451    /// Create a new array function decorator.
452    #[must_use]
453    pub fn new(function: F, name: &'static str) -> Self {
454        Self { function, name }
455    }
456
457    /// Register the function with the global registry.
458    pub fn register(self) -> F {
459        let implementation = Arc::new(
460            move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
461                // Implementation that converts generic arguments to specific types
462                // and calls the original function
463                // This is a simplified version - in practice, we would need more complex
464                // type conversion
465                Err(CoreError::NotImplementedError(ErrorContext::new(
466                    "ArrayFunctionDecorator: Type conversion in array_function_dispatch is not implemented yet".to_string()
467                )))
468            },
469        );
470
471        let func = ArrayFunction {
472            name: self.name,
473            implementation,
474        };
475
476        // Register the function with the global registry
477        let registry = ArrayFunctionRegistry::global();
478        if let Ok(mut registry) = registry.write() {
479            registry.register(func);
480        } else {
481            eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry, skipping function registration");
482            // Continue without registration - this may result in reduced functionality but avoids crash
483        }
484
485        self.function
486    }
487}
488
489/// Trait for arrays that can support GPU operations.
490pub trait GPUArray: ArrayProtocol {
491    /// Move the array to GPU.
492    fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>>;
493
494    /// Move the array from GPU to CPU.
495    fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
496
497    /// Check if the array is on GPU.
498    #[must_use]
499    fn is_on_gpu(&self) -> bool;
500
501    /// Get information about the GPU device that holds this array.
502    #[must_use]
503    fn device_info(&self) -> HashMap<String, String>;
504}
505
506/// Trait for distributed arrays that can span multiple machines.
507pub trait DistributedArray: ArrayProtocol {
508    /// Get information about the distribution of this array.
509    #[must_use]
510    fn distribution_info(&self) -> HashMap<String, String>;
511
512    /// Gather the distributed array to a single node.
513    fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
514
515    /// Scatter a regular array to a distributed array.
516    fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>>;
517
518    /// Check if this array is distributed.
519    #[must_use]
520    fn is_distributed(&self) -> bool;
521}
522
523/// JIT (Just-In-Time) compilation support for arrays.
524pub trait JITArray: ArrayProtocol {
525    /// Compile an expression to be evaluated on this array.
526    fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>>;
527
528    /// Check if JIT compilation is supported for this array type.
529    #[must_use]
530    fn supports_jit(&self) -> bool;
531
532    /// Get information about the JIT compiler being used.
533    #[must_use]
534    fn jit_info(&self) -> HashMap<String, String>;
535}
536
537/// A JIT-compiled function that can be evaluated on arrays.
538pub trait JITFunction: Send + Sync {
539    /// Evaluate the function with the given arguments.
540    fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>>;
541
542    /// Get the source code of the compiled function.
543    #[must_use]
544    fn source(&self) -> String;
545
546    /// Get information about how the function was compiled.
547    #[must_use]
548    fn compile_info(&self) -> HashMap<String, String>;
549
550    /// Clone this JIT function into a `Box<dyn JITFunction>`.
551    #[must_use]
552    fn clone_box(&self) -> Box<dyn JITFunction>;
553}
554
555/// A factory for creating JIT functions for specific array implementations.
556pub trait JITFunctionFactory: Send + Sync {
557    /// Create a new JIT function for the given expression and array type.
558    fn create_jit_function(
559        &self,
560        expression: &str,
561        array_typeid: TypeId,
562    ) -> CoreResult<Box<dyn JITFunction>>;
563
564    /// Check if this factory supports the given array type.
565    #[must_use]
566    fn supports_array_type(&self, array_typeid: TypeId) -> bool;
567}
568
569/// Registry of JIT function factories.
570#[derive(Default)]
571pub struct JITFactoryRegistry {
572    factories: Vec<Box<dyn JITFunctionFactory>>,
573}
574
575impl std::fmt::Debug for JITFactoryRegistry {
576    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
577        write!(
578            f,
579            "JITFactoryRegistry {{ factories: {} }}",
580            self.factories.len()
581        )
582    }
583}
584
585impl JITFactoryRegistry {
586    /// Get the global registry.
587    #[must_use]
588    pub fn global() -> &'static RwLock<Self> {
589        static REGISTRY: LazyLock<RwLock<JITFactoryRegistry>> = LazyLock::new(|| {
590            RwLock::new(JITFactoryRegistry {
591                factories: Vec::new(),
592            })
593        });
594        &REGISTRY
595    }
596
597    /// Register a new JIT function factory.
598    pub fn register(&mut self, factory: Box<dyn JITFunctionFactory>) {
599        self.factories.push(factory);
600    }
601
602    /// Get a JIT function factory that supports the given array type.
603    #[must_use]
604    pub fn get_factory_for_array_type(
605        &self,
606        array_typeid: TypeId,
607    ) -> Option<&dyn JITFunctionFactory> {
608        for factory in &self.factories {
609            if factory.supports_array_type(array_typeid) {
610                return Some(&**factory);
611            }
612        }
613        None
614    }
615}
616
617/// A wrapper for ndarray to implement the ArrayProtocol trait.
618#[derive(Debug, Clone)]
619pub struct NdarrayWrapper<T, D: crate::ndarray::Dimension> {
620    array: crate::ndarray::Array<T, D>,
621    phantom: PhantomData<(T, D)>,
622}
623
624impl<T, D> NdarrayWrapper<T, D>
625where
626    T: Clone + 'static,
627    D: crate::ndarray::Dimension + 'static,
628{
629    /// Create a new ndarray wrapper.
630    #[must_use]
631    pub fn new(array: crate::ndarray::Array<T, D>) -> Self {
632        Self {
633            array,
634            phantom: PhantomData,
635        }
636    }
637
638    /// Get the underlying ndarray.
639    #[must_use]
640    pub const fn as_array(&self) -> &crate::ndarray::Array<T, D> {
641        &self.array
642    }
643
644    /// Convert into the underlying ndarray.
645    #[must_use]
646    pub fn into_array(self) -> crate::ndarray::Array<T, D> {
647        self.array
648    }
649
650    /// Update the underlying array with a new one.
651    pub fn array_2(&mut self, newarray: crate::ndarray::Array<T, D>) {
652        self.array = newarray;
653    }
654}
655
656impl<T, D> ArrayProtocol for NdarrayWrapper<T, D>
657where
658    T: Clone + Send + Sync + 'static,
659    D: crate::ndarray::Dimension + Send + Sync + 'static,
660{
661    fn array_function(
662        &self,
663        func: &ArrayFunction,
664        _types: &[TypeId],
665        args: &[Box<dyn Any>],
666        kwargs: &HashMap<String, Box<dyn Any>>,
667    ) -> Result<Box<dyn Any>, NotImplemented> {
668        match func.name {
669            "scirs2::array_protocol::operations::add" => {
670                // Addition operation for NdarrayWrapper
671                if args.len() < 2 {
672                    return Err(NotImplemented);
673                }
674
675                if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
676                    if let (Some(a), Some(b)) = (
677                        self.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
678                        other.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
679                    ) {
680                        // Need to make sure T supports addition
681                        if TypeId::of::<T>() == TypeId::of::<f64>() {
682                            let a_f64 =
683                                unsafe { &*(a as *const _ as *const NdarrayWrapper<f64, D>) };
684                            let b_f64 =
685                                unsafe { &*(b as *const _ as *const NdarrayWrapper<f64, D>) };
686                            let result = a_f64.as_array() + b_f64.as_array();
687                            return Ok(Box::new(NdarrayWrapper::new(result)));
688                        } else if TypeId::of::<T>() == TypeId::of::<f32>() {
689                            let a_f32 =
690                                unsafe { &*(a as *const _ as *const NdarrayWrapper<f32, D>) };
691                            let b_f32 =
692                                unsafe { &*(b as *const _ as *const NdarrayWrapper<f32, D>) };
693                            let result = a_f32.as_array() + b_f32.as_array();
694                            return Ok(Box::new(NdarrayWrapper::new(result)));
695                        }
696                    }
697                }
698                Err(NotImplemented)
699            }
700            "scirs2::array_protocol::operations::matmul" => {
701                // Matrix multiplication for NdarrayWrapper
702                if args.len() < 2 {
703                    return Err(NotImplemented);
704                }
705
706                // We can only handle matrix multiplication for 2D arrays
707                // Check for 2D array using TypeId
708                if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
709                    return Err(NotImplemented);
710                }
711
712                if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
713                    // Since we've already checked TypeId::of::<D>() == TypeId::of::<crate::ndarray::Ix2>()
714                    // We can safely specialize for Ix2 matrices
715
716                    // Handle the case for f64 matrices
717                    if TypeId::of::<T>() == TypeId::of::<f64>() {
718                        // Cast to concrete _types we know how to handle
719                        let a_f64 = unsafe {
720                            &*(self as *const _ as *const NdarrayWrapper<f64, crate::ndarray::Ix2>)
721                        };
722                        let b_f64 = unsafe {
723                            &*(other as *const _ as *const NdarrayWrapper<f64, crate::ndarray::Ix2>)
724                        };
725
726                        // Get dimensions
727                        let ashape = a_f64.as_array().shape();
728                        let bshape = b_f64.as_array().shape();
729
730                        if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
731                            return Err(NotImplemented);
732                        }
733
734                        // Use the higher-level dot operation which will be more efficient
735                        // than our manual implementation
736                        let result = a_f64.as_array().dot(b_f64.as_array());
737                        return Ok(Box::new(NdarrayWrapper::new(result)));
738                    }
739                    // Handle the case for f32 matrices
740                    else if TypeId::of::<T>() == TypeId::of::<f32>() {
741                        // Cast to concrete _types we know how to handle
742                        let a_f32 = unsafe {
743                            &*(self as *const _ as *const NdarrayWrapper<f32, crate::ndarray::Ix2>)
744                        };
745                        let b_f32 = unsafe {
746                            &*(other as *const _ as *const NdarrayWrapper<f32, crate::ndarray::Ix2>)
747                        };
748
749                        // Get dimensions
750                        let ashape = a_f32.as_array().shape();
751                        let bshape = b_f32.as_array().shape();
752
753                        if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
754                            return Err(NotImplemented);
755                        }
756
757                        // Use the higher-level dot operation which will be more efficient
758                        // than our manual implementation
759                        let result = a_f32.as_array().dot(b_f32.as_array());
760                        return Ok(Box::new(NdarrayWrapper::new(result)));
761                    }
762                }
763                // If we get here, we don't know how to handle this case
764                Err(NotImplemented)
765            }
766            "scirs2::array_protocol::operations::transpose" => {
767                // Transpose operation for NdarrayWrapper
768                if TypeId::of::<T>() == TypeId::of::<f64>() {
769                    let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
770                    let result = a_f64.as_array().t().to_owned();
771                    return Ok(Box::new(NdarrayWrapper::new(result)));
772                } else if TypeId::of::<T>() == TypeId::of::<f32>() {
773                    let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
774                    let result = a_f32.as_array().t().to_owned();
775                    return Ok(Box::new(NdarrayWrapper::new(result)));
776                }
777                Err(NotImplemented)
778            }
779            "scirs2::array_protocol::operations::sum" => {
780                // Sum operation for NdarrayWrapper
781                let axis_ref = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
782
783                if TypeId::of::<T>() == TypeId::of::<f64>() {
784                    let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
785                    match axis_ref {
786                        Some(&_ax) => {
787                            // Can't use sum_axis without RemoveAxis trait
788                            // Just return the full sum for now
789                            let result = a_f64.as_array().sum();
790                            return Ok(Box::new(result));
791                        }
792                        None => {
793                            let result = a_f64.as_array().sum();
794                            return Ok(Box::new(result));
795                        }
796                    }
797                } else if TypeId::of::<T>() == TypeId::of::<f32>() {
798                    let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
799                    match axis_ref {
800                        Some(&_ax) => {
801                            // Can't use sum_axis without RemoveAxis trait
802                            // Just return the full sum for now
803                            let result = a_f32.as_array().sum();
804                            return Ok(Box::new(result));
805                        }
806                        None => {
807                            let result = a_f32.as_array().sum();
808                            return Ok(Box::new(result));
809                        }
810                    }
811                }
812                Err(NotImplemented)
813            }
814            "scirs2::array_protocol::operations::reshape" => {
815                // Reshape operation for NdarrayWrapper
816                if let Some(shape) = kwargs
817                    .get("shape")
818                    .and_then(|s| s.downcast_ref::<Vec<usize>>())
819                {
820                    if TypeId::of::<T>() == TypeId::of::<f64>() {
821                        let a_f64 =
822                            unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
823                        match a_f64
824                            .as_array()
825                            .clone()
826                            .into_shape_with_order(shape.clone())
827                        {
828                            Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
829                            Err(_) => return Err(NotImplemented),
830                        }
831                    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
832                        let a_f32 =
833                            unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
834                        match a_f32
835                            .as_array()
836                            .clone()
837                            .into_shape_with_order(shape.clone())
838                        {
839                            Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
840                            Err(_) => return Err(NotImplemented),
841                        }
842                    }
843                }
844                Err(NotImplemented)
845            }
846            _ => Err(NotImplemented),
847        }
848    }
849
850    fn as_any(&self) -> &dyn Any {
851        self
852    }
853
854    fn shape(&self) -> &[usize] {
855        self.array.shape()
856    }
857
858    fn dtype(&self) -> TypeId {
859        TypeId::of::<T>()
860    }
861
862    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
863        Box::new(self.clone())
864    }
865}
866
867// Example implementation for a third-party array library:
868
869/// A mock distributed array implementation.
870#[derive(Debug, Clone)]
871pub struct MockDistributedArray<T: Clone + 'static> {
872    chunks: Vec<T>,
873    shape: Vec<usize>,
874}
875
876impl<T: Clone + Send + Sync + 'static> MockDistributedArray<T> {
877    /// Create a new mock distributed array.
878    #[must_use]
879    pub fn new(chunks: Vec<T>, shape: Vec<usize>) -> Self {
880        Self { chunks, shape }
881    }
882}
883
884impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockDistributedArray<T> {
885    fn array_function(
886        &self,
887        func: &ArrayFunction,
888        _types: &[TypeId],
889        _args: &[Box<dyn Any>],
890        _kwargs: &HashMap<String, Box<dyn Any>>,
891    ) -> Result<Box<dyn Any>, NotImplemented> {
892        match func.name {
893            "scirs2::mean" => {
894                // Example: Implement a mean function for distributed arrays
895                // In a real implementation, this would use distributed computation
896
897                // For simplicity, we'll just return a dummy result
898                let result = T::clone(&self.chunks[0]);
899                Ok(Box::new(result))
900            }
901            _ => Err(NotImplemented),
902        }
903    }
904
905    fn as_any(&self) -> &dyn Any {
906        self
907    }
908
909    fn shape(&self) -> &[usize] {
910        &self.shape
911    }
912
913    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
914        Box::new(self.clone())
915    }
916}
917
918impl<T: Clone + Send + Sync + 'static> DistributedArray for MockDistributedArray<T> {
919    fn distribution_info(&self) -> HashMap<String, String> {
920        let mut info = HashMap::new();
921        info.insert("type".to_string(), "mock_distributed".to_string());
922        info.insert("chunks".to_string(), self.chunks.len().to_string());
923        info
924    }
925
926    fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
927        // In a real implementation, this would gather data from all nodes
928        // For now, we just return self boxed as ArrayProtocol
929        Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
930    }
931
932    fn scatter(&self, _numchunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
933        // In a real implementation, this would scatter data to multiple nodes
934        // For now, we just return self boxed as DistributedArray
935        Ok(Box::new(self.clone()) as Box<dyn DistributedArray>)
936    }
937
938    fn is_distributed(&self) -> bool {
939        true
940    }
941}
942
943/// A mock GPU array implementation.
944#[derive(Debug, Clone)]
945pub struct MockGPUArray<T: Clone + 'static> {
946    data: Vec<T>,
947    shape: Vec<usize>,
948    device: String,
949}
950
951impl<T: Clone + Send + Sync + 'static> MockGPUArray<T> {
952    /// Create a new mock GPU array.
953    #[must_use]
954    pub fn new(data: Vec<T>, shape: Vec<usize>, device: String) -> Self {
955        Self {
956            data,
957            shape,
958            device,
959        }
960    }
961}
962
963impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockGPUArray<T> {
964    fn array_function(
965        &self,
966        func: &ArrayFunction,
967        _types: &[TypeId],
968        _args: &[Box<dyn Any>],
969        _kwargs: &HashMap<String, Box<dyn Any>>,
970    ) -> Result<Box<dyn Any>, NotImplemented> {
971        match func.name {
972            "scirs2::matmul" => {
973                // Example: Implement a GPU-accelerated matrix multiplication
974                // In a real implementation, this would use GPU computation
975
976                // For simplicity, we'll just return a dummy result
977                let result =
978                    MockGPUArray::new(self.data.clone(), self.shape.clone(), self.device.clone());
979                Ok(Box::new(result))
980            }
981            _ => Err(NotImplemented),
982        }
983    }
984
985    fn as_any(&self) -> &dyn Any {
986        self
987    }
988
989    fn shape(&self) -> &[usize] {
990        &self.shape
991    }
992
993    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
994        Box::new(self.clone())
995    }
996}
997
998impl<T: Clone + Send + Sync + 'static> GPUArray for MockGPUArray<T> {
999    fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
1000        // Already on GPU
1001        Ok(Box::new(self.clone()) as Box<dyn GPUArray>)
1002    }
1003
1004    fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
1005        // In a real implementation, this would transfer data from GPU to CPU
1006        // For now, we just return self boxed as ArrayProtocol
1007        Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
1008    }
1009
1010    fn is_on_gpu(&self) -> bool {
1011        true
1012    }
1013
1014    fn device_info(&self) -> HashMap<String, String> {
1015        let mut info = HashMap::new();
1016        info.insert("device".to_string(), self.device.clone());
1017        info.insert("type".to_string(), "mock_gpu".to_string());
1018        info
1019    }
1020}
1021
1022/// A factory for creating and registering array protocol enabled functions.
1023///
1024/// This provides a convenient way to create functions that can be overridden
1025/// by third-party array implementations.
1026#[derive(Debug)]
1027pub struct ArrayProtocolFunction<F> {
1028    func: F,
1029    name: &'static str,
1030}
1031
1032impl<F> ArrayProtocolFunction<F> {
1033    /// Create a new array protocol function.
1034    #[must_use]
1035    pub fn new(func: F, name: &'static str) -> Self {
1036        Self { func, name }
1037    }
1038}
1039
1040impl<F> ArrayProtocolFunction<F>
1041where
1042    F: Clone + Send + Sync + 'static,
1043{
1044    /// Register this function with the array protocol system.
1045    pub fn register(self) -> F {
1046        let implementation = Arc::new(
1047            move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1048                // This is a placeholder for actual implementation that would:
1049                // 1. Convert generic args to specific types needed by the function
1050                // 2. Call the function with the converted args
1051                // 3. Return the result as a Box<dyn Any>
1052                Err(CoreError::NotImplementedError(ErrorContext::new(
1053                    "ArrayProtocolFunction: Implementation for array protocol functions is not complete".to_string()
1054                )))
1055            },
1056        );
1057
1058        let array_func = ArrayFunction {
1059            name: self.name,
1060            implementation,
1061        };
1062
1063        // Register the function
1064        if let Ok(mut registry) = ArrayFunctionRegistry::global().write() {
1065            registry.register(array_func);
1066        } else {
1067            eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry during array protocol building, skipping function registration");
1068            // Continue without registration - this may result in reduced functionality but avoids crash
1069        }
1070
1071        self.func
1072    }
1073}
1074
1075/// Convenience macro for defining array protocol functions.
1076///
1077/// This macro creates a function and registers it with the array protocol system.
1078/// The function can then be overridden by array types that implement the ArrayProtocol trait.
1079///
1080/// Example usage:
1081/// ```ignore
1082/// use scirs2_core::array_protocol::{ArrayFunction, ArrayFunctionRegistry};
1083/// use std::sync::Arc;
1084/// use std::collections::HashMap;
1085/// use std::any::Any;
1086///
1087/// // Define and register a sum function
1088/// fn register_sum_function() {
1089///     let implementation = Arc::new(
1090///         move |args: &[Box<dyn Any>], kwargs: &HashMap<String, Box<dyn Any>>| {
1091///             if let Some(array) = args.get(0)
1092///                 .and_then(|arg| arg.downcast_ref::<crate::ndarray::Array<f64, crate::ndarray::Ix2>>()) {
1093///                 let sum = array.sum();
1094///                 Ok(Box::new(sum) as Box<dyn Any>)
1095///             } else {
1096///                 Err(scirs2_core::error::CoreError::InvalidArgument(
1097///                     scirs2_core::error::ErrorContext::new(
1098///                         "Expected Array2<f64> as first argument".to_string()
1099///                     )
1100///                 ))
1101///             }
1102///         }
1103///     );
1104///     
1105///     let func = ArrayFunction {
1106///         name: "scirs2::sum",
1107///         implementation,
1108///     };
1109///     
1110///     // Register the function
1111///     if let Ok(mut registry) = ArrayFunctionRegistry::global().write() {
1112///         registry.register(func);
1113///     }
1114/// }
1115/// ```
1116#[macro_export]
1117macro_rules! array_function_def {
1118    (fn $name:ident $(<$($gen:ident),*>)? ($($arg:ident : $arg_ty:ty),*) -> $ret:ty $body:block, $funcname:expr) => {
1119        {
1120            // Define the function
1121            fn $name $(<$($gen),*>)? ($($arg : $arg_ty),*) -> $ret $body
1122
1123            // Return the function so it can be used
1124            $name
1125        }
1126    };
1127}
1128
1129// Re-export distributed array implementation
1130pub use self::distributed_impl::{
1131    ArrayChunk, DistributedBackend, DistributedConfig, DistributedNdarray, DistributionStrategy,
1132};
1133
1134// Re-export GPU array implementation
1135pub use self::gpu_impl::{
1136    kernels as gpu_kernels, GPUArrayBuilder, GPUBackend, GPUConfig, GPUNdarray,
1137};
1138
1139// Re-export JIT compilation implementation
1140pub use self::jit_impl::{
1141    CraneliftFunctionFactory, JITBackend, JITConfig, JITEnabledArray, JITFunctionImpl, JITManager,
1142    LLVMFunctionFactory,
1143};
1144
1145// Re-export array operations
1146pub use self::operations::{
1147    add, apply_elementwise, concatenate, inverse, matmul, multiply, reshape, subtract, sum, svd,
1148    transpose, OperationError,
1149};
1150
1151// Re-export ml_ops
1152pub use self::ml_ops::{
1153    activation, batch_norm, conv2d, cross_entropy, dropout, max_pool2d, self_attention,
1154    ActivationFunc,
1155};
1156
1157/// Initializes the array protocol system.
1158///
1159/// This function initializes the array protocol system by registering the
1160/// default JIT function factories and other components. It should be called
1161/// before using any of the array protocol features.
1162#[allow(dead_code)]
1163pub fn init() {
1164    // Initialize the JIT manager
1165    let mut jit_manager = JITManager::global().write().expect("Operation failed");
1166    jit_manager.initialize();
1167}
1168
1169/// Extra traits for third-party array implementations.
1170pub mod traits {
1171    use super::*;
1172
1173    /// Trait for arrays that support strided access.
1174    pub trait StridedArray: ArrayProtocol {
1175        /// Get the strides of this array.
1176        #[must_use]
1177        fn strides(&self) -> Vec<usize>;
1178
1179        /// Check if this array is contiguous.
1180        #[must_use]
1181        fn is_contiguous(&self) -> bool;
1182
1183        /// Check if this array is Fortran-contiguous (column-major).
1184        #[must_use]
1185        fn is_fortran_contiguous(&self) -> bool;
1186    }
1187
1188    /// Trait for arrays that support zero-copy operations.
1189    pub trait ZeroCopyArray: ArrayProtocol {
1190        /// Create a view of this array.
1191        #[must_use]
1192        fn view(&self) -> Box<dyn ZeroCopyArray>;
1193
1194        /// Create a mutable view of this array.
1195        #[must_use]
1196        fn view_mut(&mut self) -> Box<dyn ZeroCopyArray>;
1197
1198        /// Check if this array is a view.
1199        #[must_use]
1200        fn is_view(&self) -> bool;
1201    }
1202
1203    /// Trait for arrays that support automatic differentiation.
1204    pub trait DifferentiableArray: ArrayProtocol {
1205        /// Compute the gradient of this array with respect to some variables.
1206        fn gradient(
1207            &self,
1208            variables: &[Box<dyn DifferentiableArray>],
1209        ) -> Vec<Box<dyn DifferentiableArray>>;
1210
1211        /// Set whether to record operations for automatic differentiation.
1212        fn set_requiresgrad(&mut self, requiresgrad: bool);
1213
1214        /// Check if this array requires gradient computation.
1215        #[must_use]
1216        fn requiresgrad(&self) -> bool;
1217
1218        /// Get the gradient of this array.
1219        #[must_use]
1220        fn grad(&self) -> Option<Box<dyn DifferentiableArray>>;
1221    }
1222
1223    /// Trait for arrays that support asynchronous operations.
1224    pub trait AsyncArray: ArrayProtocol {
1225        /// Perform an asynchronous operation on this array.
1226        fn async_op<F, R>(&self, op: F) -> impl std::future::Future<Output = CoreResult<R>>
1227        where
1228            F: FnOnce(&Self) -> CoreResult<R> + Send + 'static,
1229            R: Send + 'static;
1230
1231        /// Check if this array supports asynchronous operations.
1232        #[must_use]
1233        fn supports_async(&self) -> bool;
1234    }
1235}
1236
1237#[cfg(test)]
1238mod tests {
1239    use super::*;
1240
1241    #[test]
1242    fn test_array_protocol_registry() {
1243        // Create a function and register it
1244        let implementation = Arc::new(
1245            move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1246                Ok(Box::new(42.0) as Box<dyn Any>)
1247            },
1248        );
1249
1250        let func = ArrayFunction {
1251            name: "scirs2::test::test_func",
1252            implementation,
1253        };
1254
1255        let registry = ArrayFunctionRegistry::global();
1256        {
1257            let mut reg = registry.write().expect("Operation failed");
1258            reg.register(func.clone());
1259        }
1260
1261        // Verify the function was registered
1262        {
1263            let reg = registry.read().expect("Operation failed");
1264            let registered_func = reg
1265                .get("scirs2::test::test_func")
1266                .expect("Operation failed");
1267            assert_eq!(registered_func.name, "scirs2::test::test_func");
1268        }
1269    }
1270
1271    #[test]
1272    fn test_mock_distributed_array() {
1273        let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1274        assert!(array.is_distributed());
1275
1276        let info = array.distribution_info();
1277        assert_eq!(
1278            info.get("type").expect("Operation failed"),
1279            "mock_distributed"
1280        );
1281        assert_eq!(info.get("chunks").expect("Operation failed"), "3");
1282    }
1283
1284    #[test]
1285    fn test_mock_gpu_array() {
1286        let array = MockGPUArray::new(vec![1.0, 2.0, 3.0], vec![3], "cuda:0".to_string());
1287        assert!(array.is_on_gpu());
1288
1289        let info = array.device_info();
1290        assert_eq!(info.get("device").expect("Operation failed"), "cuda:0");
1291        assert_eq!(info.get("type").expect("Operation failed"), "mock_gpu");
1292    }
1293
1294    #[test]
1295    fn test_box_clone() {
1296        // Test Box<dyn ArrayProtocol> cloning for NdarrayWrapper
1297        let array = crate::ndarray::Array2::<f64>::ones((3, 3));
1298        let wrapped = NdarrayWrapper::new(array);
1299        let boxed: Box<dyn ArrayProtocol> = Box::new(wrapped);
1300        let cloned = boxed.clone();
1301
1302        // Verify the clone is correct
1303        assert_eq!(cloned.shape(), &[3, 3]);
1304
1305        // Test Box<dyn ArrayProtocol> cloning for MockDistributedArray
1306        let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1307        let boxed: Box<dyn ArrayProtocol> = Box::new(array);
1308        let cloned = boxed.clone();
1309
1310        // Verify the clone is correct
1311        assert_eq!(cloned.shape(), &[3]);
1312    }
1313}
1314
1315/// Examples of using the array protocol.
1316#[cfg(test)]
1317mod examples {
1318    use super::*;
1319    use ::ndarray::Array2;
1320    use std::any::Any;
1321    use std::collections::HashMap;
1322
1323    /// Example: Create and use a distributed array.
1324    #[test]
1325    fn example_distributed_array() {
1326        // Create a regular array
1327        let array = Array2::<f64>::ones((10, 5));
1328
1329        // Create a distributed array configuration
1330        let config = DistributedConfig {
1331            chunks: 3,
1332            balance: true,
1333            strategy: DistributionStrategy::RowWise,
1334            backend: DistributedBackend::Threaded,
1335        };
1336
1337        // Create a distributed array
1338        let dist_array = DistributedNdarray::from_array(&array, config);
1339
1340        // Check that the array was split correctly
1341        assert_eq!(dist_array.num_chunks(), 3);
1342        assert_eq!(dist_array.shape(), &[10, 5]);
1343
1344        // Convert back to a regular array
1345        let result = dist_array.to_array().expect("Operation failed");
1346
1347        // Check that the result matches the original array
1348        assert_eq!(result.shape(), array.shape());
1349        // NOTE: Arrays with different dimensions can't be directly compared
1350        // assert_eq!(result, array);
1351    }
1352
1353    /// Example: Create and use a GPU array.
1354    #[test]
1355    fn example_gpu_array() {
1356        // Create a regular array
1357        let array = Array2::<f64>::ones((10, 5));
1358
1359        // Create a GPU array configuration
1360        let config = GPUConfig {
1361            backend: GPUBackend::CUDA,
1362            device_id: 0,
1363            async_ops: true,
1364            mixed_precision: false,
1365            memory_fraction: 0.9,
1366        };
1367
1368        // Create a GPU array
1369        let gpu_array = GPUNdarray::new(array.clone(), config);
1370
1371        // Check that the array was created correctly
1372        assert_eq!(gpu_array.shape(), &[10, 5]);
1373        assert!(gpu_array.is_on_gpu());
1374
1375        // Get device information
1376        let info = gpu_array.device_info();
1377        assert_eq!(info.get("backend").expect("Operation failed"), "CUDA");
1378
1379        // Test box_clone for GPU array
1380        let gpu_box: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1381        let gpu_clone = gpu_box.clone();
1382
1383        // Check the cloned GPU array
1384        assert_eq!(gpu_clone.shape(), &[10, 5]);
1385    }
1386
1387    /// Example: Create and use a JIT-enabled array.
1388    #[test]
1389    fn example_jit_array() {
1390        // Initialize the JIT manager
1391        init();
1392
1393        // Create a regular array
1394        let array = Array2::<f64>::ones((10, 5));
1395        let wrapped = NdarrayWrapper::new(array);
1396
1397        // Create a JIT-enabled array
1398        let jitarray: JITEnabledArray<f64, NdarrayWrapper<f64, crate::ndarray::Ix2>> =
1399            JITEnabledArray::new(wrapped);
1400
1401        // Check if JIT is supported
1402        assert!(jitarray.supports_jit());
1403
1404        // Compile a function
1405        let expression = "x + y";
1406        let jit_function = jitarray.compile(expression).expect("Operation failed");
1407
1408        // Check the function's properties
1409        assert_eq!(jit_function.source(), expression);
1410
1411        // Get JIT information
1412        let info = jitarray.jit_info();
1413        assert_eq!(info.get("supports_jit").expect("Operation failed"), "true");
1414
1415        // Test box_clone for JIT-enabled array
1416        let jit_box: Box<dyn ArrayProtocol> = Box::new(jitarray);
1417        let jit_clone = jit_box.clone();
1418
1419        // Check the cloned JIT array
1420        assert_eq!(jit_clone.shape(), &[10, 5]);
1421    }
1422
1423    /// Example: Test cloning Box<dyn ArrayProtocol>
1424    #[test]
1425    fn example_cloning_array_protocol_objects() {
1426        // Create a GPU array with box_clone support
1427        let array = Array2::<f64>::ones((10, 5));
1428        let config = GPUConfig::default();
1429        let gpu_array = GPUNdarray::new(array.clone(), config);
1430
1431        // Box the array as ArrayProtocol and clone it
1432        let boxed: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1433        let cloned = boxed.clone();
1434
1435        // Verify the clone works correctly
1436        assert_eq!(cloned.shape(), &[10, 5]);
1437
1438        // Create a distributed array and test box_clone
1439        let config = DistributedConfig {
1440            chunks: 3,
1441            balance: true,
1442            strategy: DistributionStrategy::RowWise,
1443            backend: DistributedBackend::Threaded,
1444        };
1445        let dist_array = DistributedNdarray::from_array(&array, config);
1446
1447        // Box the array as ArrayProtocol and clone it
1448        let boxed: Box<dyn ArrayProtocol> = Box::new(dist_array);
1449        let cloned = boxed.clone();
1450
1451        // Verify the clone works correctly
1452        assert_eq!(cloned.shape(), &[10, 5]);
1453    }
1454
1455    /*
1456    // Commented out examples using macros - we'll fix these later
1457
1458    /// Example: Define an array function using the macro.
1459    /// Example: Register and use an array function.
1460    #[test]
1461    fn example_array_function() {
1462        // Create a simple array function (without using macros)
1463        let funcname = "scirs2::example::sum";
1464
1465        // Create an ArrayFunction manually
1466        let implementation = Arc::new(move |args: &[Box<dyn Any>], kwargs: &HashMap<String, Box<dyn Any>>| {
1467            if let Some(array) = args.get(0)
1468                .and_then(|arg| arg.downcast_ref::<Array2<f64>>()) {
1469                let sum = array.sum();
1470                Ok(Box::new(sum))
1471            } else {
1472                Err(CoreError::InvalidArgument(ErrorContext::new(
1473                    "Expected Array2<f64> as first argument".to_string()
1474                )))
1475            }
1476        });
1477
1478        let func = ArrayFunction {
1479            name: funcname,
1480            implementation,
1481        };
1482
1483        // Register the function
1484        let registry = ArrayFunctionRegistry::global();
1485        {
1486            let mut reg = registry.write().expect("Operation failed");
1487            reg.register(func.clone());
1488        }
1489
1490        // Verify the function was registered
1491        {
1492            let reg = registry.read().expect("Operation failed");
1493            let registered_func = reg.get(funcname).expect("Operation failed");
1494            assert_eq!(registered_func.name, funcname);
1495        }
1496    }
1497    */
1498
1499    /// Example: Interoperability between different array types
1500    #[test]
1501    fn example_array_interoperability() {
1502        // Initialize the system
1503        init();
1504
1505        // Create arrays of different types
1506        let cpu_array = Array2::<f64>::ones((5, 5));
1507
1508        // Create a GPU array
1509        let gpu_config = GPUConfig {
1510            backend: GPUBackend::CUDA,
1511            device_id: 0,
1512            async_ops: false,
1513            mixed_precision: false,
1514            memory_fraction: 0.9,
1515        };
1516        let gpu_array = GPUNdarray::new(cpu_array.clone(), gpu_config);
1517
1518        // Create a distributed array
1519        let dist_config = DistributedConfig {
1520            chunks: 2,
1521            balance: true,
1522            strategy: DistributionStrategy::RowWise,
1523            backend: DistributedBackend::Threaded,
1524        };
1525        let dist_array = DistributedNdarray::from_array(&cpu_array, dist_config);
1526
1527        // Simple test of interoperability: convert both to Box<dyn ArrayProtocol>
1528        let gpu_wrapper: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1529        let dist_wrapper: Box<dyn ArrayProtocol> = Box::new(dist_array);
1530
1531        // Verify the clones work correctly
1532        let gpu_clone = gpu_wrapper.clone();
1533        let dist_clone = dist_wrapper.clone();
1534
1535        assert_eq!(gpu_clone.shape(), &[5, 5]);
1536        assert_eq!(dist_clone.shape(), &[5, 5]);
1537    }
1538
1539    /// Example: Advanced usage with custom array type
1540    #[test]
1541    fn example_custom_array_type() {
1542        use std::sync::Arc;
1543
1544        // Define a custom array type
1545        struct MyCustomArray<T> {
1546            data: Vec<T>,
1547            shape: Vec<usize>,
1548        }
1549
1550        impl<T: Clone + 'static> MyCustomArray<T> {
1551            fn new(data: Vec<T>, shape: Vec<usize>) -> Self {
1552                Self { data, shape }
1553            }
1554
1555            // Commented out since it's unused but may be needed in the future
1556            // fn shape(&self) -> &[usize] {
1557            //     &self.shape
1558            // }
1559        }
1560
1561        // Implement ArrayProtocol for the custom array type
1562        impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MyCustomArray<T> {
1563            fn array_function(
1564                &self,
1565                func: &ArrayFunction,
1566                _types: &[TypeId],
1567                _args: &[Box<dyn Any>],
1568                _kwargs: &HashMap<String, Box<dyn Any>>,
1569            ) -> Result<Box<dyn Any>, NotImplemented> {
1570                if func.name == "scirs2::example::custom_sum" {
1571                    // Implement custom sum for our array type
1572                    match std::any::TypeId::of::<T>() {
1573                        tid if tid == std::any::TypeId::of::<f64>() => {
1574                            // For f64 arrays, cast to f64 slice
1575                            let f64_data = unsafe {
1576                                std::slice::from_raw_parts(
1577                                    self.data.as_ptr() as *const f64,
1578                                    self.data.len(),
1579                                )
1580                            };
1581                            let sum = f64_data.iter().sum::<f64>();
1582                            Ok(Box::new(sum))
1583                        }
1584                        tid if tid == std::any::TypeId::of::<f32>() => {
1585                            // For f32 arrays, cast to f32 slice
1586                            let f32_data = unsafe {
1587                                std::slice::from_raw_parts(
1588                                    self.data.as_ptr() as *const f32,
1589                                    self.data.len(),
1590                                )
1591                            };
1592                            let sum = f32_data.iter().sum::<f32>();
1593                            Ok(Box::new(sum))
1594                        }
1595                        _ => Err(NotImplemented),
1596                    }
1597                } else {
1598                    Err(NotImplemented)
1599                }
1600            }
1601
1602            fn as_any(&self) -> &dyn Any {
1603                self
1604            }
1605
1606            fn shape(&self) -> &[usize] {
1607                &self.shape
1608            }
1609
1610            fn box_clone(&self) -> Box<dyn ArrayProtocol> {
1611                Box::new(MyCustomArray {
1612                    data: self.data.clone(),
1613                    shape: self.shape.clone(),
1614                })
1615            }
1616        }
1617
1618        // Create an instance of the custom array type
1619        let custom_array = MyCustomArray::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1620
1621        // Test box_clone functionality
1622        let boxed: Box<dyn ArrayProtocol> = Box::new(custom_array);
1623        let cloned = boxed.clone();
1624
1625        // Verify the clone has the correct shape
1626        assert_eq!(cloned.shape(), &[2, 2]);
1627
1628        // Create an ArrayFunction for testing
1629        let func = ArrayFunction {
1630            name: "scirs2::example::custom_sum",
1631            implementation: Arc::new(move |_args, _kwargs| {
1632                // Dummy implementation
1633                Ok(Box::new(42.0) as Box<dyn Any>)
1634            }),
1635        };
1636
1637        // Test array_function directly
1638        let result = cloned.array_function(
1639            &func,
1640            &[std::any::TypeId::of::<f64>()],
1641            &[],
1642            &HashMap::new(),
1643        );
1644
1645        // Verify we get a result (the sum of 1+2+3+4 = 10)
1646        assert!(result.is_ok());
1647        if let Ok(value) = result {
1648            let sum = *value.downcast_ref::<f64>().expect("Operation failed");
1649            assert_eq!(sum, 10.0);
1650        }
1651    }
1652}
1653/// Make `Box<dyn JITFunction>` cloneable via clone_box
1654impl Clone for Box<dyn JITFunction> {
1655    fn clone(&self) -> Self {
1656        self.clone_box()
1657    }
1658}