1use std::any::{Any, TypeId};
29use std::collections::HashMap;
30use std::fmt::{Debug, Display};
31use std::marker::PhantomData;
32use std::sync::{Arc, LazyLock, RwLock};
33use std::time::{Duration, Instant};
34
35use crate::error::{CoreError, CoreResult, ErrorContext};
36
37mod distributed_impl;
39mod gpu_impl;
40mod jit_impl;
41mod operations;
42
43pub use crate::array_function_dispatch;
45
46pub mod auto_device;
48pub mod distributed_training;
49pub mod grad;
50pub mod mixed_precision;
51pub mod ml_ops;
52pub mod neural;
53#[cfg(feature = "serialization")]
54pub mod serialization;
55pub mod training;
56
57pub trait ArrayProtocol: Any + Send + Sync {
61 fn array_function(
75 &self,
76 func: &ArrayFunction,
77 types: &[TypeId],
78 args: &[Box<dyn Any>],
79 kwargs: &HashMap<String, Box<dyn Any>>,
80 ) -> Result<Box<dyn Any>, NotImplemented>;
81
82 #[must_use]
84 fn as_any(&self) -> &dyn Any;
85
86 #[must_use]
88 fn shape(&self) -> &[usize] {
89 &[]
90 }
91
92 #[must_use]
94 fn dtype(&self) -> TypeId {
95 TypeId::of::<f64>()
96 }
97
98 #[must_use]
100 fn box_clone(&self) -> Box<dyn ArrayProtocol>;
101}
102
103impl Clone for Box<dyn ArrayProtocol> {
105 fn clone(&self) -> Self {
106 self.box_clone()
107 }
108}
109
110#[derive(Debug, Clone, Copy)]
120pub struct NotImplemented;
121
122impl Display for NotImplemented {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 write!(f, "NotImplemented")
125 }
126}
127
128pub type ArrayFunctionImpl = dyn Fn(&[Box<dyn Any>], &HashMap<String, Box<dyn Any>>) -> CoreResult<Box<dyn Any>>
130 + Send
131 + Sync;
132
133#[derive(Clone)]
135pub struct ArrayFunction {
136 pub name: &'static str,
138
139 pub implementation: Arc<ArrayFunctionImpl>,
141}
142
143impl Debug for ArrayFunction {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_struct("ArrayFunction")
146 .field("name", &self.name)
147 .finish_non_exhaustive()
148 }
149}
150
151impl PartialEq for ArrayFunction {
152 fn eq(&self, other: &Self) -> bool {
153 self.name == other.name
154 }
155}
156
157impl Eq for ArrayFunction {}
158
159impl std::hash::Hash for ArrayFunction {
160 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
161 self.name.hash(state);
162 }
163}
164
165impl ArrayFunction {
166 #[must_use]
168 pub fn new(name: &'static str) -> Self {
169 Self {
170 name,
171 implementation: Arc::new(|_args, _kwargs| {
173 Err(CoreError::NotImplementedError(ErrorContext::new(
174 "Function not implemented".to_string(),
175 )))
176 }),
177 }
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct DispatchCacheEntry {
184 #[allow(dead_code)]
186 type_signature: Vec<TypeId>,
187 #[allow(dead_code)]
189 preferred_impl_type: TypeId,
190 timestamp: Instant,
192 hit_count: u64,
194}
195
196#[derive(Debug)]
198pub struct ArrayFunctionRegistry {
199 functions: HashMap<&'static str, ArrayFunction>,
201 dispatch_cache: HashMap<(&'static str, Vec<TypeId>), DispatchCacheEntry>,
203 max_cache_size: usize,
205 cache_ttl: Duration,
207}
208
209impl Default for ArrayFunctionRegistry {
210 fn default() -> Self {
211 Self {
212 functions: HashMap::new(),
213 dispatch_cache: HashMap::new(),
214 max_cache_size: 1000, cache_ttl: Duration::from_secs(300), }
217 }
218}
219
220impl ArrayFunctionRegistry {
221 #[must_use]
223 pub fn global() -> &'static RwLock<Self> {
224 static REGISTRY: LazyLock<RwLock<ArrayFunctionRegistry>> =
225 LazyLock::new(|| RwLock::new(ArrayFunctionRegistry::default()));
226 ®ISTRY
227 }
228
229 pub fn register(&mut self, func: ArrayFunction) {
231 self.functions.insert(func.name, func);
232 }
233
234 #[must_use]
236 #[allow(dead_code)]
237 pub fn get(&self, name: &str) -> Option<&ArrayFunction> {
238 self.functions.get(name)
239 }
240
241 #[must_use]
243 pub fn all_functions(&self) -> Vec<&ArrayFunction> {
244 self.functions.values().collect()
245 }
246
247 #[must_use]
249 pub fn get_cached_dispatch(
250 &self,
251 funcname: &'static str,
252 types: &[TypeId],
253 ) -> Option<&DispatchCacheEntry> {
254 let key = (funcname, types.to_vec());
255 if let Some(entry) = self.dispatch_cache.get(&key) {
256 if entry.timestamp.elapsed() < self.cache_ttl {
258 return Some(entry);
259 }
260 }
261 None
262 }
263
264 pub fn cache_dispatch(
266 &mut self,
267 funcname: &'static str,
268 types: Vec<TypeId>,
269 impl_type: TypeId,
270 ) {
271 if self.dispatch_cache.len() >= self.max_cache_size {
273 self.cleanup_cache();
274 }
275
276 let key = (funcname, types.clone());
277 let entry = DispatchCacheEntry {
278 type_signature: types,
279 preferred_impl_type: impl_type,
280 timestamp: Instant::now(),
281 hit_count: 0,
282 };
283 self.dispatch_cache.insert(key, entry);
284 }
285
286 pub fn update_cache_hit(&mut self, funcname: &'static str, types: &[TypeId]) {
288 let key = (funcname, types.to_vec());
289 if let Some(entry) = self.dispatch_cache.get_mut(&key) {
290 entry.hit_count += 1;
291 }
292 }
293
294 fn cleanup_cache(&mut self) {
296 let now = Instant::now();
297 self.dispatch_cache
298 .retain(|_, entry| now.duration_since(entry.timestamp) < self.cache_ttl);
299
300 if self.dispatch_cache.len() >= self.max_cache_size {
302 let mut entries: Vec<_> = self
303 .dispatch_cache
304 .iter()
305 .map(|(k, v)| (k.clone(), v.hit_count))
306 .collect();
307 entries.sort_by_key(|(_, hit_count)| *hit_count);
308
309 let to_remove = self.dispatch_cache.len() / 4;
311 let keys_to_remove: Vec<_> = entries
312 .iter()
313 .take(to_remove)
314 .map(|(key, _)| key.clone())
315 .collect();
316 for key in keys_to_remove {
317 self.dispatch_cache.remove(&key);
318 }
319 }
320 }
321
322 #[must_use]
324 pub fn cache_stats(&self) -> HashMap<String, u64> {
325 let mut stats = HashMap::new();
326 stats.insert("cache_size".to_string(), self.dispatch_cache.len() as u64);
327 stats.insert("max_cache_size".to_string(), self.max_cache_size as u64);
328
329 let total_hits: u64 = self.dispatch_cache.values().map(|e| e.hit_count).sum();
330 stats.insert("total_hits".to_string(), total_hits);
331
332 stats
333 }
334}
335
336#[allow(dead_code)]
341pub fn get_implementing_args(args: &[Box<dyn Any>]) -> Vec<(TypeId, &dyn ArrayProtocol)> {
342 if args.is_empty() {
343 return Vec::new();
344 }
345
346 let mut implementing_args = Vec::with_capacity(args.len());
348
349 for arg in args {
350 if let Some(array_protocol_obj) = arg.downcast_ref::<Box<dyn ArrayProtocol>>() {
351 let type_id = (**array_protocol_obj).type_id();
352 implementing_args.push((type_id, &**array_protocol_obj));
353 }
354 }
355
356 implementing_args.sort_by_key(|&_type_id_| {
359 use std::hash::{Hash, Hasher};
361 let mut hasher = std::collections::hash_map::DefaultHasher::new();
362 std::any::TypeId::of::<i32>().hash(&mut hasher);
363 hasher.finish()
364 });
365
366 implementing_args
367}
368
369#[allow(dead_code)]
380pub fn array_function_dispatch(
381 func: &ArrayFunction,
382 args: &[Box<dyn Any>],
383 kwargs: &HashMap<String, Box<dyn Any>>,
384) -> CoreResult<Box<dyn Any>> {
385 if args.is_empty() {
387 return (func.implementation)(args, kwargs);
388 }
389
390 let implementing_args = get_implementing_args(args);
392
393 if implementing_args.is_empty() {
394 return (func.implementation)(args, kwargs);
396 }
397
398 if implementing_args.len() == 1 {
400 let (type_id, array_protocol_obj) = implementing_args[0];
401 let types = [type_id];
402 match array_protocol_obj.array_function(func, &types, args, kwargs) {
403 Ok(result) => return Ok(result),
404 Err(NotImplemented) => {
405 return Err(CoreError::DispatchError(ErrorContext::new(format!(
406 "No implementation found for {} with type {:?}",
407 func.name, type_id
408 ))));
409 }
410 }
411 }
412
413 let mut unique_types = Vec::with_capacity(implementing_args.len());
415 let mut seen_types = std::collections::HashSet::with_capacity(implementing_args.len());
416
417 for &(type_id, _) in &implementing_args {
418 if seen_types.insert(type_id) {
419 unique_types.push(type_id);
420 }
421 }
422
423 for (_, array_protocol_obj) in implementing_args {
425 if let Ok(result) = array_protocol_obj.array_function(func, &unique_types, args, kwargs) {
426 return Ok(result);
427 }
428 }
429
430 Err(CoreError::DispatchError(ErrorContext::new(format!(
432 "No implementation found for {} with {} argument types: {:?}",
433 func.name,
434 unique_types.len(),
435 unique_types
436 ))))
437}
438
439pub struct ArrayFunctionDecorator<F> {
443 function: F,
444 name: &'static str,
445}
446
447impl<F> ArrayFunctionDecorator<F>
448where
449 F: Send + Sync + 'static,
450{
451 #[must_use]
453 pub fn new(function: F, name: &'static str) -> Self {
454 Self { function, name }
455 }
456
457 pub fn register(self) -> F {
459 let implementation = Arc::new(
460 move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
461 Err(CoreError::NotImplementedError(ErrorContext::new(
466 "ArrayFunctionDecorator: Type conversion in array_function_dispatch is not implemented yet".to_string()
467 )))
468 },
469 );
470
471 let func = ArrayFunction {
472 name: self.name,
473 implementation,
474 };
475
476 let registry = ArrayFunctionRegistry::global();
478 if let Ok(mut registry) = registry.write() {
479 registry.register(func);
480 } else {
481 eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry, skipping function registration");
482 }
484
485 self.function
486 }
487}
488
489pub trait GPUArray: ArrayProtocol {
491 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>>;
493
494 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
496
497 #[must_use]
499 fn is_on_gpu(&self) -> bool;
500
501 #[must_use]
503 fn device_info(&self) -> HashMap<String, String>;
504}
505
506pub trait DistributedArray: ArrayProtocol {
508 #[must_use]
510 fn distribution_info(&self) -> HashMap<String, String>;
511
512 fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>;
514
515 fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>>;
517
518 #[must_use]
520 fn is_distributed(&self) -> bool;
521}
522
523pub trait JITArray: ArrayProtocol {
525 fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>>;
527
528 #[must_use]
530 fn supports_jit(&self) -> bool;
531
532 #[must_use]
534 fn jit_info(&self) -> HashMap<String, String>;
535}
536
537pub trait JITFunction: Send + Sync {
539 fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>>;
541
542 #[must_use]
544 fn source(&self) -> String;
545
546 #[must_use]
548 fn compile_info(&self) -> HashMap<String, String>;
549
550 #[must_use]
552 fn clone_box(&self) -> Box<dyn JITFunction>;
553}
554
555pub trait JITFunctionFactory: Send + Sync {
557 fn create_jit_function(
559 &self,
560 expression: &str,
561 array_typeid: TypeId,
562 ) -> CoreResult<Box<dyn JITFunction>>;
563
564 #[must_use]
566 fn supports_array_type(&self, array_typeid: TypeId) -> bool;
567}
568
569#[derive(Default)]
571pub struct JITFactoryRegistry {
572 factories: Vec<Box<dyn JITFunctionFactory>>,
573}
574
575impl std::fmt::Debug for JITFactoryRegistry {
576 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
577 write!(
578 f,
579 "JITFactoryRegistry {{ factories: {} }}",
580 self.factories.len()
581 )
582 }
583}
584
585impl JITFactoryRegistry {
586 #[must_use]
588 pub fn global() -> &'static RwLock<Self> {
589 static REGISTRY: LazyLock<RwLock<JITFactoryRegistry>> = LazyLock::new(|| {
590 RwLock::new(JITFactoryRegistry {
591 factories: Vec::new(),
592 })
593 });
594 ®ISTRY
595 }
596
597 pub fn register(&mut self, factory: Box<dyn JITFunctionFactory>) {
599 self.factories.push(factory);
600 }
601
602 #[must_use]
604 pub fn get_factory_for_array_type(
605 &self,
606 array_typeid: TypeId,
607 ) -> Option<&dyn JITFunctionFactory> {
608 for factory in &self.factories {
609 if factory.supports_array_type(array_typeid) {
610 return Some(&**factory);
611 }
612 }
613 None
614 }
615}
616
617#[derive(Debug, Clone)]
619pub struct NdarrayWrapper<T, D: crate::ndarray::Dimension> {
620 array: crate::ndarray::Array<T, D>,
621 phantom: PhantomData<(T, D)>,
622}
623
624impl<T, D> NdarrayWrapper<T, D>
625where
626 T: Clone + 'static,
627 D: crate::ndarray::Dimension + 'static,
628{
629 #[must_use]
631 pub fn new(array: crate::ndarray::Array<T, D>) -> Self {
632 Self {
633 array,
634 phantom: PhantomData,
635 }
636 }
637
638 #[must_use]
640 pub const fn as_array(&self) -> &crate::ndarray::Array<T, D> {
641 &self.array
642 }
643
644 #[must_use]
646 pub fn into_array(self) -> crate::ndarray::Array<T, D> {
647 self.array
648 }
649
650 pub fn array_2(&mut self, newarray: crate::ndarray::Array<T, D>) {
652 self.array = newarray;
653 }
654}
655
656impl<T, D> ArrayProtocol for NdarrayWrapper<T, D>
657where
658 T: Clone + Send + Sync + 'static,
659 D: crate::ndarray::Dimension + Send + Sync + 'static,
660{
661 fn array_function(
662 &self,
663 func: &ArrayFunction,
664 _types: &[TypeId],
665 args: &[Box<dyn Any>],
666 kwargs: &HashMap<String, Box<dyn Any>>,
667 ) -> Result<Box<dyn Any>, NotImplemented> {
668 match func.name {
669 "scirs2::array_protocol::operations::add" => {
670 if args.len() < 2 {
672 return Err(NotImplemented);
673 }
674
675 if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
676 if let (Some(a), Some(b)) = (
677 self.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
678 other.as_any().downcast_ref::<NdarrayWrapper<T, D>>(),
679 ) {
680 if TypeId::of::<T>() == TypeId::of::<f64>() {
682 let a_f64 =
683 unsafe { &*(a as *const _ as *const NdarrayWrapper<f64, D>) };
684 let b_f64 =
685 unsafe { &*(b as *const _ as *const NdarrayWrapper<f64, D>) };
686 let result = a_f64.as_array() + b_f64.as_array();
687 return Ok(Box::new(NdarrayWrapper::new(result)));
688 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
689 let a_f32 =
690 unsafe { &*(a as *const _ as *const NdarrayWrapper<f32, D>) };
691 let b_f32 =
692 unsafe { &*(b as *const _ as *const NdarrayWrapper<f32, D>) };
693 let result = a_f32.as_array() + b_f32.as_array();
694 return Ok(Box::new(NdarrayWrapper::new(result)));
695 }
696 }
697 }
698 Err(NotImplemented)
699 }
700 "scirs2::array_protocol::operations::matmul" => {
701 if args.len() < 2 {
703 return Err(NotImplemented);
704 }
705
706 if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
709 return Err(NotImplemented);
710 }
711
712 if let Some(other) = args[1].downcast_ref::<NdarrayWrapper<T, D>>() {
713 if TypeId::of::<T>() == TypeId::of::<f64>() {
718 let a_f64 = unsafe {
720 &*(self as *const _ as *const NdarrayWrapper<f64, crate::ndarray::Ix2>)
721 };
722 let b_f64 = unsafe {
723 &*(other as *const _ as *const NdarrayWrapper<f64, crate::ndarray::Ix2>)
724 };
725
726 let ashape = a_f64.as_array().shape();
728 let bshape = b_f64.as_array().shape();
729
730 if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
731 return Err(NotImplemented);
732 }
733
734 let result = a_f64.as_array().dot(b_f64.as_array());
737 return Ok(Box::new(NdarrayWrapper::new(result)));
738 }
739 else if TypeId::of::<T>() == TypeId::of::<f32>() {
741 let a_f32 = unsafe {
743 &*(self as *const _ as *const NdarrayWrapper<f32, crate::ndarray::Ix2>)
744 };
745 let b_f32 = unsafe {
746 &*(other as *const _ as *const NdarrayWrapper<f32, crate::ndarray::Ix2>)
747 };
748
749 let ashape = a_f32.as_array().shape();
751 let bshape = b_f32.as_array().shape();
752
753 if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
754 return Err(NotImplemented);
755 }
756
757 let result = a_f32.as_array().dot(b_f32.as_array());
760 return Ok(Box::new(NdarrayWrapper::new(result)));
761 }
762 }
763 Err(NotImplemented)
765 }
766 "scirs2::array_protocol::operations::transpose" => {
767 if TypeId::of::<T>() == TypeId::of::<f64>() {
769 let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
770 let result = a_f64.as_array().t().to_owned();
771 return Ok(Box::new(NdarrayWrapper::new(result)));
772 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
773 let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
774 let result = a_f32.as_array().t().to_owned();
775 return Ok(Box::new(NdarrayWrapper::new(result)));
776 }
777 Err(NotImplemented)
778 }
779 "scirs2::array_protocol::operations::sum" => {
780 let axis_ref = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
782
783 if TypeId::of::<T>() == TypeId::of::<f64>() {
784 let a_f64 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
785 match axis_ref {
786 Some(&_ax) => {
787 let result = a_f64.as_array().sum();
790 return Ok(Box::new(result));
791 }
792 None => {
793 let result = a_f64.as_array().sum();
794 return Ok(Box::new(result));
795 }
796 }
797 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
798 let a_f32 = unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
799 match axis_ref {
800 Some(&_ax) => {
801 let result = a_f32.as_array().sum();
804 return Ok(Box::new(result));
805 }
806 None => {
807 let result = a_f32.as_array().sum();
808 return Ok(Box::new(result));
809 }
810 }
811 }
812 Err(NotImplemented)
813 }
814 "scirs2::array_protocol::operations::reshape" => {
815 if let Some(shape) = kwargs
817 .get("shape")
818 .and_then(|s| s.downcast_ref::<Vec<usize>>())
819 {
820 if TypeId::of::<T>() == TypeId::of::<f64>() {
821 let a_f64 =
822 unsafe { &*(self as *const _ as *const NdarrayWrapper<f64, D>) };
823 match a_f64
824 .as_array()
825 .clone()
826 .into_shape_with_order(shape.clone())
827 {
828 Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
829 Err(_) => return Err(NotImplemented),
830 }
831 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
832 let a_f32 =
833 unsafe { &*(self as *const _ as *const NdarrayWrapper<f32, D>) };
834 match a_f32
835 .as_array()
836 .clone()
837 .into_shape_with_order(shape.clone())
838 {
839 Ok(result) => return Ok(Box::new(NdarrayWrapper::new(result))),
840 Err(_) => return Err(NotImplemented),
841 }
842 }
843 }
844 Err(NotImplemented)
845 }
846 _ => Err(NotImplemented),
847 }
848 }
849
850 fn as_any(&self) -> &dyn Any {
851 self
852 }
853
854 fn shape(&self) -> &[usize] {
855 self.array.shape()
856 }
857
858 fn dtype(&self) -> TypeId {
859 TypeId::of::<T>()
860 }
861
862 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
863 Box::new(self.clone())
864 }
865}
866
867#[derive(Debug, Clone)]
871pub struct MockDistributedArray<T: Clone + 'static> {
872 chunks: Vec<T>,
873 shape: Vec<usize>,
874}
875
876impl<T: Clone + Send + Sync + 'static> MockDistributedArray<T> {
877 #[must_use]
879 pub fn new(chunks: Vec<T>, shape: Vec<usize>) -> Self {
880 Self { chunks, shape }
881 }
882}
883
884impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockDistributedArray<T> {
885 fn array_function(
886 &self,
887 func: &ArrayFunction,
888 _types: &[TypeId],
889 _args: &[Box<dyn Any>],
890 _kwargs: &HashMap<String, Box<dyn Any>>,
891 ) -> Result<Box<dyn Any>, NotImplemented> {
892 match func.name {
893 "scirs2::mean" => {
894 let result = T::clone(&self.chunks[0]);
899 Ok(Box::new(result))
900 }
901 _ => Err(NotImplemented),
902 }
903 }
904
905 fn as_any(&self) -> &dyn Any {
906 self
907 }
908
909 fn shape(&self) -> &[usize] {
910 &self.shape
911 }
912
913 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
914 Box::new(self.clone())
915 }
916}
917
918impl<T: Clone + Send + Sync + 'static> DistributedArray for MockDistributedArray<T> {
919 fn distribution_info(&self) -> HashMap<String, String> {
920 let mut info = HashMap::new();
921 info.insert("type".to_string(), "mock_distributed".to_string());
922 info.insert("chunks".to_string(), self.chunks.len().to_string());
923 info
924 }
925
926 fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
927 Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
930 }
931
932 fn scatter(&self, _numchunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
933 Ok(Box::new(self.clone()) as Box<dyn DistributedArray>)
936 }
937
938 fn is_distributed(&self) -> bool {
939 true
940 }
941}
942
943#[derive(Debug, Clone)]
945pub struct MockGPUArray<T: Clone + 'static> {
946 data: Vec<T>,
947 shape: Vec<usize>,
948 device: String,
949}
950
951impl<T: Clone + Send + Sync + 'static> MockGPUArray<T> {
952 #[must_use]
954 pub fn new(data: Vec<T>, shape: Vec<usize>, device: String) -> Self {
955 Self {
956 data,
957 shape,
958 device,
959 }
960 }
961}
962
963impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MockGPUArray<T> {
964 fn array_function(
965 &self,
966 func: &ArrayFunction,
967 _types: &[TypeId],
968 _args: &[Box<dyn Any>],
969 _kwargs: &HashMap<String, Box<dyn Any>>,
970 ) -> Result<Box<dyn Any>, NotImplemented> {
971 match func.name {
972 "scirs2::matmul" => {
973 let result =
978 MockGPUArray::new(self.data.clone(), self.shape.clone(), self.device.clone());
979 Ok(Box::new(result))
980 }
981 _ => Err(NotImplemented),
982 }
983 }
984
985 fn as_any(&self) -> &dyn Any {
986 self
987 }
988
989 fn shape(&self) -> &[usize] {
990 &self.shape
991 }
992
993 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
994 Box::new(self.clone())
995 }
996}
997
998impl<T: Clone + Send + Sync + 'static> GPUArray for MockGPUArray<T> {
999 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
1000 Ok(Box::new(self.clone()) as Box<dyn GPUArray>)
1002 }
1003
1004 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
1005 Ok(Box::new(self.clone()) as Box<dyn ArrayProtocol>)
1008 }
1009
1010 fn is_on_gpu(&self) -> bool {
1011 true
1012 }
1013
1014 fn device_info(&self) -> HashMap<String, String> {
1015 let mut info = HashMap::new();
1016 info.insert("device".to_string(), self.device.clone());
1017 info.insert("type".to_string(), "mock_gpu".to_string());
1018 info
1019 }
1020}
1021
1022#[derive(Debug)]
1027pub struct ArrayProtocolFunction<F> {
1028 func: F,
1029 name: &'static str,
1030}
1031
1032impl<F> ArrayProtocolFunction<F> {
1033 #[must_use]
1035 pub fn new(func: F, name: &'static str) -> Self {
1036 Self { func, name }
1037 }
1038}
1039
1040impl<F> ArrayProtocolFunction<F>
1041where
1042 F: Clone + Send + Sync + 'static,
1043{
1044 pub fn register(self) -> F {
1046 let implementation = Arc::new(
1047 move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1048 Err(CoreError::NotImplementedError(ErrorContext::new(
1053 "ArrayProtocolFunction: Implementation for array protocol functions is not complete".to_string()
1054 )))
1055 },
1056 );
1057
1058 let array_func = ArrayFunction {
1059 name: self.name,
1060 implementation,
1061 };
1062
1063 if let Ok(mut registry) = ArrayFunctionRegistry::global().write() {
1065 registry.register(array_func);
1066 } else {
1067 eprintln!("Warning: Failed to acquire write lock on ArrayFunctionRegistry during array protocol building, skipping function registration");
1068 }
1070
1071 self.func
1072 }
1073}
1074
1075#[macro_export]
1117macro_rules! array_function_def {
1118 (fn $name:ident $(<$($gen:ident),*>)? ($($arg:ident : $arg_ty:ty),*) -> $ret:ty $body:block, $funcname:expr) => {
1119 {
1120 fn $name $(<$($gen),*>)? ($($arg : $arg_ty),*) -> $ret $body
1122
1123 $name
1125 }
1126 };
1127}
1128
1129pub use self::distributed_impl::{
1131 ArrayChunk, DistributedBackend, DistributedConfig, DistributedNdarray, DistributionStrategy,
1132};
1133
1134pub use self::gpu_impl::{
1136 kernels as gpu_kernels, GPUArrayBuilder, GPUBackend, GPUConfig, GPUNdarray,
1137};
1138
1139pub use self::jit_impl::{
1141 CraneliftFunctionFactory, JITBackend, JITConfig, JITEnabledArray, JITFunctionImpl, JITManager,
1142 LLVMFunctionFactory,
1143};
1144
1145pub use self::operations::{
1147 add, apply_elementwise, concatenate, inverse, matmul, multiply, reshape, subtract, sum, svd,
1148 transpose, OperationError,
1149};
1150
1151pub use self::ml_ops::{
1153 activation, batch_norm, conv2d, cross_entropy, dropout, max_pool2d, self_attention,
1154 ActivationFunc,
1155};
1156
1157#[allow(dead_code)]
1163pub fn init() {
1164 let mut jit_manager = JITManager::global().write().expect("Operation failed");
1166 jit_manager.initialize();
1167}
1168
1169pub mod traits {
1171 use super::*;
1172
1173 pub trait StridedArray: ArrayProtocol {
1175 #[must_use]
1177 fn strides(&self) -> Vec<usize>;
1178
1179 #[must_use]
1181 fn is_contiguous(&self) -> bool;
1182
1183 #[must_use]
1185 fn is_fortran_contiguous(&self) -> bool;
1186 }
1187
1188 pub trait ZeroCopyArray: ArrayProtocol {
1190 #[must_use]
1192 fn view(&self) -> Box<dyn ZeroCopyArray>;
1193
1194 #[must_use]
1196 fn view_mut(&mut self) -> Box<dyn ZeroCopyArray>;
1197
1198 #[must_use]
1200 fn is_view(&self) -> bool;
1201 }
1202
1203 pub trait DifferentiableArray: ArrayProtocol {
1205 fn gradient(
1207 &self,
1208 variables: &[Box<dyn DifferentiableArray>],
1209 ) -> Vec<Box<dyn DifferentiableArray>>;
1210
1211 fn set_requiresgrad(&mut self, requiresgrad: bool);
1213
1214 #[must_use]
1216 fn requiresgrad(&self) -> bool;
1217
1218 #[must_use]
1220 fn grad(&self) -> Option<Box<dyn DifferentiableArray>>;
1221 }
1222
1223 pub trait AsyncArray: ArrayProtocol {
1225 fn async_op<F, R>(&self, op: F) -> impl std::future::Future<Output = CoreResult<R>>
1227 where
1228 F: FnOnce(&Self) -> CoreResult<R> + Send + 'static,
1229 R: Send + 'static;
1230
1231 #[must_use]
1233 fn supports_async(&self) -> bool;
1234 }
1235}
1236
1237#[cfg(test)]
1238mod tests {
1239 use super::*;
1240
1241 #[test]
1242 fn test_array_protocol_registry() {
1243 let implementation = Arc::new(
1245 move |_args: &[Box<dyn Any>], _kwargs: &HashMap<String, Box<dyn Any>>| {
1246 Ok(Box::new(42.0) as Box<dyn Any>)
1247 },
1248 );
1249
1250 let func = ArrayFunction {
1251 name: "scirs2::test::test_func",
1252 implementation,
1253 };
1254
1255 let registry = ArrayFunctionRegistry::global();
1256 {
1257 let mut reg = registry.write().expect("Operation failed");
1258 reg.register(func.clone());
1259 }
1260
1261 {
1263 let reg = registry.read().expect("Operation failed");
1264 let registered_func = reg
1265 .get("scirs2::test::test_func")
1266 .expect("Operation failed");
1267 assert_eq!(registered_func.name, "scirs2::test::test_func");
1268 }
1269 }
1270
1271 #[test]
1272 fn test_mock_distributed_array() {
1273 let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1274 assert!(array.is_distributed());
1275
1276 let info = array.distribution_info();
1277 assert_eq!(
1278 info.get("type").expect("Operation failed"),
1279 "mock_distributed"
1280 );
1281 assert_eq!(info.get("chunks").expect("Operation failed"), "3");
1282 }
1283
1284 #[test]
1285 fn test_mock_gpu_array() {
1286 let array = MockGPUArray::new(vec![1.0, 2.0, 3.0], vec![3], "cuda:0".to_string());
1287 assert!(array.is_on_gpu());
1288
1289 let info = array.device_info();
1290 assert_eq!(info.get("device").expect("Operation failed"), "cuda:0");
1291 assert_eq!(info.get("type").expect("Operation failed"), "mock_gpu");
1292 }
1293
1294 #[test]
1295 fn test_box_clone() {
1296 let array = crate::ndarray::Array2::<f64>::ones((3, 3));
1298 let wrapped = NdarrayWrapper::new(array);
1299 let boxed: Box<dyn ArrayProtocol> = Box::new(wrapped);
1300 let cloned = boxed.clone();
1301
1302 assert_eq!(cloned.shape(), &[3, 3]);
1304
1305 let array = MockDistributedArray::new(vec![1.0, 2.0, 3.0], vec![3]);
1307 let boxed: Box<dyn ArrayProtocol> = Box::new(array);
1308 let cloned = boxed.clone();
1309
1310 assert_eq!(cloned.shape(), &[3]);
1312 }
1313}
1314
1315#[cfg(test)]
1317mod examples {
1318 use super::*;
1319 use ::ndarray::Array2;
1320 use std::any::Any;
1321 use std::collections::HashMap;
1322
1323 #[test]
1325 fn example_distributed_array() {
1326 let array = Array2::<f64>::ones((10, 5));
1328
1329 let config = DistributedConfig {
1331 chunks: 3,
1332 balance: true,
1333 strategy: DistributionStrategy::RowWise,
1334 backend: DistributedBackend::Threaded,
1335 };
1336
1337 let dist_array = DistributedNdarray::from_array(&array, config);
1339
1340 assert_eq!(dist_array.num_chunks(), 3);
1342 assert_eq!(dist_array.shape(), &[10, 5]);
1343
1344 let result = dist_array.to_array().expect("Operation failed");
1346
1347 assert_eq!(result.shape(), array.shape());
1349 }
1352
1353 #[test]
1355 fn example_gpu_array() {
1356 let array = Array2::<f64>::ones((10, 5));
1358
1359 let config = GPUConfig {
1361 backend: GPUBackend::CUDA,
1362 device_id: 0,
1363 async_ops: true,
1364 mixed_precision: false,
1365 memory_fraction: 0.9,
1366 };
1367
1368 let gpu_array = GPUNdarray::new(array.clone(), config);
1370
1371 assert_eq!(gpu_array.shape(), &[10, 5]);
1373 assert!(gpu_array.is_on_gpu());
1374
1375 let info = gpu_array.device_info();
1377 assert_eq!(info.get("backend").expect("Operation failed"), "CUDA");
1378
1379 let gpu_box: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1381 let gpu_clone = gpu_box.clone();
1382
1383 assert_eq!(gpu_clone.shape(), &[10, 5]);
1385 }
1386
1387 #[test]
1389 fn example_jit_array() {
1390 init();
1392
1393 let array = Array2::<f64>::ones((10, 5));
1395 let wrapped = NdarrayWrapper::new(array);
1396
1397 let jitarray: JITEnabledArray<f64, NdarrayWrapper<f64, crate::ndarray::Ix2>> =
1399 JITEnabledArray::new(wrapped);
1400
1401 assert!(jitarray.supports_jit());
1403
1404 let expression = "x + y";
1406 let jit_function = jitarray.compile(expression).expect("Operation failed");
1407
1408 assert_eq!(jit_function.source(), expression);
1410
1411 let info = jitarray.jit_info();
1413 assert_eq!(info.get("supports_jit").expect("Operation failed"), "true");
1414
1415 let jit_box: Box<dyn ArrayProtocol> = Box::new(jitarray);
1417 let jit_clone = jit_box.clone();
1418
1419 assert_eq!(jit_clone.shape(), &[10, 5]);
1421 }
1422
1423 #[test]
1425 fn example_cloning_array_protocol_objects() {
1426 let array = Array2::<f64>::ones((10, 5));
1428 let config = GPUConfig::default();
1429 let gpu_array = GPUNdarray::new(array.clone(), config);
1430
1431 let boxed: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1433 let cloned = boxed.clone();
1434
1435 assert_eq!(cloned.shape(), &[10, 5]);
1437
1438 let config = DistributedConfig {
1440 chunks: 3,
1441 balance: true,
1442 strategy: DistributionStrategy::RowWise,
1443 backend: DistributedBackend::Threaded,
1444 };
1445 let dist_array = DistributedNdarray::from_array(&array, config);
1446
1447 let boxed: Box<dyn ArrayProtocol> = Box::new(dist_array);
1449 let cloned = boxed.clone();
1450
1451 assert_eq!(cloned.shape(), &[10, 5]);
1453 }
1454
1455 #[test]
1501 fn example_array_interoperability() {
1502 init();
1504
1505 let cpu_array = Array2::<f64>::ones((5, 5));
1507
1508 let gpu_config = GPUConfig {
1510 backend: GPUBackend::CUDA,
1511 device_id: 0,
1512 async_ops: false,
1513 mixed_precision: false,
1514 memory_fraction: 0.9,
1515 };
1516 let gpu_array = GPUNdarray::new(cpu_array.clone(), gpu_config);
1517
1518 let dist_config = DistributedConfig {
1520 chunks: 2,
1521 balance: true,
1522 strategy: DistributionStrategy::RowWise,
1523 backend: DistributedBackend::Threaded,
1524 };
1525 let dist_array = DistributedNdarray::from_array(&cpu_array, dist_config);
1526
1527 let gpu_wrapper: Box<dyn ArrayProtocol> = Box::new(gpu_array);
1529 let dist_wrapper: Box<dyn ArrayProtocol> = Box::new(dist_array);
1530
1531 let gpu_clone = gpu_wrapper.clone();
1533 let dist_clone = dist_wrapper.clone();
1534
1535 assert_eq!(gpu_clone.shape(), &[5, 5]);
1536 assert_eq!(dist_clone.shape(), &[5, 5]);
1537 }
1538
1539 #[test]
1541 fn example_custom_array_type() {
1542 use std::sync::Arc;
1543
1544 struct MyCustomArray<T> {
1546 data: Vec<T>,
1547 shape: Vec<usize>,
1548 }
1549
1550 impl<T: Clone + 'static> MyCustomArray<T> {
1551 fn new(data: Vec<T>, shape: Vec<usize>) -> Self {
1552 Self { data, shape }
1553 }
1554
1555 }
1560
1561 impl<T: Clone + Send + Sync + 'static> ArrayProtocol for MyCustomArray<T> {
1563 fn array_function(
1564 &self,
1565 func: &ArrayFunction,
1566 _types: &[TypeId],
1567 _args: &[Box<dyn Any>],
1568 _kwargs: &HashMap<String, Box<dyn Any>>,
1569 ) -> Result<Box<dyn Any>, NotImplemented> {
1570 if func.name == "scirs2::example::custom_sum" {
1571 match std::any::TypeId::of::<T>() {
1573 tid if tid == std::any::TypeId::of::<f64>() => {
1574 let f64_data = unsafe {
1576 std::slice::from_raw_parts(
1577 self.data.as_ptr() as *const f64,
1578 self.data.len(),
1579 )
1580 };
1581 let sum = f64_data.iter().sum::<f64>();
1582 Ok(Box::new(sum))
1583 }
1584 tid if tid == std::any::TypeId::of::<f32>() => {
1585 let f32_data = unsafe {
1587 std::slice::from_raw_parts(
1588 self.data.as_ptr() as *const f32,
1589 self.data.len(),
1590 )
1591 };
1592 let sum = f32_data.iter().sum::<f32>();
1593 Ok(Box::new(sum))
1594 }
1595 _ => Err(NotImplemented),
1596 }
1597 } else {
1598 Err(NotImplemented)
1599 }
1600 }
1601
1602 fn as_any(&self) -> &dyn Any {
1603 self
1604 }
1605
1606 fn shape(&self) -> &[usize] {
1607 &self.shape
1608 }
1609
1610 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
1611 Box::new(MyCustomArray {
1612 data: self.data.clone(),
1613 shape: self.shape.clone(),
1614 })
1615 }
1616 }
1617
1618 let custom_array = MyCustomArray::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1620
1621 let boxed: Box<dyn ArrayProtocol> = Box::new(custom_array);
1623 let cloned = boxed.clone();
1624
1625 assert_eq!(cloned.shape(), &[2, 2]);
1627
1628 let func = ArrayFunction {
1630 name: "scirs2::example::custom_sum",
1631 implementation: Arc::new(move |_args, _kwargs| {
1632 Ok(Box::new(42.0) as Box<dyn Any>)
1634 }),
1635 };
1636
1637 let result = cloned.array_function(
1639 &func,
1640 &[std::any::TypeId::of::<f64>()],
1641 &[],
1642 &HashMap::new(),
1643 );
1644
1645 assert!(result.is_ok());
1647 if let Ok(value) = result {
1648 let sum = *value.downcast_ref::<f64>().expect("Operation failed");
1649 assert_eq!(sum, 10.0);
1650 }
1651 }
1652}
1653impl Clone for Box<dyn JITFunction> {
1655 fn clone(&self) -> Self {
1656 self.clone_box()
1657 }
1658}