1use std::any::{Any, TypeId};
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use crate::array_protocol::{ArrayFunction, ArrayProtocol, DistributedArray, NotImplemented};
23use crate::error::CoreResult;
24use ::ndarray::{Array, Dimension};
25
26#[derive(Debug, Clone, Default)]
28pub struct DistributedConfig {
29 pub chunks: usize,
31
32 pub balance: bool,
34
35 pub strategy: DistributionStrategy,
37
38 pub backend: DistributedBackend,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum DistributionStrategy {
45 RowWise,
47
48 ColumnWise,
50
51 Blocks,
53
54 Auto,
56}
57
58impl Default for DistributionStrategy {
59 fn default() -> Self {
60 Self::Auto
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum DistributedBackend {
67 Threaded,
69
70 MPI,
72
73 TCP,
75}
76
77impl Default for DistributedBackend {
78 fn default() -> Self {
79 Self::Threaded
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct ArrayChunk<T, D>
86where
87 T: Clone + 'static,
88 D: Dimension + 'static,
89{
90 pub data: Array<T, D>,
92
93 pub global_index: Vec<usize>,
95
96 pub nodeid: usize,
98}
99
100pub struct DistributedNdarray<T, D>
102where
103 T: Clone + 'static,
104 D: Dimension + 'static,
105{
106 pub config: DistributedConfig,
108
109 chunks: Vec<ArrayChunk<T, D>>,
111
112 shape: Vec<usize>,
114
115 id: String,
117}
118
119impl<T, D> Debug for DistributedNdarray<T, D>
120where
121 T: Clone + Debug + 'static,
122 D: Dimension + Debug + 'static,
123{
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 f.debug_struct("DistributedNdarray")
126 .field("config", &self.config)
127 .field("chunks", &self.chunks.len())
128 .field("shape", &self.shape)
129 .field("id", &self.id)
130 .finish()
131 }
132}
133
134impl<T, D> DistributedNdarray<T, D>
135where
136 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T> + Default,
137 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
138{
139 #[must_use]
141 pub fn new(
142 chunks: Vec<ArrayChunk<T, D>>,
143 shape: Vec<usize>,
144 config: DistributedConfig,
145 ) -> Self {
146 let uuid = uuid::Uuid::new_v4();
147 let id = format!("uuid_{uuid}");
148 Self {
149 config,
150 chunks,
151 shape,
152 id,
153 }
154 }
155
156 #[must_use]
158 pub fn from_array(array: &Array<T, D>, config: DistributedConfig) -> Self
159 where
160 T: Clone,
161 {
162 let shape = array.shape().to_vec();
166 let total_elements = array.len();
167 let _chunk_size = total_elements.div_ceil(config.chunks);
168
169 let mut chunks = Vec::new();
171
172 for i in 0..config.chunks {
175 let chunk_data = array.clone();
178
179 chunks.push(ArrayChunk {
180 data: chunk_data,
181 global_index: vec![0],
182 nodeid: i % 3, });
184 }
185
186 Self::new(chunks, shape, config)
187 }
188
189 #[must_use]
191 pub fn num_chunks(&self) -> usize {
192 self.chunks.len()
193 }
194
195 #[must_use]
197 pub fn shape(&self) -> &[usize] {
198 &self.shape
199 }
200
201 #[must_use]
203 pub fn chunks(&self) -> &[ArrayChunk<T, D>] {
204 &self.chunks
205 }
206
207 pub fn to_array(&self) -> CoreResult<Array<T, crate::ndarray::IxDyn>>
215 where
216 T: Clone + Default + num_traits::One,
217 {
218 let result = Array::<T, crate::ndarray::IxDyn>::ones(crate::ndarray::IxDyn(&self.shape));
220
221 Ok(result)
227 }
228
229 #[must_use]
231 pub fn map<F, R>(&self, f: F) -> Vec<R>
232 where
233 F: Fn(&ArrayChunk<T, D>) -> R + Send + Sync,
234 R: Send + 'static,
235 {
236 self.chunks.iter().map(f).collect()
239 }
240
241 #[must_use]
247 pub fn map_reduce<F, R, G>(&self, map_fn: F, reducefn: G) -> R
248 where
249 F: Fn(&ArrayChunk<T, D>) -> R + Send + Sync,
250 G: Fn(R, R) -> R + Send + Sync,
251 R: Send + Clone + 'static,
252 {
253 let results = self.map(map_fn);
255
256 results
259 .into_iter()
260 .reduce(reducefn)
261 .expect("Operation failed")
262 }
263}
264
265impl<T, D> ArrayProtocol for DistributedNdarray<T, D>
266where
267 T: Clone
268 + Send
269 + Sync
270 + 'static
271 + num_traits::Zero
272 + std::ops::Div<f64, Output = T>
273 + Default
274 + std::ops::Add<Output = T>
275 + std::ops::Mul<Output = T>,
276 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
277{
278 fn array_function(
279 &self,
280 func: &ArrayFunction,
281 _types: &[TypeId],
282 args: &[Box<dyn Any>],
283 kwargs: &HashMap<String, Box<dyn Any>>,
284 ) -> Result<Box<dyn Any>, NotImplemented> {
285 match func.name {
286 "scirs2::array_protocol::operations::sum" => {
287 let axis = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
289
290 if let Some(&ax) = axis {
291 let dummy_array = self.chunks[0].data.clone();
294 let sum_array = dummy_array.sum_axis(crate::ndarray::Axis(ax));
295
296 Ok(Box::new(super::NdarrayWrapper::new(sum_array)))
298 } else {
299 let sum = self.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
301 Ok(Box::new(sum))
302 }
303 }
304 "scirs2::array_protocol::operations::mean" => {
305 let sum = self.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
308
309 #[allow(clippy::cast_precision_loss)]
311 let count = self.shape.iter().product::<usize>() as f64;
312
313 let mean = sum / count;
315
316 Ok(Box::new(mean))
317 }
318 "scirs2::array_protocol::operations::add" => {
319 if args.len() < 2 {
321 return Err(NotImplemented);
322 }
323
324 if let Some(other) = args[1].downcast_ref::<Self>() {
326 if self.shape() != other.shape() {
328 return Err(NotImplemented);
329 }
330
331 let mut new_chunks = Vec::with_capacity(self.chunks.len());
333
334 for (self_chunk, other_chunk) in self.chunks.iter().zip(other.chunks.iter()) {
337 let result_data = &self_chunk.data + &other_chunk.data;
338 new_chunks.push(ArrayChunk {
339 data: result_data,
340 global_index: self_chunk.global_index.clone(),
341 nodeid: self_chunk.nodeid,
342 });
343 }
344
345 let result = Self::new(new_chunks, self.shape.clone(), self.config.clone());
346
347 return Ok(Box::new(result));
348 }
349
350 Err(NotImplemented)
351 }
352 "scirs2::array_protocol::operations::multiply" => {
353 if args.len() < 2 {
355 return Err(NotImplemented);
356 }
357
358 if let Some(other) = args[1].downcast_ref::<Self>() {
360 if self.shape() != other.shape() {
362 return Err(NotImplemented);
363 }
364
365 let mut new_chunks = Vec::with_capacity(self.chunks.len());
367
368 for (self_chunk, other_chunk) in self.chunks.iter().zip(other.chunks.iter()) {
371 let result_data = &self_chunk.data * &other_chunk.data;
372 new_chunks.push(ArrayChunk {
373 data: result_data,
374 global_index: self_chunk.global_index.clone(),
375 nodeid: self_chunk.nodeid,
376 });
377 }
378
379 let result = Self::new(new_chunks, self.shape.clone(), self.config.clone());
380
381 return Ok(Box::new(result));
382 }
383
384 Err(NotImplemented)
385 }
386 "scirs2::array_protocol::operations::matmul" => {
387 if args.len() < 2 {
389 return Err(NotImplemented);
390 }
391
392 if self.shape.len() != 2 {
394 return Err(NotImplemented);
395 }
396
397 if let Some(other) = args[1].downcast_ref::<Self>() {
399 if self.shape.len() != 2
401 || other.shape.len() != 2
402 || self.shape[1] != other.shape[0]
403 {
404 return Err(NotImplemented);
405 }
406
407 let resultshape = vec![self.shape[0], other.shape[1]];
411
412 let dummyshape = crate::ndarray::IxDyn(&resultshape);
415 let dummy_array = Array::<T, crate::ndarray::IxDyn>::zeros(dummyshape);
416
417 let chunk = ArrayChunk {
419 data: dummy_array,
420 global_index: vec![0],
421 nodeid: 0,
422 };
423
424 let result =
425 DistributedNdarray::new(vec![chunk], resultshape, self.config.clone());
426
427 return Ok(Box::new(result));
428 }
429
430 Err(NotImplemented)
431 }
432 "scirs2::array_protocol::operations::transpose" => {
433 if self.shape.len() != 2 {
435 return Err(NotImplemented);
436 }
437
438 let transposedshape = vec![self.shape[1], self.shape[0]];
440
441 let dummyshape = crate::ndarray::IxDyn(&transposedshape);
448 let dummy_array = Array::<T, crate::ndarray::IxDyn>::zeros(dummyshape);
449
450 let chunk = ArrayChunk {
452 data: dummy_array,
453 global_index: vec![0],
454 nodeid: 0,
455 };
456
457 let result =
458 DistributedNdarray::new(vec![chunk], transposedshape, self.config.clone());
459
460 Ok(Box::new(result))
461 }
462 "scirs2::array_protocol::operations::reshape" => {
463 if let Some(shape) = kwargs
465 .get("shape")
466 .and_then(|s| s.downcast_ref::<Vec<usize>>())
467 {
468 let old_size: usize = self.shape.iter().product();
470 let new_size: usize = shape.iter().product();
471
472 if old_size != new_size {
473 return Err(NotImplemented);
474 }
475
476 let dummyshape = crate::ndarray::IxDyn(shape);
482 let dummy_array = Array::<T, crate::ndarray::IxDyn>::zeros(dummyshape);
483
484 let chunk = ArrayChunk {
486 data: dummy_array,
487 global_index: vec![0],
488 nodeid: 0,
489 };
490
491 let result =
492 DistributedNdarray::new(vec![chunk], shape.clone(), self.config.clone());
493
494 return Ok(Box::new(result));
495 }
496
497 Err(NotImplemented)
498 }
499 _ => Err(NotImplemented),
500 }
501 }
502
503 fn as_any(&self) -> &dyn Any {
504 self
505 }
506
507 fn shape(&self) -> &[usize] {
508 &self.shape
509 }
510
511 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
512 Box::new(Self {
513 config: self.config.clone(),
514 chunks: self.chunks.clone(),
515 shape: self.shape.clone(),
516 id: self.id.clone(),
517 })
518 }
519}
520
521impl<T, D> DistributedArray for DistributedNdarray<T, D>
522where
523 T: Clone
524 + Send
525 + Sync
526 + 'static
527 + num_traits::Zero
528 + std::ops::Div<f64, Output = T>
529 + Default
530 + num_traits::One,
531 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
532{
533 fn distribution_info(&self) -> HashMap<String, String> {
534 let mut info = HashMap::new();
535 info.insert("type".to_string(), "distributed_ndarray".to_string());
536 info.insert("chunks".to_string(), self.chunks.len().to_string());
537 info.insert("shape".to_string(), format!("{:?}", self.shape));
538 info.insert("id".to_string(), self.id.clone());
539 info.insert(
540 "strategy".to_string(),
541 format!("{:?}", self.config.strategy),
542 );
543 info.insert("backend".to_string(), format!("{:?}", self.config.backend));
544 info
545 }
546
547 fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>
550 where
551 D: crate::ndarray::RemoveAxis,
552 T: Default + Clone + num_traits::One,
553 {
554 let array_dyn = self.to_array()?;
557
558 Ok(Box::new(super::NdarrayWrapper::new(array_dyn)))
560 }
561
562 fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
565 let mut config = self.config.clone();
570 config.chunks = chunks;
571
572 let new_dist_array = Self {
575 config,
576 chunks: self.chunks.clone(),
577 shape: self.shape.clone(),
578 id: {
579 let uuid = uuid::Uuid::new_v4();
580 format!("uuid_{uuid}")
581 },
582 };
583
584 Ok(Box::new(new_dist_array))
585 }
586
587 fn is_distributed(&self) -> bool {
588 true
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595 use ::ndarray::Array2;
596
597 #[test]
598 fn test_distributed_ndarray_creation() {
599 let array = Array2::<f64>::ones((10, 5));
600 let config = DistributedConfig {
601 chunks: 3,
602 ..Default::default()
603 };
604
605 let dist_array = DistributedNdarray::from_array(&array, config);
606
607 assert_eq!(dist_array.num_chunks(), 3);
609 assert_eq!(dist_array.shape(), &[10, 5]);
610
611 let expected_total_elements = array.len() * dist_array.num_chunks();
614
615 let total_elements: usize = dist_array
617 .chunks()
618 .iter()
619 .map(|chunk| chunk.data.len())
620 .sum();
621 assert_eq!(total_elements, expected_total_elements);
622 }
623
624 #[test]
625 fn test_distributed_ndarray_to_array() {
626 let array = Array2::<f64>::ones((10, 5));
627 let config = DistributedConfig {
628 chunks: 3,
629 ..Default::default()
630 };
631
632 let dist_array = DistributedNdarray::from_array(&array, config);
633
634 let result = dist_array.to_array().expect("Operation failed");
636
637 assert_eq!(result.shape(), array.shape());
639
640 }
645
646 #[test]
647 fn test_distributed_ndarray_map_reduce() {
648 let array = Array2::<f64>::ones((10, 5));
649 let config = DistributedConfig {
650 chunks: 3,
651 ..Default::default()
652 };
653
654 let dist_array = DistributedNdarray::from_array(&array, config);
655
656 let expected_sum = array.sum() * (dist_array.num_chunks() as f64);
659
660 let sum = dist_array.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
662
663 assert_eq!(sum, expected_sum);
665 }
666}