1#![allow(clippy::await_holding_lock)]
17#![allow(dead_code)]
18use scirs2_core::random::thread_rng;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::{Arc, Mutex, RwLock};
22use std::time::{Duration, Instant};
23use tokio::sync::{mpsc, oneshot};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27pub enum RdmaProtocol {
28 InfiniBand,
30 RoCEv1,
32 RoCEv2,
34 IWARP,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
40pub enum RdmaOperation {
41 Read,
43 Write,
45 WriteImmediate,
47 Send,
49 Recv,
51 CompareSwap,
53 FetchAdd,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum RdmaQoS {
60 BestEffort,
62 LowLatency,
64 HighBandwidth,
66 RealTime,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
72pub enum MemoryRegistration {
73 Standard,
75 FastReg,
77 MemoryWindow,
79 Global,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct RdmaConfig {
86 pub protocol: RdmaProtocol,
88 pub qos: RdmaQoS,
90 pub max_message_size: usize,
92 pub queue_depth: u32,
94 pub cq_size: u32,
96 pub memory_registration: MemoryRegistration,
98 pub hardware_checksum: bool,
100 pub adaptive_routing: bool,
102 pub connection_timeout: Duration,
104 pub retry_count: u8,
106 pub path_mtu: u32,
108}
109
110impl Default for RdmaConfig {
111 fn default() -> Self {
112 Self {
113 protocol: RdmaProtocol::RoCEv2,
114 qos: RdmaQoS::HighBandwidth,
115 max_message_size: 4 * 1024 * 1024, queue_depth: 256,
117 cq_size: 512,
118 memory_registration: MemoryRegistration::FastReg,
119 hardware_checksum: true,
120 adaptive_routing: true,
121 connection_timeout: Duration::from_secs(30),
122 retry_count: 7,
123 path_mtu: 4096,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct MemoryRegion {
131 pub addr: u64,
133 pub size: usize,
135 pub rkey: u32,
137 pub lkey: u32,
139 pub access: MemoryAccess,
141 pub registration_type: MemoryRegistration,
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub struct MemoryAccess {
148 pub read: bool,
149 pub write: bool,
150 pub atomic: bool,
151 pub remote_read: bool,
152 pub remote_write: bool,
153 pub remote_atomic: bool,
154}
155
156impl Default for MemoryAccess {
157 fn default() -> Self {
158 Self {
159 read: true,
160 write: true,
161 atomic: false,
162 remote_read: true,
163 remote_write: true,
164 remote_atomic: false,
165 }
166 }
167}
168
169#[derive(Debug, Clone)]
171pub struct RdmaEndpoint {
172 pub node_id: usize,
174 pub address: String,
176 pub port: u16,
178 pub gid: Option<[u8; 16]>,
180 pub lid: Option<u16>,
182 pub qp_num: u32,
184 pub psn: u32,
186}
187
188#[derive(Debug)]
190pub struct WorkRequest {
191 pub id: u64,
193 pub operation: RdmaOperation,
195 pub local_addr: u64,
197 pub lkey: u32,
199 pub remote_addr: Option<u64>,
201 pub rkey: Option<u32>,
203 pub length: usize,
205 pub immediate: Option<u32>,
207 pub completion: oneshot::Sender<RdmaResult<WorkCompletion>>,
209}
210
211#[derive(Debug, Clone)]
213pub struct WorkCompletion {
214 pub wr_id: u64,
216 pub status: CompletionStatus,
218 pub operation: RdmaOperation,
220 pub bytes_transferred: usize,
222 pub immediate: Option<u32>,
224 pub timestamp: Instant,
226}
227
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
230pub enum CompletionStatus {
231 Success,
232 LocalLengthError,
233 LocalQpOperationError,
234 LocalProtectionError,
235 WorkRequestFlushed,
236 MemoryManagementError,
237 BadResponseError,
238 LocalAccessError,
239 RemoteInvalidRequestError,
240 RemoteAccessError,
241 RemoteOperationError,
242 RetryExceededError,
243 RnrRetryExceededError,
244 LocalRddViolationError,
245 RemoteInvalidRdRequest,
246 RemoteAborted,
247 InvalidEecnError,
248 InvalidEecStateError,
249 Fatal,
250}
251
252#[derive(Debug, thiserror::Error)]
254pub enum RdmaError {
255 #[error("Connection failed: {0}")]
256 ConnectionFailed(String),
257 #[error("Memory registration failed: {0}")]
258 MemoryRegistrationFailed(String),
259 #[error("Operation failed: {0}")]
260 OperationFailed(String),
261 #[error("Timeout: {0}")]
262 Timeout(String),
263 #[error("Invalid configuration: {0}")]
264 InvalidConfig(String),
265 #[error("Hardware error: {0}")]
266 HardwareError(String),
267 #[error("Protocol error: {0}")]
268 ProtocolError(String),
269}
270
271pub type RdmaResult<T> = Result<T, RdmaError>;
272
273#[derive(Debug, Clone, Default)]
275pub struct RdmaStatistics {
276 pub total_operations: u64,
278 pub operations_by_type: HashMap<RdmaOperation, u64>,
280 pub bytes_transferred: u64,
282 pub avg_latency_us: f64,
284 pub peak_bandwidth_gbps: f64,
286 pub current_bandwidth_gbps: f64,
288 pub error_count: u64,
290 pub retry_count: u64,
292 pub uptime: Duration,
294 pub cpu_usage_percent: f64,
296}
297
298pub struct RdmaMemoryPool {
300 regions: RwLock<HashMap<usize, Vec<MemoryRegion>>>,
302 config: RdmaMemoryPoolConfig,
304 stats: Arc<Mutex<MemoryPoolStats>>,
306}
307
308#[derive(Debug, Clone)]
309pub struct RdmaMemoryPoolConfig {
310 pub min_pool_size: usize,
312 pub max_pool_size: usize,
314 pub region_sizes: Vec<usize>,
316 pub prefault: bool,
318 pub huge_pages: bool,
320}
321
322#[derive(Debug, Default, Clone)]
323pub struct MemoryPoolStats {
324 allocations: u64,
325 deallocations: u64,
326 cache_hits: u64,
327 cache_misses: u64,
328 total_memory_allocated: usize,
329 peak_memory_usage: usize,
330}
331
332impl RdmaMemoryPool {
333 pub fn new(config: RdmaMemoryPoolConfig) -> RdmaResult<Self> {
335 let mut regions = HashMap::new();
336
337 for &size in &config.region_sizes {
339 let mut size_regions = Vec::new();
340 for _ in 0..config.min_pool_size {
341 let region = Self::allocate_region(size, &config)?;
342 size_regions.push(region);
343 }
344 regions.insert(size, size_regions);
345 }
346
347 Ok(Self {
348 regions: RwLock::new(regions),
349 config,
350 stats: Arc::new(Mutex::new(MemoryPoolStats::default())),
351 })
352 }
353
354 pub fn allocate(&self, size: usize) -> RdmaResult<MemoryRegion> {
356 let mut stats = self.stats.lock().expect("lock should not be poisoned");
357 stats.allocations += 1;
358
359 let region_size = self
361 .config
362 .region_sizes
363 .iter()
364 .find(|&&s| s >= size)
365 .copied()
366 .unwrap_or_else(|| {
367 size.next_power_of_two()
369 });
370
371 let mut regions = self.regions.write().expect("lock should not be poisoned");
372
373 if let Some(size_regions) = regions.get_mut(®ion_size) {
374 if let Some(region) = size_regions.pop() {
375 stats.cache_hits += 1;
376 return Ok(region);
377 }
378 }
379
380 stats.cache_misses += 1;
382 let region = Self::allocate_region(region_size, &self.config)?;
383 Ok(region)
384 }
385
386 pub fn deallocate(&self, mut region: MemoryRegion) {
388 let mut stats = self.stats.lock().expect("lock should not be poisoned");
389 stats.deallocations += 1;
390
391 let mut regions = self.regions.write().expect("lock should not be poisoned");
392 let size_regions = regions.entry(region.size).or_default();
393
394 if size_regions.len() < self.config.max_pool_size {
395 region.addr = 0; size_regions.push(region);
398 }
399 }
401
402 fn allocate_region(size: usize, _config: &RdmaMemoryPoolConfig) -> RdmaResult<MemoryRegion> {
403 Ok(MemoryRegion {
409 addr: 0x1000_0000, size,
411 rkey: thread_rng().random::<u32>(),
412 lkey: thread_rng().random::<u32>(),
413 access: MemoryAccess::default(),
414 registration_type: MemoryRegistration::FastReg,
415 })
416 }
417
418 pub fn statistics(&self) -> MemoryPoolStats {
420 (*self.stats.lock().expect("lock should not be poisoned")).clone()
421 }
422}
423
424pub struct RdmaConnectionManager {
426 connections: RwLock<HashMap<usize, RdmaConnection>>,
428 config: RdmaConfig,
430 stats: Arc<Mutex<RdmaStatistics>>,
432 memory_pool: Arc<RdmaMemoryPool>,
434 work_sender: mpsc::UnboundedSender<WorkRequest>,
436}
437
438pub struct RdmaConnection {
440 pub local_endpoint: RdmaEndpoint,
442 pub remote_endpoint: RdmaEndpoint,
444 pub state: ConnectionState,
446 pub qp_handle: u64,
448 pub cq_handle: u64,
450 pub stats: RdmaStatistics,
452}
453
454#[derive(Debug, Clone, Copy, PartialEq, Eq)]
455pub enum ConnectionState {
456 Disconnected,
457 Connecting,
458 Connected,
459 Error,
460}
461
462impl RdmaConnectionManager {
463 pub fn new(config: RdmaConfig) -> RdmaResult<Self> {
465 let memory_pool_config = RdmaMemoryPoolConfig {
466 min_pool_size: 16,
467 max_pool_size: 256,
468 region_sizes: vec![4096, 65536, 1048576, 16777216], prefault: true,
470 huge_pages: config.max_message_size > 2 * 1024 * 1024,
471 };
472
473 let memory_pool = Arc::new(RdmaMemoryPool::new(memory_pool_config)?);
474 let (work_sender, _work_receiver) = mpsc::unbounded_channel();
475
476 Ok(Self {
477 connections: RwLock::new(HashMap::new()),
478 config,
479 stats: Arc::new(Mutex::new(RdmaStatistics::default())),
480 memory_pool,
481 work_sender,
482 })
483 }
484
485 pub async fn connect(&self, remote_endpoint: RdmaEndpoint) -> RdmaResult<usize> {
487 let connection_id = remote_endpoint.node_id;
488
489 let local_endpoint = RdmaEndpoint {
491 node_id: 0, address: "0.0.0.0".to_string(),
493 port: 0,
494 gid: None,
495 lid: None,
496 qp_num: thread_rng().random::<u32>(),
497 psn: thread_rng().random::<u32>(),
498 };
499
500 let connection = RdmaConnection {
501 local_endpoint,
502 remote_endpoint,
503 state: ConnectionState::Connected,
504 qp_handle: thread_rng().random::<u64>(),
505 cq_handle: thread_rng().random::<u64>(),
506 stats: RdmaStatistics::default(),
507 };
508
509 self.connections
510 .write()
511 .expect("lock should not be poisoned")
512 .insert(connection_id, connection);
513 Ok(connection_id)
514 }
515
516 pub async fn rdma_read(
518 &self,
519 _connection_id: usize,
520 local_addr: u64,
521 remote_addr: u64,
522 length: usize,
523 lkey: u32,
524 rkey: u32,
525 ) -> RdmaResult<WorkCompletion> {
526 self.submit_work_request(WorkRequest {
527 id: thread_rng().random::<u64>(),
528 operation: RdmaOperation::Read,
529 local_addr,
530 lkey,
531 remote_addr: Some(remote_addr),
532 rkey: Some(rkey),
533 length,
534 immediate: None,
535 completion: oneshot::channel().0,
536 })
537 .await
538 }
539
540 pub async fn rdma_write(
542 &self,
543 _connection_id: usize,
544 local_addr: u64,
545 remote_addr: u64,
546 length: usize,
547 lkey: u32,
548 rkey: u32,
549 ) -> RdmaResult<WorkCompletion> {
550 self.submit_work_request(WorkRequest {
551 id: thread_rng().random::<u64>(),
552 operation: RdmaOperation::Write,
553 local_addr,
554 lkey,
555 remote_addr: Some(remote_addr),
556 rkey: Some(rkey),
557 length,
558 immediate: None,
559 completion: oneshot::channel().0,
560 })
561 .await
562 }
563
564 pub async fn atomic_compare_swap(
566 &self,
567 _connection_id: usize,
568 remote_addr: u64,
569 compare: u64,
570 _swap: u64,
571 rkey: u32,
572 ) -> RdmaResult<u64> {
573 let _completion = self
576 .submit_work_request(WorkRequest {
577 id: thread_rng().random::<u64>(),
578 operation: RdmaOperation::CompareSwap,
579 local_addr: 0,
580 lkey: 0,
581 remote_addr: Some(remote_addr),
582 rkey: Some(rkey),
583 length: 8,
584 immediate: None,
585 completion: oneshot::channel().0,
586 })
587 .await?;
588
589 Ok(compare) }
592
593 async fn submit_work_request(&self, work_request: WorkRequest) -> RdmaResult<WorkCompletion> {
594 tokio::time::sleep(Duration::from_micros(1)).await; let completion = WorkCompletion {
598 wr_id: work_request.id,
599 status: CompletionStatus::Success,
600 operation: work_request.operation,
601 bytes_transferred: work_request.length,
602 immediate: work_request.immediate,
603 timestamp: Instant::now(),
604 };
605
606 let mut stats = self.stats.lock().expect("lock should not be poisoned");
608 stats.total_operations += 1;
609 *stats
610 .operations_by_type
611 .entry(work_request.operation)
612 .or_insert(0) += 1;
613 stats.bytes_transferred += work_request.length as u64;
614
615 Ok(completion)
616 }
617
618 pub fn statistics(&self) -> RdmaStatistics {
620 self.stats
621 .lock()
622 .expect("lock should not be poisoned")
623 .clone()
624 }
625
626 pub fn memory_pool_statistics(&self) -> MemoryPoolStats {
628 self.memory_pool.statistics()
629 }
630}
631
632pub struct RdmaTensorScheduler {
634 connection_manager: Arc<RdmaConnectionManager>,
636 operation_queue: Arc<Mutex<Vec<TensorOperation>>>,
638 bandwidth_optimizer: BandwidthOptimizer,
640}
641
642#[derive(Debug)]
643pub struct TensorOperation {
644 tensor_id: String,
645 operation_type: TensorOperationType,
646 source_node: usize,
647 target_nodes: Vec<usize>,
648 data_size: usize,
649 priority: OperationPriority,
650 deadline: Option<Instant>,
651}
652
653#[derive(Debug, Clone, Copy)]
654enum TensorOperationType {
655 AllReduce,
656 AllGather,
657 ReduceScatter,
658 Broadcast,
659 AllToAll,
660}
661
662#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
663enum OperationPriority {
664 Low,
665 Normal,
666 High,
667 Critical,
668}
669
670#[derive(Debug)]
671struct BandwidthOptimizer {
672 link_bandwidth: HashMap<(usize, usize), f64>,
673 link_utilization: HashMap<(usize, usize), f64>,
674 optimization_strategy: BandwidthStrategy,
675}
676
677#[derive(Debug, Clone, Copy)]
678enum BandwidthStrategy {
679 MinimizeLatency,
680 MaximizeThroughput,
681 BalanceLatencyThroughput,
682 AdaptiveDynamic,
683}
684
685impl RdmaTensorScheduler {
686 pub fn new(connection_manager: Arc<RdmaConnectionManager>) -> Self {
688 Self {
689 connection_manager,
690 operation_queue: Arc::new(Mutex::new(Vec::new())),
691 bandwidth_optimizer: BandwidthOptimizer {
692 link_bandwidth: HashMap::new(),
693 link_utilization: HashMap::new(),
694 optimization_strategy: BandwidthStrategy::AdaptiveDynamic,
695 },
696 }
697 }
698
699 pub async fn schedule_operation(&self, operation: TensorOperation) -> RdmaResult<()> {
701 self.operation_queue
702 .lock()
703 .expect("lock should not be poisoned")
704 .push(operation);
705 self.optimize_scheduling().await
706 }
707
708 async fn optimize_scheduling(&self) -> RdmaResult<()> {
709 #[allow(clippy::await_holding_lock)]
710 let mut queue = self
711 .operation_queue
712 .lock()
713 .expect("lock should not be poisoned");
714
715 queue.sort_by(|a, b| {
717 a.priority
718 .cmp(&b.priority)
719 .reverse()
720 .then_with(|| match (a.deadline, b.deadline) {
721 (Some(da), Some(db)) => da.cmp(&db),
722 (Some(_), None) => std::cmp::Ordering::Less,
723 (None, Some(_)) => std::cmp::Ordering::Greater,
724 (None, None) => std::cmp::Ordering::Equal,
725 })
726 });
727
728 if let Some(operation) = queue.pop() {
730 self.execute_tensor_operation(operation).await?;
731 }
732
733 Ok(())
734 }
735
736 async fn execute_tensor_operation(&self, operation: TensorOperation) -> RdmaResult<()> {
737 match operation.operation_type {
738 TensorOperationType::AllReduce => self.execute_all_reduce(&operation).await,
739 TensorOperationType::AllGather => self.execute_all_gather(&operation).await,
740 TensorOperationType::ReduceScatter => self.execute_reduce_scatter(&operation).await,
741 TensorOperationType::Broadcast => self.execute_broadcast(&operation).await,
742 TensorOperationType::AllToAll => self.execute_all_to_all(&operation).await,
743 }
744 }
745
746 async fn execute_all_reduce(&self, _operation: &TensorOperation) -> RdmaResult<()> {
747 Ok(())
750 }
751
752 async fn execute_all_gather(&self, _operation: &TensorOperation) -> RdmaResult<()> {
753 Ok(())
755 }
756
757 async fn execute_reduce_scatter(&self, _operation: &TensorOperation) -> RdmaResult<()> {
758 Ok(())
760 }
761
762 async fn execute_broadcast(&self, _operation: &TensorOperation) -> RdmaResult<()> {
763 Ok(())
765 }
766
767 async fn execute_all_to_all(&self, _operation: &TensorOperation) -> RdmaResult<()> {
768 Ok(())
770 }
771}
772
773#[cfg(test)]
774mod tests {
775 use super::*;
776
777 #[tokio::test]
778 async fn test_rdma_memory_pool() {
779 let config = RdmaMemoryPoolConfig {
780 min_pool_size: 2,
781 max_pool_size: 10,
782 region_sizes: vec![4096, 65536],
783 prefault: true,
784 huge_pages: false,
785 };
786
787 let pool = RdmaMemoryPool::new(config).unwrap();
788
789 let region1 = pool.allocate(2048).unwrap();
791 assert!(region1.size >= 2048);
792
793 let region2 = pool.allocate(8192).unwrap();
794 assert!(region2.size >= 8192);
795
796 pool.deallocate(region1);
798 pool.deallocate(region2);
799
800 let stats = pool.statistics();
801 assert_eq!(stats.allocations, 2);
802 assert_eq!(stats.deallocations, 2);
803 }
804
805 #[tokio::test]
806 async fn test_rdma_connection_manager() {
807 let config = RdmaConfig::default();
808 let manager = RdmaConnectionManager::new(config).unwrap();
809
810 let remote_endpoint = RdmaEndpoint {
811 node_id: 1,
812 address: "192.168.1.100".to_string(),
813 port: 18515,
814 gid: None,
815 lid: None,
816 qp_num: 12345,
817 psn: 67890,
818 };
819
820 let connection_id = manager.connect(remote_endpoint).await.unwrap();
821 assert_eq!(connection_id, 1);
822
823 let result = manager
825 .rdma_read(connection_id, 0x1000, 0x2000, 1024, 0x12345678, 0x87654321)
826 .await
827 .unwrap();
828
829 assert_eq!(result.status, CompletionStatus::Success);
830 assert_eq!(result.operation, RdmaOperation::Read);
831 assert_eq!(result.bytes_transferred, 1024);
832 }
833
834 #[test]
835 fn test_rdma_config_serialization() {
836 let config = RdmaConfig::default();
837 let serialized = serde_json::to_string(&config).unwrap();
838 let deserialized: RdmaConfig = serde_json::from_str(&serialized).unwrap();
839
840 assert_eq!(config.protocol, deserialized.protocol);
841 assert_eq!(config.qos, deserialized.qos);
842 assert_eq!(config.max_message_size, deserialized.max_message_size);
843 }
844
845 #[tokio::test]
846 async fn test_atomic_operations() {
847 let config = RdmaConfig::default();
848 let manager = RdmaConnectionManager::new(config).unwrap();
849
850 let remote_endpoint = RdmaEndpoint {
851 node_id: 1,
852 address: "192.168.1.100".to_string(),
853 port: 18515,
854 gid: None,
855 lid: None,
856 qp_num: 12345,
857 psn: 67890,
858 };
859
860 let connection_id = manager.connect(remote_endpoint).await.unwrap();
861
862 let previous_value = manager
863 .atomic_compare_swap(connection_id, 0x3000, 42, 84, 0x12345678)
864 .await
865 .unwrap();
866
867 assert_eq!(previous_value, 42);
868 }
869}