1use std::any::{Any, TypeId};
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use crate::array_protocol::{ArrayFunction, ArrayProtocol, DistributedArray, NotImplemented};
23use crate::error::CoreResult;
24use ndarray::{Array, Dimension};
25
26#[derive(Debug, Clone, Default)]
28pub struct DistributedConfig {
29 pub chunks: usize,
31
32 pub balance: bool,
34
35 pub strategy: DistributionStrategy,
37
38 pub backend: DistributedBackend,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum DistributionStrategy {
45 RowWise,
47
48 ColumnWise,
50
51 Blocks,
53
54 Auto,
56}
57
58impl Default for DistributionStrategy {
59 fn default() -> Self {
60 Self::Auto
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum DistributedBackend {
67 Threaded,
69
70 MPI,
72
73 TCP,
75}
76
77impl Default for DistributedBackend {
78 fn default() -> Self {
79 Self::Threaded
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct ArrayChunk<T, D>
86where
87 T: Clone + 'static,
88 D: Dimension + 'static,
89{
90 pub data: Array<T, D>,
92
93 pub global_index: Vec<usize>,
95
96 pub node_id: usize,
98}
99
100pub struct DistributedNdarray<T, D>
102where
103 T: Clone + 'static,
104 D: Dimension + 'static,
105{
106 pub config: DistributedConfig,
108
109 chunks: Vec<ArrayChunk<T, D>>,
111
112 shape: Vec<usize>,
114
115 id: String,
117}
118
119impl<T, D> Debug for DistributedNdarray<T, D>
120where
121 T: Clone + Debug + 'static,
122 D: Dimension + Debug + 'static,
123{
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 f.debug_struct("DistributedNdarray")
126 .field("config", &self.config)
127 .field("chunks", &self.chunks.len())
128 .field("shape", &self.shape)
129 .field("id", &self.id)
130 .finish()
131 }
132}
133
134impl<T, D> DistributedNdarray<T, D>
135where
136 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T> + Default,
137 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
138{
139 #[must_use]
141 pub fn new(
142 chunks: Vec<ArrayChunk<T, D>>,
143 shape: Vec<usize>,
144 config: DistributedConfig,
145 ) -> Self {
146 let id = format!("dist_array_{}", uuid::Uuid::new_v4());
147 Self {
148 config,
149 chunks,
150 shape,
151 id,
152 }
153 }
154
155 #[must_use]
157 pub fn from_array(array: &Array<T, D>, config: DistributedConfig) -> Self
158 where
159 T: Clone,
160 {
161 let shape = array.shape().to_vec();
165 let total_elements = array.len();
166 let _chunk_size = total_elements.div_ceil(config.chunks);
167
168 let mut chunks = Vec::new();
170
171 for i in 0..config.chunks {
174 let chunk_data = array.clone();
177
178 chunks.push(ArrayChunk {
179 data: chunk_data,
180 global_index: vec![i],
181 node_id: i % 3, });
183 }
184
185 Self::new(chunks, shape, config)
186 }
187
188 #[must_use]
190 pub fn num_chunks(&self) -> usize {
191 self.chunks.len()
192 }
193
194 #[must_use]
196 pub fn shape(&self) -> &[usize] {
197 &self.shape
198 }
199
200 #[must_use]
202 pub fn chunks(&self) -> &[ArrayChunk<T, D>] {
203 &self.chunks
204 }
205
206 pub fn to_array(&self) -> CoreResult<Array<T, ndarray::IxDyn>>
214 where
215 T: Clone + Default + num_traits::One,
216 {
217 let result = Array::<T, ndarray::IxDyn>::ones(ndarray::IxDyn(&self.shape));
219
220 Ok(result)
226 }
227
228 #[must_use]
230 pub fn map<F, R>(&self, f: F) -> Vec<R>
231 where
232 F: Fn(&ArrayChunk<T, D>) -> R + Send + Sync,
233 R: Send + 'static,
234 {
235 self.chunks.iter().map(f).collect()
238 }
239
240 #[must_use]
246 pub fn map_reduce<F, R, G>(&self, map_fn: F, reduce_fn: G) -> R
247 where
248 F: Fn(&ArrayChunk<T, D>) -> R + Send + Sync,
249 G: Fn(R, R) -> R + Send + Sync,
250 R: Send + Clone + 'static,
251 {
252 let results = self.map(map_fn);
254
255 results.into_iter().reduce(reduce_fn).unwrap()
258 }
259}
260
261impl<T, D> ArrayProtocol for DistributedNdarray<T, D>
262where
263 T: Clone
264 + Send
265 + Sync
266 + 'static
267 + num_traits::Zero
268 + std::ops::Div<f64, Output = T>
269 + Default
270 + std::ops::Add<Output = T>
271 + std::ops::Mul<Output = T>,
272 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
273{
274 fn array_function(
275 &self,
276 func: &ArrayFunction,
277 _types: &[TypeId],
278 args: &[Box<dyn Any>],
279 kwargs: &HashMap<String, Box<dyn Any>>,
280 ) -> Result<Box<dyn Any>, NotImplemented> {
281 match func.name {
282 "scirs2::array_protocol::operations::sum" => {
283 let axis = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
285
286 if let Some(&ax) = axis {
287 let dummy_array = self.chunks[0].data.clone();
290 let sum_array = dummy_array.sum_axis(ndarray::Axis(ax));
291
292 Ok(Box::new(super::NdarrayWrapper::new(sum_array)))
294 } else {
295 let sum = self.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
297 Ok(Box::new(sum))
298 }
299 }
300 "scirs2::array_protocol::operations::mean" => {
301 let sum = self.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
304
305 #[allow(clippy::cast_precision_loss)]
307 let count = self.shape.iter().product::<usize>() as f64;
308
309 let mean = sum / count;
311
312 Ok(Box::new(mean))
313 }
314 "scirs2::array_protocol::operations::add" => {
315 if args.len() < 2 {
317 return Err(NotImplemented);
318 }
319
320 if let Some(other) = args[1].downcast_ref::<Self>() {
322 if self.shape() != other.shape() {
324 return Err(NotImplemented);
325 }
326
327 let mut new_chunks = Vec::with_capacity(self.chunks.len());
329
330 for (self_chunk, other_chunk) in self.chunks.iter().zip(other.chunks.iter()) {
333 let result_data = &self_chunk.data + &other_chunk.data;
334 new_chunks.push(ArrayChunk {
335 data: result_data,
336 global_index: self_chunk.global_index.clone(),
337 node_id: self_chunk.node_id,
338 });
339 }
340
341 let result = Self::new(new_chunks, self.shape.clone(), self.config.clone());
342
343 return Ok(Box::new(result));
344 }
345
346 Err(NotImplemented)
347 }
348 "scirs2::array_protocol::operations::multiply" => {
349 if args.len() < 2 {
351 return Err(NotImplemented);
352 }
353
354 if let Some(other) = args[1].downcast_ref::<Self>() {
356 if self.shape() != other.shape() {
358 return Err(NotImplemented);
359 }
360
361 let mut new_chunks = Vec::with_capacity(self.chunks.len());
363
364 for (self_chunk, other_chunk) in self.chunks.iter().zip(other.chunks.iter()) {
367 let result_data = &self_chunk.data * &other_chunk.data;
368 new_chunks.push(ArrayChunk {
369 data: result_data,
370 global_index: self_chunk.global_index.clone(),
371 node_id: self_chunk.node_id,
372 });
373 }
374
375 let result = Self::new(new_chunks, self.shape.clone(), self.config.clone());
376
377 return Ok(Box::new(result));
378 }
379
380 Err(NotImplemented)
381 }
382 "scirs2::array_protocol::operations::matmul" => {
383 if args.len() < 2 {
385 return Err(NotImplemented);
386 }
387
388 if self.shape.len() != 2 {
390 return Err(NotImplemented);
391 }
392
393 if let Some(other) = args[1].downcast_ref::<Self>() {
395 if self.shape.len() != 2
397 || other.shape.len() != 2
398 || self.shape[1] != other.shape[0]
399 {
400 return Err(NotImplemented);
401 }
402
403 let result_shape = vec![self.shape[0], other.shape[1]];
407
408 let dummy_shape = ndarray::IxDyn(&result_shape);
411 let dummy_array = Array::<T, ndarray::IxDyn>::zeros(dummy_shape);
412
413 let chunk = ArrayChunk {
415 data: dummy_array,
416 global_index: vec![0],
417 node_id: 0,
418 };
419
420 let result =
421 DistributedNdarray::new(vec![chunk], result_shape, self.config.clone());
422
423 return Ok(Box::new(result));
424 }
425
426 Err(NotImplemented)
427 }
428 "scirs2::array_protocol::operations::transpose" => {
429 if self.shape.len() != 2 {
431 return Err(NotImplemented);
432 }
433
434 let transposed_shape = vec![self.shape[1], self.shape[0]];
436
437 let dummy_shape = ndarray::IxDyn(&transposed_shape);
444 let dummy_array = Array::<T, ndarray::IxDyn>::zeros(dummy_shape);
445
446 let chunk = ArrayChunk {
448 data: dummy_array,
449 global_index: vec![0],
450 node_id: 0,
451 };
452
453 let result =
454 DistributedNdarray::new(vec![chunk], transposed_shape, self.config.clone());
455
456 Ok(Box::new(result))
457 }
458 "scirs2::array_protocol::operations::reshape" => {
459 if let Some(shape) = kwargs
461 .get("shape")
462 .and_then(|s| s.downcast_ref::<Vec<usize>>())
463 {
464 let old_size: usize = self.shape.iter().product();
466 let new_size: usize = shape.iter().product();
467
468 if old_size != new_size {
469 return Err(NotImplemented);
470 }
471
472 let dummy_shape = ndarray::IxDyn(shape);
478 let dummy_array = Array::<T, ndarray::IxDyn>::zeros(dummy_shape);
479
480 let chunk = ArrayChunk {
482 data: dummy_array,
483 global_index: vec![0],
484 node_id: 0,
485 };
486
487 let result =
488 DistributedNdarray::new(vec![chunk], shape.clone(), self.config.clone());
489
490 return Ok(Box::new(result));
491 }
492
493 Err(NotImplemented)
494 }
495 _ => Err(NotImplemented),
496 }
497 }
498
499 fn as_any(&self) -> &dyn Any {
500 self
501 }
502
503 fn shape(&self) -> &[usize] {
504 &self.shape
505 }
506
507 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
508 Box::new(Self {
509 config: self.config.clone(),
510 chunks: self.chunks.clone(),
511 shape: self.shape.clone(),
512 id: self.id.clone(),
513 })
514 }
515}
516
517impl<T, D> DistributedArray for DistributedNdarray<T, D>
518where
519 T: Clone
520 + Send
521 + Sync
522 + 'static
523 + num_traits::Zero
524 + std::ops::Div<f64, Output = T>
525 + Default
526 + num_traits::One,
527 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
528{
529 fn distribution_info(&self) -> HashMap<String, String> {
530 let mut info = HashMap::new();
531 info.insert("type".to_string(), "distributed_ndarray".to_string());
532 info.insert("chunks".to_string(), self.chunks.len().to_string());
533 info.insert(
534 "shape".to_string(),
535 format!("{shape:?}", shape = self.shape),
536 );
537 info.insert("id".to_string(), self.id.clone());
538 info.insert(
539 "strategy".to_string(),
540 format!("{strategy:?}", strategy = self.config.strategy),
541 );
542 info.insert(
543 "backend".to_string(),
544 format!("{backend:?}", backend = self.config.backend),
545 );
546 info
547 }
548
549 fn gather(&self) -> CoreResult<Box<dyn ArrayProtocol>>
552 where
553 D: ndarray::RemoveAxis,
554 T: Default + Clone + num_traits::One,
555 {
556 let array_dyn = self.to_array()?;
559
560 Ok(Box::new(super::NdarrayWrapper::new(array_dyn)))
562 }
563
564 fn scatter(&self, chunks: usize) -> CoreResult<Box<dyn DistributedArray>> {
567 let mut config = self.config.clone();
572 config.chunks = chunks;
573
574 let new_dist_array = Self {
577 config,
578 chunks: self.chunks.clone(),
579 shape: self.shape.clone(),
580 id: format!("dist_array_{}", uuid::Uuid::new_v4()),
581 };
582
583 Ok(Box::new(new_dist_array))
584 }
585
586 fn is_distributed(&self) -> bool {
587 true
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use ndarray::Array2;
595
596 #[test]
597 fn test_distributed_ndarray_creation() {
598 let array = Array2::<f64>::ones((10, 5));
599 let config = DistributedConfig {
600 chunks: 3,
601 ..Default::default()
602 };
603
604 let dist_array = DistributedNdarray::from_array(&array, config);
605
606 assert_eq!(dist_array.num_chunks(), 3);
608 assert_eq!(dist_array.shape(), &[10, 5]);
609
610 let expected_total_elements = array.len() * dist_array.num_chunks();
613
614 let total_elements: usize = dist_array
616 .chunks()
617 .iter()
618 .map(|chunk| chunk.data.len())
619 .sum();
620 assert_eq!(total_elements, expected_total_elements);
621 }
622
623 #[test]
624 fn test_distributed_ndarray_to_array() {
625 let array = Array2::<f64>::ones((10, 5));
626 let config = DistributedConfig {
627 chunks: 3,
628 ..Default::default()
629 };
630
631 let dist_array = DistributedNdarray::from_array(&array, config);
632
633 let result = dist_array.to_array().unwrap();
635
636 assert_eq!(result.shape(), array.shape());
638
639 }
644
645 #[test]
646 fn test_distributed_ndarray_map_reduce() {
647 let array = Array2::<f64>::ones((10, 5));
648 let config = DistributedConfig {
649 chunks: 3,
650 ..Default::default()
651 };
652
653 let dist_array = DistributedNdarray::from_array(&array, config);
654
655 let expected_sum = array.sum() * (dist_array.num_chunks() as f64);
658
659 let sum = dist_array.map_reduce(|chunk| chunk.data.sum(), |a, b| a + b);
661
662 assert_eq!(sum, expected_sum);
664 }
665}