1use crate::error::{NeuralError, Result};
7use ndarray::{Array, ArrayD};
8#[cfg(feature = "parallel")]
9use scirs2_core::parallel_ops::*;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fmt;
13use std::sync::{Arc, Mutex, RwLock};
14use std::time::{Duration, Instant};
15
16pub struct ThreadPoolManager {
21 #[cfg(feature = "parallel")]
22 pool: ThreadPool,
23 num_threads: usize,
24}
25
26impl ThreadPoolManager {
27 pub fn new(num_threads: Option<usize>) -> Result<Self> {
45 let num_threads = num_threads.unwrap_or_else(|| {
46 std::thread::available_parallelism()
47 .map(|n| n.get())
48 .unwrap_or(4)
49 });
50
51 #[cfg(feature = "parallel")]
52 let pool = ThreadPoolBuilder::new()
53 .num_threads(num_threads)
54 .build()
55 .map_err(|e| {
56 NeuralError::ComputationError(format!("Failed to create thread pool: {}", e))
57 })?;
58
59 Ok(Self {
60 #[cfg(feature = "parallel")]
61 pool,
62 num_threads,
63 })
64 }
65
66 #[cfg(feature = "parallel")]
68 pub fn execute<F, R>(&self, f: F) -> R
69 where
70 F: FnOnce() -> R + Send,
71 R: Send,
72 {
73 self.pool.install(f)
74 }
75
76 #[cfg(not(feature = "parallel"))]
78 pub fn execute<F, R>(&self, f: F) -> R
79 where
80 F: FnOnce() -> R + Send,
81 R: Send,
82 {
83 f()
84 }
85
86 pub fn parallel_matmul(&self, a: &ArrayD<f32>, b: &ArrayD<f32>) -> Result<ArrayD<f32>> {
91 if a.ndim() != 2 || b.ndim() != 2 {
92 return Err(NeuralError::ComputationError(
93 "Parallel matmul requires 2D arrays".to_string(),
94 ));
95 }
96
97 let (m, k) = (a.shape()[0], a.shape()[1]);
98 let (k2, n) = (b.shape()[0], b.shape()[1]);
99
100 if k != k2 {
101 return Err(NeuralError::ComputationError(
102 "Matrix dimensions incompatible for multiplication".to_string(),
103 ));
104 }
105
106 #[cfg(feature = "parallel")]
107 return self.execute(|| {
108 let mut result = Array::zeros((m, n));
109
110 result
111 .axis_iter_mut(ndarray::Axis(0))
112 .into_par_iter()
113 .enumerate()
114 .for_each(|(i, mut row)| {
115 for j in 0..n {
116 let mut sum = 0.0;
117 for k in 0..k {
118 sum += a[[i, k]] * b[[k, j]];
119 }
120 row[j] = sum;
121 }
122 });
123
124 Ok(result.into_dyn())
125 });
126
127 #[cfg(not(feature = "parallel"))]
128 {
129 let mut result = Array::zeros((m, n));
130 for i in 0..m {
131 for j in 0..n {
132 let mut sum = 0.0;
133 for k in 0..k {
134 sum += a[[i, k]] * b[[k, j]];
135 }
136 result[[i, j]] = sum;
137 }
138 }
139 Ok(result.into_dyn())
140 }
141 }
142
143 pub fn parallel_conv2d(
145 &self,
146 input: &ArrayD<f32>,
147 kernel: &ArrayD<f32>,
148 bias: Option<&[f32]>,
149 stride: (usize, usize),
150 padding: (usize, usize),
151 ) -> Result<ArrayD<f32>> {
152 if input.ndim() != 4 || kernel.ndim() != 4 {
153 return Err(NeuralError::ComputationError(
154 "Input and kernel must be 4D arrays".to_string(),
155 ));
156 }
157
158 let (batch_size, in_channels, in_height, in_width) = (
159 input.shape()[0],
160 input.shape()[1],
161 input.shape()[2],
162 input.shape()[3],
163 );
164 let (out_channels, _, kernel_height, kernel_width) = (
165 kernel.shape()[0],
166 kernel.shape()[1],
167 kernel.shape()[2],
168 kernel.shape()[3],
169 );
170
171 let out_height = (in_height + 2 * padding.0 - kernel_height) / stride.0 + 1;
172 let out_width = (in_width + 2 * padding.1 - kernel_width) / stride.1 + 1;
173
174 #[cfg(feature = "parallel")]
175 return self.execute(|| {
176 let mut output = Array::zeros((batch_size, out_channels, out_height, out_width));
177
178 output
179 .axis_iter_mut(ndarray::Axis(0))
180 .into_par_iter()
181 .enumerate()
182 .for_each(|(batch, mut batch_output)| {
183 for out_ch in 0..out_channels {
184 for out_h in 0..out_height {
185 for out_w in 0..out_width {
186 let mut sum = 0.0f32;
187
188 for in_ch in 0..in_channels {
189 for kh in 0..kernel_height {
190 for kw in 0..kernel_width {
191 let in_h = out_h * stride.0 + kh;
192 let in_w = out_w * stride.1 + kw;
193
194 if in_h >= padding.0
195 && in_w >= padding.1
196 && in_h - padding.0 < in_height
197 && in_w - padding.1 < in_width
198 {
199 let input_val = input[[
200 batch,
201 in_ch,
202 in_h - padding.0,
203 in_w - padding.1,
204 ]];
205 let kernel_val = kernel[[out_ch, in_ch, kh, kw]];
206 sum += input_val * kernel_val;
207 }
208 }
209 }
210 }
211
212 if let Some(b) = bias {
213 sum += b[out_ch % b.len()];
214 }
215
216 batch_output[[out_ch, out_h, out_w]] = sum;
217 }
218 }
219 }
220 });
221
222 Ok(output.into_dyn())
223 });
224
225 #[cfg(not(feature = "parallel"))]
226 {
227 let mut output = Array::zeros((batch_size, out_channels, out_height, out_width));
229
230 for batch in 0..batch_size {
231 for out_ch in 0..out_channels {
232 for out_h in 0..out_height {
233 for out_w in 0..out_width {
234 let mut sum = 0.0f32;
235
236 for in_ch in 0..in_channels {
237 for kh in 0..kernel_height {
238 for kw in 0..kernel_width {
239 let in_h = out_h * stride.0 + kh;
240 let in_w = out_w * stride.1 + kw;
241
242 if in_h >= padding.0
243 && in_w >= padding.1
244 && in_h - padding.0 < in_height
245 && in_w - padding.1 < in_width
246 {
247 let input_val = input[[
248 batch,
249 in_ch,
250 in_h - padding.0,
251 in_w - padding.1,
252 ]];
253 let kernel_val = kernel[[out_ch, in_ch, kh, kw]];
254 sum += input_val * kernel_val;
255 }
256 }
257 }
258 }
259
260 if let Some(b) = bias {
261 sum += b[out_ch % b.len()];
262 }
263
264 output[[batch, out_ch, out_h, out_w]] = sum;
265 }
266 }
267 }
268 }
269
270 Ok(output.into_dyn())
271 }
272 }
273
274 pub fn num_threads(&self) -> usize {
276 self.num_threads
277 }
278
279 pub fn get_stats(&self) -> ThreadPoolStats {
281 ThreadPoolStats {
282 num_threads: self.num_threads,
283 active: true,
284 }
285 }
286}
287
288#[derive(Debug, Clone)]
290pub struct ThreadPoolStats {
291 pub num_threads: usize,
293 pub active: bool,
295}
296
297pub struct PerformanceProfiler {
302 enabled: bool,
303 timings: HashMap<String, Duration>,
304 call_counts: HashMap<String, usize>,
305 active_timers: HashMap<String, Instant>,
306}
307
308impl PerformanceProfiler {
309 pub fn new(enabled: bool) -> Self {
327 Self {
328 enabled,
329 timings: HashMap::new(),
330 call_counts: HashMap::new(),
331 active_timers: HashMap::new(),
332 }
333 }
334
335 pub fn start_timer(&mut self, name: &str) -> Option<Instant> {
337 if self.enabled {
338 let start_time = Instant::now();
339 self.active_timers.insert(name.to_string(), start_time);
340 Some(start_time)
341 } else {
342 None
343 }
344 }
345
346 pub fn end_timer(&mut self, name: String, start_time: Option<Instant>) {
348 if self.enabled {
349 if let Some(start) = start_time {
350 let elapsed = start.elapsed();
351
352 *self.timings.entry(name.clone()).or_insert(Duration::ZERO) += elapsed;
354
355 *self.call_counts.entry(name.clone()).or_insert(0) += 1;
357
358 self.active_timers.remove(&name);
360 }
361 }
362 }
363
364 pub fn time_operation<F, R>(&mut self, name: &str, operation: F) -> R
366 where
367 F: FnOnce() -> R,
368 {
369 let timer = self.start_timer(name);
370 let result = operation();
371 self.end_timer(name.to_string(), timer);
372 result
373 }
374
375 pub fn get_timings(&self) -> &HashMap<String, Duration> {
377 &self.timings
378 }
379
380 pub fn get_call_counts(&self) -> &HashMap<String, usize> {
382 &self.call_counts
383 }
384
385 pub fn get_average_time(&self, name: &str) -> Option<Duration> {
387 if let (Some(&total_time), Some(&count)) =
388 (self.timings.get(name), self.call_counts.get(name))
389 {
390 if count > 0 {
391 Some(total_time / count as u32)
392 } else {
393 None
394 }
395 } else {
396 None
397 }
398 }
399
400 pub fn clear(&mut self) {
402 self.timings.clear();
403 self.call_counts.clear();
404 self.active_timers.clear();
405 }
406
407 pub fn print_summary(&self) {
409 if !self.enabled {
410 println!("Performance profiling is disabled");
411 return;
412 }
413
414 println!("Performance Profile Summary:");
415 println!("===========================");
416
417 let mut operations: Vec<_> = self.timings.keys().collect();
418 operations.sort();
419
420 for name in operations {
421 let total_time = self.timings[name];
422 let count = self.call_counts.get(name).unwrap_or(&0);
423 let avg_time = if *count > 0 {
424 total_time / *count as u32
425 } else {
426 Duration::ZERO
427 };
428
429 println!(
430 "{}: {:.3}ms total, {} calls, {:.3}ms avg",
431 name,
432 total_time.as_secs_f64() * 1000.0,
433 count,
434 avg_time.as_secs_f64() * 1000.0
435 );
436 }
437
438 let total_time: Duration = self.timings.values().sum();
439 println!(
440 "\nTotal profiled time: {:.3}ms",
441 total_time.as_secs_f64() * 1000.0
442 );
443 }
444
445 pub fn get_stats(&self) -> ProfilingStats {
447 let total_time: Duration = self.timings.values().sum();
448 let total_calls: usize = self.call_counts.values().sum();
449
450 ProfilingStats {
451 enabled: self.enabled,
452 total_operations: self.timings.len(),
453 total_calls,
454 total_time,
455 active_timers: self.active_timers.len(),
456 }
457 }
458
459 pub fn set_enabled(&mut self, enabled: bool) {
461 self.enabled = enabled;
462 if !enabled {
463 self.active_timers.clear();
464 }
465 }
466}
467
468#[derive(Debug, Clone)]
470pub struct ProfilingStats {
471 pub enabled: bool,
473 pub total_operations: usize,
475 pub total_calls: usize,
477 pub total_time: Duration,
479 pub active_timers: usize,
481}
482
483pub mod distributed {
485 use super::*;
486
487 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
489 pub enum CommunicationBackend {
490 NCCL,
492 Gloo,
494 MPI,
496 TCP,
498 InMemory,
500 }
501
502 impl fmt::Display for CommunicationBackend {
503 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
504 match self {
505 CommunicationBackend::NCCL => write!(f, "NCCL"),
506 CommunicationBackend::Gloo => write!(f, "Gloo"),
507 CommunicationBackend::MPI => write!(f, "MPI"),
508 CommunicationBackend::TCP => write!(f, "TCP"),
509 CommunicationBackend::InMemory => write!(f, "InMemory"),
510 }
511 }
512 }
513
514 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
516 pub enum DistributedStrategy {
517 DataParallel,
519 ModelParallel,
521 PipelineParallel,
523 Hybrid,
525 }
526
527 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
529 pub enum GradientSyncMethod {
530 AllReduce,
532 ParameterServer,
534 RingAllReduce,
536 TreeAllReduce,
538 HierarchicalAllReduce,
540 }
541
542 #[derive(Debug, Clone, Serialize, Deserialize)]
544 pub struct ProcessInfo {
545 pub local_rank: usize,
547 pub global_rank: usize,
549 pub world_size: usize,
551 pub node_id: usize,
553 pub local_world_size: usize,
555 pub master_addr: String,
557 pub master_port: u16,
559 }
560
561 #[derive(Debug, Clone, Serialize, Deserialize)]
563 pub struct DistributedConfig {
564 pub backend: CommunicationBackend,
566 pub strategy: DistributedStrategy,
568 pub sync_method: GradientSyncMethod,
570 pub process_info: ProcessInfo,
572 pub timeout: u64,
574 pub enable_compression: bool,
576 pub bucket_size_mb: usize,
578 pub mixed_precision: bool,
580 pub overlap_comm: bool,
582 }
583
584 impl Default for DistributedConfig {
585 fn default() -> Self {
586 Self {
587 backend: CommunicationBackend::TCP,
588 strategy: DistributedStrategy::DataParallel,
589 sync_method: GradientSyncMethod::AllReduce,
590 process_info: ProcessInfo {
591 local_rank: 0,
592 global_rank: 0,
593 world_size: 1,
594 node_id: 0,
595 local_world_size: 1,
596 master_addr: "localhost".to_string(),
597 master_port: 12345,
598 },
599 timeout: 300, enable_compression: false,
601 bucket_size_mb: 25,
602 mixed_precision: false,
603 overlap_comm: true,
604 }
605 }
606 }
607
608 #[derive(Debug, Clone, Default, Serialize, Deserialize)]
610 pub struct DistributedStats {
611 pub bytes_communicated: u64,
613 pub allreduce_count: u64,
615 pub communication_time: Duration,
617 pub computation_time: Duration,
619 pub communication_efficiency: f32,
621 pub average_bandwidth: f32,
623 }
624
625 pub struct DistributedManager {
627 config: DistributedConfig,
628 stats: Arc<Mutex<DistributedStats>>,
629 process_group: Option<Arc<dyn ProcessGroup>>,
630 }
631
632 impl DistributedManager {
633 pub fn new(config: DistributedConfig) -> Result<Self> {
635 Ok(Self {
636 config,
637 stats: Arc::new(Mutex::new(DistributedStats::default())),
638 process_group: None,
639 })
640 }
641
642 pub fn initialize(&mut self) -> Result<()> {
644 match self.config.backend {
646 CommunicationBackend::TCP => {
647 self.process_group = Some(Arc::new(TcpProcessGroup::new(&self.config)?));
648 }
649 CommunicationBackend::InMemory => {
650 self.process_group = Some(Arc::new(InMemoryProcessGroup::new(&self.config)?));
651 }
652 _ => {
653 return Err(NeuralError::ComputationError(format!(
654 "Backend {:?} not yet implemented",
655 self.config.backend
656 )));
657 }
658 }
659 Ok(())
660 }
661
662 pub fn all_reduce(&self, tensor: &mut ArrayD<f32>) -> Result<()> {
664 if let Some(ref pg) = self.process_group {
665 let start_time = Instant::now();
666 pg.all_reduce(tensor)?;
667
668 if let Ok(mut stats) = self.stats.lock() {
670 stats.allreduce_count += 1;
671 stats.communication_time += start_time.elapsed();
672 stats.bytes_communicated += (tensor.len() * std::mem::size_of::<f32>()) as u64;
673 }
674
675 Ok(())
676 } else {
677 Err(NeuralError::ComputationError(
678 "Distributed training not initialized".to_string(),
679 ))
680 }
681 }
682
683 pub fn get_stats(&self) -> Result<DistributedStats> {
685 self.stats
686 .lock()
687 .map(|stats| stats.clone())
688 .map_err(|_| NeuralError::ComputationError("Failed to get stats".to_string()))
689 }
690
691 pub fn barrier(&self) -> Result<()> {
693 if let Some(ref pg) = self.process_group {
694 pg.barrier()
695 } else {
696 Ok(()) }
698 }
699
700 pub fn broadcast(&self, tensor: &mut ArrayD<f32>, root: usize) -> Result<()> {
702 if let Some(ref pg) = self.process_group {
703 pg.broadcast(tensor, root)
704 } else {
705 Ok(()) }
707 }
708 }
709
710 pub trait ProcessGroup: Send + Sync {
712 fn all_reduce(&self, tensor: &mut ArrayD<f32>) -> Result<()>;
714 fn barrier(&self) -> Result<()>;
716 fn broadcast(&self, tensor: &mut ArrayD<f32>, root: usize) -> Result<()>;
718 fn get_rank(&self) -> usize;
720 fn get_world_size(&self) -> usize;
722 }
723
724 pub struct TcpProcessGroup {
726 rank: usize,
727 world_size: usize,
728 }
729
730 impl TcpProcessGroup {
731 pub fn new(config: &DistributedConfig) -> Result<Self> {
733 Ok(Self {
734 rank: config.process_info.global_rank,
735 world_size: config.process_info.world_size,
736 })
737 }
738 }
739
740 impl ProcessGroup for TcpProcessGroup {
741 fn all_reduce(&self, tensor: &mut ArrayD<f32>) -> Result<()> {
742 if self.world_size > 1 {
745 tensor.mapv_inplace(|x| x / self.world_size as f32);
746 }
747 Ok(())
748 }
749
750 fn barrier(&self) -> Result<()> {
751 Ok(())
753 }
754
755 fn broadcast(&self, _tensor: &mut ArrayD<f32>, _root: usize) -> Result<()> {
756 Ok(())
758 }
759
760 fn get_rank(&self) -> usize {
761 self.rank
762 }
763
764 fn get_world_size(&self) -> usize {
765 self.world_size
766 }
767 }
768
769 pub struct InMemoryProcessGroup {
771 rank: usize,
772 world_size: usize,
773 #[allow(dead_code)]
774 shared_data: Arc<RwLock<HashMap<String, ArrayD<f32>>>>,
775 }
776
777 impl InMemoryProcessGroup {
778 pub fn new(config: &DistributedConfig) -> Result<Self> {
780 Ok(Self {
781 rank: config.process_info.global_rank,
782 world_size: config.process_info.world_size,
783 shared_data: Arc::new(RwLock::new(HashMap::new())),
784 })
785 }
786 }
787
788 impl ProcessGroup for InMemoryProcessGroup {
789 fn all_reduce(&self, tensor: &mut ArrayD<f32>) -> Result<()> {
790 if self.world_size > 1 {
792 tensor.mapv_inplace(|x| x / self.world_size as f32);
793 }
794 Ok(())
795 }
796
797 fn barrier(&self) -> Result<()> {
798 Ok(())
800 }
801
802 fn broadcast(&self, _tensor: &mut ArrayD<f32>, _root: usize) -> Result<()> {
803 Ok(())
805 }
806
807 fn get_rank(&self) -> usize {
808 self.rank
809 }
810
811 fn get_world_size(&self) -> usize {
812 self.world_size
813 }
814 }
815}