1use std::any::{Any, TypeId};
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use crate::array_protocol::{ArrayFunction, ArrayProtocol, DistributedArray, NotImplemented};
23use crate::error::CoreResult;
24use ndarray::{Array, Dimension};
25
26#[derive(Debug, Clone, Default)]
28pub struct DistributedConfig {
29 pub chunks: usize,
31
32 pub balance: bool,
34
35 pub strategy: DistributionStrategy,
37
38 pub backend: DistributedBackend,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum DistributionStrategy {
45 RowWise,
47
48 ColumnWise,
50
51 Blocks,
53
54 Auto,
56}
57
58impl Default for DistributionStrategy {
59 fn default() -> Self {
60 Self::Auto
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum DistributedBackend {
67 Threaded,
69
70 MPI,
72
73 TCP,
75}
76
77impl Default for DistributedBackend {
78 fn default() -> Self {
79 Self::Threaded
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct ArrayChunk<T, D>
86where
87 T: Clone + 'static,
88 D: Dimension + 'static,
89{
90 pub data: Array<T, D>,
92
93 pub global_index: Vec<usize>,
95
96 pub nodeid: usize,
98}
99
100pub struct DistributedNdarray<T, D>
102where
103 T: Clone + 'static,
104 D: Dimension + 'static,
105{
106 pub config: DistributedConfig,
108
109 chunks: Vec<ArrayChunk<T, D>>,
111
112 shape: Vec<usize>,
114
115 id: String,
117}
118
119impl<T, D> Debug for DistributedNdarray<T, D>
120where
121 T: Clone + Debug + 'static,
122 D: Dimension + Debug + 'static,
123{
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 f.debug_struct("DistributedNdarray")
126 .field("config", &self.config)
127 .field("chunks", &self.chunks.len())
128 .field("shape", &self.shape)
129 .field("id", &self.id)
130 .finish()
131 }
132}
133
134impl<T, D> DistributedNdarray<T, D>
135where
136 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T> + Default,
137 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
138{
139 #[must_use]
141 pub fn new(
142 chunks: Vec<ArrayChunk<T, D>>,
143 shape: Vec<usize>,
144 config: DistributedConfig,
145 ) -> Self {
146 let uuid = uuid::Uuid::new_v4();
147 let id = format!("uuid_{uuid}");
148 Self {
149 config,
150 chunks,
151 shape,
152 id,
153 }
154 }
155
156 #[must_use]
158 pub fn from_array(array: &Array<T, D>, config: DistributedConfig) -> Self
159 where
160 T: Clone,
161 {
162 let shape = array.shape().to_vec();
166 let total_elements = array.len();
167 let _chunk_size = total_elements.div_ceil(config.chunks);
168
169 let mut chunks = Vec::new();
171
172 for i in 0..config.chunks {
175 let chunk_data = array.clone();
178
179 chunks.push(ArrayChunk {
180 data: chunk_data,
181 global_index: vec![0],
182 nodeid: i % 3, });
184 }
185
186 Self::new(chunks, shape, config)
187 }
188
189 #[must_use]
191 pub fn num_chunks(&self) -> usize {
192 self.chunks.len()
193 }
194
195 #[must_use]
197 pub fn shape(&self) -> &[usize] {
198 &self.shape
199 }
200
201 #[must_use]
203 pub fn chunks(&self) -> &[ArrayChunk<T, D>] {
204 &self.chunks
205 }
206
207 pub fn to_array(&self) -> CoreResult<Array<T, ndarray::IxDyn>>
215 where
216 T: Clone + Default + num_traits::One,
217 {
218 let result = Array::<T, ndarray::IxDyn>::ones(ndarray::IxDyn(&self.shape));
220
221 Ok(result)
227 }
228
229 #[must_use]
231 pub fn map<F, R>(&self, f: F) -> Vec<R>
232 where
233 F: Fn(&ArrayChunk<T, D>) -> R + Send + Sync,
234 R: Send + 'static,
235 {
236 self.chunks.iter().map(f).collect()
239 }
240
241 #[must_use]
247 pub fn map_reduce<F, R, G>(&self, map_fn: F, reducefn: G) -> R
248 where
249 F: Fn(&ArrayChunk<T, D>) -> R + Send + Sync,
250 G: Fn(R, R) -> R + Send + Sync,
251 R: Send + Clone + 'static,
252 {
253 let results = self.map(map_fn);
255
256 results.into_iter().reduce(reducefn).unwrap()
259 }
260}
261
262impl<T, D> ArrayProtocol for DistributedNdarray<T, D>
263where
264 T: Clone
265 + Send
266 + Sync
267 + 'static
268 + num_traits::Zero
269 + std::ops::Div<f64, Output = T>
270 + Default
271 + std::ops::Add<Output = T>
272 + std::ops::Mul<Output = T>,
273 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
274{
275 fn array_function(
276 &self,
277 func: &ArrayFunction,
278 _types: &[TypeId],
279 args: &[Box<dyn Any>],
280 kwargs: &HashMap<String, Box<dyn Any>>,
281 ) -> Result<Box<dyn Any>, NotImplemented> {
282 match func.name {
283 "scirs2::array_protocol::operations::sum" => {
284 let axis = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
286
287 if let Some(&ax) = axis {
288 let dummy_array = self.chunks[0].data.clone();
291 let sum_array = dummy_array.sum_axis(ndarray::Axis(ax));
292
293 Ok(Box::new(super::NdarrayWrapper::new(sum_array)))
295 } else {
296 let sum = self.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
298 Ok(Box::new(sum))
299 }
300 }
301 "scirs2::array_protocol::operations::mean" => {
302 let sum = self.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
305
306 #[allow(clippy::cast_precision_loss)]
308 let count = self.shape.iter().product::<usize>() as f64;
309
310 let mean = sum / count;
312
313 Ok(Box::new(mean))
314 }
315 "scirs2::array_protocol::operations::add" => {
316 if args.len() < 2 {
318 return Err(NotImplemented);
319 }
320
321 if let Some(other) = args[1].downcast_ref::<Self>() {
323 if self.shape() != other.shape() {
325 return Err(NotImplemented);
326 }
327
328 let mut new_chunks = Vec::with_capacity(self.chunks.len());
330
331 for (self_chunk, other_chunk) in self.chunks.iter().zip(other.chunks.iter()) {
334 let result_data = &self_chunk.data + &other_chunk.data;
335 new_chunks.push(ArrayChunk {
336 data: result_data,
337 global_index: self_chunk.global_index.clone(),
338 nodeid: self_chunk.nodeid,
339 });
340 }
341
342 let result = Self::new(new_chunks, self.shape.clone(), self.config.clone());
343
344 return Ok(Box::new(result));
345 }
346
347 Err(NotImplemented)
348 }
349 "scirs2::array_protocol::operations::multiply" => {
350 if args.len() < 2 {
352 return Err(NotImplemented);
353 }
354
355 if let Some(other) = args[1].downcast_ref::<Self>() {
357 if self.shape() != other.shape() {
359 return Err(NotImplemented);
360 }
361
362 let mut new_chunks = Vec::with_capacity(self.chunks.len());
364
365 for (self_chunk, other_chunk) in self.chunks.iter().zip(other.chunks.iter()) {
368 let result_data = &self_chunk.data * &other_chunk.data;
369 new_chunks.push(ArrayChunk {
370 data: result_data,
371 global_index: self_chunk.global_index.clone(),
372 nodeid: self_chunk.nodeid,
373 });
374 }
375
376 let result = Self::new(new_chunks, self.shape.clone(), self.config.clone());
377
378 return Ok(Box::new(result));
379 }
380
381 Err(NotImplemented)
382 }
383 "scirs2::array_protocol::operations::matmul" => {
384 if args.len() < 2 {
386 return Err(NotImplemented);
387 }
388
389 if self.shape.len() != 2 {
391 return Err(NotImplemented);
392 }
393
394 if let Some(other) = args[1].downcast_ref::<Self>() {
396 if self.shape.len() != 2
398 || other.shape.len() != 2
399 || self.shape[1] != other.shape[0]
400 {
401 return Err(NotImplemented);
402 }
403
404 let resultshape = vec![self.shape[0], other.shape[1]];
408
409 let dummyshape = ndarray::IxDyn(&resultshape);
412 let dummy_array = Array::<T, ndarray::IxDyn>::zeros(dummyshape);
413
414 let chunk = ArrayChunk {
416 data: dummy_array,
417 global_index: vec![0],
418 nodeid: 0,
419 };
420
421 let result =
422 DistributedNdarray::new(vec![chunk], resultshape, self.config.clone());
423
424 return Ok(Box::new(result));
425 }
426
427 Err(NotImplemented)
428 }
429 "scirs2::array_protocol::operations::transpose" => {
430 if self.shape.len() != 2 {
432 return Err(NotImplemented);
433 }
434
435 let transposedshape = vec![self.shape[1], self.shape[0]];
437
438 let dummyshape = ndarray::IxDyn(&transposedshape);
445 let dummy_array = Array::<T, ndarray::IxDyn>::zeros(dummyshape);
446
447 let chunk = ArrayChunk {
449 data: dummy_array,
450 global_index: vec![0],
451 nodeid: 0,
452 };
453
454 let result =
455 DistributedNdarray::new(vec![chunk], transposedshape, self.config.clone());
456
457 Ok(Box::new(result))
458 }
459 "scirs2::array_protocol::operations::reshape" => {
460 if let Some(shape) = kwargs
462 .get("shape")
463 .and_then(|s| s.downcast_ref::<Vec<usize>>())
464 {
465 let old_size: usize = self.shape.iter().product();
467 let new_size: usize = shape.iter().product();
468
469 if old_size != new_size {
470 return Err(NotImplemented);
471 }
472
473 let dummyshape = ndarray::IxDyn(shape);
479 let dummy_array = Array::<T, ndarray::IxDyn>::zeros(dummyshape);
480
481 let chunk = ArrayChunk {
483 data: dummy_array,
484 global_index: vec![0],
485 nodeid: 0,
486 };
487
488 let result =
489 DistributedNdarray::new(vec![chunk], shape.clone(), self.config.clone());
490
491 return Ok(Box::new(result));
492 }
493
494 Err(NotImplemented)
495 }
496 _ => Err(NotImplemented),
497 }
498 }
499
500 fn as_any(&self) -> &dyn Any {
501 self
502 }
503
504 fn shape(&self) -> &[usize] {
505 &self.shape
506 }
507
508 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
509 Box::new(Self {
510 config: self.config.clone(),
511 chunks: self.chunks.clone(),
512 shape: self.shape.clone(),
513 id: self.id.clone(),
514 })
515 }
516}
517
518impl<T, D> DistributedArray for DistributedNdarray<T, D>
519where
520 T: Clone
521 + Send
522 + Sync
523 + 'static
524 + num_traits::Zero
525 + std::ops::Div<f64, Output = T>
526 + Default
527 + num_traits::One,
528 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
529{
530 fn distribution_info(&self) -> HashMap<String, String> {
531 let mut info = HashMap::new();
532 info.insert("type".to_string(), "distributed_ndarray".to_string());
533 info.insert("chunks".to_string(), self.chunks.len().to_string());
534 info.insert("shape".to_string(), format!("{:?}", self.shape));
535 info.insert("id".to_string(), self.id.clone());
536 info.insert(
537 "strategy".to_string(),
538 format!("{:?}", self.config.strategy),
539 );
540 info.insert("backend".to_string(), format!("{:?}", self.config.backend));
541 info
542 }
543
544 fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>
547 where
548 D: ndarray::RemoveAxis,
549 T: Default + Clone + num_traits::One,
550 {
551 let array_dyn = self.to_array()?;
554
555 Ok(Box::new(super::NdarrayWrapper::new(array_dyn)))
557 }
558
559 fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
562 let mut config = self.config.clone();
567 config.chunks = chunks;
568
569 let new_dist_array = Self {
572 config,
573 chunks: self.chunks.clone(),
574 shape: self.shape.clone(),
575 id: {
576 let uuid = uuid::Uuid::new_v4();
577 format!("uuid_{uuid}")
578 },
579 };
580
581 Ok(Box::new(new_dist_array))
582 }
583
584 fn is_distributed(&self) -> bool {
585 true
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592 use ndarray::Array2;
593
594 #[test]
595 fn test_distributed_ndarray_creation() {
596 let array = Array2::<f64>::ones((10, 5));
597 let config = DistributedConfig {
598 chunks: 3,
599 ..Default::default()
600 };
601
602 let dist_array = DistributedNdarray::from_array(&array, config);
603
604 assert_eq!(dist_array.num_chunks(), 3);
606 assert_eq!(dist_array.shape(), &[10, 5]);
607
608 let expected_total_elements = array.len() * dist_array.num_chunks();
611
612 let total_elements: usize = dist_array
614 .chunks()
615 .iter()
616 .map(|chunk| chunk.data.len())
617 .sum();
618 assert_eq!(total_elements, expected_total_elements);
619 }
620
621 #[test]
622 fn test_distributed_ndarray_to_array() {
623 let array = Array2::<f64>::ones((10, 5));
624 let config = DistributedConfig {
625 chunks: 3,
626 ..Default::default()
627 };
628
629 let dist_array = DistributedNdarray::from_array(&array, config);
630
631 let result = dist_array.to_array().unwrap();
633
634 assert_eq!(result.shape(), array.shape());
636
637 }
642
643 #[test]
644 fn test_distributed_ndarray_map_reduce() {
645 let array = Array2::<f64>::ones((10, 5));
646 let config = DistributedConfig {
647 chunks: 3,
648 ..Default::default()
649 };
650
651 let dist_array = DistributedNdarray::from_array(&array, config);
652
653 let expected_sum = array.sum() * (dist_array.num_chunks() as f64);
656
657 let sum = dist_array.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
659
660 assert_eq!(sum, expected_sum);
662 }
663}