1use std::any::{Any, TypeId};
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16use crate::array_protocol::{ArrayFunction, ArrayProtocol, DistributedArray, NotImplemented};
17use crate::error::CoreResult;
18use ::ndarray::{Array, Dimension};
19
20#[derive(Debug, Clone, Default)]
22pub struct DistributedConfig {
23 pub chunks: usize,
25
26 pub balance: bool,
28
29 pub strategy: DistributionStrategy,
31
32 pub backend: DistributedBackend,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum DistributionStrategy {
39 RowWise,
41
42 ColumnWise,
44
45 Blocks,
47
48 Auto,
50}
51
52impl Default for DistributionStrategy {
53 fn default() -> Self {
54 Self::Auto
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum DistributedBackend {
61 Threaded,
63
64 MPI,
66
67 TCP,
69}
70
71impl Default for DistributedBackend {
72 fn default() -> Self {
73 Self::Threaded
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct ArrayChunk<T, D>
80where
81 T: Clone + 'static,
82 D: Dimension + 'static,
83{
84 pub data: Array<T, D>,
86
87 pub global_index: Vec<usize>,
89
90 pub nodeid: usize,
92}
93
94pub struct DistributedNdarray<T, D>
96where
97 T: Clone + 'static,
98 D: Dimension + 'static,
99{
100 pub config: DistributedConfig,
102
103 chunks: Vec<ArrayChunk<T, D>>,
105
106 shape: Vec<usize>,
108
109 id: String,
111}
112
113impl<T, D> Debug for DistributedNdarray<T, D>
114where
115 T: Clone + Debug + 'static,
116 D: Dimension + Debug + 'static,
117{
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("DistributedNdarray")
120 .field("config", &self.config)
121 .field("chunks", &self.chunks.len())
122 .field("shape", &self.shape)
123 .field("id", &self.id)
124 .finish()
125 }
126}
127
128impl<T, D> DistributedNdarray<T, D>
129where
130 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T> + Default,
131 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
132{
133 #[must_use]
135 pub fn new(
136 chunks: Vec<ArrayChunk<T, D>>,
137 shape: Vec<usize>,
138 config: DistributedConfig,
139 ) -> Self {
140 let uuid = uuid::Uuid::new_v4();
141 let id = format!("uuid_{uuid}");
142 Self {
143 config,
144 chunks,
145 shape,
146 id,
147 }
148 }
149
150 #[must_use]
152 pub fn from_array(array: &Array<T, D>, config: DistributedConfig) -> Self
153 where
154 T: Clone,
155 {
156 let shape = array.shape().to_vec();
160 let total_elements = array.len();
161 let _chunk_size = total_elements.div_ceil(config.chunks);
162
163 let mut chunks = Vec::new();
165
166 for i in 0..config.chunks {
169 let chunk_data = array.clone();
172
173 chunks.push(ArrayChunk {
174 data: chunk_data,
175 global_index: vec![0],
176 nodeid: i % 3, });
178 }
179
180 Self::new(chunks, shape, config)
181 }
182
183 #[must_use]
185 pub fn num_chunks(&self) -> usize {
186 self.chunks.len()
187 }
188
189 #[must_use]
191 pub fn shape(&self) -> &[usize] {
192 &self.shape
193 }
194
195 #[must_use]
197 pub fn chunks(&self) -> &[ArrayChunk<T, D>] {
198 &self.chunks
199 }
200
201 pub fn to_array(&self) -> CoreResult<Array<T, crate::ndarray::IxDyn>>
209 where
210 T: Clone + Default + num_traits::One,
211 {
212 let result = Array::<T, crate::ndarray::IxDyn>::ones(crate::ndarray::IxDyn(&self.shape));
214
215 Ok(result)
221 }
222
223 #[must_use]
225 pub fn map<F, R>(&self, f: F) -> Vec<R>
226 where
227 F: Fn(&ArrayChunk<T, D>) -> R + Send + Sync,
228 R: Send + 'static,
229 {
230 self.chunks.iter().map(f).collect()
233 }
234
235 #[must_use]
241 pub fn map_reduce<F, R, G>(&self, map_fn: F, reducefn: G) -> R
242 where
243 F: Fn(&ArrayChunk<T, D>) -> R + Send + Sync,
244 G: Fn(R, R) -> R + Send + Sync,
245 R: Send + Clone + 'static,
246 {
247 let results = self.map(map_fn);
249
250 results
253 .into_iter()
254 .reduce(reducefn)
255 .expect("Operation failed")
256 }
257}
258
259impl<T, D> ArrayProtocol for DistributedNdarray<T, D>
260where
261 T: Clone
262 + Send
263 + Sync
264 + 'static
265 + num_traits::Zero
266 + std::ops::Div<f64, Output = T>
267 + Default
268 + std::ops::Add<Output = T>
269 + std::ops::Mul<Output = T>,
270 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
271{
272 fn array_function(
273 &self,
274 func: &ArrayFunction,
275 _types: &[TypeId],
276 args: &[Box<dyn Any>],
277 kwargs: &HashMap<String, Box<dyn Any>>,
278 ) -> Result<Box<dyn Any>, NotImplemented> {
279 match func.name {
280 "scirs2::array_protocol::operations::sum" => {
281 let axis = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
283
284 if let Some(&ax) = axis {
285 let dummy_array = self.chunks[0].data.clone();
288 let sum_array = dummy_array.sum_axis(crate::ndarray::Axis(ax));
289
290 Ok(Box::new(super::NdarrayWrapper::new(sum_array)))
292 } else {
293 let sum = self.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
295 Ok(Box::new(sum))
296 }
297 }
298 "scirs2::array_protocol::operations::mean" => {
299 let sum = self.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
302
303 #[allow(clippy::cast_precision_loss)]
305 let count = self.shape.iter().product::<usize>() as f64;
306
307 let mean = sum / count;
309
310 Ok(Box::new(mean))
311 }
312 "scirs2::array_protocol::operations::add" => {
313 if args.len() < 2 {
315 return Err(NotImplemented);
316 }
317
318 if let Some(other) = args[1].downcast_ref::<Self>() {
320 if self.shape() != other.shape() {
322 return Err(NotImplemented);
323 }
324
325 let mut new_chunks = Vec::with_capacity(self.chunks.len());
327
328 for (self_chunk, other_chunk) in self.chunks.iter().zip(other.chunks.iter()) {
331 let result_data = &self_chunk.data + &other_chunk.data;
332 new_chunks.push(ArrayChunk {
333 data: result_data,
334 global_index: self_chunk.global_index.clone(),
335 nodeid: self_chunk.nodeid,
336 });
337 }
338
339 let result = Self::new(new_chunks, self.shape.clone(), self.config.clone());
340
341 return Ok(Box::new(result));
342 }
343
344 Err(NotImplemented)
345 }
346 "scirs2::array_protocol::operations::multiply" => {
347 if args.len() < 2 {
349 return Err(NotImplemented);
350 }
351
352 if let Some(other) = args[1].downcast_ref::<Self>() {
354 if self.shape() != other.shape() {
356 return Err(NotImplemented);
357 }
358
359 let mut new_chunks = Vec::with_capacity(self.chunks.len());
361
362 for (self_chunk, other_chunk) in self.chunks.iter().zip(other.chunks.iter()) {
365 let result_data = &self_chunk.data * &other_chunk.data;
366 new_chunks.push(ArrayChunk {
367 data: result_data,
368 global_index: self_chunk.global_index.clone(),
369 nodeid: self_chunk.nodeid,
370 });
371 }
372
373 let result = Self::new(new_chunks, self.shape.clone(), self.config.clone());
374
375 return Ok(Box::new(result));
376 }
377
378 Err(NotImplemented)
379 }
380 "scirs2::array_protocol::operations::matmul" => {
381 if args.len() < 2 {
383 return Err(NotImplemented);
384 }
385
386 if self.shape.len() != 2 {
388 return Err(NotImplemented);
389 }
390
391 if let Some(other) = args[1].downcast_ref::<Self>() {
393 if self.shape.len() != 2
395 || other.shape.len() != 2
396 || self.shape[1] != other.shape[0]
397 {
398 return Err(NotImplemented);
399 }
400
401 let resultshape = vec![self.shape[0], other.shape[1]];
405
406 let dummyshape = crate::ndarray::IxDyn(&resultshape);
409 let dummy_array = Array::<T, crate::ndarray::IxDyn>::zeros(dummyshape);
410
411 let chunk = ArrayChunk {
413 data: dummy_array,
414 global_index: vec![0],
415 nodeid: 0,
416 };
417
418 let result =
419 DistributedNdarray::new(vec![chunk], resultshape, self.config.clone());
420
421 return Ok(Box::new(result));
422 }
423
424 Err(NotImplemented)
425 }
426 "scirs2::array_protocol::operations::transpose" => {
427 if self.shape.len() != 2 {
429 return Err(NotImplemented);
430 }
431
432 let transposedshape = vec![self.shape[1], self.shape[0]];
434
435 let dummyshape = crate::ndarray::IxDyn(&transposedshape);
442 let dummy_array = Array::<T, crate::ndarray::IxDyn>::zeros(dummyshape);
443
444 let chunk = ArrayChunk {
446 data: dummy_array,
447 global_index: vec![0],
448 nodeid: 0,
449 };
450
451 let result =
452 DistributedNdarray::new(vec![chunk], transposedshape, self.config.clone());
453
454 Ok(Box::new(result))
455 }
456 "scirs2::array_protocol::operations::reshape" => {
457 if let Some(shape) = kwargs
459 .get("shape")
460 .and_then(|s| s.downcast_ref::<Vec<usize>>())
461 {
462 let old_size: usize = self.shape.iter().product();
464 let new_size: usize = shape.iter().product();
465
466 if old_size != new_size {
467 return Err(NotImplemented);
468 }
469
470 let dummyshape = crate::ndarray::IxDyn(shape);
476 let dummy_array = Array::<T, crate::ndarray::IxDyn>::zeros(dummyshape);
477
478 let chunk = ArrayChunk {
480 data: dummy_array,
481 global_index: vec![0],
482 nodeid: 0,
483 };
484
485 let result =
486 DistributedNdarray::new(vec![chunk], shape.clone(), self.config.clone());
487
488 return Ok(Box::new(result));
489 }
490
491 Err(NotImplemented)
492 }
493 _ => Err(NotImplemented),
494 }
495 }
496
497 fn as_any(&self) -> &dyn Any {
498 self
499 }
500
501 fn shape(&self) -> &[usize] {
502 &self.shape
503 }
504
505 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
506 Box::new(Self {
507 config: self.config.clone(),
508 chunks: self.chunks.clone(),
509 shape: self.shape.clone(),
510 id: self.id.clone(),
511 })
512 }
513}
514
515impl<T, D> DistributedArray for DistributedNdarray<T, D>
516where
517 T: Clone
518 + Send
519 + Sync
520 + 'static
521 + num_traits::Zero
522 + std::ops::Div<f64, Output = T>
523 + Default
524 + num_traits::One,
525 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
526{
527 fn distribution_info(&self) -> HashMap<String, String> {
528 let mut info = HashMap::new();
529 info.insert("type".to_string(), "distributed_ndarray".to_string());
530 info.insert("chunks".to_string(), self.chunks.len().to_string());
531 info.insert("shape".to_string(), format!("{:?}", self.shape));
532 info.insert("id".to_string(), self.id.clone());
533 info.insert(
534 "strategy".to_string(),
535 format!("{:?}", self.config.strategy),
536 );
537 info.insert("backend".to_string(), format!("{:?}", self.config.backend));
538 info
539 }
540
541 fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>
544 where
545 D: crate::ndarray::RemoveAxis,
546 T: Default + Clone + num_traits::One,
547 {
548 let array_dyn = self.to_array()?;
551
552 Ok(Box::new(super::NdarrayWrapper::new(array_dyn)))
554 }
555
556 fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
559 let mut config = self.config.clone();
564 config.chunks = chunks;
565
566 let new_dist_array = Self {
569 config,
570 chunks: self.chunks.clone(),
571 shape: self.shape.clone(),
572 id: {
573 let uuid = uuid::Uuid::new_v4();
574 format!("uuid_{uuid}")
575 },
576 };
577
578 Ok(Box::new(new_dist_array))
579 }
580
581 fn is_distributed(&self) -> bool {
582 true
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589 use ::ndarray::Array2;
590
591 #[test]
592 fn test_distributed_ndarray_creation() {
593 let array = Array2::<f64>::ones((10, 5));
594 let config = DistributedConfig {
595 chunks: 3,
596 ..Default::default()
597 };
598
599 let dist_array = DistributedNdarray::from_array(&array, config);
600
601 assert_eq!(dist_array.num_chunks(), 3);
603 assert_eq!(dist_array.shape(), &[10, 5]);
604
605 let expected_total_elements = array.len() * dist_array.num_chunks();
608
609 let total_elements: usize = dist_array
611 .chunks()
612 .iter()
613 .map(|chunk| chunk.data.len())
614 .sum();
615 assert_eq!(total_elements, expected_total_elements);
616 }
617
618 #[test]
619 fn test_distributed_ndarray_to_array() {
620 let array = Array2::<f64>::ones((10, 5));
621 let config = DistributedConfig {
622 chunks: 3,
623 ..Default::default()
624 };
625
626 let dist_array = DistributedNdarray::from_array(&array, config);
627
628 let result = dist_array.to_array().expect("Operation failed");
630
631 assert_eq!(result.shape(), array.shape());
633
634 }
639
640 #[test]
641 fn test_distributed_ndarray_map_reduce() {
642 let array = Array2::<f64>::ones((10, 5));
643 let config = DistributedConfig {
644 chunks: 3,
645 ..Default::default()
646 };
647
648 let dist_array = DistributedNdarray::from_array(&array, config);
649
650 let expected_sum = array.sum() * (dist_array.num_chunks() as f64);
653
654 let sum = dist_array.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
656
657 assert_eq!(sum, expected_sum);
659 }
660}