scirs2_core/array_protocol/
mod.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under either of
4//
5// * Apache License, Version 2.0
6//   (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
7// * MIT license
8//   (LICENSE-MIT or http://opensource.org/licenses/MIT)
9//
10// at your option.
11//
12
13//! Implementation of Array Protocol (similar to ``NumPy``'s `__array_function__` protocol)
14//!
15//! This module provides a mechanism for third-party array implementations to
16//! override ``SciRS2`` functions. It is inspired by ``NumPy``'s `__array_function__`
17//! protocol defined in NEP-18.
18//!
19//! The protocol enables third-party arrays to implement ``SciRS2`` functions in a way
20//! that is recognized by the ``SciRS2`` library. This allows for seamless integration with
21//! distributed arrays, GPU arrays, and other custom array implementations.
22//!
23//! ## Core Components
24//!
25//! The Array Protocol system includes:
26//!
27//! * Specialized array implementations (GPU, distributed, JIT)
28//! * Automatic device placement with `AutoDevice`
29//! * Mixed-precision operations
30//! * Neural network layers and models
31//! * Gradient computation and training capabilities
32//! * Distributed training and model serialization
33
34use std::any::{Any, TypeId};
35use std::collections::HashMap;
36use std::fmt::{Debug, Display};
37use std::marker::PhantomData;
38use std::sync::{Arc, LazyLock, RwLock};
39use std::time::{Duration, Instant};
40
41use crate::error::{CoreError, CoreResult, ErrorContext};
42
43// Internal submodules
44mod distributed_impl;
45mod gpu_impl;
46mod jit_impl;
47mod operations;
48
49// Re-export the array_function_dispatch macro
50pub use crate::array_function_dispatch;
51
52// Public submodules
53pub mod auto_device;
54pub mod distributed_training;
55pub mod grad;
56pub mod mixed_precision;
57pub mod ml_ops;
58pub mod neural;
59#[cfg(feature = "serialization")]
60pub mod serialization;
61pub mod training;
62
63/// Trait for objects that can handle the array protocol.
64///
65/// This is similar to `NumPy`'s `__array_function__` protocol.
66pub trait ArrayProtocol: Any + Send + Sync {
67    /// Implementation of the array protocol.
68    ///
69    /// * `func` - The function being called
70    /// * `types` - The types of all arguments that implement `ArrayProtocol`
71    /// * `args` - The arguments to the function
72    /// * `kwargs` - Named arguments to the function
73    ///
74    /// Returns `Ok(result)` if the operation is successful, or `Err(NotImplemented)`
75    /// if the operation is not implemented for this type.
76    ///
77    /// # Errors
78    ///
79    /// Returns `Err(NotImplemented)` if the operation is not supported by this array type.
80    fn array_function(
81        &self,
82        func: &ArrayFunction,
83        types: &[TypeId],
84        args: &[Box<dyn Any>],
85        kwargs: &HashMap<String, Box<dyn Any>>,
86    ) -> Result<Box<dyn Any>, NotImplemented>;
87
88    /// Get the array as Any for downcasting
89    #[must_use]
90    fn as_any(&self) -> &dyn Any;
91
92    /// Get the shape of the array (default implementation returns empty slice)
93    #[must_use]
94    fn shape(&self) -> &[usize] {
95        &[]
96    }
97
98    /// Get the data type of the array (default implementation returns f64)
99    #[must_use]
100    fn dtype(&self) -> TypeId {
101        TypeId::of::<f64>()
102    }
103
104    /// Clone this array protocol object.
105    #[must_use]
106    fn box_clone(&self) -> Box<dyn ArrayProtocol>;
107}
108
109/// Make `Box<dyn ArrayProtocol>` cloneable via `box_clone`
110impl Clone for Box<dyn ArrayProtocol> {
111    fn clone(&self) -> Self {
112        self.box_clone()
113    }
114}
115
116/// Marker for functions not implemented by a specific type.
117///
118/// This is part of the Array Protocol API design and is used as a marker to indicate
119/// that a function is not implemented by a specific array type. It's different from
120/// the `CoreError::NotImplementedError` enum variant, which is used for error reporting.
121///
122/// When an error is propagated up the call chain, `NotImplemented` is converted
123/// to `OperationError::NotImplemented` and then to `CoreError::NotImplementedError`
124/// for consistent error handling.
125#[derive(Debug, Clone, Copy)]
126pub struct NotImplemented;
127
128impl Display for NotImplemented {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        write!(f, "NotImplemented")
131    }
132}
133
134/// Type alias for the complex function implementation type
135pub type ArrayFunctionImpl = dyn Fn(&[Box<dyn Any>], &HashMap<String, Box<dyn Any>>) -> CoreResult<Box<dyn Any>>
136    + Send
137    + Sync;
138
139/// A wrapper for functions that can be overridden by the array protocol.
140#[derive(Clone)]
141pub struct ArrayFunction {
142    /// The name of the function, including its module path
143    pub name: &'static str,
144
145    /// The function implementation
146    pub implementation: Arc<ArrayFunctionImpl>,
147}
148
149impl Debug for ArrayFunction {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("ArrayFunction")
152            .field("name", &self.name)
153            .finish_non_exhaustive()
154    }
155}
156
157impl PartialEq for ArrayFunction {
158    fn eq(&self, other: &Self) -> bool {
159        self.name == other.name
160    }
161}
162
163impl Eq for ArrayFunction {}
164
165impl std::hash::Hash for ArrayFunction {
166    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
167        self.name.hash(state);
168    }
169}
170
171impl ArrayFunction {
172    /// Create a new array function with the given name
173    #[must_use]
174    pub fn new(name: &'static str) -> Self {
175        Self {
176            name,
177            // Default implementation that returns NotImplemented
178            implementation: Arc::new(|_args, _kwargs| {
179                Err(CoreError::NotImplementedError(ErrorContext::new(
180                    "Function not implemented".to_string(),
181                )))
182            }),
183        }
184    }
185}
186
187/// Cache entry for function dispatch optimization
188#[derive(Debug, Clone)]
189pub struct DispatchCacheEntry {
190    /// Type signature for the cached result
191    #[allow(dead_code)]
192    type_signature: Vec<TypeId>,
193    /// Which implementation type to try first
194    #[allow(dead_code)]
195    preferred_impl_type: TypeId,
196    /// Cache timestamp for TTL management
197    timestamp: Instant,
198    /// Number of cache hits
199    hit_count: u64,
200}
201
202/// Registry of all array functions with dispatch caching.
203#[derive(Debug)]
204pub struct ArrayFunctionRegistry {
205    /// Map of function names to array functions
206    functions: HashMap<&'static str, ArrayFunction>,
207    /// Dispatch cache for performance optimization
208    dispatch_cache: HashMap<(&'static str, Vec<TypeId>), DispatchCacheEntry>,
209    /// Maximum cache size to prevent unbounded growth
210    max_cache_size: usize,
211    /// Cache TTL for entries (prevents stale cache)
212    cache_ttl: Duration,
213}
214
215impl Default for ArrayFunctionRegistry {
216    fn default() -> Self {
217        Self {
218            functions: HashMap::new(),
219            dispatch_cache: HashMap::new(),
220            max_cache_size: 1000,                // Reasonable default cache size
221            cache_ttl: Duration::from_secs(300), // 5 minutes TTL
222        }
223    }
224}
225
226impl ArrayFunctionRegistry {
227    /// Get the global registry.
228    #[must_use]
229    pub fn global() -> &'static RwLock<Self> {
230        static REGISTRY: LazyLock<RwLock<ArrayFunctionRegistry>> =
231            LazyLock::new(|| RwLock::new(ArrayFunctionRegistry::default()));
232        &REGISTRY
233    }
234
235    /// Register a new array function.
236    pub fn register(&mut self, func: ArrayFunction) {
237        self.functions.insert(func.name, func);
238    }
239
240    /// Get an array function by name.
241    #[must_use]
242    #[allow(dead_code)]
243    pub fn get(&self, name: &str) -> Option<&ArrayFunction> {
244        self.functions.get(name)
245    }
246
247    /// Get all registered functions.
248    #[must_use]
249    pub fn all_functions(&self) -> Vec<&ArrayFunction> {
250        self.functions.values().collect()
251    }
252
253    /// Get cached dispatch entry for optimization
254    #[must_use]
255    pub fn get_cached_dispatch(
256        &self,
257        funcname: &'static str,
258        types: &[TypeId],
259    ) -> Option<&DispatchCacheEntry> {
260        let key = (funcname, types.to_vec());
261        if let Some(entry) = self.dispatch_cache.get(&key) {
262            // Check if cache entry is still valid (TTL check)
263            if entry.timestamp.elapsed() < self.cache_ttl {
264                return Some(entry);
265            }
266        }
267        None
268    }
269
270    /// Cache dispatch result for future optimization
271    pub fn cache_dispatch(
272        &mut self,
273        funcname: &'static str,
274        types: Vec<TypeId>,
275        impl_type: TypeId,
276    ) {
277        // Clean cache if it's getting too large
278        if self.dispatch_cache.len() >= self.max_cache_size {
279            self.cleanup_cache();
280        }
281
282        let key = (funcname, types.clone());
283        let entry = DispatchCacheEntry {
284            type_signature: types,
285            preferred_impl_type: impl_type,
286            timestamp: Instant::now(),
287            hit_count: 0,
288        };
289        self.dispatch_cache.insert(key, entry);
290    }
291
292    /// Update cache hit count for an entry
293    pub fn update_cache_hit(&mut self, funcname: &'static str, types: &[TypeId]) {
294        let key = (funcname, types.to_vec());
295        if let Some(entry) = self.dispatch_cache.get_mut(&key) {
296            entry.hit_count += 1;
297        }
298    }
299
300    /// Clean up expired cache entries
301    fn cleanup_cache(&mut self) {
302        let now = Instant::now();
303        self.dispatch_cache
304            .retain(|_, entry| now.duration_since(entry.timestamp) < self.cache_ttl);
305
306        // If still too large, remove least recently used entries
307        if self.dispatch_cache.len() >= self.max_cache_size {
308            let mut entries: Vec<_> = self
309                .dispatch_cache
310                .iter()
311                .map(|(k, v)| (k.clone(), v.hit_count))
312                .collect();
313            entries.sort_by_key(|(_, hit_count)| *hit_count);
314
315            // Remove bottom 25% of entries by hit count
316            let to_remove = self.dispatch_cache.len() / 4;
317            let keys_to_remove: Vec<_> = entries
318                .iter()
319                .take(to_remove)
320                .map(|(key, _)| key.clone())
321                .collect();
322            for key in keys_to_remove {
323                self.dispatch_cache.remove(&key);
324            }
325        }
326    }
327
328    /// Get cache statistics for monitoring
329    #[must_use]
330    pub fn cache_stats(&self) -> HashMap<String, u64> {
331        let mut stats = HashMap::new();
332        stats.insert("cache_size".to_string(), self.dispatch_cache.len() as u64);
333        stats.insert("max_cache_size".to_string(), self.max_cache_size as u64);
334
335        let total_hits: u64 = self.dispatch_cache.values().map(|e| e.hit_count).sum();
336        stats.insert("total_hits".to_string(), total_hits);
337
338        stats
339    }
340}
341
342/// Helper function to extract all arguments implementing the `ArrayProtocol` trait.
343///
344/// This is similar to `NumPy`'s `_get_implementing_args` function.
345/// Optimized version with pre-allocated capacity and fast-path for common cases.
346#[allow(dead_code)]
347pub fn get_implementing_args(args: &[Box<dyn Any>]) -> Vec<(TypeId, &dyn ArrayProtocol)> {
348    if args.is_empty() {
349        return Vec::new();
350    }
351
352    // Pre-allocate with capacity to avoid reallocation
353    let mut implementing_args = Vec::with_capacity(args.len());
354
355    for arg in args {
356        if let Some(array_protocol_obj) = arg.downcast_ref::<Box<dyn ArrayProtocol>>() {
357            let type_id = (**array_protocol_obj).type_id();
358            implementing_args.push((type_id, &**array_protocol_obj));
359        }
360    }
361
362    // Sort implementing _args by TypeId for deterministic dispatch order
363    // This ensures consistent dispatch behavior across calls
364    implementing_args.sort_by_key(|&_type_id_| {
365        // Use TypeId hash for deterministic ordering
366        use std::hash::{Hash, Hasher};
367        let mut hasher = std::collections::hash_map::DefaultHasher::new();
368        std::any::TypeId::of::<i32>().hash(&mut hasher);
369        hasher.finish()
370    });
371
372    implementing_args
373}
374
375/// Calls the array protocol with the given function and arguments.
376///
377/// * `func` - The array function to call
378/// * `args` - The arguments to the function
379/// * `kwargs` - Named arguments to the function
380///
381/// Returns the result of the function call, or an error if the function
382/// cannot be dispatched to any of the array protocol implementations.
383///
384/// Optimized version with caching and fast-path optimizations.
385#[allow(dead_code)]
386pub fn array_function_dispatch(
387    func: &ArrayFunction,
388    args: &[Box<dyn Any>],
389    kwargs: &HashMap<String, Box<dyn Any>>,
390) -> CoreResult<Box<dyn Any>> {
391    // Fast path for empty args
392    if args.is_empty() {
393        return (func.implementation)(args, kwargs);
394    }
395
396    // Find all arguments implementing ArrayProtocol
397    let implementing_args = get_implementing_args(args);
398
399    if implementing_args.is_empty() {
400        // No arguments implement ArrayProtocol, use default implementation
401        return (func.implementation)(args, kwargs);
402    }
403
404    // Fast path for single implementing argument
405    if implementing_args.len() == 1 {
406        let (type_id, array_protocol_obj) = implementing_args[0];
407        let types = [type_id];
408        match array_protocol_obj.array_function(func, &types, args, kwargs) {
409            Ok(result) => return Ok(result),
410            Err(NotImplemented) => {
411                return Err(CoreError::DispatchError(ErrorContext::new(format!(
412                    "No implementation found for {} with type {:?}",
413                    func.name, type_id
414                ))));
415            }
416        }
417    }
418
419    // Extract all unique types that implement ArrayProtocol (optimized)
420    let mut unique_types = Vec::with_capacity(implementing_args.len());
421    let mut seen_types = std::collections::HashSet::with_capacity(implementing_args.len());
422
423    for &(type_id, _) in &implementing_args {
424        if seen_types.insert(type_id) {
425            unique_types.push(type_id);
426        }
427    }
428
429    // Try dispatching to each implementation in priority order
430    for (_, array_protocol_obj) in implementing_args {
431        if let Ok(result) = array_protocol_obj.array_function(func, &unique_types, args, kwargs) {
432            return Ok(result);
433        }
434    }
435
436    // If we get here, no implementation was found
437    Err(CoreError::DispatchError(ErrorContext::new(format!(
438        "No implementation found for {} with {} argument types: {:?}",
439        func.name,
440        unique_types.len(),
441        unique_types
442    ))))
443}
444
445/// Decorator for adding array function dispatch capabilities to a function.
446///
447/// This is similar to `NumPy`'s `array_function_dispatch` decorator.
448pub struct ArrayFunctionDecorator<F> {
449    function: F,
450    name: &'static str,
451}
452
453impl<F> ArrayFunctionDecorator<F>
454where
455    F: Send + Sync + 'static,
456{
457    /// Create a new array function decorator.
458    #[must_use]
459    pub fn new(function: F, name: &'static str) -> Self {
460        Self { function, name }
461    }
462
463    /// Register the function with the global registry.
464    pub fn register(self) -> F {
465        let implementation = Arc::new(
466            move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
467                // Implementation that converts generic arguments to specific types
468                // and calls the original function
469                // This is a simplified version - in practice, we would need more complex
470                // type conversion
471                Err(CoreError::NotImplementedError(ErrorContext::new(
472                    "ArrayFunctionDecorator: Type conversion in array_function_dispatch is not implemented yet".to_string()
473                )))
474            },
475        );
476
477        let func = ArrayFunction {
478            name: self.name,
479            implementation,
480        };
481
482        // Register the function with the global registry
483        let registry = ArrayFunctionRegistry::global();
484        if let Ok(mut registry) = registry.write() {
485            registry.register(func);
486        } else {
487            eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry, skipping function registration");
488            // Continue without registration - this may result in reduced functionality but avoids crash
489        }
490
491        self.function
492    }
493}
494
495/// Trait for arrays that can support GPU operations.
496pub trait GPUArray: ArrayProtocol {
497    /// Move the array to GPU.
498    fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>>;
499
500    /// Move the array from GPU to CPU.
501    fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
502
503    /// Check if the array is on GPU.
504    #[must_use]
505    fn is_on_gpu(&self) -> bool;
506
507    /// Get information about the GPU device that holds this array.
508    #[must_use]
509    fn device_info(&self) -> HashMap<String, String>;
510}
511
512/// Trait for distributed arrays that can span multiple machines.
513pub trait DistributedArray: ArrayProtocol {
514    /// Get information about the distribution of this array.
515    #[must_use]
516    fn distribution_info(&self) -> HashMap<String, String>;
517
518    /// Gather the distributed array to a single node.
519    fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
520
521    /// Scatter a regular array to a distributed array.
522    fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>>;
523
524    /// Check if this array is distributed.
525    #[must_use]
526    fn is_distributed(&self) -> bool;
527}
528
529/// JIT (Just-In-Time) compilation support for arrays.
530pub trait JITArray: ArrayProtocol {
531    /// Compile an expression to be evaluated on this array.
532    fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>>;
533
534    /// Check if JIT compilation is supported for this array type.
535    #[must_use]
536    fn supports_jit(&self) -> bool;
537
538    /// Get information about the JIT compiler being used.
539    #[must_use]
540    fn jit_info(&self) -> HashMap<String, String>;
541}
542
543/// A JIT-compiled function that can be evaluated on arrays.
544pub trait JITFunction: Send + Sync {
545    /// Evaluate the function with the given arguments.
546    fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>>;
547
548    /// Get the source code of the compiled function.
549    #[must_use]
550    fn source(&self) -> String;
551
552    /// Get information about how the function was compiled.
553    #[must_use]
554    fn compile_info(&self) -> HashMap<String, String>;
555
556    /// Clone this JIT function into a `Box<dyn JITFunction>`.
557    #[must_use]
558    fn clone_box(&self) -> Box<dyn JITFunction>;
559}
560
561/// A factory for creating JIT functions for specific array implementations.
562pub trait JITFunctionFactory: Send + Sync {
563    /// Create a new JIT function for the given expression and array type.
564    fn create_jit_function(
565        &self,
566        expression: &str,
567        array_typeid: TypeId,
568    ) -> CoreResult<Box<dyn JITFunction>>;
569
570    /// Check if this factory supports the given array type.
571    #[must_use]
572    fn supports_array_type(&self, array_typeid: TypeId) -> bool;
573}
574
575/// Registry of JIT function factories.
576#[derive(Default)]
577pub struct JITFactoryRegistry {
578    factories: Vec<Box<dyn JITFunctionFactory>>,
579}
580
581impl std::fmt::Debug for JITFactoryRegistry {
582    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583        write!(
584            f,
585            "JITFactoryRegistry {{ factories: {} }}",
586            self.factories.len()
587        )
588    }
589}
590
591impl JITFactoryRegistry {
592    /// Get the global registry.
593    #[must_use]
594    pub fn global() -> &'static RwLock<Self> {
595        static REGISTRY: LazyLock<RwLock<JITFactoryRegistry>> = LazyLock::new(|| {
596            RwLock::new(JITFactoryRegistry {
597                factories: Vec::new(),
598            })
599        });
600        &REGISTRY
601    }
602
603    /// Register a new JIT function factory.
604    pub fn register(&mut self, factory: Box<dyn JITFunctionFactory>) {
605        self.factories.push(factory);
606    }
607
608    /// Get a JIT function factory that supports the given array type.
609    #[must_use]
610    pub fn get_factory_for_array_type(
611        &self,
612        array_typeid: TypeId,
613    ) -> Option<&dyn JITFunctionFactory> {
614        for factory in &self.factories {
615            if factory.supports_array_type(array_typeid) {
616                return Some(&**factory);
617            }
618        }
619        None
620    }
621}
622
623/// A wrapper for ndarray to implement the ArrayProtocol trait.
624#[derive(Debug, Clone)]
625pub struct NdarrayWrapper<T, D: ndarray::Dimension> {
626    array: ndarray::Array<T, D>,
627    phantom: PhantomData<(T, D)>,
628}
629
630impl<T, D> NdarrayWrapper<T, D>
631where
632    T: Clone + 'static,
633    D: ndarray::Dimension + 'static,
634{
635    /// Create a new ndarray wrapper.
636    #[must_use]
637    pub fn new(array: ndarray::Array<T, D>) -> Self {
638        Self {
639            array,
640            phantom: PhantomData,
641        }
642    }
643
644    /// Get the underlying ndarray.
645    #[must_use]
646    pub const fn as_array(&self) -> &ndarray::Array<T, D> {
647        &self.array
648    }
649
650    /// Convert into the underlying ndarray.
651    #[must_use]
652    pub fn into_array(self) -> ndarray::Array<T, D> {
653        self.array
654    }
655
656    /// Update the underlying array with a new one.
657    pub fn array_2(&mut self, newarray: ndarray::Array<T, D>) {
658        self.array = newarray;
659    }
660}
661
662impl<T, D> ArrayProtocol for NdarrayWrapper<T, D>
663where
664    T: Clone + Send + Sync + 'static,
665    D: ndarray::Dimension + Send + Sync + 'static,
666{
667    fn array_function(
668        &self,
669        func: &ArrayFunction,
670        _types: &[TypeId],
671        args: &[Box<dyn Any>],
672        kwargs: &HashMap<String, Box<dyn Any>>,
673    ) -> Result<Box<dyn Any>, NotImplemented> {
674        match func.name {
675            "scirs2::array_protocol::operations::add" => {
676                // Addition operation for NdarrayWrapper
677                if args.len() < 2 {
678                    return Err(NotImplemented);
679                }
680
681                if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
682                    if let (Some(a), Some(b)) = (
683                        self.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
684                        other.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
685                    ) {
686                        // Need to make sure T supports addition
687                        if TypeId::of::<T>() == TypeId::of::<f64>() {
688                            let a_f64 =
689                                unsafe { &*(a as *const _ as *const NdarrayWrapper<f64, D>) };
690                            let b_f64 =
691                                unsafe { &*(b as *const _ as *const NdarrayWrapper<f64, D>) };
692                            let result = a_f64.as_array() + b_f64.as_array();
693                            return Ok(Box::new(NdarrayWrapper::new(result)));
694                        } else if TypeId::of::<T>() == TypeId::of::<f32>() {
695                            let a_f32 =
696                                unsafe { &*(a as *const _ as *const NdarrayWrapper<f32, D>) };
697                            let b_f32 =
698                                unsafe { &*(b as *const _ as *const NdarrayWrapper<f32, D>) };
699                            let result = a_f32.as_array() + b_f32.as_array();
700                            return Ok(Box::new(NdarrayWrapper::new(result)));
701                        }
702                    }
703                }
704                Err(NotImplemented)
705            }
706            "scirs2::array_protocol::operations::matmul" => {
707                // Matrix multiplication for NdarrayWrapper
708                if args.len() < 2 {
709                    return Err(NotImplemented);
710                }
711
712                // We can only handle matrix multiplication for 2D arrays
713                // Check for 2D array using TypeId
714                if TypeId::of::<D>() != TypeId::of::<ndarray::Ix2>() {
715                    return Err(NotImplemented);
716                }
717
718                if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
719                    // Since we've already checked TypeId::of::<D>() == TypeId::of::<ndarray::Ix2>()
720                    // We can safely specialize for Ix2 matrices
721
722                    // Handle the case for f64 matrices
723                    if TypeId::of::<T>() == TypeId::of::<f64>() {
724                        // Cast to concrete _types we know how to handle
725                        let a_f64 = unsafe {
726                            &*(self as *const _ as *const NdarrayWrapper<f64, ndarray::Ix2>)
727                        };
728                        let b_f64 = unsafe {
729                            &*(other as *const _ as *const NdarrayWrapper<f64, ndarray::Ix2>)
730                        };
731
732                        // Get dimensions
733                        let ashape = a_f64.as_array().shape();
734                        let bshape = b_f64.as_array().shape();
735
736                        if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
737                            return Err(NotImplemented);
738                        }
739
740                        // Use the higher-level dot operation which will be more efficient
741                        // than our manual implementation
742                        let result = a_f64.as_array().dot(b_f64.as_array());
743                        return Ok(Box::new(NdarrayWrapper::new(result)));
744                    }
745                    // Handle the case for f32 matrices
746                    else if TypeId::of::<T>() == TypeId::of::<f32>() {
747                        // Cast to concrete _types we know how to handle
748                        let a_f32 = unsafe {
749                            &*(self as *const _ as *const NdarrayWrapper<f32, ndarray::Ix2>)
750                        };
751                        let b_f32 = unsafe {
752                            &*(other as *const _ as *const NdarrayWrapper<f32, ndarray::Ix2>)
753                        };
754
755                        // Get dimensions
756                        let ashape = a_f32.as_array().shape();
757                        let bshape = b_f32.as_array().shape();
758
759                        if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
760                            return Err(NotImplemented);
761                        }
762
763                        // Use the higher-level dot operation which will be more efficient
764                        // than our manual implementation
765                        let result = a_f32.as_array().dot(b_f32.as_array());
766                        return Ok(Box::new(NdarrayWrapper::new(result)));
767                    }
768                }
769                // If we get here, we don't know how to handle this case
770                Err(NotImplemented)
771            }
772            "scirs2::array_protocol::operations::transpose" => {
773                // Transpose operation for NdarrayWrapper
774                if TypeId::of::<T>() == TypeId::of::<f64>() {
775                    let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
776                    let result = a_f64.as_array().t().to_owned();
777                    return Ok(Box::new(NdarrayWrapper::new(result)));
778                } else if TypeId::of::<T>() == TypeId::of::<f32>() {
779                    let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
780                    let result = a_f32.as_array().t().to_owned();
781                    return Ok(Box::new(NdarrayWrapper::new(result)));
782                }
783                Err(NotImplemented)
784            }
785            "scirs2::array_protocol::operations::sum" => {
786                // Sum operation for NdarrayWrapper
787                let axis_ref = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
788
789                if TypeId::of::<T>() == TypeId::of::<f64>() {
790                    let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
791                    match axis_ref {
792                        Some(&_ax) => {
793                            // Can't use sum_axis without RemoveAxis trait
794                            // Just return the full sum for now
795                            let result = a_f64.as_array().sum();
796                            return Ok(Box::new(result));
797                        }
798                        None => {
799                            let result = a_f64.as_array().sum();
800                            return Ok(Box::new(result));
801                        }
802                    }
803                } else if TypeId::of::<T>() == TypeId::of::<f32>() {
804                    let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
805                    match axis_ref {
806                        Some(&_ax) => {
807                            // Can't use sum_axis without RemoveAxis trait
808                            // Just return the full sum for now
809                            let result = a_f32.as_array().sum();
810                            return Ok(Box::new(result));
811                        }
812                        None => {
813                            let result = a_f32.as_array().sum();
814                            return Ok(Box::new(result));
815                        }
816                    }
817                }
818                Err(NotImplemented)
819            }
820            "scirs2::array_protocol::operations::reshape" => {
821                // Reshape operation for NdarrayWrapper
822                if let Some(shape) = kwargs
823                    .get("shape")
824                    .and_then(|s| s.downcast_ref::<Vec<usize>>())
825                {
826                    if TypeId::of::<T>() == TypeId::of::<f64>() {
827                        let a_f64 =
828                            unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
829                        match a_f64
830                            .as_array()
831                            .clone()
832                            .into_shape_with_order(shape.clone())
833                        {
834                            Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
835                            Err(_) => return Err(NotImplemented),
836                        }
837                    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
838                        let a_f32 =
839                            unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
840                        match a_f32
841                            .as_array()
842                            .clone()
843                            .into_shape_with_order(shape.clone())
844                        {
845                            Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
846                            Err(_) => return Err(NotImplemented),
847                        }
848                    }
849                }
850                Err(NotImplemented)
851            }
852            _ => Err(NotImplemented),
853        }
854    }
855
856    fn as_any(&self) -> &dyn Any {
857        self
858    }
859
860    fn shape(&self) -> &[usize] {
861        self.array.shape()
862    }
863
864    fn dtype(&self) -> TypeId {
865        TypeId::of::<T>()
866    }
867
868    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
869        Box::new(self.clone())
870    }
871}
872
873// Example implementation for a third-party array library:
874
875/// A mock distributed array implementation.
876#[derive(Debug, Clone)]
877pub struct MockDistributedArray<T: Clone + 'static> {
878    chunks: Vec<T>,
879    shape: Vec<usize>,
880}
881
882impl<T: Clone + Send + Sync + 'static> MockDistributedArray<T> {
883    /// Create a new mock distributed array.
884    #[must_use]
885    pub fn new(chunks: Vec<T>, shape: Vec<usize>) -> Self {
886        Self { chunks, shape }
887    }
888}
889
890impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockDistributedArray<T> {
891    fn array_function(
892        &self,
893        func: &ArrayFunction,
894        _types: &[TypeId],
895        _args: &[Box<dyn Any>],
896        _kwargs: &HashMap<String, Box<dyn Any>>,
897    ) -> Result<Box<dyn Any>, NotImplemented> {
898        match func.name {
899            "scirs2::mean" => {
900                // Example: Implement a mean function for distributed arrays
901                // In a real implementation, this would use distributed computation
902
903                // For simplicity, we'll just return a dummy result
904                let result = T::clone(&self.chunks[0]);
905                Ok(Box::new(result))
906            }
907            _ => Err(NotImplemented),
908        }
909    }
910
911    fn as_any(&self) -> &dyn Any {
912        self
913    }
914
915    fn shape(&self) -> &[usize] {
916        &self.shape
917    }
918
919    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
920        Box::new(self.clone())
921    }
922}
923
924impl<T: Clone + Send + Sync + 'static> DistributedArray for MockDistributedArray<T> {
925    fn distribution_info(&self) -> HashMap<String, String> {
926        let mut info = HashMap::new();
927        info.insert("type".to_string(), "mock_distributed".to_string());
928        info.insert("chunks".to_string(), self.chunks.len().to_string());
929        info
930    }
931
932    fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
933        // In a real implementation, this would gather data from all nodes
934        // For now, we just return self boxed as ArrayProtocol
935        Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
936    }
937
938    fn scatter(&self, _numchunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
939        // In a real implementation, this would scatter data to multiple nodes
940        // For now, we just return self boxed as DistributedArray
941        Ok(Box::new(self.clone()) as Box<dyn DistributedArray>)
942    }
943
944    fn is_distributed(&self) -> bool {
945        true
946    }
947}
948
949/// A mock GPU array implementation.
950#[derive(Debug, Clone)]
951pub struct MockGPUArray<T: Clone + 'static> {
952    data: Vec<T>,
953    shape: Vec<usize>,
954    device: String,
955}
956
957impl<T: Clone + Send + Sync + 'static> MockGPUArray<T> {
958    /// Create a new mock GPU array.
959    #[must_use]
960    pub fn new(data: Vec<T>, shape: Vec<usize>, device: String) -> Self {
961        Self {
962            data,
963            shape,
964            device,
965        }
966    }
967}
968
969impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockGPUArray<T> {
970    fn array_function(
971        &self,
972        func: &ArrayFunction,
973        _types: &[TypeId],
974        _args: &[Box<dyn Any>],
975        _kwargs: &HashMap<String, Box<dyn Any>>,
976    ) -> Result<Box<dyn Any>, NotImplemented> {
977        match func.name {
978            "scirs2::matmul" => {
979                // Example: Implement a GPU-accelerated matrix multiplication
980                // In a real implementation, this would use GPU computation
981
982                // For simplicity, we'll just return a dummy result
983                let result =
984                    MockGPUArray::new(self.data.clone(), self.shape.clone(), self.device.clone());
985                Ok(Box::new(result))
986            }
987            _ => Err(NotImplemented),
988        }
989    }
990
991    fn as_any(&self) -> &dyn Any {
992        self
993    }
994
995    fn shape(&self) -> &[usize] {
996        &self.shape
997    }
998
999    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
1000        Box::new(self.clone())
1001    }
1002}
1003
1004impl<T: Clone + Send + Sync + 'static> GPUArray for MockGPUArray<T> {
1005    fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
1006        // Already on GPU
1007        Ok(Box::new(self.clone()) as Box<dyn GPUArray>)
1008    }
1009
1010    fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
1011        // In a real implementation, this would transfer data from GPU to CPU
1012        // For now, we just return self boxed as ArrayProtocol
1013        Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
1014    }
1015
1016    fn is_on_gpu(&self) -> bool {
1017        true
1018    }
1019
1020    fn device_info(&self) -> HashMap<String, String> {
1021        let mut info = HashMap::new();
1022        info.insert("device".to_string(), self.device.clone());
1023        info.insert("type".to_string(), "mock_gpu".to_string());
1024        info
1025    }
1026}
1027
1028/// A factory for creating and registering array protocol enabled functions.
1029///
1030/// This provides a convenient way to create functions that can be overridden
1031/// by third-party array implementations.
1032#[derive(Debug)]
1033pub struct ArrayProtocolFunction<F> {
1034    func: F,
1035    name: &'static str,
1036}
1037
1038impl<F> ArrayProtocolFunction<F> {
1039    /// Create a new array protocol function.
1040    #[must_use]
1041    pub fn new(func: F, name: &'static str) -> Self {
1042        Self { func, name }
1043    }
1044}
1045
1046impl<F> ArrayProtocolFunction<F>
1047where
1048    F: Clone + Send + Sync + 'static,
1049{
1050    /// Register this function with the array protocol system.
1051    pub fn register(self) -> F {
1052        let implementation = Arc::new(
1053            move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1054                // This is a placeholder for actual implementation that would:
1055                // 1. Convert generic args to specific types needed by the function
1056                // 2. Call the function with the converted args
1057                // 3. Return the result as a Box<dyn Any>
1058                Err(CoreError::NotImplementedError(ErrorContext::new(
1059                    "ArrayProtocolFunction: Implementation for array protocol functions is not complete".to_string()
1060                )))
1061            },
1062        );
1063
1064        let array_func = ArrayFunction {
1065            name: self.name,
1066            implementation,
1067        };
1068
1069        // Register the function
1070        if let Ok(mut registry) = ArrayFunctionRegistry::global().write() {
1071            registry.register(array_func);
1072        } else {
1073            eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry during array protocol building, skipping function registration");
1074            // Continue without registration - this may result in reduced functionality but avoids crash
1075        }
1076
1077        self.func
1078    }
1079}
1080
1081/// Convenience macro for defining array protocol functions.
1082///
1083/// This macro creates a function and registers it with the array protocol system.
1084/// The function can then be overridden by array types that implement the ArrayProtocol trait.
1085///
1086/// Example usage:
1087/// ```ignore
1088/// use scirs2_core::array_protocol::{ArrayFunction, ArrayFunctionRegistry};
1089/// use std::sync::Arc;
1090/// use std::collections::HashMap;
1091/// use std::any::Any;
1092///
1093/// // Define and register a sum function
1094/// fn register_sum_function() {
1095///     let implementation = Arc::new(
1096///         move |args: &[Box<dyn Any>], kwargs: &HashMap<String, Box<dyn Any>>| {
1097///             if let Some(array) = args.get(0)
1098///                 .and_then(|arg| arg.downcast_ref::<ndarray::Array<f64, ndarray::Ix2>>()) {
1099///                 let sum = array.sum();
1100///                 Ok(Box::new(sum) as Box<dyn Any>)
1101///             } else {
1102///                 Err(scirs2_core::error::CoreError::InvalidArgument(
1103///                     scirs2_core::error::ErrorContext::new(
1104///                         "Expected Array2<f64> as first argument".to_string()
1105///                     )
1106///                 ))
1107///             }
1108///         }
1109///     );
1110///     
1111///     let func = ArrayFunction {
1112///         name: "scirs2::sum",
1113///         implementation,
1114///     };
1115///     
1116///     // Register the function
1117///     if let Ok(mut registry) = ArrayFunctionRegistry::global().write() {
1118///         registry.register(func);
1119///     }
1120/// }
1121/// ```
1122#[macro_export]
1123macro_rules! array_function_def {
1124    (fn $name:ident $(<$($gen:ident),*>)? ($($arg:ident : $arg_ty:ty),*) -> $ret:ty $body:block, $funcname:expr) => {
1125        {
1126            // Define the function
1127            fn $name $(<$($gen),*>)? ($($arg : $arg_ty),*) -> $ret $body
1128
1129            // Return the function so it can be used
1130            $name
1131        }
1132    };
1133}
1134
1135// Re-export distributed array implementation
1136pub use self::distributed_impl::{
1137    ArrayChunk, DistributedBackend, DistributedConfig, DistributedNdarray, DistributionStrategy,
1138};
1139
1140// Re-export GPU array implementation
1141pub use self::gpu_impl::{
1142    kernels as gpu_kernels, GPUArrayBuilder, GPUBackend, GPUConfig, GPUNdarray,
1143};
1144
1145// Re-export JIT compilation implementation
1146pub use self::jit_impl::{
1147    CraneliftFunctionFactory, JITBackend, JITConfig, JITEnabledArray, JITFunctionImpl, JITManager,
1148    LLVMFunctionFactory,
1149};
1150
1151// Re-export array operations
1152pub use self::operations::{
1153    add, apply_elementwise, concatenate, inverse, matmul, multiply, reshape, subtract, sum, svd,
1154    transpose, OperationError,
1155};
1156
1157// Re-export ml_ops
1158pub use self::ml_ops::{
1159    activation, batch_norm, conv2d, cross_entropy, dropout, max_pool2d, self_attention,
1160    ActivationFunc,
1161};
1162
1163/// Initializes the array protocol system.
1164///
1165/// This function initializes the array protocol system by registering the
1166/// default JIT function factories and other components. It should be called
1167/// before using any of the array protocol features.
1168#[allow(dead_code)]
1169pub fn init() {
1170    // Initialize the JIT manager
1171    let mut jit_manager = JITManager::global().write().unwrap();
1172    jit_manager.initialize();
1173}
1174
1175/// Extra traits for third-party array implementations.
1176pub mod traits {
1177    use super::*;
1178
1179    /// Trait for arrays that support strided access.
1180    pub trait StridedArray: ArrayProtocol {
1181        /// Get the strides of this array.
1182        #[must_use]
1183        fn strides(&self) -> Vec<usize>;
1184
1185        /// Check if this array is contiguous.
1186        #[must_use]
1187        fn is_contiguous(&self) -> bool;
1188
1189        /// Check if this array is Fortran-contiguous (column-major).
1190        #[must_use]
1191        fn is_fortran_contiguous(&self) -> bool;
1192    }
1193
1194    /// Trait for arrays that support zero-copy operations.
1195    pub trait ZeroCopyArray: ArrayProtocol {
1196        /// Create a view of this array.
1197        #[must_use]
1198        fn view(&self) -> Box<dyn ZeroCopyArray>;
1199
1200        /// Create a mutable view of this array.
1201        #[must_use]
1202        fn view_mut(&mut self) -> Box<dyn ZeroCopyArray>;
1203
1204        /// Check if this array is a view.
1205        #[must_use]
1206        fn is_view(&self) -> bool;
1207    }
1208
1209    /// Trait for arrays that support automatic differentiation.
1210    pub trait DifferentiableArray: ArrayProtocol {
1211        /// Compute the gradient of this array with respect to some variables.
1212        fn gradient(
1213            &self,
1214            variables: &[Box<dyn DifferentiableArray>],
1215        ) -> Vec<Box<dyn DifferentiableArray>>;
1216
1217        /// Set whether to record operations for automatic differentiation.
1218        fn set_requiresgrad(&mut self, requiresgrad: bool);
1219
1220        /// Check if this array requires gradient computation.
1221        #[must_use]
1222        fn requiresgrad(&self) -> bool;
1223
1224        /// Get the gradient of this array.
1225        #[must_use]
1226        fn grad(&self) -> Option<Box<dyn DifferentiableArray>>;
1227    }
1228
1229    /// Trait for arrays that support asynchronous operations.
1230    pub trait AsyncArray: ArrayProtocol {
1231        /// Perform an asynchronous operation on this array.
1232        fn async_op<F, R>(&self, op: F) -> impl std::future::Future<Output = CoreResult<R>>
1233        where
1234            F: FnOnce(&Self) -> CoreResult<R> + Send + 'static,
1235            R: Send + 'static;
1236
1237        /// Check if this array supports asynchronous operations.
1238        #[must_use]
1239        fn supports_async(&self) -> bool;
1240    }
1241}
1242
1243#[cfg(test)]
1244mod tests {
1245    use super::*;
1246
1247    #[test]
1248    fn test_array_protocol_registry() {
1249        // Create a function and register it
1250        let implementation = Arc::new(
1251            move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1252                Ok(Box::new(42.0) as Box<dyn Any>)
1253            },
1254        );
1255
1256        let func = ArrayFunction {
1257            name: "scirs2::test::test_func",
1258            implementation,
1259        };
1260
1261        let registry = ArrayFunctionRegistry::global();
1262        {
1263            let mut reg = registry.write().unwrap();
1264            reg.register(func.clone());
1265        }
1266
1267        // Verify the function was registered
1268        {
1269            let reg = registry.read().unwrap();
1270            let registered_func = reg.get("scirs2::test::test_func").unwrap();
1271            assert_eq!(registered_func.name, "scirs2::test::test_func");
1272        }
1273    }
1274
1275    #[test]
1276    fn test_mock_distributed_array() {
1277        let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1278        assert!(array.is_distributed());
1279
1280        let info = array.distribution_info();
1281        assert_eq!(info.get("type").unwrap(), "mock_distributed");
1282        assert_eq!(info.get("chunks").unwrap(), "3");
1283    }
1284
1285    #[test]
1286    fn test_mock_gpu_array() {
1287        let array = MockGPUArray::new(vec![1.0, 2.0, 3.0], vec![3], "cuda:0".to_string());
1288        assert!(array.is_on_gpu());
1289
1290        let info = array.device_info();
1291        assert_eq!(info.get("device").unwrap(), "cuda:0");
1292        assert_eq!(info.get("type").unwrap(), "mock_gpu");
1293    }
1294
1295    #[test]
1296    fn test_box_clone() {
1297        // Test Box<dyn ArrayProtocol> cloning for NdarrayWrapper
1298        let array = ndarray::Array2::<f64>::ones((3, 3));
1299        let wrapped = NdarrayWrapper::new(array);
1300        let boxed: Box<dyn ArrayProtocol> = Box::new(wrapped);
1301        let cloned = boxed.clone();
1302
1303        // Verify the clone is correct
1304        assert_eq!(cloned.shape(), &[3, 3]);
1305
1306        // Test Box<dyn ArrayProtocol> cloning for MockDistributedArray
1307        let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1308        let boxed: Box<dyn ArrayProtocol> = Box::new(array);
1309        let cloned = boxed.clone();
1310
1311        // Verify the clone is correct
1312        assert_eq!(cloned.shape(), &[3]);
1313    }
1314}
1315
1316/// Examples of using the array protocol.
1317#[cfg(test)]
1318mod examples {
1319    use super::*;
1320    use ndarray::Array2;
1321    use std::any::Any;
1322    use std::collections::HashMap;
1323
1324    /// Example: Create and use a distributed array.
1325    #[test]
1326    fn example_distributed_array() {
1327        // Create a regular array
1328        let array = Array2::<f64>::ones((10, 5));
1329
1330        // Create a distributed array configuration
1331        let config = DistributedConfig {
1332            chunks: 3,
1333            balance: true,
1334            strategy: DistributionStrategy::RowWise,
1335            backend: DistributedBackend::Threaded,
1336        };
1337
1338        // Create a distributed array
1339        let dist_array = DistributedNdarray::from_array(&array, config);
1340
1341        // Check that the array was split correctly
1342        assert_eq!(dist_array.num_chunks(), 3);
1343        assert_eq!(dist_array.shape(), &[10, 5]);
1344
1345        // Convert back to a regular array
1346        let result = dist_array.to_array().unwrap();
1347
1348        // Check that the result matches the original array
1349        assert_eq!(result.shape(), array.shape());
1350        // NOTE: Arrays with different dimensions can't be directly compared
1351        // assert_eq!(result, array);
1352    }
1353
1354    /// Example: Create and use a GPU array.
1355    #[test]
1356    fn example_gpu_array() {
1357        // Create a regular array
1358        let array = Array2::<f64>::ones((10, 5));
1359
1360        // Create a GPU array configuration
1361        let config = GPUConfig {
1362            backend: GPUBackend::CUDA,
1363            device_id: 0,
1364            async_ops: true,
1365            mixed_precision: false,
1366            memory_fraction: 0.9,
1367        };
1368
1369        // Create a GPU array
1370        let gpu_array = GPUNdarray::new(array.clone(), config);
1371
1372        // Check that the array was created correctly
1373        assert_eq!(gpu_array.shape(), &[10, 5]);
1374        assert!(gpu_array.is_on_gpu());
1375
1376        // Get device information
1377        let info = gpu_array.device_info();
1378        assert_eq!(info.get("backend").unwrap(), "CUDA");
1379
1380        // Test box_clone for GPU array
1381        let gpu_box: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1382        let gpu_clone = gpu_box.clone();
1383
1384        // Check the cloned GPU array
1385        assert_eq!(gpu_clone.shape(), &[10, 5]);
1386    }
1387
1388    /// Example: Create and use a JIT-enabled array.
1389    #[test]
1390    fn example_jit_array() {
1391        // Initialize the JIT manager
1392        init();
1393
1394        // Create a regular array
1395        let array = Array2::<f64>::ones((10, 5));
1396        let wrapped = NdarrayWrapper::new(array);
1397
1398        // Create a JIT-enabled array
1399        let jitarray: JITEnabledArray<f64, NdarrayWrapper<f64, ndarray::Ix2>> =
1400            JITEnabledArray::new(wrapped);
1401
1402        // Check if JIT is supported
1403        assert!(jitarray.supports_jit());
1404
1405        // Compile a function
1406        let expression = "x + y";
1407        let jit_function = jitarray.compile(expression).unwrap();
1408
1409        // Check the function's properties
1410        assert_eq!(jit_function.source(), expression);
1411
1412        // Get JIT information
1413        let info = jitarray.jit_info();
1414        assert_eq!(info.get("supports_jit").unwrap(), "true");
1415
1416        // Test box_clone for JIT-enabled array
1417        let jit_box: Box<dyn ArrayProtocol> = Box::new(jitarray);
1418        let jit_clone = jit_box.clone();
1419
1420        // Check the cloned JIT array
1421        assert_eq!(jit_clone.shape(), &[10, 5]);
1422    }
1423
1424    /// Example: Test cloning Box<dyn ArrayProtocol>
1425    #[test]
1426    fn example_cloning_array_protocol_objects() {
1427        // Create a GPU array with box_clone support
1428        let array = Array2::<f64>::ones((10, 5));
1429        let config = GPUConfig::default();
1430        let gpu_array = GPUNdarray::new(array.clone(), config);
1431
1432        // Box the array as ArrayProtocol and clone it
1433        let boxed: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1434        let cloned = boxed.clone();
1435
1436        // Verify the clone works correctly
1437        assert_eq!(cloned.shape(), &[10, 5]);
1438
1439        // Create a distributed array and test box_clone
1440        let config = DistributedConfig {
1441            chunks: 3,
1442            balance: true,
1443            strategy: DistributionStrategy::RowWise,
1444            backend: DistributedBackend::Threaded,
1445        };
1446        let dist_array = DistributedNdarray::from_array(&array, config);
1447
1448        // Box the array as ArrayProtocol and clone it
1449        let boxed: Box<dyn ArrayProtocol> = Box::new(dist_array);
1450        let cloned = boxed.clone();
1451
1452        // Verify the clone works correctly
1453        assert_eq!(cloned.shape(), &[10, 5]);
1454    }
1455
1456    /*
1457    // Commented out examples using macros - we'll fix these later
1458
1459    /// Example: Define an array function using the macro.
1460    /// Example: Register and use an array function.
1461    #[test]
1462    fn example_array_function() {
1463        // Create a simple array function (without using macros)
1464        let funcname = "scirs2::example::sum";
1465
1466        // Create an ArrayFunction manually
1467        let implementation = Arc::new(move |args: &[Box<dyn Any>], kwargs: &HashMap<String, Box<dyn Any>>| {
1468            if let Some(array) = args.get(0)
1469                .and_then(|arg| arg.downcast_ref::<Array2<f64>>()) {
1470                let sum = array.sum();
1471                Ok(Box::new(sum))
1472            } else {
1473                Err(CoreError::InvalidArgument(ErrorContext::new(
1474                    "Expected Array2<f64> as first argument".to_string()
1475                )))
1476            }
1477        });
1478
1479        let func = ArrayFunction {
1480            name: funcname,
1481            implementation,
1482        };
1483
1484        // Register the function
1485        let registry = ArrayFunctionRegistry::global();
1486        {
1487            let mut reg = registry.write().unwrap();
1488            reg.register(func.clone());
1489        }
1490
1491        // Verify the function was registered
1492        {
1493            let reg = registry.read().unwrap();
1494            let registered_func = reg.get(funcname).unwrap();
1495            assert_eq!(registered_func.name, funcname);
1496        }
1497    }
1498    */
1499
1500    /// Example: Interoperability between different array types
1501    #[test]
1502    fn example_array_interoperability() {
1503        // Initialize the system
1504        init();
1505
1506        // Create arrays of different types
1507        let cpu_array = Array2::<f64>::ones((5, 5));
1508
1509        // Create a GPU array
1510        let gpu_config = GPUConfig {
1511            backend: GPUBackend::CUDA,
1512            device_id: 0,
1513            async_ops: false,
1514            mixed_precision: false,
1515            memory_fraction: 0.9,
1516        };
1517        let gpu_array = GPUNdarray::new(cpu_array.clone(), gpu_config);
1518
1519        // Create a distributed array
1520        let dist_config = DistributedConfig {
1521            chunks: 2,
1522            balance: true,
1523            strategy: DistributionStrategy::RowWise,
1524            backend: DistributedBackend::Threaded,
1525        };
1526        let dist_array = DistributedNdarray::from_array(&cpu_array, dist_config);
1527
1528        // Simple test of interoperability: convert both to Box<dyn ArrayProtocol>
1529        let gpu_wrapper: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1530        let dist_wrapper: Box<dyn ArrayProtocol> = Box::new(dist_array);
1531
1532        // Verify the clones work correctly
1533        let gpu_clone = gpu_wrapper.clone();
1534        let dist_clone = dist_wrapper.clone();
1535
1536        assert_eq!(gpu_clone.shape(), &[5, 5]);
1537        assert_eq!(dist_clone.shape(), &[5, 5]);
1538    }
1539
1540    /// Example: Advanced usage with custom array type
1541    #[test]
1542    fn example_custom_array_type() {
1543        use std::sync::Arc;
1544
1545        // Define a custom array type
1546        struct MyCustomArray<T> {
1547            data: Vec<T>,
1548            shape: Vec<usize>,
1549        }
1550
1551        impl<T: Clone + 'static> MyCustomArray<T> {
1552            fn new(data: Vec<T>, shape: Vec<usize>) -> Self {
1553                Self { data, shape }
1554            }
1555
1556            // Commented out since it's unused but may be needed in the future
1557            // fn shape(&self) -> &[usize] {
1558            //     &self.shape
1559            // }
1560        }
1561
1562        // Implement ArrayProtocol for the custom array type
1563        impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MyCustomArray<T> {
1564            fn array_function(
1565                &self,
1566                func: &ArrayFunction,
1567                _types: &[TypeId],
1568                _args: &[Box<dyn Any>],
1569                _kwargs: &HashMap<String, Box<dyn Any>>,
1570            ) -> Result<Box<dyn Any>, NotImplemented> {
1571                if func.name == "scirs2::example::custom_sum" {
1572                    // Implement custom sum for our array type
1573                    match std::any::TypeId::of::<T>() {
1574                        tid if tid == std::any::TypeId::of::<f64>() => {
1575                            // For f64 arrays, cast to f64 slice
1576                            let f64_data = unsafe {
1577                                std::slice::from_raw_parts(
1578                                    self.data.as_ptr() as *const f64,
1579                                    self.data.len(),
1580                                )
1581                            };
1582                            let sum = f64_data.iter().sum::<f64>();
1583                            Ok(Box::new(sum))
1584                        }
1585                        tid if tid == std::any::TypeId::of::<f32>() => {
1586                            // For f32 arrays, cast to f32 slice
1587                            let f32_data = unsafe {
1588                                std::slice::from_raw_parts(
1589                                    self.data.as_ptr() as *const f32,
1590                                    self.data.len(),
1591                                )
1592                            };
1593                            let sum = f32_data.iter().sum::<f32>();
1594                            Ok(Box::new(sum))
1595                        }
1596                        _ => Err(NotImplemented),
1597                    }
1598                } else {
1599                    Err(NotImplemented)
1600                }
1601            }
1602
1603            fn as_any(&self) -> &dyn Any {
1604                self
1605            }
1606
1607            fn shape(&self) -> &[usize] {
1608                &self.shape
1609            }
1610
1611            fn box_clone(&self) -> Box<dyn ArrayProtocol> {
1612                Box::new(MyCustomArray {
1613                    data: self.data.clone(),
1614                    shape: self.shape.clone(),
1615                })
1616            }
1617        }
1618
1619        // Create an instance of the custom array type
1620        let custom_array = MyCustomArray::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1621
1622        // Test box_clone functionality
1623        let boxed: Box<dyn ArrayProtocol> = Box::new(custom_array);
1624        let cloned = boxed.clone();
1625
1626        // Verify the clone has the correct shape
1627        assert_eq!(cloned.shape(), &[2, 2]);
1628
1629        // Create an ArrayFunction for testing
1630        let func = ArrayFunction {
1631            name: "scirs2::example::custom_sum",
1632            implementation: Arc::new(move |_args, _kwargs| {
1633                // Dummy implementation
1634                Ok(Box::new(42.0) as Box<dyn Any>)
1635            }),
1636        };
1637
1638        // Test array_function directly
1639        let result = cloned.array_function(
1640            &func,
1641            &[std::any::TypeId::of::<f64>()],
1642            &[],
1643            &HashMap::new(),
1644        );
1645
1646        // Verify we get a result (the sum of 1+2+3+4 = 10)
1647        assert!(result.is_ok());
1648        if let Ok(value) = result {
1649            let sum = *value.downcast_ref::<f64>().unwrap();
1650            assert_eq!(sum, 10.0);
1651        }
1652    }
1653}
1654/// Make `Box<dyn JITFunction>` cloneable via clone_box
1655impl Clone for Box<dyn JITFunction> {
1656    fn clone(&self) -> Self {
1657        self.clone_box()
1658    }
1659}