1use std::any::{Any, TypeId};
35use std::collections::HashMap;
36use std::fmt::{Debug, Display};
37use std::marker::PhantomData;
38use std::sync::{Arc, LazyLock, RwLock};
39use std::time::{Duration, Instant};
40
41use crate::error::{CoreError, CoreResult, ErrorContext};
42
43mod distributed_impl;
45mod gpu_impl;
46mod jit_impl;
47mod operations;
48
49pub use crate::array_function_dispatch;
51
52pub mod auto_device;
54pub mod distributed_training;
55pub mod grad;
56pub mod mixed_precision;
57pub mod ml_ops;
58pub mod neural;
59#[cfg(feature = "serialization")]
60pub mod serialization;
61pub mod training;
62
63pub trait ArrayProtocol: Any + Send + Sync {
67 fn array_function(
81 &self,
82 func: &ArrayFunction,
83 types: &[TypeId],
84 args: &[Box<dyn Any>],
85 kwargs: &HashMap<String, Box<dyn Any>>,
86 ) -> Result<Box<dyn Any>, NotImplemented>;
87
88 #[must_use]
90 fn as_any(&self) -> &dyn Any;
91
92 #[must_use]
94 fn shape(&self) -> &[usize] {
95 &[]
96 }
97
98 #[must_use]
100 fn dtype(&self) -> TypeId {
101 TypeId::of::<f64>()
102 }
103
104 #[must_use]
106 fn box_clone(&self) -> Box<dyn ArrayProtocol>;
107}
108
109impl Clone for Box<dyn ArrayProtocol> {
111 fn clone(&self) -> Self {
112 self.box_clone()
113 }
114}
115
116#[derive(Debug, Clone, Copy)]
126pub struct NotImplemented;
127
128impl Display for NotImplemented {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 write!(f, "NotImplemented")
131 }
132}
133
134pub type ArrayFunctionImpl = dyn Fn(&[Box<dyn Any>], &HashMap<String, Box<dyn Any>>) -> CoreResult<Box<dyn Any>>
136 + Send
137 + Sync;
138
139#[derive(Clone)]
141pub struct ArrayFunction {
142 pub name: &'static str,
144
145 pub implementation: Arc<ArrayFunctionImpl>,
147}
148
149impl Debug for ArrayFunction {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 f.debug_struct("ArrayFunction")
152 .field("name", &self.name)
153 .finish_non_exhaustive()
154 }
155}
156
157impl PartialEq for ArrayFunction {
158 fn eq(&self, other: &Self) -> bool {
159 self.name == other.name
160 }
161}
162
163impl Eq for ArrayFunction {}
164
165impl std::hash::Hash for ArrayFunction {
166 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
167 self.name.hash(state);
168 }
169}
170
171impl ArrayFunction {
172 #[must_use]
174 pub fn new(name: &'static str) -> Self {
175 Self {
176 name,
177 implementation: Arc::new(|_args, _kwargs| {
179 Err(CoreError::NotImplementedError(ErrorContext::new(
180 "Function not implemented".to_string(),
181 )))
182 }),
183 }
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct DispatchCacheEntry {
190 #[allow(dead_code)]
192 type_signature: Vec<TypeId>,
193 #[allow(dead_code)]
195 preferred_impl_type: TypeId,
196 timestamp: Instant,
198 hit_count: u64,
200}
201
202#[derive(Debug)]
204pub struct ArrayFunctionRegistry {
205 functions: HashMap<&'static str, ArrayFunction>,
207 dispatch_cache: HashMap<(&'static str, Vec<TypeId>), DispatchCacheEntry>,
209 max_cache_size: usize,
211 cache_ttl: Duration,
213}
214
215impl Default for ArrayFunctionRegistry {
216 fn default() -> Self {
217 Self {
218 functions: HashMap::new(),
219 dispatch_cache: HashMap::new(),
220 max_cache_size: 1000, cache_ttl: Duration::from_secs(300), }
223 }
224}
225
226impl ArrayFunctionRegistry {
227 #[must_use]
229 pub fn global() -> &'static RwLock<Self> {
230 static REGISTRY: LazyLock<RwLock<ArrayFunctionRegistry>> =
231 LazyLock::new(|| RwLock::new(ArrayFunctionRegistry::default()));
232 ®ISTRY
233 }
234
235 pub fn register(&mut self, func: ArrayFunction) {
237 self.functions.insert(func.name, func);
238 }
239
240 #[must_use]
242 #[allow(dead_code)]
243 pub fn get(&self, name: &str) -> Option<&ArrayFunction> {
244 self.functions.get(name)
245 }
246
247 #[must_use]
249 pub fn all_functions(&self) -> Vec<&ArrayFunction> {
250 self.functions.values().collect()
251 }
252
253 #[must_use]
255 pub fn get_cached_dispatch(
256 &self,
257 funcname: &'static str,
258 types: &[TypeId],
259 ) -> Option<&DispatchCacheEntry> {
260 let key = (funcname, types.to_vec());
261 if let Some(entry) = self.dispatch_cache.get(&key) {
262 if entry.timestamp.elapsed() < self.cache_ttl {
264 return Some(entry);
265 }
266 }
267 None
268 }
269
270 pub fn cache_dispatch(
272 &mut self,
273 funcname: &'static str,
274 types: Vec<TypeId>,
275 impl_type: TypeId,
276 ) {
277 if self.dispatch_cache.len() >= self.max_cache_size {
279 self.cleanup_cache();
280 }
281
282 let key = (funcname, types.clone());
283 let entry = DispatchCacheEntry {
284 type_signature: types,
285 preferred_impl_type: impl_type,
286 timestamp: Instant::now(),
287 hit_count: 0,
288 };
289 self.dispatch_cache.insert(key, entry);
290 }
291
292 pub fn update_cache_hit(&mut self, funcname: &'static str, types: &[TypeId]) {
294 let key = (funcname, types.to_vec());
295 if let Some(entry) = self.dispatch_cache.get_mut(&key) {
296 entry.hit_count += 1;
297 }
298 }
299
300 fn cleanup_cache(&mut self) {
302 let now = Instant::now();
303 self.dispatch_cache
304 .retain(|_, entry| now.duration_since(entry.timestamp) < self.cache_ttl);
305
306 if self.dispatch_cache.len() >= self.max_cache_size {
308 let mut entries: Vec<_> = self
309 .dispatch_cache
310 .iter()
311 .map(|(k, v)| (k.clone(), v.hit_count))
312 .collect();
313 entries.sort_by_key(|(_, hit_count)| *hit_count);
314
315 let to_remove = self.dispatch_cache.len() / 4;
317 let keys_to_remove: Vec<_> = entries
318 .iter()
319 .take(to_remove)
320 .map(|(key, _)| key.clone())
321 .collect();
322 for key in keys_to_remove {
323 self.dispatch_cache.remove(&key);
324 }
325 }
326 }
327
328 #[must_use]
330 pub fn cache_stats(&self) -> HashMap<String, u64> {
331 let mut stats = HashMap::new();
332 stats.insert("cache_size".to_string(), self.dispatch_cache.len() as u64);
333 stats.insert("max_cache_size".to_string(), self.max_cache_size as u64);
334
335 let total_hits: u64 = self.dispatch_cache.values().map(|e| e.hit_count).sum();
336 stats.insert("total_hits".to_string(), total_hits);
337
338 stats
339 }
340}
341
342#[allow(dead_code)]
347pub fn get_implementing_args(args: &[Box<dyn Any>]) -> Vec<(TypeId, &dyn ArrayProtocol)> {
348 if args.is_empty() {
349 return Vec::new();
350 }
351
352 let mut implementing_args = Vec::with_capacity(args.len());
354
355 for arg in args {
356 if let Some(array_protocol_obj) = arg.downcast_ref::<Box<dyn ArrayProtocol>>() {
357 let type_id = (**array_protocol_obj).type_id();
358 implementing_args.push((type_id, &**array_protocol_obj));
359 }
360 }
361
362 implementing_args.sort_by_key(|&_type_id_| {
365 use std::hash::{Hash, Hasher};
367 let mut hasher = std::collections::hash_map::DefaultHasher::new();
368 std::any::TypeId::of::<i32>().hash(&mut hasher);
369 hasher.finish()
370 });
371
372 implementing_args
373}
374
375#[allow(dead_code)]
386pub fn array_function_dispatch(
387 func: &ArrayFunction,
388 args: &[Box<dyn Any>],
389 kwargs: &HashMap<String, Box<dyn Any>>,
390) -> CoreResult<Box<dyn Any>> {
391 if args.is_empty() {
393 return (func.implementation)(args, kwargs);
394 }
395
396 let implementing_args = get_implementing_args(args);
398
399 if implementing_args.is_empty() {
400 return (func.implementation)(args, kwargs);
402 }
403
404 if implementing_args.len() == 1 {
406 let (type_id, array_protocol_obj) = implementing_args[0];
407 let types = [type_id];
408 match array_protocol_obj.array_function(func, &types, args, kwargs) {
409 Ok(result) => return Ok(result),
410 Err(NotImplemented) => {
411 return Err(CoreError::DispatchError(ErrorContext::new(format!(
412 "No implementation found for {} with type {:?}",
413 func.name, type_id
414 ))));
415 }
416 }
417 }
418
419 let mut unique_types = Vec::with_capacity(implementing_args.len());
421 let mut seen_types = std::collections::HashSet::with_capacity(implementing_args.len());
422
423 for &(type_id, _) in &implementing_args {
424 if seen_types.insert(type_id) {
425 unique_types.push(type_id);
426 }
427 }
428
429 for (_, array_protocol_obj) in implementing_args {
431 if let Ok(result) = array_protocol_obj.array_function(func, &unique_types, args, kwargs) {
432 return Ok(result);
433 }
434 }
435
436 Err(CoreError::DispatchError(ErrorContext::new(format!(
438 "No implementation found for {} with {} argument types: {:?}",
439 func.name,
440 unique_types.len(),
441 unique_types
442 ))))
443}
444
445pub struct ArrayFunctionDecorator<F> {
449 function: F,
450 name: &'static str,
451}
452
453impl<F> ArrayFunctionDecorator<F>
454where
455 F: Send + Sync + 'static,
456{
457 #[must_use]
459 pub fn new(function: F, name: &'static str) -> Self {
460 Self { function, name }
461 }
462
463 pub fn register(self) -> F {
465 let implementation = Arc::new(
466 move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
467 Err(CoreError::NotImplementedError(ErrorContext::new(
472 "ArrayFunctionDecorator: Type conversion in array_function_dispatch is not implemented yet".to_string()
473 )))
474 },
475 );
476
477 let func = ArrayFunction {
478 name: self.name,
479 implementation,
480 };
481
482 let registry = ArrayFunctionRegistry::global();
484 if let Ok(mut registry) = registry.write() {
485 registry.register(func);
486 } else {
487 eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry, skipping function registration");
488 }
490
491 self.function
492 }
493}
494
495pub trait GPUArray: ArrayProtocol {
497 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>>;
499
500 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
502
503 #[must_use]
505 fn is_on_gpu(&self) -> bool;
506
507 #[must_use]
509 fn device_info(&self) -> HashMap<String, String>;
510}
511
512pub trait DistributedArray: ArrayProtocol {
514 #[must_use]
516 fn distribution_info(&self) -> HashMap<String, String>;
517
518 fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
520
521 fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>>;
523
524 #[must_use]
526 fn is_distributed(&self) -> bool;
527}
528
529pub trait JITArray: ArrayProtocol {
531 fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>>;
533
534 #[must_use]
536 fn supports_jit(&self) -> bool;
537
538 #[must_use]
540 fn jit_info(&self) -> HashMap<String, String>;
541}
542
543pub trait JITFunction: Send + Sync {
545 fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>>;
547
548 #[must_use]
550 fn source(&self) -> String;
551
552 #[must_use]
554 fn compile_info(&self) -> HashMap<String, String>;
555
556 #[must_use]
558 fn clone_box(&self) -> Box<dyn JITFunction>;
559}
560
561pub trait JITFunctionFactory: Send + Sync {
563 fn create_jit_function(
565 &self,
566 expression: &str,
567 array_typeid: TypeId,
568 ) -> CoreResult<Box<dyn JITFunction>>;
569
570 #[must_use]
572 fn supports_array_type(&self, array_typeid: TypeId) -> bool;
573}
574
575#[derive(Default)]
577pub struct JITFactoryRegistry {
578 factories: Vec<Box<dyn JITFunctionFactory>>,
579}
580
581impl std::fmt::Debug for JITFactoryRegistry {
582 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
583 write!(
584 f,
585 "JITFactoryRegistry {{ factories: {} }}",
586 self.factories.len()
587 )
588 }
589}
590
591impl JITFactoryRegistry {
592 #[must_use]
594 pub fn global() -> &'static RwLock<Self> {
595 static REGISTRY: LazyLock<RwLock<JITFactoryRegistry>> = LazyLock::new(|| {
596 RwLock::new(JITFactoryRegistry {
597 factories: Vec::new(),
598 })
599 });
600 ®ISTRY
601 }
602
603 pub fn register(&mut self, factory: Box<dyn JITFunctionFactory>) {
605 self.factories.push(factory);
606 }
607
608 #[must_use]
610 pub fn get_factory_for_array_type(
611 &self,
612 array_typeid: TypeId,
613 ) -> Option<&dyn JITFunctionFactory> {
614 for factory in &self.factories {
615 if factory.supports_array_type(array_typeid) {
616 return Some(&**factory);
617 }
618 }
619 None
620 }
621}
622
623#[derive(Debug, Clone)]
625pub struct NdarrayWrapper<T, D: ndarray::Dimension> {
626 array: ndarray::Array<T, D>,
627 phantom: PhantomData<(T, D)>,
628}
629
630impl<T, D> NdarrayWrapper<T, D>
631where
632 T: Clone + 'static,
633 D: ndarray::Dimension + 'static,
634{
635 #[must_use]
637 pub fn new(array: ndarray::Array<T, D>) -> Self {
638 Self {
639 array,
640 phantom: PhantomData,
641 }
642 }
643
644 #[must_use]
646 pub const fn as_array(&self) -> &ndarray::Array<T, D> {
647 &self.array
648 }
649
650 #[must_use]
652 pub fn into_array(self) -> ndarray::Array<T, D> {
653 self.array
654 }
655
656 pub fn array_2(&mut self, newarray: ndarray::Array<T, D>) {
658 self.array = newarray;
659 }
660}
661
662impl<T, D> ArrayProtocol for NdarrayWrapper<T, D>
663where
664 T: Clone + Send + Sync + 'static,
665 D: ndarray::Dimension + Send + Sync + 'static,
666{
667 fn array_function(
668 &self,
669 func: &ArrayFunction,
670 _types: &[TypeId],
671 args: &[Box<dyn Any>],
672 kwargs: &HashMap<String, Box<dyn Any>>,
673 ) -> Result<Box<dyn Any>, NotImplemented> {
674 match func.name {
675 "scirs2::array_protocol::operations::add" => {
676 if args.len() < 2 {
678 return Err(NotImplemented);
679 }
680
681 if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
682 if let (Some(a), Some(b)) = (
683 self.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
684 other.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
685 ) {
686 if TypeId::of::<T>() == TypeId::of::<f64>() {
688 let a_f64 =
689 unsafe { &*(a as *const _ as *const NdarrayWrapper<f64, D>) };
690 let b_f64 =
691 unsafe { &*(b as *const _ as *const NdarrayWrapper<f64, D>) };
692 let result = a_f64.as_array() + b_f64.as_array();
693 return Ok(Box::new(NdarrayWrapper::new(result)));
694 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
695 let a_f32 =
696 unsafe { &*(a as *const _ as *const NdarrayWrapper<f32, D>) };
697 let b_f32 =
698 unsafe { &*(b as *const _ as *const NdarrayWrapper<f32, D>) };
699 let result = a_f32.as_array() + b_f32.as_array();
700 return Ok(Box::new(NdarrayWrapper::new(result)));
701 }
702 }
703 }
704 Err(NotImplemented)
705 }
706 "scirs2::array_protocol::operations::matmul" => {
707 if args.len() < 2 {
709 return Err(NotImplemented);
710 }
711
712 if TypeId::of::<D>() != TypeId::of::<ndarray::Ix2>() {
715 return Err(NotImplemented);
716 }
717
718 if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
719 if TypeId::of::<T>() == TypeId::of::<f64>() {
724 let a_f64 = unsafe {
726 &*(self as *const _ as *const NdarrayWrapper<f64, ndarray::Ix2>)
727 };
728 let b_f64 = unsafe {
729 &*(other as *const _ as *const NdarrayWrapper<f64, ndarray::Ix2>)
730 };
731
732 let ashape = a_f64.as_array().shape();
734 let bshape = b_f64.as_array().shape();
735
736 if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
737 return Err(NotImplemented);
738 }
739
740 let result = a_f64.as_array().dot(b_f64.as_array());
743 return Ok(Box::new(NdarrayWrapper::new(result)));
744 }
745 else if TypeId::of::<T>() == TypeId::of::<f32>() {
747 let a_f32 = unsafe {
749 &*(self as *const _ as *const NdarrayWrapper<f32, ndarray::Ix2>)
750 };
751 let b_f32 = unsafe {
752 &*(other as *const _ as *const NdarrayWrapper<f32, ndarray::Ix2>)
753 };
754
755 let ashape = a_f32.as_array().shape();
757 let bshape = b_f32.as_array().shape();
758
759 if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
760 return Err(NotImplemented);
761 }
762
763 let result = a_f32.as_array().dot(b_f32.as_array());
766 return Ok(Box::new(NdarrayWrapper::new(result)));
767 }
768 }
769 Err(NotImplemented)
771 }
772 "scirs2::array_protocol::operations::transpose" => {
773 if TypeId::of::<T>() == TypeId::of::<f64>() {
775 let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
776 let result = a_f64.as_array().t().to_owned();
777 return Ok(Box::new(NdarrayWrapper::new(result)));
778 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
779 let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
780 let result = a_f32.as_array().t().to_owned();
781 return Ok(Box::new(NdarrayWrapper::new(result)));
782 }
783 Err(NotImplemented)
784 }
785 "scirs2::array_protocol::operations::sum" => {
786 let axis_ref = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
788
789 if TypeId::of::<T>() == TypeId::of::<f64>() {
790 let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
791 match axis_ref {
792 Some(&_ax) => {
793 let result = a_f64.as_array().sum();
796 return Ok(Box::new(result));
797 }
798 None => {
799 let result = a_f64.as_array().sum();
800 return Ok(Box::new(result));
801 }
802 }
803 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
804 let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
805 match axis_ref {
806 Some(&_ax) => {
807 let result = a_f32.as_array().sum();
810 return Ok(Box::new(result));
811 }
812 None => {
813 let result = a_f32.as_array().sum();
814 return Ok(Box::new(result));
815 }
816 }
817 }
818 Err(NotImplemented)
819 }
820 "scirs2::array_protocol::operations::reshape" => {
821 if let Some(shape) = kwargs
823 .get("shape")
824 .and_then(|s| s.downcast_ref::<Vec<usize>>())
825 {
826 if TypeId::of::<T>() == TypeId::of::<f64>() {
827 let a_f64 =
828 unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
829 match a_f64
830 .as_array()
831 .clone()
832 .into_shape_with_order(shape.clone())
833 {
834 Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
835 Err(_) => return Err(NotImplemented),
836 }
837 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
838 let a_f32 =
839 unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
840 match a_f32
841 .as_array()
842 .clone()
843 .into_shape_with_order(shape.clone())
844 {
845 Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
846 Err(_) => return Err(NotImplemented),
847 }
848 }
849 }
850 Err(NotImplemented)
851 }
852 _ => Err(NotImplemented),
853 }
854 }
855
856 fn as_any(&self) -> &dyn Any {
857 self
858 }
859
860 fn shape(&self) -> &[usize] {
861 self.array.shape()
862 }
863
864 fn dtype(&self) -> TypeId {
865 TypeId::of::<T>()
866 }
867
868 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
869 Box::new(self.clone())
870 }
871}
872
873#[derive(Debug, Clone)]
877pub struct MockDistributedArray<T: Clone + 'static> {
878 chunks: Vec<T>,
879 shape: Vec<usize>,
880}
881
882impl<T: Clone + Send + Sync + 'static> MockDistributedArray<T> {
883 #[must_use]
885 pub fn new(chunks: Vec<T>, shape: Vec<usize>) -> Self {
886 Self { chunks, shape }
887 }
888}
889
890impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockDistributedArray<T> {
891 fn array_function(
892 &self,
893 func: &ArrayFunction,
894 _types: &[TypeId],
895 _args: &[Box<dyn Any>],
896 _kwargs: &HashMap<String, Box<dyn Any>>,
897 ) -> Result<Box<dyn Any>, NotImplemented> {
898 match func.name {
899 "scirs2::mean" => {
900 let result = T::clone(&self.chunks[0]);
905 Ok(Box::new(result))
906 }
907 _ => Err(NotImplemented),
908 }
909 }
910
911 fn as_any(&self) -> &dyn Any {
912 self
913 }
914
915 fn shape(&self) -> &[usize] {
916 &self.shape
917 }
918
919 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
920 Box::new(self.clone())
921 }
922}
923
924impl<T: Clone + Send + Sync + 'static> DistributedArray for MockDistributedArray<T> {
925 fn distribution_info(&self) -> HashMap<String, String> {
926 let mut info = HashMap::new();
927 info.insert("type".to_string(), "mock_distributed".to_string());
928 info.insert("chunks".to_string(), self.chunks.len().to_string());
929 info
930 }
931
932 fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
933 Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
936 }
937
938 fn scatter(&self, _numchunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
939 Ok(Box::new(self.clone()) as Box<dyn DistributedArray>)
942 }
943
944 fn is_distributed(&self) -> bool {
945 true
946 }
947}
948
949#[derive(Debug, Clone)]
951pub struct MockGPUArray<T: Clone + 'static> {
952 data: Vec<T>,
953 shape: Vec<usize>,
954 device: String,
955}
956
957impl<T: Clone + Send + Sync + 'static> MockGPUArray<T> {
958 #[must_use]
960 pub fn new(data: Vec<T>, shape: Vec<usize>, device: String) -> Self {
961 Self {
962 data,
963 shape,
964 device,
965 }
966 }
967}
968
969impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockGPUArray<T> {
970 fn array_function(
971 &self,
972 func: &ArrayFunction,
973 _types: &[TypeId],
974 _args: &[Box<dyn Any>],
975 _kwargs: &HashMap<String, Box<dyn Any>>,
976 ) -> Result<Box<dyn Any>, NotImplemented> {
977 match func.name {
978 "scirs2::matmul" => {
979 let result =
984 MockGPUArray::new(self.data.clone(), self.shape.clone(), self.device.clone());
985 Ok(Box::new(result))
986 }
987 _ => Err(NotImplemented),
988 }
989 }
990
991 fn as_any(&self) -> &dyn Any {
992 self
993 }
994
995 fn shape(&self) -> &[usize] {
996 &self.shape
997 }
998
999 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
1000 Box::new(self.clone())
1001 }
1002}
1003
1004impl<T: Clone + Send + Sync + 'static> GPUArray for MockGPUArray<T> {
1005 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
1006 Ok(Box::new(self.clone()) as Box<dyn GPUArray>)
1008 }
1009
1010 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
1011 Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
1014 }
1015
1016 fn is_on_gpu(&self) -> bool {
1017 true
1018 }
1019
1020 fn device_info(&self) -> HashMap<String, String> {
1021 let mut info = HashMap::new();
1022 info.insert("device".to_string(), self.device.clone());
1023 info.insert("type".to_string(), "mock_gpu".to_string());
1024 info
1025 }
1026}
1027
1028#[derive(Debug)]
1033pub struct ArrayProtocolFunction<F> {
1034 func: F,
1035 name: &'static str,
1036}
1037
1038impl<F> ArrayProtocolFunction<F> {
1039 #[must_use]
1041 pub fn new(func: F, name: &'static str) -> Self {
1042 Self { func, name }
1043 }
1044}
1045
1046impl<F> ArrayProtocolFunction<F>
1047where
1048 F: Clone + Send + Sync + 'static,
1049{
1050 pub fn register(self) -> F {
1052 let implementation = Arc::new(
1053 move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1054 Err(CoreError::NotImplementedError(ErrorContext::new(
1059 "ArrayProtocolFunction: Implementation for array protocol functions is not complete".to_string()
1060 )))
1061 },
1062 );
1063
1064 let array_func = ArrayFunction {
1065 name: self.name,
1066 implementation,
1067 };
1068
1069 if let Ok(mut registry) = ArrayFunctionRegistry::global().write() {
1071 registry.register(array_func);
1072 } else {
1073 eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry during array protocol building, skipping function registration");
1074 }
1076
1077 self.func
1078 }
1079}
1080
1081#[macro_export]
1123macro_rules! array_function_def {
1124 (fn $name:ident $(<$($gen:ident),*>)? ($($arg:ident : $arg_ty:ty),*) -> $ret:ty $body:block, $funcname:expr) => {
1125 {
1126 fn $name $(<$($gen),*>)? ($($arg : $arg_ty),*) -> $ret $body
1128
1129 $name
1131 }
1132 };
1133}
1134
1135pub use self::distributed_impl::{
1137 ArrayChunk, DistributedBackend, DistributedConfig, DistributedNdarray, DistributionStrategy,
1138};
1139
1140pub use self::gpu_impl::{
1142 kernels as gpu_kernels, GPUArrayBuilder, GPUBackend, GPUConfig, GPUNdarray,
1143};
1144
1145pub use self::jit_impl::{
1147 CraneliftFunctionFactory, JITBackend, JITConfig, JITEnabledArray, JITFunctionImpl, JITManager,
1148 LLVMFunctionFactory,
1149};
1150
1151pub use self::operations::{
1153 add, apply_elementwise, concatenate, inverse, matmul, multiply, reshape, subtract, sum, svd,
1154 transpose, OperationError,
1155};
1156
1157pub use self::ml_ops::{
1159 activation, batch_norm, conv2d, cross_entropy, dropout, max_pool2d, self_attention,
1160 ActivationFunc,
1161};
1162
1163#[allow(dead_code)]
1169pub fn init() {
1170 let mut jit_manager = JITManager::global().write().unwrap();
1172 jit_manager.initialize();
1173}
1174
1175pub mod traits {
1177 use super::*;
1178
1179 pub trait StridedArray: ArrayProtocol {
1181 #[must_use]
1183 fn strides(&self) -> Vec<usize>;
1184
1185 #[must_use]
1187 fn is_contiguous(&self) -> bool;
1188
1189 #[must_use]
1191 fn is_fortran_contiguous(&self) -> bool;
1192 }
1193
1194 pub trait ZeroCopyArray: ArrayProtocol {
1196 #[must_use]
1198 fn view(&self) -> Box<dyn ZeroCopyArray>;
1199
1200 #[must_use]
1202 fn view_mut(&mut self) -> Box<dyn ZeroCopyArray>;
1203
1204 #[must_use]
1206 fn is_view(&self) -> bool;
1207 }
1208
1209 pub trait DifferentiableArray: ArrayProtocol {
1211 fn gradient(
1213 &self,
1214 variables: &[Box<dyn DifferentiableArray>],
1215 ) -> Vec<Box<dyn DifferentiableArray>>;
1216
1217 fn set_requiresgrad(&mut self, requiresgrad: bool);
1219
1220 #[must_use]
1222 fn requiresgrad(&self) -> bool;
1223
1224 #[must_use]
1226 fn grad(&self) -> Option<Box<dyn DifferentiableArray>>;
1227 }
1228
1229 pub trait AsyncArray: ArrayProtocol {
1231 fn async_op<F, R>(&self, op: F) -> impl std::future::Future<Output = CoreResult<R>>
1233 where
1234 F: FnOnce(&Self) -> CoreResult<R> + Send + 'static,
1235 R: Send + 'static;
1236
1237 #[must_use]
1239 fn supports_async(&self) -> bool;
1240 }
1241}
1242
1243#[cfg(test)]
1244mod tests {
1245 use super::*;
1246
1247 #[test]
1248 fn test_array_protocol_registry() {
1249 let implementation = Arc::new(
1251 move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1252 Ok(Box::new(42.0) as Box<dyn Any>)
1253 },
1254 );
1255
1256 let func = ArrayFunction {
1257 name: "scirs2::test::test_func",
1258 implementation,
1259 };
1260
1261 let registry = ArrayFunctionRegistry::global();
1262 {
1263 let mut reg = registry.write().unwrap();
1264 reg.register(func.clone());
1265 }
1266
1267 {
1269 let reg = registry.read().unwrap();
1270 let registered_func = reg.get("scirs2::test::test_func").unwrap();
1271 assert_eq!(registered_func.name, "scirs2::test::test_func");
1272 }
1273 }
1274
1275 #[test]
1276 fn test_mock_distributed_array() {
1277 let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1278 assert!(array.is_distributed());
1279
1280 let info = array.distribution_info();
1281 assert_eq!(info.get("type").unwrap(), "mock_distributed");
1282 assert_eq!(info.get("chunks").unwrap(), "3");
1283 }
1284
1285 #[test]
1286 fn test_mock_gpu_array() {
1287 let array = MockGPUArray::new(vec![1.0, 2.0, 3.0], vec![3], "cuda:0".to_string());
1288 assert!(array.is_on_gpu());
1289
1290 let info = array.device_info();
1291 assert_eq!(info.get("device").unwrap(), "cuda:0");
1292 assert_eq!(info.get("type").unwrap(), "mock_gpu");
1293 }
1294
1295 #[test]
1296 fn test_box_clone() {
1297 let array = ndarray::Array2::<f64>::ones((3, 3));
1299 let wrapped = NdarrayWrapper::new(array);
1300 let boxed: Box<dyn ArrayProtocol> = Box::new(wrapped);
1301 let cloned = boxed.clone();
1302
1303 assert_eq!(cloned.shape(), &[3, 3]);
1305
1306 let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1308 let boxed: Box<dyn ArrayProtocol> = Box::new(array);
1309 let cloned = boxed.clone();
1310
1311 assert_eq!(cloned.shape(), &[3]);
1313 }
1314}
1315
1316#[cfg(test)]
1318mod examples {
1319 use super::*;
1320 use ndarray::Array2;
1321 use std::any::Any;
1322 use std::collections::HashMap;
1323
1324 #[test]
1326 fn example_distributed_array() {
1327 let array = Array2::<f64>::ones((10, 5));
1329
1330 let config = DistributedConfig {
1332 chunks: 3,
1333 balance: true,
1334 strategy: DistributionStrategy::RowWise,
1335 backend: DistributedBackend::Threaded,
1336 };
1337
1338 let dist_array = DistributedNdarray::from_array(&array, config);
1340
1341 assert_eq!(dist_array.num_chunks(), 3);
1343 assert_eq!(dist_array.shape(), &[10, 5]);
1344
1345 let result = dist_array.to_array().unwrap();
1347
1348 assert_eq!(result.shape(), array.shape());
1350 }
1353
1354 #[test]
1356 fn example_gpu_array() {
1357 let array = Array2::<f64>::ones((10, 5));
1359
1360 let config = GPUConfig {
1362 backend: GPUBackend::CUDA,
1363 device_id: 0,
1364 async_ops: true,
1365 mixed_precision: false,
1366 memory_fraction: 0.9,
1367 };
1368
1369 let gpu_array = GPUNdarray::new(array.clone(), config);
1371
1372 assert_eq!(gpu_array.shape(), &[10, 5]);
1374 assert!(gpu_array.is_on_gpu());
1375
1376 let info = gpu_array.device_info();
1378 assert_eq!(info.get("backend").unwrap(), "CUDA");
1379
1380 let gpu_box: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1382 let gpu_clone = gpu_box.clone();
1383
1384 assert_eq!(gpu_clone.shape(), &[10, 5]);
1386 }
1387
1388 #[test]
1390 fn example_jit_array() {
1391 init();
1393
1394 let array = Array2::<f64>::ones((10, 5));
1396 let wrapped = NdarrayWrapper::new(array);
1397
1398 let jitarray: JITEnabledArray<f64, NdarrayWrapper<f64, ndarray::Ix2>> =
1400 JITEnabledArray::new(wrapped);
1401
1402 assert!(jitarray.supports_jit());
1404
1405 let expression = "x + y";
1407 let jit_function = jitarray.compile(expression).unwrap();
1408
1409 assert_eq!(jit_function.source(), expression);
1411
1412 let info = jitarray.jit_info();
1414 assert_eq!(info.get("supports_jit").unwrap(), "true");
1415
1416 let jit_box: Box<dyn ArrayProtocol> = Box::new(jitarray);
1418 let jit_clone = jit_box.clone();
1419
1420 assert_eq!(jit_clone.shape(), &[10, 5]);
1422 }
1423
1424 #[test]
1426 fn example_cloning_array_protocol_objects() {
1427 let array = Array2::<f64>::ones((10, 5));
1429 let config = GPUConfig::default();
1430 let gpu_array = GPUNdarray::new(array.clone(), config);
1431
1432 let boxed: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1434 let cloned = boxed.clone();
1435
1436 assert_eq!(cloned.shape(), &[10, 5]);
1438
1439 let config = DistributedConfig {
1441 chunks: 3,
1442 balance: true,
1443 strategy: DistributionStrategy::RowWise,
1444 backend: DistributedBackend::Threaded,
1445 };
1446 let dist_array = DistributedNdarray::from_array(&array, config);
1447
1448 let boxed: Box<dyn ArrayProtocol> = Box::new(dist_array);
1450 let cloned = boxed.clone();
1451
1452 assert_eq!(cloned.shape(), &[10, 5]);
1454 }
1455
1456 #[test]
1502 fn example_array_interoperability() {
1503 init();
1505
1506 let cpu_array = Array2::<f64>::ones((5, 5));
1508
1509 let gpu_config = GPUConfig {
1511 backend: GPUBackend::CUDA,
1512 device_id: 0,
1513 async_ops: false,
1514 mixed_precision: false,
1515 memory_fraction: 0.9,
1516 };
1517 let gpu_array = GPUNdarray::new(cpu_array.clone(), gpu_config);
1518
1519 let dist_config = DistributedConfig {
1521 chunks: 2,
1522 balance: true,
1523 strategy: DistributionStrategy::RowWise,
1524 backend: DistributedBackend::Threaded,
1525 };
1526 let dist_array = DistributedNdarray::from_array(&cpu_array, dist_config);
1527
1528 let gpu_wrapper: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1530 let dist_wrapper: Box<dyn ArrayProtocol> = Box::new(dist_array);
1531
1532 let gpu_clone = gpu_wrapper.clone();
1534 let dist_clone = dist_wrapper.clone();
1535
1536 assert_eq!(gpu_clone.shape(), &[5, 5]);
1537 assert_eq!(dist_clone.shape(), &[5, 5]);
1538 }
1539
1540 #[test]
1542 fn example_custom_array_type() {
1543 use std::sync::Arc;
1544
1545 struct MyCustomArray<T> {
1547 data: Vec<T>,
1548 shape: Vec<usize>,
1549 }
1550
1551 impl<T: Clone + 'static> MyCustomArray<T> {
1552 fn new(data: Vec<T>, shape: Vec<usize>) -> Self {
1553 Self { data, shape }
1554 }
1555
1556 }
1561
1562 impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MyCustomArray<T> {
1564 fn array_function(
1565 &self,
1566 func: &ArrayFunction,
1567 _types: &[TypeId],
1568 _args: &[Box<dyn Any>],
1569 _kwargs: &HashMap<String, Box<dyn Any>>,
1570 ) -> Result<Box<dyn Any>, NotImplemented> {
1571 if func.name == "scirs2::example::custom_sum" {
1572 match std::any::TypeId::of::<T>() {
1574 tid if tid == std::any::TypeId::of::<f64>() => {
1575 let f64_data = unsafe {
1577 std::slice::from_raw_parts(
1578 self.data.as_ptr() as *const f64,
1579 self.data.len(),
1580 )
1581 };
1582 let sum = f64_data.iter().sum::<f64>();
1583 Ok(Box::new(sum))
1584 }
1585 tid if tid == std::any::TypeId::of::<f32>() => {
1586 let f32_data = unsafe {
1588 std::slice::from_raw_parts(
1589 self.data.as_ptr() as *const f32,
1590 self.data.len(),
1591 )
1592 };
1593 let sum = f32_data.iter().sum::<f32>();
1594 Ok(Box::new(sum))
1595 }
1596 _ => Err(NotImplemented),
1597 }
1598 } else {
1599 Err(NotImplemented)
1600 }
1601 }
1602
1603 fn as_any(&self) -> &dyn Any {
1604 self
1605 }
1606
1607 fn shape(&self) -> &[usize] {
1608 &self.shape
1609 }
1610
1611 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
1612 Box::new(MyCustomArray {
1613 data: self.data.clone(),
1614 shape: self.shape.clone(),
1615 })
1616 }
1617 }
1618
1619 let custom_array = MyCustomArray::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1621
1622 let boxed: Box<dyn ArrayProtocol> = Box::new(custom_array);
1624 let cloned = boxed.clone();
1625
1626 assert_eq!(cloned.shape(), &[2, 2]);
1628
1629 let func = ArrayFunction {
1631 name: "scirs2::example::custom_sum",
1632 implementation: Arc::new(move |_args, _kwargs| {
1633 Ok(Box::new(42.0) as Box<dyn Any>)
1635 }),
1636 };
1637
1638 let result = cloned.array_function(
1640 &func,
1641 &[std::any::TypeId::of::<f64>()],
1642 &[],
1643 &HashMap::new(),
1644 );
1645
1646 assert!(result.is_ok());
1648 if let Ok(value) = result {
1649 let sum = *value.downcast_ref::<f64>().unwrap();
1650 assert_eq!(sum, 10.0);
1651 }
1652 }
1653}
1654impl Clone for Box<dyn JITFunction> {
1656 fn clone(&self) -> Self {
1657 self.clone_box()
1658 }
1659}