1use std::any::{Any, TypeId};
29use std::collections::HashMap;
30use std::fmt::{Debug, Display};
31use std::marker::PhantomData;
32use std::sync::{Arc, LazyLock, RwLock};
33use std::time::{Duration, Instant};
34
35use crate::error::{CoreError, CoreResult, ErrorContext};
36
37mod distributed_impl;
39mod gpu_impl;
40#[cfg(feature = "array_protocol_wgpu")]
41pub mod gpu_ndarray;
42#[cfg(feature = "array_protocol_wgpu")]
43mod gpu_ndarray_shaders;
44mod jit_impl;
45mod operations;
46
47pub use crate::array_function_dispatch;
49
50pub mod auto_device;
52pub mod distributed_training;
53pub mod grad;
54pub mod mixed_precision;
55pub mod ml_ops;
56pub mod neural;
57#[cfg(feature = "serialization")]
58pub mod serialization;
59pub mod training;
60
61pub trait ArrayProtocol: Any + Send + Sync {
65 fn array_function(
79 &self,
80 func: &ArrayFunction,
81 types: &[TypeId],
82 args: &[Box<dyn Any>],
83 kwargs: &HashMap<String, Box<dyn Any>>,
84 ) -> Result<Box<dyn Any>, NotImplemented>;
85
86 #[must_use]
88 fn as_any(&self) -> &dyn Any;
89
90 #[must_use]
92 fn shape(&self) -> &[usize] {
93 &[]
94 }
95
96 #[must_use]
98 fn dtype(&self) -> TypeId {
99 TypeId::of::<f64>()
100 }
101
102 #[must_use]
104 fn box_clone(&self) -> Box<dyn ArrayProtocol>;
105}
106
107impl Clone for Box<dyn ArrayProtocol> {
109 fn clone(&self) -> Self {
110 self.box_clone()
111 }
112}
113
114#[derive(Debug, Clone, Copy)]
124pub struct NotImplemented;
125
126impl Display for NotImplemented {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 write!(f, "NotImplemented")
129 }
130}
131
132pub type ArrayFunctionImpl = dyn Fn(&[Box<dyn Any>], &HashMap<String, Box<dyn Any>>) -> CoreResult<Box<dyn Any>>
134 + Send
135 + Sync;
136
137#[derive(Clone)]
139pub struct ArrayFunction {
140 pub name: &'static str,
142
143 pub implementation: Arc<ArrayFunctionImpl>,
145}
146
147impl Debug for ArrayFunction {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("ArrayFunction")
150 .field("name", &self.name)
151 .finish_non_exhaustive()
152 }
153}
154
155impl PartialEq for ArrayFunction {
156 fn eq(&self, other: &Self) -> bool {
157 self.name == other.name
158 }
159}
160
161impl Eq for ArrayFunction {}
162
163impl std::hash::Hash for ArrayFunction {
164 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
165 self.name.hash(state);
166 }
167}
168
169impl ArrayFunction {
170 #[must_use]
172 pub fn new(name: &'static str) -> Self {
173 Self {
174 name,
175 implementation: Arc::new(|_args, _kwargs| {
177 Err(CoreError::NotImplementedError(ErrorContext::new(
178 "Function not implemented".to_string(),
179 )))
180 }),
181 }
182 }
183}
184
185#[derive(Debug, Clone)]
187pub struct DispatchCacheEntry {
188 #[allow(dead_code)]
190 type_signature: Vec<TypeId>,
191 #[allow(dead_code)]
193 preferred_impl_type: TypeId,
194 timestamp: Instant,
196 hit_count: u64,
198}
199
200#[derive(Debug)]
202pub struct ArrayFunctionRegistry {
203 functions: HashMap<&'static str, ArrayFunction>,
205 dispatch_cache: HashMap<(&'static str, Vec<TypeId>), DispatchCacheEntry>,
207 max_cache_size: usize,
209 cache_ttl: Duration,
211}
212
213impl Default for ArrayFunctionRegistry {
214 fn default() -> Self {
215 Self {
216 functions: HashMap::new(),
217 dispatch_cache: HashMap::new(),
218 max_cache_size: 1000, cache_ttl: Duration::from_secs(300), }
221 }
222}
223
224impl ArrayFunctionRegistry {
225 #[must_use]
227 pub fn global() -> &'static RwLock<Self> {
228 static REGISTRY: LazyLock<RwLock<ArrayFunctionRegistry>> =
229 LazyLock::new(|| RwLock::new(ArrayFunctionRegistry::default()));
230 ®ISTRY
231 }
232
233 pub fn register(&mut self, func: ArrayFunction) {
235 self.functions.insert(func.name, func);
236 }
237
238 #[must_use]
240 #[allow(dead_code)]
241 pub fn get(&self, name: &str) -> Option<&ArrayFunction> {
242 self.functions.get(name)
243 }
244
245 #[must_use]
247 pub fn all_functions(&self) -> Vec<&ArrayFunction> {
248 self.functions.values().collect()
249 }
250
251 #[must_use]
253 pub fn get_cached_dispatch(
254 &self,
255 funcname: &'static str,
256 types: &[TypeId],
257 ) -> Option<&DispatchCacheEntry> {
258 let key = (funcname, types.to_vec());
259 if let Some(entry) = self.dispatch_cache.get(&key) {
260 if entry.timestamp.elapsed() < self.cache_ttl {
262 return Some(entry);
263 }
264 }
265 None
266 }
267
268 pub fn cache_dispatch(
270 &mut self,
271 funcname: &'static str,
272 types: Vec<TypeId>,
273 impl_type: TypeId,
274 ) {
275 if self.dispatch_cache.len() >= self.max_cache_size {
277 self.cleanup_cache();
278 }
279
280 let key = (funcname, types.clone());
281 let entry = DispatchCacheEntry {
282 type_signature: types,
283 preferred_impl_type: impl_type,
284 timestamp: Instant::now(),
285 hit_count: 0,
286 };
287 self.dispatch_cache.insert(key, entry);
288 }
289
290 pub fn update_cache_hit(&mut self, funcname: &'static str, types: &[TypeId]) {
292 let key = (funcname, types.to_vec());
293 if let Some(entry) = self.dispatch_cache.get_mut(&key) {
294 entry.hit_count += 1;
295 }
296 }
297
298 fn cleanup_cache(&mut self) {
300 let now = Instant::now();
301 self.dispatch_cache
302 .retain(|_, entry| now.duration_since(entry.timestamp) < self.cache_ttl);
303
304 if self.dispatch_cache.len() >= self.max_cache_size {
306 let mut entries: Vec<_> = self
307 .dispatch_cache
308 .iter()
309 .map(|(k, v)| (k.clone(), v.hit_count))
310 .collect();
311 entries.sort_by_key(|(_, hit_count)| *hit_count);
312
313 let to_remove = self.dispatch_cache.len() / 4;
315 let keys_to_remove: Vec<_> = entries
316 .iter()
317 .take(to_remove)
318 .map(|(key, _)| key.clone())
319 .collect();
320 for key in keys_to_remove {
321 self.dispatch_cache.remove(&key);
322 }
323 }
324 }
325
326 #[must_use]
328 pub fn cache_stats(&self) -> HashMap<String, u64> {
329 let mut stats = HashMap::new();
330 stats.insert("cache_size".to_string(), self.dispatch_cache.len() as u64);
331 stats.insert("max_cache_size".to_string(), self.max_cache_size as u64);
332
333 let total_hits: u64 = self.dispatch_cache.values().map(|e| e.hit_count).sum();
334 stats.insert("total_hits".to_string(), total_hits);
335
336 stats
337 }
338}
339
340#[allow(dead_code)]
345pub fn get_implementing_args(args: &[Box<dyn Any>]) -> Vec<(TypeId, &dyn ArrayProtocol)> {
346 if args.is_empty() {
347 return Vec::new();
348 }
349
350 let mut implementing_args = Vec::with_capacity(args.len());
352
353 for arg in args {
354 if let Some(array_protocol_obj) = arg.downcast_ref::<Box<dyn ArrayProtocol>>() {
355 let type_id = (**array_protocol_obj).type_id();
356 implementing_args.push((type_id, &**array_protocol_obj));
357 }
358 }
359
360 implementing_args.sort_by_key(|&_type_id_| {
363 use std::hash::{Hash, Hasher};
365 let mut hasher = std::collections::hash_map::DefaultHasher::new();
366 std::any::TypeId::of::<i32>().hash(&mut hasher);
367 hasher.finish()
368 });
369
370 implementing_args
371}
372
373#[allow(dead_code)]
384pub fn array_function_dispatch(
385 func: &ArrayFunction,
386 args: &[Box<dyn Any>],
387 kwargs: &HashMap<String, Box<dyn Any>>,
388) -> CoreResult<Box<dyn Any>> {
389 if args.is_empty() {
391 return (func.implementation)(args, kwargs);
392 }
393
394 let implementing_args = get_implementing_args(args);
396
397 if implementing_args.is_empty() {
398 return (func.implementation)(args, kwargs);
400 }
401
402 if implementing_args.len() == 1 {
404 let (type_id, array_protocol_obj) = implementing_args[0];
405 let types = [type_id];
406 match array_protocol_obj.array_function(func, &types, args, kwargs) {
407 Ok(result) => return Ok(result),
408 Err(NotImplemented) => {
409 return Err(CoreError::DispatchError(ErrorContext::new(format!(
410 "No implementation found for {} with type {:?}",
411 func.name, type_id
412 ))));
413 }
414 }
415 }
416
417 let mut unique_types = Vec::with_capacity(implementing_args.len());
419 let mut seen_types = std::collections::HashSet::with_capacity(implementing_args.len());
420
421 for &(type_id, _) in &implementing_args {
422 if seen_types.insert(type_id) {
423 unique_types.push(type_id);
424 }
425 }
426
427 for (_, array_protocol_obj) in implementing_args {
429 if let Ok(result) = array_protocol_obj.array_function(func, &unique_types, args, kwargs) {
430 return Ok(result);
431 }
432 }
433
434 Err(CoreError::DispatchError(ErrorContext::new(format!(
436 "No implementation found for {} with {} argument types: {:?}",
437 func.name,
438 unique_types.len(),
439 unique_types
440 ))))
441}
442
443pub struct ArrayFunctionDecorator<F> {
447 function: F,
448 name: &'static str,
449}
450
451impl<F> ArrayFunctionDecorator<F>
452where
453 F: Send + Sync + 'static,
454{
455 #[must_use]
457 pub fn new(function: F, name: &'static str) -> Self {
458 Self { function, name }
459 }
460
461 pub fn register(self) -> F {
463 let implementation = Arc::new(
464 move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
465 Err(CoreError::NotImplementedError(ErrorContext::new(
470 "ArrayFunctionDecorator: Type conversion in array_function_dispatch is not implemented yet".to_string()
471 )))
472 },
473 );
474
475 let func = ArrayFunction {
476 name: self.name,
477 implementation,
478 };
479
480 let registry = ArrayFunctionRegistry::global();
482 if let Ok(mut registry) = registry.write() {
483 registry.register(func);
484 } else {
485 eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry, skipping function registration");
486 }
488
489 self.function
490 }
491}
492
493pub trait GPUArray: ArrayProtocol {
495 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>>;
497
498 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
500
501 #[must_use]
503 fn is_on_gpu(&self) -> bool;
504
505 #[must_use]
507 fn device_info(&self) -> HashMap<String, String>;
508}
509
510pub trait DistributedArray: ArrayProtocol {
512 #[must_use]
514 fn distribution_info(&self) -> HashMap<String, String>;
515
516 fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
518
519 fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>>;
521
522 #[must_use]
524 fn is_distributed(&self) -> bool;
525}
526
527pub trait JITArray: ArrayProtocol {
529 fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>>;
531
532 #[must_use]
534 fn supports_jit(&self) -> bool;
535
536 #[must_use]
538 fn jit_info(&self) -> HashMap<String, String>;
539}
540
541pub trait JITFunction: Send + Sync {
543 fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>>;
545
546 #[must_use]
548 fn source(&self) -> String;
549
550 #[must_use]
552 fn compile_info(&self) -> HashMap<String, String>;
553
554 #[must_use]
556 fn clone_box(&self) -> Box<dyn JITFunction>;
557}
558
559pub trait JITFunctionFactory: Send + Sync {
561 fn create_jit_function(
563 &self,
564 expression: &str,
565 array_typeid: TypeId,
566 ) -> CoreResult<Box<dyn JITFunction>>;
567
568 #[must_use]
570 fn supports_array_type(&self, array_typeid: TypeId) -> bool;
571}
572
573#[derive(Default)]
575pub struct JITFactoryRegistry {
576 factories: Vec<Box<dyn JITFunctionFactory>>,
577}
578
579impl std::fmt::Debug for JITFactoryRegistry {
580 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
581 write!(
582 f,
583 "JITFactoryRegistry {{ factories: {} }}",
584 self.factories.len()
585 )
586 }
587}
588
589impl JITFactoryRegistry {
590 #[must_use]
592 pub fn global() -> &'static RwLock<Self> {
593 static REGISTRY: LazyLock<RwLock<JITFactoryRegistry>> = LazyLock::new(|| {
594 RwLock::new(JITFactoryRegistry {
595 factories: Vec::new(),
596 })
597 });
598 ®ISTRY
599 }
600
601 pub fn register(&mut self, factory: Box<dyn JITFunctionFactory>) {
603 self.factories.push(factory);
604 }
605
606 #[must_use]
608 pub fn get_factory_for_array_type(
609 &self,
610 array_typeid: TypeId,
611 ) -> Option<&dyn JITFunctionFactory> {
612 for factory in &self.factories {
613 if factory.supports_array_type(array_typeid) {
614 return Some(&**factory);
615 }
616 }
617 None
618 }
619}
620
621#[derive(Debug, Clone)]
623pub struct NdarrayWrapper<T, D: crate::ndarray::Dimension> {
624 array: crate::ndarray::Array<T, D>,
625 phantom: PhantomData<(T, D)>,
626}
627
628impl<T, D> NdarrayWrapper<T, D>
629where
630 T: Clone + 'static,
631 D: crate::ndarray::Dimension + 'static,
632{
633 #[must_use]
635 pub fn new(array: crate::ndarray::Array<T, D>) -> Self {
636 Self {
637 array,
638 phantom: PhantomData,
639 }
640 }
641
642 #[must_use]
644 pub const fn as_array(&self) -> &crate::ndarray::Array<T, D> {
645 &self.array
646 }
647
648 #[must_use]
650 pub fn into_array(self) -> crate::ndarray::Array<T, D> {
651 self.array
652 }
653
654 pub fn array_2(&mut self, newarray: crate::ndarray::Array<T, D>) {
656 self.array = newarray;
657 }
658}
659
660impl<T, D> ArrayProtocol for NdarrayWrapper<T, D>
661where
662 T: Clone + Send + Sync + 'static,
663 D: crate::ndarray::Dimension + Send + Sync + 'static,
664{
665 fn array_function(
666 &self,
667 func: &ArrayFunction,
668 _types: &[TypeId],
669 args: &[Box<dyn Any>],
670 kwargs: &HashMap<String, Box<dyn Any>>,
671 ) -> Result<Box<dyn Any>, NotImplemented> {
672 match func.name {
673 "scirs2::array_protocol::operations::add" => {
674 if args.len() < 2 {
676 return Err(NotImplemented);
677 }
678
679 if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
680 if let (Some(a), Some(b)) = (
681 self.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
682 other.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
683 ) {
684 if TypeId::of::<T>() == TypeId::of::<f64>() {
686 let a_f64 =
687 unsafe { &*(a as *const _ as *const NdarrayWrapper<f64, D>) };
688 let b_f64 =
689 unsafe { &*(b as *const _ as *const NdarrayWrapper<f64, D>) };
690 let result = a_f64.as_array() + b_f64.as_array();
691 return Ok(Box::new(NdarrayWrapper::new(result)));
692 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
693 let a_f32 =
694 unsafe { &*(a as *const _ as *const NdarrayWrapper<f32, D>) };
695 let b_f32 =
696 unsafe { &*(b as *const _ as *const NdarrayWrapper<f32, D>) };
697 let result = a_f32.as_array() + b_f32.as_array();
698 return Ok(Box::new(NdarrayWrapper::new(result)));
699 }
700 }
701 }
702 Err(NotImplemented)
703 }
704 "scirs2::array_protocol::operations::matmul" => {
705 if args.len() < 2 {
707 return Err(NotImplemented);
708 }
709
710 if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
713 return Err(NotImplemented);
714 }
715
716 if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
717 if TypeId::of::<T>() == TypeId::of::<f64>() {
722 let a_f64 = unsafe {
724 &*(self as *const _ as *const NdarrayWrapper<f64, crate::ndarray::Ix2>)
725 };
726 let b_f64 = unsafe {
727 &*(other as *const _ as *const NdarrayWrapper<f64, crate::ndarray::Ix2>)
728 };
729
730 let ashape = a_f64.as_array().shape();
732 let bshape = b_f64.as_array().shape();
733
734 if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
735 return Err(NotImplemented);
736 }
737
738 let result = a_f64.as_array().dot(b_f64.as_array());
741 return Ok(Box::new(NdarrayWrapper::new(result)));
742 }
743 else if TypeId::of::<T>() == TypeId::of::<f32>() {
745 let a_f32 = unsafe {
747 &*(self as *const _ as *const NdarrayWrapper<f32, crate::ndarray::Ix2>)
748 };
749 let b_f32 = unsafe {
750 &*(other as *const _ as *const NdarrayWrapper<f32, crate::ndarray::Ix2>)
751 };
752
753 let ashape = a_f32.as_array().shape();
755 let bshape = b_f32.as_array().shape();
756
757 if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
758 return Err(NotImplemented);
759 }
760
761 let result = a_f32.as_array().dot(b_f32.as_array());
764 return Ok(Box::new(NdarrayWrapper::new(result)));
765 }
766 }
767 Err(NotImplemented)
769 }
770 "scirs2::array_protocol::operations::transpose" => {
771 if TypeId::of::<T>() == TypeId::of::<f64>() {
773 let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
774 let result = a_f64.as_array().t().to_owned();
775 return Ok(Box::new(NdarrayWrapper::new(result)));
776 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
777 let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
778 let result = a_f32.as_array().t().to_owned();
779 return Ok(Box::new(NdarrayWrapper::new(result)));
780 }
781 Err(NotImplemented)
782 }
783 "scirs2::array_protocol::operations::sum" => {
784 let axis_ref = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
786
787 if TypeId::of::<T>() == TypeId::of::<f64>() {
788 let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
789 match axis_ref {
790 Some(&_ax) => {
791 let result = a_f64.as_array().sum();
794 return Ok(Box::new(result));
795 }
796 None => {
797 let result = a_f64.as_array().sum();
798 return Ok(Box::new(result));
799 }
800 }
801 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
802 let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
803 match axis_ref {
804 Some(&_ax) => {
805 let result = a_f32.as_array().sum();
808 return Ok(Box::new(result));
809 }
810 None => {
811 let result = a_f32.as_array().sum();
812 return Ok(Box::new(result));
813 }
814 }
815 }
816 Err(NotImplemented)
817 }
818 "scirs2::array_protocol::operations::reshape" => {
819 if let Some(shape) = kwargs
821 .get("shape")
822 .and_then(|s| s.downcast_ref::<Vec<usize>>())
823 {
824 if TypeId::of::<T>() == TypeId::of::<f64>() {
825 let a_f64 =
826 unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
827 match a_f64
828 .as_array()
829 .clone()
830 .into_shape_with_order(shape.clone())
831 {
832 Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
833 Err(_) => return Err(NotImplemented),
834 }
835 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
836 let a_f32 =
837 unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
838 match a_f32
839 .as_array()
840 .clone()
841 .into_shape_with_order(shape.clone())
842 {
843 Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
844 Err(_) => return Err(NotImplemented),
845 }
846 }
847 }
848 Err(NotImplemented)
849 }
850 _ => Err(NotImplemented),
851 }
852 }
853
854 fn as_any(&self) -> &dyn Any {
855 self
856 }
857
858 fn shape(&self) -> &[usize] {
859 self.array.shape()
860 }
861
862 fn dtype(&self) -> TypeId {
863 TypeId::of::<T>()
864 }
865
866 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
867 Box::new(self.clone())
868 }
869}
870
871#[derive(Debug, Clone)]
875pub struct MockDistributedArray<T: Clone + 'static> {
876 chunks: Vec<T>,
877 shape: Vec<usize>,
878}
879
880impl<T: Clone + Send + Sync + 'static> MockDistributedArray<T> {
881 #[must_use]
883 pub fn new(chunks: Vec<T>, shape: Vec<usize>) -> Self {
884 Self { chunks, shape }
885 }
886}
887
888impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockDistributedArray<T> {
889 fn array_function(
890 &self,
891 func: &ArrayFunction,
892 _types: &[TypeId],
893 _args: &[Box<dyn Any>],
894 _kwargs: &HashMap<String, Box<dyn Any>>,
895 ) -> Result<Box<dyn Any>, NotImplemented> {
896 match func.name {
897 "scirs2::mean" => {
898 let result = T::clone(&self.chunks[0]);
903 Ok(Box::new(result))
904 }
905 _ => Err(NotImplemented),
906 }
907 }
908
909 fn as_any(&self) -> &dyn Any {
910 self
911 }
912
913 fn shape(&self) -> &[usize] {
914 &self.shape
915 }
916
917 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
918 Box::new(self.clone())
919 }
920}
921
922impl<T: Clone + Send + Sync + 'static> DistributedArray for MockDistributedArray<T> {
923 fn distribution_info(&self) -> HashMap<String, String> {
924 let mut info = HashMap::new();
925 info.insert("type".to_string(), "mock_distributed".to_string());
926 info.insert("chunks".to_string(), self.chunks.len().to_string());
927 info
928 }
929
930 fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
931 Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
934 }
935
936 fn scatter(&self, _numchunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
937 Ok(Box::new(self.clone()) as Box<dyn DistributedArray>)
940 }
941
942 fn is_distributed(&self) -> bool {
943 true
944 }
945}
946
947#[derive(Debug, Clone)]
949pub struct MockGPUArray<T: Clone + 'static> {
950 data: Vec<T>,
951 shape: Vec<usize>,
952 device: String,
953}
954
955impl<T: Clone + Send + Sync + 'static> MockGPUArray<T> {
956 #[must_use]
958 pub fn new(data: Vec<T>, shape: Vec<usize>, device: String) -> Self {
959 Self {
960 data,
961 shape,
962 device,
963 }
964 }
965}
966
967impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockGPUArray<T> {
968 fn array_function(
969 &self,
970 func: &ArrayFunction,
971 _types: &[TypeId],
972 _args: &[Box<dyn Any>],
973 _kwargs: &HashMap<String, Box<dyn Any>>,
974 ) -> Result<Box<dyn Any>, NotImplemented> {
975 match func.name {
976 "scirs2::matmul" => {
977 let result =
982 MockGPUArray::new(self.data.clone(), self.shape.clone(), self.device.clone());
983 Ok(Box::new(result))
984 }
985 _ => Err(NotImplemented),
986 }
987 }
988
989 fn as_any(&self) -> &dyn Any {
990 self
991 }
992
993 fn shape(&self) -> &[usize] {
994 &self.shape
995 }
996
997 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
998 Box::new(self.clone())
999 }
1000}
1001
1002impl<T: Clone + Send + Sync + 'static> GPUArray for MockGPUArray<T> {
1003 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
1004 Ok(Box::new(self.clone()) as Box<dyn GPUArray>)
1006 }
1007
1008 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
1009 Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
1012 }
1013
1014 fn is_on_gpu(&self) -> bool {
1015 true
1016 }
1017
1018 fn device_info(&self) -> HashMap<String, String> {
1019 let mut info = HashMap::new();
1020 info.insert("device".to_string(), self.device.clone());
1021 info.insert("type".to_string(), "mock_gpu".to_string());
1022 info
1023 }
1024}
1025
1026#[derive(Debug)]
1031pub struct ArrayProtocolFunction<F> {
1032 func: F,
1033 name: &'static str,
1034}
1035
1036impl<F> ArrayProtocolFunction<F> {
1037 #[must_use]
1039 pub fn new(func: F, name: &'static str) -> Self {
1040 Self { func, name }
1041 }
1042}
1043
1044impl<F> ArrayProtocolFunction<F>
1045where
1046 F: Clone + Send + Sync + 'static,
1047{
1048 pub fn register(self) -> F {
1050 let implementation = Arc::new(
1051 move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1052 Err(CoreError::NotImplementedError(ErrorContext::new(
1057 "ArrayProtocolFunction: Implementation for array protocol functions is not complete".to_string()
1058 )))
1059 },
1060 );
1061
1062 let array_func = ArrayFunction {
1063 name: self.name,
1064 implementation,
1065 };
1066
1067 if let Ok(mut registry) = ArrayFunctionRegistry::global().write() {
1069 registry.register(array_func);
1070 } else {
1071 eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry during array protocol building, skipping function registration");
1072 }
1074
1075 self.func
1076 }
1077}
1078
1079#[macro_export]
1121macro_rules! array_function_def {
1122 (fn $name:ident $(<$($gen:ident),*>)? ($($arg:ident : $arg_ty:ty),*) -> $ret:ty $body:block, $funcname:expr) => {
1123 {
1124 fn $name $(<$($gen),*>)? ($($arg : $arg_ty),*) -> $ret $body
1126
1127 $name
1129 }
1130 };
1131}
1132
1133pub use self::distributed_impl::{
1135 ArrayChunk, DistributedBackend, DistributedConfig, DistributedNdarray, DistributionStrategy,
1136};
1137
1138pub use self::gpu_impl::{
1140 kernels as gpu_kernels, GPUArrayBuilder, GPUBackend, GPUConfig, GPUNdarray,
1141};
1142
1143pub use self::jit_impl::{
1145 CraneliftFunctionFactory, JITBackend, JITConfig, JITEnabledArray, JITFunctionImpl, JITManager,
1146 LLVMFunctionFactory,
1147};
1148
1149pub use self::operations::{
1151 add, apply_elementwise, concatenate, inverse, matmul, multiply, reshape, subtract, sum, svd,
1152 transpose, OperationError,
1153};
1154
1155pub use self::ml_ops::{
1157 activation, batch_norm, conv2d, cross_entropy, dropout, max_pool2d, self_attention,
1158 ActivationFunc,
1159};
1160
1161#[allow(dead_code)]
1167pub fn init() {
1168 let mut jit_manager = JITManager::global().write().expect("Operation failed");
1170 jit_manager.initialize();
1171}
1172
1173pub mod traits {
1175 use super::*;
1176
1177 pub trait StridedArray: ArrayProtocol {
1179 #[must_use]
1181 fn strides(&self) -> Vec<usize>;
1182
1183 #[must_use]
1185 fn is_contiguous(&self) -> bool;
1186
1187 #[must_use]
1189 fn is_fortran_contiguous(&self) -> bool;
1190 }
1191
1192 pub trait ZeroCopyArray: ArrayProtocol {
1194 #[must_use]
1196 fn view(&self) -> Box<dyn ZeroCopyArray>;
1197
1198 #[must_use]
1200 fn view_mut(&mut self) -> Box<dyn ZeroCopyArray>;
1201
1202 #[must_use]
1204 fn is_view(&self) -> bool;
1205 }
1206
1207 pub trait DifferentiableArray: ArrayProtocol {
1209 fn gradient(
1211 &self,
1212 variables: &[Box<dyn DifferentiableArray>],
1213 ) -> Vec<Box<dyn DifferentiableArray>>;
1214
1215 fn set_requiresgrad(&mut self, requiresgrad: bool);
1217
1218 #[must_use]
1220 fn requiresgrad(&self) -> bool;
1221
1222 #[must_use]
1224 fn grad(&self) -> Option<Box<dyn DifferentiableArray>>;
1225 }
1226
1227 pub trait AsyncArray: ArrayProtocol {
1229 fn async_op<F, R>(&self, op: F) -> impl std::future::Future<Output = CoreResult<R>>
1231 where
1232 F: FnOnce(&Self) -> CoreResult<R> + Send + 'static,
1233 R: Send + 'static;
1234
1235 #[must_use]
1237 fn supports_async(&self) -> bool;
1238 }
1239}
1240
1241#[cfg(test)]
1242mod tests {
1243 use super::*;
1244
1245 #[test]
1246 fn test_array_protocol_registry() {
1247 let implementation = Arc::new(
1249 move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1250 Ok(Box::new(42.0) as Box<dyn Any>)
1251 },
1252 );
1253
1254 let func = ArrayFunction {
1255 name: "scirs2::test::test_func",
1256 implementation,
1257 };
1258
1259 let registry = ArrayFunctionRegistry::global();
1260 {
1261 let mut reg = registry.write().expect("Operation failed");
1262 reg.register(func.clone());
1263 }
1264
1265 {
1267 let reg = registry.read().expect("Operation failed");
1268 let registered_func = reg
1269 .get("scirs2::test::test_func")
1270 .expect("Operation failed");
1271 assert_eq!(registered_func.name, "scirs2::test::test_func");
1272 }
1273 }
1274
1275 #[test]
1276 fn test_mock_distributed_array() {
1277 let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1278 assert!(array.is_distributed());
1279
1280 let info = array.distribution_info();
1281 assert_eq!(
1282 info.get("type").expect("Operation failed"),
1283 "mock_distributed"
1284 );
1285 assert_eq!(info.get("chunks").expect("Operation failed"), "3");
1286 }
1287
1288 #[test]
1289 fn test_mock_gpu_array() {
1290 let array = MockGPUArray::new(vec![1.0, 2.0, 3.0], vec![3], "cuda:0".to_string());
1291 assert!(array.is_on_gpu());
1292
1293 let info = array.device_info();
1294 assert_eq!(info.get("device").expect("Operation failed"), "cuda:0");
1295 assert_eq!(info.get("type").expect("Operation failed"), "mock_gpu");
1296 }
1297
1298 #[test]
1299 fn test_box_clone() {
1300 let array = crate::ndarray::Array2::<f64>::ones((3, 3));
1302 let wrapped = NdarrayWrapper::new(array);
1303 let boxed: Box<dyn ArrayProtocol> = Box::new(wrapped);
1304 let cloned = boxed.clone();
1305
1306 assert_eq!(cloned.shape(), &[3, 3]);
1308
1309 let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1311 let boxed: Box<dyn ArrayProtocol> = Box::new(array);
1312 let cloned = boxed.clone();
1313
1314 assert_eq!(cloned.shape(), &[3]);
1316 }
1317}
1318
1319#[cfg(test)]
1321mod examples {
1322 use super::*;
1323 use ::ndarray::Array2;
1324 use std::any::Any;
1325 use std::collections::HashMap;
1326
1327 #[test]
1329 fn example_distributed_array() {
1330 let array = Array2::<f64>::ones((10, 5));
1332
1333 let config = DistributedConfig {
1335 chunks: 3,
1336 balance: true,
1337 strategy: DistributionStrategy::RowWise,
1338 backend: DistributedBackend::Threaded,
1339 };
1340
1341 let dist_array = DistributedNdarray::from_array(&array, config);
1343
1344 assert_eq!(dist_array.num_chunks(), 3);
1346 assert_eq!(dist_array.shape(), &[10, 5]);
1347
1348 let result = dist_array.to_array().expect("Operation failed");
1350
1351 assert_eq!(result.shape(), array.shape());
1353 }
1356
1357 #[test]
1359 fn example_gpu_array() {
1360 let array = Array2::<f64>::ones((10, 5));
1362
1363 let config = GPUConfig {
1365 backend: GPUBackend::CUDA,
1366 device_id: 0,
1367 async_ops: true,
1368 mixed_precision: false,
1369 memory_fraction: 0.9,
1370 };
1371
1372 let gpu_array = GPUNdarray::new(array.clone(), config);
1374
1375 assert_eq!(gpu_array.shape(), &[10, 5]);
1377 assert!(gpu_array.is_on_gpu());
1378
1379 let info = gpu_array.device_info();
1381 assert_eq!(info.get("backend").expect("Operation failed"), "CUDA");
1382
1383 let gpu_box: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1385 let gpu_clone = gpu_box.clone();
1386
1387 assert_eq!(gpu_clone.shape(), &[10, 5]);
1389 }
1390
1391 #[test]
1393 fn example_jit_array() {
1394 init();
1396
1397 let array = Array2::<f64>::ones((10, 5));
1399 let wrapped = NdarrayWrapper::new(array);
1400
1401 let jitarray: JITEnabledArray<f64, NdarrayWrapper<f64, crate::ndarray::Ix2>> =
1403 JITEnabledArray::new(wrapped);
1404
1405 assert!(jitarray.supports_jit());
1407
1408 let expression = "x + y";
1410 let jit_function = jitarray.compile(expression).expect("Operation failed");
1411
1412 assert_eq!(jit_function.source(), expression);
1414
1415 let info = jitarray.jit_info();
1417 assert_eq!(info.get("supports_jit").expect("Operation failed"), "true");
1418
1419 let jit_box: Box<dyn ArrayProtocol> = Box::new(jitarray);
1421 let jit_clone = jit_box.clone();
1422
1423 assert_eq!(jit_clone.shape(), &[10, 5]);
1425 }
1426
1427 #[test]
1429 fn example_cloning_array_protocol_objects() {
1430 let array = Array2::<f64>::ones((10, 5));
1432 let config = GPUConfig::default();
1433 let gpu_array = GPUNdarray::new(array.clone(), config);
1434
1435 let boxed: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1437 let cloned = boxed.clone();
1438
1439 assert_eq!(cloned.shape(), &[10, 5]);
1441
1442 let config = DistributedConfig {
1444 chunks: 3,
1445 balance: true,
1446 strategy: DistributionStrategy::RowWise,
1447 backend: DistributedBackend::Threaded,
1448 };
1449 let dist_array = DistributedNdarray::from_array(&array, config);
1450
1451 let boxed: Box<dyn ArrayProtocol> = Box::new(dist_array);
1453 let cloned = boxed.clone();
1454
1455 assert_eq!(cloned.shape(), &[10, 5]);
1457 }
1458
1459 #[test]
1505 fn example_array_interoperability() {
1506 init();
1508
1509 let cpu_array = Array2::<f64>::ones((5, 5));
1511
1512 let gpu_config = GPUConfig {
1514 backend: GPUBackend::CUDA,
1515 device_id: 0,
1516 async_ops: false,
1517 mixed_precision: false,
1518 memory_fraction: 0.9,
1519 };
1520 let gpu_array = GPUNdarray::new(cpu_array.clone(), gpu_config);
1521
1522 let dist_config = DistributedConfig {
1524 chunks: 2,
1525 balance: true,
1526 strategy: DistributionStrategy::RowWise,
1527 backend: DistributedBackend::Threaded,
1528 };
1529 let dist_array = DistributedNdarray::from_array(&cpu_array, dist_config);
1530
1531 let gpu_wrapper: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1533 let dist_wrapper: Box<dyn ArrayProtocol> = Box::new(dist_array);
1534
1535 let gpu_clone = gpu_wrapper.clone();
1537 let dist_clone = dist_wrapper.clone();
1538
1539 assert_eq!(gpu_clone.shape(), &[5, 5]);
1540 assert_eq!(dist_clone.shape(), &[5, 5]);
1541 }
1542
1543 #[test]
1545 fn example_custom_array_type() {
1546 use std::sync::Arc;
1547
1548 struct MyCustomArray<T> {
1550 data: Vec<T>,
1551 shape: Vec<usize>,
1552 }
1553
1554 impl<T: Clone + 'static> MyCustomArray<T> {
1555 fn new(data: Vec<T>, shape: Vec<usize>) -> Self {
1556 Self { data, shape }
1557 }
1558
1559 }
1564
1565 impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MyCustomArray<T> {
1567 fn array_function(
1568 &self,
1569 func: &ArrayFunction,
1570 _types: &[TypeId],
1571 _args: &[Box<dyn Any>],
1572 _kwargs: &HashMap<String, Box<dyn Any>>,
1573 ) -> Result<Box<dyn Any>, NotImplemented> {
1574 if func.name == "scirs2::example::custom_sum" {
1575 match std::any::TypeId::of::<T>() {
1577 tid if tid == std::any::TypeId::of::<f64>() => {
1578 let f64_data = unsafe {
1580 std::slice::from_raw_parts(
1581 self.data.as_ptr() as *const f64,
1582 self.data.len(),
1583 )
1584 };
1585 let sum = f64_data.iter().sum::<f64>();
1586 Ok(Box::new(sum))
1587 }
1588 tid if tid == std::any::TypeId::of::<f32>() => {
1589 let f32_data = unsafe {
1591 std::slice::from_raw_parts(
1592 self.data.as_ptr() as *const f32,
1593 self.data.len(),
1594 )
1595 };
1596 let sum = f32_data.iter().sum::<f32>();
1597 Ok(Box::new(sum))
1598 }
1599 _ => Err(NotImplemented),
1600 }
1601 } else {
1602 Err(NotImplemented)
1603 }
1604 }
1605
1606 fn as_any(&self) -> &dyn Any {
1607 self
1608 }
1609
1610 fn shape(&self) -> &[usize] {
1611 &self.shape
1612 }
1613
1614 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
1615 Box::new(MyCustomArray {
1616 data: self.data.clone(),
1617 shape: self.shape.clone(),
1618 })
1619 }
1620 }
1621
1622 let custom_array = MyCustomArray::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1624
1625 let boxed: Box<dyn ArrayProtocol> = Box::new(custom_array);
1627 let cloned = boxed.clone();
1628
1629 assert_eq!(cloned.shape(), &[2, 2]);
1631
1632 let func = ArrayFunction {
1634 name: "scirs2::example::custom_sum",
1635 implementation: Arc::new(move |_args, _kwargs| {
1636 Ok(Box::new(42.0) as Box<dyn Any>)
1638 }),
1639 };
1640
1641 let result = cloned.array_function(
1643 &func,
1644 &[std::any::TypeId::of::<f64>()],
1645 &[],
1646 &HashMap::new(),
1647 );
1648
1649 assert!(result.is_ok());
1651 if let Ok(value) = result {
1652 let sum = *value.downcast_ref::<f64>().expect("Operation failed");
1653 assert_eq!(sum, 10.0);
1654 }
1655 }
1656}
1657impl Clone for Box<dyn JITFunction> {
1659 fn clone(&self) -> Self {
1660 self.clone_box()
1661 }
1662}